Merge pull request #130 from LD-Reborn/129-post-entity-only-does-upserting

129 post entity only does upserting
This commit is contained in:
LD50
2026-02-22 20:00:08 +01:00
committed by GitHub
18 changed files with 487 additions and 277 deletions

View File

@@ -47,15 +47,27 @@ public class Client
return await FetchUrlAndProcessJson<EntityListResults>(HttpMethod.Get, url);
}
public async Task<EntityIndexResult> EntityIndexAsync(List<JSONEntity> jsonEntity)
public async Task<EntityIndexResult> EntityIndexAsync(List<JSONEntity> jsonEntity, string? sessionId = null, bool? sessionComplete = null)
{
return await EntityIndexAsync(JsonSerializer.Serialize(jsonEntity));
return await EntityIndexAsync(JsonSerializer.Serialize(jsonEntity), sessionId, sessionComplete);
}
public async Task<EntityIndexResult> EntityIndexAsync(string jsonEntity)
public async Task<EntityIndexResult> EntityIndexAsync(string jsonEntity, string? sessionId = null, bool? sessionComplete = null)
{
var content = new StringContent(jsonEntity, Encoding.UTF8, "application/json");
return await FetchUrlAndProcessJson<EntityIndexResult>(HttpMethod.Put, GetUrl($"{baseUri}", "Entities", []), content);
Dictionary<string, string> parameters = [];
if (sessionId is not null) parameters.Add("sessionId", sessionId);
if (sessionComplete is not null) parameters.Add("sessionComplete", ((bool)sessionComplete).ToString());
return await FetchUrlAndProcessJson<EntityIndexResult>(
HttpMethod.Put,
GetUrl(
$"{baseUri}",
$"Entities",
parameters
),
content
);
}
public async Task<EntityDeleteResults> EntityDeleteAsync(string entityName)

View File

@@ -65,6 +65,7 @@ def index_files(toolset: Toolset):
jsonEntities.append(jsonEntity)
jsonstring = json.dumps(jsonEntities)
timer_start = time.time()
# Index all entities in one go. If you need to split it into chunks, use the session attributes! See example_chunked.py
result:EntityIndexResult = toolset.Client.EntityIndexAsync(jsonstring).Result
timer_end = time.time()
toolset.Logger.LogInformation(f"Update was successful: {result.Success} - and was done in {timer_end - timer_start} seconds.")

View File

@@ -0,0 +1,85 @@
import math
import os
from tools import *
import json
from dataclasses import asdict
import time
import uuid
example_content = "./Scripts/example_content"
probmethod = "HVEWAvg"
similarityMethod = "Cosine"
example_searchdomain = "example_" + probmethod
example_counter = 0
models = ["ollama:bge-m3", "ollama:mxbai-embed-large"]
probmethod_datapoint = probmethod
probmethod_entity = probmethod
# Example for a dictionary based weighted average:
# probmethod_datapoint = "DictionaryWeightedAverage:{\"ollama:bge-m3\": 4, \"ollama:mxbai-embed-large\": 1}"
# probmethod_entity = "DictionaryWeightedAverage:{\"title\": 2, \"filename\": 0.1, \"text\": 0.25}"
def init(toolset: Toolset):
global example_counter
toolset.Logger.LogInformation("{toolset.Name} - init", toolset.Name)
toolset.Logger.LogInformation("This is the init function from the python example script")
toolset.Logger.LogInformation(f"example_counter: {example_counter}")
searchdomainlist:SearchdomainListResults = toolset.Client.SearchdomainListAsync().Result
if example_searchdomain not in searchdomainlist.Searchdomains:
toolset.Client.SearchdomainCreateAsync(example_searchdomain).Result
searchdomainlist = toolset.Client.SearchdomainListAsync().Result
output = "Currently these searchdomains exist:\n"
for searchdomain in searchdomainlist.Searchdomains:
output += f" - {searchdomain}\n"
toolset.Logger.LogInformation(output)
def update(toolset: Toolset):
global example_counter
toolset.Logger.LogInformation("{toolset.Name} - update", toolset.Name)
toolset.Logger.LogInformation("This is the update function from the python example script")
callbackInfos:ICallbackInfos = toolset.CallbackInfos
if (str(callbackInfos) == "Indexer.Models.RunOnceCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by a runonce call")
elif (str(callbackInfos) == "Indexer.Models.IntervalCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by an interval call")
elif (str(callbackInfos) == "Indexer.Models.ScheduleCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by a schedule call")
elif (str(callbackInfos) == "Indexer.Models.FileUpdateCallbackInfos"):
toolset.Logger.LogInformation("It was triggered by a fileupdate call")
else:
toolset.Logger.LogInformation("It was triggered, but the origin of the call could not be determined")
example_counter += 1
toolset.Logger.LogInformation(f"example_counter: {example_counter}")
index_files(toolset)
def index_files(toolset: Toolset):
jsonEntities:list = []
for filename in os.listdir(example_content):
qualified_filepath = example_content + "/" + filename
with open(qualified_filepath, "r", encoding='utf-8', errors="replace") as file:
title = file.readline()
text = file.read()
datapoints:list = [
JSONDatapoint("filename", qualified_filepath, probmethod_datapoint, similarityMethod, models),
JSONDatapoint("title", title, probmethod_datapoint, similarityMethod, models),
JSONDatapoint("text", text, probmethod_datapoint, similarityMethod, models)
]
jsonEntity:dict = asdict(JSONEntity(qualified_filepath, probmethod_entity, example_searchdomain, {}, datapoints))
jsonEntities.append(jsonEntity)
timer_start = time.time()
chunkSize = 5
chunkList = chunk_list(jsonEntities, chunkSize)
chunkCount = math.ceil(len(jsonEntities) / chunkSize)
sessionId = uuid.uuid4().hex
print(f"indexing {len(jsonEntities)} entities")
for i, entities in enumerate(chunkList):
isLast = i == chunkCount
print(f'Processing chunk {i} / {len(jsonEntities) / chunkSize}')
jsonstring = json.dumps(entities)
result:EntityIndexResult = toolset.Client.EntityIndexAsync(jsonstring, sessionId, isLast).Result
timer_end = time.time()
toolset.Logger.LogInformation(f"Update was successful: {result.Success} - and was done in {timer_end - timer_start} seconds.")
def chunk_list(lst, chunk_size):
for i in range(0, len(lst), chunk_size):
yield lst[i: i + chunk_size]

View File

@@ -107,6 +107,8 @@ class Client:
# pass
async def EntityIndexAsync(jsonEntity:str) -> EntityIndexResult:
pass
async def EntityIndexAsync(jsonEntity:str, sessionId:str, sessionComplete:bool) -> EntityIndexResult:
pass
async def EntityIndexAsync(searchdomain:str, jsonEntity:str) -> EntityIndexResult:
pass
async def EntityListAsync(returnEmbeddings:bool = False) -> EntityListResults:

View File

