Added AIProvider and support for OpenAI compatible APIs
This commit is contained in:
@@ -8,11 +8,11 @@ example_content = "./Scripts/example_content"
|
|||||||
probmethod = "LVEWAvg"
|
probmethod = "LVEWAvg"
|
||||||
example_searchdomain = "example_" + probmethod
|
example_searchdomain = "example_" + probmethod
|
||||||
example_counter = 0
|
example_counter = 0
|
||||||
models = ["bge-m3", "mxbai-embed-large"]
|
models = ["ollama:bge-m3", "ollama:mxbai-embed-large"]
|
||||||
probmethod_datapoint = probmethod
|
probmethod_datapoint = probmethod
|
||||||
probmethod_entity = probmethod
|
probmethod_entity = probmethod
|
||||||
# Example for a dictionary based weighted average:
|
# Example for a dictionary based weighted average:
|
||||||
# probmethod_datapoint = "DictionaryWeightedAverage:{\"bge-m3\": 4, \"mxbai-embed-large\": 1}"
|
# probmethod_datapoint = "DictionaryWeightedAverage:{\"ollama:bge-m3\": 4, \"ollama:mxbai-embed-large\": 1}"
|
||||||
# probmethod_entity = "DictionaryWeightedAverage:{\"title\": 2, \"filename\": 0.1, \"text\": 0.25}"
|
# probmethod_entity = "DictionaryWeightedAverage:{\"title\": 2, \"filename\": 0.1, \"text\": 0.25}"
|
||||||
|
|
||||||
def init(toolset: Toolset):
|
def init(toolset: Toolset):
|
||||||
|
|||||||
147
src/Server/AIProvider.cs
Normal file
147
src/Server/AIProvider.cs
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
|
||||||
|
using System.Collections.Frozen;
|
||||||
|
using System.Collections.Immutable;
|
||||||
|
using System.Text;
|
||||||
|
using System.Text.Json;
|
||||||
|
using System.Text.Json.Serialization;
|
||||||
|
using Newtonsoft.Json;
|
||||||
|
using Newtonsoft.Json.Linq;
|
||||||
|
using Server.Exceptions;
|
||||||
|
|
||||||
|
namespace server;
|
||||||
|
|
||||||
|
public class AIProvider
|
||||||
|
{
|
||||||
|
private readonly ILogger<AIProvider> _logger;
|
||||||
|
private readonly IConfiguration _configuration;
|
||||||
|
public AIProvidersConfiguration aIProvidersConfiguration;
|
||||||
|
|
||||||
|
public AIProvider(ILogger<AIProvider> logger, IConfiguration configuration)
|
||||||
|
{
|
||||||
|
_logger = logger;
|
||||||
|
_configuration = configuration;
|
||||||
|
AIProvidersConfiguration? retrievedAiProvidersConfiguration = _configuration
|
||||||
|
.GetSection("Embeddingsearch")
|
||||||
|
.Get<AIProvidersConfiguration>();
|
||||||
|
if (retrievedAiProvidersConfiguration is null)
|
||||||
|
{
|
||||||
|
_logger.LogCritical("Unable to build AIProvidersConfiguration. Please check your configuration.");
|
||||||
|
throw new ServerConfigurationException("Unable to build AIProvidersConfiguration. Please check your configuration.");
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
aIProvidersConfiguration = retrievedAiProvidersConfiguration;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] GenerateEmbeddings(string modelUri, string[] input)
|
||||||
|
{
|
||||||
|
Uri uri = new(modelUri);
|
||||||
|
string provider = uri.Scheme;
|
||||||
|
string model = uri.AbsolutePath;
|
||||||
|
AIProviderConfiguration? aIProvider = aIProvidersConfiguration.AiProviders
|
||||||
|
.FirstOrDefault(x => String.Equals(x.Key.ToLower(), provider.ToLower()))
|
||||||
|
.Value;
|
||||||
|
if (aIProvider is null)
|
||||||
|
{
|
||||||
|
_logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]);
|
||||||
|
throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}");
|
||||||
|
}
|
||||||
|
using var httpClient = new HttpClient();
|
||||||
|
|
||||||
|
string embeddingsJsonPath = "";
|
||||||
|
Uri baseUri = new(aIProvider.BaseURL);
|
||||||
|
Uri requestUri;
|
||||||
|
IEmbedRequestBody embedRequest;
|
||||||
|
string[][] requestHeaders = [];
|
||||||
|
switch (aIProvider.Handler)
|
||||||
|
{
|
||||||
|
case "ollama":
|
||||||
|
embeddingsJsonPath = "$.embeddings[*]";
|
||||||
|
requestUri = new Uri(baseUri, "/api/embed");
|
||||||
|
embedRequest = new OllamaEmbedRequestBody()
|
||||||
|
{
|
||||||
|
input = input,
|
||||||
|
model = model
|
||||||
|
};
|
||||||
|
break;
|
||||||
|
case "openai":
|
||||||
|
embeddingsJsonPath = "$.data[*].embedding";
|
||||||
|
requestUri = new Uri(baseUri, "/v1/embeddings");
|
||||||
|
embedRequest = new OpenAIEmbedRequestBody()
|
||||||
|
{
|
||||||
|
input = input,
|
||||||
|
model = model
|
||||||
|
};
|
||||||
|
if (aIProvider.ApiKey is not null)
|
||||||
|
{
|
||||||
|
requestHeaders = [
|
||||||
|
["Authorization", $"Bearer {aIProvider.ApiKey}"]
|
||||||
|
];
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
_logger.LogError("Unknown handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]);
|
||||||
|
throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}.");
|
||||||
|
}
|
||||||
|
var requestContent = new StringContent(
|
||||||
|
JsonConvert.SerializeObject(embedRequest),
|
||||||
|
UnicodeEncoding.UTF8,
|
||||||
|
"application/json"
|
||||||
|
);
|
||||||
|
|
||||||
|
var request = new HttpRequestMessage()
|
||||||
|
{
|
||||||
|
RequestUri = requestUri,
|
||||||
|
Method = HttpMethod.Post,
|
||||||
|
Content = requestContent
|
||||||
|
};
|
||||||
|
|
||||||
|
foreach (var header in requestHeaders)
|
||||||
|
{
|
||||||
|
request.Headers.Add(header[0], header[1]);
|
||||||
|
}
|
||||||
|
HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result;
|
||||||
|
string responseContent = response.Content.ReadAsStringAsync().Result;
|
||||||
|
try
|
||||||
|
{
|
||||||
|
JObject responseContentJson = JObject.Parse(responseContent);
|
||||||
|
JToken? responseContentTokens = responseContentJson.SelectToken(embeddingsJsonPath);
|
||||||
|
if (responseContentTokens is null)
|
||||||
|
{
|
||||||
|
throw new Exception($"Unable to select tokens using JSONPath {embeddingsJsonPath}."); // TODO add proper exception
|
||||||
|
}
|
||||||
|
return [.. responseContentTokens.Values<float>()];
|
||||||
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
_logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]);
|
||||||
|
throw new Exception($"Unable to parse the response to valid embeddings. {ex.Message}"); // TODO add proper exception
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class AIProvidersConfiguration
|
||||||
|
{
|
||||||
|
public required Dictionary<string, AIProviderConfiguration> AiProviders { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class AIProviderConfiguration
|
||||||
|
{
|
||||||
|
public required string Handler { get; set; }
|
||||||
|
public required string BaseURL { get; set; }
|
||||||
|
public string? ApiKey { get; set; }
|
||||||
|
}
|
||||||
|
public interface IEmbedRequestBody { }
|
||||||
|
|
||||||
|
public class OllamaEmbedRequestBody : IEmbedRequestBody
|
||||||
|
{
|
||||||
|
public required string model { get; set; }
|
||||||
|
public required string[] input { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class OpenAIEmbedRequestBody : IEmbedRequestBody
|
||||||
|
{
|
||||||
|
public required string model { get; set; }
|
||||||
|
public required string[] input { get; set; }
|
||||||
|
}
|
||||||
@@ -46,7 +46,7 @@ public class EntityController : ControllerBase
|
|||||||
List<Entity>? entities = SearchdomainHelper.EntitiesFromJSON(
|
List<Entity>? entities = SearchdomainHelper.EntitiesFromJSON(
|
||||||
[],
|
[],
|
||||||
_domainManager.embeddingCache,
|
_domainManager.embeddingCache,
|
||||||
_domainManager.client,
|
_domainManager.aIProvider,
|
||||||
_domainManager.helper,
|
_domainManager.helper,
|
||||||
_logger,
|
_logger,
|
||||||
JsonSerializer.Serialize(jsonEntities));
|
JsonSerializer.Serialize(jsonEntities));
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ using System.Threading.Tasks;
|
|||||||
using Microsoft.Extensions.AI;
|
using Microsoft.Extensions.AI;
|
||||||
using OllamaSharp;
|
using OllamaSharp;
|
||||||
using OllamaSharp.Models;
|
using OllamaSharp.Models;
|
||||||
|
using server;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -29,9 +30,9 @@ public class Datapoint
|
|||||||
return probMethod.method(probabilities);
|
return probMethod.method(probabilities);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
|
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
|
||||||
{
|
{
|
||||||
return GenerateEmbeddings(content, models, ollama, []);
|
return GenerateEmbeddings(content, models, aIProvider, []);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(List<string> contents, string model, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
public static Dictionary<string, float[]> GenerateEmbeddings(List<string> contents, string model, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
@@ -78,7 +79,7 @@ public class Datapoint
|
|||||||
return retVal;
|
return retVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
{
|
{
|
||||||
Dictionary<string, float[]> retVal = [];
|
Dictionary<string, float[]> retVal = [];
|
||||||
foreach (string model in models)
|
foreach (string model in models)
|
||||||
@@ -88,24 +89,17 @@ public class Datapoint
|
|||||||
retVal[model] = embeddingCache[model][content];
|
retVal[model] = embeddingCache[model][content];
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
EmbedRequest request = new()
|
var response = aIProvider.GenerateEmbeddings(model, [content]);
|
||||||
{
|
|
||||||
Model = model,
|
|
||||||
Input = [content]
|
|
||||||
};
|
|
||||||
|
|
||||||
var response = ollama.EmbedAsync(request).Result;
|
|
||||||
if (response is not null)
|
if (response is not null)
|
||||||
{
|
{
|
||||||
float[] var = [.. response.Embeddings.First()];
|
retVal[model] = response;
|
||||||
retVal[model] = var;
|
|
||||||
if (!embeddingCache.ContainsKey(model))
|
if (!embeddingCache.ContainsKey(model))
|
||||||
{
|
{
|
||||||
embeddingCache[model] = [];
|
embeddingCache[model] = [];
|
||||||
}
|
}
|
||||||
if (!embeddingCache[model].ContainsKey(content))
|
if (!embeddingCache[model].ContainsKey(content))
|
||||||
{
|
{
|
||||||
embeddingCache[model][content] = var;
|
embeddingCache[model][content] = response;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ public class AIProviderHealthCheck : IHealthCheck
|
|||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result;
|
//var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result; // TODO reimplement this
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ using System.Text;
|
|||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
using MySql.Data.MySqlClient;
|
using MySql.Data.MySqlClient;
|
||||||
using OllamaSharp;
|
using OllamaSharp;
|
||||||
|
using server;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -41,7 +42,7 @@ public static class SearchdomainHelper
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json)
|
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, string json)
|
||||||
{
|
{
|
||||||
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
|
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
|
||||||
if (jsonEntities is null)
|
if (jsonEntities is null)
|
||||||
@@ -65,10 +66,11 @@ public static class SearchdomainHelper
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
ConcurrentQueue<Entity> retVal = [];
|
ConcurrentQueue<Entity> retVal = [];
|
||||||
Parallel.ForEach(jsonEntities, jSONEntity =>
|
ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool
|
||||||
|
Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity =>
|
||||||
{
|
{
|
||||||
using var tempHelper = helper.DuplicateConnection();
|
using var tempHelper = helper.DuplicateConnection();
|
||||||
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity);
|
var entity = EntityFromJSON(entityCache, embeddingCache, aIProvider, tempHelper, logger, jSONEntity);
|
||||||
if (entity is not null)
|
if (entity is not null)
|
||||||
{
|
{
|
||||||
retVal.Enqueue(entity);
|
retVal.Enqueue(entity);
|
||||||
@@ -77,7 +79,7 @@ public static class SearchdomainHelper
|
|||||||
return [.. retVal];
|
return [.. retVal];
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
||||||
{
|
{
|
||||||
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
||||||
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
||||||
@@ -130,14 +132,14 @@ public static class SearchdomainHelper
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], ollama, embeddingCache);
|
var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], aIProvider, embeddingCache);
|
||||||
embeddings.Add(model, additionalEmbeddings.First().Value);
|
embeddings.Add(model, additionalEmbeddings.First().Value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
|
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache);
|
||||||
}
|
}
|
||||||
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
||||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
using ElmahCore;
|
using ElmahCore;
|
||||||
using ElmahCore.Mvc;
|
using ElmahCore.Mvc;
|
||||||
using Serilog;
|
using Serilog;
|
||||||
|
using server;
|
||||||
using Server;
|
using Server;
|
||||||
using Server.HealthChecks;
|
using Server.HealthChecks;
|
||||||
|
|
||||||
@@ -17,6 +18,7 @@ Log.Logger = new LoggerConfiguration()
|
|||||||
.CreateLogger();
|
.CreateLogger();
|
||||||
builder.Logging.AddSerilog();
|
builder.Logging.AddSerilog();
|
||||||
builder.Services.AddSingleton<SearchdomainManager>();
|
builder.Services.AddSingleton<SearchdomainManager>();
|
||||||
|
builder.Services.AddSingleton<AIProvider>();
|
||||||
builder.Services.AddHealthChecks()
|
builder.Services.AddHealthChecks()
|
||||||
.AddCheck<DatabaseHealthCheck>("DatabaseHealthCheck")
|
.AddCheck<DatabaseHealthCheck>("DatabaseHealthCheck")
|
||||||
.AddCheck<AIProviderHealthCheck>("AIProviderHealthChecck");
|
.AddCheck<AIProviderHealthCheck>("AIProviderHealthChecck");
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ using Server;
|
|||||||
using System.Security.Cryptography;
|
using System.Security.Cryptography;
|
||||||
using System.Text;
|
using System.Text;
|
||||||
using System.Collections.Concurrent;
|
using System.Collections.Concurrent;
|
||||||
|
using server;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ public class Searchdomain
|
|||||||
{
|
{
|
||||||
private readonly string _connectionString;
|
private readonly string _connectionString;
|
||||||
private readonly string _provider;
|
private readonly string _provider;
|
||||||
public OllamaApiClient ollama;
|
public AIProvider aIProvider;
|
||||||
public string searchdomain;
|
public string searchdomain;
|
||||||
public int id;
|
public int id;
|
||||||
public Dictionary<string, List<(DateTime, List<(float, string)>)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x
|
public Dictionary<string, List<(DateTime, List<(float, string)>)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x
|
||||||
@@ -42,12 +43,12 @@ public class Searchdomain
|
|||||||
|
|
||||||
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
|
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
|
||||||
|
|
||||||
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
||||||
{
|
{
|
||||||
_connectionString = connectionString;
|
_connectionString = connectionString;
|
||||||
_provider = provider.ToLower();
|
_provider = provider.ToLower();
|
||||||
this.searchdomain = searchdomain;
|
this.searchdomain = searchdomain;
|
||||||
this.ollama = ollama;
|
this.aIProvider = aIProvider;
|
||||||
this.embeddingCache = embeddingCache;
|
this.embeddingCache = embeddingCache;
|
||||||
this._logger = logger;
|
this._logger = logger;
|
||||||
searchCache = [];
|
searchCache = [];
|
||||||
@@ -69,7 +70,7 @@ public class Searchdomain
|
|||||||
{
|
{
|
||||||
["id"] = this.id
|
["id"] = this.id
|
||||||
};
|
};
|
||||||
DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain);
|
DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain); // TODO fix: parametersIDSearchdomain defined, but not used
|
||||||
Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = [];
|
Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = [];
|
||||||
while (embeddingReader.Read())
|
while (embeddingReader.Read())
|
||||||
{
|
{
|
||||||
@@ -162,7 +163,7 @@ public class Searchdomain
|
|||||||
{
|
{
|
||||||
if (!embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings))
|
if (!embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings))
|
||||||
{
|
{
|
||||||
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, ollama);
|
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider);
|
||||||
if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits
|
if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits
|
||||||
{ // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries
|
{ // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries
|
||||||
embeddingCache.Add(query, queryEmbeddings);
|
embeddingCache.Add(query, queryEmbeddings);
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ using OllamaSharp;
|
|||||||
using Microsoft.IdentityModel.Tokens;
|
using Microsoft.IdentityModel.Tokens;
|
||||||
using Server.Exceptions;
|
using Server.Exceptions;
|
||||||
using Server.Migrations;
|
using Server.Migrations;
|
||||||
|
using server;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
@@ -12,25 +13,20 @@ public class SearchdomainManager
|
|||||||
private Dictionary<string, Searchdomain> searchdomains = [];
|
private Dictionary<string, Searchdomain> searchdomains = [];
|
||||||
private readonly ILogger<SearchdomainManager> _logger;
|
private readonly ILogger<SearchdomainManager> _logger;
|
||||||
private readonly IConfiguration _config;
|
private readonly IConfiguration _config;
|
||||||
|
public readonly AIProvider aIProvider;
|
||||||
private readonly string ollamaURL;
|
private readonly string ollamaURL;
|
||||||
private readonly string connectionString;
|
private readonly string connectionString;
|
||||||
public OllamaApiClient client;
|
|
||||||
private MySqlConnection connection;
|
private MySqlConnection connection;
|
||||||
public SQLHelper helper;
|
public SQLHelper helper;
|
||||||
public Dictionary<string, Dictionary<string, float[]>> embeddingCache;
|
public Dictionary<string, Dictionary<string, float[]>> embeddingCache;
|
||||||
|
|
||||||
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config)
|
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config, AIProvider aIProvider)
|
||||||
{
|
{
|
||||||
_logger = logger;
|
_logger = logger;
|
||||||
_config = config;
|
_config = config;
|
||||||
|
this.aIProvider = aIProvider;
|
||||||
embeddingCache = [];
|
embeddingCache = [];
|
||||||
ollamaURL = _config.GetSection("Embeddingsearch")["OllamaURL"] ?? "";
|
|
||||||
connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? "";
|
connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? "";
|
||||||
if (ollamaURL.IsNullOrEmpty() || connectionString.IsNullOrEmpty())
|
|
||||||
{
|
|
||||||
throw new ServerConfigurationException("Ollama URL or connection string is empty");
|
|
||||||
}
|
|
||||||
client = new(new Uri(ollamaURL));
|
|
||||||
connection = new MySqlConnection(connectionString);
|
connection = new MySqlConnection(connectionString);
|
||||||
connection.Open();
|
connection.Open();
|
||||||
helper = new SQLHelper(connection, connectionString);
|
helper = new SQLHelper(connection, connectionString);
|
||||||
@@ -53,13 +49,18 @@ public class SearchdomainManager
|
|||||||
}
|
}
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger));
|
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, aIProvider, embeddingCache, _logger));
|
||||||
}
|
}
|
||||||
catch (MySqlException)
|
catch (MySqlException)
|
||||||
{
|
{
|
||||||
_logger.LogError("Unable to find the searchdomain {searchdomain}", searchdomain);
|
_logger.LogError("Unable to find the searchdomain {searchdomain}", searchdomain);
|
||||||
throw new Exception($"Unable to find the searchdomain {searchdomain}");
|
throw new Exception($"Unable to find the searchdomain {searchdomain}");
|
||||||
}
|
}
|
||||||
|
catch (Exception ex)
|
||||||
|
{
|
||||||
|
_logger.LogError("Unable to load the searchdomain {searchdomain} due to the following exception: {ex}", [searchdomain, ex.Message]);
|
||||||
|
throw;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void InvalidateSearchdomainCache(string searchdomainName)
|
public void InvalidateSearchdomainCache(string searchdomainName)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
<PackageReference Include="ElmahCore" Version="2.1.2" />
|
<PackageReference Include="ElmahCore" Version="2.1.2" />
|
||||||
|
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
|
||||||
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
|
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
|
||||||
<PackageReference Include="Serilog.Sinks.File" Version="7.0.0" />
|
<PackageReference Include="Serilog.Sinks.File" Version="7.0.0" />
|
||||||
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />
|
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />
|
||||||
|
|||||||
@@ -24,7 +24,17 @@
|
|||||||
"172.17.0.1"
|
"172.17.0.1"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"OllamaURL": "http://localhost:11434",
|
"AiProviders": {
|
||||||
|
"ollama": {
|
||||||
|
"handler": "ollama",
|
||||||
|
"baseURL": "http://192.168.0.101:11434"
|
||||||
|
},
|
||||||
|
"localAI": {
|
||||||
|
"handler": "openai",
|
||||||
|
"baseURL": "http://localhost:8080",
|
||||||
|
"ApiKey": "Some API key here"
|
||||||
|
}
|
||||||
|
},
|
||||||
"ApiKeys": ["Some UUID here", "Another UUID here"],
|
"ApiKeys": ["Some UUID here", "Another UUID here"],
|
||||||
"UseHttpsRedirection": true
|
"UseHttpsRedirection": true
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user