perf: optimize HTTP client, DB queries, and clean up dead code
- Make reqwest::Client a LazyLock singleton instead of per-call allocation - Parallelize 3 independent DB queries in get_index_advisor_report with tokio::join! - Eliminate per-iteration Vec allocation in snapshot FK dependency loop - Hoist try_local_pg_dump() call in SampleData clone mode to avoid double execution - Evict stale schema cache entries on write to prevent unbounded memory growth - Remove unused ValidationReport struct and config_path field - Rename IndexRecommendationType variants to remove redundant suffix
This commit is contained in:
@@ -2,10 +2,10 @@ use crate::commands::data::bind_json_value;
|
||||
use crate::commands::queries::pg_value_to_json;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::ai::{
|
||||
AiProvider, AiSettings, GenerateDataParams, GeneratedDataPreview, GeneratedTableData,
|
||||
IndexAdvisorReport, IndexRecommendation, IndexStats,
|
||||
OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse,
|
||||
SlowQuery, TableStats, ValidationRule, ValidationStatus, DataGenProgress,
|
||||
AiProvider, AiSettings, DataGenProgress, GenerateDataParams, GeneratedDataPreview,
|
||||
GeneratedTableData, IndexAdvisorReport, IndexRecommendation, IndexStats, OllamaChatMessage,
|
||||
OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse, SlowQuery, TableStats,
|
||||
ValidationRule, ValidationStatus,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use crate::utils::{escape_ident, topological_sort_tables};
|
||||
@@ -20,12 +20,16 @@ use tauri::{AppHandle, Emitter, Manager, State};
|
||||
const MAX_RETRIES: u32 = 2;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
|
||||
fn http_client() -> reqwest::Client {
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.timeout(Duration::from_secs(300))
|
||||
.build()
|
||||
.unwrap_or_default()
|
||||
fn http_client() -> &'static reqwest::Client {
|
||||
use std::sync::LazyLock;
|
||||
static CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
||||
reqwest::Client::builder()
|
||||
.connect_timeout(Duration::from_secs(5))
|
||||
.timeout(Duration::from_secs(300))
|
||||
.build()
|
||||
.unwrap_or_default()
|
||||
});
|
||||
&CLIENT
|
||||
}
|
||||
|
||||
fn get_ai_settings_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
|
||||
@@ -65,11 +69,10 @@ pub async fn save_ai_settings(
|
||||
#[tauri::command]
|
||||
pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaModel>> {
|
||||
let url = format!("{}/api/tags", ollama_url.trim_end_matches('/'));
|
||||
let resp = http_client()
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e)))?;
|
||||
let resp =
|
||||
http_client().get(&url).send().await.map_err(|e| {
|
||||
TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e))
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
@@ -119,7 +122,10 @@ where
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| {
|
||||
TuskError::Ai(format!("{} failed after {} attempts", operation, MAX_RETRIES))
|
||||
TuskError::Ai(format!(
|
||||
"{} failed after {} attempts",
|
||||
operation, MAX_RETRIES
|
||||
))
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -164,10 +170,7 @@ async fn call_ollama_chat(
|
||||
}
|
||||
|
||||
let model = settings.model.clone();
|
||||
let url = format!(
|
||||
"{}/api/chat",
|
||||
settings.ollama_url.trim_end_matches('/')
|
||||
);
|
||||
let url = format!("{}/api/chat", settings.ollama_url.trim_end_matches('/'));
|
||||
|
||||
let request = OllamaChatRequest {
|
||||
model: model.clone(),
|
||||
@@ -194,10 +197,7 @@ async fn call_ollama_chat(
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
TuskError::Ai(format!(
|
||||
"Cannot connect to Ollama at {}: {}",
|
||||
url, e
|
||||
))
|
||||
TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", url, e))
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
@@ -379,10 +379,7 @@ pub async fn fix_sql_error(
|
||||
schema_text
|
||||
);
|
||||
|
||||
let user_content = format!(
|
||||
"SQL query:\n{}\n\nError message:\n{}",
|
||||
sql, error_message
|
||||
);
|
||||
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
|
||||
|
||||
let raw = call_ollama_chat(&app, &state, system_prompt, user_content).await?;
|
||||
Ok(clean_sql_response(&raw))
|
||||
@@ -405,9 +402,15 @@ pub(crate) async fn build_schema_context(
|
||||
|
||||
// Run all metadata queries in parallel for speed
|
||||
let (
|
||||
version_res, col_res, fk_res, enum_res,
|
||||
tbl_comment_res, col_comment_res, unique_res,
|
||||
varchar_res, jsonb_res,
|
||||
version_res,
|
||||
col_res,
|
||||
fk_res,
|
||||
enum_res,
|
||||
tbl_comment_res,
|
||||
col_comment_res,
|
||||
unique_res,
|
||||
varchar_res,
|
||||
jsonb_res,
|
||||
) = tokio::join!(
|
||||
sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool),
|
||||
fetch_columns(&pool),
|
||||
@@ -586,10 +589,9 @@ pub(crate) async fn build_schema_context(
|
||||
// Unique constraints for this table
|
||||
let schema_table: Vec<&str> = full_name.splitn(2, '.').collect();
|
||||
if schema_table.len() == 2 {
|
||||
if let Some(uqs) = unique_map.get(&(
|
||||
schema_table[0].to_string(),
|
||||
schema_table[1].to_string(),
|
||||
)) {
|
||||
if let Some(uqs) =
|
||||
unique_map.get(&(schema_table[0].to_string(), schema_table[1].to_string()))
|
||||
{
|
||||
for uq_cols in uqs {
|
||||
output.push(format!(" UNIQUE({})", uq_cols));
|
||||
}
|
||||
@@ -609,7 +611,9 @@ pub(crate) async fn build_schema_context(
|
||||
let result = output.join("\n");
|
||||
|
||||
// Cache the result
|
||||
state.set_schema_cache(connection_id.to_string(), result.clone()).await;
|
||||
state
|
||||
.set_schema_cache(connection_id.to_string(), result.clone())
|
||||
.await;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
@@ -931,10 +935,7 @@ async fn fetch_jsonb_keys(
|
||||
|
||||
let query = parts.join(" UNION ALL ");
|
||||
|
||||
let rows = match sqlx::query(&query)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
{
|
||||
let rows = match sqlx::query(&query).fetch_all(pool).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
log::warn!("Failed to fetch JSONB keys: {}", e);
|
||||
@@ -1033,6 +1034,26 @@ fn simplify_default(raw: &str) -> String {
|
||||
s.to_string()
|
||||
}
|
||||
|
||||
fn validate_select_statement(sql: &str) -> TuskResult<()> {
|
||||
let sql_upper = sql.trim().to_uppercase();
|
||||
if !sql_upper.starts_with("SELECT") {
|
||||
return Err(TuskError::Custom(
|
||||
"Validation query must be a SELECT statement".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_index_ddl(ddl: &str) -> TuskResult<()> {
|
||||
let ddl_upper = ddl.trim().to_uppercase();
|
||||
if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") {
|
||||
return Err(TuskError::Custom(
|
||||
"Only CREATE INDEX and DROP INDEX statements are allowed".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn clean_sql_response(raw: &str) -> String {
|
||||
let trimmed = raw.trim();
|
||||
// Remove markdown code fences
|
||||
@@ -1098,18 +1119,13 @@ pub async fn run_validation_rule(
|
||||
sql: String,
|
||||
sample_limit: Option<u32>,
|
||||
) -> TuskResult<ValidationRule> {
|
||||
let sql_upper = sql.trim().to_uppercase();
|
||||
if !sql_upper.starts_with("SELECT") {
|
||||
return Err(TuskError::Custom(
|
||||
"Validation query must be a SELECT statement".to_string(),
|
||||
));
|
||||
}
|
||||
validate_select_statement(&sql)?;
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
let limit = sample_limit.unwrap_or(10);
|
||||
let _start = Instant::now();
|
||||
|
||||
let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
|
||||
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
|
||||
sqlx::query("SET TRANSACTION READ ONLY")
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
@@ -1199,7 +1215,13 @@ pub async fn suggest_validation_rules(
|
||||
schema_text
|
||||
);
|
||||
|
||||
let raw = call_ollama_chat(&app, &state, system_prompt, "Suggest validation rules".to_string()).await?;
|
||||
let raw = call_ollama_chat(
|
||||
&app,
|
||||
&state,
|
||||
system_prompt,
|
||||
"Suggest validation rules".to_string(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cleaned = raw.trim();
|
||||
let json_start = cleaned.find('[').unwrap_or(0);
|
||||
@@ -1207,7 +1229,10 @@ pub async fn suggest_validation_rules(
|
||||
let json_str = &cleaned[json_start..json_end];
|
||||
|
||||
let rules: Vec<String> = serde_json::from_str(json_str).map_err(|e| {
|
||||
TuskError::Ai(format!("Failed to parse AI response as JSON array: {}. Response: {}", e, cleaned))
|
||||
TuskError::Ai(format!(
|
||||
"Failed to parse AI response as JSON array: {}. Response: {}",
|
||||
e, cleaned
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(rules)
|
||||
@@ -1226,13 +1251,16 @@ pub async fn generate_test_data_preview(
|
||||
) -> TuskResult<GeneratedDataPreview> {
|
||||
let pool = state.get_pool(¶ms.connection_id).await?;
|
||||
|
||||
let _ = app.emit("datagen-progress", DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "schema".to_string(),
|
||||
percent: 10,
|
||||
message: "Building schema context...".to_string(),
|
||||
detail: None,
|
||||
});
|
||||
let _ = app.emit(
|
||||
"datagen-progress",
|
||||
DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "schema".to_string(),
|
||||
percent: 10,
|
||||
message: "Building schema context...".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
let schema_text = build_schema_context(&state, ¶ms.connection_id).await?;
|
||||
|
||||
@@ -1255,7 +1283,14 @@ pub async fn generate_test_data_preview(
|
||||
|
||||
let fk_edges: Vec<(String, String, String, String)> = fk_rows
|
||||
.iter()
|
||||
.map(|fk| (fk.schema.clone(), fk.table.clone(), fk.ref_schema.clone(), fk.ref_table.clone()))
|
||||
.map(|fk| {
|
||||
(
|
||||
fk.schema.clone(),
|
||||
fk.table.clone(),
|
||||
fk.ref_schema.clone(),
|
||||
fk.ref_table.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let sorted_tables = topological_sort_tables(&fk_edges, &target_tables);
|
||||
@@ -1266,13 +1301,16 @@ pub async fn generate_test_data_preview(
|
||||
|
||||
let row_count = params.row_count.min(1000);
|
||||
|
||||
let _ = app.emit("datagen-progress", DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "generating".to_string(),
|
||||
percent: 30,
|
||||
message: "AI is generating test data...".to_string(),
|
||||
detail: None,
|
||||
});
|
||||
let _ = app.emit(
|
||||
"datagen-progress",
|
||||
DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "generating".to_string(),
|
||||
percent: 30,
|
||||
message: "AI is generating test data...".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
let tables_desc: Vec<String> = sorted_tables
|
||||
.iter()
|
||||
@@ -1329,13 +1367,16 @@ pub async fn generate_test_data_preview(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _ = app.emit("datagen-progress", DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "parsing".to_string(),
|
||||
percent: 80,
|
||||
message: "Parsing generated data...".to_string(),
|
||||
detail: None,
|
||||
});
|
||||
let _ = app.emit(
|
||||
"datagen-progress",
|
||||
DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "parsing".to_string(),
|
||||
percent: 80,
|
||||
message: "Parsing generated data...".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
// Parse JSON response
|
||||
let cleaned = raw.trim();
|
||||
@@ -1343,9 +1384,13 @@ pub async fn generate_test_data_preview(
|
||||
let json_end = cleaned.rfind('}').map(|i| i + 1).unwrap_or(cleaned.len());
|
||||
let json_str = &cleaned[json_start..json_end];
|
||||
|
||||
let data_map: HashMap<String, Vec<HashMap<String, Value>>> =
|
||||
serde_json::from_str(json_str).map_err(|e| {
|
||||
TuskError::Ai(format!("Failed to parse generated data: {}. Response: {}", e, &cleaned[..cleaned.len().min(500)]))
|
||||
let data_map: HashMap<String, Vec<HashMap<String, Value>>> = serde_json::from_str(json_str)
|
||||
.map_err(|e| {
|
||||
TuskError::Ai(format!(
|
||||
"Failed to parse generated data: {}. Response: {}",
|
||||
e,
|
||||
&cleaned[..cleaned.len().min(500)]
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut tables = Vec::new();
|
||||
@@ -1362,7 +1407,12 @@ pub async fn generate_test_data_preview(
|
||||
|
||||
let rows: Vec<Vec<Value>> = rows_data
|
||||
.iter()
|
||||
.map(|row_map| columns.iter().map(|col| row_map.get(col).cloned().unwrap_or(Value::Null)).collect())
|
||||
.map(|row_map| {
|
||||
columns
|
||||
.iter()
|
||||
.map(|col| row_map.get(col).cloned().unwrap_or(Value::Null))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let count = rows.len() as u32;
|
||||
@@ -1378,13 +1428,20 @@ pub async fn generate_test_data_preview(
|
||||
}
|
||||
}
|
||||
|
||||
let _ = app.emit("datagen-progress", DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "done".to_string(),
|
||||
percent: 100,
|
||||
message: "Data generation complete".to_string(),
|
||||
detail: Some(format!("{} rows across {} tables", total_rows, tables.len())),
|
||||
});
|
||||
let _ = app.emit(
|
||||
"datagen-progress",
|
||||
DataGenProgress {
|
||||
gen_id: gen_id.clone(),
|
||||
stage: "done".to_string(),
|
||||
percent: 100,
|
||||
message: "Data generation complete".to_string(),
|
||||
detail: Some(format!(
|
||||
"{} rows across {} tables",
|
||||
total_rows,
|
||||
tables.len()
|
||||
)),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(GeneratedDataPreview {
|
||||
tables,
|
||||
@@ -1404,7 +1461,7 @@ pub async fn insert_generated_data(
|
||||
}
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
|
||||
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
|
||||
|
||||
// Defer constraints for circular FKs
|
||||
sqlx::query("SET CONSTRAINTS ALL DEFERRED")
|
||||
@@ -1466,20 +1523,38 @@ pub async fn get_index_advisor_report(
|
||||
) -> TuskResult<IndexAdvisorReport> {
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
// Fetch table stats
|
||||
let table_stats_rows = sqlx::query(
|
||||
"SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \
|
||||
pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \
|
||||
pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \
|
||||
FROM pg_stat_user_tables \
|
||||
ORDER BY seq_scan DESC \
|
||||
LIMIT 50"
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
// Fetch table stats, index stats, and slow queries concurrently
|
||||
let (table_stats_res, index_stats_res, slow_queries_res) = tokio::join!(
|
||||
sqlx::query(
|
||||
"SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \
|
||||
pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \
|
||||
pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \
|
||||
FROM pg_stat_user_tables \
|
||||
ORDER BY seq_scan DESC \
|
||||
LIMIT 50"
|
||||
)
|
||||
.fetch_all(&pool),
|
||||
sqlx::query(
|
||||
"SELECT schemaname, relname, indexrelname, idx_scan, \
|
||||
pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \
|
||||
pg_get_indexdef(indexrelid) AS definition \
|
||||
FROM pg_stat_user_indexes \
|
||||
ORDER BY idx_scan ASC \
|
||||
LIMIT 50",
|
||||
)
|
||||
.fetch_all(&pool),
|
||||
sqlx::query(
|
||||
"SELECT query, calls, total_exec_time, mean_exec_time, rows \
|
||||
FROM pg_stat_statements \
|
||||
WHERE calls > 0 \
|
||||
ORDER BY mean_exec_time DESC \
|
||||
LIMIT 20",
|
||||
)
|
||||
.fetch_all(&pool),
|
||||
);
|
||||
|
||||
let table_stats: Vec<TableStats> = table_stats_rows
|
||||
let table_stats: Vec<TableStats> = table_stats_res
|
||||
.map_err(TuskError::Database)?
|
||||
.iter()
|
||||
.map(|r| TableStats {
|
||||
schema: r.get(0),
|
||||
@@ -1492,20 +1567,8 @@ pub async fn get_index_advisor_report(
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Fetch index stats
|
||||
let index_stats_rows = sqlx::query(
|
||||
"SELECT schemaname, relname, indexrelname, idx_scan, \
|
||||
pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \
|
||||
pg_get_indexdef(indexrelid) AS definition \
|
||||
FROM pg_stat_user_indexes \
|
||||
ORDER BY idx_scan ASC \
|
||||
LIMIT 50"
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let index_stats: Vec<IndexStats> = index_stats_rows
|
||||
let index_stats: Vec<IndexStats> = index_stats_res
|
||||
.map_err(TuskError::Database)?
|
||||
.iter()
|
||||
.map(|r| IndexStats {
|
||||
schema: r.get(0),
|
||||
@@ -1517,17 +1580,7 @@ pub async fn get_index_advisor_report(
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Fetch slow queries (graceful if pg_stat_statements not available)
|
||||
let (slow_queries, has_pg_stat_statements) = match sqlx::query(
|
||||
"SELECT query, calls, total_exec_time, mean_exec_time, rows \
|
||||
FROM pg_stat_statements \
|
||||
WHERE calls > 0 \
|
||||
ORDER BY mean_exec_time DESC \
|
||||
LIMIT 20"
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
{
|
||||
let (slow_queries, has_pg_stat_statements) = match slow_queries_res {
|
||||
Ok(rows) => {
|
||||
let queries: Vec<SlowQuery> = rows
|
||||
.iter()
|
||||
@@ -1551,7 +1604,13 @@ pub async fn get_index_advisor_report(
|
||||
for ts in &table_stats {
|
||||
stats_text.push_str(&format!(
|
||||
" {}.{}: seq_scan={}, idx_scan={}, rows={}, size={}, idx_size={}\n",
|
||||
ts.schema, ts.table, ts.seq_scan, ts.idx_scan, ts.n_live_tup, ts.table_size, ts.index_size
|
||||
ts.schema,
|
||||
ts.table,
|
||||
ts.seq_scan,
|
||||
ts.idx_scan,
|
||||
ts.n_live_tup,
|
||||
ts.table_size,
|
||||
ts.index_size
|
||||
));
|
||||
}
|
||||
|
||||
@@ -1568,7 +1627,10 @@ pub async fn get_index_advisor_report(
|
||||
for sq in &slow_queries {
|
||||
stats_text.push_str(&format!(
|
||||
" calls={}, mean={:.1}ms, total={:.1}ms, rows={}: {}\n",
|
||||
sq.calls, sq.mean_time_ms, sq.total_time_ms, sq.rows,
|
||||
sq.calls,
|
||||
sq.mean_time_ms,
|
||||
sq.total_time_ms,
|
||||
sq.rows,
|
||||
sq.query.chars().take(200).collect::<String>()
|
||||
));
|
||||
}
|
||||
@@ -1635,12 +1697,7 @@ pub async fn apply_index_recommendation(
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let ddl_upper = ddl.trim().to_uppercase();
|
||||
if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") {
|
||||
return Err(TuskError::Custom(
|
||||
"Only CREATE INDEX and DROP INDEX statements are allowed".to_string(),
|
||||
));
|
||||
}
|
||||
validate_index_ddl(&ddl)?;
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
@@ -1655,3 +1712,151 @@ pub async fn apply_index_recommendation(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── validate_select_statement ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn select_valid_simple() {
|
||||
assert!(validate_select_statement("SELECT 1").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_valid_with_leading_whitespace() {
|
||||
assert!(validate_select_statement(" SELECT * FROM users").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_valid_lowercase() {
|
||||
assert!(validate_select_statement("select * from users").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_rejects_insert() {
|
||||
assert!(validate_select_statement("INSERT INTO users VALUES (1)").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_rejects_delete() {
|
||||
assert!(validate_select_statement("DELETE FROM users").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_rejects_drop() {
|
||||
assert!(validate_select_statement("DROP TABLE users").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_rejects_empty() {
|
||||
assert!(validate_select_statement("").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn select_rejects_whitespace_only() {
|
||||
assert!(validate_select_statement(" ").is_err());
|
||||
}
|
||||
|
||||
// NOTE: This test documents a known weakness — SELECT prefix allows injection
|
||||
#[test]
|
||||
fn select_allows_semicolon_after_select() {
|
||||
// "SELECT 1; DROP TABLE users" starts with SELECT — passes validation
|
||||
// This is a known limitation documented in the review
|
||||
assert!(validate_select_statement("SELECT 1; DROP TABLE users").is_ok());
|
||||
}
|
||||
|
||||
// ── validate_index_ddl ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn ddl_valid_create_index() {
|
||||
assert!(validate_index_ddl("CREATE INDEX idx_name ON users(email)").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_valid_create_index_concurrently() {
|
||||
assert!(validate_index_ddl("CREATE INDEX CONCURRENTLY idx ON t(c)").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_valid_drop_index() {
|
||||
assert!(validate_index_ddl("DROP INDEX idx_name").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_valid_with_leading_whitespace() {
|
||||
assert!(validate_index_ddl(" CREATE INDEX idx ON t(c)").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_valid_lowercase() {
|
||||
assert!(validate_index_ddl("create index idx on t(c)").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_rejects_create_table() {
|
||||
assert!(validate_index_ddl("CREATE TABLE evil(id int)").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_rejects_drop_table() {
|
||||
assert!(validate_index_ddl("DROP TABLE users").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_rejects_alter_table() {
|
||||
assert!(validate_index_ddl("ALTER TABLE users ADD COLUMN x int").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ddl_rejects_empty() {
|
||||
assert!(validate_index_ddl("").is_err());
|
||||
}
|
||||
|
||||
// NOTE: Documents bypass weakness — semicolon after valid prefix
|
||||
#[test]
|
||||
fn ddl_allows_semicolon_injection() {
|
||||
// "CREATE INDEX x ON t(c); DROP TABLE users" — passes validation
|
||||
// Mitigated by sqlx single-statement execution
|
||||
assert!(validate_index_ddl("CREATE INDEX x ON t(c); DROP TABLE users").is_ok());
|
||||
}
|
||||
|
||||
// ── clean_sql_response ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn clean_sql_plain() {
|
||||
assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_fences() {
|
||||
assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_generic_fences() {
|
||||
assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_postgresql_fences() {
|
||||
assert_eq!(
|
||||
clean_sql_response("```postgresql\nSELECT 1\n```"),
|
||||
"SELECT 1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_whitespace() {
|
||||
assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_no_fences_multiline() {
|
||||
assert_eq!(
|
||||
clean_sql_response("SELECT\n *\nFROM users"),
|
||||
"SELECT\n *\nFROM users"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user