@@ -13,7 +13,7 @@ public class AIProvider
{
private readonly ILogger<AIProvider> _logger;
private readonly EmbeddingSearchOptions _configuration;
public Dictionary<string, AiProvider> aIProvidersConfiguration;
public Dictionary<string, AiProvider> AiProvidersConfiguration;
public AIProvider(ILogger<AIProvider> logger, IOptions<EmbeddingSearchOptions> configuration)
{
@@ -27,7 +27,7 @@ public class AIProvider
}
else
{
aIProvidersConfiguration = retrievedAiProvidersConfiguration;
AiProvidersConfiguration = retrievedAiProvidersConfiguration;
}
}
@@ -41,7 +41,7 @@ public class AIProvider
Uri uri = new(modelUri);
string provider = uri.Scheme;
string model = uri.AbsolutePath;
AiProvider? aIProvider = aIProvidersConfiguration
AiProvider? aIProvider = AiProvidersConfiguration
.FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower()))
.Value;
if (aIProvider is null)
@@ -134,7 +134,7 @@ public class AIProvider
public string[] GetModels()
{
var aIProviders = aIProvidersConfiguration;
var aIProviders = AiProvidersConfiguration;
List<string> results = [];
foreach (KeyValuePair<string, AiProvider> aIProviderKV in aIProviders)
{

View File

@@ -14,6 +14,9 @@ public class EntityController : ControllerBase
private SearchdomainManager _domainManager;
private readonly SearchdomainHelper _searchdomainHelper;
private readonly DatabaseHelper _databaseHelper;
private readonly Dictionary<string, EntityIndexSessionData> _sessions = [];
private readonly object _sessionLock = new();
private const int SessionTimeoutMinutes = 60; // TODO: remove magic number; add an optional configuration option
public EntityController(ILogger<EntityController> logger, IConfiguration config, SearchdomainManager domainManager, SearchdomainHelper searchdomainHelper, DatabaseHelper databaseHelper)
{
@@ -46,34 +49,34 @@ public class EntityController : ControllerBase
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
EntityListResults entityListResults = new() {Results = [], Success = true};
foreach ((string _, Entity entity) in searchdomain_.entityCache)
foreach ((string _, Entity entity) in searchdomain_.EntityCache)
{
List<AttributeResult> attributeResults = [];
foreach (KeyValuePair<string, string> attribute in entity.attributes)
foreach (KeyValuePair<string, string> attribute in entity.Attributes)
{
attributeResults.Add(new AttributeResult() {Name = attribute.Key, Value = attribute.Value});
}
List<DatapointResult> datapointResults = [];
foreach (Datapoint datapoint in entity.datapoints)
foreach (Datapoint datapoint in entity.Datapoints)
{
if (returnModels)
{
List<EmbeddingResult> embeddingResults = [];
foreach ((string, float[]) embedding in datapoint.embeddings)
foreach ((string, float[]) embedding in datapoint.Embeddings)
{
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = returnEmbeddings ? embedding.Item2 : []});
}
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = embeddingResults});
datapointResults.Add(new DatapointResult() {Name = datapoint.Name, ProbMethod = datapoint.ProbMethod.Name, SimilarityMethod = datapoint.SimilarityMethod.Name, Embeddings = embeddingResults});
}
else
{
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = null});
datapointResults.Add(new DatapointResult() {Name = datapoint.Name, ProbMethod = datapoint.ProbMethod.Name, SimilarityMethod = datapoint.SimilarityMethod.Name, Embeddings = null});
}
}
EntityListResult entityListResult = new()
{
Name = entity.name,
ProbMethod = entity.probMethodName,
Name = entity.Name,
ProbMethod = entity.ProbMethodName,
Attributes = attributeResults,
Datapoints = datapointResults
};
@@ -86,31 +89,59 @@ public class EntityController : ControllerBase
/// Index entities
/// </summary>
/// <remarks>
/// Behavior: Creates new entities, but overwrites existing entities that have the same name
/// Behavior: Updates the index using the provided entities. Creates new entities, overwrites existing entities with the same name, and deletes entities that are not part of the index anymore.
///
/// Can be executed in a single request or in multiple chunks using a (self-defined) session UUID string.
///
/// For session-based chunk uploads:
/// - Provide sessionId to accumulate entities across multiple requests
/// - Set sessionComplete=true on the final request to finalize and delete entities that are not in the accumulated list
/// - Without sessionId: Missing entities will be deleted from the searchdomain.
/// - Sessions expire after 60 minutes of inactivity (or as otherwise configured in the appsettings)
/// </remarks>
/// <param name="jsonEntities">Entities to index</param>
/// <param name="sessionId">Optional session ID for batch uploads across multiple requests</param>
/// <param name="sessionComplete">If true, finalizes the session and deletes entities not in the accumulated list</param>
[HttpPut("/Entities")]
public async Task<ActionResult<EntityIndexResult>> Index([FromBody] List<JSONEntity>? jsonEntities)
public async Task<ActionResult<EntityIndexResult>> Index(
[FromBody] List<JSONEntity>? jsonEntities,
string? sessionId = null,
bool sessionComplete = false)
{
try
{
if (sessionId is null || string.IsNullOrWhiteSpace(sessionId))
{
sessionId = Guid.NewGuid().ToString(); // Create a short-lived session
sessionComplete = true; // If no sessionId was set, there is no trackable session. The pseudo-session ends here.
}
// Periodic cleanup of expired sessions
CleanupExpiredEntityIndexSessions();
EntityIndexSessionData session = GetOrCreateEntityIndexSession(sessionId);
if (jsonEntities is null && !sessionComplete)
{
return BadRequest(new EntityIndexResult() { Success = false, Message = "jsonEntities can only be null for a complete session" });
} else if (jsonEntities is null && sessionComplete)
{
await EntityIndexSessionDeleteUnindexedEntities(session);
return Ok(new EntityIndexResult() { Success = true });
}
// Standard entity indexing (upsert behavior)
List<Entity>? entities = await _searchdomainHelper.EntitiesFromJSON(
_domainManager,
_logger,
JsonSerializer.Serialize(jsonEntities));
if (entities is not null && jsonEntities is not null)
{
List<string> invalidatedSearchdomains = [];
foreach (var jsonEntity in jsonEntities)
session.AccumulatedEntities.AddRange(entities);
if (sessionComplete)
{
string jsonEntityName = jsonEntity.Name;
string jsonEntitySearchdomainName = jsonEntity.Searchdomain;
if (entities.Select(x => x.name == jsonEntityName).Any()
&& !invalidatedSearchdomains.Contains(jsonEntitySearchdomainName))
{
invalidatedSearchdomains.Add(jsonEntitySearchdomainName);
}
await EntityIndexSessionDeleteUnindexedEntities(session);
}
return Ok(new EntityIndexResult() { Success = true });
}
else
@@ -129,6 +160,44 @@ public class EntityController : ControllerBase
}
private async Task EntityIndexSessionDeleteUnindexedEntities(EntityIndexSessionData session)
{
var entityGroupsBySearchdomain = session.AccumulatedEntities.GroupBy(e => e.Searchdomain);
foreach (var entityGroup in entityGroupsBySearchdomain)
{
string searchdomainName = entityGroup.Key;
var entityNamesInRequest = entityGroup.Select(e => e.Name).ToHashSet();
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) =
SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomainName, _logger);
if (searchdomain_ is not null && httpStatusCode is null) // If getting searchdomain was successful
{
var entitiesToDelete = searchdomain_.EntityCache
.Where(kvp => !entityNamesInRequest.Contains(kvp.Value.Name))
.Select(kvp => kvp.Value)
.ToList();
foreach (var entity in entitiesToDelete)
{
searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity);
await _databaseHelper.RemoveEntity(
[],
_domainManager.Helper,
entity.Name,
searchdomainName);
searchdomain_.EntityCache.TryRemove(entity.Name, out _);
_logger.LogInformation("Deleted entity {entityName} from {searchdomain}", entity.Name, searchdomainName);
}
}
else
{
_logger.LogWarning("Unable to delete entities for searchdomain {searchdomain}", searchdomainName);
}
}
}
/// <summary>
/// Deletes an entity
/// </summary>
@@ -140,7 +209,7 @@ public class EntityController : ControllerBase
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
Entity? entity_ = SearchdomainHelper.CacheGetEntity(searchdomain_.entityCache, entityName);
Entity? entity_ = SearchdomainHelper.CacheGetEntity(searchdomain_.EntityCache, entityName);
if (entity_ is null)
{
_logger.LogError("Unable to delete the entity {entityName} in {searchdomain} - it was not found under the specified name", [entityName, searchdomain]);
@@ -152,10 +221,50 @@ public class EntityController : ControllerBase
return Ok(new EntityDeleteResults() {Success = false, Message = "Entity not found"});
}
searchdomain_.ReconciliateOrInvalidateCacheForDeletedEntity(entity_);
await _databaseHelper.RemoveEntity([], _domainManager.helper, entityName, searchdomain);
await _databaseHelper.RemoveEntity([], _domainManager.Helper, entityName, searchdomain);
bool success = searchdomain_.entityCache.TryRemove(entityName, out Entity? _);
bool success = searchdomain_.EntityCache.TryRemove(entityName, out Entity? _);
return Ok(new EntityDeleteResults() {Success = success});
}
private void CleanupExpiredEntityIndexSessions()
{
lock (_sessionLock)
{
var expiredSessions = _sessions
.Where(kvp => (DateTime.UtcNow - kvp.Value.LastInteractionAt).TotalMinutes > SessionTimeoutMinutes)
.Select(kvp => kvp.Key)
.ToList();
foreach (var sessionId in expiredSessions)
{
_sessions.Remove(sessionId);
_logger.LogWarning("Removed expired, non-closed session {sessionId}", sessionId);
}
}
}
private EntityIndexSessionData GetOrCreateEntityIndexSession(string sessionId)
{
lock (_sessionLock)
{
if (!_sessions.TryGetValue(sessionId, out var session))
{
session = new EntityIndexSessionData();
_sessions[sessionId] = session;
} else
{
session.LastInteractionAt = DateTime.UtcNow;
}
return session;
}
}
}
public class EntityIndexSessionData
{
public List<Entity> AccumulatedEntities { get; set; } = [];
public DateTime LastInteractionAt { get; set; } = DateTime.UtcNow;
}

View File

