Added AIProvider and support for OpenAI compatible APIs

This commit is contained in:
2025-07-06 22:28:45 +02:00
parent 84a4a9d51e
commit b596695fd9
11 changed files with 196 additions and 38 deletions

View File

@@ -8,11 +8,11 @@ example_content = "./Scripts/example_content"
probmethod = "LVEWAvg" probmethod = "LVEWAvg"
example_searchdomain = "example_" + probmethod example_searchdomain = "example_" + probmethod
example_counter = 0 example_counter = 0
models = ["bge-m3", "mxbai-embed-large"] models = ["ollama:bge-m3", "ollama:mxbai-embed-large"]
probmethod_datapoint = probmethod probmethod_datapoint = probmethod
probmethod_entity = probmethod probmethod_entity = probmethod
# Example for a dictionary based weighted average: # Example for a dictionary based weighted average:
# probmethod_datapoint = "DictionaryWeightedAverage:{\"bge-m3\": 4, \"mxbai-embed-large\": 1}" # probmethod_datapoint = "DictionaryWeightedAverage:{\"ollama:bge-m3\": 4, \"ollama:mxbai-embed-large\": 1}"
# probmethod_entity = "DictionaryWeightedAverage:{\"title\": 2, \"filename\": 0.1, \"text\": 0.25}" # probmethod_entity = "DictionaryWeightedAverage:{\"title\": 2, \"filename\": 0.1, \"text\": 0.25}"
def init(toolset: Toolset): def init(toolset: Toolset):

147
src/Server/AIProvider.cs Normal file
View File

@@ -0,0 +1,147 @@
using System.Collections.Frozen;
using System.Collections.Immutable;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Server.Exceptions;
namespace server;
public class AIProvider
{
private readonly ILogger<AIProvider> _logger;
private readonly IConfiguration _configuration;
public AIProvidersConfiguration aIProvidersConfiguration;
public AIProvider(ILogger<AIProvider> logger, IConfiguration configuration)
{
_logger = logger;
_configuration = configuration;
AIProvidersConfiguration? retrievedAiProvidersConfiguration = _configuration
.GetSection("Embeddingsearch")
.Get<AIProvidersConfiguration>();
if (retrievedAiProvidersConfiguration is null)
{
_logger.LogCritical("Unable to build AIProvidersConfiguration. Please check your configuration.");
throw new ServerConfigurationException("Unable to build AIProvidersConfiguration. Please check your configuration.");
}
else
{
aIProvidersConfiguration = retrievedAiProvidersConfiguration;
}
}
public float[] GenerateEmbeddings(string modelUri, string[] input)
{
Uri uri = new(modelUri);
string provider = uri.Scheme;
string model = uri.AbsolutePath;
AIProviderConfiguration? aIProvider = aIProvidersConfiguration.AiProviders
.FirstOrDefault(x => String.Equals(x.Key.ToLower(), provider.ToLower()))
.Value;
if (aIProvider is null)
{
_logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]);
throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}");
}
using var httpClient = new HttpClient();
string embeddingsJsonPath = "";
Uri baseUri = new(aIProvider.BaseURL);
Uri requestUri;
IEmbedRequestBody embedRequest;
string[][] requestHeaders = [];
switch (aIProvider.Handler)
{
case "ollama":
embeddingsJsonPath = "$.embeddings[*]";
requestUri = new Uri(baseUri, "/api/embed");
embedRequest = new OllamaEmbedRequestBody()
{
input = input,
model = model
};
break;
case "openai":
embeddingsJsonPath = "$.data[*].embedding";
requestUri = new Uri(baseUri, "/v1/embeddings");
embedRequest = new OpenAIEmbedRequestBody()
{
input = input,
model = model
};
if (aIProvider.ApiKey is not null)
{
requestHeaders = [
["Authorization", $"Bearer {aIProvider.ApiKey}"]
];
}
break;
default:
_logger.LogError("Unknown handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]);
throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}.");
}
var requestContent = new StringContent(
JsonConvert.SerializeObject(embedRequest),
UnicodeEncoding.UTF8,
"application/json"
);
var request = new HttpRequestMessage()
{
RequestUri = requestUri,
Method = HttpMethod.Post,
Content = requestContent
};
foreach (var header in requestHeaders)
{
request.Headers.Add(header[0], header[1]);
}
HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result;
string responseContent = response.Content.ReadAsStringAsync().Result;
try
{
JObject responseContentJson = JObject.Parse(responseContent);
JToken? responseContentTokens = responseContentJson.SelectToken(embeddingsJsonPath);
if (responseContentTokens is null)
{
throw new Exception($"Unable to select tokens using JSONPath {embeddingsJsonPath}."); // TODO add proper exception
}
return [.. responseContentTokens.Values<float>()];
}
catch (Exception ex)
{
_logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]);
throw new Exception($"Unable to parse the response to valid embeddings. {ex.Message}"); // TODO add proper exception
}
}
}
public class AIProvidersConfiguration
{
public required Dictionary<string, AIProviderConfiguration> AiProviders { get; set; }
}
public class AIProviderConfiguration
{
public required string Handler { get; set; }
public required string BaseURL { get; set; }
public string? ApiKey { get; set; }
}
public interface IEmbedRequestBody { }
public class OllamaEmbedRequestBody : IEmbedRequestBody
{
public required string model { get; set; }
public required string[] input { get; set; }
}
public class OpenAIEmbedRequestBody : IEmbedRequestBody
{
public required string model { get; set; }
public required string[] input { get; set; }
}

