2 Commits

Author SHA1 Message Date
LD50
17cc8f41d5 Merge pull request #93 from LD-Reborn/92-datapointgenerateembeddings-does-not-feed-embedding-cache
Moved embeddingCache to EnumerableLruCache, fixed GenerateEmbeddings …
2026-01-16 10:36:10 +01:00
a01985d1b8 Moved embeddingCache to EnumerableLruCache, fixed GenerateEmbeddings not feeding embeddingCache 2026-01-16 10:35:46 +01:00
5 changed files with 47 additions and 46 deletions

View File

@@ -9,6 +9,7 @@ using Microsoft.Extensions.Options;
using Server.Exceptions; using Server.Exceptions;
using Server.Helper; using Server.Helper;
using Server.Models; using Server.Models;
using Shared;
using Shared.Models; using Shared.Models;
[ApiController] [ApiController]
@@ -61,18 +62,12 @@ public class ServerController : ControllerBase
long size = 0; long size = 0;
long elementCount = 0; long elementCount = 0;
long embeddingsCount = 0; long embeddingsCount = 0;
LRUCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.embeddingCache;
var cacheListField = embeddingCache.GetType()
.GetField("_cacheList", BindingFlags.Instance | BindingFlags.NonPublic) ?? throw new InvalidOperationException("_cacheList field not found"); // TODO Remove this unsafe reflection atrocity
LinkedList<string> cacheListOriginal = (LinkedList<string>)cacheListField.GetValue(embeddingCache)!;
LinkedList<string> cacheList = new(cacheListOriginal);
foreach (string key in cacheList) foreach (KeyValuePair<string, Dictionary<string, float[]>> kv in embeddingCache)
{ {
if (!embeddingCache.TryGet(key, out var entry)) string key = kv.Key;
continue; Dictionary<string, float[]> entry = kv.Value;
// estimate size
size += EstimateEntrySize(key, entry); size += EstimateEntrySize(key, entry);
elementCount++; elementCount++;
embeddingsCount += entry.Keys.Count; embeddingsCount += entry.Keys.Count;

View File

@@ -1,6 +1,7 @@
using AdaptiveExpressions; using AdaptiveExpressions;
using OllamaSharp; using OllamaSharp;
using OllamaSharp.Models; using OllamaSharp.Models;
using Shared;
namespace Server; namespace Server;
@@ -26,36 +27,39 @@ public class Datapoint
return probMethod.method(probabilities); return probMethod.method(probabilities);
} }
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider) public static Dictionary<string, float[]> GetEmbeddings(string content, List<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
{ {
return GenerateEmbeddings(content, models, aIProvider, new()); Dictionary<string, float[]> embeddings = [];
bool embeddingCacheHasContent = embeddingCache.TryGetValue(content, out var embeddingCacheForContent);
if (!embeddingCacheHasContent || embeddingCacheForContent is null)
{
models.ForEach(model =>
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache)
);
return embeddings;
}
models.ForEach(model =>
{
bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embeddingCacheForModel);
if (embeddingCacheHasModel && embeddingCacheForModel is not null)
{
embeddings[model] = embeddingCacheForModel;
} else
{
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache);
}
});
return embeddings;
} }
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache) public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
{ {
Dictionary<string, float[]> retVal = []; float[] embeddings = aIProvider.GenerateEmbeddings(model, [content]);
foreach (string model in models) if (!embeddingCache.ContainsKey(content))
{ {
bool embeddingCacheHasModel = embeddingCache.TryGet(model, out var embeddingCacheForModel); embeddingCache[content] = [];
if (embeddingCacheHasModel && embeddingCacheForModel.ContainsKey(content))
{
retVal[model] = embeddingCacheForModel[content];
continue;
} }
var response = aIProvider.GenerateEmbeddings(model, [content]); embeddingCache[content][model] = embeddings;
if (response is not null) return embeddings;
{
retVal[model] = response;
if (!embeddingCacheHasModel)
{
embeddingCacheForModel = [];
}
if (!embeddingCacheForModel.ContainsKey(content))
{
embeddingCacheForModel[content] = response;
}
}
}
return retVal;
} }
} }

View File

@@ -4,6 +4,7 @@ using System.Text;
using System.Text.Json; using System.Text.Json;
using AdaptiveExpressions; using AdaptiveExpressions;
using Server.Exceptions; using Server.Exceptions;
using Shared;
using Shared.Models; using Shared.Models;
namespace Server.Helper; namespace Server.Helper;
@@ -47,7 +48,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json) public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
{ {
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache;
AIProvider aIProvider = searchdomainManager.aIProvider; AIProvider aIProvider = searchdomainManager.aIProvider;
SQLHelper helper = searchdomainManager.helper; SQLHelper helper = searchdomainManager.helper;
@@ -92,7 +93,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain); Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
List<Entity> entityCache = searchdomain.entityCache; List<Entity> entityCache = searchdomain.entityCache;
AIProvider aIProvider = searchdomain.aIProvider; AIProvider aIProvider = searchdomain.aIProvider;
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
bool invalidateSearchCache = false; bool invalidateSearchCache = false;
@@ -274,10 +275,10 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
throw new Exception("jsonDatapoint.Text must not be null at this point"); throw new Exception("jsonDatapoint.Text must not be null at this point");
} }
using SQLHelper helper = searchdomain.helper.DuplicateConnection(); using SQLHelper helper = searchdomain.helper.DuplicateConnection();
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding); var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding);
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod); var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod);
return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);

View File

@@ -21,12 +21,12 @@ public class Searchdomain
public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp) public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp)
public List<Entity> entityCache; public List<Entity> entityCache;
public List<string> modelsInUse; public List<string> modelsInUse;
public LRUCache<string, Dictionary<string, float[]>> embeddingCache; public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
private readonly MySqlConnection connection; private readonly MySqlConnection connection;
public SQLHelper helper; public SQLHelper helper;
private readonly ILogger _logger; private readonly ILogger _logger;
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
{ {
_connectionString = connectionString; _connectionString = connectionString;
_provider = provider.ToLower(); _provider = provider.ToLower();
@@ -194,12 +194,12 @@ public class Searchdomain
public Dictionary<string, float[]> GetQueryEmbeddings(string query) public Dictionary<string, float[]> GetQueryEmbeddings(string query)
{ {
bool hasQuery = embeddingCache.TryGet(query, out Dictionary<string, float[]> queryEmbeddings); bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings);
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model)); bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null) if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null)
{ {
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache); queryEmbeddings = Datapoint.GetEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery)) if (!embeddingCache.TryGetValue(query, out var embeddingCacheForCurrentQuery))
{ {
embeddingCache.Set(query, queryEmbeddings); embeddingCache.Set(query, queryEmbeddings);
} }
@@ -207,7 +207,7 @@ public class Searchdomain
{ {
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding
{ {
if (!embeddingCache.TryGet(kvp.Key, out var _)) if (!embeddingCache.TryGetValue(kvp.Key, out var _))
{ {
embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value; embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value;
} }

View File

@@ -8,6 +8,7 @@ using Shared.Models;
using System.Text.Json; using System.Text.Json;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Server.Models; using Server.Models;
using Shared;
namespace Server; namespace Server;
@@ -21,7 +22,7 @@ public class SearchdomainManager
private readonly string connectionString; private readonly string connectionString;
private MySqlConnection connection; private MySqlConnection connection;
public SQLHelper helper; public SQLHelper helper;
public LRUCache<string, Dictionary<string, float[]>> embeddingCache; public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
public long EmbeddingCacheMaxCount; public long EmbeddingCacheMaxCount;
public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper) public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper)