Added hash based optimization

This commit is contained in:
2025-06-11 13:23:02 +02:00
parent 371a16511a
commit e6211a185b
3 changed files with 75 additions and 32 deletions

View File

@@ -14,11 +14,13 @@ public class Datapoint
public string name; public string name;
public Probmethods.probMethodDelegate probMethod; public Probmethods.probMethodDelegate probMethod;
public List<(string, float[])> embeddings; public List<(string, float[])> embeddings;
public string hash;
public Datapoint(string name, Probmethods.probMethodDelegate probMethod, List<(string, float[])> embeddings) public Datapoint(string name, Probmethods.probMethodDelegate probMethod, string hash, List<(string, float[])> embeddings)
{ {
this.name = name; this.name = name;
this.probMethod = probMethod; this.probMethod = probMethod;
this.hash = hash;
this.embeddings = embeddings; this.embeddings = embeddings;
} }

View File

@@ -17,9 +17,13 @@ public static class DatabaseMigrations
databaseVersion = UpdateFrom1(helper); // TODO: Implement reflection based dynamic invocation. databaseVersion = UpdateFrom1(helper); // TODO: Implement reflection based dynamic invocation.
goto case 2; goto case 2;
case 2: case 2:
databaseVersion = UpdateFrom2(helper);
goto case 3;
case 3:
default: default:
break; break;
} }
helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() });
} }
public static int DatabaseGetVersion(SQLHelper helper) public static int DatabaseGetVersion(SQLHelper helper)
{ {
@@ -66,4 +70,11 @@ public static class DatabaseMigrations
helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES (\"DatabaseVersion\", \"2\");", []); helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES (\"DatabaseVersion\", \"2\");", []);
return 2; return 2;
} }
public static int UpdateFrom2(SQLHelper helper)
{
helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD hash VARCHAR(44);", []);
helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []);
return 3;
}
} }

View File

