feat: add Greenplum 7 compatibility and AI SQL generation

Greenplum 7 (PG12-based) compatibility:
- Auto-detect GP via version() string, store DbFlavor per connection
- connect returns ConnectResult with version + flavor
- Fix pg_total_relation_size to use c.oid (universal, safer on both PG/GP)
- Branch is_identity column query for GP (lacks the column)
- Branch list_sessions wait_event fields for GP
- Exclude gp_toolkit schema in schema listing, completion, lookup, AI context
- Smart StatusBar version display: GP shows "GP 7.0.0 (PG 12.4)"
- Fix connection list spinner showing on all cards during connect

AI SQL generation (Ollama):
- Add AI settings, model selection, and generate_sql command
- Frontend AI panel with prompt input and SQL output

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 18:24:06 +03:00
parent d5cff8bd5e
commit e8d99c645b
27 changed files with 1276 additions and 113 deletions

315
src-tauri/Cargo.lock generated
View File

@@ -438,6 +438,16 @@ dependencies = [
"version_check",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation"
version = "0.10.1"
@@ -461,9 +471,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
dependencies = [
"bitflags 2.10.0",
"core-foundation",
"core-foundation 0.10.1",
"core-graphics-types",
"foreign-types",
"foreign-types 0.5.0",
"libc",
]
@@ -474,7 +484,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
dependencies = [
"bitflags 2.10.0",
"core-foundation",
"core-foundation 0.10.1",
"libc",
]
@@ -920,6 +930,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "fdeflate"
version = "0.3.7"
@@ -978,6 +994,15 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared 0.1.1",
]
[[package]]
name = "foreign-types"
version = "0.5.0"
@@ -985,7 +1010,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [
"foreign-types-macros",
"foreign-types-shared",
"foreign-types-shared 0.3.1",
]
[[package]]
@@ -999,6 +1024,12 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]]
name = "foreign-types-shared"
version = "0.3.1"
@@ -1424,6 +1455,25 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "h2"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http",
"indexmap 2.13.0",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@@ -1568,6 +1618,7 @@ dependencies = [
"bytes",
"futures-channel",
"futures-core",
"h2",
"http",
"http-body",
"httparse",
@@ -1580,6 +1631,38 @@ dependencies = [
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.27.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58"
dependencies = [
"http",
"hyper",
"hyper-util",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
]
[[package]]
name = "hyper-util"
version = "0.1.20"
@@ -1598,9 +1681,11 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"socket2",
"system-configuration",
"tokio",
"tower-service",
"tracing",
"windows-registry",
]
[[package]]
@@ -1994,6 +2079,12 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]]
name = "litemap"
version = "0.8.1"
@@ -2131,6 +2222,23 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "native-tls"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6cdede44f9a69cab2899a2049e2c3bd49bf911a157f6a3353d4a91c61abbce44"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndk"
version = "0.9.0"
@@ -2487,6 +2595,50 @@ dependencies = [
"pathdiff",
]
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "option-ext"
version = "0.2.0"
@@ -3099,6 +3251,46 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c"
[[package]]
name = "reqwest"
version = "0.12.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
dependencies = [
"base64 0.22.1",
"bytes",
"encoding_rs",
"futures-core",
"h2",
"http",
"http-body",
"http-body-util",
"hyper",
"hyper-rustls",
"hyper-tls",
"hyper-util",
"js-sys",
"log",
"mime",
"native-tls",
"percent-encoding",
"pin-project-lite",
"rustls-pki-types",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tower",
"tower-http",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "reqwest"
version = "0.13.2"
@@ -3245,6 +3437,19 @@ dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls"
version = "0.23.36"
@@ -3300,6 +3505,15 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "schannel"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "schemars"
version = "0.8.22"
@@ -3371,6 +3585,29 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "selectors"
version = "0.24.0"
@@ -4092,6 +4329,27 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "system-configuration"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "system-deps"
version = "6.2.2"
@@ -4113,7 +4371,7 @@ checksum = "f3a753bdc39c07b192151523a3f77cd0394aa75413802c883a0f6f6a0e5ee2e7"
dependencies = [
"bitflags 2.10.0",
"block2",
"core-foundation",
"core-foundation 0.10.1",
"core-graphics",
"crossbeam-channel",
"dispatch",
@@ -4192,7 +4450,7 @@ dependencies = [
"percent-encoding",
"plist",
"raw-window-handle",
"reqwest",
"reqwest 0.13.2",
"serde",
"serde_json",
"serde_repr",
@@ -4455,6 +4713,19 @@ dependencies = [
"toml 0.9.12+spec-1.1.0",
]
[[package]]
name = "tempfile"
version = "3.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]]
name = "tendril"
version = "0.4.3"
@@ -4590,6 +4861,26 @@ dependencies = [
"syn 2.0.114",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
dependencies = [
"rustls",
"tokio",
]
[[package]]
name = "tokio-stream"
version = "0.1.18"
@@ -4826,6 +5117,7 @@ dependencies = [
"csv",
"hex",
"log",
"reqwest 0.12.28",
"rmcp",
"schemars 1.2.1",
"serde",
@@ -5405,6 +5697,17 @@ dependencies = [
"windows-link 0.1.3",
]
[[package]]
name = "windows-registry"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720"
dependencies = [
"windows-link 0.2.1",
"windows-result 0.4.1",
"windows-strings 0.5.1",
]
[[package]]
name = "windows-result"
version = "0.3.4"

View File

@@ -30,6 +30,7 @@ csv = "1"
log = "0.4"
hex = "0.4"
bigdecimal = { version = "0.4", features = ["serde"] }
reqwest = { version = "0.12", features = ["json"] }
rmcp = { version = "0.15", features = ["server", "macros", "transport-streamable-http-server"] }
axum = "0.8"
schemars = "1"

View File

@@ -0,0 +1,299 @@
use crate::error::{TuskError, TuskResult};
use crate::models::ai::{
AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel,
OllamaTagsResponse,
};
use crate::state::AppState;
use sqlx::Row;
use std::collections::BTreeMap;
use std::fs;
use std::sync::Arc;
use std::time::Duration;
use tauri::{AppHandle, Manager, State};
fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(300))
.build()
.unwrap_or_default()
}
fn get_ai_settings_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
let dir = app
.path()
.app_data_dir()
.map_err(|e| TuskError::Custom(e.to_string()))?;
fs::create_dir_all(&dir)?;
Ok(dir.join("ai_settings.json"))
}
#[tauri::command]
pub async fn get_ai_settings(app: AppHandle) -> TuskResult<AiSettings> {
let path = get_ai_settings_path(&app)?;
if !path.exists() {
return Ok(AiSettings::default());
}
let data = fs::read_to_string(&path)?;
let settings: AiSettings = serde_json::from_str(&data)?;
Ok(settings)
}
#[tauri::command]
pub async fn save_ai_settings(app: AppHandle, settings: AiSettings) -> TuskResult<()> {
let path = get_ai_settings_path(&app)?;
let data = serde_json::to_string_pretty(&settings)?;
fs::write(&path, data)?;
Ok(())
}
#[tauri::command]
pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaModel>> {
let url = format!("{}/api/tags", ollama_url.trim_end_matches('/'));
let resp = http_client()
.get(&url)
.send()
.await
.map_err(|e| TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e)))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Ollama error ({}): {}",
status, body
)));
}
let tags: OllamaTagsResponse = resp
.json()
.await
.map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?;
Ok(tags.models)
}
#[tauri::command]
pub async fn generate_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
prompt: String,
) -> TuskResult<String> {
// Load AI settings
let settings = {
let path = get_ai_settings_path(&app)?;
if !path.exists() {
return Err(TuskError::Ai(
"No AI model selected. Open AI settings to choose a model.".to_string(),
));
}
let data = fs::read_to_string(&path)?;
serde_json::from_str::<AiSettings>(&data)?
};
if settings.model.is_empty() {
return Err(TuskError::Ai(
"No AI model selected. Open AI settings to choose a model.".to_string(),
));
}
// Build schema context
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a PostgreSQL SQL generator. Given the database schema below and a natural language request, \
output ONLY a valid PostgreSQL SQL query. Do not include any explanation, markdown formatting, \
or code fences. Output raw SQL only.\n\n\
RULES:\n\
- Use FK relationships for correct JOIN conditions.\n\
- timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\
- interval cannot be cast to numeric directly.\n\
- When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\
- Use COALESCE for nullable columns in aggregations when appropriate.\n\
- Prefer LEFT JOIN when the related row may not exist.\n\n\
DATABASE SCHEMA:\n{}",
schema_text
);
let request = OllamaChatRequest {
model: settings.model,
messages: vec![
OllamaChatMessage {
role: "system".to_string(),
content: system_prompt,
},
OllamaChatMessage {
role: "user".to_string(),
content: prompt,
},
],
stream: false,
};
let url = format!(
"{}/api/chat",
settings.ollama_url.trim_end_matches('/')
);
let resp = http_client()
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| {
TuskError::Ai(format!(
"Cannot connect to Ollama at {}: {}",
settings.ollama_url, e
))
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Ollama error ({}): {}",
status, body
)));
}
let chat_resp: OllamaChatResponse = resp
.json()
.await
.map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?;
let sql = clean_sql_response(&chat_resp.message.content);
Ok(sql)
}
async fn build_schema_context(
state: &AppState,
connection_id: &str,
) -> TuskResult<String> {
let pools = state.pools.read().await;
let pool = pools
.get(connection_id)
.ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?;
// Single query: all columns with real type names (enum types show actual name, not USER-DEFINED)
let col_rows = sqlx::query(
"SELECT \
c.table_schema, c.table_name, c.column_name, \
CASE WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name ELSE c.data_type END AS data_type, \
c.is_nullable = 'NO' AS not_null, \
EXISTS( \
SELECT 1 FROM information_schema.table_constraints tc \
JOIN information_schema.key_column_usage kcu \
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema \
WHERE tc.constraint_type = 'PRIMARY KEY' \
AND tc.table_schema = c.table_schema \
AND tc.table_name = c.table_name \
AND kcu.column_name = c.column_name \
) AS is_pk \
FROM information_schema.columns c \
WHERE c.table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
ORDER BY c.table_schema, c.table_name, c.ordinal_position",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
// Group columns by schema.table
let mut tables: BTreeMap<String, Vec<String>> = BTreeMap::new();
for row in &col_rows {
let schema: String = row.get(0);
let table: String = row.get(1);
let col_name: String = row.get(2);
let data_type: String = row.get(3);
let not_null: bool = row.get(4);
let is_pk: bool = row.get(5);
let mut parts = vec![col_name, data_type];
if is_pk {
parts.push("PK".to_string());
}
if not_null {
parts.push("NOT NULL".to_string());
}
let key = format!("{}.{}", schema, table);
tables.entry(key).or_default().push(parts.join(" "));
}
let mut lines: Vec<String> = tables
.into_iter()
.map(|(key, cols)| format!("{}({})", key, cols.join(", ")))
.collect();
// Fetch FK relationships
let fks = fetch_foreign_keys_from_pool(pool).await?;
for fk in &fks {
lines.push(fk.clone());
}
Ok(lines.join("\n"))
}
async fn fetch_foreign_keys_from_pool(
pool: &sqlx::PgPool,
) -> TuskResult<Vec<String>> {
let rows = sqlx::query(
"SELECT \
cn.nspname AS schema_name, cl.relname AS table_name, \
array_agg(DISTINCT a.attname ORDER BY a.attname) AS columns, \
cnf.nspname AS ref_schema, clf.relname AS ref_table, \
array_agg(DISTINCT af.attname ORDER BY af.attname) AS ref_columns \
FROM pg_constraint con \
JOIN pg_class cl ON con.conrelid = cl.oid \
JOIN pg_namespace cn ON cl.relnamespace = cn.oid \
JOIN pg_class clf ON con.confrelid = clf.oid \
JOIN pg_namespace cnf ON clf.relnamespace = cnf.oid \
JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \
JOIN pg_attribute af ON af.attrelid = con.confrelid AND af.attnum = ANY(con.confkey) \
WHERE con.contype = 'f' \
AND cn.nspname NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \
GROUP BY cn.nspname, cl.relname, cnf.nspname, clf.relname, con.oid",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
let fks: Vec<String> = rows
.iter()
.map(|r| {
let schema: String = r.get(0);
let table: String = r.get(1);
let cols: Vec<String> = r.get(2);
let ref_schema: String = r.get(3);
let ref_table: String = r.get(4);
let ref_cols: Vec<String> = r.get(5);
format!(
"FK: {}.{}({}) -> {}.{}({})",
schema,
table,
cols.join(", "),
ref_schema,
ref_table,
ref_cols.join(", ")
)
})
.collect();
Ok(fks)
}
fn clean_sql_response(raw: &str) -> String {
let trimmed = raw.trim();
// Remove markdown code fences
let without_fences = if trimmed.starts_with("```") {
let inner = trimmed
.strip_prefix("```sql")
.or_else(|| trimmed.strip_prefix("```SQL"))
.or_else(|| trimmed.strip_prefix("```"))
.unwrap_or(trimmed);
inner.strip_suffix("```").unwrap_or(inner)
} else {
trimmed
};
without_fences.trim().to_string()
}

View File

@@ -1,12 +1,19 @@
use crate::error::{TuskError, TuskResult};
use crate::models::connection::ConnectionConfig;
use crate::state::AppState;
use crate::state::{AppState, DbFlavor};
use serde::Serialize;
use sqlx::PgPool;
use sqlx::Row;
use std::fs;
use std::sync::Arc;
use tauri::{AppHandle, Manager, State};
#[derive(Debug, Clone, Serialize)]
pub struct ConnectResult {
pub version: String,
pub flavor: DbFlavor,
}
fn get_connections_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
let dir = app
.path()
@@ -72,6 +79,9 @@ pub async fn delete_connection(
let mut ro = state.read_only.write().await;
ro.remove(&id);
let mut flavors = state.db_flavors.write().await;
flavors.remove(&id);
Ok(())
}
@@ -92,7 +102,10 @@ pub async fn test_connection(config: ConnectionConfig) -> TuskResult<String> {
}
#[tauri::command]
pub async fn connect(state: State<'_, Arc<AppState>>, config: ConnectionConfig) -> TuskResult<()> {
pub async fn connect(
state: State<'_, Arc<AppState>>,
config: ConnectionConfig,
) -> TuskResult<ConnectResult> {
let pool = PgPool::connect(&config.connection_url())
.await
.map_err(TuskError::Database)?;
@@ -103,13 +116,29 @@ pub async fn connect(state: State<'_, Arc<AppState>>, config: ConnectionConfig)
.await
.map_err(TuskError::Database)?;
// Detect database flavor via version()
let row = sqlx::query("SELECT version()")
.fetch_one(&pool)
.await
.map_err(TuskError::Database)?;
let version: String = row.get(0);
let flavor = if version.to_lowercase().contains("greenplum") {
DbFlavor::Greenplum
} else {
DbFlavor::PostgreSQL
};
let mut pools = state.pools.write().await;
pools.insert(config.id.clone(), pool);
let mut ro = state.read_only.write().await;
ro.insert(config.id.clone(), true);
Ok(())
let mut flavors = state.db_flavors.write().await;
flavors.insert(config.id.clone(), flavor);
Ok(ConnectResult { version, flavor })
}
#[tauri::command]
@@ -149,6 +178,9 @@ pub async fn disconnect(state: State<'_, Arc<AppState>>, id: String) -> TuskResu
let mut ro = state.read_only.write().await;
ro.remove(&id);
let mut flavors = state.db_flavors.write().await;
flavors.remove(&id);
Ok(())
}
@@ -170,3 +202,11 @@ pub async fn get_read_only(
) -> TuskResult<bool> {
Ok(state.is_read_only(&connection_id).await)
}
#[tauri::command]
pub async fn get_db_flavor(
state: State<'_, Arc<AppState>>,
connection_id: String,
) -> TuskResult<DbFlavor> {
Ok(state.get_flavor(&connection_id).await)
}

View File

@@ -82,7 +82,7 @@ async fn search_database_inner(
"SELECT table_schema, table_name, data_type \
FROM information_schema.columns \
WHERE column_name = $1 \
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast')",
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit')",
)
.bind(column_name)
.fetch_all(pool)

View File

@@ -1,6 +1,6 @@
use crate::error::{TuskError, TuskResult};
use crate::models::management::*;
use crate::state::AppState;
use crate::state::{AppState, DbFlavor};
use crate::utils::escape_ident;
use sqlx::Row;
use std::sync::Arc;
@@ -514,22 +514,32 @@ pub async fn list_sessions(
state: State<'_, Arc<AppState>>,
connection_id: String,
) -> TuskResult<Vec<SessionInfo>> {
let flavor = state.get_flavor(&connection_id).await;
let pools = state.pools.read().await;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query(
let sql = if flavor == DbFlavor::Greenplum {
"SELECT pid, usename, datname, state, query, \
query_start::text, NULL::text as wait_event_type, NULL::text as wait_event, \
client_addr::text \
FROM pg_stat_activity \
WHERE datname IS NOT NULL \
ORDER BY query_start DESC NULLS LAST"
} else {
"SELECT pid, usename, datname, state, query, \
query_start::text, wait_event_type, wait_event, \
client_addr::text \
FROM pg_stat_activity \
WHERE datname IS NOT NULL \
ORDER BY query_start DESC NULLS LAST",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
ORDER BY query_start DESC NULLS LAST"
};
let rows = sqlx::query(sql)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
let sessions = rows
.iter()

View File

@@ -1,6 +1,6 @@
use crate::error::{TuskError, TuskResult};
use crate::models::schema::{ColumnDetail, ColumnInfo, ConstraintInfo, IndexInfo, SchemaObject};
use crate::state::AppState;
use crate::state::{AppState, DbFlavor};
use sqlx::Row;
use std::collections::HashMap;
use std::sync::Arc;
@@ -37,14 +37,21 @@ pub async fn list_schemas_core(
.get(connection_id)
.ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?;
let rows = sqlx::query(
let flavor = state.get_flavor(connection_id).await;
let sql = if flavor == DbFlavor::Greenplum {
"SELECT schema_name FROM information_schema.schemata \
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
ORDER BY schema_name"
} else {
"SELECT schema_name FROM information_schema.schemata \
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \
ORDER BY schema_name",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
ORDER BY schema_name"
};
let rows = sqlx::query(sql)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
Ok(rows.iter().map(|r| r.get::<String, _>(0)).collect())
}
@@ -70,7 +77,7 @@ pub async fn list_tables_core(
let rows = sqlx::query(
"SELECT t.table_name, \
c.reltuples::bigint as row_count, \
pg_total_relation_size(quote_ident(t.table_schema) || '.' || quote_ident(t.table_name))::bigint as size_bytes \
pg_total_relation_size(c.oid)::bigint as size_bytes \
FROM information_schema.tables t \
LEFT JOIN pg_class c ON c.relname = t.table_name \
AND c.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = $1) \
@@ -387,20 +394,28 @@ pub async fn get_completion_schema(
state: State<'_, Arc<AppState>>,
connection_id: String,
) -> TuskResult<HashMap<String, HashMap<String, Vec<String>>>> {
let flavor = state.get_flavor(&connection_id).await;
let pools = state.pools.read().await;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query(
let sql = if flavor == DbFlavor::Greenplum {
"SELECT table_schema, table_name, column_name \
FROM information_schema.columns \
WHERE table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
ORDER BY table_schema, table_name, ordinal_position"
} else {
"SELECT table_schema, table_name, column_name \
FROM information_schema.columns \
WHERE table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast') \
ORDER BY table_schema, table_name, ordinal_position",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
ORDER BY table_schema, table_name, ordinal_position"
};
let rows = sqlx::query(sql)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
let mut result: HashMap<String, HashMap<String, Vec<String>>> = HashMap::new();
for row in &rows {
@@ -426,25 +441,36 @@ pub async fn get_column_details(
schema: String,
table: String,
) -> TuskResult<Vec<ColumnDetail>> {
let flavor = state.get_flavor(&connection_id).await;
let pools = state.pools.read().await;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query(
let sql = if flavor == DbFlavor::Greenplum {
"SELECT c.column_name, c.data_type, \
c.is_nullable = 'YES' as is_nullable, \
c.column_default, \
false as is_identity \
FROM information_schema.columns c \
WHERE c.table_schema = $1 AND c.table_name = $2 \
ORDER BY c.ordinal_position"
} else {
"SELECT c.column_name, c.data_type, \
c.is_nullable = 'YES' as is_nullable, \
c.column_default, \
c.is_identity = 'YES' as is_identity \
FROM information_schema.columns c \
WHERE c.table_schema = $1 AND c.table_name = $2 \
ORDER BY c.ordinal_position",
)
.bind(&schema)
.bind(&table)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
ORDER BY c.ordinal_position"
};
let rows = sqlx::query(sql)
.bind(&schema)
.bind(&table)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
Ok(rows
.iter()

View File

@@ -20,6 +20,9 @@ pub enum TuskError {
#[error("Connection is in read-only mode")]
ReadOnly,
#[error("AI error: {0}")]
Ai(String),
#[error("{0}")]
Custom(String),
}

View File

@@ -43,6 +43,7 @@ pub fn run() {
commands::connections::disconnect,
commands::connections::set_read_only,
commands::connections::get_read_only,
commands::connections::get_db_flavor,
// queries
commands::queries::execute_query,
// schema

View File

@@ -0,0 +1,44 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiSettings {
pub ollama_url: String,
pub model: String,
}
impl Default for AiSettings {
fn default() -> Self {
Self {
ollama_url: "http://localhost:11434".to_string(),
model: String::new(),
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OllamaChatMessage {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize)]
pub struct OllamaChatRequest {
pub model: String,
pub messages: Vec<OllamaChatMessage>,
pub stream: bool,
}
#[derive(Debug, Deserialize)]
pub struct OllamaChatResponse {
pub message: OllamaChatMessage,
}
#[derive(Debug, Deserialize)]
pub struct OllamaTagsResponse {
pub models: Vec<OllamaModel>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaModel {
pub name: String,
}

View File

@@ -1,12 +1,21 @@
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::collections::HashMap;
use std::path::PathBuf;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum DbFlavor {
PostgreSQL,
Greenplum,
}
pub struct AppState {
pub pools: RwLock<HashMap<String, PgPool>>,
pub config_path: RwLock<Option<PathBuf>>,
pub read_only: RwLock<HashMap<String, bool>>,
pub db_flavors: RwLock<HashMap<String, DbFlavor>>,
}
impl AppState {
@@ -15,6 +24,7 @@ impl AppState {
pools: RwLock::new(HashMap::new()),
config_path: RwLock::new(None),
read_only: RwLock::new(HashMap::new()),
db_flavors: RwLock::new(HashMap::new()),
}
}
@@ -22,4 +32,9 @@ impl AppState {
let map = self.read_only.read().await;
map.get(id).copied().unwrap_or(true)
}
pub async fn get_flavor(&self, id: &str) -> DbFlavor {
let map = self.db_flavors.read().await;
map.get(id).copied().unwrap_or(DbFlavor::PostgreSQL)
}
}

View File

@@ -13,7 +13,7 @@ import { useAppStore } from "@/stores/app-store";
import type { Tab } from "@/types";
export default function App() {
const { activeConnectionId, addTab } = useAppStore();
const { activeConnectionId, currentDatabase, addTab } = useAppStore();
const handleNewQuery = useCallback(() => {
if (!activeConnectionId) return;
@@ -22,10 +22,11 @@ export default function App() {
type: "query",
title: "New Query",
connectionId: activeConnectionId,
database: currentDatabase ?? undefined,
sql: "",
};
addTab(tab);
}, [activeConnectionId, addTab]);
}, [activeConnectionId, currentDatabase, addTab]);
const handleCloseTab = useCallback(() => {
const { activeTabId, closeTab } = useAppStore.getState();

View File

@@ -0,0 +1,92 @@
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { AiSettingsPopover } from "./AiSettingsPopover";
import { useGenerateSql } from "@/hooks/use-ai";
import { Sparkles, Loader2, X } from "lucide-react";
import { toast } from "sonner";
interface Props {
connectionId: string;
onSqlGenerated: (sql: string) => void;
onClose: () => void;
onExecute?: () => void;
}
export function AiBar({ connectionId, onSqlGenerated, onClose, onExecute }: Props) {
const [prompt, setPrompt] = useState("");
const generateMutation = useGenerateSql();
const handleGenerate = () => {
if (!prompt.trim() || generateMutation.isPending) return;
generateMutation.mutate(
{ connectionId, prompt },
{
onSuccess: (sql) => {
onSqlGenerated(sql);
setPrompt("");
},
onError: (err) => {
toast.error("AI generation failed", { description: String(err) });
},
}
);
};
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === "Enter" && (e.ctrlKey || e.metaKey)) {
e.preventDefault();
e.stopPropagation();
onExecute?.();
return;
}
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
e.stopPropagation();
handleGenerate();
return;
}
if (e.key === "Escape") {
e.stopPropagation();
onClose();
}
};
return (
<div className="flex items-center gap-2 border-b bg-muted/50 px-2 py-1">
<Sparkles className="h-3.5 w-3.5 shrink-0 text-purple-500" />
<Input
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
onKeyDown={handleKeyDown}
placeholder="Describe the query you want..."
className="h-7 min-w-0 flex-1 text-xs"
autoFocus
disabled={generateMutation.isPending}
/>
<Button
size="sm"
variant="ghost"
className="h-6 gap-1 text-xs"
onClick={handleGenerate}
disabled={generateMutation.isPending || !prompt.trim()}
>
{generateMutation.isPending ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
"Generate"
)}
</Button>
<AiSettingsPopover />
<Button
size="sm"
variant="ghost"
className="h-6 w-6 p-0"
onClick={onClose}
title="Close AI bar"
>
<X className="h-3 w-3" />
</Button>
</div>
);
}

View File

@@ -0,0 +1,121 @@
import { useState } from "react";
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useAiSettings, useSaveAiSettings, useOllamaModels } from "@/hooks/use-ai";
import { Settings, RefreshCw, Loader2 } from "lucide-react";
import { toast } from "sonner";
export function AiSettingsPopover() {
const { data: settings } = useAiSettings();
const saveMutation = useSaveAiSettings();
const [url, setUrl] = useState<string | null>(null);
const [model, setModel] = useState<string | null>(null);
const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434";
const currentModel = model ?? settings?.model ?? "";
const {
data: models,
isLoading: modelsLoading,
isError: modelsError,
refetch: refetchModels,
} = useOllamaModels(currentUrl);
const handleSave = () => {
saveMutation.mutate(
{ ollama_url: currentUrl, model: currentModel },
{
onSuccess: () => toast.success("AI settings saved"),
onError: (err) =>
toast.error("Failed to save AI settings", {
description: String(err),
}),
}
);
};
return (
<Popover>
<PopoverTrigger asChild>
<Button
size="sm"
variant="ghost"
className="h-6 w-6 p-0"
title="AI Settings"
>
<Settings className="h-3 w-3" />
</Button>
</PopoverTrigger>
<PopoverContent className="w-80" align="end">
<div className="flex flex-col gap-3">
<h4 className="text-sm font-medium">Ollama Settings</h4>
<div className="flex flex-col gap-1.5">
<label className="text-xs text-muted-foreground">Ollama URL</label>
<Input
value={currentUrl}
onChange={(e) => setUrl(e.target.value)}
placeholder="http://localhost:11434"
className="h-8 text-xs"
/>
</div>
<div className="flex flex-col gap-1.5">
<div className="flex items-center justify-between">
<label className="text-xs text-muted-foreground">Model</label>
<Button
size="sm"
variant="ghost"
className="h-5 w-5 p-0"
onClick={() => refetchModels()}
disabled={modelsLoading}
title="Refresh models"
>
{modelsLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<RefreshCw className="h-3 w-3" />
)}
</Button>
</div>
{modelsError ? (
<p className="text-xs text-destructive">
Cannot connect to Ollama
</p>
) : (
<Select value={currentModel} onValueChange={setModel}>
<SelectTrigger className="h-8 w-full text-xs">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{models?.map((m) => (
<SelectItem key={m.name} value={m.name}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
<Button size="sm" className="h-7 text-xs" onClick={handleSave}>
Save
</Button>
</div>
</PopoverContent>
</Popover>
);
}

View File

@@ -25,6 +25,7 @@ import {
import type { ConnectionConfig } from "@/types";
import { EnvironmentBadge } from "@/components/connections/EnvironmentBadge";
import { ENVIRONMENTS } from "@/lib/environment";
import { useState } from "react";
interface Props {
open: boolean;
@@ -39,8 +40,10 @@ export function ConnectionList({ open, onOpenChange, onEdit, onNew }: Props) {
const connectMutation = useConnect();
const disconnectMutation = useDisconnect();
const { connectedIds, activeConnectionId } = useAppStore();
const [connectingId, setConnectingId] = useState<string | null>(null);
const handleConnect = (conn: ConnectionConfig) => {
setConnectingId(conn.id);
connectMutation.mutate(conn, {
onSuccess: () => {
toast.success(`Connected to ${conn.name}`);
@@ -49,6 +52,9 @@ export function ConnectionList({ open, onOpenChange, onEdit, onNew }: Props) {
onError: (err) => {
toast.error("Connection failed", { description: String(err) });
},
onSettled: () => {
setConnectingId(null);
},
});
};
@@ -169,9 +175,9 @@ export function ConnectionList({ open, onOpenChange, onEdit, onNew }: Props) {
variant="ghost"
className="h-7 w-7"
onClick={() => handleConnect(conn)}
disabled={connectMutation.isPending}
disabled={connectingId !== null}
>
{connectMutation.isPending ? (
{connectingId === conn.id ? (
<Loader2 className="h-3.5 w-3.5 animate-spin" />
) : (
<Plug className="h-3.5 w-3.5" />

View File

@@ -12,13 +12,14 @@ export function HistoryPanel() {
const { data: entries } = useHistory(undefined, search || undefined);
const clearMutation = useClearHistory();
const handleClick = (sql: string, connectionId: string) => {
const handleClick = (sql: string, connectionId: string, database?: string) => {
const cid = activeConnectionId ?? connectionId;
const tab: Tab = {
id: crypto.randomUUID(),
type: "query",
title: "History Query",
connectionId: cid,
database,
sql,
};
addTab(tab);
@@ -52,7 +53,7 @@ export function HistoryPanel() {
<button
key={entry.id}
className="flex w-full flex-col gap-0.5 border-b px-3 py-2 text-left text-xs hover:bg-accent"
onClick={() => handleClick(entry.sql, entry.connection_id)}
onClick={() => handleClick(entry.sql, entry.connection_id, entry.database || undefined)}
>
<div className="flex items-center gap-1.5">
{entry.status === "success" ? (

View File

@@ -3,6 +3,16 @@ import { useConnections } from "@/hooks/use-connections";
import { Circle } from "lucide-react";
import { EnvironmentBadge } from "@/components/connections/EnvironmentBadge";
function formatDbVersion(version: string): string {
const gpMatch = version.match(/Greenplum Database ([\d.]+)/i);
if (gpMatch) {
const pgMatch = version.match(/^PostgreSQL ([\d.]+)/);
const pgVer = pgMatch ? ` (PG ${pgMatch[1]})` : "";
return `GP ${gpMatch[1]}${pgVer}`;
}
return version.split(",")[0]?.replace("PostgreSQL ", "PG ") ?? version;
}
interface Props {
rowCount?: number | null;
executionTime?: number | null;
@@ -46,7 +56,7 @@ export function StatusBar({ rowCount, executionTime }: Props) {
</span>
)}
{pgVersion && (
<span className="hidden sm:inline">{pgVersion.split(",")[0]?.replace("PostgreSQL ", "PG ")}</span>
<span className="hidden sm:inline">{formatDbVersion(pgVersion)}</span>
)}
</div>
<div className="flex items-center gap-3">

View File

@@ -54,6 +54,7 @@ export function AdminPanel() {
type: "roles",
title: "Roles & Users",
connectionId: activeConnectionId,
database: currentDatabase ?? undefined,
};
addTab(tab);
}}
@@ -66,6 +67,7 @@ export function AdminPanel() {
type: "sessions",
title: "Active Sessions",
connectionId: activeConnectionId,
database: currentDatabase ?? undefined,
};
addTab(tab);
}}

View File

@@ -8,7 +8,7 @@ import type { Tab } from "@/types";
export function SavedQueriesPanel() {
const [search, setSearch] = useState("");
const { activeConnectionId, addTab } = useAppStore();
const { activeConnectionId, currentDatabase, addTab } = useAppStore();
const { data: queries } = useSavedQueries(search || undefined);
const deleteMutation = useDeleteSavedQuery();
@@ -20,6 +20,7 @@ export function SavedQueriesPanel() {
type: "query",
title: "Saved Query",
connectionId: cid,
database: currentDatabase ?? undefined,
sql,
};
addTab(tab);

View File

@@ -118,6 +118,7 @@ export function SchemaTree() {
type: "table",
title: table,
connectionId: activeConnectionId,
database: currentDatabase ?? undefined,
schema,
table,
};
@@ -129,6 +130,7 @@ export function SchemaTree() {
type: "structure",
title: `${table} (structure)`,
connectionId: activeConnectionId,
database: currentDatabase ?? undefined,
schema,
table,
};

View File

@@ -0,0 +1,87 @@
import * as React from "react"
import { Popover as PopoverPrimitive } from "radix-ui"
import { cn } from "@/lib/utils"
function Popover({
...props
}: React.ComponentProps<typeof PopoverPrimitive.Root>) {
return <PopoverPrimitive.Root data-slot="popover" {...props} />
}
function PopoverTrigger({
...props
}: React.ComponentProps<typeof PopoverPrimitive.Trigger>) {
return <PopoverPrimitive.Trigger data-slot="popover-trigger" {...props} />
}
function PopoverContent({
className,
align = "center",
sideOffset = 4,
...props
}: React.ComponentProps<typeof PopoverPrimitive.Content>) {
return (
<PopoverPrimitive.Portal>
<PopoverPrimitive.Content
data-slot="popover-content"
align={align}
sideOffset={sideOffset}
className={cn(
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 z-50 w-72 origin-(--radix-popover-content-transform-origin) rounded-md border p-4 shadow-md outline-hidden",
className
)}
{...props}
/>
</PopoverPrimitive.Portal>
)
}
function PopoverAnchor({
...props
}: React.ComponentProps<typeof PopoverPrimitive.Anchor>) {
return <PopoverPrimitive.Anchor data-slot="popover-anchor" {...props} />
}
function PopoverHeader({ className, ...props }: React.ComponentProps<"div">) {
return (
<div
data-slot="popover-header"
className={cn("flex flex-col gap-1 text-sm", className)}
{...props}
/>
)
}
function PopoverTitle({ className, ...props }: React.ComponentProps<"h2">) {
return (
<div
data-slot="popover-title"
className={cn("font-medium", className)}
{...props}
/>
)
}
function PopoverDescription({
className,
...props
}: React.ComponentProps<"p">) {
return (
<p
data-slot="popover-description"
className={cn("text-muted-foreground", className)}
{...props}
/>
)
}
export {
Popover,
PopoverTrigger,
PopoverContent,
PopoverAnchor,
PopoverHeader,
PopoverTitle,
PopoverDescription,
}

View File

@@ -13,7 +13,7 @@ import { useCompletionSchema } from "@/hooks/use-completion-schema";
import { useConnections } from "@/hooks/use-connections";
import { useAppStore } from "@/stores/app-store";
import { Button } from "@/components/ui/button";
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces } from "lucide-react";
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles } from "lucide-react";
import { format as formatSql } from "sql-formatter";
import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog";
import {
@@ -25,6 +25,7 @@ import {
import { exportCsv, exportJson } from "@/lib/tauri";
import { save } from "@tauri-apps/plugin-dialog";
import { toast } from "sonner";
import { AiBar } from "@/components/ai/AiBar";
import type { QueryResult, ExplainResult } from "@/types";
interface Props {
@@ -51,6 +52,7 @@ export function WorkspacePanel({
const [resultView, setResultView] = useState<"results" | "explain">("results");
const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table");
const [saveDialogOpen, setSaveDialogOpen] = useState(false);
const [aiBarOpen, setAiBarOpen] = useState(false);
const queryMutation = useQueryExecution();
const addHistoryMutation = useAddHistory();
@@ -245,6 +247,16 @@ export function WorkspacePanel({
<Bookmark className="h-3 w-3" />
Save
</Button>
<Button
size="sm"
variant={aiBarOpen ? "secondary" : "ghost"}
className="h-6 gap-1 text-xs"
onClick={() => setAiBarOpen(!aiBarOpen)}
title="AI SQL Generator"
>
<Sparkles className="h-3 w-3" />
AI
</Button>
{result && result.columns.length > 0 && (
<DropdownMenu>
<DropdownMenuTrigger asChild>
@@ -277,7 +289,18 @@ export function WorkspacePanel({
</span>
)}
</div>
<div className="flex-1 overflow-hidden">
{aiBarOpen && (
<AiBar
connectionId={connectionId}
onSqlGenerated={(sql) => {
setSqlValue(sql);
onSqlChange?.(sql);
}}
onClose={() => setAiBarOpen(false)}
onExecute={handleExecute}
/>
)}
<div className="min-h-0 flex-1">
<SqlEditor
value={sqlValue}
onChange={handleChange}
@@ -290,70 +313,74 @@ export function WorkspacePanel({
</ResizablePanel>
<ResizableHandle withHandle />
<ResizablePanel id="results" defaultSize="60%" minSize="15%">
{(explainData || result || error) && (
<div className="flex items-center border-b text-xs">
<button
className={`px-3 py-1 font-medium ${
resultView === "results"
? "bg-background text-foreground"
: "text-muted-foreground hover:text-foreground"
}`}
onClick={() => setResultView("results")}
>
Results
</button>
{explainData && (
<div className="flex h-full flex-col overflow-hidden">
{(explainData || result || error) && (
<div className="flex shrink-0 items-center border-b text-xs">
<button
className={`px-3 py-1 font-medium ${
resultView === "explain"
resultView === "results"
? "bg-background text-foreground"
: "text-muted-foreground hover:text-foreground"
}`}
onClick={() => setResultView("explain")}
onClick={() => setResultView("results")}
>
Explain
Results
</button>
)}
{resultView === "results" && result && result.columns.length > 0 && (
<div className="ml-auto mr-2 flex items-center rounded-md border">
{explainData && (
<button
className={`flex items-center gap-1 px-2 py-0.5 font-medium ${
resultViewMode === "table"
? "bg-muted text-foreground"
className={`px-3 py-1 font-medium ${
resultView === "explain"
? "bg-background text-foreground"
: "text-muted-foreground hover:text-foreground"
}`}
onClick={() => setResultViewMode("table")}
title="Table view"
onClick={() => setResultView("explain")}
>
<Table2 className="h-3 w-3" />
Table
Explain
</button>
<button
className={`flex items-center gap-1 px-2 py-0.5 font-medium ${
resultViewMode === "json"
? "bg-muted text-foreground"
: "text-muted-foreground hover:text-foreground"
}`}
onClick={() => setResultViewMode("json")}
title="JSON view"
>
<Braces className="h-3 w-3" />
JSON
</button>
</div>
)}
{resultView === "results" && result && result.columns.length > 0 && (
<div className="ml-auto mr-2 flex items-center rounded-md border">
<button
className={`flex items-center gap-1 px-2 py-0.5 font-medium ${
resultViewMode === "table"
? "bg-muted text-foreground"
: "text-muted-foreground hover:text-foreground"
}`}
onClick={() => setResultViewMode("table")}
title="Table view"
>
<Table2 className="h-3 w-3" />
Table
</button>
<button
className={`flex items-center gap-1 px-2 py-0.5 font-medium ${
resultViewMode === "json"
? "bg-muted text-foreground"
: "text-muted-foreground hover:text-foreground"
}`}
onClick={() => setResultViewMode("json")}
title="JSON view"
>
<Braces className="h-3 w-3" />
JSON
</button>
</div>
)}
</div>
)}
<div className="min-h-0 flex-1">
{resultView === "explain" && explainData ? (
<ExplainView data={explainData} />
) : (
<ResultsPanel
result={result}
error={error}
isLoading={queryMutation.isPending && resultView === "results"}
viewMode={resultViewMode}
/>
)}
</div>
)}
{resultView === "explain" && explainData ? (
<ExplainView data={explainData} />
) : (
<ResultsPanel
result={result}
error={error}
isLoading={queryMutation.isPending && resultView === "results"}
viewMode={resultViewMode}
/>
)}
</div>
</ResizablePanel>
</ResizablePanelGroup>

47
src/hooks/use-ai.ts Normal file
View File

@@ -0,0 +1,47 @@
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import {
getAiSettings,
saveAiSettings,
listOllamaModels,
generateSql,
} from "@/lib/tauri";
import type { AiSettings } from "@/types";
export function useAiSettings() {
return useQuery({
queryKey: ["ai-settings"],
queryFn: getAiSettings,
staleTime: Infinity,
});
}
export function useSaveAiSettings() {
const queryClient = useQueryClient();
return useMutation({
mutationFn: (settings: AiSettings) => saveAiSettings(settings),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ["ai-settings"] });
},
});
}
export function useOllamaModels(ollamaUrl: string | undefined) {
return useQuery({
queryKey: ["ollama-models", ollamaUrl],
queryFn: () => listOllamaModels(ollamaUrl!),
enabled: !!ollamaUrl,
retry: false,
});
}
export function useGenerateSql() {
return useMutation({
mutationFn: ({
connectionId,
prompt,
}: {
connectionId: string;
prompt: string;
}) => generateSql(connectionId, prompt),
});
}

View File

@@ -51,19 +51,19 @@ export function useTestConnection() {
}
export function useConnect() {
const { addConnectedId, setActiveConnectionId, setPgVersion, setCurrentDatabase } =
const { addConnectedId, setActiveConnectionId, setPgVersion, setDbFlavor, setCurrentDatabase } =
useAppStore();
return useMutation({
mutationFn: async (config: ConnectionConfig) => {
await connectDb(config);
const version = await testConnection(config);
return { id: config.id, version, database: config.database };
const result = await connectDb(config);
return { id: config.id, ...result, database: config.database };
},
onSuccess: ({ id, version, database }) => {
onSuccess: ({ id, version, flavor, database }) => {
addConnectedId(id);
setActiveConnectionId(id);
setPgVersion(version);
setDbFlavor(id, flavor);
setCurrentDatabase(database);
},
});
@@ -91,17 +91,17 @@ export function useDisconnect() {
export function useReconnect() {
const queryClient = useQueryClient();
const { setPgVersion, setCurrentDatabase } = useAppStore();
const { setPgVersion, setDbFlavor, setCurrentDatabase } = useAppStore();
return useMutation({
mutationFn: async (config: ConnectionConfig) => {
await disconnectDb(config.id);
await connectDb(config);
const version = await testConnection(config);
return { version, database: config.database };
const result = await connectDb(config);
return { id: config.id, ...result, database: config.database };
},
onSuccess: ({ version, database }) => {
onSuccess: ({ id, version, flavor, database }) => {
setPgVersion(version);
setDbFlavor(id, flavor);
setCurrentDatabase(database);
queryClient.invalidateQueries();
},

View File

@@ -2,6 +2,8 @@ import { invoke } from "@tauri-apps/api/core";
import { listen, type UnlistenFn } from "@tauri-apps/api/event";
import type {
ConnectionConfig,
ConnectResult,
DbFlavor,
QueryResult,
PaginatedQueryResult,
SchemaObject,
@@ -40,7 +42,7 @@ export const testConnection = (config: ConnectionConfig) =>
invoke<string>("test_connection", { config });
export const connectDb = (config: ConnectionConfig) =>
invoke<void>("connect", { config });
invoke<ConnectResult>("connect", { config });
export const disconnectDb = (id: string) =>
invoke<void>("disconnect", { id });
@@ -55,6 +57,9 @@ export const setReadOnly = (connectionId: string, readOnly: boolean) =>
export const getReadOnly = (connectionId: string) =>
invoke<boolean>("get_read_only", { connectionId });
export const getDbFlavor = (connectionId: string) =>
invoke<DbFlavor>("get_db_flavor", { connectionId });
// Queries
export const executeQuery = (connectionId: string, sql: string) =>
invoke<QueryResult>("execute_query", { connectionId, sql });

View File

@@ -1,5 +1,5 @@
import { create } from "zustand";
import type { ConnectionConfig, Tab } from "@/types";
import type { ConnectionConfig, DbFlavor, Tab } from "@/types";
interface AppState {
connections: ConnectionConfig[];
@@ -7,6 +7,7 @@ interface AppState {
currentDatabase: string | null;
connectedIds: Set<string>;
readOnlyMap: Record<string, boolean>;
dbFlavors: Record<string, DbFlavor>;
tabs: Tab[];
activeTabId: string | null;
sidebarWidth: number;
@@ -18,6 +19,7 @@ interface AppState {
addConnectedId: (id: string) => void;
removeConnectedId: (id: string) => void;
setReadOnly: (connectionId: string, readOnly: boolean) => void;
setDbFlavor: (connectionId: string, flavor: DbFlavor) => void;
setPgVersion: (version: string | null) => void;
addTab: (tab: Tab) => void;
@@ -33,6 +35,7 @@ export const useAppStore = create<AppState>((set) => ({
currentDatabase: null,
connectedIds: new Set(),
readOnlyMap: {},
dbFlavors: {},
tabs: [],
activeTabId: null,
sidebarWidth: 260,
@@ -50,13 +53,22 @@ export const useAppStore = create<AppState>((set) => ({
set((state) => {
const next = new Set(state.connectedIds);
next.delete(id);
const { [id]: _, ...restRo } = state.readOnlyMap;
return { connectedIds: next, readOnlyMap: restRo };
const restRo = Object.fromEntries(
Object.entries(state.readOnlyMap).filter(([k]) => k !== id)
);
const restFlavors = Object.fromEntries(
Object.entries(state.dbFlavors).filter(([k]) => k !== id)
);
return { connectedIds: next, readOnlyMap: restRo, dbFlavors: restFlavors };
}),
setReadOnly: (connectionId, readOnly) =>
set((state) => ({
readOnlyMap: { ...state.readOnlyMap, [connectionId]: readOnly },
})),
setDbFlavor: (connectionId, flavor) =>
set((state) => ({
dbFlavors: { ...state.dbFlavors, [connectionId]: flavor },
})),
setPgVersion: (version) => set({ pgVersion: version }),
addTab: (tab) =>

View File

@@ -1,3 +1,10 @@
export type DbFlavor = "postgresql" | "greenplum";
export interface ConnectResult {
version: string;
flavor: DbFlavor;
}
export interface ConnectionConfig {
id: string;
name: string;