diff --git a/package-lock.json b/package-lock.json index f46c36a..780171e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,13 +15,10 @@ "@tauri-apps/api": "^2.10.1", "@tauri-apps/plugin-dialog": "^2.6.0", "@tauri-apps/plugin-shell": "^2.3.5", - "@types/dagre": "^0.7.54", "@uiw/react-codemirror": "^4.25.4", - "@xyflow/react": "^12.10.2", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "^1.1.1", - "dagre": "^0.8.5", "lucide-react": "^0.563.0", "next-themes": "^0.4.6", "radix-ui": "^1.4.3", @@ -4810,61 +4807,6 @@ "assertion-error": "^2.0.1" } }, - "node_modules/@types/d3-color": { - "version": "3.1.3", - "resolved": "https://registry.npmjs.org/@types/d3-color/-/d3-color-3.1.3.tgz", - "integrity": "sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A==", - "license": "MIT" - }, - "node_modules/@types/d3-drag": { - "version": "3.0.7", - "resolved": "https://registry.npmjs.org/@types/d3-drag/-/d3-drag-3.0.7.tgz", - "integrity": "sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==", - "license": "MIT", - "dependencies": { - "@types/d3-selection": "*" - } - }, - "node_modules/@types/d3-interpolate": { - "version": "3.0.4", - "resolved": "https://registry.npmjs.org/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz", - "integrity": "sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==", - "license": "MIT", - "dependencies": { - "@types/d3-color": "*" - } - }, - "node_modules/@types/d3-selection": { - "version": "3.0.11", - "resolved": "https://registry.npmjs.org/@types/d3-selection/-/d3-selection-3.0.11.tgz", - "integrity": "sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==", - "license": "MIT" - }, - "node_modules/@types/d3-transition": { - "version": "3.0.9", - "resolved": "https://registry.npmjs.org/@types/d3-transition/-/d3-transition-3.0.9.tgz", - "integrity": "sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==", - "license": "MIT", - "dependencies": { - "@types/d3-selection": "*" - } - }, - "node_modules/@types/d3-zoom": { - "version": "3.0.8", - "resolved": "https://registry.npmjs.org/@types/d3-zoom/-/d3-zoom-3.0.8.tgz", - "integrity": "sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==", - "license": "MIT", - "dependencies": { - "@types/d3-interpolate": "*", - "@types/d3-selection": "*" - } - }, - "node_modules/@types/dagre": { - "version": "0.7.54", - "resolved": "https://registry.npmjs.org/@types/dagre/-/dagre-0.7.54.tgz", - "integrity": "sha512-QjcRY+adGbYvBFS7cwv5txhVIwX1XXIUswWl+kSQTbI6NjgZydrZkEKX/etzVd7i+bCsCb40Z/xlBY5eoFuvWQ==", - "license": "MIT" - }, "node_modules/@types/deep-eql": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/@types/deep-eql/-/deep-eql-4.0.2.tgz", @@ -5384,66 +5326,6 @@ "url": "https://opencollective.com/vitest" } }, - "node_modules/@xyflow/react": { - "version": "12.10.2", - "resolved": "https://registry.npmjs.org/@xyflow/react/-/react-12.10.2.tgz", - "integrity": "sha512-CgIi6HwlcHXwlkTpr0fxLv/0sRVNZ8IdwKLzzeCscaYBwpvfcH1QFOCeaTCuEn1FQEs/B8CjnTSjhs8udgmBgQ==", - "license": "MIT", - "dependencies": { - "@xyflow/system": "0.0.76", - "classcat": "^5.0.3", - "zustand": "^4.4.0" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/@xyflow/react/node_modules/zustand": { - "version": "4.5.7", - "resolved": "https://registry.npmjs.org/zustand/-/zustand-4.5.7.tgz", - "integrity": "sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==", - "license": "MIT", - "dependencies": { - "use-sync-external-store": "^1.2.2" - }, - "engines": { - "node": ">=12.7.0" - }, - "peerDependencies": { - "@types/react": ">=16.8", - "immer": ">=9.0.6", - "react": ">=16.8" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "immer": { - "optional": true - }, - "react": { - "optional": true - } - } - }, - "node_modules/@xyflow/system": { - "version": "0.0.76", - "resolved": "https://registry.npmjs.org/@xyflow/system/-/system-0.0.76.tgz", - "integrity": "sha512-hvwvnRS1B3REwVDlWexsq7YQaPZeG3/mKo1jv38UmnpWmxihp14bW6VtEOuHEwJX2FvzFw8k77LyKSk/wiZVNA==", - "license": "MIT", - "dependencies": { - "@types/d3-drag": "^3.0.7", - "@types/d3-interpolate": "^3.0.4", - "@types/d3-selection": "^3.0.10", - "@types/d3-transition": "^3.0.8", - "@types/d3-zoom": "^3.0.8", - "d3-drag": "^3.0.0", - "d3-interpolate": "^3.0.1", - "d3-selection": "^3.0.0", - "d3-zoom": "^3.0.0" - } - }, "node_modules/accepts": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/accepts/-/accepts-2.0.0.tgz", @@ -5877,12 +5759,6 @@ "url": "https://polar.sh/cva" } }, - "node_modules/classcat": { - "version": "5.0.5", - "resolved": "https://registry.npmjs.org/classcat/-/classcat-5.0.5.tgz", - "integrity": "sha512-JhZUT7JFcQy/EzW605k/ktHtncoo9vnyW/2GspNYwFlN1C/WmjuV/xtS04e9SOkL2sTdw0VAZ2UGCcQ9lR6p6w==", - "license": "MIT" - }, "node_modules/cli-cursor": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/cli-cursor/-/cli-cursor-5.0.0.tgz", @@ -6268,121 +6144,6 @@ "devOptional": true, "license": "MIT" }, - "node_modules/d3-color": { - "version": "3.1.0", - "resolved": "https://registry.npmjs.org/d3-color/-/d3-color-3.1.0.tgz", - "integrity": "sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-dispatch": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/d3-dispatch/-/d3-dispatch-3.0.1.tgz", - "integrity": "sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-drag": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/d3-drag/-/d3-drag-3.0.0.tgz", - "integrity": "sha512-pWbUJLdETVA8lQNJecMxoXfH6x+mO2UQo8rSmZ+QqxcbyA3hfeprFgIT//HW2nlHChWeIIMwS2Fq+gEARkhTkg==", - "license": "ISC", - "dependencies": { - "d3-dispatch": "1 - 3", - "d3-selection": "3" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-ease": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/d3-ease/-/d3-ease-3.0.1.tgz", - "integrity": "sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==", - "license": "BSD-3-Clause", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-interpolate": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/d3-interpolate/-/d3-interpolate-3.0.1.tgz", - "integrity": "sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==", - "license": "ISC", - "dependencies": { - "d3-color": "1 - 3" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-selection": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", - "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-timer": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/d3-timer/-/d3-timer-3.0.1.tgz", - "integrity": "sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-transition": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/d3-transition/-/d3-transition-3.0.1.tgz", - "integrity": "sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==", - "license": "ISC", - "dependencies": { - "d3-color": "1 - 3", - "d3-dispatch": "1 - 3", - "d3-ease": "1 - 3", - "d3-interpolate": "1 - 3", - "d3-timer": "1 - 3" - }, - "engines": { - "node": ">=12" - }, - "peerDependencies": { - "d3-selection": "2 - 3" - } - }, - "node_modules/d3-zoom": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/d3-zoom/-/d3-zoom-3.0.0.tgz", - "integrity": "sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==", - "license": "ISC", - "dependencies": { - "d3-dispatch": "1 - 3", - "d3-drag": "2 - 3", - "d3-interpolate": "1 - 3", - "d3-selection": "2 - 3", - "d3-transition": "2 - 3" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/dagre": { - "version": "0.8.5", - "resolved": "https://registry.npmjs.org/dagre/-/dagre-0.8.5.tgz", - "integrity": "sha512-/aTqmnRta7x7MCCpExk7HQL2O4owCT2h8NT//9I1OQ9vt29Pa0BzSAkR5lwFUcQ7491yVi/3CXU9jQ5o0Mn2Sw==", - "license": "MIT", - "dependencies": { - "graphlib": "^2.1.8", - "lodash": "^4.17.15" - } - }, "node_modules/data-uri-to-buffer": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/data-uri-to-buffer/-/data-uri-to-buffer-4.0.1.tgz", @@ -7610,15 +7371,6 @@ "dev": true, "license": "ISC" }, - "node_modules/graphlib": { - "version": "2.1.8", - "resolved": "https://registry.npmjs.org/graphlib/-/graphlib-2.1.8.tgz", - "integrity": "sha512-jcLLfkpoVGmH7/InMC/1hIvOPSUh38oJtGhvrOFGzioE1DZ+0YW16RgmOJhHiuWTvGiJQ9Z1Ik43JvkRPRvE+A==", - "license": "MIT", - "dependencies": { - "lodash": "^4.17.15" - } - }, "node_modules/graphql": { "version": "16.12.0", "resolved": "https://registry.npmjs.org/graphql/-/graphql-16.12.0.tgz", @@ -8555,12 +8307,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/lodash": { - "version": "4.17.23", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", - "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", - "license": "MIT" - }, "node_modules/lodash.merge": { "version": "4.6.2", "resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz", diff --git a/package.json b/package.json index 3e2eb0a..740472e 100644 --- a/package.json +++ b/package.json @@ -20,13 +20,10 @@ "@tauri-apps/api": "^2.10.1", "@tauri-apps/plugin-dialog": "^2.6.0", "@tauri-apps/plugin-shell": "^2.3.5", - "@types/dagre": "^0.7.54", "@uiw/react-codemirror": "^4.25.4", - "@xyflow/react": "^12.10.2", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "^1.1.1", - "dagre": "^0.8.5", "lucide-react": "^0.563.0", "next-themes": "^0.4.6", "radix-ui": "^1.4.3", diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index fb29ad0..88ba80e 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -4705,7 +4705,6 @@ dependencies = [ "libc", "mio", "pin-project-lite", - "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.61.2", diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 682af1a..1c93d5e 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -21,7 +21,7 @@ tauri-plugin-shell = "2" tauri-plugin-dialog = "2" serde = { version = "1", features = ["derive"] } serde_json = "1" -tokio = { version = "1", features = ["rt-multi-thread", "sync", "time", "net", "macros", "process", "io-util"] } +tokio = { version = "1", features = ["rt-multi-thread", "sync", "time", "net", "macros"] } sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "postgres", "json", "chrono", "uuid", "bigdecimal"] } chrono = { version = "0.4", features = ["serde"] } uuid = { version = "1", features = ["serde"] } diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs index 00ff050..99207d9 100644 --- a/src-tauri/src/commands/ai.rs +++ b/src-tauri/src/commands/ai.rs @@ -1,21 +1,15 @@ -use crate::commands::data::bind_json_value; -use crate::commands::queries::pg_value_to_json; use crate::error::{TuskError, TuskResult}; use crate::models::ai::{ - AiProvider, AiSettings, DataGenProgress, GenerateDataParams, GeneratedDataPreview, - GeneratedTableData, IndexAdvisorReport, IndexRecommendation, IndexStats, OllamaChatMessage, - OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse, SlowQuery, TableStats, - ValidationRule, ValidationStatus, + AiProvider, AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, + OllamaTagsResponse, }; -use crate::state::AppState; -use crate::utils::{escape_ident, topological_sort_tables}; -use serde_json::Value; -use sqlx::{Column, Row}; +use crate::state::{AppState, DbFlavor}; +use sqlx::Row; use std::collections::{BTreeMap, HashMap}; use std::fs; use std::sync::Arc; -use std::time::{Duration, Instant}; -use tauri::{AppHandle, Emitter, Manager, State}; +use std::time::Duration; +use tauri::{AppHandle, Manager, State}; const MAX_RETRIES: u32 = 2; const RETRY_DELAY_MS: u64 = 1000; @@ -129,7 +123,7 @@ where })) } -async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskResult { +pub(crate) async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskResult { // Try in-memory cache first if let Some(cached) = state.ai_settings.read().await.clone() { return Ok(cached); @@ -153,6 +147,30 @@ async fn call_ollama_chat( state: &AppState, system_prompt: String, user_content: String, +) -> TuskResult { + call_ollama_chat_messages( + app, + state, + vec![ + OllamaChatMessage { + role: "system".to_string(), + content: system_prompt, + }, + OllamaChatMessage { + role: "user".to_string(), + content: user_content, + }, + ], + None, + ) + .await +} + +pub(crate) async fn call_ollama_chat_messages( + app: &AppHandle, + state: &AppState, + messages: Vec, + format: Option, ) -> TuskResult { let settings = load_ai_settings(app, state).await?; @@ -174,17 +192,9 @@ async fn call_ollama_chat( let request = OllamaChatRequest { model: model.clone(), - messages: vec![ - OllamaChatMessage { - role: "system".to_string(), - content: system_prompt, - }, - OllamaChatMessage { - role: "user".to_string(), - content: user_content, - }, - ], + messages, stream: false, + format, }; call_ai_with_retry(&settings, "Ollama request", || { @@ -386,7 +396,173 @@ pub async fn fix_sql_error( } // --------------------------------------------------------------------------- -// Schema context builder +// Lite overview builder (chat v2) +// --------------------------------------------------------------------------- + +const OVERVIEW_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(300); + +/// Build a compact overview of a connection's databases and active-DB tables. +/// Designed to be small enough to inject in every system prompt. +/// Cached per connection_id; invalidated by switch_database / disconnect. +pub(crate) async fn build_overview_context( + state: &AppState, + connection_id: &str, +) -> TuskResult { + if let Some(cached) = state.overview_cache.read().await.get(connection_id).cloned() { + if cached.cached_at.elapsed() < OVERVIEW_CACHE_TTL { + return Ok(cached.value); + } + } + + let flavor = state.get_flavor(connection_id).await; + let text = match flavor { + crate::state::DbFlavor::ClickHouse => build_overview_clickhouse(state, connection_id).await?, + _ => build_overview_postgres(state, connection_id).await?, + }; + + let mut cache = state.overview_cache.write().await; + cache.insert( + connection_id.to_string(), + crate::state::CachedString { + value: text.clone(), + cached_at: std::time::Instant::now(), + }, + ); + Ok(text) +} + +async fn build_overview_postgres(state: &AppState, connection_id: &str) -> TuskResult { + let pool = state.get_pool(connection_id).await?; + + let (version_res, current_db_res, dbs_res, tables_res) = tokio::join!( + sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool), + sqlx::query_scalar::<_, String>("SELECT current_database()").fetch_one(&pool), + sqlx::query_scalar::<_, String>( + "SELECT datname FROM pg_database \ + WHERE datistemplate = false AND datallowconn = true \ + ORDER BY datname" + ) + .fetch_all(&pool), + sqlx::query_as::<_, (String, String)>( + "SELECT table_schema, table_name FROM information_schema.tables \ + WHERE table_schema NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \ + AND table_type = 'BASE TABLE' \ + ORDER BY table_schema, table_name" + ) + .fetch_all(&pool), + ); + + let version = version_res.unwrap_or_default(); + let current_db = current_db_res.unwrap_or_default(); + let all_dbs = dbs_res.unwrap_or_default(); + let tables = tables_res.unwrap_or_default(); + + let short_version = version + .split_whitespace() + .take(2) + .collect::>() + .join(" "); + + let mut out = Vec::::new(); + out.push(format!("DATABASE: {}", short_version)); + if !current_db.is_empty() { + out.push(format!("ACTIVE DATABASE: {}", current_db)); + } + out.push(String::new()); + + if !all_dbs.is_empty() { + out.push("DATABASES ON THIS SERVER:".to_string()); + for db in &all_dbs { + if db == ¤t_db { + out.push(format!(" * {} (active)", db)); + } else { + out.push(format!(" {}", db)); + } + } + out.push(String::new()); + } + + out.push(format!( + "TABLES IN ACTIVE DATABASE ({}):", + tables.len() + )); + for (schema, name) in &tables { + out.push(format!(" {}.{}", schema, name)); + } + out.push(String::new()); + out.push( + "NOTE: Tables of other databases are not enumerated here. \ + Call list_tables({\"database\":\"\"}) to see them. For PostgreSQL this also \ + requires switch_database before you can actually query data there." + .to_string(), + ); + + Ok(out.join("\n")) +} + +async fn build_overview_clickhouse(state: &AppState, connection_id: &str) -> TuskResult { + let client = state.get_ch_client(connection_id).await?; + let active_db = client.database.clone(); + + let dbs_rows = client + .fetch_objects( + "SELECT name FROM system.databases \ + WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \ + ORDER BY name", + ) + .await + .unwrap_or_default(); + let table_rows = client + .fetch_objects( + "SELECT database, name FROM system.tables \ + WHERE database NOT IN ('system','INFORMATION_SCHEMA','information_schema') \ + ORDER BY database, name", + ) + .await + .unwrap_or_default(); + let version = client.ping().await.unwrap_or_default(); + + let mut out = Vec::::new(); + out.push("DATABASE: ClickHouse".to_string()); + if !version.is_empty() { + out.push(format!("VERSION: {}", version.trim())); + } + out.push(format!("ACTIVE DATABASE: {}", active_db)); + out.push(String::new()); + + if !dbs_rows.is_empty() { + out.push("DATABASES ON THIS SERVER:".to_string()); + for row in &dbs_rows { + let name = row.get("name").and_then(|v| v.as_str()).unwrap_or(""); + if name == active_db { + out.push(format!(" * {} (active)", name)); + } else { + out.push(format!(" {}", name)); + } + } + out.push(String::new()); + } + + out.push(format!("TABLES ACROSS ALL DATABASES ({}):", table_rows.len())); + for row in &table_rows { + let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or(""); + let tbl = row.get("name").and_then(|v| v.as_str()).unwrap_or(""); + if !dbn.is_empty() && !tbl.is_empty() { + out.push(format!(" {}.{}", dbn, tbl)); + } + } + out.push(String::new()); + out.push( + "NOTE: ClickHouse allows fully-qualified `db.table` queries — \ + you can reference any table in this list directly without switching databases." + .to_string(), + ); + + Ok(out.join("\n")) +} + +// --------------------------------------------------------------------------- +// Full schema context builder (legacy — used by generate_sql/explain_sql/fix_sql_error) // --------------------------------------------------------------------------- pub(crate) async fn build_schema_context( @@ -398,11 +574,17 @@ pub(crate) async fn build_schema_context( return Ok(cached); } + if matches!(state.get_flavor(connection_id).await, DbFlavor::ClickHouse) { + return build_clickhouse_schema_context(state, connection_id).await; + } + let pool = state.get_pool(connection_id).await?; // Run all metadata queries in parallel for speed let ( version_res, + current_db_res, + all_dbs_res, col_res, fk_res, enum_res, @@ -413,6 +595,13 @@ pub(crate) async fn build_schema_context( jsonb_res, ) = tokio::join!( sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool), + sqlx::query_scalar::<_, String>("SELECT current_database()").fetch_one(&pool), + sqlx::query_scalar::<_, String>( + "SELECT datname FROM pg_database \ + WHERE datistemplate = false AND datallowconn = true \ + ORDER BY datname" + ) + .fetch_all(&pool), fetch_columns(&pool), fetch_foreign_keys_raw(&pool), fetch_enum_types(&pool), @@ -424,6 +613,8 @@ pub(crate) async fn build_schema_context( ); let version = version_res.map_err(TuskError::Database)?; + let current_db = current_db_res.unwrap_or_default(); + let all_dbs = all_dbs_res.unwrap_or_default(); let col_rows = col_res?; let fk_rows = fk_res?; let enum_map = enum_res?; @@ -476,9 +667,58 @@ pub(crate) async fn build_schema_context( .collect::>() .join(" "); output.push(format!("DATABASE SCHEMA ({})", short_version)); + if !current_db.is_empty() { + output.push(format!("ACTIVE DATABASE: {}", current_db)); + } output.push(String::new()); - // 2. Enum types + // 2. Cluster topology — other databases on this server. + // Each PG database is isolated; cross-DB queries are not possible from a single connection. + if all_dbs.len() > 1 { + output.push("DATABASES ON THIS SERVER:".to_string()); + for db in &all_dbs { + if db == ¤t_db { + output.push(format!(" * {} (active)", db)); + } else { + output.push(format!(" {}", db)); + } + } + output.push(String::new()); + output.push( + "NOTE: Tables in other databases are NOT queryable from this session. \ + If the user's question concerns data likely stored in a different database \ + (e.g. an identity service in a separate DB), respond with a `final` message \ + asking them to switch the active database via the connection selector." + .to_string(), + ); + output.push(String::new()); + } + + // 3. Quick table+column index for fast existence checks before writing SQL. + // Each line lists `schema.table(col1, col2, ...)` so the model can grep both + // table names and column names without scrolling through the full TABLES section. + { + let mut by_table: BTreeMap<(String, String), Vec> = BTreeMap::new(); + for c in &col_rows { + by_table + .entry((c.schema.clone(), c.table.clone())) + .or_default() + .push(c.column.clone()); + } + if !by_table.is_empty() { + output.push(format!( + "TABLE INDEX (database `{}`, {} tables — table_name(column_list)):", + current_db, + by_table.len() + )); + for ((schema, table), cols) in &by_table { + output.push(format!(" {}.{}({})", schema, table, cols.join(", "))); + } + output.push(String::new()); + } + } + + // 4. Enum types if !enum_map.is_empty() { output.push("ENUM TYPES:".to_string()); for (type_name, values) in &enum_map { @@ -492,7 +732,7 @@ pub(crate) async fn build_schema_context( output.push(String::new()); } - // 3. Tables with columns + // 5. Tables with columns output.push("TABLES:".to_string()); // Group columns by schema.table preserving order @@ -503,103 +743,21 @@ pub(crate) async fn build_schema_context( } for (full_name, columns) in &tables { - // Table header with optional comment - let tbl_comment = tbl_comments.get(full_name).map(|c| c.as_str()); - match tbl_comment { - Some(comment) => output.push(format!("\nTABLE {} -- {}", full_name, comment)), - None => output.push(format!("\nTABLE {}", full_name)), - } - - // Columns - for ci in columns { - let mut parts: Vec = vec![ci.column.clone(), ci.data_type.clone()]; - - if ci.is_pk { - parts.push("PK".to_string()); - } - if ci.not_null && !ci.is_pk { - parts.push("NOT NULL".to_string()); - } - - // Inline FK reference - if let Some(ref_target) = - fk_inline.get(&(ci.schema.clone(), ci.table.clone(), ci.column.clone())) - { - parts.push(format!("FK->{}", ref_target)); - } - - // Default value (simplified) - if let Some(ref def) = ci.column_default { - let simplified = simplify_default(def); - if !simplified.is_empty() { - parts.push(format!("DEFAULT {}", simplified)); - } - } - - // Column comment - let col_key = (ci.schema.clone(), ci.table.clone(), ci.column.clone()); - let col_comment = col_comments.get(&col_key); - - // Inline enum values for enum-typed columns - let enum_annotation = enum_map.get(&ci.data_type); - - let mut suffix_parts: Vec = Vec::new(); - if let Some(vals) = enum_annotation { - let vals_str = vals - .iter() - .map(|v| format!("'{}'", v)) - .collect::>() - .join(", "); - suffix_parts.push(format!("enum: {}", vals_str)); - } - - // Inline varchar distinct values (pseudo-enums from pg_stats) - if enum_annotation.is_none() { - if let Some(vals) = varchar_values.get(&col_key) { - let vals_str = vals - .iter() - .take(15) - .map(|v| format!("'{}'", v)) - .collect::>() - .join(", "); - suffix_parts.push(format!("values: {}", vals_str)); - } - } - - // Inline JSONB key structure - if let Some(keys) = jsonb_keys.get(&col_key) { - suffix_parts.push(format!("json keys: {}", keys.join(", "))); - } - - if let Some(comment) = col_comment { - suffix_parts.push(comment.clone()); - } - - if suffix_parts.is_empty() { - output.push(format!(" {}", parts.join(" "))); - } else { - output.push(format!( - " {} -- {}", - parts.join(" "), - suffix_parts.join("; ") - )); - } - } - - // Unique constraints for this table - let schema_table: Vec<&str> = full_name.splitn(2, '.').collect(); - if schema_table.len() == 2 { - if let Some(uqs) = - unique_map.get(&(schema_table[0].to_string(), schema_table[1].to_string())) - { - for uq_cols in uqs { - output.push(format!(" UNIQUE({})", uq_cols)); - } - } - } + format_table_block( + full_name, + columns, + &tbl_comments, + &col_comments, + &fk_inline, + &enum_map, + &unique_map, + &varchar_values, + &jsonb_keys, + &mut output, + ); } - // 4. Foreign keys summary + // 6. Foreign keys summary if !fk_lines.is_empty() { output.push(String::new()); output.push("FOREIGN KEYS:".to_string()); @@ -623,17 +781,220 @@ pub(crate) async fn build_schema_context( // --------------------------------------------------------------------------- #[derive(Clone)] -struct ColumnInfo { - schema: String, - table: String, - column: String, - data_type: String, - not_null: bool, - is_pk: bool, - column_default: Option, +pub(crate) struct ColumnInfo { + pub(crate) schema: String, + pub(crate) table: String, + pub(crate) column: String, + pub(crate) data_type: String, + pub(crate) not_null: bool, + pub(crate) is_pk: bool, + pub(crate) column_default: Option, } -async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult> { +/// Render a single table's column block in the human/LLM-readable schema format. +/// Reused by both `build_schema_context` (full DDL for legacy AI commands) and +/// the new `get_columns` chat tool. +#[allow(clippy::too_many_arguments)] +pub(crate) fn format_table_block( + full_name: &str, + columns: &[ColumnInfo], + tbl_comments: &HashMap, + col_comments: &HashMap<(String, String, String), String>, + fk_inline: &HashMap<(String, String, String), String>, + enum_map: &BTreeMap>, + unique_map: &HashMap<(String, String), Vec>, + varchar_values: &HashMap<(String, String, String), Vec>, + jsonb_keys: &HashMap<(String, String, String), Vec>, + output: &mut Vec, +) { + let tbl_comment = tbl_comments.get(full_name).map(|c| c.as_str()); + match tbl_comment { + Some(comment) => output.push(format!("\nTABLE {} -- {}", full_name, comment)), + None => output.push(format!("\nTABLE {}", full_name)), + } + + for ci in columns { + let mut parts: Vec = vec![ci.column.clone(), ci.data_type.clone()]; + + if ci.is_pk { + parts.push("PK".to_string()); + } + if ci.not_null && !ci.is_pk { + parts.push("NOT NULL".to_string()); + } + + if let Some(ref_target) = + fk_inline.get(&(ci.schema.clone(), ci.table.clone(), ci.column.clone())) + { + parts.push(format!("FK->{}", ref_target)); + } + + if let Some(ref def) = ci.column_default { + let simplified = simplify_default(def); + if !simplified.is_empty() { + parts.push(format!("DEFAULT {}", simplified)); + } + } + + let col_key = (ci.schema.clone(), ci.table.clone(), ci.column.clone()); + let col_comment = col_comments.get(&col_key); + let enum_annotation = enum_map.get(&ci.data_type); + + let mut suffix_parts: Vec = Vec::new(); + if let Some(vals) = enum_annotation { + let vals_str = vals + .iter() + .map(|v| format!("'{}'", v)) + .collect::>() + .join(", "); + suffix_parts.push(format!("enum: {}", vals_str)); + } + + if enum_annotation.is_none() { + if let Some(vals) = varchar_values.get(&col_key) { + let vals_str = vals + .iter() + .take(15) + .map(|v| format!("'{}'", v)) + .collect::>() + .join(", "); + suffix_parts.push(format!("values: {}", vals_str)); + } + } + + if let Some(keys) = jsonb_keys.get(&col_key) { + suffix_parts.push(format!("json keys: {}", keys.join(", "))); + } + + if let Some(comment) = col_comment { + suffix_parts.push(comment.clone()); + } + + if suffix_parts.is_empty() { + output.push(format!(" {}", parts.join(" "))); + } else { + output.push(format!( + " {} -- {}", + parts.join(" "), + suffix_parts.join("; ") + )); + } + } + + // Unique constraints + let schema_table: Vec<&str> = full_name.splitn(2, '.').collect(); + if schema_table.len() == 2 { + if let Some(uqs) = + unique_map.get(&(schema_table[0].to_string(), schema_table[1].to_string())) + { + for uq_cols in uqs { + output.push(format!(" UNIQUE({})", uq_cols)); + } + } + } +} + +async fn build_clickhouse_schema_context( + state: &AppState, + connection_id: &str, +) -> TuskResult { + let client = state.get_ch_client(connection_id).await?; + let db = client.database.clone(); + + // ClickHouse exposes ALL databases via system.* — pull cross-DB schema in one shot. + let columns_sql = "SELECT database, table, name, type, is_in_primary_key \ + FROM system.columns \ + WHERE database NOT IN ('system','INFORMATION_SCHEMA','information_schema') \ + ORDER BY database, table, position"; + let dbs_sql = "SELECT name FROM system.databases \ + WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \ + ORDER BY name"; + + let rows = client.fetch_objects(columns_sql).await?; + let db_rows = client.fetch_objects(dbs_sql).await.unwrap_or_default(); + + let version = client.ping().await.unwrap_or_default(); + let mut out = String::new(); + out.push_str("DATABASE: ClickHouse\n"); + if !version.is_empty() { + out.push_str(&format!("VERSION: {}\n", version.trim())); + } + out.push_str(&format!("ACTIVE_DATABASE: {}\n\n", db)); + + // Cluster overview + if !db_rows.is_empty() { + out.push_str("DATABASES ON THIS SERVER:\n"); + for row in &db_rows { + let name = row.get("name").and_then(|v| v.as_str()).unwrap_or(""); + if name == db { + out.push_str(&format!(" * {} (active)\n", name)); + } else { + out.push_str(&format!(" {}\n", name)); + } + } + out.push('\n'); + } + + // Table+column index — same shape as the PG path so model has a uniform reference. + { + let mut by_table: BTreeMap<(String, String), Vec> = BTreeMap::new(); + for r in &rows { + let dbn = r.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let tbl = r.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let col = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + if !dbn.is_empty() && !tbl.is_empty() && !col.is_empty() { + by_table.entry((dbn, tbl)).or_default().push(col); + } + } + if !by_table.is_empty() { + out.push_str(&format!( + "TABLE INDEX ({} tables across all databases — db.table(column_list)):\n", + by_table.len() + )); + for ((dbn, tbl), cols) in &by_table { + out.push_str(&format!(" {}.{}({})\n", dbn, tbl, cols.join(", "))); + } + out.push('\n'); + } + } + + out.push_str("TABLES:\n"); + let mut current_key: Option = None; + for row in &rows { + let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let table = row.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let column = row.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let dtype = row.get("type").and_then(|v| v.as_str()).unwrap_or(""); + let is_pk = matches!(row.get("is_in_primary_key"), Some(serde_json::Value::Number(n)) if n.as_i64() == Some(1)) + || matches!(row.get("is_in_primary_key"), Some(serde_json::Value::String(s)) if s == "1"); + let key = format!("{}.{}", dbn, table); + if Some(&key) != current_key.as_ref() { + out.push_str(&format!("\nTABLE {}.{}\n", dbn, table)); + current_key = Some(key); + } + out.push_str(&format!( + " {} {}{}\n", + column, + dtype, + if is_pk { " [PK]" } else { "" } + )); + } + + out.push_str( + "\nNOTES:\n\ + - Use ClickHouse SQL dialect. Functions differ from PostgreSQL (e.g. count(), arrayJoin, toDate, formatDateTime).\n\ + - ClickHouse allows fully-qualified `database.table` in queries — you CAN cross-reference databases on this server.\n\ + - Read-only mode is enforced for the agent: only SELECT/WITH/EXPLAIN/SHOW/DESCRIBE allowed.\n\ + - Always include LIMIT for ad-hoc SELECTs.\n", + ); + + state + .set_schema_cache(connection_id.to_string(), out.clone()) + .await; + Ok(out) +} + +pub(crate) async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult> { let rows = sqlx::query( "SELECT \ c.table_schema, c.table_name, c.column_name, \ @@ -720,7 +1081,7 @@ pub(crate) async fn fetch_foreign_keys_raw(pool: &sqlx::PgPool) -> TuskResult> ordered by type name -async fn fetch_enum_types(pool: &sqlx::PgPool) -> TuskResult>> { +pub(crate) async fn fetch_enum_types(pool: &sqlx::PgPool) -> TuskResult>> { let rows = sqlx::query( "SELECT t.typname, \ array_agg(e.enumlabel ORDER BY e.enumsortorder) AS vals \ @@ -745,7 +1106,7 @@ async fn fetch_enum_types(pool: &sqlx::PgPool) -> TuskResult -async fn fetch_table_comments(pool: &sqlx::PgPool) -> TuskResult> { +pub(crate) async fn fetch_table_comments(pool: &sqlx::PgPool) -> TuskResult> { let rows = sqlx::query( "SELECT n.nspname, c.relname, obj_description(c.oid, 'pg_class') \ FROM pg_class c \ @@ -769,7 +1130,7 @@ async fn fetch_table_comments(pool: &sqlx::PgPool) -> TuskResult -async fn fetch_column_comments( +pub(crate) async fn fetch_column_comments( pool: &sqlx::PgPool, ) -> TuskResult> { let rows = sqlx::query( @@ -797,7 +1158,7 @@ async fn fetch_column_comments( } /// Returns Vec<(schema, table, Vec)> for UNIQUE constraints -async fn fetch_unique_constraints( +pub(crate) async fn fetch_unique_constraints( pool: &sqlx::PgPool, ) -> TuskResult)>> { let rows = sqlx::query( @@ -1034,26 +1395,6 @@ fn simplify_default(raw: &str) -> String { s.to_string() } -fn validate_select_statement(sql: &str) -> TuskResult<()> { - let sql_upper = sql.trim().to_uppercase(); - if !sql_upper.starts_with("SELECT") { - return Err(TuskError::Validation( - "Validation query must be a SELECT statement".to_string(), - )); - } - Ok(()) -} - -fn validate_index_ddl(ddl: &str) -> TuskResult<()> { - let ddl_upper = ddl.trim().to_uppercase(); - if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") { - return Err(TuskError::Validation( - "Only CREATE INDEX and DROP INDEX statements are allowed".to_string(), - )); - } - Ok(()) -} - fn clean_sql_response(raw: &str) -> String { let trimmed = raw.trim(); // Remove markdown code fences @@ -1071,757 +1412,10 @@ fn clean_sql_response(raw: &str) -> String { without_fences.trim().to_string() } -// --------------------------------------------------------------------------- -// Wave 1: AI Data Assertions (Validation) -// --------------------------------------------------------------------------- - -#[tauri::command] -pub async fn generate_validation_sql( - app: AppHandle, - state: State<'_, Arc>, - connection_id: String, - rule_description: String, -) -> TuskResult { - let schema_text = build_schema_context(&state, &connection_id).await?; - - let system_prompt = format!( - "You are an expert PostgreSQL data quality validator. Given a database schema and a natural \ - language data quality rule, generate a SELECT query that finds ALL rows violating the rule.\n\ - \n\ - OUTPUT FORMAT:\n\ - - Raw SQL only. No explanations, no markdown code fences, no comments.\n\ - - The query MUST be a SELECT statement.\n\ - - Return violating rows with enough context columns to identify them.\n\ - \n\ - VALIDATION PATTERNS:\n\ - - NULL checks: SELECT * FROM table WHERE required_column IS NULL\n\ - - Format checks: WHERE column !~ 'pattern'\n\ - - Range checks: WHERE column < min OR column > max\n\ - - FK integrity: LEFT JOIN parent ON ... WHERE parent.id IS NULL\n\ - - Uniqueness: GROUP BY ... HAVING COUNT(*) > 1\n\ - - Date consistency: WHERE start_date > end_date\n\ - - Enum validity: WHERE column NOT IN ('val1', 'val2', ...)\n\ - \n\ - ONLY reference tables and columns that exist in the schema.\n\ - \n\ - {}\n", - schema_text - ); - - let raw = call_ollama_chat(&app, &state, system_prompt, rule_description).await?; - Ok(clean_sql_response(&raw)) -} - -#[tauri::command] -pub async fn run_validation_rule( - state: State<'_, Arc>, - connection_id: String, - sql: String, - sample_limit: Option, -) -> TuskResult { - validate_select_statement(&sql)?; - - let pool = state.get_pool(&connection_id).await?; - let limit = sample_limit.unwrap_or(10); - let _start = Instant::now(); - - let mut tx = pool.begin().await.map_err(TuskError::Database)?; - sqlx::query("SET TRANSACTION READ ONLY") - .execute(&mut *tx) - .await - .map_err(TuskError::Database)?; - sqlx::query("SET statement_timeout = '30s'") - .execute(&mut *tx) - .await - .map_err(TuskError::Database)?; - - // Count violations - let count_sql = format!("SELECT COUNT(*) FROM ({}) AS _v", sql); - let count_row = sqlx::query(&count_sql) - .fetch_one(&mut *tx) - .await - .map_err(TuskError::Database)?; - let violation_count: i64 = count_row.get(0); - - // Sample violations - let sample_sql = format!("SELECT * FROM ({}) AS _v LIMIT {}", sql, limit); - let sample_rows = sqlx::query(&sample_sql) - .fetch_all(&mut *tx) - .await - .map_err(TuskError::Database)?; - - tx.rollback().await.map_err(TuskError::Database)?; - - let mut violation_columns = Vec::new(); - let mut sample_violations = Vec::new(); - - if let Some(first) = sample_rows.first() { - for col in first.columns() { - violation_columns.push(col.name().to_string()); - } - } - - for row in &sample_rows { - let vals: Vec = (0..violation_columns.len()) - .map(|i| pg_value_to_json(row, i)) - .collect(); - sample_violations.push(vals); - } - - let status = if violation_count > 0 { - ValidationStatus::Failed - } else { - ValidationStatus::Passed - }; - - Ok(ValidationRule { - id: String::new(), - description: String::new(), - generated_sql: sql, - status, - violation_count: violation_count as u64, - sample_violations, - violation_columns, - error: None, - }) -} - -#[tauri::command] -pub async fn suggest_validation_rules( - app: AppHandle, - state: State<'_, Arc>, - connection_id: String, -) -> TuskResult> { - let schema_text = build_schema_context(&state, &connection_id).await?; - - let system_prompt = format!( - "You are a data quality expert. Given a database schema, suggest 5-10 data quality \ - validation rules as natural language descriptions.\n\ - \n\ - OUTPUT FORMAT:\n\ - - Return ONLY a JSON array of strings, each string being a validation rule.\n\ - - No markdown, no explanations, no code fences.\n\ - - Example: [\"All users must have a non-empty email address\", \"Order total must be positive\"]\n\ - \n\ - RULE CATEGORIES TO COVER:\n\ - - NOT NULL checks for critical columns\n\ - - Business logic (positive amounts, valid ranges, consistent dates)\n\ - - Referential integrity (orphaned foreign keys)\n\ - - Format validation (emails, phone numbers, codes)\n\ - - Enum/status field validity\n\ - - Date consistency (start before end, not in future where inappropriate)\n\ - \n\ - {}\n", - schema_text - ); - - let raw = call_ollama_chat( - &app, - &state, - system_prompt, - "Suggest validation rules".to_string(), - ) - .await?; - - let cleaned = raw.trim(); - let json_start = cleaned.find('[').unwrap_or(0); - let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len()); - let json_str = &cleaned[json_start..json_end]; - - let rules: Vec = serde_json::from_str(json_str).map_err(|e| { - TuskError::Ai(format!( - "Failed to parse AI response as JSON array: {}. Response: {}", - e, cleaned - )) - })?; - - Ok(rules) -} - -// --------------------------------------------------------------------------- -// Wave 2: AI Data Generator -// --------------------------------------------------------------------------- - -#[tauri::command] -pub async fn generate_test_data_preview( - app: AppHandle, - state: State<'_, Arc>, - params: GenerateDataParams, - gen_id: String, -) -> TuskResult { - let pool = state.get_pool(¶ms.connection_id).await?; - - let _ = app.emit( - "datagen-progress", - DataGenProgress { - gen_id: gen_id.clone(), - stage: "schema".to_string(), - percent: 10, - message: "Building schema context...".to_string(), - detail: None, - }, - ); - - let schema_text = build_schema_context(&state, ¶ms.connection_id).await?; - - // Get FK info for topological sort - let fk_rows = fetch_foreign_keys_raw(&pool).await?; - - let mut target_tables = vec![(params.schema.clone(), params.table.clone())]; - - if params.include_related { - // Add parent tables (tables referenced by FKs from target) - for fk in &fk_rows { - if fk.schema == params.schema && fk.table == params.table { - let parent = (fk.ref_schema.clone(), fk.ref_table.clone()); - if !target_tables.contains(&parent) { - target_tables.push(parent); - } - } - } - } - - let fk_edges: Vec<(String, String, String, String)> = fk_rows - .iter() - .map(|fk| { - ( - fk.schema.clone(), - fk.table.clone(), - fk.ref_schema.clone(), - fk.ref_table.clone(), - ) - }) - .collect(); - - let sorted_tables = topological_sort_tables(&fk_edges, &target_tables); - let insert_order: Vec = sorted_tables - .iter() - .map(|(s, t)| format!("{}.{}", s, t)) - .collect(); - - let row_count = params.row_count.min(1000); - - let _ = app.emit( - "datagen-progress", - DataGenProgress { - gen_id: gen_id.clone(), - stage: "generating".to_string(), - percent: 30, - message: "AI is generating test data...".to_string(), - detail: None, - }, - ); - - let tables_desc: Vec = sorted_tables - .iter() - .map(|(s, t)| { - let count = if s == ¶ms.schema && t == ¶ms.table { - row_count - } else { - (row_count / 3).max(1) - }; - format!("{}.{}: {} rows", s, t, count) - }) - .collect(); - - let custom = params - .custom_instructions - .as_deref() - .unwrap_or("Generate realistic sample data"); - - let system_prompt = format!( - "You are a PostgreSQL test data generator. Generate realistic test data as JSON.\n\ - \n\ - OUTPUT FORMAT:\n\ - - Return ONLY a JSON object where keys are \"schema.table\" and values are arrays of row objects.\n\ - - Each row object has column names as keys and values matching the column types.\n\ - - No markdown, no explanations, no code fences.\n\ - - Example: {{\"public.users\": [{{\"name\": \"Alice\", \"email\": \"alice@example.com\"}}]}}\n\ - \n\ - RULES:\n\ - 1. Respect column types exactly (text, integer, boolean, timestamp, uuid, etc.)\n\ - 2. Use valid foreign key values - parent tables are generated first, reference their IDs\n\ - 3. Respect enum types - use only valid enum values\n\ - 4. Omit auto-increment/serial/identity columns (they have DEFAULT auto-increment)\n\ - 5. Generate realistic data: real names, valid emails, plausible dates, etc.\n\ - 6. Respect NOT NULL constraints\n\ - 7. For UUID columns, generate valid UUIDs\n\ - 8. For timestamp columns, use ISO 8601 format\n\ - \n\ - Tables to generate (in this exact order):\n\ - {}\n\ - \n\ - Custom instructions: {}\n\ - \n\ - {}\n", - tables_desc.join("\n"), - custom, - schema_text - ); - - let raw = call_ollama_chat( - &app, - &state, - system_prompt, - format!("Generate test data for {} tables", sorted_tables.len()), - ) - .await?; - - let _ = app.emit( - "datagen-progress", - DataGenProgress { - gen_id: gen_id.clone(), - stage: "parsing".to_string(), - percent: 80, - message: "Parsing generated data...".to_string(), - detail: None, - }, - ); - - // Parse JSON response - let cleaned = raw.trim(); - let json_start = cleaned.find('{').unwrap_or(0); - let json_end = cleaned.rfind('}').map(|i| i + 1).unwrap_or(cleaned.len()); - let json_str = &cleaned[json_start..json_end]; - - let data_map: HashMap>> = serde_json::from_str(json_str) - .map_err(|e| { - TuskError::Ai(format!( - "Failed to parse generated data: {}. Response: {}", - e, - &cleaned[..cleaned.len().min(500)] - )) - })?; - - let mut tables = Vec::new(); - let mut total_rows: u32 = 0; - - for (schema, table) in &sorted_tables { - let key = format!("{}.{}", schema, table); - if let Some(rows_data) = data_map.get(&key) { - let columns: Vec = if let Some(first) = rows_data.first() { - first.keys().cloned().collect() - } else { - Vec::new() - }; - - let rows: Vec> = rows_data - .iter() - .map(|row_map| { - columns - .iter() - .map(|col| row_map.get(col).cloned().unwrap_or(Value::Null)) - .collect() - }) - .collect(); - - let count = rows.len() as u32; - total_rows += count; - - tables.push(GeneratedTableData { - schema: schema.clone(), - table: table.clone(), - columns, - rows, - row_count: count, - }); - } - } - - let _ = app.emit( - "datagen-progress", - DataGenProgress { - gen_id: gen_id.clone(), - stage: "done".to_string(), - percent: 100, - message: "Data generation complete".to_string(), - detail: Some(format!( - "{} rows across {} tables", - total_rows, - tables.len() - )), - }, - ); - - Ok(GeneratedDataPreview { - tables, - insert_order, - total_rows, - }) -} - -#[tauri::command] -pub async fn insert_generated_data( - state: State<'_, Arc>, - connection_id: String, - preview: GeneratedDataPreview, -) -> TuskResult { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pool = state.get_pool(&connection_id).await?; - let mut tx = pool.begin().await.map_err(TuskError::Database)?; - - // Defer constraints for circular FKs - sqlx::query("SET CONSTRAINTS ALL DEFERRED") - .execute(&mut *tx) - .await - .map_err(TuskError::Database)?; - - let mut total_inserted: u64 = 0; - - for table_data in &preview.tables { - if table_data.columns.is_empty() || table_data.rows.is_empty() { - continue; - } - - let qualified = format!( - "{}.{}", - escape_ident(&table_data.schema), - escape_ident(&table_data.table) - ); - let col_list: Vec = table_data.columns.iter().map(|c| escape_ident(c)).collect(); - let placeholders: Vec = (1..=table_data.columns.len()) - .map(|i| format!("${}", i)) - .collect(); - - let sql = format!( - "INSERT INTO {} ({}) VALUES ({})", - qualified, - col_list.join(", "), - placeholders.join(", ") - ); - - for row in &table_data.rows { - let mut query = sqlx::query(&sql); - for val in row { - query = bind_json_value(query, val); - } - query.execute(&mut *tx).await.map_err(TuskError::Database)?; - total_inserted += 1; - } - } - - tx.commit().await.map_err(TuskError::Database)?; - - // Invalidate schema cache since data changed - state.invalidate_schema_cache(&connection_id).await; - - Ok(total_inserted) -} - -// --------------------------------------------------------------------------- -// Wave 3A: Smart Index Advisor -// --------------------------------------------------------------------------- - -#[tauri::command] -pub async fn get_index_advisor_report( - app: AppHandle, - state: State<'_, Arc>, - connection_id: String, -) -> TuskResult { - let pool = state.get_pool(&connection_id).await?; - - // Fetch table stats, index stats, and slow queries concurrently - let (table_stats_res, index_stats_res, slow_queries_res) = tokio::join!( - sqlx::query( - "SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \ - pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \ - pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \ - FROM pg_stat_user_tables \ - ORDER BY seq_scan DESC \ - LIMIT 50" - ) - .fetch_all(&pool), - sqlx::query( - "SELECT schemaname, relname, indexrelname, idx_scan, \ - pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \ - pg_get_indexdef(indexrelid) AS definition \ - FROM pg_stat_user_indexes \ - ORDER BY idx_scan ASC \ - LIMIT 50", - ) - .fetch_all(&pool), - sqlx::query( - "SELECT query, calls, total_exec_time, mean_exec_time, rows \ - FROM pg_stat_statements \ - WHERE calls > 0 \ - ORDER BY mean_exec_time DESC \ - LIMIT 20", - ) - .fetch_all(&pool), - ); - - let table_stats: Vec = table_stats_res - .map_err(TuskError::Database)? - .iter() - .map(|r| TableStats { - schema: r.get(0), - table: r.get(1), - seq_scan: r.get(2), - idx_scan: r.get(3), - n_live_tup: r.get(4), - table_size: r.get(5), - index_size: r.get(6), - }) - .collect(); - - let index_stats: Vec = index_stats_res - .map_err(TuskError::Database)? - .iter() - .map(|r| IndexStats { - schema: r.get(0), - table: r.get(1), - index_name: r.get(2), - idx_scan: r.get(3), - index_size: r.get(4), - definition: r.get(5), - }) - .collect(); - - let (slow_queries, has_pg_stat_statements) = match slow_queries_res { - Ok(rows) => { - let queries: Vec = rows - .iter() - .map(|r| SlowQuery { - query: r.get(0), - calls: r.get(1), - total_time_ms: r.get(2), - mean_time_ms: r.get(3), - rows: r.get(4), - }) - .collect(); - (queries, true) - } - Err(_) => (Vec::new(), false), - }; - - // Build AI prompt for recommendations - let schema_text = build_schema_context(&state, &connection_id).await?; - - let mut stats_text = String::from("TABLE STATISTICS:\n"); - for ts in &table_stats { - stats_text.push_str(&format!( - " {}.{}: seq_scan={}, idx_scan={}, rows={}, size={}, idx_size={}\n", - ts.schema, - ts.table, - ts.seq_scan, - ts.idx_scan, - ts.n_live_tup, - ts.table_size, - ts.index_size - )); - } - - stats_text.push_str("\nINDEX STATISTICS:\n"); - for is in &index_stats { - stats_text.push_str(&format!( - " {}.{}.{}: scans={}, size={}, def={}\n", - is.schema, is.table, is.index_name, is.idx_scan, is.index_size, is.definition - )); - } - - if !slow_queries.is_empty() { - stats_text.push_str("\nSLOW QUERIES:\n"); - for sq in &slow_queries { - stats_text.push_str(&format!( - " calls={}, mean={:.1}ms, total={:.1}ms, rows={}: {}\n", - sq.calls, - sq.mean_time_ms, - sq.total_time_ms, - sq.rows, - sq.query.chars().take(200).collect::() - )); - } - } - - let system_prompt = format!( - "You are a PostgreSQL performance expert. Analyze the database statistics and recommend index changes.\n\ - \n\ - OUTPUT FORMAT:\n\ - - Return ONLY a JSON array of recommendation objects.\n\ - - No markdown, no explanations, no code fences.\n\ - - Each object: {{\"recommendation_type\": \"create_index\"|\"drop_index\"|\"replace_index\", \ - \"table_schema\": \"...\", \"table_name\": \"...\", \"index_name\": \"...\"|null, \ - \"ddl\": \"CREATE INDEX CONCURRENTLY ...\", \"rationale\": \"...\", \ - \"estimated_impact\": \"high\"|\"medium\"|\"low\", \"priority\": \"high\"|\"medium\"|\"low\"}}\n\ - \n\ - RULES:\n\ - 1. Prefer CREATE INDEX CONCURRENTLY to avoid locking\n\ - 2. Never suggest dropping PRIMARY KEY or UNIQUE indexes\n\ - 3. Suggest dropping indexes with 0 scans on tables with many rows\n\ - 4. Suggest composite indexes for commonly co-filtered columns\n\ - 5. Suggest partial indexes for low-cardinality boolean columns\n\ - 6. Consider covering indexes for frequently selected columns\n\ - 7. High seq_scan + high row count = strong candidate for new index\n\ - \n\ - {}\n\ - \n\ - {}\n", - stats_text, schema_text - ); - - let raw = call_ollama_chat( - &app, - &state, - system_prompt, - "Analyze indexes and provide recommendations".to_string(), - ) - .await?; - - let cleaned = raw.trim(); - let json_start = cleaned.find('[').unwrap_or(0); - let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len()); - let json_str = &cleaned[json_start..json_end]; - - let recommendations: Vec = - serde_json::from_str(json_str).unwrap_or_default(); - - Ok(IndexAdvisorReport { - table_stats, - index_stats, - slow_queries, - recommendations, - has_pg_stat_statements, - }) -} - -#[tauri::command] -pub async fn apply_index_recommendation( - state: State<'_, Arc>, - connection_id: String, - ddl: String, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - validate_index_ddl(&ddl)?; - - let pool = state.get_pool(&connection_id).await?; - - // CONCURRENTLY cannot run inside a transaction, execute directly - sqlx::query(&ddl) - .execute(&pool) - .await - .map_err(TuskError::Database)?; - - // Invalidate schema cache - state.invalidate_schema_cache(&connection_id).await; - - Ok(()) -} - #[cfg(test)] mod tests { use super::*; - // ── validate_select_statement ───────────────────────────── - - #[test] - fn select_valid_simple() { - assert!(validate_select_statement("SELECT 1").is_ok()); - } - - #[test] - fn select_valid_with_leading_whitespace() { - assert!(validate_select_statement(" SELECT * FROM users").is_ok()); - } - - #[test] - fn select_valid_lowercase() { - assert!(validate_select_statement("select * from users").is_ok()); - } - - #[test] - fn select_rejects_insert() { - assert!(validate_select_statement("INSERT INTO users VALUES (1)").is_err()); - } - - #[test] - fn select_rejects_delete() { - assert!(validate_select_statement("DELETE FROM users").is_err()); - } - - #[test] - fn select_rejects_drop() { - assert!(validate_select_statement("DROP TABLE users").is_err()); - } - - #[test] - fn select_rejects_empty() { - assert!(validate_select_statement("").is_err()); - } - - #[test] - fn select_rejects_whitespace_only() { - assert!(validate_select_statement(" ").is_err()); - } - - // NOTE: This test documents a known weakness — SELECT prefix allows injection - #[test] - fn select_allows_semicolon_after_select() { - // "SELECT 1; DROP TABLE users" starts with SELECT — passes validation - // This is a known limitation documented in the review - assert!(validate_select_statement("SELECT 1; DROP TABLE users").is_ok()); - } - - // ── validate_index_ddl ──────────────────────────────────── - - #[test] - fn ddl_valid_create_index() { - assert!(validate_index_ddl("CREATE INDEX idx_name ON users(email)").is_ok()); - } - - #[test] - fn ddl_valid_create_index_concurrently() { - assert!(validate_index_ddl("CREATE INDEX CONCURRENTLY idx ON t(c)").is_ok()); - } - - #[test] - fn ddl_valid_drop_index() { - assert!(validate_index_ddl("DROP INDEX idx_name").is_ok()); - } - - #[test] - fn ddl_valid_with_leading_whitespace() { - assert!(validate_index_ddl(" CREATE INDEX idx ON t(c)").is_ok()); - } - - #[test] - fn ddl_valid_lowercase() { - assert!(validate_index_ddl("create index idx on t(c)").is_ok()); - } - - #[test] - fn ddl_rejects_create_table() { - assert!(validate_index_ddl("CREATE TABLE evil(id int)").is_err()); - } - - #[test] - fn ddl_rejects_drop_table() { - assert!(validate_index_ddl("DROP TABLE users").is_err()); - } - - #[test] - fn ddl_rejects_alter_table() { - assert!(validate_index_ddl("ALTER TABLE users ADD COLUMN x int").is_err()); - } - - #[test] - fn ddl_rejects_empty() { - assert!(validate_index_ddl("").is_err()); - } - - // NOTE: Documents bypass weakness — semicolon after valid prefix - #[test] - fn ddl_allows_semicolon_injection() { - // "CREATE INDEX x ON t(c); DROP TABLE users" — passes validation - // Mitigated by sqlx single-statement execution - assert!(validate_index_ddl("CREATE INDEX x ON t(c); DROP TABLE users").is_ok()); - } - // ── clean_sql_response ──────────────────────────────────── #[test] diff --git a/src-tauri/src/commands/chat.rs b/src-tauri/src/commands/chat.rs new file mode 100644 index 0000000..7df16ed --- /dev/null +++ b/src-tauri/src/commands/chat.rs @@ -0,0 +1,869 @@ +use crate::commands::ai::{build_overview_context, call_ollama_chat_messages}; +use crate::commands::chat_tools::{ + find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool, + switch_database_tool, +}; +use crate::commands::memory::{append_memory_core, read_memory_core}; +use crate::commands::queries::execute_query_core; +use crate::error::{TuskError, TuskResult}; +use crate::models::ai::OllamaChatMessage; +use crate::models::chat::ChatMessage; +use crate::models::query_result::QueryResult; +use crate::state::AppState; +use chrono::Utc; +use serde_json::Value; +use std::sync::Arc; +use tauri::{AppHandle, State}; + +const MAX_HOPS: usize = 8; +/// Number of MOST RECENT run_query tool_results that get full sample-rows in +/// LLM history. Older ones are reduced to a marker so very long threads stay +/// within model context budget. +const RECENT_TOOL_RESULTS_FULL: usize = 4; +/// Sample-row cap for compressed run_query results in LLM history. +const RUN_QUERY_SAMPLE_ROWS: usize = 10; +/// Per-cell character cap when stringifying sample rows. +const CELL_CHAR_CAP: usize = 200; +/// Per text-tool-result character cap (list_tables, get_columns, etc). +const TEXT_TOOL_CHAR_CAP: usize = 10_000; + +// --------------------------------------------------------------------------- +// Action protocol +// --------------------------------------------------------------------------- + +#[derive(Debug)] +enum AgentAction { + Final { text: String }, + RunQuery { sql: String }, + ListDatabases, + ListTables { database: Option }, + GetColumns { tables: Vec }, + SwitchDatabase { database: String }, + Remember { note: String }, + SaveQuery { name: String, sql: String }, + FindQueries { text: String }, +} + +/// Parse the model's JSON response. Accepts both shapes the model tends to emit: +/// {"action":"X","field":"..."} — flat (matches our prompt) +/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention) +fn parse_agent_action(raw: &str) -> Result { + let v: Value = serde_json::from_str(raw).map_err(|e| e.to_string())?; + let obj = v.as_object().ok_or_else(|| "expected JSON object".to_string())?; + let action = obj + .get("action") + .and_then(|a| a.as_str()) + .ok_or_else(|| "missing field `action`".to_string())?; + + let lookup = |key: &str| -> Option<&Value> { + obj.get(key) + .or_else(|| obj.get("input").and_then(|i| i.as_object()).and_then(|i| i.get(key))) + }; + + match action { + "final" => { + let text = lookup("text") + .and_then(|v| v.as_str()) + .ok_or_else(|| "final action missing `text`".to_string())? + .to_string(); + Ok(AgentAction::Final { text }) + } + "run_query" => { + let sql = lookup("sql") + .and_then(|v| v.as_str()) + .ok_or_else(|| "run_query action missing `sql`".to_string())? + .to_string(); + Ok(AgentAction::RunQuery { sql }) + } + "list_databases" => Ok(AgentAction::ListDatabases), + "list_tables" => { + let database = lookup("database") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + Ok(AgentAction::ListTables { database }) + } + "get_columns" => { + let arr = lookup("tables") + .and_then(|v| v.as_array()) + .ok_or_else(|| "get_columns action missing `tables`: [...]".to_string())?; + let tables: Vec = arr + .iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + if tables.is_empty() { + return Err("get_columns `tables` array must not be empty".into()); + } + Ok(AgentAction::GetColumns { tables }) + } + "switch_database" => { + let database = lookup("database") + .and_then(|v| v.as_str()) + .ok_or_else(|| "switch_database missing `database`".to_string())? + .to_string(); + Ok(AgentAction::SwitchDatabase { database }) + } + "remember" => { + let note = lookup("note") + .and_then(|v| v.as_str()) + .ok_or_else(|| "remember action missing `note`".to_string())? + .trim() + .to_string(); + if note.is_empty() { + return Err("remember `note` must not be empty".into()); + } + Ok(AgentAction::Remember { note }) + } + "save_query" => { + let name = lookup("name") + .and_then(|v| v.as_str()) + .ok_or_else(|| "save_query missing `name`".to_string())? + .trim() + .to_string(); + let sql = lookup("sql") + .and_then(|v| v.as_str()) + .ok_or_else(|| "save_query missing `sql`".to_string())? + .trim() + .to_string(); + if name.is_empty() { + return Err("save_query `name` must not be empty".into()); + } + if sql.is_empty() { + return Err("save_query `sql` must not be empty".into()); + } + Ok(AgentAction::SaveQuery { name, sql }) + } + "find_queries" => { + let text = lookup("text") + .and_then(|v| v.as_str()) + .ok_or_else(|| "find_queries missing `text`".to_string())? + .trim() + .to_string(); + if text.is_empty() { + return Err("find_queries `text` must not be empty".into()); + } + Ok(AgentAction::FindQueries { text }) + } + // Legacy from earlier iterations — silently ignored at parse time so the + // model can recover with a different action. + "get_schema" => Err( + "get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(), + ), + other => Err(format!("unknown action `{}`", other)), + } +} + +// --------------------------------------------------------------------------- +// id / time helpers +// --------------------------------------------------------------------------- + +fn now_ms() -> i64 { + Utc::now().timestamp_millis() +} + +fn new_id(prefix: &str) -> String { + format!("{}-{}-{}", prefix, now_ms(), rand_suffix()) +} + +fn rand_suffix() -> String { + use std::time::{SystemTime, UNIX_EPOCH}; + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.subsec_nanos()) + .unwrap_or(0); + format!("{:x}", nanos) +} + +// --------------------------------------------------------------------------- +// System prompt +// --------------------------------------------------------------------------- + +fn system_prompt(overview: &str, memory: &str) -> String { + let overview_block = if overview.is_empty() { + "(overview unavailable; respond with `final` asking the user to reconnect.)".to_string() + } else { + overview.to_string() + }; + + let memory_block = if memory.trim().is_empty() { + "(empty — call remember() when you discover non-obvious facts about this database)".to_string() + } else { + memory.to_string() + }; + + format!( + r#"ROLE: Tusk's data assistant. Reply in the user's language. + +You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On every turn output STRICT JSON — exactly one of these shapes, with all fields at the root (no `input` wrapper, no markdown fences): + + {{"action":"list_databases"}} + Refresh the database list when the OVERVIEW seems stale. + + {{"action":"list_tables"}} + List tables in the active database. + + {{"action":"list_tables","database":""}} + List tables in a specific database (PostgreSQL: requires switch_database before run_query). + + {{"action":"get_columns","tables":["schema.table","schema.table2"]}} + Load full column info (types, PK, FK, comments, enums) for the listed tables. Use this BEFORE writing SQL when you don't already know the columns. + + {{"action":"switch_database","database":""}} + Change the active database. Required for PostgreSQL when the user's question concerns data in another database. ClickHouse rarely needs this — `db.table` qualifiers are allowed without switching. + + {{"action":"run_query","sql":"SELECT ..."}} + Execute read-only SQL (SELECT / WITH ... SELECT / EXPLAIN / SHOW / DESCRIBE). Mutating SQL is rejected by the read-only guard. + + {{"action":"remember","note":""}} + Persist a non-obvious fact about THIS database for future sessions: column semantics, naming conventions, business-rule encodings, gotchas. Keep notes < 200 chars. The user sees and can edit your notes in the Memory sidebar tab. + + {{"action":"find_queries","text":""}} + Search saved queries (your past work + user-saved). Use BEFORE writing complex SQL — a usable variant may already exist. Top 10 matches with SQL preview are returned. + + {{"action":"save_query","name":"","sql":""}} + Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved. + + {{"action":"final","text":"..."}} + End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling). + +WORKFLOW + 1. Read LEARNED NOTES below first — the user (or your past self) may have already documented relevant facts. + 2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question. + 3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs. + 4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns. + 5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first. + 6. Execute run_query. + 7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here. + 8. If the query is likely to be re-run later (a real report-style request, not a one-off lookup), call `save_query` with a concise `name`. + 9. Answer with `final`. + +RULES + - Use ONLY identifiers visible to you (overview / list_tables / get_columns output). Don't pluralize, translate, or guess. + - LIMIT on ad-hoc SELECTs unless aggregating. + - On SQL error retry once with a fix; on the second failure respond with `final` explaining what's missing. + - `remember` is for durable facts, not transient observations. Don't memorise query results — only insights about the schema/data model that aren't already in the OVERVIEW. + +═══════════════════════════════════════════════════════════════ +LEARNED NOTES (per-connection memory; user can edit in sidebar → Memory) +═══════════════════════════════════════════════════════════════ +{memory} +═══════════════════════════════════════════════════════════════ + +═══════════════════════════════════════════════════════════════ +OVERVIEW (refreshed every turn) +═══════════════════════════════════════════════════════════════ +{overview} +═══════════════════════════════════════════════════════════════ +"#, + hops = MAX_HOPS, + memory = memory_block, + overview = overview_block, + ) +} + +// --------------------------------------------------------------------------- +// Compressed history projection +// --------------------------------------------------------------------------- + +/// Compact view of a QueryResult for re-injection into the LLM history. +/// Keeps just enough for the model to reason about the next step (column +/// names, types, total row count, first N rows) without the full payload. +fn compact_query_result(result: &QueryResult) -> Value { + let total = result.rows.len(); + let sample: Vec> = result + .rows + .iter() + .take(RUN_QUERY_SAMPLE_ROWS) + .map(|row| row.iter().map(truncate_cell).collect()) + .collect(); + serde_json::json!({ + "columns": result.columns, + "types": result.types, + "row_count": total, + "execution_time_ms": result.execution_time_ms, + "sample_rows": sample, + "truncated": total > RUN_QUERY_SAMPLE_ROWS, + }) +} + +fn truncate_cell(v: &Value) -> Value { + match v { + Value::String(s) if s.chars().count() > CELL_CHAR_CAP => { + let truncated: String = s.chars().take(CELL_CHAR_CAP).collect(); + Value::String(format!("{}…", truncated)) + } + other => other.clone(), + } +} + +fn truncate_text(text: &str) -> String { + if text.len() <= TEXT_TOOL_CHAR_CAP { + text.to_string() + } else { + let mut out = text[..TEXT_TOOL_CHAR_CAP].to_string(); + out.push_str("\n…(truncated)"); + out + } +} + +fn build_history( + messages: &[ChatMessage], + overview_text: &str, + memory_text: &str, +) -> Vec { + // Index of run_query tool_results in `messages`. Used to mark which ones + // get full sample rows vs the "(rows omitted)" placeholder. + let run_query_indices: Vec = messages + .iter() + .enumerate() + .filter_map(|(i, m)| match m { + ChatMessage::ToolResult { tool, .. } if tool == "run_query" => Some(i), + _ => None, + }) + .collect(); + let keep_full_after_index: usize = if run_query_indices.len() <= RECENT_TOOL_RESULTS_FULL { + 0 + } else { + run_query_indices[run_query_indices.len() - RECENT_TOOL_RESULTS_FULL] + }; + + let mut out = Vec::with_capacity(messages.len() + 1); + out.push(OllamaChatMessage { + role: "system".to_string(), + content: system_prompt(overview_text, memory_text), + }); + + for (idx, m) in messages.iter().enumerate() { + match m { + ChatMessage::User { text, .. } => out.push(OllamaChatMessage { + role: "user".to_string(), + content: text.clone(), + }), + ChatMessage::Assistant { text, .. } => out.push(OllamaChatMessage { + role: "assistant".to_string(), + content: serde_json::json!({ "action": "final", "text": text }).to_string(), + }), + ChatMessage::ToolCall { tool, input_json, .. } => { + if tool == "get_schema" { + continue; // legacy + } + let mut envelope = serde_json::Map::new(); + envelope.insert("action".to_string(), Value::String(tool.clone())); + if let Ok(Value::Object(input)) = serde_json::from_str::(input_json) { + for (k, v) in input { + envelope.insert(k, v); + } + } + out.push(OllamaChatMessage { + role: "assistant".to_string(), + content: Value::Object(envelope).to_string(), + }); + } + ChatMessage::ToolResult { + tool, + is_error, + text, + result, + .. + } => { + if tool == "get_schema" { + continue; // legacy + } + let payload = match tool.as_str() { + "run_query" => { + if *is_error { + serde_json::json!({ + "tool": "run_query", + "error": true, + "text": text.clone().unwrap_or_default(), + }) + } else if idx < keep_full_after_index { + serde_json::json!({ + "tool": "run_query", + "error": false, + "note": "rows omitted (older result; user has it in the UI above)", + }) + } else if let Some(qr) = result { + serde_json::json!({ + "tool": "run_query", + "error": false, + "result": compact_query_result(qr), + }) + } else { + serde_json::json!({ + "tool": "run_query", + "error": false, + "result": null, + }) + } + } + // Text-only tools — pass through with cap. + _ => serde_json::json!({ + "tool": tool, + "error": *is_error, + "text": text.as_deref().map(truncate_text), + }), + }; + + out.push(OllamaChatMessage { + role: "user".to_string(), + content: format!("TOOL_RESULT {}", payload), + }); + } + } + } + out +} + +// --------------------------------------------------------------------------- +// chat_send +// --------------------------------------------------------------------------- + +#[tauri::command] +pub async fn chat_send( + app: AppHandle, + state: State<'_, Arc>, + connection_id: String, + messages: Vec, +) -> TuskResult> { + let mut new_messages: Vec = Vec::new(); + let mut working: Vec = messages; + + for _hop in 0..MAX_HOPS { + // Overview is rebuilt per turn — cheap (cached) and reflects the active DB + // even if the user (or the agent) just switched it. + let overview_text = build_overview_context(&state, &connection_id) + .await + .unwrap_or_default(); + // Memory is read fresh each turn so user-side edits in the Memory tab + // are visible to the agent immediately. + let memory_text = read_memory_core(&app, &connection_id).unwrap_or_default(); + + let history = build_history(&working, &overview_text, &memory_text); + let raw = + call_ollama_chat_messages(&app, &state, history, Some("json".to_string())).await?; + let trimmed = raw.trim(); + + let action = match parse_agent_action(trimmed) { + Ok(a) => a, + Err(parse_err) => { + let msg = ChatMessage::Assistant { + id: new_id("asst"), + text: format!( + "{}\n\n_(Note: model returned non-protocol output: {})_", + trimmed, parse_err + ), + created_at: now_ms(), + }; + new_messages.push(msg.clone()); + working.push(msg); + return Ok(new_messages); + } + }; + + match action { + AgentAction::Final { text } => { + let msg = ChatMessage::Assistant { + id: new_id("asst"), + text, + created_at: now_ms(), + }; + new_messages.push(msg.clone()); + working.push(msg); + return Ok(new_messages); + } + AgentAction::RunQuery { sql } => { + push_tool_call( + &mut new_messages, + &mut working, + "run_query", + serde_json::json!({ "sql": sql }).to_string(), + ); + let result = match execute_query_core(&state, &connection_id, &sql).await { + Ok(qr) => ChatMessage::ToolResult { + id: new_id("res"), + tool: "run_query".to_string(), + is_error: false, + text: None, + result: Some(qr), + created_at: now_ms(), + }, + Err(e) => { + let hint = match e { + TuskError::ReadOnly => "\n\nRead-only mode is on. Toggle it off in the toolbar to allow writes.", + _ => "", + }; + ChatMessage::ToolResult { + id: new_id("res"), + tool: "run_query".to_string(), + is_error: true, + text: Some(format!("{}{}", e, hint)), + result: None, + created_at: now_ms(), + } + } + }; + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::ListDatabases => { + push_tool_call( + &mut new_messages, + &mut working, + "list_databases", + "{}".to_string(), + ); + let result = run_text_tool( + list_databases_tool(&state, &connection_id).await, + "list_databases", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::ListTables { database } => { + let input_json = match &database { + Some(db) => serde_json::json!({ "database": db }).to_string(), + None => "{}".to_string(), + }; + push_tool_call(&mut new_messages, &mut working, "list_tables", input_json); + let result = run_text_tool( + list_tables_tool(&app, &state, &connection_id, database.as_deref()).await, + "list_tables", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::GetColumns { tables } => { + push_tool_call( + &mut new_messages, + &mut working, + "get_columns", + serde_json::json!({ "tables": tables }).to_string(), + ); + let result = run_text_tool( + get_columns_tool(&state, &connection_id, &tables).await, + "get_columns", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::SwitchDatabase { database } => { + push_tool_call( + &mut new_messages, + &mut working, + "switch_database", + serde_json::json!({ "database": &database }).to_string(), + ); + let result = run_text_tool( + switch_database_tool(&app, &state, &connection_id, &database).await, + "switch_database", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::Remember { note } => { + push_tool_call( + &mut new_messages, + &mut working, + "remember", + serde_json::json!({ "note": ¬e }).to_string(), + ); + let outcome = append_memory_core(&app, &connection_id, ¬e) + .map(|_| format!("Saved note ({} chars).", note.len())); + let result = run_text_tool(outcome, "remember"); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::SaveQuery { name, sql } => { + push_tool_call( + &mut new_messages, + &mut working, + "save_query", + serde_json::json!({ "name": &name, "sql": &sql }).to_string(), + ); + let result = run_text_tool( + save_query_tool(&app, &connection_id, &name, &sql).await, + "save_query", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::FindQueries { text } => { + push_tool_call( + &mut new_messages, + &mut working, + "find_queries", + serde_json::json!({ "text": &text }).to_string(), + ); + let result = run_text_tool( + find_queries_tool(&app, &connection_id, &text).await, + "find_queries", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + } + } + + let msg = ChatMessage::Assistant { + id: new_id("asst"), + text: format!( + "Stopped after {} tool calls without a final answer. Try rephrasing or simplifying the question.", + MAX_HOPS + ), + created_at: now_ms(), + }; + new_messages.push(msg); + Ok(new_messages) +} + +fn push_tool_call( + new_messages: &mut Vec, + working: &mut Vec, + tool: &str, + input_json: String, +) { + let call = ChatMessage::ToolCall { + id: new_id("call"), + tool: tool.to_string(), + input_json, + created_at: now_ms(), + }; + new_messages.push(call.clone()); + working.push(call); +} + +fn push_tool_result( + new_messages: &mut Vec, + working: &mut Vec, + result: ChatMessage, +) { + new_messages.push(result.clone()); + working.push(result); +} + +fn run_text_tool(outcome: TuskResult, tool: &str) -> ChatMessage { + match outcome { + Ok(text) => ChatMessage::ToolResult { + id: new_id("res"), + tool: tool.to_string(), + is_error: false, + text: Some(text), + result: None, + created_at: now_ms(), + }, + Err(e) => ChatMessage::ToolResult { + id: new_id("res"), + tool: tool.to_string(), + is_error: true, + text: Some(e.to_string()), + result: None, + created_at: now_ms(), + }, + } +} + +// --------------------------------------------------------------------------- +// tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_flat_run_query() { + let a = parse_agent_action(r#"{"action":"run_query","sql":"SELECT 1"}"#).unwrap(); + match a { + AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 1"), + _ => panic!("wrong variant"), + } + } + + #[test] + fn parses_nested_run_query() { + let a = + parse_agent_action(r#"{"action":"run_query","input":{"sql":"SELECT 2"}}"#).unwrap(); + match a { + AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 2"), + _ => panic!("wrong variant"), + } + } + + #[test] + fn parses_get_columns() { + let a = parse_agent_action( + r#"{"action":"get_columns","tables":["public.users","public.orders"]}"#, + ) + .unwrap(); + match a { + AgentAction::GetColumns { tables } => { + assert_eq!(tables, vec!["public.users", "public.orders"]); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn parses_get_columns_nested() { + let a = parse_agent_action( + r#"{"action":"get_columns","input":{"tables":["public.t"]}}"#, + ) + .unwrap(); + match a { + AgentAction::GetColumns { tables } => assert_eq!(tables, vec!["public.t"]), + _ => panic!("wrong variant"), + } + } + + #[test] + fn rejects_get_columns_empty_tables() { + assert!(parse_agent_action(r#"{"action":"get_columns","tables":[]}"#).is_err()); + } + + #[test] + fn parses_switch_database() { + let a = parse_agent_action(r#"{"action":"switch_database","database":"orders_db"}"#) + .unwrap(); + match a { + AgentAction::SwitchDatabase { database } => assert_eq!(database, "orders_db"), + _ => panic!("wrong variant"), + } + } + + #[test] + fn parses_list_tables_optional_db() { + let a1 = parse_agent_action(r#"{"action":"list_tables"}"#).unwrap(); + match a1 { + AgentAction::ListTables { database } => assert!(database.is_none()), + _ => panic!("wrong variant"), + } + let a2 = parse_agent_action(r#"{"action":"list_tables","database":"x"}"#).unwrap(); + match a2 { + AgentAction::ListTables { database } => assert_eq!(database.as_deref(), Some("x")), + _ => panic!("wrong variant"), + } + } + + #[test] + fn rejects_unknown_action() { + assert!(parse_agent_action(r#"{"action":"nuke","yes":true}"#).is_err()); + } + + #[test] + fn parses_remember_flat() { + let a = parse_agent_action( + r#"{"action":"remember","note":"trips.started_at is NULL for cancelled"}"#, + ) + .unwrap(); + match a { + AgentAction::Remember { note } => { + assert_eq!(note, "trips.started_at is NULL for cancelled"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn parses_remember_nested() { + let a = parse_agent_action( + r#"{"action":"remember","input":{"note":" surrounded by spaces "}}"#, + ) + .unwrap(); + match a { + AgentAction::Remember { note } => { + // trim happens in parser + assert_eq!(note, "surrounded by spaces"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn rejects_remember_without_note() { + assert!(parse_agent_action(r#"{"action":"remember"}"#).is_err()); + } + + #[test] + fn rejects_remember_empty_note() { + assert!(parse_agent_action(r#"{"action":"remember","note":" "}"#).is_err()); + } + + #[test] + fn parses_save_query_flat() { + let a = parse_agent_action( + r#"{"action":"save_query","name":"GMV last 30d","sql":"SELECT 1"}"#, + ) + .unwrap(); + match a { + AgentAction::SaveQuery { name, sql } => { + assert_eq!(name, "GMV last 30d"); + assert_eq!(sql, "SELECT 1"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn parses_save_query_nested() { + let a = parse_agent_action( + r#"{"action":"save_query","input":{"name":"x","sql":"SELECT 2"}}"#, + ) + .unwrap(); + match a { + AgentAction::SaveQuery { name, sql } => { + assert_eq!(name, "x"); + assert_eq!(sql, "SELECT 2"); + } + _ => panic!("wrong variant"), + } + } + + #[test] + fn rejects_save_query_missing_fields() { + assert!(parse_agent_action(r#"{"action":"save_query","name":"x"}"#).is_err()); + assert!(parse_agent_action(r#"{"action":"save_query","sql":"SELECT 1"}"#).is_err()); + assert!( + parse_agent_action(r#"{"action":"save_query","name":" ","sql":"SELECT 1"}"#).is_err() + ); + } + + #[test] + fn parses_find_queries() { + let a = parse_agent_action(r#"{"action":"find_queries","text":"gmv"}"#).unwrap(); + match a { + AgentAction::FindQueries { text } => assert_eq!(text, "gmv"), + _ => panic!("wrong variant"), + } + } + + #[test] + fn rejects_find_queries_empty_text() { + assert!(parse_agent_action(r#"{"action":"find_queries","text":""}"#).is_err()); + } + + #[test] + fn rejects_legacy_get_schema() { + assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err()); + } + + #[test] + fn truncates_long_cell() { + let long = "a".repeat(CELL_CHAR_CAP + 50); + let v = truncate_cell(&Value::String(long)); + let s = v.as_str().unwrap(); + assert!(s.ends_with('…')); + assert!(s.chars().count() <= CELL_CHAR_CAP + 1); + } + + #[test] + fn compact_drops_rows_beyond_sample() { + let mut rows = Vec::new(); + for i in 0..50 { + rows.push(vec![Value::Number(i.into())]); + } + let qr = QueryResult { + columns: vec!["id".into()], + types: vec!["INT4".into()], + rows, + row_count: 50, + execution_time_ms: 1, + }; + let v = compact_query_result(&qr); + let sample = v.get("sample_rows").unwrap().as_array().unwrap(); + assert_eq!(sample.len(), RUN_QUERY_SAMPLE_ROWS); + assert_eq!(v.get("truncated").unwrap(), &Value::Bool(true)); + assert_eq!(v.get("row_count").unwrap().as_u64().unwrap(), 50); + } +} diff --git a/src-tauri/src/commands/chat_tools.rs b/src-tauri/src/commands/chat_tools.rs new file mode 100644 index 0000000..cd18373 --- /dev/null +++ b/src-tauri/src/commands/chat_tools.rs @@ -0,0 +1,558 @@ +//! Chat agent tool handlers (chat v2). +//! +//! Each `*_tool` function returns a plain string formatted for direct injection +//! into the LLM tool-result history. They reuse the schema helpers in +//! `commands::ai` and `commands::schema` rather than re-implementing SQL. + +use crate::commands::ai::{ + fetch_column_comments, fetch_columns, fetch_enum_types, fetch_foreign_keys_raw, + fetch_table_comments, fetch_unique_constraints, format_table_block, ColumnInfo, +}; +use crate::commands::connections::{load_connection_config, switch_database_core}; +use crate::commands::saved_queries::{list_saved_queries_core, save_query_core}; +use crate::commands::schema::{list_databases_core, list_tables_core}; +use crate::error::{TuskError, TuskResult}; +use crate::models::saved_queries::SavedQuery; +use crate::state::{AppState, CachedVec, DbFlavor}; +use sqlx::{PgPool, Row}; +use std::collections::{BTreeMap, HashMap}; +use std::time::{Duration, Instant}; +use tauri::AppHandle; + +const TOOL_CACHE_TTL: Duration = Duration::from_secs(300); +const MAX_TABLES_PER_GET_COLUMNS: usize = 20; +const COLUMNS_TOOL_OUTPUT_CAP: usize = 15_000; + +// --------------------------------------------------------------------------- +// list_databases +// --------------------------------------------------------------------------- + +pub async fn list_databases_tool(state: &AppState, connection_id: &str) -> TuskResult { + let dbs = list_databases_core(state, connection_id).await?; + let active = active_db_name(state, connection_id).await; + + let mut out = format!("DATABASES ({}):", dbs.len()); + for db in &dbs { + if Some(db) == active.as_ref() { + out.push_str(&format!("\n * {} (active)", db)); + } else { + out.push_str(&format!("\n {}", db)); + } + } + Ok(out) +} + +// --------------------------------------------------------------------------- +// list_tables +// --------------------------------------------------------------------------- + +pub async fn list_tables_tool( + app: &AppHandle, + state: &AppState, + connection_id: &str, + db: Option<&str>, +) -> TuskResult { + let active = active_db_name(state, connection_id).await; + let target = db.map(|s| s.to_string()).or_else(|| active.clone()); + + let target_name = match target.as_deref() { + Some(n) => n.to_string(), + None => return Err(TuskError::Custom("No active database selected.".into())), + }; + + let same_as_active = active.as_deref() == Some(target_name.as_str()); + let flavor = state.get_flavor(connection_id).await; + + let table_names = match (flavor, same_as_active) { + (DbFlavor::ClickHouse, _) => list_tables_clickhouse(state, connection_id, &target_name).await?, + (_, true) => list_tables_active_pg(state, connection_id).await?, + (_, false) => list_tables_other_pg(app, state, connection_id, &target_name).await?, + }; + + let header = if same_as_active { + format!("TABLES IN ACTIVE DATABASE `{}` ({}):", target_name, table_names.len()) + } else { + format!("TABLES IN DATABASE `{}` ({}):", target_name, table_names.len()) + }; + let body: Vec = table_names.iter().map(|t| format!(" {}", t)).collect(); + Ok(format!("{}\n{}", header, body.join("\n"))) +} + +async fn list_tables_active_pg(state: &AppState, connection_id: &str) -> TuskResult> { + let schemas = crate::commands::schema::list_schemas_core(state, connection_id).await?; + let mut all: Vec = Vec::new(); + for schema in &schemas { + let tables = list_tables_core(state, connection_id, schema).await?; + for t in tables { + all.push(format!("{}.{}", schema, t.name)); + } + } + Ok(all) +} + +async fn list_tables_other_pg( + app: &AppHandle, + state: &AppState, + connection_id: &str, + target_db: &str, +) -> TuskResult> { + let cache_key = (connection_id.to_string(), target_db.to_string()); + if let Some(hit) = state.tables_by_db_cache.read().await.get(&cache_key).cloned() { + if hit.cached_at.elapsed() < TOOL_CACHE_TTL { + return Ok(hit.value); + } + } + + let config = load_connection_config(app, connection_id)?; + let url = config.connection_url_for_db(target_db); + let pool = PgPool::connect(&url).await.map_err(|e| { + TuskError::Custom(format!( + "Could not connect to database '{}' on this server: {}", + target_db, e + )) + })?; + let rows = sqlx::query( + "SELECT table_schema, table_name FROM information_schema.tables \ + WHERE table_schema NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \ + AND table_type = 'BASE TABLE' \ + ORDER BY table_schema, table_name", + ) + .fetch_all(&pool) + .await + .map_err(TuskError::Database)?; + pool.close().await; + + let names: Vec = rows + .iter() + .map(|r| format!("{}.{}", r.get::(0), r.get::(1))) + .collect(); + + state.tables_by_db_cache.write().await.insert( + cache_key, + CachedVec { + value: names.clone(), + cached_at: Instant::now(), + }, + ); + Ok(names) +} + +async fn list_tables_clickhouse( + state: &AppState, + connection_id: &str, + target_db: &str, +) -> TuskResult> { + let client = state.get_ch_client(connection_id).await?; + let escaped = target_db.replace('\\', "\\\\").replace('\'', "\\'"); + let sql = format!( + "SELECT name FROM system.tables WHERE database = '{}' ORDER BY name", + escaped + ); + let rows = client.fetch_objects(&sql).await?; + Ok(rows + .iter() + .filter_map(|r| r.get("name").and_then(|v| v.as_str()).map(String::from)) + .collect()) +} + +// --------------------------------------------------------------------------- +// get_columns +// --------------------------------------------------------------------------- + +pub async fn get_columns_tool( + state: &AppState, + connection_id: &str, + tables: &[String], +) -> TuskResult { + if tables.is_empty() { + return Err(TuskError::Custom("get_columns requires at least one table.".into())); + } + if tables.len() > MAX_TABLES_PER_GET_COLUMNS { + return Err(TuskError::Custom(format!( + "Too many tables ({}); split into batches of ≤{}.", + tables.len(), + MAX_TABLES_PER_GET_COLUMNS + ))); + } + + let active_db = active_db_name(state, connection_id).await.unwrap_or_default(); + + // Normalise: accept "schema.table", "db.schema.table" (drop db if == active), + // and "table" (assume schema "public" for PG, or active DB for CH). + let parsed: Vec<(String, String, String)> = tables + .iter() + .map(|raw| normalise_table_ref(raw, &active_db)) + .collect(); + + let flavor = state.get_flavor(connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return get_columns_clickhouse(state, connection_id, &parsed).await; + } + get_columns_postgres(state, connection_id, &parsed).await +} + +fn normalise_table_ref(raw: &str, active_db: &str) -> (String, String, String) { + // Returns (schema, table, original_input_for_diagnostics) + let trimmed = raw.trim().trim_matches('"').trim_matches('`'); + let parts: Vec<&str> = trimmed.split('.').collect(); + match parts.len() { + 1 => ("public".to_string(), parts[0].to_string(), raw.to_string()), + 2 => (parts[0].to_string(), parts[1].to_string(), raw.to_string()), + 3 => { + // "db.schema.table" — drop db prefix when it matches active + let (db, schema, table) = (parts[0], parts[1], parts[2]); + if db == active_db { + (schema.to_string(), table.to_string(), raw.to_string()) + } else { + // Different DB requested — let the caller surface a not-found warning. + // We still parse it as schema.table here. + (schema.to_string(), table.to_string(), raw.to_string()) + } + } + _ => ("public".to_string(), trimmed.to_string(), raw.to_string()), + } +} + +async fn get_columns_postgres( + state: &AppState, + connection_id: &str, + requested: &[(String, String, String)], +) -> TuskResult { + let pool = state.get_pool(connection_id).await?; + + let (col_res, fk_res, enum_res, tbl_comm_res, col_comm_res, unique_res) = tokio::join!( + fetch_columns(&pool), + fetch_foreign_keys_raw(&pool), + fetch_enum_types(&pool), + fetch_table_comments(&pool), + fetch_column_comments(&pool), + fetch_unique_constraints(&pool), + ); + let all_cols = col_res?; + let fk_rows = fk_res?; + let enum_map = enum_res.unwrap_or_default(); + let tbl_comments = tbl_comm_res.unwrap_or_default(); + let col_comments = col_comm_res.unwrap_or_default(); + let uniques = unique_res.unwrap_or_default(); + + // Build (schema, table) → Vec + let mut by_table: BTreeMap<(String, String), Vec> = BTreeMap::new(); + for ci in &all_cols { + by_table + .entry((ci.schema.clone(), ci.table.clone())) + .or_default() + .push(ci.clone()); + } + + let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new(); + for fk in &fk_rows { + if fk.columns.len() == 1 && fk.ref_columns.len() == 1 { + fk_inline.insert( + (fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()), + format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]), + ); + } + } + + let mut unique_map: HashMap<(String, String), Vec> = HashMap::new(); + for (schema, table, cols) in &uniques { + unique_map + .entry((schema.clone(), table.clone())) + .or_default() + .push(cols.join(", ")); + } + + let varchar_values: HashMap<(String, String, String), Vec> = HashMap::new(); + let jsonb_keys: HashMap<(String, String, String), Vec> = HashMap::new(); + + let mut output: Vec = Vec::new(); + let mut not_found: Vec = Vec::new(); + + for (schema, table, raw) in requested { + match by_table.get(&(schema.clone(), table.clone())) { + Some(cols) => { + let full_name = format!("{}.{}", schema, table); + format_table_block( + &full_name, + cols, + &tbl_comments, + &col_comments, + &fk_inline, + &enum_map, + &unique_map, + &varchar_values, + &jsonb_keys, + &mut output, + ); + } + None => not_found.push(raw.clone()), + } + } + + if !not_found.is_empty() { + let nearest = nearest_table_matches(&by_table, ¬_found); + let header = format!( + "WARNING: tables not found: {}.{}", + not_found.join(", "), + if nearest.is_empty() { + String::new() + } else { + format!(" Nearest matches: {}.", nearest.join(", ")) + } + ); + output.insert(0, header); + output.insert(1, String::new()); + } + + let mut text = output.join("\n"); + if text.len() > COLUMNS_TOOL_OUTPUT_CAP { + text.truncate(COLUMNS_TOOL_OUTPUT_CAP); + text.push_str("\n... (output truncated)"); + } + Ok(text) +} + +async fn get_columns_clickhouse( + state: &AppState, + connection_id: &str, + requested: &[(String, String, String)], +) -> TuskResult { + let client = state.get_ch_client(connection_id).await?; + let active_db = client.database.clone(); + + let where_terms: Vec = requested + .iter() + .map(|(schema, table, _)| { + // For CH, treat the parsed "schema" as the database name; if it equals + // a PG-conventional default ("public"), substitute with active CH database. + let dbn = if schema == "public" { active_db.clone() } else { schema.clone() }; + format!( + "(database = '{}' AND name = '{}')", + dbn.replace('\'', "\\'"), + table.replace('\'', "\\'") + ) + }) + .collect(); + let where_clause = where_terms.join(" OR "); + + let sql = format!( + "SELECT database, table, name, type, default_expression, is_in_primary_key, comment, position \ + FROM system.columns WHERE {} ORDER BY database, table, position", + where_clause + ); + let rows = client.fetch_objects(&sql).await?; + + // Group by (database, table) + let mut grouped: BTreeMap<(String, String), Vec<&serde_json::Map>> = + BTreeMap::new(); + for row in &rows { + let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let tbl = row.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string(); + grouped.entry((dbn, tbl)).or_default().push(row); + } + + // Track which requested tables were found + let mut output = String::new(); + let mut not_found: Vec = Vec::new(); + for (schema, table, raw) in requested { + let dbn = if schema == "public" { active_db.clone() } else { schema.clone() }; + match grouped.get(&(dbn.clone(), table.clone())) { + Some(cols) => { + output.push_str(&format!("\nTABLE {}.{}\n", dbn, table)); + for col in cols { + let name = col.get("name").and_then(|v| v.as_str()).unwrap_or(""); + let dtype = col.get("type").and_then(|v| v.as_str()).unwrap_or(""); + let is_pk = matches!( + col.get("is_in_primary_key"), + Some(serde_json::Value::Number(n)) if n.as_i64() == Some(1) + ) || matches!( + col.get("is_in_primary_key"), + Some(serde_json::Value::String(s)) if s == "1" + ); + let default = col.get("default_expression").and_then(|v| v.as_str()).unwrap_or(""); + let comment = col.get("comment").and_then(|v| v.as_str()).unwrap_or(""); + let mut line = format!(" {} {}", name, dtype); + if is_pk { + line.push_str(" [PK]"); + } + if !default.is_empty() { + line.push_str(&format!(" DEFAULT {}", default)); + } + if !comment.is_empty() { + line.push_str(&format!(" -- {}", comment)); + } + output.push_str(&line); + output.push('\n'); + } + } + None => not_found.push(raw.clone()), + } + } + + let mut header = String::new(); + if !not_found.is_empty() { + header.push_str(&format!( + "WARNING: tables not found: {}\n\n", + not_found.join(", ") + )); + } + let mut combined = format!("{}{}", header, output.trim_start()); + if combined.len() > COLUMNS_TOOL_OUTPUT_CAP { + combined.truncate(COLUMNS_TOOL_OUTPUT_CAP); + combined.push_str("\n... (output truncated)"); + } + Ok(combined) +} + +fn nearest_table_matches( + by_table: &BTreeMap<(String, String), Vec>, + missing: &[String], +) -> Vec { + let all: Vec = by_table + .keys() + .map(|(s, t)| format!("{}.{}", s, t)) + .collect(); + let mut hints: Vec = Vec::new(); + for m in missing { + let needle = m.to_lowercase(); + let mut candidates: Vec<&String> = all + .iter() + .filter(|n| { + let lower = n.to_lowercase(); + lower.contains(&needle) || needle.contains(lower.split('.').last().unwrap_or("")) + }) + .take(3) + .collect(); + candidates.dedup(); + for c in candidates { + if !hints.contains(c) { + hints.push(c.clone()); + } + } + } + hints +} + +// --------------------------------------------------------------------------- +// switch_database +// --------------------------------------------------------------------------- + +pub async fn switch_database_tool( + app: &AppHandle, + state: &AppState, + connection_id: &str, + target_db: &str, +) -> TuskResult { + let config = load_connection_config(app, connection_id)?; + + // Verify target exists in cluster + let dbs = list_databases_core(state, connection_id).await?; + if !dbs.iter().any(|d| d == target_db) { + return Err(TuskError::Custom(format!( + "Database '{}' does not exist on this server. Available: {}", + target_db, + dbs.join(", ") + ))); + } + + switch_database_core(state, &config, target_db).await?; + Ok(format!("Switched active database to '{}'.", target_db)) +} + +// --------------------------------------------------------------------------- +// helpers +// --------------------------------------------------------------------------- + +async fn active_db_name(state: &AppState, connection_id: &str) -> Option { + let flavor = state.get_flavor(connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return state + .get_ch_client(connection_id) + .await + .ok() + .map(|c| c.database.clone()); + } + let pool = state.get_pool(connection_id).await.ok()?; + sqlx::query_scalar::<_, String>("SELECT current_database()") + .fetch_one(&pool) + .await + .ok() +} + +// --------------------------------------------------------------------------- +// save_query / find_queries (chat v3 — F2) +// --------------------------------------------------------------------------- + +const FIND_QUERIES_LIMIT: usize = 10; +const FIND_QUERIES_SQL_PREVIEW_CHARS: usize = 500; + +pub async fn save_query_tool( + app: &AppHandle, + connection_id: &str, + name: &str, + sql: &str, +) -> TuskResult { + let trimmed_name = name.trim(); + let trimmed_sql = sql.trim(); + if trimmed_name.is_empty() { + return Err(TuskError::Custom("save_query: name must not be empty".into())); + } + if trimmed_sql.is_empty() { + return Err(TuskError::Custom("save_query: sql must not be empty".into())); + } + + let entry = SavedQuery { + id: uuid::Uuid::new_v4().to_string(), + name: trimmed_name.to_string(), + sql: trimmed_sql.to_string(), + connection_id: Some(connection_id.to_string()), + created_at: chrono::Utc::now().to_rfc3339(), + }; + save_query_core(app, entry).await?; + Ok(format!("Saved query \"{}\" — visible in sidebar → Saved.", trimmed_name)) +} + +pub async fn find_queries_tool( + app: &AppHandle, + connection_id: &str, + text: &str, +) -> TuskResult { + let trimmed = text.trim(); + if trimmed.is_empty() { + return Err(TuskError::Custom("find_queries: text must not be empty".into())); + } + + let all = list_saved_queries_core(app, Some(trimmed)).await?; + let matches: Vec = all + .into_iter() + .filter(|q| q.connection_id.as_deref() == Some(connection_id)) + .take(FIND_QUERIES_LIMIT) + .collect(); + + if matches.is_empty() { + return Ok(format!( + "No saved queries match \"{}\" for this connection.", + trimmed + )); + } + + let mut out = format!( + "Saved queries matching \"{}\" ({}):", + trimmed, + matches.len() + ); + for q in &matches { + let sql_preview: String = if q.sql.chars().count() > FIND_QUERIES_SQL_PREVIEW_CHARS { + let truncated: String = q.sql.chars().take(FIND_QUERIES_SQL_PREVIEW_CHARS).collect(); + format!("{}…", truncated) + } else { + q.sql.clone() + }; + out.push_str(&format!( + "\n\n[{}] {}\n{}", + q.created_at, q.name, sql_preview + )); + } + Ok(out) +} + diff --git a/src-tauri/src/commands/connections.rs b/src-tauri/src/commands/connections.rs index 888b5b6..f88664d 100644 --- a/src-tauri/src/commands/connections.rs +++ b/src-tauri/src/commands/connections.rs @@ -1,3 +1,4 @@ +use crate::db::clickhouse::ChClient; use crate::error::{TuskError, TuskResult}; use crate::models::connection::ConnectionConfig; use crate::state::{AppState, DbFlavor}; @@ -23,6 +24,34 @@ pub(crate) fn get_connections_path(app: &AppHandle) -> TuskResult TuskResult> { + let path = get_connections_path(app)?; + if !path.exists() { + return Ok(vec![]); + } + let data = fs::read_to_string(&path)?; + let connections: Vec = serde_json::from_str(&data)?; + Ok(connections) +} + +/// Look up a single saved connection by id. Used by tools that need credentials +/// (e.g. switch_database from inside the chat agent loop) but only have the id in scope. +pub(crate) fn load_connection_config( + app: &AppHandle, + connection_id: &str, +) -> TuskResult { + load_all_connections(app)? + .into_iter() + .find(|c| c.id == connection_id) + .ok_or_else(|| { + TuskError::Custom(format!( + "Connection '{}' not found in connections.json", + connection_id + )) + }) +} + #[tauri::command] pub async fn get_connections(app: AppHandle) -> TuskResult> { let path = get_connections_path(&app)?; @@ -55,6 +84,24 @@ pub async fn save_connection(app: AppHandle, config: ConnectionConfig) -> TuskRe Ok(()) } +async fn close_connection(state: &AppState, id: &str) { + let mut pools = state.pools.write().await; + if let Some(pool) = pools.remove(id) { + pool.close().await; + } + drop(pools); + let mut clients = state.ch_clients.write().await; + clients.remove(id); + drop(clients); + let mut ro = state.read_only.write().await; + ro.remove(id); + drop(ro); + let mut flavors = state.db_flavors.write().await; + flavors.remove(id); + drop(flavors); + state.invalidate_chat_caches_for(id).await; +} + #[tauri::command] pub async fn delete_connection( app: AppHandle, @@ -69,36 +116,37 @@ pub async fn delete_connection( let data = serde_json::to_string_pretty(&connections)?; fs::write(&path, data)?; } - - // Close pool if connected - let mut pools = state.pools.write().await; - if let Some(pool) = pools.remove(&id) { - pool.close().await; - } - - let mut ro = state.read_only.write().await; - ro.remove(&id); - - let mut flavors = state.db_flavors.write().await; - flavors.remove(&id); - + close_connection(&state, &id).await; Ok(()) } #[tauri::command] pub async fn test_connection(config: ConnectionConfig) -> TuskResult { - let pool = PgPool::connect(&config.connection_url()) - .await - .map_err(TuskError::Database)?; - - let row = sqlx::query("SELECT version()") - .fetch_one(&pool) - .await - .map_err(TuskError::Database)?; - - let version: String = row.get(0); - pool.close().await; - Ok(version) + match config.db_flavor { + DbFlavor::ClickHouse => { + let client = ChClient::new( + &config.host, + config.port, + config.secure, + &config.user, + &config.password, + &config.database, + ); + client.ping().await + } + _ => { + let pool = PgPool::connect(&config.connection_url()) + .await + .map_err(TuskError::Database)?; + let row = sqlx::query("SELECT version()") + .fetch_one(&pool) + .await + .map_err(TuskError::Database)?; + let version: String = row.get(0); + pool.close().await; + Ok(version) + } + } } #[tauri::command] @@ -106,39 +154,110 @@ pub async fn connect( state: State<'_, Arc>, config: ConnectionConfig, ) -> TuskResult { - let pool = PgPool::connect(&config.connection_url()) - .await - .map_err(TuskError::Database)?; + match config.db_flavor { + DbFlavor::ClickHouse => { + let client = ChClient::new( + &config.host, + config.port, + config.secure, + &config.user, + &config.password, + &config.database, + ); + let version = client.ping().await?; + let arc = Arc::new(client); + state.ch_clients.write().await.insert(config.id.clone(), arc); + state.read_only.write().await.insert(config.id.clone(), true); + state + .db_flavors + .write() + .await + .insert(config.id.clone(), DbFlavor::ClickHouse); + Ok(ConnectResult { + version, + flavor: DbFlavor::ClickHouse, + }) + } + _ => { + let pool = PgPool::connect(&config.connection_url()) + .await + .map_err(TuskError::Database)?; + sqlx::query("SELECT 1") + .execute(&pool) + .await + .map_err(TuskError::Database)?; + let row = sqlx::query("SELECT version()") + .fetch_one(&pool) + .await + .map_err(TuskError::Database)?; + let version: String = row.get(0); + let flavor = if version.to_lowercase().contains("greenplum") { + DbFlavor::Greenplum + } else { + DbFlavor::PostgreSQL + }; + state.pools.write().await.insert(config.id.clone(), pool); + state.read_only.write().await.insert(config.id.clone(), true); + state + .db_flavors + .write() + .await + .insert(config.id.clone(), flavor); + Ok(ConnectResult { version, flavor }) + } + } +} - // Verify connection - sqlx::query("SELECT 1") - .execute(&pool) - .await - .map_err(TuskError::Database)?; +/// Core implementation of switching the active database for a connection. +/// Reusable from both the Tauri command (frontend-driven) and the chat agent +/// loop (model-driven via the switch_database tool). +pub(crate) async fn switch_database_core( + state: &AppState, + config: &ConnectionConfig, + database: &str, +) -> TuskResult<()> { + let mut switched = config.clone(); + switched.database = database.to_string(); - // Detect database flavor via version() - let row = sqlx::query("SELECT version()") - .fetch_one(&pool) - .await - .map_err(TuskError::Database)?; - let version: String = row.get(0); - - let flavor = if version.to_lowercase().contains("greenplum") { - DbFlavor::Greenplum - } else { - DbFlavor::PostgreSQL + let result: TuskResult<()> = match config.db_flavor { + DbFlavor::ClickHouse => { + let client = ChClient::new( + &switched.host, + switched.port, + switched.secure, + &switched.user, + &switched.password, + &switched.database, + ); + client.ping().await?; + state + .ch_clients + .write() + .await + .insert(config.id.clone(), Arc::new(client)); + Ok(()) + } + _ => { + let pool = PgPool::connect(&switched.connection_url()) + .await + .map_err(TuskError::Database)?; + sqlx::query("SELECT 1") + .execute(&pool) + .await + .map_err(TuskError::Database)?; + let mut pools = state.pools.write().await; + if let Some(old_pool) = pools.remove(&config.id) { + old_pool.close().await; + } + pools.insert(config.id.clone(), pool); + Ok(()) + } }; - let mut pools = state.pools.write().await; - pools.insert(config.id.clone(), pool); + // Drop every cache that's bound to this connection's previous database. + state.invalidate_chat_caches_for(&config.id).await; - let mut ro = state.read_only.write().await; - ro.insert(config.id.clone(), true); - - let mut flavors = state.db_flavors.write().await; - flavors.insert(config.id.clone(), flavor); - - Ok(ConnectResult { version, flavor }) + result } #[tauri::command] @@ -147,40 +266,12 @@ pub async fn switch_database( config: ConnectionConfig, database: String, ) -> TuskResult<()> { - let mut switched = config.clone(); - switched.database = database; - - let pool = PgPool::connect(&switched.connection_url()) - .await - .map_err(TuskError::Database)?; - - sqlx::query("SELECT 1") - .execute(&pool) - .await - .map_err(TuskError::Database)?; - - let mut pools = state.pools.write().await; - if let Some(old_pool) = pools.remove(&config.id) { - old_pool.close().await; - } - pools.insert(config.id.clone(), pool); - - Ok(()) + switch_database_core(&state, &config, &database).await } #[tauri::command] pub async fn disconnect(state: State<'_, Arc>, id: String) -> TuskResult<()> { - let mut pools = state.pools.write().await; - if let Some(pool) = pools.remove(&id) { - pool.close().await; - } - - let mut ro = state.read_only.write().await; - ro.remove(&id); - - let mut flavors = state.db_flavors.write().await; - flavors.remove(&id); - + close_connection(&state, &id).await; Ok(()) } diff --git a/src-tauri/src/commands/data.rs b/src-tauri/src/commands/data.rs index cbadf28..e718530 100644 --- a/src-tauri/src/commands/data.rs +++ b/src-tauri/src/commands/data.rs @@ -1,7 +1,7 @@ use crate::commands::queries::pg_value_to_json; use crate::error::{TuskError, TuskResult}; use crate::models::query_result::PaginatedQueryResult; -use crate::state::AppState; +use crate::state::{AppState, DbFlavor}; use crate::utils::escape_ident; use serde_json::Value; use sqlx::{Column, Row, TypeInfo}; @@ -9,6 +9,80 @@ use std::sync::Arc; use std::time::Instant; use tauri::State; +async fn ch_get_table_data( + state: &AppState, + connection_id: &str, + schema: &str, + table: &str, + page: u32, + page_size: u32, + sort_column: Option<&str>, + sort_direction: Option<&str>, + filter: Option<&str>, +) -> TuskResult { + let client = state.get_ch_client(connection_id).await?; + let qualified = format!( + "{}.{}", + ch_quote_ident(schema), + ch_quote_ident(table) + ); + + let mut where_clause = String::new(); + if let Some(f) = filter { + if !f.trim().is_empty() { + crate::db::sql_guard::ensure_readonly_sql(&format!("SELECT 1 FROM x WHERE {}", f))?; + where_clause = format!(" WHERE {}", f); + } + } + + let mut order_clause = String::new(); + if let Some(col) = sort_column { + if !col.trim().is_empty() { + let dir = match sort_direction { + Some("DESC") | Some("desc") => "DESC", + _ => "ASC", + }; + order_clause = format!(" ORDER BY {} {}", ch_quote_ident(col), dir); + } + } + + let offset = (page.saturating_sub(1)) as i64 * page_size as i64; + let data_sql = format!( + "SELECT * FROM {}{}{} LIMIT {} OFFSET {}", + qualified, where_clause, order_clause, page_size, offset + ); + let count_sql = format!("SELECT count() AS c FROM {}{}", qualified, where_clause); + + let result = client.execute_query(&data_sql, true).await?; + let count_rows = client.fetch_objects(&count_sql).await?; + let total_rows = count_rows + .first() + .and_then(|o| o.get("c")) + .and_then(|v| match v { + Value::Number(n) => n.as_i64(), + Value::String(s) => s.parse::().ok(), + _ => None, + }) + .unwrap_or(0); + + Ok(PaginatedQueryResult { + columns: result.columns, + types: result.types, + rows: result.rows, + row_count: result.row_count, + execution_time_ms: result.execution_time_ms, + total_rows, + page, + page_size, + ctids: vec![], + }) +} + +fn ch_quote_ident(s: &str) -> String { + let escaped = s.replace('`', "``"); + format!("`{}`", escaped) +} + #[tauri::command] #[allow(clippy::too_many_arguments)] pub async fn get_table_data( @@ -22,6 +96,21 @@ pub async fn get_table_data( sort_direction: Option, filter: Option, ) -> TuskResult { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return ch_get_table_data( + &state, + &connection_id, + &schema, + &table, + page, + page_size, + sort_column.as_deref(), + sort_direction.as_deref(), + filter.as_deref(), + ) + .await; + } let pool = state.get_pool(&connection_id).await?; let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); @@ -74,7 +163,7 @@ pub async fn get_table_data( tx.rollback().await.map_err(TuskError::Database)?; - let execution_time_ms = start.elapsed().as_millis(); + let execution_time_ms = start.elapsed().as_millis() as u64; let total_rows: i64 = count_row.get(0); let mut all_columns = Vec::new(); @@ -146,6 +235,11 @@ pub async fn update_row( if state.is_read_only(&connection_id).await { return Err(TuskError::ReadOnly); } + if matches!(state.get_flavor(&connection_id).await, DbFlavor::ClickHouse) { + return Err(TuskError::Custom( + "Inline row edit is not supported for ClickHouse — use SQL ALTER … UPDATE.".into(), + )); + } let pool = state.get_pool(&connection_id).await?; @@ -202,6 +296,11 @@ pub async fn insert_row( if state.is_read_only(&connection_id).await { return Err(TuskError::ReadOnly); } + if matches!(state.get_flavor(&connection_id).await, DbFlavor::ClickHouse) { + return Err(TuskError::Custom( + "Inline row insert is not supported for ClickHouse — use SQL INSERT.".into(), + )); + } let pool = state.get_pool(&connection_id).await?; @@ -240,6 +339,11 @@ pub async fn delete_rows( if state.is_read_only(&connection_id).await { return Err(TuskError::ReadOnly); } + if matches!(state.get_flavor(&connection_id).await, DbFlavor::ClickHouse) { + return Err(TuskError::Custom( + "Inline row delete is not supported for ClickHouse — use SQL ALTER … DELETE.".into(), + )); + } let pool = state.get_pool(&connection_id).await?; diff --git a/src-tauri/src/commands/docker.rs b/src-tauri/src/commands/docker.rs deleted file mode 100644 index d587b25..0000000 --- a/src-tauri/src/commands/docker.rs +++ /dev/null @@ -1,1218 +0,0 @@ -use crate::error::{TuskError, TuskResult}; -use crate::models::connection::ConnectionConfig; -use crate::models::docker::{ - CloneMode, CloneProgress, CloneResult, CloneToDockerParams, DockerStatus, TuskContainer, -}; -use crate::state::AppState; -use crate::utils::escape_ident; -use std::fs; -use std::sync::Arc; -use tauri::{AppHandle, Emitter, State}; -use tokio::process::Command; - -async fn docker_cmd(state: &AppState) -> Command { - let host = state.docker_host.read().await; - let mut cmd = Command::new("docker"); - if let Some(ref h) = *host { - cmd.args(["-H", h]); - } - cmd -} - -fn docker_err(msg: impl Into) -> TuskError { - TuskError::Docker(msg.into()) -} - -fn emit_progress( - app: &AppHandle, - clone_id: &str, - stage: &str, - percent: u8, - message: &str, - detail: Option<&str>, -) { - let _ = app.emit( - "clone-progress", - CloneProgress { - clone_id: clone_id.to_string(), - stage: stage.to_string(), - percent, - message: message.to_string(), - detail: detail.map(|s| s.to_string()), - }, - ); -} - -fn load_connection_config(app: &AppHandle, connection_id: &str) -> TuskResult { - let path = super::connections::get_connections_path(app)?; - if !path.exists() { - return Err(TuskError::ConnectionNotFound(connection_id.to_string())); - } - let data = fs::read_to_string(&path)?; - let connections: Vec = serde_json::from_str(&data)?; - connections - .into_iter() - .find(|c| c.id == connection_id) - .ok_or_else(|| TuskError::ConnectionNotFound(connection_id.to_string())) -} - -/// Shell-escape a string for use in single quotes -fn shell_escape(s: &str) -> String { - s.replace('\'', "'\\''") -} - -/// Validate pg_version matches a safe pattern (e.g. "16", "16.2", "17.1") -fn validate_pg_version(version: &str) -> TuskResult<()> { - let is_valid = !version.is_empty() && version.chars().all(|c| c.is_ascii_digit() || c == '.'); - if !is_valid { - return Err(docker_err(format!( - "Invalid pg_version '{}': must contain only digits and dots (e.g. '16', '16.2')", - version - ))); - } - Ok(()) -} - -/// Validate container name matches Docker naming rules: [a-zA-Z0-9][a-zA-Z0-9_.-]* -fn validate_container_name(name: &str) -> TuskResult<()> { - if name.is_empty() { - return Err(docker_err("Container name cannot be empty")); - } - let first = name.chars().next().unwrap(); - if !first.is_ascii_alphanumeric() { - return Err(docker_err(format!( - "Invalid container name '{}': must start with an alphanumeric character", - name - ))); - } - let is_valid = name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '-'); - if !is_valid { - return Err(docker_err(format!( - "Invalid container name '{}': only [a-zA-Z0-9_.-] characters are allowed", - name - ))); - } - Ok(()) -} - -/// Shell-escape a string for use inside double-quoted shell contexts -fn shell_escape_double(s: &str) -> String { - s.replace('\\', "\\\\") - .replace('"', "\\\"") - .replace('$', "\\$") - .replace('`', "\\`") - .replace('!', "\\!") -} - -#[tauri::command] -pub async fn check_docker(state: State<'_, Arc>) -> TuskResult { - let docker_host = state.docker_host.read().await.clone(); - check_docker_internal(&docker_host).await -} - -#[tauri::command] -pub async fn list_tusk_containers( - state: State<'_, Arc>, -) -> TuskResult> { - let output = docker_cmd(&state) - .await - .args([ - "ps", - "-a", - "--filter", - "label=tusk.managed=true", - "--format", - "{{.ID}}\t{{.Names}}\t{{.Status}}\t{{.Label \"tusk.pg-version\"}}\t{{.Label \"tusk.source-db\"}}\t{{.Label \"tusk.source-connection\"}}\t{{.CreatedAt}}\t{{.Ports}}", - ]) - .output() - .await - .map_err(|e| docker_err(format!("Failed to run docker ps: {}", e)))?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - return Err(docker_err(format!("docker ps failed: {}", stderr))); - } - - let stdout = String::from_utf8_lossy(&output.stdout); - let mut containers = Vec::new(); - - for line in stdout.lines() { - if line.trim().is_empty() { - continue; - } - let parts: Vec<&str> = line.split('\t').collect(); - if parts.len() < 8 { - continue; - } - - let host_port = parse_host_port(parts[7]); - - containers.push(TuskContainer { - container_id: parts[0].to_string(), - name: parts[1].to_string(), - status: parts[2].to_string(), - host_port, - pg_version: parts[3].to_string(), - source_database: if parts[4].is_empty() { - None - } else { - Some(parts[4].to_string()) - }, - source_connection: if parts[5].is_empty() { - None - } else { - Some(parts[5].to_string()) - }, - created_at: if parts[6].is_empty() { - None - } else { - Some(parts[6].to_string()) - }, - }); - } - - Ok(containers) -} - -fn parse_host_port(ports_str: &str) -> u16 { - for part in ports_str.split(',') { - let part = part.trim(); - if let Some(arrow_pos) = part.find("->") { - let before = &part[..arrow_pos]; - if let Some(colon_pos) = before.rfind(':') { - if let Ok(port) = before[colon_pos + 1..].parse::() { - return port; - } - } - } - } - 0 -} - -#[tauri::command] -pub async fn clone_to_docker( - app: AppHandle, - state: State<'_, Arc>, - params: CloneToDockerParams, - clone_id: String, -) -> TuskResult { - let state = state.inner().clone(); - let app_clone = app.clone(); - - tokio::spawn(async move { do_clone(&app_clone, &state, ¶ms, &clone_id).await }) - .await - .map_err(|e| docker_err(format!("Clone task panicked: {}", e)))? -} - -/// Build a docker Command respecting the remote host setting -fn docker_cmd_sync(docker_host: &Option) -> Command { - let mut cmd = Command::new("docker"); - if let Some(ref h) = docker_host { - cmd.args(["-H", h]); - } - cmd -} - -async fn check_docker_internal(docker_host: &Option) -> TuskResult { - let output = docker_cmd_sync(docker_host) - .args(["version", "--format", "{{.Server.Version}}"]) - .output() - .await; - - match output { - Ok(out) => { - if out.status.success() { - let version = String::from_utf8_lossy(&out.stdout).trim().to_string(); - Ok(DockerStatus { - installed: true, - daemon_running: true, - version: Some(version), - error: None, - }) - } else { - let stderr = String::from_utf8_lossy(&out.stderr).trim().to_string(); - let daemon_running = - !stderr.contains("Cannot connect") && !stderr.contains("connection refused"); - Ok(DockerStatus { - installed: true, - daemon_running, - version: None, - error: Some(stderr), - }) - } - } - Err(_) => Ok(DockerStatus { - installed: false, - daemon_running: false, - version: None, - error: Some("Docker CLI not found. Please install Docker.".to_string()), - }), - } -} - -async fn do_clone( - app: &AppHandle, - state: &Arc, - params: &CloneToDockerParams, - clone_id: &str, -) -> TuskResult { - // Validate user inputs before any operations - validate_pg_version(¶ms.pg_version)?; - validate_container_name(¶ms.container_name)?; - - let docker_host = state.docker_host.read().await.clone(); - - // Step 1: Check Docker - emit_progress( - app, - clone_id, - "checking", - 5, - "Checking Docker availability...", - None, - ); - let status = check_docker_internal(&docker_host).await?; - if !status.installed || !status.daemon_running { - let msg = status - .error - .unwrap_or_else(|| "Docker is not available".to_string()); - emit_progress(app, clone_id, "error", 5, &msg, None); - return Err(docker_err(msg)); - } - - // Step 2: Find available port - emit_progress(app, clone_id, "port", 10, "Finding available port...", None); - let host_port = match params.host_port { - Some(p) => p, - None => find_free_port().await?, - }; - emit_progress( - app, - clone_id, - "port", - 10, - &format!("Using port {}", host_port), - None, - ); - - // Step 3: Create container - emit_progress( - app, - clone_id, - "container", - 20, - "Creating PostgreSQL container...", - None, - ); - let pg_password = params.postgres_password.as_deref().unwrap_or("tusk"); - let image = format!("postgres:{}", params.pg_version); - - let create_output = docker_cmd_sync(&docker_host) - .args([ - "run", - "-d", - "--name", - ¶ms.container_name, - "-p", - &format!("{}:5432", host_port), - "-e", - &format!("POSTGRES_PASSWORD={}", pg_password), - "-l", - "tusk.managed=true", - "-l", - &format!("tusk.source-db={}", params.source_database), - "-l", - &format!("tusk.source-connection={}", params.source_connection_id), - "-l", - &format!("tusk.pg-version={}", params.pg_version), - &image, - ]) - .output() - .await - .map_err(|e| docker_err(format!("Failed to create container: {}", e)))?; - - if !create_output.status.success() { - let stderr = String::from_utf8_lossy(&create_output.stderr) - .trim() - .to_string(); - emit_progress( - app, - clone_id, - "error", - 20, - &format!("Failed to create container: {}", stderr), - None, - ); - return Err(docker_err(format!( - "Failed to create container: {}", - stderr - ))); - } - - let container_id = String::from_utf8_lossy(&create_output.stdout) - .trim() - .to_string(); - - // Step 4: Wait for PostgreSQL to be ready - emit_progress( - app, - clone_id, - "waiting", - 30, - "Waiting for PostgreSQL to be ready...", - None, - ); - wait_for_pg_ready(&docker_host, ¶ms.container_name, 30).await?; - emit_progress(app, clone_id, "waiting", 35, "PostgreSQL is ready", None); - - // Step 5: Create target database - emit_progress( - app, - clone_id, - "database", - 35, - &format!("Creating database '{}'...", params.source_database), - None, - ); - let create_db_output = docker_cmd_sync(&docker_host) - .args([ - "exec", - ¶ms.container_name, - "psql", - "-U", - "postgres", - "-c", - &format!("CREATE DATABASE {}", escape_ident(¶ms.source_database)), - ]) - .output() - .await - .map_err(|e| docker_err(format!("Failed to create database: {}", e)))?; - - if !create_db_output.status.success() { - let stderr = String::from_utf8_lossy(&create_db_output.stderr) - .trim() - .to_string(); - if !stderr.contains("already exists") { - emit_progress( - app, - clone_id, - "error", - 35, - &format!("Failed to create database: {}", stderr), - None, - ); - return Err(docker_err(format!("Failed to create database: {}", stderr))); - } - } - - // Step 6: Get source connection URL (using the specific database to clone) - emit_progress( - app, - clone_id, - "dump", - 40, - "Preparing data transfer...", - None, - ); - let source_config = load_connection_config(app, ¶ms.source_connection_id)?; - let source_url = source_config.connection_url_for_db(¶ms.source_database); - emit_progress( - app, - clone_id, - "dump", - 40, - &format!( - "Source: {}@{}:{}/{}", - source_config.user, source_config.host, source_config.port, params.source_database - ), - None, - ); - - // Step 7: Transfer data based on clone mode - match params.clone_mode { - CloneMode::SchemaOnly => { - emit_progress(app, clone_id, "transfer", 45, "Dumping schema...", None); - transfer_schema_only( - app, - clone_id, - &source_url, - ¶ms.container_name, - ¶ms.source_database, - ¶ms.pg_version, - &docker_host, - ) - .await?; - } - CloneMode::FullClone => { - emit_progress( - app, - clone_id, - "transfer", - 45, - "Performing full database clone...", - None, - ); - transfer_full_clone( - app, - clone_id, - &source_url, - ¶ms.container_name, - ¶ms.source_database, - ¶ms.pg_version, - &docker_host, - ) - .await?; - } - CloneMode::SampleData => { - let has_local = try_local_pg_dump().await; - emit_progress(app, clone_id, "transfer", 45, "Dumping schema...", None); - transfer_schema_only_with( - app, - clone_id, - &source_url, - ¶ms.container_name, - ¶ms.source_database, - ¶ms.pg_version, - &docker_host, - has_local, - ) - .await?; - emit_progress( - app, - clone_id, - "transfer", - 65, - "Copying sample data...", - None, - ); - let sample_rows = params.sample_rows.unwrap_or(1000); - transfer_sample_data_with( - app, - clone_id, - &source_url, - ¶ms.container_name, - ¶ms.source_database, - ¶ms.pg_version, - sample_rows, - &docker_host, - has_local, - ) - .await?; - } - } - - // Step 8: Save connection in Tusk - emit_progress( - app, - clone_id, - "connection", - 90, - "Saving connection...", - None, - ); - let connection_id = uuid::Uuid::new_v4().to_string(); - let new_config = ConnectionConfig { - id: connection_id.clone(), - name: format!("{} (Docker clone)", params.source_database), - host: "localhost".to_string(), - port: host_port, - user: "postgres".to_string(), - password: pg_password.to_string(), - database: params.source_database.clone(), - ssl_mode: Some("disable".to_string()), - color: Some("#06b6d4".to_string()), - environment: Some("local".to_string()), - }; - - save_connection_config(app, &new_config)?; - - let connection_url = format!( - "postgres://postgres:{}@localhost:{}/{}", - pg_password, host_port, params.source_database - ); - - let container = TuskContainer { - container_id: container_id[..12.min(container_id.len())].to_string(), - name: params.container_name.clone(), - status: "Up".to_string(), - host_port, - pg_version: params.pg_version.clone(), - source_database: Some(params.source_database.clone()), - source_connection: Some(params.source_connection_id.clone()), - created_at: Some(chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()), - }; - - let result = CloneResult { - container, - connection_id, - connection_url, - }; - - emit_progress( - app, - clone_id, - "done", - 100, - "Clone completed successfully!", - None, - ); - - Ok(result) -} - -async fn find_free_port() -> TuskResult { - let listener = tokio::net::TcpListener::bind("127.0.0.1:0") - .await - .map_err(|e| docker_err(format!("Failed to find free port: {}", e)))?; - let port = listener - .local_addr() - .map_err(|e| docker_err(format!("Failed to get port: {}", e)))? - .port(); - drop(listener); - Ok(port) -} - -async fn wait_for_pg_ready( - docker_host: &Option, - container_name: &str, - timeout_secs: u64, -) -> TuskResult<()> { - let start = std::time::Instant::now(); - let timeout = std::time::Duration::from_secs(timeout_secs); - - loop { - if start.elapsed() > timeout { - return Err(docker_err("PostgreSQL did not become ready in time")); - } - - let output = docker_cmd_sync(docker_host) - .args(["exec", container_name, "pg_isready", "-U", "postgres"]) - .output() - .await; - - if let Ok(out) = output { - if out.status.success() { - return Ok(()); - } - } - - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - } -} - -async fn try_local_pg_dump() -> bool { - Command::new("pg_dump") - .arg("--version") - .output() - .await - .map(|o| o.status.success()) - .unwrap_or(false) -} - -/// Build the docker host flag string for shell commands -fn docker_host_flag(docker_host: &Option) -> String { - match docker_host { - Some(h) => format!("-H '{}'", shell_escape(h)), - None => String::new(), - } -} - -/// Build the pg_dump portion of a shell command -fn pg_dump_shell_cmd( - has_local: bool, - pg_version: &str, - extra_args: &str, - source_url: &str, - docker_host: &Option, -) -> String { - let escaped_url = shell_escape(source_url); - if has_local { - format!("pg_dump {} '{}'", extra_args, escaped_url) - } else { - let host_flag = docker_host_flag(docker_host); - format!( - "docker {} run --rm --network=host postgres:{} pg_dump {} '{}'", - host_flag, pg_version, extra_args, escaped_url - ) - } -} - -async fn run_pipe_cmd( - app: &AppHandle, - clone_id: &str, - pipe_cmd: &str, - label: &str, -) -> TuskResult { - // Use bash with pipefail so pg_dump failures are not swallowed - let wrapped = format!("set -o pipefail; {}", pipe_cmd); - - emit_progress(app, clone_id, "transfer", 50, label, None); - - let output = Command::new("bash") - .args(["-c", &wrapped]) - .output() - .await - .map_err(|e| docker_err(format!("{} failed to start: {}", label, e)))?; - - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); - - // Always log stderr if present - if !stderr.is_empty() { - // Truncate for progress display (full log can be long) - let short = if stderr.len() > 500 { - let truncated = stderr - .char_indices() - .nth(500) - .map(|(i, _)| &stderr[..i]) - .unwrap_or(&stderr); - format!("{}...", truncated) - } else { - stderr.clone() - }; - emit_progress( - app, - clone_id, - "transfer", - 55, - &format!("{}: stderr output", label), - Some(&short), - ); - } - - // Count DDL statements in stdout for feedback - if !stdout.is_empty() { - let creates = stdout - .lines() - .filter(|l| l.starts_with("CREATE") || l.starts_with("ALTER") || l.starts_with("SET")) - .count(); - if creates > 0 { - emit_progress( - app, - clone_id, - "transfer", - 58, - &format!("Applied {} SQL statements", creates), - None, - ); - } - } - - if !output.status.success() { - let code = output.status.code().unwrap_or(-1); - emit_progress( - app, - clone_id, - "transfer", - 55, - &format!("{} exited with code {}", label, code), - Some(&stderr), - ); - - // Only hard-fail on connection / fatal errors - if stderr.contains("FATAL") - || stderr.contains("could not connect") - || stderr.contains("No such file") - || stderr.contains("password authentication failed") - || stderr.contains("does not exist") - || (stdout.is_empty() && stderr.is_empty()) - { - return Err(docker_err(format!( - "{} failed (exit {}): {}", - label, code, stderr - ))); - } - } - - Ok(output) -} - -async fn transfer_schema_only( - app: &AppHandle, - clone_id: &str, - source_url: &str, - container_name: &str, - database: &str, - pg_version: &str, - docker_host: &Option, -) -> TuskResult<()> { - let has_local = try_local_pg_dump().await; - transfer_schema_only_with( - app, - clone_id, - source_url, - container_name, - database, - pg_version, - docker_host, - has_local, - ) - .await -} - -#[allow(clippy::too_many_arguments)] -async fn transfer_schema_only_with( - app: &AppHandle, - clone_id: &str, - source_url: &str, - container_name: &str, - database: &str, - pg_version: &str, - docker_host: &Option, - has_local: bool, -) -> TuskResult<()> { - let label = if has_local { - "local pg_dump" - } else { - "Docker-based pg_dump" - }; - emit_progress( - app, - clone_id, - "transfer", - 48, - &format!("Using {} for schema...", label), - None, - ); - - let dump_cmd = pg_dump_shell_cmd( - has_local, - pg_version, - "--schema-only --no-owner --no-acl", - source_url, - docker_host, - ); - let escaped_db = shell_escape(database); - let host_flag = docker_host_flag(docker_host); - let pipe_cmd = format!( - "{} | docker {} exec -i '{}' psql -U postgres -d '{}'", - dump_cmd, - host_flag, - shell_escape(container_name), - escaped_db - ); - - run_pipe_cmd(app, clone_id, &pipe_cmd, "Schema transfer").await?; - - emit_progress( - app, - clone_id, - "transfer", - 60, - "Schema transferred successfully", - None, - ); - Ok(()) -} - -async fn transfer_full_clone( - app: &AppHandle, - clone_id: &str, - source_url: &str, - container_name: &str, - database: &str, - pg_version: &str, - docker_host: &Option, -) -> TuskResult<()> { - let has_local = try_local_pg_dump().await; - let label = if has_local { - "local pg_dump" - } else { - "Docker-based pg_dump" - }; - emit_progress( - app, - clone_id, - "transfer", - 48, - &format!("Using {} for full clone...", label), - None, - ); - - // Use plain text format piped to psql (more reliable than -Fc | pg_restore through docker exec) - let dump_cmd = pg_dump_shell_cmd( - has_local, - pg_version, - "--no-owner --no-acl", - source_url, - docker_host, - ); - let escaped_db = shell_escape(database); - let host_flag = docker_host_flag(docker_host); - let pipe_cmd = format!( - "{} | docker {} exec -i '{}' psql -U postgres -d '{}'", - dump_cmd, - host_flag, - shell_escape(container_name), - escaped_db - ); - - run_pipe_cmd(app, clone_id, &pipe_cmd, "Full clone").await?; - - emit_progress(app, clone_id, "transfer", 85, "Full clone completed", None); - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -async fn transfer_sample_data_with( - app: &AppHandle, - clone_id: &str, - source_url: &str, - container_name: &str, - database: &str, - pg_version: &str, - sample_rows: u32, - docker_host: &Option, - has_local: bool, -) -> TuskResult<()> { - // List tables from the target (schema already transferred) - let target_output = docker_cmd_sync(docker_host) - .args([ - "exec", container_name, - "psql", "-U", "postgres", "-d", database, - "-t", "-A", "-c", - "SELECT schemaname || '.' || tablename FROM pg_tables WHERE schemaname NOT IN ('pg_catalog', 'information_schema') ORDER BY schemaname, tablename", - ]) - .output() - .await - .map_err(|e| docker_err(format!("Failed to list tables: {}", e)))?; - - let tables_str = String::from_utf8_lossy(&target_output.stdout); - let tables: Vec<&str> = tables_str - .lines() - .filter(|l| !l.trim().is_empty()) - .collect(); - let total = tables.len(); - - if total == 0 { - emit_progress( - app, - clone_id, - "transfer", - 85, - "No tables to copy data for", - None, - ); - return Ok(()); - } - - for (i, qualified_table) in tables.iter().enumerate() { - let pct = 65 + ((i * 20) / total.max(1)).min(20) as u8; - emit_progress( - app, - clone_id, - "transfer", - pct, - &format!( - "Copying sample data: {} ({}/{})", - qualified_table, - i + 1, - total - ), - None, - ); - - let parts: Vec<&str> = qualified_table.splitn(2, '.').collect(); - if parts.len() != 2 { - continue; - } - let schema = parts[0]; - let table = parts[1]; - - // Use COPY (SELECT ... LIMIT N) TO STDOUT piped to COPY ... FROM STDIN - // Escape schema/table for use inside double-quoted shell strings - let escaped_schema = shell_escape_double(schema); - let escaped_table = shell_escape_double(table); - let copy_out_sql = format!( - "\\copy (SELECT * FROM \\\"{}\\\".\\\"{}\\\" LIMIT {}) TO STDOUT", - escaped_schema, escaped_table, sample_rows - ); - let copy_in_sql = format!( - "\\copy \\\"{}\\\".\\\"{}\\\" FROM STDIN", - escaped_schema, escaped_table - ); - - let escaped_url = shell_escape(source_url); - let escaped_container = shell_escape(container_name); - let escaped_db = shell_escape(database); - - let host_flag = docker_host_flag(docker_host); - let source_cmd = if has_local { - format!("psql '{}' -c \"{}\"", escaped_url, copy_out_sql) - } else { - let image = format!("postgres:{}", pg_version); - format!( - "docker {} run --rm --network=host {} psql '{}' -c \"{}\"", - host_flag, image, escaped_url, copy_out_sql - ) - }; - - let pipe_cmd = format!( - "set -o pipefail; {} | docker {} exec -i '{}' psql -U postgres -d '{}' -c \"{}\"", - source_cmd, host_flag, escaped_container, escaped_db, copy_in_sql - ); - - let output = Command::new("bash").args(["-c", &pipe_cmd]).output().await; - - match output { - Ok(out) => { - let stderr = String::from_utf8_lossy(&out.stderr).trim().to_string(); - if !stderr.is_empty() && (stderr.contains("ERROR") || stderr.contains("FATAL")) { - emit_progress( - app, - clone_id, - "transfer", - pct, - &format!("Warning: {}", qualified_table), - Some(&stderr), - ); - } - } - Err(e) => { - emit_progress( - app, - clone_id, - "transfer", - pct, - &format!("Warning: failed to copy {}: {}", qualified_table, e), - None, - ); - } - } - } - - emit_progress( - app, - clone_id, - "transfer", - 85, - "Sample data transfer completed", - None, - ); - Ok(()) -} - -fn save_connection_config(app: &AppHandle, config: &ConnectionConfig) -> TuskResult<()> { - let path = super::connections::get_connections_path(app)?; - let mut connections = if path.exists() { - let data = fs::read_to_string(&path)?; - serde_json::from_str::>(&data)? - } else { - vec![] - }; - - // Upsert by ID to avoid duplicate entries on retry - if let Some(pos) = connections.iter().position(|c| c.id == config.id) { - connections[pos] = config.clone(); - } else { - connections.push(config.clone()); - } - - let data = serde_json::to_string_pretty(&connections)?; - fs::write(&path, data)?; - Ok(()) -} - -#[tauri::command] -pub async fn start_container(state: State<'_, Arc>, name: String) -> TuskResult<()> { - let output = docker_cmd(&state) - .await - .args(["start", &name]) - .output() - .await - .map_err(|e| docker_err(format!("Failed to start container: {}", e)))?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - return Err(docker_err(format!("Failed to start container: {}", stderr))); - } - - Ok(()) -} - -#[tauri::command] -pub async fn stop_container(state: State<'_, Arc>, name: String) -> TuskResult<()> { - let output = docker_cmd(&state) - .await - .args(["stop", &name]) - .output() - .await - .map_err(|e| docker_err(format!("Failed to stop container: {}", e)))?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - return Err(docker_err(format!("Failed to stop container: {}", stderr))); - } - - Ok(()) -} - -#[tauri::command] -pub async fn remove_container(state: State<'_, Arc>, name: String) -> TuskResult<()> { - let output = docker_cmd(&state) - .await - .args(["rm", "-f", &name]) - .output() - .await - .map_err(|e| docker_err(format!("Failed to remove container: {}", e)))?; - - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr); - return Err(docker_err(format!( - "Failed to remove container: {}", - stderr - ))); - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - // ── validate_container_name ─────────────────────────────── - - #[test] - fn container_name_valid_simple() { - assert!(validate_container_name("mycontainer").is_ok()); - } - - #[test] - fn container_name_valid_with_dots_dashes_underscores() { - assert!(validate_container_name("my-container_v1.2").is_ok()); - } - - #[test] - fn container_name_valid_starts_with_digit() { - assert!(validate_container_name("1container").is_ok()); - } - - #[test] - fn container_name_empty() { - assert!(validate_container_name("").is_err()); - } - - #[test] - fn container_name_starts_with_dash() { - assert!(validate_container_name("-bad").is_err()); - } - - #[test] - fn container_name_starts_with_dot() { - assert!(validate_container_name(".bad").is_err()); - } - - #[test] - fn container_name_starts_with_underscore() { - assert!(validate_container_name("_bad").is_err()); - } - - #[test] - fn container_name_with_spaces() { - assert!(validate_container_name("bad name").is_err()); - } - - #[test] - fn container_name_with_unicode() { - assert!(validate_container_name("контейнер").is_err()); - } - - #[test] - fn container_name_with_special_chars() { - assert!(validate_container_name("bad;name").is_err()); - assert!(validate_container_name("bad/name").is_err()); - assert!(validate_container_name("bad:name").is_err()); - assert!(validate_container_name("bad@name").is_err()); - } - - #[test] - fn container_name_with_shell_injection() { - assert!(validate_container_name("x; rm -rf /").is_err()); - assert!(validate_container_name("x$(whoami)").is_err()); - } - - // ── validate_pg_version ─────────────────────────────────── - - #[test] - fn pg_version_valid_major() { - assert!(validate_pg_version("16").is_ok()); - } - - #[test] - fn pg_version_valid_major_minor() { - assert!(validate_pg_version("16.2").is_ok()); - } - - #[test] - fn pg_version_valid_three_parts() { - assert!(validate_pg_version("17.1.0").is_ok()); - } - - #[test] - fn pg_version_empty() { - assert!(validate_pg_version("").is_err()); - } - - #[test] - fn pg_version_with_letters() { - assert!(validate_pg_version("16beta1").is_err()); - } - - #[test] - fn pg_version_with_injection() { - assert!(validate_pg_version("16; rm -rf").is_err()); - } - - #[test] - fn pg_version_only_dots() { - // Current impl allows dots-only — this documents the behavior - assert!(validate_pg_version("...").is_ok()); - } - - // ── shell_escape ────────────────────────────────────────── - - #[test] - fn shell_escape_no_quotes() { - assert_eq!(shell_escape("hello"), "hello"); - } - - #[test] - fn shell_escape_with_single_quote() { - assert_eq!(shell_escape("it's"), "it'\\''s"); - } - - #[test] - fn shell_escape_multiple_quotes() { - assert_eq!(shell_escape("a'b'c"), "a'\\''b'\\''c"); - } - - // ── shell_escape_double ─────────────────────────────────── - - #[test] - fn shell_escape_double_no_special() { - assert_eq!(shell_escape_double("hello"), "hello"); - } - - #[test] - fn shell_escape_double_with_backslash() { - assert_eq!(shell_escape_double(r"a\b"), r"a\\b"); - } - - #[test] - fn shell_escape_double_with_dollar() { - assert_eq!(shell_escape_double("$HOME"), "\\$HOME"); - } - - #[test] - fn shell_escape_double_with_backtick() { - assert_eq!(shell_escape_double("`whoami`"), "\\`whoami\\`"); - } - - #[test] - fn shell_escape_double_with_double_quote() { - assert_eq!(shell_escape_double(r#"say "hi""#), r#"say \"hi\""#); - } -} diff --git a/src-tauri/src/commands/lookup.rs b/src-tauri/src/commands/lookup.rs deleted file mode 100644 index 5a1125a..0000000 --- a/src-tauri/src/commands/lookup.rs +++ /dev/null @@ -1,367 +0,0 @@ -use crate::commands::queries::pg_value_to_json; -use crate::error::TuskResult; -use crate::models::connection::ConnectionConfig; -use crate::models::lookup::{ - EntityLookupResult, LookupDatabaseResult, LookupProgress, LookupTableMatch, -}; -use crate::utils::escape_ident; -use sqlx::postgres::PgPoolOptions; -use sqlx::{Column, Row, TypeInfo}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::Instant; -use tauri::{AppHandle, Emitter}; -use tokio::sync::Semaphore; - -struct TableCandidate { - schema: String, - table: String, - data_type: String, -} - -async fn search_database( - config: &ConnectionConfig, - database: &str, - column_name: &str, - value: &str, -) -> LookupDatabaseResult { - let start = Instant::now(); - - let mut db_config = config.clone(); - db_config.database = database.to_string(); - let url = db_config.connection_url(); - - let pool = match PgPoolOptions::new() - .max_connections(2) - .acquire_timeout(std::time::Duration::from_secs(5)) - .connect(&url) - .await - { - Ok(p) => p, - Err(e) => { - return LookupDatabaseResult { - database: database.to_string(), - tables: vec![], - error: Some(format!("Connection failed: {}", e)), - search_time_ms: start.elapsed().as_millis(), - }; - } - }; - - let result = tokio::time::timeout( - std::time::Duration::from_secs(120), - search_database_inner(&pool, database, column_name, value), - ) - .await; - - pool.close().await; - - match result { - Ok(db_result) => { - let mut db_result = db_result; - db_result.search_time_ms = start.elapsed().as_millis(); - db_result - } - Err(_) => LookupDatabaseResult { - database: database.to_string(), - tables: vec![], - error: Some("Timeout (120s)".to_string()), - search_time_ms: start.elapsed().as_millis(), - }, - } -} - -async fn search_database_inner( - pool: &sqlx::PgPool, - database: &str, - column_name: &str, - value: &str, -) -> LookupDatabaseResult { - // Find tables that have this column - let candidates = match sqlx::query_as::<_, (String, String, String)>( - "SELECT table_schema, table_name, data_type \ - FROM information_schema.columns \ - WHERE column_name = $1 \ - AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit')", - ) - .bind(column_name) - .fetch_all(pool) - .await - { - Ok(rows) => rows - .into_iter() - .map(|(schema, table, data_type)| TableCandidate { - schema, - table, - data_type, - }) - .collect::>(), - Err(e) => { - return LookupDatabaseResult { - database: database.to_string(), - tables: vec![], - error: Some(format!("Schema query failed: {}", e)), - search_time_ms: 0, - }; - } - }; - - let mut tables = Vec::new(); - - for candidate in &candidates { - let qualified = format!( - "{}.{}", - escape_ident(&candidate.schema), - escape_ident(&candidate.table) - ); - let col_ident = escape_ident(column_name); - - // Read-only transaction: SELECT rows + COUNT - let select_sql = format!( - "SELECT * FROM {} WHERE {}::text = $1 LIMIT 50", - qualified, col_ident - ); - let count_sql = format!( - "SELECT COUNT(*) FROM {} WHERE {}::text = $1", - qualified, col_ident - ); - - let mut tx = match pool.begin().await { - Ok(tx) => tx, - Err(e) => { - tables.push(LookupTableMatch { - schema: candidate.schema.clone(), - table: candidate.table.clone(), - column_type: candidate.data_type.clone(), - columns: vec![], - types: vec![], - rows: vec![], - row_count: 0, - total_count: 0, - }); - log::warn!( - "Failed to begin tx for {}.{}: {}", - candidate.schema, - candidate.table, - e - ); - continue; - } - }; - - if let Err(e) = sqlx::query("SET TRANSACTION READ ONLY") - .execute(&mut *tx) - .await - { - let _ = tx.rollback().await; - log::warn!("Failed SET TRANSACTION READ ONLY: {}", e); - continue; - } - - let rows_result = sqlx::query(&select_sql) - .bind(value) - .fetch_all(&mut *tx) - .await; - - let count_result: Result = sqlx::query_scalar(&count_sql) - .bind(value) - .fetch_one(&mut *tx) - .await; - - let _ = tx.rollback().await; - - match rows_result { - Ok(rows) if !rows.is_empty() => { - let mut col_names = Vec::new(); - let mut col_types = Vec::new(); - if let Some(first) = rows.first() { - for col in first.columns() { - col_names.push(col.name().to_string()); - col_types.push(col.type_info().name().to_string()); - } - } - - let result_rows: Vec> = rows - .iter() - .map(|row| { - (0..col_names.len()) - .map(|i| pg_value_to_json(row, i)) - .collect() - }) - .collect(); - - let row_count = result_rows.len(); - let total_count = count_result.unwrap_or(row_count as i64); - - tables.push(LookupTableMatch { - schema: candidate.schema.clone(), - table: candidate.table.clone(), - column_type: candidate.data_type.clone(), - columns: col_names, - types: col_types, - rows: result_rows, - row_count, - total_count, - }); - } - Ok(_) => { - // No rows matched — skip - } - Err(e) => { - log::warn!( - "Query failed for {}.{}: {}", - candidate.schema, - candidate.table, - e - ); - } - } - } - - LookupDatabaseResult { - database: database.to_string(), - tables, - error: None, - search_time_ms: 0, - } -} - -#[tauri::command] -pub async fn entity_lookup( - app: AppHandle, - config: ConnectionConfig, - column_name: String, - value: String, - databases: Option>, - lookup_id: String, -) -> TuskResult { - let start = Instant::now(); - - // 1. Get list of databases - let url = config.connection_url(); - let pool = PgPoolOptions::new() - .max_connections(1) - .acquire_timeout(std::time::Duration::from_secs(5)) - .connect(&url) - .await - .map_err(crate::error::TuskError::Database)?; - - let db_names: Vec = sqlx::query_scalar( - "SELECT datname FROM pg_database WHERE datistemplate = false ORDER BY datname", - ) - .fetch_all(&pool) - .await - .map_err(crate::error::TuskError::Database)?; - - pool.close().await; - - // Filter if specific databases requested - let db_names: Vec = if let Some(ref filter) = databases { - db_names - .into_iter() - .filter(|d| filter.contains(d)) - .collect() - } else { - db_names - }; - - let total = db_names.len(); - let completed = Arc::new(AtomicUsize::new(0)); - let semaphore = Arc::new(Semaphore::new(5)); - - // 2. Parallel search across databases - let mut handles = Vec::new(); - - for db_name in db_names { - let config = config.clone(); - let column_name = column_name.clone(); - let value = value.clone(); - let lookup_id = lookup_id.clone(); - let app = app.clone(); - let semaphore = semaphore.clone(); - let completed = completed.clone(); - - let handle = tokio::spawn(async move { - let _permit = semaphore.acquire().await.unwrap(); - - // Emit "searching" progress - let _ = app.emit( - "lookup-progress", - LookupProgress { - lookup_id: lookup_id.clone(), - database: db_name.clone(), - status: "searching".to_string(), - tables_found: 0, - rows_found: 0, - error: None, - completed: completed.load(Ordering::Relaxed), - total, - }, - ); - - let result = search_database(&config, &db_name, &column_name, &value).await; - - let done = completed.fetch_add(1, Ordering::Relaxed) + 1; - - let status = if result.error.is_some() { - "error" - } else { - "done" - }; - - let _ = app.emit( - "lookup-progress", - LookupProgress { - lookup_id: lookup_id.clone(), - database: db_name.clone(), - status: status.to_string(), - tables_found: result.tables.len(), - rows_found: result.tables.iter().map(|t| t.row_count).sum(), - error: result.error.clone(), - completed: done, - total, - }, - ); - - result - }); - - handles.push(handle); - } - - // 3. Collect results - let mut all_results = Vec::new(); - for handle in handles { - match handle.await { - Ok(result) => all_results.push(result), - Err(e) => { - log::error!("Join error: {}", e); - } - } - } - - // Sort: databases with matches first, then by name - all_results.sort_by(|a, b| { - let a_has = !a.tables.is_empty(); - let b_has = !b.tables.is_empty(); - b_has.cmp(&a_has).then(a.database.cmp(&b.database)) - }); - - let total_databases_searched = all_results.len(); - let total_tables_matched: usize = all_results.iter().map(|d| d.tables.len()).sum(); - let total_rows_found: usize = all_results - .iter() - .flat_map(|d| d.tables.iter()) - .map(|t| t.row_count) - .sum(); - - Ok(EntityLookupResult { - column_name, - value, - databases: all_results, - total_databases_searched, - total_tables_matched, - total_rows_found, - total_time_ms: start.elapsed().as_millis(), - }) -} diff --git a/src-tauri/src/commands/management.rs b/src-tauri/src/commands/management.rs deleted file mode 100644 index 4d64bff..0000000 --- a/src-tauri/src/commands/management.rs +++ /dev/null @@ -1,594 +0,0 @@ -use crate::error::{TuskError, TuskResult}; -use crate::models::management::*; -use crate::state::{AppState, DbFlavor}; -use crate::utils::escape_ident; -use sqlx::Row; -use std::sync::Arc; -use tauri::State; - -#[tauri::command] -pub async fn get_database_info( - state: State<'_, Arc>, - connection_id: String, -) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let rows = sqlx::query( - "SELECT d.datname, \ - pg_catalog.pg_get_userbyid(d.datdba) AS owner, \ - pg_catalog.pg_encoding_to_char(d.encoding) AS encoding, \ - d.datcollate, \ - d.datctype, \ - COALESCE(t.spcname, 'pg_default') AS tablespace, \ - d.datconnlimit, \ - pg_catalog.pg_size_pretty(pg_catalog.pg_database_size(d.datname)) AS size, \ - pg_catalog.shobj_description(d.oid, 'pg_database') AS description \ - FROM pg_catalog.pg_database d \ - LEFT JOIN pg_catalog.pg_tablespace t ON d.dattablespace = t.oid \ - WHERE NOT d.datistemplate \ - ORDER BY d.datname", - ) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; - - let databases = rows - .iter() - .map(|row| DatabaseInfo { - name: row.get("datname"), - owner: row.get("owner"), - encoding: row.get("encoding"), - collation: row.get("datcollate"), - ctype: row.get("datctype"), - tablespace: row.get("tablespace"), - connection_limit: row.get("datconnlimit"), - size: row.get("size"), - description: row.get("description"), - }) - .collect(); - - Ok(databases) -} - -#[tauri::command] -pub async fn create_database( - state: State<'_, Arc>, - connection_id: String, - params: CreateDatabaseParams, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let mut sql = format!("CREATE DATABASE {}", escape_ident(¶ms.name)); - - if let Some(ref owner) = params.owner { - sql.push_str(&format!(" OWNER {}", escape_ident(owner))); - } - if let Some(ref template) = params.template { - sql.push_str(&format!(" TEMPLATE {}", escape_ident(template))); - } - if let Some(ref encoding) = params.encoding { - sql.push_str(&format!(" ENCODING '{}'", encoding.replace('\'', "''"))); - } - if let Some(ref tablespace) = params.tablespace { - sql.push_str(&format!(" TABLESPACE {}", escape_ident(tablespace))); - } - if let Some(limit) = params.connection_limit { - sql.push_str(&format!(" CONNECTION LIMIT {}", limit)); - } - - sqlx::query(&sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - - Ok(()) -} - -#[tauri::command] -pub async fn drop_database( - state: State<'_, Arc>, - connection_id: String, - name: String, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - // Terminate active connections to the target database - sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1::name AND pid <> pg_backend_pid()") - .bind(&name) - .execute(pool) - .await - .map_err(TuskError::Database)?; - - let drop_sql = format!("DROP DATABASE {}", escape_ident(&name)); - sqlx::query(&drop_sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - - Ok(()) -} - -#[tauri::command] -pub async fn list_roles( - state: State<'_, Arc>, - connection_id: String, -) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let rows = sqlx::query( - "SELECT r.rolname, \ - r.rolsuper, \ - r.rolcanlogin, \ - r.rolcreatedb, \ - r.rolcreaterole, \ - r.rolinherit, \ - r.rolreplication, \ - r.rolconnlimit, \ - r.rolpassword IS NOT NULL AS password_set, \ - r.rolvaliduntil::text, \ - COALESCE(( \ - SELECT array_agg(g.rolname ORDER BY g.rolname) \ - FROM pg_catalog.pg_auth_members m \ - JOIN pg_catalog.pg_roles g ON m.roleid = g.oid \ - WHERE m.member = r.oid \ - ), ARRAY[]::text[]) AS member_of, \ - COALESCE(( \ - SELECT array_agg(m2.rolname ORDER BY m2.rolname) \ - FROM pg_catalog.pg_auth_members am \ - JOIN pg_catalog.pg_roles m2 ON am.member = m2.oid \ - WHERE am.roleid = r.oid \ - ), ARRAY[]::text[]) AS members, \ - pg_catalog.shobj_description(r.oid, 'pg_authid') AS description \ - FROM pg_catalog.pg_roles r \ - WHERE r.rolname !~ '^pg_' \ - ORDER BY r.rolname", - ) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; - - let roles = rows - .iter() - .map(|row| RoleInfo { - name: row.get("rolname"), - is_superuser: row.get("rolsuper"), - can_login: row.get("rolcanlogin"), - can_create_db: row.get("rolcreatedb"), - can_create_role: row.get("rolcreaterole"), - inherit: row.get("rolinherit"), - is_replication: row.get("rolreplication"), - connection_limit: row.get("rolconnlimit"), - password_set: row.get("password_set"), - valid_until: row.get("rolvaliduntil"), - member_of: row.get("member_of"), - members: row.get("members"), - description: row.get("description"), - }) - .collect(); - - Ok(roles) -} - -#[tauri::command] -pub async fn create_role( - state: State<'_, Arc>, - connection_id: String, - params: CreateRoleParams, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let mut sql = format!("CREATE ROLE {}", escape_ident(¶ms.name)); - - let mut options = Vec::new(); - options.push(if params.login { "LOGIN" } else { "NOLOGIN" }); - options.push(if params.superuser { - "SUPERUSER" - } else { - "NOSUPERUSER" - }); - options.push(if params.createdb { - "CREATEDB" - } else { - "NOCREATEDB" - }); - options.push(if params.createrole { - "CREATEROLE" - } else { - "NOCREATEROLE" - }); - options.push(if params.inherit { - "INHERIT" - } else { - "NOINHERIT" - }); - options.push(if params.replication { - "REPLICATION" - } else { - "NOREPLICATION" - }); - - if let Some(ref password) = params.password { - options.push("PASSWORD"); - // Will be appended separately - sql.push_str(&format!(" {}", options.join(" "))); - sql.push_str(&format!(" '{}'", password.replace('\'', "''"))); - } else { - sql.push_str(&format!(" {}", options.join(" "))); - } - - if let Some(limit) = params.connection_limit { - sql.push_str(&format!(" CONNECTION LIMIT {}", limit)); - } - - if let Some(ref valid_until) = params.valid_until { - sql.push_str(&format!( - " VALID UNTIL '{}'", - valid_until.replace('\'', "''") - )); - } - - if !params.in_roles.is_empty() { - let roles: Vec = params.in_roles.iter().map(|r| escape_ident(r)).collect(); - sql.push_str(&format!(" IN ROLE {}", roles.join(", "))); - } - - sqlx::query(&sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - - Ok(()) -} - -#[tauri::command] -pub async fn alter_role( - state: State<'_, Arc>, - connection_id: String, - params: AlterRoleParams, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let mut options = Vec::new(); - - if let Some(login) = params.login { - options.push(if login { - "LOGIN".to_string() - } else { - "NOLOGIN".to_string() - }); - } - if let Some(superuser) = params.superuser { - options.push(if superuser { - "SUPERUSER".to_string() - } else { - "NOSUPERUSER".to_string() - }); - } - if let Some(createdb) = params.createdb { - options.push(if createdb { - "CREATEDB".to_string() - } else { - "NOCREATEDB".to_string() - }); - } - if let Some(createrole) = params.createrole { - options.push(if createrole { - "CREATEROLE".to_string() - } else { - "NOCREATEROLE".to_string() - }); - } - if let Some(inherit) = params.inherit { - options.push(if inherit { - "INHERIT".to_string() - } else { - "NOINHERIT".to_string() - }); - } - if let Some(replication) = params.replication { - options.push(if replication { - "REPLICATION".to_string() - } else { - "NOREPLICATION".to_string() - }); - } - if let Some(ref password) = params.password { - options.push(format!("PASSWORD '{}'", password.replace('\'', "''"))); - } - if let Some(limit) = params.connection_limit { - options.push(format!("CONNECTION LIMIT {}", limit)); - } - if let Some(ref valid_until) = params.valid_until { - options.push(format!("VALID UNTIL '{}'", valid_until.replace('\'', "''"))); - } - - if !options.is_empty() { - let sql = format!( - "ALTER ROLE {} {}", - escape_ident(¶ms.name), - options.join(" ") - ); - sqlx::query(&sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - } - - if let Some(ref new_name) = params.rename_to { - let sql = format!( - "ALTER ROLE {} RENAME TO {}", - escape_ident(¶ms.name), - escape_ident(new_name) - ); - sqlx::query(&sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - } - - Ok(()) -} - -#[tauri::command] -pub async fn drop_role( - state: State<'_, Arc>, - connection_id: String, - name: String, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let sql = format!("DROP ROLE {}", escape_ident(&name)); - sqlx::query(&sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - - Ok(()) -} - -#[tauri::command] -pub async fn get_table_privileges( - state: State<'_, Arc>, - connection_id: String, - schema: String, - table: String, -) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let rows = sqlx::query( - "SELECT grantee, table_schema, table_name, privilege_type, \ - is_grantable = 'YES' AS is_grantable \ - FROM information_schema.role_table_grants \ - WHERE table_schema = $1 AND table_name = $2 \ - ORDER BY grantee, privilege_type", - ) - .bind(&schema) - .bind(&table) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; - - let privileges = rows - .iter() - .map(|row| TablePrivilege { - grantee: row.get("grantee"), - table_schema: row.get("table_schema"), - table_name: row.get("table_name"), - privilege_type: row.get("privilege_type"), - is_grantable: row.get("is_grantable"), - }) - .collect(); - - Ok(privileges) -} - -#[tauri::command] -pub async fn grant_revoke( - state: State<'_, Arc>, - connection_id: String, - params: GrantRevokeParams, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let privs = params.privileges.join(", "); - let object_type = params.object_type.to_uppercase(); - let object_ref = escape_ident(¶ms.object_name); - let role_ref = escape_ident(¶ms.role_name); - - let sql = if params.action.to_uppercase() == "GRANT" { - let grant_option = if params.with_grant_option { - " WITH GRANT OPTION" - } else { - "" - }; - format!( - "GRANT {} ON {} {} TO {}{}", - privs, object_type, object_ref, role_ref, grant_option - ) - } else { - format!( - "REVOKE {} ON {} {} FROM {}", - privs, object_type, object_ref, role_ref - ) - }; - - sqlx::query(&sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - - Ok(()) -} - -#[tauri::command] -pub async fn manage_role_membership( - state: State<'_, Arc>, - connection_id: String, - params: RoleMembershipParams, -) -> TuskResult<()> { - if state.is_read_only(&connection_id).await { - return Err(TuskError::ReadOnly); - } - - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let role_ref = escape_ident(¶ms.role_name); - let member_ref = escape_ident(¶ms.member_name); - - let sql = if params.action.to_uppercase() == "GRANT" { - format!("GRANT {} TO {}", role_ref, member_ref) - } else { - format!("REVOKE {} FROM {}", role_ref, member_ref) - }; - - sqlx::query(&sql) - .execute(pool) - .await - .map_err(TuskError::Database)?; - - Ok(()) -} - -#[tauri::command] -pub async fn list_sessions( - state: State<'_, Arc>, - connection_id: String, -) -> TuskResult> { - let flavor = state.get_flavor(&connection_id).await; - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let sql = if flavor == DbFlavor::Greenplum { - "SELECT pid, usename, datname, state, query, \ - query_start::text, NULL::text as wait_event_type, NULL::text as wait_event, \ - client_addr::text \ - FROM pg_stat_activity \ - WHERE datname IS NOT NULL \ - ORDER BY query_start DESC NULLS LAST" - } else { - "SELECT pid, usename, datname, state, query, \ - query_start::text, wait_event_type, wait_event, \ - client_addr::text \ - FROM pg_stat_activity \ - WHERE datname IS NOT NULL \ - ORDER BY query_start DESC NULLS LAST" - }; - - let rows = sqlx::query(sql) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; - - let sessions = rows - .iter() - .map(|row| SessionInfo { - pid: row.get("pid"), - usename: row.get("usename"), - datname: row.get("datname"), - state: row.get("state"), - query: row.get("query"), - query_start: row.get("query_start"), - wait_event_type: row.get("wait_event_type"), - wait_event: row.get("wait_event"), - client_addr: row.get("client_addr"), - }) - .collect(); - - Ok(sessions) -} - -#[tauri::command] -pub async fn cancel_query( - state: State<'_, Arc>, - connection_id: String, - pid: i32, -) -> TuskResult { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let row = sqlx::query("SELECT pg_cancel_backend($1)") - .bind(pid) - .fetch_one(pool) - .await - .map_err(TuskError::Database)?; - - Ok(row.get::(0)) -} - -#[tauri::command] -pub async fn terminate_backend( - state: State<'_, Arc>, - connection_id: String, - pid: i32, -) -> TuskResult { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; - - let row = sqlx::query("SELECT pg_terminate_backend($1)") - .bind(pid) - .fetch_one(pool) - .await - .map_err(TuskError::Database)?; - - Ok(row.get::(0)) -} diff --git a/src-tauri/src/commands/memory.rs b/src-tauri/src/commands/memory.rs new file mode 100644 index 0000000..01664c3 --- /dev/null +++ b/src-tauri/src/commands/memory.rs @@ -0,0 +1,214 @@ +//! Per-connection long-term memory for the chat agent (F1). +//! +//! Stored as a markdown file at `/memory/.md`. +//! The agent appends notes via the `remember` tool; the user can view and edit +//! the file in the Memory sidebar tab. The same content is injected into the +//! LEARNED NOTES section of the system prompt every turn. + +use crate::error::{TuskError, TuskResult}; +use chrono::Utc; +use std::fs; +use std::path::PathBuf; +use tauri::{AppHandle, Manager}; + +/// Soft cap on memory file size. Overflow drops oldest `## ts` blocks until fits. +pub const MEMORY_BYTE_CAP: usize = 16 * 1024; + +pub(crate) fn get_memory_path( + app: &AppHandle, + connection_id: &str, +) -> TuskResult { + let dir = app + .path() + .app_data_dir() + .map_err(|e| TuskError::Config(e.to_string()))? + .join("memory"); + fs::create_dir_all(&dir)?; + let safe = sanitize_connection_id(connection_id); + Ok(dir.join(format!("{}.md", safe))) +} + +fn sanitize_connection_id(id: &str) -> String { + id.chars() + .map(|c| match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c, + _ => '_', + }) + .collect() +} + +pub(crate) fn read_memory_core( + app: &AppHandle, + connection_id: &str, +) -> TuskResult { + let path = get_memory_path(app, connection_id)?; + if !path.exists() { + return Ok(String::new()); + } + Ok(fs::read_to_string(&path)?) +} + +pub(crate) fn write_memory_core( + app: &AppHandle, + connection_id: &str, + content: &str, +) -> TuskResult<()> { + let path = get_memory_path(app, connection_id)?; + let trimmed = enforce_size_cap(content, MEMORY_BYTE_CAP); + fs::write(&path, trimmed)?; + Ok(()) +} + +pub(crate) fn append_memory_core( + app: &AppHandle, + connection_id: &str, + note: &str, +) -> TuskResult<()> { + let trimmed_note = note.trim(); + if trimmed_note.is_empty() { + return Err(TuskError::Custom("remember: note must not be empty".into())); + } + + let existing = read_memory_core(app, connection_id)?; + let mut buf = String::new(); + if existing.is_empty() { + buf.push_str("# Memory\n\n"); + } else { + buf.push_str(&existing); + if !buf.ends_with('\n') { + buf.push('\n'); + } + if !buf.ends_with("\n\n") { + buf.push('\n'); + } + } + let ts = Utc::now().format("%Y-%m-%dT%H:%M:%SZ"); + buf.push_str(&format!("## {}\n{}\n", ts, trimmed_note)); + + let final_content = enforce_size_cap(&buf, MEMORY_BYTE_CAP); + let path = get_memory_path(app, connection_id)?; + fs::write(&path, final_content)?; + Ok(()) +} + +/// Trim the file from the *oldest* note (top) until it fits within `cap` bytes. +/// Always preserves the trailing notes (the most recent observations). Keeps +/// the leading `# Memory\n\n` header if present. +pub(crate) fn enforce_size_cap(content: &str, cap: usize) -> String { + if content.len() <= cap { + return content.to_string(); + } + + let header = if content.starts_with("# Memory") { + match content.find("\n## ") { + Some(pos) => &content[..pos + 1], + None => "# Memory\n\n", + } + } else { + "" + }; + + // Split into note blocks by "\n## " marker. + // First block (after header) might lack the leading "## " — handle uniformly. + let body_start = header.len(); + let body = &content[body_start..]; + + let mut blocks: Vec<&str> = Vec::new(); + let mut idx = 0; + while idx < body.len() { + // Find the next "\n## " starting at idx; if not found, the rest is one block. + let rel = body[idx..].find("\n## "); + match rel { + Some(r) => { + blocks.push(&body[idx..idx + r + 1]); // include trailing newline before next block + idx = idx + r + 1; // start of "## " + } + None => { + blocks.push(&body[idx..]); + break; + } + } + } + + // Drop blocks from the front until total fits. + let mut current_size: usize = header.len() + blocks.iter().map(|b| b.len()).sum::(); + let mut start = 0usize; + while current_size > cap && start < blocks.len() { + current_size -= blocks[start].len(); + start += 1; + } + + let mut out = String::with_capacity(current_size); + out.push_str(header); + for b in &blocks[start..] { + out.push_str(b); + } + out +} + +#[tauri::command] +pub async fn get_memory(app: AppHandle, connection_id: String) -> TuskResult { + read_memory_core(&app, &connection_id) +} + +#[tauri::command] +pub async fn save_memory( + app: AppHandle, + connection_id: String, + content: String, +) -> TuskResult<()> { + write_memory_core(&app, &connection_id, &content) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cap_passthrough_under_limit() { + let small = "# Memory\n\n## 2026-01-01T00:00:00Z\nshort note\n"; + assert_eq!(enforce_size_cap(small, MEMORY_BYTE_CAP), small); + } + + #[test] + fn cap_drops_oldest_blocks() { + // 3 blocks of ~6KB each -> 18KB total > 16KB cap + let block_body = "x".repeat(6000); + let content = format!( + "# Memory\n\n## 2026-01-01T00:00:00Z\n{body}\n## 2026-02-01T00:00:00Z\n{body}\n## 2026-03-01T00:00:00Z\n{body}\n", + body = block_body + ); + assert!(content.len() > MEMORY_BYTE_CAP); + let trimmed = enforce_size_cap(&content, MEMORY_BYTE_CAP); + assert!(trimmed.len() <= MEMORY_BYTE_CAP); + // Most recent block must survive. + assert!(trimmed.contains("2026-03-01T00:00:00Z")); + // Oldest must be dropped. + assert!(!trimmed.contains("2026-01-01T00:00:00Z")); + // Header preserved. + assert!(trimmed.starts_with("# Memory")); + } + + #[test] + fn cap_keeps_only_latest_when_single_block_huge() { + let block_body = "y".repeat(20_000); + let content = format!( + "# Memory\n\n## 2026-01-01T00:00:00Z\n{}\n", + block_body + ); + let trimmed = enforce_size_cap(&content, MEMORY_BYTE_CAP); + // Even after dropping that single block we keep at least the header, + // so the result is just the header (or close to it). + assert!(trimmed.starts_with("# Memory")); + assert!(trimmed.len() <= MEMORY_BYTE_CAP); + } + + #[test] + fn sanitize_strips_path_chars() { + assert_eq!(sanitize_connection_id("abc/../etc"), "abc____etc"); + assert_eq!( + sanitize_connection_id("cf9feefd-59ab-4a7c"), + "cf9feefd-59ab-4a7c" + ); + } +} diff --git a/src-tauri/src/commands/mod.rs b/src-tauri/src/commands/mod.rs index 8bf6058..e290b41 100644 --- a/src-tauri/src/commands/mod.rs +++ b/src-tauri/src/commands/mod.rs @@ -1,13 +1,12 @@ pub mod ai; +pub mod chat; +pub mod chat_tools; pub mod connections; pub mod data; -pub mod docker; pub mod export; pub mod history; -pub mod lookup; -pub mod management; +pub mod memory; pub mod queries; pub mod saved_queries; pub mod schema; pub mod settings; -pub mod snapshot; diff --git a/src-tauri/src/commands/queries.rs b/src-tauri/src/commands/queries.rs index b94526f..a92e015 100644 --- a/src-tauri/src/commands/queries.rs +++ b/src-tauri/src/commands/queries.rs @@ -1,6 +1,7 @@ +use crate::db::sql_guard::ensure_readonly_sql; use crate::error::{TuskError, TuskResult}; use crate::models::query_result::QueryResult; -use crate::state::AppState; +use crate::state::{AppState, DbFlavor}; use serde_json::Value; use sqlx::postgres::PgRow; use sqlx::{Column, Row, TypeInfo}; @@ -81,6 +82,16 @@ pub async fn execute_query_core( sql: &str, ) -> TuskResult { let read_only = state.is_read_only(connection_id).await; + let flavor = state.get_flavor(connection_id).await; + + if read_only { + ensure_readonly_sql(sql)?; + } + + if matches!(flavor, DbFlavor::ClickHouse) { + let client = state.get_ch_client(connection_id).await?; + return client.execute_query(sql, read_only).await; + } let pools = state.pools.read().await; let pool = pools @@ -106,7 +117,7 @@ pub async fn execute_query_core( .await .map_err(TuskError::Database)? }; - let execution_time_ms = start.elapsed().as_millis(); + let execution_time_ms = start.elapsed().as_millis() as u64; let mut columns = Vec::new(); let mut types = Vec::new(); diff --git a/src-tauri/src/commands/saved_queries.rs b/src-tauri/src/commands/saved_queries.rs index ce233fd..951b029 100644 --- a/src-tauri/src/commands/saved_queries.rs +++ b/src-tauri/src/commands/saved_queries.rs @@ -12,12 +12,11 @@ fn get_saved_queries_path(app: &AppHandle) -> TuskResult { Ok(dir.join("saved_queries.json")) } -#[tauri::command] -pub async fn list_saved_queries( - app: AppHandle, - search: Option, +pub(crate) async fn list_saved_queries_core( + app: &AppHandle, + search: Option<&str>, ) -> TuskResult> { - let path = get_saved_queries_path(&app)?; + let path = get_saved_queries_path(app)?; if !path.exists() { return Ok(vec![]); } @@ -27,7 +26,7 @@ pub async fn list_saved_queries( let filtered: Vec = entries .into_iter() .filter(|e| { - if let Some(ref s) = search { + if let Some(s) = search { let lower = s.to_lowercase(); e.name.to_lowercase().contains(&lower) || e.sql.to_lowercase().contains(&lower) } else { @@ -39,9 +38,8 @@ pub async fn list_saved_queries( Ok(filtered) } -#[tauri::command] -pub async fn save_query(app: AppHandle, query: SavedQuery) -> TuskResult<()> { - let path = get_saved_queries_path(&app)?; +pub(crate) async fn save_query_core(app: &AppHandle, query: SavedQuery) -> TuskResult<()> { + let path = get_saved_queries_path(app)?; let mut entries = if path.exists() { let data = fs::read_to_string(&path)?; serde_json::from_str::>(&data).unwrap_or_default() @@ -56,6 +54,19 @@ pub async fn save_query(app: AppHandle, query: SavedQuery) -> TuskResult<()> { Ok(()) } +#[tauri::command] +pub async fn list_saved_queries( + app: AppHandle, + search: Option, +) -> TuskResult> { + list_saved_queries_core(&app, search.as_deref()).await +} + +#[tauri::command] +pub async fn save_query(app: AppHandle, query: SavedQuery) -> TuskResult<()> { + save_query_core(&app, query).await +} + #[tauri::command] pub async fn delete_saved_query(app: AppHandle, id: String) -> TuskResult<()> { let path = get_saved_queries_path(&app)?; diff --git a/src-tauri/src/commands/schema.rs b/src-tauri/src/commands/schema.rs index e68e071..23e02f5 100644 --- a/src-tauri/src/commands/schema.rs +++ b/src-tauri/src/commands/schema.rs @@ -1,20 +1,53 @@ use crate::error::{TuskError, TuskResult}; use crate::models::schema::{ - ColumnDetail, ColumnInfo, ConstraintInfo, ErdColumn, ErdData, ErdRelationship, ErdTable, - IndexInfo, SchemaObject, TriggerInfo, + ColumnDetail, ColumnInfo, ConstraintInfo, IndexInfo, SchemaObject, TriggerInfo, }; use crate::state::{AppState, DbFlavor}; +use serde_json::Value; use sqlx::Row; use std::collections::HashMap; use std::sync::Arc; use tauri::State; -#[tauri::command] -pub async fn list_databases( - state: State<'_, Arc>, - connection_id: String, -) -> TuskResult> { - let pool = state.get_pool(&connection_id).await?; +fn ch_string_literal(s: &str) -> String { + let escaped = s.replace('\\', "\\\\").replace('\'', "\\'"); + format!("'{}'", escaped) +} + +fn ch_obj_string(obj: &serde_json::Map, key: &str) -> Option { + obj.get(key).and_then(|v| match v { + Value::String(s) => Some(s.clone()), + Value::Number(n) => Some(n.to_string()), + _ => None, + }) +} + +fn ch_obj_i64(obj: &serde_json::Map, key: &str) -> Option { + obj.get(key).and_then(|v| match v { + Value::Number(n) => n.as_i64(), + Value::String(s) => s.parse::().ok(), + _ => None, + }) +} + +pub async fn list_databases_core(state: &AppState, connection_id: &str) -> TuskResult> { + let flavor = state.get_flavor(connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + let client = state.get_ch_client(connection_id).await?; + let rows = client + .fetch_objects( + "SELECT name FROM system.databases \ + WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \ + ORDER BY name", + ) + .await?; + return Ok(rows + .iter() + .filter_map(|o| ch_obj_string(o, "name")) + .collect()); + } + + let pool = state.get_pool(connection_id).await?; let rows = sqlx::query( "SELECT datname FROM pg_database \ @@ -28,10 +61,24 @@ pub async fn list_databases( Ok(rows.iter().map(|r| r.get::(0)).collect()) } +#[tauri::command] +pub async fn list_databases( + state: State<'_, Arc>, + connection_id: String, +) -> TuskResult> { + list_databases_core(&state, &connection_id).await +} + pub async fn list_schemas_core(state: &AppState, connection_id: &str) -> TuskResult> { + let flavor = state.get_flavor(connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + // ClickHouse has no schema layer — surface the active database as a virtual schema. + let client = state.get_ch_client(connection_id).await?; + return Ok(vec![client.database.clone()]); + } + let pool = state.get_pool(connection_id).await?; - let flavor = state.get_flavor(connection_id).await; let sql = if flavor == DbFlavor::Greenplum { "SELECT schema_name FROM information_schema.schemata \ WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ @@ -63,6 +110,29 @@ pub async fn list_tables_core( connection_id: &str, schema: &str, ) -> TuskResult> { + let flavor = state.get_flavor(connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + let client = state.get_ch_client(connection_id).await?; + let escaped = ch_string_literal(schema); + let sql = format!( + "SELECT name, total_rows, total_bytes FROM system.tables \ + WHERE database = {} AND engine NOT LIKE '%View' \ + ORDER BY name", + escaped + ); + let rows = client.fetch_objects(&sql).await?; + return Ok(rows + .iter() + .map(|o| SchemaObject { + name: ch_obj_string(o, "name").unwrap_or_default(), + object_type: "table".to_string(), + schema: schema.to_string(), + row_count: ch_obj_i64(o, "total_rows"), + size_bytes: ch_obj_i64(o, "total_bytes"), + }) + .collect()); + } + let pool = state.get_pool(connection_id).await?; let rows = sqlx::query( @@ -107,6 +177,28 @@ pub async fn list_views( connection_id: String, schema: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + let client = state.get_ch_client(&connection_id).await?; + let sql = format!( + "SELECT name FROM system.tables \ + WHERE database = {} AND engine LIKE '%View' \ + ORDER BY name", + ch_string_literal(&schema) + ); + let rows = client.fetch_objects(&sql).await?; + return Ok(rows + .iter() + .map(|o| SchemaObject { + name: ch_obj_string(o, "name").unwrap_or_default(), + object_type: "view".to_string(), + schema: schema.clone(), + row_count: None, + size_bytes: None, + }) + .collect()); + } + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( @@ -137,6 +229,11 @@ pub async fn list_functions( connection_id: String, schema: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + // ClickHouse functions are global, not schema-scoped — surface empty here. + return Ok(vec![]); + } let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( @@ -167,6 +264,10 @@ pub async fn list_indexes( connection_id: String, schema: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return Ok(vec![]); + } let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( @@ -197,6 +298,10 @@ pub async fn list_sequences( connection_id: String, schema: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return Ok(vec![]); + } let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( @@ -227,6 +332,36 @@ pub async fn get_table_columns_core( schema: &str, table: &str, ) -> TuskResult> { + let flavor = state.get_flavor(connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + let client = state.get_ch_client(connection_id).await?; + let sql = format!( + "SELECT name, type, default_expression, is_in_primary_key, comment, position \ + FROM system.columns WHERE database = {} AND table = {} \ + ORDER BY position", + ch_string_literal(schema), + ch_string_literal(table) + ); + let rows = client.fetch_objects(&sql).await?; + return Ok(rows + .iter() + .map(|o| { + let type_str = ch_obj_string(o, "type").unwrap_or_default(); + let is_nullable = type_str.starts_with("Nullable("); + ColumnInfo { + name: ch_obj_string(o, "name").unwrap_or_default(), + data_type: type_str, + is_nullable, + column_default: ch_obj_string(o, "default_expression"), + ordinal_position: ch_obj_i64(o, "position").unwrap_or(0) as i32, + character_maximum_length: None, + is_primary_key: ch_obj_i64(o, "is_in_primary_key").unwrap_or(0) != 0, + comment: ch_obj_string(o, "comment"), + } + }) + .collect()); + } + let pool = state.get_pool(connection_id).await?; let rows = sqlx::query( @@ -296,6 +431,10 @@ pub async fn get_table_constraints( schema: String, table: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return Ok(vec![]); + } let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( @@ -372,6 +511,10 @@ pub async fn get_table_indexes( schema: String, table: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return Ok(vec![]); + } let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( @@ -410,6 +553,25 @@ pub async fn get_completion_schema( connection_id: String, ) -> TuskResult>>> { let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + let client = state.get_ch_client(&connection_id).await?; + let sql = format!( + "SELECT database, table, name FROM system.columns \ + WHERE database = {} \ + ORDER BY database, table, position", + ch_string_literal(&client.database) + ); + let rows = client.fetch_objects(&sql).await?; + let mut result: HashMap>> = HashMap::new(); + for row in rows { + let db = ch_obj_string(&row, "database").unwrap_or_default(); + let table = ch_obj_string(&row, "table").unwrap_or_default(); + let column = ch_obj_string(&row, "name").unwrap_or_default(); + result.entry(db).or_default().entry(table).or_default().push(column); + } + return Ok(result); + } + let pool = state.get_pool(&connection_id).await?; let sql = if flavor == DbFlavor::Greenplum { @@ -454,6 +616,19 @@ pub async fn get_column_details( table: String, ) -> TuskResult> { let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + let columns = get_table_columns_core(&state, &connection_id, &schema, &table).await?; + return Ok(columns + .into_iter() + .map(|c| ColumnDetail { + column_name: c.name, + data_type: c.data_type, + is_nullable: c.is_nullable, + column_default: c.column_default, + is_identity: false, + }) + .collect()); + } let pool = state.get_pool(&connection_id).await?; let sql = if flavor == DbFlavor::Greenplum { @@ -500,6 +675,10 @@ pub async fn get_table_triggers( schema: String, table: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; + if matches!(flavor, DbFlavor::ClickHouse) { + return Ok(vec![]); + } let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( @@ -547,127 +726,3 @@ pub async fn get_table_triggers( .collect()) } -#[tauri::command] -pub async fn get_schema_erd( - state: State<'_, Arc>, - connection_id: String, - schema: String, -) -> TuskResult { - let pool = state.get_pool(&connection_id).await?; - - // Get all tables with columns - let col_rows = sqlx::query( - "SELECT \ - c.table_name, \ - c.column_name, \ - c.data_type, \ - c.is_nullable = 'YES' AS is_nullable, \ - COALESCE(( \ - SELECT true FROM pg_constraint con \ - JOIN pg_class cl ON cl.oid = con.conrelid \ - JOIN pg_namespace ns ON ns.oid = cl.relnamespace \ - WHERE con.contype = 'p' \ - AND ns.nspname = $1 AND cl.relname = c.table_name \ - AND EXISTS ( \ - SELECT 1 FROM unnest(con.conkey) k \ - JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = k \ - WHERE a.attname = c.column_name \ - ) \ - LIMIT 1 \ - ), false) AS is_pk \ - FROM information_schema.columns c \ - JOIN information_schema.tables t \ - ON t.table_schema = c.table_schema AND t.table_name = c.table_name \ - WHERE c.table_schema = $1 AND t.table_type = 'BASE TABLE' \ - ORDER BY c.table_name, c.ordinal_position", - ) - .bind(&schema) - .fetch_all(&pool) - .await - .map_err(TuskError::Database)?; - - // Build tables map - let mut tables_map: HashMap = HashMap::new(); - for row in &col_rows { - let table_name: String = row.get(0); - let entry = tables_map - .entry(table_name.clone()) - .or_insert_with(|| ErdTable { - schema: schema.clone(), - name: table_name, - columns: Vec::new(), - }); - entry.columns.push(ErdColumn { - name: row.get(1), - data_type: row.get(2), - is_nullable: row.get(3), - is_primary_key: row.get(4), - }); - } - let tables: Vec = tables_map.into_values().collect(); - - // Get all FK relationships - let fk_rows = sqlx::query( - "SELECT \ - c.conname AS constraint_name, \ - src_ns.nspname AS source_schema, \ - src_cl.relname AS source_table, \ - ARRAY( \ - SELECT a.attname FROM unnest(c.conkey) WITH ORDINALITY AS k(attnum, ord) \ - JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = k.attnum \ - ORDER BY k.ord \ - )::text[] AS source_columns, \ - ref_ns.nspname AS target_schema, \ - ref_cl.relname AS target_table, \ - ARRAY( \ - SELECT a.attname FROM unnest(c.confkey) WITH ORDINALITY AS k(attnum, ord) \ - JOIN pg_attribute a ON a.attrelid = c.confrelid AND a.attnum = k.attnum \ - ORDER BY k.ord \ - )::text[] AS target_columns, \ - CASE c.confupdtype \ - WHEN 'a' THEN 'NO ACTION' \ - WHEN 'r' THEN 'RESTRICT' \ - WHEN 'c' THEN 'CASCADE' \ - WHEN 'n' THEN 'SET NULL' \ - WHEN 'd' THEN 'SET DEFAULT' \ - END AS update_rule, \ - CASE c.confdeltype \ - WHEN 'a' THEN 'NO ACTION' \ - WHEN 'r' THEN 'RESTRICT' \ - WHEN 'c' THEN 'CASCADE' \ - WHEN 'n' THEN 'SET NULL' \ - WHEN 'd' THEN 'SET DEFAULT' \ - END AS delete_rule \ - FROM pg_constraint c \ - JOIN pg_class src_cl ON src_cl.oid = c.conrelid \ - JOIN pg_namespace src_ns ON src_ns.oid = src_cl.relnamespace \ - JOIN pg_class ref_cl ON ref_cl.oid = c.confrelid \ - JOIN pg_namespace ref_ns ON ref_ns.oid = ref_cl.relnamespace \ - WHERE c.contype = 'f' AND src_ns.nspname = $1 \ - ORDER BY c.conname", - ) - .bind(&schema) - .fetch_all(&pool) - .await - .map_err(TuskError::Database)?; - - let relationships: Vec = fk_rows - .iter() - .map(|r| ErdRelationship { - constraint_name: r.get(0), - source_schema: r.get(1), - source_table: r.get(2), - source_columns: r.get(3), - target_schema: r.get(4), - target_table: r.get(5), - target_columns: r.get(6), - update_rule: r.get(7), - delete_rule: r.get(8), - }) - .collect(); - - Ok(ErdData { - tables, - relationships, - }) -} diff --git a/src-tauri/src/commands/settings.rs b/src-tauri/src/commands/settings.rs index 5324d58..f621f57 100644 --- a/src-tauri/src/commands/settings.rs +++ b/src-tauri/src/commands/settings.rs @@ -1,6 +1,6 @@ use crate::error::{TuskError, TuskResult}; use crate::mcp; -use crate::models::settings::{AppSettings, DockerHost, McpStatus}; +use crate::models::settings::{AppSettings, McpStatus}; use crate::state::AppState; use std::fs; use std::sync::Arc; @@ -36,15 +36,6 @@ pub async fn save_app_settings( let data = serde_json::to_string_pretty(&settings)?; fs::write(&path, data)?; - // Apply docker host setting - { - let mut docker_host = state.docker_host.write().await; - *docker_host = match settings.docker.host { - DockerHost::Remote => settings.docker.remote_url.clone(), - DockerHost::Local => None, - }; - } - // Apply MCP setting: restart or stop let is_running = *state.mcp_running.read().await; diff --git a/src-tauri/src/commands/snapshot.rs b/src-tauri/src/commands/snapshot.rs deleted file mode 100644 index 03d0d25..0000000 --- a/src-tauri/src/commands/snapshot.rs +++ /dev/null @@ -1,362 +0,0 @@ -use crate::commands::ai::fetch_foreign_keys_raw; -use crate::commands::data::bind_json_value; -use crate::commands::queries::pg_value_to_json; -use crate::error::{TuskError, TuskResult}; -use crate::models::snapshot::{ - CreateSnapshotParams, RestoreSnapshotParams, Snapshot, SnapshotMetadata, SnapshotProgress, - SnapshotTableData, SnapshotTableMeta, -}; -use crate::state::AppState; -use crate::utils::{escape_ident, topological_sort_tables}; -use serde_json::Value; -use sqlx::{Column, Row, TypeInfo}; -use std::fs; -use std::sync::Arc; -use tauri::{AppHandle, Emitter, Manager, State}; - -#[tauri::command] -pub async fn create_snapshot( - app: AppHandle, - state: State<'_, Arc>, - params: CreateSnapshotParams, - snapshot_id: String, - file_path: String, -) -> TuskResult { - let pool = state.get_pool(¶ms.connection_id).await?; - - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "preparing".to_string(), - percent: 5, - message: "Preparing snapshot...".to_string(), - detail: None, - }, - ); - - let mut target_tables: Vec<(String, String)> = params - .tables - .iter() - .map(|t| (t.schema.clone(), t.table.clone())) - .collect(); - - // Fetch FK info once — used for both dependency expansion and topological sort - let fk_rows = fetch_foreign_keys_raw(&pool).await?; - - if params.include_dependencies { - for fk in &fk_rows { - if target_tables - .iter() - .any(|(s, t)| s == &fk.schema && t == &fk.table) - { - let parent = (fk.ref_schema.clone(), fk.ref_table.clone()); - if !target_tables.contains(&parent) { - target_tables.push(parent); - } - } - } - } - - // FK-based topological sort - let fk_edges: Vec<(String, String, String, String)> = fk_rows - .iter() - .map(|fk| { - ( - fk.schema.clone(), - fk.table.clone(), - fk.ref_schema.clone(), - fk.ref_table.clone(), - ) - }) - .collect(); - let sorted_tables = topological_sort_tables(&fk_edges, &target_tables); - - let mut tx = pool.begin().await.map_err(TuskError::Database)?; - sqlx::query("SET TRANSACTION READ ONLY") - .execute(&mut *tx) - .await - .map_err(TuskError::Database)?; - - let total_tables = sorted_tables.len(); - let mut snapshot_tables: Vec = Vec::new(); - let mut table_metas: Vec = Vec::new(); - let mut total_rows: u64 = 0; - - for (i, (schema, table)) in sorted_tables.iter().enumerate() { - let percent = (10 + (i * 80 / total_tables.max(1))).min(90) as u8; - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "exporting".to_string(), - percent, - message: format!("Exporting {}.{}...", schema, table), - detail: None, - }, - ); - - let qualified = format!("{}.{}", escape_ident(schema), escape_ident(table)); - let sql = format!("SELECT * FROM {}", qualified); - let rows = sqlx::query(&sql) - .fetch_all(&mut *tx) - .await - .map_err(TuskError::Database)?; - - let mut columns = Vec::new(); - let mut column_types = Vec::new(); - - if let Some(first) = rows.first() { - for col in first.columns() { - columns.push(col.name().to_string()); - column_types.push(col.type_info().name().to_string()); - } - } - - let data_rows: Vec> = rows - .iter() - .map(|row| { - (0..columns.len()) - .map(|i| pg_value_to_json(row, i)) - .collect() - }) - .collect(); - - let row_count = data_rows.len() as u64; - total_rows += row_count; - - table_metas.push(SnapshotTableMeta { - schema: schema.clone(), - table: table.clone(), - row_count, - columns: columns.clone(), - column_types: column_types.clone(), - }); - - snapshot_tables.push(SnapshotTableData { - schema: schema.clone(), - table: table.clone(), - columns, - column_types, - rows: data_rows, - }); - } - - tx.rollback().await.map_err(TuskError::Database)?; - - let metadata = SnapshotMetadata { - id: snapshot_id.clone(), - name: params.name.clone(), - created_at: chrono::Utc::now().to_rfc3339(), - connection_name: String::new(), - database: String::new(), - tables: table_metas, - total_rows, - file_size_bytes: 0, - version: 1, - }; - - let snapshot = Snapshot { - metadata: metadata.clone(), - tables: snapshot_tables, - }; - - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "saving".to_string(), - percent: 95, - message: "Saving snapshot file...".to_string(), - detail: None, - }, - ); - - let json = serde_json::to_string_pretty(&snapshot)?; - let file_size = json.len() as u64; - fs::write(&file_path, json)?; - - let mut final_metadata = metadata; - final_metadata.file_size_bytes = file_size; - - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "done".to_string(), - percent: 100, - message: "Snapshot created successfully".to_string(), - detail: Some(format!("{} rows, {} tables", total_rows, total_tables)), - }, - ); - - Ok(final_metadata) -} - -#[tauri::command] -pub async fn restore_snapshot( - app: AppHandle, - state: State<'_, Arc>, - params: RestoreSnapshotParams, - snapshot_id: String, -) -> TuskResult { - if state.is_read_only(¶ms.connection_id).await { - return Err(TuskError::ReadOnly); - } - - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "reading".to_string(), - percent: 5, - message: "Reading snapshot file...".to_string(), - detail: None, - }, - ); - - let data = fs::read_to_string(¶ms.file_path)?; - let snapshot: Snapshot = serde_json::from_str(&data)?; - - let pool = state.get_pool(¶ms.connection_id).await?; - let mut tx = pool.begin().await.map_err(TuskError::Database)?; - - sqlx::query("SET CONSTRAINTS ALL DEFERRED") - .execute(&mut *tx) - .await - .map_err(TuskError::Database)?; - - // TRUNCATE in reverse order (children first) - if params.truncate_before_restore { - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "truncating".to_string(), - percent: 15, - message: "Truncating existing data...".to_string(), - detail: None, - }, - ); - - for table_data in snapshot.tables.iter().rev() { - let qualified = format!( - "{}.{}", - escape_ident(&table_data.schema), - escape_ident(&table_data.table) - ); - let truncate_sql = format!("TRUNCATE {} CASCADE", qualified); - sqlx::query(&truncate_sql) - .execute(&mut *tx) - .await - .map_err(TuskError::Database)?; - } - } - - // INSERT in forward order (parents first) - let total_tables = snapshot.tables.len(); - let mut total_inserted: u64 = 0; - - for (i, table_data) in snapshot.tables.iter().enumerate() { - if table_data.columns.is_empty() || table_data.rows.is_empty() { - continue; - } - - let percent = (20 + (i * 75 / total_tables.max(1))).min(95) as u8; - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "inserting".to_string(), - percent, - message: format!("Restoring {}.{}...", table_data.schema, table_data.table), - detail: Some(format!("{} rows", table_data.rows.len())), - }, - ); - - let qualified = format!( - "{}.{}", - escape_ident(&table_data.schema), - escape_ident(&table_data.table) - ); - let col_list: Vec = table_data.columns.iter().map(|c| escape_ident(c)).collect(); - let placeholders: Vec = (1..=table_data.columns.len()) - .map(|i| format!("${}", i)) - .collect(); - - let sql = format!( - "INSERT INTO {} ({}) VALUES ({})", - qualified, - col_list.join(", "), - placeholders.join(", ") - ); - - // Chunked insert - for row in &table_data.rows { - let mut query = sqlx::query(&sql); - for val in row { - query = bind_json_value(query, val); - } - query.execute(&mut *tx).await.map_err(TuskError::Database)?; - total_inserted += 1; - } - } - - tx.commit().await.map_err(TuskError::Database)?; - - let _ = app.emit( - "snapshot-progress", - SnapshotProgress { - snapshot_id: snapshot_id.clone(), - stage: "done".to_string(), - percent: 100, - message: "Restore completed successfully".to_string(), - detail: Some(format!("{} rows restored", total_inserted)), - }, - ); - - state.invalidate_schema_cache(¶ms.connection_id).await; - - Ok(total_inserted) -} - -#[tauri::command] -pub async fn list_snapshots(app: AppHandle) -> TuskResult> { - let dir = app - .path() - .app_data_dir() - .map_err(|e| TuskError::Config(e.to_string()))? - .join("snapshots"); - - if !dir.exists() { - return Ok(Vec::new()); - } - - let mut snapshots = Vec::new(); - - for entry in fs::read_dir(&dir)? { - let entry = entry?; - let path = entry.path(); - if path.extension().map(|e| e == "json").unwrap_or(false) { - if let Ok(data) = fs::read_to_string(&path) { - if let Ok(snapshot) = serde_json::from_str::(&data) { - let mut meta = snapshot.metadata; - meta.file_size_bytes = entry.metadata().map(|m| m.len()).unwrap_or(0); - snapshots.push(meta); - } - } - } - } - - snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at)); - Ok(snapshots) -} - -#[tauri::command] -pub async fn read_snapshot_metadata(file_path: String) -> TuskResult { - let data = fs::read_to_string(&file_path)?; - let snapshot: Snapshot = serde_json::from_str(&data)?; - let mut meta = snapshot.metadata; - meta.file_size_bytes = fs::metadata(&file_path).map(|m| m.len()).unwrap_or(0); - Ok(meta) -} diff --git a/src-tauri/src/db/clickhouse.rs b/src-tauri/src/db/clickhouse.rs new file mode 100644 index 0000000..687a984 --- /dev/null +++ b/src-tauri/src/db/clickhouse.rs @@ -0,0 +1,168 @@ +use crate::error::{TuskError, TuskResult}; +use crate::models::query_result::QueryResult; +use serde::Deserialize; +use serde_json::{Map, Value}; +use std::sync::LazyLock; +use std::time::{Duration, Instant}; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(120); + +fn http_client() -> &'static reqwest::Client { + static CLIENT: LazyLock = LazyLock::new(|| { + reqwest::Client::builder() + .connect_timeout(Duration::from_secs(5)) + .timeout(DEFAULT_TIMEOUT) + .build() + .unwrap_or_default() + }); + &CLIENT +} + +#[derive(Debug, Clone)] +pub struct ChClient { + pub base_url: String, + pub user: String, + pub password: String, + pub database: String, +} + +impl ChClient { + pub fn new(host: &str, port: u16, secure: bool, user: &str, password: &str, database: &str) -> Self { + let scheme = if secure { "https" } else { "http" }; + let base_url = format!("{}://{}:{}", scheme, host, port); + Self { + base_url, + user: user.to_string(), + password: password.to_string(), + database: database.to_string(), + } + } + + fn endpoint(&self, database: Option<&str>, format: Option<&str>, read_only: bool) -> String { + let db = database.unwrap_or(&self.database); + let mut params = vec![ + format!("database={}", urlencode(db)), + format!("user={}", urlencode(&self.user)), + ]; + if !self.password.is_empty() { + params.push(format!("password={}", urlencode(&self.password))); + } + if let Some(fmt) = format { + params.push(format!("default_format={}", urlencode(fmt))); + } + if read_only { + params.push("readonly=1".to_string()); + } + format!("{}/?{}", self.base_url, params.join("&")) + } + + /// Execute SQL and return raw response body. + pub async fn execute_raw(&self, sql: &str, format: Option<&str>, read_only: bool) -> TuskResult { + let url = self.endpoint(None, format, read_only); + let resp = http_client() + .post(&url) + .body(sql.to_string()) + .send() + .await + .map_err(|e| TuskError::Custom(format!("ClickHouse request failed: {}", e)))?; + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| TuskError::Custom(format!("Failed to read ClickHouse response: {}", e)))?; + if !status.is_success() { + return Err(TuskError::Custom(format!( + "ClickHouse error ({}): {}", + status, + body.trim() + ))); + } + Ok(body) + } + + /// Test connection by running `SELECT 1` and return the server version. + pub async fn ping(&self) -> TuskResult { + // Use raw FORMAT TabSeparated to fetch version + let body = self.execute_raw("SELECT version()", Some("TabSeparated"), false).await?; + Ok(body.trim().to_string()) + } + + /// Execute SQL and parse rows via JSONCompact to preserve column metadata + types. + pub async fn execute_query(&self, sql: &str, read_only: bool) -> TuskResult { + let start = Instant::now(); + let body = self.execute_raw(sql, Some("JSONCompact"), read_only).await?; + let execution_time_ms = start.elapsed().as_millis() as u64; + + // Empty body for statements without result set (DDL etc.) — return zero rows + if body.trim().is_empty() { + return Ok(QueryResult { + columns: vec![], + types: vec![], + rows: vec![], + row_count: 0, + execution_time_ms, + }); + } + + let parsed: ChJsonCompactResponse = serde_json::from_str(&body).map_err(|e| { + TuskError::Custom(format!( + "Failed to parse ClickHouse JSONCompact response: {} (body head: {})", + e, + body.chars().take(200).collect::() + )) + })?; + + let columns: Vec = parsed.meta.iter().map(|m| m.name.clone()).collect(); + let types: Vec = parsed.meta.iter().map(|m| m.r#type.clone()).collect(); + let row_count = parsed.data.len(); + + Ok(QueryResult { + columns, + types, + rows: parsed.data, + row_count, + execution_time_ms, + }) + } + + /// Execute SQL expecting result rows as objects (for schema introspection helpers). + pub async fn fetch_objects(&self, sql: &str) -> TuskResult>> { + let body = self.execute_raw(sql, Some("JSONEachRow"), false).await?; + let mut out = Vec::new(); + for line in body.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + let value: Value = serde_json::from_str(line).map_err(|e| { + TuskError::Custom(format!("Failed to parse JSONEachRow line: {}", e)) + })?; + if let Value::Object(obj) = value { + out.push(obj); + } + } + Ok(out) + } +} + +#[derive(Debug, Deserialize)] +struct ChJsonCompactResponse { + meta: Vec, + data: Vec>, +} + +#[derive(Debug, Deserialize)] +struct ChMetaEntry { + name: String, + r#type: String, +} + +fn urlencode(s: &str) -> String { + s.chars() + .map(|c| match c { + ':' | '/' | '?' | '#' | '[' | ']' | '@' | '!' | '$' | '&' | '\'' | '(' | ')' | '*' + | '+' | ',' | ';' | '=' | '%' | ' ' => format!("%{:02X}", c as u8), + _ => c.to_string(), + }) + .collect() +} diff --git a/src-tauri/src/db/mod.rs b/src-tauri/src/db/mod.rs new file mode 100644 index 0000000..b1c28d2 --- /dev/null +++ b/src-tauri/src/db/mod.rs @@ -0,0 +1,2 @@ +pub mod clickhouse; +pub mod sql_guard; diff --git a/src-tauri/src/db/sql_guard.rs b/src-tauri/src/db/sql_guard.rs new file mode 100644 index 0000000..96a14cd --- /dev/null +++ b/src-tauri/src/db/sql_guard.rs @@ -0,0 +1,140 @@ +use crate::error::{TuskError, TuskResult}; + +/// Cross-flavor whitelist guard for read-only SQL execution. +/// Allows: SELECT, WITH ... SELECT, SHOW, EXPLAIN, DESCRIBE. +/// Rejects: INSERT, UPDATE, DELETE, ALTER, DROP, CREATE, TRUNCATE, +/// RENAME, GRANT, REVOKE, ATTACH, DETACH, OPTIMIZE, SYSTEM. +pub fn ensure_readonly_sql(sql: &str) -> TuskResult<()> { + let normalized = strip_leading_comments(sql).to_ascii_uppercase(); + let trimmed = normalized.trim(); + if trimmed.is_empty() { + return Err(TuskError::Validation("Empty SQL statement".into())); + } + let allowed_starts = ["SELECT", "WITH", "SHOW", "EXPLAIN", "DESCRIBE", "DESC ", "DESC\n", "VALUES"]; + let starts_ok = allowed_starts + .iter() + .any(|p| trimmed.starts_with(p) || trimmed == p.trim()); + if !starts_ok { + return Err(TuskError::ReadOnly); + } + + // Reject if any forbidden keyword appears as a top-level token + let forbidden = [ + "INSERT", "UPDATE", "DELETE", "ALTER", "DROP", "CREATE", "TRUNCATE", + "RENAME", "GRANT", "REVOKE", "ATTACH", "DETACH", "OPTIMIZE", "SYSTEM", + "REPLACE", "MERGE", + ]; + for kw in forbidden { + if contains_keyword(&normalized, kw) { + return Err(TuskError::ReadOnly); + } + } + Ok(()) +} + +fn strip_leading_comments(sql: &str) -> &str { + let mut s = sql.trim_start(); + loop { + if let Some(rest) = s.strip_prefix("--") { + // line comment — skip to newline + match rest.find('\n') { + Some(idx) => s = rest[idx + 1..].trim_start(), + None => return "", + } + } else if let Some(rest) = s.strip_prefix("/*") { + match rest.find("*/") { + Some(idx) => s = rest[idx + 2..].trim_start(), + None => return "", + } + } else { + break; + } + } + s +} + +fn contains_keyword(haystack: &str, kw: &str) -> bool { + let bytes = haystack.as_bytes(); + let needle = kw.as_bytes(); + let mut i = 0; + while i + needle.len() <= bytes.len() { + if &bytes[i..i + needle.len()] == needle { + let before_ok = i == 0 || !is_word_char(bytes[i - 1]); + let after = i + needle.len(); + let after_ok = after == bytes.len() || !is_word_char(bytes[after]); + if before_ok && after_ok { + return true; + } + } + i += 1; + } + false +} + +fn is_word_char(b: u8) -> bool { + b.is_ascii_alphanumeric() || b == b'_' +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn allows_select() { + assert!(ensure_readonly_sql("SELECT 1").is_ok()); + assert!(ensure_readonly_sql(" SELECT * FROM t").is_ok()); + assert!(ensure_readonly_sql("select 1").is_ok()); + } + + #[test] + fn allows_with_select() { + assert!(ensure_readonly_sql("WITH x AS (SELECT 1) SELECT * FROM x").is_ok()); + } + + #[test] + fn allows_explain_show() { + assert!(ensure_readonly_sql("EXPLAIN SELECT 1").is_ok()); + assert!(ensure_readonly_sql("SHOW TABLES").is_ok()); + assert!(ensure_readonly_sql("DESCRIBE t").is_ok()); + } + + #[test] + fn rejects_dml_ddl() { + assert!(ensure_readonly_sql("INSERT INTO t VALUES (1)").is_err()); + assert!(ensure_readonly_sql("UPDATE t SET a=1").is_err()); + assert!(ensure_readonly_sql("DELETE FROM t").is_err()); + assert!(ensure_readonly_sql("DROP TABLE t").is_err()); + assert!(ensure_readonly_sql("CREATE TABLE t(a int)").is_err()); + assert!(ensure_readonly_sql("TRUNCATE TABLE t").is_err()); + } + + #[test] + fn rejects_writable_cte() { + // PG writable CTE — looks like WITH but contains INSERT + assert!( + ensure_readonly_sql("WITH x AS (INSERT INTO t VALUES (1) RETURNING *) SELECT * FROM x") + .is_err() + ); + } + + #[test] + fn rejects_select_with_drop_chain() { + assert!(ensure_readonly_sql("SELECT 1; DROP TABLE t").is_err()); + } + + #[test] + fn allows_select_with_keyword_in_string() { + // Real-world: column names containing forbidden keywords should pass; string literals + // containing them should also pass. Our guard is conservative and may reject some + // legitimate queries — that is acceptable for the read-only safety net. + // This test documents the limitation: queries embedding "DROP" as a literal will be rejected. + // The user can disable read-only mode to run them. + } + + #[test] + fn strips_leading_comments() { + assert!(ensure_readonly_sql("-- comment\nSELECT 1").is_ok()); + assert!(ensure_readonly_sql("/* block */ SELECT 1").is_ok()); + assert!(ensure_readonly_sql("/* multi\nline */\n SELECT 1").is_ok()); + } +} diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index 7440a7d..10f5309 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -11,9 +11,6 @@ pub enum TuskError { #[error("Serialization error: {0}")] Serde(#[from] serde_json::Error), - #[error("Connection not found: {0}")] - ConnectionNotFound(String), - #[error("Not connected: {0}")] NotConnected(String), @@ -23,9 +20,6 @@ pub enum TuskError { #[error("AI error: {0}")] Ai(String), - #[error("Docker error: {0}")] - Docker(String), - #[error("Configuration error: {0}")] Config(String), diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 40655c9..c6dff21 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -1,11 +1,12 @@ mod commands; +mod db; mod error; mod mcp; mod models; mod state; mod utils; -use models::settings::{AppSettings, DockerHost}; +use models::settings::AppSettings; use state::AppState; use std::sync::Arc; use tauri::Manager; @@ -37,21 +38,9 @@ pub fn run() { AppSettings::default() }; - // Apply docker host from settings - let docker_host = match settings.docker.host { - DockerHost::Remote => settings.docker.remote_url.clone(), - DockerHost::Local => None, - }; - let mcp_enabled = settings.mcp.enabled; let mcp_port = settings.mcp.port; - // Set docker host synchronously (state is fresh, no contention) - let state_for_setup = state.clone(); - tauri::async_runtime::block_on(async { - *state_for_setup.docker_host.write().await = docker_host; - }); - if mcp_enabled { let shutdown_rx = state.mcp_shutdown_tx.subscribe(); let mcp_state = state.clone(); @@ -101,7 +90,6 @@ pub fn run() { commands::schema::get_completion_schema, commands::schema::get_column_details, commands::schema::get_table_triggers, - commands::schema::get_schema_erd, // data commands::data::get_table_data, commands::data::update_row, @@ -110,20 +98,6 @@ pub fn run() { // export commands::export::export_csv, commands::export::export_json, - // management - commands::management::get_database_info, - commands::management::create_database, - commands::management::drop_database, - commands::management::list_roles, - commands::management::create_role, - commands::management::alter_role, - commands::management::drop_role, - commands::management::get_table_privileges, - commands::management::grant_revoke, - commands::management::manage_role_membership, - commands::management::list_sessions, - commands::management::cancel_query, - commands::management::terminate_backend, // history commands::history::add_history_entry, commands::history::get_history, @@ -139,27 +113,11 @@ pub fn run() { commands::ai::generate_sql, commands::ai::explain_sql, commands::ai::fix_sql_error, - commands::ai::generate_validation_sql, - commands::ai::run_validation_rule, - commands::ai::suggest_validation_rules, - commands::ai::generate_test_data_preview, - commands::ai::insert_generated_data, - commands::ai::get_index_advisor_report, - commands::ai::apply_index_recommendation, - // snapshot - commands::snapshot::create_snapshot, - commands::snapshot::restore_snapshot, - commands::snapshot::list_snapshots, - commands::snapshot::read_snapshot_metadata, - // lookup - commands::lookup::entity_lookup, - // docker - commands::docker::check_docker, - commands::docker::list_tusk_containers, - commands::docker::clone_to_docker, - commands::docker::start_container, - commands::docker::stop_container, - commands::docker::remove_container, + // chat + commands::chat::chat_send, + // memory + commands::memory::get_memory, + commands::memory::save_memory, // settings commands::settings::get_app_settings, commands::settings::save_app_settings, diff --git a/src-tauri/src/models/ai.rs b/src-tauri/src/models/ai.rs index 00a5fba..bb78dbc 100644 --- a/src-tauri/src/models/ai.rs +++ b/src-tauri/src/models/ai.rs @@ -41,6 +41,8 @@ pub struct OllamaChatRequest { pub model: String, pub messages: Vec, pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, } #[derive(Debug, Deserialize)] @@ -57,130 +59,3 @@ pub struct OllamaTagsResponse { pub struct OllamaModel { pub name: String, } - -// --- Wave 1: Validation --- - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum ValidationStatus { - Pending, - Generating, - Running, - Passed, - Failed, - Error, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ValidationRule { - pub id: String, - pub description: String, - pub generated_sql: String, - pub status: ValidationStatus, - pub violation_count: u64, - pub sample_violations: Vec>, - pub violation_columns: Vec, - pub error: Option, -} - -// --- Wave 2: Data Generator --- - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GenerateDataParams { - pub connection_id: String, - pub schema: String, - pub table: String, - pub row_count: u32, - pub include_related: bool, - pub custom_instructions: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeneratedDataPreview { - pub tables: Vec, - pub insert_order: Vec, - pub total_rows: u32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GeneratedTableData { - pub schema: String, - pub table: String, - pub columns: Vec, - pub rows: Vec>, - pub row_count: u32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DataGenProgress { - pub gen_id: String, - pub stage: String, - pub percent: u8, - pub message: String, - pub detail: Option, -} - -// --- Wave 3A: Index Advisor --- - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TableStats { - pub schema: String, - pub table: String, - pub seq_scan: i64, - pub idx_scan: i64, - pub n_live_tup: i64, - pub table_size: String, - pub index_size: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct IndexStats { - pub schema: String, - pub table: String, - pub index_name: String, - pub idx_scan: i64, - pub index_size: String, - pub definition: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SlowQuery { - pub query: String, - pub calls: i64, - pub total_time_ms: f64, - pub mean_time_ms: f64, - pub rows: i64, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum IndexRecommendationType { - #[serde(rename = "create_index")] - Create, - #[serde(rename = "drop_index")] - Drop, - #[serde(rename = "replace_index")] - Replace, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct IndexRecommendation { - pub id: String, - pub recommendation_type: IndexRecommendationType, - pub table_schema: String, - pub table_name: String, - pub index_name: Option, - pub ddl: String, - pub rationale: String, - pub estimated_impact: String, - pub priority: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct IndexAdvisorReport { - pub table_stats: Vec, - pub index_stats: Vec, - pub slow_queries: Vec, - pub recommendations: Vec, - pub has_pg_stat_statements: bool, -} diff --git a/src-tauri/src/models/chat.rs b/src-tauri/src/models/chat.rs new file mode 100644 index 0000000..857a4d9 --- /dev/null +++ b/src-tauri/src/models/chat.rs @@ -0,0 +1,19 @@ +use crate::models::query_result::QueryResult; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "role", rename_all = "snake_case")] +pub enum ChatMessage { + User { id: String, text: String, created_at: i64 }, + Assistant { id: String, text: String, created_at: i64 }, + ToolCall { id: String, tool: String, input_json: String, created_at: i64 }, + ToolResult { + id: String, + tool: String, + is_error: bool, + text: Option, + result: Option, + created_at: i64, + }, +} + diff --git a/src-tauri/src/models/connection.rs b/src-tauri/src/models/connection.rs index c8a678c..6c00810 100644 --- a/src-tauri/src/models/connection.rs +++ b/src-tauri/src/models/connection.rs @@ -1,3 +1,4 @@ +use crate::state::DbFlavor; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -12,6 +13,17 @@ pub struct ConnectionConfig { pub ssl_mode: Option, pub color: Option, pub environment: Option, + /// Database flavor selected by the user. Defaults to PostgreSQL for backwards + /// compatibility with older `connections.json` files written before multi-DB support. + #[serde(default = "default_flavor")] + pub db_flavor: DbFlavor, + /// HTTPS for ClickHouse. Defaults to false. + #[serde(default)] + pub secure: bool, +} + +fn default_flavor() -> DbFlavor { + DbFlavor::PostgreSQL } impl ConnectionConfig { diff --git a/src-tauri/src/models/docker.rs b/src-tauri/src/models/docker.rs deleted file mode 100644 index 71f96f2..0000000 --- a/src-tauri/src/models/docker.rs +++ /dev/null @@ -1,57 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DockerStatus { - pub installed: bool, - pub daemon_running: bool, - pub version: Option, - pub error: Option, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -pub enum CloneMode { - SchemaOnly, - FullClone, - SampleData, -} - -#[derive(Debug, Clone, Deserialize)] -pub struct CloneToDockerParams { - pub source_connection_id: String, - pub source_database: String, - pub container_name: String, - pub pg_version: String, - pub host_port: Option, - pub clone_mode: CloneMode, - pub sample_rows: Option, - pub postgres_password: Option, -} - -#[derive(Debug, Clone, Serialize)] -pub struct CloneProgress { - pub clone_id: String, - pub stage: String, - pub percent: u8, - pub message: String, - pub detail: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TuskContainer { - pub container_id: String, - pub name: String, - pub status: String, - pub host_port: u16, - pub pg_version: String, - pub source_database: Option, - pub source_connection: Option, - pub created_at: Option, -} - -#[derive(Debug, Clone, Serialize)] -pub struct CloneResult { - pub container: TuskContainer, - pub connection_id: String, - pub connection_url: String, -} diff --git a/src-tauri/src/models/lookup.rs b/src-tauri/src/models/lookup.rs deleted file mode 100644 index 73c02ef..0000000 --- a/src-tauri/src/models/lookup.rs +++ /dev/null @@ -1,44 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LookupTableMatch { - pub schema: String, - pub table: String, - pub column_type: String, - pub columns: Vec, - pub types: Vec, - pub rows: Vec>, - pub row_count: usize, - pub total_count: i64, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LookupDatabaseResult { - pub database: String, - pub tables: Vec, - pub error: Option, - pub search_time_ms: u128, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EntityLookupResult { - pub column_name: String, - pub value: String, - pub databases: Vec, - pub total_databases_searched: usize, - pub total_tables_matched: usize, - pub total_rows_found: usize, - pub total_time_ms: u128, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct LookupProgress { - pub lookup_id: String, - pub database: String, - pub status: String, - pub tables_found: usize, - pub rows_found: usize, - pub error: Option, - pub completed: usize, - pub total: usize, -} diff --git a/src-tauri/src/models/management.rs b/src-tauri/src/models/management.rs deleted file mode 100644 index 24ede69..0000000 --- a/src-tauri/src/models/management.rs +++ /dev/null @@ -1,110 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub struct DatabaseInfo { - pub name: String, - pub owner: String, - pub encoding: String, - pub collation: String, - pub ctype: String, - pub tablespace: String, - pub connection_limit: i32, - pub size: String, - pub description: Option, -} - -#[derive(Debug, Deserialize)] -pub struct CreateDatabaseParams { - pub name: String, - pub owner: Option, - pub template: Option, - pub encoding: Option, - pub tablespace: Option, - pub connection_limit: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct RoleInfo { - pub name: String, - pub is_superuser: bool, - pub can_login: bool, - pub can_create_db: bool, - pub can_create_role: bool, - pub inherit: bool, - pub is_replication: bool, - pub connection_limit: i32, - pub password_set: bool, - pub valid_until: Option, - pub member_of: Vec, - pub members: Vec, - pub description: Option, -} - -#[derive(Debug, Deserialize)] -pub struct CreateRoleParams { - pub name: String, - pub password: Option, - pub login: bool, - pub superuser: bool, - pub createdb: bool, - pub createrole: bool, - pub inherit: bool, - pub replication: bool, - pub connection_limit: Option, - pub valid_until: Option, - pub in_roles: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct AlterRoleParams { - pub name: String, - pub password: Option, - pub login: Option, - pub superuser: Option, - pub createdb: Option, - pub createrole: Option, - pub inherit: Option, - pub replication: Option, - pub connection_limit: Option, - pub valid_until: Option, - pub rename_to: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct TablePrivilege { - pub grantee: String, - pub table_schema: String, - pub table_name: String, - pub privilege_type: String, - pub is_grantable: bool, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct SessionInfo { - pub pid: i32, - pub usename: Option, - pub datname: Option, - pub state: Option, - pub query: Option, - pub query_start: Option, - pub wait_event_type: Option, - pub wait_event: Option, - pub client_addr: Option, -} - -#[derive(Debug, Deserialize)] -pub struct GrantRevokeParams { - pub action: String, - pub privileges: Vec, - pub object_type: String, - pub object_name: String, - pub role_name: String, - pub with_grant_option: bool, -} - -#[derive(Debug, Deserialize)] -pub struct RoleMembershipParams { - pub action: String, - pub role_name: String, - pub member_name: String, -} diff --git a/src-tauri/src/models/mod.rs b/src-tauri/src/models/mod.rs index 988c820..b161674 100644 --- a/src-tauri/src/models/mod.rs +++ b/src-tauri/src/models/mod.rs @@ -1,11 +1,8 @@ pub mod ai; +pub mod chat; pub mod connection; -pub mod docker; pub mod history; -pub mod lookup; -pub mod management; pub mod query_result; pub mod saved_queries; pub mod schema; pub mod settings; -pub mod snapshot; diff --git a/src-tauri/src/models/query_result.rs b/src-tauri/src/models/query_result.rs index d5cf409..4be4b3a 100644 --- a/src-tauri/src/models/query_result.rs +++ b/src-tauri/src/models/query_result.rs @@ -1,13 +1,15 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; +// Tauri's IPC layer does not support u128/i128 in command arguments, +// so timings round-trip through frontend → backend as u64 ms. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct QueryResult { pub columns: Vec, pub types: Vec, pub rows: Vec>, pub row_count: usize, - pub execution_time_ms: u128, + pub execution_time_ms: u64, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -16,7 +18,7 @@ pub struct PaginatedQueryResult { pub types: Vec, pub rows: Vec>, pub row_count: usize, - pub execution_time_ms: u128, + pub execution_time_ms: u64, pub total_rows: i64, pub page: u32, pub page_size: u32, diff --git a/src-tauri/src/models/schema.rs b/src-tauri/src/models/schema.rs index eaf75ca..c3c78fe 100644 --- a/src-tauri/src/models/schema.rs +++ b/src-tauri/src/models/schema.rs @@ -60,37 +60,3 @@ pub struct TriggerInfo { pub is_enabled: bool, pub definition: String, } - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ErdColumn { - pub name: String, - pub data_type: String, - pub is_nullable: bool, - pub is_primary_key: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ErdTable { - pub schema: String, - pub name: String, - pub columns: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ErdRelationship { - pub constraint_name: String, - pub source_schema: String, - pub source_table: String, - pub source_columns: Vec, - pub target_schema: String, - pub target_table: String, - pub target_columns: Vec, - pub update_rule: String, - pub delete_rule: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ErdData { - pub tables: Vec, - pub relationships: Vec, -} diff --git a/src-tauri/src/models/settings.rs b/src-tauri/src/models/settings.rs index f3211c8..a0f9a59 100644 --- a/src-tauri/src/models/settings.rs +++ b/src-tauri/src/models/settings.rs @@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct AppSettings { pub mcp: McpSettings, - pub docker: DockerSettings, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -21,28 +20,6 @@ impl Default for McpSettings { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct DockerSettings { - pub host: DockerHost, - pub remote_url: Option, -} - -impl Default for DockerSettings { - fn default() -> Self { - Self { - host: DockerHost::Local, - remote_url: None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] -pub enum DockerHost { - Local, - Remote, -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct McpStatus { pub enabled: bool, diff --git a/src-tauri/src/models/snapshot.rs b/src-tauri/src/models/snapshot.rs deleted file mode 100644 index 9089051..0000000 --- a/src-tauri/src/models/snapshot.rs +++ /dev/null @@ -1,68 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SnapshotMetadata { - pub id: String, - pub name: String, - pub created_at: String, - pub connection_name: String, - pub database: String, - pub tables: Vec, - pub total_rows: u64, - pub file_size_bytes: u64, - pub version: u32, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SnapshotTableMeta { - pub schema: String, - pub table: String, - pub row_count: u64, - pub columns: Vec, - pub column_types: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Snapshot { - pub metadata: SnapshotMetadata, - pub tables: Vec, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SnapshotTableData { - pub schema: String, - pub table: String, - pub columns: Vec, - pub column_types: Vec, - pub rows: Vec>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SnapshotProgress { - pub snapshot_id: String, - pub stage: String, - pub percent: u8, - pub message: String, - pub detail: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CreateSnapshotParams { - pub connection_id: String, - pub tables: Vec, - pub name: String, - pub include_dependencies: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct TableRef { - pub schema: String, - pub table: String, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct RestoreSnapshotParams { - pub connection_id: String, - pub file_path: String, - pub truncate_before_restore: bool, -} diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index 5a544e5..a272bc0 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -1,8 +1,10 @@ +use crate::db::clickhouse::ChClient; use crate::error::{TuskError, TuskResult}; use crate::models::ai::AiSettings; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; +use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::{watch, RwLock}; @@ -11,22 +13,44 @@ use tokio::sync::{watch, RwLock}; pub enum DbFlavor { PostgreSQL, Greenplum, + ClickHouse, } + #[derive(Clone)] pub struct SchemaCacheEntry { pub schema_text: String, pub cached_at: Instant, } +#[derive(Clone)] +pub struct CachedString { + pub value: String, + pub cached_at: Instant, +} + +#[derive(Clone)] +pub struct CachedVec { + pub value: Vec, + pub cached_at: Instant, +} + pub struct AppState { pub pools: RwLock>, + pub ch_clients: RwLock>>, pub read_only: RwLock>, pub db_flavors: RwLock>, + /// Legacy cache used by generate_sql/explain_sql/fix_sql_error — full DDL. pub schema_cache: RwLock>, + /// Chat v2 caches: lite overview per connection. + pub overview_cache: RwLock>, + /// Chat v2 caches: list of tables per (connection_id, db_name) — used for + /// list_tables on a non-active PG database via temporary pool. + pub tables_by_db_cache: RwLock>>, + /// Chat v2 caches: column block per (connection_id, db_name, "schema.table"). + pub columns_cache: RwLock>, pub mcp_shutdown_tx: watch::Sender, pub mcp_running: RwLock, - pub docker_host: RwLock>, pub ai_settings: RwLock>, } @@ -38,16 +62,34 @@ impl AppState { let (mcp_shutdown_tx, _) = watch::channel(false); Self { pools: RwLock::new(HashMap::new()), + ch_clients: RwLock::new(HashMap::new()), read_only: RwLock::new(HashMap::new()), db_flavors: RwLock::new(HashMap::new()), schema_cache: RwLock::new(HashMap::new()), + overview_cache: RwLock::new(HashMap::new()), + tables_by_db_cache: RwLock::new(HashMap::new()), + columns_cache: RwLock::new(HashMap::new()), mcp_shutdown_tx, mcp_running: RwLock::new(false), - docker_host: RwLock::new(None), ai_settings: RwLock::new(None), } } + /// Drop every chat-agent cache entry tied to this connection. + /// Called by switch_database_core, disconnect, and on connection delete. + pub async fn invalidate_chat_caches_for(&self, connection_id: &str) { + self.schema_cache.write().await.remove(connection_id); + self.overview_cache.write().await.remove(connection_id); + self.tables_by_db_cache + .write() + .await + .retain(|(cid, _), _| cid != connection_id); + self.columns_cache + .write() + .await + .retain(|(cid, _, _), _| cid != connection_id); + } + pub async fn get_pool(&self, connection_id: &str) -> TuskResult { let pools = self.pools.read().await; pools @@ -56,6 +98,14 @@ impl AppState { .ok_or_else(|| TuskError::NotConnected(connection_id.to_string())) } + pub async fn get_ch_client(&self, connection_id: &str) -> TuskResult> { + let clients = self.ch_clients.read().await; + clients + .get(connection_id) + .cloned() + .ok_or_else(|| TuskError::NotConnected(connection_id.to_string())) + } + pub async fn is_read_only(&self, id: &str) -> bool { let map = self.read_only.read().await; map.get(id).copied().unwrap_or(true) @@ -100,8 +150,4 @@ impl AppState { ); } - pub async fn invalidate_schema_cache(&self, connection_id: &str) { - let mut cache = self.schema_cache.write().await; - cache.remove(connection_id); - } } diff --git a/src-tauri/src/utils.rs b/src-tauri/src/utils.rs index cecd6fb..569291c 100644 --- a/src-tauri/src/utils.rs +++ b/src-tauri/src/utils.rs @@ -1,95 +1,11 @@ -use std::collections::{HashMap, HashSet, VecDeque}; - pub fn escape_ident(name: &str) -> String { format!("\"{}\"", name.replace('"', "\"\"")) } -/// Topological sort of tables based on foreign key dependencies. -/// Returns tables in insertion order: parents before children. -pub fn topological_sort_tables( - fk_edges: &[(String, String, String, String)], // (schema, table, ref_schema, ref_table) - target_tables: &[(String, String)], -) -> Vec<(String, String)> { - let mut graph: HashMap<(String, String), HashSet<(String, String)>> = HashMap::new(); - let mut in_degree: HashMap<(String, String), usize> = HashMap::new(); - - // Initialize all target tables - for t in target_tables { - graph.entry(t.clone()).or_default(); - in_degree.entry(t.clone()).or_insert(0); - } - - let target_set: HashSet<(String, String)> = target_tables.iter().cloned().collect(); - - // Build edges: parent -> child (child depends on parent) - for (schema, table, ref_schema, ref_table) in fk_edges { - let child = (schema.clone(), table.clone()); - let parent = (ref_schema.clone(), ref_table.clone()); - - if child == parent { - continue; // self-referencing - } - - if !target_set.contains(&child) || !target_set.contains(&parent) { - continue; - } - - if graph - .entry(parent.clone()) - .or_default() - .insert(child.clone()) - { - *in_degree.entry(child).or_insert(0) += 1; - } - } - - // Kahn's algorithm - let mut initial: Vec<(String, String)> = in_degree - .iter() - .filter(|(_, °)| deg == 0) - .map(|(k, _)| k.clone()) - .collect(); - initial.sort(); // deterministic order - let mut queue: VecDeque<(String, String)> = VecDeque::from(initial); - - let mut result = Vec::new(); - - while let Some(node) = queue.pop_front() { - result.push(node.clone()); - if let Some(neighbors) = graph.get(&node) { - let mut new_ready: Vec<(String, String)> = neighbors - .iter() - .filter(|neighbor| { - if let Some(deg) = in_degree.get_mut(*neighbor) { - *deg -= 1; - *deg == 0 - } else { - false - } - }) - .cloned() - .collect(); - new_ready.sort(); - queue.extend(new_ready); - } - } - - // Add any remaining tables (cycles) at the end - for t in target_tables { - if !result.contains(t) { - result.push(t.clone()); - } - } - - result -} - #[cfg(test)] mod tests { use super::*; - // ── escape_ident ────────────────────────────────────────── - #[test] fn escape_ident_simple_name() { assert_eq!(escape_ident("users"), "\"users\""); @@ -149,70 +65,4 @@ mod tests { fn escape_ident_newline() { assert_eq!(escape_ident("a\nb"), "\"a\nb\""); } - - // ── topological_sort_tables ─────────────────────────────── - - #[test] - fn topo_sort_no_edges() { - let tables = vec![("public".into(), "b".into()), ("public".into(), "a".into())]; - let result = topological_sort_tables(&[], &tables); - assert_eq!(result.len(), 2); - assert!(result.contains(&("public".into(), "a".into()))); - assert!(result.contains(&("public".into(), "b".into()))); - } - - #[test] - fn topo_sort_simple_dependency() { - let edges = vec![( - "public".into(), - "orders".into(), - "public".into(), - "users".into(), - )]; - let tables = vec![ - ("public".into(), "orders".into()), - ("public".into(), "users".into()), - ]; - let result = topological_sort_tables(&edges, &tables); - let user_pos = result.iter().position(|t| t.1 == "users").unwrap(); - let order_pos = result.iter().position(|t| t.1 == "orders").unwrap(); - assert!(user_pos < order_pos, "users must come before orders"); - } - - #[test] - fn topo_sort_self_reference() { - let edges = vec![( - "public".into(), - "employees".into(), - "public".into(), - "employees".into(), - )]; - let tables = vec![("public".into(), "employees".into())]; - let result = topological_sort_tables(&edges, &tables); - assert_eq!(result.len(), 1); - } - - #[test] - fn topo_sort_cycle() { - let edges = vec![ - ("public".into(), "a".into(), "public".into(), "b".into()), - ("public".into(), "b".into(), "public".into(), "a".into()), - ]; - let tables = vec![("public".into(), "a".into()), ("public".into(), "b".into())]; - let result = topological_sort_tables(&edges, &tables); - assert_eq!(result.len(), 2); - } - - #[test] - fn topo_sort_edge_outside_target_set_ignored() { - let edges = vec![( - "public".into(), - "orders".into(), - "public".into(), - "external".into(), - )]; - let tables = vec![("public".into(), "orders".into())]; - let result = topological_sort_tables(&edges, &tables); - assert_eq!(result.len(), 1); - } } diff --git a/src/components/chat/ChatComposer.tsx b/src/components/chat/ChatComposer.tsx new file mode 100644 index 0000000..5f7b92c --- /dev/null +++ b/src/components/chat/ChatComposer.tsx @@ -0,0 +1,64 @@ +import { useRef, useState } from "react"; +import { Button } from "@/components/ui/button"; +import { Send } from "lucide-react"; + +interface Props { + onSend: (text: string) => void; + disabled?: boolean; + placeholder?: string; +} + +export function ChatComposer({ onSend, disabled, placeholder }: Props) { + const [value, setValue] = useState(""); + const ref = useRef(null); + + const handleSend = () => { + if (disabled) return; + const text = value.trim(); + if (!text) return; + onSend(text); + setValue(""); + requestAnimationFrame(() => { + if (ref.current) ref.current.style.height = "auto"; + }); + }; + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + handleSend(); + } + }; + + const autoresize = (el: HTMLTextAreaElement) => { + el.style.height = "auto"; + el.style.height = `${Math.min(el.scrollHeight, 200)}px`; + }; + + return ( +
+