From b596695fd95e72a1252599f8b82649bd54e463ad Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Sun, 6 Jul 2025 22:28:45 +0200 Subject: [PATCH] Added AIProvider and support for OpenAI compatible APIs --- src/Indexer/Scripts/example.py | 4 +- src/Server/AIProvider.cs | 147 ++++++++++++++++++ src/Server/Controllers/EntityController.cs | 2 +- src/Server/Datapoint.cs | 20 +-- .../HealthChecks/AIProviderHealthChecks.cs | 2 +- src/Server/Helper/SearchdomainHelper.cs | 14 +- src/Server/Program.cs | 2 + src/Server/Searchdomain.cs | 11 +- src/Server/SearchdomainManager.cs | 19 +-- src/Server/Server.csproj | 1 + src/Server/appsettings.Development.json | 12 +- 11 files changed, 196 insertions(+), 38 deletions(-) create mode 100644 src/Server/AIProvider.cs diff --git a/src/Indexer/Scripts/example.py b/src/Indexer/Scripts/example.py index 8a0dacf..9874ad1 100644 --- a/src/Indexer/Scripts/example.py +++ b/src/Indexer/Scripts/example.py @@ -8,11 +8,11 @@ example_content = "./Scripts/example_content" probmethod = "LVEWAvg" example_searchdomain = "example_" + probmethod example_counter = 0 -models = ["bge-m3", "mxbai-embed-large"] +models = ["ollama:bge-m3", "ollama:mxbai-embed-large"] probmethod_datapoint = probmethod probmethod_entity = probmethod # Example for a dictionary based weighted average: -# probmethod_datapoint = "DictionaryWeightedAverage:{\"bge-m3\": 4, \"mxbai-embed-large\": 1}" +# probmethod_datapoint = "DictionaryWeightedAverage:{\"ollama:bge-m3\": 4, \"ollama:mxbai-embed-large\": 1}" # probmethod_entity = "DictionaryWeightedAverage:{\"title\": 2, \"filename\": 0.1, \"text\": 0.25}" def init(toolset: Toolset): diff --git a/src/Server/AIProvider.cs b/src/Server/AIProvider.cs new file mode 100644 index 0000000..a84ef6b --- /dev/null +++ b/src/Server/AIProvider.cs @@ -0,0 +1,147 @@ + +using System.Collections.Frozen; +using System.Collections.Immutable; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; +using Server.Exceptions; + +namespace server; + +public class AIProvider +{ + private readonly ILogger _logger; + private readonly IConfiguration _configuration; + public AIProvidersConfiguration aIProvidersConfiguration; + + public AIProvider(ILogger logger, IConfiguration configuration) + { + _logger = logger; + _configuration = configuration; + AIProvidersConfiguration? retrievedAiProvidersConfiguration = _configuration + .GetSection("Embeddingsearch") + .Get(); + if (retrievedAiProvidersConfiguration is null) + { + _logger.LogCritical("Unable to build AIProvidersConfiguration. Please check your configuration."); + throw new ServerConfigurationException("Unable to build AIProvidersConfiguration. Please check your configuration."); + } + else + { + aIProvidersConfiguration = retrievedAiProvidersConfiguration; + } + } + + public float[] GenerateEmbeddings(string modelUri, string[] input) + { + Uri uri = new(modelUri); + string provider = uri.Scheme; + string model = uri.AbsolutePath; + AIProviderConfiguration? aIProvider = aIProvidersConfiguration.AiProviders + .FirstOrDefault(x => String.Equals(x.Key.ToLower(), provider.ToLower())) + .Value; + if (aIProvider is null) + { + _logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]); + throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}"); + } + using var httpClient = new HttpClient(); + + string embeddingsJsonPath = ""; + Uri baseUri = new(aIProvider.BaseURL); + Uri requestUri; + IEmbedRequestBody embedRequest; + string[][] requestHeaders = []; + switch (aIProvider.Handler) + { + case "ollama": + embeddingsJsonPath = "$.embeddings[*]"; + requestUri = new Uri(baseUri, "/api/embed"); + embedRequest = new OllamaEmbedRequestBody() + { + input = input, + model = model + }; + break; + case "openai": + embeddingsJsonPath = "$.data[*].embedding"; + requestUri = new Uri(baseUri, "/v1/embeddings"); + embedRequest = new OpenAIEmbedRequestBody() + { + input = input, + model = model + }; + if (aIProvider.ApiKey is not null) + { + requestHeaders = [ + ["Authorization", $"Bearer {aIProvider.ApiKey}"] + ]; + } + break; + default: + _logger.LogError("Unknown handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]); + throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}."); + } + var requestContent = new StringContent( + JsonConvert.SerializeObject(embedRequest), + UnicodeEncoding.UTF8, + "application/json" + ); + + var request = new HttpRequestMessage() + { + RequestUri = requestUri, + Method = HttpMethod.Post, + Content = requestContent + }; + + foreach (var header in requestHeaders) + { + request.Headers.Add(header[0], header[1]); + } + HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result; + string responseContent = response.Content.ReadAsStringAsync().Result; + try + { + JObject responseContentJson = JObject.Parse(responseContent); + JToken? responseContentTokens = responseContentJson.SelectToken(embeddingsJsonPath); + if (responseContentTokens is null) + { + throw new Exception($"Unable to select tokens using JSONPath {embeddingsJsonPath}."); // TODO add proper exception + } + return [.. responseContentTokens.Values()]; + } + catch (Exception ex) + { + _logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]); + throw new Exception($"Unable to parse the response to valid embeddings. {ex.Message}"); // TODO add proper exception + } + } +} + +public class AIProvidersConfiguration +{ + public required Dictionary AiProviders { get; set; } +} + +public class AIProviderConfiguration +{ + public required string Handler { get; set; } + public required string BaseURL { get; set; } + public string? ApiKey { get; set; } +} +public interface IEmbedRequestBody { } + +public class OllamaEmbedRequestBody : IEmbedRequestBody +{ + public required string model { get; set; } + public required string[] input { get; set; } +} + +public class OpenAIEmbedRequestBody : IEmbedRequestBody +{ + public required string model { get; set; } + public required string[] input { get; set; } +} \ No newline at end of file diff --git a/src/Server/Controllers/EntityController.cs b/src/Server/Controllers/EntityController.cs index d05a9e7..d896956 100644 --- a/src/Server/Controllers/EntityController.cs +++ b/src/Server/Controllers/EntityController.cs @@ -46,7 +46,7 @@ public class EntityController : ControllerBase List? entities = SearchdomainHelper.EntitiesFromJSON( [], _domainManager.embeddingCache, - _domainManager.client, + _domainManager.aIProvider, _domainManager.helper, _logger, JsonSerializer.Serialize(jsonEntities)); diff --git a/src/Server/Datapoint.cs b/src/Server/Datapoint.cs index e7e6ccf..5e54841 100644 --- a/src/Server/Datapoint.cs +++ b/src/Server/Datapoint.cs @@ -6,6 +6,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.AI; using OllamaSharp; using OllamaSharp.Models; +using server; namespace Server; @@ -29,9 +30,9 @@ public class Datapoint return probMethod.method(probabilities); } - public static Dictionary GenerateEmbeddings(string content, List models, OllamaApiClient ollama) + public static Dictionary GenerateEmbeddings(string content, List models, AIProvider aIProvider) { - return GenerateEmbeddings(content, models, ollama, []); + return GenerateEmbeddings(content, models, aIProvider, []); } public static Dictionary GenerateEmbeddings(List contents, string model, OllamaApiClient ollama, Dictionary> embeddingCache) @@ -78,7 +79,7 @@ public class Datapoint return retVal; } - public static Dictionary GenerateEmbeddings(string content, List models, OllamaApiClient ollama, Dictionary> embeddingCache) + public static Dictionary GenerateEmbeddings(string content, List models, AIProvider aIProvider, Dictionary> embeddingCache) { Dictionary retVal = []; foreach (string model in models) @@ -88,24 +89,17 @@ public class Datapoint retVal[model] = embeddingCache[model][content]; continue; } - EmbedRequest request = new() - { - Model = model, - Input = [content] - }; - - var response = ollama.EmbedAsync(request).Result; + var response = aIProvider.GenerateEmbeddings(model, [content]); if (response is not null) { - float[] var = [.. response.Embeddings.First()]; - retVal[model] = var; + retVal[model] = response; if (!embeddingCache.ContainsKey(model)) { embeddingCache[model] = []; } if (!embeddingCache[model].ContainsKey(content)) { - embeddingCache[model][content] = var; + embeddingCache[model][content] = response; } } } diff --git a/src/Server/HealthChecks/AIProviderHealthChecks.cs b/src/Server/HealthChecks/AIProviderHealthChecks.cs index bd8a0fc..091c84f 100644 --- a/src/Server/HealthChecks/AIProviderHealthChecks.cs +++ b/src/Server/HealthChecks/AIProviderHealthChecks.cs @@ -17,7 +17,7 @@ public class AIProviderHealthCheck : IHealthCheck { try { - var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result; + //var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result; // TODO reimplement this } catch (Exception ex) { diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 508b254..334401d 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -4,6 +4,7 @@ using System.Text; using System.Text.Json; using MySql.Data.MySqlClient; using OllamaSharp; +using server; namespace Server; @@ -41,7 +42,7 @@ public static class SearchdomainHelper return null; } - public static List? EntitiesFromJSON(List entityCache, Dictionary> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json) + public static List? EntitiesFromJSON(List entityCache, Dictionary> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, string json) { List? jsonEntities = JsonSerializer.Deserialize>(json); if (jsonEntities is null) @@ -65,10 +66,11 @@ public static class SearchdomainHelper } } ConcurrentQueue retVal = []; - Parallel.ForEach(jsonEntities, jSONEntity => + ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool + Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity => { using var tempHelper = helper.DuplicateConnection(); - var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity); + var entity = EntityFromJSON(entityCache, embeddingCache, aIProvider, tempHelper, logger, jSONEntity); if (entity is not null) { retVal.Enqueue(entity); @@ -77,7 +79,7 @@ public static class SearchdomainHelper return [.. retVal]; } - public static Entity? EntityFromJSON(List entityCache, Dictionary> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json) + public static Entity? EntityFromJSON(List entityCache, Dictionary> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json) { Dictionary> embeddingsLUT = []; int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); @@ -130,14 +132,14 @@ public static class SearchdomainHelper } else { - var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], ollama, embeddingCache); + var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], aIProvider, embeddingCache); embeddings.Add(model, additionalEmbeddings.First().Value); } } } else { - embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache); + embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache); } var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); diff --git a/src/Server/Program.cs b/src/Server/Program.cs index 5b08e35..d8b1024 100644 --- a/src/Server/Program.cs +++ b/src/Server/Program.cs @@ -1,6 +1,7 @@ using ElmahCore; using ElmahCore.Mvc; using Serilog; +using server; using Server; using Server.HealthChecks; @@ -17,6 +18,7 @@ Log.Logger = new LoggerConfiguration() .CreateLogger(); builder.Logging.AddSerilog(); builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddHealthChecks() .AddCheck("DatabaseHealthCheck") .AddCheck("AIProviderHealthChecck"); diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index 0a5b84d..8b076b0 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -21,6 +21,7 @@ using Server; using System.Security.Cryptography; using System.Text; using System.Collections.Concurrent; +using server; namespace Server; @@ -28,7 +29,7 @@ public class Searchdomain { private readonly string _connectionString; private readonly string _provider; - public OllamaApiClient ollama; + public AIProvider aIProvider; public string searchdomain; public int id; public Dictionary)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x @@ -42,12 +43,12 @@ public class Searchdomain // TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain() - public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) + public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, Dictionary> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false) { _connectionString = connectionString; _provider = provider.ToLower(); this.searchdomain = searchdomain; - this.ollama = ollama; + this.aIProvider = aIProvider; this.embeddingCache = embeddingCache; this._logger = logger; searchCache = []; @@ -69,7 +70,7 @@ public class Searchdomain { ["id"] = this.id }; - DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain); + DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain); // TODO fix: parametersIDSearchdomain defined, but not used Dictionary> embedding_unassigned = []; while (embeddingReader.Read()) { @@ -162,7 +163,7 @@ public class Searchdomain { if (!embeddingCache.TryGetValue(query, out Dictionary? queryEmbeddings)) { - queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, ollama); + queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider); if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits { // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries embeddingCache.Add(query, queryEmbeddings); diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index 6a07364..b4f1623 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -4,6 +4,7 @@ using OllamaSharp; using Microsoft.IdentityModel.Tokens; using Server.Exceptions; using Server.Migrations; +using server; namespace Server; @@ -12,25 +13,20 @@ public class SearchdomainManager private Dictionary searchdomains = []; private readonly ILogger _logger; private readonly IConfiguration _config; + public readonly AIProvider aIProvider; private readonly string ollamaURL; private readonly string connectionString; - public OllamaApiClient client; private MySqlConnection connection; public SQLHelper helper; public Dictionary> embeddingCache; - public SearchdomainManager(ILogger logger, IConfiguration config) + public SearchdomainManager(ILogger logger, IConfiguration config, AIProvider aIProvider) { _logger = logger; _config = config; + this.aIProvider = aIProvider; embeddingCache = []; - ollamaURL = _config.GetSection("Embeddingsearch")["OllamaURL"] ?? ""; connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? ""; - if (ollamaURL.IsNullOrEmpty() || connectionString.IsNullOrEmpty()) - { - throw new ServerConfigurationException("Ollama URL or connection string is empty"); - } - client = new(new Uri(ollamaURL)); connection = new MySqlConnection(connectionString); connection.Open(); helper = new SQLHelper(connection, connectionString); @@ -53,13 +49,18 @@ public class SearchdomainManager } try { - return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger)); + return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, aIProvider, embeddingCache, _logger)); } catch (MySqlException) { _logger.LogError("Unable to find the searchdomain {searchdomain}", searchdomain); throw new Exception($"Unable to find the searchdomain {searchdomain}"); } + catch (Exception ex) + { + _logger.LogError("Unable to load the searchdomain {searchdomain} due to the following exception: {ex}", [searchdomain, ex.Message]); + throw; + } } public void InvalidateSearchdomainCache(string searchdomainName) diff --git a/src/Server/Server.csproj b/src/Server/Server.csproj index ed6df4d..e1ceed5 100644 --- a/src/Server/Server.csproj +++ b/src/Server/Server.csproj @@ -8,6 +8,7 @@ + diff --git a/src/Server/appsettings.Development.json b/src/Server/appsettings.Development.json index 6d2a5ea..eb06731 100644 --- a/src/Server/appsettings.Development.json +++ b/src/Server/appsettings.Development.json @@ -24,7 +24,17 @@ "172.17.0.1" ] }, - "OllamaURL": "http://localhost:11434", + "AiProviders": { + "ollama": { + "handler": "ollama", + "baseURL": "http://192.168.0.101:11434" + }, + "localAI": { + "handler": "openai", + "baseURL": "http://localhost:8080", + "ApiKey": "Some API key here" + } + }, "ApiKeys": ["Some UUID here", "Another UUID here"], "UseHttpsRedirection": true }