From 9d5d21515eead3e19cb7d015bb2a1bc6ee4f1ac6 Mon Sep 17 00:00:00 2001 From: Loshawn <2428333123@qq.com> Date: Thu, 4 Sep 2025 11:57:25 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9SQL=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/sql/schema.py | 4 +- apps/scheduler/call/sql/sql.py | 136 +++++++++++------------------- 2 files changed, 51 insertions(+), 89 deletions(-) diff --git a/apps/scheduler/call/sql/schema.py b/apps/scheduler/call/sql/schema.py index 06ffb4f0..c3b57c52 100644 --- a/apps/scheduler/call/sql/schema.py +++ b/apps/scheduler/call/sql/schema.py @@ -1,7 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """SQL工具的输入输出""" -from typing import Any +from typing import Any, Optional from pydantic import Field @@ -17,5 +17,5 @@ class SQLInput(DataBase): class SQLOutput(DataBase): """SQL工具的输出""" - dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") + result: list[dict[str, Any]] = Field(description="SQL工具的执行结果") sql: str = Field(description="SQL语句") diff --git a/apps/scheduler/call/sql/sql.py b/apps/scheduler/call/sql/sql.py index 6eab8443..cf0c6f67 100644 --- a/apps/scheduler/call/sql/sql.py +++ b/apps/scheduler/call/sql/sql.py @@ -5,6 +5,7 @@ import logging from collections.abc import AsyncGenerator from typing import Any, ClassVar +from urllib.parse import urlparse import httpx from fastapi import status from pydantic import Field @@ -37,10 +38,16 @@ MESSAGE = { class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): """SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。""" - database_url: str = Field(description="数据库连接地址") + database_type: str = Field(description="数据库类型",default="postgres") # mysql mongodb opengauss postgres + host: str = Field(description="数据库地址",default="localhost") + port: int = Field(description="数据库端口",default=5432) + username: str = Field(description="数据库用户名",default="root") + password: str = Field(description="数据库密码",default="root") + database: str = Field(description="数据库名称",default="postgres") + table_name_list: list[str] = Field(description="表名列表",default=[]) - top_k: int = Field(description="生成SQL语句数量",default=5) - use_llm_enhancements: bool = Field(description="是否使用大模型增强", default=False) + + i18n_info: ClassVar[dict[str, dict]] = { LanguageType.CHINESE: { @@ -59,100 +66,55 @@ class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): question=call_vars.question, ) + async def _exec( + self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE + ) -> AsyncGenerator[CallOutputChunk, None]: + """运行SQL工具, 支持MySQL, MongoDB, PostgreSQL, OpenGauss""" + + data = SQLInput(**input_data) + + headers = {"Content-Type": "application/json"} - async def _generate_sql(self, data: SQLInput) -> list[dict[str, Any]]: - """生成SQL语句列表""" post_data = { - "database_url": self.database_url, - "table_name_list": self.table_name_list, - "question": data.question, - "topk": self.top_k, - "use_llm_enhancements": self.use_llm_enhancements, + "type": self.database_type, + "host": self.host, + "port": self.port, + "username": self.username, + "password": self.password, + "database": self.database, + "goal": data.question, + "table_list": self.table_name_list, } - headers = {"Content-Type": "application/json"} - sql_list = [] - request_num = 0 - max_request = 5 - - while request_num < max_request and len(sql_list) < self.top_k: - try: - async with httpx.AsyncClient() as client: - response = await client.post( - Config().get_config().extra.sql_url + "/database/sql", - headers=headers, - json=post_data, - timeout=60.0, - ) - request_num += 1 - if response.status_code == status.HTTP_200_OK: - result = response.json() - if result["code"] == status.HTTP_200_OK: - sql_list.extend(result["result"]["sql_list"]) - else: - logger.error("[SQL] 生成失败:%s", response.text) - except Exception: - logger.exception("[SQL] 生成失败") - request_num += 1 - - return sql_list - - - async def _execute_sql( - self, - sql_list: list[dict[str, Any]], - ) -> tuple[list[dict[str, Any]] | None, str | None]: - """执行SQL语句并返回结果""" - headers = {"Content-Type": "application/json"} + try: + async with httpx.AsyncClient() as client: + response = await client.post( + Config().get_config().extra.sql_url + "/sql/handler", + headers=headers, + json=post_data, + timeout=60.0, + ) - for sql_dict in sql_list: - try: - async with httpx.AsyncClient() as client: - response = await client.post( - Config().get_config().extra.sql_url + "/sql/execute", - headers=headers, - json={ - "database_id": sql_dict["database_id"], - "sql": sql_dict["sql"], - }, - timeout=60.0, - ) - if response.status_code == status.HTTP_200_OK: - result = response.json() - if result["code"] == status.HTTP_200_OK: - return result["result"], sql_dict["sql"] - else: - logger.error("[SQL] 调用失败:%s", response.text) - except Exception: - logger.exception("[SQL] 调用失败") - - return None, None + result = response.json() + if response.status_code == status.HTTP_200_OK: + if result["code"] == status.HTTP_200_OK: + result_data = result["result"] + sql_exec_results = result_data.get("execute_result") + sql_exec = result_data.get("sql") + sql_exec_risk = result_data.get("risk") + logger.info("[SQL] 调用成功\n[SQL 语句]: %s\n[SQL 结果]: %s\n[SQL 风险]: %s", sql_exec, sql_exec_results, sql_exec_risk) + + else: + logger.error("[SQL] 调用失败:%s", response.text) + logger.error("[SQL] 错误信息:%s", response["result"]) + except Exception: + logger.exception("[SQL] 调用失败") - async def _exec( - self, input_data: dict[str, Any], language: LanguageType = LanguageType.CHINESE - ) -> AsyncGenerator[CallOutputChunk, None]: - """运行SQL工具""" - data = SQLInput(**input_data) - # 生成SQL语句 - sql_list = await self._generate_sql(data) - if not sql_list: - raise CallError( - message=MESSAGE["invaild"][language], - data={}, - ) - - # 执行SQL语句 - sql_exec_results, sql_exec = await self._execute_sql(sql_list) - if sql_exec_results is None or sql_exec is None: - raise CallError( - message=MESSAGE["fail"][language], - data={}, - ) # 返回结果 data = SQLOutput( - dataset=sql_exec_results, + result=sql_exec_results, sql=sql_exec, ).model_dump(exclude_none=True, by_alias=True) -- Gitee