diff --git a/apps/common/minio.py b/apps/common/minio.py index b5187dbb5b63674879c53def5a125cdd60348544..4974fc574b883a1c4274b3565ee46b753f241623 100644 --- a/apps/common/minio.py +++ b/apps/common/minio.py @@ -32,6 +32,7 @@ class MinioClient: @classmethod def download_file(cls, bucket_name: str, file_path: str) -> tuple[dict[str, Any], bytes]: """下载文件""" + response = None try: obj_stat = cls.client.stat_object(bucket_name, file_path) metadata = obj_stat.metadata if isinstance(obj_stat.metadata, dict) else {} diff --git a/apps/common/queue.py b/apps/common/queue.py index 13165c5d3d8b9c0ad416656b88697b66c78b9c8d..95b6f44989abcd409a31d8ed1717e2f84f689376 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -56,9 +56,12 @@ class MessageQueue: flow = MessageFlow( appId=task.state.app_id, flowId=task.state.flow_id, + flowName=task.state.flow_name, + flowStatus=task.state.flow_status, stepId=task.state.step_id, stepName=task.state.step_name, - stepStatus=task.state.status, + stepDescription=task.state.step_description, + stepStatus=task.state.step_status ) else: flow = None diff --git a/apps/constants.py b/apps/constants.py index 58158b33c3aa4ea5aa3595f5642bf6444e7d76cb..20cb79b54ac8db5450fdbb296cb400b47a29adf0 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -11,7 +11,7 @@ from apps.common.config import Config # 新对话默认标题 NEW_CHAT = "新对话" # 滑动窗口限流 默认窗口期 -SLIDE_WINDOW_TIME = 60 +SLIDE_WINDOW_TIME = 15 # OIDC 访问Token 过期时间(分钟) OIDC_ACCESS_TOKEN_EXPIRE_TIME = 30 # OIDC 刷新Token 过期时间(分钟) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 87cbd290283507aa911e2c0d188bea2099bd150f..7841a96ddf5240c0c50aa9615728db82f12ee282 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """用户鉴权""" +import os import logging from fastapi import Depends @@ -75,8 +76,11 @@ async def get_user(request: HTTPConnection) -> str: :return: 用户sub """ if Config().get_config().no_auth.enable: - # 如果启用了无认证访问,直接返回调试用户 - return Config().get_config().no_auth.user_sub + # 如果启用了无认证访问,直接返回当前操作系统用户的名称 + username = os.environ.get('USERNAME') # 适用于 Windows 系统 + if not username: + username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 + return username or "admin" session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/llm/embedding.py b/apps/llm/embedding.py index 28ab86b49a96d19cc6e2db83930d5aff48f82260..b56723e08161c37dcb34aa8a5d44f4f4b274be3e 100644 --- a/apps/llm/embedding.py +++ b/apps/llm/embedding.py @@ -2,7 +2,9 @@ import httpx +import logging from apps.common.config import Config +logger = logging.getLogger(__name__) class Embedding: @@ -75,10 +77,18 @@ class Embedding: :param text: 待向量化文本(多条文本组成List) :return: 文本对应的向量(顺序与text一致,也为List) """ - if Config().get_config().embedding.type == "openai": - return await cls._get_openai_embedding(text) - if Config().get_config().embedding.type == "mindie": - return await cls._get_tei_embedding(text) + try: + if Config().get_config().embedding.type == "openai": + return await cls._get_openai_embedding(text) + if Config().get_config().embedding.type == "mindie": + return await cls._get_tei_embedding(text) - err = f"不支持的Embedding API类型: {Config().get_config().embedding.type}" - raise ValueError(err) + err = f"不支持的Embedding API类型: {Config().get_config().embedding.type}" + raise ValueError(err) + except Exception as e: + err = f"获取Embedding失败: {e}" + logger.error(err) + rt = [] + for i in range(len(text)): + rt.append([0.0]*1024) + return rt diff --git a/apps/llm/function.py b/apps/llm/function.py index 1f995fe7ba187cead03aa6fc62a4cbce1ec05a65..efaa11548886b21a07bbedb372572cbba021feac 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -11,6 +11,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from jsonschema import Draft7Validator +from jsonschema import validate from apps.common.config import Config from apps.constants import JSON_GEN_MAX_TRIAL, REASONING_END_TOKEN from apps.llm.prompt import JSON_GEN_BASIC @@ -42,6 +43,7 @@ class FunctionLLM: self._params = { "model": self._config.model, "messages": [], + "timeout": 300 } if self._config.backend == "ollama": @@ -237,6 +239,58 @@ class FunctionLLM: class JsonGenerator: """JSON生成器""" + @staticmethod + async def _parse_result_by_stack(result: str, schema: dict[str, Any]) -> str: + """解析推理结果""" + left_index = result.find('{') + right_index = result.rfind('}') + if left_index != -1 and right_index != -1 and left_index < right_index: + try: + tmp_js = json.loads(result[left_index:right_index + 1]) + validate(instance=tmp_js, schema=schema) + return tmp_js + except Exception as e: + logger.error("[JsonGenerator] 解析结果失败: %s", e) + stack = [] + json_candidates = [] + # 定义括号匹配关系 + bracket_map = {')': '(', ']': '[', '}': '{'} + + for i, char in enumerate(result): + # 遇到左括号则入栈 + if char in bracket_map.values(): + stack.append((char, i)) + # 遇到右括号且栈不为空时检查匹配 + elif char in bracket_map.keys() and stack: + if not stack: + continue + top_char, top_index = stack[-1] + # 检查是否匹配当前右括号 + if top_char == bracket_map[char]: + stack.pop() + # 当栈为空且当前是右花括号时,认为找到一个完整JSON + if not stack and char == '}': + json_str = result[top_index:i+1] + json_candidates.append(json_str) + else: + # 如果不匹配,清空栈 + stack.clear() + # 移除重复项并保持顺序 + seen = set() + unique_jsons = [] + for json_str in json_candidates[::]: + if json_str not in seen: + seen.add(json_str) + unique_jsons.append(json_str) + + for json_str in unique_jsons: + try: + tmp_js = json.loads(json_str) + validate(instance=tmp_js, schema=schema) + return tmp_js + except Exception as e: + logger.error("[JsonGenerator] 解析结果失败: %s", e) + return None def __init__(self, query: str, conversation: list[dict[str, str]], schema: dict[str, Any]) -> None: """初始化JSON生成器""" @@ -275,8 +329,8 @@ class JsonGenerator: """单次尝试""" prompt = await self._assemble_message() messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, + {"role": "system", "content": prompt}, ++ {"role": "user", "content": "please generate a JSON response based on the above information and schema."}, ] function = FunctionLLM() return await function.call(messages, self._schema, max_tokens, temperature) @@ -291,14 +345,12 @@ class JsonGenerator: while self._count < JSON_GEN_MAX_TRIAL: self._count += 1 result = await self._single_trial() - logger.info("[JSONGenerator] 得到:%s", result) try: validator.validate(result) except Exception as err: # noqa: BLE001 err_info = str(err) err_info = err_info.split("\n\n")[0] self._err_info = err_info - logger.info("[JSONGenerator] 验证失败:%s", self._err_info) continue return result diff --git a/apps/llm/patterns/core.py b/apps/llm/patterns/core.py index 4ef8133a9fed1b1e62f1ceb578c6bdb5a93b12a5..c4a58364ec7c21a132f401d5ebfb98e3d7c58b82 100644 --- a/apps/llm/patterns/core.py +++ b/apps/llm/patterns/core.py @@ -3,40 +3,52 @@ from abc import ABC, abstractmethod from textwrap import dedent +from pydantic import BaseModel, Field +from apps.schemas.enum_var import LanguageType class CorePattern(ABC): """基础大模型范式抽象类""" - system_prompt: str = "" - """系统提示词""" - user_prompt: str = "" """用户提示词""" input_tokens: int = 0 """输入Token数量""" output_tokens: int = 0 """输出Token数量""" + def get_default_prompt(self) -> dict[LanguageType, str]: + """ + 获取默认的用户提示词 + + :return: 默认的用户提示词 + :rtype: dict[LanguageType, str] + """ + return {}, {} - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[LanguageType, str] | None = None, + ) -> None: """ 检查是否已经自定义了Prompt;有的话就用自定义的;同时对Prompt进行空格清除 :param system_prompt: 系统提示词,f-string格式 :param user_prompt: 用户提示词,f-string格式 """ + default_system_prompt, default_user_prompt = self.get_default_prompt() if system_prompt is not None: self.system_prompt = system_prompt - + else: + self.system_prompt = default_system_prompt if user_prompt is not None: self.user_prompt = user_prompt + else: + self.user_prompt = default_user_prompt - if not self.user_prompt: - err = "必须设置用户提示词!" - raise ValueError(err) + self.system_prompt = {lang: dedent(prompt).strip("\n") for lang, prompt in self.system_prompt.items()} - self.system_prompt = dedent(self.system_prompt).strip("\n") - self.user_prompt = dedent(self.user_prompt).strip("\n") + self.user_prompt = {lang: dedent(prompt).strip("\n") for lang, prompt in self.user_prompt.items()} @abstractmethod async def generate(self, **kwargs): # noqa: ANN003, ANN201 diff --git a/apps/llm/patterns/executor.py b/apps/llm/patterns/executor.py index f872fd2ac8d691b4079756a56a6107d6b6556585..e2153487a568eaa1289677b14111bbcbcc7b68ea 100644 --- a/apps/llm/patterns/executor.py +++ b/apps/llm/patterns/executor.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM from apps.llm.snippet import convert_context_to_prompt, facts_to_prompt - +from apps.schemas.enum_var import LanguageType if TYPE_CHECKING: from apps.schemas.scheduler import ExecutorBackground @@ -14,40 +14,79 @@ if TYPE_CHECKING: class ExecutorThought(CorePattern): """通过大模型生成Executor的思考内容""" - user_prompt: str = r""" - - - 你是一个可以使用工具的智能助手。 - 在回答用户的问题时,你为了获取更多的信息,使用了一个工具。 - 请简明扼要地总结工具的使用过程,提供你的见解,并给出下一步的行动。 - - 注意: - 工具的相关信息在标签中给出。 - 为了使你更好的理解发生了什么,你之前的思考过程在标签中给出。 - 输出时请不要包含XML标签,输出时请保持简明和清晰。 - - - - - {tool_name} - {tool_description} - {tool_output} - - - - {last_thought} - - - - 你当前需要解决的问题是: - {user_question} - - - 请综合以上信息,再次一步一步地进行思考,并给出见解和行动: - """ - """用户提示词""" - - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: + def get_default_prompt(self) -> dict[LanguageType, str]: + user_prompt = { + LanguageType.CHINESE: r""" + + + 你是一个可以使用工具的智能助手。 + 在回答用户的问题时,你为了获取更多的信息,使用了一个工具。 + 请简明扼要地总结工具的使用过程,提供你的见解,并给出下一步的行动。 + + 注意: + 工具的相关信息在标签中给出。 + 为了使你更好的理解发生了什么,你之前的思考过程在标签中给出。 + 输出时请不要包含XML标签,输出时请保持简明和清晰。 + + + + + {tool_name} + {tool_description} + {tool_output} + + + + {last_thought} + + + + 你当前需要解决的问题是: + {user_question} + + + 请综合以上信息,再次一步一步地进行思考,并给出见解和行动: + """, + LanguageType.ENGLISH: r""" + + + You are an intelligent assistant who can use tools. + When answering user questions, you use a tool to get more information. + Please summarize the process of using the tool briefly, provide your insights, and give the next action. + + Note: + The information about the tool is given in the tag. + To help you better understand what happened, your previous thought process is given in the tag. + Do not include XML tags in the output, and keep the output brief and clear. + + + + + {tool_name} + {tool_description} + {tool_output} + + + + {last_thought} + + + + The question you need to solve is: + {user_question} + + + Please integrate the above information, think step by step again, provide insights, and give actions: + """, + } + """用户提示词""" + return {}, user_prompt + + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[LanguageType, str] | None = None, + ) -> None: """处理Prompt""" super().__init__(system_prompt, user_prompt) @@ -57,19 +96,23 @@ class ExecutorThought(CorePattern): last_thought: str = kwargs["last_thought"] user_question: str = kwargs["user_question"] tool_info: dict[str, Any] = kwargs["tool_info"] + language: LanguageType = kwargs.get("language", LanguageType.CHINESE) except Exception as e: err = "参数不正确!" raise ValueError(err) from e messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": self.user_prompt.format( - last_thought=last_thought, - user_question=user_question, - tool_name=tool_info["name"], - tool_description=tool_info["description"], - tool_output=tool_info["output"], - )}, + { + "role": "user", + "content": self.user_prompt[language].format( + last_thought=last_thought, + user_question=user_question, + tool_name=tool_info["name"], + tool_description=tool_info["description"], + tool_output=tool_info["output"], + ), + }, ] llm = ReasoningLLM() @@ -85,30 +128,59 @@ class ExecutorThought(CorePattern): class ExecutorSummary(CorePattern): """使用大模型进行生成Executor初始背景""" - user_prompt: str = r""" - - 根据给定的对话记录和关键事实,生成一个三句话背景总结。这个总结将用于后续对话的上下文理解。 - - 生成总结的要求如下: - 1. 突出重要信息点,例如时间、地点、人物、事件等。 - 2. “关键事实”中的内容可在生成总结时作为已知信息。 - 3. 输出时请不要包含XML标签,确保信息准确性,不得编造信息。 - 4. 总结应少于3句话,应少于300个字。 - - 对话记录将在标签中给出,关键事实将在标签中给出。 - - - {conversation} - - - {facts} - - - 现在,请开始生成背景总结: - """ - """用户提示词""" - - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: + def get_default_prompt(self) -> dict[LanguageType, str]: + user_prompt = { + LanguageType.CHINESE: r""" + + 根据给定的对话记录和关键事实,生成一个三句话背景总结。这个总结将用于后续对话的上下文理解。 + + 生成总结的要求如下: + 1. 突出重要信息点,例如时间、地点、人物、事件等。 + 2. “关键事实”中的内容可在生成总结时作为已知信息。 + 3. 输出时请不要包含XML标签,确保信息准确性,不得编造信息。 + 4. 总结应少于3句话,应少于300个字。 + + 对话记录将在标签中给出,关键事实将在标签中给出。 + + + {conversation} + + + {facts} + + + 现在,请开始生成背景总结: + """, + LanguageType.ENGLISH: r""" + + Based on the given conversation records and key facts, generate a three-sentence background summary. This summary will be used for context understanding in subsequent conversations. + + The requirements for generating the summary are as follows: + 1. Highlight important information points, such as time, location, people, events, etc. + 2. The content in the "key facts" can be used as known information when generating the summary. + 3. Do not include XML tags in the output, ensure the accuracy of the information, and do not make up information. + 4. The summary should be less than 3 sentences and less than 300 words. + + The conversation records will be given in the tag, and the key facts will be given in the tag. + + + {conversation} + + + {facts} + + + Now, please start generating the background summary: + """, + } + """用户提示词""" + return {}, user_prompt + + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[LanguageType, str] | None = None, + ) -> None: """初始化Background模式""" super().__init__(system_prompt, user_prompt) @@ -117,13 +189,17 @@ class ExecutorSummary(CorePattern): background: ExecutorBackground = kwargs["background"] conversation_str = convert_context_to_prompt(background.conversation) facts_str = facts_to_prompt(background.facts) + language = kwargs.get("language", LanguageType.CHINESE) messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": self.user_prompt.format( - facts=facts_str, - conversation=conversation_str, - )}, + { + "role": "user", + "content": self.user_prompt[language].format( + facts=facts_str, + conversation=conversation_str, + ), + }, ] result = "" diff --git a/apps/llm/patterns/facts.py b/apps/llm/patterns/facts.py index 0b0381ff40a0e6632fd204c9efcf834a13c4711f..a8fa09442ec785c3fe5cf2d1b09acaa0fae6f49b 100644 --- a/apps/llm/patterns/facts.py +++ b/apps/llm/patterns/facts.py @@ -9,6 +9,7 @@ from apps.llm.function import JsonGenerator from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM from apps.llm.snippet import convert_context_to_prompt +from apps.schemas.enum_var import LanguageType logger = logging.getLogger(__name__) @@ -22,52 +23,100 @@ class FactsResult(BaseModel): class Facts(CorePattern): """事实提取""" - system_prompt: str = "You are a helpful assistant." - """系统提示词(暂不使用)""" - - user_prompt: str = r""" - - - 从对话中提取关键信息,并将它们组织成独一无二的、易于理解的事实,包含用户偏好、关系、实体等有用信息。 - 以下是需要关注的信息类型以及有关如何处理输入数据的详细说明。 - - **你需要关注的信息类型** - 1. 实体:对话中涉及到的实体。例如:姓名、地点、组织、事件等。 - 2. 偏好:对待实体的态度。例如喜欢、讨厌等。 - 3. 关系:用户与实体之间,或两个实体之间的关系。例如包含、并列、互斥等。 - 4. 动作:对实体产生影响的具体动作。例如查询、搜索、浏览、点击等。 - - **要求** - 1. 事实必须准确,只能从对话中提取。不要将样例中的信息体现在输出中。 - 2. 事实必须清晰、简洁、易于理解。必须少于30个字。 - 3. 必须按照以下JSON格式输出: - - {{ - "facts": ["事实1", "事实2", "事实3"] - }} - - - - - 杭州西湖有哪些景点? - 杭州西湖是中国浙江省杭州市的一个著名景点,以其美丽的自然风光和丰富的文化遗产而闻名。西湖周围有许多著名的景点,包括著名的苏堤、白堤、断桥、三潭印月等。西湖以其清澈的湖水和周围的山脉而著名,是中国最著名的湖泊之一。 - - - + def get_default_prompt(self) -> dict[LanguageType, str]: + system_prompt = { + LanguageType.CHINESE: "你是一个有用的助手。", + LanguageType.ENGLISH: "You are a helpful assistant." + } + user_prompt = { + LanguageType.CHINESE: r""" + + + 从对话中提取关键信息,并将它们组织成独一无二的、易于理解的事实,包含用户偏好、关系、实体等有用信息。 + 以下是需要关注的信息类型以及有关如何处理输入数据的详细说明。 + + **你需要关注的信息类型** + 1. 实体:对话中涉及到的实体。例如:姓名、地点、组织、事件等。 + 2. 偏好:对待实体的态度。例如喜欢、讨厌等。 + 3. 关系:用户与实体之间,或两个实体之间的关系。例如包含、并列、互斥等。 + 4. 动作:对实体产生影响的具体动作。例如查询、搜索、浏览、点击等。 + + **要求** + 1. 事实必须准确,只能从对话中提取。不要将样例中的信息体现在输出中。 + 2. 事实必须清晰、简洁、易于理解。必须少于30个字。 + 3. 必须按照以下JSON格式输出: + {{ - "facts": ["杭州西湖有苏堤、白堤、断桥、三潭印月等景点"] + "facts": ["事实1", "事实2", "事实3"] }} - - - - - {conversation} - - """ - """用户提示词""" - + + + + + 杭州西湖有哪些景点? + 杭州西湖是中国浙江省杭州市的一个著名景点,以其美丽的自然风光和丰富的文化遗产而闻名。西湖周围有许多著名的景点,包括著名的苏堤、白堤、断桥、三潭印月等。西湖以其清澈的湖水和周围的山脉而著名,是中国最著名的湖泊之一。 + + + + {{ + "facts": ["杭州西湖有苏堤、白堤、断桥、三潭印月等景点"] + }} + + + + + {conversation} + + """, + LanguageType.ENGLISH: r""" + + + Extract key information from the conversation and organize it into unique, easily understandable facts that include user preferences, relationships, entities, etc. + The following are the types of information to be paid attention to and detailed instructions on how to handle the input data. + + **Types of information to be paid attention to** + 1. Entities: Entities involved in the conversation. For example: names, locations, organizations, events, etc. + 2. Preferences: Attitudes towards entities. For example: like, dislike, etc. + 3. Relationships: Relationships between the user and entities, or between two entities. For example: include, parallel, exclusive, etc. + 4. Actions: Specific actions that affect entities. For example: query, search, browse, click, etc. + + **Requirements** + 1. Facts must be accurate and can only be extracted from the conversation. Do not include information from the sample in the output. + 2. Facts must be clear, concise, and easy to understand. Must be less than 30 words. + 3. Output in the following JSON format: - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: + {{ + "facts": ["fact1", "fact2", "fact3"] + }} + + + + + What are the attractions in West Lake, Hangzhou? + West Lake in Hangzhou is a famous scenic spot in Hangzhou, Zhejiang Province, China, famous for its beautiful natural scenery and rich cultural heritage. There are many famous attractions around West Lake, including the famous Su Causeway, Bai Causeway, Broken Bridge, Three Pools Mirroring the Moon, etc. West Lake is famous for its clear water and surrounding mountains, and is one of the most famous lakes in China. + + + + {{ + "facts": ["West Lake has the famous attractions of Suzhou Embankment, Bai Embankment, Qiantang Bridge, San Tang Yin Yue, etc."] + }} + + + + + {conversation} + + """, + } + """用户提示词""" + return system_prompt, user_prompt + + + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[LanguageType, str] | None = None, + ) -> None: """初始化Prompt""" super().__init__(system_prompt, user_prompt) @@ -75,9 +124,11 @@ class Facts(CorePattern): async def generate(self, **kwargs) -> list[str]: # noqa: ANN003 """事实提取""" conversation = convert_context_to_prompt(kwargs["conversation"]) + language = kwargs.get("language", LanguageType.CHINESE) + messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.user_prompt.format(conversation=conversation)}, + {"role": "system", "content": self.system_prompt[language]}, + {"role": "user", "content": self.user_prompt[language].format(conversation=conversation)}, ] result = "" llm = ReasoningLLM() @@ -88,7 +139,7 @@ class Facts(CorePattern): messages += [{"role": "assistant", "content": result}] json_gen = JsonGenerator( - query="根据给定的背景信息,提取事实条目", + query="Extract fact entries based on the given background information", conversation=messages, schema=FactsResult.model_json_schema(), ) diff --git a/apps/llm/patterns/rewoo.py b/apps/llm/patterns/rewoo.py index ef78d92667d30fbd3b26d55ba4d87961181f3d48..66c3c65114a08b7824352ec629b10c748bcf7cda 100644 --- a/apps/llm/patterns/rewoo.py +++ b/apps/llm/patterns/rewoo.py @@ -3,59 +3,122 @@ from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM +from apps.schemas.enum_var import LanguageType class InitPlan(CorePattern): """规划生成命令行""" - system_prompt: str = r""" - 你是一个计划生成器。对于给定的目标,**制定一个简单的计划**,该计划可以逐步生成合适的命令行参数和标志。 - - 你会收到一个"命令前缀",这是已经确定和生成的命令部分。你需要基于这个前缀使用标志和参数来完成命令。 - - 在每一步中,指明使用哪个外部工具以及工具输入来获取证据。 - - 工具可以是以下之一: - (1) Option["指令"]:查询最相似的命令行标志。只接受一个输入参数,"指令"必须是搜索字符串。搜索字符串应该详细且包含必要的数据。 - (2) Argument[名称]<值>:将任务中的数据放置到命令行的特定位置。接受两个输入参数。 - - 所有步骤必须以"Plan: "开头,且少于150个单词。 - 不要添加任何多余的步骤。 - 确保每个步骤都包含所需的所有信息 - 不要跳过步骤。 - 不要在证据后面添加任何额外数据。 - - 开始示例 - - 任务:在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 - 前缀:`docker run` - 用法:`docker run ${OPTS} ${image} ${command}`。这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 \ - ["image", "command"] 其中之一。 - 前缀描述:二进制程序`docker`的描述为"Docker容器平台",`run`子命令的描述为"从镜像创建并运行一个新的容器"。 - - Plan: 我需要一个标志使容器在后台运行。 #E1 = Option[在后台运行单个容器] - Plan: 我需要一个标志,将主机/root目录挂载至容器内/data目录。 #E2 = Option[挂载主机/root目录至/data目录] - Plan: 我需要从任务中解析出镜像名称。 #E3 = Argument[image] - Plan: 我需要指定容器中运行的命令。 #E4 = Argument[command] - Final: 组装上述线索,生成最终命令。 #F - - 示例结束 - - 让我们开始! - """ - """系统提示词""" - - user_prompt: str = r""" - 任务:{instruction} - 前缀:`{binary_name} {subcmd_name}` - 用法:`{subcmd_usage}`。这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 {argument_list} 其中之一。 - 前缀描述:二进制程序`{binary_name}`的描述为"{binary_description}",`{subcmd_name}`子命令的描述为\ - "{subcmd_description}"。 - - 请生成相应的计划。 - """ - """用户提示词""" - - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: + def get_default_prompt(self) -> dict[LanguageType, str]: + system_prompt = { + LanguageType.CHINESE: r""" + 你是一个计划生成器。对于给定的目标,**制定一个简单的计划**,该计划可以逐步生成合适的命令行参数和标志。 + + 你会收到一个"命令前缀",这是已经确定和生成的命令部分。你需要基于这个前缀使用标志和参数来完成命令。 + + 在每一步中,指明使用哪个外部工具以及工具输入来获取证据。 + + 工具可以是以下之一: + (1) Option["指令"]:查询最相似的命令行标志。只接受一个输入参数,"指令"必须是搜索字符串。搜索字符串应该详细且包含必要的数据。 + (2) Argument[名称]<值>:将任务中的数据放置到命令行的特定位置。接受两个输入参数。 + + 所有步骤必须以"Plan: "开头,且少于150个单词。 + 不要添加任何多余的步骤。 + 确保每个步骤都包含所需的所有信息 - 不要跳过步骤。 + 不要在证据后面添加任何额外数据。 + + 开始示例 + + 任务:在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + 前缀:`docker run` + 用法:`docker run ${OPTS} ${image} ${command}`。这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 \ + ["image", "command"] 其中之一。 + 前缀描述:二进制程序`docker`的描述为"Docker容器平台",`run`子命令的描述为"从镜像创建并运行一个新的容器"。 + + Plan: 我需要一个标志使容器在后台运行。 #E1 = Option[在后台运行单个容器] + Plan: 我需要一个标志,将主机/root目录挂载至容器内/data目录。 #E2 = Option[挂载主机/root目录至/data目录] + Plan: 我需要从任务中解析出镜像名称。 #E3 = Argument[image] + Plan: 我需要指定容器中运行的命令。 #E4 = Argument[command] + Final: 组装上述线索,生成最终命令。 #F + + 示例结束 + + 让我们开始! + """, + LanguageType.ENGLISH: r""" + You are a plan generator. For a given goal, **draft a simple plan** that can step-by-step generate the \ + appropriate command line arguments and flags. + + You will receive a "command prefix", which is the already determined and generated command part. You need to \ + use the flags and arguments based on this prefix to complete the command. + + In each step, specify which external tool to use and the tool input to get the evidence. + + The tool can be one of the following: + (1) Option["instruction"]: Query the most similar command line flag. Only accepts one input parameter, \ + "instruction" must be a search string. The search string should be detailed and contain necessary data. + (2) Argument["name"]: Place the data from the task into a specific position in the command line. \ + Accepts two input parameters. + + All steps must start with "Plan: " and be less than 150 words. + Do not add any extra steps. + Ensure each step contains all the required information - do not skip steps. + Do not add any extra data after the evidence. + + Start example + + Task: Run a new alpine:latest container in the background, mount the host /root folder to /data, and execute \ + the top command. + Prefix: `docker run` + Usage: `docker run ${OPTS} ${image} ${command}`. This is a Python template string. OPTS is a placeholder for all \ + flags. The arguments must be one of ["image", "command"]. + Prefix description: The description of binary program `docker` is "Docker container platform", and the \ + description of `run` subcommand is "Create and run a new container from an image". + + Plan: I need a flag to make the container run in the background. #E1 = Option[Run a single container in the \ + background] + Plan: I need a flag to mount the host /root directory to /data directory in the container. #E2 = Option[Mount \ + host /root directory to /data directory] + Plan: I need to parse the image name from the task. #E3 = Argument[image] + Plan: I need to specify the command to be run in the container. #E4 = Argument[command] + Final: Assemble the above clues to generate the final command. #F + + End example + + Let's get started! + """, + } + """系统提示词""" + + user_prompt = { + LanguageType.CHINESE: r""" + 任务:{instruction} + 前缀:`{binary_name} {subcmd_name}` + 用法:`{subcmd_usage}`。这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 {argument_list} 其中之一。 + 前缀描述:二进制程序`{binary_name}`的描述为"{binary_description}",`{subcmd_name}`子命令的描述为\ + "{subcmd_description}"。 + + 请生成相应的计划。 + """, + LanguageType.ENGLISH: r""" + Task: {instruction} + Prefix: `{binary_name} {subcmd_name}` + Usage: `{subcmd_usage}`. This is a Python template string. OPTS is a placeholder for all flags. The arguments \ + must be one of {argument_list}. + Prefix description: The description of binary program `{binary_name}` is "{binary_description}", and the \ + description of `{subcmd_name}` subcommand is "{subcmd_description}". + + Please generate the corresponding plan. + """, + } + """用户提示词""" + return system_prompt, user_prompt + + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[LanguageType, str] | None = None, + ) -> None: """处理Prompt""" super().__init__(system_prompt, user_prompt) @@ -64,6 +127,7 @@ class InitPlan(CorePattern): spec = kwargs["spec"] binary_name = kwargs["binary_name"] subcmd_name = kwargs["subcmd_name"] + language = kwargs.get("language", LanguageType.CHINESE) binary_description = spec[binary_name][0] subcmd_usage = spec[binary_name][2][subcmd_name][1] subcmd_description = spec[binary_name][2][subcmd_name][0] @@ -73,16 +137,19 @@ class InitPlan(CorePattern): argument_list += [key] messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.user_prompt.format( - instruction=kwargs["instruction"], - binary_name=binary_name, - subcmd_name=subcmd_name, - binary_description=binary_description, - subcmd_description=subcmd_description, - subcmd_usage=subcmd_usage, - argument_list=argument_list, - )}, + {"role": "system", "content": self.system_prompt[language]}, + { + "role": "user", + "content": self.user_prompt[language].format( + instruction=kwargs["instruction"], + binary_name=binary_name, + subcmd_name=subcmd_name, + binary_description=binary_description, + subcmd_description=subcmd_description, + subcmd_usage=subcmd_usage, + argument_list=argument_list, + ), + }, ] result = "" @@ -98,46 +165,84 @@ class InitPlan(CorePattern): class PlanEvaluator(CorePattern): """计划评估器""" - system_prompt: str = r""" - 你是一个计划评估器。你的任务是评估给定的计划是否合理和完整。 - - 一个好的计划应该: - 1. 涵盖原始任务的所有要求 - 2. 使用适当的工具收集必要的信息 - 3. 具有清晰和逻辑的步骤 - 4. 没有冗余或不必要的步骤 - - 对于计划中的每个步骤,评估: - 1. 工具选择是否适当 - 2. 输入参数是否清晰和充分 - 3. 该步骤是否有助于实现最终目标 - - 请回复: - "VALID" - 如果计划良好且完整 - "INVALID: <原因>" - 如果计划有问题,请解释原因 - """ - """系统提示词""" - - user_prompt: str = r""" - 任务:{instruction} - 计划:{plan} - - 评估计划并回复"VALID"或"INVALID: <原因>"。 - """ - """用户提示词""" - - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: + def get_default_prompt(self) -> dict[LanguageType, str]: + system_prompt = { + LanguageType.CHINESE: r""" + 你是一个计划评估器。你的任务是评估给定的计划是否合理和完整。 + + 一个好的计划应该: + 1. 涵盖原始任务的所有要求 + 2. 使用适当的工具收集必要的信息 + 3. 具有清晰和逻辑的步骤 + 4. 没有冗余或不必要的步骤 + + 对于计划中的每个步骤,评估: + 1. 工具选择是否适当 + 2. 输入参数是否清晰和充分 + 3. 该步骤是否有助于实现最终目标 + + 请回复: + "VALID" - 如果计划良好且完整 + "INVALID: <原因>" - 如果计划有问题,请解释原因 + """, + LanguageType.ENGLISH: r""" + You are a plan evaluator. Your task is to evaluate whether the given plan is reasonable and complete. + + A good plan should: + 1. Cover all requirements of the original task + 2. Use appropriate tools to collect necessary information + 3. Have clear and logical steps + 4. Have no redundant or unnecessary steps + + For each step in the plan, evaluate: + 1. Whether the tool selection is appropriate + 2. Whether the input parameters are clear and sufficient + 3. Whether this step helps achieve the final goal + + Please reply: + "VALID" - If the plan is good and complete + "INVALID: <原因>" - If the plan has problems, please explain the reason + """, + } + """系统提示词""" + + user_prompt = { + LanguageType.CHINESE: r""" + 任务:{instruction} + 计划:{plan} + + 评估计划并回复"VALID"或"INVALID: <原因>"。 + """, + LanguageType.ENGLISH: r""" + Task: {instruction} + Plan: {plan} + + Evaluate the plan and reply with "VALID" or "INVALID: <原因>". + """, + } + """用户提示词""" + return system_prompt, user_prompt + + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[LanguageType, str] | None = None, + ) -> None: """初始化Prompt""" super().__init__(system_prompt, user_prompt) async def generate(self, **kwargs) -> str: """生成计划评估结果""" + language = kwargs.get("language", LanguageType.CHINESE) messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.user_prompt.format( - instruction=kwargs["instruction"], - plan=kwargs["plan"], - )}, + {"role": "system", "content": self.system_prompt[language]}, + { + "role": "user", + "content": self.user_prompt[language].format( + instruction=kwargs["instruction"], + plan=kwargs["plan"], + ), + }, ] result = "" @@ -153,45 +258,81 @@ class PlanEvaluator(CorePattern): class RePlanner(CorePattern): """重新规划器""" - system_prompt: str = r""" - 你是一个计划重新规划器。当计划被评估为无效时,你需要生成一个新的、改进的计划。 - - 新计划应该: - 1. 解决评估中提到的所有问题 - 2. 保持与原始计划相同的格式 - 3. 更加精确和完整 - 4. 为每个步骤使用适当的工具 - - 遵循与原始计划相同的格式: - - 每个步骤应以"Plan: "开头 - - 包含带有适当参数的工具使用 - - 保持步骤简洁和重点突出 - - 以"Final"步骤结束 - """ - """系统提示词""" - - user_prompt: str = r""" - 任务:{instruction} - 原始计划:{plan} - 评估:{evaluation} - - 生成一个新的、改进的计划,解决评估中提到的所有问题。 - """ - """用户提示词""" - - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: + def get_default_prompt(self) -> dict[LanguageType, str]: + system_prompt = { + LanguageType.CHINESE: r""" + 你是一个计划重新规划器。当计划被评估为无效时,你需要生成一个新的、改进的计划。 + + 新计划应该: + 1. 解决评估中提到的所有问题 + 2. 保持与原始计划相同的格式 + 3. 更加精确和完整 + 4. 为每个步骤使用适当的工具 + + 遵循与原始计划相同的格式: + - 每个步骤应以"Plan: "开头 + - 包含带有适当参数的工具使用 + - 保持步骤简洁和重点突出 + - 以"Final"步骤结束 + """, + LanguageType.ENGLISH: r""" + You are a plan replanner. When the plan is evaluated as invalid, you need to generate a new, improved plan. + + The new plan should: + 1. Solve all problems mentioned in the evaluation + 2. Keep the same format as the original plan + 3. Be more precise and complete + 4. Use appropriate tools for each step + + Follow the same format as the original plan: + - Each step should start with "Plan: " + - Include tool usage with appropriate parameters + - Keep steps concise and focused + - End with the "Final" step + """, + } + """系统提示词""" + + user_prompt = { + LanguageType.CHINESE: r""" + 任务:{instruction} + 原始计划:{plan} + 评估:{evaluation} + + 生成一个新的、改进的计划,解决评估中提到的所有问题。 + """, + LanguageType.ENGLISH: r""" + Task: {instruction} + Original Plan: {plan} + Evaluation: {evaluation} + + Generate a new, improved plan that solves all problems mentioned in the evaluation. + """, + } + """用户提示词""" + return system_prompt, user_prompt + + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[LanguageType, str] | None = None, + ) -> None: """初始化Prompt""" super().__init__(system_prompt, user_prompt) async def generate(self, **kwargs) -> str: """生成重新规划结果""" + language = kwargs.get("language", LanguageType.CHINESE) messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.user_prompt.format( - instruction=kwargs["instruction"], - plan=kwargs["plan"], - evaluation=kwargs["evaluation"], - )}, + {"role": "system", "content": self.system_prompt[language]}, + { + "role": "user", + "content": self.user_prompt[language].format( + instruction=kwargs["instruction"], + plan=kwargs["plan"], + evaluation=kwargs["evaluation"], + ), + }, ] result = "" diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index 15d52ab288b4c964e81ee290894bd29a590bbab2..3fc53f811d033a79c180671feb6d74ecca9add22 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -4,11 +4,16 @@ import logging from pydantic import BaseModel, Field +from textwrap import dedent + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment from apps.llm.function import JsonGenerator from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM from apps.llm.token import TokenCalculator +from apps.schemas.enum_var import LanguageType logger = logging.getLogger(__name__) @@ -18,82 +23,159 @@ class QuestionRewriteResult(BaseModel): question: str = Field(description="补全后的问题") +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, +) + + class QuestionRewrite(CorePattern): """问题补全与重写""" - system_prompt: str = r""" + def get_default_prompt(self) -> dict[LanguageType, str]: + system_prompt = { + LanguageType.CHINESE: dedent(r""" + + + 根据历史对话,推断用户的实际意图并补全用户的提问内容,历史对话被包含在标签中,用户意图被包含在标签中。 + 要求: + 1. 请使用JSON格式输出,参考下面给出的样例;不要包含任何XML标签,不要包含任何解释说明; + 2. 若用户当前提问内容与对话上文不相关,或你认为用户的提问内容已足够完整,请直接输出用户的提问内容。 + 3. 补全内容必须精准、恰当,不要编造任何内容。 + 4. 请输出补全后的问题,不要输出其他内容。 + 输出格式样例: + ```json + { + "question": "补全后的问题" + } + ``` + + + + + + + openEuler的特点是什么? + + + openEuler相较于其他操作系统,其特点是支持多种硬件架构,并且提供稳定、安全、高效的操作系统平台。 + + + + + openEuler的优势有哪些? + + + openEuler的优势包括开源、社区支持、以及对云计算和边缘计算的优化。 + + + + + + 详细点? + + + ```json + { + "question": "openEuler的特点是什么?请详细说明其优势和应用场景。" + } + ``` + + + + + {{history}} + + + {{question}} + + """), + LanguageType.ENGLISH: dedent(r""" + + + Based on the historical dialogue, infer the user's actual intent and complete the user's question. The historical dialogue is contained within the tags, and the user's intent is contained within the tags. + Requirements: + 1. Please output in JSON format, referring to the example provided below; do not include any XML tags or any explanatory notes; + 2. If the user's current question is unrelated to the previous dialogue or you believe the user's question is already complete enough, directly output the user's question. + 3. The completed content must be precise and appropriate; do not fabricate any content. + 4. Output only the completed question; do not include any other content. + Example output format: + ```json + { + "question": "The completed question" + } + ``` + + + + + + + What are the features of openEuler? + + + Compared to other operating systems, openEuler's features include support for multiple hardware architectures and providing a stable, secure, and efficient operating system platform. + + + + + What are the advantages of openEuler? + + + The advantages of openEuler include being open-source, having community support, and optimizations for cloud and edge computing. + + + + + + More details? + + + ```json + { + "question": "What are the features of openEuler? Please elaborate on its advantages and application scenarios." + } + + + + + {{history}} + + + {{question}} + + """) + } + + """用户提示词""" + user_prompt = { + LanguageType.CHINESE: r""" - - 根据历史对话,推断用户的实际意图并补全用户的提问内容,历史对话被包含在标签中,用户意图被包含在标签中。 - 要求: - 1. 请使用JSON格式输出,参考下面给出的样例;不要包含任何XML标签,不要包含任何解释说明; - 2. 若用户当前提问内容与对话上文不相关,或你认为用户的提问内容已足够完整,请直接输出用户的提问内容。 - 3. 补全内容必须精准、恰当,不要编造任何内容。 - 4. 请输出补全后的问题,不要输出其他内容。 - 输出格式样例: - {{ - "question": "补全后的问题" - }} - - - - - - - openEuler的特点是什么? - - - openEuler相较于其他操作系统,其特点是支持多种硬件架构,并且提供稳定、安全、高效的操作系统平台。 - - - - - openEuler的优势有哪些? - - - openEuler的优势包括开源、社区支持、以及对云计算和边缘计算的优化。 - - - - - - 详细点? - - - {{ - "question": "openEuler的特点是什么?请详细说明其优势和应用场景。" - }} - - + 请输出补全后的问题 - - {history} - - - {question} - - """ - """用户提示词""" - user_prompt: str = """ - - 请输出补全后的问题 - - """ + """, + LanguageType.ENGLISH: r""" + + Please output the completed question + + """} + return system_prompt, user_prompt async def generate(self, **kwargs) -> str: # noqa: ANN003 """问题补全与重写""" history = kwargs.get("history", []) question = kwargs["question"] llm = kwargs.get("llm", None) + language = kwargs.get("language", LanguageType.CHINESE) if not llm: llm = ReasoningLLM() leave_tokens = llm._config.max_tokens leave_tokens -= TokenCalculator().calculate_token_length( - messages=[ - {"role": "system", "content": self.system_prompt.format(history="", question=question)}, - {"role": "user", "content": self.user_prompt} - ] - ) + messages=[{"role": "system", "content": _env.from_string(self.system_prompt[language]).render( + history="", question=question)}, + {"role": "user", "content": _env.from_string(self.user_prompt[language]).render()}]) if leave_tokens <= 0: logger.error("[QuestionRewrite] 大模型上下文窗口不足,无法进行问题补全与重写") return question @@ -112,16 +194,17 @@ class QuestionRewrite(CorePattern): if leave_tokens >= 0: qa = sub_qa + qa index += 2 - messages = [ - {"role": "system", "content": self.system_prompt.format(history=qa, question=question)}, - {"role": "user", "content": self.user_prompt} - ] + messages = [{"role": "system", "content": _env.from_string(self.system_prompt[language]).render( + history=qa, question=question)}, {"role": "user", "content": _env.from_string(self.user_prompt[language]).render()}] result = "" async for chunk in llm.call(messages, streaming=False): result += chunk self.input_tokens = llm.input_tokens self.output_tokens = llm.output_tokens + tmp_js = await JsonGenerator._parse_result_by_stack(result, QuestionRewriteResult.model_json_schema()) + if tmp_js is not None: + return tmp_js['question'] messages += [{"role": "assistant", "content": result}] json_gen = JsonGenerator( query="根据给定的背景信息,生成预测问题", diff --git a/apps/llm/patterns/select.py b/apps/llm/patterns/select.py index a6c496bdd0ef79631bbdc41935e717ca3e3668ea..4d9ab3a08832b120e87a9e7f9ac0dde7980fee76 100644 --- a/apps/llm/patterns/select.py +++ b/apps/llm/patterns/select.py @@ -11,6 +11,7 @@ from apps.llm.function import JsonGenerator from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM from apps.llm.snippet import choices_to_prompt +from apps.schemas.enum_var import LanguageType logger = logging.getLogger(__name__) @@ -18,61 +19,123 @@ logger = logging.getLogger(__name__) class Select(CorePattern): """通过投票选择最佳答案""" - system_prompt: str = "You are a helpful assistant." - """系统提示词""" - - user_prompt: str = r""" - - - 根据历史对话(包括工具调用结果)和用户问题,从给出的选项列表中,选出最符合要求的那一项。 - 在输出之前,请先思考,并使用“”标签给出思考过程。 - 结果需要使用JSON格式输出,输出格式为:{{ "choice": "选项名称" }} - - - - - 使用天气API,查询明天杭州的天气信息 + def get_default_prompt(self) -> dict[LanguageType, str]: + system_prompt = { + LanguageType.CHINESE: "你是一个有用的助手。", + LanguageType.ENGLISH: "You are a helpful assistant.", + } + """系统提示词""" + + user_prompt = { + LanguageType.CHINESE: r""" + + + 根据历史对话(包括工具调用结果)和用户问题,从给出的选项列表中,选出最符合要求的那一项。 + 在输出之前,请先思考,并使用“”标签给出思考过程。 + 结果需要使用JSON格式输出,输出格式为:{{ "choice": "选项名称" }} + + + + + 使用天气API,查询明天杭州的天气信息 + + + + API + HTTP请求,获得返回的JSON数据 + + + SQL + 查询数据库,获得数据库表中的数据 + + + + + + API 工具可以通过 API 来获取外部数据,而天气信息可能就存储在外部数据中,由于用户说明中明确提到了 \ + 天气 API 的使用,因此应该优先使用 API 工具。\ + SQL 工具用于从数据库中获取信息,考虑到天气数据的可变性和动态性,不太可能存储在数据库中,因此 \ + SQL 工具的优先级相对较低,\ + 最佳选择似乎是“API:请求特定 API,获取返回的 JSON 数据”。 + + + + {{ "choice": "API" }} + + + + + + + {question} + + + + {choice_list} + + + + + 让我们一步一步思考。 + """, + LanguageType.ENGLISH: r""" + + + Based on the historical dialogue (including tool call results) and user question, select the most \ + suitable option from the given option list. + Before outputting, please think carefully and use the "" tag to give the thinking process. + The output needs to be in JSON format, the output format is: {{ "choice": "option name" }} + + + + + Use the weather API to query the weather information of Hangzhou tomorrow + + + + API + HTTP request, get the returned JSON data + + + SQL + Query the database, get the data in the database table + + + + + + The API tool can get external data through API, and the weather information may be stored in \ + external data. Since the user clearly mentioned the use of weather API, it should be given \ + priority to the API tool.\ + The SQL tool is used to get information from the database, considering the variability and \ + dynamism of weather data, it is unlikely to be stored in the database, so the priority of \ + the SQL tool is relatively low, \ + The best choice seems to be "API: request a specific API, get the returned JSON data". + + + + {{ "choice": "API" }} + + + + + + {question} + - - API - HTTP请求,获得返回的JSON数据 - - - SQL - 查询数据库,获得数据库表中的数据 - + {choice_list} - API 工具可以通过 API 来获取外部数据,而天气信息可能就存储在外部数据中,由于用户说明中明确提到了 \ - 天气 API 的使用,因此应该优先使用 API 工具。\ - SQL 工具用于从数据库中获取信息,考虑到天气数据的可变性和动态性,不太可能存储在数据库中,因此 \ - SQL 工具的优先级相对较低,\ - 最佳选择似乎是“API:请求特定 API,获取返回的 JSON 数据”。 + Let's think step by step. - - - {{ "choice": "API" }} - - - - - - - {question} - - - - {choice_list} - - - - - 让我们一步一步思考。 - """ - """用户提示词""" + + """, + } + """用户提示词""" + return system_prompt, user_prompt slot_schema: ClassVar[dict[str, Any]] = { "type": "object", @@ -86,17 +149,19 @@ class Select(CorePattern): } """最终输出的JSON Schema""" - - def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: - """初始化Prompt""" + def __init__( + self, + system_prompt: dict[LanguageType, str] | None = None, + user_prompt: dict[str, str] | None = None, + ) -> None: + """处理Prompt""" super().__init__(system_prompt, user_prompt) - async def _generate_single_attempt(self, user_input: str, choice_list: list[str]) -> str: """使用ReasoningLLM进行单次尝试""" logger.info("[Select] 单次选择尝试: %s", user_input) messages = [ - {"role": "system", "content": self.system_prompt}, + {"role": "system", "content": self.system_prompt[self.language]}, {"role": "user", "content": user_input}, ] result = "" @@ -120,7 +185,6 @@ class Select(CorePattern): function_result = await json_gen.generate() return function_result["choice"] - async def generate(self, **kwargs) -> str: # noqa: ANN003 """使用大模型做出选择""" logger.info("[Select] 使用LLM选择") @@ -128,6 +192,7 @@ class Select(CorePattern): result_list = [] background = kwargs.get("background", "无背景信息。") + language = kwargs.get("language", LanguageType.CHINESE) data_str = json.dumps(kwargs.get("data", {}), ensure_ascii=False) choice_prompt, choices_list = choices_to_prompt(kwargs["choices"]) @@ -141,7 +206,7 @@ class Select(CorePattern): return choices_list[0] logger.info("[Select] 选项列表: %s", choice_prompt) - user_input = self.user_prompt.format( + user_input = self.user_prompt[language].format( question=kwargs["question"], background=background, data=data_str, diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py index 7cf555be390bc72301323cd103c1a7ec0ab7a085..40fc015881b908cdf24c4fc8d2ac9c84b0efec73 100644 --- a/apps/llm/prompt.py +++ b/apps/llm/prompt.py @@ -19,7 +19,7 @@ JSON_GEN_BASIC = dedent(r""" Background information is given in XML tags. - Here are the conversations between you and the user: + Here are the background information between you and the user: {% if conversation|length > 0 %} {% for message in conversation %} @@ -48,7 +48,6 @@ JSON_GEN_BASIC = dedent(r""" {% endif %} - {% if not function_call %} # Tools You must call one function to assist with the user query. @@ -67,5 +66,4 @@ JSON_GEN_BASIC = dedent(r""" # Output - {% endif %} """) diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index fdb36fc05adf38920bcce0d962b6aafc21e44b71..fddc84db89d98bb8c29d96b4dee71aeed49804ea 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -61,15 +61,16 @@ class ReasoningContent: return reason, text if self.reasoning_type == "args": - if hasattr(chunk.choices[0].delta, "reasoning_content"): + if hasattr( + chunk.choices[0].delta, "reasoning_content") and chunk.choices[0].delta.reasoning_content is not None: # type: ignore[attr-defined] + # 仍在推理中,继续添加推理内容 reason = chunk.choices[0].delta.reasoning_content or "" # type: ignore[attr-defined] else: # 推理结束,设置标志并添加结束标签 self.is_reasoning = False reason = "" # 如果当前内容不是推理内容标签,将其作为文本返回 - if content and not content.startswith(""): - text = content + text = content.lstrip("") elif self.reasoning_type == "tokens": for token in REASONING_END_TOKEN: if token == content: @@ -141,10 +142,11 @@ class ReasoningLLM: return await self._client.chat.completions.create( model=model, messages=messages, # type: ignore[] - max_tokens=max_tokens or self._config.max_tokens, + max_completion_tokens=max_tokens or self._config.max_tokens, temperature=temperature or self._config.temperature, stream=True, stream_options={"include_usage": True}, + timeout=300 ) # type: ignore[] async def call( # noqa: C901, PLR0912, PLR0913 diff --git a/apps/llm/token.py b/apps/llm/token.py index 532c575f71dbce95eaa55c99f8a2c956053c8dbc..3536da2b4bcc50d86b7b441f769a4a8b02dd027c 100644 --- a/apps/llm/token.py +++ b/apps/llm/token.py @@ -9,6 +9,33 @@ logger = logging.getLogger(__name__) class TokenCalculator(metaclass=SingletonMeta): """用于计算Token消耗量""" + @staticmethod + def get_k_tokens_words_from_content(content: str, k: int | None = None) -> str: + """获取k个token的词""" + if k is None: + return content + if k <= 0: + return "" + try: + if TokenCalculator().calculate_token_length(messages=[ + {"role": "user", "content": content}, + ], pure_text=True) <= k: + return content + l = 0 + r = len(content) + while l + 1 < r: + mid = (l + r) // 2 + if TokenCalculator().calculate_token_length(messages=[ + {"role": "user", "content": content[:mid]}, + ], pure_text=True) <= k: + l = mid + else: + r = mid + return content[:l] + except Exception: + logger.exception("[RAG] 获取k个token的词失败") + return "" + def __init__(self) -> None: """初始化Tokenizer""" import tiktoken diff --git a/apps/main.py b/apps/main.py index cdc844ab4bcb9e56717c3de4f6a769335bf6fd7d..1389af0ed51bbbf0741f0fcad711c2b6af87b018 100644 --- a/apps/main.py +++ b/apps/main.py @@ -8,7 +8,6 @@ from __future__ import annotations import asyncio import logging -import logging.config import signal import sys from contextlib import asynccontextmanager @@ -135,17 +134,52 @@ logging.basicConfig( width=160, ))], ) +logger = logging.getLogger(__name__) +async def add_no_auth_user() -> None: + """ + 添加无认证用户 + """ + from apps.common.mongo import MongoDB + from apps.schemas.collection import User + import os + mongo = MongoDB() + user_collection = mongo.get_collection("user") + username = os.environ.get('USERNAME') # 适用于 Windows 系统 + if not username: + username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 + if not username: + username = "admin" + try: + await user_collection.insert_one(User( + _id=username, + is_admin=True, + auto_execute=False + ).model_dump(by_alias=True)) + except Exception as e: + logger.error(f"[add_no_auth_user] 默认用户 {username} 已存在") + +async def clear_user_activity() -> None: + """清除所有用户的活跃状态""" + from apps.services.activity import Activity + from apps.common.mongo import MongoDB + mongo = MongoDB() + activity_collection = mongo.get_collection("activity") + await activity_collection.delete_many({}) + logging.info("清除所有用户活跃状态完成") async def init_resources() -> None: """初始化必要资源""" - logger = logging.getLogger(__name__) WordsCheck() await LanceDB().init() await Pool.init() TokenCalculator() + if Config().get_config().no_auth.enable: + await add_no_auth_user() + await clear_user_activity() + # 初始化变量池管理器 from apps.scheduler.variable.pool_manager import initialize_pool_manager await initialize_pool_manager() @@ -174,7 +208,6 @@ async def init_resources() -> None: except Exception as e: logging.warning(f"前置节点变量缓存服务初始化失败(将降级使用实时解析): {e}") - async def startup_file_cleanup(): """启动时清理遗留文件(除了已绑定历史记录的文件)""" logger = logging.getLogger(__name__) @@ -258,7 +291,6 @@ async def startup_file_cleanup(): except Exception as e: logger.error(f"启动时文件清理失败: {e}") - async def cleanup_orphaned_files(): """清理孤儿文件(不被任何变量引用且未绑定历史记录的文件)""" logger = logging.getLogger(__name__) diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 0ec4db9155a06fe16cbeddff223ccd9a764d8b3d..ca2d93d5d0e11f69713a4130262de08d8191fd20 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -13,6 +13,8 @@ from apps.schemas.appcenter import AppFlowInfo, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType from apps.schemas.request_data import CreateAppRequest, ModFavAppRequest from apps.schemas.response_data import ( + AppMcpServiceInfo, + LLMIteam, BaseAppOperationMsg, BaseAppOperationRsp, GetAppListMsg, @@ -25,7 +27,8 @@ from apps.schemas.response_data import ( ResponseData, ) from apps.services.appcenter import AppCenterManager - +from apps.services.llm import LLMManager +from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/app", @@ -180,6 +183,7 @@ async def get_recently_used_applications( @router.get("/{appId}", response_model=GetAppPropertyRsp | ResponseData) async def get_application( + user_sub: Annotated[str, Depends(get_user)], app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], ) -> JSONResponse: """获取应用详情""" @@ -214,6 +218,24 @@ async def get_application( ) for flow in app_data.flows ] + mcp_service = [] + if app_data.mcp_service: + for service in app_data.mcp_service: + mcp_collection = await MCPServiceManager.get_mcp_service(service) + mcp_service.append(AppMcpServiceInfo( + id=mcp_collection.id, + name=mcp_collection.name, + description=mcp_collection.description, + )) + if app_data.llm_id == "empty": + llm_item = LLMIteam() + else: + llm_collection = await LLMManager.get_llm_by_id(user_sub, app_data.llm_id) + llm_item = LLMIteam( + llmId=llm_collection.id, + modelName=llm_collection.model_name, + icon=llm_collection.icon + ) return JSONResponse( status_code=status.HTTP_200_OK, content=GetAppPropertyRsp( @@ -234,7 +256,8 @@ async def get_application( authorizedUsers=app_data.permission.users, ), workflows=workflows, - mcpService=app_data.mcp_service, + mcpService=mcp_service, + llm=llm_item, ), ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 42e770ae55e99ee81f51afdbccf320e83c249363..c2d1a6a741b6ef3989f45019db9c9fd581a5f25a 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -74,7 +74,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: status_code=status.HTTP_403_FORBIDDEN, ) - await UserManager.update_userinfo_by_user_sub(user_sub) + await UserManager.update_refresh_revision_by_user_sub(user_sub) current_session = await SessionManager.create_session(user_host, user_sub) @@ -177,6 +177,7 @@ async def userinfo( user_sub=user_sub, revision=user.is_active, is_admin=user.is_admin, + auto_execute=user.auto_execute, ), ).model_dump(exclude_none=True, by_alias=True), ) @@ -192,7 +193,7 @@ async def userinfo( ) async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001 """更新用户协议信息""" - ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True) + ret: bool = await UserManager.update_refresh_revision_by_user_sub(user_sub, refresh_revision=True) if not ret: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 5bca9336a64655a362393c659216667c7b0ef3d7..9391f021785bea6b660309d4663144da39324a19 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -7,20 +7,24 @@ import uuid from collections.abc import AsyncGenerator from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue from apps.common.wordscheck import WordsCheck from apps.dependency import get_session, get_user +from apps.schemas.enum_var import FlowStatus from apps.scheduler.scheduler import Scheduler from apps.scheduler.scheduler.context import save_data -from apps.schemas.request_data import RequestData +from apps.schemas.request_data import RequestData, RequestDataApp from apps.schemas.response_data import ResponseData +from apps.schemas.enum_var import LanguageType from apps.schemas.task import Task from apps.services.activity import Activity from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager from apps.services.flow import FlowManager +from apps.services.conversation import ConversationManager +from apps.services.record import RecordManager from apps.services.task import TaskManager from apps.services.appcenter import AppCenterManager from apps.scheduler.variable.pool_manager import get_pool_manager @@ -146,32 +150,50 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - if post_body.new_task: - # 创建或还原Task - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) - if task: - await TaskManager.delete_task_by_task_id(task.id) - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + # 更改信息并刷新数据库 - if post_body.new_task: + if post_body.task_id is None: + conversation = await ConversationManager.get_conversation_by_conversation_id( + user_sub=user_sub, + conversation_id=post_body.conversation_id, + ) + if not conversation: + err = "[Chat] 用户没有权限访问该对话!" + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err) + task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id) + await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids) + task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body) task.runtime.question = post_body.question task.ids.group_id = post_body.group_id + task.state.app_id = post_body.app.app_id if post_body.app else "" + else: + if not post_body.task_id: + err = "[Chat] task_id 不可为空!" + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty") + task = await TaskManager.get_task_by_task_id(post_body.task_id) + post_body.app = RequestDataApp(appId=task.state.app_id) + post_body.group_id = task.ids.group_id + post_body.conversation_id = task.ids.conversation_id + post_body.language = task.language + post_body.question = task.runtime.question + task.language = post_body.language return task async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" try: - await Activity.set_active(user_sub) + active_id = await Activity.set_active(user_sub) # 敏感词检查 if await WordsCheck().check(post_body.question) != 1: yield "data: [SENSITIVE]\n\n" logger.info("[Chat] 问题包含敏感词!") - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) return task = await init_task(post_body, user_sub, session_id) + task.ids.active_id = active_id # 检查必填文件变量 flow_id_for_check = None @@ -224,23 +246,17 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 获取最终答案 task = scheduler.task - # 🔑 修复:对于工作流调试模式或纯逻辑节点,允许答案为空 - is_flow_debug = post_body.app and post_body.app.flow_id - if not task.runtime.answer and not is_flow_debug: - logger.error("[Chat] 答案为空且非工作流调试模式") + if task.state.flow_status == FlowStatus.ERROR: + logger.error("[Chat] 生成答案失败") yield "data: [ERROR]\n\n" - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) return - elif not task.runtime.answer and is_flow_debug: - logger.info("[Chat] 工作流调试模式,答案为空是正常的(可能是纯逻辑节点)") - # 为工作流调试提供默认响应 - task.runtime.answer = "工作流执行完成" # 对结果进行敏感词检查 if await WordsCheck().check(task.runtime.answer) != 1: yield "data: [SENSITIVE]\n\n" logger.info("[Chat] 答案包含敏感词!") - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) return # 创建新Record,存入数据库 @@ -260,7 +276,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) yield "data: [ERROR]\n\n" finally: - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) @router.post("/chat") @@ -270,16 +286,13 @@ async def chat( session_id: Annotated[str, Depends(get_session)], ) -> StreamingResponse: """LLM流式对话接口""" + post_body.language = LanguageType.CHINESE if post_body.language in {"zh", LanguageType.CHINESE} else LanguageType.ENGLISH # 前端 Flow-Debug 传输为“zh" # 问题黑名单检测 - if not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question): + if post_body.question is not None and not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question): # 用户扣分 await UserBlacklistManager.change_blacklisted_users(user_sub, -10) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") - # 限流检查 - if await Activity.is_active(user_sub): - raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") - res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, @@ -291,9 +304,12 @@ async def chat( @router.post("/stop", response_model=ResponseData) -async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 +async def stop_generation(user_sub: Annotated[str, Depends(get_user)], + task_id: Annotated[str, Query(..., alias="taskId")] = "") -> JSONResponse: """停止生成""" - await Activity.remove_active(user_sub) + task = await TaskManager.get_task_by_task_id(task_id) + if task: + await Activity.remove_active(task.ids.active_id) return JSONResponse( status_code=status.HTTP_200_OK, content=ResponseData( diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 1620ab5421a4018fd316a6d55939d859befd5282..c152b427133222b302a75a2199932e8dc20ffc49 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -44,6 +44,7 @@ logger = logging.getLogger(__name__) async def create_new_conversation( + title: str, user_sub: str, app_id: str = "", llm_id: str = "empty", @@ -57,7 +58,8 @@ async def create_new_conversation( err = "Invalid app_id." raise RuntimeError(err) new_conv = await ConversationManager.add_conversation_by_user_sub( - user_sub, + title=title, + user_sub=user_sub, app_id=app_id, llm_id=llm_id, kb_ids=kb_ids or [], @@ -127,6 +129,7 @@ async def get_conversation_list(user_sub: Annotated[str, Depends(get_user)]) -> async def add_conversation( user_sub: Annotated[str, Depends(get_user)], app_id: Annotated[str, Query(..., alias="appId")] = "", + title: Annotated[str, Body(...)] = "New Chat", llm_id: Annotated[str, Body(..., alias="llmId")] = "empty", kb_ids: Annotated[list[str] | None, Body(..., alias="kbIds")] = None, *, @@ -138,7 +141,8 @@ async def add_conversation( app_id = app_id if app_id else "" debug = debug if debug is not None else False new_conv = await create_new_conversation( - user_sub, + title=title, + user_sub=user_sub, app_id=app_id, llm_id=llm_id, kb_ids=kb_ids or [], @@ -185,15 +189,15 @@ async def update_conversation( ) # 更新Conversation数据 - change_status = await ConversationManager.update_conversation_by_conversation_id( - user_sub, - conversation_id, - { - "title": post_body.title, - }, - ) - - if not change_status: + try: + await ConversationManager.update_conversation_by_conversation_id( + user_sub, + conversation_id, + { + "title": post_body.title, + }, + ) + except Exception as e: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( diff --git a/apps/routers/flow.py b/apps/routers/flow.py index 256321ac6b4781e9cfeeb2f8506daff0e2bad056..620777d769ac3383f7ca45eced6ba0f3566929b8 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -21,6 +21,7 @@ from apps.schemas.response_data import ( NodeServiceListRsp, ResponseData, ) +from apps.schemas.enum_var import LanguageType from apps.services.appcenter import AppCenterManager from apps.services.application import AppManager from apps.services.flow import FlowManager @@ -46,9 +47,10 @@ router = APIRouter( ) async def get_services( user_sub: Annotated[str, Depends(get_user)], + language: LanguageType = Query(LanguageType.CHINESE, description="语言参数,默认为中文") ) -> NodeServiceListRsp: """获取用户可访问的节点元数据所在服务的信息""" - services = await FlowManager.get_service_by_user_id(user_sub) + services = await FlowManager.get_service_by_user_id(user_sub, language) if services is None: return NodeServiceListRsp( code=status.HTTP_404_NOT_FOUND, diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index a845a3761ea2ab179a1b71d4674d0d5fd3b5a760..b238a5cda115e880d75ea91f6aa90a34a0360815 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -36,6 +36,7 @@ router = APIRouter( dependencies=[Depends(verify_user)], ) + async def _check_user_admin(user_sub: str) -> None: user = await UserManager.get_userinfo_by_user_sub(user_sub) if not user: @@ -52,6 +53,8 @@ async def get_mcpservice_list( ] = SearchType.ALL, keyword: Annotated[str | None, Query(..., alias="keyword", description="搜索关键字")] = None, page: Annotated[int, Query(..., alias="page", ge=1, description="页码")] = 1, + is_install: Annotated[bool | None, Query(..., alias="isInstall", description="是否已安装")] = None, + is_active: Annotated[bool | None, Query(..., alias="isActive", description="是否激活")] = None, ) -> JSONResponse: """获取服务列表""" try: @@ -60,6 +63,8 @@ async def get_mcpservice_list( user_sub, keyword, page, + is_install, + is_active ) except Exception as e: err = f"[MCPServiceCenter] 获取MCP服务列表失败: {e}" @@ -89,7 +94,7 @@ async def get_mcpservice_list( @router.post("", response_model=UpdateMCPServiceRsp) async def create_or_update_mcpservice( user_sub: Annotated[str, Depends(get_user)], # TODO: get_user直接获取所有用户信息 - data: UpdateMCPServiceRequest, + data: UpdateMCPServiceRequest ) -> JSONResponse: """新建或更新MCP服务""" await _check_user_admin(user_sub) @@ -130,6 +135,35 @@ async def create_or_update_mcpservice( ).model_dump(exclude_none=True, by_alias=True)) +@router.post("/{serviceId}/install") +async def install_mcp_service( + user_sub: Annotated[str, Depends(get_user)], + service_id: Annotated[str, Path(..., alias="serviceId", description="服务ID")], + install: Annotated[bool, Query(..., description="是否安装")] = True, +) -> JSONResponse: + try: + await MCPServiceManager.install_mcpservice(user_sub, service_id, install) + except Exception as e: + err = f"[MCPService] 安装mcp服务失败: {e!s}" if install else f"[MCPService] 卸载mcp服务失败: {e!s}" + logger.exception(err) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=err, + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="OK", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + + @router.get("/{serviceId}", response_model=GetMCPServiceDetailRsp) async def get_service_detail( user_sub: Annotated[str, Depends(get_user)], @@ -166,11 +200,8 @@ async def get_service_detail( name=data.name, description=data.description, overview=config.overview, - data=json.dumps( - config.config.model_dump(by_alias=True, exclude_none=True), - indent=4, - ensure_ascii=False, - ), + data=config.config.model_dump( + exclude_none=True, by_alias=True), mcpType=config.type, ) else: @@ -181,6 +212,7 @@ async def get_service_detail( name=data.name, description=data.description, overview=config.overview, + status=data.status, tools=data.tools, ) @@ -225,7 +257,7 @@ async def delete_service( ) -@router.post("/icon", response_model=UpdateMCPServiceRsp) +@router.post("/icon/{serviceId}", response_model=UpdateMCPServiceRsp) async def update_mcp_icon( user_sub: Annotated[str, Depends(get_user)], service_id: Annotated[str, Path(..., alias="serviceId", description="服务ID")], @@ -282,7 +314,7 @@ async def active_or_deactivate_mcp_service( """激活/取消激活mcp""" try: if data.active: - await MCPServiceManager.active_mcpservice(user_sub, service_id) + await MCPServiceManager.active_mcpservice(user_sub, service_id, data.mcp_env) else: await MCPServiceManager.deactive_mcpservice(user_sub, service_id) except Exception as e: diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..6edbe2e142cb6589be8947c28ae2eb4a7287baa1 --- /dev/null +++ b/apps/routers/parameter.py @@ -0,0 +1,77 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, status +from fastapi.responses import JSONResponse + +from apps.dependency import get_user +from apps.dependency.user import verify_user +from apps.services.parameter import ParameterManager +from apps.schemas.response_data import ( + GetOperaRsp, + GetParamsRsp +) +from apps.services.application import AppManager +from apps.services.flow import FlowManager + +router = APIRouter( + prefix="/api/parameter", + tags=["parameter"], + dependencies=[ + Depends(verify_user), + ], +) + + +@router.get("", response_model=GetParamsRsp) +async def get_parameters( + user_sub: Annotated[str, Depends(get_user)], + app_id: Annotated[str, Query(alias="appId")], + flow_id: Annotated[str, Query(alias="flowId")], + step_id: Annotated[str, Query(alias="stepId")], +) -> JSONResponse: + """Get parameters for node choice.""" + if not await AppManager.validate_user_app_access(user_sub, app_id): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content=GetParamsRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + if not flow: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=GetParamsRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetParamsRsp( + code=status.HTTP_200_OK, + message="获取参数成功", + result=result + ).model_dump(exclude_none=True, by_alias=True), + ) + + +@router.get("/operate", response_model=GetOperaRsp) +async def get_operate_parameters( + user_sub: Annotated[str, Depends(get_user)], + param_type: Annotated[str, Query(alias="ParamType")], +) -> JSONResponse: + """Get parameters for node choice.""" + result = await ParameterManager.get_operate_and_bind_type(param_type) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetOperaRsp( + code=status.HTTP_200_OK, + message="获取操作成功", + result=result + ).model_dump(exclude_none=True, by_alias=True), + ) diff --git a/apps/routers/record.py b/apps/routers/record.py index 7384793b4b2abea5117462cb630c5b13c8f76071..663708b86b2dc57a0973c8b73a3d4cb02298b08b 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -65,7 +65,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record = RecordData( id=record.id, groupId=record_group.id, - taskId=record_group.task_id, + taskId=record.task_id, conversationId=conversation_id, content=record_data, metadata=record.metadata @@ -81,26 +81,27 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ # 获得Record关联的文档 tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) - # 获得Record关联的flow数据 - flow_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) - if flow_list: - first_flow = FlowStepHistory.model_validate(flow_list[0]) + flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) + if flow_step_list: tmp_record.flow = RecordFlow( - id=first_flow.flow_name, #TODO: 此处前端应该用name + id=record.flow.flow_id, # TODO: 此处前端应该用name recordId=record.id, - flowId=first_flow.id, - stepNum=len(flow_list), + flowId=record.flow.flow_id, + flowName=record.flow.flow_name, + flowStatus=record.flow.flow_staus, + stepNum=len(flow_step_list), steps=[], ) - for flow in flow_list: - flow_step = FlowStepHistory.model_validate(flow) + for flow_step in flow_step_list: tmp_record.flow.steps.append( RecordFlowStep( - stepId=flow_step.step_name, #TODO: 此处前端应该用name - stepStatus=flow_step.status, + stepId=flow_step.step_id, + stepName=flow_step.step_name, + stepStatus=flow_step.step_status, input=flow_step.input_data, output=flow_step.output_data, + exData=flow_step.ex_data ), ) diff --git a/apps/routers/user.py b/apps/routers/user.py index 54e12f444181b56e408f7face5c4ed37be76008b..537f1bf3adb95e8b48ef1b3d384376bae86ac802 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -3,10 +3,11 @@ from typing import Annotated -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Body, Depends, status, Query from fastapi.responses import JSONResponse from apps.dependency import get_user +from apps.schemas.request_data import UserUpdateRequest from apps.schemas.response_data import UserGetMsp, UserGetRsp from apps.schemas.user import UserInfo from apps.services.user import UserManager @@ -17,12 +18,14 @@ router = APIRouter( ) -@router.get("") -async def chat( +@router.get("", response_model=UserGetRsp) +async def get_user_sub( user_sub: Annotated[str, Depends(get_user)], + page_size: Annotated[int, Query(description="每页用户数量")] = 20, + page_cnt: Annotated[int, Query(description="当前页码")] = 1, ) -> JSONResponse: """查询所有用户接口""" - user_list = await UserManager.get_all_user_sub() + user_list, total = await UserManager.get_all_user_sub(page_cnt=page_cnt, page_size=page_size, filter_user_subs=[user_sub]) user_info_list = [] for user in user_list: # user_info = await UserManager.get_userinfo_by_user_sub(user) 暂时不需要查询user_name @@ -39,6 +42,22 @@ async def chat( content=UserGetRsp( code=status.HTTP_200_OK, message="用户数据详细信息获取成功", - result=UserGetMsp(userInfoList=user_info_list), + result=UserGetMsp(userInfoList=user_info_list, total=total), ).model_dump(exclude_none=True, by_alias=True), ) + + +@router.post("") +async def update_user_info( + user_sub: Annotated[str, Depends(get_user)], + *, + data: Annotated[UserUpdateRequest, Body(..., description="用户更新信息")], +) -> JSONResponse: + """更新用户信息接口""" + # 更新用户信息 + + await UserManager.update_userinfo_by_user_sub(user_sub, data) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"}, + ) diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index 1aec22c82b4754e34755d346d863e09f08f8092a..47337d385d1b525f405af42fdb65a059ae994aa6 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -5,7 +5,7 @@ import json import logging from collections.abc import AsyncGenerator from functools import partial -from typing import Any +from typing import Any, ClassVar import httpx from fastapi import status @@ -15,7 +15,7 @@ from pydantic.json_schema import SkipJsonSchema from apps.common.oidc import oidc_provider from apps.scheduler.call.api.schema import APIInput, APIOutput from apps.scheduler.call.core import CoreCall -from apps.schemas.enum_var import CallOutputType, CallType, ContentType, HTTPMethod +from apps.schemas.enum_var import CallOutputType, CallType, ContentType, HTTPMethod, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -59,14 +59,18 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): body: dict[str, Any] = Field(description="已知的部分请求体", default={}) query: dict[str, Any] = Field(description="已知的部分请求参数", default={}) - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="API调用", - type=CallType.TOOL, - description="向某一个API接口发送HTTP请求,获取数据。" - ) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "API调用", + "type": CallType.TOOL, + "description": "向某一个API接口发送HTTP请求,获取数据", + }, + LanguageType.ENGLISH: { + "name": "API Call", + "type": CallType.TOOL, + "description": "Send an HTTP request to an API to obtain data", + }, + } async def _init(self, call_vars: CallVars) -> APIInput: """初始化API调用工具""" @@ -103,7 +107,9 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): body=self.body, ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """调用API,然后返回LLM解析后的数据""" self._client = httpx.AsyncClient(timeout=self.timeout) input_obj = APIInput.model_validate(input_data) @@ -116,7 +122,9 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): finally: await self._client.aclose() - async def _make_api_call(self, data: APIInput, files: dict[str, tuple[str, bytes, str]]) -> httpx.Response: + async def _make_api_call( + self, data: APIInput, files: dict[str, tuple[str, bytes, str]] + ) -> httpx.Response: """组装API请求""" # 获取必要参数 if self._auth: diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 50485105b8d814818801b313e72d90d4d46d418b..27184963730898d4d180c88976e18cdefaef4e43 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -5,7 +5,7 @@ import ast import copy import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar from pydantic import Field @@ -20,7 +20,7 @@ from apps.scheduler.call.choice.schema import ( ) from apps.schemas.parameters import ValueType from apps.scheduler.call.core import CoreCall -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -35,17 +35,23 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" to_user: bool = Field(default=False) - choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(), - ChoiceBranch(conditions=[Condition()], is_default=False)]) + controlled_output: bool = Field(default=True) + choices: list[ChoiceBranch] = Field( + description="分支", default=[ChoiceBranch(), ChoiceBranch(conditions=[Condition()], is_default=False)] + ) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "条件分支", + "type": CallType.LOGIC, + "description": "使用大模型或使用程序做出判断", + }, + LanguageType.ENGLISH: { + "name": "Choice", + "type": CallType.LOGIC, + "description": "Use a large model or a program to make a decision", + }, - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="条件分支", - type=CallType.LOGIC, - description="使用大模型或使用程序做出条件判断,决定后续分支" - ) + } def _validate_branch_logic(self, choice: ChoiceBranch) -> bool: """验证分支的逻辑运算符是否有效 diff --git a/apps/scheduler/call/cmd/cmd.py b/apps/scheduler/call/cmd/cmd.py index 7f9d9f8c4f47e1c0cf3e4bcceea1b619014598fa..d12748814a4de48852e7467707d3d7634d3e6f9a 100644 --- a/apps/scheduler/call/cmd/cmd.py +++ b/apps/scheduler/call/cmd/cmd.py @@ -1,11 +1,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """自然语言生成命令""" -from typing import Any +from typing import Any, ClassVar + from pydantic import BaseModel, Field from apps.scheduler.call.core import CoreCall +from apps.schemas.enum_var import CallType, LanguageType class _CmdParams(BaseModel): @@ -19,12 +21,21 @@ class _CmdOutput(BaseModel): """Cmd工具的输出""" - class Cmd(CoreCall): """Cmd工具。用于根据BTDL描述文件,生成命令。""" - name: str = "cmd" - description: str = "根据BTDL描述文件,生成命令。" + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "命令生成", + "type": CallType.TOOL, + "description": "根据BTDL描述文件,生成命令", + }, + LanguageType.ENGLISH: { + "name": "Command Generation", + "type": CallType.TOOL, + "description": "Generate commands based on BTDL description files", + }, + } async def _exec(self, _slot_data: dict[str, Any]) -> _CmdOutput: """调用Cmd工具""" diff --git a/apps/scheduler/call/code/code.py b/apps/scheduler/call/code/code.py index 5a331941eb94eeab47e322d7865ad7f5ec21e9af..fe6095a016633d0a9e533c17eda2316db8583175 100644 --- a/apps/scheduler/call/code/code.py +++ b/apps/scheduler/call/code/code.py @@ -3,7 +3,7 @@ import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar import httpx from pydantic import Field @@ -11,7 +11,7 @@ from pydantic import Field from apps.common.config import Config from apps.scheduler.call.code.schema import CodeInput, CodeOutput from apps.scheduler.call.core import CoreCall -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -26,6 +26,7 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): """代码执行工具""" to_user: bool = Field(default=True) + controlled_output: bool = Field(default=True) # 代码执行参数 code: str = Field(description="要执行的代码", default="") @@ -36,17 +37,19 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0) input_parameters: dict[str, Any] = Field(description="输入参数配置", default={}) output_parameters: dict[str, Any] = Field(description="输出参数配置", default={}) - - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="代码执行", - type=CallType.TRANSFORM, - description="在安全的沙箱环境中执行Python、JavaScript、Bash代码。" - ) - + + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "代码执行", + "type": CallType.TOOL, + "description": "在安全的沙箱环境中执行Python、JavaScript、Bash代码", + }, + LanguageType.ENGLISH: { + "name": "Code", + "type": CallType.TOOL, + "description": "Executing Python, JavaScript, Bash in secure sandbox", + }, + } async def _init(self, call_vars: CallVars) -> CodeInput: """初始化代码执行工具""" diff --git a/apps/scheduler/call/convert/convert.py b/apps/scheduler/call/convert/convert.py index bbe0dbe80217ba5e24c629fd02279086e08230fc..6375ed8907dc2da79f24edff2b16ccdebb5e975e 100644 --- a/apps/scheduler/call/convert/convert.py +++ b/apps/scheduler/call/convert/convert.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator from datetime import datetime -from typing import Any +from typing import Any, ClassVar import pytz from jinja2 import BaseLoader @@ -12,7 +12,7 @@ from pydantic import Field from apps.scheduler.call.convert.schema import ConvertInput, ConvertOutput from apps.scheduler.call.core import CallOutputChunk, CoreCall -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallInfo, CallOutputChunk, @@ -25,16 +25,18 @@ class Convert(CoreCall, input_model=ConvertInput, output_model=ConvertOutput): text_template: str | None = Field(description="自然语言信息的格式化模板,jinja2语法", default=None) data_template: str | None = Field(description="原始数据的格式化模板,jinja2语法", default=None) - - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="模板转换", - type=CallType.TRANSFORM, - description="使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。" - ) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "转换工具", + "type": CallType.TRANSFORM, + "description": "提取或格式化Step输出", + }, + LanguageType.ENGLISH: { + "name": "Convert Tool", + "type": CallType.TRANSFORM, + "description": "Extract or format Step output", + }, + } async def _init(self, call_vars: CallVars) -> ConvertInput: """初始化工具""" diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index c1a65d1615688c2f5a128e9dd5317404546bef34..12bf9dff0aa0740fc5a2760f23bebb9a80b90699 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -16,7 +16,7 @@ from pydantic.json_schema import SkipJsonSchema from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM from apps.scheduler.variable.integration import VariableIntegration -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, LanguageType from apps.schemas.pool import NodePool from apps.schemas.parameters import ValueType from apps.schemas.scheduler import ( @@ -56,7 +56,9 @@ class CoreCall(BaseModel): name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True) description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True) node: SkipJsonSchema[NodePool | None] = Field(description="节点信息", exclude=True) - enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False, exclude=True) + enable_filling: SkipJsonSchema[bool] = Field( + description="是否需要进行自动参数填充", default=False, exclude=True + ) tokens: SkipJsonSchema[CallTokens] = Field( description="Call的输入输出Tokens信息", default=CallTokens(), @@ -72,27 +74,35 @@ class CoreCall(BaseModel): exclude=True, frozen=True, ) - to_user: bool = Field(description="是否需要将输出返回给用户", default=False) enable_variable_resolution: bool = Field(description="是否启用自动变量解析", default=True) + controlled_output: bool = Field(description="是否允许用户定义输出参数", default=False) + i18n_info: ClassVar[SkipJsonSchema[dict[str, dict]]] = {} model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", ) - def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None: + @classmethod + def info(cls, language: LanguageType = LanguageType.CHINESE) -> CallInfo: + """ + 返回Call的名称和描述 + + :return: Call的名称和描述 + :rtype: CallInfo + """ + lang_info = cls.i18n_info.get(language, cls.i18n_info[LanguageType.CHINESE]) + return CallInfo(name=lang_info["name"], type=lang_info["type"], description=lang_info["description"]) + + def __init_subclass__( + cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any + ) -> None: """初始化子类""" super().__init_subclass__(**kwargs) cls.input_model = input_model cls.output_model = output_model - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - err = "[CoreCall] 必须手动实现info方法" - raise NotImplementedError(err) - @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" @@ -133,7 +143,7 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") - if len(split_path) < 2: + if len(split_path) < 1: err = f"[CoreCall] 路径格式错误: {path}" logger.error(err) return None @@ -146,13 +156,7 @@ class CoreCall(BaseModel): if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" logger.error(err) - raise CallError( - message=err, - data={ - "step_id": split_path[0], - "key": key, - }, - ) + return None data = data[key] return data @@ -350,7 +354,12 @@ class CoreCall(BaseModel): """Call类实例的执行后方法""" - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def exec( + self, + executor: "StepExecutor", + input_data: dict[str, Any], + language: LanguageType = LanguageType.CHINESE, + ) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" self._last_output_data = {} # 初始化输出数据存储 diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py index a66aac5251ab580942c583576e9ec9bba3204951..47568198e858a6571dc7544fb66f57fa73002959 100644 --- a/apps/scheduler/call/empty.py +++ b/apps/scheduler/call/empty.py @@ -2,30 +2,27 @@ """空白Call""" from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar from apps.scheduler.call.core import CoreCall, DataBase -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars class Empty(CoreCall, input_model=DataBase, output_model=DataBase): """空Call""" - - @classmethod - def info(cls) -> CallInfo: - """ - 返回Call的名称和描述 - - :return: Call的名称和描述 - :rtype: CallInfo - """ - return CallInfo( - name="空白", - type=CallType.DEFAULT, - description="空白节点,用于占位" - ) - + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "空白节点", + "type": CallType.DEFAULT, + "description": "空白节点,用于占位", + }, + LanguageType.ENGLISH: { + "name": "Empty Node", + "type": CallType.DEFAULT, + "description": "Empty node for placeholder", + }, + } async def _init(self, call_vars: CallVars) -> DataBase: """ @@ -38,7 +35,9 @@ class Empty(CoreCall, input_model=DataBase, output_model=DataBase): return DataBase() - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """ 执行Call diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index 33556a8dd9487528e60cacca39697fd4120c8b9b..a7cbf58e21ded36121c16c8ac0bb0f5e315f53ed 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -3,7 +3,7 @@ import logging from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Self, ClassVar from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -17,7 +17,7 @@ from apps.scheduler.call.facts.schema import ( FactsInput, FactsOutput, ) -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars from apps.services.user_domain import UserDomainManager @@ -28,19 +28,19 @@ if TYPE_CHECKING: class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): """提取事实工具""" - answer: str = Field(description="用户输入") - - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="提取事实", - type=CallType.DEFAULT, - description="从对话上下文和文档片段中提取事实。" - ) - + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "提取事实", + "type": CallType.TOOL, + "description": "从对话上下文和文档片段中提取事实。", + }, + LanguageType.ENGLISH: { + "name": "Fact Extraction", + "type": CallType.TOOL, + "description": "Extract facts from the conversation context and document snippets.", + }, + } @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: @@ -71,7 +71,9 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" data = FactsInput(**input_data) # jinja2 环境 @@ -83,7 +85,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ) # 提取事实信息 - facts_tpl = env.from_string(FACTS_PROMPT) + facts_tpl = env.from_string(FACTS_PROMPT[language]) facts_prompt = facts_tpl.render(conversation=data.message) try: facts_obj: FactsGen = await self._json([ @@ -96,7 +98,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): facts_obj = FactsGen(facts=[]) # 更新用户画像 - domain_tpl = env.from_string(DOMAIN_PROMPT) + domain_tpl = env.from_string(DOMAIN_PROMPT[language]) domain_prompt = domain_tpl.render(conversation=data.message) try: domain_list: DomainGen = await self._json([ @@ -120,9 +122,14 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ) - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def exec( + self, + executor: "StepExecutor", + input_data: dict[str, Any], + language: LanguageType = LanguageType.CHINESE, + ) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" - async for chunk in self._exec(input_data): + async for chunk in self._exec(input_data, language): content = chunk.content if not isinstance(content, dict): err = "[FactsCall] 工具输出格式错误" diff --git a/apps/scheduler/call/facts/prompt.py b/apps/scheduler/call/facts/prompt.py index b2b2513f2c28feb4d19f5fb4ee3eba1bd616a4dc..02e134391fa33d3dbe3b3e71cc63947acc739777 100644 --- a/apps/scheduler/call/facts/prompt.py +++ b/apps/scheduler/call/facts/prompt.py @@ -2,8 +2,10 @@ """记忆提取工具的提示词""" from textwrap import dedent - -DOMAIN_PROMPT: str = dedent(r""" +from apps.schemas.enum_var import LanguageType +DOMAIN_PROMPT: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" 根据对话上文,提取推荐系统所需的关键词标签,要求: @@ -35,8 +37,48 @@ DOMAIN_PROMPT: str = dedent(r""" {% endfor %} -""") -FACTS_PROMPT: str = dedent(r""" +""" + ), + LanguageType.ENGLISH: dedent( + r""" + + + Extract keywords for recommendation system based on the previous conversation, requirements: + 1. Entity nouns, technical terms, time range, location, product, etc. can be keyword tags + 2. At least one keyword is related to the topic of the conversation + 3. Tags should be concise and not repeated, not exceeding 10 characters + 4. Output in JSON format, do not include XML tags, do not include any explanatory notes + + + + + What's the weather like in Beijing? + Beijing is sunny today. + + + + { + "keywords": ["Beijing", "weather"] + } + + + + + + {% for item in conversation %} + <{{item['role']}}> + {{item['content']}} + + {% endfor %} + + +""" + ), +} + +FACTS_PROMPT: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" 从对话中提取关键信息,并将它们组织成独一无二的、易于理解的事实,包含用户偏好、关系、实体等有用信息。 @@ -80,4 +122,53 @@ FACTS_PROMPT: str = dedent(r""" {% endfor %} -""") +""" + ), + LanguageType.ENGLISH: dedent( + r""" + + + Extract key information from the conversation and organize it into unique, easily understandable facts, including user preferences, relationships, entities, etc. + The following are the types of information you need to pay attention to and detailed instructions on how to handle input data. + + **Types of information you need to pay attention to** + 1. Entities: Entities involved in the conversation. For example: names, locations, organizations, events, etc. + 2. Preferences: Attitudes towards entities. For example: like, dislike, etc. + 3. Relationships: Relationships between users and entities, or between two entities. For example: include, parallel, mutually exclusive, etc. + 4. Actions: Specific actions that affect entities. For example: query, search, browse, click, etc. + + **Requirements** + 1. Facts must be accurate and can only be extracted from the conversation. Do not include the information in the example in the output. + 2. Facts must be clear, concise, and easy to understand. Must be less than 30 words. + 3. Output in the following JSON format: + + { + "facts": ["Fact 1", "Fact 2", "Fact 3"] + } + + + + + What are the attractions in Hangzhou West Lake? + West Lake in Hangzhou, Zhejiang Province, China, is a famous scenic spot known for its beautiful natural scenery and rich cultural heritage. Many notable attractions surround West Lake, including the renowned Su Causeway, Bai Causeway, Broken Bridge, and the Three Pools Mirroring the Moon. Famous for its crystal-clear waters and the surrounding mountains, West Lake is one of China's most famous lakes. + + + + { + "facts": ["Hangzhou West Lake has famous attractions such as Suzhou Embankment, Bai Budi, Qiantang Bridge, San Tang Yue, etc."] + } + + + + + + {% for item in conversation %} + <{{item['role']}}> + {{item['content']}} + + {% endfor %} + + +""" + ), +} diff --git a/apps/scheduler/call/file_extract/file_extract.py b/apps/scheduler/call/file_extract/file_extract.py index e02e4089169885232386487f124c294ccfbc51e8..542cc46e360d6e3d0ee16074d2d2647ad7482ae0 100644 --- a/apps/scheduler/call/file_extract/file_extract.py +++ b/apps/scheduler/call/file_extract/file_extract.py @@ -9,7 +9,7 @@ import tempfile import uuid from collections.abc import AsyncGenerator from pathlib import Path -from typing import Any +from typing import Any, ClassVar import httpx @@ -19,7 +19,7 @@ from apps.scheduler.call.core import CoreCall from apps.scheduler.call.file_extract.schema import FileExtractInput, FileExtractOutput from apps.scheduler.variable.pool_manager import get_pool_manager from apps.scheduler.variable.type import VariableType -from apps.schemas.enum_var import CallType +from apps.schemas.enum_var import CallType, LanguageType from apps.schemas.scheduler import CallInfo, CallVars, CallOutputChunk, CallOutputType, CallError from pydantic import Field @@ -28,7 +28,20 @@ logger = logging.getLogger(__name__) class FileExtract(CoreCall, input_model=FileExtractInput, output_model=FileExtractOutput): """文件提取器Call""" - + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "文件提取器", + "type": CallType.TRANSFORM, + "description": "从文件中提取文本内容,支持多种文件格式的解析", + }, + LanguageType.ENGLISH: { + "name": "FileExtract", + "type": CallType.TRANSFORM, + "description": "Extract text content from different kinds of document", + }, + } + + controlled_output: bool = Field(default=True) # 添加output_parameters字段支持 output_parameters: dict[str, Any] = Field(description="输出参数配置", default={ "text": {"type": "string", "description": "提取的文本内容"}, @@ -142,20 +155,6 @@ class FileExtract(CoreCall, input_model=FileExtractInput, output_model=FileExtra return error_msg, error_details - @classmethod - def info(cls) -> CallInfo: - """ - 返回Call的名称和描述 - - :return: Call的名称和描述 - :rtype: CallInfo - """ - return CallInfo( - name="文件提取器", - type=CallType.TRANSFORM, - description="从文件中提取文本内容,支持多种文件格式的解析" - ) - async def _init(self, call_vars: CallVars) -> FileExtractInput: """ 初始化Call @@ -219,8 +218,24 @@ class FileExtract(CoreCall, input_model=FileExtractInput, output_model=FileExtra # 处理文件 report_info = "" if var_type == VariableType.FILE: - # 单个文件 - 直接使用文件ID - file_id = file_variable.value + # 单个文件 - 需要从变量值中提取文件ID + if isinstance(file_variable.value, dict): + # 如果是字典格式,提取file_id字段 + file_id = file_variable.value.get('file_id') + if not file_id: + raise CallError( + message="文件变量格式错误:缺少file_id字段", + data={"file_variable_value": file_variable.value} + ) + elif isinstance(file_variable.value, str): + # 如果是字符串格式,直接使用 + file_id = file_variable.value + else: + raise CallError( + message=f"文件变量格式不支持:{type(file_variable.value)}", + data={"file_variable_value": file_variable.value} + ) + text_content, report_info = await self._process_single_file(file_id, parse_method, call_vars, file_variable) else: # 文件数组 - 从变量值中提取file_ids @@ -228,8 +243,19 @@ class FileExtract(CoreCall, input_model=FileExtractInput, output_model=FileExtra file_ids = file_variable.value['file_ids'] text_content, report_info = await self._process_file_array(file_ids, parse_method, call_vars, file_variable) else: - # 兼容性处理:如果值就是文件ID列表 - file_ids = file_variable.value if isinstance(file_variable.value, list) else [file_variable.value] + # 兼容性处理:如果值就是文件ID列表或单个文件ID + if isinstance(file_variable.value, list): + file_ids = file_variable.value + elif isinstance(file_variable.value, str): + file_ids = [file_variable.value] + elif isinstance(file_variable.value, dict) and 'file_id' in file_variable.value: + # 如果是单个文件字典格式,提取file_id并转为列表 + file_ids = [file_variable.value['file_id']] + else: + raise CallError( + message=f"数组文件变量格式不支持:{type(file_variable.value)}", + data={"file_variable_value": file_variable.value} + ) text_content, report_info = await self._process_file_array(file_ids, parse_method, call_vars, file_variable) # 将结果存储到对话级变量池(基于conversation_id的实际对话变量) @@ -706,7 +732,7 @@ class FileExtract(CoreCall, input_model=FileExtractInput, output_model=FileExtra # 使用新的/doc/full_text接口获取文档全文 full_text_url = f"{rag_host}/doc/temporary/text" - full_text_params = {"docId": task_id} + full_text_params = {"id": task_id} logger.info(f"获取文档全文: {full_text_url}") full_text_resp = await client.get(full_text_url, headers=headers, params=full_text_params, timeout=30.0) diff --git a/apps/scheduler/call/graph/graph.py b/apps/scheduler/call/graph/graph.py index 7383b2f2c85aacf78db11b00f95f22c2996f071f..893d1a3209d23fea078116735794993241271742 100644 --- a/apps/scheduler/call/graph/graph.py +++ b/apps/scheduler/call/graph/graph.py @@ -3,7 +3,7 @@ import json from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar from anyio import Path from pydantic import Field @@ -11,7 +11,7 @@ from pydantic import Field from apps.scheduler.call.core import CoreCall from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput from apps.scheduler.call.graph.style import RenderStyle -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -24,17 +24,18 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): """Render Call,用于将SQL Tool查询出的数据转换为图表""" dataset_key: str = Field(description="图表的数据来源(字段名)", default="") - - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="图表", - type=CallType.TRANSFORM, - description="将SQL查询出的数据转换为图表" - ) - + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "图表", + "type": CallType.TRANSFORM, + "description": "将SQL查询出的数据转换为图表。", + }, + LanguageType.ENGLISH: { + "name": "Chart", + "type": CallType.TRANSFORM, + "description": "Convert the data queried by SQL into a chart.", + }, + } async def _init(self, call_vars: CallVars) -> RenderInput: """初始化Render Call,校验参数,读取option模板""" @@ -59,7 +60,9 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """运行Render Call""" data = RenderInput(**input_data) @@ -88,7 +91,7 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): try: style_obj = RenderStyle() - llm_output = await style_obj.generate(question=data.question) + llm_output = await style_obj.generate(question=data.question, language=language) self.tokens.input_tokens += style_obj.input_tokens self.tokens.output_tokens += style_obj.output_tokens @@ -122,7 +125,9 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): return result - def _parse_options(self, column_num: int, chart_style: str, additional_style: str, scale_style: str) -> None: + def _parse_options( + self, column_num: int, chart_style: str, additional_style: str, scale_style: str + ) -> None: """解析LLM做出的图表样式选择""" series_template = {} diff --git a/apps/scheduler/call/graph/style.py b/apps/scheduler/call/graph/style.py index 631ea88acb9c0ef6b851e35cb3fffdecc567a901..e9fef038635ef7569147335d3289b46341675927 100644 --- a/apps/scheduler/call/graph/style.py +++ b/apps/scheduler/call/graph/style.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from apps.llm.function import JsonGenerator from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM +from apps.schemas.enum_var import LanguageType logger = logging.getLogger(__name__) @@ -24,53 +25,95 @@ class RenderStyleResult(BaseModel): class RenderStyle(CorePattern): """选择图表样式""" - system_prompt = r""" - You are a helpful assistant. Help the user make style choices when drawing a chart. - Chart title should be short and less than 3 words. - - Available types: - - `bar`: Bar graph - - `pie`: Pie graph - - `line`: Line graph - - `scatter`: Scatter graph - - Available bar additional styles: - - `normal`: Normal bar graph - - `stacked`: Stacked bar graph - - Available pie additional styles: - - `normal`: Normal pie graph - - `ring`: Ring pie graph - - Available scales: - - `linear`: Linear scale - - `log`: Logarithmic scale - - EXAMPLE - ## Question - 查询数据库中的数据,并绘制堆叠柱状图。 - - ## Thought - Let's think step by step. The user requires drawing a stacked bar chart, so the chart type should be `bar`, \ - i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form. - - ## Answer - The chart type should be: bar - The chart style should be: stacked - The scale should be: linear - - END OF EXAMPLE - - Let's begin. - """ - - user_prompt = r""" - ## Question - {question} - - ## Thought - Let's think step by step. - """ + def get_default_prompt(self) -> dict[LanguageType, str]: + system_prompt = { + LanguageType.CHINESE: r""" + 你是一个有用的助手。帮助用户在绘制图表时做出样式选择。 + 图表标题应简短且少于3个字。 + 可用类型: + - `bar`: 柱状图 + - `pie`: 饼图 + - `line`: 折线图 + - `scatter`: 散点图 + 可用柱状图附加样式: + - `normal`: 普通柱状图 + - `stacked`: 堆叠柱状图 + 可用饼图附加样式: + - `normal`: 普通饼图 + - `ring`: 环形饼图 + 可用比例: + - `linear`: 线性比例 + - `log`: 对数比例 + EXAMPLE + ## 问题 + 查询数据库中的数据,并绘制堆叠柱状图。 + ## 思考 + 让我们一步步思考。用户要求绘制堆叠柱状图,因此图表类型应为 `bar`,即柱状图;图表样式 + 应为 `stacked`,即堆叠形式。 + ## 答案 + 图表类型应为:bar + 图表样式应为:stacked + 比例应为:linear + END OF EXAMPLE + + 让我们开始吧。 + """, + LanguageType.ENGLISH: r""" + You are a helpful assistant. Help the user make style choices when drawing a chart. + Chart title should be short and less than 3 words. + + Available types: + - `bar`: Bar graph + - `pie`: Pie graph + - `line`: Line graph + - `scatter`: Scatter graph + + Available bar additional styles: + - `normal`: Normal bar graph + - `stacked`: Stacked bar graph + + Available pie additional styles: + - `normal`: Normal pie graph + - `ring`: Ring pie graph + + Available scales: + - `linear`: Linear scale + - `log`: Logarithmic scale + + EXAMPLE + ## Question + 查询数据库中的数据,并绘制堆叠柱状图。 + + ## Thought + Let's think step by step. The user requires drawing a stacked bar chart, so the chart type should be `bar`, \ + i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form. + + ## Answer + The chart type should be: bar + The chart style should be: stacked + The scale should be: linear + + END OF EXAMPLE + + Let's begin. + """ + } + user_prompt = { + LanguageType.CHINESE: r""" + ## 问题 + {question} + ## 思考 + 让我们一步步思考。根据用户问题,选择合适的图表类型、样式和比例。 + """, + LanguageType.ENGLISH: r""" + ## Question + {question} + + ## Thought + Let's think step by step. + """ + } + return system_prompt, user_prompt def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: """初始化RenderStyle Prompt""" @@ -79,11 +122,11 @@ class RenderStyle(CorePattern): async def generate(self, **kwargs) -> dict[str, Any]: # noqa: ANN003 """使用LLM选择图表样式""" question = kwargs["question"] - + language = kwargs.get("language", LanguageType.CHINESE) # 使用Reasoning模型进行推理 messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.user_prompt.format(question=question)}, + {"role": "system", "content": self.system_prompt[language]}, + {"role": "user", "content": self.user_prompt[language].format(question=question)}, ] result = "" llm = ReasoningLLM() diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 7f35310e5e84eb1e7b8245b94b3ddde258e98b41..7c8e0a8d2f30171721af5b0767493c5672915c9f 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -4,7 +4,7 @@ import logging from collections.abc import AsyncGenerator from datetime import datetime -from typing import Any +from typing import Any, ClassVar import pytz from jinja2 import BaseLoader @@ -15,7 +15,7 @@ from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall from apps.scheduler.call.llm.prompt import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT from apps.scheduler.call.llm.schema import LLMInput, LLMOutput -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -38,16 +38,18 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): system_prompt: str = Field(description="大模型系统提示词", default="You are a helpful assistant.") user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT) - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="大模型", - type=CallType.DEFAULT, - description="以指定的提示词和上下文信息调用大模型,并获得输出。" - ) - + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "大模型", + "type": CallType.DEFAULT, + "description": "以指定的提示词和上下文信息调用大模型,并获得输出。", + }, + LanguageType.ENGLISH: { + "name": "Foundation Model", + "type": CallType.DEFAULT, + "description": "Call the foundation model with specified prompt and context, and obtain the output.", + }, + } async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: """准备消息""" @@ -105,7 +107,9 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """运行LLM Call""" data = LLMInput(**input_data) try: diff --git a/apps/scheduler/call/llm/prompt.py b/apps/scheduler/call/llm/prompt.py index 0f227dcaa618b11a2c888f55a61fa51f349d7d8b..20536ff27b28b17979bbf7eed3b94f5e15f2139a 100644 --- a/apps/scheduler/call/llm/prompt.py +++ b/apps/scheduler/call/llm/prompt.py @@ -2,16 +2,34 @@ """大模型工具的提示词""" from textwrap import dedent +from apps.schemas.enum_var import LanguageType LLM_CONTEXT_PROMPT = dedent( + # r""" + # 以下是对用户和AI间对话的简短总结,在中给出: + # + # {{ summary }} + # + # 你作为AI,在回答用户的问题前,需要获取必要的信息。为此,你调用了一些工具,并获得了它们的输出: + # 工具的输出数据将在中给出, 其中为工具的名称,为工具的输出数据。 + # + # {% for tool in history_data %} + # + # {{ tool.step_name }} + # {{ tool.step_description }} + # {{ tool.output_data }} + # + # {% endfor %} + # + # """, r""" - 以下是对用户和AI间对话的简短总结,在中给出: + The following is a brief summary of the user and AI conversation, given in : {{ summary }} - 你作为AI,在回答用户的问题前,需要获取必要的信息。为此,你调用了一些工具,并获得了它们的输出: - 工具的输出数据将在中给出, 其中为工具的名称,为工具的输出数据。 + As an AI, before answering the user's question, you need to obtain necessary information. For this purpose, you have called some tools and obtained their outputs: + The output data of the tools will be given in , where is the name of the tool and is the output data of the tool. {% for tool in history_data %} @@ -24,12 +42,29 @@ LLM_CONTEXT_PROMPT = dedent( """, ).strip("\n") LLM_DEFAULT_PROMPT = dedent( + # r""" + # + # 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。 + # 当前时间:{{ time }},可以作为时间参照。 + # 用户的问题将在中给出,上下文背景信息将在中给出。 + # 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 + # + # + # {{ question }} + # + # + # {{ context }} + # + # 现在,输出你的回答: + # """, r""" - 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。 - 当前时间:{{ time }},可以作为时间参照。 - 用户的问题将在中给出,上下文背景信息将在中给出。 - 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 + You are a helpful AI assistant. Please answer the user's question based on the given background information. + Current time: {{ time }}, which can be used as a reference. + The user's question will be given in , and the context background information will be given in . + + Respond using the same language as the user's question, unless the user explicitly requests a specific language—then follow that request. + Note: Do not include any XML tags in the output. Do not make up any information. If you think the user's question is unrelated to the background information, please ignore the background information and answer directly. @@ -39,12 +74,13 @@ LLM_DEFAULT_PROMPT = dedent( {{ context }} - - 现在,输出你的回答: - """, + Now, please output your answer: + """ ).strip("\n") -LLM_ERROR_PROMPT = dedent( - r""" + +LLM_ERROR_PROMPT = { + LanguageType.CHINESE: dedent( + r""" 你是一位智能助手,能够根据用户的问题,使用Python工具获取信息,并作出回答。你在使用工具解决回答用户的问题时,发生了错误。 你的任务是:分析工具(Python程序)的异常信息,分析造成该异常可能的原因,并以通俗易懂的方式,将原因告知用户。 @@ -67,8 +103,36 @@ LLM_ERROR_PROMPT = dedent( 现在,输出你的回答: - """, -).strip("\n") + """ + ).strip("\n"), + LanguageType.ENGLISH: dedent( + r""" + + You are an intelligent assistant. When using Python tools to answer user questions, an error occurred. + Your task is: Analyze the exception information of the tool (Python program), analyze the possible causes of the error, and inform the user in an easy-to-understand way. + + Current time: {{ time }}, which can be used as a reference. + The program exception information that occurred will be given in , the user's question will be given in , and the context background information will be given in . + Note: Do not include any XML tags in the output. Do not make up any information. If you think the user's question is unrelated to the background information, please ignore the background information. + + + + {{ error_info }} + + + + {{ question }} + + + + {{ context }} + + + Now, please output your answer: + """ + ).strip("\n"), +} + RAG_ANSWER_PROMPT = dedent( r""" diff --git a/apps/scheduler/call/loop/loop.py b/apps/scheduler/call/loop/loop.py index 001aa6ffe03f24486202192ca278b4a6203753cc..896fc602c98e754b1adcfe6374996cb7914d4924 100644 --- a/apps/scheduler/call/loop/loop.py +++ b/apps/scheduler/call/loop/loop.py @@ -6,7 +6,7 @@ import logging import uuid import asyncio from collections.abc import AsyncGenerator -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, ClassVar from pydantic import Field @@ -18,7 +18,7 @@ from apps.scheduler.call.choice.condition_handler import ConditionHandler from apps.scheduler.call.loop.schema import LoopInput, LoopOutput, LoopStopCondition from apps.scheduler.pool.loader.flow import FlowLoader from apps.scheduler.variable.integration import VariableIntegration -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.flow import Flow, Step, Edge from apps.schemas.flow_topology import PositionItem from apps.schemas.scheduler import ( @@ -49,14 +49,18 @@ class Loop(CoreCall, input_model=LoopInput, output_model=LoopOutput): # 保存StepExecutor引用用于子工作流执行 step_executor: Any = Field(default=None, exclude=True) - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="循环", - type=CallType.LOGIC, - description="直到循环终止条件达成或最大循环次数到达之前,子工作流将不断循环执行" - ) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "循环", + "type": CallType.LOGIC, + "description": "直到循环终止条件达成或最大循环次数到达之前,子工作流将不断循环执行", + }, + LanguageType.ENGLISH: { + "name": "Loop", + "type": CallType.LOGIC, + "description": "Subflow will be run repeatly until reaching maximum iteration or matching ending condition", + }, + } async def _process_stop_condition(self, call_vars: CallVars) -> tuple[bool, str]: """处理停止条件 @@ -573,13 +577,13 @@ class Loop(CoreCall, input_model=LoopInput, output_model=LoopOutput): sub_flow_id=sub_flow_id, ) - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def exec(self, executor: "StepExecutor", input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE) -> AsyncGenerator[CallOutputChunk, None]: """重写exec方法来保存executor引用""" # 保存executor引用 self.step_executor = executor # 调用父类的exec方法 - async for chunk in super().exec(executor, input_data): + async for chunk in super().exec(executor, input_data, language): yield chunk async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 9c78a183695f197900c9e8fbdba08d76eb4e0b6e..22027d2f63a1b3f6e60bf8041efc97258b3a8fcf 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -4,7 +4,7 @@ import logging from collections.abc import AsyncGenerator from copy import deepcopy -from typing import Any +from typing import Any, ClassVar from pydantic import Field @@ -16,7 +16,7 @@ from apps.scheduler.call.mcp.schema import ( MCPOutput, ) from apps.scheduler.mcp import MCPHost, MCPPlanner, MCPSelector -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.mcp import MCPPlanItem from apps.schemas.scheduler import ( CallInfo, @@ -26,28 +26,49 @@ from apps.schemas.scheduler import ( logger = logging.getLogger(__name__) +MCP_GENERATE: dict[str, dict[LanguageType, str]] = { + "START": { + LanguageType.CHINESE: "[MCP] 开始生成计划...\n\n\n\n", + LanguageType.ENGLISH: "[MCP] Start generating plan...\n\n\n\n", + }, + "END": { + LanguageType.CHINESE: "[MCP] 计划生成完成:\n\n{plan_str}\n\n\n\n", + LanguageType.ENGLISH: "[MCP] Plan generation completed: \n\n{plan_str}\n\n\n\n", + }, +} + +MCP_SUMMARY: dict[str, dict[LanguageType, str]] = { + "START": { + LanguageType.CHINESE: "[MCP] 正在总结任务结果...\n\n", + LanguageType.ENGLISH: "[MCP] Start summarizing task results...\n\n", + }, + "END": { + LanguageType.CHINESE: "[MCP] 任务完成\n\n---\n\n{answer}\n\n", + LanguageType.ENGLISH: "[MCP] Task summary completed\n\n{answer}\n\n", + }, +} + class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """MCP工具""" mcp_list: list[str] = Field(description="MCP Server ID列表", max_length=5, min_length=1) - max_steps: int = Field(description="最大步骤数", default=6) + max_steps: int = Field(description="最大步骤数", default=20) text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) - @classmethod - def info(cls) -> CallInfo: - """ - 返回Call的名称和描述 - - :return: Call的名称和描述 - :rtype: CallInfo - """ - return CallInfo( - name="MCP", - type=CallType.DEFAULT, - description="调用MCP Server,执行工具" - ) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "MCP", + "type": CallType.DEFAULT, + "description": "调用MCP Server,执行工具", + }, + LanguageType.ENGLISH: { + "name": "MCP", + "type": CallType.DEFAULT, + "description": "Call the MCP Server to execute tools", + }, + } async def _init(self, call_vars: CallVars) -> MCPInput: """初始化MCP""" @@ -65,31 +86,33 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """执行MCP""" # 生成计划 - async for chunk in self._generate_plan(): + async for chunk in self._generate_plan(language): yield chunk # 执行计划 plan_list = deepcopy(self._plan.plans) while len(plan_list) > 0: - async for chunk in self._execute_plan_item(plan_list.pop(0)): + async for chunk in self._execute_plan_item(plan_list.pop(0), language): yield chunk # 生成总结 - async for chunk in self._generate_answer(): + async for chunk in self._generate_answer(language): yield chunk - async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]: + async def _generate_plan(self, language) -> AsyncGenerator[CallOutputChunk, None]: """生成执行计划""" # 开始提示 - yield self._create_output("[MCP] 开始生成计划...\n\n\n\n", MCPMessageType.PLAN_BEGIN) + yield self._create_output(MCP_GENERATE["START"][language], MCPMessageType.PLAN_BEGIN) # 选择工具并生成计划 selector = MCPSelector() top_tool = await selector.select_top_tool(self._call_vars.question, self.mcp_list) - planner = MCPPlanner(self._call_vars.question) + planner = MCPPlanner(self._call_vars.question, language) self._plan = await planner.create_plan(top_tool, self.max_steps) # 输出计划 @@ -98,12 +121,14 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): plan_str += f"[+] {plan_item.content}; {plan_item.tool}[{plan_item.instruction}]\n\n" yield self._create_output( - f"[MCP] 计划生成完成:\n\n{plan_str}\n\n\n\n", + MCP_GENERATE["END"][language].format(plan_str=plan_str), MCPMessageType.PLAN_END, data=self._plan.model_dump(), ) - async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]: + async def _execute_plan_item( + self, plan_item: MCPPlanItem, language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """执行单个计划项""" # 判断是否为Final if plan_item.tool == "Final": @@ -124,7 +149,7 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): # 调用工具 try: - result = await self._host.call_tool(tool, plan_item) + result = await self._host.call_tool(tool, plan_item, language) except Exception as e: err = f"[MCP] 工具 {tool.name} 调用失败: {e!s}" logger.exception(err) @@ -140,21 +165,21 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): }, ) - async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]: + async def _generate_answer(self, language) -> AsyncGenerator[CallOutputChunk, None]: """生成总结""" # 提示开始总结 yield self._create_output( - "[MCP] 正在总结任务结果...\n\n", + MCP_SUMMARY["START"][language], MCPMessageType.FINISH_BEGIN, ) # 生成答案 - planner = MCPPlanner(self._call_vars.question) + planner = MCPPlanner(self._call_vars.question, language) answer = await planner.generate_answer(self._plan, await self._host.assemble_memory()) # 输出结果 yield self._create_output( - f"[MCP] 任务完成\n\n---\n\n{answer}\n\n", + MCP_SUMMARY["END"][language].format(answer=answer), MCPMessageType.FINISH_END, data=MCPOutput( message=answer, @@ -170,8 +195,11 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """创建输出""" if self.text_output: return CallOutputChunk(type=CallOutputType.TEXT, content=text) - return CallOutputChunk(type=CallOutputType.DATA, content=MCPMessage( - msg_type=msg_type, - message=text.strip(), - data=data or {}, - ).model_dump_json()) + return CallOutputChunk( + type=CallOutputType.DATA, + content=MCPMessage( + msg_type=msg_type, + message=text.strip(), + data=data or {}, + ).model_dump_json(), + ) diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index 7c1c48940c9e4cb399dd4453ec702797883fae55..c81a22fe5626de254d09cc1c7312d29c241acf42 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -3,7 +3,7 @@ import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar import httpx from fastapi import status @@ -13,7 +13,7 @@ from apps.common.config import Config from apps.llm.patterns.rewrite import QuestionRewrite from apps.scheduler.call.core import CoreCall from apps.scheduler.call.rag.schema import RAGInput, RAGOutput, SearchMethod -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -36,15 +36,18 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): is_rerank: bool = Field(description="是否重新排序", default=False) is_compress: bool = Field(description="是否压缩", default=False) tokens_limit: int = Field(description="token限制", default=8192) - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="知识库", - type=CallType.DEFAULT, - description="查询知识库,从文档中获取必要信息" - ) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "知识库", + "type": CallType.DEFAULT, + "description": "查询知识库,从文档中获取必要信息", + }, + LanguageType.ENGLISH: { + "name": "Knowledge Base", + "type": CallType.DEFAULT, + "description": "Query the knowledge base and obtain necessary information from documents", + }, + } async def _init(self, call_vars: CallVars) -> RAGInput: """初始化RAG工具""" @@ -62,7 +65,9 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): tokensLimit=self.tokens_limit, ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """调用RAG工具""" data = RAGInput(**input_data) question_obj = QuestionRewrite() diff --git a/apps/scheduler/call/reply/direct_reply.py b/apps/scheduler/call/reply/direct_reply.py index cd9f0818f0c3506f1db4a3fb15d8b2380eff6e32..171bdbb1f08d2506ba1e73943b075359ccce71c3 100644 --- a/apps/scheduler/call/reply/direct_reply.py +++ b/apps/scheduler/call/reply/direct_reply.py @@ -5,13 +5,13 @@ import logging import re import base64 from collections.abc import AsyncGenerator -from typing import Any, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional from pydantic import Field from apps.scheduler.call.core import CoreCall from apps.scheduler.call.reply.schema import DirectReplyInput, DirectReplyOutput -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -30,15 +30,19 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep """直接回复工具,支持变量引用语法""" to_user: bool = Field(default=True) - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="直接回复", - type=CallType.DEFAULT, - description="直接回复用户输入的内容,支持变量插入" - ) + controlled_output: bool = Field(default=True) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "直接回复", + "type": CallType.DEFAULT, + "description": "直接回复用户输入的内容,支持变量插入", + }, + LanguageType.ENGLISH: { + "name": "DirectReply", + "type": CallType.DEFAULT, + "description": "Reply contents defined by user, with inserted reference variables", + }, + } async def _init(self, call_vars: CallVars) -> DirectReplyInput: """初始化DirectReply工具""" diff --git a/apps/scheduler/call/search/search.py b/apps/scheduler/call/search/search.py index 73d21d7b9956a19b6eaaf23467b4286d7fbb3d79..1c1534e7eba6cec592acb161f18b129b0079255c 100644 --- a/apps/scheduler/call/search/search.py +++ b/apps/scheduler/call/search/search.py @@ -1,10 +1,11 @@ """搜索工具""" from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar from apps.scheduler.call.core import CoreCall from apps.scheduler.call.search.schema import SearchInput, SearchOutput +from apps.schemas.enum_var import CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -15,6 +16,18 @@ from apps.schemas.scheduler import ( class Search(CoreCall, input_model=SearchInput, output_model=SearchOutput): """搜索工具""" + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "搜索", + "type": CallType.TOOL, + "description": "获取搜索引擎的结果。", + }, + LanguageType.ENGLISH: { + "name": "Search", + "type": CallType.TOOL, + "description": "Get the results of the search engine.", + }, + } @classmethod def info(cls) -> CallInfo: diff --git a/apps/scheduler/call/slot/prompt.py b/apps/scheduler/call/slot/prompt.py index e5650a4c5764a4ab739c6b66ad00838597c134a3..8a3f7ae8535f3420f05ca66b02d02429f143ca67 100644 --- a/apps/scheduler/call/slot/prompt.py +++ b/apps/scheduler/call/slot/prompt.py @@ -1,7 +1,9 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """自动参数填充工具的提示词""" +from apps.schemas.enum_var import LanguageType -SLOT_GEN_PROMPT = r""" +SLOT_GEN_PROMPT:dict[LanguageType, str] = { + LanguageType.CHINESE: r""" 你是一个可以使用工具的AI助手,正尝试使用工具来完成任务。 目前,你正在生成一个JSON参数对象,以作为调用工具的输入。 @@ -85,4 +87,91 @@ SLOT_GEN_PROMPT = r""" {{schema}} - """ + """, + LanguageType.ENGLISH: r""" + + You are an AI assistant capable of using tools to complete tasks. + Currently, you are generating a JSON parameter object as input for calling a tool. + Please generate a compliant JSON object based on user input, background information, tool information, and JSON Schema content. + + Background information will be provided in , tool information in , JSON Schema in , \ + and the user's question in . + Output the generated JSON object in . + + Requirements: + 1. Strictly follow the JSON format described in the JSON Schema. Do not fabricate non-existent fields. + 2. Prioritize using values from user input for JSON fields. If not available, use content from background information. + 3. Only output the JSON object. Do not include any explanations or additional content. + 4. Optional fields in the JSON Schema may be omitted. + 5. Examples are for illustration only. Do not copy content from examples or use them as output. + 6. Respond in the same language as the user's question by default, unless explicitly requested otherwise. + + + + + User asked about today's weather in Hangzhou. AI replied it's sunny, 20℃. User then asks about tomorrow's weather in Hangzhou. + + + What's the weather like in Hangzhou tomorrow? + + + Tool name: check_weather + Tool description: Query weather information for specified cities + + + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name" + }, + "date": { + "type": "string", + "description": "Query date" + }, + "required": ["city", "date"] + } + } + + + { + "city": "Hangzhou", + "date": "tomorrow" + } + + + + + Historical summary of tasks given by user, provided in : + + {{summary}} + + Additional itemized information: + {{ facts }} + + + During this task, you have called some tools and obtained their outputs, provided in : + + {% for tool in history_data %} + + {{ tool.step_name }} + {{ tool.step_description }} + {{ tool.output_data }} + + {% endfor %} + + + + {{question}} + + + Tool name: {{current_tool["name"]}} + Tool description: {{current_tool["description"]}} + + + {{schema}} + + + """, +} diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 69c2bfbcd94c04adab5396a9f27763633715e182..71ed9afeae9e211fd371d092645d15954578be6a 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -3,7 +3,7 @@ import json from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Self, ClassVar from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -15,7 +15,7 @@ from apps.scheduler.call.core import CoreCall from apps.scheduler.call.slot.prompt import SLOT_GEN_PROMPT from apps.scheduler.call.slot.schema import SlotInput, SlotOutput from apps.scheduler.slot.slot import Slot as SlotProcessor -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars @@ -31,19 +31,22 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): summary: str = Field(description="背景信息总结", default="") facts: list[str] = Field(description="事实信息", default=[]) step_num: int = Field(description="历史步骤数", default=1) - - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="参数自动填充", - type=CallType.TRANSFORM, - description="根据步骤历史,自动填充参数" - ) - - - async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]: + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "参数自动填充", + "type": CallType.TOOL, + "description": "根据步骤历史,自动填充参数", + }, + LanguageType.ENGLISH: { + "name": "Parameter Auto-Fill", + "type": CallType.TOOL, + "description": "Auto-fill parameters based on step history.", + }, + } + + async def _llm_slot_fill( + self, remaining_schema: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> tuple[str, dict[str, Any]]: """使用大模型填充参数;若大模型解析度足够,则直接返回结果""" env = SandboxedEnvironment( loader=BaseLoader(), @@ -51,21 +54,24 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): trim_blocks=True, lstrip_blocks=True, ) - template = env.from_string(SLOT_GEN_PROMPT) + template = env.from_string(SLOT_GEN_PROMPT[language]) conversation = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": template.render( - current_tool={ - "name": self.name, - "description": self.description, - }, - schema=remaining_schema, - history_data=self._flow_history, - summary=self.summary, - question=self._question, - facts=self.facts, - )}, + { + "role": "user", + "content": template.render( + current_tool={ + "name": self.name, + "description": self.description, + }, + schema=remaining_schema, + history_data=self._flow_history, + summary=self.summary, + question=self._question, + facts=self.facts, + ), + }, ] # 使用大模型进行尝试 @@ -131,7 +137,9 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """执行参数填充""" data = SlotInput(**input_data) @@ -145,7 +153,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): ).model_dump(by_alias=True, exclude_none=True), ) return - answer, slot_data = await self._llm_slot_fill(data.remaining_schema) + answer, slot_data = await self._llm_slot_fill(data.remaining_schema, language) slot_data = self._processor.convert_json(slot_data) remaining_schema = self._processor.check_json(slot_data) diff --git a/apps/scheduler/call/sql/schema.py b/apps/scheduler/call/sql/schema.py index 06ffb4f0611bb12a3d6080a166c2bead52d24e10..c3b57c5245dd1ad2cbb1fd24fb80f2f96959fea8 100644 --- a/apps/scheduler/call/sql/schema.py +++ b/apps/scheduler/call/sql/schema.py @@ -1,7 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """SQL工具的输入输出""" -from typing import Any +from typing import Any, Optional from pydantic import Field @@ -17,5 +17,5 @@ class SQLInput(DataBase): class SQLOutput(DataBase): """SQL工具的输出""" - dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") + result: list[dict[str, Any]] = Field(description="SQL工具的执行结果") sql: str = Field(description="SQL语句") diff --git a/apps/scheduler/call/sql/sql.py b/apps/scheduler/call/sql/sql.py index 2f4ef97fb095baeac0b89ceb2dff1c98338fed18..0b732cc2966af774593ede1c1ce994081ec745be 100644 --- a/apps/scheduler/call/sql/sql.py +++ b/apps/scheduler/call/sql/sql.py @@ -3,8 +3,9 @@ import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar +from urllib.parse import urlparse import httpx from fastapi import status from pydantic import Field @@ -12,7 +13,7 @@ from pydantic import Field from apps.common.config import Config from apps.scheduler.call.core import CoreCall from apps.scheduler.call.sql.schema import SQLInput, SQLOutput -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -22,25 +23,40 @@ from apps.schemas.scheduler import ( logger = logging.getLogger(__name__) +MESSAGE = { + "invaild": { + LanguageType.CHINESE: "SQL查询错误:无法生成有效的SQL语句!", + LanguageType.ENGLISH: "SQL query error: Unable to generate valid SQL statements!", + }, + "fail": { + LanguageType.CHINESE: "SQL查询错误:SQL语句执行失败!", + LanguageType.ENGLISH: "SQL query error: SQL statement execution failed!", + }, +} + class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): """SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。""" - database_url: str = Field(description="数据库连接地址") + database_type: str = Field(description="数据库类型",default="postgres") # mysql mongodb opengauss postgres + host: str = Field(description="数据库地址",default="localhost") + port: int = Field(description="数据库端口",default=5432) + username: str = Field(description="数据库用户名",default="root") + password: str = Field(description="数据库密码",default="root") + database: str = Field(description="数据库名称",default="postgres") table_name_list: list[str] = Field(description="表名列表",default=[]) - top_k: int = Field(description="生成SQL语句数量",default=5) - use_llm_enhancements: bool = Field(description="是否使用大模型增强", default=False) - - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="SQL查询", - type=CallType.TOOL, - description="使用大模型生成SQL语句,用于查询数据库中的结构化数据" - ) - + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "SQL查询", + "type": CallType.TOOL, + "description": "使用大模型生成SQL语句,用于查询数据库中的结构化数据", + }, + LanguageType.ENGLISH: { + "name": "SQL Query", + "type": CallType.TOOL, + "description": "Use the foundation model to generate SQL statements to query structured data in the databases", + }, + } async def _init(self, call_vars: CallVars) -> SQLInput: """初始化SQL工具。""" @@ -49,98 +65,55 @@ class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): ) - async def _generate_sql(self, data: SQLInput) -> list[dict[str, Any]]: - """生成SQL语句列表""" + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: + """运行SQL工具, 支持MySQL, MongoDB, PostgreSQL, OpenGauss""" + + data = SQLInput(**input_data) + + headers = {"Content-Type": "application/json"} + post_data = { - "database_url": self.database_url, - "table_name_list": self.table_name_list, - "question": data.question, - "topk": self.top_k, - "use_llm_enhancements": self.use_llm_enhancements, + "type": self.database_type, + "host": self.host, + "port": self.port, + "username": self.username, + "password": self.password, + "database": self.database, + "goal": data.question, + "table_list": self.table_name_list, } - headers = {"Content-Type": "application/json"} - sql_list = [] - request_num = 0 - max_request = 5 - - while request_num < max_request and len(sql_list) < self.top_k: - try: - async with httpx.AsyncClient() as client: - response = await client.post( - Config().get_config().extra.sql_url + "/database/sql", - headers=headers, - json=post_data, - timeout=60.0, - ) - request_num += 1 - if response.status_code == status.HTTP_200_OK: - result = response.json() - if result["code"] == status.HTTP_200_OK: - sql_list.extend(result["result"]["sql_list"]) - else: - logger.error("[SQL] 生成失败:%s", response.text) - except Exception: - logger.exception("[SQL] 生成失败") - request_num += 1 - - return sql_list - - - async def _execute_sql( - self, - sql_list: list[dict[str, Any]], - ) -> tuple[list[dict[str, Any]] | None, str | None]: - """执行SQL语句并返回结果""" - headers = {"Content-Type": "application/json"} + try: + async with httpx.AsyncClient() as client: + response = await client.post( + Config().get_config().extra.sql_url + "/sql/handler", + headers=headers, + json=post_data, + timeout=60.0, + ) + + result = response.json() + if response.status_code == status.HTTP_200_OK: + if result["code"] == status.HTTP_200_OK: + result_data = result["result"] + sql_exec_results = result_data.get("execute_result") + sql_exec = result_data.get("sql") + sql_exec_risk = result_data.get("risk") + logger.info("[SQL] 调用成功\n[SQL 语句]: %s\n[SQL 结果]: %s\n[SQL 风险]: %s", sql_exec, sql_exec_results, sql_exec_risk) + + else: + logger.error("[SQL] 调用失败:%s", response.text) + logger.error("[SQL] 错误信息:%s", response["result"]) + except Exception: + logger.exception("[SQL] 调用失败") - for sql_dict in sql_list: - try: - async with httpx.AsyncClient() as client: - response = await client.post( - Config().get_config().extra.sql_url + "/sql/execute", - headers=headers, - json={ - "database_id": sql_dict["database_id"], - "sql": sql_dict["sql"], - }, - timeout=60.0, - ) - if response.status_code == status.HTTP_200_OK: - result = response.json() - if result["code"] == status.HTTP_200_OK: - return result["result"], sql_dict["sql"] - else: - logger.error("[SQL] 调用失败:%s", response.text) - except Exception: - logger.exception("[SQL] 调用失败") - - return None, None - - - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: - """运行SQL工具""" - data = SQLInput(**input_data) - # 生成SQL语句 - sql_list = await self._generate_sql(data) - if not sql_list: - raise CallError( - message="SQL查询错误:无法生成有效的SQL语句!", - data={}, - ) - - # 执行SQL语句 - sql_exec_results, sql_exec = await self._execute_sql(sql_list) - if sql_exec_results is None or sql_exec is None: - raise CallError( - message="SQL查询错误:SQL语句执行失败!", - data={}, - ) # 返回结果 data = SQLOutput( - dataset=sql_exec_results, + result=sql_exec_results, sql=sql_exec, ).model_dump(exclude_none=True, by_alias=True) diff --git a/apps/scheduler/call/suggest/prompt.py b/apps/scheduler/call/suggest/prompt.py index abc5e7d186500f917b4610b9c0e893dec7ef3c12..a9f61d7c1c5df6bc7fcee488d6ab05781b685d72 100644 --- a/apps/scheduler/call/suggest/prompt.py +++ b/apps/scheduler/call/suggest/prompt.py @@ -2,8 +2,11 @@ """问题推荐工具的提示词""" from textwrap import dedent +from apps.schemas.enum_var import LanguageType -SUGGEST_PROMPT = dedent(r""" +SUGGEST_PROMPT: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" 根据提供的对话和附加信息(用户倾向、历史问题列表、工具信息等),生成三个预测问题。 @@ -89,4 +92,98 @@ SUGGEST_PROMPT = dedent(r""" 现在,进行问题生成: -""") +""" + ), + LanguageType.ENGLISH: dedent( + r""" + + + Generate three predicted questions based on the provided conversation and additional information (user preferences, historical question list, tool information, etc.). + The historical question list displays questions asked by the user before the historical conversation and is for background reference only. + The conversation will be given in the tag, the user preferences will be given in the tag, + the historical question list will be given in the tag, and the tool information will be given in the tag. + + Requirements for generating predicted questions: + + 1. Generate three predicted questions in the user's voice. They must be interrogative or imperative sentences and must be less than 30 words. + + 2. Predicted questions must be concise, without repetition, unnecessary information, or text other than the question. + + 3. Output must be in the following format: + + ```json + { + "predicted_questions": [ + "Predicted question 1", + "Predicted question 2", + "Predicted question 3" + ] + } + ``` + + + + What are the famous attractions in Hangzhou? + Hangzhou West Lake is a famous scenic spot in Hangzhou, Zhejiang Province, China, known for its beautiful natural scenery and rich cultural heritage. There are many famous attractions around West Lake, including the renowned Su Causeway, Bai Causeway, Broken Bridge, and the Three Pools Mirroring the Moon. West Lake is renowned for its clear waters and surrounding mountains, making it one of China's most famous lakes. + + + Briefly introduce Hangzhou + What are the famous attractions in Hangzhou? + + + Scenic Spot Search + Scenic Spot Information Search + + ["Hangzhou", "Tourism"] + + Now, generate questions: + + { + "predicted_questions": [ + "What is the ticket price for the West Lake Scenic Area in Hangzhou?", + "What are the famous attractions in Hangzhou?", + "What's the weather like in Hangzhou?" + ] + } + + + + Here's the actual data: + + + {% for message in conversation %} + <{{ message.role }}>{{ message.content }} + {% endfor %} + + + + {% if history %} + {% for question in history %} + {{ question }} + {% endfor %} + {% else %} + (No history question) + {% endif %} + + + + {% if tool %} + {{ tool.name }} + {{ tool.description }} + {% else %} + (No tool information) + {% endif %} + + + + {% if preference %} + {{ preference }} + {% else %} + (no user preference) + {% endif %} + + + Now, generate the question: + """ + ), +} diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 663e1d9c95d3faae2b51d75de0a7bab3c54ca66c..77ca0550b45e516037028988bd75c7ab54bcf326 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -3,7 +3,7 @@ import random from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Self, ClassVar from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -20,7 +20,7 @@ from apps.scheduler.call.suggest.schema import ( SuggestionInput, SuggestionOutput, ) -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.pool import NodePool from apps.schemas.record import RecordContent from apps.schemas.scheduler import ( @@ -47,15 +47,18 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO context: SkipJsonSchema[list[dict[str, str]]] = Field(description="Executor的上下文", exclude=True) conversation_id: SkipJsonSchema[str] = Field(description="对话ID", exclude=True) - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="问题推荐", - type=CallType.DEFAULT, - description="在答案下方显示推荐的下一个问题" - ) - + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "问题推荐", + "type": CallType.DEFAULT, + "description": "在答案下方显示推荐的下一个问题", + }, + LanguageType.ENGLISH: { + "name": "Question Suggestion", + "type": CallType.DEFAULT, + "description": "Display the suggested next question under the answer", + }, + } @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: @@ -128,8 +131,9 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO history_questions.append(record_data.question) return history_questions - - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """运行问题推荐""" data = SuggestionInput(**input_data) @@ -145,7 +149,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO # 已推送问题数量 pushed_questions = 0 # 初始化Prompt - prompt_tpl = self._env.from_string(SUGGEST_PROMPT) + prompt_tpl = self._env.from_string(SUGGEST_PROMPT[language]) # 先处理configs for config in self.configs: diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py index 7f6ff062d522efa25af8a5fdc5ceec38d7ba967d..9cbf8b1165ecf24dffb7f0bbc5c1db82be61d9b1 100644 --- a/apps/scheduler/call/summary/summary.py +++ b/apps/scheduler/call/summary/summary.py @@ -2,14 +2,14 @@ """总结上下文工具""" from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Self, ClassVar from pydantic import Field from apps.llm.patterns.executor import ExecutorSummary from apps.scheduler.call.core import CoreCall, DataBase from apps.scheduler.call.summary.schema import SummaryOutput -from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.pool import NodePool from apps.schemas.scheduler import ( CallInfo, @@ -27,15 +27,18 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): """总结工具""" context: ExecutorBackground = Field(description="对话上下文") - - @classmethod - def info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo( - name="理解上下文", - type=CallType.DEFAULT, - description="使用大模型,理解对话上下文" - ) + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "理解上下文", + "type": CallType.DEFAULT, + "description": "使用大模型,理解对话上下文", + }, + LanguageType.ENGLISH: { + "name": "Context Understanding", + "type": CallType.DEFAULT, + "description": "Use the foundation model to understand the conversation context", + }, + } @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: @@ -56,19 +59,26 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): return DataBase() - async def _exec(self, _input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def _exec( + self, _input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" summary_obj = ExecutorSummary() - summary = await summary_obj.generate(background=self.context) + summary = await summary_obj.generate(background=self.context, language=language) self.tokens.input_tokens += summary_obj.input_tokens self.tokens.output_tokens += summary_obj.output_tokens yield CallOutputChunk(type=CallOutputType.TEXT, content=summary) - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + async def exec( + self, + executor: "StepExecutor", + input_data: dict[str, Any], + language: LanguageType = LanguageType.CHINESE, + ) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" - async for chunk in self._exec(input_data): + async for chunk in self._exec(input_data, language): content = chunk.content if not isinstance(content, str): err = "[SummaryCall] 工具输出格式错误" diff --git a/apps/scheduler/call/variable_assign/variable_assign.py b/apps/scheduler/call/variable_assign/variable_assign.py index 583109a24e33999bf2d17eebff034d617aa730ce..dec7cfcc8d816443a1d5b198e92cceff2e1874f3 100644 --- a/apps/scheduler/call/variable_assign/variable_assign.py +++ b/apps/scheduler/call/variable_assign/variable_assign.py @@ -4,7 +4,7 @@ import logging import math from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, ClassVar from apps.scheduler.call.core import CoreCall from apps.scheduler.call.variable_assign.schema import ( @@ -17,7 +17,7 @@ from apps.scheduler.call.variable_assign.schema import ( ) from apps.scheduler.variable.type import VariableType from apps.scheduler.variable.pool_manager import get_pool_manager -from apps.schemas.enum_var import CallType +from apps.schemas.enum_var import CallType, LanguageType from apps.schemas.scheduler import CallInfo, CallVars, CallOutputChunk, CallOutputType, CallError @@ -26,25 +26,23 @@ logger = logging.getLogger(__name__) class VariableAssign(CoreCall, input_model=VariableAssignInput, output_model=VariableAssignOutput): """变量赋值Call""" + i18n_info: ClassVar[dict[str, dict]] = { + LanguageType.CHINESE: { + "name": "变量赋值", + "type": CallType.TRANSFORM, + "description": "对已有变量进行值的操作,支持字符串、数值和数组类型变量的多种操作", + }, + LanguageType.ENGLISH: { + "name": "VariableAssign", + "type": CallType.TRANSFORM, + "description": "Assign value for exist vairables, supporting string, number, and array, etc.", + }, + } def __init__(self, **kwargs): super().__init__(**kwargs) self._call_vars = None - @classmethod - def info(cls) -> CallInfo: - """ - 返回Call的名称和描述 - - :return: Call的名称和描述 - :rtype: CallInfo - """ - return CallInfo( - name="变量赋值", - type=CallType.TRANSFORM, - description="对已有变量进行值的操作,支持字符串、数值和数组类型变量的多种操作" - ) - async def _init(self, call_vars: CallVars) -> VariableAssignInput: """ 初始化Call diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 1a3876ec35885c3570bd12245f639d90c5d3831c..c34ad1eb3e9647bc706b759f996f9468eafda90a 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -1,14 +1,32 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP Agent执行器""" - +from datetime import datetime,UTC import logging +import uuid +import anyio +from mcp.types import TextContent from pydantic import Field +from apps.llm.reasoning import ReasoningLLM from apps.scheduler.executor.base import BaseExecutor -from apps.scheduler.mcp_agent.agent.mcp import MCPAgent -from apps.schemas.task import ExecutorState, StepQueueItem +from apps.schemas.enum_var import LanguageType +from apps.scheduler.mcp_agent.host import MCPHost +from apps.scheduler.mcp_agent.plan import MCPPlanner +from apps.scheduler.mcp_agent.select import FINAL_TOOL_ID, SELF_DESC_TOOL_ID +from apps.scheduler.pool.mcp.pool import MCPPool +from apps.schemas.enum_var import EventType, FlowStatus, StepStatus +from apps.schemas.mcp import ( + MCPCollection, + MCPTool, + Step, +) +from apps.schemas.message import FlowParams +from apps.schemas.task import FlowStepHistory +from apps.services.appcenter import AppCenterManager +from apps.services.mcp_service import MCPServiceManager from apps.services.task import TaskManager +from apps.services.user import UserManager logger = logging.getLogger(__name__) @@ -16,17 +34,581 @@ logger = logging.getLogger(__name__) class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" - question: str = Field(description="用户输入") - max_steps: int = Field(default=20, description="最大步数") + max_steps: int = Field(default=40, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") + mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[]) + mcp_pool: MCPPool = Field(description="MCP池", default_factory=MCPPool) + tools: dict[str, MCPTool] = Field( + description="MCP工具列表,key为tool_id", + default={}, + ) + tool_list: list[MCPTool] = Field( + description="MCP工具列表,包含所有MCP工具", + default=[], + ) + params: FlowParams | bool | None = Field( + default=None, + description="流执行过程中的参数补充", + alias="params", + ) + resoning_llm: ReasoningLLM = Field( + default_factory=ReasoningLLM, + description="推理大模型", + ) + + async def update_tokens(self) -> None: + """更新令牌数""" + self.task.tokens.input_tokens = self.resoning_llm.input_tokens + self.task.tokens.output_tokens = self.resoning_llm.output_tokens + await TaskManager.save_task(self.task.id, self.task) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: - context_objects = await TaskManager.get_context_by_task_id(self.task.id) - # 将对象转换为字典以保持与系统其他部分的一致性 - self.task.context = [context.model_dump(exclude_none=True, by_alias=True) for context in context_objects] + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + + async def load_mcp(self) -> None: + """加载MCP服务器列表""" + logger.info("[MCPAgentExecutor] 加载MCP服务器列表") + # 获取MCP服务器列表 + app = await AppCenterManager.fetch_app_data_by_id(self.agent_id) + mcp_ids = app.mcp_service + for mcp_id in mcp_ids: + mcp_service = await MCPServiceManager.get_mcp_service(mcp_id) + if self.task.ids.user_sub not in mcp_service.activated: + logger.warning( + "[MCPAgentExecutor] 用户 %s 未启用MCP %s", + self.task.ids.user_sub, + mcp_id, + ) + continue + + self.mcp_list.append(mcp_service) + await self.mcp_pool._init_mcp(mcp_id, self.task.ids.user_sub) + for tool in mcp_service.tools: + self.tools[tool.id] = tool + self.tool_list.extend(mcp_service.tools) + if self.task.language == LanguageType.CHINESE: + self.tools[FINAL_TOOL_ID] = MCPTool( + id=FINAL_TOOL_ID, name="Final Tool", description="结束流程的工具", mcp_id="", input_schema={}, + ) + self.tool_list.append( + MCPTool(id=FINAL_TOOL_ID, name="Final Tool", description="结束流程的工具", mcp_id="", input_schema={}), + ) + self.tools[SELF_DESC_TOOL_ID] = MCPTool( + id=SELF_DESC_TOOL_ID, + name="Self Description", + description="用于描述自身能力和背景信息的工具", + mcp_id="", + input_schema={}, + ) + self.tool_list.append( + MCPTool( + id=SELF_DESC_TOOL_ID, + name="Self Description", + description="用于描述自身能力和背景信息的工具", + mcp_id="", + input_schema={}, + ) + ) + else: + self.tools[FINAL_TOOL_ID] = MCPTool(id=FINAL_TOOL_ID, name="Final Tool", + description="The tool to end the process", mcp_id="", input_schema={},) + self.tool_list.append( + MCPTool( + id=FINAL_TOOL_ID, name="Final Tool", description="The tool to end the process", + mcp_id="", input_schema={}),) + self.tools[SELF_DESC_TOOL_ID] = MCPTool( + id=SELF_DESC_TOOL_ID, + name="Self Description", + description="A tool used to describe one's own abilities and background information", + mcp_id="", + input_schema={}, + ) + self.tool_list.append( + MCPTool( + id=SELF_DESC_TOOL_ID, + name="Self Description", + description="A tool used to describe one's own abilities and background information", + mcp_id="", + input_schema={}, + ) + ) + + async def get_tool_input_param(self, is_first: bool) -> None: + # 工具的入参是 {} ,不需要填充 + if self.task.state.tool_id in [FINAL_TOOL_ID, SELF_DESC_TOOL_ID]: + self.task.state.current_input = {} + return + if is_first: + # 获取第一个输入参数 + mcp_tool = self.tools[self.task.state.tool_id] + self.task.state.current_input = await MCPHost._get_first_input_params( + mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task + ) + else: + # 获取后续输入参数 + if isinstance(self.params, FlowParams): + params = self.params.content + params_description = self.params.description + else: + params = {} + params_description = "" + mcp_tool = self.tools[self.task.state.tool_id] + self.task.state.current_input = await MCPHost._fill_params( + mcp_tool, + self.task.runtime.question, + self.task.state.step_description, + self.task.state.current_input, + self.task.state.error_message, + params, + params_description, + self.task.language, + ) + + async def confirm_before_step(self) -> None: + """确认前步骤""" + # 发送确认消息 + mcp_tool = self.tools[self.task.state.tool_id] + confirm_message = await MCPPlanner.get_tool_risk( + mcp_tool, self.task.state.current_input, "", self.resoning_llm, self.task.language + ) + await self.update_tokens() + await self.push_message( + EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(exclude_none=True, by_alias=True), + ) + await self.push_message(EventType.FLOW_STOP, {}) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.WAITING + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True), + ) + ) + + async def run_step(self) -> None: + """执行步骤""" + self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.step_status = StepStatus.RUNNING + mcp_tool = self.tools[self.task.state.tool_id] + result_exchange = True + try: + if self.task.state.tool_id == SELF_DESC_TOOL_ID: + tools = [] + for tool in self.tool_list: + if tool.id not in [SELF_DESC_TOOL_ID, FINAL_TOOL_ID]: + tools.append(f"{tool.name}: {tool.description}") + output_params = { + "message": tools + } + result_exchange = False + else: + mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.ids.user_sub)) + output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input) + except anyio.ClosedResourceError: + logger.exception("[MCPAgentExecutor] MCP客户端连接已关闭: %s", mcp_tool.mcp_id) + await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.ids.user_sub) + await self.mcp_pool._init_mcp(mcp_tool.mcp_id, self.task.ids.user_sub) + self.task.state.step_status = StepStatus.ERROR + return + except Exception as e: + import traceback + logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误: %s", mcp_tool.name, traceback.format_exc()) + self.task.state.step_status = StepStatus.ERROR + self.task.state.error_message = str(e) + return + if result_exchange: + if output_params.isError: + err = "" + for output in output_params.content: + if isinstance(output, TextContent): + err += output.text + self.task.state.step_status = StepStatus.ERROR + self.task.state.error_message = err + return + message = "" + for output in output_params.content: + if isinstance(output, TextContent): + message += output.text + output_params = { + "message": message, + } + + await self.update_tokens() + await self.push_message(EventType.STEP_INPUT, self.task.state.current_input) + await self.push_message(EventType.STEP_OUTPUT, output_params) + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=StepStatus.SUCCESS, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data=self.task.state.current_input, + output_data=output_params, + ) + ) + self.task.state.step_status = StepStatus.SUCCESS + + async def generate_params_with_null(self) -> None: + """生成参数补充""" + mcp_tool = self.tools[self.task.state.tool_id] + params_with_null = await MCPPlanner.get_missing_param( + mcp_tool, + self.task.state.current_input, + self.task.state.error_message, + self.resoning_llm, + self.task.language, + ) + await self.update_tokens() + error_message = await MCPPlanner.change_err_message_to_description( + error_message=self.task.state.error_message, + tool=mcp_tool, + input_params=self.task.state.current_input, + reasoning_llm=self.resoning_llm, + language=self.task.language, + ) + await self.push_message( + EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null} + ) + await self.push_message(EventType.FLOW_STOP, data={}) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.PARAM + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data={ + "message": error_message, + "params": params_with_null + } + ) + ) + + async def get_next_step(self) -> None: + """获取下一步""" + self.task.tokens.time=datetime.now(UTC).timestamp() + self.task.state.retry_times = 0 + if self.task.state.step_cnt < self.max_steps: + self.task.state.step_cnt += 1 + history = await MCPHost.assemble_memory(self.task) + max_retry = 3 + step = None + for i in range(max_retry): + try: + step = await MCPPlanner.create_next_step(self.task.runtime.question, history, self.tool_list, language=self.task.language) + if step.tool_id in self.tools.keys(): + break + except Exception as e: + logger.warning("[MCPAgentExecutor] 获取下一步失败,重试中: %s", str(e)) + if step is None or step.tool_id not in self.tools.keys(): + step = Step( + tool_id=FINAL_TOOL_ID, + description=FINAL_TOOL_ID + ) + tool_id = step.tool_id + if tool_id == FINAL_TOOL_ID: + step_name = FINAL_TOOL_ID + else: + step_name = self.tools[tool_id].name + step_description = step.description + self.task.state.step_id = str(uuid.uuid4()) + self.task.state.tool_id = tool_id + self.task.state.step_name = step_name + self.task.state.step_description = step_description + self.task.state.step_status = StepStatus.INIT + self.task.state.current_input = {} + else: + # 没有下一步了,结束流程 + self.task.state.tool_id = FINAL_TOOL_ID + return + + async def error_handle_after_step(self) -> None: + """步骤执行失败后的错误处理""" + self.task.state.step_status = StepStatus.ERROR + self.task.state.flow_status = FlowStatus.ERROR + await self.push_message( + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) + self.task.state.tool_id = FINAL_TOOL_ID + + async def work(self) -> None: + """执行当前步骤""" + if self.task.state.step_status == StepStatus.INIT: + await self.push_message( + EventType.STEP_INIT, + data={} + ) + await self.get_tool_input_param(is_first=True) + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + if not user_info.auto_execute: + # 等待用户确认 + await self.confirm_before_step() + return + self.task.state.step_status = StepStatus.RUNNING + elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: + if self.task.state.step_status == StepStatus.PARAM: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + await self.get_tool_input_param(is_first=False) + elif self.task.state.step_status == StepStatus.WAITING: + if self.params: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + else: + self.task.state.flow_status = FlowStatus.CANCELLED + self.task.state.step_status = StepStatus.CANCELLED + await self.push_message( + EventType.STEP_CANCEL, + data={} + ) + await self.push_message( + EventType.FLOW_CANCEL, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.CANCELLED + return + max_retry = 5 + for i in range(max_retry): + if i != 0: + await self.get_tool_input_param(is_first=True) + await self.run_step() + if self.task.state.step_status == StepStatus.SUCCESS: + break + elif self.task.state.step_status == StepStatus.ERROR: + # 错误处理 + self.task.state.retry_times += 1 + if self.task.state.retry_times >= 3: + await self.error_handle_after_step() + else: + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + if user_info.auto_execute: + await self.push_message( + EventType.STEP_ERROR, + data={ + "message": self.task.state.error_message, + } + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.ERROR + self.task.context[-1].output_data = { + "message": self.task.state.error_message, + } + else: + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=StepStatus.ERROR, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data=self.task.state.current_input, + output_data={ + "message": self.task.state.error_message, + }, + ) + ) + await self.get_next_step() + else: + mcp_tool = self.tools[self.task.state.tool_id] + is_param_error = await MCPPlanner.is_param_error( + self.task.runtime.question, + await MCPHost.assemble_memory(self.task), + self.task.state.error_message, + mcp_tool, + self.task.state.step_description, + self.task.state.current_input, + language=self.task.language, + ) + if is_param_error.is_param_error: + # 如果是参数错误,生成参数补充 + await self.generate_params_with_null() + else: + await self.push_message( + EventType.STEP_ERROR, + data={ + "message": self.task.state.error_message, + } + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.ERROR + self.task.context[-1].output_data = { + "message": self.task.state.error_message, + } + else: + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=StepStatus.ERROR, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data=self.task.state.current_input, + output_data={ + "message": self.task.state.error_message, + }, + ) + ) + await self.get_next_step() + elif self.task.state.step_status == StepStatus.SUCCESS: + await self.get_next_step() + + async def summarize(self) -> None: + """总结""" + async for chunk in MCPPlanner.generate_answer( + self.task.runtime.question, + (await MCPHost.assemble_memory(self.task)), + self.resoning_llm, + self.task.language, + ): + await self.push_message( + EventType.TEXT_ADD, + data=chunk + ) + self.task.runtime.answer += chunk + + async def run(self) -> None: + """执行MCP Agent的主逻辑""" + # 初始化MCP服务 + await self.load_state() + await self.load_mcp() + data = {} + if self.task.state.flow_status == FlowStatus.INIT: + # 初始化状态 + try: + self.task.state.flow_id = str(uuid.uuid4()) + self.task.state.flow_name = (await MCPPlanner.get_flow_name( + self.task.runtime.question, self.resoning_llm, self.task.language + )).flow_name + flow_risk = await MCPPlanner.get_flow_excute_risk( + self.task.runtime.question, self.tool_list, self.resoning_llm, self.task.language + ) + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + if user_info.auto_execute: + data = flow_risk.model_dump(exclude_none=True, by_alias=True) + await TaskManager.save_task(self.task.id, self.task) + await self.get_next_step() + except Exception as e: + logger.exception("[MCPAgentExecutor] 初始化失败") + self.task.state.flow_status = FlowStatus.ERROR + self.task.state.error_message = str(e) + await self.push_message( + EventType.FLOW_FAILED, + data={} + ) + return + self.task.state.flow_status = FlowStatus.RUNNING + await self.push_message( + EventType.FLOW_START, + data=data + ) + if self.task.state.tool_id == FINAL_TOOL_ID: + # 如果已经是最后一步,直接结束 + self.task.state.flow_status = FlowStatus.SUCCESS + await self.push_message( + EventType.FLOW_SUCCESS, + data={} + ) + await self.summarize() + return + try: + while self.task.state.flow_status == FlowStatus.RUNNING: + if self.task.state.tool_id == FINAL_TOOL_ID: + break + await self.work() + await TaskManager.save_task(self.task.id, self.task) + tool_id = self.task.state.tool_id + if tool_id == FINAL_TOOL_ID: + # 如果已经是最后一步,直接结束 + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_status = StepStatus.SUCCESS + await self.push_message( + EventType.FLOW_SUCCESS, + data={} + ) + await self.summarize() + except Exception as e: + logger.exception("[MCPAgentExecutor] 执行过程中发生错误") + self.task.state.flow_status = FlowStatus.ERROR + self.task.state.error_message = str(e) + self.task.state.step_status = StepStatus.ERROR + await self.push_message( + EventType.STEP_ERROR, + data={} + ) + await self.push_message( + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) + finally: + for mcp_service in self.mcp_list: + try: + await self.mcp_pool.stop(mcp_service.id, self.task.ids.user_sub) + except Exception as e: + import traceback + logger.error("[MCPAgentExecutor] 停止MCP客户端时发生错误: %s", traceback.format_exc()) \ No newline at end of file diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 7ae03b73320f9bb7a6c8ac7febb14c52704bec4c..56839ee0a9a4133606e98aad5bd0c59909c73907 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -44,15 +44,8 @@ class BaseExecutor(BaseModel, ABC): :param event_type: 事件类型 :param data: 消息数据,如果是FLOW_START事件且data为None,则自动构建FlowStartContent """ - if event_type == EventType.FLOW_START.value and isinstance(data, dict): - data = FlowStartContent( - question=self.question, - params=self.task.runtime.filled, - ).model_dump(exclude_none=True, by_alias=True) - elif event_type == EventType.FLOW_STOP.value: - data = {} - elif event_type == EventType.TEXT_ADD.value and isinstance(data, str): - data=TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) + if event_type == EventType.TEXT_ADD.value and isinstance(data, str): + data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) if data is None: data = {} diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a3cd2b01308c41d1234250795703fe84d2e8a953..e70e4e25de3e6c8ada39c8c7fdc08d01311856ee 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -12,7 +12,7 @@ from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.executor.step import StepExecutor from apps.scheduler.variable.integration import VariableIntegration -from apps.schemas.enum_var import EventType, SpecialCallType, StepStatus +from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus, LanguageType from apps.schemas.flow import Flow, Step from apps.schemas.request_data import RequestDataApp from apps.schemas.task import ExecutorState, StepQueueItem @@ -21,21 +21,37 @@ from apps.services.task import TaskManager logger = logging.getLogger(__name__) # 开始前的固定步骤 FIXED_STEPS_BEFORE_START = [ - Step( - name="理解上下文", - description="使用大模型,理解对话上下文", - node=SpecialCallType.SUMMARY.value, - type=SpecialCallType.SUMMARY.value, - ), + { + LanguageType.CHINESE: Step( + name="理解上下文", + description="使用大模型,理解对话上下文", + node=SpecialCallType.SUMMARY.value, + type=SpecialCallType.SUMMARY.value, + ), + LanguageType.ENGLISH: Step( + name="Understand context", + description="Use large model to understand the context of the dialogue", + node=SpecialCallType.SUMMARY.value, + type=SpecialCallType.SUMMARY.value, + ), + }, ] # 结束后的固定步骤 FIXED_STEPS_AFTER_END = [ - Step( - name="记忆存储", - description="理解对话答案,并存储到记忆中", - node=SpecialCallType.FACTS.value, - type=SpecialCallType.FACTS.value, - ), + { + LanguageType.CHINESE: Step( + name="记忆存储", + description="理解对话答案,并存储到记忆中", + node=SpecialCallType.FACTS.value, + type=SpecialCallType.FACTS.value, + ), + LanguageType.ENGLISH: Step( + name="Memory storage", + description="Understand the answer of the dialogue and store it in the memory", + node=SpecialCallType.FACTS.value, + type=SpecialCallType.FACTS.value, + ), + }, ] @@ -47,13 +63,20 @@ class FlowExecutor(BaseExecutor): flow_id: str = Field(description="Flow ID") question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - current_step: StepQueueItem | None = Field(default=None, description="当前执行的步骤") + current_step: StepQueueItem | None = Field( + description="当前执行的步骤", + default=None + ) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if ( + self.task.state + and self.task.state.flow_status != FlowStatus.INIT + and self.task.state.flow_status != FlowStatus.UNKNOWN + ): context_objects = await TaskManager.get_context_by_task_id(self.task.id) # 将对象转换为字典以保持与系统其他部分的一致性 self.task.context = [context.model_dump(exclude_none=True, by_alias=True) for context in context_objects] @@ -62,11 +85,13 @@ class FlowExecutor(BaseExecutor): self.task.state = ExecutorState( flow_id=str(self.flow_id), flow_name=self.flow.name, + flow_status=FlowStatus.RUNNING, description=str(self.flow.description), - status=StepStatus.RUNNING, + step_status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), step_id="start", - step_name="开始", + # TODO 这种写法不利于多语言扩展建议改写 + step_name="开始" if self.task.language == LanguageType.CHINESE else "Start", ) self.validate_flow_state(self.task) # 是否到达Flow结束终点(变量) @@ -199,40 +224,52 @@ class FlowExecutor(BaseExecutor): # 头插开始前的系统步骤,并执行 for step in FIXED_STEPS_BEFORE_START: - self.step_queue.append(StepQueueItem( - step_id=str(uuid.uuid4()), - step=step, - enable_filling=False, - to_user=False, - )) + self.step_queue.append( + StepQueueItem( + step_id=str(uuid.uuid4()), + step=step.get(self.task.language, step[LanguageType.CHINESE]), + enable_filling=False, + to_user=False, + ) + ) await self._step_process() - + # 插入首个步骤 self.step_queue.append(first_step) - + + self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] + # 运行Flow(未达终点) + is_error = False while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() - self.step_queue.appendleft(StepQueueItem( - step_id=str(uuid.uuid4()), - step=Step( - name="错误处理", - description="错误处理", - node=SpecialCallType.LLM.value, - type=SpecialCallType.LLM.value, - params={ - "user_prompt": LLM_ERROR_PROMPT.replace( - "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.step_queue.appendleft( + StepQueueItem( + step_id=str(uuid.uuid4()), + step=Step( + name=( + "错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling" + ), + description=( + "错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling" ), - }, - ), - enable_filling=False, - to_user=False, - )) + node=SpecialCallType.LLM.value, + type=SpecialCallType.LLM.value, + params={ + "user_prompt": LLM_ERROR_PROMPT[self.task.language].replace( + "{{ error_info }}", + self.task.state.error_info["err_msg"], # type: ignore[arg-type] + ), + }, + ), + enable_filling=False, + to_user=False, + ) + ) + is_error = True # 错误处理后结束 self._reached_end = True @@ -251,15 +288,26 @@ class FlowExecutor(BaseExecutor): else: logger.info("[FlowExecutor] 步骤 %s 已经执行过,不再添加到队列中", step.step_id) + # 更新Task状态 + if is_error: + self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + else: + self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: - self.step_queue.append(StepQueueItem( - step_id=str(uuid.uuid4()), - step=step, - )) + self.step_queue.append( + StepQueueItem( + step_id=str(uuid.uuid4()), + step=step.get(self.task.language, step[LanguageType.CHINESE]), + ) + ) await self._step_process() # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time # 推送Flow停止消息 - await self.push_message(EventType.FLOW_STOP.value) + if is_error: + await self.push_message(EventType.FLOW_FAILED.value) + else: + await self.push_message(EventType.FLOW_SUCCESS.value) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 7f43703aa83ba02f96683dbc2cc066c0c2c1f4ca..62ee6a49856e0efaf6552c66943916c822213f23 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -76,6 +76,10 @@ class StepExecutor(BaseExecutor): return Slot if call_id == SpecialCallType.DIRECT_REPLY.value: return DirectReply + if call_id == SpecialCallType.START.value: + return Empty # start节点使用Empty类 + if call_id == SpecialCallType.END.value: + return Empty # end节点使用Empty类 # 从Pool中获取对应的Call call_cls: type[CoreCall] = await Pool().get_call(call_id) @@ -173,9 +177,9 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.step_status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 @@ -422,19 +426,19 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) # 执行步骤 - iterator = self.obj.exec(self, self.obj.input) + iterator = self.obj.exec(self, self.obj.input, language=self.task.language) try: content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.step_status = StepStatus.ERROR # type: ignore[arg-type] # 构建错误输出数据 if isinstance(e, CallError): @@ -463,7 +467,7 @@ class StepExecutor(BaseExecutor): return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -482,14 +486,15 @@ class StepExecutor(BaseExecutor): task_id=self.task.id, flow_id=self.task.state.flow_id, # type: ignore[arg-type] flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_status=self.task.state.flow_status, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + step_status=self.task.state.step_status, # type: ignore[arg-type] input_data=self.obj.input, output_data=output_data, ) - self.task.context.append(history.model_dump(exclude_none=True, by_alias=True)) + self.task.context.append(history) try: await self.push_message(EventType.STEP_OUTPUT.value, output_data) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 78aa7bc3ee869e8710e1fb02a2d9fb438d04be34..8e7b26e38e2eba24b9b9c735b93e6a5a12407970 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -14,7 +14,7 @@ from apps.llm.function import JsonGenerator from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.pool import MCPPool -from apps.schemas.enum_var import StepStatus +from apps.schemas.enum_var import StepStatus, LanguageType from apps.schemas.mcp import MCPPlanItem, MCPTool from apps.schemas.task import FlowStepHistory from apps.services.task import TaskManager @@ -25,10 +25,18 @@ logger = logging.getLogger(__name__) class MCPHost: """MCP宿主服务""" - def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None: + def __init__( + self, + user_sub: str, + task_id: str, + runtime_id: str, + runtime_name: str, + language: LanguageType = LanguageType.CHINESE, + ) -> None: """初始化MCP宿主""" self._user_sub = user_sub self._task_id = task_id + self.language = language # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description self._runtime_id = runtime_id self._runtime_name = runtime_name @@ -40,7 +48,6 @@ class MCPHost: lstrip_blocks=True, ) - async def get_client(self, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" mongo = MongoDB() @@ -59,7 +66,6 @@ class MCPHost: logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) return None - async def assemble_memory(self) -> str: """组装记忆""" task = await TaskManager.get_task_by_task_id(self._task_id) @@ -69,16 +75,15 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue context_list.append(context) - return self._env.from_string(MEMORY_TEMPLATE).render( + return self._env.from_string(MEMORY_TEMPLATE[self.language]).render( context_list=context_list, ) - async def _save_memory( self, tool: MCPTool, @@ -105,11 +110,12 @@ class MCPHost: task_id=self._task_id, flow_id=self._runtime_id, flow_name=self._runtime_name, + flow_status=StepStatus.RUNNING, step_id=tool.name, step_name=tool.name, # description是规划的实际内容 step_description=plan_item.content, - status=StepStatus.SUCCESS, + step_status=StepStatus.SUCCESS, input_data=input_data, output_data=output_data, ) @@ -120,12 +126,11 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context.model_dump(exclude_none=True, by_alias=True)) await TaskManager.save_task(self._task_id, task) return output_data - async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate @@ -146,7 +151,6 @@ class MCPHost: ) return await json_generator.generate() - async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client @@ -170,7 +174,6 @@ class MCPHost: return processed_result - async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: """获取工具列表""" mongo = MongoDB() diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py index cd4f5975eea3f023a92626966081c2d1eb33bdb7..78d695f2cc9fc47c995245f0e8cf059b60e76a3d 100644 --- a/apps/scheduler/mcp/plan.py +++ b/apps/scheduler/mcp/plan.py @@ -8,14 +8,16 @@ from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER from apps.schemas.mcp import MCPPlan, MCPTool +from apps.schemas.enum_var import LanguageType class MCPPlanner: """MCP 用户目标拆解与规划""" - def __init__(self, user_goal: str) -> None: + def __init__(self, user_goal: str, language: LanguageType = LanguageType.CHINESE) -> None: """初始化MCP规划器""" self.user_goal = user_goal + self.language = language self._env = SandboxedEnvironment( loader=BaseLoader, autoescape=True, @@ -25,7 +27,6 @@ class MCPPlanner: self.input_tokens = 0 self.output_tokens = 0 - async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: """规划下一步的执行流程,并输出""" # 获取推理结果 @@ -38,7 +39,7 @@ class MCPPlanner: async def _get_reasoning_plan(self, tool_list: list[MCPTool], max_steps: int) -> str: """获取推理大模型的结果""" # 格式化Prompt - template = self._env.from_string(CREATE_PLAN) + template = self._env.from_string(CREATE_PLAN[self.language]) prompt = template.render( goal=self.user_goal, tools=tool_list, @@ -88,7 +89,7 @@ class MCPPlanner: async def generate_answer(self, plan: MCPPlan, memory: str) -> str: """生成最终回答""" - template = self._env.from_string(FINAL_ANSWER) + template = self._env.from_string(FINAL_ANSWER[self.language]) prompt = template.render( plan=plan, memory=memory, diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index b322fb0883e8ed935243389cb86066845a549631..29721b31c92a84d875052f3097af6e3e5c6db82d 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -2,8 +2,11 @@ """MCP相关的大模型Prompt""" from textwrap import dedent +from apps.schemas.enum_var import LanguageType -MCP_SELECT = dedent(r""" +MCP_SELECT: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" 你是一个乐于助人的智能助手。 你的任务是:根据当前目标,选择最合适的MCP Server。 @@ -61,8 +64,73 @@ MCP_SELECT = dedent(r""" ### 请一步一步思考: -""") -CREATE_PLAN = dedent(r""" +""" + ), + LanguageType.ENGLISH: dedent( + r""" + You are a helpful intelligent assistant. + Your task is to select the most appropriate MCP server based on your current goals. + + ## Things to note when selecting an MCP server: + + 1. Ensure you fully understand your current goals and select the most appropriate MCP server. + 2. Please select from the provided list of MCP servers; do not generate your own. + 3. Please provide the rationale for your choice before making your selection. + 4. Your current goals will be listed below, along with the list of MCP servers. + Please include your thought process in the "Thought Process" section and your selection in the "Selection Results" section. + 5. Your selection must be in JSON format, strictly following the template below. Do not output any additional content: + + ```json + { + "mcp": "The name of your selected MCP server" + } + ``` + + 6. The following example is for reference only. Do not use it as a basis for selecting an MCP server. + + ## Example + + ### Goal + + I need an MCP server to complete a task. + + ### MCP Server List + + - **mcp_1**: "MCP Server 1"; Description of MCP Server 1 + - **mcp_2**: "MCP Server 2"; Description of MCP Server 2 + + ### Think step by step: + + Because the current goal requires an MCP server to complete a task, select mcp_1. + + ### Select Result + + ```json + { + "mcp": "mcp_1" + } + ``` + + ## Let's get started! + + ### Goal + + {{goal}} + + ### MCP Server List + + {% for mcp in mcp_list %} + - **{{mcp.id}}**: "{{mcp.name}}"; {{mcp.description}} + {% endfor %} + + ### Think step by step: + +""" + ), +} +CREATE_PLAN: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" 你是一个计划生成器。 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 @@ -72,7 +140,8 @@ CREATE_PLAN = dedent(r""" 2. 计划中的每一个步骤必须且只能使用一个工具。 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 - + 5.生成的计划必须要覆盖用户的目标,不能遗漏任何用户目标中的内容。 + # 生成计划时的注意事项: - 每一条计划包含3个部分: @@ -93,8 +162,7 @@ CREATE_PLAN = dedent(r""" } ``` - - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ -思考过程应放置在 XML标签中。 + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应按步骤顺序放置在 XML标签中。 - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 @@ -106,8 +174,7 @@ CREATE_PLAN = dedent(r""" {% for tool in tools %} - {{ tool.id }}{{tool.name}};{{ tool.description }} {% endfor %} - - Final结束步骤,当执行到这一步时,\ -表示计划执行结束,所得到的结果将作为最终结果。 + - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 # 样例 @@ -162,8 +229,114 @@ CREATE_PLAN = dedent(r""" {{goal}} # 计划 -""") -EVALUATE_PLAN = dedent(r""" +""" + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan generator. + Please analyze the user's goal and generate a plan. You will then follow this plan to achieve the user's goal step by step. + + # A good plan should: + + 1. Be able to successfully achieve the user's goal. + 2. Each step in the plan must use only one tool. + 3. The steps in the plan must have clear and logical steps, without redundant or unnecessary steps. + 4. The last step in the plan must be a Final tool to ensure that the plan is executed. + + # Things to note when generating plans: + + - Each plan contains three parts: + - Plan content: Describes the general content of a single plan step + - Tool ID: Must be selected from the tool list below + - Tool instructions: Rewrite the user's goal to make it more consistent with the tool's input requirements + - Plans must be generated in the following format. Do not output any additional data: + + ```json + { + "plans": [ + { + "content":"Plan content", + "tool":"Tool ID", + "instruction":"Tool instructions" + } + ] + } + ``` + + - Before generating a plan, please think step by step, analyze the user's goal, and guide your next steps. The thought process should be placed in sequential steps within XML tags. + - In the plan content, you can use "Result[]" to reference the results of the previous plan steps. For example: "Result[3]" refers to the result after the third plan is executed. + - The plan should not have more than {{ max_num }} items, and each plan content should be less than 150 words. + + # Tools + + You can access and use some tools, which will be given in the XML tags. + + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}}; {{ tool.description }} + {% endfor %} + - FinalEnd step. When this step is executed, \ + Indicates that the plan execution is completed and the result obtained will be used as the final result. + + + # Example + + ## Target + + Run a new alpine:latest container in the background, mount the host/root folder to /data, and execute the top command. + + ## Plan + + + 1. This goal needs to be completed using Docker. First, you need to select a suitable MCP Server. + 2. The goal can be broken down into the following parts: + - Run the alpine:latest container + - Mount the host directory + - Run in the background + - Execute the top command + 3. You need to select an MCP Server first, then generate the Docker command, and finally execute the command. + + + ```json + { + "plans": [ + { + "content": "Select an MCP Server that supports Docker", + "tool": "mcp_selector", + "instruction": "You need an MCP Server that supports running Docker containers" + }, + { + "content": "Use the MCP Server selected in Result[0] to generate Docker commands", + "tool": "command_generator", + "instruction": "Generate Docker command: Run the alpine:latest container in the background, mount /root to /data, and execute the top command" + }, + { + "content": "Execute the command generated by Result[1] on the MCP Server of Result[0]", + "tool": "command_executor", + "instruction": "Execute Docker command" + }, + { + "content": "Task execution completed, the container is running in the background, the result is Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # Now start generating the plan: + + ## Goal + + {{goal}} + + # Plan +""" + ), +} +EVALUATE_PLAN: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" 你是一个计划评估器。 请根据给定的计划,和当前计划执行的实际情况,分析当前计划是否合理和完整,并生成改进后的计划。 @@ -209,8 +382,61 @@ EVALUATE_PLAN = dedent(r""" # 现在开始评估计划: -""") -FINAL_ANSWER = dedent(r""" +""" + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan evaluator. + Based on the given plan and the actual execution of the current plan, analyze whether the current plan is reasonable and complete, and generate an improved plan. + + # A good plan should: + + 1. Be able to successfully achieve the user's goal. + 2. Each step in the plan must use only one tool. + 3. The steps in the plan must have clear and logical steps, without redundant or unnecessary steps. + 4. The last step in the plan must be a Final tool to ensure the completion of the plan execution. + + # Your previous plan was: + + {{ plan }} + + # The execution status of this plan is: + + The execution status of the plan will be placed in the XML tags. + + + {{ memory }} + + + # Notes when conducting the evaluation: + + - Please think step by step, analyze the user's goal, and guide your subsequent generation. The thinking process should be placed in the XML tags. + - The evaluation results are divided into two parts: + - Conclusion of the plan evaluation + - Improved plan + - Please output the evaluation results in the following JSON format: + + ```json + { + "evaluation": "Evaluation results", + "plans": [ + { + "content": "Improved plan content", + "tool": "Tool ID", + "instruction": "Tool instructions" + } + ] + } + ``` + + # Start evaluating the plan now: + +""" + ), +} +FINAL_ANSWER: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 # 用户目标 @@ -229,12 +455,50 @@ FINAL_ANSWER = dedent(r""" # 现在,请根据以上信息,向用户报告目标的完成情况: -""") -MEMORY_TEMPLATE = dedent(r""" +""" + ), + LanguageType.ENGLISH: dedent( + r""" + Based on the understanding of the plan execution results and background information, report to the user the completion status of the goal. + + # User goal + + {{ goal }} + + # Plan execution status + + To achieve the above goal, you implemented the following plan: + + {{ memory }} + + # Other background information: + + {{ status }} + + # Now, based on the above information, please report to the user the completion status of the goal: + +""" + ), +} +MEMORY_TEMPLATE: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" {% for ctx in context_list %} - 第{{ loop.index }}步:{{ ctx.step_description }} - - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}` - - 执行状态:{{ ctx.status }} - - 得到数据:`{{ ctx.output_data }}` + - 调用工具 `{{ ctx.step_name }}`,并提供参数 `{{ ctx.input_data|tojson }}`。 + - 执行状态:{{ ctx.step_status }} + - 得到数据:`{{ ctx.output_data|tojson }}` + {% endfor %} +""" + ), + LanguageType.ENGLISH: dedent( + r""" + {% for ctx in context_list %} + - Step {{ loop.index }}: {{ ctx.step_description }} + - Called tool `{{ ctx.step_id }}` and provided parameters `{{ ctx.input_data }}` + - Execution status: {{ ctx.status }} + - Got data: `{{ ctx.output_data }}` {% endfor %} -""") +""" + ), +} diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 2ff5034471c5e9c38f166c6187b76dfb4596f734..e8b0e88c09ac1d1ac63f995b779da8483a2fce22 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -14,6 +14,7 @@ from apps.llm.reasoning import ReasoningLLM from apps.scheduler.mcp.prompt import ( MCP_SELECT, ) +from apps.schemas.enum_var import LanguageType from apps.schemas.mcp import ( MCPCollection, MCPSelectResult, @@ -39,7 +40,6 @@ class MCPSelector: sql += f"'{mcp_id}', " return sql.rstrip(", ") + ")" - async def _get_top_mcp_by_embedding( self, query: str, @@ -49,10 +49,17 @@ class MCPSelector: logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list) mcp_table = await LanceDB().get_table("mcp") query_embedding = await Embedding.get_embedding([query]) - mcp_vecs = await (await mcp_table.search( - query=query_embedding, - vector_column_name="embedding", - )).where(f"id IN {MCPSelector._assemble_sql(mcp_list)}").limit(5).to_list() + mcp_vecs = ( + await ( + await mcp_table.search( + query=query_embedding, + vector_column_name="embedding", + ) + ) + .where(f"id IN {MCPSelector._assemble_sql(mcp_list)}") + .limit(5) + .to_list() + ) # 拿到名称和description logger.info("[MCPHelper] 查询MCP Server名称和描述: %s", mcp_vecs) @@ -72,12 +79,8 @@ class MCPSelector: }]) return llm_mcp_list - async def _get_mcp_by_llm( - self, - query: str, - mcp_list: list[dict[str, str]], - mcp_ids: list[str], + self, query: str, mcp_list: list[dict[str, str]], mcp_ids: list[str], language ) -> MCPSelectResult: """通过LLM选择最合适的MCP Server""" # 初始化jinja2环境 @@ -87,7 +90,7 @@ class MCPSelector: trim_blocks=True, lstrip_blocks=True, ) - template = env.from_string(MCP_SELECT) + template = env.from_string(MCP_SELECT[language]) # 渲染模板 mcp_prompt = template.render( mcp_list=mcp_list, @@ -100,7 +103,6 @@ class MCPSelector: # 使用小模型提取JSON return await self._call_function_mcp(result, mcp_ids) - async def _call_reasoning(self, prompt: str) -> str: """调用大模型进行推理""" logger.info("[MCPHelper] 调用推理大模型") @@ -116,7 +118,6 @@ class MCPSelector: self.output_tokens += llm.output_tokens return result - async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: """调用结构化输出小模型提取JSON""" logger.info("[MCPHelper] 调用结构化输出小模型") @@ -136,11 +137,8 @@ class MCPSelector: raise return result - async def select_top_mcp( - self, - query: str, - mcp_list: list[str], + self, query: str, mcp_list: list[str], language: LanguageType = LanguageType.CHINESE ) -> MCPSelectResult: """ 选择最合适的MCP Server @@ -151,11 +149,12 @@ class MCPSelector: llm_mcp_list = await self._get_top_mcp_by_embedding(query, mcp_list) # 通过LLM选择最合适的 - return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) - + return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list, language) @staticmethod - async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: + async def select_top_tool( + query: str, mcp_list: list[str], top_n: int = 10 + ) -> list[MCPTool]: """选择最合适的工具""" tool_vector = await LanceDB().get_table("mcp_tool") query_embedding = await Embedding.get_embedding([query]) diff --git a/apps/scheduler/mcp_agent/__init__.py b/apps/scheduler/mcp_agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12f5cb68c12e4d19d830a8155eaeb0851fce897d --- /dev/null +++ b/apps/scheduler/mcp_agent/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Scheduler MCP 模块""" + +from apps.scheduler.mcp.host import MCPHost +from apps.scheduler.mcp.plan import MCPPlanner +from apps.scheduler.mcp.select import MCPSelector + +__all__ = ["MCPHost", "MCPPlanner", "MCPSelector"] diff --git a/apps/scheduler/mcp_agent/agent/base.py b/apps/scheduler/mcp_agent/agent/base.py deleted file mode 100644 index eccb58a9ce8466f67ac76040a53469edac455109..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/base.py +++ /dev/null @@ -1,196 +0,0 @@ -"""MCP Agent基类""" -import logging -from abc import ABC, abstractmethod -from contextlib import asynccontextmanager - -from pydantic import BaseModel, Field, model_validator - -from apps.common.queue import MessageQueue -from apps.schemas.enum_var import AgentState -from apps.schemas.task import Task -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.schema import Memory, Message, Role -from apps.services.activity import Activity - -logger = logging.getLogger(__name__) - - -class BaseAgent(BaseModel, ABC): - """ - 用于管理代理状态和执行的抽象基类。 - - 为状态转换、内存管理、 - 以及分步执行循环。子类必须实现`step`方法。 - """ - - msg_queue: MessageQueue - task: Task - name: str = Field(..., description="Agent名称") - agent_id: str = Field(default="", description="Agent ID") - description: str = Field(default="", description="Agent描述") - question: str - # Prompts - next_step_prompt: str | None = Field( - None, description="判断下一步动作的提示" - ) - - # Dependencies - llm: ReasoningLLM = Field(default_factory=ReasoningLLM, description="大模型实例") - memory: Memory = Field(default_factory=Memory, description="Agent记忆库") - state: AgentState = Field( - default=AgentState.IDLE, description="Agent状态" - ) - servers_id: list[str] = Field(default_factory=list, description="MCP server id") - - # Execution control - max_steps: int = Field(default=10, description="终止前的最大步长") - current_step: int = Field(default=0, description="执行中的当前步骤") - - duplicate_threshold: int = 2 - - user_prompt: str = r""" - 当前步骤:{step} 工具输出结果:{result} - 请总结当前正在执行的步骤和对应的工具输出结果,内容包括当前步骤是多少,执行的工具是什么,输出是什么。 - 最终以报告的形式展示。 - 如果工具输出结果中执行的工具为terminate,请按照状态输出本次交互过程最终结果并完成对整个报告的总结,不需要输出你的分析过程。 - """ - """用户提示词""" - - class Config: - arbitrary_types_allowed = True - extra = "allow" # Allow extra fields for flexibility in subclasses - - @model_validator(mode="after") - def initialize_agent(self) -> "BaseAgent": - """初始化Agent""" - if self.llm is None or not isinstance(self.llm, ReasoningLLM): - self.llm = ReasoningLLM() - if not isinstance(self.memory, Memory): - self.memory = Memory() - return self - - @asynccontextmanager - async def state_context(self, new_state: AgentState): - """ - Agent状态转换上下文管理器 - - Args: - new_state: 要转变的状态 - - :return: None - :raise ValueError: 如果new_state无效 - """ - if not isinstance(new_state, AgentState): - raise ValueError(f"无效状态: {new_state}") - - previous_state = self.state - self.state = new_state - try: - yield - except Exception as e: - self.state = AgentState.ERROR # Transition to ERROR on failure - raise e - finally: - self.state = previous_state # Revert to previous state - - def update_memory( - self, - role: Role, - content: str, - **kwargs, - ) -> None: - """添加信息到Agent的memory中""" - message_map = { - "user": Message.user_message, - "system": Message.system_message, - "assistant": Message.assistant_message, - "tool": lambda content, **kw: Message.tool_message(content, **kw), - } - - if role not in message_map: - raise ValueError(f"不支持的消息角色: {role}") - - # Create message with appropriate parameters based on role - kwargs = {**(kwargs if role == "tool" else {})} - self.memory.add_message(message_map[role](content, **kwargs)) - - async def run(self, request: str | None = None) -> str: - """异步执行Agent的主循环""" - self.task.runtime.question = request - if self.state != AgentState.IDLE: - raise RuntimeError(f"无法从以下状态运行智能体: {self.state}") - - if request: - self.update_memory("user", request) - - results: list[str] = [] - async with self.state_context(AgentState.RUNNING): - while ( - self.current_step < self.max_steps and self.state != AgentState.FINISHED - ): - if not await Activity.is_active(self.task.ids.user_sub): - logger.info("用户终止会话,任务停止!") - return "" - self.current_step += 1 - logger.info(f"执行步骤{self.current_step}/{self.max_steps}") - step_result = await self.step() - - # Check for stuck state - if self.is_stuck(): - self.handle_stuck_state() - result = f"Step {self.current_step}: {step_result}" - results.append(result) - - if self.current_step >= self.max_steps: - self.current_step = 0 - self.state = AgentState.IDLE - result = f"任务终止: 已达到最大步数 ({self.max_steps})" - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": result}, # type: ignore[arg-type] - ) - results.append(result) - return "\n".join(results) if results else "未执行任何步骤" - - @abstractmethod - async def step(self) -> str: - """ - 执行代理工作流程中的单个步骤。 - - 必须由子类实现,以定义具体的行为。 - """ - - def handle_stuck_state(self): - """通过添加更改策略的提示来处理卡住状态""" - stuck_prompt = "\ - 观察到重复响应。考虑新策略,避免重复已经尝试过的无效路径" - self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}" - logger.warning(f"检测到智能体处于卡住状态。新增提示:{stuck_prompt}") - - def is_stuck(self) -> bool: - """通过检测重复内容来检查代理是否卡在循环中""" - if len(self.memory.messages) < 2: - return False - - last_message = self.memory.messages[-1] - if not last_message.content: - return False - - duplicate_count = sum( - 1 - for msg in reversed(self.memory.messages[:-1]) - if msg.role == "assistant" and msg.content == last_message.content - ) - - return duplicate_count >= self.duplicate_threshold - - @property - def messages(self) -> list[Message]: - """从Agent memory中检索消息列表""" - return self.memory.messages - - @messages.setter - def messages(self, value: list[Message]) -> None: - """设置Agent memory的消息列表""" - self.memory.messages = value diff --git a/apps/scheduler/mcp_agent/agent/mcp.py b/apps/scheduler/mcp_agent/agent/mcp.py deleted file mode 100644 index 378da368aca02d0d628352fcc4816b98b2921d01..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/mcp.py +++ /dev/null @@ -1,81 +0,0 @@ -"""MCP Agent""" -import logging - -from pydantic import Field - -from apps.scheduler.mcp.host import MCPHost -from apps.scheduler.mcp_agent.agent.toolcall import ToolCallAgent -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class MCPAgent(ToolCallAgent): - """ - 用于与MCP(模型上下文协议)服务器交互。 - - 使用SSE或stdio传输连接到MCP服务器 - 并使服务器的工具 - """ - - name: str = "MCPAgent" - description: str = "一个多功能的智能体,能够使用多种工具(包括基于MCP的工具)解决各种任务" - - # Add general-purpose tools to the tool collection - available_tools: ToolCollection = Field( - default_factory=lambda: ToolCollection( - Terminate(), - ), - ) - - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - _initialized: bool = False - - @classmethod - async def create(cls, **kwargs) -> "MCPAgent": # noqa: ANN003 - """创建并初始化MCP Agent实例""" - instance = cls(**kwargs) - await instance.initialize_mcp_servers() - instance._initialized = True - return instance - - async def initialize_mcp_servers(self) -> None: - """初始化与已配置的MCP服务器的连接""" - mcp_host = MCPHost( - self.task.ids.user_sub, - self.task.id, - self.agent_id, - self.description, - ) - mcps = {} - for mcp_id in self.servers_id: - client = await mcp_host.get_client(mcp_id) - if client: - mcps[mcp_id] = client - - for mcp_id, mcp_client in mcps.items(): - new_tools = [] - for tool in mcp_client.tools: - original_name = tool.name - # Always prefix with server_id to ensure uniqueness - tool_name = f"mcp_{mcp_id}_{original_name}" - - server_tool = MCPClientTool( - name=tool_name, - description=tool.description, - parameters=tool.inputSchema, - session=mcp_client.session, - server_id=mcp_id, - original_name=original_name, - ) - new_tools.append(server_tool) - self.available_tools.add_tools(*new_tools) - - async def think(self) -> bool: - """使用适当的上下文处理当前状态并决定下一步操作""" - if not self._initialized: - await self.initialize_mcp_servers() - self._initialized = True - - return await super().think() diff --git a/apps/scheduler/mcp_agent/agent/react.py b/apps/scheduler/mcp_agent/agent/react.py deleted file mode 100644 index b56efd8b195eb36c4d5718711cc5f07b5a49812f..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/react.py +++ /dev/null @@ -1,35 +0,0 @@ -from abc import ABC, abstractmethod - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.agent.base import BaseAgent -from apps.scheduler.mcp_agent.schema import Memory - - -class ReActAgent(BaseAgent, ABC): - name: str - description: str | None = None - - system_prompt: str | None = None - next_step_prompt: str | None = None - - llm: ReasoningLLM | None = Field(default_factory=ReasoningLLM) - memory: Memory = Field(default_factory=Memory) - state: AgentState = AgentState.IDLE - - @abstractmethod - async def think(self) -> bool: - """处理当前状态并决定下一步操作""" - - @abstractmethod - async def act(self) -> str: - """执行已决定的行动""" - - async def step(self) -> str: - """执行一个步骤:思考和行动""" - should_act = await self.think() - if not should_act: - return "思考完成-无需采取任何行动" - return await self.act() diff --git a/apps/scheduler/mcp_agent/agent/toolcall.py b/apps/scheduler/mcp_agent/agent/toolcall.py deleted file mode 100644 index 1e22099ce1d2e2f2f54a3bc018511acf887a91a1..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/toolcall.py +++ /dev/null @@ -1,238 +0,0 @@ -import asyncio -import json -import logging -from typing import Any, Optional - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.function import JsonGenerator -from apps.llm.patterns import Select -from apps.scheduler.mcp_agent.agent.react import ReActAgent -from apps.scheduler.mcp_agent.schema import Function, Message, ToolCall -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class ToolCallAgent(ReActAgent): - """用于处理工具/函数调用的基本Agent类""" - - name: str = "toolcall" - description: str = "可以执行工具调用的智能体" - - available_tools: ToolCollection = ToolCollection( - Terminate(), - ) - tool_choices: str = "auto" - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - tool_calls: list[ToolCall] = Field(default_factory=list) - _current_base64_image: str | None = None - - max_observe: int | bool | None = None - - async def think(self) -> bool: - """使用工具处理当前状态并决定下一步行动""" - messages = [] - for message in self.messages: - if isinstance(message, Message): - message = message.to_dict() - messages.append(message) - try: - # 通过工具获得响应 - select_obj = Select() - choices = [] - for available_tool in self.available_tools.to_params(): - choices.append(available_tool.get("function")) - - tool = await select_obj.generate(question=self.question, choices=choices) - if tool in self.available_tools.tool_map: - schema = self.available_tools.tool_map[tool].parameters - json_generator = JsonGenerator( - query="根据跟定的信息,获取工具参数", - conversation=messages, - schema=schema, - ) # JsonGenerator - parameters = await json_generator.generate() - - else: - raise ValueError(f"尝试调用不存在的工具: {tool}") - except Exception as e: - raise - self.tool_calls = tool_calls = [ToolCall(id=tool, function=Function(name=tool, arguments=parameters))] - content = f"选择的执行工具为:{tool}, 参数为{parameters}" - - logger.info( - f"{self.name} 选择 {len(tool_calls) if tool_calls else 0}个工具执行" - ) - if tool_calls: - logger.info( - f"准备使用的工具: {[call.function.name for call in tool_calls]}" - ) - logger.info(f"工具参数: {tool_calls[0].function.arguments}") - - try: - - assistant_msg = ( - Message.from_tool_calls(content=content, tool_calls=self.tool_calls) - if self.tool_calls - else Message.assistant_message(content) - ) - self.memory.add_message(assistant_msg) - - if not self.tool_calls: - return bool(content) - - return bool(self.tool_calls) - except Exception as e: - logger.error(f"{self.name}的思考过程遇到了问题:: {e}") - self.memory.add_message( - Message.assistant_message( - f"处理时遇到错误: {str(e)}" - ) - ) - return False - - async def act(self) -> str: - """执行工具调用并处理其结果""" - if not self.tool_calls: - # 如果没有工具调用,则返回最后的消息内容 - return self.messages[-1].content or "没有要执行的内容或命令" - - results = [] - for command in self.tool_calls: - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"正在执行工具{command.function.name}"} - ) - - self._current_base64_image = None - - result = await self.execute_tool(command) - - if self.max_observe: - result = result[: self.max_observe] - - push_result = "" - async for chunk in self.llm.call( - messages=[{"role": "system", "content": "You are a helpful asistant."}, - {"role": "user", "content": self.user_prompt.format( - step=self.current_step, - result=result, - )}, ], streaming=False - ): - push_result += chunk - self.task.tokens.input_tokens += self.llm.input_tokens - self.task.tokens.output_tokens += self.llm.output_tokens - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": push_result}, # type: ignore[arg-type] - ) - - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"工具{command.function.name}执行完成"}, # type: ignore[arg-type] - ) - - logger.info( - f"工具'{command.function.name}'执行完成! 执行结果为: {result}" - ) - - # 将工具响应添加到内存 - tool_msg = Message.tool_message( - content=result, - tool_call_id=command.id, - name=command.function.name, - ) - self.memory.add_message(tool_msg) - results.append(result) - self.question += ( - f"\n已执行工具{command.function.name}, " - f"作用是{self.available_tools.tool_map[command.function.name].description},结果为{result}" - ) - - return "\n\n".join(results) - - async def execute_tool(self, command: ToolCall) -> str: - """执行单个工具调用""" - if not command or not command.function or not command.function.name: - return "错误:无效的命令格式" - - name = command.function.name - if name not in self.available_tools.tool_map: - return f"错误:未知工具 '{name}'" - - try: - # 解析参数 - args = command.function.arguments - # 执行工具 - logger.info(f"激活工具:'{name}'...") - result = await self.available_tools.execute(name=name, tool_input=args) - - # 执行特殊工具 - await self._handle_special_tool(name=name, result=result) - - # 格式化结果 - observation = ( - f"观察到执行的工具 `{name}`的输出:\n{str(result)}" - if result - else f"工具 `{name}` 已完成,无输出" - ) - - return observation - except json.JSONDecodeError: - error_msg = f"解析{name}的参数时出错:JSON格式无效" - logger.error( - f"{name}”的参数没有意义-无效的JSON,参数:{command.function.arguments}" - ) - return f"错误: {error_msg}" - except Exception as e: - error_msg = f"工具 '{name}' 遇到问题: {str(e)}" - logger.exception(error_msg) - return f"错误: {error_msg}" - - async def _handle_special_tool(self, name: str, result: Any, **kwargs): - """处理特殊工具的执行和状态变化""" - if not self._is_special_tool(name): - return - - if self._should_finish_execution(name=name, result=result, **kwargs): - # 将智能体状态设为finished - logger.info(f"特殊工具'{name}'已完成任务!") - self.state = AgentState.FINISHED - - @staticmethod - def _should_finish_execution(**kwargs) -> bool: - """确定工具执行是否应完成""" - return True - - def _is_special_tool(self, name: str) -> bool: - """检查工具名称是否在特殊工具列表中""" - return name.lower() in [n.lower() for n in self.special_tool_names] - - async def cleanup(self): - """清理Agent工具使用的资源。""" - logger.info(f"正在清理智能体的资源'{self.name}'...") - for tool_name, tool_instance in self.available_tools.tool_map.items(): - if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction( - tool_instance.cleanup - ): - try: - logger.debug(f"清理工具: {tool_name}") - await tool_instance.cleanup() - except Exception as e: - logger.error( - f"清理工具时发生错误'{tool_name}': {e}", exc_info=True - ) - logger.info(f"智能体清理完成'{self.name}'.") - - async def run(self, request: Optional[str] = None) -> str: - """运行Agent""" - try: - return await super().run(request) - finally: - await self.cleanup() diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py new file mode 100644 index 0000000000000000000000000000000000000000..103ec60daaccbdaa401e746360354b73341d8d08 --- /dev/null +++ b/apps/scheduler/mcp_agent/base.py @@ -0,0 +1,48 @@ +from typing import Any +import json +from jsonschema import validate +import logging +from apps.llm.function import JsonGenerator +from apps.llm.reasoning import ReasoningLLM + +logger = logging.getLogger(__name__) + + +class MCPBase: + """MCP基类""" + + @staticmethod + async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理结果""" + # 调用推理大模型 + message = [ + {"role": "system", "content": prompt}, + {"role": "user", "content": "Please provide a JSON response based on the above information and schema."}, + ] + result = "" + async for chunk in resoning_llm.call( + message, + streaming=False, + temperature=0.07, + result_only=False, + ): + result += chunk + + return result + + @staticmethod + async def _parse_result(result: str, schema: dict[str, Any]) -> str: + """解析推理结果""" + json_result = await JsonGenerator._parse_result_by_stack(result, schema) + if json_result is not None: + return json_result + json_generator = JsonGenerator( + "Please provide a JSON response based on the above information and schema.\n\n", + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": result}, + ], + schema, + ) + json_result = await json_generator.generate() + return json_result diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8048c6eba1bda055732935ebe3ae72c0379f49 --- /dev/null +++ b/apps/scheduler/mcp_agent/host.py @@ -0,0 +1,110 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP宿主""" + +import json +import logging +from typing import Any + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.llm.function import JsonGenerator +from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE +from apps.scheduler.mcp_agent.base import MCPBase +from apps.scheduler.mcp_agent.prompt import GEN_PARAMS, REPAIR_PARAMS +from apps.schemas.mcp import MCPTool +from apps.schemas.task import Task +from apps.schemas.enum_var import LanguageType + +logger = logging.getLogger(__name__) + +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, +) + + +def tojson_filter(value): + return json.dumps(value, ensure_ascii=False, separators=(',', ':')) + + +_env.filters["tojson"] = tojson_filter + +LLM_QUERY_FIX = { + LanguageType.CHINESE: "请生成修复之后的工具参数", + LanguageType.ENGLISH: "Please generate the tool parameters after repair", +} + + +class MCPHost(MCPBase): + """MCP宿主服务""" + + @staticmethod + async def assemble_memory(task: Task) -> str: + """组装记忆""" + + return _env.from_string(MEMORY_TEMPLATE[task.language]).render( + context_list=task.context, + ) + + @staticmethod + async def _get_first_input_params( + mcp_tool: MCPTool, + goal: str, + current_goal: str, + task: Task, + resoning_llm: ReasoningLLM = ReasoningLLM(), + ) -> dict[str, Any]: + """填充工具参数""" + # 更清晰的输入·指令,这样可以调用generate + prompt = _env.from_string(GEN_PARAMS[task.language]).render( + tool_name=mcp_tool.name, + tool_description=mcp_tool.description, + goal=goal, + current_goal=current_goal, + input_schema=mcp_tool.input_schema, + background_info=await MCPHost.assemble_memory(task), + ) + result = await MCPHost.get_resoning_result(prompt, resoning_llm) + # 使用JsonGenerator解析结果 + result = await MCPHost._parse_result( + result, + mcp_tool.input_schema, + ) + return result + + @staticmethod + async def _fill_params( + mcp_tool: MCPTool, + goal: str, + current_goal: str, + current_input: dict[str, Any], + error_message: str = "", + params: dict[str, Any] = {}, + params_description: str = "", + language: LanguageType = LanguageType.CHINESE, + ) -> dict[str, Any]: + llm_query = LLM_QUERY_FIX[language] + prompt = _env.from_string(REPAIR_PARAMS[language]).render( + tool_name=mcp_tool.name, + goal=goal, + current_goal=current_goal, + tool_description=mcp_tool.description, + input_schema=mcp_tool.input_schema, + input_params=current_input, + error_message=error_message, + params=params, + params_description=params_description, + ) + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + mcp_tool.input_schema, + ) + return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py new file mode 100644 index 0000000000000000000000000000000000000000..b85043eaac208b4b2b59c788336b034ed93733df --- /dev/null +++ b/apps/scheduler/mcp_agent/plan.py @@ -0,0 +1,471 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP 用户目标拆解与规划""" + +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.mcp_agent.base import MCPBase +from apps.scheduler.mcp_agent.prompt import ( + CHANGE_ERROR_MESSAGE_TO_DESCRIPTION, + CREATE_PLAN, + EVALUATE_GOAL, + FINAL_ANSWER, + GEN_STEP, + GENERATE_FLOW_NAME, + GENERATE_FLOW_EXCUTE_RISK, + GET_MISSING_PARAMS, + GET_REPLAN_START_STEP_INDEX, + IS_PARAM_ERROR, + RECREATE_PLAN, + RISK_EVALUATE, + TOOL_EXECUTE_ERROR_TYPE_ANALYSIS, + TOOL_SKIP, +) +from apps.schemas.enum_var import LanguageType +from apps.scheduler.slot.slot import Slot +from apps.schemas.mcp import ( + GoalEvaluationResult, + FlowName, + FlowRisk, + IsParamError, + MCPPlan, + MCPTool, + RestartStepIndex, + Step, + ToolExcutionErrorType, + ToolRisk, + ToolSkip, +) +from apps.schemas.task import Task + +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, +) +logger = logging.getLogger(__name__) + + +class MCPPlanner(MCPBase): + """MCP 用户目标拆解与规划""" + + @staticmethod + async def evaluate_goal( + goal: str, tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE,) -> GoalEvaluationResult: + """评估用户目标的可行性""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_evaluation(goal, tool_list, resoning_llm, language) + + # 返回评估结果 + return await MCPPlanner._parse_evaluation_result(result) + + @staticmethod + async def _get_reasoning_evaluation( + goal, tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE,) -> str: + """获取推理大模型的评估结果""" + template = _env.from_string(EVALUATE_GOAL[language]) + prompt = template.render( + goal=goal, + tools=tool_list, + ) + return await MCPPlanner.get_resoning_result(prompt, resoning_llm) + + @staticmethod + async def _parse_evaluation_result(result: str) -> GoalEvaluationResult: + """将推理结果解析为结构化数据""" + schema = GoalEvaluationResult.model_json_schema() + evaluation = await MCPPlanner._parse_result(result, schema) + # 使用GoalEvaluationResult模型解析结果 + return GoalEvaluationResult.model_validate(evaluation) + + async def get_flow_name( + user_goal: str, + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> FlowName: + """获取当前流程的名称""" + + result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm, language) + result = await MCPPlanner._parse_result(result, FlowName.model_json_schema()) + # 使用FlowName模型解析结果 + return FlowName.model_validate(result) + + @staticmethod + async def _get_reasoning_flow_name( + user_goal: str, + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: + """获取推理大模型的流程名称""" + template = _env.from_string(GENERATE_FLOW_NAME[language]) + prompt = template.render(goal=user_goal) + return await MCPPlanner.get_resoning_result(prompt, resoning_llm) + + @staticmethod + async def get_flow_excute_risk( + user_goal: str, + tools: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> FlowRisk: + """获取当前流程的风险评估结果""" + result = await MCPPlanner._get_reasoning_flow_risk(user_goal, tools, resoning_llm, language) + result = await MCPPlanner._parse_result(result, FlowRisk.model_json_schema()) + # 使用FlowRisk模型解析结果 + return FlowRisk.model_validate(result) + + @staticmethod + async def _get_reasoning_flow_risk( + user_goal: str, + tools: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: + """获取推理大模型的流程风险""" + template = _env.from_string(GENERATE_FLOW_EXCUTE_RISK[language]) + prompt = template.render( + goal=user_goal, + tools=tools, + ) + return await MCPPlanner.get_resoning_result(prompt, resoning_llm) + + @staticmethod + async def get_replan_start_step_index( + user_goal: str, + error_message: str, + current_plan: MCPPlan | None = None, + history: str = "", + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> RestartStepIndex: + """获取重新规划的步骤索引""" + # 获取推理结果 + template = _env.from_string(GET_REPLAN_START_STEP_INDEX[language]) + prompt = template.render( + goal=user_goal, + error_message=error_message, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + history=history, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + schema = RestartStepIndex.model_json_schema() + schema["properties"]["start_index"]["maximum"] = len(current_plan.plans) - 1 + schema["properties"]["start_index"]["minimum"] = 0 + restart_index = await MCPPlanner._parse_result(result, schema) + # 使用RestartStepIndex模型解析结果 + return RestartStepIndex.model_validate(restart_index) + + @staticmethod + async def create_plan( + user_goal: str, + is_replan: bool = False, + error_message: str = "", + current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 6, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> MCPPlan: + """规划下一步的执行流程,并输出""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_plan( + user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm, language + ) + + # 解析为结构化数据 + return await MCPPlanner._parse_plan_result(result, max_steps) + + @staticmethod + async def _get_reasoning_plan( + user_goal: str, + is_replan: bool = False, + error_message: str = "", + current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 10, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: + """获取推理大模型的结果""" + # 格式化Prompt + tool_ids = [tool.id for tool in tool_list] + if is_replan: + template = _env.from_string(RECREATE_PLAN[language]) + prompt = template.render( + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + error_message=error_message, + goal=user_goal, + tools=tool_list, + max_num=max_steps, + ) + else: + template = _env.from_string(CREATE_PLAN[language]) + prompt = template.render( + goal=user_goal, + tools=tool_list, + max_num=max_steps, + ) + return await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + + @staticmethod + async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan: + """将推理结果解析为结构化数据""" + # 格式化Prompt + schema = MCPPlan.model_json_schema() + schema["properties"]["plans"]["maxItems"] = max_steps + plan = await MCPPlanner._parse_result(result, schema) + # 使用Function模型解析结果 + return MCPPlan.model_validate(plan) + + @staticmethod + async def create_next_step( + goal: str, + history: str, + tools: list[MCPTool], + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> Step: + """创建下一步的执行步骤""" + # 获取推理结果 + template = _env.from_string(GEN_STEP[language]) + prompt = template.render(goal=goal, history=history, tools=tools) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + + # 解析为结构化数据 + schema = Step.model_json_schema() + if "enum" not in schema["properties"]["tool_id"]: + schema["properties"]["tool_id"]["enum"] = [] + for tool in tools: + schema["properties"]["tool_id"]["enum"].append(tool.id) + step = await MCPPlanner._parse_result(result, schema) + logger.info("[MCPPlanner] 创建下一步的执行步骤: %s", step) + # 使用Step模型解析结果 + + step = Step.model_validate(step) + return step + + @staticmethod + async def tool_skip( + task: Task, + step_id: str, + step_name: str, + step_instruction: str, + step_content: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> ToolSkip: + """判断当前步骤是否需要跳过""" + # 获取推理结果 + template = _env.from_string(TOOL_SKIP[language]) + from apps.scheduler.mcp_agent.host import MCPHost + history = await MCPHost.assemble_memory(task) + prompt = template.render( + step_id=step_id, + step_name=step_name, + step_instruction=step_instruction, + step_content=step_content, + history=history, + goal=task.runtime.question + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + + # 解析为结构化数据 + schema = ToolSkip.model_json_schema() + skip_result = await MCPPlanner._parse_result(result, schema) + # 使用ToolSkip模型解析结果 + return ToolSkip.model_validate(skip_result) + + @staticmethod + async def get_tool_risk( + tool: MCPTool, + input_parm: dict[str, Any], + additional_info: str = "", + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> ToolRisk: + """获取MCP工具的风险评估结果""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_risk( + tool, input_parm, additional_info, resoning_llm, language + ) + + # 返回风险评估结果 + return await MCPPlanner._parse_risk_result(result) + + @staticmethod + async def _get_reasoning_risk( + tool: MCPTool, + input_param: dict[str, Any], + additional_info: str, + resoning_llm: ReasoningLLM, + language: LanguageType = LanguageType.CHINESE, + ) -> str: + """获取推理大模型的风险评估结果""" + template = _env.from_string(RISK_EVALUATE[language]) + prompt = template.render( + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + additional_info=additional_info, + ) + return await MCPPlanner.get_resoning_result(prompt, resoning_llm) + + @staticmethod + async def _parse_risk_result(result: str) -> ToolRisk: + """将推理结果解析为结构化数据""" + schema = ToolRisk.model_json_schema() + risk = await MCPPlanner._parse_result(result, schema) + # 使用ToolRisk模型解析结果 + return ToolRisk.model_validate(risk) + + @staticmethod + async def _get_reasoning_tool_execute_error_type( + user_goal: str, + current_plan: MCPPlan, + tool: MCPTool, + input_param: dict[str, Any], + error_message: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: + """获取推理大模型的工具执行错误类型""" + template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS[language]) + prompt = template.render( + goal=user_goal, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + error_message=error_message, + ) + return await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + + @staticmethod + async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType: + """将推理结果解析为工具执行错误类型""" + schema = ToolExcutionErrorType.model_json_schema() + error_type = await MCPPlanner._parse_result(result, schema) + # 使用ToolExcutionErrorType模型解析结果 + return ToolExcutionErrorType.model_validate(error_type) + + @staticmethod + async def get_tool_execute_error_type( + user_goal: str, + current_plan: MCPPlan, + tool: MCPTool, + input_param: dict[str, Any], + error_message: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> ToolExcutionErrorType: + """获取MCP工具执行错误类型""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_tool_execute_error_type( + user_goal, current_plan, tool, input_param, error_message, reasoning_llm, language + ) + # 返回工具执行错误类型 + return await MCPPlanner._parse_tool_execute_error_type_result(result) + + @staticmethod + async def is_param_error( + goal: str, + history: str, + error_message: str, + tool: MCPTool, + step_description: str, + input_params: dict[str, Any], + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> IsParamError: + """判断错误信息是否是参数错误""" + tmplate = _env.from_string(IS_PARAM_ERROR[language]) + prompt = tmplate.render( + goal=goal, + history=history, + step_id=tool.id, + step_name=tool.name, + step_description=step_description, + input_params=input_params, + error_message=error_message, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + schema = IsParamError.model_json_schema() + is_param_error = await MCPPlanner._parse_result(result, schema) + # 使用IsParamError模型解析结果 + return IsParamError.model_validate(is_param_error) + + @staticmethod + async def change_err_message_to_description( + error_message: str, + tool: MCPTool, + input_params: dict[str, Any], + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: + """将错误信息转换为工具描述""" + template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION[language]) + prompt = template.render( + error_message=error_message, + tool_name=tool.name, + tool_description=tool.description, + input_schema=tool.input_schema, + input_params=input_params, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return result + + @staticmethod + async def get_missing_param( + tool: MCPTool, + input_param: dict[str, Any], + error_message: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> list[str]: + """获取缺失的参数""" + slot = Slot(schema=tool.input_schema) + template = _env.from_string(GET_MISSING_PARAMS[language]) + schema_with_null = slot.add_null_to_basic_types() + prompt = template.render( + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + schema=schema_with_null, + error_message=error_message, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + input_param_with_null = await MCPPlanner._parse_result(result, schema_with_null) + return input_param_with_null + + @staticmethod + async def generate_answer( + user_goal: str, + memory: str, + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> AsyncGenerator[str, None]: + """生成最终回答""" + template = _env.from_string(FINAL_ANSWER[language]) + prompt = template.render( + memory=memory, + goal=user_goal, + ) + async for chunk in resoning_llm.call( + [{"role": "user", "content": prompt}], + streaming=True, + temperature=0.07, + ): + yield chunk diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..d030c6f5aa9d06a6ea877d257f5b334c3f9c392c --- /dev/null +++ b/apps/scheduler/mcp_agent/prompt.py @@ -0,0 +1,2471 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP相关的大模型Prompt""" +from apps.schemas.enum_var import LanguageType +from textwrap import dedent + +MCP_SELECT: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,选择最合适的MCP Server。 + + ## 选择MCP Server时的注意事项: + + 1. 确保充分理解当前目标,选择最合适的MCP Server。 + 2. 请在给定的MCP Server列表中选择,不要自己生成MCP Server。 + 3. 请先给出你选择的理由,再给出你的选择。 + 4. 当前目标将在下面给出,MCP Server列表也会在下面给出。 + 请将你的思考过程放在"思考过程"部分,将你的选择放在"选择结果"部分。 + 5. 选择必须是JSON格式,严格按照下面的模板,不要输出任何其他内容: + + ```json + { + "mcp": "你选择的MCP Server的名称" + } + ``` + + 6. 下面的示例仅供参考,不要将示例中的内容作为选择MCP Server的依据。 + + ## 示例 + + ### 目标 + + 我需要一个MCP Server来完成一个任务。 + + ### MCP Server列表 + + - **mcp_1**: "MCP Server 1";MCP Server 1的描述 + - **mcp_2**: "MCP Server 2";MCP Server 2的描述 + + ### 请一步一步思考: + + 因为当前目标需要一个MCP Server来完成一个任务,所以选择mcp_1。 + + ### 选择结果 + + ```json + { + "mcp": "mcp_1" + } + ``` + + ## 现在开始! + + ### 目标 + + {{goal}} + + ### MCP Server列表 + + {% for mcp in mcp_list %} + - **{{mcp.id}}**: "{{mcp.name}}";{{mcp.description}} + {% endfor %} + + ### 请一步一步思考: + + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are an intelligent assistant who is willing to help. + Your task is: according to the current goal, select the most suitable MCP Server. + + ## Notes when selecting MCP Server: + + 1. Make sure to fully understand the current goal and select the most suitable MCP Server. + 2. Please select from the given MCP Server list, do not generate MCP Server by yourself. + 3. Please first give your reason for selection, then give your selection. + 4. The current goal will be given below, and the MCP Server list will also be given below. + Please put your thinking process in the "Thinking Process" part, and put your selection in the "Selection Result" part. + 5. The selection must be in JSON format, strictly follow the template below, do not output any other content: + + ```json + { + "mcp": "The name of the MCP Server you selected" + } + ``` + 6. The example below is for reference only, do not use the content in the example as the basis for selecting MCP Server. + + ## Example + + ### Goal + + I need an MCP Server to complete a task. + + ### MCP Server List + + - **mcp_1**: "MCP Server 1";Description of MCP Server 1 + - **mcp_2**: "MCP Server 2";Description of MCP Server 2 + + ### Please think step by step: + + Because the current goal needs an MCP Server to complete a task, so select mcp_1. + + ### Selection Result + + ```json + { + "mcp": "mcp_1" + } + ``` + + ## Now start! + ### Goal + + {{goal}} + + ### MCP Server List + + {% for mcp in mcp_list %} + - **{{mcp.id}}**: "{{mcp.name}}";{{mcp.description}} + {% endfor %} + + ### Please think step by step: + """ + ), +} +TOOL_SELECT: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,附加信息,选择最合适的MCP工具。 + ## 选择MCP工具时的注意事项: + 1. 确保充分理解当前目标,选择实现目标所需的MCP工具。 + 2. 请在给定的MCP工具列表中选择,不要自己生成MCP工具。 + 3. 可以选择一些辅助工具,但必须确保这些工具与当前目标相关。 + 4. 注意,返回的工具ID必须是MCP工具的ID,而不是名称。 + 5. 不要选择不存在的工具。 + 必须按照以下格式生成选择结果,不要输出任何其他内容: + ```json + { + "tool_ids": ["工具ID1", "工具ID2", ...] + } + ``` + + # 示例 + ## 目标 + 调优mysql性能 + ## MCP工具列表 + + - mcp_tool_1 MySQL链接池工具;用于优化MySQL链接池 + - mcp_tool_2 MySQL性能调优工具;用于分析MySQL性能瓶颈 + - mcp_tool_3 MySQL查询优化工具;用于优化MySQL查询语句 + - mcp_tool_4 MySQL索引优化工具;用于优化MySQL索引 + - mcp_tool_5 文件存储工具;用于存储文件 + - mcp_tool_6 mongoDB工具;用于操作MongoDB数据库 + + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```json + { + "max_connections": 1000, + "innodb_buffer_pool_size": "1G", + "query_cache_size": "64M" + } + ##输出 + ```json + { + "tool_ids": ["mcp_tool_1", "mcp_tool_2", "mcp_tool_3", "mcp_tool_4"] + } + ``` + # 现在开始! + ## 目标 + {{goal}} + ## MCP工具列表 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + ## 附加信息 + {{additional_info}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are an intelligent assistant who is willing to help. + Your task is: according to the current goal, additional information, select the most suitable MCP tool. + ## Notes when selecting MCP tool: + 1. Make sure to fully understand the current goal and select the MCP tool that can achieve the goal. + 2. Please select from the given MCP tool list, do not generate MCP tool by yourself. + 3. You can select some auxiliary tools, but you must ensure that these tools are related to the current goal. + 4. Note that the returned tool ID must be the ID of the MCP tool, not the name. + 5. Do not select non-existent tools. + Must generate the selection result in the following format, do not output any other content: + ```json + { + "tool_ids": ["tool_id1", "tool_id2", ...] + } + ``` + + # Example + ## Goal + Optimize MySQL performance + ## MCP Tool List + + - mcp_tool_1 MySQL connection pool tool;used to optimize MySQL connection pool + - mcp_tool_2 MySQL performance tuning tool;used to analyze MySQL performance bottlenecks + - mcp_tool_3 MySQL query optimization tool;used to optimize MySQL query statements + - mcp_tool_4 MySQL index optimization tool;used to optimize MySQL index + - mcp_tool_5 File storage tool;used to store files + - mcp_tool_6 MongoDB tool;used to operate MongoDB database + + ## Additional Information + 1. The current MySQL database version is 8.0.26 + 2. The current MySQL database configuration file path is /etc/my.cnf, and contains the following configuration items + ```json + { + "max_connections": 1000, + "innodb_buffer_pool_size": "1G", + "query_cache_size": "64M" + } + ## Output + ```json + { + "tool_ids": ["mcp_tool_1", "mcp_tool_2", "mcp_tool_3", "mcp_tool_4"] + } + ``` + # Now start! + ## Goal + {{goal}} + ## MCP Tool List + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + ## Additional Information + {{additional_info}} + # Output + """ + ), +} +EVALUATE_GOAL: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划评估器。 + 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。 + 如果能够完成,请返回`true`,否则返回`false`。 + 推理过程必须清晰明了,能够让人理解你的判断依据。 + 必须按照以下格式回答: + ```json + { + "can_complete": true/false, + "resoning": "你的推理过程" + } + ``` + + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + + # 工具集合 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - mysql_analyzer 分析MySQL数据库性能 + - performance_tuner 调优数据库性能 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + + # 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf + + ## + ```json + { + "can_complete": true, + "resoning": "当前的工具集合中包含mysql_analyzer和performance_tuner,能够完成对MySQL数据库的性能分析和调优,因此可以完成用户的目标。" + } + ``` + + # 目标 + {{goal}} + + # 工具集合 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + + # 附加信息 + {{additional_info}} + + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan evaluator. + Please judge whether the current tool set can complete the user's goal based on the user's goal and the current tool set and some additional information. + If it can be completed, return `true`, otherwise return `false`. + The reasoning process must be clear and understandable, so that people can understand your judgment basis. + Must answer in the following format: + ```json + { + "can_complete": true/false, + "resoning": "Your reasoning process" + } + ``` + + # Example + # Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + + # Tool Set + You can access and use some tools, which will be given in the XML tag. + + - mysql_analyzer Analyze MySQL database performance + - performance_tuner Tune database performance + - Final End step, when executing this step, it means that the plan execution is over, and the result obtained will be the final result. + + + # Additional Information + 1. The current MySQL database version is 8.0.26 + 2. The current MySQL database configuration file path is /etc/my.cnf + + ## + ```json + { + "can_complete": true, + "resoning": "The current tool set contains mysql_analyzer and performance_tuner, which can complete the performance analysis and optimization of MySQL database, so the user's goal can be completed." + } + ``` + + # Goal + {{goal}} + + # Tool Set + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + + # Additional Information + {{additional_info}} + + """ + ), +} +GENERATE_FLOW_NAME: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。 + + # 生成流程名称时的注意事项: + 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。 + 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。 + 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。 + 4. 流程名称应该尽量简短,小于20个字或者单词。 + 5. 只输出流程名称,不要输出其他内容。 + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 输出 + { + "flow_name": "扫描MySQL数据库并分析性能瓶颈,进行调优" + } + # 现在开始生成流程名称: + # 目标 + {{goal}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are an intelligent assistant, your task is to generate a suitable flow name based on the user's goal. + + # Notes when generating flow names: + 1. The flow name should be concise and clear, accurately expressing the process of achieving the user's goal. + 2. The flow name should include key operations or steps, such as "scan", "analyze", "tune", etc. + 3. The flow name should avoid using overly complex or professional terms, so that users can understand. + 4. The flow name should be as short as possible, less than 20 characters or words. + 5. Only output the flow name, do not output other content. + # Example + # Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + # Output + { + "flow_name": "Scan MySQL database and analyze performance bottlenecks, and optimize it." + } + # Now start generating the flow name: + # Goal + {{goal}} + # Output + """ + ), +} +GENERATE_FLOW_EXCUTE_RISK: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个智能助手,你的任务是根据用户的目标和当前的工具集合,评估当前流程的风险。 + + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 工具集合 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - mysql_analyzer 分析MySQL数据库性能 + - performance_tuner 调优数据库性能 + + # 输出 + { + "risk": "high", + "reason": "当前目标实现带来的风险较高,因为需要通过performance_tuner工具对MySQL数据库进行调优,而该工具可能会对数据库的性能和稳定性产生较大的影响,因此风险评估为高。" + } + # 现在开始评估当前流程的风险: + # 目标 + {{goal}} + # 工具集合 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are an intelligent assistant, your task is to evaluate the risk of the current process based on the user's goal and the current tool set. + # Example + # Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + # Tool Set + You can access and use some tools, which will be given in the XML tag. + + - mysql_analyzer Analyze MySQL database performance + - performance_tuner Tune database performance + + # Output + { + "risk": "high", + "reason": "The risk brought by the realization of the current goal is relatively high, because it is necessary to tune the MySQL database through the performance_tuner tool, which may have a greater impact on the performance and stability of the database. Therefore, the risk assessment is high." + } + # Now start evaluating the risk of the current process: + # Goal + {{goal}} + # Tool Set + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + # Output + """ + ) +} +GET_REPLAN_START_STEP_INDEX: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个智能助手,你的任务是根据用户的目标、报错信息和当前计划和历史,获取重新规划的步骤起始索引。 + + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 报错信息 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 当前计划 + ```json + { + "plans": [ + { + "step_id": "step_1", + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描" + }, + { + "step_id": "step_2", + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + } + ] + } + # 历史 + [ + { + id: "0", + task_id: "task_1", + flow_id: "flow_1", + flow_name: "MYSQL性能调优", + flow_status: "RUNNING", + step_id: "step_1", + step_name: "生成端口扫描命令", + step_description: "生成端口扫描命令:扫描当前MySQL数据库的端口", + step_status: "FAILED", + input_data: { + "command": "nmap -p 3306 + "target": "localhost" + }, + output_data: { + "error": "- bash: curl: command not found" + } + } + ] + # 输出 + { + "start_index": 0, + "reasoning": "当前计划的第一步就失败了,报错信息显示curl命令未找到,可能是因为没有安装curl工具,因此需要从第一步重新规划。" + } + # 现在开始获取重新规划的步骤起始索引: + # 目标 + {{goal}} + # 报错信息 + {{error_message}} + # 当前计划 + {{current_plan}} + # 历史 + {{history}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are an intelligent assistant, your task is to get the starting index of the step to be replanned based on the user's goal, error message, and current plan and history. + + # Example + # Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + # Error message + An error occurred while executing the port scan command: `- bash: curl: command not found`. + # Current plan + ```json + { + "plans": [ + { + "step_id": "step_1", + "content": "Generate port scan command", + "tool": "command_generator", + "instruction": "Generate port scan command: scan" + }, + { + "step_id": "step_2", + "content": "Execute the command generated by Result[0]", + "tool": "command_executor", + "instruction": "Execute port scan command" + } + ] + } + # History + [ + { + id: "0", + task_id: "task_1", + flow_id: "flow_1", + flow_name: "MYSQL Performance Tuning", + flow_status: "RUNNING", + step_id: "step_1", + step_name: "Generate port scan command", + step_description: "Generate port scan command: scan the port of the current MySQL database", + step_status: "FAILED", + input_data: { + "command": "nmap -p 3306 + "target": "localhost" + }, + output_data: { + "error": "- bash: curl: command not found" + } + } + ] + # Output + { + "start_index": 0, + "reasoning": "The first step of the current plan failed, the error message shows that the curl command was not found, which may be because the curl tool was not installed. Therefore, it is necessary to replan from the first step." + } + # Now start getting the starting index of the step to be replanned: + # Goal + {{goal}} + # Error message + {{error_message}} + # Current plan + {{current_plan}} + # History + {{history}} + # Output + """ + ), +} +CREATE_PLAN: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划生成器。 + 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + + # 生成计划时的注意事项: + + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: + + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + + # 样例 + + # 目标 + + 在后台运行一个新的alpine: latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + + # 计划 + + + 1. 这个目标需要使用Docker来完成, 首先需要选择合适的MCP Server + 2. 目标可以拆解为以下几个部分: + - 运行alpine: latest容器 + - 挂载主机目录 + - 在后台运行 + - 执行top命令 + 3. 需要先选择MCP Server, 然后生成Docker命令, 最后执行命令 + + ```json + { + "plans": [ + { + "content": "选择一个支持Docker的MCP Server", + "tool": "mcp_selector", + "instruction": "需要一个支持Docker容器运行的MCP Server" + }, + { + "content": "使用Result[0]中选择的MCP Server,生成Docker命令", + "tool": "command_generator", + "instruction": "生成Docker命令:在后台运行alpine:latest容器,挂载/root到/data,执行top命令" + }, + { + "content": "在Result[0]的MCP Server上执行Result[1]生成的命令", + "tool": "command_executor", + "instruction": "执行Docker命令" + }, + { + "content": "任务执行完成,容器已在后台运行,结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始生成计划: + + # 目标 + + {{goal}} + + # 计划 +""" + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan builder. + Analyze the user's goals and generate a plan. You will then follow this plan, step by step, to achieve the user's goals. + + # A good plan should: + + 1. Be able to successfully achieve the user's goals. + 2. Each step in the plan must use only one tool. + 3. The steps in the plan must have clear and logical progression, without redundant or unnecessary steps. + 4. The last step in the plan must be the Final tool to ensure the plan execution is complete. + + # Things to note when generating a plan: + + - Each plan contains 3 parts: + - Plan content: describes the general content of a single plan step + - Tool ID: must be selected from the tool list below + - Tool instructions: rewrite the user's goal to make it more consistent with the tool's input requirements + - The plan must be generated in the following format, and no additional data should be output: + + ```json + { + "plans": [ + { + "content": "Plan content", + "tool": "Tool ID", + "instruction": "Tool instruction" + } + ] + } + ``` + + - Before generating a plan, please think step by step, analyze the user's goals, and guide your subsequent generation. + The thinking process should be placed in the XML tags. + - In the plan content, you can use "Result[]" to reference the results of the previous plan step. For example: "Result[3]" refers to the result after the third plan is executed. + - There should be no more than {{max_num}} plans, and each plan content should be less than 150 words. + + # Tools + + You can access and use a number of tools, listed within the XML tags. + + + {% for tool in tools %} + - {{tool.id}} {{tool.name}}; {{tool.description}} + {% endfor %} + + + # Example + + # Goal + + Run a new alpine:latest container in the background, mount the host's /root folder to /data, and execute the top command. + + # Plan + + + 1. This goal needs to be completed using Docker. First, you need to select a suitable MCP Server. + 2. The goal can be broken down into the following parts: + - Run the alpine:latest container + - Mount the host directory + - Run in the background + - Execute the top command + 3. You need to select the MCP Server first, then generate the Docker command, and finally execute the command. + + ```json + { + "plans": [ + { + "content": "Select an MCP Server that supports Docker", + "tool": "mcp_selector", + "instruction": "You need an MCP Server that supports running Docker containers" + }, + { + "content": "Use the MCP Server selected in Result[0] to generate Docker commands", + "tool": "command_generator", + "instruction": "Generate Docker commands: run the alpine:latest container in the background, mount /root to /data, and execute the top command" + }, + { + "content": "In the MCP of Result[0] Execute the command generated by Result[1] on the server", + "tool": "command_executor", + "instruction": "Execute Docker command" + }, + { + "content": "Task execution completed, the container is running in the background, the result is Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # Now start generating the plan: + + # Goal + + {{goal}} + + # Plan + """ + ), +} +RECREATE_PLAN: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划重建器。 + 请根据用户的目标、当前计划和运行报错,重新生成一个计划。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 你的计划必须避免之前的错误,并且能够成功执行。 + 5. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + + # 生成计划时的注意事项: + + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: + + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 + + # 样例 + + # 目标 + + 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。 + # 工具 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - command_generator 生成命令行指令 + - tool_selector 选择合适的工具 + - command_executor 执行命令行指令 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + # 当前计划 + ```json + { + "plans": [ + { + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行第一步生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + # 运行报错 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 重新生成的计划 + + + 1. 这个目标需要使用网络扫描工具来完成, 首先需要选择合适的网络扫描工具 + 2. 目标可以拆解为以下几个部分: + - 生成端口扫描命令 + - 执行端口扫描命令 + 3.但是在执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + 4.我将计划调整为: + - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具 + - 执行这个命令,查看当前机器支持哪些网络扫描工具 + - 然后从中选择一个网络扫描工具 + - 基于选择的网络扫描工具,生成端口扫描命令 + - 执行端口扫描命令 + ```json + { + "plans": [ + { + "content": "需要生成一条命令查看当前机器支持哪些网络扫描工具", + "tool": "command_generator", + "instruction": "选择一个前机器支持哪些网络扫描工具" + }, + { + "content": "执行第一步中生成的命令,查看当前机器支持哪些网络扫描工具", + "tool": "command_executor", + "instruction": "执行第一步中生成的命令" + }, + { + "content": "从第二步执行结果中选择一个网络扫描工具,生成端口扫描命令", + "tool": "tool_selector", + "instruction": "选择一个网络扫描工具,生成端口扫描命令" + }, + { + "content": "基于第三步中选择的网络扫描工具,生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "执行第四步中生成的端口扫描命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始重新生成计划: + + # 目标 + + {{goal}} + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + + # 当前计划 + {{current_plan}} + + # 运行报错 + {{error_message}} + + # 重新生成的计划 +""" + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan rebuilder. + Please regenerate a plan based on the user's goals, current plan, and runtime errors. + + # A good plan should: + + 1. Successfully achieve the user's goals. + 2. Each step in the plan must use only one tool. + 3. The steps in the plan must have clear and logical progression, without redundant or unnecessary steps. + 4. Your plan must avoid previous errors and be able to be successfully executed. + 5. The last step in the plan must be the Final tool to ensure that the plan is complete. + + # Things to note when generating a plan: + + - Each plan contains 3 parts: + - Plan content: describes the general content of a single plan step + - Tool ID: must be selected from the tool list below + - Tool instructions: rewrite the user's goal to make it more consistent with the tool's input requirements + - The plan must be generated in the following format, and no additional data should be output: + + ```json + { + "plans": [ + { + "content": "Plan content", + "tool": "Tool ID", + "instruction": "Tool instruction" + } + ] + } + ``` + + - Before generating a plan, please think step by step, analyze the user's goals, and guide your subsequent generation. + The thinking process should be placed in the XML tags. + - In the plan content, you can use "Result[]" to reference the results of the previous plan step. For example: "Result[3]" refers to the result after the third plan is executed. + - There should be no more than {{max_num}} plans, and each plan content should be less than 150 words. + + # Objective + + Please scan the ports of the machine at 192.168.1.1 to see which ports are open. + # Tools + You can access and use a number of tools, which are listed within the XML tags. + + - command_generator Generates command line instructions + - tool_selector Selects the appropriate tool + - command_executor Executes command line instructions + - Final This is the final step. When this step is reached, the plan execution ends, and the result is used as the final result. + # Current plan + ```json + { + "plans": [ + { + "content": "Generate port scan command", + "tool": "command_generator", + "instruction": "Generate port scan command: Scan open ports on 192.168.1.1" + }, + { + "content": "Execute the command generated in the first step", + "tool": "command_executor", + "instruction": "Execute the port scan command" + }, + { + "content": "Task execution completed", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + # Run error + When executing the port scan command, an error occurred: `- bash: curl: command not found`. + # Regenerate the plan + + + 1. This goal requires a network scanning tool. First, select the appropriate network scanning tool. + 2. The goal can be broken down into the following parts: + - Generate the port scanning command + - Execute the port scanning command + 3. However, when executing the port scanning command, an error occurred: `- bash: curl: command not found`. + 4. I adjusted the plan to: + - Generate a command to check which network scanning tools the current machine supports + - Execute this command to check which network scanning tools the current machine supports + - Then select a network scanning tool + - Generate a port scanning command based on the selected network scanning tool + - Execute the port scanning command + ```json + { + "plans": [ + { + "content": "You need to generate a command to check which network scanning tools the current machine supports", + "tool": "command_generator", + "instruction": "Select which network scanning tools the current machine supports" + + }, + { + "content": "Execute the command generated in the first step to check which network scanning tools the current machine supports", + "tool": "command_executor", + "instruction": "Execute the command generated in the first step" + + }, + { + "content": "Select a network scanning tool from the results of the second step and generate a port scanning command", + "tool": "tool_selector", + "instruction": "Select a network scanning tool and generate a port scanning command" + + }, + { + "content": "Generate a port scan command based on the network scanning tool selected in step 3", + "tool": "command_generator", + "instruction": "Generate a port scan command: Scan the open ports on 192.168.1.1" + }, + { + "content": "Execute the port scan command generated in step 4", + "tool": "command_executor", + "instruction": "Execute the port scan command" + }, + { + "content": "Task execution completed", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # Now start regenerating the plan: + + # Goal + + {{goal}} + + # Tools + + You can access and use a number of tools, which are listed within the XML tags. + + + {% for tool in tools %} + - {{tool.id}} {{tool.name}}; {{tool.description}} + {% endfor %} + + + # Current plan + {{current_plan}} + + # Run error + {{error_message}} + + # Regenerated plan + """ + ), +} +GEN_STEP: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划生成器。 + 请根据用户的目标、当前计划和历史,生成一个新的步骤。 + + # 一个好的计划步骤应该: + 1.使用最适合的工具来完成当前步骤。 + 2.能够基于当前的计划和历史,完成阶段性的任务。 + 3.不要选择不存在的工具。 + 4.如果你认为当前已经达成了用户的目标,可以直接返回Final工具,表示计划执行结束。 + 5.tool_id中的工具ID必须是当前工具集合中存在的工具ID,而不是工具的名称。 + 6.工具在 XML标签中给出,工具的id在 下的 XML标签中给出。 + + # 样例 1 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优,我的ip是192.168.1.1,数据库端口是3306,用户名是root,密码是password + # 历史记录 + 第1步:生成端口扫描命令 + - 调用工具 `command_generator`,并提供参数 `帮我生成一个mysql端口扫描命令` + - 执行状态:成功 + - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` + 第2步:执行端口扫描命令 + - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"result": "success"}` + # 工具 + + - DuDlgP mysql分析工具,用于分析数据库性能/description> + - ADsxSX 文件存储工具,用于存储文件 + - ySASDZ mongoDB工具,用于操作MongoDB数据库 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + # 输出 + ```json + { + "tool_id": "DuDlgP", + "description": "扫描ip为192.168.1.1的MySQL数据库,端口为3306,用户名为root,密码为password的数据库性能", + } + ``` + # 样例二 + # 目标 + 计划从杭州到北京的旅游计划 + # 历史记录 + 第1步:将杭州转换为经纬度坐标 + - 调用工具 `经纬度工具`,并提供参数 `{"city_from": "杭州", "address": "西湖"}` + - 执行状态:成功 + - 得到数据:`{"location": "123.456, 78.901"}` + 第2步:查询杭州的天气 + - 调用工具 `天气查询工具`,并提供参数 `{"location": "123.456, 78.901"}` + - 执行状态:成功 + - 得到数据:`{"weather": "晴", "temperature": "25°C"}` + 第3步:将北京转换为经纬度坐标 + - 调用工具 `经纬度工具`,并提供参数 `{"city_from": "北京", "address": "天安门"}` + - 执行状态:成功 + - 得到数据:`{"location": "123.456, 78.901"}` + 第4步:查询北京的天气 + - 调用工具 `天气查询工具`,并提供参数 `{"location": "123.456, 78.901"}` + - 执行状态:成功 + - 得到数据:`{"weather": "晴", "temperature": "25°C"}` + # 工具 + + - cSAads 经纬度工具,将详细的结构化地址转换为经纬度坐标。支持对地标性名胜景区、建筑物名称解析为经纬度坐标 + - sScseS 天气查询工具,用于查询天气信息 + - pcSEsx 路径规划工具,根据用户起终点经纬度坐标规划综合各类公共(火车、公交、地铁)交通方式的通勤方案,并且返回通勤方案的数据,跨城场景下必须传起点城市与终点城市 + - Final Final;结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + # 输出 + ```json + { + "tool_id": "pcSEsx", + "description": "规划从杭州到北京的综合公共交通方式的通勤方案" + } + ``` + # 现在开始生成步骤: + # 目标 + {{goal}} + # 历史记录 + {{history}} + # 工具 + + {% for tool in tools %} + - {{tool.id}} {{tool.description}} + {% endfor %} + +""" + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan generator. + Please generate a new step based on the user's goal, current plan, and history. + + # A good plan step should: + 1. Use the most appropriate tool for the current step. + 2. Complete the tasks at each stage based on the current plan and history. + 3. Do not select a tool that does not exist. + 4. If you believe the user's goal has been achieved, return to the Final tool to complete the plan execution. + + # Example 1 + # Objective + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. My IP address is 192.168.1.1, the database port is 3306, my username is root, and my password is password. + # History + Step 1: Generate a port scan command + - Call the `command_generator` tool and provide the `help me generate a MySQL port scan command` parameter. + - Execution status: Success. + - Result: `{"command": "nmap -sS -p --open 192.168.1.1"}` + Step 2: Execute the port scan command + - Call the `command_executor` tool and provide the `{"command": "nmap -sS -p --open 192.168.1.1"}` parameter. + - Execution status: Success. + - Result: `{"result": "success"}` + # Tools + + - mcp_tool_1 mysql_analyzer; used for analyzing database performance. + - mcp_tool_2 File storage tool; used for storing files. + - mcp_tool_3 MongoDB tool; used for operating MongoDB databases. + - Final This step completes the plan execution and the result is used as the final result. + + # Output + ```json + { + "tool_id": "mcp_tool_1", + "description": "Scan the database performance of the MySQL database with IP address 192.168.1.1, port 3306, username root, and password password", + } + ``` + # Example 2 + # Objective + Plan a trip from Hangzhou to Beijing + # History + Step 1: Convert Hangzhou to latitude and longitude coordinates + - Call the `maps_geo_planner` tool and provide `{"city_from": "Hangzhou", "address": "West Lake"}` + - Execution status: Success + - Result: `{"location": "123.456, 78.901"}` + Step 2: Query the weather in Hangzhou + - Call the `weather_query` tool and provide `{"location": "123.456, 78.901"}` + - Execution Status: Success + - Result: `{"weather": "Sunny", "temperature": "25°C"}` + Step 3: Convert Beijing to latitude and longitude coordinates + - Call the `maps_geo_planner` tool and provide `{"city_from": "Beijing", "address": "Tiananmen"}` + - Execution Status: Success + - Result: `{"location": "123.456, 78.901"}` + Step 4: Query the weather in Beijing + - Call the `weather_query` tool and provide `{"location": "123.456, 78.901"}` + - Execution Status: Success + - Result: `{"weather": "Sunny", "temperature": "25°C"}` + # Tools + + - mcp_tool_4 maps_geo_planner; Converts a detailed structured address into longitude and latitude coordinates. Supports parsing landmarks, scenic spots, and building names into longitude and latitude coordinates. + - mcp_tool_5 weather_query; Weather query, used to query weather information. + - mcp_tool_6 maps_direction_transit_integrated; Plans a commuting plan based on the user's starting and ending longitude and latitude coordinates, integrating various public transportation modes (train, bus, subway), and returns the commuting plan data. For cross-city scenarios, both the starting and ending cities must be provided. + - Final Final; Final step. When this step is reached, plan execution is complete, and the resulting result is used as the final result. + + # Output + ```json + { + "tool_id": "mcp_tool_6", + "description": "Plan a comprehensive public transportation commute from Hangzhou to Beijing" + } + ``` + # Now start generating steps: + # Goal + {{goal}} + # History + {{history}} + # Tools + + {% for tool in tools %} + - {{tool.id}} {{tool.description}} + {% endfor %} + + """ + ), +} + +TOOL_SKIP: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划执行器。 + 你的任务是根据当前的计划和用户目标,判断当前步骤是否需要跳过。 + 如果需要跳过,请返回`true`,否则返回`false`。 + 必须按照以下格式回答: + ```json + { + "skip": true/false, + } + ``` + 注意: + 1.你的判断要谨慎,在历史消息中有足够的上下文信息时,才可以判断是否跳过当前步骤。 + # 样例 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 历史 + 第1步:生成端口扫描命令 + - 调用工具 `command_generator`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` + 第2步:执行端口扫描命令 + - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"result": "success"}` + 第3步:分析端口扫描结果 + - 调用工具 `mysql_analyzer`,并提供参数 `{"host": "192.168.1.1", "port": 3306, "username": "root", "password": "password"}` + - 执行状态:成功 + - 得到数据:`{"performance": "good", "bottleneck": "none"}` + # 当前步骤 + + step_4 + command_generator + 生成MySQL性能调优命令 + 生成MySQL性能调优命令:调优MySQL数据库性能 + + # 输出 + ```json + { + "skip": true + } + ``` + # 用户目标 + {{goal}} + # 历史 + {{history}} + # 当前步骤 + + {{step_id}} + {{step_name}} + {{step_instruction}} + {{step_content}} + + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan executor. + Your task is to determine whether the current step should be skipped based on the current plan and the user's goal. + If skipping is required, return `true`; otherwise, return `false`. + The answer must follow the following format: + ```json + { + "skip": true/false, + } + ``` + Note: + 1. Be cautious in your judgment and only decide whether to skip the current step when there is sufficient context in the historical messages. + # Example + # User Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + # History + Step 1: Generate a port scan command + - Call the `command_generator` tool with `{"command": "nmap -sS -p--open 192.168.1.1"}` + - Execution Status: Success + - Result: `{"command": "nmap -sS -p--open 192.168.1.1"}` + Step 2: Execute the port scan command + - Call the `command_executor` tool with `{"command": "nmap -sS -p--open 192.168.1.1"}` + - Execution Status: Success + - Result: `{"result": "success"}` + Step 3: Analyze the port scan results + - Call the `mysql_analyzer` tool with `{"host": "192.168.1.1", "port": 3306, "username": "root", "password": "password"}` + - Execution status: Success + - Result: `{"performance": "good", "bottleneck": "none"}` + # Current step + + step_4 + command_generator + Generate MySQL performance tuning commands + Generate MySQL performance tuning commands: Tune MySQL database performance + + # Output + ```json + { + "skip": true + } + ``` + # User goal + {{goal}} + # History + {{history}} + # Current step + + {{step_id}} + {{step_name}} + {{step_instruction}} + {{step_content}} + + # output + """ + ), +} +RISK_EVALUATE: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个工具执行计划评估器。 + 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。 + ```json + { + "risk": "low/medium/high", + "reason": "提示信息" + } + ``` + # 样例 + # 工具名称 + mysql_analyzer + # 工具描述 + 分析MySQL数据库性能 + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```ini + [mysqld] + innodb_buffer_pool_size=1G + innodb_log_file_size=256M + ``` + # 输出 + ```json + { + "risk": "medium", + "reason": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。" + } + ``` + # 工具 + + {{tool_name}} + {{tool_description}} + + # 工具入参 + {{input_param}} + # 附加信息 + {{additional_info}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a tool execution plan evaluator. + Your task is to determine the risk of executing the current tool based on its name, description, input parameters, and additional information, and output a warning. + ```json + { + "risk": "low/medium/high", + "reason": "prompt message" + } + ``` + # Example + # Tool name + mysql_analyzer + # Tool description + Analyzes MySQL database performance + # Tool input + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # Additional information + 1. The current MySQL database version is 8.0.26 + 2. The current MySQL database configuration file path is /etc/my.cnf and contains the following configuration items + ```ini + [mysqld] + innodb_buffer_pool_size=1G + innodb_log_file_size=256M + ``` + # Output + ```json + { + "risk": "medium", + "reason": "This tool will connect to a MySQL database and analyze performance, which may impact database performance. This operation should only be performed in a non-production environment." + } + ``` + # Tool + + {{tool_name}} + {{tool_description}} + + # Tool Input Parameters + {{input_param}} + # Additional Information + {{additional_info}} + # Output + + """ + ), +} +# 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划 +TOOL_EXECUTE_ERROR_TYPE_ANALYSIS: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划决策器。 + + 你的任务是根据用户目标、当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 + 请根据以下规则进行判断: + 1. 仅通过补充工具入参来解决问题的,返回 missing_param; + 2. 需要重计划当前步骤的,返回 decorrect_plan + 3.推理过程必须清晰明了,能够让人理解你的判断依据,并且不超过100字。 + 你的输出要以json格式返回,格式如下: + + ```json + { + "error_type": "missing_param/decorrect_plan, + "reason": "你的推理过程" + } + ``` + + # 样例 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 当前计划 + { + "plans": [ + { + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + # 当前使用的工具 + + command_executor + 执行命令行指令 + + # 工具入参 + { + "command": "nmap -sS -p--open 192.168.1.1" + } + # 工具运行报错 + 执行端口扫描命令时,出现了错误:`- bash: nmap: command not found`。 + # 输出 + ```json + { + "error_type": "decorrect_plan", + "reason": "当前计划的第二步执行失败,报错信息显示nmap命令未找到,可能是因为没有安装nmap工具,因此需要重计划当前步骤。" + } + ``` + # 用户目标 + {{goal}} + # 当前计划 + {{current_plan}} + # 当前使用的工具 + + {{tool_name}} + {{tool_description}} + + # 工具入参 + {{input_param}} + # 工具运行报错 + {{error_message}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan decider. + + Your task is to decide the next action based on the user's goal, the current plan, the tool being used, tool inputs, and tool errors. + Please make your decision based on the following rules: + 1. If the problem can be solved by simply adding tool inputs, return missing_param; + 2. If the current step needs to be replanned, return decorrect_plan. + 3. Your reasoning must be clear and concise, allowing the user to understand your decision. It should not exceed 100 words. + Your output should be returned in JSON format, as follows: + + ```json + { + "error_type": "missing_param/decorrect_plan, + "reason": "Your reasoning" + } + ``` + + # Example + # User Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + # Current Plan + { + "plans": [ + { + "content": "Generate port scan command", + "tool": "command_generator", + "instruction": "Generate port scan command: Scan the open ports of 192.168.1.1" + }, + { + "content": "Execute the command generated by Result[0]", + "tool": "command_executor", + "instruction": "Execute the port scan command" + }, + { + "content": "Task execution completed, the port scan result is Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + # Currently used tool + + command_executor + Execute command line instructions + + # Tool input parameters + { + "command": "nmap -sS -p--open 192.168.1.1" + } + # Tool running error + When executing the port scan command, an error occurred: `- bash: nmap: command not found`. + # Output + ```json + { + "error_type": "decorrect_plan", + "reason": "The second step of the current plan failed. The error message shows that the nmap command was not found. This may be because the nmap tool is not installed. Therefore, the current step needs to be replanned." + } + ``` + # User goal + {{goal}} + # Current plan + {{current_plan}} + # Currently used tool + + {{tool_name}} + {{tool_description}} + + # Tool input parameters + {{input_param}} + # Tool execution error + {{error_message}} + # Output + """ + ), +} + +IS_PARAM_ERROR: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划执行专家,你的任务是判断当前的步骤执行失败是否是因为参数错误导致的, + 如果是,请返回`true`,否则返回`false`。 + 必须按照以下格式回答: + ```json + { + "is_param_error": true/false, + } + ``` + # 样例 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 历史 + 第1步:生成端口扫描命令 + - 调用工具 `command_generator`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` + 第2步:执行端口扫描命令 + - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"result": "success"}` + # 当前步骤 + + step_3 + mysql_analyzer + 分析MySQL数据库性能 + + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具运行报错 + 执行MySQL性能分析命令时,出现了错误:`host is not correct`。 + + # 输出 + ```json + { + "is_param_error": true + } + ``` + # 用户目标 + {{goal}} + # 历史 + {{history}} + # 当前步骤 + + {{step_id}} + {{step_name}} + {{step_instruction}} + + # 工具入参 + {{input_param}} + # 工具运行报错 + {{error_message}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan execution expert. Your task is to determine whether the current step execution failure is due to parameter errors. + If so, return `true`; otherwise, return `false`. + The answer must be in the following format: + ```json + { + "is_param_error": true/false, + } + ``` + # Example + # User Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + # History + Step 1: Generate a port scan command + - Call the `command_generator` tool and provide `{"command": "nmap -sS -p--open 192.168.1.1"}` + - Execution Status: Success + - Result: `{"command": "nmap -sS -p--open 192.168.1.1"}` + Step 2: Execute the port scan command + - Call the `command_executor` tool and provide `{"command": "nmap -sS -p--open 192.168.1.1"}` + - Execution Status: Success + - Result: `{"result": "success"}` + # Current step + + step_3 + mysql_analyzer + Analyze MySQL database performance + + # Tool input parameters + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # Tool execution error + When executing the MySQL performance analysis command, an error occurred: `host is not correct`. + + # Output + ```json + { + "is_param_error": true + } + ``` + # User goal + {{goal}} + # History + {{history}} + # Current step + + {{step_id}} + {{step_name}} + {{step_instruction}} + + # Tool input parameters + {{input_param}} + # Tool error + {{error_message}} + # Output + """ + ), +} + +# 将当前程序运行的报错转换为自然语言 +CHANGE_ERROR_MESSAGE_TO_DESCRIPTION: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个智能助手,你的任务是将当前程序运行的报错转换为自然语言描述。 + 请根据以下规则进行转换: + 1. 将报错信息转换为自然语言描述,描述应该简洁明了,能够让人理解报错的原因和影响。 + 2. 描述应该包含报错的具体内容和可能的解决方案。 + 3. 描述应该避免使用过于专业的术语,以便用户能够理解。 + 4. 描述应该尽量简短,控制在50字以内。 + 5. 只输出自然语言描述,不要输出其他内容。 + # 样例 + # 工具信息 + + port_scanner + 扫描主机端口 + + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "主机地址" + }, + "port": { + "type": "integer", + "description": "端口号" + }, + "username": { + "type": "string", + "description": "用户名" + }, + "password": { + "type": "string", + "description": "密码" + } + }, + "required": ["host", "port", "username", "password"] + } + + + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 报错信息 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 输出 + 扫描端口时发生错误:密码不正确。请检查输入的密码是否正确,并重试。 + # 现在开始转换报错信息: + # 工具信息 + + {{tool_name}} + {{tool_description}} + + {{input_schema}} + + + # 工具入参 + {{input_params}} + # 报错信息 + {{error_message}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are an intelligent assistant. Your task is to convert the error message generated by the current program into a natural language description. + Please follow the following rules for conversion: + 1. Convert the error message into a natural language description. The description should be concise and clear, allowing users to understand the cause and impact of the error. + 2. The description should include the specific content of the error and possible solutions. + 3. The description should avoid using overly technical terms so that users can understand it. + 4. The description should be as brief as possible, within 50 words. + 5. Only output the natural language description, do not output other content. + # Example + # Tool Information + + port_scanner + Scan host ports + + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "Host address" + }, + "port": { + "type": "integer", + "description": "Port number" + }, + "username": { + "type": "string", + "description": "Username" + }, + "password": { + "type": "string", + "description": "Password" + } + }, + "required": ["host", "port", "username", "password"] + } + + + # Tool input + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # Error message + An error occurred while executing the port scan command: `password is not correct`. + # Output + An error occurred while scanning the port: The password is incorrect. Please check that the password you entered is correct and try again. + # Now start converting the error message: + # Tool information + + {{tool_name}} + {{tool_description}} + + {{input_schema}} + + + # Tool input parameters + {{input_params}} + # Error message + {{error_message}} + # Output + """ + ), +} +# 获取缺失的参数的json结构体 +GET_MISSING_PARAMS: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个工具参数获取器。 + 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,将当前缺失的参数设置为null,并输出一个JSON格式的字符串。 + ```json + { + "host": "请补充主机地址", + "port": "请补充端口号", + "username": "请补充用户名", + "password": "请补充密码" + } + ``` + # 样例 + # 工具名称 + mysql_analyzer + # 工具描述 + 分析MySQL数据库性能 + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具入参schema + { + "type": "object", + "properties": { + "host": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的主机地址(可以为字符串或null)" + }, + "port": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的端口号(可以是数字、字符串或null)" + }, + "username": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的用户名(可以为字符串或null)" + }, + "password": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的密码(可以为字符串或null)" + } + }, + "required": ["host", "port", "username", "password"] + } + # 运行报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": null, + "password": null + } + ``` + # 工具 + + {{tool_name}} + {{tool_description}} + + # 工具入参 + {{input_param}} + # 工具入参schema(部分字段允许为null) + {{input_schema}} + # 运行报错 + {{error_message}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a tool parameter getter. + Your task is to set missing parameters to null based on the current tool's name, description, input parameters, input parameter schema, and runtime errors, and output a JSON-formatted string. + ```json + { + "host": "Please provide the host address", + "port": "Please provide the port number", + "username": "Please provide the username", + "password": "Please provide the password" + } + ``` + # Example + # Tool Name + mysql_analyzer + # Tool Description + Analyze MySQL database performance + # Tool Input Parameters + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # Tool Input Parameter Schema + { + "type": "object", + "properties": { + "host": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL database host address (can be a string or null)" + }, + "port": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL database port number (can be a number, a string, or null)" + }, + "username": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL database username (can be a string or null)" + }, + "password": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL database password (can be a string or null)" + } + }, + "required": ["host", "port", "username", "password"] + } + # Run error + When executing the port scan command, an error occurred: `password is not correct`. + # Output + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": null, + "password": null + } + ``` + # Tool + + {{tool_name}} + {{tool_description}} + + # Tool input parameters + {{input_param}} + # Tool input parameter schema (some fields can be null) + {{input_schema}} + # Run error + {{error_message}} + # Output + """ + ), +} + +GEN_PARAMS: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个工具参数生成器。 + 你的任务是根据总的目标、阶段性的目标、工具信息、工具入参的schema和背景信息生成工具的入参。 + 注意: + 1.生成的参数在格式上必须符合工具入参的schema。 + 2.总的目标、阶段性的目标和背景信息必须被充分理解,利用其中的信息来生成工具入参。 + 3.生成的参数必须符合阶段性目标。 + + # 样例 + # 工具信息 + < tool > + < name > mysql_analyzer < /name > + < description > 分析MySQL数据库性能 < /description > + < / tool > + # 总目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优,ip地址是192.168.1.1,端口是3306,用户名是root,密码是password。 + # 当前阶段目标 + 我要连接MySQL数据库,分析性能瓶颈,并调优。 + # 工具入参的schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL数据库的主机地址" + }, + "port": { + "type": "integer", + "description": "MySQL数据库的端口号" + }, + "username": { + "type": "string", + "description": "MySQL数据库的用户名" + }, + "password": { + "type": "string", + "description": "MySQL数据库的密码" + } + }, + "required": ["host", "port", "username", "password"] + } + # 背景信息 + 第1步:生成端口扫描命令 + - 调用工具 `command_generator`,并提供参数 `帮我生成一个mysql端口扫描命令` + - 执行状态:成功 + - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` + 第2步:执行端口扫描命令 + - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"result": "success"}` + # 输出 + ```json + { + "host": "192.168.1.1", + "port": 3306, + "username": "root", + "password": "password" + } + ``` + # 工具 + < tool > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > + < / tool > + # 总目标 + {{goal}} + # 当前阶段目标 + {{current_goal}} + # 工具入参scheme + {{input_schema}} + # 背景信息 + {{background_info}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a tool parameter generator. + Your task is to generate tool input parameters based on the overall goal, phased goals, tool information, tool input parameter schema, and background information. + Note: + 1. The generated parameters must conform to the tool input parameter schema. + 2. The overall goal, phased goals, and background information must be fully understood and used to generate tool input parameters. + 3. The generated parameters must conform to the phased goals. + + # Example + # Tool Information + < tool > + < name >mysql_analyzer < /name > + < description > Analyze MySQL Database Performance < /description > + < / tool > + # Overall Goal + I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. The IP address is 192.168.1.1, the port is 3306, the username is root, and the password is password. + # Current Phase Goal + I need to connect to the MySQL database, analyze performance bottlenecks, and optimize it. # Tool input schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL database host address" + }, + "port": { + "type": "integer", + "description": "MySQL database port number" + }, + "username": { + "type": "string", + "description": "MySQL database username" + }, + "password": { + "type": "string", + "description": "MySQL database password" + } + }, + "required": ["host", "port", "username", "password"] + } + # Background information + Step 1: Generate a port scan command + - Call the `command_generator` tool and provide the `Help me generate a MySQL port scan command` parameter + - Execution status: Success + - Received data: `{"command": "nmap -sS -p --open 192.168.1.1"}` + + Step 2: Execute the port scan command + - Call the `command_executor` tool and provide the parameters `{"command": "nmap -sS -p --open 192.168.1.1"}` + - Execution status: Success + - Received data: `{"result": "success"}` + # Output + ```json + { + "host": "192.168.1.1", + "port": 3306, + "username": "root", + "password": "password" + } + ``` + # Tool + < tool > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > + < / tool > + # Overall goal + {{goal}} + # Current stage goal + {{current_goal}} + # Tool input scheme + {{input_schema}} + # Background information + {{background_info}} + # Output + """ + ), +} + +REPAIR_PARAMS: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个工具参数修复器。 + 你的任务是根据当前的工具信息、目标、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。 + + 注意: + 1.最终修复的参数要符合目标和工具入参的schema。 + + # 样例 + # 工具信息 + + mysql_analyzer + 分析MySQL数据库性能 + + # 总目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 当前阶段目标 + 我要连接MySQL数据库,分析性能瓶颈,并调优。 + # 工具入参的schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL数据库的主机地址" + }, + "port": { + "type": "integer", + "description": "MySQL数据库的端口号" + }, + "username": { + "type": "string", + "description": "MySQL数据库的用户名" + }, + "password": { + "type": "string", + "description": "MySQL数据库的密码" + } + }, + "required": ["host", "port", "username", "password"] + } + # 工具当前的入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具的报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 补充的参数 + { + "username": "admin", + "password": "admin123" + } + # 补充的参数描述 + 用户希望使用admin用户和admin123密码来连接MySQL数据库。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": "admin", + "password": "admin123" + } + ``` + # 工具 + + {{tool_name}} + {{tool_description}} + + # 总目标 + {{goal}} + # 当前阶段目标 + {{current_goal}} + # 工具入参scheme + {{input_schema}} + # 工具当前的入参 + {{input_params}} + # 运行报错 + {{error_message}} + # 补充的参数 + {{params}} + # 补充的参数描述 + {{params_description}} + # 输出 + """ + ), + LanguageType.ENGLISH: dedent( + r""" + You are a tool parameter fixer. + Your task is to fix the current tool input parameters based on the current tool information, tool input parameter schema, tool current input parameters, tool error, supplemented parameters, and supplemented parameter descriptions. + + # Example + # Tool information + + mysql_analyzer + Analyze MySQL database performance + + # Tool input parameter schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL database host address" + }, + "port": { + "type": "integer", + "description": "MySQL database port number" + }, + "username": { + "type": "string", + "description": "MySQL database username" + }, + "password": { + "type": "string", + "description": "MySQL database password" + } + }, + "required": ["host", "port", "username", "password"] + } + # Current tool input parameters + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # Tool error + When executing the port scan command, an error occurred: `password is not correct`. + # Supplementary parameters + { + "username": "admin", + "password": "admin123" + } + # Supplementary parameter description + The user wants to use the admin user and the admin123 password to connect to the MySQL database. + # Output + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": "admin", + "password": "admin123" + } + ``` + # Tool + + {{tool_name}} + {{tool_description}} + + # Tool input schema + {{input_schema}} + # Tool current input parameters + {{input_params}} + # Runtime error + {{error_message}} + # Supplementary parameters + {{params}} + # Supplementary parameter descriptions + {{params_description}} + # Output + """ + ), +} +FINAL_ANSWER: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 + + # 注意 + 1.输出的图片链接需要设置高为400px,且使用html的img标签进行展示,不能直接输出链接。 + 1.1 例如:图片描述(可选) + 2.不要输出模型相关的信息,例如“作为一个AI模型,我无法...”等。 + # 用户目标 + + {{goal}} + + # 计划执行情况 + + 为了完成上述目标,你实施了以下计划: + + {{memory}} + + # 其他背景信息: + + {{status}} + + # 现在,请根据以上信息,向用户报告目标的完成情况: + + """ + ), + LanguageType.ENGLISH: dedent( + r""" + Comprehensively understand the plan execution results and background information, and report the goal completion status to the user. + + # Note + 1. The output image link needs to be set to a height of 400px and displayed using the HTML img tag, not directly outputting the link. + 1.1 For example: Image description (optional) + 2. Do not output model-related information, such as "As an AI model, I cannot..." etc. + # User Goal + + {{goal}} + + # Plan Execution Status + + To achieve the above goal, you implemented the following plan: + + {{memory}} + + # Additional Background Information: + + {{status}} + + # Now, based on the above information, report the goal completion status to the user: + + """ + ), +} + +MEMORY_TEMPLATE: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + {% for ctx in context_list %} + - 第{{loop.index}}步:{{ctx.step_description}} + - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}` + - 执行状态:{{ctx.status}} + - 得到数据:`{{ctx.output_data}}` + {% endfor %} + """ + ), + LanguageType.ENGLISH: dedent( + r""" + {% for ctx in context_list %} + - Step {{loop.index}}: {{ctx.step_description}} + - Call the tool `{{ctx.step_id}}` and provide the parameter `{{ctx.input_data}}` + - Execution status: {{ctx.status}} + - Receive data: `{{ctx.output_data}}` + {% endfor %} + """ + ), +} diff --git a/apps/scheduler/mcp_agent/schema.py b/apps/scheduler/mcp_agent/schema.py deleted file mode 100644 index 614139074382daf128a19a320c949e5e46803c4d..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/schema.py +++ /dev/null @@ -1,148 +0,0 @@ -"""MCP Agent执行数据结构""" -from typing import Any, Self - -from pydantic import BaseModel, Field - -from apps.schemas.enum_var import Role - - -class Function(BaseModel): - """工具函数""" - - name: str - arguments: dict[str, Any] - - -class ToolCall(BaseModel): - """Represents a tool/function call in a message""" - - id: str - type: str = "function" - function: Function - - -class Message(BaseModel): - """Represents a chat message in the conversation""" - - role: Role = Field(...) - content: str | None = Field(default=None) - tool_calls: list[ToolCall] | None = Field(default=None) - name: str | None = Field(default=None) - tool_call_id: str | None = Field(default=None) - - def __add__(self, other) -> list["Message"]: - """支持 Message + list 或 Message + Message 的操作""" - if isinstance(other, list): - return [self] + other - elif isinstance(other, Message): - return [self, other] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" - ) - - def __radd__(self, other) -> list["Message"]: - """支持 list + Message 的操作""" - if isinstance(other, list): - return other + [self] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" - ) - - def to_dict(self) -> dict: - """Convert message to dictionary format""" - message = {"role": self.role} - if self.content is not None: - message["content"] = self.content - if self.tool_calls is not None: - message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] - if self.name is not None: - message["name"] = self.name - if self.tool_call_id is not None: - message["tool_call_id"] = self.tool_call_id - return message - - @classmethod - def user_message(cls, content: str) -> Self: - """Create a user message""" - return cls(role=Role.USER, content=content) - - @classmethod - def system_message(cls, content: str) -> Self: - """Create a system message""" - return cls(role=Role.SYSTEM, content=content) - - @classmethod - def assistant_message( - cls, content: str | None = None, - ) -> Self: - """Create an assistant message""" - return cls(role=Role.ASSISTANT, content=content) - - @classmethod - def tool_message( - cls, content: str, name: str, tool_call_id: str, - ) -> Self: - """Create a tool message""" - return cls( - role=Role.TOOL, - content=content, - name=name, - tool_call_id=tool_call_id, - ) - - @classmethod - def from_tool_calls( - cls, - tool_calls: list[Any], - content: str | list[str] = "", - **kwargs, # noqa: ANN003 - ) -> Self: - """Create ToolCallsMessage from raw tool calls. - - Args: - tool_calls: Raw tool calls from LLM - content: Optional message content - """ - formatted_calls = [ - {"id": call.id, "function": call.function.model_dump(), "type": "function"} - for call in tool_calls - ] - return cls( - role=Role.ASSISTANT, - content=content, - tool_calls=formatted_calls, - **kwargs, - ) - - -class Memory(BaseModel): - messages: list[Message] = Field(default_factory=list) - max_messages: int = Field(default=100) - - def add_message(self, message: Message) -> None: - """Add a message to memory""" - self.messages.append(message) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def add_messages(self, messages: list[Message]) -> None: - """Add multiple messages to memory""" - self.messages.extend(messages) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def clear(self) -> None: - """Clear all messages""" - self.messages.clear() - - def get_recent_messages(self, n: int) -> list[Message]: - """Get n most recent messages""" - return self.messages[-n:] - - def to_dict_list(self) -> list[dict]: - """Convert messages to list of dicts""" - return [msg.to_dict() for msg in self.messages] diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py new file mode 100644 index 0000000000000000000000000000000000000000..a62af7ce8a28c285403408c00bcb2b4aa57bbfa4 --- /dev/null +++ b/apps/scheduler/mcp_agent/select.py @@ -0,0 +1,127 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""选择MCP Server及其工具""" + +import logging +import random + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.llm.reasoning import ReasoningLLM +from apps.llm.token import TokenCalculator +from apps.scheduler.mcp_agent.base import MCPBase +from apps.scheduler.mcp_agent.prompt import TOOL_SELECT +from apps.schemas.mcp import MCPTool, MCPToolIdsSelectResult +from apps.schemas.enum_var import LanguageType + +logger = logging.getLogger(__name__) + +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) + +FINAL_TOOL_ID = "FIANL" +SUMMARIZE_TOOL_ID = "SUMMARIZE" +SELF_DESC_TOOL_ID = "SELF_DESC" + + +class MCPSelector(MCPBase): + """MCP选择器""" + + @staticmethod + async def select_top_tool( + goal: str, + tool_list: list[MCPTool], + additional_info: str | None = None, + top_n: int | None = None, + reasoning_llm: ReasoningLLM | None = None, + language: LanguageType = LanguageType.CHINESE, + ) -> list[MCPTool]: + """选择最合适的工具""" + random.shuffle(tool_list) + max_tokens = reasoning_llm._config.max_tokens + template = _env.from_string(TOOL_SELECT[language]) + token_calculator = TokenCalculator() + if ( + token_calculator.calculate_token_length( + messages=[ + { + "role": "user", + "content": template.render(goal=goal, tools=[], additional_info=additional_info), + } + ], + pure_text=True, + ) + > max_tokens + ): + logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择") + return [] + current_index = 0 + tool_ids = [] + while current_index < len(tool_list): + index = current_index + sub_tools = [] + while index < len(tool_list): + tool = tool_list[index] + tokens = token_calculator.calculate_token_length( + messages=[ + { + "role": "user", + "content": template.render( + goal=goal, tools=[tool], additional_info=additional_info + ), + } + ], + pure_text=True, + ) + if tokens > max_tokens: + continue + sub_tools.append(tool) + + tokens = token_calculator.calculate_token_length( + messages=[ + { + "role": "user", + "content": template.render( + goal=goal, tools=sub_tools, additional_info=additional_info + ), + }, + ], + pure_text=True, + ) + if tokens > max_tokens: + del sub_tools[-1] + break + else: + index += 1 + current_index = index + if sub_tools: + schema = MCPToolIdsSelectResult.model_json_schema() + if "items" not in schema["properties"]["tool_ids"]: + schema["properties"]["tool_ids"]["items"] = {} + # 将enum添加到items中,限制数组元素的可选值 + schema["properties"]["tool_ids"]["items"]["enum"] = [tool.id for tool in sub_tools] + result = await MCPSelector.get_resoning_result( + template.render(goal=goal, tools=sub_tools, additional_info="请根据目标选择对应的工具"), + reasoning_llm, + ) + result = await MCPSelector._parse_result(result, schema) + try: + result = MCPToolIdsSelectResult.model_validate(result) + tool_ids.extend(result.tool_ids) + except Exception: + logger.exception("[MCPSelector] 解析MCP工具ID选择结果失败") + continue + mcp_tools = [tool for tool in tool_list if tool.id in tool_ids] + + if top_n is not None: + mcp_tools = mcp_tools[:top_n] + mcp_tools.append( + MCPTool(id=FINAL_TOOL_ID, name="Final", description="终止", mcp_id=FINAL_TOOL_ID, input_schema={}) + ) + # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize", + # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={})) + return mcp_tools diff --git a/apps/scheduler/mcp_agent/tool/__init__.py b/apps/scheduler/mcp_agent/tool/__init__.py deleted file mode 100644 index 4593f31742fee21b2e3ec1c7c18ff8e3cfea2110..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool -from apps.scheduler.mcp_agent.tool.terminate import Terminate -from apps.scheduler.mcp_agent.tool.tool_collection import ToolCollection - -__all__ = [ - "BaseTool", - "Terminate", - "ToolCollection", -] diff --git a/apps/scheduler/mcp_agent/tool/base.py b/apps/scheduler/mcp_agent/tool/base.py deleted file mode 100644 index 04ad45c47a3eecb25efdf5b2ce52beb6965b2fbd..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/base.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field - - -class BaseTool(ABC, BaseModel): - name: str - description: str - parameters: Optional[dict] = None - - class Config: - arbitrary_types_allowed = True - - async def __call__(self, **kwargs) -> Any: - return await self.execute(**kwargs) - - @abstractmethod - async def execute(self, **kwargs) -> Any: - """使用给定的参数执行工具""" - - def to_param(self) -> Dict: - """将工具转换为函数调用格式""" - return { - "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - }, - } - - -class ToolResult(BaseModel): - """表示工具执行的结果""" - - output: Any = Field(default=None) - error: Optional[str] = Field(default=None) - system: Optional[str] = Field(default=None) - - class Config: - arbitrary_types_allowed = True - - def __bool__(self): - return any(getattr(self, field) for field in self.__fields__) - - def __add__(self, other: "ToolResult"): - def combine_fields( - field: Optional[str], other_field: Optional[str], concatenate: bool = True - ): - if field and other_field: - if concatenate: - return field + other_field - raise ValueError("Cannot combine tool results") - return field or other_field - - return ToolResult( - output=combine_fields(self.output, other.output), - error=combine_fields(self.error, other.error), - system=combine_fields(self.system, other.system), - ) - - def __str__(self): - return f"Error: {self.error}" if self.error else self.output - - def replace(self, **kwargs): - """返回一个新的ToolResult,其中替换了给定的字段""" - # return self.copy(update=kwargs) - return type(self)(**{**self.dict(), **kwargs}) - - -class ToolFailure(ToolResult): - """表示失败的ToolResult""" diff --git a/apps/scheduler/mcp_agent/tool/terminate.py b/apps/scheduler/mcp_agent/tool/terminate.py deleted file mode 100644 index 84aa120316de1123f985eebfd82c471e8ceec990..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/terminate.py +++ /dev/null @@ -1,25 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool - - -_TERMINATE_DESCRIPTION = """当请求得到满足或助理无法继续处理任务时,终止交互。 -当您完成所有任务后,调用此工具结束工作。""" - - -class Terminate(BaseTool): - name: str = "terminate" - description: str = _TERMINATE_DESCRIPTION - parameters: dict = { - "type": "object", - "properties": { - "status": { - "type": "string", - "description": "交互的完成状态", - "enum": ["success", "failure"], - } - }, - "required": ["status"], - } - - async def execute(self, status: str) -> str: - """Finish the current execution""" - return f"交互已完成,状态为: {status}" diff --git a/apps/scheduler/mcp_agent/tool/tool_collection.py b/apps/scheduler/mcp_agent/tool/tool_collection.py deleted file mode 100644 index 95bda317805abdecc256af0737091b46adf77b1a..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/tool_collection.py +++ /dev/null @@ -1,55 +0,0 @@ -"""用于管理多个工具的集合类""" -import logging -from typing import Any - -from apps.scheduler.mcp_agent.tool.base import BaseTool, ToolFailure, ToolResult - -logger = logging.getLogger(__name__) - - -class ToolCollection: - """定义工具的集合""" - - class Config: - arbitrary_types_allowed = True - - def __init__(self, *tools: BaseTool): - self.tools = tools - self.tool_map = {tool.name: tool for tool in tools} - - def __iter__(self): - return iter(self.tools) - - def to_params(self) -> list[dict[str, Any]]: - return [tool.to_param() for tool in self.tools] - - async def execute( - self, *, name: str, tool_input: dict[str, Any] = None - ) -> ToolResult: - tool = self.tool_map.get(name) - if not tool: - return ToolFailure(error=f"Tool {name} is invalid") - try: - result = await tool(**tool_input) - return result - except Exception as e: - return ToolFailure(error=f"Failed to execute tool {name}: {e}") - - def add_tool(self, tool: BaseTool): - """ - 将单个工具添加到集合中。 - - 如果已存在同名工具,则将跳过该工具并记录警告。 - """ - if tool.name in self.tool_map: - logger.warning(f"Tool {tool.name} already exists in collection, skipping") - return self - - self.tools += (tool,) - self.tool_map[tool.name] = tool - return self - - def add_tools(self, *tools: BaseTool): - for tool in tools: - self.add_tool(tool) - return self diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 678bbd379b89ded869abda0a2b5eb4f8778d030d..aefeda8219fdb02f939f550ca2e3264d50cc8f2b 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -106,7 +106,6 @@ class AppLoader: await file_checker.diff_one(app_path) await self.load(app_id, file_checker.hashes[f"app/{app_id}"]) - @staticmethod async def delete(app_id: str, *, is_reload: bool = False) -> None: """ @@ -161,5 +160,6 @@ class AppLoader: }, upsert=True, ) + app_pool = await app_collection.find_one({"_id": metadata.id}) except Exception: logger.exception("[AppLoader] 更新 MongoDB 失败") diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 47a08152df133724fadf4ece0878a35869e5c496..975c7cea4bf2a0577307849c6d41449ee4c35ce1 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -11,7 +11,7 @@ import yaml from anyio import Path from apps.common.config import Config -from apps.schemas.enum_var import EdgeType +from apps.schemas.enum_var import NodeType,EdgeType from apps.schemas.flow import AppFlow, Flow from apps.schemas.pool import AppPool from apps.models.vector import FlowPoolVector @@ -82,25 +82,18 @@ class FlowLoader: err = f"[FlowLoader] 步骤名称不能以下划线开头:{key}" logger.error(err) raise ValueError(err) - if key == "start": - step["name"] = "开始" - step["description"] = "开始节点" - step["type"] = "start" - elif key == "end": - step["name"] = "结束" - step["description"] = "结束节点" - step["type"] = "end" - else: - try: - step["type"] = await NodeManager.get_node_call_id(step["node"]) - except ValueError as e: - logger.warning("[FlowLoader] 获取节点call_id失败:%s,错误信息:%s", step["node"], e) - step["type"] = "Empty" - step["name"] = ( - (await NodeManager.get_node_name(step["node"])) - if "name" not in step or step["name"] == "" - else step["name"] - ) + if step["type"]==NodeType.START.value or step["type"]==NodeType.END.value: + continue + try: + step["type"] = await NodeManager.get_node_call_id(step["node"]) + except ValueError as e: + logger.warning("[FlowLoader] 获取节点call_id失败:%s,错误信息:%s", step["node"], e) + step["type"] = "Empty" + step["name"] = ( + (await NodeManager.get_node_name(step["node"])) + if "name" not in step or step["name"] == "" + else step["name"] + ) return flow_yaml async def load(self, app_id: str, flow_id: str) -> Flow | None: diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 1463d0a1a9141bca271e29d7387759748aa8bd43..bb4cb8a10408118d7614145259e1ef0fd9de8d46 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -11,6 +11,7 @@ import shutil import asyncer from anyio import Path from sqids.sqids import Sqids +from typing import Any from apps.common.lance import LanceDB from apps.common.mongo import MongoDB @@ -91,48 +92,47 @@ class MCPLoader(metaclass=SingletonMeta): :param MCPServerConfig config: MCP配置 :return: 无 """ - if not config.config.auto_install: - print(f"[Installer] MCP模板无需安装: {mcp_id}") # noqa: T201 - - elif isinstance(config.config, MCPServerStdioConfig): - print(f"[Installer] Stdio方式的MCP模板,开始自动安装: {mcp_id}") # noqa: T201 - if "uv" in config.config.command: - new_config = await install_uvx(mcp_id, config.config) - elif "npx" in config.config.command: - new_config = await install_npx(mcp_id, config.config) - - if new_config is None: - logger.error("[MCPLoader] MCP模板安装失败: %s", mcp_id) - await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) - return - - config.config = new_config - - # 重新保存config - template_config = MCP_PATH / "template" / mcp_id / "config.json" - f = await template_config.open("w+", encoding="utf-8") - config_data = config.model_dump(by_alias=True, exclude_none=True) - await f.write(json.dumps(config_data, indent=4, ensure_ascii=False)) - await f.aclose() - - else: - print(f"[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: {mcp_id}") # noqa: T201 - config.config.auto_install = False - - print(f"[Installer] MCP模板安装成功: {mcp_id}") # noqa: T201 - await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY) - await MCPLoader._insert_template_tool(mcp_id, config) + try: + if not config.config.auto_install: + print(f"[Installer] MCP模板无需安装: {mcp_id}") # noqa: T201 + elif isinstance(config.config, MCPServerStdioConfig): + print(f"[Installer] Stdio方式的MCP模板,开始自动安装: {mcp_id}") # noqa: T201 + if "uv" in config.config.command: + new_config = await install_uvx(mcp_id, config.config) + elif "npx" in config.config.command: + new_config = await install_npx(mcp_id, config.config) + + if new_config is None: + logger.error("[MCPLoader] MCP模板安装失败: %s", mcp_id) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) + return + + config.config = new_config + + # 重新保存config + template_config = MCP_PATH / "template" / mcp_id / "config.json" + f = await template_config.open("w+", encoding="utf-8") + config_data = config.model_dump(by_alias=True, exclude_none=True) + await f.write(json.dumps(config_data, indent=4, ensure_ascii=False)) + await f.aclose() + + else: + logger.info(f"[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: {mcp_id}") # noqa: T201 + config.config.auto_install = False + + await MCPLoader._insert_template_tool(mcp_id, config) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY) + logger.info(f"[Installer] MCP模板安装成功: {mcp_id}") # noqa: T201 + except Exception as e: + logger.error("[MCPLoader] MCP模板安装失败: %s, 错误: %s", mcp_id, e) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) + raise @staticmethod - async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None: + async def clear_ready_or_failed_mcp_installation() -> None: """ - 初始化单个MCP模板 - - :param str mcp_id: MCP模板ID - :param MCPServerConfig config: MCP配置 - :return: 无 + 清除状态为ready或failed的MCP安装任务 """ - # 删除完成或者失败的MCP安装任务 mcp_collection = MongoDB().get_collection("mcp") mcp_ids = ProcessHandler.get_all_task_ids() # 检索_id在mcp_ids且状态为ready或者failed的MCP的内容 @@ -147,48 +147,52 @@ class MCPLoader(metaclass=SingletonMeta): continue ProcessHandler.remove_task(item.id) logger.info("[MCPLoader] 删除已完成或失败的MCP安装进程: %s", item.id) - # 插入数据库;这里用旧的config就可以 - await MCPLoader._insert_template_db(mcp_id, config) + + @staticmethod + async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None: + """ + 初始化单个MCP模板 + + :param str mcp_id: MCP模板ID + :param MCPServerConfig config: MCP配置 + :return: 无 + """ + await MCPLoader.clear_ready_or_failed_mcp_installation() # 检查目录 template_path = MCP_PATH / "template" / mcp_id await Path.mkdir(template_path, parents=True, exist_ok=True) # 安装MCP模板 + ProcessHandler.remove_task(mcp_id) if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, mcp_id, config): err = f"安装任务无法执行,请稍后重试: {mcp_id}" logger.error(err) raise RuntimeError(err) + # 将installing状态的安装任务的状态变为cancelled @staticmethod - async def _init_all_template() -> None: + async def cancel_all_installing_task() -> None: """ - 初始化所有MCP模板 - - 遍历 ``template`` 目录下的所有MCP模板,并初始化。在Framework启动时进行此流程,确保所有MCP均可正常使用。 - 这一过程会与数据库内的条目进行对比,若发生修改,则重新创建数据库条目。 + 取消正在安装的MCP模板任务 """ template_path = MCP_PATH / "template" logger.info("[MCPLoader] 初始化所有MCP模板: %s", template_path) - + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") # 遍历所有模板 + mcp_ids = [] async for mcp_dir in template_path.iterdir(): # 不是目录 if not await mcp_dir.is_dir(): logger.warning("[MCPLoader] 跳过非目录: %s", mcp_dir.as_posix()) continue - # 检查配置文件是否存在 - config_path = mcp_dir / "config.json" - if not await config_path.exists(): - logger.warning("[MCPLoader] 跳过没有配置文件的MCP模板: %s", mcp_dir.as_posix()) - continue - - # 读取配置并加载 - config = await MCPLoader._load_config(config_path) - - # 初始化第一个MCP Server - logger.info("[MCPLoader] 初始化MCP模板: %s", mcp_dir.as_posix()) - await MCPLoader.init_one_template(mcp_dir.name, config) + mcp_ids.append(mcp_dir.name) + # 更新数据库状态 + await mcp_collection.update_many( + {"_id": {"$in": mcp_ids}, "status": MCPInstallStatus.INSTALLING}, + {"$set": {"status": MCPInstallStatus.CANCELLED}}, + ) @staticmethod async def _get_template_tool( @@ -263,6 +267,12 @@ class MCPLoader(metaclass=SingletonMeta): # 基本信息插入数据库 mcp_collection = MongoDB().get_collection("mcp") + # 清空当前工具列表 + await mcp_collection.update_one( + {"_id": mcp_id}, + {"$set": {"tools": []}}, + upsert=True, + ) await mcp_collection.update_one( {"_id": mcp_id}, { @@ -345,14 +355,15 @@ class MCPLoader(metaclass=SingletonMeta): :return: 图标 :rtype: str """ - icon_path = MCP_PATH / "template" / mcp_id / "icon.png" + icon_path = MCP_PATH / "template" / mcp_id / "icon" / f"{mcp_id}.png" if not await icon_path.exists(): logger.warning("[MCPLoader] MCP模板图标不存在: %s", mcp_id) return "" f = await icon_path.open("rb") icon = await f.read() await f.aclose() - return base64.b64encode(icon).decode("utf-8") + header = "data:image/png;base64," + return header + base64.b64encode(icon).decode("utf-8") @staticmethod async def get_config(mcp_id: str) -> MCPServerConfig: @@ -384,6 +395,7 @@ class MCPLoader(metaclass=SingletonMeta): # 更新数据库 mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") + logger.info("[MCPLoader] 更新MCP模板状态: %s -> %s", mcp_id, status) await mcp_collection.update_one( {"_id": mcp_id}, {"$set": {"status": status}}, @@ -391,7 +403,7 @@ class MCPLoader(metaclass=SingletonMeta): ) @staticmethod - async def user_active_template(user_sub: str, mcp_id: str) -> None: + async def user_active_template(user_sub: str, mcp_id: str, mcp_env: dict[str, Any] | None = None) -> None: """ 用户激活MCP模板 @@ -409,7 +421,7 @@ class MCPLoader(metaclass=SingletonMeta): if await user_path.exists(): err = f"MCP模板“{mcp_id}”已存在或有同名文件,无法激活" raise FileExistsError(err) - + mcp_config = await MCPLoader.get_config(mcp_id) # 拷贝文件 await asyncer.asyncify(shutil.copytree)( template_path.as_posix(), @@ -417,7 +429,35 @@ class MCPLoader(metaclass=SingletonMeta): dirs_exist_ok=True, symlinks=True, ) - + if mcp_env is not None: + if mcp_config.type == MCPType.STDIO: + mcp_config.config.env.update(mcp_env) + else: + mcp_config.config.headers.update(mcp_env) + if mcp_config.type == MCPType.STDIO: + index = None + for i in range(len(mcp_config.config.args)): + if mcp_config.config.args[i] == "--directory": + index = i + 1 + break + if index is not None: + if index < len(mcp_config.config.args): + mcp_config.config.args[index] = str(user_path)+'/project' + else: + mcp_config.config.args.append(str(user_path)+'/project') + else: + mcp_config.config.args = ["--directory", str(user_path)+'/project'] + mcp_config.config.args + user_config_path = user_path / "config.json" + # 更新用户配置 + f = await user_config_path.open("w", encoding="utf-8", errors="ignore") + await f.write( + json.dumps( + mcp_config.model_dump(by_alias=True, exclude_none=True), + indent=4, + ensure_ascii=False, + ) + ) + await f.aclose() # 更新数据库 mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") @@ -468,6 +508,26 @@ class MCPLoader(metaclass=SingletonMeta): logger.info("[MCPLoader] 这些MCP在文件系统中被删除: %s", deleted_mcp_list) return deleted_mcp_list + @staticmethod + async def cancel_installing_task(cancel_mcp_list: list[str]) -> None: + """ + 取消正在安装的MCP模板任务 + + :param list[str] cancel_mcp_list: 需要取消的MCP列表 + :return: 无 + """ + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") + # 更新数据库状态 + cancel_mcp_list = await mcp_collection.distinct("_id", {"_id": {"$in": cancel_mcp_list}, "status": MCPInstallStatus.INSTALLING}) + await mcp_collection.update_many( + {"_id": {"$in": cancel_mcp_list}, "status": MCPInstallStatus.INSTALLING}, + {"$set": {"status": MCPInstallStatus.CANCELLED}}, + ) + for mcp_id in cancel_mcp_list: + ProcessHandler.remove_task(mcp_id) + logger.info("[MCPLoader] 取消这些正在安装的MCP模板任务: %s", cancel_mcp_list) + @staticmethod async def remove_deleted_mcp(deleted_mcp_list: list[str]) -> None: """ @@ -573,8 +633,8 @@ class MCPLoader(metaclass=SingletonMeta): # 检查目录 await MCPLoader._check_dir() - # 初始化所有模板 - await MCPLoader._init_all_template() + # 暂停所有安装任务 + await MCPLoader.cancel_all_installing_task() # 加载用户MCP await MCPLoader._load_user_mcp() diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 2b9060461fc0f3baaece19e88c71776449e9752a..2d84069c434de99e45156452a1e5156ba03b2b3e 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -3,6 +3,7 @@ import asyncio import logging +import os import shutil from anyio import Path @@ -30,6 +31,9 @@ class ServiceLoader: """加载单个Service""" service_path = BASE_PATH / service_id # 载入元数据 + if not os.path.exists(service_path / "metadata.yaml"): + logger.error("[ServiceLoader] Service %s 的元数据不存在", service_id) + return metadata = await MetadataLoader().load_one(service_path / "metadata.yaml") if not isinstance(metadata, ServiceMetadata): err = f"[ServiceLoader] 元数据类型错误: {service_path}/metadata.yaml" @@ -48,7 +52,6 @@ class ServiceLoader: # 更新数据库 await self._update_db(nodes, metadata) - async def save(self, service_id: str, metadata: ServiceMetadata, data: dict) -> None: """在文件系统上保存Service,并更新数据库""" service_path = BASE_PATH / service_id @@ -67,7 +70,6 @@ class ServiceLoader: await file_checker.diff_one(service_path) await self.load(service_id, file_checker.hashes[f"service/{service_id}"]) - async def delete(self, service_id: str, *, is_reload: bool = False) -> None: """删除Service,并更新数据库""" mongo = MongoDB() @@ -95,7 +97,6 @@ class ServiceLoader: if await path.exists(): shutil.rmtree(path) - async def _update_db(self, nodes: list[NodePool], metadata: ServiceMetadata) -> None: # noqa: C901, PLR0912, PLR0915 """更新数据库""" if not metadata.hashes: @@ -197,4 +198,3 @@ class ServiceLoader: await asyncio.sleep(0.01) else: raise - diff --git a/apps/scheduler/pool/mcp/client.py b/apps/scheduler/pool/mcp/client.py index 092bac8909635a5e0c846dddef3456d8ad3be43c..b672690536bc1f03923de51bef6b1ed88c785a04 100644 --- a/apps/scheduler/pool/mcp/client.py +++ b/apps/scheduler/pool/mcp/client.py @@ -29,6 +29,7 @@ class MCPClient: mcp_id: str task: asyncio.Task ready_sign: asyncio.Event + error_sign: asyncio.Event stop_sign: asyncio.Event client: ClientSession status: MCPStatus @@ -54,9 +55,10 @@ class MCPClient: """ # 创建Client if isinstance(config, MCPServerSSEConfig): + headers = config.headers or {} client = sse_client( url=config.url, - headers=config.env, + headers=headers ) elif isinstance(config, MCPServerStdioConfig): if user_sub: @@ -64,7 +66,6 @@ class MCPClient: else: cwd = MCP_PATH / "template" / mcp_id / "project" await cwd.mkdir(parents=True, exist_ok=True) - client = stdio_client(server=StdioServerParameters( command=config.command, args=config.args, @@ -72,6 +73,7 @@ class MCPClient: cwd=cwd.as_posix(), )) else: + self.error_sign.set() err = f"[MCPClient] MCP {mcp_id}:未知的MCP服务类型“{config.type}”" logger.error(err) raise TypeError(err) @@ -85,23 +87,24 @@ class MCPClient: # 初始化Client await session.initialize() except Exception: + self.error_sign.set() + self.status = MCPStatus.STOPPED logger.exception("[MCPClient] MCP %s:初始化失败", mcp_id) raise self.ready_sign.set() self.status = MCPStatus.RUNNING - # 等待关闭信号 await self.stop_sign.wait() + logger.info("[MCPClient] MCP %s:收到停止信号,正在关闭", mcp_id) # 关闭Client try: - await exit_stack.aclose() # type: ignore[attr-defined] + await exit_stack.aclose() # type: ignore[attr-defined] self.status = MCPStatus.STOPPED except Exception: logger.exception("[MCPClient] MCP %s:关闭失败", mcp_id) - async def init(self, user_sub: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig) -> None: """ 初始化 MCP Client类 @@ -116,27 +119,34 @@ class MCPClient: # 初始化变量 self.mcp_id = mcp_id self.ready_sign = asyncio.Event() + self.error_sign = asyncio.Event() self.stop_sign = asyncio.Event() # 创建协程 self.task = asyncio.create_task(self._main_loop(user_sub, mcp_id, config)) # 等待初始化完成 - await self.ready_sign.wait() - - # 获取工具列表 + done, pending = await asyncio.wait( + [asyncio.create_task(self.ready_sign.wait()), + asyncio.create_task(self.error_sign.wait())], + return_when=asyncio.FIRST_COMPLETED + ) + if self.error_sign.is_set(): + self.status = MCPStatus.ERROR + logger.error("[MCPClient] MCP %s:初始化失败", mcp_id) + raise Exception(f"MCP {mcp_id} 初始化失败") + + # 获取工具列表 self.tools = (await self.client.list_tools()).tools - async def call_tool(self, tool_name: str, params: dict) -> "CallToolResult": """调用MCP Server的工具""" return await self.client.call_tool(tool_name, params) - async def stop(self) -> None: """停止MCP Client""" self.stop_sign.set() try: await self.task - except Exception: - logger.exception("[MCPClient] MCP %s:停止失败", self.mcp_id) + except Exception as e: + logger.warning("[MCPClient] MCP %s:停止时发生异常:%s", self.mcp_id, e) diff --git a/apps/scheduler/pool/mcp/install.py b/apps/scheduler/pool/mcp/install.py index 1b6c3edeb3a042b7716d0d7748e9c3e50b01af74..02392b366e17b59110ce6a467e63daa081e8a427 100644 --- a/apps/scheduler/pool/mcp/install.py +++ b/apps/scheduler/pool/mcp/install.py @@ -3,12 +3,16 @@ from asyncio import subprocess from typing import TYPE_CHECKING - +import logging +import os +import shutil from apps.constants import MCP_PATH if TYPE_CHECKING: from apps.schemas.mcp import MCPServerStdioConfig +logger = logging.getLogger(__name__) + async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServerStdioConfig | None": """ @@ -23,27 +27,35 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer :rtype: MCPServerStdioConfig :raises ValueError: 未找到MCP Server对应的Python包 """ - # 创建文件夹 - mcp_path = MCP_PATH / "template" / mcp_id / "project" - await mcp_path.mkdir(parents=True, exist_ok=True) - + uv_path = shutil.which('uv') + if uv_path is None: + error = "[Installer] 未找到uv命令,请先安装uv包管理器: pip install uv" + logging.error(error) + raise Exception(error) # 找到包名 - package = "" + package = None for arg in config.args: - if not arg.startswith("-"): + if not arg.startswith("-") and arg != "run": package = arg break - + logger.error(f"[Installer] MCP包名: {package}") if not package: print("[Installer] 未找到包名") # noqa: T201 return None - + # 创建文件夹 + mcp_path = MCP_PATH / "template" / mcp_id / "project" + logger.error(f"[Installer] MCP安装路径: {mcp_path}") + await mcp_path.mkdir(parents=True, exist_ok=True) # 如果有pyproject.toml文件,则使用sync + flag = await (mcp_path / "pyproject.toml").exists() + logger.error(f"[Installer] MCP安装标志: {flag}") if await (mcp_path / "pyproject.toml").exists(): + shell_command = f"{uv_path} venv; {uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active --no-install-project --no-cache" + logger.error(f"[Installer] MCP安装命令: {shell_command}") pipe = await subprocess.create_subprocess_shell( ( - "uv venv; " - "uv sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " + f"{uv_path} venv; " + f"{uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " "--no-install-project --no-cache" ), stdout=subprocess.PIPE, @@ -57,19 +69,20 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer return None print(f"[Installer] 检查依赖成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201 - config.command = "uv" - config.args = ["run", *config.args] + config.command = uv_path + if "run" not in config.args: + config.args = ["run", *config.args] config.auto_install = False - + logger.error(f"[Installer] MCP安装配置更新成功: {config}") return config # 否则,初始化uv项目 pipe = await subprocess.create_subprocess_shell( ( - f"uv init; " - f"uv venv; " - f"uv add --index-url https://pypi.tuna.tsinghua.edu.cn/simple {package}; " - f"uv sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " + f"{uv_path} init; " + f"{uv_path} venv; " + f"{uv_path} add --index-url https://pypi.tuna.tsinghua.edu.cn/simple {package}; " + f"{uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " f"--no-install-project --no-cache" ), stdout=subprocess.PIPE, @@ -84,8 +97,9 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer print(f"[Installer] 安装 {package} 成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201 # 更新配置 - config.command = "uv" - config.args = ["run", *config.args] + config.command = uv_path + if "run" not in config.args: + config.args = ["run", *config.args] config.auto_install = False return config @@ -103,17 +117,13 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer :rtype: MCPServerStdioConfig :raises ValueError: 未找到MCP Server对应的npm包 """ - mcp_path = MCP_PATH / "template" / mcp_id / "project" - await mcp_path.mkdir(parents=True, exist_ok=True) - - # 如果有node_modules文件夹,则认为已安装 - if await (mcp_path / "node_modules").exists(): - config.command = "npm" - config.args = ["exec", *config.args] - return config - + npm_path = shutil.which('npm') + if npm_path is None: + error = "[Installer] 未找到npm命令,请先安装Node.js和npm" + logging.error(error) + raise Exception(error) # 查找package name - package = "" + package = None for arg in config.args: if not arg.startswith("-"): package = arg @@ -122,10 +132,18 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer if not package: print("[Installer] 未找到包名") # noqa: T201 return None + mcp_path = MCP_PATH / "template" / mcp_id / "project" + await mcp_path.mkdir(parents=True, exist_ok=True) + # 如果有node_modules文件夹,则认为已安装 + if await (mcp_path / "node_modules").exists(): + config.command = npm_path + if "exec" not in config.args: + config.args = ["exec", *config.args] + return config # 安装NPM包 pipe = await subprocess.create_subprocess_shell( - f"npm install {package}", + f"{npm_path} install {package}", stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=mcp_path, @@ -137,8 +155,9 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer print(f"[Installer] 安装 {package} 成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201 # 更新配置 - config.command = "npm" - config.args = ["exec", *config.args] + config.command = npm_path + if "exec" not in config.args: + config.args = ["exec", *config.args] config.auto_install = False return config diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index 91cde4d9d6c83fd5088328e12cfbb6bf06d3c7cb..cd30333f39c496ee643dd11eab7349a0f17a4da7 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -21,16 +21,13 @@ class MCPPool(metaclass=SingletonMeta): """初始化MCP池""" self.pool = {} - async def _init_mcp(self, mcp_id: str, user_sub: str) -> MCPClient | None: """初始化MCP池""" - mcp_math = MCP_USER_PATH / user_sub / mcp_id / "project" config_path = MCP_USER_PATH / user_sub / mcp_id / "config.json" - - if not await mcp_math.exists() or not await mcp_math.is_dir(): - logger.warning("[MCPPool] 用户 %s 的MCP %s 未激活", user_sub, mcp_id) + flag = (await config_path.exists()) + if not flag: + logger.warning("[MCPPool] 用户 %s 的MCP %s 配置文件不存在", user_sub, mcp_id) return None - config = MCPServerConfig.model_validate_json(await config_path.read_text()) if config.type in (MCPType.SSE, MCPType.STDIO): @@ -40,9 +37,11 @@ class MCPPool(metaclass=SingletonMeta): return None await client.init(user_sub, mcp_id, config.config) + if user_sub not in self.pool: + self.pool[user_sub] = {} + self.pool[user_sub][mcp_id] = client return client - async def _get_from_dict(self, mcp_id: str, user_sub: str) -> MCPClient | None: """从字典中获取MCP客户端""" if user_sub not in self.pool: @@ -53,7 +52,6 @@ class MCPPool(metaclass=SingletonMeta): return self.pool[user_sub][mcp_id] - async def _validate_user(self, mcp_id: str, user_sub: str) -> bool: """验证用户是否已激活""" mongo = MongoDB() @@ -61,7 +59,6 @@ class MCPPool(metaclass=SingletonMeta): mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub}) return mcp_db_result is not None - async def get(self, mcp_id: str, user_sub: str) -> MCPClient | None: """获取MCP客户端""" item = await self._get_from_dict(mcp_id, user_sub) @@ -83,8 +80,7 @@ class MCPPool(metaclass=SingletonMeta): return item - async def stop(self, mcp_id: str, user_sub: str) -> None: """停止MCP客户端""" await self.pool[user_sub][mcp_id].stop() - del self.pool[user_sub][mcp_id] + del self.pool[user_sub][mcp_id] \ No newline at end of file diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 7710d24dc102fe02c06d8cbc14f126bd088db2d4..ead552fc2994150137121b103ff013d8ea4d66a5 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -60,7 +60,6 @@ class Pool: await Path(root_dir + "mcp").unlink(missing_ok=True) await Path(root_dir + "mcp").mkdir(parents=True, exist_ok=True) - @staticmethod async def init() -> None: """ @@ -121,13 +120,15 @@ class Pool: for app in changed_app: hash_key = Path("app/" + app).as_posix() if hash_key in checker.hashes: - await app_loader.load(app, checker.hashes[hash_key]) - + try: + await app_loader.load(app, checker.hashes[hash_key]) + except Exception as e: + await app_loader.delete(app, is_reload=True) + logger.warning("[Pool] 加载App %s 失败: %s", app, e) # 载入MCP logger.info("[Pool] 载入MCP") await MCPLoader.init() - async def get_flow_metadata(self, app_id: str) -> list[AppFlow]: """从数据库中获取特定App的全部Flow的元数据""" mongo = MongoDB() @@ -145,14 +146,12 @@ class Pool: else: return flow_metadata_list - async def get_flow(self, app_id: str, flow_id: str) -> Flow | None: """从文件系统中获取单个Flow的全部数据""" logger.info("[Pool] 获取工作流 %s", flow_id) flow_loader = FlowLoader() return await flow_loader.load(app_id, flow_id) - async def get_call(self, call_id: str) -> Any: """[Exception] 拿到Call的信息""" # 从MongoDB里拿到数据 diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 5d83e01cdada76d4d040553ad0a4c47e37e0fe0d..382077360f7134f1c15b8224c77551c1c86190fa 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -8,8 +8,9 @@ import re from apps.common.security import Security from apps.llm.patterns.facts import Facts from apps.schemas.collection import Document -from apps.schemas.enum_var import StepStatus +from apps.schemas.enum_var import StepStatus, FlowStatus from apps.schemas.record import ( + FlowHistory, Record, RecordContent, RecordDocument, @@ -188,34 +189,38 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: feature={}, ), createdAt=current_time, - flow=[context.id if hasattr(context, 'id') else context.get('_id', context.get('id', '')) for context in task.context], + flow=FlowHistory( + flow_id=task.state.flow_id, + flow_name=task.state.flow_name, + flow_status=task.state.flow_status, + history_ids=[context.id for context in task.context], + ) ) # 检查是否存在group_id if not await RecordManager.check_group_id(task.ids.group_id, user_sub): - record_group = await RecordManager.create_record_group( - task.ids.group_id, user_sub, post_body.conversation_id, task.id, - ) - if not record_group: + record_group_id = await RecordManager.create_record_group( + task.ids.group_id, user_sub, post_body.conversation_id + ) + if not record_group_id: logger.error("[Scheduler] 创建问答组失败") return else: - record_group = task.ids.group_id + record_group_id = task.ids.group_id # 修改文件状态 - await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) + await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group_id) # 保存Record - await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record) + await RecordManager.insert_record_data_into_record_group(user_sub, record_group_id, record) # 保存与答案关联的文件 - await DocumentManager.save_answer_doc(user_sub, record_group, used_docs) + await DocumentManager.save_answer_doc(user_sub, record_group_id, used_docs) if post_body.app and post_body.app.app_id: # 更新最近使用的应用 await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) # 若状态为成功,删除Task - if not task.state or task.state.status == StepStatus.SUCCESS: + if not task.state or task.state.flow_status == FlowStatus.SUCCESS or task.state.flow_status == FlowStatus.ERROR or task.state.flow_status == FlowStatus.CANCELLED: await TaskManager.delete_task_by_task_id(task.id) else: - # 更新Task await TaskManager.save_task(task.id, task) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index a2a45e41494e3ddb8d291c0a907da06bb479d4a5..b2725d6cafc32abc6e4e88b1e4054767708dde9b 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -15,6 +15,7 @@ from apps.schemas.message import ( InitContentFeature, TextAddContent, ) +from apps.schemas.enum_var import FlowStatus from apps.schemas.rag_data import RAGEventData, RAGQueryReq from apps.schemas.record import RecordDocument from apps.schemas.task import Task @@ -59,23 +60,34 @@ async def push_init_message( async def push_rag_message( - task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]], - doc_ids: list[str], - rag_data: RAGQueryReq,) -> Task: + task: Task, + queue: MessageQueue, + user_sub: str, + llm: LLM, + history: list[dict[str, str]], + doc_ids: list[str], + rag_data: RAGQueryReq, +) -> None: """推送RAG消息""" full_answer = "" - - async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): - task, content_obj = await _push_rag_chunk(task, queue, chunk) - if content_obj.event_type == EventType.TEXT_ADD.value: - # 如果是文本消息,直接拼接到答案中 - full_answer += content_obj.content - elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append(content_obj.content) + try: + async for chunk in RAG.chat_with_llm_base_on_rag( + user_sub, llm, history, doc_ids, rag_data, task.language + ): + task, content_obj = await _push_rag_chunk(task, queue, chunk) + if content_obj.event_type == EventType.TEXT_ADD.value: + # 如果是文本消息,直接拼接到答案中 + full_answer += content_obj.content + elif content_obj.event_type == EventType.DOCUMENT_ADD.value: + task.runtime.documents.append(content_obj.content) + task.state.flow_status = FlowStatus.SUCCESS + except Exception as e: + logger.error(f"[Scheduler] RAG服务发生错误: {e}") + task.state.flow_status = FlowStatus.ERROR # 保存答案 task.runtime.answer = full_answer + task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time await TaskManager.save_task(task.id, task) - return task async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData]: @@ -113,11 +125,11 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), - createdAt=round(datetime.now(tz=UTC).timestamp(), 3), + createdAt=round(content_obj.content.get("created_at", datetime.now(tz=UTC).timestamp()), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: logger.exception("[Scheduler] RAG服务返回错误数据") return task, "" else: - return task, content_obj + return task, content_obj \ No newline at end of file diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 417f93d28a147cd0cf65b124624484a02a49ab06..49953877c87290f0e85f5b901ec56bc2ca15a5b9 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -1,9 +1,13 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Scheduler模块""" +import asyncio import logging from datetime import UTC, datetime +from apps.llm.reasoning import ReasoningLLM +from apps.schemas.config import LLMConfig +from apps.llm.patterns.rewrite import QuestionRewrite from apps.common.config import Config from apps.common.mongo import MongoDB from apps.common.queue import MessageQueue @@ -17,11 +21,12 @@ from apps.scheduler.scheduler.message import ( push_rag_message, ) from apps.schemas.collection import LLM -from apps.schemas.enum_var import AppType, EventType +from apps.schemas.enum_var import FlowStatus, AppType, EventType from apps.schemas.pool import AppPool from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData from apps.schemas.scheduler import ExecutorBackground +from apps.services.activity import Activity from apps.schemas.task import Task from apps.services.appcenter import AppCenterManager from apps.services.knowledge import KnowledgeBaseManager @@ -45,8 +50,28 @@ class Scheduler: self.queue = queue self.post_body = post_body - async def run(self) -> None: # noqa: PLR0911 - """运行调度器""" + async def _monitor_activity(self, kill_event): + """监控用户活动状态,不活跃时终止工作流""" + try: + check_interval = 0.5 # 每0.5秒检查一次 + + while not kill_event.is_set(): + # 检查用户活动状态 + is_active = await Activity.is_active(self.task.ids.active_id) + if not is_active: + logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", self.task.ids.user_sub) + kill_event.set() + break + + # 控制检查频率 + await asyncio.sleep(check_interval) + except asyncio.CancelledError: + logger.info("[Scheduler] 活动监控任务已取消") + except Exception as e: + logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}") + + async def get_llm_use_in_chat_with_rag(self) -> LLM: + """获取RAG大模型""" try: # 获取当前会话使用的大模型 llm_id = await LLMManager.get_llm_id_by_conversation_id( @@ -54,8 +79,7 @@ class Scheduler: ) if not llm_id: logger.error("[Scheduler] 获取大模型ID失败") - await self.queue.close() - return + return None if llm_id == "empty": llm = LLM( _id="empty", @@ -65,24 +89,30 @@ class Scheduler: model_name=Config().get_config().llm.model, max_tokens=Config().get_config().llm.max_tokens, ) + return llm else: llm = await LLMManager.get_llm_by_id(self.task.ids.user_sub, llm_id) if not llm: logger.error("[Scheduler] 获取大模型失败") - await self.queue.close() - return + return None + return llm except Exception: logger.exception("[Scheduler] 获取大模型失败") - await self.queue.close() - return + return None + + async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]: + """获取知识库ID列表""" try: - # 获取当前会话使用的知识库 kb_ids = await KnowledgeBaseManager.get_kb_ids_by_conversation_id( - self.task.ids.user_sub, self.task.ids.conversation_id) + self.task.ids.user_sub, self.task.ids.conversation_id + ) except Exception: logger.exception("[Scheduler] 获取知识库ID失败") await self.queue.close() - return + return [] + + async def run(self) -> None: # noqa: PLR0911 + """运行调度器""" try: # 获取当前问答可供关联的文档 docs, doc_ids = await get_docs(self.task.ids.user_sub, self.post_body) @@ -92,18 +122,30 @@ class Scheduler: return history, _ = await get_context(self.task.ids.user_sub, self.post_body, 3) # 已使用文档 - # 如果是智能问答,直接执行 logger.info("[Scheduler] 开始执行") - if not self.post_body.app or self.post_body.app.app_id == "": + # 创建用于通信的事件 + kill_event = asyncio.Event() + monitor = asyncio.create_task(self._monitor_activity(kill_event)) + rag_method = True + if self.post_body.app and self.post_body.app.app_id: + rag_method = False + if self.task.state.app_id: + rag_method = False + if rag_method: + llm = await self.get_llm_use_in_chat_with_rag() + kb_ids = await self.get_kb_ids_use_in_chat_with_rag() self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( kbIds=kb_ids, query=self.post_body.question, tokensLimit=llm.max_tokens, ) - self.task = await push_rag_message(self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data) - self.task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time + + # 启动监控任务和主任务 + main_task = asyncio.create_task(push_rag_message( + self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data)) + else: # 查找对应的App元数据 app_data = await AppCenterManager.fetch_app_data_by_id(self.post_body.app.app_id) @@ -127,8 +169,27 @@ class Scheduler: conversation=context, facts=facts, ) - await self.run_executor(self.queue, self.post_body, executor_background) + # 启动监控任务和主任务 + main_task = asyncio.create_task(self.run_executor(self.queue, self.post_body, executor_background)) + # 等待任一任务完成 + done, pending = await asyncio.wait( + [main_task, monitor], + return_when=asyncio.FIRST_COMPLETED + ) + + # 如果是监控任务触发,终止主任务 + if kill_event.is_set(): + logger.warning("[Scheduler] 用户活动状态检测不活跃,正在终止工作流执行...") + main_task.cancel() + need_change_cancel_flow_state = [FlowStatus.RUNNING, FlowStatus.WAITING] + if self.task.state.flow_status in need_change_cancel_flow_state: + self.task.state.flow_status = FlowStatus.CANCELLED + try: + await main_task + logger.info("[Scheduler] 工作流执行已被终止") + except Exception as e: + logger.error(f"[Scheduler] 终止工作流时发生错误: {e}") # 更新Task,发送结束消息 logger.info("[Scheduler] 发送结束消息") await self.queue.push_output(self.task, event_type=EventType.DONE.value, data={}) @@ -152,6 +213,42 @@ class Scheduler: if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") return + if app_metadata.llm_id == "empty": + llm = LLM( + _id="empty", + user_sub=self.task.ids.user_sub, + openai_base_url=Config().get_config().llm.endpoint, + openai_api_key=Config().get_config().llm.key, + model_name=Config().get_config().llm.model, + max_tokens=Config().get_config().llm.max_tokens, + ) + else: + llm = await LLMManager.get_llm_by_id( + self.task.ids.user_sub, app_metadata.llm_id, + ) + if not llm: + logger.error("[Scheduler] 获取大模型失败") + await self.queue.close() + return + reasion_llm = ReasoningLLM( + LLMConfig( + endpoint=llm.openai_base_url, + key=llm.openai_api_key, + model=llm.model_name, + max_tokens=llm.max_tokens, + ) + ) + if background.conversation and self.task.state.flow_status == FlowStatus.INIT: + try: + question_obj = QuestionRewrite() + post_body.question = await question_obj.generate( + history=background.conversation, + question=post_body.question, + llm=reasion_llm, + language=post_body.language, + ) + except Exception: + logger.exception("[Scheduler] 问题重写失败") if app_metadata.app_type == AppType.FLOW.value: logger.info("[Scheduler] 获取工作流元数据") flow_info = await Pool().get_flow_metadata(app_info.app_id) @@ -182,7 +279,6 @@ class Scheduler: # 初始化Executor logger.info("[Scheduler] 初始化Executor") - flow_exec = FlowExecutor( flow_id=flow_id, flow=flow_data, @@ -210,6 +306,7 @@ class Scheduler: servers_id=servers_id, background=background, agent_id=app_info.app_id, + params=post_body.params ) # 开始运行 logger.info("[Scheduler] 运行Executor") @@ -218,4 +315,4 @@ class Scheduler: else: logger.error("[Scheduler] 无效的应用类型") - return + return \ No newline at end of file diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 4c9453c47710d7cd9998364559416551b33eb05d..7d34a7bcce89380475c1241906939bbf48fa1e53 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """参数槽位管理""" +import copy import json import logging import traceback @@ -13,7 +14,7 @@ from jsonschema.protocols import Validator from jsonschema.validators import extend from apps.schemas.response_data import ParamsNode -from apps.scheduler.call.choice.schema import ValueType +from apps.schemas.parameters import ValueType from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, @@ -128,7 +129,7 @@ class Slot: # Schema标准 return [_process_json_value(item, spec_data["items"]) for item in json_value] if spec_data["type"] == "object" and isinstance(json_value, dict): - # 若Schema不标准,则不进行处理 + # 若Schema不标准,则不进行处理F if "properties" not in spec_data: return json_value # Schema标准 @@ -156,35 +157,60 @@ class Slot: @staticmethod def _generate_example(schema_node: dict) -> Any: # noqa: PLR0911 """根据schema生成示例值""" + if "anyOf" in schema_node or "oneOf" in schema_node: + # 如果有anyOf,随机返回一个示例 + for item in schema_node["anyOf"] if "anyOf" in schema_node else schema_node["oneOf"]: + example = Slot._generate_example(item) + if example is not None: + return example + + if "allOf" in schema_node: + # 如果有allOf,返回所有示例的合并 + example = None + for item in schema_node["allOf"]: + if example is None: + example = Slot._generate_example(item) + else: + other_example = Slot._generate_example(item) + if isinstance(example, dict) and isinstance(other_example, dict): + example.update(other_example) + else: + example = None + break + return example + if "default" in schema_node: return schema_node["default"] if "type" not in schema_node: return None - + type_value = schema_node["type"] + if isinstance(type_value, list): + # 如果是多类型,随机返回一个示例 + if len(type_value) > 1: + type_value = type_value[0] # 处理类型为 object 的节点 - if schema_node["type"] == "object": + if type_value == "object": data = {} properties = schema_node.get("properties", {}) for name, schema in properties.items(): data[name] = Slot._generate_example(schema) return data - # 处理类型为 array 的节点 - if schema_node["type"] == "array": + elif type_value == "array": items_schema = schema_node.get("items", {}) return [Slot._generate_example(items_schema)] # 处理类型为 string 的节点 - if schema_node["type"] == "string": + elif type_value == "string": return "" # 处理类型为 number 或 integer 的节点 - if schema_node["type"] in ["number", "integer"]: + elif type_value in ["number", "integer"]: return 0 # 处理类型为 boolean 的节点 - if schema_node["type"] == "boolean": + elif type_value == "boolean": return False # 处理其他类型或未定义类型 @@ -198,29 +224,69 @@ class Slot: """从JSON Schema中提取类型描述""" def _extract_type_desc(schema_node: dict[str, Any]) -> dict[str, Any]: - if "type" not in schema_node and "anyOf" not in schema_node: - return {} - data = {"type": schema_node.get("type", ""), "description": schema_node.get("description", "")} - if "anyOf" in schema_node: - data["type"] = "anyOf" - # 处理类型为 object 的节点 - if "anyOf" in schema_node: - data["items"] = {} - type_index = 0 - for type_index, sub_schema in enumerate(schema_node["anyOf"]): - sub_result = _extract_type_desc(sub_schema) - if sub_result: - data["items"]["type_"+str(type_index)] = sub_result - if schema_node.get("type", "") == "object": - data["items"] = {} + # 处理组合关键字 + special_keys = ["anyOf", "allOf", "oneOf"] + for key in special_keys: + if key in schema_node: + data = { + "type": key, + "description": schema_node.get("description", ""), + "items": {}, + } + type_index = 0 + for item in schema_node[key]: + if isinstance(item, dict): + data["items"][f"item_{type_index}"] = _extract_type_desc(item) + else: + data["items"][f"item_{type_index}"] = {"type": item, "description": ""} + type_index += 1 + return data + # 处理基本类型 + type_val = schema_node.get("type", "") + description = schema_node.get("description", "") + + # 处理多类型数组 + if isinstance(type_val, list): + if len(type_val) > 1: + data = {"type": "union", "description": description, "items": {}} + type_index = 0 + for t in type_val: + if t == "object": + tmp_dict = {} + for key, val in schema_node.get("properties", {}).items(): + tmp_dict[key] = _extract_type_desc(val) + data["items"][f"item_{type_index}"] = tmp_dict + elif t == "array": + items_schema = schema_node.get("items", {}) + data["items"][f"item_{type_index}"] = _extract_type_desc(items_schema) + else: + data["items"][f"item_{type_index}"] = {"type": t, "description": description} + type_index += 1 + return data + elif len(type_val) == 1: + type_val = type_val[0] + else: + type_val = "" + + data = {"type": type_val, "description": description, "items": {}} + + # 递归处理对象和数组 + if type_val == "object": for key, val in schema_node.get("properties", {}).items(): data["items"][key] = _extract_type_desc(val) - - # 处理类型为 array 的节点 - if schema_node.get("type", "") == "array": + elif type_val == "array": items_schema = schema_node.get("items", {}) - data["items"] = _extract_type_desc(items_schema) + if isinstance(items_schema, list): + item_index = 0 + for item in items_schema: + data["items"][f"item_{item_index}"] = _extract_type_desc(item) + item_index += 1 + else: + data["items"]["item"] = _extract_type_desc(items_schema) + if data["items"] == {}: + del data["items"] return data + return _extract_type_desc(self._schema) def get_params_node_from_schema(self, root: str = "") -> ParamsNode: @@ -231,13 +297,15 @@ class Slot: return None param_type = schema_node["type"] + if isinstance(param_type, list): + return None # 不支持多类型 if param_type == "object": param_type = ValueType.DICT elif param_type == "array": param_type = ValueType.LIST elif param_type == "string": param_type = ValueType.STRING - elif param_type == "number": + elif param_type in ["number", "integer"]: param_type = ValueType.NUMBER elif param_type == "boolean": param_type = ValueType.BOOL @@ -246,9 +314,11 @@ class Slot: return None sub_params = [] - if param_type == "object" and "properties" in schema_node: + if param_type == ValueType.DICT and "properties" in schema_node: for key, value in schema_node["properties"].items(): - sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}")) + sub_param = _extract_params_node(value, name=key, path=f"{path}/{key}") + if sub_param: + sub_params.append(sub_param) else: # 对于非对象类型,直接返回空子参数 sub_params = None @@ -412,3 +482,54 @@ class Slot: return schema_template return {} + + def add_null_to_basic_types(self) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + """ + def add_null_to_basic_types(schema: dict[str, Any]) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + + 参数: + schema (dict): 原始 JSON Schema + + 返回: + dict: 修改后的 JSON Schema + """ + # 如果不是字典类型(schema),直接返回 + if not isinstance(schema, dict): + return schema + + # 处理当前节点的 type 字段 + if 'type' in schema: + # 处理单一类型字符串 + if isinstance(schema['type'], str): + if schema['type'] in ['boolean', 'number', 'string', 'integer']: + schema['type'] = [schema['type'], 'null'] + + # 处理类型数组 + elif isinstance(schema['type'], list): + for i, t in enumerate(schema['type']): + if isinstance(t, str) and t in ['boolean', 'number', 'string', 'integer']: + if 'null' not in schema['type']: + schema['type'].append('null') + break + + # 递归处理 properties 字段(对象类型) + if 'properties' in schema: + for prop, prop_schema in schema['properties'].items(): + schema['properties'][prop] = add_null_to_basic_types(prop_schema) + + # 递归处理 items 字段(数组类型) + if 'items' in schema: + schema['items'] = add_null_to_basic_types(schema['items']) + + # 递归处理 anyOf, oneOf, allOf 字段 + for keyword in ['anyOf', 'oneOf', 'allOf']: + if keyword in schema: + schema[keyword] = [add_null_to_basic_types(sub_schema) for sub_schema in schema[keyword]] + + return schema + schema_copy = copy.deepcopy(self._schema) + return add_null_to_basic_types(schema_copy) \ No newline at end of file diff --git a/apps/schemas/agent.py b/apps/schemas/agent.py index b52f5e1c3315fb873acddb2bcc4e27937498d561..16e818e4d588d4f25ea4854a7979e2dc5fe90ecb 100644 --- a/apps/schemas/agent.py +++ b/apps/schemas/agent.py @@ -17,6 +17,7 @@ class AgentAppMetadata(MetadataBase): app_type: AppType = Field(default=AppType.AGENT, description="应用类型", frozen=True) published: bool = Field(description="是否发布", default=False) history_len: int = Field(description="对话轮次", default=3, le=10) - mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") + mcp_service: list[str] = Field(default=[], description="MCP服务id列表") + llm_id: str = Field(default="empty", description="大模型ID") permission: Permission | None = Field(description="应用权限配置", default=None) version: str = Field(description="元数据版本") diff --git a/apps/schemas/appcenter.py b/apps/schemas/appcenter.py index a89f39df18083d90b988cc764c0e88f90500d1f3..e3bb896eba2361e0c28ee914b84f387d7da76b87 100644 --- a/apps/schemas/appcenter.py +++ b/apps/schemas/appcenter.py @@ -45,9 +45,9 @@ class AppFlowInfo(BaseModel): """应用工作流数据结构""" id: str = Field(..., description="工作流ID") - name: str = Field(..., description="工作流名称") - description: str = Field(..., description="工作流简介") - debug: bool = Field(..., description="是否经过调试") + name: str = Field(default="", description="工作流名称") + description: str = Field(default="", description="工作流简介") + debug: bool = Field(default=False, description="是否经过调试") class AppData(BaseModel): @@ -61,6 +61,7 @@ class AppData(BaseModel): first_questions: list[str] = Field( default=[], alias="recommendedQuestions", description="推荐问题", max_length=3) history_len: int = Field(3, alias="dialogRounds", ge=1, le=10, description="对话轮次(1~10)") + llm: str = Field(default="empty", description="大模型ID") permission: AppPermissionData = Field( default_factory=lambda: AppPermissionData(authorizedUsers=None), description="权限配置") workflows: list[AppFlowInfo] = Field(default=[], description="工作流信息列表") diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py index 0ff66c72bbe30b7cbb55517952c8b14744d69cf2..20bdfc7c9f1cc7357365baef9c41fa4971c51957 100644 --- a/apps/schemas/collection.py +++ b/apps/schemas/collection.py @@ -61,6 +61,7 @@ class User(BaseModel): fav_apps: list[str] = [] fav_services: list[str] = [] is_admin: bool = Field(default=False, description="是否为管理员") + auto_execute: bool = Field(default=True, description="是否自动执行任务") class LLM(BaseModel): @@ -72,6 +73,7 @@ class LLM(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") user_sub: str = Field(default="", description="用户ID") + title: str = Field(default=NEW_CHAT) icon: str = Field(default=llm_provider_dict["ollama"]["icon"], description="图标") openai_base_url: str = Field(default=Config().get_config().llm.endpoint) openai_api_key: str = Field(default=Config().get_config().llm.key) diff --git a/apps/schemas/config.py b/apps/schemas/config.py index 3205f062bf8170529691fd812dbc7dd0c8009f6a..eb9a4593aca942bdafa1d3b9f6de7a75dcd96da6 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -10,7 +10,6 @@ class NoauthConfig(BaseModel): """无认证配置""" enable: bool = Field(description="是否启用无认证访问", default=False) - user_sub: str = Field(description="调试用户的sub", default="admin") class DeployConfig(BaseModel): diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index e4670d77637ef4784592cf677cd0bcfa9cb40192..7d024c88a0608cfa0a999da5824620a1f4aa2bb6 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,11 +15,26 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" + UNKNOWN = "unknown" + INIT = "init" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" ERROR = "error" PARAM = "param" + CANCELLED = "cancelled" + + +class FlowStatus(str, Enum): + """Flow状态""" + + UNKNOWN = "unknown" + INIT = "init" + WAITING = "waiting" + RUNNING = "running" + SUCCESS = "success" + ERROR = "error" + CANCELLED = "cancelled" class DocumentStatus(str, Enum): @@ -35,16 +50,22 @@ class EventType(str, Enum): """事件类型""" HEARTBEAT = "heartbeat" - INIT = "init", + INIT = "init" TEXT_ADD = "text.add" GRAPH = "graph" DOCUMENT_ADD = "document.add" STEP_WAITING_FOR_START = "step.waiting_for_start" STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" + STEP_INIT = "step.init" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" + STEP_CANCEL = "step.cancel" + STEP_ERROR = "step.error" FLOW_STOP = "flow.stop" + FLOW_FAILED = "flow.failed" + FLOW_SUCCESS = "flow.success" + FLOW_CANCEL = "flow.cancel" DONE = "done" @@ -196,3 +217,10 @@ class AgentState(str, Enum): RUNNING = "RUNNING" FINISHED = "FINISHED" ERROR = "ERROR" + + +class LanguageType(str, Enum): + """语言类型""" + + CHINESE = "zh" + ENGLISH = "en" \ No newline at end of file diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py index 2646d04390099fd92988d78c00d1b1471780d6a9..dfffd1f14c779952d9bb761d2e2bccd85528bc0f 100644 --- a/apps/schemas/flow.py +++ b/apps/schemas/flow.py @@ -136,6 +136,7 @@ class AppMetadata(MetadataBase): published: bool = Field(description="是否发布", default=False) links: list[AppLink] = Field(description="相关链接", default=[]) first_questions: list[str] = Field(description="首次提问", default=[]) + llm_id: str = Field(description="大模型ID", default="empty") history_len: int = Field(description="对话轮次", default=3, le=10) permission: Permission | None = Field(description="应用权限配置", default=None) flows: list[AppFlow] = Field(description="Flow列表", default=[]) diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index aa5da0d585d6d0e49d5b1d614a6fc8e1d7024d06..b7e1175eff4bcc0f728b3b7a36e33e55705570b9 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -5,6 +5,7 @@ from typing import Any from pydantic import BaseModel, Field +from apps.schemas.enum_var import SpecialCallType from apps.schemas.enum_var import CallType, EdgeType @@ -52,7 +53,7 @@ class NodeItem(BaseModel): service_id: str = Field(alias="serviceId", default="") node_id: str = Field(alias="nodeId", default="") name: str = Field(default="") - call_id: str = Field(alias="callId", default="Empty") + call_id: str = Field(alias="callId", default=SpecialCallType.EMPTY.value) description: str = Field(default="") enable: bool = Field(default=True) parameters: dict[str, Any] = Field(default={}) @@ -82,6 +83,6 @@ class FlowItem(BaseModel): nodes: list[NodeItem] = Field(default=[]) edges: list[EdgeItem] = Field(default=[]) created_at: float | None = Field(alias="createdAt", default=0) - connectivity: bool = Field(default=False,description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") + connectivity: bool = Field(default=False, description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem()) debug: bool = Field(default=False) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 60c8f17b4adc4f53f21b0cacc02a86e09b495d06..cecd4a1029eca66ae45da68d32efa839f76e9b62 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,7 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" -import uuid from enum import Enum from typing import Any @@ -11,8 +10,9 @@ from pydantic import BaseModel, Field class MCPInstallStatus(str, Enum): """MCP 服务状态""" - + INIT = "init" INSTALLING = "installing" + CANCELLED = "cancelled" READY = "ready" FAILED = "failed" @@ -23,6 +23,7 @@ class MCPStatus(str, Enum): UNINITIALIZED = "uninitialized" RUNNING = "running" STOPPED = "stopped" + ERROR = "error" class MCPType(str, Enum): @@ -36,23 +37,25 @@ class MCPType(str, Enum): class MCPBasicConfig(BaseModel): """MCP 基本配置""" - env: dict[str, str] = Field(description="MCP 服务器环境变量", default={}) auto_approve: list[str] = Field(description="自动批准的MCP工具ID列表", default=[], alias="autoApprove") disabled: bool = Field(description="MCP 服务器是否禁用", default=False) - auto_install: bool = Field(description="是否自动安装MCP服务器", default=True, alias="autoInstall") + auto_install: bool = Field(description="是否自动安装MCP服务器", default=True) + timeout: int = Field(description="MCP 服务器超时时间(秒)", default=60, alias="timeout") + description: str = Field(description="MCP 服务器自然语言描述", default="") class MCPServerStdioConfig(MCPBasicConfig): """MCP 服务器配置""" + env: dict[str, Any] = Field(description="MCP 服务器环境变量", default={}) command: str = Field(description="MCP 服务器命令") args: list[str] = Field(description="MCP 服务器命令参数") class MCPServerSSEConfig(MCPBasicConfig): """MCP 服务器配置""" - - url: str = Field(description="MCP 服务器地址", default="") + headers: dict[str, str] = Field(description="MCP 服务器请求头", default={}) + url: str = Field(description="MCP 服务器地址", default="http://example.com/sse", pattern=r"^https?://.*$") class MCPServerConfig(BaseModel): @@ -85,7 +88,7 @@ class MCPCollection(BaseModel): type: MCPType = Field(description="MCP 类型", default=MCPType.SSE) activated: list[str] = Field(description="激活该MCP的用户ID列表", default=[]) tools: list[MCPTool] = Field(description="MCP工具列表", default=[]) - status: MCPInstallStatus = Field(description="MCP服务状态", default=MCPInstallStatus.INSTALLING) + status: MCPInstallStatus = Field(description="MCP服务状态", default=MCPInstallStatus.INIT) author: str = Field(description="MCP作者", default="") @@ -104,6 +107,74 @@ class MCPToolVector(LanceModel): embedding: Vector(dim=1024) = Field(description="MCP工具描述的向量信息") # type: ignore[call-arg] +class GoalEvaluationResult(BaseModel): + """MCP 目标评估结果""" + + can_complete: bool = Field(description="是否可以完成目标") + reason: str = Field(description="评估原因") + + +class Risk(str, Enum): + """MCP工具风险类型""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class FlowName(BaseModel): + """MCP 流程名称""" + + flow_name: str = Field(description="MCP 流程名称", default="") + + +class FlowRisk(BaseModel): + """MCP 流程风险评估结果""" + + risk: Risk = Field(description="风险类型", default=Risk.LOW) + reason: str = Field(description="风险原因", default="") + + +class RestartStepIndex(BaseModel): + """MCP重新规划的步骤索引""" + + start_index: int = Field(description="重新规划的起始步骤索引") + reasoning: str = Field(description="重新规划的原因") + + +class ToolSkip(BaseModel): + """MCP工具跳过执行结果""" + + skip: bool = Field(description="是否跳过当前步骤", default=False) + + +class ToolRisk(BaseModel): + """MCP工具风险评估结果""" + + risk: Risk = Field(description="风险类型", default=Risk.LOW) + reason: str = Field(description="风险原因", default="") + + +class ErrorType(str, Enum): + """MCP工具错误类型""" + + MISSING_PARAM = "missing_param" + DECORRECT_PLAN = "decorrect_plan" + + +class ToolExcutionErrorType(BaseModel): + """MCP工具执行错误""" + + type: ErrorType = Field(description="错误类型", default=ErrorType.MISSING_PARAM) + reason: str = Field(description="错误原因", default="") + + +class IsParamError(BaseModel): + """MCP工具参数错误""" + + is_param_error: bool = Field(description="是否是参数错误", default=False) + + class MCPSelectResult(BaseModel): """MCP选择结果""" @@ -116,9 +187,15 @@ class MCPToolSelectResult(BaseModel): name: str = Field(description="工具名称") +class MCPToolIdsSelectResult(BaseModel): + """MCP工具ID选择结果""" + + tool_ids: list[str] = Field(description="工具ID列表") + + class MCPPlanItem(BaseModel): """MCP 计划""" - id: str = Field(default_factory=lambda: str(uuid.uuid4())) + step_id: str = Field(description="步骤的ID", default="") content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") @@ -127,4 +204,11 @@ class MCPPlanItem(BaseModel): class MCPPlan(BaseModel): """MCP 计划""" - plans: list[MCPPlanItem] = Field(description="计划列表") + plans: list[MCPPlanItem] = Field(description="计划列表", default=[]) + + +class Step(BaseModel): + """MCP步骤""" + + tool_id: str = Field(description="工具ID") + description: str = Field(description="步骤描述,15个字以下") \ No newline at end of file diff --git a/apps/schemas/message.py b/apps/schemas/message.py index cf70a82b8573d2e3e34564b8f0ec5280195980d9..28d2a92d4f8832f5a1f205fb7a857e7ea400efbc 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -1,14 +1,21 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """队列中的消息结构""" -from typing import Any from datetime import UTC, datetime +from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import EventType, StepStatus +from apps.schemas.enum_var import EventType, FlowStatus, StepStatus from apps.schemas.record import RecordMetadata +class FlowParams(BaseModel): + """流执行过程中的参数补充""" + + content: dict[str, Any] = Field(default={}, description="流执行过程中的参数补充内容") + description: str = Field(default="", description="流执行过程中的参数补充描述") + + class HeartbeatData(BaseModel): """心跳事件的数据结构""" @@ -22,10 +29,17 @@ class MessageFlow(BaseModel): app_id: str = Field(description="插件ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") + flow_name: str = Field(description="Flow名称", alias="flowName") + flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN) step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None) + step_description: str | None = Field( + description="当前步骤描述", + alias="stepDescription", + default=None, + ) step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") @@ -76,8 +90,7 @@ class FlowStartContent(BaseModel): """flow.start消息的content""" question: str = Field(description="用户问题") - params: dict[str, Any] = Field(description="预先提供的参数") - + params: dict[str, Any] | None = Field(description="预先提供的参数", default=None) class MessageBase(HeartbeatData): """基础消息事件结构""" @@ -87,5 +100,5 @@ class MessageBase(HeartbeatData): conversation_id: str = Field(min_length=36, max_length=36, alias="conversationId") task_id: str = Field(min_length=36, max_length=36, alias="taskId") flow: MessageFlow | None = None - content: dict[str, Any] = {} + content: Any | None = Field(default=None, description="消息内容") metadata: MessageMetadata diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py index 009c8206d967bd7cdb61ba1a41ebf6cf9e87dd68..2d56d60855ba1bcc7209a42642a1ddd3f83d8d81 100644 --- a/apps/schemas/pool.py +++ b/apps/schemas/pool.py @@ -107,4 +107,7 @@ class AppPool(BaseData): permission: Permission = Field(description="应用权限配置", default=Permission()) flows: list[AppFlow] = Field(description="Flow列表", default=[]) hashes: dict[str, str] = Field(description="关联文件的hash值", default={}) - mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") + mcp_service: list[str] = Field(default=[], description="MCP服务id列表") + llm_id: str = Field( + default="empty", description="应用使用的大模型ID(如果有的话)" + ) diff --git a/apps/schemas/rag_data.py b/apps/schemas/rag_data.py index ed097c903bf349ad43176df4762b96ac123b8539..a26f3680752e3fc43be3c0149a98fd3219e71e6b 100644 --- a/apps/schemas/rag_data.py +++ b/apps/schemas/rag_data.py @@ -14,7 +14,7 @@ class RAGQueryReq(BaseModel): top_k: int = Field(default=5, description="返回的结果数量", alias="topK") doc_ids: list[str] | None = Field(default=None, description="文档id", alias="docIds") search_method: str = Field(default="dynamic_weighted_keyword_and_vector", - description="检索方法", alias="searchMethod") + description="检索方法", alias="searchMethod") is_related_surrounding: bool = Field(default=True, description="是否关联上下文", alias="isRelatedSurrounding") is_classify_by_doc: bool = Field(default=True, description="是否按文档分类", alias="isClassifyByDoc") is_rerank: bool = Field(default=False, description="是否重新排序", alias="isRerank") diff --git a/apps/schemas/record.py b/apps/schemas/record.py index b5e1b0c55ad60d0569b57377a6a06a84e7a2920e..7f0c79a9ef7a2d2457d8c789f60081a2b150a0dd 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -10,7 +10,7 @@ from pydantic import BaseModel, Field from apps.schemas.collection import ( Document, ) -from apps.schemas.enum_var import CommentType, StepStatus +from apps.schemas.enum_var import CommentType, FlowStatus, StepStatus class RecordDocument(Document): @@ -33,9 +33,11 @@ class RecordFlowStep(BaseModel): """Record表子项:flow的单步数据结构""" step_id: str = Field(alias="stepId") + step_name: str = Field(alias="stepName", default="") step_status: StepStatus = Field(alias="stepStatus") input: dict[str, Any] output: dict[str, Any] + ex_data: dict[str, Any] | None = Field(default=None, alias="exData") class RecordFlow(BaseModel): @@ -44,6 +46,8 @@ class RecordFlow(BaseModel): id: str record_id: str = Field(alias="recordId") flow_id: str = Field(alias="flowId") + flow_name: str = Field(alias="flowName", default="") + flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS) step_num: int = Field(alias="stepNum") steps: list[RecordFlowStep] @@ -92,7 +96,7 @@ class RecordData(BaseModel): id: str group_id: str = Field(alias="groupId") conversation_id: str = Field(alias="conversationId") - task_id: str = Field(alias="taskId") + task_id: str | None = Field(default=None, alias="taskId") document: list[RecordDocument] = [] flow: RecordFlow | None = None content: RecordContent @@ -115,14 +119,24 @@ class RecordGroupDocument(BaseModel): created_at: float = Field(default=0.0, description="文档创建时间") +class FlowHistory(BaseModel): + """Flow执行历史""" + flow_id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + flow_name: str = Field(default="", description="Flow名称") + flow_staus: FlowStatus = Field(default=FlowStatus.SUCCESS, description="Flow执行状态") + history_ids: list[str] = Field(default=[], description="Flow执行历史ID列表") + + class Record(RecordData): """问答,用于保存在MongoDB中""" user_sub: str key: dict[str, Any] = {} - content: str + task_id: str | None = Field(default=None, description="任务ID") + content: str = Field(default="", description="Record内容,已加密") comment: RecordComment = Field(default=RecordComment()) - flow: list[str] = Field(default=[]) + flow: FlowHistory = Field( + default=FlowHistory(), description="Flow执行历史信息") class RecordGroup(BaseModel): @@ -139,5 +153,4 @@ class RecordGroup(BaseModel): records: list[Record] = [] docs: list[RecordGroupDocument] = [] # 问题不变,所用到的文档不变 conversation_id: str - task_id: str created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index a3a8848c32e898364df671bbf3b26cb40c3f6a0a..2cf85ecc9650da4ce880e1948facf35ff740b3dd 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -7,9 +7,10 @@ from pydantic import BaseModel, Field from apps.common.config import Config from apps.schemas.appcenter import AppData -from apps.schemas.enum_var import CommentType +from apps.schemas.enum_var import CommentType, LanguageType from apps.schemas.flow_topology import FlowItem from apps.schemas.mcp import MCPType +from apps.schemas.message import FlowParams class RequestDataApp(BaseModel): @@ -17,7 +18,6 @@ class RequestDataApp(BaseModel): app_id: str = Field(description="应用ID", alias="appId") flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") - params: dict[str, Any] | None = Field(default=None, description="插件参数") class MockRequestData(BaseModel): @@ -39,14 +39,15 @@ class RequestDataFeatures(BaseModel): class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" - question: str = Field(max_length=2000, description="用户输入") - conversation_id: str = Field(default="", alias="conversationId", description="聊天ID") + question: str | None = Field(default=None, max_length=2000, description="用户输入") + conversation_id: str | None = Field(default=None, alias="conversationId", description="聊天ID") group_id: str | None = Field(default=None, alias="groupId", description="问答组ID") - language: str = Field(default="zh", description="语言") + language: LanguageType = Field(default=LanguageType.CHINESE, description="语言") files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") - new_task: bool = Field(default=True, description="是否新建任务") + task_id: str | None = Field(default=None, alias="taskId", description="任务ID") + params: FlowParams | bool | None = Field(default=None, description="流执行过程中的参数补充", alias="params") class QuestionBlacklistRequest(BaseModel): @@ -99,7 +100,7 @@ class UpdateMCPServiceRequest(BaseModel): name: str = Field(..., description="MCP服务名称") description: str = Field(..., description="MCP服务描述") overview: str = Field(..., description="MCP服务概述") - config: str = Field(..., description="MCP服务配置") + config: dict[str, Any] = Field(..., description="MCP服务配置") mcp_type: MCPType = Field(description="MCP传输协议(Stdio/SSE/Streamable)", default=MCPType.STDIO, alias="mcpType") @@ -107,6 +108,7 @@ class ActiveMCPServiceRequest(BaseModel): """POST /api/mcp/{serviceId} 请求数据结构""" active: bool = Field(description="是否激活mcp服务") + mcp_env: dict[str, Any] | None = Field(default=None, description="MCP服务环境变量", alias="mcpEnv") class UpdateServiceRequest(BaseModel): @@ -184,3 +186,9 @@ class UpdateKbReq(BaseModel): """更新知识库请求体""" kb_ids: list[str] = Field(description="知识库ID列表", alias="kbIds", default=[]) + + +class UserUpdateRequest(BaseModel): + """更新用户信息请求体""" + + auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") \ No newline at end of file diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index eac430d11d48f1d0f9197706999af30aaf85d399..5db37ef881a4ee6ae747f4a9c60169a807852ead 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -26,6 +26,7 @@ from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType from apps.schemas.record import RecordData from apps.schemas.user import UserInfo from apps.templates.generate_llm_operator_config import llm_provider_dict +from apps.common.config import Config class ResponseData(BaseModel): @@ -55,6 +56,7 @@ class AuthUserMsg(BaseModel): user_sub: str revision: bool is_admin: bool + auto_execute: bool class AuthUserRsp(ResponseData): @@ -98,7 +100,7 @@ class LLMIteam(BaseModel): icon: str = Field(default=llm_provider_dict["ollama"]["icon"]) llm_id: str = Field(alias="llmId", default="empty") - model_name: str = Field(alias="modelName", default="Ollama LLM") + model_name: str = Field(alias="modelName", default=Config().get_config().llm.model) class KbIteam(BaseModel): @@ -108,6 +110,14 @@ class KbIteam(BaseModel): kb_name: str = Field(alias="kbName") +class AppMcpServiceInfo(BaseModel): + """应用关联的MCP服务信息""" + + id: str = Field(..., description="MCP服务ID") + name: str = Field(default="", description="MCP服务名称") + description: str = Field(default="", description="MCP服务简介") + + class ConversationListItem(BaseModel): """GET /api/conversation Result数据结构""" @@ -285,6 +295,8 @@ class GetAppPropertyMsg(AppData): app_id: str = Field(..., alias="appId", description="应用ID") published: bool = Field(..., description="是否已发布") + mcp_service: list[AppMcpServiceInfo] = Field(default=[], alias="mcpService", description="MCP服务信息列表") + llm: LLMIteam | None = Field(alias="llm", default=None) class GetAppPropertyRsp(ResponseData): @@ -503,6 +515,10 @@ class GetMCPServiceDetailMsg(BaseModel): name: str = Field(..., description="MCP服务名称") description: str = Field(description="MCP服务描述") overview: str = Field(description="MCP服务概述") + status: MCPInstallStatus = Field( + description="MCP服务状态", + default=MCPInstallStatus.INIT, + ) tools: list[MCPTool] = Field(description="MCP服务Tools列表", default=[]) @@ -514,7 +530,7 @@ class EditMCPServiceMsg(BaseModel): name: str = Field(..., description="MCP服务名称") description: str = Field(description="MCP服务描述") overview: str = Field(description="MCP服务概述") - data: str = Field(description="MCP服务配置") + data: dict[str, Any] = Field(description="MCP服务配置") mcp_type: MCPType = Field(alias="mcpType", description="MCP 类型") @@ -582,6 +598,7 @@ class FlowStructureDeleteRsp(ResponseData): class UserGetMsp(BaseModel): """GET /api/user result""" + total: int = Field(default=0) user_info_list: list[UserInfo] = Field(alias="userInfoList", default=[]) diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 3a37126053e693226da63a07283832c9cf3c0fae..1b1c6540a56c8959f42fc65baf55c1c5869dd068 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -7,7 +7,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import StepStatus +from apps.schemas.enum_var import FlowStatus, StepStatus, LanguageType from apps.schemas.flow import Step from apps.schemas.mcp import MCPPlan @@ -23,12 +23,14 @@ class FlowStepHistory(BaseModel): task_id: str = Field(description="任务ID") flow_id: str = Field(description="FlowID") flow_name: str = Field(description="Flow名称") + flow_status: FlowStatus = Field(description="Flow状态") step_id: str = Field(description="当前步骤名称") step_name: str = Field(description="当前步骤名称") - step_description: str = Field(description="当前步骤描述") - status: StepStatus = Field(description="当前步骤状态") + step_description: str = Field(description="当前步骤描述", default="") + step_status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) + ex_data: dict[str, Any] | None = Field(description="额外数据", default=None) created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) @@ -36,17 +38,22 @@ class ExecutorState(BaseModel): """FlowExecutor状态""" # 执行器级数据 - flow_id: str = Field(description="Flow ID") - flow_name: str = Field(description="Flow名称") - description: str = Field(description="Flow描述") - status: StepStatus = Field(description="Flow执行状态") - # 附加信息 - step_id: str = Field(description="当前步骤ID") - step_name: str = Field(description="当前步骤名称") + flow_id: str = Field(description="Flow ID", default="") + flow_name: str = Field(description="Flow名称", default="") + description: str = Field(description="Flow描述", default="") + flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT) + # 任务级数据 + step_cnt: int = Field(description="当前步骤数量", default=0) + step_id: str = Field(description="当前步骤ID", default="") + tool_id: str = Field(description="当前工具ID", default="") + step_name: str = Field(description="当前步骤名称", default="") + step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") - app_id: str = Field(description="应用ID") - slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) - error_info: dict[str, Any] = Field(description="错误信息", default={}) + app_id: str = Field(description="应用ID", default="") + current_input: dict[str, Any] = Field(description="当前输入数据", default={}) + error_message: str = Field(description="错误信息", default="") + error_info: dict[str, Any] | None = Field(description="详细错误信息", default=None) + retry_times: int = Field(description="当前步骤重试次数", default=0) class TaskIds(BaseModel): @@ -57,6 +64,7 @@ class TaskIds(BaseModel): conversation_id: str = Field(description="对话ID") record_id: str = Field(description="记录ID", default_factory=lambda: str(uuid.uuid4())) user_sub: str = Field(description="用户ID") + active_id: str = Field(description="活动ID", default_factory=lambda: str(uuid.uuid4())) class TaskTokens(BaseModel): @@ -66,6 +74,7 @@ class TaskTokens(BaseModel): output_tokens: int = Field(description="输出Token", default=0) time: float = Field(description="时间点", default=0.0) full_time: float = Field(description="完整时间成本", default=0.0) + documents: list[dict[str, Any]] = Field(description="文档列表", default=[]) class TaskRuntime(BaseModel): @@ -90,10 +99,11 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) - state: ExecutorState | None = Field(description="Flow的状态", default=None) + state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + language: LanguageType = Field(description="语言", default=LanguageType.CHINESE) class StepQueueItem(BaseModel): diff --git a/apps/schemas/user.py b/apps/schemas/user.py index 61aa2587b8ae6255dc3885b04a96f12310f70773..ebea66924c062f4e680ad2e414768b307b056f4a 100644 --- a/apps/schemas/user.py +++ b/apps/schemas/user.py @@ -9,3 +9,4 @@ class UserInfo(BaseModel): user_sub: str = Field(alias="userSub", default="") user_name: str = Field(alias="userName", default="") + auto_execute: bool | None = Field(alias="autoExecute", default=False) diff --git a/apps/services/activity.py b/apps/services/activity.py index 299a49a640664b29a3e4724c5842f12bd289e3bf..7ba20cbb4300a5fa5cb60e37cdb887f38d54232f 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -3,70 +3,65 @@ import uuid from datetime import UTC, datetime - +import logging from apps.common.mongo import MongoDB from apps.constants import SLIDE_WINDOW_QUESTION_COUNT, SLIDE_WINDOW_TIME from apps.exceptions import ActivityError +logger = logging.getLogger(__name__) + class Activity: - """用户活动控制,限制单用户同一时间只能提问一个问题""" + """用户活动控制,限制单用户同一时间只能有SLIDE_WINDOW_QUESTION_COUNT个请求""" @staticmethod - async def is_active(user_sub: str) -> bool: + async def is_active(active_id: str) -> bool: """ 判断当前用户是否正在提问(占用GPU资源) :param user_sub: 用户实体ID :return: 判断结果,正在提问则返回True """ - time = round(datetime.now(UTC).timestamp(), 3) - - # 检查窗口内总请求数 - count = await MongoDB().get_collection("activity").count_documents( - {"timestamp": {"$gte": time - SLIDE_WINDOW_TIME, "$lte": time}}, - ) - if count >= SLIDE_WINDOW_QUESTION_COUNT: - return True # 检查用户是否正在提问 active = await MongoDB().get_collection("activity").find_one( - {"user_sub": user_sub}, + {"_id": active_id}, ) return bool(active) @staticmethod - async def set_active(user_sub: str) -> None: + async def set_active(user_sub: str) -> str: """设置用户的活跃标识""" time = round(datetime.now(UTC).timestamp(), 3) # 设置用户活跃状态 collection = MongoDB().get_collection("activity") - active = await collection.find_one({"user_sub": user_sub}) - if active: - err = "用户正在提问" + # 查看用户活跃标识是否在滑动窗口内 + + if await collection.count_documents({"user_sub": user_sub, "timestamp": {"$gt": time - SLIDE_WINDOW_TIME}}) >= SLIDE_WINDOW_QUESTION_COUNT: + err = "[Activity] 用户在滑动窗口内提问次数超过限制,请稍后再试。" raise ActivityError(err) + await collection.delete_many( + {"user_sub": user_sub, "timestamp": {"$lte": time - SLIDE_WINDOW_TIME}}, + ) + # 插入新的活跃记录 + tmp_record = { + "_id": str(uuid.uuid4()), + "user_sub": user_sub, + "timestamp": time, + } await collection.insert_one( - { - "_id": str(uuid.uuid4()), - "user_sub": user_sub, - "timestamp": time, - }, + tmp_record ) + return tmp_record["_id"] @staticmethod - async def remove_active(user_sub: str) -> None: + async def remove_active(active_id: str) -> None: """ 清除用户的活跃标识,释放GPU资源 :param user_sub: 用户实体ID """ - time = round(datetime.now(UTC).timestamp(), 3) # 清除用户当前活动标识 await MongoDB().get_collection("activity").delete_one( - {"user_sub": user_sub}, - ) - - # 清除超出窗口范围的请求记录 - await MongoDB().get_collection("activity").delete_many( - {"timestamp": {"$lte": time - SLIDE_WINDOW_TIME}}, + {"_id": active_id}, ) diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index e256ab55ab13bf8c5b9e40cb36a3c616a1d39596..3e20f12721a34dbda2b58fc72d3b380c4da3a4aa 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -59,7 +59,6 @@ class AppCenterManager: } user_favorite_app_ids = await AppCenterManager._get_favorite_app_ids_by_user(user_sub) - if filter_type == AppFilterType.ALL: # 获取所有已发布的应用 filters["published"] = True @@ -72,7 +71,6 @@ class AppCenterManager: "_id": {"$in": user_favorite_app_ids}, "published": True, } - # 添加关键字搜索条件 if keyword: filters["$or"] = [ @@ -84,7 +82,6 @@ class AppCenterManager: # 添加应用类型过滤条件 if app_type is not None: filters["app_type"] = app_type.value - # 获取应用列表 apps, total_apps = await AppCenterManager._search_apps_by_filter(filters, page, SERVICE_PAGE_SIZE) @@ -420,7 +417,7 @@ class AppCenterManager: ) @staticmethod - def _create_flow_metadata( + async def _create_flow_metadata( common_params: dict, data: AppData | None = None, app_data: AppPool | None = None, @@ -461,7 +458,7 @@ class AppCenterManager: return metadata @staticmethod - def _create_agent_metadata( + async def _create_agent_metadata( common_params: dict, user_sub: str, data: AppData | None = None, @@ -474,7 +471,12 @@ class AppCenterManager: # mcp_service 逻辑 if data is not None and hasattr(data, "mcp_service") and data.mcp_service: # 创建应用场景,验证传入的 mcp_service 状态,确保只使用已经激活的 (create_app) - metadata.mcp_service = [svc for svc in data.mcp_service if MCPServiceManager.is_active(user_sub, svc)] + activated_mcp_ids = [] + for svc in data.mcp_service: + is_activated = await MCPServiceManager.is_active(user_sub, svc) + if is_activated: + activated_mcp_ids.append(svc) + metadata.mcp_service = activated_mcp_ids elif data is not None and hasattr(data, "mcp_service"): # 更新应用场景,使用 data 中的 mcp_service (update_app) metadata.mcp_service = data.mcp_service if data.mcp_service is not None else [] @@ -484,7 +486,16 @@ class AppCenterManager: else: # 在预期的条件下,如果在 data 或 app_data 中找不到 mcp_service,则默认回退为空列表。 metadata.mcp_service = [] - + # 处理llm_id字段 + if data is not None and hasattr(data, "llm"): + # 创建应用场景,验证传入的 llm_id 状态 (create_app) + metadata.llm_id = data.llm if data.llm else "empty" + elif app_data is not None and hasattr(app_data, "llm_id"): + # 更新应用发布状态场景,使用 app_data 中的 llm_id (update_app_publish_status) + metadata.llm_id = app_data.llm_id if app_data.llm_id else "empty" + else: + # 在预期的条件下,如果在 data 或 app_data 中找不到 llm_id,则默认回退为 "empty"。 + metadata.llm_id = "empty" # Agent 应用的发布状态逻辑 if published is not None: # 从 update_app_publish_status 调用,'published' 参数已提供 metadata.published = published @@ -548,10 +559,10 @@ class AppCenterManager: # 根据应用类型创建不同的元数据 if app_type == AppType.FLOW: - return AppCenterManager._create_flow_metadata(common_params, data, app_data, published) + return (await AppCenterManager._create_flow_metadata(common_params, data, app_data, published)) if app_type == AppType.AGENT: - return AppCenterManager._create_agent_metadata(common_params, user_sub, data, app_data, published) + return (await AppCenterManager._create_agent_metadata(common_params, user_sub, data, app_data, published)) msg = "无效的应用类型" raise ValueError(msg) diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 7edd3e37cb9f15c6d6b2ac853307ceaeafc5a756..ee15e389995f5859871d08e5ea68ce515786ba10 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -40,6 +40,7 @@ class ConversationManager: @staticmethod async def add_conversation_by_user_sub( + title: str, user_sub: str, app_id: str, llm_id: str, kb_ids: list[str], *, debug: bool) -> Conversation | None: """通过用户ID新建对话""" @@ -75,6 +76,7 @@ class ConversationManager: conversation_id = str(uuid.uuid4()) conv = Conversation( _id=conversation_id, + title=title, user_sub=user_sub, app_id=app_id, llm=llm_item, @@ -183,7 +185,7 @@ class ConversationManager: except Exception as e: logger.error(f"清理对话变量池失败: {e}") - await TaskManager.delete_tasks_by_conversation_id(conversation_id) + await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id) async def _cleanup_transient_file_variables_in_pool(pool, user_sub: str, already_cleaned_files: list[str]) -> None: diff --git a/apps/services/document.py b/apps/services/document.py index 1d53c001913bb53f2cf209b5eccab4deaddd72ee..91726101525a7218075c0e72350b1c421b5caeda 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -190,7 +190,7 @@ class DocumentManager: logger.exception(f"[DocumentManager] 保存文件元数据到MongoDB失败: file_id={file_id}, error={e}") # 尝试清理MinIO中的文件 try: - from apps.common.minio_client import MinioClient + from apps.common.minio import MinioClient MinioClient.remove_object("document", file_id) logger.info(f"已清理MinIO中的文件: {file_id}") except Exception: @@ -227,7 +227,7 @@ class DocumentManager: # 清理MinIO文件 for doc in uploaded_files: try: - from apps.common.minio_client import MinioClient + from apps.common.minio import MinioClient MinioClient.remove_object("document", doc.id) except Exception as cleanup_error: cleanup_errors.append(f"MinIO清理失败 {doc.id}: {cleanup_error}") diff --git a/apps/services/flow.py b/apps/services/flow.py index 2e93d536112b6a673c129c66151d01a98e2bf06e..8c79bc5516330c0946137c9f2634509a35d3e802 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -4,13 +4,14 @@ import logging from typing import Any +from pydantic import BaseModel, Field from pymongo import ASCENDING from apps.common.mongo import MongoDB from apps.scheduler.pool.loader.flow import FlowLoader from apps.scheduler.slot.slot import Slot from apps.schemas.collection import User -from apps.schemas.enum_var import EdgeType, PermissionType +from apps.schemas.enum_var import EdgeType, PermissionType, LanguageType from apps.schemas.flow import Edge, Flow, Step from apps.schemas.flow_topology import ( EdgeItem, @@ -20,7 +21,10 @@ from apps.schemas.flow_topology import ( NodeServiceItem, PositionItem, ) +from apps.scheduler.pool.pool import Pool from apps.services.node import NodeManager +from apps.scheduler.executor.step import StepExecutor + logger = logging.getLogger(__name__) @@ -69,7 +73,9 @@ class FlowManager: return (result > 0) @staticmethod - async def get_node_id_by_service_id(service_id: str) -> list[NodeMetaDataItem] | None: + async def get_node_id_by_service_id( + service_id: str, language: LanguageType = LanguageType.CHINESE + ) -> list[NodeMetaDataItem] | None: """ serviceId获取service的接口数据,并将接口转换为节点元数据 @@ -98,12 +104,21 @@ class FlowManager: except Exception: logger.exception("[FlowManager] generate_from_schema 失败") continue + + if service_id == "": + call_class: type[BaseModel] = await Pool().get_call(node_pool_record["_id"]) + node_name = call_class.info(language).name + node_description = call_class.info().description + else: + node_name = node_pool_record["name"] + node_description = node_pool_record["description"] + node_meta_data_item = NodeMetaDataItem( nodeId=node_pool_record["_id"], callId=node_pool_record["call_id"], - name=node_pool_record["name"], + name=node_name, type=node_pool_record["type"], - description=node_pool_record["description"], + description=node_description, editable=True, createdAt=node_pool_record["created_at"], parameters=parameters, # 添加 parametersTemplate 参数 @@ -116,7 +131,9 @@ class FlowManager: return nodes_meta_data_items @staticmethod - async def get_service_by_user_id(user_sub: str) -> list[NodeServiceItem] | None: + async def get_service_by_user_id( + user_sub: str, language: LanguageType = LanguageType.CHINESE + ) -> list[NodeServiceItem] | None: """ 通过user_id获取用户自己上传的、其他人公开的且收藏的、受保护且有权限访问并收藏的service @@ -156,7 +173,14 @@ class FlowManager: sort=[("created_at", ASCENDING)], ) service_records = await service_records_cursor.to_list(length=None) - service_items = [NodeServiceItem(serviceId="", name="系统", type="system", nodeMetaDatas=[])] + service_items = [ + NodeServiceItem( + serviceId="", + name="系统" if language == LanguageType.CHINESE else "System", + type="system", + nodeMetaDatas=[], + ) + ] service_items += [ NodeServiceItem( serviceId=record["_id"], @@ -168,7 +192,9 @@ class FlowManager: for record in service_records ] for service_item in service_items: - node_meta_datas = await FlowManager.get_node_id_by_service_id(service_item.service_id) + node_meta_datas = await FlowManager.get_node_id_by_service_id( + service_item.service_id, language + ) if node_meta_datas is None: node_meta_datas = [] service_item.node_meta_datas = node_meta_datas @@ -264,28 +290,26 @@ class FlowManager: debug=flow_config.debug, ) for node_id, node_config in flow_config.steps.items(): - # TODO 新增标识位区分Node是否允许用户定义output parameters,如果output固定由节点生成,则走else逻辑 - if node_config.type == "Code" or node_config.type == "DirectReply" or node_config.type == "Choice" or node_config.type == "FileExtract": + # 根据Call的controlled_output属性判断是否允许用户定义output parameters + # TODO 两种处理分支应该有办法统一 + try: + call_cls = await StepExecutor.get_call_cls(node_config.type) + # 获取controlled_output属性值,默认为False + controlled_output = getattr(call_cls, 'controlled_output', False) + # 如果是类字段而不是实例属性,需要从字段定义中获取默认值 + if hasattr(call_cls, 'model_fields') and 'controlled_output' in call_cls.model_fields: + field_info = call_cls.model_fields['controlled_output'] + controlled_output = getattr(field_info, 'default', False) + except Exception as e: + logger.warning(f"[FlowManager] 获取Call类型 {node_config.type} 失败: {e}") + controlled_output = False + + if controlled_output: parameters = node_config.params # 直接使用保存的完整params - - # 为FileExtract节点确保有默认的text输出参数 - if node_config.type == "FileExtract": - if not parameters: - parameters = {} - if "output_parameters" not in parameters: - parameters["output_parameters"] = {} - if "text" not in parameters["output_parameters"]: - parameters["output_parameters"]["text"] = { - "type": "string", - "description": "文件提取的文本内容" - } else: # 其他节点:使用原有逻辑 input_parameters = node_config.params.get("input_parameters") - if node_config.node not in ("Empty"): - _, output_parameters = await NodeManager.get_node_params(node_config.node) - else: - output_parameters = {} + _, output_parameters = await NodeManager.get_node_params(node_config.node) # 对于循环节点,输出参数已经是扁平化格式,不需要再次处理 if hasattr(node_config, 'type') and node_config.type == "Loop": @@ -883,18 +907,6 @@ class FlowManager: # 参数处理逻辑与主工作流保持一致 if node_config.type == "Code" or node_config.type == "DirectReply" or node_config.type == "Choice" or node_config.type == "FileExtract": parameters = node_config.params # 直接使用保存的完整params - - # 为FileExtract节点确保有默认的text输出参数 - if node_config.type == "FileExtract": - if not parameters: - parameters = {} - if "output_parameters" not in parameters: - parameters["output_parameters"] = {} - if "text" not in parameters["output_parameters"]: - parameters["output_parameters"]["text"] = { - "type": "string", - "description": "文件提取的文本内容" - } else: # 其他节点:使用原有逻辑 input_parameters = node_config.params.get("input_parameters") diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index 483eb226349bef029c5ca64f5e2956b0fda683de..88b6b7894de58fba8532b1c9601b3a4bce411599 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -4,6 +4,7 @@ import collections import logging +from apps.schemas.enum_var import SpecialCallType from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError from apps.schemas.enum_var import NodeType from apps.schemas.flow_topology import EdgeItem, FlowItem, NodeItem @@ -39,14 +40,11 @@ class FlowService: for node in flow_item.nodes: from apps.scheduler.pool.pool import Pool from pydantic import BaseModel - if node.node_id != 'start' and node.node_id != 'end' and node.node_id != 'Empty': + if node.node_id != 'start' and node.node_id != 'end' and node.node_id != SpecialCallType.EMPTY.value: try: - call_class: type[BaseModel] = await Pool().get_call(node.call_id) - if not call_class: - node.node_id = 'Empty' - node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description + await Pool().get_call(node.call_id) except Exception as e: - node.node_id = 'Empty' + node.node_id = SpecialCallType.EMPTY.value node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description logger.error(f"[FlowService] 获取步骤的call_id失败{node.call_id}由于:{e}") node_branch_map[node.step_id] = set() diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 7cb880c0f08b1cc7727bde528be39ac748f07ae2..f6fec0ebd2f7053d3f8bc72a9d97ba077019528b 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -2,6 +2,7 @@ """MCP服务管理器""" import logging +from logging import config import random import re from typing import Any @@ -28,8 +29,10 @@ from apps.schemas.mcp import ( MCPTool, MCPType, ) +from apps.services.user import UserManager from apps.schemas.request_data import UpdateMCPServiceRequest from apps.schemas.response_data import MCPServiceCardItem +from apps.constants import MCP_PATH logger = logging.getLogger(__name__) sqids = Sqids(min_length=6) @@ -66,10 +69,8 @@ class MCPServiceManager: mcp_list = await mcp_collection.find({"_id": mcp_id}, {"status": True}).to_list(None) for db_item in mcp_list: status = db_item.get("status") - if MCPInstallStatus.READY.value == status: - return MCPInstallStatus.READY - if MCPInstallStatus.INSTALLING.value == status: - return MCPInstallStatus.INSTALLING + if status in MCPInstallStatus.__members__.values(): + return status return MCPInstallStatus.FAILED @staticmethod @@ -78,6 +79,8 @@ class MCPServiceManager: user_sub: str, keyword: str | None, page: int, + is_install: bool | None = None, + is_active: bool | None = None, ) -> list[MCPServiceCardItem]: """ 获取所有MCP服务列表 @@ -89,6 +92,20 @@ class MCPServiceManager: :return: MCP服务列表 """ filters = MCPServiceManager._build_filters(search_type, keyword) + if is_active is not None: + if is_active: + filters["activated"] = {"$in": [user_sub]} + else: + filters["activated"] = {"$nin": [user_sub]} + user_info = await UserManager.get_userinfo_by_user_sub(user_sub) + if not user_info.is_admin: + filters["status"] = MCPInstallStatus.READY.value + else: + if is_install is not None: + if is_install: + filters["status"] = MCPInstallStatus.READY.value + else: + filters["status"] = {"$ne": MCPInstallStatus.READY.value} mcpservice_pools = await MCPServiceManager._search_mcpservice(filters, page) return [ MCPServiceCardItem( @@ -198,7 +215,6 @@ class MCPServiceManager: base_filters = {"author": {"$regex": keyword, "$options": "i"}} return base_filters - @staticmethod async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: """ @@ -209,9 +225,9 @@ class MCPServiceManager: """ # 检查config if data.mcp_type == MCPType.SSE: - config = MCPServerSSEConfig.model_validate_json(data.config) + config = MCPServerSSEConfig.model_validate(data.config) else: - config = MCPServerStdioConfig.model_validate_json(data.config) + config = MCPServerStdioConfig.model_validate(data.config) # 构造Server mcp_server = MCPServerConfig( @@ -233,8 +249,24 @@ class MCPServiceManager: # 保存并载入配置 logger.info("[MCPServiceManager] 创建mcp:%s", mcp_server.name) + mcp_path = MCP_PATH / "template" / mcp_id / "project" + if isinstance(config, MCPServerStdioConfig): + index = None + for i in range(len(config.args)): + if not config.args[i] == "--directory": + continue + index = i + 1 + break + if index is not None: + if index >= len(config.args): + config.args.append(str(mcp_path)) + else: + config.args[index] = str(mcp_path) + else: + config.args += ["--directory", str(mcp_path)] + await MCPLoader._insert_template_db(mcp_id=mcp_id, config=mcp_server) await MCPLoader.save_one(mcp_id, mcp_server) - await MCPLoader.init_one_template(mcp_id=mcp_id, config=mcp_server) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.INIT) return mcp_id @staticmethod @@ -256,21 +288,25 @@ class MCPServiceManager: raise ValueError(msg) db_service = MCPCollection.model_validate(db_service) - for user_id in db_service.activated: - await MCPServiceManager.deactive_mcpservice(user_sub=user_id, service_id=data.service_id) - - await MCPLoader.init_one_template(mcp_id=data.service_id, config=MCPServerConfig( + mcp_config = MCPServerConfig( name=data.name, overview=data.overview, description=data.description, - config=MCPServerStdioConfig.model_validate_json( + config=MCPServerStdioConfig.model_validate( data.config, - ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate_json( + ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate( data.config, ), type=data.mcp_type, author=user_sub, - )) + ) + old_mcp_config = await MCPLoader.get_config(data.service_id) + await MCPLoader._insert_template_db(mcp_id=data.service_id, config=mcp_config) + await MCPLoader.save_one(mcp_id=data.service_id, config=mcp_config) + if old_mcp_config.type != mcp_config.type or old_mcp_config.config != mcp_config.config: + for user_id in db_service.activated: + await MCPServiceManager.deactive_mcpservice(user_sub=user_id, service_id=data.service_id) + await MCPLoader.update_template_status(data.service_id, MCPInstallStatus.INIT) # 返回服务ID return data.service_id @@ -297,6 +333,7 @@ class MCPServiceManager: async def active_mcpservice( user_sub: str, service_id: str, + mcp_env: dict[str, Any] | None = None, ) -> None: """ 激活MCP服务 @@ -310,7 +347,7 @@ class MCPServiceManager: for item in status: mcp_status = item.get("status", MCPInstallStatus.INSTALLING) if mcp_status == MCPInstallStatus.READY: - await MCPLoader.user_active_template(user_sub, service_id) + await MCPLoader.user_active_template(user_sub, service_id, mcp_env) else: err = "[MCPServiceManager] MCP服务未准备就绪" raise RuntimeError(err) @@ -365,9 +402,36 @@ class MCPServiceManager: image = image.convert("RGB") image = image.resize((64, 64), resample=Image.Resampling.LANCZOS) # 检查文件夹 - if not await MCP_ICON_PATH.exists(): - await MCP_ICON_PATH.mkdir(parents=True, exist_ok=True) + image_path = MCP_PATH / "template" / service_id / "icon" + if not await image_path.exists(): + await image_path.mkdir(parents=True, exist_ok=True) # 保存 - image.save(MCP_ICON_PATH / f"{service_id}.png", format="PNG", optimize=True, compress_level=9) + image.save(image_path / f"{service_id}.png", format="PNG", optimize=True, compress_level=9) + + return f"{image_path / f'{service_id}.png'}" + + @staticmethod + async def install_mcpservice(user_sub: str, service_id: str, install: bool) -> None: + """ + 安装或卸载MCP服务 - return f"/static/mcp/{service_id}.png" + :param user_sub: str: 用户ID + :param service_id: str: MCP服务ID + :param install: bool: 是否安装 + :return: 无 + """ + service_collection = MongoDB().get_collection("mcp") + db_service = await service_collection.find_one({"_id": service_id, "author": user_sub}) + db_service = MCPCollection.model_validate(db_service) + if install: + if db_service.status == MCPInstallStatus.INSTALLING: + err = "[MCPServiceManager] MCP服务已处于安装中" + raise Exception(err) + await service_collection.update_one( + {"_id": service_id}, + {"$set": {"status": MCPInstallStatus.INSTALLING}}, + ) + mcp_config = await MCPLoader.get_config(service_id) + await MCPLoader.init_one_template(mcp_id=service_id, config=mcp_config) + else: + await MCPLoader.cancel_installing_task([service_id]) diff --git a/apps/services/node.py b/apps/services/node.py index eb03c5f67ce7b6b9932e66b206a27d829fbd5fe1..bd98a41b018d2ee6e59b65092716226f0cbaea16 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -4,6 +4,7 @@ import logging from typing import TYPE_CHECKING, Any +from apps.schemas.enum_var import SpecialCallType from apps.common.mongo import MongoDB from apps.schemas.node import APINode from apps.schemas.pool import NodePool @@ -78,6 +79,10 @@ class NodeManager: """获取Node数据""" from apps.scheduler.pool.pool import Pool + if node_id == SpecialCallType.EMPTY.value: + # 如果是空节点,返回空Schema + return {}, {} + # 查找Node信息 logger.info("[NodeManager] 获取节点 %s", node_id) node_collection = MongoDB().get_collection("node") diff --git a/apps/services/rag.py b/apps/services/rag.py index efbdfe94031dd9ac289ac541022ef6d299cd1669..b0db7989e7cb6726060ea113e777ad4bcaf97283 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -6,6 +6,7 @@ import json import logging from collections.abc import AsyncGenerator +import re import httpx from typing import Any from fastapi import status @@ -16,9 +17,8 @@ from apps.llm.reasoning import ReasoningLLM from apps.llm.token import TokenCalculator from apps.schemas.collection import LLM from apps.schemas.config import LLMConfig -from apps.schemas.enum_var import EventType +from apps.schemas.enum_var import EventType, LanguageType from apps.schemas.rag_data import RAGQueryReq -from apps.services.activity import Activity from apps.services.session import SessionManager logger = logging.getLogger(__name__) @@ -29,59 +29,106 @@ class RAG: system_prompt: str = "You are a helpful assistant." """系统提示词""" - user_prompt = """' - - 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后进行脚注。 - 一个例子将在中给出。 - 上下文背景信息将在中给出。 - 用户的提问将在中给出。 - 注意: - 1.输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 - 2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。 - 3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。 - 4.不要对脚注本身进行解释或说明。 - 5.请不要使用中的文档的id作为脚注。 - - + user_prompt: dict[LanguageType, str] = { + LanguageType.CHINESE: r""" + + 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后进行脚注。 + 一个例子将在中给出。 + 上下文背景信息将在中给出。 + 用户的提问将在中给出。 + 注意: + 1.输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 + 2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。 + 3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。 + 4.不要对脚注本身进行解释或说明。 + 5.请不要使用中的文档的id作为脚注。 + + + + + + openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。 + + + openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。 + + + + + openEuler社区的成员来自世界各地,包括开发者、用户和企业。 + + + openEuler社区的成员共同努力,推动开源操作系统的发展,并且为用户提供支持和帮助。 + + + + + openEuler社区的目标是什么? + + + openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。[[1]] + openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]] + + + - - - openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。 - - - openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。 - - - - - openEuler社区的成员来自世界各地,包括开发者、用户和企业。 - - - openEuler社区的成员共同努力,推动开源操作系统的发展,并且为用户提供支持和帮助。 - - + {bac_info} - openEuler社区的目标是什么? + {user_question} - - openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。[[1]] - openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]] - - - - - {bac_info} - - - {user_question} - - """ + """, + LanguageType.ENGLISH: r""" + + You are a helpful assistant of openEuler community. Please answer the user's question based on the given background information and add footnotes after the related sentences. + An example will be given in . + The background information will be given in . + The user's question will be given in . + Note: + 1. Do not include any XML tags in the output, and do not make up any information. If you think the user's question is unrelated to the background information, please ignore the background information and directly answer. + 2. Your response should not exceed 250 words. + + + + + + openEuler community is an open source operating system community, committed to promoting the development of the Linux operating system. + + + openEuler community aims to provide users with a stable, secure, and efficient operating system platform, and support multiple hardware architectures. + + + + + Members of the openEuler community come from all over the world, including developers, users, and enterprises. + + + Members of the openEuler community work together to promote the development of open source operating systems, and provide support and assistance to users. + + + + + What is the goal of openEuler community? + + + openEuler community is an open source operating system community, committed to promoting the development of the Linux operating system. [[1]] + openEuler community aims to provide users with a stable, secure, and efficient operating system platform, and support multiple hardware architectures. [[1]] + + + + + {bac_info} + + + {user_question} + + """, + } @staticmethod - async def get_doc_info_from_rag(user_sub: str, max_tokens: int, - doc_ids: list[str], - data: RAGQueryReq) -> list[dict[str, Any]]: + async def get_doc_info_from_rag( + user_sub: str, max_tokens: int, doc_ids: list[str], data: RAGQueryReq + ) -> list[dict[str, Any]]: """获取RAG服务的文档信息""" session_id = await SessionManager.get_session_by_user_sub(user_sub) url = Config().get_config().rag.rag_service.rstrip("/") + "/chunk/search" @@ -139,20 +186,34 @@ class RAG: doc_cnt += 1 doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] - leave_tokens -= token_calculator.calculate_token_length(messages=[ - {"role": "user", "content": f''''''}, - {"role": "user", "content": ""} + leave_tokens -= token_calculator.calculate_token_length( + messages=[ + { + "role": "user", + "content": f"""""", + }, + {"role": "user", "content": ""}, + ], + pure_text=True, + ) + tokens_of_chunk_element = token_calculator.calculate_token_length( + messages=[ + {"role": "user", "content": ""}, + {"role": "user", "content": ""}, ], - pure_text=True) - tokens_of_chunk_element = token_calculator.calculate_token_length(messages=[ - {"role": "user", "content": ""}, - {"role": "user", "content": ""}, - ], pure_text=True) + pure_text=True, + ) doc_cnt = 0 doc_id_map = {} for doc_chunk in doc_chunk_list: if doc_chunk["docId"] not in doc_id_map: doc_cnt += 1 + t = doc_chunk.get("docCreatedAt", None) + if isinstance(t, str): + t = datetime.strptime(t, '%Y-%m-%d %H:%M') + t = round(t.replace(tzinfo=UTC).timestamp(), 3) + else: + t = round(datetime.now(UTC).timestamp(), 3) doc_info_list.append({ "id": doc_chunk["docId"], "order": doc_cnt, @@ -161,7 +222,7 @@ class RAG: "extension": doc_chunk.get("docExtension", ""), "abstract": doc_chunk.get("docAbstract", ""), "size": doc_chunk.get("docSize", 0), - "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)), + "created_at": t, }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] @@ -191,7 +252,12 @@ class RAG: @staticmethod async def chat_with_llm_base_on_rag( - user_sub: str, llm: LLM, history: list[dict[str, str]], doc_ids: list[str], data: RAGQueryReq + user_sub: str, + llm: LLM, + history: list[dict[str, str]], + doc_ids: list[str], + data: RAGQueryReq, + language: LanguageType = LanguageType.CHINESE, ) -> AsyncGenerator[str, None]: """获取RAG服务的结果""" reasion_llm = ReasoningLLM( @@ -205,7 +271,9 @@ class RAG: if history: try: question_obj = QuestionRewrite() - data.query = await question_obj.generate(history=history, question=data.query, llm=reasion_llm) + data.query = await question_obj.generate( + history=history, question=data.query, llm=reasion_llm, language=language + ) except Exception: logger.exception("[RAG] 问题重写失败") doc_chunk_list = await RAG.get_doc_info_from_rag( @@ -220,7 +288,7 @@ class RAG: }, { "role": "user", - "content": RAG.user_prompt.format( + "content": RAG.user_prompt[language].format( bac_info=bac_info, user_question=data.query, ), @@ -245,8 +313,9 @@ class RAG: + "\n\n" ) max_footnote_length = 4 - while doc_cnt > 0: - doc_cnt //= 10 + tmp_doc_cnt = doc_cnt + while tmp_doc_cnt > 0: + tmp_doc_cnt //= 10 max_footnote_length += 1 buffer = "" async for chunk in reasion_llm.call( @@ -257,8 +326,6 @@ class RAG: result_only=False, model=llm.model_name, ): - if not await Activity.is_active(user_sub): - return chunk = buffer + chunk # 防止脚注被截断 if len(chunk) >= 2 and chunk[-2:] != "]]": @@ -270,6 +337,14 @@ class RAG: chunk = chunk[:index + 1] else: buffer = "" + # 匹配脚注 + footnotes = re.findall(r"\[\[\d+\]\]", chunk) + # 去除编号大于doc_cnt的脚注 + footnotes = [fn for fn in footnotes if int(fn[2:-2]) > doc_cnt] + footnotes = list(set(footnotes)) # 去重 + if footnotes: + for fn in footnotes: + chunk = chunk.replace(fn, "") output_tokens += TokenCalculator().calculate_token_length( messages=[ {"role": "assistant", "content": chunk}, @@ -308,4 +383,4 @@ class RAG: ensure_ascii=False, ) + "\n\n" - ) + ) \ No newline at end of file diff --git a/apps/services/record.py b/apps/services/record.py index 5c8a89dfd224166ac822e2eeb47fb79f1fefe7e4..b58df7453c0d02e93de1230560f02f98891e06fe 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -8,16 +8,19 @@ from apps.common.mongo import MongoDB from apps.schemas.record import ( Record, RecordGroup, + FlowHistory, ) - +from apps.schemas.enum_var import FlowStatus logger = logging.getLogger(__name__) + + class RecordManager: """问答对相关操作""" @staticmethod - async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None: + async def create_record_group(group_id: str, user_sub: str, conversation_id: str) -> str | None: """创建问答组""" mongo = MongoDB() record_group_collection = mongo.get_collection("record_group") @@ -26,7 +29,6 @@ class RecordManager: _id=group_id, user_sub=user_sub, conversation_id=conversation_id, - task_id=task_id, ) try: @@ -49,6 +51,10 @@ class RecordManager: mongo = MongoDB() group_collection = mongo.get_collection("record_group") try: + await group_collection.update_one( + {"_id": group_id, "user_sub": user_sub}, + {"$pull": {"records": {"id": record.id}}} + ) await group_collection.update_one( {"_id": group_id, "user_sub": user_sub}, {"$push": {"records": record.model_dump(by_alias=True)}}, @@ -128,11 +134,25 @@ class RecordManager: pipeline.append({"$limit": total_pairs}) records = await record_group_collection.aggregate(pipeline) + return [RecordGroup.model_validate(record) async for record in records] except Exception: logger.exception("[RecordManager] 查询问答组失败") return [] + @staticmethod + async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None: + """更新Record关联的Flow状态""" + record_group_collection = MongoDB().get_collection("record_group") + try: + await record_group_collection.update_many( + {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, + {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, + array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], + ) + except Exception: + logger.exception("[RecordManager] 更新Record关联的Flow状态失败") + @staticmethod async def verify_record_in_group(group_id: str, record_id: str, user_sub: str) -> bool: """ @@ -151,7 +171,6 @@ class RecordManager: logger.exception("[RecordManager] 验证记录是否在组中失败") return False - @staticmethod async def check_group_id(group_id: str, user_sub: str) -> bool: """检查group_id是否存在""" diff --git a/apps/services/task.py b/apps/services/task.py index df37dee55bfbf9a809fb58616e5a61059d8a0135..bb10eafd20c056c76d4c7ef9ca6d226c018b11a1 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -83,7 +83,7 @@ class TaskManager: return [] flow_context_list = [] - for flow_context_id in records[0]["records"]["flow"]: + for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: flow_context_list.append(FlowStepHistory.model_validate(flow_context)) @@ -94,19 +94,18 @@ class TaskManager: return flow_context_list @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> List[FlowStepHistory]: + async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context") flow_context = [] try: - async for history in flow_context_collection.find( - {"task_id": task_id}, - ).sort( - "created_at", -1, - ).limit(length): - # 将字典转换为 FlowStepHistory 对象 - flow_context.append(FlowStepHistory.model_validate(history)) + if length is None: + async for context in flow_context_collection.find({"task_id": task_id}): + flow_context.append(FlowStepHistory.model_validate(context)) + else: + async for context in flow_context_collection.find({"task_id": task_id}).limit(length): + flow_context.append(FlowStepHistory.model_validate(context)) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return [] @@ -114,23 +113,39 @@ class TaskManager: return flow_context @staticmethod - async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: + async def init_new_task( + user_sub: str, + session_id: str | None = None, + post_body: RequestData | None = None, + ) -> Task: + """获取任务块""" + return Task( + _id=str(uuid.uuid4()), + ids=TaskIds( + user_sub=user_sub if user_sub else "", + session_id=session_id if session_id else "", + conversation_id=post_body.conversation_id, + group_id=post_body.group_id if post_body.group_id else "", + ), + question=post_body.question if post_body else "", + group_id=post_body.group_id if post_body else "", + tokens=TaskTokens(), + runtime=TaskRuntime(), + ) + + @staticmethod + async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: - for history in flow_context: - # 查找是否存在 - current_context = await flow_context_collection.find_one({ - "task_id": task_id, - "_id": history["_id"], - }) - if current_context: - await flow_context_collection.update_one( - {"_id": current_context["_id"]}, - {"$set": history}, - ) - else: - await flow_context_collection.insert_one(history) + # 删除旧的flow_context + await flow_context_collection.delete_many({"task_id": task_id}) + if not flow_context: + return + await flow_context_collection.insert_many( + [history.model_dump(exclude_none=True, by_alias=True) for history in flow_context], + ordered=False, + ) except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") @@ -145,7 +160,26 @@ class TaskManager: await task_collection.delete_one({"_id": task_id}) @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: str) -> None: + async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]: + """通过ConversationID删除Task信息""" + mongo = MongoDB() + task_collection = mongo.get_collection("task") + task_ids = [] + try: + async for task in task_collection.find( + {"conversation_id": conversation_id}, + {"_id": 1}, + ): + task_ids.append(task["_id"]) + if task_ids: + await task_collection.delete_many({"conversation_id": conversation_id}) + return task_ids + except Exception: + logger.exception("[TaskManager] 删除ConversationID的Task信息失败") + return [] + + @staticmethod + async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" mongo = MongoDB() task_collection = mongo.get_collection("task") @@ -162,50 +196,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod - async def get_task( - cls, - task_id: str | None = None, - session_id: str | None = None, - post_body: RequestData | None = None, - user_sub: str | None = None, - ) -> Task: - """获取任务块""" - if task_id: - try: - task = await cls.get_task_by_task_id(task_id) - if task: - return task - except Exception: - logger.exception("[TaskManager] 通过task_id获取任务失败") - - logger.info("[TaskManager] 未提供task_id,通过session_id获取任务") - if not session_id or not post_body: - err = ( - "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。" - ) - raise ValueError(err) - - if post_body.group_id: - task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id) - else: - task = await cls.get_task_by_conversation_id(post_body.conversation_id) - - if task: - return task - return Task( - _id=str(uuid.uuid4()), - ids=TaskIds( - user_sub=user_sub if user_sub else "", - session_id=session_id if session_id else "", - conversation_id=post_body.conversation_id, - group_id=post_body.group_id if post_body.group_id else "", - ), - state=None, - tokens=TaskTokens(), - runtime=TaskRuntime(), - ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" diff --git a/apps/services/user.py b/apps/services/user.py index 2721d3773bd43e2a8fda5b371d6336fcf0b9b7f3..1b96df18143f5a87b102f68e2bb4e2df62753f05 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -4,6 +4,7 @@ import logging from datetime import UTC, datetime +from apps.schemas.request_data import UserUpdateRequest from apps.common.mongo import MongoDB from apps.schemas.collection import User from apps.services.conversation import ConversationManager @@ -28,7 +29,7 @@ class UserManager: ).model_dump(by_alias=True)) @staticmethod - async def get_all_user_sub() -> list[str]: + async def get_all_user_sub(page_size: int = 20, page_cnt: int = 1, filter_user_subs: list[str] = []) -> tuple[list[str], int]: """ 获取所有用户的sub @@ -36,7 +37,13 @@ class UserManager: """ mongo = MongoDB() user_collection = mongo.get_collection("user") - return [user["_id"] async for user in user_collection.find({}, {"_id": 1})] + total = await user_collection.count_documents({}) - len(filter_user_subs) + + users = await user_collection.find( + {"_id": {"$nin": filter_user_subs}}, + {"_id": 1}, + ).skip((page_cnt - 1) * page_size).limit(page_size).to_list(length=page_size) + return [user["_id"] for user in users], total @staticmethod async def get_userinfo_by_user_sub(user_sub: str) -> User | None: @@ -52,7 +59,25 @@ class UserManager: return User(**user_data) if user_data else None @staticmethod - async def update_userinfo_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: + async def update_userinfo_by_user_sub(user_sub: str, data: UserUpdateRequest) -> None: + """ + 根据用户sub更新用户信息 + + :param user_sub: 用户sub + :param data: 用户更新信息 + :return: 是否更新成功 + """ + mongo = MongoDB() + user_collection = mongo.get_collection("user") + update_dict = { + "$set": { + "auto_execute": data.auto_execute, + } + } + await user_collection.update_one({"_id": user_sub}, update_dict) + + @staticmethod + async def update_refresh_revision_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: """ 根据用户sub更新用户信息