@@ -18,6 +18,9 @@ using System.Collections.Immutable;
using System.Text.Json; using System.Text.Json;
using System.Numerics.Tensors; using System.Numerics.Tensors;
using Server; using Server;
using System.Security.Cryptography;
using System.Text;
using System.Collections.Concurrent;
namespace Server; namespace Server;
@@ -89,7 +92,7 @@ public class Searchdomain
} }
embeddingReader.Close(); embeddingReader.Close();
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding FROM datapoint", parametersIDSearchdomain); DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, hash FROM datapoint", parametersIDSearchdomain);
Dictionary<int, List<Datapoint>> datapoint_unassigned = []; Dictionary<int, List<Datapoint>> datapoint_unassigned = [];
while (datapointReader.Read()) while (datapointReader.Read())
{ {
@@ -97,6 +100,7 @@ public class Searchdomain
int id_entity = datapointReader.GetInt32(1); int id_entity = datapointReader.GetInt32(1);
string name = datapointReader.GetString(2); string name = datapointReader.GetString(2);
string probmethodString = datapointReader.GetString(3); string probmethodString = datapointReader.GetString(3);
string hash = datapointReader.GetString(4);
Probmethods.probMethodDelegate? probmethod = probmethods.GetMethod(probmethodString); Probmethods.probMethodDelegate? probmethod = probmethods.GetMethod(probmethodString);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null) if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{ {
@@ -105,7 +109,7 @@ public class Searchdomain
{ {
datapoint_unassigned[id_entity] = []; datapoint_unassigned[id_entity] = [];
} }
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, [.. embeddings.Select(kv => (kv.Key, kv.Value))])); datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
} }
} }
datapointReader.Close(); datapointReader.Close();
@@ -257,10 +261,12 @@ public class Searchdomain
{ {
return null; return null;
} }
if (HasEntity(jsonEntity.Name)) bool hasPreexistingEntity = HasEntity(jsonEntity.Name);
Entity? preexistingEntity = null;
if (hasPreexistingEntity)
{ {
RemoveEntity(jsonEntity.Name); preexistingEntity = GetEntity(jsonEntity.Name);
RemoveEntity(jsonEntity.Name); // TODO only remove entity if there is actually a change somewhere. Perhaps create 3 datapoint lists to operate with: 1. delete, 2. update, 3. create
} }
int id_entity = DatabaseInsertEntity(jsonEntity.Name, jsonEntity.Probmethod, id); int id_entity = DatabaseInsertEntity(jsonEntity.Name, jsonEntity.Probmethod, id);
foreach (KeyValuePair<string, string> attribute in jsonEntity.Attributes) foreach (KeyValuePair<string, string> attribute in jsonEntity.Attributes)
@@ -269,17 +275,32 @@ public class Searchdomain
} }
List<Datapoint> datapoints = []; List<Datapoint> datapoints = [];
foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints) foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints)
{ {
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache); Dictionary<string, float[]> embeddings = [];
string hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
if (hasPreexistingEntity && preexistingEntity is not null)
{
IEnumerable<Datapoint> preexistingDatapoints = preexistingEntity.datapoints.Where(x => x.name == jsonDatapoint.Name && x.hash == hash);
if (preexistingDatapoints.Any())
{
var preexistingDatapoint = preexistingDatapoints.First();
embeddings = preexistingDatapoint.embeddings.ToDictionary(item => item.Item1, item => item.Item2);
}
}
if (embeddings.Count == 0)
{
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
}
var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, id_entity); int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity);
List<(string model, byte[] embedding)> data = [];
foreach ((string, float[]) embedding in datapoint.embeddings) foreach ((string, float[]) embedding in datapoint.embeddings)
{ {
DatabaseInsertEmbedding(id_datapoint, embedding.Item1, BytesFromFloatArray(embedding.Item2)); data.Add((embedding.Item1, BytesFromFloatArray(embedding.Item2)));
} }
DatabaseInsertEmbeddingBulk(id_datapoint, data);
datapoints.Add(datapoint); datapoints.Add(datapoint);
} }
@@ -315,27 +336,16 @@ public class Searchdomain
} }
} }
} }
Dictionary<string, Dictionary<string, float[]>> cache = []; // local cache ConcurrentQueue<Entity> retVal = [];
foreach (KeyValuePair<string, List<string>> cacheThis in toBeCached) Parallel.ForEach(jsonEntities, jSONEntity =>
{ {
string model = cacheThis.Key; var entity = EntityFromJSON(JsonSerializer.Serialize(jSONEntity));
List<string> contents = cacheThis.Value; if (entity is not null)
if (contents.Count == 0)
{ {
cache[model] = []; retVal.Enqueue(entity);
continue;
} }
cache[model] = Datapoint.GenerateEmbeddings(contents, model, ollama, embeddingCache); });
} return retVal.ToList();
var tempEmbeddingCache = embeddingCache;
embeddingCache = cache;
List<Entity> retVal = [];
foreach (JSONEntity jSONEntity in jsonEntities)
{
retVal.Append(EntityFromJSON(JsonSerializer.Serialize(jSONEntity)));
}
embeddingCache = tempEmbeddingCache;
return retVal;
} }
public void RemoveEntity(string name) public void RemoveEntity(string name)
@@ -384,15 +394,16 @@ public class Searchdomain
} }
public int DatabaseInsertDatapoint(string name, string probmethod_embedding, int id_entity) public int DatabaseInsertDatapoint(string name, string probmethod_embedding, string hash, int id_entity)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{ "name", name }, { "name", name },
{ "probmethod_embedding", probmethod_embedding }, { "probmethod_embedding", probmethod_embedding },
{ "hash", hash },
{ "id_entity", id_entity } { "id_entity", id_entity }
}; };
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, id_entity) VALUES (@name, @probmethod_embedding, @id_entity)", parameters); return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, hash, id_entity) VALUES (@name, @probmethod_embedding, @hash, @id_entity)", parameters);
} }
public int DatabaseInsertEmbedding(int id_datapoint, string model, byte[] embedding) public int DatabaseInsertEmbedding(int id_datapoint, string model, byte[] embedding)
@@ -405,4 +416,23 @@ public class Searchdomain
}; };
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO embedding (id_datapoint, model, embedding) VALUES (@id_datapoint, @model, @embedding)", parameters); return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO embedding (id_datapoint, model, embedding) VALUES (@id_datapoint, @model, @embedding)", parameters);
} }
public void DatabaseInsertEmbeddingBulk(int id_datapoint, List<(string model, byte[] embedding)> data)
{
Dictionary<string, object> parameters = [];
parameters["id_datapoint"] = id_datapoint;
var query = new StringBuilder("INSERT INTO embedding (id_datapoint, model, embedding) VALUES ");
foreach (var (model, embedding) in data)
{
string modelParam = $"model_{Guid.NewGuid()}".Replace("-", "");
string embeddingParam = $"embedding_{Guid.NewGuid()}".Replace("-", "");
parameters[modelParam] = model;
parameters[embeddingParam] = embedding;
query.Append($"(@id_datapoint, @{modelParam}, @{embeddingParam}), ");
}
query.Length -= 2; // remove trailing comma
helper.ExecuteSQLNonQuery(query.ToString(), parameters);
}
} }