diff --git a/src/Server/SimilarityMethods.cs b/src/Server/SimilarityMethods.cs index 9521825..9d02c22 100644 --- a/src/Server/SimilarityMethods.cs +++ b/src/Server/SimilarityMethods.cs @@ -21,20 +21,28 @@ public class SimilarityMethod } } +public enum SimilarityMethodEnum +{ + Cosine, + Euclidian, + Manhattan, + Pearson +} + public static class SimilarityMethods { public delegate float similarityMethodProtoDelegate(float[] vector1, float[] vector2); public delegate float similarityMethodDelegate(float[] vector1, float[] vector2); - public static readonly Dictionary probMethods; + public static readonly Dictionary probMethods; static SimilarityMethods() { - probMethods = new Dictionary + probMethods = new Dictionary { - ["Cosine"] = CosineSimilarity, - ["Euclidian"] = EuclidianDistance, - ["Manhattan"] = ManhattanDistance, - ["Pearson"] = PearsonCorrelation + [SimilarityMethodEnum.Cosine] = CosineSimilarity, + [SimilarityMethodEnum.Euclidian] = EuclidianDistance, + [SimilarityMethodEnum.Manhattan] = ManhattanDistance, + [SimilarityMethodEnum.Pearson] = PearsonCorrelation }; } @@ -42,7 +50,12 @@ public static class SimilarityMethods { string methodName = name; - if (!probMethods.TryGetValue(methodName, out similarityMethodProtoDelegate? method)) + SimilarityMethodEnum probMethodEnum = (SimilarityMethodEnum)Enum.Parse( + typeof(SimilarityMethodEnum), + methodName + ); + + if (!probMethods.TryGetValue(probMethodEnum, out similarityMethodProtoDelegate? method)) { return null; }