diff --git a/.gitignore b/.gitignore
index c62e334e6e0f521a7f889c54a758a5956a99c0ea..4a6c7f8f9ed214b33348170a24679fe8061102df 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,7 @@ logs
.ruff_cache/
config
uv.lock
+update.sh
+auto-update.sh
+
+*.DS_Store
\ No newline at end of file
diff --git a/CONVERSATION_VARIABLE_FIX.md b/CONVERSATION_VARIABLE_FIX.md
new file mode 100644
index 0000000000000000000000000000000000000000..2a84ca946bef291cdf8c2e376f3d2a9a00b26684
--- /dev/null
+++ b/CONVERSATION_VARIABLE_FIX.md
@@ -0,0 +1,179 @@
+# 对话变量模板问题修复总结
+
+## 🔍 **问题描述**
+
+用户报告创建对话级变量`test`成功后,无法通过API接口查询到:
+```
+GET /api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28
+```
+
+## 🔧 **根本原因分析**
+
+经过分析发现,问题出现在我之前重构变量架构时的遗漏:
+
+### 1. **创建API工作正常**
+- ✅ 对话变量正确存储到FlowVariablePool的`_conversation_templates`字典中
+- ✅ 数据库持久化成功
+
+### 2. **查询API有缺陷**
+- ❌ `pool_manager.py`中`list_variables_from_any_pool`方法只处理了`conversation_id`参数
+- ❌ 没有处理`scope=conversation&flow_id=xxx`的查询情况
+- ❌ `get_variable_from_any_pool`方法也有同样问题
+
+### 3. **更新删除API有问题**
+- ❌ FlowVariablePool的`update_variable`和`delete_variable`方法只在`_variables`字典中查找
+- ❌ 找不到存储在`_conversation_templates`字典中的对话变量模板
+
+## 🛠️ **修复方案**
+
+### 1. **修复查询逻辑**
+
+#### `list_variables_from_any_pool`方法
+**修复前**:
+```python
+elif scope == VariableScope.CONVERSATION and conversation_id:
+ pool = await self.get_conversation_pool(conversation_id)
+ if pool:
+ return await pool.list_variables(include_system=False)
+ return []
+```
+
+**修复后**:
+```python
+elif scope == VariableScope.CONVERSATION:
+ if conversation_id:
+ # 使用conversation_id查询对话变量实例
+ pool = await self.get_conversation_pool(conversation_id)
+ if pool:
+ return await pool.list_variables(include_system=False)
+ elif flow_id:
+ # 使用flow_id查询对话变量模板
+ flow_pool = await self.get_flow_pool(flow_id)
+ if flow_pool:
+ return await flow_pool.list_conversation_templates()
+ return []
+```
+
+#### `get_variable_from_any_pool`方法
+类似的修复,支持通过`flow_id`查询对话变量模板。
+
+### 2. **修复创建逻辑**
+
+#### 修改`create_variable`路由
+**修复前**:
+```python
+# 创建变量
+variable = await pool.add_variable(...)
+```
+
+**修复后**:
+```python
+# 根据作用域创建不同类型的变量
+if request.scope == VariableScope.CONVERSATION:
+ # 创建对话变量模板
+ variable = await pool.add_conversation_template(...)
+else:
+ # 创建其他类型的变量
+ variable = await pool.add_variable(...)
+```
+
+### 3. **增强FlowVariablePool功能**
+
+为FlowVariablePool添加了重写的方法,支持多字典操作:
+
+#### `update_variable`方法
+- 在环境变量、系统变量模板、对话变量模板中按顺序查找
+- 找到变量后执行更新操作
+- 正确持久化到数据库
+
+#### `delete_variable`方法
+- 支持删除存储在不同字典中的变量
+- 保留权限检查(系统变量模板不允许删除)
+
+#### `get_variable`方法
+- 统一的变量查找接口
+- 支持跨字典查找
+
+## ✅ **修复验证**
+
+### 现在支持的完整工作流程:
+
+#### 1. **Flow级别操作**(变量模板管理)
+```bash
+# 创建对话变量模板
+POST /api/variable/create
+{
+ "name": "test",
+ "var_type": "string",
+ "scope": "conversation",
+ "value": "123",
+ "description": "321",
+ "flow_id": "52e069c7-5556-42af-bdfc-63f4dc2dcd28"
+}
+
+# 查询对话变量模板
+GET /api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28
+
+# 更新对话变量模板
+PUT /api/variable/update?name=test&scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28
+
+# 删除对话变量模板
+DELETE /api/variable/delete?name=test&scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28
+```
+
+#### 2. **Conversation级别操作**(变量实例管理)
+```bash
+# 查询对话变量实例
+GET /api/variable/list?scope=conversation&conversation_id=conv123
+
+# 更新对话变量实例值
+PUT /api/variable/update?name=test&scope=conversation&conversation_id=conv123
+```
+
+## 🎯 **测试建议**
+
+### 立即测试
+现在可以重新测试原来失败的API调用:
+```bash
+curl "http://10.211.55.10:8002/api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28"
+```
+
+### 完整测试流程
+1. **创建对话变量模板**(前端已测试成功)
+2. **查询对话变量模板**(现在应该能查到)
+3. **更新对话变量模板**
+4. **删除对话变量模板**
+
+### 自动化测试
+运行测试脚本验证:
+```bash
+cd euler-copilot-framework
+python test_conversation_variables.py
+```
+
+## 📊 **架构完整性验证**
+
+现在所有变量类型的查询都应该正常工作:
+
+### Flow级别查询
+- ✅ 系统变量模板:`GET /api/variable/list?scope=system&flow_id=xxx`
+- ✅ 对话变量模板:`GET /api/variable/list?scope=conversation&flow_id=xxx`
+- ✅ 环境变量:`GET /api/variable/list?scope=environment&flow_id=xxx`
+
+### Conversation级别查询
+- ✅ 系统变量实例:`GET /api/variable/list?scope=system&conversation_id=xxx`
+- ✅ 对话变量实例:`GET /api/variable/list?scope=conversation&conversation_id=xxx`
+
+### User级别查询
+- ✅ 用户变量:`GET /api/variable/list?scope=user`
+
+## 🎉 **预期结果**
+
+修复后,你的前端应该能够:
+
+1. **成功创建对话变量模板**(已验证)
+2. **成功查询对话变量模板**(修复的核心问题)
+3. **成功更新对话变量模板**
+4. **成功删除对话变量模板**
+
+所有操作都在Flow级别进行,符合你的设计需求:**Flow级别管理模板定义,Conversation级别操作实际数据**。
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
index 7180b6a9e4fc7316a27efb46a431676dbbb31780..a52aed1db099236b8a7612961892fa3db80264ed 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,9 +1,10 @@
-FROM hub.oepkgs.net/neocopilot/framework_base:0.9.6-x86-test
+FROM hub.oepkgs.net/neocopilot/framework_base:0.9.6-arm
ENV PYTHONPATH=/app
ENV TIKTOKEN_CACHE_DIR=/app/assets/tiktoken
-COPY --chmod=550 ./ /app/
+COPY ./ /app/
+RUN chmod -R 550 /app/
RUN chmod 766 /root
-CMD ["uv", "run", "--no-sync", "--no-dev", "apps/main.py"]
+CMD ["uv", "run", "--no-dev", "apps/main.py"]
diff --git a/apps/common/redis_cache.py b/apps/common/redis_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..019d0be917a130360ac5a6e2e39459af7ad5b311
--- /dev/null
+++ b/apps/common/redis_cache.py
@@ -0,0 +1,271 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Redis缓存模块 - 用于前置节点变量预解析缓存"""
+
+import json
+import logging
+import asyncio
+from typing import List, Dict, Any, Optional
+from datetime import datetime, UTC
+
+import redis.asyncio as redis
+from apps.common.singleton import SingletonMeta
+
+logger = logging.getLogger(__name__)
+
+
+class RedisCache(metaclass=SingletonMeta):
+ """Redis缓存管理器"""
+
+ def __init__(self):
+ self._redis: Optional[redis.Redis] = None
+ self._connected = False
+
+ async def init(self, redis_config=None, redis_url: str = None):
+ """初始化Redis连接
+
+ Args:
+ redis_config: Redis配置对象(优先级更高)
+ redis_url: Redis连接URL(降级选项)
+ """
+ try:
+ if redis_config:
+ # 使用配置对象构建连接,添加连接池和超时参数
+ self._redis = redis.Redis(
+ host=redis_config.host,
+ port=redis_config.port,
+ password=redis_config.password if redis_config.password else None,
+ db=redis_config.database,
+ decode_responses=redis_config.decode_responses,
+ # 连接池配置
+ max_connections=redis_config.max_connections,
+ # 超时配置
+ socket_timeout=redis_config.socket_timeout,
+ socket_connect_timeout=redis_config.socket_connect_timeout,
+ socket_keepalive=True,
+ socket_keepalive_options={},
+ # 连接重试
+ retry_on_timeout=True,
+ retry_on_error=[ConnectionError, TimeoutError],
+ health_check_interval=redis_config.health_check_interval
+ )
+ logger.info(f"使用配置连接Redis: {redis_config.host}:{redis_config.port}, 数据库: {redis_config.database}")
+ elif redis_url:
+ # 降级使用URL连接
+ self._redis = redis.from_url(
+ redis_url,
+ decode_responses=True,
+ socket_timeout=5.0,
+ socket_connect_timeout=5.0,
+ max_connections=10
+ )
+ logger.info(f"使用URL连接Redis: {redis_url}")
+ else:
+ raise ValueError("必须提供redis_config或redis_url参数")
+
+ # 测试连接
+ logger.info("正在测试Redis连接...")
+ ping_result = await self._redis.ping()
+ logger.info(f"Redis ping结果: {ping_result}")
+
+ # 测试基本操作
+ test_key = "__redis_test__"
+ await self._redis.set(test_key, "test", ex=10)
+ test_value = await self._redis.get(test_key)
+ await self._redis.delete(test_key)
+ logger.info(f"Redis读写测试成功: {test_value}")
+
+ self._connected = True
+ logger.info("Redis连接初始化成功")
+ except ConnectionError as e:
+ logger.error(f"Redis连接错误: {e}")
+ self._connected = False
+ except TimeoutError as e:
+ logger.error(f"Redis连接超时: {e}")
+ self._connected = False
+ except Exception as e:
+ logger.error(f"Redis连接初始化失败: {e}")
+ logger.error(f"错误类型: {type(e).__name__}")
+ self._connected = False
+
+ def is_connected(self) -> bool:
+ """检查Redis连接状态"""
+ return self._connected and self._redis is not None
+
+ async def close(self):
+ """关闭Redis连接"""
+ if self._redis:
+ await self._redis.close()
+ self._connected = False
+
+
+class PredecessorVariableCache:
+ """前置节点变量预解析缓存管理器"""
+
+ def __init__(self, redis_cache: RedisCache):
+ self.redis = redis_cache
+ self.CACHE_PREFIX = "predecessor_vars"
+ self.PARSING_STATUS_PREFIX = "parsing_status"
+ self.CACHE_TTL = 3600 * 24 # 缓存24小时
+
+ def _get_cache_key(self, flow_id: str, step_id: str) -> str:
+ """生成缓存key"""
+ return f"{self.CACHE_PREFIX}:{flow_id}:{step_id}"
+
+ def _get_status_key(self, flow_id: str, step_id: str) -> str:
+ """生成解析状态key"""
+ return f"{self.PARSING_STATUS_PREFIX}:{flow_id}:{step_id}"
+
+ def _get_flow_hash_key(self, flow_id: str) -> str:
+ """生成Flow哈希key,用于存储Flow的拓扑结构哈希值"""
+ return f"flow_hash:{flow_id}"
+
+ async def get_cached_variables(self, flow_id: str, step_id: str) -> Optional[List[Dict[str, Any]]]:
+ """获取缓存的前置节点变量"""
+ if not self.redis.is_connected():
+ return None
+
+ try:
+ cache_key = self._get_cache_key(flow_id, step_id)
+ cached_data = await self.redis._redis.get(cache_key)
+
+ if cached_data:
+ data = json.loads(cached_data)
+ logger.info(f"从缓存获取前置节点变量: {flow_id}:{step_id}, 数量: {len(data.get('variables', []))}")
+ return data.get('variables', [])
+
+ except Exception as e:
+ logger.error(f"获取缓存的前置节点变量失败: {e}")
+
+ return None
+
+ async def set_cached_variables(self, flow_id: str, step_id: str, variables: List[Dict[str, Any]], flow_hash: str):
+ """设置缓存的前置节点变量"""
+ if not self.redis.is_connected():
+ return False
+
+ try:
+ cache_key = self._get_cache_key(flow_id, step_id)
+ cache_data = {
+ 'variables': variables,
+ 'flow_hash': flow_hash,
+ 'cached_at': datetime.now(UTC).isoformat(),
+ 'step_count': len(variables)
+ }
+
+ await self.redis._redis.setex(
+ cache_key,
+ self.CACHE_TTL,
+ json.dumps(cache_data, default=str)
+ )
+
+ logger.info(f"缓存前置节点变量成功: {flow_id}:{step_id}, 数量: {len(variables)}")
+ return True
+
+ except Exception as e:
+ logger.error(f"缓存前置节点变量失败: {e}")
+ return False
+
+ async def is_parsing_in_progress(self, flow_id: str, step_id: str) -> bool:
+ """检查是否正在解析中"""
+ if not self.redis.is_connected():
+ return False
+
+ try:
+ status_key = self._get_status_key(flow_id, step_id)
+ status = await self.redis._redis.get(status_key)
+ return status == "parsing"
+ except Exception as e:
+ logger.error(f"检查解析状态失败: {e}")
+ return False
+
+ async def set_parsing_status(self, flow_id: str, step_id: str, status: str, ttl: int = 300):
+ """设置解析状态 (parsing, completed, failed)"""
+ if not self.redis.is_connected():
+ return False
+
+ try:
+ status_key = self._get_status_key(flow_id, step_id)
+ await self.redis._redis.setex(status_key, ttl, status)
+ return True
+ except Exception as e:
+ logger.error(f"设置解析状态失败: {e}")
+ return False
+
+ async def wait_for_parsing_completion(self, flow_id: str, step_id: str, max_wait_time: int = 30) -> bool:
+ """等待解析完成"""
+ if not self.redis.is_connected():
+ return False
+
+ start_time = datetime.now(UTC)
+
+ while (datetime.now(UTC) - start_time).total_seconds() < max_wait_time:
+ if not await self.is_parsing_in_progress(flow_id, step_id):
+ # 检查是否有缓存结果
+ cached_vars = await self.get_cached_variables(flow_id, step_id)
+ return cached_vars is not None
+
+ await asyncio.sleep(0.5) # 等待500ms后重试
+
+ logger.warning(f"等待解析完成超时: {flow_id}:{step_id}")
+ return False
+
+ async def invalidate_flow_cache(self, flow_id: str):
+ """使某个Flow的所有缓存失效"""
+ if not self.redis.is_connected():
+ return
+
+ try:
+ # 查找所有相关的缓存key
+ pattern = f"{self.CACHE_PREFIX}:{flow_id}:*"
+ keys = await self.redis._redis.keys(pattern)
+
+ # 同时删除解析状态key
+ status_pattern = f"{self.PARSING_STATUS_PREFIX}:{flow_id}:*"
+ status_keys = await self.redis._redis.keys(status_pattern)
+
+ all_keys = keys + status_keys
+
+ if all_keys:
+ await self.redis._redis.delete(*all_keys)
+ logger.info(f"清除Flow缓存: {flow_id}, 删除key数量: {len(all_keys)}")
+
+ except Exception as e:
+ logger.error(f"清除Flow缓存失败: {e}")
+
+ async def get_flow_hash(self, flow_id: str) -> Optional[str]:
+ """获取Flow的拓扑结构哈希值"""
+ if not self.redis.is_connected():
+ return None
+
+ try:
+ hash_key = self._get_flow_hash_key(flow_id)
+ return await self.redis._redis.get(hash_key)
+ except Exception as e:
+ logger.error(f"获取Flow哈希失败: {e}")
+ return None
+
+ async def set_flow_hash(self, flow_id: str, flow_hash: str):
+ """设置Flow的拓扑结构哈希值"""
+ if not self.redis.is_connected():
+ return False
+
+ try:
+ # 检查事件循环是否仍然活跃
+ import asyncio
+ try:
+ asyncio.get_running_loop()
+ except RuntimeError:
+ logger.warning(f"事件循环已关闭,跳过设置Flow哈希: {flow_id}")
+ return False
+
+ hash_key = self._get_flow_hash_key(flow_id)
+ await self.redis._redis.setex(hash_key, self.CACHE_TTL, flow_hash)
+ return True
+ except Exception as e:
+ logger.error(f"设置Flow哈希失败: {e}")
+ return False
+
+
+# 全局实例
+redis_cache = RedisCache()
+predecessor_cache = PredecessorVariableCache(redis_cache)
\ No newline at end of file
diff --git a/apps/dependency/user.py b/apps/dependency/user.py
index fce67e51e00dd76b35bec2f7787dfe8039111589..87cbd290283507aa911e2c0d188bea2099bd150f 100644
--- a/apps/dependency/user.py
+++ b/apps/dependency/user.py
@@ -5,10 +5,12 @@ import logging
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
+import secrets
from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection
+from apps.common.config import Config
from apps.services.api_key import ApiKeyManager
from apps.services.session import SessionManager
@@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str:
:param request: HTTP请求
:return: Session ID
"""
+ if Config().get_config().no_auth.enable:
+ # 如果启用了无认证访问,直接返回调试用户
+ return secrets.token_hex(16)
session_id = await _get_session_id_from_request(request)
if not session_id:
raise HTTPException(
@@ -69,6 +74,9 @@ async def get_user(request: HTTPConnection) -> str:
:param request: HTTP请求体
:return: 用户sub
"""
+ if Config().get_config().no_auth.enable:
+ # 如果启用了无认证访问,直接返回调试用户
+ return Config().get_config().no_auth.user_sub
session_id = await _get_session_id_from_request(request)
if not session_id:
raise HTTPException(
diff --git a/apps/llm/function.py b/apps/llm/function.py
index 1f995fe7ba187cead03aa6fc62a4cbce1ec05a65..74cc9a1723395e77d799e38a93e161328bd1e596 100644
--- a/apps/llm/function.py
+++ b/apps/llm/function.py
@@ -119,7 +119,7 @@ class FunctionLLM:
"name": "generate",
"description": "Generate answer based on the background information",
"parameters": schema,
- },
+ }
},
]
diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py
index fdb36fc05adf38920bcce0d962b6aafc21e44b71..453267f833b2744a65bf3ecea5f00727abb11c35 100644
--- a/apps/llm/reasoning.py
+++ b/apps/llm/reasoning.py
@@ -134,6 +134,7 @@ class ReasoningLLM:
max_tokens: int | None,
temperature: float | None,
model: str | None = None,
+ frequency_penalty: float | None = None
) -> AsyncGenerator[ChatCompletionChunk, None]:
"""创建流式响应"""
if model is None:
@@ -143,6 +144,7 @@ class ReasoningLLM:
messages=messages, # type: ignore[]
max_tokens=max_tokens or self._config.max_tokens,
temperature=temperature or self._config.temperature,
+ frequency_penalty=frequency_penalty or self._config.frequency_penalty,
stream=True,
stream_options={"include_usage": True},
) # type: ignore[]
@@ -156,6 +158,7 @@ class ReasoningLLM:
streaming: bool = True,
result_only: bool = True,
model: str | None = None,
+ frequency_penalty: float | None = 0
) -> AsyncGenerator[str, None]:
"""调用大模型,分为流式和非流式两种"""
# 检查max_tokens和temperature
@@ -166,7 +169,7 @@ class ReasoningLLM:
if model is None:
model = self._config.model
msg_list = self._validate_messages(messages)
- stream = await self._create_stream(msg_list, max_tokens, temperature, model)
+ stream = await self._create_stream(msg_list, max_tokens, temperature, model, frequency_penalty)
reasoning = ReasoningContent()
reasoning_content = ""
result = ""
@@ -202,7 +205,7 @@ class ReasoningLLM:
yield reasoning_content
yield result
- logger.info("[Reasoning] 推理内容: %s\n\n%s", reasoning_content, result)
+ logger.info("[Reasoning] 推理内容: %r\n\n%s", reasoning_content, result)
# 更新token统计
if self.input_tokens == 0 or self.output_tokens == 0:
diff --git a/apps/main.py b/apps/main.py
index c4ca2bfb116db11624f4e4fbbe4e876a86e5f1eb..50447ea9902ee4225cb5c78bf9ccb7aa89c4f6c4 100644
--- a/apps/main.py
+++ b/apps/main.py
@@ -8,6 +8,10 @@ from __future__ import annotations
import asyncio
import logging
+import logging.config
+import signal
+import sys
+from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI
@@ -36,11 +40,62 @@ from apps.routers import (
record,
service,
user,
+ parameter,
+ variable
)
from apps.scheduler.pool.pool import Pool
+from apps.services.predecessor_cache_service import cleanup_background_tasks
+
+# 全局变量用于跟踪后台任务
+_cleanup_task = None
+async def cleanup_on_shutdown():
+ """应用关闭时的清理函数"""
+ logger = logging.getLogger(__name__)
+ logger.info("开始清理应用资源...")
+
+ try:
+ # 取消定期清理任务
+ global _cleanup_task
+ if _cleanup_task and not _cleanup_task.done():
+ _cleanup_task.cancel()
+ try:
+ await _cleanup_task
+ except asyncio.CancelledError:
+ logger.info("定期清理任务已取消")
+
+ # 清理后台任务
+ await cleanup_background_tasks()
+
+ # 关闭Redis连接
+ from apps.common.redis_cache import RedisCache
+ redis_cache = RedisCache()
+ if redis_cache.is_connected():
+ await redis_cache.close()
+ logger.info("Redis连接已关闭")
+
+ except Exception as e:
+ logger.error(f"清理应用资源时出错: {e}")
+
+ logger.info("应用资源清理完成")
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ """应用生命周期管理"""
+ # 启动时的初始化
+ await init_resources()
+
+ yield
+
+ # 关闭时的清理
+ await cleanup_on_shutdown()
# 定义FastAPI app
-app = FastAPI(redoc_url=None)
+app = FastAPI(
+ title="Euler Copilot Framework",
+ description="AI-powered automation framework",
+ version="1.0.0",
+ lifespan=lifespan,
+)
# 定义FastAPI全局中间件
app.add_middleware(
CORSMiddleware,
@@ -66,6 +121,8 @@ app.include_router(llm.router)
app.include_router(mcp_service.router)
app.include_router(flow.router)
app.include_router(user.router)
+app.include_router(parameter.router)
+app.include_router(variable.router)
# logger配置
LOGGER_FORMAT = "%(funcName)s() - %(message)s"
@@ -88,10 +145,48 @@ async def init_resources() -> None:
await Pool.init()
TokenCalculator()
+ # 初始化变量池管理器
+ from apps.scheduler.variable.pool_manager import initialize_pool_manager
+ await initialize_pool_manager()
+
+ # 初始化前置节点变量缓存服务
+ try:
+ from apps.services.predecessor_cache_service import PredecessorCacheService, periodic_cleanup_background_tasks
+ await PredecessorCacheService.initialize_redis()
+
+ # 启动定期清理任务
+ global _cleanup_task
+ _cleanup_task = asyncio.create_task(start_periodic_cleanup())
+
+ logging.info("前置节点变量缓存服务初始化成功")
+ except Exception as e:
+ logging.warning(f"前置节点变量缓存服务初始化失败(将降级使用实时解析): {e}")
+
+async def start_periodic_cleanup():
+ """启动定期清理任务"""
+ try:
+ from apps.services.predecessor_cache_service import periodic_cleanup_background_tasks
+ while True:
+ # 每60秒清理一次已完成的后台任务
+ await asyncio.sleep(60)
+ await periodic_cleanup_background_tasks()
+ except asyncio.CancelledError:
+ logging.info("定期清理任务已取消")
+ raise # 重新抛出CancelledError
+ except Exception as e:
+ logging.error(f"定期清理任务异常: {e}")
+
# 运行
if __name__ == "__main__":
- # 初始化必要资源
- asyncio.run(init_resources())
+ def signal_handler(signum, frame):
+ """信号处理器"""
+ logger = logging.getLogger(__name__)
+ logger.info(f"收到信号 {signum},准备关闭应用...")
+ sys.exit(0)
+
+ # 注册信号处理器
+ signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, signal_handler)
# 启动FastAPI
uvicorn.run(app, host="0.0.0.0", port=8002, log_level="info", log_config=None)
diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py
index 158cfc13a5be17f16e2d73ec5fa44c4a18e4f392..51366a2192eaa429a3f1e1d40672b9e915bec424 100644
--- a/apps/routers/api_key.py
+++ b/apps/routers/api_key.py
@@ -6,6 +6,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
+
from apps.dependency.user import get_user, verify_user
from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp
from apps.schemas.response_data import ResponseData
diff --git a/apps/routers/auth.py b/apps/routers/auth.py
index 1cba5ed629d90776b59ae3bf652381318dcabf0e..6c56aaa25ba1c22c75b45f823e6f7cd052a82341 100644
--- a/apps/routers/auth.py
+++ b/apps/routers/auth.py
@@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
+from apps.common.config import Config
from apps.common.oidc import oidc_provider
from apps.dependency import get_session, get_user, verify_user
from apps.schemas.collection import Audit
@@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse:
user_info = await oidc_provider.get_oidc_user(token["access_token"])
user_sub: str | None = user_info.get("user_sub", None)
- if user_sub:
- await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"])
except Exception as e:
logger.exception("User login failed")
status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN
@@ -149,7 +148,7 @@ async def oidc_redirect() -> JSONResponse:
# TODO(zwt): OIDC主动触发logout
@router.post("/logout", dependencies=[Depends(verify_user)], response_model=ResponseData)
-async def oidc_logout(token: str) -> JSONResponse:
+async def oidc_logout(token: str) -> JSONResponse | None:
"""OIDC主动触发登出"""
@@ -191,6 +190,59 @@ async def userinfo(
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": ResponseData},
},
)
+
+# 增加获取user session,开放平台api调用接口
+@router.post("/create-session", response_model=ResponseData)
+async def create_session(
+ request: Request,
+ user_sub: str
+) -> JSONResponse:
+ """通过user_sub直接创建session"""
+ if not request.client:
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content=ResponseData(
+ code=status.HTTP_400_BAD_REQUEST,
+ message="客户端IP地址缺失",
+ result={},
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+
+ try:
+ session_id = await SessionManager.create_session(
+ ip=request.client.host,
+ user_sub=user_sub
+ )
+
+ # 记录审计日志
+ data = Audit(
+ http_method="post",
+ module="auth",
+ client_ip=request.client.host,
+ user_sub=user_sub,
+ message="/api/auth/create-session: Session创建成功",
+ )
+ await AuditLogManager.add_audit_log(data)
+
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content=ResponseData(
+ code=status.HTTP_200_OK,
+ message="success",
+ result={"session_id": session_id},
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ except Exception as e:
+ logger.exception("Session创建失败")
+ return JSONResponse(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ content=ResponseData(
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ message=str(e),
+ result={},
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+
async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001
"""更新用户协议信息"""
ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True)
@@ -227,3 +279,9 @@ async def update_revision_number(request: Request, user_sub: Annotated[str, Depe
),
).model_dump(exclude_none=True, by_alias=True),
)
+
+# user_info = await oidc_provider.get_oidc_user(token["access_token"])
+
+# user_sub: str | None = user_info.get("user_sub", None)
+# if user_sub:
+
diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index 7fe5162c0f7db804c6d723207ffa327d9394e3eb..b16b41bb0370b06911e3b8127dc1df86b1336391 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -6,6 +6,7 @@ import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Annotated
+import json
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import JSONResponse, StreamingResponse
@@ -36,11 +37,16 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T
# 生成group_id
if not post_body.group_id:
post_body.group_id = str(uuid.uuid4())
- # 创建或还原Task
+ if post_body.new_task:
+ # 创建或还原Task
+ task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub)
+ if task:
+ await TaskManager.delete_task_by_task_id(task.id)
task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub)
# 更改信息并刷新数据库
- task.runtime.question = post_body.question
- task.ids.group_id = post_body.group_id
+ if post_body.new_task:
+ task.runtime.question = post_body.question
+ task.ids.group_id = post_body.group_id
return task
@@ -136,6 +142,45 @@ async def chat(
},
)
+@router.post("/chat_without_streaming")
+async def chat_without_streaming(
+ post_body: RequestData,
+ user_sub: Annotated[str, Depends(get_user)],
+ session_id: Annotated[str, Depends(get_session)],
+) -> JSONResponse:
+ """非流式对话接口"""
+ # 问题黑名单检测
+ if not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question):
+ # 用户扣分
+ await UserBlacklistManager.change_blacklisted_users(user_sub, -10)
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted")
+
+ # 限流检查
+ if await Activity.is_active(user_sub):
+ raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests")
+
+ content_texts = []
+ async for chunk in chat_generator(post_body, user_sub, session_id):
+ if chunk.startswith('data: {"event": "text.add", '):
+ content_texts.append(chunk)
+
+ final_answer = ''.join([
+ json.loads(chunk.split('data: ')[1].split('\n\n')[0])['content']['text']
+ for chunk in content_texts
+ if chunk.strip()
+ ])
+
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={
+ "code": 200,
+ "message": "success",
+ "result": {
+ "content": final_answer
+ }
+ },
+ )
+
@router.post("/stop", response_model=ResponseData)
async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201
diff --git a/apps/routers/flow.py b/apps/routers/flow.py
index fc38c1bfb9879c1d846d700e91b812c83866c2c4..646213a88ac74c296b159678d3f47b9c0746219f 100644
--- a/apps/routers/flow.py
+++ b/apps/routers/flow.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""FastAPI Flow拓扑结构展示API"""
+import logging
from typing import Annotated
from fastapi import APIRouter, Body, Depends, Query, status
@@ -25,6 +26,8 @@ from apps.services.application import AppManager
from apps.services.flow import FlowManager
from apps.services.flow_validate import FlowService
+logger = logging.getLogger(__name__)
+
router = APIRouter(
prefix="/api/flow",
tags=["flow"],
@@ -153,6 +156,20 @@ async def put_flow(
result=FlowStructurePutMsg(),
).model_dump(exclude_none=True, by_alias=True),
)
+
+ # 触发前置节点变量预解析(异步执行,不阻塞响应)
+ try:
+ from apps.services.predecessor_cache_service import PredecessorCacheService
+ import asyncio
+
+ # 在后台异步触发预解析
+ asyncio.create_task(
+ PredecessorCacheService.trigger_flow_parsing(flow_id, force_refresh=True)
+ )
+ logger.info(f"已触发Flow前置节点变量预解析: {flow_id}")
+ except Exception as trigger_error:
+ logger.warning(f"触发Flow前置节点变量预解析失败: {flow_id}, 错误: {trigger_error}")
+
return JSONResponse(
status_code=status.HTTP_200_OK,
content=FlowStructurePutRsp(
diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py
new file mode 100644
index 0000000000000000000000000000000000000000..6edbe2e142cb6589be8947c28ae2eb4a7287baa1
--- /dev/null
+++ b/apps/routers/parameter.py
@@ -0,0 +1,77 @@
+from typing import Annotated
+
+from fastapi import APIRouter, Depends, Query, status
+from fastapi.responses import JSONResponse
+
+from apps.dependency import get_user
+from apps.dependency.user import verify_user
+from apps.services.parameter import ParameterManager
+from apps.schemas.response_data import (
+ GetOperaRsp,
+ GetParamsRsp
+)
+from apps.services.application import AppManager
+from apps.services.flow import FlowManager
+
+router = APIRouter(
+ prefix="/api/parameter",
+ tags=["parameter"],
+ dependencies=[
+ Depends(verify_user),
+ ],
+)
+
+
+@router.get("", response_model=GetParamsRsp)
+async def get_parameters(
+ user_sub: Annotated[str, Depends(get_user)],
+ app_id: Annotated[str, Query(alias="appId")],
+ flow_id: Annotated[str, Query(alias="flowId")],
+ step_id: Annotated[str, Query(alias="stepId")],
+) -> JSONResponse:
+ """Get parameters for node choice."""
+ if not await AppManager.validate_user_app_access(user_sub, app_id):
+ return JSONResponse(
+ status_code=status.HTTP_403_FORBIDDEN,
+ content=GetParamsRsp(
+ code=status.HTTP_403_FORBIDDEN,
+ message="用户没有权限访问该流",
+ result=[],
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id)
+ if not flow:
+ return JSONResponse(
+ status_code=status.HTTP_404_NOT_FOUND,
+ content=GetParamsRsp(
+ code=status.HTTP_404_NOT_FOUND,
+ message="未找到该流",
+ result=[],
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id)
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content=GetParamsRsp(
+ code=status.HTTP_200_OK,
+ message="获取参数成功",
+ result=result
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+
+
+@router.get("/operate", response_model=GetOperaRsp)
+async def get_operate_parameters(
+ user_sub: Annotated[str, Depends(get_user)],
+ param_type: Annotated[str, Query(alias="ParamType")],
+) -> JSONResponse:
+ """Get parameters for node choice."""
+ result = await ParameterManager.get_operate_and_bind_type(param_type)
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content=GetOperaRsp(
+ code=status.HTTP_200_OK,
+ message="获取操作成功",
+ result=result
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
diff --git a/apps/routers/variable.py b/apps/routers/variable.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ae53a94828a9e7e629ec6b72b11a28844c91f51
--- /dev/null
+++ b/apps/routers/variable.py
@@ -0,0 +1,983 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""FastAPI 变量管理 API"""
+
+import logging
+from typing import Annotated, List, Optional, Dict
+
+from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel, Field
+
+from apps.dependency import get_user
+from apps.dependency.user import verify_user
+from apps.scheduler.variable.pool_manager import get_pool_manager
+from apps.scheduler.variable.type import VariableType, VariableScope
+from apps.scheduler.variable.parser import VariableParser
+from apps.schemas.response_data import ResponseData
+from apps.services.flow import FlowManager
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(
+ prefix="/api/variable",
+ tags=["variable"],
+ dependencies=[
+ Depends(verify_user),
+ ],
+)
+
+
+async def _get_predecessor_node_variables(
+ user_sub: str,
+ flow_id: str,
+ conversation_id: Optional[str],
+ current_step_id: str
+) -> List:
+ """获取前置节点的输出变量(优化版本,使用缓存)
+
+ Args:
+ user_sub: 用户ID
+ flow_id: 流程ID
+ conversation_id: 对话ID(可选,配置阶段可能为None)
+ current_step_id: 当前步骤ID
+
+ Returns:
+ List: 前置节点的输出变量列表
+ """
+ try:
+ variables = []
+ pool_manager = await get_pool_manager()
+
+ if conversation_id:
+ # 运行阶段:从对话池获取实际的前置节点变量
+ conversation_pool = await pool_manager.get_conversation_pool(conversation_id)
+ if conversation_pool:
+ # 获取所有对话变量
+ all_conversation_vars = await conversation_pool.list_variables()
+
+ # 筛选出前置节点的输出变量(格式为 node_id.key)
+ for var in all_conversation_vars:
+ var_name = var.name
+ # 检查是否为节点输出变量格式(包含.且不是系统变量)
+ if "." in var_name and not var_name.startswith("system."):
+ # 提取节点ID
+ node_id = var_name.split(".")[0]
+
+ # 检查是否为前置节点(这里可以根据需要添加更精确的前置判断逻辑)
+ if node_id != current_step_id: # 不是当前节点的变量
+ variables.append(var)
+ else:
+ # 配置阶段:优先使用缓存,降级到实时解析
+ try:
+ # 尝试使用优化的缓存服务
+ from apps.services.predecessor_cache_service import PredecessorCacheService
+
+ # 1. 先从flow池中查找已存在的前置节点变量
+ flow_pool = await pool_manager.get_flow_pool(flow_id)
+ if flow_pool:
+ flow_conversation_vars = await flow_pool.list_variables()
+
+ # 筛选出前置节点的输出变量(格式为 node_id.key)
+ for var in flow_conversation_vars:
+ var_name = var.name
+ if "." in var_name and not var_name.startswith("system."):
+ node_id = var_name.split(".")[0]
+ if node_id != current_step_id:
+ variables.append(var)
+
+ # 2. 使用优化的缓存服务获取前置节点变量
+ cached_var_data = await PredecessorCacheService.get_predecessor_variables_optimized(
+ flow_id, current_step_id, user_sub, max_wait_time=5
+ )
+
+ # 将缓存的变量数据转换为Variable对象
+ for var_data in cached_var_data:
+ try:
+ from apps.scheduler.variable.variables import create_variable
+ from apps.scheduler.variable.base import VariableMetadata
+ from apps.scheduler.variable.type import VariableType, VariableScope
+ from datetime import datetime
+
+ # 创建变量元数据
+ metadata = VariableMetadata(
+ name=var_data['name'],
+ var_type=VariableType(var_data['var_type']),
+ scope=VariableScope(var_data['scope']),
+ description=var_data.get('description', ''),
+ created_by=user_sub,
+ created_at=datetime.fromisoformat(var_data['created_at'].replace('Z', '+00:00')),
+ updated_at=datetime.fromisoformat(var_data['updated_at'].replace('Z', '+00:00'))
+ )
+
+ # 创建变量对象,并附加缓存的节点信息
+ variable = create_variable(metadata, var_data.get('value', ''))
+
+ # 将节点信息附加到变量对象上(用于后续响应格式化)
+ if hasattr(variable, '_cache_data'):
+ variable._cache_data = var_data
+ else:
+ # 如果对象不支持动态属性,我们可以创建一个包装类或者在响应时处理
+ setattr(variable, '_cache_data', var_data)
+
+ variables.append(variable)
+
+ except Exception as var_create_error:
+ logger.warning(f"创建缓存变量对象失败: {var_create_error}")
+ continue
+
+ logger.info(f"配置阶段:为节点 {current_step_id} 找到前置节点变量总数: {len([v for v in variables if hasattr(v, 'name') and '.' in v.name and not v.name.startswith('system.')])}")
+
+ except Exception as flow_error:
+ logger.warning(f"配置阶段获取前置节点变量失败,降级到实时解析: {flow_error}")
+ # 降级到原有的实时解析逻辑
+ predecessor_vars = await _get_predecessor_variables_from_topology(
+ flow_id, current_step_id, user_sub
+ )
+ variables.extend(predecessor_vars)
+
+ return variables
+
+ except Exception as e:
+ logger.error(f"获取前置节点变量失败: {e}")
+ return []
+
+
+
+
+
+# 请求和响应模型
+class CreateVariableRequest(BaseModel):
+ """创建变量请求"""
+ name: str = Field(description="变量名称")
+ var_type: VariableType = Field(description="变量类型")
+ scope: VariableScope = Field(description="变量作用域")
+ value: Optional[str] = Field(default=None, description="变量值")
+ description: Optional[str] = Field(default=None, description="变量描述")
+ flow_id: Optional[str] = Field(default=None, description="流程ID(环境级和对话级变量必需)")
+
+
+class UpdateVariableRequest(BaseModel):
+ """更新变量请求"""
+ value: Optional[str] = Field(default=None, description="新的变量值")
+ var_type: Optional[VariableType] = Field(default=None, description="新的变量类型")
+ description: Optional[str] = Field(default=None, description="新的变量描述")
+
+
+class VariableResponse(BaseModel):
+ """变量响应"""
+ name: str = Field(description="变量名称")
+ var_type: str = Field(description="变量类型")
+ scope: str = Field(description="变量作用域")
+ value: str = Field(description="变量值")
+ description: Optional[str] = Field(description="变量描述")
+ created_at: str = Field(description="创建时间")
+ updated_at: str = Field(description="更新时间")
+ step: Optional[str] = Field(default=None, description="节点名称(前置节点变量专用)")
+ step_id: Optional[str] = Field(default=None, description="节点ID(前置节点变量专用)")
+
+
+class VariableListResponse(BaseModel):
+ """变量列表响应"""
+ variables: List[VariableResponse] = Field(description="变量列表")
+ total: int = Field(description="总数量")
+
+
+class ParseTemplateRequest(BaseModel):
+ """解析模板请求"""
+ template: str = Field(description="包含变量引用的模板")
+ flow_id: Optional[str] = Field(default=None, description="流程ID")
+
+
+class ParseTemplateResponse(BaseModel):
+ """解析模板响应"""
+ parsed_template: str = Field(description="解析后的模板")
+ variables_used: List[str] = Field(description="使用的变量引用")
+
+
+class ValidateTemplateResponse(BaseModel):
+ """验证模板响应"""
+ is_valid: bool = Field(description="是否有效")
+ invalid_references: List[str] = Field(description="无效的变量引用")
+
+
+@router.post(
+ "/create",
+ responses={
+ status.HTTP_200_OK: {"model": ResponseData},
+ status.HTTP_400_BAD_REQUEST: {"model": ResponseData},
+ status.HTTP_403_FORBIDDEN: {"model": ResponseData},
+ },
+)
+async def create_variable(
+ user_sub: Annotated[str, Depends(get_user)],
+ request: CreateVariableRequest = Body(...),
+) -> ResponseData:
+ """创建变量"""
+ try:
+ # 验证作用域权限
+ if request.scope == VariableScope.SYSTEM:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="不允许创建系统级变量"
+ )
+
+ pool_manager = await get_pool_manager()
+
+ # 根据作用域获取合适的变量池
+ if request.scope == VariableScope.USER:
+ # 用户级变量需要user_sub参数
+ if not user_sub:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="用户级变量需要用户身份"
+ )
+ pool = await pool_manager.get_user_pool(user_sub)
+ elif request.scope == VariableScope.ENVIRONMENT:
+ # 环境级变量需要flow_id参数
+ if not request.flow_id:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="环境级变量需要flow_id参数"
+ )
+ pool = await pool_manager.get_flow_pool(request.flow_id)
+ elif request.scope == VariableScope.CONVERSATION:
+ # 对话级变量需要flow_id参数,用于创建模板
+ if not request.flow_id:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="对话级变量需要flow_id参数"
+ )
+ # 对话级变量模板在流程池中定义
+ pool = await pool_manager.get_flow_pool(request.flow_id)
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"不支持的变量作用域: {request.scope.value}"
+ )
+
+ if not pool:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="无法获取变量池"
+ )
+
+ # 根据作用域创建不同类型的变量
+ if request.scope == VariableScope.CONVERSATION:
+ # 创建对话变量模板
+ variable = await pool.add_conversation_template(
+ name=request.name,
+ var_type=request.var_type,
+ default_value=request.value,
+ description=request.description,
+ created_by=user_sub
+ )
+ else:
+ # 创建其他类型的变量
+ variable = await pool.add_variable(
+ name=request.name,
+ var_type=request.var_type,
+ value=request.value,
+ description=request.description,
+ created_by=user_sub
+ )
+
+
+ return ResponseData(
+ code=200,
+ message="变量创建成功",
+ result={"variable_name": variable.name},
+ )
+
+ except ValueError as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=str(e)
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"创建变量失败: {str(e)}"
+ )
+
+
+@router.put(
+ "/update",
+ responses={
+ status.HTTP_200_OK: {"model": ResponseData},
+ status.HTTP_400_BAD_REQUEST: {"model": ResponseData},
+ status.HTTP_403_FORBIDDEN: {"model": ResponseData},
+ status.HTTP_404_NOT_FOUND: {"model": ResponseData},
+ },
+)
+async def update_variable(
+ user_sub: Annotated[str, Depends(get_user)],
+ name: str = Query(..., description="变量名称"),
+ scope: VariableScope = Query(..., description="变量作用域"),
+ flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"),
+ conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"),
+ request: UpdateVariableRequest = Body(...),
+) -> ResponseData:
+ """更新变量值"""
+ try:
+ pool_manager = await get_pool_manager()
+
+ # 根据作用域获取合适的变量池
+ if scope == VariableScope.USER:
+ if not user_sub:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="用户级变量需要用户身份"
+ )
+ pool = await pool_manager.get_user_pool(user_sub)
+ elif scope == VariableScope.ENVIRONMENT:
+ if not flow_id:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="环境级变量需要flow_id参数"
+ )
+ pool = await pool_manager.get_flow_pool(flow_id)
+ elif scope == VariableScope.CONVERSATION:
+ if conversation_id:
+ # 运行时:使用对话池,如果不存在则创建
+ if not flow_id:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="对话级变量运行时需要conversation_id和flow_id参数"
+ )
+ pool = await pool_manager.get_conversation_pool(conversation_id)
+ if not pool:
+ # 对话池不存在,自动创建
+ pool = await pool_manager.create_conversation_pool(conversation_id, flow_id)
+ elif flow_id:
+ # 配置时:使用流程池
+ pool = await pool_manager.get_flow_pool(flow_id)
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="对话级变量需要conversation_id(运行时)或flow_id(配置时)参数"
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"不支持的变量作用域: {scope.value}"
+ )
+
+ if not pool:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="无法获取变量池"
+ )
+
+ # 更新变量
+ variable = await pool.update_variable(
+ name=name,
+ value=request.value,
+ var_type=request.var_type,
+ description=request.description
+ )
+
+ return ResponseData(
+ code=200,
+ message="变量更新成功",
+ result={"variable_name": variable.name}
+ )
+
+ except ValueError as e:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=str(e)
+ )
+ except PermissionError as e:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=str(e)
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"更新变量失败: {str(e)}"
+ )
+
+
+@router.delete(
+ "/delete",
+ responses={
+ status.HTTP_200_OK: {"model": ResponseData},
+ status.HTTP_403_FORBIDDEN: {"model": ResponseData},
+ status.HTTP_404_NOT_FOUND: {"model": ResponseData},
+ },
+)
+async def delete_variable(
+ user_sub: Annotated[str, Depends(get_user)],
+ name: str = Query(..., description="变量名称"),
+ scope: VariableScope = Query(..., description="变量作用域"),
+ flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"),
+ conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"),
+) -> ResponseData:
+ """删除变量"""
+ try:
+ pool_manager = await get_pool_manager()
+
+ # 根据作用域获取合适的变量池
+ if scope == VariableScope.USER:
+ if not user_sub:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="用户级变量需要用户身份"
+ )
+ pool = await pool_manager.get_user_pool(user_sub)
+ elif scope == VariableScope.ENVIRONMENT:
+ if not flow_id:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="环境级变量需要flow_id参数"
+ )
+ pool = await pool_manager.get_flow_pool(flow_id)
+ elif scope == VariableScope.CONVERSATION:
+ if conversation_id:
+ # 运行时:使用对话池,如果不存在则创建
+ if not flow_id:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="对话级变量运行时需要conversation_id和flow_id参数"
+ )
+ pool = await pool_manager.get_conversation_pool(conversation_id)
+ if not pool:
+ # 对话池不存在,自动创建
+ pool = await pool_manager.create_conversation_pool(conversation_id, flow_id)
+ elif flow_id:
+ # 配置时:使用流程池
+ pool = await pool_manager.get_flow_pool(flow_id)
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="对话级变量需要conversation_id(运行时)或flow_id(配置时)参数"
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"不支持的变量作用域: {scope.value}"
+ )
+
+ if not pool:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="无法获取变量池"
+ )
+
+ # 删除变量
+ success = await pool.delete_variable(name)
+
+ if not success:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="变量不存在"
+ )
+
+ return ResponseData(
+ code=200,
+ message="变量删除成功",
+ result={"variable_name": name}
+ )
+
+ except ValueError as e:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=str(e)
+ )
+ except PermissionError as e:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=str(e)
+ )
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"删除变量失败: {str(e)}"
+ )
+
+
+@router.get(
+ "/get",
+ responses={
+ status.HTTP_200_OK: {"model": VariableResponse},
+ status.HTTP_404_NOT_FOUND: {"model": ResponseData},
+ },
+)
+async def get_variable(
+ user_sub: Annotated[str, Depends(get_user)],
+ name: str = Query(..., description="变量名称"),
+ scope: VariableScope = Query(..., description="变量作用域"),
+ flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"),
+ conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"),
+) -> VariableResponse:
+ """获取单个变量"""
+ try:
+ pool_manager = await get_pool_manager()
+
+ # 根据作用域获取变量
+ variable = await pool_manager.get_variable_from_any_pool(
+ name=name,
+ scope=scope,
+ user_id=user_sub if scope == VariableScope.USER else None,
+ flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None,
+ conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None
+ )
+
+ if not variable:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="变量不存在"
+ )
+
+ # 检查权限
+ if not variable.can_access(user_sub):
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="没有权限访问此变量"
+ )
+
+ # 构建响应
+ var_dict = variable.to_dict()
+ return VariableResponse(
+ name=variable.name,
+ var_type=variable.var_type.value,
+ scope=variable.scope.value,
+ value=str(var_dict["value"]) if var_dict["value"] is not None else "",
+ description=variable.metadata.description,
+ created_at=variable.metadata.created_at.isoformat(),
+ updated_at=variable.metadata.updated_at.isoformat(),
+ )
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"获取变量失败: {str(e)}"
+ )
+
+
+@router.get(
+ "/list",
+ responses={
+ status.HTTP_200_OK: {"model": VariableListResponse},
+ },
+)
+async def list_variables(
+ user_sub: Annotated[str, Depends(get_user)],
+ scope: VariableScope = Query(..., description="变量作用域"),
+ flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"),
+ conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"),
+ current_step_id: Optional[str] = Query(default=None, description="当前步骤ID(用于获取前置节点变量)"),
+) -> VariableListResponse:
+ """列出指定作用域的变量"""
+ try:
+ pool_manager = await get_pool_manager()
+
+ # 获取变量列表
+ variables = await pool_manager.list_variables_from_any_pool(
+ scope=scope,
+ user_id=user_sub if scope == VariableScope.USER else None,
+ flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None,
+ conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None
+ )
+
+ # 如果是对话级变量且提供了current_step_id,则额外获取前置节点的输出变量
+ if scope == VariableScope.CONVERSATION and current_step_id and flow_id:
+ predecessor_variables = await _get_predecessor_node_variables(
+ user_sub, flow_id, conversation_id, current_step_id
+ )
+ variables.extend(predecessor_variables)
+
+ # 过滤权限并构建响应
+ filtered_variables = []
+ for variable in variables:
+ if variable.can_access(user_sub):
+ var_dict = variable.to_dict()
+
+ # 检查是否为前置节点变量
+ is_predecessor_var = (
+ "." in variable.name and
+ not variable.name.startswith("system.") and
+ scope == VariableScope.CONVERSATION and
+ flow_id
+ )
+
+ if is_predecessor_var:
+ # 前置节点变量特殊处理
+ parts = variable.name.split(".", 1)
+ if len(parts) == 2:
+ step_id, var_name = parts
+
+ # 优先使用缓存数据中的节点信息
+ if hasattr(variable, '_cache_data') and variable._cache_data:
+ cache_data = variable._cache_data
+ step_name = cache_data.get('step_name', step_id)
+ step_id_from_cache = cache_data.get('step_id', step_id)
+ else:
+ # 降级到实时获取节点信息
+ node_info = await _get_node_info_by_step_id(flow_id, step_id)
+ step_name = node_info["name"]
+ step_id_from_cache = node_info["step_id"]
+
+ filtered_variables.append(VariableResponse(
+ name=var_name, # 只保留变量名部分
+ var_type=variable.var_type.value,
+ scope=variable.scope.value,
+ value=str(var_dict["value"]) if var_dict["value"] is not None else "",
+ description=variable.metadata.description,
+ created_at=variable.metadata.created_at.isoformat(),
+ updated_at=variable.metadata.updated_at.isoformat(),
+ step=step_name, # 节点名称
+ step_id=step_id_from_cache # 节点ID
+ ))
+ else:
+ # 降级处理,如果格式不符合预期
+ filtered_variables.append(VariableResponse(
+ name=variable.name,
+ var_type=variable.var_type.value,
+ scope=variable.scope.value,
+ value=str(var_dict["value"]) if var_dict["value"] is not None else "",
+ description=variable.metadata.description,
+ created_at=variable.metadata.created_at.isoformat(),
+ updated_at=variable.metadata.updated_at.isoformat(),
+ ))
+ else:
+ # 普通变量
+ filtered_variables.append(VariableResponse(
+ name=variable.name,
+ var_type=variable.var_type.value,
+ scope=variable.scope.value,
+ value=str(var_dict["value"]) if var_dict["value"] is not None else "",
+ description=variable.metadata.description,
+ created_at=variable.metadata.created_at.isoformat(),
+ updated_at=variable.metadata.updated_at.isoformat(),
+ ))
+
+ return VariableListResponse(
+ variables=filtered_variables,
+ total=len(filtered_variables)
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"获取变量列表失败: {str(e)}"
+ )
+
+
+@router.post(
+ "/parse",
+ responses={
+ status.HTTP_200_OK: {"model": ParseTemplateResponse},
+ status.HTTP_400_BAD_REQUEST: {"model": ResponseData},
+ },
+)
+async def parse_template(
+ user_sub: Annotated[str, Depends(get_user)],
+ request: ParseTemplateRequest = Body(...),
+) -> ParseTemplateResponse:
+ """解析模板中的变量引用"""
+ try:
+ # 创建变量解析器
+ parser = VariableParser(
+ user_id=user_sub,
+ flow_id=request.flow_id,
+ conversation_id=None, # 不再使用conversation_id
+ )
+
+ # 解析模板
+ parsed_template = await parser.parse_template(request.template)
+
+ # 提取使用的变量
+ variables_used = await parser.extract_variables(request.template)
+
+ return ParseTemplateResponse(
+ parsed_template=parsed_template,
+ variables_used=variables_used
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"解析模板失败: {str(e)}"
+ )
+
+
+@router.post(
+ "/validate",
+ responses={
+ status.HTTP_200_OK: {"model": ValidateTemplateResponse},
+ status.HTTP_400_BAD_REQUEST: {"model": ResponseData},
+ },
+)
+async def validate_template(
+ user_sub: Annotated[str, Depends(get_user)],
+ request: ParseTemplateRequest = Body(...),
+) -> ValidateTemplateResponse:
+ """验证模板中的变量引用是否有效"""
+ try:
+ # 创建变量解析器
+ parser = VariableParser(
+ user_id=user_sub,
+ flow_id=request.flow_id,
+ conversation_id=None, # 不再使用conversation_id
+ )
+
+ # 验证模板
+ is_valid, invalid_refs = await parser.validate_template(request.template)
+
+ return ValidateTemplateResponse(
+ is_valid=is_valid,
+ invalid_references=invalid_refs
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"验证模板失败: {str(e)}"
+ )
+
+
+@router.get(
+ "/types",
+ responses={
+ status.HTTP_200_OK: {"model": ResponseData},
+ },
+)
+async def get_variable_types() -> ResponseData:
+ """获取支持的变量类型列表"""
+ return ResponseData(
+ code=200,
+ message="获取变量类型成功",
+ result={
+ "types": [vtype.value for vtype in VariableType],
+ "scopes": [scope.value for scope in VariableScope],
+ }
+ )
+
+
+@router.post(
+ "/clear-conversation",
+ responses={
+ status.HTTP_200_OK: {"model": ResponseData},
+ },
+)
+async def clear_conversation_variables(
+ user_sub: Annotated[str, Depends(get_user)],
+ flow_id: str = Query(..., description="流程ID"),
+) -> ResponseData:
+ """清空指定工作流的对话级变量"""
+ try:
+ pool_manager = await get_pool_manager()
+ # 清空工作流的对话级变量
+ await pool_manager.clear_conversation_variables(flow_id)
+
+ return ResponseData(
+ code=200,
+ message="工作流对话变量已清空",
+ result={"flow_id": flow_id}
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"清空对话变量失败: {str(e)}"
+ )
+
+
+async def _get_node_info_by_step_id(flow_id: str, step_id: str) -> Dict[str, str]:
+ """根据step_id获取节点信息"""
+ try:
+ flow_item = await _get_flow_by_flow_id(flow_id)
+ if not flow_item:
+ return {"name": step_id, "step_id": step_id} # 降级返回step_id作为名称
+
+ # 查找对应的节点
+ for node in flow_item.nodes:
+ if node.step_id == step_id:
+ return {
+ "name": node.name or step_id, # 如果没有名称则使用step_id
+ "step_id": step_id
+ }
+
+ # 如果没有找到节点,返回默认值
+ return {"name": step_id, "step_id": step_id}
+
+ except Exception as e:
+ logger.error(f"获取节点信息失败: {e}")
+ return {"name": step_id, "step_id": step_id}
+
+
+async def _get_predecessor_variables_from_topology(
+ flow_id: str,
+ current_step_id: str,
+ user_sub: str
+) -> List:
+ """通过工作流拓扑分析获取前置节点变量"""
+ try:
+ variables = []
+
+ # 直接通过flow_id获取工作流拓扑信息
+ flow_item = await _get_flow_by_flow_id(flow_id)
+ if not flow_item:
+ logger.warning(f"无法获取工作流信息: flow_id={flow_id}")
+ return variables
+
+ # 分析前置节点
+ predecessor_nodes = _find_predecessor_nodes(flow_item, current_step_id)
+
+ # 为每个前置节点创建潜在的输出变量
+ for node in predecessor_nodes:
+ node_vars = await _create_node_output_variables(node, user_sub)
+ variables.extend(node_vars)
+
+ logger.info(f"通过拓扑分析为节点 {current_step_id} 创建了 {len(variables)} 个前置节点变量")
+ return variables
+
+ except Exception as e:
+ logger.error(f"通过拓扑分析获取前置节点变量失败: {e}")
+ return []
+
+
+async def _get_flow_by_flow_id(flow_id: str):
+ """直接通过flow_id获取工作流信息"""
+ try:
+ from apps.common.mongo import MongoDB
+
+ app_collection = MongoDB().get_collection("app")
+
+ # 查询包含此flow_id的app,同时获取app_id
+ app_record = await app_collection.find_one(
+ {"flows.id": flow_id},
+ {"_id": 1}
+ )
+
+ if not app_record:
+ logger.warning(f"未找到包含flow_id {flow_id} 的应用")
+ return None
+
+ app_id = app_record["_id"]
+
+ # 使用现有的FlowManager方法获取flow
+ flow_item = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id)
+ return flow_item
+
+ except Exception as e:
+ logger.error(f"通过flow_id获取工作流失败: {e}")
+ return None
+
+
+def _find_predecessor_nodes(flow_item, current_step_id: str) -> List:
+ """在工作流中查找前置节点"""
+ try:
+ predecessor_nodes = []
+
+ # 遍历边,找到指向当前节点的边
+ for edge in flow_item.edges:
+ if edge.target_node == current_step_id:
+ # 找到前置节点
+ source_node = next(
+ (node for node in flow_item.nodes if node.step_id == edge.source_node),
+ None
+ )
+ if source_node:
+ predecessor_nodes.append(source_node)
+
+ logger.info(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点")
+ return predecessor_nodes
+
+ except Exception as e:
+ logger.error(f"查找前置节点失败: {e}")
+ return []
+
+
+async def _create_node_output_variables(node, user_sub: str) -> List:
+ """根据节点的output_parameters配置创建输出变量"""
+ try:
+ from apps.scheduler.variable.variables import create_variable
+ from apps.scheduler.variable.base import VariableMetadata
+ from datetime import datetime, UTC
+
+ variables = []
+ node_id = node.step_id
+
+ # 调试:输出节点的完整参数信息
+ logger.info(f"节点 {node_id} 的参数结构: {node.parameters}")
+
+ # 统一从节点的output_parameters创建变量
+ output_params = {}
+ if hasattr(node, 'parameters') and node.parameters:
+ # 尝试不同的访问方式
+ if isinstance(node.parameters, dict):
+ output_params = node.parameters.get('output_parameters', {})
+ logger.info(f"从字典中获取output_parameters: {output_params}")
+ else:
+ output_params = getattr(node.parameters, 'output_parameters', {})
+ logger.info(f"从对象属性中获取output_parameters: {output_params}")
+
+ # 如果没有配置output_parameters,跳过此节点
+ if not output_params:
+ logger.info(f"节点 {node_id} 没有配置output_parameters,跳过创建输出变量")
+ return variables
+
+ # 遍历output_parameters中的每个key-value对,创建对应的变量
+ for param_name, param_config in output_params.items():
+ # 解析参数配置
+ if isinstance(param_config, dict):
+ param_type = param_config.get('type', 'string')
+ description = param_config.get('description', '')
+ else:
+ # 如果param_config不是字典,可能是简单的类型字符串
+ param_type = str(param_config) if param_config else 'string'
+ description = ''
+
+ # 确定变量类型
+ var_type = VariableType.STRING # 默认类型
+ if param_type == 'number':
+ var_type = VariableType.NUMBER
+ elif param_type == 'boolean':
+ var_type = VariableType.BOOLEAN
+ elif param_type == 'object':
+ var_type = VariableType.OBJECT
+ elif param_type == 'array' or param_type == 'array[any]':
+ var_type = VariableType.ARRAY_ANY
+ elif param_type == 'array[string]':
+ var_type = VariableType.ARRAY_STRING
+ elif param_type == 'array[number]':
+ var_type = VariableType.ARRAY_NUMBER
+ elif param_type == 'array[object]':
+ var_type = VariableType.ARRAY_OBJECT
+ elif param_type == 'array[boolean]':
+ var_type = VariableType.ARRAY_BOOLEAN
+ elif param_type == 'array[file]':
+ var_type = VariableType.ARRAY_FILE
+ elif param_type == 'array[secret]':
+ var_type = VariableType.ARRAY_SECRET
+ elif param_type == 'file':
+ var_type = VariableType.FILE
+ elif param_type == 'secret':
+ var_type = VariableType.SECRET
+
+ # 创建变量元数据
+ metadata = VariableMetadata(
+ name=f"{node_id}.{param_name}",
+ var_type=var_type,
+ scope=VariableScope.CONVERSATION,
+ description=description or f"来自节点 {node_id} 的输出参数 {param_name}",
+ created_by=user_sub,
+ created_at=datetime.now(UTC),
+ updated_at=datetime.now(UTC)
+ )
+
+ # 创建变量对象
+ variable = create_variable(metadata, "") # 配置阶段的潜在变量,值为空
+ variables.append(variable)
+
+ logger.info(f"为节点 {node_id} 创建了 {len(variables)} 个输出变量: {[v.name for v in variables]}")
+ return variables
+
+ except Exception as e:
+ logger.error(f"创建节点输出变量失败: {e}")
+ return []
\ No newline at end of file
diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py
index 2ee8b862885b0d88e519bf00a1678a49863df2bd..a7475bf9c82dd94f97f8097fb263e3cfaf5241d4 100644
--- a/apps/scheduler/call/__init__.py
+++ b/apps/scheduler/call/__init__.py
@@ -2,20 +2,28 @@
"""Agent工具部分"""
from apps.scheduler.call.api.api import API
+from apps.scheduler.call.code.code import Code
from apps.scheduler.call.graph.graph import Graph
from apps.scheduler.call.llm.llm import LLM
from apps.scheduler.call.mcp.mcp import MCP
from apps.scheduler.call.rag.rag import RAG
+from apps.scheduler.call.reply.direct_reply import DirectReply
from apps.scheduler.call.sql.sql import SQL
-from apps.scheduler.call.suggest.suggest import Suggestion
+# from apps.scheduler.call.graph.graph import Graph
+# from apps.scheduler.call.suggest.suggest import Suggestion
+from apps.scheduler.call.suggest.suggest import Suggestion
+from apps.scheduler.call.choice.choice import Choice
# 只包含需要在编排界面展示的工具
__all__ = [
"API",
+ "Code",
+ "DirectReply",
"LLM",
"MCP",
"RAG",
- "SQL",
- "Graph",
- "Suggestion",
+ # "SQL",
+ # "Graph",
+ # "Suggestion",
+ "Choice"
]
diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py
index e1891f7259b72a2fa03228f6289c54da6297b958..c34cbc0585616c0c12f55e6d4f97012396818396 100644
--- a/apps/scheduler/call/api/api.py
+++ b/apps/scheduler/call/api/api.py
@@ -14,8 +14,8 @@ from pydantic.json_schema import SkipJsonSchema
from apps.common.oidc import oidc_provider
from apps.scheduler.call.api.schema import APIInput, APIOutput
-from apps.scheduler.call.core import CoreCall
-from apps.schemas.enum_var import CallOutputType, ContentType, HTTPMethod
+from apps.scheduler.call.core import CoreCall, NodeType
+from apps.schemas.enum_var import CallOutputType, CallType, ContentType, HTTPMethod
from apps.schemas.scheduler import (
CallError,
CallInfo,
@@ -49,20 +49,28 @@ SUCCESS_HTTP_CODES = [
class API(CoreCall, input_model=APIInput, output_model=APIOutput):
"""API调用工具"""
- enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=True)
+ enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False)
url: str = Field(description="API接口的完整URL")
method: HTTPMethod = Field(description="API接口的HTTP Method")
content_type: ContentType | None = Field(description="API接口的Content-Type", default=None)
timeout: int = Field(description="工具超时时间", default=300, gt=30)
- body: dict[str, Any] = Field(description="已知的部分请求体", default={})
- query: dict[str, Any] = Field(description="已知的部分请求参数", default={})
+ body: list[dict[str, Any]] = Field(description="已知的部分请求体", default=[])
+ query: list[dict[str, Any]]= Field(description="已知的部分请求参数", default=[])
+ headers: list[dict[str, Any]] = Field(description="已知的部分请求头", default=[])
+ # 增加node_type
+ node_type: NodeType | None = Field(description="节点类型", default=NodeType.API)
+
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="API调用", description="向某一个API接口发送HTTP请求,获取数据。")
+ return CallInfo(
+ name="API调用",
+ type=CallType.TOOL,
+ description="向某一个API接口发送HTTP请求,获取数据。"
+ )
async def _init(self, call_vars: CallVars) -> APIInput:
"""初始化API调用工具"""
@@ -97,6 +105,7 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput):
method=self.method,
query=self.query,
body=self.body,
+ headers=self.headers,
)
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
@@ -129,12 +138,20 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput):
url=self.url,
cookies=req_cookie,
)
+ logger.info("api_input ----------------------- %r", data)
# 根据HTTP方法创建请求
if self.method in [HTTPMethod.GET.value, HTTPMethod.DELETE.value]:
# GET/DELETE 请求处理
- req_params.update(data.query)
- return await request_factory(params=req_params)
+ # 转换query参数格式: list[dict] -> dict
+ query_dict = {item["key"]: item["value"] for item in data.query}
+ req_params.update(query_dict)
+ logger.info("query_dict ----------------------- %r", query_dict)
+ logger.info("req_params ----------------------- %r", req_params)
+ header_dict = {item["key"]: item["value"] for item in data.headers}
+ req_header.update(header_dict)
+ logger.info("req_header ----------------------- %r", req_header)
+ return await request_factory(params=req_params, headers=req_header)
if self.method in [HTTPMethod.POST.value, HTTPMethod.PUT.value, HTTPMethod.PATCH.value]:
# POST/PUT/PATCH 请求处理
@@ -142,8 +159,13 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput):
raise CallError(message="API接口的Content-Type未指定", data={})
# 根据Content-Type设置请求参数
- req_body = data.body
- req_header.update({"Content-Type": self.content_type})
+ # 转换body参数格式: list[dict] -> dict
+ req_body = {item["key"]: item["value"] for item in data.body}
+ header_dict = {item["key"]: item["value"] for item in data.headers}
+ header_dict["Content-Type"] = self.content_type
+ req_header.update(header_dict)
+ logger.info("header_dict ----------------------- %r", header_dict)
+ logger.info("req_header ----------------------- %r", req_header)
# 根据Content-Type决定如何发送请求体
content_type_handlers = {
diff --git a/apps/scheduler/call/api/schema.py b/apps/scheduler/call/api/schema.py
index 055008a889676e4e2fbd5057912ba78421e52aa2..1fbe82baae820f28ae6820f67323a48b6090d083 100644
--- a/apps/scheduler/call/api/schema.py
+++ b/apps/scheduler/call/api/schema.py
@@ -5,7 +5,6 @@ from typing import Any
from pydantic import Field
from pydantic.json_schema import SkipJsonSchema
-
from apps.scheduler.call.core import DataBase
@@ -15,8 +14,9 @@ class APIInput(DataBase):
url: SkipJsonSchema[str] = Field(description="API调用工具的URL")
method: SkipJsonSchema[str] = Field(description="API调用工具的HTTP方法")
- query: dict[str, Any] = Field(description="API调用工具的请求参数", default={})
- body: dict[str, Any] = Field(description="API调用工具的请求体", default={})
+ query: list[dict[str, Any]] = Field(description="API调用工具的请求参数", default=[])
+ body: list[dict[str, Any]] = Field(description="API调用工具的请求体", default=[])
+ headers: list[dict[str, Any]] = Field(description="API调用工具的请求头", default=[])
class APIOutput(DataBase):
diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py
index a5edf21afb2eeb308dd909a40696500751c9a086..47a0df747483c3915efb21c868d5cb1785b903e0 100644
--- a/apps/scheduler/call/choice/choice.py
+++ b/apps/scheduler/call/choice/choice.py
@@ -1,19 +1,153 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""使用大模型或使用程序做出判断"""
-from enum import Enum
+import ast
+import copy
+import logging
+from collections.abc import AsyncGenerator
+from typing import Any
-from apps.scheduler.call.choice.schema import ChoiceInput, ChoiceOutput
-from apps.scheduler.call.core import CoreCall
+from pydantic import Field
+from apps.scheduler.call.choice.condition_handler import ConditionHandler
+from apps.scheduler.call.choice.schema import (
+ Condition,
+ ChoiceBranch,
+ ChoiceInput,
+ ChoiceOutput,
+ Logic,
+)
+from apps.schemas.parameters import Type
+from apps.scheduler.call.core import CoreCall, NodeType
+from apps.schemas.enum_var import CallOutputType, CallType
+from apps.schemas.scheduler import (
+ CallError,
+ CallInfo,
+ CallOutputChunk,
+ CallVars,
+)
-class Operator(str, Enum):
- """Choice工具支持的运算符"""
-
- pass
+logger = logging.getLogger(__name__)
class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput):
"""Choice工具"""
- pass
+ to_user: bool = Field(default=False)
+ choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(),
+ ChoiceBranch(conditions=[Condition()], is_default=False)])
+ # 增加node_type
+ node_type: NodeType | None = Field(description="节点类型", default=NodeType.CHOICE)
+
+
+ @classmethod
+ def info(cls) -> CallInfo:
+ """返回Call的名称和描述"""
+ return CallInfo(name="选择器", description="使用大模型或使用程序做出判断", type=CallType.LOGIC)
+
+ async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]:
+ """替换choices中的系统变量"""
+ valid_choices = []
+
+ for choice in self.choices:
+ try:
+ # 验证逻辑运算符
+ if choice.logic not in [Logic.AND, Logic.OR]:
+ msg = f"无效的逻辑运算符: {choice.logic}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+
+ valid_conditions = []
+ for i in range(len(choice.conditions)):
+ condition = copy.deepcopy(choice.conditions[i])
+ # 处理左值
+ if condition.left.step_id is not None:
+ condition.left.value = self._extract_history_variables(
+ condition.left.step_id+'/'+condition.left.value, call_vars.history)
+ # 检查历史变量是否成功提取
+ if condition.left.value is None:
+ msg = f"步骤 {condition.left.step_id} 的历史变量不存在"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if not ConditionHandler.check_value_type(
+ condition.left.value, condition.left.type):
+ msg = f"左值类型不匹配: {condition.left.value} 应为 {condition.left.type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ else:
+ msg = "左侧变量缺少step_id"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ # 处理右值
+ if condition.right.step_id is not None:
+ condition.right.value = self._extract_history_variables(
+ condition.right.step_id+'/'+condition.right.value, call_vars.history)
+ # 检查历史变量是否成功提取
+ if condition.right.value is None:
+ msg = f"步骤 {condition.right.step_id} 的历史变量不存在"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if not ConditionHandler.check_value_type(
+ condition.right.value, condition.right.type):
+ msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ else:
+ # 如果右值没有step_id,尝试从call_vars中获取
+ right_value_type = await ConditionHandler.get_value_type_from_operate(
+ condition.operate)
+ if right_value_type is None:
+ msg = f"不支持的运算符: {condition.operate}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if condition.right.type != right_value_type:
+ msg = f"右值类型不匹配: {condition.right.value} 应为 {right_value_type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if right_value_type == Type.STRING:
+ condition.right.value = str(condition.right.value)
+ else:
+ condition.right.value = ast.literal_eval(condition.right.value)
+ if not ConditionHandler.check_value_type(
+ condition.right.value, condition.right.type):
+ msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ valid_conditions.append(condition)
+
+ # 如果所有条件都无效,抛出异常
+ if not valid_conditions:
+ msg = "分支没有有效条件"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+
+ # 更新有效条件
+ choice.conditions = valid_conditions
+ valid_choices.append(choice)
+
+ except ValueError as e:
+ logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e))
+ continue
+
+ return valid_choices
+
+ async def _init(self, call_vars: CallVars) -> ChoiceInput:
+ """初始化Choice工具"""
+ return ChoiceInput(
+ choices=await self._prepare_message(call_vars),
+ )
+
+ async def _exec(
+ self, input_data: dict[str, Any]
+ ) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行Choice工具"""
+ # 解析输入数据
+ data = ChoiceInput(**input_data)
+ try:
+ branch_id = ConditionHandler.handler(data.choices)
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=ChoiceOutput(branch_id=branch_id).model_dump(exclude_none=True, by_alias=True),
+ )
+ except Exception as e:
+ raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e
diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7542f2944eedff72b820a79725d64b5ddeba5a11
--- /dev/null
+++ b/apps/scheduler/call/choice/condition_handler.py
@@ -0,0 +1,314 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""处理条件分支的工具"""
+
+
+import logging
+
+from pydantic import BaseModel
+
+from apps.schemas.parameters import (
+ Type,
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+)
+
+from apps.scheduler.call.choice.schema import (
+ ChoiceBranch,
+ Condition,
+ Logic,
+ Value
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ConditionHandler(BaseModel):
+ """条件分支处理器"""
+ @staticmethod
+ async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate |
+ BoolOperate | DictOperate) -> Type:
+ """获取右值的类型"""
+ if isinstance(operate, NumberOperate):
+ return Type.NUMBER
+ if operate in [
+ StringOperate.EQUAL, StringOperate.NOT_EQUAL, StringOperate.CONTAINS, StringOperate.NOT_CONTAINS,
+ StringOperate.STARTS_WITH, StringOperate.ENDS_WITH, StringOperate.REGEX_MATCH]:
+ return Type.STRING
+ if operate in [StringOperate.LENGTH_EQUAL, StringOperate.LENGTH_GREATER_THAN,
+ StringOperate.LENGTH_GREATER_THAN_OR_EQUAL, StringOperate.LENGTH_LESS_THAN,
+ StringOperate.LENGTH_LESS_THAN_OR_EQUAL]:
+ return Type.NUMBER
+ if operate in [ListOperate.EQUAL, ListOperate.NOT_EQUAL]:
+ return Type.LIST
+ if operate in [ListOperate.CONTAINS, ListOperate.NOT_CONTAINS]:
+ return Type.STRING
+ if operate in [ListOperate.LENGTH_EQUAL, ListOperate.LENGTH_GREATER_THAN,
+ ListOperate.LENGTH_GREATER_THAN_OR_EQUAL, ListOperate.LENGTH_LESS_THAN,
+ ListOperate.LENGTH_LESS_THAN_OR_EQUAL]:
+ return Type.NUMBER
+ if operate in [BoolOperate.EQUAL, BoolOperate.NOT_EQUAL]:
+ return Type.BOOL
+ if operate in [DictOperate.EQUAL, DictOperate.NOT_EQUAL]:
+ return Type.DICT
+ if operate in [DictOperate.CONTAINS_KEY, DictOperate.NOT_CONTAINS_KEY]:
+ return Type.STRING
+ return None
+
+ @staticmethod
+ def check_value_type(value: Value, expected_type: Type) -> bool:
+ """检查值的类型是否符合预期"""
+ if expected_type == Type.STRING and isinstance(value.value, str):
+ return True
+ if expected_type == Type.NUMBER and isinstance(value.value, (int, float)):
+ return True
+ if expected_type == Type.LIST and isinstance(value.value, list):
+ return True
+ if expected_type == Type.DICT and isinstance(value.value, dict):
+ return True
+ if expected_type == Type.BOOL and isinstance(value.value, bool):
+ return True
+ return False
+
+ @staticmethod
+ def handler(choices: list[ChoiceBranch]) -> str:
+ """处理条件"""
+ default_branch = [c for c in choices if c.is_default]
+
+ for block_judgement in choices:
+ results = []
+ if block_judgement.is_default:
+ continue
+ for condition in block_judgement.conditions:
+ result = ConditionHandler._judge_condition(condition)
+ results.append(result)
+ if block_judgement.logic == Logic.AND:
+ final_result = all(results)
+ elif block_judgement.logic == Logic.OR:
+ final_result = any(results)
+
+ if final_result:
+ return block_judgement.branch_id
+
+ # 如果没有匹配的分支,选择默认分支
+ if default_branch:
+ return default_branch[0].branch_id
+ return ""
+
+ @staticmethod
+ def _judge_condition(condition: Condition) -> bool:
+ """
+ 判断条件是否成立。
+
+ Args:
+ condition (Condition): 'left', 'operate', 'right', 'type'
+
+ Returns:
+ bool
+
+ """
+ left = condition.left
+ operate = condition.operate
+ right = condition.right
+ value_type = condition.type
+
+ result = None
+ if value_type == Type.STRING:
+ result = ConditionHandler._judge_string_condition(left, operate, right)
+ elif value_type == Type.NUMBER:
+ result = ConditionHandler._judge_int_condition(left, operate, right)
+ elif value_type == Type.BOOL:
+ result = ConditionHandler._judge_bool_condition(left, operate, right)
+ elif value_type == Type.LIST:
+ result = ConditionHandler._judge_list_condition(left, operate, right)
+ elif value_type == Type.DICT:
+ result = ConditionHandler._judge_dict_condition(left, operate, right)
+ else:
+ logger.error("不支持的数据类型: %s", value_type)
+ msg = f"不支持的数据类型: {value_type}"
+ raise ValueError(msg)
+ return result
+
+ @staticmethod
+ def _judge_string_condition(left: Value, operate: StringOperate, right: Value) -> bool:
+ """
+ 判断字符串类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, str):
+ logger.error("左值不是字符串类型: %s", left_value)
+ msg = "左值必须是字符串类型"
+ raise TypeError(msg)
+ right_value = right.value
+ result = False
+ if operate == StringOperate.EQUAL:
+ return left_value == right_value
+ elif operate == StringOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == StringOperate.CONTAINS:
+ return right_value in left_value
+ elif operate == StringOperate.NOT_CONTAINS:
+ return right_value not in left_value
+ elif operate == StringOperate.STARTS_WITH:
+ return left_value.startswith(right_value)
+ elif operate == StringOperate.ENDS_WITH:
+ return left_value.endswith(right_value)
+ elif operate == StringOperate.REGEX_MATCH:
+ import re
+ return bool(re.match(right_value, left_value))
+ elif operate == StringOperate.LENGTH_EQUAL:
+ return len(left_value) == right_value
+ elif operate == StringOperate.LENGTH_GREATER_THAN:
+ return len(left_value) > right_value
+ elif operate == StringOperate.LENGTH_GREATER_THAN_OR_EQUAL:
+ return len(left_value) >= right_value
+ elif operate == StringOperate.LENGTH_LESS_THAN:
+ return len(left_value) < right_value
+ elif operate == StringOperate.LENGTH_LESS_THAN_OR_EQUAL:
+ return len(left_value) <= right_value
+ return False
+
+ @staticmethod
+ def _judge_number_condition(left: Value, operate: NumberOperate, right: Value) -> bool: # noqa: PLR0911
+ """
+ 判断数字类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, (int, float)):
+ logger.error("左值不是数字类型: %s", left_value)
+ msg = "左值必须是数字类型"
+ raise TypeError(msg)
+ right_value = right.value
+ if operate == NumberOperate.EQUAL:
+ return left_value == right_value
+ elif operate == NumberOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == NumberOperate.GREATER_THAN:
+ return left_value > right_value
+ elif operate == NumberOperate.LESS_THAN: # noqa: PLR2004
+ return left_value < right_value
+ elif operate == NumberOperate.GREATER_THAN_OR_EQUAL:
+ return left_value >= right_value
+ elif operate == NumberOperate.LESS_THAN_OR_EQUAL:
+ return left_value <= right_value
+ return False
+
+ @staticmethod
+ def _judge_bool_condition(left: Value, operate: BoolOperate, right: Value) -> bool:
+ """
+ 判断布尔类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, bool):
+ logger.error("左值不是布尔类型: %s", left_value)
+ msg = "左值必须是布尔类型"
+ raise TypeError(msg)
+ right_value = right.value
+ if operate == BoolOperate.EQUAL:
+ return left_value == right_value
+ elif operate == BoolOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == BoolOperate.IS_EMPTY:
+ return not left_value
+ elif operate == BoolOperate.NOT_EMPTY:
+ return left_value
+ return False
+
+ @staticmethod
+ def _judge_list_condition(left: Value, operate: ListOperate, right: Value):
+ """
+ 判断列表类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, list):
+ logger.error("左值不是列表类型: %s", left_value)
+ msg = "左值必须是列表类型"
+ raise TypeError(msg)
+ right_value = right.value
+ if operate == ListOperate.EQUAL:
+ return left_value == right_value
+ elif operate == ListOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == ListOperate.CONTAINS:
+ return right_value in left_value
+ elif operate == ListOperate.NOT_CONTAINS:
+ return right_value not in left_value
+ elif operate == ListOperate.LENGTH_EQUAL:
+ return len(left_value) == right_value
+ elif operate == ListOperate.LENGTH_GREATER_THAN:
+ return len(left_value) > right_value
+ elif operate == ListOperate.LENGTH_GREATER_THAN_OR_EQUAL:
+ return len(left_value) >= right_value
+ elif operate == ListOperate.LENGTH_LESS_THAN:
+ return len(left_value) < right_value
+ elif operate == ListOperate.LENGTH_LESS_THAN_OR_EQUAL:
+ return len(left_value) <= right_value
+ return False
+
+ @staticmethod
+ def _judge_dict_condition(left: Value, operate: DictOperate, right: Value):
+ """
+ 判断字典类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, dict):
+ logger.error("左值不是字典类型: %s", left_value)
+ msg = "左值必须是字典类型"
+ raise TypeError(msg)
+ right_value = right.value
+ if operate == DictOperate.EQUAL:
+ return left_value == right_value
+ elif operate == DictOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == DictOperate.CONTAINS_KEY:
+ return right_value in left_value
+ elif operate == DictOperate.NOT_CONTAINS_KEY:
+ return right_value not in left_value
+ return False
diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py
index 60b62d09fd66adbf32295f44ec86398a537f38d5..b95b166879608bc431525958e7e120ad93c5e5a3 100644
--- a/apps/scheduler/call/choice/schema.py
+++ b/apps/scheduler/call/choice/schema.py
@@ -1,12 +1,63 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""Choice Call的输入和输出"""
+import uuid
+from enum import Enum
+
+from pydantic import BaseModel, Field
+
+from apps.schemas.parameters import (
+ Type,
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+)
from apps.scheduler.call.core import DataBase
+class Logic(str, Enum):
+ """Choice 工具支持的逻辑运算符"""
+
+ AND = "and"
+ OR = "or"
+
+
+class Value(DataBase):
+ """值的结构"""
+
+ step_id: str | None = Field(description="步骤id", default=None)
+ type: Type | None = Field(description="值的类型", default=None)
+ value: str | float | int | bool | list | dict | None = Field(description="值", default=None)
+
+
+class Condition(DataBase):
+ """单个条件"""
+
+ left: Value = Field(description="左值", default=Value())
+ right: Value = Field(description="右值", default=Value())
+ operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate | None = Field(
+ description="运算符", default=None)
+ id: str = Field(description="条件ID", default_factory=lambda: str(uuid.uuid4()))
+
+
+class ChoiceBranch(DataBase):
+ """子分支"""
+
+ branch_id: str = Field(description="分支ID", default_factory=lambda: str(uuid.uuid4()))
+ logic: Logic = Field(description="逻辑运算符", default=Logic.AND)
+ conditions: list[Condition] = Field(description="条件列表", default=[])
+ is_default: bool = Field(description="是否为默认分支", default=True)
+
+
class ChoiceInput(DataBase):
"""Choice Call的输入"""
+ choices: list[ChoiceBranch] = Field(description="分支", default=[])
+
class ChoiceOutput(DataBase):
"""Choice Call的输出"""
+
+ branch_id: str = Field(description="分支ID", default="")
diff --git a/apps/scheduler/call/code/__init__.py b/apps/scheduler/call/code/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44da1c5078f9f6b32fec8cf3e23e6984aa1e8948
--- /dev/null
+++ b/apps/scheduler/call/code/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""代码执行工具"""
+
+from apps.scheduler.call.code.code import Code
+
+__all__ = ["Code"]
diff --git a/apps/scheduler/call/code/code.py b/apps/scheduler/call/code/code.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8ed7c01c415b3ede054c790bb195cd6d6048d32
--- /dev/null
+++ b/apps/scheduler/call/code/code.py
@@ -0,0 +1,318 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""代码执行工具"""
+
+import logging
+from collections.abc import AsyncGenerator
+from typing import Any
+import httpx
+from pydantic import Field
+from apps.common.config import Config
+from apps.scheduler.call.code.schema import CodeInput, CodeOutput
+from apps.scheduler.call.core import CoreCall, NodeType
+from apps.schemas.enum_var import CallOutputType, CallType
+from apps.schemas.scheduler import (
+ CallError,
+ CallInfo,
+ CallOutputChunk,
+ CallVars,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput):
+ """代码执行工具"""
+
+ to_user: bool = Field(default=True)
+
+ # 代码执行参数
+ code: str = Field(description="要执行的代码", default="")
+ code_type: str = Field(description="代码类型,支持python、javascript、bash", default="python")
+ security_level: str = Field(description="安全等级,low或high", default="low")
+ timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300)
+ memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024)
+ cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0)
+ input_parameters: dict[str, Any] = Field(description="输入参数配置", default={})
+ output_parameters: dict[str, Any] = Field(description="输出参数配置", default={})
+ # 增加node_type
+ node_type: NodeType | None = Field(description="节点类型", default=NodeType.CODE)
+
+
+
+ @classmethod
+ def info(cls) -> CallInfo:
+ """返回Call的名称和描述"""
+ return CallInfo(
+ name="代码执行",
+ type=CallType.TRANSFORM,
+ description="在安全的沙箱环境中执行Python、JavaScript、Bash代码。"
+ )
+
+
+ async def _init(self, call_vars: CallVars) -> CodeInput:
+ """初始化代码执行工具"""
+ # 构造用户信息
+ user_info = {
+ "user_id": call_vars.ids.user_sub,
+ "username": call_vars.ids.user_sub, # 可以从其他地方获取真实用户名
+ "permissions": ["execute"]
+ }
+
+ # 处理输入参数 - 使用基类的变量解析功能
+ input_arg = {}
+ if self.input_parameters:
+ # 解析每个输入参数
+ for param_name, param_config in self.input_parameters.items():
+ resolved_value = await self._resolve_variables_in_config(param_config, call_vars)
+ input_arg[param_name] = resolved_value
+
+ return CodeInput(
+ code=self.code,
+ code_type=self.code_type,
+ user_info=user_info,
+ security_level=self.security_level,
+ timeout_seconds=self.timeout_seconds,
+ memory_limit_mb=self.memory_limit_mb,
+ cpu_limit=self.cpu_limit,
+ input_arg=input_arg,
+ )
+
+
+ async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行代码"""
+ data = CodeInput(**input_data)
+
+ try:
+ # 获取sandbox服务地址
+ config = Config().get_config()
+ sandbox_url = config.sandbox.sandbox_service
+
+ # 构造请求数据
+ request_data = {
+ "code": data.code,
+ "code_type": data.code_type,
+ "user_info": data.user_info,
+ "security_level": data.security_level,
+ "timeout_seconds": data.timeout_seconds,
+ "memory_limit_mb": data.memory_limit_mb,
+ "cpu_limit": data.cpu_limit,
+ "input_arg": data.input_arg,
+ }
+
+ # 发送执行请求
+ async with httpx.AsyncClient(timeout=30.0) as client:
+ response = await client.post(
+ f"{sandbox_url.rstrip('/')}/execute",
+ json=request_data,
+ headers={"Content-Type": "application/json"}
+ )
+
+ if response.status_code != 200:
+ raise CallError(
+ message=f"代码执行服务请求失败: {response.status_code}",
+ data={"status_code": response.status_code, "response": response.text}
+ )
+
+ result = response.json()
+ logger.info(f"Sandbox service response: {result}")
+
+ # 检查请求是否成功
+ success = result.get("success", False)
+ message = result.get("message", "")
+ timestamp = result.get("timestamp", "")
+
+ if not success:
+ raise CallError(
+ message=f"代码执行服务返回错误: {message}",
+ data={"response": result}
+ )
+
+ # 提取任务信息
+ data = result.get("data", {})
+ task_id = data.get("task_id", "")
+ estimated_wait_time = data.get("estimated_wait_time", 0)
+ queue_position = data.get("queue_position", 0)
+
+ logger.info(f"Task submitted successfully - task_id: {task_id}, estimated_wait_time: {estimated_wait_time}s, queue_position: {queue_position}, timestamp: {timestamp}")
+
+ # 轮询获取结果
+ if task_id:
+ # 有task_id,需要轮询获取最终结果
+ result = await self._wait_for_result(sandbox_url, task_id)
+ else:
+ # 没有task_id,可能是同步执行,直接使用初始响应
+ # 但需要确保结果格式正确
+ if "output" not in result and "error" not in result:
+ # 如果初始响应没有包含执行结果,可能是异步但没有返回task_id的错误情况
+ result = {
+ "status": "error",
+ "error": "服务器没有返回task_id且没有执行结果",
+ "output": "",
+ "result": {}
+ }
+
+ # 处理sandbox返回的结果,提取output_parameters指定的数据
+ extracted_data = await self._process_sandbox_result(result)
+
+ # 构建最终输出内容
+ final_content = CodeOutput(
+ task_id=task_id,
+ status=result.get("status", "unknown"),
+ output=result.get("output") or "",
+ error=result.get("error") or "",
+ ).model_dump(by_alias=True, exclude_none=True)
+
+ # 如果成功提取到数据,将其合并到输出中
+ if extracted_data and result.get("status") == "completed":
+ final_content.update(extracted_data)
+ logger.info(f"[Code] 已将提取的数据合并到输出: {list(extracted_data.keys())}")
+
+ # 返回最终结果
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=final_content,
+ )
+
+ except httpx.TimeoutException:
+ raise CallError(message="代码执行超时", data={})
+ except httpx.RequestError as e:
+ raise CallError(message=f"请求sandbox服务失败: {e!s}", data={})
+ except Exception as e:
+ raise CallError(message=f"代码执行失败: {e!s}", data={})
+
+
+ async def _process_sandbox_result(self, result: dict[str, Any]) -> dict[str, Any] | None:
+ """处理sandbox返回的结果,根据output_parameters提取数据"""
+ try:
+ # 检查是否有output_parameters配置
+ if not hasattr(self, 'output_parameters') or not self.output_parameters:
+ logger.debug("[Code] 无output_parameters配置,跳过数据提取")
+ return None
+
+ # 获取sandbox返回的output
+ sandbox_output = result.get("output")
+ if not sandbox_output:
+ logger.warning("[Code] sandbox返回的结果中没有output字段")
+ return None
+
+ # 确保output是字典类型
+ if isinstance(sandbox_output, str):
+ # 尝试解析JSON字符串
+ try:
+ import json
+ sandbox_output = json.loads(sandbox_output)
+ except json.JSONDecodeError:
+ logger.warning(f"[Code] sandbox返回的output不是有效的JSON格式: {sandbox_output}")
+ return None
+
+ if not isinstance(sandbox_output, dict):
+ logger.warning(f"[Code] sandbox返回的output不是字典类型: {type(sandbox_output)}")
+ return None
+
+ # 根据output_parameters提取对应的kv对
+ extracted_data = {}
+ for param_name, param_config in self.output_parameters.items():
+ try:
+ # 支持多种提取方式
+ if param_name in sandbox_output:
+ # 直接键匹配
+ extracted_data[param_name] = sandbox_output[param_name]
+ elif isinstance(param_config, dict) and "path" in param_config:
+ # 路径提取
+ path = param_config["path"]
+ value = self._extract_value_by_path(sandbox_output, path)
+ if value is not None:
+ extracted_data[param_name] = value
+ elif isinstance(param_config, dict) and param_config.get("source") == "full_output":
+ # 使用完整输出
+ extracted_data[param_name] = sandbox_output
+ elif isinstance(param_config, dict) and "default" in param_config:
+ # 使用默认值
+ extracted_data[param_name] = param_config["default"]
+ else:
+ logger.debug(f"[Code] 无法提取参数 {param_name},在output中未找到对应值")
+
+ except Exception as e:
+ logger.warning(f"[Code] 提取参数 {param_name} 失败: {e}")
+
+ if extracted_data:
+ logger.info(f"[Code] 成功提取 {len(extracted_data)} 个输出参数: {list(extracted_data.keys())}")
+ return extracted_data
+ else:
+ logger.debug("[Code] 未能提取到任何输出参数")
+ return None
+
+ except Exception as e:
+ logger.error(f"[Code] 处理sandbox结果失败: {e}")
+ return None
+
+ def _extract_value_by_path(self, data: dict, path: str) -> Any:
+ """根据路径提取值 (例如: 'result.data.value')"""
+ try:
+ current = data
+ for key in path.split('.'):
+ if isinstance(current, dict) and key in current:
+ current = current[key]
+ else:
+ return None
+ return current
+ except Exception:
+ return None
+
+
+ async def _wait_for_result(self, sandbox_url: str, task_id: str, max_attempts: int = 30) -> dict[str, Any]:
+ """等待任务执行完成"""
+ import asyncio
+
+ async with httpx.AsyncClient(timeout=10.0) as client:
+ for _ in range(max_attempts):
+ try:
+ # 获取任务状态
+ response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/status")
+ if response.status_code == 200:
+ status_result = response.json()
+ logger.info(f"Task status response: {status_result}")
+
+ # 检查响应是否成功
+ success = status_result.get("success", False)
+ if not success:
+ message = status_result.get("message", "获取任务状态失败")
+ logger.warning(f"Failed to get task status: {message}")
+ await asyncio.sleep(1)
+ continue
+
+ # 提取任务状态
+ data = status_result.get("data", {})
+ status = data.get("status", "")
+
+ # 如果任务完成,获取结果
+ if status in ["completed", "failed", "cancelled"]:
+ result_response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/result")
+ if result_response.status_code == 200:
+ result_data = result_response.json()
+ logger.info(f"Task result response: {result_data}")
+
+ # 检查获取结果是否成功
+ if result_data.get("success", False):
+ # 返回实际的执行结果
+ return result_data.get("data", {})
+ else:
+ return {"status": status, "error": result_data.get("message", "获取结果失败")}
+ else:
+ return {"status": status, "error": "无法获取结果"}
+
+ # 如果任务仍在运行,继续等待
+ if status in ["pending", "running"]:
+ logger.debug(f"Task {task_id} still {status}, waiting...")
+ await asyncio.sleep(1)
+ continue
+
+ # 其他状态或请求失败,等待后重试
+ await asyncio.sleep(1)
+
+ except Exception:
+ # 请求异常,等待后重试
+ await asyncio.sleep(1)
+
+ # 超时返回
+ return {"status": "timeout", "error": "等待任务完成超时"}
\ No newline at end of file
diff --git a/apps/scheduler/call/code/schema.py b/apps/scheduler/call/code/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..93648a811a34983a81c9a644067115efcfa68ec5
--- /dev/null
+++ b/apps/scheduler/call/code/schema.py
@@ -0,0 +1,30 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""代码执行工具的数据结构"""
+
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class CodeInput(DataBase):
+ """代码执行工具的输入"""
+
+ code: str = Field(description="要执行的代码")
+ code_type: str = Field(description="代码类型,支持python、javascript、bash")
+ user_info: dict[str, Any] = Field(description="用户信息", default={})
+ security_level: str = Field(description="安全等级,low或high", default="low")
+ timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300)
+ memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024)
+ cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0)
+ input_arg: dict[str, Any] = Field(description="传递给main函数的输入参数", default={})
+
+
+class CodeOutput(DataBase):
+ """代码执行工具的输出"""
+
+ task_id: str = Field(description="任务ID")
+ status: str = Field(description="任务状态")
+ output: str = Field(description="执行输出", default="")
+ error: str = Field(description="错误信息", default="")
\ No newline at end of file
diff --git a/apps/scheduler/call/convert/convert.py b/apps/scheduler/call/convert/convert.py
index 27980bd8ad46aaaa7741d6a6919661aac580de64..bbe0dbe80217ba5e24c629fd02279086e08230fc 100644
--- a/apps/scheduler/call/convert/convert.py
+++ b/apps/scheduler/call/convert/convert.py
@@ -12,7 +12,7 @@ from pydantic import Field
from apps.scheduler.call.convert.schema import ConvertInput, ConvertOutput
from apps.scheduler.call.core import CallOutputChunk, CoreCall
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.scheduler import (
CallInfo,
CallOutputChunk,
@@ -30,7 +30,11 @@ class Convert(CoreCall, input_model=ConvertInput, output_model=ConvertOutput):
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="模板转换", description="使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。")
+ return CallInfo(
+ name="模板转换",
+ type=CallType.TRANSFORM,
+ description="使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。"
+ )
async def _init(self, call_vars: CallVars) -> ConvertInput:
"""初始化工具"""
diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py
index 2b1cbba83b9345e8df84e8f23f893137cb46532b..f6ae08661a0ece5d1f50753e64ec7ad62c634be9 100644
--- a/apps/scheduler/call/core.py
+++ b/apps/scheduler/call/core.py
@@ -6,6 +6,7 @@ Core Call类是定义了所有Call都应具有的方法和参数的PyDantic类
"""
import logging
+import re
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, ClassVar, Self
@@ -14,6 +15,7 @@ from pydantic.json_schema import SkipJsonSchema
from apps.llm.function import FunctionLLM
from apps.llm.reasoning import ReasoningLLM
+from apps.scheduler.variable.integration import VariableIntegration
from apps.schemas.enum_var import CallOutputType
from apps.schemas.pool import NodePool
from apps.schemas.scheduler import (
@@ -25,6 +27,7 @@ from apps.schemas.scheduler import (
CallVars,
)
from apps.schemas.task import FlowStepHistory
+from enum import Enum
if TYPE_CHECKING:
from apps.scheduler.executor.step import StepExecutor
@@ -32,6 +35,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class NodeType(str, Enum):
+
+ """
+ 节点类型 用来为前后端判断节点类型
+ 仅增加了当前在修改的几个类型
+ """
+
+ LLM = "llm"
+ API = "api"
+ RAG = 'rag'
+ MCP = 'mcp'
+ CHOICE = "choice"
+ CODE = 'code'
+ DIRECTREPLY = "directreply"
+
+
+
class DataBase(BaseModel):
"""所有Call的输入基类"""
@@ -70,12 +90,15 @@ class CoreCall(BaseModel):
)
to_user: bool = Field(description="是否需要将输出返回给用户", default=False)
+ enable_variable_resolution: bool = Field(description="是否启用自动变量解析", default=True)
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="allow",
)
+ node_type: NodeType | None = Field(description="节点类型,只包含已改造后的节点", default=None)
+
def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None:
"""初始化子类"""
@@ -83,14 +106,12 @@ class CoreCall(BaseModel):
cls.input_model = input_model
cls.output_model = output_model
-
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
err = "[CoreCall] 必须手动实现info方法"
raise NotImplementedError(err)
-
@staticmethod
def _assemble_call_vars(executor: "StepExecutor") -> CallVars:
"""组装CallVars"""
@@ -111,6 +132,7 @@ class CoreCall(BaseModel):
task_id=executor.task.id,
flow_id=executor.task.state.flow_id,
session_id=executor.task.ids.session_id,
+ conversation_id=executor.task.ids.conversation_id,
user_sub=executor.task.ids.user_sub,
app_id=executor.task.state.app_id,
),
@@ -120,7 +142,6 @@ class CoreCall(BaseModel):
summary=executor.task.runtime.summary,
)
-
@staticmethod
def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any:
"""
@@ -131,18 +152,16 @@ class CoreCall(BaseModel):
:return: 变量
"""
split_path = path.split("/")
+ if len(split_path) < 2:
+ err = f"[CoreCall] 路径格式错误: {path}"
+ logger.error(err)
+ return None
if split_path[0] not in history:
err = f"[CoreCall] 步骤{split_path[0]}不存在"
logger.error(err)
- raise CallError(
- message=err,
- data={
- "step_id": split_path[0],
- },
- )
-
+ return None
data = history[split_path[0]].output_data
- for key in split_path[1:]:
+ for key in split_path[2:]:
if key not in data:
err = f"[CoreCall] 输出Key {key} 不存在"
logger.error(err)
@@ -156,13 +175,27 @@ class CoreCall(BaseModel):
data = data[key]
return data
-
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self:
"""实例化Call类"""
+ if 'name' in kwargs:
+ # 如果kwargs中已经有name参数,使用kwargs中的值
+ name = kwargs.pop('name')
+ else:
+ # 否则使用executor.step.step.name
+ name = executor.step.step.name
+
+ # 检查kwargs中是否已经包含description参数,避免重复传入
+ if 'description' in kwargs:
+ # 如果kwargs中已经有description参数,使用kwargs中的值
+ description = kwargs.pop('description')
+ else:
+ # 否则使用executor.step.step.description
+ description = executor.step.step.description
+
obj = cls(
- name=executor.step.step.name,
- description=executor.step.step.description,
+ name=name,
+ description=description,
node=node,
**kwargs,
)
@@ -170,35 +203,206 @@ class CoreCall(BaseModel):
await obj._set_input(executor)
return obj
+ async def _initialize_variable_context(self, call_vars: CallVars) -> dict[str, Any]:
+ """初始化变量解析上下文并初始化系统变量"""
+ context = {
+ "question": call_vars.question,
+ "user_sub": call_vars.ids.user_sub,
+ "flow_id": call_vars.ids.flow_id,
+ "session_id": call_vars.ids.session_id,
+ "app_id": call_vars.ids.app_id,
+ "conversation_id": call_vars.ids.conversation_id,
+ }
+
+ await VariableIntegration.initialize_system_variables(context)
+ return context
+
+ async def _resolve_variables_in_config(self, config: Any, call_vars: CallVars) -> Any:
+ """解析配置中的变量引用
+
+ Args:
+ config: 配置值,可能包含变量引用
+ call_vars: Call变量
+
+ Returns:
+ 解析后的配置值
+ """
+ if isinstance(config, dict):
+ if "reference" in config:
+ # 解析变量引用
+ resolved_value = await VariableIntegration.resolve_variable_reference(
+ config["reference"],
+ user_sub=call_vars.ids.user_sub,
+ flow_id=call_vars.ids.flow_id,
+ conversation_id=call_vars.ids.conversation_id
+ )
+ return resolved_value
+ elif "value" in config:
+ # 使用默认值
+ return config["value"]
+ else:
+ # 递归解析字典中的所有值
+ resolved_dict = {}
+ for key, value in config.items():
+ resolved_dict[key] = await self._resolve_variables_in_config(value, call_vars)
+ return resolved_dict
+ elif isinstance(config, list):
+ # 递归解析列表中的所有值
+ resolved_list = []
+ for item in config:
+ resolved_item = await self._resolve_variables_in_config(item, call_vars)
+ resolved_list.append(resolved_item)
+ return resolved_list
+ elif isinstance(config, str):
+ # 解析字符串中的变量引用
+ return await self._resolve_variables_in_text(config, call_vars)
+ else:
+ # 直接返回配置值
+ return config
+
+ async def _resolve_variables_in_text(self, text: str, call_vars: CallVars) -> str:
+ """解析文本中的变量引用({{...}} 语法)
+
+ Args:
+ text: 包含变量引用的文本
+ call_vars: Call变量
+
+ Returns:
+ 解析后的文本
+ """
+ if not isinstance(text, str):
+ return text
+
+ # 检查是否包含变量引用语法
+ if not re.search(r'\{\{.*?\}\}', text):
+ return text
+
+ # 提取所有变量引用并逐一解析替换
+ variable_pattern = r'\{\{(.*?)\}\}'
+ matches = re.findall(variable_pattern, text)
+
+ resolved_text = text
+ for match in matches:
+ try:
+ # 解析变量引用
+ resolved_value = await VariableIntegration.resolve_variable_reference(
+ match.strip(),
+ user_sub=call_vars.ids.user_sub,
+ flow_id=call_vars.ids.flow_id,
+ conversation_id=call_vars.ids.conversation_id
+ )
+ # 替换原始文本中的变量引用
+ resolved_text = resolved_text.replace(f'{{{{{match}}}}}', str(resolved_value))
+ except Exception as e:
+ logger.warning(f"[CoreCall] 解析变量引用 '{match}' 失败: {e}")
+ # 如果解析失败,保留原始的变量引用
+ continue
+
+ return resolved_text
+
async def _set_input(self, executor: "StepExecutor") -> None:
"""获取Call的输入"""
self._sys_vars = self._assemble_call_vars(executor)
+ self._step_id = executor.step.step_id # 存储 step_id 用于变量名构造
+
+ # 如果启用了变量解析,初始化变量上下文
+ if self.enable_variable_resolution:
+ await self._initialize_variable_context(self._sys_vars)
+
input_data = await self._init(self._sys_vars)
self.input = input_data.model_dump(by_alias=True, exclude_none=True)
-
async def _init(self, call_vars: CallVars) -> DataBase:
"""初始化Call类,并返回Call的输入"""
err = "[CoreCall] 初始化方法必须手动实现"
raise NotImplementedError(err)
-
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""Call类实例的流式输出方法"""
yield CallOutputChunk(type=CallOutputType.TEXT, content="")
-
async def _after_exec(self, input_data: dict[str, Any]) -> None:
"""Call类实例的执行后方法"""
+ # 自动保存 extracted_data 到对话变量池
+ await self._save_extracted_data_to_variables()
+
+ async def _save_extracted_data_to_variables(self) -> None:
+ """将 extracted_data 保存到对话变量池"""
+ try:
+ # 检查是否有 output_parameters 配置
+ if not hasattr(self, 'output_parameters') or not self.output_parameters:
+ return
+
+ # 检查是否有输出数据
+ if not hasattr(self, 'input') or not self.input:
+ return
+
+ # 获取当前的输出数据(从最近的执行结果中)
+ # 注意:这里假设 extracted_data 已经被合并到输出中
+ output_data = getattr(self, '_last_output_data', {})
+ if not output_data or not isinstance(output_data, dict):
+ return
+
+ # 确定变量名前缀
+ from apps.scheduler.executor.step_config import should_use_direct_conversation_format
+ use_direct_format = should_use_direct_conversation_format(
+ call_id=getattr(self, '__class__').__name__.lower(),
+ step_name=self.name,
+ step_id=getattr(self, '_step_id', self.name)
+ )
+
+ var_prefix = "" if use_direct_format else f"{self._step_id}."
+
+ # 保存每个 output_parameter 到变量池
+ saved_count = 0
+ for param_name, param_config in self.output_parameters.items():
+ try:
+ # 检查输出数据中是否包含该参数
+ if param_name in output_data:
+ param_value = output_data[param_name]
+
+ # 构造变量名
+ var_name = f"{var_prefix}{param_name}"
+
+ # 保存到对话变量池
+ success = await VariableIntegration.save_conversation_variable(
+ var_name=var_name,
+ value=param_value,
+ var_type=param_config.get("type", "string") if isinstance(param_config, dict) else "string",
+ description=param_config.get("description", "") if isinstance(param_config, dict) else "",
+ user_sub=self._sys_vars.ids.user_sub,
+ flow_id=self._sys_vars.ids.flow_id,
+ conversation_id=self._sys_vars.ids.conversation_id
+ )
+
+ if success:
+ saved_count += 1
+ logger.debug(f"[CoreCall] 已保存提取数据到变量池: conversation.{var_name} = {param_value}")
+ else:
+ logger.warning(f"[CoreCall] 保存提取数据变量失败: {var_name}")
+
+ except Exception as e:
+ logger.warning(f"[CoreCall] 保存提取数据 {param_name} 失败: {e}")
+
+ if saved_count > 0:
+ logger.info(f"[CoreCall] 已保存 {saved_count} 个提取数据到对话变量池")
+
+ except Exception as e:
+ logger.error(f"[CoreCall] 保存提取数据到变量池失败: {e}")
async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""Call类实例的执行方法"""
+ self._last_output_data = {} # 初始化输出数据存储
+
async for chunk in self._exec(input_data):
+ # 捕获最后的输出数据
+ if chunk.type == CallOutputType.DATA and isinstance(chunk.content, dict):
+ self._last_output_data = chunk.content
yield chunk
- await self._after_exec(input_data)
+ await self._after_exec(input_data)
async def _llm(self, messages: list[dict[str, Any]]) -> str:
"""Call可直接使用的LLM非流式调用"""
@@ -210,9 +414,11 @@ class CoreCall(BaseModel):
self.output_tokens = llm.output_tokens
return result
-
async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel:
"""Call可直接使用的JSON生成"""
json = FunctionLLM()
result = await json.call(messages=messages, schema=schema.model_json_schema())
return schema.model_validate(result)
+
+
+
diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py
index 5865bc7e804491a7d6aa41fa251dc8c5d9c77dc6..a66aac5251ab580942c583576e9ec9bba3204951 100644
--- a/apps/scheduler/call/empty.py
+++ b/apps/scheduler/call/empty.py
@@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator
from typing import Any
from apps.scheduler.call.core import CoreCall, DataBase
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars
@@ -20,7 +20,11 @@ class Empty(CoreCall, input_model=DataBase, output_model=DataBase):
:return: Call的名称和描述
:rtype: CallInfo
"""
- return CallInfo(name="空白", description="空白节点,用于占位")
+ return CallInfo(
+ name="空白",
+ type=CallType.DEFAULT,
+ description="空白节点,用于占位"
+ )
async def _init(self, call_vars: CallVars) -> DataBase:
diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py
index f8aebcd748d92de4109553d0f68fecac43363f58..10241d85a97d3f2ea70a8a93a1a5a72173e4cfdd 100644
--- a/apps/scheduler/call/facts/facts.py
+++ b/apps/scheduler/call/facts/facts.py
@@ -16,7 +16,7 @@ from apps.scheduler.call.facts.schema import (
FactsInput,
FactsOutput,
)
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.pool import NodePool
from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars
from apps.services.user_domain import UserDomainManager
@@ -34,7 +34,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。")
+ return CallInfo(
+ name="提取事实",
+ type=CallType.DEFAULT,
+ description="从对话上下文和文档片段中提取事实。"
+ )
@classmethod
diff --git a/apps/scheduler/call/graph/graph.py b/apps/scheduler/call/graph/graph.py
index c2728f17913fcd0e8343168f2b508dbe6006fd6e..7383b2f2c85aacf78db11b00f95f22c2996f071f 100644
--- a/apps/scheduler/call/graph/graph.py
+++ b/apps/scheduler/call/graph/graph.py
@@ -11,7 +11,7 @@ from pydantic import Field
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput
from apps.scheduler.call.graph.style import RenderStyle
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.scheduler import (
CallError,
CallInfo,
@@ -29,7 +29,11 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput):
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="图表", description="将SQL查询出的数据转换为图表")
+ return CallInfo(
+ name="图表",
+ type=CallType.TRANSFORM,
+ description="将SQL查询出的数据转换为图表"
+ )
async def _init(self, call_vars: CallVars) -> RenderInput:
diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py
index 6a679dce98af6164211edd16b2fa38714d899f31..23098d9dac841008cb295400863364f727aada06 100644
--- a/apps/scheduler/call/llm/llm.py
+++ b/apps/scheduler/call/llm/llm.py
@@ -12,20 +12,25 @@ from jinja2.sandbox import SandboxedEnvironment
from pydantic import Field
from apps.llm.reasoning import ReasoningLLM
-from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.core import CoreCall, NodeType
from apps.scheduler.call.llm.prompt import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT
from apps.scheduler.call.llm.schema import LLMInput, LLMOutput
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.scheduler import (
CallError,
CallInfo,
CallOutputChunk,
CallVars,
)
+from apps.schemas.task import FlowStepHistory
+from apps.services.llm import LLMManager
+from apps.schemas.config import LLMConfig
+
logger = logging.getLogger(__name__)
+
class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput):
"""大模型调用工具"""
@@ -37,12 +42,20 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput):
step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3, ge=1, le=10)
system_prompt: str = Field(description="大模型系统提示词", default="You are a helpful assistant.")
user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT)
+ frequency_penalty: float = Field(description="频率惩罚大模型生成文本时用于减少重复词汇出现的参数", default=0)
+ llmId: str = Field(description="指定使用的大模型ID", default="empty")
+ # 增加node_type
+ node_type: NodeType | None = Field(description="节点类型", default=NodeType.LLM)
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="大模型", description="以指定的提示词和上下文信息调用大模型,并获得输出。")
+ return CallInfo(
+ name="大模型",
+ type=CallType.DEFAULT,
+ description="以指定的提示词和上下文信息调用大模型,并获得输出。"
+ )
async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]:
@@ -80,7 +93,7 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput):
try:
# 准备系统提示词
system_tmpl = env.from_string(self.system_prompt)
- system_input = system_tmpl.render(**formatter)
+ system_input = system_tmpl.render()
# 准备用户提示词
user_tmpl = env.from_string(self.user_prompt)
@@ -106,6 +119,8 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput):
data = LLMInput(**input_data)
try:
llm = ReasoningLLM()
+
+
async for chunk in llm.call(messages=data.message):
if not chunk:
continue
@@ -113,4 +128,4 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput):
self.tokens.input_tokens = llm.input_tokens
self.tokens.output_tokens = llm.output_tokens
except Exception as e:
- raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e
+ raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e
\ No newline at end of file
diff --git a/apps/scheduler/call/llm/prompt.py b/apps/scheduler/call/llm/prompt.py
index 0f227dcaa618b11a2c888f55a61fa51f349d7d8b..5701f7b3ff3e7783ae1e768c5003247c2fc42030 100644
--- a/apps/scheduler/call/llm/prompt.py
+++ b/apps/scheduler/call/llm/prompt.py
@@ -72,12 +72,12 @@ LLM_ERROR_PROMPT = dedent(
RAG_ANSWER_PROMPT = dedent(
r"""
- 你是由openEuler社区构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。
+ 你是由华鲲振宇构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。
用户的问题将在中给出,上下文背景信息将在中给出,文档片段将在中给出。
注意事项:
1. 输出不要包含任何XML标签。请确保输出内容的正确性,不要编造任何信息。
- 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫EulerCopilot,是openEuler社区的智能助手”。
+ 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫huakun-copilot,是华鲲振宇的智能助手”。
3. 背景信息仅供参考,若背景信息与用户问题无关,请忽略背景信息直接作答。
4. 请在回答中使用Markdown格式,并**不要**将内容放在"```"中。
diff --git a/apps/scheduler/call/llm/schema.py b/apps/scheduler/call/llm/schema.py
index c7bb50541d168a406fa478f25894c769b034ec9c..b418f96ad2652c87bad1542547ac83d8783a4d19 100644
--- a/apps/scheduler/call/llm/schema.py
+++ b/apps/scheduler/call/llm/schema.py
@@ -12,5 +12,8 @@ class LLMInput(DataBase):
message: list[dict[str, str]] = Field(description="输入给大模型的消息列表")
+
+
class LLMOutput(DataBase):
"""定义LLM工具调用的输出"""
+ text: str = Field(description="大模型输出")
diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py
index 661e9ada74da76e6d78a0f6cba42cef08c562335..d57102fc169e602fbf23136f664ee235d01e0034 100644
--- a/apps/scheduler/call/mcp/mcp.py
+++ b/apps/scheduler/call/mcp/mcp.py
@@ -8,7 +8,7 @@ from typing import Any
from pydantic import Field
-from apps.scheduler.call.core import CallError, CoreCall
+from apps.scheduler.call.core import CallError, CoreCall, NodeType
from apps.scheduler.call.mcp.schema import (
MCPInput,
MCPMessage,
@@ -16,7 +16,7 @@ from apps.scheduler.call.mcp.schema import (
MCPOutput,
)
from apps.scheduler.mcp import MCPHost, MCPPlanner, MCPSelector
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.mcp import MCPPlanItem
from apps.schemas.scheduler import (
CallInfo,
@@ -34,6 +34,9 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
max_steps: int = Field(description="最大步骤数", default=6)
text_output: bool = Field(description="是否将结果以文本形式返回", default=True)
to_user: bool = Field(description="是否将结果返回给用户", default=True)
+ # 增加node_type
+ node_type: NodeType | None = Field(description="节点类型", default=NodeType.MCP)
+
@classmethod
@@ -44,7 +47,11 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
:return: Call的名称和描述
:rtype: CallInfo
"""
- return CallInfo(name="MCP", description="调用MCP Server,执行工具")
+ return CallInfo(
+ name="MCP",
+ type=CallType.DEFAULT,
+ description="调用MCP Server,执行工具"
+ )
async def _init(self, call_vars: CallVars) -> MCPInput:
@@ -63,7 +70,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps)
-
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行MCP"""
# 生成计划
@@ -80,7 +86,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
async for chunk in self._generate_answer():
yield chunk
-
async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]:
"""生成执行计划"""
# 开始提示
@@ -89,6 +94,7 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
# 选择工具并生成计划
selector = MCPSelector()
top_tool = await selector.select_top_tool(self._call_vars.question, self.mcp_list)
+ logger.info("***************] 选择到的工具: %s", top_tool)
planner = MCPPlanner(self._call_vars.question)
self._plan = await planner.create_plan(top_tool, self.max_steps)
@@ -103,7 +109,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
data=self._plan.model_dump(),
)
-
async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]:
"""执行单个计划项"""
# 判断是否为Final
@@ -141,7 +146,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
},
)
-
async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]:
"""生成总结"""
# 提示开始总结
@@ -163,7 +167,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
).model_dump(),
)
-
def _create_output(
self,
text: str,
diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py
index e27327d8ad4d01387eeeb4f1644c056d32bc0dbe..ad9c52d46c98da31c07f94bda8cbabcf7c865ad3 100644
--- a/apps/scheduler/call/rag/rag.py
+++ b/apps/scheduler/call/rag/rag.py
@@ -11,9 +11,9 @@ from pydantic import Field
from apps.common.config import Config
from apps.llm.patterns.rewrite import QuestionRewrite
-from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.core import CoreCall, NodeType
from apps.scheduler.call.rag.schema import RAGInput, RAGOutput, SearchMethod
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.scheduler import (
CallError,
CallInfo,
@@ -36,11 +36,18 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput):
is_rerank: bool = Field(description="是否重新排序", default=False)
is_compress: bool = Field(description="是否压缩", default=False)
tokens_limit: int = Field(description="token限制", default=8192)
+ # 增加node_type
+ node_type: NodeType | None = Field(description="节点类型", default=NodeType.RAG)
+
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="知识库", description="查询知识库,从文档中获取必要信息")
+ return CallInfo(
+ name="知识库",
+ type=CallType.DEFAULT,
+ description="查询知识库,从文档中获取必要信息"
+ )
async def _init(self, call_vars: CallVars) -> RAGInput:
"""初始化RAG工具"""
diff --git a/apps/scheduler/call/reply/__init__.py b/apps/scheduler/call/reply/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5588bcc298b338a542312a22843494eb8fe1d890
--- /dev/null
+++ b/apps/scheduler/call/reply/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Reply工具模块"""
+
+from apps.scheduler.call.reply.direct_reply import DirectReply
+
+__all__ = [
+ "DirectReply",
+]
diff --git a/apps/scheduler/call/reply/direct_reply.py b/apps/scheduler/call/reply/direct_reply.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c06911021cfc7556e995d485fd7ad96d3847130
--- /dev/null
+++ b/apps/scheduler/call/reply/direct_reply.py
@@ -0,0 +1,67 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""直接回复工具"""
+
+import logging
+from collections.abc import AsyncGenerator
+from typing import Any
+
+from pydantic import Field
+
+from apps.scheduler.call.core import CoreCall, NodeType
+from apps.scheduler.call.reply.schema import DirectReplyInput, DirectReplyOutput
+from apps.schemas.enum_var import CallOutputType, CallType
+from apps.schemas.scheduler import (
+ CallError,
+ CallInfo,
+ CallOutputChunk,
+ CallVars,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectReplyOutput):
+ """直接回复工具,支持变量引用语法"""
+
+ to_user: bool = Field(default=True)
+ # 增加node_type
+ node_type: NodeType | None = Field(description="节点类型", default=NodeType.DIRECTREPLY)
+
+
+ @classmethod
+ def info(cls) -> CallInfo:
+ """返回Call的名称和描述"""
+ return CallInfo(
+ name="直接回复",
+ type=CallType.DEFAULT,
+ description="直接回复用户输入的内容,支持变量插入"
+ )
+
+ async def _init(self, call_vars: CallVars) -> DirectReplyInput:
+ """初始化DirectReply工具"""
+ answer = getattr(self, 'answer', '')
+ return DirectReplyInput(answer=answer)
+
+ async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行直接回复"""
+ data = DirectReplyInput(**input_data)
+
+ try:
+ # 使用基类的变量解析功能处理文本中的变量引用
+ final_answer = await self._resolve_variables_in_text(data.answer, self._sys_vars)
+
+ logger.info(f"[DirectReply] 原始答案: {data.answer}")
+ logger.info(f"[DirectReply] 解析后答案: {final_answer}")
+
+ # 直接返回处理后的内容
+ yield CallOutputChunk(
+ type=CallOutputType.TEXT,
+ content=final_answer
+ )
+
+ except Exception as e:
+ logger.error(f"[DirectReply] 处理回复内容失败: {e}")
+ raise CallError(
+ message=f"直接回复处理失败:{e!s}",
+ data={"original_answer": data.answer}
+ ) from e
diff --git a/apps/scheduler/call/reply/schema.py b/apps/scheduler/call/reply/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..f13dc636284dc6da764aa73fc62e99405674febe
--- /dev/null
+++ b/apps/scheduler/call/reply/schema.py
@@ -0,0 +1,18 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""DirectReply工具的输入输出定义"""
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class DirectReplyInput(DataBase):
+ """定义DirectReply工具调用的输入"""
+
+ answer: str = Field(description="直接回复的内容,支持变量引用语法")
+
+
+class DirectReplyOutput(DataBase):
+ """定义DirectReply工具调用的输出"""
+
+ message: str = Field(description="处理后的回复消息")
\ No newline at end of file
diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py
index 4f8e1010cc0bd88f22e050778bce236d7d3515e0..69c2bfbcd94c04adab5396a9f27763633715e182 100644
--- a/apps/scheduler/call/slot/slot.py
+++ b/apps/scheduler/call/slot/slot.py
@@ -15,7 +15,7 @@ from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.slot.prompt import SLOT_GEN_PROMPT
from apps.scheduler.call.slot.schema import SlotInput, SlotOutput
from apps.scheduler.slot.slot import Slot as SlotProcessor
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.pool import NodePool
from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars
@@ -36,7 +36,11 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="参数自动填充", description="根据步骤历史,自动填充参数")
+ return CallInfo(
+ name="参数自动填充",
+ type=CallType.TRANSFORM,
+ description="根据步骤历史,自动填充参数"
+ )
async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]:
diff --git a/apps/scheduler/call/sql/sql.py b/apps/scheduler/call/sql/sql.py
index 3e24301de508e06adf5cfdbf24b3d8ca37c0cc27..2f4ef97fb095baeac0b89ceb2dff1c98338fed18 100644
--- a/apps/scheduler/call/sql/sql.py
+++ b/apps/scheduler/call/sql/sql.py
@@ -12,7 +12,7 @@ from pydantic import Field
from apps.common.config import Config
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.sql.schema import SQLInput, SQLOutput
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.scheduler import (
CallError,
CallInfo,
@@ -35,7 +35,11 @@ class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput):
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="SQL查询", description="使用大模型生成SQL语句,用于查询数据库中的结构化数据")
+ return CallInfo(
+ name="SQL查询",
+ type=CallType.TOOL,
+ description="使用大模型生成SQL语句,用于查询数据库中的结构化数据"
+ )
async def _init(self, call_vars: CallVars) -> SQLInput:
diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py
index 1788fa0f4a8ede3af38264c9bb4a82628018086b..663e1d9c95d3faae2b51d75de0a7bab3c54ca66c 100644
--- a/apps/scheduler/call/suggest/suggest.py
+++ b/apps/scheduler/call/suggest/suggest.py
@@ -20,7 +20,7 @@ from apps.scheduler.call.suggest.schema import (
SuggestionInput,
SuggestionOutput,
)
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.pool import NodePool
from apps.schemas.record import RecordContent
from apps.schemas.scheduler import (
@@ -50,7 +50,11 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="问题推荐", description="在答案下方显示推荐的下一个问题")
+ return CallInfo(
+ name="问题推荐",
+ type=CallType.DEFAULT,
+ description="在答案下方显示推荐的下一个问题"
+ )
@classmethod
diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py
index b605204e179246f915d561c0884fd277712faf33..7f6ff062d522efa25af8a5fdc5ceec38d7ba967d 100644
--- a/apps/scheduler/call/summary/summary.py
+++ b/apps/scheduler/call/summary/summary.py
@@ -9,7 +9,7 @@ from pydantic import Field
from apps.llm.patterns.executor import ExecutorSummary
from apps.scheduler.call.core import CoreCall, DataBase
from apps.scheduler.call.summary.schema import SummaryOutput
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.pool import NodePool
from apps.schemas.scheduler import (
CallInfo,
@@ -31,7 +31,11 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput):
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
- return CallInfo(name="理解上下文", description="使用大模型,理解对话上下文")
+ return CallInfo(
+ name="理解上下文",
+ type=CallType.DEFAULT,
+ description="使用大模型,理解对话上下文"
+ )
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self:
diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py
index f6814dd364636d304382dc258f87fdd9f06eb9a8..2ff4f3d39633b26d792b0b2c2a257107b4874d30 100644
--- a/apps/scheduler/executor/agent.py
+++ b/apps/scheduler/executor/agent.py
@@ -7,6 +7,8 @@ from pydantic import Field
from apps.scheduler.executor.base import BaseExecutor
from apps.scheduler.mcp_agent.agent.mcp import MCPAgent
+from apps.schemas.task import ExecutorState, StepQueueItem
+from apps.services.task import TaskManager
logger = logging.getLogger(__name__)
@@ -15,26 +17,14 @@ class MCPAgentExecutor(BaseExecutor):
"""MCP Agent执行器"""
question: str = Field(description="用户输入")
- max_steps: int = Field(default=10, description="最大步数")
+ max_steps: int = Field(default=20, description="最大步数")
servers_id: list[str] = Field(description="MCP server id")
agent_id: str = Field(default="", description="Agent ID")
agent_description: str = Field(default="", description="Agent描述")
- async def run(self) -> None:
- """运行MCP Agent"""
- agent = await MCPAgent.create(
- servers_id=self.servers_id,
- max_steps=self.max_steps,
- task=self.task,
- msg_queue=self.msg_queue,
- question=self.question,
- agent_id=self.agent_id,
- description=self.agent_description,
- )
-
- try:
- answer = await agent.run(self.question)
- self.task = agent.task
- self.task.runtime.answer = answer
- except Exception as e:
- logger.error(f"Error: {str(e)}")
+ async def load_state(self) -> None:
+ """从数据库中加载FlowExecutor的状态"""
+ logger.info("[FlowExecutor] 加载Executor状态")
+ # 尝试恢复State
+ if self.task.state:
+ self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index a70d0d7073c50ec2290c3fd4c797bbb3efcd4501..a86ec4ac063c3cda5c6444ae17d7af5a6f6e7f46 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -47,7 +47,6 @@ class FlowExecutor(BaseExecutor):
question: str = Field(description="用户输入")
post_body_app: RequestDataApp = Field(description="请求体中的app信息")
-
async def load_state(self) -> None:
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
@@ -70,7 +69,6 @@ class FlowExecutor(BaseExecutor):
self._reached_end: bool = False
self.step_queue: deque[StepQueueItem] = deque()
-
async def _invoke_runner(self, queue_item: StepQueueItem) -> None:
"""单一Step执行"""
# 创建步骤Runner
@@ -90,7 +88,6 @@ class FlowExecutor(BaseExecutor):
# 更新Task(已存过库)
self.task = step_runner.task
-
async def _step_process(self) -> None:
"""执行当前queue里面的所有步骤(在用户看来是单一Step)"""
while True:
@@ -102,7 +99,6 @@ class FlowExecutor(BaseExecutor):
# 执行Step
await self._invoke_runner(queue_item)
-
async def _find_next_id(self, step_id: str) -> list[str]:
"""查找下一个节点"""
next_ids = []
@@ -111,14 +107,22 @@ class FlowExecutor(BaseExecutor):
next_ids += [edge.edge_to]
return next_ids
-
async def _find_flow_next(self) -> list[StepQueueItem]:
"""在当前步骤执行前,尝试获取下一步"""
# 如果当前步骤为结束,则直接返回
- if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type]
+ if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type]
return []
-
- next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type]
+ if self.task.state.step_name == "Choice":
+ # 如果是choice节点,获取分支ID
+ branch_id = self.task.context[-1]["output_data"]["branch_id"]
+ if branch_id:
+ self.task.state.step_id = self.task.state.step_id + "." + branch_id
+ logger.info("[FlowExecutor] 分支ID:%s", branch_id)
+ else:
+ logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表")
+ return []
+
+ next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type]
# 如果step没有任何出边,直接跳到end
if not next_steps:
return [
@@ -137,7 +141,6 @@ class FlowExecutor(BaseExecutor):
for next_step in next_steps
]
-
async def run(self) -> None:
"""
运行流,返回各步骤结果,直到无法继续执行
@@ -150,8 +153,8 @@ class FlowExecutor(BaseExecutor):
# 获取首个步骤
first_step = StepQueueItem(
- step_id=self.task.state.step_id, # type: ignore[arg-type]
- step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type]
+ step_id=self.task.state.step_id, # type: ignore[arg-type]
+ step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type]
)
# 头插开始前的系统步骤,并执行
@@ -170,7 +173,7 @@ class FlowExecutor(BaseExecutor):
# 运行Flow(未达终点)
while not self._reached_end:
# 如果当前步骤出错,执行错误处理步骤
- if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type]
+ if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type]
logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤")
self.step_queue.clear()
self.step_queue.appendleft(StepQueueItem(
@@ -183,7 +186,7 @@ class FlowExecutor(BaseExecutor):
params={
"user_prompt": LLM_ERROR_PROMPT.replace(
"{{ error_info }}",
- self.task.state.error_info["err_msg"], # type: ignore[arg-type]
+ self.task.state.error_info["err_msg"], # type: ignore[arg-type]
),
},
),
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
index 6b3451fa9ccd92b8f9b2d497f3983a64ff3a6981..5bf37392a4f599e464b00b10d47a3f622fe9289c 100644
--- a/apps/scheduler/executor/step.py
+++ b/apps/scheduler/executor/step.py
@@ -13,11 +13,14 @@ from pydantic import ConfigDict
from apps.scheduler.call.core import CoreCall
from apps.scheduler.call.empty import Empty
from apps.scheduler.call.facts.facts import FactsCall
+from apps.scheduler.call.reply.direct_reply import DirectReply
from apps.scheduler.call.slot.schema import SlotOutput
from apps.scheduler.call.slot.slot import Slot
from apps.scheduler.call.summary.summary import Summary
from apps.scheduler.executor.base import BaseExecutor
+from apps.scheduler.executor.step_config import should_use_direct_conversation_format
from apps.scheduler.pool.pool import Pool
+from apps.scheduler.variable.integration import VariableIntegration
from apps.schemas.enum_var import (
EventType,
SpecialCallType,
@@ -70,6 +73,8 @@ class StepExecutor(BaseExecutor):
return FactsCall
if call_id == SpecialCallType.SLOT.value:
return Slot
+ if call_id == SpecialCallType.DIRECT_REPLY.value:
+ return DirectReply
# 从Pool中获取对应的Call
call_cls: type[CoreCall] = await Pool().get_call(call_id)
@@ -86,8 +91,8 @@ class StepExecutor(BaseExecutor):
logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name)
# State写入ID和运行状态
- self.task.state.step_id = self.step.step_id # type: ignore[arg-type]
- self.task.state.step_name = self.step.step.name # type: ignore[arg-type]
+ self.task.state.step_id = self.step.step_id # type: ignore[arg-type]
+ self.task.state.step_name = self.step.step.name # type: ignore[arg-type]
# 获取并验证Call类
node_id = self.step.step.node
@@ -119,7 +124,6 @@ class StepExecutor(BaseExecutor):
logger.exception("[StepExecutor] 初始化Call失败")
raise
-
async def _run_slot_filling(self) -> None:
"""运行自动参数填充;相当于特殊Step,但是不存库"""
# 判断是否需要进行自动参数填充
@@ -127,13 +131,13 @@ class StepExecutor(BaseExecutor):
return
# 暂存旧数据
- current_step_id = self.task.state.step_id # type: ignore[arg-type]
- current_step_name = self.task.state.step_name # type: ignore[arg-type]
+ current_step_id = self.task.state.step_id # type: ignore[arg-type]
+ current_step_name = self.task.state.step_name # type: ignore[arg-type]
# 更新State
- self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type]
- self.task.state.step_name = "自动参数填充" # type: ignore[arg-type]
- self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type]
+ self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type]
+ self.task.state.step_name = "自动参数填充" # type: ignore[arg-type]
+ self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type]
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2)
# 初始化填参
@@ -156,21 +160,20 @@ class StepExecutor(BaseExecutor):
# 如果没有填全,则状态设置为待填参
if result.remaining_schema:
- self.task.state.status = StepStatus.PARAM # type: ignore[arg-type]
+ self.task.state.status = StepStatus.PARAM # type: ignore[arg-type]
else:
- self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type]
+ self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type]
await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True))
# 更新输入
self.obj.input.update(result.slot_data)
# 恢复State
- self.task.state.step_id = current_step_id # type: ignore[arg-type]
- self.task.state.step_name = current_step_name # type: ignore[arg-type]
+ self.task.state.step_id = current_step_id # type: ignore[arg-type]
+ self.task.state.step_name = current_step_name # type: ignore[arg-type]
self.task.tokens.input_tokens += self.obj.tokens.input_tokens
self.task.tokens.output_tokens += self.obj.tokens.output_tokens
-
async def _process_chunk(
self,
iterator: AsyncGenerator[CallOutputChunk, None],
@@ -203,43 +206,149 @@ class StepExecutor(BaseExecutor):
return content
+ async def _save_output_parameters_to_variables(self, output_data: str | dict[str, Any]) -> None:
+ """保存节点输出参数到变量池"""
+ try:
+ # 检查是否有output_parameters配置
+ output_parameters = None
+ if self.step.step.params and isinstance(self.step.step.params, dict):
+ output_parameters = self.step.step.params.get("output_parameters", {})
+
+ if not output_parameters or not isinstance(output_parameters, dict):
+ return
+
+ # 确保output_data是字典格式
+ if isinstance(output_data, str):
+ # 如果是字符串,包装成字典
+ data_dict = {"text": output_data}
+ else:
+ data_dict = output_data if isinstance(output_data, dict) else {}
+
+ # 确定变量名前缀(根据配置决定是否使用直接格式)
+ use_direct_format = should_use_direct_conversation_format(
+ call_id=self._call_id,
+ step_name=self.step.step.name,
+ step_id=self.step.step_id
+ )
+
+ if use_direct_format:
+ # 配置允许的节点类型保持原有格式:conversation.key
+ var_prefix = ""
+ logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用直接变量格式")
+ else:
+ # 其他节点使用格式:conversation.node_id.key
+ var_prefix = f"{self.step.step_id}."
+ logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用带前缀变量格式")
+
+ # 保存每个output_parameter到变量池
+ saved_count = 0
+ for param_name, param_config in output_parameters.items():
+ try:
+ # 获取参数值
+ param_value = self._extract_value_from_output_data(param_name, data_dict, param_config)
+
+ if param_value is not None:
+ # 构造变量名
+ var_name = f"{var_prefix}{param_name}"
+
+ # 保存到对话变量池
+ success = await VariableIntegration.save_conversation_variable(
+ var_name=var_name,
+ value=param_value,
+ var_type=param_config.get("type", "string"),
+ description=param_config.get("description", ""),
+ user_sub=self.task.ids.user_sub,
+ flow_id=self.task.state.flow_id, # type: ignore[arg-type]
+ conversation_id=self.task.ids.conversation_id
+ )
+
+ if success:
+ saved_count += 1
+ logger.debug(f"[StepExecutor] 已保存输出参数变量: conversation.{var_name} = {param_value}")
+ else:
+ logger.warning(f"[StepExecutor] 保存输出参数变量失败: {var_name}")
+
+ except Exception as e:
+ logger.warning(f"[StepExecutor] 保存输出参数 {param_name} 失败: {e}")
+
+ if saved_count > 0:
+ logger.info(f"[StepExecutor] 已保存 {saved_count} 个输出参数到变量池")
+
+ except Exception as e:
+ logger.error(f"[StepExecutor] 保存输出参数到变量池失败: {e}")
+
+ def _extract_value_from_output_data(self, param_name: str, output_data: dict[str, Any], param_config: dict) -> Any:
+ """从输出数据中提取参数值"""
+ # 支持多种提取方式
+
+ # 1. 直接从输出数据中获取同名key
+ if param_name in output_data:
+ return output_data[param_name]
+
+ # 2. 支持路径提取(例如:result.data.value)
+ if "path" in param_config:
+ path = param_config["path"]
+ current_data = output_data
+ for key in path.split("."):
+ if isinstance(current_data, dict) and key in current_data:
+ current_data = current_data[key]
+ else:
+ return None
+ return current_data
+
+ # 3. 支持默认值
+ if "default" in param_config:
+ return param_config["default"]
+
+ # 4. 如果参数配置为"full_output",返回完整输出
+ if param_config.get("source") == "full_output":
+ return output_data
+
+ return None
+
+
async def run(self) -> None:
"""运行单个步骤"""
self.validate_flow_state(self.task)
logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name)
+ # logger.info("************************step.py中的run的self: %r", self)
+
# 进行自动参数填充
await self._run_slot_filling()
# 更新状态
- self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type]
+ self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type]
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2)
# 推送输入
await self.push_message(EventType.STEP_INPUT.value, self.obj.input)
# 执行步骤
+
+ # node type 是新加来判断是否是改造后的node的
iterator = self.obj.exec(self, self.obj.input)
try:
content = await self._process_chunk(iterator, to_user=self.obj.to_user)
+
except Exception as e:
logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤")
- self.task.state.status = StepStatus.ERROR # type: ignore[arg-type]
+ self.task.state.status = StepStatus.ERROR # type: ignore[arg-type]
await self.push_message(EventType.STEP_OUTPUT.value, {})
if isinstance(e, CallError):
- self.task.state.error_info = { # type: ignore[arg-type]
+ self.task.state.error_info = { # type: ignore[arg-type]
"err_msg": e.message,
"data": e.data,
}
else:
- self.task.state.error_info = { # type: ignore[arg-type]
+ self.task.state.error_info = { # type: ignore[arg-type]
"err_msg": str(e),
"data": {},
}
return
# 更新执行状态
- self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type]
+ self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type]
self.task.tokens.input_tokens += self.obj.tokens.input_tokens
self.task.tokens.output_tokens += self.obj.tokens.output_tokens
self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time
@@ -250,19 +359,22 @@ class StepExecutor(BaseExecutor):
else:
output_data = content
+ # 保存output_parameters到变量池
+ await self._save_output_parameters_to_variables(output_data)
+
# 更新context
history = FlowStepHistory(
task_id=self.task.id,
- flow_id=self.task.state.flow_id, # type: ignore[arg-type]
- flow_name=self.task.state.flow_name, # type: ignore[arg-type]
+ flow_id=self.task.state.flow_id, # type: ignore[arg-type]
+ flow_name=self.task.state.flow_name, # type: ignore[arg-type]
step_id=self.step.step_id,
step_name=self.step.step.name,
step_description=self.step.step.description,
- status=self.task.state.status, # type: ignore[arg-type]
+ status=self.task.state.status, # type: ignore[arg-type]
input_data=self.obj.input,
output_data=output_data,
)
- self.task.context.append(history.model_dump(exclude_none=True, by_alias=True))
+ self.task.context.append(history)
# 推送输出
await self.push_message(EventType.STEP_OUTPUT.value, output_data)
diff --git a/apps/scheduler/executor/step_config.py b/apps/scheduler/executor/step_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1abcf44960a3dee65ba537bb731a71a21f64a43b
--- /dev/null
+++ b/apps/scheduler/executor/step_config.py
@@ -0,0 +1,79 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""步骤执行器配置"""
+
+from typing import Set
+
+# 可以直接写入 conversation.key 格式的节点类型配置
+# 这些节点类型的输出变量不会添加 step_id 前缀
+#
+# 说明:
+# - 添加新的节点类型时,直接在这个集合中添加对应的 call_id 或节点名称即可
+# - 支持大小写敏感匹配,建议同时添加大小写版本以确保兼容性
+# - 这些节点的输出变量将保存为 conversation.key 格式
+# - 其他节点的输出变量将保存为 conversation.step_id.key 格式
+DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: Set[str] = {
+ # 开始节点相关
+ "Start",
+ "start",
+
+ # 输入节点相关
+ "Input",
+ "UserInput",
+ "input",
+
+ # 未来可能的节点类型示例(取消注释即可启用)
+ # "GlobalConfig", # 全局配置节点
+ # "SessionInit", # 会话初始化节点
+ # "SystemConfig", # 系统配置节点
+}
+
+# 可以通过节点名称模式匹配的规则
+# 如果节点名称或step_id(转换为小写后)以这些字符串开头,则使用直接格式
+#
+# 说明:
+# - 这些模式用于匹配节点名称或step_id的前缀(不区分大小写)
+# - 比如:"start"会匹配 "StartProcess"、"start_workflow" 等
+# - 适用于无法提前知道具体节点名称,但可以通过命名规范识别的场景
+DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: Set[str] = {
+ "start", # 匹配所有以start开头的节点
+ "init", # 匹配所有以init开头的节点
+ "input", # 匹配所有以input开头的节点
+
+ # 可以根据需要添加更多模式
+ # "config", # 匹配配置相关节点
+ # "setup", # 匹配设置相关节点
+}
+
+def should_use_direct_conversation_format(call_id: str, step_name: str, step_id: str) -> bool:
+ """
+ 判断是否应该使用直接的 conversation.key 格式
+
+ Args:
+ call_id: 节点的call_id
+ step_name: 节点名称
+ step_id: 节点ID
+
+ Returns:
+ bool: True表示使用 conversation.key,False表示使用 conversation.step_id.key
+ """
+ # 1. 检查call_id是否在直接写入列表中
+ if call_id in DIRECT_CONVERSATION_VARIABLE_NODE_TYPES:
+ return True
+
+ # 2. 检查节点名称是否在直接写入列表中
+ if step_name in DIRECT_CONVERSATION_VARIABLE_NODE_TYPES:
+ return True
+
+ # 3. 检查节点名称是否匹配模式
+ step_name_lower = step_name.lower()
+ for pattern in DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS:
+ if step_name_lower.startswith(pattern):
+ return True
+
+ # 4. 检查step_id是否匹配模式
+ step_id_lower = step_id.lower()
+ for pattern in DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS:
+ if step_id_lower.startswith(pattern):
+ return True
+
+ return False
\ No newline at end of file
diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py
index 78aa7bc3ee869e8710e1fb02a2d9fb438d04be34..93c60711fd0d0f90a340c4f8b0d47e79c221a3ff 100644
--- a/apps/scheduler/mcp/host.py
+++ b/apps/scheduler/mcp/host.py
@@ -69,7 +69,7 @@ class MCPHost:
context_list = []
for ctx_id in self._context_list:
- context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None)
+ context = next((ctx for ctx in task.context if ctx.id == ctx_id), None)
if not context:
continue
context_list.append(context)
@@ -120,7 +120,7 @@ class MCPHost:
logger.error("任务 %s 不存在", self._task_id)
return {}
self._context_list.append(context.id)
- task.context.append(context.model_dump(by_alias=True, exclude_none=True))
+ task.context.append(context)
await TaskManager.save_task(self._task_id, task)
return output_data
diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py
index cd4f5975eea3f023a92626966081c2d1eb33bdb7..b4b982037767c1ab7b991a37d98f39b846144f96 100644
--- a/apps/scheduler/mcp/plan.py
+++ b/apps/scheduler/mcp/plan.py
@@ -1,6 +1,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP 用户目标拆解与规划"""
-
+import logging
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
@@ -9,6 +9,7 @@ from apps.llm.reasoning import ReasoningLLM
from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER
from apps.schemas.mcp import MCPPlan, MCPTool
+logger = logging.getLogger(__name__)
class MCPPlanner:
"""MCP 用户目标拆解与规划"""
@@ -50,6 +51,7 @@ class MCPPlanner:
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
+
reasoning_llm = ReasoningLLM()
result = ""
async for chunk in reasoning_llm.call(
diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py
index b322fb0883e8ed935243389cb86066845a549631..8d45bb67db897f45bf440a726c80a4b2d2896d4e 100644
--- a/apps/scheduler/mcp/prompt.py
+++ b/apps/scheduler/mcp/prompt.py
@@ -62,25 +62,21 @@ MCP_SELECT = dedent(r"""
### 请一步一步思考:
""")
+
CREATE_PLAN = dedent(r"""
你是一个计划生成器。
请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。
-
# 一个好的计划应该:
-
1. 能够成功完成用户的目标
2. 计划中的每一个步骤必须且只能使用一个工具。
3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。
4. 计划中的最后一步必须是Final工具,以确保计划执行结束。
-
# 生成计划时的注意事项:
-
- 每一条计划包含3个部分:
- 计划内容:描述单个计划步骤的大致内容
- 工具ID:必须从下文的工具列表中选择
- 工具指令:改写用户的目标,使其更符合工具的输入要求
- 必须按照如下格式生成计划,不要输出任何额外数据:
-
```json
{
"plans": [
@@ -92,16 +88,12 @@ CREATE_PLAN = dedent(r"""
]
}
```
-
- 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\
-思考过程应放置在 XML标签中。
+思考过程应放置在 XML标签中。
- 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。
- 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。
-
# 工具
-
你可以访问并使用一些工具,这些工具将在 XML标签中给出。
-
{% for tool in tools %}
- {{ tool.id }}{{tool.name}};{{ tool.description }}
@@ -113,12 +105,9 @@ CREATE_PLAN = dedent(r"""
# 样例
## 目标
-
在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。
-
## 计划
-
-
+
1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server
2. 目标可以拆解为以下几个部分:
- 运行alpine:latest容器
@@ -126,7 +115,7 @@ CREATE_PLAN = dedent(r"""
- 在后台运行
- 执行top命令
3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令
-
+
```json
{
@@ -156,9 +145,7 @@ CREATE_PLAN = dedent(r"""
```
# 现在开始生成计划:
-
## 目标
-
{{goal}}
# 计划
@@ -188,7 +175,7 @@ EVALUATE_PLAN = dedent(r"""
# 进行评估时的注意事项:
- - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。
+ - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。
- 评估结果分为两个部分:
- 计划评估的结论
- 改进后的计划
@@ -212,6 +199,7 @@ EVALUATE_PLAN = dedent(r"""
""")
FINAL_ANSWER = dedent(r"""
综合理解计划执行结果和背景信息,向用户报告目标的完成情况。
+ **注意:你只负责汇报目标完成情况,不对任务执行过程中的准确性/合理性等判断**
# 用户目标
diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py
index 2ff5034471c5e9c38f166c6187b76dfb4596f734..cca35ef2eec1644c5512729762681ffd7cd7febe 100644
--- a/apps/scheduler/mcp/select.py
+++ b/apps/scheduler/mcp/select.py
@@ -171,6 +171,7 @@ class MCPSelector:
for tool_vec in tool_vecs:
# 到MongoDB里找对应的工具
logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"])
+ logger.info("[mcp_id与tools.id]: %s, %s", tool_vec["mcp_id"], tool_vec["id"])
tool_data = await tool_collection.aggregate([
{"$match": {"_id": tool_vec["mcp_id"]}},
{"$unwind": "$tools"},
@@ -178,6 +179,7 @@ class MCPSelector:
{"$project": {"_id": 0, "tools": 1}},
{"$replaceRoot": {"newRoot": "$tools"}},
])
+ logger.info(f"[********** tool_data: {tool_data}")
async for tool in tool_data:
tool_obj = MCPTool.model_validate(tool)
llm_tool_list.append(tool_obj)
diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py
index 2834487860a5f8cad1c33723ac5add4976160b57..157342875dfaf679ac112c8ff43e60126b28c4a7 100644
--- a/apps/scheduler/pool/loader/call.py
+++ b/apps/scheduler/pool/loader/call.py
@@ -42,13 +42,12 @@ class CallLoader(metaclass=SingletonMeta):
call_metadata.append(
CallPool(
_id=call_id,
- type=CallType.SYSTEM,
+ type=call_info.type,
name=call_info.name,
description=call_info.description,
path=f"python::apps.scheduler.call::{call_id}",
),
)
-
return call_metadata
async def _load_single_call_dir(self, call_dir_name: str) -> list[CallPool]:
@@ -189,6 +188,7 @@ class CallLoader(metaclass=SingletonMeta):
NodePool(
_id=call.id,
name=call.name,
+ type=call.type,
description=call.description,
service_id="",
call_id=call.id,
diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py
index 57344d40a5c1aba0454a7f00fffcfbf5a93ee221..2b4fcdc7afdcadbcc176486685bd89ab4009233c 100644
--- a/apps/scheduler/pool/loader/flow.py
+++ b/apps/scheduler/pool/loader/flow.py
@@ -27,6 +27,10 @@ BASE_PATH = Path(Config().get_config().deploy.data_dir) / "semantics" / "app"
class FlowLoader:
"""工作流加载器"""
+ # 添加并发控制
+ _loading_flows = {} # 改为字典,存储加载任务
+ _loading_lock = asyncio.Lock()
+
async def _load_yaml_file(self, flow_path: Path) -> dict[str, Any]:
"""从YAML文件加载工作流配置"""
try:
@@ -100,7 +104,61 @@ class FlowLoader:
async def load(self, app_id: str, flow_id: str) -> Flow | None:
"""从文件系统中加载【单个】工作流"""
- logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", flow_id, app_id)
+ flow_key = f"{app_id}:{flow_id}"
+
+ # 第一次检查:是否已在加载中
+ existing_task = None
+ async with self._loading_lock:
+ if flow_key in self._loading_flows:
+ existing_task = self._loading_flows[flow_key]
+
+ # 如果找到现有任务,等待其完成
+ if existing_task is not None:
+ logger.info(f"[FlowLoader] 工作流正在加载中,等待完成: {flow_key}")
+ try:
+ return await existing_task
+ except Exception as e:
+ logger.error(f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}")
+ # 如果等待失败,清理失败的任务并重试
+ async with self._loading_lock:
+ if self._loading_flows.get(flow_key) == existing_task:
+ self._loading_flows.pop(flow_key, None)
+ return None
+
+ # 创建新的加载任务
+ task = None
+ async with self._loading_lock:
+ # 再次检查,防止竞态条件
+ if flow_key in self._loading_flows:
+ existing_task = self._loading_flows[flow_key]
+ # 如果有新任务出现,等待它完成
+ if existing_task is not None:
+ try:
+ return await existing_task
+ except Exception as e:
+ logger.error(f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}")
+ return None
+
+ # 创建新的加载任务
+ task = asyncio.create_task(self._do_load(app_id, flow_id))
+ self._loading_flows[flow_key] = task
+
+ # 执行加载任务
+ try:
+ result = await task
+ return result
+ except Exception as e:
+ logger.error(f"[FlowLoader] 工作流加载失败: {flow_key}, 错误: {e}")
+ return None
+ finally:
+ # 确保从加载集合中移除
+ async with self._loading_lock:
+ if self._loading_flows.get(flow_key) == task:
+ self._loading_flows.pop(flow_key, None)
+
+ async def _do_load(self, app_id: str, flow_id: str) -> Flow | None:
+ """实际执行加载工作流的方法"""
+ logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", app_id, flow_id)
# 构建工作流文件路径
flow_path = BASE_PATH / app_id / "flow" / f"{flow_id}.yaml"
@@ -235,18 +293,29 @@ class FlowLoader:
except Exception:
logger.exception("[FlowLoader] 更新 MongoDB 失败")
- # 删除重复的ID
- while True:
+ # 删除重复的ID,增加重试次数限制
+ max_retries = 10
+ retry_count = 0
+ while retry_count < max_retries:
try:
table = await LanceDB().get_table("flow")
await table.delete(f"id = '{metadata.id}'")
break
except RuntimeError as e:
if "Commit conflict" in str(e):
- logger.error("[FlowLoader] LanceDB删除flow冲突,重试中...") # noqa: TRY400
- await asyncio.sleep(0.01)
+ retry_count += 1
+ logger.error(f"[FlowLoader] LanceDB删除flow冲突,重试中... ({retry_count}/{max_retries})") # noqa: TRY400
+ # 指数退避,减少冲突概率
+ await asyncio.sleep(0.01 * (2 ** min(retry_count, 5)))
else:
raise
+ except Exception as e:
+ logger.error(f"[FlowLoader] LanceDB删除操作异常: {e}")
+ break
+
+ if retry_count >= max_retries:
+ logger.warning(f"[FlowLoader] LanceDB删除flow达到最大重试次数,跳过删除: {metadata.id}")
+ # 不抛出异常,继续执行后续操作
# 进行向量化
service_embedding = await Embedding.get_embedding([metadata.description])
vector_data = [
@@ -256,7 +325,10 @@ class FlowLoader:
embedding=service_embedding[0],
),
]
- while True:
+ # 插入向量数据,增加重试次数限制
+ max_retries_insert = 10
+ retry_count_insert = 0
+ while retry_count_insert < max_retries_insert:
try:
table = await LanceDB().get_table("flow")
await table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute(
@@ -265,7 +337,16 @@ class FlowLoader:
break
except RuntimeError as e:
if "Commit conflict" in str(e):
- logger.error("[FlowLoader] LanceDB插入flow冲突,重试中...") # noqa: TRY400
- await asyncio.sleep(0.01)
+ retry_count_insert += 1
+ logger.error(f"[FlowLoader] LanceDB插入flow冲突,重试中... ({retry_count_insert}/{max_retries_insert})") # noqa: TRY400
+ # 指数退避,减少冲突概率
+ await asyncio.sleep(0.01 * (2 ** min(retry_count_insert, 5)))
else:
raise
+ except Exception as e:
+ logger.error(f"[FlowLoader] LanceDB插入操作异常: {e}")
+ break
+
+ if retry_count_insert >= max_retries_insert:
+ logger.error(f"[FlowLoader] LanceDB插入flow达到最大重试次数,操作失败: {metadata.id}")
+ raise RuntimeError(f"LanceDB插入flow失败,达到最大重试次数: {metadata.id}")
diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py
index 66a516e77a69bb9e3418953d7ae6fe2c8ddc4052..3cd6183a1c8025977396765c703446ae9b8feb77 100644
--- a/apps/scheduler/pool/loader/mcp.py
+++ b/apps/scheduler/pool/loader/mcp.py
@@ -153,7 +153,6 @@ class MCPLoader(metaclass=SingletonMeta):
# 检查目录
template_path = MCP_PATH / "template" / mcp_id
await Path.mkdir(template_path, parents=True, exist_ok=True)
- ProcessHandler.clear_finished_tasks()
# 安装MCP模板
if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, mcp_id, config):
err = f"安装任务无法执行,请稍后重试: {mcp_id}"
@@ -451,23 +450,27 @@ class MCPLoader(metaclass=SingletonMeta):
)
@staticmethod
- async def _find_deleted_mcp() -> list[str]:
+ async def _find_deleted_and_undeleted_mcp() -> tuple[list[str], list[str]]:
"""
查找在文件系统中被修改和被删除的MCP
- :return: 被修改的MCP列表和被删除的MCP列表
+ :return: 被修改的MCP列表和被删除的MCP列表, 和未被修改和删除的MCP列表
:rtype: tuple[list[str], list[str]]
"""
deleted_mcp_list = []
+ undeleted_mcp_list = []
mcp_collection = MongoDB().get_collection("mcp")
mcp_list = await mcp_collection.find({}, {"_id": 1}).to_list(None)
for db_item in mcp_list:
mcp_path: Path = MCP_PATH / "template" / db_item["_id"]
- if not await mcp_path.exists():
+ if await mcp_path.exists():
+ undeleted_mcp_list.append(db_item["_id"])
+ else:
deleted_mcp_list.append(db_item["_id"])
logger.info("[MCPLoader] 这些MCP在文件系统中被删除: %s", deleted_mcp_list)
- return deleted_mcp_list
+ logger.info("[MCPLoader] 这些MCP在文件系统中存在: %s", undeleted_mcp_list)
+ return deleted_mcp_list, undeleted_mcp_list
@staticmethod
async def remove_deleted_mcp(deleted_mcp_list: list[str]) -> None:
@@ -505,6 +508,45 @@ class MCPLoader(metaclass=SingletonMeta):
raise
logger.info("[MCPLoader] 清除LanceDB中无效的MCP")
+ @staticmethod
+ async def remove_cached_mcp_in_lance(undeleted_mcp_list: list[str]) -> None:
+ """
+ 清除LanceDB中缓存的MCP
+
+ :return: 无
+ """
+
+ try:
+ mcp_table = await LanceDB().get_table("mcp")
+
+ all_data = await mcp_table.to_pandas()
+ logger.info(f"表中原始数据量: {len(all_data)}条")
+
+ if len(all_data) == 0:
+ logger.info("[MCPLoader] LanceDB中没有MCP数据,跳过删除")
+ return
+
+ ids_to_delete = all_data[~all_data['id'].isin(undeleted_mcp_list)]['id'].tolist()
+
+ if not ids_to_delete:
+ logger.info("没有需要删除的记录")
+ return
+
+ logger.info(f"将要删除 {len(ids_to_delete)} 条记录,ID列表: {ids_to_delete}")
+ for id_to_delete in ids_to_delete:
+ await mcp_table.delete(f"id == '{id_to_delete}'")
+
+ logger.info("删除操作完成")
+
+ remaining_data = await mcp_table.to_pandas()
+ logger.info(f"删除后表中剩余数据量: {len(remaining_data)}条")
+ except Exception as e:
+ if "Commit conflict" in str(e):
+ logger.error("[MCPLoader] LanceDB删除mcp冲突,重试中...") # noqa: TRY400
+ await asyncio.sleep(0.01)
+ else:
+ raise
+
@staticmethod
async def delete_mcp(mcp_id: str) -> None:
"""
@@ -568,8 +610,20 @@ class MCPLoader(metaclass=SingletonMeta):
:return: 无
"""
# 清空数据库
- deleted_mcp_list = await MCPLoader._find_deleted_mcp()
+
+ # As IS:
+ # 如果mongo的mcp库中搜索到的 在文件系统里没有 则认为是需要被[mongo, lance]中删除
+ # deleted_mcp_list = await MCPLoader._find_deleted_mcp()
+ # await MCPLoader.remove_deleted_mcp(deleted_mcp_list)
+
+ # To BE:
+ # 如果mongo的mcp中搜索到
+ # 文件中不存在的 删除 mongo删除 lance删除
+ # 再对比lance和mongo的mcp清单,删除mongo中不存在的
+
+ deleted_mcp_list, undeleted_mcp_list = await MCPLoader._find_deleted_and_undeleted_mcp()
await MCPLoader.remove_deleted_mcp(deleted_mcp_list)
+ await MCPLoader.remove_cached_mcp_in_lance(undeleted_mcp_list)
# 检查目录
await MCPLoader._check_dir()
diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py
index b7088d8da45ad01e94e0cf7208db25c5f56c3b22..687697fb81ba4389ffb1d247381288dfdd831cd7 100644
--- a/apps/scheduler/scheduler/context.py
+++ b/apps/scheduler/scheduler/context.py
@@ -114,11 +114,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
used_docs.append(
RecordGroupDocument(
_id=docs["id"],
+ author=docs.get("author", ""),
+ order=docs.get("order", 0),
name=docs["name"],
abstract=docs.get("abstract", ""),
extension=docs.get("extension", ""),
size=docs.get("size", 0),
associated="answer",
+ created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)),
)
)
if docs.get("order") is not None:
@@ -185,7 +188,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
feature={},
),
createdAt=current_time,
- flow=[i["_id"] for i in task.context],
+ flow=[i.id for i in task.context],
)
# 检查是否存在group_id
diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py
index c89fdd1014ba4714d0777742412e98c0802ca68c..a2a45e41494e3ddb8d291c0a907da06bb479d4a5 100644
--- a/apps/scheduler/scheduler/message.py
+++ b/apps/scheduler/scheduler/message.py
@@ -71,14 +71,7 @@ async def push_rag_message(
# 如果是文本消息,直接拼接到答案中
full_answer += content_obj.content
elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
- task.runtime.documents.append({
- "id": content_obj.content.get("id", ""),
- "order": content_obj.content.get("order", 0),
- "name": content_obj.content.get("name", ""),
- "abstract": content_obj.content.get("abstract", ""),
- "extension": content_obj.content.get("extension", ""),
- "size": content_obj.content.get("size", 0),
- })
+ task.runtime.documents.append(content_obj.content)
# 保存答案
task.runtime.answer = full_answer
await TaskManager.save_task(task.id, task)
@@ -115,10 +108,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl
data=DocumentAddContent(
documentId=content_obj.content.get("id", ""),
documentOrder=content_obj.content.get("order", 0),
+ documentAuthor=content_obj.content.get("author", ""),
documentName=content_obj.content.get("name", ""),
documentAbstract=content_obj.content.get("abstract", ""),
documentType=content_obj.content.get("extension", ""),
documentSize=content_obj.content.get("size", 0),
+ createdAt=round(datetime.now(tz=UTC).timestamp(), 3),
).model_dump(exclude_none=True, by_alias=True),
)
except Exception:
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index ed73638ced241909f55989597f8bb747b50f1945..417f93d28a147cd0cf65b124624484a02a49ab06 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -206,7 +206,7 @@ class Scheduler:
task=self.task,
msg_queue=queue,
question=post_body.question,
- max_steps=app_metadata.history_len,
+ history_len=app_metadata.history_len,
servers_id=servers_id,
background=background,
agent_id=app_info.app_id,
diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py
index 4caab4d2dc945ec5ac14b7769770f94c39980a9f..89433cade929ef6917664450bb3645500ed2df5a 100644
--- a/apps/scheduler/slot/slot.py
+++ b/apps/scheduler/slot/slot.py
@@ -12,6 +12,8 @@ from jsonschema.exceptions import ValidationError
from jsonschema.protocols import Validator
from jsonschema.validators import extend
+from apps.schemas.response_data import ParamsNode
+from apps.scheduler.call.choice.schema import Type
from apps.scheduler.slot.parser import (
SlotConstParser,
SlotDateParser,
@@ -221,6 +223,45 @@ class Slot:
return data
return _extract_type_desc(self._schema)
+ def get_params_node_from_schema(self, root: str = "") -> ParamsNode:
+ """从JSON Schema中提取ParamsNode"""
+ def _extract_params_node(schema_node: dict[str, Any], name: str = "", path: str = "") -> ParamsNode:
+ """递归提取ParamsNode"""
+ if "type" not in schema_node:
+ return None
+
+ param_type = schema_node["type"]
+ if param_type == "object":
+ param_type = Type.DICT
+ elif param_type == "array":
+ param_type = Type.LIST
+ elif param_type == "string":
+ param_type = Type.STRING
+ elif param_type == "number":
+ param_type = Type.NUMBER
+ elif param_type == "boolean":
+ param_type = Type.BOOL
+ else:
+ logger.warning(f"[Slot] 不支持的参数类型: {param_type}")
+ return None
+ sub_params = []
+
+ if param_type == "object" and "properties" in schema_node:
+ for key, value in schema_node["properties"].items():
+ sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}"))
+ else:
+ # 对于非对象类型,直接返回空子参数
+ sub_params = None
+ return ParamsNode(paramName=name,
+ paramPath=path,
+ paramType=param_type,
+ subParams=sub_params)
+ try:
+ return _extract_params_node(self._schema, name=root, path=root)
+ except Exception as e:
+ logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}")
+ return None
+
def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""将JSON Schema扁平化"""
result = {}
@@ -276,7 +317,6 @@ class Slot:
logger.exception("[Slot] 错误schema不合法: %s", error.schema)
return {}, []
-
def _assemble_patch(
self,
key: str,
@@ -329,7 +369,6 @@ class Slot:
logger.info("[Slot] 组装patch: %s", patch_list)
return patch_list
-
def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]:
"""将用户手动填充的参数专为真实JSON"""
json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data
diff --git a/apps/scheduler/variable/README.md b/apps/scheduler/variable/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..73718ce01f2eca2514d233b43059cc76018c7a48
--- /dev/null
+++ b/apps/scheduler/variable/README.md
@@ -0,0 +1,264 @@
+# 变量池架构文档
+
+## 架构设计
+
+基于用户需求,变量系统采用"模板-实例"的两级架构:
+
+### 设计理念
+
+- **Flow级别(父pool)**:管理变量模板定义,用户可以查看和配置变量结构
+- **Conversation级别(子pool)**:管理变量实例,存储实际的运行时数据
+
+### 变量分类
+
+不同类型的变量有不同的存储和管理方式:
+- **系统变量**:模板在Flow级别定义,实例在Conversation级别运行时更新
+- **对话变量**:模板在Flow级别定义,实例在Conversation级别用户可设置
+- **环境变量**:直接在Flow级别存储和使用
+- **用户变量**:在User级别长期存储
+
+## 架构实现
+
+### 变量池类型
+
+#### 1. UserVariablePool(用户变量池)
+- **关联ID**: `user_id`
+- **权限**: 用户可读写
+- **生命周期**: 随用户创建而创建,长期存在
+- **典型变量**: API密钥、用户偏好、个人配置等
+
+#### 2. FlowVariablePool(流程变量池)
+- **关联ID**: `flow_id`
+- **权限**: 流程可读写
+- **生命周期**: 随 flow 创建而创建
+- **继承**: 支持从父流程继承
+- **存储内容**:
+ - 环境变量(直接使用)
+ - 系统变量模板(供对话继承)
+ - 对话变量模板(供对话继承)
+
+#### 3. ConversationVariablePool(对话变量池)
+- **关联ID**: `conversation_id`
+- **权限**:
+ - 系统变量实例:只读,由系统自动更新
+ - 对话变量实例:可读写,用户可设置值
+- **生命周期**: 随对话创建而创建,对话结束后可选择性清理
+- **初始化方式**: 从FlowVariablePool的模板自动继承
+- **包含内容**:
+ - **系统变量实例**:`query`, `files`, `dialogue_count`等运行时值
+ - **对话变量实例**:用户定义的对话上下文数据
+
+## 核心设计原则
+
+### 1. 统一的对话上下文
+所有对话相关的变量(无论是系统变量还是对话变量)都在同一个对话变量池中管理,确保上下文的一致性。
+
+### 2. 权限区分
+通过 `is_system` 标记区分系统变量和对话变量:
+- `is_system=True`: 系统变量,只读,由系统自动更新
+- `is_system=False`: 对话变量,可读写,支持人为修改
+
+### 3. 自动初始化和持久化
+创建对话变量池时,自动初始化所有必需的系统变量,设置合理的默认值,并立即持久化到数据库,确保系统变量在任何时候都可用。
+
+## 使用方式
+
+### 1. 创建对话变量池
+
+```python
+pool_manager = await get_pool_manager()
+
+# 创建对话变量池(自动包含系统变量)
+conv_pool = await pool_manager.create_conversation_pool("conv123", "flow456")
+```
+
+### 2. 更新系统变量
+
+```python
+# 系统变量由解析器自动更新
+parser = VariableParser(
+ user_id="user123",
+ flow_id="flow456",
+ conversation_id="conv123"
+)
+
+# 更新系统变量
+await parser.update_system_variables({
+ "question": "你好,请帮我分析数据",
+ "files": [{"name": "data.csv", "size": 1024}],
+ "dialogue_count": 1,
+ "user_sub": "user123"
+})
+```
+
+### 3. 更新对话变量
+
+```python
+# 添加对话变量
+await conv_pool.add_variable(
+ name="context_history",
+ var_type=VariableType.ARRAY_STRING,
+ value=["用户问候", "系统回应"],
+ description="对话历史"
+)
+
+# 更新对话变量
+await conv_pool.update_variable("context_history", value=["问候", "回应", "新消息"])
+```
+
+### 4. 变量解析
+
+```python
+# 系统变量和对话变量使用相同的引用语法
+template = """
+系统变量 - 用户查询: {{sys.query}}
+系统变量 - 对话轮数: {{sys.dialogue_count}}
+对话变量 - 历史: {{conversation.context_history}}
+用户变量 - 偏好: {{user.preferences}}
+环境变量 - 数据库: {{env.database_url}}
+"""
+
+parsed = await parser.parse_template(template)
+```
+
+## 变量引用语法
+
+变量引用保持不变:
+- `{{sys.variable_name}}` - 系统变量(对话级别,只读)
+- `{{conversation.variable_name}}` - 对话变量(对话级别,可读写)
+- `{{user.variable_name}}` - 用户变量
+- `{{env.variable_name}}` - 环境变量
+
+## 权限控制详细说明
+
+### 系统变量权限
+```python
+# 普通更新会被拒绝
+await conv_pool.update_variable("query", value="new query") # ❌ 抛出 PermissionError
+
+# 系统内部更新
+await conv_pool.update_system_variable("query", "new query") # ✅ 成功
+# 或者
+await conv_pool.update_variable("query", value="new query", force_system_update=True) # ✅ 成功
+```
+
+### 对话变量权限
+```python
+# 普通对话变量可以自由更新
+await conv_pool.update_variable("context_history", value=new_history) # ✅ 成功
+```
+
+## 数据存储
+
+### 元数据增强
+```python
+class VariableMetadata(BaseModel):
+ # ... 其他字段
+ is_system: bool = Field(default=False, description="是否为系统变量(只读)")
+```
+
+### 数据库查询
+系统变量和对话变量存储在同一个集合中,通过 `metadata.is_system` 字段区分。
+
+## 迁移影响
+
+### 对用户的影响
+- ✅ **变量引用语法完全不变**
+- ✅ **API接口完全兼容**
+- ✅ **现有功能正常工作**
+
+### 内部实现变化
+- 去掉了独立的 `SystemVariablePool`
+- 系统变量现在在 `ConversationVariablePool` 中管理
+- 通过权限控制区分系统变量和对话变量
+
+## 架构优势
+
+### 1. 逻辑一致性
+系统变量和对话变量都属于对话上下文,在同一个池中管理更合理。
+
+### 2. 简化管理
+不需要在系统池和对话池之间同步数据,避免了数据一致性问题。
+
+### 3. 更好的性能
+减少了池之间的数据传递和同步开销。
+
+### 4. 扩展性
+为未来可能的对话级系统变量扩展提供了更好的基础。
+
+## 总结
+
+修正后的架构更准确地反映了变量的实际使用场景:
+- **用户变量**: 用户级别,长期存在
+- **环境变量**: 流程级别,配置相关
+- **系统变量 + 对话变量**: 对话级别,上下文相关
+
+这样的设计更符合实际业务逻辑,也更容易理解和维护。
+
+## 系统变量详细说明
+
+### 预定义系统变量
+
+每个对话变量池创建时,会自动初始化以下系统变量:
+
+| 变量名 | 类型 | 描述 | 初始值 |
+|-------|------|------|--------|
+| `query` | STRING | 用户查询内容 | "" |
+| `files` | ARRAY_FILE | 用户上传的文件列表 | [] |
+| `dialogue_count` | NUMBER | 对话轮数 | 0 |
+| `app_id` | STRING | 应用ID | "" |
+| `flow_id` | STRING | 工作流ID | {flow_id} |
+| `user_id` | STRING | 用户ID | "" |
+| `session_id` | STRING | 会话ID | "" |
+| `conversation_id` | STRING | 对话ID | {conversation_id} |
+| `timestamp` | NUMBER | 当前时间戳 | {当前时间} |
+
+### 系统变量生命周期
+
+1. **创建阶段**:对话变量池创建时,所有系统变量被初始化并持久化到数据库
+2. **更新阶段**:通过`VariableParser.update_system_variables()`方法更新系统变量值
+3. **访问阶段**:通过模板解析或直接访问获取系统变量值
+4. **清理阶段**:对话结束时,整个对话变量池被清理
+
+### 系统变量更新机制
+
+```python
+# 创建解析器并确保对话池存在
+parser = VariableParser(user_id=user_id, flow_id=flow_id, conversation_id=conversation_id)
+await parser.create_conversation_pool_if_needed()
+
+# 更新系统变量
+context = {
+ "question": "用户的问题",
+ "files": [{"name": "file.txt", "size": 1024}],
+ "dialogue_count": 1,
+ "app_id": "app123",
+ "user_sub": user_id,
+ "session_id": "session456"
+}
+
+await parser.update_system_variables(context)
+```
+
+### 系统变量的只读保护
+
+```python
+# ❌ 直接修改系统变量会失败
+await conversation_pool.update_variable("query", value="修改内容") # 抛出PermissionError
+
+# ✅ 只能通过系统内部接口更新
+await conversation_pool.update_system_variable("query", "新内容") # 成功
+```
+
+### 使用系统变量
+
+```python
+# 在模板中引用系统变量
+template = """
+用户问题:{{sys.query}}
+对话轮数:{{sys.dialogue_count}}
+工作流ID:{{sys.flow_id}}
+"""
+
+parsed = await parser.parse_template(template)
+```
\ No newline at end of file
diff --git a/apps/scheduler/variable/__init__.py b/apps/scheduler/variable/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c487379c54120835552954559f0e31bd2a6cce0
--- /dev/null
+++ b/apps/scheduler/variable/__init__.py
@@ -0,0 +1,60 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""工作流变量管理模块
+
+这个模块提供了完整的变量管理功能,包括:
+- 多种变量类型支持(String、Number、Boolean、Object、Secret、File、Array等)
+- 四种作用域支持(系统级、用户级、环境级、对话级)
+- 变量解析和模板替换
+- 安全的密钥变量处理
+- 与工作流调度器的集成
+"""
+
+from .type import VariableType, VariableScope
+from .base import BaseVariable, VariableMetadata
+from .variables import (
+ StringVariable,
+ NumberVariable,
+ BooleanVariable,
+ ObjectVariable,
+ SecretVariable,
+ FileVariable,
+ ArrayVariable,
+ create_variable,
+ VARIABLE_CLASS_MAP,
+)
+from .pool_manager import VariablePoolManager, get_pool_manager
+from .parser import VariableParser, VariableReferenceBuilder, VariableContext
+from .integration import VariableIntegration
+
+__all__ = [
+ # 基础类型和枚举
+ "VariableType",
+ "VariableScope",
+ "VariableMetadata",
+ "BaseVariable",
+
+ # 具体变量类型
+ "StringVariable",
+ "NumberVariable",
+ "BooleanVariable",
+ "ObjectVariable",
+ "SecretVariable",
+ "FileVariable",
+ "ArrayVariable",
+
+ # 工厂函数和映射
+ "create_variable",
+ "VARIABLE_CLASS_MAP",
+
+ # 变量池管理器
+ "VariablePoolManager",
+ "get_pool_manager",
+
+ # 解析器
+ "VariableParser",
+ "VariableReferenceBuilder",
+ "VariableContext",
+
+ # 集成功能
+ "VariableIntegration",
+]
diff --git a/apps/scheduler/variable/base.py b/apps/scheduler/variable/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..94beaaebc0f084c1a5dbec6696f7b233a5d66154
--- /dev/null
+++ b/apps/scheduler/variable/base.py
@@ -0,0 +1,190 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional
+from pydantic import BaseModel, Field
+from datetime import datetime, UTC
+
+from .type import VariableType, VariableScope
+
+
+class VariableMetadata(BaseModel):
+ """变量元数据"""
+ name: str = Field(description="变量名称")
+ var_type: VariableType = Field(description="变量类型")
+ scope: VariableScope = Field(description="变量作用域")
+ description: Optional[str] = Field(default=None, description="变量描述")
+ created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="创建时间")
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="更新时间")
+ created_by: Optional[str] = Field(default=None, description="创建者用户ID")
+
+ # 作用域相关属性
+ user_sub: Optional[str] = Field(default=None, description="用户级变量的用户ID")
+ flow_id: Optional[str] = Field(default=None, description="环境级/对话级变量的流程ID")
+ conversation_id: Optional[str] = Field(default=None, description="对话级变量的对话ID")
+
+ # 系统变量标识
+ is_system: bool = Field(default=False, description="是否为系统变量(只读)")
+
+ # 模板变量标识
+ is_template: bool = Field(default=False, description="是否为模板变量(存储在flow级别)")
+
+ # 安全相关属性
+ is_encrypted: bool = Field(default=False, description="是否加密存储")
+ access_permissions: Optional[Dict[str, Any]] = Field(default=None, description="访问权限")
+
+
+class BaseVariable(ABC):
+ """变量处理基类"""
+
+ def __init__(self, metadata: VariableMetadata, value: Any = None):
+ """初始化变量
+
+ Args:
+ metadata: 变量元数据
+ value: 变量值
+ """
+ self.metadata = metadata
+ self._value = None # 先设置为None
+ self._original_value = value
+ self._initializing = True # 标记正在初始化
+
+ # 通过setter设置值,触发类型验证
+ if value is not None:
+ self.value = value # 这会触发setter和类型验证
+ else:
+ self._value = value # 如果是None则直接设置
+
+ self._initializing = False # 初始化完成
+
+ @property
+ def name(self) -> str:
+ """获取变量名称"""
+ return self.metadata.name
+
+ @property
+ def var_type(self) -> VariableType:
+ """获取变量类型"""
+ return self.metadata.var_type
+
+ @property
+ def scope(self) -> VariableScope:
+ """获取变量作用域"""
+ return self.metadata.scope
+
+ @property
+ def value(self) -> Any:
+ """获取变量值"""
+ return self._value
+
+ @value.setter
+ def value(self, new_value: Any) -> None:
+ """设置变量值"""
+ # 只有在非初始化阶段才检查系统级变量的修改限制
+ if self.scope == VariableScope.SYSTEM and not getattr(self, '_initializing', False):
+ raise ValueError("系统级变量不能修改")
+
+ # 验证类型
+ if not self._validate_type(new_value):
+ raise TypeError(f"变量 {self.name} 的值类型不匹配,期望: {self.var_type}")
+
+ self._value = new_value
+ self.metadata.updated_at = datetime.now(UTC)
+
+ @abstractmethod
+ def _validate_type(self, value: Any) -> bool:
+ """验证值的类型是否正确
+
+ Args:
+ value: 要验证的值
+
+ Returns:
+ bool: 类型是否正确
+ """
+ pass
+
+ @abstractmethod
+ def to_string(self) -> str:
+ """将变量转换为字符串表示
+
+ Returns:
+ str: 字符串表示
+ """
+ pass
+
+ @abstractmethod
+ def to_dict(self) -> Dict[str, Any]:
+ """将变量转换为字典表示
+
+ Returns:
+ Dict[str, Any]: 字典表示
+ """
+ pass
+
+ @abstractmethod
+ def serialize(self) -> Dict[str, Any]:
+ """序列化变量用于存储
+
+ Returns:
+ Dict[str, Any]: 序列化后的数据
+ """
+ pass
+
+ @classmethod
+ @abstractmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "BaseVariable":
+ """从序列化数据恢复变量
+
+ Args:
+ data: 序列化数据
+
+ Returns:
+ BaseVariable: 恢复的变量实例
+ """
+ pass
+
+ def copy(self) -> "BaseVariable":
+ """创建变量的副本
+
+ Returns:
+ BaseVariable: 变量副本
+ """
+ # 深拷贝元数据和值
+ import copy
+ new_metadata = copy.deepcopy(self.metadata)
+ new_value = copy.deepcopy(self._value)
+ return self.__class__(new_metadata, new_value)
+
+ def reset(self) -> None:
+ """重置变量到初始值"""
+ if self.scope == VariableScope.SYSTEM:
+ raise ValueError("系统级变量不能重置")
+
+ self._value = self._original_value
+ self.metadata.updated_at = datetime.now(UTC)
+
+ def can_access(self, user_sub: str) -> bool:
+ """检查用户是否有权限访问此变量
+
+ Args:
+ user_sub: 用户ID
+
+ Returns:
+ bool: 是否有权限访问
+ """
+ # 系统级变量所有人都可以访问
+ if self.scope == VariableScope.SYSTEM:
+ return True
+
+ # 用户级变量只有创建者可以访问
+ if self.scope == VariableScope.USER:
+ return self.metadata.user_sub == user_sub
+
+ # 环境级和对话级变量根据上下文判断(这里简化处理)
+ return True
+
+ def __str__(self) -> str:
+ """字符串表示"""
+ return f"{self.name}({self.var_type.value})={self.to_string()}"
+
+ def __repr__(self) -> str:
+ """调试表示"""
+ return f"<{self.__class__.__name__}(name='{self.name}', type='{self.var_type.value}', scope='{self.scope.value}')>"
\ No newline at end of file
diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py
new file mode 100644
index 0000000000000000000000000000000000000000..a33385a176513a8717d98a59bcaf5f028ce8a47f
--- /dev/null
+++ b/apps/scheduler/variable/integration.py
@@ -0,0 +1,366 @@
+"""变量解析与工作流调度器集成"""
+
+import logging
+from typing import Any, Dict, List, Optional, Union
+
+from apps.scheduler.variable.parser import VariableParser
+from apps.scheduler.variable.pool_manager import get_pool_manager
+from apps.scheduler.variable.type import VariableScope
+
+logger = logging.getLogger(__name__)
+
+
+class VariableIntegration:
+ """变量解析集成类 - 为现有调度器提供变量功能"""
+
+ @staticmethod
+ async def initialize_system_variables(context: Dict[str, Any]) -> None:
+ """初始化系统变量
+
+ Args:
+ context: 系统上下文信息,包括用户查询、文件等
+ """
+ try:
+ parser = VariableParser(
+ user_sub=context.get("user_sub"),
+ flow_id=context.get("flow_id"),
+ conversation_id=context.get("conversation_id")
+ )
+
+ # 更新系统变量
+ await parser.update_system_variables(context)
+
+ logger.info("系统变量已初始化")
+ except Exception as e:
+ logger.error(f"初始化系统变量失败: {e}")
+ raise
+
+ @staticmethod
+ async def parse_call_input(input_data: Dict[str, Any],
+ user_sub: str,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None) -> Union[str, Dict, List]:
+ """解析Call输入中的变量引用
+
+ Args:
+ input_data: 输入数据
+ user_sub: 用户ID
+ flow_id: 流程ID
+ conversation_id: 对话ID
+
+ Returns:
+ Dict[str, Any]: 解析后的输入数据
+ """
+ try:
+ parser = VariableParser(
+ user_sub=user_sub,
+ flow_id=flow_id,
+ conversation_id=conversation_id
+ )
+
+ # 递归解析JSON模板中的变量引用
+ parsed_input = await parser.parse_json_template(input_data)
+
+ return parsed_input
+
+ except Exception as e:
+ logger.warning(f"解析Call输入变量失败: {e}")
+ # 如果解析失败,返回原始输入
+ return input_data
+
+ @staticmethod
+ async def resolve_variable_reference(
+ reference: str,
+ user_sub: str,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None
+ ) -> Any:
+ """解析单个变量引用
+
+ Args:
+ reference: 变量引用字符串(如 "{{user.name}}" 或 "user.name")
+ user_sub: 用户ID
+ flow_id: 流程ID
+ conversation_id: 对话ID
+
+ Returns:
+ Any: 解析后的变量值
+ """
+ try:
+ parser = VariableParser(
+ user_id=user_sub,
+ flow_id=flow_id,
+ conversation_id=conversation_id
+ )
+
+ # 清理引用字符串(移除花括号)
+ clean_reference = reference.strip("{}")
+
+ # 使用解析器解析变量引用
+ resolved_value = await parser._resolve_variable_reference(clean_reference)
+
+ return resolved_value
+
+ except Exception as e:
+ logger.error(f"解析变量引用失败: {reference}, 错误: {e}")
+ raise
+
+ @staticmethod
+ async def save_conversation_variable(
+ var_name: str,
+ value: Any,
+ var_type: str = "string",
+ description: str = "",
+ user_sub: str = "",
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None
+ ) -> bool:
+ """保存对话变量
+
+ Args:
+ var_name: 变量名(不包含scope前缀)
+ value: 变量值
+ var_type: 变量类型
+ description: 变量描述
+ user_sub: 用户ID
+ flow_id: 流程ID
+ conversation_id: 对话ID
+
+ Returns:
+ bool: 是否保存成功
+ """
+ try:
+ if not conversation_id:
+ logger.warning("无法保存对话变量:缺少conversation_id")
+ return False
+
+ # 直接使用pool_manager,避免解析器的复杂逻辑
+ pool_manager = await get_pool_manager()
+ conversation_pool = await pool_manager.get_conversation_pool(conversation_id)
+
+ if not conversation_pool:
+ logger.warning(f"无法获取对话变量池: {conversation_id}")
+ return False
+
+ # 转换变量类型
+ from apps.scheduler.variable.type import VariableType
+ try:
+ var_type_enum = VariableType(var_type)
+ except ValueError:
+ var_type_enum = VariableType.STRING
+ logger.warning(f"未知的变量类型 {var_type},使用默认类型 string")
+
+ # 尝试更新变量,如果不存在则创建
+ try:
+ await conversation_pool.update_variable(var_name, value=value)
+ logger.debug(f"对话变量已更新: {var_name} = {value}")
+ return True
+ except ValueError as e:
+ if "不存在" in str(e):
+ # 变量不存在,创建新变量
+ await conversation_pool.add_variable(
+ name=var_name,
+ var_type=var_type_enum,
+ value=value,
+ description=description,
+ created_by=user_sub or "system"
+ )
+ logger.debug(f"对话变量已创建: {var_name} = {value}")
+ return True
+ else:
+ raise # 其他错误重新抛出
+
+ except Exception as e:
+ logger.error(f"保存对话变量失败: {var_name} - {e}")
+ return False
+
+ @staticmethod
+ async def parse_template_string(template: str,
+ user_sub: str,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None) -> str:
+ """解析模板字符串中的变量引用
+
+ Args:
+ template: 模板字符串
+ user_sub: 用户ID
+ flow_id: 流程ID
+ conversation_id: 对话ID
+
+ Returns:
+ str: 解析后的字符串
+ """
+ try:
+ parser = VariableParser(
+ user_sub=user_sub,
+ flow_id=flow_id,
+ conversation_id=conversation_id
+ )
+
+ return await parser.parse_template(template)
+ except Exception as e:
+ logger.warning(f"解析模板字符串失败: {e}")
+ # 如果解析失败,返回原始模板
+ return template
+
+ @staticmethod
+ async def add_conversation_variable(name: str,
+ value: Any,
+ conversation_id: str,
+ var_type_str: str = "string") -> bool:
+ """添加对话级变量
+
+ Args:
+ name: 变量名
+ value: 变量值
+ conversation_id: 对话ID(在内部作为flow_id使用)
+ var_type_str: 变量类型字符串
+
+ Returns:
+ bool: 是否添加成功
+ """
+ try:
+ from apps.scheduler.variable.type import VariableType
+
+ # 转换变量类型
+ var_type = VariableType(var_type_str)
+
+ pool_manager = await get_pool_manager()
+ # 获取对话变量池(如果不存在会抛出异常)
+ conversation_pool = await pool_manager.get_conversation_pool(conversation_id)
+ if not conversation_pool:
+ logger.error(f"对话变量池不存在: {conversation_id}")
+ return False
+
+ await conversation_pool.add_variable(
+ name=name,
+ var_type=var_type,
+ value=value,
+ description=f"对话变量: {name}"
+ )
+
+ logger.debug(f"已添加对话变量: {name} = {value}")
+ return True
+ except Exception as e:
+ logger.error(f"添加对话变量失败: {e}")
+ return False
+
+ @staticmethod
+ async def update_conversation_variable(name: str,
+ value: Any,
+ conversation_id: str) -> bool:
+ """更新对话级变量
+
+ Args:
+ name: 变量名
+ value: 新值
+ conversation_id: 对话ID
+
+ Returns:
+ bool: 是否更新成功
+ """
+ try:
+ pool_manager = await get_pool_manager()
+ # 获取对话变量池
+ conversation_pool = await pool_manager.get_conversation_pool(conversation_id)
+ if not conversation_pool:
+ logger.error(f"对话变量池不存在: {conversation_id}")
+ return False
+
+ await conversation_pool.update_variable(
+ name=name,
+ value=value
+ )
+
+ logger.debug(f"已更新对话变量: {name} = {value}")
+ return True
+ except Exception as e:
+ logger.error(f"更新对话变量失败: {e}")
+ return False
+
+ @staticmethod
+ async def extract_output_variables(output_data: Dict[str, Any],
+ conversation_id: str,
+ step_name: str) -> None:
+ """从步骤输出中提取变量并设置为对话级变量
+
+ Args:
+ output_data: 步骤输出数据
+ conversation_id: 对话ID
+ step_name: 步骤名称
+ """
+ try:
+ # 将整个输出作为对象变量存储
+ await VariableIntegration.add_conversation_variable(
+ name=f"step_{step_name}_output",
+ value=output_data,
+ conversation_id=conversation_id,
+ var_type_str="object"
+ )
+
+ # 如果输出中有特定的变量定义,也可以单独提取
+ if isinstance(output_data, dict):
+ for key, value in output_data.items():
+ if key.startswith("var_"):
+ # 以 var_ 开头的字段被视为变量定义
+ var_name = key[4:] # 移除 var_ 前缀
+ await VariableIntegration.add_conversation_variable(
+ name=var_name,
+ value=value,
+ conversation_id=conversation_id,
+ var_type_str="string"
+ )
+
+ except Exception as e:
+ logger.error(f"提取输出变量失败: {e}")
+
+ @staticmethod
+ async def clear_conversation_context(conversation_id: str) -> None:
+ """清理对话上下文中的变量
+
+ Args:
+ conversation_id: 对话ID
+ """
+ try:
+ pool_manager = await get_pool_manager()
+ # 移除对话变量池
+ success = await pool_manager.remove_conversation_pool(conversation_id)
+ if success:
+ logger.info(f"已清理对话 {conversation_id} 的变量")
+ else:
+ logger.warning(f"对话变量池不存在: {conversation_id}")
+ except Exception as e:
+ logger.error(f"清理对话变量失败: {e}")
+
+ @staticmethod
+ async def validate_variable_references(template: str,
+ user_sub: str,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None) -> tuple[bool, list[str]]:
+ """验证模板中的变量引用是否有效
+
+ Args:
+ template: 模板字符串
+ user_sub: 用户ID
+ flow_id: 流程ID
+ conversation_id: 对话ID
+
+ Returns:
+ tuple[bool, list[str]]: (是否全部有效, 无效的变量引用列表)
+ """
+ try:
+ parser = VariableParser(
+ user_sub=user_sub,
+ flow_id=flow_id,
+ conversation_id=conversation_id
+ )
+
+ return await parser.validate_template(template)
+ except Exception as e:
+ logger.error(f"验证变量引用失败: {e}")
+ return False, [str(e)]
+
+
+# 注意:原本的 monkey_patch_scheduler 和相关扩展类已被移除
+# 因为 CoreCall 类现在已经内置了完整的变量解析功能
+# 这些代码是旧版本的遗留,会导致循环导入问题
\ No newline at end of file
diff --git a/apps/scheduler/variable/parser.py b/apps/scheduler/variable/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eb0dcfaff90d5a0254c4f13917197b12c13ef95
--- /dev/null
+++ b/apps/scheduler/variable/parser.py
@@ -0,0 +1,514 @@
+import re
+import logging
+from typing import Any, Dict, List, Optional, Tuple, Union
+import json
+from datetime import datetime, UTC
+
+from .pool_manager import get_pool_manager
+from .type import VariableScope
+
+logger = logging.getLogger(__name__)
+
+
+class VariableParser:
+ """变量解析器 - 支持新架构的变量解析器(系统变量和对话变量都在对话池中)"""
+
+ # 变量引用的正则表达式:{{scope.variable_name.nested_path}}
+ VARIABLE_PATTERN = re.compile(r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*\}\}')
+
+ def __init__(self,
+ user_id: Optional[str] = None,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None,
+ user_sub: Optional[str] = None):
+ """初始化变量解析器
+
+ Args:
+ user_id: 用户ID (向后兼容)
+ flow_id: 流程ID
+ conversation_id: 对话ID
+ user_sub: 用户订阅ID (优先使用,用于未来鉴权等需求)
+ """
+ # 优先使用 user_sub,如果没有则使用 user_id
+ self.user_id = user_sub if user_sub is not None else user_id
+ self.flow_id = flow_id
+ self.conversation_id = conversation_id
+ self._pool_manager = None
+
+ async def _get_pool_manager(self):
+ """获取变量池管理器实例"""
+ if self._pool_manager is None:
+ self._pool_manager = await get_pool_manager()
+ return self._pool_manager
+
+ async def parse_template(self, template: str) -> str:
+ """解析模板字符串,替换其中的变量引用
+
+ Args:
+ template: 包含变量引用的模板字符串
+
+ Returns:
+ str: 替换后的字符串
+ """
+ if not template:
+ return template
+
+ # 查找所有变量引用
+ matches = self.VARIABLE_PATTERN.findall(template)
+
+ # 替换每个变量引用
+ result = template
+ for match in matches:
+ try:
+ # 解析变量引用
+ value = await self._resolve_variable_reference(match)
+
+ # 转换为字符串
+ str_value = self._convert_to_string(value)
+
+ # 替换模板中的变量引用
+ result = result.replace(f"{{{{{match}}}}}", str_value)
+
+ logger.debug(f"已替换变量: {{{{{match}}}}} -> {str_value}")
+
+ except Exception as e:
+ logger.warning(f"解析变量引用失败: {{{{{match}}}}} - {e}")
+ # 保持原始引用不变
+ continue
+
+ return result
+
+ async def _resolve_variable_reference(self, reference: str) -> Any:
+ """解析变量引用
+
+ Args:
+ reference: 变量引用字符串(不含花括号)
+
+ Returns:
+ Any: 变量值
+ """
+ pool_manager = await self._get_pool_manager()
+
+ # 解析作用域和变量名
+ parts = reference.split(".", 1)
+ if len(parts) != 2:
+ raise ValueError(f"无效的变量引用格式: {reference}")
+
+ scope_str, var_path = parts
+
+ # 确定作用域
+ scope_map = {
+ "sys": VariableScope.SYSTEM,
+ "system": VariableScope.SYSTEM,
+ "user": VariableScope.USER,
+ "env": VariableScope.ENVIRONMENT,
+ "environment": VariableScope.ENVIRONMENT,
+ "conversation": VariableScope.CONVERSATION,
+ "conv": VariableScope.CONVERSATION,
+ }
+
+ scope = scope_map.get(scope_str)
+ if not scope:
+ raise ValueError(f"无效的变量作用域: {scope_str}")
+
+ # 解析变量路径
+ # 对于conversation作用域,支持节点输出变量格式:conversation.node_id.key
+ if scope == VariableScope.CONVERSATION and "." in var_path:
+ # 检查是否为节点输出变量(格式:node_id.key)
+ # 先尝试获取完整路径作为变量名
+ try:
+ variable = await pool_manager.get_variable_from_any_pool(
+ name=var_path, # 使用完整路径作为变量名
+ scope=scope,
+ user_id=self.user_id if scope == VariableScope.USER else None,
+ flow_id=self.flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None,
+ conversation_id=self.conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None
+ )
+ if variable:
+ return variable.value
+ except:
+ pass # 如果找不到,继续使用原有逻辑
+
+ # 原有逻辑:支持嵌套访问如 user.config.api_key
+ path_parts = var_path.split(".")
+ var_name = path_parts[0]
+
+ # 根据作用域获取变量
+ variable = await pool_manager.get_variable_from_any_pool(
+ name=var_name,
+ scope=scope,
+ user_id=self.user_id,
+ flow_id=self.flow_id,
+ conversation_id=self.conversation_id
+ )
+
+ if not variable:
+ raise ValueError(f"变量不存在: {scope_str}.{var_name}")
+
+ # 获取变量值
+ value = variable.value
+
+ # 如果有嵌套路径,继续解析
+ for path_part in path_parts[1:]:
+ if isinstance(value, dict):
+ value = value.get(path_part)
+ elif isinstance(value, list) and path_part.isdigit():
+ try:
+ value = value[int(path_part)]
+ except IndexError:
+ value = None
+ else:
+ raise ValueError(f"无法访问路径: {var_path}")
+
+ return value
+
+ async def extract_variables(self, template: str) -> List[str]:
+ """提取模板中的所有变量引用
+
+ Args:
+ template: 模板字符串
+
+ Returns:
+ List[str]: 变量引用列表
+ """
+ if not template:
+ return []
+
+ matches = self.VARIABLE_PATTERN.findall(template)
+ return [f"{{{{{match}}}}}" for match in matches]
+
+ async def validate_template(self, template: str) -> Tuple[bool, List[str]]:
+ """验证模板中的变量引用是否都存在
+
+ Args:
+ template: 模板字符串
+
+ Returns:
+ Tuple[bool, List[str]]: (是否全部有效, 无效的变量引用列表)
+ """
+ if not template:
+ return True, []
+
+ matches = self.VARIABLE_PATTERN.findall(template)
+ invalid_refs = []
+
+ for match in matches:
+ try:
+ await self._resolve_variable_reference(match)
+ except Exception:
+ invalid_refs.append(f"{{{{{match}}}}}")
+
+ return len(invalid_refs) == 0, invalid_refs
+
+ async def parse_json_template(self, json_template: Union[str, Dict, List]) -> Union[str, Dict, List]:
+ """解析JSON格式的模板,递归处理所有字符串值中的变量引用
+
+ Args:
+ json_template: JSON模板(字符串、字典或列表)
+
+ Returns:
+ Union[str, Dict, List]: 解析后的JSON
+ """
+ if isinstance(json_template, str):
+ return await self.parse_template(json_template)
+ elif isinstance(json_template, dict):
+ result = {}
+ for key, value in json_template.items():
+ # 键也可能包含变量引用
+ parsed_key = await self.parse_template(str(key))
+ parsed_value = await self.parse_json_template(value)
+ result[parsed_key] = parsed_value
+ return result
+ elif isinstance(json_template, list):
+ result = []
+ for item in json_template:
+ parsed_item = await self.parse_json_template(item)
+ result.append(parsed_item)
+ return result
+ else:
+ # 其他类型直接返回
+ return json_template
+
+ async def update_system_variables(self, context: Dict[str, Any]):
+ """更新系统变量的值
+
+ Args:
+ context: 系统上下文信息
+ """
+ if not self.conversation_id:
+ logger.warning("无法更新系统变量:缺少conversation_id")
+ return
+
+ # 确保对话变量池存在
+ await self.create_conversation_pool_if_needed()
+
+ pool_manager = await self._get_pool_manager()
+
+ # 预定义的系统变量映射
+ system_var_mappings = {
+ "query": context.get("question", ""),
+ "files": context.get("files", []),
+ "dialogue_count": context.get("dialogue_count", 0),
+ "app_id": context.get("app_id", ""),
+ "flow_id": context.get("flow_id", self.flow_id or ""),
+ "user_id": context.get("user_sub", self.user_id or ""),
+ "session_id": context.get("session_id", ""),
+ "conversation_id": self.conversation_id,
+ "timestamp": datetime.now(UTC).timestamp(),
+ }
+
+ # 获取对话变量池
+ conversation_pool = await pool_manager.get_conversation_pool(self.conversation_id)
+ if not conversation_pool:
+ logger.error(f"对话变量池不存在,无法更新系统变量: {self.conversation_id}")
+ return
+
+ # 更新系统变量
+ updated_count = 0
+ for var_name, var_value in system_var_mappings.items():
+ try:
+ success = await conversation_pool.update_system_variable(var_name, var_value)
+ if success:
+ updated_count += 1
+ logger.debug(f"已更新系统变量: {var_name} = {var_value}")
+ else:
+ logger.warning(f"系统变量更新失败: {var_name}")
+ except Exception as e:
+ logger.warning(f"更新系统变量失败: {var_name} - {e}")
+
+ logger.info(f"系统变量更新完成: {updated_count}/{len(system_var_mappings)} 个变量更新成功")
+
+ async def update_conversation_variable(self, var_name: str, value: Any) -> bool:
+ """更新对话变量的值
+
+ Args:
+ var_name: 变量名
+ value: 新值
+
+ Returns:
+ bool: 是否更新成功
+ """
+ if not self.conversation_id:
+ logger.warning("无法更新对话变量:缺少conversation_id")
+ return False
+
+ pool_manager = await self._get_pool_manager()
+ conversation_pool = await pool_manager.get_conversation_pool(self.conversation_id)
+
+ if not conversation_pool:
+ logger.warning(f"无法获取对话变量池: {self.conversation_id}")
+ return False
+
+ try:
+ await conversation_pool.update_variable(var_name, value=value)
+ logger.info(f"已更新对话变量: {var_name} = {value}")
+ return True
+ except Exception as e:
+ logger.error(f"更新对话变量失败: {var_name} - {e}")
+ return False
+
+ async def create_conversation_pool_if_needed(self) -> bool:
+ """如果需要,创建对话变量池
+
+ Returns:
+ bool: 是否创建成功
+ """
+ if not self.conversation_id or not self.flow_id:
+ return False
+
+ pool_manager = await self._get_pool_manager()
+ existing_pool = await pool_manager.get_conversation_pool(self.conversation_id)
+
+ if existing_pool:
+ return True
+
+ try:
+ await pool_manager.create_conversation_pool(self.conversation_id, self.flow_id)
+ logger.info(f"已创建对话变量池: {self.conversation_id}")
+ return True
+ except Exception as e:
+ logger.error(f"创建对话变量池失败: {self.conversation_id} - {e}")
+ return False
+
+ def _convert_to_string(self, value: Any) -> str:
+ """将值转换为字符串
+
+ Args:
+ value: 要转换的值
+
+ Returns:
+ str: 字符串表示
+ """
+ if value is None:
+ return ""
+ elif isinstance(value, str):
+ return value
+ elif isinstance(value, bool):
+ return str(value).lower()
+ elif isinstance(value, (int, float)):
+ return str(value)
+ elif isinstance(value, (dict, list)):
+ try:
+ return json.dumps(value, ensure_ascii=False, separators=(',', ':'))
+ except (TypeError, ValueError):
+ return str(value)
+ else:
+ return str(value)
+
+ @classmethod
+ def escape_variable_reference(cls, text: str) -> str:
+ """转义变量引用,防止被解析
+
+ Args:
+ text: 包含变量引用的文本
+
+ Returns:
+ str: 转义后的文本
+ """
+ return text.replace("{{", "\\{\\{").replace("}}", "\\}\\}")
+
+ @classmethod
+ def unescape_variable_reference(cls, text: str) -> str:
+ """取消转义变量引用
+
+ Args:
+ text: 转义的文本
+
+ Returns:
+ str: 取消转义后的文本
+ """
+ return text.replace("\\{\\{", "{{").replace("\\}\\}", "}}")
+
+
+class VariableReferenceBuilder:
+ """变量引用构建器 - 帮助构建标准的变量引用字符串"""
+
+ @staticmethod
+ def system(var_name: str, nested_path: Optional[str] = None) -> str:
+ """构建系统变量引用
+
+ Args:
+ var_name: 变量名
+ nested_path: 嵌套路径(如 config.api_key)
+
+ Returns:
+ str: 变量引用字符串
+ """
+ if nested_path:
+ return f"{{{{sys.{var_name}.{nested_path}}}}}"
+ return f"{{{{sys.{var_name}}}}}"
+
+ @staticmethod
+ def user(var_name: str, nested_path: Optional[str] = None) -> str:
+ """构建用户变量引用
+
+ Args:
+ var_name: 变量名
+ nested_path: 嵌套路径
+
+ Returns:
+ str: 变量引用字符串
+ """
+ if nested_path:
+ return f"{{{{user.{var_name}.{nested_path}}}}}"
+ return f"{{{{user.{var_name}}}}}"
+
+ @staticmethod
+ def environment(var_name: str, nested_path: Optional[str] = None) -> str:
+ """构建环境变量引用
+
+ Args:
+ var_name: 变量名
+ nested_path: 嵌套路径
+
+ Returns:
+ str: 变量引用字符串
+ """
+ if nested_path:
+ return f"{{{{env.{var_name}.{nested_path}}}}}"
+ return f"{{{{env.{var_name}}}}}"
+
+ @staticmethod
+ def conversation(var_name: str, nested_path: Optional[str] = None) -> str:
+ """构建对话变量引用
+
+ Args:
+ var_name: 变量名
+ nested_path: 嵌套路径
+
+ Returns:
+ str: 变量引用字符串
+ """
+ if nested_path:
+ return f"{{{{conversation.{var_name}.{nested_path}}}}}"
+ return f"{{{{conversation.{var_name}}}}}"
+
+
+class VariableContext:
+ """变量上下文管理器 - 管理局部变量作用域"""
+
+ def __init__(self,
+ parser: VariableParser,
+ parent_context: Optional["VariableContext"] = None):
+ """初始化变量上下文
+
+ Args:
+ parser: 变量解析器
+ parent_context: 父级上下文(用于嵌套作用域)
+ """
+ self.parser = parser
+ self.parent_context = parent_context
+ self._local_variables: Dict[str, Any] = {}
+
+ def set_local_variable(self, name: str, value: Any):
+ """设置局部变量
+
+ Args:
+ name: 变量名
+ value: 变量值
+ """
+ self._local_variables[name] = value
+
+ def get_local_variable(self, name: str) -> Any:
+ """获取局部变量
+
+ Args:
+ name: 变量名
+
+ Returns:
+ Any: 变量值
+ """
+ if name in self._local_variables:
+ return self._local_variables[name]
+ elif self.parent_context:
+ return self.parent_context.get_local_variable(name)
+ return None
+
+ async def parse_with_locals(self, template: str) -> str:
+ """使用局部变量解析模板
+
+ Args:
+ template: 模板字符串
+
+ Returns:
+ str: 解析后的字符串
+ """
+ # 首先用局部变量替换
+ result = template
+
+ # 替换局部变量(使用简单的 ${var_name} 语法)
+ for var_name, var_value in self._local_variables.items():
+ pattern = f"${{{var_name}}}"
+ str_value = self.parser._convert_to_string(var_value)
+ result = result.replace(pattern, str_value)
+
+ # 然后用全局变量解析器处理剩余的变量引用
+ return await self.parser.parse_template(result)
+
+ def create_child_context(self) -> "VariableContext":
+ """创建子级上下文
+
+ Returns:
+ VariableContext: 子级上下文
+ """
+ return VariableContext(self.parser, self)
\ No newline at end of file
diff --git a/apps/scheduler/variable/pool_base.py b/apps/scheduler/variable/pool_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..83acdffdceaa4409142d56c9c96675ebedf93680
--- /dev/null
+++ b/apps/scheduler/variable/pool_base.py
@@ -0,0 +1,828 @@
+import logging
+import asyncio
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List, Optional, Set
+from datetime import datetime, UTC
+
+from apps.common.mongo import MongoDB
+from .base import BaseVariable, VariableMetadata
+from .type import VariableType, VariableScope
+from .variables import create_variable, VARIABLE_CLASS_MAP
+
+logger = logging.getLogger(__name__)
+
+
+class BaseVariablePool(ABC):
+ """变量池基类"""
+
+ def __init__(self, pool_id: str, scope: VariableScope):
+ """初始化变量池
+
+ Args:
+ pool_id: 池标识符(如user_id、flow_id、conversation_id等)
+ scope: 池作用域
+ """
+ self.pool_id = pool_id
+ self.scope = scope
+ self._variables: Dict[str, BaseVariable] = {}
+ self._initialized = False
+ self._lock = asyncio.Lock()
+
+ @property
+ def is_initialized(self) -> bool:
+ """检查是否已初始化"""
+ return self._initialized
+
+ async def initialize(self):
+ """初始化变量池"""
+ async with self._lock:
+ if not self._initialized:
+ await self._load_variables()
+ await self._setup_default_variables()
+ self._initialized = True
+ logger.info(f"已初始化变量池: {self.__class__.__name__}({self.pool_id})")
+
+ @abstractmethod
+ async def _load_variables(self):
+ """从存储加载变量"""
+ pass
+
+ @abstractmethod
+ async def _setup_default_variables(self):
+ """设置默认变量"""
+ pass
+
+ @abstractmethod
+ def can_modify(self) -> bool:
+ """检查是否允许修改变量"""
+ pass
+
+ async def add_variable(self,
+ name: str,
+ var_type: VariableType,
+ value: Any = None,
+ description: Optional[str] = None,
+ created_by: Optional[str] = None,
+ is_system: bool = False) -> BaseVariable:
+ """添加变量"""
+ if not self.can_modify():
+ raise PermissionError(f"不允许修改{self.scope.value}级变量")
+
+ if name in self._variables:
+ raise ValueError(f"变量 {name} 已存在")
+
+ # 创建变量元数据
+ metadata = VariableMetadata(
+ name=name,
+ var_type=var_type,
+ scope=self.scope,
+ description=description,
+ user_sub=getattr(self, 'user_id', None),
+ flow_id=getattr(self, 'flow_id', None),
+ conversation_id=getattr(self, 'conversation_id', None),
+ created_by=created_by or "system",
+ is_system=is_system # 标记是否为系统变量
+ )
+
+ # 创建变量
+ variable = create_variable(metadata, value)
+ self._variables[name] = variable
+
+ # 持久化
+ await self._persist_variable(variable)
+
+ logger.info(f"已添加{'系统' if is_system else ''}变量: {name} 到池 {self.pool_id}")
+ return variable
+
+ async def update_variable(self,
+ name: str,
+ value: Optional[Any] = None,
+ var_type: Optional[VariableType] = None,
+ description: Optional[str] = None,
+ force_system_update: bool = False) -> BaseVariable:
+ """更新变量"""
+ if name not in self._variables:
+ raise ValueError(f"变量 {name} 不存在")
+
+ variable = self._variables[name]
+
+ # 检查系统变量的修改权限
+ if hasattr(variable.metadata, 'is_system') and variable.metadata.is_system and not force_system_update:
+ raise PermissionError(f"系统变量 {name} 不允许修改")
+
+ if not self.can_modify() and not force_system_update:
+ raise PermissionError(f"不允许修改{self.scope.value}级变量")
+
+ # 更新字段
+ if value is not None:
+ variable.value = value
+ if var_type is not None:
+ variable.metadata.var_type = var_type
+ if description is not None:
+ variable.metadata.description = description
+
+ variable.metadata.updated_at = datetime.now(UTC)
+
+ # 持久化
+ await self._persist_variable(variable)
+
+ logger.info(f"已更新变量: {name} 在池 {self.pool_id}, 值为{value}")
+ return variable
+
+ async def delete_variable(self, name: str) -> bool:
+ """删除变量"""
+ if not self.can_modify():
+ raise PermissionError(f"不允许修改{self.scope.value}级变量")
+
+ if name not in self._variables:
+ return False
+
+ variable = self._variables[name]
+
+ # 检查是否为系统变量
+ if hasattr(variable.metadata, 'is_system') and variable.metadata.is_system:
+ raise PermissionError(f"系统变量 {name} 不允许删除")
+
+ del self._variables[name]
+
+ # 从数据库删除
+ await self._delete_variable_from_db(variable)
+
+ logger.info(f"已删除变量: {name} 从池 {self.pool_id}")
+ return True
+
+ async def get_variable(self, name: str) -> Optional[BaseVariable]:
+ """获取变量"""
+ return self._variables.get(name)
+
+ async def list_variables(self, include_system: bool = True) -> List[BaseVariable]:
+ """列出所有变量"""
+ if include_system:
+ return list(self._variables.values())
+ else:
+ # 只返回非系统变量
+ return [var for var in self._variables.values()
+ if not (hasattr(var.metadata, 'is_system') and var.metadata.is_system)]
+
+ async def list_system_variables(self) -> List[BaseVariable]:
+ """列出系统变量"""
+ return [var for var in self._variables.values()
+ if hasattr(var.metadata, 'is_system') and var.metadata.is_system]
+
+ async def has_variable(self, name: str) -> bool:
+ """检查变量是否存在"""
+ return name in self._variables
+
+ async def copy_variables(self) -> Dict[str, BaseVariable]:
+ """拷贝所有变量"""
+ copied = {}
+ for name, variable in self._variables.items():
+ # 创建新的元数据
+ new_metadata = VariableMetadata(
+ name=variable.metadata.name,
+ var_type=variable.metadata.var_type,
+ scope=variable.metadata.scope,
+ description=variable.metadata.description,
+ user_sub=variable.metadata.user_sub,
+ flow_id=variable.metadata.flow_id,
+ conversation_id=variable.metadata.conversation_id,
+ created_by=variable.metadata.created_by,
+ is_system=getattr(variable.metadata, 'is_system', False)
+ )
+ # 创建新的变量实例
+ copied[name] = create_variable(new_metadata, variable.value)
+ return copied
+
+ async def _persist_variable(self, variable: BaseVariable):
+ """持久化变量"""
+ try:
+ collection = MongoDB().get_collection("variables")
+ data = variable.serialize()
+
+ # 构建查询条件
+ query = {
+ "metadata.name": variable.name,
+ "metadata.scope": variable.scope.value
+ }
+
+ # 添加池特定的查询条件
+ self._add_pool_query_conditions(query, variable)
+
+ # 更新或插入
+ from pymongo import WriteConcern
+ result = await collection.with_options(
+ write_concern=WriteConcern(w="majority", j=True)
+ ).replace_one(query, data, upsert=True)
+
+ if not (result.acknowledged and (result.matched_count > 0 or result.upserted_id)):
+ raise RuntimeError(f"变量持久化失败: {variable.name}")
+
+ except Exception as e:
+ logger.error(f"持久化变量失败: {e}")
+ raise
+
+ async def _delete_variable_from_db(self, variable: BaseVariable):
+ """从数据库删除变量"""
+ try:
+ collection = MongoDB().get_collection("variables")
+
+ query = {
+ "metadata.name": variable.name,
+ "metadata.scope": variable.scope.value
+ }
+
+ # 添加池特定的查询条件
+ self._add_pool_query_conditions(query, variable)
+
+ from pymongo import WriteConcern
+ result = await collection.with_options(
+ write_concern=WriteConcern(w="majority", j=True)
+ ).delete_one(query)
+
+ if not result.acknowledged:
+ raise RuntimeError(f"变量删除失败: {variable.name}")
+
+ except Exception as e:
+ logger.error(f"删除变量失败: {e}")
+ raise
+
+ @abstractmethod
+ def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable):
+ """添加池特定的查询条件"""
+ pass
+
+ async def resolve_variable_reference(self, reference: str) -> Any:
+ """解析变量引用"""
+ # 移除 {{ 和 }}
+ clean_ref = reference.strip("{}").strip()
+
+ # 解析变量路径
+ path_parts = clean_ref.split(".")
+ var_name = path_parts[0]
+
+ # 获取变量
+ variable = await self.get_variable(var_name)
+ if not variable:
+ raise ValueError(f"变量不存在: {var_name}")
+
+ # 获取变量值
+ value = variable.value
+
+ # 处理嵌套路径
+ for path_part in path_parts[1:]:
+ if isinstance(value, dict):
+ value = value.get(path_part)
+ elif isinstance(value, list) and path_part.isdigit():
+ try:
+ value = value[int(path_part)]
+ except IndexError:
+ value = None
+ else:
+ raise ValueError(f"无法访问路径: {clean_ref}")
+
+ return value
+
+
+class UserVariablePool(BaseVariablePool):
+ """用户变量池"""
+
+ def __init__(self, user_id: str):
+ super().__init__(user_id, VariableScope.USER)
+ self.user_id = user_id
+
+ async def _load_variables(self):
+ """从数据库加载用户变量"""
+ try:
+ collection = MongoDB().get_collection("variables")
+ cursor = collection.find({
+ "metadata.scope": VariableScope.USER.value,
+ "metadata.user_sub": self.user_id
+ })
+
+ loaded_count = 0
+ async for doc in cursor:
+ try:
+ variable_class_name = doc.get("class")
+ if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]:
+ for var_class in VARIABLE_CLASS_MAP.values():
+ if var_class.__name__ == variable_class_name:
+ variable = var_class.deserialize(doc)
+ self._variables[variable.name] = variable
+ loaded_count += 1
+ break
+ except Exception as e:
+ var_name = doc.get("metadata", {}).get("name", "unknown")
+ logger.warning(f"用户变量 {var_name} 数据损坏: {e}")
+
+ logger.debug(f"用户 {self.user_id} 加载变量完成: {loaded_count} 个")
+
+ except Exception as e:
+ logger.error(f"加载用户变量失败: {e}")
+
+ async def _setup_default_variables(self):
+ """用户变量池不需要默认变量"""
+ pass
+
+ def can_modify(self) -> bool:
+ """用户变量允许修改"""
+ return True
+
+ def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable):
+ """添加用户变量池的查询条件"""
+ query["metadata.user_sub"] = self.user_id
+
+
+class FlowVariablePool(BaseVariablePool):
+ """流程变量池(环境变量 + 系统变量模板 + 对话变量模板)"""
+
+ def __init__(self, flow_id: str, parent_flow_id: Optional[str] = None):
+ super().__init__(flow_id, VariableScope.ENVIRONMENT) # 保持主要scope为ENVIRONMENT
+ self.flow_id = flow_id
+ self.parent_flow_id = parent_flow_id
+
+ # 分别存储不同类型的变量
+ # _variables 继续存储环境变量(保持向后兼容)
+ self._system_templates: Dict[str, BaseVariable] = {} # 系统变量模板
+ self._conversation_templates: Dict[str, BaseVariable] = {} # 对话变量模板
+
+ async def _load_variables(self):
+ """从数据库加载所有类型的变量(环境变量 + 模板变量)"""
+ try:
+ collection = MongoDB().get_collection("variables")
+ loaded_counts = {"environment": 0, "system_templates": 0, "conversation_templates": 0}
+
+ # 1. 加载环境变量
+ env_cursor = collection.find({
+ "metadata.scope": VariableScope.ENVIRONMENT.value,
+ "metadata.flow_id": self.flow_id
+ })
+
+ async for doc in env_cursor:
+ try:
+ variable_class_name = doc.get("class")
+ if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]:
+ for var_class in VARIABLE_CLASS_MAP.values():
+ if var_class.__name__ == variable_class_name:
+ variable = var_class.deserialize(doc)
+ self._variables[variable.name] = variable
+ loaded_counts["environment"] += 1
+ break
+ except Exception as e:
+ var_name = doc.get("metadata", {}).get("name", "unknown")
+ logger.warning(f"环境变量 {var_name} 数据损坏: {e}")
+
+ # 2. 加载系统变量模板
+ system_template_cursor = collection.find({
+ "metadata.scope": VariableScope.SYSTEM.value,
+ "metadata.flow_id": self.flow_id,
+ "metadata.is_template": True
+ })
+
+ async for doc in system_template_cursor:
+ try:
+ variable_class_name = doc.get("class")
+ if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]:
+ for var_class in VARIABLE_CLASS_MAP.values():
+ if var_class.__name__ == variable_class_name:
+ variable = var_class.deserialize(doc)
+ self._system_templates[variable.name] = variable
+ loaded_counts["system_templates"] += 1
+ break
+ except Exception as e:
+ var_name = doc.get("metadata", {}).get("name", "unknown")
+ logger.warning(f"系统变量模板 {var_name} 数据损坏: {e}")
+
+ # 3. 加载对话变量模板
+ conv_template_cursor = collection.find({
+ "metadata.scope": VariableScope.CONVERSATION.value,
+ "metadata.flow_id": self.flow_id,
+ "metadata.is_template": True
+ })
+
+ async for doc in conv_template_cursor:
+ try:
+ variable_class_name = doc.get("class")
+ if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]:
+ for var_class in VARIABLE_CLASS_MAP.values():
+ if var_class.__name__ == variable_class_name:
+ variable = var_class.deserialize(doc)
+ self._conversation_templates[variable.name] = variable
+ loaded_counts["conversation_templates"] += 1
+ break
+ except Exception as e:
+ var_name = doc.get("metadata", {}).get("name", "unknown")
+ logger.warning(f"对话变量模板 {var_name} 数据损坏: {e}")
+
+ total_loaded = sum(loaded_counts.values())
+ logger.debug(f"流程 {self.flow_id} 加载变量完成: 环境变量{loaded_counts['environment']}个, "
+ f"系统模板{loaded_counts['system_templates']}个, "
+ f"对话模板{loaded_counts['conversation_templates']}个, 总计{total_loaded}个")
+
+ except Exception as e:
+ logger.error(f"加载流程变量失败: {e}")
+
+ async def _setup_default_variables(self):
+ """设置默认的系统变量模板"""
+ from datetime import datetime, UTC
+
+ # 定义系统变量模板(这些是模板,不是实例)
+ system_var_templates = [
+ ("query", VariableType.STRING, "用户查询内容", ""),
+ ("files", VariableType.ARRAY_FILE, "用户上传的文件列表", []),
+ ("dialogue_count", VariableType.NUMBER, "对话轮数", 0),
+ ("app_id", VariableType.STRING, "应用ID", ""),
+ ("flow_id", VariableType.STRING, "工作流ID", self.flow_id),
+ ("user_id", VariableType.STRING, "用户ID", ""),
+ ("session_id", VariableType.STRING, "会话ID", ""),
+ ("conversation_id", VariableType.STRING, "对话ID", ""),
+ ("timestamp", VariableType.NUMBER, "当前时间戳", 0),
+ ]
+
+ created_count = 0
+ for var_name, var_type, description, default_value in system_var_templates:
+ # 如果系统变量模板不存在,才创建
+ if var_name not in self._system_templates:
+ metadata = VariableMetadata(
+ name=var_name,
+ var_type=var_type,
+ scope=VariableScope.SYSTEM,
+ description=description,
+ flow_id=self.flow_id,
+ created_by="system",
+ is_system=True,
+ is_template=True # 标记为模板
+ )
+ variable = create_variable(metadata, default_value)
+ self._system_templates[var_name] = variable
+
+ # 持久化模板到数据库
+ try:
+ await self._persist_variable(variable)
+ created_count += 1
+ logger.debug(f"已持久化系统变量模板: {var_name}")
+ except Exception as e:
+ logger.error(f"持久化系统变量模板失败: {var_name} - {e}")
+
+ if created_count > 0:
+ logger.info(f"已为流程 {self.flow_id} 初始化 {created_count} 个系统变量模板")
+
+ def can_modify(self) -> bool:
+ """环境变量允许修改"""
+ return True
+
+ # === 系统变量模板相关方法 ===
+
+ async def get_system_template(self, name: str) -> Optional[BaseVariable]:
+ """获取系统变量模板"""
+ return self._system_templates.get(name)
+
+ async def list_system_templates(self) -> List[BaseVariable]:
+ """列出所有系统变量模板"""
+ return list(self._system_templates.values())
+
+ async def add_system_template(self, name: str, var_type: VariableType,
+ default_value: Any = None, description: str = None) -> BaseVariable:
+ """添加系统变量模板"""
+ if name in self._system_templates:
+ raise ValueError(f"系统变量模板 {name} 已存在")
+
+ metadata = VariableMetadata(
+ name=name,
+ var_type=var_type,
+ scope=VariableScope.SYSTEM,
+ description=description,
+ flow_id=self.flow_id,
+ created_by="system",
+ is_system=True,
+ is_template=True
+ )
+
+ variable = create_variable(metadata, default_value)
+ self._system_templates[name] = variable
+
+ # 持久化到数据库
+ await self._persist_variable(variable)
+
+ logger.info(f"已添加系统变量模板: {name} 到流程 {self.flow_id}")
+ return variable
+
+ # === 对话变量模板相关方法 ===
+
+ async def get_conversation_template(self, name: str) -> Optional[BaseVariable]:
+ """获取对话变量模板"""
+ return self._conversation_templates.get(name)
+
+ async def list_conversation_templates(self) -> List[BaseVariable]:
+ """列出所有对话变量模板"""
+ return list(self._conversation_templates.values())
+
+ async def add_conversation_template(self, name: str, var_type: VariableType,
+ default_value: Any = None, description: str = None,
+ created_by: str = None) -> BaseVariable:
+ """添加对话变量模板"""
+ if name in self._conversation_templates:
+ raise ValueError(f"对话变量模板 {name} 已存在")
+
+ metadata = VariableMetadata(
+ name=name,
+ var_type=var_type,
+ scope=VariableScope.CONVERSATION,
+ description=description,
+ flow_id=self.flow_id,
+ created_by=created_by or "user",
+ is_system=False,
+ is_template=True
+ )
+
+ variable = create_variable(metadata, default_value)
+ self._conversation_templates[name] = variable
+
+ # 持久化到数据库
+ await self._persist_variable(variable)
+
+ logger.info(f"已添加对话变量模板: {name} 到流程 {self.flow_id}")
+ return variable
+
+ # === 重写基类方法支持多scope查询 ===
+
+ async def get_variable_by_scope(self, name: str, scope: VariableScope) -> Optional[BaseVariable]:
+ """根据作用域获取变量"""
+ if scope == VariableScope.ENVIRONMENT:
+ return self._variables.get(name)
+ elif scope == VariableScope.SYSTEM:
+ return self._system_templates.get(name)
+ elif scope == VariableScope.CONVERSATION:
+ return self._conversation_templates.get(name)
+ else:
+ return None
+
+ async def list_variables_by_scope(self, scope: VariableScope) -> List[BaseVariable]:
+ """根据作用域列出变量"""
+ if scope == VariableScope.ENVIRONMENT:
+ return list(self._variables.values())
+ elif scope == VariableScope.SYSTEM:
+ return list(self._system_templates.values())
+ elif scope == VariableScope.CONVERSATION:
+ return list(self._conversation_templates.values())
+ else:
+ return []
+
+ # === 重写基类方法支持多字典操作 ===
+
+ async def update_variable(self, name: str, value: Any = None,
+ var_type: Optional[VariableType] = None,
+ description: Optional[str] = None,
+ force_system_update: bool = False) -> BaseVariable:
+ """更新变量(支持多字典查找)"""
+
+ # 先在环境变量中查找
+ if name in self._variables:
+ return await super().update_variable(name, value, var_type, description, force_system_update)
+
+ # 在系统变量模板中查找
+ elif name in self._system_templates:
+ variable = self._system_templates[name]
+
+ # 检查权限
+ if not force_system_update and getattr(variable.metadata, 'is_system', False):
+ raise PermissionError(f"系统变量 {name} 不允许直接修改")
+
+ # 更新变量
+ if value is not None:
+ variable.value = value
+ if var_type is not None:
+ variable.metadata.var_type = var_type
+ if description is not None:
+ variable.metadata.description = description
+
+ # 持久化
+ await self._persist_variable(variable)
+ return variable
+
+ # 在对话变量模板中查找
+ elif name in self._conversation_templates:
+ variable = self._conversation_templates[name]
+
+ # 更新变量
+ if value is not None:
+ variable.value = value
+ if var_type is not None:
+ variable.metadata.var_type = var_type
+ if description is not None:
+ variable.metadata.description = description
+
+ # 持久化
+ await self._persist_variable(variable)
+ return variable
+
+ else:
+ raise ValueError(f"变量 {name} 不存在")
+
+ async def delete_variable(self, name: str) -> bool:
+ """删除变量(支持多字典查找)"""
+
+ # 先在环境变量中查找
+ if name in self._variables:
+ return await super().delete_variable(name)
+
+ # 在系统变量模板中查找
+ elif name in self._system_templates:
+ variable = self._system_templates[name]
+
+ # 检查权限
+ if getattr(variable.metadata, 'is_system', False):
+ raise PermissionError(f"系统变量模板 {name} 不允许删除")
+
+ del self._system_templates[name]
+ await self._delete_variable_from_db(variable)
+ return True
+
+ # 在对话变量模板中查找
+ elif name in self._conversation_templates:
+ variable = self._conversation_templates[name]
+ del self._conversation_templates[name]
+ await self._delete_variable_from_db(variable)
+ return True
+
+ else:
+ return False
+
+ async def get_variable(self, name: str) -> Optional[BaseVariable]:
+ """获取变量(支持多字典查找)"""
+
+ # 先在环境变量中查找
+ if name in self._variables:
+ return self._variables[name]
+
+ # 在系统变量模板中查找
+ elif name in self._system_templates:
+ return self._system_templates[name]
+
+ # 在对话变量模板中查找
+ elif name in self._conversation_templates:
+ return self._conversation_templates[name]
+
+ else:
+ return None
+
+ def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable):
+ """添加环境变量池的查询条件"""
+ query["metadata.flow_id"] = self.flow_id
+
+ async def inherit_from_parent(self, parent_pool: "FlowVariablePool"):
+ """从父流程继承环境变量"""
+ parent_variables = await parent_pool.copy_variables()
+ for name, variable in parent_variables.items():
+ # 更新元数据中的flow_id
+ variable.metadata.flow_id = self.flow_id
+ self._variables[name] = variable
+ # 持久化继承的变量
+ await self._persist_variable(variable)
+
+ logger.info(f"流程 {self.flow_id} 从父流程 {parent_pool.flow_id} 继承了 {len(parent_variables)} 个环境变量")
+
+
+class ConversationVariablePool(BaseVariablePool):
+ """对话变量池 - 包含系统变量和对话变量"""
+
+ def __init__(self, conversation_id: str, flow_id: str):
+ super().__init__(conversation_id, VariableScope.CONVERSATION)
+ self.conversation_id = conversation_id
+ self.flow_id = flow_id
+
+ async def _load_variables(self):
+ """从数据库加载对话变量"""
+ try:
+ collection = MongoDB().get_collection("variables")
+ cursor = collection.find({
+ "metadata.scope": VariableScope.CONVERSATION.value,
+ "metadata.conversation_id": self.conversation_id
+ })
+
+ loaded_count = 0
+ async for doc in cursor:
+ try:
+ variable_class_name = doc.get("class")
+ if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]:
+ for var_class in VARIABLE_CLASS_MAP.values():
+ if var_class.__name__ == variable_class_name:
+ variable = var_class.deserialize(doc)
+ self._variables[variable.name] = variable
+ loaded_count += 1
+ break
+ except Exception as e:
+ var_name = doc.get("metadata", {}).get("name", "unknown")
+ logger.warning(f"对话变量 {var_name} 数据损坏: {e}")
+
+ logger.debug(f"对话 {self.conversation_id} 加载变量完成: {loaded_count} 个")
+
+ except Exception as e:
+ logger.error(f"加载对话变量失败: {e}")
+
+ async def _setup_default_variables(self):
+ """从flow模板继承系统变量和对话变量"""
+ from .pool_manager import get_pool_manager
+
+ try:
+ pool_manager = await get_pool_manager()
+ flow_pool = await pool_manager.get_flow_pool(self.flow_id)
+
+ if not flow_pool:
+ logger.warning(f"未找到流程池 {self.flow_id},无法继承变量模板")
+ return
+
+ created_count = 0
+
+ # 1. 从系统变量模板创建系统变量实例
+ system_templates = await flow_pool.list_system_templates()
+ for template in system_templates:
+ if template.name not in self._variables:
+ # 创建系统变量实例(不是模板)
+ metadata = VariableMetadata(
+ name=template.name,
+ var_type=template.var_type,
+ scope=VariableScope.CONVERSATION, # 存储在对话作用域
+ description=template.metadata.description,
+ flow_id=self.flow_id,
+ conversation_id=self.conversation_id,
+ created_by="system",
+ is_system=True, # 标记为系统变量
+ is_template=False # 这是实例,不是模板
+ )
+
+ # 使用模板的默认值创建实例
+ variable = create_variable(metadata, template.value)
+ self._variables[template.name] = variable
+
+ # 持久化系统变量实例
+ try:
+ await self._persist_variable(variable)
+ created_count += 1
+ logger.debug(f"已从模板创建系统变量实例: {template.name}")
+ except Exception as e:
+ logger.error(f"持久化系统变量实例失败: {template.name} - {e}")
+
+ # 2. 从对话变量模板创建对话变量实例
+ conversation_templates = await flow_pool.list_conversation_templates()
+ for template in conversation_templates:
+ if template.name not in self._variables:
+ # 创建对话变量实例
+ metadata = VariableMetadata(
+ name=template.name,
+ var_type=template.var_type,
+ scope=VariableScope.CONVERSATION,
+ description=template.metadata.description,
+ flow_id=self.flow_id,
+ conversation_id=self.conversation_id,
+ created_by=template.metadata.created_by,
+ is_system=False, # 对话变量
+ is_template=False # 这是实例,不是模板
+ )
+
+ # 使用模板的默认值创建实例
+ variable = create_variable(metadata, template.value)
+ self._variables[template.name] = variable
+
+ # 持久化对话变量实例
+ try:
+ await self._persist_variable(variable)
+ created_count += 1
+ logger.debug(f"已从模板创建对话变量实例: {template.name}")
+ except Exception as e:
+ logger.error(f"持久化对话变量实例失败: {template.name} - {e}")
+
+ if created_count > 0:
+ logger.info(f"已为对话 {self.conversation_id} 从流程模板继承 {created_count} 个变量")
+
+ except Exception as e:
+ logger.error(f"从流程模板继承变量失败: {e}")
+
+ def can_modify(self) -> bool:
+ """对话变量允许修改"""
+ return True
+
+ def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable):
+ """添加对话变量池的查询条件"""
+ query["metadata.conversation_id"] = self.conversation_id
+ query["metadata.flow_id"] = self.flow_id
+
+ async def update_system_variable(self, name: str, value: Any) -> bool:
+ """更新系统变量的值(系统内部调用)"""
+ try:
+ await self.update_variable(name, value=value, force_system_update=True)
+ return True
+ except Exception as e:
+ logger.error(f"更新系统变量失败: {name} - {e}")
+ return False
+
+ async def inherit_from_conversation_template(self, template_pool: Optional["ConversationVariablePool"] = None):
+ """从对话模板池继承变量(如果存在)"""
+ if template_pool:
+ template_variables = await template_pool.copy_variables()
+ for name, variable in template_variables.items():
+ # 只继承非系统变量
+ if not (hasattr(variable.metadata, 'is_system') and variable.metadata.is_system):
+ variable.metadata.conversation_id = self.conversation_id
+ self._variables[name] = variable
+
+ logger.info(f"对话 {self.conversation_id} 从模板继承了 {len(template_variables)} 个变量")
\ No newline at end of file
diff --git a/apps/scheduler/variable/pool_manager.py b/apps/scheduler/variable/pool_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..479015add1115e479ce1919e004906f91224c579
--- /dev/null
+++ b/apps/scheduler/variable/pool_manager.py
@@ -0,0 +1,353 @@
+import logging
+import asyncio
+from typing import Dict, List, Optional, Set, Tuple, Any
+from contextlib import asynccontextmanager
+
+from apps.common.mongo import MongoDB
+from .pool_base import (
+ BaseVariablePool,
+ UserVariablePool,
+ FlowVariablePool,
+ ConversationVariablePool
+)
+from .type import VariableScope
+from .base import BaseVariable
+
+logger = logging.getLogger(__name__)
+
+
+class VariablePoolManager:
+ """变量池管理器 - 管理所有类型变量池的生命周期"""
+
+ def __init__(self):
+ """初始化变量池管理器"""
+ # 用户变量池缓存: user_id -> UserVariablePool
+ self._user_pools: Dict[str, UserVariablePool] = {}
+
+ # 流程变量池缓存: flow_id -> FlowVariablePool
+ self._flow_pools: Dict[str, FlowVariablePool] = {}
+
+ # 对话变量池缓存: conversation_id -> ConversationVariablePool
+ self._conversation_pools: Dict[str, ConversationVariablePool] = {}
+
+ # 流程继承关系缓存: child_flow_id -> parent_flow_id
+ self._flow_inheritance: Dict[str, str] = {}
+
+ self._initialized = False
+ self._lock = asyncio.Lock()
+
+ async def initialize(self):
+ """初始化变量池管理器"""
+ async with self._lock:
+ if not self._initialized:
+ await self._load_existing_entities()
+ await self._patrol_and_create_missing_pools()
+ self._initialized = True
+ logger.info("变量池管理器初始化完成")
+
+ async def _load_existing_entities(self):
+ """加载现有的用户和流程实体"""
+ try:
+ # 这里应该从相应的用户和流程数据库表中加载
+ # 目前先从变量表中推断存在的实体
+ collection = MongoDB().get_collection("variables")
+
+ # 获取所有唯一的用户ID
+ user_ids = await collection.distinct("metadata.user_sub", {
+ "metadata.user_sub": {"$ne": None}
+ })
+ logger.info(f"发现 {len(user_ids)} 个用户需要变量池")
+
+ # 获取所有唯一的流程ID
+ flow_ids = await collection.distinct("metadata.flow_id", {
+ "metadata.flow_id": {"$ne": None}
+ })
+ logger.info(f"发现 {len(flow_ids)} 个流程需要变量池")
+
+ # 缓存实体信息用于后续创建池
+ self._discovered_users = set(user_ids)
+ self._discovered_flows = set(flow_ids)
+
+ except Exception as e:
+ logger.error(f"加载现有实体失败: {e}")
+ self._discovered_users = set()
+ self._discovered_flows = set()
+
+ async def _patrol_and_create_missing_pools(self):
+ """巡检并创建缺失的变量池"""
+ logger.info("开始巡检并创建缺失的变量池...")
+
+ # 为所有发现的用户创建用户变量池
+ created_user_pools = 0
+ for user_id in self._discovered_users:
+ if user_id not in self._user_pools:
+ await self._create_user_pool(user_id)
+ created_user_pools += 1
+
+ # 为所有发现的流程创建流程变量池
+ created_flow_pools = 0
+ for flow_id in self._discovered_flows:
+ if flow_id not in self._flow_pools:
+ await self._create_flow_pool(flow_id)
+ created_flow_pools += 1
+
+ logger.info(f"巡检完成: 创建了 {created_user_pools} 个用户池, "
+ f"{created_flow_pools} 个流程池")
+
+ async def get_user_pool(self, user_id: str, auto_create: bool = True) -> Optional[UserVariablePool]:
+ """获取用户变量池"""
+ if user_id in self._user_pools:
+ return self._user_pools[user_id]
+
+ if auto_create:
+ return await self._create_user_pool(user_id)
+
+ return None
+
+ async def get_flow_pool(self, flow_id: str, parent_flow_id: Optional[str] = None,
+ auto_create: bool = True) -> Optional[FlowVariablePool]:
+ """获取流程变量池"""
+ if flow_id in self._flow_pools:
+ return self._flow_pools[flow_id]
+
+ if auto_create:
+ return await self._create_flow_pool(flow_id, parent_flow_id)
+
+ return None
+
+ async def create_conversation_pool(self, conversation_id: str, flow_id: str) -> ConversationVariablePool:
+ """创建对话变量池(包含系统变量和对话变量)"""
+ if conversation_id in self._conversation_pools:
+ logger.warning(f"对话池 {conversation_id} 已存在,将覆盖")
+
+ # 创建对话变量池
+ conversation_pool = ConversationVariablePool(conversation_id, flow_id)
+ await conversation_pool.initialize()
+
+ # 从对话模板池继承变量(如果存在)
+ conversation_template_pool = await self._get_conversation_template_pool(flow_id)
+ await conversation_pool.inherit_from_conversation_template(conversation_template_pool)
+
+ # 缓存池
+ self._conversation_pools[conversation_id] = conversation_pool
+
+ logger.info(f"已创建对话变量池: {conversation_id}")
+ return conversation_pool
+
+ async def get_conversation_pool(self, conversation_id: str) -> Optional[ConversationVariablePool]:
+ """获取对话变量池"""
+ return self._conversation_pools.get(conversation_id)
+
+ async def remove_conversation_pool(self, conversation_id: str) -> bool:
+ """移除对话变量池"""
+ if conversation_id in self._conversation_pools:
+ del self._conversation_pools[conversation_id]
+ logger.info(f"已移除对话变量池: {conversation_id}")
+ return True
+ return False
+
+ async def get_variable_from_any_pool(self,
+ name: str,
+ scope: VariableScope,
+ user_id: Optional[str] = None,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None) -> Optional[BaseVariable]:
+ """从任意池中获取变量"""
+ if scope == VariableScope.USER and user_id:
+ pool = await self.get_user_pool(user_id)
+ return await pool.get_variable(name) if pool else None
+
+ elif scope == VariableScope.ENVIRONMENT and flow_id:
+ pool = await self.get_flow_pool(flow_id)
+ return await pool.get_variable(name) if pool else None
+
+ elif scope == VariableScope.CONVERSATION:
+ if conversation_id:
+ # 使用conversation_id查询对话变量实例
+ pool = await self.get_conversation_pool(conversation_id)
+ return await pool.get_variable(name) if pool else None
+ elif flow_id:
+ # 使用flow_id查询对话变量模板
+ flow_pool = await self.get_flow_pool(flow_id)
+ if flow_pool:
+ return await flow_pool.get_conversation_template(name)
+ return None
+
+ # 系统变量处理
+ elif scope == VariableScope.SYSTEM:
+ if conversation_id:
+ # 优先使用conversation_id查询实际的系统变量实例
+ pool = await self.get_conversation_pool(conversation_id)
+ if pool:
+ variable = await pool.get_variable(name)
+ # 检查是否为系统变量
+ if variable and hasattr(variable.metadata, 'is_system') and variable.metadata.is_system:
+ return variable
+ elif flow_id:
+ # 使用flow_id查询系统变量模板
+ flow_pool = await self.get_flow_pool(flow_id)
+ if flow_pool:
+ return await flow_pool.get_system_template(name)
+
+ return None
+
+ async def list_variables_from_any_pool(self,
+ scope: VariableScope,
+ user_id: Optional[str] = None,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None) -> List[BaseVariable]:
+ """从任意池中列出变量"""
+ if scope == VariableScope.USER and user_id:
+ pool = await self.get_user_pool(user_id)
+ return await pool.list_variables() if pool else []
+
+ elif scope == VariableScope.ENVIRONMENT and flow_id:
+ pool = await self.get_flow_pool(flow_id)
+ return await pool.list_variables() if pool else []
+
+ elif scope == VariableScope.CONVERSATION:
+ if conversation_id:
+ # 使用conversation_id查询对话变量实例
+ pool = await self.get_conversation_pool(conversation_id)
+ if pool:
+ # 只返回非系统变量
+ return await pool.list_variables(include_system=False)
+ elif flow_id:
+ # 使用flow_id查询对话变量模板
+ flow_pool = await self.get_flow_pool(flow_id)
+ if flow_pool:
+ return await flow_pool.list_conversation_templates()
+ return []
+
+ # 系统变量处理
+ elif scope == VariableScope.SYSTEM:
+ if conversation_id:
+ # 优先使用conversation_id查询实际的系统变量实例
+ pool = await self.get_conversation_pool(conversation_id)
+ if pool:
+ # 只返回系统变量
+ return await pool.list_system_variables()
+ elif flow_id:
+ # 使用flow_id查询系统变量模板
+ flow_pool = await self.get_flow_pool(flow_id)
+ if flow_pool:
+ return await flow_pool.list_system_templates()
+ return []
+
+ return []
+
+ async def update_system_variable(self, conversation_id: str, name: str, value: Any) -> bool:
+ """更新对话中的系统变量"""
+ conversation_pool = await self.get_conversation_pool(conversation_id)
+ if conversation_pool:
+ return await conversation_pool.update_system_variable(name, value)
+ return False
+
+ async def _create_user_pool(self, user_id: str) -> UserVariablePool:
+ """创建用户变量池"""
+ pool = UserVariablePool(user_id)
+ await pool.initialize()
+ self._user_pools[user_id] = pool
+ logger.info(f"已创建用户变量池: {user_id}")
+ return pool
+
+ async def _create_flow_pool(self, flow_id: str, parent_flow_id: Optional[str] = None) -> FlowVariablePool:
+ """创建流程变量池"""
+ pool = FlowVariablePool(flow_id, parent_flow_id)
+ await pool.initialize()
+
+ # 如果有父流程,从父流程继承变量
+ if parent_flow_id and parent_flow_id in self._flow_pools:
+ parent_pool = self._flow_pools[parent_flow_id]
+ await pool.inherit_from_parent(parent_pool)
+ self._flow_inheritance[flow_id] = parent_flow_id
+
+ self._flow_pools[flow_id] = pool
+ logger.info(f"已创建流程变量池: {flow_id}")
+ return pool
+
+ async def _get_conversation_template_pool(self, flow_id: str) -> Optional[ConversationVariablePool]:
+ """获取对话模板池(目前简化处理,返回None)"""
+ # 这里可以实现从数据库加载对话模板的逻辑
+ # 目前简化处理,返回None
+ return None
+
+ async def clear_conversation_variables(self, flow_id: str):
+ """清空工作流的所有对话变量池"""
+ to_remove = []
+ for conversation_id, pool in self._conversation_pools.items():
+ if pool.flow_id == flow_id:
+ to_remove.append(conversation_id)
+
+ for conversation_id in to_remove:
+ del self._conversation_pools[conversation_id]
+
+ logger.info(f"已清空工作流 {flow_id} 的 {len(to_remove)} 个对话变量池")
+
+ async def get_pool_stats(self) -> Dict[str, int]:
+ """获取变量池统计信息"""
+ return {
+ "user_pools": len(self._user_pools),
+ "flow_pools": len(self._flow_pools),
+ "conversation_pools": len(self._conversation_pools),
+ }
+
+ async def cleanup_unused_pools(self, active_conversations: Set[str]):
+ """清理未使用的对话变量池"""
+ to_remove = []
+ for conversation_id in self._conversation_pools:
+ if conversation_id not in active_conversations:
+ to_remove.append(conversation_id)
+
+ for conversation_id in to_remove:
+ del self._conversation_pools[conversation_id]
+
+ if to_remove:
+ logger.info(f"清理了 {len(to_remove)} 个未使用的对话变量池")
+
+ @asynccontextmanager
+ async def get_pool_for_scope(self,
+ scope: VariableScope,
+ user_id: Optional[str] = None,
+ flow_id: Optional[str] = None,
+ conversation_id: Optional[str] = None):
+ """上下文管理器,获取指定作用域的变量池"""
+ pool = None
+
+ try:
+ if scope == VariableScope.USER and user_id:
+ pool = await self.get_user_pool(user_id)
+ elif scope == VariableScope.ENVIRONMENT and flow_id:
+ pool = await self.get_flow_pool(flow_id)
+ elif scope in [VariableScope.CONVERSATION, VariableScope.SYSTEM] and conversation_id:
+ pool = await self.get_conversation_pool(conversation_id)
+
+ if not pool:
+ raise ValueError(f"无法获取 {scope.value} 级变量池")
+
+ yield pool
+
+ except Exception:
+ raise
+ finally:
+ # 这里可以添加清理逻辑,比如对话池的自动清理等
+ pass
+
+
+# 全局变量池管理器实例
+_pool_manager = None
+
+
+async def get_pool_manager() -> VariablePoolManager:
+ """获取全局变量池管理器实例"""
+ global _pool_manager
+ if _pool_manager is None:
+ _pool_manager = VariablePoolManager()
+ await _pool_manager.initialize()
+ return _pool_manager
+
+
+async def initialize_pool_manager():
+ """初始化变量池管理器(在应用启动时调用)"""
+ await get_pool_manager()
+ logger.info("变量池管理器已启动")
\ No newline at end of file
diff --git a/apps/scheduler/variable/security.py b/apps/scheduler/variable/security.py
new file mode 100644
index 0000000000000000000000000000000000000000..07fde158c01e7bf9375d7f2a555e3bc2d5b1806b
--- /dev/null
+++ b/apps/scheduler/variable/security.py
@@ -0,0 +1,415 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""变量安全管理模块
+
+提供密钥变量的额外安全保障,包括:
+- 访问审计日志
+- 密钥轮换
+- 安全检查
+- 权限验证
+"""
+
+import logging
+import hashlib
+import time
+from typing import Any, Dict, List, Optional, Set
+from datetime import datetime, UTC, timedelta
+from dataclasses import dataclass
+
+from apps.common.mongo import MongoDB
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AccessLog:
+ """访问日志记录"""
+ variable_name: str
+ user_sub: str
+ access_time: datetime
+ access_type: str # read, write, delete
+ ip_address: Optional[str] = None
+ user_agent: Optional[str] = None
+ success: bool = True
+ error_message: Optional[str] = None
+
+
+class SecretVariableSecurity:
+ """密钥变量安全管理器"""
+
+ def __init__(self):
+ """初始化安全管理器"""
+ self._access_logs: List[AccessLog] = []
+ self._failed_access_attempts: Dict[str, List[datetime]] = {}
+ self._blocked_users: Set[str] = set()
+
+ # 安全配置
+ self.max_failed_attempts = 5 # 最大失败尝试次数
+ self.block_duration_minutes = 30 # 封禁时长(分钟)
+ self.audit_retention_days = 90 # 审计日志保留天数
+
+ async def verify_access_permission(self,
+ variable_name: str,
+ user_sub: str,
+ access_type: str = "read",
+ ip_address: Optional[str] = None) -> bool:
+ """验证访问权限
+
+ Args:
+ variable_name: 变量名
+ user_sub: 用户ID
+ access_type: 访问类型(read, write, delete)
+ ip_address: IP地址
+
+ Returns:
+ bool: 是否允许访问
+ """
+ try:
+ # 检查用户是否被封禁
+ if await self._is_user_blocked(user_sub):
+ await self._log_access(
+ variable_name, user_sub, access_type, ip_address,
+ success=False, error_message="用户被临时封禁"
+ )
+ return False
+
+ # 检查是否超出访问频率限制
+ if await self._check_rate_limit(user_sub):
+ await self._log_access(
+ variable_name, user_sub, access_type, ip_address,
+ success=False, error_message="访问频率过高"
+ )
+ await self._record_failed_attempt(user_sub)
+ return False
+
+ # 验证IP地址(可选)
+ if ip_address and not await self._is_ip_allowed(ip_address):
+ await self._log_access(
+ variable_name, user_sub, access_type, ip_address,
+ success=False, error_message="IP地址不在允许列表中"
+ )
+ return False
+
+ # 记录成功访问
+ await self._log_access(variable_name, user_sub, access_type, ip_address)
+ await self._clear_failed_attempts(user_sub)
+
+ return True
+
+ except Exception as e:
+ logger.error(f"验证访问权限失败: {e}")
+ return False
+
+ async def audit_secret_access(self,
+ variable_name: str,
+ user_sub: str,
+ actual_value: str,
+ access_type: str = "read") -> None:
+ """审计密钥访问
+
+ Args:
+ variable_name: 变量名
+ user_sub: 用户ID
+ actual_value: 实际访问的值
+ access_type: 访问类型
+ """
+ try:
+ # 计算值的哈希(用于审计,不存储原始值)
+ value_hash = hashlib.sha256(actual_value.encode()).hexdigest()[:16]
+
+ # 记录详细的审计信息
+ audit_record = {
+ "variable_name": variable_name,
+ "user_sub": user_sub,
+ "access_time": datetime.now(UTC),
+ "access_type": access_type,
+ "value_hash": value_hash,
+ "value_length": len(actual_value),
+ }
+
+ # 保存到审计日志
+ await self._save_audit_record(audit_record)
+
+ logger.info(f"已审计密钥访问: {variable_name} by {user_sub}")
+
+ except Exception as e:
+ logger.error(f"审计密钥访问失败: {e}")
+
+ async def check_secret_strength(self, secret_value: str) -> Dict[str, Any]:
+ """检查密钥强度
+
+ Args:
+ secret_value: 密钥值
+
+ Returns:
+ Dict[str, Any]: 检查结果
+ """
+ result = {
+ "is_strong": False,
+ "score": 0,
+ "warnings": [],
+ "recommendations": []
+ }
+
+ try:
+ # 长度检查
+ if len(secret_value) < 8:
+ result["warnings"].append("密钥长度过短(建议至少8位)")
+ elif len(secret_value) >= 12:
+ result["score"] += 2
+ else:
+ result["score"] += 1
+
+ # 复杂性检查
+ has_upper = any(c.isupper() for c in secret_value)
+ has_lower = any(c.islower() for c in secret_value)
+ has_digit = any(c.isdigit() for c in secret_value)
+ has_special = any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in secret_value)
+
+ complexity_score = sum([has_upper, has_lower, has_digit, has_special])
+ result["score"] += complexity_score
+
+ if complexity_score < 3:
+ result["recommendations"].append("建议包含大小写字母、数字和特殊字符")
+
+ # 常见密码检查
+ common_passwords = ["password", "123456", "admin", "root", "qwerty"]
+ if secret_value.lower() in common_passwords:
+ result["warnings"].append("使用了常见的弱密码")
+ result["score"] = 0
+
+ # 重复字符检查
+ if len(set(secret_value)) < len(secret_value) * 0.6:
+ result["warnings"].append("包含过多重复字符")
+
+ # 设置强度等级
+ if result["score"] >= 6 and len(result["warnings"]) == 0:
+ result["is_strong"] = True
+
+ return result
+
+ except Exception as e:
+ logger.error(f"检查密钥强度失败: {e}")
+ return result
+
+ async def rotate_secret_key(self, variable_name: str, user_sub: str) -> bool:
+ """轮换密钥的加密密钥
+
+ Args:
+ variable_name: 变量名
+ user_sub: 用户ID
+
+ Returns:
+ bool: 是否轮换成功
+ """
+ try:
+ from .pool_manager import get_pool_manager
+ from .type import VariableScope
+
+ pool_manager = await get_pool_manager()
+
+ # 获取用户变量池
+ user_pool = await pool_manager.get_user_pool(user_sub)
+ if not user_pool:
+ logger.error(f"用户变量池不存在: {user_sub}")
+ return False
+
+ # 获取密钥变量
+ variable = await user_pool.get_variable(variable_name)
+ if not variable or not variable.var_type.is_secret_type():
+ return False
+
+ # 获取原始值
+ original_value = variable.value
+
+ # 重新生成加密密钥并重新加密
+ variable._encryption_key = variable._generate_encryption_key()
+ variable.value = original_value # 这会触发重新加密
+
+ # 更新存储
+ await user_pool._persist_variable(variable)
+
+ # 记录轮换操作
+ await self._log_access(
+ variable_name, user_sub, "key_rotation",
+ success=True, error_message="密钥轮换成功"
+ )
+
+ logger.info(f"已轮换密钥变量的加密密钥: {variable_name}")
+ return True
+
+ except Exception as e:
+ logger.error(f"轮换密钥失败: {e}")
+ return False
+
+ async def get_access_logs(self,
+ variable_name: Optional[str] = None,
+ user_sub: Optional[str] = None,
+ hours: int = 24) -> List[Dict[str, Any]]:
+ """获取访问日志
+
+ Args:
+ variable_name: 变量名(可选)
+ user_sub: 用户ID(可选)
+ hours: 查询最近几小时的日志
+
+ Returns:
+ List[Dict[str, Any]]: 访问日志列表
+ """
+ try:
+ collection = MongoDB().get_collection("variable_access_logs")
+
+ # 构建查询条件
+ query = {
+ "access_time": {
+ "$gte": datetime.now(UTC) - timedelta(hours=hours)
+ }
+ }
+
+ if variable_name:
+ query["variable_name"] = variable_name
+ if user_sub:
+ query["user_sub"] = user_sub
+
+ # 查询日志
+ logs = []
+ async for doc in collection.find(query).sort("access_time", -1):
+ logs.append({
+ "variable_name": doc.get("variable_name"),
+ "user_sub": doc.get("user_sub"),
+ "access_time": doc.get("access_time"),
+ "access_type": doc.get("access_type"),
+ "ip_address": doc.get("ip_address"),
+ "success": doc.get("success", True),
+ "error_message": doc.get("error_message")
+ })
+
+ return logs
+
+ except Exception as e:
+ logger.error(f"获取访问日志失败: {e}")
+ return []
+
+ async def _is_user_blocked(self, user_sub: str) -> bool:
+ """检查用户是否被封禁"""
+ if user_sub not in self._failed_access_attempts:
+ return False
+
+ failed_attempts = self._failed_access_attempts[user_sub]
+ recent_failures = [
+ attempt for attempt in failed_attempts
+ if attempt > datetime.now(UTC) - timedelta(minutes=self.block_duration_minutes)
+ ]
+
+ return len(recent_failures) >= self.max_failed_attempts
+
+ async def _check_rate_limit(self, user_sub: str) -> bool:
+ """检查访问频率限制"""
+ try:
+ # 获取最近1分钟的访问记录
+ collection = MongoDB().get_collection("variable_access_logs")
+ count = await collection.count_documents({
+ "user_sub": user_sub,
+ "access_time": {
+ "$gte": datetime.now(UTC) - timedelta(minutes=1)
+ }
+ })
+
+ # 限制每分钟最多30次访问
+ return count >= 30
+
+ except Exception as e:
+ logger.error(f"检查访问频率失败: {e}")
+ return False
+
+ async def _is_ip_allowed(self, ip_address: str) -> bool:
+ """检查IP地址是否被允许"""
+ # 这里可以实现IP白名单/黑名单逻辑
+ # 暂时返回True,表示允许所有IP
+ return True
+
+ async def _record_failed_attempt(self, user_sub: str):
+ """记录失败尝试"""
+ if user_sub not in self._failed_access_attempts:
+ self._failed_access_attempts[user_sub] = []
+
+ self._failed_access_attempts[user_sub].append(datetime.now(UTC))
+
+ # 清理过期的失败记录
+ cutoff_time = datetime.now(UTC) - timedelta(minutes=self.block_duration_minutes)
+ self._failed_access_attempts[user_sub] = [
+ attempt for attempt in self._failed_access_attempts[user_sub]
+ if attempt > cutoff_time
+ ]
+
+ async def _clear_failed_attempts(self, user_sub: str):
+ """清除失败尝试记录"""
+ if user_sub in self._failed_access_attempts:
+ del self._failed_access_attempts[user_sub]
+
+ async def _log_access(self,
+ variable_name: str,
+ user_sub: str,
+ access_type: str,
+ ip_address: Optional[str] = None,
+ success: bool = True,
+ error_message: Optional[str] = None):
+ """记录访问日志"""
+ try:
+ log_entry = {
+ "variable_name": variable_name,
+ "user_sub": user_sub,
+ "access_time": datetime.now(UTC),
+ "access_type": access_type,
+ "ip_address": ip_address,
+ "success": success,
+ "error_message": error_message
+ }
+
+ # 保存到数据库
+ collection = MongoDB().get_collection("variable_access_logs")
+ await collection.insert_one(log_entry)
+
+ except Exception as e:
+ logger.error(f"记录访问日志失败: {e}")
+
+ async def _save_audit_record(self, audit_record: Dict[str, Any]):
+ """保存审计记录"""
+ try:
+ collection = MongoDB().get_collection("variable_audit_logs")
+ await collection.insert_one(audit_record)
+ except Exception as e:
+ logger.error(f"保存审计记录失败: {e}")
+
+ async def cleanup_old_logs(self):
+ """清理过期的日志记录"""
+ try:
+ cutoff_time = datetime.now(UTC) - timedelta(days=self.audit_retention_days)
+
+ # 清理访问日志
+ access_collection = MongoDB().get_collection("variable_access_logs")
+ result1 = await access_collection.delete_many({
+ "access_time": {"$lt": cutoff_time}
+ })
+
+ # 清理审计日志
+ audit_collection = MongoDB().get_collection("variable_audit_logs")
+ result2 = await audit_collection.delete_many({
+ "access_time": {"$lt": cutoff_time}
+ })
+
+ logger.info(f"已清理过期日志: 访问日志 {result1.deleted_count} 条, 审计日志 {result2.deleted_count} 条")
+
+ except Exception as e:
+ logger.error(f"清理过期日志失败: {e}")
+
+
+# 全局安全管理器实例
+_security_manager = None
+
+
+def get_security_manager() -> SecretVariableSecurity:
+ """获取全局安全管理器实例"""
+ global _security_manager
+ if _security_manager is None:
+ _security_manager = SecretVariableSecurity()
+ return _security_manager
\ No newline at end of file
diff --git a/apps/scheduler/variable/system_variables_example.py b/apps/scheduler/variable/system_variables_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e48de30a11cc3434261361782a2d0af4b8b7bc3
--- /dev/null
+++ b/apps/scheduler/variable/system_variables_example.py
@@ -0,0 +1,220 @@
+"""
+系统变量使用示例
+
+演示系统变量的正确初始化、更新和访问流程
+"""
+
+import asyncio
+from datetime import datetime, UTC
+from typing import Dict, Any
+
+from .pool_manager import get_pool_manager
+from .parser import VariableParser, VariableReferenceBuilder
+from .type import VariableScope
+
+
+async def demonstrate_system_variables():
+ """演示系统变量的完整工作流程"""
+
+ # 模拟对话参数
+ user_id = "user123"
+ flow_id = "flow456"
+ conversation_id = "conv789"
+
+ print("=== 系统变量演示 ===\n")
+
+ # 1. 创建变量解析器(会自动创建对话池并初始化系统变量)
+ print("1. 创建变量解析器并初始化对话变量池...")
+ parser = VariableParser(
+ user_id=user_id,
+ flow_id=flow_id,
+ conversation_id=conversation_id
+ )
+
+ # 确保对话池存在
+ success = await parser.create_conversation_pool_if_needed()
+ print(f" 对话池创建结果: {'成功' if success else '失败'}")
+
+ # 2. 检查系统变量是否已正确初始化
+ print("\n2. 检查初始化的系统变量...")
+ pool_manager = await get_pool_manager()
+ conversation_pool = await pool_manager.get_conversation_pool(conversation_id)
+
+ if conversation_pool:
+ system_vars = await conversation_pool.list_system_variables()
+ print(f" 已初始化 {len(system_vars)} 个系统变量:")
+ for var in system_vars:
+ print(f" - {var.name}: {var.value} ({var.var_type.value})")
+
+ # 3. 更新系统变量(模拟对话开始)
+ print("\n3. 更新系统变量...")
+ context = {
+ "question": "请帮我分析这个数据文件",
+ "files": [{"name": "data.csv", "size": 1024, "type": "text/csv"}],
+ "dialogue_count": 1,
+ "app_id": "app001",
+ "user_sub": user_id,
+ "session_id": "session123"
+ }
+
+ await parser.update_system_variables(context)
+ print(" 系统变量更新完成")
+
+ # 4. 验证系统变量已正确更新
+ print("\n4. 验证系统变量更新结果...")
+ updated_vars = await conversation_pool.list_system_variables()
+ for var in updated_vars:
+ if var.name in ["query", "files", "dialogue_count", "app_id", "user_id", "session_id"]:
+ print(f" - {var.name}: {var.value}")
+
+ # 5. 使用变量引用解析模板
+ print("\n5. 解析包含系统变量的模板...")
+ template = """
+用户查询: {{sys.query}}
+对话轮数: {{sys.dialogue_count}}
+流程ID: {{sys.flow_id}}
+用户ID: {{sys.user_id}}
+文件数量: {{sys.files.length}}
+"""
+
+ try:
+ parsed_result = await parser.parse_template(template)
+ print(" 模板解析结果:")
+ print(parsed_result)
+ except Exception as e:
+ print(f" 模板解析失败: {e}")
+
+ # 6. 验证系统变量的只读性
+ print("\n6. 验证系统变量的只读保护...")
+ try:
+ # 尝试直接修改系统变量(应该失败)
+ await conversation_pool.update_variable("query", value="恶意修改")
+ print(" ❌ 错误:系统变量被意外修改")
+ except PermissionError:
+ print(" ✅ 正确:系统变量只读保护生效")
+ except Exception as e:
+ print(f" 🤔 意外错误: {e}")
+
+ # 7. 展示系统变量的强制更新(内部使用)
+ print("\n7. 演示系统变量的内部更新...")
+ success = await conversation_pool.update_system_variable("dialogue_count", 2)
+ if success:
+ updated_var = await conversation_pool.get_variable("dialogue_count")
+ print(f" ✅ 系统变量内部更新成功: dialogue_count = {updated_var.value}")
+ else:
+ print(" ❌ 系统变量内部更新失败")
+
+ # 8. 清理
+ print("\n8. 清理对话变量池...")
+ removed = await pool_manager.remove_conversation_pool(conversation_id)
+ print(f" 清理结果: {'成功' if removed else '失败'}")
+
+ print("\n=== 演示完成 ===")
+
+
+async def demonstrate_variable_references():
+ """演示系统变量引用的构建和使用"""
+
+ print("\n=== 变量引用演示 ===\n")
+
+ # 构建各种变量引用
+ print("1. 变量引用构建示例:")
+
+ # 系统变量引用
+ query_ref = VariableReferenceBuilder.system("query")
+ files_ref = VariableReferenceBuilder.system("files", "0.name") # 嵌套访问
+
+ # 用户变量引用
+ api_key_ref = VariableReferenceBuilder.user("api_key")
+
+ # 环境变量引用
+ db_url_ref = VariableReferenceBuilder.environment("database_url")
+
+ # 对话变量引用
+ history_ref = VariableReferenceBuilder.conversation("chat_history")
+
+ print(f" 系统变量 - 用户查询: {query_ref}")
+ print(f" 系统变量 - 首个文件名: {files_ref}")
+ print(f" 用户变量 - API密钥: {api_key_ref}")
+ print(f" 环境变量 - 数据库: {db_url_ref}")
+ print(f" 对话变量 - 聊天历史: {history_ref}")
+
+ # 构建复杂模板
+ print("\n2. 复杂模板示例:")
+ complex_template = f"""
+# 对话上下文
+- 用户: {query_ref}
+- 轮次: {VariableReferenceBuilder.system("dialogue_count")}
+- 时间: {VariableReferenceBuilder.system("timestamp")}
+
+# 文件信息
+- 文件列表: {VariableReferenceBuilder.system("files")}
+- 文件数量: {VariableReferenceBuilder.system("files", "length")}
+
+# 会话信息
+- 对话ID: {VariableReferenceBuilder.system("conversation_id")}
+- 流程ID: {VariableReferenceBuilder.system("flow_id")}
+- 用户ID: {VariableReferenceBuilder.system("user_id")}
+"""
+
+ print(complex_template)
+
+ print("=== 引用演示完成 ===")
+
+
+async def validate_system_variable_persistence():
+ """验证系统变量的持久化"""
+
+ print("\n=== 持久化验证 ===\n")
+
+ conversation_id = "test_persistence_conv"
+ flow_id = "test_persistence_flow"
+
+ # 创建对话池
+ pool_manager = await get_pool_manager()
+ conversation_pool = await pool_manager.create_conversation_pool(conversation_id, flow_id)
+
+ print("1. 检查新创建池的系统变量...")
+ system_vars_before = await conversation_pool.list_system_variables()
+ print(f" 创建后的系统变量数量: {len(system_vars_before)}")
+
+ # 模拟应用重启 - 重新获取池
+ print("\n2. 模拟重新加载...")
+ await pool_manager.remove_conversation_pool(conversation_id)
+
+ # 重新创建同一个对话池
+ conversation_pool_reloaded = await pool_manager.create_conversation_pool(conversation_id, flow_id)
+ system_vars_after = await conversation_pool_reloaded.list_system_variables()
+
+ print(f" 重新加载后的系统变量数量: {len(system_vars_after)}")
+
+ # 验证变量是否一致
+ vars_before_names = {var.name for var in system_vars_before}
+ vars_after_names = {var.name for var in system_vars_after}
+
+ if vars_before_names == vars_after_names:
+ print(" ✅ 系统变量持久化验证成功")
+ else:
+ print(" ❌ 系统变量持久化验证失败")
+ print(f" 之前: {vars_before_names}")
+ print(f" 之后: {vars_after_names}")
+
+ # 清理
+ await pool_manager.remove_conversation_pool(conversation_id)
+
+ print("=== 持久化验证完成 ===")
+
+
+if __name__ == "__main__":
+ async def main():
+ """运行所有演示"""
+ try:
+ await demonstrate_system_variables()
+ await demonstrate_variable_references()
+ await validate_system_variable_persistence()
+ except Exception as e:
+ print(f"演示过程中发生错误: {e}")
+ import traceback
+ traceback.print_exc()
+
+ asyncio.run(main())
\ No newline at end of file
diff --git a/apps/scheduler/variable/type.py b/apps/scheduler/variable/type.py
new file mode 100644
index 0000000000000000000000000000000000000000..9eb4627d8b5cf3762c24eb9bc0214623d56c5c8b
--- /dev/null
+++ b/apps/scheduler/variable/type.py
@@ -0,0 +1,72 @@
+from enum import StrEnum
+
+
+class VariableType(StrEnum):
+ """变量类型枚举"""
+ # 基础类型
+ NUMBER = "number"
+ STRING = "string"
+ BOOLEAN = "boolean"
+ OBJECT = "object"
+ SECRET = "secret"
+ GROUP = "group"
+ FILE = "file"
+
+ # 数组类型
+ ARRAY_ANY = "array[any]"
+ ARRAY_STRING = "array[string]"
+ ARRAY_NUMBER = "array[number]"
+ ARRAY_OBJECT = "array[object]"
+ ARRAY_FILE = "array[file]"
+ ARRAY_BOOLEAN = "array[boolean]"
+ ARRAY_SECRET = "array[secret]"
+
+ def is_array_type(self) -> bool:
+ """检查是否为数组类型"""
+ return self in _ARRAY_TYPES
+
+ def is_secret_type(self) -> bool:
+ """检查是否为密钥类型"""
+ return self in _SECRET_TYPES
+
+ def get_array_element_type(self) -> "VariableType | None":
+ """获取数组元素类型"""
+ if not self.is_array_type():
+ return None
+
+ element_type_map = {
+ VariableType.ARRAY_ANY: None, # any类型无法确定具体元素类型
+ VariableType.ARRAY_STRING: VariableType.STRING,
+ VariableType.ARRAY_NUMBER: VariableType.NUMBER,
+ VariableType.ARRAY_OBJECT: VariableType.OBJECT,
+ VariableType.ARRAY_FILE: VariableType.FILE,
+ VariableType.ARRAY_BOOLEAN: VariableType.BOOLEAN,
+ VariableType.ARRAY_SECRET: VariableType.SECRET,
+ }
+ return element_type_map.get(self)
+
+
+class VariableScope(StrEnum):
+ """变量作用域枚举"""
+ SYSTEM = "system" # 系统级变量(只读)
+ USER = "user" # 用户级变量(跟随用户)
+ ENVIRONMENT = "env" # 环境级变量(跟随flow)
+ CONVERSATION = "conversation" # 对话级变量(每次运行重新初始化)
+
+
+# 数组类型集合
+_ARRAY_TYPES = frozenset([
+ VariableType.ARRAY_ANY,
+ VariableType.ARRAY_STRING,
+ VariableType.ARRAY_NUMBER,
+ VariableType.ARRAY_OBJECT,
+ VariableType.ARRAY_FILE,
+ VariableType.ARRAY_BOOLEAN,
+ VariableType.ARRAY_SECRET,
+])
+
+# 密钥类型集合
+_SECRET_TYPES = frozenset([
+ VariableType.SECRET,
+ VariableType.ARRAY_SECRET,
+])
\ No newline at end of file
diff --git a/apps/scheduler/variable/variables.py b/apps/scheduler/variable/variables.py
new file mode 100644
index 0000000000000000000000000000000000000000..2902dd3adf454ff96aaa98eb2507a533ffccae7c
--- /dev/null
+++ b/apps/scheduler/variable/variables.py
@@ -0,0 +1,479 @@
+import json
+import base64
+import hashlib
+from typing import Any, Dict, List, Union, Optional
+from cryptography.fernet import Fernet
+from cryptography.hazmat.primitives import hashes
+from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
+
+from .base import BaseVariable, VariableMetadata
+from .type import VariableType
+
+
+class StringVariable(BaseVariable):
+ """字符串变量"""
+
+ def _validate_type(self, value: Any) -> bool:
+ """验证值是否为字符串类型"""
+ return isinstance(value, str)
+
+ def to_string(self) -> str:
+ """转换为字符串"""
+ return str(self._value) if self._value is not None else ""
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典"""
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self._value,
+ "scope": self.scope.value
+ }
+
+ def serialize(self) -> Dict[str, Any]:
+ """序列化"""
+ return {
+ "metadata": self.metadata.model_dump(),
+ "value": self._value,
+ "class": self.__class__.__name__
+ }
+
+ @classmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "StringVariable":
+ """反序列化"""
+ metadata = VariableMetadata(**data["metadata"])
+ return cls(metadata, data["value"])
+
+
+class NumberVariable(BaseVariable):
+ """数字变量"""
+
+ def _validate_type(self, value: Any) -> bool:
+ """验证值是否为数字类型"""
+ return isinstance(value, (int, float))
+
+ def to_string(self) -> str:
+ """转换为字符串"""
+ return str(self._value) if self._value is not None else "0"
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典"""
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self._value,
+ "scope": self.scope.value
+ }
+
+ def serialize(self) -> Dict[str, Any]:
+ """序列化"""
+ return {
+ "metadata": self.metadata.model_dump(),
+ "value": self._value,
+ "class": self.__class__.__name__
+ }
+
+ @classmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "NumberVariable":
+ """反序列化"""
+ metadata = VariableMetadata(**data["metadata"])
+ return cls(metadata, data["value"])
+
+
+class BooleanVariable(BaseVariable):
+ """布尔变量"""
+
+ def _validate_type(self, value: Any) -> bool:
+ """验证值是否为布尔类型"""
+ return isinstance(value, bool)
+
+ def to_string(self) -> str:
+ """转换为字符串"""
+ return str(self._value).lower() if self._value is not None else "false"
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典"""
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self._value,
+ "scope": self.scope.value
+ }
+
+ def serialize(self) -> Dict[str, Any]:
+ """序列化"""
+ return {
+ "metadata": self.metadata.model_dump(),
+ "value": self._value,
+ "class": self.__class__.__name__
+ }
+
+ @classmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "BooleanVariable":
+ """反序列化"""
+ metadata = VariableMetadata(**data["metadata"])
+ return cls(metadata, data["value"])
+
+
+class ObjectVariable(BaseVariable):
+ """对象变量"""
+
+ def _validate_type(self, value: Any) -> bool:
+ """验证值是否为对象类型"""
+ return isinstance(value, dict)
+
+ def to_string(self) -> str:
+ """转换为字符串"""
+ if self._value is None:
+ return "{}"
+ try:
+ return json.dumps(self._value, ensure_ascii=False, indent=2)
+ except (TypeError, ValueError):
+ return str(self._value)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典"""
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self._value,
+ "scope": self.scope.value
+ }
+
+ def serialize(self) -> Dict[str, Any]:
+ """序列化"""
+ return {
+ "metadata": self.metadata.model_dump(),
+ "value": self._value,
+ "class": self.__class__.__name__
+ }
+
+ @classmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "ObjectVariable":
+ """反序列化"""
+ metadata = VariableMetadata(**data["metadata"])
+ return cls(metadata, data["value"])
+
+
+class SecretVariable(BaseVariable):
+ """密钥变量 - 提供安全存储和访问机制"""
+
+ def __init__(self, metadata: VariableMetadata, value: Any = None, encryption_key: Optional[str] = None):
+ """初始化密钥变量
+
+ Args:
+ metadata: 变量元数据
+ value: 变量值
+ encryption_key: 加密密钥(可选,如果不提供会自动生成)
+ """
+ # 先设置加密密钥,因为父类初始化时可能会调用value setter
+ self._encryption_key = encryption_key or self._generate_encryption_key()
+ super().__init__(metadata, value)
+ self.metadata.is_encrypted = True
+
+ # 如果提供了值,确保它已被加密(在value setter中已处理)
+ # 这里不需要再次加密,因为super().__init__已经通过setter处理了
+
+ def _generate_encryption_key(self) -> str:
+ """生成加密密钥"""
+ # 使用用户ID和变量名生成唯一的加密密钥
+ user_sub = self.metadata.user_sub or "default"
+ salt = f"{user_sub}:{self.metadata.name}".encode()
+
+ kdf = PBKDF2HMAC(
+ algorithm=hashes.SHA256(),
+ length=32,
+ salt=salt,
+ iterations=100000,
+ )
+ key = base64.urlsafe_b64encode(kdf.derive(b"secret_variable_key"))
+ return key.decode()
+
+ def _encrypt_value(self, value: str) -> str:
+ """加密值"""
+ if not isinstance(value, str):
+ value = str(value)
+
+ f = Fernet(self._encryption_key.encode())
+ encrypted_value = f.encrypt(value.encode())
+ return base64.urlsafe_b64encode(encrypted_value).decode()
+
+ def _decrypt_value(self, encrypted_value: str) -> str:
+ """解密值"""
+ try:
+ encrypted_bytes = base64.urlsafe_b64decode(encrypted_value.encode())
+ f = Fernet(self._encryption_key.encode())
+ decrypted_value = f.decrypt(encrypted_bytes)
+ return decrypted_value.decode()
+ except Exception:
+ return "[解密失败]"
+
+ def _validate_type(self, value: Any) -> bool:
+ """验证值类型"""
+ return isinstance(value, str)
+
+ @property
+ def value(self) -> str:
+ """获取解密后的值"""
+ if self._value is None:
+ return ""
+ return self._decrypt_value(self._value)
+
+ @value.setter
+ def value(self, new_value: Any) -> None:
+ """设置新值(会自动加密)"""
+ if self.scope.value == "system":
+ raise ValueError("系统级变量不能修改")
+
+ if not self._validate_type(new_value):
+ raise TypeError(f"变量 {self.name} 的值类型不匹配,期望: {self.var_type}")
+
+ self._value = self._encrypt_value(new_value)
+ from datetime import datetime, UTC
+ self.metadata.updated_at = datetime.now(UTC)
+
+ def get_masked_value(self) -> str:
+ """获取掩码值用于显示"""
+ actual_value = self.value
+ if len(actual_value) <= 4:
+ return "*" * len(actual_value)
+ return actual_value[:2] + "*" * (len(actual_value) - 4) + actual_value[-2:]
+
+ def to_string(self) -> str:
+ """转换为字符串(掩码形式)"""
+ return self.get_masked_value()
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典(掩码形式)"""
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self.get_masked_value(),
+ "scope": self.scope.value
+ }
+
+ def to_dict_with_actual_value(self, user_sub: str) -> Dict[str, Any]:
+ """转换为包含实际值的字典(需要权限检查)"""
+ if not self.can_access(user_sub):
+ raise PermissionError(f"用户 {user_sub} 没有权限访问密钥变量 {self.name}")
+
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self.value, # 实际解密值
+ "scope": self.scope.value
+ }
+
+ def serialize(self) -> Dict[str, Any]:
+ """序列化(保持加密状态)"""
+ return {
+ "metadata": self.metadata.model_dump(),
+ "value": self._value, # 加密后的值
+ "encryption_key": self._encryption_key,
+ "class": self.__class__.__name__
+ }
+
+ @classmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "SecretVariable":
+ """反序列化"""
+ metadata = VariableMetadata(**data["metadata"])
+ instance = cls(metadata, None, data.get("encryption_key"))
+ instance._value = data["value"] # 直接设置加密值
+ return instance
+
+
+class FileVariable(BaseVariable):
+ """文件变量"""
+
+ def _validate_type(self, value: Any) -> bool:
+ """验证值是否为文件路径或文件对象"""
+ return isinstance(value, (str, dict)) and (
+ isinstance(value, str) or
+ (isinstance(value, dict) and "filename" in value and "content" in value)
+ )
+
+ def to_string(self) -> str:
+ """转换为字符串"""
+ if isinstance(self._value, str):
+ return self._value
+ elif isinstance(self._value, dict):
+ return self._value.get("filename", "unnamed_file")
+ return ""
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典"""
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self._value,
+ "scope": self.scope.value
+ }
+
+ def serialize(self) -> Dict[str, Any]:
+ """序列化"""
+ return {
+ "metadata": self.metadata.model_dump(),
+ "value": self._value,
+ "class": self.__class__.__name__
+ }
+
+ @classmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "FileVariable":
+ """反序列化"""
+ metadata = VariableMetadata(**data["metadata"])
+ return cls(metadata, data["value"])
+
+
+class ArrayVariable(BaseVariable):
+ """数组变量"""
+
+ def __init__(self, metadata: VariableMetadata, value: Any = None):
+ """初始化数组变量"""
+ # 先设置元素类型,因为父类初始化时会调用_validate_type
+ self._element_type = metadata.var_type.get_array_element_type()
+ super().__init__(metadata, value or [])
+
+ def _validate_type(self, value: Any) -> bool:
+ """验证值是否为数组类型,并检查元素类型"""
+ if not isinstance(value, list):
+ return False
+
+ # 如果是 array[any],不需要检查元素类型
+ if self._element_type is None:
+ return True
+
+ # 检查所有元素类型
+ for item in value:
+ if not self._validate_element_type(item):
+ return False
+
+ return True
+
+ def _validate_element_type(self, element: Any) -> bool:
+ """验证单个元素的类型"""
+ if self._element_type is None: # array[any]
+ return True
+
+ type_validators = {
+ VariableType.STRING: lambda x: isinstance(x, str),
+ VariableType.NUMBER: lambda x: isinstance(x, (int, float)),
+ VariableType.BOOLEAN: lambda x: isinstance(x, bool),
+ VariableType.OBJECT: lambda x: isinstance(x, dict),
+ VariableType.SECRET: lambda x: isinstance(x, str),
+ VariableType.FILE: lambda x: isinstance(x, (str, dict)),
+ }
+
+ validator = type_validators.get(self._element_type)
+ return validator(element) if validator else False
+
+ def append(self, item: Any) -> None:
+ """添加元素到数组"""
+ if not self._validate_element_type(item):
+ raise TypeError(f"元素类型不匹配,期望: {self._element_type}")
+
+ if self._value is None:
+ self._value = []
+ self._value.append(item)
+ from datetime import datetime, UTC
+ self.metadata.updated_at = datetime.now(UTC)
+
+ def remove(self, item: Any) -> None:
+ """从数组中移除元素"""
+ if self._value and item in self._value:
+ self._value.remove(item)
+ from datetime import datetime, UTC
+ self.metadata.updated_at = datetime.now(UTC)
+
+ def __len__(self) -> int:
+ """获取数组长度"""
+ return len(self._value) if self._value else 0
+
+ def __getitem__(self, index: int) -> Any:
+ """获取指定索引的元素"""
+ if self._value is None:
+ raise IndexError("数组为空")
+ return self._value[index]
+
+ def __setitem__(self, index: int, value: Any) -> None:
+ """设置指定索引的元素"""
+ if not self._validate_element_type(value):
+ raise TypeError(f"元素类型不匹配,期望: {self._element_type}")
+
+ if self._value is None:
+ self._value = []
+
+ # 扩展数组到指定索引
+ while len(self._value) <= index:
+ self._value.append(None)
+
+ self._value[index] = value
+ from datetime import datetime, UTC
+ self.metadata.updated_at = datetime.now(UTC)
+
+ def to_string(self) -> str:
+ """转换为字符串"""
+ if self._value is None:
+ return "[]"
+ try:
+ return json.dumps(self._value, ensure_ascii=False, indent=2)
+ except (TypeError, ValueError):
+ return str(self._value)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """转换为字典"""
+ return {
+ "name": self.name,
+ "type": self.var_type.value,
+ "value": self._value,
+ "scope": self.scope.value,
+ "element_type": self._element_type.value if self._element_type else None
+ }
+
+ def serialize(self) -> Dict[str, Any]:
+ """序列化"""
+ return {
+ "metadata": self.metadata.model_dump(),
+ "value": self._value,
+ "class": self.__class__.__name__
+ }
+
+ @classmethod
+ def deserialize(cls, data: Dict[str, Any]) -> "ArrayVariable":
+ """反序列化"""
+ metadata = VariableMetadata(**data["metadata"])
+ return cls(metadata, data["value"])
+
+
+# 变量类型映射
+VARIABLE_CLASS_MAP = {
+ VariableType.STRING: StringVariable,
+ VariableType.NUMBER: NumberVariable,
+ VariableType.BOOLEAN: BooleanVariable,
+ VariableType.OBJECT: ObjectVariable,
+ VariableType.SECRET: SecretVariable,
+ VariableType.FILE: FileVariable,
+ VariableType.ARRAY_ANY: ArrayVariable,
+ VariableType.ARRAY_STRING: ArrayVariable,
+ VariableType.ARRAY_NUMBER: ArrayVariable,
+ VariableType.ARRAY_OBJECT: ArrayVariable,
+ VariableType.ARRAY_FILE: ArrayVariable,
+ VariableType.ARRAY_BOOLEAN: ArrayVariable,
+ VariableType.ARRAY_SECRET: ArrayVariable,
+}
+
+
+def create_variable(metadata: VariableMetadata, value: Any = None) -> BaseVariable:
+ """根据类型创建变量实例
+
+ Args:
+ metadata: 变量元数据
+ value: 变量值
+
+ Returns:
+ BaseVariable: 创建的变量实例
+ """
+ variable_class = VARIABLE_CLASS_MAP.get(metadata.var_type)
+ if not variable_class:
+ raise ValueError(f"不支持的变量类型: {metadata.var_type}")
+
+ return variable_class(metadata, value)
\ No newline at end of file
diff --git a/apps/schemas/config.py b/apps/schemas/config.py
index b88a81f1afb018663713a3bf4c0b9b62436e59d4..1c2648d28f1953c38411f04f311138b9ecdc2adb 100644
--- a/apps/schemas/config.py
+++ b/apps/schemas/config.py
@@ -6,6 +6,13 @@ from typing import Literal
from pydantic import BaseModel, Field
+class NoauthConfig(BaseModel):
+ """无认证配置"""
+
+ enable: bool = Field(description="是否启用无认证访问", default=False)
+ user_sub: str = Field(description="调试用户的sub", default="admin")
+
+
class DeployConfig(BaseModel):
"""部署配置"""
@@ -77,6 +84,20 @@ class MongoDBConfig(BaseModel):
database: str = Field(description="MongoDB数据库名")
+class RedisConfig(BaseModel):
+ """Redis配置"""
+
+ host: str = Field(description="Redis主机名", default="redis-db")
+ port: int = Field(description="Redis端口号", default=6379)
+ password: str | None = Field(description="Redis密码", default=None)
+ database: int = Field(description="Redis数据库编号", default=0)
+ decode_responses: bool = Field(description="是否解码响应", default=True)
+ socket_timeout: float = Field(description="套接字超时时间(秒)", default=5.0)
+ socket_connect_timeout: float = Field(description="连接超时时间(秒)", default=5.0)
+ max_connections: int = Field(description="最大连接数", default=10)
+ health_check_interval: int = Field(description="健康检查间隔(秒)", default=30)
+
+
class LLMConfig(BaseModel):
"""LLM配置"""
@@ -85,6 +106,7 @@ class LLMConfig(BaseModel):
model: str = Field(description="LLM API 模型名")
max_tokens: int | None = Field(description="LLM API 最大Token数", default=None)
temperature: float | None = Field(description="LLM API 温度", default=None)
+ frequency_penalty: float | None = Field(description="频率惩罚是大模型生成文本时用于减少重复词汇出现的参数", default=0)
class FunctionCallConfig(BaseModel):
@@ -114,6 +136,12 @@ class CheckConfig(BaseModel):
words_list: str = Field(description="敏感词列表文件路径")
+class SandboxConfig(BaseModel):
+ """代码沙箱配置"""
+
+ sandbox_service: str = Field(description="代码沙箱服务地址")
+
+
class ExtraConfig(BaseModel):
"""额外配置"""
@@ -122,7 +150,7 @@ class ExtraConfig(BaseModel):
class ConfigModel(BaseModel):
"""配置文件的校验Class"""
-
+ no_auth: NoauthConfig = Field(description="无认证配置", default=NoauthConfig())
deploy: DeployConfig
login: LoginConfig
embedding: EmbeddingConfig
@@ -130,8 +158,10 @@ class ConfigModel(BaseModel):
fastapi: FastAPIConfig
minio: MinioConfig
mongodb: MongoDBConfig
+ redis: RedisConfig
llm: LLMConfig
function_call: FunctionCallConfig
security: SecurityConfig
check: CheckConfig
+ sandbox: SandboxConfig
extra: ExtraConfig
diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py
index 9a20ba84d4805bcd502d9d4ca0f9ba3c49a7bb4c..40225a663143e07ec5a13d0e64abff5021557665 100644
--- a/apps/schemas/enum_var.py
+++ b/apps/schemas/enum_var.py
@@ -15,6 +15,7 @@ class SlotType(str, Enum):
class StepStatus(str, Enum):
"""步骤状态"""
+ WAITING = "waiting"
RUNNING = "running"
SUCCESS = "success"
ERROR = "error"
@@ -38,6 +39,8 @@ class EventType(str, Enum):
TEXT_ADD = "text.add"
GRAPH = "graph"
DOCUMENT_ADD = "document.add"
+ STEP_WAITING_FOR_START = "step.waiting_for_start"
+ STEP_WAITING_FOR_PARAM = "step.waiting_for_param"
FLOW_START = "flow.start"
STEP_INPUT = "step.input"
STEP_OUTPUT = "step.output"
@@ -48,8 +51,10 @@ class EventType(str, Enum):
class CallType(str, Enum):
"""Call类型"""
- SYSTEM = "system"
- PYTHON = "python"
+ DEFAULT = "default"
+ LOGIC = "logic"
+ TRANSFORM = "transform"
+ TOOL = "tool"
class MetadataType(str, Enum):
@@ -118,7 +123,6 @@ class HTTPMethod(str, Enum):
DELETE = "delete"
PATCH = "patch"
-
class ContentType(str, Enum):
"""Content-Type"""
@@ -142,6 +146,7 @@ class SpecialCallType(str, Enum):
FACTS = "Facts"
SLOT = "Slot"
LLM = "LLM"
+ DIRECT_REPLY = "DirectReply"
START = "start"
END = "end"
CHOICE = "choice"
diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py
index 2646d04390099fd92988d78c00d1b1471780d6a9..b70e481477163d00bd3e9d7b371670d27b674233 100644
--- a/apps/schemas/flow.py
+++ b/apps/schemas/flow.py
@@ -156,3 +156,4 @@ class FlowConfig(BaseModel):
flow_id: str
flow_config: Flow
+
diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py
index d0ab666a9902ec7e0d81bb922f973d655f29eaf2..aa5da0d585d6d0e49d5b1d614a6fc8e1d7024d06 100644
--- a/apps/schemas/flow_topology.py
+++ b/apps/schemas/flow_topology.py
@@ -5,7 +5,7 @@ from typing import Any
from pydantic import BaseModel, Field
-from apps.schemas.enum_var import EdgeType
+from apps.schemas.enum_var import CallType, EdgeType
class NodeMetaDataItem(BaseModel):
@@ -14,6 +14,7 @@ class NodeMetaDataItem(BaseModel):
node_id: str = Field(alias="nodeId")
call_id: str = Field(alias="callId")
name: str
+ type: CallType
description: str
parameters: dict[str, Any] | None
editable: bool = Field(default=True)
diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py
index 44021b0ed5f5f6c372ac0a472bb3ac28ff5bbddb..60c8f17b4adc4f53f21b0cacc02a86e09b495d06 100644
--- a/apps/schemas/mcp.py
+++ b/apps/schemas/mcp.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP 相关数据结构"""
+import uuid
from enum import Enum
from typing import Any
@@ -117,7 +118,7 @@ class MCPToolSelectResult(BaseModel):
class MCPPlanItem(BaseModel):
"""MCP 计划"""
-
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()))
content: str = Field(description="计划内容")
tool: str = Field(description="工具名称")
instruction: str = Field(description="工具指令")
diff --git a/apps/schemas/message.py b/apps/schemas/message.py
index d0661224e0817ff6f25e341d65c88903020d18e8..cf70a82b8573d2e3e34564b8f0ec5280195980d9 100644
--- a/apps/schemas/message.py
+++ b/apps/schemas/message.py
@@ -2,7 +2,7 @@
"""队列中的消息结构"""
from typing import Any
-
+from datetime import UTC, datetime
from pydantic import BaseModel, Field
from apps.schemas.enum_var import EventType, StepStatus
@@ -24,6 +24,8 @@ class MessageFlow(BaseModel):
flow_id: str = Field(description="Flow ID", alias="flowId")
step_id: str = Field(description="当前步骤ID", alias="stepId")
step_name: str = Field(description="当前步骤名称", alias="stepName")
+ sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None)
+ sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None)
step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus")
@@ -60,10 +62,14 @@ class DocumentAddContent(BaseModel):
document_id: str = Field(description="文档UUID", alias="documentId")
document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder")
+ document_author: str = Field(description="文档作者", alias="documentAuthor", default="")
document_name: str = Field(description="文档名称", alias="documentName")
document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="")
document_type: str = Field(description="文档MIME类型", alias="documentType", default="")
document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0)
+ created_at: float = Field(
+ description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)
+ )
class FlowStartContent(BaseModel):
diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd908d2375415c798f4804c4eabde797f6a4f7a0
--- /dev/null
+++ b/apps/schemas/parameters.py
@@ -0,0 +1,69 @@
+from enum import Enum
+
+
+class NumberOperate(str, Enum):
+ """Choice 工具支持的数字运算符"""
+
+ EQUAL = "number_equal"
+ NOT_EQUAL = "number_not_equal"
+ GREATER_THAN = "number_greater_than"
+ LESS_THAN = "number_less_than"
+ GREATER_THAN_OR_EQUAL = "number_greater_than_or_equal"
+ LESS_THAN_OR_EQUAL = "number_less_than_or_equal"
+
+
+class StringOperate(str, Enum):
+ """Choice 工具支持的字符串运算符"""
+
+ EQUAL = "string_equal"
+ NOT_EQUAL = "string_not_equal"
+ CONTAINS = "string_contains"
+ NOT_CONTAINS = "string_not_contains"
+ STARTS_WITH = "string_starts_with"
+ ENDS_WITH = "string_ends_with"
+ LENGTH_EQUAL = "string_length_equal"
+ LENGTH_GREATER_THAN = "string_length_greater_than"
+ LENGTH_GREATER_THAN_OR_EQUAL = "string_length_greater_than_or_equal"
+ LENGTH_LESS_THAN = "string_length_less_than"
+ LENGTH_LESS_THAN_OR_EQUAL = "string_length_less_than_or_equal"
+ REGEX_MATCH = "string_regex_match"
+
+
+class ListOperate(str, Enum):
+ """Choice 工具支持的列表运算符"""
+
+ EQUAL = "list_equal"
+ NOT_EQUAL = "list_not_equal"
+ CONTAINS = "list_contains"
+ NOT_CONTAINS = "list_not_contains"
+ LENGTH_EQUAL = "list_length_equal"
+ LENGTH_GREATER_THAN = "list_length_greater_than"
+ LENGTH_GREATER_THAN_OR_EQUAL = "list_length_greater_than_or_equal"
+ LENGTH_LESS_THAN = "list_length_less_than"
+ LENGTH_LESS_THAN_OR_EQUAL = "list_length_less_than_or_equal"
+
+
+class BoolOperate(str, Enum):
+ """Choice 工具支持的布尔运算符"""
+
+ EQUAL = "bool_equal"
+ NOT_EQUAL = "bool_not_equal"
+
+
+class DictOperate(str, Enum):
+ """Choice 工具支持的字典运算符"""
+
+ EQUAL = "dict_equal"
+ NOT_EQUAL = "dict_not_equal"
+ CONTAINS_KEY = "dict_contains_key"
+ NOT_CONTAINS_KEY = "dict_not_contains_key"
+
+
+class Type(str, Enum):
+ """Choice 工具支持的类型"""
+
+ STRING = "string"
+ NUMBER = "number"
+ LIST = "list"
+ DICT = "dict"
+ BOOL = "bool"
diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py
index 27e16b370ec83acc11e1f435ac1da296fe2a9560..009c8206d967bd7cdb61ba1a41ebf6cf9e87dd68 100644
--- a/apps/schemas/pool.py
+++ b/apps/schemas/pool.py
@@ -45,9 +45,6 @@ class CallPool(BaseData):
Call信息
collection: call
-
- “path”的格式如下:
- 1. Python代码会被导入成包,路径格式为`python::::`,用于查找Call的包路径和类路径
"""
type: CallType = Field(description="Call的类型")
@@ -77,6 +74,7 @@ class NodePool(BaseData):
service_id: str | None = Field(description="Node所属的Service ID", default=None)
call_id: str = Field(description="所使用的Call的ID")
+ type: CallType = Field(description="所使用的Call的类型")
known_params: dict[str, Any] | None = Field(
description="已知的用于Call部分的参数,独立于输入和输出之外",
default=None,
diff --git a/apps/schemas/record.py b/apps/schemas/record.py
index d7acd368205d0aa5a2b368edf4af3a31ed4d1ff6..b5e1b0c55ad60d0569b57377a6a06a84e7a2920e 100644
--- a/apps/schemas/record.py
+++ b/apps/schemas/record.py
@@ -17,8 +17,10 @@ class RecordDocument(Document):
"""GET /api/record/{conversation_id} Result中的document数据结构"""
id: str = Field(alias="_id", default="")
+ order: int = Field(default=0, description="文档顺序")
abstract: str = Field(default="", description="文档摘要")
user_sub: None = None
+ author: str = Field(default="", description="文档作者")
associated: Literal["question", "answer"]
class Config:
@@ -103,11 +105,14 @@ class RecordGroupDocument(BaseModel):
"""RecordGroup关联的文件"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
+ order: int = Field(default=0, description="文档顺序")
+ author: str = Field(default="", description="文档作者")
name: str = Field(default="", description="文档名称")
abstract: str = Field(default="", description="文档摘要")
extension: str = Field(default="", description="文档扩展名")
size: int = Field(default=0, description="文档大小,单位是KB")
associated: Literal["question", "answer"]
+ created_at: float = Field(default=0.0, description="文档创建时间")
class Record(RecordData):
diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py
index 2305dd93dfe33969691dcc42928e578e7f1c4207..a3a8848c32e898364df671bbf3b26cb40c3f6a0a 100644
--- a/apps/schemas/request_data.py
+++ b/apps/schemas/request_data.py
@@ -16,8 +16,8 @@ class RequestDataApp(BaseModel):
"""模型对话中包含的app信息"""
app_id: str = Field(description="应用ID", alias="appId")
- flow_id: str = Field(description="Flow ID", alias="flowId")
- params: dict[str, Any] = Field(description="插件参数")
+ flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId")
+ params: dict[str, Any] | None = Field(default=None, description="插件参数")
class MockRequestData(BaseModel):
@@ -46,6 +46,7 @@ class RequestData(BaseModel):
files: list[str] = Field(default=[], description="文件列表")
app: RequestDataApp | None = Field(default=None, description="应用")
debug: bool = Field(default=False, description="是否调试")
+ new_task: bool = Field(default=True, description="是否新建任务")
class QuestionBlacklistRequest(BaseModel):
diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py
index b2a1872918638b153e1e94ff4db85ef00cfc0502..7a8cc783015a8f69259194cec4d2bea359f7a7fd 100644
--- a/apps/schemas/response_data.py
+++ b/apps/schemas/response_data.py
@@ -14,6 +14,14 @@ from apps.schemas.flow_topology import (
NodeServiceItem,
PositionItem,
)
+from apps.schemas.parameters import (
+ Type,
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+)
from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType
from apps.schemas.record import RecordData
from apps.schemas.user import UserInfo
@@ -608,7 +616,8 @@ class LLMProviderInfo(BaseModel):
"""LLM数据结构"""
llm_id: str = Field(alias="llmId", description="LLM ID")
- icon: str = Field(default="", description="LLM图标", max_length=25536)
+ # icon: str = Field(default="", description="LLM图标", max_length=25536)
+ icon: str = Field(default="", description="LLM图标")
openai_base_url: str = Field(
default="https://api.openai.com/v1",
description="OpenAI API Base URL",
@@ -628,3 +637,42 @@ class ListLLMRsp(ResponseData):
"""GET /api/llm 返回数据结构"""
result: list[LLMProviderInfo] = Field(default=[], title="Result")
+
+
+class ParamsNode(BaseModel):
+ """参数数据结构"""
+ param_name: str = Field(..., description="参数名称", alias="paramName")
+ param_path: str = Field(..., description="参数路径", alias="paramPath")
+ param_type: Type = Field(..., description="参数类型", alias="paramType")
+ sub_params: list["ParamsNode"] | None = Field(
+ default=None, description="子参数列表", alias="subParams"
+ )
+
+
+class StepParams(BaseModel):
+ """参数数据结构"""
+ step_id: str = Field(..., description="步骤ID", alias="stepId")
+ name: str = Field(..., description="Step名称")
+ params_node: ParamsNode | None = Field(
+ default=None, description="参数节点", alias="paramsNode")
+
+
+class GetParamsRsp(ResponseData):
+ """GET /api/params 返回数据结构"""
+
+ result: list[StepParams] = Field(
+ default=[], description="参数列表", alias="result"
+ )
+
+
+class OperateAndBindType(BaseModel):
+ """操作和绑定类型数据结构"""
+
+ operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型")
+ bind_type: Type = Field(description="绑定类型")
+
+
+class GetOperaRsp(ResponseData):
+ """GET /api/operate 返回数据结构"""
+
+ result: list[OperateAndBindType] = Field(..., title="Result")
diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py
index 38fd94ad4db6cff189e79f4abaec3448c97d45ec..dd2dbb6047241c99e4b09296b0cf67023edf6628 100644
--- a/apps/schemas/scheduler.py
+++ b/apps/schemas/scheduler.py
@@ -1,18 +1,19 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""插件、工作流、步骤相关数据结构定义"""
+from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
-from apps.schemas.enum_var import CallOutputType
+from apps.schemas.enum_var import CallOutputType, CallType
from apps.schemas.task import FlowStepHistory
class CallInfo(BaseModel):
"""Call的名称和描述"""
-
name: str = Field(description="Call的名称")
+ type: CallType = Field(description="Call的类别")
description: str = Field(description="Call的描述")
@@ -22,6 +23,7 @@ class CallIds(BaseModel):
task_id: str = Field(description="任务ID")
flow_id: str = Field(description="Flow ID")
session_id: str = Field(description="当前用户的Session ID")
+ conversation_id: str = Field(description="当前对话ID")
app_id: str = Field(description="当前应用的ID")
user_sub: str = Field(description="当前用户的用户ID")
diff --git a/apps/schemas/task.py b/apps/schemas/task.py
index 8efcb59914d568478c65d611670b251d019b38ed..3a37126053e693226da63a07283832c9cf3c0fae 100644
--- a/apps/schemas/task.py
+++ b/apps/schemas/task.py
@@ -9,6 +9,7 @@ from pydantic import BaseModel, Field
from apps.schemas.enum_var import StepStatus
from apps.schemas.flow import Step
+from apps.schemas.mcp import MCPPlan
class FlowStepHistory(BaseModel):
@@ -42,6 +43,7 @@ class ExecutorState(BaseModel):
# 附加信息
step_id: str = Field(description="当前步骤ID")
step_name: str = Field(description="当前步骤名称")
+ step_description: str = Field(description="当前步骤描述", default="")
app_id: str = Field(description="应用ID")
slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={})
error_info: dict[str, Any] = Field(description="错误信息", default={})
@@ -75,6 +77,7 @@ class TaskRuntime(BaseModel):
summary: str = Field(description="摘要", default="")
filled: dict[str, Any] = Field(description="填充的槽位", default={})
documents: list[dict[str, Any]] = Field(description="文档列表", default=[])
+ temporary_plans: MCPPlan | None = Field(description="临时计划列表", default=None)
class Task(BaseModel):
@@ -86,7 +89,7 @@ class Task(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
ids: TaskIds = Field(description="任务涉及的各种ID")
- context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[])
+ context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[])
state: ExecutorState | None = Field(description="Flow的状态", default=None)
tokens: TaskTokens = Field(description="Token信息")
runtime: TaskRuntime = Field(description="任务运行时数据")
diff --git a/apps/services/conversation.py b/apps/services/conversation.py
index 4bcade45c757ea2b3dca6c58863b2852af98c15b..bac964db132fecc9599bb80943e6fa119048f387 100644
--- a/apps/services/conversation.py
+++ b/apps/services/conversation.py
@@ -59,7 +59,11 @@ class ConversationManager:
model_name=llm.model_name,
)
kb_item_list = []
- team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ try:
+ team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ except:
+ logger.error("[ConversationManager] 获取团队知识库列表失败")
+ team_kb_list = []
for team_kb in team_kb_list:
for kb in team_kb["kbList"]:
if str(kb["kbId"]) in kb_ids:
diff --git a/apps/services/document.py b/apps/services/document.py
index 203162da1137cdd69d900ceea65a3029d2f4fd8e..451423a9d8f752c2abbf97f5b2a1df38e6d7fefe 100644
--- a/apps/services/document.py
+++ b/apps/services/document.py
@@ -2,6 +2,7 @@
"""文件Manager"""
import base64
+from datetime import UTC, datetime
import logging
import uuid
@@ -131,12 +132,15 @@ class DocumentManager:
return [
RecordDocument(
_id=doc.id,
+ order=doc.order,
+ author=doc.author,
abstract=doc.abstract,
name=doc.name,
type=doc.extension,
size=doc.size,
conversation_id=record_group.get("conversation_id", ""),
associated=doc.associated,
+ created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3)
)
for doc in docs if type is None or doc.associated == type
]
diff --git a/apps/services/flow.py b/apps/services/flow.py
index 9275fc60c75199e8e17ebcba8f164083190b7211..f84f5b832c02fd4bd6a704b4a25dec80db0dfada 100644
--- a/apps/services/flow.py
+++ b/apps/services/flow.py
@@ -20,7 +20,6 @@ from apps.schemas.flow_topology import (
PositionItem,
)
from apps.services.node import NodeManager
-
logger = logging.getLogger(__name__)
@@ -96,6 +95,7 @@ class FlowManager:
nodeId=node_pool_record["_id"],
callId=node_pool_record["call_id"],
name=node_pool_record["name"],
+ type=node_pool_record["type"],
description=node_pool_record["description"],
editable=True,
createdAt=node_pool_record["created_at"],
@@ -154,7 +154,7 @@ class FlowManager:
NodeServiceItem(
serviceId=record["_id"],
name=record["name"],
- type="default",
+ type="default", # TODO record["type"]?
nodeMetaDatas=[],
createdAt=str(record["created_at"]),
)
@@ -257,15 +257,21 @@ class FlowManager:
debug=flow_config.debug,
)
for node_id, node_config in flow_config.steps.items():
- input_parameters = node_config.params
- if node_config.node not in ("Empty"):
- _, output_parameters = await NodeManager.get_node_params(node_config.node)
+ # 对于Code节点,直接使用保存的完整params作为parameters
+ if node_config.type == "Code":
+ parameters = node_config.params # 直接使用保存的完整params
else:
- output_parameters = {}
- parameters = {
- "input_parameters": input_parameters,
- "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(),
- }
+ # 其他节点:使用原有逻辑
+ input_parameters = node_config.params
+ if node_config.node not in ("Empty"):
+ _, output_parameters = await NodeManager.get_node_params(node_config.node)
+ else:
+ output_parameters = {}
+ parameters = {
+ "input_parameters": input_parameters,
+ "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(),
+ }
+
node_item = NodeItem(
stepId=node_id,
nodeId=node_config.node,
@@ -275,8 +281,7 @@ class FlowManager:
editable=True,
callId=node_config.type,
parameters=parameters,
- position=PositionItem(
- x=node_config.pos.x, y=node_config.pos.y),
+ position=PositionItem(x=node_config.pos.x, y=node_config.pos.y),
)
flow_item.nodes.append(node_item)
@@ -384,13 +389,19 @@ class FlowManager:
debug=flow_item.debug,
)
for node_item in flow_item.nodes:
+ # 对于Code节点,保存完整的parameters;其他节点只保存input_parameters
+ if node_item.call_id == "Code":
+ params = node_item.parameters # 保存完整的parameters(包含input_parameters、output_parameters以及code配置)
+ else:
+ params = node_item.parameters.get("input_parameters", {}) # 其他节点只保存input_parameters
+
flow_config.steps[node_item.step_id] = Step(
type=node_item.call_id,
node=node_item.node_id,
name=node_item.name,
description=node_item.description,
pos=node_item.position,
- params=node_item.parameters.get("input_parameters", {}),
+ params=params,
)
for edge_item in flow_item.edges:
edge_from = edge_item.source_node
diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py
index 9b4077f92cf44bca42d321cc20975d7ea2e5cadd..bd8dfc9e808f642a8eecf493a96f762dad8ae7b9 100644
--- a/apps/services/knowledge.py
+++ b/apps/services/knowledge.py
@@ -138,7 +138,11 @@ class KnowledgeBaseManager:
return []
kb_ids_update_success = []
kb_item_dict_list = []
- team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ try:
+ team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ except Exception as e:
+ logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}")
+ team_kb_list = []
for team_kb in team_kb_list:
for kb in team_kb["kbList"]:
if str(kb["kbId"]) in kb_ids:
diff --git a/apps/services/node.py b/apps/services/node.py
index 6f0d492edacdd7611324a38953417e0fa769249b..bf48e71feec4df41dddc3fecfa1228913940571c 100644
--- a/apps/services/node.py
+++ b/apps/services/node.py
@@ -16,6 +16,7 @@ NODE_TYPE_MAP = {
"API": APINode,
}
+
class NodeManager:
"""Node管理器"""
@@ -29,7 +30,6 @@ class NodeManager:
raise ValueError(err)
return node["call_id"]
-
@staticmethod
async def get_node(node_id: str) -> NodePool:
"""获取Node的类型"""
@@ -40,7 +40,6 @@ class NodeManager:
raise ValueError(err)
return NodePool.model_validate(node)
-
@staticmethod
async def get_node_name(node_id: str) -> str:
"""获取node的名称"""
@@ -52,7 +51,6 @@ class NodeManager:
return ""
return node_doc["name"]
-
@staticmethod
def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]:
"""递归合并参数Schema,将known_params中的值填充到params_schema的对应位置"""
@@ -75,7 +73,6 @@ class NodeManager:
return params_schema
-
@staticmethod
async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]:
"""获取Node数据"""
@@ -100,7 +97,6 @@ class NodeManager:
err = f"[NodeManager] Call {call_id} 不存在"
logger.error(err)
raise ValueError(err)
-
# 返回参数Schema
return (
NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}),
diff --git a/apps/services/parameter.py b/apps/services/parameter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae375e97fcf92bebfc8dca2b384ffcead4da58bf
--- /dev/null
+++ b/apps/services/parameter.py
@@ -0,0 +1,86 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""flow Manager"""
+
+import logging
+
+from pymongo import ASCENDING
+
+from apps.services.node import NodeManager
+from apps.schemas.flow_topology import FlowItem
+from apps.scheduler.slot.slot import Slot
+from apps.scheduler.call.choice.condition_handler import ConditionHandler
+from apps.scheduler.call.choice.schema import (
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+ Type
+)
+from apps.schemas.response_data import (
+ OperateAndBindType,
+ ParamsNode,
+ StepParams,
+)
+from apps.services.node import NodeManager
+logger = logging.getLogger(__name__)
+
+
+class ParameterManager:
+ """Parameter Manager"""
+ @staticmethod
+ async def get_operate_and_bind_type(param_type: Type) -> list[OperateAndBindType]:
+ """Get operate and bind type"""
+ result = []
+ operate = None
+ if param_type == Type.NUMBER:
+ operate = NumberOperate
+ elif param_type == Type.STRING:
+ operate = StringOperate
+ elif param_type == Type.LIST:
+ operate = ListOperate
+ elif param_type == Type.BOOL:
+ operate = BoolOperate
+ elif param_type == Type.DICT:
+ operate = DictOperate
+ if operate:
+ for item in operate:
+ result.append(OperateAndBindType(
+ operate=item,
+ bind_type=ConditionHandler.get_value_type_from_operate(item)))
+ return result
+
+ @staticmethod
+ async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]:
+ """Get pre params by flow and step id"""
+ index = 0
+ q = [step_id]
+ in_edges = {}
+ step_id_to_node_id = {}
+ for step in flow.nodes:
+ step_id_to_node_id[step.step_id] = step.node_id
+ for edge in flow.edges:
+ if edge.target_node not in in_edges:
+ in_edges[edge.target_node] = []
+ in_edges[edge.target_node].append(edge.source_node)
+ while index < len(q):
+ tmp_step_id = q[index]
+ index += 1
+ for i in range(len(in_edges.get(tmp_step_id, []))):
+ pre_node_id = in_edges[tmp_step_id][i]
+ if pre_node_id not in q:
+ q.append(pre_node_id)
+ pre_step_params = []
+ for step_id in q:
+ node_id = step_id_to_node_id.get(step_id)
+ params_schema, output_schema = await NodeManager.get_node_params(node_id)
+ slot = Slot(output_schema)
+ params_node = slot.get_params_node_from_schema(root='/output')
+ pre_step_params.append(
+ StepParams(
+ stepId=node_id,
+ name=params_schema.get("name", ""),
+ paramsNode=params_node
+ )
+ )
+ return pre_step_params
diff --git a/apps/services/predecessor_cache_service.py b/apps/services/predecessor_cache_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..967077ea96b06982c543cf9ad6306b300cea48d5
--- /dev/null
+++ b/apps/services/predecessor_cache_service.py
@@ -0,0 +1,488 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""前置节点变量预解析缓存服务"""
+
+import asyncio
+import hashlib
+import json
+import logging
+from typing import List, Dict, Any, Optional
+from datetime import datetime, UTC
+
+from apps.common.redis_cache import RedisCache
+from apps.common.process_handler import ProcessHandler
+from apps.services.flow import FlowManager
+from apps.scheduler.variable.variables import create_variable
+from apps.scheduler.variable.base import VariableMetadata
+from apps.scheduler.variable.type import VariableType, VariableScope
+
+logger = logging.getLogger(__name__)
+
+# 全局Redis缓存实例
+redis_cache = RedisCache()
+predecessor_cache = None
+
+# 添加任务管理
+_background_tasks: Dict[str, asyncio.Task] = {}
+_task_lock = asyncio.Lock()
+
+def _get_predecessor_cache():
+ """获取predecessor_cache实例,确保初始化"""
+ global predecessor_cache
+ if predecessor_cache is None:
+ from apps.common.redis_cache import PredecessorVariableCache
+ predecessor_cache = PredecessorVariableCache(redis_cache)
+ return predecessor_cache
+
+async def init_predecessor_cache():
+ """初始化前置节点变量缓存"""
+ global predecessor_cache
+ if predecessor_cache is None:
+ from apps.common.redis_cache import PredecessorVariableCache
+ predecessor_cache = PredecessorVariableCache(redis_cache)
+
+async def cleanup_background_tasks():
+ """清理后台任务"""
+ async with _task_lock:
+ if not _background_tasks:
+ logger.info("没有后台任务需要清理")
+ return
+
+ logger.info(f"开始清理 {len(_background_tasks)} 个后台任务")
+
+ # 取消所有未完成的任务
+ cancelled_count = 0
+ for task_id, task in list(_background_tasks.items()):
+ if not task.done():
+ task.cancel()
+ cancelled_count += 1
+ logger.debug(f"取消后台任务: {task_id}")
+
+ # 等待任务取消完成,设置超时避免永久等待
+ if cancelled_count > 0:
+ logger.info(f"等待 {cancelled_count} 个任务取消完成...")
+ timeout = 5.0 # 5秒超时
+ try:
+ await asyncio.wait_for(
+ asyncio.gather(*[task for task in _background_tasks.values()], return_exceptions=True),
+ timeout=timeout
+ )
+ except asyncio.TimeoutError:
+ logger.warning(f"等待任务取消超时 ({timeout}s),强制清理")
+ except Exception as e:
+ logger.error(f"等待任务取消时出错: {e}")
+
+ # 清理任务字典
+ completed_count = len(_background_tasks)
+ _background_tasks.clear()
+ logger.info(f"后台任务清理完成,共清理 {completed_count} 个任务")
+
+async def periodic_cleanup_background_tasks():
+ """定期清理已完成的后台任务"""
+ try:
+ async with _task_lock:
+ if not _background_tasks:
+ return
+
+ completed_tasks = []
+ for task_id, task in list(_background_tasks.items()):
+ if task.done():
+ completed_tasks.append(task_id)
+ try:
+ # 获取任务结果,记录异常
+ await task
+ logger.debug(f"后台任务已完成: {task_id}")
+ except Exception as e:
+ logger.error(f"后台任务执行异常: {task_id}, 错误: {e}")
+
+ # 移除已完成的任务
+ for task_id in completed_tasks:
+ _background_tasks.pop(task_id, None)
+
+ if completed_tasks:
+ logger.info(f"定期清理了 {len(completed_tasks)} 个已完成的后台任务")
+
+ except Exception as e:
+ logger.error(f"定期清理后台任务失败: {e}")
+
+
+class PredecessorCacheService:
+ """前置节点变量预解析缓存服务"""
+
+ @staticmethod
+ async def initialize_redis():
+ """初始化Redis连接"""
+ try:
+ # 从配置文件读取Redis配置
+ from apps.common.config import Config
+
+ config = Config().get_config()
+ redis_config = config.redis
+
+ logger.info(f"准备连接Redis: {redis_config.host}:{redis_config.port}")
+ await redis_cache.init(redis_config=redis_config)
+
+ # 验证连接是否正常
+ if redis_cache.is_connected():
+ logger.info("前置节点缓存服务Redis初始化成功")
+ return
+ else:
+ raise Exception("Redis连接验证失败")
+
+ except Exception as e:
+ logger.error(f"使用配置文件连接Redis失败: {e}")
+
+ # 尝试降级连接方案
+ try:
+ logger.info("尝试降级连接方案...")
+ from apps.common.config import Config
+ config = Config().get_config()
+ redis_config = config.redis
+
+ # 构建简单的Redis URL
+ password_part = f":{redis_config.password}@" if redis_config.password else ""
+ redis_url = f"redis://{password_part}{redis_config.host}:{redis_config.port}/{redis_config.database}"
+
+ await redis_cache.init(redis_url=redis_url)
+
+ if redis_cache.is_connected():
+ logger.info("降级连接方案成功")
+ return
+ else:
+ raise Exception("降级连接方案也失败")
+
+ except Exception as fallback_error:
+ logger.error(f"降级连接方案也失败: {fallback_error}")
+
+ # 即使Redis初始化失败,也不要抛出异常,而是继续运行(降级模式)
+ logger.info("将使用实时解析模式作为降级方案")
+
+ @staticmethod
+ def calculate_flow_hash(flow_item) -> str:
+ """计算Flow拓扑结构的哈希值"""
+ try:
+ # 提取关键的拓扑信息
+ topology_data = {
+ 'nodes': [
+ {
+ 'step_id': node.step_id,
+ 'call_id': getattr(node, 'call_id', ''),
+ 'parameters': getattr(node, 'parameters', {})
+ }
+ for node in flow_item.nodes
+ ],
+ 'edges': [
+ {
+ 'source_node': edge.source_node,
+ 'target_node': edge.target_node
+ }
+ for edge in flow_item.edges
+ ]
+ }
+
+ # 生成哈希
+ topology_json = json.dumps(topology_data, sort_keys=True)
+ return hashlib.md5(topology_json.encode()).hexdigest()
+ except Exception as e:
+ logger.error(f"计算Flow哈希失败: {e}")
+ return str(datetime.now(UTC).timestamp()) # 降级方案
+
+ @staticmethod
+ async def trigger_flow_parsing(flow_id: str, force_refresh: bool = False):
+ """触发整个Flow的前置节点变量解析"""
+ try:
+ # 获取Flow信息
+ flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id)
+ if not flow_item:
+ logger.warning(f"Flow不存在,跳过解析: {flow_id}")
+ return
+
+ # 计算当前Flow的哈希
+ current_hash = PredecessorCacheService.calculate_flow_hash(flow_item)
+
+ # 检查是否需要重新解析
+ if not force_refresh:
+ cached_hash = await _get_predecessor_cache().get_flow_hash(flow_id)
+ if cached_hash == current_hash:
+ logger.info(f"Flow拓扑未变化,跳过解析: {flow_id}")
+ return
+
+ # 更新Flow哈希
+ await _get_predecessor_cache().set_flow_hash(flow_id, current_hash)
+
+ # 清除旧缓存
+ await _get_predecessor_cache().invalidate_flow_cache(flow_id)
+
+ # 为每个节点启动异步解析任务
+ tasks = []
+ for node in flow_item.nodes:
+ step_id = node.step_id
+ task_id = f"parse_predecessor_{flow_id}_{step_id}"
+
+ # 避免重复任务
+ async with _task_lock:
+ if task_id in _background_tasks and not _background_tasks[task_id].done():
+ continue
+
+ # 异步启动解析任务
+ task = asyncio.create_task(
+ PredecessorCacheService._parse_single_node_predecessor(
+ flow_id, step_id, current_hash
+ )
+ )
+ _background_tasks[task_id] = task
+ tasks.append((task_id, task))
+
+ if tasks:
+ logger.info(f"启动Flow前置节点解析任务: {flow_id}, 节点数量: {len(tasks)}")
+ # 简化处理:直接启动任务,依赖cleanup_background_tasks进行清理
+ for task_id, task in tasks:
+ # 不添加回调,让任务自然完成
+ logger.debug(f"启动后台任务: {task_id}")
+
+ except Exception as e:
+ logger.error(f"触发Flow解析失败: {flow_id}, 错误: {e}")
+
+ @staticmethod
+ async def _cleanup_task(task_id: str):
+ """清理完成的任务"""
+ try:
+ async with _task_lock:
+ task = _background_tasks.pop(task_id, None)
+ if task and task.done():
+ # 检查任务是否有异常
+ try:
+ result = await task
+ logger.debug(f"后台任务完成: {task_id}")
+ except Exception as e:
+ logger.error(f"后台任务执行异常: {task_id}, 错误: {e}")
+ except Exception as e:
+ logger.error(f"清理任务失败: {task_id}, 错误: {e}")
+
+ @staticmethod
+ async def _parse_single_node_predecessor(flow_id: str, step_id: str, flow_hash: str):
+ """解析单个节点的前置节点变量"""
+ try:
+ # 检查事件循环是否仍然活跃
+ try:
+ asyncio.get_running_loop()
+ except RuntimeError:
+ logger.warning(f"事件循环已关闭,跳过解析: {flow_id}:{step_id}")
+ return
+
+ # 设置解析状态
+ await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "parsing")
+
+ # 获取Flow信息
+ flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id)
+ if not flow_item:
+ await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "failed")
+ return
+
+ # 查找前置节点
+ predecessor_nodes = PredecessorCacheService._find_predecessor_nodes(flow_item, step_id)
+
+ # 为每个前置节点创建输出变量
+ variables_data = []
+ for node in predecessor_nodes:
+ node_vars = await PredecessorCacheService._create_node_output_variables(node)
+ variables_data.extend(node_vars)
+
+ # 缓存结果
+ await _get_predecessor_cache().set_cached_variables(flow_id, step_id, variables_data, flow_hash)
+
+ # 设置完成状态
+ await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "completed")
+
+ logger.info(f"节点前置变量解析完成: {flow_id}:{step_id}, 变量数量: {len(variables_data)}")
+
+ except asyncio.CancelledError:
+ logger.info(f"节点前置变量解析任务被取消: {flow_id}:{step_id}")
+ try:
+ await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "cancelled")
+ except Exception:
+ pass # 忽略清理时的错误
+ except Exception as e:
+ logger.error(f"解析节点前置变量失败: {flow_id}:{step_id}, 错误: {e}")
+ try:
+ await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "failed")
+ except Exception:
+ # 如果连设置状态都失败了,说明可能是事件循环关闭导致的
+ logger.warning(f"无法设置解析状态为失败: {flow_id}:{step_id}")
+
+ @staticmethod
+ async def get_predecessor_variables_optimized(
+ flow_id: str,
+ step_id: str,
+ user_sub: str,
+ max_wait_time: int = 10
+ ) -> List[Dict[str, Any]]:
+ """优化的前置节点变量获取(优先使用缓存)"""
+ try:
+ # 1. 先尝试从缓存获取
+ cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id)
+ if cached_vars is not None:
+ logger.info(f"使用缓存的前置节点变量: {flow_id}:{step_id}")
+ return cached_vars
+
+ # 2. 检查是否正在解析中
+ if await _get_predecessor_cache().is_parsing_in_progress(flow_id, step_id):
+ logger.info(f"等待前置节点变量解析完成: {flow_id}:{step_id}")
+ # 等待解析完成
+ if await _get_predecessor_cache().wait_for_parsing_completion(flow_id, step_id, max_wait_time):
+ cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id)
+ if cached_vars is not None:
+ return cached_vars
+
+ # 3. 缓存未命中,启动实时解析
+ logger.info(f"缓存未命中,启动实时解析: {flow_id}:{step_id}")
+
+ # 获取Flow信息
+ flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id)
+ if not flow_item:
+ return []
+
+ # 计算Flow哈希
+ flow_hash = PredecessorCacheService.calculate_flow_hash(flow_item)
+
+ # 立即解析并缓存
+ await PredecessorCacheService._parse_single_node_predecessor(flow_id, step_id, flow_hash)
+
+ # 再次尝试从缓存获取
+ cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id)
+ return cached_vars or []
+
+ except Exception as e:
+ logger.error(f"获取优化前置节点变量失败: {flow_id}:{step_id}, 错误: {e}")
+ return []
+
+ @staticmethod
+ async def _get_flow_by_flow_id(flow_id: str):
+ """通过flow_id获取工作流信息"""
+ try:
+ from apps.common.mongo import MongoDB
+
+ app_collection = MongoDB().get_collection("app")
+
+ # 查询包含此flow_id的app,同时获取app_id
+ app_record = await app_collection.find_one(
+ {"flows.id": flow_id},
+ {"_id": 1}
+ )
+
+ if not app_record:
+ logger.warning(f"未找到包含flow_id {flow_id} 的应用")
+ return None
+
+ app_id = app_record["_id"]
+
+ # 使用现有的FlowManager方法获取flow
+ flow_item = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id)
+ return flow_item
+
+ except Exception as e:
+ logger.error(f"通过flow_id获取工作流失败: {e}")
+ return None
+
+ @staticmethod
+ def _find_predecessor_nodes(flow_item, current_step_id: str) -> List:
+ """在工作流中查找前置节点"""
+ try:
+ predecessor_nodes = []
+
+ # 遍历边,找到指向当前节点的边
+ for edge in flow_item.edges:
+ if edge.target_node == current_step_id:
+ # 找到前置节点
+ source_node = next(
+ (node for node in flow_item.nodes if node.step_id == edge.source_node),
+ None
+ )
+ if source_node:
+ predecessor_nodes.append(source_node)
+
+ logger.debug(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点")
+ return predecessor_nodes
+
+ except Exception as e:
+ logger.error(f"查找前置节点失败: {e}")
+ return []
+
+ @staticmethod
+ async def _create_node_output_variables(node) -> List[Dict[str, Any]]:
+ """根据节点的output_parameters配置创建输出变量数据"""
+ try:
+ variables_data = []
+ node_id = node.step_id
+
+ # 统一从节点的output_parameters创建变量
+ output_params = {}
+ if hasattr(node, 'parameters') and node.parameters:
+ if isinstance(node.parameters, dict):
+ output_params = node.parameters.get('output_parameters', {})
+ else:
+ output_params = getattr(node.parameters, 'output_parameters', {})
+
+ # 如果没有配置output_parameters,跳过此节点
+ if not output_params:
+ logger.debug(f"节点 {node_id} 没有配置output_parameters,跳过创建输出变量")
+ return variables_data
+
+ # 遍历output_parameters中的每个key-value对,创建对应的变量数据
+ for param_name, param_config in output_params.items():
+ # 解析参数配置
+ if isinstance(param_config, dict):
+ param_type = param_config.get('type', 'string')
+ description = param_config.get('description', '')
+ else:
+ # 如果param_config不是字典,可能是简单的类型字符串
+ param_type = str(param_config) if param_config else 'string'
+ description = ''
+
+ # 确定变量类型
+ var_type = VariableType.STRING # 默认类型
+ if param_type == 'number':
+ var_type = VariableType.NUMBER
+ elif param_type == 'boolean':
+ var_type = VariableType.BOOLEAN
+ elif param_type == 'object':
+ var_type = VariableType.OBJECT
+ elif param_type == 'array' or param_type == 'array[any]':
+ var_type = VariableType.ARRAY_ANY
+ elif param_type == 'array[string]':
+ var_type = VariableType.ARRAY_STRING
+ elif param_type == 'array[number]':
+ var_type = VariableType.ARRAY_NUMBER
+ elif param_type == 'array[object]':
+ var_type = VariableType.ARRAY_OBJECT
+ elif param_type == 'array[boolean]':
+ var_type = VariableType.ARRAY_BOOLEAN
+ elif param_type == 'array[file]':
+ var_type = VariableType.ARRAY_FILE
+ elif param_type == 'array[secret]':
+ var_type = VariableType.ARRAY_SECRET
+ elif param_type == 'file':
+ var_type = VariableType.FILE
+ elif param_type == 'secret':
+ var_type = VariableType.SECRET
+
+ # 创建变量数据(用于缓存的字典格式)
+ variable_data = {
+ 'name': f"{node_id}.{param_name}",
+ 'var_type': var_type.value,
+ 'scope': VariableScope.CONVERSATION.value,
+ 'value': "", # 配置阶段的潜在变量,值为空
+ 'description': description or f"来自节点 {node_id} 的输出参数 {param_name}",
+ 'created_at': datetime.now(UTC).isoformat(),
+ 'updated_at': datetime.now(UTC).isoformat(),
+ 'step_name': getattr(node, 'name', node_id), # 节点名称
+ 'step_id': node_id # 节点ID
+ }
+
+ variables_data.append(variable_data)
+
+ logger.debug(f"为节点 {node_id} 创建了 {len(variables_data)} 个输出变量: {[v['name'] for v in variables_data]}")
+ return variables_data
+
+ except Exception as e:
+ logger.error(f"创建节点输出变量失败: {e}")
+ return []
\ No newline at end of file
diff --git a/apps/services/rag.py b/apps/services/rag.py
index 6b6c843dc6fc09a14809ef0a11bc44b2da434f29..1dc6fd32f9b90453046de07b402274b98e59a8be 100644
--- a/apps/services/rag.py
+++ b/apps/services/rag.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""对接Euler Copilot RAG"""
+from datetime import UTC, datetime
import json
import logging
from collections.abc import AsyncGenerator
@@ -30,8 +31,7 @@ class RAG:
"""系统提示词"""
user_prompt = """'
- 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后进行脚注。
- 一个例子将在中给出。
+ 你是华鲲振宇的智能助手。请结合给出的背景信息, 回答用户的提问。
上下文背景信息将在中给出。
用户的提问将在中给出。
注意:
@@ -68,7 +68,7 @@ class RAG:
openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]]
-
+
{bac_info}
@@ -156,9 +156,11 @@ class RAG:
"id": doc_chunk["docId"],
"order": doc_cnt,
"name": doc_chunk.get("docName", ""),
+ "author": doc_chunk.get("docAuthor", ""),
"extension": doc_chunk.get("docExtension", ""),
"abstract": doc_chunk.get("docAbstract", ""),
"size": doc_chunk.get("docSize", 0),
+ "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)),
})
doc_id_map[doc_chunk["docId"]] = doc_cnt
doc_index = doc_id_map[doc_chunk["docId"]]
diff --git a/apps/services/task.py b/apps/services/task.py
index 1e672be690a6f17896ab01bc7149aa353987ecf7..17976e62c1e3b0b8c53091789d15a06751042095 100644
--- a/apps/services/task.py
+++ b/apps/services/task.py
@@ -9,6 +9,7 @@ from apps.common.mongo import MongoDB
from apps.schemas.record import RecordGroup
from apps.schemas.request_data import RequestData
from apps.schemas.task import (
+ FlowStepHistory,
Task,
TaskIds,
TaskRuntime,
@@ -45,7 +46,6 @@ class TaskManager:
return Task.model_validate(task)
-
@staticmethod
async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None:
"""获取组ID的最后一条问答组关联的任务"""
@@ -58,7 +58,6 @@ class TaskManager:
task = await task_collection.find_one({"_id": record_group_obj.task_id})
return Task.model_validate(task)
-
@staticmethod
async def get_task_by_task_id(task_id: str) -> Task | None:
"""根据task_id获取任务"""
@@ -68,7 +67,6 @@ class TaskManager:
return None
return Task.model_validate(task)
-
@staticmethod
async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]:
"""根据record_group_id获取flow信息"""
@@ -95,9 +93,8 @@ class TaskManager:
else:
return flow_context_list
-
@staticmethod
- async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]:
+ async def get_context_by_task_id(task_id: str, length: int = 0) -> list[FlowStepHistory]:
"""根据task_id获取flow信息"""
flow_context_collection = MongoDB().get_collection("flow_context")
@@ -115,9 +112,8 @@ class TaskManager:
else:
return flow_context
-
@staticmethod
- async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None:
+ async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None:
"""保存flow信息到flow_context"""
flow_context_collection = MongoDB().get_collection("flow_context")
try:
@@ -125,7 +121,7 @@ class TaskManager:
# 查找是否存在
current_context = await flow_context_collection.find_one({
"task_id": task_id,
- "_id": history["_id"],
+ "_id": history.id,
})
if current_context:
await flow_context_collection.update_one(
@@ -133,11 +129,10 @@ class TaskManager:
{"$set": history},
)
else:
- await flow_context_collection.insert_one(history)
+ await flow_context_collection.insert_one(history.model_dump(by_alias=True, exclude_none=True))
except Exception:
logger.exception("[TaskManager] 保存flow执行记录失败")
-
@staticmethod
async def delete_task_by_task_id(task_id: str) -> None:
"""通过task_id删除Task信息"""
@@ -148,7 +143,6 @@ class TaskManager:
if task:
await task_collection.delete_one({"_id": task_id})
-
@staticmethod
async def delete_tasks_by_conversation_id(conversation_id: str) -> None:
"""通过ConversationID删除Task信息"""
@@ -167,7 +161,6 @@ class TaskManager:
await task_collection.delete_many({"conversation_id": conversation_id}, session=session)
await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session)
-
@classmethod
async def get_task(
cls,
@@ -212,7 +205,6 @@ class TaskManager:
runtime=TaskRuntime(),
)
-
@classmethod
async def save_task(cls, task_id: str, task: Task) -> None:
"""保存任务块"""
diff --git a/apps/templates/generate_llm_operator_config.py b/apps/templates/generate_llm_operator_config.py
index 2dc270b6c327f8347139ef157fe175e4c185d1b2..f63136413bedd37b177250b6b2e6550476cedb3b 100644
--- a/apps/templates/generate_llm_operator_config.py
+++ b/apps/templates/generate_llm_operator_config.py
@@ -5,54 +5,60 @@ import base64
import os
llm_provider_dict={
- "baichuan":{
- "provider":"baichuan",
- "url":"https://api.baichuan-ai.com/v1",
- "description":"百川大模型平台",
- "icon":"",
- },
- "modelscope":{
- "provider":"modelscope",
- "url":None,
- "description":"基于魔塔部署的本地大模型服务",
- "icon":"",
+ "huakunzhenyu": {
+ "provider": "huakunzhenyu",
+ "url": "http://192.168.226.30:9997/v1",
+ "description": "天巡CubeX智擎平台",
+ "icon": "",
},
+ # "baichuan":{
+ # "provider":"baichuan",
+ # "url":"https://api.baichuan-ai.com/v1",
+ # "description":"百川大模型平台",
+ # "icon":"",
+ # },
+ # "modelscope":{
+ # "provider":"modelscope",
+ # "url":None,
+ # "description":"基于魔塔部署的本地大模型服务",
+ # "icon":"",
+ # },
"ollama":{
"provider":"ollama",
"url":None,
"description":"基于Ollama部署的本地大模型服务",
"icon":"",
},
- "openai":{
- "provider":"openai",
- "url":"https://api.openai.com/v1",
- "description":"OpenAI大模型平台",
- "icon":"",
- },
- "qwen":{
- "provider":"qwen",
- "url":"https://dashscope.aliyuncs.com/compatible-mode/v1",
- "description":"阿里百炼大模型平台",
- "icon":"",
- },
- "spark":{
- "provider":"spark",
- "url":"https://spark-api-open.xf-yun.com/v1",
- "description":"讯飞星火大模型平台",
- "icon":"",
- },
+ # "openai":{
+ # "provider":"openai",
+ # "url":"https://api.openai.com/v1",
+ # "description":"OpenAI大模型平台",
+ # "icon":"",
+ # },
+ # "qwen":{
+ # "provider":"qwen",
+ # "url":"https://dashscope.aliyuncs.com/compatible-mode/v1",
+ # "description":"阿里百炼大模型平台",
+ # "icon":"",
+ # },
+ # "spark":{
+ # "provider":"spark",
+ # "url":"https://spark-api-open.xf-yun.com/v1",
+ # "description":"讯飞星火大模型平台",
+ # "icon":"",
+ # },
"vllm":{
"provider":"vllm",
"url":None,
"description":"基于VLLM部署的本地大模型服务",
"icon":"",
},
- "wenxin":{
- "provider":"wenxin",
- "url":"https://qianfan.baidubce.com/v2",
- "description":"百度文心大模型平台",
- "icon":"",
- },
+ # "wenxin":{
+ # "provider":"wenxin",
+ # "url":"https://qianfan.baidubce.com/v2",
+ # "description":"百度文心大模型平台",
+ # "icon":"",
+ # },
}
icon_path="./apps/templates/llm_provider_icon"
icon_file_name_list=os.listdir(icon_path)
@@ -68,3 +74,6 @@ for file_name in icon_file_name_list:
if provider_name in provider:
llm_provider_dict[provider]['icon'] = f"data:image/svg+xml;base64,{base64_string}"
break
+
+if __name__ == "__main__":
+ print(llm_provider_dict)
\ No newline at end of file
diff --git a/apps/templates/llm_provider_icon/huakunzhenyu.svg b/apps/templates/llm_provider_icon/huakunzhenyu.svg
new file mode 100644
index 0000000000000000000000000000000000000000..d5351cafe9f340b1fbd66ebc8d4c50308444b031
--- /dev/null
+++ b/apps/templates/llm_provider_icon/huakunzhenyu.svg
@@ -0,0 +1,3 @@
+
diff --git a/assets/.config.example.toml b/assets/.config.example.toml
index 61f13e49a893e8a698df3113148226c8a6a7463a..28fb5ef8ec5e3451101b1776aa4ce27f25f0057a 100644
--- a/assets/.config.example.toml
+++ b/assets/.config.example.toml
@@ -37,6 +37,17 @@ user = 'euler_copilot'
password = ''
database = 'euler_copilot'
+[redis]
+host = 'redis-db'
+port = 6379
+password = ''
+database = 0
+decode_responses = true
+socket_timeout = 5.0
+socket_connect_timeout = 5.0
+max_connections = 10
+health_check_interval = 30
+
[minio]
endpoint = '127.0.0.1:9000'
access_key = 'minioadmin'
@@ -62,5 +73,8 @@ temperature = 0.7
enable = false
words_list = ''
+[sandbox]
+sandbox_service = 'http://127.0.0.1:8000'
+
[extra]
sql_url = 'http://127.0.0.1:9015'
diff --git a/deploy/chart/euler_copilot/configs/framework/config.toml b/deploy/chart/euler_copilot/configs/framework/config.toml
index 64a1a6a03600b1cfc48c504f32c1e8c9ce4350cb..863f3aa5c7abf96aadc7ee367af1e064d44633f0 100644
--- a/deploy/chart/euler_copilot/configs/framework/config.toml
+++ b/deploy/chart/euler_copilot/configs/framework/config.toml
@@ -37,6 +37,17 @@ user = 'euler_copilot'
password = '${mongo-password}'
database = 'euler_copilot'
+[redis]
+host = 'redis-db.{{ .Release.Namespace }}.svc.cluster.local'
+port = 6379
+password = '${redis-password}'
+database = 5
+decode_responses = true
+socket_timeout = 5.0
+socket_connect_timeout = 5.0
+max_connections = 10
+health_check_interval = 30
+
[minio]
endpoint = 'minio-service.{{ .Release.Namespace }}.svc.cluster.local:9000'
access_key = 'minioadmin'
@@ -62,5 +73,8 @@ temperature = {{ default 0.7 .Values.models.functionCall.temperature }}
enable = false
words_list = ""
+[sandbox]
+sandbox_service = 'http://euler-copilot-sandbox-service.{{ .Release.Namespace }}.svc.cluster.local:8000'
+
[extra]
sql_url = ''
diff --git a/docs/variable_configuration.md b/docs/variable_configuration.md
new file mode 100644
index 0000000000000000000000000000000000000000..b105c4af1fb43d785e8fadb0af84f5965a6783af
--- /dev/null
+++ b/docs/variable_configuration.md
@@ -0,0 +1,133 @@
+# 变量存储格式配置说明
+
+## 概述
+
+节点执行完成后,系统会根据节点的`output_parameters`配置自动将输出数据保存到对话变量池中。变量的存储格式有两种:
+
+1. **直接格式**: `conversation.key`
+2. **带前缀格式**: `conversation.step_id.key`
+
+## 配置方式
+
+### 1. 通过节点类型配置
+
+在 `apps/scheduler/executor/step_config.py` 中的 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 集合中添加节点类型:
+
+```python
+DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: Set[str] = {
+ "Start", # Start节点
+ "Input", # 输入节点
+ "YourNewNode", # 新增的节点类型
+}
+```
+
+### 2. 通过名称模式配置
+
+在 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 集合中添加匹配模式:
+
+```python
+DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: Set[str] = {
+ "start", # 匹配以"start"开头的节点名称
+ "init", # 匹配以"init"开头的节点名称
+ "config", # 新增:匹配以"config"开头的节点名称
+}
+```
+
+## 判断逻辑
+
+系统会按以下顺序判断是否使用直接格式:
+
+1. 检查节点的 `call_id` 是否在 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 中
+2. 检查节点的 `step_name` 是否在 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 中
+3. 检查节点名称(小写)是否以 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 中的模式开头
+4. 检查 `step_id`(小写)是否以 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 中的模式开头
+
+如果任一条件满足,则使用直接格式 `conversation.key`,否则使用带前缀格式 `conversation.step_id.key`。
+
+## 使用示例
+
+### 示例1:Start节点
+
+```json
+// 节点配置
+{
+ "call_id": "Start",
+ "step_name": "start",
+ "step_id": "start_001",
+ "output_parameters": {
+ "user_name": {"type": "string"},
+ "session_id": {"type": "string"}
+ }
+}
+
+// 保存的变量格式
+conversation.user_name = "张三"
+conversation.session_id = "sess_123"
+```
+
+### 示例2:普通处理节点
+
+```json
+// 节点配置
+{
+ "call_id": "Code",
+ "step_name": "数据处理",
+ "step_id": "process_001",
+ "output_parameters": {
+ "result": {"type": "object"},
+ "status": {"type": "string"}
+ }
+}
+
+// 保存的变量格式
+conversation.process_001.result = {...}
+conversation.process_001.status = "success"
+```
+
+### 示例3:配置节点(新增类型)
+
+```python
+# 在step_config.py中添加
+DIRECT_CONVERSATION_VARIABLE_NODE_TYPES.add("GlobalConfig")
+```
+
+```json
+// 节点配置
+{
+ "call_id": "GlobalConfig",
+ "step_name": "全局配置",
+ "step_id": "config_001",
+ "output_parameters": {
+ "api_key": {"type": "secret"},
+ "timeout": {"type": "number"}
+ }
+}
+
+// 保存的变量格式(使用直接格式)
+conversation.api_key = "xxx"
+conversation.timeout = 30
+```
+
+## 变量引用
+
+在其他节点中可以通过以下方式引用这些变量:
+
+```json
+{
+ "input_parameters": {
+ "user": {
+ "reference": "{{conversation.user_name}}" // 直接格式变量
+ },
+ "data": {
+ "reference": "{{conversation.process_001.result}}" // 带前缀格式变量
+ }
+ }
+}
+```
+
+## 注意事项
+
+1. **一致性**: 建议同时添加大小写版本以确保兼容性
+2. **命名冲突**: 使用直接格式时需要注意变量名冲突问题
+3. **可追溯性**: 带前缀格式便于追踪变量来源,直接格式便于全局访问
+4. **配置变更**: 修改配置后需要重启服务才能生效
\ No newline at end of file
diff --git a/euler_copilot_framework.egg-info/PKG-INFO b/euler_copilot_framework.egg-info/PKG-INFO
new file mode 100644
index 0000000000000000000000000000000000000000..768d3f89fce727f40aaee643151eddd45dc3fd9d
--- /dev/null
+++ b/euler_copilot_framework.egg-info/PKG-INFO
@@ -0,0 +1,38 @@
+Metadata-Version: 2.4
+Name: euler-copilot-framework
+Version: 0.9.6
+Summary: EulerCopilot 后端服务
+Requires-Python: ==3.11.6
+License-File: LICENSE
+Requires-Dist: aiofiles==24.1.0
+Requires-Dist: asyncer==0.0.8
+Requires-Dist: asyncpg==0.30.0
+Requires-Dist: cryptography==44.0.2
+Requires-Dist: fastapi==0.115.12
+Requires-Dist: httpx==0.28.1
+Requires-Dist: httpx-sse==0.4.0
+Requires-Dist: jinja2==3.1.6
+Requires-Dist: jionlp==1.5.20
+Requires-Dist: jsonschema==4.23.0
+Requires-Dist: lancedb==0.21.2
+Requires-Dist: mcp==1.9.4
+Requires-Dist: minio==7.2.15
+Requires-Dist: ollama==0.5.1
+Requires-Dist: openai==1.91.0
+Requires-Dist: pandas==2.2.3
+Requires-Dist: pgvector==0.4.1
+Requires-Dist: pillow==10.3.0
+Requires-Dist: pydantic==2.11.7
+Requires-Dist: pymongo==4.12.1
+Requires-Dist: python-jsonpath==1.3.0
+Requires-Dist: python-magic==0.4.27
+Requires-Dist: python-multipart==0.0.20
+Requires-Dist: pytz==2025.2
+Requires-Dist: pyyaml==6.0.2
+Requires-Dist: rich==13.9.4
+Requires-Dist: sqids==0.5.1
+Requires-Dist: sqlalchemy==2.0.41
+Requires-Dist: tiktoken==0.9.0
+Requires-Dist: toml==0.10.2
+Requires-Dist: uvicorn==0.34.0
+Dynamic: license-file
diff --git a/euler_copilot_framework.egg-info/SOURCES.txt b/euler_copilot_framework.egg-info/SOURCES.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3beeaae71ae08e9482a4157601718053efd51c51
--- /dev/null
+++ b/euler_copilot_framework.egg-info/SOURCES.txt
@@ -0,0 +1,12 @@
+LICENSE
+README.md
+pyproject.toml
+apps/__init__.py
+apps/constants.py
+apps/exceptions.py
+apps/main.py
+euler_copilot_framework.egg-info/PKG-INFO
+euler_copilot_framework.egg-info/SOURCES.txt
+euler_copilot_framework.egg-info/dependency_links.txt
+euler_copilot_framework.egg-info/requires.txt
+euler_copilot_framework.egg-info/top_level.txt
\ No newline at end of file
diff --git a/euler_copilot_framework.egg-info/dependency_links.txt b/euler_copilot_framework.egg-info/dependency_links.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/euler_copilot_framework.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/euler_copilot_framework.egg-info/requires.txt b/euler_copilot_framework.egg-info/requires.txt
new file mode 100644
index 0000000000000000000000000000000000000000..de92c758f82c631f9fd0ef4277e15568d0e2d6bb
--- /dev/null
+++ b/euler_copilot_framework.egg-info/requires.txt
@@ -0,0 +1,31 @@
+aiofiles==24.1.0
+asyncer==0.0.8
+asyncpg==0.30.0
+cryptography==44.0.2
+fastapi==0.115.12
+httpx==0.28.1
+httpx-sse==0.4.0
+jinja2==3.1.6
+jionlp==1.5.20
+jsonschema==4.23.0
+lancedb==0.21.2
+mcp==1.9.4
+minio==7.2.15
+ollama==0.5.1
+openai==1.91.0
+pandas==2.2.3
+pgvector==0.4.1
+pillow==10.3.0
+pydantic==2.11.7
+pymongo==4.12.1
+python-jsonpath==1.3.0
+python-magic==0.4.27
+python-multipart==0.0.20
+pytz==2025.2
+pyyaml==6.0.2
+rich==13.9.4
+sqids==0.5.1
+sqlalchemy==2.0.41
+tiktoken==0.9.0
+toml==0.10.2
+uvicorn==0.34.0
diff --git a/euler_copilot_framework.egg-info/top_level.txt b/euler_copilot_framework.egg-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ea5a6165b61e9169f28b337d9757f02c7af34ba1
--- /dev/null
+++ b/euler_copilot_framework.egg-info/top_level.txt
@@ -0,0 +1 @@
+apps
diff --git a/pyproject.toml b/pyproject.toml
index 681ea9fb4552650da9148a22dd55d2063dca2410..8a34b17dd1b202e375447139133b46495cf5c30f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,6 +5,7 @@ description = "EulerCopilot 后端服务"
requires-python = "==3.11.6"
dependencies = [
"aiofiles==24.1.0",
+ "redis==5.0.8",
"asyncer==0.0.8",
"asyncpg==0.30.0",
"cryptography==44.0.2",
@@ -51,3 +52,10 @@ dev = [
"sphinx==8.2.3",
"sphinx-rtd-theme==3.0.2",
]
+
+[build-system]
+requires = ["setuptools>=65"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+packages = ["apps"] # 仅包含apps目录
\ No newline at end of file
diff --git a/tests/common/test_config.py b/tests/common/test_config.py
index 6c77141ca78e662ab12a8ec3c1ae017f3de6fced..9d90da2efd8b6f81a14136a6787ba38e157ebd16 100644
--- a/tests/common/test_config.py
+++ b/tests/common/test_config.py
@@ -68,6 +68,7 @@ MOCK_CONFIG_DATA: dict[str, Any] = {
"jwt_key": "test_jwt_key",
},
"check": {"enable": False, "words_list": ""},
+ "sandbox": {"sandbox_service": "http://localhost:8000"},
"extra": {"sql_url": "http://localhost"},
}
diff --git a/update.sh b/update.sh
new file mode 100755
index 0000000000000000000000000000000000000000..88b1df6b8850215b9c8712c32a96c396419600fc
--- /dev/null
+++ b/update.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+# update_euler_framework.sh - Docker镜像构建与推送脚本
+# 用法: ./update_euler_framework.sh
+# 示例: ./update_euler_framework.sh 0.9.6
+
+# 检查是否传入版本参数
+if [ $# -eq 0 ]; then
+ echo "错误: 必须指定版本号作为参数"
+ echo "用法: $0 "
+ exit 1
+fi
+
+VERSION=$1
+DOCKERFILE_PATH="/home/jay/code/euler-copilot-framework/Dockerfile"
+IMAGE_NAME="swr.cn-north-4.myhuaweicloud.com/euler-copilot/euler-copilot-framework:hkzy-${VERSION}-arm"
+
+echo "开始拉取最新版本的代码, 会使用当前项目分支..."
+git pull
+
+echo "开始构建Docker镜像,版本: ${VERSION}"
+
+# 构建Docker镜像
+docker build . -f ${DOCKERFILE_PATH} -t ${IMAGE_NAME}
+if [ $? -ne 0 ]; then
+ echo "错误: Docker镜像构建失败"
+ exit 1
+fi
+
+echo "Docker镜像构建成功,开始推送..."
+
+# 推送Docker镜像
+docker push ${IMAGE_NAME}
+if [ $? -ne 0 ]; then
+ echo "错误: Docker镜像推送失败"
+ exit 1
+fi
+
+echo "Docker镜像推送成功..."
+echo "操作完成,镜像地址: ${IMAGE_NAME}"
+
+echo "helm 更新开始..."
+helm upgrade euler-copilot /home/jay/deploy/chart/euler_copilot/. -n euler-copilot
+
+echo "输出euler-copilot pod状态..."
+kubectl get pods -n euler-copilot
\ No newline at end of file