continued cache implementation, massively improved entity creation

This commit is contained in:
EzFeDezy
2025-04-20 03:29:39 +02:00
parent 473bfd1a8b
commit a724ef80a2
6 changed files with 178 additions and 16 deletions

1
.gitignore vendored
View File

@@ -3,3 +3,4 @@ src/cli/obj
src/embeddingsearch/bin src/embeddingsearch/bin
src/embeddingsearch/obj src/embeddingsearch/obj
src/server src/server
src/debug

View File

@@ -115,8 +115,9 @@ public class OptionsEntityIndex : OptionsEntity // Example: -i -e {"name": "myfi
[Option('s', Required = true, HelpText = "Searchdomain the entity belongs to")] [Option('s', Required = true, HelpText = "Searchdomain the entity belongs to")]
public required string Searchdomain { get; set; } public required string Searchdomain { get; set; }
[Option('e', Required = true, HelpText = "Entity (as JSON) to be inserted")] [Option('e', Required = false, HelpText = "Entity (as JSON) to be inserted")]
public required string EntityJSON { get; set; } public string? EntityJSON { get; set; }
/* Example for an entity: /* Example for an entity:
{ {
"name": "myfile.txt", "name": "myfile.txt",

View File

@@ -159,6 +159,7 @@ parser.ParseArguments<OptionsCommand>(args).WithParsed<OptionsCommand>(opts =>
counter += 1; counter += 1;
} }
Console.WriteLine($"Number of entities deleted as part of deleting the searchdomain: {counter}"); Console.WriteLine($"Number of entities deleted as part of deleting the searchdomain: {counter}");
searchdomain.ExecuteSQLNonQuery("DELETE FROM entity WHERE id_searchdomain = @id", new() {{"id", searchdomain.id}}); // Cleanup // TODO add rows affected
searchdomain.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", opts.Searchdomain}}); searchdomain.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", opts.Searchdomain}});
Console.WriteLine("Searchdomain has been successfully removed."); Console.WriteLine("Searchdomain has been successfully removed.");
}) })
@@ -204,14 +205,44 @@ parser.ParseArguments<OptionsCommand>(args).WithParsed<OptionsCommand>(opts =>
{ {
parser.ParseArguments<OptionsEntityIndex>(args).WithParsed<OptionsEntityIndex>(opts => parser.ParseArguments<OptionsEntityIndex>(args).WithParsed<OptionsEntityIndex>(opts =>
{ {
if (opts.EntityJSON is null)
{
opts.EntityJSON = Console.In.ReadToEnd();
}
Searchdomain searchdomain = GetSearchdomain(opts.OllamaURL, opts.Searchdomain, opts.IP, opts.Username, opts.Password); Searchdomain searchdomain = GetSearchdomain(opts.OllamaURL, opts.Searchdomain, opts.IP, opts.Username, opts.Password);
Entity? entity = searchdomain.EntityFromJSON(opts.EntityJSON); try
if (entity is not null)
{ {
Console.WriteLine("Successfully created/updated the entity"); if (opts.EntityJSON.StartsWith('[')) // multiple entities
} else {
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>?>(opts.EntityJSON);
if (jsonEntities is not null)
{
List<Entity>? entities = searchdomain.EntitiesFromJSON(opts.EntityJSON);
if (entities is not null)
{
Console.WriteLine("Successfully created/updated the entity");
} else
{
Console.Error.WriteLine("Unable to create the entity using the provided JSON.");
retval = 1;
}
}
} else
{
Entity? entity = searchdomain.EntityFromJSON(opts.EntityJSON);
if (entity is not null)
{
Console.WriteLine("Successfully created/updated the entity");
} else
{
Console.Error.WriteLine("Unable to create the entity using the provided JSON.");
retval = 1;
}
}
} catch (Exception e)
{ {
Console.Error.WriteLine("Unable to create the entity using the provided JSON."); Console.Error.WriteLine($"Unable to create the entity using the provided JSON.\nException: {e}");
retval = 1; retval = 1;
} }
}) })
@@ -287,11 +318,16 @@ static List<(float, string)> Search(OptionsEntityEvaluate optionsEntityIndex)
static Searchdomain GetSearchdomain(string ollamaURL, string searchdomain, string ip, string username, string password, bool runEmpty = false) static Searchdomain GetSearchdomain(string ollamaURL, string searchdomain, string ip, string username, string password, bool runEmpty = false)
{ {
string connectionString = $"server={ip};database=embeddingsearch;uid={username};pwd={password};"; string connectionString = $"server={ip};database=embeddingsearch;uid={username};pwd={password};";
var ollamaConfig = new OllamaApiClient.Configuration // var ollamaConfig = new OllamaApiClient.Configuration
// {
// Uri = new Uri(ollamaURL)
// };
var httpClient = new HttpClient
{ {
Uri = new Uri(ollamaURL) BaseAddress = new Uri(ollamaURL),
Timeout = TimeSpan.FromSeconds(36000) //.MaxValue //FromSeconds(timeout)
}; };
var ollama = new OllamaApiClient(ollamaConfig); var ollama = new OllamaApiClient(httpClient);
return new Searchdomain(searchdomain, connectionString, ollama, "sqlserver", runEmpty); return new Searchdomain(searchdomain, connectionString, ollama, "sqlserver", runEmpty);
} }

View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text.Json;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.Extensions.AI; using Microsoft.Extensions.AI;
using OllamaSharp; using OllamaSharp;
@@ -39,10 +40,64 @@ public class Datapoint
} }
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama) public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama)
{
return GenerateEmbeddings(content, models, ollama, []);
}
public static Dictionary<string, float[]> GenerateEmbeddings(List<string> contents, string model, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
{
Dictionary<string, float[]> retVal = [];
List<string> remainingContents = new List<string>(contents);
for (int i = contents.Count - 1; i >= 0; i--) // Compare against cache and remove accordingly
{
string content = contents[i];
if (embeddingCache.ContainsKey(model) && embeddingCache[model].ContainsKey(content))
{
retVal[content] = embeddingCache[model][content];
remainingContents.RemoveAt(i);
}
}
if (remainingContents.Count == 0)
{
return retVal;
}
EmbedRequest request = new()
{
Model = model,
Input = remainingContents
};
EmbedResponse response = ollama.EmbedAsync(request).Result;
for (int i = 0; i < response.Embeddings.Count; i++)
{
string content = remainingContents.ElementAt(i);
float[] embeddings = response.Embeddings.ElementAt(i);
retVal[content] = embeddings;
if (!embeddingCache.ContainsKey(model))
{
embeddingCache[model] = [];
}
if (!embeddingCache[model].ContainsKey(content))
{
embeddingCache[model][content] = embeddings;
}
}
return retVal;
}
public static Dictionary<string, float[]> GenerateEmbeddings(string content, List<string> models, OllamaApiClient ollama, Dictionary<string, Dictionary<string, float[]>> embeddingCache)
{ {
Dictionary<string, float[]> retVal = []; Dictionary<string, float[]> retVal = [];
foreach (string model in models) foreach (string model in models)
{ {
if (embeddingCache.ContainsKey(model) && embeddingCache[model].ContainsKey(content))
{
retVal[model] = embeddingCache[model][content];
continue;
}
EmbedRequest request = new() EmbedRequest request = new()
{ {
Model = model, Model = model,
@@ -55,6 +110,14 @@ public class Datapoint
float[] var = new float[response.Vector.Length]; float[] var = new float[response.Vector.Length];
response.Vector.CopyTo(var); response.Vector.CopyTo(var);
retVal[model] = var; retVal[model] = var;
if (!embeddingCache.ContainsKey(model))
{
embeddingCache[model] = [];
}
if (!embeddingCache[model].ContainsKey(content))
{
embeddingCache[model][content] = var;
}
} }
} }
return retVal; return retVal;

View File

@@ -1,4 +1,6 @@
class JSONEntity namespace embeddingsearch;
public class JSONEntity
{ {
public required string name { get; set; } public required string name { get; set; }
public required string probmethod { get; set; } public required string probmethod { get; set; }
@@ -7,7 +9,7 @@ class JSONEntity
public required JSONDatapoint[] datapoints { get; set; } public required JSONDatapoint[] datapoints { get; set; }
} }
class JSONDatapoint public class JSONDatapoint
{ {
public required string name { get; set; } public required string name { get; set; }
public required string text { get; set; } public required string text { get; set; }

View File

@@ -282,7 +282,7 @@ public class Searchdomain
foreach (JSONDatapoint jsonDatapoint in jsonEntity.datapoints) foreach (JSONDatapoint jsonDatapoint in jsonEntity.datapoints)
{ {
Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.text, [.. jsonDatapoint.model], ollama); Dictionary<string, float[]> embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.text, [.. jsonDatapoint.model], ollama, embeddingCache);
var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.probmethod_embedding}"); var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.probmethod_embedding}");
Datapoint datapoint = new(jsonDatapoint.name, probMethod_embedding, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); Datapoint datapoint = new(jsonDatapoint.name, probMethod_embedding, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.name, jsonDatapoint.probmethod_embedding, id_entity); int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.name, jsonDatapoint.probmethod_embedding, id_entity);
@@ -302,6 +302,65 @@ public class Searchdomain
return entity; return entity;
} }
public List<Entity>? EntitiesFromJSON(string json)
{
List<JSONEntity>? jsonEntities = JsonSerializer.Deserialize<List<JSONEntity>>(json);
if (jsonEntities is null)
{
return null;
}
Dictionary<string, List<string>> toBeCached = [];
foreach (JSONEntity jSONEntity in jsonEntities)
{
foreach (JSONDatapoint datapoint in jSONEntity.datapoints)
{
foreach (string model in datapoint.model)
{
if (!toBeCached.ContainsKey(model))
{
toBeCached[model] = [];
}
toBeCached[model].Add(datapoint.text);
}
}
}
//Console.WriteLine(JsonSerializer.Serialize(toBeCached));
//return new List<Entity>();
Dictionary<string, Dictionary<string, float[]>> cache = []; // local cache
foreach (KeyValuePair<string, List<string>> cacheThis in toBeCached)
{
string model = cacheThis.Key;
List<string> contents = cacheThis.Value;
Console.WriteLine("DEBUG@searchdomain-1");
Console.WriteLine(model);
Console.WriteLine(contents);
Console.WriteLine(contents.Count);
if (contents.Count == 0)
{
Console.WriteLine("DEBUG@searchdomain-2-no");
cache[model] = [];
continue;
}
Console.WriteLine("DEBUG@searchdomain-2-yes");
//Console.WriteLine("DEBUG@searchdomain[[");
//Console.WriteLine(model);
//Console.WriteLine(JsonSerializer.Serialize(contents));
//Console.WriteLine("]]");
cache[model] = Datapoint.GenerateEmbeddings(contents, model, ollama, embeddingCache);
//Console.WriteLine(JsonSerializer.Serialize(cache[model]));
}
var tempEmbeddingCache = embeddingCache;
embeddingCache = cache;
List<Entity> retVal = [];
foreach (JSONEntity jSONEntity in jsonEntities)
{
retVal.Append(EntityFromJSON(JsonSerializer.Serialize(jSONEntity)));
}
embeddingCache = tempEmbeddingCache;
return retVal;
}
public void DatabaseRemoveEntity(string name) public void DatabaseRemoveEntity(string name)
{ {
Dictionary<string, dynamic> parameters = new() Dictionary<string, dynamic> parameters = new()