From a9a5ee4cb66345de2e4e84315beaa00fd61573bb Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Fri, 16 Jan 2026 12:52:15 +0100 Subject: [PATCH] Added embeddings prefetching for entities ingest --- src/Server/AIProvider.cs | 11 +++-- src/Server/Datapoint.cs | 62 ++++++++++++++++++++++++- src/Server/Helper/SearchdomainHelper.cs | 6 +++ 3 files changed, 75 insertions(+), 4 deletions(-) diff --git a/src/Server/AIProvider.cs b/src/Server/AIProvider.cs index a166a85..0f83731 100644 --- a/src/Server/AIProvider.cs +++ b/src/Server/AIProvider.cs @@ -31,7 +31,12 @@ public class AIProvider } } - public float[] GenerateEmbeddings(string modelUri, string[] input) + public float[] GenerateEmbeddings(string modelUri, string input) + { + return [.. GenerateEmbeddings(modelUri, [input]).First()]; + } + + public IEnumerable GenerateEmbeddings(string modelUri, string[] input) { Uri uri = new(modelUri); string provider = uri.Scheme; @@ -103,13 +108,13 @@ public class AIProvider try { JObject responseContentJson = JObject.Parse(responseContent); - JToken? responseContentTokens = responseContentJson.SelectToken(embeddingsJsonPath); + List? responseContentTokens = [.. responseContentJson.SelectTokens(embeddingsJsonPath)]; if (responseContentTokens is null) { _logger.LogError("Unable to select tokens using JSONPath {embeddingsJsonPath} for string: {responseContent}.", [embeddingsJsonPath, responseContent]); throw new JSONPathSelectionException(embeddingsJsonPath, responseContent); } - return [.. responseContentTokens.Values()]; + return [.. responseContentTokens.Select(token => token.ToObject() ?? throw new Exception("Unable to cast embeddings response to float[]"))]; } catch (Exception ex) { diff --git a/src/Server/Datapoint.cs b/src/Server/Datapoint.cs index 6325a96..ee17c36 100644 --- a/src/Server/Datapoint.cs +++ b/src/Server/Datapoint.cs @@ -52,9 +52,69 @@ public class Datapoint return embeddings; } + public static Dictionary> GetEmbeddings(string[] content, List models, AIProvider aIProvider, EnumerableLruCache> embeddingCache) + { + Dictionary> embeddings = []; + foreach (string model in models) + { + List toBeGenerated = []; + embeddings[model] = []; + foreach (string value in content) + { + bool generateThisEntry = true; + bool embeddingCacheHasContent = embeddingCache.TryGetValue(value, out var embeddingCacheForContent); + if (embeddingCacheHasContent && embeddingCacheForContent is not null) + { + bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embedding); + if (embeddingCacheHasModel && embedding is not null) + { + embeddings[model][value] = embedding; + generateThisEntry = false; + } + } + if (generateThisEntry) + { + if (!toBeGenerated.Contains(value)) + { + toBeGenerated.Add(value); + } + } + } + IEnumerable generatedEmbeddings = GenerateEmbeddings([.. toBeGenerated], model, aIProvider, embeddingCache); + if (generatedEmbeddings.Count() != toBeGenerated.Count) + { + throw new Exception("Requested embeddings count and generated embeddings count mismatched!"); + } + for (int i = 0; i < toBeGenerated.Count; i++) + { + embeddings[model][toBeGenerated.ElementAt(i)] = generatedEmbeddings.ElementAt(i); + } + } + return embeddings; + } + + public static IEnumerable GenerateEmbeddings(string[] content, string model, AIProvider aIProvider, EnumerableLruCache> embeddingCache) + { + IEnumerable embeddings = aIProvider.GenerateEmbeddings(model, content); + if (embeddings.Count() != content.Length) + { + throw new Exception("Resulting embeddings count does not match up with request count"); + } + for (int i = 0; i < content.Length; i++) + { + if (!embeddingCache.ContainsKey(content[i])) + { + embeddingCache[content[i]] = []; + } + embeddingCache[content[i]][model] = embeddings.ElementAt(i); + } + return embeddings; + } + + public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache> embeddingCache) { - float[] embeddings = aIProvider.GenerateEmbeddings(model, [content]); + float[] embeddings = aIProvider.GenerateEmbeddings(model, content); if (!embeddingCache.ContainsKey(content)) { embeddingCache[content] = []; diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 4aabd6d..7e62bda 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -74,6 +74,12 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp } } } + foreach (var toBeCachedKV in toBeCached) + { + string model = toBeCachedKV.Key; + List uniqueStrings = [.. toBeCachedKV.Value.Distinct()]; + Datapoint.GetEmbeddings([.. uniqueStrings], [model], aIProvider, embeddingCache); + } ConcurrentQueue retVal = []; ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity =>