Skip to main content

locus_sdk/application/
ai_router.rs

1use anyhow::Result;
2use locus_core_rs::domain::models::AvecState;
3
4use crate::domain::ai::{AiProviderRegistry, EmbedRequest, ScoreAvecRequest};
5
6pub async fn route_embedding(
7    registry: &dyn AiProviderRegistry,
8    request: &EmbedRequest,
9) -> Result<Vec<f32>> {
10    let provider = registry.resolve(request.task, request.provider_id.as_deref(), request.policy)?;
11
12    match request.task {
13        crate::domain::ai::AiTask::SemanticEmbedding => {
14            provider.embed_semantic(request).await
15        }
16        crate::domain::ai::AiTask::AvecEmbedding => {
17            provider.embed_avec(request).await
18        }
19        crate::domain::ai::AiTask::AvecScoring => {
20            anyhow::bail!("AiTask::AvecScoring is not an embedding task")
21        }
22    }
23}
24
25pub async fn route_avec_score(
26    registry: &dyn AiProviderRegistry,
27    request: &ScoreAvecRequest,
28) -> Result<AvecState> {
29    let provider = registry.resolve(
30        crate::domain::ai::AiTask::AvecScoring,
31        request.provider_id.as_deref(),
32        request.policy,
33    )?;
34
35    provider.score_avec(request).await
36}
37
38#[cfg(test)]
39mod tests {
40    use std::sync::Arc;
41
42    use anyhow::{Result, anyhow};
43    use async_trait::async_trait;
44    use locus_core_rs::domain::models::AvecState;
45    use tokio::sync::Mutex;
46
47    use super::{route_avec_score, route_embedding};
48    use crate::domain::ai::{
49        AiCapability, AiProvider, AiProviderRegistry, AiTask, EmbedRequest, ProviderPolicy,
50        ScoreAvecRequest,
51    };
52    use crate::infrastructure::registry::InMemoryAiProviderRegistry;
53
54    #[derive(Clone, Default)]
55    struct CallCounters {
56        semantic_calls: Arc<Mutex<usize>>,
57        avec_embedding_calls: Arc<Mutex<usize>>,
58        avec_scoring_calls: Arc<Mutex<usize>>,
59    }
60
61    struct FullMockProvider {
62        id: String,
63        counters: CallCounters,
64    }
65
66    impl FullMockProvider {
67        fn new(id: impl Into<String>, counters: CallCounters) -> Self {
68            Self {
69                id: id.into(),
70                counters,
71            }
72        }
73    }
74
75    #[async_trait]
76    impl AiProvider for FullMockProvider {
77        fn provider_id(&self) -> &str {
78            &self.id
79        }
80
81        fn capabilities(&self) -> &'static [AiCapability] {
82            &[
83                AiCapability::SemanticEmbedding,
84                AiCapability::AvecEmbedding,
85                AiCapability::AvecScoring,
86            ]
87        }
88
89        async fn embed_semantic(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
90            let mut calls = self.counters.semantic_calls.lock().await;
91            *calls += 1;
92            Ok(vec![0.1, 0.2, 0.3])
93        }
94
95        async fn embed_avec(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
96            let mut calls = self.counters.avec_embedding_calls.lock().await;
97            *calls += 1;
98            Ok(vec![0.4, 0.5, 0.6])
99        }
100
101        async fn score_avec(&self, _request: &ScoreAvecRequest) -> Result<AvecState> {
102            let mut calls = self.counters.avec_scoring_calls.lock().await;
103            *calls += 1;
104            Ok(AvecState {
105                stability: 0.8,
106                friction: 0.2,
107                logic: 0.9,
108                autonomy: 0.7,
109            })
110        }
111    }
112
113    struct SemanticOnlyProvider;
114
115    #[async_trait]
116    impl AiProvider for SemanticOnlyProvider {
117        fn provider_id(&self) -> &str {
118            "semantic-only"
119        }
120
121        fn capabilities(&self) -> &'static [AiCapability] {
122            &[AiCapability::SemanticEmbedding]
123        }
124
125        async fn embed_semantic(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
126            Ok(vec![1.0])
127        }
128
129        async fn embed_avec(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
130            Err(anyhow!("unsupported"))
131        }
132
133        async fn score_avec(&self, _request: &ScoreAvecRequest) -> Result<AvecState> {
134            Err(anyhow!("unsupported"))
135        }
136    }
137
138    #[tokio::test]
139    async fn route_embedding_dispatches_semantic_task() {
140        let counters = CallCounters::default();
141        let mut registry = InMemoryAiProviderRegistry::new();
142        registry.register(FullMockProvider::new("mock", counters.clone()));
143
144        let request = EmbedRequest {
145            text: "hello".to_string(),
146            task: AiTask::SemanticEmbedding,
147            provider_id: Some("mock".to_string()),
148            model: None,
149            policy: ProviderPolicy::Preferred,
150        };
151
152        let vector = route_embedding(&registry as &dyn AiProviderRegistry, &request)
153            .await
154            .expect("semantic routing should succeed");
155
156        assert_eq!(vector, vec![0.1, 0.2, 0.3]);
157        assert_eq!(*counters.semantic_calls.lock().await, 1);
158        assert_eq!(*counters.avec_embedding_calls.lock().await, 0);
159        assert_eq!(*counters.avec_scoring_calls.lock().await, 0);
160    }
161
162    #[tokio::test]
163    async fn route_embedding_dispatches_avec_embedding_task() {
164        let counters = CallCounters::default();
165        let mut registry = InMemoryAiProviderRegistry::new();
166        registry.register(FullMockProvider::new("mock", counters.clone()));
167
168        let request = EmbedRequest {
169            text: "hello".to_string(),
170            task: AiTask::AvecEmbedding,
171            provider_id: Some("mock".to_string()),
172            model: None,
173            policy: ProviderPolicy::Preferred,
174        };
175
176        let vector = route_embedding(&registry as &dyn AiProviderRegistry, &request)
177            .await
178            .expect("AVEC embedding routing should succeed");
179
180        assert_eq!(vector, vec![0.4, 0.5, 0.6]);
181        assert_eq!(*counters.semantic_calls.lock().await, 0);
182        assert_eq!(*counters.avec_embedding_calls.lock().await, 1);
183        assert_eq!(*counters.avec_scoring_calls.lock().await, 0);
184    }
185
186    #[tokio::test]
187    async fn route_avec_score_dispatches_scoring_task() {
188        let counters = CallCounters::default();
189        let mut registry = InMemoryAiProviderRegistry::new();
190        registry.register(FullMockProvider::new("mock", counters.clone()));
191
192        let request = ScoreAvecRequest {
193            text: "score this".to_string(),
194            provider_id: Some("mock".to_string()),
195            model: None,
196            policy: ProviderPolicy::Preferred,
197        };
198
199        let avec = route_avec_score(&registry as &dyn AiProviderRegistry, &request)
200            .await
201            .expect("AVEC scoring should succeed");
202
203        assert!((avec.stability - 0.8).abs() < f32::EPSILON);
204        assert_eq!(*counters.semantic_calls.lock().await, 0);
205        assert_eq!(*counters.avec_embedding_calls.lock().await, 0);
206        assert_eq!(*counters.avec_scoring_calls.lock().await, 1);
207    }
208
209    #[tokio::test]
210    async fn route_avec_score_fails_when_provider_lacks_capability() {
211        let mut registry = InMemoryAiProviderRegistry::new();
212        registry.register(SemanticOnlyProvider);
213
214        let request = ScoreAvecRequest {
215            text: "score this".to_string(),
216            provider_id: Some("semantic-only".to_string()),
217            model: None,
218            policy: ProviderPolicy::Preferred,
219        };
220
221        let err = route_avec_score(&registry as &dyn AiProviderRegistry, &request)
222            .await
223            .expect_err("expected capability mismatch error");
224
225        assert!(
226            err.to_string()
227                .contains("does not support task 'AvecScoring'")
228        );
229    }
230}