From b5a8eec445fc43df103ec6635ee2502e45a5be8f Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Sun, 8 Mar 2026 10:49:27 +0100 Subject: [PATCH] Added reranker exploration setup --- src/Client/Client.cs | 15 +++ src/Server/AIProvider.cs | 113 ++++++++++++++++++ .../Controllers/SearchdomainController.cs | 72 +++++++++++ src/Shared/Models/EntityResults.cs | 17 +++ 4 files changed, 217 insertions(+) diff --git a/src/Client/Client.cs b/src/Client/Client.cs index a00537e..1853906 100644 --- a/src/Client/Client.cs +++ b/src/Client/Client.cs @@ -160,6 +160,21 @@ public class Client return await FetchUrlAndProcessJson(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null); } + public async Task SearchdomainQueryRerankedAsync(string searchdomain, string query, string rerankerModel, int topN, int topNRetrieval, bool returnAttributes = false) + { + Dictionary parameters = new() + { + { "searchdomain", searchdomain }, + { "query", query }, + { "rerankerModel", (rerankerModel).ToString() }, + { "topN", (topN).ToString() }, + { "topNRetrieval", (topNRetrieval).ToString() } + }; + if (returnAttributes) parameters.Add("returnAttributes", returnAttributes.ToString()); + + return await FetchUrlAndProcessJson(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "QueryReranked", parameters), null); + } + public async Task SearchdomainDeleteQueryAsync(string searchdomain, string query) { Dictionary parameters = new() diff --git a/src/Server/AIProvider.cs b/src/Server/AIProvider.cs index d6bf09d..49ff36d 100644 --- a/src/Server/AIProvider.cs +++ b/src/Server/AIProvider.cs @@ -132,6 +132,107 @@ public class AIProvider } } + public IEnumerable<(int index, float score)> Rerank(string modelUri, string input, string[] documents, int topN) + { + Uri uri = new(modelUri); + string provider = uri.Scheme; + string model = uri.AbsolutePath; + AiProvider? aIProvider = AiProvidersConfiguration + .FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower())) + .Value; + if (aIProvider is null) + { + _logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]); + throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}"); + } + using var httpClient = new HttpClient(); + httpClient.Timeout = TimeSpan.FromMinutes(150); + + string indexJsonPath = ""; + string scoreJsonPath = ""; + IEnumerable<(string, float)> values = []; + Uri baseUri = new(aIProvider.BaseURL); + Uri requestUri; + IRerankRequestBody rerankRequestBody; + string[][] requestHeaders = []; + switch (aIProvider.Handler) + { + case "openai": + indexJsonPath = "$.results[*].index"; + scoreJsonPath = "$.results[*].relevance_score"; + requestUri = new Uri(baseUri, "/v1/rerank"); + rerankRequestBody = new OpenAIRerankRequestBody() + { + model = model, + query = input, + documents = documents, + top_n = topN + }; + if (aIProvider.ApiKey is not null) + { + requestHeaders = [ + ["Authorization", $"Bearer {aIProvider.ApiKey}"] + ]; + } + break; + default: + _logger.LogError("Invalid reranking handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]); + throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}."); + } + var requestContent = new StringContent( + JsonConvert.SerializeObject(rerankRequestBody), + Encoding.UTF8, + "application/json" + ); + + var request = new HttpRequestMessage() + { + RequestUri = requestUri, + Method = HttpMethod.Post, + Content = requestContent + }; + + foreach (var header in requestHeaders) + { + request.Headers.Add(header[0], header[1]); + } + HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result; + string responseContent = response.Content.ReadAsStringAsync().Result; + try + { + JObject responseContentJson = JObject.Parse(responseContent); + List? responseContentIndexTokens = [.. responseContentJson.SelectTokens(indexJsonPath)]; + List? responseContentScoreTokens = [.. responseContentJson.SelectTokens(scoreJsonPath)]; + if (responseContentIndexTokens is null || responseContentIndexTokens.Count == 0 + || responseContentScoreTokens is null || responseContentScoreTokens.Count == 0) + { + if (responseContentJson.TryGetValue("error", out JToken? errorMessageJson) && errorMessageJson is not null) + { + string errorMessage = (string?)errorMessageJson.Value("message") ?? ""; + string errorCode = (string?)errorMessageJson.Value("code") ?? ""; + string errorType = (string?)errorMessageJson.Value("type") ?? ""; + _logger.LogError("Unable to retrieve reranking results due to error: {errorCode} - {errorMessage} - {errorType}", [errorCode, errorMessage, errorType]); + throw new Exception($"Unable to retrieve reranking results due to error: {errorMessage}"); + + } else + { + _logger.LogError("Unable to select tokens using JSONPath {indexJsonPath} for string: {responseContent}.", [indexJsonPath, responseContent]); + throw new JSONPathSelectionException(indexJsonPath, responseContent); + } + } + IEnumerable indices = responseContentIndexTokens.Select(token => token.ToObject()); + IEnumerable scores = responseContentScoreTokens.Select(token => token.ToObject()); + IEnumerable<(int index, float score)> zipped = indices.Zip(scores, (index, score) => (index, score)); + + return zipped; + } + catch (Exception ex) + { + _logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]); + throw; + } + } + public string[] GetModels() { var aIProviders = AiProvidersConfiguration; @@ -239,4 +340,16 @@ public class OpenAIEmbedRequestBody : IEmbedRequestBody { public required string model { get; set; } public required string[] input { get; set; } +} + + +public interface IRerankRequestBody { } + + +public class OpenAIRerankRequestBody : IRerankRequestBody +{ + public required string model { get; set; } + public required string query { get; set; } + public required int top_n { get; set; } + public required string[] documents { get; set; } } \ No newline at end of file diff --git a/src/Server/Controllers/SearchdomainController.cs b/src/Server/Controllers/SearchdomainController.cs index 436e32d..a6ad482 100644 --- a/src/Server/Controllers/SearchdomainController.cs +++ b/src/Server/Controllers/SearchdomainController.cs @@ -148,6 +148,78 @@ public class SearchdomainController : ControllerBase return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true }); } + + + /// + /// Executes a query in the searchdomain and reranks the result using a specified reranker + /// + /// Name of the searchdomain + /// Query to execute + /// Return only the top N results + /// Return the attributes of the object + [HttpPost("QueryReranked")] + public ActionResult QueryReranked([Required]string searchdomain, [Required]string query, [Required]string rerankerModel, int topN, int topNRetrieval, ProbMethodEnum probMethod = ProbMethodEnum.HVEWAvg, bool returnAttributes = false) + { + + (Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger); + if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message}); + List<(float, string)> results = searchdomain_.Search(query, topNRetrieval); + List<(string Name, Dictionary Attributes)> queryResults = [.. results.Select(r => ( + Name: r.Item2, + Attributes: searchdomain_.EntityCache[r.Item2]?.Attributes ?? [] + ))]; + + + // Key: Attribute name + Dictionary> resultsByAttribute = []; + queryResults.ForEach(r => + { + foreach (var kv in r.Attributes) + { + if (!resultsByAttribute.TryGetValue(kv.Key, out List<(string EntityName, string AttributeValue)>? values) || values is null) + { + values = []; + resultsByAttribute[kv.Key] = values; + } + values.Add((r.Name, kv.Value)); + } + }); + + // Key: EntityName + Dictionary> scoresByEntity = []; + foreach (var kv in resultsByAttribute) + { + string attributeName = kv.Key; + List<(string EntityName, string AttributeValue)> nameValuePairs = kv.Value; + + List documents = [.. nameValuePairs.Select(r => r.AttributeValue)]; + List<(int index, float score)> rerankResults = [.. searchdomain_.AiProvider.Rerank(rerankerModel, query, [.. documents], topN)]; + List<(string entityName, float score)> rerankedScores = [.. rerankResults.Select(r => (nameValuePairs.ElementAt(r.index).EntityName, r.score))]; + foreach ((string entityName, float score) in rerankedScores) + { + if (!scoresByEntity.TryGetValue(entityName, out List<(string attribute, float score)>? values) || values is null) + { + values = []; + scoresByEntity[entityName] = values; + } + values.Add((attributeName, score)); + } + } + List entityRerankResults = [.. scoresByEntity.Select(scoreKV => + { + string entityName = scoreKV.Key; + float score = new ProbMethod(probMethod).Method(scoreKV.Value); + return new EntityRerankResult() + { + Name = entityName, + Value = score, + Attributes = returnAttributes ? (searchdomain_.EntityCache[entityName]?.Attributes ?? []) : null + }; + })]; + + return Ok(new EntityRerankResults(){Results = entityRerankResults, Success = true }); + } + /// /// Executes a query in the searchdomain /// diff --git a/src/Shared/Models/EntityResults.cs b/src/Shared/Models/EntityResults.cs index ed9b7a4..dedbeed 100644 --- a/src/Shared/Models/EntityResults.cs +++ b/src/Shared/Models/EntityResults.cs @@ -20,6 +20,23 @@ public class EntityQueryResult public Dictionary? Attributes { get; set; } } +public class EntityRerankResults : SuccesMessageBaseModel +{ + [JsonPropertyName("Results")] + public required List Results { get; set; } +} + +public class EntityRerankResult +{ + [JsonPropertyName("Name")] + public required string Name { get; set; } + [JsonPropertyName("Value")] + public float Value { get; set; } + [JsonPropertyName("Attributes")] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public Dictionary? Attributes { get; set; } +} + public class EntityIndexResult : SuccesMessageBaseModel {} public class EntityListResults