Added new probmethods, Made Probmethods static
This commit is contained in:
@@ -8,6 +8,8 @@ example_content = "./Scripts/example_content"
|
|||||||
example_searchdomain = "example"
|
example_searchdomain = "example"
|
||||||
example_counter = 0
|
example_counter = 0
|
||||||
models = ["bge-m3", "mxbai-embed-large"]
|
models = ["bge-m3", "mxbai-embed-large"]
|
||||||
|
probmethod_datapoint = "HighValueEmphasisWeightedAverage"
|
||||||
|
probmethod_entity = "HighValueEmphasisWeightedAverage"
|
||||||
|
|
||||||
def init(toolset: Toolset):
|
def init(toolset: Toolset):
|
||||||
global example_counter
|
global example_counter
|
||||||
@@ -44,11 +46,11 @@ def index_files(toolset: Toolset):
|
|||||||
title = file.readline()
|
title = file.readline()
|
||||||
text = file.read()
|
text = file.read()
|
||||||
datapoints:list = [
|
datapoints:list = [
|
||||||
JSONDatapoint("filename", qualified_filepath, "wavg", models),
|
JSONDatapoint("filename", qualified_filepath, probmethod_datapoint, models),
|
||||||
JSONDatapoint("title", title, "wavg", models),
|
JSONDatapoint("title", title, probmethod_datapoint, models),
|
||||||
JSONDatapoint("text", text, "wavg", models)
|
JSONDatapoint("text", text, probmethod_datapoint, models)
|
||||||
]
|
]
|
||||||
jsonEntity:dict = asdict(JSONEntity(qualified_filepath, "wavg", example_searchdomain, {}, datapoints))
|
jsonEntity:dict = asdict(JSONEntity(qualified_filepath, probmethod_entity, example_searchdomain, {}, datapoints))
|
||||||
jsonEntities.append(jsonEntity)
|
jsonEntities.append(jsonEntity)
|
||||||
jsonstring = json.dumps(jsonEntities)
|
jsonstring = json.dumps(jsonEntities)
|
||||||
timer_start = time.time()
|
timer_start = time.time()
|
||||||
|
|||||||
@@ -1,72 +1,149 @@
|
|||||||
|
|
||||||
|
|
||||||
using System.Numerics.Tensors;
|
using System.Numerics.Tensors;
|
||||||
|
using System.Text.Json;
|
||||||
|
|
||||||
namespace Server;
|
namespace Server;
|
||||||
|
|
||||||
|
public static class Probmethods
|
||||||
public class Probmethods
|
|
||||||
{
|
{
|
||||||
public delegate float probMethodDelegate(List<(string, float)> list);
|
public delegate float probMethodDelegate(List<(string, float)> list);
|
||||||
public Dictionary<string, probMethodDelegate> probMethods;
|
public static readonly Dictionary<string, probMethodDelegate> probMethods;
|
||||||
|
|
||||||
public Probmethods(Dictionary<string, probMethodDelegate> probMethods)
|
static Probmethods()
|
||||||
{
|
{
|
||||||
this.probMethods = probMethods;
|
probMethods = new Dictionary<string, probMethodDelegate>
|
||||||
}
|
{
|
||||||
|
["Mean"] = Mean,
|
||||||
public Probmethods()
|
["HarmonicMean"] = HarmonicMean,
|
||||||
{
|
["QuadraticMean"] = QuadraticMean,
|
||||||
probMethods = [];
|
["GeometricMean"] = GeometricMean,
|
||||||
probMethods["wavg"] = WavgList;
|
["ExtremeValuesEmphasisWeightedAverage"] = ExtremeValuesEmphasisWeightedAverage,
|
||||||
probMethods["weighted_average"] = WavgList;
|
["EVEWAvg"] = ExtremeValuesEmphasisWeightedAverage,
|
||||||
|
["HighValueEmphasisWeightedAverage"] = HighValueEmphasisWeightedAverage,
|
||||||
|
["HVEWAvg"] = HighValueEmphasisWeightedAverage,
|
||||||
|
["LowValueEmphasisWeightedAverage"] = LowValueEmphasisWeightedAverage,
|
||||||
|
["LVEWAvg"] = LowValueEmphasisWeightedAverage
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
public probMethodDelegate? GetMethod(string name)
|
public static probMethodDelegate? GetMethod(string name)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
return probMethods[name];
|
return probMethods[name];
|
||||||
} catch (Exception)
|
}
|
||||||
|
catch
|
||||||
{
|
{
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static float Fact(float x)
|
public static float Mean(List<(string, float)> list)
|
||||||
{
|
{
|
||||||
return 1 / (1 - x);
|
if (list.Count == 0) return 0;
|
||||||
|
float sum = 0;
|
||||||
|
foreach ((_, float value) in list)
|
||||||
|
{
|
||||||
|
sum += value;
|
||||||
|
}
|
||||||
|
return sum / list.Count;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static float WavgList(List<(string, float)> list)
|
public static float HarmonicMean(List<(string, float)> list)
|
||||||
{
|
{
|
||||||
float[] arr = new float[list.Count];
|
int n_T = list.Count;
|
||||||
for (int i = 0; i < list.Count; i++)
|
float[] nonzeros = [.. list.Select(t => t.Item2).Where(t => t != 0)];
|
||||||
{
|
int n_nz = nonzeros.Length;
|
||||||
arr[i] = list.ElementAt(i).Item2;
|
if (n_nz == 0) return 0;
|
||||||
}
|
|
||||||
return Wavg(arr);
|
float nzSum = nonzeros.Sum(x => 1 / x);
|
||||||
|
return n_nz / nzSum * (n_nz / (float)n_T);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static float Wavg(float[] arr)
|
public static float QuadraticMean(List<(string, float)> list)
|
||||||
{
|
{
|
||||||
if (arr.Contains(1))
|
float sum = 0;
|
||||||
|
foreach (var (_, value) in list)
|
||||||
{
|
{
|
||||||
return 1;
|
sum += value * value;
|
||||||
}
|
}
|
||||||
float f = 0;
|
return (float)Math.Sqrt(sum / list.Count);
|
||||||
float fm = 0;
|
}
|
||||||
for (int i = 0; i < arr.Length; i++)
|
|
||||||
|
public static float GeometricMean(List<(string, float)> list)
|
||||||
|
{
|
||||||
|
if (list.Count == 0) return 0;
|
||||||
|
float product = 1;
|
||||||
|
foreach ((_, float value) in list)
|
||||||
{
|
{
|
||||||
float x = arr[i];
|
product *= value;
|
||||||
f += Fact(x);
|
}
|
||||||
fm += x * Fact(x);
|
return (float)Math.Pow(product, 1f / list.Count);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float ExtremeValuesEmphasisWeightedAverage(List<(string, float)> list)
|
||||||
|
{
|
||||||
|
float[] arr = [.. list.Select(x => x.Item2)];
|
||||||
|
if (arr.Contains(1)) return 1;
|
||||||
|
if (arr.Contains(0)) return 0;
|
||||||
|
|
||||||
|
float f = 0, fm = 0;
|
||||||
|
foreach (float x in arr)
|
||||||
|
{
|
||||||
|
f += x / (x * (1 - x));
|
||||||
|
fm += 1 / (x * (1 - x));
|
||||||
|
}
|
||||||
|
return f / fm;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float HighValueEmphasisWeightedAverage(List<(string, float)> list)
|
||||||
|
{
|
||||||
|
float[] arr = [.. list.Select(x => x.Item2)];
|
||||||
|
if (arr.Contains(1)) return 1;
|
||||||
|
|
||||||
|
float f = 0, fm = 0;
|
||||||
|
foreach (float x in arr)
|
||||||
|
{
|
||||||
|
f += x / (1 - x);
|
||||||
|
fm += 1 / (1 - x);
|
||||||
|
}
|
||||||
|
return f / fm;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float LowValueEmphasisWeightedAverage(List<(string, float)> list)
|
||||||
|
{
|
||||||
|
float[] arr = [.. list.Select(x => x.Item2)];
|
||||||
|
if (arr.Contains(0)) return 0;
|
||||||
|
|
||||||
|
float f = 0, fm = 0;
|
||||||
|
foreach (float x in arr)
|
||||||
|
{
|
||||||
|
f += 1;
|
||||||
|
fm += 1 / x;
|
||||||
|
}
|
||||||
|
return f / fm;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static float DictionaryWeightedAverage(List<(string, float)> list, string jsonValues)
|
||||||
|
{
|
||||||
|
var values = JsonSerializer.Deserialize<Dictionary<string, float>>(jsonValues)
|
||||||
|
?? throw new Exception($"Unable to convert the string to a Dictionary<string,float>: {jsonValues}");
|
||||||
|
|
||||||
|
float f = 0, fm = 0;
|
||||||
|
foreach (var (key, value) in list)
|
||||||
|
{
|
||||||
|
float fact = 1;
|
||||||
|
if (values.TryGetValue(key, out float factor))
|
||||||
|
{
|
||||||
|
fact *= factor;
|
||||||
|
}
|
||||||
|
f += fact * value;
|
||||||
|
fm += fact;
|
||||||
}
|
}
|
||||||
return f / fm;
|
return f / fm;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static float Similarity(float[] vector1, float[] vector2)
|
public static float Similarity(float[] vector1, float[] vector2)
|
||||||
{
|
{
|
||||||
return (float) TensorPrimitives.CosineSimilarity(vector1, vector2);
|
return TensorPrimitives.CosineSimilarity(vector1, vector2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ public class Searchdomain
|
|||||||
private readonly string _connectionString;
|
private readonly string _connectionString;
|
||||||
private readonly string _provider;
|
private readonly string _provider;
|
||||||
public OllamaApiClient ollama;
|
public OllamaApiClient ollama;
|
||||||
public Probmethods probmethods;
|
|
||||||
public string searchdomain;
|
public string searchdomain;
|
||||||
public int id;
|
public int id;
|
||||||
public Dictionary<string, List<(DateTime, List<(float, string)>)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x
|
public Dictionary<string, List<(DateTime, List<(float, string)>)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x
|
||||||
@@ -54,7 +53,6 @@ public class Searchdomain
|
|||||||
connection = new MySqlConnection(connectionString);
|
connection = new MySqlConnection(connectionString);
|
||||||
connection.Open();
|
connection.Open();
|
||||||
helper = new SQLHelper(connection);
|
helper = new SQLHelper(connection);
|
||||||
probmethods = new();
|
|
||||||
modelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that...
|
modelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that...
|
||||||
if (!runEmpty)
|
if (!runEmpty)
|
||||||
{
|
{
|
||||||
@@ -101,7 +99,7 @@ public class Searchdomain
|
|||||||
string name = datapointReader.GetString(2);
|
string name = datapointReader.GetString(2);
|
||||||
string probmethodString = datapointReader.GetString(3);
|
string probmethodString = datapointReader.GetString(3);
|
||||||
string hash = datapointReader.GetString(4);
|
string hash = datapointReader.GetString(4);
|
||||||
Probmethods.probMethodDelegate? probmethod = probmethods.GetMethod(probmethodString);
|
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
|
||||||
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
if (embedding_unassigned.TryGetValue(id, out Dictionary<string, float[]>? embeddings) && probmethod is not null)
|
||||||
{
|
{
|
||||||
embedding_unassigned.Remove(id);
|
embedding_unassigned.Remove(id);
|
||||||
@@ -142,7 +140,7 @@ public class Searchdomain
|
|||||||
{
|
{
|
||||||
attributes = [];
|
attributes = [];
|
||||||
}
|
}
|
||||||
Probmethods.probMethodDelegate? probmethod = probmethods.GetMethod(probmethodString);
|
Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString);
|
||||||
if (datapoint_unassigned.TryGetValue(id, out List<Datapoint>? datapoints) && probmethod is not null)
|
if (datapoint_unassigned.TryGetValue(id, out List<Datapoint>? datapoints) && probmethod is not null)
|
||||||
{
|
{
|
||||||
Entity entity = new(attributes, probmethod, datapoints, name)
|
Entity entity = new(attributes, probmethod, datapoints, name)
|
||||||
@@ -292,7 +290,7 @@ public class Searchdomain
|
|||||||
{
|
{
|
||||||
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
|
embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache);
|
||||||
}
|
}
|
||||||
var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
var probMethod_embedding = Probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}");
|
||||||
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]);
|
||||||
int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity);
|
int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity);
|
||||||
List<(string model, byte[] embedding)> data = [];
|
List<(string model, byte[] embedding)> data = [];
|
||||||
@@ -304,7 +302,7 @@ public class Searchdomain
|
|||||||
datapoints.Add(datapoint);
|
datapoints.Add(datapoint);
|
||||||
}
|
}
|
||||||
|
|
||||||
var probMethod = probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new Exception($"Unknown probmethod name {jsonEntity.Probmethod}");
|
var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new Exception($"Unknown probmethod name {jsonEntity.Probmethod}");
|
||||||
Entity entity = new(jsonEntity.Attributes, probMethod, datapoints, jsonEntity.Name)
|
Entity entity = new(jsonEntity.Attributes, probMethod, datapoints, jsonEntity.Name)
|
||||||
{
|
{
|
||||||
id = id_entity
|
id = id_entity
|
||||||
|
|||||||
Reference in New Issue
Block a user