60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
from pathlib import Path
|
|
|
|
import chromadb
|
|
from typing import Callable, Optional
|
|
|
|
from app.core.config import get_settings
|
|
|
|
|
|
class _EmbeddingFunctionWrapper:
|
|
def __init__(self, settings):
|
|
self.settings = settings
|
|
|
|
def name(self) -> str:
|
|
return "custom_embedding_wrapper"
|
|
|
|
def __call__(self, input):
|
|
return self._embed(input)
|
|
|
|
def embed_query(self, input):
|
|
return self._embed(input)
|
|
|
|
def _embed(self, texts):
|
|
from app.services.embedding_client import EmbeddingClient
|
|
import asyncio
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
def _run_in_thread(texts):
|
|
client = EmbeddingClient(self.settings)
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
return loop.run_until_complete(client.embed(texts))
|
|
finally:
|
|
loop.close()
|
|
|
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
|
return executor.submit(_run_in_thread, texts).result()
|
|
|
|
|
|
def get_embedding_function_settings(settings):
|
|
"""Return a synchronous wrapper suitable for embedding functions in ChromaDB."""
|
|
# Lazy import to avoid import-time side effects in tests
|
|
try:
|
|
return _EmbeddingFunctionWrapper(settings)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def get_chroma_client() -> chromadb.Client:
|
|
settings = get_settings()
|
|
persist_dir = Path(settings.chroma_db_path)
|
|
persist_dir.mkdir(parents=True, exist_ok=True)
|
|
return chromadb.PersistentClient(path=str(persist_dir))
|
|
|
|
|
|
def get_or_create_collection(client: chromadb.Client, name: str, embedding_function: Optional[Callable] = None):
|
|
if embedding_function is not None:
|
|
return client.get_or_create_collection(name=name, embedding_function=embedding_function)
|
|
return client.get_or_create_collection(name=name)
|