154 lines
4.9 KiB
Python
154 lines
4.9 KiB
Python
import logging
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from app.core.dependencies import get_prompt_service
|
|
from app.models.prompts import (
|
|
ProfileListResponse,
|
|
ProfileItem,
|
|
PromptSetResponse,
|
|
PromptUpdateRequest,
|
|
PromptBatchUpdateRequest,
|
|
ResetToDefaultsRequest,
|
|
ProfileExportResponse,
|
|
AllProfilesExportResponse,
|
|
ProfileImportRequest,
|
|
ProfileImportResponse,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/prompts", tags=["prompts"])
|
|
|
|
_VALID_NAMES = {"A", "B", "C"}
|
|
_VALID_STEPS = {
|
|
"decompose",
|
|
"decompose_format",
|
|
"filter",
|
|
"generate",
|
|
"generate_per_subq",
|
|
"filter_intro",
|
|
"filter_section",
|
|
"filter_outro",
|
|
}
|
|
|
|
|
|
def _ensure_valid_name(name: str) -> None:
|
|
if name not in _VALID_NAMES:
|
|
raise HTTPException(status_code=400, detail=f"Invalid profile name '{name}'. Must be one of A, B, C.")
|
|
|
|
|
|
def _ensure_valid_step(step: str) -> None:
|
|
if step not in _VALID_STEPS:
|
|
raise HTTPException(status_code=400, detail=f"Invalid step '{step}'. Must be one of decompose, filter, generate.")
|
|
|
|
|
|
_EXPORT_FORMAT = "legco-reranker-profile/v1"
|
|
|
|
|
|
@router.get("/export/all", response_model=AllProfilesExportResponse)
|
|
def export_all_profiles():
|
|
svc = get_prompt_service()
|
|
profiles_list = svc.list_profiles()
|
|
active_profile = svc.get_active_profile_name()
|
|
profiles_data: dict[str, dict] = {}
|
|
for p in profiles_list:
|
|
profiles_data[p["name"]] = {"prompts": svc.get_profile_prompts(p["name"])}
|
|
return AllProfilesExportResponse(
|
|
format=_EXPORT_FORMAT,
|
|
exported_at=datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
active_profile=active_profile,
|
|
profiles=profiles_data,
|
|
)
|
|
|
|
|
|
@router.get("/profiles/{name}/export", response_model=ProfileExportResponse)
|
|
def export_profile(name: str):
|
|
_ensure_valid_name(name)
|
|
svc = get_prompt_service()
|
|
prompts = svc.get_profile_prompts(name)
|
|
return JSONResponse(
|
|
content=ProfileExportResponse(
|
|
format=_EXPORT_FORMAT,
|
|
profile_name=name,
|
|
exported_at=datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ"),
|
|
prompts=prompts,
|
|
).model_dump(),
|
|
headers={"Content-Disposition": f'attachment; filename="legco-profile-{name}.json"'},
|
|
)
|
|
|
|
|
|
@router.post("/profiles/{name}/import", response_model=ProfileImportResponse)
|
|
def import_profile(name: str, body: ProfileImportRequest):
|
|
_ensure_valid_name(name)
|
|
if body.format != _EXPORT_FORMAT:
|
|
raise HTTPException(status_code=400, detail=f"Unsupported format '{body.format}'. Expected '{_EXPORT_FORMAT}'.")
|
|
provided = set(body.prompts.keys())
|
|
missing = _VALID_STEPS - provided
|
|
if missing:
|
|
raise HTTPException(status_code=400, detail=f"Missing required steps: {', '.join(sorted(missing))}")
|
|
unknown = provided - _VALID_STEPS
|
|
if unknown:
|
|
raise HTTPException(status_code=400, detail=f"Unknown steps: {', '.join(sorted(unknown))}")
|
|
svc = get_prompt_service()
|
|
svc.update_all_prompts(name, body.prompts)
|
|
return ProfileImportResponse(
|
|
status="ok",
|
|
profile=name,
|
|
imported_steps=len(body.prompts),
|
|
source_profile=body.profile_name,
|
|
)
|
|
|
|
|
|
@router.get("/profiles", response_model=ProfileListResponse)
|
|
def list_profiles():
|
|
svc = get_prompt_service()
|
|
profiles = [ProfileItem(**p) for p in svc.list_profiles()]
|
|
return ProfileListResponse(profiles=profiles)
|
|
|
|
|
|
@router.get("/profiles/{name}", response_model=PromptSetResponse)
|
|
def get_profile_prompts(name: str):
|
|
_ensure_valid_name(name)
|
|
svc = get_prompt_service()
|
|
prompts = svc.get_profile_prompts(name)
|
|
return PromptSetResponse(profile_name=name, prompts=prompts)
|
|
|
|
|
|
@router.put("/profiles/{name}/activate")
|
|
def activate_profile(name: str):
|
|
_ensure_valid_name(name)
|
|
svc = get_prompt_service()
|
|
svc.activate_profile(name)
|
|
return {"status": "ok", "active_profile": name}
|
|
|
|
|
|
@router.put("/profiles/{name}/all")
|
|
def update_all_prompts(name: str, body: PromptBatchUpdateRequest):
|
|
_ensure_valid_name(name)
|
|
svc = get_prompt_service()
|
|
svc.update_all_prompts(name, body.prompts)
|
|
return {"status": "ok", "profile": name}
|
|
|
|
|
|
@router.put("/profiles/{name}/reset")
|
|
def reset_to_defaults(name: str, body: ResetToDefaultsRequest | None = None):
|
|
_ensure_valid_name(name)
|
|
step = body.step if body else None
|
|
if step is not None:
|
|
_ensure_valid_step(step)
|
|
svc = get_prompt_service()
|
|
svc.reset_to_defaults(name, step=step)
|
|
return {"status": "ok", "profile": name, "reset_step": step or "all"}
|
|
|
|
|
|
@router.put("/profiles/{name}/{step}")
|
|
def update_prompt(name: str, step: str, body: PromptUpdateRequest):
|
|
_ensure_valid_name(name)
|
|
_ensure_valid_step(step)
|
|
svc = get_prompt_service()
|
|
svc.update_prompt(name, step, body.template)
|
|
return {"status": "ok", "profile": name, "step": step}
|