Fixed broken ProbMethod in Entity/Index, fixed undisposed MySQL connections, dependency update, improved logging, improved cache invalidation
This commit is contained in:
@@ -5,7 +5,7 @@ from dataclasses import asdict
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
example_content = "./Scripts/example_content"
|
example_content = "./Scripts/example_content"
|
||||||
probmethod = "DictionaryWeightedAverage"
|
probmethod = "LVEWAvg"
|
||||||
example_searchdomain = "example_" + probmethod
|
example_searchdomain = "example_" + probmethod
|
||||||
example_counter = 0
|
example_counter = 0
|
||||||
models = ["bge-m3", "mxbai-embed-large"]
|
models = ["bge-m3", "mxbai-embed-large"]
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ public class EntityController : ControllerBase
|
|||||||
_domainManager.embeddingCache,
|
_domainManager.embeddingCache,
|
||||||
_domainManager.client,
|
_domainManager.client,
|
||||||
_domainManager.helper,
|
_domainManager.helper,
|
||||||
|
_logger,
|
||||||
JsonSerializer.Serialize(jsonEntities));
|
JsonSerializer.Serialize(jsonEntities));
|
||||||
if (entities is not null && jsonEntities is not null)
|
if (entities is not null && jsonEntities is not null)
|
||||||
{
|
{
|
||||||
@@ -58,9 +59,9 @@ public class EntityController : ControllerBase
|
|||||||
if (entities.Select(x => x.name == jsonEntityName).Any()
|
if (entities.Select(x => x.name == jsonEntityName).Any()
|
||||||
&& !invalidatedSearchdomains.Contains(jsonEntityName))
|
&& !invalidatedSearchdomains.Contains(jsonEntityName))
|
||||||
{
|
{
|
||||||
string jsonEntitySearchdomain = jsonEntity.Searchdomain;
|
string jsonEntitySearchdomainName = jsonEntity.Searchdomain;
|
||||||
invalidatedSearchdomains.Add(jsonEntitySearchdomain);
|
invalidatedSearchdomains.Add(jsonEntitySearchdomainName);
|
||||||
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomain);
|
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Ok(new EntityIndexResult() { Success = true });
|
return Ok(new EntityIndexResult() { Success = true });
|
||||||
@@ -103,11 +104,11 @@ public class EntityController : ControllerBase
|
|||||||
{
|
{
|
||||||
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
|
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
|
||||||
}
|
}
|
||||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = embeddingResults});
|
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults});
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = null});
|
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EntityListResult entityListResult = new()
|
EntityListResult entityListResult = new()
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ namespace Server;
|
|||||||
public class Datapoint
|
public class Datapoint
|
||||||
{
|
{
|
||||||
public string name;
|
public string name;
|
||||||
public Probmethods.probMethodDelegate probMethod;
|
public ProbMethod probMethod;
|
||||||
public List<(string, float[])> embeddings;
|
public List<(string, float[])> embeddings;
|
||||||
public string hash;
|
public string hash;
|
||||||
|
|
||||||
public Datapoint(string name, Probmethods.probMethodDelegate probMethod, string hash, List<(string, float[])> embeddings)
|
public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings)
|
||||||
{
|
{
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.probMethod = probMethod;
|
this.probMethod = probMethod;
|
||||||
@@ -24,21 +24,9 @@ public class Datapoint
|
|||||||
this.embeddings = embeddings;
|
this.embeddings = embeddings;
|
||||||
}
|
}
|
||||||
|
|
||||||
// public Datapoint(string name, Probmethods.probMethodDelegate probmethod, string content, List<string> models, OllamaApiClient ollama)
|
|
||||||
// {
|
|
||||||
// this.name = name;
|
|
||||||
// this.probMethod = probmethod;
|
|
||||||
// embeddings = GenerateEmbeddings(content, models, ollama);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// public float CalcProbability()
|
|
||||||
// {
|
|
||||||
// return probMethod(embeddings); // <--- prob method is not used with the embeddings!
|
|
||||||
// }
|
|
||||||
|
|
||||||
public float CalcProbability(List<(string, float)> probabilities)
|
public float CalcProbability(List<(string, float)> probabilities)
|
||||||
{
|
{
|
||||||
return probMethod(probabilities);
|
return probMethod.method(probabilities);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
|
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
|
||||||
@@ -106,11 +94,10 @@ public class Datapoint
|
|||||||
Input = [content]
|
Input = [content]
|
||||||
};
|
};
|
||||||
|
|
||||||
var response = ollama.GenerateEmbeddingAsync(content, new EmbeddingGenerationOptions(){ModelId=model}).Result;
|
var response = ollama.EmbedAsync(request).Result;
|
||||||
if (response is not null)
|
if (response is not null)
|
||||||
{
|
{
|
||||||
float[] var = new float[response.Vector.Length];
|
float[] var = [.. response.Embeddings.First()];
|
||||||
response.Vector.CopyTo(var);
|
|
||||||
retVal[model] = var;
|
retVal[model] = var;
|
||||||
if (!embeddingCache.ContainsKey(model))
|
if (!embeddingCache.ContainsKey(model))
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ using MySql.Data.MySqlClient;
|
|||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
public class SQLHelper
|
public class SQLHelper:IDisposable
|
||||||
{
|
{
|
||||||
public MySqlConnection connection;
|
public MySqlConnection connection;
|
||||||
public string connectionString;
|
public string connectionString;
|
||||||
@@ -19,6 +19,12 @@ public class SQLHelper
|
|||||||
return new SQLHelper(newConnection, connectionString);
|
return new SQLHelper(newConnection, connectionString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
connection.Close();
|
||||||
|
GC.SuppressFinalize(this);
|
||||||
|
}
|
||||||
|
|
||||||
public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters)
|
public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters)
|
||||||
{
|
{
|
||||||
lock (connection)
|
lock (connection)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ public static class SearchdomainHelper
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, string json)
|
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json)
|
||||||
{
|
{
|
||||||
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
|
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
|
||||||
if (jsonEntities is null)
|
if (jsonEntities is null)
|
||||||
@@ -67,8 +67,8 @@ public static class SearchdomainHelper
|
|||||||
ConcurrentQueue<Entity> retVal = [];
|
ConcurrentQueue<Entity> retVal = [];
|
||||||
Parallel.ForEach(jsonEntities, jSONEntity =>
|
Parallel.ForEach(jsonEntities, jSONEntity =>
|
||||||
{
|
{
|
||||||
var tempHelper = helper.DuplicateConnection();
|
using var tempHelper = helper.DuplicateConnection();
|
||||||
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, jSONEntity);
|
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity);
|
||||||
if (entity is not null)
|
if (entity is not null)
|
||||||
{
|
{
|
||||||
retVal.Enqueue(entity);
|
retVal.Enqueue(entity);
|
||||||
@@ -77,7 +77,7 @@ public static class SearchdomainHelper
|
|||||||
return [.. retVal];
|
return [.. retVal];
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, JSONEntity jsonEntity) //string json)
|
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
||||||
{
|
{
|
||||||
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
||||||
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
||||||
@@ -123,7 +123,7 @@ public static class SearchdomainHelper
|
|||||||
{
|
{
|
||||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
|
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
|
||||||
}
|
}
|
||||||
var probMethod_embedding = Probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
||||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||||
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
|
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
|
||||||
List<(string model, byte[] embedding)> data = [];
|
List<(string model, byte[] embedding)> data = [];
|
||||||
|
|||||||
@@ -3,6 +3,24 @@ using System.Text.Json;
|
|||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
|
public class ProbMethod
|
||||||
|
{
|
||||||
|
public Probmethods.probMethodDelegate method;
|
||||||
|
public string name;
|
||||||
|
|
||||||
|
public ProbMethod(string name, ILogger logger)
|
||||||
|
{
|
||||||
|
this.name = name;
|
||||||
|
Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name);
|
||||||
|
if (probMethod is null)
|
||||||
|
{
|
||||||
|
logger.LogError("Unable to retrieve probMethod {name}", [name]);
|
||||||
|
throw new Exception("Unable to retrieve probMethod");
|
||||||
|
}
|
||||||
|
method = probMethod;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public static class Probmethods
|
public static class Probmethods
|
||||||
{
|
{
|
||||||
public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters);
|
public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters);
|
||||||
|
|||||||
@@ -38,16 +38,18 @@ public class Searchdomain
|
|||||||
public int embeddingCacheMaxSize = 10000000;
|
public int embeddingCacheMaxSize = 10000000;
|
||||||
private readonly MySqlConnection connection;
|
private readonly MySqlConnection connection;
|
||||||
public SQLHelper helper;
|
public SQLHelper helper;
|
||||||
|
private readonly ILogger _logger;
|
||||||
|
|
||||||
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
|
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
|
||||||
|
|
||||||
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, string provider = "sqlserver", bool runEmpty = false)
|
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
||||||
{
|
{
|
||||||
_connectionString = connectionString;
|
_connectionString = connectionString;
|
||||||
_provider = provider.ToLower();
|
_provider = provider.ToLower();
|
||||||
this.searchdomain = searchdomain;
|
this.searchdomain = searchdomain;
|
||||||
this.ollama = ollama;
|
this.ollama = ollama;
|
||||||
this.embeddingCache = embeddingCache;
|
this.embeddingCache = embeddingCache;
|
||||||
|
this._logger = logger;
|
||||||
searchCache = [];
|
searchCache = [];
|
||||||
entityCache = [];
|
entityCache = [];
|
||||||
connection = new MySqlConnection(connectionString);
|
connection = new MySqlConnection(connectionString);
|
||||||
@@ -57,12 +59,13 @@ public class Searchdomain
|
|||||||
if (!runEmpty)
|
if (!runEmpty)
|
||||||
{
|
{
|
||||||
GetID();
|
GetID();
|
||||||
UpdateSearchDomain();
|
UpdateEntityCache();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void UpdateSearchDomain()
|
public void UpdateEntityCache()
|
||||||
{
|
{
|
||||||
|
entityCache = [];
|
||||||
Dictionary<string, dynamic> parametersIDSearchdomain = new()
|
Dictionary<string, dynamic> parametersIDSearchdomain = new()
|
||||||
{
|
{
|
||||||
["id"] = this.id
|
["id"] = this.id
|
||||||
@@ -99,7 +102,7 @@ public class Searchdomain
|
|||||||
string name = datapointReader.GetString(2);
|
string name = datapointReader.GetString(2);
|
||||||
string probmethodString = datapointReader.GetString(3);
|
string probmethodString = datapointReader.GetString(3);
|
||||||
string hash = datapointReader.GetString(4);
|
string hash = datapointReader.GetString(4);
|
||||||
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
|
ProbMethod probmethod = new(probmethodString, _logger);
|
||||||
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
||||||
{
|
{
|
||||||
embedding_unassigned.Remove(id);
|
embedding_unassigned.Remove(id);
|
||||||
@@ -179,7 +182,7 @@ public class Searchdomain
|
|||||||
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
|
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
|
||||||
list.Add((key, value));
|
list.Add((key, value));
|
||||||
}
|
}
|
||||||
datapointProbs.Add((datapoint.name, datapoint.probMethod(list)));
|
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
|
||||||
}
|
}
|
||||||
result.Add((entity.probMethod(datapointProbs), entity.name));
|
result.Add((entity.probMethod(datapointProbs), entity.name));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ public class SearchdomainManager
|
|||||||
}
|
}
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache));
|
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger));
|
||||||
}
|
}
|
||||||
catch (MySqlException)
|
catch (MySqlException)
|
||||||
{
|
{
|
||||||
@@ -62,9 +62,12 @@ public class SearchdomainManager
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void InvalidateSearchdomainCache(string searchdomain)
|
public void InvalidateSearchdomainCache(string searchdomainName)
|
||||||
{
|
{
|
||||||
searchdomains.Remove(searchdomain);
|
if (searchdomains.TryGetValue(searchdomainName, out var searchdomain))
|
||||||
|
{
|
||||||
|
searchdomain.UpdateEntityCache();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<string> ListSearchdomains()
|
public List<string> ListSearchdomains()
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
<PackageReference Include="Microsoft.Data.Sqlite" Version="9.0.3" />
|
<PackageReference Include="Microsoft.Data.Sqlite" Version="9.0.3" />
|
||||||
<PackageReference Include="MySql.Data" Version="9.2.0" />
|
<PackageReference Include="MySql.Data" Version="9.2.0" />
|
||||||
<PackageReference Include="Npgsql" Version="9.0.3" />
|
<PackageReference Include="Npgsql" Version="9.0.3" />
|
||||||
<PackageReference Include="OllamaSharp" Version="5.1.9" />
|
<PackageReference Include="OllamaSharp" Version="5.2.2" />
|
||||||
<PackageReference Include="System.Configuration.ConfigurationManager" Version="9.0.3" />
|
<PackageReference Include="System.Configuration.ConfigurationManager" Version="9.0.3" />
|
||||||
<PackageReference Include="System.Data.SqlClient" Version="4.9.0" />
|
<PackageReference Include="System.Data.SqlClient" Version="4.9.0" />
|
||||||
<PackageReference Include="System.Data.Sqlite" Version="1.0.119" />
|
<PackageReference Include="System.Data.Sqlite" Version="1.0.119" />
|
||||||
|
|||||||
Reference in New Issue
Block a user