From 4416ba89ee68b1ad8e217d3d7c2ab9a5fea079ea Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 15 Jan 2025 10:40:40 +0800 Subject: [PATCH] Framework dev --- Dockerfile | 4 +- Jenkinsfile | 62 +++---- apps/common/config.py | 20 +-- apps/common/oidc.py | 33 +++- apps/common/queue.py | 2 +- apps/common/wordscheck.py | 164 +++++++++++++++++- apps/constants.py | 10 +- apps/dependency/session.py | 1 + apps/entities/{enum.py => enum_var.py} | 14 ++ apps/entities/flow.py | 101 +++++++++++ apps/entities/message.py | 2 +- apps/entities/plugin.py | 30 ---- apps/entities/record.py | 2 +- apps/entities/request_data.py | 3 +- apps/entities/response_data.py | 2 +- apps/entities/task.py | 2 +- apps/entities/vector.py | 33 ++++ apps/gunicorn.conf.py | 34 ---- apps/llm/function.py | 3 + apps/llm/patterns/rewoo.py | 23 +++ apps/llm/reasoning.py | 63 +------ apps/main.py | 22 +++ apps/manager/document.py | 17 +- apps/manager/gitee_white_list.py | 45 +++++ apps/manager/session.py | 11 +- apps/manager/task.py | 2 +- apps/models/minio.py | 49 ++++++ apps/models/mongo.py | 8 + apps/routers/document.py | 2 +- apps/scheduler/call/api/api.py | 52 ++++-- apps/scheduler/call/choice.py | 31 ++-- apps/scheduler/call/cmd/cmd.py | 13 +- apps/scheduler/call/core.py | 24 +-- apps/scheduler/call/llm.py | 26 +-- apps/scheduler/call/next_flow.py | 13 ++ apps/scheduler/call/reformat.py | 31 ++-- apps/scheduler/call/render/format.py | 15 -- apps/scheduler/call/render/render.py | 13 +- apps/scheduler/call/sql.py | 20 +-- apps/scheduler/embedding.py | 21 +++ apps/scheduler/executor/flow.py | 6 +- apps/scheduler/executor/message.py | 4 +- apps/scheduler/json_schema.py | 23 +-- apps/scheduler/openapi.py | 162 +++++++++++++++++ apps/scheduler/pool/loader.py | 2 +- apps/scheduler/pool/loader/__init__.py | 9 + apps/scheduler/pool/{ => loader}/btdl.py | 16 +- apps/scheduler/pool/loader/metadata.py | 31 ++++ apps/scheduler/pool/loader/openapi.py | 45 +++++ apps/scheduler/pool/pool.py | 3 +- apps/scheduler/scheduler/message.py | 2 +- apps/scheduler/scheduler/scheduler.py | 2 +- apps/scheduler/slot/parser/__init__.py | 4 + apps/scheduler/slot/parser/const.py | 24 +++ apps/scheduler/slot/parser/core.py | 2 +- apps/scheduler/slot/parser/date.py | 2 +- apps/scheduler/slot/parser/default.py | 24 +++ apps/scheduler/slot/parser/timestamp.py | 2 +- apps/scheduler/slot/slot.py | 23 ++- apps/scheduler/vector.py | 131 -------------- apps/service/activity.py | 2 +- apps/service/rag.py | 4 +- apps/service/suggestion.py | 2 +- apps/utils/get_api_doc.py | 2 +- assets/.env.example | 13 ++ assets/logging.example.json | 47 ----- op.conf | 7 + requirements.txt | 4 +- sample/README.txt | 2 + .../apps/test_app}/flows/flow.yaml | 0 .../apps/test_app/metadata.yaml | 0 sample/calls/__init__.py | 8 + sample/calls/test_call/__init__.py | 14 ++ sample/calls/test_call/sub_lib/__init__.py | 4 + sample/calls/test_call/sub_lib/add.py | 9 + .../calls/test_call}/user_tool.py | 55 +++--- sample/services/test_service/metadata.yaml | 18 ++ .../services/test_service/openapi/api.yaml | 0 sdk/example_plugin/lib/__init__.py | 8 - sdk/example_plugin/plugin.json | 11 -- 80 files changed, 1160 insertions(+), 590 deletions(-) rename apps/entities/{enum.py => enum_var.py} (83%) create mode 100644 apps/entities/flow.py create mode 100644 apps/entities/vector.py delete mode 100644 apps/gunicorn.conf.py create mode 100644 apps/manager/gitee_white_list.py create mode 100644 apps/models/minio.py create mode 100644 apps/scheduler/call/next_flow.py delete mode 100644 apps/scheduler/call/render/format.py create mode 100644 apps/scheduler/embedding.py create mode 100644 apps/scheduler/openapi.py create mode 100644 apps/scheduler/pool/loader/__init__.py rename apps/scheduler/pool/{ => loader}/btdl.py (96%) create mode 100644 apps/scheduler/pool/loader/metadata.py create mode 100644 apps/scheduler/pool/loader/openapi.py create mode 100644 apps/scheduler/slot/parser/const.py create mode 100644 apps/scheduler/slot/parser/default.py delete mode 100644 apps/scheduler/vector.py delete mode 100644 assets/logging.example.json create mode 100644 op.conf create mode 100644 sample/README.txt rename {sdk/example_plugin => sample/apps/test_app}/flows/flow.yaml (100%) rename sdk/example_plugin/lib/sub_lib/__init__.py => sample/apps/test_app/metadata.yaml (100%) create mode 100644 sample/calls/__init__.py create mode 100644 sample/calls/test_call/__init__.py create mode 100644 sample/calls/test_call/sub_lib/__init__.py create mode 100644 sample/calls/test_call/sub_lib/add.py rename {sdk/example_plugin/lib => sample/calls/test_call}/user_tool.py (59%) create mode 100644 sample/services/test_service/metadata.yaml rename sdk/example_plugin/openapi.yaml => sample/services/test_service/openapi/api.yaml (100%) delete mode 100644 sdk/example_plugin/lib/__init__.py delete mode 100644 sdk/example_plugin/plugin.json diff --git a/Dockerfile b/Dockerfile index 3f363e93..9cd9f102 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM hub.oepkgs.net/neocopilot/framework-baseimg:0.9.1 +FROM hub.oepkgs.net/neocopilot/framework-baseimg:dev USER root RUN sed -i 's/umask 002/umask 027/g' /etc/bashrc && \ @@ -11,4 +11,4 @@ COPY --chown=1001:1001 --chmod=550 ./ /euler-copilot-frame/ WORKDIR /euler-copilot-frame ENV PYTHONPATH /euler-copilot-frame -CMD bash -c "python3 -m gunicorn -c apps/gunicorn.conf.py apps.main:app" +CMD bash -c "python3 apps/main.py" diff --git a/Jenkinsfile b/Jenkinsfile index 8fb5afa1..ef3086b4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,41 +1,41 @@ node { + properties([ + parameters([ + string(name: "REPO", defaultValue: "framework-dev", description: "当前项目名") + ]) + ]) + echo "拉取代码仓库" checkout scm - def REPO = scm.getUserRemoteConfigs()[0].getUrl().tokenize('/').last().split("\\.")[0] - def BRANCH = scm.branches[0].name.split("/")[1] def BUILD = sh(script: 'git rev-parse --short HEAD', returnStdout: true).trim() + def reg = "" - withCredentials([string(credentialsId: "host", variable: "HOST")]) { - echo "构建当前分支Docker Image镜像" - sh "sed -i 's|framework_base|${HOST}:30000/framework-baseimg|g' Dockerfile" - docker.withRegistry("http://${HOST}:30000", "dockerAuth") { - def image = docker.build("${HOST}:30000/${REPO}:${BUILD}", "-f Dockerfile .") - image.push() - image.push("${BRANCH}") - } + withCredentials([string(credentialsId: "reg_host", variable: "REG_HOST")]) { + reg = "${REG_HOST}" + } + echo "构建当前分支Docker Image镜像" + docker.withRegistry("https://${reg}", "dockerAuth") { + def image = docker.build("${reg}/${params.REPO}:${BUILD}", "-f ./Dockerfile .") + image.push() + } - def remote = [:] - remote.name = "machine" + def remote = [:] + remote.name = "machine" + withCredentials([string(credentialsId: "ssh_host", variable: "HOST")]) { remote.host = "${HOST}" - withCredentials([usernamePassword(credentialsId: "ssh", usernameVariable: 'sshUser', passwordVariable: 'sshPass')]) { - remote.user = sshUser - remote.password = sshPass - } - remote.allowAnyHosts = true + } + withCredentials([usernamePassword(credentialsId: "ssh", usernameVariable: 'sshUser', passwordVariable: 'sshPass')]) { + remote.user = sshUser + remote.password = sshPass + } + remote.allowAnyHosts = true - echo "清除构建缓存" - sshCommand remote: remote, command: "sh -c \"docker rmi ${HOST}:30000/${REPO}:${BUILD} || true\";" - sshCommand remote: remote, command: "sh -c \"docker rmi ${REPO}:${BUILD} || true\";" - sshCommand remote: remote, command: "sh -c \"docker rmi ${REPO}:${BRANCH} || true\";" - sshCommand remote: remote, command: "sh -c \"docker image prune -f || true\";"; - sshCommand remote: remote, command: "sh -c \"docker builder prune -f || true\";"; - sshCommand remote: remote, command: "sh -c \"k3s crictl rmi --prune || true\";"; + echo "清除构建缓存" + sshCommand remote: remote, command: "sh -c \"docker rmi ${reg}/${params.REPO}:${BUILD} || true\";" + sshCommand remote: remote, command: "sh -c \"docker image prune -f || true\";"; + sshCommand remote: remote, command: "sh -c \"docker builder prune -f || true\";"; - echo "重新部署" - withCredentials([usernamePassword(credentialsId: "dockerAuth", usernameVariable: 'dockerUser', passwordVariable: 'dockerPass')]) { - sshCommand remote: remote, command: "sh -c \"cd /home/registry/registry-cli; python3 ./registry.py -l ${dockerUser}:${dockerPass} -r http://${HOST}:30000 --delete --keep-tags 'master' '0001' '330-feature' '430-feature' || true\";" - } - sshCommand remote: remote, command: "sh -c \"kubectl -n euler-copilot set image deployment/framework-deploy framework=${HOST}:30000/${REPO}:${BUILD}\";" - } -} + echo "重新部署" + sshCommand remote: remote, command: "sh -c \"kubectl -n euler-copilot set image deployment/framework-deploy framework=${reg}/${params.REPO}:${BUILD}\";" +} \ No newline at end of file diff --git a/apps/common/config.py b/apps/common/config.py index fcbeded2..0f44e8aa 100644 --- a/apps/common/config.py +++ b/apps/common/config.py @@ -4,6 +4,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import os import secrets +from pathlib import Path from typing import Optional from dotenv import dotenv_values @@ -49,6 +50,11 @@ class ConfigModel(BaseModel): DETECT_TYPE: Optional[str] = Field(description="敏感词检测系统类型", default=None) WORDS_CHECK: Optional[str] = Field(description="AutoGPT敏感词检测系统API URL", default=None) WORDS_LIST: Optional[str] = Field(description="敏感词列表文件路径", default=None) + SCAS_APP_ID: Optional[str] = Field(description="SCAS敏感词检测系统 APP ID", default=None) + SCAS_SIGN_KEY: Optional[str] = Field(description="SCAS敏感词检测系统 请求签名密钥", default=None) + SCAS_BUSINESS_ID: Optional[str] = Field(description="SCAS敏感词检测系统 业务ID", default=None) + SCAS_SCENE_ID: Optional[str] = Field(description="SCAS敏感词检测系统 场景ID", default=None) + SCAS_URL: Optional[str] = Field(description="SCAS实例域名", default=None) # CSRF ENABLE_CSRF: bool = Field(description="是否启用CSRF Token功能", default=True) # MongoDB @@ -71,18 +77,10 @@ class ConfigModel(BaseModel): HALF_KEY1: str = Field(description="Half key 1") HALF_KEY2: str = Field(description="Half key 2") HALF_KEY3: str = Field(description="Half key 3") - # 模型类型 - MODEL: str = Field(description="选择的模型类型", default="openai") - # OpenAI API + # OpenAI大模型 LLM_KEY: Optional[str] = Field(description="OpenAI API 密钥", default=None) LLM_URL: Optional[str] = Field(description="OpenAI API URL地址", default=None) LLM_MODEL: Optional[str] = Field(description="OpenAI API 模型名", default=None) - # 星火大模型 - SPARK_APP_ID: Optional[str] = Field(description="星火大模型API 应用名", default=None) - SPARK_API_KEY: Optional[str] = Field(description="星火大模型API 密钥名", default=None) - SPARK_API_SECRET: Optional[str] = Field(description="星火大模型API 密钥值", default=None) - SPARK_API_URL: Optional[str] = Field(description="星火大模型API URL地址", default=None) - SPARK_LLM_DOMAIN: Optional[str] = Field(description="星火大模型API 领域名", default=None) # 参数猜解 SCHEDULER_BACKEND: Optional[str] = Field(description="参数猜解后端", default=None) SCHEDULER_MODEL: Optional[str] = Field(description="参数猜解模型名", default=None) @@ -94,6 +92,8 @@ class ConfigModel(BaseModel): PLUGIN_DIR: Optional[str] = Field(description="插件路径", default=None) # SQL接口路径 SQL_URL: str = Field(description="Chat2DB接口路径") + # Gitee白名单路径 + GITEE_WHITELIST: Optional[str] = Field(description="Gitee白名单路径") class Config: @@ -109,7 +109,7 @@ class Config: self._config = ConfigModel.model_validate(dotenv_values(config_file)) if os.getenv("PROD"): - os.remove(config_file) + Path(config_file).unlink() def __getitem__(self, key: str): # noqa: ANN204 """获得配置文件中特定条目的值 diff --git a/apps/common/oidc.py b/apps/common/oidc.py index 9bfd28aa..6b82636a 100644 --- a/apps/common/oidc.py +++ b/apps/common/oidc.py @@ -5,10 +5,11 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from typing import Any import aiohttp -from fastapi import status +from fastapi import HTTPException, status from apps.common.config import config from apps.constants import LOGGER +from apps.manager.gitee_white_list import GiteeIDManager from apps.models.redis import RedisConnectionPool @@ -51,6 +52,8 @@ async def get_oidc_user(access_token: str, refresh_token: str) -> dict: """获取OIDC用户""" if config["DEPLOY_MODE"] == "local": return await get_local_oidc_user(access_token, refresh_token) + if config["DEPLOY_MODE"] == "gitee": + return await get_gitee_oidc_user(access_token, refresh_token) if not access_token: err = "Access token is empty." @@ -132,3 +135,31 @@ async def get_local_oidc_user(access_token: str, refresh_token: str) -> dict: "user_sub": user_sub, } + +async def get_gitee_oidc_user(access_token: str, refresh_token: str) -> dict: + """获取Gitee用户信息""" + if not access_token: + err = "Access token is empty." + raise ValueError(err) + + url = f"{config['OIDC_USER_URL']}?access_token={access_token}" + result = None + async with aiohttp.ClientSession() as session, session.get(url, timeout=10) as resp: + if resp.status != status.HTTP_200_OK: + err = f"Get OIDC user error: {resp.status}, full response is: {await resp.text()}" + raise RuntimeError(err) + LOGGER.info(f"full response is {await resp.text()}") + result = await resp.json() + + user_sub = result["login"] + if not GiteeIDManager().check_user_exist_or_not(user_sub): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="auth error", + ) + await set_redis_token(user_sub, access_token, refresh_token) + + return { + "user_sub": user_sub, + } + diff --git a/apps/common/queue.py b/apps/common/queue.py index 569b2ef8..3b28be4d 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -8,7 +8,7 @@ from typing import Any from redis.exceptions import ResponseError from apps.constants import LOGGER -from apps.entities.enum import EventType, StepStatus +from apps.entities.enum_var import EventType, StepStatus from apps.entities.message import ( HeartbeatData, MessageBase, diff --git a/apps/common/wordscheck.py b/apps/common/wordscheck.py index 3c5d49ec..c6ae8eb7 100644 --- a/apps/common/wordscheck.py +++ b/apps/common/wordscheck.py @@ -2,11 +2,19 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import base64 +import hashlib +import hmac import http +import json import re -from typing import Union +import uuid +from datetime import datetime +from typing import Any, Optional, Union +import pytz import requests +from fastapi import status from apps.common.config import config from apps.constants import LOGGER @@ -35,6 +43,154 @@ class APICheck: return -1 +class SCAS: + """使用SCAS接口检查敏感词""" + + app_id: Optional[str] + sign_key: Optional[str] + business_id: Optional[str] + scene_id: Optional[str] + url: str + retry: int = 2 + timeout: int = 3 + count: int = 0 + enable: bool = True + THRESHOLD: int = 100 + + def __init__(self) -> None: + """初始化SCAS""" + self.app_id = config["SCAS_APP_ID"] + if self.app_id is None: + err = "配置文件中未设置SCAS_APP_ID" + raise ValueError(err) + + self.sign_key = config["SCAS_SIGN_KEY"] + if self.sign_key is None: + err = "配置文件中未设置SCAS_SIGN_KEY" + raise ValueError(err) + + self.business_id = config["SCAS_BUSINESS_ID"] + if self.business_id is None: + err = "配置文件中未设置SCAS_BUSINESS_ID" + raise ValueError(err) + + self.scene_id = config["SCAS_SCENE_ID"] + if self.scene_id is None: + err = "配置文件中未设置SCAS_SCENE_ID" + raise ValueError(err) + + self.url = config["SCAS_URL"] + "/scas/v1/textIdentify" + if self.url is None: + err = "配置文件中未设置SCAS_URL" + raise ValueError(err) + + def _make_auth_header(self, request: str, time: datetime) -> str: + """生成认证头""" + auth_param = 'CLOUDSOA-HMAC-SHA256 appid={}, timestamp={}, signature="{}"' + sign_format = "{}&{}&{}&{}&appid={}×tamp={}" + + current_timestamp = int(time.timestamp() * 1000) + + sign_str = sign_format.format("POST", "/scas/v1/textIdentify", "", request, + self.app_id, current_timestamp) + if self.sign_key is None: + err = "配置文件中未设置SCAS_SIGN_KEY" + raise ValueError(err) + + sign_value = base64.b64encode( + hmac.new( + bytes.fromhex(self.sign_key), + bytes(sign_str, "utf-8"), + hashlib.sha256, + ).digest()).decode("utf-8") + + return auth_param.format(self.app_id, current_timestamp, sign_value) + + def _make_request_body(self, message: str) -> tuple[dict[str, Any], str]: + """生成SCAS请求体""" + if not message: + return {}, "" + + task_id = str(uuid.uuid4()) + + current_time = datetime.now(pytz.timezone("Asia/Shanghai")) + timestamp = current_time.isoformat(sep=" ", timespec="milliseconds")[:-6] + timestamp += current_time.strftime("%z") + + post_data = { + "taskID": task_id, + "message": { + "text": message, + }, + "businessID": self.business_id, + "sceneID": self.scene_id, + "uid": "-1", + "reqTime": timestamp, + "returnCleanText": 0, + "loginType": "WEB", + } + header = self._make_auth_header(json.dumps(post_data), current_time) + + return post_data, header + + def _post_with_retry(self, post_data: dict[str, Any], header: str) -> Optional[requests.Response]: + for _ in range(self.retry): + try: + return requests.post(self.url, json=post_data, headers={ + "Content-Type": "application/json", + "Authorization": header, + }, timeout=self.timeout) + except Exception as e: # noqa: PERF203 + LOGGER.error(f"检查敏感词错误:{e!s}") + continue + return None + + def _check_message(self, message: str) -> int: # noqa: PLR0911 + """使用SCAS检查敏感词""" + # -1: 异常,0: 不通过,1: 通过 + post_data, header = self._make_request_body(message) + if not post_data and not header: + LOGGER.info("待审核信息错误") + return -1 + + req = self._post_with_retry(post_data, header) + if req is None: + LOGGER.info("风控接口调用参数错误") + return -1 + + if req.status_code != status.HTTP_200_OK: + LOGGER.info(f"风控HTTP错误:{req.status_code}") + return -1 + + return_data = req.json() + if "resultCode" not in return_data or "securityResult" not in return_data: + LOGGER.info("风控接口返回错误") + return -1 + + if return_data["resultCode"]: + LOGGER.info("风控处理错误:{}".format(return_data["resultCode"])) + return -1 + + if return_data["securityResult"] == "ACCEPT": + return 1 + return 0 + + def check(self, message: str) -> int: + """使用SCAS检查消息""" + ret = self._check_message(message) + if ret == -1: + if not self.enable: + # 放通 + return 1 + self.count += 1 + if self.count >= self.THRESHOLD: + self.enable = False + else: + self.enable = True + self.count = 0 + return ret + + class KeywordCheck: """使用关键词列表检查敏感词""" @@ -55,12 +211,14 @@ class KeywordCheck: class WordsCheck: """敏感词检查工具""" - tool: Union[APICheck, KeywordCheck, None] = None + tool: Union[APICheck, KeywordCheck, SCAS, None] = None @classmethod def init(cls) -> None: """初始化敏感词检查器""" - if config["DETECT_TYPE"] == "keyword": + if config["DETECT_TYPE"] == "scas": + cls.tool = SCAS() + elif config["DETECT_TYPE"] == "keyword": cls.tool = KeywordCheck() elif config["DETECT_TYPE"] == "wordscheck": cls.tool = APICheck() diff --git a/apps/constants.py b/apps/constants.py index 87ed23bc..270fd881 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -6,11 +6,17 @@ from __future__ import annotations import logging -CURRENT_REVISION_VERSION = "0.0.0" +# 新对话默认标题 NEW_CHAT = "New Chat" +# 滑动窗口限流 默认窗口期 SLIDE_WINDOW_TIME = 60 +# 滑动窗口限流 最大请求数 SLIDE_WINDOW_QUESTION_COUNT = 10 +# API Call 最大返回值长度(字符) MAX_API_RESPONSE_LENGTH = 4096 +# Scheduler最大历史轮次 MAX_SCHEDULER_HISTORY_SIZE = 3 +# 语义接口目录中工具子目录 +CALL_DIR = "call" -LOGGER = logging.getLogger("gunicorn.error") +LOGGER = logging.getLogger("ray") diff --git a/apps/dependency/session.py b/apps/dependency/session.py index cc10d4e4..05b1bf9a 100644 --- a/apps/dependency/session.py +++ b/apps/dependency/session.py @@ -24,6 +24,7 @@ class VerifySessionMiddleware(BaseHTTPMiddleware): if request.url.path in BYPASS_LIST: return await call_next(request) + # TODO: 加入apikey校验 cookie = request.cookies.get("ECSESSION", "") if request.client is None or request.client.host is None: err = "无法检测请求来源IP!" diff --git a/apps/entities/enum.py b/apps/entities/enum_var.py similarity index 83% rename from apps/entities/enum.py rename to apps/entities/enum_var.py index 8f93213e..e8eac33b 100644 --- a/apps/entities/enum.py +++ b/apps/entities/enum_var.py @@ -55,3 +55,17 @@ class EventType(str, Enum): STEP_OUTPUT = "step.output" FLOW_STOP = "flow.stop" DONE = "done" + + +class CallType(str, Enum): + """Call类型""" + + SYSTEM = "system" + PYTHON = "python" + + +class MetadataType(str, Enum): + """元数据类型""" + + SERVICE = "service" + APP = "app" diff --git a/apps/entities/flow.py b/apps/entities/flow.py new file mode 100644 index 00000000..27c20b1c --- /dev/null +++ b/apps/entities/flow.py @@ -0,0 +1,101 @@ +"""Flow和Service等外置配置数据结构 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +import uuid +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from apps.entities.enum_var import CallType, MetadataType + + +class Step(BaseModel): + """Flow中Step的数据""" + + name: str + confirm: bool = False + call_type: str + params: dict[str, Any] = {} + next: Optional[str] = None + + +class NextFlow(BaseModel): + """Flow中“下一步”的数据格式""" + + id: str + plugin: Optional[str] = None + question: Optional[str] = None + + +class Flow(BaseModel): + """Flow(工作流)的数据格式""" + + on_error: Optional[Step] = Step( + name="error", + call_type="llm", + params={ + "user_prompt": "当前工具执行发生错误,原始错误信息为:{data}. 请向用户展示错误信息,并给出可能的解决方案。\n\n背景信息:{context}", + }, + ) + steps: dict[str, Step] + next_flow: Optional[list[NextFlow]] = None + + +class Service(BaseModel): + """外部服务信息 + + collection: service + """ + + id: str = Field(alias="_id") + name: str + description: str + dir_path: str + + +class StepPool(BaseModel): + """Step信息 + + collection: step_pool + """ + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + name: str + description: str + + +class FlowPool(BaseModel): + """Flow信息 + + collection: flow_pool + """ + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + name: str + description: str + data: Flow + + +class CallMetadata(BaseModel): + """Call工具信息 + + key: call_metadata + """ + + id: str = Field(alias="_id", description="Call的ID") + type: CallType = Field(description="Call的类型") + name: str = Field(description="Call的名称") + description: str = Field(description="Call的描述") + path: str = Field(description="Call的路径;当为系统Call时,形如 system::LLM;当为Python Call时,形如 python::tune::call.tune.CheckSystem") + + +class Metadata(BaseModel): + """Service或App的元数据""" + + type: MetadataType = Field(description="元数据类型") + id: str = Field(alias="_id", description="元数据ID") + name: str = Field(description="元数据名称") + description: str = Field(description="元数据描述") + version: str = Field(description="元数据版本") + diff --git a/apps/entities/message.py b/apps/entities/message.py index 9beb4c4a..94e626a8 100644 --- a/apps/entities/message.py +++ b/apps/entities/message.py @@ -7,7 +7,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field from apps.entities.collection import RecordMetadata -from apps.entities.enum import EventType, FlowOutputType, StepStatus +from apps.entities.enum_var import EventType, FlowOutputType, StepStatus class HeartbeatData(BaseModel): diff --git a/apps/entities/plugin.py b/apps/entities/plugin.py index 6c5a756b..5dddc6c0 100644 --- a/apps/entities/plugin.py +++ b/apps/entities/plugin.py @@ -10,36 +10,6 @@ from apps.common.queue import MessageQueue from apps.entities.task import FlowHistory, RequestDataPlugin -class Step(BaseModel): - """Flow中Step的数据""" - - name: str - confirm: bool = False - call_type: str - params: dict[str, Any] = {} - next: Optional[str] = None - -class NextFlow(BaseModel): - """Flow中“下一步”的数据格式""" - - id: str - plugin: Optional[str] = None - question: Optional[str] = None - -class Flow(BaseModel): - """Flow(工作流)的数据格式""" - - on_error: Optional[Step] = Step( - name="error", - call_type="llm", - params={ - "user_prompt": "当前工具执行发生错误,原始错误信息为:{data}. 请向用户展示错误信息,并给出可能的解决方案。\n\n背景信息:{context}", - }, - ) - steps: dict[str, Step] - next_flow: Optional[list[NextFlow]] = None - - class PluginData(BaseModel): """插件数据格式""" diff --git a/apps/entities/record.py b/apps/entities/record.py index 4506815c..e720bb9c 100644 --- a/apps/entities/record.py +++ b/apps/entities/record.py @@ -11,7 +11,7 @@ from apps.entities.collection import ( RecordContent, RecordMetadata, ) -from apps.entities.enum import StepStatus +from apps.entities.enum_var import StepStatus class RecordDocument(Document): diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index 7ec5b01a..b97cbf5b 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -85,8 +85,7 @@ class AddCommentData(BaseModel): is_like: bool = Field(...) dislike_reason: list[str] = Field(default=[], max_length=10) reason_link: str = Field(default=None, max_length=200) - reason_description: str = Field( - default=None, max_length=500) + reason_description: str = Field(default=None, max_length=500) class PostDomainData(BaseModel): diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py index cf881c66..2534def0 100644 --- a/apps/entities/response_data.py +++ b/apps/entities/response_data.py @@ -7,7 +7,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field from apps.entities.collection import Blacklist, Document -from apps.entities.enum import DocumentStatus +from apps.entities.enum_var import DocumentStatus from apps.entities.plugin import PluginData from apps.entities.record import RecordData diff --git a/apps/entities/task.py b/apps/entities/task.py index 52705392..a3a1ae18 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -8,7 +8,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field -from apps.entities.enum import StepStatus +from apps.entities.enum_var import StepStatus from apps.entities.record import RecordData diff --git a/apps/entities/vector.py b/apps/entities/vector.py new file mode 100644 index 00000000..029283d5 --- /dev/null +++ b/apps/entities/vector.py @@ -0,0 +1,33 @@ +"""向量数据库数据结构;数据将存储在PostgreSQL中 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from pgvector.sqlalchemy import Vector +from sqlalchemy import Column, String +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class FlowVector(Base): + """Flow向量信息""" + + __tablename__ = "flow_vector" + id = Column(String(length=100), primary_key=True, nullable=False, unique=True) + embedding = Column(Vector(1024), nullable=False) + + +class ServiceVector(Base): + """Service向量信息""" + + __tablename__ = "service_vector" + id = Column(String(length=100), primary_key=True, nullable=False, unique=True) + embedding = Column(Vector(1024), nullable=False) + + +class StepPoolVector(Base): + """StepPool向量信息""" + + __tablename__ = "step_pool_vector" + id = Column(String(length=100), primary_key=True, nullable=False, unique=True) + embedding = Column(Vector(1024), nullable=False) diff --git a/apps/gunicorn.conf.py b/apps/gunicorn.conf.py deleted file mode 100644 index 63c2dd21..00000000 --- a/apps/gunicorn.conf.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Gunicorn配置文件 - -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -""" -from __future__ import annotations - -from apps.common.wordscheck import WordsCheck -from apps.scheduler.pool.loader import Loader - -preload_app = True -bind = "0.0.0.0:8002" -workers = 8 -timeout = 300 -accesslog = "-" -capture_output = True -worker_class = "uvicorn.workers.UvicornWorker" - -def on_starting(server): # noqa: ANN001, ANN201, ARG001 - """Gunicorn服务器启动时的初始化代码 - - :param server: 服务器配置项 - :return: - """ - WordsCheck.init() - Loader.init() - - -def post_fork(server, worker): # noqa: ANN001, ANN201 - """Gunicorn服务器每个Worker进程启动后的初始化代码 - - :param server: 服务器配置项 - :param worker: Worker配置项 - :return: - """ diff --git a/apps/llm/function.py b/apps/llm/function.py index a98d5973..039581fc 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -39,6 +39,9 @@ class FunctionLLM: if config["SCHEDULER_BACKEND"] == "ollama": self._client = ollama.AsyncClient( host=config["SCHEDULER_URL"], + headers={ + # "Authorization": f"Bearer {config['SCHEDULER_API_KEY']}", + }, ) @staticmethod diff --git a/apps/llm/patterns/rewoo.py b/apps/llm/patterns/rewoo.py index 39e78822..6057523f 100644 --- a/apps/llm/patterns/rewoo.py +++ b/apps/llm/patterns/rewoo.py @@ -98,3 +98,26 @@ class InitPlan(CorePattern): result += chunk return result + + +# class PlanEvaluator: +# system_prompt = """You are a plan evaluator. Your task is: for the given user objective and your original plan, \ +# +# +# """ +# user_prompt = """""" +# +# def __init__(self, system_prompt: Union[str, None] = None, user_prompt: Union[str, None] = None): +# if system_prompt is not None: +# self.system_prompt = system_prompt +# if user_prompt is not None: +# self.user_prompt = user_prompt +# +# @staticmethod +# @sglang.function +# def _plan(s): +# pass +# +# def generate(self, **kwargs) -> str: +# pass + diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index a0446c5a..d31bb56b 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -5,10 +5,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from collections.abc import AsyncGenerator import tiktoken -from langchain_core.messages import ChatMessage as LangchainChatMessage +from langchain_core.messages import ChatMessage from langchain_openai import ChatOpenAI -from sparkai.llm.llm import ChatSparkLLM -from sparkai.messages import ChatMessage as SparkChatMessage from apps.common.config import config from apps.common.singleton import Singleton @@ -22,48 +20,13 @@ class ReasoningLLM(metaclass=Singleton): def __init__(self) -> None: """判断配置文件里用了哪种大模型;初始化大模型客户端""" - if config["MODEL"] == "openai": - self._client = ChatOpenAI( - api_key=config["LLM_KEY"], - base_url=config["LLM_URL"], - model=config["LLM_MODEL"], - tiktoken_model_name="cl100k_base", - streaming=True, - ) - elif config["MODEL"] == "spark": - self._client = ChatSparkLLM( - spark_app_id=config["SPARK_APP_ID"], - spark_api_key=config["SPARK_API_KEY"], - spark_api_secret=config["SPARK_API_SECRET"], - spark_api_url=config["SPARK_API_URL"], - spark_llm_domain=config["SPARK_LLM_DOMAIN"], - request_timeout=600, - streaming=True, - ) - else: - err = "暂不支持此种大模型API" - raise NotImplementedError(err) - - - @staticmethod - def _construct_openai_message(messages: list[dict[str, str]]) -> list[LangchainChatMessage]: - """模型类型为OpenAI API时:构造消息列表 - - :param messages: 原始的消息,形如`{"role": "xxx", "content": "xxx"}` - :returns: 构造后的消息内容 - """ - return [LangchainChatMessage(content=msg["content"], role=msg["role"]) for msg in messages] - - - @staticmethod - def _construct_spark_message(messages: list[dict[str, str]]) -> list[SparkChatMessage]: - """当模型类型为星火(星火SDK时),构造消息 - - :param messages: 原始的消息,形如`{"role": "xxx", "content": "xxx"}` - :return: 构造后的消息内容 - """ - return [SparkChatMessage(content=msg["content"], role=msg["role"]) for msg in messages] - + self._client = ChatOpenAI( + api_key=config["LLM_KEY"], + base_url=config["LLM_URL"], + model=config["LLM_MODEL"], + tiktoken_model_name="cl100k_base", + streaming=True, + ) def _calculate_token_length(self, messages: list[dict[str, str]], *, pure_text: bool = False) -> int: """使用ChatGPT的cl100k tokenizer,估算Token消耗量""" @@ -76,7 +39,6 @@ class ReasoningLLM(metaclass=Singleton): return result - async def call(self, task_id: str, messages: list[dict[str, str]], max_tokens: int = 8192, temperature: float = 0.07, *, streaming: bool = True) -> AsyncGenerator[str, None]: """调用大模型,分为流式和非流式两种 @@ -88,14 +50,7 @@ class ReasoningLLM(metaclass=Singleton): :param temperature: 模型温度(随机化程度) """ input_tokens = self._calculate_token_length(messages) - - if config["MODEL"] == "openai": - msg_list = self._construct_openai_message(messages) - elif config["MODEL"] == "spark": - msg_list = self._construct_spark_message(messages) - else: - err = "暂不支持此种大模型API" - raise NotImplementedError(err) + msg_list = [ChatMessage(content=msg["content"], role=msg["role"]) for msg in messages] if streaming: result = "" diff --git a/apps/main.py b/apps/main.py index fe4bab93..4f2f8f37 100644 --- a/apps/main.py +++ b/apps/main.py @@ -4,11 +4,15 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from __future__ import annotations +import ray from apscheduler.schedulers.background import BackgroundScheduler from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from ray import serve +from ray.serve.config import HTTPOptions from apps.common.config import config +from apps.common.wordscheck import WordsCheck from apps.cron.delete_user import DeleteUserCron from apps.dependency.session import VerifySessionMiddleware from apps.routers import ( @@ -25,6 +29,7 @@ from apps.routers import ( plugin, record, ) +from apps.scheduler.pool.loader import Loader # 定义FastAPI app app = FastAPI(docs_url=None, redoc_url=None) @@ -54,3 +59,20 @@ app.include_router(knowledge.router) scheduler = BackgroundScheduler() scheduler.start() scheduler.add_job(DeleteUserCron.delete_user, "cron", hour=3) + +# 包装Ray +@serve.deployment(ray_actor_options={"num_gpus": 0}) +@serve.ingress(app) +class FastAPIWrapper: + """FastAPI Ray包装器""" + + +# 运行 +if __name__ == "__main__": + # 初始化 + WordsCheck.init() + Loader.init() + # 启动Ray + ray.init(dashboard_host="0.0.0.0", num_cpus=4) # noqa: S104 + serve.start(http_options=HTTPOptions(host="0.0.0.0", port=8002)) # noqa: S104 + serve.run(FastAPIWrapper.bind(), blocking=True) diff --git a/apps/manager/document.py b/apps/manager/document.py index 0d457752..985b0ef2 100644 --- a/apps/manager/document.py +++ b/apps/manager/document.py @@ -8,10 +8,8 @@ from typing import Optional import asyncer import magic -import minio from fastapi import UploadFile -from apps.common.config import config from apps.constants import LOGGER from apps.entities.collection import ( Conversation, @@ -20,6 +18,7 @@ from apps.entities.collection import ( RecordGroupDocument, ) from apps.entities.record import RecordDocument +from apps.models.minio import MinioClient from apps.models.mongo import MongoDB from apps.service import KnowledgeBaseService @@ -27,18 +26,10 @@ from apps.service import KnowledgeBaseService class DocumentManager: """文件相关操作""" - client = minio.Minio( - endpoint=config["MINIO_ENDPOINT"], - access_key=config["MINIO_ACCESS_KEY"], - secret_key=config["MINIO_SECRET_KEY"], - secure=config["MINIO_SECURE"], - ) - @classmethod def _storage_single_doc_minio(cls, file_id: str, document: UploadFile) -> str: """存储单个文件到MinIO""" - if not cls.client.bucket_exists("document"): - cls.client.make_bucket("document") + MinioClient.check_bucket("document") # 获取文件MIME file = document.file @@ -46,7 +37,7 @@ class DocumentManager: file.seek(0) # 上传到MinIO - cls.client.put_object( + MinioClient.upload_file( bucket_name="document", object_name=file_id, data=file, @@ -166,7 +157,7 @@ class DocumentManager: @classmethod def _remove_doc_from_minio(cls, doc_id: str) -> None: """从MinIO中删除文件""" - cls.client.remove_object("document", doc_id) + MinioClient.delete_file("document", doc_id) @classmethod async def delete_document(cls, user_sub: str, document_list: list[str]) -> bool: diff --git a/apps/manager/gitee_white_list.py b/apps/manager/gitee_white_list.py new file mode 100644 index 00000000..43ee2361 --- /dev/null +++ b/apps/manager/gitee_white_list.py @@ -0,0 +1,45 @@ +"""Gitee ID 白名单 Manager + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +import os +from pathlib import Path +from typing import ClassVar + +from apps.common.config import config +from apps.common.singleton import Singleton +from apps.constants import LOGGER + + +class GiteeIDManager(metaclass=Singleton): + """Gitee ID 白名单 Manager""" + + whitelist: ClassVar[list[str]] = [] + + def __init__(self) -> None: + """读取白名单文件""" + config_path = os.getenv("CONFIG") + if not config_path: + err = "CONFIG is not set." + raise ValueError(err) + + if not config["GITEE_WHITELIST"]: + LOGGER.warning("未设置GITEE白名单路径,不做处理。") + return + + path = Path(config_path, config["GITEE_WHITELIST"]) + with open(path, encoding="utf-8") as f: + for line in f: + line_strip = line.strip() + if not line_strip or line_strip.startswith("#"): + continue + GiteeIDManager.whitelist.append(line_strip) + + @staticmethod + def check_user_exist_or_not(gitee_id: str) -> bool: + """检查用户是否在白名单中 + + :param gitee_id: Gitee ID + :return: 是否在白名单中 + """ + return gitee_id in GiteeIDManager.whitelist diff --git a/apps/manager/session.py b/apps/manager/session.py index e339a00b..4098203f 100644 --- a/apps/manager/session.py +++ b/apps/manager/session.py @@ -68,14 +68,23 @@ class SessionManager: if not session_id: return await SessionManager.create_session(session_ip) + ip = None async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: try: pipe.hget(session_id, "ip") pipe.expire(session_id, config["SESSION_TTL"] * 60) - await pipe.execute() + result = await pipe.execute() + ip = result[0].decode() except Exception as e: LOGGER.error(f"Read session error: {e}") + # if not ip: + # session_id = SessionManager.create_session(session_ip) + # return session_id + # elif ip != session_ip: + # session_id = SessionManager.create_session(session_ip) + # return session_id + # else: return session_id @staticmethod diff --git a/apps/manager/task.py b/apps/manager/task.py index 4a85d450..eb9ee384 100644 --- a/apps/manager/task.py +++ b/apps/manager/task.py @@ -12,7 +12,7 @@ from apps.constants import LOGGER from apps.entities.collection import ( RecordGroup, ) -from apps.entities.enum import StepStatus +from apps.entities.enum_var import StepStatus from apps.entities.record import ( RecordContent, RecordData, diff --git a/apps/models/minio.py b/apps/models/minio.py new file mode 100644 index 00000000..39994f8c --- /dev/null +++ b/apps/models/minio.py @@ -0,0 +1,49 @@ +"""MinIO客户端""" + +from typing import Any + +import minio + +from apps.common.config import config + + +class MinioClient: + """MinIO客户端""" + + client = minio.Minio( + endpoint=config["MINIO_ENDPOINT"], + access_key=config["MINIO_ACCESS_KEY"], + secret_key=config["MINIO_SECRET_KEY"], + secure=config["MINIO_SECURE"], + ) + + @classmethod + def check_bucket(cls, bucket_name: str) -> None: + """检查Bucket是否存在""" + if not cls.client.bucket_exists(bucket_name): + cls.client.make_bucket(bucket_name) + + @classmethod + def upload_file(cls, **kwargs: Any) -> None: # noqa: ANN401 + """上传文件""" + cls.client.put_object(**kwargs) + + @classmethod + def download_file(cls, bucket_name: str, file_path: str) -> tuple[dict[str, Any], bytes]: + """下载文件""" + try: + obj_stat = cls.client.stat_object(bucket_name, file_path) + metadata = obj_stat.metadata if isinstance(obj_stat.metadata, dict) else {} + response = cls.client.get_object(bucket_name, file_path) + doc = response.read() + + return metadata, doc + finally: + if response: + response.close() + response.release_conn() + + @classmethod + def delete_file(cls, bucket_name: str, file_name: str) -> None: + """删除文件""" + cls.client.remove_object(bucket_name, file_name) diff --git a/apps/models/mongo.py b/apps/models/mongo.py index 53ed2183..7cb25419 100644 --- a/apps/models/mongo.py +++ b/apps/models/mongo.py @@ -33,6 +33,14 @@ class MongoDB: LOGGER.error(f"Get collection {collection_name} failed: {e}") raise RuntimeError(str(e)) from e + @classmethod + async def clear_collection(cls, collection_name: str) -> None: + """清空MongoDB集合(表)""" + try: + await cls._client[config["MONGODB_DATABASE"]][collection_name].delete_many({}) + except Exception as e: + LOGGER.error(f"Clear collection {collection_name} failed: {e}") + @classmethod def get_session(cls) -> AsyncClientSession: """获取MongoDB会话""" diff --git a/apps/routers/document.py b/apps/routers/document.py index 91077d3b..cc2457c6 100644 --- a/apps/routers/document.py +++ b/apps/routers/document.py @@ -8,7 +8,7 @@ from fastapi import APIRouter, Depends, File, Query, UploadFile, status from fastapi.responses import JSONResponse from apps.dependency import get_user, verify_csrf_token, verify_user -from apps.entities.enum import DocumentStatus +from apps.entities.enum_var import DocumentStatus from apps.entities.response_data import ( ConversationDocumentItem, ConversationDocumentMsg, diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index 50d93056..a102cb20 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -3,7 +3,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json -from typing import Any, ClassVar, Optional +from typing import Any, Optional import aiohttp from fastapi import status @@ -20,6 +20,7 @@ from apps.scheduler.pool.pool import Pool class _APIParams(BaseModel): endpoint: str = Field(description="API接口HTTP Method 与 URI") timeout: int = Field(description="工具超时时间", default=300) + fixed_data: dict[str, Any] = Field(description="固定数据", default={}) class API(CoreCall): @@ -27,24 +28,20 @@ class API(CoreCall): name: str = "api" description: str = "根据给定的用户输入和历史记录信息,向某一个API接口发送请求、获取数据。" - params_schema: ClassVar[dict[str, Any]] = _APIParams.model_json_schema() + params: type[_APIParams] = _APIParams - def __init__(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 + async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 """初始化API调用工具""" - # 固定参数 - self._core_params = syscall_vars - self._params = _APIParams.model_validate(kwargs) - # 初始化Slot Schema - self.slot_schema = {} + await super().init(syscall_vars, **kwargs) # 额外参数 - if "plugin_id" not in self._core_params.extra: + if "plugin_id" not in self._syscall_vars.extra: err = "[API] plugin_id not in extra_data" raise ValueError(err) - plugin_name: str = self._core_params.extra["plugin_id"] + plugin_name: str = self._syscall_vars.extra["plugin_id"] - method, _ = self._params.endpoint.split(" ") + method, _ = self.params.endpoint.split(" ") plugin_data = Pool().get_plugin(plugin_name) if plugin_data is None: err = f"[API] 插件{plugin_name}不存在!" @@ -59,7 +56,7 @@ class API(CoreCall): # 从spec中找出该接口对应的spec for item in full_spec.endpoints: name, _, _ = item - if name == self._params.endpoint: + if name == self.params.endpoint: self._spec = item if not hasattr(self, "_spec"): err = "[API] Endpoint not found." @@ -79,7 +76,7 @@ class API(CoreCall): async def call(self, slot_data: dict[str, Any]) -> CallResult: """调用API,然后返回LLM解析后的数据""" - method, url = self._params.endpoint.split(" ") + method, url = self.params.endpoint.split(" ") self._session = aiohttp.ClientSession() try: result = await self._call_api(method, url, slot_data) @@ -114,7 +111,7 @@ class API(CoreCall): elif self._auth["type"] == "oidc": token = await TokenManager.get_plugin_token( self._auth["domain"], - self._core_params.session_id, + self._syscall_vars.session_id, self._auth["access_token_url"], int(self._auth["token_expire_time"]), ) @@ -123,16 +120,16 @@ class API(CoreCall): if method == "GET": params.update(data) return self._session.get(self._server + url, params=params, headers=header, cookies=cookie, - timeout=self._params.timeout) + timeout=self.params.timeout) if method == "POST": if self._data_type == "form": form_data = files for key, val in data.items(): form_data.add_field(key, val) return self._session.post(self._server + url, data=form_data, headers=header, cookies=cookie, - timeout=self._params.timeout) + timeout=self.params.timeout) return self._session.post(self._server + url, json=data, headers=header, cookies=cookie, - timeout=self._params.timeout) + timeout=self.params.timeout) err = "Method not implemented." raise NotImplementedError(err) @@ -149,6 +146,27 @@ class API(CoreCall): err = "Data type not implemented." raise NotImplementedError(err) + # def _file_to_lists(self, spec: dict[str, Any]) -> aiohttp.FormData: + # file_form = aiohttp.FormData() + + # if self._params.files is None: + # return file_form + + # file_names = [] + # for file in self._params.files: + # file_names.append(Files.get_by_id(file)["name"]) + + # file_spec = check_upload_file(spec, file_names) + # selected_file = choose_file(file_names, file_spec, self.params_obj.question, self.params_obj.background, self.usage) + + # for key, val in json.loads(selected_file).items(): + # if isinstance(val, str): + # file_form.add_field(key, open(Files.get_by_name(val)["path"], "rb"), filename=val) + # else: + # for item in val: + # file_form.add_field(key, open(Files.get_by_name(item)["path"], "rb"), filename=item) + # return file_form + async def _call_api(self, method: str, url: str, slot_data: Optional[dict[str, Any]] = None) -> CallResult: LOGGER.info(f"调用接口{url},请求数据为{slot_data}") session_context = await self._make_api_call(method, url, slot_data, aiohttp.FormData()) diff --git a/apps/scheduler/call/choice.py b/apps/scheduler/call/choice.py index cb64b27c..a5cb6bee 100644 --- a/apps/scheduler/call/choice.py +++ b/apps/scheduler/call/choice.py @@ -2,11 +2,11 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any, ClassVar +from typing import Any from pydantic import BaseModel, Field -from apps.entities.plugin import CallError, CallResult, SysCallVars +from apps.entities.plugin import CallError, CallResult from apps.llm.patterns.select import Select from apps.scheduler.call.core import CoreCall @@ -21,20 +21,9 @@ class _ChoiceParams(BaseModel): class Choice(CoreCall): """Choice工具。用于大模型在多个选项中选择一个,并跳转到对应的Step。""" - def __init__(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 - """初始化Choice工具,解析参数。 - - :param params: Choice工具所需的参数 - """ - self._core_params = syscall_vars - self._params = _ChoiceParams.model_validate(kwargs) - # 初始化Slot Schema - self.slot_schema = {} - - name: str = "choice" description: str = "选择工具,用于根据给定的上下文和问题,判断正确/错误,或从选项列表中选择最符合用户要求的一项。" - params_schema: ClassVar[dict[str, Any]] = _ChoiceParams.model_json_schema() + params: type[_ChoiceParams] = _ChoiceParams async def call(self, _slot_data: dict[str, Any]) -> CallResult: @@ -44,16 +33,16 @@ class Choice(CoreCall): :return: Choice工具的输出信息。包含下一个Step的名称、自然语言解释等。 """ previous_data = {} - if len(self._core_params.history) > 0: - previous_data = CallResult(**self._core_params.history[-1].output_data).output + if len(self._syscall_vars.history) > 0: + previous_data = CallResult(**self._syscall_vars.history[-1].output_data).output try: result = await Select().generate( - question=self._params.propose, - background=self._core_params.background, + question=self.params.propose, + background=self._syscall_vars.background, data=previous_data, - choices=self._params.choices, - task_id=self._core_params.task_id, + choices=self.params.choices, + task_id=self._syscall_vars.task_id, ) except Exception as e: raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e @@ -64,5 +53,5 @@ class Choice(CoreCall): extra={ "next_step": result, }, - message=f"针对“{self._params.propose}”,作出的选择为:{result}。", + message=f"针对“{self.params.propose}”,作出的选择为:{result}。", ) diff --git a/apps/scheduler/call/cmd/cmd.py b/apps/scheduler/call/cmd/cmd.py index a6d22655..9d7e15e9 100644 --- a/apps/scheduler/call/cmd/cmd.py +++ b/apps/scheduler/call/cmd/cmd.py @@ -2,11 +2,11 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any, ClassVar, Optional +from typing import Any, Optional from pydantic import BaseModel, Field -from apps.entities.plugin import CallResult, SysCallVars +from apps.entities.plugin import CallResult from apps.scheduler.call.core import CoreCall @@ -23,14 +23,7 @@ class Cmd(CoreCall): name: str = "cmd" description: str = "根据BTDL描述文件,生成命令。" - params_schema: ClassVar[dict[str, Any]] = {} - - - def __init__(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003, ARG002 - """初始化Cmd工具""" - self._syscall_vars = syscall_vars - # 初始化Slot Schema - self.slot_schema = {} + params: type[_CmdParams] = _CmdParams async def call(self, _slot_data: dict[str, Any]) -> CallResult: """调用Cmd工具""" diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 048826fc..1616dfe4 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -3,8 +3,7 @@ 所有Call类必须继承此类,并实现所有方法。 Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from abc import ABC, abstractmethod -from typing import Any, ClassVar +from typing import Any from pydantic import BaseModel @@ -16,33 +15,38 @@ class AdditionalParams(BaseModel): -class CoreCall(ABC): +class CoreCall: """Call抽象类。所有Call必须继承此类,并实现所有方法。""" name: str = "" description: str = "" - params_schema: ClassVar[dict[str, Any]] = {} + params: type[BaseModel] = AdditionalParams + @property + def params_schema(self) -> dict[str, Any]: + """返回params的schema""" + return self.params.model_json_schema() - @abstractmethod - def __init__(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 - """初始化Call,并对参数进行解析。 + async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 + """初始化Call,赋值参数 :param syscall_vars: Call所需的固定参数。此处的参数为系统提供。 :param kwargs: Call所需的额外参数。此处的参数为Flow开发者填充。 """ # 使用此种方式进行params校验 self._syscall_vars = syscall_vars - self._params = AdditionalParams.model_validate(kwargs) + self._params = self.params.model_validate(kwargs) # 在此初始化Slot Schema self.slot_schema: dict[str, Any] = {} + async def load(self) -> None: + """如果Call需要载入文件,则在这里定义逻辑""" + pass # noqa: PIE790 - @abstractmethod async def call(self, slot_data: dict[str, Any]) -> CallResult: """运行Call。 :param slot_data: Call的参数槽。此处的参数槽为用户通过大模型交互式填充。 - :return: Dict类型的数据。返回值中"output"为工具的原始返回信息(有格式字符串);"message"为工具经LLM处理后的返回信息(字符串)。也可以带有其他字段,其他字段将起到额外的说明和信息传递作用。 + :return: CallResult类型的数据。返回值中"output"为工具的原始返回信息(有格式字符串);"message"为工具经LLM处理后的返回信息(字符串)。也可以带有其他字段,其他字段将起到额外的说明和信息传递作用。 """ raise NotImplementedError diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm.py index 43b5bd96..f6a28d6d 100644 --- a/apps/scheduler/call/llm.py +++ b/apps/scheduler/call/llm.py @@ -4,14 +4,14 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from datetime import datetime from textwrap import dedent -from typing import Any, ClassVar +from typing import Any import pytz from jinja2 import BaseLoader, select_autoescape from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field -from apps.entities.plugin import CallError, CallResult, SysCallVars +from apps.entities.plugin import CallError, CallResult from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall @@ -44,17 +44,9 @@ class _LLMParams(BaseModel): class LLM(CoreCall): """大模型调用工具""" - def __init__(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 - """初始化LLM Call""" - self._core_params = syscall_vars - self._params = _LLMParams.model_validate(kwargs) - # 初始化Slot Schema - self.slot_schema = {} - - name: str = "llm" description: str = "大模型调用工具,用于以指定的提示词和上下文信息调用大模型,并获得输出。" - params_schema: ClassVar[dict[str, Any]] = _LLMParams.model_json_schema() + params: type[_LLMParams] = _LLMParams async def call(self, _slot_data: dict[str, Any]) -> CallResult: @@ -63,9 +55,9 @@ class LLM(CoreCall): time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") formatter = { "time": time, - "context": self._core_params.background, - "question": self._core_params.question, - "history": self._core_params.history, + "context": self._syscall_vars.background, + "question": self._syscall_vars.question, + "history": self._syscall_vars.history, } try: @@ -75,14 +67,14 @@ class LLM(CoreCall): autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(self._params.system_prompt) + ).from_string(self.params.system_prompt) system_input = system_tmpl.render(**formatter) user_tmpl = SandboxedEnvironment( loader=BaseLoader(), autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(self._params.user_prompt) + ).from_string(self.params.user_prompt) user_input = user_tmpl.render(**formatter) except Exception as e: raise CallError(message=f"用户提示词渲染失败:{e!s}", data={}) from e @@ -94,7 +86,7 @@ class LLM(CoreCall): try: result = "" - async for chunk in ReasoningLLM().call(task_id=self._core_params.task_id, messages=message): + async for chunk in ReasoningLLM().call(task_id=self._syscall_vars.task_id, messages=message): result += chunk except Exception as e: raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/next_flow.py b/apps/scheduler/call/next_flow.py new file mode 100644 index 00000000..5bddaee4 --- /dev/null +++ b/apps/scheduler/call/next_flow.py @@ -0,0 +1,13 @@ +"""用于下一步工作流推荐的工具""" +from apps.scheduler.call.core import CallResult, CoreCall + + +class NextFlowCall(CoreCall): + """用于下一步工作流推荐的工具""" + + name = "next_flow" + description = "用于下一步工作流推荐的工具" + + def call(self) -> CallResult: + return CallResult(output={}, message="", output_schema={}) + diff --git a/apps/scheduler/call/reformat.py b/apps/scheduler/call/reformat.py index f31f9395..d86cc607 100644 --- a/apps/scheduler/call/reformat.py +++ b/apps/scheduler/call/reformat.py @@ -5,7 +5,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import json from datetime import datetime from textwrap import dedent -from typing import Any, ClassVar, Optional +from typing import Any, Optional import _jsonnet import pytz @@ -29,17 +29,12 @@ class Extract(CoreCall): name: str = "reformat" description: str = "从上一步的工具的原始JSON返回结果中,提取特定字段的信息。" - params_schema: ClassVar[dict[str, Any]] = _ReformatParam.model_json_schema() + params: type[_ReformatParam] = _ReformatParam - - def __init__(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 + async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 """初始化Reformat工具""" - self._core_params = syscall_vars - self._params = _ReformatParam.model_validate(kwargs) - self._last_output = CallResult(**self._core_params.history[-1].output_data) - # 初始化Slot Schema - self.slot_schema = {} - + await super().init(syscall_vars, **kwargs) + self._last_output = CallResult(**self._syscall_vars.history[-1].output_data) async def call(self, _slot_data: dict[str, Any]) -> CallResult: """调用Reformat工具 @@ -49,7 +44,7 @@ class Extract(CoreCall): """ # 判断用户是否给了值 time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") - if self._params.text is None: + if self.params.text is None: result_message = self._last_output.message else: text_template = SandboxedEnvironment( @@ -57,23 +52,23 @@ class Extract(CoreCall): autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(self._params.text) - result_message = text_template.render(time=time, history=self._core_params.history, question=self._core_params.question) + ).from_string(self.params.text) + result_message = text_template.render(time=time, history=self._syscall_vars.history, question=self._syscall_vars.question) - if self._params.data is None: + if self.params.data is None: result_data = self._last_output.output else: extra_str = json.dumps({ "time": time, - "question": self._core_params.question, + "question": self._syscall_vars.question, }, ensure_ascii=False) - history_str = json.dumps([CallResult(**item.output_data).output for item in self._core_params.history], ensure_ascii=False) + history_str = json.dumps([CallResult(**item.output_data).output for item in self._syscall_vars.history], ensure_ascii=False) data_template = dedent(f""" local extra = {extra_str}; local history = {history_str}; - {self._params.data} + {self.params.data} """) - result_data = json.loads(_jsonnet.evaluate_snippet(data_template, self._params.data), ensure_ascii=False) + result_data = json.loads(_jsonnet.evaluate_snippet(data_template, self.params.data), ensure_ascii=False) return CallResult( message=result_message, diff --git a/apps/scheduler/call/render/format.py b/apps/scheduler/call/render/format.py deleted file mode 100644 index d75a3894..00000000 --- a/apps/scheduler/call/render/format.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -from typing import Optional - -from apps.llm.patterns.core import CorePattern - - -class RenderFormat(CorePattern): - _system_prompt = "" - _user_prompt = "" - - def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None: - super().__init__(system_prompt, user_prompt) - - async def generate(self, task_id: str, **kwargs) -> str: - pass diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/render/render.py index ac3a8630..6cd406bf 100644 --- a/apps/scheduler/call/render/render.py +++ b/apps/scheduler/call/render/render.py @@ -4,7 +4,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json from pathlib import Path -from typing import Any, ClassVar +from typing import Any from apps.entities.plugin import CallError, CallResult, SysCallVars from apps.scheduler.call.core import CoreCall @@ -16,17 +16,14 @@ class Render(CoreCall): name: str = "render" description: str = "渲染图表工具,可将给定的数据绘制为图表。" - params_schema: ClassVar[dict[str, Any]] = {} - def __init__(self, syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 + async def init(self, syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 """初始化Render Call,校验参数,读取option模板 :param syscall_vars: Render Call参数 """ - self._core_params = syscall_vars - # 初始化Slot Schema - self.slot_schema = {} + await super().init(syscall_vars, **_kwargs) try: option_location = Path(__file__).parent / "option.json" @@ -39,7 +36,7 @@ class Render(CoreCall): async def call(self, _slot_data: dict[str, Any]) -> CallResult: """运行Render Call""" # 检测前一个工具是否为SQL - data = CallResult(**self._core_params.history[-1].output_data).output + data = CallResult(**self._syscall_vars.history[-1].output_data).output if data["type"] != "sql" or "dataset" not in data: raise CallError( message="图表生成失败!Render必须在SQL后调用!", @@ -70,7 +67,7 @@ class Render(CoreCall): self._option_template["dataset"]["source"] = data try: - llm_output = await RenderStyle().generate(self._core_params.task_id, question=self._core_params.question) + llm_output = await RenderStyle().generate(self._syscall_vars.task_id, question=self._syscall_vars.question) add_style = llm_output.get("additional_style", "") self._parse_options(column_num, llm_output["chart_type"], add_style, llm_output["scale_type"]) except Exception as e: diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql.py index efea0f8b..43dd359d 100644 --- a/apps/scheduler/call/sql.py +++ b/apps/scheduler/call/sql.py @@ -4,7 +4,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json -from typing import Any, ClassVar +from typing import Any import aiohttp from fastapi import status @@ -21,20 +21,14 @@ class SQL(CoreCall): name: str = "sql" description: str = "SQL工具,用于查询数据库中的结构化数据" - params_schema: ClassVar[dict[str, Any]] = {} - def __init__(self, syscall_vars: SysCallVars, **_kwargs) -> None: # noqa: ANN003 - """初始化SQL工具。 - - 解析SQL工具参数,拼接PostgreSQL连接字符串,创建SQLAlchemy Engine。 - :param params: SQL工具需要的参数。 - """ + async def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 + """初始化SQL工具。""" + await super().init(syscall_vars, **kwargs) + # 初始化aiohttp的ClientSession self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) - self._core_params = syscall_vars - # 初始化Slot Schema - self.slot_schema = {} - + # 初始化SQLAlchemy Engine try: db_url = f'postgresql+psycopg2://{config["POSTGRES_USER"]}:{config["POSTGRES_PWD"]}@{config["POSTGRES_HOST"]}/{config["POSTGRES_DATABASE"]}' self._engine = create_engine(db_url, pool_size=20, max_overflow=80, pool_recycle=300, pool_pre_ping=True) @@ -50,7 +44,7 @@ class SQL(CoreCall): :return: 从数据库中查询得到的数据,或报错信息 """ post_data = { - "question": self._core_params.question, + "question": self._syscall_vars.question, "topk_sql": 5, "use_llm_enhancements": True, } diff --git a/apps/scheduler/embedding.py b/apps/scheduler/embedding.py new file mode 100644 index 00000000..03493803 --- /dev/null +++ b/apps/scheduler/embedding.py @@ -0,0 +1,21 @@ +"""从Vectorize获取向量化数据 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +import aiohttp + +from apps.common.config import config + + +async def get_embedding(text: list[str]) -> list[float]: + """访问Vectorize的Embedding API,获得向量化数据 + + :param text: 待向量化文本(多条文本组成List) + :return: 文本对应的向量(顺序与text一致,也为List) + """ + api = config["VECTORIZE_HOST"].rstrip("/") + "/embedding" + + async with aiohttp.ClientSession() as session, session.post( + api, json={"texts": text}, timeout=30, + ) as response: + return await response.json() diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 560d7dd2..e7e2e210 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -6,10 +6,10 @@ import traceback from typing import Optional from apps.constants import LOGGER, MAX_SCHEDULER_HISTORY_SIZE -from apps.entities.enum import StepStatus +from apps.entities.enum_var import StepStatus +from apps.entities.flow import Step from apps.entities.plugin import ( CallResult, - Step, SysCallVars, SysExecVars, ) @@ -61,7 +61,7 @@ class Executor: # 保存Flow数据(只读) self._flow_data = flow_data - #尝试恢复State + # 尝试恢复State if task.flow_state: self.flow_state = task.flow_state # 如果flow_context为空,则从flow_history中恢复 diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py index 6e8d3476..ad1a3a2f 100644 --- a/apps/scheduler/executor/message.py +++ b/apps/scheduler/executor/message.py @@ -3,7 +3,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from apps.common.queue import MessageQueue -from apps.entities.enum import EventType, FlowOutputType, StepStatus +from apps.entities.enum_var import EventType, FlowOutputType, StepStatus +from apps.entities.flow import Flow from apps.entities.message import ( FlowStartContent, FlowStopContent, @@ -13,7 +14,6 @@ from apps.entities.message import ( ) from apps.entities.plugin import ( CallResult, - Flow, ) from apps.entities.task import ExecutorState, FlowHistory from apps.llm.patterns.executor import ExecutorResult diff --git a/apps/scheduler/json_schema.py b/apps/scheduler/json_schema.py index 87bd4d28..12d18394 100644 --- a/apps/scheduler/json_schema.py +++ b/apps/scheduler/json_schema.py @@ -43,18 +43,21 @@ format_to_regex = { } -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): +def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None) -> str: """将JSON Schema转换为正则表达式""" - schema: dict[str, Any] = json.loads(schema) - Validator.check_schema(schema) + schema_dict: dict[str, Any] = json.loads(schema) + Validator.check_schema(schema_dict) # Build reference resolver - schema = Resource(contents=schema, specification=DRAFT202012) - uri = schema.id() if schema.id() is not None else "" - registry = Registry().with_resource(uri=uri, resource=schema) + schema_resource = Resource(contents=schema_dict, specification=DRAFT202012) + uri = schema_resource.id() if schema_resource.id() is not None else "" + if not uri: + err = "schema_resource.id() is None" + raise ValueError(err) + registry = Registry().with_resource(uri=uri, resource=schema_resource) # type: ignore[arg-type] resolver = registry.resolver() - content = schema.contents + content = schema_resource.contents return to_regex(resolver, content, whitespace_pattern) @@ -94,9 +97,9 @@ def validate_quantifiers( return min_bound, max_bound -def to_regex( +def to_regex( # noqa: C901, PLR0911, PLR0912, PLR0915 resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None, -): +) -> str: """将 JSON Schema 实例转换为对应的正则表达式""" # set whitespace pattern if whitespace_pattern is None: @@ -241,7 +244,7 @@ def to_regex( try: if int(max_items) < int(min_items): err = "maxLength must be greater than or equal to minLength" - raise ValueError(err) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume) + raise ValueError(err) # noqa: TRY301 except ValueError: pass return f'"{STRING_INNER}{{{min_items},{max_items}}}"' diff --git a/apps/scheduler/openapi.py b/apps/scheduler/openapi.py new file mode 100644 index 00000000..04d638ea --- /dev/null +++ b/apps/scheduler/openapi.py @@ -0,0 +1,162 @@ +"""OpenAPI文档相关操作 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from collections.abc import Sequence +from copy import deepcopy +from typing import Any, Optional + +from pydantic import BaseModel + + +class ReducedOpenAPISpec(BaseModel): + """精简后的OpenAPISpec文档""" + + servers: list[dict] + id: str + description: str + endpoints: list[tuple[str, str, dict]] + + +def _retrieve_ref(path: str, schema: dict) -> dict: + """从OpenAPI文档中找到$ref对应的schema""" + components = path.split("/") + if components[0] != "#": + msg = ( + "ref paths are expected to be URI fragments, meaning they should start " + "with #." + ) + raise ValueError(msg) + out = schema + for component in components[1:]: + if component in out: + out = out[component] + elif component.isdigit() and int(component) in out: + out = out[int(component)] + else: + msg = f"Reference '{path}' not found." + raise KeyError(msg) + return deepcopy(out) + + +def _dereference_refs_helper( + obj: Any, # noqa: ANN401 + full_schema: dict[str, Any], + skip_keys: Sequence[str], + processed_refs: Optional[set[str]] = None, +) -> Any: # noqa: ANN401 + """递归地将OpenAPI中的$ref替换为实际的schema""" + if processed_refs is None: + processed_refs = set() + + if isinstance(obj, dict): + obj_out = {} + for k, v in obj.items(): + if k in skip_keys: + obj_out[k] = v + elif k == "$ref": + if v in processed_refs: + continue + processed_refs.add(v) + ref = _retrieve_ref(v, full_schema) + full_ref = _dereference_refs_helper( + ref, full_schema, skip_keys, processed_refs, + ) + processed_refs.remove(v) + return full_ref + elif isinstance(v, (list, dict)): + obj_out[k] = _dereference_refs_helper( + v, full_schema, skip_keys, processed_refs, + ) + else: + obj_out[k] = v + return obj_out + + if isinstance(obj, list): + return [ + _dereference_refs_helper(el, full_schema, skip_keys, processed_refs) + for el in obj + ] + + return obj + + +def _infer_skip_keys( + obj: Any, full_schema: dict, processed_refs: Optional[set[str]] = None, # noqa: ANN401 +) -> list[str]: + """推断需要跳过的OpenAPI文档中的键""" + if processed_refs is None: + processed_refs = set() + + keys = [] + if isinstance(obj, dict): + for k, v in obj.items(): + if k == "$ref": + if v in processed_refs: + continue + processed_refs.add(v) + ref = _retrieve_ref(v, full_schema) + keys.append(v.split("/")[1]) + keys += _infer_skip_keys(ref, full_schema, processed_refs) + elif isinstance(v, (list, dict)): + keys += _infer_skip_keys(v, full_schema, processed_refs) + elif isinstance(obj, list): + for el in obj: + keys += _infer_skip_keys(el, full_schema, processed_refs) + return keys + + +def dereference_refs( + schema_obj: dict, + *, + full_schema: Optional[dict] = None, +) -> dict: + """将OpenAPI中的$ref替换为实际的schema""" + full_schema = full_schema or schema_obj + skip_keys = _infer_skip_keys(schema_obj, full_schema) + return _dereference_refs_helper(schema_obj, full_schema, skip_keys) + + +def reduce_openapi_spec(spec: dict) -> ReducedOpenAPISpec: + """解析和处理OpenAPI文档""" + # 只支持get, post, patch, put, delete API + endpoints = [ + (f"{operation_name.upper()} {route}", docs.get("description"), docs) + for route, operation in spec["paths"].items() + for operation_name, docs in operation.items() + if operation_name in ["get", "post", "patch", "put", "delete"] + ] + + # 强制去除ref + endpoints = [ + (name, description, dereference_refs(docs, full_schema=spec)) + for name, description, docs in endpoints + ] + + # 只提取关键字段【可修改】 + def reduce_endpoint_docs(docs: dict) -> dict: + out = {} + if docs.get("description"): + out["description"] = docs.get("description") + if docs.get("parameters"): + out["parameters"] = [ + parameter + for parameter in docs.get("parameters", []) + if parameter.get("required") + ] + if "200" in docs["responses"]: + out["responses"] = docs["responses"]["200"] + if docs.get("requestBody"): + out["requestBody"] = docs.get("requestBody") + return out + + endpoints = [ + (name, description, reduce_endpoint_docs(docs)) + for name, description, docs in endpoints + ] + return ReducedOpenAPISpec( + servers=spec["servers"], + id=spec["info"]["title"], + description=spec["info"].get("description", ""), + endpoints=endpoints, + ) diff --git a/apps/scheduler/pool/loader.py b/apps/scheduler/pool/loader.py index 9fa8b640..2c512e38 100644 --- a/apps/scheduler/pool/loader.py +++ b/apps/scheduler/pool/loader.py @@ -19,7 +19,7 @@ import apps.scheduler.call as system_call from apps.common.config import config from apps.common.singleton import Singleton from apps.constants import LOGGER -from apps.entities.plugin import Flow, NextFlow, Step +from apps.entities.flow import Flow, NextFlow, Step from apps.scheduler.pool.pool import Pool OPENAPI_FILENAME = "openapi.yaml" diff --git a/apps/scheduler/pool/loader/__init__.py b/apps/scheduler/pool/loader/__init__.py new file mode 100644 index 00000000..cb2a8502 --- /dev/null +++ b/apps/scheduler/pool/loader/__init__.py @@ -0,0 +1,9 @@ +"""配置加载器 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from apps.scheduler.pool.loader.app import AppLoader +from apps.scheduler.pool.loader.call import CallLoader +from apps.scheduler.pool.loader.service import ServiceLoader + +__all__ = ["AppLoader", "CallLoader", "ServiceLoader"] diff --git a/apps/scheduler/pool/btdl.py b/apps/scheduler/pool/loader/btdl.py similarity index 96% rename from apps/scheduler/pool/btdl.py rename to apps/scheduler/pool/loader/btdl.py index 668caf4f..75623973 100644 --- a/apps/scheduler/pool/btdl.py +++ b/apps/scheduler/pool/loader/btdl.py @@ -1,10 +1,11 @@ +"""BTDL文档加载器 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" import hashlib -from typing import Any, Union +from typing import Any import yaml -from chromadb import Collection - -from apps.scheduler.vector import DocumentWrapper, VectorDB btdl_spec = [] @@ -17,13 +18,6 @@ btdl_spec = [] class BTDLLoader: """二进制描述文件 加载器""" - vec_collection: Collection - - def __init__(self, collection_name: str) -> None: - """初始化BTDL加载器""" - # Create or use existing vec_db - self.vec_collection = VectorDB.get_collection(collection_name) - @staticmethod # 循环检查每一个参数,确定为合法JSON Schema def _check_single_argument(argument: dict[str, Any], *, strict: bool = True) -> None: diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py new file mode 100644 index 00000000..5a81c3f8 --- /dev/null +++ b/apps/scheduler/pool/loader/metadata.py @@ -0,0 +1,31 @@ +"""元数据加载器 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from pathlib import Path + +import yaml + +from apps.constants import LOGGER +from apps.entities.flow import Metadata + + +class MetadataLoader: + """元数据加载器""" + + @staticmethod + def check_metadata(dir: Path) -> bool: + """检查metadata.yaml是否正确""" + # 检查yaml格式 + try: + metadata = yaml.safe_load(Path(dir, "metadata.yaml").read_text()) + except Exception as e: + LOGGER.error("metadata.yaml读取失败: %s", e) + return False + + + + @classmethod + async def load(cls) -> None: + """执行元数据加载""" + pass diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py new file mode 100644 index 00000000..b498588f --- /dev/null +++ b/apps/scheduler/pool/loader/openapi.py @@ -0,0 +1,45 @@ +"""OpenAPI文档载入器 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from pathlib import Path +from typing import Any + +import yaml + +from apps.scheduler.openapi import ReducedOpenAPISpec, reduce_openapi_spec +from apps.scheduler.pool.util import get_bytes_hash + + +class OpenAPILoader: + """OpenAPI文档载入器""" + + @classmethod + def load_from_disk(cls, yaml_path: str) -> tuple[str, ReducedOpenAPISpec]: + """从本地磁盘加载OpenAPI文档""" + path = Path(yaml_path) + if not path.exists(): + msg = f"File not found: {yaml_path}" + raise FileNotFoundError(msg) + + with path.open(mode="rb") as f: + content = f.read() + hash_value = get_bytes_hash(content) + spec = yaml.safe_load(content) + return hash_value, reduce_openapi_spec(spec) + + @classmethod + def load_from_minio(cls, yaml_path: str) -> tuple[str, ReducedOpenAPISpec]: + """从MinIO加载OpenAPI文档""" + pass + + @classmethod + def load(cls) -> ReducedOpenAPISpec: + """执行OpenAPI文档的加载""" + pass + + @classmethod + def process(cls, spec: ReducedOpenAPISpec) -> dict[str, Any]: + """处理OpenAPI文档""" + pass + diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index cf27ee95..c1278e3c 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -18,7 +18,8 @@ from sqlalchemy.orm import sessionmaker from apps.common.config import config from apps.common.singleton import Singleton from apps.constants import LOGGER -from apps.entities.plugin import Flow, PluginData +from apps.entities.flow import Flow +from apps.entities.plugin import PluginData from apps.scheduler.pool.entities import Base, CallItem, FlowItem, PluginItem from apps.scheduler.vector import DocumentWrapper, VectorDB diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 64085818..fb7bd87b 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -9,7 +9,7 @@ from typing import Union from apps.common.queue import MessageQueue from apps.constants import LOGGER from apps.entities.collection import Document -from apps.entities.enum import EventType +from apps.entities.enum_var import EventType from apps.entities.message import ( DocumentAddContent, InitContent, diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 66706ce7..4c078d07 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -14,7 +14,7 @@ from apps.entities.collection import ( Document, Record, ) -from apps.entities.enum import EventType, StepStatus +from apps.entities.enum_var import EventType, StepStatus from apps.entities.plugin import ExecutorBackground, SysExecVars from apps.entities.rag_data import RAGQueryReq from apps.entities.record import RecordDocument diff --git a/apps/scheduler/slot/parser/__init__.py b/apps/scheduler/slot/parser/__init__.py index 1dc187f2..25804d68 100644 --- a/apps/scheduler/slot/parser/__init__.py +++ b/apps/scheduler/slot/parser/__init__.py @@ -2,10 +2,14 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +from apps.scheduler.slot.parser.const import SlotConstParser from apps.scheduler.slot.parser.date import SlotDateParser +from apps.scheduler.slot.parser.default import SlotDefaultParser from apps.scheduler.slot.parser.timestamp import SlotTimestampParser __all__ = [ + "SlotConstParser", "SlotDateParser", + "SlotDefaultParser", "SlotTimestampParser", ] diff --git a/apps/scheduler/slot/parser/const.py b/apps/scheduler/slot/parser/const.py new file mode 100644 index 00000000..ed0d68dc --- /dev/null +++ b/apps/scheduler/slot/parser/const.py @@ -0,0 +1,24 @@ +"""固定值设置器 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from typing import Any + +from apps.entities.enum_var import SlotType +from apps.scheduler.slot.parser.core import SlotParser + + +class SlotConstParser(SlotParser): + """给字段设置固定值""" + + type: SlotType = SlotType.KEYWORD + name: str = "const" + + @classmethod + def convert(cls, data: Any, **kwargs) -> Any: # noqa: ANN003, ANN401 + """生成keyword的验证器 + + 如果没有对应逻辑则不实现 + """ + raise NotImplementedError + diff --git a/apps/scheduler/slot/parser/core.py b/apps/scheduler/slot/parser/core.py index ee407eb7..c2789958 100644 --- a/apps/scheduler/slot/parser/core.py +++ b/apps/scheduler/slot/parser/core.py @@ -7,7 +7,7 @@ from typing import Any from jsonschema import TypeChecker from jsonschema.protocols import Validator -from apps.entities.enum import SlotType +from apps.entities.enum_var import SlotType class SlotParser: diff --git a/apps/scheduler/slot/parser/date.py b/apps/scheduler/slot/parser/date.py index bdce6f4f..5d8086ca 100644 --- a/apps/scheduler/slot/parser/date.py +++ b/apps/scheduler/slot/parser/date.py @@ -10,7 +10,7 @@ from jionlp import parse_time from jsonschema import TypeChecker from apps.constants import LOGGER -from apps.entities.enum import SlotType +from apps.entities.enum_var import SlotType from apps.scheduler.slot.parser.core import SlotParser diff --git a/apps/scheduler/slot/parser/default.py b/apps/scheduler/slot/parser/default.py new file mode 100644 index 00000000..89ec9c19 --- /dev/null +++ b/apps/scheduler/slot/parser/default.py @@ -0,0 +1,24 @@ +"""默认值设置器 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from typing import Any + +from apps.entities.enum_var import SlotType +from apps.scheduler.slot.parser.core import SlotParser + + +class SlotDefaultParser(SlotParser): + """给字段设置默认值""" + + type: SlotType = SlotType.KEYWORD + name: str = "default" + + @classmethod + def convert(cls, data: Any, **kwargs) -> Any: # noqa: ANN003, ANN401 + """给字段设置默认值 + + 如果没有对应逻辑则不实现 + """ + raise NotImplementedError + diff --git a/apps/scheduler/slot/parser/timestamp.py b/apps/scheduler/slot/parser/timestamp.py index e86d3103..4dede63f 100644 --- a/apps/scheduler/slot/parser/timestamp.py +++ b/apps/scheduler/slot/parser/timestamp.py @@ -9,7 +9,7 @@ import pytz from jsonschema import TypeChecker from apps.constants import LOGGER -from apps.entities.enum import SlotType +from apps.entities.enum_var import SlotType from apps.scheduler.slot.parser.core import SlotParser diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index d7a386b9..bedc3a49 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -16,7 +16,12 @@ from jsonschema.validators import extend from apps.constants import LOGGER from apps.entities.plugin import CallResult from apps.llm.patterns.json import Json -from apps.scheduler.slot.parser import SlotDateParser, SlotTimestampParser +from apps.scheduler.slot.parser import ( + SlotConstParser, + SlotDateParser, + SlotDefaultParser, + SlotTimestampParser, +) from apps.scheduler.slot.util import escape_path, patch_json # 各类检查器 @@ -25,12 +30,20 @@ _TYPE_CHECKER = [ SlotTimestampParser, ] _FORMAT_CHECKER = [] -_KEYWORD_CHECKER = {} -# 类型转换器 -_CONVERTER = [ +_KEYWORD_CHECKER = { + "const": SlotConstParser, + "default": SlotDefaultParser, +} + +# 各类转换器 +_TYPE_CONVERTER = [ SlotDateParser, SlotTimestampParser, ] +_KEYWORD_CONVERTER = { + "const": SlotConstParser, + "default": SlotDefaultParser, +} class Slot: """参数槽 @@ -112,7 +125,7 @@ class Slot: processed_dict[key] = Slot._process_json_value(val, spec_data["properties"][key]) return processed_dict - for converter in _CONVERTER: + for converter in _TYPE_CONVERTER: # 如果是自定义类型 if converter.name == spec_data["type"]: # 如果类型有附加字段 diff --git a/apps/scheduler/vector.py b/apps/scheduler/vector.py deleted file mode 100644 index f789a3d9..00000000 --- a/apps/scheduler/vector.py +++ /dev/null @@ -1,131 +0,0 @@ -"""ChromaDB内存向量数据库 - -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -""" -from typing import ClassVar, Optional - -import numpy as np -import requests -from chromadb import ( - Client, - Collection, - Documents, - EmbeddingFunction, - Embeddings, -) -from chromadb.api import ClientAPI -from chromadb.api.types import IncludeEnum -from pydantic import BaseModel, Field - -from apps.common.config import config -from apps.constants import LOGGER - - -def _get_embedding(text: list[str]) -> list[np.ndarray]: - """访问Vectorize的Embedding API,获得向量化数据 - - :param text: 待向量化文本(多条文本组成List) - :return: 文本对应的向量(顺序与text一致,也为List) - """ - api = config["VECTORIZE_HOST"].rstrip("/") + "/embedding" - response = requests.post( - api, - json={"texts": text}, - verify=False, # noqa: S501 - timeout=30, - ) - - return [np.array(vec) for vec in response.json()] - - -# 模块内部类,不应在模块外部使用 -class DocumentWrapper(BaseModel): - """单个ChromaDB文档的结构""" - - data: str = Field(description="文档内容") - id: str = Field(description="文档ID,用于确保唯一性") - metadata: Optional[dict] = Field(description="文档元数据", default=None) - - -class RAGEmbedding(EmbeddingFunction): - """ChromaDB用于进行文本向量化的函数""" - - def __call__(self, input: Documents) -> Embeddings: # noqa: A002 - """调用RAG接口进行文本向量化""" - return _get_embedding(input) - - -class VectorDB: - """ChromaDB单例""" - - client: ClassVar[ClientAPI] = Client() - - @classmethod - def get_collection(cls, collection_name: str) -> Optional[Collection]: - """创建并返回ChromaDB集合 - - :param collection_name: 集合名称,字符串 - :return: ChromaDB集合对象 - """ - try: - return cls.client.get_or_create_collection(collection_name, embedding_function=RAGEmbedding(), - metadata={"hnsw:space": "cosine"}) - except Exception as e: - LOGGER.error(f"Get collection failed: {e}") - return None - - @classmethod - def delete_collection(cls, collection_name: str) -> None: - """删除ChromaDB集合 - - :param collection_name: 集合名称,字符串 - """ - cls.client.delete_collection(collection_name) - - @classmethod - def add_docs(cls, collection: Collection, docs: list[DocumentWrapper]) -> None: - """向ChromaDB集合中添加文档 - - :param collection: ChromaDB集合对象 - :param docs: 待向量化的文档List - """ - doc_list = [] - metadata_list = [] - id_list = [] - for doc in docs: - doc_list.append(doc.data) - id_list.append(doc.id) - metadata_list.append(doc.metadata) - - collection.add( - ids=id_list, - metadatas=metadata_list, - documents=doc_list, - ) - - @classmethod - def get_docs(cls, collection: Collection, question: str, requirements: dict, num: int = 3) -> list[DocumentWrapper]: - """根据输入,从ChromaDB中查询K个向量最相似的文档 - - :param collection: ChromaDB集合对象 - :param question: 查询输入 - :param requirements: 查询过滤条件 - :param num: Top K中K的值 - :return: 文档List,包含文档内容、元数据、ID - """ - result = collection.query( - query_texts=[question], - where=requirements, - n_results=num, - include=[IncludeEnum.documents, IncludeEnum.metadatas], - ) - - length = min(num, len(result["ids"][0])) - return [ - DocumentWrapper( - id=result["ids"][0][i], - metadata=result["metadatas"][0][i], # type: ignore[index] - data=result["documents"][0][i], # type: ignore[index] - ) - for i in range(length) - ] diff --git a/apps/service/activity.py b/apps/service/activity.py index 5ff0f305..ab0ad313 100644 --- a/apps/service/activity.py +++ b/apps/service/activity.py @@ -48,7 +48,7 @@ class Activity: @staticmethod async def remove_active(user_sub: str) -> None: - """清除用户的活动标识,释放GPU资源 + """清除用户的活跃标识,释放GPU资源 :param user_sub: 用户实体ID """ diff --git a/apps/service/rag.py b/apps/service/rag.py index 299e9a38..9c677425 100644 --- a/apps/service/rag.py +++ b/apps/service/rag.py @@ -29,7 +29,9 @@ class RAG: # asyncio HTTP请求 - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=300)) as session, session.post(url, headers=headers, data=payload, ssl=False) as response: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=300)) as session, session.post( + url, headers=headers, data=payload, ssl=False, + ) as response: if response.status != status.HTTP_200_OK: LOGGER.error(f"RAG服务返回错误码: {response.status}\n{await response.text()}") return diff --git a/apps/service/suggestion.py b/apps/service/suggestion.py index 53293e11..b23e95e0 100644 --- a/apps/service/suggestion.py +++ b/apps/service/suggestion.py @@ -9,7 +9,7 @@ from apps.common.queue import MessageQueue from apps.common.security import Security from apps.constants import LOGGER from apps.entities.collection import RecordContent -from apps.entities.enum import EventType +from apps.entities.enum_var import EventType from apps.entities.message import SuggestContent from apps.entities.task import RequestDataPlugin from apps.llm.patterns.recommend import Recommend diff --git a/apps/utils/get_api_doc.py b/apps/utils/get_api_doc.py index 8d8c9a6b..fe05d18a 100644 --- a/apps/utils/get_api_doc.py +++ b/apps/utils/get_api_doc.py @@ -21,7 +21,7 @@ def get_api_doc() -> None: raise ValueError(err) path = Path(config_path) / "openapi.json" - with open(path, "w", encoding="utf-8") as f: + with path.open("w", encoding="utf-8") as f: json.dump(get_openapi( title=app.title, version=app.version, diff --git a/assets/.env.example b/assets/.env.example index ccfc158f..bae83ccf 100644 --- a/assets/.env.example +++ b/assets/.env.example @@ -27,6 +27,11 @@ SESSION_TTL= DETECT_TYPE= WORDS_CHECK= WORDS_LIST= +SCAS_APP_ID= +SCAS_SIGN_KEY= +SCAS_BUSINESS_ID= +SCAS_SCENE_ID= +SCAS_URL= # logging LOG= @@ -46,6 +51,11 @@ PICKLE_KEY= DETECT_TYPE= WORDS_CHECK= WORDS_LIST= +SCAS_APP_ID= +SCAS_SIGN_KEY= +SCAS_BUSINESS_ID= +SCAS_SCENE_ID= +SCAS_URL= # CSRF ENABLE_CSRF=True @@ -100,3 +110,6 @@ PLUGIN_DIR= # SQL SQL_URL= + +# Gitee +GITEE_WHITELIST= diff --git a/assets/logging.example.json b/assets/logging.example.json deleted file mode 100644 index d1d1d62b..00000000 --- a/assets/logging.example.json +++ /dev/null @@ -1,47 +0,0 @@ -{ - "version": 1, - "disable_existing_loggers": false, - "root": { - "level": "INFO", - "handlers": [ - "console" - ] - }, - "loggers": { - "gunicorn.error": { - "level": "INFO", - "handlers": [ - "error_console" - ], - "propagate": true, - "qualname": "gunicorn.error" - }, - "gunicorn.access": { - "level": "INFO", - "handlers": [ - "console" - ], - "propagate": true, - "qualname": "gunicorn.access" - } - }, - "handlers": { - "console": { - "class": "logging.StreamHandler", - "formatter": "generic", - "stream": "ext://sys.stdout" - }, - "error_console": { - "class": "logging.StreamHandler", - "formatter": "generic", - "stream": "ext://sys.stderr" - } - }, - "formatters": { - "generic": { - "format": "[{asctime}][{levelname}][{name}][P{process}][T{thread}][{message}][{funcName}({filename}:{lineno})]", - "datefmt": "[%Y-%m-%d %H:%M:%S %z]", - "class": "logging.Formatter" - } - } -} \ No newline at end of file diff --git a/op.conf b/op.conf new file mode 100644 index 00000000..be6200b3 --- /dev/null +++ b/op.conf @@ -0,0 +1,7 @@ +up: git pull +build: + docker build . -t hub.oepkgs.net/neocopilot/framework-dev:`(git rev-parse --short HEAD)` + docker push hub.oepkgs.net/neocopilot/framework-dev:`(git rev-parse --short HEAD)` + docker rmi hub.oepkgs.net/neocopilot/framework-dev:`(git rev-parse --short HEAD)` + docker builder prune -f + kubectl -n euler-copilot set image deployment/framework-deploy framework=hub.oepkgs.net/neocopilot/framework-dev:`(git rev-parse --short HEAD)` diff --git a/requirements.txt b/requirements.txt index 5cbab4b8..0dddae7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,9 @@ JSON-minify==0.3.0 -PyMySQL==1.1.1 aiofiles==24.1.0 aiohttp==3.10.11 apscheduler==3.10.4 asgiref==3.8.1 asyncer==0.0.8 -chromadb==0.5.15 coverage==7.6.4 cryptography==43.0.3 eval-type-backport==0.2.0 @@ -39,11 +37,11 @@ python-multipart==0.0.9 pytz==2024.2 pyyaml==6.0.2 rank-bm25==0.2.2 +ray[serve]==2.40.0 redis==5.2.0 requests==2.32.3 sglang==0.4.0.post1 sortedcontainers==2.4.0 -spark-ai-python==0.4.5 sqlalchemy==2.0.35 starlette==0.41.2 tiktoken==0.8.0 diff --git a/sample/README.txt b/sample/README.txt new file mode 100644 index 00000000..d14b16dd --- /dev/null +++ b/sample/README.txt @@ -0,0 +1,2 @@ +该插件包样例为MinIO内文件的目录树结构, +用户可手动上传或下载MinIO内configs/子路径内的文件,实现批量预载和手动备份 \ No newline at end of file diff --git a/sdk/example_plugin/flows/flow.yaml b/sample/apps/test_app/flows/flow.yaml similarity index 100% rename from sdk/example_plugin/flows/flow.yaml rename to sample/apps/test_app/flows/flow.yaml diff --git a/sdk/example_plugin/lib/sub_lib/__init__.py b/sample/apps/test_app/metadata.yaml similarity index 100% rename from sdk/example_plugin/lib/sub_lib/__init__.py rename to sample/apps/test_app/metadata.yaml diff --git a/sample/calls/__init__.py b/sample/calls/__init__.py new file mode 100644 index 00000000..6a2a67fd --- /dev/null +++ b/sample/calls/__init__.py @@ -0,0 +1,8 @@ +"""样例工具 + +calls/文件夹中存放由用户或开发者自定义的Python工具。 +这些工具可被Framework载入,展示在语义接口中心中,可在编写工作流时使用。 +编写自定义工具时,需遵循一定的标准。参考样例工具代码以了解详情。 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" diff --git a/sample/calls/test_call/__init__.py b/sample/calls/test_call/__init__.py new file mode 100644 index 00000000..f8db4e94 --- /dev/null +++ b/sample/calls/test_call/__init__.py @@ -0,0 +1,14 @@ +"""样例工具 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +# 这里应当导入所有工具类型 +from .user_tool import UserTool + +# 【必填】使用__all__暴露所有Call Class +__all__ = [ + "UserTool", +] + +# 【必填】这个工具关联的服务 +service = "test_service" diff --git a/sample/calls/test_call/sub_lib/__init__.py b/sample/calls/test_call/sub_lib/__init__.py new file mode 100644 index 00000000..dad10068 --- /dev/null +++ b/sample/calls/test_call/sub_lib/__init__.py @@ -0,0 +1,4 @@ +"""样例子包 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" diff --git a/sample/calls/test_call/sub_lib/add.py b/sample/calls/test_call/sub_lib/add.py new file mode 100644 index 00000000..fe1daf5f --- /dev/null +++ b/sample/calls/test_call/sub_lib/add.py @@ -0,0 +1,9 @@ +"""样例子包 - 逻辑 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" + + +def add(a: int, b: int) -> int: + """加法""" + return a + b diff --git a/sdk/example_plugin/lib/user_tool.py b/sample/calls/test_call/user_tool.py similarity index 59% rename from sdk/example_plugin/lib/user_tool.py rename to sample/calls/test_call/user_tool.py index 080b81f2..dae413ee 100644 --- a/sdk/example_plugin/lib/user_tool.py +++ b/sample/calls/test_call/user_tool.py @@ -1,58 +1,55 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -# Python工具基本形式,供用户参考 -from typing import Optional, Any, List, Dict +"""样例工具 - 主逻辑 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from typing import Any, Optional + from pydantic import BaseModel, Field -# 可以使用子模块 -from . import sub_lib + +# 可以导入并使用子模块 +from .sub_lib import add class UserCallResult(BaseModel): - """ - Call运行后的返回值 - """ + """Call运行后的返回值""" + message: str = Field(description="Call的文字输出") output: Dict[str, Any] = Field(description="Call的结构化数据输出") extra: Optional[Dict[str, Any]] = Field(description="Call的额外输出", default=None) class UserCallParams(BaseModel): - """ - 此处为工具接受的各项参数。参数可在flow中配置,也可由大模型自动填充 - """ + """此处为工具接受的各项参数。参数可在flow中配置,也可由大模型自动填充""" + background: str = Field(description="上下文信息,由Executor自动传递") question: str = Field(description="给Call提供的用户输入,由Executor自动传递") - files: List[str] = Field(description="用户询问问题时上传的文件,由Executor自动传递") - history: List[UserCallResult] = Field(description="Executor中历史Call的返回值,由Executor自动传递") + files: list[str] = Field(description="用户询问问题时上传的文件,由Executor自动传递") + history: list[UserCallResult] = Field(description="Executor中历史Call的返回值,由Executor自动传递") task_id: Optional[str] = Field(description="任务ID, 由Executor自动传递") class UserTool: - """ - 这是工具类的基础形式 - """ - _name: str = "user_tool" + """这是工具类的基础形式""" + + name: str = "user_tool" """工具名称,会体现在flow中的on_error.tool和steps[].tool字段内""" - _description: str = "用户自定义工具样例" + description: str = "用户自定义工具样例" """工具描述,后续将用于自动编排工具""" - _params_obj: UserCallParams + params_obj: UserCallParams """工具接受的参数""" - _slot_schema: Dict[str, Any] + slot_schema: dict[str, Any] """参数槽的JSON Schema""" - def __init__(self, params: Dict[str, Any]): - """ - 初始化工具,并对参数进行解析。 - """ + def __init__(self, params: dict[str, Any]): + """初始化工具,并对参数进行解析。""" self._params_obj = UserCallParams(**params) pass - # - async def call(self, slot_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - """ - 工具调用逻辑 + # 工具调用逻辑 + async def call(self, slot_data: Optional[dict[str, Any]] = None) -> dict[str, Any]: + """工具调用逻辑 :param slot_data: 参数槽,由大模型交互式填充 """ - output = {} message = "" # 返回值为dict类型,其中output字段为工具的原始数据(带格式);message字段为工具经LLM处理后的数据(仅字符串);您还可以提供其他数据字段 diff --git a/sample/services/test_service/metadata.yaml b/sample/services/test_service/metadata.yaml new file mode 100644 index 00000000..364a313d --- /dev/null +++ b/sample/services/test_service/metadata.yaml @@ -0,0 +1,18 @@ +# 元数据种类 +type: service + +# 服务的ID +id: test_service +# 服务的名称(展示用) +name: 测试服务 +# 服务的描述(展示用) +description: | + 这是一个测试服务!可以在该文件夹中放置连接服务必需的配置,例如OpenAPI文档等。 +# Service包版本(展示用) +version: "1.0.0" +# 关联的用户账号 +user: zjq + +# API相关设置项 +openapi: + \ No newline at end of file diff --git a/sdk/example_plugin/openapi.yaml b/sample/services/test_service/openapi/api.yaml similarity index 100% rename from sdk/example_plugin/openapi.yaml rename to sample/services/test_service/openapi/api.yaml diff --git a/sdk/example_plugin/lib/__init__.py b/sdk/example_plugin/lib/__init__.py deleted file mode 100644 index c4b58a3c..00000000 --- a/sdk/example_plugin/lib/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -# 这里应当导入所有工具类型 -from .user_tool import UserTool - -# 在_exported变量中,加入所有工具类型 -_exported = [ - UserTool -] \ No newline at end of file diff --git a/sdk/example_plugin/plugin.json b/sdk/example_plugin/plugin.json deleted file mode 100644 index 31d4307d..00000000 --- a/sdk/example_plugin/plugin.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - "id": "1", - "name": "示例插件", - "description": "这是示例插件,用于演示编写插件的格式。\n 'type' 可以为 'api' 、 'local',指定了插件的类型。", - "auth": { - "type": "header", - "args": { - "Authorization": "" - } - } -} \ No newline at end of file -- Gitee