Merge pull request #55 from LD-Reborn/54-properly-implement-embeddings-cache-size-limit-global
Moved embeddingCache from Dictionary to LRUCache
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
using AdaptiveExpressions;
|
||||||
using OllamaSharp;
|
using OllamaSharp;
|
||||||
using OllamaSharp.Models;
|
using OllamaSharp.Models;
|
||||||
|
|
||||||
@@ -27,30 +28,31 @@ public class Datapoint
|
|||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
|
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
|
||||||
{
|
{
|
||||||
return GenerateEmbeddings(content, models, aIProvider, []);
|
return GenerateEmbeddings(content, models, aIProvider, new());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
{
|
{
|
||||||
Dictionary<string, float[]> retVal = [];
|
Dictionary<string, float[]> retVal = [];
|
||||||
foreach (string model in models)
|
foreach (string model in models)
|
||||||
{
|
{
|
||||||
if (embeddingCache.ContainsKey(model) && embeddingCache[model].ContainsKey(content))
|
bool embeddingCacheHasModel = embeddingCache.TryGet(model, out var embeddingCacheForModel);
|
||||||
|
if (embeddingCacheHasModel && embeddingCacheForModel.ContainsKey(content))
|
||||||
{
|
{
|
||||||
retVal[model] = embeddingCache[model][content];
|
retVal[model] = embeddingCacheForModel[content];
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
var response = aIProvider.GenerateEmbeddings(model, [content]);
|
var response = aIProvider.GenerateEmbeddings(model, [content]);
|
||||||
if (response is not null)
|
if (response is not null)
|
||||||
{
|
{
|
||||||
retVal[model] = response;
|
retVal[model] = response;
|
||||||
if (!embeddingCache.ContainsKey(model))
|
if (!embeddingCacheHasModel)
|
||||||
{
|
{
|
||||||
embeddingCache[model] = [];
|
embeddingCacheForModel = [];
|
||||||
}
|
}
|
||||||
if (!embeddingCache[model].ContainsKey(content))
|
if (!embeddingCacheForModel.ContainsKey(content))
|
||||||
{
|
{
|
||||||
embeddingCache[model][content] = response;
|
embeddingCacheForModel[content] = response;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ using System.Collections.Concurrent;
|
|||||||
using System.Security.Cryptography;
|
using System.Security.Cryptography;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
|
using AdaptiveExpressions;
|
||||||
using Server.Exceptions;
|
using Server.Exceptions;
|
||||||
using Shared.Models;
|
using Shared.Models;
|
||||||
|
|
||||||
@@ -46,7 +47,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
|
|||||||
|
|
||||||
public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
|
public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
|
||||||
{
|
{
|
||||||
Dictionary<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache;
|
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache;
|
||||||
AIProvider aIProvider = searchdomainManager.aIProvider;
|
AIProvider aIProvider = searchdomainManager.aIProvider;
|
||||||
SQLHelper helper = searchdomainManager.helper;
|
SQLHelper helper = searchdomainManager.helper;
|
||||||
|
|
||||||
@@ -91,7 +92,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
|
|||||||
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
|
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
|
||||||
List<Entity> entityCache = searchdomain.entityCache;
|
List<Entity> entityCache = searchdomain.entityCache;
|
||||||
AIProvider aIProvider = searchdomain.aIProvider;
|
AIProvider aIProvider = searchdomain.aIProvider;
|
||||||
Dictionary<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
||||||
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
|
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
|
||||||
|
|
||||||
if (preexistingEntity is not null)
|
if (preexistingEntity is not null)
|
||||||
@@ -261,7 +262,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
|
|||||||
throw new Exception("jsonDatapoint.Text must not be null at this point");
|
throw new Exception("jsonDatapoint.Text must not be null at this point");
|
||||||
}
|
}
|
||||||
using SQLHelper helper = searchdomain.helper.DuplicateConnection();
|
using SQLHelper helper = searchdomain.helper.DuplicateConnection();
|
||||||
Dictionary<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
|
||||||
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
|
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
|
||||||
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
|
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
|
||||||
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);
|
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ using ElmahCore.Mvc.Logger;
|
|||||||
using MySql.Data.MySqlClient;
|
using MySql.Data.MySqlClient;
|
||||||
using Server.Helper;
|
using Server.Helper;
|
||||||
using Shared.Models;
|
using Shared.Models;
|
||||||
|
using AdaptiveExpressions;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -19,13 +20,12 @@ public class Searchdomain
|
|||||||
public Dictionary<string, DateTimedSearchResult> searchCache; // Key: query, Value: Search results for that query (with timestamp)
|
public Dictionary<string, DateTimedSearchResult> searchCache; // Key: query, Value: Search results for that query (with timestamp)
|
||||||
public List<Entity> entityCache;
|
public List<Entity> entityCache;
|
||||||
public List<string> modelsInUse;
|
public List<string> modelsInUse;
|
||||||
public Dictionary<string, Dictionary<string, float[]>> embeddingCache;
|
public LRUCache<string, Dictionary<string, float[]>> embeddingCache;
|
||||||
public int embeddingCacheMaxSize = 10000000;
|
|
||||||
private readonly MySqlConnection connection;
|
private readonly MySqlConnection connection;
|
||||||
public SQLHelper helper;
|
public SQLHelper helper;
|
||||||
private readonly ILogger _logger;
|
private readonly ILogger _logger;
|
||||||
|
|
||||||
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, LRUCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
||||||
{
|
{
|
||||||
_connectionString = connectionString;
|
_connectionString = connectionString;
|
||||||
_provider = provider.ToLower();
|
_provider = provider.ToLower();
|
||||||
@@ -169,25 +169,22 @@ public class Searchdomain
|
|||||||
return [.. cachedResult.Results.Select(r => (r.Score, r.Name))];
|
return [.. cachedResult.Results.Select(r => (r.Score, r.Name))];
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings);
|
bool hasQuery = embeddingCache.TryGet(query, out Dictionary<string, float[]>? queryEmbeddings);
|
||||||
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
|
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
|
||||||
if (!(hasQuery && allModelsInQuery))
|
if (!(hasQuery && allModelsInQuery))
|
||||||
{
|
{
|
||||||
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
|
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
|
||||||
if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits
|
if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery))
|
||||||
{ // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries
|
{
|
||||||
if (!embeddingCache.ContainsKey(query))
|
embeddingCache.Set(query, queryEmbeddings);
|
||||||
|
}
|
||||||
|
else // embeddingCache already has an entry for this query, so the missing model-embedding pairs have to be filled in
|
||||||
|
{
|
||||||
|
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding
|
||||||
{
|
{
|
||||||
embeddingCache.Add(query, queryEmbeddings);
|
if (!embeddingCache.TryGet(kvp.Key, out var _))
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings)
|
|
||||||
{
|
{
|
||||||
if (!embeddingCache.ContainsKey(kvp.Key))
|
embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value;
|
||||||
{
|
|
||||||
embeddingCache[query][kvp.Key] = kvp.Value;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ using System.Data.Common;
|
|||||||
using Server.Migrations;
|
using Server.Migrations;
|
||||||
using Server.Helper;
|
using Server.Helper;
|
||||||
using Server.Exceptions;
|
using Server.Exceptions;
|
||||||
|
using AdaptiveExpressions;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -16,7 +17,8 @@ public class SearchdomainManager
|
|||||||
private readonly string connectionString;
|
private readonly string connectionString;
|
||||||
private MySqlConnection connection;
|
private MySqlConnection connection;
|
||||||
public SQLHelper helper;
|
public SQLHelper helper;
|
||||||
public Dictionary<string, Dictionary<string, float[]>> embeddingCache;
|
public LRUCache<string, Dictionary<string, float[]>> embeddingCache;
|
||||||
|
public int EmbeddingCacheMaxCount;
|
||||||
|
|
||||||
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config, AIProvider aIProvider, DatabaseHelper databaseHelper)
|
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config, AIProvider aIProvider, DatabaseHelper databaseHelper)
|
||||||
{
|
{
|
||||||
@@ -24,7 +26,8 @@ public class SearchdomainManager
|
|||||||
_config = config;
|
_config = config;
|
||||||
this.aIProvider = aIProvider;
|
this.aIProvider = aIProvider;
|
||||||
_databaseHelper = databaseHelper;
|
_databaseHelper = databaseHelper;
|
||||||
embeddingCache = [];
|
EmbeddingCacheMaxCount = config.GetValue<int?>("Embeddingsearch:EmbeddingCacheMaxCount") ?? 1000000;
|
||||||
|
embeddingCache = new(EmbeddingCacheMaxCount);
|
||||||
connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? "";
|
connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? "";
|
||||||
connection = new MySqlConnection(connectionString);
|
connection = new MySqlConnection(connectionString);
|
||||||
connection.Open();
|
connection.Open();
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
<PackageReference Include="AdaptiveExpressions" Version="4.23.0" />
|
||||||
<PackageReference Include="ElmahCore" Version="2.1.2" />
|
<PackageReference Include="ElmahCore" Version="2.1.2" />
|
||||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
|
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
|
||||||
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
|
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
|
||||||
|
|||||||
@@ -24,6 +24,7 @@
|
|||||||
"172.17.0.1"
|
"172.17.0.1"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"EmbeddingCacheMaxCount": 5,
|
||||||
"AiProviders": {
|
"AiProviders": {
|
||||||
"ollama": {
|
"ollama": {
|
||||||
"handler": "ollama",
|
"handler": "ollama",
|
||||||
@@ -35,6 +36,15 @@
|
|||||||
"ApiKey": "Some API key here"
|
"ApiKey": "Some API key here"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"SimpleAuth": {
|
||||||
|
"Users": [
|
||||||
|
{
|
||||||
|
"Username": "admin",
|
||||||
|
"Password": "UnsafePractice.67",
|
||||||
|
"Roles": ["Admin"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"ApiKeys": ["Some UUID here", "Another UUID here"],
|
"ApiKeys": ["Some UUID here", "Another UUID here"],
|
||||||
"UseHttpsRedirection": true
|
"UseHttpsRedirection": true
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user