From 5f05aac9095ccb92bee099e9eaaf8dce6feb6e79 Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Wed, 21 Jan 2026 23:54:08 +0100 Subject: [PATCH] Added persistent embedding cache --- .gitignore | 1 + src/Server/Helper/CacheHelper.cs | 242 ++++++++++++++++++++++ src/Server/Helper/SQLiteHelper.cs | 76 +++++++ src/Server/Migrations/SQLiteMigrations.cs | 65 ++++++ src/Server/Models/ConfigModels.cs | 10 +- src/Server/Models/SQLHelper.cs | 109 ++++++++++ src/Server/Program.cs | 10 + src/Server/SearchdomainManager.cs | 54 ++++- src/Server/appsettings.Development.json | 11 +- 9 files changed, 570 insertions(+), 8 deletions(-) create mode 100644 src/Server/Helper/CacheHelper.cs create mode 100644 src/Server/Helper/SQLiteHelper.cs create mode 100644 src/Server/Migrations/SQLiteMigrations.cs create mode 100644 src/Server/Models/SQLHelper.cs diff --git a/.gitignore b/.gitignore index f8595c4..6c181cd 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ src/Shared/obj src/Server/wwwroot/logs/* src/Server/Tools/CriticalCSS/node_modules src/Server/Tools/CriticalCSS/package*.json +*.db* diff --git a/src/Server/Helper/CacheHelper.cs b/src/Server/Helper/CacheHelper.cs new file mode 100644 index 0000000..a80b6db --- /dev/null +++ b/src/Server/Helper/CacheHelper.cs @@ -0,0 +1,242 @@ +using System.Configuration; +using Microsoft.Data.Sqlite; +using Microsoft.Extensions.Options; +using OllamaSharp.Models; +using Server.Models; +using Shared; + +namespace Server.Helper; + +public static class CacheHelper +{ + public static EnumerableLruCache> GetEmbeddingStore(EmbeddingSearchOptions options) + { + SQLiteHelper helper = new(options); + EnumerableLruCache> embeddingCache = new((int)(options.Cache.StoreTopN ?? options.Cache.CacheTopN)); + helper.ExecuteQuery( + "SELECT cache_key, model_key, embedding, idx FROM embedding_cache ORDER BY idx ASC", [], r => + { + int embeddingOrdinal = r.GetOrdinal("embedding"); + int length = (int)r.GetBytes(embeddingOrdinal, 0, null, 0, 0); + byte[] buffer = new byte[length]; + r.GetBytes(embeddingOrdinal, 0, buffer, 0, length); + var cache_key = r.GetString(r.GetOrdinal("cache_key")); + var model_key = r.GetString(r.GetOrdinal("model_key")); + var embedding = SearchdomainHelper.FloatArrayFromBytes(buffer); + var index = r.GetInt32(r.GetOrdinal("idx")); + if (cache_key is null || model_key is null || embedding is null) + { + throw new Exception("Unable to get the embedding store due to a returned element being null"); + } + if (!embeddingCache.TryGetValue(cache_key, out Dictionary? keyElement) || keyElement is null) + { + keyElement = []; + embeddingCache[cache_key] = keyElement; + } + keyElement[model_key] = embedding; + return 0; + } + ); + embeddingCache.Capacity = (int)options.Cache.CacheTopN; + return embeddingCache; + } + + public static async Task UpdateEmbeddingStore(EnumerableLruCache> embeddingCache, EmbeddingSearchOptions options) + { + if (options.Cache.StoreTopN is not null) + { + embeddingCache.Capacity = (int)options.Cache.StoreTopN; + } + SQLiteHelper helper = new(options); + EnumerableLruCache> embeddingStore = GetEmbeddingStore(options); + + + var embeddingCacheMappings = GetCacheMappings(embeddingCache); + var embeddingCacheIndexMap = embeddingCacheMappings.positionToEntry; + var embeddingCacheObjectMap = embeddingCacheMappings.entryToPosition; + + var embeddingStoreMappings = GetCacheMappings(embeddingStore); + var embeddingStoreIndexMap = embeddingStoreMappings.positionToEntry; + var embeddingStoreObjectMap = embeddingStoreMappings.entryToPosition; + + List deletedEntries = []; + + foreach (KeyValuePair>> kv in embeddingStoreIndexMap) + { + int storeEntryIndex = kv.Key; + string storeEntryString = kv.Value.Key; + bool cacheEntryExists = embeddingCacheObjectMap.TryGetValue(storeEntryString, out int cacheEntryIndex); + + if (!cacheEntryExists) // Deleted + { + deletedEntries.Add(storeEntryIndex); + } + } + Task removeEntriesFromStoreTask = RemoveEntriesFromStore(helper, deletedEntries); + + + List<(int Index, KeyValuePair> Entry)> createdEntries = []; + List<(int Index, int NewIndex)> changedEntries = []; + List<(int Index, string Model, string Key, float[] Embedding)> AddedModels = []; + List<(int Index, string Model)> RemovedModels = []; + foreach (KeyValuePair>> kv in embeddingCacheIndexMap) + { + int cacheEntryIndex = kv.Key; + string cacheEntryString = kv.Value.Key; + + bool storeEntryExists = embeddingStoreObjectMap.TryGetValue(cacheEntryString, out int storeEntryIndex); + + if (!storeEntryExists) // Created + { + createdEntries.Add(( + Index: cacheEntryIndex, + Entry: kv.Value + )); + continue; + } + if (cacheEntryIndex != storeEntryIndex) // Changed + { + changedEntries.Add(( + Index: cacheEntryIndex, + NewIndex: storeEntryIndex + )); + } + + // Check for new/removed models + var storeModels = embeddingStoreIndexMap[storeEntryIndex].Value; + var cacheModels = kv.Value.Value; + // New models + foreach (var model in storeModels.Keys.Except(cacheModels.Keys)) + { + RemovedModels.Add(( + Index: cacheEntryIndex, + Model: model + )); + } + // Removed models + foreach (var model in cacheModels.Keys.Except(storeModels.Keys)) + { + AddedModels.Add(( + Index: cacheEntryIndex, + Model: model, + Key: cacheEntryString, + Embedding: cacheModels[model] + )); + } + } + + var taskSet = new List + { + removeEntriesFromStoreTask, + CreateEntriesInStore(helper, createdEntries), + UpdateEntryIndicesInStore(helper, changedEntries), + AddModelsToIndices(helper, AddedModels), + RemoveModelsFromIndices(helper, RemovedModels) + }; + + await Task.WhenAll(taskSet); + } + + private static async Task CreateEntriesInStore( + SQLiteHelper helper, + List<(int Index, KeyValuePair> Entry)> createdEntries) + { + helper.BulkExecuteNonQuery( + "INSERT INTO embedding_cache (cache_key, model_key, embedding, idx) VALUES (@cache_key, @model_key, @embedding, @index)", + createdEntries.SelectMany(element => { + return element.Entry.Value.Select(model => new object[] + { + new SqliteParameter("@cache_key", element.Entry.Key), + new SqliteParameter("@model_key", model.Key), + new SqliteParameter("@embedding", SearchdomainHelper.BytesFromFloatArray(model.Value)), + new SqliteParameter("@index", element.Index) + }); + }) + ); + } + + private static async Task UpdateEntryIndicesInStore( + SQLiteHelper helper, + List<(int Index, int NewIndex)> changedEntries) + { + helper.BulkExecuteNonQuery( + "UPDATE embedding_cache SET idx = @newIndex WHERE idx = @index", + changedEntries.Select(element => new object[] + { + new SqliteParameter("@index", element.Index), + new SqliteParameter("@newIndex", -element.NewIndex) // The "-" prevents in-place update collisions + }) + ); + helper.BulkExecuteNonQuery( + "UPDATE embedding_cache SET idx = @newIndex WHERE idx = @index", + changedEntries.Select(element => new object[] + { + new SqliteParameter("@index", -element.NewIndex), + new SqliteParameter("@newIndex", element.NewIndex) // Flip the negative prefix + }) + ); + } + + private static async Task RemoveEntriesFromStore( + SQLiteHelper helper, + List deletedEntries) + { + helper.BulkExecuteNonQuery( + "DELETE FROM embedding_cache WHERE idx = @index", + deletedEntries.Select(index => new object[] + { + new SqliteParameter("@index", index) + }) + ); + } + + private static async Task AddModelsToIndices( + SQLiteHelper helper, + List<(int Index, string Model, string Key, float[] Embedding)> addedModels) + { + helper.BulkExecuteNonQuery( + "INSERT INTO embedding_cache (cache_key, model_key, embedding, idx) VALUES (@cache_key, @model_key, @embedding, @index)", + addedModels.Select(element => new object[] + { + new SqliteParameter("@cache_key", element.Key), + new SqliteParameter("@model_key", element.Model), + new SqliteParameter("@embedding", SearchdomainHelper.BytesFromFloatArray(element.Embedding)), + new SqliteParameter("@index", element.Index) + }) + ); + } + + private static async Task RemoveModelsFromIndices( + SQLiteHelper helper, + List<(int Index, string Model)> removedModels) + { + helper.BulkExecuteNonQuery( + "DELETE FROM embedding_cache WHERE idx = @index AND model_key = @model", + removedModels.Select(element => new object[] + { + new SqliteParameter("@index", element.Index), + new SqliteParameter("@model", element.Model) + }) + ); + } + + + private static (Dictionary>> positionToEntry, + Dictionary entryToPosition) + GetCacheMappings(EnumerableLruCache> embeddingCache) + { + var positionToEntry = new Dictionary>>(); + var entryToPosition = new Dictionary(); + + int position = 0; + + foreach (var entry in embeddingCache) + { + positionToEntry[position] = entry; + entryToPosition[entry.Key] = position; + position++; + } + + return (positionToEntry, entryToPosition); + } +} \ No newline at end of file diff --git a/src/Server/Helper/SQLiteHelper.cs b/src/Server/Helper/SQLiteHelper.cs new file mode 100644 index 0000000..4d63752 --- /dev/null +++ b/src/Server/Helper/SQLiteHelper.cs @@ -0,0 +1,76 @@ +using System.Data; +using System.Data.Common; +using Microsoft.Data.Sqlite; +using Server.Models; +using MySql.Data.MySqlClient; +using System.Configuration; + +namespace Server.Helper; + +public class SQLiteHelper : SqlHelper, IDisposable +{ + public SQLiteHelper(DbConnection connection, string connectionString) : base(connection, connectionString) + { + Connection = connection; + ConnectionString = connectionString; + } + + public SQLiteHelper(EmbeddingSearchOptions options) : base(new SqliteConnection(options.ConnectionStrings.Cache), options.ConnectionStrings.Cache ?? "") + { + if (options.ConnectionStrings.Cache is null) + { + throw new ConfigurationErrorsException("Cache options must not be null when instantiating SQLiteHelper"); + } + ConnectionString = options.ConnectionStrings.Cache; + Connection = new SqliteConnection(ConnectionString); + } + + public override SQLiteHelper DuplicateConnection() + { + SqliteConnection newConnection = new(ConnectionString); + return new SQLiteHelper(newConnection, ConnectionString); + } + + public override int ExecuteSQLCommandGetInsertedID(string query, object[] parameters) + { + lock (Connection) + { + EnsureConnected(); + EnsureDbReaderIsClosed(); + using DbCommand command = Connection.CreateCommand(); + + command.CommandText = query; + command.Parameters.AddRange(parameters); + command.ExecuteNonQuery(); + command.CommandText = "SELECT last_insert_rowid();"; + return Convert.ToInt32(command.ExecuteScalar()); + } + } + + public int BulkExecuteNonQuery(string sql, IEnumerable parameterSets) + { + lock (Connection) + { + EnsureConnected(); + EnsureDbReaderIsClosed(); + + using var transaction = Connection.BeginTransaction(); + using var command = Connection.CreateCommand(); + + command.CommandText = sql; + command.Transaction = transaction; + + int affectedRows = 0; + + foreach (var parameters in parameterSets) + { + command.Parameters.Clear(); + command.Parameters.AddRange(parameters); + affectedRows += command.ExecuteNonQuery(); + } + + transaction.Commit(); + return affectedRows; + } + } +} \ No newline at end of file diff --git a/src/Server/Migrations/SQLiteMigrations.cs b/src/Server/Migrations/SQLiteMigrations.cs new file mode 100644 index 0000000..c7f2614 --- /dev/null +++ b/src/Server/Migrations/SQLiteMigrations.cs @@ -0,0 +1,65 @@ +using System.Data.Common; + +public static class SQLiteMigrations +{ + public static void Migrate(DbConnection conn) + { + EnableWal(conn); + + using var cmd = conn.CreateCommand(); + + cmd.CommandText = "PRAGMA user_version;"; + var version = Convert.ToInt32(cmd.ExecuteScalar()); + + if (version == 0) + { + CreateV1(conn); + SetVersion(conn, 1); + version = 1; + } + + if (version == 1) + { + // future migration + // UpdateFrom1To2(conn); + // SetVersion(conn, 2); + } + } + + private static void EnableWal(DbConnection conn) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = "PRAGMA journal_mode = WAL;"; + cmd.ExecuteNonQuery(); + } + + + private static void CreateV1(DbConnection conn) + { + using var tx = conn.BeginTransaction(); + using var cmd = conn.CreateCommand(); + + cmd.CommandText = """ + CREATE TABLE embedding_cache ( + cache_key TEXT NOT NULL, + model_key TEXT NOT NULL, + embedding BLOB NOT NULL, + idx INTEGER NOT NULL, + PRIMARY KEY (cache_key, model_key) + ); + + CREATE INDEX idx_index + ON embedding_cache(idx); + """; + + cmd.ExecuteNonQuery(); + tx.Commit(); + } + + private static void SetVersion(DbConnection conn, int version) + { + using var cmd = conn.CreateCommand(); + cmd.CommandText = $"PRAGMA user_version = {version};"; + cmd.ExecuteNonQuery(); + } +} diff --git a/src/Server/Models/ConfigModels.cs b/src/Server/Models/ConfigModels.cs index 321c642..1c318c8 100644 --- a/src/Server/Models/ConfigModels.cs +++ b/src/Server/Models/ConfigModels.cs @@ -8,9 +8,9 @@ public class EmbeddingSearchOptions : ApiKeyOptions { public required ConnectionStringsOptions ConnectionStrings { get; set; } public ElmahOptions? Elmah { get; set; } - public required long EmbeddingCacheMaxCount { get; set; } public required Dictionary AiProviders { get; set; } public required SimpleAuthOptions SimpleAuth { get; set; } + public required CacheOptions Cache { get; set; } public required bool UseHttpsRedirection { get; set; } } @@ -38,4 +38,12 @@ public class SimpleUser public class ConnectionStringsOptions { public required string SQL { get; set; } + public string? Cache { get; set; } +} + +public class CacheOptions +{ + public required long CacheTopN { get; set; } + public bool StoreEmbeddingCache { get; set; } = false; + public int? StoreTopN { get; set; } } \ No newline at end of file diff --git a/src/Server/Models/SQLHelper.cs b/src/Server/Models/SQLHelper.cs new file mode 100644 index 0000000..0edeb7f --- /dev/null +++ b/src/Server/Models/SQLHelper.cs @@ -0,0 +1,109 @@ +namespace Server.Models; +using System.Data.Common; + +public abstract partial class SqlHelper : IDisposable +{ + public DbConnection Connection { get; set; } + public DbDataReader? DbDataReader { get; set; } + public string ConnectionString { get; set; } + public SqlHelper(DbConnection connection, string connectionString) + { + Connection = connection; + ConnectionString = connectionString; + } + + public abstract SqlHelper DuplicateConnection(); + + public void Dispose() + { + Connection.Close(); + GC.SuppressFinalize(this); + } + + public DbDataReader ExecuteSQLCommand(string query, object[] parameters) + { + lock (Connection) + { + EnsureConnected(); + EnsureDbReaderIsClosed(); + using DbCommand command = Connection.CreateCommand(); + command.CommandText = query; + command.Parameters.AddRange(parameters); + DbDataReader = command.ExecuteReader(); + return DbDataReader; + } + } + + public void ExecuteQuery(string query, object[] parameters, Func map) + { + lock (Connection) + { + EnsureConnected(); + EnsureDbReaderIsClosed(); + + using var command = Connection.CreateCommand(); + command.CommandText = query; + command.Parameters.AddRange(parameters); + + using var reader = command.ExecuteReader(); + + while (reader.Read()) + { + map(reader); + } + + return; + } + } + + public int ExecuteSQLNonQuery(string query, object[] parameters) + { + lock (Connection) + { + EnsureConnected(); + EnsureDbReaderIsClosed(); + using DbCommand command = Connection.CreateCommand(); + + command.CommandText = query; + command.Parameters.AddRange(parameters); + return command.ExecuteNonQuery(); + } + } + + public abstract int ExecuteSQLCommandGetInsertedID(string query, object[] parameters); + + public bool EnsureConnected() + { + if (Connection.State != System.Data.ConnectionState.Open) + { + try + { + Connection.Close(); + Connection.Open(); + } + catch (Exception ex) + { + ElmahCore.ElmahExtensions.RaiseError(ex); + throw; + } + } + return true; + } + + public void EnsureDbReaderIsClosed() + { + int counter = 0; + int sleepTime = 10; + int timeout = 5000; + while (!(DbDataReader?.IsClosed ?? true)) + { + if (counter > timeout / sleepTime) + { + TimeoutException ex = new("Unable to ensure dbDataReader is closed"); + ElmahCore.ElmahExtensions.RaiseError(ex); + throw ex; + } + Thread.Sleep(sleepTime); + } + } +} \ No newline at end of file diff --git a/src/Server/Program.cs b/src/Server/Program.cs index a1c9244..a400de8 100644 --- a/src/Server/Program.cs +++ b/src/Server/Program.cs @@ -17,6 +17,7 @@ using Microsoft.AspNetCore.ResponseCompression; using System.Net; using System.Text; using Server.Migrations; +using Microsoft.Data.Sqlite; var builder = WebApplication.CreateBuilder(args); @@ -39,6 +40,15 @@ builder.Services.Configure(configurationSection); var helper = new SQLHelper(new MySql.Data.MySqlClient.MySqlConnection(configuration.ConnectionStrings.SQL), configuration.ConnectionStrings.SQL); DatabaseMigrations.Migrate(helper); +// Migrate SQLite cache +if (configuration.ConnectionStrings.Cache is not null) +{ + + var SqliteConnection = new SqliteConnection(configuration.ConnectionStrings.Cache); + SqliteConnection.Open(); + SQLiteMigrations.Migrate(SqliteConnection); +} + // Add Localization builder.Services.AddLocalization(options => options.ResourcesPath = "Resources"); builder.Services.Configure(options => diff --git a/src/Server/SearchdomainManager.cs b/src/Server/SearchdomainManager.cs index 432550c..34e6948 100644 --- a/src/Server/SearchdomainManager.cs +++ b/src/Server/SearchdomainManager.cs @@ -9,10 +9,11 @@ using System.Text.Json; using Microsoft.Extensions.Options; using Server.Models; using Shared; +using System.Diagnostics; namespace Server; -public class SearchdomainManager +public class SearchdomainManager : IDisposable { private Dictionary searchdomains = []; private readonly ILogger _logger; @@ -24,6 +25,7 @@ public class SearchdomainManager public SQLHelper helper; public EnumerableLruCache> embeddingCache; public long EmbeddingCacheMaxCount; + private bool disposed = false; public SearchdomainManager(ILogger logger, IOptions options, AIProvider aIProvider, DatabaseHelper databaseHelper) { @@ -31,8 +33,17 @@ public class SearchdomainManager _options = options.Value; this.aIProvider = aIProvider; _databaseHelper = databaseHelper; - EmbeddingCacheMaxCount = _options.EmbeddingCacheMaxCount; - embeddingCache = new((int)EmbeddingCacheMaxCount); + EmbeddingCacheMaxCount = _options.Cache.CacheTopN; + if (options.Value.Cache.StoreEmbeddingCache) + { + var stopwatch = Stopwatch.StartNew(); + embeddingCache = CacheHelper.GetEmbeddingStore(options.Value); + stopwatch.Stop(); + _logger.LogInformation("GetEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds); + } else + { + embeddingCache = new((int)EmbeddingCacheMaxCount); + } connectionString = _options.ConnectionStrings.SQL; connection = new MySqlConnection(connectionString); connection.Open(); @@ -80,7 +91,7 @@ public class SearchdomainManager { results.Add(reader.GetString(0)); } - return results; + return results; } finally { @@ -127,4 +138,39 @@ public class SearchdomainManager { return searchdomains.ContainsKey(name); } + + // Cleanup procedure + private async Task Cleanup() + { + try + { + if (_options.Cache.StoreEmbeddingCache) + { + var stopwatch = Stopwatch.StartNew(); + await CacheHelper.UpdateEmbeddingStore(embeddingCache, _options); + stopwatch.Stop(); + _logger.LogInformation("UpdateEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds); + } + _logger.LogInformation("SearchdomainManager cleanup completed"); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error during SearchdomainManager cleanup"); + } + } + + public void Dispose() + { + Dispose(true).Wait(); + GC.SuppressFinalize(this); + } + + protected virtual async Task Dispose(bool disposing) + { + if (!disposed && disposing) + { + await Cleanup(); + disposed = true; + } + } } diff --git a/src/Server/appsettings.Development.json b/src/Server/appsettings.Development.json index 6387c27..723e4f3 100644 --- a/src/Server/appsettings.Development.json +++ b/src/Server/appsettings.Development.json @@ -15,12 +15,12 @@ "Embeddingsearch": { "ConnectionStrings": { - "SQL": "server=localhost;database=embeddingsearch;uid=embeddingsearch;pwd=somepassword!;" + "SQL": "server=localhost;database=embeddingsearch;uid=embeddingsearch;pwd=somepassword!;", + "Cache": "Data Source=embeddings.db;Mode=ReadWriteCreate;Cache=Shared" }, "Elmah": { "LogPath": "~/logs" }, - "EmbeddingCacheMaxCount": 10000000, "AiProviders": { "ollama": { "handler": "ollama", @@ -46,6 +46,11 @@ ] }, "ApiKeys": ["Some UUID here", "Another UUID here"], - "UseHttpsRedirection": true + "UseHttpsRedirection": true, + "Cache": { + "CacheTopN": 100000, + "StoreEmbeddingCache": true, + "StoreTopN": 20000 + } } }