Moved embeddingCache from Dictionary to LRUCache

This commit is contained in:
2025-12-27 18:40:03 +01:00
parent 5eabb0d924
commit 7b4a3bd2c8
6 changed files with 43 additions and 29 deletions

View File

@@ -1,3 +1,4 @@
using AdaptiveExpressions;
using OllamaSharp; using OllamaSharp;
using OllamaSharp.Models; using OllamaSharp.Models;
@@ -27,30 +28,31 @@ public class Datapoint
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider) public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
{ {
return GenerateEmbeddings(content, models, aIProvider, []); return GenerateEmbeddings(content, models, aIProvider, new());
} }
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache) public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache)
{ {
Dictionary<string, float[]> retVal = []; Dictionary<string, float[]> retVal = [];
foreach (string model in models) foreach (string model in models)
{ {
if (embeddingCache.ContainsKey(model) && embeddingCache[model].ContainsKey(content)) bool embeddingCacheHasModel = embeddingCache.TryGet(model, out var embeddingCacheForModel);
if (embeddingCacheHasModel && embeddingCacheForModel.ContainsKey(content))
{ {
retVal[model] = embeddingCache[model][content]; retVal[model] = embeddingCacheForModel[content];
continue; continue;
} }
var response = aIProvider.GenerateEmbeddings(model, [content]); var response = aIProvider.GenerateEmbeddings(model, [content]);
if (response is not null) if (response is not null)
{ {
retVal[model] = response; retVal[model] = response;
if (!embeddingCache.ContainsKey(model)) if (!embeddingCacheHasModel)
{ {
embeddingCache[model] = []; embeddingCacheForModel = [];
} }
if (!embeddingCache[model].ContainsKey(content)) if (!embeddingCacheForModel.ContainsKey(content))
{ {
embeddingCache[model][content] = response; embeddingCacheForModel[content] = response;
} }
} }
} }

View File

@@ -2,6 +2,7 @@ using System.Collections.Concurrent;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using AdaptiveExpressions;
using Server.Exceptions; using Server.Exceptions;
using Shared.Models; using Shared.Models;
@@ -46,7 +47,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json) public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
{ {
Dictionary<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache; LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache;
AIProvider aIProvider = searchdomainManager.aIProvider; AIProvider aIProvider = searchdomainManager.aIProvider;
SQLHelper helper = searchdomainManager.helper; SQLHelper helper = searchdomainManager.helper;
@@ -91,7 +92,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain); Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
List<Entity> entityCache = searchdomain.entityCache; List<Entity> entityCache = searchdomain.entityCache;
AIProvider aIProvider = searchdomain.aIProvider; AIProvider aIProvider = searchdomain.aIProvider;
Dictionary<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
if (preexistingEntity is not null) if (preexistingEntity is not null)
@@ -261,7 +262,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
throw new Exception("jsonDatapoint.Text must not be null at this point"); throw new Exception("jsonDatapoint.Text must not be null at this point");
} }
using SQLHelper helper = searchdomain.helper.DuplicateConnection(); using SQLHelper helper = searchdomain.helper.DuplicateConnection();
Dictionary<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);

View File

