diff --git a/apps/llm/function.py b/apps/llm/function.py index 2f231b1f339e6e770a1a9dc54727c4457185ae99..4aac7e1b8bd17f0dd820896a14d180b511bd63b6 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -10,7 +10,7 @@ from typing import Any from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from jsonschema import Draft7Validator - +from jsonschema import validate from apps.common.config import Config from apps.constants import JSON_GEN_MAX_TRIAL, REASONING_END_TOKEN from apps.llm.prompt import JSON_GEN_BASIC @@ -233,6 +233,58 @@ class FunctionLLM: class JsonGenerator: """JSON生成器""" + @staticmethod + async def _parse_result_by_stack(result: str, schema: dict[str, Any]) -> str: + """解析推理结果""" + left_index = result.find('{') + right_index = result.rfind('}') + if left_index != -1 and right_index != -1 and left_index < right_index: + try: + tmp_js = json.loads(result[left_index:right_index + 1]) + validate(instance=tmp_js, schema=schema) + return tmp_js + except Exception as e: + logger.error("[JsonGenerator] 解析结果失败: %s", e) + stack = [] + json_candidates = [] + # 定义括号匹配关系 + bracket_map = {')': '(', ']': '[', '}': '{'} + + for i, char in enumerate(result): + # 遇到左括号则入栈 + if char in bracket_map.values(): + stack.append((char, i)) + # 遇到右括号且栈不为空时检查匹配 + elif char in bracket_map.keys() and stack: + if not stack: + continue + top_char, top_index = stack[-1] + # 检查是否匹配当前右括号 + if top_char == bracket_map[char]: + stack.pop() + # 当栈为空且当前是右花括号时,认为找到一个完整JSON + if not stack and char == '}': + json_str = result[top_index:i+1] + json_candidates.append(json_str) + else: + # 如果不匹配,清空栈 + stack.clear() + # 移除重复项并保持顺序 + seen = set() + unique_jsons = [] + for json_str in json_candidates[::]: + if json_str not in seen: + seen.add(json_str) + unique_jsons.append(json_str) + + for json_str in unique_jsons: + try: + tmp_js = json.loads(json_str) + validate(instance=tmp_js, schema=schema) + return tmp_js + except Exception as e: + logger.error("[JsonGenerator] 解析结果失败: %s", e) + return None def __init__(self, query: str, conversation: list[dict[str, str]], schema: dict[str, Any]) -> None: """初始化JSON生成器""" diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index 0b495d034586e0bc24d8469935d33a2914a95cc6..e7004a30c0602c073f793c269f172c0d8a8863f2 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -201,7 +201,9 @@ class QuestionRewrite(CorePattern): result += chunk self.input_tokens = llm.input_tokens self.output_tokens = llm.output_tokens - + tmp_js = await JsonGenerator._parse_result_by_stack(result, QuestionRewriteResult.model_json_schema()) + if tmp_js is not None: + return tmp_js['question'] messages += [{"role": "assistant", "content": result}] json_gen = JsonGenerator( query="根据给定的背景信息,生成预测问题", diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 07760953b38784807ebbb479af0dd8aaf2d2fc7a..103ec60daaccbdaa401e746360354b73341d8d08 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -30,63 +30,10 @@ class MCPBase: return result - @staticmethod - async def _parse_result_by_stack(result: str, schema: dict[str, Any]) -> str: - """解析推理结果""" - left_index = result.find('{') - right_index = result.rfind('}') - if left_index != -1 and right_index != -1 and left_index < right_index: - try: - tmp_js = json.loads(result[left_index:right_index + 1]) - validate(instance=tmp_js, schema=schema) - return tmp_js - except Exception as e: - logger.error("[McpBase] 解析结果失败: %s", e) - stack = [] - json_candidates = [] - # 定义括号匹配关系 - bracket_map = {')': '(', ']': '[', '}': '{'} - - for i, char in enumerate(result): - # 遇到左括号则入栈 - if char in bracket_map.values(): - stack.append((char, i)) - # 遇到右括号且栈不为空时检查匹配 - elif char in bracket_map.keys() and stack: - if not stack: - continue - top_char, top_index = stack[-1] - # 检查是否匹配当前右括号 - if top_char == bracket_map[char]: - stack.pop() - # 当栈为空且当前是右花括号时,认为找到一个完整JSON - if not stack and char == '}': - json_str = result[top_index:i+1] - json_candidates.append(json_str) - else: - # 如果不匹配,清空栈 - stack.clear() - # 移除重复项并保持顺序 - seen = set() - unique_jsons = [] - for json_str in json_candidates[::]: - if json_str not in seen: - seen.add(json_str) - unique_jsons.append(json_str) - - for json_str in unique_jsons: - try: - tmp_js = json.loads(json_str) - validate(instance=tmp_js, schema=schema) - return tmp_js - except Exception as e: - logger.error("[McpBase] 解析结果失败: %s", e) - return None - @staticmethod async def _parse_result(result: str, schema: dict[str, Any]) -> str: """解析推理结果""" - json_result = await MCPBase._parse_result_by_stack(result, schema) + json_result = await JsonGenerator._parse_result_by_stack(result, schema) if json_result is not None: return json_result json_generator = JsonGenerator(