Added SimilarityMethod to datapoint; Added euclidian distance, manhattan distance, pearson correlation; improved CosineSimilarity result using a remap
This commit is contained in:
@@ -104,11 +104,11 @@ public class EntityController : ControllerBase
|
||||
{
|
||||
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
|
||||
}
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults});
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = embeddingResults});
|
||||
}
|
||||
else
|
||||
{
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null});
|
||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = null});
|
||||
}
|
||||
}
|
||||
EntityListResult entityListResult = new()
|
||||
|
||||
@@ -7,13 +7,15 @@ public class Datapoint
|
||||
{
|
||||
public string name;
|
||||
public ProbMethod probMethod;
|
||||
public SimilarityMethod similarityMethod;
|
||||
public List<(string, float[])> embeddings;
|
||||
public string hash;
|
||||
|
||||
public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings)
|
||||
public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings)
|
||||
{
|
||||
this.name = name;
|
||||
this.probMethod = probMethod;
|
||||
this.similarityMethod = similarityMethod;
|
||||
this.hash = hash;
|
||||
this.embeddings = embeddings;
|
||||
}
|
||||
|
||||
@@ -56,16 +56,17 @@ public static class DatabaseHelper
|
||||
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", parameters);
|
||||
}
|
||||
|
||||
public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string hash, int id_entity)
|
||||
public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string similarityMethod, string hash, int id_entity)
|
||||
{
|
||||
Dictionary<string, dynamic> parameters = new()
|
||||
{
|
||||
{ "name", name },
|
||||
{ "probmethod_embedding", probmethod_embedding },
|
||||
{ "similaritymethod", similarityMethod },
|
||||
{ "hash", hash },
|
||||
{ "id_entity", id_entity }
|
||||
};
|
||||
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, hash, id_entity) VALUES (@name, @probmethod_embedding, @hash, @id_entity)", parameters);
|
||||
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters);
|
||||
}
|
||||
|
||||
public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding)
|
||||
|
||||
@@ -48,6 +48,7 @@ public static class SearchdomainHelper
|
||||
return null;
|
||||
}
|
||||
|
||||
// toBeCached: model -> [datapoint.text * n]
|
||||
Dictionary<string, List<string>> toBeCached = [];
|
||||
foreach (JSONEntity jSONEntity in jsonEntities)
|
||||
{
|
||||
@@ -79,7 +80,7 @@ public static class SearchdomainHelper
|
||||
|
||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
||||
{
|
||||
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
||||
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = []; // embeddingsLUT: hash -> model -> [embeddingValues * n]
|
||||
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
||||
if (preexistingEntityID is not null)
|
||||
{
|
||||
@@ -139,9 +140,10 @@ public static class SearchdomainHelper
|
||||
{
|
||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache);
|
||||
}
|
||||
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
|
||||
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probMethod name {jsonDatapoint.Probmethod_embedding}");
|
||||
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new Exception($"Unknown similarityMethod name {jsonDatapoint.SimilarityMethod}");
|
||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
|
||||
List<(string model, byte[] embedding)> data = [];
|
||||
foreach ((string, float[]) embedding in datapoint.embeddings)
|
||||
{
|
||||
|
||||
@@ -85,4 +85,10 @@ public static class DatabaseMigrations
|
||||
helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []);
|
||||
return 3;
|
||||
}
|
||||
|
||||
public static int UpdateFrom3(SQLHelper helper)
|
||||
{
|
||||
helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []);
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
@@ -169,9 +169,4 @@ public static class Probmethods
|
||||
}
|
||||
return f / fm;
|
||||
}
|
||||
|
||||
public static float Similarity(float[] vector1, float[] vector2)
|
||||
{
|
||||
return TensorPrimitives.CosineSimilarity(vector1, vector2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ public class Searchdomain
|
||||
}
|
||||
embeddingReader.Close();
|
||||
|
||||
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, hash FROM datapoint", parametersIDSearchdomain);
|
||||
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, similaritymethod, hash FROM datapoint", parametersIDSearchdomain);
|
||||
Dictionary<int, List<Datapoint>> datapoint_unassigned = [];
|
||||
while (datapointReader.Read())
|
||||
{
|
||||
@@ -81,8 +81,10 @@ public class Searchdomain
|
||||
int id_entity = datapointReader.GetInt32(1);
|
||||
string name = datapointReader.GetString(2);
|
||||
string probmethodString = datapointReader.GetString(3);
|
||||
string hash = datapointReader.GetString(4);
|
||||
string similarityMethodString = datapointReader.GetString(4);
|
||||
string hash = datapointReader.GetString(5);
|
||||
ProbMethod probmethod = new(probmethodString, _logger);
|
||||
SimilarityMethod similarityMethod = new(similarityMethodString, _logger);
|
||||
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
||||
{
|
||||
embedding_unassigned.Remove(id);
|
||||
@@ -90,7 +92,7 @@ public class Searchdomain
|
||||
{
|
||||
datapoint_unassigned[id_entity] = [];
|
||||
}
|
||||
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
|
||||
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
|
||||
}
|
||||
}
|
||||
datapointReader.Close();
|
||||
@@ -157,11 +159,12 @@ public class Searchdomain
|
||||
List<(string, float)> datapointProbs = [];
|
||||
foreach (Datapoint datapoint in entity.datapoints)
|
||||
{
|
||||
SimilarityMethod similarityMethod = datapoint.similarityMethod;
|
||||
List<(string, float)> list = [];
|
||||
foreach ((string, float[]) embedding in datapoint.embeddings)
|
||||
{
|
||||
string key = embedding.Item1;
|
||||
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
|
||||
float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2);
|
||||
list.Add((key, value));
|
||||
}
|
||||
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
|
||||
|
||||
112
src/Server/SimilarityMethods.cs
Normal file
112
src/Server/SimilarityMethods.cs
Normal file
@@ -0,0 +1,112 @@
|
||||
using System.Numerics.Tensors;
|
||||
using System.Text.Json;
|
||||
|
||||
namespace Server;
|
||||
|
||||
public class SimilarityMethod
|
||||
{
|
||||
public SimilarityMethods.similarityMethodDelegate method;
|
||||
public string name;
|
||||
|
||||
public SimilarityMethod(string name, ILogger logger)
|
||||
{
|
||||
this.name = name;
|
||||
SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name);
|
||||
if (probMethod is null)
|
||||
{
|
||||
logger.LogError("Unable to retrieve similarityMethod {name}", [name]);
|
||||
throw new Exception("Unable to retrieve similarityMethod");
|
||||
}
|
||||
method = probMethod;
|
||||
}
|
||||
}
|
||||
|
||||
public static class SimilarityMethods
|
||||
{
|
||||
public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2);
|
||||
public delegate float similarityMethodDelegate(float[] vector1, float[] vector2);
|
||||
public static readonly Dictionary<string, similarityMethodProtoDelegate> probMethods;
|
||||
|
||||
static SimilarityMethods()
|
||||
{
|
||||
probMethods = new Dictionary<string, similarityMethodProtoDelegate>
|
||||
{
|
||||
["Cosine"] = CosineSimilarity,
|
||||
["Euclidian"] = EuclidianDistance,
|
||||
["Manhattan"] = ManhattanDistance,
|
||||
["Pearson"] = PearsonCorrelation
|
||||
};
|
||||
}
|
||||
|
||||
public static similarityMethodDelegate? GetMethod(string name)
|
||||
{
|
||||
string methodName = name;
|
||||
|
||||
if (!probMethods.TryGetValue(methodName, out similarityMethodProtoDelegate? method))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
return (vector1, vector2) => method(vector1, vector2);
|
||||
}
|
||||
|
||||
|
||||
public static float CosineSimilarity(float[] vector1, float[] vector2)
|
||||
{
|
||||
return (TensorPrimitives.CosineSimilarity(vector1, vector2) + 1) / 2;
|
||||
}
|
||||
|
||||
public static float EuclidianDistance(float[] vector1, float[] vector2)
|
||||
{
|
||||
if (vector1.Length != vector2.Length)
|
||||
{
|
||||
throw new ArgumentException("Unable to calculate Euclidian distance - Vectors must have the same length");
|
||||
}
|
||||
float sum = 0;
|
||||
for (int i = 0; i < vector1.Length; i++)
|
||||
{
|
||||
float diff = vector1[i] - vector2[i];
|
||||
sum += diff * diff;
|
||||
}
|
||||
return RationalRemap((float)Math.Sqrt(sum));
|
||||
}
|
||||
|
||||
public static float ManhattanDistance(float[] vector1, float[] vector2)
|
||||
{
|
||||
if (vector1.Length != vector2.Length)
|
||||
throw new ArgumentException("Unable to calculate Manhattan distance - Vectors must have the same length");
|
||||
|
||||
float sum = 0;
|
||||
for (int i = 0; i < vector1.Length; i++)
|
||||
{
|
||||
sum += Math.Abs(vector1[i] - vector2[i]);
|
||||
}
|
||||
return RationalRemap(sum);
|
||||
}
|
||||
|
||||
public static float PearsonCorrelation(float[] vector1, float[] vector2)
|
||||
{
|
||||
if (vector1.Length != vector2.Length)
|
||||
throw new ArgumentException("Unable to calculate Pearson correlation - Vectors must have the same length");
|
||||
|
||||
int n = vector1.Length;
|
||||
double sum1 = vector1.Sum();
|
||||
double sum2 = vector2.Sum();
|
||||
double sum1Sq = vector1.Select(x => x * x).Sum();
|
||||
double sum2Sq = vector2.Select(x => x * x).Sum();
|
||||
double pSum = vector1.Zip(vector2, (x, y) => x * y).Sum();
|
||||
|
||||
double num = pSum - (sum1 * sum2 / n);
|
||||
double den = Math.Sqrt((sum1Sq - (sum1 * sum1) / n) * (sum2Sq - (sum2 * sum2) / n));
|
||||
|
||||
return den == 0 ? 0 : (float)(num / den);
|
||||
}
|
||||
|
||||
public static float RationalRemap(float x)
|
||||
{
|
||||
if (x == float.PositiveInfinity)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
return 1 / (1 + x);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user