diff --git a/Cargo.lock b/Cargo.lock index 77b3fb4236c7a12c1da8542ab92ca5376ef1cfde..7365b445d66e6a8b79f2a0492de47ab179b731e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -942,6 +942,7 @@ checksum = "f2481980430f9f78649238835720ddccc57e52df14ffce1c6f37391d61b563e9" dependencies = [ "equivalent", "hashbrown 0.15.5", + "serde", ] [[package]] @@ -1169,6 +1170,7 @@ dependencies = [ "chrono", "dotenvy", "fnv", + "indexmap", "lazy_static", "log", "regex", @@ -1607,6 +1609,7 @@ version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ + "indexmap", "itoa", "memchr", "ryu", diff --git a/Cargo.toml b/Cargo.toml index 17b876409367cac38ddbd0280b327e0c818ff2f7..cb9e671d42c033525cf242c5c8bd17ef2e27f189 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ base64 = { version = "0.22.1" } # 序列化/反序列化 serde = { version = "1.0", features = ["derive"] } -serde_json = { version = "1" } +serde_json = { version = "1", features = ["preserve_order"]} serde_yaml = { version = "0.9.33" } # 数据库 @@ -38,4 +38,5 @@ tower-http = { version = "0.5", features = ["cors"] } # 异步运行时 tokio = { version = "1.0", features = ["full"] } +indexmap = { version = "2.11.0", features = ["serde"] } diff --git a/src/app/common/rpc.rs b/src/app/common/rpc.rs index 9d54952439913edd140e57c0be7505cb85ac5bb2..92ed136b3a90615e5677472385a9821224513d64 100644 --- a/src/app/common/rpc.rs +++ b/src/app/common/rpc.rs @@ -1,13 +1,4 @@ -/// 标准RPC响应结构体 -#[derive(Debug, Clone)] -pub struct RpcResult { - /// HTTP状态码,与http::StatusCode一致 - pub code: HttpCode, - /// 人类可读的消息 - pub msg: Option, - /// 响应数据负载 - pub data: Option, -} + /// HTTP状态码枚举 #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/src/app/datasource/dialect.rs b/src/app/datasource/dialect.rs index c5856fb7ecf5cf1f6f7fa733a59e9f06b9c7a4c2..f1bf124ce0cc9d69d839bb758dc9683eff6e6f9a 100644 --- a/src/app/datasource/dialect.rs +++ b/src/app/datasource/dialect.rs @@ -1,5 +1,5 @@ use crate::app::datasource::config::DataSourceKind; -use std::collections::HashMap; +use serde_json::{Value, Map}; /// SQL方言特性 pub trait SqlDialect: Send + Sync { @@ -48,10 +48,10 @@ pub trait SqlDialect: Send + Sync { fn escape_string_value(&self, value: &str) -> String; /// 格式化值(根据类型) - fn format_value(&self, value: &serde_json::Value) -> String; + fn format_value(&self, value: &Value) -> String; /// 构建WHERE子句 - fn build_wheres(&self, conditions: &HashMap) -> String; + fn build_wheres(&self, conditions: &Map) -> String; /// 构建SELECT字段列表,支持字段选择和别名 /// @@ -68,7 +68,7 @@ pub struct MySqlDialect; impl MySqlDialect { /// 解析单个条件 - fn parse_condition(&self, key: &str, value: &serde_json::Value) -> String { + fn parse_condition(&self, key: &str, value: &Value) -> String { // 解析逻辑运算符 let (field_name, logic_op, operator) = self.parse_key(key); @@ -76,8 +76,8 @@ impl MySqlDialect { "{}" => { // 对于{}操作符,需要判断值的类型来决定是IN条件还是范围条件 match value { - serde_json::Value::Array(_) => self.build_in_condition(&field_name, value, &logic_op), - serde_json::Value::String(s) if s.contains(',') && (s.contains('<') || s.contains('>') || s.contains('=')) => { + Value::Array(_) => self.build_in_condition(&field_name, value, &logic_op), + Value::String(s) if s.contains(',') && (s.contains('<') || s.contains('>') || s.contains('=')) => { self.build_range_condition(&field_name, value, &logic_op) }, _ => self.build_in_condition(&field_name, value, &logic_op), @@ -130,11 +130,11 @@ impl MySqlDialect { } /// 构建IN条件 - fn build_in_condition(&self, field: &str, value: &serde_json::Value, logic_op: &str) -> String { + fn build_in_condition(&self, field: &str, value: &Value, logic_op: &str) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::Array(arr) => { + Value::Array(arr) => { let values: Vec = arr.iter().map(|v| self.format_value(v)).collect(); let condition = format!("{} IN ({})", escaped_field, values.join(",")); if logic_op == "!" { @@ -155,20 +155,20 @@ impl MySqlDialect { } /// 构建包含条件(JSON_CONTAINS) - fn build_contains_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_contains_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); let formatted_value = self.format_value(value); format!("JSON_CONTAINS({}, {})", escaped_field, formatted_value) } /// 构建LIKE条件 - fn build_like_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_like_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { format!("{} LIKE {}", escaped_field, self.escape_string_value(s)) }, - serde_json::Value::Array(arr) => { + Value::Array(arr) => { let conditions: Vec = arr.iter() .filter_map(|v| v.as_str()) .map(|s| format!("{} LIKE {}", escaped_field, self.escape_string_value(s))) @@ -184,13 +184,13 @@ impl MySqlDialect { } /// 构建正则表达式条件 - fn build_regexp_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_regexp_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { format!("{} REGEXP {}", escaped_field, self.escape_string_value(s)) }, - serde_json::Value::Array(arr) => { + Value::Array(arr) => { let conditions: Vec = arr.iter() .filter_map(|v| v.as_str()) .map(|s| format!("{} REGEXP {}", escaped_field, self.escape_string_value(s))) @@ -206,10 +206,10 @@ impl MySqlDialect { } /// 构建BETWEEN条件 - fn build_between_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_between_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { let parts: Vec<&str> = s.split(',').collect(); if parts.len() == 2 { let start = self.escape_string_value(parts[0].trim()); @@ -219,7 +219,7 @@ impl MySqlDialect { "1=1".to_string() } }, - serde_json::Value::Array(arr) => { + Value::Array(arr) => { if arr.len() == 2 { let start = self.format_value(&arr[0]); let end = self.format_value(&arr[1]); @@ -233,11 +233,11 @@ impl MySqlDialect { } /// 构建范围条件 - fn build_range_condition(&self, field: &str, value: &serde_json::Value, logic_op: &str) -> String { + fn build_range_condition(&self, field: &str, value: &Value, logic_op: &str) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { let conditions: Vec = s.split(',') .map(|condition| { let condition = condition.trim(); @@ -261,7 +261,7 @@ impl MySqlDialect { let joiner = if logic_op == "&" { " AND " } else { " OR " }; format!("({})", conditions.join(joiner)) }, - serde_json::Value::Array(arr) => { + Value::Array(arr) => { let conditions: Vec = arr.iter() .filter_map(|v| v.as_str()) .map(|condition| { @@ -392,17 +392,17 @@ impl SqlDialect for MySqlDialect { format!("'{}'", value.replace("'", "''").replace("\\", "\\\\")) } - fn format_value(&self, value: &serde_json::Value) -> String { + fn format_value(&self, value: &Value) -> String { match value { - serde_json::Value::Null => "NULL".to_string(), - serde_json::Value::Bool(b) => if *b { "TRUE" } else { "FALSE" }.to_string(), - serde_json::Value::Number(n) => n.to_string(), - serde_json::Value::String(s) => self.escape_string_value(s), + Value::Null => "NULL".to_string(), + Value::Bool(b) => if *b { "TRUE" } else { "FALSE" }.to_string(), + Value::Number(n) => n.to_string(), + Value::String(s) => self.escape_string_value(s), _ => self.escape_string_value(&value.to_string()), } } - fn build_wheres(&self, conditions: &HashMap) -> String { + fn build_wheres(&self, conditions: &Map) -> String { let mut where_clauses = Vec::new(); for (key, value) in conditions { @@ -454,7 +454,7 @@ pub struct PostgreSqlDialect; impl PostgreSqlDialect { /// 解析单个条件 - fn parse_condition(&self, key: &str, value: &serde_json::Value) -> String { + fn parse_condition(&self, key: &str, value: &Value) -> String { // 解析逻辑运算符 let (field_name, logic_op, operator) = self.parse_key(key); @@ -462,8 +462,8 @@ impl PostgreSqlDialect { "{}" => { // 对于{}操作符,需要判断值的类型来决定是IN条件还是范围条件 match value { - serde_json::Value::Array(_) => self.build_in_condition(&field_name, value, &logic_op), - serde_json::Value::String(s) if s.contains(',') && (s.contains('<') || s.contains('>') || s.contains('=')) => { + Value::Array(_) => self.build_in_condition(&field_name, value, &logic_op), + Value::String(s) if s.contains(',') && (s.contains('<') || s.contains('>') || s.contains('=')) => { self.build_range_condition(&field_name, value, &logic_op) }, _ => self.build_in_condition(&field_name, value, &logic_op), @@ -516,11 +516,11 @@ impl PostgreSqlDialect { } /// 构建IN条件 - fn build_in_condition(&self, field: &str, value: &serde_json::Value, logic_op: &str) -> String { + fn build_in_condition(&self, field: &str, value: &Value, logic_op: &str) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::Array(arr) => { + Value::Array(arr) => { let values: Vec = arr.iter().map(|v| self.format_value(v)).collect(); let condition = format!("{} IN ({})", escaped_field, values.join(",")); if logic_op == "!" { @@ -541,20 +541,20 @@ impl PostgreSqlDialect { } /// 构建包含条件(PostgreSQL使用@>操作符) - fn build_contains_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_contains_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); let formatted_value = self.format_value(value); format!("{} @> {}", escaped_field, formatted_value) } /// 构建LIKE条件 - fn build_like_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_like_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { format!("{} LIKE {}", escaped_field, self.escape_string_value(s)) }, - serde_json::Value::Array(arr) => { + Value::Array(arr) => { let conditions: Vec = arr.iter() .filter_map(|v| v.as_str()) .map(|s| format!("{} LIKE {}", escaped_field, self.escape_string_value(s))) @@ -570,13 +570,13 @@ impl PostgreSqlDialect { } /// 构建正则表达式条件(PostgreSQL使用~操作符) - fn build_regexp_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_regexp_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { format!("{} ~ {}", escaped_field, self.escape_string_value(s)) }, - serde_json::Value::Array(arr) => { + Value::Array(arr) => { let conditions: Vec = arr.iter() .filter_map(|v| v.as_str()) .map(|s| format!("{} ~ {}", escaped_field, self.escape_string_value(s))) @@ -592,10 +592,10 @@ impl PostgreSqlDialect { } /// 构建BETWEEN条件 - fn build_between_condition(&self, field: &str, value: &serde_json::Value) -> String { + fn build_between_condition(&self, field: &str, value: &Value) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { let parts: Vec<&str> = s.split(',').collect(); if parts.len() == 2 { let start = self.escape_string_value(parts[0].trim()); @@ -605,7 +605,7 @@ impl PostgreSqlDialect { "1=1".to_string() } }, - serde_json::Value::Array(arr) => { + Value::Array(arr) => { if arr.len() == 2 { let start = self.format_value(&arr[0]); let end = self.format_value(&arr[1]); @@ -619,11 +619,11 @@ impl PostgreSqlDialect { } /// 构建范围条件 - fn build_range_condition(&self, field: &str, value: &serde_json::Value, logic_op: &str) -> String { + fn build_range_condition(&self, field: &str, value: &Value, logic_op: &str) -> String { let escaped_field = self.escape_identifier(field); match value { - serde_json::Value::String(s) => { + Value::String(s) => { let conditions: Vec = s.split(',') .map(|condition| { let condition = condition.trim(); @@ -762,17 +762,17 @@ impl SqlDialect for PostgreSqlDialect { format!("'{}'", value.replace("'", "''")) } - fn format_value(&self, value: &serde_json::Value) -> String { + fn format_value(&self, value: &Value) -> String { match value { - serde_json::Value::Null => "NULL".to_string(), - serde_json::Value::Bool(b) => if *b { "true" } else { "false" }.to_string(), - serde_json::Value::Number(n) => n.to_string(), - serde_json::Value::String(s) => self.escape_string_value(s), + Value::Null => "NULL".to_string(), + Value::Bool(b) => if *b { "true" } else { "false" }.to_string(), + Value::Number(n) => n.to_string(), + Value::String(s) => self.escape_string_value(s), _ => self.escape_string_value(&value.to_string()), } } - fn build_wheres(&self, conditions: &HashMap) -> String { + fn build_wheres(&self, conditions: &Map) -> String { let mut where_clauses = Vec::new(); for (key, value) in conditions { @@ -850,7 +850,7 @@ impl SqlBuilder { &self, schema: &str, table: &str, - data: &HashMap, + data: &Map, ) -> String { let fields: Vec = data.keys().cloned().collect(); let values: Vec = data.values().map(|v| self.dialect.format_value(v)).collect(); @@ -863,7 +863,7 @@ impl SqlBuilder { &self, schema: &str, table: &str, - data: &HashMap, + data: &Map, id: i64, ) -> String { let set_clauses: Vec = data @@ -935,7 +935,7 @@ impl SqlBuilder { /// 构建复杂WHERE条件 pub fn build_where_conditions( &self, - conditions: &HashMap, + conditions: &Map, ) -> String { self.dialect.build_wheres(conditions) } @@ -946,7 +946,7 @@ impl SqlBuilder { schema: &str, table: &str, fields: Option<&[String]>, - conditions: &HashMap, + conditions: &Map, order_by: Option<&str>, limit: Option, offset: Option, diff --git a/src/app/datasource/manager.rs b/src/app/datasource/manager.rs index f05dc18826165d5db1ab6b5e13aa8bb6926863c3..ff4b35aa4ec772175f1178d77d3a2213c9d625bc 100644 --- a/src/app/datasource/manager.rs +++ b/src/app/datasource/manager.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use indexmap::IndexMap; use std::sync::Arc; use crate::app::datasource::config::{DataSourceKind, DataSourcesConfig}; use crate::app::datasource::mysql::DBConn; @@ -15,7 +15,7 @@ pub enum DatabaseConnection { impl DatabaseConnection { /// 查询单条记录 - pub async fn query_one(&self, sql: &str, params: Vec) -> Result>, Box> { + pub async fn query_one(&self, sql: &str, params: Vec) -> Result>, Box> { match self { DatabaseConnection::Mysql(conn) => { conn.query_one(sql, params).await.map_err(|e| Box::new(e) as Box) @@ -27,7 +27,7 @@ impl DatabaseConnection { } /// 查询多条记录 - pub async fn query_list(&self, sql: &str, params: Vec) -> Result>, Box> { + pub async fn query_list(&self, sql: &str, params: Vec) -> Result>, Box> { match self { DatabaseConnection::Mysql(conn) => { conn.query_list(sql, params).await.map_err(|e| Box::new(e) as Box) @@ -64,7 +64,7 @@ impl DatabaseConnection { } /// 数据源管理器 -/// +/// /// 负责管理多个数据源的连接,提供统一的数据库操作接口 #[derive(Debug, Clone)] pub struct DataSourceManager { @@ -73,45 +73,45 @@ pub struct DataSourceManager { /// 数据源名称到连接的映射 /// Key: 数据源名称 /// Value: 数据库名称到连接的映射 - connections: Arc>>>, + connections: Arc>>>, /// 默认数据源名称 default_datasource: Option, } impl DataSourceManager { /// 创建新的数据源管理器 - /// + /// /// # 参数 /// * `config` - 数据源配置 - /// + /// /// # 返回值 /// 返回数据源管理器实例 pub fn new(config: DataSourcesConfig) -> Self { let default_datasource = config.get_default_datasource().map(|ds| ds.name.clone()); - + Self { config, - connections: Arc::new(std::sync::RwLock::new(HashMap::new())), + connections: Arc::new(std::sync::RwLock::new(IndexMap::new())), default_datasource, } } /// 初始化所有数据源连接 - /// + /// /// 遍历配置中的所有数据源,为每个数据源的每个数据库建立连接 - /// + /// /// # 返回值 /// * `Ok(())` - 初始化成功 /// * `Err(Box)` - 初始化失败 pub async fn initialize(&self) -> Result<(), Box> { let mut connections = self.connections.write().unwrap(); - + for datasource in &self.config.datasource { - let mut ds_connections = HashMap::new(); - + let mut ds_connections = IndexMap::new(); + for database in &datasource.database { let connection_url = datasource.build_connection_url(database); - + let db_connection = match datasource.kind { DataSourceKind::Mysql => { let conn = DBConn::new(&connection_url).await @@ -124,24 +124,24 @@ impl DataSourceManager { DatabaseConnection::Postgres(conn) } }; - + ds_connections.insert(database.clone(), db_connection); log::info!("Connected to database: {} in datasource: {}", database, datasource.name); } - + connections.insert(datasource.name.clone(), ds_connections); log::info!("Datasource {} initialized with {} databases", datasource.name, datasource.database.len()); } - + Ok(()) } /// 获取指定数据源和数据库的连接 - /// + /// /// # 参数 /// * `datasource_name` - 数据源名称 /// * `database_name` - 数据库名称 - /// + /// /// # 返回值 /// 返回数据库连接的克隆 pub fn get_connection(&self, datasource_name: &str, database_name: &str) -> Option { @@ -152,10 +152,10 @@ impl DataSourceManager { } /// 获取默认数据源的指定数据库连接 - /// + /// /// # 参数 /// * `database_name` - 数据库名称 - /// + /// /// # 返回值 /// 返回默认数据源中指定数据库的连接 pub fn get_default_connection(&self, database_name: &str) -> Option { @@ -167,7 +167,7 @@ impl DataSourceManager { } /// 获取所有数据源名称 - /// + /// /// # 返回值 /// 返回所有数据源名称的向量 pub fn get_datasource_names(&self) -> Vec { @@ -175,10 +175,10 @@ impl DataSourceManager { } /// 获取指定数据源的所有数据库名称 - /// + /// /// # 参数 /// * `datasource_name` - 数据源名称 - /// + /// /// # 返回值 /// 返回指定数据源的所有数据库名称 pub fn get_database_names(&self, datasource_name: &str) -> Vec { @@ -190,23 +190,23 @@ impl DataSourceManager { } /// 获取所有数据源和数据库的映射 - /// + /// /// # 返回值 /// 返回数据源名称到数据库名称列表的映射 - pub fn get_all_datasource_databases(&self) -> HashMap> { + pub fn get_all_datasource_databases(&self) -> IndexMap> { self.config.datasource.iter() .map(|ds| (ds.name.clone(), ds.database.clone())) .collect() } /// 根据数据源名称和数据库名称查询单条记录 - /// + /// /// # 参数 /// * `datasource_name` - 数据源名称 /// * `database_name` - 数据库名称 /// * `sql` - SQL 查询语句 /// * `params` - 查询参数 - /// + /// /// # 返回值 /// 返回查询结果 pub async fn query_one( @@ -215,7 +215,7 @@ impl DataSourceManager { database_name: &str, sql: &str, params: Vec, - ) -> Result>, Box> { + ) -> Result>, Box> { if let Some(connection) = self.get_connection(datasource_name, database_name) { connection.query_one(sql, params).await } else { @@ -224,13 +224,13 @@ impl DataSourceManager { } /// 根据数据源名称和数据库名称查询多条记录 - /// + /// /// # 参数 /// * `datasource_name` - 数据源名称 /// * `database_name` - 数据库名称 /// * `sql` - SQL 查询语句 /// * `params` - 查询参数 - /// + /// /// # 返回值 /// 返回查询结果列表 pub async fn query_list( @@ -239,7 +239,7 @@ impl DataSourceManager { database_name: &str, sql: &str, params: Vec, - ) -> Result>, Box> { + ) -> Result>, Box> { if let Some(connection) = self.get_connection(datasource_name, database_name) { connection.query_list(sql, params).await } else { diff --git a/src/app/datasource/mysql.rs b/src/app/datasource/mysql.rs index cfe5b7869763942a4ef705fa23da2faea618d761..76c9d997c163a35873cda380803d13d17576169c 100644 --- a/src/app/datasource/mysql.rs +++ b/src/app/datasource/mysql.rs @@ -4,10 +4,10 @@ use sqlx::{ mysql::{MySqlColumn, MySqlPool, MySqlRow}, types::Decimal, }; -use std::collections::HashMap; use crate::app::datasource::metadata::{put_db_meta, put_db_tables, put_table_meta}; use crate::app::datasource::codec::base64_encode; use crate::app::datasource::{ColumnMeta, TableMeta}; +use indexmap::IndexMap; // MySQL系统数据库列表` const MYSQL_SYS_DB: &[&str] = &["information_schema", "mysql", "performance_schema", "sys"]; @@ -219,10 +219,10 @@ impl DBConn { /// * `params` - 查询参数列表,用于绑定到 SQL 语句中的占位符 /// /// # 返回值 - /// * `Ok(Some(HashMap))` - 成功查询到记录,返回包含字段名和值的映射 + /// * `Ok(Some(IndexMap))` - 成功查询到记录,返回包含字段名和值的映射 /// * `Ok(None)` - 没有查询到记录 /// * `Err(sqlx::Error)` - 查询过程中发生错误 - pub async fn query_one(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { + pub async fn query_one(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { // 如果 SQL 语句中没有包含 LIMIT,则自动添加 LIMIT 1 以提高查询效率 let sql = if !sql.to_lowercase().contains("limit") { format!("{} LIMIT 1", sql) @@ -239,13 +239,13 @@ impl DBConn { // 执行查询并获取结果 let row_opt = query.fetch_optional(&self.pool).await?; - // 如果查询到记录,则将行数据转换为 HashMap + // 如果查询到记录,则将行数据转换为 IndexMap match row_opt { Some(row) => { let columns = row.columns(); - let mut record = HashMap::with_capacity(columns.len()); + let mut record = IndexMap::with_capacity(columns.len()); - // 遍历所有列,将列名和对应的值插入到 HashMap 中 + // 遍历所有列,将列名和对应的值插入到 IndexMap 中 for column in columns { let value = Self::get_column_val(&row, column); record.insert(column.name().to_string(), value); @@ -266,9 +266,9 @@ impl DBConn { /// * `params` - 查询参数列表,用于绑定到 SQL 语句中的占位符 /// /// # 返回值 - /// * `Ok(Vec>)` - 成功查询到的记录列表,每条记录为字段名和值的映射 + /// * `Ok(Vec>)` - 成功查询到的记录列表,每条记录为字段名和值的映射 /// * `Err(sqlx::Error)` - 查询过程中发生错误 - pub async fn query_list(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { + pub async fn query_list(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { // 构建查询并绑定参数 let mut query = sqlx::query(sql); for param in params { @@ -281,12 +281,12 @@ impl DBConn { // 预分配容量的结果向量 let mut results = Vec::with_capacity(rows.len()); - // 遍历每一行,将行数据转换为 HashMap 并添加到结果中 + // 遍历每一行,将行数据转换为 IndexMap 并添加到结果中 for row in rows.into_iter() { let columns = row.columns(); - let mut record = HashMap::with_capacity(columns.len()); + let mut record = IndexMap::with_capacity(columns.len()); - // 遍历所有列,将列名和对应的值插入到 HashMap 中 + // 遍历所有列,将列名和对应的值插入到 IndexMap 中 for column in columns { let value = Self::get_column_val(&row, column); record.insert(column.name().to_string(), value); diff --git a/src/app/datasource/postgres.rs b/src/app/datasource/postgres.rs index 51c841b225bb529841f3f8a98c1dcac21a741a5d..4b871261061879bf703de8020cef2cefe33ee5d1 100644 --- a/src/app/datasource/postgres.rs +++ b/src/app/datasource/postgres.rs @@ -4,7 +4,7 @@ use sqlx::{ postgres::{PgColumn, PgPool, PgRow}, types::Decimal, }; -use std::collections::HashMap; +use indexmap::IndexMap; use crate::app::datasource::metadata::{put_db_meta, put_db_tables, put_table_meta}; use crate::app::datasource::codec::base64_encode; use crate::app::datasource::{ColumnMeta, TableMeta}; @@ -254,10 +254,10 @@ impl PgConn { /// * `params` - 查询参数列表,用于绑定到 SQL 语句中的占位符 /// /// # 返回值 - /// * `Ok(Some(HashMap))` - 成功查询到记录,返回包含字段名和值的映射 + /// * `Ok(Some(IndexMap))` - 成功查询到记录,返回包含字段名和值的映射 /// * `Ok(None)` - 没有查询到记录 /// * `Err(sqlx::Error)` - 查询过程中发生错误 - pub async fn query_one(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { + pub async fn query_one(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { // 如果 SQL 语句中没有包含 LIMIT,则自动添加 LIMIT 1 以提高查询效率 let sql = if !sql.to_lowercase().contains("limit") { format!("{} LIMIT 1", sql) @@ -274,13 +274,13 @@ impl PgConn { // 执行查询并获取结果 let row_opt = query.fetch_optional(&self.pool).await?; - // 如果查询到记录,则将行数据转换为 HashMap + // 如果查询到记录,则将行数据转换为 IndexMap match row_opt { Some(row) => { let columns = row.columns(); - let mut record = HashMap::with_capacity(columns.len()); + let mut record = IndexMap::with_capacity(columns.len()); - // 遍历所有列,将列名和对应的值插入到 HashMap 中 + // 遍历所有列,将列名和对应的值插入到 IndexMap 中 for column in columns { let value = Self::get_column_val(&row, column); record.insert(column.name().to_string(), value); @@ -301,9 +301,9 @@ impl PgConn { /// * `params` - 查询参数列表,用于绑定到 SQL 语句中的占位符 /// /// # 返回值 - /// * `Ok(Vec>)` - 成功查询到的记录列表,每条记录为字段名和值的映射 + /// * `Ok(Vec>)` - 成功查询到的记录列表,每条记录为字段名和值的映射 /// * `Err(sqlx::Error)` - 查询过程中发生错误 - pub async fn query_list(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { + pub async fn query_list(&self, sql: &str, params: Vec) -> Result>, sqlx::Error> { // 构建查询并绑定参数 let mut query = sqlx::query(sql); for param in params { @@ -316,12 +316,12 @@ impl PgConn { // 预分配容量的结果向量 let mut results = Vec::with_capacity(rows.len()); - // 遍历每一行,将行数据转换为 HashMap 并添加到结果中 + // 遍历每一行,将行数据转换为 IndexMap 并添加到结果中 for row in rows.into_iter() { let columns = row.columns(); - let mut record = HashMap::with_capacity(columns.len()); + let mut record = IndexMap::with_capacity(columns.len()); - // 遍历所有列,将列名和对应的值插入到 HashMap 中 + // 遍历所有列,将列名和对应的值插入到 IndexMap 中 for column in columns { let value = Self::get_column_val(&row, column); record.insert(column.name().to_string(), value); diff --git a/src/app/handler/ctx/dialect.rs b/src/app/handler/ctx/dialect.rs index 5e87bb6a9accbb37aa0211b479419e1a4b0f60bc..05053c42fc206dbb111aa1da36643df6cd71ad01 100644 --- a/src/app/handler/ctx/dialect.rs +++ b/src/app/handler/ctx/dialect.rs @@ -1,4 +1,6 @@ use std::fmt::Debug; +use sqlx::ColumnIndex; +use crate::app::handler::util::string_util::{is_num, is_name}; pub trait SqlDialect: Debug + Send + Sync { #[allow(dead_code)] @@ -22,21 +24,73 @@ impl SqlDialect for MySqlDialect { } fn build_columns(&self, columns_str: &str) -> Vec { - columns_str - .split(',') - .map(|s| { - let trimmed = s.trim(); - if trimmed.contains(" as ") || trimmed.contains(" AS ") { - // 处理别名 - trimmed.to_string() + let mut fcs: Vec = Vec::new(); + + let _ = columns_str + .split(';') + .map(|cs| { + let mut part = cs; + let mut alias = ""; + let ind = part.rfind(':'); + if (ind.is_some()) { + alias = &part[(ind.unwrap() + 1)..]; + part = &part[0..ind.unwrap()]; + } + + let start = part.find('('); + let end = if start.is_none() { 0 } else { part.rfind(')').unwrap() }; + + let fun = if end <= 0 { "" } else { &part[..start.unwrap()] }; + let arg_str = if end <= 0 { part } else { &part[(start.unwrap() + 1)..end] }; + let args = arg_str.split(',').map(|mut a_s| { + let mut als = ""; + if fun.is_empty() { + let ind = a_s.rfind(':'); + if (ind.is_some()) { + als = &a_s[(ind.unwrap() + 1)..]; + a_s = &a_s[0..ind.unwrap()]; + } + } + + if als.is_empty() { + if is_num(a_s) { + return a_s.to_string(); + } + if is_name(a_s) { + return format!("`{}`", a_s.replace("`", "\\`")); + } + return format!("'{}'", a_s.replace("'", "\\'")); + } else { + if is_num(a_s) { + return format!("{} AS `{}`", a_s, als.replace("`", "\\`")); + } + if is_name(a_s) { + return format!("`{}` AS `{}`", a_s.replace("`", "\\`"), als.replace("`", "\\`")); + } + return format!("'{}' AS `{}`", a_s.replace("'", "\\'"), als.replace("`", "\\`")); + } + }).collect::>(); + + let a_s = args.join(","); + let a_s2 = a_s.as_str(); + if fun.is_empty() { + if alias.is_empty() { + fcs.push(format!("{}", a_s2)); + } else { + fcs.push(format!("{} AS `{}`", a_s2, alias.replace("`", "\\`"))); + } } else { - // 普通字段,添加反引号 - format!("`{}`", trimmed) + if alias.is_empty() { + fcs.push(format!("{}({})", fun, a_s2)); + } else { + fcs.push(format!("{}({}) AS `{}`", fun, a_s2, alias.replace("`", "\\`"))); + } } - }) - .collect() + }).count(); + + return fcs } - + fn build_limit(&self, limit: usize, offset: usize) -> String { if offset > 0 { format!("LIMIT {} OFFSET {}", limit, offset) diff --git a/src/app/handler/ctx/query_context.rs b/src/app/handler/ctx/query_context.rs index 6b74f8cc6b02038a270a70bb4bcdc12173b3f3f6..05ea73d59c3b949702914e242d43fe653d9946b3 100644 --- a/src/app/handler/ctx/query_context.rs +++ b/src/app/handler/ctx/query_context.rs @@ -1,18 +1,13 @@ -use fnv::FnvHashMap; use std::sync::Arc; -use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::collections::{BTreeMap, VecDeque}; use crate::app::common::rpc::HttpCode; use crate::app::handler::ctx::query_executor::QueryExecutor; use crate::app::datasource::config::DataSourceKind; use crate::app::handler::util::parser::DatabaseTargetDefaults; +use serde_json::{Number, Value, Map, json}; +use indexmap::IndexMap; -/// 主节点权重常量 -pub const RATIO_PRIMARY: i32 = 10000; -/// 被依赖一次的权重系数 -const RATIO_RELATED: i32 = 10; - - -#[derive(Debug)] +#[derive()] pub struct QueryContext { // 状态码 pub code: HttpCode, @@ -23,26 +18,24 @@ pub struct QueryContext { pub database_defaults: Option, // 主节点字段映射表(主节点路径 -> 主节点字段 -> 指向从节点关联字段路径) - pub primary_relate_kv: FnvHashMap>, + pub primary_relate_kv: IndexMap>, // 从节点字段映射表(从节点路径 -> 从节点字段 -> 指向主节点关联字段路径) - pub slave_relate_kv: FnvHashMap>, + pub slave_relate_kv: IndexMap>, // 分层节点,层级: 节点列表 pub layer_query_node: BTreeMap>>, // 命名空间节点 - pub namespace_node: FnvHashMap>, + pub namespace_node: IndexMap>, // 数据查询节点,节点路径: 节点 - pub query_node: FnvHashMap>, + pub query_node: IndexMap>, // 主节点数据列表(节点路径 -> 结果数据),主节点就是每一个命名空间的主查询节点 - pub primary_node_data: FnvHashMap>>, + pub primary_node_data: IndexMap>>, // 被关联字段的值(主节点字段路径 -> 主节点字段值(默认Value::Null, 结果是array or object)) - pub primary_node_related_field_values: FnvHashMap, + pub primary_node_related_field_values: IndexMap, // 从节点关联字段映射表(从节点父路径 -> 字段对应的值), 用于主节点获取从节点数据 - pub slave_node_relate_data: FnvHashMap>>>, + pub slave_node_relate_data: IndexMap, - // 节点权重(不在节点内存储权重,避免可变共享状态) - pub node_weight: FnvHashMap, } #[derive(Debug, Clone)] @@ -54,9 +47,7 @@ pub struct QueryNode { // 标记是否是列表查询 pub is_list: bool, /// 属性映射 - pub attributes: HashMap, - // 权重(不再使用,由 QueryContext.node_weight 管理,保留字段以兼容) - pub weight: i32, + pub attributes: IndexMap, // SQL执行器,负责生成和执行SQL pub sql_executor: QueryExecutor, } @@ -64,24 +55,24 @@ pub struct QueryNode { impl QueryContext { /// 从 JSON 值构建 QueryContext - pub fn from_json(root: HashMap, datasource_kind: DataSourceKind, database_defaults: Option) -> Self { + pub fn from_json(root: Map, datasource_kind: DataSourceKind, database_defaults: Option) -> Self { // 创建处理队列,每项包含:(父路径, 节点名称, 节点值, 深度) - let mut json_vec_deque: VecDeque<(String, String, serde_json::Value, i32)> = VecDeque::new(); + let mut json_vec_deque: VecDeque<(String, String, Value, i32)> = VecDeque::new(); // 分层节点,层级: 节点列表 let mut layer_query_node: BTreeMap>> = BTreeMap::default(); // 初始化数据结构,用于构建查询上下文 - let mut namespace_node = FnvHashMap::default(); + let mut namespace_node = IndexMap::default(); // 数据查询节点,节点路径: 节点 - let mut query_node: FnvHashMap> = FnvHashMap::default(); + let mut query_node: IndexMap> = IndexMap::default(); // 主节点字段映射表(主节点路径 -> 主节点字段 -> 指向从节点关联字段路径) - let mut primary_relate_kv: FnvHashMap> = FnvHashMap::default(); + let mut primary_relate_kv: IndexMap> = IndexMap::default(); // 从节点字段映射表(从节点路径 -> 从节点字段 -> 指向主节点关联字段路径) - let mut slave_relate_kv: FnvHashMap> = FnvHashMap::default(); + let mut slave_relate_kv: IndexMap> = IndexMap::default(); // 被关联字段的值(主节点字段路径 -> 主节点字段值(默认Value::Null, 结果是array or object)) - let mut primary_node_related_field_values: FnvHashMap = FnvHashMap::default(); + let mut primary_node_related_field_values: IndexMap = IndexMap::default(); // 处理根节点,区分数组节点和普通节点 for (key, val) in root { @@ -99,7 +90,7 @@ impl QueryContext { // 将子对象加入处理队列: // - 父路径: 当前数组节点的key // - 子节点名称: k - // - 子节点值: v + // - 子节点值: v // - 深度设为2(相对于根节点的深度) json_vec_deque.push_back((key.clone(), k.clone(), v.clone(), 2)); } @@ -130,7 +121,7 @@ impl QueryContext { ); } } else { // 处理普通节点,提取属性和关联关系 - let mut attributes = HashMap::new(); + let mut attributes = IndexMap::new(); // 判断节点是否属于列表(父节点是否为数组) let mut is_list = parent_path.ends_with("[]"); if let Some(map) = node_val.as_object() { @@ -141,30 +132,31 @@ impl QueryContext { let field_name = field_key[..(field_key.len()-1)].to_string(); let field_path = format!("{}/{}", &node_path, field_name); // 依赖关系是唯一索引则节点数据结果一定不是 list - if field_name.as_str() == "id" { is_list = false; } + // if field_name.as_str() == "id" { is_list = false; } // 关联关系 - if let serde_json::Value::String(primary_field_path) = field_value { + if let Value::String(primary_field_path) = field_value { slave_relate_kv.entry(node_path.clone()).or_default().insert(field_name, primary_field_path.to_string()); // 添加主节点字段对应的值到值映射表中 - primary_node_related_field_values.insert(primary_field_path.to_string(), serde_json::Value::Null); + primary_node_related_field_values.insert(primary_field_path.to_string(), Value::Null); // 添加主节点字段对应的值到字段映射表中 let index = primary_field_path.rfind('/').unwrap_or(0); let primary_node_path = &primary_field_path[..index]; let primary_related_field = &primary_field_path[(index+1)..]; primary_relate_kv.entry(primary_node_path.to_string()).or_default().insert(primary_related_field.to_string(), field_path.to_string()); } - } else { // 普通查询属性 - attributes.insert(field_key.clone(), field_value.clone()); + // } else { + } + // 普通查询属性 + attributes.insert(field_key.clone(), field_value.clone()); } } } - + // 创建查询节点并添加到对应深度的节点列表中 let shared_node = Arc::new(QueryNode { name: (&name).to_string(), path: node_path.clone(), - weight: 0, is_list, attributes, sql_executor: QueryExecutor::new(datasource_kind.clone()), @@ -184,57 +176,19 @@ impl QueryContext { slave_relate_kv, primary_node_related_field_values, - slave_node_relate_data: FnvHashMap::default(), - primary_node_data: FnvHashMap::default(), - node_weight: FnvHashMap::default(), + slave_node_relate_data: IndexMap::default(), + primary_node_data: IndexMap::default() }; - ctx.compute_node_weight(); - ctx + return ctx } - /// 计算每个节点的权重 - fn compute_node_weight(&mut self) { - // 收集所有节点引用,将多层嵌套的节点扁平化为一个列表 - let all_nodes: Vec> = self.layer_query_node.values().flatten().cloned().collect(); - - // 统计每个节点被依赖的次数,用于后续权重计算 - let mut counts: HashMap = HashMap::new(); - for node_rc in &all_nodes { - let node_path = node_rc.path.clone(); - // 检查当前节点是否有从属关系映射 - if let Some(relate) = self.slave_relate_kv.get(&node_path) { - // 遍历所有从属关系映射值(主节点路径) - relate.values().for_each(|parent_path| { - // 更新主节点被依赖计数: - // 1. 如果主节点路径不存在于counts中,则插入并初始化为0 - // 2. 对主节点路径的计数加1 - *counts.entry(parent_path.clone()).or_insert(0) += 1; - }); - } - } - - // 根据节点的依赖关系计算权重 - // 1. 无依赖的节点获得基础权重 RATIO_PRIMARY - // 2. 被依赖的节点获得额外权重 RATIO_RELATED^count - for node_rc in &all_nodes { - let (path, has_dep) = { - let b_path = node_rc.path.clone(); - let b_relate_kv = self.slave_relate_kv.get(&b_path); - (b_path, b_relate_kv.is_some()) - }; - let count = counts.get(&path).copied().unwrap_or(0); - let addition = if count > 0 { RATIO_RELATED.pow(count) } else { 0 }; - let weight = if !has_dep { RATIO_PRIMARY + addition } else { addition }; - self.node_weight.insert(path, weight); - } - } } /// 判断是否为标量 -fn is_scalar_field(v: &serde_json::Value) -> bool { v.is_number() || v.is_string() || v.is_boolean() } +fn is_scalar_field(v: &Value) -> bool { v.is_number() || v.is_string() || v.is_boolean() } /// 从 JSON 对象中收集标量属性 -fn collect_scalar_attrs(v: &serde_json::Value) -> FnvHashMap { +fn collect_scalar_attrs(v: &Value) -> IndexMap { // 从JSON值中收集标量属性(数字、字符串、布尔值) match v.as_object() { // 如果值是JSON对象 @@ -244,7 +198,7 @@ fn collect_scalar_attrs(v: &serde_json::Value) -> FnvHashMap FnvHashMap::default(), + None => IndexMap::default(), } } diff --git a/src/app/handler/ctx/query_executor.rs b/src/app/handler/ctx/query_executor.rs index 27fb99867e92525ac7c26b8007dad82012fc9109..ae52e3d458a69a4a38ba2ed257b0918b352ab9dd 100644 --- a/src/app/handler/ctx/query_executor.rs +++ b/src/app/handler/ctx/query_executor.rs @@ -1,4 +1,6 @@ -use std::collections::HashMap; +use indexmap::IndexMap; +use std::ptr::null; +use serde_json::Value; use crate::app::datasource::mysql::DBConn; use crate::app::datasource::metadata::get_table; use crate::app::handler::ctx::dialect::{SqlDialect, MySqlDialect, PostgreSqlDialect}; @@ -12,7 +14,9 @@ pub struct QueryExecutor { table: String, columns: Vec, where_clauses: Vec, - params: Vec, + params: Vec, + group: Option, + having: Vec, order: Option, page: i32, limit: i32, @@ -27,6 +31,8 @@ impl QueryExecutor { columns: vec![], where_clauses: vec![], params: vec![], + group: None, + having: Vec::new(), // 报错 vec![], order: None, page: 0, limit: 1, @@ -42,31 +48,31 @@ impl QueryExecutor { } #[allow(dead_code)] - pub fn get_params(&self) -> Vec { + pub fn get_params(&self) -> Vec { self.params.clone() } pub fn get_string_params(&self) -> Vec { self.params.iter().map(|v| { match v { - serde_json::Value::String(s) => s.clone(), - serde_json::Value::Number(n) => n.to_string(), - serde_json::Value::Bool(b) => b.to_string(), - serde_json::Value::Null => "NULL".to_string(), + Value::String(s) => s.clone(), + Value::Number(n) => n.to_string(), + Value::Bool(b) => b.to_string(), + Value::Null => "NULL".to_string(), _ => v.to_string(), } }).collect() } #[allow(dead_code)] - pub async fn exec(&self, db: &DBConn) -> Result>, sqlx::Error> { + pub async fn exec(&self, db: &DBConn) -> Result>, sqlx::Error> { let sql = self.to_sql(); log::info!("sql.exec: {}, params: {}", sql, serde_json::to_string(&self.params).unwrap()); let params: Vec = self.params.iter() .map(|v| match v { - serde_json::Value::Null => "NULL".to_string(), - serde_json::Value::String(s) => s.clone(), - serde_json::Value::Array(_) | serde_json::Value::Object(_) => + Value::Null => "NULL".to_string(), + Value::String(s) => s.clone(), + Value::Array(_) | Value::Object(_) => serde_json::to_string(v).unwrap_or_else(|_| "NULL".to_string()), _ => v.to_string(), }) @@ -87,15 +93,31 @@ impl QueryExecutor { } // FROM子句 - let escaped_schema = dialect.escape_identifier(&self.schema); + let escaped_table = dialect.escape_identifier(&self.table); - sql.push_str(&format!(" FROM {}.{}", escaped_schema, escaped_table)); + if (self.schema.is_empty()) { // FIXME 必须有 schema,否则查询总是为空 + sql.push_str(&format!(" FROM {}", escaped_table)); + } else { + let escaped_schema = dialect.escape_identifier(&self.schema); + sql.push_str(&format!(" FROM {}.{}", escaped_schema, escaped_table)); + } // WHERE子句 if !self.where_clauses.is_empty() { sql.push_str(" WHERE "); sql.push_str(&self.where_clauses.join(" AND ")); } + + // GROUP BY子句 + if let Some(group) = &self.group { + sql.push_str(&format!(" GROUP BY {}", group)); + } + + // HAVING子句 + if !self.having.is_empty() { + sql.push_str(" HAVING "); + sql.push_str(&self.having.join(" AND ")); + } // ORDER BY子句 if let Some(order) = &self.order { @@ -111,59 +133,76 @@ impl QueryExecutor { sql } - pub fn parse_table(&mut self, table_key: &str) -> Result<(), String> { + pub fn parse_table(&mut self, table_key: &str) -> Result { let table_key = if table_key.ends_with("[]") { &table_key[..table_key.len()-2] } else { table_key }; - let schema_table_vec = table_key.split(".").collect::>(); - let schema = schema_table_vec[0]; - let table = schema_table_vec[1]; - match get_table(&schema, table) { - Some(table) => { - self.table = table.name.clone(); - self.schema = table.schema.clone(); - Ok(()) - }, - None => Err(format!("table: {} not exists", table_key)) + let keys = table_key.split("-").collect::>(); + let l = keys.len(); + let table = if l < 1 {""} else {keys[0]}; + // match get_table(&schema, table) { + // Some(table) => { + // self.table = table.name.clone(); + // self.schema = table.schema.clone(); + // Ok(()) + // }, + // None => Err(format!("table: {} not exists", table_key)) + // } + if (table.is_empty()) { + Err(format!("Table is empty"))? } + self.table = table.to_string().clone(); + self.schema = "sys".to_string(); + Ok(table.to_string()) } - pub fn parse_condition(&mut self, field: &str, value: &serde_json::Value) { + pub fn parse_condition(&mut self, field: &str, value: &Value) { // 处理特殊参数 if field.starts_with('@') { match &field[1..] { - "order" => { - if let serde_json::Value::String(order) = value { - self.order = Some(order.to_string()); - } - } "column" => { - if let serde_json::Value::String(cols) = value { + if let Value::String(cols) = value { // 使用dialect的build_columns方法处理字段选择和别名 let dialect = self.get_dialect(); - let columns = dialect.build_columns(cols); + let columns = dialect.build_columns(cols.as_str()); self.columns = columns; } } + "group" => { + if let Value::String(group) = value { + self.group = Some(group.to_string()); + } + } + "having" => { + if let Value::String(having) = value { + self.having = having.split(";").map(|s| s.to_string()).collect::>(); + } + } + "order" => { + if let Value::String(order) = value { + self.order = Some(order.to_string().replace("+", " ASC ").replace("-", " DESC ")); + } + } _ => {} } return; } - // 处理各种查询条件 + // 处理各种查询条件 https://github.com/Tencent/APIJSON/blob/master/Document.md#3.2 + // https://github.com/Tencent/APIJSON/blob/master/APIJSONORM/src/main/java/apijson/JSONMap.java#L152-L190 if field.ends_with('$') { - // 模糊查询 + // 模糊搜索 https://github.com/Tencent/APIJSON/blob/master/APIJSONORM/src/main/java/apijson/orm/AbstractSQLConfig.java#L4114-L4232 let actual_field = &field[..field.len() - 1]; let dialect = self.get_dialect(); let escaped_field = dialect.escape_identifier(actual_field); self.where_clauses.push(format!("{} LIKE ?", escaped_field)); self.params.push(value.to_owned()); - } else if field.ends_with('?') { - // 正则匹配 + } else if field.ends_with('~') { + // 正则匹配 https://github.com/Tencent/APIJSON/blob/master/APIJSONORM/src/main/java/apijson/orm/AbstractSQLConfig.java#L4236-L4323 let actual_field = &field[..field.len() - 1]; let dialect = self.get_dialect(); let escaped_field = dialect.escape_identifier(actual_field); - let regex_op = match self.dialect { - DataSourceKind::Mysql => "REGEXP", - DataSourceKind::Postgres => "~", + let regex_op = match self.dialect { // FIXME 根据不同数据库类型及版本来适配,MySQL 8.0+ 用 regexp_like() + DataSourceKind::Mysql => if field.ends_with("*~") { "REGEXP" } else { "REGEXP BINARY" }, + DataSourceKind::Postgres => if field.ends_with("*~") { "*~" } else { "~" }, }; self.where_clauses.push(format!("{} {} ?", escaped_field, regex_op)); self.params.push(value.to_owned()); @@ -173,7 +212,7 @@ impl QueryExecutor { let dialect = self.get_dialect(); let escaped_field = dialect.escape_identifier(actual_field); match value { - serde_json::Value::Array(values) => { + Value::Array(values) => { if !values.is_empty() { let placeholders = vec!["?"; values.len()].join(","); self.where_clauses.push(format!("{} IN ({})", escaped_field, placeholders)); @@ -185,16 +224,26 @@ impl QueryExecutor { self.params.push(value.to_owned()); } } + } else if field.ends_with("%") { + let actual_field = &field[..field.len() - 1]; + let dialect = self.get_dialect(); + let escaped_field = dialect.escape_identifier(actual_field); + self.where_clauses.push(format!("{} BETWEEN ? AND ?", escaped_field)); + + let vals = value.as_str().unwrap().split(',').collect::>(); + assert_eq!(vals.len(), 2); + self.params.push(Value::from(vals[0])); + self.params.push(Value::from(vals[1])); } else if field.ends_with("<>") { - // NOT IN查询 + // json contains 查询 https://github.com/Tencent/APIJSON/blob/master/APIJSONORM/src/main/java/apijson/orm/AbstractSQLConfig.java#L4561-L4656 let actual_field = &field[..field.len() - 2]; let dialect = self.get_dialect(); let escaped_field = dialect.escape_identifier(actual_field); match value { - serde_json::Value::Array(values) => { + Value::Array(values) => { if !values.is_empty() { - let placeholders = vec!["?"; values.len()].join(","); - self.where_clauses.push(format!("{} NOT IN ({})", escaped_field, placeholders)); + let placeholders = vec!["?"; values.len()].join(","); // FIXME 根据不同数据库类型及版本来适配 + self.where_clauses.push(format!("json_contains({}, {}, '$')", escaped_field, placeholders)); self.params.extend(values.to_owned()); } } @@ -203,17 +252,48 @@ impl QueryExecutor { self.params.push(value.to_owned()); } } + } else if field.ends_with(">=") { + let actual_field = &field[..field.len() - 2]; + let dialect = self.get_dialect(); + let escaped_field = dialect.escape_identifier(actual_field); + self.where_clauses.push(format!("{} >= ?", escaped_field)); + self.params.push(value.to_owned()); + } else if field.ends_with("<=") { + let actual_field = &field[..field.len() - 2]; + let dialect = self.get_dialect(); + let escaped_field = dialect.escape_identifier(actual_field); + self.where_clauses.push(format!("{} <= ?", escaped_field)); + self.params.push(value.to_owned()); + } else if field.ends_with(">") { + let actual_field = &field[..field.len() - 1]; + let dialect = self.get_dialect(); + let escaped_field = dialect.escape_identifier(actual_field); + self.where_clauses.push(format!("{} > ?", escaped_field)); + self.params.push(value.to_owned()); + } else if field.ends_with("<") { + let actual_field = &field[..field.len() - 1]; + let dialect = self.get_dialect(); + let escaped_field = dialect.escape_identifier(actual_field); + self.where_clauses.push(format!("{} < ?", escaped_field)); + self.params.push(value.to_owned()); + } else if field.ends_with("!") { + let actual_field = &field[..field.len() - 1]; + let dialect = self.get_dialect(); + let escaped_field = dialect.escape_identifier(actual_field); + self.where_clauses.push(format!("{} != ?", escaped_field)); + self.params.push(value.to_owned()); } else { // 普通等值查询 let dialect = self.get_dialect(); let escaped_field = dialect.escape_identifier(field); match value { - serde_json::Value::Array(values) => { - if !values.is_empty() { - let placeholders = vec!["?"; values.len()].join(","); - self.where_clauses.push(format!("{} IN ({})", escaped_field, placeholders)); - self.params.extend(values.to_owned()); - } + Value::Object(values) => { + assert!(false, "还未实现子查询!!!!!"); + // if !values.is_empty() { + // let placeholders = vec!["?"; values.len()].join(","); + // self.where_clauses.push(format!("{} IN ({})", escaped_field, placeholders)); + // self.params.extend(values.to_owned()); + // } } _ => { self.where_clauses.push(format!("{} = ?", escaped_field)); @@ -223,14 +303,14 @@ impl QueryExecutor { } } - pub fn page_size(&mut self, page: serde_json::Value, count: serde_json::Value) { + pub fn page_size(&mut self, page: Value, count: Value) { self.page = Self::parse_num(&page, 0); self.limit = Self::parse_num(&count, 10); } - fn parse_num(value: &serde_json::Value, default_val: i32) -> i32 { + fn parse_num(value: &Value, default_val: i32) -> i32 { match value { - serde_json::Value::Number(n) => n.as_f64() + Value::Number(n) => n.as_f64() .map(|f| f as i32) .unwrap_or(default_val), _ => default_val, diff --git a/src/app/handler/delete.rs b/src/app/handler/delete.rs index f8c8f44b8449b7f6a02cacf1ee16fa6cfe0d502d..32ac0cae36085039f24f29a8e0ec0fbcc33c000d 100644 --- a/src/app/handler/delete.rs +++ b/src/app/handler/delete.rs @@ -1,11 +1,10 @@ -use std::collections::HashMap; use log::{debug, info, warn, error}; -use crate::app::common::rpc::{RpcResult, HttpCode}; +use crate::app::common::rpc::{HttpCode}; use crate::app::datasource::manager::{DataSourceManager, DatabaseConnection}; use crate::app::datasource::dialect::SqlBuilder; -use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults, DatabaseTarget}; +use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults, DatabaseTarget, new_ok_result, new_err_result}; use std::sync::Arc; -use serde_json::Value; +use serde_json::{Value, Map}; /// 处理删除数据的请求 /// @@ -15,24 +14,16 @@ use serde_json::Value; /// pub async fn handle_delete( datasource_manager: Arc, - body_map: HashMap, -) -> RpcResult> { + body_map: Map, +) -> Map { info!("开始处理删除请求,包含 {} 个表项", body_map.len()); debug!("删除请求体: {:?}", body_map); - - let mut rpc_result = RpcResult { - code: HttpCode::Ok, - msg: None, - data: None, - }; // 解析请求体 let parse_result = DatabaseTargetParser::parse_request_body(body_map); if !parse_result.errors.is_empty() { error!("解析请求体失败: {:?}", parse_result.errors); - rpc_result.code = HttpCode::BadRequest; - rpc_result.msg = Some(format!("请求解析失败: {}", parse_result.errors.join(", "))); - return rpc_result; + return new_err_result(HttpCode::BadRequest, parse_result.errors.join(", ").as_str()); } info!("成功解析 {} 个数据项", parse_result.items.len()); @@ -44,7 +35,7 @@ pub async fn handle_delete( Some("public".to_string()), ); - let mut result_payload = HashMap::::new(); + let mut result_payload = Map::::new(); // 处理每个数据项 for item in parse_result.items { @@ -53,9 +44,11 @@ pub async fn handle_delete( // 验证目标完整性 if let Err(err) = defaults.validate_target(&target) { warn!("目标验证失败: {}", err); - rpc_result.code = HttpCode::BadRequest; - result_payload.insert(target.table.clone(), Value::String(err)); - continue; + return new_err_result(HttpCode::BadRequest, err.as_str()); + + // rpc_result.code = HttpCode::BadRequest; + // result_payload.insert(target.table.clone(), Value::String(err)); + // continue; } // 执行删除操作 @@ -66,18 +59,13 @@ pub async fn handle_delete( } Err(err) => { error!("删除失败,表: {}, 错误: {}", target.table, err); - rpc_result.code = HttpCode::InternalServerError; - result_payload.insert(target.table, Value::String(err)); + return new_err_result(HttpCode::InternalServerError, err.as_str()); } } } - if !result_payload.is_empty() { - rpc_result.data = Some(result_payload); - } - info!("删除请求处理完成"); - rpc_result + return new_ok_result(result_payload) } /// 执行数据删除操作 @@ -95,7 +83,7 @@ pub async fn handle_delete( async fn delete_one( datasource_manager: &DataSourceManager, target: &DatabaseTarget, - data: &HashMap, + data: &Map, ) -> Result { let datasource_name = target.datasource.as_ref().ok_or("数据源名称为空")?; let database_name = target.database.as_ref().ok_or("数据库名称为空")?; diff --git a/src/app/handler/get.rs b/src/app/handler/get.rs index e071b47fb790bea7afe81170745d9c15b6967b55..18777f1ef81f5ce525cfb21ecba727c7426c8947 100644 --- a/src/app/handler/get.rs +++ b/src/app/handler/get.rs @@ -1,12 +1,13 @@ -use std::collections::HashMap; -use fnv::FnvHashMap; -use crate::app::common::rpc::{RpcResult, HttpCode}; +use std::ptr::null; +use serde_json::{Number, Value, Map, json, to_value}; +use indexmap::IndexMap; +use crate::app::common::rpc::{HttpCode}; use crate::app::datasource::manager::DataSourceManager; -use crate::app::handler::ctx::query_context::{get_parent_node_path, QueryContext, QueryNode, RATIO_PRIMARY}; +use crate::app::handler::ctx::query_context::{get_parent_node_path, QueryContext, QueryNode}; use crate::app::datasource::config::DataSourceKind; use crate::app::handler::ctx::query_executor::DEFAULT_MAX_COUNT; use crate::app::handler::util::transform::transform_salve_value; -use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults}; +use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults, KEY_CODE, KEY_MSG, MSG_SUCCESS}; /// 处理GET请求的异步方法 /// @@ -16,7 +17,7 @@ use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefa /// /// # 返回值 /// 返回serde_json::Value类型的JSON响应数据 -pub async fn handle_get(manager: &DataSourceManager, body_map: HashMap) -> RpcResult::> { +pub async fn handle_get(manager: &DataSourceManager, body_map: Map) -> Map { // 解析数据库目标信息(@datasource、@database、@schema) let parse_result = DatabaseTargetParser::parse_request_body(body_map.clone()); @@ -46,7 +47,7 @@ pub async fn handle_get(manager: &DataSourceManager, body_map: HashMap RpcResult::> { + async fn response(&mut self, manager: &DataSourceManager) -> Map { // 克隆 query_node 以避免借用冲突 let query_node = self.layer_query_node.clone(); @@ -54,19 +55,18 @@ impl QueryContext { for nodes in query_node.values() { // 按权重降序排序 let mut sorted_nodes = nodes.clone(); - sorted_nodes.sort_unstable_by(|a, b| b.weight.cmp(&a.weight)); for node in sorted_nodes { let mut node_owned = (*node).clone(); - if node_owned.weight >= RATIO_PRIMARY { - self.query_primary_node(&mut node_owned, manager).await; - } else { + // if node_owned.weight >= RATIO_PRIMARY { + // self.query_primary_node(&mut node_owned, manager).await; + // } else { self.query_relate_node(&mut node_owned, manager).await; - } + // } } } // 构建响应结果映射 - let mut response_payload = HashMap::new(); + let mut response_payload = Map::new(); // 遍历所有主节点数据(每个主节点路径及其对应的查询结果) for (node_path, results) in &self.primary_node_data { // 获取当前主节点的引用 @@ -82,11 +82,11 @@ impl QueryContext { if is_list { // 如果主节点是列表类型,遍历每个结果,构建主节点及其关联从节点的嵌套结构 - let primary_node_result_list: Vec<_> = results.iter() + let primary_node_result_list: Vec> = results.iter() .map(|result| self.build_primary_value(&namespace, node_name, result, &primary_relate_kv)) .collect(); // 将结果列表插入到响应映射中,键为命名空间 - response_payload.insert(namespace, serde_json::json!(primary_node_result_list)); + response_payload.insert(namespace, json!(primary_node_result_list)); } else { // 如果主节点不是列表类型,取第一个结果(若无则用默认值) let result = results.first().cloned().unwrap_or_default(); @@ -99,23 +99,25 @@ impl QueryContext { } } - let status_code = self.code; - let err_msg = &self.err_msg; - RpcResult::>{ code: status_code, msg: err_msg.to_owned(), data: Some(response_payload) } + response_payload.insert(KEY_CODE.to_string(), Value::Number(Number::from(self.code as i32))); + response_payload.insert(KEY_MSG.to_string(), Value::String(String::from(&self.err_msg.clone().unwrap_or( + if self.code == HttpCode::Ok { MSG_SUCCESS.to_string() } else {"unknown error!".to_string()} + ).to_string()))); + return response_payload } - fn build_primary_value(&self, namespace: &str, primary_node_name: &str, primary_node_data: &HashMap, primary_relate_kv: &HashMap) -> HashMap { - let mut result_map = HashMap::::new(); + fn build_primary_value(&self, namespace: &str, primary_node_name: &str, primary_node_data: &IndexMap, primary_relate_kv: &IndexMap) -> IndexMap { + let mut result_map = IndexMap::::new(); // 主节点数据 - result_map.insert(primary_node_name.to_string(), serde_json::to_value(primary_node_data.clone()).unwrap()); + result_map.insert(primary_node_name.to_string(), to_value(primary_node_data.clone()).unwrap_or_default()); // 从节点数据 for (primary_field, slave_node_field_path) in primary_relate_kv { // 获取主节点中关联字段的值,用于查询从节点数据 - let primary_field_value = primary_node_data.get(primary_field).unwrap(); + let primary_field_value = primary_node_data.get(primary_field).unwrap_or_default(); // 从路径中提取从节点字段名称(取最后一个斜杠后的部分) - let slave_node_field = slave_node_field_path.split("/").last().unwrap(); + let slave_node_field = slave_node_field_path.split("/").last().unwrap_or_default(); // 构建从节点字段值的键,格式为"字段名/字段值" let slave_node_field_value_key = format!("{}/{}", slave_node_field, primary_field_value); // 根据从节点字段路径和值键获取对应的从节点数据 @@ -142,7 +144,7 @@ impl QueryContext { // 如果路径包含"/",表示需要进行嵌套结构转换 if let Some(slave_data) = slave_node_field_data_opt { // 创建一个只有一个键值对的映射,用于转换 - let slave_field_value_map = std::iter::once((node_data_relative_path, slave_data)).collect::>(); + let slave_field_value_map = std::iter::once((node_data_relative_path, slave_data)).collect::>(); // 使用transform_salve_value函数将扁平结构转换为嵌套结构 result_map.extend(transform_salve_value(slave_field_value_map)); } @@ -164,7 +166,7 @@ impl QueryContext { /// /// # 返回值 /// 返回Option,包含从节点数据的JSON值或None - fn get_slave_node_data(&self, slave_node_field_path: &str, slave_node_field_value_key: &str) -> Option { + fn get_slave_node_data(&self, slave_node_field_path: &str, slave_node_field_value_key: &str) -> Option { // 获取从节点路径(去除字段名部分) let slave_node_path = get_parent_node_path(slave_node_field_path); @@ -177,14 +179,15 @@ impl QueryContext { .and_then(|relate_field_data| { // 检查从节点是否为列表类型 let is_list = self.query_node.get(&slave_node_path)?.is_list; + let data = relate_field_data; - if relate_field_data.is_empty() { // 记录空数据日志 + if data.is_array() && data.as_array().unwrap().is_empty() { // 记录空数据日志 log::debug!("slave.data: {}.{} is empty", &slave_node_path, slave_node_field_value_key); None } else if is_list { // 列表类型:直接序列化整个数组 - Some(serde_json::to_value(relate_field_data).unwrap()) + Some(to_value(relate_field_data).unwrap()) } else { // 非列表类型:只取第一个元素序列化 - Some(serde_json::to_value(&relate_field_data[0]).unwrap()) + Some(to_value(&relate_field_data[0]).unwrap()) } } ) @@ -213,14 +216,14 @@ impl QueryContext { } // 处理列表类型结果 - fn process_list_results(&mut self, node: &QueryNode, results: Vec>) { + fn process_list_results(&mut self, node: &QueryNode, results: Vec>) { for result in results { for (k, v) in result { let full_path = format!("{}/{}", node.path, k); if let Some(entry) = self.primary_node_related_field_values.get_mut(&full_path) { match entry { - serde_json::Value::Null => *entry = serde_json::Value::Array(vec![v]), - serde_json::Value::Array(arr) => arr.push(v), + Value::Null => *entry = Value::Array(vec![v]), + Value::Array(arr) => arr.push(v), _ => {} } } @@ -229,7 +232,7 @@ impl QueryContext { } // 处理单个结果 - fn process_single_result(&mut self, node: &QueryNode, result: &HashMap) { + fn process_single_result(&mut self, node: &QueryNode, result: &IndexMap) { for (k, v) in result { let full_path = format!("{}/{}", node.path, k); if let Some(existing_value_slot) = self.primary_node_related_field_values.get_mut(&full_path) { @@ -238,67 +241,146 @@ impl QueryContext { } } - async fn query_relate_node(&mut self, node: &mut QueryNode, manager: &DataSourceManager) { + async fn query_relate_node(&mut self, node: &mut QueryNode, manager: &DataSourceManager) -> Option>> { // 获取当前节点的路径和关联字段映射关系 - let node_path = node.path.clone(); - let node_relate_kv = self.slave_relate_kv.get(&node_path).cloned().unwrap_or_default(); + let node_name = &node.name; // .to_lowercase(); + let node_path = &node.path.clone(); + let node_attrs = &node.attributes; + let mut sql_attrs: IndexMap = IndexMap::new(); + for (field, val) in node_attrs { + // 处理字段名后缀@的情况 + let val2 = val.clone(); + + if field.ends_with('@') && val.is_string() { + let field_key = &field[..field.len() - 1]; + let mut keys = val.as_str().unwrap().split('/').collect::>(); + let kp = keys.join("/"); + + let node_relate_kv = &self.slave_node_relate_data; + let value = node_relate_kv.get(&kp); - // 处理每个关联字段的查询条件 - for (field_name, primary_node_field_path) in &node_relate_kv { - // 从主节点获取关联字段的值 - if let Some(value) = self.primary_node_related_field_values.get(primary_node_field_path) { // 如果值为空则直接返回 - if value.is_null() { - return; - } - // 如果是数组类型,设置分页大小为数组长度 - if let serde_json::Value::Array(array) = value { - node.sql_executor.page_size(serde_json::json!(0), serde_json::json!(array.len())); + let mut val3 = value.unwrap_or(&Value::Null); + if val3.is_null() { + let last_key = keys[keys.len() - 1]; + keys.remove(keys.len() - 1); + let parent_val = node_relate_kv.get(&keys.join("/")); + let parent = parent_val.unwrap_or(&Value::Null); + // 如果值为空则直接返回 + val3 = if parent.is_null() { + &Value::Null + } else { + parent.get(last_key).unwrap_or(&Value::Null) + }; + + if (val3.is_null()) { + return Some(Vec::>::new()); + } } - // 解析查询条件 - node.sql_executor.parse_condition(field_name, value); - } else { + + sql_attrs.insert(field_key.to_string(), val3.clone()); continue; } - // 确保关联字段在查询字段列表中 - node.sql_executor.add_column(field_name); + + sql_attrs.insert(field.to_string(), val2); } - // 执行节点数据查询 - if let Some(node_results) = self.query_node_data(node, manager).await { - // 处理每个关联字段的查询结果 - for (field, _) in &node_relate_kv { - let mut field_map = FnvHashMap::>>::default(); - // 处理字段名后缀@的情况 - let field_key = if field.ends_with('@') { - &field[..field.len() - 1] - } else { - field.as_str() - }; + let sql_attrs2 = sql_attrs.clone(); + node.attributes = sql_attrs; + + // 设置查询的表名 + let _ = node.sql_executor.parse_table(node_name); + + // 解析节点属性中的查询条件 + for (key, value) in sql_attrs2 { + let _ = node.sql_executor.parse_condition(&key, &value); + } + + // 处理列表查询的分页逻辑 + if node.is_list { + let parent_path = get_parent_node_path(node_path); + // 尝试从父节点获取分页参数 + if let Some(parent_node_attrs) = self.namespace_node.get(&parent_path).cloned() { + // 获取页码和每页数量,如果不存在则使用默认值 + let page = parent_node_attrs.get("page").cloned().unwrap_or_else(|| json!(0)); + let count = parent_node_attrs.get("count").cloned().unwrap_or_else(|| json!(DEFAULT_MAX_COUNT)); + node.sql_executor.page_size(page, count); + } else { + // 父节点不存在时使用默认分页参数 + node.sql_executor.page_size(json!(0), json!(DEFAULT_MAX_COUNT)); + } + } + + // 构建SQL语句 + let sql = node.sql_executor.to_sql(); + let params = node.sql_executor.get_string_params(); + + println!("sql = {}", sql); + println!("params = {:?}", params); + + // 获取数据源和数据库名称 + let (datasource_name, database_name) = if let Some(ref defaults) = self.database_defaults { + ( + defaults.default_datasource.as_deref().unwrap_or("default"), + defaults.default_database.as_deref().unwrap_or("sys") + ) + } else { + ("default", "sys") + }; + + // 执行查询 + match manager.query_list(datasource_name, database_name, &sql, params).await { + Ok(results) => { + self.primary_node_data.insert(node.path.clone(), results.clone()); + + // 处理每个关联字段的查询结果 // 遍历查询结果,构建字段映射关系 - for result in &node_results { - if let Some(field_value) = result.get(field_key) { - // 构建字段路径格式:字段名/字段值 - let field_path = format!("{}/{}", field_key, field_value); - // 将结果存入字段映射表 - field_map.entry(field_path).or_insert_with(Vec::new).push(result.clone()); + if ! results.is_empty() { // 将字段映射表存入从节点关联数据 + let node_relate_kv = &self.slave_node_relate_data; + let mut node_relate_kv2 = node_relate_kv.clone(); + + if node.is_list { + let mut i= -1; + for result in &results { + i += 1; + let mut m = Map::new(); + for (k, v) in result { + m.insert(k.clone(), v.clone()); + } + node_relate_kv2.insert(node.path.clone() + "/" + i.to_string().as_str(), Value::Object(m)); + } + } else { + let mut m = Map::new(); + let first = results.get(0).unwrap(); + for (k, v) in first.clone() { + m.insert(k, v); + } + node_relate_kv2.insert(node.path.clone(), Value::Object(m)); } + + self.slave_node_relate_data = node_relate_kv2; } - // 将字段映射表存入从节点关联数据 - self.slave_node_relate_data.insert(node.path.clone(), field_map); + + Some(results) + }, + Err(e) => { + self.err_msg = Some(format!("查询失败: {}", e)); + self.code = HttpCode::InternalServerError; + None } } + } - async fn query_node_data(&mut self, node: &mut QueryNode, manager: &DataSourceManager) -> Option>> { + async fn query_node_data(&mut self, node: &mut QueryNode, manager: &DataSourceManager) -> Option>> { // 准备SQL查询的基本参数 - let node_name = &node.name.to_lowercase(); + let node_name = &node.name; // .to_lowercase(); let node_path = &node.path; let node_attrs = &node.attributes; // 设置查询的表名 let _ = node.sql_executor.parse_table(node_name); - + // 解析节点属性中的查询条件 for (key, value) in node_attrs { let _ = node.sql_executor.parse_condition(key, value); @@ -310,12 +392,12 @@ impl QueryContext { // 尝试从父节点获取分页参数 if let Some(parent_node_attrs) = self.namespace_node.get(&parent_path).cloned() { // 获取页码和每页数量,如果不存在则使用默认值 - let page = parent_node_attrs.get("page").cloned().unwrap_or_else(|| serde_json::json!(0)); - let count = parent_node_attrs.get("count").cloned().unwrap_or_else(|| serde_json::json!(DEFAULT_MAX_COUNT)); + let page = parent_node_attrs.get("page").cloned().unwrap_or_else(|| json!(0)); + let count = parent_node_attrs.get("count").cloned().unwrap_or_else(|| json!(DEFAULT_MAX_COUNT)); node.sql_executor.page_size(page, count); } else { // 父节点不存在时使用默认分页参数 - node.sql_executor.page_size(serde_json::json!(0), serde_json::json!(DEFAULT_MAX_COUNT)); + node.sql_executor.page_size(json!(0), json!(DEFAULT_MAX_COUNT)); } } @@ -327,10 +409,10 @@ impl QueryContext { let (datasource_name, database_name) = if let Some(ref defaults) = self.database_defaults { ( defaults.default_datasource.as_deref().unwrap_or("default"), - defaults.default_database.as_deref().unwrap_or("default") + defaults.default_database.as_deref().unwrap_or("sys") ) } else { - ("default", "default") + ("default", "sys") }; // 执行查询 diff --git a/src/app/handler/head.rs b/src/app/handler/head.rs index def3f61df6e9d661a700183a385c3fb59b0900b7..b0a99bb8da4f42995e2ee9ef411890d80557511a 100644 --- a/src/app/handler/head.rs +++ b/src/app/handler/head.rs @@ -1,16 +1,16 @@ -use std::collections::HashMap; use log::{info, warn, error, debug}; -use crate::app::common::rpc::{RpcResult, HttpCode}; +use crate::app::common::rpc::{HttpCode}; use crate::app::datasource::manager::DataSourceManager; use crate::app::datasource::dialect::SqlBuilder; -use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults}; +use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults, new_err_result, new_ok_result, KEY_CODE, KEY_MSG}; use std::sync::Arc; +use serde_json::{Value, Map, json}; /// 处理HEAD请求的异步方法,主要用于检查表是否存在和记录计数 /// /// # 参数 /// * `datasource_manager` - 数据源管理器 -/// * `body_map` - 包含请求参数的HashMap,键为表名(String),值为查询条件(serde_json::Value) +/// * `body_map` - 包含请求参数的HashMap,键为表名(String),值为查询条件(Value) /// * `defaults` - 默认值提供者 /// /// # 返回值 @@ -22,17 +22,13 @@ use std::sync::Arc; /// - 如果查询失败,返回错误信息 pub async fn handle_head( datasource_manager: Arc, - body_map: HashMap, + body_map: Map, defaults: Option, -) -> RpcResult> { +) -> Map { info!("开始处理HEAD请求"); debug!("请求数据: {:?}", body_map); - - let mut rpc_result = RpcResult { - code: HttpCode::Ok, - msg: None, - data: None, - }; + + let mut result_payload: Map = Map::new(); // 解析请求体 let parse_result = DatabaseTargetParser::parse_request_body(body_map); @@ -40,14 +36,12 @@ pub async fn handle_head( // 如果有解析错误,直接返回 if !parse_result.errors.is_empty() { warn!("请求解析失败: {:?}", parse_result.errors); - rpc_result.code = HttpCode::BadRequest; - rpc_result.msg = Some(format!("解析错误: {}", parse_result.errors.join("; "))); - return rpc_result; + result_payload = new_err_result(HttpCode::BadRequest, parse_result.errors.join("; ").as_str()); + return result_payload; } info!("请求解析成功,共解析到 {} 个数据项", parse_result.items.len()); - let mut result_payload = HashMap::new(); let defaults = defaults.unwrap_or_else(|| { DatabaseTargetDefaults::new(None, None, Some("public".to_string())) }); @@ -76,17 +70,13 @@ pub async fn handle_head( }, Err(err) => { error!("查询失败: {}", err); - rpc_result.code = HttpCode::BadRequest; - result_payload.insert(target.table.clone(), serde_json::Value::String(err)); + result_payload = new_err_result(HttpCode::BadRequest, err.as_str()); break; } } } - - if !result_payload.is_empty() { - rpc_result.data = Some(result_payload); - } - rpc_result + + return new_ok_result(result_payload); } async fn count_one( @@ -95,8 +85,8 @@ async fn count_one( database_name: &str, schema: &str, table: &str, - conditions: &HashMap, -) -> Result { + conditions: &Map, +) -> Result { debug!("开始处理查询: {}.{}, 条件数量: {}", schema, table, conditions.len()); // 获取数据库连接 @@ -118,7 +108,7 @@ async fn count_one( .unwrap_or("*"); // 过滤掉@开头的特殊参数,只保留查询条件 - let query_conditions: HashMap = conditions + let query_conditions: Map = conditions .iter() .filter(|(key, _)| !key.starts_with('@')) .map(|(k, v)| (k.clone(), v.clone())) @@ -146,7 +136,7 @@ async fn count_one( match connection.query_list(&sql, vec![]).await { Ok(rows) => { debug!("查询结果行数: {}", rows.len()); - Ok(serde_json::json!(rows)) + Ok(json!(rows)) }, Err(e) => { error!("SELECT查询失败: {}", e); @@ -169,7 +159,7 @@ async fn count_one( match connection.count(&sql, vec![]).await { Ok(count) => { debug!("统计结果: {}", count); - Ok(serde_json::json!(count)) + Ok(json!(count)) }, Err(e) => { error!("COUNT查询失败: {}", e); diff --git a/src/app/handler/post.rs b/src/app/handler/post.rs index bb9ae5a0d3a050b3a8bdabbad54e0896b1feec17..2176b26861861a440c626e3c8f31bbae5abca7fa 100644 --- a/src/app/handler/post.rs +++ b/src/app/handler/post.rs @@ -1,11 +1,11 @@ -use std::collections::HashMap; use log::{info, warn, error, debug}; -use crate::app::common::rpc::{RpcResult, HttpCode}; +use crate::app::common::rpc::{HttpCode}; use crate::app::common::id::get_next_id; use crate::app::datasource::manager::DataSourceManager; use crate::app::datasource::dialect::SqlBuilder; -use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults}; +use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults, new_err_result, new_ok_result}; use std::sync::Arc; +use serde_json::{Value, Map, Number, json}; /// 处理数据插入请求 /// @@ -20,17 +20,11 @@ use std::sync::Arc; /// * 失败:`{"code": 400, "msg": "错误信息"}` pub async fn handle_post( datasource_manager: Arc, - body_map: HashMap, + body_map: Map, defaults: Option, -) -> RpcResult> { +) -> Map { info!("开始处理POST请求"); debug!("请求数据: {:?}", body_map); - - let mut rpc_result = RpcResult { - code: HttpCode::Ok, - msg: None, - data: None, - }; // 解析请求体 let parse_result = DatabaseTargetParser::parse_request_body(body_map); @@ -38,18 +32,17 @@ pub async fn handle_post( // 如果有解析错误,直接返回 if !parse_result.errors.is_empty() { warn!("请求解析失败: {:?}", parse_result.errors); - rpc_result.code = HttpCode::BadRequest; - rpc_result.msg = Some(format!("解析错误: {}", parse_result.errors.join("; "))); - return rpc_result; + return new_err_result(HttpCode::BadRequest, parse_result.errors.join("; ").as_str()); } info!("请求解析成功,共解析到 {} 个数据项", parse_result.items.len()); - let mut result_payload = HashMap::new(); let defaults = defaults.unwrap_or_else(|| { DatabaseTargetDefaults::new(None, None, Some("public".to_string())) }); + let mut result_payload = Map::new(); + // 处理每个解析后的数据项 for item in parse_result.items { // 应用默认值 @@ -60,12 +53,13 @@ pub async fn handle_post( // 验证目标完整性 if let Err(err) = defaults.validate_target(&target) { error!("目标验证失败: {}, table: {}", err, target.table); - rpc_result.code = HttpCode::BadRequest; - result_payload.insert( - target.table.clone(), - serde_json::json!(format!("目标验证失败: {}", err)) - ); - continue; + return new_err_result(HttpCode::BadRequest, err.as_str()); + + // result_payload.insert( + // target.table.clone(), + // json!(format!("目标验证失败: {}", err)) + // ); + // continue; } // 执行插入操作 @@ -85,20 +79,16 @@ pub async fn handle_post( ).await { Ok(id) => { info!("插入操作成功: {}.{}.{}.{}, id={}", datasource_name, database_name, schema_name, target.table, id); - result_payload.insert(target.table.clone(), serde_json::json!(id)); + result_payload.insert(target.table.clone(), json!(id)); }, Err(err) => { error!("插入操作失败: {}.{}.{}.{}, 错误: {}", datasource_name, database_name, schema_name, target.table, err); - rpc_result.code = HttpCode::BadRequest; - result_payload.insert(target.table.clone(), serde_json::Value::String(err)); + result_payload = new_err_result(HttpCode::BadRequest, err.as_str()); } } } - if !result_payload.is_empty() { - rpc_result.data = Some(result_payload); - } - rpc_result + return new_ok_result(result_payload) } /// 执行单条记录的插入操作 @@ -120,7 +110,7 @@ async fn insert_one( database_name: &str, schema: &str, table: &str, - data: &HashMap, + data: &Map, ) -> Result { debug!("开始执行单条插入操作: {}.{}.{}.{}", datasource_name, database_name, schema, table); debug!("插入数据: {:?}", data); @@ -155,7 +145,7 @@ async fn insert_one( let mut insert_data = data.clone(); let data_id = get_next_id(); debug!("生成记录ID: {}", data_id); - insert_data.insert("id".to_string(), serde_json::Value::Number(serde_json::Number::from(data_id))); + insert_data.insert("id".to_string(), Value::Number(Number::from(data_id))); // 构建INSERT SQL let sql = sql_builder.build_insert(schema, table, &insert_data); diff --git a/src/app/handler/put.rs b/src/app/handler/put.rs index 702d51fffadb0d26598025aba3b7fee945f1374b..a7c2dcbb0abe1f1cbfa83b32298fefdd7fb4ceb7 100644 --- a/src/app/handler/put.rs +++ b/src/app/handler/put.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; use std::sync::Arc; use log::{info, warn, error, debug}; use crate::app::datasource::manager::DataSourceManager; use crate::app::datasource::dialect::SqlBuilder; -use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults}; -use crate::app::common::rpc::{RpcResult, HttpCode}; +use crate::app::handler::util::parser::{DatabaseTargetParser, DatabaseTargetDefaults, new_err_result, new_ok_result}; +use crate::app::common::rpc::{HttpCode}; +use serde_json::{Value, Map}; /// 处理数据更新请求 /// @@ -34,17 +34,11 @@ use crate::app::common::rpc::{RpcResult, HttpCode}; /// ``` pub async fn handle_put( datasource_manager: Arc, - body_map: HashMap, + body_map: Map, defaults: Option, -) -> RpcResult> { +) -> Map { info!("开始处理PUT请求"); debug!("请求数据: {:?}", body_map); - - let mut rpc_result = RpcResult::>{ - code: HttpCode::Ok, - msg: None, - data: None - }; // 解析请求体 let parse_result = DatabaseTargetParser::parse_request_body(body_map); @@ -52,9 +46,7 @@ pub async fn handle_put( // 检查解析错误 if !parse_result.errors.is_empty() { warn!("请求解析失败: {:?}", parse_result.errors); - rpc_result.code = HttpCode::BadRequest; - rpc_result.msg = Some(format!("请求解析失败: {}", parse_result.errors.join(", "))); - return rpc_result; + return new_err_result(HttpCode::BadRequest, parse_result.errors.join("; ").as_str()); } info!("请求解析成功,共解析到 {} 个数据项", parse_result.items.len()); @@ -69,16 +61,14 @@ pub async fn handle_put( parse_result.items }; - let mut result_payload = HashMap::new(); - + let mut result_payload = Map::new(); + for item in items_with_defaults { // 验证目标完整性 if item.target.datasource.is_none() || item.target.database.is_none() || item.target.schema.is_none() { error!("缺少必要的数据库目标信息: datasource={:?}, database={:?}, schema={:?}", item.target.datasource, item.target.database, item.target.schema); - rpc_result.code = HttpCode::BadRequest; - rpc_result.msg = Some("缺少必要的数据库目标信息".to_string()); - return rpc_result; + return new_err_result(HttpCode::BadRequest, "缺少必要的数据库目标信息"); } let target = item.target; @@ -104,16 +94,12 @@ pub async fn handle_put( }, Err(err) => { error!("更新操作失败: {}.{}.{}.{}, 错误: {}", datasource_name, database_name, schema_name, target.table, err); - rpc_result.code = HttpCode::BadRequest; - result_payload.insert(target.table.clone(), serde_json::json!(err)); + return new_err_result(HttpCode::BadRequest, err.as_str()); } } } - - if !result_payload.is_empty() { - rpc_result.data = Some(result_payload); - } - rpc_result + + return new_ok_result(result_payload) } /// 执行单条记录的更新操作 @@ -135,7 +121,7 @@ async fn update_one( database_name: &str, schema: &str, table: &str, - data: &HashMap, + data: &Map, ) -> Result { debug!("开始执行单条更新操作: {}.{}.{}.{}", datasource_name, database_name, schema, table); debug!("更新数据: {:?}", data); diff --git a/src/app/handler/util/mod.rs b/src/app/handler/util/mod.rs index f4268743cb0e6876bd768635af106939fea46b84..cf3aed80faddda7f0c19944420ece1b7726b12cc 100644 --- a/src/app/handler/util/mod.rs +++ b/src/app/handler/util/mod.rs @@ -1,2 +1,3 @@ pub mod parser; -pub mod transform; \ No newline at end of file +pub mod transform; +pub(crate) mod string_util; \ No newline at end of file diff --git a/src/app/handler/util/parser.rs b/src/app/handler/util/parser.rs index def1f0c90a03ba2c36a3c583d9de99f25013fadb..05bdf5df86e4d28a8719b629065597932d58a725 100644 --- a/src/app/handler/util/parser.rs +++ b/src/app/handler/util/parser.rs @@ -1,6 +1,27 @@ -use std::collections::HashMap; -use serde_json::Value; +use serde_json::{Number, Value, Map}; use log::{debug, info, warn, error}; +use crate::app::common::rpc::HttpCode; + +// https://github.com/Tencent/APIJSON/blob/master/APIJSONORM/src/main/java/apijson/JSONResponse.java#L51-L80 + +pub const KEY_CODE: &str = "code"; +pub const KEY_MSG: &str = "msg"; +pub const MSG_SUCCESS: &str = "success"; + +pub fn new_err_result(code: HttpCode, msg: &str) -> Map { + let mut result_payload = Map::new(); + result_payload.insert(String::from(KEY_CODE), Value::Number(Number::from(code as i32))); + result_payload.insert(String::from(KEY_MSG), Value::String(String::from(msg))); + return result_payload +} + +pub fn new_ok_result(data: Map) -> Map { + let mut result_payload = data.clone(); + result_payload.insert(String::from(KEY_CODE), Value::Number(Number::from(HttpCode::Ok as i32))); + result_payload.insert(String::from(KEY_MSG), Value::String(String::from(MSG_SUCCESS))); + return result_payload +} + /// 数据库四元素信息 #[derive(Debug, Clone)] @@ -21,7 +42,7 @@ pub struct ParsedDataItem { /// 数据库目标信息 pub target: DatabaseTarget, /// 清理后的数据(移除了元数据字段) - pub data: HashMap, + pub data: Map, } /// 解析结果 @@ -44,7 +65,7 @@ impl DatabaseTargetParser { /// /// # 返回值 /// 返回解析结果,包含所有解析后的数据项和错误信息 - pub fn parse_request_body(body_map: HashMap) -> ParseResult { + pub fn parse_request_body(body_map: Map) -> ParseResult { info!("开始解析请求体,包含 {} 个表项", body_map.len()); debug!("请求体内容: {:?}", body_map); @@ -94,28 +115,46 @@ impl DatabaseTargetParser { debug!("解析表数据: table_name={}, is_array={}", table_name, is_array); if is_array { - // 处理数组数据 - debug!("处理数组数据"); - match param.as_array() { - Some(array) => { - debug!("数组包含 {} 个元素", array.len()); - for (index, item) in array.iter().enumerate() { - match item.as_object() { - Some(obj) => { - debug!("解析数组第 {} 项", index); - let parsed_item = Self::parse_single_item(table_name, obj.clone(), true)?; - items.push(parsed_item); - } - None => { - error!("数组第{}项不是有效的对象: {:?}", index, item); - return Err(format!("数组第{}项不是有效的对象", index)); + if param.is_object() { + let obj = param.as_object().unwrap(); + let count_val = obj.get("count"); + let page_val = obj.get("page"); + + let count = if count_val == None {10} else { count_val.unwrap().as_i64().unwrap()}; + let page = if page_val == None {0} else { page_val.unwrap().as_i64().unwrap()}; + + for i in 0..count { + let parsed_item = Self::parse_single_item(table_name, obj.clone(), true)?; + if parsed_item.data.is_empty() { + break + } + + items.push(parsed_item); + } + } else { + // 处理数组数据 + debug!("处理数组数据"); + match param.as_array() { + Some(array) => { + debug!("数组包含 {} 个元素", array.len()); + for (index, item) in array.iter().enumerate() { + match item.as_object() { + Some(obj) => { + debug!("解析数组第 {} 项", index); + let parsed_item = Self::parse_single_item(table_name, obj.clone(), true)?; + items.push(parsed_item); + } + None => { + error!("数组第{}项不是有效的对象: {:?}", index, item); + return Err(format!("数组第{}项不是有效的对象", index)); + } } } } - } - None => { - error!("期望数组类型,但收到: {:?}", param); - return Err("期望数组类型,但收到其他类型".to_string()); + None => { + error!("期望数组类型,但收到: {:?}", param); + return Err("期望数组类型,但收到其他类型".to_string()); + } } } } else { @@ -152,7 +191,7 @@ impl DatabaseTargetParser { ) -> Result { debug!("解析单个数据项: table={}, is_array={}", table_name, is_array); - let mut data = HashMap::new(); + let mut data = Map::new(); let mut datasource = None; let mut database = None; let mut schema = None; @@ -280,7 +319,7 @@ mod tests { #[test] fn test_parse_single_object() { - let mut body_map = HashMap::new(); + let mut body_map = Map::new(); body_map.insert( "users".to_string(), json!({ @@ -306,7 +345,7 @@ mod tests { #[test] fn test_parse_array_data() { - let mut body_map = HashMap::new(); + let mut body_map = Map::new(); body_map.insert( "moment[]".to_string(), json!([ diff --git a/src/app/handler/util/string_util.rs b/src/app/handler/util/string_util.rs new file mode 100644 index 0000000000000000000000000000000000000000..6f650d8641f045f70492bff0b4450616a8bc291a --- /dev/null +++ b/src/app/handler/util/string_util.rs @@ -0,0 +1,26 @@ +use regex::Regex; + +const REGEXP_NUM: &str = r"^[0-9]*$"; +const REGEXP_NAME: &str = r"^[a-zA-Z][a-zA-Z0-9_]*$"; +const REGEXP_BIG_NAME: &str = r"^[A-Z][a-zA-Z0-9_]*$"; +const REGEXP_SMALL_NAME: &str = r"^[a-z][a-zA-Z0-9_]*$"; + +pub fn is_num(s: &str) -> bool { + let re = Regex::new(REGEXP_NUM).unwrap(); + return re.is_match(s) +} + +pub fn is_name(s: &str) -> bool { + let re = Regex::new(REGEXP_NAME).unwrap(); + return re.is_match(s) +} + +pub fn is_big_name(s: &str) -> bool { + let re = Regex::new(REGEXP_BIG_NAME).unwrap(); + return re.is_match(s) +} + +pub fn is_small_name(s: &str) -> bool { + let re = Regex::new(REGEXP_SMALL_NAME).unwrap(); + return re.is_match(s) +} \ No newline at end of file diff --git a/src/app/handler/util/transform.rs b/src/app/handler/util/transform.rs index 152d3ba0efe73379cc50c361bd3b761187204733..fa4371ce64a48a834c89b383ed9e51ca5769ec6a 100644 --- a/src/app/handler/util/transform.rs +++ b/src/app/handler/util/transform.rs @@ -1,8 +1,8 @@ -use std::collections::HashMap; +use indexmap::IndexMap; /// 将一个平铺了 "a/b/c" 风格键的 JSON 对象,转换为嵌套版本 -pub fn transform_salve_value(input_map: HashMap) -> HashMap { - let mut result_map = HashMap::new(); +pub fn transform_salve_value(input_map: IndexMap) -> IndexMap { + let mut result_map = IndexMap::new(); for (raw_key, value) in input_map { // 拆分路径 @@ -65,7 +65,7 @@ fn build_nested_object(segments: &[&str], value: serde_json::Value) -> serde_jso #[cfg(test)] mod tests { - use std::collections::HashMap; + use indexmap::IndexMap; use crate::app::handler::util::transform::transform_salve_value; #[test] @@ -95,7 +95,7 @@ mod tests { for (i, src) in test_cases.iter().enumerate() { println!("测试用例 {}", i + 1); let v: serde_json::Value = serde_json::from_str(src).unwrap(); - let input_map: HashMap = v.as_object() + let input_map: IndexMap = v.as_object() .unwrap() .iter() .map(|(k, v)| (k.clone(), v.clone())) diff --git a/src/app/server.rs b/src/app/server.rs index b7547ffc4df3ac93a306ed0af64f05447b46b090..9c646565572fa5ea26f0bf836efcbb13cb80dbaa 100644 --- a/src/app/server.rs +++ b/src/app/server.rs @@ -9,7 +9,7 @@ use axum::{ routing::{get, post}, Router, }; -use serde_json::{json, Value}; +use serde_json::{json, Value, Map}; use std::collections::HashMap; use tower_http::cors::CorsLayer; use crate::app::{ @@ -22,11 +22,13 @@ use crate::app::{ delete::handle_delete, }, }; +use crate::app::common::rpc::HttpCode; +use crate::app::handler::util::parser::new_err_result; /// 创建HTTP服务器路由 pub fn create_router(datasource_manager: Arc) -> Router { let mgr = datasource_manager.clone(); - let curd_handler = move |Path(method): Path, Json(request_data): Json>| { + let curd_handler = move |Path(method): Path, Json(request_data): Json>| { let mgr = mgr.clone(); async move { let method_norm = method.strip_suffix(".json").unwrap_or(&method); @@ -47,18 +49,10 @@ pub fn create_router(datasource_manager: Arc) -> Router { handle_delete(mgr.clone(), request_data).await } _ => { - return Json(json!({ - "code": 400u16, - "data": serde_json::Value::Null, - "msg": format!("unknown method: {}", method_norm) - })); + return Json(json!(new_err_result(HttpCode::MethodNotAllowed, format!("unknown method: {}", method_norm).as_str()))); } }; - Json(json!({ - "code": rpc_result.code as u16, - "data": rpc_result.data, - "msg": rpc_result.msg - })) + Json(json!(rpc_result)) } };