11 Commits

Author SHA1 Message Date
b5a8eec445 Added reranker exploration setup 2026-03-08 10:49:27 +01:00
LD50
6f6ded1d90 Merge pull request #130 from LD-Reborn/129-post-entity-only-does-upserting
129 post entity only does upserting
2026-02-22 20:00:08 +01:00
cda028f213 Fixed naming convention issues 2026-02-22 19:59:49 +01:00
0582ff9a6c Fixed Putting entities only upserts entities instead of also deleting non-existant ones 2026-02-22 19:48:26 +01:00
LD50
51d34cb06c Merge pull request #128 from LD-Reborn/104-embedding-cache-store-exception-on-shutdown
104 embedding cache store exception on shutdown
2026-02-21 22:23:31 +01:00
dbc5e9e6e8 Fixed UNIQUE constraint failed exception 2026-02-21 22:23:11 +01:00
820ecbc83b Fixed embeddings generation errors not propagating to response model in a user-friendly way, Fixed non-awaited SQL actions, Fixed connection pool filling up, fixed newly created searchdomain not found 2026-02-19 03:00:46 +01:00
LD50
cda8c61429 Merge pull request #127 from LD-Reborn/124-cannot-delete-large-searchdomains
Fixed entityCache not multithreading safe, Reduced expensive table jo…
2026-02-18 13:42:23 +01:00
f537912e4e Fixed entityCache not multithreading safe, Reduced expensive table joins for embedding, Fixed timeouts on large deletes, fixed possible unclosed readers, Improved EntityFromJSON speed, Added connection pool control for fault tolerance, Fixed modelsInUse multithreading safety 2026-02-18 13:41:55 +01:00
LD50
7a0363a470 Merge pull request #125 from LD-Reborn/118-searchdomainhelper-fix-non-bulk-queries
118 searchdomainhelper fix non bulk queries
2026-02-14 17:45:03 +01:00
4aabc3bae0 Fixed DatabaseInsertEmbeddingBulk, Added attributes bulk edit and delete, Fixed entityCache not multithreading safe, fixed EntityFromJSON missing bulk inserts, Added retry logic for BulkExecuteNonQuery, added MaxRequestBodySize configuration 2026-02-12 20:57:01 +01:00
24 changed files with 1445 additions and 636 deletions

View File

@@ -47,15 +47,27 @@ public class Client
return await FetchUrlAndProcessJson<EntityListResults>(HttpMethod.Get, url); return await FetchUrlAndProcessJson<EntityListResults>(HttpMethod.Get, url);
} }
public async Task<EntityIndexResult> EntityIndexAsync(List<JSONEntity> jsonEntity) public async Task<EntityIndexResult> EntityIndexAsync(List<JSONEntity> jsonEntity, string? sessionId = null, bool? sessionComplete = null)
{ {
return await EntityIndexAsync(JsonSerializer.Serialize(jsonEntity)); return await EntityIndexAsync(JsonSerializer.Serialize(jsonEntity), sessionId, sessionComplete);
} }
public async Task<EntityIndexResult> EntityIndexAsync(string jsonEntity) public async Task<EntityIndexResult> EntityIndexAsync(string jsonEntity, string? sessionId = null, bool? sessionComplete = null)
{ {
var content = new StringContent(jsonEntity, Encoding.UTF8, "application/json"); var content = new StringContent(jsonEntity, Encoding.UTF8, "application/json");
return await FetchUrlAndProcessJson<EntityIndexResult>(HttpMethod.Put, GetUrl($"{baseUri}", "Entities", []), content); Dictionary<string, string> parameters = [];
if (sessionId is not null) parameters.Add("sessionId", sessionId);
if (sessionComplete is not null) parameters.Add("sessionComplete", ((bool)sessionComplete).ToString());
return await FetchUrlAndProcessJson<EntityIndexResult>(
HttpMethod.Put,
GetUrl(
$"{baseUri}",
$"Entities",
parameters
),
content
);
} }
public async Task<EntityDeleteResults> EntityDeleteAsync(string entityName) public async Task<EntityDeleteResults> EntityDeleteAsync(string entityName)
@@ -148,6 +160,21 @@ public class Client
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null); return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null);
} }
public async Task<EntityQueryResults> SearchdomainQueryRerankedAsync(string searchdomain, string query, string rerankerModel, int topN, int topNRetrieval, bool returnAttributes = false)
{
Dictionary<string, string> parameters = new()
{
{ "searchdomain", searchdomain },
{ "query", query },
{ "rerankerModel", (rerankerModel).ToString() },
{ "topN", (topN).ToString() },
{ "topNRetrieval", (topNRetrieval).ToString() }
};
if (returnAttributes) parameters.Add("returnAttributes", returnAttributes.ToString());
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "QueryReranked", parameters), null);
}
public async Task<SearchdomainDeleteSearchResult> SearchdomainDeleteQueryAsync(string searchdomain, string query) public async Task<SearchdomainDeleteSearchResult> SearchdomainDeleteQueryAsync(string searchdomain, string query)
{ {
Dictionary<string, string> parameters = new() Dictionary<string, string> parameters = new()

View File

@@ -65,6 +65,7 @@ def index_files(toolset: Toolset):
jsonEntities.append(jsonEntity) jsonEntities.append(jsonEntity)
jsonstring = json.dumps(jsonEntities) jsonstring = json.dumps(jsonEntities)
timer_start = time.time() timer_start = time.time()
# Index all entities in one go. If you need to split it into chunks, use the session attributes! See example_chunked.py
result:EntityIndexResult = toolset.Client.EntityIndexAsync(jsonstring).Result result:EntityIndexResult = toolset.Client.EntityIndexAsync(jsonstring).Result
timer_end = time.time() timer_end = time.time()
toolset.Logger.LogInformation(f"Update was successful: {result.Success} - and was done in {timer_end - timer_start} seconds.") toolset.Logger.LogInformation(f"Update was successful: {result.Success} - and was done in {timer_end - timer_start} seconds.")

View File

@@ -0,0 +1,85 @@
import math
import os
from tools import *
import json
from dataclasses import asdict
import time
import uuid
example_content = "./Scripts/example_content"
probmethod = "HVEWAvg"
similarityMethod = "Cosine"
example_searchdomain = "example_" + probmethod
example_counter = 0
models = ["ollama:bge-m3", "ollama:mxbai-embed-large"]
probmethod_datapoint = probmethod
probmethod_entity = probmethod
# Example for a dictionary based weighted average:
# probmethod_datapoint = "DictionaryWeightedAverage:{\"ollama:bge-m3\": 4, \"ollama:mxbai-embed-large\": 1}"
# probmethod_entity = "DictionaryWeightedAverage:{\"title\": 2, \"filename\": 0.1, \"text\": 0.25}"
def init(toolset: Toolset):
global example_counter
toolset.Logger.LogInformation("{toolset.Name} - init", toolset.Name)
toolset.Logger.LogInformation("This is the init function from the python example script")
toolset.Logger.LogInformation(f"example_counter: {example_counter}")
searchdomainlist:SearchdomainListResults = toolset.Client.SearchdomainListAsync().Result
if example_searchdomain not in searchdomainlist.Searchdomains:
toolset.Client.SearchdomainCreateAsync(example_searchdomain).Result
searchdomainlist = toolset.Client.SearchdomainListAsync().Result
output = "Currently these searchdomains exist:\n"
for searchdomain in searchdomainlist.Searchdomains:
output += f" - {searchdomain}\n"
toolset.Logger.LogInformation(output)
def update(toolset: Toolset):
global example_counter
toolset.Logger.LogInformation("{toolset.Name} - update", toolset.Name)
toolset.Logger.LogInformation("This is the update function from the python example script")
callbackInfos:ICallbackInfos = toolset.CallbackInfos
if (str(callbackInfos) == "Indexer.Models.RunOnceCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by a runonce call")
elif (str(callbackInfos) == "Indexer.Models.IntervalCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by an interval call")
elif (str(callbackInfos) == "Indexer.Models.ScheduleCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by a schedule call")
elif (str(callbackInfos) == "Indexer.Models.FileUpdateCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by a fileupdate call")
else:
toolset.Logger.LogInformation("It was triggered, but the origin of the call could not be determined")
example_counter += 1
toolset.Logger.LogInformation(f"example_counter: {example_counter}")
index_files(toolset)
def index_files(toolset: Toolset):
jsonEntities:list = []
for filename in os.listdir(example_content):
qualified_filepath = example_content + "/" + filename
with open(qualified_filepath, "r", encoding='utf-8', errors="replace") as file:
title = file.readline()
text = file.read()
datapoints:list = [
JSONDatapoint("filename", qualified_filepath, probmethod_datapoint, similarityMethod, models),
JSONDatapoint("title", title, probmethod_datapoint, similarityMethod, models),
JSONDatapoint("text", text, probmethod_datapoint, similarityMethod, models)
]
jsonEntity:dict = asdict(JSONEntity(qualified_filepath, probmethod_entity, example_searchdomain, {}, datapoints))
jsonEntities.append(jsonEntity)
timer_start = time.time()
chunkSize = 5
chunkList = chunk_list(jsonEntities, chunkSize)
chunkCount = math.ceil(len(jsonEntities) / chunkSize)
sessionId = uuid.uuid4().hex
print(f"indexing {len(jsonEntities)} entities")
for i, entities in enumerate(chunkList):
isLast = i == chunkCount
print(f'Processing chunk {i} / {len(jsonEntities) / chunkSize}')
jsonstring = json.dumps(entities)
result:EntityIndexResult = toolset.Client.EntityIndexAsync(jsonstring, sessionId, isLast).Result
timer_end = time.time()
toolset.Logger.LogInformation(f"Update was successful: {result.Success} - and was done in {timer_end - timer_start} seconds.")
def chunk_list(lst, chunk_size):
for i in range(0, len(lst), chunk_size):
yield lst[i: i + chunk_size]

View File

@@ -107,6 +107,8 @@ class Client:
# pass # pass
async def EntityIndexAsync(jsonEntity:str) -> EntityIndexResult: async def EntityIndexAsync(jsonEntity:str) -> EntityIndexResult:
pass pass
async def EntityIndexAsync(jsonEntity:str, sessionId:str, sessionComplete:bool) -> EntityIndexResult:
pass
async def EntityIndexAsync(searchdomain:str, jsonEntity:str) -> EntityIndexResult: async def EntityIndexAsync(searchdomain:str, jsonEntity:str) -> EntityIndexResult:
pass pass
async def EntityListAsync(returnEmbeddings:bool = False) -> EntityListResults: async def EntityListAsync(returnEmbeddings:bool = False) -> EntityListResults:

View File

@@ -13,7 +13,7 @@ public class AIProvider
{ {
private readonly ILogger<AIProvider> _logger; private readonly ILogger<AIProvider> _logger;
private readonly EmbeddingSearchOptions _configuration; private readonly EmbeddingSearchOptions _configuration;
public Dictionary<string, AiProvider> aIProvidersConfiguration; public Dictionary<string, AiProvider> AiProvidersConfiguration;
public AIProvider(ILogger<AIProvider> logger, IOptions<EmbeddingSearchOptions> configuration) public AIProvider(ILogger<AIProvider> logger, IOptions<EmbeddingSearchOptions> configuration)
{ {
@@ -27,7 +27,7 @@ public class AIProvider
} }
else else
{ {
aIProvidersConfiguration = retrievedAiProvidersConfiguration; AiProvidersConfiguration = retrievedAiProvidersConfiguration;
} }
} }
@@ -41,7 +41,7 @@ public class AIProvider
Uri uri = new(modelUri); Uri uri = new(modelUri);
string provider = uri.Scheme; string provider = uri.Scheme;
string model = uri.AbsolutePath; string model = uri.AbsolutePath;
AiProvider? aIProvider = aIProvidersConfiguration AiProvider? aIProvider = AiProvidersConfiguration
.FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower())) .FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower()))
.Value; .Value;
if (aIProvider is null) if (aIProvider is null)
@@ -109,10 +109,19 @@ public class AIProvider
{ {
JObject responseContentJson = JObject.Parse(responseContent); JObject responseContentJson = JObject.Parse(responseContent);
List<JToken>? responseContentTokens = [.. responseContentJson.SelectTokens(embeddingsJsonPath)]; List<JToken>? responseContentTokens = [.. responseContentJson.SelectTokens(embeddingsJsonPath)];
if (responseContentTokens is null) if (responseContentTokens is null || responseContentTokens.Count == 0)
{ {
_logger.LogError("Unable to select tokens using JSONPath {embeddingsJsonPath} for string: {responseContent}.", [embeddingsJsonPath, responseContent]); if (responseContentJson.TryGetValue("error", out JToken? errorMessageJson) && errorMessageJson is not null)
throw new JSONPathSelectionException(embeddingsJsonPath, responseContent); {
string errorMessage = errorMessageJson.Value<string>() ?? "";
_logger.LogError("Unable to retrieve embeddings due to error: {errorMessage}", [errorMessage]);
throw new Exception($"Unable to retrieve embeddings due to error: {errorMessage}");
} else
{
_logger.LogError("Unable to select tokens using JSONPath {embeddingsJsonPath} for string: {responseContent}.", [embeddingsJsonPath, responseContent]);
throw new JSONPathSelectionException(embeddingsJsonPath, responseContent);
}
} }
return [.. responseContentTokens.Select(token => token.ToObject<float[]>() ?? throw new Exception("Unable to cast embeddings response to float[]"))]; return [.. responseContentTokens.Select(token => token.ToObject<float[]>() ?? throw new Exception("Unable to cast embeddings response to float[]"))];
} }
@@ -123,9 +132,110 @@ public class AIProvider
} }
} }
public IEnumerable<(int index, float score)> Rerank(string modelUri, string input, string[] documents, int topN)
{
Uri uri = new(modelUri);
string provider = uri.Scheme;
string model = uri.AbsolutePath;
AiProvider? aIProvider = AiProvidersConfiguration
.FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower()))
.Value;
if (aIProvider is null)
{
_logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]);
throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}");
}
using var httpClient = new HttpClient();
httpClient.Timeout = TimeSpan.FromMinutes(150);
string indexJsonPath = "";
string scoreJsonPath = "";
IEnumerable<(string, float)> values = [];
Uri baseUri = new(aIProvider.BaseURL);
Uri requestUri;
IRerankRequestBody rerankRequestBody;
string[][] requestHeaders = [];
switch (aIProvider.Handler)
{
case "openai":
indexJsonPath = "$.results[*].index";
scoreJsonPath = "$.results[*].relevance_score";
requestUri = new Uri(baseUri, "/v1/rerank");
rerankRequestBody = new OpenAIRerankRequestBody()
{
model = model,
query = input,
documents = documents,
top_n = topN
};
if (aIProvider.ApiKey is not null)
{
requestHeaders = [
["Authorization", $"Bearer {aIProvider.ApiKey}"]
];
}
break;
default:
_logger.LogError("Invalid reranking handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]);
throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}.");
}
var requestContent = new StringContent(
JsonConvert.SerializeObject(rerankRequestBody),
Encoding.UTF8,
"application/json"
);
var request = new HttpRequestMessage()
{
RequestUri = requestUri,
Method = HttpMethod.Post,
Content = requestContent
};
foreach (var header in requestHeaders)
{
request.Headers.Add(header[0], header[1]);
}
HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result;
string responseContent = response.Content.ReadAsStringAsync().Result;
try
{
JObject responseContentJson = JObject.Parse(responseContent);
List<JToken>? responseContentIndexTokens = [.. responseContentJson.SelectTokens(indexJsonPath)];
List<JToken>? responseContentScoreTokens = [.. responseContentJson.SelectTokens(scoreJsonPath)];
if (responseContentIndexTokens is null || responseContentIndexTokens.Count == 0
|| responseContentScoreTokens is null || responseContentScoreTokens.Count == 0)
{
if (responseContentJson.TryGetValue("error", out JToken? errorMessageJson) && errorMessageJson is not null)
{
string errorMessage = (string?)errorMessageJson.Value<string>("message") ?? "";
string errorCode = (string?)errorMessageJson.Value<string>("code") ?? "";
string errorType = (string?)errorMessageJson.Value<string>("type") ?? "";
_logger.LogError("Unable to retrieve reranking results due to error: {errorCode} - {errorMessage} - {errorType}", [errorCode, errorMessage, errorType]);
throw new Exception($"Unable to retrieve reranking results due to error: {errorMessage}");
} else
{
_logger.LogError("Unable to select tokens using JSONPath {indexJsonPath} for string: {responseContent}.", [indexJsonPath, responseContent]);
throw new JSONPathSelectionException(indexJsonPath, responseContent);
}
}
IEnumerable<int> indices = responseContentIndexTokens.Select(token => token.ToObject<int>());
IEnumerable<float> scores = responseContentScoreTokens.Select(token => token.ToObject<float>());
IEnumerable<(int index, float score)> zipped = indices.Zip(scores, (index, score) => (index, score));
return zipped;
}
catch (Exception ex)
{
_logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]);
throw;
}
}
public string[] GetModels() public string[] GetModels()
{ {
var aIProviders = aIProvidersConfiguration; var aIProviders = AiProvidersConfiguration;
List<string> results = []; List<string> results = [];
foreach (KeyValuePair<string, AiProvider> aIProviderKV in aIProviders) foreach (KeyValuePair<string, AiProvider> aIProviderKV in aIProviders)
{ {
@@ -231,3 +341,15 @@ public class OpenAIEmbedRequestBody : IEmbedRequestBody
public required string model { get; set; } public required string model { get; set; }
public required string[] input { get; set; } public required string[] input { get; set; }
} }
public interface IRerankRequestBody { }
public class OpenAIRerankRequestBody : IRerankRequestBody
{
public required string model { get; set; }
public required string query { get; set; }
public required int top_n { get; set; }
public required string[] documents { get; set; }
}

View File

@@ -14,6 +14,9 @@ public class EntityController : ControllerBase
private SearchdomainManager _domainManager; private SearchdomainManager _domainManager;
private readonly SearchdomainHelper _searchdomainHelper; private readonly SearchdomainHelper _searchdomainHelper;
private readonly DatabaseHelper _databaseHelper; private readonly DatabaseHelper _databaseHelper;
private readonly Dictionary<string, EntityIndexSessionData> _sessions = [];
private readonly object _sessionLock = new();
private const int SessionTimeoutMinutes = 60; // TODO: remove magic number; add an optional configuration option
public EntityController(ILogger<EntityController> logger, IConfiguration config, SearchdomainManager domainManager, SearchdomainHelper searchdomainHelper, DatabaseHelper databaseHelper) public EntityController(ILogger<EntityController> logger, IConfiguration config, SearchdomainManager domainManager, SearchdomainHelper searchdomainHelper, DatabaseHelper databaseHelper)
{ {
@@ -46,34 +49,34 @@ public class EntityController : ControllerBase
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
EntityListResults entityListResults = new() {Results = [], Success = true}; EntityListResults entityListResults = new() {Results = [], Success = true};
foreach (Entity entity in searchdomain_.entityCache) foreach ((string _, Entity entity) in searchdomain_.EntityCache)
{ {
List<AttributeResult> attributeResults = []; List<AttributeResult> attributeResults = [];
foreach (KeyValuePair<string, string> attribute in entity.attributes) foreach (KeyValuePair<string, string> attribute in entity.Attributes)
{ {
attributeResults.Add(new AttributeResult() {Name = attribute.Key, Value = attribute.Value}); attributeResults.Add(new AttributeResult() {Name = attribute.Key, Value = attribute.Value});
} }
List<DatapointResult> datapointResults = []; List<DatapointResult> datapointResults = [];
foreach (Datapoint datapoint in entity.datapoints) foreach (Datapoint datapoint in entity.Datapoints)
{ {
if (returnModels) if (returnModels)
{ {
List<EmbeddingResult> embeddingResults = []; List<EmbeddingResult> embeddingResults = [];
foreach ((string, float[]) embedding in datapoint.embeddings) foreach ((string, float[]) embedding in datapoint.Embeddings)
{ {
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = returnEmbeddings ? embedding.Item2 : []}); embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = returnEmbeddings ? embedding.Item2 : []});
} }
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = embeddingResults}); datapointResults.Add(new DatapointResult() {Name = datapoint.Name, ProbMethod = datapoint.ProbMethod.Name, SimilarityMethod = datapoint.SimilarityMethod.Name, Embeddings = embeddingResults});
} }
else else
{ {
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = null}); datapointResults.Add(new DatapointResult() {Name = datapoint.Name, ProbMethod = datapoint.ProbMethod.Name, SimilarityMethod = datapoint.SimilarityMethod.Name, Embeddings = null});
} }
} }
EntityListResult entityListResult = new() EntityListResult entityListResult = new()
{ {
Name = entity.name, Name = entity.Name,
ProbMethod = entity.probMethodName, ProbMethod = entity.ProbMethodName,
Attributes = attributeResults, Attributes = attributeResults,
Datapoints = datapointResults Datapoints = datapointResults
}; };
@@ -86,31 +89,59 @@ public class EntityController : ControllerBase
/// Index entities /// Index entities
/// </summary> /// </summary>
/// <remarks> /// <remarks>
/// Behavior: Creates new entities, but overwrites existing entities that have the same name /// Behavior: Updates the index using the provided entities. Creates new entities, overwrites existing entities with the same name, and deletes entities that are not part of the index anymore.
///
/// Can be executed in a single request or in multiple chunks using a (self-defined) session UUID string.
///
/// For session-based chunk uploads:
/// - Provide sessionId to accumulate entities across multiple requests
/// - Set sessionComplete=true on the final request to finalize and delete entities that are not in the accumulated list
/// - Without sessionId: Missing entities will be deleted from the searchdomain.
/// - Sessions expire after 60 minutes of inactivity (or as otherwise configured in the appsettings)
/// </remarks> /// </remarks>
/// <param name="jsonEntities">Entities to index</param> /// <param name="jsonEntities">Entities to index</param>
/// <param name="sessionId">Optional session ID for batch uploads across multiple requests</param>
/// <param name="sessionComplete">If true, finalizes the session and deletes entities not in the accumulated list</param>
[HttpPut("/Entities")] [HttpPut("/Entities")]
public ActionResult<EntityIndexResult> Index([FromBody] List<JSONEntity>? jsonEntities) public async Task<ActionResult<EntityIndexResult>> Index(
[FromBody] List<JSONEntity>? jsonEntities,
string? sessionId = null,
bool sessionComplete = false)
{ {
try try
{ {
List<Entity>? entities = _searchdomainHelper.EntitiesFromJSON( if (sessionId is null || string.IsNullOrWhiteSpace(sessionId))
{
sessionId = Guid.NewGuid().ToString(); // Create a short-lived session
sessionComplete = true; // If no sessionId was set, there is no trackable session. The pseudo-session ends here.
}
// Periodic cleanup of expired sessions
CleanupExpiredEntityIndexSessions();
EntityIndexSessionData session = GetOrCreateEntityIndexSession(sessionId);
if (jsonEntities is null && !sessionComplete)
{
return BadRequest(new EntityIndexResult() { Success = false, Message = "jsonEntities can only be null for a complete session" });
} else if (jsonEntities is null && sessionComplete)
{
await EntityIndexSessionDeleteUnindexedEntities(session);
return Ok(new EntityIndexResult() { Success = true });
}
// Standard entity indexing (upsert behavior)
List<Entity>? entities = await _searchdomainHelper.EntitiesFromJSON(
_domainManager, _domainManager,
_logger, _logger,
JsonSerializer.Serialize(jsonEntities)); JsonSerializer.Serialize(jsonEntities));
if (entities is not null && jsonEntities is not null) if (entities is not null && jsonEntities is not null)
{ {
List<string> invalidatedSearchdomains = []; session.AccumulatedEntities.AddRange(entities);
foreach (var jsonEntity in jsonEntities)
if (sessionComplete)
{ {
string jsonEntityName = jsonEntity.Name; await EntityIndexSessionDeleteUnindexedEntities(session);
string jsonEntitySearchdomainName = jsonEntity.Searchdomain;
if (entities.Select(x => x.name == jsonEntityName).Any()
&& !invalidatedSearchdomains.Contains(jsonEntitySearchdomainName))
{
invalidatedSearchdomains.Add(jsonEntitySearchdomainName);
}
} }
return Ok(new EntityIndexResult() { Success = true }); return Ok(new EntityIndexResult() { Success = true });
} }
else else
@@ -129,18 +160,56 @@ public class EntityController : ControllerBase
} }
private async Task EntityIndexSessionDeleteUnindexedEntities(EntityIndexSessionData session)
{
var entityGroupsBySearchdomain = session.AccumulatedEntities.GroupBy(e => e.Searchdomain);
foreach (var entityGroup in entityGroupsBySearchdomain)
{
string searchdomainName = entityGroup.Key;
var entityNamesInRequest = entityGroup.Select(e => e.Name).ToHashSet();
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) =
SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomainName, _logger);
if (searchdomain_ is not null && httpStatusCode is null) // If getting searchdomain was successful
{
var entitiesToDelete = searchdomain_.EntityCache
.Where(kvp => !entityNamesInRequest.Contains(kvp.Value.Name))
.Select(kvp => kvp.Value)
.ToList();
foreach (var entity in entitiesToDelete)
{
searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity);
await _databaseHelper.RemoveEntity(
[],
_domainManager.Helper,
entity.Name,
searchdomainName);
searchdomain_.EntityCache.TryRemove(entity.Name, out _);
_logger.LogInformation("Deleted entity {entityName} from {searchdomain}", entity.Name, searchdomainName);
}
}
else
{
_logger.LogWarning("Unable to delete entities for searchdomain {searchdomain}", searchdomainName);
}
}
}
/// <summary> /// <summary>
/// Deletes an entity /// Deletes an entity
/// </summary> /// </summary>
/// <param name="searchdomain">Name of the searchdomain</param> /// <param name="searchdomain">Name of the searchdomain</param>
/// <param name="entityName">Name of the entity</param> /// <param name="entityName">Name of the entity</param>
[HttpDelete] [HttpDelete]
public ActionResult<EntityDeleteResults> Delete(string searchdomain, string entityName) public async Task<ActionResult<EntityDeleteResults>> Delete(string searchdomain, string entityName)
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
Entity? entity_ = SearchdomainHelper.CacheGetEntity(searchdomain_.entityCache, entityName); Entity? entity_ = SearchdomainHelper.CacheGetEntity(searchdomain_.EntityCache, entityName);
if (entity_ is null) if (entity_ is null)
{ {
_logger.LogError("Unable to delete the entity {entityName} in {searchdomain} - it was not found under the specified name", [entityName, searchdomain]); _logger.LogError("Unable to delete the entity {entityName} in {searchdomain} - it was not found under the specified name", [entityName, searchdomain]);
@@ -152,8 +221,50 @@ public class EntityController : ControllerBase
return Ok(new EntityDeleteResults() {Success = false, Message = "Entity not found"}); return Ok(new EntityDeleteResults() {Success = false, Message = "Entity not found"});
} }
searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_); searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_);
_databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain); await _databaseHelper.RemoveEntity([], _domainManager.Helper, entityName, searchdomain);
searchdomain_.entityCache.RemoveAll(entity => entity.name == entityName);
return Ok(new EntityDeleteResults() {Success = true}); bool success = searchdomain_.EntityCache.TryRemove(entityName, out Entity? _);
return Ok(new EntityDeleteResults() {Success = success});
}
private void CleanupExpiredEntityIndexSessions()
{
lock (_sessionLock)
{
var expiredSessions = _sessions
.Where(kvp => (DateTime.UtcNow - kvp.Value.LastInteractionAt).TotalMinutes > SessionTimeoutMinutes)
.Select(kvp => kvp.Key)
.ToList();
foreach (var sessionId in expiredSessions)
{
_sessions.Remove(sessionId);
_logger.LogWarning("Removed expired, non-closed session {sessionId}", sessionId);
}
}
}
private EntityIndexSessionData GetOrCreateEntityIndexSession(string sessionId)
{
lock (_sessionLock)
{
if (!_sessions.TryGetValue(sessionId, out var session))
{
session = new EntityIndexSessionData();
_sessions[sessionId] = session;
} else
{
session.LastInteractionAt = DateTime.UtcNow;
}
return session;
}
} }
} }
public class EntityIndexSessionData
{
public List<Entity> AccumulatedEntities { get; set; } = [];
public DateTime LastInteractionAt { get; set; } = DateTime.UtcNow;
}

