diff --git a/backend/app/core/database.py b/backend/app/core/database.py index 1a19157..bb69071 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -14,6 +14,12 @@ class _EmbeddingFunctionWrapper: return "custom_embedding_wrapper" def __call__(self, input): + return self._embed(input) + + def embed_query(self, input): + return self._embed(input) + + def _embed(self, texts): from app.services.embedding_client import EmbeddingClient import asyncio from concurrent.futures import ThreadPoolExecutor @@ -28,7 +34,7 @@ class _EmbeddingFunctionWrapper: loop.close() with ThreadPoolExecutor(max_workers=1) as executor: - return executor.submit(_run_in_thread, input).result() + return executor.submit(_run_in_thread, texts).result() def get_embedding_function_settings(settings):