Added hash based optimization

This commit is contained in:
2025-06-11 13:23:02 +02:00
parent 371a16511a
commit e6211a185b
3 changed files with 75 additions and 32 deletions

View File

@@ -18,6 +18,9 @@ using System.Collections.Immutable;
using System.Text.Json;
using System.Numerics.Tensors;
using Server;
using System.Security.Cryptography;
using System.Text;
using System.Collections.Concurrent;
namespace Server;
@@ -89,7 +92,7 @@ public class Searchdomain
}
embeddingReader.Close();
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding FROM datapoint", parametersIDSearchdomain);
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, hash FROM datapoint", parametersIDSearchdomain);
Dictionary<int, List<Datapoint>> datapoint_unassigned = [];
while (datapointReader.Read())
{
@@ -97,6 +100,7 @@ public class Searchdomain
int id_entity = datapointReader.GetInt32(1);
string name = datapointReader.GetString(2);
string probmethodString = datapointReader.GetString(3);
string hash = datapointReader.GetString(4);
Probmethods.probMethodDelegate? probmethod = probmethods.GetMethod(probmethodString);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{
@@ -105,7 +109,7 @@ public class Searchdomain
{
datapoint_unassigned[id_entity] = [];
}
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
}
}
datapointReader.Close();
@@ -257,29 +261,46 @@ public class Searchdomain
{
return null;
}
if (HasEntity(jsonEntity.Name))
bool hasPreexistingEntity = HasEntity(jsonEntity.Name);
Entity? preexistingEntity = null;
if (hasPreexistingEntity)
{
RemoveEntity(jsonEntity.Name);
preexistingEntity = GetEntity(jsonEntity.Name);
RemoveEntity(jsonEntity.Name); // TODO only remove entity if there is actually a change somewhere. Perhaps create 3 datapoint lists to operate with: 1. delete, 2. update, 3. create
}
int id_entity = DatabaseInsertEntity(jsonEntity.Name, jsonEntity.Probmethod, id);
foreach (KeyValuePair<string, string> attribute in jsonEntity.Attributes)
{
DatabaseInsertAttribute(attribute.Key, attribute.Value, id_entity);
}
List<Datapoint> datapoints = [];
List<Datapoint> datapoints = [];
foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints)
{
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
Dictionary<string, float[]> embeddings = [];
string hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
if (hasPreexistingEntity && preexistingEntity is not null)
{
IEnumerable<Datapoint> preexistingDatapoints = preexistingEntity.datapoints.Where(x => x.name == jsonDatapoint.Name && x.hash == hash);
if (preexistingDatapoints.Any())
{
var preexistingDatapoint = preexistingDatapoints.First();
embeddings = preexistingDatapoint.embeddings.ToDictionary(item => item.Item1, item => item.Item2);
}
}
if (embeddings.Count == 0)
{
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
}
var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, id_entity);
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity);
List<(string model, byte[] embedding)> data = [];
foreach ((string, float[]) embedding in datapoint.embeddings)
{
DatabaseInsertEmbedding(id_datapoint, embedding.Item1, BytesFromFloatArray(embedding.Item2));
data.Add((embedding.Item1, BytesFromFloatArray(embedding.Item2)));
}
DatabaseInsertEmbeddingBulk(id_datapoint, data);
datapoints.Add(datapoint);
}
@@ -315,27 +336,16 @@ public class Searchdomain
}
}
}
Dictionary<string, Dictionary<string, float[]>> cache = []; // local cache
foreach (KeyValuePair<string, List<string>> cacheThis in toBeCached)
ConcurrentQueue<Entity> retVal = [];
Parallel.ForEach(jsonEntities, jSONEntity =>
{
string model = cacheThis.Key;
List<string> contents = cacheThis.Value;
if (contents.Count == 0)
var entity = EntityFromJSON(JsonSerializer.Serialize(jSONEntity));
if (entity is not null)
{
cache[model] = [];
continue;
retVal.Enqueue(entity);
}
cache[model] = Datapoint.GenerateEmbeddings(contents, model, ollama, embeddingCache);
}
var tempEmbeddingCache = embeddingCache;
embeddingCache = cache;
List<Entity> retVal = [];
foreach (JSONEntity jSONEntity in jsonEntities)
{
retVal.Append(EntityFromJSON(JsonSerializer.Serialize(jSONEntity)));
}
embeddingCache = tempEmbeddingCache;
return retVal;
});
return retVal.ToList();
}
public void RemoveEntity(string name)
@@ -384,15 +394,16 @@ public class Searchdomain
}
public int DatabaseInsertDatapoint(string name, string probmethod_embedding, int id_entity)
public int DatabaseInsertDatapoint(string name, string probmethod_embedding, string hash, int id_entity)
{
Dictionary<string, dynamic> parameters = new()
{
{ "name", name },
{ "probmethod_embedding", probmethod_embedding },
{ "hash", hash },
{ "id_entity", id_entity }
};
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, id_entity) VALUES (@name, @probmethod_embedding, @id_entity)", parameters);
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, hash, id_entity) VALUES (@name, @probmethod_embedding, @hash, @id_entity)", parameters);
}
public int DatabaseInsertEmbedding(int id_datapoint, string model, byte[] embedding)
@@ -405,4 +416,23 @@ public class Searchdomain
};
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO embedding (id_datapoint, model, embedding) VALUES (@id_datapoint, @model, @embedding)", parameters);
}
public void DatabaseInsertEmbeddingBulk(int id_datapoint, List<(string model, byte[] embedding)> data)
{
Dictionary<string, object> parameters = [];
parameters["id_datapoint"] = id_datapoint;
var query = new StringBuilder("INSERT INTO embedding (id_datapoint, model, embedding) VALUES ");
foreach (var (model, embedding) in data)
{
string modelParam = $"model_{Guid.NewGuid()}".Replace("-", "");
string embeddingParam = $"embedding_{Guid.NewGuid()}".Replace("-", "");
parameters[modelParam] = model;
parameters[embeddingParam] = embedding;
query.Append($"(@id_datapoint, @{modelParam}, @{embeddingParam}), ");
}
query.Length -= 2; // remove trailing comma
helper.ExecuteSQLNonQuery(query.ToString(), parameters);
}
}