View File

@@ -35,11 +35,11 @@ public class HomeController : Controller
[Authorize] [Authorize]
[HttpGet("Searchdomains")] [HttpGet("Searchdomains")]
public IActionResult Searchdomains() public async Task<ActionResult> Searchdomains()
{ {
HomeIndexViewModel viewModel = new() HomeIndexViewModel viewModel = new()
{ {
Searchdomains = _domainManager.ListSearchdomains() Searchdomains = await _domainManager.ListSearchdomainsAsync()
}; };
return View(viewModel); return View(viewModel);
} }

View File

@@ -29,12 +29,12 @@ public class SearchdomainController : ControllerBase
/// Lists all searchdomains /// Lists all searchdomains
/// </summary> /// </summary>
[HttpGet("/Searchdomains")] [HttpGet("/Searchdomains")]
public ActionResult<SearchdomainListResults> List() public async Task<ActionResult<SearchdomainListResults>> List()
{ {
List<string> results; List<string> results;
try try
{ {
results = _domainManager.ListSearchdomains(); results = await _domainManager.ListSearchdomainsAsync();
} }
catch (Exception) catch (Exception)
{ {
@@ -51,7 +51,7 @@ public class SearchdomainController : ControllerBase
/// <param name="searchdomain">Name of the searchdomain</param> /// <param name="searchdomain">Name of the searchdomain</param>
/// <param name="settings">Optional initial settings</param> /// <param name="settings">Optional initial settings</param>
[HttpPost] [HttpPost]
public ActionResult<SearchdomainCreateResults> Create([Required]string searchdomain, [FromBody]SearchdomainSettings settings = new()) public async Task<ActionResult<SearchdomainCreateResults>> Create([Required]string searchdomain, [FromBody]SearchdomainSettings settings = new())
{ {
try try
{ {
@@ -59,7 +59,7 @@ public class SearchdomainController : ControllerBase
{ {
settings.QueryCacheSize = 1_000_000; // TODO get rid of this magic number settings.QueryCacheSize = 1_000_000; // TODO get rid of this magic number
} }
int id = _domainManager.CreateSearchdomain(searchdomain, settings); int id = await _domainManager.CreateSearchdomain(searchdomain, settings);
return Ok(new SearchdomainCreateResults(){Id = id, Success = true}); return Ok(new SearchdomainCreateResults(){Id = id, Success = true});
} catch (Exception) } catch (Exception)
{ {
@@ -73,7 +73,7 @@ public class SearchdomainController : ControllerBase
/// </summary> /// </summary>
/// <param name="searchdomain">Name of the searchdomain</param> /// <param name="searchdomain">Name of the searchdomain</param>
[HttpDelete] [HttpDelete]
public ActionResult<SearchdomainDeleteResults> Delete([Required]string searchdomain) public async Task<ActionResult<SearchdomainDeleteResults>> Delete([Required]string searchdomain)
{ {
bool success; bool success;
int deletedEntries; int deletedEntries;
@@ -81,7 +81,7 @@ public class SearchdomainController : ControllerBase
try try
{ {
success = true; success = true;
deletedEntries = _domainManager.DeleteSearchdomain(searchdomain); deletedEntries = await _domainManager.DeleteSearchdomain(searchdomain);
} }
catch (SearchdomainNotFoundException ex) catch (SearchdomainNotFoundException ex)
{ {
@@ -109,7 +109,7 @@ public class SearchdomainController : ControllerBase
/// <param name="newName">Updated name of the searchdomain</param> /// <param name="newName">Updated name of the searchdomain</param>
/// <param name="settings">Updated settings of searchdomain</param> /// <param name="settings">Updated settings of searchdomain</param>
[HttpPut] [HttpPut]
public ActionResult<SearchdomainUpdateResults> Update([Required]string searchdomain, string newName, [FromBody]SearchdomainSettings? settings) public async Task<ActionResult<SearchdomainUpdateResults>> Update([Required]string searchdomain, string newName, [FromBody]SearchdomainSettings? settings)
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
@@ -118,18 +118,18 @@ public class SearchdomainController : ControllerBase
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{"name", newName}, {"name", newName},
{"id", searchdomain_.id} {"id", searchdomain_.Id}
}; };
searchdomain_.helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name WHERE id = @id", parameters); await searchdomain_.Helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name WHERE id = @id", parameters);
} else } else
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{"name", newName}, {"name", newName},
{"settings", settings}, {"settings", settings},
{"id", searchdomain_.id} {"id", searchdomain_.Id}
}; };
searchdomain_.helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name, settings = @settings WHERE id = @id", parameters); await searchdomain_.Helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name, settings = @settings WHERE id = @id", parameters);
} }
return Ok(new SearchdomainUpdateResults(){Success = true}); return Ok(new SearchdomainUpdateResults(){Success = true});
} }
@@ -143,11 +143,83 @@ public class SearchdomainController : ControllerBase
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
Dictionary<string, DateTimedSearchResult> searchCache = searchdomain_.queryCache.AsDictionary(); Dictionary<string, DateTimedSearchResult> searchCache = searchdomain_.QueryCache.AsDictionary();
return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true }); return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true });
} }
/// <summary>
/// Executes a query in the searchdomain and reranks the result using a specified reranker
/// </summary>
/// <param name="searchdomain">Name of the searchdomain</param>
/// <param name="query">Query to execute</param>
/// <param name="topN">Return only the top N results</param>
/// <param name="returnAttributes">Return the attributes of the object</param>
[HttpPost("QueryReranked")]
public ActionResult<EntityRerankResults> QueryReranked([Required]string searchdomain, [Required]string query, [Required]string rerankerModel, int topN, int topNRetrieval, ProbMethodEnum probMethod = ProbMethodEnum.HVEWAvg, bool returnAttributes = false)
{
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
List<(float, string)> results = searchdomain_.Search(query, topNRetrieval);
List<(string Name, Dictionary<string, string> Attributes)> queryResults = [.. results.Select(r => (
Name: r.Item2,
Attributes: searchdomain_.EntityCache[r.Item2]?.Attributes ?? []
))];
// Key: Attribute name
Dictionary<string, List<(string EntityName, string AttributeValue)>> resultsByAttribute = [];
queryResults.ForEach(r =>
{
foreach (var kv in r.Attributes)
{
if (!resultsByAttribute.TryGetValue(kv.Key, out List<(string EntityName, string AttributeValue)>? values) || values is null)
{
values = [];
resultsByAttribute[kv.Key] = values;
}
values.Add((r.Name, kv.Value));
}
});
// Key: EntityName
Dictionary<string, List<(string attribute, float score)>> scoresByEntity = [];
foreach (var kv in resultsByAttribute)
{
string attributeName = kv.Key;
List<(string EntityName, string AttributeValue)> nameValuePairs = kv.Value;
List<string> documents = [.. nameValuePairs.Select(r => r.AttributeValue)];
List<(int index, float score)> rerankResults = [.. searchdomain_.AiProvider.Rerank(rerankerModel, query, [.. documents], topN)];
List<(string entityName, float score)> rerankedScores = [.. rerankResults.Select(r => (nameValuePairs.ElementAt(r.index).EntityName, r.score))];
foreach ((string entityName, float score) in rerankedScores)
{
if (!scoresByEntity.TryGetValue(entityName, out List<(string attribute, float score)>? values) || values is null)
{
values = [];
scoresByEntity[entityName] = values;
}
values.Add((attributeName, score));
}
}
List<EntityRerankResult> entityRerankResults = [.. scoresByEntity.Select(scoreKV =>
{
string entityName = scoreKV.Key;
float score = new ProbMethod(probMethod).Method(scoreKV.Value);
return new EntityRerankResult()
{
Name = entityName,
Value = score,
Attributes = returnAttributes ? (searchdomain_.EntityCache[entityName]?.Attributes ?? []) : null
};
})];
return Ok(new EntityRerankResults(){Results = entityRerankResults, Success = true });
}
/// <summary> /// <summary>
/// Executes a query in the searchdomain /// Executes a query in the searchdomain
/// </summary> /// </summary>
@@ -165,7 +237,7 @@ public class SearchdomainController : ControllerBase
{ {
Name = r.Item2, Name = r.Item2,
Value = r.Item1, Value = r.Item1,
Attributes = returnAttributes ? (searchdomain_.entityCache.FirstOrDefault(x => x.name == r.Item2)?.attributes ?? null) : null Attributes = returnAttributes ? (searchdomain_.EntityCache[r.Item2]?.Attributes ?? null) : null
})]; })];
return Ok(new EntityQueryResults(){Results = queryResults, Success = true }); return Ok(new EntityQueryResults(){Results = queryResults, Success = true });
} }
@@ -180,7 +252,7 @@ public class SearchdomainController : ControllerBase
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.queryCache; EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.QueryCache;
bool containsKey = searchCache.ContainsKey(query); bool containsKey = searchCache.ContainsKey(query);
if (containsKey) if (containsKey)
{ {
@@ -201,7 +273,7 @@ public class SearchdomainController : ControllerBase
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.queryCache; EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.QueryCache;
bool containsKey = searchCache.ContainsKey(query); bool containsKey = searchCache.ContainsKey(query);
if (containsKey) if (containsKey)
{ {
@@ -222,7 +294,7 @@ public class SearchdomainController : ControllerBase
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
SearchdomainSettings settings = searchdomain_.settings; SearchdomainSettings settings = searchdomain_.Settings;
return Ok(new SearchdomainSettingsResults() { Settings = settings, Success = true }); return Ok(new SearchdomainSettingsResults() { Settings = settings, Success = true });
} }
@@ -230,19 +302,20 @@ public class SearchdomainController : ControllerBase
/// Update the settings of a searchdomain /// Update the settings of a searchdomain
/// </summary> /// </summary>
/// <param name="searchdomain">Name of the searchdomain</param> /// <param name="searchdomain">Name of the searchdomain</param>
/// <param name="request">Settings to apply to the searchdomain</param>
[HttpPut("Settings")] [HttpPut("Settings")]
public ActionResult<SearchdomainUpdateResults> UpdateSettings([Required]string searchdomain, [Required][FromBody] SearchdomainSettings request) public async Task<ActionResult<SearchdomainUpdateResults>> UpdateSettings([Required]string searchdomain, [Required][FromBody] SearchdomainSettings request)
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{"settings", JsonSerializer.Serialize(request)}, {"settings", JsonSerializer.Serialize(request)},
{"id", searchdomain_.id} {"id", searchdomain_.Id}
}; };
searchdomain_.helper.ExecuteSQLNonQuery("UPDATE searchdomain set settings = @settings WHERE id = @id", parameters); await searchdomain_.Helper.ExecuteSQLNonQuery("UPDATE searchdomain set settings = @settings WHERE id = @id", parameters);
searchdomain_.settings = request; searchdomain_.Settings = request;
searchdomain_.queryCache.Capacity = request.QueryCacheSize; searchdomain_.QueryCache.Capacity = request.QueryCacheSize;
return Ok(new SearchdomainUpdateResults(){Success = true}); return Ok(new SearchdomainUpdateResults(){Success = true});
} }
@@ -259,8 +332,8 @@ public class SearchdomainController : ControllerBase
} }
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
int elementCount = searchdomain_.queryCache.Count; int elementCount = searchdomain_.QueryCache.Count;
int ElementMaxCount = searchdomain_.settings.QueryCacheSize; int ElementMaxCount = searchdomain_.Settings.QueryCacheSize;
return Ok(new SearchdomainQueryCacheSizeResults() { SizeBytes = searchdomain_.GetSearchCacheSize(), ElementCount = elementCount, ElementMaxCount = ElementMaxCount, Success = true }); return Ok(new SearchdomainQueryCacheSizeResults() { SizeBytes = searchdomain_.GetSearchCacheSize(), ElementCount = elementCount, ElementMaxCount = ElementMaxCount, Success = true });
} }
@@ -286,7 +359,7 @@ public class SearchdomainController : ControllerBase
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
long EmbeddingCacheUtilization = DatabaseHelper.GetSearchdomainDatabaseSize(searchdomain_.helper, searchdomain); long EmbeddingCacheUtilization = DatabaseHelper.GetSearchdomainDatabaseSize(searchdomain_.Helper, searchdomain);
return Ok(new SearchdomainGetDatabaseSizeResult() { SearchdomainDatabaseSizeBytes = EmbeddingCacheUtilization, Success = true }); return Ok(new SearchdomainGetDatabaseSizeResult() { SearchdomainDatabaseSizeBytes = EmbeddingCacheUtilization, Success = true });
} }
} }

