diff --git a/.gitignore b/.gitignore index 2cd7170..8fcc3ce 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ src/cli/bin src/cli/obj src/embeddingsearch/bin src/embeddingsearch/obj -src/server \ No newline at end of file +src/server +src/debug \ No newline at end of file diff --git a/src/cli/Options.cs b/src/cli/Options.cs index 417eaa3..ea428e3 100644 --- a/src/cli/Options.cs +++ b/src/cli/Options.cs @@ -115,8 +115,9 @@ public class OptionsEntityIndex : OptionsEntity // Example: -i -e {"name": "myfi [Option('s', Required = true, HelpText = "Searchdomain the entity belongs to")] public required string Searchdomain { get; set; } - [Option('e', Required = true, HelpText = "Entity (as JSON) to be inserted")] - public required string EntityJSON { get; set; } + [Option('e', Required = false, HelpText = "Entity (as JSON) to be inserted")] + public string? EntityJSON { get; set; } + /* Example for an entity: { "name": "myfile.txt", diff --git a/src/cli/Program.cs b/src/cli/Program.cs index 178af0e..ef2444a 100644 --- a/src/cli/Program.cs +++ b/src/cli/Program.cs @@ -159,6 +159,7 @@ parser.ParseArguments(args).WithParsed(opts => counter += 1; } Console.WriteLine($"Number of entities deleted as part of deleting the searchdomain: {counter}"); + searchdomain.ExecuteSQLNonQuery("DELETE FROM entity WHERE id_searchdomain = @id", new() {{"id", searchdomain.id}}); // Cleanup // TODO add rows affected searchdomain.ExecuteSQLNonQuery("DELETE FROM searchdomain WHERE name = @name", new() {{"name", opts.Searchdomain}}); Console.WriteLine("Searchdomain has been successfully removed."); }) @@ -204,14 +205,44 @@ parser.ParseArguments(args).WithParsed(opts => { parser.ParseArguments(args).WithParsed(opts => { + if (opts.EntityJSON is null) + { + opts.EntityJSON = Console.In.ReadToEnd(); + } Searchdomain searchdomain = GetSearchdomain(opts.OllamaURL, opts.Searchdomain, opts.IP, opts.Username, opts.Password); - Entity? entity = searchdomain.EntityFromJSON(opts.EntityJSON); - if (entity is not null) + try { - Console.WriteLine("Successfully created/updated the entity"); - } else + if (opts.EntityJSON.StartsWith('[')) // multiple entities + { + List? jsonEntities = JsonSerializer.Deserialize?>(opts.EntityJSON); + if (jsonEntities is not null) + { + + List? entities = searchdomain.EntitiesFromJSON(opts.EntityJSON); + if (entities is not null) + { + Console.WriteLine("Successfully created/updated the entity"); + } else + { + Console.Error.WriteLine("Unable to create the entity using the provided JSON."); + retval = 1; + } + } + } else + { + Entity? entity = searchdomain.EntityFromJSON(opts.EntityJSON); + if (entity is not null) + { + Console.WriteLine("Successfully created/updated the entity"); + } else + { + Console.Error.WriteLine("Unable to create the entity using the provided JSON."); + retval = 1; + } + } + } catch (Exception e) { - Console.Error.WriteLine("Unable to create the entity using the provided JSON."); + Console.Error.WriteLine($"Unable to create the entity using the provided JSON.\nException: {e}"); retval = 1; } }) @@ -287,11 +318,16 @@ static List<(float, string)> Search(OptionsEntityEvaluate optionsEntityIndex) static Searchdomain GetSearchdomain(string ollamaURL, string searchdomain, string ip, string username, string password, bool runEmpty = false) { string connectionString = $"server={ip};database=embeddingsearch;uid={username};pwd={password};"; - var ollamaConfig = new OllamaApiClient.Configuration + // var ollamaConfig = new OllamaApiClient.Configuration + // { + // Uri = new Uri(ollamaURL) + // }; + var httpClient = new HttpClient { - Uri = new Uri(ollamaURL) + BaseAddress = new Uri(ollamaURL), + Timeout = TimeSpan.FromSeconds(36000) //.MaxValue //FromSeconds(timeout) }; - var ollama = new OllamaApiClient(ollamaConfig); + var ollama = new OllamaApiClient(httpClient); return new Searchdomain(searchdomain, connectionString, ollama, "sqlserver", runEmpty); } diff --git a/src/embeddingsearch/Datapoint.cs b/src/embeddingsearch/Datapoint.cs index b7cda16..5f8eff0 100644 --- a/src/embeddingsearch/Datapoint.cs +++ b/src/embeddingsearch/Datapoint.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.Extensions.AI; using OllamaSharp; @@ -39,22 +40,84 @@ public class Datapoint } public static Dictionary GenerateEmbeddings(string content, List models, OllamaApiClient ollama) + { + return GenerateEmbeddings(content, models, ollama, []); + } + + public static Dictionary GenerateEmbeddings(List contents, string model, OllamaApiClient ollama, Dictionary> embeddingCache) + { + Dictionary retVal = []; + + List remainingContents = new List(contents); + for (int i = contents.Count - 1; i >= 0; i--) // Compare against cache and remove accordingly + { + string content = contents[i]; + if (embeddingCache.ContainsKey(model) && embeddingCache[model].ContainsKey(content)) + { + retVal[content] = embeddingCache[model][content]; + remainingContents.RemoveAt(i); + } + } + if (remainingContents.Count == 0) + { + return retVal; + } + + EmbedRequest request = new() + { + Model = model, + Input = remainingContents + }; + + EmbedResponse response = ollama.EmbedAsync(request).Result; + for (int i = 0; i < response.Embeddings.Count; i++) + { + string content = remainingContents.ElementAt(i); + float[] embeddings = response.Embeddings.ElementAt(i); + retVal[content] = embeddings; + if (!embeddingCache.ContainsKey(model)) + { + embeddingCache[model] = []; + } + if (!embeddingCache[model].ContainsKey(content)) + { + embeddingCache[model][content] = embeddings; + } + } + + return retVal; + } + + public static Dictionary GenerateEmbeddings(string content, List models, OllamaApiClient ollama, Dictionary> embeddingCache) { Dictionary retVal = []; foreach (string model in models) { + if (embeddingCache.ContainsKey(model) && embeddingCache[model].ContainsKey(content)) + { + retVal[model] = embeddingCache[model][content]; + continue; + } EmbedRequest request = new() { Model = model, Input = [content] }; - + var response = ollama.GenerateEmbeddingAsync(content, new EmbeddingGenerationOptions(){ModelId=model}).Result; if (response is not null) { float[] var = new float[response.Vector.Length]; response.Vector.CopyTo(var); retVal[model] = var; + if (!embeddingCache.ContainsKey(model)) + { + embeddingCache[model] = []; + } + if (!embeddingCache[model].ContainsKey(content)) + { + embeddingCache[model][content] = var; + } } } return retVal; diff --git a/src/embeddingsearch/JSONModels.cs b/src/embeddingsearch/JSONModels.cs index b5d8785..5e8fd2f 100644 --- a/src/embeddingsearch/JSONModels.cs +++ b/src/embeddingsearch/JSONModels.cs @@ -1,4 +1,6 @@ -class JSONEntity +namespace embeddingsearch; + +public class JSONEntity { public required string name { get; set; } public required string probmethod { get; set; } @@ -7,10 +9,10 @@ class JSONEntity public required JSONDatapoint[] datapoints { get; set; } } -class JSONDatapoint +public class JSONDatapoint { public required string name { get; set; } public required string text { get; set; } public required string probmethod_embedding { get; set; } public required string[] model { get; set; } -} +} \ No newline at end of file diff --git a/src/embeddingsearch/Searchdomain.cs b/src/embeddingsearch/Searchdomain.cs index 6a5520d..5093a91 100644 --- a/src/embeddingsearch/Searchdomain.cs +++ b/src/embeddingsearch/Searchdomain.cs @@ -282,7 +282,7 @@ public class Searchdomain foreach (JSONDatapoint jsonDatapoint in jsonEntity.datapoints) { - Dictionary embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.text, [.. jsonDatapoint.model], ollama); + Dictionary embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.text, [.. jsonDatapoint.model], ollama, embeddingCache); var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.probmethod_embedding}"); Datapoint datapoint = new(jsonDatapoint.name, probMethod_embedding, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.name, jsonDatapoint.probmethod_embedding, id_entity); @@ -302,6 +302,65 @@ public class Searchdomain return entity; } + public List? EntitiesFromJSON(string json) + { + List? jsonEntities = JsonSerializer.Deserialize>(json); + if (jsonEntities is null) + { + return null; + } + + Dictionary> toBeCached = []; + foreach (JSONEntity jSONEntity in jsonEntities) + { + foreach (JSONDatapoint datapoint in jSONEntity.datapoints) + { + foreach (string model in datapoint.model) + { + if (!toBeCached.ContainsKey(model)) + { + toBeCached[model] = []; + } + toBeCached[model].Add(datapoint.text); + } + } + } + //Console.WriteLine(JsonSerializer.Serialize(toBeCached)); + //return new List(); + Dictionary> cache = []; // local cache + foreach (KeyValuePair> cacheThis in toBeCached) + { + string model = cacheThis.Key; + List contents = cacheThis.Value; + Console.WriteLine("DEBUG@searchdomain-1"); + Console.WriteLine(model); + Console.WriteLine(contents); + Console.WriteLine(contents.Count); + if (contents.Count == 0) + { + Console.WriteLine("DEBUG@searchdomain-2-no"); + cache[model] = []; + continue; + } + Console.WriteLine("DEBUG@searchdomain-2-yes"); + //Console.WriteLine("DEBUG@searchdomain[["); + //Console.WriteLine(model); + //Console.WriteLine(JsonSerializer.Serialize(contents)); + //Console.WriteLine("]]"); + cache[model] = Datapoint.GenerateEmbeddings(contents, model, ollama, embeddingCache); + //Console.WriteLine(JsonSerializer.Serialize(cache[model])); + } + var tempEmbeddingCache = embeddingCache; + embeddingCache = cache; + List retVal = []; + foreach (JSONEntity jSONEntity in jsonEntities) + { + retVal.Append(EntityFromJSON(JsonSerializer.Serialize(jSONEntity))); + } + embeddingCache = tempEmbeddingCache; + return retVal; + } + public void DatabaseRemoveEntity(string name) { Dictionary parameters = new()