@@ -118,18 +118,18 @@ public class SearchdomainController : ControllerBase
Dictionary<string, dynamic> parameters = new()
{
{"name", newName},
{"id", searchdomain_.id}
{"id", searchdomain_.Id}
};
await searchdomain_.helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name WHERE id = @id", parameters);
await searchdomain_.Helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name WHERE id = @id", parameters);
} else
{
Dictionary<string, dynamic> parameters = new()
{
{"name", newName},
{"settings", settings},
{"id", searchdomain_.id}
{"id", searchdomain_.Id}
};
await searchdomain_.helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name, settings = @settings WHERE id = @id", parameters);
await searchdomain_.Helper.ExecuteSQLNonQuery("UPDATE searchdomain set name = @name, settings = @settings WHERE id = @id", parameters);
}
return Ok(new SearchdomainUpdateResults(){Success = true});
}
@@ -143,7 +143,7 @@ public class SearchdomainController : ControllerBase
{
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
Dictionary<string, DateTimedSearchResult> searchCache = searchdomain_.queryCache.AsDictionary();
Dictionary<string, DateTimedSearchResult> searchCache = searchdomain_.QueryCache.AsDictionary();
return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true });
}
@@ -165,7 +165,7 @@ public class SearchdomainController : ControllerBase
{
Name = r.Item2,
Value = r.Item1,
Attributes = returnAttributes ? (searchdomain_.entityCache[r.Item2]?.attributes ?? null) : null
Attributes = returnAttributes ? (searchdomain_.EntityCache[r.Item2]?.Attributes ?? null) : null
})];
return Ok(new EntityQueryResults(){Results = queryResults, Success = true });
}
@@ -180,7 +180,7 @@ public class SearchdomainController : ControllerBase
{
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.queryCache;
EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.QueryCache;
bool containsKey = searchCache.ContainsKey(query);
if (containsKey)
{
@@ -201,7 +201,7 @@ public class SearchdomainController : ControllerBase
{
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.queryCache;
EnumerableLruCache<string, DateTimedSearchResult> searchCache = searchdomain_.QueryCache;
bool containsKey = searchCache.ContainsKey(query);
if (containsKey)
{
@@ -222,7 +222,7 @@ public class SearchdomainController : ControllerBase
{
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
SearchdomainSettings settings = searchdomain_.settings;
SearchdomainSettings settings = searchdomain_.Settings;
return Ok(new SearchdomainSettingsResults() { Settings = settings, Success = true });
}
@@ -239,11 +239,11 @@ public class SearchdomainController : ControllerBase
Dictionary<string, dynamic> parameters = new()
{
{"settings", JsonSerializer.Serialize(request)},
{"id", searchdomain_.id}
{"id", searchdomain_.Id}
};
await searchdomain_.helper.ExecuteSQLNonQuery("UPDATE searchdomain set settings = @settings WHERE id = @id", parameters);
searchdomain_.settings = request;
searchdomain_.queryCache.Capacity = request.QueryCacheSize;
await searchdomain_.Helper.ExecuteSQLNonQuery("UPDATE searchdomain set settings = @settings WHERE id = @id", parameters);
searchdomain_.Settings = request;
searchdomain_.QueryCache.Capacity = request.QueryCacheSize;
return Ok(new SearchdomainUpdateResults(){Success = true});
}
@@ -260,8 +260,8 @@ public class SearchdomainController : ControllerBase
}
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
int elementCount = searchdomain_.queryCache.Count;
int ElementMaxCount = searchdomain_.settings.QueryCacheSize;
int elementCount = searchdomain_.QueryCache.Count;
int ElementMaxCount = searchdomain_.Settings.QueryCacheSize;
return Ok(new SearchdomainQueryCacheSizeResults() { SizeBytes = searchdomain_.GetSearchCacheSize(), ElementCount = elementCount, ElementMaxCount = ElementMaxCount, Success = true });
}
@@ -287,7 +287,7 @@ public class SearchdomainController : ControllerBase
{
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
long EmbeddingCacheUtilization = DatabaseHelper.GetSearchdomainDatabaseSize(searchdomain_.helper, searchdomain);
long EmbeddingCacheUtilization = DatabaseHelper.GetSearchdomainDatabaseSize(searchdomain_.Helper, searchdomain);
return Ok(new SearchdomainGetDatabaseSizeResult() { SearchdomainDatabaseSizeBytes = EmbeddingCacheUtilization, Success = true });
}
}

View File

