Skip to main content

locus_sdk/infrastructure/genai_adapter/
provider.rs

1use anyhow::{Context, Result, anyhow};
2use async_trait::async_trait;
3use genai::Client;
4use genai::chat::ChatRequest;
5use locus_core_rs::domain::models::AvecState;
6
7use crate::domain::ai::{AiCapability, AiProvider, EmbedRequest, ScoreAvecRequest};
8
9pub struct GenaiProviderAdapter {
10    provider_id: String,
11    default_model: Option<String>,
12    client: Client,
13}
14
15impl GenaiProviderAdapter {
16    pub fn new(provider_id: impl Into<String>, default_model: Option<String>) -> Self {
17        Self {
18            provider_id: provider_id.into(),
19            default_model,
20            client: Client::default(),
21        }
22    }
23
24    pub fn default_model(&self) -> Option<&str> {
25        self.default_model.as_deref()
26    }
27
28    fn resolve_model<'a>(&'a self, requested: Option<&'a str>) -> Result<&'a str> {
29        requested
30            .filter(|value| !value.trim().is_empty())
31            .or_else(|| self.default_model.as_deref())
32            .ok_or_else(|| anyhow!("no model provided and no default model configured"))
33    }
34}
35
36#[async_trait]
37impl AiProvider for GenaiProviderAdapter {
38    fn provider_id(&self) -> &str {
39        &self.provider_id
40    }
41
42    fn capabilities(&self) -> &'static [AiCapability] {
43        &[
44            AiCapability::SemanticEmbedding,
45            AiCapability::AvecEmbedding,
46            AiCapability::AvecScoring,
47        ]
48    }
49
50    async fn embed_semantic(&self, request: &EmbedRequest) -> Result<Vec<f32>> {
51        let model = self.resolve_model(request.model.as_deref())?;
52
53        let response = self
54            .client
55            .embed(model, request.text.clone(), None)
56            .await
57            .with_context(|| format!("genai semantic embedding call failed for model '{model}'"))?;
58
59        response
60            .first_embedding()
61            .map(|embedding| embedding.vector().to_vec())
62            .filter(|vector| !vector.is_empty())
63            .ok_or_else(|| anyhow!("genai semantic embedding response is missing vector data"))
64    }
65
66    async fn embed_avec(&self, request: &EmbedRequest) -> Result<Vec<f32>> {
67        let model = self.resolve_model(request.model.as_deref())?;
68
69        let response = self
70            .client
71            .embed(model, request.text.clone(), None)
72            .await
73            .with_context(|| format!("genai AVEC embedding call failed for model '{model}'"))?;
74
75        response
76            .first_embedding()
77            .map(|embedding| embedding.vector().to_vec())
78            .filter(|vector| !vector.is_empty())
79            .ok_or_else(|| anyhow!("genai AVEC embedding response is missing vector data"))
80    }
81
82    async fn score_avec(&self, request: &ScoreAvecRequest) -> Result<AvecState> {
83        let model = self.resolve_model(request.model.as_deref())?;
84        let prompt = "Return only compact JSON with numeric fields in [0,1]: stability, friction, logic, autonomy.";
85
86        let chat_req = ChatRequest::from_system(prompt).append_message(genai::chat::ChatMessage::user(
87            request.text.clone(),
88        ));
89
90        let response = self
91            .client
92            .exec_chat(model, chat_req, None)
93            .await
94            .with_context(|| format!("genai AVEC scoring call failed for model '{model}'"))?;
95
96        let content = response
97            .first_text()
98            .ok_or_else(|| anyhow!("genai AVEC scoring returned no text content"))?;
99
100        parse_avec_state_from_text(content)
101    }
102}
103
104#[derive(Debug, serde::Deserialize)]
105struct ParsedAvecScore {
106    stability: f32,
107    friction: f32,
108    logic: f32,
109    autonomy: f32,
110}
111
112fn parse_avec_state_from_text(raw: &str) -> Result<AvecState> {
113    let parsed: ParsedAvecScore = serde_json::from_str(raw)
114        .with_context(|| "failed to parse AVEC JSON response from model")?;
115
116    let values = [parsed.stability, parsed.friction, parsed.logic, parsed.autonomy];
117    if values.iter().any(|value| !(0.0..=1.0).contains(value)) {
118        return Err(anyhow!(
119            "AVEC response contains values outside [0,1]: {:?}",
120            values
121        ));
122    }
123
124    Ok(AvecState {
125        stability: parsed.stability,
126        friction: parsed.friction,
127        logic: parsed.logic,
128        autonomy: parsed.autonomy,
129    })
130}
131
132#[cfg(test)]
133mod tests {
134    use super::parse_avec_state_from_text;
135
136    #[test]
137    fn parse_avec_state_accepts_valid_payload() {
138        let raw = r#"{"stability":0.8,"friction":0.2,"logic":0.9,"autonomy":0.7}"#;
139        let avec = parse_avec_state_from_text(raw).expect("expected valid AVEC payload");
140        assert!((avec.stability - 0.8).abs() < f32::EPSILON);
141        assert!((avec.friction - 0.2).abs() < f32::EPSILON);
142        assert!((avec.logic - 0.9).abs() < f32::EPSILON);
143        assert!((avec.autonomy - 0.7).abs() < f32::EPSILON);
144    }
145
146    #[test]
147    fn parse_avec_state_rejects_out_of_range_values() {
148        let raw = r#"{"stability":1.2,"friction":0.2,"logic":0.9,"autonomy":0.7}"#;
149        let err = parse_avec_state_from_text(raw).expect_err("expected range validation error");
150        assert!(err.to_string().contains("outside [0,1]"));
151    }
152}