Fixed broken ProbMethod in Entity/Index, fixed undisposed MySQL connections, dependency update, improved logging, improved cache invalidation
This commit is contained in:
@@ -48,6 +48,7 @@ public class EntityController : ControllerBase
|
||||
_domainManager.embeddingCache,
|
||||
_domainManager.client,
|
||||
_domainManager.helper,
|
||||
_logger,
|
||||
JsonSerializer.Serialize(jsonEntities));
|
||||
if (entities is not null && jsonEntities is not null)
|
||||
{
|
||||
@@ -58,9 +59,9 @@ public class EntityController : ControllerBase
|
||||
if (entities.Select(x => x.name == jsonEntityName).Any()
|
||||
&& !invalidatedSearchdomains.Contains(jsonEntityName))
|
||||
{
|
||||
string jsonEntitySearchdomain = jsonEntity.Searchdomain;
|
||||
invalidatedSearchdomains.Add(jsonEntitySearchdomain);
|
||||
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomain);
|
||||
string jsonEntitySearchdomainName = jsonEntity.Searchdomain;
|
||||
invalidatedSearchdomains.Add(jsonEntitySearchdomainName);
|
||||
_domainManager.InvalidateSearchdomainCache(jsonEntitySearchdomainName);
|
||||
}
|
||||
}
|
||||
return Ok(new EntityIndexResult() { Success = true });
|
||||
@@ -103,11 +104,11 @@ public class EntityController : ControllerBase
|
||||
{
|
||||
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
|
||||
}
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = embeddingResults});
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults});
|
||||
}
|
||||
else
|
||||
{
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.Method.Name, Embeddings = null});
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null});
|
||||
}
|
||||
}
|
||||
EntityListResult entityListResult = new()
|
||||
|
||||
@@ -12,11 +12,11 @@ namespace Server;
|
||||
public class Datapoint
|
||||
{
|
||||
public string name;
|
||||
public Probmethods.probMethodDelegate probMethod;
|
||||
public ProbMethod probMethod;
|
||||
public List<(string, float[])> embeddings;
|
||||
public string hash;
|
||||
|
||||
public Datapoint(string name, Probmethods.probMethodDelegate probMethod, string hash, List<(string, float[])> embeddings)
|
||||
public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings)
|
||||
{
|
||||
this.name = name;
|
||||
this.probMethod = probMethod;
|
||||
@@ -24,21 +24,9 @@ public class Datapoint
|
||||
this.embeddings = embeddings;
|
||||
}
|
||||
|
||||
// public Datapoint(string name, Probmethods.probMethodDelegate probmethod, string content, List<string> models, OllamaApiClient ollama)
|
||||
// {
|
||||
// this.name = name;
|
||||
// this.probMethod = probmethod;
|
||||
// embeddings = GenerateEmbeddings(content, models, ollama);
|
||||
// }
|
||||
|
||||
// public float CalcProbability()
|
||||
// {
|
||||
// return probMethod(embeddings); // <--- prob method is not used with the embeddings!
|
||||
// }
|
||||
|
||||
public float CalcProbability(List<(string, float)> probabilities)
|
||||
{
|
||||
return probMethod(probabilities);
|
||||
return probMethod.method(probabilities);
|
||||
}
|
||||
|
||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
|
||||
@@ -106,11 +94,10 @@ public class Datapoint
|
||||
Input = [content]
|
||||
};
|
||||
|
||||
var response = ollama.GenerateEmbeddingAsync(content, new EmbeddingGenerationOptions(){ModelId=model}).Result;
|
||||
var response = ollama.EmbedAsync(request).Result;
|
||||
if (response is not null)
|
||||
{
|
||||
float[] var = new float[response.Vector.Length];
|
||||
response.Vector.CopyTo(var);
|
||||
float[] var = [.. response.Embeddings.First()];
|
||||
retVal[model] = var;
|
||||
if (!embeddingCache.ContainsKey(model))
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@ using MySql.Data.MySqlClient;
|
||||
|
||||
namespace Server;
|
||||
|
||||
public class SQLHelper
|
||||
public class SQLHelper:IDisposable
|
||||
{
|
||||
public MySqlConnection connection;
|
||||
public string connectionString;
|
||||
@@ -19,6 +19,12 @@ public class SQLHelper
|
||||
return new SQLHelper(newConnection, connectionString);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
connection.Close();
|
||||
GC.SuppressFinalize(this);
|
||||
}
|
||||
|
||||
public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters)
|
||||
{
|
||||
lock (connection)
|
||||
|
||||
@@ -41,7 +41,7 @@ public static class SearchdomainHelper
|
||||
return null;
|
||||
}
|
||||
|
||||
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, string json)
|
||||
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json)
|
||||
{
|
||||
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
|
||||
if (jsonEntities is null)
|
||||
@@ -67,8 +67,8 @@ public static class SearchdomainHelper
|
||||
ConcurrentQueue<Entity> retVal = [];
|
||||
Parallel.ForEach(jsonEntities, jSONEntity =>
|
||||
{
|
||||
var tempHelper = helper.DuplicateConnection();
|
||||
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, jSONEntity);
|
||||
using var tempHelper = helper.DuplicateConnection();
|
||||
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity);
|
||||
if (entity is not null)
|
||||
{
|
||||
retVal.Enqueue(entity);
|
||||
@@ -77,7 +77,7 @@ public static class SearchdomainHelper
|
||||
return [.. retVal];
|
||||
}
|
||||
|
||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, JSONEntity jsonEntity) //string json)
|
||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
||||
{
|
||||
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
||||
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
||||
@@ -123,7 +123,7 @@ public static class SearchdomainHelper
|
||||
{
|
||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
|
||||
}
|
||||
var probMethod_embedding = Probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
||||
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
|
||||
List<(string model, byte[] embedding)> data = [];
|
||||
|
||||
@@ -3,6 +3,24 @@ using System.Text.Json;
|
||||
|
||||
namespace Server;
|
||||
|
||||
public class ProbMethod
|
||||
{
|
||||
public Probmethods.probMethodDelegate method;
|
||||
public string name;
|
||||
|
||||
public ProbMethod(string name, ILogger logger)
|
||||
{
|
||||
this.name = name;
|
||||
Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name);
|
||||
if (probMethod is null)
|
||||
{
|
||||
logger.LogError("Unable to retrieve probMethod {name}", [name]);
|
||||
throw new Exception("Unable to retrieve probMethod");
|
||||
}
|
||||
method = probMethod;
|
||||
}
|
||||
}
|
||||
|
||||
public static class Probmethods
|
||||
{
|
||||
public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters);
|
||||
|
||||
@@ -38,16 +38,18 @@ public class Searchdomain
|
||||
public int embeddingCacheMaxSize = 10000000;
|
||||
private readonly MySqlConnection connection;
|
||||
public SQLHelper helper;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
|
||||
|
||||
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, string provider = "sqlserver", bool runEmpty = false)
|
||||
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
||||
{
|
||||
_connectionString = connectionString;
|
||||
_provider = provider.ToLower();
|
||||
this.searchdomain = searchdomain;
|
||||
this.ollama = ollama;
|
||||
this.embeddingCache = embeddingCache;
|
||||
this._logger = logger;
|
||||
searchCache = [];
|
||||
entityCache = [];
|
||||
connection = new MySqlConnection(connectionString);
|
||||
@@ -57,12 +59,13 @@ public class Searchdomain
|
||||
if (!runEmpty)
|
||||
{
|
||||
GetID();
|
||||
UpdateSearchDomain();
|
||||
UpdateEntityCache();
|
||||
}
|
||||
}
|
||||
|
||||
public void UpdateSearchDomain()
|
||||
public void UpdateEntityCache()
|
||||
{
|
||||
entityCache = [];
|
||||
Dictionary<string, dynamic> parametersIDSearchdomain = new()
|
||||
{
|
||||
["id"] = this.id
|
||||
@@ -99,7 +102,7 @@ public class Searchdomain
|
||||
string name = datapointReader.GetString(2);
|
||||
string probmethodString = datapointReader.GetString(3);
|
||||
string hash = datapointReader.GetString(4);
|
||||
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
|
||||
ProbMethod probmethod = new(probmethodString, _logger);
|
||||
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
||||
{
|
||||
embedding_unassigned.Remove(id);
|
||||
@@ -179,7 +182,7 @@ public class Searchdomain
|
||||
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
|
||||
list.Add((key, value));
|
||||
}
|
||||
datapointProbs.Add((datapoint.name, datapoint.probMethod(list)));
|
||||
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
|
||||
}
|
||||
result.Add((entity.probMethod(datapointProbs), entity.name));
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ public class SearchdomainManager
|
||||
}
|
||||
try
|
||||
{
|
||||
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache));
|
||||
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger));
|
||||
}
|
||||
catch (MySqlException)
|
||||
{
|
||||
@@ -62,9 +62,12 @@ public class SearchdomainManager
|
||||
}
|
||||
}
|
||||
|
||||
public void InvalidateSearchdomainCache(string searchdomain)
|
||||
public void InvalidateSearchdomainCache(string searchdomainName)
|
||||
{
|
||||
searchdomains.Remove(searchdomain);
|
||||
if (searchdomains.TryGetValue(searchdomainName, out var searchdomain))
|
||||
{
|
||||
searchdomain.UpdateEntityCache();
|
||||
}
|
||||
}
|
||||
|
||||
public List<string> ListSearchdomains()
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
<PackageReference Include="Microsoft.Data.Sqlite" Version="9.0.3" />
|
||||
<PackageReference Include="MySql.Data" Version="9.2.0" />
|
||||
<PackageReference Include="Npgsql" Version="9.0.3" />
|
||||
<PackageReference Include="OllamaSharp" Version="5.1.9" />
|
||||
<PackageReference Include="OllamaSharp" Version="5.2.2" />
|
||||
<PackageReference Include="System.Configuration.ConfigurationManager" Version="9.0.3" />
|
||||
<PackageReference Include="System.Data.SqlClient" Version="4.9.0" />
|
||||
<PackageReference Include="System.Data.Sqlite" Version="1.0.119" />
|
||||
|
||||
Reference in New Issue
Block a user