From d33d48bd8c1840049d2106abfe1f7e0632e4ef23 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Thu, 23 Jan 2025 21:33:59 +0800 Subject: [PATCH] =?UTF-8?q?Call=E4=BD=BF=E7=94=A8=E8=87=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=E8=BF=94=E5=9B=9E=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/constants.py | 2 +- apps/entities/plugin.py | 11 +- apps/manager/knowledge.py | 5 +- apps/scheduler/call/__init__.py | 6 +- apps/scheduler/call/api.py | 14 +- apps/scheduler/call/choice.py | 50 ++++--- apps/scheduler/call/cmd/assembler.py | 1 - apps/scheduler/call/cmd/cmd.py | 8 +- apps/scheduler/call/cmd/solver.py | 3 +- apps/scheduler/call/core.py | 7 + apps/scheduler/call/llm.py | 34 +++-- apps/scheduler/call/next_flow.py | 10 +- apps/scheduler/call/rag.py | 0 apps/scheduler/call/reformat.py | 48 ++++--- apps/scheduler/call/render/render.py | 108 ++++++-------- apps/scheduler/call/render/style.py | 151 ++++++++++---------- deploy/chart/euler_copilot/configs/rag/.env | 4 +- 17 files changed, 231 insertions(+), 231 deletions(-) create mode 100644 apps/scheduler/call/rag.py diff --git a/apps/constants.py b/apps/constants.py index 270fd8814..642281a45 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -18,5 +18,5 @@ MAX_API_RESPONSE_LENGTH = 4096 MAX_SCHEDULER_HISTORY_SIZE = 3 # 语义接口目录中工具子目录 CALL_DIR = "call" - +# 日志记录器 LOGGER = logging.getLogger("ray") diff --git a/apps/entities/plugin.py b/apps/entities/plugin.py index 0855100a8..af3082082 100644 --- a/apps/entities/plugin.py +++ b/apps/entities/plugin.py @@ -2,7 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -10,15 +10,6 @@ from apps.common.queue import MessageQueue from apps.entities.task import FlowHistory, RequestDataPlugin -class CallResult(BaseModel): - """Call运行后的返回值""" - - message: str = Field(description="经LLM理解后的Call的输出") - output: dict[str, Any] = Field(description="Call的原始输出") - output_schema: dict[str, Any] = Field(description="Call中Output对应的Schema") - extra: Optional[dict[str, Any]] = Field(description="Call的额外输出", default=None) - - class SysCallVars(BaseModel): """所有Call都需要接受的参数。包含用户输入、上下文信息、Step的输出记录等 diff --git a/apps/manager/knowledge.py b/apps/manager/knowledge.py index dc2250444..d56c3b6a6 100644 --- a/apps/manager/knowledge.py +++ b/apps/manager/knowledge.py @@ -16,10 +16,11 @@ class KnowledgeBaseManager: """修改当前用户的知识库ID""" user_collection = MongoDB.get_collection("user") try: - result = await user_collection.update_one({"_id": user_sub}, {"$set": {"kb_id": kb_id}}) - if result.modified_count == 0: + user = await user_collection.find_one({"_id": user_sub}, {"kb_id": 1}) + if user is None: LOGGER.error("[KnowledgeBaseManager] change kb_id error: user_sub not found") return False + await user_collection.update_one({"_id": user_sub}, {"$set": {"kb_id": kb_id}}) return True except Exception as e: LOGGER.error(f"[KnowledgeBaseManager] change kb_id error: {e}") diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index b8075f44b..467770842 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -2,10 +2,10 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from apps.scheduler.call.api.api import API +from apps.scheduler.call.api import API from apps.scheduler.call.choice import Choice from apps.scheduler.call.llm import LLM -from apps.scheduler.call.reformat import Extract +from apps.scheduler.call.reformat import Reformat from apps.scheduler.call.render.render import Render from apps.scheduler.call.sql import SQL @@ -14,6 +14,6 @@ __all__ = [ "LLM", "SQL", "Choice", - "Extract", + "Reformat", "Render", ] diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index 8bcb6f734..5ca7c7ac0 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -10,7 +10,7 @@ from fastapi import status from pydantic import BaseModel, Field from apps.constants import LOGGER -from apps.entities.plugin import CallError, CallResult, SysCallVars +from apps.entities.plugin import CallError, SysCallVars from apps.manager.token import TokenManager from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot @@ -26,6 +26,12 @@ class _APIParams(BaseModel): service_id: Optional[str] = Field(description="服务ID") +class _APIOutput(BaseModel): + """API调用工具的输出""" + + output: dict[str, Any] = Field(description="API调用工具的输出") + + class API(CoreCall): """API调用工具""" @@ -36,8 +42,6 @@ class API(CoreCall): async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 """初始化API调用工具""" - await super().init(syscall_vars, **kwargs) - # 插件鉴权 self._auth = json.loads(str(plugin_data.auth)) # 从spec中找出该接口对应的spec @@ -54,7 +58,7 @@ class API(CoreCall): self.slot_schema, self._data_type = self._check_data_type(self._spec[2]["requestBody"]["content"]) elif method == "GET": if "parameters" in self._spec[2]: - self.slot_schema = APISanitizer.parameters_to_spec(self._spec[2]["parameters"]) + self.slot_schema = self.parameters_to_spec(self._spec[2]["parameters"]) self._data_type = "json" else: err = "[API] HTTP method not implemented." @@ -217,4 +221,4 @@ class API(CoreCall): response_schema = {} LOGGER.info(f"调用接口{url}, 结果为 {response_data}") - return APISanitizer.process(response_data, url, self._spec[1], response_schema) + return self.process(response_data, url, self._spec[1], response_schema) diff --git a/apps/scheduler/call/choice.py b/apps/scheduler/call/choice.py index a5cb6bee0..f2845fa06 100644 --- a/apps/scheduler/call/choice.py +++ b/apps/scheduler/call/choice.py @@ -6,19 +6,33 @@ from typing import Any from pydantic import BaseModel, Field -from apps.entities.plugin import CallError, CallResult +from apps.entities.plugin import CallError, SysCallVars from apps.llm.patterns.select import Select from apps.scheduler.call.core import CoreCall +class _ChoiceBranch(BaseModel): + """Choice工具的选项""" + + name: str = Field(description="选项的名称") + value: str = Field(description="选项的值") + + class _ChoiceParams(BaseModel): """Choice工具所需的额外参数""" propose: str = Field(description="针对哪一个问题进行答案选择?") - choices: list[dict[str, Any]] = Field(description="Choice工具所有可能的选项") + choices: list[_ChoiceBranch] = Field(description="Choice工具所有可能的选项") + + +class _ChoiceOutput(BaseModel): + """Choice工具的输出""" + + message: str = Field(description="Choice工具的输出") + next_step: str = Field(description="Choice工具的输出") -class Choice(CoreCall): +class Choice(metaclass=CoreCall): """Choice工具。用于大模型在多个选项中选择一个,并跳转到对应的Step。""" name: str = "choice" @@ -26,32 +40,28 @@ class Choice(CoreCall): params: type[_ChoiceParams] = _ChoiceParams - async def call(self, _slot_data: dict[str, Any]) -> CallResult: - """调用Choice工具。 + async def __call__(self, _slot_data: dict[str, Any]) -> _ChoiceOutput: + """调用Choice工具。""" + # 获取必要参数 + params: _ChoiceParams = getattr(self, "_params") + syscall_vars: SysCallVars = getattr(self, "_syscall_vars") - :param _slot_data: 经用户修正过的参数(暂未使用) - :return: Choice工具的输出信息。包含下一个Step的名称、自然语言解释等。 - """ previous_data = {} - if len(self._syscall_vars.history) > 0: - previous_data = CallResult(**self._syscall_vars.history[-1].output_data).output + if len(syscall_vars.history) > 0: + previous_data = syscall_vars.history[-1].output_data try: result = await Select().generate( - question=self.params.propose, - background=self._syscall_vars.background, + question=params.propose, + background=syscall_vars.background, data=previous_data, - choices=self.params.choices, - task_id=self._syscall_vars.task_id, + choices=params.choices, + task_id=syscall_vars.task_id, ) except Exception as e: raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e - return CallResult( - output={}, - output_schema={}, - extra={ - "next_step": result, - }, + return _ChoiceOutput( + next_step=result, message=f"针对“{self.params.propose}”,作出的选择为:{result}。", ) diff --git a/apps/scheduler/call/cmd/assembler.py b/apps/scheduler/call/cmd/assembler.py index f4f0bb7c0..d1f774187 100644 --- a/apps/scheduler/call/cmd/assembler.py +++ b/apps/scheduler/call/cmd/assembler.py @@ -6,7 +6,6 @@ import string from typing import Any, Literal, Optional from apps.llm.patterns.select import Select -from apps.scheduler.vector import DocumentWrapper, VectorDB class CommandlineAssembler: diff --git a/apps/scheduler/call/cmd/cmd.py b/apps/scheduler/call/cmd/cmd.py index 9d7e15e97..790095c6c 100644 --- a/apps/scheduler/call/cmd/cmd.py +++ b/apps/scheduler/call/cmd/cmd.py @@ -6,7 +6,6 @@ from typing import Any, Optional from pydantic import BaseModel, Field -from apps.entities.plugin import CallResult from apps.scheduler.call.core import CoreCall @@ -17,14 +16,17 @@ class _CmdParams(BaseModel): args: list[str] = Field(default=[], description="命令中可执行文件的参数(例如 `--help`),可选") +class _CmdOutput(BaseModel): + """Cmd工具的输出""" + + class Cmd(CoreCall): """Cmd工具。用于根据BTDL描述文件,生成命令。""" name: str = "cmd" description: str = "根据BTDL描述文件,生成命令。" - params: type[_CmdParams] = _CmdParams - async def call(self, _slot_data: dict[str, Any]) -> CallResult: + async def __call__(self, _slot_data: dict[str, Any]) -> _CmdOutput: """调用Cmd工具""" pass diff --git a/apps/scheduler/call/cmd/solver.py b/apps/scheduler/call/cmd/solver.py index 20acce2a2..513ea7379 100644 --- a/apps/scheduler/call/cmd/solver.py +++ b/apps/scheduler/call/cmd/solver.py @@ -14,7 +14,8 @@ class Solver: """解析命令行生成器""" @staticmethod - async def _get_option(agent_input: str, collection_name: str, binary_name: str, subcmd_name: str, spec: dict[str, Any]): + async def _get_option(agent_input: str, collection_name: str, binary_name: str, subcmd_name: str, spec: dict[str, Any]) -> tuple[str, str]: + """选择最匹配的命令行参数""" # 选择最匹配的Global Options global_options = CommandlineAssembler.get_data("global_option", agent_input, collection_name, binary_name, num=2) # 选择最匹配的Options diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 7e98343ab..271da63b7 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -53,10 +53,17 @@ class CoreCall(type): if not issubclass(kwargs["param_cls"], BaseModel): err = f"参数模板{kwargs['param_cls']}不是Pydantic类!" raise TypeError(err) + if "output_cls" not in kwargs: + err = f"请给工具{name}提供输出模板!" + raise AttributeError(err) + if not issubclass(kwargs["output_cls"], BaseModel): + err = f"输出模板{kwargs['output_cls']}不是Pydantic类!" + raise TypeError(err) # 设置参数相关的属性 attrs["_param_cls"] = kwargs["param_cls"] attrs["params_schema"] = kwargs["param_cls"].model_json_schema() + attrs["output_schema"] = kwargs["output_cls"].model_json_schema() # __init__不允许自定义 attrs["__init__"] = lambda self, syscall_vars, **kwargs: self._class_init_fixed(syscall_vars, **kwargs) # 提供空逻辑占位 diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm.py index f6a28d6d7..a99deea83 100644 --- a/apps/scheduler/call/llm.py +++ b/apps/scheduler/call/llm.py @@ -11,7 +11,7 @@ from jinja2 import BaseLoader, select_autoescape from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field -from apps.entities.plugin import CallError, CallResult +from apps.entities.plugin import CallError, SysCallVars from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall @@ -41,23 +41,31 @@ class _LLMParams(BaseModel): timeout: int = Field(description="超时时间", default=30) -class LLM(CoreCall): +class _LLMOutput(BaseModel): + """定义LLM工具调用的输出""" + + message: str = Field(description="大模型输出的文字信息") + + +class LLM(metaclass=CoreCall, param_cls=_LLMParams, output_cls=_LLMOutput): """大模型调用工具""" name: str = "llm" description: str = "大模型调用工具,用于以指定的提示词和上下文信息调用大模型,并获得输出。" - params: type[_LLMParams] = _LLMParams - async def call(self, _slot_data: dict[str, Any]) -> CallResult: + async def __call__(self, _slot_data: dict[str, Any]) -> _LLMOutput: """运行LLM Call""" + # 获取必要参数 + syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + params: _LLMParams = getattr(self, "_params") # 参数 time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") formatter = { "time": time, - "context": self._syscall_vars.background, - "question": self._syscall_vars.question, - "history": self._syscall_vars.history, + "context": syscall_vars.background, + "question": syscall_vars.question, + "history": syscall_vars.history, } try: @@ -67,14 +75,14 @@ class LLM(CoreCall): autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(self.params.system_prompt) + ).from_string(params.system_prompt) system_input = system_tmpl.render(**formatter) user_tmpl = SandboxedEnvironment( loader=BaseLoader(), autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(self.params.user_prompt) + ).from_string(params.user_prompt) user_input = user_tmpl.render(**formatter) except Exception as e: raise CallError(message=f"用户提示词渲染失败:{e!s}", data={}) from e @@ -86,13 +94,9 @@ class LLM(CoreCall): try: result = "" - async for chunk in ReasoningLLM().call(task_id=self._syscall_vars.task_id, messages=message): + async for chunk in ReasoningLLM().call(task_id=syscall_vars.task_id, messages=message): result += chunk except Exception as e: raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e - return CallResult( - output={}, - message=result, - output_schema={}, - ) + return _LLMOutput(message=result) diff --git a/apps/scheduler/call/next_flow.py b/apps/scheduler/call/next_flow.py index 5bddaee43..c60b90636 100644 --- a/apps/scheduler/call/next_flow.py +++ b/apps/scheduler/call/next_flow.py @@ -1,13 +1,15 @@ """用于下一步工作流推荐的工具""" -from apps.scheduler.call.core import CallResult, CoreCall +from typing import Any +from apps.scheduler.call.core import CoreCall -class NextFlowCall(CoreCall): + +class NextFlowCall(metaclass=CoreCall): """用于下一步工作流推荐的工具""" name = "next_flow" description = "用于下一步工作流推荐的工具" - def call(self) -> CallResult: - return CallResult(output={}, message="", output_schema={}) + async def __call__(self, _slot_data: dict[str, Any]): + """调用NextFlow工具""" diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/scheduler/call/reformat.py b/apps/scheduler/call/reformat.py index d86cc607b..608dc05e5 100644 --- a/apps/scheduler/call/reformat.py +++ b/apps/scheduler/call/reformat.py @@ -13,7 +13,7 @@ from jinja2 import BaseLoader, select_autoescape from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field -from apps.entities.plugin import CallResult, SysCallVars +from apps.entities.plugin import SysCallVars from apps.scheduler.call.core import CoreCall @@ -24,57 +24,59 @@ class _ReformatParam(BaseModel): data: Optional[str] = Field(description="对生成的原始数据(JSON)进行格式化,没有则不改动;jsonnet语法", default=None) -class Extract(CoreCall): +class _ReformatOutput(BaseModel): + """定义Reformat工具的输出""" + + message: str = Field(description="格式化后的文字信息") + output: dict = Field(description="格式化后的结果") + + +class Reformat(metaclass=CoreCall, param_cls=_ReformatParam, output_cls=_ReformatOutput): """Reformat 工具,用于对生成的文字信息和原始数据进行格式化""" name: str = "reformat" description: str = "从上一步的工具的原始JSON返回结果中,提取特定字段的信息。" - params: type[_ReformatParam] = _ReformatParam - async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 - """初始化Reformat工具""" - await super().init(syscall_vars, **kwargs) - self._last_output = CallResult(**self._syscall_vars.history[-1].output_data) - async def call(self, _slot_data: dict[str, Any]) -> CallResult: + async def __call__(self, _slot_data: dict[str, Any]) -> _ReformatOutput: """调用Reformat工具 :param _slot_data: 经用户确认后的参数(目前未使用) :return: 提取出的字段 """ + # 获取必要参数 + params: _ReformatParam = getattr(self, "_params") + syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + last_output = syscall_vars.history[-1].output_data # 判断用户是否给了值 time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") - if self.params.text is None: - result_message = self._last_output.message + if params.text is None: + result_message = last_output.get("message", "") else: text_template = SandboxedEnvironment( loader=BaseLoader(), autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(self.params.text) - result_message = text_template.render(time=time, history=self._syscall_vars.history, question=self._syscall_vars.question) + ).from_string(params.text) + result_message = text_template.render(time=time, history=syscall_vars.history, question=syscall_vars.question) - if self.params.data is None: - result_data = self._last_output.output + if params.data is None: + result_data = last_output.get("output", {}) else: extra_str = json.dumps({ "time": time, - "question": self._syscall_vars.question, + "question": syscall_vars.question, }, ensure_ascii=False) - history_str = json.dumps([CallResult(**item.output_data).output for item in self._syscall_vars.history], ensure_ascii=False) + history_str = json.dumps([CallResult(**item.output_data).output for item in syscall_vars.history], ensure_ascii=False) data_template = dedent(f""" local extra = {extra_str}; local history = {history_str}; - {self.params.data} + {params.data} """) - result_data = json.loads(_jsonnet.evaluate_snippet(data_template, self.params.data), ensure_ascii=False) + result_data = json.loads(_jsonnet.evaluate_snippet(data_template, params.data), ensure_ascii=False) - return CallResult( + return _ReformatOutput( message=result_message, output=result_data, - output_schema={ - "type": "object", - "description": "格式化后的结果", - }, ) diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/render/render.py index 6cd406bf6..443f5bd5f 100644 --- a/apps/scheduler/call/render/render.py +++ b/apps/scheduler/call/render/render.py @@ -6,25 +6,52 @@ import json from pathlib import Path from typing import Any -from apps.entities.plugin import CallError, CallResult, SysCallVars +from pydantic import BaseModel, Field + +from apps.entities.plugin import CallError, SysCallVars from apps.scheduler.call.core import CoreCall from apps.scheduler.call.render.style import RenderStyle -class Render(CoreCall): +class _RenderAxis(BaseModel): + """ECharts图表的轴配置""" + + type: str = Field(description="轴的类型") + axisTick: dict = Field(description="轴刻度配置") # noqa: N815 + + +class _RenderFormat(BaseModel): + """ECharts图表配置""" + + tooltip: dict[str, Any] = Field(description="ECharts图表的提示框配置") + legend: dict[str, Any] = Field(description="ECharts图表的图例配置") + dataset: dict[str, Any] = Field(description="ECharts图表的数据集配置") + xAxis: _RenderAxis = Field(description="ECharts图表的X轴配置") # noqa: N815 + yAxis: _RenderAxis = Field(description="ECharts图表的Y轴配置") # noqa: N815 + series: list[dict[str, Any]] = Field(description="ECharts图表的数据列配置") + + +class _RenderOutput(BaseModel): + """Render工具的输出""" + + output: _RenderFormat = Field(description="ECharts图表配置") + message: str = Field(description="Render工具的输出") + + +class _RenderParam(BaseModel): + """Render工具的参数""" + + + +class Render(metaclass=CoreCall, param_cls=_RenderParam, output_cls=_RenderOutput): """Render Call,用于将SQL Tool查询出的数据转换为图表""" name: str = "render" description: str = "渲染图表工具,可将给定的数据绘制为图表。" - async def init(self, syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 - """初始化Render Call,校验参数,读取option模板 - - :param syscall_vars: Render Call参数 - """ - await super().init(syscall_vars, **_kwargs) - + def init(self, _syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 + """初始化Render Call,校验参数,读取option模板""" try: option_location = Path(__file__).parent / "option.json" with Path(option_location).open(encoding="utf-8") as f: @@ -33,10 +60,13 @@ class Render(CoreCall): raise CallError(message=f"图表模板读取失败:{e!s}", data={}) from e - async def call(self, _slot_data: dict[str, Any]) -> CallResult: + async def __call__(self, _slot_data: dict[str, Any]) -> _RenderOutput: """运行Render Call""" + # 获取必要参数 + syscall_vars: SysCallVars = getattr(self, "_syscall_vars") + # 检测前一个工具是否为SQL - data = CallResult(**self._syscall_vars.history[-1].output_data).output + data = syscall_vars.history[-1].output_data if data["type"] != "sql" or "dataset" not in data: raise CallError( message="图表生成失败!Render必须在SQL后调用!", @@ -67,65 +97,18 @@ class Render(CoreCall): self._option_template["dataset"]["source"] = data try: - llm_output = await RenderStyle().generate(self._syscall_vars.task_id, question=self._syscall_vars.question) + llm_output = await RenderStyle().generate(syscall_vars.task_id, question=syscall_vars.question) add_style = llm_output.get("additional_style", "") self._parse_options(column_num, llm_output["chart_type"], add_style, llm_output["scale_type"]) except Exception as e: raise CallError(message=f"图表生成失败:{e!s}", data={"data": data}) from e - return CallResult( - output=self._option_template, - output_schema={ - "type": "object", - "description": "ECharts图表配置", - "properties": { - "tooltip": { - "type": "object", - "description": "ECharts图表的提示框配置", - }, - "legend": { - "type": "object", - "description": "ECharts图表的图例配置", - }, - "dataset": { - "type": "object", - "description": "ECharts图表的数据集配置", - }, - "xAxis": { - "type": "object", - "description": "ECharts图表的X轴配置", - "properties": { - "type": { - "type": "string", - "description": "ECharts图表的X轴类型", - "default": "category", - }, - "axisTick": { - "type": "object", - "description": "ECharts图表的X轴刻度配置", - }, - }, - }, - "yAxis": { - "type": "object", - "description": "ECharts图表的Y轴配置", - "properties": { - "type": { - "type": "string", - "description": "ECharts图表的Y轴类型", - "default": "value", - }, - }, - }, - "series": { - "type": "array", - "description": "ECharts图表的数据列配置", - }, - }, - }, + return _RenderOutput( + output=_RenderFormat.model_validate(self._option_template), message="图表生成成功!图表将使用外置工具进行展示。", ) + @staticmethod def _separate_key_value(data: list[dict[str, Any]]) -> list[dict[str, Any]]: """若数据只有一组(例如:{"aaa": "bbb"}),则分离键值对。 @@ -141,6 +124,7 @@ class Render(CoreCall): result.append({"type": key, "value": val}) return result + def _parse_options(self, column_num: int, chart_style: str, additional_style: str, scale_style: str) -> None: """解析LLM做出的图表样式选择""" series_template = {} diff --git a/apps/scheduler/call/render/style.py b/apps/scheduler/call/render/style.py index f9369dfe0..f014ea39a 100644 --- a/apps/scheduler/call/render/style.py +++ b/apps/scheduler/call/render/style.py @@ -2,7 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any, Optional +from typing import Any, ClassVar, Optional from apps.llm.patterns.core import CorePattern from apps.llm.patterns.json import Json @@ -12,95 +12,88 @@ from apps.llm.reasoning import ReasoningLLM class RenderStyle(CorePattern): """选择图表样式""" - @property - def predefined_system_prompt(self) -> str: - """系统提示词""" - return r""" - You are a helpful assistant. Help the user make style choices when drawing a chart. - Chart title should be short and less than 3 words. - - Available types: - - `bar`: Bar graph - - `pie`: Pie graph - - `line`: Line graph - - `scatter`: Scatter graph - - Available bar additional styles: - - `normal`: Normal bar graph - - `stacked`: Stacked bar graph - - Available pie additional styles: - - `normal`: Normal pie graph - - `ring`: Ring pie graph - - Available scales: - - `linear`: Linear scale - - `log`: Logarithmic scale - - EXAMPLE - ## Question - 查询数据库中的数据,并绘制堆叠柱状图。 - - ## Thought - Let's think step by step. The user requires drawing a stacked bar chart, so the chart type should be `bar`, \ - i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form. - - ## Answer - The chart type should be: bar - The chart style should be: stacked - The scale should be: linear - - END OF EXAMPLE - - Let's begin. - """ - - def predefined_user_prompt(self) -> str: - """用户提示词""" - return r""" - ## Question - {question} - - ## Thought - Let's think step by step. - """ - - def slot_schema(self) -> dict[str, Any]: - """槽位Schema""" - return { - "type": "object", - "properties": { - "chart_type": { - "type": "string", - "description": "The type of the chart.", - "enum": ["bar", "pie", "line", "scatter"], - }, - "additional_style": { - "type": "string", - "description": "The additional style of the chart.", - "enum": ["normal", "stacked", "ring"], - }, - "scale_type": { - "type": "string", - "description": "The scale of the chart.", - "enum": ["linear", "log"], - }, + system_prompt = r""" + You are a helpful assistant. Help the user make style choices when drawing a chart. + Chart title should be short and less than 3 words. + + Available types: + - `bar`: Bar graph + - `pie`: Pie graph + - `line`: Line graph + - `scatter`: Scatter graph + + Available bar additional styles: + - `normal`: Normal bar graph + - `stacked`: Stacked bar graph + + Available pie additional styles: + - `normal`: Normal pie graph + - `ring`: Ring pie graph + + Available scales: + - `linear`: Linear scale + - `log`: Logarithmic scale + + EXAMPLE + ## Question + 查询数据库中的数据,并绘制堆叠柱状图。 + + ## Thought + Let's think step by step. The user requires drawing a stacked bar chart, so the chart type should be `bar`, \ + i.e. a bar chart; the chart style should be `stacked`, i.e. a stacked form. + + ## Answer + The chart type should be: bar + The chart style should be: stacked + The scale should be: linear + + END OF EXAMPLE + + Let's begin. + """ + + user_prompt = r""" + ## Question + {question} + + ## Thought + Let's think step by step. + """ + + slot_schema: ClassVar[dict[str, Any]] = { + "type": "object", + "properties": { + "chart_type": { + "type": "string", + "description": "The type of the chart.", + "enum": ["bar", "pie", "line", "scatter"], }, - "required": ["chart_type", "scale_type"], - } + "additional_style": { + "type": "string", + "description": "The additional style of the chart.", + "enum": ["normal", "stacked", "ring"], + }, + "scale_type": { + "type": "string", + "description": "The scale of the chart.", + "enum": ["linear", "log"], + }, + }, + "required": ["chart_type", "scale_type"], + } def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None: """初始化RenderStyle Prompt""" super().__init__(system_prompt, user_prompt) - async def generate(self, task_id: str, **kwargs) -> dict[str, Any]: + async def generate(self, task_id: str, **kwargs) -> dict[str, Any]: # noqa: ANN003 """使用LLM选择图表样式""" question = kwargs["question"] # 使用Reasoning模型进行推理 messages = [ - {"role": "system", "content": self._system_prompt}, - {"role": "user", "content": self._user_prompt.format(question=question)}, + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format(question=question)}, ] result = "" async for chunk in ReasoningLLM().call(task_id, messages, streaming=False): diff --git a/deploy/chart/euler_copilot/configs/rag/.env b/deploy/chart/euler_copilot/configs/rag/.env index 83851d566..cb3578df4 100644 --- a/deploy/chart/euler_copilot/configs/rag/.env +++ b/deploy/chart/euler_copilot/configs/rag/.env @@ -25,9 +25,9 @@ TASK_RETRY_TIME=3 # Embedding Service {{- if .Values.euler_copilot.rag.vectorize.use_internal }} -REMOTE_EMBEDDING_ENDPOINT=http://vectorize-agent-service-{{ .Release.Name }}.{{ .Release.Namespace }}.svc.cluster.local:8001 +REMOTE_EMBEDDING_ENDPOINT=http://vectorize-agent-service-{{ .Release.Name }}.{{ .Release.Namespace }}.svc.cluster.local:8001/embeddings {{- else }} -REMOTE_EMBEDDING_ENDPOINT={{ .Values.euler_copilot.rag.vectorize.url }} +REMOTE_EMBEDDING_ENDPOINT={{ .Values.euler_copilot.rag.vectorize.url }}/embeddings {{- end }} # Token -- Gitee