diff --git a/.gitignore b/.gitignore index c62e334e6e0f521a7f889c54a758a5956a99c0ea..4a6c7f8f9ed214b33348170a24679fe8061102df 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ logs .ruff_cache/ config uv.lock +update.sh +auto-update.sh + +*.DS_Store \ No newline at end of file diff --git a/CONVERSATION_VARIABLE_FIX.md b/CONVERSATION_VARIABLE_FIX.md new file mode 100644 index 0000000000000000000000000000000000000000..2a84ca946bef291cdf8c2e376f3d2a9a00b26684 --- /dev/null +++ b/CONVERSATION_VARIABLE_FIX.md @@ -0,0 +1,179 @@ +# 对话变量模板问题修复总结 + +## 🔍 **问题描述** + +用户报告创建对话级变量`test`成功后,无法通过API接口查询到: +``` +GET /api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 +``` + +## 🔧 **根本原因分析** + +经过分析发现,问题出现在我之前重构变量架构时的遗漏: + +### 1. **创建API工作正常** +- ✅ 对话变量正确存储到FlowVariablePool的`_conversation_templates`字典中 +- ✅ 数据库持久化成功 + +### 2. **查询API有缺陷** +- ❌ `pool_manager.py`中`list_variables_from_any_pool`方法只处理了`conversation_id`参数 +- ❌ 没有处理`scope=conversation&flow_id=xxx`的查询情况 +- ❌ `get_variable_from_any_pool`方法也有同样问题 + +### 3. **更新删除API有问题** +- ❌ FlowVariablePool的`update_variable`和`delete_variable`方法只在`_variables`字典中查找 +- ❌ 找不到存储在`_conversation_templates`字典中的对话变量模板 + +## 🛠️ **修复方案** + +### 1. **修复查询逻辑** + +#### `list_variables_from_any_pool`方法 +**修复前**: +```python +elif scope == VariableScope.CONVERSATION and conversation_id: + pool = await self.get_conversation_pool(conversation_id) + if pool: + return await pool.list_variables(include_system=False) + return [] +``` + +**修复后**: +```python +elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 使用conversation_id查询对话变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + return await pool.list_variables(include_system=False) + elif flow_id: + # 使用flow_id查询对话变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.list_conversation_templates() + return [] +``` + +#### `get_variable_from_any_pool`方法 +类似的修复,支持通过`flow_id`查询对话变量模板。 + +### 2. **修复创建逻辑** + +#### 修改`create_variable`路由 +**修复前**: +```python +# 创建变量 +variable = await pool.add_variable(...) +``` + +**修复后**: +```python +# 根据作用域创建不同类型的变量 +if request.scope == VariableScope.CONVERSATION: + # 创建对话变量模板 + variable = await pool.add_conversation_template(...) +else: + # 创建其他类型的变量 + variable = await pool.add_variable(...) +``` + +### 3. **增强FlowVariablePool功能** + +为FlowVariablePool添加了重写的方法,支持多字典操作: + +#### `update_variable`方法 +- 在环境变量、系统变量模板、对话变量模板中按顺序查找 +- 找到变量后执行更新操作 +- 正确持久化到数据库 + +#### `delete_variable`方法 +- 支持删除存储在不同字典中的变量 +- 保留权限检查(系统变量模板不允许删除) + +#### `get_variable`方法 +- 统一的变量查找接口 +- 支持跨字典查找 + +## ✅ **修复验证** + +### 现在支持的完整工作流程: + +#### 1. **Flow级别操作**(变量模板管理) +```bash +# 创建对话变量模板 +POST /api/variable/create +{ + "name": "test", + "var_type": "string", + "scope": "conversation", + "value": "123", + "description": "321", + "flow_id": "52e069c7-5556-42af-bdfc-63f4dc2dcd28" +} + +# 查询对话变量模板 +GET /api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 + +# 更新对话变量模板 +PUT /api/variable/update?name=test&scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 + +# 删除对话变量模板 +DELETE /api/variable/delete?name=test&scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 +``` + +#### 2. **Conversation级别操作**(变量实例管理) +```bash +# 查询对话变量实例 +GET /api/variable/list?scope=conversation&conversation_id=conv123 + +# 更新对话变量实例值 +PUT /api/variable/update?name=test&scope=conversation&conversation_id=conv123 +``` + +## 🎯 **测试建议** + +### 立即测试 +现在可以重新测试原来失败的API调用: +```bash +curl "http://10.211.55.10:8002/api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28" +``` + +### 完整测试流程 +1. **创建对话变量模板**(前端已测试成功) +2. **查询对话变量模板**(现在应该能查到) +3. **更新对话变量模板** +4. **删除对话变量模板** + +### 自动化测试 +运行测试脚本验证: +```bash +cd euler-copilot-framework +python test_conversation_variables.py +``` + +## 📊 **架构完整性验证** + +现在所有变量类型的查询都应该正常工作: + +### Flow级别查询 +- ✅ 系统变量模板:`GET /api/variable/list?scope=system&flow_id=xxx` +- ✅ 对话变量模板:`GET /api/variable/list?scope=conversation&flow_id=xxx` +- ✅ 环境变量:`GET /api/variable/list?scope=environment&flow_id=xxx` + +### Conversation级别查询 +- ✅ 系统变量实例:`GET /api/variable/list?scope=system&conversation_id=xxx` +- ✅ 对话变量实例:`GET /api/variable/list?scope=conversation&conversation_id=xxx` + +### User级别查询 +- ✅ 用户变量:`GET /api/variable/list?scope=user` + +## 🎉 **预期结果** + +修复后,你的前端应该能够: + +1. **成功创建对话变量模板**(已验证) +2. **成功查询对话变量模板**(修复的核心问题) +3. **成功更新对话变量模板** +4. **成功删除对话变量模板** + +所有操作都在Flow级别进行,符合你的设计需求:**Flow级别管理模板定义,Conversation级别操作实际数据**。 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 7180b6a9e4fc7316a27efb46a431676dbbb31780..a52aed1db099236b8a7612961892fa3db80264ed 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,10 @@ -FROM hub.oepkgs.net/neocopilot/framework_base:0.9.6-x86-test +FROM hub.oepkgs.net/neocopilot/framework_base:0.9.6-arm ENV PYTHONPATH=/app ENV TIKTOKEN_CACHE_DIR=/app/assets/tiktoken -COPY --chmod=550 ./ /app/ +COPY ./ /app/ +RUN chmod -R 550 /app/ RUN chmod 766 /root -CMD ["uv", "run", "--no-sync", "--no-dev", "apps/main.py"] +CMD ["uv", "run", "--no-dev", "apps/main.py"] diff --git a/apps/common/redis_cache.py b/apps/common/redis_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..019d0be917a130360ac5a6e2e39459af7ad5b311 --- /dev/null +++ b/apps/common/redis_cache.py @@ -0,0 +1,271 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Redis缓存模块 - 用于前置节点变量预解析缓存""" + +import json +import logging +import asyncio +from typing import List, Dict, Any, Optional +from datetime import datetime, UTC + +import redis.asyncio as redis +from apps.common.singleton import SingletonMeta + +logger = logging.getLogger(__name__) + + +class RedisCache(metaclass=SingletonMeta): + """Redis缓存管理器""" + + def __init__(self): + self._redis: Optional[redis.Redis] = None + self._connected = False + + async def init(self, redis_config=None, redis_url: str = None): + """初始化Redis连接 + + Args: + redis_config: Redis配置对象(优先级更高) + redis_url: Redis连接URL(降级选项) + """ + try: + if redis_config: + # 使用配置对象构建连接,添加连接池和超时参数 + self._redis = redis.Redis( + host=redis_config.host, + port=redis_config.port, + password=redis_config.password if redis_config.password else None, + db=redis_config.database, + decode_responses=redis_config.decode_responses, + # 连接池配置 + max_connections=redis_config.max_connections, + # 超时配置 + socket_timeout=redis_config.socket_timeout, + socket_connect_timeout=redis_config.socket_connect_timeout, + socket_keepalive=True, + socket_keepalive_options={}, + # 连接重试 + retry_on_timeout=True, + retry_on_error=[ConnectionError, TimeoutError], + health_check_interval=redis_config.health_check_interval + ) + logger.info(f"使用配置连接Redis: {redis_config.host}:{redis_config.port}, 数据库: {redis_config.database}") + elif redis_url: + # 降级使用URL连接 + self._redis = redis.from_url( + redis_url, + decode_responses=True, + socket_timeout=5.0, + socket_connect_timeout=5.0, + max_connections=10 + ) + logger.info(f"使用URL连接Redis: {redis_url}") + else: + raise ValueError("必须提供redis_config或redis_url参数") + + # 测试连接 + logger.info("正在测试Redis连接...") + ping_result = await self._redis.ping() + logger.info(f"Redis ping结果: {ping_result}") + + # 测试基本操作 + test_key = "__redis_test__" + await self._redis.set(test_key, "test", ex=10) + test_value = await self._redis.get(test_key) + await self._redis.delete(test_key) + logger.info(f"Redis读写测试成功: {test_value}") + + self._connected = True + logger.info("Redis连接初始化成功") + except ConnectionError as e: + logger.error(f"Redis连接错误: {e}") + self._connected = False + except TimeoutError as e: + logger.error(f"Redis连接超时: {e}") + self._connected = False + except Exception as e: + logger.error(f"Redis连接初始化失败: {e}") + logger.error(f"错误类型: {type(e).__name__}") + self._connected = False + + def is_connected(self) -> bool: + """检查Redis连接状态""" + return self._connected and self._redis is not None + + async def close(self): + """关闭Redis连接""" + if self._redis: + await self._redis.close() + self._connected = False + + +class PredecessorVariableCache: + """前置节点变量预解析缓存管理器""" + + def __init__(self, redis_cache: RedisCache): + self.redis = redis_cache + self.CACHE_PREFIX = "predecessor_vars" + self.PARSING_STATUS_PREFIX = "parsing_status" + self.CACHE_TTL = 3600 * 24 # 缓存24小时 + + def _get_cache_key(self, flow_id: str, step_id: str) -> str: + """生成缓存key""" + return f"{self.CACHE_PREFIX}:{flow_id}:{step_id}" + + def _get_status_key(self, flow_id: str, step_id: str) -> str: + """生成解析状态key""" + return f"{self.PARSING_STATUS_PREFIX}:{flow_id}:{step_id}" + + def _get_flow_hash_key(self, flow_id: str) -> str: + """生成Flow哈希key,用于存储Flow的拓扑结构哈希值""" + return f"flow_hash:{flow_id}" + + async def get_cached_variables(self, flow_id: str, step_id: str) -> Optional[List[Dict[str, Any]]]: + """获取缓存的前置节点变量""" + if not self.redis.is_connected(): + return None + + try: + cache_key = self._get_cache_key(flow_id, step_id) + cached_data = await self.redis._redis.get(cache_key) + + if cached_data: + data = json.loads(cached_data) + logger.info(f"从缓存获取前置节点变量: {flow_id}:{step_id}, 数量: {len(data.get('variables', []))}") + return data.get('variables', []) + + except Exception as e: + logger.error(f"获取缓存的前置节点变量失败: {e}") + + return None + + async def set_cached_variables(self, flow_id: str, step_id: str, variables: List[Dict[str, Any]], flow_hash: str): + """设置缓存的前置节点变量""" + if not self.redis.is_connected(): + return False + + try: + cache_key = self._get_cache_key(flow_id, step_id) + cache_data = { + 'variables': variables, + 'flow_hash': flow_hash, + 'cached_at': datetime.now(UTC).isoformat(), + 'step_count': len(variables) + } + + await self.redis._redis.setex( + cache_key, + self.CACHE_TTL, + json.dumps(cache_data, default=str) + ) + + logger.info(f"缓存前置节点变量成功: {flow_id}:{step_id}, 数量: {len(variables)}") + return True + + except Exception as e: + logger.error(f"缓存前置节点变量失败: {e}") + return False + + async def is_parsing_in_progress(self, flow_id: str, step_id: str) -> bool: + """检查是否正在解析中""" + if not self.redis.is_connected(): + return False + + try: + status_key = self._get_status_key(flow_id, step_id) + status = await self.redis._redis.get(status_key) + return status == "parsing" + except Exception as e: + logger.error(f"检查解析状态失败: {e}") + return False + + async def set_parsing_status(self, flow_id: str, step_id: str, status: str, ttl: int = 300): + """设置解析状态 (parsing, completed, failed)""" + if not self.redis.is_connected(): + return False + + try: + status_key = self._get_status_key(flow_id, step_id) + await self.redis._redis.setex(status_key, ttl, status) + return True + except Exception as e: + logger.error(f"设置解析状态失败: {e}") + return False + + async def wait_for_parsing_completion(self, flow_id: str, step_id: str, max_wait_time: int = 30) -> bool: + """等待解析完成""" + if not self.redis.is_connected(): + return False + + start_time = datetime.now(UTC) + + while (datetime.now(UTC) - start_time).total_seconds() < max_wait_time: + if not await self.is_parsing_in_progress(flow_id, step_id): + # 检查是否有缓存结果 + cached_vars = await self.get_cached_variables(flow_id, step_id) + return cached_vars is not None + + await asyncio.sleep(0.5) # 等待500ms后重试 + + logger.warning(f"等待解析完成超时: {flow_id}:{step_id}") + return False + + async def invalidate_flow_cache(self, flow_id: str): + """使某个Flow的所有缓存失效""" + if not self.redis.is_connected(): + return + + try: + # 查找所有相关的缓存key + pattern = f"{self.CACHE_PREFIX}:{flow_id}:*" + keys = await self.redis._redis.keys(pattern) + + # 同时删除解析状态key + status_pattern = f"{self.PARSING_STATUS_PREFIX}:{flow_id}:*" + status_keys = await self.redis._redis.keys(status_pattern) + + all_keys = keys + status_keys + + if all_keys: + await self.redis._redis.delete(*all_keys) + logger.info(f"清除Flow缓存: {flow_id}, 删除key数量: {len(all_keys)}") + + except Exception as e: + logger.error(f"清除Flow缓存失败: {e}") + + async def get_flow_hash(self, flow_id: str) -> Optional[str]: + """获取Flow的拓扑结构哈希值""" + if not self.redis.is_connected(): + return None + + try: + hash_key = self._get_flow_hash_key(flow_id) + return await self.redis._redis.get(hash_key) + except Exception as e: + logger.error(f"获取Flow哈希失败: {e}") + return None + + async def set_flow_hash(self, flow_id: str, flow_hash: str): + """设置Flow的拓扑结构哈希值""" + if not self.redis.is_connected(): + return False + + try: + # 检查事件循环是否仍然活跃 + import asyncio + try: + asyncio.get_running_loop() + except RuntimeError: + logger.warning(f"事件循环已关闭,跳过设置Flow哈希: {flow_id}") + return False + + hash_key = self._get_flow_hash_key(flow_id) + await self.redis._redis.setex(hash_key, self.CACHE_TTL, flow_hash) + return True + except Exception as e: + logger.error(f"设置Flow哈希失败: {e}") + return False + + +# 全局实例 +redis_cache = RedisCache() +predecessor_cache = PredecessorVariableCache(redis_cache) \ No newline at end of file diff --git a/apps/dependency/user.py b/apps/dependency/user.py index fce67e51e00dd76b35bec2f7787dfe8039111589..87cbd290283507aa911e2c0d188bea2099bd150f 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -5,10 +5,12 @@ import logging from fastapi import Depends from fastapi.security import OAuth2PasswordBearer +import secrets from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config from apps.services.api_key import ApiKeyManager from apps.services.session import SessionManager @@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return secrets.token_hex(16) session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( @@ -69,6 +74,9 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return Config().get_config().no_auth.user_sub session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/llm/function.py b/apps/llm/function.py index 1f995fe7ba187cead03aa6fc62a4cbce1ec05a65..74cc9a1723395e77d799e38a93e161328bd1e596 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -119,7 +119,7 @@ class FunctionLLM: "name": "generate", "description": "Generate answer based on the background information", "parameters": schema, - }, + } }, ] diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index fdb36fc05adf38920bcce0d962b6aafc21e44b71..453267f833b2744a65bf3ecea5f00727abb11c35 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -134,6 +134,7 @@ class ReasoningLLM: max_tokens: int | None, temperature: float | None, model: str | None = None, + frequency_penalty: float | None = None ) -> AsyncGenerator[ChatCompletionChunk, None]: """创建流式响应""" if model is None: @@ -143,6 +144,7 @@ class ReasoningLLM: messages=messages, # type: ignore[] max_tokens=max_tokens or self._config.max_tokens, temperature=temperature or self._config.temperature, + frequency_penalty=frequency_penalty or self._config.frequency_penalty, stream=True, stream_options={"include_usage": True}, ) # type: ignore[] @@ -156,6 +158,7 @@ class ReasoningLLM: streaming: bool = True, result_only: bool = True, model: str | None = None, + frequency_penalty: float | None = 0 ) -> AsyncGenerator[str, None]: """调用大模型,分为流式和非流式两种""" # 检查max_tokens和temperature @@ -166,7 +169,7 @@ class ReasoningLLM: if model is None: model = self._config.model msg_list = self._validate_messages(messages) - stream = await self._create_stream(msg_list, max_tokens, temperature, model) + stream = await self._create_stream(msg_list, max_tokens, temperature, model, frequency_penalty) reasoning = ReasoningContent() reasoning_content = "" result = "" @@ -202,7 +205,7 @@ class ReasoningLLM: yield reasoning_content yield result - logger.info("[Reasoning] 推理内容: %s\n\n%s", reasoning_content, result) + logger.info("[Reasoning] 推理内容: %r\n\n%s", reasoning_content, result) # 更新token统计 if self.input_tokens == 0 or self.output_tokens == 0: diff --git a/apps/main.py b/apps/main.py index c4ca2bfb116db11624f4e4fbbe4e876a86e5f1eb..50447ea9902ee4225cb5c78bf9ccb7aa89c4f6c4 100644 --- a/apps/main.py +++ b/apps/main.py @@ -8,6 +8,10 @@ from __future__ import annotations import asyncio import logging +import logging.config +import signal +import sys +from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI @@ -36,11 +40,62 @@ from apps.routers import ( record, service, user, + parameter, + variable ) from apps.scheduler.pool.pool import Pool +from apps.services.predecessor_cache_service import cleanup_background_tasks + +# 全局变量用于跟踪后台任务 +_cleanup_task = None +async def cleanup_on_shutdown(): + """应用关闭时的清理函数""" + logger = logging.getLogger(__name__) + logger.info("开始清理应用资源...") + + try: + # 取消定期清理任务 + global _cleanup_task + if _cleanup_task and not _cleanup_task.done(): + _cleanup_task.cancel() + try: + await _cleanup_task + except asyncio.CancelledError: + logger.info("定期清理任务已取消") + + # 清理后台任务 + await cleanup_background_tasks() + + # 关闭Redis连接 + from apps.common.redis_cache import RedisCache + redis_cache = RedisCache() + if redis_cache.is_connected(): + await redis_cache.close() + logger.info("Redis连接已关闭") + + except Exception as e: + logger.error(f"清理应用资源时出错: {e}") + + logger.info("应用资源清理完成") + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + # 启动时的初始化 + await init_resources() + + yield + + # 关闭时的清理 + await cleanup_on_shutdown() # 定义FastAPI app -app = FastAPI(redoc_url=None) +app = FastAPI( + title="Euler Copilot Framework", + description="AI-powered automation framework", + version="1.0.0", + lifespan=lifespan, +) # 定义FastAPI全局中间件 app.add_middleware( CORSMiddleware, @@ -66,6 +121,8 @@ app.include_router(llm.router) app.include_router(mcp_service.router) app.include_router(flow.router) app.include_router(user.router) +app.include_router(parameter.router) +app.include_router(variable.router) # logger配置 LOGGER_FORMAT = "%(funcName)s() - %(message)s" @@ -88,10 +145,48 @@ async def init_resources() -> None: await Pool.init() TokenCalculator() + # 初始化变量池管理器 + from apps.scheduler.variable.pool_manager import initialize_pool_manager + await initialize_pool_manager() + + # 初始化前置节点变量缓存服务 + try: + from apps.services.predecessor_cache_service import PredecessorCacheService, periodic_cleanup_background_tasks + await PredecessorCacheService.initialize_redis() + + # 启动定期清理任务 + global _cleanup_task + _cleanup_task = asyncio.create_task(start_periodic_cleanup()) + + logging.info("前置节点变量缓存服务初始化成功") + except Exception as e: + logging.warning(f"前置节点变量缓存服务初始化失败(将降级使用实时解析): {e}") + +async def start_periodic_cleanup(): + """启动定期清理任务""" + try: + from apps.services.predecessor_cache_service import periodic_cleanup_background_tasks + while True: + # 每60秒清理一次已完成的后台任务 + await asyncio.sleep(60) + await periodic_cleanup_background_tasks() + except asyncio.CancelledError: + logging.info("定期清理任务已取消") + raise # 重新抛出CancelledError + except Exception as e: + logging.error(f"定期清理任务异常: {e}") + # 运行 if __name__ == "__main__": - # 初始化必要资源 - asyncio.run(init_resources()) + def signal_handler(signum, frame): + """信号处理器""" + logger = logging.getLogger(__name__) + logger.info(f"收到信号 {signum},准备关闭应用...") + sys.exit(0) + + # 注册信号处理器 + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) # 启动FastAPI uvicorn.run(app, host="0.0.0.0", port=8002, log_level="info", log_config=None) diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py index 158cfc13a5be17f16e2d73ec5fa44c4a18e4f392..51366a2192eaa429a3f1e1d40672b9e915bec424 100644 --- a/apps/routers/api_key.py +++ b/apps/routers/api_key.py @@ -6,6 +6,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse + from apps.dependency.user import get_user, verify_user from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp from apps.schemas.response_data import ResponseData diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed629d90776b59ae3bf652381318dcabf0e..6c56aaa25ba1c22c75b45f823e6f7cd052a82341 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates +from apps.common.config import Config from apps.common.oidc import oidc_provider from apps.dependency import get_session, get_user, verify_user from apps.schemas.collection import Audit @@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: user_info = await oidc_provider.get_oidc_user(token["access_token"]) user_sub: str | None = user_info.get("user_sub", None) - if user_sub: - await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"]) except Exception as e: logger.exception("User login failed") status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN @@ -149,7 +148,7 @@ async def oidc_redirect() -> JSONResponse: # TODO(zwt): OIDC主动触发logout @router.post("/logout", dependencies=[Depends(verify_user)], response_model=ResponseData) -async def oidc_logout(token: str) -> JSONResponse: +async def oidc_logout(token: str) -> JSONResponse | None: """OIDC主动触发登出""" @@ -191,6 +190,59 @@ async def userinfo( status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": ResponseData}, }, ) + +# 增加获取user session,开放平台api调用接口 +@router.post("/create-session", response_model=ResponseData) +async def create_session( + request: Request, + user_sub: str +) -> JSONResponse: + """通过user_sub直接创建session""" + if not request.client: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="客户端IP地址缺失", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + + try: + session_id = await SessionManager.create_session( + ip=request.client.host, + user_sub=user_sub + ) + + # 记录审计日志 + data = Audit( + http_method="post", + module="auth", + client_ip=request.client.host, + user_sub=user_sub, + message="/api/auth/create-session: Session创建成功", + ) + await AuditLogManager.add_audit_log(data) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="success", + result={"session_id": session_id}, + ).model_dump(exclude_none=True, by_alias=True), + ) + except Exception as e: + logger.exception("Session创建失败") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=str(e), + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001 """更新用户协议信息""" ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True) @@ -227,3 +279,9 @@ async def update_revision_number(request: Request, user_sub: Annotated[str, Depe ), ).model_dump(exclude_none=True, by_alias=True), ) + +# user_info = await oidc_provider.get_oidc_user(token["access_token"]) + +# user_sub: str | None = user_info.get("user_sub", None) +# if user_sub: + diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7fe5162c0f7db804c6d723207ffa327d9394e3eb..b16b41bb0370b06911e3b8127dc1df86b1336391 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -6,6 +6,7 @@ import logging import uuid from collections.abc import AsyncGenerator from typing import Annotated +import json from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse, StreamingResponse @@ -36,11 +37,16 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - # 创建或还原Task + if post_body.new_task: + # 创建或还原Task + task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + if task: + await TaskManager.delete_task_by_task_id(task.id) task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) # 更改信息并刷新数据库 - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + if post_body.new_task: + task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id return task @@ -136,6 +142,45 @@ async def chat( }, ) +@router.post("/chat_without_streaming") +async def chat_without_streaming( + post_body: RequestData, + user_sub: Annotated[str, Depends(get_user)], + session_id: Annotated[str, Depends(get_session)], +) -> JSONResponse: + """非流式对话接口""" + # 问题黑名单检测 + if not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question): + # 用户扣分 + await UserBlacklistManager.change_blacklisted_users(user_sub, -10) + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") + + # 限流检查 + if await Activity.is_active(user_sub): + raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") + + content_texts = [] + async for chunk in chat_generator(post_body, user_sub, session_id): + if chunk.startswith('data: {"event": "text.add", '): + content_texts.append(chunk) + + final_answer = ''.join([ + json.loads(chunk.split('data: ')[1].split('\n\n')[0])['content']['text'] + for chunk in content_texts + if chunk.strip() + ]) + + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "code": 200, + "message": "success", + "result": { + "content": final_answer + } + }, + ) + @router.post("/stop", response_model=ResponseData) async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 diff --git a/apps/routers/flow.py b/apps/routers/flow.py index fc38c1bfb9879c1d846d700e91b812c83866c2c4..646213a88ac74c296b159678d3f47b9c0746219f 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """FastAPI Flow拓扑结构展示API""" +import logging from typing import Annotated from fastapi import APIRouter, Body, Depends, Query, status @@ -25,6 +26,8 @@ from apps.services.application import AppManager from apps.services.flow import FlowManager from apps.services.flow_validate import FlowService +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/api/flow", tags=["flow"], @@ -153,6 +156,20 @@ async def put_flow( result=FlowStructurePutMsg(), ).model_dump(exclude_none=True, by_alias=True), ) + + # 触发前置节点变量预解析(异步执行,不阻塞响应) + try: + from apps.services.predecessor_cache_service import PredecessorCacheService + import asyncio + + # 在后台异步触发预解析 + asyncio.create_task( + PredecessorCacheService.trigger_flow_parsing(flow_id, force_refresh=True) + ) + logger.info(f"已触发Flow前置节点变量预解析: {flow_id}") + except Exception as trigger_error: + logger.warning(f"触发Flow前置节点变量预解析失败: {flow_id}, 错误: {trigger_error}") + return JSONResponse( status_code=status.HTTP_200_OK, content=FlowStructurePutRsp( diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..6edbe2e142cb6589be8947c28ae2eb4a7287baa1 --- /dev/null +++ b/apps/routers/parameter.py @@ -0,0 +1,77 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, status +from fastapi.responses import JSONResponse + +from apps.dependency import get_user +from apps.dependency.user import verify_user +from apps.services.parameter import ParameterManager +from apps.schemas.response_data import ( + GetOperaRsp, + GetParamsRsp +) +from apps.services.application import AppManager +from apps.services.flow import FlowManager + +router = APIRouter( + prefix="/api/parameter", + tags=["parameter"], + dependencies=[ + Depends(verify_user), + ], +) + + +@router.get("", response_model=GetParamsRsp) +async def get_parameters( + user_sub: Annotated[str, Depends(get_user)], + app_id: Annotated[str, Query(alias="appId")], + flow_id: Annotated[str, Query(alias="flowId")], + step_id: Annotated[str, Query(alias="stepId")], +) -> JSONResponse: + """Get parameters for node choice.""" + if not await AppManager.validate_user_app_access(user_sub, app_id): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content=GetParamsRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + if not flow: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=GetParamsRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetParamsRsp( + code=status.HTTP_200_OK, + message="获取参数成功", + result=result + ).model_dump(exclude_none=True, by_alias=True), + ) + + +@router.get("/operate", response_model=GetOperaRsp) +async def get_operate_parameters( + user_sub: Annotated[str, Depends(get_user)], + param_type: Annotated[str, Query(alias="ParamType")], +) -> JSONResponse: + """Get parameters for node choice.""" + result = await ParameterManager.get_operate_and_bind_type(param_type) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetOperaRsp( + code=status.HTTP_200_OK, + message="获取操作成功", + result=result + ).model_dump(exclude_none=True, by_alias=True), + ) diff --git a/apps/routers/variable.py b/apps/routers/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..9ae53a94828a9e7e629ec6b72b11a28844c91f51 --- /dev/null +++ b/apps/routers/variable.py @@ -0,0 +1,983 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""FastAPI 变量管理 API""" + +import logging +from typing import Annotated, List, Optional, Dict + +from fastapi import APIRouter, Body, Depends, HTTPException, Query, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +from apps.dependency import get_user +from apps.dependency.user import verify_user +from apps.scheduler.variable.pool_manager import get_pool_manager +from apps.scheduler.variable.type import VariableType, VariableScope +from apps.scheduler.variable.parser import VariableParser +from apps.schemas.response_data import ResponseData +from apps.services.flow import FlowManager + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/api/variable", + tags=["variable"], + dependencies=[ + Depends(verify_user), + ], +) + + +async def _get_predecessor_node_variables( + user_sub: str, + flow_id: str, + conversation_id: Optional[str], + current_step_id: str +) -> List: + """获取前置节点的输出变量(优化版本,使用缓存) + + Args: + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID(可选,配置阶段可能为None) + current_step_id: 当前步骤ID + + Returns: + List: 前置节点的输出变量列表 + """ + try: + variables = [] + pool_manager = await get_pool_manager() + + if conversation_id: + # 运行阶段:从对话池获取实际的前置节点变量 + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + if conversation_pool: + # 获取所有对话变量 + all_conversation_vars = await conversation_pool.list_variables() + + # 筛选出前置节点的输出变量(格式为 node_id.key) + for var in all_conversation_vars: + var_name = var.name + # 检查是否为节点输出变量格式(包含.且不是系统变量) + if "." in var_name and not var_name.startswith("system."): + # 提取节点ID + node_id = var_name.split(".")[0] + + # 检查是否为前置节点(这里可以根据需要添加更精确的前置判断逻辑) + if node_id != current_step_id: # 不是当前节点的变量 + variables.append(var) + else: + # 配置阶段:优先使用缓存,降级到实时解析 + try: + # 尝试使用优化的缓存服务 + from apps.services.predecessor_cache_service import PredecessorCacheService + + # 1. 先从flow池中查找已存在的前置节点变量 + flow_pool = await pool_manager.get_flow_pool(flow_id) + if flow_pool: + flow_conversation_vars = await flow_pool.list_variables() + + # 筛选出前置节点的输出变量(格式为 node_id.key) + for var in flow_conversation_vars: + var_name = var.name + if "." in var_name and not var_name.startswith("system."): + node_id = var_name.split(".")[0] + if node_id != current_step_id: + variables.append(var) + + # 2. 使用优化的缓存服务获取前置节点变量 + cached_var_data = await PredecessorCacheService.get_predecessor_variables_optimized( + flow_id, current_step_id, user_sub, max_wait_time=5 + ) + + # 将缓存的变量数据转换为Variable对象 + for var_data in cached_var_data: + try: + from apps.scheduler.variable.variables import create_variable + from apps.scheduler.variable.base import VariableMetadata + from apps.scheduler.variable.type import VariableType, VariableScope + from datetime import datetime + + # 创建变量元数据 + metadata = VariableMetadata( + name=var_data['name'], + var_type=VariableType(var_data['var_type']), + scope=VariableScope(var_data['scope']), + description=var_data.get('description', ''), + created_by=user_sub, + created_at=datetime.fromisoformat(var_data['created_at'].replace('Z', '+00:00')), + updated_at=datetime.fromisoformat(var_data['updated_at'].replace('Z', '+00:00')) + ) + + # 创建变量对象,并附加缓存的节点信息 + variable = create_variable(metadata, var_data.get('value', '')) + + # 将节点信息附加到变量对象上(用于后续响应格式化) + if hasattr(variable, '_cache_data'): + variable._cache_data = var_data + else: + # 如果对象不支持动态属性,我们可以创建一个包装类或者在响应时处理 + setattr(variable, '_cache_data', var_data) + + variables.append(variable) + + except Exception as var_create_error: + logger.warning(f"创建缓存变量对象失败: {var_create_error}") + continue + + logger.info(f"配置阶段:为节点 {current_step_id} 找到前置节点变量总数: {len([v for v in variables if hasattr(v, 'name') and '.' in v.name and not v.name.startswith('system.')])}") + + except Exception as flow_error: + logger.warning(f"配置阶段获取前置节点变量失败,降级到实时解析: {flow_error}") + # 降级到原有的实时解析逻辑 + predecessor_vars = await _get_predecessor_variables_from_topology( + flow_id, current_step_id, user_sub + ) + variables.extend(predecessor_vars) + + return variables + + except Exception as e: + logger.error(f"获取前置节点变量失败: {e}") + return [] + + + + + +# 请求和响应模型 +class CreateVariableRequest(BaseModel): + """创建变量请求""" + name: str = Field(description="变量名称") + var_type: VariableType = Field(description="变量类型") + scope: VariableScope = Field(description="变量作用域") + value: Optional[str] = Field(default=None, description="变量值") + description: Optional[str] = Field(default=None, description="变量描述") + flow_id: Optional[str] = Field(default=None, description="流程ID(环境级和对话级变量必需)") + + +class UpdateVariableRequest(BaseModel): + """更新变量请求""" + value: Optional[str] = Field(default=None, description="新的变量值") + var_type: Optional[VariableType] = Field(default=None, description="新的变量类型") + description: Optional[str] = Field(default=None, description="新的变量描述") + + +class VariableResponse(BaseModel): + """变量响应""" + name: str = Field(description="变量名称") + var_type: str = Field(description="变量类型") + scope: str = Field(description="变量作用域") + value: str = Field(description="变量值") + description: Optional[str] = Field(description="变量描述") + created_at: str = Field(description="创建时间") + updated_at: str = Field(description="更新时间") + step: Optional[str] = Field(default=None, description="节点名称(前置节点变量专用)") + step_id: Optional[str] = Field(default=None, description="节点ID(前置节点变量专用)") + + +class VariableListResponse(BaseModel): + """变量列表响应""" + variables: List[VariableResponse] = Field(description="变量列表") + total: int = Field(description="总数量") + + +class ParseTemplateRequest(BaseModel): + """解析模板请求""" + template: str = Field(description="包含变量引用的模板") + flow_id: Optional[str] = Field(default=None, description="流程ID") + + +class ParseTemplateResponse(BaseModel): + """解析模板响应""" + parsed_template: str = Field(description="解析后的模板") + variables_used: List[str] = Field(description="使用的变量引用") + + +class ValidateTemplateResponse(BaseModel): + """验证模板响应""" + is_valid: bool = Field(description="是否有效") + invalid_references: List[str] = Field(description="无效的变量引用") + + +@router.post( + "/create", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + }, +) +async def create_variable( + user_sub: Annotated[str, Depends(get_user)], + request: CreateVariableRequest = Body(...), +) -> ResponseData: + """创建变量""" + try: + # 验证作用域权限 + if request.scope == VariableScope.SYSTEM: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="不允许创建系统级变量" + ) + + pool_manager = await get_pool_manager() + + # 根据作用域获取合适的变量池 + if request.scope == VariableScope.USER: + # 用户级变量需要user_sub参数 + if not user_sub: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户级变量需要用户身份" + ) + pool = await pool_manager.get_user_pool(user_sub) + elif request.scope == VariableScope.ENVIRONMENT: + # 环境级变量需要flow_id参数 + if not request.flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="环境级变量需要flow_id参数" + ) + pool = await pool_manager.get_flow_pool(request.flow_id) + elif request.scope == VariableScope.CONVERSATION: + # 对话级变量需要flow_id参数,用于创建模板 + if not request.flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量需要flow_id参数" + ) + # 对话级变量模板在流程池中定义 + pool = await pool_manager.get_flow_pool(request.flow_id) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的变量作用域: {request.scope.value}" + ) + + if not pool: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="无法获取变量池" + ) + + # 根据作用域创建不同类型的变量 + if request.scope == VariableScope.CONVERSATION: + # 创建对话变量模板 + variable = await pool.add_conversation_template( + name=request.name, + var_type=request.var_type, + default_value=request.value, + description=request.description, + created_by=user_sub + ) + else: + # 创建其他类型的变量 + variable = await pool.add_variable( + name=request.name, + var_type=request.var_type, + value=request.value, + description=request.description, + created_by=user_sub + ) + + + return ResponseData( + code=200, + message="变量创建成功", + result={"variable_name": variable.name}, + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"创建变量失败: {str(e)}" + ) + + +@router.put( + "/update", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + }, +) +async def update_variable( + user_sub: Annotated[str, Depends(get_user)], + name: str = Query(..., description="变量名称"), + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"), + request: UpdateVariableRequest = Body(...), +) -> ResponseData: + """更新变量值""" + try: + pool_manager = await get_pool_manager() + + # 根据作用域获取合适的变量池 + if scope == VariableScope.USER: + if not user_sub: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户级变量需要用户身份" + ) + pool = await pool_manager.get_user_pool(user_sub) + elif scope == VariableScope.ENVIRONMENT: + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="环境级变量需要flow_id参数" + ) + pool = await pool_manager.get_flow_pool(flow_id) + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 运行时:使用对话池,如果不存在则创建 + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量运行时需要conversation_id和flow_id参数" + ) + pool = await pool_manager.get_conversation_pool(conversation_id) + if not pool: + # 对话池不存在,自动创建 + pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) + elif flow_id: + # 配置时:使用流程池 + pool = await pool_manager.get_flow_pool(flow_id) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量需要conversation_id(运行时)或flow_id(配置时)参数" + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的变量作用域: {scope.value}" + ) + + if not pool: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="无法获取变量池" + ) + + # 更新变量 + variable = await pool.update_variable( + name=name, + value=request.value, + var_type=request.var_type, + description=request.description + ) + + return ResponseData( + code=200, + message="变量更新成功", + result={"variable_name": variable.name} + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e) + ) + except PermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"更新变量失败: {str(e)}" + ) + + +@router.delete( + "/delete", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + }, +) +async def delete_variable( + user_sub: Annotated[str, Depends(get_user)], + name: str = Query(..., description="变量名称"), + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"), +) -> ResponseData: + """删除变量""" + try: + pool_manager = await get_pool_manager() + + # 根据作用域获取合适的变量池 + if scope == VariableScope.USER: + if not user_sub: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户级变量需要用户身份" + ) + pool = await pool_manager.get_user_pool(user_sub) + elif scope == VariableScope.ENVIRONMENT: + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="环境级变量需要flow_id参数" + ) + pool = await pool_manager.get_flow_pool(flow_id) + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 运行时:使用对话池,如果不存在则创建 + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量运行时需要conversation_id和flow_id参数" + ) + pool = await pool_manager.get_conversation_pool(conversation_id) + if not pool: + # 对话池不存在,自动创建 + pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) + elif flow_id: + # 配置时:使用流程池 + pool = await pool_manager.get_flow_pool(flow_id) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量需要conversation_id(运行时)或flow_id(配置时)参数" + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的变量作用域: {scope.value}" + ) + + if not pool: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="无法获取变量池" + ) + + # 删除变量 + success = await pool.delete_variable(name) + + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="变量不存在" + ) + + return ResponseData( + code=200, + message="变量删除成功", + result={"variable_name": name} + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) + except PermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"删除变量失败: {str(e)}" + ) + + +@router.get( + "/get", + responses={ + status.HTTP_200_OK: {"model": VariableResponse}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + }, +) +async def get_variable( + user_sub: Annotated[str, Depends(get_user)], + name: str = Query(..., description="变量名称"), + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"), +) -> VariableResponse: + """获取单个变量""" + try: + pool_manager = await get_pool_manager() + + # 根据作用域获取变量 + variable = await pool_manager.get_variable_from_any_pool( + name=name, + scope=scope, + user_id=user_sub if scope == VariableScope.USER else None, + flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None + ) + + if not variable: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="变量不存在" + ) + + # 检查权限 + if not variable.can_access(user_sub): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="没有权限访问此变量" + ) + + # 构建响应 + var_dict = variable.to_dict() + return VariableResponse( + name=variable.name, + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取变量失败: {str(e)}" + ) + + +@router.get( + "/list", + responses={ + status.HTTP_200_OK: {"model": VariableListResponse}, + }, +) +async def list_variables( + user_sub: Annotated[str, Depends(get_user)], + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"), + current_step_id: Optional[str] = Query(default=None, description="当前步骤ID(用于获取前置节点变量)"), +) -> VariableListResponse: + """列出指定作用域的变量""" + try: + pool_manager = await get_pool_manager() + + # 获取变量列表 + variables = await pool_manager.list_variables_from_any_pool( + scope=scope, + user_id=user_sub if scope == VariableScope.USER else None, + flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None + ) + + # 如果是对话级变量且提供了current_step_id,则额外获取前置节点的输出变量 + if scope == VariableScope.CONVERSATION and current_step_id and flow_id: + predecessor_variables = await _get_predecessor_node_variables( + user_sub, flow_id, conversation_id, current_step_id + ) + variables.extend(predecessor_variables) + + # 过滤权限并构建响应 + filtered_variables = [] + for variable in variables: + if variable.can_access(user_sub): + var_dict = variable.to_dict() + + # 检查是否为前置节点变量 + is_predecessor_var = ( + "." in variable.name and + not variable.name.startswith("system.") and + scope == VariableScope.CONVERSATION and + flow_id + ) + + if is_predecessor_var: + # 前置节点变量特殊处理 + parts = variable.name.split(".", 1) + if len(parts) == 2: + step_id, var_name = parts + + # 优先使用缓存数据中的节点信息 + if hasattr(variable, '_cache_data') and variable._cache_data: + cache_data = variable._cache_data + step_name = cache_data.get('step_name', step_id) + step_id_from_cache = cache_data.get('step_id', step_id) + else: + # 降级到实时获取节点信息 + node_info = await _get_node_info_by_step_id(flow_id, step_id) + step_name = node_info["name"] + step_id_from_cache = node_info["step_id"] + + filtered_variables.append(VariableResponse( + name=var_name, # 只保留变量名部分 + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + step=step_name, # 节点名称 + step_id=step_id_from_cache # 节点ID + )) + else: + # 降级处理,如果格式不符合预期 + filtered_variables.append(VariableResponse( + name=variable.name, + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + )) + else: + # 普通变量 + filtered_variables.append(VariableResponse( + name=variable.name, + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + )) + + return VariableListResponse( + variables=filtered_variables, + total=len(filtered_variables) + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取变量列表失败: {str(e)}" + ) + + +@router.post( + "/parse", + responses={ + status.HTTP_200_OK: {"model": ParseTemplateResponse}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + }, +) +async def parse_template( + user_sub: Annotated[str, Depends(get_user)], + request: ParseTemplateRequest = Body(...), +) -> ParseTemplateResponse: + """解析模板中的变量引用""" + try: + # 创建变量解析器 + parser = VariableParser( + user_id=user_sub, + flow_id=request.flow_id, + conversation_id=None, # 不再使用conversation_id + ) + + # 解析模板 + parsed_template = await parser.parse_template(request.template) + + # 提取使用的变量 + variables_used = await parser.extract_variables(request.template) + + return ParseTemplateResponse( + parsed_template=parsed_template, + variables_used=variables_used + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"解析模板失败: {str(e)}" + ) + + +@router.post( + "/validate", + responses={ + status.HTTP_200_OK: {"model": ValidateTemplateResponse}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + }, +) +async def validate_template( + user_sub: Annotated[str, Depends(get_user)], + request: ParseTemplateRequest = Body(...), +) -> ValidateTemplateResponse: + """验证模板中的变量引用是否有效""" + try: + # 创建变量解析器 + parser = VariableParser( + user_id=user_sub, + flow_id=request.flow_id, + conversation_id=None, # 不再使用conversation_id + ) + + # 验证模板 + is_valid, invalid_refs = await parser.validate_template(request.template) + + return ValidateTemplateResponse( + is_valid=is_valid, + invalid_references=invalid_refs + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"验证模板失败: {str(e)}" + ) + + +@router.get( + "/types", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + }, +) +async def get_variable_types() -> ResponseData: + """获取支持的变量类型列表""" + return ResponseData( + code=200, + message="获取变量类型成功", + result={ + "types": [vtype.value for vtype in VariableType], + "scopes": [scope.value for scope in VariableScope], + } + ) + + +@router.post( + "/clear-conversation", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + }, +) +async def clear_conversation_variables( + user_sub: Annotated[str, Depends(get_user)], + flow_id: str = Query(..., description="流程ID"), +) -> ResponseData: + """清空指定工作流的对话级变量""" + try: + pool_manager = await get_pool_manager() + # 清空工作流的对话级变量 + await pool_manager.clear_conversation_variables(flow_id) + + return ResponseData( + code=200, + message="工作流对话变量已清空", + result={"flow_id": flow_id} + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"清空对话变量失败: {str(e)}" + ) + + +async def _get_node_info_by_step_id(flow_id: str, step_id: str) -> Dict[str, str]: + """根据step_id获取节点信息""" + try: + flow_item = await _get_flow_by_flow_id(flow_id) + if not flow_item: + return {"name": step_id, "step_id": step_id} # 降级返回step_id作为名称 + + # 查找对应的节点 + for node in flow_item.nodes: + if node.step_id == step_id: + return { + "name": node.name or step_id, # 如果没有名称则使用step_id + "step_id": step_id + } + + # 如果没有找到节点,返回默认值 + return {"name": step_id, "step_id": step_id} + + except Exception as e: + logger.error(f"获取节点信息失败: {e}") + return {"name": step_id, "step_id": step_id} + + +async def _get_predecessor_variables_from_topology( + flow_id: str, + current_step_id: str, + user_sub: str +) -> List: + """通过工作流拓扑分析获取前置节点变量""" + try: + variables = [] + + # 直接通过flow_id获取工作流拓扑信息 + flow_item = await _get_flow_by_flow_id(flow_id) + if not flow_item: + logger.warning(f"无法获取工作流信息: flow_id={flow_id}") + return variables + + # 分析前置节点 + predecessor_nodes = _find_predecessor_nodes(flow_item, current_step_id) + + # 为每个前置节点创建潜在的输出变量 + for node in predecessor_nodes: + node_vars = await _create_node_output_variables(node, user_sub) + variables.extend(node_vars) + + logger.info(f"通过拓扑分析为节点 {current_step_id} 创建了 {len(variables)} 个前置节点变量") + return variables + + except Exception as e: + logger.error(f"通过拓扑分析获取前置节点变量失败: {e}") + return [] + + +async def _get_flow_by_flow_id(flow_id: str): + """直接通过flow_id获取工作流信息""" + try: + from apps.common.mongo import MongoDB + + app_collection = MongoDB().get_collection("app") + + # 查询包含此flow_id的app,同时获取app_id + app_record = await app_collection.find_one( + {"flows.id": flow_id}, + {"_id": 1} + ) + + if not app_record: + logger.warning(f"未找到包含flow_id {flow_id} 的应用") + return None + + app_id = app_record["_id"] + + # 使用现有的FlowManager方法获取flow + flow_item = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + return flow_item + + except Exception as e: + logger.error(f"通过flow_id获取工作流失败: {e}") + return None + + +def _find_predecessor_nodes(flow_item, current_step_id: str) -> List: + """在工作流中查找前置节点""" + try: + predecessor_nodes = [] + + # 遍历边,找到指向当前节点的边 + for edge in flow_item.edges: + if edge.target_node == current_step_id: + # 找到前置节点 + source_node = next( + (node for node in flow_item.nodes if node.step_id == edge.source_node), + None + ) + if source_node: + predecessor_nodes.append(source_node) + + logger.info(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点") + return predecessor_nodes + + except Exception as e: + logger.error(f"查找前置节点失败: {e}") + return [] + + +async def _create_node_output_variables(node, user_sub: str) -> List: + """根据节点的output_parameters配置创建输出变量""" + try: + from apps.scheduler.variable.variables import create_variable + from apps.scheduler.variable.base import VariableMetadata + from datetime import datetime, UTC + + variables = [] + node_id = node.step_id + + # 调试:输出节点的完整参数信息 + logger.info(f"节点 {node_id} 的参数结构: {node.parameters}") + + # 统一从节点的output_parameters创建变量 + output_params = {} + if hasattr(node, 'parameters') and node.parameters: + # 尝试不同的访问方式 + if isinstance(node.parameters, dict): + output_params = node.parameters.get('output_parameters', {}) + logger.info(f"从字典中获取output_parameters: {output_params}") + else: + output_params = getattr(node.parameters, 'output_parameters', {}) + logger.info(f"从对象属性中获取output_parameters: {output_params}") + + # 如果没有配置output_parameters,跳过此节点 + if not output_params: + logger.info(f"节点 {node_id} 没有配置output_parameters,跳过创建输出变量") + return variables + + # 遍历output_parameters中的每个key-value对,创建对应的变量 + for param_name, param_config in output_params.items(): + # 解析参数配置 + if isinstance(param_config, dict): + param_type = param_config.get('type', 'string') + description = param_config.get('description', '') + else: + # 如果param_config不是字典,可能是简单的类型字符串 + param_type = str(param_config) if param_config else 'string' + description = '' + + # 确定变量类型 + var_type = VariableType.STRING # 默认类型 + if param_type == 'number': + var_type = VariableType.NUMBER + elif param_type == 'boolean': + var_type = VariableType.BOOLEAN + elif param_type == 'object': + var_type = VariableType.OBJECT + elif param_type == 'array' or param_type == 'array[any]': + var_type = VariableType.ARRAY_ANY + elif param_type == 'array[string]': + var_type = VariableType.ARRAY_STRING + elif param_type == 'array[number]': + var_type = VariableType.ARRAY_NUMBER + elif param_type == 'array[object]': + var_type = VariableType.ARRAY_OBJECT + elif param_type == 'array[boolean]': + var_type = VariableType.ARRAY_BOOLEAN + elif param_type == 'array[file]': + var_type = VariableType.ARRAY_FILE + elif param_type == 'array[secret]': + var_type = VariableType.ARRAY_SECRET + elif param_type == 'file': + var_type = VariableType.FILE + elif param_type == 'secret': + var_type = VariableType.SECRET + + # 创建变量元数据 + metadata = VariableMetadata( + name=f"{node_id}.{param_name}", + var_type=var_type, + scope=VariableScope.CONVERSATION, + description=description or f"来自节点 {node_id} 的输出参数 {param_name}", + created_by=user_sub, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + + # 创建变量对象 + variable = create_variable(metadata, "") # 配置阶段的潜在变量,值为空 + variables.append(variable) + + logger.info(f"为节点 {node_id} 创建了 {len(variables)} 个输出变量: {[v.name for v in variables]}") + return variables + + except Exception as e: + logger.error(f"创建节点输出变量失败: {e}") + return [] \ No newline at end of file diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 2ee8b862885b0d88e519bf00a1678a49863df2bd..a7475bf9c82dd94f97f8097fb263e3cfaf5241d4 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -2,20 +2,28 @@ """Agent工具部分""" from apps.scheduler.call.api.api import API +from apps.scheduler.call.code.code import Code from apps.scheduler.call.graph.graph import Graph from apps.scheduler.call.llm.llm import LLM from apps.scheduler.call.mcp.mcp import MCP from apps.scheduler.call.rag.rag import RAG +from apps.scheduler.call.reply.direct_reply import DirectReply from apps.scheduler.call.sql.sql import SQL -from apps.scheduler.call.suggest.suggest import Suggestion +# from apps.scheduler.call.graph.graph import Graph +# from apps.scheduler.call.suggest.suggest import Suggestion +from apps.scheduler.call.suggest.suggest import Suggestion +from apps.scheduler.call.choice.choice import Choice # 只包含需要在编排界面展示的工具 __all__ = [ "API", + "Code", + "DirectReply", "LLM", "MCP", "RAG", - "SQL", - "Graph", - "Suggestion", + # "SQL", + # "Graph", + # "Suggestion", + "Choice" ] diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index e1891f7259b72a2fa03228f6289c54da6297b958..c34cbc0585616c0c12f55e6d4f97012396818396 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -14,8 +14,8 @@ from pydantic.json_schema import SkipJsonSchema from apps.common.oidc import oidc_provider from apps.scheduler.call.api.schema import APIInput, APIOutput -from apps.scheduler.call.core import CoreCall -from apps.schemas.enum_var import CallOutputType, ContentType, HTTPMethod +from apps.scheduler.call.core import CoreCall, NodeType +from apps.schemas.enum_var import CallOutputType, CallType, ContentType, HTTPMethod from apps.schemas.scheduler import ( CallError, CallInfo, @@ -49,20 +49,28 @@ SUCCESS_HTTP_CODES = [ class API(CoreCall, input_model=APIInput, output_model=APIOutput): """API调用工具""" - enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=True) + enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False) url: str = Field(description="API接口的完整URL") method: HTTPMethod = Field(description="API接口的HTTP Method") content_type: ContentType | None = Field(description="API接口的Content-Type", default=None) timeout: int = Field(description="工具超时时间", default=300, gt=30) - body: dict[str, Any] = Field(description="已知的部分请求体", default={}) - query: dict[str, Any] = Field(description="已知的部分请求参数", default={}) + body: list[dict[str, Any]] = Field(description="已知的部分请求体", default=[]) + query: list[dict[str, Any]]= Field(description="已知的部分请求参数", default=[]) + headers: list[dict[str, Any]] = Field(description="已知的部分请求头", default=[]) + # 增加node_type + node_type: NodeType | None = Field(description="节点类型", default=NodeType.API) + @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="API调用", description="向某一个API接口发送HTTP请求,获取数据。") + return CallInfo( + name="API调用", + type=CallType.TOOL, + description="向某一个API接口发送HTTP请求,获取数据。" + ) async def _init(self, call_vars: CallVars) -> APIInput: """初始化API调用工具""" @@ -97,6 +105,7 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): method=self.method, query=self.query, body=self.body, + headers=self.headers, ) async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: @@ -129,12 +138,20 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): url=self.url, cookies=req_cookie, ) + logger.info("api_input ----------------------- %r", data) # 根据HTTP方法创建请求 if self.method in [HTTPMethod.GET.value, HTTPMethod.DELETE.value]: # GET/DELETE 请求处理 - req_params.update(data.query) - return await request_factory(params=req_params) + # 转换query参数格式: list[dict] -> dict + query_dict = {item["key"]: item["value"] for item in data.query} + req_params.update(query_dict) + logger.info("query_dict ----------------------- %r", query_dict) + logger.info("req_params ----------------------- %r", req_params) + header_dict = {item["key"]: item["value"] for item in data.headers} + req_header.update(header_dict) + logger.info("req_header ----------------------- %r", req_header) + return await request_factory(params=req_params, headers=req_header) if self.method in [HTTPMethod.POST.value, HTTPMethod.PUT.value, HTTPMethod.PATCH.value]: # POST/PUT/PATCH 请求处理 @@ -142,8 +159,13 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): raise CallError(message="API接口的Content-Type未指定", data={}) # 根据Content-Type设置请求参数 - req_body = data.body - req_header.update({"Content-Type": self.content_type}) + # 转换body参数格式: list[dict] -> dict + req_body = {item["key"]: item["value"] for item in data.body} + header_dict = {item["key"]: item["value"] for item in data.headers} + header_dict["Content-Type"] = self.content_type + req_header.update(header_dict) + logger.info("header_dict ----------------------- %r", header_dict) + logger.info("req_header ----------------------- %r", req_header) # 根据Content-Type决定如何发送请求体 content_type_handlers = { diff --git a/apps/scheduler/call/api/schema.py b/apps/scheduler/call/api/schema.py index 055008a889676e4e2fbd5057912ba78421e52aa2..1fbe82baae820f28ae6820f67323a48b6090d083 100644 --- a/apps/scheduler/call/api/schema.py +++ b/apps/scheduler/call/api/schema.py @@ -5,7 +5,6 @@ from typing import Any from pydantic import Field from pydantic.json_schema import SkipJsonSchema - from apps.scheduler.call.core import DataBase @@ -15,8 +14,9 @@ class APIInput(DataBase): url: SkipJsonSchema[str] = Field(description="API调用工具的URL") method: SkipJsonSchema[str] = Field(description="API调用工具的HTTP方法") - query: dict[str, Any] = Field(description="API调用工具的请求参数", default={}) - body: dict[str, Any] = Field(description="API调用工具的请求体", default={}) + query: list[dict[str, Any]] = Field(description="API调用工具的请求参数", default=[]) + body: list[dict[str, Any]] = Field(description="API调用工具的请求体", default=[]) + headers: list[dict[str, Any]] = Field(description="API调用工具的请求头", default=[]) class APIOutput(DataBase): diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index a5edf21afb2eeb308dd909a40696500751c9a086..47a0df747483c3915efb21c868d5cb1785b903e0 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -1,19 +1,153 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """使用大模型或使用程序做出判断""" -from enum import Enum +import ast +import copy +import logging +from collections.abc import AsyncGenerator +from typing import Any -from apps.scheduler.call.choice.schema import ChoiceInput, ChoiceOutput -from apps.scheduler.call.core import CoreCall +from pydantic import Field +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + Condition, + ChoiceBranch, + ChoiceInput, + ChoiceOutput, + Logic, +) +from apps.schemas.parameters import Type +from apps.scheduler.call.core import CoreCall, NodeType +from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) -class Operator(str, Enum): - """Choice工具支持的运算符""" - - pass +logger = logging.getLogger(__name__) class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" - pass + to_user: bool = Field(default=False) + choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(), + ChoiceBranch(conditions=[Condition()], is_default=False)]) + # 增加node_type + node_type: NodeType | None = Field(description="节点类型", default=NodeType.CHOICE) + + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo(name="选择器", description="使用大模型或使用程序做出判断", type=CallType.LOGIC) + + async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: + """替换choices中的系统变量""" + valid_choices = [] + + for choice in self.choices: + try: + # 验证逻辑运算符 + if choice.logic not in [Logic.AND, Logic.OR]: + msg = f"无效的逻辑运算符: {choice.logic}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + + valid_conditions = [] + for i in range(len(choice.conditions)): + condition = copy.deepcopy(choice.conditions[i]) + # 处理左值 + if condition.left.step_id is not None: + condition.left.value = self._extract_history_variables( + condition.left.step_id+'/'+condition.left.value, call_vars.history) + # 检查历史变量是否成功提取 + if condition.left.value is None: + msg = f"步骤 {condition.left.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.left.value, condition.left.type): + msg = f"左值类型不匹配: {condition.left.value} 应为 {condition.left.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + else: + msg = "左侧变量缺少step_id" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + # 处理右值 + if condition.right.step_id is not None: + condition.right.value = self._extract_history_variables( + condition.right.step_id+'/'+condition.right.value, call_vars.history) + # 检查历史变量是否成功提取 + if condition.right.value is None: + msg = f"步骤 {condition.right.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + else: + # 如果右值没有step_id,尝试从call_vars中获取 + right_value_type = await ConditionHandler.get_value_type_from_operate( + condition.operate) + if right_value_type is None: + msg = f"不支持的运算符: {condition.operate}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if condition.right.type != right_value_type: + msg = f"右值类型不匹配: {condition.right.value} 应为 {right_value_type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if right_value_type == Type.STRING: + condition.right.value = str(condition.right.value) + else: + condition.right.value = ast.literal_eval(condition.right.value) + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + valid_conditions.append(condition) + + # 如果所有条件都无效,抛出异常 + if not valid_conditions: + msg = "分支没有有效条件" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + + # 更新有效条件 + choice.conditions = valid_conditions + valid_choices.append(choice) + + except ValueError as e: + logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e)) + continue + + return valid_choices + + async def _init(self, call_vars: CallVars) -> ChoiceInput: + """初始化Choice工具""" + return ChoiceInput( + choices=await self._prepare_message(call_vars), + ) + + async def _exec( + self, input_data: dict[str, Any] + ) -> AsyncGenerator[CallOutputChunk, None]: + """执行Choice工具""" + # 解析输入数据 + data = ChoiceInput(**input_data) + try: + branch_id = ConditionHandler.handler(data.choices) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=ChoiceOutput(branch_id=branch_id).model_dump(exclude_none=True, by_alias=True), + ) + except Exception as e: + raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..7542f2944eedff72b820a79725d64b5ddeba5a11 --- /dev/null +++ b/apps/scheduler/call/choice/condition_handler.py @@ -0,0 +1,314 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""处理条件分支的工具""" + + +import logging + +from pydantic import BaseModel + +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) + +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + Condition, + Logic, + Value +) + +logger = logging.getLogger(__name__) + + +class ConditionHandler(BaseModel): + """条件分支处理器""" + @staticmethod + async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | + BoolOperate | DictOperate) -> Type: + """获取右值的类型""" + if isinstance(operate, NumberOperate): + return Type.NUMBER + if operate in [ + StringOperate.EQUAL, StringOperate.NOT_EQUAL, StringOperate.CONTAINS, StringOperate.NOT_CONTAINS, + StringOperate.STARTS_WITH, StringOperate.ENDS_WITH, StringOperate.REGEX_MATCH]: + return Type.STRING + if operate in [StringOperate.LENGTH_EQUAL, StringOperate.LENGTH_GREATER_THAN, + StringOperate.LENGTH_GREATER_THAN_OR_EQUAL, StringOperate.LENGTH_LESS_THAN, + StringOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [ListOperate.EQUAL, ListOperate.NOT_EQUAL]: + return Type.LIST + if operate in [ListOperate.CONTAINS, ListOperate.NOT_CONTAINS]: + return Type.STRING + if operate in [ListOperate.LENGTH_EQUAL, ListOperate.LENGTH_GREATER_THAN, + ListOperate.LENGTH_GREATER_THAN_OR_EQUAL, ListOperate.LENGTH_LESS_THAN, + ListOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [BoolOperate.EQUAL, BoolOperate.NOT_EQUAL]: + return Type.BOOL + if operate in [DictOperate.EQUAL, DictOperate.NOT_EQUAL]: + return Type.DICT + if operate in [DictOperate.CONTAINS_KEY, DictOperate.NOT_CONTAINS_KEY]: + return Type.STRING + return None + + @staticmethod + def check_value_type(value: Value, expected_type: Type) -> bool: + """检查值的类型是否符合预期""" + if expected_type == Type.STRING and isinstance(value.value, str): + return True + if expected_type == Type.NUMBER and isinstance(value.value, (int, float)): + return True + if expected_type == Type.LIST and isinstance(value.value, list): + return True + if expected_type == Type.DICT and isinstance(value.value, dict): + return True + if expected_type == Type.BOOL and isinstance(value.value, bool): + return True + return False + + @staticmethod + def handler(choices: list[ChoiceBranch]) -> str: + """处理条件""" + default_branch = [c for c in choices if c.is_default] + + for block_judgement in choices: + results = [] + if block_judgement.is_default: + continue + for condition in block_judgement.conditions: + result = ConditionHandler._judge_condition(condition) + results.append(result) + if block_judgement.logic == Logic.AND: + final_result = all(results) + elif block_judgement.logic == Logic.OR: + final_result = any(results) + + if final_result: + return block_judgement.branch_id + + # 如果没有匹配的分支,选择默认分支 + if default_branch: + return default_branch[0].branch_id + return "" + + @staticmethod + def _judge_condition(condition: Condition) -> bool: + """ + 判断条件是否成立。 + + Args: + condition (Condition): 'left', 'operate', 'right', 'type' + + Returns: + bool + + """ + left = condition.left + operate = condition.operate + right = condition.right + value_type = condition.type + + result = None + if value_type == Type.STRING: + result = ConditionHandler._judge_string_condition(left, operate, right) + elif value_type == Type.NUMBER: + result = ConditionHandler._judge_int_condition(left, operate, right) + elif value_type == Type.BOOL: + result = ConditionHandler._judge_bool_condition(left, operate, right) + elif value_type == Type.LIST: + result = ConditionHandler._judge_list_condition(left, operate, right) + elif value_type == Type.DICT: + result = ConditionHandler._judge_dict_condition(left, operate, right) + else: + logger.error("不支持的数据类型: %s", value_type) + msg = f"不支持的数据类型: {value_type}" + raise ValueError(msg) + return result + + @staticmethod + def _judge_string_condition(left: Value, operate: StringOperate, right: Value) -> bool: + """ + 判断字符串类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, str): + logger.error("左值不是字符串类型: %s", left_value) + msg = "左值必须是字符串类型" + raise TypeError(msg) + right_value = right.value + result = False + if operate == StringOperate.EQUAL: + return left_value == right_value + elif operate == StringOperate.NOT_EQUAL: + return left_value != right_value + elif operate == StringOperate.CONTAINS: + return right_value in left_value + elif operate == StringOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == StringOperate.STARTS_WITH: + return left_value.startswith(right_value) + elif operate == StringOperate.ENDS_WITH: + return left_value.endswith(right_value) + elif operate == StringOperate.REGEX_MATCH: + import re + return bool(re.match(right_value, left_value)) + elif operate == StringOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == StringOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == StringOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == StringOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == StringOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False + + @staticmethod + def _judge_number_condition(left: Value, operate: NumberOperate, right: Value) -> bool: # noqa: PLR0911 + """ + 判断数字类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, (int, float)): + logger.error("左值不是数字类型: %s", left_value) + msg = "左值必须是数字类型" + raise TypeError(msg) + right_value = right.value + if operate == NumberOperate.EQUAL: + return left_value == right_value + elif operate == NumberOperate.NOT_EQUAL: + return left_value != right_value + elif operate == NumberOperate.GREATER_THAN: + return left_value > right_value + elif operate == NumberOperate.LESS_THAN: # noqa: PLR2004 + return left_value < right_value + elif operate == NumberOperate.GREATER_THAN_OR_EQUAL: + return left_value >= right_value + elif operate == NumberOperate.LESS_THAN_OR_EQUAL: + return left_value <= right_value + return False + + @staticmethod + def _judge_bool_condition(left: Value, operate: BoolOperate, right: Value) -> bool: + """ + 判断布尔类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, bool): + logger.error("左值不是布尔类型: %s", left_value) + msg = "左值必须是布尔类型" + raise TypeError(msg) + right_value = right.value + if operate == BoolOperate.EQUAL: + return left_value == right_value + elif operate == BoolOperate.NOT_EQUAL: + return left_value != right_value + elif operate == BoolOperate.IS_EMPTY: + return not left_value + elif operate == BoolOperate.NOT_EMPTY: + return left_value + return False + + @staticmethod + def _judge_list_condition(left: Value, operate: ListOperate, right: Value): + """ + 判断列表类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, list): + logger.error("左值不是列表类型: %s", left_value) + msg = "左值必须是列表类型" + raise TypeError(msg) + right_value = right.value + if operate == ListOperate.EQUAL: + return left_value == right_value + elif operate == ListOperate.NOT_EQUAL: + return left_value != right_value + elif operate == ListOperate.CONTAINS: + return right_value in left_value + elif operate == ListOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == ListOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == ListOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == ListOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == ListOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == ListOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False + + @staticmethod + def _judge_dict_condition(left: Value, operate: DictOperate, right: Value): + """ + 判断字典类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, dict): + logger.error("左值不是字典类型: %s", left_value) + msg = "左值必须是字典类型" + raise TypeError(msg) + right_value = right.value + if operate == DictOperate.EQUAL: + return left_value == right_value + elif operate == DictOperate.NOT_EQUAL: + return left_value != right_value + elif operate == DictOperate.CONTAINS_KEY: + return right_value in left_value + elif operate == DictOperate.NOT_CONTAINS_KEY: + return right_value not in left_value + return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index 60b62d09fd66adbf32295f44ec86398a537f38d5..b95b166879608bc431525958e7e120ad93c5e5a3 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -1,12 +1,63 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Choice Call的输入和输出""" +import uuid +from enum import Enum + +from pydantic import BaseModel, Field + +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.scheduler.call.core import DataBase +class Logic(str, Enum): + """Choice 工具支持的逻辑运算符""" + + AND = "and" + OR = "or" + + +class Value(DataBase): + """值的结构""" + + step_id: str | None = Field(description="步骤id", default=None) + type: Type | None = Field(description="值的类型", default=None) + value: str | float | int | bool | list | dict | None = Field(description="值", default=None) + + +class Condition(DataBase): + """单个条件""" + + left: Value = Field(description="左值", default=Value()) + right: Value = Field(description="右值", default=Value()) + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate | None = Field( + description="运算符", default=None) + id: str = Field(description="条件ID", default_factory=lambda: str(uuid.uuid4())) + + +class ChoiceBranch(DataBase): + """子分支""" + + branch_id: str = Field(description="分支ID", default_factory=lambda: str(uuid.uuid4())) + logic: Logic = Field(description="逻辑运算符", default=Logic.AND) + conditions: list[Condition] = Field(description="条件列表", default=[]) + is_default: bool = Field(description="是否为默认分支", default=True) + + class ChoiceInput(DataBase): """Choice Call的输入""" + choices: list[ChoiceBranch] = Field(description="分支", default=[]) + class ChoiceOutput(DataBase): """Choice Call的输出""" + + branch_id: str = Field(description="分支ID", default="") diff --git a/apps/scheduler/call/code/__init__.py b/apps/scheduler/call/code/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44da1c5078f9f6b32fec8cf3e23e6984aa1e8948 --- /dev/null +++ b/apps/scheduler/call/code/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""代码执行工具""" + +from apps.scheduler.call.code.code import Code + +__all__ = ["Code"] diff --git a/apps/scheduler/call/code/code.py b/apps/scheduler/call/code/code.py new file mode 100644 index 0000000000000000000000000000000000000000..d8ed7c01c415b3ede054c790bb195cd6d6048d32 --- /dev/null +++ b/apps/scheduler/call/code/code.py @@ -0,0 +1,318 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""代码执行工具""" + +import logging +from collections.abc import AsyncGenerator +from typing import Any +import httpx +from pydantic import Field +from apps.common.config import Config +from apps.scheduler.call.code.schema import CodeInput, CodeOutput +from apps.scheduler.call.core import CoreCall, NodeType +from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) + +logger = logging.getLogger(__name__) + + +class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): + """代码执行工具""" + + to_user: bool = Field(default=True) + + # 代码执行参数 + code: str = Field(description="要执行的代码", default="") + code_type: str = Field(description="代码类型,支持python、javascript、bash", default="python") + security_level: str = Field(description="安全等级,low或high", default="low") + timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300) + memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024) + cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0) + input_parameters: dict[str, Any] = Field(description="输入参数配置", default={}) + output_parameters: dict[str, Any] = Field(description="输出参数配置", default={}) + # 增加node_type + node_type: NodeType | None = Field(description="节点类型", default=NodeType.CODE) + + + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo( + name="代码执行", + type=CallType.TRANSFORM, + description="在安全的沙箱环境中执行Python、JavaScript、Bash代码。" + ) + + + async def _init(self, call_vars: CallVars) -> CodeInput: + """初始化代码执行工具""" + # 构造用户信息 + user_info = { + "user_id": call_vars.ids.user_sub, + "username": call_vars.ids.user_sub, # 可以从其他地方获取真实用户名 + "permissions": ["execute"] + } + + # 处理输入参数 - 使用基类的变量解析功能 + input_arg = {} + if self.input_parameters: + # 解析每个输入参数 + for param_name, param_config in self.input_parameters.items(): + resolved_value = await self._resolve_variables_in_config(param_config, call_vars) + input_arg[param_name] = resolved_value + + return CodeInput( + code=self.code, + code_type=self.code_type, + user_info=user_info, + security_level=self.security_level, + timeout_seconds=self.timeout_seconds, + memory_limit_mb=self.memory_limit_mb, + cpu_limit=self.cpu_limit, + input_arg=input_arg, + ) + + + async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行代码""" + data = CodeInput(**input_data) + + try: + # 获取sandbox服务地址 + config = Config().get_config() + sandbox_url = config.sandbox.sandbox_service + + # 构造请求数据 + request_data = { + "code": data.code, + "code_type": data.code_type, + "user_info": data.user_info, + "security_level": data.security_level, + "timeout_seconds": data.timeout_seconds, + "memory_limit_mb": data.memory_limit_mb, + "cpu_limit": data.cpu_limit, + "input_arg": data.input_arg, + } + + # 发送执行请求 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{sandbox_url.rstrip('/')}/execute", + json=request_data, + headers={"Content-Type": "application/json"} + ) + + if response.status_code != 200: + raise CallError( + message=f"代码执行服务请求失败: {response.status_code}", + data={"status_code": response.status_code, "response": response.text} + ) + + result = response.json() + logger.info(f"Sandbox service response: {result}") + + # 检查请求是否成功 + success = result.get("success", False) + message = result.get("message", "") + timestamp = result.get("timestamp", "") + + if not success: + raise CallError( + message=f"代码执行服务返回错误: {message}", + data={"response": result} + ) + + # 提取任务信息 + data = result.get("data", {}) + task_id = data.get("task_id", "") + estimated_wait_time = data.get("estimated_wait_time", 0) + queue_position = data.get("queue_position", 0) + + logger.info(f"Task submitted successfully - task_id: {task_id}, estimated_wait_time: {estimated_wait_time}s, queue_position: {queue_position}, timestamp: {timestamp}") + + # 轮询获取结果 + if task_id: + # 有task_id,需要轮询获取最终结果 + result = await self._wait_for_result(sandbox_url, task_id) + else: + # 没有task_id,可能是同步执行,直接使用初始响应 + # 但需要确保结果格式正确 + if "output" not in result and "error" not in result: + # 如果初始响应没有包含执行结果,可能是异步但没有返回task_id的错误情况 + result = { + "status": "error", + "error": "服务器没有返回task_id且没有执行结果", + "output": "", + "result": {} + } + + # 处理sandbox返回的结果,提取output_parameters指定的数据 + extracted_data = await self._process_sandbox_result(result) + + # 构建最终输出内容 + final_content = CodeOutput( + task_id=task_id, + status=result.get("status", "unknown"), + output=result.get("output") or "", + error=result.get("error") or "", + ).model_dump(by_alias=True, exclude_none=True) + + # 如果成功提取到数据,将其合并到输出中 + if extracted_data and result.get("status") == "completed": + final_content.update(extracted_data) + logger.info(f"[Code] 已将提取的数据合并到输出: {list(extracted_data.keys())}") + + # 返回最终结果 + yield CallOutputChunk( + type=CallOutputType.DATA, + content=final_content, + ) + + except httpx.TimeoutException: + raise CallError(message="代码执行超时", data={}) + except httpx.RequestError as e: + raise CallError(message=f"请求sandbox服务失败: {e!s}", data={}) + except Exception as e: + raise CallError(message=f"代码执行失败: {e!s}", data={}) + + + async def _process_sandbox_result(self, result: dict[str, Any]) -> dict[str, Any] | None: + """处理sandbox返回的结果,根据output_parameters提取数据""" + try: + # 检查是否有output_parameters配置 + if not hasattr(self, 'output_parameters') or not self.output_parameters: + logger.debug("[Code] 无output_parameters配置,跳过数据提取") + return None + + # 获取sandbox返回的output + sandbox_output = result.get("output") + if not sandbox_output: + logger.warning("[Code] sandbox返回的结果中没有output字段") + return None + + # 确保output是字典类型 + if isinstance(sandbox_output, str): + # 尝试解析JSON字符串 + try: + import json + sandbox_output = json.loads(sandbox_output) + except json.JSONDecodeError: + logger.warning(f"[Code] sandbox返回的output不是有效的JSON格式: {sandbox_output}") + return None + + if not isinstance(sandbox_output, dict): + logger.warning(f"[Code] sandbox返回的output不是字典类型: {type(sandbox_output)}") + return None + + # 根据output_parameters提取对应的kv对 + extracted_data = {} + for param_name, param_config in self.output_parameters.items(): + try: + # 支持多种提取方式 + if param_name in sandbox_output: + # 直接键匹配 + extracted_data[param_name] = sandbox_output[param_name] + elif isinstance(param_config, dict) and "path" in param_config: + # 路径提取 + path = param_config["path"] + value = self._extract_value_by_path(sandbox_output, path) + if value is not None: + extracted_data[param_name] = value + elif isinstance(param_config, dict) and param_config.get("source") == "full_output": + # 使用完整输出 + extracted_data[param_name] = sandbox_output + elif isinstance(param_config, dict) and "default" in param_config: + # 使用默认值 + extracted_data[param_name] = param_config["default"] + else: + logger.debug(f"[Code] 无法提取参数 {param_name},在output中未找到对应值") + + except Exception as e: + logger.warning(f"[Code] 提取参数 {param_name} 失败: {e}") + + if extracted_data: + logger.info(f"[Code] 成功提取 {len(extracted_data)} 个输出参数: {list(extracted_data.keys())}") + return extracted_data + else: + logger.debug("[Code] 未能提取到任何输出参数") + return None + + except Exception as e: + logger.error(f"[Code] 处理sandbox结果失败: {e}") + return None + + def _extract_value_by_path(self, data: dict, path: str) -> Any: + """根据路径提取值 (例如: 'result.data.value')""" + try: + current = data + for key in path.split('.'): + if isinstance(current, dict) and key in current: + current = current[key] + else: + return None + return current + except Exception: + return None + + + async def _wait_for_result(self, sandbox_url: str, task_id: str, max_attempts: int = 30) -> dict[str, Any]: + """等待任务执行完成""" + import asyncio + + async with httpx.AsyncClient(timeout=10.0) as client: + for _ in range(max_attempts): + try: + # 获取任务状态 + response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/status") + if response.status_code == 200: + status_result = response.json() + logger.info(f"Task status response: {status_result}") + + # 检查响应是否成功 + success = status_result.get("success", False) + if not success: + message = status_result.get("message", "获取任务状态失败") + logger.warning(f"Failed to get task status: {message}") + await asyncio.sleep(1) + continue + + # 提取任务状态 + data = status_result.get("data", {}) + status = data.get("status", "") + + # 如果任务完成,获取结果 + if status in ["completed", "failed", "cancelled"]: + result_response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/result") + if result_response.status_code == 200: + result_data = result_response.json() + logger.info(f"Task result response: {result_data}") + + # 检查获取结果是否成功 + if result_data.get("success", False): + # 返回实际的执行结果 + return result_data.get("data", {}) + else: + return {"status": status, "error": result_data.get("message", "获取结果失败")} + else: + return {"status": status, "error": "无法获取结果"} + + # 如果任务仍在运行,继续等待 + if status in ["pending", "running"]: + logger.debug(f"Task {task_id} still {status}, waiting...") + await asyncio.sleep(1) + continue + + # 其他状态或请求失败,等待后重试 + await asyncio.sleep(1) + + except Exception: + # 请求异常,等待后重试 + await asyncio.sleep(1) + + # 超时返回 + return {"status": "timeout", "error": "等待任务完成超时"} \ No newline at end of file diff --git a/apps/scheduler/call/code/schema.py b/apps/scheduler/call/code/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..93648a811a34983a81c9a644067115efcfa68ec5 --- /dev/null +++ b/apps/scheduler/call/code/schema.py @@ -0,0 +1,30 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""代码执行工具的数据结构""" + +from typing import Any + +from pydantic import BaseModel, Field + +from apps.scheduler.call.core import DataBase + + +class CodeInput(DataBase): + """代码执行工具的输入""" + + code: str = Field(description="要执行的代码") + code_type: str = Field(description="代码类型,支持python、javascript、bash") + user_info: dict[str, Any] = Field(description="用户信息", default={}) + security_level: str = Field(description="安全等级,low或high", default="low") + timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300) + memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024) + cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0) + input_arg: dict[str, Any] = Field(description="传递给main函数的输入参数", default={}) + + +class CodeOutput(DataBase): + """代码执行工具的输出""" + + task_id: str = Field(description="任务ID") + status: str = Field(description="任务状态") + output: str = Field(description="执行输出", default="") + error: str = Field(description="错误信息", default="") \ No newline at end of file diff --git a/apps/scheduler/call/convert/convert.py b/apps/scheduler/call/convert/convert.py index 27980bd8ad46aaaa7741d6a6919661aac580de64..bbe0dbe80217ba5e24c629fd02279086e08230fc 100644 --- a/apps/scheduler/call/convert/convert.py +++ b/apps/scheduler/call/convert/convert.py @@ -12,7 +12,7 @@ from pydantic import Field from apps.scheduler.call.convert.schema import ConvertInput, ConvertOutput from apps.scheduler.call.core import CallOutputChunk, CoreCall -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallInfo, CallOutputChunk, @@ -30,7 +30,11 @@ class Convert(CoreCall, input_model=ConvertInput, output_model=ConvertOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="模板转换", description="使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。") + return CallInfo( + name="模板转换", + type=CallType.TRANSFORM, + description="使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。" + ) async def _init(self, call_vars: CallVars) -> ConvertInput: """初始化工具""" diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 2b1cbba83b9345e8df84e8f23f893137cb46532b..f6ae08661a0ece5d1f50753e64ec7ad62c634be9 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -6,6 +6,7 @@ Core Call类是定义了所有Call都应具有的方法和参数的PyDantic类 """ import logging +import re from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, ClassVar, Self @@ -14,6 +15,7 @@ from pydantic.json_schema import SkipJsonSchema from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.variable.integration import VariableIntegration from apps.schemas.enum_var import CallOutputType from apps.schemas.pool import NodePool from apps.schemas.scheduler import ( @@ -25,6 +27,7 @@ from apps.schemas.scheduler import ( CallVars, ) from apps.schemas.task import FlowStepHistory +from enum import Enum if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor @@ -32,6 +35,23 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class NodeType(str, Enum): + + """ + 节点类型 用来为前后端判断节点类型 + 仅增加了当前在修改的几个类型 + """ + + LLM = "llm" + API = "api" + RAG = 'rag' + MCP = 'mcp' + CHOICE = "choice" + CODE = 'code' + DIRECTREPLY = "directreply" + + + class DataBase(BaseModel): """所有Call的输入基类""" @@ -70,12 +90,15 @@ class CoreCall(BaseModel): ) to_user: bool = Field(description="是否需要将输出返回给用户", default=False) + enable_variable_resolution: bool = Field(description="是否启用自动变量解析", default=True) model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", ) + node_type: NodeType | None = Field(description="节点类型,只包含已改造后的节点", default=None) + def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None: """初始化子类""" @@ -83,14 +106,12 @@ class CoreCall(BaseModel): cls.input_model = input_model cls.output_model = output_model - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" err = "[CoreCall] 必须手动实现info方法" raise NotImplementedError(err) - @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" @@ -111,6 +132,7 @@ class CoreCall(BaseModel): task_id=executor.task.id, flow_id=executor.task.state.flow_id, session_id=executor.task.ids.session_id, + conversation_id=executor.task.ids.conversation_id, user_sub=executor.task.ids.user_sub, app_id=executor.task.state.app_id, ), @@ -120,7 +142,6 @@ class CoreCall(BaseModel): summary=executor.task.runtime.summary, ) - @staticmethod def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any: """ @@ -131,18 +152,16 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") + if len(split_path) < 2: + err = f"[CoreCall] 路径格式错误: {path}" + logger.error(err) + return None if split_path[0] not in history: err = f"[CoreCall] 步骤{split_path[0]}不存在" logger.error(err) - raise CallError( - message=err, - data={ - "step_id": split_path[0], - }, - ) - + return None data = history[split_path[0]].output_data - for key in split_path[1:]: + for key in split_path[2:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" logger.error(err) @@ -156,13 +175,27 @@ class CoreCall(BaseModel): data = data[key] return data - @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """实例化Call类""" + if 'name' in kwargs: + # 如果kwargs中已经有name参数,使用kwargs中的值 + name = kwargs.pop('name') + else: + # 否则使用executor.step.step.name + name = executor.step.step.name + + # 检查kwargs中是否已经包含description参数,避免重复传入 + if 'description' in kwargs: + # 如果kwargs中已经有description参数,使用kwargs中的值 + description = kwargs.pop('description') + else: + # 否则使用executor.step.step.description + description = executor.step.step.description + obj = cls( - name=executor.step.step.name, - description=executor.step.step.description, + name=name, + description=description, node=node, **kwargs, ) @@ -170,35 +203,206 @@ class CoreCall(BaseModel): await obj._set_input(executor) return obj + async def _initialize_variable_context(self, call_vars: CallVars) -> dict[str, Any]: + """初始化变量解析上下文并初始化系统变量""" + context = { + "question": call_vars.question, + "user_sub": call_vars.ids.user_sub, + "flow_id": call_vars.ids.flow_id, + "session_id": call_vars.ids.session_id, + "app_id": call_vars.ids.app_id, + "conversation_id": call_vars.ids.conversation_id, + } + + await VariableIntegration.initialize_system_variables(context) + return context + + async def _resolve_variables_in_config(self, config: Any, call_vars: CallVars) -> Any: + """解析配置中的变量引用 + + Args: + config: 配置值,可能包含变量引用 + call_vars: Call变量 + + Returns: + 解析后的配置值 + """ + if isinstance(config, dict): + if "reference" in config: + # 解析变量引用 + resolved_value = await VariableIntegration.resolve_variable_reference( + config["reference"], + user_sub=call_vars.ids.user_sub, + flow_id=call_vars.ids.flow_id, + conversation_id=call_vars.ids.conversation_id + ) + return resolved_value + elif "value" in config: + # 使用默认值 + return config["value"] + else: + # 递归解析字典中的所有值 + resolved_dict = {} + for key, value in config.items(): + resolved_dict[key] = await self._resolve_variables_in_config(value, call_vars) + return resolved_dict + elif isinstance(config, list): + # 递归解析列表中的所有值 + resolved_list = [] + for item in config: + resolved_item = await self._resolve_variables_in_config(item, call_vars) + resolved_list.append(resolved_item) + return resolved_list + elif isinstance(config, str): + # 解析字符串中的变量引用 + return await self._resolve_variables_in_text(config, call_vars) + else: + # 直接返回配置值 + return config + + async def _resolve_variables_in_text(self, text: str, call_vars: CallVars) -> str: + """解析文本中的变量引用({{...}} 语法) + + Args: + text: 包含变量引用的文本 + call_vars: Call变量 + + Returns: + 解析后的文本 + """ + if not isinstance(text, str): + return text + + # 检查是否包含变量引用语法 + if not re.search(r'\{\{.*?\}\}', text): + return text + + # 提取所有变量引用并逐一解析替换 + variable_pattern = r'\{\{(.*?)\}\}' + matches = re.findall(variable_pattern, text) + + resolved_text = text + for match in matches: + try: + # 解析变量引用 + resolved_value = await VariableIntegration.resolve_variable_reference( + match.strip(), + user_sub=call_vars.ids.user_sub, + flow_id=call_vars.ids.flow_id, + conversation_id=call_vars.ids.conversation_id + ) + # 替换原始文本中的变量引用 + resolved_text = resolved_text.replace(f'{{{{{match}}}}}', str(resolved_value)) + except Exception as e: + logger.warning(f"[CoreCall] 解析变量引用 '{match}' 失败: {e}") + # 如果解析失败,保留原始的变量引用 + continue + + return resolved_text + async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) + self._step_id = executor.step.step_id # 存储 step_id 用于变量名构造 + + # 如果启用了变量解析,初始化变量上下文 + if self.enable_variable_resolution: + await self._initialize_variable_context(self._sys_vars) + input_data = await self._init(self._sys_vars) self.input = input_data.model_dump(by_alias=True, exclude_none=True) - async def _init(self, call_vars: CallVars) -> DataBase: """初始化Call类,并返回Call的输入""" err = "[CoreCall] 初始化方法必须手动实现" raise NotImplementedError(err) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的流式输出方法""" yield CallOutputChunk(type=CallOutputType.TEXT, content="") - async def _after_exec(self, input_data: dict[str, Any]) -> None: """Call类实例的执行后方法""" + # 自动保存 extracted_data 到对话变量池 + await self._save_extracted_data_to_variables() + + async def _save_extracted_data_to_variables(self) -> None: + """将 extracted_data 保存到对话变量池""" + try: + # 检查是否有 output_parameters 配置 + if not hasattr(self, 'output_parameters') or not self.output_parameters: + return + + # 检查是否有输出数据 + if not hasattr(self, 'input') or not self.input: + return + + # 获取当前的输出数据(从最近的执行结果中) + # 注意:这里假设 extracted_data 已经被合并到输出中 + output_data = getattr(self, '_last_output_data', {}) + if not output_data or not isinstance(output_data, dict): + return + + # 确定变量名前缀 + from apps.scheduler.executor.step_config import should_use_direct_conversation_format + use_direct_format = should_use_direct_conversation_format( + call_id=getattr(self, '__class__').__name__.lower(), + step_name=self.name, + step_id=getattr(self, '_step_id', self.name) + ) + + var_prefix = "" if use_direct_format else f"{self._step_id}." + + # 保存每个 output_parameter 到变量池 + saved_count = 0 + for param_name, param_config in self.output_parameters.items(): + try: + # 检查输出数据中是否包含该参数 + if param_name in output_data: + param_value = output_data[param_name] + + # 构造变量名 + var_name = f"{var_prefix}{param_name}" + + # 保存到对话变量池 + success = await VariableIntegration.save_conversation_variable( + var_name=var_name, + value=param_value, + var_type=param_config.get("type", "string") if isinstance(param_config, dict) else "string", + description=param_config.get("description", "") if isinstance(param_config, dict) else "", + user_sub=self._sys_vars.ids.user_sub, + flow_id=self._sys_vars.ids.flow_id, + conversation_id=self._sys_vars.ids.conversation_id + ) + + if success: + saved_count += 1 + logger.debug(f"[CoreCall] 已保存提取数据到变量池: conversation.{var_name} = {param_value}") + else: + logger.warning(f"[CoreCall] 保存提取数据变量失败: {var_name}") + + except Exception as e: + logger.warning(f"[CoreCall] 保存提取数据 {param_name} 失败: {e}") + + if saved_count > 0: + logger.info(f"[CoreCall] 已保存 {saved_count} 个提取数据到对话变量池") + + except Exception as e: + logger.error(f"[CoreCall] 保存提取数据到变量池失败: {e}") async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" + self._last_output_data = {} # 初始化输出数据存储 + async for chunk in self._exec(input_data): + # 捕获最后的输出数据 + if chunk.type == CallOutputType.DATA and isinstance(chunk.content, dict): + self._last_output_data = chunk.content yield chunk - await self._after_exec(input_data) + await self._after_exec(input_data) async def _llm(self, messages: list[dict[str, Any]]) -> str: """Call可直接使用的LLM非流式调用""" @@ -210,9 +414,11 @@ class CoreCall(BaseModel): self.output_tokens = llm.output_tokens return result - async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel: """Call可直接使用的JSON生成""" json = FunctionLLM() result = await json.call(messages=messages, schema=schema.model_json_schema()) return schema.model_validate(result) + + + diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py index 5865bc7e804491a7d6aa41fa251dc8c5d9c77dc6..a66aac5251ab580942c583576e9ec9bba3204951 100644 --- a/apps/scheduler/call/empty.py +++ b/apps/scheduler/call/empty.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator from typing import Any from apps.scheduler.call.core import CoreCall, DataBase -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars @@ -20,7 +20,11 @@ class Empty(CoreCall, input_model=DataBase, output_model=DataBase): :return: Call的名称和描述 :rtype: CallInfo """ - return CallInfo(name="空白", description="空白节点,用于占位") + return CallInfo( + name="空白", + type=CallType.DEFAULT, + description="空白节点,用于占位" + ) async def _init(self, call_vars: CallVars) -> DataBase: diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index f8aebcd748d92de4109553d0f68fecac43363f58..10241d85a97d3f2ea70a8a93a1a5a72173e4cfdd 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -16,7 +16,7 @@ from apps.scheduler.call.facts.schema import ( FactsInput, FactsOutput, ) -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars from apps.services.user_domain import UserDomainManager @@ -34,7 +34,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。") + return CallInfo( + name="提取事实", + type=CallType.DEFAULT, + description="从对话上下文和文档片段中提取事实。" + ) @classmethod diff --git a/apps/scheduler/call/graph/graph.py b/apps/scheduler/call/graph/graph.py index c2728f17913fcd0e8343168f2b508dbe6006fd6e..7383b2f2c85aacf78db11b00f95f22c2996f071f 100644 --- a/apps/scheduler/call/graph/graph.py +++ b/apps/scheduler/call/graph/graph.py @@ -11,7 +11,7 @@ from pydantic import Field from apps.scheduler.call.core import CoreCall from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput from apps.scheduler.call.graph.style import RenderStyle -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -29,7 +29,11 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="图表", description="将SQL查询出的数据转换为图表") + return CallInfo( + name="图表", + type=CallType.TRANSFORM, + description="将SQL查询出的数据转换为图表" + ) async def _init(self, call_vars: CallVars) -> RenderInput: diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 6a679dce98af6164211edd16b2fa38714d899f31..23098d9dac841008cb295400863364f727aada06 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -12,20 +12,25 @@ from jinja2.sandbox import SandboxedEnvironment from pydantic import Field from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.core import CoreCall, NodeType from apps.scheduler.call.llm.prompt import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT from apps.scheduler.call.llm.schema import LLMInput, LLMOutput -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, CallOutputChunk, CallVars, ) +from apps.schemas.task import FlowStepHistory +from apps.services.llm import LLMManager +from apps.schemas.config import LLMConfig + logger = logging.getLogger(__name__) + class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): """大模型调用工具""" @@ -37,12 +42,20 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3, ge=1, le=10) system_prompt: str = Field(description="大模型系统提示词", default="You are a helpful assistant.") user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT) + frequency_penalty: float = Field(description="频率惩罚大模型生成文本时用于减少重复词汇出现的参数", default=0) + llmId: str = Field(description="指定使用的大模型ID", default="empty") + # 增加node_type + node_type: NodeType | None = Field(description="节点类型", default=NodeType.LLM) @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="大模型", description="以指定的提示词和上下文信息调用大模型,并获得输出。") + return CallInfo( + name="大模型", + type=CallType.DEFAULT, + description="以指定的提示词和上下文信息调用大模型,并获得输出。" + ) async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: @@ -80,7 +93,7 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): try: # 准备系统提示词 system_tmpl = env.from_string(self.system_prompt) - system_input = system_tmpl.render(**formatter) + system_input = system_tmpl.render() # 准备用户提示词 user_tmpl = env.from_string(self.user_prompt) @@ -106,6 +119,8 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): data = LLMInput(**input_data) try: llm = ReasoningLLM() + + async for chunk in llm.call(messages=data.message): if not chunk: continue @@ -113,4 +128,4 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): self.tokens.input_tokens = llm.input_tokens self.tokens.output_tokens = llm.output_tokens except Exception as e: - raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e + raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e \ No newline at end of file diff --git a/apps/scheduler/call/llm/prompt.py b/apps/scheduler/call/llm/prompt.py index 0f227dcaa618b11a2c888f55a61fa51f349d7d8b..5701f7b3ff3e7783ae1e768c5003247c2fc42030 100644 --- a/apps/scheduler/call/llm/prompt.py +++ b/apps/scheduler/call/llm/prompt.py @@ -72,12 +72,12 @@ LLM_ERROR_PROMPT = dedent( RAG_ANSWER_PROMPT = dedent( r""" - 你是由openEuler社区构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。 + 你是由华鲲振宇构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。 用户的问题将在中给出,上下文背景信息将在中给出,文档片段将在中给出。 注意事项: 1. 输出不要包含任何XML标签。请确保输出内容的正确性,不要编造任何信息。 - 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫EulerCopilot,是openEuler社区的智能助手”。 + 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫huakun-copilot,是华鲲振宇的智能助手”。 3. 背景信息仅供参考,若背景信息与用户问题无关,请忽略背景信息直接作答。 4. 请在回答中使用Markdown格式,并**不要**将内容放在"```"中。 diff --git a/apps/scheduler/call/llm/schema.py b/apps/scheduler/call/llm/schema.py index c7bb50541d168a406fa478f25894c769b034ec9c..b418f96ad2652c87bad1542547ac83d8783a4d19 100644 --- a/apps/scheduler/call/llm/schema.py +++ b/apps/scheduler/call/llm/schema.py @@ -12,5 +12,8 @@ class LLMInput(DataBase): message: list[dict[str, str]] = Field(description="输入给大模型的消息列表") + + class LLMOutput(DataBase): """定义LLM工具调用的输出""" + text: str = Field(description="大模型输出") diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 661e9ada74da76e6d78a0f6cba42cef08c562335..d57102fc169e602fbf23136f664ee235d01e0034 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import Field -from apps.scheduler.call.core import CallError, CoreCall +from apps.scheduler.call.core import CallError, CoreCall, NodeType from apps.scheduler.call.mcp.schema import ( MCPInput, MCPMessage, @@ -16,7 +16,7 @@ from apps.scheduler.call.mcp.schema import ( MCPOutput, ) from apps.scheduler.mcp import MCPHost, MCPPlanner, MCPSelector -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.mcp import MCPPlanItem from apps.schemas.scheduler import ( CallInfo, @@ -34,6 +34,9 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): max_steps: int = Field(description="最大步骤数", default=6) text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) + # 增加node_type + node_type: NodeType | None = Field(description="节点类型", default=NodeType.MCP) + @classmethod @@ -44,7 +47,11 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): :return: Call的名称和描述 :rtype: CallInfo """ - return CallInfo(name="MCP", description="调用MCP Server,执行工具") + return CallInfo( + name="MCP", + type=CallType.DEFAULT, + description="调用MCP Server,执行工具" + ) async def _init(self, call_vars: CallVars) -> MCPInput: @@ -63,7 +70,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行MCP""" # 生成计划 @@ -80,7 +86,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): async for chunk in self._generate_answer(): yield chunk - async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]: """生成执行计划""" # 开始提示 @@ -89,6 +94,7 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): # 选择工具并生成计划 selector = MCPSelector() top_tool = await selector.select_top_tool(self._call_vars.question, self.mcp_list) + logger.info("***************] 选择到的工具: %s", top_tool) planner = MCPPlanner(self._call_vars.question) self._plan = await planner.create_plan(top_tool, self.max_steps) @@ -103,7 +109,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): data=self._plan.model_dump(), ) - async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]: """执行单个计划项""" # 判断是否为Final @@ -141,7 +146,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): }, ) - async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]: """生成总结""" # 提示开始总结 @@ -163,7 +167,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): ).model_dump(), ) - def _create_output( self, text: str, diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index e27327d8ad4d01387eeeb4f1644c056d32bc0dbe..ad9c52d46c98da31c07f94bda8cbabcf7c865ad3 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -11,9 +11,9 @@ from pydantic import Field from apps.common.config import Config from apps.llm.patterns.rewrite import QuestionRewrite -from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.core import CoreCall, NodeType from apps.scheduler.call.rag.schema import RAGInput, RAGOutput, SearchMethod -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -36,11 +36,18 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): is_rerank: bool = Field(description="是否重新排序", default=False) is_compress: bool = Field(description="是否压缩", default=False) tokens_limit: int = Field(description="token限制", default=8192) + # 增加node_type + node_type: NodeType | None = Field(description="节点类型", default=NodeType.RAG) + @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="知识库", description="查询知识库,从文档中获取必要信息") + return CallInfo( + name="知识库", + type=CallType.DEFAULT, + description="查询知识库,从文档中获取必要信息" + ) async def _init(self, call_vars: CallVars) -> RAGInput: """初始化RAG工具""" diff --git a/apps/scheduler/call/reply/__init__.py b/apps/scheduler/call/reply/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5588bcc298b338a542312a22843494eb8fe1d890 --- /dev/null +++ b/apps/scheduler/call/reply/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Reply工具模块""" + +from apps.scheduler.call.reply.direct_reply import DirectReply + +__all__ = [ + "DirectReply", +] diff --git a/apps/scheduler/call/reply/direct_reply.py b/apps/scheduler/call/reply/direct_reply.py new file mode 100644 index 0000000000000000000000000000000000000000..1c06911021cfc7556e995d485fd7ad96d3847130 --- /dev/null +++ b/apps/scheduler/call/reply/direct_reply.py @@ -0,0 +1,67 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""直接回复工具""" + +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from pydantic import Field + +from apps.scheduler.call.core import CoreCall, NodeType +from apps.scheduler.call.reply.schema import DirectReplyInput, DirectReplyOutput +from apps.schemas.enum_var import CallOutputType, CallType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) + +logger = logging.getLogger(__name__) + + +class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectReplyOutput): + """直接回复工具,支持变量引用语法""" + + to_user: bool = Field(default=True) + # 增加node_type + node_type: NodeType | None = Field(description="节点类型", default=NodeType.DIRECTREPLY) + + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo( + name="直接回复", + type=CallType.DEFAULT, + description="直接回复用户输入的内容,支持变量插入" + ) + + async def _init(self, call_vars: CallVars) -> DirectReplyInput: + """初始化DirectReply工具""" + answer = getattr(self, 'answer', '') + return DirectReplyInput(answer=answer) + + async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行直接回复""" + data = DirectReplyInput(**input_data) + + try: + # 使用基类的变量解析功能处理文本中的变量引用 + final_answer = await self._resolve_variables_in_text(data.answer, self._sys_vars) + + logger.info(f"[DirectReply] 原始答案: {data.answer}") + logger.info(f"[DirectReply] 解析后答案: {final_answer}") + + # 直接返回处理后的内容 + yield CallOutputChunk( + type=CallOutputType.TEXT, + content=final_answer + ) + + except Exception as e: + logger.error(f"[DirectReply] 处理回复内容失败: {e}") + raise CallError( + message=f"直接回复处理失败:{e!s}", + data={"original_answer": data.answer} + ) from e diff --git a/apps/scheduler/call/reply/schema.py b/apps/scheduler/call/reply/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..f13dc636284dc6da764aa73fc62e99405674febe --- /dev/null +++ b/apps/scheduler/call/reply/schema.py @@ -0,0 +1,18 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""DirectReply工具的输入输出定义""" + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class DirectReplyInput(DataBase): + """定义DirectReply工具调用的输入""" + + answer: str = Field(description="直接回复的内容,支持变量引用语法") + + +class DirectReplyOutput(DataBase): + """定义DirectReply工具调用的输出""" + + message: str = Field(description="处理后的回复消息") \ No newline at end of file diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 4f8e1010cc0bd88f22e050778bce236d7d3515e0..69c2bfbcd94c04adab5396a9f27763633715e182 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -15,7 +15,7 @@ from apps.scheduler.call.core import CoreCall from apps.scheduler.call.slot.prompt import SLOT_GEN_PROMPT from apps.scheduler.call.slot.schema import SlotInput, SlotOutput from apps.scheduler.slot.slot import Slot as SlotProcessor -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars @@ -36,7 +36,11 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="参数自动填充", description="根据步骤历史,自动填充参数") + return CallInfo( + name="参数自动填充", + type=CallType.TRANSFORM, + description="根据步骤历史,自动填充参数" + ) async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]: diff --git a/apps/scheduler/call/sql/sql.py b/apps/scheduler/call/sql/sql.py index 3e24301de508e06adf5cfdbf24b3d8ca37c0cc27..2f4ef97fb095baeac0b89ceb2dff1c98338fed18 100644 --- a/apps/scheduler/call/sql/sql.py +++ b/apps/scheduler/call/sql/sql.py @@ -12,7 +12,7 @@ from pydantic import Field from apps.common.config import Config from apps.scheduler.call.core import CoreCall from apps.scheduler.call.sql.schema import SQLInput, SQLOutput -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -35,7 +35,11 @@ class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="SQL查询", description="使用大模型生成SQL语句,用于查询数据库中的结构化数据") + return CallInfo( + name="SQL查询", + type=CallType.TOOL, + description="使用大模型生成SQL语句,用于查询数据库中的结构化数据" + ) async def _init(self, call_vars: CallVars) -> SQLInput: diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 1788fa0f4a8ede3af38264c9bb4a82628018086b..663e1d9c95d3faae2b51d75de0a7bab3c54ca66c 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -20,7 +20,7 @@ from apps.scheduler.call.suggest.schema import ( SuggestionInput, SuggestionOutput, ) -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.record import RecordContent from apps.schemas.scheduler import ( @@ -50,7 +50,11 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="问题推荐", description="在答案下方显示推荐的下一个问题") + return CallInfo( + name="问题推荐", + type=CallType.DEFAULT, + description="在答案下方显示推荐的下一个问题" + ) @classmethod diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py index b605204e179246f915d561c0884fd277712faf33..7f6ff062d522efa25af8a5fdc5ceec38d7ba967d 100644 --- a/apps/scheduler/call/summary/summary.py +++ b/apps/scheduler/call/summary/summary.py @@ -9,7 +9,7 @@ from pydantic import Field from apps.llm.patterns.executor import ExecutorSummary from apps.scheduler.call.core import CoreCall, DataBase from apps.scheduler.call.summary.schema import SummaryOutput -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.scheduler import ( CallInfo, @@ -31,7 +31,11 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="理解上下文", description="使用大模型,理解对话上下文") + return CallInfo( + name="理解上下文", + type=CallType.DEFAULT, + description="使用大模型,理解对话上下文" + ) @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f6814dd364636d304382dc258f87fdd9f06eb9a8..2ff4f3d39633b26d792b0b2c2a257107b4874d30 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -7,6 +7,8 @@ from pydantic import Field from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.agent.mcp import MCPAgent +from apps.schemas.task import ExecutorState, StepQueueItem +from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -15,26 +17,14 @@ class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" question: str = Field(description="用户输入") - max_steps: int = Field(default=10, description="最大步数") + max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - async def run(self) -> None: - """运行MCP Agent""" - agent = await MCPAgent.create( - servers_id=self.servers_id, - max_steps=self.max_steps, - task=self.task, - msg_queue=self.msg_queue, - question=self.question, - agent_id=self.agent_id, - description=self.agent_description, - ) - - try: - answer = await agent.run(self.question) - self.task = agent.task - self.task.runtime.answer = answer - except Exception as e: - logger.error(f"Error: {str(e)}") + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a70d0d7073c50ec2290c3fd4c797bbb3efcd4501..a86ec4ac063c3cda5c6444ae17d7af5a6f6e7f46 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -47,7 +47,6 @@ class FlowExecutor(BaseExecutor): question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") @@ -70,7 +69,6 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: """单一Step执行""" # 创建步骤Runner @@ -90,7 +88,6 @@ class FlowExecutor(BaseExecutor): # 更新Task(已存过库) self.task = step_runner.task - async def _step_process(self) -> None: """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: @@ -102,7 +99,6 @@ class FlowExecutor(BaseExecutor): # 执行Step await self._invoke_runner(queue_item) - async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" next_ids = [] @@ -111,14 +107,22 @@ class FlowExecutor(BaseExecutor): next_ids += [edge.edge_to] return next_ids - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] - - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + if self.task.state.step_name == "Choice": + # 如果是choice节点,获取分支ID + branch_id = self.task.context[-1]["output_data"]["branch_id"] + if branch_id: + self.task.state.step_id = self.task.state.step_id + "." + branch_id + logger.info("[FlowExecutor] 分支ID:%s", branch_id) + else: + logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") + return [] + + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -137,7 +141,6 @@ class FlowExecutor(BaseExecutor): for next_step in next_steps ] - async def run(self) -> None: """ 运行流,返回各步骤结果,直到无法继续执行 @@ -150,8 +153,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -170,7 +173,7 @@ class FlowExecutor(BaseExecutor): # 运行Flow(未达终点) while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -183,7 +186,7 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], # type: ignore[arg-type] ), }, ), diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 6b3451fa9ccd92b8f9b2d497f3983a64ff3a6981..5bf37392a4f599e464b00b10d47a3f622fe9289c 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -13,11 +13,14 @@ from pydantic import ConfigDict from apps.scheduler.call.core import CoreCall from apps.scheduler.call.empty import Empty from apps.scheduler.call.facts.facts import FactsCall +from apps.scheduler.call.reply.direct_reply import DirectReply from apps.scheduler.call.slot.schema import SlotOutput from apps.scheduler.call.slot.slot import Slot from apps.scheduler.call.summary.summary import Summary from apps.scheduler.executor.base import BaseExecutor +from apps.scheduler.executor.step_config import should_use_direct_conversation_format from apps.scheduler.pool.pool import Pool +from apps.scheduler.variable.integration import VariableIntegration from apps.schemas.enum_var import ( EventType, SpecialCallType, @@ -70,6 +73,8 @@ class StepExecutor(BaseExecutor): return FactsCall if call_id == SpecialCallType.SLOT.value: return Slot + if call_id == SpecialCallType.DIRECT_REPLY.value: + return DirectReply # 从Pool中获取对应的Call call_cls: type[CoreCall] = await Pool().get_call(call_id) @@ -86,8 +91,8 @@ class StepExecutor(BaseExecutor): logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) # State写入ID和运行状态 - self.task.state.step_id = self.step.step_id # type: ignore[arg-type] - self.task.state.step_name = self.step.step.name # type: ignore[arg-type] + self.task.state.step_id = self.step.step_id # type: ignore[arg-type] + self.task.state.step_name = self.step.step.name # type: ignore[arg-type] # 获取并验证Call类 node_id = self.step.step.node @@ -119,7 +124,6 @@ class StepExecutor(BaseExecutor): logger.exception("[StepExecutor] 初始化Call失败") raise - async def _run_slot_filling(self) -> None: """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断是否需要进行自动参数填充 @@ -127,13 +131,13 @@ class StepExecutor(BaseExecutor): return # 暂存旧数据 - current_step_id = self.task.state.step_id # type: ignore[arg-type] - current_step_name = self.task.state.step_name # type: ignore[arg-type] + current_step_id = self.task.state.step_id # type: ignore[arg-type] + current_step_name = self.task.state.step_name # type: ignore[arg-type] # 更新State - self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] - self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] + self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,21 +160,20 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.task.state.step_id = current_step_id # type: ignore[arg-type] - self.task.state.step_name = current_step_name # type: ignore[arg-type] + self.task.state.step_id = current_step_id # type: ignore[arg-type] + self.task.state.step_name = current_step_name # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens - async def _process_chunk( self, iterator: AsyncGenerator[CallOutputChunk, None], @@ -203,43 +206,149 @@ class StepExecutor(BaseExecutor): return content + async def _save_output_parameters_to_variables(self, output_data: str | dict[str, Any]) -> None: + """保存节点输出参数到变量池""" + try: + # 检查是否有output_parameters配置 + output_parameters = None + if self.step.step.params and isinstance(self.step.step.params, dict): + output_parameters = self.step.step.params.get("output_parameters", {}) + + if not output_parameters or not isinstance(output_parameters, dict): + return + + # 确保output_data是字典格式 + if isinstance(output_data, str): + # 如果是字符串,包装成字典 + data_dict = {"text": output_data} + else: + data_dict = output_data if isinstance(output_data, dict) else {} + + # 确定变量名前缀(根据配置决定是否使用直接格式) + use_direct_format = should_use_direct_conversation_format( + call_id=self._call_id, + step_name=self.step.step.name, + step_id=self.step.step_id + ) + + if use_direct_format: + # 配置允许的节点类型保持原有格式:conversation.key + var_prefix = "" + logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用直接变量格式") + else: + # 其他节点使用格式:conversation.node_id.key + var_prefix = f"{self.step.step_id}." + logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用带前缀变量格式") + + # 保存每个output_parameter到变量池 + saved_count = 0 + for param_name, param_config in output_parameters.items(): + try: + # 获取参数值 + param_value = self._extract_value_from_output_data(param_name, data_dict, param_config) + + if param_value is not None: + # 构造变量名 + var_name = f"{var_prefix}{param_name}" + + # 保存到对话变量池 + success = await VariableIntegration.save_conversation_variable( + var_name=var_name, + value=param_value, + var_type=param_config.get("type", "string"), + description=param_config.get("description", ""), + user_sub=self.task.ids.user_sub, + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + conversation_id=self.task.ids.conversation_id + ) + + if success: + saved_count += 1 + logger.debug(f"[StepExecutor] 已保存输出参数变量: conversation.{var_name} = {param_value}") + else: + logger.warning(f"[StepExecutor] 保存输出参数变量失败: {var_name}") + + except Exception as e: + logger.warning(f"[StepExecutor] 保存输出参数 {param_name} 失败: {e}") + + if saved_count > 0: + logger.info(f"[StepExecutor] 已保存 {saved_count} 个输出参数到变量池") + + except Exception as e: + logger.error(f"[StepExecutor] 保存输出参数到变量池失败: {e}") + + def _extract_value_from_output_data(self, param_name: str, output_data: dict[str, Any], param_config: dict) -> Any: + """从输出数据中提取参数值""" + # 支持多种提取方式 + + # 1. 直接从输出数据中获取同名key + if param_name in output_data: + return output_data[param_name] + + # 2. 支持路径提取(例如:result.data.value) + if "path" in param_config: + path = param_config["path"] + current_data = output_data + for key in path.split("."): + if isinstance(current_data, dict) and key in current_data: + current_data = current_data[key] + else: + return None + return current_data + + # 3. 支持默认值 + if "default" in param_config: + return param_config["default"] + + # 4. 如果参数配置为"full_output",返回完整输出 + if param_config.get("source") == "full_output": + return output_data + + return None + + async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name) + # logger.info("************************step.py中的run的self: %r", self) + # 进行自动参数填充 await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) # 执行步骤 + + # node type 是新加来判断是否是改造后的node的 iterator = self.obj.exec(self, self.obj.input) try: content = await self._process_chunk(iterator, to_user=self.obj.to_user) + except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": e.message, "data": e.data, } else: - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": str(e), "data": {}, } return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -250,19 +359,22 @@ class StepExecutor(BaseExecutor): else: output_data = content + # 保存output_parameters到变量池 + await self._save_output_parameters_to_variables(output_data) + # 更新context history = FlowStepHistory( task_id=self.task.id, - flow_id=self.task.state.flow_id, # type: ignore[arg-type] - flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + flow_name=self.task.state.flow_name, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + status=self.task.state.status, # type: ignore[arg-type] input_data=self.obj.input, output_data=output_data, ) - self.task.context.append(history.model_dump(exclude_none=True, by_alias=True)) + self.task.context.append(history) # 推送输出 await self.push_message(EventType.STEP_OUTPUT.value, output_data) diff --git a/apps/scheduler/executor/step_config.py b/apps/scheduler/executor/step_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1abcf44960a3dee65ba537bb731a71a21f64a43b --- /dev/null +++ b/apps/scheduler/executor/step_config.py @@ -0,0 +1,79 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""步骤执行器配置""" + +from typing import Set + +# 可以直接写入 conversation.key 格式的节点类型配置 +# 这些节点类型的输出变量不会添加 step_id 前缀 +# +# 说明: +# - 添加新的节点类型时,直接在这个集合中添加对应的 call_id 或节点名称即可 +# - 支持大小写敏感匹配,建议同时添加大小写版本以确保兼容性 +# - 这些节点的输出变量将保存为 conversation.key 格式 +# - 其他节点的输出变量将保存为 conversation.step_id.key 格式 +DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: Set[str] = { + # 开始节点相关 + "Start", + "start", + + # 输入节点相关 + "Input", + "UserInput", + "input", + + # 未来可能的节点类型示例(取消注释即可启用) + # "GlobalConfig", # 全局配置节点 + # "SessionInit", # 会话初始化节点 + # "SystemConfig", # 系统配置节点 +} + +# 可以通过节点名称模式匹配的规则 +# 如果节点名称或step_id(转换为小写后)以这些字符串开头,则使用直接格式 +# +# 说明: +# - 这些模式用于匹配节点名称或step_id的前缀(不区分大小写) +# - 比如:"start"会匹配 "StartProcess"、"start_workflow" 等 +# - 适用于无法提前知道具体节点名称,但可以通过命名规范识别的场景 +DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: Set[str] = { + "start", # 匹配所有以start开头的节点 + "init", # 匹配所有以init开头的节点 + "input", # 匹配所有以input开头的节点 + + # 可以根据需要添加更多模式 + # "config", # 匹配配置相关节点 + # "setup", # 匹配设置相关节点 +} + +def should_use_direct_conversation_format(call_id: str, step_name: str, step_id: str) -> bool: + """ + 判断是否应该使用直接的 conversation.key 格式 + + Args: + call_id: 节点的call_id + step_name: 节点名称 + step_id: 节点ID + + Returns: + bool: True表示使用 conversation.key,False表示使用 conversation.step_id.key + """ + # 1. 检查call_id是否在直接写入列表中 + if call_id in DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: + return True + + # 2. 检查节点名称是否在直接写入列表中 + if step_name in DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: + return True + + # 3. 检查节点名称是否匹配模式 + step_name_lower = step_name.lower() + for pattern in DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: + if step_name_lower.startswith(pattern): + return True + + # 4. 检查step_id是否匹配模式 + step_id_lower = step_id.lower() + for pattern in DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: + if step_id_lower.startswith(pattern): + return True + + return False \ No newline at end of file diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 78aa7bc3ee869e8710e1fb02a2d9fb438d04be34..93c60711fd0d0f90a340c4f8b0d47e79c221a3ff 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -69,7 +69,7 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue context_list.append(context) @@ -120,7 +120,7 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py index cd4f5975eea3f023a92626966081c2d1eb33bdb7..b4b982037767c1ab7b991a37d98f39b846144f96 100644 --- a/apps/scheduler/mcp/plan.py +++ b/apps/scheduler/mcp/plan.py @@ -1,6 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 用户目标拆解与规划""" - +import logging from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -9,6 +9,7 @@ from apps.llm.reasoning import ReasoningLLM from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER from apps.schemas.mcp import MCPPlan, MCPTool +logger = logging.getLogger(__name__) class MCPPlanner: """MCP 用户目标拆解与规划""" @@ -50,6 +51,7 @@ class MCPPlanner: {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ] + reasoning_llm = ReasoningLLM() result = "" async for chunk in reasoning_llm.call( diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index b322fb0883e8ed935243389cb86066845a549631..8d45bb67db897f45bf440a726c80a4b2d2896d4e 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -62,25 +62,21 @@ MCP_SELECT = dedent(r""" ### 请一步一步思考: """) + CREATE_PLAN = dedent(r""" 你是一个计划生成器。 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 - # 一个好的计划应该: - 1. 能够成功完成用户的目标 2. 计划中的每一个步骤必须且只能使用一个工具。 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 - # 生成计划时的注意事项: - - 每一条计划包含3个部分: - 计划内容:描述单个计划步骤的大致内容 - 工具ID:必须从下文的工具列表中选择 - 工具指令:改写用户的目标,使其更符合工具的输入要求 - 必须按照如下格式生成计划,不要输出任何额外数据: - ```json { "plans": [ @@ -92,16 +88,12 @@ CREATE_PLAN = dedent(r""" ] } ``` - - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ -思考过程应放置在 XML标签中。 +思考过程应放置在 XML标签中。 - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 - # 工具 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - {% for tool in tools %} - {{ tool.id }}{{tool.name}};{{ tool.description }} @@ -113,12 +105,9 @@ CREATE_PLAN = dedent(r""" # 样例 ## 目标 - 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 - ## 计划 - - + 1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server 2. 目标可以拆解为以下几个部分: - 运行alpine:latest容器 @@ -126,7 +115,7 @@ CREATE_PLAN = dedent(r""" - 在后台运行 - 执行top命令 3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令 - + ```json { @@ -156,9 +145,7 @@ CREATE_PLAN = dedent(r""" ``` # 现在开始生成计划: - ## 目标 - {{goal}} # 计划 @@ -188,7 +175,7 @@ EVALUATE_PLAN = dedent(r""" # 进行评估时的注意事项: - - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。 + - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。 - 评估结果分为两个部分: - 计划评估的结论 - 改进后的计划 @@ -212,6 +199,7 @@ EVALUATE_PLAN = dedent(r""" """) FINAL_ANSWER = dedent(r""" 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 + **注意:你只负责汇报目标完成情况,不对任务执行过程中的准确性/合理性等判断** # 用户目标 diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 2ff5034471c5e9c38f166c6187b76dfb4596f734..cca35ef2eec1644c5512729762681ffd7cd7febe 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -171,6 +171,7 @@ class MCPSelector: for tool_vec in tool_vecs: # 到MongoDB里找对应的工具 logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"]) + logger.info("[mcp_id与tools.id]: %s, %s", tool_vec["mcp_id"], tool_vec["id"]) tool_data = await tool_collection.aggregate([ {"$match": {"_id": tool_vec["mcp_id"]}}, {"$unwind": "$tools"}, @@ -178,6 +179,7 @@ class MCPSelector: {"$project": {"_id": 0, "tools": 1}}, {"$replaceRoot": {"newRoot": "$tools"}}, ]) + logger.info(f"[********** tool_data: {tool_data}") async for tool in tool_data: tool_obj = MCPTool.model_validate(tool) llm_tool_list.append(tool_obj) diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 2834487860a5f8cad1c33723ac5add4976160b57..157342875dfaf679ac112c8ff43e60126b28c4a7 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -42,13 +42,12 @@ class CallLoader(metaclass=SingletonMeta): call_metadata.append( CallPool( _id=call_id, - type=CallType.SYSTEM, + type=call_info.type, name=call_info.name, description=call_info.description, path=f"python::apps.scheduler.call::{call_id}", ), ) - return call_metadata async def _load_single_call_dir(self, call_dir_name: str) -> list[CallPool]: @@ -189,6 +188,7 @@ class CallLoader(metaclass=SingletonMeta): NodePool( _id=call.id, name=call.name, + type=call.type, description=call.description, service_id="", call_id=call.id, diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 57344d40a5c1aba0454a7f00fffcfbf5a93ee221..2b4fcdc7afdcadbcc176486685bd89ab4009233c 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -27,6 +27,10 @@ BASE_PATH = Path(Config().get_config().deploy.data_dir) / "semantics" / "app" class FlowLoader: """工作流加载器""" + # 添加并发控制 + _loading_flows = {} # 改为字典,存储加载任务 + _loading_lock = asyncio.Lock() + async def _load_yaml_file(self, flow_path: Path) -> dict[str, Any]: """从YAML文件加载工作流配置""" try: @@ -100,7 +104,61 @@ class FlowLoader: async def load(self, app_id: str, flow_id: str) -> Flow | None: """从文件系统中加载【单个】工作流""" - logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", flow_id, app_id) + flow_key = f"{app_id}:{flow_id}" + + # 第一次检查:是否已在加载中 + existing_task = None + async with self._loading_lock: + if flow_key in self._loading_flows: + existing_task = self._loading_flows[flow_key] + + # 如果找到现有任务,等待其完成 + if existing_task is not None: + logger.info(f"[FlowLoader] 工作流正在加载中,等待完成: {flow_key}") + try: + return await existing_task + except Exception as e: + logger.error(f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}") + # 如果等待失败,清理失败的任务并重试 + async with self._loading_lock: + if self._loading_flows.get(flow_key) == existing_task: + self._loading_flows.pop(flow_key, None) + return None + + # 创建新的加载任务 + task = None + async with self._loading_lock: + # 再次检查,防止竞态条件 + if flow_key in self._loading_flows: + existing_task = self._loading_flows[flow_key] + # 如果有新任务出现,等待它完成 + if existing_task is not None: + try: + return await existing_task + except Exception as e: + logger.error(f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}") + return None + + # 创建新的加载任务 + task = asyncio.create_task(self._do_load(app_id, flow_id)) + self._loading_flows[flow_key] = task + + # 执行加载任务 + try: + result = await task + return result + except Exception as e: + logger.error(f"[FlowLoader] 工作流加载失败: {flow_key}, 错误: {e}") + return None + finally: + # 确保从加载集合中移除 + async with self._loading_lock: + if self._loading_flows.get(flow_key) == task: + self._loading_flows.pop(flow_key, None) + + async def _do_load(self, app_id: str, flow_id: str) -> Flow | None: + """实际执行加载工作流的方法""" + logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", app_id, flow_id) # 构建工作流文件路径 flow_path = BASE_PATH / app_id / "flow" / f"{flow_id}.yaml" @@ -235,18 +293,29 @@ class FlowLoader: except Exception: logger.exception("[FlowLoader] 更新 MongoDB 失败") - # 删除重复的ID - while True: + # 删除重复的ID,增加重试次数限制 + max_retries = 10 + retry_count = 0 + while retry_count < max_retries: try: table = await LanceDB().get_table("flow") await table.delete(f"id = '{metadata.id}'") break except RuntimeError as e: if "Commit conflict" in str(e): - logger.error("[FlowLoader] LanceDB删除flow冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) + retry_count += 1 + logger.error(f"[FlowLoader] LanceDB删除flow冲突,重试中... ({retry_count}/{max_retries})") # noqa: TRY400 + # 指数退避,减少冲突概率 + await asyncio.sleep(0.01 * (2 ** min(retry_count, 5))) else: raise + except Exception as e: + logger.error(f"[FlowLoader] LanceDB删除操作异常: {e}") + break + + if retry_count >= max_retries: + logger.warning(f"[FlowLoader] LanceDB删除flow达到最大重试次数,跳过删除: {metadata.id}") + # 不抛出异常,继续执行后续操作 # 进行向量化 service_embedding = await Embedding.get_embedding([metadata.description]) vector_data = [ @@ -256,7 +325,10 @@ class FlowLoader: embedding=service_embedding[0], ), ] - while True: + # 插入向量数据,增加重试次数限制 + max_retries_insert = 10 + retry_count_insert = 0 + while retry_count_insert < max_retries_insert: try: table = await LanceDB().get_table("flow") await table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( @@ -265,7 +337,16 @@ class FlowLoader: break except RuntimeError as e: if "Commit conflict" in str(e): - logger.error("[FlowLoader] LanceDB插入flow冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) + retry_count_insert += 1 + logger.error(f"[FlowLoader] LanceDB插入flow冲突,重试中... ({retry_count_insert}/{max_retries_insert})") # noqa: TRY400 + # 指数退避,减少冲突概率 + await asyncio.sleep(0.01 * (2 ** min(retry_count_insert, 5))) else: raise + except Exception as e: + logger.error(f"[FlowLoader] LanceDB插入操作异常: {e}") + break + + if retry_count_insert >= max_retries_insert: + logger.error(f"[FlowLoader] LanceDB插入flow达到最大重试次数,操作失败: {metadata.id}") + raise RuntimeError(f"LanceDB插入flow失败,达到最大重试次数: {metadata.id}") diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 66a516e77a69bb9e3418953d7ae6fe2c8ddc4052..3cd6183a1c8025977396765c703446ae9b8feb77 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -153,7 +153,6 @@ class MCPLoader(metaclass=SingletonMeta): # 检查目录 template_path = MCP_PATH / "template" / mcp_id await Path.mkdir(template_path, parents=True, exist_ok=True) - ProcessHandler.clear_finished_tasks() # 安装MCP模板 if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, mcp_id, config): err = f"安装任务无法执行,请稍后重试: {mcp_id}" @@ -451,23 +450,27 @@ class MCPLoader(metaclass=SingletonMeta): ) @staticmethod - async def _find_deleted_mcp() -> list[str]: + async def _find_deleted_and_undeleted_mcp() -> tuple[list[str], list[str]]: """ 查找在文件系统中被修改和被删除的MCP - :return: 被修改的MCP列表和被删除的MCP列表 + :return: 被修改的MCP列表和被删除的MCP列表, 和未被修改和删除的MCP列表 :rtype: tuple[list[str], list[str]] """ deleted_mcp_list = [] + undeleted_mcp_list = [] mcp_collection = MongoDB().get_collection("mcp") mcp_list = await mcp_collection.find({}, {"_id": 1}).to_list(None) for db_item in mcp_list: mcp_path: Path = MCP_PATH / "template" / db_item["_id"] - if not await mcp_path.exists(): + if await mcp_path.exists(): + undeleted_mcp_list.append(db_item["_id"]) + else: deleted_mcp_list.append(db_item["_id"]) logger.info("[MCPLoader] 这些MCP在文件系统中被删除: %s", deleted_mcp_list) - return deleted_mcp_list + logger.info("[MCPLoader] 这些MCP在文件系统中存在: %s", undeleted_mcp_list) + return deleted_mcp_list, undeleted_mcp_list @staticmethod async def remove_deleted_mcp(deleted_mcp_list: list[str]) -> None: @@ -505,6 +508,45 @@ class MCPLoader(metaclass=SingletonMeta): raise logger.info("[MCPLoader] 清除LanceDB中无效的MCP") + @staticmethod + async def remove_cached_mcp_in_lance(undeleted_mcp_list: list[str]) -> None: + """ + 清除LanceDB中缓存的MCP + + :return: 无 + """ + + try: + mcp_table = await LanceDB().get_table("mcp") + + all_data = await mcp_table.to_pandas() + logger.info(f"表中原始数据量: {len(all_data)}条") + + if len(all_data) == 0: + logger.info("[MCPLoader] LanceDB中没有MCP数据,跳过删除") + return + + ids_to_delete = all_data[~all_data['id'].isin(undeleted_mcp_list)]['id'].tolist() + + if not ids_to_delete: + logger.info("没有需要删除的记录") + return + + logger.info(f"将要删除 {len(ids_to_delete)} 条记录,ID列表: {ids_to_delete}") + for id_to_delete in ids_to_delete: + await mcp_table.delete(f"id == '{id_to_delete}'") + + logger.info("删除操作完成") + + remaining_data = await mcp_table.to_pandas() + logger.info(f"删除后表中剩余数据量: {len(remaining_data)}条") + except Exception as e: + if "Commit conflict" in str(e): + logger.error("[MCPLoader] LanceDB删除mcp冲突,重试中...") # noqa: TRY400 + await asyncio.sleep(0.01) + else: + raise + @staticmethod async def delete_mcp(mcp_id: str) -> None: """ @@ -568,8 +610,20 @@ class MCPLoader(metaclass=SingletonMeta): :return: 无 """ # 清空数据库 - deleted_mcp_list = await MCPLoader._find_deleted_mcp() + + # As IS: + # 如果mongo的mcp库中搜索到的 在文件系统里没有 则认为是需要被[mongo, lance]中删除 + # deleted_mcp_list = await MCPLoader._find_deleted_mcp() + # await MCPLoader.remove_deleted_mcp(deleted_mcp_list) + + # To BE: + # 如果mongo的mcp中搜索到 + # 文件中不存在的 删除 mongo删除 lance删除 + # 再对比lance和mongo的mcp清单,删除mongo中不存在的 + + deleted_mcp_list, undeleted_mcp_list = await MCPLoader._find_deleted_and_undeleted_mcp() await MCPLoader.remove_deleted_mcp(deleted_mcp_list) + await MCPLoader.remove_cached_mcp_in_lance(undeleted_mcp_list) # 检查目录 await MCPLoader._check_dir() diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index b7088d8da45ad01e94e0cf7208db25c5f56c3b22..687697fb81ba4389ffb1d247381288dfdd831cd7 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -114,11 +114,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: used_docs.append( RecordGroupDocument( _id=docs["id"], + author=docs.get("author", ""), + order=docs.get("order", 0), name=docs["name"], abstract=docs.get("abstract", ""), extension=docs.get("extension", ""), size=docs.get("size", 0), associated="answer", + created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)), ) ) if docs.get("order") is not None: @@ -185,7 +188,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: feature={}, ), createdAt=current_time, - flow=[i["_id"] for i in task.context], + flow=[i.id for i in task.context], ) # 检查是否存在group_id diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index c89fdd1014ba4714d0777742412e98c0802ca68c..a2a45e41494e3ddb8d291c0a907da06bb479d4a5 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -71,14 +71,7 @@ async def push_rag_message( # 如果是文本消息,直接拼接到答案中 full_answer += content_obj.content elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append({ - "id": content_obj.content.get("id", ""), - "order": content_obj.content.get("order", 0), - "name": content_obj.content.get("name", ""), - "abstract": content_obj.content.get("abstract", ""), - "extension": content_obj.content.get("extension", ""), - "size": content_obj.content.get("size", 0), - }) + task.runtime.documents.append(content_obj.content) # 保存答案 task.runtime.answer = full_answer await TaskManager.save_task(task.id, task) @@ -115,10 +108,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl data=DocumentAddContent( documentId=content_obj.content.get("id", ""), documentOrder=content_obj.content.get("order", 0), + documentAuthor=content_obj.content.get("author", ""), documentName=content_obj.content.get("name", ""), documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), + createdAt=round(datetime.now(tz=UTC).timestamp(), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index ed73638ced241909f55989597f8bb747b50f1945..417f93d28a147cd0cf65b124624484a02a49ab06 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -206,7 +206,7 @@ class Scheduler: task=self.task, msg_queue=queue, question=post_body.question, - max_steps=app_metadata.history_len, + history_len=app_metadata.history_len, servers_id=servers_id, background=background, agent_id=app_info.app_id, diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 4caab4d2dc945ec5ac14b7769770f94c39980a9f..89433cade929ef6917664450bb3645500ed2df5a 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -12,6 +12,8 @@ from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema.validators import extend +from apps.schemas.response_data import ParamsNode +from apps.scheduler.call.choice.schema import Type from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, @@ -221,6 +223,45 @@ class Slot: return data return _extract_type_desc(self._schema) + def get_params_node_from_schema(self, root: str = "") -> ParamsNode: + """从JSON Schema中提取ParamsNode""" + def _extract_params_node(schema_node: dict[str, Any], name: str = "", path: str = "") -> ParamsNode: + """递归提取ParamsNode""" + if "type" not in schema_node: + return None + + param_type = schema_node["type"] + if param_type == "object": + param_type = Type.DICT + elif param_type == "array": + param_type = Type.LIST + elif param_type == "string": + param_type = Type.STRING + elif param_type == "number": + param_type = Type.NUMBER + elif param_type == "boolean": + param_type = Type.BOOL + else: + logger.warning(f"[Slot] 不支持的参数类型: {param_type}") + return None + sub_params = [] + + if param_type == "object" and "properties" in schema_node: + for key, value in schema_node["properties"].items(): + sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}")) + else: + # 对于非对象类型,直接返回空子参数 + sub_params = None + return ParamsNode(paramName=name, + paramPath=path, + paramType=param_type, + subParams=sub_params) + try: + return _extract_params_node(self._schema, name=root, path=root) + except Exception as e: + logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") + return None + def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """将JSON Schema扁平化""" result = {} @@ -276,7 +317,6 @@ class Slot: logger.exception("[Slot] 错误schema不合法: %s", error.schema) return {}, [] - def _assemble_patch( self, key: str, @@ -329,7 +369,6 @@ class Slot: logger.info("[Slot] 组装patch: %s", patch_list) return patch_list - def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]: """将用户手动填充的参数专为真实JSON""" json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data diff --git a/apps/scheduler/variable/README.md b/apps/scheduler/variable/README.md new file mode 100644 index 0000000000000000000000000000000000000000..73718ce01f2eca2514d233b43059cc76018c7a48 --- /dev/null +++ b/apps/scheduler/variable/README.md @@ -0,0 +1,264 @@ +# 变量池架构文档 + +## 架构设计 + +基于用户需求,变量系统采用"模板-实例"的两级架构: + +### 设计理念 + +- **Flow级别(父pool)**:管理变量模板定义,用户可以查看和配置变量结构 +- **Conversation级别(子pool)**:管理变量实例,存储实际的运行时数据 + +### 变量分类 + +不同类型的变量有不同的存储和管理方式: +- **系统变量**:模板在Flow级别定义,实例在Conversation级别运行时更新 +- **对话变量**:模板在Flow级别定义,实例在Conversation级别用户可设置 +- **环境变量**:直接在Flow级别存储和使用 +- **用户变量**:在User级别长期存储 + +## 架构实现 + +### 变量池类型 + +#### 1. UserVariablePool(用户变量池) +- **关联ID**: `user_id` +- **权限**: 用户可读写 +- **生命周期**: 随用户创建而创建,长期存在 +- **典型变量**: API密钥、用户偏好、个人配置等 + +#### 2. FlowVariablePool(流程变量池) +- **关联ID**: `flow_id` +- **权限**: 流程可读写 +- **生命周期**: 随 flow 创建而创建 +- **继承**: 支持从父流程继承 +- **存储内容**: + - 环境变量(直接使用) + - 系统变量模板(供对话继承) + - 对话变量模板(供对话继承) + +#### 3. ConversationVariablePool(对话变量池) +- **关联ID**: `conversation_id` +- **权限**: + - 系统变量实例:只读,由系统自动更新 + - 对话变量实例:可读写,用户可设置值 +- **生命周期**: 随对话创建而创建,对话结束后可选择性清理 +- **初始化方式**: 从FlowVariablePool的模板自动继承 +- **包含内容**: + - **系统变量实例**:`query`, `files`, `dialogue_count`等运行时值 + - **对话变量实例**:用户定义的对话上下文数据 + +## 核心设计原则 + +### 1. 统一的对话上下文 +所有对话相关的变量(无论是系统变量还是对话变量)都在同一个对话变量池中管理,确保上下文的一致性。 + +### 2. 权限区分 +通过 `is_system` 标记区分系统变量和对话变量: +- `is_system=True`: 系统变量,只读,由系统自动更新 +- `is_system=False`: 对话变量,可读写,支持人为修改 + +### 3. 自动初始化和持久化 +创建对话变量池时,自动初始化所有必需的系统变量,设置合理的默认值,并立即持久化到数据库,确保系统变量在任何时候都可用。 + +## 使用方式 + +### 1. 创建对话变量池 + +```python +pool_manager = await get_pool_manager() + +# 创建对话变量池(自动包含系统变量) +conv_pool = await pool_manager.create_conversation_pool("conv123", "flow456") +``` + +### 2. 更新系统变量 + +```python +# 系统变量由解析器自动更新 +parser = VariableParser( + user_id="user123", + flow_id="flow456", + conversation_id="conv123" +) + +# 更新系统变量 +await parser.update_system_variables({ + "question": "你好,请帮我分析数据", + "files": [{"name": "data.csv", "size": 1024}], + "dialogue_count": 1, + "user_sub": "user123" +}) +``` + +### 3. 更新对话变量 + +```python +# 添加对话变量 +await conv_pool.add_variable( + name="context_history", + var_type=VariableType.ARRAY_STRING, + value=["用户问候", "系统回应"], + description="对话历史" +) + +# 更新对话变量 +await conv_pool.update_variable("context_history", value=["问候", "回应", "新消息"]) +``` + +### 4. 变量解析 + +```python +# 系统变量和对话变量使用相同的引用语法 +template = """ +系统变量 - 用户查询: {{sys.query}} +系统变量 - 对话轮数: {{sys.dialogue_count}} +对话变量 - 历史: {{conversation.context_history}} +用户变量 - 偏好: {{user.preferences}} +环境变量 - 数据库: {{env.database_url}} +""" + +parsed = await parser.parse_template(template) +``` + +## 变量引用语法 + +变量引用保持不变: +- `{{sys.variable_name}}` - 系统变量(对话级别,只读) +- `{{conversation.variable_name}}` - 对话变量(对话级别,可读写) +- `{{user.variable_name}}` - 用户变量 +- `{{env.variable_name}}` - 环境变量 + +## 权限控制详细说明 + +### 系统变量权限 +```python +# 普通更新会被拒绝 +await conv_pool.update_variable("query", value="new query") # ❌ 抛出 PermissionError + +# 系统内部更新 +await conv_pool.update_system_variable("query", "new query") # ✅ 成功 +# 或者 +await conv_pool.update_variable("query", value="new query", force_system_update=True) # ✅ 成功 +``` + +### 对话变量权限 +```python +# 普通对话变量可以自由更新 +await conv_pool.update_variable("context_history", value=new_history) # ✅ 成功 +``` + +## 数据存储 + +### 元数据增强 +```python +class VariableMetadata(BaseModel): + # ... 其他字段 + is_system: bool = Field(default=False, description="是否为系统变量(只读)") +``` + +### 数据库查询 +系统变量和对话变量存储在同一个集合中,通过 `metadata.is_system` 字段区分。 + +## 迁移影响 + +### 对用户的影响 +- ✅ **变量引用语法完全不变** +- ✅ **API接口完全兼容** +- ✅ **现有功能正常工作** + +### 内部实现变化 +- 去掉了独立的 `SystemVariablePool` +- 系统变量现在在 `ConversationVariablePool` 中管理 +- 通过权限控制区分系统变量和对话变量 + +## 架构优势 + +### 1. 逻辑一致性 +系统变量和对话变量都属于对话上下文,在同一个池中管理更合理。 + +### 2. 简化管理 +不需要在系统池和对话池之间同步数据,避免了数据一致性问题。 + +### 3. 更好的性能 +减少了池之间的数据传递和同步开销。 + +### 4. 扩展性 +为未来可能的对话级系统变量扩展提供了更好的基础。 + +## 总结 + +修正后的架构更准确地反映了变量的实际使用场景: +- **用户变量**: 用户级别,长期存在 +- **环境变量**: 流程级别,配置相关 +- **系统变量 + 对话变量**: 对话级别,上下文相关 + +这样的设计更符合实际业务逻辑,也更容易理解和维护。 + +## 系统变量详细说明 + +### 预定义系统变量 + +每个对话变量池创建时,会自动初始化以下系统变量: + +| 变量名 | 类型 | 描述 | 初始值 | +|-------|------|------|--------| +| `query` | STRING | 用户查询内容 | "" | +| `files` | ARRAY_FILE | 用户上传的文件列表 | [] | +| `dialogue_count` | NUMBER | 对话轮数 | 0 | +| `app_id` | STRING | 应用ID | "" | +| `flow_id` | STRING | 工作流ID | {flow_id} | +| `user_id` | STRING | 用户ID | "" | +| `session_id` | STRING | 会话ID | "" | +| `conversation_id` | STRING | 对话ID | {conversation_id} | +| `timestamp` | NUMBER | 当前时间戳 | {当前时间} | + +### 系统变量生命周期 + +1. **创建阶段**:对话变量池创建时,所有系统变量被初始化并持久化到数据库 +2. **更新阶段**:通过`VariableParser.update_system_variables()`方法更新系统变量值 +3. **访问阶段**:通过模板解析或直接访问获取系统变量值 +4. **清理阶段**:对话结束时,整个对话变量池被清理 + +### 系统变量更新机制 + +```python +# 创建解析器并确保对话池存在 +parser = VariableParser(user_id=user_id, flow_id=flow_id, conversation_id=conversation_id) +await parser.create_conversation_pool_if_needed() + +# 更新系统变量 +context = { + "question": "用户的问题", + "files": [{"name": "file.txt", "size": 1024}], + "dialogue_count": 1, + "app_id": "app123", + "user_sub": user_id, + "session_id": "session456" +} + +await parser.update_system_variables(context) +``` + +### 系统变量的只读保护 + +```python +# ❌ 直接修改系统变量会失败 +await conversation_pool.update_variable("query", value="修改内容") # 抛出PermissionError + +# ✅ 只能通过系统内部接口更新 +await conversation_pool.update_system_variable("query", "新内容") # 成功 +``` + +### 使用系统变量 + +```python +# 在模板中引用系统变量 +template = """ +用户问题:{{sys.query}} +对话轮数:{{sys.dialogue_count}} +工作流ID:{{sys.flow_id}} +""" + +parsed = await parser.parse_template(template) +``` \ No newline at end of file diff --git a/apps/scheduler/variable/__init__.py b/apps/scheduler/variable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c487379c54120835552954559f0e31bd2a6cce0 --- /dev/null +++ b/apps/scheduler/variable/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""工作流变量管理模块 + +这个模块提供了完整的变量管理功能,包括: +- 多种变量类型支持(String、Number、Boolean、Object、Secret、File、Array等) +- 四种作用域支持(系统级、用户级、环境级、对话级) +- 变量解析和模板替换 +- 安全的密钥变量处理 +- 与工作流调度器的集成 +""" + +from .type import VariableType, VariableScope +from .base import BaseVariable, VariableMetadata +from .variables import ( + StringVariable, + NumberVariable, + BooleanVariable, + ObjectVariable, + SecretVariable, + FileVariable, + ArrayVariable, + create_variable, + VARIABLE_CLASS_MAP, +) +from .pool_manager import VariablePoolManager, get_pool_manager +from .parser import VariableParser, VariableReferenceBuilder, VariableContext +from .integration import VariableIntegration + +__all__ = [ + # 基础类型和枚举 + "VariableType", + "VariableScope", + "VariableMetadata", + "BaseVariable", + + # 具体变量类型 + "StringVariable", + "NumberVariable", + "BooleanVariable", + "ObjectVariable", + "SecretVariable", + "FileVariable", + "ArrayVariable", + + # 工厂函数和映射 + "create_variable", + "VARIABLE_CLASS_MAP", + + # 变量池管理器 + "VariablePoolManager", + "get_pool_manager", + + # 解析器 + "VariableParser", + "VariableReferenceBuilder", + "VariableContext", + + # 集成功能 + "VariableIntegration", +] diff --git a/apps/scheduler/variable/base.py b/apps/scheduler/variable/base.py new file mode 100644 index 0000000000000000000000000000000000000000..94beaaebc0f084c1a5dbec6696f7b233a5d66154 --- /dev/null +++ b/apps/scheduler/variable/base.py @@ -0,0 +1,190 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field +from datetime import datetime, UTC + +from .type import VariableType, VariableScope + + +class VariableMetadata(BaseModel): + """变量元数据""" + name: str = Field(description="变量名称") + var_type: VariableType = Field(description="变量类型") + scope: VariableScope = Field(description="变量作用域") + description: Optional[str] = Field(default=None, description="变量描述") + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="创建时间") + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="更新时间") + created_by: Optional[str] = Field(default=None, description="创建者用户ID") + + # 作用域相关属性 + user_sub: Optional[str] = Field(default=None, description="用户级变量的用户ID") + flow_id: Optional[str] = Field(default=None, description="环境级/对话级变量的流程ID") + conversation_id: Optional[str] = Field(default=None, description="对话级变量的对话ID") + + # 系统变量标识 + is_system: bool = Field(default=False, description="是否为系统变量(只读)") + + # 模板变量标识 + is_template: bool = Field(default=False, description="是否为模板变量(存储在flow级别)") + + # 安全相关属性 + is_encrypted: bool = Field(default=False, description="是否加密存储") + access_permissions: Optional[Dict[str, Any]] = Field(default=None, description="访问权限") + + +class BaseVariable(ABC): + """变量处理基类""" + + def __init__(self, metadata: VariableMetadata, value: Any = None): + """初始化变量 + + Args: + metadata: 变量元数据 + value: 变量值 + """ + self.metadata = metadata + self._value = None # 先设置为None + self._original_value = value + self._initializing = True # 标记正在初始化 + + # 通过setter设置值,触发类型验证 + if value is not None: + self.value = value # 这会触发setter和类型验证 + else: + self._value = value # 如果是None则直接设置 + + self._initializing = False # 初始化完成 + + @property + def name(self) -> str: + """获取变量名称""" + return self.metadata.name + + @property + def var_type(self) -> VariableType: + """获取变量类型""" + return self.metadata.var_type + + @property + def scope(self) -> VariableScope: + """获取变量作用域""" + return self.metadata.scope + + @property + def value(self) -> Any: + """获取变量值""" + return self._value + + @value.setter + def value(self, new_value: Any) -> None: + """设置变量值""" + # 只有在非初始化阶段才检查系统级变量的修改限制 + if self.scope == VariableScope.SYSTEM and not getattr(self, '_initializing', False): + raise ValueError("系统级变量不能修改") + + # 验证类型 + if not self._validate_type(new_value): + raise TypeError(f"变量 {self.name} 的值类型不匹配,期望: {self.var_type}") + + self._value = new_value + self.metadata.updated_at = datetime.now(UTC) + + @abstractmethod + def _validate_type(self, value: Any) -> bool: + """验证值的类型是否正确 + + Args: + value: 要验证的值 + + Returns: + bool: 类型是否正确 + """ + pass + + @abstractmethod + def to_string(self) -> str: + """将变量转换为字符串表示 + + Returns: + str: 字符串表示 + """ + pass + + @abstractmethod + def to_dict(self) -> Dict[str, Any]: + """将变量转换为字典表示 + + Returns: + Dict[str, Any]: 字典表示 + """ + pass + + @abstractmethod + def serialize(self) -> Dict[str, Any]: + """序列化变量用于存储 + + Returns: + Dict[str, Any]: 序列化后的数据 + """ + pass + + @classmethod + @abstractmethod + def deserialize(cls, data: Dict[str, Any]) -> "BaseVariable": + """从序列化数据恢复变量 + + Args: + data: 序列化数据 + + Returns: + BaseVariable: 恢复的变量实例 + """ + pass + + def copy(self) -> "BaseVariable": + """创建变量的副本 + + Returns: + BaseVariable: 变量副本 + """ + # 深拷贝元数据和值 + import copy + new_metadata = copy.deepcopy(self.metadata) + new_value = copy.deepcopy(self._value) + return self.__class__(new_metadata, new_value) + + def reset(self) -> None: + """重置变量到初始值""" + if self.scope == VariableScope.SYSTEM: + raise ValueError("系统级变量不能重置") + + self._value = self._original_value + self.metadata.updated_at = datetime.now(UTC) + + def can_access(self, user_sub: str) -> bool: + """检查用户是否有权限访问此变量 + + Args: + user_sub: 用户ID + + Returns: + bool: 是否有权限访问 + """ + # 系统级变量所有人都可以访问 + if self.scope == VariableScope.SYSTEM: + return True + + # 用户级变量只有创建者可以访问 + if self.scope == VariableScope.USER: + return self.metadata.user_sub == user_sub + + # 环境级和对话级变量根据上下文判断(这里简化处理) + return True + + def __str__(self) -> str: + """字符串表示""" + return f"{self.name}({self.var_type.value})={self.to_string()}" + + def __repr__(self) -> str: + """调试表示""" + return f"<{self.__class__.__name__}(name='{self.name}', type='{self.var_type.value}', scope='{self.scope.value}')>" \ No newline at end of file diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py new file mode 100644 index 0000000000000000000000000000000000000000..a33385a176513a8717d98a59bcaf5f028ce8a47f --- /dev/null +++ b/apps/scheduler/variable/integration.py @@ -0,0 +1,366 @@ +"""变量解析与工作流调度器集成""" + +import logging +from typing import Any, Dict, List, Optional, Union + +from apps.scheduler.variable.parser import VariableParser +from apps.scheduler.variable.pool_manager import get_pool_manager +from apps.scheduler.variable.type import VariableScope + +logger = logging.getLogger(__name__) + + +class VariableIntegration: + """变量解析集成类 - 为现有调度器提供变量功能""" + + @staticmethod + async def initialize_system_variables(context: Dict[str, Any]) -> None: + """初始化系统变量 + + Args: + context: 系统上下文信息,包括用户查询、文件等 + """ + try: + parser = VariableParser( + user_sub=context.get("user_sub"), + flow_id=context.get("flow_id"), + conversation_id=context.get("conversation_id") + ) + + # 更新系统变量 + await parser.update_system_variables(context) + + logger.info("系统变量已初始化") + except Exception as e: + logger.error(f"初始化系统变量失败: {e}") + raise + + @staticmethod + async def parse_call_input(input_data: Dict[str, Any], + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> Union[str, Dict, List]: + """解析Call输入中的变量引用 + + Args: + input_data: 输入数据 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + Dict[str, Any]: 解析后的输入数据 + """ + try: + parser = VariableParser( + user_sub=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + # 递归解析JSON模板中的变量引用 + parsed_input = await parser.parse_json_template(input_data) + + return parsed_input + + except Exception as e: + logger.warning(f"解析Call输入变量失败: {e}") + # 如果解析失败,返回原始输入 + return input_data + + @staticmethod + async def resolve_variable_reference( + reference: str, + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None + ) -> Any: + """解析单个变量引用 + + Args: + reference: 变量引用字符串(如 "{{user.name}}" 或 "user.name") + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + Any: 解析后的变量值 + """ + try: + parser = VariableParser( + user_id=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + # 清理引用字符串(移除花括号) + clean_reference = reference.strip("{}") + + # 使用解析器解析变量引用 + resolved_value = await parser._resolve_variable_reference(clean_reference) + + return resolved_value + + except Exception as e: + logger.error(f"解析变量引用失败: {reference}, 错误: {e}") + raise + + @staticmethod + async def save_conversation_variable( + var_name: str, + value: Any, + var_type: str = "string", + description: str = "", + user_sub: str = "", + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None + ) -> bool: + """保存对话变量 + + Args: + var_name: 变量名(不包含scope前缀) + value: 变量值 + var_type: 变量类型 + description: 变量描述 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + bool: 是否保存成功 + """ + try: + if not conversation_id: + logger.warning("无法保存对话变量:缺少conversation_id") + return False + + # 直接使用pool_manager,避免解析器的复杂逻辑 + pool_manager = await get_pool_manager() + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + + if not conversation_pool: + logger.warning(f"无法获取对话变量池: {conversation_id}") + return False + + # 转换变量类型 + from apps.scheduler.variable.type import VariableType + try: + var_type_enum = VariableType(var_type) + except ValueError: + var_type_enum = VariableType.STRING + logger.warning(f"未知的变量类型 {var_type},使用默认类型 string") + + # 尝试更新变量,如果不存在则创建 + try: + await conversation_pool.update_variable(var_name, value=value) + logger.debug(f"对话变量已更新: {var_name} = {value}") + return True + except ValueError as e: + if "不存在" in str(e): + # 变量不存在,创建新变量 + await conversation_pool.add_variable( + name=var_name, + var_type=var_type_enum, + value=value, + description=description, + created_by=user_sub or "system" + ) + logger.debug(f"对话变量已创建: {var_name} = {value}") + return True + else: + raise # 其他错误重新抛出 + + except Exception as e: + logger.error(f"保存对话变量失败: {var_name} - {e}") + return False + + @staticmethod + async def parse_template_string(template: str, + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> str: + """解析模板字符串中的变量引用 + + Args: + template: 模板字符串 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + str: 解析后的字符串 + """ + try: + parser = VariableParser( + user_sub=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + return await parser.parse_template(template) + except Exception as e: + logger.warning(f"解析模板字符串失败: {e}") + # 如果解析失败,返回原始模板 + return template + + @staticmethod + async def add_conversation_variable(name: str, + value: Any, + conversation_id: str, + var_type_str: str = "string") -> bool: + """添加对话级变量 + + Args: + name: 变量名 + value: 变量值 + conversation_id: 对话ID(在内部作为flow_id使用) + var_type_str: 变量类型字符串 + + Returns: + bool: 是否添加成功 + """ + try: + from apps.scheduler.variable.type import VariableType + + # 转换变量类型 + var_type = VariableType(var_type_str) + + pool_manager = await get_pool_manager() + # 获取对话变量池(如果不存在会抛出异常) + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + if not conversation_pool: + logger.error(f"对话变量池不存在: {conversation_id}") + return False + + await conversation_pool.add_variable( + name=name, + var_type=var_type, + value=value, + description=f"对话变量: {name}" + ) + + logger.debug(f"已添加对话变量: {name} = {value}") + return True + except Exception as e: + logger.error(f"添加对话变量失败: {e}") + return False + + @staticmethod + async def update_conversation_variable(name: str, + value: Any, + conversation_id: str) -> bool: + """更新对话级变量 + + Args: + name: 变量名 + value: 新值 + conversation_id: 对话ID + + Returns: + bool: 是否更新成功 + """ + try: + pool_manager = await get_pool_manager() + # 获取对话变量池 + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + if not conversation_pool: + logger.error(f"对话变量池不存在: {conversation_id}") + return False + + await conversation_pool.update_variable( + name=name, + value=value + ) + + logger.debug(f"已更新对话变量: {name} = {value}") + return True + except Exception as e: + logger.error(f"更新对话变量失败: {e}") + return False + + @staticmethod + async def extract_output_variables(output_data: Dict[str, Any], + conversation_id: str, + step_name: str) -> None: + """从步骤输出中提取变量并设置为对话级变量 + + Args: + output_data: 步骤输出数据 + conversation_id: 对话ID + step_name: 步骤名称 + """ + try: + # 将整个输出作为对象变量存储 + await VariableIntegration.add_conversation_variable( + name=f"step_{step_name}_output", + value=output_data, + conversation_id=conversation_id, + var_type_str="object" + ) + + # 如果输出中有特定的变量定义,也可以单独提取 + if isinstance(output_data, dict): + for key, value in output_data.items(): + if key.startswith("var_"): + # 以 var_ 开头的字段被视为变量定义 + var_name = key[4:] # 移除 var_ 前缀 + await VariableIntegration.add_conversation_variable( + name=var_name, + value=value, + conversation_id=conversation_id, + var_type_str="string" + ) + + except Exception as e: + logger.error(f"提取输出变量失败: {e}") + + @staticmethod + async def clear_conversation_context(conversation_id: str) -> None: + """清理对话上下文中的变量 + + Args: + conversation_id: 对话ID + """ + try: + pool_manager = await get_pool_manager() + # 移除对话变量池 + success = await pool_manager.remove_conversation_pool(conversation_id) + if success: + logger.info(f"已清理对话 {conversation_id} 的变量") + else: + logger.warning(f"对话变量池不存在: {conversation_id}") + except Exception as e: + logger.error(f"清理对话变量失败: {e}") + + @staticmethod + async def validate_variable_references(template: str, + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> tuple[bool, list[str]]: + """验证模板中的变量引用是否有效 + + Args: + template: 模板字符串 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + tuple[bool, list[str]]: (是否全部有效, 无效的变量引用列表) + """ + try: + parser = VariableParser( + user_sub=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + return await parser.validate_template(template) + except Exception as e: + logger.error(f"验证变量引用失败: {e}") + return False, [str(e)] + + +# 注意:原本的 monkey_patch_scheduler 和相关扩展类已被移除 +# 因为 CoreCall 类现在已经内置了完整的变量解析功能 +# 这些代码是旧版本的遗留,会导致循环导入问题 \ No newline at end of file diff --git a/apps/scheduler/variable/parser.py b/apps/scheduler/variable/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb0dcfaff90d5a0254c4f13917197b12c13ef95 --- /dev/null +++ b/apps/scheduler/variable/parser.py @@ -0,0 +1,514 @@ +import re +import logging +from typing import Any, Dict, List, Optional, Tuple, Union +import json +from datetime import datetime, UTC + +from .pool_manager import get_pool_manager +from .type import VariableScope + +logger = logging.getLogger(__name__) + + +class VariableParser: + """变量解析器 - 支持新架构的变量解析器(系统变量和对话变量都在对话池中)""" + + # 变量引用的正则表达式:{{scope.variable_name.nested_path}} + VARIABLE_PATTERN = re.compile(r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*\}\}') + + def __init__(self, + user_id: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None, + user_sub: Optional[str] = None): + """初始化变量解析器 + + Args: + user_id: 用户ID (向后兼容) + flow_id: 流程ID + conversation_id: 对话ID + user_sub: 用户订阅ID (优先使用,用于未来鉴权等需求) + """ + # 优先使用 user_sub,如果没有则使用 user_id + self.user_id = user_sub if user_sub is not None else user_id + self.flow_id = flow_id + self.conversation_id = conversation_id + self._pool_manager = None + + async def _get_pool_manager(self): + """获取变量池管理器实例""" + if self._pool_manager is None: + self._pool_manager = await get_pool_manager() + return self._pool_manager + + async def parse_template(self, template: str) -> str: + """解析模板字符串,替换其中的变量引用 + + Args: + template: 包含变量引用的模板字符串 + + Returns: + str: 替换后的字符串 + """ + if not template: + return template + + # 查找所有变量引用 + matches = self.VARIABLE_PATTERN.findall(template) + + # 替换每个变量引用 + result = template + for match in matches: + try: + # 解析变量引用 + value = await self._resolve_variable_reference(match) + + # 转换为字符串 + str_value = self._convert_to_string(value) + + # 替换模板中的变量引用 + result = result.replace(f"{{{{{match}}}}}", str_value) + + logger.debug(f"已替换变量: {{{{{match}}}}} -> {str_value}") + + except Exception as e: + logger.warning(f"解析变量引用失败: {{{{{match}}}}} - {e}") + # 保持原始引用不变 + continue + + return result + + async def _resolve_variable_reference(self, reference: str) -> Any: + """解析变量引用 + + Args: + reference: 变量引用字符串(不含花括号) + + Returns: + Any: 变量值 + """ + pool_manager = await self._get_pool_manager() + + # 解析作用域和变量名 + parts = reference.split(".", 1) + if len(parts) != 2: + raise ValueError(f"无效的变量引用格式: {reference}") + + scope_str, var_path = parts + + # 确定作用域 + scope_map = { + "sys": VariableScope.SYSTEM, + "system": VariableScope.SYSTEM, + "user": VariableScope.USER, + "env": VariableScope.ENVIRONMENT, + "environment": VariableScope.ENVIRONMENT, + "conversation": VariableScope.CONVERSATION, + "conv": VariableScope.CONVERSATION, + } + + scope = scope_map.get(scope_str) + if not scope: + raise ValueError(f"无效的变量作用域: {scope_str}") + + # 解析变量路径 + # 对于conversation作用域,支持节点输出变量格式:conversation.node_id.key + if scope == VariableScope.CONVERSATION and "." in var_path: + # 检查是否为节点输出变量(格式:node_id.key) + # 先尝试获取完整路径作为变量名 + try: + variable = await pool_manager.get_variable_from_any_pool( + name=var_path, # 使用完整路径作为变量名 + scope=scope, + user_id=self.user_id if scope == VariableScope.USER else None, + flow_id=self.flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=self.conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None + ) + if variable: + return variable.value + except: + pass # 如果找不到,继续使用原有逻辑 + + # 原有逻辑:支持嵌套访问如 user.config.api_key + path_parts = var_path.split(".") + var_name = path_parts[0] + + # 根据作用域获取变量 + variable = await pool_manager.get_variable_from_any_pool( + name=var_name, + scope=scope, + user_id=self.user_id, + flow_id=self.flow_id, + conversation_id=self.conversation_id + ) + + if not variable: + raise ValueError(f"变量不存在: {scope_str}.{var_name}") + + # 获取变量值 + value = variable.value + + # 如果有嵌套路径,继续解析 + for path_part in path_parts[1:]: + if isinstance(value, dict): + value = value.get(path_part) + elif isinstance(value, list) and path_part.isdigit(): + try: + value = value[int(path_part)] + except IndexError: + value = None + else: + raise ValueError(f"无法访问路径: {var_path}") + + return value + + async def extract_variables(self, template: str) -> List[str]: + """提取模板中的所有变量引用 + + Args: + template: 模板字符串 + + Returns: + List[str]: 变量引用列表 + """ + if not template: + return [] + + matches = self.VARIABLE_PATTERN.findall(template) + return [f"{{{{{match}}}}}" for match in matches] + + async def validate_template(self, template: str) -> Tuple[bool, List[str]]: + """验证模板中的变量引用是否都存在 + + Args: + template: 模板字符串 + + Returns: + Tuple[bool, List[str]]: (是否全部有效, 无效的变量引用列表) + """ + if not template: + return True, [] + + matches = self.VARIABLE_PATTERN.findall(template) + invalid_refs = [] + + for match in matches: + try: + await self._resolve_variable_reference(match) + except Exception: + invalid_refs.append(f"{{{{{match}}}}}") + + return len(invalid_refs) == 0, invalid_refs + + async def parse_json_template(self, json_template: Union[str, Dict, List]) -> Union[str, Dict, List]: + """解析JSON格式的模板,递归处理所有字符串值中的变量引用 + + Args: + json_template: JSON模板(字符串、字典或列表) + + Returns: + Union[str, Dict, List]: 解析后的JSON + """ + if isinstance(json_template, str): + return await self.parse_template(json_template) + elif isinstance(json_template, dict): + result = {} + for key, value in json_template.items(): + # 键也可能包含变量引用 + parsed_key = await self.parse_template(str(key)) + parsed_value = await self.parse_json_template(value) + result[parsed_key] = parsed_value + return result + elif isinstance(json_template, list): + result = [] + for item in json_template: + parsed_item = await self.parse_json_template(item) + result.append(parsed_item) + return result + else: + # 其他类型直接返回 + return json_template + + async def update_system_variables(self, context: Dict[str, Any]): + """更新系统变量的值 + + Args: + context: 系统上下文信息 + """ + if not self.conversation_id: + logger.warning("无法更新系统变量:缺少conversation_id") + return + + # 确保对话变量池存在 + await self.create_conversation_pool_if_needed() + + pool_manager = await self._get_pool_manager() + + # 预定义的系统变量映射 + system_var_mappings = { + "query": context.get("question", ""), + "files": context.get("files", []), + "dialogue_count": context.get("dialogue_count", 0), + "app_id": context.get("app_id", ""), + "flow_id": context.get("flow_id", self.flow_id or ""), + "user_id": context.get("user_sub", self.user_id or ""), + "session_id": context.get("session_id", ""), + "conversation_id": self.conversation_id, + "timestamp": datetime.now(UTC).timestamp(), + } + + # 获取对话变量池 + conversation_pool = await pool_manager.get_conversation_pool(self.conversation_id) + if not conversation_pool: + logger.error(f"对话变量池不存在,无法更新系统变量: {self.conversation_id}") + return + + # 更新系统变量 + updated_count = 0 + for var_name, var_value in system_var_mappings.items(): + try: + success = await conversation_pool.update_system_variable(var_name, var_value) + if success: + updated_count += 1 + logger.debug(f"已更新系统变量: {var_name} = {var_value}") + else: + logger.warning(f"系统变量更新失败: {var_name}") + except Exception as e: + logger.warning(f"更新系统变量失败: {var_name} - {e}") + + logger.info(f"系统变量更新完成: {updated_count}/{len(system_var_mappings)} 个变量更新成功") + + async def update_conversation_variable(self, var_name: str, value: Any) -> bool: + """更新对话变量的值 + + Args: + var_name: 变量名 + value: 新值 + + Returns: + bool: 是否更新成功 + """ + if not self.conversation_id: + logger.warning("无法更新对话变量:缺少conversation_id") + return False + + pool_manager = await self._get_pool_manager() + conversation_pool = await pool_manager.get_conversation_pool(self.conversation_id) + + if not conversation_pool: + logger.warning(f"无法获取对话变量池: {self.conversation_id}") + return False + + try: + await conversation_pool.update_variable(var_name, value=value) + logger.info(f"已更新对话变量: {var_name} = {value}") + return True + except Exception as e: + logger.error(f"更新对话变量失败: {var_name} - {e}") + return False + + async def create_conversation_pool_if_needed(self) -> bool: + """如果需要,创建对话变量池 + + Returns: + bool: 是否创建成功 + """ + if not self.conversation_id or not self.flow_id: + return False + + pool_manager = await self._get_pool_manager() + existing_pool = await pool_manager.get_conversation_pool(self.conversation_id) + + if existing_pool: + return True + + try: + await pool_manager.create_conversation_pool(self.conversation_id, self.flow_id) + logger.info(f"已创建对话变量池: {self.conversation_id}") + return True + except Exception as e: + logger.error(f"创建对话变量池失败: {self.conversation_id} - {e}") + return False + + def _convert_to_string(self, value: Any) -> str: + """将值转换为字符串 + + Args: + value: 要转换的值 + + Returns: + str: 字符串表示 + """ + if value is None: + return "" + elif isinstance(value, str): + return value + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, (int, float)): + return str(value) + elif isinstance(value, (dict, list)): + try: + return json.dumps(value, ensure_ascii=False, separators=(',', ':')) + except (TypeError, ValueError): + return str(value) + else: + return str(value) + + @classmethod + def escape_variable_reference(cls, text: str) -> str: + """转义变量引用,防止被解析 + + Args: + text: 包含变量引用的文本 + + Returns: + str: 转义后的文本 + """ + return text.replace("{{", "\\{\\{").replace("}}", "\\}\\}") + + @classmethod + def unescape_variable_reference(cls, text: str) -> str: + """取消转义变量引用 + + Args: + text: 转义的文本 + + Returns: + str: 取消转义后的文本 + """ + return text.replace("\\{\\{", "{{").replace("\\}\\}", "}}") + + +class VariableReferenceBuilder: + """变量引用构建器 - 帮助构建标准的变量引用字符串""" + + @staticmethod + def system(var_name: str, nested_path: Optional[str] = None) -> str: + """构建系统变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径(如 config.api_key) + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{sys.{var_name}.{nested_path}}}}}" + return f"{{{{sys.{var_name}}}}}" + + @staticmethod + def user(var_name: str, nested_path: Optional[str] = None) -> str: + """构建用户变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径 + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{user.{var_name}.{nested_path}}}}}" + return f"{{{{user.{var_name}}}}}" + + @staticmethod + def environment(var_name: str, nested_path: Optional[str] = None) -> str: + """构建环境变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径 + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{env.{var_name}.{nested_path}}}}}" + return f"{{{{env.{var_name}}}}}" + + @staticmethod + def conversation(var_name: str, nested_path: Optional[str] = None) -> str: + """构建对话变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径 + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{conversation.{var_name}.{nested_path}}}}}" + return f"{{{{conversation.{var_name}}}}}" + + +class VariableContext: + """变量上下文管理器 - 管理局部变量作用域""" + + def __init__(self, + parser: VariableParser, + parent_context: Optional["VariableContext"] = None): + """初始化变量上下文 + + Args: + parser: 变量解析器 + parent_context: 父级上下文(用于嵌套作用域) + """ + self.parser = parser + self.parent_context = parent_context + self._local_variables: Dict[str, Any] = {} + + def set_local_variable(self, name: str, value: Any): + """设置局部变量 + + Args: + name: 变量名 + value: 变量值 + """ + self._local_variables[name] = value + + def get_local_variable(self, name: str) -> Any: + """获取局部变量 + + Args: + name: 变量名 + + Returns: + Any: 变量值 + """ + if name in self._local_variables: + return self._local_variables[name] + elif self.parent_context: + return self.parent_context.get_local_variable(name) + return None + + async def parse_with_locals(self, template: str) -> str: + """使用局部变量解析模板 + + Args: + template: 模板字符串 + + Returns: + str: 解析后的字符串 + """ + # 首先用局部变量替换 + result = template + + # 替换局部变量(使用简单的 ${var_name} 语法) + for var_name, var_value in self._local_variables.items(): + pattern = f"${{{var_name}}}" + str_value = self.parser._convert_to_string(var_value) + result = result.replace(pattern, str_value) + + # 然后用全局变量解析器处理剩余的变量引用 + return await self.parser.parse_template(result) + + def create_child_context(self) -> "VariableContext": + """创建子级上下文 + + Returns: + VariableContext: 子级上下文 + """ + return VariableContext(self.parser, self) \ No newline at end of file diff --git a/apps/scheduler/variable/pool_base.py b/apps/scheduler/variable/pool_base.py new file mode 100644 index 0000000000000000000000000000000000000000..83acdffdceaa4409142d56c9c96675ebedf93680 --- /dev/null +++ b/apps/scheduler/variable/pool_base.py @@ -0,0 +1,828 @@ +import logging +import asyncio +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Set +from datetime import datetime, UTC + +from apps.common.mongo import MongoDB +from .base import BaseVariable, VariableMetadata +from .type import VariableType, VariableScope +from .variables import create_variable, VARIABLE_CLASS_MAP + +logger = logging.getLogger(__name__) + + +class BaseVariablePool(ABC): + """变量池基类""" + + def __init__(self, pool_id: str, scope: VariableScope): + """初始化变量池 + + Args: + pool_id: 池标识符(如user_id、flow_id、conversation_id等) + scope: 池作用域 + """ + self.pool_id = pool_id + self.scope = scope + self._variables: Dict[str, BaseVariable] = {} + self._initialized = False + self._lock = asyncio.Lock() + + @property + def is_initialized(self) -> bool: + """检查是否已初始化""" + return self._initialized + + async def initialize(self): + """初始化变量池""" + async with self._lock: + if not self._initialized: + await self._load_variables() + await self._setup_default_variables() + self._initialized = True + logger.info(f"已初始化变量池: {self.__class__.__name__}({self.pool_id})") + + @abstractmethod + async def _load_variables(self): + """从存储加载变量""" + pass + + @abstractmethod + async def _setup_default_variables(self): + """设置默认变量""" + pass + + @abstractmethod + def can_modify(self) -> bool: + """检查是否允许修改变量""" + pass + + async def add_variable(self, + name: str, + var_type: VariableType, + value: Any = None, + description: Optional[str] = None, + created_by: Optional[str] = None, + is_system: bool = False) -> BaseVariable: + """添加变量""" + if not self.can_modify(): + raise PermissionError(f"不允许修改{self.scope.value}级变量") + + if name in self._variables: + raise ValueError(f"变量 {name} 已存在") + + # 创建变量元数据 + metadata = VariableMetadata( + name=name, + var_type=var_type, + scope=self.scope, + description=description, + user_sub=getattr(self, 'user_id', None), + flow_id=getattr(self, 'flow_id', None), + conversation_id=getattr(self, 'conversation_id', None), + created_by=created_by or "system", + is_system=is_system # 标记是否为系统变量 + ) + + # 创建变量 + variable = create_variable(metadata, value) + self._variables[name] = variable + + # 持久化 + await self._persist_variable(variable) + + logger.info(f"已添加{'系统' if is_system else ''}变量: {name} 到池 {self.pool_id}") + return variable + + async def update_variable(self, + name: str, + value: Optional[Any] = None, + var_type: Optional[VariableType] = None, + description: Optional[str] = None, + force_system_update: bool = False) -> BaseVariable: + """更新变量""" + if name not in self._variables: + raise ValueError(f"变量 {name} 不存在") + + variable = self._variables[name] + + # 检查系统变量的修改权限 + if hasattr(variable.metadata, 'is_system') and variable.metadata.is_system and not force_system_update: + raise PermissionError(f"系统变量 {name} 不允许修改") + + if not self.can_modify() and not force_system_update: + raise PermissionError(f"不允许修改{self.scope.value}级变量") + + # 更新字段 + if value is not None: + variable.value = value + if var_type is not None: + variable.metadata.var_type = var_type + if description is not None: + variable.metadata.description = description + + variable.metadata.updated_at = datetime.now(UTC) + + # 持久化 + await self._persist_variable(variable) + + logger.info(f"已更新变量: {name} 在池 {self.pool_id}, 值为{value}") + return variable + + async def delete_variable(self, name: str) -> bool: + """删除变量""" + if not self.can_modify(): + raise PermissionError(f"不允许修改{self.scope.value}级变量") + + if name not in self._variables: + return False + + variable = self._variables[name] + + # 检查是否为系统变量 + if hasattr(variable.metadata, 'is_system') and variable.metadata.is_system: + raise PermissionError(f"系统变量 {name} 不允许删除") + + del self._variables[name] + + # 从数据库删除 + await self._delete_variable_from_db(variable) + + logger.info(f"已删除变量: {name} 从池 {self.pool_id}") + return True + + async def get_variable(self, name: str) -> Optional[BaseVariable]: + """获取变量""" + return self._variables.get(name) + + async def list_variables(self, include_system: bool = True) -> List[BaseVariable]: + """列出所有变量""" + if include_system: + return list(self._variables.values()) + else: + # 只返回非系统变量 + return [var for var in self._variables.values() + if not (hasattr(var.metadata, 'is_system') and var.metadata.is_system)] + + async def list_system_variables(self) -> List[BaseVariable]: + """列出系统变量""" + return [var for var in self._variables.values() + if hasattr(var.metadata, 'is_system') and var.metadata.is_system] + + async def has_variable(self, name: str) -> bool: + """检查变量是否存在""" + return name in self._variables + + async def copy_variables(self) -> Dict[str, BaseVariable]: + """拷贝所有变量""" + copied = {} + for name, variable in self._variables.items(): + # 创建新的元数据 + new_metadata = VariableMetadata( + name=variable.metadata.name, + var_type=variable.metadata.var_type, + scope=variable.metadata.scope, + description=variable.metadata.description, + user_sub=variable.metadata.user_sub, + flow_id=variable.metadata.flow_id, + conversation_id=variable.metadata.conversation_id, + created_by=variable.metadata.created_by, + is_system=getattr(variable.metadata, 'is_system', False) + ) + # 创建新的变量实例 + copied[name] = create_variable(new_metadata, variable.value) + return copied + + async def _persist_variable(self, variable: BaseVariable): + """持久化变量""" + try: + collection = MongoDB().get_collection("variables") + data = variable.serialize() + + # 构建查询条件 + query = { + "metadata.name": variable.name, + "metadata.scope": variable.scope.value + } + + # 添加池特定的查询条件 + self._add_pool_query_conditions(query, variable) + + # 更新或插入 + from pymongo import WriteConcern + result = await collection.with_options( + write_concern=WriteConcern(w="majority", j=True) + ).replace_one(query, data, upsert=True) + + if not (result.acknowledged and (result.matched_count > 0 or result.upserted_id)): + raise RuntimeError(f"变量持久化失败: {variable.name}") + + except Exception as e: + logger.error(f"持久化变量失败: {e}") + raise + + async def _delete_variable_from_db(self, variable: BaseVariable): + """从数据库删除变量""" + try: + collection = MongoDB().get_collection("variables") + + query = { + "metadata.name": variable.name, + "metadata.scope": variable.scope.value + } + + # 添加池特定的查询条件 + self._add_pool_query_conditions(query, variable) + + from pymongo import WriteConcern + result = await collection.with_options( + write_concern=WriteConcern(w="majority", j=True) + ).delete_one(query) + + if not result.acknowledged: + raise RuntimeError(f"变量删除失败: {variable.name}") + + except Exception as e: + logger.error(f"删除变量失败: {e}") + raise + + @abstractmethod + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加池特定的查询条件""" + pass + + async def resolve_variable_reference(self, reference: str) -> Any: + """解析变量引用""" + # 移除 {{ 和 }} + clean_ref = reference.strip("{}").strip() + + # 解析变量路径 + path_parts = clean_ref.split(".") + var_name = path_parts[0] + + # 获取变量 + variable = await self.get_variable(var_name) + if not variable: + raise ValueError(f"变量不存在: {var_name}") + + # 获取变量值 + value = variable.value + + # 处理嵌套路径 + for path_part in path_parts[1:]: + if isinstance(value, dict): + value = value.get(path_part) + elif isinstance(value, list) and path_part.isdigit(): + try: + value = value[int(path_part)] + except IndexError: + value = None + else: + raise ValueError(f"无法访问路径: {clean_ref}") + + return value + + +class UserVariablePool(BaseVariablePool): + """用户变量池""" + + def __init__(self, user_id: str): + super().__init__(user_id, VariableScope.USER) + self.user_id = user_id + + async def _load_variables(self): + """从数据库加载用户变量""" + try: + collection = MongoDB().get_collection("variables") + cursor = collection.find({ + "metadata.scope": VariableScope.USER.value, + "metadata.user_sub": self.user_id + }) + + loaded_count = 0 + async for doc in cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._variables[variable.name] = variable + loaded_count += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"用户变量 {var_name} 数据损坏: {e}") + + logger.debug(f"用户 {self.user_id} 加载变量完成: {loaded_count} 个") + + except Exception as e: + logger.error(f"加载用户变量失败: {e}") + + async def _setup_default_variables(self): + """用户变量池不需要默认变量""" + pass + + def can_modify(self) -> bool: + """用户变量允许修改""" + return True + + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加用户变量池的查询条件""" + query["metadata.user_sub"] = self.user_id + + +class FlowVariablePool(BaseVariablePool): + """流程变量池(环境变量 + 系统变量模板 + 对话变量模板)""" + + def __init__(self, flow_id: str, parent_flow_id: Optional[str] = None): + super().__init__(flow_id, VariableScope.ENVIRONMENT) # 保持主要scope为ENVIRONMENT + self.flow_id = flow_id + self.parent_flow_id = parent_flow_id + + # 分别存储不同类型的变量 + # _variables 继续存储环境变量(保持向后兼容) + self._system_templates: Dict[str, BaseVariable] = {} # 系统变量模板 + self._conversation_templates: Dict[str, BaseVariable] = {} # 对话变量模板 + + async def _load_variables(self): + """从数据库加载所有类型的变量(环境变量 + 模板变量)""" + try: + collection = MongoDB().get_collection("variables") + loaded_counts = {"environment": 0, "system_templates": 0, "conversation_templates": 0} + + # 1. 加载环境变量 + env_cursor = collection.find({ + "metadata.scope": VariableScope.ENVIRONMENT.value, + "metadata.flow_id": self.flow_id + }) + + async for doc in env_cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._variables[variable.name] = variable + loaded_counts["environment"] += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"环境变量 {var_name} 数据损坏: {e}") + + # 2. 加载系统变量模板 + system_template_cursor = collection.find({ + "metadata.scope": VariableScope.SYSTEM.value, + "metadata.flow_id": self.flow_id, + "metadata.is_template": True + }) + + async for doc in system_template_cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._system_templates[variable.name] = variable + loaded_counts["system_templates"] += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"系统变量模板 {var_name} 数据损坏: {e}") + + # 3. 加载对话变量模板 + conv_template_cursor = collection.find({ + "metadata.scope": VariableScope.CONVERSATION.value, + "metadata.flow_id": self.flow_id, + "metadata.is_template": True + }) + + async for doc in conv_template_cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._conversation_templates[variable.name] = variable + loaded_counts["conversation_templates"] += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"对话变量模板 {var_name} 数据损坏: {e}") + + total_loaded = sum(loaded_counts.values()) + logger.debug(f"流程 {self.flow_id} 加载变量完成: 环境变量{loaded_counts['environment']}个, " + f"系统模板{loaded_counts['system_templates']}个, " + f"对话模板{loaded_counts['conversation_templates']}个, 总计{total_loaded}个") + + except Exception as e: + logger.error(f"加载流程变量失败: {e}") + + async def _setup_default_variables(self): + """设置默认的系统变量模板""" + from datetime import datetime, UTC + + # 定义系统变量模板(这些是模板,不是实例) + system_var_templates = [ + ("query", VariableType.STRING, "用户查询内容", ""), + ("files", VariableType.ARRAY_FILE, "用户上传的文件列表", []), + ("dialogue_count", VariableType.NUMBER, "对话轮数", 0), + ("app_id", VariableType.STRING, "应用ID", ""), + ("flow_id", VariableType.STRING, "工作流ID", self.flow_id), + ("user_id", VariableType.STRING, "用户ID", ""), + ("session_id", VariableType.STRING, "会话ID", ""), + ("conversation_id", VariableType.STRING, "对话ID", ""), + ("timestamp", VariableType.NUMBER, "当前时间戳", 0), + ] + + created_count = 0 + for var_name, var_type, description, default_value in system_var_templates: + # 如果系统变量模板不存在,才创建 + if var_name not in self._system_templates: + metadata = VariableMetadata( + name=var_name, + var_type=var_type, + scope=VariableScope.SYSTEM, + description=description, + flow_id=self.flow_id, + created_by="system", + is_system=True, + is_template=True # 标记为模板 + ) + variable = create_variable(metadata, default_value) + self._system_templates[var_name] = variable + + # 持久化模板到数据库 + try: + await self._persist_variable(variable) + created_count += 1 + logger.debug(f"已持久化系统变量模板: {var_name}") + except Exception as e: + logger.error(f"持久化系统变量模板失败: {var_name} - {e}") + + if created_count > 0: + logger.info(f"已为流程 {self.flow_id} 初始化 {created_count} 个系统变量模板") + + def can_modify(self) -> bool: + """环境变量允许修改""" + return True + + # === 系统变量模板相关方法 === + + async def get_system_template(self, name: str) -> Optional[BaseVariable]: + """获取系统变量模板""" + return self._system_templates.get(name) + + async def list_system_templates(self) -> List[BaseVariable]: + """列出所有系统变量模板""" + return list(self._system_templates.values()) + + async def add_system_template(self, name: str, var_type: VariableType, + default_value: Any = None, description: str = None) -> BaseVariable: + """添加系统变量模板""" + if name in self._system_templates: + raise ValueError(f"系统变量模板 {name} 已存在") + + metadata = VariableMetadata( + name=name, + var_type=var_type, + scope=VariableScope.SYSTEM, + description=description, + flow_id=self.flow_id, + created_by="system", + is_system=True, + is_template=True + ) + + variable = create_variable(metadata, default_value) + self._system_templates[name] = variable + + # 持久化到数据库 + await self._persist_variable(variable) + + logger.info(f"已添加系统变量模板: {name} 到流程 {self.flow_id}") + return variable + + # === 对话变量模板相关方法 === + + async def get_conversation_template(self, name: str) -> Optional[BaseVariable]: + """获取对话变量模板""" + return self._conversation_templates.get(name) + + async def list_conversation_templates(self) -> List[BaseVariable]: + """列出所有对话变量模板""" + return list(self._conversation_templates.values()) + + async def add_conversation_template(self, name: str, var_type: VariableType, + default_value: Any = None, description: str = None, + created_by: str = None) -> BaseVariable: + """添加对话变量模板""" + if name in self._conversation_templates: + raise ValueError(f"对话变量模板 {name} 已存在") + + metadata = VariableMetadata( + name=name, + var_type=var_type, + scope=VariableScope.CONVERSATION, + description=description, + flow_id=self.flow_id, + created_by=created_by or "user", + is_system=False, + is_template=True + ) + + variable = create_variable(metadata, default_value) + self._conversation_templates[name] = variable + + # 持久化到数据库 + await self._persist_variable(variable) + + logger.info(f"已添加对话变量模板: {name} 到流程 {self.flow_id}") + return variable + + # === 重写基类方法支持多scope查询 === + + async def get_variable_by_scope(self, name: str, scope: VariableScope) -> Optional[BaseVariable]: + """根据作用域获取变量""" + if scope == VariableScope.ENVIRONMENT: + return self._variables.get(name) + elif scope == VariableScope.SYSTEM: + return self._system_templates.get(name) + elif scope == VariableScope.CONVERSATION: + return self._conversation_templates.get(name) + else: + return None + + async def list_variables_by_scope(self, scope: VariableScope) -> List[BaseVariable]: + """根据作用域列出变量""" + if scope == VariableScope.ENVIRONMENT: + return list(self._variables.values()) + elif scope == VariableScope.SYSTEM: + return list(self._system_templates.values()) + elif scope == VariableScope.CONVERSATION: + return list(self._conversation_templates.values()) + else: + return [] + + # === 重写基类方法支持多字典操作 === + + async def update_variable(self, name: str, value: Any = None, + var_type: Optional[VariableType] = None, + description: Optional[str] = None, + force_system_update: bool = False) -> BaseVariable: + """更新变量(支持多字典查找)""" + + # 先在环境变量中查找 + if name in self._variables: + return await super().update_variable(name, value, var_type, description, force_system_update) + + # 在系统变量模板中查找 + elif name in self._system_templates: + variable = self._system_templates[name] + + # 检查权限 + if not force_system_update and getattr(variable.metadata, 'is_system', False): + raise PermissionError(f"系统变量 {name} 不允许直接修改") + + # 更新变量 + if value is not None: + variable.value = value + if var_type is not None: + variable.metadata.var_type = var_type + if description is not None: + variable.metadata.description = description + + # 持久化 + await self._persist_variable(variable) + return variable + + # 在对话变量模板中查找 + elif name in self._conversation_templates: + variable = self._conversation_templates[name] + + # 更新变量 + if value is not None: + variable.value = value + if var_type is not None: + variable.metadata.var_type = var_type + if description is not None: + variable.metadata.description = description + + # 持久化 + await self._persist_variable(variable) + return variable + + else: + raise ValueError(f"变量 {name} 不存在") + + async def delete_variable(self, name: str) -> bool: + """删除变量(支持多字典查找)""" + + # 先在环境变量中查找 + if name in self._variables: + return await super().delete_variable(name) + + # 在系统变量模板中查找 + elif name in self._system_templates: + variable = self._system_templates[name] + + # 检查权限 + if getattr(variable.metadata, 'is_system', False): + raise PermissionError(f"系统变量模板 {name} 不允许删除") + + del self._system_templates[name] + await self._delete_variable_from_db(variable) + return True + + # 在对话变量模板中查找 + elif name in self._conversation_templates: + variable = self._conversation_templates[name] + del self._conversation_templates[name] + await self._delete_variable_from_db(variable) + return True + + else: + return False + + async def get_variable(self, name: str) -> Optional[BaseVariable]: + """获取变量(支持多字典查找)""" + + # 先在环境变量中查找 + if name in self._variables: + return self._variables[name] + + # 在系统变量模板中查找 + elif name in self._system_templates: + return self._system_templates[name] + + # 在对话变量模板中查找 + elif name in self._conversation_templates: + return self._conversation_templates[name] + + else: + return None + + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加环境变量池的查询条件""" + query["metadata.flow_id"] = self.flow_id + + async def inherit_from_parent(self, parent_pool: "FlowVariablePool"): + """从父流程继承环境变量""" + parent_variables = await parent_pool.copy_variables() + for name, variable in parent_variables.items(): + # 更新元数据中的flow_id + variable.metadata.flow_id = self.flow_id + self._variables[name] = variable + # 持久化继承的变量 + await self._persist_variable(variable) + + logger.info(f"流程 {self.flow_id} 从父流程 {parent_pool.flow_id} 继承了 {len(parent_variables)} 个环境变量") + + +class ConversationVariablePool(BaseVariablePool): + """对话变量池 - 包含系统变量和对话变量""" + + def __init__(self, conversation_id: str, flow_id: str): + super().__init__(conversation_id, VariableScope.CONVERSATION) + self.conversation_id = conversation_id + self.flow_id = flow_id + + async def _load_variables(self): + """从数据库加载对话变量""" + try: + collection = MongoDB().get_collection("variables") + cursor = collection.find({ + "metadata.scope": VariableScope.CONVERSATION.value, + "metadata.conversation_id": self.conversation_id + }) + + loaded_count = 0 + async for doc in cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._variables[variable.name] = variable + loaded_count += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"对话变量 {var_name} 数据损坏: {e}") + + logger.debug(f"对话 {self.conversation_id} 加载变量完成: {loaded_count} 个") + + except Exception as e: + logger.error(f"加载对话变量失败: {e}") + + async def _setup_default_variables(self): + """从flow模板继承系统变量和对话变量""" + from .pool_manager import get_pool_manager + + try: + pool_manager = await get_pool_manager() + flow_pool = await pool_manager.get_flow_pool(self.flow_id) + + if not flow_pool: + logger.warning(f"未找到流程池 {self.flow_id},无法继承变量模板") + return + + created_count = 0 + + # 1. 从系统变量模板创建系统变量实例 + system_templates = await flow_pool.list_system_templates() + for template in system_templates: + if template.name not in self._variables: + # 创建系统变量实例(不是模板) + metadata = VariableMetadata( + name=template.name, + var_type=template.var_type, + scope=VariableScope.CONVERSATION, # 存储在对话作用域 + description=template.metadata.description, + flow_id=self.flow_id, + conversation_id=self.conversation_id, + created_by="system", + is_system=True, # 标记为系统变量 + is_template=False # 这是实例,不是模板 + ) + + # 使用模板的默认值创建实例 + variable = create_variable(metadata, template.value) + self._variables[template.name] = variable + + # 持久化系统变量实例 + try: + await self._persist_variable(variable) + created_count += 1 + logger.debug(f"已从模板创建系统变量实例: {template.name}") + except Exception as e: + logger.error(f"持久化系统变量实例失败: {template.name} - {e}") + + # 2. 从对话变量模板创建对话变量实例 + conversation_templates = await flow_pool.list_conversation_templates() + for template in conversation_templates: + if template.name not in self._variables: + # 创建对话变量实例 + metadata = VariableMetadata( + name=template.name, + var_type=template.var_type, + scope=VariableScope.CONVERSATION, + description=template.metadata.description, + flow_id=self.flow_id, + conversation_id=self.conversation_id, + created_by=template.metadata.created_by, + is_system=False, # 对话变量 + is_template=False # 这是实例,不是模板 + ) + + # 使用模板的默认值创建实例 + variable = create_variable(metadata, template.value) + self._variables[template.name] = variable + + # 持久化对话变量实例 + try: + await self._persist_variable(variable) + created_count += 1 + logger.debug(f"已从模板创建对话变量实例: {template.name}") + except Exception as e: + logger.error(f"持久化对话变量实例失败: {template.name} - {e}") + + if created_count > 0: + logger.info(f"已为对话 {self.conversation_id} 从流程模板继承 {created_count} 个变量") + + except Exception as e: + logger.error(f"从流程模板继承变量失败: {e}") + + def can_modify(self) -> bool: + """对话变量允许修改""" + return True + + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加对话变量池的查询条件""" + query["metadata.conversation_id"] = self.conversation_id + query["metadata.flow_id"] = self.flow_id + + async def update_system_variable(self, name: str, value: Any) -> bool: + """更新系统变量的值(系统内部调用)""" + try: + await self.update_variable(name, value=value, force_system_update=True) + return True + except Exception as e: + logger.error(f"更新系统变量失败: {name} - {e}") + return False + + async def inherit_from_conversation_template(self, template_pool: Optional["ConversationVariablePool"] = None): + """从对话模板池继承变量(如果存在)""" + if template_pool: + template_variables = await template_pool.copy_variables() + for name, variable in template_variables.items(): + # 只继承非系统变量 + if not (hasattr(variable.metadata, 'is_system') and variable.metadata.is_system): + variable.metadata.conversation_id = self.conversation_id + self._variables[name] = variable + + logger.info(f"对话 {self.conversation_id} 从模板继承了 {len(template_variables)} 个变量") \ No newline at end of file diff --git a/apps/scheduler/variable/pool_manager.py b/apps/scheduler/variable/pool_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..479015add1115e479ce1919e004906f91224c579 --- /dev/null +++ b/apps/scheduler/variable/pool_manager.py @@ -0,0 +1,353 @@ +import logging +import asyncio +from typing import Dict, List, Optional, Set, Tuple, Any +from contextlib import asynccontextmanager + +from apps.common.mongo import MongoDB +from .pool_base import ( + BaseVariablePool, + UserVariablePool, + FlowVariablePool, + ConversationVariablePool +) +from .type import VariableScope +from .base import BaseVariable + +logger = logging.getLogger(__name__) + + +class VariablePoolManager: + """变量池管理器 - 管理所有类型变量池的生命周期""" + + def __init__(self): + """初始化变量池管理器""" + # 用户变量池缓存: user_id -> UserVariablePool + self._user_pools: Dict[str, UserVariablePool] = {} + + # 流程变量池缓存: flow_id -> FlowVariablePool + self._flow_pools: Dict[str, FlowVariablePool] = {} + + # 对话变量池缓存: conversation_id -> ConversationVariablePool + self._conversation_pools: Dict[str, ConversationVariablePool] = {} + + # 流程继承关系缓存: child_flow_id -> parent_flow_id + self._flow_inheritance: Dict[str, str] = {} + + self._initialized = False + self._lock = asyncio.Lock() + + async def initialize(self): + """初始化变量池管理器""" + async with self._lock: + if not self._initialized: + await self._load_existing_entities() + await self._patrol_and_create_missing_pools() + self._initialized = True + logger.info("变量池管理器初始化完成") + + async def _load_existing_entities(self): + """加载现有的用户和流程实体""" + try: + # 这里应该从相应的用户和流程数据库表中加载 + # 目前先从变量表中推断存在的实体 + collection = MongoDB().get_collection("variables") + + # 获取所有唯一的用户ID + user_ids = await collection.distinct("metadata.user_sub", { + "metadata.user_sub": {"$ne": None} + }) + logger.info(f"发现 {len(user_ids)} 个用户需要变量池") + + # 获取所有唯一的流程ID + flow_ids = await collection.distinct("metadata.flow_id", { + "metadata.flow_id": {"$ne": None} + }) + logger.info(f"发现 {len(flow_ids)} 个流程需要变量池") + + # 缓存实体信息用于后续创建池 + self._discovered_users = set(user_ids) + self._discovered_flows = set(flow_ids) + + except Exception as e: + logger.error(f"加载现有实体失败: {e}") + self._discovered_users = set() + self._discovered_flows = set() + + async def _patrol_and_create_missing_pools(self): + """巡检并创建缺失的变量池""" + logger.info("开始巡检并创建缺失的变量池...") + + # 为所有发现的用户创建用户变量池 + created_user_pools = 0 + for user_id in self._discovered_users: + if user_id not in self._user_pools: + await self._create_user_pool(user_id) + created_user_pools += 1 + + # 为所有发现的流程创建流程变量池 + created_flow_pools = 0 + for flow_id in self._discovered_flows: + if flow_id not in self._flow_pools: + await self._create_flow_pool(flow_id) + created_flow_pools += 1 + + logger.info(f"巡检完成: 创建了 {created_user_pools} 个用户池, " + f"{created_flow_pools} 个流程池") + + async def get_user_pool(self, user_id: str, auto_create: bool = True) -> Optional[UserVariablePool]: + """获取用户变量池""" + if user_id in self._user_pools: + return self._user_pools[user_id] + + if auto_create: + return await self._create_user_pool(user_id) + + return None + + async def get_flow_pool(self, flow_id: str, parent_flow_id: Optional[str] = None, + auto_create: bool = True) -> Optional[FlowVariablePool]: + """获取流程变量池""" + if flow_id in self._flow_pools: + return self._flow_pools[flow_id] + + if auto_create: + return await self._create_flow_pool(flow_id, parent_flow_id) + + return None + + async def create_conversation_pool(self, conversation_id: str, flow_id: str) -> ConversationVariablePool: + """创建对话变量池(包含系统变量和对话变量)""" + if conversation_id in self._conversation_pools: + logger.warning(f"对话池 {conversation_id} 已存在,将覆盖") + + # 创建对话变量池 + conversation_pool = ConversationVariablePool(conversation_id, flow_id) + await conversation_pool.initialize() + + # 从对话模板池继承变量(如果存在) + conversation_template_pool = await self._get_conversation_template_pool(flow_id) + await conversation_pool.inherit_from_conversation_template(conversation_template_pool) + + # 缓存池 + self._conversation_pools[conversation_id] = conversation_pool + + logger.info(f"已创建对话变量池: {conversation_id}") + return conversation_pool + + async def get_conversation_pool(self, conversation_id: str) -> Optional[ConversationVariablePool]: + """获取对话变量池""" + return self._conversation_pools.get(conversation_id) + + async def remove_conversation_pool(self, conversation_id: str) -> bool: + """移除对话变量池""" + if conversation_id in self._conversation_pools: + del self._conversation_pools[conversation_id] + logger.info(f"已移除对话变量池: {conversation_id}") + return True + return False + + async def get_variable_from_any_pool(self, + name: str, + scope: VariableScope, + user_id: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> Optional[BaseVariable]: + """从任意池中获取变量""" + if scope == VariableScope.USER and user_id: + pool = await self.get_user_pool(user_id) + return await pool.get_variable(name) if pool else None + + elif scope == VariableScope.ENVIRONMENT and flow_id: + pool = await self.get_flow_pool(flow_id) + return await pool.get_variable(name) if pool else None + + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 使用conversation_id查询对话变量实例 + pool = await self.get_conversation_pool(conversation_id) + return await pool.get_variable(name) if pool else None + elif flow_id: + # 使用flow_id查询对话变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.get_conversation_template(name) + return None + + # 系统变量处理 + elif scope == VariableScope.SYSTEM: + if conversation_id: + # 优先使用conversation_id查询实际的系统变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + variable = await pool.get_variable(name) + # 检查是否为系统变量 + if variable and hasattr(variable.metadata, 'is_system') and variable.metadata.is_system: + return variable + elif flow_id: + # 使用flow_id查询系统变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.get_system_template(name) + + return None + + async def list_variables_from_any_pool(self, + scope: VariableScope, + user_id: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> List[BaseVariable]: + """从任意池中列出变量""" + if scope == VariableScope.USER and user_id: + pool = await self.get_user_pool(user_id) + return await pool.list_variables() if pool else [] + + elif scope == VariableScope.ENVIRONMENT and flow_id: + pool = await self.get_flow_pool(flow_id) + return await pool.list_variables() if pool else [] + + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 使用conversation_id查询对话变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + # 只返回非系统变量 + return await pool.list_variables(include_system=False) + elif flow_id: + # 使用flow_id查询对话变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.list_conversation_templates() + return [] + + # 系统变量处理 + elif scope == VariableScope.SYSTEM: + if conversation_id: + # 优先使用conversation_id查询实际的系统变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + # 只返回系统变量 + return await pool.list_system_variables() + elif flow_id: + # 使用flow_id查询系统变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.list_system_templates() + return [] + + return [] + + async def update_system_variable(self, conversation_id: str, name: str, value: Any) -> bool: + """更新对话中的系统变量""" + conversation_pool = await self.get_conversation_pool(conversation_id) + if conversation_pool: + return await conversation_pool.update_system_variable(name, value) + return False + + async def _create_user_pool(self, user_id: str) -> UserVariablePool: + """创建用户变量池""" + pool = UserVariablePool(user_id) + await pool.initialize() + self._user_pools[user_id] = pool + logger.info(f"已创建用户变量池: {user_id}") + return pool + + async def _create_flow_pool(self, flow_id: str, parent_flow_id: Optional[str] = None) -> FlowVariablePool: + """创建流程变量池""" + pool = FlowVariablePool(flow_id, parent_flow_id) + await pool.initialize() + + # 如果有父流程,从父流程继承变量 + if parent_flow_id and parent_flow_id in self._flow_pools: + parent_pool = self._flow_pools[parent_flow_id] + await pool.inherit_from_parent(parent_pool) + self._flow_inheritance[flow_id] = parent_flow_id + + self._flow_pools[flow_id] = pool + logger.info(f"已创建流程变量池: {flow_id}") + return pool + + async def _get_conversation_template_pool(self, flow_id: str) -> Optional[ConversationVariablePool]: + """获取对话模板池(目前简化处理,返回None)""" + # 这里可以实现从数据库加载对话模板的逻辑 + # 目前简化处理,返回None + return None + + async def clear_conversation_variables(self, flow_id: str): + """清空工作流的所有对话变量池""" + to_remove = [] + for conversation_id, pool in self._conversation_pools.items(): + if pool.flow_id == flow_id: + to_remove.append(conversation_id) + + for conversation_id in to_remove: + del self._conversation_pools[conversation_id] + + logger.info(f"已清空工作流 {flow_id} 的 {len(to_remove)} 个对话变量池") + + async def get_pool_stats(self) -> Dict[str, int]: + """获取变量池统计信息""" + return { + "user_pools": len(self._user_pools), + "flow_pools": len(self._flow_pools), + "conversation_pools": len(self._conversation_pools), + } + + async def cleanup_unused_pools(self, active_conversations: Set[str]): + """清理未使用的对话变量池""" + to_remove = [] + for conversation_id in self._conversation_pools: + if conversation_id not in active_conversations: + to_remove.append(conversation_id) + + for conversation_id in to_remove: + del self._conversation_pools[conversation_id] + + if to_remove: + logger.info(f"清理了 {len(to_remove)} 个未使用的对话变量池") + + @asynccontextmanager + async def get_pool_for_scope(self, + scope: VariableScope, + user_id: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None): + """上下文管理器,获取指定作用域的变量池""" + pool = None + + try: + if scope == VariableScope.USER and user_id: + pool = await self.get_user_pool(user_id) + elif scope == VariableScope.ENVIRONMENT and flow_id: + pool = await self.get_flow_pool(flow_id) + elif scope in [VariableScope.CONVERSATION, VariableScope.SYSTEM] and conversation_id: + pool = await self.get_conversation_pool(conversation_id) + + if not pool: + raise ValueError(f"无法获取 {scope.value} 级变量池") + + yield pool + + except Exception: + raise + finally: + # 这里可以添加清理逻辑,比如对话池的自动清理等 + pass + + +# 全局变量池管理器实例 +_pool_manager = None + + +async def get_pool_manager() -> VariablePoolManager: + """获取全局变量池管理器实例""" + global _pool_manager + if _pool_manager is None: + _pool_manager = VariablePoolManager() + await _pool_manager.initialize() + return _pool_manager + + +async def initialize_pool_manager(): + """初始化变量池管理器(在应用启动时调用)""" + await get_pool_manager() + logger.info("变量池管理器已启动") \ No newline at end of file diff --git a/apps/scheduler/variable/security.py b/apps/scheduler/variable/security.py new file mode 100644 index 0000000000000000000000000000000000000000..07fde158c01e7bf9375d7f2a555e3bc2d5b1806b --- /dev/null +++ b/apps/scheduler/variable/security.py @@ -0,0 +1,415 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""变量安全管理模块 + +提供密钥变量的额外安全保障,包括: +- 访问审计日志 +- 密钥轮换 +- 安全检查 +- 权限验证 +""" + +import logging +import hashlib +import time +from typing import Any, Dict, List, Optional, Set +from datetime import datetime, UTC, timedelta +from dataclasses import dataclass + +from apps.common.mongo import MongoDB + +logger = logging.getLogger(__name__) + + +@dataclass +class AccessLog: + """访问日志记录""" + variable_name: str + user_sub: str + access_time: datetime + access_type: str # read, write, delete + ip_address: Optional[str] = None + user_agent: Optional[str] = None + success: bool = True + error_message: Optional[str] = None + + +class SecretVariableSecurity: + """密钥变量安全管理器""" + + def __init__(self): + """初始化安全管理器""" + self._access_logs: List[AccessLog] = [] + self._failed_access_attempts: Dict[str, List[datetime]] = {} + self._blocked_users: Set[str] = set() + + # 安全配置 + self.max_failed_attempts = 5 # 最大失败尝试次数 + self.block_duration_minutes = 30 # 封禁时长(分钟) + self.audit_retention_days = 90 # 审计日志保留天数 + + async def verify_access_permission(self, + variable_name: str, + user_sub: str, + access_type: str = "read", + ip_address: Optional[str] = None) -> bool: + """验证访问权限 + + Args: + variable_name: 变量名 + user_sub: 用户ID + access_type: 访问类型(read, write, delete) + ip_address: IP地址 + + Returns: + bool: 是否允许访问 + """ + try: + # 检查用户是否被封禁 + if await self._is_user_blocked(user_sub): + await self._log_access( + variable_name, user_sub, access_type, ip_address, + success=False, error_message="用户被临时封禁" + ) + return False + + # 检查是否超出访问频率限制 + if await self._check_rate_limit(user_sub): + await self._log_access( + variable_name, user_sub, access_type, ip_address, + success=False, error_message="访问频率过高" + ) + await self._record_failed_attempt(user_sub) + return False + + # 验证IP地址(可选) + if ip_address and not await self._is_ip_allowed(ip_address): + await self._log_access( + variable_name, user_sub, access_type, ip_address, + success=False, error_message="IP地址不在允许列表中" + ) + return False + + # 记录成功访问 + await self._log_access(variable_name, user_sub, access_type, ip_address) + await self._clear_failed_attempts(user_sub) + + return True + + except Exception as e: + logger.error(f"验证访问权限失败: {e}") + return False + + async def audit_secret_access(self, + variable_name: str, + user_sub: str, + actual_value: str, + access_type: str = "read") -> None: + """审计密钥访问 + + Args: + variable_name: 变量名 + user_sub: 用户ID + actual_value: 实际访问的值 + access_type: 访问类型 + """ + try: + # 计算值的哈希(用于审计,不存储原始值) + value_hash = hashlib.sha256(actual_value.encode()).hexdigest()[:16] + + # 记录详细的审计信息 + audit_record = { + "variable_name": variable_name, + "user_sub": user_sub, + "access_time": datetime.now(UTC), + "access_type": access_type, + "value_hash": value_hash, + "value_length": len(actual_value), + } + + # 保存到审计日志 + await self._save_audit_record(audit_record) + + logger.info(f"已审计密钥访问: {variable_name} by {user_sub}") + + except Exception as e: + logger.error(f"审计密钥访问失败: {e}") + + async def check_secret_strength(self, secret_value: str) -> Dict[str, Any]: + """检查密钥强度 + + Args: + secret_value: 密钥值 + + Returns: + Dict[str, Any]: 检查结果 + """ + result = { + "is_strong": False, + "score": 0, + "warnings": [], + "recommendations": [] + } + + try: + # 长度检查 + if len(secret_value) < 8: + result["warnings"].append("密钥长度过短(建议至少8位)") + elif len(secret_value) >= 12: + result["score"] += 2 + else: + result["score"] += 1 + + # 复杂性检查 + has_upper = any(c.isupper() for c in secret_value) + has_lower = any(c.islower() for c in secret_value) + has_digit = any(c.isdigit() for c in secret_value) + has_special = any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in secret_value) + + complexity_score = sum([has_upper, has_lower, has_digit, has_special]) + result["score"] += complexity_score + + if complexity_score < 3: + result["recommendations"].append("建议包含大小写字母、数字和特殊字符") + + # 常见密码检查 + common_passwords = ["password", "123456", "admin", "root", "qwerty"] + if secret_value.lower() in common_passwords: + result["warnings"].append("使用了常见的弱密码") + result["score"] = 0 + + # 重复字符检查 + if len(set(secret_value)) < len(secret_value) * 0.6: + result["warnings"].append("包含过多重复字符") + + # 设置强度等级 + if result["score"] >= 6 and len(result["warnings"]) == 0: + result["is_strong"] = True + + return result + + except Exception as e: + logger.error(f"检查密钥强度失败: {e}") + return result + + async def rotate_secret_key(self, variable_name: str, user_sub: str) -> bool: + """轮换密钥的加密密钥 + + Args: + variable_name: 变量名 + user_sub: 用户ID + + Returns: + bool: 是否轮换成功 + """ + try: + from .pool_manager import get_pool_manager + from .type import VariableScope + + pool_manager = await get_pool_manager() + + # 获取用户变量池 + user_pool = await pool_manager.get_user_pool(user_sub) + if not user_pool: + logger.error(f"用户变量池不存在: {user_sub}") + return False + + # 获取密钥变量 + variable = await user_pool.get_variable(variable_name) + if not variable or not variable.var_type.is_secret_type(): + return False + + # 获取原始值 + original_value = variable.value + + # 重新生成加密密钥并重新加密 + variable._encryption_key = variable._generate_encryption_key() + variable.value = original_value # 这会触发重新加密 + + # 更新存储 + await user_pool._persist_variable(variable) + + # 记录轮换操作 + await self._log_access( + variable_name, user_sub, "key_rotation", + success=True, error_message="密钥轮换成功" + ) + + logger.info(f"已轮换密钥变量的加密密钥: {variable_name}") + return True + + except Exception as e: + logger.error(f"轮换密钥失败: {e}") + return False + + async def get_access_logs(self, + variable_name: Optional[str] = None, + user_sub: Optional[str] = None, + hours: int = 24) -> List[Dict[str, Any]]: + """获取访问日志 + + Args: + variable_name: 变量名(可选) + user_sub: 用户ID(可选) + hours: 查询最近几小时的日志 + + Returns: + List[Dict[str, Any]]: 访问日志列表 + """ + try: + collection = MongoDB().get_collection("variable_access_logs") + + # 构建查询条件 + query = { + "access_time": { + "$gte": datetime.now(UTC) - timedelta(hours=hours) + } + } + + if variable_name: + query["variable_name"] = variable_name + if user_sub: + query["user_sub"] = user_sub + + # 查询日志 + logs = [] + async for doc in collection.find(query).sort("access_time", -1): + logs.append({ + "variable_name": doc.get("variable_name"), + "user_sub": doc.get("user_sub"), + "access_time": doc.get("access_time"), + "access_type": doc.get("access_type"), + "ip_address": doc.get("ip_address"), + "success": doc.get("success", True), + "error_message": doc.get("error_message") + }) + + return logs + + except Exception as e: + logger.error(f"获取访问日志失败: {e}") + return [] + + async def _is_user_blocked(self, user_sub: str) -> bool: + """检查用户是否被封禁""" + if user_sub not in self._failed_access_attempts: + return False + + failed_attempts = self._failed_access_attempts[user_sub] + recent_failures = [ + attempt for attempt in failed_attempts + if attempt > datetime.now(UTC) - timedelta(minutes=self.block_duration_minutes) + ] + + return len(recent_failures) >= self.max_failed_attempts + + async def _check_rate_limit(self, user_sub: str) -> bool: + """检查访问频率限制""" + try: + # 获取最近1分钟的访问记录 + collection = MongoDB().get_collection("variable_access_logs") + count = await collection.count_documents({ + "user_sub": user_sub, + "access_time": { + "$gte": datetime.now(UTC) - timedelta(minutes=1) + } + }) + + # 限制每分钟最多30次访问 + return count >= 30 + + except Exception as e: + logger.error(f"检查访问频率失败: {e}") + return False + + async def _is_ip_allowed(self, ip_address: str) -> bool: + """检查IP地址是否被允许""" + # 这里可以实现IP白名单/黑名单逻辑 + # 暂时返回True,表示允许所有IP + return True + + async def _record_failed_attempt(self, user_sub: str): + """记录失败尝试""" + if user_sub not in self._failed_access_attempts: + self._failed_access_attempts[user_sub] = [] + + self._failed_access_attempts[user_sub].append(datetime.now(UTC)) + + # 清理过期的失败记录 + cutoff_time = datetime.now(UTC) - timedelta(minutes=self.block_duration_minutes) + self._failed_access_attempts[user_sub] = [ + attempt for attempt in self._failed_access_attempts[user_sub] + if attempt > cutoff_time + ] + + async def _clear_failed_attempts(self, user_sub: str): + """清除失败尝试记录""" + if user_sub in self._failed_access_attempts: + del self._failed_access_attempts[user_sub] + + async def _log_access(self, + variable_name: str, + user_sub: str, + access_type: str, + ip_address: Optional[str] = None, + success: bool = True, + error_message: Optional[str] = None): + """记录访问日志""" + try: + log_entry = { + "variable_name": variable_name, + "user_sub": user_sub, + "access_time": datetime.now(UTC), + "access_type": access_type, + "ip_address": ip_address, + "success": success, + "error_message": error_message + } + + # 保存到数据库 + collection = MongoDB().get_collection("variable_access_logs") + await collection.insert_one(log_entry) + + except Exception as e: + logger.error(f"记录访问日志失败: {e}") + + async def _save_audit_record(self, audit_record: Dict[str, Any]): + """保存审计记录""" + try: + collection = MongoDB().get_collection("variable_audit_logs") + await collection.insert_one(audit_record) + except Exception as e: + logger.error(f"保存审计记录失败: {e}") + + async def cleanup_old_logs(self): + """清理过期的日志记录""" + try: + cutoff_time = datetime.now(UTC) - timedelta(days=self.audit_retention_days) + + # 清理访问日志 + access_collection = MongoDB().get_collection("variable_access_logs") + result1 = await access_collection.delete_many({ + "access_time": {"$lt": cutoff_time} + }) + + # 清理审计日志 + audit_collection = MongoDB().get_collection("variable_audit_logs") + result2 = await audit_collection.delete_many({ + "access_time": {"$lt": cutoff_time} + }) + + logger.info(f"已清理过期日志: 访问日志 {result1.deleted_count} 条, 审计日志 {result2.deleted_count} 条") + + except Exception as e: + logger.error(f"清理过期日志失败: {e}") + + +# 全局安全管理器实例 +_security_manager = None + + +def get_security_manager() -> SecretVariableSecurity: + """获取全局安全管理器实例""" + global _security_manager + if _security_manager is None: + _security_manager = SecretVariableSecurity() + return _security_manager \ No newline at end of file diff --git a/apps/scheduler/variable/system_variables_example.py b/apps/scheduler/variable/system_variables_example.py new file mode 100644 index 0000000000000000000000000000000000000000..0e48de30a11cc3434261361782a2d0af4b8b7bc3 --- /dev/null +++ b/apps/scheduler/variable/system_variables_example.py @@ -0,0 +1,220 @@ +""" +系统变量使用示例 + +演示系统变量的正确初始化、更新和访问流程 +""" + +import asyncio +from datetime import datetime, UTC +from typing import Dict, Any + +from .pool_manager import get_pool_manager +from .parser import VariableParser, VariableReferenceBuilder +from .type import VariableScope + + +async def demonstrate_system_variables(): + """演示系统变量的完整工作流程""" + + # 模拟对话参数 + user_id = "user123" + flow_id = "flow456" + conversation_id = "conv789" + + print("=== 系统变量演示 ===\n") + + # 1. 创建变量解析器(会自动创建对话池并初始化系统变量) + print("1. 创建变量解析器并初始化对话变量池...") + parser = VariableParser( + user_id=user_id, + flow_id=flow_id, + conversation_id=conversation_id + ) + + # 确保对话池存在 + success = await parser.create_conversation_pool_if_needed() + print(f" 对话池创建结果: {'成功' if success else '失败'}") + + # 2. 检查系统变量是否已正确初始化 + print("\n2. 检查初始化的系统变量...") + pool_manager = await get_pool_manager() + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + + if conversation_pool: + system_vars = await conversation_pool.list_system_variables() + print(f" 已初始化 {len(system_vars)} 个系统变量:") + for var in system_vars: + print(f" - {var.name}: {var.value} ({var.var_type.value})") + + # 3. 更新系统变量(模拟对话开始) + print("\n3. 更新系统变量...") + context = { + "question": "请帮我分析这个数据文件", + "files": [{"name": "data.csv", "size": 1024, "type": "text/csv"}], + "dialogue_count": 1, + "app_id": "app001", + "user_sub": user_id, + "session_id": "session123" + } + + await parser.update_system_variables(context) + print(" 系统变量更新完成") + + # 4. 验证系统变量已正确更新 + print("\n4. 验证系统变量更新结果...") + updated_vars = await conversation_pool.list_system_variables() + for var in updated_vars: + if var.name in ["query", "files", "dialogue_count", "app_id", "user_id", "session_id"]: + print(f" - {var.name}: {var.value}") + + # 5. 使用变量引用解析模板 + print("\n5. 解析包含系统变量的模板...") + template = """ +用户查询: {{sys.query}} +对话轮数: {{sys.dialogue_count}} +流程ID: {{sys.flow_id}} +用户ID: {{sys.user_id}} +文件数量: {{sys.files.length}} +""" + + try: + parsed_result = await parser.parse_template(template) + print(" 模板解析结果:") + print(parsed_result) + except Exception as e: + print(f" 模板解析失败: {e}") + + # 6. 验证系统变量的只读性 + print("\n6. 验证系统变量的只读保护...") + try: + # 尝试直接修改系统变量(应该失败) + await conversation_pool.update_variable("query", value="恶意修改") + print(" ❌ 错误:系统变量被意外修改") + except PermissionError: + print(" ✅ 正确:系统变量只读保护生效") + except Exception as e: + print(f" 🤔 意外错误: {e}") + + # 7. 展示系统变量的强制更新(内部使用) + print("\n7. 演示系统变量的内部更新...") + success = await conversation_pool.update_system_variable("dialogue_count", 2) + if success: + updated_var = await conversation_pool.get_variable("dialogue_count") + print(f" ✅ 系统变量内部更新成功: dialogue_count = {updated_var.value}") + else: + print(" ❌ 系统变量内部更新失败") + + # 8. 清理 + print("\n8. 清理对话变量池...") + removed = await pool_manager.remove_conversation_pool(conversation_id) + print(f" 清理结果: {'成功' if removed else '失败'}") + + print("\n=== 演示完成 ===") + + +async def demonstrate_variable_references(): + """演示系统变量引用的构建和使用""" + + print("\n=== 变量引用演示 ===\n") + + # 构建各种变量引用 + print("1. 变量引用构建示例:") + + # 系统变量引用 + query_ref = VariableReferenceBuilder.system("query") + files_ref = VariableReferenceBuilder.system("files", "0.name") # 嵌套访问 + + # 用户变量引用 + api_key_ref = VariableReferenceBuilder.user("api_key") + + # 环境变量引用 + db_url_ref = VariableReferenceBuilder.environment("database_url") + + # 对话变量引用 + history_ref = VariableReferenceBuilder.conversation("chat_history") + + print(f" 系统变量 - 用户查询: {query_ref}") + print(f" 系统变量 - 首个文件名: {files_ref}") + print(f" 用户变量 - API密钥: {api_key_ref}") + print(f" 环境变量 - 数据库: {db_url_ref}") + print(f" 对话变量 - 聊天历史: {history_ref}") + + # 构建复杂模板 + print("\n2. 复杂模板示例:") + complex_template = f""" +# 对话上下文 +- 用户: {query_ref} +- 轮次: {VariableReferenceBuilder.system("dialogue_count")} +- 时间: {VariableReferenceBuilder.system("timestamp")} + +# 文件信息 +- 文件列表: {VariableReferenceBuilder.system("files")} +- 文件数量: {VariableReferenceBuilder.system("files", "length")} + +# 会话信息 +- 对话ID: {VariableReferenceBuilder.system("conversation_id")} +- 流程ID: {VariableReferenceBuilder.system("flow_id")} +- 用户ID: {VariableReferenceBuilder.system("user_id")} +""" + + print(complex_template) + + print("=== 引用演示完成 ===") + + +async def validate_system_variable_persistence(): + """验证系统变量的持久化""" + + print("\n=== 持久化验证 ===\n") + + conversation_id = "test_persistence_conv" + flow_id = "test_persistence_flow" + + # 创建对话池 + pool_manager = await get_pool_manager() + conversation_pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) + + print("1. 检查新创建池的系统变量...") + system_vars_before = await conversation_pool.list_system_variables() + print(f" 创建后的系统变量数量: {len(system_vars_before)}") + + # 模拟应用重启 - 重新获取池 + print("\n2. 模拟重新加载...") + await pool_manager.remove_conversation_pool(conversation_id) + + # 重新创建同一个对话池 + conversation_pool_reloaded = await pool_manager.create_conversation_pool(conversation_id, flow_id) + system_vars_after = await conversation_pool_reloaded.list_system_variables() + + print(f" 重新加载后的系统变量数量: {len(system_vars_after)}") + + # 验证变量是否一致 + vars_before_names = {var.name for var in system_vars_before} + vars_after_names = {var.name for var in system_vars_after} + + if vars_before_names == vars_after_names: + print(" ✅ 系统变量持久化验证成功") + else: + print(" ❌ 系统变量持久化验证失败") + print(f" 之前: {vars_before_names}") + print(f" 之后: {vars_after_names}") + + # 清理 + await pool_manager.remove_conversation_pool(conversation_id) + + print("=== 持久化验证完成 ===") + + +if __name__ == "__main__": + async def main(): + """运行所有演示""" + try: + await demonstrate_system_variables() + await demonstrate_variable_references() + await validate_system_variable_persistence() + except Exception as e: + print(f"演示过程中发生错误: {e}") + import traceback + traceback.print_exc() + + asyncio.run(main()) \ No newline at end of file diff --git a/apps/scheduler/variable/type.py b/apps/scheduler/variable/type.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb4627d8b5cf3762c24eb9bc0214623d56c5c8b --- /dev/null +++ b/apps/scheduler/variable/type.py @@ -0,0 +1,72 @@ +from enum import StrEnum + + +class VariableType(StrEnum): + """变量类型枚举""" + # 基础类型 + NUMBER = "number" + STRING = "string" + BOOLEAN = "boolean" + OBJECT = "object" + SECRET = "secret" + GROUP = "group" + FILE = "file" + + # 数组类型 + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILE = "array[file]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_SECRET = "array[secret]" + + def is_array_type(self) -> bool: + """检查是否为数组类型""" + return self in _ARRAY_TYPES + + def is_secret_type(self) -> bool: + """检查是否为密钥类型""" + return self in _SECRET_TYPES + + def get_array_element_type(self) -> "VariableType | None": + """获取数组元素类型""" + if not self.is_array_type(): + return None + + element_type_map = { + VariableType.ARRAY_ANY: None, # any类型无法确定具体元素类型 + VariableType.ARRAY_STRING: VariableType.STRING, + VariableType.ARRAY_NUMBER: VariableType.NUMBER, + VariableType.ARRAY_OBJECT: VariableType.OBJECT, + VariableType.ARRAY_FILE: VariableType.FILE, + VariableType.ARRAY_BOOLEAN: VariableType.BOOLEAN, + VariableType.ARRAY_SECRET: VariableType.SECRET, + } + return element_type_map.get(self) + + +class VariableScope(StrEnum): + """变量作用域枚举""" + SYSTEM = "system" # 系统级变量(只读) + USER = "user" # 用户级变量(跟随用户) + ENVIRONMENT = "env" # 环境级变量(跟随flow) + CONVERSATION = "conversation" # 对话级变量(每次运行重新初始化) + + +# 数组类型集合 +_ARRAY_TYPES = frozenset([ + VariableType.ARRAY_ANY, + VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER, + VariableType.ARRAY_OBJECT, + VariableType.ARRAY_FILE, + VariableType.ARRAY_BOOLEAN, + VariableType.ARRAY_SECRET, +]) + +# 密钥类型集合 +_SECRET_TYPES = frozenset([ + VariableType.SECRET, + VariableType.ARRAY_SECRET, +]) \ No newline at end of file diff --git a/apps/scheduler/variable/variables.py b/apps/scheduler/variable/variables.py new file mode 100644 index 0000000000000000000000000000000000000000..2902dd3adf454ff96aaa98eb2507a533ffccae7c --- /dev/null +++ b/apps/scheduler/variable/variables.py @@ -0,0 +1,479 @@ +import json +import base64 +import hashlib +from typing import Any, Dict, List, Union, Optional +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from .base import BaseVariable, VariableMetadata +from .type import VariableType + + +class StringVariable(BaseVariable): + """字符串变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为字符串类型""" + return isinstance(value, str) + + def to_string(self) -> str: + """转换为字符串""" + return str(self._value) if self._value is not None else "" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "StringVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class NumberVariable(BaseVariable): + """数字变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为数字类型""" + return isinstance(value, (int, float)) + + def to_string(self) -> str: + """转换为字符串""" + return str(self._value) if self._value is not None else "0" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "NumberVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class BooleanVariable(BaseVariable): + """布尔变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为布尔类型""" + return isinstance(value, bool) + + def to_string(self) -> str: + """转换为字符串""" + return str(self._value).lower() if self._value is not None else "false" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "BooleanVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class ObjectVariable(BaseVariable): + """对象变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为对象类型""" + return isinstance(value, dict) + + def to_string(self) -> str: + """转换为字符串""" + if self._value is None: + return "{}" + try: + return json.dumps(self._value, ensure_ascii=False, indent=2) + except (TypeError, ValueError): + return str(self._value) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "ObjectVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class SecretVariable(BaseVariable): + """密钥变量 - 提供安全存储和访问机制""" + + def __init__(self, metadata: VariableMetadata, value: Any = None, encryption_key: Optional[str] = None): + """初始化密钥变量 + + Args: + metadata: 变量元数据 + value: 变量值 + encryption_key: 加密密钥(可选,如果不提供会自动生成) + """ + # 先设置加密密钥,因为父类初始化时可能会调用value setter + self._encryption_key = encryption_key or self._generate_encryption_key() + super().__init__(metadata, value) + self.metadata.is_encrypted = True + + # 如果提供了值,确保它已被加密(在value setter中已处理) + # 这里不需要再次加密,因为super().__init__已经通过setter处理了 + + def _generate_encryption_key(self) -> str: + """生成加密密钥""" + # 使用用户ID和变量名生成唯一的加密密钥 + user_sub = self.metadata.user_sub or "default" + salt = f"{user_sub}:{self.metadata.name}".encode() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(b"secret_variable_key")) + return key.decode() + + def _encrypt_value(self, value: str) -> str: + """加密值""" + if not isinstance(value, str): + value = str(value) + + f = Fernet(self._encryption_key.encode()) + encrypted_value = f.encrypt(value.encode()) + return base64.urlsafe_b64encode(encrypted_value).decode() + + def _decrypt_value(self, encrypted_value: str) -> str: + """解密值""" + try: + encrypted_bytes = base64.urlsafe_b64decode(encrypted_value.encode()) + f = Fernet(self._encryption_key.encode()) + decrypted_value = f.decrypt(encrypted_bytes) + return decrypted_value.decode() + except Exception: + return "[解密失败]" + + def _validate_type(self, value: Any) -> bool: + """验证值类型""" + return isinstance(value, str) + + @property + def value(self) -> str: + """获取解密后的值""" + if self._value is None: + return "" + return self._decrypt_value(self._value) + + @value.setter + def value(self, new_value: Any) -> None: + """设置新值(会自动加密)""" + if self.scope.value == "system": + raise ValueError("系统级变量不能修改") + + if not self._validate_type(new_value): + raise TypeError(f"变量 {self.name} 的值类型不匹配,期望: {self.var_type}") + + self._value = self._encrypt_value(new_value) + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def get_masked_value(self) -> str: + """获取掩码值用于显示""" + actual_value = self.value + if len(actual_value) <= 4: + return "*" * len(actual_value) + return actual_value[:2] + "*" * (len(actual_value) - 4) + actual_value[-2:] + + def to_string(self) -> str: + """转换为字符串(掩码形式)""" + return self.get_masked_value() + + def to_dict(self) -> Dict[str, Any]: + """转换为字典(掩码形式)""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self.get_masked_value(), + "scope": self.scope.value + } + + def to_dict_with_actual_value(self, user_sub: str) -> Dict[str, Any]: + """转换为包含实际值的字典(需要权限检查)""" + if not self.can_access(user_sub): + raise PermissionError(f"用户 {user_sub} 没有权限访问密钥变量 {self.name}") + + return { + "name": self.name, + "type": self.var_type.value, + "value": self.value, # 实际解密值 + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化(保持加密状态)""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, # 加密后的值 + "encryption_key": self._encryption_key, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "SecretVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + instance = cls(metadata, None, data.get("encryption_key")) + instance._value = data["value"] # 直接设置加密值 + return instance + + +class FileVariable(BaseVariable): + """文件变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为文件路径或文件对象""" + return isinstance(value, (str, dict)) and ( + isinstance(value, str) or + (isinstance(value, dict) and "filename" in value and "content" in value) + ) + + def to_string(self) -> str: + """转换为字符串""" + if isinstance(self._value, str): + return self._value + elif isinstance(self._value, dict): + return self._value.get("filename", "unnamed_file") + return "" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "FileVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class ArrayVariable(BaseVariable): + """数组变量""" + + def __init__(self, metadata: VariableMetadata, value: Any = None): + """初始化数组变量""" + # 先设置元素类型,因为父类初始化时会调用_validate_type + self._element_type = metadata.var_type.get_array_element_type() + super().__init__(metadata, value or []) + + def _validate_type(self, value: Any) -> bool: + """验证值是否为数组类型,并检查元素类型""" + if not isinstance(value, list): + return False + + # 如果是 array[any],不需要检查元素类型 + if self._element_type is None: + return True + + # 检查所有元素类型 + for item in value: + if not self._validate_element_type(item): + return False + + return True + + def _validate_element_type(self, element: Any) -> bool: + """验证单个元素的类型""" + if self._element_type is None: # array[any] + return True + + type_validators = { + VariableType.STRING: lambda x: isinstance(x, str), + VariableType.NUMBER: lambda x: isinstance(x, (int, float)), + VariableType.BOOLEAN: lambda x: isinstance(x, bool), + VariableType.OBJECT: lambda x: isinstance(x, dict), + VariableType.SECRET: lambda x: isinstance(x, str), + VariableType.FILE: lambda x: isinstance(x, (str, dict)), + } + + validator = type_validators.get(self._element_type) + return validator(element) if validator else False + + def append(self, item: Any) -> None: + """添加元素到数组""" + if not self._validate_element_type(item): + raise TypeError(f"元素类型不匹配,期望: {self._element_type}") + + if self._value is None: + self._value = [] + self._value.append(item) + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def remove(self, item: Any) -> None: + """从数组中移除元素""" + if self._value and item in self._value: + self._value.remove(item) + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def __len__(self) -> int: + """获取数组长度""" + return len(self._value) if self._value else 0 + + def __getitem__(self, index: int) -> Any: + """获取指定索引的元素""" + if self._value is None: + raise IndexError("数组为空") + return self._value[index] + + def __setitem__(self, index: int, value: Any) -> None: + """设置指定索引的元素""" + if not self._validate_element_type(value): + raise TypeError(f"元素类型不匹配,期望: {self._element_type}") + + if self._value is None: + self._value = [] + + # 扩展数组到指定索引 + while len(self._value) <= index: + self._value.append(None) + + self._value[index] = value + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def to_string(self) -> str: + """转换为字符串""" + if self._value is None: + return "[]" + try: + return json.dumps(self._value, ensure_ascii=False, indent=2) + except (TypeError, ValueError): + return str(self._value) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value, + "element_type": self._element_type.value if self._element_type else None + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "ArrayVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +# 变量类型映射 +VARIABLE_CLASS_MAP = { + VariableType.STRING: StringVariable, + VariableType.NUMBER: NumberVariable, + VariableType.BOOLEAN: BooleanVariable, + VariableType.OBJECT: ObjectVariable, + VariableType.SECRET: SecretVariable, + VariableType.FILE: FileVariable, + VariableType.ARRAY_ANY: ArrayVariable, + VariableType.ARRAY_STRING: ArrayVariable, + VariableType.ARRAY_NUMBER: ArrayVariable, + VariableType.ARRAY_OBJECT: ArrayVariable, + VariableType.ARRAY_FILE: ArrayVariable, + VariableType.ARRAY_BOOLEAN: ArrayVariable, + VariableType.ARRAY_SECRET: ArrayVariable, +} + + +def create_variable(metadata: VariableMetadata, value: Any = None) -> BaseVariable: + """根据类型创建变量实例 + + Args: + metadata: 变量元数据 + value: 变量值 + + Returns: + BaseVariable: 创建的变量实例 + """ + variable_class = VARIABLE_CLASS_MAP.get(metadata.var_type) + if not variable_class: + raise ValueError(f"不支持的变量类型: {metadata.var_type}") + + return variable_class(metadata, value) \ No newline at end of file diff --git a/apps/schemas/config.py b/apps/schemas/config.py index b88a81f1afb018663713a3bf4c0b9b62436e59d4..1c2648d28f1953c38411f04f311138b9ecdc2adb 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -6,6 +6,13 @@ from typing import Literal from pydantic import BaseModel, Field +class NoauthConfig(BaseModel): + """无认证配置""" + + enable: bool = Field(description="是否启用无认证访问", default=False) + user_sub: str = Field(description="调试用户的sub", default="admin") + + class DeployConfig(BaseModel): """部署配置""" @@ -77,6 +84,20 @@ class MongoDBConfig(BaseModel): database: str = Field(description="MongoDB数据库名") +class RedisConfig(BaseModel): + """Redis配置""" + + host: str = Field(description="Redis主机名", default="redis-db") + port: int = Field(description="Redis端口号", default=6379) + password: str | None = Field(description="Redis密码", default=None) + database: int = Field(description="Redis数据库编号", default=0) + decode_responses: bool = Field(description="是否解码响应", default=True) + socket_timeout: float = Field(description="套接字超时时间(秒)", default=5.0) + socket_connect_timeout: float = Field(description="连接超时时间(秒)", default=5.0) + max_connections: int = Field(description="最大连接数", default=10) + health_check_interval: int = Field(description="健康检查间隔(秒)", default=30) + + class LLMConfig(BaseModel): """LLM配置""" @@ -85,6 +106,7 @@ class LLMConfig(BaseModel): model: str = Field(description="LLM API 模型名") max_tokens: int | None = Field(description="LLM API 最大Token数", default=None) temperature: float | None = Field(description="LLM API 温度", default=None) + frequency_penalty: float | None = Field(description="频率惩罚是大模型生成文本时用于减少重复词汇出现的参数", default=0) class FunctionCallConfig(BaseModel): @@ -114,6 +136,12 @@ class CheckConfig(BaseModel): words_list: str = Field(description="敏感词列表文件路径") +class SandboxConfig(BaseModel): + """代码沙箱配置""" + + sandbox_service: str = Field(description="代码沙箱服务地址") + + class ExtraConfig(BaseModel): """额外配置""" @@ -122,7 +150,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - + no_auth: NoauthConfig = Field(description="无认证配置", default=NoauthConfig()) deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig @@ -130,8 +158,10 @@ class ConfigModel(BaseModel): fastapi: FastAPIConfig minio: MinioConfig mongodb: MongoDBConfig + redis: RedisConfig llm: LLMConfig function_call: FunctionCallConfig security: SecurityConfig check: CheckConfig + sandbox: SandboxConfig extra: ExtraConfig diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a20ba84d4805bcd502d9d4ca0f9ba3c49a7bb4c..40225a663143e07ec5a13d0e64abff5021557665 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,6 +15,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" + WAITING = "waiting" RUNNING = "running" SUCCESS = "success" ERROR = "error" @@ -38,6 +39,8 @@ class EventType(str, Enum): TEXT_ADD = "text.add" GRAPH = "graph" DOCUMENT_ADD = "document.add" + STEP_WAITING_FOR_START = "step.waiting_for_start" + STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" @@ -48,8 +51,10 @@ class EventType(str, Enum): class CallType(str, Enum): """Call类型""" - SYSTEM = "system" - PYTHON = "python" + DEFAULT = "default" + LOGIC = "logic" + TRANSFORM = "transform" + TOOL = "tool" class MetadataType(str, Enum): @@ -118,7 +123,6 @@ class HTTPMethod(str, Enum): DELETE = "delete" PATCH = "patch" - class ContentType(str, Enum): """Content-Type""" @@ -142,6 +146,7 @@ class SpecialCallType(str, Enum): FACTS = "Facts" SLOT = "Slot" LLM = "LLM" + DIRECT_REPLY = "DirectReply" START = "start" END = "end" CHOICE = "choice" diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py index 2646d04390099fd92988d78c00d1b1471780d6a9..b70e481477163d00bd3e9d7b371670d27b674233 100644 --- a/apps/schemas/flow.py +++ b/apps/schemas/flow.py @@ -156,3 +156,4 @@ class FlowConfig(BaseModel): flow_id: str flow_config: Flow + diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index d0ab666a9902ec7e0d81bb922f973d655f29eaf2..aa5da0d585d6d0e49d5b1d614a6fc8e1d7024d06 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -5,7 +5,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import EdgeType +from apps.schemas.enum_var import CallType, EdgeType class NodeMetaDataItem(BaseModel): @@ -14,6 +14,7 @@ class NodeMetaDataItem(BaseModel): node_id: str = Field(alias="nodeId") call_id: str = Field(alias="callId") name: str + type: CallType description: str parameters: dict[str, Any] | None editable: bool = Field(default=True) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 44021b0ed5f5f6c372ac0a472bb3ac28ff5bbddb..60c8f17b4adc4f53f21b0cacc02a86e09b495d06 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" +import uuid from enum import Enum from typing import Any @@ -117,7 +118,7 @@ class MCPToolSelectResult(BaseModel): class MCPPlanItem(BaseModel): """MCP 计划""" - + id: str = Field(default_factory=lambda: str(uuid.uuid4())) content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") diff --git a/apps/schemas/message.py b/apps/schemas/message.py index d0661224e0817ff6f25e341d65c88903020d18e8..cf70a82b8573d2e3e34564b8f0ec5280195980d9 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -2,7 +2,7 @@ """队列中的消息结构""" from typing import Any - +from datetime import UTC, datetime from pydantic import BaseModel, Field from apps.schemas.enum_var import EventType, StepStatus @@ -24,6 +24,8 @@ class MessageFlow(BaseModel): flow_id: str = Field(description="Flow ID", alias="flowId") step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") + sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) + sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None) step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") @@ -60,10 +62,14 @@ class DocumentAddContent(BaseModel): document_id: str = Field(description="文档UUID", alias="documentId") document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder") + document_author: str = Field(description="文档作者", alias="documentAuthor", default="") document_name: str = Field(description="文档名称", alias="documentName") document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="") document_type: str = Field(description="文档MIME类型", alias="documentType", default="") document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0) + created_at: float = Field( + description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3) + ) class FlowStartContent(BaseModel): diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..bd908d2375415c798f4804c4eabde797f6a4f7a0 --- /dev/null +++ b/apps/schemas/parameters.py @@ -0,0 +1,69 @@ +from enum import Enum + + +class NumberOperate(str, Enum): + """Choice 工具支持的数字运算符""" + + EQUAL = "number_equal" + NOT_EQUAL = "number_not_equal" + GREATER_THAN = "number_greater_than" + LESS_THAN = "number_less_than" + GREATER_THAN_OR_EQUAL = "number_greater_than_or_equal" + LESS_THAN_OR_EQUAL = "number_less_than_or_equal" + + +class StringOperate(str, Enum): + """Choice 工具支持的字符串运算符""" + + EQUAL = "string_equal" + NOT_EQUAL = "string_not_equal" + CONTAINS = "string_contains" + NOT_CONTAINS = "string_not_contains" + STARTS_WITH = "string_starts_with" + ENDS_WITH = "string_ends_with" + LENGTH_EQUAL = "string_length_equal" + LENGTH_GREATER_THAN = "string_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "string_length_greater_than_or_equal" + LENGTH_LESS_THAN = "string_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "string_length_less_than_or_equal" + REGEX_MATCH = "string_regex_match" + + +class ListOperate(str, Enum): + """Choice 工具支持的列表运算符""" + + EQUAL = "list_equal" + NOT_EQUAL = "list_not_equal" + CONTAINS = "list_contains" + NOT_CONTAINS = "list_not_contains" + LENGTH_EQUAL = "list_length_equal" + LENGTH_GREATER_THAN = "list_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "list_length_greater_than_or_equal" + LENGTH_LESS_THAN = "list_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "list_length_less_than_or_equal" + + +class BoolOperate(str, Enum): + """Choice 工具支持的布尔运算符""" + + EQUAL = "bool_equal" + NOT_EQUAL = "bool_not_equal" + + +class DictOperate(str, Enum): + """Choice 工具支持的字典运算符""" + + EQUAL = "dict_equal" + NOT_EQUAL = "dict_not_equal" + CONTAINS_KEY = "dict_contains_key" + NOT_CONTAINS_KEY = "dict_not_contains_key" + + +class Type(str, Enum): + """Choice 工具支持的类型""" + + STRING = "string" + NUMBER = "number" + LIST = "list" + DICT = "dict" + BOOL = "bool" diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py index 27e16b370ec83acc11e1f435ac1da296fe2a9560..009c8206d967bd7cdb61ba1a41ebf6cf9e87dd68 100644 --- a/apps/schemas/pool.py +++ b/apps/schemas/pool.py @@ -45,9 +45,6 @@ class CallPool(BaseData): Call信息 collection: call - - “path”的格式如下: - 1. Python代码会被导入成包,路径格式为`python::::`,用于查找Call的包路径和类路径 """ type: CallType = Field(description="Call的类型") @@ -77,6 +74,7 @@ class NodePool(BaseData): service_id: str | None = Field(description="Node所属的Service ID", default=None) call_id: str = Field(description="所使用的Call的ID") + type: CallType = Field(description="所使用的Call的类型") known_params: dict[str, Any] | None = Field( description="已知的用于Call部分的参数,独立于输入和输出之外", default=None, diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d7acd368205d0aa5a2b368edf4af3a31ed4d1ff6..b5e1b0c55ad60d0569b57377a6a06a84e7a2920e 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -17,8 +17,10 @@ class RecordDocument(Document): """GET /api/record/{conversation_id} Result中的document数据结构""" id: str = Field(alias="_id", default="") + order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") user_sub: None = None + author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] class Config: @@ -103,11 +105,14 @@ class RecordGroupDocument(BaseModel): """RecordGroup关联的文件""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + order: int = Field(default=0, description="文档顺序") + author: str = Field(default="", description="文档作者") name: str = Field(default="", description="文档名称") abstract: str = Field(default="", description="文档摘要") extension: str = Field(default="", description="文档扩展名") size: int = Field(default=0, description="文档大小,单位是KB") associated: Literal["question", "answer"] + created_at: float = Field(default=0.0, description="文档创建时间") class Record(RecordData): diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 2305dd93dfe33969691dcc42928e578e7f1c4207..a3a8848c32e898364df671bbf3b26cb40c3f6a0a 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -16,8 +16,8 @@ class RequestDataApp(BaseModel): """模型对话中包含的app信息""" app_id: str = Field(description="应用ID", alias="appId") - flow_id: str = Field(description="Flow ID", alias="flowId") - params: dict[str, Any] = Field(description="插件参数") + flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") + params: dict[str, Any] | None = Field(default=None, description="插件参数") class MockRequestData(BaseModel): @@ -46,6 +46,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + new_task: bool = Field(default=True, description="是否新建任务") class QuestionBlacklistRequest(BaseModel): diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index b2a1872918638b153e1e94ff4db85ef00cfc0502..7a8cc783015a8f69259194cec4d2bea359f7a7fd 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -14,6 +14,14 @@ from apps.schemas.flow_topology import ( NodeServiceItem, PositionItem, ) +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType from apps.schemas.record import RecordData from apps.schemas.user import UserInfo @@ -608,7 +616,8 @@ class LLMProviderInfo(BaseModel): """LLM数据结构""" llm_id: str = Field(alias="llmId", description="LLM ID") - icon: str = Field(default="", description="LLM图标", max_length=25536) + # icon: str = Field(default="", description="LLM图标", max_length=25536) + icon: str = Field(default="", description="LLM图标") openai_base_url: str = Field( default="https://api.openai.com/v1", description="OpenAI API Base URL", @@ -628,3 +637,42 @@ class ListLLMRsp(ResponseData): """GET /api/llm 返回数据结构""" result: list[LLMProviderInfo] = Field(default=[], title="Result") + + +class ParamsNode(BaseModel): + """参数数据结构""" + param_name: str = Field(..., description="参数名称", alias="paramName") + param_path: str = Field(..., description="参数路径", alias="paramPath") + param_type: Type = Field(..., description="参数类型", alias="paramType") + sub_params: list["ParamsNode"] | None = Field( + default=None, description="子参数列表", alias="subParams" + ) + + +class StepParams(BaseModel): + """参数数据结构""" + step_id: str = Field(..., description="步骤ID", alias="stepId") + name: str = Field(..., description="Step名称") + params_node: ParamsNode | None = Field( + default=None, description="参数节点", alias="paramsNode") + + +class GetParamsRsp(ResponseData): + """GET /api/params 返回数据结构""" + + result: list[StepParams] = Field( + default=[], description="参数列表", alias="result" + ) + + +class OperateAndBindType(BaseModel): + """操作和绑定类型数据结构""" + + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型") + bind_type: Type = Field(description="绑定类型") + + +class GetOperaRsp(ResponseData): + """GET /api/operate 返回数据结构""" + + result: list[OperateAndBindType] = Field(..., title="Result") diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index 38fd94ad4db6cff189e79f4abaec3448c97d45ec..dd2dbb6047241c99e4b09296b0cf67023edf6628 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -1,18 +1,19 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """插件、工作流、步骤相关数据结构定义""" +from enum import StrEnum from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.task import FlowStepHistory class CallInfo(BaseModel): """Call的名称和描述""" - name: str = Field(description="Call的名称") + type: CallType = Field(description="Call的类别") description: str = Field(description="Call的描述") @@ -22,6 +23,7 @@ class CallIds(BaseModel): task_id: str = Field(description="任务ID") flow_id: str = Field(description="Flow ID") session_id: str = Field(description="当前用户的Session ID") + conversation_id: str = Field(description="当前对话ID") app_id: str = Field(description="当前应用的ID") user_sub: str = Field(description="当前用户的用户ID") diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 8efcb59914d568478c65d611670b251d019b38ed..3a37126053e693226da63a07283832c9cf3c0fae 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from apps.schemas.enum_var import StepStatus from apps.schemas.flow import Step +from apps.schemas.mcp import MCPPlan class FlowStepHistory(BaseModel): @@ -42,6 +43,7 @@ class ExecutorState(BaseModel): # 附加信息 step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") + step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) @@ -75,6 +77,7 @@ class TaskRuntime(BaseModel): summary: str = Field(description="摘要", default="") filled: dict[str, Any] = Field(description="填充的槽位", default={}) documents: list[dict[str, Any]] = Field(description="文档列表", default=[]) + temporary_plans: MCPPlan | None = Field(description="临时计划列表", default=None) class Task(BaseModel): @@ -86,7 +89,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) + context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) state: ExecutorState | None = Field(description="Flow的状态", default=None) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 4bcade45c757ea2b3dca6c58863b2852af98c15b..bac964db132fecc9599bb80943e6fa119048f387 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -59,7 +59,11 @@ class ConversationManager: model_name=llm.model_name, ) kb_item_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except: + logger.error("[ConversationManager] 获取团队知识库列表失败") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: diff --git a/apps/services/document.py b/apps/services/document.py index 203162da1137cdd69d900ceea65a3029d2f4fd8e..451423a9d8f752c2abbf97f5b2a1df38e6d7fefe 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -2,6 +2,7 @@ """文件Manager""" import base64 +from datetime import UTC, datetime import logging import uuid @@ -131,12 +132,15 @@ class DocumentManager: return [ RecordDocument( _id=doc.id, + order=doc.order, + author=doc.author, abstract=doc.abstract, name=doc.name, type=doc.extension, size=doc.size, conversation_id=record_group.get("conversation_id", ""), associated=doc.associated, + created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3) ) for doc in docs if type is None or doc.associated == type ] diff --git a/apps/services/flow.py b/apps/services/flow.py index 9275fc60c75199e8e17ebcba8f164083190b7211..f84f5b832c02fd4bd6a704b4a25dec80db0dfada 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -20,7 +20,6 @@ from apps.schemas.flow_topology import ( PositionItem, ) from apps.services.node import NodeManager - logger = logging.getLogger(__name__) @@ -96,6 +95,7 @@ class FlowManager: nodeId=node_pool_record["_id"], callId=node_pool_record["call_id"], name=node_pool_record["name"], + type=node_pool_record["type"], description=node_pool_record["description"], editable=True, createdAt=node_pool_record["created_at"], @@ -154,7 +154,7 @@ class FlowManager: NodeServiceItem( serviceId=record["_id"], name=record["name"], - type="default", + type="default", # TODO record["type"]? nodeMetaDatas=[], createdAt=str(record["created_at"]), ) @@ -257,15 +257,21 @@ class FlowManager: debug=flow_config.debug, ) for node_id, node_config in flow_config.steps.items(): - input_parameters = node_config.params - if node_config.node not in ("Empty"): - _, output_parameters = await NodeManager.get_node_params(node_config.node) + # 对于Code节点,直接使用保存的完整params作为parameters + if node_config.type == "Code": + parameters = node_config.params # 直接使用保存的完整params else: - output_parameters = {} - parameters = { - "input_parameters": input_parameters, - "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(), - } + # 其他节点:使用原有逻辑 + input_parameters = node_config.params + if node_config.node not in ("Empty"): + _, output_parameters = await NodeManager.get_node_params(node_config.node) + else: + output_parameters = {} + parameters = { + "input_parameters": input_parameters, + "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(), + } + node_item = NodeItem( stepId=node_id, nodeId=node_config.node, @@ -275,8 +281,7 @@ class FlowManager: editable=True, callId=node_config.type, parameters=parameters, - position=PositionItem( - x=node_config.pos.x, y=node_config.pos.y), + position=PositionItem(x=node_config.pos.x, y=node_config.pos.y), ) flow_item.nodes.append(node_item) @@ -384,13 +389,19 @@ class FlowManager: debug=flow_item.debug, ) for node_item in flow_item.nodes: + # 对于Code节点,保存完整的parameters;其他节点只保存input_parameters + if node_item.call_id == "Code": + params = node_item.parameters # 保存完整的parameters(包含input_parameters、output_parameters以及code配置) + else: + params = node_item.parameters.get("input_parameters", {}) # 其他节点只保存input_parameters + flow_config.steps[node_item.step_id] = Step( type=node_item.call_id, node=node_item.node_id, name=node_item.name, description=node_item.description, pos=node_item.position, - params=node_item.parameters.get("input_parameters", {}), + params=params, ) for edge_item in flow_item.edges: edge_from = edge_item.source_node diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index 9b4077f92cf44bca42d321cc20975d7ea2e5cadd..bd8dfc9e808f642a8eecf493a96f762dad8ae7b9 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -138,7 +138,11 @@ class KnowledgeBaseManager: return [] kb_ids_update_success = [] kb_item_dict_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except Exception as e: + logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: diff --git a/apps/services/node.py b/apps/services/node.py index 6f0d492edacdd7611324a38953417e0fa769249b..bf48e71feec4df41dddc3fecfa1228913940571c 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -16,6 +16,7 @@ NODE_TYPE_MAP = { "API": APINode, } + class NodeManager: """Node管理器""" @@ -29,7 +30,6 @@ class NodeManager: raise ValueError(err) return node["call_id"] - @staticmethod async def get_node(node_id: str) -> NodePool: """获取Node的类型""" @@ -40,7 +40,6 @@ class NodeManager: raise ValueError(err) return NodePool.model_validate(node) - @staticmethod async def get_node_name(node_id: str) -> str: """获取node的名称""" @@ -52,7 +51,6 @@ class NodeManager: return "" return node_doc["name"] - @staticmethod def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]: """递归合并参数Schema,将known_params中的值填充到params_schema的对应位置""" @@ -75,7 +73,6 @@ class NodeManager: return params_schema - @staticmethod async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" @@ -100,7 +97,6 @@ class NodeManager: err = f"[NodeManager] Call {call_id} 不存在" logger.error(err) raise ValueError(err) - # 返回参数Schema return ( NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), diff --git a/apps/services/parameter.py b/apps/services/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..ae375e97fcf92bebfc8dca2b384ffcead4da58bf --- /dev/null +++ b/apps/services/parameter.py @@ -0,0 +1,86 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""flow Manager""" + +import logging + +from pymongo import ASCENDING + +from apps.services.node import NodeManager +from apps.schemas.flow_topology import FlowItem +from apps.scheduler.slot.slot import Slot +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, + Type +) +from apps.schemas.response_data import ( + OperateAndBindType, + ParamsNode, + StepParams, +) +from apps.services.node import NodeManager +logger = logging.getLogger(__name__) + + +class ParameterManager: + """Parameter Manager""" + @staticmethod + async def get_operate_and_bind_type(param_type: Type) -> list[OperateAndBindType]: + """Get operate and bind type""" + result = [] + operate = None + if param_type == Type.NUMBER: + operate = NumberOperate + elif param_type == Type.STRING: + operate = StringOperate + elif param_type == Type.LIST: + operate = ListOperate + elif param_type == Type.BOOL: + operate = BoolOperate + elif param_type == Type.DICT: + operate = DictOperate + if operate: + for item in operate: + result.append(OperateAndBindType( + operate=item, + bind_type=ConditionHandler.get_value_type_from_operate(item))) + return result + + @staticmethod + async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]: + """Get pre params by flow and step id""" + index = 0 + q = [step_id] + in_edges = {} + step_id_to_node_id = {} + for step in flow.nodes: + step_id_to_node_id[step.step_id] = step.node_id + for edge in flow.edges: + if edge.target_node not in in_edges: + in_edges[edge.target_node] = [] + in_edges[edge.target_node].append(edge.source_node) + while index < len(q): + tmp_step_id = q[index] + index += 1 + for i in range(len(in_edges.get(tmp_step_id, []))): + pre_node_id = in_edges[tmp_step_id][i] + if pre_node_id not in q: + q.append(pre_node_id) + pre_step_params = [] + for step_id in q: + node_id = step_id_to_node_id.get(step_id) + params_schema, output_schema = await NodeManager.get_node_params(node_id) + slot = Slot(output_schema) + params_node = slot.get_params_node_from_schema(root='/output') + pre_step_params.append( + StepParams( + stepId=node_id, + name=params_schema.get("name", ""), + paramsNode=params_node + ) + ) + return pre_step_params diff --git a/apps/services/predecessor_cache_service.py b/apps/services/predecessor_cache_service.py new file mode 100644 index 0000000000000000000000000000000000000000..967077ea96b06982c543cf9ad6306b300cea48d5 --- /dev/null +++ b/apps/services/predecessor_cache_service.py @@ -0,0 +1,488 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""前置节点变量预解析缓存服务""" + +import asyncio +import hashlib +import json +import logging +from typing import List, Dict, Any, Optional +from datetime import datetime, UTC + +from apps.common.redis_cache import RedisCache +from apps.common.process_handler import ProcessHandler +from apps.services.flow import FlowManager +from apps.scheduler.variable.variables import create_variable +from apps.scheduler.variable.base import VariableMetadata +from apps.scheduler.variable.type import VariableType, VariableScope + +logger = logging.getLogger(__name__) + +# 全局Redis缓存实例 +redis_cache = RedisCache() +predecessor_cache = None + +# 添加任务管理 +_background_tasks: Dict[str, asyncio.Task] = {} +_task_lock = asyncio.Lock() + +def _get_predecessor_cache(): + """获取predecessor_cache实例,确保初始化""" + global predecessor_cache + if predecessor_cache is None: + from apps.common.redis_cache import PredecessorVariableCache + predecessor_cache = PredecessorVariableCache(redis_cache) + return predecessor_cache + +async def init_predecessor_cache(): + """初始化前置节点变量缓存""" + global predecessor_cache + if predecessor_cache is None: + from apps.common.redis_cache import PredecessorVariableCache + predecessor_cache = PredecessorVariableCache(redis_cache) + +async def cleanup_background_tasks(): + """清理后台任务""" + async with _task_lock: + if not _background_tasks: + logger.info("没有后台任务需要清理") + return + + logger.info(f"开始清理 {len(_background_tasks)} 个后台任务") + + # 取消所有未完成的任务 + cancelled_count = 0 + for task_id, task in list(_background_tasks.items()): + if not task.done(): + task.cancel() + cancelled_count += 1 + logger.debug(f"取消后台任务: {task_id}") + + # 等待任务取消完成,设置超时避免永久等待 + if cancelled_count > 0: + logger.info(f"等待 {cancelled_count} 个任务取消完成...") + timeout = 5.0 # 5秒超时 + try: + await asyncio.wait_for( + asyncio.gather(*[task for task in _background_tasks.values()], return_exceptions=True), + timeout=timeout + ) + except asyncio.TimeoutError: + logger.warning(f"等待任务取消超时 ({timeout}s),强制清理") + except Exception as e: + logger.error(f"等待任务取消时出错: {e}") + + # 清理任务字典 + completed_count = len(_background_tasks) + _background_tasks.clear() + logger.info(f"后台任务清理完成,共清理 {completed_count} 个任务") + +async def periodic_cleanup_background_tasks(): + """定期清理已完成的后台任务""" + try: + async with _task_lock: + if not _background_tasks: + return + + completed_tasks = [] + for task_id, task in list(_background_tasks.items()): + if task.done(): + completed_tasks.append(task_id) + try: + # 获取任务结果,记录异常 + await task + logger.debug(f"后台任务已完成: {task_id}") + except Exception as e: + logger.error(f"后台任务执行异常: {task_id}, 错误: {e}") + + # 移除已完成的任务 + for task_id in completed_tasks: + _background_tasks.pop(task_id, None) + + if completed_tasks: + logger.info(f"定期清理了 {len(completed_tasks)} 个已完成的后台任务") + + except Exception as e: + logger.error(f"定期清理后台任务失败: {e}") + + +class PredecessorCacheService: + """前置节点变量预解析缓存服务""" + + @staticmethod + async def initialize_redis(): + """初始化Redis连接""" + try: + # 从配置文件读取Redis配置 + from apps.common.config import Config + + config = Config().get_config() + redis_config = config.redis + + logger.info(f"准备连接Redis: {redis_config.host}:{redis_config.port}") + await redis_cache.init(redis_config=redis_config) + + # 验证连接是否正常 + if redis_cache.is_connected(): + logger.info("前置节点缓存服务Redis初始化成功") + return + else: + raise Exception("Redis连接验证失败") + + except Exception as e: + logger.error(f"使用配置文件连接Redis失败: {e}") + + # 尝试降级连接方案 + try: + logger.info("尝试降级连接方案...") + from apps.common.config import Config + config = Config().get_config() + redis_config = config.redis + + # 构建简单的Redis URL + password_part = f":{redis_config.password}@" if redis_config.password else "" + redis_url = f"redis://{password_part}{redis_config.host}:{redis_config.port}/{redis_config.database}" + + await redis_cache.init(redis_url=redis_url) + + if redis_cache.is_connected(): + logger.info("降级连接方案成功") + return + else: + raise Exception("降级连接方案也失败") + + except Exception as fallback_error: + logger.error(f"降级连接方案也失败: {fallback_error}") + + # 即使Redis初始化失败,也不要抛出异常,而是继续运行(降级模式) + logger.info("将使用实时解析模式作为降级方案") + + @staticmethod + def calculate_flow_hash(flow_item) -> str: + """计算Flow拓扑结构的哈希值""" + try: + # 提取关键的拓扑信息 + topology_data = { + 'nodes': [ + { + 'step_id': node.step_id, + 'call_id': getattr(node, 'call_id', ''), + 'parameters': getattr(node, 'parameters', {}) + } + for node in flow_item.nodes + ], + 'edges': [ + { + 'source_node': edge.source_node, + 'target_node': edge.target_node + } + for edge in flow_item.edges + ] + } + + # 生成哈希 + topology_json = json.dumps(topology_data, sort_keys=True) + return hashlib.md5(topology_json.encode()).hexdigest() + except Exception as e: + logger.error(f"计算Flow哈希失败: {e}") + return str(datetime.now(UTC).timestamp()) # 降级方案 + + @staticmethod + async def trigger_flow_parsing(flow_id: str, force_refresh: bool = False): + """触发整个Flow的前置节点变量解析""" + try: + # 获取Flow信息 + flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id) + if not flow_item: + logger.warning(f"Flow不存在,跳过解析: {flow_id}") + return + + # 计算当前Flow的哈希 + current_hash = PredecessorCacheService.calculate_flow_hash(flow_item) + + # 检查是否需要重新解析 + if not force_refresh: + cached_hash = await _get_predecessor_cache().get_flow_hash(flow_id) + if cached_hash == current_hash: + logger.info(f"Flow拓扑未变化,跳过解析: {flow_id}") + return + + # 更新Flow哈希 + await _get_predecessor_cache().set_flow_hash(flow_id, current_hash) + + # 清除旧缓存 + await _get_predecessor_cache().invalidate_flow_cache(flow_id) + + # 为每个节点启动异步解析任务 + tasks = [] + for node in flow_item.nodes: + step_id = node.step_id + task_id = f"parse_predecessor_{flow_id}_{step_id}" + + # 避免重复任务 + async with _task_lock: + if task_id in _background_tasks and not _background_tasks[task_id].done(): + continue + + # 异步启动解析任务 + task = asyncio.create_task( + PredecessorCacheService._parse_single_node_predecessor( + flow_id, step_id, current_hash + ) + ) + _background_tasks[task_id] = task + tasks.append((task_id, task)) + + if tasks: + logger.info(f"启动Flow前置节点解析任务: {flow_id}, 节点数量: {len(tasks)}") + # 简化处理:直接启动任务,依赖cleanup_background_tasks进行清理 + for task_id, task in tasks: + # 不添加回调,让任务自然完成 + logger.debug(f"启动后台任务: {task_id}") + + except Exception as e: + logger.error(f"触发Flow解析失败: {flow_id}, 错误: {e}") + + @staticmethod + async def _cleanup_task(task_id: str): + """清理完成的任务""" + try: + async with _task_lock: + task = _background_tasks.pop(task_id, None) + if task and task.done(): + # 检查任务是否有异常 + try: + result = await task + logger.debug(f"后台任务完成: {task_id}") + except Exception as e: + logger.error(f"后台任务执行异常: {task_id}, 错误: {e}") + except Exception as e: + logger.error(f"清理任务失败: {task_id}, 错误: {e}") + + @staticmethod + async def _parse_single_node_predecessor(flow_id: str, step_id: str, flow_hash: str): + """解析单个节点的前置节点变量""" + try: + # 检查事件循环是否仍然活跃 + try: + asyncio.get_running_loop() + except RuntimeError: + logger.warning(f"事件循环已关闭,跳过解析: {flow_id}:{step_id}") + return + + # 设置解析状态 + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "parsing") + + # 获取Flow信息 + flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id) + if not flow_item: + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "failed") + return + + # 查找前置节点 + predecessor_nodes = PredecessorCacheService._find_predecessor_nodes(flow_item, step_id) + + # 为每个前置节点创建输出变量 + variables_data = [] + for node in predecessor_nodes: + node_vars = await PredecessorCacheService._create_node_output_variables(node) + variables_data.extend(node_vars) + + # 缓存结果 + await _get_predecessor_cache().set_cached_variables(flow_id, step_id, variables_data, flow_hash) + + # 设置完成状态 + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "completed") + + logger.info(f"节点前置变量解析完成: {flow_id}:{step_id}, 变量数量: {len(variables_data)}") + + except asyncio.CancelledError: + logger.info(f"节点前置变量解析任务被取消: {flow_id}:{step_id}") + try: + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "cancelled") + except Exception: + pass # 忽略清理时的错误 + except Exception as e: + logger.error(f"解析节点前置变量失败: {flow_id}:{step_id}, 错误: {e}") + try: + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "failed") + except Exception: + # 如果连设置状态都失败了,说明可能是事件循环关闭导致的 + logger.warning(f"无法设置解析状态为失败: {flow_id}:{step_id}") + + @staticmethod + async def get_predecessor_variables_optimized( + flow_id: str, + step_id: str, + user_sub: str, + max_wait_time: int = 10 + ) -> List[Dict[str, Any]]: + """优化的前置节点变量获取(优先使用缓存)""" + try: + # 1. 先尝试从缓存获取 + cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id) + if cached_vars is not None: + logger.info(f"使用缓存的前置节点变量: {flow_id}:{step_id}") + return cached_vars + + # 2. 检查是否正在解析中 + if await _get_predecessor_cache().is_parsing_in_progress(flow_id, step_id): + logger.info(f"等待前置节点变量解析完成: {flow_id}:{step_id}") + # 等待解析完成 + if await _get_predecessor_cache().wait_for_parsing_completion(flow_id, step_id, max_wait_time): + cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id) + if cached_vars is not None: + return cached_vars + + # 3. 缓存未命中,启动实时解析 + logger.info(f"缓存未命中,启动实时解析: {flow_id}:{step_id}") + + # 获取Flow信息 + flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id) + if not flow_item: + return [] + + # 计算Flow哈希 + flow_hash = PredecessorCacheService.calculate_flow_hash(flow_item) + + # 立即解析并缓存 + await PredecessorCacheService._parse_single_node_predecessor(flow_id, step_id, flow_hash) + + # 再次尝试从缓存获取 + cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id) + return cached_vars or [] + + except Exception as e: + logger.error(f"获取优化前置节点变量失败: {flow_id}:{step_id}, 错误: {e}") + return [] + + @staticmethod + async def _get_flow_by_flow_id(flow_id: str): + """通过flow_id获取工作流信息""" + try: + from apps.common.mongo import MongoDB + + app_collection = MongoDB().get_collection("app") + + # 查询包含此flow_id的app,同时获取app_id + app_record = await app_collection.find_one( + {"flows.id": flow_id}, + {"_id": 1} + ) + + if not app_record: + logger.warning(f"未找到包含flow_id {flow_id} 的应用") + return None + + app_id = app_record["_id"] + + # 使用现有的FlowManager方法获取flow + flow_item = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + return flow_item + + except Exception as e: + logger.error(f"通过flow_id获取工作流失败: {e}") + return None + + @staticmethod + def _find_predecessor_nodes(flow_item, current_step_id: str) -> List: + """在工作流中查找前置节点""" + try: + predecessor_nodes = [] + + # 遍历边,找到指向当前节点的边 + for edge in flow_item.edges: + if edge.target_node == current_step_id: + # 找到前置节点 + source_node = next( + (node for node in flow_item.nodes if node.step_id == edge.source_node), + None + ) + if source_node: + predecessor_nodes.append(source_node) + + logger.debug(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点") + return predecessor_nodes + + except Exception as e: + logger.error(f"查找前置节点失败: {e}") + return [] + + @staticmethod + async def _create_node_output_variables(node) -> List[Dict[str, Any]]: + """根据节点的output_parameters配置创建输出变量数据""" + try: + variables_data = [] + node_id = node.step_id + + # 统一从节点的output_parameters创建变量 + output_params = {} + if hasattr(node, 'parameters') and node.parameters: + if isinstance(node.parameters, dict): + output_params = node.parameters.get('output_parameters', {}) + else: + output_params = getattr(node.parameters, 'output_parameters', {}) + + # 如果没有配置output_parameters,跳过此节点 + if not output_params: + logger.debug(f"节点 {node_id} 没有配置output_parameters,跳过创建输出变量") + return variables_data + + # 遍历output_parameters中的每个key-value对,创建对应的变量数据 + for param_name, param_config in output_params.items(): + # 解析参数配置 + if isinstance(param_config, dict): + param_type = param_config.get('type', 'string') + description = param_config.get('description', '') + else: + # 如果param_config不是字典,可能是简单的类型字符串 + param_type = str(param_config) if param_config else 'string' + description = '' + + # 确定变量类型 + var_type = VariableType.STRING # 默认类型 + if param_type == 'number': + var_type = VariableType.NUMBER + elif param_type == 'boolean': + var_type = VariableType.BOOLEAN + elif param_type == 'object': + var_type = VariableType.OBJECT + elif param_type == 'array' or param_type == 'array[any]': + var_type = VariableType.ARRAY_ANY + elif param_type == 'array[string]': + var_type = VariableType.ARRAY_STRING + elif param_type == 'array[number]': + var_type = VariableType.ARRAY_NUMBER + elif param_type == 'array[object]': + var_type = VariableType.ARRAY_OBJECT + elif param_type == 'array[boolean]': + var_type = VariableType.ARRAY_BOOLEAN + elif param_type == 'array[file]': + var_type = VariableType.ARRAY_FILE + elif param_type == 'array[secret]': + var_type = VariableType.ARRAY_SECRET + elif param_type == 'file': + var_type = VariableType.FILE + elif param_type == 'secret': + var_type = VariableType.SECRET + + # 创建变量数据(用于缓存的字典格式) + variable_data = { + 'name': f"{node_id}.{param_name}", + 'var_type': var_type.value, + 'scope': VariableScope.CONVERSATION.value, + 'value': "", # 配置阶段的潜在变量,值为空 + 'description': description or f"来自节点 {node_id} 的输出参数 {param_name}", + 'created_at': datetime.now(UTC).isoformat(), + 'updated_at': datetime.now(UTC).isoformat(), + 'step_name': getattr(node, 'name', node_id), # 节点名称 + 'step_id': node_id # 节点ID + } + + variables_data.append(variable_data) + + logger.debug(f"为节点 {node_id} 创建了 {len(variables_data)} 个输出变量: {[v['name'] for v in variables_data]}") + return variables_data + + except Exception as e: + logger.error(f"创建节点输出变量失败: {e}") + return [] \ No newline at end of file diff --git a/apps/services/rag.py b/apps/services/rag.py index 6b6c843dc6fc09a14809ef0a11bc44b2da434f29..1dc6fd32f9b90453046de07b402274b98e59a8be 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """对接Euler Copilot RAG""" +from datetime import UTC, datetime import json import logging from collections.abc import AsyncGenerator @@ -30,8 +31,7 @@ class RAG: """系统提示词""" user_prompt = """' - 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后进行脚注。 - 一个例子将在中给出。 + 你是华鲲振宇的智能助手。请结合给出的背景信息, 回答用户的提问。 上下文背景信息将在中给出。 用户的提问将在中给出。 注意: @@ -68,7 +68,7 @@ class RAG: openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]] - + {bac_info} @@ -156,9 +156,11 @@ class RAG: "id": doc_chunk["docId"], "order": doc_cnt, "name": doc_chunk.get("docName", ""), + "author": doc_chunk.get("docAuthor", ""), "extension": doc_chunk.get("docExtension", ""), "abstract": doc_chunk.get("docAbstract", ""), "size": doc_chunk.get("docSize", 0), + "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)), }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] diff --git a/apps/services/task.py b/apps/services/task.py index 1e672be690a6f17896ab01bc7149aa353987ecf7..17976e62c1e3b0b8c53091789d15a06751042095 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -9,6 +9,7 @@ from apps.common.mongo import MongoDB from apps.schemas.record import RecordGroup from apps.schemas.request_data import RequestData from apps.schemas.task import ( + FlowStepHistory, Task, TaskIds, TaskRuntime, @@ -45,7 +46,6 @@ class TaskManager: return Task.model_validate(task) - @staticmethod async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None: """获取组ID的最后一条问答组关联的任务""" @@ -58,7 +58,6 @@ class TaskManager: task = await task_collection.find_one({"_id": record_group_obj.task_id}) return Task.model_validate(task) - @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" @@ -68,7 +67,6 @@ class TaskManager: return None return Task.model_validate(task) - @staticmethod async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: """根据record_group_id获取flow信息""" @@ -95,9 +93,8 @@ class TaskManager: else: return flow_context_list - @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: + async def get_context_by_task_id(task_id: str, length: int = 0) -> list[FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context") @@ -115,9 +112,8 @@ class TaskManager: else: return flow_context - @staticmethod - async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: + async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: @@ -125,7 +121,7 @@ class TaskManager: # 查找是否存在 current_context = await flow_context_collection.find_one({ "task_id": task_id, - "_id": history["_id"], + "_id": history.id, }) if current_context: await flow_context_collection.update_one( @@ -133,11 +129,10 @@ class TaskManager: {"$set": history}, ) else: - await flow_context_collection.insert_one(history) + await flow_context_collection.insert_one(history.model_dump(by_alias=True, exclude_none=True)) except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") - @staticmethod async def delete_task_by_task_id(task_id: str) -> None: """通过task_id删除Task信息""" @@ -148,7 +143,6 @@ class TaskManager: if task: await task_collection.delete_one({"_id": task_id}) - @staticmethod async def delete_tasks_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" @@ -167,7 +161,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod async def get_task( cls, @@ -212,7 +205,6 @@ class TaskManager: runtime=TaskRuntime(), ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" diff --git a/apps/templates/generate_llm_operator_config.py b/apps/templates/generate_llm_operator_config.py index 2dc270b6c327f8347139ef157fe175e4c185d1b2..f63136413bedd37b177250b6b2e6550476cedb3b 100644 --- a/apps/templates/generate_llm_operator_config.py +++ b/apps/templates/generate_llm_operator_config.py @@ -5,54 +5,60 @@ import base64 import os llm_provider_dict={ - "baichuan":{ - "provider":"baichuan", - "url":"https://api.baichuan-ai.com/v1", - "description":"百川大模型平台", - "icon":"", - }, - "modelscope":{ - "provider":"modelscope", - "url":None, - "description":"基于魔塔部署的本地大模型服务", - "icon":"", + "huakunzhenyu": { + "provider": "huakunzhenyu", + "url": "http://192.168.226.30:9997/v1", + "description": "天巡CubeX智擎平台", + "icon": "", }, + # "baichuan":{ + # "provider":"baichuan", + # "url":"https://api.baichuan-ai.com/v1", + # "description":"百川大模型平台", + # "icon":"", + # }, + # "modelscope":{ + # "provider":"modelscope", + # "url":None, + # "description":"基于魔塔部署的本地大模型服务", + # "icon":"", + # }, "ollama":{ "provider":"ollama", "url":None, "description":"基于Ollama部署的本地大模型服务", "icon":"", }, - "openai":{ - "provider":"openai", - "url":"https://api.openai.com/v1", - "description":"OpenAI大模型平台", - "icon":"", - }, - "qwen":{ - "provider":"qwen", - "url":"https://dashscope.aliyuncs.com/compatible-mode/v1", - "description":"阿里百炼大模型平台", - "icon":"", - }, - "spark":{ - "provider":"spark", - "url":"https://spark-api-open.xf-yun.com/v1", - "description":"讯飞星火大模型平台", - "icon":"", - }, + # "openai":{ + # "provider":"openai", + # "url":"https://api.openai.com/v1", + # "description":"OpenAI大模型平台", + # "icon":"", + # }, + # "qwen":{ + # "provider":"qwen", + # "url":"https://dashscope.aliyuncs.com/compatible-mode/v1", + # "description":"阿里百炼大模型平台", + # "icon":"", + # }, + # "spark":{ + # "provider":"spark", + # "url":"https://spark-api-open.xf-yun.com/v1", + # "description":"讯飞星火大模型平台", + # "icon":"", + # }, "vllm":{ "provider":"vllm", "url":None, "description":"基于VLLM部署的本地大模型服务", "icon":"", }, - "wenxin":{ - "provider":"wenxin", - "url":"https://qianfan.baidubce.com/v2", - "description":"百度文心大模型平台", - "icon":"", - }, + # "wenxin":{ + # "provider":"wenxin", + # "url":"https://qianfan.baidubce.com/v2", + # "description":"百度文心大模型平台", + # "icon":"", + # }, } icon_path="./apps/templates/llm_provider_icon" icon_file_name_list=os.listdir(icon_path) @@ -68,3 +74,6 @@ for file_name in icon_file_name_list: if provider_name in provider: llm_provider_dict[provider]['icon'] = f"data:image/svg+xml;base64,{base64_string}" break + +if __name__ == "__main__": + print(llm_provider_dict) \ No newline at end of file diff --git a/apps/templates/llm_provider_icon/huakunzhenyu.svg b/apps/templates/llm_provider_icon/huakunzhenyu.svg new file mode 100644 index 0000000000000000000000000000000000000000..d5351cafe9f340b1fbd66ebc8d4c50308444b031 --- /dev/null +++ b/apps/templates/llm_provider_icon/huakunzhenyu.svg @@ -0,0 +1,3 @@ + + diff --git a/assets/.config.example.toml b/assets/.config.example.toml index 61f13e49a893e8a698df3113148226c8a6a7463a..28fb5ef8ec5e3451101b1776aa4ce27f25f0057a 100644 --- a/assets/.config.example.toml +++ b/assets/.config.example.toml @@ -37,6 +37,17 @@ user = 'euler_copilot' password = '' database = 'euler_copilot' +[redis] +host = 'redis-db' +port = 6379 +password = '' +database = 0 +decode_responses = true +socket_timeout = 5.0 +socket_connect_timeout = 5.0 +max_connections = 10 +health_check_interval = 30 + [minio] endpoint = '127.0.0.1:9000' access_key = 'minioadmin' @@ -62,5 +73,8 @@ temperature = 0.7 enable = false words_list = '' +[sandbox] +sandbox_service = 'http://127.0.0.1:8000' + [extra] sql_url = 'http://127.0.0.1:9015' diff --git a/deploy/chart/euler_copilot/configs/framework/config.toml b/deploy/chart/euler_copilot/configs/framework/config.toml index 64a1a6a03600b1cfc48c504f32c1e8c9ce4350cb..863f3aa5c7abf96aadc7ee367af1e064d44633f0 100644 --- a/deploy/chart/euler_copilot/configs/framework/config.toml +++ b/deploy/chart/euler_copilot/configs/framework/config.toml @@ -37,6 +37,17 @@ user = 'euler_copilot' password = '${mongo-password}' database = 'euler_copilot' +[redis] +host = 'redis-db.{{ .Release.Namespace }}.svc.cluster.local' +port = 6379 +password = '${redis-password}' +database = 5 +decode_responses = true +socket_timeout = 5.0 +socket_connect_timeout = 5.0 +max_connections = 10 +health_check_interval = 30 + [minio] endpoint = 'minio-service.{{ .Release.Namespace }}.svc.cluster.local:9000' access_key = 'minioadmin' @@ -62,5 +73,8 @@ temperature = {{ default 0.7 .Values.models.functionCall.temperature }} enable = false words_list = "" +[sandbox] +sandbox_service = 'http://euler-copilot-sandbox-service.{{ .Release.Namespace }}.svc.cluster.local:8000' + [extra] sql_url = '' diff --git a/docs/variable_configuration.md b/docs/variable_configuration.md new file mode 100644 index 0000000000000000000000000000000000000000..b105c4af1fb43d785e8fadb0af84f5965a6783af --- /dev/null +++ b/docs/variable_configuration.md @@ -0,0 +1,133 @@ +# 变量存储格式配置说明 + +## 概述 + +节点执行完成后,系统会根据节点的`output_parameters`配置自动将输出数据保存到对话变量池中。变量的存储格式有两种: + +1. **直接格式**: `conversation.key` +2. **带前缀格式**: `conversation.step_id.key` + +## 配置方式 + +### 1. 通过节点类型配置 + +在 `apps/scheduler/executor/step_config.py` 中的 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 集合中添加节点类型: + +```python +DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: Set[str] = { + "Start", # Start节点 + "Input", # 输入节点 + "YourNewNode", # 新增的节点类型 +} +``` + +### 2. 通过名称模式配置 + +在 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 集合中添加匹配模式: + +```python +DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: Set[str] = { + "start", # 匹配以"start"开头的节点名称 + "init", # 匹配以"init"开头的节点名称 + "config", # 新增:匹配以"config"开头的节点名称 +} +``` + +## 判断逻辑 + +系统会按以下顺序判断是否使用直接格式: + +1. 检查节点的 `call_id` 是否在 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 中 +2. 检查节点的 `step_name` 是否在 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 中 +3. 检查节点名称(小写)是否以 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 中的模式开头 +4. 检查 `step_id`(小写)是否以 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 中的模式开头 + +如果任一条件满足,则使用直接格式 `conversation.key`,否则使用带前缀格式 `conversation.step_id.key`。 + +## 使用示例 + +### 示例1:Start节点 + +```json +// 节点配置 +{ + "call_id": "Start", + "step_name": "start", + "step_id": "start_001", + "output_parameters": { + "user_name": {"type": "string"}, + "session_id": {"type": "string"} + } +} + +// 保存的变量格式 +conversation.user_name = "张三" +conversation.session_id = "sess_123" +``` + +### 示例2:普通处理节点 + +```json +// 节点配置 +{ + "call_id": "Code", + "step_name": "数据处理", + "step_id": "process_001", + "output_parameters": { + "result": {"type": "object"}, + "status": {"type": "string"} + } +} + +// 保存的变量格式 +conversation.process_001.result = {...} +conversation.process_001.status = "success" +``` + +### 示例3:配置节点(新增类型) + +```python +# 在step_config.py中添加 +DIRECT_CONVERSATION_VARIABLE_NODE_TYPES.add("GlobalConfig") +``` + +```json +// 节点配置 +{ + "call_id": "GlobalConfig", + "step_name": "全局配置", + "step_id": "config_001", + "output_parameters": { + "api_key": {"type": "secret"}, + "timeout": {"type": "number"} + } +} + +// 保存的变量格式(使用直接格式) +conversation.api_key = "xxx" +conversation.timeout = 30 +``` + +## 变量引用 + +在其他节点中可以通过以下方式引用这些变量: + +```json +{ + "input_parameters": { + "user": { + "reference": "{{conversation.user_name}}" // 直接格式变量 + }, + "data": { + "reference": "{{conversation.process_001.result}}" // 带前缀格式变量 + } + } +} +``` + +## 注意事项 + +1. **一致性**: 建议同时添加大小写版本以确保兼容性 +2. **命名冲突**: 使用直接格式时需要注意变量名冲突问题 +3. **可追溯性**: 带前缀格式便于追踪变量来源,直接格式便于全局访问 +4. **配置变更**: 修改配置后需要重启服务才能生效 \ No newline at end of file diff --git a/euler_copilot_framework.egg-info/PKG-INFO b/euler_copilot_framework.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..768d3f89fce727f40aaee643151eddd45dc3fd9d --- /dev/null +++ b/euler_copilot_framework.egg-info/PKG-INFO @@ -0,0 +1,38 @@ +Metadata-Version: 2.4 +Name: euler-copilot-framework +Version: 0.9.6 +Summary: EulerCopilot 后端服务 +Requires-Python: ==3.11.6 +License-File: LICENSE +Requires-Dist: aiofiles==24.1.0 +Requires-Dist: asyncer==0.0.8 +Requires-Dist: asyncpg==0.30.0 +Requires-Dist: cryptography==44.0.2 +Requires-Dist: fastapi==0.115.12 +Requires-Dist: httpx==0.28.1 +Requires-Dist: httpx-sse==0.4.0 +Requires-Dist: jinja2==3.1.6 +Requires-Dist: jionlp==1.5.20 +Requires-Dist: jsonschema==4.23.0 +Requires-Dist: lancedb==0.21.2 +Requires-Dist: mcp==1.9.4 +Requires-Dist: minio==7.2.15 +Requires-Dist: ollama==0.5.1 +Requires-Dist: openai==1.91.0 +Requires-Dist: pandas==2.2.3 +Requires-Dist: pgvector==0.4.1 +Requires-Dist: pillow==10.3.0 +Requires-Dist: pydantic==2.11.7 +Requires-Dist: pymongo==4.12.1 +Requires-Dist: python-jsonpath==1.3.0 +Requires-Dist: python-magic==0.4.27 +Requires-Dist: python-multipart==0.0.20 +Requires-Dist: pytz==2025.2 +Requires-Dist: pyyaml==6.0.2 +Requires-Dist: rich==13.9.4 +Requires-Dist: sqids==0.5.1 +Requires-Dist: sqlalchemy==2.0.41 +Requires-Dist: tiktoken==0.9.0 +Requires-Dist: toml==0.10.2 +Requires-Dist: uvicorn==0.34.0 +Dynamic: license-file diff --git a/euler_copilot_framework.egg-info/SOURCES.txt b/euler_copilot_framework.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..3beeaae71ae08e9482a4157601718053efd51c51 --- /dev/null +++ b/euler_copilot_framework.egg-info/SOURCES.txt @@ -0,0 +1,12 @@ +LICENSE +README.md +pyproject.toml +apps/__init__.py +apps/constants.py +apps/exceptions.py +apps/main.py +euler_copilot_framework.egg-info/PKG-INFO +euler_copilot_framework.egg-info/SOURCES.txt +euler_copilot_framework.egg-info/dependency_links.txt +euler_copilot_framework.egg-info/requires.txt +euler_copilot_framework.egg-info/top_level.txt \ No newline at end of file diff --git a/euler_copilot_framework.egg-info/dependency_links.txt b/euler_copilot_framework.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/euler_copilot_framework.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/euler_copilot_framework.egg-info/requires.txt b/euler_copilot_framework.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..de92c758f82c631f9fd0ef4277e15568d0e2d6bb --- /dev/null +++ b/euler_copilot_framework.egg-info/requires.txt @@ -0,0 +1,31 @@ +aiofiles==24.1.0 +asyncer==0.0.8 +asyncpg==0.30.0 +cryptography==44.0.2 +fastapi==0.115.12 +httpx==0.28.1 +httpx-sse==0.4.0 +jinja2==3.1.6 +jionlp==1.5.20 +jsonschema==4.23.0 +lancedb==0.21.2 +mcp==1.9.4 +minio==7.2.15 +ollama==0.5.1 +openai==1.91.0 +pandas==2.2.3 +pgvector==0.4.1 +pillow==10.3.0 +pydantic==2.11.7 +pymongo==4.12.1 +python-jsonpath==1.3.0 +python-magic==0.4.27 +python-multipart==0.0.20 +pytz==2025.2 +pyyaml==6.0.2 +rich==13.9.4 +sqids==0.5.1 +sqlalchemy==2.0.41 +tiktoken==0.9.0 +toml==0.10.2 +uvicorn==0.34.0 diff --git a/euler_copilot_framework.egg-info/top_level.txt b/euler_copilot_framework.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..ea5a6165b61e9169f28b337d9757f02c7af34ba1 --- /dev/null +++ b/euler_copilot_framework.egg-info/top_level.txt @@ -0,0 +1 @@ +apps diff --git a/pyproject.toml b/pyproject.toml index 681ea9fb4552650da9148a22dd55d2063dca2410..8a34b17dd1b202e375447139133b46495cf5c30f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "EulerCopilot 后端服务" requires-python = "==3.11.6" dependencies = [ "aiofiles==24.1.0", + "redis==5.0.8", "asyncer==0.0.8", "asyncpg==0.30.0", "cryptography==44.0.2", @@ -51,3 +52,10 @@ dev = [ "sphinx==8.2.3", "sphinx-rtd-theme==3.0.2", ] + +[build-system] +requires = ["setuptools>=65"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["apps"] # 仅包含apps目录 \ No newline at end of file diff --git a/tests/common/test_config.py b/tests/common/test_config.py index 6c77141ca78e662ab12a8ec3c1ae017f3de6fced..9d90da2efd8b6f81a14136a6787ba38e157ebd16 100644 --- a/tests/common/test_config.py +++ b/tests/common/test_config.py @@ -68,6 +68,7 @@ MOCK_CONFIG_DATA: dict[str, Any] = { "jwt_key": "test_jwt_key", }, "check": {"enable": False, "words_list": ""}, + "sandbox": {"sandbox_service": "http://localhost:8000"}, "extra": {"sql_url": "http://localhost"}, } diff --git a/update.sh b/update.sh new file mode 100755 index 0000000000000000000000000000000000000000..88b1df6b8850215b9c8712c32a96c396419600fc --- /dev/null +++ b/update.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# update_euler_framework.sh - Docker镜像构建与推送脚本 +# 用法: ./update_euler_framework.sh +# 示例: ./update_euler_framework.sh 0.9.6 + +# 检查是否传入版本参数 +if [ $# -eq 0 ]; then + echo "错误: 必须指定版本号作为参数" + echo "用法: $0 " + exit 1 +fi + +VERSION=$1 +DOCKERFILE_PATH="/home/jay/code/euler-copilot-framework/Dockerfile" +IMAGE_NAME="swr.cn-north-4.myhuaweicloud.com/euler-copilot/euler-copilot-framework:hkzy-${VERSION}-arm" + +echo "开始拉取最新版本的代码, 会使用当前项目分支..." +git pull + +echo "开始构建Docker镜像,版本: ${VERSION}" + +# 构建Docker镜像 +docker build . -f ${DOCKERFILE_PATH} -t ${IMAGE_NAME} +if [ $? -ne 0 ]; then + echo "错误: Docker镜像构建失败" + exit 1 +fi + +echo "Docker镜像构建成功,开始推送..." + +# 推送Docker镜像 +docker push ${IMAGE_NAME} +if [ $? -ne 0 ]; then + echo "错误: Docker镜像推送失败" + exit 1 +fi + +echo "Docker镜像推送成功..." +echo "操作完成,镜像地址: ${IMAGE_NAME}" + +echo "helm 更新开始..." +helm upgrade euler-copilot /home/jay/deploy/chart/euler_copilot/. -n euler-copilot + +echo "输出euler-copilot pod状态..." +kubectl get pods -n euler-copilot \ No newline at end of file