4 Commits

Author SHA1 Message Date
LD50
1b88bd1960 Merge pull request #133 from LD-Reborn/132-migrations-break-database-on-failure-due-to-lack-of-transactions
Fixed migrations breaking because of IIS, added MySQL transaction method
2026-02-23 21:13:20 +01:00
1aa2476779 Fixed migrations breaking because of IIS, added MySQL transaction method 2026-02-23 21:08:46 +01:00
LD50
7ed144bc39 Merge pull request #131 from LD-Reborn/20260223_mysqlfux
Fixed MySQL migration error
2026-02-23 07:41:33 +01:00
3b42a73b73 Fixed MySQL migration error 2026-02-23 07:41:03 +01:00
6 changed files with 69 additions and 227 deletions

View File

@@ -160,21 +160,6 @@ public class Client
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null);
}
public async Task<EntityQueryResults> SearchdomainQueryRerankedAsync(string searchdomain, string query, string rerankerModel, int topN, int topNRetrieval, bool returnAttributes = false)
{
Dictionary<string, string> parameters = new()
{
{ "searchdomain", searchdomain },
{ "query", query },
{ "rerankerModel", (rerankerModel).ToString() },
{ "topN", (topN).ToString() },
{ "topNRetrieval", (topNRetrieval).ToString() }
};
if (returnAttributes) parameters.Add("returnAttributes", returnAttributes.ToString());
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "QueryReranked", parameters), null);
}
public async Task<SearchdomainDeleteSearchResult> SearchdomainDeleteQueryAsync(string searchdomain, string query)
{
Dictionary<string, string> parameters = new()

View File

@@ -132,107 +132,6 @@ public class AIProvider
}
}
public IEnumerable<(int index, float score)> Rerank(string modelUri, string input, string[] documents, int topN)
{
Uri uri = new(modelUri);
string provider = uri.Scheme;
string model = uri.AbsolutePath;
AiProvider? aIProvider = AiProvidersConfiguration
.FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower()))
.Value;
if (aIProvider is null)
{
_logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]);
throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}");
}
using var httpClient = new HttpClient();
httpClient.Timeout = TimeSpan.FromMinutes(150);
string indexJsonPath = "";
string scoreJsonPath = "";
IEnumerable<(string, float)> values = [];
Uri baseUri = new(aIProvider.BaseURL);
Uri requestUri;
IRerankRequestBody rerankRequestBody;
string[][] requestHeaders = [];
switch (aIProvider.Handler)
{
case "openai":
indexJsonPath = "$.results[*].index";
scoreJsonPath = "$.results[*].relevance_score";
requestUri = new Uri(baseUri, "/v1/rerank");
rerankRequestBody = new OpenAIRerankRequestBody()
{
model = model,
query = input,
documents = documents,
top_n = topN
};
if (aIProvider.ApiKey is not null)
{
requestHeaders = [
["Authorization", $"Bearer {aIProvider.ApiKey}"]
];
}
break;
default:
_logger.LogError("Invalid reranking handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]);
throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}.");
}
var requestContent = new StringContent(
JsonConvert.SerializeObject(rerankRequestBody),
Encoding.UTF8,
"application/json"
);
var request = new HttpRequestMessage()
{
RequestUri = requestUri,
Method = HttpMethod.Post,
Content = requestContent
};
foreach (var header in requestHeaders)
{
request.Headers.Add(header[0], header[1]);
}
HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result;
string responseContent = response.Content.ReadAsStringAsync().Result;
try
{
JObject responseContentJson = JObject.Parse(responseContent);
List<JToken>? responseContentIndexTokens = [.. responseContentJson.SelectTokens(indexJsonPath)];
List<JToken>? responseContentScoreTokens = [.. responseContentJson.SelectTokens(scoreJsonPath)];
if (responseContentIndexTokens is null || responseContentIndexTokens.Count == 0
|| responseContentScoreTokens is null || responseContentScoreTokens.Count == 0)
{
if (responseContentJson.TryGetValue("error", out JToken? errorMessageJson) && errorMessageJson is not null)
{
string errorMessage = (string?)errorMessageJson.Value<string>("message") ?? "";
string errorCode = (string?)errorMessageJson.Value<string>("code") ?? "";
string errorType = (string?)errorMessageJson.Value<string>("type") ?? "";
_logger.LogError("Unable to retrieve reranking results due to error: {errorCode} - {errorMessage} - {errorType}", [errorCode, errorMessage, errorType]);
throw new Exception($"Unable to retrieve reranking results due to error: {errorMessage}");
} else
{
_logger.LogError("Unable to select tokens using JSONPath {indexJsonPath} for string: {responseContent}.", [indexJsonPath, responseContent]);
throw new JSONPathSelectionException(indexJsonPath, responseContent);
}
}
IEnumerable<int> indices = responseContentIndexTokens.Select(token => token.ToObject<int>());
IEnumerable<float> scores = responseContentScoreTokens.Select(token => token.ToObject<float>());
IEnumerable<(int index, float score)> zipped = indices.Zip(scores, (index, score) => (index, score));
return zipped;
}
catch (Exception ex)
{
_logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]);
throw;
}
}
public string[] GetModels()
{
var aIProviders = AiProvidersConfiguration;
@@ -341,15 +240,3 @@ public class OpenAIEmbedRequestBody : IEmbedRequestBody
public required string model { get; set; }
public required string[] input { get; set; }
}
public interface IRerankRequestBody { }
public class OpenAIRerankRequestBody : IRerankRequestBody
{
public required string model { get; set; }
public required string query { get; set; }
public required int top_n { get; set; }
public required string[] documents { get; set; }
}