@@ -58,7 +58,7 @@ public class ServerController : ControllerBase
long size = 0;
long elementCount = 0;
long embeddingsCount = 0;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.embeddingCache;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = _searchdomainManager.EmbeddingCache;
foreach (KeyValuePair<string, Dictionary<string, float[]>> kv in embeddingCache)
{
@@ -68,7 +68,7 @@ public class ServerController : ControllerBase
elementCount++;
embeddingsCount += entry.Keys.Count;
}
var sqlHelper = _searchdomainManager.helper;
var sqlHelper = _searchdomainManager.Helper;
var databaseTotalSize = DatabaseHelper.GetTotalDatabaseSize(sqlHelper);
Task<long> entityCountTask = DatabaseHelper.CountEntities(sqlHelper);
long queryCacheUtilization = 0;
@@ -82,9 +82,9 @@ public class ServerController : ControllerBase
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_searchdomainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new ServerGetStatsResult(){Success = false, Message = message});
queryCacheUtilization += searchdomain_.GetSearchCacheSize();
queryCacheElementCount += searchdomain_.queryCache.Count;
queryCacheMaxElementCountAll += searchdomain_.queryCache.Capacity;
queryCacheMaxElementCountLoadedSearchdomainsOnly += searchdomain_.queryCache.Capacity;
queryCacheElementCount += searchdomain_.QueryCache.Count;
queryCacheMaxElementCountAll += searchdomain_.QueryCache.Capacity;
queryCacheMaxElementCountLoadedSearchdomainsOnly += searchdomain_.QueryCache.Capacity;
} else
{
var searchdomainSettings = DatabaseHelper.GetSearchdomainSettings(sqlHelper, searchdomain);

View File

@@ -6,36 +6,36 @@ namespace Server;
public class Datapoint
{
public string name;
public ProbMethod probMethod;
public SimilarityMethod similarityMethod;
public List<(string, float[])> embeddings;
public string hash;
public int id;
public string Name;
public ProbMethod ProbMethod;
public SimilarityMethod SimilarityMethod;
public List<(string, float[])> Embeddings;
public string Hash;
public int Id;
public Datapoint(string name, ProbMethodEnum probMethod, SimilarityMethodEnum similarityMethod, string hash, List<(string, float[])> embeddings, int id)
{
this.name = name;
this.probMethod = new ProbMethod(probMethod);
this.similarityMethod = new SimilarityMethod(similarityMethod);
this.hash = hash;
this.embeddings = embeddings;
this.id = id;
Name = name;
ProbMethod = new ProbMethod(probMethod);
SimilarityMethod = new SimilarityMethod(similarityMethod);
Hash = hash;
Embeddings = embeddings;
Id = id;
}
public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings, int id)
{
this.name = name;
this.probMethod = probMethod;
this.similarityMethod = similarityMethod;
this.hash = hash;
this.embeddings = embeddings;
this.id = id;
Name = name;
ProbMethod = probMethod;
SimilarityMethod = similarityMethod;
Hash = hash;
Embeddings = embeddings;
Id = id;
}
public float CalcProbability(List<(string, float)> probabilities)
{
return probMethod.method(probabilities);
return ProbMethod.Method(probabilities);
}
public static Dictionary<string, float[]> GetEmbeddings(string content, ConcurrentBag<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)

View File

@@ -2,12 +2,13 @@ using System.Collections.Concurrent;
namespace Server;
public class Entity(Dictionary<string, string> attributes, Probmethods.probMethodDelegate probMethod, string probMethodName, ConcurrentBag<Datapoint> datapoints, string name)
public class Entity(Dictionary<string, string> attributes, Probmethods.ProbMethodDelegate probMethod, string probMethodName, ConcurrentBag<Datapoint> datapoints, string name, string searchdomain)
{
public Dictionary<string, string> attributes = attributes;
public Probmethods.probMethodDelegate probMethod = probMethod;
public string probMethodName = probMethodName;
public ConcurrentBag<Datapoint> datapoints = datapoints;
public int id;
public string name = name;
public Dictionary<string, string> Attributes = attributes;
public Probmethods.ProbMethodDelegate ProbMethod = probMethod;
public string ProbMethodName = probMethodName;
public ConcurrentBag<Datapoint> Datapoints = datapoints;
public int Id;
public string Name = name;
public string Searchdomain = searchdomain;
}

View File

@@ -17,7 +17,7 @@ public class DatabaseHealthCheck : IHealthCheck
{
try
{
DatabaseMigrations.DatabaseGetVersion(_searchdomainManager.helper);
DatabaseMigrations.DatabaseGetVersion(_searchdomainManager.Helper);
}
catch (Exception ex)
{
@@ -28,8 +28,8 @@ public class DatabaseHealthCheck : IHealthCheck
try
{
await _searchdomainManager.helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES ('test', 'x');", []);
await _searchdomainManager.helper.ExecuteSQLNonQuery("DELETE FROM settings WHERE name = 'test';", []);
await _searchdomainManager.Helper.ExecuteSQLNonQuery("INSERT INTO settings (name, value) VALUES ('test', 'x');", []);
await _searchdomainManager.Helper.ExecuteSQLNonQuery("DELETE FROM settings WHERE name = 'test';", []);
}
catch (Exception ex)
{

View File

@@ -205,7 +205,7 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
await helper.ExecuteSQLNonQuery("DELETE datapoint.* FROM datapoint JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters);
await helper.ExecuteSQLNonQuery("DELETE attribute.* FROM attribute JOIN entity ON id_entity = entity.id WHERE entity.name = @name AND entity.id_searchdomain = @searchdomain", parameters);
await helper.ExecuteSQLNonQuery("DELETE FROM entity WHERE name = @name AND entity.id_searchdomain = @searchdomain", parameters);
entityCache.RemoveAll(entity => entity.name == name);
entityCache.RemoveAll(entity => entity.Name == name);
}
public async Task<int> RemoveAllEntities(SQLHelper helper, string searchdomain)
@@ -243,7 +243,7 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
{ "name", name },
{ "searchdomain", await GetSearchdomainID(helper, searchdomain)}
};
lock (helper.connection)
lock (helper.Connection)
{
DbDataReader reader = helper.ExecuteSQLCommand("SELECT COUNT(*) FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters);
try
@@ -273,7 +273,7 @@ public class DatabaseHelper(ILogger<DatabaseHelper> logger)
{ "name", name },
{ "searchdomain", await GetSearchdomainID(helper, searchdomain)}
};
lock (helper.connection)
lock (helper.Connection)
{
DbDataReader reader = helper.ExecuteSQLCommand("SELECT id FROM entity WHERE name = @name AND id_searchdomain = @searchdomain", parameters);
try

View File

@@ -6,18 +6,18 @@ namespace Server.Helper;
public class SQLHelper:IDisposable
{
public MySqlConnection connection;
public DbDataReader? dbDataReader;
public MySqlConnectionPoolElement[] connectionPool;
public string connectionString;
public MySqlConnection Connection;
public DbDataReader? DbDataReader;
public MySqlConnectionPoolElement[] ConnectionPool;
public string ConnectionString;
public SQLHelper(MySqlConnection connection, string connectionString)
{
this.connection = connection;
this.connectionString = connectionString;
connectionPool = new MySqlConnectionPoolElement[50];
for (int i = 0; i < connectionPool.Length; i++)
Connection = connection;
ConnectionString = connectionString;
ConnectionPool = new MySqlConnectionPoolElement[50];
for (int i = 0; i < ConnectionPool.Length; i++)
{
connectionPool[i] = new MySqlConnectionPoolElement(new MySqlConnection(connectionString), new(1, 1));
ConnectionPool[i] = new MySqlConnectionPoolElement(new MySqlConnection(connectionString), new(1, 1));
}
}
@@ -28,24 +28,24 @@ public class SQLHelper:IDisposable
public void Dispose()
{
connection.Close();
Connection.Close();
GC.SuppressFinalize(this);
}
public DbDataReader ExecuteSQLCommand(string query, Dictionary<string, dynamic> parameters)
{
lock (connection)
lock (Connection)
{
EnsureConnected();
EnsureDbReaderIsClosed();
using MySqlCommand command = connection.CreateCommand();
using MySqlCommand command = Connection.CreateCommand();
command.CommandText = query;
foreach (KeyValuePair<string, dynamic> parameter in parameters)
{
command.Parameters.AddWithValue($"@{parameter.Key}", parameter.Value);
}
dbDataReader = command.ExecuteReader();
return dbDataReader;
DbDataReader = command.ExecuteReader();
return DbDataReader;
}
}
@@ -55,7 +55,7 @@ public class SQLHelper:IDisposable
Func<DbDataReader, T> map)
{
var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.connection;
var connection = poolElement.Connection;
try
{
await using var command = connection.CreateCommand();
@@ -83,7 +83,7 @@ public class SQLHelper:IDisposable
public async Task<int> ExecuteSQLNonQuery(string query, Dictionary<string, dynamic> parameters)
{
var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.connection;
var connection = poolElement.Connection;
try
{
using MySqlCommand command = connection.CreateCommand();
@@ -103,7 +103,7 @@ public class SQLHelper:IDisposable
public async Task<int> ExecuteSQLCommandGetInsertedID(string query, Dictionary<string, dynamic> parameters)
{
var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.connection;
var connection = poolElement.Connection;
try
{
using MySqlCommand command = connection.CreateCommand();
@@ -125,7 +125,7 @@ public class SQLHelper:IDisposable
public async Task<int> BulkExecuteNonQuery(string sql, IEnumerable<object[]> parameterSets)
{
var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.connection;
var connection = poolElement.Connection;
try
{
int affectedRows = 0;
@@ -173,14 +173,14 @@ public class SQLHelper:IDisposable
int sleepTime = 10;
do
{
foreach (var element in connectionPool)
foreach (var element in ConnectionPool)
{
if (element.Semaphore.Wait(0))
{
if (element.connection.State == ConnectionState.Closed)
if (element.Connection.State == ConnectionState.Closed)
{
await element.connection.CloseAsync();
await element.connection.OpenAsync();
await element.Connection.CloseAsync();
await element.Connection.OpenAsync();
}
return element;
}
@@ -194,12 +194,12 @@ public class SQLHelper:IDisposable
public bool EnsureConnected()
{
if (connection.State != System.Data.ConnectionState.Open)
if (Connection.State != System.Data.ConnectionState.Open)
{
try
{
connection.Close();
connection.Open();
Connection.Close();
Connection.Open();
}
catch (Exception ex)
{
@@ -215,7 +215,7 @@ public class SQLHelper:IDisposable
int counter = 0;
int sleepTime = 10;
int timeout = 5000;
while (!(dbDataReader?.IsClosed ?? true))
while (!(DbDataReader?.IsClosed ?? true))
{
if (counter > timeout / sleepTime)
{
@@ -230,12 +230,12 @@ public class SQLHelper:IDisposable
public struct MySqlConnectionPoolElement
{
public MySqlConnection connection;
public MySqlConnection Connection;
public SemaphoreSlim Semaphore;
public MySqlConnectionPoolElement(MySqlConnection connection, SemaphoreSlim semaphore)
{
this.connection = connection;
this.Semaphore = semaphore;
Connection = connection;
Semaphore = semaphore;
}
}

View File

@@ -39,7 +39,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{
foreach ((string _, Entity entity) in entityCache)
{
if (entity.name == name)
if (entity.Name == name)
{
return entity;
}
@@ -49,9 +49,9 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
public async Task<List<Entity>?> EntitiesFromJSON(SearchdomainManager searchdomainManager, ILogger logger, string json)
{
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.embeddingCache;
AIProvider aIProvider = searchdomainManager.aIProvider;
SQLHelper helper = searchdomainManager.helper;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomainManager.EmbeddingCache;
AIProvider aIProvider = searchdomainManager.AiProvider;
SQLHelper helper = searchdomainManager.Helper;
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
if (jsonEntities is null)
@@ -65,7 +65,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
foreach (JSONEntity jSONEntity in jsonEntities)
{
Dictionary<string, List<string>> targetDictionary = toBeCached;
if (searchdomainManager.GetSearchdomain(jSONEntity.Searchdomain).settings.ParallelEmbeddingsPrefetch)
if (searchdomainManager.GetSearchdomain(jSONEntity.Searchdomain).Settings.ParallelEmbeddingsPrefetch)
{
targetDictionary = toBeCachedParallel;
}
@@ -126,12 +126,12 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{
var stopwatch = Stopwatch.StartNew();
SQLHelper helper = searchdomainManager.helper;
SQLHelper helper = searchdomainManager.Helper;
Searchdomain searchdomain = searchdomainManager.GetSearchdomain(jsonEntity.Searchdomain);
int id_searchdomain = searchdomain.id;
ConcurrentDictionary<string, Entity> entityCache = searchdomain.entityCache;
AIProvider aIProvider = searchdomain.aIProvider;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
int id_searchdomain = searchdomain.Id;
ConcurrentDictionary<string, Entity> entityCache = searchdomain.EntityCache;
AIProvider aIProvider = searchdomain.AiProvider;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.EmbeddingCache;
bool invalidateSearchCache = false;
@@ -140,15 +140,15 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
if (hasEntity && preexistingEntity is not null)
{
int preexistingEntityID = preexistingEntity.id;
int preexistingEntityID = preexistingEntity.Id;
Dictionary<string, string> attributes = jsonEntity.Attributes;
// Attribute - get changes
List<(string attribute, string newValue, int entityId)> updatedAttributes = new(preexistingEntity.attributes.Count);
List<(string attribute, int entityId)> deletedAttributes = new(preexistingEntity.attributes.Count);
List<(string attribute, string newValue, int entityId)> updatedAttributes = new(preexistingEntity.Attributes.Count);
List<(string attribute, int entityId)> deletedAttributes = new(preexistingEntity.Attributes.Count);
List<(string attributeKey, string attribute, int entityId)> addedAttributes = new(jsonEntity.Attributes.Count);
foreach (KeyValuePair<string, string> attributesKV in preexistingEntity.attributes) //.ToList())
foreach (KeyValuePair<string, string> attributesKV in preexistingEntity.Attributes) //.ToList())
{
string oldAttributeKey = attributesKV.Key;
string oldAttribute = attributesKV.Value;
@@ -166,7 +166,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{
string newAttributeKey = attributesKV.Key;
string newAttribute = attributesKV.Value;
bool preexistingHasAttribute = preexistingEntity.attributes.TryGetValue(newAttributeKey, out string? preexistingAttribute);
bool preexistingHasAttribute = preexistingEntity.Attributes.TryGetValue(newAttributeKey, out string? preexistingAttribute);
if (!preexistingHasAttribute)
{
// Attribute - New
@@ -175,54 +175,54 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
}
if (updatedAttributes.Count != 0 || deletedAttributes.Count != 0 || addedAttributes.Count != 0)
_logger.LogDebug("EntityFromJSON - Updating existing entity. name: {name}, updatedAttributes: {updatedAttributes}, deletedAttributes: {deletedAttributes}, addedAttributes: {addedAttributes}", [preexistingEntity.name, updatedAttributes.Count, deletedAttributes.Count, addedAttributes.Count]);
_logger.LogDebug("EntityFromJSON - Updating existing entity. name: {name}, updatedAttributes: {updatedAttributes}, deletedAttributes: {deletedAttributes}, addedAttributes: {addedAttributes}", [preexistingEntity.Name, updatedAttributes.Count, deletedAttributes.Count, addedAttributes.Count]);
// Attribute - apply changes
if (updatedAttributes.Count != 0)
{
// Update
await DatabaseHelper.DatabaseUpdateAttributes(helper, updatedAttributes);
lock (preexistingEntity.attributes)
lock (preexistingEntity.Attributes)
{
updatedAttributes.ForEach(attribute => preexistingEntity.attributes[attribute.attribute] = attribute.newValue);
updatedAttributes.ForEach(attribute => preexistingEntity.Attributes[attribute.attribute] = attribute.newValue);
}
}
if (deletedAttributes.Count != 0)
{
// Delete
await DatabaseHelper.DatabaseDeleteAttributes(helper, deletedAttributes);
lock (preexistingEntity.attributes)
lock (preexistingEntity.Attributes)
{
deletedAttributes.ForEach(attribute => preexistingEntity.attributes.Remove(attribute.attribute));
deletedAttributes.ForEach(attribute => preexistingEntity.Attributes.Remove(attribute.attribute));
}
}
if (addedAttributes.Count != 0)
{
// Insert
await DatabaseHelper.DatabaseInsertAttributes(helper, addedAttributes);
lock (preexistingEntity.attributes)
lock (preexistingEntity.Attributes)
{
addedAttributes.ForEach(attribute => preexistingEntity.attributes.Add(attribute.attributeKey, attribute.attribute));
addedAttributes.ForEach(attribute => preexistingEntity.Attributes.Add(attribute.attributeKey, attribute.attribute));
}
}
// Datapoint - get changes
List<Datapoint> deletedDatapointInstances = new(preexistingEntity.datapoints.Count);
List<string> deletedDatapoints = new(preexistingEntity.datapoints.Count);
List<(string datapointName, int entityId, JSONDatapoint jsonDatapoint, string hash)> updatedDatapointsText = new(preexistingEntity.datapoints.Count);
List<(string datapointName, string probMethod, string similarityMethod, int entityId, JSONDatapoint jsonDatapoint)> updatedDatapointsNonText = new(preexistingEntity.datapoints.Count);
List<Datapoint> deletedDatapointInstances = new(preexistingEntity.Datapoints.Count);
List<string> deletedDatapoints = new(preexistingEntity.Datapoints.Count);
List<(string datapointName, int entityId, JSONDatapoint jsonDatapoint, string hash)> updatedDatapointsText = new(preexistingEntity.Datapoints.Count);
List<(string datapointName, string probMethod, string similarityMethod, int entityId, JSONDatapoint jsonDatapoint)> updatedDatapointsNonText = new(preexistingEntity.Datapoints.Count);
List<Datapoint> createdDatapointInstances = [];
List<(string name, ProbMethodEnum probmethod_embedding, SimilarityMethodEnum similarityMethod, string hash, Dictionary<string, float[]> embeddings, JSONDatapoint datapoint)> createdDatapoints = new(jsonEntity.Datapoints.Length);
foreach (Datapoint datapoint_ in preexistingEntity.datapoints.ToList())
foreach (Datapoint datapoint_ in preexistingEntity.Datapoints.ToList())
{
Datapoint datapoint = datapoint_; // To enable replacing the datapoint reference as foreach iterators cannot be overwritten
JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.name);
JSONDatapoint? newEntityDatapoint = jsonEntity.Datapoints.FirstOrDefault(x => x.Name == datapoint.Name);
bool newEntityHasDatapoint = newEntityDatapoint is not null;
if (!newEntityHasDatapoint)
{
// Datapoint - Deleted
deletedDatapointInstances.Add(datapoint);
deletedDatapoints.Add(datapoint.name);
deletedDatapoints.Add(datapoint.Name);
invalidateSearchCache = true;
} else
{
@@ -231,7 +231,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
newEntityDatapoint is not null
&& newEntityDatapoint.Text is not null
&& hash is not null
&& hash != datapoint.hash)
&& hash != datapoint.Hash)
{
// Datapoint - Updated (text)
updatedDatapointsText.Add(new()
@@ -245,8 +245,8 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
}
if (
newEntityDatapoint is not null
&& (newEntityDatapoint.Probmethod_embedding != datapoint.probMethod.probMethodEnum
|| newEntityDatapoint.SimilarityMethod != datapoint.similarityMethod.similarityMethodEnum))
&& (newEntityDatapoint.Probmethod_embedding != datapoint.ProbMethod.ProbMethodEnum
|| newEntityDatapoint.SimilarityMethod != datapoint.SimilarityMethod.SimilarityMethodEnum))
{
// Datapoint - Updated (probmethod or similaritymethod)
updatedDatapointsNonText.Add(new()
@@ -264,7 +264,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints)
{
bool oldEntityHasDatapoint = preexistingEntity.datapoints.Any(x => x.name == jsonDatapoint.Name);
bool oldEntityHasDatapoint = preexistingEntity.Datapoints.Any(x => x.Name == jsonDatapoint.Name);
if (!oldEntityHasDatapoint)
{
// Datapoint - New
@@ -290,13 +290,13 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
if (deletedDatapointInstances.Count != 0 || createdDatapoints.Count != 0 || addedAttributes.Count != 0 || updatedDatapointsNonText.Count != 0)
_logger.LogDebug(
"EntityFromJSON - Updating existing entity. name: {name}, deletedDatapointInstances: {deletedDatapointInstances}, createdDatapoints: {createdDatapoints}, addedAttributes: {addedAttributes}, updatedDatapointsNonText: {updatedDatapointsNonText}",
[preexistingEntity.name, deletedDatapointInstances.Count, createdDatapoints.Count, addedAttributes.Count, updatedDatapointsNonText.Count]);
[preexistingEntity.Name, deletedDatapointInstances.Count, createdDatapoints.Count, addedAttributes.Count, updatedDatapointsNonText.Count]);
// Datapoint - apply changes
// Deleted
if (deletedDatapointInstances.Count != 0)
{
await DatabaseHelper.DatabaseDeleteEmbeddingsAndDatapoints(helper, deletedDatapoints, (int)preexistingEntityID);
preexistingEntity.datapoints = [.. preexistingEntity.datapoints
preexistingEntity.Datapoints = [.. preexistingEntity.Datapoints
.Where(x =>
!deletedDatapointInstances.Contains(x)
)
@@ -306,7 +306,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
if (createdDatapoints.Count != 0)
{
List<Datapoint> datapoint = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. createdDatapoints.Select(element => (element.datapoint, element.hash))], (int)preexistingEntityID, id_searchdomain);
datapoint.ForEach(x => preexistingEntity.datapoints.Add(x));
datapoint.ForEach(x => preexistingEntity.Datapoints.Add(x));
}
// Datapoint - Updated (text)
if (updatedDatapointsText.Count != 0)
@@ -317,14 +317,14 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
.Select(d => d.datapointName)
.ToHashSet();
var newBag = new ConcurrentBag<Datapoint>(
preexistingEntity.datapoints
.Where(x => !namesToRemove.Contains(x.name))
preexistingEntity.Datapoints
.Where(x => !namesToRemove.Contains(x.Name))
);
preexistingEntity.datapoints = newBag;
preexistingEntity.Datapoints = newBag;
// Insert into database
List<Datapoint> datapoints = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, [.. updatedDatapointsText.Select(element => (datapoint: element.jsonDatapoint, hash: element.hash))], (int)preexistingEntityID, id_searchdomain);
// Insert into datapoints
datapoints.ForEach(datapoint => preexistingEntity.datapoints.Add(datapoint));
datapoints.ForEach(datapoint => preexistingEntity.Datapoints.Add(datapoint));
}
// Datapoint - Updated (probmethod or similaritymethod)
if (updatedDatapointsNonText.Count != 0)
@@ -336,9 +336,9 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
);
updatedDatapointsNonText.ForEach(element =>
{
Datapoint preexistingDatapoint = preexistingEntity.datapoints.First(x => x.name == element.datapointName);
preexistingDatapoint.probMethod = new(element.jsonDatapoint.Probmethod_embedding);
preexistingDatapoint.similarityMethod = new(element.jsonDatapoint.SimilarityMethod);
Datapoint preexistingDatapoint = preexistingEntity.Datapoints.First(x => x.Name == element.datapointName);
preexistingDatapoint.ProbMethod = new(element.jsonDatapoint.Probmethod_embedding);
preexistingDatapoint.SimilarityMethod = new(element.jsonDatapoint.SimilarityMethod);
});
}
@@ -368,7 +368,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
var insertAttributesTask = DatabaseHelper.DatabaseInsertAttributes(helper, toBeInsertedAttributes);
List<(JSONDatapoint datapoint, string hash)> toBeInsertedDatapoints = [];
ConcurrentBag<string> usedModels = searchdomain.modelsInUse;
ConcurrentBag<string> usedModels = searchdomain.ModelsInUse;
foreach (JSONDatapoint jsonDatapoint in jsonEntity.Datapoints)
{
string hash = Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
@@ -389,9 +389,9 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
List<Datapoint> datapoints = await DatabaseInsertDatapointsWithEmbeddings(helper, searchdomain, toBeInsertedDatapoints, id_entity, id_searchdomain);
var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new ProbMethodNotFoundException(jsonEntity.Probmethod);
Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), new(datapoints), jsonEntity.Name)
Entity entity = new(jsonEntity.Attributes, probMethod, jsonEntity.Probmethod.ToString(), [.. datapoints], jsonEntity.Name, jsonEntity.Searchdomain)
{
id = id_entity
Id = id_entity
};
entityCache[jsonEntity.Name] = entity;
@@ -412,16 +412,16 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
toBeInsertedDatapoints.Add(new()
{
name = datapoint.name,
probmethod_embedding = datapoint.probMethod.probMethodEnum,
similarityMethod = datapoint.similarityMethod.similarityMethodEnum,
name = datapoint.Name,
probmethod_embedding = datapoint.ProbMethod.ProbMethodEnum,
similarityMethod = datapoint.SimilarityMethod.SimilarityMethodEnum,
hash = value.hash
});
foreach ((string, float[]) embedding in datapoint.embeddings)
foreach ((string, float[]) embedding in datapoint.Embeddings)
{
toBeInsertedEmbeddings.Add(new()
{
id_datapoint = datapoint.id,
id_datapoint = datapoint.Id,
model = embedding.Item1,
embedding = BytesFromFloatArray(embedding.Item2)
});
@@ -444,7 +444,7 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
Datapoint datapoint = await BuildDatapointFromJsonDatapoint(jsonDatapoint, id_entity, searchdomain, hash);
int id_datapoint = await DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
List<(string model, byte[] embedding)> data = [];
foreach ((string, float[]) embedding in datapoint.embeddings)
foreach ((string, float[]) embedding in datapoint.Embeddings)
{
data.Add((embedding.Item1, BytesFromFloatArray(embedding.Item2)));
}
@@ -463,11 +463,11 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
{
throw new Exception("jsonDatapoint.Text must not be null at this point");
}
SQLHelper helper = searchdomain.helper;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.embeddingCache;
SQLHelper helper = searchdomain.Helper;
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = searchdomain.EmbeddingCache;
hash ??= Convert.ToBase64String(SHA256.HashData(Encoding.UTF8.GetBytes(jsonDatapoint.Text)));
int id = await DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, entityId);
Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.aIProvider, embeddingCache);
Dictionary<string, float[]> embeddings = Datapoint.GetEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], searchdomain.AiProvider, embeddingCache);
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding) ?? throw new ProbMethodNotFoundException(jsonDatapoint.Probmethod_embedding);
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod) ?? throw new SimilarityMethodNotFoundException(jsonDatapoint.SimilarityMethod);
return new Datapoint(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))], id);

View File

@@ -6,29 +6,29 @@ namespace Server;
public class ProbMethod
{
public Probmethods.probMethodDelegate method;
public ProbMethodEnum probMethodEnum;
public string name;
public Probmethods.ProbMethodDelegate Method;
public ProbMethodEnum ProbMethodEnum;
public string Name;
public ProbMethod(ProbMethodEnum probMethodEnum)
{
this.probMethodEnum = probMethodEnum;
this.name = probMethodEnum.ToString();
Probmethods.probMethodDelegate? probMethod = Probmethods.GetMethod(name) ?? throw new ProbMethodNotFoundException(probMethodEnum);
method = probMethod;
this.ProbMethodEnum = probMethodEnum;
this.Name = probMethodEnum.ToString();
Probmethods.ProbMethodDelegate? probMethod = Probmethods.GetMethod(Name) ?? throw new ProbMethodNotFoundException(probMethodEnum);
Method = probMethod;
}
}
public static class Probmethods
{
public delegate float probMethodProtoDelegate(List<(string, float)> list, string parameters);
public delegate float probMethodDelegate(List<(string, float)> list);
public static readonly Dictionary<ProbMethodEnum, probMethodProtoDelegate> probMethods;
public delegate float ProbMethodProtoDelegate(List<(string, float)> list, string parameters);
public delegate float ProbMethodDelegate(List<(string, float)> list);
public static readonly Dictionary<ProbMethodEnum, ProbMethodProtoDelegate> ProbMethods;
static Probmethods()
{
probMethods = new Dictionary<ProbMethodEnum, probMethodProtoDelegate>
ProbMethods = new Dictionary<ProbMethodEnum, ProbMethodProtoDelegate>
{
[ProbMethodEnum.Mean] = Mean,
[ProbMethodEnum.HarmonicMean] = HarmonicMean,
@@ -41,12 +41,12 @@ public static class Probmethods
};
}
public static probMethodDelegate? GetMethod(ProbMethodEnum probMethodEnum)
public static ProbMethodDelegate? GetMethod(ProbMethodEnum probMethodEnum)
{
return GetMethod(probMethodEnum.ToString());
}
public static probMethodDelegate? GetMethod(string name)
public static ProbMethodDelegate? GetMethod(string name)
{
string methodName = name;
string? jsonArg = "";
@@ -63,7 +63,7 @@ public static class Probmethods
methodName
);
if (!probMethods.TryGetValue(probMethodEnum, out probMethodProtoDelegate? method))
if (!ProbMethods.TryGetValue(probMethodEnum, out ProbMethodProtoDelegate? method))
{
return null;
}

View File

@@ -15,33 +15,33 @@ public class Searchdomain
{
private readonly string _connectionString;
private readonly string _provider;
public AIProvider aIProvider;
public string searchdomain;
public int id;
public SearchdomainSettings settings;
public EnumerableLruCache<string, DateTimedSearchResult> queryCache; // Key: query, Value: Search results for that query (with timestamp)
public ConcurrentDictionary<string, Entity> entityCache;
public ConcurrentBag<string> modelsInUse;
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
public SQLHelper helper;
public AIProvider AiProvider;
public string SearchdomainName;
public int Id;
public SearchdomainSettings Settings;
public EnumerableLruCache<string, DateTimedSearchResult> QueryCache; // Key: query, Value: Search results for that query (with timestamp)
public ConcurrentDictionary<string, Entity> EntityCache;
public ConcurrentBag<string> ModelsInUse;
public EnumerableLruCache<string, Dictionary<string, float[]>> EmbeddingCache;
public SQLHelper Helper;
private readonly ILogger _logger;
public Searchdomain(string searchdomain, string connectionString, SQLHelper sqlHelper, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
{
_connectionString = connectionString;
_provider = provider.ToLower();
this.searchdomain = searchdomain;
this.aIProvider = aIProvider;
this.embeddingCache = embeddingCache;
this.SearchdomainName = searchdomain;
this.AiProvider = aIProvider;
this.EmbeddingCache = embeddingCache;
this._logger = logger;
entityCache = [];
helper = sqlHelper;
settings = GetSettings();
queryCache = new(settings.QueryCacheSize);
modelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that...
EntityCache = [];
Helper = sqlHelper;
Settings = GetSettings();
QueryCache = new(Settings.QueryCacheSize);
ModelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that...
if (!runEmpty)
{
id = GetID().Result;
Id = GetID().Result;
UpdateEntityCache();
}
}
@@ -51,9 +51,9 @@ public class Searchdomain
InvalidateSearchCache();
Dictionary<string, dynamic> parametersIDSearchdomain = new()
{
["id"] = this.id
["id"] = this.Id
};
DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT id, id_datapoint, model, embedding FROM embedding WHERE id_searchdomain = @id", parametersIDSearchdomain);
DbDataReader embeddingReader = Helper.ExecuteSQLCommand("SELECT id, id_datapoint, model, embedding FROM embedding WHERE id_searchdomain = @id", parametersIDSearchdomain);
Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = [];
try
{
@@ -90,7 +90,7 @@ public class Searchdomain
embeddingReader.Close();
}
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT d.id, d.id_entity, d.name, d.probmethod_embedding, d.similaritymethod, d.hash FROM datapoint d JOIN entity ent ON d.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain);
DbDataReader datapointReader = Helper.ExecuteSQLCommand("SELECT d.id, d.id_entity, d.name, d.probmethod_embedding, d.similaritymethod, d.hash FROM datapoint d JOIN entity ent ON d.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain);
Dictionary<int, ConcurrentBag<Datapoint>> datapoint_unassigned = [];
try
{
@@ -127,7 +127,7 @@ public class Searchdomain
datapointReader.Close();
}
DbDataReader attributeReader = helper.ExecuteSQLCommand("SELECT a.id, a.id_entity, a.attribute, a.value FROM attribute a JOIN entity ent ON a.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain);
DbDataReader attributeReader = Helper.ExecuteSQLCommand("SELECT a.id, a.id_entity, a.attribute, a.value FROM attribute a JOIN entity ent ON a.id_entity = ent.id JOIN searchdomain s ON ent.id_searchdomain = s.id WHERE s.id = @id", parametersIDSearchdomain);
Dictionary<int, Dictionary<string, string>> attributes_unassigned = [];
try
{
@@ -149,8 +149,8 @@ public class Searchdomain
attributeReader.Close();
}
entityCache = [];
DbDataReader entityReader = helper.ExecuteSQLCommand("SELECT entity.id, name, probmethod FROM entity WHERE id_searchdomain=@id", parametersIDSearchdomain);
EntityCache = [];
DbDataReader entityReader = Helper.ExecuteSQLCommand("SELECT entity.id, name, probmethod FROM entity WHERE id_searchdomain=@id", parametersIDSearchdomain);
try
{
while (entityReader.Read())
@@ -163,26 +163,26 @@ public class Searchdomain
{
attributes = [];
}
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
Probmethods.ProbMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
if (datapoint_unassigned.TryGetValue(id, out ConcurrentBag<Datapoint>? datapoints) && probmethod is not null)
{
Entity entity = new(attributes, probmethod, probmethodString, datapoints, name)
Entity entity = new(attributes, probmethod, probmethodString, datapoints, name, SearchdomainName)
{
id = id
Id = id
};
entityCache[name] = entity;
EntityCache[name] = entity;
}
}
} finally
{
entityReader.Close();
}
modelsInUse = GetModels(entityCache);
ModelsInUse = GetModels(EntityCache);
}
public List<(float, string)> Search(string query, int? topN = null)
{
if (queryCache.TryGetValue(query, out DateTimedSearchResult cachedResult))
if (QueryCache.TryGetValue(query, out DateTimedSearchResult cachedResult))
{
cachedResult.AccessDateTimes.Add(DateTime.Now);
return [.. cachedResult.Results.Select(r => (r.Score, r.Name))];
@@ -191,9 +191,9 @@ public class Searchdomain
Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query);
List<(float, string)> result = [];
foreach ((string name, Entity entity) in entityCache)
foreach ((string name, Entity entity) in EntityCache)
{
result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.name));
result.Add((EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings), entity.Name));
}
IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1);
if (topN is not null)
@@ -205,26 +205,26 @@ public class Searchdomain
[.. sortedResults.Select(r =>
new ResultItem(r.Item1, r.Item2 ))]
);
queryCache.Set(query, new DateTimedSearchResult(DateTime.Now, searchResult));
QueryCache.Set(query, new DateTimedSearchResult(DateTime.Now, searchResult));
return results;
}
public Dictionary<string, float[]> GetQueryEmbeddings(string query)
{
bool hasQuery = embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings);
bool allModelsInQuery = queryEmbeddings is not null && modelsInUse.All(model => queryEmbeddings.ContainsKey(model));
bool hasQuery = EmbeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings);
bool allModelsInQuery = queryEmbeddings is not null && ModelsInUse.All(model => queryEmbeddings.ContainsKey(model));
if (!(hasQuery && allModelsInQuery) || queryEmbeddings is null)
{
queryEmbeddings = Datapoint.GetEmbeddings(query, modelsInUse, aIProvider, embeddingCache);
if (!embeddingCache.TryGetValue(query, out var embeddingCacheForCurrentQuery))
queryEmbeddings = Datapoint.GetEmbeddings(query, ModelsInUse, AiProvider, EmbeddingCache);
if (!EmbeddingCache.TryGetValue(query, out var embeddingCacheForCurrentQuery))
{
embeddingCache.Set(query, queryEmbeddings);
EmbeddingCache.Set(query, queryEmbeddings);
}
else // embeddingCache already has an entry for this query, so the missing model-embedding pairs have to be filled in
{
foreach (KeyValuePair<string, float[]> kvp in queryEmbeddings) // kvp.Key = model, kvp.Value = embedding
{
if (!embeddingCache.TryGetValue(kvp.Key, out var _))
if (!EmbeddingCache.TryGetValue(kvp.Key, out var _))
{
embeddingCacheForCurrentQuery[kvp.Key] = kvp.Value;
}
@@ -236,25 +236,25 @@ public class Searchdomain
public void UpdateModelsInUse()
{
modelsInUse = GetModels(entityCache);
ModelsInUse = GetModels(EntityCache);
}
private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary<string, float[]> queryEmbeddings)
{
List<(string, float)> datapointProbs = [];
foreach (Datapoint datapoint in entity.datapoints)
foreach (Datapoint datapoint in entity.Datapoints)
{
SimilarityMethod similarityMethod = datapoint.similarityMethod;
SimilarityMethod similarityMethod = datapoint.SimilarityMethod;
List<(string, float)> list = [];
foreach ((string, float[]) embedding in datapoint.embeddings)
foreach ((string, float[]) embedding in datapoint.Embeddings)
{
string key = embedding.Item1;
float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2);
float value = similarityMethod.Method(queryEmbeddings[embedding.Item1], embedding.Item2);
list.Add((key, value));
}
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
datapointProbs.Add((datapoint.Name, datapoint.ProbMethod.Method(list)));
}
return entity.probMethod(datapointProbs);
return entity.ProbMethod(datapointProbs);
}
public static ConcurrentBag<string> GetModels(ConcurrentDictionary<string, Entity> entities)
@@ -265,9 +265,9 @@ public class Searchdomain
Entity entity = element.Value;
lock (entity)
{
foreach (Datapoint datapoint in entity.datapoints)
foreach (Datapoint datapoint in entity.Datapoints)
{
foreach ((string, float[]) tuple in datapoint.embeddings)
foreach ((string, float[]) tuple in datapoint.Embeddings)
{
string model = tuple.Item1;
if (!result.Contains(model))
@@ -285,21 +285,21 @@ public class Searchdomain
{
Dictionary<string, object?> parameters = new()
{
{ "name", this.searchdomain }
{ "name", this.SearchdomainName }
};
return (await helper.ExecuteQueryAsync("SELECT id from searchdomain WHERE name = @name", parameters, x => x.GetInt32(0))).First();
return (await Helper.ExecuteQueryAsync("SELECT id from searchdomain WHERE name = @name", parameters, x => x.GetInt32(0))).First();
}
public SearchdomainSettings GetSettings()
{
return DatabaseHelper.GetSearchdomainSettings(helper, searchdomain);
return DatabaseHelper.GetSearchdomainSettings(Helper, SearchdomainName);
}
public void ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(Entity entity)
{
if (settings.CacheReconciliation)
if (Settings.CacheReconciliation)
{
foreach (var element in queryCache)
foreach (var element in QueryCache)
{
string query = element.Key;
DateTimedSearchResult searchResult = element.Value;
@@ -307,9 +307,9 @@ public class Searchdomain
Dictionary<string, float[]> queryEmbeddings = GetQueryEmbeddings(query);
float evaluationResult = EvaluateEntityAgainstQueryEmbeddings(entity, queryEmbeddings);
searchResult.Results.RemoveAll(x => x.Name == entity.name); // If entity already exists in that results list: remove it.
searchResult.Results.RemoveAll(x => x.Name == entity.Name); // If entity already exists in that results list: remove it.
ResultItem newItem = new(evaluationResult, entity.name);
ResultItem newItem = new(evaluationResult, entity.Name);
int index = searchResult.Results.BinarySearch(
newItem,
Comparer<ResultItem>.Create((a, b) => b.Score.CompareTo(a.Score)) // Invert searching order
@@ -327,13 +327,13 @@ public class Searchdomain
public void ReconciliateOrInvalidateCacheForDeletedEntity(Entity entity)
{
if (settings.CacheReconciliation)
if (Settings.CacheReconciliation)
{
foreach (KeyValuePair<string, DateTimedSearchResult> element in queryCache)
foreach (KeyValuePair<string, DateTimedSearchResult> element in QueryCache)
{
string query = element.Key;
DateTimedSearchResult searchResult = element.Value;
searchResult.Results.RemoveAll(x => x.Name == entity.name);
searchResult.Results.RemoveAll(x => x.Name == entity.Name);
}
}
else
@@ -344,13 +344,13 @@ public class Searchdomain
public void InvalidateSearchCache()
{
queryCache = new(settings.QueryCacheSize);
QueryCache = new(Settings.QueryCacheSize);
}
public long GetSearchCacheSize()
{
long EmbeddingCacheUtilization = 0;
foreach (var entry in queryCache)
foreach (var entry in QueryCache)
{
EmbeddingCacheUtilization += sizeof(int); // string length prefix
EmbeddingCacheUtilization += entry.Key.Length * sizeof(char); // string characters

View File

@@ -15,50 +15,50 @@ namespace Server;
public class SearchdomainManager : IDisposable
{
private Dictionary<string, Searchdomain> searchdomains = [];
private Dictionary<string, Searchdomain> _searchdomains = [];
private readonly ILogger<SearchdomainManager> _logger;
private readonly EmbeddingSearchOptions _options;
public readonly AIProvider aIProvider;
public readonly AIProvider AiProvider;
private readonly DatabaseHelper _databaseHelper;
private readonly string connectionString;
private MySqlConnection connection;
public SQLHelper helper;
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
private MySqlConnection _connection;
public SQLHelper Helper;
public EnumerableLruCache<string, Dictionary<string, float[]>> EmbeddingCache;
public long EmbeddingCacheMaxCount;
private bool disposed = false;
private bool _disposed = false;
public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper)
{
_logger = logger;
_options = options.Value;
this.aIProvider = aIProvider;
this.AiProvider = aIProvider;
_databaseHelper = databaseHelper;
EmbeddingCacheMaxCount = _options.Cache.CacheTopN;
if (options.Value.Cache.StoreEmbeddingCache)
{
var stopwatch = Stopwatch.StartNew();
embeddingCache = CacheHelper.GetEmbeddingStore(options.Value);
EmbeddingCache = CacheHelper.GetEmbeddingStore(options.Value);
stopwatch.Stop();
_logger.LogInformation("GetEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds);
} else
{
embeddingCache = new((int)EmbeddingCacheMaxCount);
EmbeddingCache = new((int)EmbeddingCacheMaxCount);
}
connectionString = _options.ConnectionStrings.SQL;
connection = new MySqlConnection(connectionString);
connection.Open();
helper = new SQLHelper(connection, connectionString);
_connection = new MySqlConnection(connectionString);
_connection.Open();
Helper = new SQLHelper(_connection, connectionString);
}
public Searchdomain GetSearchdomain(string searchdomain)
{
if (searchdomains.TryGetValue(searchdomain, out Searchdomain? value))
if (_searchdomains.TryGetValue(searchdomain, out Searchdomain? value))
{
return value;
}
try
{
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, helper, aIProvider, embeddingCache, _logger));
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, Helper, AiProvider, EmbeddingCache, _logger));
}
catch (MySqlException)
{
@@ -81,7 +81,7 @@ public class SearchdomainManager : IDisposable
public async Task<List<string>> ListSearchdomainsAsync()
{
return await helper.ExecuteQueryAsync("SELECT name FROM searchdomain", [], x => x.GetString(0));
return await Helper.ExecuteQueryAsync("SELECT name FROM searchdomain", [], x => x.GetString(0));
}
public async Task<int> CreateSearchdomain(string searchdomain, SearchdomainSettings settings)
@@ -91,7 +91,7 @@ public class SearchdomainManager : IDisposable
public async Task<int> CreateSearchdomain(string searchdomain, string settings = "{}")
{
if (searchdomains.TryGetValue(searchdomain, out Searchdomain? value))
if (_searchdomains.TryGetValue(searchdomain, out Searchdomain? value))
{
_logger.LogError("Searchdomain {searchdomain} could not be created, as it already exists", [searchdomain]);
throw new SearchdomainAlreadyExistsException(searchdomain);
@@ -101,30 +101,30 @@ public class SearchdomainManager : IDisposable
{ "name", searchdomain },
{ "settings", settings}
};
int id = await helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters);
searchdomains.Add(searchdomain, new(searchdomain, connectionString, helper, aIProvider, embeddingCache, _logger));
int id = await Helper.ExecuteSQLCommandGetInsertedID("INSERT INTO searchdomain (name, settings) VALUES (@name, @settings)", parameters);
_searchdomains.Add(searchdomain, new(searchdomain, connectionString, Helper, AiProvider, EmbeddingCache, _logger));
return id;
}
public async Task<int> DeleteSearchdomain(string searchdomain)
{
int counter = await _databaseHelper.RemoveAllEntities(helper, searchdomain);
int counter = await _databaseHelper.RemoveAllEntities(Helper, searchdomain);
_logger.LogDebug($"Number of entities deleted as part of deleting the searchdomain \"{searchdomain}\": {counter}");
await helper.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", searchdomain}});
searchdomains.Remove(searchdomain);
await Helper.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", searchdomain}});
_searchdomains.Remove(searchdomain);
_logger.LogDebug($"Searchdomain has been successfully removed");
return counter;
}
private Searchdomain SetSearchdomain(string name, Searchdomain searchdomain)
{
searchdomains[name] = searchdomain;
_searchdomains[name] = searchdomain;
return searchdomain;
}
public bool IsSearchdomainLoaded(string name)
{
return searchdomains.ContainsKey(name);
return _searchdomains.ContainsKey(name);
}
// Cleanup procedure
@@ -135,7 +135,7 @@ public class SearchdomainManager : IDisposable
if (_options.Cache.StoreEmbeddingCache)
{
var stopwatch = Stopwatch.StartNew();
await CacheHelper.UpdateEmbeddingStore(embeddingCache, _options);
await CacheHelper.UpdateEmbeddingStore(EmbeddingCache, _options);
stopwatch.Stop();
_logger.LogInformation("UpdateEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds);
}
@@ -155,10 +155,10 @@ public class SearchdomainManager : IDisposable
protected virtual async Task Dispose(bool disposing)
{
if (!disposed && disposing)
if (!_disposed && disposing)
{
await Cleanup();
disposed = true;
_disposed = true;
}
}
}

View File

@@ -5,16 +5,16 @@ namespace Server;
public class SimilarityMethod
{
public SimilarityMethods.similarityMethodDelegate method;
public SimilarityMethodEnum similarityMethodEnum;
public string name;
public SimilarityMethods.similarityMethodDelegate Method;
public SimilarityMethodEnum SimilarityMethodEnum;
public string Name;
public SimilarityMethod(SimilarityMethodEnum similarityMethodEnum)
{
this.similarityMethodEnum = similarityMethodEnum;
this.name = similarityMethodEnum.ToString();
SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name) ?? throw new Exception($"Unable to retrieve similarityMethod {name}");
method = probMethod;
SimilarityMethodEnum = similarityMethodEnum;
Name = similarityMethodEnum.ToString();
SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(Name) ?? throw new Exception($"Unable to retrieve similarityMethod {Name}");
Method = probMethod;
}
}
@@ -22,11 +22,11 @@ public static class SimilarityMethods
{
public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2);
public delegate float similarityMethodDelegate(float[] vector1, float[] vector2);
public static readonly Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate> probMethods;
public static readonly Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate> ProbMethods;
static SimilarityMethods()
{
probMethods = new Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate>
ProbMethods = new Dictionary<SimilarityMethodEnum, similarityMethodProtoDelegate>
{
[SimilarityMethodEnum.Cosine] = CosineSimilarity,
[SimilarityMethodEnum.Euclidian] = EuclidianDistance,
@@ -44,7 +44,7 @@ public static class SimilarityMethods
methodName
);
if (!probMethods.TryGetValue(probMethodEnum, out similarityMethodProtoDelegate? method))
if (!ProbMethods.TryGetValue(probMethodEnum, out similarityMethodProtoDelegate? method))
{
return null;
}