Skip to main content

locus_cli/
main.rs

1use std::fs;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Context, Result, anyhow, bail};
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use clap::{Parser, Subcommand, ValueEnum};
9use locus_core_rs::domain::models::{AvecState, MonthlyRollupRequest, SttpNode};
10use locus_core_rs::{
11    CalibrationService, InMemoryNodeStore, MonthlyRollupService, MoodCatalogService, NodeStore,
12    NodeStoreInitializer, NodeValidator, QueryParams,
13    StoreContextService, SurrealDbClient, SurrealDbEndpointsSettings, SurrealDbNodeStore,
14    SurrealDbRuntimeOptions, SurrealDbSettings, TreeSitterValidator,
15};
16use locus_sdk::application::memory_find::MemoryFindService;
17use locus_sdk::application::memory_recall::MemoryRecallService;
18use locus_sdk::domain::memory::{MemoryFindRequest, MemoryPage, MemoryRecallRequest, MemoryScope};
19use serde_json::{Value, json};
20use surrealdb::engine::any::{Any, connect};
21use surrealdb::opt::auth::Root;
22
23const DEFAULT_TENANT: &str = "default";
24const TENANT_SCOPE_PREFIX: &str = "tenant:";
25const TENANT_SCOPE_SEPARATOR: &str = "::session:";
26
27#[derive(Copy, Clone, Debug, ValueEnum)]
28enum StorageMode {
29    InMemory,
30    Surreal,
31}
32
33#[derive(Parser, Debug)]
34#[command(name = "locus-cli", version, about = "SDK-backed CLI for Locus memory operations")]
35struct Cli {
36    #[arg(long, env = "LOCUS_STORAGE", default_value = "surreal")]
37    storage: StorageMode,
38
39    #[arg(long, env = "LOCUS_TENANT_ID", help = "Optional tenant ID")]
40    tenant_id: Option<String>,
41
42    #[arg(long, env = "LOCUS_REMOTE", default_value_t = false)]
43    remote: bool,
44
45    #[arg(long, env = "LOCUS_ROOT_DIR_NAME", default_value = ".locus-cli")]
46    root_dir_name: String,
47
48    #[arg(long, env = "LOCUS_SURREAL_ENDPOINT")]
49    surreal_endpoint: Option<String>,
50
51    #[arg(long, env = "LOCUS_SURREAL_REMOTE_ENDPOINT")]
52    surreal_remote_endpoint: Option<String>,
53
54    #[arg(long, env = "LOCUS_SURREAL_EMBEDDED_ENDPOINT")]
55    surreal_embedded_endpoint: Option<String>,
56
57    #[arg(long, env = "LOCUS_SURREAL_NAMESPACE", default_value = "entasis")]
58    surreal_namespace: String,
59
60    #[arg(long, env = "LOCUS_SURREAL_DATABASE", default_value = "locus_cli")]
61    surreal_database: String,
62
63    #[arg(long, env = "LOCUS_SURREAL_USERNAME")]
64    surreal_username: Option<String>,
65
66    #[arg(long, env = "LOCUS_SURREAL_PASSWORD")]
67    surreal_password: Option<String>,
68
69    #[arg(long, help = "Pretty-print JSON output")]
70    pretty: bool,
71
72    #[command(subcommand)]
73    command: Commands,
74}
75
76#[derive(Subcommand, Debug)]
77enum Commands {
78    Health,
79    Calibrate {
80        #[arg(long)]
81        session_id: String,
82        #[arg(long)]
83        stability: f32,
84        #[arg(long)]
85        friction: f32,
86        #[arg(long)]
87        logic: f32,
88        #[arg(long)]
89        autonomy: f32,
90        #[arg(long, default_value = "manual")]
91        trigger: String,
92    },
93    Store {
94        #[arg(long)]
95        session_id: String,
96        #[arg(long, help = "Path to a file containing one STTP node")]
97        node_file: PathBuf,
98    },
99    Context {
100        #[arg(long)]
101        session_id: String,
102        #[arg(long)]
103        stability: f32,
104        #[arg(long)]
105        friction: f32,
106        #[arg(long)]
107        logic: f32,
108        #[arg(long)]
109        autonomy: f32,
110        #[arg(long)]
111        limit: Option<usize>,
112        #[arg(long)]
113        from_utc: Option<String>,
114        #[arg(long)]
115        to_utc: Option<String>,
116        #[arg(long, value_delimiter = ',')]
117        tiers: Vec<String>,
118        #[arg(long)]
119        query_text: Option<String>,
120        #[arg(long)]
121        alpha: Option<f32>,
122        #[arg(long)]
123        beta: Option<f32>,
124    },
125    Nodes {
126        #[arg(long)]
127        limit: Option<usize>,
128        #[arg(long)]
129        session_id: Option<String>,
130    },
131    Moods {
132        #[arg(long)]
133        target_mood: Option<String>,
134        #[arg(long)]
135        blend: Option<f32>,
136        #[arg(long)]
137        current_stability: Option<f32>,
138        #[arg(long)]
139        current_friction: Option<f32>,
140        #[arg(long)]
141        current_logic: Option<f32>,
142        #[arg(long)]
143        current_autonomy: Option<f32>,
144    },
145    Rollup {
146        #[arg(long)]
147        session_id: String,
148        #[arg(long)]
149        start_date_utc: String,
150        #[arg(long)]
151        end_date_utc: String,
152        #[arg(long)]
153        source_session_id: Option<String>,
154        #[arg(long)]
155        parent_node_id: Option<String>,
156        #[arg(long)]
157        persist: Option<bool>,
158        #[arg(long)]
159        limit: Option<usize>,
160    },
161}
162
163struct Services {
164    calibration: CalibrationService,
165    store_context: StoreContextService,
166    memory_find: MemoryFindService,
167    memory_recall: MemoryRecallService,
168    moods: MoodCatalogService,
169    monthly_rollup: MonthlyRollupService,
170    storage_mode: &'static str,
171    storage_endpoint: Option<String>,
172    storage_namespace: Option<String>,
173    storage_database: Option<String>,
174}
175
176pub struct RuntimeSurrealDbClient {
177    db: surrealdb::Surreal<Any>,
178}
179
180impl RuntimeSurrealDbClient {
181    pub async fn connect(
182        runtime: &SurrealDbRuntimeOptions,
183        user: Option<&str>,
184        password: Option<&str>,
185    ) -> Result<Self> {
186        let db = connect(runtime.endpoint.as_str()).await.with_context(|| {
187            format!(
188                "failed to connect to SurrealDB endpoint '{}'",
189                runtime.endpoint
190            )
191        })?;
192
193        if runtime.use_remote {
194            let username = user.filter(|v| !v.trim().is_empty()).unwrap_or("root");
195            let password = password.filter(|v| !v.trim().is_empty()).unwrap_or("root");
196
197            db.signin(Root {
198                username: username.to_string(),
199                password: password.to_string(),
200            })
201            .await
202            .context("failed to authenticate against remote SurrealDB")?;
203        }
204
205        db.use_ns(runtime.namespace.as_str())
206            .use_db(runtime.database.as_str())
207            .await
208            .with_context(|| {
209                format!(
210                    "failed to select namespace '{}' and database '{}'",
211                    runtime.namespace, runtime.database
212                )
213            })?;
214
215        Ok(Self { db })
216    }
217
218    fn is_read_query(query: &str) -> bool {
219        query
220            .trim_start()
221            .to_ascii_uppercase()
222            .starts_with("SELECT")
223    }
224}
225
226#[async_trait]
227impl SurrealDbClient for RuntimeSurrealDbClient {
228    async fn raw_query(&self, query: &str, parameters: QueryParams) -> Result<Vec<Value>> {
229        let is_read_query = Self::is_read_query(query);
230
231        let response = if parameters.is_empty() {
232            self.db.query(query).await?
233        } else {
234            self.db.query(query).bind(parameters).await?
235        };
236
237        let mut response = response.check()?;
238
239        if !is_read_query {
240            return Ok(Vec::new());
241        }
242
243        if let Ok(rows) = response.take::<Vec<Value>>(0) {
244            return Ok(rows);
245        }
246
247        if let Ok(Some(row)) = response.take::<Option<Value>>(0) {
248            return Ok(vec![row]);
249        }
250
251        Ok(Vec::new())
252    }
253}
254
255#[tokio::main]
256async fn main() -> Result<()> {
257    let cli = Cli::parse();
258    let tenant = resolve_tenant(cli.tenant_id.as_deref())?;
259    let services = build_services(&cli).await?;
260
261    let output = match cli.command {
262        Commands::Health => {
263            json!({
264                "status": "ok",
265                "transport": "sdk-core",
266                "storage": {
267                    "mode": services.storage_mode,
268                    "endpoint": services.storage_endpoint,
269                    "namespace": services.storage_namespace,
270                    "database": services.storage_database,
271                }
272            })
273        }
274        Commands::Calibrate {
275            session_id,
276            stability,
277            friction,
278            logic,
279            autonomy,
280            trigger,
281        } => {
282            let session_id = scope_session_id(&tenant, &session_id);
283            let result = services
284                .calibration
285                .calibrate_async(
286                    &session_id,
287                    stability,
288                    friction,
289                    logic,
290                    autonomy,
291                    &trigger,
292                )
293                .await?;
294
295            json!({
296                "previousAvec": avec_to_json(result.previous_avec),
297                "delta": result.delta,
298                "driftClassification": format!("{:?}", result.drift_classification),
299                "trigger": result.trigger,
300                "triggerHistory": result.trigger_history,
301                "isFirstCalibration": result.is_first_calibration,
302            })
303        }
304        Commands::Store {
305            session_id,
306            node_file,
307        } => {
308            let session_id = scope_session_id(&tenant, &session_id);
309            let node = fs::read_to_string(&node_file)
310                .with_context(|| format!("failed to read node file: {}", node_file.display()))?;
311
312            if node.trim().is_empty() {
313                bail!("node file is empty");
314            }
315
316            let result = services.store_context.store_async(&node, &session_id).await;
317            json!({
318                "nodeId": result.node_id,
319                "psi": result.psi,
320                "valid": result.valid,
321                "validationError": result.validation_error,
322            })
323        }
324        Commands::Context {
325            session_id,
326            stability,
327            friction,
328            logic,
329            autonomy,
330            limit,
331            from_utc,
332            to_utc,
333            tiers,
334            query_text,
335            alpha,
336            beta,
337        } => {
338            let session_id = scope_session_id(&tenant, &session_id);
339            let from_utc = parse_utc_optional(from_utc.as_deref(), "from_utc")?;
340            let to_utc = parse_utc_optional(to_utc.as_deref(), "to_utc")?;
341            let tiers = normalize_tiers(tiers);
342
343            let request = MemoryRecallRequest {
344                scope: MemoryScope {
345                    tenant_id: None,
346                    session_ids: Some(vec![session_id]),
347                    tiers,
348                    from_utc,
349                    to_utc,
350                },
351                page: MemoryPage {
352                    limit: limit.unwrap_or(5),
353                    cursor: None,
354                },
355                scoring: locus_sdk::domain::memory::MemoryScoring {
356                    alpha: alpha.unwrap_or(0.7),
357                    beta: beta.unwrap_or(0.3),
358                    ..Default::default()
359                },
360                current_avec: Some(AvecState {
361                    stability,
362                    friction,
363                    logic,
364                    autonomy,
365                }),
366                query_text,
367                query_embedding: None,
368                ..Default::default()
369            };
370
371            let result = services.memory_recall.execute(&request).await?;
372            let nodes = normalize_nodes_for_tenant(result.nodes, &tenant);
373
374            json!({
375                "nodes": nodes.iter().map(sttp_node_to_json).collect::<Vec<_>>(),
376                "retrieved": nodes.len(),
377                "psiRange": {
378                    "min": result.psi_range.min,
379                    "max": result.psi_range.max,
380                    "average": result.psi_range.average,
381                },
382                "retrievalPath": format!("{:?}", result.retrieval_path),
383                "hasMore": result.has_more,
384                "nextCursor": result.next_cursor,
385            })
386        }
387        Commands::Nodes { limit, session_id } => {
388            let requested_limit = limit.unwrap_or(50).clamp(1, 200);
389
390            let scoped_session = session_id.as_deref().map(|id| scope_session_id(&tenant, id));
391            let result = services
392                .memory_find
393                .execute(&MemoryFindRequest {
394                    scope: MemoryScope {
395                        session_ids: scoped_session.map(|id| vec![id]),
396                        ..Default::default()
397                    },
398                    page: MemoryPage {
399                        limit: if session_id.is_some() {
400                            requested_limit
401                        } else {
402                            (requested_limit * 4).clamp(1, 200)
403                        },
404                        cursor: None,
405                    },
406                    ..Default::default()
407                })
408                .await?;
409
410            let mut nodes = normalize_nodes_for_tenant(result.nodes, &tenant);
411            nodes.truncate(requested_limit);
412
413            json!({
414                "nodes": nodes.iter().map(sttp_node_to_json).collect::<Vec<_>>(),
415                "retrieved": nodes.len(),
416            })
417        }
418        Commands::Moods {
419            target_mood,
420            blend,
421            current_stability,
422            current_friction,
423            current_logic,
424            current_autonomy,
425        } => {
426            let result = services.moods.get(
427                target_mood.as_deref(),
428                blend.unwrap_or(1.0),
429                current_stability,
430                current_friction,
431                current_logic,
432                current_autonomy,
433            );
434
435            json!({
436                "presets": result.presets.iter().map(|preset| json!({
437                    "name": preset.name,
438                    "description": preset.description,
439                    "avec": avec_to_json(preset.avec),
440                })).collect::<Vec<_>>(),
441                "applyGuide": result.apply_guide,
442                "swapPreview": result.swap_preview.map(|preview| json!({
443                    "targetMood": preview.target_mood,
444                    "blend": preview.blend,
445                    "current": avec_to_json(preview.current),
446                    "target": avec_to_json(preview.target),
447                    "blended": avec_to_json(preview.blended),
448                })),
449            })
450        }
451        Commands::Rollup {
452            session_id,
453            start_date_utc,
454            end_date_utc,
455            source_session_id,
456            parent_node_id,
457            persist,
458            limit,
459        } => {
460            let session_id = scope_session_id(&tenant, &session_id);
461            let source_session_id = source_session_id.map(|id| scope_session_id(&tenant, &id));
462            let request = MonthlyRollupRequest {
463                session_id,
464                start_utc: parse_utc_required(&start_date_utc, "start_date_utc")?,
465                end_utc: parse_utc_required(&end_date_utc, "end_date_utc")?,
466                source_session_id,
467                parent_node_id,
468                limit: limit.unwrap_or(5000),
469                persist: persist.unwrap_or(true),
470            };
471
472            let result = services.monthly_rollup.create_async(request).await;
473            json!({
474                "success": result.success,
475                "nodeId": result.node_id,
476                "rawNode": result.raw_node,
477                "error": result.error,
478                "sourceNodes": result.source_nodes,
479                "parentReference": result.parent_reference,
480                "userAverage": avec_to_json(result.user_average),
481                "modelAverage": avec_to_json(result.model_average),
482                "compressionAverage": avec_to_json(result.compression_average),
483                "rhoRange": {
484                    "min": result.rho_range.min,
485                    "max": result.rho_range.max,
486                    "average": result.rho_range.average,
487                },
488                "kappaRange": {
489                    "min": result.kappa_range.min,
490                    "max": result.kappa_range.max,
491                    "average": result.kappa_range.average,
492                },
493                "psiRange": {
494                    "min": result.psi_range.min,
495                    "max": result.psi_range.max,
496                    "average": result.psi_range.average,
497                },
498                "rhoBands": {
499                    "low": result.rho_bands.low,
500                    "medium": result.rho_bands.medium,
501                    "high": result.rho_bands.high,
502                },
503                "kappaBands": {
504                    "low": result.kappa_bands.low,
505                    "medium": result.kappa_bands.medium,
506                    "high": result.kappa_bands.high,
507                },
508            })
509        }
510    };
511
512    if cli.pretty {
513        println!(
514            "{}",
515            serde_json::to_string_pretty(&output)
516                .context("failed to render pretty JSON output")?
517        );
518    } else {
519        println!("{}", serde_json::to_string(&output)?);
520    }
521
522    Ok(())
523}
524
525async fn build_services(cli: &Cli) -> Result<Services> {
526    let (store, initializer, storage_mode, storage_endpoint, storage_namespace, storage_database) =
527        match cli.storage {
528            StorageMode::InMemory => {
529                let store = Arc::new(InMemoryNodeStore::new());
530                let initializer: Arc<dyn NodeStoreInitializer> = store.clone();
531                let node_store: Arc<dyn NodeStore> = store;
532                (
533                    node_store,
534                    initializer,
535                    "in-memory",
536                    None,
537                    None,
538                    None,
539                )
540            }
541            StorageMode::Surreal => {
542                let settings = surreal_settings_from_cli(cli);
543                let runtime = surreal_runtime_from_cli(cli, &settings)?;
544
545                let client = Arc::new(
546                    RuntimeSurrealDbClient::connect(
547                        &runtime,
548                        settings.user.as_deref(),
549                        settings.password.as_deref(),
550                    )
551                    .await?,
552                );
553
554                let store = Arc::new(SurrealDbNodeStore::new(client));
555                let initializer: Arc<dyn NodeStoreInitializer> = store.clone();
556                let node_store: Arc<dyn NodeStore> = store;
557
558                (
559                    node_store,
560                    initializer,
561                    if runtime.use_remote {
562                        "surreal-remote"
563                    } else {
564                        "surreal-embedded"
565                    },
566                    Some(runtime.endpoint),
567                    Some(runtime.namespace),
568                    Some(runtime.database),
569                )
570            }
571        };
572
573    initializer.initialize_async().await?;
574
575    let validator: Arc<dyn NodeValidator> = Arc::new(TreeSitterValidator::new());
576
577    Ok(Services {
578        calibration: CalibrationService::new(store.clone()),
579        store_context: StoreContextService::new(store.clone(), validator.clone()),
580        memory_find: MemoryFindService::new(store.clone()),
581        memory_recall: MemoryRecallService::new(store.clone()),
582        moods: MoodCatalogService::new(),
583        monthly_rollup: MonthlyRollupService::new(store, validator),
584        storage_mode,
585        storage_endpoint,
586        storage_namespace,
587        storage_database,
588    })
589}
590
591fn surreal_settings_from_cli(cli: &Cli) -> SurrealDbSettings {
592    let mut settings = SurrealDbSettings {
593        endpoints: SurrealDbEndpointsSettings {
594            embedded: Some(
595                cli.surreal_embedded_endpoint
596                    .clone()
597                    .unwrap_or_else(|| "surrealkv://data/locus-cli".to_string()),
598            ),
599            remote: cli.surreal_remote_endpoint.clone(),
600        },
601        namespace: cli.surreal_namespace.clone(),
602        database: cli.surreal_database.clone(),
603        user: cli.surreal_username.clone(),
604        password: cli.surreal_password.clone(),
605    };
606
607    if let Some(endpoint) = cli
608        .surreal_endpoint
609        .as_ref()
610        .map(|value| value.trim())
611        .filter(|value| !value.is_empty())
612    {
613        settings.endpoints.embedded = Some(endpoint.to_string());
614        settings.endpoints.remote = Some(endpoint.to_string());
615    }
616
617    settings
618}
619
620fn surreal_runtime_from_cli(cli: &Cli, settings: &SurrealDbSettings) -> Result<SurrealDbRuntimeOptions> {
621    let mut args = Vec::new();
622    if cli.remote {
623        args.push("--remote".to_string());
624    }
625
626    SurrealDbRuntimeOptions::from_args(&args, settings, Some(&cli.root_dir_name))
627}
628
629fn normalize_tiers(tiers: Vec<String>) -> Option<Vec<String>> {
630    let tiers = tiers
631        .into_iter()
632        .map(|tier| tier.trim().to_ascii_lowercase())
633        .filter(|tier| !tier.is_empty())
634        .collect::<Vec<_>>();
635
636    if tiers.is_empty() {
637        None
638    } else {
639        Some(tiers)
640    }
641}
642
643fn resolve_tenant(value: Option<&str>) -> Result<String> {
644    match value.and_then(normalize_tenant_value) {
645        Some(tenant) => Ok(tenant),
646        None => {
647            if value.is_some() {
648                bail!("tenant id can only contain letters, digits, '-' or '_'");
649            }
650            Ok(DEFAULT_TENANT.to_string())
651        }
652    }
653}
654
655fn normalize_tenant_value(value: &str) -> Option<String> {
656    let trimmed = value.trim();
657    if trimmed.is_empty() {
658        return None;
659    }
660
661    let normalized = trimmed.to_ascii_lowercase();
662    if normalized
663        .chars()
664        .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_')
665    {
666        Some(normalized)
667    } else {
668        None
669    }
670}
671
672fn scope_session_id(tenant: &str, session_id: &str) -> String {
673    if tenant == DEFAULT_TENANT {
674        session_id.to_string()
675    } else {
676        format!("{TENANT_SCOPE_PREFIX}{tenant}{TENANT_SCOPE_SEPARATOR}{session_id}")
677    }
678}
679
680fn parse_scoped_session_id(session_id: &str) -> Option<(&str, &str)> {
681    let remainder = session_id.strip_prefix(TENANT_SCOPE_PREFIX)?;
682    remainder.split_once(TENANT_SCOPE_SEPARATOR)
683}
684
685fn session_belongs_to_tenant(session_id: &str, tenant: &str) -> bool {
686    match parse_scoped_session_id(session_id) {
687        Some((scoped_tenant, _)) => scoped_tenant == tenant,
688        None => tenant == DEFAULT_TENANT,
689    }
690}
691
692fn display_session_id(session_id: &str) -> String {
693    match parse_scoped_session_id(session_id) {
694        Some((_, base_session_id)) => base_session_id.to_string(),
695        None => session_id.to_string(),
696    }
697}
698
699fn normalize_nodes_for_tenant(nodes: Vec<SttpNode>, tenant: &str) -> Vec<SttpNode> {
700    nodes
701        .into_iter()
702        .filter_map(|mut node| {
703            if !session_belongs_to_tenant(&node.session_id, tenant) {
704                return None;
705            }
706            node.session_id = display_session_id(&node.session_id);
707            Some(node)
708        })
709        .collect()
710}
711
712fn avec_to_json(avec: AvecState) -> Value {
713    json!({
714        "stability": avec.stability,
715        "friction": avec.friction,
716        "logic": avec.logic,
717        "autonomy": avec.autonomy,
718        "psi": avec.psi(),
719    })
720}
721
722fn sttp_node_to_json(node: &SttpNode) -> Value {
723    json!({
724        "raw": node.raw,
725        "sessionId": node.session_id,
726        "tier": node.tier,
727        "timestamp": node.timestamp.to_rfc3339(),
728        "compressionDepth": node.compression_depth,
729        "parentNodeId": node.parent_node_id,
730        "syncKey": node.sync_key,
731        "updatedAt": node.updated_at.to_rfc3339(),
732        "contextSummary": node.context_summary,
733        "embeddingModel": node.embedding_model,
734        "embeddingDimensions": node.embedding_dimensions,
735        "embeddedAt": node.embedded_at.map(|v| v.to_rfc3339()),
736        "userAvec": avec_to_json(node.user_avec),
737        "modelAvec": avec_to_json(node.model_avec),
738        "compressionAvec": node.compression_avec.map(avec_to_json),
739        "rho": node.rho,
740        "kappa": node.kappa,
741        "psi": node.psi,
742    })
743}
744
745fn parse_utc_required(value: &str, field: &str) -> Result<DateTime<Utc>> {
746    DateTime::parse_from_rfc3339(value)
747        .map(|parsed| parsed.with_timezone(&Utc))
748        .map_err(|_| anyhow!("{field} must be an ISO8601 UTC datetime"))
749}
750
751fn parse_utc_optional(value: Option<&str>, field: &str) -> Result<Option<DateTime<Utc>>> {
752    match value {
753        Some(raw) => parse_utc_required(raw, field).map(Some),
754        None => Ok(None),
755    }
756}