Skip to main content

locus_mcp/
composition.rs

1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use locus_core_rs::domain::contracts::EmbeddingProvider;
5use locus_core_rs::{ParseProfile, SurrealDbClient, SurrealDbRuntimeOptions, SurrealDbSettings};
6#[cfg(feature = "local-embedding")]
7use locus_sdk::infrastructure::embeddings::LocalEmbeddingProvider;
8use locus_sdk::infrastructure::embeddings::OllamaEmbeddingProvider;
9use serde_json::Value;
10use surrealdb::engine::any::{Any, connect};
11use surrealdb::opt::auth::Root;
12use tracing::{error, info};
13
14#[derive(Debug, Clone)]
15enum EmbeddingsProviderKind {
16    Ollama,
17    #[cfg(feature = "local-embedding")]
18    Local,
19}
20
21impl EmbeddingsProviderKind {
22    fn parse(value: &str) -> Option<Self> {
23        match value.trim().to_ascii_lowercase().as_str() {
24            "ollama" => Some(Self::Ollama),
25            #[cfg(feature = "local-embedding")]
26            "local" | "local-embedding" | "candle" => Some(Self::Local),
27            _ => None,
28        }
29    }
30}
31
32pub(crate) struct RuntimeSurrealDbClient {
33    db: surrealdb::Surreal<Any>,
34}
35
36impl RuntimeSurrealDbClient {
37    pub(crate) async fn connect(
38        runtime: &SurrealDbRuntimeOptions,
39        user: Option<&str>,
40        password: Option<&str>,
41    ) -> Result<Self> {
42        let db = connect(runtime.endpoint.as_str()).await.with_context(|| {
43            format!(
44                "failed to connect to SurrealDB endpoint '{}'",
45                runtime.endpoint
46            )
47        })?;
48
49        if runtime.use_remote {
50            let username = user
51                .filter(|value| !value.trim().is_empty())
52                .unwrap_or("root");
53            let password = password
54                .filter(|value| !value.trim().is_empty())
55                .unwrap_or("root");
56
57            db.signin(Root {
58                username: username.to_string(),
59                password: password.to_string(),
60            })
61            .await
62            .context("failed to authenticate against remote SurrealDB")?;
63        }
64
65        db.use_ns(runtime.namespace.as_str())
66            .use_db(runtime.database.as_str())
67            .await
68            .with_context(|| {
69                format!(
70                    "failed to select namespace '{}' and database '{}'",
71                    runtime.namespace, runtime.database
72                )
73            })?;
74
75        Ok(Self { db })
76    }
77
78    fn is_read_query(query: &str) -> bool {
79        query
80            .trim_start()
81            .to_ascii_uppercase()
82            .starts_with("SELECT")
83    }
84}
85
86#[async_trait::async_trait]
87impl SurrealDbClient for RuntimeSurrealDbClient {
88    async fn raw_query(
89        &self,
90        query: &str,
91        parameters: locus_core_rs::QueryParams,
92    ) -> Result<Vec<Value>> {
93        let operation = query
94            .split_whitespace()
95            .next()
96            .unwrap_or("UNKNOWN")
97            .to_ascii_uppercase();
98        let is_read_query = Self::is_read_query(query);
99
100        let response = if parameters.is_empty() {
101            self.db.query(query).await?
102        } else {
103            self.db.query(query).bind(parameters).await?
104        };
105
106        let mut response = match response.check() {
107            Ok(value) => value,
108            Err(err) => {
109                error!(operation = %operation, error = %err, "Surreal query failed");
110                return Err(err.into());
111            }
112        };
113
114        if !is_read_query {
115            return Ok(Vec::new());
116        }
117
118        if let Ok(rows) = response.take::<Vec<Value>>(0) {
119            return Ok(rows);
120        }
121
122        if let Ok(Some(row)) = response.take::<Option<Value>>(0) {
123            return Ok(vec![row]);
124        }
125
126        Ok(Vec::new())
127    }
128}
129
130pub(crate) fn init_logging() {
131    let _ = tracing_subscriber::fmt()
132        .with_writer(std::io::stderr)
133        .with_env_filter(
134            tracing_subscriber::EnvFilter::try_from_default_env()
135                .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
136        )
137        .try_init();
138}
139
140pub(crate) fn load_surreal_settings(args: &[String]) -> Result<SurrealDbSettings> {
141    let mut settings = SurrealDbSettings::default();
142    settings.endpoints.embedded = Some("surrealkv://data/locus-mcp".to_string());
143    settings.database = "locus_mcp".to_string();
144
145    if let Some(value) = env_or_arg(
146        "LOCUS_MCP_SURREAL_REMOTE_ENDPOINT",
147        args,
148        "--remote-endpoint",
149    ) {
150        settings.endpoints.remote = Some(value);
151    }
152    if let Some(value) = env_or_arg(
153        "LOCUS_MCP_SURREAL_EMBEDDED_ENDPOINT",
154        args,
155        "--embedded-endpoint",
156    ) {
157        settings.endpoints.embedded = Some(value);
158    }
159    if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_ENDPOINT", args, "--endpoint") {
160        settings.endpoints.remote = Some(value.clone());
161        settings.endpoints.embedded = Some(value);
162    }
163    if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_NAMESPACE", args, "--namespace") {
164        settings.namespace = value;
165    }
166    if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_DATABASE", args, "--database") {
167        settings.database = value;
168    }
169    if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_USERNAME", args, "--username") {
170        settings.user = Some(value);
171    }
172    if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_PASSWORD", args, "--password") {
173        settings.password = Some(value);
174    }
175
176    Ok(settings)
177}
178
179pub(crate) fn runtime_args(args: &[String]) -> Vec<String> {
180    let mut runtime_args = args.to_vec();
181    if env_flag("LOCUS_MCP_REMOTE") && !runtime_args.iter().any(|value| value == "--remote") {
182        runtime_args.push("--remote".to_string());
183    }
184    runtime_args
185}
186
187pub(crate) fn build_embedding_provider(args: &[String]) -> Result<Option<Arc<dyn EmbeddingProvider>>> {
188    let embeddings_enabled = env_flag("LOCUS_MCP_EMBEDDINGS_ENABLED")
189        || args
190            .iter()
191            .any(|arg| arg.eq_ignore_ascii_case("--embeddings-enabled"));
192
193    if !embeddings_enabled {
194        return Ok(None);
195    }
196
197    let provider_kind_raw = env_or_arg(
198        "LOCUS_MCP_EMBEDDINGS_PROVIDER",
199        args,
200        "--embeddings-provider",
201    )
202    .unwrap_or_else(|| "ollama".to_string());
203    let provider_kind = EmbeddingsProviderKind::parse(&provider_kind_raw).ok_or_else(|| {
204        anyhow::anyhow!(
205            "unsupported embeddings provider '{}'; expected 'ollama'{}",
206            provider_kind_raw,
207            if cfg!(feature = "local-embedding") {
208                " or 'local'"
209            } else {
210                ""
211            }
212        )
213    })?;
214
215    let endpoint = env_or_arg(
216        "LOCUS_MCP_EMBEDDINGS_ENDPOINT",
217        args,
218        "--embeddings-endpoint",
219    )
220    .unwrap_or_else(|| "http://127.0.0.1:11434/api/embeddings".to_string());
221    let model = env_or_arg("LOCUS_MCP_EMBEDDINGS_MODEL", args, "--embeddings-model")
222        .unwrap_or_else(|| "sttp-encoder".to_string());
223    #[cfg(feature = "local-embedding")]
224    let repo = env_or_arg("LOCUS_MCP_EMBEDDINGS_REPO", args, "--embeddings-repo")
225        .unwrap_or_else(|| "sentence-transformers/all-MiniLM-L6-v2".to_string());
226
227    let provider: Arc<dyn EmbeddingProvider> = match provider_kind {
228        EmbeddingsProviderKind::Ollama => {
229            info!(
230                provider = "ollama",
231                endpoint = %endpoint,
232                model = %model,
233                "auto-embedding enabled for store_context"
234            );
235            Arc::new(OllamaEmbeddingProvider::new(endpoint, model))
236        }
237        #[cfg(feature = "local-embedding")]
238        EmbeddingsProviderKind::Local => {
239            info!(
240                provider = "local",
241                model = %model,
242                repo = %repo,
243                "auto-embedding enabled for store_context"
244            );
245            Arc::new(LocalEmbeddingProvider::new(model, repo)?)
246        }
247    };
248
249    Ok(Some(provider))
250}
251
252pub(crate) fn resolve_parser_profile(args: &[String]) -> Result<ParseProfile> {
253    let raw = env_or_arg("LOCUS_MCP_PARSE_PROFILE", args, "--parse-profile")
254        .unwrap_or_else(|| "strict_typed_ir".to_string());
255
256    parse_profile(raw.as_str()).ok_or_else(|| {
257        anyhow::anyhow!(
258            "unsupported parse profile '{}'; expected one of: strict_typed_ir, strict, tolerant",
259            raw
260        )
261    })
262}
263
264fn parse_profile(value: &str) -> Option<ParseProfile> {
265    match value.trim().to_ascii_lowercase().as_str() {
266        "strict_typed_ir" | "strict-typed-ir" | "stricttypedir" | "typed_ir" | "typed-ir" => {
267            Some(ParseProfile::StrictTypedIr)
268        }
269        "strict" => Some(ParseProfile::Strict),
270        "tolerant" | "default" => Some(ParseProfile::Tolerant),
271        _ => None,
272    }
273}
274
275fn env_or_arg(env_key: &str, args: &[String], arg_name: &str) -> Option<String> {
276    if let Ok(value) = std::env::var(env_key) {
277        let trimmed = value.trim();
278        if !trimmed.is_empty() {
279            return Some(trimmed.to_string());
280        }
281    }
282
283    arg_value(args, arg_name)
284}
285
286fn arg_value(args: &[String], key: &str) -> Option<String> {
287    args.windows(2)
288        .find(|window| window[0].eq_ignore_ascii_case(key))
289        .map(|window| window[1].clone())
290}
291
292fn env_flag(key: &str) -> bool {
293    std::env::var(key)
294        .map(|value| {
295            let normalized = value.trim().to_ascii_lowercase();
296            normalized == "1" || normalized == "true" || normalized == "yes"
297        })
298        .unwrap_or(false)
299}