View File

@@ -58,7 +58,7 @@ public class ServerController : ControllerBase
long size = 0; long size = 0;
long elementCount = 0; long elementCount = 0;
long embeddingsCount = 0; long embeddingsCount = 0;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.EmbeddingCache;
foreach (KeyValuePair<string, Dictionary<string, float[]>> kv in embeddingCache) foreach (KeyValuePair<string, Dictionary<string, float[]>> kv in embeddingCache)
{ {
@@ -68,23 +68,23 @@ public class ServerController : ControllerBase
elementCount++; elementCount++;
embeddingsCount += entry.Keys.Count; embeddingsCount += entry.Keys.Count;
} }
var sqlHelper = DatabaseHelper.GetSQLHelper(_options.Value); var sqlHelper = _searchdomainManager.Helper;
var databaseTotalSize = DatabaseHelper.GetTotalDatabaseSize(sqlHelper); var databaseTotalSize = DatabaseHelper.GetTotalDatabaseSize(sqlHelper);
Task<long> entityCountTask = DatabaseHelper.CountEntities(sqlHelper); Task<long> entityCountTask = DatabaseHelper.CountEntities(sqlHelper);
long queryCacheUtilization = 0; long queryCacheUtilization = 0;
long queryCacheElementCount = 0; long queryCacheElementCount = 0;
long queryCacheMaxElementCountAll = 0; long queryCacheMaxElementCountAll = 0;
long queryCacheMaxElementCountLoadedSearchdomainsOnly = 0; long queryCacheMaxElementCountLoadedSearchdomainsOnly = 0;
foreach (string searchdomain in _searchdomainManager.ListSearchdomains()) foreach (string searchdomain in await _searchdomainManager.ListSearchdomainsAsync())
{ {
if (SearchdomainHelper.IsSearchdomainLoaded(_searchdomainManager, searchdomain)) if (SearchdomainHelper.IsSearchdomainLoaded(_searchdomainManager, searchdomain))
{ {
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_searchdomainManager, searchdomain, _logger); (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_searchdomainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new ServerGetStatsResult(){Success = false, Message = message}); if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new ServerGetStatsResult(){Success = false, Message = message});
queryCacheUtilization += searchdomain_.GetSearchCacheSize(); queryCacheUtilization += searchdomain_.GetSearchCacheSize();
queryCacheElementCount += searchdomain_.queryCache.Count; queryCacheElementCount += searchdomain_.QueryCache.Count;
queryCacheMaxElementCountAll += searchdomain_.queryCache.Capacity; queryCacheMaxElementCountAll += searchdomain_.QueryCache.Capacity;
queryCacheMaxElementCountLoadedSearchdomainsOnly += searchdomain_.queryCache.Capacity; queryCacheMaxElementCountLoadedSearchdomainsOnly += searchdomain_.QueryCache.Capacity;
} else } else
{ {
var searchdomainSettings = DatabaseHelper.GetSearchdomainSettings(sqlHelper, searchdomain); var searchdomainSettings = DatabaseHelper.GetSearchdomainSettings(sqlHelper, searchdomain);

View File

@@ -1,41 +1,56 @@
using System.Collections.Concurrent;
using Shared; using Shared;
using Shared.Models;
namespace Server; namespace Server;
public class Datapoint public class Datapoint
{ {
public string name; public string Name;
public ProbMethod probMethod; public ProbMethod ProbMethod;
public SimilarityMethod similarityMethod; public SimilarityMethod SimilarityMethod;
public List<(string, float[])> embeddings; public List<(string, float[])> Embeddings;
public string hash; public string Hash;
public int Id;
public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings) public Datapoint(string name, ProbMethodEnum probMethod, SimilarityMethodEnum similarityMethod, string hash, List<(string, float[])> embeddings, int id)
{ {
this.name = name; Name = name;
this.probMethod = probMethod; ProbMethod = new ProbMethod(probMethod);
this.similarityMethod = similarityMethod; SimilarityMethod = new SimilarityMethod(similarityMethod);
this.hash = hash; Hash = hash;
this.embeddings = embeddings; Embeddings = embeddings;
Id = id;
}
public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings, int id)
{
Name = name;
ProbMethod = probMethod;
SimilarityMethod = similarityMethod;
Hash = hash;
Embeddings = embeddings;
Id = id;
} }
public float CalcProbability(List<(string, float)> probabilities) public float CalcProbability(List<(string, float)> probabilities)
{ {
return probMethod.method(probabilities); return ProbMethod.Method(probabilities);
} }
public static Dictionary<string, float[]> GetEmbeddings(string content, List<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache) public static Dictionary<string, float[]> GetEmbeddings(string content, ConcurrentBag<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
{ {
Dictionary<string, float[]> embeddings = []; Dictionary<string, float[]> embeddings = [];
bool embeddingCacheHasContent = embeddingCache.TryGetValue(content, out var embeddingCacheForContent); bool embeddingCacheHasContent = embeddingCache.TryGetValue(content, out var embeddingCacheForContent);
if (!embeddingCacheHasContent || embeddingCacheForContent is null) if (!embeddingCacheHasContent || embeddingCacheForContent is null)
{ {
models.ForEach(model => foreach (string model in models)
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache) {
); embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache);
}
return embeddings; return embeddings;
} }
models.ForEach(model => foreach (string model in models)
{ {
bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embeddingCacheForModel); bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embeddingCacheForModel);
if (embeddingCacheHasModel && embeddingCacheForModel is not null) if (embeddingCacheHasModel && embeddingCacheForModel is not null)
@@ -45,7 +60,7 @@ public class Datapoint
{ {
embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache); embeddings[model] = GenerateEmbeddings(content, model, aIProvider, embeddingCache);
} }
}); }
return embeddings; return embeddings;
} }

View File

@@ -1,11 +1,14 @@
using System.Collections.Concurrent;
namespace Server; namespace Server;
public class Entity(Dictionary<string, string> attributes, Probmethods.probMethodDelegate probMethod, string probMethodName, List<Datapoint> datapoints, string name) public class Entity(Dictionary<string, string> attributes, Probmethods.ProbMethodDelegate probMethod, string probMethodName, ConcurrentBag<Datapoint> datapoints, string name, string searchdomain)
{ {
public Dictionary<string, string> attributes = attributes; public Dictionary<string, string> Attributes = attributes;
public Probmethods.probMethodDelegate probMethod = probMethod; public Probmethods.ProbMethodDelegate ProbMethod = probMethod;
public string probMethodName = probMethodName; public string ProbMethodName = probMethodName;
public List<Datapoint> datapoints = datapoints; public ConcurrentBag<Datapoint> Datapoints = datapoints;
public int id; public int Id;
public string name = name; public string Name = name;
public string Searchdomain = searchdomain;
} }

View File

@@ -12,33 +12,33 @@ public class DatabaseHealthCheck : IHealthCheck
_searchdomainManager = searchdomainManager; _searchdomainManager = searchdomainManager;
_logger = logger; _logger = logger;
} }
public Task<HealthCheckResult> CheckHealthAsync( public async Task<HealthCheckResult> CheckHealthAsync(
HealthCheckContext context, CancellationToken cancellationToken = default) HealthCheckContext context, CancellationToken cancellationToken = default)
{ {
try try
{ {
DatabaseMigrations.DatabaseGetVersion(_searchdomainManager.helper); DatabaseMigrations.DatabaseGetVersion(_searchdomainManager.Helper);
} }
catch (Exception ex) catch (Exception ex)
{ {
_logger.LogCritical("DatabaseHealthCheck - Exception occurred when retrieving and parsing database version: {ex}", ex.Message); _logger.LogCritical("DatabaseHealthCheck - Exception occurred when retrieving and parsing database version: {ex}", ex.Message);
return Task.FromResult( return await Task.FromResult(
HealthCheckResult.Unhealthy()); HealthCheckResult.Unhealthy());
} }
try try
{ {
_searchdomainManager.helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES ('test', 'x');", []); await _searchdomainManager.Helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES ('test', 'x');", []);
_searchdomainManager.helper.ExecuteSQLNonQuery("DELETE FROM settings WHERE name = 'test';", []); await _searchdomainManager.Helper.ExecuteSQLNonQuery("DELETE FROM settings WHERE name = 'test';", []);
} }
catch (Exception ex) catch (Exception ex)
{ {
_logger.LogCritical("DatabaseHealthCheck - Exception occurred when executing INSERT/DELETE query: {ex}", ex.Message); _logger.LogCritical("DatabaseHealthCheck - Exception occurred when executing INSERT/DELETE query: {ex}", ex.Message);
return Task.FromResult( return await Task.FromResult(
HealthCheckResult.Unhealthy()); HealthCheckResult.Unhealthy());
} }
return Task.FromResult( return await Task.FromResult(
HealthCheckResult.Healthy()); HealthCheckResult.Healthy());
} }
} }

View File

