diff --git a/apps/llm/token.py b/apps/llm/token.py
index b76cd744683e076a59a790057c9d2d77eca50b3b..55e60c819e0f0a8189dcdd9b4a38809bf8070dfa 100644
--- a/apps/llm/token.py
+++ b/apps/llm/token.py
@@ -29,26 +29,13 @@ class TokenCalculator(metaclass=SingletonMeta):
return result
- @staticmethod
- def get_k_tokens_words_from_content(content: str, k: int | None = None) -> str:
+ def get_k_tokens_words_from_content(self, content: str, k: int | None = None) -> str:
"""获取k个token的词"""
if k is None:
return content
if k <= 0:
return ""
- if TokenCalculator().calculate_token_length(messages=[
- {"role": "user", "content": content},
- ], pure_text=True) <= k:
- return content
- left = 0
- right = len(content)
- while left + 1 < right:
- mid = (left + right) // 2
- if TokenCalculator().calculate_token_length(messages=[
- {"role": "user", "content": content[:mid]},
- ], pure_text=True) <= k:
- left = mid
- else:
- right = mid
- return content[:left]
+ encodings = self._encoder.encode(content)
+ encodings = encodings[:k]
+ return self._encoder.decode(encodings, errors="ignore")
diff --git a/apps/models/task.py b/apps/models/task.py
index c21975e9858ab42a4c0d2aa7347143e286498ff0..2e72a1a0f7722f974a3ffd95b44211c07542c441 100644
--- a/apps/models/task.py
+++ b/apps/models/task.py
@@ -23,8 +23,6 @@ class Task(Base):
UUID(as_uuid=True), ForeignKey("framework_conversation.id"), nullable=False,
)
"""对话ID"""
- sessionId: Mapped[str] = mapped_column(String(255), ForeignKey("framework_session.id"), nullable=False) # noqa: N815
- """会话ID"""
checkpointId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815
UUID(as_uuid=True), ForeignKey("framework_executor_checkpoint.id"),
nullable=True, default=None,
@@ -56,6 +54,10 @@ class TaskRuntime(Base):
"""时间"""
fullTime: Mapped[float] = mapped_column(Float, default=0.0, nullable=False) # noqa: N815
"""完整时间成本"""
+ sessionId: Mapped[str | None] = mapped_column( # noqa: N815
+ String(255), ForeignKey("framework_session.id"), nullable=True, default=None,
+ )
+ """会话ID"""
userInput: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815
"""用户输入"""
fullAnswer: Mapped[str] = mapped_column(Text, nullable=False, default="") # noqa: N815
diff --git a/apps/models/user.py b/apps/models/user.py
index e18918c9457e1af018000c5567c0e15357b6ed81..abd844b3b06f6adbdf34e39f01c092421e741561 100644
--- a/apps/models/user.py
+++ b/apps/models/user.py
@@ -36,8 +36,12 @@ class User(Base):
"""用户个人令牌"""
selectedKB: Mapped[list[uuid.UUID]] = mapped_column(ARRAY(UUID), default=[], nullable=False) # noqa: N815
"""用户选择的知识库的ID"""
- defaultLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815
- """用户选择的大模型ID"""
+ reasoningLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815
+ """用户选择的问答模型ID"""
+ functionLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815
+ """用户选择的函数模型ID"""
+ embeddingLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815
+ """用户选择的向量模型ID"""
autoExecute: Mapped[bool | None] = mapped_column(Boolean, default=False, nullable=True) # noqa: N815
"""Agent是否自动执行"""
diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index 7bba12989b8aa115ea0f468d70f7e161408f62ea..0c5ab7e92cff8c9509bc80fa1cedb6e685b4a42a 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -57,7 +57,7 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T
if not post_body.task_id:
err = "[Chat] task_id 不可为空!"
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty")
- task = await TaskManager.get_task_by_task_id(post_body.task_id)
+ task = await TaskManager.get_task_data_by_task_id(post_body.task_id)
post_body.app = RequestDataApp(appId=task.state.app_id)
post_body.conversation_id = task.ids.conversation_id
post_body.language = task.language
diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py
index d317011e07a9c182b9c45965510b09d9abefa801..48d2531f3ccfebbbb2a14192994a0cb74c190d62 100644
--- a/apps/scheduler/call/facts/facts.py
+++ b/apps/scheduler/call/facts/facts.py
@@ -88,18 +88,20 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
# 提取事实信息
facts_tpl = env.from_string(FACTS_PROMPT[self._sys_vars.language])
facts_prompt = facts_tpl.render(conversation=data.message)
- facts_obj: FactsGen = await self._json([
+ facts_obj = await self._json([
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": facts_prompt},
- ], FactsGen) # type: ignore[arg-type]
+ ], FactsGen.model_json_schema())
+ facts_obj = FactsGen.model_validate(facts_obj)
# 更新用户画像
domain_tpl = env.from_string(DOMAIN_PROMPT[self._sys_vars.language])
domain_prompt = domain_tpl.render(conversation=data.message)
- domain_list: DomainGen = await self._json([
+ domain_list = await self._json([
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": domain_prompt},
- ], DomainGen) # type: ignore[arg-type]
+ ], DomainGen.model_json_schema())
+ domain_list = DomainGen.model_validate(domain_list)
for domain in domain_list.keywords:
await UserTagManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain)
diff --git a/apps/scheduler/call/graph/graph.py b/apps/scheduler/call/graph/graph.py
index 46e55e1382e25d0f62b6643a98abf3ff9f1f15d1..dc5986c03e6b67668ef9994b0ef788880b705bc2 100644
--- a/apps/scheduler/call/graph/graph.py
+++ b/apps/scheduler/call/graph/graph.py
@@ -17,8 +17,8 @@ from apps.schemas.scheduler import (
CallVars,
)
+from .prompt import GENERATE_STYLE_PROMPT
from .schema import RenderFormat, RenderInput, RenderOutput
-from .style import RenderStyle
class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput):
@@ -88,10 +88,7 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput):
self._option_template["dataset"]["source"] = processed_data
try:
- style_obj = RenderStyle()
llm_output = await style_obj.generate(question=data.question)
- self.tokens.input_tokens += style_obj.input_tokens
- self.tokens.output_tokens += style_obj.output_tokens
add_style = llm_output.get("additional_style", "")
self._parse_options(column_num, llm_output["chart_type"], add_style, llm_output["scale_type"])
diff --git a/apps/scheduler/call/graph/prompt.py b/apps/scheduler/call/graph/prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..521a432de1f984d747b561152cff6a2dc9164f79
--- /dev/null
+++ b/apps/scheduler/call/graph/prompt.py
@@ -0,0 +1,101 @@
+"""图表相关提示词"""
+
+from apps.schemas.enum_var import LanguageType
+
+GENERATE_STYLE_PROMPT: dict[LanguageType, str] = {
+ LanguageType.CHINESE: r"""
+
+
+ 你的目标是:帮助用户在绘制图表时做出样式选择。
+ 请以JSON格式输出你的选择。
+
+ 图表类型:
+ - `bar`: 柱状图
+ - `pie`: 饼图
+ - `line`: 折线图
+ - `scatter`: 散点图
+ 柱状图的附加样式:
+ - `normal`: 普通柱状图
+ - `stacked`: 堆叠柱状图
+ 饼图的附加样式:
+ - `normal`: 普通饼图
+ - `ring`: 环形饼图
+ 可用坐标比例:
+ - `linear`: 线性比例
+ - `log`: 对数比例
+
+
+
+ ## 问题
+ 查询数据库中的数据,并绘制堆叠柱状图。
+
+ ## 思考
+ 让我们一步步思考。用户要求绘制堆叠柱状图,因此图表类型应为 `bar`,即柱状图;图表样式\
+应为 `stacked`,即堆叠形式。
+
+ ## 答案
+ {
+ "chart_type": "bar",
+ "additional_style": "stacked",
+ "scale_type": "linear"
+ }
+
+
+
+ ## 问题
+ {question}
+
+ ## 思考
+ 让我们一步步思考。
+ """,
+ LanguageType.ENGLISH: r"""
+
+
+ Your mission is: help the user make style choices when drawing a chart.
+ Please output your choices in JSON format.
+
+ Chart types:
+ - `bar`: Bar chart
+ - `pie`: Pie chart
+ - `line`: Line chart
+ - `scatter`: Scatter chart
+
+ Bar chart additional styles:
+ - `normal`: Normal bar chart
+ - `stacked`: Stacked bar chart
+
+ Pie chart additional styles:
+ - `normal`: Normal pie chart
+ - `ring`: Ring pie chart
+
+ Axis scaling:
+ - `linear`: Linear scaling
+ - `log`: Logarithmic scaling
+
+
+
+ ## Question
+ Query the data from the database and draw a stacked bar chart.
+
+ ## Thought
+ Let's think step by step. The user requires drawing a stacked bar chart, so the chart type \
+should be `bar`, i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form.
+
+ ## Answer
+ {
+ "chart_type": "bar",
+ "additional_style": "stacked",
+ "scale_type": "linear"
+ }
+
+
+
+ ## Question
+
+ {question}
+
+ ## Thought
+
+ Let's think step by step.
+ """,
+}
diff --git a/apps/scheduler/call/graph/schema.py b/apps/scheduler/call/graph/schema.py
index 0674a4c48e5d2c04bc64fb60433c89dbbc30f400..54048fa28b9da7affcbbf515aef517980000b306 100644
--- a/apps/scheduler/call/graph/schema.py
+++ b/apps/scheduler/call/graph/schema.py
@@ -1,13 +1,21 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""图表工具的输入输出"""
-from typing import Any
+from typing import Any, Literal
from pydantic import BaseModel, Field
from apps.scheduler.call.core import DataBase
+class RenderStyleResult(BaseModel):
+ """选择图表样式结果"""
+
+ chart_type: Literal["bar", "pie", "line", "scatter"] = Field(description="图表类型")
+ additional_style: Literal["normal", "stacked", "ring"] | None = Field(description="图表样式")
+ scale_type: Literal["linear", "log"] = Field(description="图表比例")
+
+
class RenderAxis(BaseModel):
"""ECharts图表的轴配置"""
diff --git a/apps/scheduler/call/graph/style.py b/apps/scheduler/call/graph/style.py
deleted file mode 100644
index 20f6291cd7060a8fd8364c6aaed579ae3c2c82bc..0000000000000000000000000000000000000000
--- a/apps/scheduler/call/graph/style.py
+++ /dev/null
@@ -1,164 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""选择图表样式"""
-
-import logging
-from typing import Any, Literal
-
-from pydantic import BaseModel, Field
-
-from apps.llm.function import JsonGenerator
-from apps.llm.patterns.core import CorePattern
-from apps.llm.reasoning import ReasoningLLM
-from apps.schemas.enum_var import LanguageType
-
-logger = logging.getLogger(__name__)
-
-
-class RenderStyleResult(BaseModel):
- """选择图表样式结果"""
-
- chart_type: Literal["bar", "pie", "line", "scatter"] = Field(description="图表类型")
- additional_style: Literal["normal", "stacked", "ring"] | None = Field(description="图表样式")
- scale_type: Literal["linear", "log"] = Field(description="图表比例")
-
-
-class RenderStyle(CorePattern):
- """选择图表样式"""
-
- @staticmethod
- def _default() -> tuple[dict[LanguageType, str], dict[LanguageType, str]]:
- """默认的Prompt内容"""
- return {
- LanguageType.CHINESE: r"You are a helpful assistant.",
- LanguageType.ENGLISH: r"You are a helpful assistant.",
- }, {
- LanguageType.CHINESE: r"""
-
-
- 你的目标是:帮助用户在绘制图表时做出样式选择。
- 请以JSON格式输出你的选择。
-
- 图标类型:
- - `bar`: 柱状图
- - `pie`: 饼图
- - `line`: 折线图
- - `scatter`: 散点图
- 柱状图的附加样式:
- - `normal`: 普通柱状图
- - `stacked`: 堆叠柱状图
- 饼图的附加样式:
- - `normal`: 普通饼图
- - `ring`: 环形饼图
- 可用坐标比例:
- - `linear`: 线性比例
- - `log`: 对数比例
-
-
-
- ## 问题
- 查询数据库中的数据,并绘制堆叠柱状图。
-
- ## 思考
- 让我们一步步思考。用户要求绘制堆叠柱状图,因此图表类型应为 `bar`,即柱状图;图表样式\
-应为 `stacked`,即堆叠形式。
-
- ## 答案
- {
- "chart_type": "bar",
- "additional_style": "stacked",
- "scale_type": "linear"
- }
-
-
-
- ## 问题
- {question}
-
- ## 思考
- 让我们一步步思考。
- """,
- LanguageType.ENGLISH: r"""
-
-
- Your mission is: help the user make style choices when drawing a chart.
- Please output your choices in JSON format.
-
- Chart types:
- - `bar`: Bar chart
- - `pie`: Pie chart
- - `line`: Line chart
- - `scatter`: Scatter chart
-
- Bar chart additional styles:
- - `normal`: Normal bar chart
- - `stacked`: Stacked bar chart
-
- Pie chart additional styles:
- - `normal`: Normal pie chart
- - `ring`: Ring pie chart
-
- Axis scaling:
- - `linear`: Linear scaling
- - `log`: Logarithmic scaling
-
-
-
- ## Question
- Query the data from the database and draw a stacked bar chart.
-
- ## Thought
- Let's think step by step. The user requires drawing a stacked bar chart, so the chart type \
-should be `bar`, i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form.
-
- ## Answer
- {
- "chart_type": "bar",
- "additional_style": "stacked",
- "scale_type": "linear"
- }
-
-
-
- ## Question
- {question}
-
- ## Thought
- Let's think step by step.
- """,
- }
-
- async def generate(self, **kwargs) -> dict[str, Any]: # noqa: ANN003
- """使用LLM选择图表样式"""
- question = kwargs["question"]
- language = kwargs.get("language", LanguageType.CHINESE)
-
- # 使用Reasoning模型进行推理
- messages = [
- {"role": "system", "content": self.system_prompt[language]},
- {"role": "user", "content": self.user_prompt[language].format(question=question)},
- ]
- result = ""
- llm = ReasoningLLM()
- async for chunk in llm.call(messages, streaming=False):
- result += chunk
- self.input_tokens = llm.input_tokens
- self.output_tokens = llm.output_tokens
-
- messages += [
- {"role": "assistant", "content": result},
- ]
-
- # 使用FunctionLLM模型进行提取参数
- json_gen = JsonGenerator(
- query="根据给定的背景信息,生成预测问题",
- conversation=messages,
- schema=RenderStyleResult.model_json_schema(),
- )
- try:
- result_dict = await json_gen.generate()
- RenderStyleResult.model_validate(result_dict)
- except Exception:
- logger.exception("[RenderStyle] 选择图表样式失败")
- return {}
-
- return result_dict
diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py
index 3b88fb4ff14b6a554a6d2010f5d195be4503556d..ddfcaa01ccca519611f5f8e967dbd6621d96ff70 100644
--- a/apps/scheduler/call/slot/slot.py
+++ b/apps/scheduler/call/slot/slot.py
@@ -72,7 +72,9 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
]
# 使用大模型进行尝试
- answer = await self._llm(messages=conversation, streaming=False)
+ answer = ""
+ async for chunk in self._llm(messages=conversation, streaming=True):
+ answer += chunk
answer = await FunctionLLM.process_response(answer)
try:
data = json.loads(answer)
diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py
index d2730c4d92ebdb71204a1c163bdd81a42640d159..3a63914dbdf3a5b4035767a0e8fd112d54e1531a 100644
--- a/apps/scheduler/executor/agent.py
+++ b/apps/scheduler/executor/agent.py
@@ -31,7 +31,7 @@ class MCPAgentExecutor(BaseExecutor):
max_steps: int = Field(default=40, description="最大步数")
servers_id: list[str] = Field(description="MCP server id")
- agent_id: str = Field(default="", description="Agent ID")
+ agent_id: uuid.UUID = Field(default=uuid.uuid4(), description="App ID作为Agent ID")
agent_description: str = Field(default="", description="Agent描述")
mcp_list: list[MCPInfo] = Field(description="MCP服务器列表", default=[])
mcp_pool: MCPPool = Field(description="MCP池", default=MCPPool())
@@ -56,15 +56,15 @@ class MCPAgentExecutor(BaseExecutor):
async def init(self) -> None:
"""初始化MCP Agent"""
- self.planner = MCPPlanner(self.runtime.userInput, self.resoning_llm, self.runtime.language)
- self.host = MCPHost(self.task.userSub, self.resoning_llm)
+ self.planner = MCPPlanner(self.task.runtime.userInput, self.resoning_llm, self.task.runtime.language)
+ self.host = MCPHost(self.task.metadata.userSub, self.resoning_llm)
async def load_state(self) -> None:
"""从数据库中加载FlowExecutor的状态"""
logger.info("[MCPAgentExecutor] 加载Executor状态")
# 尝试恢复State
- if self.state and self.state.executorStatus != ExecutorStatus.INIT:
- self.context = await TaskManager.get_context_by_task_id(self.task.id)
+ if self.task.state and self.task.state.executorStatus != ExecutorStatus.INIT:
+ self.task.context = await TaskManager.get_context_by_task_id(self.task.metadata.id)
async def load_mcp(self) -> None:
"""加载MCP服务器列表"""
@@ -74,16 +74,16 @@ class MCPAgentExecutor(BaseExecutor):
mcp_ids = app.mcp_service
for mcp_id in mcp_ids:
mcp_service = await MCPServiceManager.get_mcp_service(mcp_id)
- if self.task.userSub not in mcp_service.activated:
+ if self.task.metadata.userSub not in mcp_service.activated:
logger.warning(
"[MCPAgentExecutor] 用户 %s 未启用MCP %s",
- self.task.userSub,
+ self.task.metadata.userSub,
mcp_id,
)
continue
self.mcp_list.append(mcp_service)
- await self.mcp_pool.init_mcp(mcp_id, self.task.userSub)
+ await self.mcp_pool.init_mcp(mcp_id, self.task.metadata.userSub)
for tool in mcp_service.tools:
self.tools[tool.id] = tool
self.tool_list.extend(mcp_service.tools)
@@ -100,9 +100,9 @@ class MCPAgentExecutor(BaseExecutor):
"""获取工具输入参数"""
if is_first:
# 获取第一个输入参数
- mcp_tool = self.tools[self.state.toolId]
- self.state.currentInput = await MCPHost._get_first_input_params(
- mcp_tool, self.runtime.userInput, self.state.stepDescription, self.task,
+ mcp_tool = self.tools[self.task.state.toolId]
+ self.task.state.currentInput = await self.host.get_first_input_params(
+ mcp_tool, self.task.runtime.userInput, self.task,
)
else:
# 获取后续输入参数
@@ -113,65 +113,74 @@ class MCPAgentExecutor(BaseExecutor):
params = {}
params_description = ""
mcp_tool = self.tools[self.task.state.tool_id]
- self.state.currentInput = await MCPHost.fill_params(
+ self.task.state.currentInput = await self.host.fill_params(
mcp_tool,
- self.runtime.userInput,
- self.state.stepDescription,
- self.state.currentInput,
- self.state.errorMessage,
+ self.task.runtime.userInput,
+ self.task.state.currentInput,
+ self.task.state.errorMessage,
params,
params_description,
- self.runtime.language,
+ self.task.runtime.language,
)
async def confirm_before_step(self) -> None:
"""确认前步骤"""
+ if not self.task.state:
+ err = "[MCPAgentExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
# 发送确认消息
mcp_tool = self.tools[self.task.state.tool_id]
- confirm_message = await MCPPlanner.get_tool_risk(
- mcp_tool, self.task.state.current_input, "", self.resoning_llm, self.task.language
+ confirm_message = await self.planner.get_tool_risk(
+ mcp_tool, self.task.state.currentInput, "", self.resoning_llm, self.task.runtime.language,
)
await self.update_tokens()
await self.push_message(
EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(exclude_none=True, by_alias=True),
)
await self.push_message(EventType.FLOW_STOP, {})
- self.state.executorStatus = ExecutorStatus.WAITING
- self.state.stepStatus = StepStatus.WAITING
- self.context.append(
+ self.task.state.executorStatus = ExecutorStatus.WAITING
+ self.task.state.stepStatus = StepStatus.WAITING
+ self.task.context.append(
ExecutorHistory(
- task_id=self.task.id,
- step_id=self.task.state.step_id,
- step_name=self.task.state.step_name,
- step_description=self.task.state.step_description,
- step_status=self.task.state.step_status,
- flow_id=self.task.state.flow_id,
- flow_name=self.task.state.flow_name,
- flow_status=self.task.state.flow_status,
- input_data={},
- output_data={},
- ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True),
- )
+ taskId=self.task.metadata.id,
+ stepId=self.task.state.stepId,
+ stepName=self.task.state.stepName,
+ stepDescription=self.task.state.stepDescription,
+ stepStatus=self.task.state.stepStatus,
+ executorId=self.task.state.executorId,
+ executorName=self.task.state.executorName,
+ executorStatus=self.task.state.executorStatus,
+ inputData={},
+ outputData={},
+ extraData=confirm_message.model_dump(exclude_none=True, by_alias=True),
+ ),
)
async def run_step(self) -> None:
"""执行步骤"""
- self.state.executorStatus = ExecutorStatus.RUNNING
- self.state.stepStatus = StepStatus.RUNNING
- mcp_tool = self.tools[self.state.toolId]
- mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.ids.user_sub))
+ if not self.task.state:
+ err = "[MCPAgentExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
+ self.task.state.executorStatus = ExecutorStatus.RUNNING
+ self.task.state.stepStatus = StepStatus.RUNNING
+ mcp_tool = self.tools[self.task.state.toolId]
+ mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.metadata.userSub))
try:
- output_params = await mcp_client.call_tool(mcp_tool.name, self.state.currentInput)
+ output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.currentInput)
except anyio.ClosedResourceError:
logger.exception("[MCPAgentExecutor] MCP客户端连接已关闭: %s", mcp_tool.mcp_id)
- await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.ids.user_sub)
- await self.mcp_pool.init_mcp(mcp_tool.mcp_id, self.task.ids.user_sub)
- self.state.stepStatus = StepStatus.ERROR
+ await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.metadata.userSub)
+ await self.mcp_pool.init_mcp(mcp_tool.mcp_id, self.task.metadata.userSub)
+ self.task.state.stepStatus = StepStatus.ERROR
return
except Exception as e:
logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误", mcp_tool.name)
- self.state.stepStatus = StepStatus.ERROR
- self.state.errorMessage = str(e)
+ self.task.state.stepStatus = StepStatus.ERROR
+ self.task.state.errorMessage = str(e)
return
logger.error(f"当前工具名称: {mcp_tool.name}, 输出参数: {output_params}")
if output_params.isError:
@@ -179,8 +188,8 @@ class MCPAgentExecutor(BaseExecutor):
for output in output_params.content:
if isinstance(output, TextContent):
err += output.text
- self.state.stepStatus = StepStatus.ERROR
- self.state.errorMessage = {
+ self.task.state.stepStatus = StepStatus.ERROR
+ self.task.state.errorMessage = {
"err_msg": err,
"data": {},
}
@@ -194,54 +203,59 @@ class MCPAgentExecutor(BaseExecutor):
}
await self.update_tokens()
- await self.push_message(EventType.STEP_INPUT, self.state.currentInput)
+ await self.push_message(EventType.STEP_INPUT, self.task.state.currentInput)
await self.push_message(EventType.STEP_OUTPUT, output_params)
- self.context.append(
+ self.task.context.append(
ExecutorHistory(
- taskId=self.task.id,
- stepId=self.state.stepId,
- stepName=self.state.stepName,
- stepDescription=self.state.stepDescription,
+ taskId=self.task.metadata.id,
+ stepId=self.task.state.stepId,
+ stepName=self.task.state.stepName,
+ stepDescription=self.task.state.stepDescription,
stepStatus=StepStatus.SUCCESS,
- executorId=self.state.executorId,
- executorName=self.state.executorName,
- executorStatus=self.state.executorStatus,
- inputData=self.state.currentInput,
+ executorId=self.task.state.executorId,
+ executorName=self.task.state.executorName,
+ executorStatus=self.task.state.executorStatus,
+ inputData=self.task.state.currentInput,
outputData=output_params,
),
)
- self.state.stepStatus = StepStatus.SUCCESS
+ self.task.state.stepStatus = StepStatus.SUCCESS
async def generate_params_with_null(self) -> None:
"""生成参数补充"""
- mcp_tool = self.tools[self.state.toolId]
+ if not self.task.state:
+ err = "[MCPAgentExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
+ mcp_tool = self.tools[self.task.state.toolId]
params_with_null = await self.planner.get_missing_param(
mcp_tool,
- self.state.currentInput,
- self.state.errorMessage,
+ self.task.state.currentInput,
+ self.task.state.errorMessage,
)
await self.update_tokens()
error_message = await self.planner.change_err_message_to_description(
- error_message=self.state.errorMessage,
+ error_message=self.task.state.errorMessage,
tool=mcp_tool,
- input_params=self.state.currentInput,
+ input_params=self.task.state.currentInput,
)
await self.push_message(
EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null},
)
await self.push_message(EventType.FLOW_STOP, data={})
- self.state.executorStatus = ExecutorStatus.WAITING
- self.state.stepStatus = StepStatus.PARAM
- self.context.append(
+ self.task.state.executorStatus = ExecutorStatus.WAITING
+ self.task.state.stepStatus = StepStatus.PARAM
+ self.task.context.append(
ExecutorHistory(
- taskId=self.task.id,
- stepId=self.state.stepId,
- stepName=self.state.stepName,
- stepDescription=self.state.stepDescription,
- stepStatus=self.state.stepStatus,
- executorId=self.state.executorId,
- executorName=self.state.executorName,
- executorStatus=self.state.executorStatus,
+ taskId=self.task.metadata.id,
+ stepId=self.task.state.stepId,
+ stepName=self.task.state.stepName,
+ stepDescription=self.task.state.stepDescription,
+ stepStatus=self.task.state.stepStatus,
+ executorId=self.task.state.executorId,
+ executorName=self.task.state.executorName,
+ executorStatus=self.task.state.executorStatus,
inputData={},
outputData={},
extraData={
@@ -253,9 +267,14 @@ class MCPAgentExecutor(BaseExecutor):
async def get_next_step(self) -> None:
"""获取下一步"""
+ if not self.task.state:
+ err = "[MCPAgentExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
if self.step_cnt < self.max_steps:
self.step_cnt += 1
- history = await MCPHost.assemble_memory(self.task)
+ history = await self.host.assemble_memory(self.task.runtime, self.task.context)
max_retry = 3
step = None
for _ in range(max_retry):
@@ -273,38 +292,43 @@ class MCPAgentExecutor(BaseExecutor):
tool_id = step.tool_id
step_name = FINAL_TOOL_ID if tool_id == FINAL_TOOL_ID else self.tools[tool_id].name
step_description = step.description
- self.state.stepId = uuid.uuid4()
- self.state.toolId = tool_id
- self.state.stepName = step_name
- self.state.stepDescription = step_description
- self.state.stepStatus = StepStatus.INIT
- self.state.currentInput = {}
+ self.task.state.stepId = uuid.uuid4()
+ self.task.state.toolId = tool_id
+ self.task.state.stepName = step_name
+ self.task.state.stepDescription = step_description
+ self.task.state.stepStatus = StepStatus.INIT
+ self.task.state.currentInput = {}
else:
# 没有下一步了,结束流程
- self.state.toolId = FINAL_TOOL_ID
+ self.task.state.toolId = FINAL_TOOL_ID
async def error_handle_after_step(self) -> None:
"""步骤执行失败后的错误处理"""
- self.state.stepStatus = StepStatus.ERROR
- self.state.executorStatus = ExecutorStatus.ERROR
+ if not self.task.state:
+ err = "[MCPAgentExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
+ self.task.state.stepStatus = StepStatus.ERROR
+ self.task.state.executorStatus = ExecutorStatus.ERROR
await self.push_message(
EventType.FLOW_FAILED,
data={},
)
- if len(self.context) and self.context[-1].step_id == self.state.step_id:
- del self.context[-1]
- self.context.append(
+ if len(self.task.context) and self.task.context[-1].stepId == self.task.state.stepId:
+ del self.task.context[-1]
+ self.task.context.append(
ExecutorHistory(
- task_id=self.task.id,
- step_id=self.task.state.step_id,
- step_name=self.task.state.step_name,
- step_description=self.task.state.step_description,
- step_status=self.task.state.step_status,
- flow_id=self.task.state.flow_id,
- flow_name=self.task.state.flow_name,
- flow_status=self.task.state.flow_status,
- input_data={},
- output_data={},
+ taskId=self.task.metadata.id,
+ stepId=self.task.state.stepId,
+ stepName=self.task.state.stepName,
+ stepDescription=self.task.state.stepDescription,
+ stepStatus=self.task.state.stepStatus,
+ executorId=self.task.state.executorId,
+ executorName=self.task.state.executorName,
+ executorStatus=self.task.state.executorStatus,
+ inputData={},
+ outputData={},
),
)
diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py
index 1a1b12b96f6acd75837b905e7f53acefbc931c3d..0cd0d837b98db28efef1aae3236fcf14fab3b80d 100644
--- a/apps/scheduler/executor/base.py
+++ b/apps/scheduler/executor/base.py
@@ -12,8 +12,8 @@ from apps.schemas.message import TextAddContent
if TYPE_CHECKING:
from apps.common.queue import MessageQueue
- from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime
from apps.schemas.scheduler import ExecutorBackground
+ from apps.schemas.task import TaskData
logger = logging.getLogger(__name__)
@@ -21,10 +21,7 @@ logger = logging.getLogger(__name__)
class BaseExecutor(BaseModel, ABC):
"""Executor基类"""
- task: "Task"
- runtime: "TaskRuntime"
- state: "ExecutorCheckpoint"
- context: list["ExecutorHistory"]
+ task: "TaskData"
msg_queue: "MessageQueue"
background: "ExecutorBackground"
@@ -51,7 +48,7 @@ class BaseExecutor(BaseModel, ABC):
data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True)
await self.msg_queue.push_output(
- self.task,
+ self.task.metadata,
event_type=event_type,
data=data,
)
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index 9c41d27e12bdced2567c4773ab21d175769bc201..8f75c2877f025cad96eb47a82f544f8930848cd0 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -89,8 +89,8 @@ class FlowExecutor(BaseExecutor):
executorName=self.flow.name,
executorStatus=ExecutorStatus.RUNNING,
stepStatus=StepStatus.RUNNING,
- stepId="start",
- stepName="开始" if self.runtime.language == LanguageType.CHINESE else "Start",
+ stepId=self.flow.basicConfig.startStep,
+ stepName=self.flow.steps[self.flow.basicConfig.startStep].name,
)
# 是否到达Flow结束终点(变量)
self._reached_end: bool = False
@@ -106,9 +106,6 @@ class FlowExecutor(BaseExecutor):
step=self.current_step,
background=self.background,
question=self.question,
- runtime=self.runtime,
- state=self.state,
- context=self.context,
)
# 初始化步骤
@@ -161,8 +158,8 @@ class FlowExecutor(BaseExecutor):
if not next_steps:
return [
StepQueueItem(
- step_id="end",
- step=self.flow.steps["end"],
+ step_id=self.flow.basicConfig.endStep,
+ step=self.flow.steps[self.flow.basicConfig.endStep],
),
]
@@ -197,7 +194,7 @@ class FlowExecutor(BaseExecutor):
self.step_queue.append(
StepQueueItem(
step_id=uuid.uuid4(),
- step=step.get(self.runtime.language, step[LanguageType.CHINESE]),
+ step=step.get(self.task.runtime.language, step[LanguageType.CHINESE]),
enable_filling=False,
to_user=False,
),
@@ -220,15 +217,15 @@ class FlowExecutor(BaseExecutor):
step_id=uuid.uuid4(),
step=Step(
name=(
- "错误处理" if self.runtime.language == LanguageType.CHINESE else "Error Handling"
+ "错误处理" if self.task.runtime.language == LanguageType.CHINESE else "Error Handling"
),
description=(
- "错误处理" if self.runtime.language == LanguageType.CHINESE else "Error Handling"
+ "错误处理" if self.task.runtime.language == LanguageType.CHINESE else "Error Handling"
),
node=SpecialCallType.LLM.value,
type=SpecialCallType.LLM.value,
params={
- "user_prompt": LLM_ERROR_PROMPT[self.runtime.language].replace(
+ "user_prompt": LLM_ERROR_PROMPT[self.task.runtime.language].replace(
"{{ error_info }}",
self.state.errorMessage["err_msg"],
),
@@ -264,13 +261,13 @@ class FlowExecutor(BaseExecutor):
self.step_queue.append(
StepQueueItem(
step_id=uuid.uuid4(),
- step=step.get(self.runtime.language, step[LanguageType.CHINESE]),
+ step=step.get(self.task.runtime.language, step[LanguageType.CHINESE]),
),
)
await self._step_process()
# FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间)
- self.runtime.time = round(datetime.now(UTC).timestamp(), 2) - self.runtime.fullTime
+ self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.fullTime
# 推送Flow停止消息
if is_error:
await self.push_message(EventType.FLOW_FAILED.value)
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
index a3ddab274faac4f2023b2e8fed078c5673f8c6df..0a1bbf24cb838cf6a1cc3fe1fb3ed00991688fca 100644
--- a/apps/scheduler/executor/step.py
+++ b/apps/scheduler/executor/step.py
@@ -85,10 +85,15 @@ class StepExecutor(BaseExecutor):
"""初始化步骤"""
logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name)
+ if not self.task.state:
+ err = "[StepExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
# State写入ID和运行状态
- self.state.stepId = self.step.step_id
- self.state.stepDescription = self.step.step.description
- self.state.stepName = self.step.step.name
+ self.task.state.stepId = self.step.step_id
+ self.task.state.stepDescription = self.step.step.description
+ self.task.state.stepName = self.step.step.name
# 获取并验证Call类
node_id = self.step.step.node
@@ -127,15 +132,20 @@ class StepExecutor(BaseExecutor):
if not self.obj.enable_filling:
return
+ if not self.task.state:
+ err = "[StepExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
# 暂存旧数据
- current_step_id = self.state.stepId
- current_step_name = self.state.stepName
+ current_step_id = self.task.state.stepId
+ current_step_name = self.task.state.stepName
# 更新State
- self.state.stepId = uuid.uuid4()
- self.state.stepName = "自动参数填充"
- self.state.stepStatus = StepStatus.RUNNING
- self.runtime.time = round(datetime.now(UTC).timestamp(), 2)
+ self.task.state.stepId = uuid.uuid4()
+ self.task.state.stepName = "自动参数填充"
+ self.task.state.stepStatus = StepStatus.RUNNING
+ self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2)
# 初始化填参
slot_obj = await Slot.instance(
@@ -153,26 +163,26 @@ class StepExecutor(BaseExecutor):
async for chunk in iterator:
result: SlotOutput = SlotOutput.model_validate(chunk.content)
await TaskManager.update_task_token(
- self.task.id,
+ self.task.metadata.id,
input_token=slot_obj.tokens.input_tokens,
output_token=slot_obj.tokens.output_tokens,
)
# 如果没有填全,则状态设置为待填参
if result.remaining_schema:
- self.state.stepStatus = StepStatus.PARAM
+ self.task.state.stepStatus = StepStatus.PARAM
else:
- self.state.stepStatus = StepStatus.SUCCESS
+ self.task.state.stepStatus = StepStatus.SUCCESS
await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True))
# 更新输入
self.obj.input.update(result.slot_data)
# 恢复State
- self.state.stepId = current_step_id
- self.state.stepName = current_step_name
+ self.task.state.stepId = current_step_id
+ self.task.state.stepName = current_step_name
await TaskManager.update_task_token(
- self.task.id,
+ self.task.metadata.id,
input_token=self.obj.tokens.input_tokens,
output_token=self.obj.tokens.output_tokens,
)
@@ -203,7 +213,7 @@ class StepExecutor(BaseExecutor):
if to_user:
if isinstance(chunk.content, str):
await self.push_message(EventType.TEXT_ADD.value, chunk.content)
- self.runtime.fullAnswer += chunk.content
+ self.task.runtime.fullAnswer += chunk.content
else:
await self.push_message(self.step.step.type, chunk.content)
@@ -214,12 +224,17 @@ class StepExecutor(BaseExecutor):
"""运行单个步骤"""
logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name)
+ if not self.task.state:
+ err = "[StepExecutor] 任务状态不存在"
+ logger.error(err)
+ raise RuntimeError(err)
+
# 进行自动参数填充
await self._run_slot_filling()
# 更新状态
- self.state.stepStatus = StepStatus.RUNNING
- self.runtime.time = round(datetime.now(UTC).timestamp(), 2)
+ self.task.state.stepStatus = StepStatus.RUNNING
+ self.task.runtime.time = round(datetime.now(UTC).timestamp(), 2)
# 推送输入
await self.push_message(EventType.STEP_INPUT.value, self.obj.input)
@@ -230,27 +245,27 @@ class StepExecutor(BaseExecutor):
content = await self._process_chunk(iterator, to_user=self.obj.to_user)
except Exception as e:
logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤")
- self.state.stepStatus = StepStatus.ERROR
+ self.task.state.stepStatus = StepStatus.ERROR
await self.push_message(EventType.STEP_OUTPUT.value, {})
if isinstance(e, CallError):
- self.state.errorMessage = {
+ self.task.state.errorMessage = {
"err_msg": e.message,
"data": e.data,
}
else:
- self.state.errorMessage = {
+ self.task.state.errorMessage = {
"data": {},
}
return
# 更新执行状态
- self.state.stepStatus = StepStatus.SUCCESS
+ self.task.state.stepStatus = StepStatus.SUCCESS
await TaskManager.update_task_token(
- self.task.id,
+ self.task.metadata.id,
input_token=self.obj.tokens.input_tokens,
output_token=self.obj.tokens.output_tokens,
)
- self.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.runtime.time
+ self.task.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.time
# 更新history
if isinstance(content, str):
@@ -260,18 +275,18 @@ class StepExecutor(BaseExecutor):
# 更新context
history = ExecutorHistory(
- taskId=self.task.id,
- executorId=self.state.executorId,
- executorName=self.state.executorName,
- executorStatus=self.state.executorStatus,
+ taskId=self.task.metadata.id,
+ executorId=self.task.state.executorId,
+ executorName=self.task.state.executorName,
+ executorStatus=self.task.state.executorStatus,
stepId=self.step.step_id,
stepName=self.step.step.name,
stepDescription=self.step.step.description,
- stepStatus=self.state.stepStatus,
+ stepStatus=self.task.state.stepStatus,
inputData=self.obj.input,
outputData=output_data,
)
- self.context.append(history)
+ self.task.context.append(history)
# 推送输出
await self.push_message(EventType.STEP_OUTPUT.value, output_data)
diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py
index d28dc68497fbcc9d0ceefdce9b81b059e3dfea2b..ec2e004d55a4bfdbf66f67aedeb81b68bf9ba02a 100644
--- a/apps/scheduler/mcp/host.py
+++ b/apps/scheduler/mcp/host.py
@@ -28,7 +28,7 @@ class MCPHost:
"""MCP宿主服务"""
def __init__(
- self, user_sub: str, task_id: uuid.UUID, runtime_id: uuid.UUID, runtime_name: str,
+ self, user_sub: str, task_id: uuid.UUID, runtime_id: str, runtime_name: str,
language: LanguageType,
) -> None:
"""初始化MCP宿主"""
@@ -63,7 +63,7 @@ class MCPHost:
async def assemble_memory(self) -> str:
"""组装记忆"""
- task = await TaskManager.get_task_by_task_id(self._task_id)
+ task = await TaskManager.get_task_data_by_task_id(self._task_id)
if not task:
logger.error("任务 %s 不存在", self._task_id)
return ""
@@ -107,7 +107,7 @@ class MCPHost:
executorId=self._runtime_id,
executorName=self._runtime_name,
executorStatus=ExecutorStatus.RUNNING,
- stepId=tool.id,
+ stepId=uuid.uuid4(),
stepName=tool.toolName,
# description是规划的实际内容
stepDescription=plan_item.content,
@@ -117,12 +117,12 @@ class MCPHost:
)
# 保存到task
- task = await TaskManager.get_task_by_task_id(self._task_id)
+ task = await TaskManager.get_task_data_by_task_id(self._task_id)
if not task:
logger.error("任务 %s 不存在", self._task_id)
return {}
self._context_list.append(context.id)
- context.append(context.model_dump(exclude_none=True, by_alias=True))
+ context.append(context)
await TaskManager.save_task(self._task_id, task)
return output_data
diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py
index 1b273ba897d02d3ab4d229c50cec776721cdb1a3..cad4256bae6d481bd496cddf1a5b98a983d198f8 100644
--- a/apps/scheduler/mcp_agent/base.py
+++ b/apps/scheduler/mcp_agent/base.py
@@ -5,7 +5,7 @@ import logging
from typing import Any
from apps.llm.function import JsonGenerator
-from apps.llm.reasoning import ReasoningLLM
+from apps.models.llm import LLMData
logger = logging.getLogger(__name__)
@@ -13,7 +13,14 @@ logger = logging.getLogger(__name__)
class MCPBase:
"""MCP基类"""
- llm: ReasoningLLM
+ user_sub: str
+
+ def __init__(self, user_sub: str) -> None:
+ """初始化MCP基类"""
+ self.user_sub = user_sub
+
+ async def init(self) -> None:
+ pass
async def get_resoning_result(self, prompt: str) -> str:
"""获取推理结果"""
diff --git a/apps/scheduler/pool/check.py b/apps/scheduler/pool/check.py
index 55f8d8c76783d5b4e9fc87b797071e2f142e3368..2e1d5f0e571076d7732b1888ccbbc4e638dc606e 100644
--- a/apps/scheduler/pool/check.py
+++ b/apps/scheduler/pool/check.py
@@ -55,7 +55,7 @@ class FileChecker:
return self.hashes[path_diff.as_posix()] != previous_hashes
- async def diff(self, check_type: MetadataType) -> tuple[list[str], list[str]]:
+ async def diff(self, check_type: MetadataType) -> tuple[list[uuid.UUID], list[uuid.UUID]]:
"""生成更新列表和删除列表"""
async with postgres.session() as session:
# 判断类型
diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py
index fce4344e159f3d2d6a909df26d834b8c051521f1..e41363e917102d582db62d0fefe32cb4e7b39ece 100644
--- a/apps/scheduler/pool/pool.py
+++ b/apps/scheduler/pool/pool.py
@@ -111,18 +111,18 @@ class Pool:
# 批量删除App
for app in changed_app:
- await AppLoader.delete(uuid.UUID(app), is_reload=True)
+ await AppLoader.delete(app, is_reload=True)
for app in deleted_app:
- await AppLoader.delete(uuid.UUID(app))
+ await AppLoader.delete(app)
# 批量加载App
for app in changed_app:
hash_key = Path("app/" + str(app)).as_posix()
if hash_key in checker.hashes:
try:
- await AppLoader.load(uuid.UUID(app), checker.hashes[hash_key])
+ await AppLoader.load(app, checker.hashes[hash_key])
except Exception as e: # noqa: BLE001
- await AppLoader.delete(uuid.UUID(app), is_reload=True)
+ await AppLoader.delete(app, is_reload=True)
logger.warning("[Pool] 加载App %s 失败: %s", app, str(e))
# 载入MCP
diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py
index 9d46a3fc51ba71dc8556398ac42c678d24b570de..da40cc718487b116dc2592f51ff4cb7644542700 100644
--- a/apps/scheduler/scheduler/flow.py
+++ b/apps/scheduler/scheduler/flow.py
@@ -42,28 +42,3 @@ class FlowChooser:
await TaskManager.update_task_token(self.task_id, select_obj.input_tokens, select_obj.output_tokens)
return top_flow
-
-
- async def choose_flow(self) -> RequestDataApp | None:
- """
- 依据用户的输入和选择,构造对应的Flow。
-
- - 当用户没有选择任何app时,直接进行智能问答
- - 当用户选择了特定的app时,在plugin内挑选最适合的flow
- """
- if not self._user_selected or not self._user_selected.app_id:
- return None
-
- if self._user_selected.flow_id:
- return self._user_selected
-
- top_flow = await self.get_top_flow()
- # FIXME KnowledgeBase不是UUID,要改个值
- if top_flow == "KnowledgeBase":
- return None
-
- return RequestDataApp(
- appId=self._user_selected.app_id,
- flowId=top_flow,
- params=None,
- )
diff --git a/apps/schemas/task.py b/apps/schemas/task.py
index ad8244af9746460e7e37995c1bea0a2ca2bd799e..8b74073d36ced56920c134748e6ce42b2dc31f85 100644
--- a/apps/schemas/task.py
+++ b/apps/schemas/task.py
@@ -6,9 +6,20 @@ from typing import Any
from pydantic import BaseModel, Field
+from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime
+
from .flow import Step
+class TaskData(BaseModel):
+ """任务数据"""
+
+ metadata: Task = Field(description="任务")
+ runtime: TaskRuntime = Field(description="任务运行时数据")
+ state: ExecutorCheckpoint | None = Field(description="执行状态")
+ context: list[ExecutorHistory] = Field(description="执行历史")
+
+
class CheckpointExtra(BaseModel):
"""Executor额外数据"""
@@ -16,11 +27,6 @@ class CheckpointExtra(BaseModel):
error_message: str = Field(description="错误信息", default="")
retry_times: int = Field(description="当前步骤重试次数", default=0)
-
-class TaskExtra(BaseModel):
- """任务额外数据"""
-
-
class StepQueueItem(BaseModel):
"""步骤栈中的元素"""
diff --git a/apps/services/conversation.py b/apps/services/conversation.py
index 131c7a0ec1cdfa22ee2373df5d297f98ec5b4882..aa251e0efc32fd18afc4e881ae643b7f6e8d772b 100644
--- a/apps/services/conversation.py
+++ b/apps/services/conversation.py
@@ -158,7 +158,7 @@ class ConversationManager:
await session.delete(conv)
await session.commit()
- await TaskManager.delete_task_history_checkpoint_by_conversation_id(conversation_id)
+ await TaskManager.delete_tasks_by_conversation_id(conversation_id)
@staticmethod
diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py
index ed1eb6952e8f30ad5071dea901fe641900ab0fd1..b44dc729e4bf144d0b901ba75dd90c8b2e6320a2 100644
--- a/apps/services/flow_validate.py
+++ b/apps/services/flow_validate.py
@@ -3,6 +3,7 @@
import collections
import logging
+import uuid
from typing import TYPE_CHECKING
from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError
@@ -94,7 +95,7 @@ class FlowService:
return flow_item
@staticmethod
- async def _validate_node_ids(nodes: list[NodeItem]) -> tuple[str, str]:
+ async def _validate_node_ids(nodes: list[NodeItem]) -> tuple[uuid.UUID, uuid.UUID]:
"""验证节点ID的唯一性并获取起始和终止节点ID,当节点ID重复或起始/终止节点数量不为1时抛出异常"""
ids = set()
start_cnt = 0
@@ -128,7 +129,7 @@ class FlowService:
return start_id, end_id
@staticmethod
- async def validate_flow_illegal(flow_item: FlowItem) -> tuple[str, str]:
+ async def validate_flow_illegal(flow_item: FlowItem) -> tuple[uuid.UUID, uuid.UUID]:
"""验证流程图是否合法;当流程图不合法时抛出异常"""
# 验证节点ID并获取起始和终止节点
start_id, end_id = await FlowService._validate_node_ids(flow_item.nodes)
diff --git a/apps/services/llm.py b/apps/services/llm.py
index 0cd31fefa4d03c049b76df0bca7b887f2df7fcd4..ceb76c6dcee50824562d3686ea74098564b567d3 100644
--- a/apps/services/llm.py
+++ b/apps/services/llm.py
@@ -56,7 +56,7 @@ class LLMManager:
logger.error("[LLMManager] 用户 %s 不存在", user_sub)
return None
- return user.defaultLLM
+ return user.reasoningLLM
@staticmethod
@@ -186,7 +186,7 @@ class LLMManager:
if not user:
err = f"[LLMManager] 用户 {user_sub} 不存在"
raise ValueError(err)
- user.defaultLLM = None
+ user.reasoningLLM = None
await session.commit()
@@ -203,7 +203,7 @@ class LLMManager:
if not user:
err = f"[LLMManager] 用户 {user_sub} 不存在"
raise ValueError(err)
- user.defaultLLM = llm_id
+ user.reasoningLLM = llm_id
await session.commit()
diff --git a/apps/services/rag.py b/apps/services/rag.py
index caf2930c75c7884856f5725cc0fe992c8996e2f5..af5bd506b532ef10db4d4a885bf96145a6502e76 100644
--- a/apps/services/rag.py
+++ b/apps/services/rag.py
@@ -16,12 +16,13 @@ from apps.llm.patterns.rewrite import QuestionRewrite
from apps.llm.reasoning import ReasoningLLM
from apps.llm.token import TokenCalculator
from apps.models.llm import LLMData
-from apps.schemas.config import LLMConfig
from apps.schemas.enum_var import EventType, LanguageType
from apps.schemas.rag_data import RAGQueryReq
+from apps.services.llm import LLMManager
from apps.services.session import SessionManager
logger = logging.getLogger(__name__)
+CHUNK_ELEMENT_TOKENS = 5
class RAG:
@@ -127,7 +128,7 @@ class RAG:
@staticmethod
async def get_doc_info_from_rag(
- user_sub: str, max_tokens: int, doc_ids: list[str], data: RAGQueryReq,
+ user_sub: str, max_tokens: int | None, doc_ids: list[str], data: RAGQueryReq,
) -> list[dict[str, Any]]:
"""获取RAG服务的文档信息"""
session_id = await SessionManager.get_session_by_user_sub(user_sub)
@@ -182,32 +183,8 @@ class RAG:
doc_info_list = []
doc_cnt = 0
doc_id_map = {}
- leave_tokens = max_tokens
- token_calculator = TokenCalculator()
- for doc_chunk in doc_chunk_list:
- if doc_chunk["docId"] not in doc_id_map:
- doc_cnt += 1
- doc_id_map[doc_chunk["docId"]] = doc_cnt
- doc_index = doc_id_map[doc_chunk["docId"]]
- leave_tokens -= token_calculator.calculate_token_length(
- messages=[
- {
- "role": "user",
- "content": f"""""",
- },
- {"role": "user", "content": ""},
- ],
- pure_text=True,
- )
- tokens_of_chunk_element = token_calculator.calculate_token_length(
- messages=[
- {"role": "user", "content": ""},
- {"role": "user", "content": ""},
- ],
- pure_text=True,
- )
- doc_cnt = 0
- doc_id_map = {}
+ remaining_tokens = max_tokens * 0.8
+
for doc_chunk in doc_chunk_list:
if doc_chunk["docId"] not in doc_id_map:
doc_cnt += 1
@@ -231,16 +208,18 @@ class RAG:
})
doc_id_map[doc_chunk["docId"]] = doc_cnt
doc_index = doc_id_map[doc_chunk["docId"]]
+
if bac_info:
bac_info += "\n\n"
bac_info += f""""""
+
for chunk in doc_chunk["chunks"]:
- if leave_tokens <= tokens_of_chunk_element:
+ if remaining_tokens <= CHUNK_ELEMENT_TOKENS:
break
chunk_text = chunk["text"]
- chunk_text = TokenCalculator.get_k_tokens_words_from_content(
- content=chunk_text, k=leave_tokens)
- leave_tokens -= token_calculator.calculate_token_length(messages=[
+ chunk_text = TokenCalculator().get_k_tokens_words_from_content(
+ content=chunk_text, k=remaining_tokens)
+ remaining_tokens -= TokenCalculator().calculate_token_length(messages=[
{"role": "user", "content": ""},
{"role": "user", "content": chunk_text},
{"role": "user", "content": ""},
@@ -256,21 +235,20 @@ class RAG:
@staticmethod
async def chat_with_llm_base_on_rag( # noqa: C901, PLR0913
user_sub: str,
- llm: LLMData,
+ llm_id: str,
history: list[dict[str, str]],
doc_ids: list[str],
data: RAGQueryReq,
language: LanguageType = LanguageType.CHINESE,
) -> AsyncGenerator[str, None]:
"""获取RAG服务的结果"""
- reasion_llm = ReasoningLLM(
- LLMConfig(
- endpoint=llm.openaiBaseUrl,
- key=llm.openaiAPIKey,
- model=llm.modelName,
- max_tokens=llm.maxToken,
- ),
- )
+ llm_config = await LLMManager.get_llm(llm_id)
+ if not llm_config:
+ err = "[RAG] 未设置问答所用LLM"
+ logger.error(err)
+ raise RuntimeError(err)
+ reasion_llm = ReasoningLLM(llm_config)
+
if history:
try:
question_obj = QuestionRewrite()
@@ -280,9 +258,11 @@ class RAG:
except Exception:
logger.exception("[RAG] 问题重写失败")
doc_chunk_list = await RAG.get_doc_info_from_rag(
- user_sub=user_sub, max_tokens=llm.maxToken, doc_ids=doc_ids, data=data)
+ user_sub=user_sub, max_tokens=llm_config.maxToken, doc_ids=doc_ids, data=data,
+ )
bac_info, doc_info_list = await RAG.assemble_doc_info(
- doc_chunk_list=doc_chunk_list, max_tokens=llm.maxToken)
+ doc_chunk_list=doc_chunk_list, max_tokens=llm_config.maxToken,
+ )
messages = [
*history,
{
@@ -323,11 +303,11 @@ class RAG:
buffer = ""
async for chunk in reasion_llm.call(
messages,
- max_tokens=llm.maxToken,
+ max_tokens=llm_config.maxToken,
streaming=True,
temperature=0.7,
result_only=False,
- model=llm.modelName,
+ model=llm_config.modelName,
):
tmp_chunk = buffer + chunk
# 防止脚注被截断
diff --git a/apps/services/task.py b/apps/services/task.py
index a2771f26c0a073831283af7217c318719e23372c..a3c1bff1fa9509b30fc28a10e0271a5ef1acc112 100644
--- a/apps/services/task.py
+++ b/apps/services/task.py
@@ -9,7 +9,7 @@ from sqlalchemy import and_, delete, select, update
from apps.common.postgres import postgres
from apps.models.conversation import Conversation
from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime
-from apps.schemas.request_data import RequestData
+from apps.schemas.task import TaskData
logger = logging.getLogger(__name__)
@@ -48,115 +48,65 @@ class TaskManager:
@staticmethod
- async def get_task_by_task_id(task_id: uuid.UUID) -> Task | None:
+ async def get_task_data_by_task_id(task_id: uuid.UUID, context_length: int | None = None) -> TaskData | None:
"""根据task_id获取任务"""
async with postgres.session() as session:
- return (await session.scalars(
+ task_data = (await session.scalars(
select(Task).where(Task.id == task_id),
)).one_or_none()
+ if not task_data:
+ logger.error("[TaskManager] 任务不存在 %s", task_id)
+ return None
-
- @staticmethod
- async def get_task_runtime_by_task_id(task_id: uuid.UUID) -> TaskRuntime | None:
- """根据task_id获取任务运行时"""
- async with postgres.session() as session:
- return (await session.scalars(
+ runtime = (await session.scalars(
select(TaskRuntime).where(TaskRuntime.taskId == task_id),
)).one_or_none()
+ if not runtime:
+ runtime = TaskRuntime(
+ taskId=task_id,
+ inputToken=0,
+ outputToken=0,
+ )
-
- @staticmethod
- async def get_task_state_by_task_id(task_id: uuid.UUID) -> ExecutorCheckpoint | None:
- """根据task_id获取任务状态"""
- async with postgres.session() as session:
- return (await session.scalars(
+ state = (await session.scalars(
select(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == task_id),
)).one_or_none()
-
- @staticmethod
- async def get_context_by_task_id(task_id: uuid.UUID, length: int | None = None) -> list[ExecutorHistory]:
- """根据task_id获取flow信息"""
- async with postgres.session() as session:
- return list((await session.scalars(
+ if context_length == 0:
+ context = []
+ else:
+ context = list((await session.scalars(
select(ExecutorHistory).where(
ExecutorHistory.taskId == task_id,
- ).order_by(ExecutorHistory.updatedAt.desc()).limit(length),
+ ).order_by(ExecutorHistory.updatedAt.desc()).limit(context_length),
)).all())
-
- @staticmethod
- async def init_new_task(
- user_sub: str,
- session_id: str | None = None,
- post_body: RequestData | None = None,
- ) -> Task:
- """获取任务块"""
- return Task(
- _id=str(uuid.uuid4()),
- ids=TaskIds(
- user_sub=user_sub if user_sub else "",
- session_id=session_id if session_id else "",
- conversation_id=post_body.conversation_id,
- ),
- question=post_body.question if post_body else "",
- tokens=TaskTokens(),
- runtime=TaskRuntime(),
- )
-
- @staticmethod
- async def save_flow_context(task_id: str, flow_context: list[ExecutorHistory]) -> None:
- """保存flow信息到flow_context"""
- if not flow_context:
- return
-
- flow_context_collection = MongoDB().get_collection("flow_context")
- try:
- for history in flow_context:
- # 查找是否存在
- current_context = await flow_context_collection.find_one({
- "task_id": task_id,
- "_id": history.id,
- })
- if current_context:
- await flow_context_collection.update_one(
- {"_id": current_context["_id"]},
- {"$set": history.model_dump(exclude_none=True, by_alias=True)},
- )
- else:
- await flow_context_collection.insert_one(history.model_dump(exclude_none=True, by_alias=True))
- except Exception:
- logger.exception("[TaskManager] 保存flow执行记录失败")
+ return TaskData(
+ metadata=task_data,
+ runtime=runtime,
+ state=state,
+ context=context,
+ )
@staticmethod
async def delete_task_by_task_id(task_id: uuid.UUID) -> None:
"""通过task_id删除Task信息"""
async with postgres.session() as session:
- task = (await session.scalars(
- select(Task).where(Task.id == task_id),
- )).one_or_none()
- if task:
- await session.delete(task)
-
-
- @staticmethod
- async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> list[uuid.UUID]:
- """通过ConversationID删除Task信息"""
- async with postgres.session() as session:
- task_ids = []
- tasks = (await session.scalars(
- select(Task).where(Task.conversationId == conversation_id),
- )).all()
- for task in tasks:
- task_ids.append(str(task.id))
- await session.delete(task)
+ await session.execute(
+ delete(Task).where(Task.id == task_id),
+ )
+ await session.execute(
+ delete(TaskRuntime).where(TaskRuntime.taskId == task_id),
+ )
+ await session.execute(
+ delete(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == task_id),
+ )
await session.commit()
- return task_ids
@staticmethod
- async def delete_task_history_checkpoint_by_conversation_id(conversation_id: uuid.UUID) -> None:
+ async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> None:
"""通过ConversationID删除Task信息"""
# 删除Task
task_ids = []
@@ -172,6 +122,29 @@ class TaskManager:
await session.execute(
delete(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId.in_(task_ids)),
)
+ await session.execute(
+ delete(TaskRuntime).where(TaskRuntime.taskId.in_(task_ids)),
+ )
+ await session.commit()
+
+
+ @staticmethod
+ async def delete_task_context_by_task_id(task_id: uuid.UUID) -> None:
+ """通过task_id删除TaskContext信息"""
+ async with postgres.session() as session:
+ await session.execute(
+ delete(ExecutorHistory).where(ExecutorHistory.taskId == task_id),
+ )
+ await session.commit()
+
+
+ @staticmethod
+ async def delete_task_context_by_conversation_id(conversation_id: uuid.UUID) -> None:
+ """通过ConversationID删除TaskContext信息"""
+ async with postgres.session() as session:
+ task_ids = list((await session.scalars(
+ select(Task.id).where(Task.conversationId == conversation_id),
+ )).all())
await session.execute(
delete(ExecutorHistory).where(ExecutorHistory.taskId.in_(task_ids)),
)