Merge pull request #96 from LD-Reborn/94-implement-datapoint-embeddings-generation-reordering
Added embeddings prefetching for entities ingest
This commit is contained in:
@@ -31,7 +31,12 @@ public class AIProvider
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public float[] GenerateEmbeddings(string modelUri, string[] input)
|
public float[] GenerateEmbeddings(string modelUri, string input)
|
||||||
|
{
|
||||||
|
return [.. GenerateEmbeddings(modelUri, [input]).First()];
|
||||||
|
}
|
||||||
|
|
||||||
|
public IEnumerable<float[]> GenerateEmbeddings(string modelUri, string[] input)
|
||||||
{
|
{
|
||||||
Uri uri = new(modelUri);
|
Uri uri = new(modelUri);
|
||||||
string provider = uri.Scheme;
|
string provider = uri.Scheme;
|
||||||
@@ -103,13 +108,13 @@ public class AIProvider
|
|||||||
try
|
try
|
||||||
{
|
{
|
||||||
JObject responseContentJson = JObject.Parse(responseContent);
|
JObject responseContentJson = JObject.Parse(responseContent);
|
||||||
JToken? responseContentTokens = responseContentJson.SelectToken(embeddingsJsonPath);
|
List<JToken>? responseContentTokens = [.. responseContentJson.SelectTokens(embeddingsJsonPath)];
|
||||||
if (responseContentTokens is null)
|
if (responseContentTokens is null)
|
||||||
{
|
{
|
||||||
_logger.LogError("Unable to select tokens using JSONPath {embeddingsJsonPath} for string: {responseContent}.", [embeddingsJsonPath, responseContent]);
|
_logger.LogError("Unable to select tokens using JSONPath {embeddingsJsonPath} for string: {responseContent}.", [embeddingsJsonPath, responseContent]);
|
||||||
throw new JSONPathSelectionException(embeddingsJsonPath, responseContent);
|
throw new JSONPathSelectionException(embeddingsJsonPath, responseContent);
|
||||||
}
|
}
|
||||||
return [.. responseContentTokens.Values<float>()];
|
return [.. responseContentTokens.Select(token => token.ToObject<float[]>() ?? throw new Exception("Unable to cast embeddings response to float[]"))];
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -52,9 +52,69 @@ public class Datapoint
|
|||||||
return embeddings;
|
return embeddings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static Dictionary<string, Dictionary<string, float[]>> GetEmbeddings(string[] content, List<string> models, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
|
{
|
||||||
|
Dictionary<string, Dictionary<string, float[]>> embeddings = [];
|
||||||
|
foreach (string model in models)
|
||||||
|
{
|
||||||
|
List<string> toBeGenerated = [];
|
||||||
|
embeddings[model] = [];
|
||||||
|
foreach (string value in content)
|
||||||
|
{
|
||||||
|
bool generateThisEntry = true;
|
||||||
|
bool embeddingCacheHasContent = embeddingCache.TryGetValue(value, out var embeddingCacheForContent);
|
||||||
|
if (embeddingCacheHasContent && embeddingCacheForContent is not null)
|
||||||
|
{
|
||||||
|
bool embeddingCacheHasModel = embeddingCacheForContent.TryGetValue(model, out float[]? embedding);
|
||||||
|
if (embeddingCacheHasModel && embedding is not null)
|
||||||
|
{
|
||||||
|
embeddings[model][value] = embedding;
|
||||||
|
generateThisEntry = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (generateThisEntry)
|
||||||
|
{
|
||||||
|
if (!toBeGenerated.Contains(value))
|
||||||
|
{
|
||||||
|
toBeGenerated.Add(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
IEnumerable<float[]> generatedEmbeddings = GenerateEmbeddings([.. toBeGenerated], model, aIProvider, embeddingCache);
|
||||||
|
if (generatedEmbeddings.Count() != toBeGenerated.Count)
|
||||||
|
{
|
||||||
|
throw new Exception("Requested embeddings count and generated embeddings count mismatched!");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < toBeGenerated.Count; i++)
|
||||||
|
{
|
||||||
|
embeddings[model][toBeGenerated.ElementAt(i)] = generatedEmbeddings.ElementAt(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return embeddings;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static IEnumerable<float[]> GenerateEmbeddings(string[] content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
|
{
|
||||||
|
IEnumerable<float[]> embeddings = aIProvider.GenerateEmbeddings(model, content);
|
||||||
|
if (embeddings.Count() != content.Length)
|
||||||
|
{
|
||||||
|
throw new Exception("Resulting embeddings count does not match up with request count");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < content.Length; i++)
|
||||||
|
{
|
||||||
|
if (!embeddingCache.ContainsKey(content[i]))
|
||||||
|
{
|
||||||
|
embeddingCache[content[i]] = [];
|
||||||
|
}
|
||||||
|
embeddingCache[content[i]][model] = embeddings.ElementAt(i);
|
||||||
|
}
|
||||||
|
return embeddings;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
public static float[] GenerateEmbeddings(string content, string model, AIProvider aIProvider, EnumerableLruCache<string, Dictionary<string, float[]>> embeddingCache)
|
||||||
{
|
{
|
||||||
float[] embeddings = aIProvider.GenerateEmbeddings(model, [content]);
|
float[] embeddings = aIProvider.GenerateEmbeddings(model, content);
|
||||||
if (!embeddingCache.ContainsKey(content))
|
if (!embeddingCache.ContainsKey(content))
|
||||||
{
|
{
|
||||||
embeddingCache[content] = [];
|
embeddingCache[content] = [];
|
||||||
|
|||||||
@@ -74,6 +74,12 @@ public class SearchdomainHelper(ILogger<SearchdomainHelper> logger, DatabaseHelp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
foreach (var toBeCachedKV in toBeCached)
|
||||||
|
{
|
||||||
|
string model = toBeCachedKV.Key;
|
||||||
|
List<string> uniqueStrings = [.. toBeCachedKV.Value.Distinct()];
|
||||||
|
Datapoint.GetEmbeddings([.. uniqueStrings], [model], aIProvider, embeddingCache);
|
||||||
|
}
|
||||||
ConcurrentQueue<Entity> retVal = [];
|
ConcurrentQueue<Entity> retVal = [];
|
||||||
ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool
|
ParallelOptions parallelOptions = new() { MaxDegreeOfParallelism = 16 }; // <-- This is needed! Otherwise if we try to index 100+ entities at once, it spawns 100 threads, exploding the SQL pool
|
||||||
Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity =>
|
Parallel.ForEach(jsonEntities, parallelOptions, jSONEntity =>
|
||||||
|
|||||||
Reference in New Issue
Block a user