Added AIProvider and support for OpenAI compatible APIs
This commit is contained in:
147
src/Server/AIProvider.cs
Normal file
147
src/Server/AIProvider.cs
Normal file
@@ -0,0 +1,147 @@
|
||||
|
||||
using System.Collections.Frozen;
|
||||
using System.Collections.Immutable;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using Server.Exceptions;
|
||||
|
||||
namespace server;
|
||||
|
||||
public class AIProvider
|
||||
{
|
||||
private readonly ILogger<AIProvider> _logger;
|
||||
private readonly IConfiguration _configuration;
|
||||
public AIProvidersConfiguration aIProvidersConfiguration;
|
||||
|
||||
public AIProvider(ILogger<AIProvider> logger, IConfiguration configuration)
|
||||
{
|
||||
_logger = logger;
|
||||
_configuration = configuration;
|
||||
AIProvidersConfiguration? retrievedAiProvidersConfiguration = _configuration
|
||||
.GetSection("Embeddingsearch")
|
||||
.Get<AIProvidersConfiguration>();
|
||||
if (retrievedAiProvidersConfiguration is null)
|
||||
{
|
||||
_logger.LogCritical("Unable to build AIProvidersConfiguration. Please check your configuration.");
|
||||
throw new ServerConfigurationException("Unable to build AIProvidersConfiguration. Please check your configuration.");
|
||||
}
|
||||
else
|
||||
{
|
||||
aIProvidersConfiguration = retrievedAiProvidersConfiguration;
|
||||
}
|
||||
}
|
||||
|
||||
public float[] GenerateEmbeddings(string modelUri, string[] input)
|
||||
{
|
||||
Uri uri = new(modelUri);
|
||||
string provider = uri.Scheme;
|
||||
string model = uri.AbsolutePath;
|
||||
AIProviderConfiguration? aIProvider = aIProvidersConfiguration.AiProviders
|
||||
.FirstOrDefault(x => String.Equals(x.Key.ToLower(), provider.ToLower()))
|
||||
.Value;
|
||||
if (aIProvider is null)
|
||||
{
|
||||
_logger.LogError("Model provider {provider} not found in configuration. Requested model: {modelUri}", [provider, modelUri]);
|
||||
throw new ServerConfigurationException($"Model provider {provider} not found in configuration. Requested model: {modelUri}");
|
||||
}
|
||||
using var httpClient = new HttpClient();
|
||||
|
||||
string embeddingsJsonPath = "";
|
||||
Uri baseUri = new(aIProvider.BaseURL);
|
||||
Uri requestUri;
|
||||
IEmbedRequestBody embedRequest;
|
||||
string[][] requestHeaders = [];
|
||||
switch (aIProvider.Handler)
|
||||
{
|
||||
case "ollama":
|
||||
embeddingsJsonPath = "$.embeddings[*]";
|
||||
requestUri = new Uri(baseUri, "/api/embed");
|
||||
embedRequest = new OllamaEmbedRequestBody()
|
||||
{
|
||||
input = input,
|
||||
model = model
|
||||
};
|
||||
break;
|
||||
case "openai":
|
||||
embeddingsJsonPath = "$.data[*].embedding";
|
||||
requestUri = new Uri(baseUri, "/v1/embeddings");
|
||||
embedRequest = new OpenAIEmbedRequestBody()
|
||||
{
|
||||
input = input,
|
||||
model = model
|
||||
};
|
||||
if (aIProvider.ApiKey is not null)
|
||||
{
|
||||
requestHeaders = [
|
||||
["Authorization", $"Bearer {aIProvider.ApiKey}"]
|
||||
];
|
||||
}
|
||||
break;
|
||||
default:
|
||||
_logger.LogError("Unknown handler {aIProvider.Handler} in AiProvider {provider}.", [aIProvider.Handler, provider]);
|
||||
throw new ServerConfigurationException($"Unknown handler {aIProvider.Handler} in AiProvider {provider}.");
|
||||
}
|
||||
var requestContent = new StringContent(
|
||||
JsonConvert.SerializeObject(embedRequest),
|
||||
UnicodeEncoding.UTF8,
|
||||
"application/json"
|
||||
);
|
||||
|
||||
var request = new HttpRequestMessage()
|
||||
{
|
||||
RequestUri = requestUri,
|
||||
Method = HttpMethod.Post,
|
||||
Content = requestContent
|
||||
};
|
||||
|
||||
foreach (var header in requestHeaders)
|
||||
{
|
||||
request.Headers.Add(header[0], header[1]);
|
||||
}
|
||||
HttpResponseMessage response = httpClient.PostAsync(requestUri, requestContent).Result;
|
||||
string responseContent = response.Content.ReadAsStringAsync().Result;
|
||||
try
|
||||
{
|
||||
JObject responseContentJson = JObject.Parse(responseContent);
|
||||
JToken? responseContentTokens = responseContentJson.SelectToken(embeddingsJsonPath);
|
||||
if (responseContentTokens is null)
|
||||
{
|
||||
throw new Exception($"Unable to select tokens using JSONPath {embeddingsJsonPath}."); // TODO add proper exception
|
||||
}
|
||||
return [.. responseContentTokens.Values<float>()];
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError("Unable to parse the response to valid embeddings. {ex.Message}", [ex.Message]);
|
||||
throw new Exception($"Unable to parse the response to valid embeddings. {ex.Message}"); // TODO add proper exception
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class AIProvidersConfiguration
|
||||
{
|
||||
public required Dictionary<string, AIProviderConfiguration> AiProviders { get; set; }
|
||||
}
|
||||
|
||||
public class AIProviderConfiguration
|
||||
{
|
||||
public required string Handler { get; set; }
|
||||
public required string BaseURL { get; set; }
|
||||
public string? ApiKey { get; set; }
|
||||
}
|
||||
public interface IEmbedRequestBody { }
|
||||
|
||||
public class OllamaEmbedRequestBody : IEmbedRequestBody
|
||||
{
|
||||
public required string model { get; set; }
|
||||
public required string[] input { get; set; }
|
||||
}
|
||||
|
||||
public class OpenAIEmbedRequestBody : IEmbedRequestBody
|
||||
{
|
||||
public required string model { get; set; }
|
||||
public required string[] input { get; set; }
|
||||
}
|
||||
@@ -46,7 +46,7 @@ public class EntityController : ControllerBase
|
||||
List<Entity>? entities = SearchdomainHelper.EntitiesFromJSON(
|
||||
[],
|
||||
_domainManager.embeddingCache,
|
||||
_domainManager.client,
|
||||
_domainManager.aIProvider,
|
||||
_domainManager.helper,
|
||||
_logger,
|
||||
JsonSerializer.Serialize(jsonEntities));
|
||||
|
||||
@@ -6,6 +6,7 @@ using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
using OllamaSharp;
|
||||
using OllamaSharp.Models;
|
||||
using server;
|
||||
|
||||
namespace Server;
|
||||
|
||||
@@ -29,9 +30,9 @@ public class Datapoint
|
||||
return probMethod.method(probabilities);
|
||||
}
|
||||
|
||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
|
||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider)
|
||||
{
|
||||
return GenerateEmbeddings(content, models, ollama, []);
|
||||
return GenerateEmbeddings(content, models, aIProvider, []);
|
||||
}
|
||||
|
||||
public static Dictionary<string, float[]> GenerateEmbeddings(List<string> contents, string model, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
||||
@@ -78,7 +79,7 @@ public class Datapoint
|
||||
return retVal;
|
||||
}
|
||||
|
||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
||||
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
|
||||
{
|
||||
Dictionary<string, float[]> retVal = [];
|
||||
foreach (string model in models)
|
||||
@@ -88,24 +89,17 @@ public class Datapoint
|
||||
retVal[model] = embeddingCache[model][content];
|
||||
continue;
|
||||
}
|
||||
EmbedRequest request = new()
|
||||
{
|
||||
Model = model,
|
||||
Input = [content]
|
||||
};
|
||||
|
||||
var response = ollama.EmbedAsync(request).Result;
|
||||
var response = aIProvider.GenerateEmbeddings(model, [content]);
|
||||
if (response is not null)
|
||||
{
|
||||
float[] var = [.. response.Embeddings.First()];
|
||||
retVal[model] = var;
|
||||
retVal[model] = response;
|
||||
if (!embeddingCache.ContainsKey(model))
|
||||
{
|
||||
embeddingCache[model] = [];
|
||||
}
|
||||
if (!embeddingCache[model].ContainsKey(content))
|
||||
{
|
||||
embeddingCache[model][content] = var;
|
||||
embeddingCache[model][content] = response;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ public class AIProviderHealthCheck : IHealthCheck
|
||||
{
|
||||
try
|
||||
{
|
||||
var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result;
|
||||
//var _ = _searchdomainManager.client.ListLocalModelsAsync(cancellationToken).Result; // TODO reimplement this
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@ using System.Text;
|
||||
using System.Text.Json;
|
||||
using MySql.Data.MySqlClient;
|
||||
using OllamaSharp;
|
||||
using server;
|
||||
|
||||
namespace Server;
|
||||
|
||||
@@ -41,7 +42,7 @@ public static class SearchdomainHelper
|
||||
return null;
|
||||
}
|
||||
|
||||
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, string json)
|
||||
public static List<Entity>? EntitiesFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, string json)
|
||||
{
|
||||
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
|
||||
if (jsonEntities is null)
|
||||
@@ -65,10 +66,11 @@ public static class SearchdomainHelper
|
||||
}
|
||||
}
|
||||
ConcurrentQueue<Entity> retVal = [];
|
||||
Parallel.ForEach(jsonEntities, jSONEntity =>
|
||||
ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool
|
||||
Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity =>
|
||||
{
|
||||
using var tempHelper = helper.DuplicateConnection();
|
||||
var entity = EntityFromJSON(entityCache, embeddingCache, ollama, tempHelper, logger, jSONEntity);
|
||||
var entity = EntityFromJSON(entityCache, embeddingCache, aIProvider, tempHelper, logger, jSONEntity);
|
||||
if (entity is not null)
|
||||
{
|
||||
retVal.Enqueue(entity);
|
||||
@@ -77,7 +79,7 @@ public static class SearchdomainHelper
|
||||
return [.. retVal];
|
||||
}
|
||||
|
||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, OllamaApiClient ollama, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
||||
{
|
||||
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
||||
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
||||
@@ -130,14 +132,14 @@ public static class SearchdomainHelper
|
||||
}
|
||||
else
|
||||
{
|
||||
var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], ollama, embeddingCache);
|
||||
var additionalEmbeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [model], aIProvider, embeddingCache);
|
||||
embeddings.Add(model, additionalEmbeddings.First().Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
|
||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache);
|
||||
}
|
||||
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
using ElmahCore;
|
||||
using ElmahCore.Mvc;
|
||||
using Serilog;
|
||||
using server;
|
||||
using Server;
|
||||
using Server.HealthChecks;
|
||||
|
||||
@@ -17,6 +18,7 @@ Log.Logger = new LoggerConfiguration()
|
||||
.CreateLogger();
|
||||
builder.Logging.AddSerilog();
|
||||
builder.Services.AddSingleton<SearchdomainManager>();
|
||||
builder.Services.AddSingleton<AIProvider>();
|
||||
builder.Services.AddHealthChecks()
|
||||
.AddCheck<DatabaseHealthCheck>("DatabaseHealthCheck")
|
||||
.AddCheck<AIProviderHealthCheck>("AIProviderHealthChecck");
|
||||
|
||||
@@ -21,6 +21,7 @@ using Server;
|
||||
using System.Security.Cryptography;
|
||||
using System.Text;
|
||||
using System.Collections.Concurrent;
|
||||
using server;
|
||||
|
||||
namespace Server;
|
||||
|
||||
@@ -28,7 +29,7 @@ public class Searchdomain
|
||||
{
|
||||
private readonly string _connectionString;
|
||||
private readonly string _provider;
|
||||
public OllamaApiClient ollama;
|
||||
public AIProvider aIProvider;
|
||||
public string searchdomain;
|
||||
public int id;
|
||||
public Dictionary<string, List<(DateTime, List<(float, string)>)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x
|
||||
@@ -42,12 +43,12 @@ public class Searchdomain
|
||||
|
||||
// TODO Add settings and update cli/program.cs, as well as DatabaseInsertSearchdomain()
|
||||
|
||||
public Searchdomain(string searchdomain, string connectionString, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
||||
public Searchdomain(string searchdomain, string connectionString, AIProvider aIProvider, Dictionary<string, Dictionary<string, float[]>> embeddingCache, ILogger logger, string provider = "sqlserver", bool runEmpty = false)
|
||||
{
|
||||
_connectionString = connectionString;
|
||||
_provider = provider.ToLower();
|
||||
this.searchdomain = searchdomain;
|
||||
this.ollama = ollama;
|
||||
this.aIProvider = aIProvider;
|
||||
this.embeddingCache = embeddingCache;
|
||||
this._logger = logger;
|
||||
searchCache = [];
|
||||
@@ -69,7 +70,7 @@ public class Searchdomain
|
||||
{
|
||||
["id"] = this.id
|
||||
};
|
||||
DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain);
|
||||
DbDataReader embeddingReader = helper.ExecuteSQLCommand("SELECT embedding.id, id_datapoint, model, embedding FROM embedding", parametersIDSearchdomain); // TODO fix: parametersIDSearchdomain defined, but not used
|
||||
Dictionary<int, Dictionary<string, float[]>> embedding_unassigned = [];
|
||||
while (embeddingReader.Read())
|
||||
{
|
||||
@@ -162,7 +163,7 @@ public class Searchdomain
|
||||
{
|
||||
if (!embeddingCache.TryGetValue(query, out Dictionary<string, float[]>? queryEmbeddings))
|
||||
{
|
||||
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, ollama);
|
||||
queryEmbeddings = Datapoint.GenerateEmbeddings(query, modelsInUse, aIProvider);
|
||||
if (embeddingCache.Count < embeddingCacheMaxSize) // TODO add better way of managing cache limit hits
|
||||
{ // Idea: Add access count to each entry. On limit hit, sort the entries by access count and remove the bottom 10% of entries
|
||||
embeddingCache.Add(query, queryEmbeddings);
|
||||
|
||||
@@ -4,6 +4,7 @@ using OllamaSharp;
|
||||
using Microsoft.IdentityModel.Tokens;
|
||||
using Server.Exceptions;
|
||||
using Server.Migrations;
|
||||
using server;
|
||||
|
||||
namespace Server;
|
||||
|
||||
@@ -12,25 +13,20 @@ public class SearchdomainManager
|
||||
private Dictionary<string, Searchdomain> searchdomains = [];
|
||||
private readonly ILogger<SearchdomainManager> _logger;
|
||||
private readonly IConfiguration _config;
|
||||
public readonly AIProvider aIProvider;
|
||||
private readonly string ollamaURL;
|
||||
private readonly string connectionString;
|
||||
public OllamaApiClient client;
|
||||
private MySqlConnection connection;
|
||||
public SQLHelper helper;
|
||||
public Dictionary<string, Dictionary<string, float[]>> embeddingCache;
|
||||
|
||||
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config)
|
||||
public SearchdomainManager(ILogger<SearchdomainManager> logger, IConfiguration config, AIProvider aIProvider)
|
||||
{
|
||||
_logger = logger;
|
||||
_config = config;
|
||||
this.aIProvider = aIProvider;
|
||||
embeddingCache = [];
|
||||
ollamaURL = _config.GetSection("Embeddingsearch")["OllamaURL"] ?? "";
|
||||
connectionString = _config.GetSection("Embeddingsearch").GetConnectionString("SQL") ?? "";
|
||||
if (ollamaURL.IsNullOrEmpty() || connectionString.IsNullOrEmpty())
|
||||
{
|
||||
throw new ServerConfigurationException("Ollama URL or connection string is empty");
|
||||
}
|
||||
client = new(new Uri(ollamaURL));
|
||||
connection = new MySqlConnection(connectionString);
|
||||
connection.Open();
|
||||
helper = new SQLHelper(connection, connectionString);
|
||||
@@ -53,13 +49,18 @@ public class SearchdomainManager
|
||||
}
|
||||
try
|
||||
{
|
||||
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, client, embeddingCache, _logger));
|
||||
return SetSearchdomain(searchdomain, new Searchdomain(searchdomain, connectionString, aIProvider, embeddingCache, _logger));
|
||||
}
|
||||
catch (MySqlException)
|
||||
{
|
||||
_logger.LogError("Unable to find the searchdomain {searchdomain}", searchdomain);
|
||||
throw new Exception($"Unable to find the searchdomain {searchdomain}");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError("Unable to load the searchdomain {searchdomain} due to the following exception: {ex}", [searchdomain, ex.Message]);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
public void InvalidateSearchdomainCache(string searchdomainName)
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="ElmahCore" Version="2.1.2" />
|
||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
|
||||
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
|
||||
<PackageReference Include="Serilog.Sinks.File" Version="7.0.0" />
|
||||
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.6.2" />
|
||||
|
||||
@@ -24,7 +24,17 @@
|
||||
"172.17.0.1"
|
||||
]
|
||||
},
|
||||
"OllamaURL": "http://localhost:11434",
|
||||
"AiProviders": {
|
||||
"ollama": {
|
||||
"handler": "ollama",
|
||||
"baseURL": "http://192.168.0.101:11434"
|
||||
},
|
||||
"localAI": {
|
||||
"handler": "openai",
|
||||
"baseURL": "http://localhost:8080",
|
||||
"ApiKey": "Some API key here"
|
||||
}
|
||||
},
|
||||
"ApiKeys": ["Some UUID here", "Another UUID here"],
|
||||
"UseHttpsRedirection": true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user