Fixed broken ProbMethod in Entity/Index, fixed undisposed MySQL connections, dependency update, improved logging, improved cache invalidation

This commit is contained in:
2025-06-22 15:35:45 +02:00
parent 2034a20055
commit de7e145b89
9 changed files with 57 additions and 39 deletions

View File

@@ -5,7 +5,7 @@ from dataclasses import asdict
import time
example_content = "./Scripts/example_content"
probmethod = "DictionaryWeightedAverage"
probmethod = "LVEWAvg"
example_searchdomain = "example_" + probmethod
example_counter = 0
models = ["bge-m3", "mxbai-embed-large"]

View File

@@ -48,6 +48,7 @@ public class EntityController : ControllerBase
_domainManager.embeddingCache,
_domainManager.client,
_domainManager.helper,
_logger,
JsonSerializer.Serialize(jsonEntities));
if (entities is not null && jsonEntities is not null)
{
@@ -58,9 +59,9 @@ public class EntityController : ControllerBase
if (entities.Select(x => x.name == jsonEntityName).Any()
&& !invalidatedSearchdomains.Contains(jsonEntityName))
{
string jsonEntitySearchdomain = jsonEntity.Searchdomain;
invalidatedSearchdomains.Add(jsonEntitySearchdomain);
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomain);
string jsonEntitySearchdomainName = jsonEntity.Searchdomain;
invalidatedSearchdomains.Add(jsonEntitySearchdomainName);
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName);
}
}
return Ok(new EntityIndexResult() { Success = true });
@@ -103,11 +104,11 @@ public class EntityController : ControllerBase
{
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
}
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = embeddingResults});
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults});
}
else
{
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = null});
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null});
}
}
EntityListResult entityListResult = new()

View File

@@ -12,11 +12,11 @@ namespace Server;
public class Datapoint
{
public string name;
public Probmethods.probMethodDelegate probMethod;
public ProbMethod probMethod;
public List<(string, float[])> embeddings;
public string hash;
public Datapoint(string name, Probmethods.probMethodDelegate probMethod, string hash, List<(string, float[])> embeddings)
public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings)
{
this.name = name;
this.probMethod = probMethod;
@@ -24,21 +24,9 @@ public class Datapoint
this.embeddings = embeddings;
}
// public Datapoint(string name, Probmethods.probMethodDelegate probmethod, string content, List<string> models, OllamaApiClient ollama)
// {
// this.name = name;
// this.probMethod = probmethod;
// embeddings = GenerateEmbeddings(content, models, ollama);
// }
// public float CalcProbability()
// {
// return probMethod(embeddings); // <--- prob method is not used with the embeddings!
// }
public float CalcProbability(List<(string, float)> probabilities)
{
return probMethod(probabilities);
return probMethod.method(probabilities);
}
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
@@ -106,11 +94,10 @@ public class Datapoint
Input = [content]
};
var response = ollama.GenerateEmbeddingAsync(content, new EmbeddingGenerationOptions(){ModelId=model}).Result;
var response = ollama.EmbedAsync(request).Result;
if (response is not null)
{
float[] var = new float[response.Vector.Length];
response.Vector.CopyTo(var);
float[] var = [.. response.Embeddings.First()];
retVal[model] = var;
if (!embeddingCache.ContainsKey(model))
{

View File

@@ -3,7 +3,7 @@ using MySql.Data.MySqlClient;
namespace Server;
public class SQLHelper
public class SQLHelper:IDisposable
{
public MySqlConnection connection;
public string connectionString;
@@ -19,6 +19,12 @@ public class SQLHelper
return new SQLHelper(newConnection, connectionString);
}
public void Dispose()
{
connection.Close();
GC.SuppressFinalize(this);
}
public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters)
{
lock (connection)

View File

@@ -41,7 +41,7 @@ public static class SearchdomainHelper
return null;
}
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, string json)
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json)
{
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
if (jsonEntities is null)
@@ -67,8 +67,8 @@ public static class SearchdomainHelper
ConcurrentQueue<Entity> retVal = [];
Parallel.ForEach(jsonEntities, jSONEntity =>
{
var tempHelper = helper.DuplicateConnection();
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, jSONEntity);
using var tempHelper = helper.DuplicateConnection();
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity);
if (entity is not null)
{
retVal.Enqueue(entity);
@@ -77,7 +77,7 @@ public static class SearchdomainHelper
return [.. retVal];
}
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, JSONEntity jsonEntity) //string json)
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
{
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
@@ -123,7 +123,7 @@ public static class SearchdomainHelper
{
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
}
var probMethod_embedding = Probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
List<(string model, byte[] embedding)> data = [];

View File

@@ -3,6 +3,24 @@ using System.Text.Json;
namespace Server;
public class ProbMethod
{
public Probmethods.probMethodDelegate method;
public string name;
public ProbMethod(string name, ILogger logger)
{
this.name = name;
Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name);
if (probMethod is null)
{
logger.LogError("Unable to retrieve probMethod {name}", [name]);
throw new Exception("Unable to retrieve probMethod");
}
method = probMethod;
}
}
public static class Probmethods
{
public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters);

