Fixed broken ProbMethod in Entity/Index, fixed undisposed MySQL connections, dependency update, improved logging, improved cache invalidation

This commit is contained in:
2025-06-22 15:35:45 +02:00
parent 2034a20055
commit de7e145b89
9 changed files with 57 additions and 39 deletions

View File

@@ -5,7 +5,7 @@ from dataclasses import asdict
import time import time
example_content = "./Scripts/example_content" example_content = "./Scripts/example_content"
probmethod = "DictionaryWeightedAverage" probmethod = "LVEWAvg"
example_searchdomain = "example_" + probmethod example_searchdomain = "example_" + probmethod
example_counter = 0 example_counter = 0
models = ["bge-m3", "mxbai-embed-large"] models = ["bge-m3", "mxbai-embed-large"]

View File

@@ -48,6 +48,7 @@ public class EntityController : ControllerBase
_domainManager.embeddingCache, _domainManager.embeddingCache,
_domainManager.client, _domainManager.client,
_domainManager.helper, _domainManager.helper,
_logger,
JsonSerializer.Serialize(jsonEntities)); JsonSerializer.Serialize(jsonEntities));
if (entities is not null && jsonEntities is not null) if (entities is not null && jsonEntities is not null)
{ {
@@ -58,9 +59,9 @@ public class EntityController : ControllerBase
if (entities.Select(x => x.name == jsonEntityName).Any() if (entities.Select(x => x.name == jsonEntityName).Any()
&& !invalidatedSearchdomains.Contains(jsonEntityName)) && !invalidatedSearchdomains.Contains(jsonEntityName))
{ {
string jsonEntitySearchdomain = jsonEntity.Searchdomain; string jsonEntitySearchdomainName = jsonEntity.Searchdomain;
invalidatedSearchdomains.Add(jsonEntitySearchdomain); invalidatedSearchdomains.Add(jsonEntitySearchdomainName);
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomain); _domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName);
} }
} }
return Ok(new EntityIndexResult() { Success = true }); return Ok(new EntityIndexResult() { Success = true });
@@ -103,11 +104,11 @@ public class EntityController : ControllerBase
{ {
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2}); embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
} }
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = embeddingResults}); datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults});
} }
else else
{ {
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = null}); datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null});
} }
} }
EntityListResult entityListResult = new() EntityListResult entityListResult = new()

View File

