From e74ed1f9ea7be9a9741a9512df3131e5e663f4b6 Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Sat, 23 Aug 2025 21:34:48 +0200 Subject: [PATCH] Added SimilarityMethod to datapoint; Added euclidian distance, manhattan distance, pearson correlation; improved CosineSimilarity result using a remap --- src/Indexer/Scripts/tools.py | 1 + src/Server/Controllers/EntityController.cs | 4 +- src/Server/Datapoint.cs | 4 +- src/Server/Helper/DatabaseHelper.cs | 5 +- src/Server/Helper/SearchdomainHelper.cs | 10 +- src/Server/Migrations/DatabaseMigrations.cs | 6 ++ src/Server/Probmethods.cs | 5 - src/Server/Searchdomain.cs | 11 +- src/Server/SimilarityMethods.cs | 112 ++++++++++++++++++++ src/Shared/Models/EntityResults.cs | 2 + src/Shared/Models/JSONModels.cs | 1 + 11 files changed, 143 insertions(+), 18 deletions(-) create mode 100644 src/Server/SimilarityMethods.cs diff --git a/src/Indexer/Scripts/tools.py b/src/Indexer/Scripts/tools.py index fc5722e..76bd838 100644 --- a/src/Indexer/Scripts/tools.py +++ b/src/Indexer/Scripts/tools.py @@ -7,6 +7,7 @@ class JSONDatapoint: Name:str Text:str Probmethod_embedding:str + SimilarityMethod:str Model:list[str] @dataclass diff --git a/src/Server/Controllers/EntityController.cs b/src/Server/Controllers/EntityController.cs index b999be3..f8a17ce 100644 --- a/src/Server/Controllers/EntityController.cs +++ b/src/Server/Controllers/EntityController.cs @@ -104,11 +104,11 @@ public class EntityController : ControllerBase { embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2}); } - datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults}); + datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = embeddingResults}); } else { - datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null}); + datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = null}); } } EntityListResult entityListResult = new() diff --git a/src/Server/Datapoint.cs b/src/Server/Datapoint.cs index 707238b..2389d42 100644 --- a/src/Server/Datapoint.cs +++ b/src/Server/Datapoint.cs @@ -7,13 +7,15 @@ public class Datapoint { public string name; public ProbMethod probMethod; + public SimilarityMethod similarityMethod; public List<(string, float[])> embeddings; public string hash; - public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings) + public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings) { this.name = name; this.probMethod = probMethod; + this.similarityMethod = similarityMethod; this.hash = hash; this.embeddings = embeddings; } diff --git a/src/Server/Helper/DatabaseHelper.cs b/src/Server/Helper/DatabaseHelper.cs index ea6c6ef..725345c 100644 --- a/src/Server/Helper/DatabaseHelper.cs +++ b/src/Server/Helper/DatabaseHelper.cs @@ -56,16 +56,17 @@ public static class DatabaseHelper return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", parameters); } - public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string hash, int id_entity) + public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string similarityMethod, string hash, int id_entity) { Dictionary parameters = new() { { "name", name }, { "probmethod_embedding", probmethod_embedding }, + { "similaritymethod", similarityMethod }, { "hash", hash }, { "id_entity", id_entity } }; - return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, hash, id_entity) VALUES (@name, @probmethod_embedding, @hash, @id_entity)", parameters); + return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters); } public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding) diff --git a/src/Server/Helper/SearchdomainHelper.cs b/src/Server/Helper/SearchdomainHelper.cs index 79edea3..fdccd11 100644 --- a/src/Server/Helper/SearchdomainHelper.cs +++ b/src/Server/Helper/SearchdomainHelper.cs @@ -48,6 +48,7 @@ public static class SearchdomainHelper return null; } + // toBeCached: model -> [datapoint.text * n] Dictionary> toBeCached = []; foreach (JSONEntity jSONEntity in jsonEntities) { @@ -79,7 +80,7 @@ public static class SearchdomainHelper public static Entity? EntityFromJSON(List entityCache, Dictionary> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json) { - Dictionary> embeddingsLUT = []; + Dictionary> embeddingsLUT = []; // embeddingsLUT: hash -> model -> [embeddingValues * n] int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain); if (preexistingEntityID is not null) { @@ -139,9 +140,10 @@ public static class SearchdomainHelper { embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache); } - var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); - Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); - int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries + var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probMethod name {jsonDatapoint.Probmethod_embedding}"); + var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new Exception($"Unknown similarityMethod name {jsonDatapoint.SimilarityMethod}"); + Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); + int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries List<(string model, byte[] embedding)> data = []; foreach ((string, float[]) embedding in datapoint.embeddings) { diff --git a/src/Server/Migrations/DatabaseMigrations.cs b/src/Server/Migrations/DatabaseMigrations.cs index 2d50668..800396b 100644 --- a/src/Server/Migrations/DatabaseMigrations.cs +++ b/src/Server/Migrations/DatabaseMigrations.cs @@ -85,4 +85,10 @@ public static class DatabaseMigrations helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []); return 3; } + + public static int UpdateFrom3(SQLHelper helper) + { + helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []); + return 4; + } } \ No newline at end of file diff --git a/src/Server/Probmethods.cs b/src/Server/Probmethods.cs index 054d543..d094ff2 100644 --- a/src/Server/Probmethods.cs +++ b/src/Server/Probmethods.cs @@ -169,9 +169,4 @@ public static class Probmethods } return f / fm; } - - public static float Similarity(float[] vector1, float[] vector2) - { - return TensorPrimitives.CosineSimilarity(vector1, vector2); - } } diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index 75e58e5..b7320ca 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -73,7 +73,7 @@ public class Searchdomain } embeddingReader.Close(); - DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, hash FROM datapoint", parametersIDSearchdomain); + DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, similaritymethod, hash FROM datapoint", parametersIDSearchdomain); Dictionary> datapoint_unassigned = []; while (datapointReader.Read()) { @@ -81,8 +81,10 @@ public class Searchdomain int id_entity = datapointReader.GetInt32(1); string name = datapointReader.GetString(2); string probmethodString = datapointReader.GetString(3); - string hash = datapointReader.GetString(4); + string similarityMethodString = datapointReader.GetString(4); + string hash = datapointReader.GetString(5); ProbMethod probmethod = new(probmethodString, _logger); + SimilarityMethod similarityMethod = new(similarityMethodString, _logger); if (embedding_unassigned.TryGetValue(id, out Dictionary? embeddings) && probmethod is not null) { embedding_unassigned.Remove(id); @@ -90,7 +92,7 @@ public class Searchdomain { datapoint_unassigned[id_entity] = []; } - datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))])); + datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))])); } } datapointReader.Close(); @@ -157,11 +159,12 @@ public class Searchdomain List<(string, float)> datapointProbs = []; foreach (Datapoint datapoint in entity.datapoints) { + SimilarityMethod similarityMethod = datapoint.similarityMethod; List<(string, float)> list = []; foreach ((string, float[]) embedding in datapoint.embeddings) { string key = embedding.Item1; - float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2); + float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2); list.Add((key, value)); } datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list))); diff --git a/src/Server/SimilarityMethods.cs b/src/Server/SimilarityMethods.cs new file mode 100644 index 0000000..9521825 --- /dev/null +++ b/src/Server/SimilarityMethods.cs @@ -0,0 +1,112 @@ +using System.Numerics.Tensors; +using System.Text.Json; + +namespace Server; + +public class SimilarityMethod +{ + public SimilarityMethods.similarityMethodDelegate method; + public string name; + + public SimilarityMethod(string name, ILogger logger) + { + this.name = name; + SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name); + if (probMethod is null) + { + logger.LogError("Unable to retrieve similarityMethod {name}", [name]); + throw new Exception("Unable to retrieve similarityMethod"); + } + method = probMethod; + } +} + +public static class SimilarityMethods +{ + public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2); + public delegate float similarityMethodDelegate(float[] vector1, float[] vector2); + public static readonly Dictionary probMethods; + + static SimilarityMethods() + { + probMethods = new Dictionary + { + ["Cosine"] = CosineSimilarity, + ["Euclidian"] = EuclidianDistance, + ["Manhattan"] = ManhattanDistance, + ["Pearson"] = PearsonCorrelation + }; + } + + public static similarityMethodDelegate? GetMethod(string name) + { + string methodName = name; + + if (!probMethods.TryGetValue(methodName, out similarityMethodProtoDelegate? method)) + { + return null; + } + return (vector1, vector2) => method(vector1, vector2); + } + + + public static float CosineSimilarity(float[] vector1, float[] vector2) + { + return (TensorPrimitives.CosineSimilarity(vector1, vector2) + 1) / 2; + } + + public static float EuclidianDistance(float[] vector1, float[] vector2) + { + if (vector1.Length != vector2.Length) + { + throw new ArgumentException("Unable to calculate Euclidian distance - Vectors must have the same length"); + } + float sum = 0; + for (int i = 0; i < vector1.Length; i++) + { + float diff = vector1[i] - vector2[i]; + sum += diff * diff; + } + return RationalRemap((float)Math.Sqrt(sum)); + } + + public static float ManhattanDistance(float[] vector1, float[] vector2) + { + if (vector1.Length != vector2.Length) + throw new ArgumentException("Unable to calculate Manhattan distance - Vectors must have the same length"); + + float sum = 0; + for (int i = 0; i < vector1.Length; i++) + { + sum += Math.Abs(vector1[i] - vector2[i]); + } + return RationalRemap(sum); + } + + public static float PearsonCorrelation(float[] vector1, float[] vector2) + { + if (vector1.Length != vector2.Length) + throw new ArgumentException("Unable to calculate Pearson correlation - Vectors must have the same length"); + + int n = vector1.Length; + double sum1 = vector1.Sum(); + double sum2 = vector2.Sum(); + double sum1Sq = vector1.Select(x => x * x).Sum(); + double sum2Sq = vector2.Select(x => x * x).Sum(); + double pSum = vector1.Zip(vector2, (x, y) => x * y).Sum(); + + double num = pSum - (sum1 * sum2 / n); + double den = Math.Sqrt((sum1Sq - (sum1 * sum1) / n) * (sum2Sq - (sum2 * sum2) / n)); + + return den == 0 ? 0 : (float)(num / den); + } + + public static float RationalRemap(float x) + { + if (x == float.PositiveInfinity) + { + return 0; + } + return 1 / (1 + x); + } +} diff --git a/src/Shared/Models/EntityResults.cs b/src/Shared/Models/EntityResults.cs index 1ad54bf..65692bc 100644 --- a/src/Shared/Models/EntityResults.cs +++ b/src/Shared/Models/EntityResults.cs @@ -55,6 +55,8 @@ public class DatapointResult public required string Name { get; set; } [JsonPropertyName("ProbMethod")] public required string ProbMethod { get; set; } + [JsonPropertyName("SimilarityMethod")] + public required string SimilarityMethod { get; set; } [JsonPropertyName("Embeddings")] public required List? Embeddings { get; set; } } diff --git a/src/Shared/Models/JSONModels.cs b/src/Shared/Models/JSONModels.cs index d74e4ac..86ca7b3 100644 --- a/src/Shared/Models/JSONModels.cs +++ b/src/Shared/Models/JSONModels.cs @@ -14,5 +14,6 @@ public class JSONDatapoint public required string Name { get; set; } public required string Text { get; set; } public required string Probmethod_embedding { get; set; } + public required string SimilarityMethod { get; set; } public required string[] Model { get; set; } } \ No newline at end of file