@@ -72,7 +72,7 @@ public static class CacheHelper
deletedEntries.Add(storeEntryIndex); deletedEntries.Add(storeEntryIndex);
} }
} }
Task removeEntriesFromStoreTask = RemoveEntriesFromStore(helper, deletedEntries); await RemoveEntriesFromStore(helper, deletedEntries);
List<(int Index, KeyValuePair<string, Dictionary<string, float[]>> Entry)> createdEntries = []; List<(int Index, KeyValuePair<string, Dictionary<string, float[]>> Entry)> createdEntries = [];
@@ -127,7 +127,6 @@ public static class CacheHelper
var taskSet = new List<Task> var taskSet = new List<Task>
{ {
removeEntriesFromStoreTask,
CreateEntriesInStore(helper, createdEntries), CreateEntriesInStore(helper, createdEntries),
UpdateEntryIndicesInStore(helper, changedEntries), UpdateEntryIndicesInStore(helper, changedEntries),
AddModelsToIndices(helper, AddedModels), AddModelsToIndices(helper, AddedModels),

View File

@@ -20,11 +20,13 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
return new SQLHelper(connection, connectionString); return new SQLHelper(connection, connectionString);
} }
public static void DatabaseInsertEmbeddingBulk(SQLHelper helper, int id_datapoint, List<(string model, byte[] embedding)> data) public static async Task DatabaseInsertEmbeddingBulk(SQLHelper helper, int id_datapoint, List<(string model, byte[] embedding)> data, int id_entity, int id_searchdomain)
{ {
Dictionary<string, object> parameters = []; Dictionary<string, object> parameters = [];
parameters["id_datapoint"] = id_datapoint; parameters["id_datapoint"] = id_datapoint;
var query = new StringBuilder("INSERT INTO embedding (id_datapoint, model, embedding) VALUES "); parameters["id_entity"] = id_entity;
parameters["id_searchdomain"] = id_searchdomain;
var query = new StringBuilder("INSERT INTO embedding (id_datapoint, model, embedding, id_embedding, id_searchdomain) VALUES ");
foreach (var (model, embedding) in data) foreach (var (model, embedding) in data)
{ {
string modelParam = $"model_{Guid.NewGuid()}".Replace("-", ""); string modelParam = $"model_{Guid.NewGuid()}".Replace("-", "");
@@ -32,37 +34,39 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
parameters[modelParam] = model; parameters[modelParam] = model;
parameters[embeddingParam] = embedding; parameters[embeddingParam] = embedding;
query.Append($"(@id_datapoint, @{modelParam}, @{embeddingParam}), "); query.Append($"(@id_datapoint, @{modelParam}, @{embeddingParam}, @id_entity), ");
} }
query.Length -= 2; // remove trailing comma query.Length -= 2; // remove trailing comma
helper.ExecuteSQLNonQuery(query.ToString(), parameters); await helper.ExecuteSQLNonQuery(query.ToString(), parameters);
} }
public static int DatabaseInsertEmbeddingBulk(SQLHelper helper, List<(string hash, string model, byte[] embedding)> data) public static async Task<int> DatabaseInsertEmbeddingBulk(SQLHelper helper, List<(int id_datapoint, string model, byte[] embedding)> data, int id_entity, int id_searchdomain)
{ {
return helper.BulkExecuteNonQuery( return await helper.BulkExecuteNonQuery(
"INSERT INTO embedding (id_datapoint, model, embedding) SELECT d.id, @model, @embedding FROM datapoint d WHERE d.hash = @hash", "INSERT INTO embedding (id_datapoint, model, embedding, id_entity, id_searchdomain) VALUES (@id_datapoint, @model, @embedding, @id_entity, @id_searchdomain);",
data.Select(element => new object[] { data.Select(element => new object[] {
new MySqlParameter("@model", element.model), new MySqlParameter("@model", element.model),
new MySqlParameter("@embedding", element.embedding), new MySqlParameter("@embedding", element.embedding),
new MySqlParameter("@hash", element.hash) new MySqlParameter("@id_datapoint", element.id_datapoint),
new MySqlParameter("@id_entity", id_entity),
new MySqlParameter("@id_searchdomain", id_searchdomain)
}) })
); );
} }
public static int DatabaseInsertSearchdomain(SQLHelper helper, string name, SearchdomainSettings settings = new()) public static async Task<int> DatabaseInsertSearchdomain(SQLHelper helper, string name, SearchdomainSettings settings = new())
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{ "name", name }, { "name", name },
{ "settings", settings} { "settings", settings}
}; };
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters); return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters);
} }
public static int DatabaseInsertEntity(SQLHelper helper, string name, ProbMethodEnum probmethod, int id_searchdomain) public static async Task<int> DatabaseInsertEntity(SQLHelper helper, string name, ProbMethodEnum probmethod, int id_searchdomain)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
@@ -70,24 +74,13 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
{ "probmethod", probmethod.ToString() }, { "probmethod", probmethod.ToString() },
{ "id_searchdomain", id_searchdomain } { "id_searchdomain", id_searchdomain }
}; };
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO entity (name, probmethod, id_searchdomain) VALUES (@name, @probmethod, @id_searchdomain)", parameters); return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO entity (name, probmethod, id_searchdomain) VALUES (@name, @probmethod, @id_searchdomain);", parameters);
} }
public static int DatabaseInsertAttribute(SQLHelper helper, string attribute, string value, int id_entity) public static async Task<int> DatabaseInsertAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values) //string[] attribute, string value, int id_entity)
{ {
Dictionary<string, dynamic> parameters = new() return await helper.BulkExecuteNonQuery(
{ "INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity);",
{ "attribute", attribute },
{ "value", value },
{ "id_entity", id_entity }
};
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", parameters);
}
public static int DatabaseInsertAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values) //string[] attribute, string value, int id_entity)
{
return helper.BulkExecuteNonQuery(
"INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)",
values.Select(element => new object[] { values.Select(element => new object[] {
new MySqlParameter("@attribute", element.attribute), new MySqlParameter("@attribute", element.attribute),
new MySqlParameter("@value", element.value), new MySqlParameter("@value", element.value),
@@ -96,10 +89,33 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
); );
} }
public static int DatabaseInsertDatapoints(SQLHelper helper, List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> values, int id_entity) public static async Task<int> DatabaseUpdateAttributes(SQLHelper helper, List<(string attribute, string value, int id_entity)> values)
{ {
return helper.BulkExecuteNonQuery( return await helper.BulkExecuteNonQuery(
"INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", "UPDATE attribute SET value=@value WHERE id_entity=@id_entity AND attribute=@attribute",
values.Select(element => new object[] {
new MySqlParameter("@attribute", element.attribute),
new MySqlParameter("@value", element.value),
new MySqlParameter("@id_entity", element.id_entity)
})
);
}
public static async Task<int> DatabaseDeleteAttributes(SQLHelper helper, List<(string attribute, int id_entity)> values)
{
return await helper.BulkExecuteNonQuery(
"DELETE FROM attribute WHERE id_entity=@id_entity AND attribute=@attribute",
values.Select(element => new object[] {
new MySqlParameter("@attribute", element.attribute),
new MySqlParameter("@id_entity", element.id_entity)
})
);
}
public static async Task<int> DatabaseInsertDatapoints(SQLHelper helper, List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> values, int id_entity)
{
return await helper.BulkExecuteNonQuery(
"INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity);",
values.Select(element => new object[] { values.Select(element => new object[] {
new MySqlParameter("@name", element.name), new MySqlParameter("@name", element.name),
new MySqlParameter("@probmethod_embedding", element.probmethod_embedding), new MySqlParameter("@probmethod_embedding", element.probmethod_embedding),
@@ -110,7 +126,7 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
); );
} }
public static int DatabaseInsertDatapoint(SQLHelper helper, string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, int id_entity) public static async Task<int> DatabaseInsertDatapoint(SQLHelper helper, string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, int id_entity)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
@@ -120,111 +136,155 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
{ "hash", hash }, { "hash", hash },
{ "id_entity", id_entity } { "id_entity", id_entity }
}; };
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters); return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters);
} }
public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding) public static async Task<(int datapoints, int embeddings)> DatabaseDeleteEmbeddingsAndDatapoints(SQLHelper helper, List<string> values, int id_entity)
{
int embeddings = await helper.BulkExecuteNonQuery(
"DELETE e FROM embedding e WHERE id_entity = @entityId",
values.Select(element => new object[] {
new MySqlParameter("@datapointName", element),
new MySqlParameter("@entityId", id_entity)
})
);
int datapoints = await helper.BulkExecuteNonQuery(
"DELETE FROM datapoint WHERE name=@datapointName AND id_entity=@entityId",
values.Select(element => new object[] {
new MySqlParameter("@datapointName", element),
new MySqlParameter("@entityId", id_entity)
})
);
return (datapoints: datapoints, embeddings: embeddings);
}
public static async Task<int> DatabaseUpdateDatapoint(SQLHelper helper, List<(string name, string probmethod_embedding, string similarityMethod)> values, int id_entity)
{
return await helper.BulkExecuteNonQuery(
"UPDATE datapoint SET probmethod_embedding=@probmethod, similaritymethod=@similaritymethod WHERE id_entity=@entityId AND name=@datapointName",
values.Select(element => new object[] {
new MySqlParameter("@probmethod", element.probmethod_embedding),
new MySqlParameter("@similaritymethod", element.similarityMethod),
new MySqlParameter("@entityId", id_entity),
new MySqlParameter("@datapointName", element.name)
})
);
}
public static async Task<int> DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding, int id_entity, int id_searchdomain)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{ "id_datapoint", id_datapoint }, { "id_datapoint", id_datapoint },
{ "model", model }, { "model", model },
{ "embedding", embedding } { "embedding", embedding },
{ "id_entity", id_entity },
{ "id_searchdomain", id_searchdomain }
}; };
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO embedding (id_datapoint, model, embedding) VALUES (@id_datapoint, @model, @embedding)", parameters); return await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO embedding (id_datapoint, model, embedding, id_entity, id_searchdomain) VALUES (@id_datapoint, @model, @embedding, @id_entity, @id_searchdomain)", parameters);
} }
public int GetSearchdomainID(SQLHelper helper, string searchdomain) public async Task<int> GetSearchdomainID(SQLHelper helper, string searchdomain)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, object?> parameters = new()
{ {
{ "searchdomain", searchdomain} { "searchdomain", searchdomain}
}; };
lock (helper.connection) return (await helper.ExecuteQueryAsync("SELECT id FROM searchdomain WHERE name = @searchdomain", parameters, x => x.GetInt32(0))).First();
{
DbDataReader reader = helper.ExecuteSQLCommand("SELECT id FROM searchdomain WHERE name = @searchdomain", parameters);
bool success = reader.Read();
int result = success ? reader.GetInt32(0) : 0;
reader.Close();
if (success)
{
return result;
}
else
{
_logger.LogError("Unable to retrieve searchdomain ID for {searchdomain}", [searchdomain]);
throw new SearchdomainNotFoundException(searchdomain);
}
}
} }
public void RemoveEntity(List<Entity> entityCache, SQLHelper helper, string name, string searchdomain) public async Task RemoveEntity(List<Entity> entityCache, SQLHelper helper, string name, string searchdomain)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{ "name", name }, { "name", name },
{ "searchdomain", GetSearchdomainID(helper, searchdomain)} { "searchdomain", await GetSearchdomainID(helper, searchdomain)}
}; };
helper.ExecuteSQLNonQuery("DELETE embedding.* FROM embedding JOIN datapoint dp ON id_datapoint = dp.id JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); await helper.ExecuteSQLNonQuery("DELETE embedding.* FROM embedding JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters);
helper.ExecuteSQLNonQuery("DELETE datapoint.* FROM datapoint JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); await helper.ExecuteSQLNonQuery("DELETE datapoint.* FROM datapoint JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters);
helper.ExecuteSQLNonQuery("DELETE attribute.* FROM attribute JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters); await helper.ExecuteSQLNonQuery("DELETE attribute.* FROM attribute JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters);
helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE name = @name AND entity.id_searchdomain = @searchdomain", parameters); await helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE name = @name AND entity.id_searchdomain = @searchdomain", parameters);
entityCache.RemoveAll(entity => entity.name == name); entityCache.RemoveAll(entity => entity.Name == name);
} }
public int RemoveAllEntities(SQLHelper helper, string searchdomain) public async Task<int> RemoveAllEntities(SQLHelper helper, string searchdomain)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{ "searchdomain", GetSearchdomainID(helper, searchdomain)} { "searchdomain", await GetSearchdomainID(helper, searchdomain)}
}; };
int count;
helper.ExecuteSQLNonQuery("DELETE embedding.* FROM embedding JOIN datapoint dp ON id_datapoint = dp.id JOIN entity ON id_entity = entity.id WHERE entity.id_searchdomain = @searchdomain", parameters); do
helper.ExecuteSQLNonQuery("DELETE datapoint.* FROM datapoint JOIN entity ON id_entity = entity.id WHERE entity.id_searchdomain = @searchdomain", parameters); {
helper.ExecuteSQLNonQuery("DELETE FROM attribute WHERE id_entity IN (SELECT entity.id FROM entity WHERE id_searchdomain = @searchdomain)", parameters); count = await helper.ExecuteSQLNonQuery("DELETE FROM embedding WHERE id_searchdomain = @searchdomain LIMIT 10000", parameters);
return helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE entity.id_searchdomain = @searchdomain", parameters); } while (count == 10000);
do
{
count = await helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity IN (SELECT id FROM entity WHERE id_searchdomain = @searchdomain) LIMIT 10000", parameters);
} while (count == 10000);
do
{
count = await helper.ExecuteSQLNonQuery("DELETE FROM attribute WHERE id_entity IN (SELECT id FROM entity WHERE id_searchdomain = @searchdomain) LIMIT 10000", parameters);
} while (count == 10000);
int total = 0;
do
{
count = await helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE id_searchdomain = @searchdomain LIMIT 10000", parameters);
total += count;
} while (count == 10000);
return total;
} }
public bool HasEntity(SQLHelper helper, string name, string searchdomain) public async Task<bool> HasEntity(SQLHelper helper, string name, string searchdomain)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{ "name", name }, { "name", name },
{ "searchdomain", GetSearchdomainID(helper, searchdomain)} { "searchdomain", await GetSearchdomainID(helper, searchdomain)}
}; };
lock (helper.connection) lock (helper.Connection)
{ {
DbDataReader reader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters); DbDataReader reader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters);
bool success = reader.Read(); try
bool result = success && reader.GetInt32(0) > 0;
reader.Close();
if (success)
{ {
return result; bool success = reader.Read();
} bool result = success && reader.GetInt32(0) > 0;
else if (success)
{
return result;
}
else
{
_logger.LogError("Unable to determine whether an entity named {name} exists for {searchdomain}", [name, searchdomain]);
throw new Exception($"Unable to determine whether an entity named {name} exists for {searchdomain}");
}
} finally
{ {
_logger.LogError("Unable to determine whether an entity named {name} exists for {searchdomain}", [name, searchdomain]); reader.Close();
throw new Exception($"Unable to determine whether an entity named {name} exists for {searchdomain}");
} }
} }
} }
public int? GetEntityID(SQLHelper helper, string name, string searchdomain) public async Task<int?> GetEntityID(SQLHelper helper, string name, string searchdomain)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()
{ {
{ "name", name }, { "name", name },
{ "searchdomain", GetSearchdomainID(helper, searchdomain)} { "searchdomain", await GetSearchdomainID(helper, searchdomain)}
}; };
lock (helper.connection) lock (helper.Connection)
{ {
DbDataReader reader = helper.ExecuteSQLCommand("SELECT id FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters); DbDataReader reader = helper.ExecuteSQLCommand("SELECT id FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters);
bool success = reader.Read(); try
int? result = success ? reader.GetInt32(0) : 0; {
reader.Close(); bool success = reader.Read();
return result; int? result = success ? reader.GetInt32(0) : 0;
return result;
} finally
{
reader.Close();
}
} }
} }
@@ -235,29 +295,56 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
{ "searchdomain", searchdomain} { "searchdomain", searchdomain}
}; };
DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(id) + LENGTH(name) + LENGTH(settings)) AS total_bytes FROM embeddingsearch.searchdomain WHERE name=@searchdomain", parameters); DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(id) + LENGTH(name) + LENGTH(settings)) AS total_bytes FROM embeddingsearch.searchdomain WHERE name=@searchdomain", parameters);
bool success = searchdomainSumReader.Read(); bool success;
long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; long result;
searchdomainSumReader.Close(); try
{
success = searchdomainSumReader.Read();
result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0;
} finally
{
searchdomainSumReader.Close();
}
DbDataReader entitySumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(e.id) + LENGTH(e.name) + LENGTH(e.probmethod) + LENGTH(e.id_searchdomain)) AS total_bytes FROM embeddingsearch.entity e JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); DbDataReader entitySumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(e.id) + LENGTH(e.name) + LENGTH(e.probmethod) + LENGTH(e.id_searchdomain)) AS total_bytes FROM embeddingsearch.entity e JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters);
success = entitySumReader.Read(); try
result += success && !entitySumReader.IsDBNull(0) ? entitySumReader.GetInt64(0) : 0; {
entitySumReader.Close(); success = entitySumReader.Read();
result += success && !entitySumReader.IsDBNull(0) ? entitySumReader.GetInt64(0) : 0;
} finally
{
entitySumReader.Close();
}
DbDataReader datapointSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(d.id) + LENGTH(d.name) + LENGTH(d.probmethod_embedding) + LENGTH(d.similaritymethod) + LENGTH(d.id_entity) + LENGTH(d.hash)) AS total_bytes FROM embeddingsearch.datapoint d JOIN embeddingsearch.entity e ON d.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); DbDataReader datapointSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(d.id) + LENGTH(d.name) + LENGTH(d.probmethod_embedding) + LENGTH(d.similaritymethod) + LENGTH(d.id_entity) + LENGTH(d.hash)) AS total_bytes FROM embeddingsearch.datapoint d JOIN embeddingsearch.entity e ON d.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters);
success = datapointSumReader.Read(); try
result += success && !datapointSumReader.IsDBNull(0) ? datapointSumReader.GetInt64(0) : 0; {
datapointSumReader.Close(); success = datapointSumReader.Read();
result += success && !datapointSumReader.IsDBNull(0) ? datapointSumReader.GetInt64(0) : 0;
} finally
{
datapointSumReader.Close();
}
DbDataReader embeddingSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(em.id) + LENGTH(em.id_datapoint) + LENGTH(em.model) + LENGTH(em.embedding)) AS total_bytes FROM embeddingsearch.embedding em JOIN embeddingsearch.datapoint d ON em.id_datapoint = d.id JOIN embeddingsearch.entity e ON d.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); DbDataReader embeddingSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(em.id) + LENGTH(em.id_datapoint) + LENGTH(em.model) + LENGTH(em.embedding)) AS total_bytes FROM embeddingsearch.embedding em JOIN embeddingsearch.datapoint d ON em.id_datapoint = d.id JOIN embeddingsearch.entity e ON d.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters);
success = embeddingSumReader.Read(); try
result += success && !embeddingSumReader.IsDBNull(0) ? embeddingSumReader.GetInt64(0) : 0; {
embeddingSumReader.Close(); success = embeddingSumReader.Read();
result += success && !embeddingSumReader.IsDBNull(0) ? embeddingSumReader.GetInt64(0) : 0;
} finally
{
embeddingSumReader.Close();
}
DbDataReader attributeSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(a.id) + LENGTH(a.id_entity) + LENGTH(a.attribute) + LENGTH(a.value)) AS total_bytes FROM embeddingsearch.attribute a JOIN embeddingsearch.entity e ON a.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters); DbDataReader attributeSumReader = helper.ExecuteSQLCommand("SELECT SUM(LENGTH(a.id) + LENGTH(a.id_entity) + LENGTH(a.attribute) + LENGTH(a.value)) AS total_bytes FROM embeddingsearch.attribute a JOIN embeddingsearch.entity e ON a.id_entity = e.id JOIN embeddingsearch.searchdomain s ON e.id_searchdomain = s.id WHERE s.name=@searchdomain", parameters);
success = attributeSumReader.Read(); try
result += success && !attributeSumReader.IsDBNull(0) ? attributeSumReader.GetInt64(0) : 0; {
attributeSumReader.Close(); success = attributeSumReader.Read();
result += success && !attributeSumReader.IsDBNull(0) ? attributeSumReader.GetInt64(0) : 0;
} finally
{
attributeSumReader.Close();
}
return result; return result;
} }
@@ -280,10 +367,15 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
public static async Task<long> CountEntities(SQLHelper helper) public static async Task<long> CountEntities(SQLHelper helper)
{ {
DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity;", []); DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity;", []);
bool success = searchdomainSumReader.Read(); try
long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; {
searchdomainSumReader.Close(); bool success = searchdomainSumReader.Read();
return result; long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0;
return result;
} finally
{
searchdomainSumReader.Close();
}
} }
public static long CountEntitiesForSearchdomain(SQLHelper helper, string searchdomain) public static long CountEntitiesForSearchdomain(SQLHelper helper, string searchdomain)
@@ -293,10 +385,15 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
{ "searchdomain", searchdomain} { "searchdomain", searchdomain}
}; };
DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity e JOIN searchdomain s on e.id_searchdomain = s.id WHERE e.id_searchdomain = s.id AND s.name = @searchdomain;", parameters); DbDataReader searchdomainSumReader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity e JOIN searchdomain s on e.id_searchdomain = s.id WHERE e.id_searchdomain = s.id AND s.name = @searchdomain;", parameters);
bool success = searchdomainSumReader.Read(); try
long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0; {
searchdomainSumReader.Close(); bool success = searchdomainSumReader.Read();
return result; long result = success && !searchdomainSumReader.IsDBNull(0) ? searchdomainSumReader.GetInt64(0) : 0;
return result;
} finally
{
searchdomainSumReader.Close();
}
} }
public static SearchdomainSettings GetSearchdomainSettings(SQLHelper helper, string searchdomain) public static SearchdomainSettings GetSearchdomainSettings(SQLHelper helper, string searchdomain)

View File

