Skip to main content

locus_core_rs/application/services/
embedding_migration_service.rs

1use std::collections::HashSet;
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use chrono::{DateTime, Utc};
6
7use crate::domain::contracts::{EmbeddingProvider, NodeStore};
8use crate::domain::models::{NodeQuery, NodeUpsertStatus, SttpNode};
9
10#[derive(Debug, Clone, Default)]
11pub struct EmbeddingMigrationFilter {
12    pub session_id: Option<String>,
13    pub from_utc: Option<DateTime<Utc>>,
14    pub to_utc: Option<DateTime<Utc>>,
15    pub tiers: Option<Vec<String>>,
16    pub has_embedding: Option<bool>,
17    pub embedding_model: Option<String>,
18    pub sync_keys: Option<Vec<String>>,
19}
20
21#[derive(Debug, Clone)]
22pub struct EmbeddingMigrationPreviewRequest {
23    pub filter: EmbeddingMigrationFilter,
24    pub sample_limit: usize,
25    pub max_nodes: usize,
26}
27
28#[derive(Debug, Clone)]
29pub struct EmbeddingMigrationSample {
30    pub sync_key: String,
31    pub session_id: String,
32    pub tier: String,
33    pub has_embedding: bool,
34    pub embedding_model: Option<String>,
35    pub embedding_dimensions: Option<usize>,
36    pub embedded_at: Option<DateTime<Utc>>,
37    pub updated_at: DateTime<Utc>,
38    pub context_summary: Option<String>,
39}
40
41#[derive(Debug, Clone)]
42pub struct EmbeddingMigrationPreviewResult {
43    pub total_candidates: usize,
44    pub sample: Vec<EmbeddingMigrationSample>,
45    pub provider_available: bool,
46    pub provider_model: Option<String>,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum EmbeddingMigrationMode {
51    MissingOnly,
52    ReindexAll,
53}
54
55#[derive(Debug, Clone)]
56pub struct EmbeddingMigrationRunRequest {
57    pub filter: EmbeddingMigrationFilter,
58    pub mode: EmbeddingMigrationMode,
59    pub dry_run: bool,
60    pub batch_size: usize,
61    pub max_nodes: usize,
62}
63
64#[derive(Debug, Clone)]
65pub struct EmbeddingMigrationRunResult {
66    pub scanned: usize,
67    pub selected: usize,
68    pub updated: usize,
69    pub skipped: usize,
70    pub failed: usize,
71    pub duplicate: usize,
72    pub started_at: DateTime<Utc>,
73    pub completed_at: DateTime<Utc>,
74    pub provider_model: Option<String>,
75    pub failure_reasons: Vec<String>,
76}
77
78pub struct EmbeddingMigrationService {
79    store: Arc<dyn NodeStore>,
80    embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
81}
82
83impl EmbeddingMigrationService {
84    pub fn new(
85        store: Arc<dyn NodeStore>,
86        embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
87    ) -> Self {
88        Self {
89            store,
90            embedding_provider,
91        }
92    }
93
94    pub async fn preview_async(
95        &self,
96        request: EmbeddingMigrationPreviewRequest,
97    ) -> Result<EmbeddingMigrationPreviewResult> {
98        let max_nodes = request.max_nodes.clamp(1, 50_000);
99        let sample_limit = request.sample_limit.clamp(1, 200);
100        let candidates = self.fetch_candidates(&request.filter, max_nodes).await?;
101
102        Ok(EmbeddingMigrationPreviewResult {
103            total_candidates: candidates.len(),
104            sample: candidates
105                .into_iter()
106                .take(sample_limit)
107                .map(to_sample)
108                .collect::<Vec<_>>(),
109            provider_available: self.embedding_provider.is_some(),
110            provider_model: self
111                .embedding_provider
112                .as_ref()
113                .map(|provider| provider.model_name().to_string()),
114        })
115    }
116
117    pub async fn run_async(
118        &self,
119        request: EmbeddingMigrationRunRequest,
120    ) -> Result<EmbeddingMigrationRunResult> {
121        let started_at = Utc::now();
122        let max_nodes = request.max_nodes.clamp(1, 50_000);
123        let batch_size = request.batch_size.clamp(1, 500);
124        let mut candidates = self.fetch_candidates(&request.filter, max_nodes).await?;
125        let scanned = candidates.len();
126
127        if request.mode == EmbeddingMigrationMode::MissingOnly {
128            candidates.retain(|node| !node_has_embedding(node));
129        }
130
131        let selected = candidates.len();
132
133        if !request.dry_run && self.embedding_provider.is_none() {
134            return Err(anyhow!(
135                "Embedding provider is not configured. Enable embeddings before running migration."
136            ));
137        }
138
139        let provider_model = self
140            .embedding_provider
141            .as_ref()
142            .map(|provider| provider.model_name().to_string());
143
144        let mut result = EmbeddingMigrationRunResult {
145            scanned,
146            selected,
147            updated: 0,
148            skipped: 0,
149            failed: 0,
150            duplicate: 0,
151            started_at,
152            completed_at: started_at,
153            provider_model,
154            failure_reasons: Vec::new(),
155        };
156
157        if request.dry_run {
158            result.updated = selected;
159            result.completed_at = Utc::now();
160            return Ok(result);
161        }
162
163        let provider = match self.embedding_provider.as_ref() {
164            Some(provider) => provider,
165            None => {
166                return Err(anyhow!(
167                    "Embedding provider is not configured. Enable embeddings before running migration."
168                ));
169            }
170        };
171
172        for batch in candidates.chunks(batch_size) {
173            for mut node in batch.iter().cloned() {
174                let Some(embedding_input) =
175                    build_embedding_input(node.context_summary.as_deref(), &node.session_id)
176                else {
177                    result.skipped += 1;
178                    continue;
179                };
180
181                let embedding = match provider.embed_async(&embedding_input).await {
182                    Ok(values) if !values.is_empty() => values,
183                    Ok(_) => {
184                        result.failed += 1;
185                        push_failure_reason(
186                            &mut result.failure_reasons,
187                            format!(
188                                "{}: embedding provider returned an empty vector",
189                                node.sync_key
190                            ),
191                        );
192                        continue;
193                    }
194                    Err(err) => {
195                        result.failed += 1;
196                        push_failure_reason(
197                            &mut result.failure_reasons,
198                            format!("{}: embedding failed: {err}", node.sync_key),
199                        );
200                        continue;
201                    }
202                };
203
204                node.embedding_dimensions = Some(embedding.len());
205                node.embedding_model = Some(provider.model_name().to_string());
206                node.embedding = Some(embedding);
207                node.embedded_at = Some(Utc::now());
208                node.updated_at = Utc::now();
209
210                match self.store.upsert_node_async(node).await {
211                    Ok(upsert) => match upsert.status {
212                        NodeUpsertStatus::Created | NodeUpsertStatus::Updated => {
213                            result.updated += 1;
214                        }
215                        NodeUpsertStatus::Duplicate => {
216                            result.duplicate += 1;
217                        }
218                        NodeUpsertStatus::Skipped => {
219                            result.skipped += 1;
220                        }
221                    },
222                    Err(err) => {
223                        result.failed += 1;
224                        push_failure_reason(
225                            &mut result.failure_reasons,
226                            format!("store upsert failed: {err}"),
227                        );
228                    }
229                }
230            }
231        }
232
233        result.completed_at = Utc::now();
234        Ok(result)
235    }
236
237    async fn fetch_candidates(
238        &self,
239        filter: &EmbeddingMigrationFilter,
240        max_nodes: usize,
241    ) -> Result<Vec<SttpNode>> {
242        let tiers = normalize_tiers(filter.tiers.as_deref());
243        let model_filter = normalize_model_filter(filter.embedding_model.as_deref());
244        let sync_key_filter = normalize_sync_keys(filter.sync_keys.as_deref());
245
246        let nodes = self
247            .store
248            .query_nodes_async(NodeQuery {
249                limit: max_nodes,
250                session_id: filter.session_id.clone(),
251                from_utc: filter.from_utc,
252                to_utc: filter.to_utc,
253                tiers,
254            })
255            .await?;
256
257        let filtered = nodes
258            .into_iter()
259            .filter(|node| match filter.has_embedding {
260                Some(expected) => node_has_embedding(node) == expected,
261                None => true,
262            })
263            .filter(|node| match model_filter.as_deref() {
264                Some(expected) => node
265                    .embedding_model
266                    .as_ref()
267                    .map(|model| model.eq_ignore_ascii_case(expected))
268                    .unwrap_or(false),
269                None => true,
270            })
271            .filter(|node| match sync_key_filter.as_ref() {
272                Some(sync_keys) => sync_keys.contains(&node.sync_key),
273                None => true,
274            })
275            .collect::<Vec<_>>();
276
277        Ok(filtered)
278    }
279}
280
281fn to_sample(node: SttpNode) -> EmbeddingMigrationSample {
282    let has_embedding = node_has_embedding(&node);
283
284    EmbeddingMigrationSample {
285        sync_key: node.sync_key,
286        session_id: node.session_id,
287        tier: node.tier,
288        has_embedding,
289        embedding_model: node.embedding_model,
290        embedding_dimensions: node.embedding_dimensions,
291        embedded_at: node.embedded_at,
292        updated_at: node.updated_at,
293        context_summary: node.context_summary,
294    }
295}
296
297fn normalize_tiers(tiers: Option<&[String]>) -> Option<Vec<String>> {
298    let normalized = tiers
299        .unwrap_or(&[])
300        .iter()
301        .map(|value| value.trim().to_ascii_lowercase())
302        .filter(|value| !value.is_empty())
303        .collect::<Vec<_>>();
304
305    if normalized.is_empty() {
306        None
307    } else {
308        Some(normalized)
309    }
310}
311
312fn normalize_model_filter(value: Option<&str>) -> Option<String> {
313    value
314        .map(|model| model.trim().to_ascii_lowercase())
315        .filter(|model| !model.is_empty())
316}
317
318fn normalize_sync_keys(values: Option<&[String]>) -> Option<HashSet<String>> {
319    let normalized = values
320        .unwrap_or(&[])
321        .iter()
322        .map(|value| value.trim().to_string())
323        .filter(|value| !value.is_empty())
324        .collect::<HashSet<_>>();
325
326    if normalized.is_empty() {
327        None
328    } else {
329        Some(normalized)
330    }
331}
332
333fn node_has_embedding(node: &SttpNode) -> bool {
334    node.embedding
335        .as_ref()
336        .map(|values| !values.is_empty())
337        .unwrap_or(false)
338}
339
340fn build_embedding_input(context_summary: Option<&str>, session_id: &str) -> Option<String> {
341    let summary = context_summary
342        .map(str::trim)
343        .filter(|value| !value.is_empty())
344        .map(|value| value.to_string());
345    let session = session_id.trim();
346
347    if summary.is_none() && session.is_empty() {
348        return None;
349    }
350
351    Some(match summary {
352        Some(summary) if !session.is_empty() => format!("{summary}\nsession_id:{session}"),
353        Some(summary) => summary,
354        None => format!("session_id:{session}"),
355    })
356}
357
358fn push_failure_reason(reasons: &mut Vec<String>, reason: String) {
359    if reasons.len() < 100 {
360        reasons.push(reason);
361    }
362}