Added persistent embedding cache

This commit is contained in:
2026-01-21 23:54:08 +01:00
parent 76c9913485
commit 5f05aac909
9 changed files with 570 additions and 8 deletions

1
.gitignore vendored
View File

@@ -20,3 +20,4 @@ src/Shared/obj
src/Server/wwwroot/logs/*
src/Server/Tools/CriticalCSS/node_modules
src/Server/Tools/CriticalCSS/package*.json
*.db*

View File

@@ -0,0 +1,242 @@
using System.Configuration;
using Microsoft.Data.Sqlite;
using Microsoft.Extensions.Options;
using OllamaSharp.Models;
using Server.Models;
using Shared;
namespace Server.Helper;
public static class CacheHelper
{
public static EnumerableLruCache<string, Dictionary<string, float[]>> GetEmbeddingStore(EmbeddingSearchOptions options)
{
SQLiteHelper helper = new(options);
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache = new((int)(options.Cache.StoreTopN ?? options.Cache.CacheTopN));
helper.ExecuteQuery(
"SELECT cache_key, model_key, embedding, idx FROM embedding_cache ORDER BY idx ASC", [], r =>
{
int embeddingOrdinal = r.GetOrdinal("embedding");
int length = (int)r.GetBytes(embeddingOrdinal, 0, null, 0, 0);
byte[] buffer = new byte[length];
r.GetBytes(embeddingOrdinal, 0, buffer, 0, length);
var cache_key = r.GetString(r.GetOrdinal("cache_key"));
var model_key = r.GetString(r.GetOrdinal("model_key"));
var embedding = SearchdomainHelper.FloatArrayFromBytes(buffer);
var index = r.GetInt32(r.GetOrdinal("idx"));
if (cache_key is null || model_key is null || embedding is null)
{
throw new Exception("Unable to get the embedding store due to a returned element being null");
}
if (!embeddingCache.TryGetValue(cache_key, out Dictionary<string, float[]>? keyElement) || keyElement is null)
{
keyElement = [];
embeddingCache[cache_key] = keyElement;
}
keyElement[model_key] = embedding;
return 0;
}
);
embeddingCache.Capacity = (int)options.Cache.CacheTopN;
return embeddingCache;
}
public static async Task UpdateEmbeddingStore(EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache, EmbeddingSearchOptions options)
{
if (options.Cache.StoreTopN is not null)
{
embeddingCache.Capacity = (int)options.Cache.StoreTopN;
}
SQLiteHelper helper = new(options);
EnumerableLruCache<string, Dictionary<string, float[]>> embeddingStore = GetEmbeddingStore(options);
var embeddingCacheMappings = GetCacheMappings(embeddingCache);
var embeddingCacheIndexMap = embeddingCacheMappings.positionToEntry;
var embeddingCacheObjectMap = embeddingCacheMappings.entryToPosition;
var embeddingStoreMappings = GetCacheMappings(embeddingStore);
var embeddingStoreIndexMap = embeddingStoreMappings.positionToEntry;
var embeddingStoreObjectMap = embeddingStoreMappings.entryToPosition;
List<int> deletedEntries = [];
foreach (KeyValuePair<int, KeyValuePair<string, Dictionary<string, float[]>>> kv in embeddingStoreIndexMap)
{
int storeEntryIndex = kv.Key;
string storeEntryString = kv.Value.Key;
bool cacheEntryExists = embeddingCacheObjectMap.TryGetValue(storeEntryString, out int cacheEntryIndex);
if (!cacheEntryExists) // Deleted
{
deletedEntries.Add(storeEntryIndex);
}
}
Task removeEntriesFromStoreTask = RemoveEntriesFromStore(helper, deletedEntries);
List<(int Index, KeyValuePair<string, Dictionary<string, float[]>> Entry)> createdEntries = [];
List<(int Index, int NewIndex)> changedEntries = [];
List<(int Index, string Model, string Key, float[] Embedding)> AddedModels = [];
List<(int Index, string Model)> RemovedModels = [];
foreach (KeyValuePair<int, KeyValuePair<string, Dictionary<string, float[]>>> kv in embeddingCacheIndexMap)
{
int cacheEntryIndex = kv.Key;
string cacheEntryString = kv.Value.Key;
bool storeEntryExists = embeddingStoreObjectMap.TryGetValue(cacheEntryString, out int storeEntryIndex);
if (!storeEntryExists) // Created
{
createdEntries.Add((
Index: cacheEntryIndex,
Entry: kv.Value
));
continue;
}
if (cacheEntryIndex != storeEntryIndex) // Changed
{
changedEntries.Add((
Index: cacheEntryIndex,
NewIndex: storeEntryIndex
));
}
// Check for new/removed models
var storeModels = embeddingStoreIndexMap[storeEntryIndex].Value;
var cacheModels = kv.Value.Value;
// New models
foreach (var model in storeModels.Keys.Except(cacheModels.Keys))
{
RemovedModels.Add((
Index: cacheEntryIndex,
Model: model
));
}
// Removed models
foreach (var model in cacheModels.Keys.Except(storeModels.Keys))
{
AddedModels.Add((
Index: cacheEntryIndex,
Model: model,
Key: cacheEntryString,
Embedding: cacheModels[model]
));
}
}
var taskSet = new List<Task>
{
removeEntriesFromStoreTask,
CreateEntriesInStore(helper, createdEntries),
UpdateEntryIndicesInStore(helper, changedEntries),
AddModelsToIndices(helper, AddedModels),
RemoveModelsFromIndices(helper, RemovedModels)
};
await Task.WhenAll(taskSet);
}
private static async Task CreateEntriesInStore(
SQLiteHelper helper,
List<(int Index, KeyValuePair<string, Dictionary<string, float[]>> Entry)> createdEntries)
{
helper.BulkExecuteNonQuery(
"INSERT INTO embedding_cache (cache_key, model_key, embedding, idx) VALUES (@cache_key, @model_key, @embedding, @index)",
createdEntries.SelectMany(element => {
return element.Entry.Value.Select(model => new object[]
{
new SqliteParameter("@cache_key", element.Entry.Key),
new SqliteParameter("@model_key", model.Key),
new SqliteParameter("@embedding", SearchdomainHelper.BytesFromFloatArray(model.Value)),
new SqliteParameter("@index", element.Index)
});
})
);
}
private static async Task UpdateEntryIndicesInStore(
SQLiteHelper helper,
List<(int Index, int NewIndex)> changedEntries)
{
helper.BulkExecuteNonQuery(
"UPDATE embedding_cache SET idx = @newIndex WHERE idx = @index",
changedEntries.Select(element => new object[]
{
new SqliteParameter("@index", element.Index),
new SqliteParameter("@newIndex", -element.NewIndex) // The "-" prevents in-place update collisions
})
);
helper.BulkExecuteNonQuery(
"UPDATE embedding_cache SET idx = @newIndex WHERE idx = @index",
changedEntries.Select(element => new object[]
{
new SqliteParameter("@index", -element.NewIndex),
new SqliteParameter("@newIndex", element.NewIndex) // Flip the negative prefix
})
);
}
private static async Task RemoveEntriesFromStore(
SQLiteHelper helper,
List<int> deletedEntries)
{
helper.BulkExecuteNonQuery(
"DELETE FROM embedding_cache WHERE idx = @index",
deletedEntries.Select(index => new object[]
{
new SqliteParameter("@index", index)
})
);
}
private static async Task AddModelsToIndices(
SQLiteHelper helper,
List<(int Index, string Model, string Key, float[] Embedding)> addedModels)
{
helper.BulkExecuteNonQuery(
"INSERT INTO embedding_cache (cache_key, model_key, embedding, idx) VALUES (@cache_key, @model_key, @embedding, @index)",
addedModels.Select(element => new object[]
{
new SqliteParameter("@cache_key", element.Key),
new SqliteParameter("@model_key", element.Model),
new SqliteParameter("@embedding", SearchdomainHelper.BytesFromFloatArray(element.Embedding)),
new SqliteParameter("@index", element.Index)
})
);
}
private static async Task RemoveModelsFromIndices(
SQLiteHelper helper,
List<(int Index, string Model)> removedModels)
{
helper.BulkExecuteNonQuery(
"DELETE FROM embedding_cache WHERE idx = @index AND model_key = @model",
removedModels.Select(element => new object[]
{
new SqliteParameter("@index", element.Index),
new SqliteParameter("@model", element.Model)
})
);
}
private static (Dictionary<int, KeyValuePair<string, Dictionary<string, float[]>>> positionToEntry,
Dictionary<string, int> entryToPosition)
GetCacheMappings(EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
{
var positionToEntry = new Dictionary<int, KeyValuePair<string, Dictionary<string, float[]>>>();
var entryToPosition = new Dictionary<string, int>();
int position = 0;
foreach (var entry in embeddingCache)
{
positionToEntry[position] = entry;
entryToPosition[entry.Key] = position;
position++;
}
return (positionToEntry, entryToPosition);
}
}

View File

@@ -0,0 +1,76 @@
using System.Data;
using System.Data.Common;
using Microsoft.Data.Sqlite;
using Server.Models;
using MySql.Data.MySqlClient;
using System.Configuration;
namespace Server.Helper;
public class SQLiteHelper : SqlHelper, IDisposable
{
public SQLiteHelper(DbConnection connection, string connectionString) : base(connection, connectionString)
{
Connection = connection;
ConnectionString = connectionString;
}
public SQLiteHelper(EmbeddingSearchOptions options) : base(new SqliteConnection(options.ConnectionStrings.Cache), options.ConnectionStrings.Cache ?? "")
{
if (options.ConnectionStrings.Cache is null)
{
throw new ConfigurationErrorsException("Cache options must not be null when instantiating SQLiteHelper");
}
ConnectionString = options.ConnectionStrings.Cache;
Connection = new SqliteConnection(ConnectionString);
}
public override SQLiteHelper DuplicateConnection()
{
SqliteConnection newConnection = new(ConnectionString);
return new SQLiteHelper(newConnection, ConnectionString);
}
public override int ExecuteSQLCommandGetInsertedID(string query, object[] parameters)
{
lock (Connection)
{
EnsureConnected();
EnsureDbReaderIsClosed();
using DbCommand command = Connection.CreateCommand();
command.CommandText = query;
command.Parameters.AddRange(parameters);
command.ExecuteNonQuery();
command.CommandText = "SELECT last_insert_rowid();";
return Convert.ToInt32(command.ExecuteScalar());
}
}
public int BulkExecuteNonQuery(string sql, IEnumerable<object[]> parameterSets)
{
lock (Connection)
{
EnsureConnected();
EnsureDbReaderIsClosed();
using var transaction = Connection.BeginTransaction();
using var command = Connection.CreateCommand();
command.CommandText = sql;
command.Transaction = transaction;
int affectedRows = 0;
foreach (var parameters in parameterSets)
{
command.Parameters.Clear();
command.Parameters.AddRange(parameters);
affectedRows += command.ExecuteNonQuery();
}
transaction.Commit();
return affectedRows;
}
}
}

View File

@@ -0,0 +1,65 @@
using System.Data.Common;
public static class SQLiteMigrations
{
public static void Migrate(DbConnection conn)
{
EnableWal(conn);
using var cmd = conn.CreateCommand();
cmd.CommandText = "PRAGMA user_version;";
var version = Convert.ToInt32(cmd.ExecuteScalar());
if (version == 0)
{
CreateV1(conn);
SetVersion(conn, 1);
version = 1;
}
if (version == 1)
{
// future migration
// UpdateFrom1To2(conn);
// SetVersion(conn, 2);
}
}
private static void EnableWal(DbConnection conn)
{
using var cmd = conn.CreateCommand();
cmd.CommandText = "PRAGMA journal_mode = WAL;";
cmd.ExecuteNonQuery();
}
private static void CreateV1(DbConnection conn)
{
using var tx = conn.BeginTransaction();
using var cmd = conn.CreateCommand();
cmd.CommandText = """
CREATE TABLE embedding_cache (
cache_key TEXT NOT NULL,
model_key TEXT NOT NULL,
embedding BLOB NOT NULL,
idx INTEGER NOT NULL,
PRIMARY KEY (cache_key, model_key)
);
CREATE INDEX idx_index
ON embedding_cache(idx);
""";
cmd.ExecuteNonQuery();
tx.Commit();
}
private static void SetVersion(DbConnection conn, int version)
{
using var cmd = conn.CreateCommand();
cmd.CommandText = $"PRAGMA user_version = {version};";
cmd.ExecuteNonQuery();
}
}

View File

@@ -8,9 +8,9 @@ public class EmbeddingSearchOptions : ApiKeyOptions
{
public required ConnectionStringsOptions ConnectionStrings { get; set; }
public ElmahOptions? Elmah { get; set; }
public required long EmbeddingCacheMaxCount { get; set; }
public required Dictionary<string, AiProvider> AiProviders { get; set; }
public required SimpleAuthOptions SimpleAuth { get; set; }
public required CacheOptions Cache { get; set; }
public required bool UseHttpsRedirection { get; set; }
}
@@ -38,4 +38,12 @@ public class SimpleUser
public class ConnectionStringsOptions
{
public required string SQL { get; set; }
public string? Cache { get; set; }
}
public class CacheOptions
{
public required long CacheTopN { get; set; }
public bool StoreEmbeddingCache { get; set; } = false;
public int? StoreTopN { get; set; }
}

View File

@@ -0,0 +1,109 @@
namespace Server.Models;
using System.Data.Common;
public abstract partial class SqlHelper : IDisposable
{
public DbConnection Connection { get; set; }
public DbDataReader? DbDataReader { get; set; }
public string ConnectionString { get; set; }
public SqlHelper(DbConnection connection, string connectionString)
{
Connection = connection;
ConnectionString = connectionString;
}
public abstract SqlHelper DuplicateConnection();
public void Dispose()
{
Connection.Close();
GC.SuppressFinalize(this);
}
public DbDataReader ExecuteSQLCommand(string query, object[] parameters)
{
lock (Connection)
{
EnsureConnected();
EnsureDbReaderIsClosed();
using DbCommand command = Connection.CreateCommand();
command.CommandText = query;
command.Parameters.AddRange(parameters);
DbDataReader = command.ExecuteReader();
return DbDataReader;
}
}
public void ExecuteQuery<T>(string query, object[] parameters, Func<DbDataReader, T> map)
{
lock (Connection)
{
EnsureConnected();
EnsureDbReaderIsClosed();
using var command = Connection.CreateCommand();
command.CommandText = query;
command.Parameters.AddRange(parameters);
using var reader = command.ExecuteReader();
while (reader.Read())
{
map(reader);
}
return;
}
}
public int ExecuteSQLNonQuery(string query, object[] parameters)
{
lock (Connection)
{
EnsureConnected();
EnsureDbReaderIsClosed();
using DbCommand command = Connection.CreateCommand();
command.CommandText = query;
command.Parameters.AddRange(parameters);
return command.ExecuteNonQuery();
}
}
public abstract int ExecuteSQLCommandGetInsertedID(string query, object[] parameters);
public bool EnsureConnected()
{
if (Connection.State != System.Data.ConnectionState.Open)
{
try
{
Connection.Close();
Connection.Open();
}
catch (Exception ex)
{
ElmahCore.ElmahExtensions.RaiseError(ex);
throw;
}
}
return true;
}
public void EnsureDbReaderIsClosed()
{
int counter = 0;
int sleepTime = 10;
int timeout = 5000;
while (!(DbDataReader?.IsClosed ?? true))
{
if (counter > timeout / sleepTime)
{
TimeoutException ex = new("Unable to ensure dbDataReader is closed");
ElmahCore.ElmahExtensions.RaiseError(ex);
throw ex;
}
Thread.Sleep(sleepTime);
}
}
}

View File

@@ -17,6 +17,7 @@ using Microsoft.AspNetCore.ResponseCompression;
using System.Net;
using System.Text;
using Server.Migrations;
using Microsoft.Data.Sqlite;
var builder = WebApplication.CreateBuilder(args);
@@ -39,6 +40,15 @@ builder.Services.Configure<ApiKeyOptions>(configurationSection);
var helper = new SQLHelper(new MySql.Data.MySqlClient.MySqlConnection(configuration.ConnectionStrings.SQL), configuration.ConnectionStrings.SQL);
DatabaseMigrations.Migrate(helper);
// Migrate SQLite cache
if (configuration.ConnectionStrings.Cache is not null)
{
var SqliteConnection = new SqliteConnection(configuration.ConnectionStrings.Cache);
SqliteConnection.Open();
SQLiteMigrations.Migrate(SqliteConnection);
}
// Add Localization
builder.Services.AddLocalization(options => options.ResourcesPath = "Resources");
builder.Services.Configure<RequestLocalizationOptions>(options =>

View File

@@ -9,10 +9,11 @@ using System.Text.Json;
using Microsoft.Extensions.Options;
using Server.Models;
using Shared;
using System.Diagnostics;
namespace Server;
public class SearchdomainManager
public class SearchdomainManager : IDisposable
{
private Dictionary<string, Searchdomain> searchdomains = [];
private readonly ILogger<SearchdomainManager> _logger;
@@ -24,6 +25,7 @@ public class SearchdomainManager
public SQLHelper helper;
public EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache;
public long EmbeddingCacheMaxCount;
private bool disposed = false;
public SearchdomainManager(ILogger<SearchdomainManager> logger, IOptions<EmbeddingSearchOptions> options, AIProvider aIProvider, DatabaseHelper databaseHelper)
{
@@ -31,8 +33,17 @@ public class SearchdomainManager
_options = options.Value;
this.aIProvider = aIProvider;
_databaseHelper = databaseHelper;
EmbeddingCacheMaxCount = _options.EmbeddingCacheMaxCount;
embeddingCache = new((int)EmbeddingCacheMaxCount);
EmbeddingCacheMaxCount = _options.Cache.CacheTopN;
if (options.Value.Cache.StoreEmbeddingCache)
{
var stopwatch = Stopwatch.StartNew();
embeddingCache = CacheHelper.GetEmbeddingStore(options.Value);
stopwatch.Stop();
_logger.LogInformation("GetEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds);
} else
{
embeddingCache = new((int)EmbeddingCacheMaxCount);
}
connectionString = _options.ConnectionStrings.SQL;
connection = new MySqlConnection(connectionString);
connection.Open();
@@ -80,7 +91,7 @@ public class SearchdomainManager
{
results.Add(reader.GetString(0));
}
return results;
return results;
}
finally
{
@@ -127,4 +138,39 @@ public class SearchdomainManager
{
return searchdomains.ContainsKey(name);
}
// Cleanup procedure
private async Task Cleanup()
{
try
{
if (_options.Cache.StoreEmbeddingCache)
{
var stopwatch = Stopwatch.StartNew();
await CacheHelper.UpdateEmbeddingStore(embeddingCache, _options);
stopwatch.Stop();
_logger.LogInformation("UpdateEmbeddingStore completed in {ElapsedMilliseconds} ms", stopwatch.ElapsedMilliseconds);
}
_logger.LogInformation("SearchdomainManager cleanup completed");
}
catch (Exception ex)
{
_logger.LogError(ex, "Error during SearchdomainManager cleanup");
}
}
public void Dispose()
{
Dispose(true).Wait();
GC.SuppressFinalize(this);
}
protected virtual async Task Dispose(bool disposing)
{
if (!disposed && disposing)
{
await Cleanup();
disposed = true;
}
}
}

View File

@@ -15,12 +15,12 @@
"Embeddingsearch": {
"ConnectionStrings": {
"SQL": "server=localhost;database=embeddingsearch;uid=embeddingsearch;pwd=somepassword!;"
"SQL": "server=localhost;database=embeddingsearch;uid=embeddingsearch;pwd=somepassword!;",
"Cache": "Data Source=embeddings.db;Mode=ReadWriteCreate;Cache=Shared"
},
"Elmah": {
"LogPath": "~/logs"
},
"EmbeddingCacheMaxCount": 10000000,
"AiProviders": {
"ollama": {
"handler": "ollama",
@@ -46,6 +46,11 @@
]
},
"ApiKeys": ["Some UUID here", "Another UUID here"],
"UseHttpsRedirection": true
"UseHttpsRedirection": true,
"Cache": {
"CacheTopN": 100000,
"StoreEmbeddingCache": true,
"StoreTopN": 20000
}
}
}