View File

@@ -148,78 +148,6 @@ public class SearchdomainController : ControllerBase
return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true });
}
/// <summary>
/// Executes a query in the searchdomain and reranks the result using a specified reranker
/// </summary>
/// <param name="searchdomain">Name of the searchdomain</param>
/// <param name="query">Query to execute</param>
/// <param name="topN">Return only the top N results</param>
/// <param name="returnAttributes">Return the attributes of the object</param>
[HttpPost("QueryReranked")]
public ActionResult<EntityRerankResults> QueryReranked([Required]string searchdomain, [Required]string query, [Required]string rerankerModel, int topN, int topNRetrieval, ProbMethodEnum probMethod = ProbMethodEnum.HVEWAvg, bool returnAttributes = false)
{
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
List<(float, string)> results = searchdomain_.Search(query, topNRetrieval);
List<(string Name, Dictionary<string, string> Attributes)> queryResults = [.. results.Select(r => (
Name: r.Item2,
Attributes: searchdomain_.EntityCache[r.Item2]?.Attributes ?? []
))];
// Key: Attribute name
Dictionary<string, List<(string EntityName, string AttributeValue)>> resultsByAttribute = [];
queryResults.ForEach(r =>
{
foreach (var kv in r.Attributes)
{
if (!resultsByAttribute.TryGetValue(kv.Key, out List<(string EntityName, string AttributeValue)>? values) || values is null)
{
values = [];
resultsByAttribute[kv.Key] = values;
}
values.Add((r.Name, kv.Value));
}
});
// Key: EntityName
Dictionary<string, List<(string attribute, float score)>> scoresByEntity = [];
foreach (var kv in resultsByAttribute)
{
string attributeName = kv.Key;
List<(string EntityName, string AttributeValue)> nameValuePairs = kv.Value;
List<string> documents = [.. nameValuePairs.Select(r => r.AttributeValue)];
List<(int index, float score)> rerankResults = [.. searchdomain_.AiProvider.Rerank(rerankerModel, query, [.. documents], topN)];
List<(string entityName, float score)> rerankedScores = [.. rerankResults.Select(r => (nameValuePairs.ElementAt(r.index).EntityName, r.score))];
foreach ((string entityName, float score) in rerankedScores)
{
if (!scoresByEntity.TryGetValue(entityName, out List<(string attribute, float score)>? values) || values is null)
{
values = [];
scoresByEntity[entityName] = values;
}
values.Add((attributeName, score));
}
}
List<EntityRerankResult> entityRerankResults = [.. scoresByEntity.Select(scoreKV =>
{
string entityName = scoreKV.Key;
float score = new ProbMethod(probMethod).Method(scoreKV.Value);
return new EntityRerankResult()
{
Name = entityName,
Value = score,
Attributes = returnAttributes ? (searchdomain_.EntityCache[entityName]?.Attributes ?? []) : null
};
})];
return Ok(new EntityRerankResults(){Results = entityRerankResults, Success = true });
}
/// <summary>
/// Executes a query in the searchdomain
/// </summary>

