Added SimilarityMethod to datapoint; Added euclidian distance, manhattan distance, pearson correlation; improved CosineSimilarity result using a remap

This commit is contained in:
2025-08-23 21:34:48 +02:00
parent 631aafe68f
commit e74ed1f9ea
11 changed files with 143 additions and 18 deletions

View File

@@ -7,6 +7,7 @@ class JSONDatapoint:
Name:str
Text:str
Probmethod_embedding:str
SimilarityMethod:str
Model:list[str]
@dataclass

View File

@@ -104,11 +104,11 @@ public class EntityController : ControllerBase
{
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
}
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults});
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = embeddingResults});
}
else
{
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null});
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = null});
}
}
EntityListResult entityListResult = new()

View File

@@ -7,13 +7,15 @@ public class Datapoint
{
public string name;
public ProbMethod probMethod;
public SimilarityMethod similarityMethod;
public List<(string, float[])> embeddings;
public string hash;
public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings)
public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings)
{
this.name = name;
this.probMethod = probMethod;
this.similarityMethod = similarityMethod;
this.hash = hash;
this.embeddings = embeddings;
}

View File

@@ -56,16 +56,17 @@ public static class DatabaseHelper
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", parameters);
}
public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string hash, int id_entity)
public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string similarityMethod, string hash, int id_entity)
{
Dictionary<string, dynamic> parameters = new()
{
{ "name", name },
{ "probmethod_embedding", probmethod_embedding },
{ "similaritymethod", similarityMethod },
{ "hash", hash },
{ "id_entity", id_entity }
};
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, hash, id_entity) VALUES (@name, @probmethod_embedding, @hash, @id_entity)", parameters);
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters);
}
public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding)

View File

@@ -48,6 +48,7 @@ public static class SearchdomainHelper
return null;
}
// toBeCached: model -> [datapoint.text * n]
Dictionary<string, List<string>> toBeCached = [];
foreach (JSONEntity jSONEntity in jsonEntities)
{
@@ -79,7 +80,7 @@ public static class SearchdomainHelper
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
{
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = []; // embeddingsLUT: hash -> model -> [embeddingValues * n]
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
if (preexistingEntityID is not null)
{
@@ -139,9 +140,10 @@ public static class SearchdomainHelper
{
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache);
}
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probMethod name {jsonDatapoint.Probmethod_embedding}");
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new Exception($"Unknown similarityMethod name {jsonDatapoint.SimilarityMethod}");
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
List<(string model, byte[] embedding)> data = [];
foreach ((string, float[]) embedding in datapoint.embeddings)
{

View File

@@ -85,4 +85,10 @@ public static class DatabaseMigrations
helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []);
return 3;
}
public static int UpdateFrom3(SQLHelper helper)
{
helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []);
return 4;
}
}

View File

@@ -169,9 +169,4 @@ public static class Probmethods
}
return f / fm;
}
public static float Similarity(float[] vector1, float[] vector2)
{
return TensorPrimitives.CosineSimilarity(vector1, vector2);
}
}

View File

@@ -73,7 +73,7 @@ public class Searchdomain
}
embeddingReader.Close();
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, hash FROM datapoint", parametersIDSearchdomain);
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, similaritymethod, hash FROM datapoint", parametersIDSearchdomain);
Dictionary<int, List<Datapoint>> datapoint_unassigned = [];
while (datapointReader.Read())
{
@@ -81,8 +81,10 @@ public class Searchdomain
int id_entity = datapointReader.GetInt32(1);
string name = datapointReader.GetString(2);
string probmethodString = datapointReader.GetString(3);
string hash = datapointReader.GetString(4);
string similarityMethodString = datapointReader.GetString(4);
string hash = datapointReader.GetString(5);
ProbMethod probmethod = new(probmethodString, _logger);
SimilarityMethod similarityMethod = new(similarityMethodString, _logger);
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
{
embedding_unassigned.Remove(id);
@@ -90,7 +92,7 @@ public class Searchdomain
{
datapoint_unassigned[id_entity] = [];
}
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
}
}
datapointReader.Close();
@@ -157,11 +159,12 @@ public class Searchdomain
List<(string, float)> datapointProbs = [];
foreach (Datapoint datapoint in entity.datapoints)
{
SimilarityMethod similarityMethod = datapoint.similarityMethod;
List<(string, float)> list = [];
foreach ((string, float[]) embedding in datapoint.embeddings)
{
string key = embedding.Item1;
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2);
list.Add((key, value));
}
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));