View File

@@ -46,7 +46,7 @@ public class EntityController : ControllerBase
List<Entity>? entities = SearchdomainHelper.EntitiesFromJSON( List<Entity>? entities = SearchdomainHelper.EntitiesFromJSON(
[], [],
_domainManager.embeddingCache, _domainManager.embeddingCache,
_domainManager.client, _domainManager.aIProvider,
_domainManager.helper, _domainManager.helper,
_logger, _logger,
JsonSerializer.Serialize(jsonEntities)); JsonSerializer.Serialize(jsonEntities));

View File

@@ -6,6 +6,7 @@ using System.Threading.Tasks;
using Microsoft.Extensions.AI; using Microsoft.Extensions.AI;
using OllamaSharp; using OllamaSharp;
using OllamaSharp.Models; using OllamaSharp.Models;
using server;
namespace Server; namespace Server;
@@ -29,9 +30,9 @@ public class Datapoint
return probMethod.method(probabilities); return probMethod.method(probabilities);
} }
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama) public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
{ {
return GenerateEmbeddings(content, models, ollama, []); return GenerateEmbeddings(content, models, aIProvider, []);
} }
public static Dictionary<string, float[]> GenerateEmbeddings(List<string> contents, string model, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache) public static Dictionary<string, float[]> GenerateEmbeddings(List<string> contents, string model, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
@@ -78,7 +79,7 @@ public class Datapoint
return retVal; return retVal;
} }
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache) public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
{ {
Dictionary<string, float[]> retVal = []; Dictionary<string, float[]> retVal = [];
foreach (string model in models) foreach (string model in models)
@@ -88,24 +89,17 @@ public class Datapoint
retVal[model] = embeddingCache[model][content]; retVal[model] = embeddingCache[model][content];
continue; continue;
} }
EmbedRequest request = new() var response = aIProvider.GenerateEmbeddings(model, [content]);
{
Model = model,
Input = [content]
};
var response = ollama.EmbedAsync(request).Result;
if (response is not null) if (response is not null)
{ {
float[] var = [.. response.Embeddings.First()]; retVal[model] = response;
retVal[model] = var;
if (!embeddingCache.ContainsKey(model)) if (!embeddingCache.ContainsKey(model))
{ {
embeddingCache[model] = []; embeddingCache[model] = [];
} }
if (!embeddingCache[model].ContainsKey(content)) if (!embeddingCache[model].ContainsKey(content))
{ {
embeddingCache[model][content] = var; embeddingCache[model][content] = response;
} }
} }
} }

View File

