locus_sdk/infrastructure/
embeddings.rs1#[cfg(feature = "local-embedding")]
2use std::sync::Arc;
3
4use anyhow::{Result, anyhow};
5use async_trait::async_trait;
6use locus_core_rs::domain::contracts::EmbeddingProvider;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Serialize)]
10struct OllamaEmbeddingRequest<'a> {
11 model: &'a str,
12 prompt: &'a str,
13}
14
15#[derive(Debug, Deserialize)]
16struct OllamaEmbeddingResponse {
17 embedding: Option<Vec<f32>>,
18}
19
20#[derive(Clone)]
21pub struct OllamaEmbeddingProvider {
22 client: reqwest::Client,
23 endpoint: String,
24 model: String,
25}
26
27impl OllamaEmbeddingProvider {
28 pub fn new(endpoint: String, model: String) -> Self {
29 Self {
30 client: reqwest::Client::new(),
31 endpoint,
32 model,
33 }
34 }
35}
36
37#[async_trait]
38impl EmbeddingProvider for OllamaEmbeddingProvider {
39 fn model_name(&self) -> &str {
40 &self.model
41 }
42
43 async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
44 let response = self
45 .client
46 .post(&self.endpoint)
47 .json(&OllamaEmbeddingRequest {
48 model: &self.model,
49 prompt: text,
50 })
51 .send()
52 .await?
53 .error_for_status()?;
54
55 let body: OllamaEmbeddingResponse = response.json().await?;
56 match body.embedding {
57 Some(embedding) if !embedding.is_empty() => Ok(embedding),
58 _ => Err(anyhow!("embedding response missing vector")),
59 }
60 }
61}
62
63#[cfg(feature = "local-embedding")]
64pub struct LocalEmbeddingProvider {
65 model_name: String,
66 runtime: Arc<std::sync::Mutex<CandleRuntime>>,
67}
68
69#[cfg(feature = "local-embedding")]
70impl LocalEmbeddingProvider {
71 pub fn new(model_name: String, repo_id: String) -> Result<Self> {
72 let runtime = CandleRuntime::new(&repo_id)?;
73
74 Ok(Self {
75 model_name: format!("local-{}", model_name.trim().to_lowercase()),
76 runtime: Arc::new(std::sync::Mutex::new(runtime)),
77 })
78 }
79}
80
81#[cfg(feature = "local-embedding")]
82#[async_trait]
83impl EmbeddingProvider for LocalEmbeddingProvider {
84 fn model_name(&self) -> &str {
85 &self.model_name
86 }
87
88 async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
89 use anyhow::Context;
90
91 let runtime = Arc::clone(&self.runtime);
92 let input = text.to_string();
93
94 tokio::task::spawn_blocking(move || {
95 let runtime = runtime
96 .lock()
97 .map_err(|_| anyhow!("Local embedding runtime lock poisoned"))?;
98 runtime.embed(&input)
99 })
100 .await
101 .context("embedding worker join failure")?
102 }
103}
104
105#[cfg(feature = "local-embedding")]
106struct CandleRuntime {
107 model: candle_transformers::models::bert::BertModel,
108 tokenizer: tokenizers::Tokenizer,
109 device: candle_core::Device,
110}
111
112#[cfg(feature = "local-embedding")]
113impl CandleRuntime {
114 fn new(repo_id: &str) -> Result<Self> {
115 use anyhow::Context;
116 use candle_core::{DType, Device};
117 use candle_nn::VarBuilder;
118 use candle_transformers::models::bert::{BertModel, Config};
119 use hf_hub::{Repo, RepoType, api::sync::ApiBuilder};
120 use tokenizers::PaddingParams;
121
122 let device = Device::Cpu;
123
124 let api = ApiBuilder::new()
125 .with_endpoint("https://huggingface.co".to_string())
126 .build()
127 .context("failed to create HuggingFace API client")?;
128 let repo = api.repo(Repo::new(repo_id.to_string(), RepoType::Model));
129
130 let config_path = repo
131 .get("config.json")
132 .with_context(|| format!("failed to fetch config.json from {repo_id}"))?;
133 let tokenizer_path = repo
134 .get("tokenizer.json")
135 .with_context(|| format!("failed to fetch tokenizer.json from {repo_id}"))?;
136 let weights_path = repo
137 .get("model.safetensors")
138 .with_context(|| format!("failed to fetch model.safetensors from {repo_id}"))?;
139
140 let config: Config = serde_json::from_str(
141 &std::fs::read_to_string(&config_path)
142 .with_context(|| format!("failed to read {}", config_path.display()))?,
143 )
144 .with_context(|| format!("failed to parse {}", config_path.display()))?;
145
146 let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
147 .map_err(|err| anyhow!("tokenizer error: {err}"))?;
148 tokenizer.with_padding(Some(PaddingParams {
149 strategy: tokenizers::PaddingStrategy::BatchLongest,
150 ..Default::default()
151 }));
152
153 let vb = unsafe {
154 VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
155 .context("failed to map safetensors weights")?
156 };
157 let model = BertModel::load(vb, &config).context("failed to load BERT model")?;
158
159 Ok(Self {
160 model,
161 tokenizer,
162 device,
163 })
164 }
165
166 fn embed(&self, text: &str) -> Result<Vec<f32>> {
167 let embeddings = self.embed_batch(&[text])?;
168 embeddings
169 .into_iter()
170 .next()
171 .ok_or_else(|| anyhow!("empty embedding output"))
172 }
173
174 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
175 use anyhow::Context;
176 use candle_core::{DType, Tensor};
177
178 if texts.is_empty() {
179 return Ok(Vec::new());
180 }
181
182 let encodings = self
183 .tokenizer
184 .encode_batch(texts.to_vec(), true)
185 .map_err(|err| anyhow!("tokenization failed: {err}"))?;
186
187 let seq_len = encodings[0].get_ids().len();
188 let batch_size = texts.len();
189
190 let input_ids: Vec<u32> = encodings.iter().flat_map(|e| e.get_ids().to_vec()).collect();
191 let attention_mask: Vec<u32> = encodings
192 .iter()
193 .flat_map(|e| e.get_attention_mask().to_vec())
194 .collect();
195 let token_type_ids: Vec<u32> = vec![0u32; batch_size * seq_len];
196
197 let input_ids = Tensor::from_vec(input_ids, (batch_size, seq_len), &self.device)?;
198 let attention_mask = Tensor::from_vec(attention_mask, (batch_size, seq_len), &self.device)?;
199 let token_type_ids = Tensor::from_vec(token_type_ids, (batch_size, seq_len), &self.device)?;
200
201 let output = self
202 .model
203 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
204 .context("local embedding forward pass failed")?;
205
206 let mask_f32 = attention_mask.to_dtype(DType::F32)?.unsqueeze(2)?;
207 let masked = output.broadcast_mul(&mask_f32)?;
208 let summed = masked.sum(1)?;
209 let counts = mask_f32.sum(1)?;
210 let pooled = summed.broadcast_div(&counts)?;
211
212 let norm = pooled.sqr()?.sum_keepdim(1)?.sqrt()?;
213 let normalized = pooled.broadcast_div(&norm)?;
214
215 Ok(normalized.to_vec2::<f32>()?)
216 }
217}