View File

@@ -0,0 +1,112 @@
using System.Numerics.Tensors;
using System.Text.Json;
namespace Server;
public class SimilarityMethod
{
public SimilarityMethods.similarityMethodDelegate method;
public string name;
public SimilarityMethod(string name, ILogger logger)
{
this.name = name;
SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name);
if (probMethod is null)
{
logger.LogError("Unable to retrieve similarityMethod {name}", [name]);
throw new Exception("Unable to retrieve similarityMethod");
}
method = probMethod;
}
}
public static class SimilarityMethods
{
public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2);
public delegate float similarityMethodDelegate(float[] vector1, float[] vector2);
public static readonly Dictionary<string, similarityMethodProtoDelegate> probMethods;
static SimilarityMethods()
{
probMethods = new Dictionary<string, similarityMethodProtoDelegate>
{
["Cosine"] = CosineSimilarity,
["Euclidian"] = EuclidianDistance,
["Manhattan"] = ManhattanDistance,
["Pearson"] = PearsonCorrelation
};
}
public static similarityMethodDelegate? GetMethod(string name)
{
string methodName = name;
if (!probMethods.TryGetValue(methodName, out similarityMethodProtoDelegate? method))
{
return null;
}
return (vector1, vector2) => method(vector1, vector2);
}
public static float CosineSimilarity(float[] vector1, float[] vector2)
{
return (TensorPrimitives.CosineSimilarity(vector1, vector2) + 1) / 2;
}
public static float EuclidianDistance(float[] vector1, float[] vector2)
{
if (vector1.Length != vector2.Length)
{
throw new ArgumentException("Unable to calculate Euclidian distance - Vectors must have the same length");
}
float sum = 0;
for (int i = 0; i < vector1.Length; i++)
{
float diff = vector1[i] - vector2[i];
sum += diff * diff;
}
return RationalRemap((float)Math.Sqrt(sum));
}
public static float ManhattanDistance(float[] vector1, float[] vector2)
{
if (vector1.Length != vector2.Length)
throw new ArgumentException("Unable to calculate Manhattan distance - Vectors must have the same length");
float sum = 0;
for (int i = 0; i < vector1.Length; i++)
{
sum += Math.Abs(vector1[i] - vector2[i]);
}
return RationalRemap(sum);
}
public static float PearsonCorrelation(float[] vector1, float[] vector2)
{
if (vector1.Length != vector2.Length)
throw new ArgumentException("Unable to calculate Pearson correlation - Vectors must have the same length");
int n = vector1.Length;
double sum1 = vector1.Sum();
double sum2 = vector2.Sum();
double sum1Sq = vector1.Select(x => x * x).Sum();
double sum2Sq = vector2.Select(x => x * x).Sum();
double pSum = vector1.Zip(vector2, (x, y) => x * y).Sum();
double num = pSum - (sum1 * sum2 / n);
double den = Math.Sqrt((sum1Sq - (sum1 * sum1) / n) * (sum2Sq - (sum2 * sum2) / n));
return den == 0 ? 0 : (float)(num / den);
}
public static float RationalRemap(float x)
{
if (x == float.PositiveInfinity)
{
return 0;
}
return 1 / (1 + x);
}
}

View File

@@ -55,6 +55,8 @@ public class DatapointResult
public required string Name { get; set; }
[JsonPropertyName("ProbMethod")]
public required string ProbMethod { get; set; }
[JsonPropertyName("SimilarityMethod")]
public required string SimilarityMethod { get; set; }
[JsonPropertyName("Embeddings")]
public required List<EmbeddingResult>? Embeddings { get; set; }
}

View File

@@ -14,5 +14,6 @@ public class JSONDatapoint
public required string Name { get; set; }
public required string Text { get; set; }
public required string Probmethod_embedding { get; set; }
public required string SimilarityMethod { get; set; }
public required string[] Model { get; set; }
}