Merge pull request #93 from LD-Reborn/92-datapointgenerateembeddings-does-not-feed-embedding-cache
Moved embeddingCache to EnumerableLruCache, fixed GenerateEmbeddings …
This commit is contained in:
@@ -9,6 +9,7 @@ using Microsoft.Extensions.Options;
|
|||||||
using Server.Exceptions;
|
using Server.Exceptions;
|
||||||
using Server.Helper;
|
using Server.Helper;
|
||||||
using Server.Models;
|
using Server.Models;
|
||||||
|
using Shared;
|
||||||
using Shared.Models;
|
using Shared.Models;
|
||||||
|
|
||||||
[ApiController]
|
[ApiController]
|
||||||
@@ -61,18 +62,12 @@ public class ServerController : ControllerBase
|
|||||||
long size = 0;
|
long size = 0;
|
||||||
long elementCount = 0;
|
long elementCount = 0;
|
||||||
long embeddingsCount = 0;
|
long embeddingsCount = 0;
|
||||||
LRUCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.embeddingCache;
|
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.embeddingCache;
|
||||||
var cacheListField = embeddingCache.GetType()
|
|
||||||
.GetField("_cacheList", BindingFlags.Instance | BindingFlags.NonPublic) ?? throw new InvalidOperationException("_cacheList field not found"); // TODO Remove this unsafe reflection atrocity
|
|
||||||
LinkedList<string> cacheListOriginal = (LinkedList<string>)cacheListField.GetValue(embeddingCache)!;
|
|
||||||
LinkedList<string> cacheList = new(cacheListOriginal);
|
|
||||||
|
|
||||||
foreach (string key in cacheList)
|
foreach (KeyValuePair<string, Dictionary<string, float[]>> kv in embeddingCache)
|
||||||
{
|
{
|
||||||
if (!embeddingCache.TryGet(key, out var entry))
|
string key = kv.Key;
|
||||||
continue;
|
Dictionary<string, float[]> entry = kv.Value;
|
||||||
|
|
||||||
// estimate size
|
|
||||||
size += EstimateEntrySize(key, entry);
|
size += EstimateEntrySize(key, entry);
|
||||||
elementCount++;
|
elementCount++;
|
||||||
embeddingsCount += entry.Keys.Count;
|
embeddingsCount += entry.Keys.Count;
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
using AdaptiveExpressions;
|
using AdaptiveExpressions;
|
||||||
using OllamaSharp;
|
using OllamaSharp;
|
||||||
using OllamaSharp.Models;
|
using OllamaSharp.Models;
|
||||||
|
using Shared;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -26,36 +27,39 @@ public class Datapoint
|
|||||||
return probMethod.method(probabilities);
|
return probMethod.method(probabilities);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
|
public static Dictionary<string, float[]> GetEmbeddings(string content, List<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
{
|
{
|
||||||
return GenerateEmbeddings(content, models, aIProvider, new());
|
Dictionary<string, float[]> embeddings = [];
|
||||||
|
bool embeddingCacheHasContent = embeddingCache.TryGetValue(content, out var embeddingCacheForContent);
|
||||||
|
if (!embeddingCacheHasContent || embeddingCacheForContent is null)
|
||||||
|
{
|
||||||
|
models.ForEach(model =>
|
||||||
|
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache)
|
||||||
|
);
|
||||||
|
return embeddings;
|
||||||
|
}
|
||||||
|
models.ForEach(model =>
|
||||||
|
{
|
||||||
|
bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embeddingCacheForModel);
|
||||||
|
if (embeddingCacheHasModel && embeddingCacheForModel is not null)
|
||||||
|
{
|
||||||
|
embeddings[model] = embeddingCacheForModel;
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return embeddings;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache)
|
public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
{
|
{
|
||||||
Dictionary<string, float[]> retVal = [];
|
float[] embeddings = aIProvider.GenerateEmbeddings(model, [content]);
|
||||||
foreach (string model in models)
|
if (!embeddingCache.ContainsKey(content))
|
||||||
{
|
{
|
||||||
bool embeddingCacheHasModel = embeddingCache.TryGet(model, out var embeddingCacheForModel);
|
embeddingCache[content] = [];
|
||||||
if (embeddingCacheHasModel && embeddingCacheForModel.ContainsKey(content))
|
|
||||||
{
|
|
||||||
retVal[model] = embeddingCacheForModel[content];
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
var response = aIProvider.GenerateEmbeddings(model, [content]);
|
embeddingCache[content][model] = embeddings;
|
||||||
if (response is not null)
|
return embeddings;
|
||||||
{
|
|
||||||
retVal[model] = response;
|
|
||||||
if (!embeddingCacheHasModel)
|
|
||||||
{
|
|
||||||
embeddingCacheForModel = [];
|
|
||||||
}
|
|
||||||
if (!embeddingCacheForModel.ContainsKey(content))
|
|
||||||
{
|
|
||||||
embeddingCacheForModel[content] = response;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return retVal;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -4,6 +4,7 @@ using System.Text;
|
|||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
using AdaptiveExpressions;
|
using AdaptiveExpressions;
|
||||||
using Server.Exceptions;
|
using Server.Exceptions;
|
||||||
|
using Shared;
|
||||||
using Shared.Models;
|
using Shared.Models;
|
||||||
|
|
||||||
namespace Server.Helper;
|
namespace Server.Helper;
|
||||||
@@ -47,7 +48,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
|
|||||||
|
|
||||||
public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
|
public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
|
||||||
{
|
{
|
||||||
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache;
|
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache;
|
||||||
AIProvider aIProvider = searchdomainManager.aIProvider;
|
AIProvider aIProvider = searchdomainManager.aIProvider;
|
||||||
SQLHelper helper = searchdomainManager.helper;
|
SQLHelper helper = searchdomainManager.helper;
|
||||||
|
|
||||||
@@ -92,7 +93,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
|
|||||||
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
|
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
|
||||||
List<Entity> entityCache = searchdomain.entityCache;
|
List<Entity> entityCache = searchdomain.entityCache;
|
||||||
AIProvider aIProvider = searchdomain.aIProvider;
|
AIProvider aIProvider = searchdomain.aIProvider;
|
||||||
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
||||||
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
|
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
|
||||||
bool invalidateSearchCache = false;
|
bool invalidateSearchCache = false;
|
||||||
|
|
||||||
@@ -274,10 +275,10 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
|
|||||||
throw new Exception("jsonDatapoint.Text must not be null at this point");
|
throw new Exception("jsonDatapoint.Text must not be null at this point");
|
||||||
}
|
}
|
||||||
using SQLHelper helper = searchdomain.helper.DuplicateConnection();
|
using SQLHelper helper = searchdomain.helper.DuplicateConnection();
|
||||||
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
||||||
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
|
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
|
||||||
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
|
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
|
||||||
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);
|
Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);
|
||||||
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding);
|
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding);
|
||||||
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod);
|
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod);
|
||||||
return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||||
|
|||||||
@@ -21,12 +21,12 @@ public class Searchdomain
|
|||||||
public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp)
|
public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp)
|
||||||
public List<Entity> entityCache;
|
public List<Entity> entityCache;
|
||||||
public List<string> modelsInUse;
|
public List<string> modelsInUse;
|
||||||
public LRUCache<string, Dictionary<string, float[]>> embeddingCache;
|
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
|
||||||
private readonly MySqlConnection connection;
|
private readonly MySqlConnection connection;
|
||||||
public SQLHelper helper;
|
public SQLHelper helper;
|
||||||
private readonly ILogger _logger;
|
private readonly ILogger _logger;
|
||||||
|
|
||||||
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
||||||
{
|
{
|
||||||
_connectionString = connectionString;
|
_connectionString = connectionString;
|
||||||
_provider = provider.ToLower();
|
_provider = provider.ToLower();
|
||||||
@@ -194,12 +194,12 @@ public class Searchdomain
|
|||||||
|
|
||||||
public Dictionary<string, float[]> GetQueryEmbeddings(string query)
|
public Dictionary<string, float[]> GetQueryEmbeddings(string query)
|
||||||
{
|
{
|
||||||
bool hasQuery = embeddingCache.TryGet(query, out Dictionary<string, float[]> queryEmbeddings);
|
bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings);
|
||||||
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
|
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
|
||||||
if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null)
|
if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null)
|
||||||
{
|
{
|
||||||
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
|
queryEmbeddings = Datapoint.GetEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
|
||||||
if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery))
|
if (!embeddingCache.TryGetValue(query, out var embeddingCacheForCurrentQuery))
|
||||||
{
|
{
|
||||||
embeddingCache.Set(query, queryEmbeddings);
|
embeddingCache.Set(query, queryEmbeddings);
|
||||||
}
|
}
|
||||||
@@ -207,7 +207,7 @@ public class Searchdomain
|
|||||||
{
|
{
|
||||||
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding
|
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding
|
||||||
{
|
{
|
||||||
if (!embeddingCache.TryGet(kvp.Key, out var _))
|
if (!embeddingCache.TryGetValue(kvp.Key, out var _))
|
||||||
{
|
{
|
||||||
embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value;
|
embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ using Shared.Models;
|
|||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
using Microsoft.Extensions.Options;
|
using Microsoft.Extensions.Options;
|
||||||
using Server.Models;
|
using Server.Models;
|
||||||
|
using Shared;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ public class SearchdomainManager
|
|||||||
private readonly string connectionString;
|
private readonly string connectionString;
|
||||||
private MySqlConnection connection;
|
private MySqlConnection connection;
|
||||||
public SQLHelper helper;
|
public SQLHelper helper;
|
||||||
public LRUCache<string, Dictionary<string, float[]>> embeddingCache;
|
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
|
||||||
public long EmbeddingCacheMaxCount;
|
public long EmbeddingCacheMaxCount;
|
||||||
|
|
||||||
public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper)
|
public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper)
|
||||||
|
|||||||
Reference in New Issue
Block a user