@@ -6,50 +6,86 @@ namespace Server.Helper;
public class SQLHelper:IDisposable public class SQLHelper:IDisposable
{ {
public MySqlConnection connection; public MySqlConnection Connection;
public DbDataReader? dbDataReader; public DbDataReader? DbDataReader;
public string connectionString; public MySqlConnectionPoolElement[] ConnectionPool;
public string ConnectionString;
public SQLHelper(MySqlConnection connection, string connectionString) public SQLHelper(MySqlConnection connection, string connectionString)
{ {
this.connection = connection; Connection = connection;
this.connectionString = connectionString; ConnectionString = connectionString;
ConnectionPool = new MySqlConnectionPoolElement[50];
for (int i = 0; i < ConnectionPool.Length; i++)
{
ConnectionPool[i] = new MySqlConnectionPoolElement(new MySqlConnection(connectionString), new(1, 1));
}
} }
public SQLHelper DuplicateConnection() public SQLHelper DuplicateConnection() // TODO remove this
{ {
MySqlConnection newConnection = new(connectionString); return this;
return new SQLHelper(newConnection, connectionString);
} }
public void Dispose() public void Dispose()
{ {
connection.Close(); Connection.Close();
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters) public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters)
{ {
lock (connection) lock (Connection)
{ {
EnsureConnected(); EnsureConnected();
EnsureDbReaderIsClosed(); EnsureDbReaderIsClosed();
using MySqlCommand command = connection.CreateCommand(); using MySqlCommand command = Connection.CreateCommand();
command.CommandText = query; command.CommandText = query;
foreach (KeyValuePair<string, dynamic> parameter in parameters) foreach (KeyValuePair<string, dynamic> parameter in parameters)
{ {
command.Parameters.AddWithValue($"@{parameter.Key}", parameter.Value); command.Parameters.AddWithValue($"@{parameter.Key}", parameter.Value);
} }
dbDataReader = command.ExecuteReader(); DbDataReader = command.ExecuteReader();
return dbDataReader; return DbDataReader;
} }
} }
public int ExecuteSQLNonQuery(string query, Dictionary<string, dynamic> parameters) public async Task<List<T>> ExecuteQueryAsync<T>(
string sql,
Dictionary<string, object?> parameters,
Func<DbDataReader, T> map)
{ {
lock (connection) var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.Connection;
try
{
await using var command = connection.CreateCommand();
command.CommandText = sql;
foreach (var p in parameters)
command.Parameters.AddWithValue($"@{p.Key}", p.Value);
await using var reader = await command.ExecuteReaderAsync();
var result = new List<T>();
while (await reader.ReadAsync())
{
result.Add(map(reader));
}
return result;
} finally
{
poolElement.Semaphore.Release();
}
}
public async Task<int> ExecuteSQLNonQuery(string query, Dictionary<string, dynamic> parameters)
{
var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.Connection;
try
{ {
EnsureConnected();
EnsureDbReaderIsClosed();
using MySqlCommand command = connection.CreateCommand(); using MySqlCommand command = connection.CreateCommand();
command.CommandText = query; command.CommandText = query;
@@ -58,15 +94,18 @@ public class SQLHelper:IDisposable
command.Parameters.AddWithValue($"@{parameter.Key}", parameter.Value); command.Parameters.AddWithValue($"@{parameter.Key}", parameter.Value);
} }
return command.ExecuteNonQuery(); return command.ExecuteNonQuery();
} finally
{
poolElement.Semaphore.Release();
} }
} }
public int ExecuteSQLCommandGetInsertedID(string query, Dictionary<string, dynamic> parameters) public async Task<int> ExecuteSQLCommandGetInsertedID(string query, Dictionary<string, dynamic> parameters)
{ {
lock (connection) var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.Connection;
try
{ {
EnsureConnected();
EnsureDbReaderIsClosed();
using MySqlCommand command = connection.CreateCommand(); using MySqlCommand command = connection.CreateCommand();
command.CommandText = query; command.CommandText = query;
@@ -77,44 +116,90 @@ public class SQLHelper:IDisposable
command.ExecuteNonQuery(); command.ExecuteNonQuery();
command.CommandText = "SELECT LAST_INSERT_ID();"; command.CommandText = "SELECT LAST_INSERT_ID();";
return Convert.ToInt32(command.ExecuteScalar()); return Convert.ToInt32(command.ExecuteScalar());
} finally
{
poolElement.Semaphore.Release();
} }
} }
public int BulkExecuteNonQuery(string sql, IEnumerable<object[]> parameterSets) public async Task<int> BulkExecuteNonQuery(string sql, IEnumerable<object[]> parameterSets)
{ {
lock (connection) var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.Connection;
try
{ {
EnsureConnected();
EnsureDbReaderIsClosed();
using var transaction = connection.BeginTransaction();
using var command = connection.CreateCommand();
command.CommandText = sql;
command.Transaction = transaction;
int affectedRows = 0; int affectedRows = 0;
int retries = 0;
foreach (var parameters in parameterSets) while (retries <= 3)
{ {
command.Parameters.Clear(); try
command.Parameters.AddRange(parameters); {
affectedRows += command.ExecuteNonQuery(); using var transaction = connection.BeginTransaction();
using var command = connection.CreateCommand();
command.CommandText = sql;
command.Transaction = transaction;
foreach (var parameters in parameterSets)
{
command.Parameters.Clear();
command.Parameters.AddRange(parameters);
affectedRows += command.ExecuteNonQuery();
}
transaction.Commit();
break;
}
catch (Exception)
{
retries++;
if (retries > 3)
throw;
Thread.Sleep(10);
}
} }
transaction.Commit();
return affectedRows; return affectedRows;
} finally
{
poolElement.Semaphore.Release();
} }
} }
public async Task<MySqlConnectionPoolElement> GetMySqlConnectionPoolElement()
{
int counter = 0;
int sleepTime = 10;
do
{
foreach (var element in ConnectionPool)
{
if (element.Semaphore.Wait(0))
{
if (element.Connection.State == ConnectionState.Closed)
{
await element.Connection.CloseAsync();
await element.Connection.OpenAsync();
}
return element;
}
}
Thread.Sleep(sleepTime);
} while (++counter <= 50);
TimeoutException ex = new("Unable to get MySqlConnection");
ElmahCore.ElmahExtensions.RaiseError(ex);
throw ex;
}
public bool EnsureConnected() public bool EnsureConnected()
{ {
if (connection.State != System.Data.ConnectionState.Open) if (Connection.State != System.Data.ConnectionState.Open)
{ {
try try
{ {
connection.Close(); Connection.Close();
connection.Open(); Connection.Open();
} }
catch (Exception ex) catch (Exception ex)
{ {
@@ -130,7 +215,7 @@ public class SQLHelper:IDisposable
int counter = 0; int counter = 0;
int sleepTime = 10; int sleepTime = 10;
int timeout = 5000; int timeout = 5000;
while (!(dbDataReader?.IsClosed ?? true)) while (!(DbDataReader?.IsClosed ?? true))
{ {
if (counter > timeout / sleepTime) if (counter > timeout / sleepTime)
{ {
@@ -142,3 +227,15 @@ public class SQLHelper:IDisposable
} }
} }
} }
public struct MySqlConnectionPoolElement
{
public MySqlConnection Connection;
public SemaphoreSlim Semaphore;
public MySqlConnectionPoolElement(MySqlConnection connection, SemaphoreSlim semaphore)
{
Connection = connection;
Semaphore = semaphore;
}
}

View File

@@ -1,4 +1,5 @@
using System.Collections.Concurrent; using System.Collections.Concurrent;
using System.Diagnostics;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
@@ -29,16 +30,16 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
return floatArray; return floatArray;
} }
public static bool CacheHasEntity(List<Entity> entityCache, string name) public static bool CacheHasEntity(ConcurrentDictionary<string, Entity> entityCache, string name)
{ {
return CacheGetEntity(entityCache, name) is not null; return CacheGetEntity(entityCache, name) is not null;
} }
public static Entity? CacheGetEntity(List<Entity> entityCache, string name) public static Entity? CacheGetEntity(ConcurrentDictionary<string, Entity> entityCache, string name)
{ {
foreach (Entity entity in entityCache) foreach ((string _, Entity entity) in entityCache)
{ {
if (entity.name == name) if (entity.Name == name)
{ {
return entity; return entity;
} }
@@ -46,11 +47,11 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
return null; return null;
} }
public List<Entity>? EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json) public async Task<List<Entity>?> EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
{ {
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.EmbeddingCache;
AIProvider aIProvider = searchdomainManager.aIProvider; AIProvider aIProvider = searchdomainManager.AiProvider;
SQLHelper helper = searchdomainManager.helper; SQLHelper helper = searchdomainManager.Helper;
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json); List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
if (jsonEntities is null) if (jsonEntities is null)
@@ -64,7 +65,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
foreach (JSONEntity jSONEntity in jsonEntities) foreach (JSONEntity jSONEntity in jsonEntities)
{ {
Dictionary<string, List<string>> targetDictionary = toBeCached; Dictionary<string, List<string>> targetDictionary = toBeCached;
if (searchdomainManager.GetSearchdomain(jSONEntity.Searchdomain).settings.ParallelEmbeddingsPrefetch) if (searchdomainManager.GetSearchdomain(jSONEntity.Searchdomain).Settings.ParallelEmbeddingsPrefetch)
{ {
targetDictionary = toBeCachedParallel; targetDictionary = toBeCachedParallel;
} }
@@ -96,155 +97,264 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
// Index/parse the entities // Index/parse the entities
ConcurrentQueue<Entity> retVal = []; ConcurrentQueue<Entity> retVal = [];
ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool
Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity =>
List<Task> entityTasks = [];
foreach (JSONEntity jSONEntity in jsonEntities)
{ {
var entity = EntityFromJSON(searchdomainManager, logger, jSONEntity); entityTasks.Add(Task.Run(async () =>
if (entity is not null)
{ {
retVal.Enqueue(entity); var entity = await EntityFromJSON(searchdomainManager, logger, jSONEntity);
if (entity is not null)
{
retVal.Enqueue(entity);
}
}));
if (entityTasks.Count >= parallelOptions.MaxDegreeOfParallelism)
{
await Task.WhenAny(entityTasks);
entityTasks.RemoveAll(t => t.IsCompleted);
} }
}); }
await Task.WhenAll(entityTasks);
return [.. retVal]; return [.. retVal];
} }
public Entity? EntityFromJSON(SearchdomainManager searchdomainManager, ILogger logger, JSONEntity jsonEntity) //string json) public async Task<Entity?> EntityFromJSON(SearchdomainManager searchdomainManager, ILogger logger, JSONEntity jsonEntity)
{ {
using SQLHelper helper = searchdomainManager.helper.DuplicateConnection(); var stopwatch = Stopwatch.StartNew();
SQLHelper helper = searchdomainManager.Helper;
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain); Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
List<Entity> entityCache = searchdomain.entityCache; int id_searchdomain = searchdomain.Id;
AIProvider aIProvider = searchdomain.aIProvider; ConcurrentDictionary<string, Entity> entityCache = searchdomain.EntityCache;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; AIProvider aIProvider = searchdomain.AiProvider;
Entity? preexistingEntity = entityCache.FirstOrDefault(entity => entity.name == jsonEntity.Name); EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.EmbeddingCache;
bool invalidateSearchCache = false; bool invalidateSearchCache = false;
if (preexistingEntity is not null)
bool hasEntity = entityCache.TryGetValue(jsonEntity.Name, out Entity? preexistingEntity);
if (hasEntity && preexistingEntity is not null)
{ {
int? preexistingEntityID = _databaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
if (preexistingEntityID is null) int preexistingEntityID = preexistingEntity.Id;
{
_logger.LogCritical("Unable to index entity {jsonEntity.Name} because it already exists in the searchdomain but not in the database.", [jsonEntity.Name]);
throw new Exception($"Unable to index entity {jsonEntity.Name} because it already exists in the searchdomain but not in the database.");
}
Dictionary<string, string> attributes = jsonEntity.Attributes; Dictionary<string, string> attributes = jsonEntity.Attributes;
// Attribute // Attribute - get changes
foreach (KeyValuePair<string, string> attributesKV in preexistingEntity.attributes.ToList()) List<(string attribute, string newValue, int entityId)> updatedAttributes = new(preexistingEntity.Attributes.Count);
List<(string attribute, int entityId)> deletedAttributes = new(preexistingEntity.Attributes.Count);
List<(string attributeKey, string attribute, int entityId)> addedAttributes = new(jsonEntity.Attributes.Count);
foreach (KeyValuePair<string, string> attributesKV in preexistingEntity.Attributes) //.ToList())
{ {
string oldAttributeKey = attributesKV.Key; string oldAttributeKey = attributesKV.Key;
string oldAttribute = attributesKV.Value; string oldAttribute = attributesKV.Value;
bool newHasAttribute = jsonEntity.Attributes.TryGetValue(oldAttributeKey, out string? newAttribute); bool newHasAttribute = jsonEntity.Attributes.TryGetValue(oldAttributeKey, out string? newAttribute);
if (newHasAttribute && newAttribute is not null && newAttribute != oldAttribute) if (newHasAttribute && newAttribute is not null && newAttribute != oldAttribute)
{ {
// Attribute - Updated updatedAttributes.Add((attribute: oldAttributeKey, newValue: newAttribute, entityId: (int)preexistingEntityID));
Dictionary<string, dynamic> parameters = new()
{
{ "newValue", newAttribute },
{ "entityId", preexistingEntityID },
{ "attribute", oldAttributeKey}
};
helper.ExecuteSQLNonQuery("UPDATE attribute SET value=@newValue WHERE id_entity=@entityId AND attribute=@attribute", parameters);
preexistingEntity.attributes[oldAttributeKey] = newAttribute;
} else if (!newHasAttribute) } else if (!newHasAttribute)
{ {
// Attribute - Deleted deletedAttributes.Add((attribute: oldAttributeKey, entityId: (int)preexistingEntityID));
Dictionary<string, dynamic> parameters = new()
{
{ "entityId", preexistingEntityID },
{ "attribute", oldAttributeKey}
};
helper.ExecuteSQLNonQuery("DELETE FROM attribute WHERE id_entity=@entityId AND attribute=@attribute", parameters);
preexistingEntity.attributes.Remove(oldAttributeKey);
} }
} }
foreach (var attributesKV in jsonEntity.Attributes) foreach (var attributesKV in jsonEntity.Attributes)
{ {
string newAttributeKey = attributesKV.Key; string newAttributeKey = attributesKV.Key;
string newAttribute = attributesKV.Value; string newAttribute = attributesKV.Value;
bool preexistingHasAttribute = preexistingEntity.attributes.TryGetValue(newAttributeKey, out string? preexistingAttribute); bool preexistingHasAttribute = preexistingEntity.Attributes.TryGetValue(newAttributeKey, out string? preexistingAttribute);
if (!preexistingHasAttribute) if (!preexistingHasAttribute)
{ {
// Attribute - New // Attribute - New
DatabaseHelper.DatabaseInsertAttribute(helper, newAttributeKey, newAttribute, (int)preexistingEntityID); addedAttributes.Add((attributeKey: newAttributeKey, attribute: newAttribute, entityId: (int)preexistingEntityID));
preexistingEntity.attributes.Add(newAttributeKey, newAttribute);
} }
} }
// Datapoint if (updatedAttributes.Count != 0 || deletedAttributes.Count != 0 || addedAttributes.Count != 0)
foreach (Datapoint datapoint_ in preexistingEntity.datapoints.ToList()) _logger.LogDebug("EntityFromJSON - Updating existing entity. name: {name}, updatedAttributes: {updatedAttributes}, deletedAttributes: {deletedAttributes}, addedAttributes: {addedAttributes}", [preexistingEntity.Name, updatedAttributes.Count, deletedAttributes.Count, addedAttributes.Count]);
// Attribute - apply changes
if (updatedAttributes.Count != 0)
{
// Update
await DatabaseHelper.DatabaseUpdateAttributes(helper, updatedAttributes);
lock (preexistingEntity.Attributes)
{
updatedAttributes.ForEach(attribute => preexistingEntity.Attributes[attribute.attribute] = attribute.newValue);
}
}
if (deletedAttributes.Count != 0)
{
// Delete
await DatabaseHelper.DatabaseDeleteAttributes(helper, deletedAttributes);
lock (preexistingEntity.Attributes)
{
deletedAttributes.ForEach(attribute => preexistingEntity.Attributes.Remove(attribute.attribute));
}
}
if (addedAttributes.Count != 0)
{
// Insert
await DatabaseHelper.DatabaseInsertAttributes(helper, addedAttributes);
lock (preexistingEntity.Attributes)
{
addedAttributes.ForEach(attribute => preexistingEntity.Attributes.Add(attribute.attributeKey, attribute.attribute));
}
}
// Datapoint - get changes
List<Datapoint> deletedDatapointInstances = new(preexistingEntity.Datapoints.Count);
List<string> deletedDatapoints = new(preexistingEntity.Datapoints.Count);
List<(string datapointName, int entityId, JSONDatapoint jsonDatapoint, string hash)> updatedDatapointsText = new(preexistingEntity.Datapoints.Count);
List<(string datapointName, string probMethod, string similarityMethod, int entityId, JSONDatapoint jsonDatapoint)> updatedDatapointsNonText = new(preexistingEntity.Datapoints.Count);
List<Datapoint> createdDatapointInstances = [];
List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, Dictionary<string, float[]> embeddings, JSONDatapoint datapoint)> createdDatapoints = new(jsonEntity.Datapoints.Length);
foreach (Datapoint datapoint_ in preexistingEntity.Datapoints.ToList())
{ {
Datapoint datapoint = datapoint_; // To enable replacing the datapoint reference as foreach iterators cannot be overwritten Datapoint datapoint = datapoint_; // To enable replacing the datapoint reference as foreach iterators cannot be overwritten
bool newEntityHasDatapoint = jsonEntity.Datapoints.Any(x => x.Name == datapoint.name); JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.Name);
bool newEntityHasDatapoint = newEntityDatapoint is not null;
if (!newEntityHasDatapoint) if (!newEntityHasDatapoint)
{ {
// Datapoint - Deleted // Datapoint - Deleted
Dictionary<string, dynamic> parameters = new() deletedDatapointInstances.Add(datapoint);
{ deletedDatapoints.Add(datapoint.Name);
{ "datapointName", datapoint.name },
{ "entityId", preexistingEntityID}
};
helper.ExecuteSQLNonQuery("DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", parameters);
helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity=@entityId AND name=@datapointName", parameters);
preexistingEntity.datapoints.Remove(datapoint);
invalidateSearchCache = true; invalidateSearchCache = true;
} else } else
{ {
JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name); string? hash = newEntityDatapoint?.Text is not null ? GetHash(newEntityDatapoint) : null;
if (newEntityDatapoint is not null && newEntityDatapoint.Text is not null) if (
newEntityDatapoint is not null
&& newEntityDatapoint.Text is not null
&& hash is not null
&& hash != datapoint.Hash)
{ {
// Datapoint - Updated (text) // Datapoint - Updated (text)
Dictionary<string, dynamic> parameters = new() updatedDatapointsText.Add(new()
{ {
{ "datapointName", datapoint.name }, datapointName = newEntityDatapoint.Name,
{ "entityId", preexistingEntityID} entityId = (int)preexistingEntityID,
}; jsonDatapoint = newEntityDatapoint,
helper.ExecuteSQLNonQuery("DELETE e FROM embedding e JOIN datapoint d ON e.id_datapoint=d.id WHERE d.name=@datapointName AND d.id_entity=@entityId", parameters); hash = hash
helper.ExecuteSQLNonQuery("DELETE FROM datapoint WHERE id_entity=@entityId AND name=@datapointName", parameters); });
preexistingEntity.datapoints.Remove(datapoint);
Datapoint newDatapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, newEntityDatapoint, (int)preexistingEntityID);
preexistingEntity.datapoints.Add(newDatapoint);
datapoint = newDatapoint;
invalidateSearchCache = true; invalidateSearchCache = true;
} }
if (newEntityDatapoint is not null && (newEntityDatapoint.Probmethod_embedding != datapoint.probMethod.probMethodEnum || newEntityDatapoint.SimilarityMethod != datapoint.similarityMethod.similarityMethodEnum)) if (
newEntityDatapoint is not null
&& (newEntityDatapoint.Probmethod_embedding != datapoint.ProbMethod.ProbMethodEnum
|| newEntityDatapoint.SimilarityMethod != datapoint.SimilarityMethod.SimilarityMethodEnum))
{ {
// Datapoint - Updated (probmethod or similaritymethod) // Datapoint - Updated (probmethod or similaritymethod)
Dictionary<string, dynamic> parameters = new() updatedDatapointsNonText.Add(new()
{ {
{ "probmethod", newEntityDatapoint.Probmethod_embedding.ToString() }, datapointName = newEntityDatapoint.Name,
{ "similaritymethod", newEntityDatapoint.SimilarityMethod.ToString() }, entityId = (int)preexistingEntityID,
{ "datapointName", datapoint.name }, probMethod = newEntityDatapoint.Probmethod_embedding.ToString(),
{ "entityId", preexistingEntityID} similarityMethod = newEntityDatapoint.SimilarityMethod.ToString(),
}; jsonDatapoint = newEntityDatapoint
helper.ExecuteSQLNonQuery("UPDATE datapoint SET probmethod_embedding=@probmethod, similaritymethod=@similaritymethod WHERE id_entity=@entityId AND name=@datapointName", parameters); });
Datapoint preexistingDatapoint = preexistingEntity.datapoints.First(x => x == datapoint); // The for loop is a copy. This retrieves the original such that it can be updated.
preexistingDatapoint.probMethod = new(newEntityDatapoint.Probmethod_embedding, _logger);
preexistingDatapoint.similarityMethod = new(newEntityDatapoint.SimilarityMethod, _logger);
invalidateSearchCache = true; invalidateSearchCache = true;
} }
} }
} }
foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints) foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints)
{ {
bool oldEntityHasDatapoint = preexistingEntity.datapoints.Any(x => x.name == jsonDatapoint.Name); bool oldEntityHasDatapoint = preexistingEntity.Datapoints.Any(x => x.Name == jsonDatapoint.Name);
if (!oldEntityHasDatapoint) if (!oldEntityHasDatapoint)
{ {
// Datapoint - New // Datapoint - New
Datapoint datapoint = DatabaseInsertDatapointWithEmbeddings(helper, searchdomain, jsonDatapoint, (int)preexistingEntityID); createdDatapoints.Add(new()
preexistingEntity.datapoints.Add(datapoint); {
name = jsonDatapoint.Name,
probmethod_embedding = jsonDatapoint.Probmethod_embedding,
similarityMethod = jsonDatapoint.SimilarityMethod,
hash = GetHash(jsonDatapoint),
embeddings = Datapoint.GetEmbeddings(
jsonDatapoint.Text ?? throw new Exception("jsonDatapoint.Text must not be null when retrieving embeddings"),
[.. jsonDatapoint.Model],
aIProvider,
embeddingCache
),
datapoint = jsonDatapoint
});
invalidateSearchCache = true; invalidateSearchCache = true;
} }
} }
if (deletedDatapointInstances.Count != 0 || createdDatapoints.Count != 0 || addedAttributes.Count != 0 || updatedDatapointsNonText.Count != 0)
_logger.LogDebug(
"EntityFromJSON - Updating existing entity. name: {name}, deletedDatapointInstances: {deletedDatapointInstances}, createdDatapoints: {createdDatapoints}, addedAttributes: {addedAttributes}, updatedDatapointsNonText: {updatedDatapointsNonText}",
[preexistingEntity.Name, deletedDatapointInstances.Count, createdDatapoints.Count, addedAttributes.Count, updatedDatapointsNonText.Count]);
// Datapoint - apply changes
// Deleted
if (deletedDatapointInstances.Count != 0)
{
await DatabaseHelper.DatabaseDeleteEmbeddingsAndDatapoints(helper, deletedDatapoints, (int)preexistingEntityID);
preexistingEntity.Datapoints = [.. preexistingEntity.Datapoints
.Where(x =>
!deletedDatapointInstances.Contains(x)
)
];
}
// Created
if (createdDatapoints.Count != 0)
{
List<Datapoint> datapoint = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. createdDatapoints.Select(element => (element.datapoint, element.hash))], (int)preexistingEntityID, id_searchdomain);
datapoint.ForEach(x => preexistingEntity.Datapoints.Add(x));
}
// Datapoint - Updated (text)
if (updatedDatapointsText.Count != 0)
{
await DatabaseHelper.DatabaseDeleteEmbeddingsAndDatapoints(helper, [.. updatedDatapointsText.Select(datapoint => datapoint.datapointName)], (int)preexistingEntityID);
// Remove from datapoints
var namesToRemove = updatedDatapointsText
.Select(d => d.datapointName)
.ToHashSet();
var newBag = new ConcurrentBag<Datapoint>(
preexistingEntity.Datapoints
.Where(x => !namesToRemove.Contains(x.Name))
);
preexistingEntity.Datapoints = newBag;
// Insert into database
List<Datapoint> datapoints = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. updatedDatapointsText.Select(element => (datapoint: element.jsonDatapoint, hash: element.hash))], (int)preexistingEntityID, id_searchdomain);
// Insert into datapoints
datapoints.ForEach(datapoint => preexistingEntity.Datapoints.Add(datapoint));
}
// Datapoint - Updated (probmethod or similaritymethod)
if (updatedDatapointsNonText.Count != 0)
{
await DatabaseHelper.DatabaseUpdateDatapoint(
helper,
[.. updatedDatapointsNonText.Select(element => (element.datapointName, element.probMethod, element.similarityMethod))],
(int)preexistingEntityID
);
updatedDatapointsNonText.ForEach(element =>
{
Datapoint preexistingDatapoint = preexistingEntity.Datapoints.First(x => x.Name == element.datapointName);
preexistingDatapoint.ProbMethod = new(element.jsonDatapoint.Probmethod_embedding);
preexistingDatapoint.SimilarityMethod = new(element.jsonDatapoint.SimilarityMethod);
});
}
if (invalidateSearchCache) if (invalidateSearchCache)
{ {
searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity); searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity);
searchdomain.UpdateModelsInUse();
} }
searchdomain.UpdateModelsInUse();
return preexistingEntity; return preexistingEntity;
} }
else else
{ {
int id_entity = DatabaseHelper.DatabaseInsertEntity(helper, jsonEntity.Name, jsonEntity.Probmethod, _databaseHelper.GetSearchdomainID(helper, jsonEntity.Searchdomain)); int id_entity = await DatabaseHelper.DatabaseInsertEntity(helper, jsonEntity.Name, jsonEntity.Probmethod, id_searchdomain);
List<(string attribute, string value, int id_entity)> toBeInsertedAttributes = []; List<(string attribute, string value, int id_entity)> toBeInsertedAttributes = [];
foreach (KeyValuePair<string, string> attribute in jsonEntity.Attributes) foreach (KeyValuePair<string, string> attribute in jsonEntity.Attributes)
{ {
@@ -254,10 +364,11 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
id_entity = id_entity id_entity = id_entity
}); });
} }
DatabaseHelper.DatabaseInsertAttributes(helper, toBeInsertedAttributes);
List<Datapoint> datapoints = []; var insertAttributesTask = DatabaseHelper.DatabaseInsertAttributes(helper, toBeInsertedAttributes);
List<(JSONDatapoint datapoint, string hash)> toBeInsertedDatapoints = []; List<(JSONDatapoint datapoint, string hash)> toBeInsertedDatapoints = [];
ConcurrentBag<string> usedModels = searchdomain.ModelsInUse;
foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints) foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints)
{ {
string hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); string hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
@@ -266,85 +377,100 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
datapoint = jsonDatapoint, datapoint = jsonDatapoint,
hash = hash hash = hash
}); });
foreach (string model in jsonDatapoint.Model)
{
if (!usedModels.Contains(model))
{
usedModels.Add(model);
}
}
} }
List<Datapoint> datapoint = DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, toBeInsertedDatapoints, id_entity);
List<Datapoint> datapoints = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, toBeInsertedDatapoints, id_entity, id_searchdomain);
var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new ProbMethodNotFoundException(jsonEntity.Probmethod); var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new ProbMethodNotFoundException(jsonEntity.Probmethod);
Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), datapoints, jsonEntity.Name) Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), [.. datapoints], jsonEntity.Name, jsonEntity.Searchdomain)
{ {
id = id_entity Id = id_entity
}; };
entityCache.Add(entity); entityCache[jsonEntity.Name] = entity;
searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(entity); searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(entity);
searchdomain.UpdateModelsInUse(); await insertAttributesTask;
return entity; return entity;
} }
} }
public List<Datapoint> DatabaseInsertDatapointsWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, List<(JSONDatapoint datapoint, string hash)> values, int id_entity) public async Task<List<Datapoint>> DatabaseInsertDatapointsWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, List<(JSONDatapoint datapoint, string hash)> values, int id_entity, int id_searchdomain)
{ {
List<Datapoint> result = []; List<Datapoint> result = [];
List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> toBeInsertedDatapoints = []; List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash)> toBeInsertedDatapoints = [];
List<(string hash, string model, byte[] embedding)> toBeInsertedEmbeddings = []; List<(int id_datapoint, string model, byte[] embedding)> toBeInsertedEmbeddings = [];
foreach ((JSONDatapoint datapoint, string hash) value in values) foreach ((JSONDatapoint datapoint, string hash) value in values)
{ {
Datapoint datapoint = BuildDatapointFromJsonDatapoint(value.datapoint, id_entity, searchdomain, value.hash); Datapoint datapoint = await BuildDatapointFromJsonDatapoint(value.datapoint, id_entity, searchdomain, value.hash);
toBeInsertedDatapoints.Add(new() toBeInsertedDatapoints.Add(new()
{ {
name = datapoint.name, name = datapoint.Name,
probmethod_embedding = datapoint.probMethod.probMethodEnum, probmethod_embedding = datapoint.ProbMethod.ProbMethodEnum,
similarityMethod = datapoint.similarityMethod.similarityMethodEnum, similarityMethod = datapoint.SimilarityMethod.SimilarityMethodEnum,
hash = value.hash hash = value.hash
}); });
foreach ((string, float[]) embedding in datapoint.embeddings) foreach ((string, float[]) embedding in datapoint.Embeddings)
{ {
toBeInsertedEmbeddings.Add(new() toBeInsertedEmbeddings.Add(new()
{ {
hash = value.hash, id_datapoint = datapoint.Id,
model = embedding.Item1, model = embedding.Item1,
embedding = BytesFromFloatArray(embedding.Item2) embedding = BytesFromFloatArray(embedding.Item2)
}); });
} }
result.Add(datapoint); result.Add(datapoint);
} }
DatabaseHelper.DatabaseInsertDatapoints(helper, toBeInsertedDatapoints, id_entity); await DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, toBeInsertedEmbeddings, id_entity, id_searchdomain);
DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, toBeInsertedEmbeddings);
return result; return result;
} }
public Datapoint DatabaseInsertDatapointWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, JSONDatapoint jsonDatapoint, int id_entity, string? hash = null) public async Task<Datapoint> DatabaseInsertDatapointWithEmbeddings(SQLHelper helper, Searchdomain searchdomain, JSONDatapoint jsonDatapoint, int id_entity, int id_searchdomain, string? hash = null)
{ {
if (jsonDatapoint.Text is null) if (jsonDatapoint.Text is null)
{ {
throw new Exception("jsonDatapoint.Text must not be null at this point"); throw new Exception("jsonDatapoint.Text must not be null at this point");
} }
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); hash ??= GetHash(jsonDatapoint);
Datapoint datapoint = BuildDatapointFromJsonDatapoint(jsonDatapoint, id_entity, searchdomain, hash); Datapoint datapoint = await BuildDatapointFromJsonDatapoint(jsonDatapoint, id_entity, searchdomain, hash);
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries int id_datapoint = await DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
List<(string model, byte[] embedding)> data = []; List<(string model, byte[] embedding)> data = [];
foreach ((string, float[]) embedding in datapoint.embeddings) foreach ((string, float[]) embedding in datapoint.Embeddings)
{ {
data.Add((embedding.Item1, BytesFromFloatArray(embedding.Item2))); data.Add((embedding.Item1, BytesFromFloatArray(embedding.Item2)));
} }
DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, id_datapoint, data); await DatabaseHelper.DatabaseInsertEmbeddingBulk(helper, id_datapoint, data, id_entity, id_searchdomain);
return datapoint; return datapoint;
} }
public Datapoint BuildDatapointFromJsonDatapoint(JSONDatapoint jsonDatapoint, int entityId, Searchdomain searchdomain, string? hash = null) public string GetHash(JSONDatapoint jsonDatapoint)
{
return Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text ?? throw new Exception("jsonDatapoint.Text must not be null to compute hash"))));
}
public async Task<Datapoint> BuildDatapointFromJsonDatapoint(JSONDatapoint jsonDatapoint, int entityId, Searchdomain searchdomain, string? hash = null)
{ {
if (jsonDatapoint.Text is null) if (jsonDatapoint.Text is null)
{ {
throw new Exception("jsonDatapoint.Text must not be null at this point"); throw new Exception("jsonDatapoint.Text must not be null at this point");
} }
using SQLHelper helper = searchdomain.helper.DuplicateConnection(); SQLHelper helper = searchdomain.Helper;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache; EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.EmbeddingCache;
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text))); hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId); int id = await DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache); Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.AiProvider, embeddingCache);
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding); var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding);
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod); var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod);
return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))], id);
} }
public static (Searchdomain?, int?, string?) TryGetSearchdomain(SearchdomainManager searchdomainManager, string searchdomain, ILogger logger) public static (Searchdomain?, int?, string?) TryGetSearchdomain(SearchdomainManager searchdomainManager, string searchdomain, ILogger logger)

