From bc8ba893e0b653c78b64c29f6afd2faf72cc203b Mon Sep 17 00:00:00 2001 From: LD-Reborn Date: Sat, 21 Jun 2025 14:22:25 +0200 Subject: [PATCH] Added new probmethods, Made Probmethods static --- src/Indexer/Scripts/example.py | 10 ++- src/Server/Probmethods.cs | 149 +++++++++++++++++++++++++-------- src/Server/Searchdomain.cs | 10 +-- 3 files changed, 123 insertions(+), 46 deletions(-) diff --git a/src/Indexer/Scripts/example.py b/src/Indexer/Scripts/example.py index 79c83ac..da279bf 100644 --- a/src/Indexer/Scripts/example.py +++ b/src/Indexer/Scripts/example.py @@ -8,6 +8,8 @@ example_content = "./Scripts/example_content" example_searchdomain = "example" example_counter = 0 models = ["bge-m3", "mxbai-embed-large"] +probmethod_datapoint = "HighValueEmphasisWeightedAverage" +probmethod_entity = "HighValueEmphasisWeightedAverage" def init(toolset: Toolset): global example_counter @@ -44,11 +46,11 @@ def index_files(toolset: Toolset): title = file.readline() text = file.read() datapoints:list = [ - JSONDatapoint("filename", qualified_filepath, "wavg", models), - JSONDatapoint("title", title, "wavg", models), - JSONDatapoint("text", text, "wavg", models) + JSONDatapoint("filename", qualified_filepath, probmethod_datapoint, models), + JSONDatapoint("title", title, probmethod_datapoint, models), + JSONDatapoint("text", text, probmethod_datapoint, models) ] - jsonEntity:dict = asdict(JSONEntity(qualified_filepath, "wavg", example_searchdomain, {}, datapoints)) + jsonEntity:dict = asdict(JSONEntity(qualified_filepath, probmethod_entity, example_searchdomain, {}, datapoints)) jsonEntities.append(jsonEntity) jsonstring = json.dumps(jsonEntities) timer_start = time.time() diff --git a/src/Server/Probmethods.cs b/src/Server/Probmethods.cs index ae59d92..adcf237 100644 --- a/src/Server/Probmethods.cs +++ b/src/Server/Probmethods.cs @@ -1,72 +1,149 @@ - - using System.Numerics.Tensors; +using System.Text.Json; namespace Server; - -public class Probmethods +public static class Probmethods { public delegate float probMethodDelegate(List<(string, float)> list); - public Dictionary probMethods; + public static readonly Dictionary probMethods; - public Probmethods(Dictionary probMethods) + static Probmethods() { - this.probMethods = probMethods; - } - - public Probmethods() - { - probMethods = []; - probMethods["wavg"] = WavgList; - probMethods["weighted_average"] = WavgList; + probMethods = new Dictionary + { + ["Mean"] = Mean, + ["HarmonicMean"] = HarmonicMean, + ["QuadraticMean"] = QuadraticMean, + ["GeometricMean"] = GeometricMean, + ["ExtremeValuesEmphasisWeightedAverage"] = ExtremeValuesEmphasisWeightedAverage, + ["EVEWAvg"] = ExtremeValuesEmphasisWeightedAverage, + ["HighValueEmphasisWeightedAverage"] = HighValueEmphasisWeightedAverage, + ["HVEWAvg"] = HighValueEmphasisWeightedAverage, + ["LowValueEmphasisWeightedAverage"] = LowValueEmphasisWeightedAverage, + ["LVEWAvg"] = LowValueEmphasisWeightedAverage + }; } - public probMethodDelegate? GetMethod(string name) + public static probMethodDelegate? GetMethod(string name) { try { return probMethods[name]; - } catch (Exception) + } + catch { return null; } } - public static float Fact(float x) + public static float Mean(List<(string, float)> list) { - return 1 / (1 - x); + if (list.Count == 0) return 0; + float sum = 0; + foreach ((_, float value) in list) + { + sum += value; + } + return sum / list.Count; } - public static float WavgList(List<(string, float)> list) + public static float HarmonicMean(List<(string, float)> list) { - float[] arr = new float[list.Count]; - for (int i = 0; i < list.Count; i++) - { - arr[i] = list.ElementAt(i).Item2; - } - return Wavg(arr); + int n_T = list.Count; + float[] nonzeros = [.. list.Select(t => t.Item2).Where(t => t != 0)]; + int n_nz = nonzeros.Length; + if (n_nz == 0) return 0; + + float nzSum = nonzeros.Sum(x => 1 / x); + return n_nz / nzSum * (n_nz / (float)n_T); } - public static float Wavg(float[] arr) + public static float QuadraticMean(List<(string, float)> list) { - if (arr.Contains(1)) + float sum = 0; + foreach (var (_, value) in list) { - return 1; + sum += value * value; } - float f = 0; - float fm = 0; - for (int i = 0; i < arr.Length; i++) + return (float)Math.Sqrt(sum / list.Count); + } + + public static float GeometricMean(List<(string, float)> list) + { + if (list.Count == 0) return 0; + float product = 1; + foreach ((_, float value) in list) { - float x = arr[i]; - f += Fact(x); - fm += x * Fact(x); + product *= value; + } + return (float)Math.Pow(product, 1f / list.Count); + } + + public static float ExtremeValuesEmphasisWeightedAverage(List<(string, float)> list) + { + float[] arr = [.. list.Select(x => x.Item2)]; + if (arr.Contains(1)) return 1; + if (arr.Contains(0)) return 0; + + float f = 0, fm = 0; + foreach (float x in arr) + { + f += x / (x * (1 - x)); + fm += 1 / (x * (1 - x)); + } + return f / fm; + } + + public static float HighValueEmphasisWeightedAverage(List<(string, float)> list) + { + float[] arr = [.. list.Select(x => x.Item2)]; + if (arr.Contains(1)) return 1; + + float f = 0, fm = 0; + foreach (float x in arr) + { + f += x / (1 - x); + fm += 1 / (1 - x); + } + return f / fm; + } + + public static float LowValueEmphasisWeightedAverage(List<(string, float)> list) + { + float[] arr = [.. list.Select(x => x.Item2)]; + if (arr.Contains(0)) return 0; + + float f = 0, fm = 0; + foreach (float x in arr) + { + f += 1; + fm += 1 / x; + } + return f / fm; + } + + public static float DictionaryWeightedAverage(List<(string, float)> list, string jsonValues) + { + var values = JsonSerializer.Deserialize>(jsonValues) + ?? throw new Exception($"Unable to convert the string to a Dictionary: {jsonValues}"); + + float f = 0, fm = 0; + foreach (var (key, value) in list) + { + float fact = 1; + if (values.TryGetValue(key, out float factor)) + { + fact *= factor; + } + f += fact * value; + fm += fact; } return f / fm; } public static float Similarity(float[] vector1, float[] vector2) { - return (float) TensorPrimitives.CosineSimilarity(vector1, vector2); + return TensorPrimitives.CosineSimilarity(vector1, vector2); } -} \ No newline at end of file +} diff --git a/src/Server/Searchdomain.cs b/src/Server/Searchdomain.cs index d0157ad..babc300 100644 --- a/src/Server/Searchdomain.cs +++ b/src/Server/Searchdomain.cs @@ -29,7 +29,6 @@ public class Searchdomain private readonly string _connectionString; private readonly string _provider; public OllamaApiClient ollama; - public Probmethods probmethods; public string searchdomain; public int id; public Dictionary)>> searchCache; // Yeah look at this abomination. searchCache[x][0] = last accessed time, searchCache[x][1] = results for x @@ -54,7 +53,6 @@ public class Searchdomain connection = new MySqlConnection(connectionString); connection.Open(); helper = new SQLHelper(connection); - probmethods = new(); modelsInUse = []; // To make the compiler shut up - it is set in UpdateSearchDomain() don't worry // yeah, about that... if (!runEmpty) { @@ -101,7 +99,7 @@ public class Searchdomain string name = datapointReader.GetString(2); string probmethodString = datapointReader.GetString(3); string hash = datapointReader.GetString(4); - Probmethods.probMethodDelegate? probmethod = probmethods.GetMethod(probmethodString); + Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString); if (embedding_unassigned.TryGetValue(id, out Dictionary? embeddings) && probmethod is not null) { embedding_unassigned.Remove(id); @@ -142,7 +140,7 @@ public class Searchdomain { attributes = []; } - Probmethods.probMethodDelegate? probmethod = probmethods.GetMethod(probmethodString); + Probmethods.probMethodDelegate? probmethod = Probmethods.GetMethod(probmethodString); if (datapoint_unassigned.TryGetValue(id, out List? datapoints) && probmethod is not null) { Entity entity = new(attributes, probmethod, datapoints, name) @@ -292,7 +290,7 @@ public class Searchdomain { embeddings = Datapoint.GenerateEmbeddings(jsonDatapoint.Text, [.. jsonDatapoint.Model], ollama, embeddingCache); } - var probMethod_embedding = probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); + var probMethod_embedding = Probmethods.GetMethod(jsonDatapoint.Probmethod_embedding) ?? throw new Exception($"Unknown probmethod name {jsonDatapoint.Probmethod_embedding}"); Datapoint datapoint = new(jsonDatapoint.Name, probMethod_embedding, hash, [.. embeddings.Select(kv => (kv.Key, kv.Value))]); int id_datapoint = DatabaseInsertDatapoint(jsonDatapoint.Name, jsonDatapoint.Probmethod_embedding, hash, id_entity); List<(string model, byte[] embedding)> data = []; @@ -304,7 +302,7 @@ public class Searchdomain datapoints.Add(datapoint); } - var probMethod = probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new Exception($"Unknown probmethod name {jsonEntity.Probmethod}"); + var probMethod = Probmethods.GetMethod(jsonEntity.Probmethod) ?? throw new Exception($"Unknown probmethod name {jsonEntity.Probmethod}"); Entity entity = new(jsonEntity.Attributes, probMethod, datapoints, jsonEntity.Name) { id = id_entity