Fixed DatabaseInsertEmbeddingBulk, Added attributes bulk edit and delete, Fixed entityCache not multithreading safe, fixed EntityFromJSON missing bulk inserts, Added retry logic for BulkExecuteNonQuery, added MaxRequestBodySize configuration

This commit is contained in:
2026-02-12 20:57:01 +01:00
parent 41fd8a067e
commit 4aabc3bae0
9 changed files with 271 additions and 109 deletions

View File

@@ -153,7 +153,8 @@ public class EntityController : ControllerBase
} }
searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_); searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_);
_databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain); _databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain);
searchdomain_.entityCache.RemoveAll(entity => entity.name == entityName); Entity toBeRemoved = searchdomain_.entityCache.First(entity => entity.name == entityName);
searchdomain_.entityCache = [.. searchdomain_.entityCache.Except([toBeRemoved])];
return Ok(new EntityDeleteResults() {Success = true}); return Ok(new EntityDeleteResults() {Success = true});
} }
} }

View File

@@ -1,4 +1,5 @@
using Shared; using Shared;
using Shared.Models;
namespace Server; namespace Server;
@@ -10,6 +11,15 @@ public class Datapoint
public List<(string, float[])> embeddings; public List<(string, float[])> embeddings;
public string hash; public string hash;
public Datapoint(string name, ProbMethodEnum probMethod, SimilarityMethodEnum similarityMethod, string hash, List<(string, float[])> embeddings)
{
this.name = name;
this.probMethod = new ProbMethod(probMethod);
this.similarityMethod = new SimilarityMethod(similarityMethod);
this.hash = hash;
this.embeddings = embeddings;
}
public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings) public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings)
{ {
this.name = name; this.name = name;

View File

@@ -39,14 +39,15 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
helper.ExecuteSQLNonQuery(query.ToString(), parameters); helper.ExecuteSQLNonQuery(query.ToString(), parameters);
} }
public static int DatabaseInsertEmbeddingBulk(SQLHelper helper, List<(string hash, string model, byte[] embedding)> data) public static int DatabaseInsertEmbeddingBulk(SQLHelper helper, List<(string name, string model, byte[] embedding)> data, int id_entity)
{ {
return helper.BulkExecuteNonQuery( return helper.BulkExecuteNonQuery(
"INSERT INTO embedding (id_datapoint, model, embedding) SELECT d.id, @model, @embedding FROM datapoint d WHERE d.hash = @hash", "INSERT INTO embedding (id_datapoint, model, embedding) SELECT d.id, @model, @embedding FROM datapoint d WHERE d.name = @name AND d.id_entity = @id_entity ORDER BY d.id LIMIT 1", // TODO: fix limitation - entity must not have 2 datapoints with the same content, i.e. hash
data.Select(element => new object[] { data.Select(element => new object[] {
new MySqlParameter("@model", element.model), new MySqlParameter("@model", element.model),
new MySqlParameter("@embedding", element.embedding), new MySqlParameter("@embedding", element.embedding),
new MySqlParameter("@hash", element.hash) new MySqlParameter("@name", element.name),
new MySqlParameter("@id_entity", id_entity)
}) })
); );
} }
@@ -96,6 +97,29 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
); );
} }
public static int DatabaseUpdateAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values)
{
return helper.BulkExecuteNonQuery(
"UPDATE attribute SET value=@value WHERE id_entity=@id_entity AND attribute=@attribute",
values.Select(element => new object[] {
new MySqlParameter("@attribute", element.attribute),
new MySqlParameter("@value", element.value),
new MySqlParameter("@id_entity", element.id_entity)
})
);
}
public static int DatabaseDeleteAttributes(SQLHelper helper, List<(string attribute, int id_entity)> values)
{
return helper.BulkExecuteNonQuery(
"DELETE FROM attribute WHERE id_entity=@id_entity AND attribute=@attribute",
values.Select(element => new object[] {
new MySqlParameter("@attribute", element.attribute),
new MySqlParameter("@id_entity", element.id_entity)
})
);
}
public static int DatabaseInsertDatapoints(SQLHelper helper, List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> values, int id_entity) public static int DatabaseInsertDatapoints(SQLHelper helper, List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> values, int id_entity)
{ {
return helper.BulkExecuteNonQuery( return helper.BulkExecuteNonQuery(
@@ -123,6 +147,38 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters); return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters);
} }
public static (int datapoints, int embeddings) DatabaseDeleteDatapoints(SQLHelper helper, List<string> values, int id_entity)
{
int embeddings = helper.BulkExecuteNonQuery(
"DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId",
values.Select(element => new object[] {
new MySqlParameter("@datapointName", element),
new MySqlParameter("@entityId", id_entity)
})
);
int datapoints = helper.BulkExecuteNonQuery(
"DELETE FROM datapoint WHERE name=@datapointName AND id_entity=@entityId",
values.Select(element => new object[] {
new MySqlParameter("@datapointName", element),
new MySqlParameter("@entityId", id_entity)
})
);
return (datapoints: datapoints, embeddings: embeddings);
}
public static int DatabaseUpdateDatapoint(SQLHelper helper, List<(string name, string probmethod_embedding, string similarityMethod)> values, int id_entity)
{
return helper.BulkExecuteNonQuery(
"UPDATE datapoint SET probmethod_embedding=@probmethod, similaritymethod=@similaritymethod WHERE id_entity=@entityId AND name=@datapointName",
values.Select(element => new object[] {
new MySqlParameter("@probmethod", element.probmethod_embedding),
new MySqlParameter("@similaritymethod", element.similarityMethod),
new MySqlParameter("@entityId", id_entity),
new MySqlParameter("@datapointName", element.name)
})
);
}
public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding) public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()

