Skip to main content

locus_gateway/
providers.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::{Value, json};
7
8use locus_core_rs::domain::contracts::EmbeddingProvider;
9use locus_core_rs::domain::models as core_models;
10
11#[derive(Debug, Deserialize)]
12struct ParsedAvecScore {
13    stability: f32,
14    friction: f32,
15    logic: f32,
16    autonomy: f32,
17}
18
19#[derive(Debug, Serialize)]
20struct OllamaChatRequest<'a> {
21    model: &'a str,
22    messages: Vec<OllamaChatMessage<'a>>,
23    stream: bool,
24    format: Value,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28struct OllamaChatMessage<'a> {
29    role: &'a str,
30    content: &'a str,
31}
32
33#[derive(Debug, Deserialize)]
34struct OllamaChatResponse {
35    message: Option<OllamaChatMessageOwned>,
36}
37
38#[derive(Debug, Deserialize)]
39struct OllamaChatMessageOwned {
40    content: String,
41}
42
43#[async_trait]
44pub(crate) trait AvecScorer: Send + Sync {
45    fn provider_name(&self) -> &str;
46    fn model_name(&self) -> &str;
47    async fn score_async(&self, text: &str) -> Result<core_models::AvecState>;
48}
49
50#[derive(Clone)]
51pub(crate) struct OllamaAvecScorer {
52    client: reqwest::Client,
53    endpoint: String,
54    model: String,
55}
56
57impl OllamaAvecScorer {
58    pub(crate) fn new(endpoint: String, model: String) -> Self {
59        Self {
60            client: reqwest::Client::new(),
61            endpoint,
62            model,
63        }
64    }
65}
66
67#[async_trait]
68impl AvecScorer for OllamaAvecScorer {
69    fn provider_name(&self) -> &str {
70        "ollama"
71    }
72
73    fn model_name(&self) -> &str {
74        &self.model
75    }
76
77    async fn score_async(&self, text: &str) -> Result<core_models::AvecState> {
78        let prompt = "Return ONLY valid compact JSON with numeric fields in [0,1]: stability, friction, logic, autonomy.";
79        let response = self
80            .client
81            .post(&self.endpoint)
82            .json(&OllamaChatRequest {
83                model: &self.model,
84                messages: vec![
85                    OllamaChatMessage {
86                        role: "system",
87                        content: prompt,
88                    },
89                    OllamaChatMessage {
90                        role: "user",
91                        content: text,
92                    },
93                ],
94                stream: false,
95                format: json!("json"),
96            })
97            .send()
98            .await?
99            .error_for_status()?;
100
101        let body: OllamaChatResponse = response.json().await?;
102        let content = body
103            .message
104            .map(|message| message.content)
105            .ok_or_else(|| anyhow!("ollama scoring response missing message content"))?;
106
107        parse_avec_state_from_text(&content)
108    }
109}
110
111pub(crate) async fn resolve_query_embedding(
112    embedding_provider: Option<&Arc<dyn EmbeddingProvider>>,
113    query_text: Option<&str>,
114    provided_embedding: Option<&[f32]>,
115) -> Option<Vec<f32>> {
116    if let Some(embedding) = provided_embedding.filter(|embedding| !embedding.is_empty()) {
117        return Some(embedding.to_vec());
118    }
119
120    let text = match query_text.and_then(|text| {
121        let trimmed = text.trim();
122        if trimmed.is_empty() {
123            None
124        } else {
125            Some(trimmed)
126        }
127    }) {
128        Some(text) => text,
129        None => return None,
130    };
131
132    let provider = embedding_provider?;
133    provider.embed_async(text).await.ok()
134}
135
136pub(crate) fn parse_avec_state_from_text(content: &str) -> Result<core_models::AvecState> {
137    let parsed: ParsedAvecScore = match serde_json::from_str(content) {
138        Ok(value) => value,
139        Err(_) => {
140            let start = content
141                .find('{')
142                .ok_or_else(|| anyhow!("AVEC scorer did not return JSON"))?;
143            let end = content
144                .rfind('}')
145                .ok_or_else(|| anyhow!("AVEC scorer returned malformed JSON"))?;
146            let candidate = &content[start..=end];
147            serde_json::from_str(candidate)
148                .map_err(|err| anyhow!("failed to parse AVEC JSON payload: {err}"))?
149        }
150    };
151
152    Ok(core_models::AvecState {
153        stability: parsed.stability.clamp(0.0, 1.0),
154        friction: parsed.friction.clamp(0.0, 1.0),
155        logic: parsed.logic.clamp(0.0, 1.0),
156        autonomy: parsed.autonomy.clamp(0.0, 1.0),
157    })
158}