diff --git a/backend/app/core/database.py b/backend/app/core/database.py index c5ddccf..d54f06f 100644 --- a/backend/app/core/database.py +++ b/backend/app/core/database.py @@ -1,10 +1,42 @@ from pathlib import Path import chromadb +from typing import Callable, Optional from app.core.config import get_settings +class _EmbeddingFunctionWrapper: + def __init__(self, settings): + self.settings = settings + + def __call__(self, input): + from app.services.embedding_client import EmbeddingClient + import asyncio + from concurrent.futures import ThreadPoolExecutor + + def _run_in_thread(texts): + client = EmbeddingClient(self.settings) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return loop.run_until_complete(client.embed(texts)) + finally: + loop.close() + + with ThreadPoolExecutor(max_workers=1) as executor: + return executor.submit(_run_in_thread, input).result() + + +def get_embedding_function_settings(settings): + """Return a synchronous wrapper suitable for embedding functions in ChromaDB.""" + # Lazy import to avoid import-time side effects in tests + try: + return _EmbeddingFunctionWrapper(settings) + except Exception: + return None + + def get_chroma_client() -> chromadb.Client: settings = get_settings() persist_dir = Path(settings.chroma_db_path) @@ -12,5 +44,7 @@ def get_chroma_client() -> chromadb.Client: return chromadb.PersistentClient(path=str(persist_dir)) -def get_or_create_collection(client: chromadb.Client, name: str): +def get_or_create_collection(client: chromadb.Client, name: str, embedding_function: Optional[Callable] = None): + if embedding_function is not None: + return client.get_or_create_collection(name=name, embedding_function=embedding_function) return client.get_or_create_collection(name=name)