View File

@@ -34,72 +34,113 @@ public static class DatabaseMigrations
if (databaseVersion != initialDatabaseVersion) if (databaseVersion != initialDatabaseVersion)
{ {
helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() }); var _ = helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() }).Result;
} }
} }
public static int DatabaseGetVersion(SQLHelper helper) public static int DatabaseGetVersion(SQLHelper helper)
{ {
DbDataReader reader = helper.ExecuteSQLCommand("show tables", []); DbDataReader reader = helper.ExecuteSQLCommand("show tables", []);
bool hasTables = reader.Read(); try
reader.Close();
if (!hasTables)
{ {
return 0; bool hasTables = reader.Read();
if (!hasTables)
{
return 0;
}
} finally
{
reader.Close();
} }
reader = helper.ExecuteSQLCommand("show tables like '%settings%'", []); reader = helper.ExecuteSQLCommand("show tables like '%settings%'", []);
bool hasSystemTable = reader.Read(); try
reader.Close();
if (!hasSystemTable)
{ {
return 1; bool hasSystemTable = reader.Read();
if (!hasSystemTable)
{
return 1;
}
} finally
{
reader.Close();
} }
reader = helper.ExecuteSQLCommand("SELECT value FROM settings WHERE name=\"DatabaseVersion\"", []); reader = helper.ExecuteSQLCommand("SELECT value FROM settings WHERE name=\"DatabaseVersion\"", []);
reader.Read(); try
string rawVersion = reader.GetString(0);
reader.Close();
bool success = int.TryParse(rawVersion, out int version);
if (!success)
{ {
throw new DatabaseVersionException(); reader.Read();
string rawVersion = reader.GetString(0);
bool success = int.TryParse(rawVersion, out int version);
if (!success)
{
throw new DatabaseVersionException();
}
return version;
} finally
{
reader.Close();
} }
return version;
} }
public static int Create(SQLHelper helper) public static int Create(SQLHelper helper)
{ {
helper.ExecuteSQLNonQuery("CREATE TABLE searchdomain (id int PRIMARY KEY auto_increment, name varchar(512), settings JSON);", []); var _ = helper.ExecuteSQLNonQuery("CREATE TABLE searchdomain (id int PRIMARY KEY auto_increment, name varchar(512), settings JSON);", []).Result;
helper.ExecuteSQLNonQuery("CREATE TABLE entity (id int PRIMARY KEY auto_increment, name varchar(512), probmethod varchar(128), id_searchdomain int, FOREIGN KEY (id_searchdomain) REFERENCES searchdomain(id));", []); _ = helper.ExecuteSQLNonQuery("CREATE TABLE entity (id int PRIMARY KEY auto_increment, name varchar(512), probmethod varchar(128), id_searchdomain int, FOREIGN KEY (id_searchdomain) REFERENCES searchdomain(id));", []).Result;
helper.ExecuteSQLNonQuery("CREATE TABLE attribute (id int PRIMARY KEY auto_increment, id_entity int, attribute varchar(512), value longtext, FOREIGN KEY (id_entity) REFERENCES entity(id));", []); _ = helper.ExecuteSQLNonQuery("CREATE TABLE attribute (id int PRIMARY KEY auto_increment, id_entity int, attribute varchar(512), value longtext, FOREIGN KEY (id_entity) REFERENCES entity(id));", []).Result;
helper.ExecuteSQLNonQuery("CREATE TABLE datapoint (id int PRIMARY KEY auto_increment, name varchar(512), probmethod_embedding varchar(512), id_entity int, FOREIGN KEY (id_entity) REFERENCES entity(id));", []); _ = helper.ExecuteSQLNonQuery("CREATE TABLE datapoint (id int PRIMARY KEY auto_increment, name varchar(512), probmethod_embedding varchar(512), id_entity int, FOREIGN KEY (id_entity) REFERENCES entity(id));", []).Result;
helper.ExecuteSQLNonQuery("CREATE TABLE embedding (id int PRIMARY KEY auto_increment, id_datapoint int, model varchar(512), embedding blob, FOREIGN KEY (id_datapoint) REFERENCES datapoint(id));", []); _ = helper.ExecuteSQLNonQuery("CREATE TABLE embedding (id int PRIMARY KEY auto_increment, id_datapoint int, model varchar(512), embedding blob, FOREIGN KEY (id_datapoint) REFERENCES datapoint(id));", []).Result;
return 1; return 1;
} }
public static int UpdateFrom1(SQLHelper helper) public static int UpdateFrom1(SQLHelper helper)
{ {
helper.ExecuteSQLNonQuery("CREATE TABLE settings (name varchar(512), value varchar(8192));", []); var _ = helper.ExecuteSQLNonQuery("CREATE TABLE settings (name varchar(512), value varchar(8192));", []).Result;
helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES (\"DatabaseVersion\", \"2\");", []); _ = helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES (\"DatabaseVersion\", \"2\");", []).Result;
return 2; return 2;
} }
public static int UpdateFrom2(SQLHelper helper) public static int UpdateFrom2(SQLHelper helper)
{ {
helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD hash VARCHAR(44);", []); var _ = helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD hash VARCHAR(44);", []).Result;
helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []); _ = helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []).Result;
return 3; return 3;
} }
public static int UpdateFrom3(SQLHelper helper) public static int UpdateFrom3(SQLHelper helper)
{ {
helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []); var _ = helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []).Result;
return 4; return 4;
} }
public static int UpdateFrom4(SQLHelper helper) public static int UpdateFrom4(SQLHelper helper)
{ {
helper.ExecuteSQLNonQuery("UPDATE searchdomain SET settings = JSON_SET(settings, '$.QueryCacheSize', 1000000) WHERE JSON_EXTRACT(settings, '$.QueryCacheSize') is NULL;", []); // Set QueryCacheSize to a default of 1000000 var _ = helper.ExecuteSQLNonQuery("UPDATE searchdomain SET settings = JSON_SET(settings, '$.QueryCacheSize', 1000000) WHERE JSON_EXTRACT(settings, '$.QueryCacheSize') is NULL;", []).Result; // Set QueryCacheSize to a default of 1000000
return 5; return 5;
} }
public static int UpdateFrom5(SQLHelper helper)
{
// Add id_entity to embedding
var _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_entity INT NULL", []).Result;
int count;
do
{
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN datapoint d ON d.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_entity IS NULL LIMIT 10000) x on x.id = e.id SET e.id_entity = d.id_entity;", []).Result;
} while (count == 10000);
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_entity INT NOT NULL;", []).Result;
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_entity_model ON embedding (id_entity, model)", []).Result;
// Add id_searchdomain to embedding
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_searchdomain INT NULL", []).Result;
do
{
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN entity en ON en.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_searchdomain IS NULL LIMIT 10000) x on x.id = e.id SET e.id_searchdomain = en.id_searchdomain;", []).Result;
} while (count == 10000);
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_searchdomain INT NOT NULL;", []).Result;
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_searchdomain_model ON embedding (id_searchdomain)", []).Result;
return 6;
}
} }

