Compare commits
4 Commits
108-add-re
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b88bd1960 | ||
| 1aa2476779 | |||
|
|
7ed144bc39 | ||
| 3b42a73b73 |
@@ -160,21 +160,6 @@ public class Client
|
|||||||
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null);
|
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Task<EntityQueryResults> SearchdomainQueryRerankedAsync(string searchdomain, string query, string rerankerModel, int topN, int topNRetrieval, bool returnAttributes = false)
|
|
||||||
{
|
|
||||||
Dictionary<string, string> parameters = new()
|
|
||||||
{
|
|
||||||
{ "searchdomain", searchdomain },
|
|
||||||
{ "query", query },
|
|
||||||
{ "rerankerModel", (rerankerModel).ToString() },
|
|
||||||
{ "topN", (topN).ToString() },
|
|
||||||
{ "topNRetrieval", (topNRetrieval).ToString() }
|
|
||||||
};
|
|
||||||
if (returnAttributes) parameters.Add("returnAttributes", returnAttributes.ToString());
|
|
||||||
|
|
||||||
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "QueryReranked", parameters), null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public async Task<SearchdomainDeleteSearchResult> SearchdomainDeleteQueryAsync(string searchdomain, string query)
|
public async Task<SearchdomainDeleteSearchResult> SearchdomainDeleteQueryAsync(string searchdomain, string query)
|
||||||
{
|
{
|
||||||
Dictionary<string, string> parameters = new()
|
Dictionary<string, string> parameters = new()
|
||||||
|
|||||||
@@ -132,107 +132,6 @@ public class AIProvider
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public IEnumerable<(int index, float score)> Rerank(string modelUri, string input, string[] documents, int topN)
|
|
||||||
{
|
|
||||||
Uri uri = new(modelUri);
|
|
||||||
string provider = uri.Scheme;
|
|
||||||
string model = uri.AbsolutePath;
|
|
||||||
AiProvider? aIProvider = AiProvidersConfiguration
|
|
||||||
.FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower()))
|
|
||||||
.Value;
|
|
||||||
if (aIProvider is null)
|
|
||||||
{
|
|
||||||
_logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]);
|
|
||||||
throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}");
|
|
||||||
}
|
|
||||||
using var httpClient = new HttpClient();
|
|
||||||
httpClient.Timeout = TimeSpan.FromMinutes(150);
|
|
||||||
|
|
||||||
string indexJsonPath = "";
|
|
||||||
string scoreJsonPath = "";
|
|
||||||
IEnumerable<(string, float)> values = [];
|
|
||||||
Uri baseUri = new(aIProvider.BaseURL);
|
|
||||||
Uri requestUri;
|
|
||||||
IRerankRequestBody rerankRequestBody;
|
|
||||||
string[][] requestHeaders = [];
|
|
||||||
switch (aIProvider.Handler)
|
|
||||||
{
|
|
||||||
case "openai":
|
|
||||||
indexJsonPath = "$.results[*].index";
|
|
||||||
scoreJsonPath = "$.results[*].relevance_score";
|
|
||||||
requestUri = new Uri(baseUri, "/v1/rerank");
|
|
||||||
rerankRequestBody = new OpenAIRerankRequestBody()
|
|
||||||
{
|
|
||||||
model = model,
|
|
||||||
query = input,
|
|
||||||
documents = documents,
|
|
||||||
top_n = topN
|
|
||||||
};
|
|
||||||
if (aIProvider.ApiKey is not null)
|
|
||||||
{
|
|
||||||
requestHeaders = [
|
|
||||||
["Authorization", $"Bearer {aIProvider.ApiKey}"]
|
|
||||||
];
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
_logger.LogError("Invalid reranking handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]);
|
|
||||||
throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}.");
|
|
||||||
}
|
|
||||||
var requestContent = new StringContent(
|
|
||||||
JsonConvert.SerializeObject(rerankRequestBody),
|
|
||||||
Encoding.UTF8,
|
|
||||||
"application/json"
|
|
||||||
);
|
|
||||||
|
|
||||||
var request = new HttpRequestMessage()
|
|
||||||
{
|
|
||||||
RequestUri = requestUri,
|
|
||||||
Method = HttpMethod.Post,
|
|
||||||
Content = requestContent
|
|
||||||
};
|
|
||||||
|
|
||||||
foreach (var header in requestHeaders)
|
|
||||||
{
|
|
||||||
request.Headers.Add(header[0], header[1]);
|
|
||||||
}
|
|
||||||
HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result;
|
|
||||||
string responseContent = response.Content.ReadAsStringAsync().Result;
|
|
||||||
try
|
|
||||||
{
|
|
||||||
JObject responseContentJson = JObject.Parse(responseContent);
|
|
||||||
List<JToken>? responseContentIndexTokens = [.. responseContentJson.SelectTokens(indexJsonPath)];
|
|
||||||
List<JToken>? responseContentScoreTokens = [.. responseContentJson.SelectTokens(scoreJsonPath)];
|
|
||||||
if (responseContentIndexTokens is null || responseContentIndexTokens.Count == 0
|
|
||||||
|| responseContentScoreTokens is null || responseContentScoreTokens.Count == 0)
|
|
||||||
{
|
|
||||||
if (responseContentJson.TryGetValue("error", out JToken? errorMessageJson) && errorMessageJson is not null)
|
|
||||||
{
|
|
||||||
string errorMessage = (string?)errorMessageJson.Value<string>("message") ?? "";
|
|
||||||
string errorCode = (string?)errorMessageJson.Value<string>("code") ?? "";
|
|
||||||
string errorType = (string?)errorMessageJson.Value<string>("type") ?? "";
|
|
||||||
_logger.LogError("Unable to retrieve reranking results due to error: {errorCode} - {errorMessage} - {errorType}", [errorCode, errorMessage, errorType]);
|
|
||||||
throw new Exception($"Unable to retrieve reranking results due to error: {errorMessage}");
|
|
||||||
|
|
||||||
} else
|
|
||||||
{
|
|
||||||
_logger.LogError("Unable to select tokens using JSONPath {indexJsonPath} for string: {responseContent}.", [indexJsonPath, responseContent]);
|
|
||||||
throw new JSONPathSelectionException(indexJsonPath, responseContent);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
IEnumerable<int> indices = responseContentIndexTokens.Select(token => token.ToObject<int>());
|
|
||||||
IEnumerable<float> scores = responseContentScoreTokens.Select(token => token.ToObject<float>());
|
|
||||||
IEnumerable<(int index, float score)> zipped = indices.Zip(scores, (index, score) => (index, score));
|
|
||||||
|
|
||||||
return zipped;
|
|
||||||
}
|
|
||||||
catch (Exception ex)
|
|
||||||
{
|
|
||||||
_logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]);
|
|
||||||
throw;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public string[] GetModels()
|
public string[] GetModels()
|
||||||
{
|
{
|
||||||
var aIProviders = AiProvidersConfiguration;
|
var aIProviders = AiProvidersConfiguration;
|
||||||
@@ -340,16 +239,4 @@ public class OpenAIEmbedRequestBody : IEmbedRequestBody
|
|||||||
{
|
{
|
||||||
public required string model { get; set; }
|
public required string model { get; set; }
|
||||||
public required string[] input { get; set; }
|
public required string[] input { get; set; }
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public interface IRerankRequestBody { }
|
|
||||||
|
|
||||||
|
|
||||||
public class OpenAIRerankRequestBody : IRerankRequestBody
|
|
||||||
{
|
|
||||||
public required string model { get; set; }
|
|
||||||
public required string query { get; set; }
|
|
||||||
public required int top_n { get; set; }
|
|
||||||
public required string[] documents { get; set; }
|
|
||||||
}
|
}
|
||||||
@@ -148,78 +148,6 @@ public class SearchdomainController : ControllerBase
|
|||||||
return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true });
|
return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/// <summary>
|
|
||||||
/// Executes a query in the searchdomain and reranks the result using a specified reranker
|
|
||||||
/// </summary>
|
|
||||||
/// <param name="searchdomain">Name of the searchdomain</param>
|
|
||||||
/// <param name="query">Query to execute</param>
|
|
||||||
/// <param name="topN">Return only the top N results</param>
|
|
||||||
/// <param name="returnAttributes">Return the attributes of the object</param>
|
|
||||||
[HttpPost("QueryReranked")]
|
|
||||||
public ActionResult<EntityRerankResults> QueryReranked([Required]string searchdomain, [Required]string query, [Required]string rerankerModel, int topN, int topNRetrieval, ProbMethodEnum probMethod = ProbMethodEnum.HVEWAvg, bool returnAttributes = false)
|
|
||||||
{
|
|
||||||
|
|
||||||
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
|
|
||||||
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
|
|
||||||
List<(float, string)> results = searchdomain_.Search(query, topNRetrieval);
|
|
||||||
List<(string Name, Dictionary<string, string> Attributes)> queryResults = [.. results.Select(r => (
|
|
||||||
Name: r.Item2,
|
|
||||||
Attributes: searchdomain_.EntityCache[r.Item2]?.Attributes ?? []
|
|
||||||
))];
|
|
||||||
|
|
||||||
|
|
||||||
// Key: Attribute name
|
|
||||||
Dictionary<string, List<(string EntityName, string AttributeValue)>> resultsByAttribute = [];
|
|
||||||
queryResults.ForEach(r =>
|
|
||||||
{
|
|
||||||
foreach (var kv in r.Attributes)
|
|
||||||
{
|
|
||||||
if (!resultsByAttribute.TryGetValue(kv.Key, out List<(string EntityName, string AttributeValue)>? values) || values is null)
|
|
||||||
{
|
|
||||||
values = [];
|
|
||||||
resultsByAttribute[kv.Key] = values;
|
|
||||||
}
|
|
||||||
values.Add((r.Name, kv.Value));
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Key: EntityName
|
|
||||||
Dictionary<string, List<(string attribute, float score)>> scoresByEntity = [];
|
|
||||||
foreach (var kv in resultsByAttribute)
|
|
||||||
{
|
|
||||||
string attributeName = kv.Key;
|
|
||||||
List<(string EntityName, string AttributeValue)> nameValuePairs = kv.Value;
|
|
||||||
|
|
||||||
List<string> documents = [.. nameValuePairs.Select(r => r.AttributeValue)];
|
|
||||||
List<(int index, float score)> rerankResults = [.. searchdomain_.AiProvider.Rerank(rerankerModel, query, [.. documents], topN)];
|
|
||||||
List<(string entityName, float score)> rerankedScores = [.. rerankResults.Select(r => (nameValuePairs.ElementAt(r.index).EntityName, r.score))];
|
|
||||||
foreach ((string entityName, float score) in rerankedScores)
|
|
||||||
{
|
|
||||||
if (!scoresByEntity.TryGetValue(entityName, out List<(string attribute, float score)>? values) || values is null)
|
|
||||||
{
|
|
||||||
values = [];
|
|
||||||
scoresByEntity[entityName] = values;
|
|
||||||
}
|
|
||||||
values.Add((attributeName, score));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
List<EntityRerankResult> entityRerankResults = [.. scoresByEntity.Select(scoreKV =>
|
|
||||||
{
|
|
||||||
string entityName = scoreKV.Key;
|
|
||||||
float score = new ProbMethod(probMethod).Method(scoreKV.Value);
|
|
||||||
return new EntityRerankResult()
|
|
||||||
{
|
|
||||||
Name = entityName,
|
|
||||||
Value = score,
|
|
||||||
Attributes = returnAttributes ? (searchdomain_.EntityCache[entityName]?.Attributes ?? []) : null
|
|
||||||
};
|
|
||||||
})];
|
|
||||||
|
|
||||||
return Ok(new EntityRerankResults(){Results = entityRerankResults, Success = true });
|
|
||||||
}
|
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Executes a query in the searchdomain
|
/// Executes a query in the searchdomain
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|||||||
@@ -226,6 +226,53 @@ public class SQLHelper:IDisposable
|
|||||||
Thread.Sleep(sleepTime);
|
Thread.Sleep(sleepTime);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
public async Task ExecuteInTransactionAsync(Func<MySqlConnection, DbTransaction, Task> operation)
|
||||||
|
{
|
||||||
|
var poolElement = await GetMySqlConnectionPoolElement();
|
||||||
|
var connection = poolElement.Connection;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
using var transaction = connection.BeginTransaction();
|
||||||
|
try
|
||||||
|
{
|
||||||
|
await operation(connection, transaction);
|
||||||
|
await transaction.CommitAsync();
|
||||||
|
}
|
||||||
|
catch
|
||||||
|
{
|
||||||
|
await transaction.RollbackAsync();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
poolElement.Semaphore.Release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void ExecuteInTransaction(Action<MySqlConnection, MySqlTransaction> operation)
|
||||||
|
{
|
||||||
|
var poolElement = GetMySqlConnectionPoolElement().Result;
|
||||||
|
var connection = poolElement.Connection;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
using var transaction = connection.BeginTransaction();
|
||||||
|
try
|
||||||
|
{
|
||||||
|
operation(connection, transaction);
|
||||||
|
transaction.Commit();
|
||||||
|
}
|
||||||
|
catch
|
||||||
|
{
|
||||||
|
transaction.Rollback();
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
poolElement.Semaphore.Release();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public struct MySqlConnectionPoolElement
|
public struct MySqlConnectionPoolElement
|
||||||
|
|||||||
@@ -29,13 +29,9 @@ public static class DatabaseMigrations
|
|||||||
if (version >= databaseVersion)
|
if (version >= databaseVersion)
|
||||||
{
|
{
|
||||||
databaseVersion = (int)method.Invoke(null, new object[] { helper });
|
databaseVersion = (int)method.Invoke(null, new object[] { helper });
|
||||||
|
var _ = helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() }).Result;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (databaseVersion != initialDatabaseVersion)
|
|
||||||
{
|
|
||||||
var _ = helper.ExecuteSQLNonQuery("UPDATE settings SET value = @databaseVersion", new() { ["databaseVersion"] = databaseVersion.ToString() }).Result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int DatabaseGetVersion(SQLHelper helper)
|
public static int DatabaseGetVersion(SQLHelper helper)
|
||||||
@@ -122,25 +118,41 @@ public static class DatabaseMigrations
|
|||||||
{
|
{
|
||||||
// Add id_entity to embedding
|
// Add id_entity to embedding
|
||||||
var _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_entity INT NULL", []).Result;
|
var _ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_entity INT NULL", []).Result;
|
||||||
|
return 6;
|
||||||
|
}
|
||||||
|
public static int UpdateFrom6(SQLHelper helper)
|
||||||
|
{
|
||||||
int count;
|
int count;
|
||||||
do
|
do
|
||||||
{
|
{
|
||||||
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN datapoint d ON d.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_entity IS NULL LIMIT 10000) x on x.id = e.id SET e.id_entity = d.id_entity;", []).Result;
|
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN datapoint d ON d.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_entity IS NULL LIMIT 10000) x on x.id = e.id SET e.id_entity = d.id_entity;", []).Result;
|
||||||
} while (count == 10000);
|
} while (count == 10000);
|
||||||
|
return 7;
|
||||||
|
}
|
||||||
|
public static int UpdateFrom7(SQLHelper helper)
|
||||||
|
{
|
||||||
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_entity INT NOT NULL;", []).Result;
|
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_entity INT NOT NULL;", []).Result;
|
||||||
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_entity_model ON embedding (id_entity, model)", []).Result;
|
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_entity_model ON embedding (id_entity, model)", []).Result;
|
||||||
|
|
||||||
// Add id_searchdomain to embedding
|
// Add id_searchdomain to embedding
|
||||||
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_searchdomain INT NULL", []).Result;
|
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding ADD COLUMN id_searchdomain INT NULL", []).Result;
|
||||||
|
return 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int UpdateFrom8(SQLHelper helper)
|
||||||
|
{
|
||||||
|
int count = 0;
|
||||||
do
|
do
|
||||||
{
|
{
|
||||||
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN entity en ON en.id = e.id_datapoint JOIN (SELECT id FROM embedding WHERE id_searchdomain IS NULL LIMIT 10000) x on x.id = e.id SET e.id_searchdomain = en.id_searchdomain;", []).Result;
|
count = helper.ExecuteSQLNonQuery("UPDATE embedding e JOIN entity en ON en.id = e.id_entity JOIN (SELECT id FROM embedding WHERE id_searchdomain IS NULL LIMIT 10000) x on x.id = e.id SET e.id_searchdomain = en.id_searchdomain;", []).Result;
|
||||||
} while (count == 10000);
|
} while (count == 10000);
|
||||||
|
return 9;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static int UpdateFrom9(SQLHelper helper)
|
||||||
|
{
|
||||||
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_searchdomain INT NOT NULL;", []).Result;
|
_ = helper.ExecuteSQLNonQuery("ALTER TABLE embedding MODIFY id_searchdomain INT NOT NULL;", []).Result;
|
||||||
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_searchdomain_model ON embedding (id_searchdomain)", []).Result;
|
_ = helper.ExecuteSQLNonQuery("CREATE INDEX idx_embedding_searchdomain_model ON embedding (id_searchdomain)", []).Result;
|
||||||
|
return 10;
|
||||||
return 6;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -20,23 +20,6 @@ public class EntityQueryResult
|
|||||||
public Dictionary<string, string>? Attributes { get; set; }
|
public Dictionary<string, string>? Attributes { get; set; }
|
||||||
}
|
}
|
||||||
|
|
||||||
public class EntityRerankResults : SuccesMessageBaseModel
|
|
||||||
{
|
|
||||||
[JsonPropertyName("Results")]
|
|
||||||
public required List<EntityRerankResult> Results { get; set; }
|
|
||||||
}
|
|
||||||
|
|
||||||
public class EntityRerankResult
|
|
||||||
{
|
|
||||||
[JsonPropertyName("Name")]
|
|
||||||
public required string Name { get; set; }
|
|
||||||
[JsonPropertyName("Value")]
|
|
||||||
public float Value { get; set; }
|
|
||||||
[JsonPropertyName("Attributes")]
|
|
||||||
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
|
||||||
public Dictionary<string, string>? Attributes { get; set; }
|
|
||||||
}
|
|
||||||
|
|
||||||
public class EntityIndexResult : SuccesMessageBaseModel {}
|
public class EntityIndexResult : SuccesMessageBaseModel {}
|
||||||
|
|
||||||
public class EntityListResults
|
public class EntityListResults
|
||||||
|
|||||||
Reference in New Issue
Block a user