From 7b4a3bd2c8d0d62d2dce91e20b5a4620bc8ba102 Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Sat, 27 Dec 2025 18:40:03 +0100 Subject: [PATCH] Moved embeddingCache from Dictionary to LRUCache --- src/Server/Datapoint.cs | 18 ++++++++------- src/Server/Helper/SearchdomainHelper.cs | 7 +++--- src/Server/Searchdomain.cs | 29 +++++++++++-------------- src/Server/SearchdomainManager.cs | 7 ++++-- src/Server/Server.csproj | 1 + src/Server/appsettings.Development.json | 10 +++++++++ 6 files changed, 43 insertions(+), 29 deletions(-) diff --git a/src/Server/Datapoint.cs b/src/Server/Datapoint.cs index 6c515c4..42d59a7 100644 --- a/src/Server/Datapoint.cs +++ b/src/Server/Datapoint.cs @@ -1,3 +1,4 @@ +using AdaptiveExpressions; using OllamaSharp; using OllamaSharp.Models; @@ -27,30 +28,31 @@ public class Datapoint public static Dictionary GenerateEmbeddings(string content, List models, AIProvider aIProvider) { - return GenerateEmbeddings(content, models, aIProvider, []); + return GenerateEmbeddings(content, models, aIProvider, new()); } - public static Dictionary GenerateEmbeddings(string content, List models, AIProvider aIProvider, Dictionary> embeddingCache) + public static Dictionary GenerateEmbeddings(string content, List models, AIProvider aIProvider, LRUCache> embeddingCache) { Dictionary retVal = []; foreach (string model in models) { - if (embeddingCache.ContainsKey(model) && embeddingCache[model].ContainsKey(content)) + bool embeddingCacheHasModel = embeddingCache.TryGet(model, out var embeddingCacheForModel); + if (embeddingCacheHasModel && embeddingCacheForModel.ContainsKey(content)) { - retVal[model] = embeddingCache[model][content]; + retVal[model] = embeddingCacheForModel[content]; continue; } var response = aIProvider.GenerateEmbeddings(model, [content]); if (response is not null) { retVal[model] = response; - if (!embeddingCache.ContainsKey(model)) + if (!embeddingCacheHasModel) { - embeddingCache[model] = []; + embeddingCacheForModel = []; } - if (!embeddingCache[model].ContainsKey(content)) + if (!embeddingCacheForModel.ContainsKey(content)) { - embeddingCache[model][content] = response; + embeddingCacheForModel[content] = response; } } } diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 3409ba4..559e29b 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -2,6 +2,7 @@ using System.Collections.Concurrent; using System.Security.Cryptography; using System.Text; using System.Text.Json; +using AdaptiveExpressions; using Server.Exceptions; using Shared.Models; @@ -46,7 +47,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp public List? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json) { - Dictionary> embeddingCache = searchdomainManager.embeddingCache; + LRUCache> embeddingCache = searchdomainManager.embeddingCache; AIProvider aIProvider = searchdomainManager.aIProvider; SQLHelper helper = searchdomainManager.helper; @@ -91,7 +92,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain); List entityCache = searchdomain.entityCache; AIProvider aIProvider = searchdomain.aIProvider; - Dictionary> embeddingCache = searchdomain.embeddingCache; + LRUCache> embeddingCache = searchdomain.embeddingCache; Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); if (preexistingEntity is not null) @@ -261,7 +262,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp throw new Exception("jsonDatapoint.Text must not be null at this point"); } using SQLHelper helper = searchdomain.helper.DuplicateConnection(); - Dictionary> embeddingCache = searchdomain.embeddingCache; + LRUCache> embeddingCache = searchdomain.embeddingCache; hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); Dictionary embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index 2435914..81502cd 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -5,6 +5,7 @@ using ElmahCore.Mvc.Logger; using MySql.Data.MySqlClient; using Server.Helper; using Shared.Models; +using AdaptiveExpressions; namespace Server; @@ -19,13 +20,12 @@ public class Searchdomain public Dictionary searchCache; // Key: query, Value: Search results for that query (with timestamp) public List entityCache; public List modelsInUse; - public Dictionary> embeddingCache; - public int embeddingCacheMaxSize = 10000000; + public LRUCache> embeddingCache; private readonly MySqlConnection connection; public SQLHelper helper; private readonly ILogger _logger; - public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, Dictionary> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) + public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, LRUCache> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) { _connectionString = connectionString; _provider = provider.ToLower(); @@ -169,25 +169,22 @@ public class Searchdomain return [.. cachedResult.Results.Select(r => (r.Score, r.Name))]; } - bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary? queryEmbeddings); + bool hasQuery = embeddingCache.TryGet(query, out Dictionary? queryEmbeddings); bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model)); if (!(hasQuery && allModelsInQuery)) { queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache); - if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits - { // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries - if (!embeddingCache.ContainsKey(query)) + if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery)) + { + embeddingCache.Set(query, queryEmbeddings); + } + else // embeddingCache already has an entry for this query, so the missing model-embedding pairs have to be filled in + { + foreach (KeyValuePair kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding { - embeddingCache.Add(query, queryEmbeddings); - } - else - { - foreach (KeyValuePair kvp in queryEmbeddings) + if (!embeddingCache.TryGet(kvp.Key, out var _)) { - if (!embeddingCache.ContainsKey(kvp.Key)) - { - embeddingCache[query][kvp.Key] = kvp.Value; - } + embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value; } } } diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index e354f97..c024c2c 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -3,6 +3,7 @@ using System.Data.Common; using Server.Migrations; using Server.Helper; using Server.Exceptions; +using AdaptiveExpressions; namespace Server; @@ -16,7 +17,8 @@ public class SearchdomainManager private readonly string connectionString; private MySqlConnection connection; public SQLHelper helper; - public Dictionary> embeddingCache; + public LRUCache> embeddingCache; + public int EmbeddingCacheMaxCount; public SearchdomainManager(ILogger logger, IConfiguration config, AIProvider aIProvider, DatabaseHelper databaseHelper) { @@ -24,7 +26,8 @@ public class SearchdomainManager _config = config; this.aIProvider = aIProvider; _databaseHelper = databaseHelper; - embeddingCache = []; + EmbeddingCacheMaxCount = config.GetValue("Embeddingsearch:EmbeddingCacheMaxCount") ?? 1000000; + embeddingCache = new(EmbeddingCacheMaxCount); connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? ""; connection = new MySqlConnection(connectionString); connection.Open(); diff --git a/src/Server/Server.csproj b/src/Server/Server.csproj index 178f4c0..6a38352 100644 --- a/src/Server/Server.csproj +++ b/src/Server/Server.csproj @@ -7,6 +7,7 @@ + diff --git a/src/Server/appsettings.Development.json b/src/Server/appsettings.Development.json index 3260092..50694d5 100644 --- a/src/Server/appsettings.Development.json +++ b/src/Server/appsettings.Development.json @@ -24,6 +24,7 @@ "172.17.0.1" ] }, + "EmbeddingCacheMaxCount": 5, "AiProviders": { "ollama": { "handler": "ollama", @@ -35,6 +36,15 @@ "ApiKey": "Some API key here" } }, + "SimpleAuth": { + "Users": [ + { + "Username": "admin", + "Password": "UnsafePractice.67", + "Roles": ["Admin"] + } + ] + }, "ApiKeys": ["Some UUID here", "Another UUID here"], "UseHttpsRedirection": true }