From 92b4bb6c8a899dd4af8fe876f814f0a76a78a6a4 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 25 Feb 2025 21:08:04 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E5=BC=80=E5=8F=91Scheduler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 59 ++---- apps/common/task.py | 2 + apps/constants.py | 2 + apps/entities/message.py | 19 +- apps/entities/node.py | 3 +- apps/entities/scheduler.py | 26 +-- apps/entities/task.py | 11 +- apps/llm/patterns/domain.py | 2 +- apps/llm/patterns/executor.py | 41 ++-- apps/llm/patterns/facts.py | 4 +- apps/llm/patterns/json.py | 22 +- apps/llm/patterns/recommend.py | 6 +- apps/llm/patterns/rewrite.py | 9 + apps/llm/patterns/select.py | 7 +- apps/llm/reasoning.py | 2 +- apps/main.py | 4 + apps/manager/node.py | 76 +++++++ apps/manager/task.py | 13 +- apps/routers/chat.py | 30 ++- apps/scheduler/call/__init__.py | 8 - apps/scheduler/call/api.py | 26 ++- apps/scheduler/call/convert.py | 12 +- apps/scheduler/call/core.py | 75 ++----- apps/scheduler/call/direct.py | 0 apps/scheduler/call/llm.py | 71 ++++--- apps/scheduler/call/rag.py | 46 ++--- apps/scheduler/call/render/render.py | 6 +- apps/scheduler/call/sql.py | 6 +- apps/scheduler/call/suggest.py | 4 +- apps/scheduler/executor/flow.py | 141 +++++++------ apps/scheduler/executor/message.py | 136 ++++--------- apps/scheduler/pool/check.py | 8 +- apps/scheduler/pool/loader/app.py | 2 + apps/scheduler/pool/loader/call.py | 30 --- apps/scheduler/pool/loader/metadata.py | 2 +- apps/scheduler/pool/loader/service.py | 2 + apps/scheduler/pool/pool.py | 37 +++- apps/scheduler/scheduler/context.py | 100 ++++++++- apps/scheduler/scheduler/flow.py | 4 +- apps/scheduler/scheduler/message.py | 69 +++---- apps/scheduler/scheduler/scheduler.py | 272 ++++++++++--------------- apps/scheduler/slot/slot.py | 6 +- 42 files changed, 674 insertions(+), 727 deletions(-) delete mode 100644 apps/scheduler/call/direct.py diff --git a/apps/common/queue.py b/apps/common/queue.py index e327568b1..3ae2e3baf 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -9,7 +9,7 @@ import ray from redis.exceptions import ResponseError from apps.constants import LOGGER -from apps.entities.enum_var import EventType, StepStatus +from apps.entities.enum_var import EventType from apps.entities.message import ( HeartbeatData, MessageBase, @@ -20,13 +20,14 @@ from apps.entities.task import TaskBlock from apps.models.redis import RedisConnectionPool +@ray.remote class MessageQueue: """包装SimpleQueue,加入组装消息、自动心跳等机制""" _heartbeat_interval: float = 3.0 - async def init(self, task_id: str, *, enable_heartbeat: bool = False) -> None: + async def init(self, task_id: str) -> None: """异步初始化消息队列 :param task_id: 任务ID @@ -38,11 +39,10 @@ class MessageQueue: self._consumer_name = "consumer" self._close = False - if enable_heartbeat: - self._heartbeat_task = asyncio.create_task(self._heartbeat()) + self._heartbeat_task = asyncio.get_running_loop().create_task(self._heartbeat()) - async def push_output(self, event_type: EventType, data: dict[str, Any]) -> None: + async def push_output(self, task: TaskBlock, event_type: EventType, data: dict[str, Any]) -> None: """组装用于向用户(前端/Shell端)输出的消息""" client = RedisConnectionPool.get_redis_connection() @@ -50,46 +50,31 @@ class MessageQueue: await client.publish(self._stream_name, "[DONE]") return - task = ray.get_actor("task") - tcb: TaskBlock = await task.get_task.remote(self._task_id) - - # 计算创建Task到现在的时间 - used_time = round((datetime.now(timezone.utc).timestamp() - tcb.record.metadata.time), 2) + # 计算当前Step时间 + step_time = round((datetime.now(timezone.utc).timestamp() - task.record.metadata.time), 3) metadata = MessageMetadata( - time=used_time, - input_tokens=tcb.record.metadata.input_tokens, - output_tokens=tcb.record.metadata.output_tokens, + time=step_time, + input_tokens=task.record.metadata.input_tokens, + output_tokens=task.record.metadata.output_tokens, ) - if tcb.flow_state: - history_ids = tcb.new_context - if not history_ids: - # 如果new_history为空,则说明是第一次执行,创建一个空值 - flow = MessageFlow( - appId=tcb.flow_state.app_id, - flowId=tcb.flow_state.name, - stepId="start", - stepStatus=StepStatus.RUNNING, - ) - else: - # 如果new_history不为空,则说明是继续执行,使用最后一个FlowHistory - history = tcb.flow_context[tcb.flow_state.step_id] - - flow = MessageFlow( - appId=history.app_id, - flowId=history.flow_id, - stepId=history.step_id, - stepStatus=history.status, - ) + if task.flow_state: + # 如果使用了Flow + flow = MessageFlow( + appId=task.flow_state.app_id, + flowId=task.flow_state.name, + stepId=task.flow_state.step_id, + stepStatus=task.flow_state.status, + ) else: flow = None message = MessageBase( event=event_type, - id=tcb.record.id, - groupId=tcb.record.group_id, - conversationId=tcb.record.conversation_id, - taskId=tcb.record.task_id, + id=task.record.id, + groupId=task.record.group_id, + conversationId=task.record.conversation_id, + taskId=task.record.task_id, metadata=metadata, flow=flow, content=data, diff --git a/apps/common/task.py b/apps/common/task.py index 9443cd594..7970b22eb 100644 --- a/apps/common/task.py +++ b/apps/common/task.py @@ -29,11 +29,13 @@ class Task: """初始化TaskManager""" self._task_map: dict[str, TaskBlock] = {} + async def update_token_summary(self, task_id: str, input_num: int, output_num: int) -> None: """更新对应task_id的Token统计数据""" self._task_map[task_id].record.metadata.input_tokens += input_num self._task_map[task_id].record.metadata.output_tokens += output_num + async def get_task(self, task_id: Optional[str] = None, session_id: Optional[str] = None, post_body: Optional[RequestData] = None) -> TaskBlock: """获取任务块""" # 如果task_map里面已经有了,则直接返回副本 diff --git a/apps/constants.py b/apps/constants.py index 994b74dec..b50adbc53 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -26,6 +26,8 @@ APP_DIR = "app" FLOW_DIR = "flow" # 日志记录器 LOGGER = logging.getLogger("ray") +# Scheduler进程数 +SCHEDULER_REPLICAS = 4 REASONING_BEGIN_TOKEN = [ "", diff --git a/apps/entities/message.py b/apps/entities/message.py index 9988f6d1a..60f79865d 100644 --- a/apps/entities/message.py +++ b/apps/entities/message.py @@ -80,26 +80,11 @@ class FlowStartContent(BaseModel): params: dict[str, Any] = Field(description="预先提供的参数") -class StepInputContent(BaseModel): - """step.input消息的content""" - - call_type: str = Field(description="Call类型", alias="callType") - params: dict[str, Any] = Field(description="Step最后输入的参数") - - -class StepOutputContent(BaseModel): - """step.output消息的content""" - - call_type: str = Field(description="Call类型", alias="callType") - message: str = Field(description="LLM大模型输出的自然语言文本") - output: dict[str, Any] = Field(description="Step输出的结构化数据") - - class FlowStopContent(BaseModel): """flow.stop消息的content""" - type: FlowOutputType = Field(description="Flow输出的类型") - data: Optional[dict[str, Any]] = Field(description="Flow输出的数据") + type: Optional[FlowOutputType] = Field(description="Flow输出的类型", default=None) + data: Optional[dict[str, Any]] = Field(description="Flow输出的数据", default=None) class MessageBase(HeartbeatData): diff --git a/apps/entities/node.py b/apps/entities/node.py index 530fe6717..03aab2e7e 100644 --- a/apps/entities/node.py +++ b/apps/entities/node.py @@ -12,6 +12,7 @@ class APINodeInput(BaseModel): param_schema: Optional[dict[str, Any]] = Field(description="API节点输入参数Schema", default=None) body_schema: Optional[dict[str, Any]] = Field(description="API节点输入请求体Schema", default=None) + class APINodeOutput(BaseModel): """API节点覆盖输出""" @@ -24,5 +25,3 @@ class APINode(NodePool): call_id: str = "API" override_input: Optional[APINodeInput] = Field(description="API节点输入覆盖", default=None) override_output: Optional[APINodeOutput] = Field(description="API节点输出覆盖", default=None) - - diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py index 658624dda..3a23d0f5e 100644 --- a/apps/entities/scheduler.py +++ b/apps/entities/scheduler.py @@ -6,11 +6,10 @@ from typing import Any from pydantic import BaseModel, Field -from apps.common.queue import MessageQueue -from apps.entities.task import FlowHistory, RequestDataApp +from apps.entities.task import FlowStepHistory -class SysCallVars(BaseModel): +class CallVars(BaseModel): """所有Call都需要接受的参数。包含用户输入、上下文信息、Step的输出记录等 这一部分的参数由Executor填充,用户无法修改 @@ -18,7 +17,7 @@ class SysCallVars(BaseModel): background: str = Field(description="上下文信息") question: str = Field(description="改写后的用户输入") - history: list[FlowHistory] = Field(description="Executor中历史工具的结构化数据", default=[]) + history: list[FlowStepHistory] = Field(description="Executor中历史工具的结构化数据", default=[]) task_id: str = Field(description="任务ID") session_id: str = Field(description="当前用户的Session ID") extra: dict[str, Any] = Field(description="其他Executor设置的、用户不可修改的参数", default={}) @@ -32,25 +31,6 @@ class ExecutorBackground(BaseModel): thought: str = Field(description="之前Executor的思考内容", default="") -class SysExecVars(BaseModel): - """Executor状态 - - 由系统自动传递给Executor - """ - - queue: MessageQueue = Field(description="当前Executor关联的Queue") - question: str = Field(description="当前Agent的目标") - task_id: str = Field(description="当前Executor关联的TaskID") - session_id: str = Field(description="当前用户的Session ID") - app_data: RequestDataApp = Field(description="传递给Executor中Call的参数") - background: ExecutorBackground = Field(description="当前Executor的背景信息") - - class Config: - """允许任意类型""" - - arbitrary_types_allowed = True - - class CallError(Exception): """Call错误""" diff --git a/apps/entities/task.py b/apps/entities/task.py index 26a788ede..77737fb8b 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -12,7 +12,7 @@ from apps.entities.enum_var import StepStatus from apps.entities.record import RecordData -class FlowHistory(BaseModel): +class FlowStepHistory(BaseModel): """任务执行历史;每个Executor每个步骤执行后都会创建 Collection: flow_history @@ -40,7 +40,7 @@ class ExecutorState(BaseModel): app_id: str = Field(description="应用ID") # 运行时数据 thought: str = Field(description="大模型的思考内容", default="") - slot_data: dict[str, Any] = Field(description="待使用的参数", default={}) + filled_data: dict[str, Any] = Field(description="待使用的参数", default={}) remaining_schema: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) @@ -50,7 +50,7 @@ class TaskBlock(BaseModel): session_id: str = Field(description="浏览器会话ID") record: RecordData = Field(description="当前任务执行过程关联的Record") flow_state: Optional[ExecutorState] = Field(description="Flow的状态", default=None) - flow_context: dict[str, FlowHistory] = Field(description="Flow的执行信息", default={}) + flow_context: dict[str, FlowStepHistory] = Field(description="Flow的执行信息", default={}) new_context: list[str] = Field(description="Flow的执行信息(增量ID)", default=[]) @@ -76,3 +76,8 @@ class TaskData(BaseModel): ended: bool = False updated_at: float = Field(default_factory=lambda: round(datetime.now(tz=timezone.utc).timestamp(), 3)) + +class SchedulerResult(BaseModel): + """调度器返回结果""" + + used_docs: list[str] = Field(description="已使用的文档ID列表") diff --git a/apps/llm/patterns/domain.py b/apps/llm/patterns/domain.py index 159e5b40b..40c75e169 100644 --- a/apps/llm/patterns/domain.py +++ b/apps/llm/patterns/domain.py @@ -73,5 +73,5 @@ class Domain(CorePattern): {"role": "assistant", "content": result}, ] - output = await Json().generate(task_id, conversation=messages, spec=self.slot_schema) + output = await Json().generate("", conversation=messages, spec=self.slot_schema) return output["keywords"] diff --git a/apps/llm/patterns/executor.py b/apps/llm/patterns/executor.py index 4a5dd8e1c..086ee7174 100644 --- a/apps/llm/patterns/executor.py +++ b/apps/llm/patterns/executor.py @@ -24,7 +24,7 @@ class ExecutorThought(CorePattern): 注意: 工具的相关信息在标签中给出。 为了使你更好的理解发生了什么,你之前的思考过程在标签中给出。 - 输出时请不要包含XML标签,请精准、简明。 + 输出时请不要包含XML标签,输出时请保持简明和清晰。 @@ -51,6 +51,7 @@ class ExecutorThought(CorePattern): """处理Prompt""" super().__init__(system_prompt, user_prompt) + async def generate(self, task_id: str, **kwargs) -> str: # noqa: ANN003 """调用大模型,生成对话总结""" try: @@ -83,11 +84,13 @@ class ExecutorBackground(CorePattern): """使用大模型进行生成Executor初始背景""" user_prompt: str = r""" - 根据对话上文,结合给定的AI助手思考过程,生成一个完整的背景总结。这个总结将用于后续对话的上下文理解。 - 生成总结的要求如下: - 1. 突出重要信息点,例如时间、地点、人物、事件等。 - 2. 下面给出的事实条目若与历史记录有关,则可以在生成总结时作为已知信息。 - 3. 确保信息准确性,不得编造信息。 + + + 根据,结合给定的AI助手思考过程,生成一个完整的背景总结。这个总结将用于后续对话的上下文理解。 + 生成总结的要求如下: + 1. 突出重要信息点,例如时间、地点、人物、事件等。 + 2. 下面给出的事实条目若与历史记录有关,则可以在生成总结时作为已知信息。 + 3. 确保信息准确性,不得编造信息。 4. 总结应少于1000个字。 思考过程(在标签中): @@ -141,28 +144,36 @@ class FinalThought(CorePattern): """使用大模型生成Executor的最终结果""" user_prompt: str = r""" - 你是AI智能助手,请回答用户的问题并满足以下要求: - 1. 使用中文回答问题,不要使用其他语言。 - 2. 回答应当语气友好、通俗易懂,并包含尽可能完整的信息。 - 3. 回答时应结合思考过程。 + + 你是AI智能助手,请回答用户的问题并满足以下要求: - 用户的问题是: - {question} + 1. 使用中文回答问题,不要使用其他语言。 + 2. 回答应当语气友好、通俗易懂,并包含尽可能完整的信息。 + 3. 回答时应结合思考过程。 + 4. 输出时请不要包含XML标签,不要编造任何信息。 + + 用户的问题将在标签中给出,你之前的思考过程将在标签中给出。 + + + + {question} + - 思考过程(在标签中): - {thought}{output} + {thought}{output} 现在,请根据以上信息进行回答: """ """用户提示词""" + def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None: """初始化ExecutorResult模式""" super().__init__(system_prompt, user_prompt) - async def generate(self, task_id: str, **kwargs) -> AsyncGenerator[str, None]: # noqa: ANN003 + + async def generate(self, task_id: str, **kwargs) -> AsyncGenerator[str, Any]: # noqa: ANN003 """进行ExecutorResult生成""" question: str = kwargs["question"] thought: str = kwargs["thought"] diff --git a/apps/llm/patterns/facts.py b/apps/llm/patterns/facts.py index cf9451e59..a61e51785 100644 --- a/apps/llm/patterns/facts.py +++ b/apps/llm/patterns/facts.py @@ -29,11 +29,9 @@ class Facts(CorePattern): 2. 事实必须清晰、简洁、易于理解。必须少于30个字。 3. 必须按照以下JSON格式输出: - ```json {{ "facts": ["事实1", "事实2", "事实3"] }} - ``` @@ -89,7 +87,7 @@ class Facts(CorePattern): result += chunk messages += [{"role": "assistant", "content": result}] - fact_dict = await Json().generate(task_id, conversation=messages, spec=self.slot_schema) + fact_dict = await Json().generate("", conversation=messages, spec=self.slot_schema) if not fact_dict or "facts" not in fact_dict or not fact_dict["facts"]: return [] diff --git a/apps/llm/patterns/json.py b/apps/llm/patterns/json.py index 4b06bc8fe..10c594f44 100644 --- a/apps/llm/patterns/json.py +++ b/apps/llm/patterns/json.py @@ -32,7 +32,7 @@ class Json(CorePattern): EXAMPLE - [HUMAN] 创建“任务1”,并进行扫描 + [HUMAN] 创建"任务1",并进行扫描 @@ -122,25 +122,25 @@ class Json(CorePattern): # 把必要信息放到描述中,起提示作用 if "pattern" in spec: - new_spec["description"] += f"\nThe regex pattern is: {spec['pattern']}." + new_spec["description"] += f"\n正则表达式模式为:{spec['pattern']}" if "example" in spec: - new_spec["description"] += f"\nFor example: {spec['example']}." + new_spec["description"] += f"\n示例:{spec['example']}" if "default" in spec: - new_spec["description"] += f"\nThe default value is: {spec['default']}." + new_spec["description"] += f"\n默认值为:{spec['default']}" if "enum" in spec: - new_spec["description"] += f"\nValue must be one of: {', '.join(str(item) for item in spec['enum'])}." + new_spec["description"] += f"\n取值必须是以下之一:{', '.join(str(item) for item in spec['enum'])}" if "minimum" in spec: - new_spec["description"] += f"\nValue must be greater than or equal to: {spec['minimum']}." + new_spec["description"] += f"\n值必须大于或等于:{spec['minimum']}" if "maximum" in spec: - new_spec["description"] += f"\nValue must be less than or equal to: {spec['maximum']}." + new_spec["description"] += f"\n值必须小于或等于:{spec['maximum']}" if "minLength" in spec: - new_spec["description"] += f"\nValue must be at least {spec['minLength']} characters long." + new_spec["description"] += f"\n长度必须至少为 {spec['minLength']} 个字符" if "maxLength" in spec: - new_spec["description"] += f"\nValue must be at most {spec['maxLength']} characters long." + new_spec["description"] += f"\n长度不能超过 {spec['maxLength']} 个字符" if "minItems" in spec: - new_spec["description"] += f"\nArray must contain at least {spec['minItems']} items." + new_spec["description"] += f"\n数组至少包含 {spec['minItems']} 个项目" if "maxItems" in spec: - new_spec["description"] += f"\nArray must contain at most {spec['maxItems']} items." + new_spec["description"] += f"\n数组最多包含 {spec['maxItems']} 个项目" return new_spec diff --git a/apps/llm/patterns/recommend.py b/apps/llm/patterns/recommend.py index 9397d0d38..9307fff5b 100644 --- a/apps/llm/patterns/recommend.py +++ b/apps/llm/patterns/recommend.py @@ -90,10 +90,12 @@ class Recommend(CorePattern): } """最终输出的JSON Schema""" + def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None: """初始化推荐问题生成Prompt""" super().__init__(system_prompt, user_prompt) + async def generate(self, task_id: str, **kwargs) -> list[str]: # noqa: ANN003 """生成推荐问题""" if "action_description" not in kwargs or not kwargs["action_description"]: @@ -123,11 +125,11 @@ class Recommend(CorePattern): ] result = "" - async for chunk in ReasoningLLM().call(task_id, messages, streaming=False, temperature=0.7, result_only=True): + async for chunk in ReasoningLLM().call(task_id, messages, streaming=False, temperature=0.7): result += chunk messages += [{"role": "assistant", "content": result}] - question_dict = await Json().generate(task_id, conversation=messages, spec=self.slot_schema) + question_dict = await Json().generate("", conversation=messages, spec=self.slot_schema) if not question_dict or "predicted_questions" not in question_dict or not question_dict["predicted_questions"]: return [] diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index be2066e7b..1760db18d 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -25,4 +25,13 @@ class QuestionRewrite(CorePattern): {question} """ + """用户提示词""" + + async def generate(self, task_id: str, **kwargs) -> str: # noqa: ANN003 + """问题补全与重写""" + question = kwargs["question"] + + messages = [ + + ] diff --git a/apps/llm/patterns/select.py b/apps/llm/patterns/select.py index 47ad9f41e..6e1990f6c 100644 --- a/apps/llm/patterns/select.py +++ b/apps/llm/patterns/select.py @@ -77,6 +77,7 @@ class Select(CorePattern): """初始化Prompt""" super().__init__(system_prompt, user_prompt) + @staticmethod def _choices_to_prompt(choices: list[dict[str, Any]]) -> tuple[str, list[str]]: """将选项转换为Prompt""" @@ -87,6 +88,7 @@ class Select(CorePattern): choice_str_list.append(choice["name"]) return choices_prompt, choice_str_list + async def _generate_single_attempt(self, task_id: str, user_input: str, choice_list: list[str]) -> str: """使用ReasoningLLM进行单次尝试""" messages = [ @@ -96,14 +98,17 @@ class Select(CorePattern): result = "" async for chunk in ReasoningLLM().call(task_id, messages, streaming=False): result += chunk + + # 使用FunctionLLM进行参数提取 schema = self.slot_schema schema["properties"]["choice"]["enum"] = choice_list messages += [{"role": "assistant", "content": result}] - function_result = await Json().generate(task_id, conversation=messages, spec=schema) + function_result = await Json().generate("", conversation=messages, spec=schema) return function_result["choice"] + async def generate(self, task_id: str, **kwargs) -> str: # noqa: ANN003 """使用大模型做出选择""" max_try = 3 diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index 6cb2ca0a5..ab9606748 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -158,4 +158,4 @@ class ReasoningLLM: output_tokens = self._calculate_token_length([{"role": "assistant", "content": result}], pure_text=True) task = ray.get_actor("task") - await task.update_token_summary.remote(task_id, input_tokens, output_tokens) + await task.update_token_summary.remote(task_id, input_tokens, output_tokens) \ No newline at end of file diff --git a/apps/main.py b/apps/main.py index 71f3f8d1e..b59e88c35 100644 --- a/apps/main.py +++ b/apps/main.py @@ -15,6 +15,7 @@ from ray.serve.config import HTTPOptions from apps.common.config import config from apps.common.task import Task from apps.common.wordscheck import WordsCheck +from apps.constants import SCHEDULER_REPLICAS from apps.cron.delete_user import DeleteUserCron from apps.dependency.session import VerifySessionMiddleware from apps.routers import ( @@ -36,6 +37,7 @@ from apps.routers import ( user, ) from apps.scheduler.pool.pool import Pool +from apps.scheduler.scheduler.scheduler import Scheduler # 定义FastAPI app app = FastAPI(docs_url=None, redoc_url=None) @@ -88,6 +90,8 @@ if __name__ == "__main__": task = Task.options(name="task").remote() pool_actor = Pool.options(name="pool").remote() ray.get(pool_actor.init.remote()) # type: ignore[attr-type] + # 初始化Scheduler + scheduler_sctors = [Scheduler.options(name=f"scheduler_{i}").remote() for i in range(SCHEDULER_REPLICAS)] # 启动FastAPI serve.start(http_options=HTTPOptions(host="0.0.0.0", port=8002)) # noqa: S104 diff --git a/apps/manager/node.py b/apps/manager/node.py index 05f0a3d76..e79175a5b 100644 --- a/apps/manager/node.py +++ b/apps/manager/node.py @@ -1,6 +1,16 @@ """Node管理器""" +from typing import Any, Optional + +import ray + +from apps.constants import LOGGER +from apps.entities.node import APINode +from apps.entities.pool import CallPool, Node, NodePool from apps.models.mongo import MongoDB +NODE_TYPE_MAP = { + "API": APINode, +} class NodeManager: """Node管理器""" @@ -14,3 +24,69 @@ class NodeManager: err = f"[NodeManager] Node {node_id} not found." raise ValueError(err) return node["call_id"] + + + @staticmethod + async def get_node_name(node_id: str) -> str: + """获取node的名称""" + node_collection = MongoDB.get_collection("node") + # 查询 Node 集合获取对应的 name + node_doc = await node_collection.find_one({"_id": node_id}, {"name": 1}) + if not node_doc: + LOGGER.error(f"Node {node_id} not found") + return "" + return node_doc["name"] + + + @staticmethod + def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]: + """合并参数Schema""" + pass + + + @staticmethod + async def get_node_data(node_id: str) -> Optional[Node]: + """获取Node数据""" + # 查找Node信息 + node_collection = MongoDB().get_collection("node") + try: + node = await node_collection.find_one({"id": node_id}) + node_data = NodePool.model_validate(node) + except Exception as e: + err = f"[NodeManager] Get node data error: {e}" + LOGGER.error(err) + raise ValueError(err) from e + + call_id = node_data.call_id + # 查找Node对应的Call信息 + call_collection = MongoDB().get_collection("call") + try: + call = await call_collection.find_one({"id": call_id}) + call_data = CallPool.model_validate(call) + except Exception as e: + err = f"[NodeManager] Get call data error: {e}" + LOGGER.error(err) + raise ValueError(err) from e + + # 查找Call信息 + pool = ray.get_actor("pool") + call_class = await pool.get_call.remote(call_data.path) + if not call_class: + err = f"[NodeManager] Call {call_data.path} not found" + LOGGER.error(err) + raise ValueError(err) + + # 找到Call的参数 + result_node = Node( + _id=node_data.id, + name=node_data.name, + description=node_data.description, + created_at=node_data.created_at, + service_id=node_data.service_id, + call_id=node_data.call_id, + output_schema=call_class.output_schema, + params_schema=NodeManager.merge_params_schema(call_class.params_schema, node_data.known_params or {}), + ) + + return call_class + diff --git a/apps/manager/task.py b/apps/manager/task.py index 6aa4fb4c9..dd1aeaa3a 100644 --- a/apps/manager/task.py +++ b/apps/manager/task.py @@ -6,7 +6,7 @@ from typing import Optional from apps.constants import LOGGER from apps.entities.collection import RecordGroup -from apps.entities.task import FlowHistory, TaskData +from apps.entities.task import FlowStepHistory, TaskData from apps.manager.record import RecordManager from apps.models.mongo import MongoDB @@ -59,7 +59,7 @@ class TaskManager: return None @staticmethod - async def get_flow_history_by_record_id(record_group_id: str, record_id: str) -> list[FlowHistory]: + async def get_flow_history_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: """根据record_group_id获取flow信息""" record_group_collection = MongoDB.get_collection("record_group") flow_context_collection = MongoDB.get_collection("flow_context") @@ -77,7 +77,7 @@ class TaskManager: for flow_context_id in records[0]["records"]["flow"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context = FlowHistory.model_validate(flow_context) + flow_context = FlowStepHistory.model_validate(flow_context) flow_context_list.append(flow_context) return flow_context_list @@ -88,14 +88,14 @@ class TaskManager: @staticmethod - async def get_flow_history_by_task_id(task_id: str) -> dict[str, FlowHistory]: + async def get_flow_history_by_task_id(task_id: str) -> dict[str, FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB.get_collection("flow_context") flow_context = {} try: async for history in flow_context_collection.find({"task_id": task_id}): - history_obj = FlowHistory.model_validate(history) + history_obj = FlowStepHistory.model_validate(history) flow_context[history_obj.step_id] = history_obj return flow_context @@ -105,7 +105,7 @@ class TaskManager: @staticmethod - async def create_flows(flow_context: list[FlowHistory]) -> None: + async def create_flows(flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB.get_collection("flow_context") try: @@ -128,4 +128,3 @@ class TaskManager: await session.commit_transaction() except Exception as e: LOGGER.error(f"[TaskManager] Delete tasks by conversation_id failed: {e}") - diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 58ab7fe8b..2bb806368 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -2,7 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import asyncio +import random import traceback import uuid from collections.abc import AsyncGenerator @@ -13,7 +13,7 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue -from apps.constants import LOGGER +from apps.constants import LOGGER, SCHEDULER_REPLICAS from apps.dependency import ( get_session, get_user, @@ -24,7 +24,7 @@ from apps.entities.request_data import RequestData from apps.entities.response_data import ResponseData from apps.manager.appcenter import AppCenterManager from apps.manager.blacklist import QuestionBlacklistManager, UserBlacklistManager -from apps.scheduler.scheduler import Scheduler +from apps.scheduler.scheduler.context import save_data from apps.service.activity import Activity RECOMMEND_TRES = 5 @@ -61,22 +61,24 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) await task_pool.set_task.remote(task_id, task) # 创建queue;由Scheduler进行关闭 - queue = MessageQueue() - await queue.init(task_id, enable_heartbeat=True) + queue = MessageQueue.remote() + await queue.init.remote(task_id) # type: ignore[attr-defined] # 在单独Task中运行Scheduler,拉齐queue.get的时机 - scheduler = Scheduler(task_id, queue) - scheduler_task = asyncio.create_task(scheduler.run(user_sub, session_id, post_body)) + randnum = random.randint(0, SCHEDULER_REPLICAS - 1) # noqa: S311 + scheduler_actor = ray.get_actor(f"scheduler_{randnum}") + scheduler = scheduler_actor.run.remote(task_id, queue, user_sub, post_body) # 处理每一条消息 - async for event in queue.get(): - if event[:6] == "[DONE]": + async for event in queue.get.remote(): # type: ignore[attr-defined] + content = await event + if content[:6] == "[DONE]": break - yield "data: " + event + "\n\n" + yield "data: " + content + "\n\n" # 等待Scheduler运行完毕 - await asyncio.gather(scheduler_task) + result = await scheduler # 获取最终答案 task = await task_pool.get_task.remote(task_id) @@ -95,9 +97,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) return # 创建新Record,存入数据库 - await scheduler.save_state(user_sub, post_body) - # 保存Task,从task_map中删除task - await task_pool.save_task.remote(task_id) + await save_data(task_id, user_sub, post_body, result.used_docs) yield "data: [DONE]\n\n" @@ -106,8 +106,6 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) yield "data: [ERROR]\n\n" finally: - if scheduler_task: - scheduler_task.cancel() await Activity.remove_active(user_sub) diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index c962e24c1..96281e1a3 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -3,19 +3,11 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from apps.scheduler.call.api import API -from apps.scheduler.call.convert import Convert from apps.scheduler.call.llm import LLM from apps.scheduler.call.rag import RAG -from apps.scheduler.call.render.render import Render -from apps.scheduler.call.sql import SQL -from apps.scheduler.call.suggest import Suggestion __all__ = [ "API", "LLM", "RAG", - "SQL", - "Convert", - "Render", - "Suggestion", ] diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index 8035ec476..d2c76b489 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -3,20 +3,20 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, ClassVar import aiohttp from fastapi import status from pydantic import BaseModel, Field from apps.constants import LOGGER -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError, CallVars from apps.manager.token import TokenManager from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot -class APIParams(BaseModel): +class _APIParams(BaseModel): """API调用工具的参数""" url: str = Field(description="API接口的完整URL") @@ -28,7 +28,6 @@ class APIParams(BaseModel): ] = Field(description="API接口的Content-Type") timeout: int = Field(description="工具超时时间", default=300) body: dict[str, Any] = Field(description="已知的部分请求体", default={}) - input_schema: dict[str, Any] = Field(description="API请求体的JSON Schema", default={}) auth: dict[str, Any] = Field(description="API鉴权信息", default={}) @@ -40,13 +39,13 @@ class _APIOutput(BaseModel): output: dict[str, Any] = Field(description="API调用工具的输出") -class API(metaclass=CoreCall, param_cls=APIParams, output_cls=_APIOutput): +class API(CoreCall): """API调用工具""" - name: str = "api" - description: str = "根据给定的用户输入和历史记录信息,向某一个API接口发送请求、获取数据。" + name: ClassVar[str] = "HTTP请求" + description: ClassVar[str] = "向某一个API接口发送HTTP请求,获取数据。" - async def __call__(self, slot_data: dict[str, Any]) -> _APIOutput: + async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> _APIOutput: """调用API,然后返回LLM解析后的数据""" self._session = aiohttp.ClientSession() try: @@ -60,11 +59,10 @@ class API(metaclass=CoreCall, param_cls=APIParams, output_cls=_APIOutput): async def _make_api_call(self, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202, C901 # 获取必要参数 - params: APIParams = getattr(self, "_params") - syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + params: _APIParams = getattr(self, "_params") + syscall_vars: CallVars = getattr(self, "_syscall_vars") - """调用API""" - if self._data_type != "form": + if params.content_type != "form": req_header = { "Content-Type": "application/json", } @@ -113,7 +111,7 @@ class API(metaclass=CoreCall, param_cls=APIParams, output_cls=_APIOutput): async def _call_api(self, slot_data: Optional[dict[str, Any]] = None) -> _APIOutput: # 获取必要参数 - params: APIParams = getattr(self, "_params") + params: _APIParams = getattr(self, "_params") LOGGER.info(f"调用接口{params.url},请求数据为{slot_data}") session_context = await self._make_api_call(slot_data, aiohttp.FormData()) @@ -155,5 +153,5 @@ class API(metaclass=CoreCall, param_cls=APIParams, output_cls=_APIOutput): return _APIOutput( http_code=response_status, output=json.loads(response_data), - message=message + """The API returned some data, and is shown in the "output" field below.""", + message=message + "The API returned some data, and is shown in the 'output' field below.", ) diff --git a/apps/scheduler/call/convert.py b/apps/scheduler/call/convert.py index 3a1ecea9b..f29b74118 100644 --- a/apps/scheduler/call/convert.py +++ b/apps/scheduler/call/convert.py @@ -13,7 +13,7 @@ from jinja2 import BaseLoader, select_autoescape from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field -from apps.entities.scheduler import SysCallVars +from apps.entities.scheduler import CallVars from apps.scheduler.call.core import CoreCall @@ -27,8 +27,8 @@ class _ConvertParam(BaseModel): class _ConvertOutput(BaseModel): """定义Convert工具的输出""" - message: str = Field(description="格式化后的文字信息") - output: dict = Field(description="格式化后的结果") + text: str = Field(description="格式化后的文字信息") + data: dict = Field(description="格式化后的结果") class Convert(metaclass=CoreCall, param_cls=_ConvertParam, output_cls=_ConvertOutput): @@ -46,7 +46,7 @@ class Convert(metaclass=CoreCall, param_cls=_ConvertParam, output_cls=_ConvertOu """ # 获取必要参数 params: _ConvertParam = getattr(self, "_params") - syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + syscall_vars: CallVars = getattr(self, "_syscall_vars") last_output = syscall_vars.history[-1].output_data # 判断用户是否给了值 time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") @@ -77,6 +77,6 @@ class Convert(metaclass=CoreCall, param_cls=_ConvertParam, output_cls=_ConvertOu result_data = json.loads(_jsonnet.evaluate_snippet(data_template, params.data), ensure_ascii=False) return _ConvertOutput( - message=result_message, - output=result_data, + text=result_message, + data=result_data, ) diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 15d70f058..f50fafd2e 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -3,73 +3,30 @@ 所有Call类必须继承此类,并实现所有方法。 Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any +from typing import Any, ClassVar -from pydantic import BaseModel +from pydantic import BaseModel, Field -from apps.entities.scheduler import SysCallVars +from apps.entities.scheduler import CallVars -class CoreCall(type): - """Call元类。所有Call必须继承此类,并实现所有方法。""" +class CoreCall(BaseModel): + """所有Call的父类,所有Call必须继承此类。""" - @staticmethod - def _check_class_attr(cls_name: str, attrs: dict[str, Any]) -> None: - """检查类属性是否存在""" - if "name" not in attrs: - err = f"类{cls_name}中不存在属性name" - raise AttributeError(err) - if "description" not in attrs: - err = f"类{cls_name}中不存在属性description" - raise AttributeError(err) - if "__call__" not in attrs or not callable(attrs["__call__"]): - err = f"类{cls_name}中不存在属性__call__" - raise AttributeError(err) + name: ClassVar[str] = Field(description="Call的名称") + description: ClassVar[str] = Field(description="Call的描述") - @staticmethod - def _class_init_fixed(self, syscall_vars: SysCallVars, **kwargs) -> None: # type: ignore[] # noqa: ANN001, ANN003, PLW0211 - """Call子类的固定初始化函数""" - self._syscall_vars = syscall_vars - self._params = self._param_cls.model_validate(kwargs) - # 调用附加的初始化函数 - self.init(syscall_vars, **kwargs) + class Config: + """Pydantic 配置类""" - @staticmethod - def _class_init(self, syscall_vars: SysCallVars, **kwargs) -> None: # type: ignore[] # noqa: ANN001, ANN003, PLW0211 - """Call子类的附加初始化函数""" + arbitrary_types_allowed = True - @staticmethod - def _class_load(self) -> None: # type: ignore[] # noqa: ANN001, PLW0211 - """Call子类的文件载入函数""" + def __init_subclass__(cls, **kwargs: Any) -> None: + """初始化子类""" + return super().__init_subclass__(**kwargs) - def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], **kwargs) -> type: # noqa: ANN003 - """创建Call类""" - # 检查kwargs - if "param_cls" not in kwargs: - err = f"请给工具{name}提供参数模板!" - raise AttributeError(err) - if not issubclass(kwargs["param_cls"], BaseModel): - err = f"参数模板{kwargs['param_cls']}不是Pydantic类!" - raise TypeError(err) - if "output_cls" not in kwargs: - err = f"请给工具{name}提供输出模板!" - raise AttributeError(err) - if not issubclass(kwargs["output_cls"], BaseModel): - err = f"输出模板{kwargs['output_cls']}不是Pydantic类!" - raise TypeError(err) - - # 设置参数相关的属性 - attrs["_param_cls"] = kwargs["param_cls"] - attrs["params_schema"] = kwargs["param_cls"].model_json_schema() - attrs["output_schema"] = kwargs["output_cls"].model_json_schema() - # __init__不允许自定义 - attrs["__init__"] = lambda self, syscall_vars, **kwargs: self._class_init_fixed(syscall_vars, **kwargs) - # 提供空逻辑占位 - if "init" not in attrs: - attrs["init"] = cls._class_init - if "load" not in attrs: - attrs["load"] = cls._class_load - # 提供 - return super().__new__(cls, name, bases, attrs) + async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> type[BaseModel]: + """Call类实例的调用方法""" + raise NotImplementedError diff --git a/apps/scheduler/call/direct.py b/apps/scheduler/call/direct.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm.py index fda3ef815..e6f45f8db 100644 --- a/apps/scheduler/call/llm.py +++ b/apps/scheduler/call/llm.py @@ -4,85 +4,82 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from datetime import datetime from textwrap import dedent -from typing import Any +from typing import Any, ClassVar import pytz from jinja2 import BaseLoader, select_autoescape from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError, CallVars from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall +LLM_DEFAULT_PROMPT = dedent( + r""" + + 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。 + 当前时间:{{ time }},可以作为时间参照。 + 用户的问题将在中给出,上下文背景信息将在中给出。 + 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 + -class _LLMParams(BaseModel): - """LLMParams类用于定义大模型调用的参数,包括温度设置、系统提示词、用户提示词和超时时间。 + + {{ question }} + - 属性: - temperature (float): 大模型温度设置,默认值是1.0。 - system_prompt (str): 大模型系统提示词。 - user_prompt (str): 大模型用户提示词。 - timeout (int): 超时时间,默认值是30秒。 - """ + + {{ context }} + + """, + ).strip("\n") - temperature: float = Field(description="大模型温度设置", default=1.0) - system_prompt: str = Field(description="大模型系统提示词", default="你是一个乐于助人的助手。") - user_prompt: str = Field( - description="大模型用户提示词", - default=dedent(""" - 回答下面的用户问题: - {{ question }} - 附加信息: - 当前时间为{{ time }}。用户在提问前,使用了工具,并获得了以下返回值:`{{ last.output }}`。 - 额外的背景信息:{{ context }} - """).strip("\n")) - timeout: int = Field(description="超时时间", default=30) - - -class _LLMOutput(BaseModel): +class LLMNodeOutput(BaseModel): """定义LLM工具调用的输出""" message: str = Field(description="大模型输出的文字信息") -class LLM(metaclass=CoreCall, param_cls=_LLMParams, output_cls=_LLMOutput): +class LLM(CoreCall): """大模型调用工具""" - name: str = "llm" - description: str = "大模型调用工具,用于以指定的提示词和上下文信息调用大模型,并获得输出。" + name: ClassVar[str] = "大模型" + description: ClassVar[str] = "以指定的提示词和上下文信息调用大模型,并获得输出。" + temperature: float = Field(description="大模型温度(随机化程度)", default=0.7) + enable_context: bool = Field(description="是否启用上下文", default=True) + system_prompt: str = Field(description="大模型系统提示词", default="") + user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT) - async def __call__(self, _slot_data: dict[str, Any]) -> _LLMOutput: + + async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> LLMNodeOutput: """运行LLM Call""" - # 获取必要参数 - syscall_vars: SysCallVars = getattr(self, "_syscall_vars") - params: _LLMParams = getattr(self, "_params") # 参数 time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") formatter = { "time": time, "context": syscall_vars.background, "question": syscall_vars.question, - "history": syscall_vars.history, } try: - # 准备提示词 + # 准备系统提示词 system_tmpl = SandboxedEnvironment( loader=BaseLoader(), autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(params.system_prompt) + ).from_string(self.system_prompt) system_input = system_tmpl.render(**formatter) + + # 准备用户提示词 user_tmpl = SandboxedEnvironment( loader=BaseLoader(), autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(params.user_prompt) + ).from_string(self.user_prompt) user_input = user_tmpl.render(**formatter) except Exception as e: raise CallError(message=f"用户提示词渲染失败:{e!s}", data={}) from e @@ -99,4 +96,4 @@ class LLM(metaclass=CoreCall, param_cls=_LLMParams, output_cls=_LLMOutput): except Exception as e: raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e - return _LLMOutput(message=result) + return LLMNodeOutput(message=result) diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag.py index 3cace0fd2..a234bb082 100644 --- a/apps/scheduler/call/rag.py +++ b/apps/scheduler/call/rag.py @@ -2,50 +2,42 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any, Optional +from typing import Any, ClassVar, Optional import aiohttp from fastapi import status from pydantic import BaseModel, Field from apps.common.config import config -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError, CallVars from apps.scheduler.call.core import CoreCall -class _RAGParams(BaseModel): - """RAG工具的参数""" - - knowledge_base: str = Field(description="知识库的id", alias="kb_sn") - top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5) - methods: Optional[list[str]] = Field(description="rag检索方法") - - -class _RAGOutputList(BaseModel): +class _RAGOutput(BaseModel): """RAG工具的输出""" corpus: list[str] = Field(description="知识库的语料列表") -class _RAGOutput(BaseModel): - """RAG工具的输出""" - - output: _RAGOutputList = Field(description="RAG工具的输出") +class RAG(CoreCall): + """RAG工具:查询知识库""" + name: ClassVar[str] = "知识库" + description: ClassVar[str] = "查询知识库,从文档中获取必要信息" -class RAG(metaclass=CoreCall, param_cls=_RAGParams, output_cls=_RAGOutput): - """RAG工具:查询知识库""" + knowledge_base: str = Field(description="知识库的id", alias="kb_sn") + top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5) + methods: Optional[list[str]] = Field(description="rag检索方法") - name: str = "rag" - description: str = "RAG工具,用于查询知识库" - async def __call__(self, _slot_data: dict[str, Any]) -> _RAGOutput: + async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> _RAGOutput: """调用RAG工具""" - syscall_vars: SysCallVars = getattr(self, "_syscall_vars") - params: _RAGParams = getattr(self, "_params") - - params_dict = params.model_dump(exclude_none=True, by_alias=True) - params_dict["question"] = syscall_vars.question + params_dict = { + "kb_sn": self.knowledge_base, + "top_k": self.top_k, + "methods": self.methods, + "question": syscall_vars.question, + } url = config["RAG_HOST"].rstrip("/") + "/chunk/get" headers = { @@ -59,9 +51,7 @@ class RAG(metaclass=CoreCall, param_cls=_RAGParams, output_cls=_RAGOutput): if response.status == status.HTTP_200_OK: result = await response.json() chunk_list = result["data"] - return _RAGOutput( - output=_RAGOutputList(corpus=chunk_list), - ) + return _RAGOutput(corpus=chunk_list) text = await response.text() raise CallError( message=f"rag调用失败:{text}", diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/render/render.py index 08079e11e..09bc3b8a0 100644 --- a/apps/scheduler/call/render/render.py +++ b/apps/scheduler/call/render/render.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError, CallVars from apps.scheduler.call.core import CoreCall from apps.scheduler.call.render.style import RenderStyle @@ -50,7 +50,7 @@ class Render(metaclass=CoreCall, param_cls=_RenderParam, output_cls=_RenderOutpu description: str = "渲染图表工具,可将给定的数据绘制为图表。" - def init(self, _syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 + def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003 """初始化Render Call,校验参数,读取option模板""" try: option_location = Path(__file__).parent / "option.json" @@ -63,7 +63,7 @@ class Render(metaclass=CoreCall, param_cls=_RenderParam, output_cls=_RenderOutpu async def __call__(self, _slot_data: dict[str, Any]) -> _RenderOutput: """运行Render Call""" # 获取必要参数 - syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + syscall_vars: CallVars = getattr(self, "_syscall_vars") # 检测前一个工具是否为SQL if "dataset" not in syscall_vars.history[-1].output_data: diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql.py index 0aca289c6..28e5e4c89 100644 --- a/apps/scheduler/call/sql.py +++ b/apps/scheduler/call/sql.py @@ -13,7 +13,7 @@ from sqlalchemy import text from apps.common.config import config from apps.constants import LOGGER -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError, CallVars from apps.models.postgres import PostgreSQL from apps.scheduler.call.core import CoreCall @@ -38,7 +38,7 @@ class SQL(metaclass=CoreCall, param_cls=_SQLParams, output_cls=_SQLOutput): description: str = "SQL工具,用于查询数据库中的结构化数据" - def init(self, _syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 + def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003 """初始化SQL工具。""" # 初始化aiohttp的ClientSession self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) @@ -48,7 +48,7 @@ class SQL(metaclass=CoreCall, param_cls=_SQLParams, output_cls=_SQLOutput): """运行SQL工具""" # 获取必要参数 params: _SQLParams = getattr(self, "_params") - syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + syscall_vars: CallVars = getattr(self, "_syscall_vars") # 若手动设置了SQL,则直接使用 session = await PostgreSQL.get_session() diff --git a/apps/scheduler/call/suggest.py b/apps/scheduler/call/suggest.py index 61d853794..ce3bf36ab 100644 --- a/apps/scheduler/call/suggest.py +++ b/apps/scheduler/call/suggest.py @@ -7,7 +7,7 @@ from typing import Any, Optional import ray from pydantic import BaseModel, Field -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError, CallVars from apps.manager.user_domain import UserDomainManager from apps.scheduler.call.core import CoreCall @@ -50,7 +50,7 @@ class Suggestion(metaclass=CoreCall, param_cls=_SuggestInput, output_cls=_Sugges async def __call__(self, _slot_data: dict[str, Any]) -> _SuggestionOutput: """运行问题推荐""" - sys_vars: SysCallVars = getattr(self, "_syscall_vars") + sys_vars: CallVars = getattr(self, "_syscall_vars") params: _SuggestInput = getattr(self, "_params") # 获取当前任务 diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 6472d9f6e..6b68d22d9 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -3,17 +3,16 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import traceback -from typing import Any, Optional +from typing import Any import ray +from pydantic import BaseModel, Field from apps.constants import LOGGER, STEP_HISTORY_SIZE from apps.entities.enum_var import StepStatus -from apps.entities.flow import Step -from apps.entities.scheduler import ( - SysCallVars, - SysExecVars, -) +from apps.entities.flow import Flow, Step +from apps.entities.request_data import RequestDataApp +from apps.entities.scheduler import CallVars from apps.entities.task import ExecutorState, TaskBlock from apps.llm.patterns import ExecutorThought from apps.llm.patterns.executor import ExecutorBackground @@ -25,77 +24,63 @@ from apps.scheduler.executor.message import ( push_step_input, push_step_output, ) -from apps.scheduler.pool.pool import Pool from apps.scheduler.slot.slot import Slot # 单个流的执行工具 -@ray.remote -class Executor: +class Executor(BaseModel): """用于执行工作流的Executor""" - name: str = "" - """Flow名称""" - description: str = "" - """Flow描述""" - - - async def load_state(self, sysexec_vars: SysExecVars) -> None: - """从JSON中加载FlowExecutor的状态""" - # 获取Task - task_actor = ray.get_actor("task") - try: - self._task: TaskBlock = await task_actor.get_task.remote(sysexec_vars.task_id) - except Exception as e: - err = f"[Executor] Task error. {e!s}" - raise ValueError(err) from e + name: str = Field(description="Flow名称") + description: str = Field(description="Flow描述") - # 加载Flow信息 - pool = ray.get_actor("pool") - flow, flow_data = await pool.get_flow.remote(sysexec_vars.app_data.flow_id, sysexec_vars.app_data.app_id) + flow: Flow = Field(description="工作流数据") + task: TaskBlock = Field(description="任务信息") + queue: ray.ObjectRef = Field(description="消息队列") + question: str = Field(description="用户输入") + context: str = Field(description="上下文", default="") + post_body_app: RequestDataApp = Field(description="请求体中的app信息") - # Flow不合法,拒绝执行 - if flow is None or flow_data is None: - err = "Flow不存在!请检查Flow ID是否正确!" - raise ValueError(err) + class Config: + """Pydantic配置""" - # 设置名称和描述 - self.name = str(flow.name) - self.description = str(flow.description) + arbitrary_types_allowed = True - # 保存当前变量(只读) - self._vars = sysexec_vars - # 保存Flow数据(只读) - self._flow_data = flow_data + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" # 尝试恢复State - if self._task.flow_state: - self.flow_state = self._task.flow_state + if self.task.flow_state: + self.flow_state = self.task.flow_state # 如果flow_context为空,则从flow_history中恢复 - if not self._task.flow_context: - self._task.flow_context = await TaskManager.get_flow_history_by_task_id(self._vars.task_id) - self._task.new_context = [] + if not self.task.flow_context: + self.task.flow_context = await TaskManager.get_flow_history_by_task_id(self.task.record.task_id) + self.task.new_context = [] else: # 创建ExecutorState self.flow_state = ExecutorState( - name=str(flow.name), - description=str(flow.description), + name=str(self.flow.name), + description=str(self.flow.description), status=StepStatus.RUNNING, - app_id=str(sysexec_vars.app_data.app_id), + app_id=str(self.post_body_app.app_id), step_id="start", thought="", - slot_data=sysexec_vars.app_data.params, + filled_data=self.post_body_app.params, ) # 是否结束运行 self._stop = False - await task_actor.set_task.remote(self._vars.task_id, self._task) - async def _get_last_output(self, task: TaskBlock) -> dict[str, Any]: - """获取上一步的输出""" - if not task.flow_context: - return {} - return CallResult(**task.flow_context[self.flow_state.step_id].output_data) + async def _check_cls(self, call_cls: Any) -> bool: + """检查Call是否符合标准要求""" + flag = True + if not hasattr(call_cls, "name") or not isinstance(call_cls.name, str): + flag = False + if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str): + flag = False + if not hasattr(call_cls, "exec") or not callable(call_cls.exec): + flag = False + return flag async def _run_step(self, step_data: Step) -> dict[str, Any]: @@ -104,32 +89,44 @@ class Executor: self.flow_state.step_id = step_data.name self.flow_state.status = StepStatus.RUNNING - # Call类型为none,直接错误 + # Call类型为none,跳过执行 node_id = step_data.node if node_id == "none": return {} # 获取对应Node的call_id - call_id = await NodeManager.get_node_call_id(node_id) + try: + call_id = await NodeManager.get_node_call_id(node_id) + except Exception as e: + LOGGER.error(f"[FlowExecutor] 获取工具{node_id}的call_id时发生错误:{e}。\n{traceback.format_exc()}") + self.flow_state.status = StepStatus.ERROR + return {} + # 从Pool中获取对应的Call pool = ray.get_actor("pool") try: - call_data, call_cls = await pool.get_call.remote(call_id, self.flow_state.app_id) + call_cls = await pool.get_call.remote(call_id, self.flow_state.app_id) except Exception as e: - LOGGER.error(f"[FlowExecutor] 尝试执行工具{node_id}时发生错误:{e}。\n{traceback.format_exc()}") + LOGGER.error(f"[FlowExecutor] 载入工具{node_id}时发生错误:{e}。\n{traceback.format_exc()}") + self.flow_state.status = StepStatus.ERROR + return {} + + # 检查Call合法性 + if not self._check_cls(call_cls): + LOGGER.error(f"[FlowExecutor] 工具{node_id}不符合Call标准要求。") self.flow_state.status = StepStatus.ERROR return {} # 准备history - history = list(self._task.flow_context.values()) + history = list(self.task.flow_context.values()) length = min(STEP_HISTORY_SIZE, len(history)) history = history[-length:] # 准备SysCallVars - sys_vars = SysCallVars( - question=self._vars.question, - task_id=self._vars.task_id, - session_id=self._vars.session_id, + sys_vars = CallVars( + question=self.question, + task_id=self.task.record.task_id, + session_id=self.task.session_id, extra={ "app_id": self.flow_state.app_id, "flow_id": self.flow_state.name, @@ -160,28 +157,28 @@ class Executor: if slot_processor is not None: # 处理参数 remaining_schema, slot_data = await slot_processor.process( - self.flow_state.slot_data, - self._vars.app_data.params, + self.flow_state.filled_data, + self.sysexec_vars.app_data.params, { - "task_id": self._vars.task_id, - "question": self._vars.question, + "task_id": self.task.record.task_id, + "question": self.question, "thought": self.flow_state.thought, - "previous_output": await self._get_last_output(task), + "previous_output": await self._get_last_output(self.task), }, ) # 保存Schema至State self.flow_state.remaining_schema = remaining_schema - self.flow_state.slot_data.update(slot_data) + self.flow_state.filled_data.update(slot_data) # 如果还有未填充的部分,则终止执行 if remaining_schema: self._stop = True self.flow_state.status = StepStatus.RUNNING # 推送空输入 - await push_step_input(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data) + await push_step_input(self.task, self.queue, self.flow_state, self.flow) # 推送空输出 self.flow_state.status = StepStatus.PARAM result = {} - await push_step_output(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data, result) + await push_step_output(self.task, self.queue, self.flow_state, result) return result # 推送步骤输入 @@ -189,7 +186,7 @@ class Executor: # 执行Call try: - result: dict[str, Any] = await call_obj.call(self.flow_state.slot_data) + result: dict[str, Any] = await call_obj.call(self.flow_state.filled_data) except Exception as e: err = f"[FlowExecutor] 执行工具{node_id}时发生错误:{e!s}\n{traceback.format_exc()}" LOGGER.error(err) @@ -244,7 +241,7 @@ class Executor: 数据通过向Queue发送消息的方式传输 """ # 推送Flow开始 - await push_flow_start(self._vars.task_id, self._vars.queue, self.flow_state, self._vars.question) + await push_flow_start(self._task, self._vars.queue, self.flow_state, self._vars.question) # 更新背景 self.flow_state.thought = await ExecutorBackground().generate(self._vars.task_id, background=self._vars.background) diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py index 0aa7a87d5..1f41eff9c 100644 --- a/apps/scheduler/executor/message.py +++ b/apps/scheduler/executor/message.py @@ -6,163 +6,103 @@ from typing import Any import ray -from apps.common.queue import MessageQueue from apps.entities.enum_var import EventType, FlowOutputType, StepStatus from apps.entities.flow import Flow from apps.entities.message import ( FlowStartContent, FlowStopContent, - StepInputContent, - StepOutputContent, - TextAddContent, ) -from apps.entities.task import ExecutorState, FlowHistory, TaskBlock -from apps.llm.patterns.executor import FinalThought +from apps.entities.task import ExecutorState, FlowStepHistory, TaskBlock -async def push_step_input(task_id: str, queue: MessageQueue, state: ExecutorState, flow: Flow) -> None: +async def push_step_input(task_id: str, queue: ray.ObjectRef, state: ExecutorState, input_data: dict[str, Any]) -> None: """推送步骤输入""" - # 获取Task task_actor = ray.get_actor("task") - task = await task_actor.get_task.remote(task_id) - - if not task.flow_state: - err = "当前Record不存在Flow信息!" - raise ValueError(err) - + task: TaskBlock = await task_actor.get_task.remote(task_id) # 更新State task.flow_state = state # 更新FlowContext - flow_history = FlowHistory( - task_id=task_id, + task.flow_context[state.step_id] = FlowStepHistory( + task_id=task.record.task_id, flow_id=state.name, step_id=state.step_id, status=state.status, - input_data=state.slot_data, + input_data=state.filled_data, output_data={}, ) - task.new_context.append(flow_history.id) - task.flow_context[state.step_id] = flow_history - # 保存Task到TaskMap - await task.set_task.remote(task_id, task) - # 组装消息 - if state.status == StepStatus.ERROR: - # 如果当前步骤是错误,则推送错误步骤的输入 - if not flow.on_error: - err = "当前步骤不存在错误处理步骤!" - raise ValueError(err) - content = StepInputContent( - callType="llm", - params=state.slot_data, - ) - else: - content = StepInputContent( - callType=flow.steps[state.step_id].node, - params=state.slot_data, - ) # 推送消息 - await queue.push_output(event_type=EventType.STEP_INPUT, data=content.model_dump(exclude_none=True, by_alias=True)) + await queue.push_output.remote(task, event_type=EventType.STEP_INPUT, data=input_data) # type: ignore[attr-defined] + await task_actor.set_task.remote(task_id, task) -async def push_step_output(task_id: str, queue: MessageQueue, state: ExecutorState, flow: Flow, output: dict[str, Any]) -> None: +async def push_step_output(task_id: str, queue: ray.ObjectRef, state: ExecutorState, output: dict[str, Any]) -> None: """推送步骤输出""" - # 获取Task task_actor = ray.get_actor("task") - task = await task_actor.get_task.remote(task_id) - - if not task.flow_state: - err = "当前Record不存在Flow信息!" - raise ValueError(err) - + task: TaskBlock = await task_actor.get_task.remote(task_id) # 更新State task.flow_state = state # 更新FlowContext task.flow_context[state.step_id].output_data = output task.flow_context[state.step_id].status = state.status - # 保存Task到TaskMap - await task.set_task.remote(task_id, task) - - # 组装消息;只保留message和output - content = StepOutputContent( - callType=flow.steps[state.step_id].node, - message=output["message"] if output and "message" in output else "", - output=output["output"] if output and "output" in output else {}, - ) - await queue.push_output(event_type=EventType.STEP_OUTPUT, data=content.model_dump(exclude_none=True, by_alias=True)) + # FlowContext加入Record + task.new_context.append(task.flow_context[state.step_id].id) + + # 推送消息 + await queue.push_output.remote(task, event_type=EventType.STEP_OUTPUT, data=output) # type: ignore[attr-defined] + await task_actor.set_task.remote(task_id, task) -async def push_flow_start(task_id: str, queue: MessageQueue, state: ExecutorState, question: str) -> None: + +async def push_flow_start(task_id: str, queue: ray.ObjectRef, state: ExecutorState, question: str) -> None: """推送Flow开始""" - # 获取Task task_actor = ray.get_actor("task") - task = await task_actor.get_task.remote(task_id) - + task: TaskBlock = await task_actor.get_task.remote(task_id) # 设置state task.flow_state = state - # 保存Task到TaskMap - await task.set_task.remote(task_id, task) # 组装消息 content = FlowStartContent( question=question, - params=state.slot_data, + params=state.filled_data, ) # 推送消息 - await queue.push_output(event_type=EventType.FLOW_START, data=content.model_dump(exclude_none=True, by_alias=True)) - + await queue.push_output.remote(task, event_type=EventType.FLOW_START, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined] + await task_actor.set_task.remote(task_id, task) -async def push_flow_stop(task_id: str, queue: MessageQueue, state: ExecutorState, flow: Flow, question: str) -> None: - """推送Flow结束""" - # 获取Task - task_actor = ray.get_actor("task") - task = await task_actor.get_task.remote(task_id) - - task.flow_state = state - await task.set_task.remote(task_id, task) - # 准备必要数据 +async def assemble_flow_stop_content(state: ExecutorState, flow: Flow) -> FlowStopContent: + """组装Flow结束消息""" call_type = flow.steps[state.step_id].call_type - if state.remaining_schema: # 如果当前Flow是填充步骤,则推送Schema content = FlowStopContent( type=FlowOutputType.SCHEMA, data=state.remaining_schema, - ).model_dump(exclude_none=True, by_alias=True) + ) elif call_type == "render": # 如果当前Flow是图表,则推送Chart chart_option = task.flow_context[state.step_id].output_data["output"] content = FlowStopContent( type=FlowOutputType.CHART, data=chart_option, - ).model_dump(exclude_none=True, by_alias=True) + ) else: # 如果当前Flow是其他类型,则推送空消息 - content = {} - - # 推送最终结果 - params = { - "question": question, - "thought": state.thought, - "final_output": content, - } - full_text = "" - async for chunk in FinalThought().generate(task_id, **params): - if not chunk: - continue - await queue.push_output( - event_type=EventType.TEXT_ADD, - data=TextAddContent(text=chunk).model_dump(exclude_none=True, by_alias=True), - ) - full_text += chunk + content = FlowStopContent() - # 推送Stop消息 - await queue.push_output(event_type=EventType.FLOW_STOP, data=content) + return content - # 更新Thought - task.record.content.answer = full_text + +async def push_flow_stop(task_id: str, queue: ray.ObjectRef, state: ExecutorState, flow: Flow) -> None: + """推送Flow结束""" + task_actor = ray.get_actor("task") + task: TaskBlock = await task_actor.get_task.remote(task_id) + # 设置state task.flow_state = state + content = await assemble_flow_stop_content(state, flow) - await task.set_task.remote(task_id, task) + # 推送Stop消息 + await queue.push_output.remote(task, event_type=EventType.FLOW_STOP, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined] + await task_actor.set_task.remote(task_id, task) diff --git a/apps/scheduler/pool/check.py b/apps/scheduler/pool/check.py index 8e8b975ac..2a474bd74 100644 --- a/apps/scheduler/pool/check.py +++ b/apps/scheduler/pool/check.py @@ -33,10 +33,10 @@ class FileChecker: raise err async for file in path.iterdir(): - if file.is_file(): + if await file.is_file(): relative_path = file.relative_to(self._resource_path) hashes[relative_path.as_posix()] = sha256(await file.read_bytes()).hexdigest() - elif file.is_dir(): + elif await file.is_dir(): hashes.update(await self.check_one(file)) return hashes @@ -71,11 +71,11 @@ class FileChecker: # 遍历列表 for list_item in items: # 判断是否存在? - if not Path(self._dir_path / list_item["_id"]).exists(): + if not await Path(self._dir_path / list_item["_id"]).exists(): deleted_list.append(list_item["_id"]) continue # 判断是否发生变化 - if self.diff_one(Path(self._dir_path / list_item["_id"]), list_item.get("hashes", None)): + if await self.diff_one(Path(self._dir_path / list_item["_id"]), list_item.get("hashes", None)): changed_list.append(list_item["_id"]) # 遍历目录 diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 98f72b59c..77a260760 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -76,6 +76,7 @@ class AppLoader: raise RuntimeError(err) from e await self._update_db(metadata) + async def save(self, metadata: AppMetadata, app_id: str) -> None: """保存应用 @@ -115,6 +116,7 @@ class AppLoader: await session.aclose() + async def _update_db(self, metadata: AppMetadata) -> None: """更新数据库""" if not metadata.hashes: diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index c466dbbd7..5c46e7d66 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -26,27 +26,6 @@ class CallLoader: 用户Call放在call下 """ - @staticmethod - def _check_class(user_cls) -> bool: # noqa: ANN001 - """检查用户类是否符合Call标准要求""" - flag = True - - if not hasattr(user_cls, "name") or not isinstance(user_cls.name, str): - flag = False - if not hasattr(user_cls, "description") or not isinstance(user_cls.description, str): - flag = False - if not hasattr(user_cls, "output_schema") or not isinstance(user_cls.output_schema, dict): - flag = False - if not hasattr(user_cls, "params_schema") or not isinstance(user_cls.params_schema, dict): - flag = False - if not hasattr(user_cls, "init") or not callable(user_cls.init): - flag = False - if not callable(user_cls) or not callable(user_cls.__call__): - flag = False - - return flag - - async def _load_system_call(self) -> list[CallPool]: """加载系统Call""" call_metadata = [] @@ -54,10 +33,6 @@ class CallLoader: # 检查合法性 for call_id in system_call.__all__: call_cls = getattr(system_call, call_id) - if not self._check_class(call_cls): - err = f"系统类{call_cls.__name__}不符合Call标准要求。" - LOGGER.info(msg=err) - continue call_metadata.append( CallPool( @@ -108,11 +83,6 @@ class CallLoader: LOGGER.info(msg=err) continue - if not self._check_class(call_cls): - err = f"工具call.{call_name}.{call_id}不符合标准要求;跳过载入。" - LOGGER.info(msg=err) - continue - cls_path = f"{call_package.service}::call.{call_name}.{call_id}" cls_hash = shake_128(cls_path.encode()).hexdigest(8) call_metadata.append( diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py index fa60c416e..3942059da 100644 --- a/apps/scheduler/pool/loader/metadata.py +++ b/apps/scheduler/pool/loader/metadata.py @@ -2,7 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ - +import json from typing import Any, Optional, Union import yaml diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 51392b82e..d408a3c34 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -55,6 +55,7 @@ class ServiceLoader: nodes = [NodePool(**node.model_dump(exclude_none=True, by_alias=True)) for node in nodes] await self._update_db(nodes, metadata) + async def save(self, service_id: str, metadata: ServiceMetadata, data: dict) -> None: """在文件系统上保存Service,并更新数据库""" service_path = Path(config["SEMANTICS_DIR"]) / SERVICE_DIR / service_id @@ -75,6 +76,7 @@ class ServiceLoader: await file_checker.diff_one(service_path) await self.load(service_id, file_checker.hashes[f"{SERVICE_DIR}/{service_id}"]) + async def delete(self, service_id: str) -> None: """删除Service,并更新数据库""" service_collection = MongoDB.get_collection("service") diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 2d94a8c01..6d23c568d 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -9,10 +9,10 @@ from typing import Any, Optional import ray from anyio import Path -from apps.common.config import config from apps.constants import APP_DIR, LOGGER, SERVICE_DIR from apps.entities.enum_var import MetadataType -from apps.entities.flow_topology import FlowItem +from apps.entities.flow import Flow +from apps.entities.pool import AppFlow from apps.models.mongo import MongoDB from apps.scheduler.pool.check import FileChecker from apps.scheduler.pool.loader import ( @@ -80,23 +80,42 @@ class Pool: pass - async def get_flow_metadata(self, app_id: str) -> Optional[FlowItem]: + async def get_flow_metadata(self, app_id: str) -> list[AppFlow]: """从数据库中获取特定App的全部Flow的元数据""" app_collection = MongoDB.get_collection("app") + flow_metadata_list = [] try: flow_list = await app_collection.find_one({"_id": app_id}, {"flows": 1}) + if not flow_list: + return [] + for flow in flow_list["flows"]: + flow_metadata_list.extend(AppFlow(**flow)) except Exception as e: err = f"获取App{app_id}的Flow列表失败:{e}" LOGGER.error(err) raise RuntimeError(err) from e - if not flow_list: - return None + return flow_metadata_list - async def get_flow(self, app_id: str, flow_id: str) -> Optional[FlowItem]: + # TODO + async def get_flow(self, app_id: str, flow_id: str) -> Optional[Flow]: """从数据库中获取单个Flow的全部数据""" - pass + app_collection = MongoDB.get_collection("app") + try: + # 使用聚合管道来查找特定的flow + pipeline = [ + {"$match": {"_id": app_id}}, + {"$unwind": "$flows"}, + {"$match": {"flows._id": flow_id}}, + ] + async for flow in await app_collection.aggregate(pipeline): + return Flow(**flow) + return None + except Exception as e: + err = f"获取App {app_id} 的Flow {flow_id} 失败:{e}" + LOGGER.error(err) + raise RuntimeError(err) from e async def get_call(self, call_path: str) -> Any: @@ -107,11 +126,11 @@ class Pool: LOGGER.error(err) raise ValueError(err) + # Python类型的Call if call_path_split[0] == "python": try: call_module = importlib.import_module(call_path_split[1]) - call_class = getattr(call_module, call_path_split[2]) - return call_class + return getattr(call_module, call_path_split[2]) except Exception as e: err = f"导入Call{call_path}失败:{e}" LOGGER.error(err) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 106687ce1..bba2a0b98 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -2,13 +2,39 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +from datetime import datetime, timezone +from typing import Union + import ray from apps.common.security import Security -from apps.entities.collection import RecordContent +from apps.constants import LOGGER +from apps.entities.collection import Document, Record, RecordContent +from apps.entities.record import RecordDocument from apps.entities.request_data import RequestData +from apps.entities.task import TaskBlock from apps.llm.patterns.facts import Facts +from apps.manager.document import DocumentManager from apps.manager.record import RecordManager +from apps.manager.task import TaskManager + + +async def get_docs(user_sub: str, post_body: RequestData) -> tuple[Union[list[RecordDocument], list[Document]], list[str]]: + """获取当前问答可供关联的文档""" + doc_ids = [] + if post_body.group_id: + # 是重新生成,直接从RecordGroup中获取 + docs = await DocumentManager.get_used_docs_by_record_group(user_sub, post_body.group_id) + doc_ids += [doc.id for doc in docs] + else: + # 是新提问 + # 从Conversation中获取刚上传的文档 + docs = await DocumentManager.get_unused_docs(user_sub, post_body.conversation_id) + # 从最近10条Record中获取文档 + docs += await DocumentManager.get_used_docs(user_sub, post_body.conversation_id, 10) + doc_ids += [doc.id for doc in docs] + + return docs, doc_ids async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[str, list[str]]: @@ -44,17 +70,73 @@ async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[st return messages, facts -async def generate_facts(task_id: str, question: str) -> list[str]: +async def generate_facts(task: TaskBlock, question: str) -> list[str]: """生成Facts""" - task_pool = ray.get_actor("task") - task = await task_pool.get_task.remote(task_id) - if not task: - err = "Task not found" - raise ValueError(err) - message = { "question": question, "answer": task.record.content.answer, } - return await Facts().generate(task_id, message=message) + return await Facts().generate(task.record.task_id, message=message) + + +async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_docs: list[str]) -> None: + """保存当前Executor、Task、Record等的数据""" + # 获取当前Task + task_actor = ray.get_actor("task") + task: TaskBlock = await task_actor.get_task.remote(task_id) + # 加密Record数据 + try: + encrypt_data, encrypt_config = Security.encrypt(task.record.content.model_dump_json(by_alias=True)) + except Exception as e: + LOGGER.info(f"[Scheduler] Encryption failed: {e}") + return + + # 保存Flow信息 + if task.flow_state: + # 循环创建FlowHistory + history_data = [] + # 遍历查找数据,并添加 + for history_id in task.new_context: + for history in task.flow_context.values(): + if history.id == history_id: + history_data.append(history) + break + await TaskManager.create_flows(history_data) + + # 修改metadata里面时间为实际运行时间 + task.record.metadata.time = round(datetime.now(timezone.utc).timestamp() - task.record.metadata.time, 2) + + # 提取facts + # 记忆提取 + facts = await generate_facts(task, post_body.question) + + # 整理Record数据 + record = Record( + record_id=task.record.id, + user_sub=user_sub, + data=encrypt_data, + key=encrypt_config, + facts=facts, + metadata=task.record.metadata, + created_at=task.record.created_at, + flow=task.new_context, + ) + + record_group = task.record.group_id + # 检查是否存在group_id + if not await RecordManager.check_group_id(record_group, user_sub): + record_group = await RecordManager.create_record_group(user_sub, post_body.conversation_id, task.record.task_id) + if not record_group: + LOGGER.error("[Scheduler] Create record group failed.") + return + + # 修改文件状态 + await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) + # 保存Record + await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record) + # 保存与答案关联的文件 + await DocumentManager.save_answer_doc(user_sub, record_group, used_docs) + + # 保存Task + await task_actor.save_task.remote(task_id) diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 64a20ef75..4be23bcd1 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -40,9 +40,7 @@ class FlowChooser: if not flow_list: return "KnowledgeBase" - top_flow = await Select.generate() - return top_flow - + return await Select().generate(self._task_id, question=self._question, choices=flow_list) async def choose_flow(self) -> Optional[RequestDataApp]: diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index a38617827..22c1fab90 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -8,7 +8,6 @@ from typing import Union import ray -from apps.common.queue import MessageQueue from apps.constants import LOGGER from apps.entities.collection import Document from apps.entities.enum_var import EventType @@ -25,15 +24,10 @@ from apps.entities.task import TaskBlock from apps.service import RAG -async def push_init_message(task_id: str, queue: MessageQueue, post_body: RequestData, *, is_flow: bool = False) -> None: +async def push_init_message(task_id: str, queue: ray.ObjectRef, post_body: RequestData, *, is_flow: bool = False) -> None: """推送初始化消息""" - # 拿到Task task_actor = ray.get_actor("task") task: TaskBlock = await task_actor.get_task.remote(task_id) - if not task: - err = "[Scheduler] Task not found" - raise ValueError(err) - # 组装feature if is_flow: feature = InitContentFeature( @@ -54,68 +48,75 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques created_at = round(datetime.now(timezone.utc).timestamp(), 2) task.record.metadata.time = created_at task.record.metadata.feature = feature.model_dump(exclude_none=True, by_alias=True) - await task_actor.set_task.remote(task_id, task) # 推送初始化消息 - await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True)) + await queue.push_output.remote( # type: ignore[attr-defined] + task=task, + event_type=EventType.INIT, + data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True), + ) + await task_actor.set_task.remote(task_id, task) -async def push_rag_message(task_id: str, queue: MessageQueue, user_sub: str, rag_data: RAGQueryReq) -> None: +async def push_rag_message(task_id: str, queue: ray.ObjectRef, user_sub: str, rag_data: RAGQueryReq) -> None: """推送RAG消息""" task_actor = ray.get_actor("task") - task: TaskBlock = await task_actor.get_task.remote(task_id) - if not task: - err = "Task not found" - raise ValueError(err) - - rag_input_tokens = 0 - rag_output_tokens = 0 full_answer = "" async for chunk in RAG.get_rag_result(user_sub, rag_data): - chunk_content, rag_input_tokens, rag_output_tokens = await _push_rag_chunk(task_id, queue, chunk, rag_input_tokens, rag_output_tokens) + chunk_content = await _push_rag_chunk(task_id, queue, chunk) full_answer += chunk_content # 保存答案 + task: TaskBlock = await task_actor.get_task.remote(task_id) task.record.content.answer = full_answer await task_actor.set_task.remote(task_id, task) -async def _push_rag_chunk(task_id: str, queue: MessageQueue, content: str, rag_input_tokens: int, rag_output_tokens: int) -> tuple[str, int, int]: +async def _push_rag_chunk(task_id: str, queue: ray.ObjectRef, content: str) -> str: """推送RAG单个消息块""" + task_actor = ray.get_actor("task") + task: TaskBlock = await task_actor.get_task.remote(task_id) # 如果是换行 if not content or not content.rstrip().rstrip("\n"): - return "", rag_input_tokens, rag_output_tokens + return "" try: content_obj = RAGEventData.model_validate_json(dedent(content[6:]).rstrip("\n")) # 如果是空消息 if not content_obj.content: - return "", rag_input_tokens, rag_output_tokens + return "" - # 计算Token数量 - delta_input_tokens = content_obj.input_tokens - rag_input_tokens - delta_output_tokens = content_obj.output_tokens - rag_output_tokens - task_actor = ray.get_actor("task") - await task_actor.update_token_summary.remote(task_id, delta_input_tokens, delta_output_tokens) - # 更新Token的值 - rag_input_tokens = content_obj.input_tokens - rag_output_tokens = content_obj.output_tokens + # 更新Token数量 + task.record.metadata.input_tokens = content_obj.input_tokens + task.record.metadata.output_tokens = content_obj.output_tokens # 推送消息 - await queue.push_output(event_type=EventType.TEXT_ADD, data=TextAddContent(text=content_obj.content).model_dump(exclude_none=True, by_alias=True)) - return content_obj.content, rag_input_tokens, rag_output_tokens + await queue.push_output.remote( # type: ignore[attr-defined] + task=task, + event_type=EventType.TEXT_ADD, + data=TextAddContent(text=content_obj.content).model_dump(exclude_none=True, by_alias=True), + ) + await task_actor.set_task.remote(task_id, task) + return content_obj.content except Exception as e: LOGGER.error(f"[Scheduler] RAG服务返回错误数据: {e!s}\n{content}") - return "", rag_input_tokens, rag_output_tokens + return "" -async def push_document_message(queue: MessageQueue, doc: Union[RecordDocument, Document]) -> None: +async def push_document_message(task_id: str, queue: ray.ObjectRef, doc: Union[RecordDocument, Document]) -> None: """推送文档消息""" + task_actor = ray.get_actor("task") + task: TaskBlock = await task_actor.get_task.remote(task_id) content = DocumentAddContent( documentId=doc.id, documentName=doc.name, documentType=doc.type, documentSize=round(doc.size, 2), ) - await queue.push_output(event_type=EventType.DOCUMENT_ADD, data=content.model_dump(exclude_none=True, by_alias=True)) + await queue.push_output.remote( # type: ignore[attr-defined] + task=task, + event_type=EventType.DOCUMENT_ADD, + data=content.model_dump(exclude_none=True, by_alias=True), + ) + await task_actor.set_task.remote(task_id, task) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index d9c7a9a17..320ff8d2c 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -4,30 +4,19 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import asyncio import traceback -from datetime import datetime, timezone -from typing import Union import ray -from apps.common.queue import MessageQueue -from apps.common.security import Security from apps.constants import LOGGER -from apps.entities.collection import ( - Document, - Record, -) -from apps.entities.enum_var import EventType, StepStatus +from apps.entities.enum_var import EventType from apps.entities.rag_data import RAGQueryReq -from apps.entities.record import RecordDocument from apps.entities.request_data import RequestData -from apps.entities.scheduler import ExecutorBackground, SysExecVars -from apps.entities.task import RequestDataApp -from apps.manager.document import DocumentManager -from apps.manager.record import RecordManager -from apps.manager.task import TaskManager +from apps.entities.scheduler import ExecutorBackground +from apps.entities.task import SchedulerResult, TaskBlock from apps.manager.user import UserManager from apps.scheduler.executor import Executor -from apps.scheduler.scheduler.context import generate_facts, get_context +from apps.scheduler.scheduler.context import get_context, get_docs +from apps.scheduler.scheduler.flow import FlowChooser from apps.scheduler.scheduler.message import ( push_document_message, push_init_message, @@ -35,173 +24,126 @@ from apps.scheduler.scheduler.message import ( ) +@ray.remote class Scheduler: """“调度器”,是最顶层的、控制Executor执行顺序和状态的逻辑。 Scheduler包含一个“SchedulerContext”,作用为多个Executor的“聊天会话” """ - def __init__(self, task_id: str, queue: MessageQueue) -> None: - """初始化Scheduler""" - self._task_id = task_id - self._queue = queue - self.used_docs = [] - - - async def _get_docs(self, user_sub: str, post_body: RequestData) -> tuple[Union[list[RecordDocument], list[Document]], list[str]]: - """获取当前问答可供关联的文档""" - doc_ids = [] - if post_body.group_id: - # 是重新生成,直接从RecordGroup中获取 - docs = await DocumentManager.get_used_docs_by_record_group(user_sub, post_body.group_id) - doc_ids += [doc.id for doc in docs] - else: - # 是新提问 - # 从Conversation中获取刚上传的文档 - docs = await DocumentManager.get_unused_docs(user_sub, post_body.conversation_id) - # 从最近10条Record中获取文档 - docs += await DocumentManager.get_used_docs(user_sub, post_body.conversation_id, 10) - doc_ids += [doc.id for doc in docs] - - return docs, doc_ids - - - async def run(self, user_sub: str, session_id: str, post_body: RequestData) -> None: + async def run(self, task_id: str, queue: ray.ObjectRef, user_sub: str, post_body: RequestData) -> SchedulerResult: """运行调度器""" + task_actor = ray.get_actor("task") + try: + task = await task_actor.get_task.remote(task_id) + except Exception as e: + LOGGER.error(f"[Scheduler] Task {task_id} not found: {e!s}\n{traceback.format_exc()}") + await queue.close.remote() # type: ignore[attr-defined] + return SchedulerResult(used_docs=[]) + try: - # 根据用户的请求,返回插件ID列表,选择Flow # 获取当前问答可供关联的文档 - docs, doc_ids = await self._get_docs(user_sub, post_body) + docs, doc_ids = await get_docs(user_sub, post_body) + except Exception as e: + LOGGER.error(f"[Scheduler] Get docs failed: {e!s}\n{traceback.format_exc()}") + await queue.close.remote() # type: ignore[attr-defined] + return SchedulerResult(used_docs=[]) + + try: # 获取上下文;最多20轮 context, facts = await get_context(user_sub, post_body, post_body.features.context_num) + except Exception as e: + LOGGER.error(f"[Scheduler] Get context failed: {e!s}\n{traceback.format_exc()}") + await queue.close.remote() # type: ignore[attr-defined] + return SchedulerResult(used_docs=[]) + + # 获取用户配置的kb_sn + user_info = await UserManager.get_userinfo_by_user_sub(user_sub) + if not user_info: + err = "[Scheduler] User not found" + LOGGER.error(err) + await queue.close.remote() # type: ignore[attr-defined] + return SchedulerResult(used_docs=[]) + + # 组装RAG请求数据,备用 + rag_data = RAGQueryReq( + question=post_body.question, + language=post_body.language, + document_ids=doc_ids, + kb_sn=None if not user_info.kb_id else user_info.kb_id, + top_k=5, + ) + # 已使用文档 + used_docs = [] + + # 如果是智能问答,直接执行 + if not post_body.app or post_body.app.app_id == "": + await push_init_message(task_id, queue, post_body, is_flow=False) + await asyncio.sleep(0.1) + for doc in docs: + # 保存使用的文件ID + used_docs.append(doc.id) + await push_document_message(task_id, queue, doc) + await asyncio.sleep(0.1) - # 获取用户配置的kb_sn - user_info = await UserManager.get_userinfo_by_user_sub(user_sub) - if not user_info: - err = "[Scheduler] User not found" - raise ValueError(err) # noqa: TRY301 - # 组装RAG请求数据,备用 - rag_data = RAGQueryReq( - question=post_body.question, - language=post_body.language, - document_ids=doc_ids, - kb_sn=None if not user_info.kb_id else user_info.kb_id, - top_k=5, + # 保存有数据的最后一条消息 + await push_rag_message(task_id, queue, user_sub, rag_data) + else: + # 需要执行Flow + await push_init_message(task_id, queue, post_body, is_flow=True) + # 组装上下文 + background = ExecutorBackground( + conversation=context, + facts=facts, ) + await self.run_executor(task, queue, post_body, background) - # 如果是智能问答,直接执行 - if not post_body.app or post_body.app.app_id == "": - await push_init_message(self._task_id, self._queue, post_body, is_flow=False) - await asyncio.sleep(0.1) - for doc in docs: - # 保存使用的文件ID - self.used_docs.append(doc.id) - await push_document_message(self._queue, doc) - - # 保存有数据的最后一条消息 - await push_rag_message(self._task_id, self._queue, user_sub, rag_data) - else: - # 需要执行Flow - await push_init_message(self._task_id, self._queue, post_body, is_flow=True) - # 组装上下文 - background = ExecutorBackground( - conversation=context, - facts=facts, - ) - - # 记忆提取 - self._facts = await generate_facts(self._task_id, post_body.question) - - # 发送结束消息 - await self._queue.push_output(event_type=EventType.DONE, data={}) - # 关闭Queue - await self._queue.close() - except Exception as e: - LOGGER.error(f"[Scheduler] Error: {e!s}\n{traceback.format_exc()}") - await self._queue.close() + # 发送结束消息 + task = await task_actor.get_task.remote(task_id) + await queue.push_output.remote(task, event_type=EventType.DONE, data={}) # type: ignore[attr-defined] + # 关闭Queue + await queue.close.remote() # type: ignore[attr-defined] + return SchedulerResult(used_docs=used_docs) - async def run_executor(self, session_id: str, post_body: RequestData, background: ExecutorBackground, selected_flow: RequestDataApp) -> bool: + async def run_executor(self, task: TaskBlock, queue: ray.ObjectRef, post_body: RequestData, background: ExecutorBackground) -> None: """构造FlowExecutor,并执行所选择的流""" - # 获取当前Task - task_pool = ray.get_actor("task") - task = await task_pool.get_task.remote(self._task_id) - if not task: - err = "[Scheduler] Task error." - raise ValueError(err) - - # 设置Flow接受的系统变量 - param = SysExecVars( - queue=self._queue, - question=post_body.question, - task_id=self._task_id, - session_id=session_id, - app_data=selected_flow, - background=background, - ) + # 读取App中所有Flow的信息 + pool_actor = ray.get_actor("pool") + if not post_body.app: + LOGGER.error("[Scheduler] Not using workflow!") + return + flow_info = await pool_actor.get_flow_metadata.remote(post_body.app.app_id) - # 执行Executor - # flow_exec = Executor() - # await flow_exec.load_state(param) - # # 开始运行 - # await flow_exec.run() - # # 判断状态 - # return flow_exec.flow_state.status != StepStatus.PARAM - - async def save_state(self, user_sub: str, post_body: RequestData) -> None: - """保存当前Executor、Task、Record等的数据""" - # 获取当前Task - task_pool = ray.get_actor("task") - task = await task_pool.get_task.remote(self._task_id) - if not task: - err = "Task not found" - raise ValueError(err) - - # 加密Record数据 - try: - encrypt_data, encrypt_config = Security.encrypt(task.record.content.model_dump_json(by_alias=True)) - except Exception as e: - LOGGER.info(f"[Scheduler] Encryption failed: {e}") + # 如果flow_info为空,则直接返回 + if not flow_info: + LOGGER.error(f"[Scheduler] Flow info not found for app {post_body.app.app_id}") return - # 保存Flow信息 - if task.flow_state: - # 循环创建FlowHistory - history_data = [] - # 遍历查找数据,并添加 - for history_id in task.new_context: - for history in task.flow_context.values(): - if history.id == history_id: - history_data.append(history) - break - await Task.create_flows(history_data) - - # 修改metadata里面时间为实际运行时间 - task.record.metadata.time = round(datetime.now(timezone.utc).timestamp() - task.record.metadata.time, 2) - - # 整理Record数据 - record = Record( - record_id=task.record.id, - user_sub=user_sub, - data=encrypt_data, - key=encrypt_config, - facts=self._facts, - metadata=task.record.metadata, - created_at=task.record.created_at, - flow=task.new_context, - ) + # 如果用户选了特定的Flow + if post_body.app.flow_id: + flow_data = await pool_actor.get_flow.remote(post_body.app.app_id, post_body.app.flow_id) + else: + # 如果用户没有选特定的Flow,则根据语义选择一个Flow + flow_chooser = FlowChooser(task.record.task_id, post_body.question, post_body.app) + flow_id = await flow_chooser.get_top_flow() + flow_data = await pool_actor.get_flow.remote(post_body.app.app_id, flow_id) + + # 如果flow_data为空,则直接返回 + if not flow_data: + LOGGER.error(f"[Scheduler] Flow data not found for app {post_body.app.app_id} and flow {flow_id}") + return - record_group = task.record.group_id - # 检查是否存在group_id - if not await RecordManager.check_group_id(record_group, user_sub): - record_group = await RecordManager.create_record_group(user_sub, post_body.conversation_id, self._task_id) - if not record_group: - LOGGER.error("[Scheduler] Create record group failed.") - return - - # 修改文件状态 - await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) - # 保存Record - await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record) - # 保存与答案关联的文件 - await DocumentManager.save_answer_doc(user_sub, record_group, self.used_docs) + # 初始化Executor + flow_exec = Executor( + name=flow_data.name, + description=flow_data.description, + flow=flow_data, + task=task, + queue=queue, + question=post_body.question, + post_body_app=post_body.app, + ) + # 开始运行 + await flow_exec.load_state() + await flow_exec.run() diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 927165aac..f92c7fd7b 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -268,9 +268,9 @@ class Slot: ] if previous_output is not None: - tool_str = f"""I used a tool to get extra information from other sources. \ - The output data of the tool is `{previous_output}`. - The schema of the output is `{json.dumps(previous_output["output_schema"], ensure_ascii=False)}`, which contains description of the output. + tool_str = f"""我使用了一个工具从其他来源获取额外信息。\ + 工具的输出数据是 `{previous_output}`。 + 输出的schema是 `{json.dumps(previous_output["output_schema"], ensure_ascii=False)}`,其中包含了输出的描述信息。 """ conversation.append({"role": "tool", "content": tool_str}) -- Gitee From 21c0a8822328770c012ca11dc401f64b6454f8af Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 26 Feb 2025 02:18:28 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E5=BC=80=E5=8F=91Scheduler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/pool.py | 1 - apps/entities/scheduler.py | 5 +- apps/entities/task.py | 2 +- apps/llm/patterns/__init__.py | 4 +- apps/llm/patterns/executor.py | 2 +- apps/main.py | 2 - apps/manager/flow.py | 18 +- apps/manager/node.py | 78 +++++- apps/routers/chat.py | 127 ++++----- apps/routers/flow.py | 35 ++- apps/scheduler/call/api.py | 12 +- apps/scheduler/call/convert.py | 2 +- apps/scheduler/call/core.py | 7 +- apps/scheduler/call/llm.py | 2 +- apps/scheduler/call/rag.py | 28 +- apps/scheduler/call/sql.py | 22 +- apps/scheduler/executor/__init__.py | 5 - apps/scheduler/executor/flow.py | 257 +++++++++--------- apps/scheduler/executor/message.py | 11 +- apps/scheduler/pool/loader/app.py | 7 + apps/scheduler/pool/loader/call.py | 13 +- apps/scheduler/pool/loader/flow.py | 2 +- apps/scheduler/pool/loader/openapi.py | 7 +- apps/scheduler/scheduler/scheduler.py | 60 +--- .../euler_copilot/templates/secrets.yaml | 4 +- 25 files changed, 376 insertions(+), 337 deletions(-) diff --git a/apps/entities/pool.py b/apps/entities/pool.py index 97260ccda..f9a85c256 100644 --- a/apps/entities/pool.py +++ b/apps/entities/pool.py @@ -76,7 +76,6 @@ class NodePool(BaseData): service_id: Optional[str] = Field(description="Node所属的Service ID", default=None) call_id: str = Field(description="所使用的Call的ID") - annotation: Optional[str] = Field(description="Node的注释", default=None) known_params: Optional[dict[str, Any]] = Field( description="已知的用于Call部分的参数,独立于输入和输出之外", default=None, diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py index 3a23d0f5e..18a7c95aa 100644 --- a/apps/entities/scheduler.py +++ b/apps/entities/scheduler.py @@ -19,16 +19,15 @@ class CallVars(BaseModel): question: str = Field(description="改写后的用户输入") history: list[FlowStepHistory] = Field(description="Executor中历史工具的结构化数据", default=[]) task_id: str = Field(description="任务ID") + flow_id: str = Field(description="Flow ID") session_id: str = Field(description="当前用户的Session ID") - extra: dict[str, Any] = Field(description="其他Executor设置的、用户不可修改的参数", default={}) class ExecutorBackground(BaseModel): """Executor的背景信息""" - conversation: list[dict[str, str]] = Field(description="当前Executor的背景信息") + conversation: str = Field(description="当前Executor的背景信息") facts: list[str] = Field(description="当前Executor的背景信息") - thought: str = Field(description="之前Executor的思考内容", default="") class CallError(Exception): diff --git a/apps/entities/task.py b/apps/entities/task.py index 77737fb8b..23fd1db40 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -39,7 +39,7 @@ class ExecutorState(BaseModel): step_id: str = Field(description="当前步骤名称") app_id: str = Field(description="应用ID") # 运行时数据 - thought: str = Field(description="大模型的思考内容", default="") + ai_summary: str = Field(description="大模型的思考内容", default="") filled_data: dict[str, Any] = Field(description="待使用的参数", default={}) remaining_schema: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) diff --git a/apps/llm/patterns/__init__.py b/apps/llm/patterns/__init__.py index 914ad23c0..20ef8bbcf 100644 --- a/apps/llm/patterns/__init__.py +++ b/apps/llm/patterns/__init__.py @@ -5,7 +5,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from apps.llm.patterns.core import CorePattern from apps.llm.patterns.domain import Domain from apps.llm.patterns.executor import ( - ExecutorBackground, + ExecutorSummary, FinalThought, ExecutorThought, ) @@ -16,7 +16,7 @@ from apps.llm.patterns.select import Select __all__ = [ "CorePattern", "Domain", - "ExecutorBackground", + "ExecutorSummary", "FinalThought", "ExecutorThought", "Json", diff --git a/apps/llm/patterns/executor.py b/apps/llm/patterns/executor.py index 086ee7174..dc975900e 100644 --- a/apps/llm/patterns/executor.py +++ b/apps/llm/patterns/executor.py @@ -80,7 +80,7 @@ class ExecutorThought(CorePattern): return result -class ExecutorBackground(CorePattern): +class ExecutorSummary(CorePattern): """使用大模型进行生成Executor初始背景""" user_prompt: str = r""" diff --git a/apps/main.py b/apps/main.py index b59e88c35..a132c24d0 100644 --- a/apps/main.py +++ b/apps/main.py @@ -31,7 +31,6 @@ from apps.routers import ( flow, health, knowledge, - mock, record, service, user, @@ -65,7 +64,6 @@ app.include_router(blacklist.router) app.include_router(document.router) app.include_router(knowledge.router) app.include_router(flow.router) -app.include_router(mock.router) app.include_router(user.router) # 初始化后台定时任务 scheduler = BackgroundScheduler() diff --git a/apps/manager/flow.py b/apps/manager/flow.py index f7ac31a79..3b2950c5e 100644 --- a/apps/manager/flow.py +++ b/apps/manager/flow.py @@ -4,13 +4,13 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Optional -from pymongo import ASCENDING import ray +from pymongo import ASCENDING from apps.constants import LOGGER from apps.entities.appcenter import AppLink -from apps.entities.enum_var import MetadataType, PermissionType -from apps.entities.flow import AppMetadata, Edge, Flow, FlowConfig, Permission, Step, StepPos +from apps.entities.enum_var import EdgeType, MetadataType, PermissionType +from apps.entities.flow import AppMetadata, Edge, Flow, Permission, Step, StepPos from apps.entities.flow_topology import ( EdgeItem, FlowItem, @@ -41,6 +41,9 @@ class FlowManager: try: node_pool_record = await node_pool_collection.find_one({"_id": node_meta_data_id}) + if node_pool_record is None: + LOGGER.error(f"节点元数据{node_meta_data_id}不存在") + return False match_conditions = [ {"author": user_sub}, {"permissions.type": PermissionType.PUBLIC.value}, @@ -159,6 +162,9 @@ class FlowManager: node_pool_collection = MongoDB.get_collection("node") # 获取节点集合 try: node_pool_record = await node_pool_collection.find_one({"_id": node_meta_data_id}) + if node_pool_record is None: + LOGGER.error(f"节点元数据{node_meta_data_id}不存在") + return None parameters = { "input_parameters": node_pool_record["params_schema"], "output_parameters": node_pool_record["output_schema"], @@ -253,7 +259,7 @@ class FlowManager: edgeId=edge_config.id, sourceNode=edge_from, targetNode=edge_config.edge_to, - type=edge_config.edge_type, + type=edge_config.edge_type.value if edge_config.edge_type else EdgeType.NORMAL.value, branchId=branch_id, )) return (flow_item, focus_point) @@ -319,7 +325,7 @@ class FlowManager: id=edge_item.edge_id, edge_from=edge_from, edge_to=edge_item.target_node, - edge_type=edge_item.type + edge_type=EdgeType(edge_item.type) if edge_item.type else EdgeType.NORMAL, ) flow_config.edges.append(edge_config) await FlowLoader().save(app_id, flow_id, flow_config) @@ -447,5 +453,5 @@ class FlowManager: return True except Exception as e: - LOGGER.error(f'Update flow debug from app pool failed: {e}') + LOGGER.error(f"Update flow debug from app pool failed: {e!s}") return False diff --git a/apps/manager/node.py b/apps/manager/node.py index b02a2f877..728a6f178 100644 --- a/apps/manager/node.py +++ b/apps/manager/node.py @@ -1,11 +1,11 @@ """Node管理器""" -from typing import Any, Optional +from typing import Any import ray from apps.constants import LOGGER from apps.entities.node import APINode -from apps.entities.pool import CallPool, Node, NodePool +from apps.entities.pool import CallPool, NodePool from apps.models.mongo import MongoDB NODE_TYPE_MAP = { @@ -25,12 +25,76 @@ class NodeManager: raise ValueError(err) return node["call_id"] + @staticmethod async def get_node_name(node_id: str) -> str: - """获取Node的名称""" + """获取node的名称""" + node_collection = MongoDB.get_collection("node") + # 查询 Node 集合获取对应的 name + node_doc = await node_collection.find_one({"_id": node_id}, {"name": 1}) + if not node_doc: + LOGGER.error(f"Node {node_id} not found") + return "" + return node_doc["name"] + + + @staticmethod + def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]: + """递归合并参数Schema,将known_params中的值填充到params_schema的对应位置""" + if not isinstance(params_schema, dict): + return params_schema + + if params_schema.get("type") == "object": + properties = params_schema.get("properties", {}) + for key, value in properties.items(): + if key in known_params: + # 如果在known_params中找到匹配的键,更新default值 + properties[key]["default"] = known_params[key] + # 递归处理嵌套的schema + properties[key] = NodeManager.merge_params_schema(value, known_params) + + elif params_schema.get("type") == "array": + items = params_schema.get("items", {}) + # 递归处理数组项 + params_schema["items"] = NodeManager.merge_params_schema(items, known_params) + + return params_schema + + + @staticmethod + async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: + """获取Node数据""" + # 查找Node信息 node_collection = MongoDB().get_collection("node") - node = await node_collection.find_one({"_id": node_id}, {"name": 1}) - if not node: - err = f"[NodeManager] Node name_id {node_id} not found." + try: + node = await node_collection.find_one({"id": node_id}) + node_data = NodePool.model_validate(node) + except Exception as e: + err = f"[NodeManager] Get node data error: {e}" + LOGGER.error(err) + raise ValueError(err) from e + + call_id = node_data.call_id + # 查找Node对应的Call信息 + call_collection = MongoDB().get_collection("call") + try: + call = await call_collection.find_one({"id": call_id}) + call_data = CallPool.model_validate(call) + except Exception as e: + err = f"[NodeManager] Get call data error: {e}" + LOGGER.error(err) + raise ValueError(err) from e + + # 查找Call信息 + pool = ray.get_actor("pool") + call_class = await pool.get_call.remote(call_data.path) + if not call_class: + err = f"[NodeManager] Call {call_data.path} not found" + LOGGER.error(err) raise ValueError(err) - return node["name"] + + # 返回参数Schema + return ( + NodeManager.merge_params_schema(call_class.params_schema, node_data.known_params or {}), + call_class.output_schema, + ) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 568d98dbe..2bb806368 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -11,7 +11,7 @@ from typing import Annotated import ray from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse, StreamingResponse -from apps.routers.mock import mock_data + from apps.common.queue import MessageQueue from apps.constants import LOGGER, SCHEDULER_REPLICAS from apps.dependency import ( @@ -24,7 +24,7 @@ from apps.entities.request_data import RequestData from apps.entities.response_data import ResponseData from apps.manager.appcenter import AppCenterManager from apps.manager.blacklist import QuestionBlacklistManager, UserBlacklistManager -# from apps.scheduler.scheduler import Scheduler +from apps.scheduler.scheduler.context import save_data from apps.service.activity import Activity RECOMMEND_TRES = 5 @@ -35,80 +35,78 @@ router = APIRouter( ) -# async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: -# """进行实际问答,并从MQ中获取消息""" -# try: -# await Activity.set_active(user_sub) +async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: + """进行实际问答,并从MQ中获取消息""" + try: + await Activity.set_active(user_sub) -# # 敏感词检查 -# word_check = ray.get_actor("words_check") -# if await word_check.check.remote(post_body.question) != 1: -# yield "data: [SENSITIVE]\n\n" -# LOGGER.info(msg="问题包含敏感词!") -# await Activity.remove_active(user_sub) -# return + # 敏感词检查 + word_check = ray.get_actor("words_check") + if await word_check.check.remote(post_body.question) != 1: + yield "data: [SENSITIVE]\n\n" + LOGGER.info(msg="问题包含敏感词!") + await Activity.remove_active(user_sub) + return -# # 生成group_id -# group_id = str(uuid.uuid4()) if not post_body.group_id else post_body.group_id + # 生成group_id + group_id = str(uuid.uuid4()) if not post_body.group_id else post_body.group_id -# # 创建或还原Task(获取task_id) -# task_pool = ray.get_actor("task") -# task = await task_pool.get_task.remote(session_id=session_id, post_body=post_body) -# task_id = task.record.task_id + # 创建或还原Task + task_pool = ray.get_actor("task") + task = await task_pool.get_task.remote(session_id=session_id, post_body=post_body) + task_id = task.record.task_id -# task.record.group_id = group_id -# post_body.group_id = group_id -# await task_pool.set_task.remote(task_id, task) + task.record.group_id = group_id + post_body.group_id = group_id + await task_pool.set_task.remote(task_id, task) -# # 创建queue;由Scheduler进行关闭 -# queue = MessageQueue() -# await queue.init(task_id, enable_heartbeat=True) + # 创建queue;由Scheduler进行关闭 + queue = MessageQueue.remote() + await queue.init.remote(task_id) # type: ignore[attr-defined] -# # 在单独Task中运行Scheduler,拉齐queue.get的时机 -# scheduler = Scheduler(task_id, queue) -# scheduler_task = asyncio.create_task(scheduler.run(user_sub, session_id, post_body)) + # 在单独Task中运行Scheduler,拉齐queue.get的时机 + randnum = random.randint(0, SCHEDULER_REPLICAS - 1) # noqa: S311 + scheduler_actor = ray.get_actor(f"scheduler_{randnum}") + scheduler = scheduler_actor.run.remote(task_id, queue, user_sub, post_body) -# # 处理每一条消息 -# async for event in queue.get(): -# if event[:6] == "[DONE]": -# break + # 处理每一条消息 + async for event in queue.get.remote(): # type: ignore[attr-defined] + content = await event + if content[:6] == "[DONE]": + break -# yield "data: " + event + "\n\n" + yield "data: " + content + "\n\n" -# # 等待Scheduler运行完毕 -# await asyncio.gather(scheduler_task) + # 等待Scheduler运行完毕 + result = await scheduler -# # 获取最终答案 -# task = await task_pool.get_task.remote(task_id) -# answer_text = task.record.content.answer -# if not answer_text: -# LOGGER.error(msg="Answer is empty") -# yield "data: [ERROR]\n\n" -# await Activity.remove_active(user_sub) -# return + # 获取最终答案 + task = await task_pool.get_task.remote(task_id) + answer_text = task.record.content.answer + if not answer_text: + LOGGER.error(msg="Answer is empty") + yield "data: [ERROR]\n\n" + await Activity.remove_active(user_sub) + return -# # 对结果进行敏感词检查 -# if await word_check.check.remote(answer_text) != 1: -# yield "data: [SENSITIVE]\n\n" -# LOGGER.info(msg="答案包含敏感词!") -# await Activity.remove_active(user_sub) -# return + # 对结果进行敏感词检查 + if await word_check.check.remote(answer_text) != 1: + yield "data: [SENSITIVE]\n\n" + LOGGER.info(msg="答案包含敏感词!") + await Activity.remove_active(user_sub) + return -# # 创建新Record,存入数据库 -# await scheduler.save_state(user_sub, post_body) -# # 保存Task,从task_map中删除task -# await task_pool.save_task.remote(task_id) + # 创建新Record,存入数据库 + await save_data(task_id, user_sub, post_body, result.used_docs) -# yield "data: [DONE]\n\n" + yield "data: [DONE]\n\n" -# except Exception as e: -# LOGGER.error(msg=f"生成答案失败:{e!s}\n{traceback.format_exc()}") -# yield "data: [ERROR]\n\n" + except Exception as e: + LOGGER.error(msg=f"生成答案失败:{e!s}\n{traceback.format_exc()}") + yield "data: [ERROR]\n\n" -# finally: -# if scheduler_task: -# scheduler_task.cancel() -# await Activity.remove_active(user_sub) + finally: + await Activity.remove_active(user_sub) @router.post("/chat", dependencies=[Depends(verify_csrf_token), Depends(verify_user)]) @@ -130,12 +128,7 @@ async def chat( if post_body.app and post_body.app.app_id: await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) - # res = chat_generator(post_body, user_sub, session_id) - - if post_body.app and post_body.app.app_id: - res = mock_data(appId=post_body.app.app_id, conversationId=post_body.conversation_id, flowId=post_body.app.flow_id,question=post_body.question) - else: - res = mock_data(question=post_body.question) + res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, media_type="text/event-stream", diff --git a/apps/routers/flow.py b/apps/routers/flow.py index f49a1dbf2..3b054a4d8 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -50,19 +50,19 @@ async def get_services( return NodeServiceListRsp( code=status.HTTP_404_NOT_FOUND, message="未找到符合条件的服务", - result=NodeServiceListMsg() + result=NodeServiceListMsg(), ) return NodeServiceListRsp( code=status.HTTP_200_OK, message="节点元数据所在服务信息获取成功", - result=NodeServiceListMsg(services=services) + result=NodeServiceListMsg(services=services), ) @router.get("/service/node", response_model=NodeMetaDataRsp, responses={ status.HTTP_403_FORBIDDEN: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData} + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, }) async def get_node_metadatas( user_sub: Annotated[str, Depends(get_user)], @@ -93,7 +93,7 @@ async def get_node_metadatas( @router.get("", response_model=FlowStructureGetRsp, responses={ status.HTTP_403_FORBIDDEN: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData} + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, }) async def get_flow( user_sub: Annotated[str, Depends(get_user)], @@ -117,7 +117,7 @@ async def get_flow( return JSONResponse(status_code=status.HTTP_200_OK, content=FlowStructureGetRsp( code=status.HTTP_200_OK, message="应用下流程获取成功", - result=FlowStructureGetMsg(flow=result[0], focus_point=result[1]) + result=FlowStructureGetMsg(flow=result[0], focus_point=result[1]), ).model_dump(exclude_none=True, by_alias=True)) @@ -125,14 +125,14 @@ async def get_flow( status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, status.HTTP_403_FORBIDDEN: {"model": ResponseData}, status.HTTP_404_NOT_FOUND: {"model": ResponseData}, - status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": ResponseData} + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": ResponseData}, }) -async def put_flow( +async def put_flow( # noqa: ANN201 user_sub: Annotated[str, Depends(get_user)], - app_id: str = Query(..., alias="appId"), - flow_id: str = Query(..., alias="flowId"), - topology_check: Optional[bool] = Query(..., alias="topologyCheck"), - put_body: PutFlowReq = Body(...) + app_id: Annotated[str, Query(alias="appId")], + flow_id: Annotated[str, Query(alias="flowId")], + put_body: Annotated[PutFlowReq, Body(...)], + topology_check: Annotated[Optional[bool], Query(alias="topologyCheck")] = True, ): """修改流拓扑结构""" if not await AppManager.validate_app_belong_to_user(user_sub, app_id): @@ -162,18 +162,17 @@ async def put_flow( return JSONResponse(status_code=status.HTTP_200_OK, content=FlowStructurePutRsp( code=status.HTTP_200_OK, message="应用下流更新成功", - result=FlowStructurePutMsg(flow=flow[0]) + result=FlowStructurePutMsg(flow=flow[0]), ).model_dump(exclude_none=True, by_alias=True)) - @router.delete("", response_model=FlowStructureDeleteRsp, responses={ - status.HTTP_404_NOT_FOUND: {"model": ResponseData} + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, }) -async def delete_flow( +async def delete_flow( # noqa: ANN201 user_sub: Annotated[str, Depends(get_user)], - app_id: str = Query(..., alias="appId"), - flow_id: str = Query(..., alias="flowId") + app_id: Annotated[str, Query(alias="appId")], + flow_id: Annotated[str, Query(alias="flowId")], ): """删除流拓扑结构""" if not await AppManager.validate_app_belong_to_user(user_sub, app_id): @@ -192,5 +191,5 @@ async def delete_flow( return JSONResponse(status_code=status.HTTP_200_OK, content=FlowStructureDeleteRsp( code=status.HTTP_200_OK, message="应用下流程删除成功", - result=FlowStructureDeleteMsg(flowId=result) + result=FlowStructureDeleteMsg(flowId=result), ).model_dump(exclude_none=True, by_alias=True)) diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index d2c76b489..bf6b43624 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -3,7 +3,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json -from typing import Any, Literal, Optional, ClassVar +from typing import Any, ClassVar, Literal, Optional import aiohttp from fastapi import status @@ -31,7 +31,7 @@ class _APIParams(BaseModel): auth: dict[str, Any] = Field(description="API鉴权信息", default={}) -class _APIOutput(BaseModel): +class APIOutput(BaseModel): """API调用工具的输出""" http_code: int = Field(description="API调用工具的HTTP返回码") @@ -45,7 +45,7 @@ class API(CoreCall): name: ClassVar[str] = "HTTP请求" description: ClassVar[str] = "向某一个API接口发送HTTP请求,获取数据。" - async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> _APIOutput: + async def __call__(self, syscall_vars: CallVars, **_kwargs: Any) -> APIOutput: """调用API,然后返回LLM解析后的数据""" self._session = aiohttp.ClientSession() try: @@ -109,7 +109,7 @@ class API(CoreCall): raise NotImplementedError(err) - async def _call_api(self, slot_data: Optional[dict[str, Any]] = None) -> _APIOutput: + async def _call_api(self, slot_data: Optional[dict[str, Any]] = None) -> APIOutput: # 获取必要参数 params: _APIParams = getattr(self, "_params") LOGGER.info(f"调用接口{params.url},请求数据为{slot_data}") @@ -132,7 +132,7 @@ class API(CoreCall): message = f"""You called the HTTP API "{params.url}", which is used to "{self._spec[2]['summary']}".""" # 如果没有返回结果 if response_data is None: - return _APIOutput( + return APIOutput( http_code=response_status, output={}, message=message + "But the API returned an empty response.", @@ -150,7 +150,7 @@ class API(CoreCall): slot = Slot(response_schema) response_data = json.dumps(slot.process_json(response_dict), ensure_ascii=False) - return _APIOutput( + return APIOutput( http_code=response_status, output=json.loads(response_data), message=message + "The API returned some data, and is shown in the 'output' field below.", diff --git a/apps/scheduler/call/convert.py b/apps/scheduler/call/convert.py index f29b74118..80b7fdd96 100644 --- a/apps/scheduler/call/convert.py +++ b/apps/scheduler/call/convert.py @@ -31,7 +31,7 @@ class _ConvertOutput(BaseModel): data: dict = Field(description="格式化后的结果") -class Convert(metaclass=CoreCall, param_cls=_ConvertParam, output_cls=_ConvertOutput): +class Convert(CoreCall): """Convert 工具,用于对生成的文字信息和原始数据进行格式化""" name: str = "convert" diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index f50fafd2e..f80d9f3e7 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -22,11 +22,6 @@ class CoreCall(BaseModel): arbitrary_types_allowed = True - def __init_subclass__(cls, **kwargs: Any) -> None: - """初始化子类""" - return super().__init_subclass__(**kwargs) - - - async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> type[BaseModel]: + async def __call__(self, syscall_vars: CallVars, **kwargs: Any) -> type[BaseModel]: """Call类实例的调用方法""" raise NotImplementedError diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm.py index e6f45f8db..26e694f19 100644 --- a/apps/scheduler/call/llm.py +++ b/apps/scheduler/call/llm.py @@ -53,7 +53,7 @@ class LLM(CoreCall): user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT) - async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> LLMNodeOutput: + async def __call__(self, syscall_vars: CallVars, **_kwargs: Any) -> LLMNodeOutput: """运行LLM Call""" # 参数 time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag.py index be2990740..b8c824844 100644 --- a/apps/scheduler/call/rag.py +++ b/apps/scheduler/call/rag.py @@ -2,7 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar, Literal import aiohttp from fastapi import status @@ -13,7 +13,7 @@ from apps.entities.scheduler import CallError, CallVars from apps.scheduler.call.core import CoreCall -class _RAGOutput(BaseModel): +class RAGOutput(BaseModel): """RAG工具的输出""" corpus: list[str] = Field(description="知识库的语料列表") @@ -27,16 +27,18 @@ class RAG(CoreCall): knowledge_base: str = Field(description="知识库的id", alias="kb_sn", default=None) top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5) - retrieval_mode: str = Field(description="检索模式", default="chunk", choices=['chunk', 'full_text']) + retrieval_mode: Literal["chunk", "full_text"] = Field(description="检索模式", default="chunk") - async def exec(self, syscall_vars: CallVars, **kwargs: Any) -> _RAGOutput: + async def __call__(self, syscall_vars: CallVars, **_kwargs: Any) -> RAGOutput: """调用RAG工具""" - syscall_vars: SysCallVars = getattr(self, "_syscall_vars") - params: _RAGParams = getattr(self, "_params") + params_dict = { + "kb_sn": self.knowledge_base, + "top_k": self.top_k, + "retrieval_mode": self.retrieval_mode, + "question": syscall_vars.question, + } - params_dict = params.model_dump(exclude_none=True, by_alias=True) - params_dict["content"] = syscall_vars.question url = config["RAG_HOST"].rstrip("/") + "/chunk/get" headers = { "Content-Type": "application/json", @@ -49,10 +51,14 @@ class RAG(CoreCall): if response.status == status.HTTP_200_OK: result = await response.json() chunk_list = result["data"] + + corpus = [] for chunk in chunk_list: - chunk=chunk.replace("\n", " ") - return _RAGOutput( - output=_RAGOutputList(corpus=chunk_list), + clean_chunk = chunk.replace("\n", " ") + corpus.append(clean_chunk) + + return RAGOutput( + corpus=corpus, ) text = await response.text() raise CallError( diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql.py index 28e5e4c89..164a1899f 100644 --- a/apps/scheduler/call/sql.py +++ b/apps/scheduler/call/sql.py @@ -18,24 +18,19 @@ from apps.models.postgres import PostgreSQL from apps.scheduler.call.core import CoreCall -class _SQLParams(BaseModel): - """SQL工具的参数""" - - sql: Optional[str] = Field(description="用户输入") - - -class _SQLOutput(BaseModel): +class SQLOutput(BaseModel): """SQL工具的输出""" message: str = Field(description="SQL工具的执行结果") dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") -class SQL(metaclass=CoreCall, param_cls=_SQLParams, output_cls=_SQLOutput): +class SQL(CoreCall): """SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。""" - name: str = "sql" - description: str = "SQL工具,用于查询数据库中的结构化数据" + name: str = "数据库" + description: str = "使用大模型生成SQL语句,用于查询数据库中的结构化数据" + sql: Optional[str] = Field(description="用户输入") def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003 @@ -44,10 +39,9 @@ class SQL(metaclass=CoreCall, param_cls=_SQLParams, output_cls=_SQLOutput): self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) - async def __call__(self, _slot_data: dict[str, Any]) -> _SQLOutput: + async def __call__(self, _slot_data: dict[str, Any]) -> SQLOutput: """运行SQL工具""" # 获取必要参数 - params: _SQLParams = getattr(self, "_params") syscall_vars: CallVars = getattr(self, "_syscall_vars") # 若手动设置了SQL,则直接使用 @@ -58,7 +52,7 @@ class SQL(metaclass=CoreCall, param_cls=_SQLParams, output_cls=_SQLOutput): await session.close() dataset_list = [db_item._asdict() for db_item in result] - return _SQLOutput( + return SQLOutput( message="SQL查询成功!", dataset=dataset_list, ) @@ -91,7 +85,7 @@ class SQL(metaclass=CoreCall, param_cls=_SQLParams, output_cls=_SQLOutput): await session.close() dataset_list = [db_item._asdict() for db_item in db_result] - return _SQLOutput( + return SQLOutput( message="数据库查询成功!", dataset=dataset_list, ) diff --git a/apps/scheduler/executor/__init__.py b/apps/scheduler/executor/__init__.py index 0d4222ae2..85f8c3734 100644 --- a/apps/scheduler/executor/__init__.py +++ b/apps/scheduler/executor/__init__.py @@ -2,8 +2,3 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from apps.scheduler.executor.flow import Executor - -__all__ = [ - "Executor", -] diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 6b68d22d9..d29d09465 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -3,21 +3,21 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import traceback -from typing import Any +from typing import Any, Optional import ray from pydantic import BaseModel, Field from apps.constants import LOGGER, STEP_HISTORY_SIZE from apps.entities.enum_var import StepStatus -from apps.entities.flow import Flow, Step +from apps.entities.flow import Flow, FlowError, Step from apps.entities.request_data import RequestDataApp -from apps.entities.scheduler import CallVars +from apps.entities.scheduler import CallVars, ExecutorBackground from apps.entities.task import ExecutorState, TaskBlock -from apps.llm.patterns import ExecutorThought -from apps.llm.patterns.executor import ExecutorBackground +from apps.llm.patterns.executor import ExecutorSummary from apps.manager.node import NodeManager from apps.manager.task import TaskManager +from apps.scheduler.call.core import CoreCall from apps.scheduler.executor.message import ( push_flow_start, push_flow_stop, @@ -40,6 +40,7 @@ class Executor(BaseModel): question: str = Field(description="用户输入") context: str = Field(description="上下文", default="") post_body_app: RequestDataApp = Field(description="请求体中的app信息") + executor_background: ExecutorBackground = Field(description="Executor的背景信息") class Config: """Pydantic配置""" @@ -64,7 +65,7 @@ class Executor(BaseModel): status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), step_id="start", - thought="", + ai_summary="", filled_data=self.post_body_app.params, ) # 是否结束运行 @@ -83,156 +84,168 @@ class Executor(BaseModel): return flag - async def _run_step(self, step_data: Step) -> dict[str, Any]: - """运行单个步骤""" - # 更新State - self.flow_state.step_id = step_data.name - self.flow_state.status = StepStatus.RUNNING + async def _run_error(self, step: FlowError) -> dict[str, Any]: + """运行错误处理步骤""" + pass - # Call类型为none,跳过执行 - node_id = step_data.node - if node_id == "none": - return {} + async def _get_call_cls(self, node_id: str) -> Optional[type[CoreCall]]: + """获取并验证Call类""" # 获取对应Node的call_id try: call_id = await NodeManager.get_node_call_id(node_id) except Exception as e: LOGGER.error(f"[FlowExecutor] 获取工具{node_id}的call_id时发生错误:{e}。\n{traceback.format_exc()}") self.flow_state.status = StepStatus.ERROR - return {} + return None # 从Pool中获取对应的Call pool = ray.get_actor("pool") try: - call_cls = await pool.get_call.remote(call_id, self.flow_state.app_id) + call_cls: type[CoreCall] = await pool.get_call.remote(call_id, self.flow_state.app_id) except Exception as e: LOGGER.error(f"[FlowExecutor] 载入工具{node_id}时发生错误:{e}。\n{traceback.format_exc()}") self.flow_state.status = StepStatus.ERROR - return {} + return None # 检查Call合法性 if not self._check_cls(call_cls): LOGGER.error(f"[FlowExecutor] 工具{node_id}不符合Call标准要求。") self.flow_state.status = StepStatus.ERROR + return None + + return call_cls + + + async def _process_slots(self, call_obj: Any) -> tuple[bool, Optional[dict[str, Any]]]: + """处理slot参数""" + if not (hasattr(call_obj, "slot_schema") and call_obj.slot_schema): + return True, None + + slot_processor = Slot(call_obj.slot_schema) + remaining_schema, slot_data = await slot_processor.process( + self.flow_state.filled_data, + self.post_body_app.params, + { + "task_id": self.task.record.task_id, + "question": self.question, + "thought": self.flow_state.ai_summary, + "previous_output": await self._get_last_output(self.task), + }, + ) + + # 保存Schema至State + self.flow_state.remaining_schema = remaining_schema + self.flow_state.filled_data.update(slot_data) + + # 如果还有未填充的部分,则返回False + if remaining_schema: + self._stop = True + self.flow_state.status = StepStatus.RUNNING + # 推送空输入输出 + await push_step_input(self.task.record.task_id, self.queue, self.flow_state, self.flow) + self.flow_state.status = StepStatus.PARAM + await push_step_output(self.task.record.task_id, self.queue, self.flow_state, {}) + return False, None + + return True, slot_data + + + async def _execute_call(self, call_obj: Any, sys_vars: CallVars, node_id: str) -> dict[str, Any]: + """执行Call并处理结果""" + if not call_obj: + LOGGER.error(f"[FlowExecutor] 工具{node_id}不存在。") + return {} + + try: + result: BaseModel = await call_obj(sys_vars) + except Exception as e: + LOGGER.error(f"[FlowExecutor] 执行工具{node_id}时发生错误:{e!s}\n{traceback.format_exc()}") + self.flow_state.status = StepStatus.ERROR + return {} + + try: + result_data = result.model_dump(exclude_none=True, by_alias=True) + except Exception as e: + LOGGER.error(f"[FlowExecutor] 无法处理工具{node_id}返回值:{e!s}\n{traceback.format_exc()}") + self.flow_state.status = StepStatus.ERROR return {} - # 准备history - history = list(self.task.flow_context.values()) - length = min(STEP_HISTORY_SIZE, len(history)) - history = history[-length:] + self.flow_state.status = StepStatus.SUCCESS + return result_data - # 准备SysCallVars + + async def _run_step(self, step_data: Step) -> None: + """运行单个步骤""" + # 更新State + self.flow_state.step_id = step_data.name + self.flow_state.status = StepStatus.RUNNING + + # Call类型为none,跳过执行 + node_id = step_data.node + if node_id == "none": + return + + # 获取并验证Call类 + call_cls = await self._get_call_cls(node_id) + if call_cls is None: + return + + # 准备系统变量 + history = list(self.task.flow_context.values())[-STEP_HISTORY_SIZE:] sys_vars = CallVars( question=self.question, task_id=self.task.record.task_id, + flow_id=self.post_body_app.flow_id, session_id=self.task.session_id, - extra={ - "app_id": self.flow_state.app_id, - "flow_id": self.flow_state.name, - }, history=history, - background=self.flow_state.thought, + background=self.flow_state.ai_summary, ) # 初始化Call try: - # 拿到开发者定义的参数 - params = step_data.params - # 初始化Call - call_obj = call_cls(sys_vars, **params) + call_obj = call_cls.model_validate(step_data.params) except Exception as e: err = f"[FlowExecutor] 初始化工具{node_id}时发生错误:{e!s}\n{traceback.format_exc()}" LOGGER.error(err) self.flow_state.status = StepStatus.ERROR - return {} + return - # 如果call_obj里面有slot_schema,初始化Slot处理器 - if hasattr(call_obj, "slot_schema") and call_obj.slot_schema: - slot_processor = Slot(call_obj.slot_schema) - else: - # 没有schema,不进行处理 - slot_processor = None - - if slot_processor is not None: - # 处理参数 - remaining_schema, slot_data = await slot_processor.process( - self.flow_state.filled_data, - self.sysexec_vars.app_data.params, - { - "task_id": self.task.record.task_id, - "question": self.question, - "thought": self.flow_state.thought, - "previous_output": await self._get_last_output(self.task), - }, - ) - # 保存Schema至State - self.flow_state.remaining_schema = remaining_schema - self.flow_state.filled_data.update(slot_data) - # 如果还有未填充的部分,则终止执行 - if remaining_schema: - self._stop = True - self.flow_state.status = StepStatus.RUNNING - # 推送空输入 - await push_step_input(self.task, self.queue, self.flow_state, self.flow) - # 推送空输出 - self.flow_state.status = StepStatus.PARAM - result = {} - await push_step_output(self.task, self.queue, self.flow_state, result) - return result + # TODO: 处理slots + # can_continue, slot_data = await self._process_slots(call_obj) + # if not can_continue: + # return # 推送步骤输入 - await push_step_input(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data) + await push_step_input(self.task.record.task_id, self.queue, self.flow_state, self.flow_state.filled_data) - # 执行Call - try: - result: dict[str, Any] = await call_obj.call(self.flow_state.filled_data) - except Exception as e: - err = f"[FlowExecutor] 执行工具{node_id}时发生错误:{e!s}\n{traceback.format_exc()}" - LOGGER.error(err) - self.flow_state.status = StepStatus.ERROR - # 推送空输出 - result = {} - await push_step_output(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data, result) - return result - - # 更新背景 - await self._update_thought(call_obj.name, call_obj.description, result) - # 推送消息、保存结果 - self.flow_state.status = StepStatus.SUCCESS - await push_step_output(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data, result) - return result + # 执行Call并获取结果 + result_data = await self._execute_call(call_obj, sys_vars, node_id) + + # 推送输出 + await push_step_output(self.task.record.task_id, self.queue, self.flow_state, result_data) + return - async def _handle_next_step(self, result: dict[str, Any]) -> None: + async def _handle_next_step(self) -> None: """处理下一步""" - if self._next_step is None: - return + next_nodes = [] + # 遍历Edges,查找下一个节点 + for edge in self.flow.edges: + if edge.edge_from == self.flow_state.step_id: + next_nodes += [edge.edge_to] + # TODO # 处理分支(cloice工具) - if self._flow_data.steps[self._next_step].call_type == "cloice" and result.extra is not None: - self._next_step = result.extra.get("next_step") - return + # if self._flow_data.steps[self._next_step].call_type == "choice" and result.extra is not None: + # self._next_step = result.extra.get("next_step") + # return # 处理下一步 - self._next_step = self._flow_data.steps[self._next_step].next - - - async def _update_thought(self, call_name: str, call_description: str, call_result: dict[str, Any]) -> None: - """执行步骤后,更新FlowExecutor的思考内容""" - # 组装工具信息 - tool_info = { - "name": call_name, - "description": call_description, - "output": call_result, - } - # 更新背景 - self.flow_state.thought = await ExecutorThought().generate( - self._vars.task_id, - last_thought=self.flow_state.thought, - user_question=self._vars.question, - tool_info=tool_info, - ) + if not next_nodes: + self.flow_state.step_id = "end" + else: + self.flow_state.step_id = next_nodes[0] async def run(self) -> None: @@ -241,41 +254,39 @@ class Executor(BaseModel): 数据通过向Queue发送消息的方式传输 """ # 推送Flow开始 - await push_flow_start(self._task, self._vars.queue, self.flow_state, self._vars.question) - - # 更新背景 - self.flow_state.thought = await ExecutorBackground().generate(self._vars.task_id, background=self._vars.background) + await push_flow_start(self.task.record.task_id, self.queue, self.flow_state, self.question) while not self._stop: # 当前步骤不存在 - if self.flow_state.step_id not in self._flow_data.steps: + if self.flow_state.step_id not in self.flow.steps: break if self.flow_state.status == StepStatus.ERROR: # 当前步骤为错误处理步骤 - step = self._flow_data.on_error + step = self.flow.on_error else: - step = self._flow_data.steps[self.flow_state.step_id] + step = self.flow.steps[self.flow_state.step_id] # 当前步骤空白 if not step: break # 判断当前是否为最后一步 - if step.name == "end": - self._stop = True - if not step.next or step.next == "end": - self._stop = True + if self.flow_state.step_id == "end": + break # 运行步骤 - result = await self._run_step(step) + if isinstance(step, FlowError): + result = await self._run_error(step) + else: + result = await self._run_step(step) # 如果停止,则结束执行 if self._stop: break # 处理下一步 - await self._handle_next_step(result) + await self._handle_next_step() # Flow停止运行,推送消息 - await push_flow_stop(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data, self._vars.question) + await push_flow_stop(self.task.record.task_id, self.queue, self.flow_state, self.flow) diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py index 1f41eff9c..879234a8c 100644 --- a/apps/scheduler/executor/message.py +++ b/apps/scheduler/executor/message.py @@ -6,13 +6,18 @@ from typing import Any import ray -from apps.entities.enum_var import EventType, FlowOutputType, StepStatus +from apps.entities.enum_var import EventType, FlowOutputType from apps.entities.flow import Flow from apps.entities.message import ( FlowStartContent, FlowStopContent, ) -from apps.entities.task import ExecutorState, FlowStepHistory, TaskBlock +from apps.entities.task import ( + ExecutorState, + FlowStepHistory, + TaskBlock, +) +from apps.manager.node import NodeManager async def push_step_input(task_id: str, queue: ray.ObjectRef, state: ExecutorState, input_data: dict[str, Any]) -> None: @@ -74,7 +79,7 @@ async def push_flow_start(task_id: str, queue: ray.ObjectRef, state: ExecutorSta async def assemble_flow_stop_content(state: ExecutorState, flow: Flow) -> FlowStopContent: """组装Flow结束消息""" - call_type = flow.steps[state.step_id].call_type + call_type = await NodeManager.get_node_call_id(state.step_id) if state.remaining_schema: # 如果当前Flow是填充步骤,则推送Schema content = FlowStopContent( diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 91ecbe680..3164a7f80 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -3,6 +3,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import shutil + import ray from anyio import Path from fastapi.encoders import jsonable_encoder @@ -94,6 +96,7 @@ class AppLoader: await file_checker.diff_one(app_path) await self.load(app_id, file_checker.hashes[f"{APP_DIR}/{app_id}"]) + async def delete(self, app_id: str) -> None: """删除App,并更新数据库 @@ -116,6 +119,10 @@ class AppLoader: await session.aclose() + app_path = Path(config["SEMANTICS_DIR"]) / APP_DIR / app_id + if await app_path.exists(): + shutil.rmtree(str(app_path), ignore_errors=True) + async def _update_db(self, metadata: AppMetadata) -> None: """更新数据库""" diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 5c46e7d66..f3d6d079e 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -13,7 +13,7 @@ import apps.scheduler.call as system_call from apps.common.config import config from apps.constants import CALL_DIR, LOGGER from apps.entities.enum_var import CallType -from apps.entities.pool import CallPool +from apps.entities.pool import CallPool, NodePool from apps.entities.vector import CallPoolVector from apps.models.mongo import MongoDB from apps.models.postgres import PostgreSQL @@ -128,7 +128,6 @@ class CallLoader: return call_metadata - # TODO: 动态卸载 async def _delete_one(self, call_name: str) -> None: """删除单个Call""" pass @@ -144,10 +143,18 @@ class CallLoader: """更新数据库""" # 更新MongoDB call_collection = MongoDB.get_collection("call") + node_collection = MongoDB.get_collection("node") call_descriptions = [] try: for call in call_metadata: await call_collection.update_one({"_id": call.id}, {"$set": call.model_dump(exclude_none=True, by_alias=True)}, upsert=True) + await node_collection.insert_one(NodePool( + _id=call.id, + name=call.name, + description=call.description, + service_id="", + call_id=call.id, + ).model_dump(exclude_none=True, by_alias=True)) call_descriptions += [call.description] except Exception as e: err = f"更新MongoDB失败:{e}" @@ -174,8 +181,10 @@ class CallLoader: """初始化Call信息""" # 清空collection call_collection = MongoDB.get_collection("call") + node_collection = MongoDB.get_collection("node") try: await call_collection.delete_many({}) + await node_collection.delete_many({"service_id": ""}) except Exception as e: LOGGER.error(msg=f"Call的collection清空失败:{e}") diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 2baf03e6e..9271115ee 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -5,9 +5,9 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from typing import Optional import aiofiles -from fastapi.encoders import jsonable_encoder import yaml from anyio import Path +from fastapi.encoders import jsonable_encoder from apps.common.config import config from apps.constants import APP_DIR, FLOW_DIR, LOGGER diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index 46e94aa97..c2d864fbc 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -3,7 +3,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import uuid +from hashlib import shake_128 from typing import Any import ray @@ -122,15 +122,16 @@ class OpenAPILoader: """将OpenAPI文档拆解为Node""" nodes = [] for api_endpoint in spec.endpoints: + # 通过算法生成唯一的标识符 + identifier = shake_128(f"openapi::{yaml_filename}::{api_endpoint.uri}".encode()).hexdigest(16) # 组装新的NodePool item node = APINode( - _id=str(uuid.uuid4()), + _id=identifier, name=api_endpoint.name, # 此处固定Call的ID是"API" call_id="API", description=api_endpoint.description, service_id=service_id, - annotation=f"openapi::{yaml_filename}", ) # 合并参数 diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index b377ffc48..254267e5f 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -14,8 +14,9 @@ from apps.entities.request_data import RequestData from apps.entities.scheduler import ExecutorBackground from apps.entities.task import SchedulerResult, TaskBlock from apps.manager.user import UserManager -from apps.scheduler.scheduler.context import generate_facts, get_context -from apps.scheduler.scheduler.flow import Flow, FlowChooser +from apps.scheduler.executor.flow import Executor +from apps.scheduler.scheduler.context import get_context, get_docs +from apps.scheduler.scheduler.flow import FlowChooser from apps.scheduler.scheduler.message import ( push_document_message, push_init_message, @@ -41,9 +42,6 @@ class Scheduler: return SchedulerResult(used_docs=[]) try: - # 根据用户的请求,返回插件ID列表,选择Flow - flow_chooser = FlowChooser(task_id=self._task_id, question=post_body.question,user_selected=post_body.app) - user_selected_flow = flow_chooser.choose_flow() # 获取当前问答可供关联的文档 docs, doc_ids = await get_docs(user_sub, post_body) except Exception as e: @@ -54,49 +52,6 @@ class Scheduler: try: # 获取上下文;最多20轮 context, facts = await get_context(user_sub, post_body, post_body.features.context_num) - - # 获取用户配置的kb_sn - user_info = await UserManager.get_userinfo_by_user_sub(user_sub) - if not user_info: - err = "[Scheduler] User not found" - raise ValueError(err) # noqa: TRY301 - # 组装RAG请求数据,备用 - rag_data = RAGQueryReq( - question=post_body.question, - language=post_body.language, - document_ids=doc_ids, - kb_sn=None if not user_info.kb_id else user_info.kb_id, - top_k=5, - ) - - # 如果是智能问答,直接执行 - if not user_selected_flow: - await push_init_message(self._task_id, self._queue, post_body, is_flow=False) - await asyncio.sleep(0.1) - for doc in docs: - # 保存使用的文件ID - self.used_docs.append(doc.id) - await push_document_message(self._queue, doc) - - # 保存有数据的最后一条消息 - await push_rag_message(self._task_id, self._queue, user_sub, rag_data) - else: - # 需要执行Flow - await push_init_message(self._task_id, self._queue, post_body, is_flow=True) - # 组装上下文 - background = ExecutorBackground( - conversation=context, - facts=facts, - ) - need_recommend = await self.run_executor(session_id, post_body, background, user_selected_flow) - - # 记忆提取 - self._facts = await generate_facts(self._task_id, post_body.question) - - # 发送结束消息 - await self._queue.push_output(event_type=EventType.DONE, data={}) - # 关闭Queue - await self._queue.close() except Exception as e: LOGGER.error(f"[Scheduler] Get context failed: {e!s}\n{traceback.format_exc()}") await queue.close.remote() # type: ignore[attr-defined] @@ -118,7 +73,8 @@ class Scheduler: kb_sn=None if not user_info.kb_id else user_info.kb_id, top_k=5, ) - # print("begin_to_run") + # 已使用文档 + used_docs = [] # 如果是智能问答,直接执行 if not post_body.app or post_body.app.app_id == "": @@ -136,11 +92,11 @@ class Scheduler: # 需要执行Flow await push_init_message(task_id, queue, post_body, is_flow=True) # 组装上下文 - background = ExecutorBackground( + executor_background = ExecutorBackground( conversation=context, facts=facts, ) - await self.run_executor(task, queue, post_body, background) + await self.run_executor(task, queue, post_body, executor_background) # 发送结束消息 task = await task_actor.get_task.remote(task_id) @@ -150,6 +106,7 @@ class Scheduler: return SchedulerResult(used_docs=used_docs) + async def run_executor(self, task: TaskBlock, queue: ray.ObjectRef, post_body: RequestData, background: ExecutorBackground) -> None: """构造FlowExecutor,并执行所选择的流""" # 读取App中所有Flow的信息 @@ -187,6 +144,7 @@ class Scheduler: queue=queue, question=post_body.question, post_body_app=post_body.app, + executor_background=background, ) # 开始运行 await flow_exec.load_state() diff --git a/deploy/chart/euler_copilot/templates/secrets.yaml b/deploy/chart/euler_copilot/templates/secrets.yaml index 221cb4f03..5e7faa723 100644 --- a/deploy/chart/euler_copilot/templates/secrets.yaml +++ b/deploy/chart/euler_copilot/templates/secrets.yaml @@ -14,8 +14,8 @@ stringData: halfKey2: {{ index $systemSecret.data.halfKey2 | b64dec }} halfKey3: {{ index $systemSecret.data.halfKey3 | b64dec }} csrfKey: {{ index $systemSecret.data.csrfKey | b64dec }} - clientId: {{ index $systemSecret.data.clientId | b64dec }} - clientSecret: {{ index $systemSecret.data.clientSecret | b64dec }} + clientId: {{ .Values.login.client.id }} + clientSecret: {{ .Values.login.client.secret }} {{- else -}} apiVersion: v1 kind: Secret -- Gitee From 5c7ad170012d2aab681841df499a5753695d4c55 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 26 Feb 2025 03:04:23 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E5=8E=BB=E9=99=A4Loader=E7=9A=84Ray?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/flow_topology.py | 2 +- apps/manager/appcenter.py | 17 +++++-------- apps/manager/flow.py | 22 ++++++++--------- apps/manager/service.py | 16 +++++-------- apps/routers/chat.py | 1 + apps/routers/mock.py | 14 ++++------- apps/scheduler/pool/loader/app.py | 2 -- apps/scheduler/pool/loader/metadata.py | 1 - apps/scheduler/pool/loader/service.py | 1 - apps/scheduler/pool/pool.py | 33 ++++++++++---------------- 10 files changed, 40 insertions(+), 69 deletions(-) diff --git a/apps/entities/flow_topology.py b/apps/entities/flow_topology.py index 19649dbbd..469ee3f87 100644 --- a/apps/entities/flow_topology.py +++ b/apps/entities/flow_topology.py @@ -28,7 +28,7 @@ class NodeServiceItem(BaseModel): name: str = Field(..., description="服务名称") type: str = Field(..., description="服务类型") node_meta_datas: list[NodeMetaDataItem] = Field(alias="nodeMetaDatas", default=[]) - created_at: Optional[str] = Field(..., alias="createdAt", description="创建时间") + created_at: Optional[str] = Field(default=None, alias="createdAt", description="创建时间") class PositionItem(BaseModel): diff --git a/apps/manager/appcenter.py b/apps/manager/appcenter.py index 1605e1e4a..30d321322 100644 --- a/apps/manager/appcenter.py +++ b/apps/manager/appcenter.py @@ -8,8 +8,6 @@ from datetime import datetime, timezone from enum import Enum from typing import Any, Optional -import ray - from apps.constants import LOGGER from apps.entities.appcenter import AppCenterCardItem, AppData from apps.entities.collection import User @@ -223,9 +221,8 @@ class AppCenterManager: ), ) try: - app_loader = AppLoader.remote() - await app_loader.save.remote(metadata, app_id) # type: ignore[attr-type] - ray.kill(app_loader) + app_loader = AppLoader() + await app_loader.save(metadata, app_id) return app_id except Exception as e: LOGGER.error(f"[AppCenterManager] Create app failed: {e}") @@ -262,9 +259,8 @@ class AppCenterManager: return False metadata.flows = app_data.flows metadata.published = app_data.published - app_loader = AppLoader.remote() - await app_loader.save.remote(metadata, app_id) # type: ignore[attr-type] - ray.kill(app_loader) + app_loader = AppLoader() + await app_loader.save(metadata, app_id) return True except Exception as e: LOGGER.error(f"[AppCenterManager] Update app failed: {e}") @@ -343,9 +339,8 @@ class AppCenterManager: if app_data.author != user_sub: return False # 删除应用 - app_loader = AppLoader.remote() - await app_loader.delete.remote(app_id) # type: ignore[attr-type] - ray.kill(app_loader) + app_loader = AppLoader() + await app_loader.delete(app_id) user_collection = MongoDB.get_collection("user") # 删除用户使用记录 await user_collection.update_many( diff --git a/apps/manager/flow.py b/apps/manager/flow.py index f60d09b5d..b73b1fa21 100644 --- a/apps/manager/flow.py +++ b/apps/manager/flow.py @@ -4,7 +4,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Optional -import ray from pymongo import ASCENDING from apps.constants import LOGGER @@ -20,10 +19,11 @@ from apps.entities.flow_topology import ( PositionItem, ) from apps.entities.pool import AppFlow, AppPool +from apps.manager.node import NodeManager from apps.models.mongo import MongoDB from apps.scheduler.pool.loader.app import AppLoader from apps.scheduler.pool.loader.flow import FlowLoader -from apps.manager.node import NodeManager + class FlowManager: """Flow相关操作""" @@ -79,7 +79,8 @@ class FlowManager: nodes_meta_data_items = [] async for node_pool_record in cursor: - params_schema, output_schema = await NodeManager.get_node_params(node_pool_record["_id"]) + # params_schema, output_schema = await NodeManager.get_node_params(node_pool_record["_id"]) + params_schema, output_schema = {},{} parameters = { "input_parameters": params_schema, "output_parameters": output_schema, @@ -374,9 +375,8 @@ class FlowManager: focus_point=PositionItem(x=focus_point.x, y=focus_point.y), ) metadata.flows.append(new_flow) - app_loader = AppLoader.remote() - await app_loader.save.remote(metadata, app_id) # type: ignore[attr-type] - ray.kill(app_loader) + app_loader = AppLoader() + await app_loader.save(metadata, app_id) if result is None: LOGGER.error("Add flow failed") return None @@ -419,9 +419,8 @@ class FlowManager: for flow in metadata.flows: if flow.id == flow_id: metadata.flows.remove(flow) - app_loader = AppLoader.remote() - await app_loader.save.remote(metadata, app_id) # type: ignore[attr-type] - ray.kill(app_loader) + app_loader = AppLoader() + await app_loader.save(metadata, app_id) if result is None: LOGGER.error("Delete flow from app pool failed") return None @@ -472,9 +471,8 @@ class FlowManager: for flows in metadata.flows: if flows.id == flow_id: flows.debug = debug - app_loader = AppLoader.remote() - await app_loader.save.remote(metadata, app_id) # type: ignore[attr-type] - ray.kill(app_loader) + app_loader = AppLoader() + await app_loader.save(metadata, app_id) flow_loader = FlowLoader() flow = await flow_loader.load(app_id, flow_id) if flow is None: diff --git a/apps/manager/service.py b/apps/manager/service.py index cf4550a44..3e0cb147f 100644 --- a/apps/manager/service.py +++ b/apps/manager/service.py @@ -5,7 +5,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. import uuid from typing import Any, Optional -import ray import yaml from anyio import Path from jsonschema import ValidationError @@ -120,9 +119,8 @@ class ServiceCenterManager: author=user_sub, api=ServiceApiConfig(server=data["servers"][0]["url"]), ) - service_loader = ServiceLoader.remote() - await service_loader.save.remote(service_id, service_metadata, data) # type: ignore[attr-type] - ray.kill(service_loader) + service_loader = ServiceLoader() + await service_loader.save(service_id, service_metadata, data) # 返回服务ID return service_id @@ -154,9 +152,8 @@ class ServiceCenterManager: author=user_sub, api=ServiceApiConfig(server=data["servers"][0]["url"]), ) - service_loader = ServiceLoader.remote() - await service_loader.save.remote(service_id, service_metadata, data) # type: ignore[attr-type] - ray.kill(service_loader) + service_loader = ServiceLoader() + await service_loader.save(service_id, service_metadata, data) # 返回服务ID return service_id @@ -226,9 +223,8 @@ class ServiceCenterManager: msg = "Permission denied" raise ValueError(msg) # 删除服务 - service_loader = ServiceLoader.remote() - await service_loader.delete.remote(service_id) # type: ignore[attr-type] - ray.kill(service_loader) + service_loader = ServiceLoader() + await service_loader.delete(service_id) # 删除用户收藏 await user_collection.update_many( {"fav_services": {"$in": [service_id]}}, diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 2bb806368..850f1a6df 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -138,6 +138,7 @@ async def chat( ) + @router.post("/stop", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """停止生成""" diff --git a/apps/routers/mock.py b/apps/routers/mock.py index f2b7017c2..0f00ab577 100644 --- a/apps/routers/mock.py +++ b/apps/routers/mock.py @@ -1,5 +1,3 @@ -import asyncio -import copy import json import random import time @@ -8,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Optional import aiohttp from pydantic import BaseModel, Field import tiktoken -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, status from fastapi.responses import StreamingResponse from apps.common.config import config @@ -19,20 +17,16 @@ from apps.dependency import ( verify_user, ) from apps.entities.request_data import MockRequestData, RequestData -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError from apps.manager.flow import FlowManager from apps.scheduler.pool.loader.flow import FlowLoader from datetime import datetime from textwrap import dedent from typing import Any -import pytz -from jinja2 import BaseLoader, select_autoescape -from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field -from apps.entities.scheduler import CallError, SysCallVars -from apps.scheduler.call.core import CoreCall +from apps.entities.scheduler import CallError """问答大模型调用 @@ -323,7 +317,7 @@ async def mock_data(appId="68dd3d90-6a97-4da0-aa62-d38a81c7d2f5", flowId="966c79 time.sleep(t) yield "data: " + json.dumps(message,ensure_ascii=False) + "\n\n" mid_message = [] - flow = await FlowLoader.load(appId, flowId) + flow = await FlowLoader().load(appId, flowId) now_flow_item = "start" start_time = time.time() last_item = "" diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 3164a7f80..8d8d8c2fa 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -5,7 +5,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import shutil -import ray from anyio import Path from fastapi.encoders import jsonable_encoder from sqlalchemy import delete @@ -23,7 +22,6 @@ from apps.scheduler.pool.loader.flow import FlowLoader from apps.scheduler.pool.loader.metadata import MetadataLoader -@ray.remote class AppLoader: """应用加载器""" diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py index 3942059da..ffe786908 100644 --- a/apps/scheduler/pool/loader/metadata.py +++ b/apps/scheduler/pool/loader/metadata.py @@ -2,7 +2,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import json from typing import Any, Optional, Union import yaml diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 566ac0886..2386afbe1 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -23,7 +23,6 @@ from apps.scheduler.pool.loader.metadata import MetadataLoader, MetadataType from apps.scheduler.pool.loader.openapi import OpenAPILoader -@ray.remote class ServiceLoader: """Service 加载器""" diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 6d23c568d..a94b488e3 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -2,7 +2,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import asyncio import importlib from typing import Any, Optional @@ -36,43 +35,35 @@ class Pool: changed_service, deleted_service = await checker.diff(MetadataType.SERVICE) # 处理Service - service_loader = ServiceLoader.remote() + service_loader = ServiceLoader() # 批量删除 - delete_task = [service_loader.delete.remote(service) for service in changed_service] # type: ignore[attr-type] - delete_task += [service_loader.delete.remote(service) for service in deleted_service] # type: ignore[attr-type] - await asyncio.gather(*delete_task) + for service in changed_service: + await service_loader.delete(service) + for service in deleted_service: + await service_loader.delete(service) # 批量加载 - load_task = [] for service in changed_service: hash_key = Path(SERVICE_DIR + "/" + service).as_posix() if hash_key in checker.hashes: - load_task.append(service_loader.load.remote(service, checker.hashes[hash_key])) # type: ignore[attr-type] - await asyncio.gather(*load_task) - - # 完成Service load - ray.kill(service_loader) + await service_loader.load(service, checker.hashes[hash_key]) # 加载App changed_app, deleted_app = await checker.diff(MetadataType.APP) - app_loader = AppLoader.remote() + app_loader = AppLoader() # 批量删除App - delete_task = [app_loader.delete.remote(app) for app in changed_app] # type: ignore[attr-type] - delete_task += [app_loader.delete.remote(app) for app in deleted_app] # type: ignore[attr-type] - await asyncio.gather(*delete_task) + for app in changed_app: + await app_loader.delete(app) + for app in deleted_app: + await app_loader.delete(app) # 批量加载App - load_task = [] for app in changed_app: hash_key = Path(APP_DIR + "/" + app).as_posix() if hash_key in checker.hashes: - load_task.append(app_loader.load.remote(app, checker.hashes[hash_key])) # type: ignore[attr-type] - await asyncio.gather(*load_task) - - # 完成App load - ray.kill(app_loader) + await app_loader.load(app, checker.hashes[hash_key]) async def save(self, *, is_deletion: bool = False) -> None: -- Gitee From 6c44e7c76c4487f75af5e7e70c0c084dacc08fad Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 26 Feb 2025 07:05:13 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/task.py | 2 +- apps/entities/request_data.py | 4 +- apps/llm/patterns/select.py | 6 ++- apps/manager/node.py | 23 +++++++-- apps/routers/chat.py | 1 - apps/scheduler/call/api.py | 3 +- apps/scheduler/call/core.py | 7 +++ apps/scheduler/call/llm.py | 3 +- apps/scheduler/call/rag.py | 2 +- apps/scheduler/pool/loader/__init__.py | 3 +- apps/scheduler/pool/loader/flow.py | 69 ++++---------------------- apps/scheduler/pool/pool.py | 29 +++-------- apps/scheduler/scheduler/flow.py | 3 ++ apps/scheduler/scheduler/message.py | 12 ++--- apps/scheduler/scheduler/scheduler.py | 31 ++++++++---- 15 files changed, 86 insertions(+), 112 deletions(-) diff --git a/apps/common/task.py b/apps/common/task.py index 7970b22eb..e46ed338a 100644 --- a/apps/common/task.py +++ b/apps/common/task.py @@ -66,7 +66,7 @@ class Task: input_tokens=0, output_tokens=0, time=0, - feature=post_body.features.model_dump(by_alias=True), + feature={}, ), createdAt=round(datetime.now(timezone.utc).timestamp(), 3), ) diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index 5adf76329..d3b4c2100 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -7,6 +7,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field +from apps.common.config import config from apps.entities.appcenter import AppData from apps.entities.flow_topology import FlowItem, PositionItem from apps.entities.task import RequestDataApp @@ -24,7 +25,7 @@ class MockRequestData(BaseModel): class RequestDataFeatures(BaseModel): """POST /api/chat的features字段数据""" - max_tokens: int = Field(default=8192, description="最大生成token数", ge=0) + max_tokens: int = Field(default=config["LLM_MAX_TOKENS"], description="最大生成token数", ge=0) context_num: int = Field(default=5, description="上下文消息数量", le=10, ge=0) @@ -37,7 +38,6 @@ class RequestData(BaseModel): language: str = Field(default="zh", description="语言") files: list[str] = Field(default=[], description="文件列表") app: Optional[RequestDataApp] = Field(default=None, description="应用") - features: Optional[RequestDataFeatures] = Field(default=None, description="消息功能设置") debug: bool = Field(default=False, description="是否调试") diff --git a/apps/llm/patterns/select.py b/apps/llm/patterns/select.py index 6e1990f6c..d9c814c5d 100644 --- a/apps/llm/patterns/select.py +++ b/apps/llm/patterns/select.py @@ -7,6 +7,7 @@ import json from collections import Counter from typing import Any, ClassVar, Optional +from apps.constants import LOGGER from apps.llm.patterns.core import CorePattern from apps.llm.patterns.json import Json from apps.llm.reasoning import ReasoningLLM @@ -91,6 +92,7 @@ class Select(CorePattern): async def _generate_single_attempt(self, task_id: str, user_input: str, choice_list: list[str]) -> str: """使用ReasoningLLM进行单次尝试""" + LOGGER.info(f"[Select] Trying single attempt for task {task_id}...") messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": user_input}, @@ -98,7 +100,7 @@ class Select(CorePattern): result = "" async for chunk in ReasoningLLM().call(task_id, messages, streaming=False): result += chunk - + LOGGER.info(f"[Select] Result: {result}") # 使用FunctionLLM进行参数提取 schema = self.slot_schema @@ -111,6 +113,7 @@ class Select(CorePattern): async def generate(self, task_id: str, **kwargs) -> str: # noqa: ANN003 """使用大模型做出选择""" + LOGGER.info(f"[Select] Selecting using LLM: {task_id}...") max_try = 3 result_list = [] @@ -129,4 +132,5 @@ class Select(CorePattern): result_list = await asyncio.gather(*result_coroutine) count = Counter(result_list) + LOGGER.info(f"[Select] Result: {count.most_common(1)[0][0]}") return count.most_common(1)[0][0] diff --git a/apps/manager/node.py b/apps/manager/node.py index 728a6f178..e4ef2b600 100644 --- a/apps/manager/node.py +++ b/apps/manager/node.py @@ -2,6 +2,7 @@ from typing import Any import ray +from pydantic import BaseModel from apps.constants import LOGGER from apps.entities.node import APINode @@ -65,9 +66,15 @@ class NodeManager: async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" # 查找Node信息 + LOGGER.info(f"[NodeManager] Getting node {node_id}...") node_collection = MongoDB().get_collection("node") + node = await node_collection.find_one({"_id": node_id}) + if not node: + err = f"[NodeManager] Node {node_id} not found." + LOGGER.error(err) + raise ValueError(err) + try: - node = await node_collection.find_one({"id": node_id}) node_data = NodePool.model_validate(node) except Exception as e: err = f"[NodeManager] Get node data error: {e}" @@ -77,8 +84,13 @@ class NodeManager: call_id = node_data.call_id # 查找Node对应的Call信息 call_collection = MongoDB().get_collection("call") + call = await call_collection.find_one({"_id": call_id}) + if not call: + err = f"[NodeManager] Call {call_id} not found." + LOGGER.error(err) + raise ValueError(err) + try: - call = await call_collection.find_one({"id": call_id}) call_data = CallPool.model_validate(call) except Exception as e: err = f"[NodeManager] Get call data error: {e}" @@ -86,8 +98,9 @@ class NodeManager: raise ValueError(err) from e # 查找Call信息 + LOGGER.info(f"[NodeManager] Getting call {call_data.path}...") pool = ray.get_actor("pool") - call_class = await pool.get_call.remote(call_data.path) + call_class: type[BaseModel] = await pool.get_call.remote(call_data.path) if not call_class: err = f"[NodeManager] Call {call_data.path} not found" LOGGER.error(err) @@ -95,6 +108,6 @@ class NodeManager: # 返回参数Schema return ( - NodeManager.merge_params_schema(call_class.params_schema, node_data.known_params or {}), - call_class.output_schema, + NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), + call_class.ret_type.model_json_schema(), ) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 850f1a6df..2bb806368 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -138,7 +138,6 @@ async def chat( ) - @router.post("/stop", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """停止生成""" diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index bf6b43624..41e9f34c8 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -39,12 +39,13 @@ class APIOutput(BaseModel): output: dict[str, Any] = Field(description="API调用工具的输出") -class API(CoreCall): +class API(CoreCall, ret_type=APIOutput): """API调用工具""" name: ClassVar[str] = "HTTP请求" description: ClassVar[str] = "向某一个API接口发送HTTP请求,获取数据。" + async def __call__(self, syscall_vars: CallVars, **_kwargs: Any) -> APIOutput: """调用API,然后返回LLM解析后的数据""" self._session = aiohttp.ClientSession() diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index f80d9f3e7..59551dff5 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -16,6 +16,13 @@ class CoreCall(BaseModel): name: ClassVar[str] = Field(description="Call的名称") description: ClassVar[str] = Field(description="Call的描述") + ret_type: ClassVar[type[BaseModel]] + + def __init_subclass__(cls, ret_type: type[BaseModel], **kwargs: Any) -> None: + """初始化子类""" + super().__init_subclass__(**kwargs) + cls.ret_type = ret_type + class Config: """Pydantic 配置类""" diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm.py index 26e694f19..c3e8ea666 100644 --- a/apps/scheduler/call/llm.py +++ b/apps/scheduler/call/llm.py @@ -41,7 +41,7 @@ class LLMNodeOutput(BaseModel): message: str = Field(description="大模型输出的文字信息") -class LLM(CoreCall): +class LLM(CoreCall, ret_type=LLMNodeOutput): """大模型调用工具""" name: ClassVar[str] = "大模型" @@ -52,7 +52,6 @@ class LLM(CoreCall): system_prompt: str = Field(description="大模型系统提示词", default="") user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT) - async def __call__(self, syscall_vars: CallVars, **_kwargs: Any) -> LLMNodeOutput: """运行LLM Call""" # 参数 diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag.py index b8c824844..776d7c6b1 100644 --- a/apps/scheduler/call/rag.py +++ b/apps/scheduler/call/rag.py @@ -19,7 +19,7 @@ class RAGOutput(BaseModel): corpus: list[str] = Field(description="知识库的语料列表") -class RAG(CoreCall): +class RAG(CoreCall, ret_type=RAGOutput): """RAG工具:查询知识库""" name: ClassVar[str] = "知识库" diff --git a/apps/scheduler/pool/loader/__init__.py b/apps/scheduler/pool/loader/__init__.py index cb2a85024..8a842a898 100644 --- a/apps/scheduler/pool/loader/__init__.py +++ b/apps/scheduler/pool/loader/__init__.py @@ -4,6 +4,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from apps.scheduler.pool.loader.app import AppLoader from apps.scheduler.pool.loader.call import CallLoader +from apps.scheduler.pool.loader.flow import FlowLoader from apps.scheduler.pool.loader.service import ServiceLoader -__all__ = ["AppLoader", "CallLoader", "ServiceLoader"] +__all__ = ["AppLoader", "CallLoader", "FlowLoader", "ServiceLoader"] diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 9271115ee..5461612e1 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -7,14 +7,12 @@ from typing import Optional import aiofiles import yaml from anyio import Path -from fastapi.encoders import jsonable_encoder from apps.common.config import config from apps.constants import APP_DIR, FLOW_DIR, LOGGER from apps.entities.enum_var import EdgeType -from apps.entities.flow import Flow, FlowConfig +from apps.entities.flow import Flow from apps.manager.node import NodeManager -from apps.models.mongo import MongoDB class FlowLoader: @@ -22,6 +20,7 @@ class FlowLoader: async def load(self, app_id, flow_id) -> Optional[Flow]: """从文件系统中加载【单个】工作流""" + LOGGER.info(f"[FlowLoader] Loading flow {flow_id} for app {app_id}...") flow_path = Path(config["SEMANTICS_DIR"]) / "app" / app_id / "flow" / f"{flow_id}.yaml" async with aiofiles.open(flow_path, encoding="utf-8") as f: flow_yaml = yaml.safe_load(await f.read()) @@ -39,6 +38,7 @@ class FlowLoader: LOGGER.error(err) raise ValueError(err) + LOGGER.info(f"[FlowLoader] Parsing edges of flow {flow_id} for app {app_id}...") for edge in flow_yaml["edges"]: # 把from变成edge_from,to改成edge_to,type改成edge_type if "from" in edge: @@ -54,6 +54,7 @@ class FlowLoader: LOGGER.error(err) raise ValueError(err) from e + LOGGER.info(f"[FlowLoader] Parsing steps of flow {flow_id} for app {app_id}...") for key, step in flow_yaml["steps"].items(): if key == "start": step["name"] = "开始" @@ -67,25 +68,20 @@ class FlowLoader: step["type"] = await NodeManager.get_node_call_id(step["node"]) step["name"] = await NodeManager.get_node_name(step["node"]) if "name" not in step or step["name"] == "" else step["name"] + LOGGER.info(f"[FlowLoader] Validating flow {flow_id} for app {app_id}...") try: # 检查Flow格式,并转换为Flow对象 - flow = Flow.model_validate(flow_yaml) + return Flow.model_validate(flow_yaml) except Exception as e: LOGGER.error(f"Invalid flow format: {e}") return None - await self._updata_db(FlowConfig(flow_config=flow, flow_id=flow_id)) - - return flow async def save(self, app_id: str, flow_id: str, flow: Flow) -> None: """保存工作流""" - await self._updata_db(FlowConfig(flow_config=flow, flow_id=flow_id)) flow_path = Path(config["SEMANTICS_DIR"]) / "app" / app_id / "flow" / f"{flow_id}.yaml" if not await flow_path.parent.exists(): await flow_path.parent.mkdir(parents=True) - if not await flow_path.exists(): - await flow_path.touch() flow_dict = { "name": flow.name, @@ -124,61 +120,14 @@ class FlowLoader: """删除指定工作流文件""" flow_path = Path(config["SEMANTICS_DIR"]) / APP_DIR / app_id / FLOW_DIR / f"{flow_id}.yaml" # 确保目标为文件且存在 - if await flow_path.is_file(): + if await flow_path.exists(): try: await flow_path.unlink() LOGGER.info(f"[FlowLoader] Successfully deleted flow file: {flow_path}") + return True except OSError as e: LOGGER.error(f"[FlowLoader] Failed to delete flow file {flow_path}: {e}") return False else: LOGGER.warning(f"[FlowLoader] Flow file does not exist or is not a file: {flow_path}") - return False - - flow_collection = MongoDB.get_collection("flow") - try: - await flow_collection.delete_one({"_id": flow_id}) - except Exception as e: - LOGGER.error(f"[FlowLoader] Failed to delete flow from database: {e}") - return False - - - async def _updata_db(self, flow_config: FlowConfig): - """更新数据库""" - try: - flow_collection = MongoDB.get_collection("flow") - flow = flow_config.flow_config - # 查询条件为app_id - if await flow_collection.find_one({"_id": flow_config.flow_id}) is None: - # 创建应用时需写入完整数据结构,自动初始化创建时间、flow列表、收藏列表和权限 - await flow_collection.insert_one( - jsonable_encoder( - Flow( - name=flow.name, - description=flow.description, - on_error=flow.on_error, - steps=flow.steps, - edges=flow.edges, - debug=flow.debug, - ), - ), - ) - else: - # 更新应用数据:部分映射 AppMetadata 到 AppPool,其他字段不更新 - await flow_collection.update_one( - {"_id":flow_config.flow_id}, - jsonable_encoder( - Flow( - name=flow.name, - description=flow.description, - on_error=flow.on_error, - steps=flow.steps, - edges=flow.edges, - debug=flow.debug, - ), - ), - ) - except Exception as e: - err=f"[FlowLoader] Failed to update flow in database: {e}" - LOGGER.error(err) - raise ValueError(err) + return True diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index a94b488e3..fc98457d7 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -17,6 +17,7 @@ from apps.scheduler.pool.check import FileChecker from apps.scheduler.pool.loader import ( AppLoader, CallLoader, + FlowLoader, ServiceLoader, ) @@ -80,33 +81,19 @@ class Pool: if not flow_list: return [] for flow in flow_list["flows"]: - flow_metadata_list.extend(AppFlow(**flow)) + flow_metadata_list += [AppFlow.model_validate(flow)] + return flow_metadata_list except Exception as e: err = f"获取App{app_id}的Flow列表失败:{e}" LOGGER.error(err) - raise RuntimeError(err) from e + return [] - return flow_metadata_list - - # TODO async def get_flow(self, app_id: str, flow_id: str) -> Optional[Flow]: - """从数据库中获取单个Flow的全部数据""" - app_collection = MongoDB.get_collection("app") - try: - # 使用聚合管道来查找特定的flow - pipeline = [ - {"$match": {"_id": app_id}}, - {"$unwind": "$flows"}, - {"$match": {"flows._id": flow_id}}, - ] - async for flow in await app_collection.aggregate(pipeline): - return Flow(**flow) - return None - except Exception as e: - err = f"获取App {app_id} 的Flow {flow_id} 失败:{e}" - LOGGER.error(err) - raise RuntimeError(err) from e + """从文件系统中获取单个Flow的全部数据""" + LOGGER.info(f"[Pool] Getting flow {flow_id} for app {app_id}...") + flow_loader = FlowLoader() + return await flow_loader.load(app_id, flow_id) async def get_call(self, call_path: str) -> Any: diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 4be23bcd1..f9665c53d 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -7,6 +7,7 @@ from typing import Optional import ray +from apps.constants import LOGGER from apps.entities.flow import Flow from apps.entities.task import RequestDataApp from apps.llm.patterns import Select @@ -31,6 +32,7 @@ class FlowChooser: async def get_top_flow(self) -> str: """获取Top1 Flow""" + LOGGER.info(f"[FlowChooser] Judging top flow for task {self._task_id}...") pool = ray.get_actor("pool") # 获取所选应用的所有Flow if not self._user_selected or not self._user_selected.app_id: @@ -40,6 +42,7 @@ class FlowChooser: if not flow_list: return "KnowledgeBase" + LOGGER.info(f"[FlowChooser] Selecting top flow for task {self._task_id}...") return await Select().generate(self._task_id, question=self._question, choices=flow_list) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 22c1fab90..be91d6102 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -8,6 +8,7 @@ from typing import Union import ray +from apps.common.config import config from apps.constants import LOGGER from apps.entities.collection import Document from apps.entities.enum_var import EventType @@ -19,27 +20,26 @@ from apps.entities.message import ( ) from apps.entities.rag_data import RAGEventData, RAGQueryReq from apps.entities.record import RecordDocument -from apps.entities.request_data import RequestData from apps.entities.task import TaskBlock from apps.service import RAG -async def push_init_message(task_id: str, queue: ray.ObjectRef, post_body: RequestData, *, is_flow: bool = False) -> None: +async def push_init_message(task_id: str, queue: ray.ObjectRef, context_num: int, *, is_flow: bool = False) -> None: """推送初始化消息""" task_actor = ray.get_actor("task") task: TaskBlock = await task_actor.get_task.remote(task_id) # 组装feature if is_flow: feature = InitContentFeature( - maxTokens=post_body.features.max_tokens, - contextNum=post_body.features.context_num, + maxTokens=config["LLM_MAX_TOKENS"], + contextNum=context_num, enableFeedback=False, enableRegenerate=False, ) else: feature = InitContentFeature( - maxTokens=post_body.features.max_tokens, - contextNum=post_body.features.context_num, + maxTokens=config["LLM_MAX_TOKENS"], + contextNum=context_num, enableFeedback=True, enableRegenerate=True, ) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 254267e5f..65221f323 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -13,6 +13,7 @@ from apps.entities.rag_data import RAGQueryReq from apps.entities.request_data import RequestData from apps.entities.scheduler import ExecutorBackground from apps.entities.task import SchedulerResult, TaskBlock +from apps.manager.appcenter import AppCenterManager from apps.manager.user import UserManager from apps.scheduler.executor.flow import Executor from apps.scheduler.scheduler.context import get_context, get_docs @@ -49,14 +50,6 @@ class Scheduler: await queue.close.remote() # type: ignore[attr-defined] return SchedulerResult(used_docs=[]) - try: - # 获取上下文;最多20轮 - context, facts = await get_context(user_sub, post_body, post_body.features.context_num) - except Exception as e: - LOGGER.error(f"[Scheduler] Get context failed: {e!s}\n{traceback.format_exc()}") - await queue.close.remote() # type: ignore[attr-defined] - return SchedulerResult(used_docs=[]) - # 获取用户配置的kb_sn user_info = await UserManager.get_userinfo_by_user_sub(user_sub) if not user_info: @@ -78,7 +71,7 @@ class Scheduler: # 如果是智能问答,直接执行 if not post_body.app or post_body.app.app_id == "": - await push_init_message(task_id, queue, post_body, is_flow=False) + await push_init_message(task_id, queue, 3, is_flow=False) await asyncio.sleep(0.1) for doc in docs: # 保存使用的文件ID @@ -89,8 +82,18 @@ class Scheduler: # 保存有数据的最后一条消息 await push_rag_message(task_id, queue, user_sub, rag_data) else: + # 查找对应的App元数据 + app_data = await AppCenterManager.fetch_app_data_by_id(post_body.app.app_id) + if not app_data: + LOGGER.error(f"[Scheduler] App {post_body.app.app_id} not found") + await queue.close.remote() # type: ignore[attr-defined] + return SchedulerResult(used_docs=[]) + + # 获取上下文 + context, facts = await get_context(user_sub, post_body, app_data.history_len) + # 需要执行Flow - await push_init_message(task_id, queue, post_body, is_flow=True) + await push_init_message(task_id, queue, app_data.history_len, is_flow=True) # 组装上下文 executor_background = ExecutorBackground( conversation=context, @@ -110,10 +113,12 @@ class Scheduler: async def run_executor(self, task: TaskBlock, queue: ray.ObjectRef, post_body: RequestData, background: ExecutorBackground) -> None: """构造FlowExecutor,并执行所选择的流""" # 读取App中所有Flow的信息 + LOGGER.info(f"[Scheduler] Getting flow metadata for app {post_body.app}...") pool_actor = ray.get_actor("pool") if not post_body.app: LOGGER.error("[Scheduler] Not using workflow!") return + LOGGER.info(f"[Scheduler] Getting flow metadata for app {post_body.app}...") flow_info = await pool_actor.get_flow_metadata.remote(post_body.app.app_id) # 如果flow_info为空,则直接返回 @@ -123,11 +128,14 @@ class Scheduler: # 如果用户选了特定的Flow if post_body.app.flow_id: + LOGGER.info(f"[Scheduler] Getting flow data for app {post_body.app.app_id} and flow {post_body.app.flow_id}...") flow_data = await pool_actor.get_flow.remote(post_body.app.app_id, post_body.app.flow_id) else: # 如果用户没有选特定的Flow,则根据语义选择一个Flow + LOGGER.info(f"[Scheduler] Choosing top flow for app {post_body.app.app_id}...") flow_chooser = FlowChooser(task.record.task_id, post_body.question, post_body.app) flow_id = await flow_chooser.get_top_flow() + LOGGER.info(f"[Scheduler] Getting flow data for app {post_body.app.app_id} and flow {flow_id}...") flow_data = await pool_actor.get_flow.remote(post_body.app.app_id, flow_id) # 如果flow_data为空,则直接返回 @@ -136,6 +144,7 @@ class Scheduler: return # 初始化Executor + LOGGER.info(f"[Scheduler] Initializing executor for app {post_body.app.app_id} and flow {flow_id}...") flow_exec = Executor( name=flow_data.name, description=flow_data.description, @@ -147,5 +156,7 @@ class Scheduler: executor_background=background, ) # 开始运行 + LOGGER.info(f"[Scheduler] Loading state for app {post_body.app.app_id} and flow {flow_id}...") await flow_exec.load_state() + LOGGER.info(f"[Scheduler] Running executor for app {post_body.app.app_id} and flow {flow_id}...") await flow_exec.run() -- Gitee