diff --git a/src/Server/Controllers/EntityController.cs b/src/Server/Controllers/EntityController.cs index cc72a33..d103dd1 100644 --- a/src/Server/Controllers/EntityController.cs +++ b/src/Server/Controllers/EntityController.cs @@ -46,7 +46,7 @@ public class EntityController : ControllerBase (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); EntityListResults entityListResults = new() {Results = [], Success = true}; - foreach (Entity entity in searchdomain_.entityCache) + foreach ((string _, Entity entity) in searchdomain_.entityCache) { List attributeResults = []; foreach (KeyValuePair attribute in entity.attributes) @@ -90,11 +90,11 @@ public class EntityController : ControllerBase /// /// Entities to index [HttpPut("/Entities")] - public ActionResult Index([FromBody] List? jsonEntities) + public async Task> Index([FromBody] List? jsonEntities) { try { - List? entities = _searchdomainHelper.EntitiesFromJSON( + List? entities = await _searchdomainHelper.EntitiesFromJSON( _domainManager, _logger, JsonSerializer.Serialize(jsonEntities)); @@ -135,7 +135,7 @@ public class EntityController : ControllerBase /// Name of the searchdomain /// Name of the entity [HttpDelete] - public ActionResult Delete(string searchdomain, string entityName) + public async Task> Delete(string searchdomain, string entityName) { (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); @@ -152,9 +152,10 @@ public class EntityController : ControllerBase return Ok(new EntityDeleteResults() {Success = false, Message = "Entity not found"}); } searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_); - _databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain); - Entity toBeRemoved = searchdomain_.entityCache.First(entity => entity.name == entityName); - searchdomain_.entityCache = [.. searchdomain_.entityCache.Except([toBeRemoved])]; - return Ok(new EntityDeleteResults() {Success = true}); + await _databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain); + + bool success = searchdomain_.entityCache.TryRemove(entityName, out Entity? _); + + return Ok(new EntityDeleteResults() {Success = success}); } } diff --git a/src/Server/Controllers/Frontend/HomeController.cs b/src/Server/Controllers/Frontend/HomeController.cs index a54bd56..f6aa77c 100644 --- a/src/Server/Controllers/Frontend/HomeController.cs +++ b/src/Server/Controllers/Frontend/HomeController.cs @@ -35,11 +35,11 @@ public class HomeController : Controller [Authorize] [HttpGet("Searchdomains")] - public IActionResult Searchdomains() + public async Task Searchdomains() { HomeIndexViewModel viewModel = new() { - Searchdomains = _domainManager.ListSearchdomains() + Searchdomains = await _domainManager.ListSearchdomainsAsync() }; return View(viewModel); } diff --git a/src/Server/Controllers/SearchdomainController.cs b/src/Server/Controllers/SearchdomainController.cs index ffd1e8e..57228e1 100644 --- a/src/Server/Controllers/SearchdomainController.cs +++ b/src/Server/Controllers/SearchdomainController.cs @@ -29,12 +29,12 @@ public class SearchdomainController : ControllerBase /// Lists all searchdomains /// [HttpGet("/Searchdomains")] - public ActionResult List() + public async Task> List() { List results; try { - results = _domainManager.ListSearchdomains(); + results = await _domainManager.ListSearchdomainsAsync(); } catch (Exception) { @@ -51,7 +51,7 @@ public class SearchdomainController : ControllerBase /// Name of the searchdomain /// Optional initial settings [HttpPost] - public ActionResult Create([Required]string searchdomain, [FromBody]SearchdomainSettings settings = new()) + public async Task> Create([Required]string searchdomain, [FromBody]SearchdomainSettings settings = new()) { try { @@ -59,7 +59,7 @@ public class SearchdomainController : ControllerBase { settings.QueryCacheSize = 1_000_000; // TODO get rid of this magic number } - int id = _domainManager.CreateSearchdomain(searchdomain, settings); + int id = await _domainManager.CreateSearchdomain(searchdomain, settings); return Ok(new SearchdomainCreateResults(){Id = id, Success = true}); } catch (Exception) { @@ -73,7 +73,7 @@ public class SearchdomainController : ControllerBase /// /// Name of the searchdomain [HttpDelete] - public ActionResult Delete([Required]string searchdomain) + public async Task> Delete([Required]string searchdomain) { bool success; int deletedEntries; @@ -81,7 +81,7 @@ public class SearchdomainController : ControllerBase try { success = true; - deletedEntries = _domainManager.DeleteSearchdomain(searchdomain); + deletedEntries = await _domainManager.DeleteSearchdomain(searchdomain); } catch (SearchdomainNotFoundException ex) { @@ -165,7 +165,7 @@ public class SearchdomainController : ControllerBase { Name = r.Item2, Value = r.Item1, - Attributes = returnAttributes ? (searchdomain_.entityCache.FirstOrDefault(x => x.name == r.Item2)?.attributes ?? null) : null + Attributes = returnAttributes ? (searchdomain_.entityCache[r.Item2]?.attributes ?? null) : null })]; return Ok(new EntityQueryResults(){Results = queryResults, Success = true }); } diff --git a/src/Server/Controllers/ServerController.cs b/src/Server/Controllers/ServerController.cs index e333aa3..35a0e01 100644 --- a/src/Server/Controllers/ServerController.cs +++ b/src/Server/Controllers/ServerController.cs @@ -75,7 +75,7 @@ public class ServerController : ControllerBase long queryCacheElementCount = 0; long queryCacheMaxElementCountAll = 0; long queryCacheMaxElementCountLoadedSearchdomainsOnly = 0; - foreach (string searchdomain in _searchdomainManager.ListSearchdomains()) + foreach (string searchdomain in await _searchdomainManager.ListSearchdomainsAsync()) { if (SearchdomainHelper.IsSearchdomainLoaded(_searchdomainManager, searchdomain)) { diff --git a/src/Server/Datapoint.cs b/src/Server/Datapoint.cs index 8240252..91f3062 100644 --- a/src/Server/Datapoint.cs +++ b/src/Server/Datapoint.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using Shared; using Shared.Models; @@ -10,23 +11,26 @@ public class Datapoint public SimilarityMethod similarityMethod; public List<(string, float[])> embeddings; public string hash; + public int id; - public Datapoint(string name, ProbMethodEnum probMethod, SimilarityMethodEnum similarityMethod, string hash, List<(string, float[])> embeddings) + public Datapoint(string name, ProbMethodEnum probMethod, SimilarityMethodEnum similarityMethod, string hash, List<(string, float[])> embeddings, int id) { this.name = name; this.probMethod = new ProbMethod(probMethod); this.similarityMethod = new SimilarityMethod(similarityMethod); this.hash = hash; this.embeddings = embeddings; + this.id = id; } - public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings) + public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings, int id) { this.name = name; this.probMethod = probMethod; this.similarityMethod = similarityMethod; this.hash = hash; this.embeddings = embeddings; + this.id = id; } public float CalcProbability(List<(string, float)> probabilities) @@ -34,18 +38,19 @@ public class Datapoint return probMethod.method(probabilities); } - public static Dictionary GetEmbeddings(string content, List models, AIProvider aIProvider, EnumerableLruCache> embeddingCache) + public static Dictionary GetEmbeddings(string content, ConcurrentBag models, AIProvider aIProvider, EnumerableLruCache> embeddingCache) { Dictionary embeddings = []; bool embeddingCacheHasContent = embeddingCache.TryGetValue(content, out var embeddingCacheForContent); if (!embeddingCacheHasContent || embeddingCacheForContent is null) { - models.ForEach(model => - embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache) - ); + foreach (string model in models) + { + embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache); + } return embeddings; } - models.ForEach(model => + foreach (string model in models) { bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embeddingCacheForModel); if (embeddingCacheHasModel && embeddingCacheForModel is not null) @@ -55,7 +60,7 @@ public class Datapoint { embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache); } - }); + } return embeddings; } diff --git a/src/Server/Entity.cs b/src/Server/Entity.cs index 681b96b..ad5cb1c 100644 --- a/src/Server/Entity.cs +++ b/src/Server/Entity.cs @@ -1,11 +1,13 @@ +using System.Collections.Concurrent; + namespace Server; -public class Entity(Dictionary attributes, Probmethods.probMethodDelegate probMethod, string probMethodName, List datapoints, string name) +public class Entity(Dictionary attributes, Probmethods.probMethodDelegate probMethod, string probMethodName, ConcurrentBag datapoints, string name) { public Dictionary attributes = attributes; public Probmethods.probMethodDelegate probMethod = probMethod; public string probMethodName = probMethodName; - public List datapoints = datapoints; + public ConcurrentBag datapoints = datapoints; public int id; public string name = name; } \ No newline at end of file diff --git a/src/Server/Helper/DatabaseHelper.cs b/src/Server/Helper/DatabaseHelper.cs index ed0d5ab..d1113e0 100644 --- a/src/Server/Helper/DatabaseHelper.cs +++ b/src/Server/Helper/DatabaseHelper.cs @@ -20,11 +20,13 @@ public class DatabaseHelper(ILogger logger) return new SQLHelper(connection, connectionString); } - public static void DatabaseInsertEmbeddingBulk(SQLHelper helper, int id_datapoint, List<(string model, byte[] embedding)> data) + public static async Task DatabaseInsertEmbeddingBulk(SQLHelper helper, int id_datapoint, List<(string model, byte[] embedding)> data, int id_entity, int id_searchdomain) { Dictionary parameters = []; parameters["id_datapoint"] = id_datapoint; - var query = new StringBuilder("INSERT INTO embedding (id_datapoint, model, embedding) VALUES "); + parameters["id_entity"] = id_entity; + parameters["id_searchdomain"] = id_searchdomain; + var query = new StringBuilder("INSERT INTO embedding (id_datapoint, model, embedding, id_embedding, id_searchdomain) VALUES "); foreach (var (model, embedding) in data) { string modelParam = $"model_{Guid.NewGuid()}".Replace("-", ""); @@ -32,38 +34,39 @@ public class DatabaseHelper(ILogger logger) parameters[modelParam] = model; parameters[embeddingParam] = embedding; - query.Append($"(@id_datapoint, @{modelParam}, @{embeddingParam}), "); + query.Append($"(@id_datapoint, @{modelParam}, @{embeddingParam}, @id_entity), "); } query.Length -= 2; // remove trailing comma - helper.ExecuteSQLNonQuery(query.ToString(), parameters); + await helper.ExecuteSQLNonQuery(query.ToString(), parameters); } - public static int DatabaseInsertEmbeddingBulk(SQLHelper helper, List<(string name, string model, byte[] embedding)> data, int id_entity) + public static async Task DatabaseInsertEmbeddingBulk(SQLHelper helper, List<(int id_datapoint, string model, byte[] embedding)> data, int id_entity, int id_searchdomain) { - return helper.BulkExecuteNonQuery( - "INSERT INTO embedding (id_datapoint, model, embedding) SELECT d.id, @model, @embedding FROM datapoint d WHERE d.name = @name AND d.id_entity = @id_entity ORDER BY d.id LIMIT 1", // TODO: fix limitation - entity must not have 2 datapoints with the same content, i.e. hash + return await helper.BulkExecuteNonQuery( + "INSERT INTO embedding (id_datapoint, model, embedding, id_entity, id_searchdomain) VALUES (@id_datapoint, @model, @embedding, @id_entity, @id_searchdomain);", data.Select(element => new object[] { new MySqlParameter("@model", element.model), new MySqlParameter("@embedding", element.embedding), - new MySqlParameter("@name", element.name), - new MySqlParameter("@id_entity", id_entity) + new MySqlParameter("@id_datapoint", element.id_datapoint), + new MySqlParameter("@id_entity", id_entity), + new MySqlParameter("@id_searchdomain", id_searchdomain) }) ); } - public static int DatabaseInsertSearchdomain(SQLHelper helper, string name, SearchdomainSettings settings = new()) + public static async Task DatabaseInsertSearchdomain(SQLHelper helper, string name, SearchdomainSettings settings = new()) { Dictionary parameters = new() { { "name", name }, { "settings", settings} }; - return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters); + return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters); } - public static int DatabaseInsertEntity(SQLHelper helper, string name, ProbMethodEnum probmethod, int id_searchdomain) + public static async Task DatabaseInsertEntity(SQLHelper helper, string name, ProbMethodEnum probmethod, int id_searchdomain) { Dictionary parameters = new() { @@ -71,24 +74,13 @@ public class DatabaseHelper(ILogger logger) { "probmethod", probmethod.ToString() }, { "id_searchdomain", id_searchdomain } }; - return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO entity (name, probmethod, id_searchdomain) VALUES (@name, @probmethod, @id_searchdomain)", parameters); + return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO entity (name, probmethod, id_searchdomain) VALUES (@name, @probmethod, @id_searchdomain);", parameters); } - public static int DatabaseInsertAttribute(SQLHelper helper, string attribute, string value, int id_entity) + public static async Task DatabaseInsertAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values) //string[] attribute, string value, int id_entity) { - Dictionary parameters = new() - { - { "attribute", attribute }, - { "value", value }, - { "id_entity", id_entity } - }; - return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", parameters); - } - - public static int DatabaseInsertAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values) //string[] attribute, string value, int id_entity) - { - return helper.BulkExecuteNonQuery( - "INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", + return await helper.BulkExecuteNonQuery( + "INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity);", values.Select(element => new object[] { new MySqlParameter("@attribute", element.attribute), new MySqlParameter("@value", element.value), @@ -97,9 +89,9 @@ public class DatabaseHelper(ILogger logger) ); } - public static int DatabaseUpdateAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values) + public static async Task DatabaseUpdateAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values) { - return helper.BulkExecuteNonQuery( + return await helper.BulkExecuteNonQuery( "UPDATE attribute SET value=@value WHERE id_entity=@id_entity AND attribute=@attribute", values.Select(element => new object[] { new MySqlParameter("@attribute", element.attribute), @@ -109,9 +101,9 @@ public class DatabaseHelper(ILogger logger) ); } - public static int DatabaseDeleteAttributes(SQLHelper helper, List<(string attribute, int id_entity)> values) + public static async Task DatabaseDeleteAttributes(SQLHelper helper, List<(string attribute, int id_entity)> values) { - return helper.BulkExecuteNonQuery( + return await helper.BulkExecuteNonQuery( "DELETE FROM attribute WHERE id_entity=@id_entity AND attribute=@attribute", values.Select(element => new object[] { new MySqlParameter("@attribute", element.attribute), @@ -120,10 +112,10 @@ public class DatabaseHelper(ILogger logger) ); } - public static int DatabaseInsertDatapoints(SQLHelper helper, List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> values, int id_entity) + public static async Task DatabaseInsertDatapoints(SQLHelper helper, List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> values, int id_entity) { - return helper.BulkExecuteNonQuery( - "INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", + return await helper.BulkExecuteNonQuery( + "INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity);", values.Select(element => new object[] { new MySqlParameter("@name", element.name), new MySqlParameter("@probmethod_embedding", element.probmethod_embedding), @@ -134,7 +126,7 @@ public class DatabaseHelper(ILogger logger) ); } - public static int DatabaseInsertDatapoint(SQLHelper helper, string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, int id_entity) + public static async Task DatabaseInsertDatapoint(SQLHelper helper, string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, int id_entity) { Dictionary parameters = new() { @@ -144,19 +136,19 @@ public class DatabaseHelper(ILogger logger) { "hash", hash }, { "id_entity", id_entity } }; - return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters); + return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters); } - public static (int datapoints, int embeddings) DatabaseDeleteDatapoints(SQLHelper helper, List values, int id_entity) + public static async Task<(int datapoints, int embeddings)> DatabaseDeleteEmbeddingsAndDatapoints(SQLHelper helper, List values, int id_entity) { - int embeddings = helper.BulkExecuteNonQuery( - "DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", + int embeddings = await helper.BulkExecuteNonQuery( + "DELETE e FROM embedding e WHERE id_entity = @entityId", values.Select(element => new object[] { new MySqlParameter("@datapointName", element), new MySqlParameter("@entityId", id_entity) }) ); - int datapoints = helper.BulkExecuteNonQuery( + int datapoints = await helper.BulkExecuteNonQuery( "DELETE FROM datapoint WHERE name=@datapointName AND id_entity=@entityId", values.Select(element => new object[] { new MySqlParameter("@datapointName", element), @@ -166,9 +158,9 @@ public class DatabaseHelper(ILogger logger) return (datapoints: datapoints, embeddings: embeddings); } - public static int DatabaseUpdateDatapoint(SQLHelper helper, List<(string name, string probmethod_embedding, string similarityMethod)> values, int id_entity) + public static async Task DatabaseUpdateDatapoint(SQLHelper helper, List<(string name, string probmethod_embedding, string similarityMethod)> values, int id_entity) { - return helper.BulkExecuteNonQuery( + return await helper.BulkExecuteNonQuery( "UPDATE datapoint SET probmethod_embedding=@probmethod, similaritymethod=@similaritymethod WHERE id_entity=@entityId AND name=@datapointName", values.Select(element => new object[] { new MySqlParameter("@probmethod", element.probmethod_embedding), @@ -179,108 +171,120 @@ public class DatabaseHelper(ILogger logger) ); } - public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding) + public static async Task DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding, int id_entity, int id_searchdomain) { Dictionary parameters = new() { { "id_datapoint", id_datapoint }, { "model", model }, - { "embedding", embedding } + { "embedding", embedding }, + { "id_entity", id_entity }, + { "id_searchdomain", id_searchdomain } }; - return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO embedding (id_datapoint, model, embedding) VALUES (@id_datapoint, @model, @embedding)", parameters); + return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO embedding (id_datapoint, model, embedding, id_entity, id_searchdomain) VALUES (@id_datapoint, @model, @embedding, @id_entity, @id_searchdomain)", parameters); } - public int GetSearchdomainID(SQLHelper helper, string searchdomain) + public async Task GetSearchdomainID(SQLHelper helper, string searchdomain) { - Dictionary parameters = new() + Dictionary parameters = new() { { "searchdomain", searchdomain} }; - lock (helper.connection) - { - DbDataReader reader = helper.ExecuteSQLCommand("SELECT id FROM searchdomain WHERE name = @searchdomain", parameters); - bool success = reader.Read(); - int result = success ? reader.GetInt32(0) : 0; - reader.Close(); - if (success) - { - return result; - } - else - { - _logger.LogError("Unable to retrieve searchdomain ID for {searchdomain}", [searchdomain]); - throw new SearchdomainNotFoundException(searchdomain); - } - } + return (await helper.ExecuteQueryAsync("SELECT id FROM searchdomain WHERE name = @searchdomain", parameters, x => x.GetInt32(0))).First(); } - public void RemoveEntity(List entityCache, SQLHelper helper, string name, string searchdomain) + public async Task RemoveEntity(List entityCache, SQLHelper helper, string name, string searchdomain) { Dictionary parameters = new() { { "name", name }, - { "searchdomain", GetSearchdomainID(helper, searchdomain)} + { "searchdomain", await GetSearchdomainID(helper, searchdomain)} }; - helper.ExecuteSQLNonQuery("DELETE embedding.* FROM embedding JOIN datapoint dp ON id_datapoint = dp.id JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); - helper.ExecuteSQLNonQuery("DELETE datapoint.* FROM datapoint JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); - helper.ExecuteSQLNonQuery("DELETE attribute.* FROM attribute JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); - helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE name = @name AND entity.id_searchdomain = @searchdomain", parameters); + await helper.ExecuteSQLNonQuery("DELETE embedding.* FROM embedding JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); + await helper.ExecuteSQLNonQuery("DELETE datapoint.* FROM datapoint JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); + await helper.ExecuteSQLNonQuery("DELETE attribute.* FROM attribute JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); + await helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE name = @name AND entity.id_searchdomain = @searchdomain", parameters); entityCache.RemoveAll(entity => entity.name == name); } - public int RemoveAllEntities(SQLHelper helper, string searchdomain) + public async Task RemoveAllEntities(SQLHelper helper, string searchdomain) { Dictionary parameters = new() { - { "searchdomain", GetSearchdomainID(helper, searchdomain)} + { "searchdomain", await GetSearchdomainID(helper, searchdomain)} }; - - helper.ExecuteSQLNonQuery("DELETE embedding.* FROM embedding JOIN datapoint dp ON id_datapoint = dp.id JOIN entity ON id_entity = entity.id WHERE entity.id_searchdomain = @searchdomain", parameters); - helper.ExecuteSQLNonQuery("DELETE datapoint.* FROM datapoint JOIN entity ON id_entity = entity.id WHERE entity.id_searchdomain = @searchdomain", parameters); - helper.ExecuteSQLNonQuery("DELETE FROM attribute WHERE id_entity IN (SELECT entity.id FROM entity WHERE id_searchdomain = @searchdomain)", parameters); - return helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE entity.id_searchdomain = @searchdomain", parameters); + int count; + do + { + count = await helper.ExecuteSQLNonQuery("DELETE FROM embedding WHERE id_searchdomain = @searchdomain LIMIT 10000", parameters); + } while (count == 10000); + do + { + count = await helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity IN (SELECT id FROM entity WHERE id_searchdomain = @searchdomain) LIMIT 10000", parameters); + } while (count == 10000); + do + { + count = await helper.ExecuteSQLNonQuery("DELETE FROM attribute WHERE id_entity IN (SELECT id FROM entity WHERE id_searchdomain = @searchdomain) LIMIT 10000", parameters); + } while (count == 10000); + int total = 0; + do + { + count = await helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE id_searchdomain = @searchdomain LIMIT 10000", parameters); + total += count; + } while (count == 10000); + return total; } - public bool HasEntity(SQLHelper helper, string name, string searchdomain) + public async Task HasEntity(SQLHelper helper, string name, string searchdomain) { Dictionary parameters = new() { { "name", name }, - { "searchdomain", GetSearchdomainID(helper, searchdomain)} + { "searchdomain", await GetSearchdomainID(helper, searchdomain)} }; lock (helper.connection) { DbDataReader reader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters); - bool success = reader.Read(); - bool result = success && reader.GetInt32(0) > 0; - reader.Close(); - if (success) + try { - return result; - } - else + bool success = reader.Read(); + bool result = success && reader.GetInt32(0) > 0; + if (success) + { + return result; + } + else + { + _logger.LogError("Unable to determine whether an entity named {name} exists for {searchdomain}", [name, searchdomain]); + throw new Exception($"Unable to determine whether an entity named {name} exists for {searchdomain}"); + } + } finally { - _logger.LogError("Unable to determine whether an entity named {name} exists for {searchdomain}", [name, searchdomain]); - throw new Exception($"Unable to determine whether an entity named {name} exists for {searchdomain}"); + reader.Close(); } } } - public int? GetEntityID(SQLHelper helper, string name, string searchdomain) + public async Task GetEntityID(SQLHelper helper, string name, string searchdomain) { Dictionary parameters = new() { { "name", name }, - { "searchdomain", GetSearchdomainID(helper, searchdomain)} + { "searchdomain", await GetSearchdomainID(helper, searchdomain)} }; lock (helper.connection) { DbDataReader reader = helper.ExecuteSQLCommand("SELECT id FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters); - bool success = reader.Read(); - int? result = success ? reader.GetInt32(0) : 0; - reader.Close(); - return result; + try + { + bool success = reader.Read(); + int? result = success ? reader.GetInt32(0) : 0; + return result; + } finally + { + reader.Close(); + } } } @@ -291,29 +295,56 @@ public class DatabaseHelper(ILogger logger) { "searchdomain", searchdomain} }; DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(id) + LENGTH(name) + LENGTH(settings)) AS total_bytes FROM embeddingsearch.searchdomain WHERE name=@searchdomain", parameters); - bool success = searchdomainSumReader.Read(); - long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; - searchdomainSumReader.Close(); + bool success; + long result; + try + { + success = searchdomainSumReader.Read(); + result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; + } finally + { + searchdomainSumReader.Close(); + } DbDataReader entitySumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(e.id) + LENGTH(e.name) + LENGTH(e.probmethod) + LENGTH(e.id_searchdomain)) AS total_bytes FROM embeddingsearch.entity e JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); - success = entitySumReader.Read(); - result += success && !entitySumReader.IsDBNull(0) ? entitySumReader.GetInt64(0) : 0; - entitySumReader.Close(); + try + { + success = entitySumReader.Read(); + result += success && !entitySumReader.IsDBNull(0) ? entitySumReader.GetInt64(0) : 0; + } finally + { + entitySumReader.Close(); + } DbDataReader datapointSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(d.id) + LENGTH(d.name) + LENGTH(d.probmethod_embedding) + LENGTH(d.similaritymethod) + LENGTH(d.id_entity) + LENGTH(d.hash)) AS total_bytes FROM embeddingsearch.datapoint d JOIN embeddingsearch.entity e ON d.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); - success = datapointSumReader.Read(); - result += success && !datapointSumReader.IsDBNull(0) ? datapointSumReader.GetInt64(0) : 0; - datapointSumReader.Close(); + try + { + success = datapointSumReader.Read(); + result += success && !datapointSumReader.IsDBNull(0) ? datapointSumReader.GetInt64(0) : 0; + } finally + { + datapointSumReader.Close(); + } DbDataReader embeddingSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(em.id) + LENGTH(em.id_datapoint) + LENGTH(em.model) + LENGTH(em.embedding)) AS total_bytes FROM embeddingsearch.embedding em JOIN embeddingsearch.datapoint d ON em.id_datapoint = d.id JOIN embeddingsearch.entity e ON d.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); - success = embeddingSumReader.Read(); - result += success && !embeddingSumReader.IsDBNull(0) ? embeddingSumReader.GetInt64(0) : 0; - embeddingSumReader.Close(); + try + { + success = embeddingSumReader.Read(); + result += success && !embeddingSumReader.IsDBNull(0) ? embeddingSumReader.GetInt64(0) : 0; + } finally + { + embeddingSumReader.Close(); + } DbDataReader attributeSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(a.id) + LENGTH(a.id_entity) + LENGTH(a.attribute) + LENGTH(a.value)) AS total_bytes FROM embeddingsearch.attribute a JOIN embeddingsearch.entity e ON a.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); - success = attributeSumReader.Read(); - result += success && !attributeSumReader.IsDBNull(0) ? attributeSumReader.GetInt64(0) : 0; - attributeSumReader.Close(); + try + { + success = attributeSumReader.Read(); + result += success && !attributeSumReader.IsDBNull(0) ? attributeSumReader.GetInt64(0) : 0; + } finally + { + attributeSumReader.Close(); + } return result; } @@ -336,10 +367,15 @@ public class DatabaseHelper(ILogger logger) public static async Task CountEntities(SQLHelper helper) { DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity;", []); - bool success = searchdomainSumReader.Read(); - long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; - searchdomainSumReader.Close(); - return result; + try + { + bool success = searchdomainSumReader.Read(); + long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; + return result; + } finally + { + searchdomainSumReader.Close(); + } } public static long CountEntitiesForSearchdomain(SQLHelper helper, string searchdomain) @@ -349,10 +385,15 @@ public class DatabaseHelper(ILogger logger) { "searchdomain", searchdomain} }; DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity e JOIN searchdomain s on e.id_searchdomain = s.id WHERE e.id_searchdomain = s.id AND s.name = @searchdomain;", parameters); - bool success = searchdomainSumReader.Read(); - long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; - searchdomainSumReader.Close(); - return result; + try + { + bool success = searchdomainSumReader.Read(); + long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; + return result; + } finally + { + searchdomainSumReader.Close(); + } } public static SearchdomainSettings GetSearchdomainSettings(SQLHelper helper, string searchdomain) diff --git a/src/Server/Helper/SQLHelper.cs b/src/Server/Helper/SQLHelper.cs index 0fb1ec4..3158b77 100644 --- a/src/Server/Helper/SQLHelper.cs +++ b/src/Server/Helper/SQLHelper.cs @@ -8,17 +8,22 @@ public class SQLHelper:IDisposable { public MySqlConnection connection; public DbDataReader? dbDataReader; + public MySqlConnectionPoolElement[] connectionPool; public string connectionString; public SQLHelper(MySqlConnection connection, string connectionString) { this.connection = connection; this.connectionString = connectionString; + connectionPool = new MySqlConnectionPoolElement[50]; + for (int i = 0; i < connectionPool.Length; i++) + { + connectionPool[i] = new MySqlConnectionPoolElement(new MySqlConnection(connectionString), new(1, 1)); + } } - public SQLHelper DuplicateConnection() + public SQLHelper DuplicateConnection() // TODO remove this { - MySqlConnection newConnection = new(connectionString); - return new SQLHelper(newConnection, connectionString); + return this; } public void Dispose() @@ -44,12 +49,43 @@ public class SQLHelper:IDisposable } } - public int ExecuteSQLNonQuery(string query, Dictionary parameters) + public async Task> ExecuteQueryAsync( + string sql, + Dictionary parameters, + Func map) { - lock (connection) + var poolElement = await GetMySqlConnectionPoolElement(); + var connection = poolElement.connection; + try + { + await using var command = connection.CreateCommand(); + command.CommandText = sql; + + foreach (var p in parameters) + command.Parameters.AddWithValue($"@{p.Key}", p.Value); + + await using var reader = await command.ExecuteReaderAsync(); + + var result = new List(); + while (await reader.ReadAsync()) + { + result.Add(map(reader)); + } + + return result; + } finally + { + + poolElement.Semaphore.Release(); + } + } + + public async Task ExecuteSQLNonQuery(string query, Dictionary parameters) + { + var poolElement = await GetMySqlConnectionPoolElement(); + var connection = poolElement.connection; + try { - EnsureConnected(); - EnsureDbReaderIsClosed(); using MySqlCommand command = connection.CreateCommand(); command.CommandText = query; @@ -58,15 +94,18 @@ public class SQLHelper:IDisposable command.Parameters.AddWithValue($"@{parameter.Key}", parameter.Value); } return command.ExecuteNonQuery(); + } finally + { + poolElement.Semaphore.Release(); } } - public int ExecuteSQLCommandGetInsertedID(string query, Dictionary parameters) + public async Task ExecuteSQLCommandGetInsertedID(string query, Dictionary parameters) { - lock (connection) + var poolElement = await GetMySqlConnectionPoolElement(); + var connection = poolElement.connection; + try { - EnsureConnected(); - EnsureDbReaderIsClosed(); using MySqlCommand command = connection.CreateCommand(); command.CommandText = query; @@ -77,16 +116,18 @@ public class SQLHelper:IDisposable command.ExecuteNonQuery(); command.CommandText = "SELECT LAST_INSERT_ID();"; return Convert.ToInt32(command.ExecuteScalar()); + } finally + { + poolElement.Semaphore.Release(); } } - public int BulkExecuteNonQuery(string sql, IEnumerable parameterSets) + public async Task BulkExecuteNonQuery(string sql, IEnumerable parameterSets) { - lock (connection) + var poolElement = await GetMySqlConnectionPoolElement(); + var connection = poolElement.connection; + try { - EnsureConnected(); - EnsureDbReaderIsClosed(); - int affectedRows = 0; int retries = 0; @@ -120,9 +161,37 @@ public class SQLHelper:IDisposable } return affectedRows; + } finally + { + poolElement.Semaphore.Release(); } } + public async Task GetMySqlConnectionPoolElement() + { + int counter = 0; + int sleepTime = 10; + do + { + foreach (var element in connectionPool) + { + if (element.Semaphore.Wait(0)) + { + if (element.connection.State == ConnectionState.Closed) + { + await element.connection.CloseAsync(); + await element.connection.OpenAsync(); + } + return element; + } + } + Thread.Sleep(sleepTime); + } while (++counter <= 50); + TimeoutException ex = new("Unable to get MySqlConnection"); + ElmahCore.ElmahExtensions.RaiseError(ex); + throw ex; + } + public bool EnsureConnected() { if (connection.State != System.Data.ConnectionState.Open) @@ -157,4 +226,16 @@ public class SQLHelper:IDisposable Thread.Sleep(sleepTime); } } +} + +public struct MySqlConnectionPoolElement +{ + public MySqlConnection connection; + public SemaphoreSlim Semaphore; + + public MySqlConnectionPoolElement(MySqlConnection connection, SemaphoreSlim semaphore) + { + this.connection = connection; + this.Semaphore = semaphore; + } } \ No newline at end of file diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 8f38047..5a543aa 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -1,4 +1,5 @@ using System.Collections.Concurrent; +using System.Diagnostics; using System.Security.Cryptography; using System.Text; using System.Text.Json; @@ -29,14 +30,14 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp return floatArray; } - public static bool CacheHasEntity(ConcurrentBag entityCache, string name) + public static bool CacheHasEntity(ConcurrentDictionary entityCache, string name) { return CacheGetEntity(entityCache, name) is not null; } - public static Entity? CacheGetEntity(ConcurrentBag entityCache, string name) + public static Entity? CacheGetEntity(ConcurrentDictionary entityCache, string name) { - foreach (Entity entity in entityCache) + foreach ((string _, Entity entity) in entityCache) { if (entity.name == name) { @@ -46,7 +47,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp return null; } - public List? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json) + public async Task?> EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json) { EnumerableLruCache> embeddingCache = searchdomainManager.embeddingCache; AIProvider aIProvider = searchdomainManager.aIProvider; @@ -96,46 +97,58 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp // Index/parse the entities ConcurrentQueue retVal = []; ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool - Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity => + + List entityTasks = []; + foreach (JSONEntity jSONEntity in jsonEntities) { - var entity = EntityFromJSON(searchdomainManager, logger, jSONEntity); - if (entity is not null) + entityTasks.Add(Task.Run(async () => { - retVal.Enqueue(entity); + var entity = await EntityFromJSON(searchdomainManager, logger, jSONEntity); + if (entity is not null) + { + retVal.Enqueue(entity); + } + })); + + if (entityTasks.Count >= parallelOptions.MaxDegreeOfParallelism) + { + await Task.WhenAny(entityTasks); + entityTasks.RemoveAll(t => t.IsCompleted); } - }); + } + + await Task.WhenAll(entityTasks); + return [.. retVal]; } - public Entity? EntityFromJSON(SearchdomainManager searchdomainManager, ILogger logger, JSONEntity jsonEntity) //string json) + public async Task EntityFromJSON(SearchdomainManager searchdomainManager, ILogger logger, JSONEntity jsonEntity) { - using SQLHelper helper = searchdomainManager.helper.DuplicateConnection(); + var stopwatch = Stopwatch.StartNew(); + + SQLHelper helper = searchdomainManager.helper; Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain); - ConcurrentBag entityCache = searchdomain.entityCache; + int id_searchdomain = searchdomain.id; + ConcurrentDictionary entityCache = searchdomain.entityCache; AIProvider aIProvider = searchdomain.aIProvider; EnumerableLruCache> embeddingCache = searchdomain.embeddingCache; - Entity? preexistingEntity; - lock (entityCache) - { - preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); - } bool invalidateSearchCache = false; - if (preexistingEntity is not null) + + bool hasEntity = entityCache.TryGetValue(jsonEntity.Name, out Entity? preexistingEntity); + + if (hasEntity && preexistingEntity is not null) { - int? preexistingEntityID = _databaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); - if (preexistingEntityID is null) - { - _logger.LogCritical("Unable to index entity {jsonEntity.Name} because it already exists in the searchdomain but not in the database.", [jsonEntity.Name]); - throw new Exception($"Unable to index entity {jsonEntity.Name} because it already exists in the searchdomain but not in the database."); - } + + int preexistingEntityID = preexistingEntity.id; + Dictionary attributes = jsonEntity.Attributes; // Attribute - get changes - List<(string attribute, string newValue, int entityId)> updatedAttributes = []; - List<(string attribute, int entityId)> deletedAttributes = []; - List<(string attributeKey, string attribute, int entityId)> addedAttributes = []; - foreach (KeyValuePair attributesKV in preexistingEntity.attributes.ToList()) + List<(string attribute, string newValue, int entityId)> updatedAttributes = new(preexistingEntity.attributes.Count); + List<(string attribute, int entityId)> deletedAttributes = new(preexistingEntity.attributes.Count); + List<(string attributeKey, string attribute, int entityId)> addedAttributes = new(jsonEntity.Attributes.Count); + foreach (KeyValuePair attributesKV in preexistingEntity.attributes) //.ToList()) { string oldAttributeKey = attributesKV.Key; string oldAttribute = attributesKV.Value; @@ -148,6 +161,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp deletedAttributes.Add((attribute: oldAttributeKey, entityId: (int)preexistingEntityID)); } } + foreach (var attributesKV in jsonEntity.Attributes) { string newAttributeKey = attributesKV.Key; @@ -160,12 +174,13 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp } } - + if (updatedAttributes.Count != 0 || deletedAttributes.Count != 0 || addedAttributes.Count != 0) + _logger.LogDebug("EntityFromJSON - Updating existing entity. name: {name}, updatedAttributes: {updatedAttributes}, deletedAttributes: {deletedAttributes}, addedAttributes: {addedAttributes}", [preexistingEntity.name, updatedAttributes.Count, deletedAttributes.Count, addedAttributes.Count]); // Attribute - apply changes if (updatedAttributes.Count != 0) { // Update - DatabaseHelper.DatabaseUpdateAttributes(helper, updatedAttributes); + await DatabaseHelper.DatabaseUpdateAttributes(helper, updatedAttributes); lock (preexistingEntity.attributes) { updatedAttributes.ForEach(attribute => preexistingEntity.attributes[attribute.attribute] = attribute.newValue); @@ -174,7 +189,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp if (deletedAttributes.Count != 0) { // Delete - DatabaseHelper.DatabaseDeleteAttributes(helper, deletedAttributes); + await DatabaseHelper.DatabaseDeleteAttributes(helper, deletedAttributes); lock (preexistingEntity.attributes) { deletedAttributes.ForEach(attribute => preexistingEntity.attributes.Remove(attribute.attribute)); @@ -183,7 +198,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp if (addedAttributes.Count != 0) { // Insert - DatabaseHelper.DatabaseInsertAttributes(helper, addedAttributes); + await DatabaseHelper.DatabaseInsertAttributes(helper, addedAttributes); lock (preexistingEntity.attributes) { addedAttributes.ForEach(attribute => preexistingEntity.attributes.Add(attribute.attributeKey, attribute.attribute)); @@ -191,17 +206,18 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp } // Datapoint - get changes - List deletedDatapointInstances = []; - List deletedDatapoints = []; - List<(string datapointName, int entityId, JSONDatapoint jsonDatapoint, string hash)> updatedDatapointsText = []; - List<(string datapointName, string probMethod, string similarityMethod, int entityId, JSONDatapoint jsonDatapoint)> updatedDatapointsNonText = []; + List deletedDatapointInstances = new(preexistingEntity.datapoints.Count); + List deletedDatapoints = new(preexistingEntity.datapoints.Count); + List<(string datapointName, int entityId, JSONDatapoint jsonDatapoint, string hash)> updatedDatapointsText = new(preexistingEntity.datapoints.Count); + List<(string datapointName, string probMethod, string similarityMethod, int entityId, JSONDatapoint jsonDatapoint)> updatedDatapointsNonText = new(preexistingEntity.datapoints.Count); List createdDatapointInstances = []; - List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, Dictionary embeddings, JSONDatapoint datapoint)> createdDatapoints = []; + List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, Dictionary embeddings, JSONDatapoint datapoint)> createdDatapoints = new(jsonEntity.Datapoints.Length); foreach (Datapoint datapoint_ in preexistingEntity.datapoints.ToList()) { Datapoint datapoint = datapoint_; // To enable replacing the datapoint reference as foreach iterators cannot be overwritten - bool newEntityHasDatapoint = jsonEntity.Datapoints.Any(x => x.Name == datapoint.name); + JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name); + bool newEntityHasDatapoint = newEntityDatapoint is not null; if (!newEntityHasDatapoint) { // Datapoint - Deleted @@ -210,7 +226,6 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp invalidateSearchCache = true; } else { - JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name); string? hash = newEntityDatapoint?.Text is not null ? GetHash(newEntityDatapoint) : null; if ( newEntityDatapoint is not null @@ -246,6 +261,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp } } } + foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints) { bool oldEntityHasDatapoint = preexistingEntity.datapoints.Any(x => x.name == jsonDatapoint.Name); @@ -269,38 +285,51 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp invalidateSearchCache = true; } } - + + + if (deletedDatapointInstances.Count != 0 || createdDatapoints.Count != 0 || addedAttributes.Count != 0 || updatedDatapointsNonText.Count != 0) + _logger.LogDebug( + "EntityFromJSON - Updating existing entity. name: {name}, deletedDatapointInstances: {deletedDatapointInstances}, createdDatapoints: {createdDatapoints}, addedAttributes: {addedAttributes}, updatedDatapointsNonText: {updatedDatapointsNonText}", + [preexistingEntity.name, deletedDatapointInstances.Count, createdDatapoints.Count, addedAttributes.Count, updatedDatapointsNonText.Count]); // Datapoint - apply changes // Deleted if (deletedDatapointInstances.Count != 0) { - DatabaseHelper.DatabaseDeleteDatapoints(helper, deletedDatapoints, (int)preexistingEntityID); - deletedDatapointInstances.ForEach(datapoint => preexistingEntity.datapoints.Remove(datapoint)); + await DatabaseHelper.DatabaseDeleteEmbeddingsAndDatapoints(helper, deletedDatapoints, (int)preexistingEntityID); + preexistingEntity.datapoints = [.. preexistingEntity.datapoints + .Where(x => + !deletedDatapointInstances.Contains(x) + ) + ]; } // Created if (createdDatapoints.Count != 0) { - List datapoint = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. createdDatapoints.Select(element => (element.datapoint, element.hash))], (int)preexistingEntityID); - createdDatapoints.ForEach(datapoint => preexistingEntity.datapoints.Add(new( - datapoint.name, - datapoint.probmethod_embedding, - datapoint.similarityMethod, - datapoint.hash, - [.. datapoint.embeddings.Select(element => (element.Key, element.Value))]) - )); + List datapoint = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. createdDatapoints.Select(element => (element.datapoint, element.hash))], (int)preexistingEntityID, id_searchdomain); + datapoint.ForEach(x => preexistingEntity.datapoints.Add(x)); } // Datapoint - Updated (text) if (updatedDatapointsText.Count != 0) { - DatabaseHelper.DatabaseDeleteDatapoints(helper, [.. updatedDatapointsText.Select(datapoint => datapoint.datapointName)], (int)preexistingEntityID); - updatedDatapointsText.ForEach(datapoint => preexistingEntity.datapoints.RemoveAll(x => x.name == datapoint.datapointName)); - List datapoints = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. updatedDatapointsText.Select(element => (datapoint: element.jsonDatapoint, hash: element.hash))], (int)preexistingEntityID); - preexistingEntity.datapoints.AddRange(datapoints); + await DatabaseHelper.DatabaseDeleteEmbeddingsAndDatapoints(helper, [.. updatedDatapointsText.Select(datapoint => datapoint.datapointName)], (int)preexistingEntityID); + // Remove from datapoints + var namesToRemove = updatedDatapointsText + .Select(d => d.datapointName) + .ToHashSet(); + var newBag = new ConcurrentBag( + preexistingEntity.datapoints + .Where(x => !namesToRemove.Contains(x.name)) + ); + preexistingEntity.datapoints = newBag; + // Insert into database + List datapoints = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. updatedDatapointsText.Select(element => (datapoint: element.jsonDatapoint, hash: element.hash))], (int)preexistingEntityID, id_searchdomain); + // Insert into datapoints + datapoints.ForEach(datapoint => preexistingEntity.datapoints.Add(datapoint)); } // Datapoint - Updated (probmethod or similaritymethod) if (updatedDatapointsNonText.Count != 0) { - DatabaseHelper.DatabaseUpdateDatapoint( + await DatabaseHelper.DatabaseUpdateDatapoint( helper, [.. updatedDatapointsNonText.Select(element => (element.datapointName, element.probMethod, element.similarityMethod))], (int)preexistingEntityID @@ -313,16 +342,19 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp }); } + if (invalidateSearchCache) { + searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity); + searchdomain.UpdateModelsInUse(); } - searchdomain.UpdateModelsInUse(); + return preexistingEntity; } else { - int id_entity = DatabaseHelper.DatabaseInsertEntity(helper, jsonEntity.Name, jsonEntity.Probmethod, _databaseHelper.GetSearchdomainID(helper, jsonEntity.Searchdomain)); + int id_entity = await DatabaseHelper.DatabaseInsertEntity(helper, jsonEntity.Name, jsonEntity.Probmethod, id_searchdomain); List<(string attribute, string value, int id_entity)> toBeInsertedAttributes = []; foreach (KeyValuePair attribute in jsonEntity.Attributes) { @@ -332,9 +364,11 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp id_entity = id_entity }); } - DatabaseHelper.DatabaseInsertAttributes(helper, toBeInsertedAttributes); + + var insertAttributesTask = DatabaseHelper.DatabaseInsertAttributes(helper, toBeInsertedAttributes); List<(JSONDatapoint datapoint, string hash)> toBeInsertedDatapoints = []; + ConcurrentBag usedModels = searchdomain.modelsInUse; foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints) { string hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); @@ -343,29 +377,39 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp datapoint = jsonDatapoint, hash = hash }); + foreach (string model in jsonDatapoint.Model) + { + if (!usedModels.Contains(model)) + { + usedModels.Add(model); + } + } } - List datapoints = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, toBeInsertedDatapoints, id_entity); + + List datapoints = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, toBeInsertedDatapoints, id_entity, id_searchdomain); var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new ProbMethodNotFoundException(jsonEntity.Probmethod); - Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), datapoints, jsonEntity.Name) + Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), new(datapoints), jsonEntity.Name) { id = id_entity }; - entityCache.Add(entity); + entityCache[jsonEntity.Name] = entity; + searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(entity); - searchdomain.UpdateModelsInUse(); + await insertAttributesTask; return entity; } } - public List DatabaseInsertDatapointsWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, List<(JSONDatapoint datapoint, string hash)> values, int id_entity) + public async Task> DatabaseInsertDatapointsWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, List<(JSONDatapoint datapoint, string hash)> values, int id_entity, int id_searchdomain) { List result = []; List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> toBeInsertedDatapoints = []; - List<(string name, string model, byte[] embedding)> toBeInsertedEmbeddings = []; + List<(int id_datapoint, string model, byte[] embedding)> toBeInsertedEmbeddings = []; foreach ((JSONDatapoint datapoint, string hash) value in values) { - Datapoint datapoint = BuildDatapointFromJsonDatapoint(value.datapoint, id_entity, searchdomain, value.hash); + Datapoint datapoint = await BuildDatapointFromJsonDatapoint(value.datapoint, id_entity, searchdomain, value.hash); + toBeInsertedDatapoints.Add(new() { name = datapoint.name, @@ -377,34 +421,34 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp { toBeInsertedEmbeddings.Add(new() { - name = datapoint.name, + id_datapoint = datapoint.id, model = embedding.Item1, embedding = BytesFromFloatArray(embedding.Item2) }); } result.Add(datapoint); + } - int insertedDatapoints = DatabaseHelper.DatabaseInsertDatapoints(helper, toBeInsertedDatapoints, id_entity); - int insertedEmbeddings = DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, toBeInsertedEmbeddings, id_entity); + await DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, toBeInsertedEmbeddings, id_entity, id_searchdomain); return result; } - public Datapoint DatabaseInsertDatapointWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, JSONDatapoint jsonDatapoint, int id_entity, string? hash = null) + public async Task DatabaseInsertDatapointWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, JSONDatapoint jsonDatapoint, int id_entity, int id_searchdomain, string? hash = null) { if (jsonDatapoint.Text is null) { throw new Exception("jsonDatapoint.Text must not be null at this point"); } hash ??= GetHash(jsonDatapoint); - Datapoint datapoint = BuildDatapointFromJsonDatapoint(jsonDatapoint, id_entity, searchdomain, hash); - int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries + Datapoint datapoint = await BuildDatapointFromJsonDatapoint(jsonDatapoint, id_entity, searchdomain, hash); + int id_datapoint = await DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries List<(string model, byte[] embedding)> data = []; foreach ((string, float[]) embedding in datapoint.embeddings) { data.Add((embedding.Item1, BytesFromFloatArray(embedding.Item2))); } - DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, id_datapoint, data); + await DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, id_datapoint, data, id_entity, id_searchdomain); return datapoint; } @@ -413,20 +457,20 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp return Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text ?? throw new Exception("jsonDatapoint.Text must not be null to compute hash")))); } - public Datapoint BuildDatapointFromJsonDatapoint(JSONDatapoint jsonDatapoint, int entityId, Searchdomain searchdomain, string? hash = null) + public async Task BuildDatapointFromJsonDatapoint(JSONDatapoint jsonDatapoint, int entityId, Searchdomain searchdomain, string? hash = null) { if (jsonDatapoint.Text is null) { throw new Exception("jsonDatapoint.Text must not be null at this point"); } - using SQLHelper helper = searchdomain.helper.DuplicateConnection(); + SQLHelper helper = searchdomain.helper; EnumerableLruCache> embeddingCache = searchdomain.embeddingCache; hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); - DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); + int id = await DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); Dictionary embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding); var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod); - return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); + return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))], id); } public static (Searchdomain?, int?, string?) TryGetSearchdomain(SearchdomainManager searchdomainManager, string searchdomain, ILogger logger) diff --git a/src/Server/Migrations/DatabaseMigrations.cs b/src/Server/Migrations/DatabaseMigrations.cs index 3f4d5c3..1b966c9 100644 --- a/src/Server/Migrations/DatabaseMigrations.cs +++ b/src/Server/Migrations/DatabaseMigrations.cs @@ -34,72 +34,113 @@ public static class DatabaseMigrations if (databaseVersion != initialDatabaseVersion) { - helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() }); + var _ = helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() }).Result; } } public static int DatabaseGetVersion(SQLHelper helper) { DbDataReader reader = helper.ExecuteSQLCommand("show tables", []); - bool hasTables = reader.Read(); - reader.Close(); - if (!hasTables) + try { - return 0; + bool hasTables = reader.Read(); + if (!hasTables) + { + return 0; + } + } finally + { + reader.Close(); } reader = helper.ExecuteSQLCommand("show tables like '%settings%'", []); - bool hasSystemTable = reader.Read(); - reader.Close(); - if (!hasSystemTable) + try { - return 1; + bool hasSystemTable = reader.Read(); + if (!hasSystemTable) + { + return 1; + } + } finally + { + reader.Close(); } reader = helper.ExecuteSQLCommand("SELECT value FROM settings WHERE name=\"DatabaseVersion\"", []); - reader.Read(); - string rawVersion = reader.GetString(0); - reader.Close(); - bool success = int.TryParse(rawVersion, out int version); - if (!success) + try { - throw new DatabaseVersionException(); + reader.Read(); + string rawVersion = reader.GetString(0); + bool success = int.TryParse(rawVersion, out int version); + if (!success) + { + throw new DatabaseVersionException(); + } + return version; + } finally + { + reader.Close(); } - return version; } public static int Create(SQLHelper helper) { - helper.ExecuteSQLNonQuery("CREATE TABLE searchdomain (id int PRIMARY KEY auto_increment, name varchar(512), settings JSON);", []); - helper.ExecuteSQLNonQuery("CREATE TABLE entity (id int PRIMARY KEY auto_increment, name varchar(512), probmethod varchar(128), id_searchdomain int, FOREIGN KEY (id_searchdomain) REFERENCES searchdomain(id));", []); - helper.ExecuteSQLNonQuery("CREATE TABLE attribute (id int PRIMARY KEY auto_increment, id_entity int, attribute varchar(512), value longtext, FOREIGN KEY (id_entity) REFERENCES entity(id));", []); - helper.ExecuteSQLNonQuery("CREATE TABLE datapoint (id int PRIMARY KEY auto_increment, name varchar(512), probmethod_embedding varchar(512), id_entity int, FOREIGN KEY (id_entity) REFERENCES entity(id));", []); - helper.ExecuteSQLNonQuery("CREATE TABLE embedding (id int PRIMARY KEY auto_increment, id_datapoint int, model varchar(512), embedding blob, FOREIGN KEY (id_datapoint) REFERENCES datapoint(id));", []); + var _ = helper.ExecuteSQLNonQuery("CREATE TABLE searchdomain (id int PRIMARY KEY auto_increment, name varchar(512), settings JSON);", []).Result; + _ = helper.ExecuteSQLNonQuery("CREATE TABLE entity (id int PRIMARY KEY auto_increment, name varchar(512), probmethod varchar(128), id_searchdomain int, FOREIGN KEY (id_searchdomain) REFERENCES searchdomain(id));", []).Result; + _ = helper.ExecuteSQLNonQuery("CREATE TABLE attribute (id int PRIMARY KEY auto_increment, id_entity int, attribute varchar(512), value longtext, FOREIGN KEY (id_entity) REFERENCES entity(id));", []).Result; + _ = helper.ExecuteSQLNonQuery("CREATE TABLE datapoint (id int PRIMARY KEY auto_increment, name varchar(512), probmethod_embedding varchar(512), id_entity int, FOREIGN KEY (id_entity) REFERENCES entity(id));", []).Result; + _ = helper.ExecuteSQLNonQuery("CREATE TABLE embedding (id int PRIMARY KEY auto_increment, id_datapoint int, model varchar(512), embedding blob, FOREIGN KEY (id_datapoint) REFERENCES datapoint(id));", []).Result; return 1; } public static int UpdateFrom1(SQLHelper helper) { - helper.ExecuteSQLNonQuery("CREATE TABLE settings (name varchar(512), value varchar(8192));", []); - helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES (\"DatabaseVersion\", \"2\");", []); + var _ = helper.ExecuteSQLNonQuery("CREATE TABLE settings (name varchar(512), value varchar(8192));", []).Result; + _ = helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES (\"DatabaseVersion\", \"2\");", []).Result; return 2; } public static int UpdateFrom2(SQLHelper helper) { - helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD hash VARCHAR(44);", []); - helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []); + var _ = helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD hash VARCHAR(44);", []).Result; + _ = helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []).Result; return 3; } public static int UpdateFrom3(SQLHelper helper) { - helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []); + var _ = helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []).Result; return 4; } public static int UpdateFrom4(SQLHelper helper) { - helper.ExecuteSQLNonQuery("UPDATE searchdomain SET settings = JSON_SET(settings, '$.QueryCacheSize', 1000000) WHERE JSON_EXTRACT(settings, '$.QueryCacheSize') is NULL;", []); // Set QueryCacheSize to a default of 1000000 + var _ = helper.ExecuteSQLNonQuery("UPDATE searchdomain SET settings = JSON_SET(settings, '$.QueryCacheSize', 1000000) WHERE JSON_EXTRACT(settings, '$.QueryCacheSize') is NULL;", []).Result; // Set QueryCacheSize to a default of 1000000 return 5; } + + public static int UpdateFrom5(SQLHelper helper) + { + // Add id_entity to embedding + var _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_entity INT NULL", []).Result; + int count; + do + { + count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN datapoint d ON d.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_entity IS NULL LIMIT 10000) x on x.id = e.id SET e.id_entity = d.id_entity;", []).Result; + } while (count == 10000); + + _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_entity INT NOT NULL;", []).Result; + _ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_entity_model ON embedding (id_entity, model)", []).Result; + + // Add id_searchdomain to embedding + _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_searchdomain INT NULL", []).Result; + do + { + count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN entity en ON en.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_searchdomain IS NULL LIMIT 10000) x on x.id = e.id SET e.id_searchdomain = en.id_searchdomain;", []).Result; + } while (count == 10000); + + _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_searchdomain INT NOT NULL;", []).Result; + _ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_searchdomain_model ON embedding (id_searchdomain)", []).Result; + + return 6; + } } \ No newline at end of file diff --git a/src/Server/Program.cs b/src/Server/Program.cs index 7580210..071635f 100644 --- a/src/Server/Program.cs +++ b/src/Server/Program.cs @@ -42,7 +42,7 @@ builder.WebHost.ConfigureKestrel(options => }); // Migrate database -var helper = new SQLHelper(new MySql.Data.MySqlClient.MySqlConnection(configuration.ConnectionStrings.SQL), configuration.ConnectionStrings.SQL); +SQLHelper helper = new(new MySql.Data.MySqlClient.MySqlConnection(configuration.ConnectionStrings.SQL), configuration.ConnectionStrings.SQL); DatabaseMigrations.Migrate(helper); // Migrate SQLite cache diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index d0c718e..44beaee 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -20,8 +20,8 @@ public class Searchdomain public int id; public SearchdomainSettings settings; public EnumerableLruCache queryCache; // Key: query, Value: Search results for that query (with timestamp) - public ConcurrentBag entityCache; - public List modelsInUse; + public ConcurrentDictionary entityCache; + public ConcurrentBag modelsInUse; public EnumerableLruCache> embeddingCache; private readonly MySqlConnection connection; public SQLHelper helper; @@ -44,7 +44,7 @@ public class Searchdomain modelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that... if (!runEmpty) { - GetID(); + id = GetID().Result; UpdateEntityCache(); } } @@ -56,110 +56,130 @@ public class Searchdomain { ["id"] = this.id }; - DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT e.id, e.id_datapoint, e.model, e.embedding FROM embedding e JOIN datapoint d ON e.id_datapoint = d.id JOIN entity ent ON d.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain); + DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT id, id_datapoint, model, embedding FROM embedding WHERE id_searchdomain = @id", parametersIDSearchdomain); Dictionary> embedding_unassigned = []; - while (embeddingReader.Read()) + try { - int? id_datapoint_debug = null; - try + while (embeddingReader.Read()) { - int id_datapoint = embeddingReader.GetInt32(1); - id_datapoint_debug = id_datapoint; - string model = embeddingReader.GetString(2); - long length = embeddingReader.GetBytes(3, 0, null, 0, 0); - byte[] embedding = new byte[length]; - embeddingReader.GetBytes(3, 0, embedding, 0, (int) length); - if (embedding_unassigned.TryGetValue(id_datapoint, out Dictionary? embedding_unassigned_id_datapoint)) + int? id_datapoint_debug = null; + try { - embedding_unassigned[id_datapoint][model] = SearchdomainHelper.FloatArrayFromBytes(embedding); - } - else - { - embedding_unassigned[id_datapoint] = new() + int id_datapoint = embeddingReader.GetInt32(1); + id_datapoint_debug = id_datapoint; + string model = embeddingReader.GetString(2); + long length = embeddingReader.GetBytes(3, 0, null, 0, 0); + byte[] embedding = new byte[length]; + embeddingReader.GetBytes(3, 0, embedding, 0, (int) length); + if (embedding_unassigned.TryGetValue(id_datapoint, out Dictionary? embedding_unassigned_id_datapoint)) { - [model] = SearchdomainHelper.FloatArrayFromBytes(embedding) - }; + embedding_unassigned[id_datapoint][model] = SearchdomainHelper.FloatArrayFromBytes(embedding); + } + else + { + embedding_unassigned[id_datapoint] = new() + { + [model] = SearchdomainHelper.FloatArrayFromBytes(embedding) + }; + } + } catch (Exception e) + { + _logger.LogError("Error reading embedding (id: {id_datapoint}) from database: {e.Message} - {e.StackTrace}", [id_datapoint_debug, e.Message, e.StackTrace]); + ElmahCore.ElmahExtensions.RaiseError(e); } - } catch (Exception e) - { - _logger.LogError("Error reading embedding (id: {id_datapoint}) from database: {e.Message} - {e.StackTrace}", [id_datapoint_debug, e.Message, e.StackTrace]); - ElmahCore.ElmahExtensions.RaiseError(e); } + } finally + { + embeddingReader.Close(); } - embeddingReader.Close(); DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT d.id, d.id_entity, d.name, d.probmethod_embedding, d.similaritymethod, d.hash FROM datapoint d JOIN entity ent ON d.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain); - Dictionary> datapoint_unassigned = []; - while (datapointReader.Read()) + Dictionary> datapoint_unassigned = []; + try { - int id = datapointReader.GetInt32(0); - int id_entity = datapointReader.GetInt32(1); - string name = datapointReader.GetString(2); - string probmethodString = datapointReader.GetString(3); - string similarityMethodString = datapointReader.GetString(4); - string hash = datapointReader.GetString(5); - ProbMethodEnum probmethodEnum = (ProbMethodEnum)Enum.Parse( - typeof(ProbMethodEnum), - probmethodString - ); - SimilarityMethodEnum similairtyMethodEnum = (SimilarityMethodEnum)Enum.Parse( - typeof(SimilarityMethodEnum), - similarityMethodString - ); - ProbMethod probmethod = new(probmethodEnum); - SimilarityMethod similarityMethod = new(similairtyMethodEnum); - if (embedding_unassigned.TryGetValue(id, out Dictionary? embeddings) && probmethod is not null) + while (datapointReader.Read()) { - embedding_unassigned.Remove(id); - if (!datapoint_unassigned.ContainsKey(id_entity)) + int id = datapointReader.GetInt32(0); + int id_entity = datapointReader.GetInt32(1); + string name = datapointReader.GetString(2); + string probmethodString = datapointReader.GetString(3); + string similarityMethodString = datapointReader.GetString(4); + string hash = datapointReader.GetString(5); + ProbMethodEnum probmethodEnum = (ProbMethodEnum)Enum.Parse( + typeof(ProbMethodEnum), + probmethodString + ); + SimilarityMethodEnum similairtyMethodEnum = (SimilarityMethodEnum)Enum.Parse( + typeof(SimilarityMethodEnum), + similarityMethodString + ); + ProbMethod probmethod = new(probmethodEnum); + SimilarityMethod similarityMethod = new(similairtyMethodEnum); + if (embedding_unassigned.TryGetValue(id, out Dictionary? embeddings) && probmethod is not null) { - datapoint_unassigned[id_entity] = []; + embedding_unassigned.Remove(id); + if (!datapoint_unassigned.ContainsKey(id_entity)) + { + datapoint_unassigned[id_entity] = []; + } + datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))], id)); } - datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))])); } + } finally + { + datapointReader.Close(); } - datapointReader.Close(); DbDataReader attributeReader = helper.ExecuteSQLCommand("SELECT a.id, a.id_entity, a.attribute, a.value FROM attribute a JOIN entity ent ON a.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain); Dictionary> attributes_unassigned = []; - while (attributeReader.Read()) + try { - //"SELECT id, id_entity, attribute, value FROM attribute JOIN entity on attribute.id_entity as en JOIN searchdomain on en.id_searchdomain as sd WHERE sd=@id" - int id = attributeReader.GetInt32(0); - int id_entity = attributeReader.GetInt32(1); - string attribute = attributeReader.GetString(2); - string value = attributeReader.GetString(3); - if (!attributes_unassigned.ContainsKey(id_entity)) + while (attributeReader.Read()) { - attributes_unassigned[id_entity] = []; + //"SELECT id, id_entity, attribute, value FROM attribute JOIN entity on attribute.id_entity as en JOIN searchdomain on en.id_searchdomain as sd WHERE sd=@id" + int id = attributeReader.GetInt32(0); + int id_entity = attributeReader.GetInt32(1); + string attribute = attributeReader.GetString(2); + string value = attributeReader.GetString(3); + if (!attributes_unassigned.ContainsKey(id_entity)) + { + attributes_unassigned[id_entity] = []; + } + attributes_unassigned[id_entity].Add(attribute, value); } - attributes_unassigned[id_entity].Add(attribute, value); + } finally + { + attributeReader.Close(); } - attributeReader.Close(); entityCache = []; DbDataReader entityReader = helper.ExecuteSQLCommand("SELECT entity.id, name, probmethod FROM entity WHERE id_searchdomain=@id", parametersIDSearchdomain); - while (entityReader.Read()) + try { - //SELECT id, name, probmethod FROM entity WHERE id_searchdomain=@id - int id = entityReader.GetInt32(0); - string name = entityReader.GetString(1); - string probmethodString = entityReader.GetString(2); - if (!attributes_unassigned.TryGetValue(id, out Dictionary? attributes)) + while (entityReader.Read()) { - attributes = []; - } - Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString); - if (datapoint_unassigned.TryGetValue(id, out List? datapoints) && probmethod is not null) - { - Entity entity = new(attributes, probmethod, probmethodString, datapoints, name) + //SELECT id, name, probmethod FROM entity WHERE id_searchdomain=@id + int id = entityReader.GetInt32(0); + string name = entityReader.GetString(1); + string probmethodString = entityReader.GetString(2); + if (!attributes_unassigned.TryGetValue(id, out Dictionary? attributes)) { - id = id - }; - entityCache.Add(entity); + attributes = []; + } + Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString); + if (datapoint_unassigned.TryGetValue(id, out ConcurrentBag? datapoints) && probmethod is not null) + { + Entity entity = new(attributes, probmethod, probmethodString, datapoints, name) + { + id = id + }; + entityCache[name] = entity; + } } + } finally + { + entityReader.Close(); } - entityReader.Close(); modelsInUse = GetModels(entityCache); } @@ -174,7 +194,7 @@ public class Searchdomain Dictionary queryEmbeddings = GetQueryEmbeddings(query); List<(float, string)> result = []; - foreach (Entity entity in entityCache) + foreach ((string name, Entity entity) in entityCache) { result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.name)); } @@ -219,10 +239,7 @@ public class Searchdomain public void UpdateModelsInUse() { - lock (modelsInUse) - { - modelsInUse = GetModels(entityCache); - } + modelsInUse = GetModels(entityCache); } private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary queryEmbeddings) @@ -243,24 +260,22 @@ public class Searchdomain return entity.probMethod(datapointProbs); } - public static List GetModels(ConcurrentBag entities) + public static ConcurrentBag GetModels(ConcurrentDictionary entities) { - List result = []; - foreach (Entity entity in entities) + ConcurrentBag result = []; + foreach (KeyValuePair element in entities) { + Entity entity = element.Value; lock (entity) { foreach (Datapoint datapoint in entity.datapoints) { - lock (entity.datapoints) + foreach ((string, float[]) tuple in datapoint.embeddings) { - foreach ((string, float[]) tuple in datapoint.embeddings) + string model = tuple.Item1; + if (!result.Contains(model)) { - string model = tuple.Item1; - if (!result.Contains(model)) - { - result.Add(model); - } + result.Add(model); } } } @@ -269,17 +284,13 @@ public class Searchdomain return result; } - public int GetID() + public async Task GetID() { - Dictionary parameters = new() + Dictionary parameters = new() { - ["name"] = this.searchdomain + { "name", this.searchdomain } }; - DbDataReader reader = helper.ExecuteSQLCommand("SELECT id from searchdomain WHERE name = @name", parameters); - reader.Read(); - this.id = reader.GetInt32(0); - reader.Close(); - return this.id; + return (await helper.ExecuteQueryAsync("SELECT id from searchdomain WHERE name = @name", parameters, x => x.GetInt32(0))).First(); } public SearchdomainSettings GetSettings() diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index 34e6948..47ebada 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -79,32 +79,17 @@ public class SearchdomainManager : IDisposable searchdomain.InvalidateSearchCache(); } - public List ListSearchdomains() + public async Task> ListSearchdomainsAsync() { - lock (helper.connection) - { - DbDataReader reader = helper.ExecuteSQLCommand("SELECT name FROM searchdomain", []); - List results = []; - try - { - while (reader.Read()) - { - results.Add(reader.GetString(0)); - } - return results; - } - finally - { - reader.Close(); - } - } + return await helper.ExecuteQueryAsync("SELECT name FROM searchdomain", [], x => x.GetString(0)); } - public int CreateSearchdomain(string searchdomain, SearchdomainSettings settings) + public async Task CreateSearchdomain(string searchdomain, SearchdomainSettings settings) { - return CreateSearchdomain(searchdomain, JsonSerializer.Serialize(settings)); + return await CreateSearchdomain(searchdomain, JsonSerializer.Serialize(settings)); } - public int CreateSearchdomain(string searchdomain, string settings = "{}") + + public async Task CreateSearchdomain(string searchdomain, string settings = "{}") { if (searchdomains.TryGetValue(searchdomain, out Searchdomain? value)) { @@ -116,18 +101,19 @@ public class SearchdomainManager : IDisposable { "name", searchdomain }, { "settings", settings} }; - return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters); + return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters); } - public int DeleteSearchdomain(string searchdomain) + public async Task DeleteSearchdomain(string searchdomain) { - int counter = _databaseHelper.RemoveAllEntities(helper, searchdomain); + int counter = await _databaseHelper.RemoveAllEntities(helper, searchdomain); _logger.LogDebug($"Number of entities deleted as part of deleting the searchdomain \"{searchdomain}\": {counter}"); - helper.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", searchdomain}}); + await helper.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", searchdomain}}); searchdomains.Remove(searchdomain); _logger.LogDebug($"Searchdomain has been successfully removed"); return counter; } + private Searchdomain SetSearchdomain(string name, Searchdomain searchdomain) { searchdomains[name] = searchdomain;