View File

@@ -87,14 +87,19 @@ public class SQLHelper:IDisposable
EnsureConnected(); EnsureConnected();
EnsureDbReaderIsClosed(); EnsureDbReaderIsClosed();
int affectedRows = 0;
int retries = 0;
while (retries <= 3)
{
try
{
using var transaction = connection.BeginTransaction(); using var transaction = connection.BeginTransaction();
using var command = connection.CreateCommand(); using var command = connection.CreateCommand();
command.CommandText = sql; command.CommandText = sql;
command.Transaction = transaction; command.Transaction = transaction;
int affectedRows = 0;
foreach (var parameters in parameterSets) foreach (var parameters in parameterSets)
{ {
command.Parameters.Clear(); command.Parameters.Clear();
@@ -103,6 +108,17 @@ public class SQLHelper:IDisposable
} }
transaction.Commit(); transaction.Commit();
break;
}
catch (Exception)
{
retries++;
if (retries > 3)
throw;
Thread.Sleep(10);
}
}
return affectedRows; return affectedRows;
} }
} }

View File

@@ -29,12 +29,12 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
return floatArray; return floatArray;
} }
public static bool CacheHasEntity(List<Entity> entityCache, string name) public static bool CacheHasEntity(ConcurrentBag<Entity> entityCache, string name)
{ {
return CacheGetEntity(entityCache, name) is not null; return CacheGetEntity(entityCache, name) is not null;
} }
public static Entity? CacheGetEntity(List<Entity> entityCache, string name) public static Entity? CacheGetEntity(ConcurrentBag<Entity> entityCache, string name)
{ {
foreach (Entity entity in entityCache) foreach (Entity entity in entityCache)
{ {
@@ -111,10 +111,14 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{ {
using SQLHelper helper = searchdomainManager.helper.DuplicateConnection(); using SQLHelper helper = searchdomainManager.helper.DuplicateConnection();
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain); Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
List<Entity> entityCache = searchdomain.entityCache; ConcurrentBag<Entity> entityCache = searchdomain.entityCache;
AIProvider aIProvider = searchdomain.aIProvider; AIProvider aIProvider = searchdomain.aIProvider;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); Entity? preexistingEntity;
lock (entityCache)
{
preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name);
}
bool invalidateSearchCache = false; bool invalidateSearchCache = false;
if (preexistingEntity is not null) if (preexistingEntity is not null)
@@ -127,7 +131,10 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
} }
Dictionary<string, string> attributes = jsonEntity.Attributes; Dictionary<string, string> attributes = jsonEntity.Attributes;
// Attribute // Attribute - get changes
List<(string attribute, string newValue, int entityId)> updatedAttributes = [];
List<(string attribute, int entityId)> deletedAttributes = [];
List<(string attributeKey, string attribute, int entityId)> addedAttributes = [];
foreach (KeyValuePair<string, string> attributesKV in preexistingEntity.attributes.ToList()) foreach (KeyValuePair<string, string> attributesKV in preexistingEntity.attributes.ToList())
{ {
string oldAttributeKey = attributesKV.Key; string oldAttributeKey = attributesKV.Key;
@@ -135,25 +142,10 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
bool newHasAttribute = jsonEntity.Attributes.TryGetValue(oldAttributeKey, out string? newAttribute); bool newHasAttribute = jsonEntity.Attributes.TryGetValue(oldAttributeKey, out string? newAttribute);
if (newHasAttribute && newAttribute is not null && newAttribute != oldAttribute) if (newHasAttribute && newAttribute is not null && newAttribute != oldAttribute)
{ {
// Attribute - Updated updatedAttributes.Add((attribute: oldAttributeKey, newValue: newAttribute, entityId: (int)preexistingEntityID));
Dictionary<string, dynamic> parameters = new()
{
{ "newValue", newAttribute },
{ "entityId", preexistingEntityID },
{ "attribute", oldAttributeKey}
};
helper.ExecuteSQLNonQuery("UPDATE attribute SET value=@newValue WHERE id_entity=@entityId AND attribute=@attribute", parameters);
preexistingEntity.attributes[oldAttributeKey] = newAttribute;
} else if (!newHasAttribute) } else if (!newHasAttribute)
{ {
// Attribute - Deleted deletedAttributes.Add((attribute: oldAttributeKey, entityId: (int)preexistingEntityID));
Dictionary<string, dynamic> parameters = new()
{
{ "entityId", preexistingEntityID },
{ "attribute", oldAttributeKey}
};
helper.ExecuteSQLNonQuery("DELETE FROM attribute WHERE id_entity=@entityId AND attribute=@attribute", parameters);
preexistingEntity.attributes.Remove(oldAttributeKey);
} }
} }
foreach (var attributesKV in jsonEntity.Attributes) foreach (var attributesKV in jsonEntity.Attributes)
@@ -164,12 +156,48 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
if (!preexistingHasAttribute) if (!preexistingHasAttribute)
{ {
// Attribute - New // Attribute - New
DatabaseHelper.DatabaseInsertAttribute(helper, newAttributeKey, newAttribute, (int)preexistingEntityID); addedAttributes.Add((attributeKey: newAttributeKey, attribute: newAttribute, entityId: (int)preexistingEntityID));
preexistingEntity.attributes.Add(newAttributeKey, newAttribute);
} }
} }
// Datapoint
// Attribute - apply changes
if (updatedAttributes.Count != 0)
{
// Update
DatabaseHelper.DatabaseUpdateAttributes(helper, updatedAttributes);
lock (preexistingEntity.attributes)
{
updatedAttributes.ForEach(attribute => preexistingEntity.attributes[attribute.attribute] = attribute.newValue);
}
}
if (deletedAttributes.Count != 0)
{
// Delete
DatabaseHelper.DatabaseDeleteAttributes(helper, deletedAttributes);
lock (preexistingEntity.attributes)
{
deletedAttributes.ForEach(attribute => preexistingEntity.attributes.Remove(attribute.attribute));
}
}
if (addedAttributes.Count != 0)
{
// Insert
DatabaseHelper.DatabaseInsertAttributes(helper, addedAttributes);
lock (preexistingEntity.attributes)
{
addedAttributes.ForEach(attribute => preexistingEntity.attributes.Add(attribute.attributeKey, attribute.attribute));
}
}
// Datapoint - get changes
List<Datapoint> deletedDatapointInstances = [];
List<string> deletedDatapoints = [];
List<(string datapointName, int entityId, JSONDatapoint jsonDatapoint, string hash)> updatedDatapointsText = [];
List<(string datapointName, string probMethod, string similarityMethod, int entityId, JSONDatapoint jsonDatapoint)> updatedDatapointsNonText = [];
List<Datapoint> createdDatapointInstances = [];
List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, Dictionary<string, float[]> embeddings, JSONDatapoint datapoint)> createdDatapoints = [];
foreach (Datapoint datapoint_ in preexistingEntity.datapoints.ToList()) foreach (Datapoint datapoint_ in preexistingEntity.datapoints.ToList())
{ {
Datapoint datapoint = datapoint_; // To enable replacing the datapoint reference as foreach iterators cannot be overwritten Datapoint datapoint = datapoint_; // To enable replacing the datapoint reference as foreach iterators cannot be overwritten
@@ -177,48 +205,43 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
if (!newEntityHasDatapoint) if (!newEntityHasDatapoint)
{ {
// Datapoint - Deleted // Datapoint - Deleted
Dictionary<string, dynamic> parameters = new() deletedDatapointInstances.Add(datapoint);
{ deletedDatapoints.Add(datapoint.name);
{ "datapointName", datapoint.name },
{ "entityId", preexistingEntityID}
};
helper.ExecuteSQLNonQuery("DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", parameters);
helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity=@entityId AND name=@datapointName", parameters);
preexistingEntity.datapoints.Remove(datapoint);
invalidateSearchCache = true; invalidateSearchCache = true;
} else } else
{ {
JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name); JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name);
if (newEntityDatapoint is not null && newEntityDatapoint.Text is not null) string? hash = newEntityDatapoint?.Text is not null ? GetHash(newEntityDatapoint) : null;
if (
newEntityDatapoint is not null
&& newEntityDatapoint.Text is not null
&& hash is not null
&& hash != datapoint.hash)
{ {
// Datapoint - Updated (text) // Datapoint - Updated (text)
Dictionary<string, dynamic> parameters = new() updatedDatapointsText.Add(new()
{ {
{ "datapointName", datapoint.name }, datapointName = newEntityDatapoint.Name,
{ "entityId", preexistingEntityID} entityId = (int)preexistingEntityID,
}; jsonDatapoint = newEntityDatapoint,
helper.ExecuteSQLNonQuery("DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", parameters); hash = hash
helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity=@entityId AND name=@datapointName", parameters); });
preexistingEntity.datapoints.Remove(datapoint);
Datapoint newDatapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, newEntityDatapoint, (int)preexistingEntityID);
preexistingEntity.datapoints.Add(newDatapoint);
datapoint = newDatapoint;
invalidateSearchCache = true; invalidateSearchCache = true;
} }
if (newEntityDatapoint is not null && (newEntityDatapoint.Probmethod_embedding != datapoint.probMethod.probMethodEnum || newEntityDatapoint.SimilarityMethod != datapoint.similarityMethod.similarityMethodEnum)) if (
newEntityDatapoint is not null
&& (newEntityDatapoint.Probmethod_embedding != datapoint.probMethod.probMethodEnum
|| newEntityDatapoint.SimilarityMethod != datapoint.similarityMethod.similarityMethodEnum))
{ {
// Datapoint - Updated (probmethod or similaritymethod) // Datapoint - Updated (probmethod or similaritymethod)
Dictionary<string, dynamic> parameters = new() updatedDatapointsNonText.Add(new()
{ {
{ "probmethod", newEntityDatapoint.Probmethod_embedding.ToString() }, datapointName = newEntityDatapoint.Name,
{ "similaritymethod", newEntityDatapoint.SimilarityMethod.ToString() }, entityId = (int)preexistingEntityID,
{ "datapointName", datapoint.name }, probMethod = newEntityDatapoint.Probmethod_embedding.ToString(),
{ "entityId", preexistingEntityID} similarityMethod = newEntityDatapoint.SimilarityMethod.ToString(),
}; jsonDatapoint = newEntityDatapoint
helper.ExecuteSQLNonQuery("UPDATE datapoint SET probmethod_embedding=@probmethod, similaritymethod=@similaritymethod WHERE id_entity=@entityId AND name=@datapointName", parameters); });
Datapoint preexistingDatapoint = preexistingEntity.datapoints.First(x => x == datapoint); // The for loop is a copy. This retrieves the original such that it can be updated.
preexistingDatapoint.probMethod = new(newEntityDatapoint.Probmethod_embedding, _logger);
preexistingDatapoint.similarityMethod = new(newEntityDatapoint.SimilarityMethod, _logger);
invalidateSearchCache = true; invalidateSearchCache = true;
} }
} }
@@ -229,12 +252,67 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
if (!oldEntityHasDatapoint) if (!oldEntityHasDatapoint)
{ {
// Datapoint - New // Datapoint - New
Datapoint datapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, jsonDatapoint, (int)preexistingEntityID); createdDatapoints.Add(new()
preexistingEntity.datapoints.Add(datapoint); {
name = jsonDatapoint.Name,
probmethod_embedding = jsonDatapoint.Probmethod_embedding,
similarityMethod = jsonDatapoint.SimilarityMethod,
hash = GetHash(jsonDatapoint),
embeddings = Datapoint.GetEmbeddings(
jsonDatapoint.Text ?? throw new Exception("jsonDatapoint.Text must not be null when retrieving embeddings"),
[.. jsonDatapoint.Model],
aIProvider,
embeddingCache
),
datapoint = jsonDatapoint
});
invalidateSearchCache = true; invalidateSearchCache = true;
} }
} }
// Datapoint - apply changes
// Deleted
if (deletedDatapointInstances.Count != 0)
{
DatabaseHelper.DatabaseDeleteDatapoints(helper, deletedDatapoints, (int)preexistingEntityID);
deletedDatapointInstances.ForEach(datapoint => preexistingEntity.datapoints.Remove(datapoint));
}
// Created
if (createdDatapoints.Count != 0)
{
List<Datapoint> datapoint = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. createdDatapoints.Select(element => (element.datapoint, element.hash))], (int)preexistingEntityID);
createdDatapoints.ForEach(datapoint => preexistingEntity.datapoints.Add(new(
datapoint.name,
datapoint.probmethod_embedding,
datapoint.similarityMethod,
datapoint.hash,
[.. datapoint.embeddings.Select(element => (element.Key, element.Value))])
));
}
// Datapoint - Updated (text)
if (updatedDatapointsText.Count != 0)
{
DatabaseHelper.DatabaseDeleteDatapoints(helper, [.. updatedDatapointsText.Select(datapoint => datapoint.datapointName)], (int)preexistingEntityID);
updatedDatapointsText.ForEach(datapoint => preexistingEntity.datapoints.RemoveAll(x => x.name == datapoint.datapointName));
List<Datapoint> datapoints = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. updatedDatapointsText.Select(element => (datapoint: element.jsonDatapoint, hash: element.hash))], (int)preexistingEntityID);
preexistingEntity.datapoints.AddRange(datapoints);
}
// Datapoint - Updated (probmethod or similaritymethod)
if (updatedDatapointsNonText.Count != 0)
{
DatabaseHelper.DatabaseUpdateDatapoint(
helper,
[.. updatedDatapointsNonText.Select(element => (element.datapointName, element.probMethod, element.similarityMethod))],
(int)preexistingEntityID
);
updatedDatapointsNonText.ForEach(element =>
{
Datapoint preexistingDatapoint = preexistingEntity.datapoints.First(x => x.name == element.datapointName);
preexistingDatapoint.probMethod = new(element.jsonDatapoint.Probmethod_embedding);
preexistingDatapoint.similarityMethod = new(element.jsonDatapoint.SimilarityMethod);
});
}
if (invalidateSearchCache) if (invalidateSearchCache)
{ {
searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity); searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity);
@@ -256,7 +334,6 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
} }
DatabaseHelper.DatabaseInsertAttributes(helper, toBeInsertedAttributes); DatabaseHelper.DatabaseInsertAttributes(helper, toBeInsertedAttributes);
List<Datapoint> datapoints = [];
List<(JSONDatapoint datapoint, string hash)> toBeInsertedDatapoints = []; List<(JSONDatapoint datapoint, string hash)> toBeInsertedDatapoints = [];
foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints) foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints)
{ {
@@ -267,7 +344,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
hash = hash hash = hash
}); });
} }
List<Datapoint> datapoint = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, toBeInsertedDatapoints, id_entity); List<Datapoint> datapoints = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, toBeInsertedDatapoints, id_entity);
var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new ProbMethodNotFoundException(jsonEntity.Probmethod); var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new ProbMethodNotFoundException(jsonEntity.Probmethod);
Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), datapoints, jsonEntity.Name) Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), datapoints, jsonEntity.Name)
@@ -285,7 +362,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{ {
List<Datapoint> result = []; List<Datapoint> result = [];
List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> toBeInsertedDatapoints = []; List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> toBeInsertedDatapoints = [];
List<(string hash, string model, byte[] embedding)> toBeInsertedEmbeddings = []; List<(string name, string model, byte[] embedding)> toBeInsertedEmbeddings = [];
foreach ((JSONDatapoint datapoint, string hash) value in values) foreach ((JSONDatapoint datapoint, string hash) value in values)
{ {
Datapoint datapoint = BuildDatapointFromJsonDatapoint(value.datapoint, id_entity, searchdomain, value.hash); Datapoint datapoint = BuildDatapointFromJsonDatapoint(value.datapoint, id_entity, searchdomain, value.hash);
@@ -300,7 +377,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{ {
toBeInsertedEmbeddings.Add(new() toBeInsertedEmbeddings.Add(new()
{ {
hash = value.hash, name = datapoint.name,
model = embedding.Item1, model = embedding.Item1,
embedding = BytesFromFloatArray(embedding.Item2) embedding = BytesFromFloatArray(embedding.Item2)
}); });
@@ -308,8 +385,8 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
result.Add(datapoint); result.Add(datapoint);
} }
DatabaseHelper.DatabaseInsertDatapoints(helper, toBeInsertedDatapoints, id_entity); int insertedDatapoints = DatabaseHelper.DatabaseInsertDatapoints(helper, toBeInsertedDatapoints, id_entity);
DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, toBeInsertedEmbeddings); int insertedEmbeddings = DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, toBeInsertedEmbeddings, id_entity);
return result; return result;
} }
@@ -319,7 +396,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{ {
throw new Exception("jsonDatapoint.Text must not be null at this point"); throw new Exception("jsonDatapoint.Text must not be null at this point");
} }
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); hash ??= GetHash(jsonDatapoint);
Datapoint datapoint = BuildDatapointFromJsonDatapoint(jsonDatapoint, id_entity, searchdomain, hash); Datapoint datapoint = BuildDatapointFromJsonDatapoint(jsonDatapoint, id_entity, searchdomain, hash);
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
List<(string model, byte[] embedding)> data = []; List<(string model, byte[] embedding)> data = [];
@@ -331,6 +408,11 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
return datapoint; return datapoint;
} }
public string GetHash(JSONDatapoint jsonDatapoint)
{
return Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text ?? throw new Exception("jsonDatapoint.Text must not be null to compute hash"))));
}
public Datapoint BuildDatapointFromJsonDatapoint(JSONDatapoint jsonDatapoint, int entityId, Searchdomain searchdomain, string? hash = null) public Datapoint BuildDatapointFromJsonDatapoint(JSONDatapoint jsonDatapoint, int entityId, Searchdomain searchdomain, string? hash = null)
{ {
if (jsonDatapoint.Text is null) if (jsonDatapoint.Text is null)
@@ -342,8 +424,8 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding); var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding);
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod); var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod);
return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
} }

