diff --git a/apps/common/postgres.py b/apps/common/postgres.py
index 90b1837a9e00e9f612a0d950f17fc08919d03842..74b0dcfb2d21277920574b77e0cb40bcb1273482 100644
--- a/apps/common/postgres.py
+++ b/apps/common/postgres.py
@@ -5,7 +5,7 @@ import urllib.parse
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
-from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
+from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from apps.models import (
Base,
@@ -19,18 +19,20 @@ logger = logging.getLogger(__name__)
class Postgres:
"""Postgres连接器"""
+ engine: AsyncEngine
+
async def init(self) -> None:
"""初始化Postgres连接器"""
logger.info("[Postgres] 初始化Postgres连接器")
- self._engine = create_async_engine(
+ self.engine = create_async_engine(
f"postgresql+asyncpg://{urllib.parse.quote_plus(config.postgres.user)}:"
f"{urllib.parse.quote_plus(config.postgres.password)}@{config.postgres.host}:"
f"{config.postgres.port}/{config.postgres.database}",
)
- self._session = async_sessionmaker(self._engine, expire_on_commit=False)
+ self._session = async_sessionmaker(self.engine, expire_on_commit=False)
logger.info("[Postgres] 创建表")
- async with self._engine.begin() as conn:
+ async with self.engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@asynccontextmanager
diff --git a/apps/llm/embedding.py b/apps/llm/embedding.py
index 3fd98667e4775e3b78e5618f87504a09f5816215..c18ec1bb378afc4e2a4a09016d36a138485bd8a1 100644
--- a/apps/llm/embedding.py
+++ b/apps/llm/embedding.py
@@ -1,6 +1,7 @@
"""Embedding模型"""
import logging
+from typing import Any
import httpx
from pgvector.sqlalchemy import Vector
@@ -12,7 +13,6 @@ from apps.common.postgres import postgres
from apps.models.llm import EmbeddingBackend, LLMData
_logger = logging.getLogger(__name__)
-VectorBase = declarative_base()
_flow_pool_vector_table = {
"__tablename__": "framework_flow_vector",
"appId": Column(UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=False),
@@ -86,6 +86,13 @@ _mcp_tool_vector_table = {
class Embedding:
"""Embedding模型"""
+ VectorBase: Any
+ NodePoolVector: Any
+ FlowPoolVector: Any
+ ServicePoolVector: Any
+ MCPVector: Any
+ MCPToolVector: Any
+
async def _get_embedding_dimension(self) -> int:
"""获取Embedding的维度"""
embedding = await self.get_embedding(["测试文本"])
@@ -116,17 +123,31 @@ class Embedding:
_mcp_tool_vector_table,
]:
table["embedding"] = Column(Vector(dim), nullable=False)
- # 创建表
- VectorBase.metadata.create_all(VectorBase.metadata)
- async def __init__(self, llm_config: LLMData | None = None) -> None:
- """初始化Embedding模型"""
+ # 创建表
+ self.VectorBase = declarative_base()
+ self.NodePoolVector = type("NodePoolVector", (self.VectorBase,), _node_pool_vector_table)
+ self.FlowPoolVector = type("FlowPoolVector", (self.VectorBase,), _flow_pool_vector_table)
+ self.ServicePoolVector = type("ServicePoolVector", (self.VectorBase,), _service_pool_vector_table)
+ self.MCPVector = type("MCPVector", (self.VectorBase,), _mcp_vector_table)
+ self.MCPToolVector = type("MCPToolVector", (self.VectorBase,), _mcp_tool_vector_table)
+ self.VectorBase.metadata.create_all(postgres.engine)
+
+ def __init__(self, llm_config: LLMData | None = None) -> None:
+ """初始化Embedding对象"""
if not llm_config or not llm_config.embeddingBackend:
err = "[Embedding] 未设置Embedding模型"
_logger.error(err)
raise RuntimeError(err)
self._config: LLMData = llm_config
+ async def init(self) -> None:
+ """在使用Embedding前初始化数据库表等资源"""
+ await self._delete_vector()
+ # 检测维度
+ dim = await self._get_embedding_dimension()
+ await self._create_vector_table(dim)
+
async def _get_openai_embedding(self, text: list[str]) -> list[list[float]]:
"""访问OpenAI兼容的Embedding API,获得向量化数据"""
api = self._config.openaiBaseUrl + "/embeddings"
@@ -152,7 +173,6 @@ class Embedding:
json = response.json()
return [item["embedding"] for item in json["data"]]
-
async def _get_tei_embedding(self, text: list[str]) -> list[list[float]]:
"""访问TEI兼容的Embedding API,获得向量化数据"""
api = self._config.openaiBaseUrl + "/embed"
@@ -177,7 +197,6 @@ class Embedding:
return result
-
async def get_embedding(self, text: list[str]) -> list[list[float]]:
"""
访问OpenAI兼容的Embedding API,获得向量化数据
@@ -185,14 +204,11 @@ class Embedding:
:param text: 待向量化文本(多条文本组成List)
:return: 文本对应的向量(顺序与text一致,也为List)
"""
- try:
- if self._config.embeddingBackend == EmbeddingBackend.OPENAI:
- return await self._get_openai_embedding(text)
- if self._config.embeddingBackend == EmbeddingBackend.TEI:
- return await self._get_tei_embedding(text)
-
- _logger.error("[Embedding] 不支持的Embedding API类型: %s", self._config.modelName)
- return [[0.0] * 1024 for _ in range(len(text))]
- except Exception:
- _logger.exception("[Embedding] 获取Embedding失败")
- return [[0.0] * 1024 for _ in range(len(text))]
+ if self._config.embeddingBackend == EmbeddingBackend.OPENAI:
+ return await self._get_openai_embedding(text)
+ if self._config.embeddingBackend == EmbeddingBackend.TEI:
+ return await self._get_tei_embedding(text)
+
+ err = f"[Embedding] 不支持的Embedding API类型: {self._config.modelName}"
+ _logger.error(err)
+ raise RuntimeError(err)
diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py
index 0c3fb2ffd76ce1f078ebb0608cd9f2018e016c29..f4e83230fab8e9dbf8c5d568ed731e297c45f20c 100644
--- a/apps/llm/patterns/rewrite.py
+++ b/apps/llm/patterns/rewrite.py
@@ -172,7 +172,7 @@ application scenarios."
llm = kwargs.get("llm")
if not llm:
llm = ReasoningLLM()
- leave_tokens = llm._config.max_tokens
+ leave_tokens = llm.config.max_tokens
leave_tokens -= TokenCalculator().calculate_token_length(messages)
if leave_tokens <= 0:
logger.error("[QuestionRewrite] 大模型上下文窗口不足,无法进行问题补全与重写")
diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py
index 4cd886b2057fe8f4ec9fffa32f01d03f17240d2a..9141a5c35a1cab0df7a4a7d755cdf1fee920ebb9 100644
--- a/apps/llm/reasoning.py
+++ b/apps/llm/reasoning.py
@@ -100,21 +100,21 @@ class ReasoningLLM:
err = "未设置大模型配置"
logger.error(err)
raise RuntimeError(err)
- self._config: LLMData = llm_config
+ self.config: LLMData = llm_config
self._init_client()
def _init_client(self) -> None:
"""初始化OpenAI客户端"""
- if not self._config.openaiAPIKey:
+ if not self.config.openaiAPIKey:
self._client = AsyncOpenAI(
- base_url=self._config.openaiBaseUrl,
+ base_url=self.config.openaiBaseUrl,
timeout=self.timeout,
)
return
self._client = AsyncOpenAI(
- api_key=self._config.openaiAPIKey,
- base_url=self._config.openaiBaseUrl,
+ api_key=self.config.openaiAPIKey,
+ base_url=self.config.openaiBaseUrl,
timeout=self.timeout,
)
@@ -140,12 +140,12 @@ class ReasoningLLM:
) -> AsyncGenerator[ChatCompletionChunk, None]:
"""创建流式响应"""
if model is None:
- model = self._config.modelName
+ model = self.config.modelName
return await self._client.chat.completions.create(
model=model,
messages=messages, # type: ignore[]
- max_tokens=max_tokens or self._config.maxToken,
- temperature=temperature or self._config.temperature,
+ max_tokens=max_tokens or self.config.maxToken,
+ temperature=temperature or self.config.temperature,
stream=True,
stream_options={"include_usage": True},
) # type: ignore[]
@@ -163,11 +163,11 @@ class ReasoningLLM:
"""调用大模型,分为流式和非流式两种"""
# 检查max_tokens和temperature
if max_tokens is None:
- max_tokens = self._config.maxToken
+ max_tokens = self.config.maxToken
if temperature is None:
- temperature = self._config.temperature
+ temperature = self.config.temperature
if model is None:
- model = self._config.modelName
+ model = self.config.modelName
msg_list = self._validate_messages(messages)
stream = await self._create_stream(msg_list, max_tokens, temperature, model)
reasoning = ReasoningContent()
diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py
index 69abdfae6e5b19ac19b35a6840d4960eb6c74339..4e237aad36f10ce64638558f29a7f439913d3789 100644
--- a/apps/scheduler/call/api/api.py
+++ b/apps/scheduler/call/api/api.py
@@ -171,6 +171,11 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput):
async def _apply_auth(self) -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
"""应用认证信息到请求参数中"""
+ if not self._session_id:
+ err = "[API] 未设置Session ID"
+ logger.error(err)
+ raise CallError(message=err, data={})
+
# self._auth可能是None或ServiceApiAuth类型
# ServiceApiAuth类型包含header、cookie、query和oidc属性
req_header = {}
diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py
index 05d0709920d4c357ceeb8ab46f71379a970c9148..bb3e2429dbf6d1819e0cd37cfde53b31d004c772 100644
--- a/apps/scheduler/call/core.py
+++ b/apps/scheduler/call/core.py
@@ -12,8 +12,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Self
from pydantic import BaseModel, ConfigDict, Field
from pydantic.json_schema import SkipJsonSchema
-from apps.llm.function import FunctionLLM
-from apps.llm.reasoning import ReasoningLLM
from apps.models.node import NodeInfo
from apps.models.task import ExecutorHistory
from apps.schemas.enum_var import CallOutputType, LanguageType
@@ -22,10 +20,8 @@ from apps.schemas.scheduler import (
CallIds,
CallInfo,
CallOutputChunk,
- CallTokens,
CallVars,
)
-from apps.services.llm import LLMManager
if TYPE_CHECKING:
from apps.scheduler.executor.step import StepExecutor
@@ -54,11 +50,6 @@ class CoreCall(BaseModel):
description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True)
node: SkipJsonSchema[NodeInfo | None] = Field(description="节点信息", exclude=True)
enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False, exclude=True)
- tokens: SkipJsonSchema[CallTokens] = Field(
- description="Call的输入输出Tokens信息",
- default=CallTokens(),
- exclude=True,
- )
input_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field(
description="Call的输入Pydantic类型;不包含override的模板",
exclude=True,
@@ -166,6 +157,7 @@ class CoreCall(BaseModel):
async def _set_input(self, executor: "StepExecutor") -> None:
"""获取Call的输入"""
+ self._llm_obj = executor.llm
self._sys_vars = self._assemble_call_vars(executor)
input_data = await self._init(self._sys_vars)
self.input = input_data.model_dump(by_alias=True, exclude_none=True)
@@ -196,18 +188,19 @@ class CoreCall(BaseModel):
async def _llm(self, messages: list[dict[str, Any]], *, streaming: bool = False) -> AsyncGenerator[str, None]:
"""Call可直接使用的LLM非流式调用"""
if streaming:
- async for chunk in llm.call(messages, streaming=streaming):
+ async for chunk in self._llm_obj.reasoning.call(messages, streaming=streaming):
yield chunk
else:
result = ""
- async for chunk in llm.call(messages, streaming=streaming):
+ async for chunk in self._llm_obj.reasoning.call(messages, streaming=streaming):
result += chunk
yield result
- self.tokens.input_tokens += llm.input_tokens
- self.tokens.output_tokens += llm.output_tokens
-
async def _json(self, messages: list[dict[str, Any]], schema: dict[str, Any]) -> dict[str, Any]:
"""Call可直接使用的JSON生成"""
- return await json.call(messages=messages, schema=schema)
+ if not self._llm_obj.function:
+ err = "[CoreCall] 未设置函数调用模型!"
+ logger.error(err)
+ raise CallError(message=err, data={})
+ return await self._llm_obj.function.call(messages=messages, schema=schema)
diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py
index 48d2531f3ccfebbbb2a14192994a0cb74c190d62..57339125d6edf84b5972c102ef1d6cc2c6d3bf0c 100644
--- a/apps/scheduler/call/facts/facts.py
+++ b/apps/scheduler/call/facts/facts.py
@@ -49,7 +49,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self:
"""初始化工具"""
obj = cls(
- answer=executor.runtime.fullAnswer,
+ answer=executor.task.runtime.fullAnswer,
name=executor.step.step.name,
description=executor.step.step.description,
node=node,
@@ -122,5 +122,5 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
if not isinstance(content, dict):
err = "[FactsCall] 工具输出格式错误"
raise TypeError(err)
- executor.runtime.fact = FactsOutput.model_validate(content).facts
+ executor.task.runtime.fact = FactsOutput.model_validate(content).facts
yield chunk
diff --git a/apps/scheduler/call/rag/prompt.py b/apps/scheduler/call/rag/prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a32ec4c4321d968c1fe1390e8adc5600323e685
--- /dev/null
+++ b/apps/scheduler/call/rag/prompt.py
@@ -0,0 +1,108 @@
+"""RAG工具的提示词"""
+
+from apps.schemas.enum_var import LanguageType
+
+GEN_RAG_ANSWER: dict[LanguageType, str] = {
+ LanguageType.CHINESE: r"""
+
+ 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后\
+进行脚注。
+ 一个例子将在中给出。
+ 上下文背景信息将在中给出。
+ 用户的提问将在中给出。
+ 注意:
+ 1.输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。
+ 2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。
+ 3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。
+ 4.不要对脚注本身进行解释或说明。
+ 5.请不要使用中的文档的id作为脚注。
+
+
+
+
+
+ openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。
+
+
+ openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。
+
+
+
+
+ openEuler社区的成员来自世界各地,包括开发者、用户和企业。
+
+
+ openEuler社区的成员共同努力,推动开源操作系统的发展,并且为用户提供支持和帮助。
+
+
+
+
+ openEuler社区的目标是什么?
+
+
+ openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。[[1]] openEuler社区的目标是为\
+用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]]
+
+
+
+
+ {bac_info}
+
+
+ {user_question}
+
+ """,
+ LanguageType.ENGLISH: r"""
+
+ You are a helpful assistant of openEuler community. Please answer the user's question based on the \
+given background information and add footnotes after the related sentences.
+ An example will be given in .
+ The background information will be given in .
+ The user's question will be given in .
+ Note:
+ 1. Do not include any XML tags in the output, and do not make up any information. If you think the \
+user's question is unrelated to the background information, please ignore the background information and directly \
+answer.
+ 2. Your response should not exceed 250 words.
+
+
+
+
+
+ openEuler community is an open source operating system community, committed to promoting \
+the development of the Linux operating system.
+
+
+ openEuler community aims to provide users with a stable, secure, and efficient operating \
+system platform, and support multiple hardware architectures.
+
+
+
+
+ Members of the openEuler community come from all over the world, including developers, \
+users, and enterprises.
+
+
+ Members of the openEuler community work together to promote the development of open \
+source operating systems, and provide support and assistance to users.
+
+
+
+
+ What is the goal of openEuler community?
+
+
+ openEuler community is an open source operating system community, committed to promoting the \
+development of the Linux operating system. [[1]] openEuler community aims to provide users with a stable, secure, \
+and efficient operating system platform, and support multiple hardware architectures. [[1]]
+
+
+
+
+ {bac_info}
+
+
+ {user_question}
+
+ """,
+ }
diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py
index e4a12bf1eb0b7a6504390e31facb3d258f746301..740110900f8140f26408a2609c25ff466931918b 100644
--- a/apps/scheduler/call/rag/rag.py
+++ b/apps/scheduler/call/rag/rag.py
@@ -54,6 +54,11 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput):
async def _init(self, call_vars: CallVars) -> RAGInput:
"""初始化RAG工具"""
+ if not call_vars.ids.session_id:
+ err = "[RAG] 未设置Session ID"
+ logger.error(err)
+ raise CallError(message=err, data={})
+
return RAGInput(
session_id=call_vars.ids.session_id,
kbIds=self.knowledge_base_ids,
@@ -74,8 +79,6 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput):
question_obj = QuestionRewrite()
question = await question_obj.generate(question=data.question)
data.question = question
- self.tokens.input_tokens += question_obj.input_tokens
- self.tokens.output_tokens += question_obj.output_tokens
url = config.rag.rag_service.rstrip("/") + "/chunk/search"
headers = {
diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py
index ddfcaa01ccca519611f5f8e967dbd6621d96ff70..92520f513f03d3b1e32f2e45d3c63308f583e0a7 100644
--- a/apps/scheduler/call/slot/slot.py
+++ b/apps/scheduler/call/slot/slot.py
@@ -97,7 +97,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
name=executor.step.step.name,
description=executor.step.step.description,
facts=executor.background.facts,
- summary=executor.runtime.reasoning,
+ summary=executor.task.runtime.reasoning,
node=node,
**kwargs,
)
diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py
index d67dc385f8ba882ac269936d47d88f559762db73..f58eae0ebe0b2fc7b555215ad3f85535d7bb4959 100644
--- a/apps/scheduler/call/suggest/suggest.py
+++ b/apps/scheduler/call/suggest/suggest.py
@@ -69,11 +69,11 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO
context = [
{
"role": "user",
- "content": executor.runtime.userInput,
+ "content": executor.task.runtime.userInput,
},
{
"role": "assistant",
- "content": executor.runtime.fullAnswer,
+ "content": executor.task.runtime.fullAnswer,
},
]
obj = cls(
@@ -81,7 +81,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO
description=executor.step.step.description,
node=node,
context=context,
- conversation_id=executor.task.conversationId,
+ conversation_id=executor.task.metadata.conversationId,
**kwargs,
)
await obj._set_input(executor)
diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py
index 6f1956dc7b32339aef7f99de71197c06a54bf391..e0abafa45f6e1d235d6cd90f9e1b60e6ba71d8ea 100644
--- a/apps/scheduler/call/summary/summary.py
+++ b/apps/scheduler/call/summary/summary.py
@@ -64,8 +64,6 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput):
"""执行工具"""
summary_obj = ExecutorSummary()
summary = await summary_obj.generate(background=self.context, language=self._sys_vars.language)
- self.tokens.input_tokens += summary_obj.input_tokens
- self.tokens.output_tokens += summary_obj.output_tokens
yield CallOutputChunk(type=CallOutputType.TEXT, content=summary)
@@ -77,5 +75,5 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput):
if not isinstance(content, str):
err = "[SummaryCall] 工具输出格式错误"
raise TypeError(err)
- executor.runtime.reasoning = content
+ executor.task.runtime.reasoning = content
yield chunk
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index e316061af84d526c7dbf220cec5efe8885875777..2b3f25a49fa9cd3473afd9a74e985641de8d1986 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -106,6 +106,7 @@ class FlowExecutor(BaseExecutor):
step=self.current_step,
background=self.background,
question=self.question,
+ llm=self.llm,
)
# 初始化步骤
diff --git a/apps/scheduler/executor/qa.py b/apps/scheduler/executor/qa.py
index 4c1a485d5dde73bca9a616f42f020af66790d068..d5d8190d90cf39d9e4592a319fe4ce822298452f 100644
--- a/apps/scheduler/executor/qa.py
+++ b/apps/scheduler/executor/qa.py
@@ -1,12 +1,21 @@
"""用于执行智能问答的Executor"""
+import logging
import uuid
+from datetime import UTC, datetime
+from textwrap import dedent
from apps.models.document import Document
+from apps.schemas.enum_var import EventType
+from apps.schemas.message import DocumentAddContent, TextAddContent
+from apps.schemas.rag_data import RAGEventData
from apps.schemas.record import RecordDocument
from apps.services.document import DocumentManager
+from apps.services.rag import RAG
from .base import BaseExecutor
+_logger = logging.getLogger(__name__)
+
class QAExecutor(BaseExecutor):
"""用于执行智能问答的Executor"""
@@ -21,6 +30,66 @@ class QAExecutor(BaseExecutor):
doc_ids += [doc.id for doc in docs]
return docs, doc_ids
+
+ async def _push_rag_chunk(self, content: str) -> RAGEventData | None:
+ """推送RAG单个消息块"""
+ # 如果是换行
+ if not content or not content.rstrip().rstrip("\n"):
+ return None
+
+ try:
+ content_obj = RAGEventData.model_validate_json(dedent(content[6:]).rstrip("\n"))
+ # 如果是空消息
+ if not content_obj.content:
+ return None
+
+ # 推送消息
+ if content_obj.event_type == EventType.TEXT_ADD.value:
+ await self.msg_queue.push_output(
+ task=self.task,
+ event_type=content_obj.event_type,
+ data=TextAddContent(text=content_obj.content).model_dump(exclude_none=True, by_alias=True),
+ )
+ elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
+ await self.msg_queue.push_output(
+ task=self.task,
+ event_type=content_obj.event_type,
+ data=DocumentAddContent(
+ documentId=content_obj.content.get("id", ""),
+ documentOrder=content_obj.content.get("order", 0),
+ documentAuthor=content_obj.content.get("author", ""),
+ documentName=content_obj.content.get("name", ""),
+ documentAbstract=content_obj.content.get("abstract", ""),
+ documentType=content_obj.content.get("extension", ""),
+ documentSize=content_obj.content.get("size", 0),
+ createdAt=round(content_obj.content.get("created_at", datetime.now(tz=UTC).timestamp()), 3),
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ except Exception:
+ _logger.exception("[Scheduler] RAG服务返回错误数据")
+ return None
+ else:
+ return content_obj
+
async def run(self) -> None:
"""运行QA"""
- pass
+ full_answer = ""
+
+ try:
+ async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data):
+ task, content_obj = await self._push_rag_chunk(task, queue, chunk)
+ if not content_obj:
+ continue
+ if content_obj.event_type == EventType.TEXT_ADD.value:
+ # 如果是文本消息,直接拼接到答案中
+ full_answer += content_obj.content
+ elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
+ task.runtime.documents.append(content_obj.content)
+ task.state.flow_status = ExecutorStatus.SUCCESS
+ except Exception as e:
+ logger.error(f"[Scheduler] RAG服务发生错误: {e}")
+ task.state.flow_status = ExecutorStatus.ERROR
+ # 保存答案
+ task.runtime.answer = full_answer
+ task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time
+ await TaskManager.save_task(task.id, task)
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
index cbecd9293d47839335d58c820f3a883de2c88a48..6bdf01c896f2a6885ba078d1fa5341b2202746c4 100644
--- a/apps/scheduler/executor/step.py
+++ b/apps/scheduler/executor/step.py
@@ -26,7 +26,6 @@ from apps.schemas.enum_var import (
from apps.schemas.message import TextAddContent
from apps.schemas.scheduler import CallError, CallOutputChunk
from apps.services.node import NodeManager
-from apps.services.task import TaskManager
from .base import BaseExecutor
@@ -162,11 +161,6 @@ class StepExecutor(BaseExecutor):
iterator = slot_obj.exec(self, slot_obj.input)
async for chunk in iterator:
result: SlotOutput = SlotOutput.model_validate(chunk.content)
- await TaskManager.update_task_token(
- self.task.metadata.id,
- input_token=slot_obj.tokens.input_tokens,
- output_token=slot_obj.tokens.output_tokens,
- )
# 如果没有填全,则状态设置为待填参
if result.remaining_schema:
@@ -181,11 +175,6 @@ class StepExecutor(BaseExecutor):
# 恢复State
self.task.state.stepId = current_step_id
self.task.state.stepName = current_step_name
- await TaskManager.update_task_token(
- self.task.metadata.id,
- input_token=self.obj.tokens.input_tokens,
- output_token=self.obj.tokens.output_tokens,
- )
async def _process_chunk(
@@ -260,11 +249,6 @@ class StepExecutor(BaseExecutor):
# 更新执行状态
self.task.state.stepStatus = StepStatus.SUCCESS
- await TaskManager.update_task_token(
- self.task.metadata.id,
- input_token=self.obj.tokens.input_tokens,
- output_token=self.obj.tokens.output_tokens,
- )
self.task.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.task.runtime.time
# 更新history
diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py
index 78b3b75720b78603c5cb3cb9fac5e90ae3e57c28..8634f97290713ac7cfa12e9c289ca59e9ebe7ab2 100644
--- a/apps/scheduler/pool/loader/call.py
+++ b/apps/scheduler/pool/loader/call.py
@@ -8,7 +8,7 @@ from sqlalchemy import delete
import apps.scheduler.call as system_call
from apps.common.postgres import postgres
from apps.common.singleton import SingletonMeta
-from apps.llm.embedding import Embedding, VectorBase
+from apps.llm.embedding import Embedding
from apps.models.node import NodeInfo
from apps.schemas.scheduler import CallInfo
@@ -31,14 +31,12 @@ class CallLoader(metaclass=SingletonMeta):
return call_metadata
- # 更新数据库
- async def _add_to_db(self, call_metadata: dict[str, CallInfo]) -> None:
- """更新数据库"""
+ # 将数据插入数据库
+ async def _add_data_to_db(self, call_metadata: dict[str, CallInfo]) -> None:
+ """将数据插入数据库"""
# 清除旧数据
async with postgres.session() as session:
await session.execute(delete(NodeInfo).where(NodeInfo.serviceId == None)) # noqa: E711
- NodePoolVector = VectorBase.metadata.tables["framework_node_vector"] # noqa: N806
- await session.execute(delete(NodePoolVector).where(NodePoolVector.serviceId == None)) # noqa: E711
# 更新数据库
call_descriptions = []
@@ -55,18 +53,26 @@ class CallLoader(metaclass=SingletonMeta):
))
call_descriptions.append(call.description)
- # 进行向量化
- call_vecs = await Embedding.get_embedding(call_descriptions)
- vector_data = []
+ await session.commit()
+
+
+ # 将向量化数据存入数据库
+ async def _add_vector_to_db(
+ self, call_metadata: dict[str, CallInfo], embedding_model: Embedding,
+ ) -> None:
+ """将向量化数据存入数据库"""
+ async with postgres.session() as session:
+ # 删除旧数据
+ await session.execute(
+ delete(embedding_model.NodePoolVector).where(embedding_model.NodePoolVector.serviceId == None), # noqa: E711
+ )
+
+ call_vecs = await embedding_model.get_embedding([call.description for call in call_metadata.values()])
for call_id, vec in zip(call_metadata.keys(), call_vecs, strict=True):
- vector_data.append(
- NodePoolVector(
- id=call_id,
- serviceId=None,
- embedding=vec,
- ),
- )
- session.add_all(vector_data)
+ session.add(embedding_model.NodePoolVector(
+ id=call_id,
+ embedding=vec,
+ ))
await session.commit()
@@ -81,4 +87,4 @@ class CallLoader(metaclass=SingletonMeta):
raise RuntimeError(err) from e
# 更新数据库
- await self._add_to_db(sys_call_metadata)
+ await self._add_data_to_db(sys_call_metadata)
diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py
index ab74a79100f979911acd976030457f28977045b1..ef95968631107238db2ff3ea6592b3783b16ea5c 100644
--- a/apps/scheduler/pool/loader/flow.py
+++ b/apps/scheduler/pool/loader/flow.py
@@ -147,7 +147,7 @@ class FlowLoader:
description=flow_config.description,
enabled=True,
path=str(flow_path),
- debug=flow_config.debug,
+ debug=flow_config.checkStatus.debug,
),
)
return Flow.model_validate(flow_yaml)
@@ -180,13 +180,13 @@ class FlowLoader:
description=flow.description,
enabled=True,
path=str(flow_path),
- debug=flow.debug,
+ debug=flow.checkStatus.debug,
),
)
@staticmethod
- async def delete(app_id: uuid.UUID, flow_id: str) -> None:
+ async def delete(app_id: uuid.UUID, flow_id: str, embedding_model: Embedding | None = None) -> None:
"""删除指定工作流文件"""
flow_path = BASE_PATH / str(app_id) / "flow" / f"{flow_id}.yaml"
# 确保目标为文件且存在
@@ -201,7 +201,10 @@ class FlowLoader:
FlowInfo.id == flow_id,
),
))
- await session.execute(delete(FlowPoolVector).where(FlowPoolVector.id == flow_id))
+ if embedding_model:
+ await session.execute(
+ delete(embedding_model.FlowPoolVector).where(embedding_model.FlowPoolVector.id == flow_id),
+ )
await session.commit()
return
logger.warning("[FlowLoader] 工作流文件不存在或不是文件:%s", flow_path)
@@ -228,7 +231,6 @@ class FlowLoader:
AppHashes.filePath == f"flow/{metadata.id}.yaml",
),
))
- await session.execute(delete(FlowPoolVector).where(FlowPoolVector.id == metadata.id))
# 创建新的Flow数据
session.add(metadata)
@@ -243,15 +245,21 @@ class FlowLoader:
filePath=f"flow/{metadata.id}.yaml",
)
session.add(flow_hash)
+ await session.commit()
- # 进行向量化
- service_embedding = await Embedding.get_embedding([metadata.description])
- vector_data = [
- FlowPoolVector(
- id=metadata.id,
- appId=app_id,
- embedding=service_embedding[0],
- ),
- ]
- session.add_all(vector_data)
+ @staticmethod
+ async def _update_vector(app_id: uuid.UUID, metadata: FlowInfo, embedding_model: Embedding) -> None:
+ """将向量化数据存入数据库"""
+ async with postgres.session() as session:
+ await session.execute(
+ delete(embedding_model.FlowPoolVector).where(embedding_model.FlowPoolVector.id == metadata.id),
+ )
+
+ # 获取向量数据
+ service_embedding = await embedding_model.get_embedding([metadata.description])
+ session.add(embedding_model.FlowPoolVector(
+ id=metadata.id,
+ appId=app_id,
+ embedding=service_embedding[0],
+ ))
await session.commit()
diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py
index 3fac807119f99e48e28cfb5e4fd19a6581727f84..5074133ef4b42c1e641a041a571adf17237080c8 100644
--- a/apps/scheduler/pool/loader/mcp.py
+++ b/apps/scheduler/pool/loader/mcp.py
@@ -315,29 +315,29 @@ class MCPLoader(metaclass=SingletonMeta):
session.add_all(tool_list)
await session.commit()
- # 服务本身向量化
- embedding = await Embedding.get_embedding([config.description])
- async with postgres.session() as session:
- # 删除旧的向量
- await session.execute(delete(MCPVector).where(MCPVector.id == mcp_id))
- # 插入新的向量
- session.add(MCPVector(
- id=mcp_id,
- embedding=embedding[0],
- ))
- await session.commit()
-
- # 工具向量化
+ @staticmethod
+ async def _insert_template_tool_vector(mcp_id: str, config: MCPServerConfig, embedding_model: Embedding) -> None:
+ """插入MCP相关的向量数据"""
+ # 获取工具列表
+ tool_list = await MCPLoader._get_template_tool(mcp_id, config)
tool_desc_list = [tool.description for tool in tool_list]
- tool_embedding = await Embedding.get_embedding(tool_desc_list)
+ mcp_embedding = await embedding_model.get_embedding([config.description])
+ tool_embedding = await embedding_model.get_embedding(tool_desc_list)
async with postgres.session() as session:
- # 删除旧的工具向量
- await session.execute(delete(MCPToolVector).where(MCPToolVector.mcpId == mcp_id))
- # 插入新的工具向量
+ # 删除旧数据
+ await session.execute(delete(embedding_model.MCPVector).where(embedding_model.MCPVector.id == mcp_id))
+ await session.execute(
+ delete(embedding_model.MCPToolVector).where(embedding_model.MCPToolVector.mcpId == mcp_id),
+ )
+ # 插入新数据
+ session.add(embedding_model.MCPVector(
+ id=mcp_id,
+ embedding=mcp_embedding[0],
+ ))
for tool, embedding in zip(tool_list, tool_embedding, strict=True):
- session.add(MCPToolVector(
+ session.add(embedding_model.MCPToolVector(
id=tool.id,
mcpId=mcp_id,
embedding=embedding,
@@ -522,7 +522,7 @@ class MCPLoader(metaclass=SingletonMeta):
@staticmethod
- async def remove_deleted_mcp(deleted_mcp_list: list[str]) -> None:
+ async def remove_deleted_mcp(deleted_mcp_list: list[str], embedding_model: Embedding | None = None) -> None:
"""
删除无效的MCP在数据库中的记录
@@ -545,12 +545,17 @@ class MCPLoader(metaclass=SingletonMeta):
logger.info("[MCPLoader] 清除数据库中无效的MCP")
# 删除MCP的向量化数据
- async with postgres.session() as session:
- for mcp_id in deleted_mcp_list:
- await session.execute(delete(MCPVector).where(MCPVector.id == mcp_id))
- await session.execute(delete(MCPToolVector).where(MCPToolVector.mcpId == mcp_id))
- await session.commit()
- logger.info("[MCPLoader] 清除数据库中无效的MCP向量化数据")
+ if embedding_model:
+ async with postgres.session() as session:
+ for mcp_id in deleted_mcp_list:
+ await session.execute(
+ delete(embedding_model.MCPVector).where(embedding_model.MCPVector.id == mcp_id),
+ )
+ await session.execute(
+ delete(embedding_model.MCPToolVector).where(embedding_model.MCPToolVector.mcpId == mcp_id),
+ )
+ await session.commit()
+ logger.info("[MCPLoader] 清除数据库中无效的MCP向量化数据")
@staticmethod
diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py
index fb73ea299f84e8a164de08501648c8b5b0ff40c0..3a0ce5f944f7c1a344e5aea79db3d62cfc5f93fa 100644
--- a/apps/scheduler/pool/loader/metadata.py
+++ b/apps/scheduler/pool/loader/metadata.py
@@ -63,7 +63,7 @@ class MetadataLoader:
raise RuntimeError(err) from e
elif metadata_type == MetadataType.SERVICE.value:
try:
- metadata = ServiceMetadata(id=file_path.parent.name, **metadata_dict)
+ metadata = ServiceMetadata(id=uuid.UUID(file_path.parent.name), **metadata_dict)
except Exception as e:
err = "[MetadataLoader] Service metadata.yaml格式错误"
logger.exception(err)
diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py
index 802ec35050ae860727dedb2e354c277853330559..950491e894399858a0edbc2d395f498bd1c6da9f 100644
--- a/apps/scheduler/pool/loader/service.py
+++ b/apps/scheduler/pool/loader/service.py
@@ -13,7 +13,6 @@ from apps.common.postgres import postgres
from apps.llm.embedding import Embedding
from apps.models.node import NodeInfo
from apps.models.service import Service, ServiceACL, ServiceHashes
-from apps.models.vectors import NodePoolVector, ServicePoolVector
from apps.scheduler.pool.check import FileChecker
from apps.schemas.flow import PermissionType, ServiceMetadata
@@ -75,15 +74,25 @@ class ServiceLoader:
@staticmethod
- async def delete(service_id: uuid.UUID, *, is_reload: bool = False) -> None:
+ async def delete(
+ service_id: uuid.UUID, embedding_model: Embedding | None = None, *, is_reload: bool = False,
+ ) -> None:
"""删除Service,并更新数据库"""
async with postgres.session() as session:
await session.execute(delete(Service).where(Service.id == service_id))
await session.execute(delete(NodeInfo).where(NodeInfo.serviceId == service_id))
await session.execute(delete(ServiceACL).where(ServiceACL.serviceId == service_id))
await session.execute(delete(ServiceHashes).where(ServiceHashes.serviceId == service_id))
- await session.execute(delete(ServicePoolVector).where(ServicePoolVector.id == service_id))
- await session.execute(delete(NodePoolVector).where(NodePoolVector.serviceId == service_id))
+
+ if embedding_model:
+ await session.execute(
+ delete(embedding_model.ServicePoolVector).where(embedding_model.ServicePoolVector.id == service_id),
+ )
+ await session.execute(
+ delete(
+ embedding_model.NodePoolVector,
+ ).where(embedding_model.NodePoolVector.serviceId == service_id),
+ )
await session.commit()
if not is_reload:
@@ -120,30 +129,31 @@ class ServiceLoader:
session.add(node)
await session.commit()
- # 删除旧的向量数据
- async with postgres.session() as session:
- await session.execute(delete(ServicePoolVector).where(ServicePoolVector.id == metadata.id))
- await session.execute(delete(NodePoolVector).where(NodePoolVector.serviceId == metadata.id))
- await session.commit()
-
- # 进行向量化,更新postgres
- service_vecs = await Embedding.get_embedding([metadata.description])
- async with postgres.session() as session:
- pool_data = ServicePoolVector(
- id=metadata.id,
- embedding=service_vecs[0],
- )
- session.add(pool_data)
- await session.commit()
+ @staticmethod
+ async def _update_vector(nodes: list[NodeInfo], metadata: ServiceMetadata, embedding_model: Embedding) -> None:
+ """更新向量数据"""
+ service_vecs = await embedding_model.get_embedding([metadata.description])
node_descriptions = []
for node in nodes:
node_descriptions += [node.description]
- node_vecs = await Embedding.get_embedding(node_descriptions)
+ node_vecs = await embedding_model.get_embedding(node_descriptions)
async with postgres.session() as session:
+ # 删除旧数据
+ await session.execute(
+ delete(embedding_model.ServicePoolVector).where(embedding_model.ServicePoolVector.id == metadata.id),
+ )
+ await session.execute(
+ delete(embedding_model.NodePoolVector).where(embedding_model.NodePoolVector.serviceId == metadata.id),
+ )
+ # 插入新数据
+ session.add(embedding_model.ServicePoolVector(
+ id=metadata.id,
+ embedding=service_vecs[0],
+ ))
for vec in node_vecs:
- node_data = NodePoolVector(
+ node_data = embedding_model.NodePoolVector(
id=node.id,
serviceId=metadata.id,
embedding=vec,
diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py
index dbb06682bb29fb5bf91808991a477d023830eec5..986bbc44f89f873617cd3786e0d8957755c81911 100644
--- a/apps/scheduler/scheduler/context.py
+++ b/apps/scheduler/scheduler/context.py
@@ -7,13 +7,13 @@ from datetime import UTC, datetime
from apps.common.security import Security
from apps.models.record import Record, RecordMetadata
+from apps.scheduler.scheduler import Scheduler
from apps.schemas.enum_var import StepStatus
from apps.schemas.record import (
FlowHistory,
RecordContent,
RecordGroupDocument,
)
-from apps.schemas.request_data import RequestData
from apps.services.appcenter import AppCenterManager
from apps.services.document import DocumentManager
from apps.services.record import RecordManager
@@ -22,7 +22,7 @@ from apps.services.task import TaskManager
logger = logging.getLogger(__name__)
-async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
+async def save_data(scheduler: "Scheduler") -> None:
"""保存当前Executor、Task、Record等的数据"""
# 构造RecordContent
used_docs = []
diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py
index da40cc718487b116dc2592f51ff4cb7644542700..62041c1d7375ffbf8479bd0b8466ef1af9b772ea 100644
--- a/apps/scheduler/scheduler/flow.py
+++ b/apps/scheduler/scheduler/flow.py
@@ -7,7 +7,6 @@ import uuid
from apps.llm.patterns import Select
from apps.scheduler.pool.pool import Pool
from apps.schemas.request_data import RequestDataApp
-from apps.services.task import TaskManager
logger = logging.getLogger(__name__)
@@ -38,7 +37,4 @@ class FlowChooser:
"description": f"{flow.name}, {flow.description}",
} for flow in flow_list]
select_obj = Select()
- top_flow = await select_obj.generate(question=self._question, choices=choices)
-
- await TaskManager.update_task_token(self.task_id, select_obj.input_tokens, select_obj.output_tokens)
- return top_flow
+ return await select_obj.generate(question=self._question, choices=choices)
diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py
deleted file mode 100644
index 25f2f6d4ead4e76fb221978a5efba0ac2109db81..0000000000000000000000000000000000000000
--- a/apps/scheduler/scheduler/message.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""Scheduler消息推送"""
-
-import logging
-from datetime import UTC, datetime
-from textwrap import dedent
-
-from apps.common.config import config
-from apps.common.queue import MessageQueue
-from apps.models.document import Document
-from apps.models.task import Task
-from apps.schemas.enum_var import EventType, ExecutorStatus
-from apps.schemas.message import (
- DocumentAddContent,
- InitContent,
- InitContentFeature,
- TextAddContent,
-)
-from apps.schemas.rag_data import RAGEventData, RAGQueryReq
-from apps.schemas.record import RecordDocument
-from apps.services.rag import RAG
-from apps.services.task import TaskManager
-
-logger = logging.getLogger(__name__)
-
-
-async def push_init_message(
- task: Task, queue: MessageQueue, context_num: int, *, is_flow: bool = False,
-) -> Task:
- """推送初始化消息"""
- # 组装feature
- if is_flow:
- feature = InitContentFeature(
- maxTokens=config.llm.max_tokens or 0,
- contextNum=context_num,
- enableFeedback=False,
- enableRegenerate=False,
- )
- else:
- feature = InitContentFeature(
- maxTokens=config.llm.max_tokens or 0,
- contextNum=context_num,
- enableFeedback=True,
- enableRegenerate=True,
- )
-
- # 保存必要信息到Task
- created_at = round(datetime.now(UTC).timestamp(), 3)
- task.tokens.time = created_at
-
- await TaskManager.save_task(task.id, task)
- # 推送初始化消息
- await queue.push_output(
- task=task,
- event_type=EventType.INIT.value,
- data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True),
- )
- return task
-
-
-async def push_rag_message(
- task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]],
- doc_ids: list[str],
- rag_data: RAGQueryReq) -> None:
- """推送RAG消息"""
- full_answer = ""
-
- try:
- async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data):
- task, content_obj = await _push_rag_chunk(task, queue, chunk)
- if not content_obj:
- continue
- if content_obj.event_type == EventType.TEXT_ADD.value:
- # 如果是文本消息,直接拼接到答案中
- full_answer += content_obj.content
- elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
- task.runtime.documents.append(content_obj.content)
- task.state.flow_status = ExecutorStatus.SUCCESS
- except Exception as e:
- logger.error(f"[Scheduler] RAG服务发生错误: {e}")
- task.state.flow_status = ExecutorStatus.ERROR
- # 保存答案
- task.runtime.answer = full_answer
- task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time
- await TaskManager.save_task(task.id, task)
-
-
-async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData | None]:
- """推送RAG单个消息块"""
- # 如果是换行
- if not content or not content.rstrip().rstrip("\n"):
- return task, None
-
- try:
- content_obj = RAGEventData.model_validate_json(dedent(content[6:]).rstrip("\n"))
- # 如果是空消息
- if not content_obj.content:
- return task, None
-
- await TaskManager.update_task_token(
- task.id,
- content_obj.input_tokens,
- content_obj.output_tokens,
- override=True,
- )
-
- await TaskManager.save_task(task.id, task)
- # 推送消息
- if content_obj.event_type == EventType.TEXT_ADD.value:
- await queue.push_output(
- task=task,
- event_type=content_obj.event_type,
- data=TextAddContent(text=content_obj.content).model_dump(exclude_none=True, by_alias=True),
- )
- elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
- await queue.push_output(
- task=task,
- event_type=content_obj.event_type,
- data=DocumentAddContent(
- documentId=content_obj.content.get("id", ""),
- documentOrder=content_obj.content.get("order", 0),
- documentAuthor=content_obj.content.get("author", ""),
- documentName=content_obj.content.get("name", ""),
- documentAbstract=content_obj.content.get("abstract", ""),
- documentType=content_obj.content.get("extension", ""),
- documentSize=content_obj.content.get("size", 0),
- createdAt=round(content_obj.content.get("created_at", datetime.now(tz=UTC).timestamp()), 3),
- ).model_dump(exclude_none=True, by_alias=True),
- )
- except Exception:
- logger.exception("[Scheduler] RAG服务返回错误数据")
- return task, None
- else:
- return task, content_obj
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index 945f4e72e312348c6c7edeff078f66e1fe6e1a61..3de6392a32a44095e388d4206050b0829dbff133 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -4,6 +4,7 @@
import asyncio
import logging
import uuid
+from datetime import UTC, datetime
from apps.common.queue import MessageQueue
from apps.llm.embedding import Embedding
@@ -15,13 +16,12 @@ from apps.models.user import User
from apps.scheduler.executor.agent import MCPAgentExecutor
from apps.scheduler.executor.flow import FlowExecutor
from apps.scheduler.pool.pool import Pool
-from apps.scheduler.scheduler.context import get_context, get_docs
from apps.scheduler.scheduler.flow import FlowChooser
-from apps.scheduler.scheduler.message import (
- push_init_message,
- push_rag_message,
-)
from apps.schemas.enum_var import AppType, EventType, ExecutorStatus
+from apps.schemas.message import (
+ InitContent,
+ InitContentFeature,
+)
from apps.schemas.rag_data import RAGQueryReq
from apps.schemas.request_data import RequestData
from apps.schemas.scheduler import ExecutorBackground, LLMConfig
@@ -86,6 +86,37 @@ class Scheduler:
self.task = task
+ async def push_init_message(
+ self, context_num: int, *, is_flow: bool = False,
+ ) -> None:
+ """推送初始化消息"""
+ # 组装feature
+ if is_flow:
+ feature = InitContentFeature(
+ maxTokens=self.llm.reasoning.config.maxToken or 0,
+ contextNum=context_num,
+ enableFeedback=False,
+ enableRegenerate=False,
+ )
+ else:
+ feature = InitContentFeature(
+ maxTokens=self.llm.reasoning.config.maxToken or 0,
+ contextNum=context_num,
+ enableFeedback=True,
+ enableRegenerate=True,
+ )
+
+ # 保存必要信息到Task
+ created_at = round(datetime.now(UTC).timestamp(), 3)
+ self.task.runtime.time = created_at
+
+ # 推送初始化消息
+ await self.queue.push_output(
+ task=self.task,
+ event_type=EventType.INIT.value,
+ data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True),
+ )
+
async def _monitor_activity(self, kill_event: asyncio.Event, user_sub: str) -> None:
"""监控用户活动状态,不活跃时终止工作流"""
try:
diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py
index d127a6eefa606514d6e30db3f49a0a08b350bdf2..8b94684d13f12809d6d1b1638a6d789ff93da0c0 100644
--- a/apps/schemas/mcp.py
+++ b/apps/schemas/mcp.py
@@ -37,7 +37,7 @@ class MCPServerStdioConfig(MCPBasicConfig):
class MCPServerSSEConfig(MCPBasicConfig):
"""MCP 服务器配置"""
- url: str = Field(description="MCP 服务器地址", default="http://example.com/sse", pattern=r"^https?://.+/sse$")
+ url: str = Field(description="MCP 服务器地址", default="http://example.com/sse", pattern=r"^https?://.+$")
class MCPServerItem(BaseModel):
diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py
index 0d6ebb985c7eb6abef893a510b4129bd8008429b..068d03cef2c446ffdf0569ac94f61796b7e90652 100644
--- a/apps/schemas/scheduler.py
+++ b/apps/schemas/scheduler.py
@@ -34,7 +34,7 @@ class CallIds(BaseModel):
task_id: uuid.UUID = Field(description="任务ID")
executor_id: str = Field(description="Flow ID")
- session_id: str = Field(description="当前用户的Session ID")
+ session_id: str | None = Field(description="当前用户的Session ID")
app_id: uuid.UUID = Field(description="当前应用的ID")
user_sub: str = Field(description="当前用户的用户ID")
@@ -50,13 +50,6 @@ class CallVars(BaseModel):
language: LanguageType = Field(description="语言", default=LanguageType.CHINESE)
-class CallTokens(BaseModel):
- """Call的Tokens"""
-
- input_tokens: int = Field(description="输入的Tokens", default=0)
- output_tokens: int = Field(description="输出的Tokens", default=0)
-
-
class ExecutorBackground(BaseModel):
"""Executor的背景信息"""
diff --git a/apps/services/rag.py b/apps/services/rag.py
index af5bd506b532ef10db4d4a885bf96145a6502e76..43e35de8b0bd69acd7705edb8a27032f82d5acc3 100644
--- a/apps/services/rag.py
+++ b/apps/services/rag.py
@@ -15,7 +15,6 @@ from apps.common.config import config
from apps.llm.patterns.rewrite import QuestionRewrite
from apps.llm.reasoning import ReasoningLLM
from apps.llm.token import TokenCalculator
-from apps.models.llm import LLMData
from apps.schemas.enum_var import EventType, LanguageType
from apps.schemas.rag_data import RAGQueryReq
from apps.services.llm import LLMManager
@@ -28,104 +27,6 @@ CHUNK_ELEMENT_TOKENS = 5
class RAG:
"""调用RAG服务,获取知识库答案"""
- system_prompt: str = "You are a helpful assistant."
- """系统提示词"""
- user_prompt: dict[LanguageType, str] = {
- LanguageType.CHINESE: r"""
-
- 你是openEuler社区的智能助手。请结合给出的背景信息, 回答用户的提问,并且基于给出的背景信息在相关句子后进行脚注。
- 一个例子将在中给出。
- 上下文背景信息将在中给出。
- 用户的提问将在中给出。
- 注意:
- 1.输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。
- 2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。
- 3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。
- 4.不要对脚注本身进行解释或说明。
- 5.请不要使用中的文档的id作为脚注。
-
-
-
-
-
- openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。
-
-
- openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。
-
-
-
-
- openEuler社区的成员来自世界各地,包括开发者、用户和企业。
-
-
- openEuler社区的成员共同努力,推动开源操作系统的发展,并且为用户提供支持和帮助。
-
-
-
-
- openEuler社区的目标是什么?
-
-
- openEuler社区是一个开源操作系统社区,致力于推动Linux操作系统的发展。[[1]]
- openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]]
-
-
-
-
- {bac_info}
-
-
- {user_question}
-
- """,
- LanguageType.ENGLISH: r"""
-
- You are a helpful assistant of openEuler community. Please answer the user's question based on the given background information and add footnotes after the related sentences.
- An example will be given in .
- The background information will be given in .
- The user's question will be given in .
- Note:
- 1. Do not include any XML tags in the output, and do not make up any information. If you think the user's question is unrelated to the background information, please ignore the background information and directly answer.
- 2. Your response should not exceed 250 words.
-
-
-
-
-
- openEuler community is an open source operating system community, committed to promoting the development of the Linux operating system.
-
-
- openEuler community aims to provide users with a stable, secure, and efficient operating system platform, and support multiple hardware architectures.
-
-
-
-
- Members of the openEuler community come from all over the world, including developers, users, and enterprises.
-
-
- Members of the openEuler community work together to promote the development of open source operating systems, and provide support and assistance to users.
-
-
-
-
- What is the goal of openEuler community?
-
-
- openEuler community is an open source operating system community, committed to promoting the development of the Linux operating system. [[1]]
- openEuler community aims to provide users with a stable, secure, and efficient operating system platform, and support multiple hardware architectures. [[1]]
-
-
-
-
- {bac_info}
-
-
- {user_question}
-
- """,
- }
-
@staticmethod
async def get_doc_info_from_rag(
user_sub: str, max_tokens: int | None, doc_ids: list[str], data: RAGQueryReq,