legco_ai_assistant/backend/app/core/database.py

54 lines
1.7 KiB
Python

from pathlib import Path
import chromadb
from typing import Callable, Optional
from app.core.config import get_settings
class _EmbeddingFunctionWrapper:
def __init__(self, settings):
self.settings = settings
def name(self) -> str:
return "custom_embedding_wrapper"
def __call__(self, input):
from app.services.embedding_client import EmbeddingClient
import asyncio
from concurrent.futures import ThreadPoolExecutor
def _run_in_thread(texts):
client = EmbeddingClient(self.settings)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(client.embed(texts))
finally:
loop.close()
with ThreadPoolExecutor(max_workers=1) as executor:
return executor.submit(_run_in_thread, input).result()
def get_embedding_function_settings(settings):
"""Return a synchronous wrapper suitable for embedding functions in ChromaDB."""
# Lazy import to avoid import-time side effects in tests
try:
return _EmbeddingFunctionWrapper(settings)
except Exception:
return None
def get_chroma_client() -> chromadb.Client:
settings = get_settings()
persist_dir = Path(settings.chroma_db_path)
persist_dir.mkdir(parents=True, exist_ok=True)
return chromadb.PersistentClient(path=str(persist_dir))
def get_or_create_collection(client: chromadb.Client, name: str, embedding_function: Optional[Callable] = None):
if embedding_function is not None:
return client.get_or_create_collection(name=name, embedding_function=embedding_function)
return client.get_or_create_collection(name=name)