Merge pull request #57 from LD-Reborn/54-properly-implement-embeddings-cache-size-limit-global

Implemented cache reconciliation
This commit is contained in:
LD50
2025-12-28 00:22:15 +01:00
committed by GitHub
4 changed files with 100 additions and 34 deletions

View File

@@ -44,7 +44,6 @@ public class EntityController : ControllerBase
&& !invalidatedSearchdomains.Contains(jsonEntitySearchdomainName)) && !invalidatedSearchdomains.Contains(jsonEntitySearchdomainName))
{ {
invalidatedSearchdomains.Add(jsonEntitySearchdomainName); invalidatedSearchdomains.Add(jsonEntitySearchdomainName);
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName);
} }
} }
return Ok(new EntityIndexResult() { Success = true }); return Ok(new EntityIndexResult() { Success = true });
@@ -122,6 +121,7 @@ public class EntityController : ControllerBase
_logger.LogError("Unable to delete the entity {entityName} in {searchdomain} - it was not found under the specified name", [entityName, searchdomain]); _logger.LogError("Unable to delete the entity {entityName} in {searchdomain} - it was not found under the specified name", [entityName, searchdomain]);
return Ok(new EntityDeleteResults() {Success = false, Message = "Entity not found"}); return Ok(new EntityDeleteResults() {Success = false, Message = "Entity not found"});
} }
searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_);
_databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain); _databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain);
searchdomain_.entityCache.RemoveAll(entity => entity.name == entityName); searchdomain_.entityCache.RemoveAll(entity => entity.name == entityName);
return Ok(new EntityDeleteResults() {Success = true}); return Ok(new EntityDeleteResults() {Success = true});

View File

@@ -94,7 +94,8 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
AIProvider aIProvider = searchdomain.aIProvider; AIProvider aIProvider = searchdomain.aIProvider;
LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; LRUCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
bool invalidateSearchCache = false;
if (preexistingEntity is not null) if (preexistingEntity is not null)
{ {
int? preexistingEntityID = _databaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); int? preexistingEntityID = _databaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
@@ -162,6 +163,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
helper.ExecuteSQLNonQuery("DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", parameters); helper.ExecuteSQLNonQuery("DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", parameters);
helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity=@entityId AND name=@datapointName", parameters); helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity=@entityId AND name=@datapointName", parameters);
preexistingEntity.datapoints.Remove(datapoint); preexistingEntity.datapoints.Remove(datapoint);
invalidateSearchCache = true;
} else } else
{ {
JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name); JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name);
@@ -178,7 +180,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
preexistingEntity.datapoints.Remove(datapoint); preexistingEntity.datapoints.Remove(datapoint);
Datapoint newDatapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, newEntityDatapoint, (int)preexistingEntityID); Datapoint newDatapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, newEntityDatapoint, (int)preexistingEntityID);
preexistingEntity.datapoints.Add(newDatapoint); preexistingEntity.datapoints.Add(newDatapoint);
invalidateSearchCache = true;
} }
if (newEntityDatapoint is not null && (newEntityDatapoint.Probmethod_embedding != datapoint.probMethod.probMethodEnum || newEntityDatapoint.SimilarityMethod != datapoint.similarityMethod.similarityMethodEnum)) if (newEntityDatapoint is not null && (newEntityDatapoint.Probmethod_embedding != datapoint.probMethod.probMethodEnum || newEntityDatapoint.SimilarityMethod != datapoint.similarityMethod.similarityMethodEnum))
{ {
@@ -194,6 +196,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
Datapoint preexistingDatapoint = preexistingEntity.datapoints.First(x => x == datapoint); // The for loop is a copy. This retrieves the original such that it can be updated. Datapoint preexistingDatapoint = preexistingEntity.datapoints.First(x => x == datapoint); // The for loop is a copy. This retrieves the original such that it can be updated.
preexistingDatapoint.probMethod = datapoint.probMethod; preexistingDatapoint.probMethod = datapoint.probMethod;
preexistingDatapoint.similarityMethod = datapoint.similarityMethod; preexistingDatapoint.similarityMethod = datapoint.similarityMethod;
invalidateSearchCache = true;
} }
} }
} }
@@ -205,10 +208,14 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
// Datapoint - New // Datapoint - New
Datapoint datapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, jsonDatapoint, (int)preexistingEntityID); Datapoint datapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, jsonDatapoint, (int)preexistingEntityID);
preexistingEntity.datapoints.Add(datapoint); preexistingEntity.datapoints.Add(datapoint);
invalidateSearchCache = true;
} }
} }
if (invalidateSearchCache)
{
searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity);
}
return preexistingEntity; return preexistingEntity;
} }
else else
@@ -233,6 +240,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
id = id_entity id = id_entity
}; };
entityCache.Add(entity); entityCache.Add(entity);
searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(entity);
return entity; return entity;
} }
} }

View File

