From d0c26e1048b8e4a134d3d56227b6700ed6ae3a2f Mon Sep 17 00:00:00 2001 From: z30057876 Date: Thu, 21 Aug 2025 16:21:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E5=90=88agent=E5=88=86=E6=94=AF?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20&=20=E4=BF=AE=E6=AD=A3=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/function.py | 59 +++++++- apps/llm/patterns/rewrite.py | 64 ++++++--- apps/llm/prompt.py | 2 - apps/models/service.py | 2 - apps/routers/mcp_service.py | 6 +- apps/routers/parameter.py | 3 +- apps/routers/record.py | 3 +- apps/routers/service.py | 11 +- apps/scheduler/executor/agent.py | 219 ++++++++++++++--------------- apps/scheduler/executor/base.py | 4 +- apps/scheduler/executor/flow.py | 41 +++--- apps/scheduler/executor/qa.py | 1 - apps/scheduler/executor/step.py | 2 +- apps/scheduler/mcp_agent/base.py | 42 +----- apps/scheduler/mcp_agent/plan.py | 12 +- apps/scheduler/mcp_agent/prompt.py | 32 +++-- apps/scheduler/pool/mcp/client.py | 5 +- apps/scheduler/pool/mcp/pool.py | 4 +- apps/schemas/parameters.py | 2 + apps/schemas/response_data.py | 7 +- apps/services/task.py | 4 +- 21 files changed, 271 insertions(+), 254 deletions(-) diff --git a/apps/llm/function.py b/apps/llm/function.py index cd000c665..869dbb43a 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -234,6 +234,63 @@ class FunctionLLM: class JsonGenerator: """JSON生成器""" + @staticmethod + async def parse_result_by_stack(result: str, schema: dict[str, Any]) -> str: # noqa: C901, PLR0912 + """解析推理结果""" + validator = Draft7Validator(schema) + left_index = result.find("{") + right_index = result.rfind("}") + if left_index != -1 and right_index != -1 and left_index < right_index: + try: + tmp_js = json.loads(result[left_index:right_index + 1]) + validator.validate(tmp_js) + except Exception: + logger.exception("[JsonGenerator] 解析结果失败") + else: + return tmp_js + stack = [] + json_candidates = [] + # 定义括号匹配关系 + bracket_map = {")": "(", "]": "[", "}": "{"} + + for i, char in enumerate(result): + # 遇到左括号则入栈 + if char in bracket_map.values(): + stack.append((char, i)) + # 遇到右括号且栈不为空时检查匹配 + elif char in bracket_map and stack: + if not stack: + continue + top_char, top_index = stack[-1] + # 检查是否匹配当前右括号 + if top_char == bracket_map[char]: + stack.pop() + # 当栈为空且当前是右花括号时,认为找到一个完整JSON + if not stack and char == "}": + json_str = result[top_index:i+1] + json_candidates.append(json_str) + else: + # 如果不匹配,清空栈 + stack.clear() + # 移除重复项并保持顺序 + seen = set() + unique_jsons = [] + for json_str in json_candidates[::]: + if json_str not in seen: + seen.add(json_str) + unique_jsons.append(json_str) + + for json_str in unique_jsons: + try: + tmp_js = json.loads(json_str) + validator.validate(tmp_js) + except Exception: + logger.exception("[JsonGenerator] 解析结果失败") + else: + return tmp_js + + return "" + def __init__(self, query: str, conversation: list[dict[str, str]], schema: dict[str, Any]) -> None: """初始化JSON生成器""" self._query = query @@ -282,12 +339,10 @@ class JsonGenerator: """生成JSON""" Draft7Validator.check_schema(self._schema) validator = Draft7Validator(self._schema) - logger.info("[JSONGenerator] Schema:%s", self._schema) while self._count < JSON_GEN_MAX_TRIAL: self._count += 1 result = await self._single_trial() - logger.info("[JSONGenerator] 得到:%s", result) try: validator.validate(result) except Exception as err: # noqa: BLE001 diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index 23d20c7a6..0c3fb2ffd 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -2,7 +2,10 @@ """问题改写""" import logging +from textwrap import dedent +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field from apps.llm.function import JsonGenerator @@ -13,6 +16,12 @@ from apps.schemas.enum_var import LanguageType from .core import CorePattern logger = logging.getLogger(__name__) +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, +) class QuestionRewriteResult(BaseModel): @@ -31,7 +40,7 @@ class QuestionRewrite(CorePattern): LanguageType.CHINESE: r"You are a helpful assistant.", LanguageType.ENGLISH: r"You are a helpful assistant.", }, { - LanguageType.CHINESE: r""" + LanguageType.CHINESE: dedent(r""" 根据历史对话,推断用户的实际意图并补全用户的提问内容,历史对话被包含在标签中,用户意图被包含在标签中。 @@ -41,9 +50,11 @@ class QuestionRewrite(CorePattern): 3. 补全内容必须精准、恰当,不要编造任何内容。 4. 请输出补全后的问题,不要输出其他内容。 输出格式样例: - {{ - "question": "补全后的问题" - }} + ```json + { + "question": "补全后的问题" + } + ``` @@ -62,24 +73,26 @@ class QuestionRewrite(CorePattern): 详细点? - {{ - "question": "详细说明openEuler操作系统的优势和应用场景" - }} + ```json + { + "question": "详细说明openEuler操作系统的优势和应用场景" + } + ``` - {history} + {{history}} - {question} + {{question}} 现在,请输出补全后的问题: - """, - LanguageType.ENGLISH: r""" + """).strip("\n"), + LanguageType.ENGLISH: dedent(r""" Based on the historical dialogue, infer the user's actual intent and complete the user's question. \ @@ -93,9 +106,11 @@ user's question is already complete enough, directly output the user's question. 3. The completed content must be precise and appropriate; do not fabricate any content. 4. Output only the completed question; do not include any other content. Example output format: - {{ - "question": "The completed question" - }} + ```json + { + "question": "The completed question" + } + ``` @@ -124,20 +139,22 @@ and optimizations for cloud and edge computing. More details? - {{ - "question": "What are the features of openEuler? Please elaborate on its advantages and \ + ```json + { + "question": "What are the features of openEuler? Please elaborate on its advantages and \ application scenarios." - }} + } + ``` - {history} + {{history}} - {question} + {{question}} - """, + """).strip("\n"), } async def generate(self, **kwargs) -> str: # noqa: ANN003 @@ -148,7 +165,9 @@ application scenarios." messages = [ {"role": "system", "content": self.system_prompt[language]}, - {"role": "user", "content": self.user_prompt[language].format(history="", question=question)}, + {"role": "user", "content": _env.from_string( + self.user_prompt[language], + ).render(history="", question=question)}, ] llm = kwargs.get("llm") if not llm: @@ -181,6 +200,9 @@ application scenarios." self.input_tokens = llm.input_tokens self.output_tokens = llm.output_tokens + tmp_js = await JsonGenerator.parse_result_by_stack(result, QuestionRewriteResult.model_json_schema()) + if tmp_js is not None: + return tmp_js["question"] messages += [{"role": "assistant", "content": result}] json_gen = JsonGenerator( query="根据给定的背景信息,生成预测问题", diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py index 7702f6d09..40fc01588 100644 --- a/apps/llm/prompt.py +++ b/apps/llm/prompt.py @@ -48,7 +48,6 @@ JSON_GEN_BASIC = dedent(r""" {% endif %} - {% if not function_call %} # Tools You must call one function to assist with the user query. @@ -67,5 +66,4 @@ JSON_GEN_BASIC = dedent(r""" # Output - {% endif %} """) diff --git a/apps/models/service.py b/apps/models/service.py index 8fe23f846..2341a5a28 100644 --- a/apps/models/service.py +++ b/apps/models/service.py @@ -1,10 +1,8 @@ """插件 数据库表""" -import uuid from datetime import UTC, datetime from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String -from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column from apps.schemas.enum_var import PermissionType diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index db0e64f67..9389587b5 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -198,11 +198,7 @@ async def get_service_detail( name=data.name, description=data.description, overview=config.overview, - data=json.dumps( - config.config.model_dump(by_alias=True, exclude_none=True), - indent=4, - ensure_ascii=False, - ), + data=config.config.model_dump(by_alias=True, exclude_none=True), mcpType=config.mcpType, ) else: diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index a409e12e1..02ac6d637 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import JSONResponse from apps.dependency.user import verify_personal_token, verify_session +from apps.schemas.parameters import Type from apps.schemas.response_data import GetOperaRsp, GetParamsRsp from apps.services.appcenter import AppCenterManager from apps.services.flow import FlowManager @@ -57,7 +58,7 @@ async def get_parameters( @router.get("/operate", response_model=GetOperaRsp) -async def get_operate_parameters(paramType: str) -> JSONResponse: # noqa: N803 +async def get_operate_parameters(paramType: Type) -> JSONResponse: # noqa: N803 """Get parameters for node choice.""" result = await ParameterManager.get_operate_and_bind_type(paramType) return JSONResponse( diff --git a/apps/routers/record.py b/apps/routers/record.py index 9277b5417..9849fc13b 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -10,6 +10,7 @@ from fastapi.responses import JSONResponse from apps.common.security import Security from apps.dependency import verify_personal_token, verify_session +from apps.models.task import ExecutorHistory from apps.schemas.record import ( RecordContent, RecordData, @@ -22,7 +23,6 @@ from apps.schemas.response_data import ( RecordListRsp, ResponseData, ) -from apps.schemas.task import FlowStepHistory from apps.services.conversation import ConversationManager from apps.services.document import DocumentManager from apps.services.record import RecordManager @@ -67,7 +67,6 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path tmp_record = RecordData( id=record.id, - groupId=record_group.id, taskId=record.task_id, conversationId=conversationId, content=record_data, diff --git a/apps/routers/service.py b/apps/routers/service.py index 902919e6d..7e1054140 100644 --- a/apps/routers/service.py +++ b/apps/routers/service.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, Depends, Path, Request, status from fastapi.responses import JSONResponse from apps.dependency.user import verify_personal_token, verify_session -from apps.exceptions import InstancePermissionError, ServiceIDError +from apps.exceptions import InstancePermissionError from apps.schemas.enum_var import SearchType from apps.schemas.request_data import ChangeFavouriteServiceRequest, UpdateServiceRequest from apps.schemas.response_data import ( @@ -135,15 +135,6 @@ async def update_service(request: Request, data: UpdateServiceRequest) -> JSONRe else: try: service_id = await ServiceCenterManager.update_service(request.state.user_sub, data.service_id, data.data) - except ServiceIDError: - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="Service ID错误", - result={}, - ).model_dump(exclude_none=True, by_alias=True), - ) except InstancePermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 285aa954a..2fe3b88ac 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -9,18 +9,15 @@ from mcp.types import TextContent from pydantic import Field from apps.llm.reasoning import ReasoningLLM +from apps.models.mcp import MCPInfo, MCPTools +from apps.models.task import ExecutorHistory from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.host import MCPHost from apps.scheduler.mcp_agent.plan import MCPPlanner from apps.scheduler.pool.mcp.pool import MCPPool -from apps.schemas.enum_var import EventType, FlowStatus, LanguageType, StepStatus -from apps.schemas.mcp import ( - MCPCollection, - MCPTool, - Step, -) +from apps.schemas.enum_var import EventType, ExecutorStatus, LanguageType, StepStatus +from apps.schemas.mcp import Step from apps.schemas.message import FlowParams -from apps.schemas.task import FlowStepHistory from apps.services.appcenter import AppCenterManager from apps.services.mcp_service import MCPServiceManager from apps.services.task import TaskManager @@ -36,13 +33,13 @@ class MCPAgentExecutor(BaseExecutor): servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[]) + mcp_list: list[MCPInfo] = Field(description="MCP服务器列表", default=[]) mcp_pool: MCPPool = Field(description="MCP池", default=MCPPool()) - tools: dict[str, MCPTool] = Field( + tools: dict[str, MCPTools] = Field( description="MCP工具列表,key为tool_id", default={}, ) - tool_list: list[MCPTool] = Field( + tool_list: list[MCPTools] = Field( description="MCP工具列表,包含所有MCP工具", default=[], ) @@ -55,19 +52,18 @@ class MCPAgentExecutor(BaseExecutor): default=ReasoningLLM(), description="推理大模型", ) + step_cnt: int = Field(default=0, description="当前已执行步骤数") - async def update_tokens(self) -> None: - """更新令牌数""" - self.task.tokens.input_tokens = self.resoning_llm.input_tokens - self.task.tokens.output_tokens = self.resoning_llm.output_tokens - await TaskManager.save_task(self.task.id, self.task) + async def init(self) -> None: + """初始化MCP Agent""" + self.planner = MCPPlanner(self.runtime.userInput, self.resoning_llm, self.runtime.language) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" - logger.info("[FlowExecutor] 加载Executor状态") + logger.info("[MCPAgentExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state and self.task.state.flow_status != FlowStatus.INIT: - self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + if self.state and self.state.executorStatus != ExecutorStatus.INIT: + self.context = await TaskManager.get_context_by_task_id(self.task.id) async def load_mcp(self) -> None: """加载MCP服务器列表""" @@ -77,32 +73,35 @@ class MCPAgentExecutor(BaseExecutor): mcp_ids = app.mcp_service for mcp_id in mcp_ids: mcp_service = await MCPServiceManager.get_mcp_service(mcp_id) - if self.task.ids.user_sub not in mcp_service.activated: + if self.task.userSub not in mcp_service.activated: logger.warning( "[MCPAgentExecutor] 用户 %s 未启用MCP %s", - self.task.ids.user_sub, + self.task.userSub, mcp_id, ) continue self.mcp_list.append(mcp_service) - await self.mcp_pool._init_mcp(mcp_id, self.task.ids.user_sub) + await self.mcp_pool.init_mcp(mcp_id, self.task.userSub) for tool in mcp_service.tools: self.tools[tool.id] = tool self.tool_list.extend(mcp_service.tools) - self.tools[FINAL_TOOL_ID] = MCPTool( - id=FINAL_TOOL_ID, name="Final Tool", description="结束流程的工具", mcp_id="", input_schema={}, + self.tools[FINAL_TOOL_ID] = MCPTools( + id=FINAL_TOOL_ID, mcpId="", toolName="Final Tool", description="结束流程的工具", + inputSchema={}, outputSchema={}, ) - self.tool_list.append( - MCPTool(id=FINAL_TOOL_ID, name="Final Tool", description="结束流程的工具", mcp_id="", input_schema={}), + self.tool_list.append(MCPTools( + id=FINAL_TOOL_ID, mcpId="", toolName="Final Tool", description="结束流程的工具", + inputSchema={}, outputSchema={}), ) - async def get_tool_input_param(self, is_first: bool) -> None: + async def get_tool_input_param(self, *, is_first: bool) -> None: + """获取工具输入参数""" if is_first: # 获取第一个输入参数 mcp_tool = self.tools[self.task.state.tool_id] - self.task.state.current_input = await MCPHost._get_first_input_params( - mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task + self.state.currentInput = await MCPHost._get_first_input_params( + mcp_tool, self.runtime.userInput, self.state.stepDescription, self.task, ) else: # 获取后续输入参数 @@ -113,15 +112,15 @@ class MCPAgentExecutor(BaseExecutor): params = {} params_description = "" mcp_tool = self.tools[self.task.state.tool_id] - self.task.state.current_input = await MCPHost._fill_params( + self.state.currentInput = await MCPHost.fill_params( mcp_tool, - self.task.runtime.question, - self.task.state.step_description, - self.task.state.current_input, - self.task.state.error_message, + self.runtime.userInput, + self.state.stepDescription, + self.state.currentInput, + self.state.errorMessage, params, params_description, - self.task.language, + self.runtime.language, ) async def confirm_before_step(self) -> None: @@ -139,7 +138,7 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.flow_status = FlowStatus.WAITING self.task.state.step_status = StepStatus.WAITING self.task.context.append( - FlowStepHistory( + ExecutorHistory( task_id=self.task.id, step_id=self.task.state.step_id, step_name=self.task.state.step_name, @@ -169,8 +168,7 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.step_status = StepStatus.ERROR return except Exception as e: - import traceback - logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误: %s", mcp_tool.name, traceback.format_exc()) + logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误", mcp_tool.name) self.task.state.step_status = StepStatus.ERROR self.task.state.error_message = str(e) return @@ -180,8 +178,11 @@ class MCPAgentExecutor(BaseExecutor): for output in output_params.content: if isinstance(output, TextContent): err += output.text - self.task.state.step_status = StepStatus.ERROR - self.task.state.error_message = err + self.state.stepStatus = StepStatus.ERROR + self.state.errorMessage = { + "err_msg": err, + "data": {}, + } return message = "" for output in output_params.content: @@ -192,59 +193,55 @@ class MCPAgentExecutor(BaseExecutor): } await self.update_tokens() - await self.push_message(EventType.STEP_INPUT, self.task.state.current_input) + await self.push_message(EventType.STEP_INPUT, self.state.currentInput) await self.push_message(EventType.STEP_OUTPUT, output_params) - self.task.context.append( - FlowStepHistory( - task_id=self.task.id, - step_id=self.task.state.step_id, - step_name=self.task.state.step_name, - step_description=self.task.state.step_description, - step_status=StepStatus.SUCCESS, - flow_id=self.task.state.flow_id, - flow_name=self.task.state.flow_name, - flow_status=self.task.state.flow_status, - input_data=self.task.state.current_input, - output_data=output_params, - ) + self.context.append( + ExecutorHistory( + taskId=self.task.id, + stepId=self.state.stepId, + stepName=self.state.stepName, + stepDescription=self.state.stepDescription, + stepStatus=StepStatus.SUCCESS, + executorId=self.state.executorId, + executorName=self.state.executorName, + executorStatus=self.state.executorStatus, + inputData=self.state.currentInput, + outputData=output_params, + ), ) - self.task.state.step_status = StepStatus.SUCCESS + self.state.stepStatus = StepStatus.SUCCESS async def generate_params_with_null(self) -> None: """生成参数补充""" - mcp_tool = self.tools[self.task.state.tool_id] - params_with_null = await MCPPlanner.get_missing_param( + mcp_tool = self.tools[self.state.toolId] + params_with_null = await self.planner.get_missing_param( mcp_tool, - self.task.state.current_input, - self.task.state.error_message, - self.resoning_llm, - self.task.language, + self.state.currentInput, + self.state.errorMessage, ) await self.update_tokens() - error_message = await MCPPlanner.change_err_message_to_description( - error_message=self.task.state.error_message, + error_message = await self.planner.change_err_message_to_description( + error_message=self.state.errorMessage, tool=mcp_tool, input_params=self.task.state.current_input, - reasoning_llm=self.resoning_llm, - language=self.task.language, ) await self.push_message( EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null}, ) await self.push_message(EventType.FLOW_STOP, data={}) - self.task.state.flow_status = FlowStatus.WAITING - self.task.state.step_status = StepStatus.PARAM - self.task.context.append( - FlowStepHistory( - task_id=self.task.id, - step_id=self.task.state.step_id, - step_name=self.task.state.step_name, - step_description=self.task.state.step_description, - step_status=self.task.state.step_status, - flow_id=self.task.state.flow_id, - flow_name=self.task.state.flow_name, - flow_status=self.task.state.flow_status, - input_data={}, + self.state.executorStatus = ExecutorStatus.WAITING + self.state.stepStatus = StepStatus.PARAM + self.context.append( + ExecutorHistory( + taskId=self.task.id, + stepId=self.state.stepId, + stepName=self.state.stepName, + stepDescription=self.state.stepDescription, + stepStatus=self.state.stepStatus, + executorId=self.state.executorId, + executorName=self.state.executorName, + executorStatus=self.state.executorStatus, + inputData={}, output_data={}, ex_data={ "message": error_message, @@ -255,50 +252,46 @@ class MCPAgentExecutor(BaseExecutor): async def get_next_step(self) -> None: """获取下一步""" - if self.task.state.step_cnt < self.max_steps: - self.task.state.step_cnt += 1 + if self.step_cnt < self.max_steps: + self.step_cnt += 1 history = await MCPHost.assemble_memory(self.task) max_retry = 3 step = None - for i in range(max_retry): + for _ in range(max_retry): try: - step = await MCPPlanner.create_next_step(self.task.runtime.question, history, self.tool_list, language=self.task.language) - if step.tool_id in self.tools.keys(): + step = await self.planner.create_next_step(history, self.tool_list) + if step.tool_id in self.tools: break - except Exception as e: - logger.warning("[MCPAgentExecutor] 获取下一步失败,重试中: %s", str(e)) - if step is None or step.tool_id not in self.tools.keys(): + except Exception: + logger.exception("[MCPAgentExecutor] 获取下一步失败,重试中...") + if step is None or step.tool_id not in self.tools: step = Step( tool_id=FINAL_TOOL_ID, - description=FINAL_TOOL_ID + description=FINAL_TOOL_ID, ) tool_id = step.tool_id - if tool_id == FINAL_TOOL_ID: - step_name = FINAL_TOOL_ID - else: - step_name = self.tools[tool_id].name + step_name = FINAL_TOOL_ID if tool_id == FINAL_TOOL_ID else self.tools[tool_id].name step_description = step.description - self.task.state.step_id = str(uuid.uuid4()) - self.task.state.tool_id = tool_id - self.task.state.step_name = step_name - self.task.state.step_description = step_description - self.task.state.step_status = StepStatus.INIT - self.task.state.current_input = {} + self.state.stepId = uuid.uuid4() + self.state.toolId = tool_id + self.state.stepName = step_name + self.state.stepDescription = step_description + self.state.stepStatus = StepStatus.INIT + self.state.currentInput = {} else: # 没有下一步了,结束流程 - self.task.state.tool_id = FINAL_TOOL_ID - return + self.state.toolId = FINAL_TOOL_ID async def error_handle_after_step(self) -> None: """步骤执行失败后的错误处理""" - self.task.state.step_status = StepStatus.ERROR - self.task.state.flow_status = FlowStatus.ERROR + self.state.stepStatus = StepStatus.ERROR + self.state.executorStatus = ExecutorStatus.ERROR await self.push_message( EventType.FLOW_FAILED, - data={} + data={}, ) - if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: - del self.task.context[-1] + if len(self.context) and self.context[-1].step_id == self.state.step_id: + del self.context[-1] self.task.context.append( FlowStepHistory( task_id=self.task.id, @@ -316,26 +309,26 @@ class MCPAgentExecutor(BaseExecutor): async def work(self) -> None: """执行当前步骤""" - if self.task.state.step_status == StepStatus.INIT: + if self.state.stepStatus == StepStatus.INIT: await self.push_message( EventType.STEP_INIT, data={}, ) await self.get_tool_input_param(is_first=True) - user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + user_info = await UserManager.get_userinfo_by_user_sub(self.task.userSub) if not user_info.auto_execute: # 等待用户确认 await self.confirm_before_step() return - self.task.state.step_status = StepStatus.RUNNING - elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: - if self.task.state.step_status == StepStatus.PARAM: - if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: - del self.task.context[-1] - elif self.task.state.step_status == StepStatus.WAITING: + self.state.stepStatus = StepStatus.RUNNING + elif self.state.stepStatus in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: + if self.state.stepStatus == StepStatus.PARAM: + if len(self.context) and self.context[-1].stepId == self.state.stepId: + del self.context[-1] + elif self.state.stepStatus == StepStatus.WAITING: if self.params: - if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: - del self.task.context[-1] + if len(self.context) and self.context[-1].stepId == self.state.stepId: + del self.context[-1] else: self.task.state.flow_status = FlowStatus.CANCELLED self.task.state.step_status = StepStatus.CANCELLED diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 44751760a..1a1b12b96 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -12,7 +12,7 @@ from apps.schemas.message import TextAddContent if TYPE_CHECKING: from apps.common.queue import MessageQueue - from apps.models.task import ExecutorCheckpoint, Task, TaskRuntime + from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime from apps.schemas.scheduler import ExecutorBackground logger = logging.getLogger(__name__) @@ -24,7 +24,9 @@ class BaseExecutor(BaseModel, ABC): task: "Task" runtime: "TaskRuntime" state: "ExecutorCheckpoint" + context: list["ExecutorHistory"] msg_queue: "MessageQueue" + background: "ExecutorBackground" question: str diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 083625c68..9c41d27e1 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -8,7 +8,7 @@ from datetime import UTC, datetime from pydantic import Field -from apps.models.task import ExecutorHistory, Task +from apps.models.task import ExecutorCheckpoint from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.schemas.enum_var import EventType, ExecutorStatus, LanguageType, SpecialCallType, StepStatus from apps.schemas.flow import Flow, Step @@ -76,24 +76,22 @@ class FlowExecutor(BaseExecutor): logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State if ( - self.task.state - and self.task.state.flow_status != FlowStatus.INIT - and self.task.state.flow_status != FlowStatus.UNKNOWN + self.state + and self.state.executorStatus not in [ExecutorStatus.INIT, ExecutorStatus.UNKNOWN] ): - self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + self.context = await TaskManager.get_context_by_task_id(self.task.id) else: # 创建ExecutorState - self.state = ExecutorHistory( + self.state = ExecutorCheckpoint( + taskId=self.task.id, + appId=self.post_body_app.app_id, executorId=str(self.flow_id), executorName=self.flow.name, executorStatus=ExecutorStatus.RUNNING, - description=str(self.flow.description), - step_status=StepStatus.RUNNING, - app_id=str(self.post_body_app.app_id), - step_id="start", - step_name="开始" if self.task.language == LanguageType.CHINESE else "Start", + stepStatus=StepStatus.RUNNING, + stepId="start", + stepName="开始" if self.runtime.language == LanguageType.CHINESE else "Start", ) - self.validate_flow_state(self.task) # 是否到达Flow结束终点(变量) self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() @@ -110,6 +108,7 @@ class FlowExecutor(BaseExecutor): question=self.question, runtime=self.runtime, state=self.state, + context=self.context, ) # 初始化步骤 @@ -149,9 +148,9 @@ class FlowExecutor(BaseExecutor): return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1].output_data["branch_id"] + branch_id = self.context[-1].outputData["branch_id"] if branch_id: - next_steps = await self._find_next_id(self.state.stepId + "." + branch_id) + next_steps = await self._find_next_id(str(self.state.stepId) + "." + branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) else: logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") @@ -198,7 +197,7 @@ class FlowExecutor(BaseExecutor): self.step_queue.append( StepQueueItem( step_id=uuid.uuid4(), - step=step.get(self.task.language, step[LanguageType.CHINESE]), + step=step.get(self.runtime.language, step[LanguageType.CHINESE]), enable_filling=False, to_user=False, ), @@ -218,20 +217,20 @@ class FlowExecutor(BaseExecutor): self.step_queue.clear() self.step_queue.appendleft( StepQueueItem( - step_id=str(uuid.uuid4()), + step_id=uuid.uuid4(), step=Step( name=( - "错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling" + "错误处理" if self.runtime.language == LanguageType.CHINESE else "Error Handling" ), description=( - "错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling" + "错误处理" if self.runtime.language == LanguageType.CHINESE else "Error Handling" ), node=SpecialCallType.LLM.value, type=SpecialCallType.LLM.value, params={ - "user_prompt": LLM_ERROR_PROMPT[self.task.language].replace( + "user_prompt": LLM_ERROR_PROMPT[self.runtime.language].replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.state.errorMessage["err_msg"], ), }, ), @@ -265,7 +264,7 @@ class FlowExecutor(BaseExecutor): self.step_queue.append( StepQueueItem( step_id=uuid.uuid4(), - step=step.get(self.task.language, step[LanguageType.CHINESE]), + step=step.get(self.runtime.language, step[LanguageType.CHINESE]), ), ) await self._step_process() diff --git a/apps/scheduler/executor/qa.py b/apps/scheduler/executor/qa.py index e3d75e9d2..d7ffcfb1d 100644 --- a/apps/scheduler/executor/qa.py +++ b/apps/scheduler/executor/qa.py @@ -11,4 +11,3 @@ class QAExecutor(BaseExecutor): async def run(self) -> None: """运行QA""" pass - diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index cda0b7082..a3ddab274 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -271,7 +271,7 @@ class StepExecutor(BaseExecutor): inputData=self.obj.input, outputData=output_data, ) - self.task.context.append(history) + self.context.append(history) # 推送输出 await self.push_message(EventType.STEP_OUTPUT.value, output_data) diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 0dca0eb3a..4b9807ad2 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -1,13 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP基类""" -import json import logging -from typing import Any -from jsonschema import validate - -from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM logger = logging.getLogger(__name__) @@ -30,43 +25,8 @@ class MCPBase: message, streaming=False, temperature=0.07, - result_only=True, + result_only=False, ): result += chunk return result - - @staticmethod - async def _parse_result( - result: str, - schema: dict[str, Any], left_str: str = "{", right_str: str = "}", - ) -> dict[str, Any]: - """解析推理结果""" - left_index = result.find(left_str) - right_index = result.rfind(right_str) - flag = True - if left_str == -1 or right_str == -1: - flag = False - - if left_index > right_index: - flag = False - if flag: - try: - tmp_js = json.loads(result[left_index : right_index + 1]) - validate(instance=tmp_js, schema=schema) - except Exception: - logger.exception("[McpBase] 解析结果失败") - flag = False - if not flag: - json_generator = JsonGenerator( - "Please provide a JSON response based on the above information and schema.\n\n", - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": result}, - ], - schema, - ) - json_result = await json_generator.generate() - else: - json_result = json.loads(result[left_index : right_index + 1]) - return json_result diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 44d088f40..98b76946e 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -9,6 +9,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from apps.llm.reasoning import ReasoningLLM +from apps.models.mcp import MCPTools from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import ( CHANGE_ERROR_MESSAGE_TO_DESCRIPTION, @@ -24,7 +25,6 @@ from apps.schemas.enum_var import LanguageType from apps.schemas.mcp import ( FlowName, IsParamError, - MCPTool, Step, ToolRisk, ) @@ -59,7 +59,7 @@ class MCPPlanner(MCPBase): result = await self._parse_result(result, FlowName.model_json_schema()) return FlowName.model_validate(result) - async def create_next_step(self, history: str, tools: list[MCPTool]) -> Step: + async def create_next_step(self, history: str, tools: list[MCPTools]) -> Step: """创建下一步的执行步骤""" # 获取推理结果 template = _env.from_string(GEN_STEP[self.language]) @@ -79,7 +79,7 @@ class MCPPlanner(MCPBase): async def get_tool_risk( self, - tool: MCPTool, + tool: MCPTools, input_param: dict[str, Any], additional_info: str = "", ) -> ToolRisk: @@ -104,7 +104,7 @@ class MCPPlanner(MCPBase): self, history: str, error_message: str, - tool: MCPTool, + tool: MCPTools, step_description: str, input_params: dict[str, Any], ) -> IsParamError: @@ -127,7 +127,7 @@ class MCPPlanner(MCPBase): return IsParamError.model_validate(is_param_error) async def change_err_message_to_description( - self, error_message: str, tool: MCPTool, input_params: dict[str, Any], + self, error_message: str, tool: MCPTools, input_params: dict[str, Any], ) -> str: """将错误信息转换为工具描述""" template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION[self.language]) @@ -140,7 +140,7 @@ class MCPPlanner(MCPBase): ) return await self.get_resoning_result(prompt) - async def get_missing_param(self, tool: MCPTool, input_param: dict[str, Any], error_message: str) -> dict[str, Any]: + async def get_missing_param(self, tool: MCPTools, input_param: dict[str, Any], error_message: str) -> dict[str, Any]: """获取缺失的参数""" slot = Slot(schema=tool.input_schema) template = _env.from_string(GET_MISSING_PARAMS[self.language]) diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 487d7ca47..4577e1789 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -66,6 +66,8 @@ GEN_STEP: dict[LanguageType, str] = { 2.能够基于当前的计划和历史,完成阶段性的任务。 3.不要选择不存在的工具。 4.如果你认为当前已经达成了用户的目标,可以直接返回Final工具,表示计划执行结束。 + 5.tool_id中的工具ID必须是当前工具集合中存在的工具ID,而不是工具的名称。 + 6.工具在 XML标签中给出,工具的id在 下的 XML标签中给出。 # 样例 1 # 目标 @@ -82,16 +84,16 @@ GEN_STEP: dict[LanguageType, str] = { - 得到数据:`{"result": "success"}` # 工具 - - mcp_tool_1 mysql_analyzer;用于分析数据库性能/description> - - mcp_tool_2 文件存储工具;用于存储文件 - - mcp_tool_3 mongoDB工具;用于操作MongoDB数据库 + - mcp_tool_1 mysql分析工具,用于分析数据库性能/description> + - mcp_tool_2 文件存储工具,用于存储文件 + - mcp_tool_3 mongoDB工具,用于操作MongoDB数据库 - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终\ 结果。 # 输出 ```json { - "tool_id": "mcp_tool_1", // 选择的工具ID + "tool_id": "mcp_tool_1", "description": "扫描ip为192.168.1.1的MySQL数据库,端口为3306,用户名为root,密码为password的数据库性能", } ``` @@ -100,19 +102,19 @@ GEN_STEP: dict[LanguageType, str] = { 计划从杭州到北京的旅游计划 # 历史记录 第1步:将杭州转换为经纬度坐标 - - 调用工具 `maps_geo_planner`,并提供参数 `{"city_from": "杭州", "address": "西湖"}` - - 执行状态:成功 - - 得到数据:`{"location": "123.456, 78.901"}` + - 调用工具 `经纬度工具`,并提供参数 `{"city_from": "杭州", "address": "西湖"}` + - 执行状态:成功 + - 得到数据:`{"location": "123.456, 78.901"}` 第2步:查询杭州的天气 - - 调用工具 `weather_query`,并提供参数 `{"location": "123.456, 78.901"}` + - 调用工具 `天气查询工具`,并提供参数 `{"location": "123.456, 78.901"}` - 执行状态:成功 - 得到数据:`{"weather": "晴", "temperature": "25°C"}` 第3步:将北京转换为经纬度坐标 - - 调用工具 `maps_geo_planner`,并提供参数 `{"city_from": "北京", "address": "天安门"}` + - 调用工具 `经纬度工具`,并提供参数 `{"city_from": "北京", "address": "天安门"}` - 执行状态:成功 - 得到数据:`{"location": "123.456, 78.901"}` 第4步:查询北京的天气 - - 调用工具 `weather_query`,并提供参数 `{"location": "123.456, 78.901"}` + - 调用工具 `天气查询工具`,并提供参数 `{"location": "123.456, 78.901"}` - 执行状态:成功 - 得到数据:`{"weather": "晴", "temperature": "25°C"}` # 工具 @@ -128,7 +130,7 @@ GEN_STEP: dict[LanguageType, str] = { # 输出 ```json { - "tool_id": "mcp_tool_6", // 选择的工具ID + "tool_id": "mcp_tool_6", "description": "规划从杭州到北京的综合公共交通方式的通勤方案" } ``` @@ -140,7 +142,7 @@ GEN_STEP: dict[LanguageType, str] = { # 工具 {% for tool in tools %} - - {{tool.id}} {{tool.name}};{{tool.description}} + - {{tool.id}} {{tool.description}} {% endfor %} """, @@ -183,7 +185,7 @@ final result. # Output ```json { - "tool_id": "mcp_tool_1", // Selected tool ID + "tool_id": "mcp_tool_1", "description": "Scan the database performance of the MySQL database with IP address 192.168.1.1, \ port 3306, username root, and password password", } @@ -225,7 +227,7 @@ is complete, and the resulting result is used as the final result. {% for tool in tools %} - - {{tool.id}} {{tool.name}}; {{tool.description}} + - {{tool.id}} {{tool.description}} {% endfor %} """, diff --git a/apps/scheduler/pool/mcp/client.py b/apps/scheduler/pool/mcp/client.py index 297ed3e2e..fc26249f2 100644 --- a/apps/scheduler/pool/mcp/client.py +++ b/apps/scheduler/pool/mcp/client.py @@ -136,8 +136,9 @@ class MCPClient: ) if self.error_sign.is_set(): self.status = MCPStatus.ERROR - logger.error("[MCPClient] MCP %s:初始化失败", mcp_id) - raise Exception(f"MCP {mcp_id} 初始化失败") + err_msg = f"[MCPClient] MCP {mcp_id} 初始化失败" + logger.error(err_msg) + raise RuntimeError(err_msg) # 获取工具列表 self.tools = (await self.client.list_tools()).tools diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index 91413db61..cea8d6b57 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -22,7 +22,7 @@ class MCPPool(metaclass=SingletonMeta): self.pool = {} - async def _init_mcp(self, mcp_id: str, user_sub: str) -> MCPClient | None: + async def init_mcp(self, mcp_id: str, user_sub: str) -> MCPClient | None: """初始化MCP池""" config_path = MCP_USER_PATH / user_sub / mcp_id / "config.json" @@ -75,7 +75,7 @@ class MCPPool(metaclass=SingletonMeta): return None # 初始化进程 - item = await self._init_mcp(mcp_id, user_sub) + item = await self.init_mcp(mcp_id, user_sub) if item is None: return None diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py index bd908d237..31714faa8 100644 --- a/apps/schemas/parameters.py +++ b/apps/schemas/parameters.py @@ -1,3 +1,5 @@ +"""步骤参数相关""" + from enum import Enum diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 201e1b599..2ba8f82f8 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -7,8 +7,7 @@ from typing import Any from pydantic import BaseModel, Field from apps.models.blacklist import Blacklist -from apps.models.mcp import MCPInstallStatus -from apps.templates.generate_llm_operator_config import llm_provider_dict +from apps.models.mcp import MCPInstallStatus, MCPTools from .appcenter import AppCenterCardItem, AppData from .enum_var import DocumentStatus @@ -467,7 +466,7 @@ class GetMCPServiceDetailMsg(BaseModel): name: str = Field(..., description="MCP服务名称") description: str = Field(description="MCP服务描述") overview: str = Field(description="MCP服务概述") - tools: list[MCPTool] = Field(description="MCP服务Tools列表", default=[]) + tools: list[MCPTools] = Field(description="MCP服务Tools列表", default=[]) status: MCPInstallStatus = Field( description="MCP服务状态", default=MCPInstallStatus.INIT, @@ -482,7 +481,7 @@ class EditMCPServiceMsg(BaseModel): name: str = Field(..., description="MCP服务名称") description: str = Field(description="MCP服务描述") overview: str = Field(description="MCP服务概述") - data: str = Field(description="MCP服务配置") + data: dict[str, Any] = Field(description="MCP服务配置") mcp_type: MCPType = Field(alias="mcpType", description="MCP 类型") diff --git a/apps/services/task.py b/apps/services/task.py index bc009c335..a2771f26c 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -75,7 +75,7 @@ class TaskManager: @staticmethod - async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[ExecutorHistory]: + async def get_context_by_task_id(task_id: uuid.UUID, length: int | None = None) -> list[ExecutorHistory]: """根据task_id获取flow信息""" async with postgres.session() as session: return list((await session.scalars( @@ -179,7 +179,7 @@ class TaskManager: @classmethod - async def save_task(cls, task_id: str, task: Task) -> None: + async def save_task(cls, task_id: uuid.UUID, task: Task) -> None: """保存任务块""" task_collection = MongoDB().get_collection("task") -- Gitee