@@ -17,7 +17,7 @@ public class AIProviderHealthCheck : IHealthCheck
{ {
try try
{ {
var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result; //var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result; // TODO reimplement this
} }
catch (Exception ex) catch (Exception ex)
{ {

View File

@@ -4,6 +4,7 @@ using System.Text;
using System.Text.Json; using System.Text.Json;
using MySql.Data.MySqlClient; using MySql.Data.MySqlClient;
using OllamaSharp; using OllamaSharp;
using server;
namespace Server; namespace Server;
@@ -41,7 +42,7 @@ public static class SearchdomainHelper
return null; return null;
} }
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json) public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, string json)
{ {
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json); List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
if (jsonEntities is null) if (jsonEntities is null)
@@ -65,10 +66,11 @@ public static class SearchdomainHelper
} }
} }
ConcurrentQueue<Entity> retVal = []; ConcurrentQueue<Entity> retVal = [];
Parallel.ForEach(jsonEntities, jSONEntity => ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool
Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity =>
{ {
using var tempHelper = helper.DuplicateConnection(); using var tempHelper = helper.DuplicateConnection();
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity); var entity = EntityFromJSON(entityCache, embeddingCache, aIProvider, tempHelper, logger, jSONEntity);
if (entity is not null) if (entity is not null)
{ {
retVal.Enqueue(entity); retVal.Enqueue(entity);
@@ -77,7 +79,7 @@ public static class SearchdomainHelper
return [.. retVal]; return [.. retVal];
} }
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json) public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
{ {
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = []; Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
@@ -130,14 +132,14 @@ public static class SearchdomainHelper
} }
else else
{ {
var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], ollama, embeddingCache); var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], aIProvider, embeddingCache);
embeddings.Add(model, additionalEmbeddings.First().Value); embeddings.Add(model, additionalEmbeddings.First().Value);
} }
} }
} }
else else
{ {
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache); embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache);
} }
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);

View File

@@ -1,6 +1,7 @@
using ElmahCore; using ElmahCore;
using ElmahCore.Mvc; using ElmahCore.Mvc;
using Serilog; using Serilog;
using server;
using Server; using Server;
using Server.HealthChecks; using Server.HealthChecks;
@@ -17,6 +18,7 @@ Log.Logger = new LoggerConfiguration()
.CreateLogger(); .CreateLogger();
builder.Logging.AddSerilog(); builder.Logging.AddSerilog();
builder.Services.AddSingleton<SearchdomainManager>(); builder.Services.AddSingleton<SearchdomainManager>();
builder.Services.AddSingleton<AIProvider>();
builder.Services.AddHealthChecks() builder.Services.AddHealthChecks()
.AddCheck<DatabaseHealthCheck>("DatabaseHealthCheck") .AddCheck<DatabaseHealthCheck>("DatabaseHealthCheck")
.AddCheck<AIProviderHealthCheck>("AIProviderHealthChecck"); .AddCheck<AIProviderHealthCheck>("AIProviderHealthChecck");

View File