View File

@@ -12,6 +12,7 @@ public class EmbeddingSearchOptions : ApiKeyOptions
public required SimpleAuthOptions SimpleAuth { get; set; } public required SimpleAuthOptions SimpleAuth { get; set; }
public required CacheOptions Cache { get; set; } public required CacheOptions Cache { get; set; }
public required bool UseHttpsRedirection { get; set; } public required bool UseHttpsRedirection { get; set; }
public int? MaxRequestBodySize { get; set; }
} }
public class AiProvider public class AiProvider

View File

@@ -10,16 +10,11 @@ public class ProbMethod
public ProbMethodEnum probMethodEnum; public ProbMethodEnum probMethodEnum;
public string name; public string name;
public ProbMethod(ProbMethodEnum probMethodEnum, ILogger logger) public ProbMethod(ProbMethodEnum probMethodEnum)
{ {
this.probMethodEnum = probMethodEnum; this.probMethodEnum = probMethodEnum;
this.name = probMethodEnum.ToString(); this.name = probMethodEnum.ToString();
Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name); Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name) ?? throw new ProbMethodNotFoundException(probMethodEnum);
if (probMethod is null)
{
logger.LogError("Unable to retrieve probMethod {name}", [name]);
throw new ProbMethodNotFoundException(probMethodEnum);
}
method = probMethod; method = probMethod;
} }
} }

