From aa95308f61462e6201afcf7f77069a4fe2f9044c Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Wed, 31 Dec 2025 03:47:40 +0100 Subject: [PATCH] Added allowlist and denylist, fixed patchy configuration with proper options models, fixed api middleware authorization issues --- src/Client/Client.cs | 11 +++--- src/Indexer/Indexer.csproj | 1 + src/Indexer/Models/OptionModels.cs | 9 +++++ src/Indexer/Models/ScriptModels.cs | 4 +- src/Indexer/Program.cs | 8 ++++ .../ScriptContainers/PythonScriptContainer.cs | 7 +--- src/Indexer/WorkerManager.cs | 30 ++++----------- src/Server/AIProvider.cs | 37 ++++++++++++------- src/Server/Controllers/AccountController.cs | 4 +- src/Server/Models/ConfigModels.cs | 8 ++-- src/Server/Program.cs | 18 +++++++-- src/Shared/ApiKeyMiddleware.cs | 35 ++++++++++-------- src/Shared/Models/OptionModels.cs | 13 +++++++ 13 files changed, 113 insertions(+), 72 deletions(-) create mode 100644 src/Indexer/Models/OptionModels.cs create mode 100644 src/Shared/Models/OptionModels.cs diff --git a/src/Client/Client.cs b/src/Client/Client.cs index be7b92e..b4df755 100644 --- a/src/Client/Client.cs +++ b/src/Client/Client.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Configuration; using System.Reflection.Metadata.Ecma335; using Shared.Models; using System.Net; +using Microsoft.Extensions.Options; namespace Client; @@ -25,12 +26,12 @@ public class Client this.searchdomain = searchdomain; } - public Client(IConfiguration configuration) + public Client(IOptions configuration) { - string? baseUri = configuration.GetSection("Embeddingsearch").GetValue("BaseUri"); - string? apiKey = configuration.GetSection("Embeddingsearch").GetValue("ApiKey"); - string? searchdomain = configuration.GetSection("Embeddingsearch").GetValue("Searchdomain"); - this.baseUri = baseUri ?? ""; + string baseUri = configuration.Value.BaseUri; + string? apiKey = configuration.Value.ApiKey; + string? searchdomain = configuration.Value.Searchdomain; + this.baseUri = baseUri; this.apiKey = apiKey ?? ""; this.searchdomain = searchdomain ?? ""; } diff --git a/src/Indexer/Indexer.csproj b/src/Indexer/Indexer.csproj index 8c2a30b..c3d7d71 100644 --- a/src/Indexer/Indexer.csproj +++ b/src/Indexer/Indexer.csproj @@ -15,6 +15,7 @@ + diff --git a/src/Indexer/Models/OptionModels.cs b/src/Indexer/Models/OptionModels.cs new file mode 100644 index 0000000..6b33793 --- /dev/null +++ b/src/Indexer/Models/OptionModels.cs @@ -0,0 +1,9 @@ +using Shared.Models; +namespace Indexer.Models; + +public class IndexerOptions : ApiKeyOptions +{ + public required WorkerConfig[] Workers { get; set; } + public required ServerOptions Server { get; set;} + public required string PythonRuntime { get; set; } = "libpython3.13.so"; +} diff --git a/src/Indexer/Models/ScriptModels.cs b/src/Indexer/Models/ScriptModels.cs index 3f3043e..252269a 100644 --- a/src/Indexer/Models/ScriptModels.cs +++ b/src/Indexer/Models/ScriptModels.cs @@ -15,11 +15,11 @@ public class ScriptToolSet public Client.Client Client; public LoggerWrapper Logger; public ICallbackInfos? CallbackInfos; - public IConfiguration Configuration; + public IndexerOptions Configuration; public CancellationToken CancellationToken; public string Name; - public ScriptToolSet(string filePath, Client.Client client, ILogger logger, IConfiguration configuration, CancellationToken cancellationToken, string name) + public ScriptToolSet(string filePath, Client.Client client, ILogger logger, IndexerOptions configuration, CancellationToken cancellationToken, string name) { Configuration = configuration; Name = name; diff --git a/src/Indexer/Program.cs b/src/Indexer/Program.cs index a78af58..fbef06d 100644 --- a/src/Indexer/Program.cs +++ b/src/Indexer/Program.cs @@ -6,6 +6,8 @@ using ElmahCore.Mvc; using ElmahCore.Mvc.Logger; using Serilog; using Quartz; +using System.Configuration; +using Shared.Models; var builder = WebApplication.CreateBuilder(args); @@ -21,6 +23,12 @@ Log.Logger = new LoggerConfiguration() builder.Logging.AddSerilog(); builder.Services.AddHttpContextAccessor(); builder.Services.AddSingleton(builder.Configuration); + +IConfigurationSection configurationSection = builder.Configuration.GetSection("Indexer"); +IndexerOptions configuration = configurationSection.Get() ?? throw new ConfigurationErrorsException("Unable to start server due to an invalid configration"); +builder.Services.Configure(configurationSection); +builder.Services.Configure(configurationSection.GetSection("Server")); +builder.Services.Configure(configurationSection); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddHostedService(); diff --git a/src/Indexer/ScriptContainers/PythonScriptContainer.cs b/src/Indexer/ScriptContainers/PythonScriptContainer.cs index 5202a85..0e186e1 100644 --- a/src/Indexer/ScriptContainers/PythonScriptContainer.cs +++ b/src/Indexer/ScriptContainers/PythonScriptContainer.cs @@ -15,11 +15,8 @@ public class PythonScriptable : IScriptContainer public ILogger _logger { get; set; } public PythonScriptable(ScriptToolSet toolSet, ILogger logger) { - string? runtime = toolSet.Configuration.GetValue("EmbeddingsearchIndexer:PythonRuntime"); - if (runtime is not null) - { - Runtime.PythonDLL ??= runtime; - } + string runtime = toolSet.Configuration.PythonRuntime; + Runtime.PythonDLL ??= runtime; _logger = logger; SourceLoaded = false; if (!PythonEngine.IsInitialized) diff --git a/src/Indexer/WorkerManager.cs b/src/Indexer/WorkerManager.cs index 9a2e2c1..1549078 100644 --- a/src/Indexer/WorkerManager.cs +++ b/src/Indexer/WorkerManager.cs @@ -1,21 +1,22 @@ using Indexer.Exceptions; using Indexer.Models; using Indexer.ScriptContainers; +using Microsoft.Extensions.Options; public class WorkerManager { public Dictionary Workers; public List types; private readonly ILogger _logger; - private readonly IConfiguration _configuration; + private readonly IndexerOptions _configuration; private readonly Client.Client client; - public WorkerManager(ILogger logger, IConfiguration configuration, Client.Client client) + public WorkerManager(ILogger logger, IOptions configuration, Client.Client client) { Workers = []; types = [typeof(PythonScriptable), typeof(CSharpScriptable)]; _logger = logger; - _configuration = configuration; + _configuration = configuration.Value; this.client = client; } @@ -23,27 +24,12 @@ public class WorkerManager { _logger.LogInformation("Initializing workers"); // Load and configure all workers - var sectionMain = _configuration.GetSection("EmbeddingsearchIndexer"); - if (!sectionMain.Exists()) - { - _logger.LogCritical("Unable to load section \"EmbeddingsearchIndexer\""); - throw new IndexerConfigurationException("Unable to load section \"EmbeddingsearchIndexer\""); - } - WorkerCollectionConfig? sectionWorker = (WorkerCollectionConfig?)sectionMain.Get(typeof(WorkerCollectionConfig)); //GetValue("Worker"); - if (sectionWorker is not null) + foreach (WorkerConfig workerConfig in _configuration.Workers) { - foreach (WorkerConfig workerConfig in sectionWorker.Worker) - { - CancellationTokenSource cancellationTokenSource = new(); - ScriptToolSet toolSet = new(workerConfig.Script, client, _logger, _configuration, cancellationTokenSource.Token, workerConfig.Name); - InitializeWorker(toolSet, workerConfig, cancellationTokenSource); - } - } - else - { - _logger.LogCritical("Unable to load section \"Worker\""); - throw new IndexerConfigurationException("Unable to load section \"Worker\""); + CancellationTokenSource cancellationTokenSource = new(); + ScriptToolSet toolSet = new(workerConfig.Script, client, _logger, _configuration, cancellationTokenSource.Token, workerConfig.Name); + InitializeWorker(toolSet, workerConfig, cancellationTokenSource); } _logger.LogInformation("Initialized workers"); } diff --git a/src/Server/AIProvider.cs b/src/Server/AIProvider.cs index 1e5c485..a166a85 100644 --- a/src/Server/AIProvider.cs +++ b/src/Server/AIProvider.cs @@ -1,24 +1,25 @@ using System.Text; +using System.Text.RegularExpressions; +using Microsoft.Extensions.Options; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Server.Exceptions; +using Server.Models; namespace Server; public class AIProvider { private readonly ILogger _logger; - private readonly IConfiguration _configuration; - public AIProvidersConfiguration aIProvidersConfiguration; + private readonly EmbeddingSearchOptions _configuration; + public Dictionary aIProvidersConfiguration; - public AIProvider(ILogger logger, IConfiguration configuration) + public AIProvider(ILogger logger, IOptions configuration) { _logger = logger; - _configuration = configuration; - AIProvidersConfiguration? retrievedAiProvidersConfiguration = _configuration - .GetSection("Embeddingsearch") - .Get(); + _configuration = configuration.Value; + Dictionary? retrievedAiProvidersConfiguration = _configuration.AiProviders; if (retrievedAiProvidersConfiguration is null) { _logger.LogCritical("Unable to build AIProvidersConfiguration. Please check your configuration."); @@ -35,8 +36,8 @@ public class AIProvider Uri uri = new(modelUri); string provider = uri.Scheme; string model = uri.AbsolutePath; - AIProviderConfiguration? aIProvider = aIProvidersConfiguration.AiProviders - .FirstOrDefault(x => String.Equals(x.Key.ToLower(), provider.ToLower())) + AiProvider? aIProvider = aIProvidersConfiguration + .FirstOrDefault(x => string.Equals(x.Key.ToLower(), provider.ToLower())) .Value; if (aIProvider is null) { @@ -119,12 +120,12 @@ public class AIProvider public string[] GetModels() { - var aIProviders = aIProvidersConfiguration.AiProviders; + var aIProviders = aIProvidersConfiguration; List results = []; - foreach (KeyValuePair aIProviderKV in aIProviders) + foreach (KeyValuePair aIProviderKV in aIProviders) { string aIProviderName = aIProviderKV.Key; - AIProviderConfiguration aIProvider = aIProviderKV.Value; + AiProvider aIProvider = aIProviderKV.Value; using var httpClient = new HttpClient(); @@ -178,7 +179,12 @@ public class AIProvider foreach (string? result in aIProviderResult) { if (result is null) continue; - results.Add(aIProviderName + ":" + result); + bool isInAllowList = ElementMatchesAnyRegexInList(result, aIProvider.Allowlist); + bool isInDenyList = ElementMatchesAnyRegexInList(result, aIProvider.Denylist); + if (isInAllowList && !isInDenyList) + { + results.Add(aIProviderName + ":" + result); + } } } catch (Exception ex) @@ -189,6 +195,11 @@ public class AIProvider } return [.. results]; } + + private static bool ElementMatchesAnyRegexInList(string element, string[] list) + { + return list?.Any(pattern => pattern != null && Regex.IsMatch(element, pattern)) ?? false; + } } public class AIProvidersConfiguration diff --git a/src/Server/Controllers/AccountController.cs b/src/Server/Controllers/AccountController.cs index 708cfe8..bfa0a6f 100644 --- a/src/Server/Controllers/AccountController.cs +++ b/src/Server/Controllers/AccountController.cs @@ -12,9 +12,9 @@ public class AccountController : Controller { private readonly SimpleAuthOptions _options; - public AccountController(EmbeddingSearchOptions options) + public AccountController(IOptions options) { - _options = options.SimpleAuth; + _options = options.Value.SimpleAuth; } [HttpGet("Login")] diff --git a/src/Server/Models/ConfigModels.cs b/src/Server/Models/ConfigModels.cs index 4ea56ff..af55494 100644 --- a/src/Server/Models/ConfigModels.cs +++ b/src/Server/Models/ConfigModels.cs @@ -1,16 +1,16 @@ using System.Configuration; using ElmahCore; +using Shared.Models; namespace Server.Models; -public class EmbeddingSearchOptions +public class EmbeddingSearchOptions : ApiKeyOptions { public required ConnectionStringsSection ConnectionStrings { get; set; } public ElmahOptions? Elmah { get; set; } public required long EmbeddingCacheMaxCount { get; set; } - public required AiProvider[] AiProviders { get; set; } + public required Dictionary AiProviders { get; set; } public required SimpleAuthOptions SimpleAuth { get; set; } - public string[]? ApiKeys { get; set; } public required bool UseHttpsRedirection { get; set; } } @@ -18,7 +18,7 @@ public class AiProvider { public required string Handler { get; set; } public required string BaseURL { get; set; } - public required string ApiKey { get; set; } + public string? ApiKey { get; set; } public required string[] Allowlist { get; set; } public required string[] Denylist { get; set; } } diff --git a/src/Server/Program.cs b/src/Server/Program.cs index 770b7f6..35dbbb1 100644 --- a/src/Server/Program.cs +++ b/src/Server/Program.cs @@ -12,6 +12,7 @@ using System.Text.Json.Serialization; using System.Reflection; using System.Configuration; using Microsoft.OpenApi.Models; +using Shared.Models; var builder = WebApplication.CreateBuilder(args); @@ -29,6 +30,7 @@ IConfigurationSection configurationSection = builder.Configuration.GetSection("E EmbeddingSearchOptions configuration = configurationSection.Get() ?? throw new ConfigurationErrorsException("Unable to start server due to an invalid configration"); builder.Services.Configure(configurationSection); +builder.Services.Configure(configurationSection); // Add Localization builder.Services.AddLocalization(options => options.ResourcesPath = "Resources"); @@ -133,8 +135,6 @@ app.MapHealthChecks("/healthz/AIProvider", new Microsoft.AspNetCore.Diagnostics. }); bool IsDevelopment = app.Environment.IsDevelopment(); -bool useSwagger = app.Configuration.GetValue("UseSwagger"); -bool? UseMiddleware = app.Configuration.GetValue("UseMiddleware"); app.UseSwagger(); app.UseSwaggerUI(options => @@ -145,7 +145,19 @@ app.UseSwaggerUI(options => if (configuration.ApiKeys is not null) { - app.UseMiddleware(); + app.UseWhen(context => + { + RouteData routeData = context.GetRouteData(); + string controllerName = routeData.Values["controller"]?.ToString() ?? "StaticFile"; + if (controllerName == "Account" || controllerName == "Home" || controllerName == "StaticFile") + { + return false; + } + return true; + }, appBuilder => + { + appBuilder.UseMiddleware(); + }); } // Add localization diff --git a/src/Shared/ApiKeyMiddleware.cs b/src/Shared/ApiKeyMiddleware.cs index 007a24b..775860b 100644 --- a/src/Shared/ApiKeyMiddleware.cs +++ b/src/Shared/ApiKeyMiddleware.cs @@ -1,38 +1,41 @@ using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; +using Shared.Models; namespace Shared; public class ApiKeyMiddleware { private readonly RequestDelegate _next; - private readonly IConfiguration _configuration; + private readonly ApiKeyOptions _configuration; - public ApiKeyMiddleware(RequestDelegate next, IConfiguration configuration) + public ApiKeyMiddleware(RequestDelegate next, IOptions configuration) { _next = next; - _configuration = configuration; + _configuration = configuration.Value; } public async Task InvokeAsync(HttpContext context) { - if (!context.Request.Headers.TryGetValue("X-API-KEY", out StringValues extractedApiKey)) + if (!(context.User.Identity?.IsAuthenticated ?? false)) { - context.Response.StatusCode = 401; - await context.Response.WriteAsync("API Key is missing."); - return; - } + if (!context.Request.Headers.TryGetValue("X-API-KEY", out StringValues extractedApiKey)) + { + context.Response.StatusCode = 401; + await context.Response.WriteAsync("API Key is missing."); + return; + } - var validApiKeys = _configuration.GetSection("Embeddingsearch").GetSection("ApiKeys").Get>(); -#pragma warning disable CS8604 - if (validApiKeys == null || !validApiKeys.Contains(extractedApiKey)) // CS8604 extractedApiKey is not null here, but the compiler still thinks that it might be. - { - context.Response.StatusCode = 403; - await context.Response.WriteAsync("Invalid API Key."); - return; + string[]? validApiKeys = _configuration.ApiKeys; + if (validApiKeys == null || !validApiKeys.ToList().Contains(extractedApiKey)) + { + context.Response.StatusCode = 403; + await context.Response.WriteAsync("Invalid API Key."); + return; + } } -#pragma warning restore CS8604 await _next(context); } diff --git a/src/Shared/Models/OptionModels.cs b/src/Shared/Models/OptionModels.cs new file mode 100644 index 0000000..f9a0196 --- /dev/null +++ b/src/Shared/Models/OptionModels.cs @@ -0,0 +1,13 @@ +namespace Shared.Models; + +public class ApiKeyOptions +{ + public string[]? ApiKeys { get; set; } +} + +public class ServerOptions +{ + public required string BaseUri { get; set; } + public string? ApiKey { get; set; } + public string? Searchdomain { get; set; } +} \ No newline at end of file