1use anyhow::Result;
2use locus_core_rs::domain::models::AvecState;
3
4use crate::domain::ai::{AiProviderRegistry, EmbedRequest, ScoreAvecRequest};
5
6pub async fn route_embedding(
7 registry: &dyn AiProviderRegistry,
8 request: &EmbedRequest,
9) -> Result<Vec<f32>> {
10 let provider = registry.resolve(request.task, request.provider_id.as_deref(), request.policy)?;
11
12 match request.task {
13 crate::domain::ai::AiTask::SemanticEmbedding => {
14 provider.embed_semantic(request).await
15 }
16 crate::domain::ai::AiTask::AvecEmbedding => {
17 provider.embed_avec(request).await
18 }
19 crate::domain::ai::AiTask::AvecScoring => {
20 anyhow::bail!("AiTask::AvecScoring is not an embedding task")
21 }
22 }
23}
24
25pub async fn route_avec_score(
26 registry: &dyn AiProviderRegistry,
27 request: &ScoreAvecRequest,
28) -> Result<AvecState> {
29 let provider = registry.resolve(
30 crate::domain::ai::AiTask::AvecScoring,
31 request.provider_id.as_deref(),
32 request.policy,
33 )?;
34
35 provider.score_avec(request).await
36}
37
38#[cfg(test)]
39mod tests {
40 use std::sync::Arc;
41
42 use anyhow::{Result, anyhow};
43 use async_trait::async_trait;
44 use locus_core_rs::domain::models::AvecState;
45 use tokio::sync::Mutex;
46
47 use super::{route_avec_score, route_embedding};
48 use crate::domain::ai::{
49 AiCapability, AiProvider, AiProviderRegistry, AiTask, EmbedRequest, ProviderPolicy,
50 ScoreAvecRequest,
51 };
52 use crate::infrastructure::registry::InMemoryAiProviderRegistry;
53
54 #[derive(Clone, Default)]
55 struct CallCounters {
56 semantic_calls: Arc<Mutex<usize>>,
57 avec_embedding_calls: Arc<Mutex<usize>>,
58 avec_scoring_calls: Arc<Mutex<usize>>,
59 }
60
61 struct FullMockProvider {
62 id: String,
63 counters: CallCounters,
64 }
65
66 impl FullMockProvider {
67 fn new(id: impl Into<String>, counters: CallCounters) -> Self {
68 Self {
69 id: id.into(),
70 counters,
71 }
72 }
73 }
74
75 #[async_trait]
76 impl AiProvider for FullMockProvider {
77 fn provider_id(&self) -> &str {
78 &self.id
79 }
80
81 fn capabilities(&self) -> &'static [AiCapability] {
82 &[
83 AiCapability::SemanticEmbedding,
84 AiCapability::AvecEmbedding,
85 AiCapability::AvecScoring,
86 ]
87 }
88
89 async fn embed_semantic(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
90 let mut calls = self.counters.semantic_calls.lock().await;
91 *calls += 1;
92 Ok(vec![0.1, 0.2, 0.3])
93 }
94
95 async fn embed_avec(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
96 let mut calls = self.counters.avec_embedding_calls.lock().await;
97 *calls += 1;
98 Ok(vec![0.4, 0.5, 0.6])
99 }
100
101 async fn score_avec(&self, _request: &ScoreAvecRequest) -> Result<AvecState> {
102 let mut calls = self.counters.avec_scoring_calls.lock().await;
103 *calls += 1;
104 Ok(AvecState {
105 stability: 0.8,
106 friction: 0.2,
107 logic: 0.9,
108 autonomy: 0.7,
109 })
110 }
111 }
112
113 struct SemanticOnlyProvider;
114
115 #[async_trait]
116 impl AiProvider for SemanticOnlyProvider {
117 fn provider_id(&self) -> &str {
118 "semantic-only"
119 }
120
121 fn capabilities(&self) -> &'static [AiCapability] {
122 &[AiCapability::SemanticEmbedding]
123 }
124
125 async fn embed_semantic(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
126 Ok(vec![1.0])
127 }
128
129 async fn embed_avec(&self, _request: &EmbedRequest) -> Result<Vec<f32>> {
130 Err(anyhow!("unsupported"))
131 }
132
133 async fn score_avec(&self, _request: &ScoreAvecRequest) -> Result<AvecState> {
134 Err(anyhow!("unsupported"))
135 }
136 }
137
138 #[tokio::test]
139 async fn route_embedding_dispatches_semantic_task() {
140 let counters = CallCounters::default();
141 let mut registry = InMemoryAiProviderRegistry::new();
142 registry.register(FullMockProvider::new("mock", counters.clone()));
143
144 let request = EmbedRequest {
145 text: "hello".to_string(),
146 task: AiTask::SemanticEmbedding,
147 provider_id: Some("mock".to_string()),
148 model: None,
149 policy: ProviderPolicy::Preferred,
150 };
151
152 let vector = route_embedding(®istry as &dyn AiProviderRegistry, &request)
153 .await
154 .expect("semantic routing should succeed");
155
156 assert_eq!(vector, vec![0.1, 0.2, 0.3]);
157 assert_eq!(*counters.semantic_calls.lock().await, 1);
158 assert_eq!(*counters.avec_embedding_calls.lock().await, 0);
159 assert_eq!(*counters.avec_scoring_calls.lock().await, 0);
160 }
161
162 #[tokio::test]
163 async fn route_embedding_dispatches_avec_embedding_task() {
164 let counters = CallCounters::default();
165 let mut registry = InMemoryAiProviderRegistry::new();
166 registry.register(FullMockProvider::new("mock", counters.clone()));
167
168 let request = EmbedRequest {
169 text: "hello".to_string(),
170 task: AiTask::AvecEmbedding,
171 provider_id: Some("mock".to_string()),
172 model: None,
173 policy: ProviderPolicy::Preferred,
174 };
175
176 let vector = route_embedding(®istry as &dyn AiProviderRegistry, &request)
177 .await
178 .expect("AVEC embedding routing should succeed");
179
180 assert_eq!(vector, vec![0.4, 0.5, 0.6]);
181 assert_eq!(*counters.semantic_calls.lock().await, 0);
182 assert_eq!(*counters.avec_embedding_calls.lock().await, 1);
183 assert_eq!(*counters.avec_scoring_calls.lock().await, 0);
184 }
185
186 #[tokio::test]
187 async fn route_avec_score_dispatches_scoring_task() {
188 let counters = CallCounters::default();
189 let mut registry = InMemoryAiProviderRegistry::new();
190 registry.register(FullMockProvider::new("mock", counters.clone()));
191
192 let request = ScoreAvecRequest {
193 text: "score this".to_string(),
194 provider_id: Some("mock".to_string()),
195 model: None,
196 policy: ProviderPolicy::Preferred,
197 };
198
199 let avec = route_avec_score(®istry as &dyn AiProviderRegistry, &request)
200 .await
201 .expect("AVEC scoring should succeed");
202
203 assert!((avec.stability - 0.8).abs() < f32::EPSILON);
204 assert_eq!(*counters.semantic_calls.lock().await, 0);
205 assert_eq!(*counters.avec_embedding_calls.lock().await, 0);
206 assert_eq!(*counters.avec_scoring_calls.lock().await, 1);
207 }
208
209 #[tokio::test]
210 async fn route_avec_score_fails_when_provider_lacks_capability() {
211 let mut registry = InMemoryAiProviderRegistry::new();
212 registry.register(SemanticOnlyProvider);
213
214 let request = ScoreAvecRequest {
215 text: "score this".to_string(),
216 provider_id: Some("semantic-only".to_string()),
217 model: None,
218 policy: ProviderPolicy::Preferred,
219 };
220
221 let err = route_avec_score(®istry as &dyn AiProviderRegistry, &request)
222 .await
223 .expect_err("expected capability mismatch error");
224
225 assert!(
226 err.to_string()
227 .contains("does not support task 'AvecScoring'")
228 );
229 }
230}