From d507162377f352e24dd28a45272c2980dbb40352 Mon Sep 17 00:00:00 2001 From: "A.Shakhmatov" Date: Sat, 21 Feb 2026 11:41:14 +0300 Subject: [PATCH] fix: harden security, reduce duplication, and improve robustness - Fix SQL injection in data.rs by wrapping get_table_data in READ ONLY transaction - Fix SQL injection in docker.rs CREATE DATABASE via escape_ident - Fix command injection in docker.rs by validating pg_version/container_name and escaping shell-interpolated values - Fix UTF-8 panic on stderr truncation with char_indices - Wrap delete_rows in a transaction for atomicity - Replace .expect() with proper error propagation in lib.rs - Cache AI settings in AppState to avoid repeated disk reads - Cap JSONB column discovery at 50 to prevent unbounded queries - Fix ERD colorMode to respect system theme via useTheme() - Extract AppState::get_pool() replacing ~19 inline pool patterns - Extract shared AiSettingsFields component (DRY popover + sheet) - Make get_connections_path pub(crate) and reuse from docker.rs - Deduplicate check_docker by delegating to check_docker_internal Co-Authored-By: Claude Opus 4.6 --- src-tauri/Cargo.lock | 345 ++----- src-tauri/Cargo.toml | 2 +- src-tauri/src/commands/ai.rs | 948 ++++++++++++++++--- src-tauri/src/commands/connections.rs | 2 +- src-tauri/src/commands/data.rs | 58 +- src-tauri/src/commands/docker.rs | 127 ++- src-tauri/src/commands/schema.rs | 100 +- src-tauri/src/lib.rs | 18 +- src-tauri/src/models/ai.rs | 19 +- src-tauri/src/state.rs | 12 + src-tauri/tauri.conf.json | 4 +- src/components/ai/AiSettingsFields.tsx | 82 ++ src/components/ai/AiSettingsPopover.tsx | 75 +- src/components/erd/ErdDiagram.tsx | 8 +- src/components/settings/AppSettingsSheet.tsx | 63 +- 15 files changed, 1196 insertions(+), 667 deletions(-) create mode 100644 src/components/ai/AiSettingsFields.tsx diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index bce075b..d5c5155 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -383,6 +383,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.43" @@ -438,16 +444,6 @@ dependencies = [ "version_check", ] -[[package]] -name = "core-foundation" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "core-foundation" version = "0.10.1" @@ -471,9 +467,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1" dependencies = [ "bitflags 2.10.0", - "core-foundation 0.10.1", + "core-foundation", "core-graphics-types", - "foreign-types 0.5.0", + "foreign-types", "libc", ] @@ -484,7 +480,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" dependencies = [ "bitflags 2.10.0", - "core-foundation 0.10.1", + "core-foundation", "libc", ] @@ -930,12 +926,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "fastrand" -version = "2.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" - [[package]] name = "fdeflate" version = "0.3.7" @@ -994,15 +984,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared 0.1.1", -] - [[package]] name = "foreign-types" version = "0.5.0" @@ -1010,7 +991,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared 0.3.1", + "foreign-types-shared", ] [[package]] @@ -1024,12 +1005,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1291,8 +1266,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1302,9 +1279,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1455,25 +1434,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "h2" -version = "0.4.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" -dependencies = [ - "atomic-waker", - "bytes", - "fnv", - "futures-core", - "futures-sink", - "http", - "indexmap 2.13.0", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -1618,7 +1578,6 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "h2", "http", "http-body", "httparse", @@ -1645,22 +1604,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", -] - -[[package]] -name = "hyper-tls" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" -dependencies = [ - "bytes", - "http-body-util", - "hyper", - "hyper-util", - "native-tls", - "tokio", - "tokio-native-tls", - "tower-service", + "webpki-roots 1.0.6", ] [[package]] @@ -1681,11 +1625,9 @@ dependencies = [ "percent-encoding", "pin-project-lite", "socket2", - "system-configuration", "tokio", "tower-service", "tracing", - "windows-registry", ] [[package]] @@ -2079,12 +2021,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "linux-raw-sys" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" - [[package]] name = "litemap" version = "0.8.1" @@ -2106,6 +2042,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "mac" version = "0.1.1" @@ -2222,23 +2164,6 @@ dependencies = [ "windows-sys 0.60.2", ] -[[package]] -name = "native-tls" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6cdede44f9a69cab2899a2049e2c3bd49bf911a157f6a3353d4a91c61abbce44" -dependencies = [ - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "ndk" version = "0.9.0" @@ -2595,50 +2520,6 @@ dependencies = [ "pathdiff", ] -[[package]] -name = "openssl" -version = "0.10.75" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" -dependencies = [ - "bitflags 2.10.0", - "cfg-if", - "foreign-types 0.3.2", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.114", -] - -[[package]] -name = "openssl-probe" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" - -[[package]] -name = "openssl-sys" -version = "0.9.111" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "option-ext" version = "0.2.0" @@ -3042,6 +2923,61 @@ dependencies = [ "memchr", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.44" @@ -3259,29 +3195,26 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" dependencies = [ "base64 0.22.1", "bytes", - "encoding_rs", "futures-core", - "h2", "http", "http-body", "http-body-util", "hyper", "hyper-rustls", - "hyper-tls", "hyper-util", "js-sys", "log", - "mime", - "native-tls", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-native-tls", + "tokio-rustls", "tower", "tower-http", "tower-service", @@ -3289,6 +3222,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", + "webpki-roots 1.0.6", ] [[package]] @@ -3428,6 +3362,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -3437,19 +3377,6 @@ dependencies = [ "semver", ] -[[package]] -name = "rustix" -version = "1.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" -dependencies = [ - "bitflags 2.10.0", - "errno", - "libc", - "linux-raw-sys", - "windows-sys 0.61.2", -] - [[package]] name = "rustls" version = "0.23.36" @@ -3470,6 +3397,7 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ + "web-time", "zeroize", ] @@ -3505,15 +3433,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schannel" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" -dependencies = [ - "windows-sys 0.61.2", -] - [[package]] name = "schemars" version = "0.8.22" @@ -3585,29 +3504,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "security-framework" -version = "2.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" -dependencies = [ - "bitflags 2.10.0", - "core-foundation 0.9.4", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "selectors" version = "0.24.0" @@ -4329,27 +4225,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "system-configuration" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" -dependencies = [ - "bitflags 2.10.0", - "core-foundation 0.9.4", - "system-configuration-sys", -] - -[[package]] -name = "system-configuration-sys" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "system-deps" version = "6.2.2" @@ -4371,7 +4246,7 @@ checksum = "f3a753bdc39c07b192151523a3f77cd0394aa75413802c883a0f6f6a0e5ee2e7" dependencies = [ "bitflags 2.10.0", "block2", - "core-foundation 0.10.1", + "core-foundation", "core-graphics", "crossbeam-channel", "dispatch", @@ -4713,19 +4588,6 @@ dependencies = [ "toml 0.9.12+spec-1.1.0", ] -[[package]] -name = "tempfile" -version = "3.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" -dependencies = [ - "fastrand", - "getrandom 0.3.4", - "once_cell", - "rustix", - "windows-sys 0.61.2", -] - [[package]] name = "tendril" version = "0.4.3" @@ -4861,16 +4723,6 @@ dependencies = [ "syn 2.0.114", ] -[[package]] -name = "tokio-native-tls" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" -dependencies = [ - "native-tls", - "tokio", -] - [[package]] name = "tokio-rustls" version = "0.26.4" @@ -5440,6 +5292,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webkit2gtk" version = "2.0.2" @@ -5697,17 +5559,6 @@ dependencies = [ "windows-link 0.1.3", ] -[[package]] -name = "windows-registry" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" -dependencies = [ - "windows-link 0.2.1", - "windows-result 0.4.1", - "windows-strings 0.5.1", -] - [[package]] name = "windows-result" version = "0.3.4" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 39a5150..1b6d48c 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -30,7 +30,7 @@ csv = "1" log = "0.4" hex = "0.4" bigdecimal = { version = "0.4", features = ["serde"] } -reqwest = { version = "0.12", features = ["json"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } rmcp = { version = "0.15", features = ["server", "macros", "transport-streamable-http-server"] } axum = "0.8" schemars = "1" diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs index ea9ebaf..8fe6eb0 100644 --- a/src-tauri/src/commands/ai.rs +++ b/src-tauri/src/commands/ai.rs @@ -1,16 +1,19 @@ use crate::error::{TuskError, TuskResult}; use crate::models::ai::{ - AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, - OllamaTagsResponse, + AiProvider, AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, + OllamaModel, OllamaTagsResponse, }; use crate::state::AppState; use sqlx::Row; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::fs; use std::sync::Arc; use std::time::Duration; use tauri::{AppHandle, Manager, State}; +const MAX_RETRIES: u32 = 2; +const RETRY_DELAY_MS: u64 = 1000; + fn http_client() -> reqwest::Client { reqwest::Client::builder() .connect_timeout(Duration::from_secs(5)) @@ -40,10 +43,16 @@ pub async fn get_ai_settings(app: AppHandle) -> TuskResult { } #[tauri::command] -pub async fn save_ai_settings(app: AppHandle, settings: AiSettings) -> TuskResult<()> { +pub async fn save_ai_settings( + app: AppHandle, + state: State<'_, Arc>, + settings: AiSettings, +) -> TuskResult<()> { let path = get_ai_settings_path(&app)?; let data = serde_json::to_string_pretty(&settings)?; fs::write(&path, data)?; + // Update in-memory cache + *state.ai_settings.write().await = Some(settings); Ok(()) } @@ -73,21 +82,67 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult( + _settings: &AiSettings, + operation: &str, + f: F, +) -> TuskResult +where + F: Fn() -> Fut, + Fut: std::future::Future>, +{ + let mut last_error = None; + + for attempt in 0..MAX_RETRIES { + match f().await { + Ok(result) => return Ok(result), + Err(e) => { + last_error = Some(e); + if attempt < MAX_RETRIES - 1 { + log::warn!( + "{} failed (attempt {}/{}), retrying in {}ms...", + operation, + attempt + 1, + MAX_RETRIES, + RETRY_DELAY_MS + ); + tokio::time::sleep(Duration::from_millis(RETRY_DELAY_MS)).await; + } + } + } + } + + Err(last_error.unwrap_or_else(|| { + TuskError::Ai(format!("{} failed after {} attempts", operation, MAX_RETRIES)) + })) +} + +async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskResult { + // Try in-memory cache first + if let Some(cached) = state.ai_settings.read().await.clone() { + return Ok(cached); + } + // Fallback to disk + let path = get_ai_settings_path(app)?; + if !path.exists() { + return Err(TuskError::Ai( + "No AI model selected. Open AI settings to choose a model.".to_string(), + )); + } + let data = fs::read_to_string(&path)?; + let settings: AiSettings = serde_json::from_str(&data)?; + // Populate cache for future calls + *state.ai_settings.write().await = Some(settings.clone()); + Ok(settings) +} + async fn call_ollama_chat( app: &AppHandle, + state: &AppState, system_prompt: String, user_content: String, ) -> TuskResult { - let settings = { - let path = get_ai_settings_path(app)?; - if !path.exists() { - return Err(TuskError::Ai( - "No AI model selected. Open AI settings to choose a model.".to_string(), - )); - } - let data = fs::read_to_string(&path)?; - serde_json::from_str::(&data)? - }; + let settings = load_ai_settings(app, state).await?; if settings.model.is_empty() { return Err(TuskError::Ai( @@ -95,8 +150,21 @@ async fn call_ollama_chat( )); } + if settings.provider != AiProvider::Ollama { + return Err(TuskError::Ai(format!( + "Provider {:?} not implemented yet", + settings.provider + ))); + } + + let model = settings.model.clone(); + let url = format!( + "{}/api/chat", + settings.ollama_url.trim_end_matches('/') + ); + let request = OllamaChatRequest { - model: settings.model, + model: model.clone(), messages: vec![ OllamaChatMessage { role: "system".to_string(), @@ -110,40 +178,46 @@ async fn call_ollama_chat( stream: false, }; - let url = format!( - "{}/api/chat", - settings.ollama_url.trim_end_matches('/') - ); + call_ai_with_retry(&settings, "Ollama request", || { + let url = url.clone(); + let request = request.clone(); + async move { + let resp = http_client() + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| { + TuskError::Ai(format!( + "Cannot connect to Ollama at {}: {}", + url, e + )) + })?; - let resp = http_client() - .post(&url) - .json(&request) - .send() - .await - .map_err(|e| { - TuskError::Ai(format!( - "Cannot connect to Ollama at {}: {}", - settings.ollama_url, e - )) - })?; + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(TuskError::Ai(format!( + "Ollama error ({}): {}", + status, body + ))); + } - if !resp.status().is_success() { - let status = resp.status(); - let body = resp.text().await.unwrap_or_default(); - return Err(TuskError::Ai(format!( - "Ollama error ({}): {}", - status, body - ))); - } + let chat_resp: OllamaChatResponse = resp + .json() + .await + .map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?; - let chat_resp: OllamaChatResponse = resp - .json() - .await - .map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?; - - Ok(chat_resp.message.content) + Ok(chat_resp.message.content) + } + }) + .await } +// --------------------------------------------------------------------------- +// SQL generation +// --------------------------------------------------------------------------- + #[tauri::command] pub async fn generate_sql( app: AppHandle, @@ -154,24 +228,80 @@ pub async fn generate_sql( let schema_text = build_schema_context(&state, &connection_id).await?; let system_prompt = format!( - "You are a PostgreSQL SQL generator. Given the database schema below and a natural language request, \ - output ONLY a valid PostgreSQL SQL query. Do not include any explanation, markdown formatting, \ - or code fences. Output raw SQL only.\n\n\ - RULES:\n\ - - Use FK relationships for correct JOIN conditions.\n\ - - timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ - - interval cannot be cast to numeric directly.\n\ - - When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\ - - Use COALESCE for nullable columns in aggregations when appropriate.\n\ - - Prefer LEFT JOIN when the related row may not exist.\n\n\ - DATABASE SCHEMA:\n{}", + "You are an expert PostgreSQL query generator. You receive a database schema and a natural \ + language request. Output ONLY a valid, executable PostgreSQL SQL query.\n\ + \n\ + OUTPUT FORMAT:\n\ + - Raw SQL only. No explanations, no markdown code fences (```), no comments, no preamble.\n\ + - The output must be directly executable in psql.\n\ + - For complex queries use readable formatting with line breaks and indentation.\n\ + \n\ + CRITICAL RULES:\n\ + 1. ONLY reference tables and columns that exist in the schema. Never invent names.\n\ + 2. Use the FOREIGN KEY information to determine correct JOIN conditions.\n\ + 3. Use LEFT JOIN when the FK column is nullable or the relationship is optional; \ + INNER JOIN when both sides must exist.\n\ + 4. Every non-aggregated column in SELECT must appear in GROUP BY.\n\ + 5. Use COALESCE for nullable columns in aggregations: COALESCE(SUM(x), 0).\n\ + 6. For enum columns, use ONLY the values listed in the ENUM TYPES section.\n\ + 7. Use IS NULL / IS NOT NULL for null checks — never = NULL or != NULL.\n\ + 8. Add LIMIT when the user asks for \"top N\", \"first N\", \"latest N\", etc.\n\ + 9. Qualify column names with table alias when the query involves multiple tables.\n\ + \n\ + SEMANTIC RULES (very important):\n\ + - When a table has both actual_* and planned_* columns (e.g. actual_start vs planned_start), \ + they represent DIFFERENT concepts: planned = future estimate, actual = what really happened. \ + NEVER mix them with COALESCE unless the user explicitly requests a fallback.\n\ + - For time-based calculations involving real events (\"how long did X take\", \"average time between\"), \ + use ONLY actual/factual timestamps (actual_*, started_at, completed_at, ended_at). \ + Filter out NULL values with WHERE instead of falling back to planned timestamps.\n\ + - Planned timestamps (planned_*, scheduled_*, estimated_*) should ONLY be used when the user \ + asks about plans, schedules, SLA, or compares plan vs fact.\n\ + - When computing durations or averages, always filter out rows where any involved timestamp \ + is NULL rather than substituting with unrelated defaults.\n\ + - Pay attention to column descriptions/comments in the schema — they reveal business semantics \ + that are critical for correct queries.\n\ + \n\ + TYPE RULES:\n\ + - timestamp - timestamp = interval. For seconds: EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ + - interval cannot be cast to numeric directly; use EXTRACT(EPOCH FROM interval).\n\ + - UNION/UNION ALL requires matching column count and compatible types; cast enums to text.\n\ + - Use ::type for PostgreSQL-style casts.\n\ + - For array columns use ANY, ALL, @>, <@ operators.\n\ + - For JSONB columns use ->, ->>, #>, jsonb_extract_path.\n\ + \n\ + COMMON PATTERNS:\n\ + - FIRST/LAST per group: to find MIN(started_at) per trip, use \ + \"SELECT trip_id, MIN(started_at) FROM t GROUP BY trip_id\". \ + NEVER put the aggregated column (started_at) into GROUP BY — that defeats the aggregation \ + and returns every row separately instead of one per group.\n\ + - TOP-1 per group with extra columns: use DISTINCT ON (group_col) ... ORDER BY group_col, sort_col \ + or a subquery with ROW_NUMBER() OVER (PARTITION BY group_col ORDER BY sort_col) = 1.\n\ + - For \"time from A to B\" calculations, ensure both timestamps are NOT NULL with WHERE filters; \ + never use COALESCE to mix planned and actual timestamps.\n\ + \n\ + BEST PRACTICES:\n\ + - Use ILIKE for case-insensitive text search, LIKE for case-sensitive.\n\ + - Use EXISTS instead of IN for subquery existence checks.\n\ + - Use CTE (WITH ... AS) for complex multi-step logic.\n\ + - Use window functions (ROW_NUMBER, RANK, LAG, LEAD, SUM OVER) for ranking and running totals.\n\ + - Use date_trunc('period', column) for time-based grouping.\n\ + - Use generate_series() for creating ranges.\n\ + - Use string_agg(col, ', ') for concatenating grouped values.\n\ + - Use FILTER (WHERE ...) for conditional aggregation instead of CASE inside aggregate.\n\ + \n\ + {}\n", schema_text ); - let raw = call_ollama_chat(&app, system_prompt, prompt).await?; + let raw = call_ollama_chat(&app, &state, system_prompt, prompt).await?; Ok(clean_sql_response(&raw)) } +// --------------------------------------------------------------------------- +// SQL explanation +// --------------------------------------------------------------------------- + #[tauri::command] pub async fn explain_sql( app: AppHandle, @@ -182,16 +312,32 @@ pub async fn explain_sql( let schema_text = build_schema_context(&state, &connection_id).await?; let system_prompt = format!( - "You are a PostgreSQL expert. Explain what this SQL query does in clear, concise language. \ - Focus on the business logic, mention the tables, joins, and filters used. \ - Use short paragraphs or bullet points.\n\n\ - DATABASE SCHEMA:\n{}", + "You are a PostgreSQL expert. Explain the given SQL query clearly and concisely.\n\ + \n\ + Structure your explanation as:\n\ + 1. **Summary** — one sentence describing what the query returns in business terms.\n\ + 2. **Step-by-step breakdown** — explain tables accessed, joins, filters, aggregations, \ + subqueries, and sorting. Use bullet points.\n\ + 3. **Notes** — mention edge cases, potential issues, or performance concerns if any.\n\ + \n\ + Use the database schema below to understand table relationships and column meanings.\n\ + Keep the explanation short; avoid restating the SQL verbatim.\n\ + \n\ + IMPORTANT: If you notice semantic issues (e.g. mixing planned_* and actual_* timestamps \ + with COALESCE, comparing unrelated columns, missing NULL filters on nullable timestamps), \ + mention them in the Notes section as potential problems.\n\ + \n\ + {}\n", schema_text ); - call_ollama_chat(&app, system_prompt, sql).await + call_ollama_chat(&app, &state, system_prompt, sql).await } +// --------------------------------------------------------------------------- +// SQL error fixing +// --------------------------------------------------------------------------- + #[tauri::command] pub async fn fix_sql_error( app: AppHandle, @@ -203,44 +349,291 @@ pub async fn fix_sql_error( let schema_text = build_schema_context(&state, &connection_id).await?; let system_prompt = format!( - "You are a PostgreSQL expert. Fix the SQL query based on the error message. \ - Output ONLY the corrected valid PostgreSQL SQL. Do not include any explanation, \ - markdown formatting, or code fences. Output raw SQL only.\n\n\ - RULES:\n\ - - Use FK relationships for correct JOIN conditions.\n\ - - timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ - - interval cannot be cast to numeric directly.\n\ - - When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\ - - Use COALESCE for nullable columns in aggregations when appropriate.\n\ - - Prefer LEFT JOIN when the related row may not exist.\n\n\ - DATABASE SCHEMA:\n{}", + "You are a PostgreSQL expert debugger. You receive a SQL query and the error it produced. \ + Fix the query so it executes correctly.\n\ + \n\ + OUTPUT FORMAT:\n\ + - Raw SQL only. No explanations, no markdown code fences (```), no comments.\n\ + - The output must be directly executable.\n\ + \n\ + DIAGNOSTIC CHECKLIST:\n\ + - Column/table does not exist → check the schema for correct names and spelling.\n\ + - Column is ambiguous → qualify with table name or alias.\n\ + - Must appear in GROUP BY → add missing non-aggregated columns to GROUP BY.\n\ + - Type mismatch → add appropriate casts (::text, ::integer, etc.).\n\ + - Permission denied → wrap in a read-only transaction if needed.\n\ + - Syntax error → correct PostgreSQL syntax (check commas, parentheses, keywords).\n\ + - Subquery returns more than one row → use IN, ANY, or add LIMIT 1.\n\ + - Division by zero → wrap divisor with NULLIF(x, 0).\n\ + \n\ + ONLY use tables and columns from the schema below. Never invent names.\n\ + Preserve the original intent of the query; change only what is necessary to fix the error.\n\ + \n\ + {}\n", schema_text ); let user_content = format!( - "Original SQL:\n{}\n\nError:\n{}", + "SQL query:\n{}\n\nError message:\n{}", sql, error_message ); - let raw = call_ollama_chat(&app, system_prompt, user_content).await?; + let raw = call_ollama_chat(&app, &state, system_prompt, user_content).await?; Ok(clean_sql_response(&raw)) } +// --------------------------------------------------------------------------- +// Schema context builder +// --------------------------------------------------------------------------- + async fn build_schema_context( state: &AppState, connection_id: &str, ) -> TuskResult { - let pools = state.pools.read().await; - let pool = pools - .get(connection_id) - .ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?; + // Check cache first + if let Some(cached) = state.get_schema_cache(connection_id).await { + return Ok(cached); + } - // Single query: all columns with real type names (enum types show actual name, not USER-DEFINED) - let col_rows = sqlx::query( + let pool = state.get_pool(connection_id).await?; + + // Run all metadata queries in parallel for speed + let ( + version_res, col_res, fk_res, enum_res, + tbl_comment_res, col_comment_res, unique_res, + varchar_res, jsonb_res, + ) = tokio::join!( + sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool), + fetch_columns(&pool), + fetch_foreign_keys_raw(&pool), + fetch_enum_types(&pool), + fetch_table_comments(&pool), + fetch_column_comments(&pool), + fetch_unique_constraints(&pool), + fetch_varchar_values(&pool), + fetch_jsonb_keys(&pool), + ); + + let version = version_res.map_err(TuskError::Database)?; + let col_rows = col_res?; + let fk_rows = fk_res?; + let enum_map = enum_res?; + let tbl_comments = tbl_comment_res?; + let col_comments = col_comment_res?; + let unique_constraints = unique_res?; + let varchar_values = varchar_res.unwrap_or_default(); + let jsonb_keys = jsonb_res.unwrap_or_default(); + + // -- Build FK inline lookup: (schema, table, column) -> "ref_schema.ref_table(ref_col)" -- + let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new(); + let mut fk_lines: Vec = Vec::new(); + for fk in &fk_rows { + let line = format!( + "FK: {}.{}({}) -> {}.{}({})", + fk.schema, + fk.table, + fk.columns.join(", "), + fk.ref_schema, + fk.ref_table, + fk.ref_columns.join(", ") + ); + fk_lines.push(line); + + // For single-column FKs, enable inline annotation on column + if fk.columns.len() == 1 && fk.ref_columns.len() == 1 { + fk_inline.insert( + (fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()), + format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]), + ); + } + } + + // -- Build unique constraint lookup: (schema, table) -> Vec -- + let mut unique_map: HashMap<(String, String), Vec> = HashMap::new(); + for (schema, table, cols) in &unique_constraints { + unique_map + .entry((schema.clone(), table.clone())) + .or_default() + .push(cols.join(", ")); + } + + // -- Format output -- + let mut output: Vec = Vec::new(); + + // 1. PostgreSQL version (short form) + let short_version = version + .split_whitespace() + .take(2) + .collect::>() + .join(" "); + output.push(format!("DATABASE SCHEMA ({})", short_version)); + output.push(String::new()); + + // 2. Enum types + if !enum_map.is_empty() { + output.push("ENUM TYPES:".to_string()); + for (type_name, values) in &enum_map { + let vals_str = values + .iter() + .map(|v| format!("'{}'", v)) + .collect::>() + .join(", "); + output.push(format!(" {} = [{}]", type_name, vals_str)); + } + output.push(String::new()); + } + + // 3. Tables with columns + output.push("TABLES:".to_string()); + + // Group columns by schema.table preserving order + let mut tables: BTreeMap> = BTreeMap::new(); + for ci in &col_rows { + let key = format!("{}.{}", ci.schema, ci.table); + tables.entry(key).or_default().push(ci.clone()); + } + + for (full_name, columns) in &tables { + // Table header with optional comment + let tbl_comment = tbl_comments.get(full_name).map(|c| c.as_str()); + match tbl_comment { + Some(comment) => output.push(format!("\nTABLE {} -- {}", full_name, comment)), + None => output.push(format!("\nTABLE {}", full_name)), + } + + // Columns + for ci in columns { + let mut parts: Vec = vec![ci.column.clone(), ci.data_type.clone()]; + + if ci.is_pk { + parts.push("PK".to_string()); + } + if ci.not_null && !ci.is_pk { + parts.push("NOT NULL".to_string()); + } + + // Inline FK reference + if let Some(ref_target) = + fk_inline.get(&(ci.schema.clone(), ci.table.clone(), ci.column.clone())) + { + parts.push(format!("FK->{}", ref_target)); + } + + // Default value (simplified) + if let Some(ref def) = ci.column_default { + let simplified = simplify_default(def); + if !simplified.is_empty() { + parts.push(format!("DEFAULT {}", simplified)); + } + } + + // Column comment + let col_key = (ci.schema.clone(), ci.table.clone(), ci.column.clone()); + let col_comment = col_comments.get(&col_key); + + // Inline enum values for enum-typed columns + let enum_annotation = enum_map.get(&ci.data_type); + + let mut suffix_parts: Vec = Vec::new(); + if let Some(vals) = enum_annotation { + let vals_str = vals + .iter() + .map(|v| format!("'{}'", v)) + .collect::>() + .join(", "); + suffix_parts.push(format!("enum: {}", vals_str)); + } + + // Inline varchar distinct values (pseudo-enums from pg_stats) + if enum_annotation.is_none() { + if let Some(vals) = varchar_values.get(&col_key) { + let vals_str = vals + .iter() + .take(15) + .map(|v| format!("'{}'", v)) + .collect::>() + .join(", "); + suffix_parts.push(format!("values: {}", vals_str)); + } + } + + // Inline JSONB key structure + if let Some(keys) = jsonb_keys.get(&col_key) { + suffix_parts.push(format!("json keys: {}", keys.join(", "))); + } + + if let Some(comment) = col_comment { + suffix_parts.push(comment.clone()); + } + + if suffix_parts.is_empty() { + output.push(format!(" {}", parts.join(" "))); + } else { + output.push(format!( + " {} -- {}", + parts.join(" "), + suffix_parts.join("; ") + )); + } + } + + // Unique constraints for this table + let schema_table: Vec<&str> = full_name.splitn(2, '.').collect(); + if schema_table.len() == 2 { + if let Some(uqs) = unique_map.get(&( + schema_table[0].to_string(), + schema_table[1].to_string(), + )) { + for uq_cols in uqs { + output.push(format!(" UNIQUE({})", uq_cols)); + } + } + } + } + + // 4. Foreign keys summary + if !fk_lines.is_empty() { + output.push(String::new()); + output.push("FOREIGN KEYS:".to_string()); + for fk in &fk_lines { + output.push(format!(" {}", fk)); + } + } + + let result = output.join("\n"); + + // Cache the result + state.set_schema_cache(connection_id.to_string(), result.clone()).await; + + Ok(result) +} + +// --------------------------------------------------------------------------- +// Schema query helpers +// --------------------------------------------------------------------------- + +#[derive(Clone)] +struct ColumnInfo { + schema: String, + table: String, + column: String, + data_type: String, + not_null: bool, + is_pk: bool, + column_default: Option, +} + +async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult> { + let rows = sqlx::query( "SELECT \ c.table_schema, c.table_name, c.column_name, \ - CASE WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name ELSE c.data_type END AS data_type, \ + CASE \ + WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name \ + WHEN c.data_type = 'ARRAY' THEN c.udt_name || '[]' \ + ELSE c.data_type \ + END AS data_type, \ c.is_nullable = 'NO' AS not_null, \ + c.column_default, \ EXISTS( \ SELECT 1 FROM information_schema.table_constraints tc \ JOIN information_schema.key_column_usage kcu \ @@ -258,45 +651,30 @@ async fn build_schema_context( .await .map_err(TuskError::Database)?; - // Group columns by schema.table - let mut tables: BTreeMap> = BTreeMap::new(); - for row in &col_rows { - let schema: String = row.get(0); - let table: String = row.get(1); - let col_name: String = row.get(2); - let data_type: String = row.get(3); - let not_null: bool = row.get(4); - let is_pk: bool = row.get(5); - - let mut parts = vec![col_name, data_type]; - if is_pk { - parts.push("PK".to_string()); - } - if not_null { - parts.push("NOT NULL".to_string()); - } - - let key = format!("{}.{}", schema, table); - tables.entry(key).or_default().push(parts.join(" ")); - } - - let mut lines: Vec = tables - .into_iter() - .map(|(key, cols)| format!("{}({})", key, cols.join(", "))) - .collect(); - - // Fetch FK relationships - let fks = fetch_foreign_keys_from_pool(pool).await?; - for fk in &fks { - lines.push(fk.clone()); - } - - Ok(lines.join("\n")) + Ok(rows + .iter() + .map(|r| ColumnInfo { + schema: r.get(0), + table: r.get(1), + column: r.get(2), + data_type: r.get(3), + not_null: r.get(4), + column_default: r.get(5), + is_pk: r.get(6), + }) + .collect()) } -async fn fetch_foreign_keys_from_pool( - pool: &sqlx::PgPool, -) -> TuskResult> { +struct ForeignKeyInfo { + schema: String, + table: String, + columns: Vec, + ref_schema: String, + ref_table: String, + ref_columns: Vec, +} + +async fn fetch_foreign_keys_raw(pool: &sqlx::PgPool) -> TuskResult> { let rows = sqlx::query( "SELECT \ cn.nspname AS schema_name, cl.relname AS table_name, \ @@ -318,28 +696,335 @@ async fn fetch_foreign_keys_from_pool( .await .map_err(TuskError::Database)?; - let fks: Vec = rows + Ok(rows + .iter() + .map(|r| ForeignKeyInfo { + schema: r.get(0), + table: r.get(1), + columns: r.get(2), + ref_schema: r.get(3), + ref_table: r.get(4), + ref_columns: r.get(5), + }) + .collect()) +} + +/// Returns BTreeMap> ordered by type name +async fn fetch_enum_types(pool: &sqlx::PgPool) -> TuskResult>> { + let rows = sqlx::query( + "SELECT t.typname, \ + array_agg(e.enumlabel ORDER BY e.enumsortorder) AS vals \ + FROM pg_enum e \ + JOIN pg_type t ON e.enumtypid = t.oid \ + JOIN pg_namespace n ON t.typnamespace = n.oid \ + WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') \ + GROUP BY t.typname \ + ORDER BY t.typname", + ) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; + + let mut map = BTreeMap::new(); + for r in &rows { + let name: String = r.get(0); + let vals: Vec = r.get(1); + map.insert(name, vals); + } + Ok(map) +} + +/// Returns HashMap<"schema.table", comment> +async fn fetch_table_comments(pool: &sqlx::PgPool) -> TuskResult> { + let rows = sqlx::query( + "SELECT n.nspname, c.relname, obj_description(c.oid, 'pg_class') \ + FROM pg_class c \ + JOIN pg_namespace n ON c.relnamespace = n.oid \ + WHERE c.relkind IN ('r', 'v', 'p', 'm') \ + AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ + AND obj_description(c.oid, 'pg_class') IS NOT NULL", + ) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; + + let mut map = HashMap::new(); + for r in &rows { + let schema: String = r.get(0); + let table: String = r.get(1); + let comment: String = r.get(2); + map.insert(format!("{}.{}", schema, table), comment); + } + Ok(map) +} + +/// Returns HashMap<(schema, table, column), comment> +async fn fetch_column_comments( + pool: &sqlx::PgPool, +) -> TuskResult> { + let rows = sqlx::query( + "SELECT n.nspname, c.relname, a.attname, d.description \ + FROM pg_description d \ + JOIN pg_class c ON d.objoid = c.oid \ + JOIN pg_namespace n ON c.relnamespace = n.oid \ + JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = d.objsubid \ + WHERE d.objsubid > 0 \ + AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit')", + ) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; + + let mut map = HashMap::new(); + for r in &rows { + let schema: String = r.get(0); + let table: String = r.get(1); + let column: String = r.get(2); + let comment: String = r.get(3); + map.insert((schema, table, column), comment); + } + Ok(map) +} + +/// Returns Vec<(schema, table, Vec)> for UNIQUE constraints +async fn fetch_unique_constraints( + pool: &sqlx::PgPool, +) -> TuskResult)>> { + let rows = sqlx::query( + "SELECT n.nspname, cl.relname, \ + array_agg(a.attname ORDER BY a.attnum) AS cols \ + FROM pg_constraint con \ + JOIN pg_class cl ON con.conrelid = cl.oid \ + JOIN pg_namespace n ON cl.relnamespace = n.oid \ + JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \ + WHERE con.contype = 'u' \ + AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ + GROUP BY n.nspname, cl.relname, con.oid \ + ORDER BY n.nspname, cl.relname", + ) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; + + Ok(rows .iter() .map(|r| { let schema: String = r.get(0); let table: String = r.get(1); let cols: Vec = r.get(2); - let ref_schema: String = r.get(3); - let ref_table: String = r.get(4); - let ref_cols: Vec = r.get(5); - format!( - "FK: {}.{}({}) -> {}.{}({})", - schema, - table, - cols.join(", "), - ref_schema, - ref_table, - ref_cols.join(", ") + (schema, table, cols) + }) + .collect()) +} + +/// Returns HashMap<(schema, table, column), Vec> for varchar columns +/// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery. +/// Returns None if pg_stats is not accessible (graceful degradation). +async fn fetch_varchar_values( + pool: &sqlx::PgPool, +) -> Option>> { + let rows = match sqlx::query( + "SELECT s.schemaname, s.tablename, s.attname, \ + s.most_common_vals::text AS vals \ + FROM pg_stats s \ + JOIN information_schema.columns c \ + ON c.table_schema = s.schemaname \ + AND c.table_name = s.tablename \ + AND c.column_name = s.attname \ + WHERE s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ + AND c.data_type = 'character varying' \ + AND s.n_distinct > 0 AND s.n_distinct <= 20 \ + AND s.most_common_vals IS NOT NULL \ + ORDER BY s.schemaname, s.tablename, s.attname", + ) + .fetch_all(pool) + .await + { + Ok(r) => r, + Err(e) => { + log::warn!("Failed to fetch varchar values from pg_stats: {}", e); + return None; + } + }; + + let mut map = HashMap::new(); + for r in &rows { + let schema: String = r.get(0); + let table: String = r.get(1); + let column: String = r.get(2); + let vals_text: String = r.get(3); + let vals = parse_pg_array_text(&vals_text); + if !vals.is_empty() { + map.insert((schema, table, column), vals); + } + } + Some(map) +} + +/// Discovers top-level keys in JSONB columns by sampling actual data. +/// Runs two sequential queries internally: first discovers JSONB columns, +/// then samples keys from each via a single UNION ALL query. +/// Returns None on error (graceful degradation). +async fn fetch_jsonb_keys( + pool: &sqlx::PgPool, +) -> Option>> { + // Step 1: Find all JSONB columns + let col_rows = match sqlx::query( + "SELECT table_schema, table_name, column_name \ + FROM information_schema.columns \ + WHERE data_type = 'jsonb' \ + AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ + ORDER BY table_schema, table_name, column_name", + ) + .fetch_all(pool) + .await + { + Ok(r) => r, + Err(e) => { + log::warn!("Failed to fetch JSONB columns: {}", e); + return None; + } + }; + + if col_rows.is_empty() { + return Some(HashMap::new()); + } + + // Cap at 50 JSONB columns to prevent unbounded UNION ALL queries on large schemas + let columns: Vec<(String, String, String)> = col_rows + .iter() + .take(50) + .map(|r| { + ( + r.get::(0), + r.get::(1), + r.get::(2), ) }) .collect(); - Ok(fks) + // Step 2: Build a single UNION ALL query to sample keys from all JSONB columns + let parts: Vec = columns + .iter() + .enumerate() + .map(|(i, (schema, table, col))| { + let qs = schema.replace('"', "\"\""); + let qt = table.replace('"', "\"\""); + let qc = col.replace('"', "\"\""); + format!( + "(SELECT '{}.{}.{}' AS col_ref, key FROM (\ + SELECT DISTINCT jsonb_object_keys(\"{}\") AS key \ + FROM \"{}\".\"{}\" \ + WHERE \"{}\" IS NOT NULL AND jsonb_typeof(\"{}\") = 'object' \ + LIMIT 50\ + ) sub{})", + schema, table, col, qc, qs, qt, qc, qc, i + ) + }) + .collect(); + + let query = parts.join(" UNION ALL "); + + let rows = match sqlx::query(&query) + .fetch_all(pool) + .await + { + Ok(r) => r, + Err(e) => { + log::warn!("Failed to fetch JSONB keys: {}", e); + return None; + } + }; + + let mut map: HashMap<(String, String, String), Vec> = HashMap::new(); + for r in &rows { + let col_ref: String = r.get(0); + let key: String = r.get(1); + let ref_parts: Vec<&str> = col_ref.splitn(3, '.').collect(); + if ref_parts.len() == 3 { + let entry = map + .entry(( + ref_parts[0].to_string(), + ref_parts[1].to_string(), + ref_parts[2].to_string(), + )) + .or_default(); + if !entry.contains(&key) { + entry.push(key); + } + } + } + + for vals in map.values_mut() { + vals.sort(); + } + + Some(map) +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"} +fn parse_pg_array_text(s: &str) -> Vec { + let s = s.trim(); + let s = s.strip_prefix('{').unwrap_or(s); + let s = s.strip_suffix('}').unwrap_or(s); + if s.is_empty() { + return Vec::new(); + } + + let mut result = Vec::new(); + let mut current = String::new(); + let mut in_quotes = false; + let mut chars = s.chars().peekable(); + + while let Some(ch) = chars.next() { + match ch { + '"' if !in_quotes => { + in_quotes = true; + } + '"' if in_quotes => { + if chars.peek() == Some(&'"') { + current.push('"'); + chars.next(); + } else { + in_quotes = false; + } + } + ',' if !in_quotes => { + result.push(current.trim().to_string()); + current = String::new(); + } + _ => { + current.push(ch); + } + } + } + if !current.is_empty() || !result.is_empty() { + result.push(current.trim().to_string()); + } + result +} + +fn simplify_default(raw: &str) -> String { + let s = raw.trim(); + if s.contains("nextval(") { + return "auto-increment".to_string(); + } + // Shorten common defaults + if s == "now()" || s == "CURRENT_TIMESTAMP" || s == "current_timestamp" { + return "now()".to_string(); + } + if s == "true" || s == "false" { + return s.to_string(); + } + // Numeric/string literals — keep short ones, skip very long generated defaults + if s.len() > 50 { + return String::new(); + } + s.to_string() } fn clean_sql_response(raw: &str) -> String { @@ -349,6 +1034,7 @@ fn clean_sql_response(raw: &str) -> String { let inner = trimmed .strip_prefix("```sql") .or_else(|| trimmed.strip_prefix("```SQL")) + .or_else(|| trimmed.strip_prefix("```postgresql")) .or_else(|| trimmed.strip_prefix("```")) .unwrap_or(trimmed); inner.strip_suffix("```").unwrap_or(inner) diff --git a/src-tauri/src/commands/connections.rs b/src-tauri/src/commands/connections.rs index 5550c9a..0e36d38 100644 --- a/src-tauri/src/commands/connections.rs +++ b/src-tauri/src/commands/connections.rs @@ -14,7 +14,7 @@ pub struct ConnectResult { pub flavor: DbFlavor, } -fn get_connections_path(app: &AppHandle) -> TuskResult { +pub(crate) fn get_connections_path(app: &AppHandle) -> TuskResult { let dir = app .path() .app_data_dir() diff --git a/src-tauri/src/commands/data.rs b/src-tauri/src/commands/data.rs index 0043fd9..51e41cf 100644 --- a/src-tauri/src/commands/data.rs +++ b/src-tauri/src/commands/data.rs @@ -21,10 +21,7 @@ pub async fn get_table_data( sort_direction: Option, filter: Option, ) -> TuskResult { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); @@ -56,11 +53,24 @@ pub async fn get_table_data( let start = Instant::now(); - let (rows, count_row) = tokio::try_join!( - sqlx::query(&data_sql).fetch_all(pool), - sqlx::query(&count_sql).fetch_one(pool), - ) - .map_err(TuskError::Database)?; + // Always run table data queries in a read-only transaction to prevent + // writable CTEs or other mutation via the raw filter parameter. + let mut tx = (&pool).begin().await.map_err(TuskError::Database)?; + sqlx::query("SET TRANSACTION READ ONLY") + .execute(&mut *tx) + .await + .map_err(TuskError::Database)?; + + let rows = sqlx::query(&data_sql) + .fetch_all(&mut *tx) + .await + .map_err(TuskError::Database)?; + let count_row = sqlx::query(&count_sql) + .fetch_one(&mut *tx) + .await + .map_err(TuskError::Database)?; + + tx.rollback().await.map_err(TuskError::Database)?; let execution_time_ms = start.elapsed().as_millis(); let total_rows: i64 = count_row.get(0); @@ -134,10 +144,7 @@ pub async fn update_row( return Err(TuskError::ReadOnly); } - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); @@ -155,7 +162,7 @@ pub async fn update_row( let mut query = sqlx::query(&sql); query = bind_json_value(query, &value); query = query.bind(ctid_val); - query.execute(pool).await.map_err(TuskError::Database)?; + query.execute(&pool).await.map_err(TuskError::Database)?; } else { let where_parts: Vec = pk_columns .iter() @@ -174,7 +181,7 @@ pub async fn update_row( for pk_val in &pk_values { query = bind_json_value(query, pk_val); } - query.execute(pool).await.map_err(TuskError::Database)?; + query.execute(&pool).await.map_err(TuskError::Database)?; } Ok(()) @@ -193,10 +200,7 @@ pub async fn insert_row( return Err(TuskError::ReadOnly); } - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); @@ -215,7 +219,7 @@ pub async fn insert_row( query = bind_json_value(query, val); } - query.execute(pool).await.map_err(TuskError::Database)?; + query.execute(&pool).await.map_err(TuskError::Database)?; Ok(()) } @@ -234,14 +238,14 @@ pub async fn delete_rows( return Err(TuskError::ReadOnly); } - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); let mut total_affected: u64 = 0; + // Wrap all deletes in a transaction for atomicity + let mut tx = (&pool).begin().await.map_err(TuskError::Database)?; + if pk_columns.is_empty() { // Fallback: use ctids for row identification let ctid_list = ctids.ok_or_else(|| { @@ -250,7 +254,7 @@ pub async fn delete_rows( for ctid_val in &ctid_list { let sql = format!("DELETE FROM {} WHERE ctid = $1::tid", qualified); let query = sqlx::query(&sql).bind(ctid_val); - let result = query.execute(pool).await.map_err(TuskError::Database)?; + let result = query.execute(&mut *tx).await.map_err(TuskError::Database)?; total_affected += result.rows_affected(); } } else { @@ -269,11 +273,13 @@ pub async fn delete_rows( query = bind_json_value(query, val); } - let result = query.execute(pool).await.map_err(TuskError::Database)?; + let result = query.execute(&mut *tx).await.map_err(TuskError::Database)?; total_affected += result.rows_affected(); } } + tx.commit().await.map_err(TuskError::Database)?; + Ok(total_affected) } diff --git a/src-tauri/src/commands/docker.rs b/src-tauri/src/commands/docker.rs index 73010ac..e39df66 100644 --- a/src-tauri/src/commands/docker.rs +++ b/src-tauri/src/commands/docker.rs @@ -4,9 +4,10 @@ use crate::models::docker::{ CloneMode, CloneProgress, CloneResult, CloneToDockerParams, DockerStatus, TuskContainer, }; use crate::state::AppState; +use crate::utils::escape_ident; use std::fs; use std::sync::Arc; -use tauri::{AppHandle, Emitter, Manager, State}; +use tauri::{AppHandle, Emitter, State}; use tokio::process::Command; async fn docker_cmd(state: &AppState) -> Command { @@ -42,17 +43,8 @@ fn emit_progress( ); } -fn get_connections_path(app: &AppHandle) -> TuskResult { - let dir = app - .path() - .app_data_dir() - .map_err(|e| TuskError::Custom(e.to_string()))?; - fs::create_dir_all(&dir)?; - Ok(dir.join("connections.json")) -} - fn load_connection_config(app: &AppHandle, connection_id: &str) -> TuskResult { - let path = get_connections_path(app)?; + let path = super::connections::get_connections_path(app)?; if !path.exists() { return Err(TuskError::ConnectionNotFound(connection_id.to_string())); } @@ -69,43 +61,58 @@ fn shell_escape(s: &str) -> String { s.replace('\'', "'\\''") } +/// Validate pg_version matches a safe pattern (e.g. "16", "16.2", "17.1") +fn validate_pg_version(version: &str) -> TuskResult<()> { + let is_valid = !version.is_empty() + && version + .chars() + .all(|c| c.is_ascii_digit() || c == '.'); + if !is_valid { + return Err(docker_err(format!( + "Invalid pg_version '{}': must contain only digits and dots (e.g. '16', '16.2')", + version + ))); + } + Ok(()) +} + +/// Validate container name matches Docker naming rules: [a-zA-Z0-9][a-zA-Z0-9_.-]* +fn validate_container_name(name: &str) -> TuskResult<()> { + if name.is_empty() { + return Err(docker_err("Container name cannot be empty")); + } + let first = name.chars().next().unwrap(); + if !first.is_ascii_alphanumeric() { + return Err(docker_err(format!( + "Invalid container name '{}': must start with an alphanumeric character", + name + ))); + } + let is_valid = name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '-'); + if !is_valid { + return Err(docker_err(format!( + "Invalid container name '{}': only [a-zA-Z0-9_.-] characters are allowed", + name + ))); + } + Ok(()) +} + +/// Shell-escape a string for use inside double-quoted shell contexts +fn shell_escape_double(s: &str) -> String { + s.replace('\\', "\\\\") + .replace('"', "\\\"") + .replace('$', "\\$") + .replace('`', "\\`") + .replace('!', "\\!") +} + #[tauri::command] pub async fn check_docker(state: State<'_, Arc>) -> TuskResult { - let output = docker_cmd(&state) - .await - .args(["version", "--format", "{{.Server.Version}}"]) - .output() - .await; - - match output { - Ok(out) => { - if out.status.success() { - let version = String::from_utf8_lossy(&out.stdout).trim().to_string(); - Ok(DockerStatus { - installed: true, - daemon_running: true, - version: Some(version), - error: None, - }) - } else { - let stderr = String::from_utf8_lossy(&out.stderr).trim().to_string(); - let daemon_running = !stderr.contains("Cannot connect") - && !stderr.contains("connection refused"); - Ok(DockerStatus { - installed: true, - daemon_running, - version: None, - error: Some(stderr), - }) - } - } - Err(_) => Ok(DockerStatus { - installed: false, - daemon_running: false, - version: None, - error: Some("Docker CLI not found. Please install Docker.".to_string()), - }), - } + let docker_host = state.docker_host.read().await.clone(); + check_docker_internal(&docker_host).await } #[tauri::command] @@ -252,6 +259,10 @@ async fn do_clone( params: &CloneToDockerParams, clone_id: &str, ) -> TuskResult { + // Validate user inputs before any operations + validate_pg_version(¶ms.pg_version)?; + validate_container_name(¶ms.container_name)?; + let docker_host = state.docker_host.read().await.clone(); // Step 1: Check Docker @@ -313,7 +324,7 @@ async fn do_clone( .args([ "exec", ¶ms.container_name, "psql", "-U", "postgres", "-c", - &format!("CREATE DATABASE \"{}\"", params.source_database), + &format!("CREATE DATABASE {}", escape_ident(¶ms.source_database)), ]) .output() .await @@ -492,7 +503,11 @@ async fn run_pipe_cmd( if !stderr.is_empty() { // Truncate for progress display (full log can be long) let short = if stderr.len() > 500 { - format!("{}...", &stderr[..500]) + let truncated = stderr.char_indices() + .nth(500) + .map(|(i, _)| &stderr[..i]) + .unwrap_or(&stderr); + format!("{}...", truncated) } else { stderr.clone() }; @@ -633,13 +648,16 @@ async fn transfer_sample_data( let table = parts[1]; // Use COPY (SELECT ... LIMIT N) TO STDOUT piped to COPY ... FROM STDIN + // Escape schema/table for use inside double-quoted shell strings + let escaped_schema = shell_escape_double(schema); + let escaped_table = shell_escape_double(table); let copy_out_sql = format!( "\\copy (SELECT * FROM \\\"{}\\\".\\\"{}\\\" LIMIT {}) TO STDOUT", - schema, table, sample_rows + escaped_schema, escaped_table, sample_rows ); let copy_in_sql = format!( "\\copy \\\"{}\\\".\\\"{}\\\" FROM STDIN", - schema, table + escaped_schema, escaped_table ); let escaped_url = shell_escape(source_url); @@ -693,7 +711,7 @@ async fn transfer_sample_data( } fn save_connection_config(app: &AppHandle, config: &ConnectionConfig) -> TuskResult<()> { - let path = get_connections_path(app)?; + let path = super::connections::get_connections_path(app)?; let mut connections = if path.exists() { let data = fs::read_to_string(&path)?; serde_json::from_str::>(&data)? @@ -701,7 +719,12 @@ fn save_connection_config(app: &AppHandle, config: &ConnectionConfig) -> TuskRes vec![] }; - connections.push(config.clone()); + // Upsert by ID to avoid duplicate entries on retry + if let Some(pos) = connections.iter().position(|c| c.id == config.id) { + connections[pos] = config.clone(); + } else { + connections.push(config.clone()); + } let data = serde_json::to_string_pretty(&connections)?; fs::write(&path, data)?; diff --git a/src-tauri/src/commands/schema.rs b/src-tauri/src/commands/schema.rs index 97cba8c..4535729 100644 --- a/src-tauri/src/commands/schema.rs +++ b/src-tauri/src/commands/schema.rs @@ -14,17 +14,14 @@ pub async fn list_databases( state: State<'_, Arc>, connection_id: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT datname FROM pg_database \ WHERE datistemplate = false \ ORDER BY datname", ) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -35,10 +32,7 @@ pub async fn list_schemas_core( state: &AppState, connection_id: &str, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(connection_id) - .ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?; + let pool = state.get_pool(connection_id).await?; let flavor = state.get_flavor(connection_id).await; let sql = if flavor == DbFlavor::Greenplum { @@ -52,7 +46,7 @@ pub async fn list_schemas_core( }; let rows = sqlx::query(sql) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -72,10 +66,7 @@ pub async fn list_tables_core( connection_id: &str, schema: &str, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(connection_id) - .ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?; + let pool = state.get_pool(connection_id).await?; let rows = sqlx::query( "SELECT t.table_name, \ @@ -88,7 +79,7 @@ pub async fn list_tables_core( ORDER BY t.table_name", ) .bind(schema) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -119,10 +110,7 @@ pub async fn list_views( connection_id: String, schema: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT table_name FROM information_schema.views \ @@ -130,7 +118,7 @@ pub async fn list_views( ORDER BY table_name", ) .bind(&schema) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -152,10 +140,7 @@ pub async fn list_functions( connection_id: String, schema: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT routine_name FROM information_schema.routines \ @@ -163,7 +148,7 @@ pub async fn list_functions( ORDER BY routine_name", ) .bind(&schema) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -185,10 +170,7 @@ pub async fn list_indexes( connection_id: String, schema: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT indexname FROM pg_indexes \ @@ -196,7 +178,7 @@ pub async fn list_indexes( ORDER BY indexname", ) .bind(&schema) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -218,10 +200,7 @@ pub async fn list_sequences( connection_id: String, schema: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT sequence_name FROM information_schema.sequences \ @@ -229,7 +208,7 @@ pub async fn list_sequences( ORDER BY sequence_name", ) .bind(&schema) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -251,10 +230,7 @@ pub async fn get_table_columns_core( schema: &str, table: &str, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(connection_id) - .ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?; + let pool = state.get_pool(connection_id).await?; let rows = sqlx::query( "SELECT \ @@ -287,7 +263,7 @@ pub async fn get_table_columns_core( ) .bind(schema) .bind(table) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -323,10 +299,7 @@ pub async fn get_table_constraints( schema: String, table: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT \ @@ -376,7 +349,7 @@ pub async fn get_table_constraints( ) .bind(&schema) .bind(&table) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -402,10 +375,7 @@ pub async fn get_table_indexes( schema: String, table: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT \ @@ -422,7 +392,7 @@ pub async fn get_table_indexes( ) .bind(&schema) .bind(&table) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -443,10 +413,7 @@ pub async fn get_completion_schema( connection_id: String, ) -> TuskResult>>> { let flavor = state.get_flavor(&connection_id).await; - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let sql = if flavor == DbFlavor::Greenplum { "SELECT table_schema, table_name, column_name \ @@ -461,7 +428,7 @@ pub async fn get_completion_schema( }; let rows = sqlx::query(sql) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -490,10 +457,7 @@ pub async fn get_column_details( table: String, ) -> TuskResult> { let flavor = state.get_flavor(&connection_id).await; - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let sql = if flavor == DbFlavor::Greenplum { "SELECT c.column_name, c.data_type, \ @@ -516,7 +480,7 @@ pub async fn get_column_details( let rows = sqlx::query(sql) .bind(&schema) .bind(&table) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -539,10 +503,7 @@ pub async fn get_table_triggers( schema: String, table: String, ) -> TuskResult> { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; let rows = sqlx::query( "SELECT \ @@ -571,7 +532,7 @@ pub async fn get_table_triggers( ) .bind(&schema) .bind(&table) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -595,10 +556,7 @@ pub async fn get_schema_erd( connection_id: String, schema: String, ) -> TuskResult { - let pools = state.pools.read().await; - let pool = pools - .get(&connection_id) - .ok_or(TuskError::NotConnected(connection_id))?; + let pool = state.get_pool(&connection_id).await?; // Get all tables with columns let col_rows = sqlx::query( @@ -627,7 +585,7 @@ pub async fn get_schema_erd( ORDER BY c.table_name, c.ordinal_position", ) .bind(&schema) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; @@ -690,7 +648,7 @@ pub async fn get_schema_erd( ORDER BY c.conname", ) .bind(&schema) - .fetch_all(pool) + .fetch_all(&pool) .await .map_err(TuskError::Database)?; diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index b2ed6a6..6a794b5 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -13,24 +13,20 @@ use tauri::Manager; pub fn run() { let shared_state = Arc::new(AppState::new()); - tauri::Builder::default() + let _ = tauri::Builder::default() .plugin(tauri_plugin_shell::init()) .plugin(tauri_plugin_dialog::init()) .manage(shared_state) .setup(|app| { let state = app.state::>().inner().clone(); - let connections_path = app + let data_dir = app .path() .app_data_dir() - .expect("failed to resolve app data dir") - .join("connections.json"); + .map_err(|e| Box::new(e) as Box)?; + let connections_path = data_dir.join("connections.json"); // Read app settings - let settings_path = app - .path() - .app_data_dir() - .expect("failed to resolve app data dir") - .join("app_settings.json"); + let settings_path = data_dir.join("app_settings.json"); let settings = if settings_path.exists() { std::fs::read_to_string(&settings_path) @@ -154,5 +150,7 @@ pub fn run() { commands::settings::get_mcp_status, ]) .run(tauri::generate_context!()) - .expect("error while running tauri application"); + .inspect_err(|e| { + log::error!("Tauri application error: {}", e); + }); } diff --git a/src-tauri/src/models/ai.rs b/src-tauri/src/models/ai.rs index 1dabb7d..3b312dc 100644 --- a/src-tauri/src/models/ai.rs +++ b/src-tauri/src/models/ai.rs @@ -1,27 +1,42 @@ use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum AiProvider { + #[default] + Ollama, + OpenAi, + Anthropic, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AiSettings { + pub provider: AiProvider, pub ollama_url: String, + pub openai_api_key: Option, + pub anthropic_api_key: Option, pub model: String, } impl Default for AiSettings { fn default() -> Self { Self { + provider: AiProvider::Ollama, ollama_url: "http://localhost:11434".to_string(), + openai_api_key: None, + anthropic_api_key: None, model: String::new(), } } } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaChatMessage { pub role: String, pub content: String, } -#[derive(Debug, Serialize)] +#[derive(Debug, Clone, Serialize)] pub struct OllamaChatRequest { pub model: String, pub messages: Vec, diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index afd1f00..d86ae5e 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -1,3 +1,5 @@ +use crate::error::{TuskError, TuskResult}; +use crate::models::ai::AiSettings; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; @@ -27,6 +29,7 @@ pub struct AppState { pub mcp_shutdown_tx: watch::Sender, pub mcp_running: RwLock, pub docker_host: RwLock>, + pub ai_settings: RwLock>, } const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes @@ -43,9 +46,18 @@ impl AppState { mcp_shutdown_tx, mcp_running: RwLock::new(false), docker_host: RwLock::new(None), + ai_settings: RwLock::new(None), } } + pub async fn get_pool(&self, connection_id: &str) -> TuskResult { + let pools = self.pools.read().await; + pools + .get(connection_id) + .cloned() + .ok_or_else(|| TuskError::NotConnected(connection_id.to_string())) + } + pub async fn is_read_only(&self, id: &str) -> bool { let map = self.read_only.read().await; map.get(id).copied().unwrap_or(true) diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index f9e3d05..c5c50a8 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -2,7 +2,7 @@ "$schema": "https://schema.tauri.app/config/2", "productName": "Tusk", "version": "0.1.0", - "identifier": "com.tusk.app", + "identifier": "com.tusk.dbm", "build": { "frontendDist": "../dist", "devUrl": "http://localhost:5173", @@ -27,7 +27,7 @@ }, "bundle": { "active": true, - "targets": "all", + "targets": ["deb", "rpm", "dmg", "nsis"], "icon": [ "icons/32x32.png", "icons/128x128.png", diff --git a/src/components/ai/AiSettingsFields.tsx b/src/components/ai/AiSettingsFields.tsx new file mode 100644 index 0000000..edd163a --- /dev/null +++ b/src/components/ai/AiSettingsFields.tsx @@ -0,0 +1,82 @@ +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { useOllamaModels } from "@/hooks/use-ai"; +import { RefreshCw, Loader2 } from "lucide-react"; + +interface Props { + ollamaUrl: string; + onOllamaUrlChange: (url: string) => void; + model: string; + onModelChange: (model: string) => void; +} + +export function AiSettingsFields({ + ollamaUrl, + onOllamaUrlChange, + model, + onModelChange, +}: Props) { + const { + data: models, + isLoading: modelsLoading, + isError: modelsError, + refetch: refetchModels, + } = useOllamaModels(ollamaUrl); + + return ( + <> +
+ + onOllamaUrlChange(e.target.value)} + placeholder="http://localhost:11434" + className="h-8 text-xs" + /> +
+ +
+
+ + +
+ {modelsError ? ( +

Cannot connect to Ollama

+ ) : ( + + )} +
+ + ); +} diff --git a/src/components/ai/AiSettingsPopover.tsx b/src/components/ai/AiSettingsPopover.tsx index c942d17..54a82a2 100644 --- a/src/components/ai/AiSettingsPopover.tsx +++ b/src/components/ai/AiSettingsPopover.tsx @@ -5,17 +5,10 @@ import { PopoverTrigger, } from "@/components/ui/popover"; import { Button } from "@/components/ui/button"; -import { Input } from "@/components/ui/input"; -import { - Select, - SelectContent, - SelectItem, - SelectTrigger, - SelectValue, -} from "@/components/ui/select"; -import { useAiSettings, useSaveAiSettings, useOllamaModels } from "@/hooks/use-ai"; -import { Settings, RefreshCw, Loader2 } from "lucide-react"; +import { useAiSettings, useSaveAiSettings } from "@/hooks/use-ai"; +import { Settings } from "lucide-react"; import { toast } from "sonner"; +import { AiSettingsFields } from "./AiSettingsFields"; export function AiSettingsPopover() { const { data: settings } = useAiSettings(); @@ -27,16 +20,9 @@ export function AiSettingsPopover() { const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434"; const currentModel = model ?? settings?.model ?? ""; - const { - data: models, - isLoading: modelsLoading, - isError: modelsError, - refetch: refetchModels, - } = useOllamaModels(currentUrl); - const handleSave = () => { saveMutation.mutate( - { ollama_url: currentUrl, model: currentModel }, + { provider: "ollama", ollama_url: currentUrl, model: currentModel }, { onSuccess: () => toast.success("AI settings saved"), onError: (err) => @@ -63,53 +49,12 @@ export function AiSettingsPopover() {

Ollama Settings

-
- - setUrl(e.target.value)} - placeholder="http://localhost:11434" - className="h-8 text-xs" - /> -
- -
-
- - -
- {modelsError ? ( -

- Cannot connect to Ollama -

- ) : ( - - )} -
+
-
- - setOllamaUrl(e.target.value)} - placeholder="http://localhost:11434" - className="h-8 text-xs" - /> -
- -
-
- - -
- {modelsError ? ( -

Cannot connect to Ollama

- ) : ( - - )} -
+