@@ -5,6 +5,7 @@ using ElmahCore.Mvc.Logger;
using MySql.Data.MySqlClient; using MySql.Data.MySqlClient;
using Server.Helper; using Server.Helper;
using Shared.Models; using Shared.Models;
using AdaptiveExpressions;
namespace Server; namespace Server;
@@ -19,13 +20,12 @@ public class Searchdomain
public Dictionary<string, DateTimedSearchResult> searchCache; // Key: query, Value: Search results for that query (with timestamp) public Dictionary<string, DateTimedSearchResult> searchCache; // Key: query, Value: Search results for that query (with timestamp)
public List<Entity> entityCache; public List<Entity> entityCache;
public List<string> modelsInUse; public List<string> modelsInUse;
public Dictionary<string, Dictionary<string, float[]>> embeddingCache; public LRUCache<string, Dictionary<string, float[]>> embeddingCache;
public int embeddingCacheMaxSize = 10000000;
private readonly MySqlConnection connection; private readonly MySqlConnection connection;
public SQLHelper helper; public SQLHelper helper;
private readonly ILogger _logger; private readonly ILogger _logger;
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
{ {
_connectionString = connectionString; _connectionString = connectionString;
_provider = provider.ToLower(); _provider = provider.ToLower();
@@ -169,25 +169,22 @@ public class Searchdomain
return [.. cachedResult.Results.Select(r => (r.Score, r.Name))]; return [.. cachedResult.Results.Select(r => (r.Score, r.Name))];
} }
bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings); bool hasQuery = embeddingCache.TryGet(query, out Dictionary<string, float[]>? queryEmbeddings);
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model)); bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
if (!(hasQuery && allModelsInQuery)) if (!(hasQuery && allModelsInQuery))
{ {
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache); queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery))
{ // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries
if (!embeddingCache.ContainsKey(query))
{ {
embeddingCache.Add(query, queryEmbeddings); embeddingCache.Set(query, queryEmbeddings);
} }
else else // embeddingCache already has an entry for this query, so the missing model-embedding pairs have to be filled in
{ {
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding
{ {
if (!embeddingCache.ContainsKey(kvp.Key)) if (!embeddingCache.TryGet(kvp.Key, out var _))
{ {
embeddingCache[query][kvp.Key] = kvp.Value; embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value;
}
} }
} }
} }

View File

@@ -3,6 +3,7 @@ using System.Data.Common;
using Server.Migrations; using Server.Migrations;
using Server.Helper; using Server.Helper;
using Server.Exceptions; using Server.Exceptions;
using AdaptiveExpressions;
namespace Server; namespace Server;
@@ -16,7 +17,8 @@ public class SearchdomainManager
private readonly string connectionString; private readonly string connectionString;
private MySqlConnection connection; private MySqlConnection connection;
public SQLHelper helper; public SQLHelper helper;
public Dictionary<string, Dictionary<string, float[]>> embeddingCache; public LRUCache<string, Dictionary<string, float[]>> embeddingCache;
public int EmbeddingCacheMaxCount;
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config, AIProvider aIProvider, DatabaseHelper databaseHelper) public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config, AIProvider aIProvider, DatabaseHelper databaseHelper)
{ {
@@ -24,7 +26,8 @@ public class SearchdomainManager
_config = config; _config = config;
this.aIProvider = aIProvider; this.aIProvider = aIProvider;
_databaseHelper = databaseHelper; _databaseHelper = databaseHelper;
embeddingCache = []; EmbeddingCacheMaxCount = config.GetValue<int?>("Embeddingsearch:EmbeddingCacheMaxCount") ?? 1000000;
embeddingCache = new(EmbeddingCacheMaxCount);
connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? ""; connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? "";
connection = new MySqlConnection(connectionString); connection = new MySqlConnection(connectionString);
connection.Open(); connection.Open();

View File

@@ -7,6 +7,7 @@
</PropertyGroup> </PropertyGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="AdaptiveExpressions" Version="4.23.0" />
<PackageReference Include="ElmahCore" Version="2.1.2" /> <PackageReference Include="ElmahCore" Version="2.1.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> <PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" /> <PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />

View File

@@ -24,6 +24,7 @@
"172.17.0.1" "172.17.0.1"
] ]
}, },
"EmbeddingCacheMaxCount": 5,
"AiProviders": { "AiProviders": {
"ollama": { "ollama": {
"handler": "ollama", "handler": "ollama",
@@ -35,6 +36,15 @@
"ApiKey": "Some API key here" "ApiKey": "Some API key here"
} }
}, },
"SimpleAuth": {
"Users": [
{
"Username": "admin",
"Password": "UnsafePractice.67",
"Roles": ["Admin"]
}
]
},
"ApiKeys": ["Some UUID here", "Another UUID here"], "ApiKeys": ["Some UUID here", "Another UUID here"],
"UseHttpsRedirection": true "UseHttpsRedirection": true
} }