fix(backend): extract JSON from markdown code blocks in LLM responses
The LLM (Qwen3.5 via OpenRouter) returns JSON wrapped in markdown code blocks: ```json ["project manager", "limits", ...] ``` But the code was trying to parse this directly with json.loads(), causing: - QueryDecomposer to return empty keywords - RelevanceFilter to fail with "Expecting value: line 1 column 1" Changes: - Added _extract_json_from_markdown() helper function to both modules - Strips markdown code block markers (```json and ```) before JSON parsing - Added unit tests for markdown code block handling Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus \u003cclio-agent@sisyphuslabs.ai\u003e
This commit is contained in:
parent
675b1d573b
commit
33b960f786
|
|
@ -8,12 +8,23 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_from_markdown(response: str) -> str:
|
||||||
|
if not isinstance(response, str):
|
||||||
|
return str(response)
|
||||||
|
pattern = r"```(?:json)?\s*\n?(.*?)\n?```"
|
||||||
|
match = re.search(pattern, response, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
class QueryDecomposer:
|
class QueryDecomposer:
|
||||||
"""Decompose a natural language question into a list of keywords.
|
"""Decompose a natural language question into a list of keywords.
|
||||||
|
|
||||||
|
|
@ -51,6 +62,8 @@ class QueryDecomposer:
|
||||||
if not isinstance(response, str):
|
if not isinstance(response, str):
|
||||||
response = str(response)
|
response = str(response)
|
||||||
|
|
||||||
|
response = _extract_json_from_markdown(response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(response)
|
data = json.loads(response)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,23 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import List, Tuple, Dict
|
from typing import List, Tuple, Dict
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_from_markdown(response: str) -> str:
|
||||||
|
if not isinstance(response, str):
|
||||||
|
return str(response)
|
||||||
|
pattern = r"```(?:json)?\s*\n?(.*?)\n?```"
|
||||||
|
match = re.search(pattern, response, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
class RelevanceFilter:
|
class RelevanceFilter:
|
||||||
"""RelevanceFilter batches chunk texts to an LLM and selects those with
|
"""RelevanceFilter batches chunk texts to an LLM and selects those with
|
||||||
relevance scores above a threshold.
|
relevance scores above a threshold.
|
||||||
|
|
@ -43,6 +54,7 @@ class RelevanceFilter:
|
||||||
|
|
||||||
scores: List[float] = []
|
scores: List[float] = []
|
||||||
try:
|
try:
|
||||||
|
response = _extract_json_from_markdown(response)
|
||||||
parsed = json.loads(response)
|
parsed = json.loads(response)
|
||||||
if not isinstance(parsed, list):
|
if not isinstance(parsed, list):
|
||||||
return []
|
return []
|
||||||
|
|
|
||||||
|
|
@ -55,3 +55,17 @@ async def test_decompose_mixed_types_coerced_to_strings():
|
||||||
decomposer = QueryDecomposer(llm)
|
decomposer = QueryDecomposer(llm)
|
||||||
result = await decomposer.decompose("Question?")
|
result = await decomposer.decompose("Question?")
|
||||||
assert result == ["a", "2", "None"]
|
assert result == ["a", "2", "None"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_decompose_json_in_markdown_code_block():
|
||||||
|
llm = MockLLMClient('```json\n["project", "manager", "limits"]\n```')
|
||||||
|
decomposer = QueryDecomposer(llm)
|
||||||
|
result = await decomposer.decompose("What are the limits?")
|
||||||
|
assert result == ["project", "manager", "limits"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_decompose_json_in_plain_code_block():
|
||||||
|
llm = MockLLMClient('```\n["alpha", "beta"]\n```')
|
||||||
|
decomposer = QueryDecomposer(llm)
|
||||||
|
result = await decomposer.decompose("Keywords?")
|
||||||
|
assert result == ["alpha", "beta"]
|
||||||
|
|
|
||||||
|
|
@ -66,3 +66,15 @@ async def test_filter_all_outside_threshold():
|
||||||
rf = RelevanceFilter(llm)
|
rf = RelevanceFilter(llm)
|
||||||
result = await rf.filter("Question", chunks, threshold=5.0)
|
result = await rf.filter("Question", chunks, threshold=5.0)
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_filter_json_in_markdown_code_block():
|
||||||
|
chunks = _make_chunks()
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.complete = AsyncMock(return_value="```json\n[8.0, 3.0, 9.0]\n```")
|
||||||
|
|
||||||
|
rf = RelevanceFilter(llm)
|
||||||
|
result = await rf.filter("Question", chunks, threshold=7.0)
|
||||||
|
|
||||||
|
expected = [chunks[0], chunks[2]]
|
||||||
|
assert result == expected
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue