From 8b237cef009f84cac3383f360a3a33333e52661d Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 27 Aug 2025 16:02:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B4=E7=90=86Task=E5=92=8C=E5=A4=A7?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/token.py | 21 +-- apps/models/task.py | 6 +- apps/models/user.py | 8 +- apps/routers/chat.py | 2 +- apps/scheduler/call/facts/facts.py | 10 +- apps/scheduler/call/graph/graph.py | 5 +- apps/scheduler/call/graph/prompt.py | 101 +++++++++++++ apps/scheduler/call/graph/schema.py | 10 +- apps/scheduler/call/graph/style.py | 164 --------------------- apps/scheduler/call/slot/slot.py | 4 +- apps/scheduler/executor/agent.py | 216 +++++++++++++++------------- apps/scheduler/executor/base.py | 9 +- apps/scheduler/executor/flow.py | 23 ++- apps/scheduler/executor/step.py | 75 ++++++---- apps/scheduler/mcp/host.py | 10 +- apps/scheduler/mcp_agent/base.py | 11 +- apps/scheduler/pool/check.py | 2 +- apps/scheduler/pool/pool.py | 8 +- apps/scheduler/scheduler/flow.py | 25 ---- apps/schemas/task.py | 16 ++- apps/services/conversation.py | 2 +- apps/services/flow_validate.py | 5 +- apps/services/llm.py | 6 +- apps/services/rag.py | 70 ++++----- apps/services/task.py | 143 ++++++++---------- 25 files changed, 433 insertions(+), 519 deletions(-) create mode 100644 apps/scheduler/call/graph/prompt.py delete mode 100644 apps/scheduler/call/graph/style.py diff --git a/apps/llm/token.py b/apps/llm/token.py index b76cd744..55e60c81 100644 --- a/apps/llm/token.py +++ b/apps/llm/token.py @@ -29,26 +29,13 @@ class TokenCalculator(metaclass=SingletonMeta): return result - @staticmethod - def get_k_tokens_words_from_content(content: str, k: int | None = None) -> str: + def get_k_tokens_words_from_content(self, content: str, k: int | None = None) -> str: """获取k个token的词""" if k is None: return content if k <= 0: return "" - if TokenCalculator().calculate_token_length(messages=[ - {"role": "user", "content": content}, - ], pure_text=True) <= k: - return content - left = 0 - right = len(content) - while left + 1 < right: - mid = (left + right) // 2 - if TokenCalculator().calculate_token_length(messages=[ - {"role": "user", "content": content[:mid]}, - ], pure_text=True) <= k: - left = mid - else: - right = mid - return content[:left] + encodings = self._encoder.encode(content) + encodings = encodings[:k] + return self._encoder.decode(encodings, errors="ignore") diff --git a/apps/models/task.py b/apps/models/task.py index c21975e9..2e72a1a0 100644 --- a/apps/models/task.py +++ b/apps/models/task.py @@ -23,8 +23,6 @@ class Task(Base): UUID(as_uuid=True), ForeignKey("framework_conversation.id"), nullable=False, ) """对话ID""" - sessionId: Mapped[str] = mapped_column(String(255), ForeignKey("framework_session.id"), nullable=False) # noqa: N815 - """会话ID""" checkpointId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815 UUID(as_uuid=True), ForeignKey("framework_executor_checkpoint.id"), nullable=True, default=None, @@ -56,6 +54,10 @@ class TaskRuntime(Base): """时间""" fullTime: Mapped[float] = mapped_column(Float, default=0.0, nullable=False) # noqa: N815 """完整时间成本""" + sessionId: Mapped[str | None] = mapped_column( # noqa: N815 + String(255), ForeignKey("framework_session.id"), nullable=True, default=None, + ) + """会话ID""" userInput: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815 """用户输入""" fullAnswer: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815 diff --git a/apps/models/user.py b/apps/models/user.py index e18918c9..abd844b3 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -36,8 +36,12 @@ class User(Base): """用户个人令牌""" selectedKB: Mapped[list[uuid.UUID]] = mapped_column(ARRAY(UUID), default=[], nullable=False) # noqa: N815 """用户选择的知识库的ID""" - defaultLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 - """用户选择的大模型ID""" + reasoningLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 + """用户选择的问答模型ID""" + functionLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 + """用户选择的函数模型ID""" + embeddingLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 + """用户选择的向量模型ID""" autoExecute: Mapped[bool | None] = mapped_column(Boolean, default=False, nullable=True) # noqa: N815 """Agent是否自动执行""" diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7bba1298..0c5ab7e9 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -57,7 +57,7 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T if not post_body.task_id: err = "[Chat] task_id 不可为空!" raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty") - task = await TaskManager.get_task_by_task_id(post_body.task_id) + task = await TaskManager.get_task_data_by_task_id(post_body.task_id) post_body.app = RequestDataApp(appId=task.state.app_id) post_body.conversation_id = task.ids.conversation_id post_body.language = task.language diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index d317011e..48d2531f 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -88,18 +88,20 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): # 提取事实信息 facts_tpl = env.from_string(FACTS_PROMPT[self._sys_vars.language]) facts_prompt = facts_tpl.render(conversation=data.message) - facts_obj: FactsGen = await self._json([ + facts_obj = await self._json([ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": facts_prompt}, - ], FactsGen) # type: ignore[arg-type] + ], FactsGen.model_json_schema()) + facts_obj = FactsGen.model_validate(facts_obj) # 更新用户画像 domain_tpl = env.from_string(DOMAIN_PROMPT[self._sys_vars.language]) domain_prompt = domain_tpl.render(conversation=data.message) - domain_list: DomainGen = await self._json([ + domain_list = await self._json([ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": domain_prompt}, - ], DomainGen) # type: ignore[arg-type] + ], DomainGen.model_json_schema()) + domain_list = DomainGen.model_validate(domain_list) for domain in domain_list.keywords: await UserTagManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) diff --git a/apps/scheduler/call/graph/graph.py b/apps/scheduler/call/graph/graph.py index 46e55e13..dc5986c0 100644 --- a/apps/scheduler/call/graph/graph.py +++ b/apps/scheduler/call/graph/graph.py @@ -17,8 +17,8 @@ from apps.schemas.scheduler import ( CallVars, ) +from .prompt import GENERATE_STYLE_PROMPT from .schema import RenderFormat, RenderInput, RenderOutput -from .style import RenderStyle class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): @@ -88,10 +88,7 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): self._option_template["dataset"]["source"] = processed_data try: - style_obj = RenderStyle() llm_output = await style_obj.generate(question=data.question) - self.tokens.input_tokens += style_obj.input_tokens - self.tokens.output_tokens += style_obj.output_tokens add_style = llm_output.get("additional_style", "") self._parse_options(column_num, llm_output["chart_type"], add_style, llm_output["scale_type"]) diff --git a/apps/scheduler/call/graph/prompt.py b/apps/scheduler/call/graph/prompt.py new file mode 100644 index 00000000..521a432d --- /dev/null +++ b/apps/scheduler/call/graph/prompt.py @@ -0,0 +1,101 @@ +"""图表相关提示词""" + +from apps.schemas.enum_var import LanguageType + +GENERATE_STYLE_PROMPT: dict[LanguageType, str] = { + LanguageType.CHINESE: r""" + + + 你的目标是:帮助用户在绘制图表时做出样式选择。 + 请以JSON格式输出你的选择。 + + 图表类型: + - `bar`: 柱状图 + - `pie`: 饼图 + - `line`: 折线图 + - `scatter`: 散点图 + 柱状图的附加样式: + - `normal`: 普通柱状图 + - `stacked`: 堆叠柱状图 + 饼图的附加样式: + - `normal`: 普通饼图 + - `ring`: 环形饼图 + 可用坐标比例: + - `linear`: 线性比例 + - `log`: 对数比例 + + + + ## 问题 + 查询数据库中的数据,并绘制堆叠柱状图。 + + ## 思考 + 让我们一步步思考。用户要求绘制堆叠柱状图,因此图表类型应为 `bar`,即柱状图;图表样式\ +应为 `stacked`,即堆叠形式。 + + ## 答案 + { + "chart_type": "bar", + "additional_style": "stacked", + "scale_type": "linear" + } + + + + ## 问题 + {question} + + ## 思考 + 让我们一步步思考。 + """, + LanguageType.ENGLISH: r""" + + + Your mission is: help the user make style choices when drawing a chart. + Please output your choices in JSON format. + + Chart types: + - `bar`: Bar chart + - `pie`: Pie chart + - `line`: Line chart + - `scatter`: Scatter chart + + Bar chart additional styles: + - `normal`: Normal bar chart + - `stacked`: Stacked bar chart + + Pie chart additional styles: + - `normal`: Normal pie chart + - `ring`: Ring pie chart + + Axis scaling: + - `linear`: Linear scaling + - `log`: Logarithmic scaling + + + + ## Question + Query the data from the database and draw a stacked bar chart. + + ## Thought + Let's think step by step. The user requires drawing a stacked bar chart, so the chart type \ +should be `bar`, i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form. + + ## Answer + { + "chart_type": "bar", + "additional_style": "stacked", + "scale_type": "linear" + } + + + + ## Question + + {question} + + ## Thought + + Let's think step by step. + """, +} diff --git a/apps/scheduler/call/graph/schema.py b/apps/scheduler/call/graph/schema.py index 0674a4c4..54048fa2 100644 --- a/apps/scheduler/call/graph/schema.py +++ b/apps/scheduler/call/graph/schema.py @@ -1,13 +1,21 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """图表工具的输入输出""" -from typing import Any +from typing import Any, Literal from pydantic import BaseModel, Field from apps.scheduler.call.core import DataBase +class RenderStyleResult(BaseModel): + """选择图表样式结果""" + + chart_type: Literal["bar", "pie", "line", "scatter"] = Field(description="图表类型") + additional_style: Literal["normal", "stacked", "ring"] | None = Field(description="图表样式") + scale_type: Literal["linear", "log"] = Field(description="图表比例") + + class RenderAxis(BaseModel): """ECharts图表的轴配置""" diff --git a/apps/scheduler/call/graph/style.py b/apps/scheduler/call/graph/style.py deleted file mode 100644 index 20f6291c..00000000 --- a/apps/scheduler/call/graph/style.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""选择图表样式""" - -import logging -from typing import Any, Literal - -from pydantic import BaseModel, Field - -from apps.llm.function import JsonGenerator -from apps.llm.patterns.core import CorePattern -from apps.llm.reasoning import ReasoningLLM -from apps.schemas.enum_var import LanguageType - -logger = logging.getLogger(__name__) - - -class RenderStyleResult(BaseModel): - """选择图表样式结果""" - - chart_type: Literal["bar", "pie", "line", "scatter"] = Field(description="图表类型") - additional_style: Literal["normal", "stacked", "ring"] | None = Field(description="图表样式") - scale_type: Literal["linear", "log"] = Field(description="图表比例") - - -class RenderStyle(CorePattern): - """选择图表样式""" - - @staticmethod - def _default() -> tuple[dict[LanguageType, str], dict[LanguageType, str]]: - """默认的Prompt内容""" - return { - LanguageType.CHINESE: r"You are a helpful assistant.", - LanguageType.ENGLISH: r"You are a helpful assistant.", - }, { - LanguageType.CHINESE: r""" - - - 你的目标是:帮助用户在绘制图表时做出样式选择。 - 请以JSON格式输出你的选择。 - - 图标类型: - - `bar`: 柱状图 - - `pie`: 饼图 - - `line`: 折线图 - - `scatter`: 散点图 - 柱状图的附加样式: - - `normal`: 普通柱状图 - - `stacked`: 堆叠柱状图 - 饼图的附加样式: - - `normal`: 普通饼图 - - `ring`: 环形饼图 - 可用坐标比例: - - `linear`: 线性比例 - - `log`: 对数比例 - - - - ## 问题 - 查询数据库中的数据,并绘制堆叠柱状图。 - - ## 思考 - 让我们一步步思考。用户要求绘制堆叠柱状图,因此图表类型应为 `bar`,即柱状图;图表样式\ -应为 `stacked`,即堆叠形式。 - - ## 答案 - { - "chart_type": "bar", - "additional_style": "stacked", - "scale_type": "linear" - } - - - - ## 问题 - {question} - - ## 思考 - 让我们一步步思考。 - """, - LanguageType.ENGLISH: r""" - - - Your mission is: help the user make style choices when drawing a chart. - Please output your choices in JSON format. - - Chart types: - - `bar`: Bar chart - - `pie`: Pie chart - - `line`: Line chart - - `scatter`: Scatter chart - - Bar chart additional styles: - - `normal`: Normal bar chart - - `stacked`: Stacked bar chart - - Pie chart additional styles: - - `normal`: Normal pie chart - - `ring`: Ring pie chart - - Axis scaling: - - `linear`: Linear scaling - - `log`: Logarithmic scaling - - - - ## Question - Query the data from the database and draw a stacked bar chart. - - ## Thought - Let's think step by step. The user requires drawing a stacked bar chart, so the chart type \ -should be `bar`, i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form. - - ## Answer - { - "chart_type": "bar", - "additional_style": "stacked", - "scale_type": "linear" - } - - - - ## Question - {question} - - ## Thought - Let's think step by step. - """, - } - - async def generate(self, **kwargs) -> dict[str, Any]: # noqa: ANN003 - """使用LLM选择图表样式""" - question = kwargs["question"] - language = kwargs.get("language", LanguageType.CHINESE) - - # 使用Reasoning模型进行推理 - messages = [ - {"role": "system", "content": self.system_prompt[language]}, - {"role": "user", "content": self.user_prompt[language].format(question=question)}, - ] - result = "" - llm = ReasoningLLM() - async for chunk in llm.call(messages, streaming=False): - result += chunk - self.input_tokens = llm.input_tokens - self.output_tokens = llm.output_tokens - - messages += [ - {"role": "assistant", "content": result}, - ] - - # 使用FunctionLLM模型进行提取参数 - json_gen = JsonGenerator( - query="根据给定的背景信息,生成预测问题", - conversation=messages, - schema=RenderStyleResult.model_json_schema(), - ) - try: - result_dict = await json_gen.generate() - RenderStyleResult.model_validate(result_dict) - except Exception: - logger.exception("[RenderStyle] 选择图表样式失败") - return {} - - return result_dict diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 3b88fb4f..ddfcaa01 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -72,7 +72,9 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): ] # 使用大模型进行尝试 - answer = await self._llm(messages=conversation, streaming=False) + answer = "" + async for chunk in self._llm(messages=conversation, streaming=True): + answer += chunk answer = await FunctionLLM.process_response(answer) try: data = json.loads(answer) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index d2730c4d..3a63914d 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -31,7 +31,7 @@ class MCPAgentExecutor(BaseExecutor): max_steps: int = Field(default=40, description="最大步数") servers_id: list[str] = Field(description="MCP server id") - agent_id: str = Field(default="", description="Agent ID") + agent_id: uuid.UUID = Field(default=uuid.uuid4(), description="App ID作为Agent ID") agent_description: str = Field(default="", description="Agent描述") mcp_list: list[MCPInfo] = Field(description="MCP服务器列表", default=[]) mcp_pool: MCPPool = Field(description="MCP池", default=MCPPool()) @@ -56,15 +56,15 @@ class MCPAgentExecutor(BaseExecutor): async def init(self) -> None: """初始化MCP Agent""" - self.planner = MCPPlanner(self.runtime.userInput, self.resoning_llm, self.runtime.language) - self.host = MCPHost(self.task.userSub, self.resoning_llm) + self.planner = MCPPlanner(self.task.runtime.userInput, self.resoning_llm, self.task.runtime.language) + self.host = MCPHost(self.task.metadata.userSub, self.resoning_llm) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[MCPAgentExecutor] 加载Executor状态") # 尝试恢复State - if self.state and self.state.executorStatus != ExecutorStatus.INIT: - self.context = await TaskManager.get_context_by_task_id(self.task.id) + if self.task.state and self.task.state.executorStatus != ExecutorStatus.INIT: + self.task.context = await TaskManager.get_context_by_task_id(self.task.metadata.id) async def load_mcp(self) -> None: """加载MCP服务器列表""" @@ -74,16 +74,16 @@ class MCPAgentExecutor(BaseExecutor): mcp_ids = app.mcp_service for mcp_id in mcp_ids: mcp_service = await MCPServiceManager.get_mcp_service(mcp_id) - if self.task.userSub not in mcp_service.activated: + if self.task.metadata.userSub not in mcp_service.activated: logger.warning( "[MCPAgentExecutor] 用户 %s 未启用MCP %s", - self.task.userSub, + self.task.metadata.userSub, mcp_id, ) continue self.mcp_list.append(mcp_service) - await self.mcp_pool.init_mcp(mcp_id, self.task.userSub) + await self.mcp_pool.init_mcp(mcp_id, self.task.metadata.userSub) for tool in mcp_service.tools: self.tools[tool.id] = tool self.tool_list.extend(mcp_service.tools) @@ -100,9 +100,9 @@ class MCPAgentExecutor(BaseExecutor): """获取工具输入参数""" if is_first: # 获取第一个输入参数 - mcp_tool = self.tools[self.state.toolId] - self.state.currentInput = await MCPHost._get_first_input_params( - mcp_tool, self.runtime.userInput, self.state.stepDescription, self.task, + mcp_tool = self.tools[self.task.state.toolId] + self.task.state.currentInput = await self.host.get_first_input_params( + mcp_tool, self.task.runtime.userInput, self.task, ) else: # 获取后续输入参数 @@ -113,65 +113,74 @@ class MCPAgentExecutor(BaseExecutor): params = {} params_description = "" mcp_tool = self.tools[self.task.state.tool_id] - self.state.currentInput = await MCPHost.fill_params( + self.task.state.currentInput = await self.host.fill_params( mcp_tool, - self.runtime.userInput, - self.state.stepDescription, - self.state.currentInput, - self.state.errorMessage, + self.task.runtime.userInput, + self.task.state.currentInput, + self.task.state.errorMessage, params, params_description, - self.runtime.language, + self.task.runtime.language, ) async def confirm_before_step(self) -> None: """确认前步骤""" + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + # 发送确认消息 mcp_tool = self.tools[self.task.state.tool_id] - confirm_message = await MCPPlanner.get_tool_risk( - mcp_tool, self.task.state.current_input, "", self.resoning_llm, self.task.language + confirm_message = await self.planner.get_tool_risk( + mcp_tool, self.task.state.currentInput, "", self.resoning_llm, self.task.runtime.language, ) await self.update_tokens() await self.push_message( EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(exclude_none=True, by_alias=True), ) await self.push_message(EventType.FLOW_STOP, {}) - self.state.executorStatus = ExecutorStatus.WAITING - self.state.stepStatus = StepStatus.WAITING - self.context.append( + self.task.state.executorStatus = ExecutorStatus.WAITING + self.task.state.stepStatus = StepStatus.WAITING + self.task.context.append( ExecutorHistory( - task_id=self.task.id, - step_id=self.task.state.step_id, - step_name=self.task.state.step_name, - step_description=self.task.state.step_description, - step_status=self.task.state.step_status, - flow_id=self.task.state.flow_id, - flow_name=self.task.state.flow_name, - flow_status=self.task.state.flow_status, - input_data={}, - output_data={}, - ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True), - ) + taskId=self.task.metadata.id, + stepId=self.task.state.stepId, + stepName=self.task.state.stepName, + stepDescription=self.task.state.stepDescription, + stepStatus=self.task.state.stepStatus, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, + inputData={}, + outputData={}, + extraData=confirm_message.model_dump(exclude_none=True, by_alias=True), + ), ) async def run_step(self) -> None: """执行步骤""" - self.state.executorStatus = ExecutorStatus.RUNNING - self.state.stepStatus = StepStatus.RUNNING - mcp_tool = self.tools[self.state.toolId] - mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.ids.user_sub)) + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + + self.task.state.executorStatus = ExecutorStatus.RUNNING + self.task.state.stepStatus = StepStatus.RUNNING + mcp_tool = self.tools[self.task.state.toolId] + mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.metadata.userSub)) try: - output_params = await mcp_client.call_tool(mcp_tool.name, self.state.currentInput) + output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.currentInput) except anyio.ClosedResourceError: logger.exception("[MCPAgentExecutor] MCP客户端连接已关闭: %s", mcp_tool.mcp_id) - await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.ids.user_sub) - await self.mcp_pool.init_mcp(mcp_tool.mcp_id, self.task.ids.user_sub) - self.state.stepStatus = StepStatus.ERROR + await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.metadata.userSub) + await self.mcp_pool.init_mcp(mcp_tool.mcp_id, self.task.metadata.userSub) + self.task.state.stepStatus = StepStatus.ERROR return except Exception as e: logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误", mcp_tool.name) - self.state.stepStatus = StepStatus.ERROR - self.state.errorMessage = str(e) + self.task.state.stepStatus = StepStatus.ERROR + self.task.state.errorMessage = str(e) return logger.error(f"当前工具名称: {mcp_tool.name}, 输出参数: {output_params}") if output_params.isError: @@ -179,8 +188,8 @@ class MCPAgentExecutor(BaseExecutor): for output in output_params.content: if isinstance(output, TextContent): err += output.text - self.state.stepStatus = StepStatus.ERROR - self.state.errorMessage = { + self.task.state.stepStatus = StepStatus.ERROR + self.task.state.errorMessage = { "err_msg": err, "data": {}, } @@ -194,54 +203,59 @@ class MCPAgentExecutor(BaseExecutor): } await self.update_tokens() - await self.push_message(EventType.STEP_INPUT, self.state.currentInput) + await self.push_message(EventType.STEP_INPUT, self.task.state.currentInput) await self.push_message(EventType.STEP_OUTPUT, output_params) - self.context.append( + self.task.context.append( ExecutorHistory( - taskId=self.task.id, - stepId=self.state.stepId, - stepName=self.state.stepName, - stepDescription=self.state.stepDescription, + taskId=self.task.metadata.id, + stepId=self.task.state.stepId, + stepName=self.task.state.stepName, + stepDescription=self.task.state.stepDescription, stepStatus=StepStatus.SUCCESS, - executorId=self.state.executorId, - executorName=self.state.executorName, - executorStatus=self.state.executorStatus, - inputData=self.state.currentInput, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, + inputData=self.task.state.currentInput, outputData=output_params, ), ) - self.state.stepStatus = StepStatus.SUCCESS + self.task.state.stepStatus = StepStatus.SUCCESS async def generate_params_with_null(self) -> None: """生成参数补充""" - mcp_tool = self.tools[self.state.toolId] + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + + mcp_tool = self.tools[self.task.state.toolId] params_with_null = await self.planner.get_missing_param( mcp_tool, - self.state.currentInput, - self.state.errorMessage, + self.task.state.currentInput, + self.task.state.errorMessage, ) await self.update_tokens() error_message = await self.planner.change_err_message_to_description( - error_message=self.state.errorMessage, + error_message=self.task.state.errorMessage, tool=mcp_tool, - input_params=self.state.currentInput, + input_params=self.task.state.currentInput, ) await self.push_message( EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null}, ) await self.push_message(EventType.FLOW_STOP, data={}) - self.state.executorStatus = ExecutorStatus.WAITING - self.state.stepStatus = StepStatus.PARAM - self.context.append( + self.task.state.executorStatus = ExecutorStatus.WAITING + self.task.state.stepStatus = StepStatus.PARAM + self.task.context.append( ExecutorHistory( - taskId=self.task.id, - stepId=self.state.stepId, - stepName=self.state.stepName, - stepDescription=self.state.stepDescription, - stepStatus=self.state.stepStatus, - executorId=self.state.executorId, - executorName=self.state.executorName, - executorStatus=self.state.executorStatus, + taskId=self.task.metadata.id, + stepId=self.task.state.stepId, + stepName=self.task.state.stepName, + stepDescription=self.task.state.stepDescription, + stepStatus=self.task.state.stepStatus, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, inputData={}, outputData={}, extraData={ @@ -253,9 +267,14 @@ class MCPAgentExecutor(BaseExecutor): async def get_next_step(self) -> None: """获取下一步""" + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + if self.step_cnt < self.max_steps: self.step_cnt += 1 - history = await MCPHost.assemble_memory(self.task) + history = await self.host.assemble_memory(self.task.runtime, self.task.context) max_retry = 3 step = None for _ in range(max_retry): @@ -273,38 +292,43 @@ class MCPAgentExecutor(BaseExecutor): tool_id = step.tool_id step_name = FINAL_TOOL_ID if tool_id == FINAL_TOOL_ID else self.tools[tool_id].name step_description = step.description - self.state.stepId = uuid.uuid4() - self.state.toolId = tool_id - self.state.stepName = step_name - self.state.stepDescription = step_description - self.state.stepStatus = StepStatus.INIT - self.state.currentInput = {} + self.task.state.stepId = uuid.uuid4() + self.task.state.toolId = tool_id + self.task.state.stepName = step_name + self.task.state.stepDescription = step_description + self.task.state.stepStatus = StepStatus.INIT + self.task.state.currentInput = {} else: # 没有下一步了,结束流程 - self.state.toolId = FINAL_TOOL_ID + self.task.state.toolId = FINAL_TOOL_ID async def error_handle_after_step(self) -> None: """步骤执行失败后的错误处理""" - self.state.stepStatus = StepStatus.ERROR - self.state.executorStatus = ExecutorStatus.ERROR + if not self.task.state: + err = "[MCPAgentExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + + self.task.state.stepStatus = StepStatus.ERROR + self.task.state.executorStatus = ExecutorStatus.ERROR await self.push_message( EventType.FLOW_FAILED, data={}, ) - if len(self.context) and self.context[-1].step_id == self.state.step_id: - del self.context[-1] - self.context.append( + if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId: + del self.task.context[-1] + self.task.context.append( ExecutorHistory( - task_id=self.task.id, - step_id=self.task.state.step_id, - step_name=self.task.state.step_name, - step_description=self.task.state.step_description, - step_status=self.task.state.step_status, - flow_id=self.task.state.flow_id, - flow_name=self.task.state.flow_name, - flow_status=self.task.state.flow_status, - input_data={}, - output_data={}, + taskId=self.task.metadata.id, + stepId=self.task.state.stepId, + stepName=self.task.state.stepName, + stepDescription=self.task.state.stepDescription, + stepStatus=self.task.state.stepStatus, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, + inputData={}, + outputData={}, ), ) diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 1a1b12b9..0cd0d837 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -12,8 +12,8 @@ from apps.schemas.message import TextAddContent if TYPE_CHECKING: from apps.common.queue import MessageQueue - from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime from apps.schemas.scheduler import ExecutorBackground + from apps.schemas.task import TaskData logger = logging.getLogger(__name__) @@ -21,10 +21,7 @@ logger = logging.getLogger(__name__) class BaseExecutor(BaseModel, ABC): """Executor基类""" - task: "Task" - runtime: "TaskRuntime" - state: "ExecutorCheckpoint" - context: list["ExecutorHistory"] + task: "TaskData" msg_queue: "MessageQueue" background: "ExecutorBackground" @@ -51,7 +48,7 @@ class BaseExecutor(BaseModel, ABC): data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) await self.msg_queue.push_output( - self.task, + self.task.metadata, event_type=event_type, data=data, ) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 9c41d27e..8f75c287 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -89,8 +89,8 @@ class FlowExecutor(BaseExecutor): executorName=self.flow.name, executorStatus=ExecutorStatus.RUNNING, stepStatus=StepStatus.RUNNING, - stepId="start", - stepName="开始" if self.runtime.language == LanguageType.CHINESE else "Start", + stepId=self.flow.basicConfig.startStep, + stepName=self.flow.steps[self.flow.basicConfig.startStep].name, ) # 是否到达Flow结束终点(变量) self._reached_end: bool = False @@ -106,9 +106,6 @@ class FlowExecutor(BaseExecutor): step=self.current_step, background=self.background, question=self.question, - runtime=self.runtime, - state=self.state, - context=self.context, ) # 初始化步骤 @@ -161,8 +158,8 @@ class FlowExecutor(BaseExecutor): if not next_steps: return [ StepQueueItem( - step_id="end", - step=self.flow.steps["end"], + step_id=self.flow.basicConfig.endStep, + step=self.flow.steps[self.flow.basicConfig.endStep], ), ] @@ -197,7 +194,7 @@ class FlowExecutor(BaseExecutor): self.step_queue.append( StepQueueItem( step_id=uuid.uuid4(), - step=step.get(self.runtime.language, step[LanguageType.CHINESE]), + step=step.get(self.task.runtime.language, step[LanguageType.CHINESE]), enable_filling=False, to_user=False, ), @@ -220,15 +217,15 @@ class FlowExecutor(BaseExecutor): step_id=uuid.uuid4(), step=Step( name=( - "错误处理" if self.runtime.language == LanguageType.CHINESE else "Error Handling" + "错误处理" if self.task.runtime.language == LanguageType.CHINESE else "Error Handling" ), description=( - "错误处理" if self.runtime.language == LanguageType.CHINESE else "Error Handling" + "错误处理" if self.task.runtime.language == LanguageType.CHINESE else "Error Handling" ), node=SpecialCallType.LLM.value, type=SpecialCallType.LLM.value, params={ - "user_prompt": LLM_ERROR_PROMPT[self.runtime.language].replace( + "user_prompt": LLM_ERROR_PROMPT[self.task.runtime.language].replace( "{{ error_info }}", self.state.errorMessage["err_msg"], ), @@ -264,13 +261,13 @@ class FlowExecutor(BaseExecutor): self.step_queue.append( StepQueueItem( step_id=uuid.uuid4(), - step=step.get(self.runtime.language, step[LanguageType.CHINESE]), + step=step.get(self.task.runtime.language, step[LanguageType.CHINESE]), ), ) await self._step_process() # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) - self.runtime.time = round(datetime.now(UTC).timestamp(), 2) - self.runtime.fullTime + self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.fullTime # 推送Flow停止消息 if is_error: await self.push_message(EventType.FLOW_FAILED.value) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index a3ddab27..0a1bbf24 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -85,10 +85,15 @@ class StepExecutor(BaseExecutor): """初始化步骤""" logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) + if not self.task.state: + err = "[StepExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + # State写入ID和运行状态 - self.state.stepId = self.step.step_id - self.state.stepDescription = self.step.step.description - self.state.stepName = self.step.step.name + self.task.state.stepId = self.step.step_id + self.task.state.stepDescription = self.step.step.description + self.task.state.stepName = self.step.step.name # 获取并验证Call类 node_id = self.step.step.node @@ -127,15 +132,20 @@ class StepExecutor(BaseExecutor): if not self.obj.enable_filling: return + if not self.task.state: + err = "[StepExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + # 暂存旧数据 - current_step_id = self.state.stepId - current_step_name = self.state.stepName + current_step_id = self.task.state.stepId + current_step_name = self.task.state.stepName # 更新State - self.state.stepId = uuid.uuid4() - self.state.stepName = "自动参数填充" - self.state.stepStatus = StepStatus.RUNNING - self.runtime.time = round(datetime.now(UTC).timestamp(), 2) + self.task.state.stepId = uuid.uuid4() + self.task.state.stepName = "自动参数填充" + self.task.state.stepStatus = StepStatus.RUNNING + self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 slot_obj = await Slot.instance( @@ -153,26 +163,26 @@ class StepExecutor(BaseExecutor): async for chunk in iterator: result: SlotOutput = SlotOutput.model_validate(chunk.content) await TaskManager.update_task_token( - self.task.id, + self.task.metadata.id, input_token=slot_obj.tokens.input_tokens, output_token=slot_obj.tokens.output_tokens, ) # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.state.stepStatus = StepStatus.PARAM + self.task.state.stepStatus = StepStatus.PARAM else: - self.state.stepStatus = StepStatus.SUCCESS + self.task.state.stepStatus = StepStatus.SUCCESS await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.state.stepId = current_step_id - self.state.stepName = current_step_name + self.task.state.stepId = current_step_id + self.task.state.stepName = current_step_name await TaskManager.update_task_token( - self.task.id, + self.task.metadata.id, input_token=self.obj.tokens.input_tokens, output_token=self.obj.tokens.output_tokens, ) @@ -203,7 +213,7 @@ class StepExecutor(BaseExecutor): if to_user: if isinstance(chunk.content, str): await self.push_message(EventType.TEXT_ADD.value, chunk.content) - self.runtime.fullAnswer += chunk.content + self.task.runtime.fullAnswer += chunk.content else: await self.push_message(self.step.step.type, chunk.content) @@ -214,12 +224,17 @@ class StepExecutor(BaseExecutor): """运行单个步骤""" logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name) + if not self.task.state: + err = "[StepExecutor] 任务状态不存在" + logger.error(err) + raise RuntimeError(err) + # 进行自动参数填充 await self._run_slot_filling() # 更新状态 - self.state.stepStatus = StepStatus.RUNNING - self.runtime.time = round(datetime.now(UTC).timestamp(), 2) + self.task.state.stepStatus = StepStatus.RUNNING + self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -230,27 +245,27 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.state.stepStatus = StepStatus.ERROR + self.task.state.stepStatus = StepStatus.ERROR await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.state.errorMessage = { + self.task.state.errorMessage = { "err_msg": e.message, "data": e.data, } else: - self.state.errorMessage = { + self.task.state.errorMessage = { "data": {}, } return # 更新执行状态 - self.state.stepStatus = StepStatus.SUCCESS + self.task.state.stepStatus = StepStatus.SUCCESS await TaskManager.update_task_token( - self.task.id, + self.task.metadata.id, input_token=self.obj.tokens.input_tokens, output_token=self.obj.tokens.output_tokens, ) - self.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.runtime.time + self.task.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.time # 更新history if isinstance(content, str): @@ -260,18 +275,18 @@ class StepExecutor(BaseExecutor): # 更新context history = ExecutorHistory( - taskId=self.task.id, - executorId=self.state.executorId, - executorName=self.state.executorName, - executorStatus=self.state.executorStatus, + taskId=self.task.metadata.id, + executorId=self.task.state.executorId, + executorName=self.task.state.executorName, + executorStatus=self.task.state.executorStatus, stepId=self.step.step_id, stepName=self.step.step.name, stepDescription=self.step.step.description, - stepStatus=self.state.stepStatus, + stepStatus=self.task.state.stepStatus, inputData=self.obj.input, outputData=output_data, ) - self.context.append(history) + self.task.context.append(history) # 推送输出 await self.push_message(EventType.STEP_OUTPUT.value, output_data) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index d28dc684..ec2e004d 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -28,7 +28,7 @@ class MCPHost: """MCP宿主服务""" def __init__( - self, user_sub: str, task_id: uuid.UUID, runtime_id: uuid.UUID, runtime_name: str, + self, user_sub: str, task_id: uuid.UUID, runtime_id: str, runtime_name: str, language: LanguageType, ) -> None: """初始化MCP宿主""" @@ -63,7 +63,7 @@ class MCPHost: async def assemble_memory(self) -> str: """组装记忆""" - task = await TaskManager.get_task_by_task_id(self._task_id) + task = await TaskManager.get_task_data_by_task_id(self._task_id) if not task: logger.error("任务 %s 不存在", self._task_id) return "" @@ -107,7 +107,7 @@ class MCPHost: executorId=self._runtime_id, executorName=self._runtime_name, executorStatus=ExecutorStatus.RUNNING, - stepId=tool.id, + stepId=uuid.uuid4(), stepName=tool.toolName, # description是规划的实际内容 stepDescription=plan_item.content, @@ -117,12 +117,12 @@ class MCPHost: ) # 保存到task - task = await TaskManager.get_task_by_task_id(self._task_id) + task = await TaskManager.get_task_data_by_task_id(self._task_id) if not task: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - context.append(context.model_dump(exclude_none=True, by_alias=True)) + context.append(context) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 1b273ba8..cad4256b 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -5,7 +5,7 @@ import logging from typing import Any from apps.llm.function import JsonGenerator -from apps.llm.reasoning import ReasoningLLM +from apps.models.llm import LLMData logger = logging.getLogger(__name__) @@ -13,7 +13,14 @@ logger = logging.getLogger(__name__) class MCPBase: """MCP基类""" - llm: ReasoningLLM + user_sub: str + + def __init__(self, user_sub: str) -> None: + """初始化MCP基类""" + self.user_sub = user_sub + + async def init(self) -> None: + pass async def get_resoning_result(self, prompt: str) -> str: """获取推理结果""" diff --git a/apps/scheduler/pool/check.py b/apps/scheduler/pool/check.py index 55f8d8c7..2e1d5f0e 100644 --- a/apps/scheduler/pool/check.py +++ b/apps/scheduler/pool/check.py @@ -55,7 +55,7 @@ class FileChecker: return self.hashes[path_diff.as_posix()] != previous_hashes - async def diff(self, check_type: MetadataType) -> tuple[list[str], list[str]]: + async def diff(self, check_type: MetadataType) -> tuple[list[uuid.UUID], list[uuid.UUID]]: """生成更新列表和删除列表""" async with postgres.session() as session: # 判断类型 diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index fce4344e..e41363e9 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -111,18 +111,18 @@ class Pool: # 批量删除App for app in changed_app: - await AppLoader.delete(uuid.UUID(app), is_reload=True) + await AppLoader.delete(app, is_reload=True) for app in deleted_app: - await AppLoader.delete(uuid.UUID(app)) + await AppLoader.delete(app) # 批量加载App for app in changed_app: hash_key = Path("app/" + str(app)).as_posix() if hash_key in checker.hashes: try: - await AppLoader.load(uuid.UUID(app), checker.hashes[hash_key]) + await AppLoader.load(app, checker.hashes[hash_key]) except Exception as e: # noqa: BLE001 - await AppLoader.delete(uuid.UUID(app), is_reload=True) + await AppLoader.delete(app, is_reload=True) logger.warning("[Pool] 加载App %s 失败: %s", app, str(e)) # 载入MCP diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 9d46a3fc..da40cc71 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -42,28 +42,3 @@ class FlowChooser: await TaskManager.update_task_token(self.task_id, select_obj.input_tokens, select_obj.output_tokens) return top_flow - - - async def choose_flow(self) -> RequestDataApp | None: - """ - 依据用户的输入和选择,构造对应的Flow。 - - - 当用户没有选择任何app时,直接进行智能问答 - - 当用户选择了特定的app时,在plugin内挑选最适合的flow - """ - if not self._user_selected or not self._user_selected.app_id: - return None - - if self._user_selected.flow_id: - return self._user_selected - - top_flow = await self.get_top_flow() - # FIXME KnowledgeBase不是UUID,要改个值 - if top_flow == "KnowledgeBase": - return None - - return RequestDataApp( - appId=self._user_selected.app_id, - flowId=top_flow, - params=None, - ) diff --git a/apps/schemas/task.py b/apps/schemas/task.py index ad8244af..8b74073d 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -6,9 +6,20 @@ from typing import Any from pydantic import BaseModel, Field +from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime + from .flow import Step +class TaskData(BaseModel): + """任务数据""" + + metadata: Task = Field(description="任务") + runtime: TaskRuntime = Field(description="任务运行时数据") + state: ExecutorCheckpoint | None = Field(description="执行状态") + context: list[ExecutorHistory] = Field(description="执行历史") + + class CheckpointExtra(BaseModel): """Executor额外数据""" @@ -16,11 +27,6 @@ class CheckpointExtra(BaseModel): error_message: str = Field(description="错误信息", default="") retry_times: int = Field(description="当前步骤重试次数", default=0) - -class TaskExtra(BaseModel): - """任务额外数据""" - - class StepQueueItem(BaseModel): """步骤栈中的元素""" diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 131c7a0e..aa251e0e 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -158,7 +158,7 @@ class ConversationManager: await session.delete(conv) await session.commit() - await TaskManager.delete_task_history_checkpoint_by_conversation_id(conversation_id) + await TaskManager.delete_tasks_by_conversation_id(conversation_id) @staticmethod diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index ed1eb695..b44dc729 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -3,6 +3,7 @@ import collections import logging +import uuid from typing import TYPE_CHECKING from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError @@ -94,7 +95,7 @@ class FlowService: return flow_item @staticmethod - async def _validate_node_ids(nodes: list[NodeItem]) -> tuple[str, str]: + async def _validate_node_ids(nodes: list[NodeItem]) -> tuple[uuid.UUID, uuid.UUID]: """验证节点ID的唯一性并获取起始和终止节点ID,当节点ID重复或起始/终止节点数量不为1时抛出异常""" ids = set() start_cnt = 0 @@ -128,7 +129,7 @@ class FlowService: return start_id, end_id @staticmethod - async def validate_flow_illegal(flow_item: FlowItem) -> tuple[str, str]: + async def validate_flow_illegal(flow_item: FlowItem) -> tuple[uuid.UUID, uuid.UUID]: """验证流程图是否合法;当流程图不合法时抛出异常""" # 验证节点ID并获取起始和终止节点 start_id, end_id = await FlowService._validate_node_ids(flow_item.nodes) diff --git a/apps/services/llm.py b/apps/services/llm.py index 0cd31fef..ceb76c6d 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -56,7 +56,7 @@ class LLMManager: logger.error("[LLMManager] 用户 %s 不存在", user_sub) return None - return user.defaultLLM + return user.reasoningLLM @staticmethod @@ -186,7 +186,7 @@ class LLMManager: if not user: err = f"[LLMManager] 用户 {user_sub} 不存在" raise ValueError(err) - user.defaultLLM = None + user.reasoningLLM = None await session.commit() @@ -203,7 +203,7 @@ class LLMManager: if not user: err = f"[LLMManager] 用户 {user_sub} 不存在" raise ValueError(err) - user.defaultLLM = llm_id + user.reasoningLLM = llm_id await session.commit() diff --git a/apps/services/rag.py b/apps/services/rag.py index caf2930c..af5bd506 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -16,12 +16,13 @@ from apps.llm.patterns.rewrite import QuestionRewrite from apps.llm.reasoning import ReasoningLLM from apps.llm.token import TokenCalculator from apps.models.llm import LLMData -from apps.schemas.config import LLMConfig from apps.schemas.enum_var import EventType, LanguageType from apps.schemas.rag_data import RAGQueryReq +from apps.services.llm import LLMManager from apps.services.session import SessionManager logger = logging.getLogger(__name__) +CHUNK_ELEMENT_TOKENS = 5 class RAG: @@ -127,7 +128,7 @@ class RAG: @staticmethod async def get_doc_info_from_rag( - user_sub: str, max_tokens: int, doc_ids: list[str], data: RAGQueryReq, + user_sub: str, max_tokens: int | None, doc_ids: list[str], data: RAGQueryReq, ) -> list[dict[str, Any]]: """获取RAG服务的文档信息""" session_id = await SessionManager.get_session_by_user_sub(user_sub) @@ -182,32 +183,8 @@ class RAG: doc_info_list = [] doc_cnt = 0 doc_id_map = {} - leave_tokens = max_tokens - token_calculator = TokenCalculator() - for doc_chunk in doc_chunk_list: - if doc_chunk["docId"] not in doc_id_map: - doc_cnt += 1 - doc_id_map[doc_chunk["docId"]] = doc_cnt - doc_index = doc_id_map[doc_chunk["docId"]] - leave_tokens -= token_calculator.calculate_token_length( - messages=[ - { - "role": "user", - "content": f"""""", - }, - {"role": "user", "content": ""}, - ], - pure_text=True, - ) - tokens_of_chunk_element = token_calculator.calculate_token_length( - messages=[ - {"role": "user", "content": ""}, - {"role": "user", "content": ""}, - ], - pure_text=True, - ) - doc_cnt = 0 - doc_id_map = {} + remaining_tokens = max_tokens * 0.8 + for doc_chunk in doc_chunk_list: if doc_chunk["docId"] not in doc_id_map: doc_cnt += 1 @@ -231,16 +208,18 @@ class RAG: }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] + if bac_info: bac_info += "\n\n" bac_info += f"""""" + for chunk in doc_chunk["chunks"]: - if leave_tokens <= tokens_of_chunk_element: + if remaining_tokens <= CHUNK_ELEMENT_TOKENS: break chunk_text = chunk["text"] - chunk_text = TokenCalculator.get_k_tokens_words_from_content( - content=chunk_text, k=leave_tokens) - leave_tokens -= token_calculator.calculate_token_length(messages=[ + chunk_text = TokenCalculator().get_k_tokens_words_from_content( + content=chunk_text, k=remaining_tokens) + remaining_tokens -= TokenCalculator().calculate_token_length(messages=[ {"role": "user", "content": ""}, {"role": "user", "content": chunk_text}, {"role": "user", "content": ""}, @@ -256,21 +235,20 @@ class RAG: @staticmethod async def chat_with_llm_base_on_rag( # noqa: C901, PLR0913 user_sub: str, - llm: LLMData, + llm_id: str, history: list[dict[str, str]], doc_ids: list[str], data: RAGQueryReq, language: LanguageType = LanguageType.CHINESE, ) -> AsyncGenerator[str, None]: """获取RAG服务的结果""" - reasion_llm = ReasoningLLM( - LLMConfig( - endpoint=llm.openaiBaseUrl, - key=llm.openaiAPIKey, - model=llm.modelName, - max_tokens=llm.maxToken, - ), - ) + llm_config = await LLMManager.get_llm(llm_id) + if not llm_config: + err = "[RAG] 未设置问答所用LLM" + logger.error(err) + raise RuntimeError(err) + reasion_llm = ReasoningLLM(llm_config) + if history: try: question_obj = QuestionRewrite() @@ -280,9 +258,11 @@ class RAG: except Exception: logger.exception("[RAG] 问题重写失败") doc_chunk_list = await RAG.get_doc_info_from_rag( - user_sub=user_sub, max_tokens=llm.maxToken, doc_ids=doc_ids, data=data) + user_sub=user_sub, max_tokens=llm_config.maxToken, doc_ids=doc_ids, data=data, + ) bac_info, doc_info_list = await RAG.assemble_doc_info( - doc_chunk_list=doc_chunk_list, max_tokens=llm.maxToken) + doc_chunk_list=doc_chunk_list, max_tokens=llm_config.maxToken, + ) messages = [ *history, { @@ -323,11 +303,11 @@ class RAG: buffer = "" async for chunk in reasion_llm.call( messages, - max_tokens=llm.maxToken, + max_tokens=llm_config.maxToken, streaming=True, temperature=0.7, result_only=False, - model=llm.modelName, + model=llm_config.modelName, ): tmp_chunk = buffer + chunk # 防止脚注被截断 diff --git a/apps/services/task.py b/apps/services/task.py index a2771f26..a3c1bff1 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -9,7 +9,7 @@ from sqlalchemy import and_, delete, select, update from apps.common.postgres import postgres from apps.models.conversation import Conversation from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime -from apps.schemas.request_data import RequestData +from apps.schemas.task import TaskData logger = logging.getLogger(__name__) @@ -48,115 +48,65 @@ class TaskManager: @staticmethod - async def get_task_by_task_id(task_id: uuid.UUID) -> Task | None: + async def get_task_data_by_task_id(task_id: uuid.UUID, context_length: int | None = None) -> TaskData | None: """根据task_id获取任务""" async with postgres.session() as session: - return (await session.scalars( + task_data = (await session.scalars( select(Task).where(Task.id == task_id), )).one_or_none() + if not task_data: + logger.error("[TaskManager] 任务不存在 %s", task_id) + return None - - @staticmethod - async def get_task_runtime_by_task_id(task_id: uuid.UUID) -> TaskRuntime | None: - """根据task_id获取任务运行时""" - async with postgres.session() as session: - return (await session.scalars( + runtime = (await session.scalars( select(TaskRuntime).where(TaskRuntime.taskId == task_id), )).one_or_none() + if not runtime: + runtime = TaskRuntime( + taskId=task_id, + inputToken=0, + outputToken=0, + ) - - @staticmethod - async def get_task_state_by_task_id(task_id: uuid.UUID) -> ExecutorCheckpoint | None: - """根据task_id获取任务状态""" - async with postgres.session() as session: - return (await session.scalars( + state = (await session.scalars( select(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == task_id), )).one_or_none() - - @staticmethod - async def get_context_by_task_id(task_id: uuid.UUID, length: int | None = None) -> list[ExecutorHistory]: - """根据task_id获取flow信息""" - async with postgres.session() as session: - return list((await session.scalars( + if context_length == 0: + context = [] + else: + context = list((await session.scalars( select(ExecutorHistory).where( ExecutorHistory.taskId == task_id, - ).order_by(ExecutorHistory.updatedAt.desc()).limit(length), + ).order_by(ExecutorHistory.updatedAt.desc()).limit(context_length), )).all()) - - @staticmethod - async def init_new_task( - user_sub: str, - session_id: str | None = None, - post_body: RequestData | None = None, - ) -> Task: - """获取任务块""" - return Task( - _id=str(uuid.uuid4()), - ids=TaskIds( - user_sub=user_sub if user_sub else "", - session_id=session_id if session_id else "", - conversation_id=post_body.conversation_id, - ), - question=post_body.question if post_body else "", - tokens=TaskTokens(), - runtime=TaskRuntime(), - ) - - @staticmethod - async def save_flow_context(task_id: str, flow_context: list[ExecutorHistory]) -> None: - """保存flow信息到flow_context""" - if not flow_context: - return - - flow_context_collection = MongoDB().get_collection("flow_context") - try: - for history in flow_context: - # 查找是否存在 - current_context = await flow_context_collection.find_one({ - "task_id": task_id, - "_id": history.id, - }) - if current_context: - await flow_context_collection.update_one( - {"_id": current_context["_id"]}, - {"$set": history.model_dump(exclude_none=True, by_alias=True)}, - ) - else: - await flow_context_collection.insert_one(history.model_dump(exclude_none=True, by_alias=True)) - except Exception: - logger.exception("[TaskManager] 保存flow执行记录失败") + return TaskData( + metadata=task_data, + runtime=runtime, + state=state, + context=context, + ) @staticmethod async def delete_task_by_task_id(task_id: uuid.UUID) -> None: """通过task_id删除Task信息""" async with postgres.session() as session: - task = (await session.scalars( - select(Task).where(Task.id == task_id), - )).one_or_none() - if task: - await session.delete(task) - - - @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> list[uuid.UUID]: - """通过ConversationID删除Task信息""" - async with postgres.session() as session: - task_ids = [] - tasks = (await session.scalars( - select(Task).where(Task.conversationId == conversation_id), - )).all() - for task in tasks: - task_ids.append(str(task.id)) - await session.delete(task) + await session.execute( + delete(Task).where(Task.id == task_id), + ) + await session.execute( + delete(TaskRuntime).where(TaskRuntime.taskId == task_id), + ) + await session.execute( + delete(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == task_id), + ) await session.commit() - return task_ids @staticmethod - async def delete_task_history_checkpoint_by_conversation_id(conversation_id: uuid.UUID) -> None: + async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> None: """通过ConversationID删除Task信息""" # 删除Task task_ids = [] @@ -172,6 +122,29 @@ class TaskManager: await session.execute( delete(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId.in_(task_ids)), ) + await session.execute( + delete(TaskRuntime).where(TaskRuntime.taskId.in_(task_ids)), + ) + await session.commit() + + + @staticmethod + async def delete_task_context_by_task_id(task_id: uuid.UUID) -> None: + """通过task_id删除TaskContext信息""" + async with postgres.session() as session: + await session.execute( + delete(ExecutorHistory).where(ExecutorHistory.taskId == task_id), + ) + await session.commit() + + + @staticmethod + async def delete_task_context_by_conversation_id(conversation_id: uuid.UUID) -> None: + """通过ConversationID删除TaskContext信息""" + async with postgres.session() as session: + task_ids = list((await session.scalars( + select(Task.id).where(Task.conversationId == conversation_id), + )).all()) await session.execute( delete(ExecutorHistory).where(ExecutorHistory.taskId.in_(task_ids)), ) -- Gitee