diff --git a/apps/common/wordscheck.py b/apps/common/wordscheck.py index d6c81c944e4f03faf8101462ca31988f92bb7a8f..cc389e938523a5585fb9caafc77aa46676e1d424 100644 --- a/apps/common/wordscheck.py +++ b/apps/common/wordscheck.py @@ -4,6 +4,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import http import re +from pathlib import Path from typing import Union import requests @@ -41,7 +42,7 @@ class KeywordCheck: def __init__(self) -> None: """初始化关键词列表""" - with open(config["WORDS_LIST"], encoding="utf-8") as f: + with Path(config["WORDS_LIST"]).open("r", encoding="utf-8") as f: self.words_list = f.read().splitlines() def check(self, message: str) -> int: diff --git a/apps/dependency/user.py b/apps/dependency/user.py index ea9e069aa6484fb0268aade23e6c3e43892e76c7..9be898dd6931513d1c517a620e78abf645c89a45 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -35,7 +35,6 @@ async def get_session(request: HTTPConnection) -> str: if not await SessionManager.verify_user(session_id): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") return session_id - # return "test" async def get_user(request: HTTPConnection) -> str: """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence @@ -48,7 +47,6 @@ async def get_user(request: HTTPConnection) -> str: if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") return user - # return "test" async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: """验证API Key是否有效;无效则抛出HTTP 401;接口级dependence diff --git a/apps/entities/pool.py b/apps/entities/pool.py index 750c7387ec7b467c73fd45bc94603b04482e70dd..810b6358e33b561c9ebb85785a32d281a0627d69 100644 --- a/apps/entities/pool.py +++ b/apps/entities/pool.py @@ -2,7 +2,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import uuid from datetime import datetime, timezone from typing import Any, Optional @@ -13,7 +12,7 @@ from apps.entities.flow import AppLink, Permission from apps.entities.flow_topology import PositionItem -class PoolBase(BaseModel): +class BaseData(BaseModel): """Pool的基础信息""" id: str = Field(alias="_id") @@ -30,7 +29,7 @@ class ServiceApiInfo(BaseModel): path: str = Field(description="OpenAPI文件路径") -class ServicePool(PoolBase): +class ServicePool(BaseData): """外部服务信息 collection: service @@ -43,37 +42,42 @@ class ServicePool(PoolBase): openapi_spec: dict = Field(description="服务关联的 OpenAPI 文件内容") -class CallPool(PoolBase): +class CallPool(BaseData): """Call信息 collection: call + + 路径的格式: + 1. 系统Call的路径格式样例:“LLM” + 2. Python Call的路径格式样例:“tune::call.tune.CheckSystem” """ - id: str = Field(description="Call的ID", alias="_id") type: CallType = Field(description="Call的类型") path: str = Field(description="Call的路径") -class NodePool(PoolBase): - """Node信息 +class Node(BaseData): + """Node合并后的信息(不存库)""" + + service_id: Optional[str] = Field(description="Node所属的Service ID", default=None) + call_id: str = Field(description="所使用的Call的ID") + params_schema: dict[str, Any] = Field(description="Node的参数schema", default={}) + output_schema: dict[str, Any] = Field(description="Node输出的完整Schema", default={}) + + +class NodePool(BaseData): + """Node合并前的信息(作为附带信息的指针) collection: node - 注: - 1. 基类Call的ID,即meta_call,可以为None,表示该Node是系统Node - 2. 路径的格式: - 1. 系统Node的路径格式样例:“LLM” - 2. Python Node的路径格式样例:“tune::call.tune.CheckSystem” """ - id: str = Field(description="Node的ID", default_factory=lambda: str(uuid.uuid4()), alias="_id") - service_id: str = Field(description="Node所属的Service ID") + service_id: Optional[str] = Field(description="Node所属的Service ID", default=None) call_id: str = Field(description="所使用的Call的ID") - api_path: Optional[str] = Field(description="Call的API路径", default=None) - params_schema: dict[str, Any] = Field(description="Node的参数schema;只包含用户可以改变的参数", default={}) - output_schema: dict[str, Any] = Field(description="Node的输出schema;做输出的展示用", default={}) + path: str = Field(description="Node的路径") + known_params: dict[str, Any] = Field(description="已知的用于Call部分的参数", default={}) -class AppFlow(PoolBase): +class AppFlow(BaseData): """Flow的元数据;会被存储在App下面""" enabled: bool = Field(description="是否启用", default=True) @@ -82,7 +86,7 @@ class AppFlow(PoolBase): description="Flow的视觉焦点", default=PositionItem(x=0, y=0)) -class AppPool(PoolBase): +class AppPool(BaseData): """应用信息 collection: app diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index 8321a5b9e66b09914ca47667065a192560fee4a8..99b96a901965903511501969d4b0b6dd8b09e1d4 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -14,6 +14,7 @@ from apps.entities.task import RequestDataApp class MockRequestData(BaseModel): """POST /api/mock/chat的请求体""" + app_id: str = Field(default="", description="应用ID", alias="appId") flow_id: str = Field(default="", description="流程ID", alias="flowId") conversation_id : str = Field(..., description="会话ID", alias="conversationId") diff --git a/apps/entities/vector.py b/apps/entities/vector.py index 780b2e7f7b4bd09ff0cbae7df042b3f96f638df7..d7f207aed0dde7ee7cc431aa7e15de54555bb21f 100644 --- a/apps/entities/vector.py +++ b/apps/entities/vector.py @@ -9,7 +9,7 @@ from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() -class AppVector(Base): +class AppPoolVector(Base): """App向量信息""" __tablename__ = "app_vector" @@ -17,7 +17,7 @@ class AppVector(Base): embedding = Column(Vector(1024), nullable=False) -class ServiceVector(Base): +class ServicePoolVector(Base): """Service向量信息""" __tablename__ = "service_vector" @@ -25,7 +25,7 @@ class ServiceVector(Base): embedding = Column(Vector(1024), nullable=False) -class NodeVector(Base): +class NodePoolVector(Base): """Node向量信息""" __tablename__ = "node_vector" diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index a9844f465ce53b9c618104e78edcfbbe21a098e3..31b12e929d3601e6b09e78973a88d71b77bb97f8 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -55,7 +55,7 @@ class ReasoningLLM(metaclass=Singleton): return messages - async def call(self, task_id: str, messages: list[dict[str, str]], # noqa: C901, PLR0912 + async def call(self, task_id: str, messages: list[dict[str, str]], # noqa: C901 max_tokens: Optional[int] = None, temperature: Optional[float] = None, *, streaming: bool = True, result_only: bool = True) -> AsyncGenerator[str, None]: """调用大模型,分为流式和非流式两种""" diff --git a/apps/manager/flow.py b/apps/manager/flow.py index d2b302a8f48b454ea2db6b1d7010f66f07cd8a1b..105a8d6cca6cac4764c1bff776aeaa9b8f9fa338 100644 --- a/apps/manager/flow.py +++ b/apps/manager/flow.py @@ -9,12 +9,20 @@ from pymongo import ASCENDING from apps.constants import LOGGER from apps.entities.enum_var import PermissionType from apps.entities.flow import Edge, Flow, FlowConfig, Step, StepPos -from apps.entities.flow_topology import EdgeItem, FlowItem, NodeItem, NodeMetaDataItem, NodeServiceItem, PositionItem +from apps.entities.flow_topology import ( + EdgeItem, + FlowItem, + NodeItem, + NodeMetaDataItem, + NodeServiceItem, + PositionItem, +) from apps.entities.pool import AppFlow from apps.models.mongo import MongoDB class FlowManager: + """Flow相关操作""" @staticmethod async def validate_user_node_meta_data_access(user_sub: str, node_meta_data_id: str) -> bool: diff --git a/apps/models/postgres.py b/apps/models/postgres.py index 9ede34c242197466e95e902ae50d0f22eb4064a1..f76fd74ef6ab24721f836ca56fc11929580e4f9f 100644 --- a/apps/models/postgres.py +++ b/apps/models/postgres.py @@ -42,14 +42,29 @@ class PostgreSQL: @staticmethod async def get_embedding(text: list[str]) -> list[float]: - """访问Vectorize的Embedding API,获得向量化数据 + """访问OpenAI兼容的Embedding API,获得向量化数据 :param text: 待向量化文本(多条文本组成List) :return: 文本对应的向量(顺序与text一致,也为List) """ - api = config["VECTORIZE_HOST"].rstrip("/") + "/embedding" + api = config["EMBEDDING_URL"] + + if config["EMBEDDING_KEY"]: + headers = { + "Authorization": f"Bearer {config['EMBEDDING_KEY']}", + } + else: + headers = {} + + headers["Content-Type"] = "application/json" + data = { + "encoding_format": "float", + "model": config["EMBEDDING_MODEL"], + "input": text, + } async with aiohttp.ClientSession() as session, session.post( - api, json={"texts": text}, timeout=30, + api, json=data, headers=headers, timeout=60, ) as response: - return await response.json() + json = await response.json() + return [item["embedding"] for item in json["data"]] diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 19172eeb9c7b58dac8e6bc0e52a192b11faf93a0..46066efb814927b6444b19169048717b754760a2 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -124,11 +124,11 @@ async def get_conversation_list(user_sub: Annotated[str, Depends(get_user)]): # @router.post("", dependencies=[Depends(verify_csrf_token)], response_model=AddConversationRsp) -async def add_conversation( +async def add_conversation( # noqa: ANN201 user_sub: Annotated[str, Depends(get_user)], - appId: Optional[str] = None, # noqa: N803 - debug: Optional[bool] = None, # noqa: N803 -): + appId: Optional[str] = None, + debug: Optional[bool] = None, +): """手动创建新对话""" conversations = await ConversationManager.get_conversation_by_user_sub(user_sub) # 尝试创建新对话 @@ -162,7 +162,7 @@ async def add_conversation( @router.put("", response_model=UpdateConversationRsp, dependencies=[Depends(verify_csrf_token)]) async def update_conversation( # noqa: ANN201 post_body: ModifyConversationData, - conversationId: Annotated[str, Query()], # noqa: N803 + conversationId: Annotated[str, Query()], user_sub: Annotated[str, Depends(get_user)], ): """更新特定Conversation的数据""" diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index 6558f30c0da1c67715297c125f0561807f6680df..bcca063d815e311b0063bb1249dee90d6af29127 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -24,8 +24,9 @@ class APIParams(BaseModel): "get", "post", "put", "delete", "patch", ] = Field(description="API接口的HTTP Method") timeout: int = Field(description="工具超时时间", default=300) - input_data: dict[str, Any] = Field(description="固定数据", default={}) + body_override: dict[str, Any] = Field(description="固定数据", default={}) auth: dict[str, Any] = Field(description="API鉴权信息", default={}) + input_schema: dict[str, Any] = Field(description="API请求体的JSON Schema", default={}) service_id: Optional[str] = Field(description="服务ID") @@ -46,7 +47,6 @@ class API(metaclass=CoreCall, param_cls=APIParams, output_cls=_APIOutput): def init(self, syscall_vars: SysCallVars, **kwargs) -> None: # noqa: ANN003 """初始化API调用工具""" - if len(self.) if kwargs["method"] == "POST": if "requestBody" in self._spec[2]: self.slot_schema, self._data_type = self._check_data_type(self._spec[2]["requestBody"]["content"]) diff --git a/apps/scheduler/call/suggest.py b/apps/scheduler/call/suggest.py index ebfa3fde74b69e87af037ffafc4f0bf248fa87d3..03f9ae7bb3c8ebbb1dcfc3eddbb2bb911155b257 100644 --- a/apps/scheduler/call/suggest.py +++ b/apps/scheduler/call/suggest.py @@ -7,7 +7,10 @@ from typing import Any, Optional from pydantic import BaseModel, Field from apps.entities.scheduler import CallError, SysCallVars -from apps.manager import TaskManager, UserDomainManager +from apps.manager import ( + TaskManager, + UserDomainManager, +) from apps.scheduler.call.core import CoreCall @@ -55,4 +58,17 @@ class Suggestion(metaclass=CoreCall, param_cls=_SuggestInput, output_cls=_Sugges # 获取当前任务 task = await TaskManager.get_task(sys_vars.task_id) + # 获取当前用户的画像 + user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(sys_vars.user_sub, 5) + + current_record = [ + { + "role": "user", + "content": task.record.content.question, + }, + { + "role": "assistant", + "content": task.record.content.answer, + }, + ] diff --git a/apps/scheduler/openapi.py b/apps/scheduler/openapi.py index b5e6657992b52e73de97ebaa041cbad02649c163..187bf424242a6fe45952a471edde4e47f23a27f4 100644 --- a/apps/scheduler/openapi.py +++ b/apps/scheduler/openapi.py @@ -12,11 +12,12 @@ from pydantic import BaseModel, Field class ReducedOpenAPIEndpoint(BaseModel): """精简后的OpenAPI文档中的单个API""" + id: Optional[str] = Field(default=None, description="API的Operation ID") uri: str = Field(..., description="API的URI") method: str = Field(..., description="API的请求方法") name: str = Field(..., description="API的自定义名称") description: str = Field(..., description="API的描述信息") - schema: dict = Field(..., description="API的JSON Schema") + spec: dict = Field(..., description="API的JSON Schema") class ReducedOpenAPISpec(BaseModel): @@ -149,15 +150,16 @@ def reduce_openapi_spec(spec: dict) -> ReducedOpenAPISpec: # 只支持get, post, patch, put, delete API;强制去除ref;提取关键字段 endpoints = [ ReducedOpenAPIEndpoint( + id=docs.get("operationId", None), uri=route, method=operation_name, name=docs.get("summary"), description=docs.get("description"), - schema=reduce_endpoint_docs(dereference_refs(docs, full_schema=spec)), + spec=reduce_endpoint_docs(dereference_refs(docs, full_schema=spec)), ) for route, operation in spec["paths"].items() for operation_name, docs in operation.items() - if operation_name in ["get", "post", "patch", "put", "delete"] + if operation_name in ["get", "post", "patch", "put", "delete"] and (not hasattr(docs, "deprecated") or not docs.deprecated) ] return ReducedOpenAPISpec( diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 617834ecb72a336c70af84e38b6d2274ed6bcff1..e6f6f4c9ef3b2db7718595506c64e3c490a9fc92 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -12,8 +12,13 @@ import apps.scheduler.call as system_call from apps.common.config import config from apps.constants import CALL_DIR, LOGGER from apps.entities.enum_var import CallType -from apps.entities.pool import CallPool +from apps.entities.pool import ( + CallPool, + NodePool, +) +from apps.entities.vector import NodePoolVector from apps.models.mongo import MongoDB +from apps.models.postgres import PostgreSQL from apps.scheduler.pool.util import get_short_hash @@ -41,95 +46,224 @@ class CallLoader: return flag - @staticmethod - async def _load_system_call() -> list[CallPool]: + @classmethod + async def _load_system_call(cls) -> tuple[list[CallPool], list[NodePool]]: """加载系统Call""" - metadata = [] + call_metadata = [] + node_metadata = [] - for call_name in system_call.__all__: - call_cls = getattr(system_call, call_name) - if not CallLoader._check_class(call_cls): + for call_id in system_call.__all__: + call_cls = getattr(system_call, call_id) + if not cls._check_class(call_cls): err = f"类{call_cls.__name__}不符合Call标准要求。" LOGGER.info(msg=err) continue - metadata.append( + call_metadata.append( CallPool( - _id=call_name, + _id=call_id, type=CallType.SYSTEM, name=call_cls.name, description=call_cls.description, - path=call_name, + path=call_id, ), ) - return metadata + node_metadata.append( + NodePool( + _id=call_id, + call_id=call_id, + name=call_cls.name, + description=call_cls.description, + path=call_id, + ), + ) + + return call_metadata, node_metadata @classmethod - async def _load_python_call(cls) -> list[CallPool]: - """加载Python Call""" - call_dir = Path(config["SERVICE_DIR"]) / CALL_DIR - metadata = [] + async def _load_single_call_dir(cls, call_name: str) -> tuple[list[CallPool], list[NodePool]]: + """加载单个Call package""" + call_metadata = [] + node_metadata = [] - # 检查是否存在__init__.py + call_dir = Path(config["SERVICE_DIR"]) / CALL_DIR / call_name if not (call_dir / "__init__.py").exists(): - LOGGER.info(msg=f"目录{call_dir}不存在__init__.py文件。") - (Path(call_dir) / "__init__.py").touch() + LOGGER.info(msg=f"模块{call_dir}不存在__init__.py文件,尝试自动创建。") + try: + (Path(call_dir) / "__init__.py").touch() + except Exception as e: + err = f"自动创建模块文件{call_dir}/__init__.py失败:{e}。" + raise RuntimeError(err) from e + + # 载入子包 + try: + call_package = importlib.import_module("call." + call_name) + except Exception as e: + err = f"载入模块call.{call_name}失败:{e}。" + raise RuntimeError(err) from e + + # 已载入包,处理包中每个工具 + if not hasattr(call_package, "__all__"): + err = f"包call.{call_name}不符合模块要求,无法处理。" + LOGGER.info(msg=err) + raise ValueError(err) + + for call_id in call_package.__all__: + try: + call_cls = getattr(call_package, call_id) + except Exception as e: + err = f"载入工具{call_name}.{call_id}失败:{e};跳过载入。" + LOGGER.info(msg=err) + continue + + if not cls._check_class(call_cls): + err = f"工具{call_name}.{call_id}不符合标准要求;跳过载入。" + LOGGER.info(msg=err) + continue + + cls_path = f"{call_package.service}::call.{call_name}.{call_id}" + cls_hash = get_short_hash(cls_path.encode()) + call_metadata.append( + CallPool( + _id=cls_hash, + type=CallType.PYTHON, + name=call_cls.name, + description=call_cls.description, + path=cls_path, + ), + ) + node_metadata.append( + NodePool( + _id=cls_hash, + call_id=cls_hash, + name=call_cls.name, + description=call_cls.description, + path=cls_path, + ), + ) + + return call_metadata, node_metadata - # 载入整个包 + @classmethod + async def _load_all_user_call(cls) -> tuple[list[CallPool], list[NodePool]]: + """加载Python Call""" + call_dir = Path(config["SERVICE_DIR"]) / CALL_DIR + call_metadata = [] + node_metadata = [] + + # 载入父包 try: sys.path.insert(0, str(call_dir)) + if not (call_dir / "__init__.py").exists(): + LOGGER.info(msg=f"父模块{call_dir}不存在__init__.py文件,尝试自动创建。") + (Path(call_dir) / "__init__.py").touch() importlib.import_module("call") except Exception as e: - err = f"载入包{call_dir}失败:{e}" + err = f'父模块"call"创建失败:{e};无法载入。' raise RuntimeError(err) from e # 处理每一个子包 for call_file in Path(call_dir).rglob("*"): if not call_file.is_dir(): continue - # 载入包 try: - call_package = importlib.import_module("call." + call_file.name) - if not CallLoader._check_class(call_package.service): - LOGGER.info(msg=f"包call.{call_file.name}不符合Call标准要求,跳过载入。") - continue - - for call_id in call_package.__all__: - call_cls = getattr(call_package, call_id) - if not CallLoader._check_class(call_cls): - LOGGER.info(msg=f"类{call_cls.__name__}不符合Call标准要求,跳过载入。") - continue - - cls_path = f"{call_package.service}::call.{call_file.name}.{call_id}" - metadata.append( - CallPool( - _id=get_short_hash(cls_path.encode()), - type=CallType.PYTHON, - name=call_cls.name, - description=call_cls.description, - path=cls_path, - ), - ) + call_metadata, node_metadata = await CallLoader._load_single_call_dir(call_file.name) + call_metadata.extend(call_metadata) + node_metadata.extend(node_metadata) + except Exception as e: - err = f"载入包{call_file}失败:{e},跳过载入" + err = f"载入模块{call_file}失败:{e},跳过载入。" LOGGER.info(msg=err) continue - return metadata + return call_metadata, node_metadata + + + # TODO: 动态卸载 + + + # 更新数据库 + @staticmethod + async def _update_db(call_metadata: list[CallPool], node_metadata: list[NodePool]) -> None: + """更新数据库;call和node下标一致""" + # 更新MongoDB + call_collection = MongoDB.get_collection("call") + node_collection = MongoDB.get_collection("node") + try: + for call, node in zip(call_metadata, node_metadata): + await call_collection.update_one({"_id": call.id}, {"$set": call.model_dump(exclude_none=True, by_alias=True)}, upsert=True) + await node_collection.update_one({"_id": node.id}, {"$set": node.model_dump(exclude_none=True, by_alias=True)}, upsert=True) + except Exception as e: + err = f"更新MongoDB失败:{e}" + LOGGER.error(msg=err) + raise RuntimeError(err) from e + + # 进行向量化,更新PostgreSQL + node_descriptions = [] + for node in node_metadata: + node_descriptions += [node.description] + + session = await PostgreSQL.get_session() + node_vecs = await PostgreSQL.get_embedding(node_descriptions) + for i, data in enumerate(node_vecs): + node_vec = NodePoolVector( + _id=node_metadata[i].id, + embedding=data, + ) + session.add(node_vec) + await session.commit() @staticmethod - async def load_one() -> None: - """加载Call""" - call_metadata = await CallLoader._load_system_call() - call_metadata.extend(await CallLoader._load_python_call()) + async def init() -> None: + """初始化Call信息""" + # 清空collection + call_collection = MongoDB.get_collection("call") + node_collection = MongoDB.get_collection("node") + try: + await call_collection.delete_many({}) + await node_collection.delete_many({}) + except Exception as e: + LOGGER.error(msg=f"Call和Node的collection清空失败:{e}") + + # 载入所有已知的Call信息 + try: + sys_call_metadata, sys_node_metadata = await CallLoader._load_system_call() + except Exception as e: + err = f"载入系统Call信息失败:{e};停止运行。" + LOGGER.error(msg=err) + raise RuntimeError(err) from e + + user_call_metadata, user_node_metadata = await CallLoader._load_all_user_call() + + # 合并Call元数据 + call_metadata = sys_call_metadata + user_call_metadata + node_metadata = sys_node_metadata + user_node_metadata + + # 更新数据库 + await CallLoader._update_db(call_metadata, node_metadata) + + + @staticmethod + async def load_one(call_name: str) -> None: + """加载单个Call""" + try: + call_metadata, node_metadata = await CallLoader._load_single_call_dir(call_name) + except Exception as e: + err = f"载入Call信息失败:{e}。" + LOGGER.error(msg=err) + raise RuntimeError(err) from e + + # 有数据时更新数据库 + if call_metadata: + await CallLoader._update_db(call_metadata, node_metadata) @staticmethod async def get() -> list[CallPool]: - """获取当前已知的所有Call元数据""" + """获取当前已知的所有Python Call元数据""" call_collection = MongoDB.get_collection("call") result: list[CallPool] = [] try: diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index 46a6c36d2cffe6be0a41bb1d2382250f1a29b708..a5ff7d0fbb17c05821a550ae284ce7a6326764d4 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -2,12 +2,14 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import uuid + import yaml from anyio import Path from apps.constants import LOGGER +from apps.entities.flow import ServiceMetadata from apps.entities.pool import NodePool -from apps.scheduler.call import API from apps.scheduler.openapi import ( ReducedOpenAPISpec, reduce_openapi_spec, @@ -17,12 +19,8 @@ from apps.scheduler.openapi import ( class OpenAPILoader: """OpenAPI文档载入器""" - # 工具的参数类 - api_param_cls = API.params - - @classmethod - async def load(cls, yaml_path: Path) -> ReducedOpenAPISpec: + async def _read_yaml(cls, yaml_path: Path) -> ReducedOpenAPISpec: """从本地磁盘加载OpenAPI文档""" if not yaml_path.exists(): msg = f"File not found: {yaml_path}" @@ -36,33 +34,44 @@ class OpenAPILoader: @classmethod - def _process_spec(cls, service_id: str, spec: ReducedOpenAPISpec) -> list[NodePool]: + def _process_spec(cls, service_id: str, spec: ReducedOpenAPISpec, service_metadata: ServiceMetadata) -> list[NodePool]: """将OpenAPI文档拆解为Node""" nodes = [] for api_endpoint in spec.endpoints: - # 组装新的Node + # 判断用户是否手动设置了ID + node_id = api_endpoint.id if api_endpoint.id else str(uuid.uuid4()) + + # 组装新的NodePool item node = NodePool( + _id=node_id, name=api_endpoint.name, # 此处固定Call的ID是“API” call_id="API", description=api_endpoint.description, service_id=service_id, + path="", ) - # 提取固定参数 + # 合并参数 + node.known_params = { + "method": api_endpoint.method, + "full_url": service_metadata.api.server + api_endpoint.uri, + } nodes.append(node) return nodes @classmethod - async def load_one(cls, yaml_path: Path) -> list[NodePool]: + async def load_one(cls, yaml_folder: Path, service_metadata: ServiceMetadata) -> list[NodePool]: """加载单个OpenAPI文档,可以直接指定路径""" - try: - spec = await cls.load(yaml_path) - except Exception as e: - err = f"加载OpenAPI文档失败:{e}" - LOGGER.error(msg=err) - raise RuntimeError(err) from e - - return cls._process_spec(yaml_path.name, spec) + async for yaml_path in yaml_folder.rglob("*.yaml"): + try: + spec = await cls._read_yaml(yaml_path) + except Exception as e: + err = f"加载OpenAPI文档{yaml_path}失败:{e}" + LOGGER.error(msg=err) + continue + + service_id = yaml_folder.parent.name + return cls._process_spec(service_id, spec, service_metadata) diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 77a09c4bb79c363807bc5b828749bd9c54ba705c..63ef4a4f3f1d94a8e62342cf6c38d98ae2602f8e 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -7,7 +7,10 @@ from typing import Any from anyio import Path from apps.common.config import config -from apps.entities.vector import NodeVector, ServiceVector +from apps.constants import LOGGER +from apps.entities.flow import ServiceMetadata +from apps.entities.pool import NodePool +from apps.entities.vector import NodePoolVector, ServicePoolVector from apps.models.mongo import MongoDB from apps.models.postgres import PostgreSQL from apps.scheduler.pool.loader.metadata import MetadataLoader @@ -21,17 +24,27 @@ class ServiceLoader: @classmethod - async def load(cls, service_dir: Path) -> None: + async def load_one(cls, service_dir: Path) -> None: """加载单个Service""" service_path = Path(config["SERVICE_DIR"]) / "service" / service_dir # 载入元数据 metadata = await MetadataLoader.load(service_path / "metadata.yaml") + if not isinstance(metadata, ServiceMetadata): + err = f"元数据类型错误: {service_path / 'metadata.yaml'}" + LOGGER.error(err) + raise TypeError(err) + # 载入OpenAPI文档,获取Node列表 - nodes = await OpenAPILoader.load_one(service_path / "openapi.yaml") + nodes = await OpenAPILoader.load_one(service_path / "openapi", metadata) + + + @classmethod + async def _update_db(cls, nodes: list[NodePool], metadata: ServiceMetadata) -> None: + """更新数据库""" # 向量化所有数据 session = await PostgreSQL.get_session() - service_vec = ServiceVector( + service_vec = ServicePoolVector( _id=metadata.id, embedding=PostgreSQL.get_embedding([metadata.description]), ) @@ -43,7 +56,7 @@ class ServiceLoader: node_vecs = await PostgreSQL.get_embedding(node_descriptions) for i, data in enumerate(node_vecs): - node_vec = NodeVector( + node_vec = NodePoolVector( _id=nodes[i].id, embedding=data, ) @@ -52,10 +65,6 @@ class ServiceLoader: await session.commit() - - - - @staticmethod async def save(cls) -> dict[str, Any]: """加载所有Service""" @@ -63,6 +72,6 @@ class ServiceLoader: @staticmethod - async def load_all(cls) -> dict[str, Any]: - """执行Service的加载""" + async def init() -> None: + """在初始化时加载所有Service""" pass diff --git a/mock/make_data.py b/mock/make_data.py index bfe9549e1fafe96ea57740376f5f2694fbfd88f8..bebbb4b75fe68f564e3ced84eb5ed2be8806f24c 100644 --- a/mock/make_data.py +++ b/mock/make_data.py @@ -117,7 +117,7 @@ async def insert_service_pool(): print(f"An error occurred while inserting the document: {e}") -class NodePool(PoolBase): +class Node(PoolBase): """Node信息 collection: node @@ -140,7 +140,7 @@ async def insert_node_pool() -> None: collection = MongoDB.get_collection("node") result = collection.delete_many({}) # 清空集合中的所有文档(仅用于演示) node_pools = [ - NodePool( + Node( _id=str(uuid.uuid4()), # 自动生成一个唯一的 ID service_id="6a08c845-abdc-45fb-853e-54a806437dab", # 使用 "test" 作为 service_id call_id="knowledge_base", # 随机生成一个 call_id @@ -155,7 +155,7 @@ async def insert_node_pool() -> None: }, output_schema={"content": {"type": "string", "description": "回答"}}, ), - NodePool( + Node( _id=str(uuid.uuid4()), # 自动生成一个唯一的 ID service_id="6a08c845-abdc-45fb-853e-54a806437dab", # 使用 "test" 作为 service_id call_id="LLM", # 随机生成一个 call_id @@ -171,7 +171,7 @@ async def insert_node_pool() -> None: }, output_schema={"content": {}}, ), - NodePool( + Node( _id=str(uuid.uuid4()), # 自动生成一个唯一的 ID service_id="6a08c845-abdc-45fb-853e-54a806437dab", # 使用 "test" 作为 service_id call_id="choice", # 随机生成一个 call_id @@ -201,7 +201,7 @@ async def insert_node_pool() -> None: }, }, ), - NodePool( + Node( _id=str(uuid.uuid4()), # 自动生成一个唯一的 ID service_id="6a08c845-abdc-45fb-853e-54a806437dab", # 使用 "test" 作为 service_id call_id="choice", # 随机生成一个 call_id @@ -224,7 +224,7 @@ async def insert_node_pool() -> None: }, output_schema={}, ), - NodePool( + Node( _id=str(uuid.uuid4()), # 自动生成一个唯一的 ID service_id="6a08c845-abdc-45fb-853e-54a806437dab", # 使用 "test" 作为 service_id call_id="loop_begin", # 随机生成一个 call_id @@ -233,7 +233,7 @@ async def insert_node_pool() -> None: params_schema={"operation_exp": {}}, output_schema={}, ), - NodePool( + Node( _id=str(uuid.uuid4()), # 自动生成一个唯一的 ID service_id="6a08c845-abdc-45fb-853e-54a806437dab", # 使用 "test" 作为 service_id call_id="loop_begin", # 随机生成一个 call_id @@ -242,7 +242,7 @@ async def insert_node_pool() -> None: params_schema={"operation_exp": {}}, output_schema={}, ), - NodePool( + Node( _id=str(uuid.uuid4()), # 自动生成一个唯一的 ID service_id="6a08c845-abdc-45fb-853e-54a806437dab", # 使用 "test" 作为 service_id call_id="template_exchange", # 随机生成一个 call_id @@ -269,7 +269,7 @@ async def insert_node_pool() -> None: }, }, ), - NodePool( + Node( _id="343da7db-5da8-42ef-9b59-cc56df54d9aa", service_id="1137ab09-20ae-4278-8346-524d4ce81d2f", call_id="api", @@ -332,7 +332,7 @@ async def insert_node_pool() -> None: }, }, ), - NodePool( + Node( _id="8841e328-da5b-45c7-8839-5b8054a92de7", service_id="1137ab09-20ae-4278-8346-524d4ce81d2f", call_id="choice", @@ -359,7 +359,7 @@ async def insert_node_pool() -> None: "properties": {}, }, ), - NodePool( + Node( _id="7377ad0d-f867-46fe-806a-d0c4535d2f1a", service_id="1137ab09-20ae-4278-8346-524d4ce81d2f", call_id="api", @@ -426,7 +426,7 @@ async def insert_node_pool() -> None: }, }, ), - NodePool( + Node( _id="3d94b288-a0df-4717-b75c-fc2c67e24294", service_id="1137ab09-20ae-4278-8346-524d4ce81d2f", call_id="api", @@ -501,7 +501,7 @@ async def insert_node_pool() -> None: "required": ["status_code", "data"], }, ), - NodePool( + Node( _id="1a8ddfb9-c894-4819-ab9b-88fcb5f14c10", service_id="1137ab09-20ae-4278-8346-524d4ce81d2f", call_id="llm",