@@ -21,6 +21,7 @@ using Server;
using System.Security.Cryptography; using System.Security.Cryptography;
using System.Text; using System.Text;
using System.Collections.Concurrent; using System.Collections.Concurrent;
using server;
namespace Server; namespace Server;
@@ -28,7 +29,7 @@ public class Searchdomain
{ {
private readonly string _connectionString; private readonly string _connectionString;
private readonly string _provider; private readonly string _provider;
public OllamaApiClient ollama; public AIProvider aIProvider;
public string searchdomain; public string searchdomain;
public int id; public int id;
public Dictionary<string, List<(DateTime, List<(float, string)>)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x public Dictionary<string, List<(DateTime, List<(float, string)>)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x
@@ -42,12 +43,12 @@ public class Searchdomain
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain() // TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
{ {
_connectionString = connectionString; _connectionString = connectionString;
_provider = provider.ToLower(); _provider = provider.ToLower();
this.searchdomain = searchdomain; this.searchdomain = searchdomain;
this.ollama = ollama; this.aIProvider = aIProvider;
this.embeddingCache = embeddingCache; this.embeddingCache = embeddingCache;
this._logger = logger; this._logger = logger;
searchCache = []; searchCache = [];
@@ -69,7 +70,7 @@ public class Searchdomain
{ {
["id"] = this.id ["id"] = this.id
}; };
DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain); DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain); // TODO fix: parametersIDSearchdomain defined, but not used
Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = []; Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = [];
while (embeddingReader.Read()) while (embeddingReader.Read())
{ {
@@ -162,7 +163,7 @@ public class Searchdomain
{ {
if (!embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings)) if (!embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings))
{ {
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, ollama); queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider);
if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits
{ // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries { // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries
embeddingCache.Add(query, queryEmbeddings); embeddingCache.Add(query, queryEmbeddings);

View File

@@ -4,6 +4,7 @@ using OllamaSharp;
using Microsoft.IdentityModel.Tokens; using Microsoft.IdentityModel.Tokens;
using Server.Exceptions; using Server.Exceptions;
using Server.Migrations; using Server.Migrations;
using server;
namespace Server; namespace Server;
@@ -12,25 +13,20 @@ public class SearchdomainManager
private Dictionary<string, Searchdomain> searchdomains = []; private Dictionary<string, Searchdomain> searchdomains = [];
private readonly ILogger<SearchdomainManager> _logger; private readonly ILogger<SearchdomainManager> _logger;
private readonly IConfiguration _config; private readonly IConfiguration _config;
public readonly AIProvider aIProvider;
private readonly string ollamaURL; private readonly string ollamaURL;
private readonly string connectionString; private readonly string connectionString;
public OllamaApiClient client;
private MySqlConnection connection; private MySqlConnection connection;
public SQLHelper helper; public SQLHelper helper;
public Dictionary<string, Dictionary<string, float[]>> embeddingCache; public Dictionary<string, Dictionary<string, float[]>> embeddingCache;
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config) public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config, AIProvider aIProvider)
{ {
_logger = logger; _logger = logger;
_config = config; _config = config;
this.aIProvider = aIProvider;
embeddingCache = []; embeddingCache = [];
ollamaURL = _config.GetSection("Embeddingsearch")["OllamaURL"] ?? "";
connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? ""; connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? "";
if (ollamaURL.IsNullOrEmpty() || connectionString.IsNullOrEmpty())
{
throw new ServerConfigurationException("Ollama URL or connection string is empty");
}
client = new(new Uri(ollamaURL));
connection = new MySqlConnection(connectionString); connection = new MySqlConnection(connectionString);
connection.Open(); connection.Open();
helper = new SQLHelper(connection, connectionString); helper = new SQLHelper(connection, connectionString);
@@ -53,13 +49,18 @@ public class SearchdomainManager
} }
try try
{ {
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger)); return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, aIProvider, embeddingCache, _logger));
} }
catch (MySqlException) catch (MySqlException)
{ {
_logger.LogError("Unable to find the searchdomain {searchdomain}", searchdomain); _logger.LogError("Unable to find the searchdomain {searchdomain}", searchdomain);
throw new Exception($"Unable to find the searchdomain {searchdomain}"); throw new Exception($"Unable to find the searchdomain {searchdomain}");
} }
catch (Exception ex)
{
_logger.LogError("Unable to load the searchdomain {searchdomain} due to the following exception: {ex}", [searchdomain, ex.Message]);
throw;
}
} }
public void InvalidateSearchdomainCache(string searchdomainName) public void InvalidateSearchdomainCache(string searchdomainName)

View File

@@ -8,6 +8,7 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="ElmahCore" Version="2.1.2" /> <PackageReference Include="ElmahCore" Version="2.1.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" /> <PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
<PackageReference Include="Serilog.Sinks.File" Version="7.0.0" /> <PackageReference Include="Serilog.Sinks.File" Version="7.0.0" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" /> <PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />

View File

@@ -24,7 +24,17 @@
"172.17.0.1" "172.17.0.1"
] ]
}, },
"OllamaURL": "http://localhost:11434", "AiProviders": {
"ollama": {
"handler": "ollama",
"baseURL": "http://192.168.0.101:11434"
},
"localAI": {
"handler": "openai",
"baseURL": "http://localhost:8080",
"ApiKey": "Some API key here"
}
},
"ApiKeys": ["Some UUID here", "Another UUID here"], "ApiKeys": ["Some UUID here", "Another UUID here"],
"UseHttpsRedirection": true "UseHttpsRedirection": true
} }