diff --git a/src/Server/Controllers/EntityController.cs b/src/Server/Controllers/EntityController.cs index 2698f6a..0281c96 100644 --- a/src/Server/Controllers/EntityController.cs +++ b/src/Server/Controllers/EntityController.cs @@ -44,7 +44,6 @@ public class EntityController : ControllerBase && !invalidatedSearchdomains.Contains(jsonEntitySearchdomainName)) { invalidatedSearchdomains.Add(jsonEntitySearchdomainName); - _domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName); } } return Ok(new EntityIndexResult() { Success = true }); @@ -122,6 +121,7 @@ public class EntityController : ControllerBase _logger.LogError("Unable to delete the entity {entityName} in {searchdomain} - it was not found under the specified name", [entityName, searchdomain]); return Ok(new EntityDeleteResults() {Success = false, Message = "Entity not found"}); } + searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_); _databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain); searchdomain_.entityCache.RemoveAll(entity => entity.name == entityName); return Ok(new EntityDeleteResults() {Success = true}); diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 559e29b..4012689 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -94,7 +94,8 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp AIProvider aIProvider = searchdomain.aIProvider; LRUCache> embeddingCache = searchdomain.embeddingCache; Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); - + bool invalidateSearchCache = false; + if (preexistingEntity is not null) { int? preexistingEntityID = _databaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); @@ -162,6 +163,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp helper.ExecuteSQLNonQuery("DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", parameters); helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity=@entityId AND name=@datapointName", parameters); preexistingEntity.datapoints.Remove(datapoint); + invalidateSearchCache = true; } else { JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name); @@ -178,7 +180,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp preexistingEntity.datapoints.Remove(datapoint); Datapoint newDatapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, newEntityDatapoint, (int)preexistingEntityID); preexistingEntity.datapoints.Add(newDatapoint); - + invalidateSearchCache = true; } if (newEntityDatapoint is not null && (newEntityDatapoint.Probmethod_embedding != datapoint.probMethod.probMethodEnum || newEntityDatapoint.SimilarityMethod != datapoint.similarityMethod.similarityMethodEnum)) { @@ -194,6 +196,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp Datapoint preexistingDatapoint = preexistingEntity.datapoints.First(x => x == datapoint); // The for loop is a copy. This retrieves the original such that it can be updated. preexistingDatapoint.probMethod = datapoint.probMethod; preexistingDatapoint.similarityMethod = datapoint.similarityMethod; + invalidateSearchCache = true; } } } @@ -205,10 +208,14 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp // Datapoint - New Datapoint datapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, jsonDatapoint, (int)preexistingEntityID); preexistingEntity.datapoints.Add(datapoint); + invalidateSearchCache = true; } } - + if (invalidateSearchCache) + { + searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity); + } return preexistingEntity; } else @@ -233,6 +240,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp id = id_entity }; entityCache.Add(entity); + searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(entity); return entity; } } diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index 81502cd..421fe8d 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -169,9 +169,33 @@ public class Searchdomain return [.. cachedResult.Results.Select(r => (r.Score, r.Name))]; } - bool hasQuery = embeddingCache.TryGet(query, out Dictionary? queryEmbeddings); + Dictionary queryEmbeddings = GetQueryEmbeddings(query); + + List<(float, string)> result = []; + + foreach (Entity entity in entityCache) + { + result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.name)); + } + IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1); + if (topN is not null) + { + sortedResults = sortedResults.Take(topN ?? 0); + } + List<(float, string)> results = [.. sortedResults]; + List searchResult = new( + [.. sortedResults.Select(r => + new ResultItem(r.Item1, r.Item2 ))] + ); + searchCache[query] = new DateTimedSearchResult(DateTime.Now, searchResult); + return results; + } + + public Dictionary GetQueryEmbeddings(string query) + { + bool hasQuery = embeddingCache.TryGet(query, out Dictionary queryEmbeddings); bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model)); - if (!(hasQuery && allModelsInQuery)) + if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null) { queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider, embeddingCache); if (!embeddingCache.TryGet(query, out var embeddingCacheForCurrentQuery)) @@ -189,38 +213,25 @@ public class Searchdomain } } } + return queryEmbeddings; + } - List<(float, string)> result = []; - - foreach (Entity entity in entityCache) + private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary queryEmbeddings) + { + List<(string, float)> datapointProbs = []; + foreach (Datapoint datapoint in entity.datapoints) { - List<(string, float)> datapointProbs = []; - foreach (Datapoint datapoint in entity.datapoints) + SimilarityMethod similarityMethod = datapoint.similarityMethod; + List<(string, float)> list = []; + foreach ((string, float[]) embedding in datapoint.embeddings) { - SimilarityMethod similarityMethod = datapoint.similarityMethod; - List<(string, float)> list = []; - foreach ((string, float[]) embedding in datapoint.embeddings) - { - string key = embedding.Item1; - float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2); - list.Add((key, value)); - } - datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list))); + string key = embedding.Item1; + float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2); + list.Add((key, value)); } - result.Add((entity.probMethod(datapointProbs), entity.name)); + datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list))); } - IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1); - if (topN is not null) - { - sortedResults = sortedResults.Take(topN ?? 0); - } - List<(float, string)> results = [.. sortedResults]; - List searchResult = new( - [.. sortedResults.Select(r => - new ResultItem(r.Item1, r.Item2 ))] - ); - searchCache[query] = new DateTimedSearchResult(DateTime.Now, searchResult); - return results; + return entity.probMethod(datapointProbs); } public static List GetModels(List entities) @@ -269,6 +280,53 @@ public class Searchdomain return JsonSerializer.Deserialize(settingsString); } + public void ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(Entity entity) + { + if (settings.CacheReconciliation) + { + foreach (KeyValuePair element in searchCache) + { + string query = element.Key; + DateTimedSearchResult searchResult = element.Value; + + Dictionary queryEmbeddings = GetQueryEmbeddings(query); + float evaluationResult = EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings); + + searchResult.Results.RemoveAll(x => x.Name == entity.name); // If entity already exists in that results list: remove it. + + ResultItem newItem = new(evaluationResult, entity.name); + int index = searchResult.Results.BinarySearch( + newItem, + Comparer.Create((a, b) => b.Score.CompareTo(a.Score)) // Invert searching order + ); + if (index < 0) // If not found, BinarySearch gives the bitwise complement + index = ~index; + searchResult.Results.Insert(index, newItem); + } + } + else + { + InvalidateSearchCache(); + } + } + + public void ReconciliateOrInvalidateCacheForDeletedEntity(Entity entity) + { + if (settings.CacheReconciliation) + { + foreach (KeyValuePair element in searchCache) + { + string query = element.Key; + DateTimedSearchResult searchResult = element.Value; + searchResult.Results.RemoveAll(x => x.Name == entity.name); + } + } + else + { + InvalidateSearchCache(); + } + } + public void InvalidateSearchCache() { searchCache = []; diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index c024c2c..81d1edb 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -69,7 +69,7 @@ public class SearchdomainManager { var searchdomain = GetSearchdomain(searchdomainName); searchdomain.UpdateEntityCache(); - searchdomain.InvalidateSearchCache(); // TODO implement cache remediation (Suggestion: searchdomain-wide setting for cache remediation / invalidation - ) + searchdomain.InvalidateSearchCache(); } public List ListSearchdomains()