Moved embeddingCache to EnumerableLruCache, fixed GenerateEmbeddings not feeding embeddingCache

This commit is contained in:
2026-01-16 10:35:46 +01:00
parent 4c1f0305fc
commit a01985d1b8
5 changed files with 47 additions and 46 deletions

View File

@@ -1,6 +1,7 @@
using AdaptiveExpressions;
using OllamaSharp;
using OllamaSharp.Models;
using Shared;
namespace Server;
@@ -26,36 +27,39 @@ public class Datapoint
return probMethod.method(probabilities);
}
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
public static Dictionary<string, float[]> GetEmbeddings(string content, List<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
{
return GenerateEmbeddings(content, models, aIProvider, new());
Dictionary<string, float[]> embeddings = [];
bool embeddingCacheHasContent = embeddingCache.TryGetValue(content, out var embeddingCacheForContent);
if (!embeddingCacheHasContent || embeddingCacheForContent is null)
{
models.ForEach(model =>
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache)
);
return embeddings;
}
models.ForEach(model =>
{
bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embeddingCacheForModel);
if (embeddingCacheHasModel && embeddingCacheForModel is not null)
{
embeddings[model] = embeddingCacheForModel;
} else
{
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache);
}
});
return embeddings;
}
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache)
public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
{
Dictionary<string, float[]> retVal = [];
foreach (string model in models)
float[] embeddings = aIProvider.GenerateEmbeddings(model, [content]);
if (!embeddingCache.ContainsKey(content))
{
bool embeddingCacheHasModel = embeddingCache.TryGet(model, out var embeddingCacheForModel);
if (embeddingCacheHasModel && embeddingCacheForModel.ContainsKey(content))
{
retVal[model] = embeddingCacheForModel[content];
continue;
}
var response = aIProvider.GenerateEmbeddings(model, [content]);
if (response is not null)
{
retVal[model] = response;
if (!embeddingCacheHasModel)
{
embeddingCacheForModel = [];
}
if (!embeddingCacheForModel.ContainsKey(content))
{
embeddingCacheForModel[content] = response;
}
}
embeddingCache[content] = [];
}
return retVal;
embeddingCache[content][model] = embeddings;
return embeddings;
}
}