View File

@@ -12,6 +12,7 @@ public class EmbeddingSearchOptions : ApiKeyOptions
public required SimpleAuthOptions SimpleAuth { get; set; } public required SimpleAuthOptions SimpleAuth { get; set; }
public required CacheOptions Cache { get; set; } public required CacheOptions Cache { get; set; }
public required bool UseHttpsRedirection { get; set; } public required bool UseHttpsRedirection { get; set; }
public int? MaxRequestBodySize { get; set; }
} }
public class AiProvider public class AiProvider

View File

@@ -6,34 +6,29 @@ namespace Server;
public class ProbMethod public class ProbMethod
{ {
public Probmethods.probMethodDelegate method; public Probmethods.ProbMethodDelegate Method;
public ProbMethodEnum probMethodEnum; public ProbMethodEnum ProbMethodEnum;
public string name; public string Name;
public ProbMethod(ProbMethodEnum probMethodEnum, ILogger logger) public ProbMethod(ProbMethodEnum probMethodEnum)
{ {
this.probMethodEnum = probMethodEnum; this.ProbMethodEnum = probMethodEnum;
this.name = probMethodEnum.ToString(); this.Name = probMethodEnum.ToString();
Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name); Probmethods.ProbMethodDelegate? probMethod = Probmethods.GetMethod(Name) ?? throw new ProbMethodNotFoundException(probMethodEnum);
if (probMethod is null) Method = probMethod;
{
logger.LogError("Unable to retrieve probMethod {name}", [name]);
throw new ProbMethodNotFoundException(probMethodEnum);
}
method = probMethod;
} }
} }
public static class Probmethods public static class Probmethods
{ {
public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters); public delegate float ProbMethodProtoDelegate(List<(string, float)> list, string parameters);
public delegate float probMethodDelegate(List<(string, float)> list); public delegate float ProbMethodDelegate(List<(string, float)> list);
public static readonly Dictionary<ProbMethodEnum, probMethodProtoDelegate> probMethods; public static readonly Dictionary<ProbMethodEnum, ProbMethodProtoDelegate> ProbMethods;
static Probmethods() static Probmethods()
{ {
probMethods = new Dictionary<ProbMethodEnum, probMethodProtoDelegate> ProbMethods = new Dictionary<ProbMethodEnum, ProbMethodProtoDelegate>
{ {
[ProbMethodEnum.Mean] = Mean, [ProbMethodEnum.Mean] = Mean,
[ProbMethodEnum.HarmonicMean] = HarmonicMean, [ProbMethodEnum.HarmonicMean] = HarmonicMean,
@@ -46,12 +41,12 @@ public static class Probmethods
}; };
} }
public static probMethodDelegate? GetMethod(ProbMethodEnum probMethodEnum) public static ProbMethodDelegate? GetMethod(ProbMethodEnum probMethodEnum)
{ {
return GetMethod(probMethodEnum.ToString()); return GetMethod(probMethodEnum.ToString());
} }
public static probMethodDelegate? GetMethod(string name) public static ProbMethodDelegate? GetMethod(string name)
{ {
string methodName = name; string methodName = name;
string? jsonArg = ""; string? jsonArg = "";
@@ -68,7 +63,7 @@ public static class Probmethods
methodName methodName
); );
if (!probMethods.TryGetValue(probMethodEnum, out probMethodProtoDelegate? method)) if (!ProbMethods.TryGetValue(probMethodEnum, out ProbMethodProtoDelegate? method))
{ {
return null; return null;
} }

View File

@@ -42,7 +42,7 @@ builder.WebHost.ConfigureKestrel(options =>
}); });
// Migrate database // Migrate database
var helper = new SQLHelper(new MySql.Data.MySqlClient.MySqlConnection(configuration.ConnectionStrings.SQL), configuration.ConnectionStrings.SQL); SQLHelper helper = new(new MySql.Data.MySqlClient.MySqlConnection(configuration.ConnectionStrings.SQL), configuration.ConnectionStrings.SQL);
DatabaseMigrations.Migrate(helper); DatabaseMigrations.Migrate(helper);
// Migrate SQLite cache // Migrate SQLite cache

View File

