diff --git a/src/Server/Controllers/EntityController.cs b/src/Server/Controllers/EntityController.cs index ecdb52e..d05a9e7 100644 --- a/src/Server/Controllers/EntityController.cs +++ b/src/Server/Controllers/EntityController.cs @@ -56,10 +56,10 @@ public class EntityController : ControllerBase foreach (var jsonEntity in jsonEntities) { string jsonEntityName = jsonEntity.Name; + string jsonEntitySearchdomainName = jsonEntity.Searchdomain; if (entities.Select(x => x.name == jsonEntityName).Any() - && !invalidatedSearchdomains.Contains(jsonEntityName)) + && !invalidatedSearchdomains.Contains(jsonEntitySearchdomainName)) { - string jsonEntitySearchdomainName = jsonEntity.Searchdomain; invalidatedSearchdomains.Add(jsonEntitySearchdomainName); _domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName); } diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 14e7467..508b254 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -118,8 +118,24 @@ public static class SearchdomainHelper foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints) { string hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); - Dictionary embeddings = embeddingsLUT.ContainsKey(hash) ? embeddingsLUT[hash] : []; - if (embeddings.Count == 0) + Dictionary embeddings = []; + if (embeddingsLUT.ContainsKey(hash)) + { + Dictionary hashLUT = embeddingsLUT[hash]; + foreach (string model in jsonDatapoint.Model) + { + if (hashLUT.ContainsKey(model)) + { + embeddings.Add(model, hashLUT[model]); + } + else + { + var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], ollama, embeddingCache); + embeddings.Add(model, additionalEmbeddings.First().Value); + } + } + } + else { embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache); } diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index e3aa9b3..0a5b84d 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -65,7 +65,6 @@ public class Searchdomain public void UpdateEntityCache() { - entityCache = []; Dictionary parametersIDSearchdomain = new() { ["id"] = this.id @@ -92,7 +91,7 @@ public class Searchdomain } } embeddingReader.Close(); - + DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, hash FROM datapoint", parametersIDSearchdomain); Dictionary> datapoint_unassigned = []; while (datapointReader.Read()) @@ -115,7 +114,7 @@ public class Searchdomain } datapointReader.Close(); - DbDataReader attributeReader = helper.ExecuteSQLCommand("SELECT id, id_entity, attribute, value FROM attribute", parametersIDSearchdomain); + DbDataReader attributeReader = helper.ExecuteSQLCommand("SELECT id, id_entity, attribute, value FROM attribute", parametersIDSearchdomain); Dictionary> attributes_unassigned = []; while (attributeReader.Read()) { @@ -132,6 +131,7 @@ public class Searchdomain } attributeReader.Close(); + entityCache = []; DbDataReader entityReader = helper.ExecuteSQLCommand("SELECT entity.id, name, probmethod FROM entity WHERE id_searchdomain=@id", parametersIDSearchdomain); while (entityReader.Read()) { @@ -155,6 +155,7 @@ public class Searchdomain } entityReader.Close(); modelsInUse = GetModels(entityCache); + embeddingCache = []; // TODO remove this and implement proper remediation to improve performance } public List<(float, string)> Search(string query, bool sort=true) @@ -166,7 +167,7 @@ public class Searchdomain { // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries embeddingCache.Add(query, queryEmbeddings); } - } + } // TODO implement proper cache remediation for embeddingCache here List<(float, string)> result = []; diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index 08e8532..eca4e6e 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -64,10 +64,7 @@ public class SearchdomainManager public void InvalidateSearchdomainCache(string searchdomainName) { - if (searchdomains.TryGetValue(searchdomainName, out var searchdomain)) - { - searchdomain.UpdateEntityCache(); - } + GetSearchdomain(searchdomainName).UpdateEntityCache(); } public List ListSearchdomains()