View File

@@ -7,6 +7,7 @@ using Server.Helper;
using Shared; using Shared;
using Shared.Models; using Shared.Models;
using AdaptiveExpressions; using AdaptiveExpressions;
using System.Collections.Concurrent;
namespace Server; namespace Server;
@@ -19,7 +20,7 @@ public class Searchdomain
public int id; public int id;
public SearchdomainSettings settings; public SearchdomainSettings settings;
public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp) public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp)
public List<Entity> entityCache; public ConcurrentBag<Entity> entityCache;
public List<string> modelsInUse; public List<string> modelsInUse;
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache; public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
private readonly MySqlConnection connection; private readonly MySqlConnection connection;
@@ -105,8 +106,8 @@ public class Searchdomain
typeof(SimilarityMethodEnum), typeof(SimilarityMethodEnum),
similarityMethodString similarityMethodString
); );
ProbMethod probmethod = new(probmethodEnum, _logger); ProbMethod probmethod = new(probmethodEnum);
SimilarityMethod similarityMethod = new(similairtyMethodEnum, _logger); SimilarityMethod similarityMethod = new(similairtyMethodEnum);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null) if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{ {
embedding_unassigned.Remove(id); embedding_unassigned.Remove(id);
@@ -173,7 +174,6 @@ public class Searchdomain
Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query); Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query);
List<(float, string)> result = []; List<(float, string)> result = [];
foreach (Entity entity in entityCache) foreach (Entity entity in entityCache)
{ {
result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.name)); result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.name));
@@ -219,7 +219,10 @@ public class Searchdomain
public void UpdateModelsInUse() public void UpdateModelsInUse()
{ {
modelsInUse = GetModels(entityCache.ToList()); lock (modelsInUse)
{
modelsInUse = GetModels(entityCache);
}
} }
private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary<string, float[]> queryEmbeddings) private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary<string, float[]> queryEmbeddings)
@@ -240,14 +243,16 @@ public class Searchdomain
return entity.probMethod(datapointProbs); return entity.probMethod(datapointProbs);
} }
public static List<string> GetModels(List<Entity> entities) public static List<string> GetModels(ConcurrentBag<Entity> entities)
{ {
List<string> result = []; List<string> result = [];
lock (entities)
{
foreach (Entity entity in entities) foreach (Entity entity in entities)
{
lock (entity)
{ {
foreach (Datapoint datapoint in entity.datapoints) foreach (Datapoint datapoint in entity.datapoints)
{
lock (entity.datapoints)
{ {
foreach ((string, float[]) tuple in datapoint.embeddings) foreach ((string, float[]) tuple in datapoint.embeddings)
{ {
@@ -260,6 +265,7 @@ public class Searchdomain
} }
} }
} }
}
return result; return result;
} }

View File

@@ -9,16 +9,11 @@ public class SimilarityMethod
public SimilarityMethodEnum similarityMethodEnum; public SimilarityMethodEnum similarityMethodEnum;
public string name; public string name;
public SimilarityMethod(SimilarityMethodEnum similarityMethodEnum, ILogger logger) public SimilarityMethod(SimilarityMethodEnum similarityMethodEnum)
{ {
this.similarityMethodEnum = similarityMethodEnum; this.similarityMethodEnum = similarityMethodEnum;
this.name = similarityMethodEnum.ToString(); this.name = similarityMethodEnum.ToString();
SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name); SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name) ?? throw new Exception($"Unable to retrieve similarityMethod {name}");
if (probMethod is null)
{
logger.LogError("Unable to retrieve similarityMethod {name}", [name]);
throw new Exception("Unable to retrieve similarityMethod");
}
method = probMethod; method = probMethod;
} }
} }