diff --git a/apps/llm/schema.py b/apps/llm/enum.py similarity index 100% rename from apps/llm/schema.py rename to apps/llm/enum.py diff --git a/apps/llm/function.py b/apps/llm/function.py index 759133cbbc3f11f92a8f88a17b209997f5394f9d..02750d06244977bdf1a3ab708883d1a510033fdb 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -16,7 +16,7 @@ from apps.common.config import Config from apps.constants import JSON_GEN_MAX_TRIAL, REASONING_END_TOKEN from apps.llm.prompt import JSON_GEN_BASIC from apps.llm.adapters import AdapterFactory, get_provider_from_endpoint - +from apps.schemas.config import FunctionCallConfig # 导入异常处理相关模块 import openai import httpx @@ -30,7 +30,7 @@ logger = logging.getLogger(__name__) def infer_backend_from_capabilities(provider: str, model_name: str, explicit_backend: str | None = None) -> str: """ 根据模型能力推断最佳的backend - + :param provider: 模型提供商 :param model_name: 模型名称 :param explicit_backend: 明确指定的backend(如果有),优先使用 @@ -40,34 +40,40 @@ def infer_backend_from_capabilities(provider: str, model_name: str, explicit_bac if explicit_backend: logger.info(f"[FunctionCall] 使用明确指定的backend: {explicit_backend}") return explicit_backend - + # 从模型注册表获取模型能力 - capabilities = model_registry.get_model_capabilities(provider, model_name, ModelType.CHAT) - + capabilities = model_registry.get_model_capabilities( + provider, model_name, ModelType.CHAT) + if not capabilities or not isinstance(capabilities, ChatCapabilities): - logger.warning(f"[FunctionCall] 无法获取模型 {provider}:{model_name} 的能力信息,使用默认backend: json_mode") + logger.warning( + f"[FunctionCall] 无法获取模型 {provider}:{model_name} 的能力信息,使用默认backend: json_mode") return "json_mode" - + # 根据能力优先级推断backend # 优先级: structured_output > function_call > json_mode if capabilities.supports_structured_output: - logger.info(f"[FunctionCall] 模型 {provider}:{model_name} 支持structured_output,自动选择") + logger.info( + f"[FunctionCall] 模型 {provider}:{model_name} 支持structured_output,自动选择") return "structured_output" elif capabilities.supports_function_calling: - logger.info(f"[FunctionCall] 模型 {provider}:{model_name} 支持function_calling,自动选择") + logger.info( + f"[FunctionCall] 模型 {provider}:{model_name} 支持function_calling,自动选择") return "function_call" elif capabilities.supports_json_mode: - logger.info(f"[FunctionCall] 模型 {provider}:{model_name} 支持json_mode,自动选择") + logger.info( + f"[FunctionCall] 模型 {provider}:{model_name} 支持json_mode,自动选择") return "json_mode" else: - logger.warning(f"[FunctionCall] 模型 {provider}:{model_name} 不支持任何JSON生成能力,回退到json_mode") + logger.warning( + f"[FunctionCall] 模型 {provider}:{model_name} 不支持任何JSON生成能力,回退到json_mode") return "json_mode" class FunctionLLM: """用于FunctionCall的模型""" - def __init__(self, llm_config=None) -> None: + def __init__(self, llm_config: FunctionCallConfig = None) -> None: """ 初始化用于FunctionCall的模型 @@ -77,32 +83,18 @@ class FunctionLLM: - function_call - json_mode - structured_output - + backend会根据模型能力自动推断,也可以通过配置明确指定 - + :param llm_config: 可选的LLM配置,如果不提供则使用配置文件中的function_call配置 """ # 使用传入的配置或从配置文件获取 if llm_config: self._config = llm_config - # 如果没有backend字段,根据模型特性推断 - if not hasattr(self._config, 'backend'): - # 创建一个包含backend的配置对象 - class ConfigWithBackend: - def __init__(self, base_config): - self.model = base_config.model - self.endpoint = base_config.endpoint - self.api_key = getattr(base_config, 'key', getattr(base_config, 'api_key', '')) - self.max_tokens = getattr(base_config, 'max_tokens', 8192) - self.temperature = getattr(base_config, 'temperature', 0.7) - # backend将在后面通过推断设置 - self.backend = None - - self._config = ConfigWithBackend(llm_config) else: # 暂存config;这里可以替代为从其他位置获取 self._config = Config().get_config().function_call - + if not self._config.model: err_msg = "[FunctionCall] 未设置FuntionCall所用模型!" logger.error(err_msg) @@ -114,20 +106,21 @@ class FunctionLLM: self._provider = self._config.provider else: self._provider = get_provider_from_endpoint(self._config.endpoint) - - self._adapter = AdapterFactory.create_adapter(self._provider, self._config.model) - + + self._adapter = AdapterFactory.create_adapter( + self._provider, self._config.model) + # 智能推断backend:如果配置中backend为None或空字符串,则根据模型能力推断 explicit_backend = getattr(self._config, 'backend', None) if not explicit_backend or explicit_backend == 'null': explicit_backend = None - + self._backend = infer_backend_from_capabilities( - self._provider, + self._provider, self._config.model, explicit_backend ) - + self._params = { "model": self._config.model, "messages": [], @@ -218,8 +211,9 @@ class FunctionLLM: ] # 使用适配器调整参数,对于JSON生成任务禁用thinking以避免解析问题 - adapted_params = self._adapter.adapt_create_params(self._params, enable_thinking=False) - + adapted_params = self._adapter.adapt_create_params( + self._params, enable_thinking=False) + try: # type: ignore[arg-type] response = await self._client.chat.completions.create(**adapted_params) @@ -481,14 +475,14 @@ class JsonGenerator: return tmp_js except Exception as e: logger.error("[JsonGenerator] 解析结果失败: %s", e) - return None + return {} - def __init__(self, query: str, conversation: list[dict[str, str]], schema: dict[str, Any]) -> None: + def __init__(self, query: str, conversation: list[dict[str, str]], schema: dict[str, Any], func_call_llm: FunctionLLM = FunctionLLM()) -> None: """初始化JSON生成器""" self._query = query self._conversation = conversation self._schema = schema - + self.func_call_llm = func_call_llm self._trial = {} self._count = 0 self._env = SandboxedEnvironment( @@ -524,8 +518,7 @@ class JsonGenerator: {"role": "system", "content": prompt}, {"role": "user", "content": "please generate a JSON response based on the above information and schema./no_think"}, ] - function = FunctionLLM() - return await function.call(messages, self._schema, max_tokens, temperature) + return await self.func_call_llm.call(messages, self._schema, max_tokens, temperature) async def generate(self) -> dict[str, Any]: """生成JSON""" diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index eab73b6581c10648b40d286f0e37d9a5bbc6008c..c518b02e3ad617969541392ce988b15e624ff3a3 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -9,11 +9,17 @@ from textwrap import dedent from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment +from apps.llm.enum import DefaultModelId from apps.llm.function import JsonGenerator from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM +from apps.llm.function import FunctionLLM from apps.llm.token import TokenCalculator from apps.schemas.enum_var import LanguageType +from apps.services.llm import LLMManager +from apps.llm.adapters import get_provider_from_endpoint +from apps.schemas.config import LLMConfig +from apps.schemas.config import FunctionCallConfig logger = logging.getLogger(__name__) @@ -38,8 +44,7 @@ class QuestionRewrite(CorePattern): self, system_prompt: dict[LanguageType, str] | None = None, user_prompt: dict[LanguageType, str] | None = None, - llm_id: str | None = None, - enable_thinking: bool = False, + llm_id: str = DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value, ) -> None: """初始化问题改写模式 @@ -50,7 +55,6 @@ class QuestionRewrite(CorePattern): """ super().__init__(system_prompt, user_prompt) self.llm_id = llm_id - self.enable_thinking = enable_thinking def get_default_prompt(self) -> dict[LanguageType, str]: system_prompt = { @@ -188,31 +192,21 @@ class QuestionRewrite(CorePattern): language = kwargs.get("language", LanguageType.CHINESE) # 根据llm_id获取模型配置并创建LLM实例 - llm = None - if self.llm_id: - from apps.services.llm import LLMManager - from apps.llm.adapters import get_provider_from_endpoint - from apps.schemas.config import LLMConfig + llm_info = await LLMManager.get_llm_by_id(self.llm_id) + provider = llm_info.provider or get_provider_from_endpoint( + llm_info.openai_base_url) - llm_info = await LLMManager.get_llm_by_id(self.llm_id) - if llm_info: - provider = llm_info.provider or get_provider_from_endpoint( - llm_info.openai_base_url) - - llm_config = LLMConfig( - provider=provider, - endpoint=llm_info.openai_base_url, - api_key=llm_info.openai_api_key, - model=llm_info.model_name, - max_tokens=llm_info.max_tokens, - temperature=0.7, - ) - llm = ReasoningLLM(llm_config) - - if not llm: - llm = ReasoningLLM() + llm_config = LLMConfig( + provider=provider, + endpoint=llm_info.openai_base_url, + api_key=llm_info.openai_api_key, + model=llm_info.model_name, + max_tokens=llm_info.max_tokens, + temperature=0.7, + ) + llm = ReasoningLLM(llm_config) - leave_tokens = llm._config.max_tokens + leave_tokens = llm_info.max_tokens leave_tokens -= TokenCalculator().calculate_token_length( messages=[{"role": "system", "content": _env.from_string(self.system_prompt[language]).render( history="", question=question)}, @@ -247,10 +241,21 @@ class QuestionRewrite(CorePattern): if tmp_js is not None: return tmp_js['question'] messages += [{"role": "assistant", "content": result}] + # 使用Function LLM进行JSON生成 + func_call_llm_config = FunctionCallConfig( + provider=llm_info.provider, + endpoint=llm_info.openai_base_url, + api_key=llm_info.openai_api_key, + model=llm_info.model_name, + max_tokens=llm_info.max_tokens, + temperature=0.7, + ) + func_call_llm = FunctionLLM(func_call_llm_config) json_gen = JsonGenerator( query="根据给定的背景信息,生成预测问题", conversation=messages, schema=QuestionRewriteResult.model_json_schema(), + func_call_llm=func_call_llm ) try: question_dict = QuestionRewriteResult.model_validate(await json_gen.generate()) diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 73b6f13c71c279c27c5fbd876016a032abf010fe..84ab5551df1adcb1ab54648418ba6efb7ff61e6b 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse from apps.dependency.user import get_user, verify_user from apps.exceptions import InstancePermissionError -from apps.llm.schema import DefaultModelId +from apps.llm.enum import DefaultModelId from apps.schemas.appcenter import AppFlowInfo, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType from apps.schemas.request_data import CreateAppRequest, ModFavAppRequest diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 12bf9dff0aa0740fc5a2760f23bebb9a80b90699..93c5d0eca26295f0248656adfe4842e52877fb61 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -54,8 +54,10 @@ class CoreCall(BaseModel): """所有Call的父类,包含通用的逻辑""" name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True) - description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True) - node: SkipJsonSchema[NodePool | None] = Field(description="节点信息", exclude=True) + description: SkipJsonSchema[str] = Field( + description="Step的描述", exclude=True) + node: SkipJsonSchema[NodePool | None] = Field( + description="节点信息", exclude=True) enable_filling: SkipJsonSchema[bool] = Field( description="是否需要进行自动参数填充", default=False, exclude=True ) @@ -75,7 +77,8 @@ class CoreCall(BaseModel): frozen=True, ) to_user: bool = Field(description="是否需要将输出返回给用户", default=False) - enable_variable_resolution: bool = Field(description="是否启用自动变量解析", default=True) + enable_variable_resolution: bool = Field( + description="是否启用自动变量解析", default=True) controlled_output: bool = Field(description="是否允许用户定义输出参数", default=False) i18n_info: ClassVar[SkipJsonSchema[dict[str, dict]]] = {} @@ -92,7 +95,8 @@ class CoreCall(BaseModel): :return: Call的名称和描述 :rtype: CallInfo """ - lang_info = cls.i18n_info.get(language, cls.i18n_info[LanguageType.CHINESE]) + lang_info = cls.i18n_info.get( + language, cls.i18n_info[LanguageType.CHINESE]) return CallInfo(name=lang_info["name"], type=lang_info["type"], description=lang_info["description"]) def __init_subclass__( @@ -183,17 +187,17 @@ class CoreCall(BaseModel): "app_id": call_vars.ids.app_id, "conversation_id": call_vars.ids.conversation_id, } - + await VariableIntegration.initialize_system_variables(context) return context async def _resolve_variables_in_config(self, config: Any, call_vars: CallVars) -> Any: """解析配置中的变量引用 - + Args: config: 配置值,可能包含变量引用 call_vars: Call变量 - + Returns: 解析后的配置值 """ @@ -232,25 +236,25 @@ class CoreCall(BaseModel): async def _resolve_variables_in_text(self, text: str, call_vars: CallVars) -> str: """解析文本中的变量引用({{...}} 语法) - + Args: text: 包含变量引用的文本 call_vars: Call变量 - + Returns: 解析后的文本 """ if not isinstance(text, str): return text - + # 检查是否包含变量引用语法 if not re.search(r'\{\{.*?\}\}', text): return text - + # 提取所有变量引用并逐一解析替换 variable_pattern = r'\{\{(.*?)\}\}' matches = re.findall(variable_pattern, text) - + resolved_text = text for match in matches: try: @@ -263,21 +267,22 @@ class CoreCall(BaseModel): current_step_id=getattr(self, '_step_id', None) ) # 替换原始文本中的变量引用 - resolved_text = resolved_text.replace(f'{{{{{match}}}}}', str(resolved_value)) + resolved_text = resolved_text.replace( + f'{{{{{match}}}}}', str(resolved_value)) except Exception as e: logger.warning(f"[CoreCall] 解析变量引用 '{match}' 失败: {e}") # 如果解析失败,保留原始的变量引用 continue - + return resolved_text async def _resolve_single_value(self, value, call_vars: CallVars): """解析单个变量引用 - + Args: value: Value对象,包含type和value字段 call_vars: Call变量上下文 - + Returns: Value: 解析后的Value对象,如果是引用类型则解析为具体值和类型 """ @@ -285,11 +290,11 @@ class CoreCall(BaseModel): from apps.schemas.parameters import ValueType from apps.scheduler.variable.type import VariableType as VarType from apps.scheduler.call.choice.schema import Value - + # 如果不是引用类型,直接返回 if value.type != ValueType.REFERENCE: return value - + try: # 解析变量引用 resolved_value, resolved_type = await VariableIntegration.resolve_variable_reference( @@ -299,7 +304,7 @@ class CoreCall(BaseModel): conversation_id=call_vars.ids.conversation_id, current_step_id=getattr(self, '_step_id', None) ) - + # 🔑 关键修复:将VariableType转换为ValueType # VariableType到ValueType的映射 type_mapping = { @@ -315,29 +320,29 @@ class CoreCall(BaseModel): VarType.ARRAY_BOOLEAN: ValueType.LIST, VarType.ARRAY_SECRET: ValueType.LIST, } - + # 转换类型 if resolved_type in type_mapping: converted_type = type_mapping[resolved_type] else: # 如果没有映射,默认为STRING converted_type = ValueType.STRING - + except Exception as e: logger.warning(f"[CoreCall] 解析变量引用 '{value.value}' 失败: {e}") return value - + return Value(value=resolved_value, type=converted_type) async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) self._step_id = executor.step.step_id # 存储 step_id 用于变量名构造 - + # 如果启用了变量解析,初始化变量上下文 if self.enable_variable_resolution: await self._initialize_variable_context(self._sys_vars) - + input_data = await self._init(self._sys_vars) self.input = input_data.model_dump(by_alias=True, exclude_none=True) @@ -353,7 +358,6 @@ class CoreCall(BaseModel): async def _after_exec(self, input_data: dict[str, Any]) -> None: """Call类实例的执行后方法""" - async def exec( self, executor: "StepExecutor", @@ -362,13 +366,13 @@ class CoreCall(BaseModel): ) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" self._last_output_data = {} # 初始化输出数据存储 - + async for chunk in self._exec(input_data): # 捕获最后的输出数据 if chunk.type == CallOutputType.DATA and isinstance(chunk.content, dict): self._last_output_data = chunk.content yield chunk - + await self._after_exec(input_data) async def _llm(self, messages: list[dict[str, Any]]) -> str: diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 14aa58f10aa28a282d8f7e5d0f2072d70403cc8a..4dbdf88d8935d79c9613d900111e7078eb01efda 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -11,7 +11,10 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field +from apps.services.llm import LLMManager +from apps.llm.adapters import get_provider_from_endpoint from apps.llm.reasoning import ReasoningLLM +from apps.llm.enum import DefaultModelId from apps.scheduler.call.core import CoreCall from apps.scheduler.call.llm.prompt import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT from apps.scheduler.call.llm.schema import LLMInput, LLMOutput @@ -22,6 +25,7 @@ from apps.schemas.scheduler import ( CallOutputChunk, CallVars, ) +from apps.schemas.config import LLMConfig logger = logging.getLogger(__name__) @@ -38,7 +42,8 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): }) # 模型配置 - llmId: str = Field(description="大模型ID", default="") + llm_id: str = Field(description="大模型ID", + default=DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) # 大模型基础参数 temperature: float = Field(description="大模型温度(随机化程度)", default=0.7) @@ -146,27 +151,23 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): full_reply = "" # 用于累积完整回复 try: # 根据llmId获取模型配置 - llm_config = None - if self.llmId: - from apps.services.llm import LLMManager - from apps.llm.adapters import get_provider_from_endpoint - - llm_info = await LLMManager.get_llm_by_id(self.llmId) - if llm_info: - from apps.schemas.config import LLMConfig - - # 获取provider,如果没有则从endpoint推断 - provider = llm_info.provider or get_provider_from_endpoint( - llm_info.openai_base_url) - - llm_config = LLMConfig( - provider=provider, - endpoint=llm_info.openai_base_url, - api_key=llm_info.openai_api_key, - model=llm_info.model_name, - max_tokens=llm_info.max_tokens, - temperature=self.temperature if self.enable_temperature else 0.7, - ) + llm_info = await LLMManager.get_llm_by_id(self.llm_id) + if llm_info: + + # 获取provider,如果没有则从endpoint推断 + provider = llm_info.provider or get_provider_from_endpoint( + llm_info.openai_base_url) + + llm_config = LLMConfig( + provider=provider, + endpoint=llm_info.openai_base_url, + api_key=llm_info.openai_api_key, + model=llm_info.model_name, + max_tokens=llm_info.max_tokens, + temperature=self.temperature if self.enable_temperature else 0.7, + ) + else: + llm_config = None # 初始化LLM客户端(会自动加载适配器) llm = ReasoningLLM(llm_config) if llm_config else ReasoningLLM() diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 71ed9afeae9e211fd371d092645d15954578be6a..f99643089e5fa6541143c0ce30b139079d1c92c9 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -2,6 +2,7 @@ """自动参数填充工具""" import json +from jsonschema import validate from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, Self, ClassVar @@ -9,8 +10,10 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field +from apps.llm.enum import DefaultModelId from apps.llm.function import FunctionLLM, JsonGenerator from apps.llm.reasoning import ReasoningLLM +from apps.llm.function import FunctionLLM from apps.scheduler.call.core import CoreCall from apps.scheduler.call.slot.prompt import SLOT_GEN_PROMPT from apps.scheduler.call.slot.schema import SlotInput, SlotOutput @@ -18,6 +21,8 @@ from apps.scheduler.slot.slot import Slot as SlotProcessor from apps.schemas.enum_var import CallOutputType, CallType, LanguageType from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars +from apps.schemas.config import LLMConfig, FunctionCallConfig +from apps.services.llm import LLMManager if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor @@ -25,7 +30,10 @@ if TYPE_CHECKING: class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): """参数填充工具""" - + chat_llm_id: str = Field( + description="对话大模型ID", default=DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) + func_llm_id: str = Field(description="Function Call大模型ID", + default=DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value) data: dict[str, Any] = Field(description="当前输入", default={}) current_schema: dict[str, Any] = Field(description="当前Schema", default={}) summary: str = Field(description="背景信息总结", default="") @@ -75,16 +83,32 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): ] # 使用大模型进行尝试 - reasoning = ReasoningLLM() + chat_llm = await LLMManager.get_llm_by_id(self.chat_llm_id) + chat_llm_config = LLMConfig( + provider=chat_llm.provider, + api_key=chat_llm.openai_api_key, + endpoint=chat_llm.openai_base_url, + model=chat_llm.model_name, + max_tokens=chat_llm.max_tokens + ) + reasoning = ReasoningLLM(llm_config=chat_llm_config) answer = "" async for chunk in reasoning.call(messages=conversation, streaming=False): answer += chunk self.tokens.input_tokens += reasoning.input_tokens self.tokens.output_tokens += reasoning.output_tokens - - answer = await FunctionLLM.process_response(answer) + data = None try: data = json.loads(answer) + validate(instance=data, schema=remaining_schema) + except Exception: + pass + if data is not None: + return answer, data + # 使用JsonGenerator进行解析 + try: + data = JsonGenerator._parse_result_by_stack( + answer, remaining_schema) except Exception: # noqa: BLE001 data = {} return answer, data @@ -95,27 +119,37 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): {"role": "user", "content": self._question}, {"role": "assistant", "content": answer}, ] + func_call_llm = await LLMManager.get_llm_by_id(self.func_llm_id) + func_call_llm_config = FunctionCallConfig( + provider=func_call_llm.provider, + endpoint=func_call_llm.openai_base_url, + api_key=func_call_llm.openai_api_key, + model=func_call_llm.model_name, + max_tokens=func_call_llm.max_tokens, + temperature=0.7, + ) + func_call_llm = FunctionLLM(func_call_llm_config) json_gen = JsonGenerator( query=self._question, conversation=conversation, schema=remaining_schema, + func_call_llm=func_call_llm ) return await json_gen.generate() - @classmethod - async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: - """实例化Call类""" - obj = cls( - name=executor.step.step.name, - description=executor.step.step.description, - facts=executor.background.facts, - summary=executor.task.runtime.summary, - node=node, - **kwargs, - ) - await obj._set_input(executor) - return obj - + async def instance(self, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: + """实例化Call类(实例方法版本)""" + # 重置或初始化实例属性 + self.name = executor.step.step.name + self.description = executor.step.step.description + self.facts = executor.background.facts + self.summary = executor.task.runtime.summary + self.node = node + self.func_llm_id = executor.func_call_llm_id + self.chat_llm_id = executor.chat_llm_id + # 处理额外关键字参数 + for key, value in kwargs.items(): + setattr(self, key, value) async def _init(self, call_vars: CallVars) -> SlotInput: """初始化""" @@ -136,7 +170,6 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): remaining_schema=remaining_schema, ) - async def _exec( self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE ) -> AsyncGenerator[CallOutputChunk, None]: diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 9f29d4e0b909eae4cf79e9eaa952f0ad32a55608..6e8a7b6a3a5283f89795c910f7a1ef1ba551051e 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -8,7 +8,10 @@ import anyio from mcp.types import TextContent from pydantic import Field +from apps.services.llm import LLMManager from apps.llm.reasoning import ReasoningLLM +from apps.llm.function import FunctionLLM +from apps.llm.enum import DefaultModelId from apps.scheduler.executor.base import BaseExecutor from apps.schemas.enum_var import LanguageType from apps.scheduler.mcp_agent.host import MCPHost @@ -23,6 +26,8 @@ from apps.schemas.mcp import ( ) from apps.schemas.message import FlowParams from apps.schemas.task import FlowStepHistory +from apps.schemas.config import LLMConfig +from apps.schemas.config import FunctionCallConfig from apps.services.appcenter import AppCenterManager from apps.services.mcp_service import MCPServiceManager from apps.services.task import TaskManager @@ -38,8 +43,10 @@ class MCPAgentExecutor(BaseExecutor): servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default_factory=list) - mcp_pool: MCPPool = Field(description="MCP池", default_factory=MCPPool, exclude=True) + mcp_list: list[MCPCollection] = Field( + description="MCP服务器列表", default_factory=list) + mcp_pool: MCPPool = Field( + description="MCP池", default_factory=MCPPool, exclude=True) tools: dict[str, MCPTool] = Field( description="MCP工具列表,key为tool_id", default_factory=dict, @@ -53,13 +60,67 @@ class MCPAgentExecutor(BaseExecutor): description="流执行过程中的参数补充", alias="params", ) + chat_llm_id: str = Field( + default=DefaultModelId.DEFAULT_CHAT_MODEL_ID.value, + description="聊天大模型ID", + ) + enable_thinking: bool = Field(default=False, description="是否启用思考模式") + func_call_llm_id: str = Field( + default=DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value, + description="函数调用大模型ID", + ) resoning_llm: ReasoningLLM = Field( - default_factory=ReasoningLLM, description="推理大模型", - exclude=True, + default=ReasoningLLM() + ) + function_call_llm: FunctionLLM = Field( + description="函数调用大模型", + default=FunctionLLM() ) app_owner: str = Field(default="", description="应用所有者") auto_execute: bool | None = Field(default=None, description="是否自动执行(来自请求)") + mcp_host: MCPHost = Field( + description="MCP主机", + default=MCPHost() + ) + mcp_planner: MCPPlanner = Field( + description="MCP规划器", + default=MCPPlanner() + ) + + async def init_llms(self) -> None: + """初始化大模型""" + reasoning_llm = await LLMManager.get_llm_by_id(self.chat_llm_id) + reasoning_llm_config = LLMConfig( + provider=reasoning_llm.provider, + api_key=reasoning_llm.openai_api_key, + endpoint=reasoning_llm.openai_base_url, + model=reasoning_llm.model_name, + max_tokens=reasoning_llm.max_tokens, + temperature=0.7 + ) + self.resoning_llm = ReasoningLLM(config=reasoning_llm_config) + function_call_llm = await LLMManager.get_llm_by_id(self.func_call_llm_id) + function_call_llm_config = FunctionCallConfig( + provider=function_call_llm.provider, + api_key=function_call_llm.openai_api_key, + endpoint=function_call_llm.openai_base_url, + model=function_call_llm.model_name, + max_tokens=function_call_llm.max_tokens, + temperature=0.0 + ) + self.function_call_llm = FunctionLLM(config=function_call_llm_config) + + async def init_mcp_plan_and_host(self) -> None: + """初始化MCP的Host和Planner""" + self.mcp_host = MCPHost( + reasoning_llm=self.resoning_llm, + function_llm=self.function_call_llm, + ) + self.mcp_planner = MCPPlanner( + reasoning_llm=self.resoning_llm, + function_llm=self.function_call_llm, + ) async def get_auto_execute(self) -> bool: """ @@ -69,7 +130,7 @@ class MCPAgentExecutor(BaseExecutor): # 如果请求中明确指定了,使用请求中的设置 if self.auto_execute is not None: return self.auto_execute - + # 否则使用用户的全局设置 user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) return user_info.auto_execute @@ -102,7 +163,8 @@ class MCPAgentExecutor(BaseExecutor): # 尝试初始化MCP,只有成功才添加到列表 client = await self.mcp_pool._init_mcp(mcp_id, self.app_owner) if client is None: - logger.warning("[MCPAgentExecutor] MCP服务 %s 初始化失败,跳过", mcp_service.name) + logger.warning( + "[MCPAgentExecutor] MCP服务 %s 初始化失败,跳过", mcp_service.name) continue self.mcp_list.append(mcp_service) for tool in mcp_service.tools: @@ -164,8 +226,8 @@ class MCPAgentExecutor(BaseExecutor): if is_first: # 获取第一个输入参数 mcp_tool = self.tools[self.task.state.tool_id] - self.task.state.current_input = await MCPHost._get_first_input_params( - mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task, self.resoning_llm, self.enable_thinking + self.task.state.current_input = await self.mcp_host._get_first_input_params( + mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task, self.resoning_llm ) else: # 获取后续输入参数 @@ -176,7 +238,7 @@ class MCPAgentExecutor(BaseExecutor): params = {} params_description = "" mcp_tool = self.tools[self.task.state.tool_id] - self.task.state.current_input = await MCPHost._fill_params( + self.task.state.current_input = await self.mcp_host._fill_params( mcp_tool, self.task.runtime.question, self.task.state.step_description, @@ -191,7 +253,7 @@ class MCPAgentExecutor(BaseExecutor): """确认前步骤""" # 发送确认消息 mcp_tool = self.tools[self.task.state.tool_id] - confirm_message = await MCPPlanner.get_tool_risk( + confirm_message = await self.mcp_planner.get_tool_risk( mcp_tool, self.task.state.current_input, "", self.resoning_llm, self.task.language ) await self.update_tokens() @@ -297,22 +359,20 @@ class MCPAgentExecutor(BaseExecutor): async def generate_params_with_null(self) -> None: """生成参数补充""" mcp_tool = self.tools[self.task.state.tool_id] - params_with_null = await MCPPlanner.get_missing_param( + params_with_null = await self.mcp_planner.get_missing_param( mcp_tool, self.task.state.current_input, self.task.state.error_message, self.resoning_llm, self.task.language, - self.enable_thinking, ) await self.update_tokens() - error_message = await MCPPlanner.change_err_message_to_description( + error_message = await self.mcp_planner.change_err_message_to_description( error_message=self.task.state.error_message, tool=mcp_tool, input_params=self.task.state.current_input, reasoning_llm=self.resoning_llm, language=self.task.language, - enable_thinking=self.enable_thinking, ) await self.push_message( EventType.STEP_WAITING_FOR_PARAM, data={ @@ -346,12 +406,12 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.retry_times = 0 if self.task.state.step_cnt < self.max_steps: self.task.state.step_cnt += 1 - history = await MCPHost.assemble_memory(self.task) + history = await self.mcp_host.assemble_memory(self.task) max_retry = 3 step = None for i in range(max_retry): try: - step = await MCPPlanner.create_next_step(self.task.runtime.question, history, self.tool_list, self.resoning_llm, self.task.language, self.enable_thinking) + step = await self.mcp_planner.create_next_step(self.task.runtime.question, history, self.tool_list, self.resoning_llm, self.task.language) if step.tool_id in self.tools.keys(): break except Exception as e: @@ -488,16 +548,15 @@ class MCPAgentExecutor(BaseExecutor): await self.get_next_step() else: mcp_tool = self.tools[self.task.state.tool_id] - is_param_error = await MCPPlanner.is_param_error( + is_param_error = await self.mcp_planner.is_param_error( self.task.runtime.question, - await MCPHost.assemble_memory(self.task), + await self.mcp_host.assemble_memory(self.task), self.task.state.error_message, mcp_tool, self.task.state.step_description, self.task.state.current_input, self.resoning_llm, - self.task.language, - self.enable_thinking, + self.task.language ) if is_param_error.is_param_error: # 如果是参数错误,生成参数补充 @@ -537,11 +596,12 @@ class MCPAgentExecutor(BaseExecutor): async def summarize(self) -> None: """总结""" - async for chunk in MCPPlanner.generate_answer( + async for chunk in self.mcp_planner.generate_answer( self.task.runtime.question, - (await MCPHost.assemble_memory(self.task)), + (await self.mcp_host.assemble_memory(self.task)), self.resoning_llm, self.task.language, + enable_thinking=self.enable_thinking ): await self.push_message( EventType.TEXT_ADD, @@ -552,6 +612,8 @@ class MCPAgentExecutor(BaseExecutor): async def run(self) -> None: """执行MCP Agent的主逻辑""" # 初始化MCP服务 + await self.init_llms() + await self.init_mcp_plan_and_host() self.app_owner = (await AppCenterManager.fetch_app_data_by_id(self.agent_id)).author await self.load_state() await self.load_mcp() @@ -560,10 +622,10 @@ class MCPAgentExecutor(BaseExecutor): # 初始化状态 try: self.task.state.flow_id = str(uuid.uuid4()) - self.task.state.flow_name = (await MCPPlanner.get_flow_name( + self.task.state.flow_name = (await self.mcp_planner.get_flow_name( self.task.runtime.question, self.resoning_llm, self.task.language )).flow_name - flow_risk = await MCPPlanner.get_flow_excute_risk( + flow_risk = await self.mcp_planner.get_flow_excute_risk( self.task.runtime.question, self.tool_list, self.resoning_llm, self.task.language ) auto_execute = await self.get_auto_execute() diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 3e6d83a61a1372b4a804131d0e2b03f9bcaba741..0ec89a43ce159830e3157860614aa802cd22f470 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -8,6 +8,7 @@ from datetime import UTC, datetime from pydantic import Field +from apps.llm.enum import DefaultModelId from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.executor.step import StepExecutor @@ -63,8 +64,13 @@ class FlowExecutor(BaseExecutor): flow_id: str = Field(description="Flow ID") question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") + chat_llm_id: str = Field(description="对话使用的大模型ID", + default=DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) enable_thinking: bool = Field(description="是否启用思维链", default=False) - llm_id: str | None = Field(description="应用配置的模型ID", default=None) + func_call_llm_id: str = Field( + description="Function Call使用的大模型ID", + default=DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value, + ) current_step: StepQueueItem | None = Field( description="当前执行的步骤", default=None @@ -81,7 +87,8 @@ class FlowExecutor(BaseExecutor): ): context_objects = await TaskManager.get_context_by_task_id(self.task.id) # 将对象转换为字典以保持与系统其他部分的一致性 - self.task.context = [context.model_dump(exclude_none=True, by_alias=True) for context in context_objects] + self.task.context = [context.model_dump( + exclude_none=True, by_alias=True) for context in context_objects] else: # 创建ExecutorState self.task.state = ExecutorState( @@ -114,6 +121,9 @@ class FlowExecutor(BaseExecutor): step=self.current_step, background=self.background, question=self.question, + chat_llm_id=self.chat_llm_id, + enable_thinking=self.enable_thinking, + func_call_llm_id=self.func_call_llm_id, ) # 初始化步骤 @@ -134,7 +144,8 @@ class FlowExecutor(BaseExecutor): # Check if it has been executed if self.current_step.step_id in self._executed_steps: - logger.info("[FlowExecutor] 步骤 %s 已经执行过,跳过执行", self.current_step.step_id) + logger.info("[FlowExecutor] 步骤 %s 已经执行过,跳过执行", + self.current_step.step_id) continue # 执行Step @@ -154,7 +165,8 @@ class FlowExecutor(BaseExecutor): async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,从变量池中获取分支ID @@ -171,17 +183,20 @@ class FlowExecutor(BaseExecutor): if branch_id: # 构建带分支ID的edge_from edge_from = f"{self.task.state.step_id}.{branch_id}" - logger.info("[FlowExecutor] 从变量池获取分支ID:%s,查找边:%s", branch_id, edge_from) + logger.info( + "[FlowExecutor] 从变量池获取分支ID:%s,查找边:%s", branch_id, edge_from) # 在edges中查找对应的下一个节点 next_steps = [] for edge in self.flow.edges: if edge.edge_from == edge_from: next_steps.append(edge.edge_to) - logger.info("[FlowExecutor] 找到下一个节点:%s", edge.edge_to) + logger.info( + "[FlowExecutor] 找到下一个节点:%s", edge.edge_to) if not next_steps: - logger.warning("[FlowExecutor] 没有找到分支 %s 对应的边", edge_from) + logger.warning( + "[FlowExecutor] 没有找到分支 %s 对应的边", edge_from) else: logger.warning("[FlowExecutor] 没有找到分支ID变量") return [] @@ -189,7 +204,8 @@ class FlowExecutor(BaseExecutor): logger.error("[FlowExecutor] 从变量池获取分支ID失败: %s", e) return [] else: - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + # type: ignore[arg-type] + next_steps = await self._find_next_id(self.task.state.step_id) # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -221,7 +237,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], ) # 🔑 获取function call场景使用的模型ID用于系统步骤(理解上下文、记忆存储) @@ -236,12 +253,14 @@ class FlowExecutor(BaseExecutor): function_call_model_id = self.llm_id logger.warning("[FlowExecutor] 未找到任何模型,使用应用配置的模型用于系统步骤") else: - logger.info(f"[FlowExecutor] 系统步骤(理解上下文、记忆存储)使用模型: {function_call_model_id}") - + logger.info( + f"[FlowExecutor] 系统步骤(理解上下文、记忆存储)使用模型: {function_call_model_id}") + # 头插开始前的系统步骤,并执行 for step in FIXED_STEPS_BEFORE_START: # 为系统步骤添加function call模型信息 - step_data = step.get(self.task.language, step[LanguageType.CHINESE]) + step_data = step.get(self.task.language, + step[LanguageType.CHINESE]) # 将llm_id和enable_thinking添加到step的params中 step_data_with_params = step_data.model_copy() step_data_with_params.params = { @@ -257,17 +276,19 @@ class FlowExecutor(BaseExecutor): ) ) await self._step_process() - + # 插入首个步骤 self.step_queue.append(first_step) - - self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] - + + # type: ignore[arg-type] + self.task.state.flow_status = FlowStatus.RUNNING + # 运行Flow(未达终点) is_error = False while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] + # type: ignore[arg-type] + if self.task.state.step_status == StepStatus.ERROR: logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft( @@ -276,8 +297,8 @@ class FlowExecutor(BaseExecutor): step=Step( name=( "错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling" - ), - description=( + ), + description=( "错误处理" if self.task.language == LanguageType.CHINESE else "Error Handling" ), node=SpecialCallType.LLM.value, @@ -285,7 +306,8 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT[self.task.language].replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + # type: ignore[arg-type] + self.task.state.error_info["err_msg"], ), }, ), @@ -310,14 +332,17 @@ class FlowExecutor(BaseExecutor): if step.step_id not in self._executed_steps: self.step_queue.append(step) else: - logger.info("[FlowExecutor] 步骤 %s 已经执行过,不再添加到队列中", step.step_id) + logger.info( + "[FlowExecutor] 步骤 %s 已经执行过,不再添加到队列中", step.step_id) # 更新Task状态 if is_error: - self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + # type: ignore[arg-type] + self.task.state.flow_status = FlowStatus.ERROR else: - self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] - + # type: ignore[arg-type] + self.task.state.flow_status = FlowStatus.SUCCESS + # 重置Conversation变量池 try: from apps.scheduler.variable.integration import VariableIntegration @@ -327,17 +352,21 @@ class FlowExecutor(BaseExecutor): user_sub=self.task.ids.user_sub ) if reset_success: - logger.info(f"[FlowExecutor] Flow {self.flow_id} 执行完成后,成功重置对话变量池到默认值") + logger.info( + f"[FlowExecutor] Flow {self.flow_id} 执行完成后,成功重置对话变量池到默认值") else: - logger.warning(f"[FlowExecutor] Flow {self.flow_id} 执行完成后,重置对话变量池失败") + logger.warning( + f"[FlowExecutor] Flow {self.flow_id} 执行完成后,重置对话变量池失败") except Exception as e: # 重置失败不应该影响Flow的正常完成 - logger.error(f"[FlowExecutor] Flow {self.flow_id} 执行完成后重置对话变量池时发生异常: {e}") + logger.error( + f"[FlowExecutor] Flow {self.flow_id} 执行完成后重置对话变量池时发生异常: {e}") # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: # 为系统步骤添加function call模型信息 - step_data = step.get(self.task.language, step[LanguageType.CHINESE]) + step_data = step.get(self.task.language, + step[LanguageType.CHINESE]) # 将llm_id和enable_thinking添加到step的params中 step_data_with_params = step_data.model_copy() step_data_with_params.params = { @@ -353,7 +382,8 @@ class FlowExecutor(BaseExecutor): await self._step_process() # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) - self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time + self.task.tokens.time = round(datetime.now( + UTC).timestamp(), 2) - self.task.tokens.full_time # 推送Flow停止消息 if is_error: await self.push_message(EventType.FLOW_FAILED.value) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index ce6ba7ccbdd2393682a2d61d58040eab4381e7e2..04600e84064e39c7107008ffefaf63bc840b1606 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -6,7 +6,7 @@ import logging import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime -from typing import Any +from typing import Any, Field import json import jsonschema from pydantic import ConfigDict @@ -31,7 +31,7 @@ from apps.schemas.message import TextAddContent from apps.schemas.scheduler import CallError, CallOutputChunk from apps.schemas.task import FlowStepHistory, StepQueueItem from apps.services.node import NodeManager - +from apps.llm.enum import DefaultModelId logger = logging.getLogger(__name__) @@ -40,6 +40,13 @@ class StepExecutor(BaseExecutor): step: StepQueueItem + chat_llm_id: str = Field(description="对话使用的大模型ID", + default=DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) + enable_thinking: bool = Field(description="是否启用思维链", default=False) + func_call_llm_id: str = Field( + description="Function Call使用的大模型ID", + default=DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value, + ) model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", @@ -127,7 +134,8 @@ class StepExecutor(BaseExecutor): params.update(input_params) # 对于LLM调用,注入enable_thinking参数 - if self._call_id == "LLM": + if self._call_id == SpecialCallType.LLM.value: + params["llm_id"] = self.chat_llm_id params['enable_thinking'] = self.background.enable_thinking try: @@ -154,7 +162,8 @@ class StepExecutor(BaseExecutor): self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 - slot_obj = await Slot.instance( + slot_obj = Slot() + await slot_obj.instance( self, self.node, data=self.obj.input, diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 90da7f78b632cdbb5c5700b9e19cf21cdb8b5ea1..8cb4c3f1f9920436b10756a234797f78c0a87f45 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -4,15 +4,26 @@ from jsonschema import validate import logging from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM - +from apps.llm.function import FunctionLLM logger = logging.getLogger(__name__) class MCPBase: """MCP基类""" - @staticmethod - async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM(), enable_thinking: bool = False) -> str: + def __init__(self, reasoning_llm: ReasoningLLM | None = None, + function_llm: FunctionLLM | None = None) -> None: + """初始化MCP基类""" + if reasoning_llm: + self.reasoning_llm = reasoning_llm + else: + self.reasoning_llm = ReasoningLLM() + if function_llm: + self.function_llm = function_llm + else: + self.function_llm = FunctionLLM() + + async def get_resoning_result(self, prompt: str) -> str: """获取推理结果""" # 调用推理大模型 message = [ @@ -20,19 +31,17 @@ class MCPBase: {"role": "user", "content": "Please provide a JSON response based on the above information and schema."}, ] result = "" - async for chunk in resoning_llm.call( + async for chunk in self.reasoning_llm.call( message, streaming=False, temperature=0.07, - result_only=False, - enable_thinking=enable_thinking, + result_only=False ): result += chunk return result - @staticmethod - async def _parse_result(result: str, schema: dict[str, Any]) -> str: + async def _parse_result(self, result: str, schema: dict[str, Any]) -> str: """解析推理结果""" json_result = await JsonGenerator._parse_result_by_stack(result, schema) if json_result is not None: @@ -44,6 +53,7 @@ class MCPBase: {"role": "user", "content": result}, ], schema, + self.function_llm ) json_result = await json_generator.generate() return json_result diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index 35f3f39bfd2c262be19964cdff7df436a19660bb..208f3ccf9974392f5a70da672455319fe180441d 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -10,6 +10,7 @@ from jinja2.sandbox import SandboxedEnvironment from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM +from apps.llm.function import FunctionLLM from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import GEN_PARAMS, REPAIR_PARAMS @@ -42,22 +43,22 @@ LLM_QUERY_FIX = { class MCPHost(MCPBase): """MCP宿主服务""" - @staticmethod - async def assemble_memory(task: Task) -> str: + def __init__(self, reasoning_llm: ReasoningLLM = None, function_llm: FunctionLLM = None): + super().__init__(reasoning_llm, function_llm) + + async def assemble_memory(self, task: Task) -> str: """组装记忆""" return _env.from_string(MEMORY_TEMPLATE[task.language]).render( context_list=task.context, ) - @staticmethod async def _get_first_input_params( + self, mcp_tool: MCPTool, goal: str, current_goal: str, - task: Task, - resoning_llm: ReasoningLLM = ReasoningLLM(), - enable_thinking: bool = False, + task: Task ) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate @@ -67,18 +68,18 @@ class MCPHost(MCPBase): goal=goal, current_goal=current_goal, input_schema=mcp_tool.input_schema, - background_info=await MCPHost.assemble_memory(task), + background_info=await self.assemble_memory(task), ) - result = await MCPHost.get_resoning_result(prompt, resoning_llm, enable_thinking) + result = await self.get_resoning_result(prompt) # 使用JsonGenerator解析结果 - result = await MCPHost._parse_result( + result = await self._parse_result( result, mcp_tool.input_schema, ) return result - @staticmethod async def _fill_params( + self, mcp_tool: MCPTool, goal: str, current_goal: str, @@ -88,7 +89,6 @@ class MCPHost(MCPBase): params_description: str = "", language: LanguageType = LanguageType.CHINESE, ) -> dict[str, Any]: - llm_query = LLM_QUERY_FIX[language] prompt = _env.from_string(REPAIR_PARAMS[language]).render( tool_name=mcp_tool.name, goal=goal, @@ -100,12 +100,10 @@ class MCPHost(MCPBase): params=params, params_description=params_description, ) - json_generator = JsonGenerator( - llm_query, - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], + result = await self.get_resoning_result(prompt) + # 使用JsonGenerator解析结果 + result = await self._parse_result( + result, mcp_tool.input_schema, ) - return await json_generator.generate() + return result diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 1237a40438f37e0888f02f341e040c13194c7951..5e69262e8c0dc8585249d00e302ccb4865ca7a00 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -9,6 +9,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from apps.llm.reasoning import ReasoningLLM +from apps.llm.function import FunctionLLM from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import ( CHANGE_ERROR_MESSAGE_TO_DESCRIPTION, @@ -55,84 +56,79 @@ logger = logging.getLogger(__name__) class MCPPlanner(MCPBase): """MCP 用户目标拆解与规划""" - @staticmethod + def __init__(self, reasoning_llm: ReasoningLLM = None, function_llm: FunctionLLM = None): + super().__init__(reasoning_llm, function_llm) + async def evaluate_goal( + self, goal: str, tool_list: list[MCPTool], - resoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE,) -> GoalEvaluationResult: """评估用户目标的可行性""" # 获取推理结果 - result = await MCPPlanner._get_reasoning_evaluation(goal, tool_list, resoning_llm, language) + result = await self._get_reasoning_evaluation(goal, tool_list, language) # 返回评估结果 - return await MCPPlanner._parse_evaluation_result(result) + return await self._parse_evaluation_result(result) - @staticmethod async def _get_reasoning_evaluation( - goal, tool_list: list[MCPTool], - resoning_llm: ReasoningLLM = ReasoningLLM(), - language: LanguageType = LanguageType.CHINESE, - enable_thinking: bool = False,) -> str: + self, + goal, + tool_list: list[MCPTool], + language: LanguageType = LanguageType.CHINESE) -> str: """获取推理大模型的评估结果""" template = _env.from_string(EVALUATE_GOAL[language]) prompt = template.render( goal=goal, tools=tool_list, ) - return await MCPPlanner.get_resoning_result(prompt, resoning_llm, enable_thinking) + return await self.get_resoning_result(prompt) - @staticmethod - async def _parse_evaluation_result(result: str) -> GoalEvaluationResult: + async def _parse_evaluation_result(self, result: str) -> GoalEvaluationResult: """将推理结果解析为结构化数据""" schema = GoalEvaluationResult.model_json_schema() - evaluation = await MCPPlanner._parse_result(result, schema) + evaluation = await self._parse_result(result, schema) # 使用GoalEvaluationResult模型解析结果 return GoalEvaluationResult.model_validate(evaluation) async def get_flow_name( + self, user_goal: str, - resoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> FlowName: """获取当前流程的名称""" - result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm, language) - result = await MCPPlanner._parse_result(result, FlowName.model_json_schema()) + result = await self._get_reasoning_flow_name(user_goal, language) + result = await self._parse_result(result, FlowName.model_json_schema()) # 使用FlowName模型解析结果 return FlowName.model_validate(result) - @staticmethod async def _get_reasoning_flow_name( + self, user_goal: str, - resoning_llm: ReasoningLLM = ReasoningLLM(), - language: LanguageType = LanguageType.CHINESE, - enable_thinking: bool = False, + language: LanguageType = LanguageType.CHINESE ) -> str: """获取推理大模型的流程名称""" template = _env.from_string(GENERATE_FLOW_NAME[language]) prompt = template.render(goal=user_goal) - return await MCPPlanner.get_resoning_result(prompt, resoning_llm, enable_thinking) + return await self.get_resoning_result(prompt) - @staticmethod async def get_flow_excute_risk( + self, user_goal: str, tools: list[MCPTool], - resoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> FlowRisk: """获取当前流程的风险评估结果""" - result = await MCPPlanner._get_reasoning_flow_risk(user_goal, tools, resoning_llm, language) - result = await MCPPlanner._parse_result(result, FlowRisk.model_json_schema()) + result = await self._get_reasoning_flow_risk(user_goal, tools, language) + result = await self._parse_result(result, FlowRisk.model_json_schema()) # 使用FlowRisk模型解析结果 return FlowRisk.model_validate(result) - @staticmethod async def _get_reasoning_flow_risk( + self, user_goal: str, tools: list[MCPTool], - resoning_llm: ReasoningLLM = ReasoningLLM(), - language: LanguageType = LanguageType.CHINESE, - enable_thinking: bool = False, + language: LanguageType = LanguageType.CHINESE ) -> str: """获取推理大模型的流程风险""" template = _env.from_string(GENERATE_FLOW_EXCUTE_RISK[language]) @@ -140,15 +136,14 @@ class MCPPlanner(MCPBase): goal=user_goal, tools=tools, ) - return await MCPPlanner.get_resoning_result(prompt, resoning_llm, enable_thinking) + return await self.get_resoning_result(prompt) - @staticmethod async def get_replan_start_step_index( + self, user_goal: str, error_message: str, current_plan: MCPPlan | None = None, history: str = "", - reasoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> RestartStepIndex: """获取重新规划的步骤索引""" @@ -161,18 +156,18 @@ class MCPPlanner(MCPBase): exclude_none=True, by_alias=True), history=history, ) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + result = await self.get_resoning_result(prompt) # 解析为结构化数据 schema = RestartStepIndex.model_json_schema() schema["properties"]["start_index"]["maximum"] = len( current_plan.plans) - 1 schema["properties"]["start_index"]["minimum"] = 0 - restart_index = await MCPPlanner._parse_result(result, schema) + restart_index = await self._parse_result(result, schema) # 使用RestartStepIndex模型解析结果 return RestartStepIndex.model_validate(restart_index) - @staticmethod async def create_plan( + self, user_goal: str, is_replan: bool = False, error_message: str = "", @@ -184,22 +179,21 @@ class MCPPlanner(MCPBase): ) -> MCPPlan: """规划下一步的执行流程,并输出""" # 获取推理结果 - result = await MCPPlanner._get_reasoning_plan( - user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm, language + result = await self._get_reasoning_plan( + user_goal, is_replan, error_message, current_plan, tool_list, max_steps, language ) # 解析为结构化数据 - return await MCPPlanner._parse_plan_result(result, max_steps) + return await self._parse_plan_result(result, max_steps) - @staticmethod async def _get_reasoning_plan( + self, user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, tool_list: list[MCPTool] = [], max_steps: int = 10, - reasoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> str: """获取推理大模型的结果""" @@ -222,32 +216,29 @@ class MCPPlanner(MCPBase): tools=tool_list, max_num=max_steps, ) - return await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return await self.get_resoning_result(prompt) - @staticmethod - async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan: + async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: """将推理结果解析为结构化数据""" # 格式化Prompt schema = MCPPlan.model_json_schema() schema["properties"]["plans"]["maxItems"] = max_steps - plan = await MCPPlanner._parse_result(result, schema) + plan = await self._parse_result(result, schema) # 使用Function模型解析结果 return MCPPlan.model_validate(plan) - @staticmethod async def create_next_step( + self, goal: str, history: str, tools: list[MCPTool], - reasoning_llm: ReasoningLLM = ReasoningLLM(), - language: LanguageType = LanguageType.CHINESE, - enable_thinking: bool = False, + language: LanguageType = LanguageType.CHINESE ) -> Step: """创建下一步的执行步骤""" # 获取推理结果 template = _env.from_string(GEN_STEP[language]) prompt = template.render(goal=goal, history=history, tools=tools) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm, enable_thinking) + result = await self.get_resoning_result(prompt) # 解析为结构化数据 schema = Step.model_json_schema() @@ -262,14 +253,13 @@ class MCPPlanner(MCPBase): step = Step.model_validate(step) return step - @staticmethod async def tool_skip( + self, task: Task, step_id: str, step_name: str, step_instruction: str, step_content: str, - reasoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> ToolSkip: """判断当前步骤是否需要跳过""" @@ -285,37 +275,35 @@ class MCPPlanner(MCPBase): history=history, goal=task.runtime.question ) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + result = await self.get_resoning_result(prompt) # 解析为结构化数据 schema = ToolSkip.model_json_schema() - skip_result = await MCPPlanner._parse_result(result, schema) + skip_result = await self._parse_result(result, schema) # 使用ToolSkip模型解析结果 return ToolSkip.model_validate(skip_result) - @staticmethod async def get_tool_risk( + self, tool: MCPTool, input_parm: dict[str, Any], additional_info: str = "", - resoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> ToolRisk: """获取MCP工具的风险评估结果""" # 获取推理结果 - result = await MCPPlanner._get_reasoning_risk( - tool, input_parm, additional_info, resoning_llm, language + result = await self._get_reasoning_risk( + tool, input_parm, additional_info, language ) # 返回风险评估结果 - return await MCPPlanner._parse_risk_result(result) + return await self._parse_risk_result(result) - @staticmethod async def _get_reasoning_risk( + self, tool: MCPTool, input_param: dict[str, Any], additional_info: str, - resoning_llm: ReasoningLLM, language: LanguageType = LanguageType.CHINESE, ) -> str: """获取推理大模型的风险评估结果""" @@ -326,24 +314,22 @@ class MCPPlanner(MCPBase): input_param=input_param, additional_info=additional_info, ) - return await MCPPlanner.get_resoning_result(prompt, resoning_llm) + return await self.get_resoning_result(prompt) - @staticmethod - async def _parse_risk_result(result: str) -> ToolRisk: + async def _parse_risk_result(self, result: str) -> ToolRisk: """将推理结果解析为结构化数据""" schema = ToolRisk.model_json_schema() - risk = await MCPPlanner._parse_result(result, schema) + risk = await self._parse_result(result, schema) # 使用ToolRisk模型解析结果 return ToolRisk.model_validate(risk) - @staticmethod async def _get_reasoning_tool_execute_error_type( + self, user_goal: str, current_plan: MCPPlan, tool: MCPTool, input_param: dict[str, Any], error_message: str, - reasoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> str: """获取推理大模型的工具执行错误类型""" @@ -357,45 +343,41 @@ class MCPPlanner(MCPBase): input_param=input_param, error_message=error_message, ) - return await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return await self.get_resoning_result(prompt) - @staticmethod - async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType: + async def _parse_tool_execute_error_type_result(self, result: str) -> ToolExcutionErrorType: """将推理结果解析为工具执行错误类型""" schema = ToolExcutionErrorType.model_json_schema() - error_type = await MCPPlanner._parse_result(result, schema) + error_type = await self._parse_result(result, schema) # 使用ToolExcutionErrorType模型解析结果 return ToolExcutionErrorType.model_validate(error_type) - @staticmethod async def get_tool_execute_error_type( + self, user_goal: str, current_plan: MCPPlan, tool: MCPTool, input_param: dict[str, Any], error_message: str, - reasoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, ) -> ToolExcutionErrorType: """获取MCP工具执行错误类型""" # 获取推理结果 - result = await MCPPlanner._get_reasoning_tool_execute_error_type( - user_goal, current_plan, tool, input_param, error_message, reasoning_llm, language + result = await self._get_reasoning_tool_execute_error_type( + user_goal, current_plan, tool, input_param, error_message, language ) # 返回工具执行错误类型 - return await MCPPlanner._parse_tool_execute_error_type_result(result) + return await self._parse_tool_execute_error_type_result(result) - @staticmethod async def is_param_error( + self, goal: str, history: str, error_message: str, tool: MCPTool, step_description: str, input_params: dict[str, Any], - reasoning_llm: ReasoningLLM = ReasoningLLM(), - language: LanguageType = LanguageType.CHINESE, - enable_thinking: bool = False, + language: LanguageType = LanguageType.CHINESE ) -> IsParamError: """判断错误信息是否是参数错误""" tmplate = _env.from_string(IS_PARAM_ERROR[language]) @@ -408,21 +390,19 @@ class MCPPlanner(MCPBase): input_params=input_params, error_message=error_message, ) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm, enable_thinking) + result = await self.get_resoning_result(prompt) # 解析为结构化数据 schema = IsParamError.model_json_schema() - is_param_error = await MCPPlanner._parse_result(result, schema) + is_param_error = await self._parse_result(result, schema) # 使用IsParamError模型解析结果 return IsParamError.model_validate(is_param_error) - @staticmethod async def change_err_message_to_description( + self, error_message: str, tool: MCPTool, input_params: dict[str, Any], - reasoning_llm: ReasoningLLM = ReasoningLLM(), - language: LanguageType = LanguageType.CHINESE, - enable_thinking: bool = False, + language: LanguageType = LanguageType.CHINESE ) -> str: """将错误信息转换为工具描述""" template = _env.from_string( @@ -434,17 +414,15 @@ class MCPPlanner(MCPBase): input_schema=tool.input_schema, input_params=input_params, ) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm, enable_thinking) + result = await self.get_resoning_result(prompt) return result - @staticmethod async def get_missing_param( + self, tool: MCPTool, input_param: dict[str, Any], error_message: str, - reasoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, - enable_thinking: bool = False, ) -> list[str]: """获取缺失的参数""" slot = Slot(schema=tool.input_schema) @@ -457,17 +435,17 @@ class MCPPlanner(MCPBase): schema=schema_with_null, error_message=error_message, ) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm, enable_thinking) + result = await self.get_resoning_result(prompt) # 解析为结构化数据 - input_param_with_null = await MCPPlanner._parse_result(result, schema_with_null) + input_param_with_null = await self._parse_result(result, schema_with_null) return input_param_with_null - @staticmethod async def generate_answer( + self, user_goal: str, memory: str, - resoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, + enable_thinking: bool = False, ) -> AsyncGenerator[str, None]: """生成最终回答""" template = _env.from_string(FINAL_ANSWER[language]) @@ -475,10 +453,10 @@ class MCPPlanner(MCPBase): memory=memory, goal=user_goal, ) - async for chunk in resoning_llm.call( + async for chunk in self.reasoning_llm.call( [{"role": "user", "content": prompt}], streaming=True, temperature=0.07, - enable_thinking=True + enable_thinking=enable_thinking, ): yield chunk diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py index a62af7ce8a28c285403408c00bcb2b4aa57bbfa4..c1feed2f6efef673cd73dbe4c61da7d82e1a70ce 100644 --- a/apps/scheduler/mcp_agent/select.py +++ b/apps/scheduler/mcp_agent/select.py @@ -8,6 +8,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from apps.llm.reasoning import ReasoningLLM +from apps.llm.function import FunctionLLM from apps.llm.token import TokenCalculator from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import TOOL_SELECT @@ -31,18 +32,24 @@ SELF_DESC_TOOL_ID = "SELF_DESC" class MCPSelector(MCPBase): """MCP选择器""" - @staticmethod + def __init__( + self, + reasoning_llm: ReasoningLLM = None, + function_llm: FunctionLLM = None, + ): + super().__init__(reasoning_llm, function_llm) + async def select_top_tool( + self, goal: str, tool_list: list[MCPTool], additional_info: str | None = None, top_n: int | None = None, - reasoning_llm: ReasoningLLM | None = None, language: LanguageType = LanguageType.CHINESE, ) -> list[MCPTool]: """选择最合适的工具""" random.shuffle(tool_list) - max_tokens = reasoning_llm._config.max_tokens + max_tokens = self.reasoning_llm._config.max_tokens template = _env.from_string(TOOL_SELECT[language]) token_calculator = TokenCalculator() if ( @@ -71,7 +78,8 @@ class MCPSelector(MCPBase): { "role": "user", "content": template.render( - goal=goal, tools=[tool], additional_info=additional_info + goal=goal, tools=[ + tool], additional_info=additional_info ), } ], @@ -103,12 +111,14 @@ class MCPSelector(MCPBase): if "items" not in schema["properties"]["tool_ids"]: schema["properties"]["tool_ids"]["items"] = {} # 将enum添加到items中,限制数组元素的可选值 - schema["properties"]["tool_ids"]["items"]["enum"] = [tool.id for tool in sub_tools] - result = await MCPSelector.get_resoning_result( - template.render(goal=goal, tools=sub_tools, additional_info="请根据目标选择对应的工具"), - reasoning_llm, + schema["properties"]["tool_ids"]["items"]["enum"] = [ + tool.id for tool in sub_tools] + result = await self.get_resoning_result( + template.render(goal=goal, tools=sub_tools, + additional_info="请根据目标选择对应的工具"), + self.reasoning_llm, ) - result = await MCPSelector._parse_result(result, schema) + result = await self._parse_result(result, schema) try: result = MCPToolIdsSelectResult.model_validate(result) tool_ids.extend(result.tool_ids) @@ -120,7 +130,8 @@ class MCPSelector(MCPBase): if top_n is not None: mcp_tools = mcp_tools[:top_n] mcp_tools.append( - MCPTool(id=FINAL_TOOL_ID, name="Final", description="终止", mcp_id=FINAL_TOOL_ID, input_schema={}) + MCPTool(id=FINAL_TOOL_ID, name="Final", description="终止", + mcp_id=FINAL_TOOL_ID, input_schema={}) ) # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize", # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={})) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 6980dbc9f3bb11b318b6d7749db60c517d3a9c1a..652374cebe87a1d45b158486b8b33d5b03c01ef2 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -5,7 +5,7 @@ import asyncio import logging from datetime import UTC, datetime -from apps.llm.schema import DefaultModelId +from apps.llm.enum import DefaultModelId from apps.llm.reasoning import ReasoningLLM from apps.schemas.config import LLMConfig from apps.llm.patterns.rewrite import QuestionRewrite @@ -231,40 +231,16 @@ class Scheduler: if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") return - logger.info( - f"[Scheduler] 应用配置的模型ID: {app_metadata.llm_id}, 启用思维链: {app_metadata.enable_thinking if hasattr(app_metadata, 'enable_thinking') else 'N/A'}") - if not app_metadata.llm_id: - # 获取系统默认模型 - llm = await LLMManager.get_llm_by_id(DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) - else: - llm = await LLMManager.get_llm_by_id(app_metadata.llm_id) - if not llm: - logger.error("[Scheduler] 获取大模型失败") - await self.queue.close() - return - reasion_llm = ReasoningLLM( - LLMConfig( - provider=llm.provider, - endpoint=llm.openai_base_url, - api_key=llm.openai_api_key, - model=llm.model_name, - max_tokens=llm.max_tokens, - ) - ) if background.conversation and self.task.state.flow_status == FlowStatus.INIT: + chat_llm_id = app_metadata.llm_id + if not chat_llm_id: + chat_llm_id = DefaultModelId.DEFAULT_CHAT_MODEL_ID.value + func_call_llm_id = app_metadata.llm_id + if not func_call_llm_id: + func_call_llm_id = DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value try: - # 使用function call模型进行问题改写 - # 降级顺序:应用配置模型 -> 用户偏好的function call模型 -> 系统默认function call模型 -> 系统默认chat模型 - llm_id_for_rewrite = app_metadata.llm_id - if not llm_id_for_rewrite: - llm_id_for_rewrite = DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value - enable_thinking_for_rewrite = app_metadata.enable_thinking if hasattr( - app_metadata, 'enable_thinking') else False - - logger.info(f"[Scheduler] 问题改写使用模型ID: {llm_id_for_rewrite}") question_obj = QuestionRewrite( - llm_id=llm_id_for_rewrite, - enable_thinking=enable_thinking_for_rewrite, + llm_id=func_call_llm_id ) post_body.question = await question_obj.generate( history=background.conversation, @@ -314,10 +290,9 @@ class Scheduler: msg_queue=queue, question=post_body.question, post_body_app=app_info, - enable_thinking=app_metadata.enable_thinking if hasattr( - app_metadata, 'enable_thinking') else False, - llm_id=app_metadata.llm_id if hasattr( - app_metadata, 'llm_id') and app_metadata.llm_id != "empty" else None, + chat_llm_id=chat_llm_id, + enable_thinking=app_metadata.enable_thinking, + func_call_llm_id=func_call_llm_id, background=background, ) @@ -339,8 +314,9 @@ class Scheduler: background=background, agent_id=app_info.app_id, params=post_body.params, + chat_llm_id=chat_llm_id, enable_thinking=app_metadata.enable_thinking, - resoning_llm=reasion_llm, + func_call_llm_id=func_call_llm_id, auto_execute=post_body.auto_execute ) # 开始运行 diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 1b1c6540a56c8959f42fc65baf55c1c5869dd068..dc2135b945e4760569d3997a94ff0137617f17f1 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -7,6 +7,7 @@ from typing import Any from pydantic import BaseModel, Field +from apps.llm.enum import DefaultModelId from apps.schemas.enum_var import FlowStatus, StepStatus, LanguageType from apps.schemas.flow import Step from apps.schemas.mcp import MCPPlan @@ -31,7 +32,8 @@ class FlowStepHistory(BaseModel): input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) ex_data: dict[str, Any] | None = Field(description="额外数据", default=None) - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) + created_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) class ExecutorState(BaseModel): @@ -41,18 +43,21 @@ class ExecutorState(BaseModel): flow_id: str = Field(description="Flow ID", default="") flow_name: str = Field(description="Flow名称", default="") description: str = Field(description="Flow描述", default="") - flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT) + flow_status: FlowStatus = Field( + description="Flow状态", default=FlowStatus.INIT) # 任务级数据 step_cnt: int = Field(description="当前步骤数量", default=0) step_id: str = Field(description="当前步骤ID", default="") tool_id: str = Field(description="当前工具ID", default="") step_name: str = Field(description="当前步骤名称", default="") - step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) + step_status: StepStatus = Field( + description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID", default="") current_input: dict[str, Any] = Field(description="当前输入数据", default={}) error_message: str = Field(description="错误信息", default="") - error_info: dict[str, Any] | None = Field(description="详细错误信息", default=None) + error_info: dict[str, Any] | None = Field( + description="详细错误信息", default=None) retry_times: int = Field(description="当前步骤重试次数", default=0) @@ -62,9 +67,19 @@ class TaskIds(BaseModel): session_id: str = Field(description="会话ID") group_id: str = Field(description="组ID") conversation_id: str = Field(description="对话ID") - record_id: str = Field(description="记录ID", default_factory=lambda: str(uuid.uuid4())) + record_id: str = Field( + description="记录ID", default_factory=lambda: str(uuid.uuid4())) user_sub: str = Field(description="用户ID") - active_id: str = Field(description="活动ID", default_factory=lambda: str(uuid.uuid4())) + active_id: str = Field( + description="活动ID", default_factory=lambda: str(uuid.uuid4())) + func_call_llm_id: str = Field( + description="函数调用大模型ID", default=DefaultModelId.DEFAULT_FUNCTION_CALL_MODEL_ID.value) + chat_llm_id: str = Field( + description="聊天大模型ID", default=DefaultModelId.DEFAULT_CHAT_MODEL_ID.value) + embedding_llm_id: str = Field( + description="向量化大模型ID", default=DefaultModelId.DEFAULT_EMBEDDING_MODEL_ID.value) + reranker_llm_id: str = Field( + description="重排序大模型ID", default=DefaultModelId.DEFAULT_RERANKER_MODEL_ID.value) class TaskTokens(BaseModel): @@ -98,12 +113,16 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) - state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) + context: list[FlowStepHistory] = Field( + description="Flow的步骤执行信息", default=[]) + state: ExecutorState = Field( + description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - language: LanguageType = Field(description="语言", default=LanguageType.CHINESE) + created_at: float = Field(default_factory=lambda: round( + datetime.now(tz=UTC).timestamp(), 3)) + language: LanguageType = Field( + description="语言", default=LanguageType.CHINESE) class StepQueueItem(BaseModel): diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 89fe4dc885d82b69828b020e2ff0185f4de03ebe..f2b2551adc7aef8160c03e4839230eb68930aa7c 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -14,7 +14,7 @@ from apps.services.llm import LLMManager from apps.services.task import TaskManager from apps.templates.generate_llm_operator_config import llm_provider_dict from apps.llm.adapters import get_provider_from_endpoint -from apps.llm.schema import DefaultModelId +from apps.llm.enum import DefaultModelId logger = logging.getLogger(__name__) diff --git a/apps/services/llm.py b/apps/services/llm.py index 595551323c1b90d9b6546b812c06fb0d5baaba93..6ba8814a1967434e7ba354363ff2d8278e0c941f 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -13,7 +13,7 @@ from apps.schemas.request_data import ( ) from apps.schemas.response_data import LLMProvider, LLMProviderInfo from apps.templates.generate_llm_operator_config import llm_provider_dict -from apps.llm.schema import DefaultModelId +from apps.llm.enum import DefaultModelId from apps.llm.model_registry import model_registry from apps.llm.adapters import get_provider_from_endpoint @@ -32,15 +32,20 @@ class LLMManager: :return: LLMProviderInfo 对象 """ # 标准化type字段为列表格式 - llm_type = llm.get("type", "chat") - if isinstance(llm_type, str): - llm_type = [llm_type] - + llm_type = llm.get("type", None) + if llm_type is None: + llm_type = ["chat"] + # 将枚举类DefaultModelId转换为list + default_model_id_list = [item.value for item in DefaultModelId] + if llm["_id"] in default_model_id_list: + openai_api_key = "" # 系统默认模型不返回api key + else: + openai_api_key = llm.get("openai_api_key", "") return LLMProviderInfo( llmId=llm["_id"], # _id已经是UUID字符串 icon=llm["icon"], openaiBaseUrl=llm["openai_base_url"], - openaiApiKey=llm["openai_api_key"], + openaiApiKey=openai_api_key, modelName=llm["model_name"], maxTokens=llm["max_tokens"], isEditable=bool(llm.get("user_sub")), # 系统模型(user_sub="")不可编辑