locus_gateway/
orchestration.rs1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use axum::http::HeaderValue;
5use locus_core_rs::application::services::{
6 CalibrationService, ContextQueryService, MonthlyRollupService,
7 MoodCatalogService, RekeyScopeService, StoreContextService,
8};
9use locus_core_rs::application::validation::TreeSitterValidator;
10use locus_core_rs::domain::contracts::{
11 EmbeddingProvider, NodeStore, NodeStoreInitializer, NodeValidator,
12};
13use locus_core_rs::storage::{
14 InMemoryNodeStore, SurrealDbEndpointsSettings, SurrealDbNodeStore, SurrealDbRuntimeOptions,
15 SurrealDbSettings,
16};
17#[cfg(feature = "local-embedding")]
18use locus_sdk::infrastructure::embeddings::LocalEmbeddingProvider;
19use locus_sdk::infrastructure::embeddings::OllamaEmbeddingProvider;
20use tracing::{error, info};
21
22use crate::app_state::AppState;
23use crate::gateway_args::{EmbeddingsProviderKind, GatewayArgs, GatewayBackend};
24use crate::http_models::CorsAllowedOrigins;
25use crate::providers::{AvecScorer, OllamaAvecScorer};
26use crate::surreal_client::RuntimeSurrealDbClient;
27
28pub(crate) async fn build_state(args: &GatewayArgs) -> Result<AppState> {
29 build_state_with_backend(&args.backend, Some(args)).await
30}
31
32#[cfg(test)]
33pub(crate) async fn build_in_memory_state() -> Result<AppState> {
34 build_in_memory_state_with_args(None).await
35}
36
37pub(crate) fn parse_cors_allowed_origins(value: &str) -> Result<CorsAllowedOrigins> {
38 let trimmed = value.trim();
39 if trimmed.is_empty() {
40 return Err(anyhow!(
41 "CORS allowed origins cannot be empty when CORS is enabled"
42 ));
43 }
44
45 if trimmed == "*" {
46 return Ok(CorsAllowedOrigins::Any);
47 }
48
49 let mut origins = Vec::new();
50 for origin in trimmed
51 .split(',')
52 .map(str::trim)
53 .filter(|part| !part.is_empty())
54 {
55 let header = HeaderValue::from_str(origin)
56 .map_err(|_| anyhow!("Invalid CORS origin value: {origin}"))?;
57 origins.push(header);
58 }
59
60 if origins.is_empty() {
61 return Err(anyhow!(
62 "CORS allowed origins must include at least one origin or '*'"
63 ));
64 }
65
66 Ok(CorsAllowedOrigins::Explicit(origins))
67}
68
69pub(crate) async fn shutdown_signal() {
70 if let Err(err) = tokio::signal::ctrl_c().await {
71 error!(error = %err, "Failed waiting for ctrl_c signal");
72 }
73}
74
75async fn build_state_with_backend(
76 backend: &GatewayBackend,
77 options: Option<&GatewayArgs>,
78) -> Result<AppState> {
79 match backend {
80 GatewayBackend::InMemory => build_in_memory_state_with_args(options).await,
81 GatewayBackend::Surreal => {
82 let options = options.ok_or_else(|| {
83 anyhow!("Surreal backend selected, but no gateway runtime options were provided.")
84 })?;
85 build_surreal_state(options).await
86 }
87 }
88}
89
90async fn build_in_memory_state_with_args(args: Option<&GatewayArgs>) -> Result<AppState> {
91 let store = Arc::new(InMemoryNodeStore::new());
92
93 let initializer: Arc<dyn NodeStoreInitializer> = store.clone();
94 initializer.initialize_async().await?;
95
96 let store_trait: Arc<dyn NodeStore> = store;
97 let validator: Arc<dyn NodeValidator> = Arc::new(TreeSitterValidator);
98 let embedding_provider = build_embedding_provider(args)?;
99 let avec_scorer = build_avec_scorer(args);
100
101 Ok(build_services(
102 store_trait,
103 validator,
104 embedding_provider,
105 avec_scorer,
106 ))
107}
108
109fn build_services(
110 store_trait: Arc<dyn NodeStore>,
111 validator: Arc<dyn NodeValidator>,
112 embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
113 avec_scorer: Option<Arc<dyn AvecScorer>>,
114) -> AppState {
115 let store_context = match embedding_provider.as_ref() {
116 Some(provider) => Arc::new(StoreContextService::with_embedding_provider(
117 store_trait.clone(),
118 validator.clone(),
119 provider.clone(),
120 )),
121 None => Arc::new(StoreContextService::new(
122 store_trait.clone(),
123 validator.clone(),
124 )),
125 };
126
127 AppState {
128 node_store: store_trait.clone(),
129 embedding_provider: embedding_provider.clone(),
130 avec_scorer,
131 calibration: Arc::new(CalibrationService::new(store_trait.clone())),
132 context_query: Arc::new(ContextQueryService::new(store_trait.clone())),
133 mood_catalog: Arc::new(MoodCatalogService::new()),
134 store_context,
135 monthly_rollup: Arc::new(MonthlyRollupService::new(store_trait.clone(), validator)),
136 rekey_scope: Arc::new(RekeyScopeService::new(store_trait)),
137 }
138}
139
140async fn build_surreal_state(args: &GatewayArgs) -> Result<AppState> {
141 let mut settings = SurrealDbSettings::default();
142 settings.endpoints = SurrealDbEndpointsSettings {
143 embedded: args
144 .surreal_embedded_endpoint
145 .clone()
146 .or(settings.endpoints.embedded),
147 remote: args
148 .surreal_remote_endpoint
149 .clone()
150 .or(settings.endpoints.remote),
151 };
152 settings.namespace = args.surreal_namespace.clone();
153 settings.database = args.surreal_database.clone();
154 settings.user = Some(args.surreal_user.clone());
155 settings.password = Some(args.surreal_password.clone());
156
157 let mut runtime_args = Vec::new();
158 if args.remote {
159 runtime_args.push("--remote".to_string());
160 }
161
162 let runtime = SurrealDbRuntimeOptions::from_args(
163 &runtime_args,
164 &settings,
165 Some(args.root_dir_name.as_str()),
166 )?;
167
168 info!(
169 backend = "surreal",
170 root_dir = runtime.root_dir,
171 mode = if runtime.use_remote {
172 "remote"
173 } else {
174 "embedded"
175 },
176 endpoint = runtime.endpoint,
177 namespace = runtime.namespace,
178 database = runtime.database,
179 "Surreal backend requested"
180 );
181
182 let client = Arc::new(
183 RuntimeSurrealDbClient::connect(
184 &runtime,
185 settings.user.as_deref(),
186 settings.password.as_deref(),
187 )
188 .await?,
189 );
190
191 let store = Arc::new(SurrealDbNodeStore::new(client));
192
193 let initializer: Arc<dyn NodeStoreInitializer> = store.clone();
194 initializer.initialize_async().await?;
195
196 let store_trait: Arc<dyn NodeStore> = store;
197 let validator: Arc<dyn NodeValidator> = Arc::new(TreeSitterValidator);
198 let embedding_provider = build_embedding_provider(Some(args))?;
199 let avec_scorer = build_avec_scorer(Some(args));
200
201 Ok(build_services(
202 store_trait,
203 validator,
204 embedding_provider,
205 avec_scorer,
206 ))
207}
208
209fn build_avec_scorer(args: Option<&GatewayArgs>) -> Option<Arc<dyn AvecScorer>> {
210 let args = args?;
211 if !args.avec_scoring_enabled {
212 return None;
213 }
214
215 Some(Arc::new(OllamaAvecScorer::new(
216 args.avec_scoring_endpoint.clone(),
217 args.avec_scoring_model.clone(),
218 )))
219}
220
221fn build_embedding_provider(
222 args: Option<&GatewayArgs>,
223) -> Result<Option<Arc<dyn EmbeddingProvider>>> {
224 let Some(args) = args else {
225 return Ok(None);
226 };
227
228 if !args.embeddings_enabled {
229 return Ok(None);
230 }
231
232 let provider: Arc<dyn EmbeddingProvider> = match args.embeddings_provider {
233 EmbeddingsProviderKind::Ollama => Arc::new(OllamaEmbeddingProvider::new(
234 args.embeddings_endpoint.clone(),
235 args.embeddings_model.clone(),
236 )),
237 #[cfg(feature = "local-embedding")]
238 EmbeddingsProviderKind::Local => Arc::new(LocalEmbeddingProvider::new(
239 args.embeddings_model.clone(),
240 args.embeddings_repo.clone(),
241 )?),
242 };
243
244 Ok(Some(provider))
245}