Skip to main content

locus_gateway/
orchestration.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use axum::http::HeaderValue;
5use locus_core_rs::application::services::{
6    CalibrationService, ContextQueryService, MonthlyRollupService,
7    MoodCatalogService, RekeyScopeService, StoreContextService,
8};
9use locus_core_rs::application::validation::TreeSitterValidator;
10use locus_core_rs::domain::contracts::{
11    EmbeddingProvider, NodeStore, NodeStoreInitializer, NodeValidator,
12};
13use locus_core_rs::storage::{
14    InMemoryNodeStore, SurrealDbEndpointsSettings, SurrealDbNodeStore, SurrealDbRuntimeOptions,
15    SurrealDbSettings,
16};
17#[cfg(feature = "local-embedding")]
18use locus_sdk::infrastructure::embeddings::LocalEmbeddingProvider;
19use locus_sdk::infrastructure::embeddings::OllamaEmbeddingProvider;
20use tracing::{error, info};
21
22use crate::app_state::AppState;
23use crate::gateway_args::{EmbeddingsProviderKind, GatewayArgs, GatewayBackend};
24use crate::http_models::CorsAllowedOrigins;
25use crate::providers::{AvecScorer, OllamaAvecScorer};
26use crate::surreal_client::RuntimeSurrealDbClient;
27
28pub(crate) async fn build_state(args: &GatewayArgs) -> Result<AppState> {
29    build_state_with_backend(&args.backend, Some(args)).await
30}
31
32#[cfg(test)]
33pub(crate) async fn build_in_memory_state() -> Result<AppState> {
34    build_in_memory_state_with_args(None).await
35}
36
37pub(crate) fn parse_cors_allowed_origins(value: &str) -> Result<CorsAllowedOrigins> {
38    let trimmed = value.trim();
39    if trimmed.is_empty() {
40        return Err(anyhow!(
41            "CORS allowed origins cannot be empty when CORS is enabled"
42        ));
43    }
44
45    if trimmed == "*" {
46        return Ok(CorsAllowedOrigins::Any);
47    }
48
49    let mut origins = Vec::new();
50    for origin in trimmed
51        .split(',')
52        .map(str::trim)
53        .filter(|part| !part.is_empty())
54    {
55        let header = HeaderValue::from_str(origin)
56            .map_err(|_| anyhow!("Invalid CORS origin value: {origin}"))?;
57        origins.push(header);
58    }
59
60    if origins.is_empty() {
61        return Err(anyhow!(
62            "CORS allowed origins must include at least one origin or '*'"
63        ));
64    }
65
66    Ok(CorsAllowedOrigins::Explicit(origins))
67}
68
69pub(crate) async fn shutdown_signal() {
70    if let Err(err) = tokio::signal::ctrl_c().await {
71        error!(error = %err, "Failed waiting for ctrl_c signal");
72    }
73}
74
75async fn build_state_with_backend(
76    backend: &GatewayBackend,
77    options: Option<&GatewayArgs>,
78) -> Result<AppState> {
79    match backend {
80        GatewayBackend::InMemory => build_in_memory_state_with_args(options).await,
81        GatewayBackend::Surreal => {
82            let options = options.ok_or_else(|| {
83                anyhow!("Surreal backend selected, but no gateway runtime options were provided.")
84            })?;
85            build_surreal_state(options).await
86        }
87    }
88}
89
90async fn build_in_memory_state_with_args(args: Option<&GatewayArgs>) -> Result<AppState> {
91    let store = Arc::new(InMemoryNodeStore::new());
92
93    let initializer: Arc<dyn NodeStoreInitializer> = store.clone();
94    initializer.initialize_async().await?;
95
96    let store_trait: Arc<dyn NodeStore> = store;
97    let validator: Arc<dyn NodeValidator> = Arc::new(TreeSitterValidator);
98    let embedding_provider = build_embedding_provider(args)?;
99    let avec_scorer = build_avec_scorer(args);
100
101    Ok(build_services(
102        store_trait,
103        validator,
104        embedding_provider,
105        avec_scorer,
106    ))
107}
108
109fn build_services(
110    store_trait: Arc<dyn NodeStore>,
111    validator: Arc<dyn NodeValidator>,
112    embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
113    avec_scorer: Option<Arc<dyn AvecScorer>>,
114) -> AppState {
115    let store_context = match embedding_provider.as_ref() {
116        Some(provider) => Arc::new(StoreContextService::with_embedding_provider(
117            store_trait.clone(),
118            validator.clone(),
119            provider.clone(),
120        )),
121        None => Arc::new(StoreContextService::new(
122            store_trait.clone(),
123            validator.clone(),
124        )),
125    };
126
127    AppState {
128        node_store: store_trait.clone(),
129        embedding_provider: embedding_provider.clone(),
130        avec_scorer,
131        calibration: Arc::new(CalibrationService::new(store_trait.clone())),
132        context_query: Arc::new(ContextQueryService::new(store_trait.clone())),
133        mood_catalog: Arc::new(MoodCatalogService::new()),
134        store_context,
135        monthly_rollup: Arc::new(MonthlyRollupService::new(store_trait.clone(), validator)),
136        rekey_scope: Arc::new(RekeyScopeService::new(store_trait)),
137    }
138}
139
140async fn build_surreal_state(args: &GatewayArgs) -> Result<AppState> {
141    let mut settings = SurrealDbSettings::default();
142    settings.endpoints = SurrealDbEndpointsSettings {
143        embedded: args
144            .surreal_embedded_endpoint
145            .clone()
146            .or(settings.endpoints.embedded),
147        remote: args
148            .surreal_remote_endpoint
149            .clone()
150            .or(settings.endpoints.remote),
151    };
152    settings.namespace = args.surreal_namespace.clone();
153    settings.database = args.surreal_database.clone();
154    settings.user = Some(args.surreal_user.clone());
155    settings.password = Some(args.surreal_password.clone());
156
157    let mut runtime_args = Vec::new();
158    if args.remote {
159        runtime_args.push("--remote".to_string());
160    }
161
162    let runtime = SurrealDbRuntimeOptions::from_args(
163        &runtime_args,
164        &settings,
165        Some(args.root_dir_name.as_str()),
166    )?;
167
168    info!(
169        backend = "surreal",
170        root_dir = runtime.root_dir,
171        mode = if runtime.use_remote {
172            "remote"
173        } else {
174            "embedded"
175        },
176        endpoint = runtime.endpoint,
177        namespace = runtime.namespace,
178        database = runtime.database,
179        "Surreal backend requested"
180    );
181
182    let client = Arc::new(
183        RuntimeSurrealDbClient::connect(
184            &runtime,
185            settings.user.as_deref(),
186            settings.password.as_deref(),
187        )
188        .await?,
189    );
190
191    let store = Arc::new(SurrealDbNodeStore::new(client));
192
193    let initializer: Arc<dyn NodeStoreInitializer> = store.clone();
194    initializer.initialize_async().await?;
195
196    let store_trait: Arc<dyn NodeStore> = store;
197    let validator: Arc<dyn NodeValidator> = Arc::new(TreeSitterValidator);
198    let embedding_provider = build_embedding_provider(Some(args))?;
199    let avec_scorer = build_avec_scorer(Some(args));
200
201    Ok(build_services(
202        store_trait,
203        validator,
204        embedding_provider,
205        avec_scorer,
206    ))
207}
208
209fn build_avec_scorer(args: Option<&GatewayArgs>) -> Option<Arc<dyn AvecScorer>> {
210    let args = args?;
211    if !args.avec_scoring_enabled {
212        return None;
213    }
214
215    Some(Arc::new(OllamaAvecScorer::new(
216        args.avec_scoring_endpoint.clone(),
217        args.avec_scoring_model.clone(),
218    )))
219}
220
221fn build_embedding_provider(
222    args: Option<&GatewayArgs>,
223) -> Result<Option<Arc<dyn EmbeddingProvider>>> {
224    let Some(args) = args else {
225        return Ok(None);
226    };
227
228    if !args.embeddings_enabled {
229        return Ok(None);
230    }
231
232    let provider: Arc<dyn EmbeddingProvider> = match args.embeddings_provider {
233        EmbeddingsProviderKind::Ollama => Arc::new(OllamaEmbeddingProvider::new(
234            args.embeddings_endpoint.clone(),
235            args.embeddings_model.clone(),
236        )),
237        #[cfg(feature = "local-embedding")]
238        EmbeddingsProviderKind::Local => Arc::new(LocalEmbeddingProvider::new(
239            args.embeddings_model.clone(),
240            args.embeddings_repo.clone(),
241        )?),
242    };
243
244    Ok(Some(provider))
245}