1use std::sync::Arc;
2
3use anyhow::{Context, Result};
4use locus_core_rs::domain::contracts::EmbeddingProvider;
5use locus_core_rs::{ParseProfile, SurrealDbClient, SurrealDbRuntimeOptions, SurrealDbSettings};
6#[cfg(feature = "local-embedding")]
7use locus_sdk::infrastructure::embeddings::LocalEmbeddingProvider;
8use locus_sdk::infrastructure::embeddings::OllamaEmbeddingProvider;
9use serde_json::Value;
10use surrealdb::engine::any::{Any, connect};
11use surrealdb::opt::auth::Root;
12use tracing::{error, info};
13
14#[derive(Debug, Clone)]
15enum EmbeddingsProviderKind {
16 Ollama,
17 #[cfg(feature = "local-embedding")]
18 Local,
19}
20
21impl EmbeddingsProviderKind {
22 fn parse(value: &str) -> Option<Self> {
23 match value.trim().to_ascii_lowercase().as_str() {
24 "ollama" => Some(Self::Ollama),
25 #[cfg(feature = "local-embedding")]
26 "local" | "local-embedding" | "candle" => Some(Self::Local),
27 _ => None,
28 }
29 }
30}
31
32pub(crate) struct RuntimeSurrealDbClient {
33 db: surrealdb::Surreal<Any>,
34}
35
36impl RuntimeSurrealDbClient {
37 pub(crate) async fn connect(
38 runtime: &SurrealDbRuntimeOptions,
39 user: Option<&str>,
40 password: Option<&str>,
41 ) -> Result<Self> {
42 let db = connect(runtime.endpoint.as_str()).await.with_context(|| {
43 format!(
44 "failed to connect to SurrealDB endpoint '{}'",
45 runtime.endpoint
46 )
47 })?;
48
49 if runtime.use_remote {
50 let username = user
51 .filter(|value| !value.trim().is_empty())
52 .unwrap_or("root");
53 let password = password
54 .filter(|value| !value.trim().is_empty())
55 .unwrap_or("root");
56
57 db.signin(Root {
58 username: username.to_string(),
59 password: password.to_string(),
60 })
61 .await
62 .context("failed to authenticate against remote SurrealDB")?;
63 }
64
65 db.use_ns(runtime.namespace.as_str())
66 .use_db(runtime.database.as_str())
67 .await
68 .with_context(|| {
69 format!(
70 "failed to select namespace '{}' and database '{}'",
71 runtime.namespace, runtime.database
72 )
73 })?;
74
75 Ok(Self { db })
76 }
77
78 fn is_read_query(query: &str) -> bool {
79 query
80 .trim_start()
81 .to_ascii_uppercase()
82 .starts_with("SELECT")
83 }
84}
85
86#[async_trait::async_trait]
87impl SurrealDbClient for RuntimeSurrealDbClient {
88 async fn raw_query(
89 &self,
90 query: &str,
91 parameters: locus_core_rs::QueryParams,
92 ) -> Result<Vec<Value>> {
93 let operation = query
94 .split_whitespace()
95 .next()
96 .unwrap_or("UNKNOWN")
97 .to_ascii_uppercase();
98 let is_read_query = Self::is_read_query(query);
99
100 let response = if parameters.is_empty() {
101 self.db.query(query).await?
102 } else {
103 self.db.query(query).bind(parameters).await?
104 };
105
106 let mut response = match response.check() {
107 Ok(value) => value,
108 Err(err) => {
109 error!(operation = %operation, error = %err, "Surreal query failed");
110 return Err(err.into());
111 }
112 };
113
114 if !is_read_query {
115 return Ok(Vec::new());
116 }
117
118 if let Ok(rows) = response.take::<Vec<Value>>(0) {
119 return Ok(rows);
120 }
121
122 if let Ok(Some(row)) = response.take::<Option<Value>>(0) {
123 return Ok(vec![row]);
124 }
125
126 Ok(Vec::new())
127 }
128}
129
130pub(crate) fn init_logging() {
131 let _ = tracing_subscriber::fmt()
132 .with_writer(std::io::stderr)
133 .with_env_filter(
134 tracing_subscriber::EnvFilter::try_from_default_env()
135 .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
136 )
137 .try_init();
138}
139
140pub(crate) fn load_surreal_settings(args: &[String]) -> Result<SurrealDbSettings> {
141 let mut settings = SurrealDbSettings::default();
142 settings.endpoints.embedded = Some("surrealkv://data/locus-mcp".to_string());
143 settings.database = "locus_mcp".to_string();
144
145 if let Some(value) = env_or_arg(
146 "LOCUS_MCP_SURREAL_REMOTE_ENDPOINT",
147 args,
148 "--remote-endpoint",
149 ) {
150 settings.endpoints.remote = Some(value);
151 }
152 if let Some(value) = env_or_arg(
153 "LOCUS_MCP_SURREAL_EMBEDDED_ENDPOINT",
154 args,
155 "--embedded-endpoint",
156 ) {
157 settings.endpoints.embedded = Some(value);
158 }
159 if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_ENDPOINT", args, "--endpoint") {
160 settings.endpoints.remote = Some(value.clone());
161 settings.endpoints.embedded = Some(value);
162 }
163 if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_NAMESPACE", args, "--namespace") {
164 settings.namespace = value;
165 }
166 if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_DATABASE", args, "--database") {
167 settings.database = value;
168 }
169 if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_USERNAME", args, "--username") {
170 settings.user = Some(value);
171 }
172 if let Some(value) = env_or_arg("LOCUS_MCP_SURREAL_PASSWORD", args, "--password") {
173 settings.password = Some(value);
174 }
175
176 Ok(settings)
177}
178
179pub(crate) fn runtime_args(args: &[String]) -> Vec<String> {
180 let mut runtime_args = args.to_vec();
181 if env_flag("LOCUS_MCP_REMOTE") && !runtime_args.iter().any(|value| value == "--remote") {
182 runtime_args.push("--remote".to_string());
183 }
184 runtime_args
185}
186
187pub(crate) fn build_embedding_provider(args: &[String]) -> Result<Option<Arc<dyn EmbeddingProvider>>> {
188 let embeddings_enabled = env_flag("LOCUS_MCP_EMBEDDINGS_ENABLED")
189 || args
190 .iter()
191 .any(|arg| arg.eq_ignore_ascii_case("--embeddings-enabled"));
192
193 if !embeddings_enabled {
194 return Ok(None);
195 }
196
197 let provider_kind_raw = env_or_arg(
198 "LOCUS_MCP_EMBEDDINGS_PROVIDER",
199 args,
200 "--embeddings-provider",
201 )
202 .unwrap_or_else(|| "ollama".to_string());
203 let provider_kind = EmbeddingsProviderKind::parse(&provider_kind_raw).ok_or_else(|| {
204 anyhow::anyhow!(
205 "unsupported embeddings provider '{}'; expected 'ollama'{}",
206 provider_kind_raw,
207 if cfg!(feature = "local-embedding") {
208 " or 'local'"
209 } else {
210 ""
211 }
212 )
213 })?;
214
215 let endpoint = env_or_arg(
216 "LOCUS_MCP_EMBEDDINGS_ENDPOINT",
217 args,
218 "--embeddings-endpoint",
219 )
220 .unwrap_or_else(|| "http://127.0.0.1:11434/api/embeddings".to_string());
221 let model = env_or_arg("LOCUS_MCP_EMBEDDINGS_MODEL", args, "--embeddings-model")
222 .unwrap_or_else(|| "sttp-encoder".to_string());
223 #[cfg(feature = "local-embedding")]
224 let repo = env_or_arg("LOCUS_MCP_EMBEDDINGS_REPO", args, "--embeddings-repo")
225 .unwrap_or_else(|| "sentence-transformers/all-MiniLM-L6-v2".to_string());
226
227 let provider: Arc<dyn EmbeddingProvider> = match provider_kind {
228 EmbeddingsProviderKind::Ollama => {
229 info!(
230 provider = "ollama",
231 endpoint = %endpoint,
232 model = %model,
233 "auto-embedding enabled for store_context"
234 );
235 Arc::new(OllamaEmbeddingProvider::new(endpoint, model))
236 }
237 #[cfg(feature = "local-embedding")]
238 EmbeddingsProviderKind::Local => {
239 info!(
240 provider = "local",
241 model = %model,
242 repo = %repo,
243 "auto-embedding enabled for store_context"
244 );
245 Arc::new(LocalEmbeddingProvider::new(model, repo)?)
246 }
247 };
248
249 Ok(Some(provider))
250}
251
252pub(crate) fn resolve_parser_profile(args: &[String]) -> Result<ParseProfile> {
253 let raw = env_or_arg("LOCUS_MCP_PARSE_PROFILE", args, "--parse-profile")
254 .unwrap_or_else(|| "strict_typed_ir".to_string());
255
256 parse_profile(raw.as_str()).ok_or_else(|| {
257 anyhow::anyhow!(
258 "unsupported parse profile '{}'; expected one of: strict_typed_ir, strict, tolerant",
259 raw
260 )
261 })
262}
263
264fn parse_profile(value: &str) -> Option<ParseProfile> {
265 match value.trim().to_ascii_lowercase().as_str() {
266 "strict_typed_ir" | "strict-typed-ir" | "stricttypedir" | "typed_ir" | "typed-ir" => {
267 Some(ParseProfile::StrictTypedIr)
268 }
269 "strict" => Some(ParseProfile::Strict),
270 "tolerant" | "default" => Some(ParseProfile::Tolerant),
271 _ => None,
272 }
273}
274
275fn env_or_arg(env_key: &str, args: &[String], arg_name: &str) -> Option<String> {
276 if let Ok(value) = std::env::var(env_key) {
277 let trimmed = value.trim();
278 if !trimmed.is_empty() {
279 return Some(trimmed.to_string());
280 }
281 }
282
283 arg_value(args, arg_name)
284}
285
286fn arg_value(args: &[String], key: &str) -> Option<String> {
287 args.windows(2)
288 .find(|window| window[0].eq_ignore_ascii_case(key))
289 .map(|window| window[1].clone())
290}
291
292fn env_flag(key: &str) -> bool {
293 std::env::var(key)
294 .map(|value| {
295 let normalized = value.trim().to_ascii_lowercase();
296 normalized == "1" || normalized == "true" || normalized == "yes"
297 })
298 .unwrap_or(false)
299}