Added filtering query output to top n elements

This commit is contained in:
2025-12-20 20:13:28 +01:00
parent 1375e45f59
commit e4a711fcbd
2 changed files with 10 additions and 5 deletions

View File

@@ -25,7 +25,7 @@ public class EntityController : ControllerBase
} }
[HttpGet("Query")] [HttpGet("Query")]
public ActionResult<EntityQueryResults> Query(string searchdomain, string query) public ActionResult<EntityQueryResults> Query(string searchdomain, string query, int? topN)
{ {
Searchdomain searchdomain_; Searchdomain searchdomain_;
try try
@@ -40,7 +40,7 @@ public class EntityController : ControllerBase
_logger.LogError("Unable to retrieve the searchdomain {searchdomain} - {ex.Message} - {ex.StackTrace}", [searchdomain, ex.Message, ex.StackTrace]); _logger.LogError("Unable to retrieve the searchdomain {searchdomain} - {ex.Message} - {ex.StackTrace}", [searchdomain, ex.Message, ex.StackTrace]);
return Ok(new EntityQueryResults() {Results = [], Success = false, Message = "Unable to retrieve the searchdomain - it likely exists, but some other error happened." }); return Ok(new EntityQueryResults() {Results = [], Success = false, Message = "Unable to retrieve the searchdomain - it likely exists, but some other error happened." });
} }
var results = searchdomain_.Search(query); List<(float, string)> results = searchdomain_.Search(query, topN);
List<EntityQueryResult> queryResults = [.. results.Select(r => new EntityQueryResult List<EntityQueryResult> queryResults = [.. results.Select(r => new EntityQueryResult
{ {
Name = r.Item2, Name = r.Item2,

View File

@@ -154,7 +154,7 @@ public class Searchdomain
embeddingCache = []; // TODO remove this and implement proper remediation to improve performance embeddingCache = []; // TODO remove this and implement proper remediation to improve performance
} }
public List<(float, string)> Search(string query) public List<(float, string)> Search(string query, int? topN = null)
{ {
if (searchCache.TryGetValue(query, out DateTimedSearchResult cachedResult)) if (searchCache.TryGetValue(query, out DateTimedSearchResult cachedResult))
{ {
@@ -190,9 +190,14 @@ public class Searchdomain
} }
result.Add((entity.probMethod(datapointProbs), entity.name)); result.Add((entity.probMethod(datapointProbs), entity.name));
} }
List<(float, string)> results = [.. result.OrderByDescending(s => s.Item1)]; IEnumerable<(float, string)> sortedResults = result.OrderByDescending(s => s.Item1);
if (topN is not null)
{
sortedResults = sortedResults.Take(topN ?? 0);
}
List<(float, string)> results = [.. sortedResults];
List<ResultItem> searchResult = new( List<ResultItem> searchResult = new(
[.. results.Select(r => [.. sortedResults.Select(r =>
new ResultItem(r.Item1, r.Item2 ))] new ResultItem(r.Item1, r.Item2 ))]
); );
searchCache[query] = new DateTimedSearchResult(DateTime.Now, searchResult); searchCache[query] = new DateTimedSearchResult(DateTime.Now, searchResult);