Added embeddings prefetching for entities ingest
This commit is contained in:
@@ -52,9 +52,69 @@ public class Datapoint
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
public static Dictionary<string, Dictionary<string, float[]>> GetEmbeddings(string[] content, List<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||
{
|
||||
Dictionary<string, Dictionary<string, float[]>> embeddings = [];
|
||||
foreach (string model in models)
|
||||
{
|
||||
List<string> toBeGenerated = [];
|
||||
embeddings[model] = [];
|
||||
foreach (string value in content)
|
||||
{
|
||||
bool generateThisEntry = true;
|
||||
bool embeddingCacheHasContent = embeddingCache.TryGetValue(value, out var embeddingCacheForContent);
|
||||
if (embeddingCacheHasContent && embeddingCacheForContent is not null)
|
||||
{
|
||||
bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embedding);
|
||||
if (embeddingCacheHasModel && embedding is not null)
|
||||
{
|
||||
embeddings[model][value] = embedding;
|
||||
generateThisEntry = false;
|
||||
}
|
||||
}
|
||||
if (generateThisEntry)
|
||||
{
|
||||
if (!toBeGenerated.Contains(value))
|
||||
{
|
||||
toBeGenerated.Add(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
IEnumerable<float[]> generatedEmbeddings = GenerateEmbeddings([.. toBeGenerated], model, aIProvider, embeddingCache);
|
||||
if (generatedEmbeddings.Count() != toBeGenerated.Count)
|
||||
{
|
||||
throw new Exception("Requested embeddings count and generated embeddings count mismatched!");
|
||||
}
|
||||
for (int i = 0; i < toBeGenerated.Count; i++)
|
||||
{
|
||||
embeddings[model][toBeGenerated.ElementAt(i)] = generatedEmbeddings.ElementAt(i);
|
||||
}
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
public static IEnumerable<float[]> GenerateEmbeddings(string[] content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||
{
|
||||
IEnumerable<float[]> embeddings = aIProvider.GenerateEmbeddings(model, content);
|
||||
if (embeddings.Count() != content.Length)
|
||||
{
|
||||
throw new Exception("Resulting embeddings count does not match up with request count");
|
||||
}
|
||||
for (int i = 0; i < content.Length; i++)
|
||||
{
|
||||
if (!embeddingCache.ContainsKey(content[i]))
|
||||
{
|
||||
embeddingCache[content[i]] = [];
|
||||
}
|
||||
embeddingCache[content[i]][model] = embeddings.ElementAt(i);
|
||||
}
|
||||
return embeddings;
|
||||
}
|
||||
|
||||
|
||||
public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||
{
|
||||
float[] embeddings = aIProvider.GenerateEmbeddings(model, [content]);
|
||||
float[] embeddings = aIProvider.GenerateEmbeddings(model, content);
|
||||
if (!embeddingCache.ContainsKey(content))
|
||||
{
|
||||
embeddingCache[content] = [];
|
||||
|
||||
Reference in New Issue
Block a user