diff --git a/apps/routers/knowledge.py b/apps/routers/knowledge.py index ef85efc4db2aa4d5f6dd7c143c8ec2a654d8e40f..0ca70c6d9cefc5ab44a00bdd11087008cf4af1bd 100644 --- a/apps/routers/knowledge.py +++ b/apps/routers/knowledge.py @@ -64,3 +64,23 @@ async def update_conversation_kb( result=kb_ids_update_success, ).model_dump(exclude_none=True, by_alias=True), ) + + +@router.get("/team", response_model=ResponseData, responses={ + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, +}) +async def list_team_kb( + user_sub: Annotated[str, Depends(get_user)], + kb_id: Annotated[str, Query(alias="kbId")] = None, + kb_name: Annotated[str, Query(alias="kbName")] = "", +) -> JSONResponse: + """获取团队知识库列表""" + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, kb_id, kb_name) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="success", + result=team_kb_list, + ).model_dump(exclude_none=True, by_alias=True), + ) diff --git a/apps/routers/llm.py b/apps/routers/llm.py index 7b111ab64e8f63337224abbfbbc3b3f99988ba01..5cc98366db94477b76fb31311849a5602e8abf73 100644 --- a/apps/routers/llm.py +++ b/apps/routers/llm.py @@ -6,6 +6,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, Query, status from fastapi.responses import JSONResponse +from apps.common.config import Config from apps.dependency import get_user, verify_user from apps.schemas.request_data import ( UpdateLLMReq, @@ -131,3 +132,48 @@ async def update_conv_llm( result=llm_id, ).model_dump(exclude_none=True, by_alias=True), ) + + +@router.get("/embedding", response_model=ResponseData) +async def get_embedding_config() -> JSONResponse: + """获取 Embedding 配置""" + config = Config().get_config() + embedding_config = config.embedding.__dict__ + + # 如果配置中没有 icon,添加一个默认图标 + if 'icon' not in embedding_config or not embedding_config['icon']: + # 根据模型名称设置默认图标 + model_name = embedding_config.get('model', '') + if 'bge' in model_name.lower() and 'baai' in model_name.lower(): + embedding_config['icon'] = 'https://sf-maas-uat-prod.oss-cn-shanghai.aliyuncs.com/Model_LOGO/BAAI.svg' + else: + embedding_config['icon'] = '' + + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="success", + result=embedding_config, + ).model_dump(exclude_none=True, by_alias=True), + ) + +@router.get("/reranker", response_model=ResponseData) +async def get_reranker_config() -> JSONResponse: + """获取 Reranker 配置""" + config = Config().get_config() + reranker_config = config.reranker.__dict__ + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="success", + result=[ + { + 'type': 'algorithm', + 'name': 'jaccard_dis_reranker' + }, + reranker_config + ], + ).model_dump(exclude_none=True, by_alias=True), + ) \ No newline at end of file diff --git a/apps/routers/user.py b/apps/routers/user.py index 537f1bf3adb95e8b48ef1b3d384376bae86ac802..6ccc167fef96a788014150a284f547c6a7414c70 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -7,7 +7,7 @@ from fastapi import APIRouter, Body, Depends, status, Query from fastapi.responses import JSONResponse from apps.dependency import get_user -from apps.schemas.request_data import UserUpdateRequest +from apps.schemas.request_data import UserUpdateRequest, UserPreferencesRequest from apps.schemas.response_data import UserGetMsp, UserGetRsp from apps.schemas.user import UserInfo from apps.services.user import UserManager @@ -61,3 +61,33 @@ async def update_user_info( status_code=status.HTTP_200_OK, content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"}, ) + + +@router.put("/preferences") +async def update_user_preferences( + user_sub: Annotated[str, Depends(get_user)], + *, + data: Annotated[UserPreferencesRequest, Body(..., description="用户偏好设置更新信息")], +) -> JSONResponse: + """更新用户偏好设置接口""" + await UserManager.update_user_preferences_by_user_sub(user_sub, data) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"code": status.HTTP_200_OK, "message": "用户偏好设置更新成功"}, + ) + + +@router.get("/preferences") +async def get_user_preferences( + user_sub: Annotated[str, Depends(get_user)], +) -> JSONResponse: + """获取用户偏好设置接口""" + preferences = await UserManager.get_user_preferences_by_user_sub(user_sub) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={ + "code": status.HTTP_200_OK, + "message": "用户偏好设置获取成功", + "result": preferences.model_dump(by_alias=True, exclude_none=True) + }, + ) diff --git a/apps/routers/variable.py b/apps/routers/variable.py index 6297d6713a27eb571a928da274e02c6d4f35b807..3c19221b83c9401845bfdec7e1c51d26e6862634 100644 --- a/apps/routers/variable.py +++ b/apps/routers/variable.py @@ -82,6 +82,17 @@ async def _get_predecessor_node_variables( # 将缓存的变量数据转换为Variable对象 for var_data in cached_var_data: try: + var_name = var_data['name'] + + # 检查是否为当前步骤的输出变量 + if "." in var_name and not var_name.startswith("system."): + # 提取节点ID + node_id = var_name.split(".")[0] + + # 排除当前步骤的输出变量 + if node_id == current_step_id: + continue # 跳过当前步骤的输出变量 + from apps.scheduler.variable.variables import create_variable from apps.scheduler.variable.base import VariableMetadata from apps.scheduler.variable.type import VariableType, VariableScope @@ -89,7 +100,7 @@ async def _get_predecessor_node_variables( # 创建变量元数据 metadata = VariableMetadata( - name=var_data['name'], + name=var_name, var_type=VariableType(var_data['var_type']), scope=VariableScope(var_data['scope']), description=var_data.get('description', ''), @@ -776,6 +787,10 @@ async def list_variables( if len(parts) == 2: step_id, var_name = parts + # 确保不是当前步骤的输出变量(双重保险) + if current_step_id and step_id == current_step_id: + continue + # 优先使用缓存数据中的节点信息 if hasattr(variable, '_cache_data') and variable._cache_data: cache_data = variable._cache_data diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 27184963730898d4d180c88976e18cdefaef4e43..c6d750264ea71c70246b4acec50630944a3e834d 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -82,6 +82,9 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """ # 处理reference类型 if value.type == ValueType.REFERENCE: + # 检查引用值是否为空或空白 + if not value.value or (isinstance(value.value, str) and not value.value.strip()): + return False, None, f"{value_position}引用为空或无效" try: resolved_value = await self._resolve_single_value(value, call_vars) except Exception as e: @@ -255,6 +258,11 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): Returns: tuple[bool, list[str]]: (是否成功, 错误消息列表) """ + # 对于默认分支(else分支),直接返回成功,不需要处理条件 + if choice.is_default: + logger.debug(f"[Choice] 分支 {choice.branch_id} 是默认分支,跳过条件处理") + return True, [] + # 验证逻辑运算符 if not self._validate_branch_logic(choice): return False, ["无效的逻辑运算符"] @@ -274,7 +282,7 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {error_msg}") # 如果没有有效条件,返回失败 - if not valid_conditions and not choice.is_default: + if not valid_conditions: error_messages.append("分支没有有效条件") return False, error_messages diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index c14b3c794320f85ce255a471884bf8c7720f530f..5ea4066a93f0bef92bdbbfa9a8162a2a3bbc70b2 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -79,7 +79,6 @@ class ConditionHandler(BaseModel): @staticmethod def handler(choices: list[ChoiceBranch]) -> str: """处理条件""" - logger.error(choices) default_branch = [c for c in choices if c.is_default] # 先处理所有非默认分支 diff --git a/apps/scheduler/call/loop/loop.py b/apps/scheduler/call/loop/loop.py index 896fc602c98e754b1adcfe6374996cb7914d4924..28fba5d2c7cb0043236e3bea2ea95190dc8a47f1 100644 --- a/apps/scheduler/call/loop/loop.py +++ b/apps/scheduler/call/loop/loop.py @@ -498,7 +498,7 @@ class Loop(CoreCall, input_model=LoopInput, output_model=LoopOutput): # 保存原始的循环节点状态,确保不被子步骤状态影响 original_loop_step_id = step_executor.task.state.step_id original_loop_step_name = step_executor.task.state.step_name - original_loop_status = step_executor.task.state.status + original_loop_status = step_executor.task.state.step_status # 筛选出需要执行的步骤(排除start和end) executable_steps = [ @@ -550,7 +550,7 @@ class Loop(CoreCall, input_model=LoopInput, output_model=LoopOutput): # 🔑 关键修改:确保循环节点的状态完全恢复,不受子步骤影响 step_executor.task.state.step_id = original_loop_step_id step_executor.task.state.step_name = original_loop_step_name - step_executor.task.state.status = original_loop_status + step_executor.task.state.step_status = original_loop_status logger.info(f"[Loop] 已完全恢复循环节点状态: {original_loop_step_name} (ID: {original_loop_step_id})") async def _init(self, call_vars: CallVars) -> LoopInput: diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index e70e4e25de3e6c8ada39c8c7fdc08d01311856ee..7d2fd39268f95591cb8697f2a8e33018ff3e5b70 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -293,6 +293,22 @@ class FlowExecutor(BaseExecutor): self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] else: self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + + # 重置Conversation变量池 + try: + from apps.scheduler.variable.integration import VariableIntegration + reset_success = await VariableIntegration.reset_conversation_variables_to_defaults( + conversation_id=self.task.ids.conversation_id, + flow_id=self.flow_id, + user_sub=self.task.ids.user_sub + ) + if reset_success: + logger.info(f"[FlowExecutor] Flow {self.flow_id} 执行完成后,成功重置对话变量池到默认值") + else: + logger.warning(f"[FlowExecutor] Flow {self.flow_id} 执行完成后,重置对话变量池失败") + except Exception as e: + # 重置失败不应该影响Flow的正常完成 + logger.error(f"[FlowExecutor] Flow {self.flow_id} 执行完成后重置对话变量池时发生异常: {e}") # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 62ee6a49856e0efaf6552c66943916c822213f23..5c6aa688758a2f72c75e908a4c946fea74b8c3e3 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -154,7 +154,7 @@ class StepExecutor(BaseExecutor): # 更新State self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -340,7 +340,7 @@ class StepExecutor(BaseExecutor): # 如果有失败的参数,将步骤状态设置为失败 if failed_params: from apps.schemas.enum_var import StepStatus - self.task.state.status = StepStatus.FAILED # type: ignore[assignment] + self.task.state.step_status = StepStatus.FAILED # type: ignore[assignment] failure_msg = f"输出参数类型验证失败:\n" + "\n".join(failed_params) logger.error(f"[StepExecutor] 步骤 {self.step.step_id} 执行失败: {failure_msg}") @@ -363,7 +363,7 @@ class StepExecutor(BaseExecutor): logger.error(f"[StepExecutor] 保存输出参数到变量池失败: {e}") # 对于其他意外错误,也将步骤设置为失败 from apps.schemas.enum_var import StepStatus - self.task.state.status = StepStatus.FAILED # type: ignore[assignment] + self.task.state.step_status = StepStatus.FAILED # type: ignore[assignment] raise def _extract_value_from_output_data(self, param_name: str, output_data: dict[str, Any], param_config: dict) -> Any: diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py index 4b761f3b28478b371780251d7dc84892e1fb827f..f64ea5885d65d9a688d1f3c1ebd978a56d4b5769 100644 --- a/apps/scheduler/variable/integration.py +++ b/apps/scheduler/variable/integration.py @@ -209,6 +209,33 @@ class VariableIntegration: # 如果解析失败,返回原始模板 return template + @staticmethod + async def reset_conversation_variables_to_defaults( + conversation_id: str, + flow_id: Optional[str] = None, + user_sub: Optional[str] = None + ) -> bool: + """将对话变量池中的所有conversation类型变量重置为Flow定义的默认值 + + Args: + conversation_id: 对话ID + flow_id: 流程ID(可选,用于兼容,实际从对话池中获取) + user_sub: 用户ID(可选,用于日志记录) + + Returns: + bool: 是否重置成功 + """ + try: + # 直接委托给VariablePoolManager执行重置逻辑 + from apps.scheduler.variable.pool_manager import get_pool_manager + pool_manager = await get_pool_manager() + + return await pool_manager.reset_conversation_variables_to_defaults(conversation_id) + + except Exception as e: + logger.error(f"[VariableIntegration] 重置对话变量池失败: conversation_id={conversation_id}, 错误: {e}") + return False + # 注意:原本的 monkey_patch_scheduler 和相关扩展类已被移除 # 因为 CoreCall 类现在已经内置了完整的变量解析功能 diff --git a/apps/scheduler/variable/pool_manager.py b/apps/scheduler/variable/pool_manager.py index c6b8348c4195c9e48e426e0849ff7f0180ae923c..78691a5591382d19d7bd41e9765ee51ec4193c71 100644 --- a/apps/scheduler/variable/pool_manager.py +++ b/apps/scheduler/variable/pool_manager.py @@ -369,6 +369,67 @@ class VariablePoolManager: logger.info(f"已清空工作流 {flow_id} 的 {len(to_remove)} 个对话变量池") + async def reset_conversation_variables_to_defaults(self, conversation_id: str) -> bool: + """将指定对话的变量重置为Flow定义的默认值 + + Args: + conversation_id: 对话ID + + Returns: + bool: 是否重置成功 + """ + try: + conversation_pool = await self.get_conversation_pool(conversation_id) + if not conversation_pool: + logger.warning(f"[VariablePoolManager] 未找到对话变量池: {conversation_id}") + return False + + flow_id = conversation_pool.flow_id + flow_pool = await self.get_flow_pool(flow_id) + if not flow_pool: + logger.warning(f"[VariablePoolManager] 未找到Flow变量池: {flow_id}") + return False + + # 获取所有对话变量模板 + conversation_templates = await flow_pool.list_conversation_templates() + if not conversation_templates: + logger.info(f"[VariablePoolManager] Flow {flow_id} 没有定义对话变量模板,无需重置") + return True + + reset_count = 0 + failed_count = 0 + + # 重置每个对话变量到其默认值 + for template in conversation_templates: + try: + existing_variable = await conversation_pool.get_variable(template.name) + if existing_variable: + # 跳过系统变量 + if hasattr(existing_variable.metadata, 'is_system') and existing_variable.metadata.is_system: + continue + + # 重置为模板的默认值 + await conversation_pool.update_variable( + name=template.name, + value=template.value, + force_system_update=False + ) + reset_count += 1 + logger.debug(f"[VariablePoolManager] 已重置对话变量: {template.name} = {template.value}") + + except Exception as e: + failed_count += 1 + logger.error(f"[VariablePoolManager] 重置对话变量 {template.name} 失败: {e}") + + if reset_count > 0: + logger.info(f"[VariablePoolManager] 对话 {conversation_id} 成功重置了 {reset_count} 个变量到默认值") + + return failed_count == 0 + + except Exception as e: + logger.error(f"[VariablePoolManager] 重置对话变量池失败: conversation_id={conversation_id}, 错误: {e}") + return False + async def get_pool_stats(self) -> Dict[str, int]: """获取变量池统计信息""" return { diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py index 20bdfc7c9f1cc7357365baef9c41fa4971c51957..9a5daaef63c6d71b05ab07bf5eea138463f708f9 100644 --- a/apps/schemas/collection.py +++ b/apps/schemas/collection.py @@ -8,6 +8,7 @@ from pydantic import BaseModel, Field from apps.common.config import Config from apps.constants import NEW_CHAT +from apps.schemas.preferences import UserPreferences from apps.templates.generate_llm_operator_config import llm_provider_dict @@ -62,6 +63,10 @@ class User(BaseModel): fav_services: list[str] = [] is_admin: bool = Field(default=False, description="是否为管理员") auto_execute: bool = Field(default=True, description="是否自动执行任务") + preferences: UserPreferences = Field( + default_factory=UserPreferences, + description="用户偏好设置" + ) class LLM(BaseModel): diff --git a/apps/schemas/config.py b/apps/schemas/config.py index eb9a4593aca942bdafa1d3b9f6de7a75dcd96da6..d3dbf85927cfd49ad434b7c89bc1f0bb65d9af2a 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -50,6 +50,17 @@ class EmbeddingConfig(BaseModel): endpoint: str = Field(description="Embedding模型地址") api_key: str = Field(description="Embedding模型API Key") model: str = Field(description="Embedding模型名称") + icon: str = Field(description="Embedding模型图标") + + +class RerankerConfig(BaseModel): + """Reranker配置""" + + type: str = Field(description="Reranker接口类型", default="openai") + endpoint: str = Field(description="Reranker模型地址") + api_key: str = Field(description="Reranker模型API Key") + model: str = Field(description="Reranker模型名称") + icon: str = Field(description="Reranker模型图标") class RAGConfig(BaseModel): @@ -152,6 +163,7 @@ class ConfigModel(BaseModel): deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig + reranker: RerankerConfig rag: RAGConfig fastapi: FastAPIConfig minio: MinioConfig diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py index dfffd1f14c779952d9bb761d2e2bccd85528bc0f..d9d1ca7399916faab4255f9664fc7333a6692fea 100644 --- a/apps/schemas/flow.py +++ b/apps/schemas/flow.py @@ -15,6 +15,16 @@ from apps.schemas.enum_var import ( from apps.schemas.flow_topology import PositionItem +class Note(BaseModel): + """Flow中Note的数据""" + + note_id: str = Field(description="备注的ID") + text: str = Field(description="备注内容") + position: PositionItem = Field(description="备注在画布上的位置", default=PositionItem(x=0, y=0)) + width: float = Field(description="备注的宽度", default=200.0) + height: float = Field(description="备注的高度", default=100.0) + + class Edge(BaseModel): """Flow中Edge的数据""" @@ -53,6 +63,7 @@ class Flow(BaseModel): on_error: FlowError = FlowError(use_llm=True) steps: dict[str, Step] = Field(description="节点列表", default={}) edges: list[Edge] = Field(description="边列表", default=[]) + notes: list[Note] = Field(description="备注列表", default=[]) class Permission(BaseModel): diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index b7e1175eff4bcc0f728b3b7a36e33e55705570b9..2ed64a54473058e369d8a783d7979c1260b77544 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -72,6 +72,16 @@ class EdgeItem(BaseModel): branch_id: str = Field(alias="branchId") +class NoteItem(BaseModel): + """请求/响应中的备注变量类""" + + note_id: str = Field(alias="noteId") + text: str + position: PositionItem = Field(default=PositionItem()) + width: float = Field(default=200.0) + height: float = Field(default=100.0) + + class FlowItem(BaseModel): """请求/响应中的流变量类""" @@ -82,6 +92,7 @@ class FlowItem(BaseModel): editable: bool = Field(default=True) nodes: list[NodeItem] = Field(default=[]) edges: list[EdgeItem] = Field(default=[]) + notes: list[NoteItem] = Field(default=[]) created_at: float | None = Field(alias="createdAt", default=0) connectivity: bool = Field(default=False, description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem()) diff --git a/apps/schemas/preferences.py b/apps/schemas/preferences.py new file mode 100644 index 0000000000000000000000000000000000000000..948194f7ae18806387b4e1f85f883250b5ecd200 --- /dev/null +++ b/apps/schemas/preferences.py @@ -0,0 +1,73 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""用户偏好设置相关数据结构""" + +from pydantic import BaseModel, Field + + +class ReasoningModelPreference(BaseModel): + """推理模型偏好设置""" + + model_config = {"populate_by_name": True} + + llm_id: str = Field(alias="llmId", description="模型ID") + icon: str = Field(default="", description="模型图标") + openai_base_url: str = Field(alias="openaiBaseUrl", description="OpenAI API基础URL") + openai_api_key: str = Field(alias="openaiApiKey", description="OpenAI API密钥") + model_name: str = Field(alias="modelName", description="模型名称") + max_tokens: int = Field(alias="maxTokens", description="最大token数") + is_editable: bool | None = Field(alias="isEditable", default=None, description="是否可编辑") + + +class EmbeddingModelPreference(BaseModel): + """嵌入模型偏好设置""" + + model_config = {"populate_by_name": True} + + llm_id: str = Field(alias="llmId", description="模型ID") + model_name: str = Field(alias="modelName", description="模型名称") + icon: str = Field(default="", description="模型图标") + type: str = Field(description="模型类型") + endpoint: str = Field(description="API端点") + api_key: str = Field(alias="apiKey", description="API密钥") + + +class RerankerModelPreference(BaseModel): + """重排序模型偏好设置""" + + model_config = {"populate_by_name": True} + + llm_id: str = Field(alias="llmId", description="模型ID") + model_name: str = Field(alias="modelName", description="模型名称") + icon: str = Field(default="", description="模型图标") + type: str = Field(description="模型类型") + # 以下字段为可选,因为algorithm类型的reranker可能没有这些字段 + endpoint: str | None = Field(default=None, description="API端点") + api_key: str | None = Field(alias="apiKey", default=None, description="API密钥") + name: str | None = Field(default=None, description="算法名称") + + +class UserPreferences(BaseModel): + """用户偏好设置""" + + model_config = {"populate_by_name": True} + + reasoning_model_preference: ReasoningModelPreference | None = Field( + default=None, + description="推理模型偏好", + alias="reasoningModelPreference" + ) + embedding_model_preference: EmbeddingModelPreference | None = Field( + default=None, + description="嵌入模型偏好", + alias="embeddingModelPreference" + ) + reranker_preference: RerankerModelPreference | None = Field( + default=None, + description="重排序模型偏好", + alias="rerankerPreference" + ) + chain_of_thought_preference: bool = Field( + default=True, + description="思维链偏好", + alias="chainOfThoughtPreference" + ) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 2cf85ecc9650da4ce880e1948facf35ff740b3dd..c0e51120297c00fe5f5214592df41a28e04cd6fa 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -11,6 +11,7 @@ from apps.schemas.enum_var import CommentType, LanguageType from apps.schemas.flow_topology import FlowItem from apps.schemas.mcp import MCPType from apps.schemas.message import FlowParams +from apps.schemas.preferences import ReasoningModelPreference, EmbeddingModelPreference, RerankerModelPreference class RequestDataApp(BaseModel): @@ -191,4 +192,29 @@ class UpdateKbReq(BaseModel): class UserUpdateRequest(BaseModel): """更新用户信息请求体""" - auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") \ No newline at end of file + auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") + + +class UserPreferencesRequest(BaseModel): + """更新用户偏好设置请求体""" + + reasoning_model_preference: ReasoningModelPreference | None = Field( + default=None, + description="推理模型偏好", + alias="reasoningModelPreference" + ) + embedding_model_preference: EmbeddingModelPreference | None = Field( + default=None, + description="嵌入模型偏好", + alias="embeddingModelPreference" + ) + reranker_preference: RerankerModelPreference | None = Field( + default=None, + description="重排序模型偏好", + alias="rerankerPreference" + ) + chain_of_thought_preference: bool | None = Field( + default=None, + description="思维链偏好", + alias="chainOfThoughtPreference" + ) \ No newline at end of file diff --git a/apps/schemas/user.py b/apps/schemas/user.py index ebea66924c062f4e680ad2e414768b307b056f4a..23966539e6c2f33c78194451780047cdd25ce494 100644 --- a/apps/schemas/user.py +++ b/apps/schemas/user.py @@ -3,6 +3,8 @@ from pydantic import BaseModel, Field +from apps.schemas.preferences import UserPreferences + class UserInfo(BaseModel): """用户信息数据结构""" @@ -10,3 +12,4 @@ class UserInfo(BaseModel): user_sub: str = Field(alias="userSub", default="") user_name: str = Field(alias="userName", default="") auto_execute: bool | None = Field(alias="autoExecute", default=False) + preferences: UserPreferences | None = Field(alias="preferences", default=None) diff --git a/apps/services/flow.py b/apps/services/flow.py index 8c79bc5516330c0946137c9f2634509a35d3e802..8f2928f16983ad7f0e94ef25ed92f9962e34112d 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -12,13 +12,14 @@ from apps.scheduler.pool.loader.flow import FlowLoader from apps.scheduler.slot.slot import Slot from apps.schemas.collection import User from apps.schemas.enum_var import EdgeType, PermissionType, LanguageType -from apps.schemas.flow import Edge, Flow, Step +from apps.schemas.flow import Edge, Flow, Note, Step from apps.schemas.flow_topology import ( EdgeItem, FlowItem, NodeItem, NodeMetaDataItem, NodeServiceItem, + NoteItem, PositionItem, ) from apps.scheduler.pool.pool import Pool @@ -108,7 +109,7 @@ class FlowManager: if service_id == "": call_class: type[BaseModel] = await Pool().get_call(node_pool_record["_id"]) node_name = call_class.info(language).name - node_description = call_class.info().description + node_description = call_class.info(language).description else: node_name = node_pool_record["name"] node_description = node_pool_record["description"] @@ -285,6 +286,7 @@ class FlowManager: editable=True, nodes=[], edges=[], + notes=[], focusPoint=focus_point, connectivity=flow_config.connectivity, debug=flow_config.debug, @@ -354,6 +356,18 @@ class FlowManager: branchId=branch_id, ), ) + + # 处理notes + for note_config in flow_config.notes: + flow_item.notes.append( + NoteItem( + noteId=note_config.note_id, + text=note_config.text, + position=note_config.position, + width=note_config.width, + height=note_config.height, + ), + ) return flow_item except Exception: logger.exception("[FlowManager] 获取流失败") @@ -435,6 +449,7 @@ class FlowManager: description=flow_item.description, steps={}, edges=[], + notes=[], focus_point=flow_item.focus_point, connectivity=flow_item.connectivity, debug=flow_item.debug, @@ -505,7 +520,23 @@ class FlowManager: logger.error(f"[FlowManager] 创建边失败: {edge_item.edge_id}, 错误: {e}") continue - logger.info(f"[FlowManager] 构建完成,flow_config.edges数量: {len(flow_config.edges)}") + # 处理notes + for note_item in flow_item.notes: + try: + note_config = Note( + note_id=note_item.note_id, + text=note_item.text, + position=note_item.position, + width=note_item.width, + height=note_item.height, + ) + flow_config.notes.append(note_config) + logger.info(f"[FlowManager] 添加备注: {note_item.note_id}") + except Exception as e: + logger.error(f"[FlowManager] 创建备注失败: {note_item.note_id}, 错误: {e}") + continue + + logger.info(f"[FlowManager] 构建完成,flow_config.edges数量: {len(flow_config.edges)}, flow_config.notes数量: {len(flow_config.notes)}") if old_flow_config is None: error_msg = f"[FlowManager] 流 {flow_id} 不存在;可能为新创建" @@ -828,6 +859,7 @@ class FlowManager: description=flow_item.description, steps={}, edges=[], + notes=[], focus_point=flow_item.focus_point, connectivity=flow_item.connectivity, debug=flow_item.debug, @@ -856,6 +888,22 @@ class FlowManager: edge_type=EdgeType.NORMAL, # 子工作流默认使用普通边 ) flow_config.edges.append(edge_config) + + # 处理notes + for note_item in flow_item.notes: + try: + note_config = Note( + note_id=note_item.note_id, + text=note_item.text, + position=note_item.position, + width=note_item.width, + height=note_item.height, + ) + flow_config.notes.append(note_config) + logger.info(f"[FlowManager] 子工作流添加备注: {note_item.note_id}") + except Exception as e: + logger.error(f"[FlowManager] 子工作流创建备注失败: {note_item.note_id}, 错误: {e}") + continue # 使用子工作流专用的保存路径 flow_loader = FlowLoader() @@ -898,6 +946,7 @@ class FlowManager: editable=True, nodes=[], edges=[], + notes=[], focusPoint=focus_point, connectivity=flow_config.connectivity, debug=flow_config.debug, @@ -963,6 +1012,18 @@ class FlowManager: branchId=branch_id, ) flow_item.edges.append(edge_item) + + # 处理notes + for note_config in flow_config.notes: + flow_item.notes.append( + NoteItem( + noteId=note_config.note_id, + text=note_config.text, + position=note_config.position, + width=note_config.width, + height=note_config.height, + ), + ) return flow_item diff --git a/apps/services/user.py b/apps/services/user.py index 1b96df18143f5a87b102f68e2bb4e2df62753f05..aaa220da1463244e501f7845ce78066aab96f9f6 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -4,9 +4,10 @@ import logging from datetime import UTC, datetime -from apps.schemas.request_data import UserUpdateRequest +from apps.schemas.request_data import UserUpdateRequest, UserPreferencesRequest from apps.common.mongo import MongoDB from apps.schemas.collection import User +from apps.schemas.preferences import UserPreferences from apps.services.conversation import ConversationManager logger = logging.getLogger(__name__) @@ -129,3 +130,46 @@ class UserManager: for conv_id in result.conversations: await ConversationManager.delete_conversation_by_conversation_id(user_sub, conv_id) + + @staticmethod + async def update_user_preferences_by_user_sub(user_sub: str, data: UserPreferencesRequest) -> None: + """ + 根据用户sub更新用户偏好设置 + + :param user_sub: 用户sub + :param data: 用户偏好设置更新信息 + """ + mongo = MongoDB() + user_collection = mongo.get_collection("user") + + # 构建更新字典,只更新非None的字段 + preferences_update = {} + if data.reasoning_model_preference is not None: + preferences_update["preferences.reasoning_model_preference"] = data.reasoning_model_preference.model_dump() + if data.embedding_model_preference is not None: + preferences_update["preferences.embedding_model_preference"] = data.embedding_model_preference.model_dump() + if data.reranker_preference is not None: + preferences_update["preferences.reranker_preference"] = data.reranker_preference.model_dump() + if data.chain_of_thought_preference is not None: + preferences_update["preferences.chain_of_thought_preference"] = data.chain_of_thought_preference + + if preferences_update: + update_dict = {"$set": preferences_update} + await user_collection.update_one({"_id": user_sub}, update_dict) + + @staticmethod + async def get_user_preferences_by_user_sub(user_sub: str) -> UserPreferences: + """ + 根据用户sub获取用户偏好设置 + + :param user_sub: 用户sub + :return: 用户偏好设置 + """ + mongo = MongoDB() + user_collection = mongo.get_collection("user") + user_data = await user_collection.find_one({"_id": user_sub}, {"preferences": 1}) + if user_data and "preferences" in user_data: + # 使用model_validate来处理从数据库读取的数据,这样会正确处理别名映射 + return UserPreferences.model_validate(user_data["preferences"]) + else: + return UserPreferences() diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py index 5375180a3f453a9d9d18d17b95b0ed9c39d8c5c4..e1517277d09c2f3f560a560d6e55c4822af76130 100644 --- a/tests/common/test_queue.py +++ b/tests/common/test_queue.py @@ -74,7 +74,7 @@ async def test_push_output_with_flow(message_queue, mock_task): mock_task.state.flow_id = "flow_id" mock_task.state.step_id = "step_id" mock_task.state.step_name = "step_name" - mock_task.state.status = "running" + mock_task.state.step_status = "running" await message_queue.init("test_task") await message_queue.push_output(mock_task, EventType.TEXT_ADD, {})