Added reranker exploration setup
This commit is contained in:
@@ -160,6 +160,21 @@ public class Client
|
|||||||
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null);
|
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async Task<EntityQueryResults> SearchdomainQueryRerankedAsync(string searchdomain, string query, string rerankerModel, int topN, int topNRetrieval, bool returnAttributes = false)
|
||||||
|
{
|
||||||
|
Dictionary<string, string> parameters = new()
|
||||||
|
{
|
||||||
|
{ "searchdomain", searchdomain },
|
||||||
|
{ "query", query },
|
||||||
|
{ "rerankerModel", (rerankerModel).ToString() },
|
||||||
|
{ "topN", (topN).ToString() },
|
||||||
|
{ "topNRetrieval", (topNRetrieval).ToString() }
|
||||||
|
};
|
||||||
|
if (returnAttributes) parameters.Add("returnAttributes", returnAttributes.ToString());
|
||||||
|
|
||||||
|
return await FetchUrlAndProcessJson<EntityQueryResults>(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "QueryReranked", parameters), null);
|
||||||
|
}
|
||||||
|
|
||||||
public async Task<SearchdomainDeleteSearchResult> SearchdomainDeleteQueryAsync(string searchdomain, string query)
|
public async Task<SearchdomainDeleteSearchResult> SearchdomainDeleteQueryAsync(string searchdomain, string query)
|
||||||
{
|
{
|
||||||
Dictionary<string, string> parameters = new()
|
Dictionary<string, string> parameters = new()
|
||||||
|
|||||||
@@ -132,6 +132,107 @@ public class AIProvider
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public IEnumerable<(int index, float score)> Rerank(string modelUri, string input, string[] documents, int topN)
|
||||||
|
{
|
||||||
|
Uri uri = new(modelUri);
|
||||||
|
string provider = uri.Scheme;
|
||||||
|
string model = uri.AbsolutePath;
|
||||||
|
AiProvider? aIProvider = AiProvidersConfiguration
|
||||||
|
.FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower()))
|
||||||
|
.Value;
|
||||||
|
if (aIProvider is null)
|
||||||
|
{
|
||||||
|
_logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]);
|
||||||
|
throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}");
|
||||||
|
}
|
||||||
|
using var httpClient = new HttpClient();
|
||||||
|
httpClient.Timeout = TimeSpan.FromMinutes(150);
|
||||||
|
|
||||||
|
string indexJsonPath = "";
|
||||||
|
string scoreJsonPath = "";
|
||||||
|
IEnumerable<(string, float)> values = [];
|
||||||
|
Uri baseUri = new(aIProvider.BaseURL);
|
||||||
|
Uri requestUri;
|
||||||
|
IRerankRequestBody rerankRequestBody;
|
||||||
|
string[][] requestHeaders = [];
|
||||||
|
switch (aIProvider.Handler)
|
||||||
|
{
|
||||||
|
case "openai":
|
||||||
|
indexJsonPath = "$.results[*].index";
|
||||||
|
scoreJsonPath = "$.results[*].relevance_score";
|
||||||
|
requestUri = new Uri(baseUri, "/v1/rerank");
|
||||||
|
rerankRequestBody = new OpenAIRerankRequestBody()
|
||||||
|
{
|
||||||
|
model = model,
|
||||||
|
query = input,
|
||||||
|
documents = documents,
|
||||||
|
top_n = topN
|
||||||
|
};
|
||||||
|
if (aIProvider.ApiKey is not null)
|
||||||
|
{
|
||||||
|
requestHeaders = [
|
||||||
|
["Authorization", $"Bearer {aIProvider.ApiKey}"]
|
||||||
|
];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
_logger.LogError("Invalid reranking handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]);
|
||||||
|
throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}.");
|
||||||
|
}
|
||||||
|
var requestContent = new StringContent(
|
||||||
|
JsonConvert.SerializeObject(rerankRequestBody),
|
||||||
|
Encoding.UTF8,
|
||||||
|
"application/json"
|
||||||
|
);
|
||||||
|
|
||||||
|
var request = new HttpRequestMessage()
|
||||||
|
{
|
||||||
|
RequestUri = requestUri,
|
||||||
|
Method = HttpMethod.Post,
|
||||||
|
Content = requestContent
|
||||||
|
};
|
||||||
|
|
||||||
|
foreach (var header in requestHeaders)
|
||||||
|
{
|
||||||
|
request.Headers.Add(header[0], header[1]);
|
||||||
|
}
|
||||||
|
HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result;
|
||||||
|
string responseContent = response.Content.ReadAsStringAsync().Result;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
JObject responseContentJson = JObject.Parse(responseContent);
|
||||||
|
List<JToken>? responseContentIndexTokens = [.. responseContentJson.SelectTokens(indexJsonPath)];
|
||||||
|
List<JToken>? responseContentScoreTokens = [.. responseContentJson.SelectTokens(scoreJsonPath)];
|
||||||
|
if (responseContentIndexTokens is null || responseContentIndexTokens.Count == 0
|
||||||
|
|| responseContentScoreTokens is null || responseContentScoreTokens.Count == 0)
|
||||||
|
{
|
||||||
|
if (responseContentJson.TryGetValue("error", out JToken? errorMessageJson) && errorMessageJson is not null)
|
||||||
|
{
|
||||||
|
string errorMessage = (string?)errorMessageJson.Value<string>("message") ?? "";
|
||||||
|
string errorCode = (string?)errorMessageJson.Value<string>("code") ?? "";
|
||||||
|
string errorType = (string?)errorMessageJson.Value<string>("type") ?? "";
|
||||||
|
_logger.LogError("Unable to retrieve reranking results due to error: {errorCode} - {errorMessage} - {errorType}", [errorCode, errorMessage, errorType]);
|
||||||
|
throw new Exception($"Unable to retrieve reranking results due to error: {errorMessage}");
|
||||||
|
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
_logger.LogError("Unable to select tokens using JSONPath {indexJsonPath} for string: {responseContent}.", [indexJsonPath, responseContent]);
|
||||||
|
throw new JSONPathSelectionException(indexJsonPath, responseContent);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
IEnumerable<int> indices = responseContentIndexTokens.Select(token => token.ToObject<int>());
|
||||||
|
IEnumerable<float> scores = responseContentScoreTokens.Select(token => token.ToObject<float>());
|
||||||
|
IEnumerable<(int index, float score)> zipped = indices.Zip(scores, (index, score) => (index, score));
|
||||||
|
|
||||||
|
return zipped;
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
_logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public string[] GetModels()
|
public string[] GetModels()
|
||||||
{
|
{
|
||||||
var aIProviders = AiProvidersConfiguration;
|
var aIProviders = AiProvidersConfiguration;
|
||||||
@@ -239,4 +340,16 @@ public class OpenAIEmbedRequestBody : IEmbedRequestBody
|
|||||||
{
|
{
|
||||||
public required string model { get; set; }
|
public required string model { get; set; }
|
||||||
public required string[] input { get; set; }
|
public required string[] input { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public interface IRerankRequestBody { }
|
||||||
|
|
||||||
|
|
||||||
|
public class OpenAIRerankRequestBody : IRerankRequestBody
|
||||||
|
{
|
||||||
|
public required string model { get; set; }
|
||||||
|
public required string query { get; set; }
|
||||||
|
public required int top_n { get; set; }
|
||||||
|
public required string[] documents { get; set; }
|
||||||
}
|
}
|
||||||
@@ -148,6 +148,78 @@ public class SearchdomainController : ControllerBase
|
|||||||
return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true });
|
return Ok(new SearchdomainQueriesResults() { Searches = searchCache, Success = true });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Executes a query in the searchdomain and reranks the result using a specified reranker
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="searchdomain">Name of the searchdomain</param>
|
||||||
|
/// <param name="query">Query to execute</param>
|
||||||
|
/// <param name="topN">Return only the top N results</param>
|
||||||
|
/// <param name="returnAttributes">Return the attributes of the object</param>
|
||||||
|
[HttpPost("QueryReranked")]
|
||||||
|
public ActionResult<EntityRerankResults> QueryReranked([Required]string searchdomain, [Required]string query, [Required]string rerankerModel, int topN, int topNRetrieval, ProbMethodEnum probMethod = ProbMethodEnum.HVEWAvg, bool returnAttributes = false)
|
||||||
|
{
|
||||||
|
|
||||||
|
(Searchdomain? searchdomain_, int? httpStatusCode, string? message) = SearchdomainHelper.TryGetSearchdomain(_domainManager, searchdomain, _logger);
|
||||||
|
if (searchdomain_ is null || httpStatusCode is not null) return StatusCode(httpStatusCode ?? 500, new SearchdomainUpdateResults(){Success = false, Message = message});
|
||||||
|
List<(float, string)> results = searchdomain_.Search(query, topNRetrieval);
|
||||||
|
List<(string Name, Dictionary<string, string> Attributes)> queryResults = [.. results.Select(r => (
|
||||||
|
Name: r.Item2,
|
||||||
|
Attributes: searchdomain_.EntityCache[r.Item2]?.Attributes ?? []
|
||||||
|
))];
|
||||||
|
|
||||||
|
|
||||||
|
// Key: Attribute name
|
||||||
|
Dictionary<string, List<(string EntityName, string AttributeValue)>> resultsByAttribute = [];
|
||||||
|
queryResults.ForEach(r =>
|
||||||
|
{
|
||||||
|
foreach (var kv in r.Attributes)
|
||||||
|
{
|
||||||
|
if (!resultsByAttribute.TryGetValue(kv.Key, out List<(string EntityName, string AttributeValue)>? values) || values is null)
|
||||||
|
{
|
||||||
|
values = [];
|
||||||
|
resultsByAttribute[kv.Key] = values;
|
||||||
|
}
|
||||||
|
values.Add((r.Name, kv.Value));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Key: EntityName
|
||||||
|
Dictionary<string, List<(string attribute, float score)>> scoresByEntity = [];
|
||||||
|
foreach (var kv in resultsByAttribute)
|
||||||
|
{
|
||||||
|
string attributeName = kv.Key;
|
||||||
|
List<(string EntityName, string AttributeValue)> nameValuePairs = kv.Value;
|
||||||
|
|
||||||
|
List<string> documents = [.. nameValuePairs.Select(r => r.AttributeValue)];
|
||||||
|
List<(int index, float score)> rerankResults = [.. searchdomain_.AiProvider.Rerank(rerankerModel, query, [.. documents], topN)];
|
||||||
|
List<(string entityName, float score)> rerankedScores = [.. rerankResults.Select(r => (nameValuePairs.ElementAt(r.index).EntityName, r.score))];
|
||||||
|
foreach ((string entityName, float score) in rerankedScores)
|
||||||
|
{
|
||||||
|
if (!scoresByEntity.TryGetValue(entityName, out List<(string attribute, float score)>? values) || values is null)
|
||||||
|
{
|
||||||
|
values = [];
|
||||||
|
scoresByEntity[entityName] = values;
|
||||||
|
}
|
||||||
|
values.Add((attributeName, score));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
List<EntityRerankResult> entityRerankResults = [.. scoresByEntity.Select(scoreKV =>
|
||||||
|
{
|
||||||
|
string entityName = scoreKV.Key;
|
||||||
|
float score = new ProbMethod(probMethod).Method(scoreKV.Value);
|
||||||
|
return new EntityRerankResult()
|
||||||
|
{
|
||||||
|
Name = entityName,
|
||||||
|
Value = score,
|
||||||
|
Attributes = returnAttributes ? (searchdomain_.EntityCache[entityName]?.Attributes ?? []) : null
|
||||||
|
};
|
||||||
|
})];
|
||||||
|
|
||||||
|
return Ok(new EntityRerankResults(){Results = entityRerankResults, Success = true });
|
||||||
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Executes a query in the searchdomain
|
/// Executes a query in the searchdomain
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
|||||||
@@ -20,6 +20,23 @@ public class EntityQueryResult
|
|||||||
public Dictionary<string, string>? Attributes { get; set; }
|
public Dictionary<string, string>? Attributes { get; set; }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class EntityRerankResults : SuccesMessageBaseModel
|
||||||
|
{
|
||||||
|
[JsonPropertyName("Results")]
|
||||||
|
public required List<EntityRerankResult> Results { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class EntityRerankResult
|
||||||
|
{
|
||||||
|
[JsonPropertyName("Name")]
|
||||||
|
public required string Name { get; set; }
|
||||||
|
[JsonPropertyName("Value")]
|
||||||
|
public float Value { get; set; }
|
||||||
|
[JsonPropertyName("Attributes")]
|
||||||
|
[JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
|
||||||
|
public Dictionary<string, string>? Attributes { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
public class EntityIndexResult : SuccesMessageBaseModel {}
|
public class EntityIndexResult : SuccesMessageBaseModel {}
|
||||||
|
|
||||||
public class EntityListResults
|
public class EntityListResults
|
||||||
|
|||||||
Reference in New Issue
Block a user