diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs index 99eca49..ad5291c 100644 --- a/src-tauri/src/commands/ai.rs +++ b/src-tauri/src/commands/ai.rs @@ -2,10 +2,10 @@ use crate::commands::data::bind_json_value; use crate::commands::queries::pg_value_to_json; use crate::error::{TuskError, TuskResult}; use crate::models::ai::{ - AiProvider, AiSettings, GenerateDataParams, GeneratedDataPreview, GeneratedTableData, - IndexAdvisorReport, IndexRecommendation, IndexStats, - OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse, - SlowQuery, TableStats, ValidationRule, ValidationStatus, DataGenProgress, + AiProvider, AiSettings, DataGenProgress, GenerateDataParams, GeneratedDataPreview, + GeneratedTableData, IndexAdvisorReport, IndexRecommendation, IndexStats, OllamaChatMessage, + OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse, SlowQuery, TableStats, + ValidationRule, ValidationStatus, }; use crate::state::AppState; use crate::utils::{escape_ident, topological_sort_tables}; @@ -20,12 +20,16 @@ use tauri::{AppHandle, Emitter, Manager, State}; const MAX_RETRIES: u32 = 2; const RETRY_DELAY_MS: u64 = 1000; -fn http_client() -> reqwest::Client { - reqwest::Client::builder() - .connect_timeout(Duration::from_secs(5)) - .timeout(Duration::from_secs(300)) - .build() - .unwrap_or_default() +fn http_client() -> &'static reqwest::Client { + use std::sync::LazyLock; + static CLIENT: LazyLock = LazyLock::new(|| { + reqwest::Client::builder() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(300)) + .build() + .unwrap_or_default() + }); + &CLIENT } fn get_ai_settings_path(app: &AppHandle) -> TuskResult { @@ -65,11 +69,10 @@ pub async fn save_ai_settings( #[tauri::command] pub async fn list_ollama_models(ollama_url: String) -> TuskResult> { let url = format!("{}/api/tags", ollama_url.trim_end_matches('/')); - let resp = http_client() - .get(&url) - .send() - .await - .map_err(|e| TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e)))?; + let resp = + http_client().get(&url).send().await.map_err(|e| { + TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e)) + })?; if !resp.status().is_success() { let status = resp.status(); @@ -119,7 +122,10 @@ where } Err(last_error.unwrap_or_else(|| { - TuskError::Ai(format!("{} failed after {} attempts", operation, MAX_RETRIES)) + TuskError::Ai(format!( + "{} failed after {} attempts", + operation, MAX_RETRIES + )) })) } @@ -164,10 +170,7 @@ async fn call_ollama_chat( } let model = settings.model.clone(); - let url = format!( - "{}/api/chat", - settings.ollama_url.trim_end_matches('/') - ); + let url = format!("{}/api/chat", settings.ollama_url.trim_end_matches('/')); let request = OllamaChatRequest { model: model.clone(), @@ -194,10 +197,7 @@ async fn call_ollama_chat( .send() .await .map_err(|e| { - TuskError::Ai(format!( - "Cannot connect to Ollama at {}: {}", - url, e - )) + TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", url, e)) })?; if !resp.status().is_success() { @@ -379,10 +379,7 @@ pub async fn fix_sql_error( schema_text ); - let user_content = format!( - "SQL query:\n{}\n\nError message:\n{}", - sql, error_message - ); + let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message); let raw = call_ollama_chat(&app, &state, system_prompt, user_content).await?; Ok(clean_sql_response(&raw)) @@ -405,9 +402,15 @@ pub(crate) async fn build_schema_context( // Run all metadata queries in parallel for speed let ( - version_res, col_res, fk_res, enum_res, - tbl_comment_res, col_comment_res, unique_res, - varchar_res, jsonb_res, + version_res, + col_res, + fk_res, + enum_res, + tbl_comment_res, + col_comment_res, + unique_res, + varchar_res, + jsonb_res, ) = tokio::join!( sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool), fetch_columns(&pool), @@ -586,10 +589,9 @@ pub(crate) async fn build_schema_context( // Unique constraints for this table let schema_table: Vec<&str> = full_name.splitn(2, '.').collect(); if schema_table.len() == 2 { - if let Some(uqs) = unique_map.get(&( - schema_table[0].to_string(), - schema_table[1].to_string(), - )) { + if let Some(uqs) = + unique_map.get(&(schema_table[0].to_string(), schema_table[1].to_string())) + { for uq_cols in uqs { output.push(format!(" UNIQUE({})", uq_cols)); } @@ -609,7 +611,9 @@ pub(crate) async fn build_schema_context( let result = output.join("\n"); // Cache the result - state.set_schema_cache(connection_id.to_string(), result.clone()).await; + state + .set_schema_cache(connection_id.to_string(), result.clone()) + .await; Ok(result) } @@ -931,10 +935,7 @@ async fn fetch_jsonb_keys( let query = parts.join(" UNION ALL "); - let rows = match sqlx::query(&query) - .fetch_all(pool) - .await - { + let rows = match sqlx::query(&query).fetch_all(pool).await { Ok(r) => r, Err(e) => { log::warn!("Failed to fetch JSONB keys: {}", e); @@ -1033,6 +1034,26 @@ fn simplify_default(raw: &str) -> String { s.to_string() } +fn validate_select_statement(sql: &str) -> TuskResult<()> { + let sql_upper = sql.trim().to_uppercase(); + if !sql_upper.starts_with("SELECT") { + return Err(TuskError::Custom( + "Validation query must be a SELECT statement".to_string(), + )); + } + Ok(()) +} + +fn validate_index_ddl(ddl: &str) -> TuskResult<()> { + let ddl_upper = ddl.trim().to_uppercase(); + if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") { + return Err(TuskError::Custom( + "Only CREATE INDEX and DROP INDEX statements are allowed".to_string(), + )); + } + Ok(()) +} + fn clean_sql_response(raw: &str) -> String { let trimmed = raw.trim(); // Remove markdown code fences @@ -1098,18 +1119,13 @@ pub async fn run_validation_rule( sql: String, sample_limit: Option, ) -> TuskResult { - let sql_upper = sql.trim().to_uppercase(); - if !sql_upper.starts_with("SELECT") { - return Err(TuskError::Custom( - "Validation query must be a SELECT statement".to_string(), - )); - } + validate_select_statement(&sql)?; let pool = state.get_pool(&connection_id).await?; let limit = sample_limit.unwrap_or(10); let _start = Instant::now(); - let mut tx = (&pool).begin().await.map_err(TuskError::Database)?; + let mut tx = pool.begin().await.map_err(TuskError::Database)?; sqlx::query("SET TRANSACTION READ ONLY") .execute(&mut *tx) .await @@ -1199,7 +1215,13 @@ pub async fn suggest_validation_rules( schema_text ); - let raw = call_ollama_chat(&app, &state, system_prompt, "Suggest validation rules".to_string()).await?; + let raw = call_ollama_chat( + &app, + &state, + system_prompt, + "Suggest validation rules".to_string(), + ) + .await?; let cleaned = raw.trim(); let json_start = cleaned.find('[').unwrap_or(0); @@ -1207,7 +1229,10 @@ pub async fn suggest_validation_rules( let json_str = &cleaned[json_start..json_end]; let rules: Vec = serde_json::from_str(json_str).map_err(|e| { - TuskError::Ai(format!("Failed to parse AI response as JSON array: {}. Response: {}", e, cleaned)) + TuskError::Ai(format!( + "Failed to parse AI response as JSON array: {}. Response: {}", + e, cleaned + )) })?; Ok(rules) @@ -1226,13 +1251,16 @@ pub async fn generate_test_data_preview( ) -> TuskResult { let pool = state.get_pool(¶ms.connection_id).await?; - let _ = app.emit("datagen-progress", DataGenProgress { - gen_id: gen_id.clone(), - stage: "schema".to_string(), - percent: 10, - message: "Building schema context...".to_string(), - detail: None, - }); + let _ = app.emit( + "datagen-progress", + DataGenProgress { + gen_id: gen_id.clone(), + stage: "schema".to_string(), + percent: 10, + message: "Building schema context...".to_string(), + detail: None, + }, + ); let schema_text = build_schema_context(&state, ¶ms.connection_id).await?; @@ -1255,7 +1283,14 @@ pub async fn generate_test_data_preview( let fk_edges: Vec<(String, String, String, String)> = fk_rows .iter() - .map(|fk| (fk.schema.clone(), fk.table.clone(), fk.ref_schema.clone(), fk.ref_table.clone())) + .map(|fk| { + ( + fk.schema.clone(), + fk.table.clone(), + fk.ref_schema.clone(), + fk.ref_table.clone(), + ) + }) .collect(); let sorted_tables = topological_sort_tables(&fk_edges, &target_tables); @@ -1266,13 +1301,16 @@ pub async fn generate_test_data_preview( let row_count = params.row_count.min(1000); - let _ = app.emit("datagen-progress", DataGenProgress { - gen_id: gen_id.clone(), - stage: "generating".to_string(), - percent: 30, - message: "AI is generating test data...".to_string(), - detail: None, - }); + let _ = app.emit( + "datagen-progress", + DataGenProgress { + gen_id: gen_id.clone(), + stage: "generating".to_string(), + percent: 30, + message: "AI is generating test data...".to_string(), + detail: None, + }, + ); let tables_desc: Vec = sorted_tables .iter() @@ -1329,13 +1367,16 @@ pub async fn generate_test_data_preview( ) .await?; - let _ = app.emit("datagen-progress", DataGenProgress { - gen_id: gen_id.clone(), - stage: "parsing".to_string(), - percent: 80, - message: "Parsing generated data...".to_string(), - detail: None, - }); + let _ = app.emit( + "datagen-progress", + DataGenProgress { + gen_id: gen_id.clone(), + stage: "parsing".to_string(), + percent: 80, + message: "Parsing generated data...".to_string(), + detail: None, + }, + ); // Parse JSON response let cleaned = raw.trim(); @@ -1343,9 +1384,13 @@ pub async fn generate_test_data_preview( let json_end = cleaned.rfind('}').map(|i| i + 1).unwrap_or(cleaned.len()); let json_str = &cleaned[json_start..json_end]; - let data_map: HashMap>> = - serde_json::from_str(json_str).map_err(|e| { - TuskError::Ai(format!("Failed to parse generated data: {}. Response: {}", e, &cleaned[..cleaned.len().min(500)])) + let data_map: HashMap>> = serde_json::from_str(json_str) + .map_err(|e| { + TuskError::Ai(format!( + "Failed to parse generated data: {}. Response: {}", + e, + &cleaned[..cleaned.len().min(500)] + )) })?; let mut tables = Vec::new(); @@ -1362,7 +1407,12 @@ pub async fn generate_test_data_preview( let rows: Vec> = rows_data .iter() - .map(|row_map| columns.iter().map(|col| row_map.get(col).cloned().unwrap_or(Value::Null)).collect()) + .map(|row_map| { + columns + .iter() + .map(|col| row_map.get(col).cloned().unwrap_or(Value::Null)) + .collect() + }) .collect(); let count = rows.len() as u32; @@ -1378,13 +1428,20 @@ pub async fn generate_test_data_preview( } } - let _ = app.emit("datagen-progress", DataGenProgress { - gen_id: gen_id.clone(), - stage: "done".to_string(), - percent: 100, - message: "Data generation complete".to_string(), - detail: Some(format!("{} rows across {} tables", total_rows, tables.len())), - }); + let _ = app.emit( + "datagen-progress", + DataGenProgress { + gen_id: gen_id.clone(), + stage: "done".to_string(), + percent: 100, + message: "Data generation complete".to_string(), + detail: Some(format!( + "{} rows across {} tables", + total_rows, + tables.len() + )), + }, + ); Ok(GeneratedDataPreview { tables, @@ -1404,7 +1461,7 @@ pub async fn insert_generated_data( } let pool = state.get_pool(&connection_id).await?; - let mut tx = (&pool).begin().await.map_err(TuskError::Database)?; + let mut tx = pool.begin().await.map_err(TuskError::Database)?; // Defer constraints for circular FKs sqlx::query("SET CONSTRAINTS ALL DEFERRED") @@ -1466,20 +1523,38 @@ pub async fn get_index_advisor_report( ) -> TuskResult { let pool = state.get_pool(&connection_id).await?; - // Fetch table stats - let table_stats_rows = sqlx::query( - "SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \ - pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \ - pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \ - FROM pg_stat_user_tables \ - ORDER BY seq_scan DESC \ - LIMIT 50" - ) - .fetch_all(&pool) - .await - .map_err(TuskError::Database)?; + // Fetch table stats, index stats, and slow queries concurrently + let (table_stats_res, index_stats_res, slow_queries_res) = tokio::join!( + sqlx::query( + "SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \ + pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \ + pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \ + FROM pg_stat_user_tables \ + ORDER BY seq_scan DESC \ + LIMIT 50" + ) + .fetch_all(&pool), + sqlx::query( + "SELECT schemaname, relname, indexrelname, idx_scan, \ + pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \ + pg_get_indexdef(indexrelid) AS definition \ + FROM pg_stat_user_indexes \ + ORDER BY idx_scan ASC \ + LIMIT 50", + ) + .fetch_all(&pool), + sqlx::query( + "SELECT query, calls, total_exec_time, mean_exec_time, rows \ + FROM pg_stat_statements \ + WHERE calls > 0 \ + ORDER BY mean_exec_time DESC \ + LIMIT 20", + ) + .fetch_all(&pool), + ); - let table_stats: Vec = table_stats_rows + let table_stats: Vec = table_stats_res + .map_err(TuskError::Database)? .iter() .map(|r| TableStats { schema: r.get(0), @@ -1492,20 +1567,8 @@ pub async fn get_index_advisor_report( }) .collect(); - // Fetch index stats - let index_stats_rows = sqlx::query( - "SELECT schemaname, relname, indexrelname, idx_scan, \ - pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \ - pg_get_indexdef(indexrelid) AS definition \ - FROM pg_stat_user_indexes \ - ORDER BY idx_scan ASC \ - LIMIT 50" - ) - .fetch_all(&pool) - .await - .map_err(TuskError::Database)?; - - let index_stats: Vec = index_stats_rows + let index_stats: Vec = index_stats_res + .map_err(TuskError::Database)? .iter() .map(|r| IndexStats { schema: r.get(0), @@ -1517,17 +1580,7 @@ pub async fn get_index_advisor_report( }) .collect(); - // Fetch slow queries (graceful if pg_stat_statements not available) - let (slow_queries, has_pg_stat_statements) = match sqlx::query( - "SELECT query, calls, total_exec_time, mean_exec_time, rows \ - FROM pg_stat_statements \ - WHERE calls > 0 \ - ORDER BY mean_exec_time DESC \ - LIMIT 20" - ) - .fetch_all(&pool) - .await - { + let (slow_queries, has_pg_stat_statements) = match slow_queries_res { Ok(rows) => { let queries: Vec = rows .iter() @@ -1551,7 +1604,13 @@ pub async fn get_index_advisor_report( for ts in &table_stats { stats_text.push_str(&format!( " {}.{}: seq_scan={}, idx_scan={}, rows={}, size={}, idx_size={}\n", - ts.schema, ts.table, ts.seq_scan, ts.idx_scan, ts.n_live_tup, ts.table_size, ts.index_size + ts.schema, + ts.table, + ts.seq_scan, + ts.idx_scan, + ts.n_live_tup, + ts.table_size, + ts.index_size )); } @@ -1568,7 +1627,10 @@ pub async fn get_index_advisor_report( for sq in &slow_queries { stats_text.push_str(&format!( " calls={}, mean={:.1}ms, total={:.1}ms, rows={}: {}\n", - sq.calls, sq.mean_time_ms, sq.total_time_ms, sq.rows, + sq.calls, + sq.mean_time_ms, + sq.total_time_ms, + sq.rows, sq.query.chars().take(200).collect::() )); } @@ -1635,12 +1697,7 @@ pub async fn apply_index_recommendation( return Err(TuskError::ReadOnly); } - let ddl_upper = ddl.trim().to_uppercase(); - if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") { - return Err(TuskError::Custom( - "Only CREATE INDEX and DROP INDEX statements are allowed".to_string(), - )); - } + validate_index_ddl(&ddl)?; let pool = state.get_pool(&connection_id).await?; @@ -1655,3 +1712,151 @@ pub async fn apply_index_recommendation( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + // ── validate_select_statement ───────────────────────────── + + #[test] + fn select_valid_simple() { + assert!(validate_select_statement("SELECT 1").is_ok()); + } + + #[test] + fn select_valid_with_leading_whitespace() { + assert!(validate_select_statement(" SELECT * FROM users").is_ok()); + } + + #[test] + fn select_valid_lowercase() { + assert!(validate_select_statement("select * from users").is_ok()); + } + + #[test] + fn select_rejects_insert() { + assert!(validate_select_statement("INSERT INTO users VALUES (1)").is_err()); + } + + #[test] + fn select_rejects_delete() { + assert!(validate_select_statement("DELETE FROM users").is_err()); + } + + #[test] + fn select_rejects_drop() { + assert!(validate_select_statement("DROP TABLE users").is_err()); + } + + #[test] + fn select_rejects_empty() { + assert!(validate_select_statement("").is_err()); + } + + #[test] + fn select_rejects_whitespace_only() { + assert!(validate_select_statement(" ").is_err()); + } + + // NOTE: This test documents a known weakness — SELECT prefix allows injection + #[test] + fn select_allows_semicolon_after_select() { + // "SELECT 1; DROP TABLE users" starts with SELECT — passes validation + // This is a known limitation documented in the review + assert!(validate_select_statement("SELECT 1; DROP TABLE users").is_ok()); + } + + // ── validate_index_ddl ──────────────────────────────────── + + #[test] + fn ddl_valid_create_index() { + assert!(validate_index_ddl("CREATE INDEX idx_name ON users(email)").is_ok()); + } + + #[test] + fn ddl_valid_create_index_concurrently() { + assert!(validate_index_ddl("CREATE INDEX CONCURRENTLY idx ON t(c)").is_ok()); + } + + #[test] + fn ddl_valid_drop_index() { + assert!(validate_index_ddl("DROP INDEX idx_name").is_ok()); + } + + #[test] + fn ddl_valid_with_leading_whitespace() { + assert!(validate_index_ddl(" CREATE INDEX idx ON t(c)").is_ok()); + } + + #[test] + fn ddl_valid_lowercase() { + assert!(validate_index_ddl("create index idx on t(c)").is_ok()); + } + + #[test] + fn ddl_rejects_create_table() { + assert!(validate_index_ddl("CREATE TABLE evil(id int)").is_err()); + } + + #[test] + fn ddl_rejects_drop_table() { + assert!(validate_index_ddl("DROP TABLE users").is_err()); + } + + #[test] + fn ddl_rejects_alter_table() { + assert!(validate_index_ddl("ALTER TABLE users ADD COLUMN x int").is_err()); + } + + #[test] + fn ddl_rejects_empty() { + assert!(validate_index_ddl("").is_err()); + } + + // NOTE: Documents bypass weakness — semicolon after valid prefix + #[test] + fn ddl_allows_semicolon_injection() { + // "CREATE INDEX x ON t(c); DROP TABLE users" — passes validation + // Mitigated by sqlx single-statement execution + assert!(validate_index_ddl("CREATE INDEX x ON t(c); DROP TABLE users").is_ok()); + } + + // ── clean_sql_response ──────────────────────────────────── + + #[test] + fn clean_sql_plain() { + assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1"); + } + + #[test] + fn clean_sql_with_fences() { + assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1"); + } + + #[test] + fn clean_sql_with_generic_fences() { + assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1"); + } + + #[test] + fn clean_sql_with_postgresql_fences() { + assert_eq!( + clean_sql_response("```postgresql\nSELECT 1\n```"), + "SELECT 1" + ); + } + + #[test] + fn clean_sql_with_whitespace() { + assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1"); + } + + #[test] + fn clean_sql_no_fences_multiline() { + assert_eq!( + clean_sql_response("SELECT\n *\nFROM users"), + "SELECT\n *\nFROM users" + ); + } +} diff --git a/src-tauri/src/commands/docker.rs b/src-tauri/src/commands/docker.rs index e39df66..62de158 100644 --- a/src-tauri/src/commands/docker.rs +++ b/src-tauri/src/commands/docker.rs @@ -63,10 +63,7 @@ fn shell_escape(s: &str) -> String { /// Validate pg_version matches a safe pattern (e.g. "16", "16.2", "17.1") fn validate_pg_version(version: &str) -> TuskResult<()> { - let is_valid = !version.is_empty() - && version - .chars() - .all(|c| c.is_ascii_digit() || c == '.'); + let is_valid = !version.is_empty() && version.chars().all(|c| c.is_ascii_digit() || c == '.'); if !is_valid { return Err(docker_err(format!( "Invalid pg_version '{}': must contain only digits and dots (e.g. '16', '16.2')", @@ -116,7 +113,9 @@ pub async fn check_docker(state: State<'_, Arc>) -> TuskResult>) -> TuskResult> { +pub async fn list_tusk_containers( + state: State<'_, Arc>, +) -> TuskResult> { let output = docker_cmd(&state) .await .args([ @@ -234,8 +233,8 @@ async fn check_docker_internal(docker_host: &Option) -> TuskResult p, None => find_free_port().await?, }; - emit_progress(app, clone_id, "port", 10, &format!("Using port {}", host_port), None); + emit_progress( + app, + clone_id, + "port", + 10, + &format!("Using port {}", host_port), + None, + ); // Step 3: Create container - emit_progress(app, clone_id, "container", 20, "Creating PostgreSQL container...", None); + emit_progress( + app, + clone_id, + "container", + 20, + "Creating PostgreSQL container...", + None, + ); let pg_password = params.postgres_password.as_deref().unwrap_or("tusk"); let image = format!("postgres:{}", params.pg_version); let create_output = docker_cmd_sync(&docker_host) .args([ - "run", "-d", - "--name", ¶ms.container_name, - "-p", &format!("{}:5432", host_port), - "-e", &format!("POSTGRES_PASSWORD={}", pg_password), - "-l", "tusk.managed=true", - "-l", &format!("tusk.source-db={}", params.source_database), - "-l", &format!("tusk.source-connection={}", params.source_connection_id), - "-l", &format!("tusk.pg-version={}", params.pg_version), + "run", + "-d", + "--name", + ¶ms.container_name, + "-p", + &format!("{}:5432", host_port), + "-e", + &format!("POSTGRES_PASSWORD={}", pg_password), + "-l", + "tusk.managed=true", + "-l", + &format!("tusk.source-db={}", params.source_database), + "-l", + &format!("tusk.source-connection={}", params.source_connection_id), + "-l", + &format!("tusk.pg-version={}", params.pg_version), &image, ]) .output() @@ -306,24 +334,56 @@ async fn do_clone( .map_err(|e| docker_err(format!("Failed to create container: {}", e)))?; if !create_output.status.success() { - let stderr = String::from_utf8_lossy(&create_output.stderr).trim().to_string(); - emit_progress(app, clone_id, "error", 20, &format!("Failed to create container: {}", stderr), None); - return Err(docker_err(format!("Failed to create container: {}", stderr))); + let stderr = String::from_utf8_lossy(&create_output.stderr) + .trim() + .to_string(); + emit_progress( + app, + clone_id, + "error", + 20, + &format!("Failed to create container: {}", stderr), + None, + ); + return Err(docker_err(format!( + "Failed to create container: {}", + stderr + ))); } - let container_id = String::from_utf8_lossy(&create_output.stdout).trim().to_string(); + let container_id = String::from_utf8_lossy(&create_output.stdout) + .trim() + .to_string(); // Step 4: Wait for PostgreSQL to be ready - emit_progress(app, clone_id, "waiting", 30, "Waiting for PostgreSQL to be ready...", None); + emit_progress( + app, + clone_id, + "waiting", + 30, + "Waiting for PostgreSQL to be ready...", + None, + ); wait_for_pg_ready(&docker_host, ¶ms.container_name, 30).await?; emit_progress(app, clone_id, "waiting", 35, "PostgreSQL is ready", None); // Step 5: Create target database - emit_progress(app, clone_id, "database", 35, &format!("Creating database '{}'...", params.source_database), None); + emit_progress( + app, + clone_id, + "database", + 35, + &format!("Creating database '{}'...", params.source_database), + None, + ); let create_db_output = docker_cmd_sync(&docker_host) .args([ - "exec", ¶ms.container_name, - "psql", "-U", "postgres", "-c", + "exec", + ¶ms.container_name, + "psql", + "-U", + "postgres", + "-c", &format!("CREATE DATABASE {}", escape_ident(¶ms.source_database)), ]) .output() @@ -331,20 +391,42 @@ async fn do_clone( .map_err(|e| docker_err(format!("Failed to create database: {}", e)))?; if !create_db_output.status.success() { - let stderr = String::from_utf8_lossy(&create_db_output.stderr).trim().to_string(); + let stderr = String::from_utf8_lossy(&create_db_output.stderr) + .trim() + .to_string(); if !stderr.contains("already exists") { - emit_progress(app, clone_id, "error", 35, &format!("Failed to create database: {}", stderr), None); + emit_progress( + app, + clone_id, + "error", + 35, + &format!("Failed to create database: {}", stderr), + None, + ); return Err(docker_err(format!("Failed to create database: {}", stderr))); } } // Step 6: Get source connection URL (using the specific database to clone) - emit_progress(app, clone_id, "dump", 40, "Preparing data transfer...", None); + emit_progress( + app, + clone_id, + "dump", + 40, + "Preparing data transfer...", + None, + ); let source_config = load_connection_config(app, ¶ms.source_connection_id)?; let source_url = source_config.connection_url_for_db(¶ms.source_database); emit_progress( - app, clone_id, "dump", 40, - &format!("Source: {}@{}:{}/{}", source_config.user, source_config.host, source_config.port, params.source_database), + app, + clone_id, + "dump", + 40, + &format!( + "Source: {}@{}:{}/{}", + source_config.user, source_config.host, source_config.port, params.source_database + ), None, ); @@ -352,23 +434,84 @@ async fn do_clone( match params.clone_mode { CloneMode::SchemaOnly => { emit_progress(app, clone_id, "transfer", 45, "Dumping schema...", None); - transfer_schema_only(app, clone_id, &source_url, ¶ms.container_name, ¶ms.source_database, ¶ms.pg_version, &docker_host).await?; + transfer_schema_only( + app, + clone_id, + &source_url, + ¶ms.container_name, + ¶ms.source_database, + ¶ms.pg_version, + &docker_host, + ) + .await?; } CloneMode::FullClone => { - emit_progress(app, clone_id, "transfer", 45, "Performing full database clone...", None); - transfer_full_clone(app, clone_id, &source_url, ¶ms.container_name, ¶ms.source_database, ¶ms.pg_version, &docker_host).await?; + emit_progress( + app, + clone_id, + "transfer", + 45, + "Performing full database clone...", + None, + ); + transfer_full_clone( + app, + clone_id, + &source_url, + ¶ms.container_name, + ¶ms.source_database, + ¶ms.pg_version, + &docker_host, + ) + .await?; } CloneMode::SampleData => { + let has_local = try_local_pg_dump().await; emit_progress(app, clone_id, "transfer", 45, "Dumping schema...", None); - transfer_schema_only(app, clone_id, &source_url, ¶ms.container_name, ¶ms.source_database, ¶ms.pg_version, &docker_host).await?; - emit_progress(app, clone_id, "transfer", 65, "Copying sample data...", None); + transfer_schema_only_with( + app, + clone_id, + &source_url, + ¶ms.container_name, + ¶ms.source_database, + ¶ms.pg_version, + &docker_host, + has_local, + ) + .await?; + emit_progress( + app, + clone_id, + "transfer", + 65, + "Copying sample data...", + None, + ); let sample_rows = params.sample_rows.unwrap_or(1000); - transfer_sample_data(app, clone_id, &source_url, ¶ms.container_name, ¶ms.source_database, ¶ms.pg_version, sample_rows, &docker_host).await?; + transfer_sample_data_with( + app, + clone_id, + &source_url, + ¶ms.container_name, + ¶ms.source_database, + ¶ms.pg_version, + sample_rows, + &docker_host, + has_local, + ) + .await?; } } // Step 8: Save connection in Tusk - emit_progress(app, clone_id, "connection", 90, "Saving connection...", None); + emit_progress( + app, + clone_id, + "connection", + 90, + "Saving connection...", + None, + ); let connection_id = uuid::Uuid::new_v4().to_string(); let new_config = ConnectionConfig { id: connection_id.clone(), @@ -407,7 +550,14 @@ async fn do_clone( connection_url, }; - emit_progress(app, clone_id, "done", 100, "Clone completed successfully!", None); + emit_progress( + app, + clone_id, + "done", + 100, + "Clone completed successfully!", + None, + ); Ok(result) } @@ -424,7 +574,11 @@ async fn find_free_port() -> TuskResult { Ok(port) } -async fn wait_for_pg_ready(docker_host: &Option, container_name: &str, timeout_secs: u64) -> TuskResult<()> { +async fn wait_for_pg_ready( + docker_host: &Option, + container_name: &str, + timeout_secs: u64, +) -> TuskResult<()> { let start = std::time::Instant::now(); let timeout = std::time::Duration::from_secs(timeout_secs); @@ -466,7 +620,13 @@ fn docker_host_flag(docker_host: &Option) -> String { } /// Build the pg_dump portion of a shell command -fn pg_dump_shell_cmd(has_local: bool, pg_version: &str, extra_args: &str, source_url: &str, docker_host: &Option) -> String { +fn pg_dump_shell_cmd( + has_local: bool, + pg_version: &str, + extra_args: &str, + source_url: &str, + docker_host: &Option, +) -> String { let escaped_url = shell_escape(source_url); if has_local { format!("pg_dump {} '{}'", extra_args, escaped_url) @@ -503,7 +663,8 @@ async fn run_pipe_cmd( if !stderr.is_empty() { // Truncate for progress display (full log can be long) let short = if stderr.len() > 500 { - let truncated = stderr.char_indices() + let truncated = stderr + .char_indices() .nth(500) .map(|(i, _)| &stderr[..i]) .unwrap_or(&stderr); @@ -511,33 +672,57 @@ async fn run_pipe_cmd( } else { stderr.clone() }; - emit_progress(app, clone_id, "transfer", 55, &format!("{}: stderr output", label), Some(&short)); + emit_progress( + app, + clone_id, + "transfer", + 55, + &format!("{}: stderr output", label), + Some(&short), + ); } // Count DDL statements in stdout for feedback if !stdout.is_empty() { - let creates = stdout.lines() + let creates = stdout + .lines() .filter(|l| l.starts_with("CREATE") || l.starts_with("ALTER") || l.starts_with("SET")) .count(); if creates > 0 { - emit_progress(app, clone_id, "transfer", 58, &format!("Applied {} SQL statements", creates), None); + emit_progress( + app, + clone_id, + "transfer", + 58, + &format!("Applied {} SQL statements", creates), + None, + ); } } if !output.status.success() { let code = output.status.code().unwrap_or(-1); emit_progress( - app, clone_id, "transfer", 55, + app, + clone_id, + "transfer", + 55, &format!("{} exited with code {}", label, code), Some(&stderr), ); // Only hard-fail on connection / fatal errors - if stderr.contains("FATAL") || stderr.contains("could not connect") - || stderr.contains("No such file") || stderr.contains("password authentication failed") - || stderr.contains("does not exist") || (stdout.is_empty() && stderr.is_empty()) + if stderr.contains("FATAL") + || stderr.contains("could not connect") + || stderr.contains("No such file") + || stderr.contains("password authentication failed") + || stderr.contains("does not exist") + || (stdout.is_empty() && stderr.is_empty()) { - return Err(docker_err(format!("{} failed (exit {}): {}", label, code, stderr))); + return Err(docker_err(format!( + "{} failed (exit {}): {}", + label, code, stderr + ))); } } @@ -554,20 +739,61 @@ async fn transfer_schema_only( docker_host: &Option, ) -> TuskResult<()> { let has_local = try_local_pg_dump().await; - let label = if has_local { "local pg_dump" } else { "Docker-based pg_dump" }; - emit_progress(app, clone_id, "transfer", 48, &format!("Using {} for schema...", label), None); + transfer_schema_only_with(app, clone_id, source_url, container_name, database, pg_version, docker_host, has_local).await +} - let dump_cmd = pg_dump_shell_cmd(has_local, pg_version, "--schema-only --no-owner --no-acl", source_url, docker_host); +#[allow(clippy::too_many_arguments)] +async fn transfer_schema_only_with( + app: &AppHandle, + clone_id: &str, + source_url: &str, + container_name: &str, + database: &str, + pg_version: &str, + docker_host: &Option, + has_local: bool, +) -> TuskResult<()> { + let label = if has_local { + "local pg_dump" + } else { + "Docker-based pg_dump" + }; + emit_progress( + app, + clone_id, + "transfer", + 48, + &format!("Using {} for schema...", label), + None, + ); + + let dump_cmd = pg_dump_shell_cmd( + has_local, + pg_version, + "--schema-only --no-owner --no-acl", + source_url, + docker_host, + ); let escaped_db = shell_escape(database); let host_flag = docker_host_flag(docker_host); let pipe_cmd = format!( "{} | docker {} exec -i '{}' psql -U postgres -d '{}'", - dump_cmd, host_flag, shell_escape(container_name), escaped_db + dump_cmd, + host_flag, + shell_escape(container_name), + escaped_db ); run_pipe_cmd(app, clone_id, &pipe_cmd, "Schema transfer").await?; - emit_progress(app, clone_id, "transfer", 60, "Schema transferred successfully", None); + emit_progress( + app, + clone_id, + "transfer", + 60, + "Schema transferred successfully", + None, + ); Ok(()) } @@ -581,16 +807,36 @@ async fn transfer_full_clone( docker_host: &Option, ) -> TuskResult<()> { let has_local = try_local_pg_dump().await; - let label = if has_local { "local pg_dump" } else { "Docker-based pg_dump" }; - emit_progress(app, clone_id, "transfer", 48, &format!("Using {} for full clone...", label), None); + let label = if has_local { + "local pg_dump" + } else { + "Docker-based pg_dump" + }; + emit_progress( + app, + clone_id, + "transfer", + 48, + &format!("Using {} for full clone...", label), + None, + ); // Use plain text format piped to psql (more reliable than -Fc | pg_restore through docker exec) - let dump_cmd = pg_dump_shell_cmd(has_local, pg_version, "--no-owner --no-acl", source_url, docker_host); + let dump_cmd = pg_dump_shell_cmd( + has_local, + pg_version, + "--no-owner --no-acl", + source_url, + docker_host, + ); let escaped_db = shell_escape(database); let host_flag = docker_host_flag(docker_host); let pipe_cmd = format!( "{} | docker {} exec -i '{}' psql -U postgres -d '{}'", - dump_cmd, host_flag, shell_escape(container_name), escaped_db + dump_cmd, + host_flag, + shell_escape(container_name), + escaped_db ); run_pipe_cmd(app, clone_id, &pipe_cmd, "Full clone").await?; @@ -599,7 +845,8 @@ async fn transfer_full_clone( Ok(()) } -async fn transfer_sample_data( +#[allow(clippy::too_many_arguments)] +async fn transfer_sample_data_with( app: &AppHandle, clone_id: &str, source_url: &str, @@ -608,6 +855,7 @@ async fn transfer_sample_data( pg_version: &str, sample_rows: u32, docker_host: &Option, + has_local: bool, ) -> TuskResult<()> { // List tables from the target (schema already transferred) let target_output = docker_cmd_sync(docker_host) @@ -622,21 +870,37 @@ async fn transfer_sample_data( .map_err(|e| docker_err(format!("Failed to list tables: {}", e)))?; let tables_str = String::from_utf8_lossy(&target_output.stdout); - let tables: Vec<&str> = tables_str.lines().filter(|l| !l.trim().is_empty()).collect(); + let tables: Vec<&str> = tables_str + .lines() + .filter(|l| !l.trim().is_empty()) + .collect(); let total = tables.len(); if total == 0 { - emit_progress(app, clone_id, "transfer", 85, "No tables to copy data for", None); + emit_progress( + app, + clone_id, + "transfer", + 85, + "No tables to copy data for", + None, + ); return Ok(()); } - let has_local = try_local_pg_dump().await; - for (i, qualified_table) in tables.iter().enumerate() { let pct = 65 + ((i * 20) / total.max(1)).min(20) as u8; emit_progress( - app, clone_id, "transfer", pct, - &format!("Copying sample data: {} ({}/{})", qualified_table, i + 1, total), + app, + clone_id, + "transfer", + pct, + &format!( + "Copying sample data: {} ({}/{})", + qualified_table, + i + 1, + total + ), None, ); @@ -680,17 +944,17 @@ async fn transfer_sample_data( source_cmd, host_flag, escaped_container, escaped_db, copy_in_sql ); - let output = Command::new("bash") - .args(["-c", &pipe_cmd]) - .output() - .await; + let output = Command::new("bash").args(["-c", &pipe_cmd]).output().await; match output { Ok(out) => { let stderr = String::from_utf8_lossy(&out.stderr).trim().to_string(); if !stderr.is_empty() && (stderr.contains("ERROR") || stderr.contains("FATAL")) { emit_progress( - app, clone_id, "transfer", pct, + app, + clone_id, + "transfer", + pct, &format!("Warning: {}", qualified_table), Some(&stderr), ); @@ -698,7 +962,10 @@ async fn transfer_sample_data( } Err(e) => { emit_progress( - app, clone_id, "transfer", pct, + app, + clone_id, + "transfer", + pct, &format!("Warning: failed to copy {}: {}", qualified_table, e), None, ); @@ -706,7 +973,14 @@ async fn transfer_sample_data( } } - emit_progress(app, clone_id, "transfer", 85, "Sample data transfer completed", None); + emit_progress( + app, + clone_id, + "transfer", + 85, + "Sample data transfer completed", + None, + ); Ok(()) } @@ -776,8 +1050,159 @@ pub async fn remove_container(state: State<'_, Arc>, name: String) -> if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - return Err(docker_err(format!("Failed to remove container: {}", stderr))); + return Err(docker_err(format!( + "Failed to remove container: {}", + stderr + ))); } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + // ── validate_container_name ─────────────────────────────── + + #[test] + fn container_name_valid_simple() { + assert!(validate_container_name("mycontainer").is_ok()); + } + + #[test] + fn container_name_valid_with_dots_dashes_underscores() { + assert!(validate_container_name("my-container_v1.2").is_ok()); + } + + #[test] + fn container_name_valid_starts_with_digit() { + assert!(validate_container_name("1container").is_ok()); + } + + #[test] + fn container_name_empty() { + assert!(validate_container_name("").is_err()); + } + + #[test] + fn container_name_starts_with_dash() { + assert!(validate_container_name("-bad").is_err()); + } + + #[test] + fn container_name_starts_with_dot() { + assert!(validate_container_name(".bad").is_err()); + } + + #[test] + fn container_name_starts_with_underscore() { + assert!(validate_container_name("_bad").is_err()); + } + + #[test] + fn container_name_with_spaces() { + assert!(validate_container_name("bad name").is_err()); + } + + #[test] + fn container_name_with_unicode() { + assert!(validate_container_name("контейнер").is_err()); + } + + #[test] + fn container_name_with_special_chars() { + assert!(validate_container_name("bad;name").is_err()); + assert!(validate_container_name("bad/name").is_err()); + assert!(validate_container_name("bad:name").is_err()); + assert!(validate_container_name("bad@name").is_err()); + } + + #[test] + fn container_name_with_shell_injection() { + assert!(validate_container_name("x; rm -rf /").is_err()); + assert!(validate_container_name("x$(whoami)").is_err()); + } + + // ── validate_pg_version ─────────────────────────────────── + + #[test] + fn pg_version_valid_major() { + assert!(validate_pg_version("16").is_ok()); + } + + #[test] + fn pg_version_valid_major_minor() { + assert!(validate_pg_version("16.2").is_ok()); + } + + #[test] + fn pg_version_valid_three_parts() { + assert!(validate_pg_version("17.1.0").is_ok()); + } + + #[test] + fn pg_version_empty() { + assert!(validate_pg_version("").is_err()); + } + + #[test] + fn pg_version_with_letters() { + assert!(validate_pg_version("16beta1").is_err()); + } + + #[test] + fn pg_version_with_injection() { + assert!(validate_pg_version("16; rm -rf").is_err()); + } + + #[test] + fn pg_version_only_dots() { + // Current impl allows dots-only — this documents the behavior + assert!(validate_pg_version("...").is_ok()); + } + + // ── shell_escape ────────────────────────────────────────── + + #[test] + fn shell_escape_no_quotes() { + assert_eq!(shell_escape("hello"), "hello"); + } + + #[test] + fn shell_escape_with_single_quote() { + assert_eq!(shell_escape("it's"), "it'\\''s"); + } + + #[test] + fn shell_escape_multiple_quotes() { + assert_eq!(shell_escape("a'b'c"), "a'\\''b'\\''c"); + } + + // ── shell_escape_double ─────────────────────────────────── + + #[test] + fn shell_escape_double_no_special() { + assert_eq!(shell_escape_double("hello"), "hello"); + } + + #[test] + fn shell_escape_double_with_backslash() { + assert_eq!(shell_escape_double(r"a\b"), r"a\\b"); + } + + #[test] + fn shell_escape_double_with_dollar() { + assert_eq!(shell_escape_double("$HOME"), "\\$HOME"); + } + + #[test] + fn shell_escape_double_with_backtick() { + assert_eq!(shell_escape_double("`whoami`"), "\\`whoami\\`"); + } + + #[test] + fn shell_escape_double_with_double_quote() { + assert_eq!(shell_escape_double(r#"say "hi""#), r#"say \"hi\""#); + } +} diff --git a/src-tauri/src/commands/snapshot.rs b/src-tauri/src/commands/snapshot.rs index 8ac093e..fa57d5c 100644 --- a/src-tauri/src/commands/snapshot.rs +++ b/src-tauri/src/commands/snapshot.rs @@ -46,12 +46,10 @@ pub async fn create_snapshot( if params.include_dependencies { for fk in &fk_rows { - for (schema, table) in ¶ms.tables.iter().map(|t| (t.schema.clone(), t.table.clone())).collect::>() { - if &fk.schema == schema && &fk.table == table { - let parent = (fk.ref_schema.clone(), fk.ref_table.clone()); - if !target_tables.contains(&parent) { - target_tables.push(parent); - } + if target_tables.iter().any(|(s, t)| s == &fk.schema && t == &fk.table) { + let parent = (fk.ref_schema.clone(), fk.ref_table.clone()); + if !target_tables.contains(&parent) { + target_tables.push(parent); } } } @@ -60,11 +58,18 @@ pub async fn create_snapshot( // FK-based topological sort let fk_edges: Vec<(String, String, String, String)> = fk_rows .iter() - .map(|fk| (fk.schema.clone(), fk.table.clone(), fk.ref_schema.clone(), fk.ref_table.clone())) + .map(|fk| { + ( + fk.schema.clone(), + fk.table.clone(), + fk.ref_schema.clone(), + fk.ref_table.clone(), + ) + }) .collect(); let sorted_tables = topological_sort_tables(&fk_edges, &target_tables); - let mut tx = (&pool).begin().await.map_err(TuskError::Database)?; + let mut tx = pool.begin().await.map_err(TuskError::Database)?; sqlx::query("SET TRANSACTION READ ONLY") .execute(&mut *tx) .await @@ -107,7 +112,11 @@ pub async fn create_snapshot( let data_rows: Vec> = rows .iter() - .map(|row| (0..columns.len()).map(|i| pg_value_to_json(row, i)).collect()) + .map(|row| { + (0..columns.len()) + .map(|i| pg_value_to_json(row, i)) + .collect() + }) .collect(); let row_count = data_rows.len() as u64; @@ -207,7 +216,7 @@ pub async fn restore_snapshot( let snapshot: Snapshot = serde_json::from_str(&data)?; let pool = state.get_pool(¶ms.connection_id).await?; - let mut tx = (&pool).begin().await.map_err(TuskError::Database)?; + let mut tx = pool.begin().await.map_err(TuskError::Database)?; sqlx::query("SET CONSTRAINTS ALL DEFERRED") .execute(&mut *tx) diff --git a/src-tauri/src/models/ai.rs b/src-tauri/src/models/ai.rs index b26d3a0..00a5fba 100644 --- a/src-tauri/src/models/ai.rs +++ b/src-tauri/src/models/ai.rs @@ -83,16 +83,6 @@ pub struct ValidationRule { pub error: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ValidationReport { - pub rules: Vec, - pub total_rules: usize, - pub passed: usize, - pub failed: usize, - pub errors: usize, - pub execution_time_ms: u128, -} - // --- Wave 2: Data Generator --- #[derive(Debug, Clone, Serialize, Deserialize)] @@ -165,9 +155,12 @@ pub struct SlowQuery { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum IndexRecommendationType { - CreateIndex, - DropIndex, - ReplaceIndex, + #[serde(rename = "create_index")] + Create, + #[serde(rename = "drop_index")] + Drop, + #[serde(rename = "replace_index")] + Replace, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index d86ae5e..931f914 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -3,7 +3,6 @@ use crate::models::ai::AiSettings; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; -use std::path::PathBuf; use std::time::{Duration, Instant}; use tokio::sync::{watch, RwLock}; @@ -22,7 +21,6 @@ pub struct SchemaCacheEntry { pub struct AppState { pub pools: RwLock>, - pub config_path: RwLock>, pub read_only: RwLock>, pub db_flavors: RwLock>, pub schema_cache: RwLock>, @@ -39,7 +37,6 @@ impl AppState { let (mcp_shutdown_tx, _) = watch::channel(false); Self { pools: RwLock::new(HashMap::new()), - config_path: RwLock::new(None), read_only: RwLock::new(HashMap::new()), db_flavors: RwLock::new(HashMap::new()), schema_cache: RwLock::new(HashMap::new()), @@ -81,6 +78,8 @@ impl AppState { pub async fn set_schema_cache(&self, connection_id: String, schema_text: String) { let mut cache = self.schema_cache.write().await; + // Evict stale entries to prevent unbounded memory growth + cache.retain(|_, entry| entry.cached_at.elapsed() < SCHEMA_CACHE_TTL); cache.insert( connection_id, SchemaCacheEntry {