@@ -12,11 +12,11 @@ namespace Server;
public class Datapoint public class Datapoint
{ {
public string name; public string name;
public Probmethods.probMethodDelegate probMethod; public ProbMethod probMethod;
public List<(string, float[])> embeddings; public List<(string, float[])> embeddings;
public string hash; public string hash;
public Datapoint(string name, Probmethods.probMethodDelegate probMethod, string hash, List<(string, float[])> embeddings) public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings)
{ {
this.name = name; this.name = name;
this.probMethod = probMethod; this.probMethod = probMethod;
@@ -24,21 +24,9 @@ public class Datapoint
this.embeddings = embeddings; this.embeddings = embeddings;
} }
// public Datapoint(string name, Probmethods.probMethodDelegate probmethod, string content, List<string> models, OllamaApiClient ollama)
// {
// this.name = name;
// this.probMethod = probmethod;
// embeddings = GenerateEmbeddings(content, models, ollama);
// }
// public float CalcProbability()
// {
// return probMethod(embeddings); // <--- prob method is not used with the embeddings!
// }
public float CalcProbability(List<(string, float)> probabilities) public float CalcProbability(List<(string, float)> probabilities)
{ {
return probMethod(probabilities); return probMethod.method(probabilities);
} }
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama) public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
@@ -106,11 +94,10 @@ public class Datapoint
Input = [content] Input = [content]
}; };
var response = ollama.GenerateEmbeddingAsync(content, new EmbeddingGenerationOptions(){ModelId=model}).Result; var response = ollama.EmbedAsync(request).Result;
if (response is not null) if (response is not null)
{ {
float[] var = new float[response.Vector.Length]; float[] var = [.. response.Embeddings.First()];
response.Vector.CopyTo(var);
retVal[model] = var; retVal[model] = var;
if (!embeddingCache.ContainsKey(model)) if (!embeddingCache.ContainsKey(model))
{ {

View File

@@ -3,7 +3,7 @@ using MySql.Data.MySqlClient;
namespace Server; namespace Server;
public class SQLHelper public class SQLHelper:IDisposable
{ {
public MySqlConnection connection; public MySqlConnection connection;
public string connectionString; public string connectionString;
@@ -19,6 +19,12 @@ public class SQLHelper
return new SQLHelper(newConnection, connectionString); return new SQLHelper(newConnection, connectionString);
} }
public void Dispose()
{
connection.Close();
GC.SuppressFinalize(this);
}
public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters) public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters)
{ {
lock (connection) lock (connection)

View File

@@ -41,7 +41,7 @@ public static class SearchdomainHelper
return null; return null;
} }
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, string json) public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json)
{ {
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json); List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
if (jsonEntities is null) if (jsonEntities is null)
@@ -67,8 +67,8 @@ public static class SearchdomainHelper
ConcurrentQueue<Entity> retVal = []; ConcurrentQueue<Entity> retVal = [];
Parallel.ForEach(jsonEntities, jSONEntity => Parallel.ForEach(jsonEntities, jSONEntity =>
{ {
var tempHelper = helper.DuplicateConnection(); using var tempHelper = helper.DuplicateConnection();
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, jSONEntity); var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity);
if (entity is not null) if (entity is not null)
{ {
retVal.Enqueue(entity); retVal.Enqueue(entity);
@@ -77,7 +77,7 @@ public static class SearchdomainHelper
return [.. retVal]; return [.. retVal];
} }
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, JSONEntity jsonEntity) //string json) public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
{ {
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = []; Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
@@ -123,7 +123,7 @@ public static class SearchdomainHelper
{ {
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache); embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
} }
var probMethod_embedding = Probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
List<(string model, byte[] embedding)> data = []; List<(string model, byte[] embedding)> data = [];

View File

@@ -3,6 +3,24 @@ using System.Text.Json;
namespace Server; namespace Server;
public class ProbMethod
{
public Probmethods.probMethodDelegate method;
public string name;
public ProbMethod(string name, ILogger logger)
{
this.name = name;
Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name);
if (probMethod is null)
{
logger.LogError("Unable to retrieve probMethod {name}", [name]);
throw new Exception("Unable to retrieve probMethod");
}
method = probMethod;
}
}
public static class Probmethods public static class Probmethods
{ {
public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters); public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters);

View File

@@ -38,16 +38,18 @@ public class Searchdomain
public int embeddingCacheMaxSize = 10000000; public int embeddingCacheMaxSize = 10000000;
private readonly MySqlConnection connection; private readonly MySqlConnection connection;
public SQLHelper helper; public SQLHelper helper;
private readonly ILogger _logger;
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain() // TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, string provider = "sqlserver", bool runEmpty = false) public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
{ {
_connectionString = connectionString; _connectionString = connectionString;
_provider = provider.ToLower(); _provider = provider.ToLower();
this.searchdomain = searchdomain; this.searchdomain = searchdomain;
this.ollama = ollama; this.ollama = ollama;
this.embeddingCache = embeddingCache; this.embeddingCache = embeddingCache;
this._logger = logger;
searchCache = []; searchCache = [];
entityCache = []; entityCache = [];
connection = new MySqlConnection(connectionString); connection = new MySqlConnection(connectionString);
@@ -57,12 +59,13 @@ public class Searchdomain
if (!runEmpty) if (!runEmpty)
{ {
GetID(); GetID();
UpdateSearchDomain(); UpdateEntityCache();
} }
} }
public void UpdateSearchDomain() public void UpdateEntityCache()
{ {
entityCache = [];
Dictionary<string, dynamic> parametersIDSearchdomain = new() Dictionary<string, dynamic> parametersIDSearchdomain = new()
{ {
["id"] = this.id ["id"] = this.id
@@ -99,7 +102,7 @@ public class Searchdomain
string name = datapointReader.GetString(2); string name = datapointReader.GetString(2);
string probmethodString = datapointReader.GetString(3); string probmethodString = datapointReader.GetString(3);
string hash = datapointReader.GetString(4); string hash = datapointReader.GetString(4);
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString); ProbMethod probmethod = new(probmethodString, _logger);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null) if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{ {
embedding_unassigned.Remove(id); embedding_unassigned.Remove(id);
@@ -179,7 +182,7 @@ public class Searchdomain
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2); float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
list.Add((key, value)); list.Add((key, value));
} }
datapointProbs.Add((datapoint.name, datapoint.probMethod(list))); datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
} }
result.Add((entity.probMethod(datapointProbs), entity.name)); result.Add((entity.probMethod(datapointProbs), entity.name));
} }

View File

@@ -53,7 +53,7 @@ public class SearchdomainManager
} }
try try
{ {
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache)); return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger));
} }
catch (MySqlException) catch (MySqlException)
{ {
@@ -62,9 +62,12 @@ public class SearchdomainManager
} }
} }
public void InvalidateSearchdomainCache(string searchdomain) public void InvalidateSearchdomainCache(string searchdomainName)
{ {
searchdomains.Remove(searchdomain); if (searchdomains.TryGetValue(searchdomainName, out var searchdomain))
{
searchdomain.UpdateEntityCache();
}
} }
public List<string> ListSearchdomains() public List<string> ListSearchdomains()

View File

@@ -15,7 +15,7 @@
<PackageReference Include="Microsoft.Data.Sqlite" Version="9.0.3" /> <PackageReference Include="Microsoft.Data.Sqlite" Version="9.0.3" />
<PackageReference Include="MySql.Data" Version="9.2.0" /> <PackageReference Include="MySql.Data" Version="9.2.0" />
<PackageReference Include="Npgsql" Version="9.0.3" /> <PackageReference Include="Npgsql" Version="9.0.3" />
<PackageReference Include="OllamaSharp" Version="5.1.9" /> <PackageReference Include="OllamaSharp" Version="5.2.2" />
<PackageReference Include="System.Configuration.ConfigurationManager" Version="9.0.3" /> <PackageReference Include="System.Configuration.ConfigurationManager" Version="9.0.3" />
<PackageReference Include="System.Data.SqlClient" Version="4.9.0" /> <PackageReference Include="System.Data.SqlClient" Version="4.9.0" />
<PackageReference Include="System.Data.Sqlite" Version="1.0.119" /> <PackageReference Include="System.Data.Sqlite" Version="1.0.119" />