diff --git a/src/Client/Client.cs b/src/Client/Client.cs index 85718a0..b4df755 100644 --- a/src/Client/Client.cs +++ b/src/Client/Client.cs @@ -8,6 +8,8 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Configuration; using System.Reflection.Metadata.Ecma335; using Shared.Models; +using System.Net; +using Microsoft.Extensions.Options; namespace Client; @@ -24,12 +26,12 @@ public class Client this.searchdomain = searchdomain; } - public Client(IConfiguration configuration) + public Client(IOptions configuration) { - string? baseUri = configuration.GetSection("Embeddingsearch").GetValue("BaseUri"); - string? apiKey = configuration.GetSection("Embeddingsearch").GetValue("ApiKey"); - string? searchdomain = configuration.GetSection("Embeddingsearch").GetValue("Searchdomain"); - this.baseUri = baseUri ?? ""; + string baseUri = configuration.Value.BaseUri; + string? apiKey = configuration.Value.ApiKey; + string? searchdomain = configuration.Value.Searchdomain; + this.baseUri = baseUri; this.apiKey = apiKey ?? ""; this.searchdomain = searchdomain ?? ""; } @@ -41,8 +43,8 @@ public class Client public async Task EntityListAsync(string searchdomain, bool returnEmbeddings = false) { - var url = $"{baseUri}/Entities?apiKey={HttpUtility.UrlEncode(apiKey)}&searchdomain={HttpUtility.UrlEncode(searchdomain)}&returnEmbeddings={HttpUtility.UrlEncode(returnEmbeddings.ToString())}"; - return await GetUrlAndProcessJson(url); + var url = $"{baseUri}/Entities?searchdomain={HttpUtility.UrlEncode(searchdomain)}&returnEmbeddings={HttpUtility.UrlEncode(returnEmbeddings.ToString())}"; + return await FetchUrlAndProcessJson(HttpMethod.Get, url); } public async Task EntityIndexAsync(List jsonEntity) @@ -53,7 +55,7 @@ public class Client public async Task EntityIndexAsync(string jsonEntity) { var content = new StringContent(jsonEntity, Encoding.UTF8, "application/json"); - return await PutUrlAndProcessJson(GetUrl($"{baseUri}", "Entities", apiKey, []), content); + return await FetchUrlAndProcessJson(HttpMethod.Put, GetUrl($"{baseUri}", "Entities", []), content); } public async Task EntityDeleteAsync(string entityName) @@ -64,12 +66,12 @@ public class Client public async Task EntityDeleteAsync(string searchdomain, string entityName) { var url = $"{baseUri}/Entity?apiKey={HttpUtility.UrlEncode(apiKey)}&searchdomain={HttpUtility.UrlEncode(searchdomain)}&entity={HttpUtility.UrlEncode(entityName)}"; - return await DeleteUrlAndProcessJson(url); + return await FetchUrlAndProcessJson(HttpMethod.Delete, url); } public async Task SearchdomainListAsync() { - return await GetUrlAndProcessJson(GetUrl($"{baseUri}", "Searchdomains", apiKey, [])); + return await FetchUrlAndProcessJson(HttpMethod.Get, GetUrl($"{baseUri}", "Searchdomains", [])); } public async Task SearchdomainCreateAsync() @@ -79,7 +81,7 @@ public class Client public async Task SearchdomainCreateAsync(string searchdomain, SearchdomainSettings searchdomainSettings = new()) { - return await PostUrlAndProcessJson(GetUrl($"{baseUri}", "Searchdomain", apiKey, new Dictionary() + return await FetchUrlAndProcessJson(HttpMethod.Post, GetUrl($"{baseUri}", "Searchdomain", new Dictionary() { {"searchdomain", searchdomain} }), new StringContent(JsonSerializer.Serialize(searchdomainSettings), Encoding.UTF8, "application/json")); @@ -92,7 +94,7 @@ public class Client public async Task SearchdomainDeleteAsync(string searchdomain) { - return await DeleteUrlAndProcessJson(GetUrl($"{baseUri}", "Searchdomain", apiKey, new Dictionary() + return await FetchUrlAndProcessJson(HttpMethod.Delete, GetUrl($"{baseUri}", "Searchdomain", new Dictionary() { {"searchdomain", searchdomain} })); @@ -112,7 +114,7 @@ public class Client public async Task SearchdomainUpdateAsync(string searchdomain, string newName, string settings = "{}") { - return await PutUrlAndProcessJson(GetUrl($"{baseUri}", "Searchdomain", apiKey, new Dictionary() + return await FetchUrlAndProcessJson(HttpMethod.Put, GetUrl($"{baseUri}", "Searchdomain", new Dictionary() { {"searchdomain", searchdomain}, {"newName", newName} @@ -125,7 +127,7 @@ public class Client { {"searchdomain", searchdomain} }; - return await GetUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain", "Queries", apiKey, parameters)); + return await FetchUrlAndProcessJson(HttpMethod.Get, GetUrl($"{baseUri}/Searchdomain", "Queries", parameters)); } public async Task SearchdomainQueryAsync(string query) @@ -143,7 +145,7 @@ public class Client if (topN is not null) parameters.Add("topN", ((int)topN).ToString()); if (returnAttributes) parameters.Add("returnAttributes", returnAttributes.ToString()); - return await PostUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain", "Query", apiKey, parameters), null); + return await FetchUrlAndProcessJson(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain", "Query", parameters), null); } public async Task SearchdomainDeleteQueryAsync(string searchdomain, string query) @@ -153,7 +155,7 @@ public class Client {"searchdomain", searchdomain}, {"query", query} }; - return await DeleteUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain", "Query", apiKey, parameters)); + return await FetchUrlAndProcessJson(HttpMethod.Delete, GetUrl($"{baseUri}/Searchdomain", "Query", parameters)); } public async Task SearchdomainUpdateQueryAsync(string searchdomain, string query, List results) @@ -163,8 +165,9 @@ public class Client {"searchdomain", searchdomain}, {"query", query} }; - return await PatchUrlAndProcessJson( - GetUrl($"{baseUri}/Searchdomain", "Query", apiKey, parameters), + return await FetchUrlAndProcessJson( + HttpMethod.Patch, + GetUrl($"{baseUri}/Searchdomain", "Query", parameters), new StringContent(JsonSerializer.Serialize(results), Encoding.UTF8, "application/json")); } @@ -174,7 +177,7 @@ public class Client { {"searchdomain", searchdomain} }; - return await GetUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain", "Settings", apiKey, parameters)); + return await FetchUrlAndProcessJson(HttpMethod.Get, GetUrl($"{baseUri}/Searchdomain", "Settings", parameters)); } public async Task SearchdomainUpdateSettingsAsync(string searchdomain, SearchdomainSettings searchdomainSettings) @@ -184,7 +187,7 @@ public class Client {"searchdomain", searchdomain} }; StringContent content = new(JsonSerializer.Serialize(searchdomainSettings), Encoding.UTF8, "application/json"); - return await PutUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain", "Settings", apiKey, parameters), content); + return await FetchUrlAndProcessJson(HttpMethod.Put, GetUrl($"{baseUri}/Searchdomain", "Settings", parameters), content); } public async Task SearchdomainGetQueryCacheSizeAsync(string searchdomain) @@ -193,7 +196,7 @@ public class Client { {"searchdomain", searchdomain} }; - return await GetUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain/QueryCache", "Size", apiKey, parameters)); + return await FetchUrlAndProcessJson(HttpMethod.Get, GetUrl($"{baseUri}/Searchdomain/QueryCache", "Size", parameters)); } public async Task SearchdomainClearQueryCache(string searchdomain) @@ -202,7 +205,7 @@ public class Client { {"searchdomain", searchdomain} }; - return await PostUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain/QueryCache", "Clear", apiKey, parameters), null); + return await FetchUrlAndProcessJson(HttpMethod.Post, GetUrl($"{baseUri}/Searchdomain/QueryCache", "Clear", parameters), null); } public async Task SearchdomainGetDatabaseSizeAsync(string searchdomain) @@ -211,74 +214,40 @@ public class Client { {"searchdomain", searchdomain} }; - return await GetUrlAndProcessJson(GetUrl($"{baseUri}/Searchdomain/Database", "Size", apiKey, parameters)); + return await FetchUrlAndProcessJson(HttpMethod.Get, GetUrl($"{baseUri}/Searchdomain/Database", "Size", parameters)); } public async Task ServerGetModelsAsync() { - return await GetUrlAndProcessJson(GetUrl($"{baseUri}/Server", "Models", apiKey, [])); + return await FetchUrlAndProcessJson(HttpMethod.Get, GetUrl($"{baseUri}/Server", "Models", [])); } public async Task ServerGetEmbeddingCacheSizeAsync() { - return await GetUrlAndProcessJson(GetUrl($"{baseUri}/Server/EmbeddingCache", "Size", apiKey, [])); + return await FetchUrlAndProcessJson(HttpMethod.Get, GetUrl($"{baseUri}/Server/EmbeddingCache", "Size", [])); } - private static async Task GetUrlAndProcessJson(string url) + private async Task FetchUrlAndProcessJson(HttpMethod httpMethod, string url, HttpContent? content = null) { + HttpRequestMessage requestMessage = new(httpMethod, url) + { + Content = content, + }; + requestMessage.Headers.Add("X-API-KEY", apiKey); using var client = new HttpClient(); - var response = await client.GetAsync(url); + var response = await client.SendAsync(requestMessage); string responseContent = await response.Content.ReadAsStringAsync(); + if (response.StatusCode == HttpStatusCode.Forbidden || response.StatusCode == HttpStatusCode.Unauthorized) throw new UnauthorizedAccessException(responseContent); // TODO implement distinct exceptions + if (response.StatusCode == HttpStatusCode.InternalServerError) throw new Exception($"Request was unsuccessful due to an internal server error: {responseContent}"); // TODO implement proper InternalServerErrorException var result = JsonSerializer.Deserialize(responseContent) ?? throw new Exception($"Failed to deserialize JSON to type {typeof(T).Name}"); return result; } - private static async Task PostUrlAndProcessJson(string url, HttpContent? content) - { - using var client = new HttpClient(); - var response = await client.PostAsync(url, content); - string responseContent = await response.Content.ReadAsStringAsync(); - var result = JsonSerializer.Deserialize(responseContent) - ?? throw new Exception($"Failed to deserialize JSON to type {typeof(T).Name}"); - return result; - } - - private static async Task PutUrlAndProcessJson(string url, HttpContent content) - { - using var client = new HttpClient(); - var response = await client.PutAsync(url, content); - string responseContent = await response.Content.ReadAsStringAsync(); - var result = JsonSerializer.Deserialize(responseContent) - ?? throw new Exception($"Failed to deserialize JSON to type {typeof(T).Name}"); - return result; - } - - private static async Task PatchUrlAndProcessJson(string url, HttpContent content) - { - using var client = new HttpClient(); - var response = await client.PatchAsync(url, content); - string responseContent = await response.Content.ReadAsStringAsync(); - var result = JsonSerializer.Deserialize(responseContent) - ?? throw new Exception($"Failed to deserialize JSON to type {typeof(T).Name}"); - return result; - } - - private static async Task DeleteUrlAndProcessJson(string url) - { - using var client = new HttpClient(); - var response = await client.DeleteAsync(url); - string responseContent = await response.Content.ReadAsStringAsync(); - var result = JsonSerializer.Deserialize(responseContent) - ?? throw new Exception($"Failed to deserialize JSON to type {typeof(T).Name}"); - return result; - } - - public static string GetUrl(string baseUri, string endpoint, string apiKey, Dictionary parameters) + public static string GetUrl(string baseUri, string endpoint, Dictionary parameters) { var uriBuilder = new UriBuilder($"{baseUri}/{endpoint}"); var query = HttpUtility.ParseQueryString(uriBuilder.Query); - if (apiKey.Length > 0) query["apiKey"] = apiKey; foreach (var param in parameters) { query[param.Key] = param.Value; diff --git a/src/Indexer/Indexer.csproj b/src/Indexer/Indexer.csproj index 8c2a30b..c3d7d71 100644 --- a/src/Indexer/Indexer.csproj +++ b/src/Indexer/Indexer.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Indexer/Models/OptionModels.cs b/src/Indexer/Models/OptionModels.cs new file mode 100644 index 0000000..6b33793 --- /dev/null +++ b/src/Indexer/Models/OptionModels.cs @@ -0,0 +1,9 @@ +using Shared.Models; +namespace Indexer.Models; + +public class IndexerOptions : ApiKeyOptions +{ + public required WorkerConfig[] Workers { get; set; } + public required ServerOptions Server { get; set;} + public required string PythonRuntime { get; set; } = "libpython3.13.so"; +} diff --git a/src/Indexer/Models/ScriptModels.cs b/src/Indexer/Models/ScriptModels.cs index 3f3043e..252269a 100644 --- a/src/Indexer/Models/ScriptModels.cs +++ b/src/Indexer/Models/ScriptModels.cs @@ -15,11 +15,11 @@ public class ScriptToolSet public Client.Client Client; public LoggerWrapper Logger; public ICallbackInfos? CallbackInfos; - public IConfiguration Configuration; + public IndexerOptions Configuration; public CancellationToken CancellationToken; public string Name; - public ScriptToolSet(string filePath, Client.Client client, ILogger logger, IConfiguration configuration, CancellationToken cancellationToken, string name) + public ScriptToolSet(string filePath, Client.Client client, ILogger logger, IndexerOptions configuration, CancellationToken cancellationToken, string name) { Configuration = configuration; Name = name; diff --git a/src/Indexer/Program.cs b/src/Indexer/Program.cs index a78af58..fbef06d 100644 --- a/src/Indexer/Program.cs +++ b/src/Indexer/Program.cs @@ -6,6 +6,8 @@ using ElmahCore.Mvc; using ElmahCore.Mvc.Logger; using Serilog; using Quartz; +using System.Configuration; +using Shared.Models; var builder = WebApplication.CreateBuilder(args); @@ -21,6 +23,12 @@ Log.Logger = new LoggerConfiguration() builder.Logging.AddSerilog(); builder.Services.AddHttpContextAccessor(); builder.Services.AddSingleton(builder.Configuration); + +IConfigurationSection configurationSection = builder.Configuration.GetSection("Indexer"); +IndexerOptions configuration = configurationSection.Get() ?? throw new ConfigurationErrorsException("Unable to start server due to an invalid configration"); +builder.Services.Configure(configurationSection); +builder.Services.Configure(configurationSection.GetSection("Server")); +builder.Services.Configure(configurationSection); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddHostedService(); diff --git a/src/Indexer/ScriptContainers/PythonScriptContainer.cs b/src/Indexer/ScriptContainers/PythonScriptContainer.cs index 5202a85..0e186e1 100644 --- a/src/Indexer/ScriptContainers/PythonScriptContainer.cs +++ b/src/Indexer/ScriptContainers/PythonScriptContainer.cs @@ -15,11 +15,8 @@ public class PythonScriptable : IScriptContainer public ILogger _logger { get; set; } public PythonScriptable(ScriptToolSet toolSet, ILogger logger) { - string? runtime = toolSet.Configuration.GetValue("EmbeddingsearchIndexer:PythonRuntime"); - if (runtime is not null) - { - Runtime.PythonDLL ??= runtime; - } + string runtime = toolSet.Configuration.PythonRuntime; + Runtime.PythonDLL ??= runtime; _logger = logger; SourceLoaded = false; if (!PythonEngine.IsInitialized) diff --git a/src/Indexer/WorkerManager.cs b/src/Indexer/WorkerManager.cs index 9a2e2c1..1549078 100644 --- a/src/Indexer/WorkerManager.cs +++ b/src/Indexer/WorkerManager.cs @@ -1,21 +1,22 @@ using Indexer.Exceptions; using Indexer.Models; using Indexer.ScriptContainers; +using Microsoft.Extensions.Options; public class WorkerManager { public Dictionary Workers; public List types; private readonly ILogger _logger; - private readonly IConfiguration _configuration; + private readonly IndexerOptions _configuration; private readonly Client.Client client; - public WorkerManager(ILogger logger, IConfiguration configuration, Client.Client client) + public WorkerManager(ILogger logger, IOptions configuration, Client.Client client) { Workers = []; types = [typeof(PythonScriptable), typeof(CSharpScriptable)]; _logger = logger; - _configuration = configuration; + _configuration = configuration.Value; this.client = client; } @@ -23,27 +24,12 @@ public class WorkerManager { _logger.LogInformation("Initializing workers"); // Load and configure all workers - var sectionMain = _configuration.GetSection("EmbeddingsearchIndexer"); - if (!sectionMain.Exists()) - { - _logger.LogCritical("Unable to load section \"EmbeddingsearchIndexer\""); - throw new IndexerConfigurationException("Unable to load section \"EmbeddingsearchIndexer\""); - } - WorkerCollectionConfig? sectionWorker = (WorkerCollectionConfig?)sectionMain.Get(typeof(WorkerCollectionConfig)); //GetValue("Worker"); - if (sectionWorker is not null) + foreach (WorkerConfig workerConfig in _configuration.Workers) { - foreach (WorkerConfig workerConfig in sectionWorker.Worker) - { - CancellationTokenSource cancellationTokenSource = new(); - ScriptToolSet toolSet = new(workerConfig.Script, client, _logger, _configuration, cancellationTokenSource.Token, workerConfig.Name); - InitializeWorker(toolSet, workerConfig, cancellationTokenSource); - } - } - else - { - _logger.LogCritical("Unable to load section \"Worker\""); - throw new IndexerConfigurationException("Unable to load section \"Worker\""); + CancellationTokenSource cancellationTokenSource = new(); + ScriptToolSet toolSet = new(workerConfig.Script, client, _logger, _configuration, cancellationTokenSource.Token, workerConfig.Name); + InitializeWorker(toolSet, workerConfig, cancellationTokenSource); } _logger.LogInformation("Initialized workers"); } diff --git a/src/Server/AIProvider.cs b/src/Server/AIProvider.cs index 1e5c485..a166a85 100644 --- a/src/Server/AIProvider.cs +++ b/src/Server/AIProvider.cs @@ -1,24 +1,25 @@ using System.Text; +using System.Text.RegularExpressions; +using Microsoft.Extensions.Options; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Server.Exceptions; +using Server.Models; namespace Server; public class AIProvider { private readonly ILogger _logger; - private readonly IConfiguration _configuration; - public AIProvidersConfiguration aIProvidersConfiguration; + private readonly EmbeddingSearchOptions _configuration; + public Dictionary aIProvidersConfiguration; - public AIProvider(ILogger logger, IConfiguration configuration) + public AIProvider(ILogger logger, IOptions configuration) { _logger = logger; - _configuration = configuration; - AIProvidersConfiguration? retrievedAiProvidersConfiguration = _configuration - .GetSection("Embeddingsearch") - .Get(); + _configuration = configuration.Value; + Dictionary? retrievedAiProvidersConfiguration = _configuration.AiProviders; if (retrievedAiProvidersConfiguration is null) { _logger.LogCritical("Unable to build AIProvidersConfiguration. Please check your configuration."); @@ -35,8 +36,8 @@ public class AIProvider Uri uri = new(modelUri); string provider = uri.Scheme; string model = uri.AbsolutePath; - AIProviderConfiguration? aIProvider = aIProvidersConfiguration.AiProviders - .FirstOrDefault(x => String.Equals(x.Key.ToLower(), provider.ToLower())) + AiProvider? aIProvider = aIProvidersConfiguration + .FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower())) .Value; if (aIProvider is null) { @@ -119,12 +120,12 @@ public class AIProvider public string[] GetModels() { - var aIProviders = aIProvidersConfiguration.AiProviders; + var aIProviders = aIProvidersConfiguration; List results = []; - foreach (KeyValuePair aIProviderKV in aIProviders) + foreach (KeyValuePair aIProviderKV in aIProviders) { string aIProviderName = aIProviderKV.Key; - AIProviderConfiguration aIProvider = aIProviderKV.Value; + AiProvider aIProvider = aIProviderKV.Value; using var httpClient = new HttpClient(); @@ -178,7 +179,12 @@ public class AIProvider foreach (string? result in aIProviderResult) { if (result is null) continue; - results.Add(aIProviderName + ":" + result); + bool isInAllowList = ElementMatchesAnyRegexInList(result, aIProvider.Allowlist); + bool isInDenyList = ElementMatchesAnyRegexInList(result, aIProvider.Denylist); + if (isInAllowList && !isInDenyList) + { + results.Add(aIProviderName + ":" + result); + } } } catch (Exception ex) @@ -189,6 +195,11 @@ public class AIProvider } return [.. results]; } + + private static bool ElementMatchesAnyRegexInList(string element, string[] list) + { + return list?.Any(pattern => pattern != null && Regex.IsMatch(element, pattern)) ?? false; + } } public class AIProvidersConfiguration diff --git a/src/Server/Controllers/AccountController.cs b/src/Server/Controllers/AccountController.cs index af218ec..bfa0a6f 100644 --- a/src/Server/Controllers/AccountController.cs +++ b/src/Server/Controllers/AccountController.cs @@ -12,9 +12,9 @@ public class AccountController : Controller { private readonly SimpleAuthOptions _options; - public AccountController(IOptions options) + public AccountController(IOptions options) { - _options = options.Value; + _options = options.Value.SimpleAuth; } [HttpGet("Login")] diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 2e0d6b4..69b271f 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -218,6 +218,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp { searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(preexistingEntity); } + searchdomain.UpdateModelsInUse(); return preexistingEntity; } else @@ -243,6 +244,7 @@ public class SearchdomainHelper(ILogger logger, DatabaseHelp }; entityCache.Add(entity); searchdomain.ReconciliateOrInvalidateCacheForNewOrUpdatedEntity(entity); + searchdomain.UpdateModelsInUse(); return entity; } } diff --git a/src/Server/Models/Auth.cs b/src/Server/Models/Auth.cs deleted file mode 100644 index 1bd0fa4..0000000 --- a/src/Server/Models/Auth.cs +++ /dev/null @@ -1,13 +0,0 @@ -namespace Server.Models; - -public class SimpleAuthOptions -{ - public List Users { get; set; } = new(); -} - -public class SimpleUser -{ - public string Username { get; set; } = ""; - public string Password { get; set; } = ""; - public string[] Roles { get; set; } = Array.Empty(); -} diff --git a/src/Server/Models/ConfigModels.cs b/src/Server/Models/ConfigModels.cs new file mode 100644 index 0000000..af55494 --- /dev/null +++ b/src/Server/Models/ConfigModels.cs @@ -0,0 +1,36 @@ +using System.Configuration; +using ElmahCore; +using Shared.Models; + +namespace Server.Models; + +public class EmbeddingSearchOptions : ApiKeyOptions +{ + public required ConnectionStringsSection ConnectionStrings { get; set; } + public ElmahOptions? Elmah { get; set; } + public required long EmbeddingCacheMaxCount { get; set; } + public required Dictionary AiProviders { get; set; } + public required SimpleAuthOptions SimpleAuth { get; set; } + public required bool UseHttpsRedirection { get; set; } +} + +public class AiProvider +{ + public required string Handler { get; set; } + public required string BaseURL { get; set; } + public string? ApiKey { get; set; } + public required string[] Allowlist { get; set; } + public required string[] Denylist { get; set; } +} + +public class SimpleAuthOptions +{ + public List Users { get; set; } = []; +} + +public class SimpleUser +{ + public string Username { get; set; } = ""; + public string Password { get; set; } = ""; + public string[] Roles { get; set; } = []; +} diff --git a/src/Server/Program.cs b/src/Server/Program.cs index 2fa7ccc..37b6bdc 100644 --- a/src/Server/Program.cs +++ b/src/Server/Program.cs @@ -10,11 +10,13 @@ using Server.Models; using Server.Services; using System.Text.Json.Serialization; using System.Reflection; +using System.Configuration; +using Microsoft.OpenApi.Models; +using Shared.Models; var builder = WebApplication.CreateBuilder(args); -// Add services to the container. - +// Add Controllers with views & string conversion for enums builder.Services.AddControllersWithViews() .AddJsonOptions(options => { @@ -23,6 +25,13 @@ builder.Services.AddControllersWithViews() ); }); +// Add Configuration +IConfigurationSection configurationSection = builder.Configuration.GetSection("Embeddingsearch"); +EmbeddingSearchOptions configuration = configurationSection.Get() ?? throw new ConfigurationErrorsException("Unable to start server due to an invalid configration"); + +builder.Services.Configure(configurationSection); +builder.Services.Configure(configurationSection); + // Add Localization builder.Services.AddLocalization(options => options.ResourcesPath = "Resources"); builder.Services.Configure(options => @@ -43,6 +52,31 @@ builder.Services.AddSwaggerGen(c => var xmlFile = $"{Assembly.GetExecutingAssembly().GetName().Name}.xml"; var xmlPath = Path.Combine(AppContext.BaseDirectory, xmlFile); c.IncludeXmlComments(xmlPath); + if (configuration.ApiKeys is not null) + { + c.AddSecurityDefinition("ApiKey", new OpenApiSecurityScheme + { + Description = "ApiKey must appear in header", + Type = SecuritySchemeType.ApiKey, + Name = "X-API-KEY", + In = ParameterLocation.Header, + Scheme = "ApiKeyScheme" + }); + var key = new OpenApiSecurityScheme() + { + Reference = new OpenApiReference + { + Type = ReferenceType.SecurityScheme, + Id = "ApiKey" + }, + In = ParameterLocation.Header + }; + var requirement = new OpenApiSecurityRequirement + { + { key, []} + }; + c.AddSecurityRequirement(requirement); + } }); Log.Logger = new LoggerConfiguration() .ReadFrom.Configuration(builder.Configuration) @@ -58,7 +92,12 @@ builder.Services.AddHealthChecks() builder.Services.AddElmah(Options => { - Options.LogPath = builder.Configuration.GetValue("Embeddingsearch:Elmah:LogFolder") ?? "~/logs"; + Options.OnPermissionCheck = context => + context.User.Claims.Any(claim => + claim.Value.Equals("Admin", StringComparison.OrdinalIgnoreCase) + || claim.Value.Equals("Elmah", StringComparison.OrdinalIgnoreCase) + ); + Options.LogPath = configuration.Elmah?.LogPath ?? "~/logs"; }); builder.Services @@ -76,35 +115,11 @@ builder.Services.AddAuthorization(options => policy => policy.RequireRole("Admin")); }); -IConfigurationSection simpleAuthSection = builder.Configuration.GetSection("Embeddingsearch:SimpleAuth"); -if (simpleAuthSection.Exists()) builder.Services.Configure(simpleAuthSection); var app = builder.Build(); -List? allowedIps = builder.Configuration.GetSection("Embeddingsearch:Elmah:AllowedHosts") - .Get>(); - -app.Use(async (context, next) => -{ - bool requestIsElmah = context.Request.Path.StartsWithSegments("/elmah"); - bool requestIsSwagger = context.Request.Path.StartsWithSegments("/swagger"); - - if (requestIsElmah || requestIsSwagger) - { - var remoteIp = context.Connection.RemoteIpAddress?.ToString(); - bool blockRequest = allowedIps is null - || remoteIp is null - || !allowedIps.Contains(remoteIp); - if (blockRequest) - { - context.Response.StatusCode = 403; - await context.Response.WriteAsync("Forbidden"); - return; - } - } - - await next(); -}); +app.UseAuthentication(); +app.UseAuthorization(); app.UseElmah(); @@ -120,19 +135,49 @@ app.MapHealthChecks("/healthz/AIProvider", new Microsoft.AspNetCore.Diagnostics. }); bool IsDevelopment = app.Environment.IsDevelopment(); -bool useSwagger = app.Configuration.GetValue("UseSwagger"); -bool? UseMiddleware = app.Configuration.GetValue("UseMiddleware"); -// Configure the HTTP request pipeline. -if (IsDevelopment || useSwagger) +app.Use(async (context, next) => { - app.UseSwagger(); - app.UseSwaggerUI(); - //app.UseElmahExceptionPage(); // Messes with JSON response for API calls. Leaving this here so I don't accidentally put this in again later on. -} -if (UseMiddleware == true && !IsDevelopment) + if (context.Request.Path.StartsWithSegments("/swagger")) + { + if (!context.User.Identity?.IsAuthenticated ?? true) + { + context.Response.Redirect("/Account/Login"); + return; + } + + if (!context.User.IsInRole("Admin")) + { + context.Response.StatusCode = StatusCodes.Status403Forbidden; + return; + } + } + + await next(); +}); + +app.UseSwagger(); +app.UseSwaggerUI(options => { - app.UseMiddleware(); + options.EnablePersistAuthorization(); +}); +//app.UseElmahExceptionPage(); // Messes with JSON response for API calls. Leaving this here so I don't accidentally put this in again later on. + +if (configuration.ApiKeys is not null) +{ + app.UseWhen(context => + { + RouteData routeData = context.GetRouteData(); + string controllerName = routeData.Values["controller"]?.ToString() ?? "StaticFile"; + if (controllerName == "Account" || controllerName == "Home" || controllerName == "StaticFile") + { + return false; + } + return true; + }, appBuilder => + { + appBuilder.UseMiddleware(); + }); } // Add localization @@ -143,9 +188,6 @@ var localizationOptions = new RequestLocalizationOptions() .AddSupportedUICultures(supportedCultures); app.UseRequestLocalization(localizationOptions); -app.UseAuthentication(); -app.UseAuthorization(); - app.MapControllers(); app.UseStaticFiles(); diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index 421fe8d..c764203 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -216,6 +216,11 @@ public class Searchdomain return queryEmbeddings; } + public void UpdateModelsInUse() + { + modelsInUse = GetModels([.. entityCache]); + } + private static float EvaluateEntityAgainstQueryEmbeddings(Entity entity, Dictionary queryEmbeddings) { List<(string, float)> datapointProbs = []; @@ -237,16 +242,19 @@ public class Searchdomain public static List GetModels(List entities) { List result = []; - foreach (Entity entity in entities) + lock (entities) { - foreach (Datapoint datapoint in entity.datapoints) + foreach (Entity entity in entities) { - foreach ((string, float[]) tuple in datapoint.embeddings) + foreach (Datapoint datapoint in entity.datapoints) { - string model = tuple.Item1; - if (!result.Contains(model)) + foreach ((string, float[]) tuple in datapoint.embeddings) { - result.Add(model); + string model = tuple.Item1; + if (!result.Contains(model)) + { + result.Add(model); + } } } } diff --git a/src/Server/appsettings.Development.json b/src/Server/appsettings.Development.json index 7c373a1..6fa3313 100644 --- a/src/Server/appsettings.Development.json +++ b/src/Server/appsettings.Development.json @@ -18,22 +18,21 @@ "SQL": "server=localhost;database=embeddingsearch;uid=embeddingsearch;pwd=somepassword!;" }, "Elmah": { - "AllowedHosts": [ - "127.0.0.1", - "::1", - "172.17.0.1" - ] + "LogPath": "~/logs" }, "EmbeddingCacheMaxCount": 10000000, "AiProviders": { "ollama": { "handler": "ollama", - "baseURL": "http://localhost:11434" + "baseURL": "http://192.168.0.101:11434", + "Allowlist": ["*"], + "Denylist": ["qwen3-coder:latest", "qwen3:0.6b", "deepseek-v3.1:671b-cloud", "qwen3-vl", "deepseek-ocr"] }, "localAI": { "handler": "openai", - "baseURL": "http://localhost:8080", - "ApiKey": "Some API key here" + "ApiKey": "Some API key here", + "Allowlist": ["*"], + "Denylist": ["cross-encoder", "kitten-tts", "jina-reranker-v1-tiny-en", "whisper-small", "qwen3-vl-2b-instruct"] } }, "SimpleAuth": { diff --git a/src/Server/appsettings.json b/src/Server/appsettings.json index 657168d..e6db786 100644 --- a/src/Server/appsettings.json +++ b/src/Server/appsettings.json @@ -16,14 +16,5 @@ "Application": "Embeddingsearch.Server" } }, - "EmbeddingsearchIndexer": { - "Elmah": { - "AllowedHosts": [ - "127.0.0.1", - "::1" - ], - "LogFolder": "./logs" - } - }, "AllowedHosts": "*" } diff --git a/src/Shared/ApiKeyMiddleware.cs b/src/Shared/ApiKeyMiddleware.cs index 007a24b..775860b 100644 --- a/src/Shared/ApiKeyMiddleware.cs +++ b/src/Shared/ApiKeyMiddleware.cs @@ -1,38 +1,41 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; +using Shared.Models; namespace Shared; public class ApiKeyMiddleware { private readonly RequestDelegate _next; - private readonly IConfiguration _configuration; + private readonly ApiKeyOptions _configuration; - public ApiKeyMiddleware(RequestDelegate next, IConfiguration configuration) + public ApiKeyMiddleware(RequestDelegate next, IOptions configuration) { _next = next; - _configuration = configuration; + _configuration = configuration.Value; } public async Task InvokeAsync(HttpContext context) { - if (!context.Request.Headers.TryGetValue("X-API-KEY", out StringValues extractedApiKey)) + if (!(context.User.Identity?.IsAuthenticated ?? false)) { - context.Response.StatusCode = 401; - await context.Response.WriteAsync("API Key is missing."); - return; - } + if (!context.Request.Headers.TryGetValue("X-API-KEY", out StringValues extractedApiKey)) + { + context.Response.StatusCode = 401; + await context.Response.WriteAsync("API Key is missing."); + return; + } - var validApiKeys = _configuration.GetSection("Embeddingsearch").GetSection("ApiKeys").Get>(); -#pragma warning disable CS8604 - if (validApiKeys == null || !validApiKeys.Contains(extractedApiKey)) // CS8604 extractedApiKey is not null here, but the compiler still thinks that it might be. - { - context.Response.StatusCode = 403; - await context.Response.WriteAsync("Invalid API Key."); - return; + string[]? validApiKeys = _configuration.ApiKeys; + if (validApiKeys == null || !validApiKeys.ToList().Contains(extractedApiKey)) + { + context.Response.StatusCode = 403; + await context.Response.WriteAsync("Invalid API Key."); + return; + } } -#pragma warning restore CS8604 await _next(context); } diff --git a/src/Shared/Models/OptionModels.cs b/src/Shared/Models/OptionModels.cs new file mode 100644 index 0000000..f9a0196 --- /dev/null +++ b/src/Shared/Models/OptionModels.cs @@ -0,0 +1,13 @@ +namespace Shared.Models; + +public class ApiKeyOptions +{ + public string[]? ApiKeys { get; set; } +} + +public class ServerOptions +{ + public required string BaseUri { get; set; } + public string? ApiKey { get; set; } + public string? Searchdomain { get; set; } +} \ No newline at end of file