View File

@@ -226,6 +226,53 @@ public class SQLHelper:IDisposable
Thread.Sleep(sleepTime);
}
}
public async Task ExecuteInTransactionAsync(Func<MySqlConnection, DbTransaction, Task> operation)
{
var poolElement = await GetMySqlConnectionPoolElement();
var connection = poolElement.Connection;
try
{
using var transaction = connection.BeginTransaction();
try
{
await operation(connection, transaction);
await transaction.CommitAsync();
}
catch
{
await transaction.RollbackAsync();
throw;
}
}
finally
{
poolElement.Semaphore.Release();
}
}
public void ExecuteInTransaction(Action<MySqlConnection, MySqlTransaction> operation)
{
var poolElement = GetMySqlConnectionPoolElement().Result;
var connection = poolElement.Connection;
try
{
using var transaction = connection.BeginTransaction();
try
{
operation(connection, transaction);
transaction.Commit();
}
catch
{
transaction.Rollback();
throw;
}
}
finally
{
poolElement.Semaphore.Release();
}
}
}
public struct MySqlConnectionPoolElement

View File

@@ -29,14 +29,10 @@ public static class DatabaseMigrations
if (version >= databaseVersion)
{
databaseVersion = (int)method.Invoke(null, new object[] { helper });
}
}
if (databaseVersion != initialDatabaseVersion)
{
var _ = helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() }).Result;
}
}
}
public static int DatabaseGetVersion(SQLHelper helper)
{
@@ -122,25 +118,41 @@ public static class DatabaseMigrations
{
// Add id_entity to embedding
var _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_entity INT NULL", []).Result;
return 6;
}
public static int UpdateFrom6(SQLHelper helper)
{
int count;
do
{
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN datapoint d ON d.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_entity IS NULL LIMIT 10000) x on x.id = e.id SET e.id_entity = d.id_entity;", []).Result;
} while (count == 10000);
return 7;
}
public static int UpdateFrom7(SQLHelper helper)
{
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_entity INT NOT NULL;", []).Result;
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_entity_model ON embedding (id_entity, model)", []).Result;
// Add id_searchdomain to embedding
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_searchdomain INT NULL", []).Result;
return 8;
}
public static int UpdateFrom8(SQLHelper helper)
{
int count = 0;
do
{
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN entity en ON en.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_searchdomain IS NULL LIMIT 10000) x on x.id = e.id SET e.id_searchdomain = en.id_searchdomain;", []).Result;
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN entity en ON en.id = e.id_entity JOIN (SELECT id FROM embedding WHERE id_searchdomain IS NULL LIMIT 10000) x on x.id = e.id SET e.id_searchdomain = en.id_searchdomain;", []).Result;
} while (count == 10000);
return 9;
}
public static int UpdateFrom9(SQLHelper helper)
{
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_searchdomain INT NOT NULL;", []).Result;
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_searchdomain_model ON embedding (id_searchdomain)", []).Result;
return 6;
return 10;
}
}

View File

@@ -20,23 +20,6 @@ public class EntityQueryResult
public Dictionary<string, string>? Attributes { get; set; }
}
public class EntityRerankResults : SuccesMessageBaseModel
{
[JsonPropertyName("Results")]
public required List<EntityRerankResult> Results { get; set; }
}
public class EntityRerankResult
{
[JsonPropertyName("Name")]
public required string Name { get; set; }
[JsonPropertyName("Value")]
public float Value { get; set; }
[JsonPropertyName("Attributes")]
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public Dictionary<string, string>? Attributes { get; set; }
}
public class EntityIndexResult : SuccesMessageBaseModel {}
public class EntityListResults