Added SimilarityMethod to datapoint; Added euclidian distance, manhattan distance, pearson correlation; improved CosineSimilarity result using a remap
This commit is contained in:
@@ -7,6 +7,7 @@ class JSONDatapoint:
|
|||||||
Name:str
|
Name:str
|
||||||
Text:str
|
Text:str
|
||||||
Probmethod_embedding:str
|
Probmethod_embedding:str
|
||||||
|
SimilarityMethod:str
|
||||||
Model:list[str]
|
Model:list[str]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -104,11 +104,11 @@ public class EntityController : ControllerBase
|
|||||||
{
|
{
|
||||||
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
|
embeddingResults.Add(new EmbeddingResult() {Model = embedding.Item1, Embeddings = embedding.Item2});
|
||||||
}
|
}
|
||||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = embeddingResults});
|
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = embeddingResults});
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, Embeddings = null});
|
datapointResults.Add(new DatapointResult() {Name = datapoint.name, ProbMethod = datapoint.probMethod.name, SimilarityMethod = datapoint.similarityMethod.name, Embeddings = null});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EntityListResult entityListResult = new()
|
EntityListResult entityListResult = new()
|
||||||
|
|||||||
@@ -7,13 +7,15 @@ public class Datapoint
|
|||||||
{
|
{
|
||||||
public string name;
|
public string name;
|
||||||
public ProbMethod probMethod;
|
public ProbMethod probMethod;
|
||||||
|
public SimilarityMethod similarityMethod;
|
||||||
public List<(string, float[])> embeddings;
|
public List<(string, float[])> embeddings;
|
||||||
public string hash;
|
public string hash;
|
||||||
|
|
||||||
public Datapoint(string name, ProbMethod probMethod, string hash, List<(string, float[])> embeddings)
|
public Datapoint(string name, ProbMethod probMethod, SimilarityMethod similarityMethod, string hash, List<(string, float[])> embeddings)
|
||||||
{
|
{
|
||||||
this.name = name;
|
this.name = name;
|
||||||
this.probMethod = probMethod;
|
this.probMethod = probMethod;
|
||||||
|
this.similarityMethod = similarityMethod;
|
||||||
this.hash = hash;
|
this.hash = hash;
|
||||||
this.embeddings = embeddings;
|
this.embeddings = embeddings;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,16 +56,17 @@ public static class DatabaseHelper
|
|||||||
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", parameters);
|
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO attribute (attribute, value, id_entity) VALUES (@attribute, @value, @id_entity)", parameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string hash, int id_entity)
|
public static int DatabaseInsertDatapoint(SQLHelper helper, string name, string probmethod_embedding, string similarityMethod, string hash, int id_entity)
|
||||||
{
|
{
|
||||||
Dictionary<string, dynamic> parameters = new()
|
Dictionary<string, dynamic> parameters = new()
|
||||||
{
|
{
|
||||||
{ "name", name },
|
{ "name", name },
|
||||||
{ "probmethod_embedding", probmethod_embedding },
|
{ "probmethod_embedding", probmethod_embedding },
|
||||||
|
{ "similaritymethod", similarityMethod },
|
||||||
{ "hash", hash },
|
{ "hash", hash },
|
||||||
{ "id_entity", id_entity }
|
{ "id_entity", id_entity }
|
||||||
};
|
};
|
||||||
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, hash, id_entity) VALUES (@name, @probmethod_embedding, @hash, @id_entity)", parameters);
|
return helper.ExecuteSQLCommandGetInsertedID("INSERT INTO datapoint (name, probmethod_embedding, similaritymethod, hash, id_entity) VALUES (@name, @probmethod_embedding, @similaritymethod, @hash, @id_entity)", parameters);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding)
|
public static int DatabaseInsertEmbedding(SQLHelper helper, int id_datapoint, string model, byte[] embedding)
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ public static class SearchdomainHelper
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toBeCached: model -> [datapoint.text * n]
|
||||||
Dictionary<string, List<string>> toBeCached = [];
|
Dictionary<string, List<string>> toBeCached = [];
|
||||||
foreach (JSONEntity jSONEntity in jsonEntities)
|
foreach (JSONEntity jSONEntity in jsonEntities)
|
||||||
{
|
{
|
||||||
@@ -79,7 +80,7 @@ public static class SearchdomainHelper
|
|||||||
|
|
||||||
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
public static Entity? EntityFromJSON(List<Entity> entityCache, Dictionary<string, Dictionary<string, float[]>> embeddingCache, AIProvider aIProvider, SQLHelper helper, ILogger logger, JSONEntity jsonEntity) //string json)
|
||||||
{
|
{
|
||||||
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = [];
|
Dictionary<string, Dictionary<string, float[]>> embeddingsLUT = []; // embeddingsLUT: hash -> model -> [embeddingValues * n]
|
||||||
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
int? preexistingEntityID = DatabaseHelper.GetEntityID(helper, jsonEntity.Name, jsonEntity.Searchdomain);
|
||||||
if (preexistingEntityID is not null)
|
if (preexistingEntityID is not null)
|
||||||
{
|
{
|
||||||
@@ -139,9 +140,10 @@ public static class SearchdomainHelper
|
|||||||
{
|
{
|
||||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache);
|
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], aIProvider, embeddingCache);
|
||||||
}
|
}
|
||||||
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
var probMethod_embedding = new ProbMethod(jsonDatapoint.Probmethod_embedding, logger) ?? throw new Exception($"Unknown probMethod name {jsonDatapoint.Probmethod_embedding}");
|
||||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
var similarityMethod = new SimilarityMethod(jsonDatapoint.SimilarityMethod, logger) ?? throw new Exception($"Unknown similarityMethod name {jsonDatapoint.SimilarityMethod}");
|
||||||
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
|
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||||
|
int id_datapoint = DatabaseHelper.DatabaseInsertDatapoint(helper, jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, jsonDatapoint.SimilarityMethod, hash, id_entity); // TODO make this a bulk add action to reduce number of queries
|
||||||
List<(string model, byte[] embedding)> data = [];
|
List<(string model, byte[] embedding)> data = [];
|
||||||
foreach ((string, float[]) embedding in datapoint.embeddings)
|
foreach ((string, float[]) embedding in datapoint.embeddings)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -85,4 +85,10 @@ public static class DatabaseMigrations
|
|||||||
helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []);
|
helper.ExecuteSQLNonQuery("UPDATE datapoint SET hash='';", []);
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static int UpdateFrom3(SQLHelper helper)
|
||||||
|
{
|
||||||
|
helper.ExecuteSQLNonQuery("ALTER TABLE datapoint ADD COLUMN similaritymethod VARCHAR(512) NULL DEFAULT 'Cosine' AFTER probmethod_embedding", []);
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -169,9 +169,4 @@ public static class Probmethods
|
|||||||
}
|
}
|
||||||
return f / fm;
|
return f / fm;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static float Similarity(float[] vector1, float[] vector2)
|
|
||||||
{
|
|
||||||
return TensorPrimitives.CosineSimilarity(vector1, vector2);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ public class Searchdomain
|
|||||||
}
|
}
|
||||||
embeddingReader.Close();
|
embeddingReader.Close();
|
||||||
|
|
||||||
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, hash FROM datapoint", parametersIDSearchdomain);
|
DbDataReader datapointReader = helper.ExecuteSQLCommand("SELECT id, id_entity, name, probmethod_embedding, similaritymethod, hash FROM datapoint", parametersIDSearchdomain);
|
||||||
Dictionary<int, List<Datapoint>> datapoint_unassigned = [];
|
Dictionary<int, List<Datapoint>> datapoint_unassigned = [];
|
||||||
while (datapointReader.Read())
|
while (datapointReader.Read())
|
||||||
{
|
{
|
||||||
@@ -81,8 +81,10 @@ public class Searchdomain
|
|||||||
int id_entity = datapointReader.GetInt32(1);
|
int id_entity = datapointReader.GetInt32(1);
|
||||||
string name = datapointReader.GetString(2);
|
string name = datapointReader.GetString(2);
|
||||||
string probmethodString = datapointReader.GetString(3);
|
string probmethodString = datapointReader.GetString(3);
|
||||||
string hash = datapointReader.GetString(4);
|
string similarityMethodString = datapointReader.GetString(4);
|
||||||
|
string hash = datapointReader.GetString(5);
|
||||||
ProbMethod probmethod = new(probmethodString, _logger);
|
ProbMethod probmethod = new(probmethodString, _logger);
|
||||||
|
SimilarityMethod similarityMethod = new(similarityMethodString, _logger);
|
||||||
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
||||||
{
|
{
|
||||||
embedding_unassigned.Remove(id);
|
embedding_unassigned.Remove(id);
|
||||||
@@ -90,7 +92,7 @@ public class Searchdomain
|
|||||||
{
|
{
|
||||||
datapoint_unassigned[id_entity] = [];
|
datapoint_unassigned[id_entity] = [];
|
||||||
}
|
}
|
||||||
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
|
datapoint_unassigned[id_entity].Add(new Datapoint(name, probmethod, similarityMethod, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
datapointReader.Close();
|
datapointReader.Close();
|
||||||
@@ -157,11 +159,12 @@ public class Searchdomain
|
|||||||
List<(string, float)> datapointProbs = [];
|
List<(string, float)> datapointProbs = [];
|
||||||
foreach (Datapoint datapoint in entity.datapoints)
|
foreach (Datapoint datapoint in entity.datapoints)
|
||||||
{
|
{
|
||||||
|
SimilarityMethod similarityMethod = datapoint.similarityMethod;
|
||||||
List<(string, float)> list = [];
|
List<(string, float)> list = [];
|
||||||
foreach ((string, float[]) embedding in datapoint.embeddings)
|
foreach ((string, float[]) embedding in datapoint.embeddings)
|
||||||
{
|
{
|
||||||
string key = embedding.Item1;
|
string key = embedding.Item1;
|
||||||
float value = Probmethods.Similarity(queryEmbeddings[embedding.Item1], embedding.Item2);
|
float value = similarityMethod.method(queryEmbeddings[embedding.Item1], embedding.Item2);
|
||||||
list.Add((key, value));
|
list.Add((key, value));
|
||||||
}
|
}
|
||||||
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
|
datapointProbs.Add((datapoint.name, datapoint.probMethod.method(list)));
|
||||||
|
|||||||
112
src/Server/SimilarityMethods.cs
Normal file
112
src/Server/SimilarityMethods.cs
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
using System.Numerics.Tensors;
|
||||||
|
using System.Text.Json;
|
||||||
|
|
||||||
|
namespace Server;
|
||||||
|
|
||||||
|
public class SimilarityMethod
|
||||||
|
{
|
||||||
|
public SimilarityMethods.similarityMethodDelegate method;
|
||||||
|
public string name;
|
||||||
|
|
||||||
|
public SimilarityMethod(string name, ILogger logger)
|
||||||
|
{
|
||||||
|
this.name = name;
|
||||||
|
SimilarityMethods.similarityMethodDelegate? probMethod = SimilarityMethods.GetMethod(name);
|
||||||
|
if (probMethod is null)
|
||||||
|
{
|
||||||
|
logger.LogError("Unable to retrieve similarityMethod {name}", [name]);
|
||||||
|
throw new Exception("Unable to retrieve similarityMethod");
|
||||||
|
}
|
||||||
|
method = probMethod;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static class SimilarityMethods
|
||||||
|
{
|
||||||
|
public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2);
|
||||||
|
public delegate float similarityMethodDelegate(float[] vector1, float[] vector2);
|
||||||
|
public static readonly Dictionary<string, similarityMethodProtoDelegate> probMethods;
|
||||||
|
|
||||||
|
static SimilarityMethods()
|
||||||
|
{
|
||||||
|
probMethods = new Dictionary<string, similarityMethodProtoDelegate>
|
||||||
|
{
|
||||||
|
["Cosine"] = CosineSimilarity,
|
||||||
|
["Euclidian"] = EuclidianDistance,
|
||||||
|
["Manhattan"] = ManhattanDistance,
|
||||||
|
["Pearson"] = PearsonCorrelation
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
public static similarityMethodDelegate? GetMethod(string name)
|
||||||
|
{
|
||||||
|
string methodName = name;
|
||||||
|
|
||||||
|
if (!probMethods.TryGetValue(methodName, out similarityMethodProtoDelegate? method))
|
||||||
|
{
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return (vector1, vector2) => method(vector1, vector2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static float CosineSimilarity(float[] vector1, float[] vector2)
|
||||||
|
{
|
||||||
|
return (TensorPrimitives.CosineSimilarity(vector1, vector2) + 1) / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float EuclidianDistance(float[] vector1, float[] vector2)
|
||||||
|
{
|
||||||
|
if (vector1.Length != vector2.Length)
|
||||||
|
{
|
||||||
|
throw new ArgumentException("Unable to calculate Euclidian distance - Vectors must have the same length");
|
||||||
|
}
|
||||||
|
float sum = 0;
|
||||||
|
for (int i = 0; i < vector1.Length; i++)
|
||||||
|
{
|
||||||
|
float diff = vector1[i] - vector2[i];
|
||||||
|
sum += diff * diff;
|
||||||
|
}
|
||||||
|
return RationalRemap((float)Math.Sqrt(sum));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float ManhattanDistance(float[] vector1, float[] vector2)
|
||||||
|
{
|
||||||
|
if (vector1.Length != vector2.Length)
|
||||||
|
throw new ArgumentException("Unable to calculate Manhattan distance - Vectors must have the same length");
|
||||||
|
|
||||||
|
float sum = 0;
|
||||||
|
for (int i = 0; i < vector1.Length; i++)
|
||||||
|
{
|
||||||
|
sum += Math.Abs(vector1[i] - vector2[i]);
|
||||||
|
}
|
||||||
|
return RationalRemap(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float PearsonCorrelation(float[] vector1, float[] vector2)
|
||||||
|
{
|
||||||
|
if (vector1.Length != vector2.Length)
|
||||||
|
throw new ArgumentException("Unable to calculate Pearson correlation - Vectors must have the same length");
|
||||||
|
|
||||||
|
int n = vector1.Length;
|
||||||
|
double sum1 = vector1.Sum();
|
||||||
|
double sum2 = vector2.Sum();
|
||||||
|
double sum1Sq = vector1.Select(x => x * x).Sum();
|
||||||
|
double sum2Sq = vector2.Select(x => x * x).Sum();
|
||||||
|
double pSum = vector1.Zip(vector2, (x, y) => x * y).Sum();
|
||||||
|
|
||||||
|
double num = pSum - (sum1 * sum2 / n);
|
||||||
|
double den = Math.Sqrt((sum1Sq - (sum1 * sum1) / n) * (sum2Sq - (sum2 * sum2) / n));
|
||||||
|
|
||||||
|
return den == 0 ? 0 : (float)(num / den);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float RationalRemap(float x)
|
||||||
|
{
|
||||||
|
if (x == float.PositiveInfinity)
|
||||||
|
{
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return 1 / (1 + x);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -55,6 +55,8 @@ public class DatapointResult
|
|||||||
public required string Name { get; set; }
|
public required string Name { get; set; }
|
||||||
[JsonPropertyName("ProbMethod")]
|
[JsonPropertyName("ProbMethod")]
|
||||||
public required string ProbMethod { get; set; }
|
public required string ProbMethod { get; set; }
|
||||||
|
[JsonPropertyName("SimilarityMethod")]
|
||||||
|
public required string SimilarityMethod { get; set; }
|
||||||
[JsonPropertyName("Embeddings")]
|
[JsonPropertyName("Embeddings")]
|
||||||
public required List<EmbeddingResult>? Embeddings { get; set; }
|
public required List<EmbeddingResult>? Embeddings { get; set; }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,5 +14,6 @@ public class JSONDatapoint
|
|||||||
public required string Name { get; set; }
|
public required string Name { get; set; }
|
||||||
public required string Text { get; set; }
|
public required string Text { get; set; }
|
||||||
public required string Probmethod_embedding { get; set; }
|
public required string Probmethod_embedding { get; set; }
|
||||||
|
public required string SimilarityMethod { get; set; }
|
||||||
public required string[] Model { get; set; }
|
public required string[] Model { get; set; }
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user