@@ -7,6 +7,7 @@ using Server.Helper;
using Shared; using Shared;
using Shared.Models; using Shared.Models;
using AdaptiveExpressions; using AdaptiveExpressions;
using System.Collections.Concurrent;
namespace Server; namespace Server;
@@ -14,36 +15,33 @@ public class Searchdomain
{ {
private readonly string _connectionString; private readonly string _connectionString;
private readonly string _provider; private readonly string _provider;
public AIProvider aIProvider; public AIProvider AiProvider;
public string searchdomain; public string SearchdomainName;
public int id; public int Id;
public SearchdomainSettings settings; public SearchdomainSettings Settings;
public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp) public EnumerableLruCache<string, DateTimedSearchResult> QueryCache; // Key: query, Value: Search results for that query (with timestamp)
public List<Entity> entityCache; public ConcurrentDictionary<string, Entity> EntityCache;
public List<string> modelsInUse; public ConcurrentBag<string> ModelsInUse;
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache; public EnumerableLruCache<string, Dictionary<string, float[]>> EmbeddingCache;
private readonly MySqlConnection connection; public SQLHelper Helper;
public SQLHelper helper;
private readonly ILogger _logger; private readonly ILogger _logger;
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) public Searchdomain(string searchdomain, string connectionString, SQLHelper sqlHelper, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
{ {
_connectionString = connectionString; _connectionString = connectionString;
_provider = provider.ToLower(); _provider = provider.ToLower();
this.searchdomain = searchdomain; this.SearchdomainName = searchdomain;
this.aIProvider = aIProvider; this.AiProvider = aIProvider;
this.embeddingCache = embeddingCache; this.EmbeddingCache = embeddingCache;
this._logger = logger; this._logger = logger;
entityCache = []; EntityCache = [];
connection = new MySqlConnection(connectionString); Helper = sqlHelper;
connection.Open(); Settings = GetSettings();
helper = new SQLHelper(connection, connectionString); QueryCache = new(Settings.QueryCacheSize);
settings = GetSettings(); ModelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that...
queryCache = new(settings.QueryCacheSize);
modelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that...
if (!runEmpty) if (!runEmpty)
{ {
GetID(); Id = GetID().Result;
UpdateEntityCache(); UpdateEntityCache();
} }
} }
@@ -53,118 +51,138 @@ public class Searchdomain
InvalidateSearchCache(); InvalidateSearchCache();
Dictionary<string, dynamic> parametersIDSearchdomain = new() Dictionary<string, dynamic> parametersIDSearchdomain = new()
{ {
["id"] = this.id ["id"] = this.Id
}; };
DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT e.id, e.id_datapoint, e.model, e.embedding FROM embedding e JOIN datapoint d ON e.id_datapoint = d.id JOIN entity ent ON d.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain); DbDataReader embeddingReader = Helper.ExecuteSQLCommand("SELECT id, id_datapoint, model, embedding FROM embedding WHERE id_searchdomain = @id", parametersIDSearchdomain);
Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = []; Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = [];
while (embeddingReader.Read()) try
{ {
int? id_datapoint_debug = null; while (embeddingReader.Read())
try
{ {
int id_datapoint = embeddingReader.GetInt32(1); int? id_datapoint_debug = null;
id_datapoint_debug = id_datapoint; try
string model = embeddingReader.GetString(2);
long length = embeddingReader.GetBytes(3, 0, null, 0, 0);
byte[] embedding = new byte[length];
embeddingReader.GetBytes(3, 0, embedding, 0, (int) length);
if (embedding_unassigned.TryGetValue(id_datapoint, out Dictionary<string, float[]>? embedding_unassigned_id_datapoint))
{ {
embedding_unassigned[id_datapoint][model] = SearchdomainHelper.FloatArrayFromBytes(embedding); int id_datapoint = embeddingReader.GetInt32(1);
} id_datapoint_debug = id_datapoint;
else string model = embeddingReader.GetString(2);
{ long length = embeddingReader.GetBytes(3, 0, null, 0, 0);
embedding_unassigned[id_datapoint] = new() byte[] embedding = new byte[length];
embeddingReader.GetBytes(3, 0, embedding, 0, (int) length);
if (embedding_unassigned.TryGetValue(id_datapoint, out Dictionary<string, float[]>? embedding_unassigned_id_datapoint))
{ {
[model] = SearchdomainHelper.FloatArrayFromBytes(embedding) embedding_unassigned[id_datapoint][model] = SearchdomainHelper.FloatArrayFromBytes(embedding);
}; }
} else
} catch (Exception e) {
{ embedding_unassigned[id_datapoint] = new()
_logger.LogError("Error reading embedding (id: {id_datapoint}) from database: {e.Message} - {e.StackTrace}", [id_datapoint_debug, e.Message, e.StackTrace]); {
ElmahCore.ElmahExtensions.RaiseError(e); [model] = SearchdomainHelper.FloatArrayFromBytes(embedding)
} };
} }
embeddingReader.Close(); } catch (Exception e)
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT d.id, d.id_entity, d.name, d.probmethod_embedding, d.similaritymethod, d.hash FROM datapoint d JOIN entity ent ON d.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain);
Dictionary<int, List<Datapoint>> datapoint_unassigned = [];
while (datapointReader.Read())
{
int id = datapointReader.GetInt32(0);
int id_entity = datapointReader.GetInt32(1);
string name = datapointReader.GetString(2);
string probmethodString = datapointReader.GetString(3);
string similarityMethodString = datapointReader.GetString(4);
string hash = datapointReader.GetString(5);
ProbMethodEnum probmethodEnum = (ProbMethodEnum)Enum.Parse(
typeof(ProbMethodEnum),
probmethodString
);
SimilarityMethodEnum similairtyMethodEnum = (SimilarityMethodEnum)Enum.Parse(
typeof(SimilarityMethodEnum),
similarityMethodString
);
ProbMethod probmethod = new(probmethodEnum, _logger);
SimilarityMethod similarityMethod = new(similairtyMethodEnum, _logger);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{
embedding_unassigned.Remove(id);
if (!datapoint_unassigned.ContainsKey(id_entity))
{ {
datapoint_unassigned[id_entity] = []; _logger.LogError("Error reading embedding (id: {id_datapoint}) from database: {e.Message} - {e.StackTrace}", [id_datapoint_debug, e.Message, e.StackTrace]);
ElmahCore.ElmahExtensions.RaiseError(e);
} }
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
} }
} finally
{
embeddingReader.Close();
} }
datapointReader.Close();
DbDataReader attributeReader = helper.ExecuteSQLCommand("SELECT a.id, a.id_entity, a.attribute, a.value FROM attribute a JOIN entity ent ON a.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain); DbDataReader datapointReader = Helper.ExecuteSQLCommand("SELECT d.id, d.id_entity, d.name, d.probmethod_embedding, d.similaritymethod, d.hash FROM datapoint d JOIN entity ent ON d.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain);
Dictionary<int, ConcurrentBag<Datapoint>> datapoint_unassigned = [];
try
{
while (datapointReader.Read())
{
int id = datapointReader.GetInt32(0);
int id_entity = datapointReader.GetInt32(1);
string name = datapointReader.GetString(2);
string probmethodString = datapointReader.GetString(3);
string similarityMethodString = datapointReader.GetString(4);
string hash = datapointReader.GetString(5);
ProbMethodEnum probmethodEnum = (ProbMethodEnum)Enum.Parse(
typeof(ProbMethodEnum),
probmethodString
);
SimilarityMethodEnum similairtyMethodEnum = (SimilarityMethodEnum)Enum.Parse(
typeof(SimilarityMethodEnum),
similarityMethodString
);
ProbMethod probmethod = new(probmethodEnum);
SimilarityMethod similarityMethod = new(similairtyMethodEnum);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{
embedding_unassigned.Remove(id);
if (!datapoint_unassigned.ContainsKey(id_entity))
{
datapoint_unassigned[id_entity] = [];
}
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))], id));
}
}
} finally
{
datapointReader.Close();
}
DbDataReader attributeReader = Helper.ExecuteSQLCommand("SELECT a.id, a.id_entity, a.attribute, a.value FROM attribute a JOIN entity ent ON a.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain);
Dictionary<int, Dictionary<string, string>> attributes_unassigned = []; Dictionary<int, Dictionary<string, string>> attributes_unassigned = [];
while (attributeReader.Read()) try
{ {
//"SELECT id, id_entity, attribute, value FROM attribute JOIN entity on attribute.id_entity as en JOIN searchdomain on en.id_searchdomain as sd WHERE sd=@id" while (attributeReader.Read())
int id = attributeReader.GetInt32(0);
int id_entity = attributeReader.GetInt32(1);
string attribute = attributeReader.GetString(2);
string value = attributeReader.GetString(3);
if (!attributes_unassigned.ContainsKey(id_entity))
{ {
attributes_unassigned[id_entity] = []; //"SELECT id, id_entity, attribute, value FROM attribute JOIN entity on attribute.id_entity as en JOIN searchdomain on en.id_searchdomain as sd WHERE sd=@id"
} int id = attributeReader.GetInt32(0);
attributes_unassigned[id_entity].Add(attribute, value); int id_entity = attributeReader.GetInt32(1);
} string attribute = attributeReader.GetString(2);
attributeReader.Close(); string value = attributeReader.GetString(3);
if (!attributes_unassigned.ContainsKey(id_entity))
entityCache = [];
DbDataReader entityReader = helper.ExecuteSQLCommand("SELECT entity.id, name, probmethod FROM entity WHERE id_searchdomain=@id", parametersIDSearchdomain);
while (entityReader.Read())
{
//SELECT id, name, probmethod FROM entity WHERE id_searchdomain=@id
int id = entityReader.GetInt32(0);
string name = entityReader.GetString(1);
string probmethodString = entityReader.GetString(2);
if (!attributes_unassigned.TryGetValue(id, out Dictionary<string, string>? attributes))
{
attributes = [];
}
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
if (datapoint_unassigned.TryGetValue(id, out List<Datapoint>? datapoints) && probmethod is not null)
{
Entity entity = new(attributes, probmethod, probmethodString, datapoints, name)
{ {
id = id attributes_unassigned[id_entity] = [];
}; }
entityCache.Add(entity); attributes_unassigned[id_entity].Add(attribute, value);
} }
} finally
{
attributeReader.Close();
} }
entityReader.Close();
modelsInUse = GetModels(entityCache); EntityCache = [];
DbDataReader entityReader = Helper.ExecuteSQLCommand("SELECT entity.id, name, probmethod FROM entity WHERE id_searchdomain=@id", parametersIDSearchdomain);
try
{
while (entityReader.Read())
{
//SELECT id, name, probmethod FROM entity WHERE id_searchdomain=@id
int id = entityReader.GetInt32(0);
string name = entityReader.GetString(1);
string probmethodString = entityReader.GetString(2);
if (!attributes_unassigned.TryGetValue(id, out Dictionary<string, string>? attributes))
{
attributes = [];
}
Probmethods.ProbMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
if (datapoint_unassigned.TryGetValue(id, out ConcurrentBag<Datapoint>? datapoints) && probmethod is not null)
{
Entity entity = new(attributes, probmethod, probmethodString, datapoints, name, SearchdomainName)
{
Id = id
};
EntityCache[name] = entity;
}
}
} finally
{
entityReader.Close();
}
ModelsInUse = GetModels(EntityCache);
} }
public List<(float, string)> Search(string query, int? topN = null) public List<(float, string)> Search(string query, int? topN = null)
{ {
if (queryCache.TryGetValue(query, out DateTimedSearchResult cachedResult)) if (QueryCache.TryGetValue(query, out DateTimedSearchResult cachedResult))
{ {
cachedResult.AccessDateTimes.Add(DateTime.Now); cachedResult.AccessDateTimes.Add(DateTime.Now);
return [.. cachedResult.Results.Select(r => (r.Score, r.Name))]; return [.. cachedResult.Results.Select(r => (r.Score, r.Name))];
@@ -173,10 +191,9 @@ public class Searchdomain
Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query); Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query);
List<(float, string)> result = []; List<(float, string)> result = [];
foreach ((string name, Entity entity) in EntityCache)
foreach (Entity entity in entityCache)
{ {
result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.name)); result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.Name));
} }
IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1); IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1);
if (topN is not null) if (topN is not null)
@@ -188,26 +205,26 @@ public class Searchdomain
[.. sortedResults.Select(r => [.. sortedResults.Select(r =>
new ResultItem(r.Item1, r.Item2 ))] new ResultItem(r.Item1, r.Item2 ))]
); );
queryCache.Set(query, new DateTimedSearchResult(DateTime.Now, searchResult)); QueryCache.Set(query, new DateTimedSearchResult(DateTime.Now, searchResult));
return results; return results;
} }
public Dictionary<string, float[]> GetQueryEmbeddings(string query) public Dictionary<string, float[]> GetQueryEmbeddings(string query)
{ {
bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings); bool hasQuery = EmbeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings);
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model)); bool allModelsInQuery = queryEmbeddings is not null && ModelsInUse.All(model => queryEmbeddings.ContainsKey(model));
if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null) if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null)
{ {
queryEmbeddings = Datapoint.GetEmbeddings(query, modelsInUse, aIProvider, embeddingCache); queryEmbeddings = Datapoint.GetEmbeddings(query, ModelsInUse, AiProvider, EmbeddingCache);
if (!embeddingCache.TryGetValue(query, out var embeddingCacheForCurrentQuery)) if (!EmbeddingCache.TryGetValue(query, out var embeddingCacheForCurrentQuery))
{ {
embeddingCache.Set(query, queryEmbeddings); EmbeddingCache.Set(query, queryEmbeddings);
} }
else // embeddingCache already has an entry for this query, so the missing model-embedding pairs have to be filled in else // embeddingCache already has an entry for this query, so the missing model-embedding pairs have to be filled in
{ {
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding
{ {
if (!embeddingCache.TryGetValue(kvp.Key, out var _)) if (!EmbeddingCache.TryGetValue(kvp.Key, out var _))
{ {
embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value; embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value;
} }
@@ -219,37 +236,38 @@ public class Searchdomain
public void UpdateModelsInUse() public void UpdateModelsInUse()
{ {
modelsInUse = GetModels(entityCache.ToList()); ModelsInUse = GetModels(EntityCache);
} }
private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary<string, float[]> queryEmbeddings) private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary<string, float[]> queryEmbeddings)
{ {
List<(string, float)> datapointProbs = []; List<(string, float)> datapointProbs = [];
foreach (Datapoint datapoint in entity.datapoints) foreach (Datapoint datapoint in entity.Datapoints)
{ {
SimilarityMethod similarityMethod = datapoint.similarityMethod; SimilarityMethod similarityMethod = datapoint.SimilarityMethod;
List<(string, float)> list = []; List<(string, float)> list = [];
foreach ((string, float[]) embedding in datapoint.embeddings) foreach ((string, float[]) embedding in datapoint.Embeddings)
{ {
string key = embedding.Item1; string key = embedding.Item1;
float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2); float value = similarityMethod.Method(queryEmbeddings[embedding.Item1], embedding.Item2);
list.Add((key, value)); list.Add((key, value));
} }
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list))); datapointProbs.Add((datapoint.Name, datapoint.ProbMethod.Method(list)));
} }
return entity.probMethod(datapointProbs); return entity.ProbMethod(datapointProbs);
} }
public static List<string> GetModels(List<Entity> entities) public static ConcurrentBag<string> GetModels(ConcurrentDictionary<string, Entity> entities)
{ {
List<string> result = []; ConcurrentBag<string> result = [];
lock (entities) foreach (KeyValuePair<string, Entity> element in entities)
{ {
foreach (Entity entity in entities) Entity entity = element.Value;
lock (entity)
{ {
foreach (Datapoint datapoint in entity.datapoints) foreach (Datapoint datapoint in entity.Datapoints)
{ {
foreach ((string, float[]) tuple in datapoint.embeddings) foreach ((string, float[]) tuple in datapoint.Embeddings)
{ {
string model = tuple.Item1; string model = tuple.Item1;
if (!result.Contains(model)) if (!result.Contains(model))
@@ -263,29 +281,25 @@ public class Searchdomain
return result; return result;
} }
public int GetID() public async Task<int> GetID()
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, object?> parameters = new()
{ {
["name"] = this.searchdomain { "name", this.SearchdomainName }
}; };
DbDataReader reader = helper.ExecuteSQLCommand("SELECT id from searchdomain WHERE name = @name", parameters); return (await Helper.ExecuteQueryAsync("SELECT id from searchdomain WHERE name = @name", parameters, x => x.GetInt32(0))).First();
reader.Read();
this.id = reader.GetInt32(0);
reader.Close();
return this.id;
} }
public SearchdomainSettings GetSettings() public SearchdomainSettings GetSettings()
{ {
return DatabaseHelper.GetSearchdomainSettings(helper, searchdomain); return DatabaseHelper.GetSearchdomainSettings(Helper, SearchdomainName);
} }
public void ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(Entity entity) public void ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(Entity entity)
{ {
if (settings.CacheReconciliation) if (Settings.CacheReconciliation)
{ {
foreach (var element in queryCache) foreach (var element in QueryCache)
{ {
string query = element.Key; string query = element.Key;
DateTimedSearchResult searchResult = element.Value; DateTimedSearchResult searchResult = element.Value;
@@ -293,9 +307,9 @@ public class Searchdomain
Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query); Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query);
float evaluationResult = EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings); float evaluationResult = EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings);
searchResult.Results.RemoveAll(x => x.Name == entity.name); // If entity already exists in that results list: remove it. searchResult.Results.RemoveAll(x => x.Name == entity.Name); // If entity already exists in that results list: remove it.
ResultItem newItem = new(evaluationResult, entity.name); ResultItem newItem = new(evaluationResult, entity.Name);
int index = searchResult.Results.BinarySearch( int index = searchResult.Results.BinarySearch(
newItem, newItem,
Comparer<ResultItem>.Create((a, b) => b.Score.CompareTo(a.Score)) // Invert searching order Comparer<ResultItem>.Create((a, b) => b.Score.CompareTo(a.Score)) // Invert searching order
@@ -313,13 +327,13 @@ public class Searchdomain
public void ReconciliateOrInvalidateCacheForDeletedEntity(Entity entity) public void ReconciliateOrInvalidateCacheForDeletedEntity(Entity entity)
{ {
if (settings.CacheReconciliation) if (Settings.CacheReconciliation)
{ {
foreach (KeyValuePair<string, DateTimedSearchResult> element in queryCache) foreach (KeyValuePair<string, DateTimedSearchResult> element in QueryCache)
{ {
string query = element.Key; string query = element.Key;
DateTimedSearchResult searchResult = element.Value; DateTimedSearchResult searchResult = element.Value;
searchResult.Results.RemoveAll(x => x.Name == entity.name); searchResult.Results.RemoveAll(x => x.Name == entity.Name);
} }
} }
else else
@@ -330,13 +344,13 @@ public class Searchdomain
public void InvalidateSearchCache() public void InvalidateSearchCache()
{ {
queryCache = new(settings.QueryCacheSize); QueryCache = new(Settings.QueryCacheSize);
} }
public long GetSearchCacheSize() public long GetSearchCacheSize()
{ {
long EmbeddingCacheUtilization = 0; long EmbeddingCacheUtilization = 0;
foreach (var entry in queryCache) foreach (var entry in QueryCache)
{ {
EmbeddingCacheUtilization += sizeof(int); // string length prefix EmbeddingCacheUtilization += sizeof(int); // string length prefix
EmbeddingCacheUtilization += entry.Key.Length * sizeof(char); // string characters EmbeddingCacheUtilization += entry.Key.Length * sizeof(char); // string characters

View File

@@ -15,50 +15,50 @@ namespace Server;
public class SearchdomainManager : IDisposable public class SearchdomainManager : IDisposable
{ {
private Dictionary<string, Searchdomain> searchdomains = []; private Dictionary<string, Searchdomain> _searchdomains = [];
private readonly ILogger<SearchdomainManager> _logger; private readonly ILogger<SearchdomainManager> _logger;
private readonly EmbeddingSearchOptions _options; private readonly EmbeddingSearchOptions _options;
public readonly AIProvider aIProvider; public readonly AIProvider AiProvider;
private readonly DatabaseHelper _databaseHelper; private readonly DatabaseHelper _databaseHelper;
private readonly string connectionString; private readonly string connectionString;
private MySqlConnection connection; private MySqlConnection _connection;
public SQLHelper helper; public SQLHelper Helper;
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache; public EnumerableLruCache<string, Dictionary<string, float[]>> EmbeddingCache;
public long EmbeddingCacheMaxCount; public long EmbeddingCacheMaxCount;
private bool disposed = false; private bool _disposed = false;
public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper) public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper)
{ {
_logger = logger; _logger = logger;
_options = options.Value; _options = options.Value;
this.aIProvider = aIProvider; this.AiProvider = aIProvider;
_databaseHelper = databaseHelper; _databaseHelper = databaseHelper;
EmbeddingCacheMaxCount = _options.Cache.CacheTopN; EmbeddingCacheMaxCount = _options.Cache.CacheTopN;
if (options.Value.Cache.StoreEmbeddingCache) if (options.Value.Cache.StoreEmbeddingCache)
{ {
var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
embeddingCache = CacheHelper.GetEmbeddingStore(options.Value); EmbeddingCache = CacheHelper.GetEmbeddingStore(options.Value);
stopwatch.Stop(); stopwatch.Stop();
_logger.LogInformation("GetEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds); _logger.LogInformation("GetEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds);
} else } else
{ {
embeddingCache = new((int)EmbeddingCacheMaxCount); EmbeddingCache = new((int)EmbeddingCacheMaxCount);
} }
connectionString = _options.ConnectionStrings.SQL; connectionString = _options.ConnectionStrings.SQL;
connection = new MySqlConnection(connectionString); _connection = new MySqlConnection(connectionString);
connection.Open(); _connection.Open();
helper = new SQLHelper(connection, connectionString); Helper = new SQLHelper(_connection, connectionString);
} }
public Searchdomain GetSearchdomain(string searchdomain) public Searchdomain GetSearchdomain(string searchdomain)
{ {
if (searchdomains.TryGetValue(searchdomain, out Searchdomain? value)) if (_searchdomains.TryGetValue(searchdomain, out Searchdomain? value))
{ {
return value; return value;
} }
try try
{ {
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, aIProvider, embeddingCache, _logger)); return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, Helper, AiProvider, EmbeddingCache, _logger));
} }
catch (MySqlException) catch (MySqlException)
{ {
@@ -79,34 +79,19 @@ public class SearchdomainManager : IDisposable
searchdomain.InvalidateSearchCache(); searchdomain.InvalidateSearchCache();
} }
public List<string> ListSearchdomains() public async Task<List<string>> ListSearchdomainsAsync()
{ {
lock (helper.connection) return await Helper.ExecuteQueryAsync("SELECT name FROM searchdomain", [], x => x.GetString(0));
{
DbDataReader reader = helper.ExecuteSQLCommand("SELECT name FROM searchdomain", []);
List<string> results = [];
try
{
while (reader.Read())
{
results.Add(reader.GetString(0));
}
return results;
}
finally
{
reader.Close();
}
}
} }
public int CreateSearchdomain(string searchdomain, SearchdomainSettings settings) public async Task<int> CreateSearchdomain(string searchdomain, SearchdomainSettings settings)
{ {
return CreateSearchdomain(searchdomain, JsonSerializer.Serialize(settings)); return await CreateSearchdomain(searchdomain, JsonSerializer.Serialize(settings));
} }
public int CreateSearchdomain(string searchdomain, string settings = "{}")
public async Task<int> CreateSearchdomain(string searchdomain, string settings = "{}")
{ {
if (searchdomains.TryGetValue(searchdomain, out Searchdomain? value)) if (_searchdomains.TryGetValue(searchdomain, out Searchdomain? value))
{ {
_logger.LogError("Searchdomain {searchdomain} could not be created, as it already exists", [searchdomain]); _logger.LogError("Searchdomain {searchdomain} could not be created, as it already exists", [searchdomain]);
throw new SearchdomainAlreadyExistsException(searchdomain); throw new SearchdomainAlreadyExistsException(searchdomain);
@@ -116,27 +101,30 @@ public class SearchdomainManager : IDisposable
{ "name", searchdomain }, { "name", searchdomain },
{ "settings", settings} { "settings", settings}
}; };
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters); int id = await Helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters);
_searchdomains.Add(searchdomain, new(searchdomain, connectionString, Helper, AiProvider, EmbeddingCache, _logger));
return id;
} }
public int DeleteSearchdomain(string searchdomain) public async Task<int> DeleteSearchdomain(string searchdomain)
{ {
int counter = _databaseHelper.RemoveAllEntities(helper, searchdomain); int counter = await _databaseHelper.RemoveAllEntities(Helper, searchdomain);
_logger.LogDebug($"Number of entities deleted as part of deleting the searchdomain \"{searchdomain}\": {counter}"); _logger.LogDebug($"Number of entities deleted as part of deleting the searchdomain \"{searchdomain}\": {counter}");
helper.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", searchdomain}}); await Helper.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", searchdomain}});
searchdomains.Remove(searchdomain); _searchdomains.Remove(searchdomain);
_logger.LogDebug($"Searchdomain has been successfully removed"); _logger.LogDebug($"Searchdomain has been successfully removed");
return counter; return counter;
} }
private Searchdomain SetSearchdomain(string name, Searchdomain searchdomain) private Searchdomain SetSearchdomain(string name, Searchdomain searchdomain)
{ {
searchdomains[name] = searchdomain; _searchdomains[name] = searchdomain;
return searchdomain; return searchdomain;
} }
public bool IsSearchdomainLoaded(string name) public bool IsSearchdomainLoaded(string name)
{ {
return searchdomains.ContainsKey(name); return _searchdomains.ContainsKey(name);
} }
// Cleanup procedure // Cleanup procedure
@@ -147,7 +135,7 @@ public class SearchdomainManager : IDisposable
if (_options.Cache.StoreEmbeddingCache) if (_options.Cache.StoreEmbeddingCache)
{ {
var stopwatch = Stopwatch.StartNew(); var stopwatch = Stopwatch.StartNew();
await CacheHelper.UpdateEmbeddingStore(embeddingCache, _options); await CacheHelper.UpdateEmbeddingStore(EmbeddingCache, _options);
stopwatch.Stop(); stopwatch.Stop();
_logger.LogInformation("UpdateEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds); _logger.LogInformation("UpdateEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds);
} }
@@ -167,10 +155,10 @@ public class SearchdomainManager : IDisposable
protected virtual async Task Dispose(bool disposing) protected virtual async Task Dispose(bool disposing)
{ {
if (!disposed && disposing) if (!_disposed && disposing)
{ {
await Cleanup(); await Cleanup();
disposed = true; _disposed = true;
} }
} }
} }

View File

@@ -5,21 +5,16 @@ namespace Server;
public class SimilarityMethod public class SimilarityMethod
{ {
public SimilarityMethods.similarityMethodDelegate method; public SimilarityMethods.similarityMethodDelegate Method;
public SimilarityMethodEnum similarityMethodEnum; public SimilarityMethodEnum SimilarityMethodEnum;
public string name; public string Name;
public SimilarityMethod(SimilarityMethodEnum similarityMethodEnum, ILogger logger) public SimilarityMethod(SimilarityMethodEnum similarityMethodEnum)
{ {
this.similarityMethodEnum = similarityMethodEnum; SimilarityMethodEnum = similarityMethodEnum;
this.name = similarityMethodEnum.ToString(); Name = similarityMethodEnum.ToString();
SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name); SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(Name) ?? throw new Exception($"Unable to retrieve similarityMethod {Name}");
if (probMethod is null) Method = probMethod;
{
logger.LogError("Unable to retrieve similarityMethod {name}", [name]);
throw new Exception("Unable to retrieve similarityMethod");
}
method = probMethod;
} }
} }
@@ -27,11 +22,11 @@ public static class SimilarityMethods
{ {
public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2); public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2);
public delegate float similarityMethodDelegate(float[] vector1, float[] vector2); public delegate float similarityMethodDelegate(float[] vector1, float[] vector2);
public static readonly Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate> probMethods; public static readonly Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate> ProbMethods;
static SimilarityMethods() static SimilarityMethods()
{ {
probMethods = new Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate> ProbMethods = new Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate>
{ {
[SimilarityMethodEnum.Cosine] = CosineSimilarity, [SimilarityMethodEnum.Cosine] = CosineSimilarity,
[SimilarityMethodEnum.Euclidian] = EuclidianDistance, [SimilarityMethodEnum.Euclidian] = EuclidianDistance,
@@ -49,7 +44,7 @@ public static class SimilarityMethods
methodName methodName
); );
if (!probMethods.TryGetValue(probMethodEnum, out similarityMethodProtoDelegate? method)) if (!ProbMethods.TryGetValue(probMethodEnum, out similarityMethodProtoDelegate? method))
{ {
return null; return null;
} }

View File

@@ -20,6 +20,23 @@ public class EntityQueryResult
public Dictionary<string, string>? Attributes { get; set; } public Dictionary<string, string>? Attributes { get; set; }
} }
public class EntityRerankResults : SuccesMessageBaseModel
{
[JsonPropertyName("Results")]
public required List<EntityRerankResult> Results { get; set; }
}
public class EntityRerankResult
{
[JsonPropertyName("Name")]
public required string Name { get; set; }
[JsonPropertyName("Value")]
public float Value { get; set; }
[JsonPropertyName("Attributes")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public Dictionary<string, string>? Attributes { get; set; }
}
public class EntityIndexResult : SuccesMessageBaseModel {} public class EntityIndexResult : SuccesMessageBaseModel {}
public class EntityListResults public class EntityListResults