From a01985d1b8acd984083f026d958b222313f98e16 Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Fri, 16 Jan 2026 10:35:46 +0100 Subject: [PATCH] Moved embeddingCache to EnumerableLruCache, fixed GenerateEmbeddings not feeding embeddingCache --- src/Server/Controllers/ServerController.cs | 15 ++---- src/Server/Datapoint.cs | 54 ++++++++++++---------- src/Server/Helper/SearchdomainHelper.cs | 9 ++-- src/Server/Searchdomain.cs | 12 ++--- src/Server/SearchdomainManager.cs | 3 +- 5 files changed, 47 insertions(+), 46 deletions(-) diff --git a/src/Server/Controllers/ServerController.cs b/src/Server/Controllers/ServerController.cs index 494dae3..6813148 100644 --- a/src/Server/Controllers/ServerController.cs +++ b/src/Server/Controllers/ServerController.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Options; using Server.Exceptions; using Server.Helper; using Server.Models; +using Shared; using Shared.Models; [ApiController] @@ -61,18 +62,12 @@ public class ServerController : ControllerBase long size = 0; long elementCount = 0; long embeddingsCount = 0; - LRUCache> embeddingCache = _searchdomainManager.embeddingCache; - var cacheListField = embeddingCache.GetType() - .GetField("_cacheList", BindingFlags.Instance | BindingFlags.NonPublic) ?? throw new InvalidOperationException("_cacheList field not found"); // TODO Remove this unsafe reflection atrocity - LinkedList cacheListOriginal = (LinkedList)cacheListField.GetValue(embeddingCache)!; - LinkedList cacheList = new(cacheListOriginal); + EnumerableLruCache> embeddingCache = _searchdomainManager.embeddingCache; - foreach (string key in cacheList) + foreach (KeyValuePair> kv in embeddingCache) { - if (!embeddingCache.TryGet(key, out var entry)) - continue; - - // estimate size + string key = kv.Key; + Dictionary entry = kv.Value; size += EstimateEntrySize(key, entry); elementCount++; embeddingsCount += entry.Keys.Count; diff --git a/src/Server/Datapoint.cs b/src/Server/Datapoint.cs index 42d59a7..6325a96 100644 --- a/src/Server/Datapoint.cs +++ b/src/Server/Datapoint.cs @@ -1,6 +1,7 @@ using AdaptiveExpressions; using OllamaSharp; using OllamaSharp.Models; +using Shared; namespace Server; @@ -26,36 +27,39 @@ public class Datapoint return probMethod.method(probabilities); } - public static Dictionary GenerateEmbeddings(string content, List models, AIProvider aIProvider) + public static Dictionary GetEmbeddings(string content, List models, AIProvider aIProvider, EnumerableLruCache> embeddingCache) { - return GenerateEmbeddings(content, models, aIProvider, new()); + Dictionary embeddings = []; + bool embeddingCacheHasContent = embeddingCache.TryGetValue(content, out var embeddingCacheForContent); + if (!embeddingCacheHasContent || embeddingCacheForContent is null) + { + models.ForEach(model => + embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache) + ); + return embeddings; + } + models.ForEach(model => + { + bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embeddingCacheForModel); + if (embeddingCacheHasModel && embeddingCacheForModel is not null) + { + embeddings[model] = embeddingCacheForModel; + } else + { + embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache); + } + }); + return embeddings; } - public static Dictionary GenerateEmbeddings(string content, List models, AIProvider aIProvider, LRUCache> embeddingCache) + public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache> embeddingCache) { - Dictionary retVal = []; - foreach (string model in models) + float[] embeddings = aIProvider.GenerateEmbeddings(model, [content]); + if (!embeddingCache.ContainsKey(content)) { - bool embeddingCacheHasModel = embeddingCache.TryGet(model, out var embeddingCacheForModel); - if (embeddingCacheHasModel && embeddingCacheForModel.ContainsKey(content)) - { - retVal[model] = embeddingCacheForModel[content]; - continue; - } - var response = aIProvider.GenerateEmbeddings(model, [content]); - if (response is not null) - { - retVal[model] = response; - if (!embeddingCacheHasModel) - { - embeddingCacheForModel = []; - } - if (!embeddingCacheForModel.ContainsKey(content)) - { - embeddingCacheForModel[content] = response; - } - } + embeddingCache[content] = []; } - return retVal; + embeddingCache[content][model] = embeddings; + return embeddings; } } \ No newline at end of file diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 9be053b..4aabd6d 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -4,6 +4,7 @@ using System.Text; using System.Text.Json; using AdaptiveExpressions; using Server.Exceptions; +using Shared; using Shared.Models; namespace Server.Helper; @@ -47,7 +48,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp public List? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json) { - LRUCache> embeddingCache = searchdomainManager.embeddingCache; + EnumerableLruCache> embeddingCache = searchdomainManager.embeddingCache; AIProvider aIProvider = searchdomainManager.aIProvider; SQLHelper helper = searchdomainManager.helper; @@ -92,7 +93,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain); List entityCache = searchdomain.entityCache; AIProvider aIProvider = searchdomain.aIProvider; - LRUCache> embeddingCache = searchdomain.embeddingCache; + EnumerableLruCache> embeddingCache = searchdomain.embeddingCache; Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); bool invalidateSearchCache = false; @@ -274,10 +275,10 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp throw new Exception("jsonDatapoint.Text must not be null at this point"); } using SQLHelper helper = searchdomain.helper.DuplicateConnection(); - LRUCache> embeddingCache = searchdomain.embeddingCache; + EnumerableLruCache> embeddingCache = searchdomain.embeddingCache; hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); - Dictionary embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); + Dictionary embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding); var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod); return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index e856947..70dd895 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -21,12 +21,12 @@ public class Searchdomain public EnumerableLruCache queryCache; // Key: query, Value: Search results for that query (with timestamp) public List entityCache; public List modelsInUse; - public LRUCache> embeddingCache; + public EnumerableLruCache> embeddingCache; private readonly MySqlConnection connection; public SQLHelper helper; private readonly ILogger _logger; - public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, LRUCache> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) + public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, EnumerableLruCache> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) { _connectionString = connectionString; _provider = provider.ToLower(); @@ -194,12 +194,12 @@ public class Searchdomain public Dictionary GetQueryEmbeddings(string query) { - bool hasQuery = embeddingCache.TryGet(query, out Dictionary queryEmbeddings); + bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary? queryEmbeddings); bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model)); if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null) { - queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache); - if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery)) + queryEmbeddings = Datapoint.GetEmbeddings(query, modelsInUse, aIProvider, embeddingCache); + if (!embeddingCache.TryGetValue(query, out var embeddingCacheForCurrentQuery)) { embeddingCache.Set(query, queryEmbeddings); } @@ -207,7 +207,7 @@ public class Searchdomain { foreach (KeyValuePair kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding { - if (!embeddingCache.TryGet(kvp.Key, out var _)) + if (!embeddingCache.TryGetValue(kvp.Key, out var _)) { embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value; } diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index 9cc5019..0ca6d0a 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -8,6 +8,7 @@ using Shared.Models; using System.Text.Json; using Microsoft.Extensions.Options; using Server.Models; +using Shared; namespace Server; @@ -21,7 +22,7 @@ public class SearchdomainManager private readonly string connectionString; private MySqlConnection connection; public SQLHelper helper; - public LRUCache> embeddingCache; + public EnumerableLruCache> embeddingCache; public long EmbeddingCacheMaxCount; public SearchdomainManager(ILogger logger, IOptions options, AIProvider aIProvider, DatabaseHelper databaseHelper)