diff --git a/src/Indexer/Scripts/example.py b/src/Indexer/Scripts/example.py index 6a65b48..8a0dacf 100644 --- a/src/Indexer/Scripts/example.py +++ b/src/Indexer/Scripts/example.py @@ -5,7 +5,7 @@ from dataclasses import asdict import time example_content = "./Scripts/example_content" -probmethod = "DictionaryWeightedAverage" +probmethod = "LVEWAvg" example_searchdomain = "example_" + probmethod example_counter = 0 models = ["bge-m3", "mxbai-embed-large"] diff --git a/src/Server/Controllers/EntityController.cs b/src/Server/Controllers/EntityController.cs index 62336d6..ecdb52e 100644 --- a/src/Server/Controllers/EntityController.cs +++ b/src/Server/Controllers/EntityController.cs @@ -48,6 +48,7 @@ public class EntityController : ControllerBase _domainManager.embeddingCache, _domainManager.client, _domainManager.helper, + _logger, JsonSerializer.Serialize(jsonEntities)); if (entities is not null && jsonEntities is not null) { @@ -58,9 +59,9 @@ public class EntityController : ControllerBase if (entities.Select(x => x.name == jsonEntityName).Any() && !invalidatedSearchdomains.Contains(jsonEntityName)) { - string jsonEntitySearchdomain = jsonEntity.Searchdomain; - invalidatedSearchdomains.Add(jsonEntitySearchdomain); - _domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomain); + string jsonEntitySearchdomainName = jsonEntity.Searchdomain; + invalidatedSearchdomains.Add(jsonEntitySearchdomainName); + _domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName); } } return Ok(new EntityIndexResult() { Success = true }); @@ -103,11 +104,11 @@ public class EntityController : ControllerBase { embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2}); } - datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = embeddingResults}); + datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults}); } else { - datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = null}); + datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null}); } } EntityListResult entityListResult = new() diff --git a/src/Server/Datapoint.cs b/src/Server/Datapoint.cs index 4d7eaf9..e7e6ccf 100644 --- a/src/Server/Datapoint.cs +++ b/src/Server/Datapoint.cs @@ -12,11 +12,11 @@ namespace Server; public class Datapoint { public string name; - public Probmethods.probMethodDelegate probMethod; + public ProbMethod probMethod; public List<(string, float[])> embeddings; public string hash; - public Datapoint(string name, Probmethods.probMethodDelegate probMethod, string hash, List<(string, float[])> embeddings) + public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings) { this.name = name; this.probMethod = probMethod; @@ -24,21 +24,9 @@ public class Datapoint this.embeddings = embeddings; } - // public Datapoint(string name, Probmethods.probMethodDelegate probmethod, string content, List models, OllamaApiClient ollama) - // { - // this.name = name; - // this.probMethod = probmethod; - // embeddings = GenerateEmbeddings(content, models, ollama); - // } - - // public float CalcProbability() - // { - // return probMethod(embeddings); // <--- prob method is not used with the embeddings! - // } - public float CalcProbability(List<(string, float)> probabilities) { - return probMethod(probabilities); + return probMethod.method(probabilities); } public static Dictionary GenerateEmbeddings(string content, List models, OllamaApiClient ollama) @@ -106,11 +94,10 @@ public class Datapoint Input = [content] }; - var response = ollama.GenerateEmbeddingAsync(content, new EmbeddingGenerationOptions(){ModelId=model}).Result; + var response = ollama.EmbedAsync(request).Result; if (response is not null) { - float[] var = new float[response.Vector.Length]; - response.Vector.CopyTo(var); + float[] var = [.. response.Embeddings.First()]; retVal[model] = var; if (!embeddingCache.ContainsKey(model)) { diff --git a/src/Server/Helper/SQLHelper.cs b/src/Server/Helper/SQLHelper.cs index 3e842cb..f4a512c 100644 --- a/src/Server/Helper/SQLHelper.cs +++ b/src/Server/Helper/SQLHelper.cs @@ -3,7 +3,7 @@ using MySql.Data.MySqlClient; namespace Server; -public class SQLHelper +public class SQLHelper:IDisposable { public MySqlConnection connection; public string connectionString; @@ -19,6 +19,12 @@ public class SQLHelper return new SQLHelper(newConnection, connectionString); } + public void Dispose() + { + connection.Close(); + GC.SuppressFinalize(this); + } + public DbDataReader ExecuteSQLCommand(string query, Dictionary parameters) { lock (connection) diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 1486d78..14e7467 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -41,7 +41,7 @@ public static class SearchdomainHelper return null; } - public static List? EntitiesFromJSON(List entityCache, Dictionary> embeddingCache, OllamaApiClient ollama, SQLHelper helper, string json) + public static List? EntitiesFromJSON(List entityCache, Dictionary> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json) { List? jsonEntities = JsonSerializer.Deserialize>(json); if (jsonEntities is null) @@ -67,8 +67,8 @@ public static class SearchdomainHelper ConcurrentQueue retVal = []; Parallel.ForEach(jsonEntities, jSONEntity => { - var tempHelper = helper.DuplicateConnection(); - var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, jSONEntity); + using var tempHelper = helper.DuplicateConnection(); + var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity); if (entity is not null) { retVal.Enqueue(entity); @@ -77,7 +77,7 @@ public static class SearchdomainHelper return [.. retVal]; } - public static Entity? EntityFromJSON(List entityCache, Dictionary> embeddingCache, OllamaApiClient ollama, SQLHelper helper, JSONEntity jsonEntity) //string json) + public static Entity? EntityFromJSON(List entityCache, Dictionary> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json) { Dictionary> embeddingsLUT = []; int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); @@ -123,7 +123,7 @@ public static class SearchdomainHelper { embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache); } - var probMethod_embedding = Probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); + var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries List<(string model, byte[] embedding)> data = []; diff --git a/src/Server/Probmethods.cs b/src/Server/Probmethods.cs index eaa00fc..054d543 100644 --- a/src/Server/Probmethods.cs +++ b/src/Server/Probmethods.cs @@ -3,6 +3,24 @@ using System.Text.Json; namespace Server; +public class ProbMethod +{ + public Probmethods.probMethodDelegate method; + public string name; + + public ProbMethod(string name, ILogger logger) + { + this.name = name; + Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name); + if (probMethod is null) + { + logger.LogError("Unable to retrieve probMethod {name}", [name]); + throw new Exception("Unable to retrieve probMethod"); + } + method = probMethod; + } +} + public static class Probmethods { public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters); diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index 30260ea..e3aa9b3 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -38,16 +38,18 @@ public class Searchdomain public int embeddingCacheMaxSize = 10000000; private readonly MySqlConnection connection; public SQLHelper helper; + private readonly ILogger _logger; // TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain() - public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary> embeddingCache, string provider = "sqlserver", bool runEmpty = false) + public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) { _connectionString = connectionString; _provider = provider.ToLower(); this.searchdomain = searchdomain; this.ollama = ollama; this.embeddingCache = embeddingCache; + this._logger = logger; searchCache = []; entityCache = []; connection = new MySqlConnection(connectionString); @@ -57,12 +59,13 @@ public class Searchdomain if (!runEmpty) { GetID(); - UpdateSearchDomain(); + UpdateEntityCache(); } } - public void UpdateSearchDomain() + public void UpdateEntityCache() { + entityCache = []; Dictionary parametersIDSearchdomain = new() { ["id"] = this.id @@ -99,7 +102,7 @@ public class Searchdomain string name = datapointReader.GetString(2); string probmethodString = datapointReader.GetString(3); string hash = datapointReader.GetString(4); - Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString); + ProbMethod probmethod = new(probmethodString, _logger); if (embedding_unassigned.TryGetValue(id, out Dictionary? embeddings) && probmethod is not null) { embedding_unassigned.Remove(id); @@ -179,7 +182,7 @@ public class Searchdomain float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2); list.Add((key, value)); } - datapointProbs.Add((datapoint.name, datapoint.probMethod(list))); + datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list))); } result.Add((entity.probMethod(datapointProbs), entity.name)); } diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index a6c2878..08e8532 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -53,7 +53,7 @@ public class SearchdomainManager } try { - return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache)); + return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger)); } catch (MySqlException) { @@ -62,9 +62,12 @@ public class SearchdomainManager } } - public void InvalidateSearchdomainCache(string searchdomain) + public void InvalidateSearchdomainCache(string searchdomainName) { - searchdomains.Remove(searchdomain); + if (searchdomains.TryGetValue(searchdomainName, out var searchdomain)) + { + searchdomain.UpdateEntityCache(); + } } public List ListSearchdomains() diff --git a/src/Server/Server.csproj b/src/Server/Server.csproj index ffb8eb7..ed6df4d 100644 --- a/src/Server/Server.csproj +++ b/src/Server/Server.csproj @@ -15,7 +15,7 @@ - +