Skip to main content

locus_sdk/infrastructure/
embeddings.rs

1#[cfg(feature = "local-embedding")]
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use async_trait::async_trait;
6use locus_core_rs::domain::contracts::EmbeddingProvider;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Serialize)]
10struct OllamaEmbeddingRequest<'a> {
11    model: &'a str,
12    prompt: &'a str,
13}
14
15#[derive(Debug, Deserialize)]
16struct OllamaEmbeddingResponse {
17    embedding: Option<Vec<f32>>,
18}
19
20#[derive(Clone)]
21pub struct OllamaEmbeddingProvider {
22    client: reqwest::Client,
23    endpoint: String,
24    model: String,
25}
26
27impl OllamaEmbeddingProvider {
28    pub fn new(endpoint: String, model: String) -> Self {
29        Self {
30            client: reqwest::Client::new(),
31            endpoint,
32            model,
33        }
34    }
35}
36
37#[async_trait]
38impl EmbeddingProvider for OllamaEmbeddingProvider {
39    fn model_name(&self) -> &str {
40        &self.model
41    }
42
43    async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
44        let response = self
45            .client
46            .post(&self.endpoint)
47            .json(&OllamaEmbeddingRequest {
48                model: &self.model,
49                prompt: text,
50            })
51            .send()
52            .await?
53            .error_for_status()?;
54
55        let body: OllamaEmbeddingResponse = response.json().await?;
56        match body.embedding {
57            Some(embedding) if !embedding.is_empty() => Ok(embedding),
58            _ => Err(anyhow!("embedding response missing vector")),
59        }
60    }
61}
62
63#[cfg(feature = "local-embedding")]
64pub struct LocalEmbeddingProvider {
65    model_name: String,
66    runtime: Arc<std::sync::Mutex<CandleRuntime>>,
67}
68
69#[cfg(feature = "local-embedding")]
70impl LocalEmbeddingProvider {
71    pub fn new(model_name: String, repo_id: String) -> Result<Self> {
72        let runtime = CandleRuntime::new(&repo_id)?;
73
74        Ok(Self {
75            model_name: format!("local-{}", model_name.trim().to_lowercase()),
76            runtime: Arc::new(std::sync::Mutex::new(runtime)),
77        })
78    }
79}
80
81#[cfg(feature = "local-embedding")]
82#[async_trait]
83impl EmbeddingProvider for LocalEmbeddingProvider {
84    fn model_name(&self) -> &str {
85        &self.model_name
86    }
87
88    async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
89        use anyhow::Context;
90
91        let runtime = Arc::clone(&self.runtime);
92        let input = text.to_string();
93
94        tokio::task::spawn_blocking(move || {
95            let runtime = runtime
96                .lock()
97                .map_err(|_| anyhow!("Local embedding runtime lock poisoned"))?;
98            runtime.embed(&input)
99        })
100        .await
101        .context("embedding worker join failure")?
102    }
103}
104
105#[cfg(feature = "local-embedding")]
106struct CandleRuntime {
107    model: candle_transformers::models::bert::BertModel,
108    tokenizer: tokenizers::Tokenizer,
109    device: candle_core::Device,
110}
111
112#[cfg(feature = "local-embedding")]
113impl CandleRuntime {
114    fn new(repo_id: &str) -> Result<Self> {
115        use anyhow::Context;
116        use candle_core::{DType, Device};
117        use candle_nn::VarBuilder;
118        use candle_transformers::models::bert::{BertModel, Config};
119        use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
120        use tokenizers::PaddingParams;
121
122        let device = Device::Cpu;
123
124        let api = ApiBuilder::new()
125            .with_endpoint("https://huggingface.co".to_string())
126            .build()
127            .context("failed to create HuggingFace API client")?;
128        let repo = api.repo(Repo::new(repo_id.to_string(), RepoType::Model));
129
130        let config_path = repo
131            .get("config.json")
132            .with_context(|| format!("failed to fetch config.json from {repo_id}"))?;
133        let tokenizer_path = repo
134            .get("tokenizer.json")
135            .with_context(|| format!("failed to fetch tokenizer.json from {repo_id}"))?;
136        let weights_path = repo
137            .get("model.safetensors")
138            .with_context(|| format!("failed to fetch model.safetensors from {repo_id}"))?;
139
140        let config: Config = serde_json::from_str(
141            &std::fs::read_to_string(&config_path)
142                .with_context(|| format!("failed to read {}", config_path.display()))?,
143        )
144        .with_context(|| format!("failed to parse {}", config_path.display()))?;
145
146        let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
147            .map_err(|err| anyhow!("tokenizer error: {err}"))?;
148        tokenizer.with_padding(Some(PaddingParams {
149            strategy: tokenizers::PaddingStrategy::BatchLongest,
150            ..Default::default()
151        }));
152
153        let vb = unsafe {
154            VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
155                .context("failed to map safetensors weights")?
156        };
157        let model = BertModel::load(vb, &config).context("failed to load BERT model")?;
158
159        Ok(Self {
160            model,
161            tokenizer,
162            device,
163        })
164    }
165
166    fn embed(&self, text: &str) -> Result<Vec<f32>> {
167        let embeddings = self.embed_batch(&[text])?;
168        embeddings
169            .into_iter()
170            .next()
171            .ok_or_else(|| anyhow!("empty embedding output"))
172    }
173
174    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
175        use anyhow::Context;
176        use candle_core::{DType, Tensor};
177
178        if texts.is_empty() {
179            return Ok(Vec::new());
180        }
181
182        let encodings = self
183            .tokenizer
184            .encode_batch(texts.to_vec(), true)
185            .map_err(|err| anyhow!("tokenization failed: {err}"))?;
186
187        let seq_len = encodings[0].get_ids().len();
188        let batch_size = texts.len();
189
190        let input_ids: Vec<u32> = encodings.iter().flat_map(|e| e.get_ids().to_vec()).collect();
191        let attention_mask: Vec<u32> = encodings
192            .iter()
193            .flat_map(|e| e.get_attention_mask().to_vec())
194            .collect();
195        let token_type_ids: Vec<u32> = vec![0u32; batch_size * seq_len];
196
197        let input_ids = Tensor::from_vec(input_ids, (batch_size, seq_len), &self.device)?;
198        let attention_mask = Tensor::from_vec(attention_mask, (batch_size, seq_len), &self.device)?;
199        let token_type_ids = Tensor::from_vec(token_type_ids, (batch_size, seq_len), &self.device)?;
200
201        let output = self
202            .model
203            .forward(&input_ids, &token_type_ids, Some(&attention_mask))
204            .context("local embedding forward pass failed")?;
205
206        let mask_f32 = attention_mask.to_dtype(DType::F32)?.unsqueeze(2)?;
207        let masked = output.broadcast_mul(&mask_f32)?;
208        let summed = masked.sum(1)?;
209        let counts = mask_f32.sum(1)?;
210        let pooled = summed.broadcast_div(&counts)?;
211
212        let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?;
213        let normalized = pooled.broadcast_div(&norm)?;
214
215        Ok(normalized.to_vec2::<f32>()?)
216    }
217}