@@ -169,9 +169,33 @@ public class Searchdomain
return [.. cachedResult.Results.Select(r => (r.Score, r.Name))]; return [.. cachedResult.Results.Select(r => (r.Score, r.Name))];
} }
bool hasQuery = embeddingCache.TryGet(query, out Dictionary<string, float[]>? queryEmbeddings); Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query);
List<(float, string)> result = [];
foreach (Entity entity in entityCache)
{
result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.name));
}
IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1);
if (topN is not null)
{
sortedResults = sortedResults.Take(topN ?? 0);
}
List<(float, string)> results = [.. sortedResults];
List<ResultItem> searchResult = new(
[.. sortedResults.Select(r =>
new ResultItem(r.Item1, r.Item2 ))]
);
searchCache[query] = new DateTimedSearchResult(DateTime.Now, searchResult);
return results;
}
public Dictionary<string, float[]> GetQueryEmbeddings(string query)
{
bool hasQuery = embeddingCache.TryGet(query, out Dictionary<string, float[]> queryEmbeddings);
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model)); bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
if (!(hasQuery && allModelsInQuery)) if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null)
{ {
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache); queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery)) if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery))
@@ -189,38 +213,25 @@ public class Searchdomain
} }
} }
} }
return queryEmbeddings;
}
List<(float, string)> result = []; private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary<string, float[]> queryEmbeddings)
{
foreach (Entity entity in entityCache) List<(string, float)> datapointProbs = [];
foreach (Datapoint datapoint in entity.datapoints)
{ {
List<(string, float)> datapointProbs = []; SimilarityMethod similarityMethod = datapoint.similarityMethod;
foreach (Datapoint datapoint in entity.datapoints) List<(string, float)> list = [];
foreach ((string, float[]) embedding in datapoint.embeddings)
{ {
SimilarityMethod similarityMethod = datapoint.similarityMethod; string key = embedding.Item1;
List<(string, float)> list = []; float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2);
foreach ((string, float[]) embedding in datapoint.embeddings) list.Add((key, value));
{
string key = embedding.Item1;
float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2);
list.Add((key, value));
}
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
} }
result.Add((entity.probMethod(datapointProbs), entity.name)); datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
} }
IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1); return entity.probMethod(datapointProbs);
if (topN is not null)
{
sortedResults = sortedResults.Take(topN ?? 0);
}
List<(float, string)> results = [.. sortedResults];
List<ResultItem> searchResult = new(
[.. sortedResults.Select(r =>
new ResultItem(r.Item1, r.Item2 ))]
);
searchCache[query] = new DateTimedSearchResult(DateTime.Now, searchResult);
return results;
} }
public static List<string> GetModels(List<Entity> entities) public static List<string> GetModels(List<Entity> entities)
@@ -269,6 +280,53 @@ public class Searchdomain
return JsonSerializer.Deserialize<SearchdomainSettings>(settingsString); return JsonSerializer.Deserialize<SearchdomainSettings>(settingsString);
} }
public void ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(Entity entity)
{
if (settings.CacheReconciliation)
{
foreach (KeyValuePair<string, DateTimedSearchResult> element in searchCache)
{
string query = element.Key;
DateTimedSearchResult searchResult = element.Value;
Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query);
float evaluationResult = EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings);
searchResult.Results.RemoveAll(x => x.Name == entity.name); // If entity already exists in that results list: remove it.
ResultItem newItem = new(evaluationResult, entity.name);
int index = searchResult.Results.BinarySearch(
newItem,
Comparer<ResultItem>.Create((a, b) => b.Score.CompareTo(a.Score)) // Invert searching order
);
if (index < 0) // If not found, BinarySearch gives the bitwise complement
index = ~index;
searchResult.Results.Insert(index, newItem);
}
}
else
{
InvalidateSearchCache();
}
}
public void ReconciliateOrInvalidateCacheForDeletedEntity(Entity entity)
{
if (settings.CacheReconciliation)
{
foreach (KeyValuePair<string, DateTimedSearchResult> element in searchCache)
{
string query = element.Key;
DateTimedSearchResult searchResult = element.Value;
searchResult.Results.RemoveAll(x => x.Name == entity.name);
}
}
else
{
InvalidateSearchCache();
}
}
public void InvalidateSearchCache() public void InvalidateSearchCache()
{ {
searchCache = []; searchCache = [];

View File

@@ -69,7 +69,7 @@ public class SearchdomainManager
{ {
var searchdomain = GetSearchdomain(searchdomainName); var searchdomain = GetSearchdomain(searchdomainName);
searchdomain.UpdateEntityCache(); searchdomain.UpdateEntityCache();
searchdomain.InvalidateSearchCache(); // TODO implement cache remediation (Suggestion: searchdomain-wide setting for cache remediation / invalidation - ) searchdomain.InvalidateSearchCache();
} }
public List<string> ListSearchdomains() public List<string> ListSearchdomains()