From e8d99c645b051f5e5dfbf430d410a9db36fdcac9 Mon Sep 17 00:00:00 2001 From: "A.Shakhmatov" Date: Fri, 13 Feb 2026 18:24:06 +0300 Subject: [PATCH] feat: add Greenplum 7 compatibility and AI SQL generation Greenplum 7 (PG12-based) compatibility: - Auto-detect GP via version() string, store DbFlavor per connection - connect returns ConnectResult with version + flavor - Fix pg_total_relation_size to use c.oid (universal, safer on both PG/GP) - Branch is_identity column query for GP (lacks the column) - Branch list_sessions wait_event fields for GP - Exclude gp_toolkit schema in schema listing, completion, lookup, AI context - Smart StatusBar version display: GP shows "GP 7.0.0 (PG 12.4)" - Fix connection list spinner showing on all cards during connect AI SQL generation (Ollama): - Add AI settings, model selection, and generate_sql command - Frontend AI panel with prompt input and SQL output Co-Authored-By: Claude Opus 4.6 --- src-tauri/Cargo.lock | 315 +++++++++++++++++- src-tauri/Cargo.toml | 1 + src-tauri/src/commands/ai.rs | 299 +++++++++++++++++ src-tauri/src/commands/connections.rs | 46 ++- src-tauri/src/commands/lookup.rs | 2 +- src-tauri/src/commands/management.rs | 24 +- src-tauri/src/commands/schema.rs | 70 ++-- src-tauri/src/error.rs | 3 + src-tauri/src/lib.rs | 1 + src-tauri/src/models/ai.rs | 44 +++ src-tauri/src/state.rs | 15 + src/App.tsx | 5 +- src/components/ai/AiBar.tsx | 92 +++++ src/components/ai/AiSettingsPopover.tsx | 121 +++++++ src/components/connections/ConnectionList.tsx | 10 +- src/components/history/HistoryPanel.tsx | 5 +- src/components/layout/StatusBar.tsx | 12 +- src/components/management/AdminPanel.tsx | 2 + .../saved-queries/SavedQueriesPanel.tsx | 3 +- src/components/schema/SchemaTree.tsx | 2 + src/components/ui/popover.tsx | 87 +++++ src/components/workspace/WorkspacePanel.tsx | 131 +++++--- src/hooks/use-ai.ts | 47 +++ src/hooks/use-connections.ts | 20 +- src/lib/tauri.ts | 7 +- src/stores/app-store.ts | 18 +- src/types/index.ts | 7 + 27 files changed, 1276 insertions(+), 113 deletions(-) create mode 100644 src-tauri/src/commands/ai.rs create mode 100644 src-tauri/src/models/ai.rs create mode 100644 src/components/ai/AiBar.tsx create mode 100644 src/components/ai/AiSettingsPopover.tsx create mode 100644 src/components/ui/popover.tsx create mode 100644 src/hooks/use-ai.ts diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 3f3371f..bce075b 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -438,6 +438,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.1" @@ -461,9 +471,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1" dependencies = [ "bitflags 2.10.0", - "core-foundation", + "core-foundation 0.10.1", "core-graphics-types", - "foreign-types", + "foreign-types 0.5.0", "libc", ] @@ -474,7 +484,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" dependencies = [ "bitflags 2.10.0", - "core-foundation", + "core-foundation 0.10.1", "libc", ] @@ -920,6 +930,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fdeflate" version = "0.3.7" @@ -978,6 +994,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] + [[package]] name = "foreign-types" version = "0.5.0" @@ -985,7 +1010,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared", + "foreign-types-shared 0.3.1", ] [[package]] @@ -999,6 +1024,12 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1424,6 +1455,25 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "h2" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap 2.13.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1568,6 +1618,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -1580,6 +1631,38 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.20" @@ -1598,9 +1681,11 @@ dependencies = [ "percent-encoding", "pin-project-lite", "socket2", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -1994,6 +2079,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.1" @@ -2131,6 +2222,23 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "native-tls" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cdede44f9a69cab2899a2049e2c3bd49bf911a157f6a3353d4a91c61abbce44" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndk" version = "0.9.0" @@ -2487,6 +2595,50 @@ dependencies = [ "pathdiff", ] +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" @@ -3099,6 +3251,46 @@ version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "reqwest" version = "0.13.2" @@ -3245,6 +3437,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags 2.10.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustls" version = "0.23.36" @@ -3300,6 +3505,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "schemars" version = "0.8.22" @@ -3371,6 +3585,29 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "selectors" version = "0.24.0" @@ -4092,6 +4329,27 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "system-configuration" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "system-deps" version = "6.2.2" @@ -4113,7 +4371,7 @@ checksum = "f3a753bdc39c07b192151523a3f77cd0394aa75413802c883a0f6f6a0e5ee2e7" dependencies = [ "bitflags 2.10.0", "block2", - "core-foundation", + "core-foundation 0.10.1", "core-graphics", "crossbeam-channel", "dispatch", @@ -4192,7 +4450,7 @@ dependencies = [ "percent-encoding", "plist", "raw-window-handle", - "reqwest", + "reqwest 0.13.2", "serde", "serde_json", "serde_repr", @@ -4455,6 +4713,19 @@ dependencies = [ "toml 0.9.12+spec-1.1.0", ] +[[package]] +name = "tempfile" +version = "3.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "tendril" version = "0.4.3" @@ -4590,6 +4861,26 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.18" @@ -4826,6 +5117,7 @@ dependencies = [ "csv", "hex", "log", + "reqwest 0.12.28", "rmcp", "schemars 1.2.1", "serde", @@ -5405,6 +5697,17 @@ dependencies = [ "windows-link 0.1.3", ] +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link 0.2.1", + "windows-result 0.4.1", + "windows-strings 0.5.1", +] + [[package]] name = "windows-result" version = "0.3.4" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index fa5c050..39a5150 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -30,6 +30,7 @@ csv = "1" log = "0.4" hex = "0.4" bigdecimal = { version = "0.4", features = ["serde"] } +reqwest = { version = "0.12", features = ["json"] } rmcp = { version = "0.15", features = ["server", "macros", "transport-streamable-http-server"] } axum = "0.8" schemars = "1" diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs new file mode 100644 index 0000000..132fd03 --- /dev/null +++ b/src-tauri/src/commands/ai.rs @@ -0,0 +1,299 @@ +use crate::error::{TuskError, TuskResult}; +use crate::models::ai::{ + AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, + OllamaTagsResponse, +}; +use crate::state::AppState; +use sqlx::Row; +use std::collections::BTreeMap; +use std::fs; +use std::sync::Arc; +use std::time::Duration; +use tauri::{AppHandle, Manager, State}; + +fn http_client() -> reqwest::Client { + reqwest::Client::builder() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(300)) + .build() + .unwrap_or_default() +} + +fn get_ai_settings_path(app: &AppHandle) -> TuskResult { + let dir = app + .path() + .app_data_dir() + .map_err(|e| TuskError::Custom(e.to_string()))?; + fs::create_dir_all(&dir)?; + Ok(dir.join("ai_settings.json")) +} + +#[tauri::command] +pub async fn get_ai_settings(app: AppHandle) -> TuskResult { + let path = get_ai_settings_path(&app)?; + if !path.exists() { + return Ok(AiSettings::default()); + } + let data = fs::read_to_string(&path)?; + let settings: AiSettings = serde_json::from_str(&data)?; + Ok(settings) +} + +#[tauri::command] +pub async fn save_ai_settings(app: AppHandle, settings: AiSettings) -> TuskResult<()> { + let path = get_ai_settings_path(&app)?; + let data = serde_json::to_string_pretty(&settings)?; + fs::write(&path, data)?; + Ok(()) +} + +#[tauri::command] +pub async fn list_ollama_models(ollama_url: String) -> TuskResult> { + let url = format!("{}/api/tags", ollama_url.trim_end_matches('/')); + let resp = http_client() + .get(&url) + .send() + .await + .map_err(|e| TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e)))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(TuskError::Ai(format!( + "Ollama error ({}): {}", + status, body + ))); + } + + let tags: OllamaTagsResponse = resp + .json() + .await + .map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?; + + Ok(tags.models) +} + +#[tauri::command] +pub async fn generate_sql( + app: AppHandle, + state: State<'_, Arc>, + connection_id: String, + prompt: String, +) -> TuskResult { + // Load AI settings + let settings = { + let path = get_ai_settings_path(&app)?; + if !path.exists() { + return Err(TuskError::Ai( + "No AI model selected. Open AI settings to choose a model.".to_string(), + )); + } + let data = fs::read_to_string(&path)?; + serde_json::from_str::(&data)? + }; + + if settings.model.is_empty() { + return Err(TuskError::Ai( + "No AI model selected. Open AI settings to choose a model.".to_string(), + )); + } + + // Build schema context + let schema_text = build_schema_context(&state, &connection_id).await?; + + let system_prompt = format!( + "You are a PostgreSQL SQL generator. Given the database schema below and a natural language request, \ + output ONLY a valid PostgreSQL SQL query. Do not include any explanation, markdown formatting, \ + or code fences. Output raw SQL only.\n\n\ + RULES:\n\ + - Use FK relationships for correct JOIN conditions.\n\ + - timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ + - interval cannot be cast to numeric directly.\n\ + - When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\ + - Use COALESCE for nullable columns in aggregations when appropriate.\n\ + - Prefer LEFT JOIN when the related row may not exist.\n\n\ + DATABASE SCHEMA:\n{}", + schema_text + ); + + let request = OllamaChatRequest { + model: settings.model, + messages: vec![ + OllamaChatMessage { + role: "system".to_string(), + content: system_prompt, + }, + OllamaChatMessage { + role: "user".to_string(), + content: prompt, + }, + ], + stream: false, + }; + + let url = format!( + "{}/api/chat", + settings.ollama_url.trim_end_matches('/') + ); + + let resp = http_client() + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| { + TuskError::Ai(format!( + "Cannot connect to Ollama at {}: {}", + settings.ollama_url, e + )) + })?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(TuskError::Ai(format!( + "Ollama error ({}): {}", + status, body + ))); + } + + let chat_resp: OllamaChatResponse = resp + .json() + .await + .map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?; + + let sql = clean_sql_response(&chat_resp.message.content); + Ok(sql) +} + +async fn build_schema_context( + state: &AppState, + connection_id: &str, +) -> TuskResult { + let pools = state.pools.read().await; + let pool = pools + .get(connection_id) + .ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?; + + // Single query: all columns with real type names (enum types show actual name, not USER-DEFINED) + let col_rows = sqlx::query( + "SELECT \ + c.table_schema, c.table_name, c.column_name, \ + CASE WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name ELSE c.data_type END AS data_type, \ + c.is_nullable = 'NO' AS not_null, \ + EXISTS( \ + SELECT 1 FROM information_schema.table_constraints tc \ + JOIN information_schema.key_column_usage kcu \ + ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema \ + WHERE tc.constraint_type = 'PRIMARY KEY' \ + AND tc.table_schema = c.table_schema \ + AND tc.table_name = c.table_name \ + AND kcu.column_name = c.column_name \ + ) AS is_pk \ + FROM information_schema.columns c \ + WHERE c.table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ + ORDER BY c.table_schema, c.table_name, c.ordinal_position", + ) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; + + // Group columns by schema.table + let mut tables: BTreeMap> = BTreeMap::new(); + for row in &col_rows { + let schema: String = row.get(0); + let table: String = row.get(1); + let col_name: String = row.get(2); + let data_type: String = row.get(3); + let not_null: bool = row.get(4); + let is_pk: bool = row.get(5); + + let mut parts = vec![col_name, data_type]; + if is_pk { + parts.push("PK".to_string()); + } + if not_null { + parts.push("NOT NULL".to_string()); + } + + let key = format!("{}.{}", schema, table); + tables.entry(key).or_default().push(parts.join(" ")); + } + + let mut lines: Vec = tables + .into_iter() + .map(|(key, cols)| format!("{}({})", key, cols.join(", "))) + .collect(); + + // Fetch FK relationships + let fks = fetch_foreign_keys_from_pool(pool).await?; + for fk in &fks { + lines.push(fk.clone()); + } + + Ok(lines.join("\n")) +} + +async fn fetch_foreign_keys_from_pool( + pool: &sqlx::PgPool, +) -> TuskResult> { + let rows = sqlx::query( + "SELECT \ + cn.nspname AS schema_name, cl.relname AS table_name, \ + array_agg(DISTINCT a.attname ORDER BY a.attname) AS columns, \ + cnf.nspname AS ref_schema, clf.relname AS ref_table, \ + array_agg(DISTINCT af.attname ORDER BY af.attname) AS ref_columns \ + FROM pg_constraint con \ + JOIN pg_class cl ON con.conrelid = cl.oid \ + JOIN pg_namespace cn ON cl.relnamespace = cn.oid \ + JOIN pg_class clf ON con.confrelid = clf.oid \ + JOIN pg_namespace cnf ON clf.relnamespace = cnf.oid \ + JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \ + JOIN pg_attribute af ON af.attrelid = con.confrelid AND af.attnum = ANY(con.confkey) \ + WHERE con.contype = 'f' \ + AND cn.nspname NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \ + GROUP BY cn.nspname, cl.relname, cnf.nspname, clf.relname, con.oid", + ) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; + + let fks: Vec = rows + .iter() + .map(|r| { + let schema: String = r.get(0); + let table: String = r.get(1); + let cols: Vec = r.get(2); + let ref_schema: String = r.get(3); + let ref_table: String = r.get(4); + let ref_cols: Vec = r.get(5); + format!( + "FK: {}.{}({}) -> {}.{}({})", + schema, + table, + cols.join(", "), + ref_schema, + ref_table, + ref_cols.join(", ") + ) + }) + .collect(); + + Ok(fks) +} + +fn clean_sql_response(raw: &str) -> String { + let trimmed = raw.trim(); + // Remove markdown code fences + let without_fences = if trimmed.starts_with("```") { + let inner = trimmed + .strip_prefix("```sql") + .or_else(|| trimmed.strip_prefix("```SQL")) + .or_else(|| trimmed.strip_prefix("```")) + .unwrap_or(trimmed); + inner.strip_suffix("```").unwrap_or(inner) + } else { + trimmed + }; + without_fences.trim().to_string() +} diff --git a/src-tauri/src/commands/connections.rs b/src-tauri/src/commands/connections.rs index f90c637..5550c9a 100644 --- a/src-tauri/src/commands/connections.rs +++ b/src-tauri/src/commands/connections.rs @@ -1,12 +1,19 @@ use crate::error::{TuskError, TuskResult}; use crate::models::connection::ConnectionConfig; -use crate::state::AppState; +use crate::state::{AppState, DbFlavor}; +use serde::Serialize; use sqlx::PgPool; use sqlx::Row; use std::fs; use std::sync::Arc; use tauri::{AppHandle, Manager, State}; +#[derive(Debug, Clone, Serialize)] +pub struct ConnectResult { + pub version: String, + pub flavor: DbFlavor, +} + fn get_connections_path(app: &AppHandle) -> TuskResult { let dir = app .path() @@ -72,6 +79,9 @@ pub async fn delete_connection( let mut ro = state.read_only.write().await; ro.remove(&id); + let mut flavors = state.db_flavors.write().await; + flavors.remove(&id); + Ok(()) } @@ -92,7 +102,10 @@ pub async fn test_connection(config: ConnectionConfig) -> TuskResult { } #[tauri::command] -pub async fn connect(state: State<'_, Arc>, config: ConnectionConfig) -> TuskResult<()> { +pub async fn connect( + state: State<'_, Arc>, + config: ConnectionConfig, +) -> TuskResult { let pool = PgPool::connect(&config.connection_url()) .await .map_err(TuskError::Database)?; @@ -103,13 +116,29 @@ pub async fn connect(state: State<'_, Arc>, config: ConnectionConfig) .await .map_err(TuskError::Database)?; + // Detect database flavor via version() + let row = sqlx::query("SELECT version()") + .fetch_one(&pool) + .await + .map_err(TuskError::Database)?; + let version: String = row.get(0); + + let flavor = if version.to_lowercase().contains("greenplum") { + DbFlavor::Greenplum + } else { + DbFlavor::PostgreSQL + }; + let mut pools = state.pools.write().await; pools.insert(config.id.clone(), pool); let mut ro = state.read_only.write().await; ro.insert(config.id.clone(), true); - Ok(()) + let mut flavors = state.db_flavors.write().await; + flavors.insert(config.id.clone(), flavor); + + Ok(ConnectResult { version, flavor }) } #[tauri::command] @@ -149,6 +178,9 @@ pub async fn disconnect(state: State<'_, Arc>, id: String) -> TuskResu let mut ro = state.read_only.write().await; ro.remove(&id); + let mut flavors = state.db_flavors.write().await; + flavors.remove(&id); + Ok(()) } @@ -170,3 +202,11 @@ pub async fn get_read_only( ) -> TuskResult { Ok(state.is_read_only(&connection_id).await) } + +#[tauri::command] +pub async fn get_db_flavor( + state: State<'_, Arc>, + connection_id: String, +) -> TuskResult { + Ok(state.get_flavor(&connection_id).await) +} diff --git a/src-tauri/src/commands/lookup.rs b/src-tauri/src/commands/lookup.rs index e9a0d66..b763938 100644 --- a/src-tauri/src/commands/lookup.rs +++ b/src-tauri/src/commands/lookup.rs @@ -82,7 +82,7 @@ async fn search_database_inner( "SELECT table_schema, table_name, data_type \ FROM information_schema.columns \ WHERE column_name = $1 \ - AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast')", + AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit')", ) .bind(column_name) .fetch_all(pool) diff --git a/src-tauri/src/commands/management.rs b/src-tauri/src/commands/management.rs index dfb7adb..b6fd918 100644 --- a/src-tauri/src/commands/management.rs +++ b/src-tauri/src/commands/management.rs @@ -1,6 +1,6 @@ use crate::error::{TuskError, TuskResult}; use crate::models::management::*; -use crate::state::AppState; +use crate::state::{AppState, DbFlavor}; use crate::utils::escape_ident; use sqlx::Row; use std::sync::Arc; @@ -514,22 +514,32 @@ pub async fn list_sessions( state: State<'_, Arc>, connection_id: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; let pools = state.pools.read().await; let pool = pools .get(&connection_id) .ok_or(TuskError::NotConnected(connection_id))?; - let rows = sqlx::query( + let sql = if flavor == DbFlavor::Greenplum { + "SELECT pid, usename, datname, state, query, \ + query_start::text, NULL::text as wait_event_type, NULL::text as wait_event, \ + client_addr::text \ + FROM pg_stat_activity \ + WHERE datname IS NOT NULL \ + ORDER BY query_start DESC NULLS LAST" + } else { "SELECT pid, usename, datname, state, query, \ query_start::text, wait_event_type, wait_event, \ client_addr::text \ FROM pg_stat_activity \ WHERE datname IS NOT NULL \ - ORDER BY query_start DESC NULLS LAST", - ) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; + ORDER BY query_start DESC NULLS LAST" + }; + + let rows = sqlx::query(sql) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; let sessions = rows .iter() diff --git a/src-tauri/src/commands/schema.rs b/src-tauri/src/commands/schema.rs index 0646c20..a240b67 100644 --- a/src-tauri/src/commands/schema.rs +++ b/src-tauri/src/commands/schema.rs @@ -1,6 +1,6 @@ use crate::error::{TuskError, TuskResult}; use crate::models::schema::{ColumnDetail, ColumnInfo, ConstraintInfo, IndexInfo, SchemaObject}; -use crate::state::AppState; +use crate::state::{AppState, DbFlavor}; use sqlx::Row; use std::collections::HashMap; use std::sync::Arc; @@ -37,14 +37,21 @@ pub async fn list_schemas_core( .get(connection_id) .ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?; - let rows = sqlx::query( + let flavor = state.get_flavor(connection_id).await; + let sql = if flavor == DbFlavor::Greenplum { + "SELECT schema_name FROM information_schema.schemata \ + WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ + ORDER BY schema_name" + } else { "SELECT schema_name FROM information_schema.schemata \ WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \ - ORDER BY schema_name", - ) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; + ORDER BY schema_name" + }; + + let rows = sqlx::query(sql) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; Ok(rows.iter().map(|r| r.get::(0)).collect()) } @@ -70,7 +77,7 @@ pub async fn list_tables_core( let rows = sqlx::query( "SELECT t.table_name, \ c.reltuples::bigint as row_count, \ - pg_total_relation_size(quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))::bigint as size_bytes \ + pg_total_relation_size(c.oid)::bigint as size_bytes \ FROM information_schema.tables t \ LEFT JOIN pg_class c ON c.relname = t.table_name \ AND c.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $1) \ @@ -387,20 +394,28 @@ pub async fn get_completion_schema( state: State<'_, Arc>, connection_id: String, ) -> TuskResult>>> { + let flavor = state.get_flavor(&connection_id).await; let pools = state.pools.read().await; let pool = pools .get(&connection_id) .ok_or(TuskError::NotConnected(connection_id))?; - let rows = sqlx::query( + let sql = if flavor == DbFlavor::Greenplum { + "SELECT table_schema, table_name, column_name \ + FROM information_schema.columns \ + WHERE table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ + ORDER BY table_schema, table_name, ordinal_position" + } else { "SELECT table_schema, table_name, column_name \ FROM information_schema.columns \ WHERE table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \ - ORDER BY table_schema, table_name, ordinal_position", - ) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; + ORDER BY table_schema, table_name, ordinal_position" + }; + + let rows = sqlx::query(sql) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; let mut result: HashMap>> = HashMap::new(); for row in &rows { @@ -426,25 +441,36 @@ pub async fn get_column_details( schema: String, table: String, ) -> TuskResult> { + let flavor = state.get_flavor(&connection_id).await; let pools = state.pools.read().await; let pool = pools .get(&connection_id) .ok_or(TuskError::NotConnected(connection_id))?; - let rows = sqlx::query( + let sql = if flavor == DbFlavor::Greenplum { + "SELECT c.column_name, c.data_type, \ + c.is_nullable = 'YES' as is_nullable, \ + c.column_default, \ + false as is_identity \ + FROM information_schema.columns c \ + WHERE c.table_schema = $1 AND c.table_name = $2 \ + ORDER BY c.ordinal_position" + } else { "SELECT c.column_name, c.data_type, \ c.is_nullable = 'YES' as is_nullable, \ c.column_default, \ c.is_identity = 'YES' as is_identity \ FROM information_schema.columns c \ WHERE c.table_schema = $1 AND c.table_name = $2 \ - ORDER BY c.ordinal_position", - ) - .bind(&schema) - .bind(&table) - .fetch_all(pool) - .await - .map_err(TuskError::Database)?; + ORDER BY c.ordinal_position" + }; + + let rows = sqlx::query(sql) + .bind(&schema) + .bind(&table) + .fetch_all(pool) + .await + .map_err(TuskError::Database)?; Ok(rows .iter() diff --git a/src-tauri/src/error.rs b/src-tauri/src/error.rs index 6f718f4..1f5c048 100644 --- a/src-tauri/src/error.rs +++ b/src-tauri/src/error.rs @@ -20,6 +20,9 @@ pub enum TuskError { #[error("Connection is in read-only mode")] ReadOnly, + #[error("AI error: {0}")] + Ai(String), + #[error("{0}")] Custom(String), } diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 1458de1..1239bf3 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -43,6 +43,7 @@ pub fn run() { commands::connections::disconnect, commands::connections::set_read_only, commands::connections::get_read_only, + commands::connections::get_db_flavor, // queries commands::queries::execute_query, // schema diff --git a/src-tauri/src/models/ai.rs b/src-tauri/src/models/ai.rs new file mode 100644 index 0000000..1dabb7d --- /dev/null +++ b/src-tauri/src/models/ai.rs @@ -0,0 +1,44 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AiSettings { + pub ollama_url: String, + pub model: String, +} + +impl Default for AiSettings { + fn default() -> Self { + Self { + ollama_url: "http://localhost:11434".to_string(), + model: String::new(), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OllamaChatMessage { + pub role: String, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct OllamaChatRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, +} + +#[derive(Debug, Deserialize)] +pub struct OllamaChatResponse { + pub message: OllamaChatMessage, +} + +#[derive(Debug, Deserialize)] +pub struct OllamaTagsResponse { + pub models: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OllamaModel { + pub name: String, +} diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index fefae63..94bfb15 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -1,12 +1,21 @@ +use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; use std::path::PathBuf; use tokio::sync::RwLock; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum DbFlavor { + PostgreSQL, + Greenplum, +} + pub struct AppState { pub pools: RwLock>, pub config_path: RwLock>, pub read_only: RwLock>, + pub db_flavors: RwLock>, } impl AppState { @@ -15,6 +24,7 @@ impl AppState { pools: RwLock::new(HashMap::new()), config_path: RwLock::new(None), read_only: RwLock::new(HashMap::new()), + db_flavors: RwLock::new(HashMap::new()), } } @@ -22,4 +32,9 @@ impl AppState { let map = self.read_only.read().await; map.get(id).copied().unwrap_or(true) } + + pub async fn get_flavor(&self, id: &str) -> DbFlavor { + let map = self.db_flavors.read().await; + map.get(id).copied().unwrap_or(DbFlavor::PostgreSQL) + } } diff --git a/src/App.tsx b/src/App.tsx index c56466a..b94aeb5 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -13,7 +13,7 @@ import { useAppStore } from "@/stores/app-store"; import type { Tab } from "@/types"; export default function App() { - const { activeConnectionId, addTab } = useAppStore(); + const { activeConnectionId, currentDatabase, addTab } = useAppStore(); const handleNewQuery = useCallback(() => { if (!activeConnectionId) return; @@ -22,10 +22,11 @@ export default function App() { type: "query", title: "New Query", connectionId: activeConnectionId, + database: currentDatabase ?? undefined, sql: "", }; addTab(tab); - }, [activeConnectionId, addTab]); + }, [activeConnectionId, currentDatabase, addTab]); const handleCloseTab = useCallback(() => { const { activeTabId, closeTab } = useAppStore.getState(); diff --git a/src/components/ai/AiBar.tsx b/src/components/ai/AiBar.tsx new file mode 100644 index 0000000..b110ea8 --- /dev/null +++ b/src/components/ai/AiBar.tsx @@ -0,0 +1,92 @@ +import { useState } from "react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { AiSettingsPopover } from "./AiSettingsPopover"; +import { useGenerateSql } from "@/hooks/use-ai"; +import { Sparkles, Loader2, X } from "lucide-react"; +import { toast } from "sonner"; + +interface Props { + connectionId: string; + onSqlGenerated: (sql: string) => void; + onClose: () => void; + onExecute?: () => void; +} + +export function AiBar({ connectionId, onSqlGenerated, onClose, onExecute }: Props) { + const [prompt, setPrompt] = useState(""); + const generateMutation = useGenerateSql(); + + const handleGenerate = () => { + if (!prompt.trim() || generateMutation.isPending) return; + generateMutation.mutate( + { connectionId, prompt }, + { + onSuccess: (sql) => { + onSqlGenerated(sql); + setPrompt(""); + }, + onError: (err) => { + toast.error("AI generation failed", { description: String(err) }); + }, + } + ); + }; + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter" && (e.ctrlKey || e.metaKey)) { + e.preventDefault(); + e.stopPropagation(); + onExecute?.(); + return; + } + if (e.key === "Enter" && !e.shiftKey) { + e.preventDefault(); + e.stopPropagation(); + handleGenerate(); + return; + } + if (e.key === "Escape") { + e.stopPropagation(); + onClose(); + } + }; + + return ( +
+ + setPrompt(e.target.value)} + onKeyDown={handleKeyDown} + placeholder="Describe the query you want..." + className="h-7 min-w-0 flex-1 text-xs" + autoFocus + disabled={generateMutation.isPending} + /> + + + +
+ ); +} diff --git a/src/components/ai/AiSettingsPopover.tsx b/src/components/ai/AiSettingsPopover.tsx new file mode 100644 index 0000000..c942d17 --- /dev/null +++ b/src/components/ai/AiSettingsPopover.tsx @@ -0,0 +1,121 @@ +import { useState } from "react"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { useAiSettings, useSaveAiSettings, useOllamaModels } from "@/hooks/use-ai"; +import { Settings, RefreshCw, Loader2 } from "lucide-react"; +import { toast } from "sonner"; + +export function AiSettingsPopover() { + const { data: settings } = useAiSettings(); + const saveMutation = useSaveAiSettings(); + + const [url, setUrl] = useState(null); + const [model, setModel] = useState(null); + + const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434"; + const currentModel = model ?? settings?.model ?? ""; + + const { + data: models, + isLoading: modelsLoading, + isError: modelsError, + refetch: refetchModels, + } = useOllamaModels(currentUrl); + + const handleSave = () => { + saveMutation.mutate( + { ollama_url: currentUrl, model: currentModel }, + { + onSuccess: () => toast.success("AI settings saved"), + onError: (err) => + toast.error("Failed to save AI settings", { + description: String(err), + }), + } + ); + }; + + return ( + + + + + +
+

Ollama Settings

+ +
+ + setUrl(e.target.value)} + placeholder="http://localhost:11434" + className="h-8 text-xs" + /> +
+ +
+
+ + +
+ {modelsError ? ( +

+ Cannot connect to Ollama +

+ ) : ( + + )} +
+ + +
+
+
+ ); +} diff --git a/src/components/connections/ConnectionList.tsx b/src/components/connections/ConnectionList.tsx index 48ebdaa..27a8230 100644 --- a/src/components/connections/ConnectionList.tsx +++ b/src/components/connections/ConnectionList.tsx @@ -25,6 +25,7 @@ import { import type { ConnectionConfig } from "@/types"; import { EnvironmentBadge } from "@/components/connections/EnvironmentBadge"; import { ENVIRONMENTS } from "@/lib/environment"; +import { useState } from "react"; interface Props { open: boolean; @@ -39,8 +40,10 @@ export function ConnectionList({ open, onOpenChange, onEdit, onNew }: Props) { const connectMutation = useConnect(); const disconnectMutation = useDisconnect(); const { connectedIds, activeConnectionId } = useAppStore(); + const [connectingId, setConnectingId] = useState(null); const handleConnect = (conn: ConnectionConfig) => { + setConnectingId(conn.id); connectMutation.mutate(conn, { onSuccess: () => { toast.success(`Connected to ${conn.name}`); @@ -49,6 +52,9 @@ export function ConnectionList({ open, onOpenChange, onEdit, onNew }: Props) { onError: (err) => { toast.error("Connection failed", { description: String(err) }); }, + onSettled: () => { + setConnectingId(null); + }, }); }; @@ -169,9 +175,9 @@ export function ConnectionList({ open, onOpenChange, onEdit, onNew }: Props) { variant="ghost" className="h-7 w-7" onClick={() => handleConnect(conn)} - disabled={connectMutation.isPending} + disabled={connectingId !== null} > - {connectMutation.isPending ? ( + {connectingId === conn.id ? ( ) : ( diff --git a/src/components/history/HistoryPanel.tsx b/src/components/history/HistoryPanel.tsx index 20f3d7c..e622b4d 100644 --- a/src/components/history/HistoryPanel.tsx +++ b/src/components/history/HistoryPanel.tsx @@ -12,13 +12,14 @@ export function HistoryPanel() { const { data: entries } = useHistory(undefined, search || undefined); const clearMutation = useClearHistory(); - const handleClick = (sql: string, connectionId: string) => { + const handleClick = (sql: string, connectionId: string, database?: string) => { const cid = activeConnectionId ?? connectionId; const tab: Tab = { id: crypto.randomUUID(), type: "query", title: "History Query", connectionId: cid, + database, sql, }; addTab(tab); @@ -52,7 +53,7 @@ export function HistoryPanel() { + {result && result.columns.length > 0 && ( @@ -277,7 +289,18 @@ export function WorkspacePanel({ )} -
+ {aiBarOpen && ( + { + setSqlValue(sql); + onSqlChange?.(sql); + }} + onClose={() => setAiBarOpen(false)} + onExecute={handleExecute} + /> + )} +
- {(explainData || result || error) && ( -
- - {explainData && ( +
+ {(explainData || result || error) && ( +
- )} - {resultView === "results" && result && result.columns.length > 0 && ( -
+ {explainData && ( - -
+ )} + {resultView === "results" && result && result.columns.length > 0 && ( +
+ + +
+ )} +
+ )} +
+ {resultView === "explain" && explainData ? ( + + ) : ( + )}
- )} - {resultView === "explain" && explainData ? ( - - ) : ( - - )} +
diff --git a/src/hooks/use-ai.ts b/src/hooks/use-ai.ts new file mode 100644 index 0000000..5070dbc --- /dev/null +++ b/src/hooks/use-ai.ts @@ -0,0 +1,47 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { + getAiSettings, + saveAiSettings, + listOllamaModels, + generateSql, +} from "@/lib/tauri"; +import type { AiSettings } from "@/types"; + +export function useAiSettings() { + return useQuery({ + queryKey: ["ai-settings"], + queryFn: getAiSettings, + staleTime: Infinity, + }); +} + +export function useSaveAiSettings() { + const queryClient = useQueryClient(); + return useMutation({ + mutationFn: (settings: AiSettings) => saveAiSettings(settings), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ["ai-settings"] }); + }, + }); +} + +export function useOllamaModels(ollamaUrl: string | undefined) { + return useQuery({ + queryKey: ["ollama-models", ollamaUrl], + queryFn: () => listOllamaModels(ollamaUrl!), + enabled: !!ollamaUrl, + retry: false, + }); +} + +export function useGenerateSql() { + return useMutation({ + mutationFn: ({ + connectionId, + prompt, + }: { + connectionId: string; + prompt: string; + }) => generateSql(connectionId, prompt), + }); +} diff --git a/src/hooks/use-connections.ts b/src/hooks/use-connections.ts index 9094687..d7361b9 100644 --- a/src/hooks/use-connections.ts +++ b/src/hooks/use-connections.ts @@ -51,19 +51,19 @@ export function useTestConnection() { } export function useConnect() { - const { addConnectedId, setActiveConnectionId, setPgVersion, setCurrentDatabase } = + const { addConnectedId, setActiveConnectionId, setPgVersion, setDbFlavor, setCurrentDatabase } = useAppStore(); return useMutation({ mutationFn: async (config: ConnectionConfig) => { - await connectDb(config); - const version = await testConnection(config); - return { id: config.id, version, database: config.database }; + const result = await connectDb(config); + return { id: config.id, ...result, database: config.database }; }, - onSuccess: ({ id, version, database }) => { + onSuccess: ({ id, version, flavor, database }) => { addConnectedId(id); setActiveConnectionId(id); setPgVersion(version); + setDbFlavor(id, flavor); setCurrentDatabase(database); }, }); @@ -91,17 +91,17 @@ export function useDisconnect() { export function useReconnect() { const queryClient = useQueryClient(); - const { setPgVersion, setCurrentDatabase } = useAppStore(); + const { setPgVersion, setDbFlavor, setCurrentDatabase } = useAppStore(); return useMutation({ mutationFn: async (config: ConnectionConfig) => { await disconnectDb(config.id); - await connectDb(config); - const version = await testConnection(config); - return { version, database: config.database }; + const result = await connectDb(config); + return { id: config.id, ...result, database: config.database }; }, - onSuccess: ({ version, database }) => { + onSuccess: ({ id, version, flavor, database }) => { setPgVersion(version); + setDbFlavor(id, flavor); setCurrentDatabase(database); queryClient.invalidateQueries(); }, diff --git a/src/lib/tauri.ts b/src/lib/tauri.ts index 10316d2..7c7d295 100644 --- a/src/lib/tauri.ts +++ b/src/lib/tauri.ts @@ -2,6 +2,8 @@ import { invoke } from "@tauri-apps/api/core"; import { listen, type UnlistenFn } from "@tauri-apps/api/event"; import type { ConnectionConfig, + ConnectResult, + DbFlavor, QueryResult, PaginatedQueryResult, SchemaObject, @@ -40,7 +42,7 @@ export const testConnection = (config: ConnectionConfig) => invoke("test_connection", { config }); export const connectDb = (config: ConnectionConfig) => - invoke("connect", { config }); + invoke("connect", { config }); export const disconnectDb = (id: string) => invoke("disconnect", { id }); @@ -55,6 +57,9 @@ export const setReadOnly = (connectionId: string, readOnly: boolean) => export const getReadOnly = (connectionId: string) => invoke("get_read_only", { connectionId }); +export const getDbFlavor = (connectionId: string) => + invoke("get_db_flavor", { connectionId }); + // Queries export const executeQuery = (connectionId: string, sql: string) => invoke("execute_query", { connectionId, sql }); diff --git a/src/stores/app-store.ts b/src/stores/app-store.ts index fe0faa5..60bb520 100644 --- a/src/stores/app-store.ts +++ b/src/stores/app-store.ts @@ -1,5 +1,5 @@ import { create } from "zustand"; -import type { ConnectionConfig, Tab } from "@/types"; +import type { ConnectionConfig, DbFlavor, Tab } from "@/types"; interface AppState { connections: ConnectionConfig[]; @@ -7,6 +7,7 @@ interface AppState { currentDatabase: string | null; connectedIds: Set; readOnlyMap: Record; + dbFlavors: Record; tabs: Tab[]; activeTabId: string | null; sidebarWidth: number; @@ -18,6 +19,7 @@ interface AppState { addConnectedId: (id: string) => void; removeConnectedId: (id: string) => void; setReadOnly: (connectionId: string, readOnly: boolean) => void; + setDbFlavor: (connectionId: string, flavor: DbFlavor) => void; setPgVersion: (version: string | null) => void; addTab: (tab: Tab) => void; @@ -33,6 +35,7 @@ export const useAppStore = create((set) => ({ currentDatabase: null, connectedIds: new Set(), readOnlyMap: {}, + dbFlavors: {}, tabs: [], activeTabId: null, sidebarWidth: 260, @@ -50,13 +53,22 @@ export const useAppStore = create((set) => ({ set((state) => { const next = new Set(state.connectedIds); next.delete(id); - const { [id]: _, ...restRo } = state.readOnlyMap; - return { connectedIds: next, readOnlyMap: restRo }; + const restRo = Object.fromEntries( + Object.entries(state.readOnlyMap).filter(([k]) => k !== id) + ); + const restFlavors = Object.fromEntries( + Object.entries(state.dbFlavors).filter(([k]) => k !== id) + ); + return { connectedIds: next, readOnlyMap: restRo, dbFlavors: restFlavors }; }), setReadOnly: (connectionId, readOnly) => set((state) => ({ readOnlyMap: { ...state.readOnlyMap, [connectionId]: readOnly }, })), + setDbFlavor: (connectionId, flavor) => + set((state) => ({ + dbFlavors: { ...state.dbFlavors, [connectionId]: flavor }, + })), setPgVersion: (version) => set({ pgVersion: version }), addTab: (tab) => diff --git a/src/types/index.ts b/src/types/index.ts index 6ef3b7c..f26897d 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -1,3 +1,10 @@ +export type DbFlavor = "postgresql" | "greenplum"; + +export interface ConnectResult { + version: string; + flavor: DbFlavor; +} + export interface ConnectionConfig { id: string; name: string;