View File

@@ -38,16 +38,18 @@ public class Searchdomain
public int embeddingCacheMaxSize = 10000000;
private readonly MySqlConnection connection;
public SQLHelper helper;
private readonly ILogger _logger;
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, string provider = "sqlserver", bool runEmpty = false)
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
{
_connectionString = connectionString;
_provider = provider.ToLower();
this.searchdomain = searchdomain;
this.ollama = ollama;
this.embeddingCache = embeddingCache;
this._logger = logger;
searchCache = [];
entityCache = [];
connection = new MySqlConnection(connectionString);
@@ -57,12 +59,13 @@ public class Searchdomain
if (!runEmpty)
{
GetID();
UpdateSearchDomain();
UpdateEntityCache();
}
}
public void UpdateSearchDomain()
public void UpdateEntityCache()
{
entityCache = [];
Dictionary<string, dynamic> parametersIDSearchdomain = new()
{
["id"] = this.id
@@ -99,7 +102,7 @@ public class Searchdomain
string name = datapointReader.GetString(2);
string probmethodString = datapointReader.GetString(3);
string hash = datapointReader.GetString(4);
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
ProbMethod probmethod = new(probmethodString, _logger);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{
embedding_unassigned.Remove(id);
@@ -179,7 +182,7 @@ public class Searchdomain
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
list.Add((key, value));
}
datapointProbs.Add((datapoint.name, datapoint.probMethod(list)));
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
}
result.Add((entity.probMethod(datapointProbs), entity.name));
}

View File

@@ -53,7 +53,7 @@ public class SearchdomainManager
}
try
{
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache));
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger));
}
catch (MySqlException)
{
@@ -62,9 +62,12 @@ public class SearchdomainManager
}
}
public void InvalidateSearchdomainCache(string searchdomain)
public void InvalidateSearchdomainCache(string searchdomainName)
{
searchdomains.Remove(searchdomain);
if (searchdomains.TryGetValue(searchdomainName, out var searchdomain))
{
searchdomain.UpdateEntityCache();
}
}
public List<string> ListSearchdomains()

View File

@@ -15,7 +15,7 @@
<PackageReference Include="Microsoft.Data.Sqlite" Version="9.0.3" />
<PackageReference Include="MySql.Data" Version="9.2.0" />
<PackageReference Include="Npgsql" Version="9.0.3" />
<PackageReference Include="OllamaSharp" Version="5.1.9" />
<PackageReference Include="OllamaSharp" Version="5.2.2" />
<PackageReference Include="System.Configuration.ConfigurationManager" Version="9.0.3" />
<PackageReference Include="System.Data.SqlClient" Version="4.9.0" />
<PackageReference Include="System.Data.Sqlite" Version="1.0.119" />