From 0cba457fb7e2eb4ebce233e2719f51650346ca8c Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Sat, 23 May 2026 15:01:52 +0300 Subject: [PATCH] refactor(ai): consolidate AI around chat tool-calling; add OpenRouter - rework chat backend (chat.rs, chat_tools.rs, ai.rs, models, state) around tool calls - add OpenRouter provider alongside Ollama/Fireworks in settings - drop inline AiBar, ResultsPanel explain/fix UI and ChartPreview in favour of the chat panel - add frontend chat tool-registry --- src-tauri/src/commands/ai.rs | 848 +++---------------- src-tauri/src/commands/chat.rs | 425 ++++------ src-tauri/src/commands/chat_tools.rs | 690 +++++++++++++++ src-tauri/src/lib.rs | 4 +- src-tauri/src/models/ai.rs | 34 +- src-tauri/src/models/chat.rs | 15 - src-tauri/src/state.rs | 61 +- src/components/ai/AiBar.tsx | 103 --- src/components/ai/AiSettingsFields.tsx | 70 +- src/components/ai/AiSettingsPopover.tsx | 21 +- src/components/chat/ChartPreview.tsx | 327 ------- src/components/chat/ChatMessageView.tsx | 206 +---- src/components/chat/tool-registry.ts | 107 +++ src/components/results/ResultsPanel.tsx | 63 +- src/components/settings/AppSettingsSheet.tsx | 18 +- src/components/workspace/WorkspacePanel.tsx | 111 +-- src/hooks/use-ai.ts | 46 +- src/lib/tauri.ts | 10 +- src/types/index.ts | 16 +- 19 files changed, 1244 insertions(+), 1931 deletions(-) delete mode 100644 src/components/ai/AiBar.tsx delete mode 100644 src/components/chat/ChartPreview.tsx create mode 100644 src/components/chat/tool-registry.ts diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs index 8edb499..31cebaa 100644 --- a/src-tauri/src/commands/ai.rs +++ b/src-tauri/src/commands/ai.rs @@ -15,6 +15,13 @@ use tauri::{AppHandle, Manager, State}; const MAX_RETRIES: u32 = 2; const RETRY_DELAY_MS: u64 = 1000; const FIREWORKS_BASE_URL: &str = "https://api.fireworks.ai/inference/v1"; +const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1"; +/// Optional attribution headers OpenRouter uses for its public app leaderboard. +/// Harmless to send on every request; ignored by other OpenAI-compatible APIs. +const OPENROUTER_HEADERS: &[(&str, &str)] = &[ + ("HTTP-Referer", "https://github.com/codelab/tusk"), + ("X-Title", "Tusk"), +]; fn http_client() -> &'static reqwest::Client { use std::sync::LazyLock; @@ -89,32 +96,51 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult TuskResult> { + list_openai_compatible_models(FIREWORKS_BASE_URL, &api_key, "Fireworks", &[]).await +} + +#[tauri::command] +pub async fn list_openrouter_models(api_key: String) -> TuskResult> { + list_openai_compatible_models(OPENROUTER_BASE_URL, &api_key, "OpenRouter", OPENROUTER_HEADERS) + .await +} + +/// List the available chat models for any OpenAI-compatible provider via its +/// `GET {base_url}/models` endpoint. `extra_headers` carries provider-specific +/// attribution headers (OpenRouter recommends `HTTP-Referer`/`X-Title`). +async fn list_openai_compatible_models( + base_url: &str, + api_key: &str, + provider_label: &str, + extra_headers: &[(&str, &str)], +) -> TuskResult> { let key = api_key.trim(); if key.is_empty() { - return Err(TuskError::Ai("Fireworks API key required".to_string())); + return Err(TuskError::Ai(format!("{} API key required", provider_label))); } - let url = format!("{}/models", FIREWORKS_BASE_URL); - let resp = http_client() - .get(&url) - .bearer_auth(key) + let url = format!("{}/models", base_url); + let mut req = http_client().get(&url).bearer_auth(key); + for (name, value) in extra_headers { + req = req.header(*name, *value); + } + let resp = req .send() .await - .map_err(|e| TuskError::Ai(format!("Cannot reach Fireworks: {}", e)))?; + .map_err(|e| TuskError::Ai(format!("Cannot reach {}: {}", provider_label, e)))?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); return Err(TuskError::Ai(format!( - "Fireworks error ({}): {}", - status, body + "{} error ({}): {}", + provider_label, status, body ))); } - let parsed: FireworksModelsResponse = resp - .json() - .await - .map_err(|e| TuskError::Ai(format!("Failed to parse Fireworks models list: {}", e)))?; + let parsed: FireworksModelsResponse = resp.json().await.map_err(|e| { + TuskError::Ai(format!("Failed to parse {} models list: {}", provider_label, e)) + })?; Ok(parsed .data @@ -180,33 +206,8 @@ pub(crate) async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskR Ok(settings) } -async fn call_chat_simple( - app: &AppHandle, - state: &AppState, - system_prompt: String, - user_content: String, -) -> TuskResult { - call_chat_messages( - app, - state, - vec![ - OllamaChatMessage { - role: "system".to_string(), - content: system_prompt, - }, - OllamaChatMessage { - role: "user".to_string(), - content: user_content, - }, - ], - None, - ) - .await -} - -/// Provider-agnostic chat-completions dispatcher used by every LLM-driven feature -/// (chat agent, generate_sql, explain_sql, fix_sql_error). Returns the model's -/// raw text content. +/// Provider-agnostic chat-completions dispatcher used by the chat agent. +/// Returns the model's raw text content. pub(crate) async fn call_chat_messages( app: &AppHandle, state: &AppState, @@ -223,14 +224,49 @@ pub(crate) async fn call_chat_messages( match settings.provider { AiProvider::Ollama => call_ollama(&settings, messages, format).await, - AiProvider::Fireworks => call_fireworks(&settings, messages, format).await, - AiProvider::OpenAi | AiProvider::Anthropic => Err(TuskError::Ai(format!( - "Provider {:?} not implemented yet", - settings.provider - ))), + AiProvider::Fireworks => { + let api_key = require_api_key( + settings.fireworks_api_key.as_deref(), + "Fireworks API key not set. Open AI settings to add it.", + )?; + call_openai_compatible( + &settings, + FIREWORKS_BASE_URL, + &api_key, + "Fireworks", + &[], + messages, + format, + ) + .await + } + AiProvider::OpenRouter => { + let api_key = require_api_key( + settings.openrouter_api_key.as_deref(), + "OpenRouter API key not set. Open AI settings to add it.", + )?; + call_openai_compatible( + &settings, + OPENROUTER_BASE_URL, + &api_key, + "OpenRouter", + OPENROUTER_HEADERS, + messages, + format, + ) + .await + } } } +/// Trim and validate an optional API key, returning a user-facing error when +/// it's missing or blank. +fn require_api_key(key: Option<&str>, missing_msg: &str) -> TuskResult { + key.map(|k| k.trim().to_string()) + .filter(|k| !k.is_empty()) + .ok_or_else(|| TuskError::Ai(missing_msg.to_string())) +} + async fn call_ollama( settings: &AiSettings, messages: Vec, @@ -277,21 +313,18 @@ async fn call_ollama( .await } -async fn call_fireworks( +/// Chat-completions call for any OpenAI-compatible provider (Fireworks, +/// OpenRouter). `extra_headers` carries provider-specific attribution headers. +async fn call_openai_compatible( settings: &AiSettings, + base_url: &str, + api_key: &str, + provider_label: &str, + extra_headers: &[(&str, &str)], messages: Vec, format: Option, ) -> TuskResult { - let api_key = settings - .fireworks_api_key - .clone() - .map(|k| k.trim().to_string()) - .filter(|k| !k.is_empty()) - .ok_or_else(|| { - TuskError::Ai("Fireworks API key not set. Open AI settings to add it.".to_string()) - })?; - - let url = format!("{}/chat/completions", FIREWORKS_BASE_URL); + let url = format!("{}/chat/completions", base_url); let response_format = format.as_deref().map(|f| FireworksResponseFormat { kind: if f == "json" { "json_object".to_string() @@ -307,19 +340,22 @@ async fn call_fireworks( response_format, }; - call_ai_with_retry(settings, "Fireworks request", || { + let operation = format!("{} request", provider_label); + call_ai_with_retry(settings, &operation, || { let url = url.clone(); let request = request.clone(); - let api_key = api_key.clone(); + let api_key = api_key.to_string(); async move { - let resp = http_client() - .post(&url) - .bearer_auth(&api_key) + let mut req = http_client().post(&url).bearer_auth(&api_key); + for (name, value) in extra_headers { + req = req.header(*name, *value); + } + let resp = req .json(&request) .send() .await .map_err(|e| { - TuskError::Ai(format!("Cannot reach Fireworks at {}: {}", url, e)) + TuskError::Ai(format!("Cannot reach {} at {}: {}", provider_label, url, e)) })?; if !resp.status().is_success() { @@ -339,13 +375,13 @@ async fn call_fireworks( )); } return Err(TuskError::Ai(format!( - "Fireworks error ({}): {}", - status, body + "{} error ({}): {}", + provider_label, status, body ))); } let parsed: FireworksChatResponse = resp.json().await.map_err(|e| { - TuskError::Ai(format!("Failed to parse Fireworks response: {}", e)) + TuskError::Ai(format!("Failed to parse {} response: {}", provider_label, e)) })?; parsed @@ -353,177 +389,14 @@ async fn call_fireworks( .into_iter() .next() .map(|c| c.message.content) - .ok_or_else(|| TuskError::Ai("Fireworks returned no choices".to_string())) + .ok_or_else(|| { + TuskError::Ai(format!("{} returned no choices", provider_label)) + }) } }) .await } -// --------------------------------------------------------------------------- -// SQL generation -// --------------------------------------------------------------------------- - -#[tauri::command] -pub async fn generate_sql( - app: AppHandle, - state: State<'_, Arc>, - connection_id: String, - prompt: String, -) -> TuskResult { - let schema_text = build_schema_context(&state, &connection_id).await?; - - let system_prompt = format!( - "You are an expert PostgreSQL query generator. You receive a database schema and a natural \ - language request. Output ONLY a valid, executable PostgreSQL SQL query.\n\ - \n\ - OUTPUT FORMAT:\n\ - - Raw SQL only. No explanations, no markdown code fences (```), no comments, no preamble.\n\ - - The output must be directly executable in psql.\n\ - - For complex queries use readable formatting with line breaks and indentation.\n\ - \n\ - CRITICAL RULES:\n\ - 1. ONLY reference tables and columns that exist in the schema. Never invent names.\n\ - 2. Use the FOREIGN KEY information to determine correct JOIN conditions.\n\ - 3. Use LEFT JOIN when the FK column is nullable or the relationship is optional; \ - INNER JOIN when both sides must exist.\n\ - 4. Every non-aggregated column in SELECT must appear in GROUP BY.\n\ - 5. Use COALESCE for nullable columns in aggregations: COALESCE(SUM(x), 0).\n\ - 6. For enum columns, use ONLY the values listed in the ENUM TYPES section.\n\ - 7. Use IS NULL / IS NOT NULL for null checks — never = NULL or != NULL.\n\ - 8. Add LIMIT when the user asks for \"top N\", \"first N\", \"latest N\", etc.\n\ - 9. Qualify column names with table alias when the query involves multiple tables.\n\ - \n\ - SEMANTIC RULES (very important):\n\ - - When a table has both actual_* and planned_* columns (e.g. actual_start vs planned_start), \ - they represent DIFFERENT concepts: planned = future estimate, actual = what really happened. \ - NEVER mix them with COALESCE unless the user explicitly requests a fallback.\n\ - - For time-based calculations involving real events (\"how long did X take\", \"average time between\"), \ - use ONLY actual/factual timestamps (actual_*, started_at, completed_at, ended_at). \ - Filter out NULL values with WHERE instead of falling back to planned timestamps.\n\ - - Planned timestamps (planned_*, scheduled_*, estimated_*) should ONLY be used when the user \ - asks about plans, schedules, SLA, or compares plan vs fact.\n\ - - When computing durations or averages, always filter out rows where any involved timestamp \ - is NULL rather than substituting with unrelated defaults.\n\ - - Pay attention to column descriptions/comments in the schema — they reveal business semantics \ - that are critical for correct queries.\n\ - \n\ - TYPE RULES:\n\ - - timestamp - timestamp = interval. For seconds: EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ - - interval cannot be cast to numeric directly; use EXTRACT(EPOCH FROM interval).\n\ - - UNION/UNION ALL requires matching column count and compatible types; cast enums to text.\n\ - - Use ::type for PostgreSQL-style casts.\n\ - - For array columns use ANY, ALL, @>, <@ operators.\n\ - - For JSONB columns use ->, ->>, #>, jsonb_extract_path.\n\ - \n\ - COMMON PATTERNS:\n\ - - FIRST/LAST per group: to find MIN(started_at) per trip, use \ - \"SELECT trip_id, MIN(started_at) FROM t GROUP BY trip_id\". \ - NEVER put the aggregated column (started_at) into GROUP BY — that defeats the aggregation \ - and returns every row separately instead of one per group.\n\ - - TOP-1 per group with extra columns: use DISTINCT ON (group_col) ... ORDER BY group_col, sort_col \ - or a subquery with ROW_NUMBER() OVER (PARTITION BY group_col ORDER BY sort_col) = 1.\n\ - - For \"time from A to B\" calculations, ensure both timestamps are NOT NULL with WHERE filters; \ - never use COALESCE to mix planned and actual timestamps.\n\ - \n\ - BEST PRACTICES:\n\ - - Use ILIKE for case-insensitive text search, LIKE for case-sensitive.\n\ - - Use EXISTS instead of IN for subquery existence checks.\n\ - - Use CTE (WITH ... AS) for complex multi-step logic.\n\ - - Use window functions (ROW_NUMBER, RANK, LAG, LEAD, SUM OVER) for ranking and running totals.\n\ - - Use date_trunc('period', column) for time-based grouping.\n\ - - Use generate_series() for creating ranges.\n\ - - Use string_agg(col, ', ') for concatenating grouped values.\n\ - - Use FILTER (WHERE ...) for conditional aggregation instead of CASE inside aggregate.\n\ - \n\ - {}\n", - schema_text - ); - - let raw = call_chat_simple(&app, &state, system_prompt, prompt).await?; - Ok(clean_sql_response(&raw)) -} - -// --------------------------------------------------------------------------- -// SQL explanation -// --------------------------------------------------------------------------- - -#[tauri::command] -pub async fn explain_sql( - app: AppHandle, - state: State<'_, Arc>, - connection_id: String, - sql: String, -) -> TuskResult { - let schema_text = build_schema_context(&state, &connection_id).await?; - - let system_prompt = format!( - "You are a PostgreSQL expert. Explain the given SQL query clearly and concisely.\n\ - \n\ - Structure your explanation as:\n\ - 1. **Summary** — one sentence describing what the query returns in business terms.\n\ - 2. **Step-by-step breakdown** — explain tables accessed, joins, filters, aggregations, \ - subqueries, and sorting. Use bullet points.\n\ - 3. **Notes** — mention edge cases, potential issues, or performance concerns if any.\n\ - \n\ - Use the database schema below to understand table relationships and column meanings.\n\ - Keep the explanation short; avoid restating the SQL verbatim.\n\ - \n\ - IMPORTANT: If you notice semantic issues (e.g. mixing planned_* and actual_* timestamps \ - with COALESCE, comparing unrelated columns, missing NULL filters on nullable timestamps), \ - mention them in the Notes section as potential problems.\n\ - \n\ - {}\n", - schema_text - ); - - call_chat_simple(&app, &state, system_prompt, sql).await -} - -// --------------------------------------------------------------------------- -// SQL error fixing -// --------------------------------------------------------------------------- - -#[tauri::command] -pub async fn fix_sql_error( - app: AppHandle, - state: State<'_, Arc>, - connection_id: String, - sql: String, - error_message: String, -) -> TuskResult { - let schema_text = build_schema_context(&state, &connection_id).await?; - - let system_prompt = format!( - "You are a PostgreSQL expert debugger. You receive a SQL query and the error it produced. \ - Fix the query so it executes correctly.\n\ - \n\ - OUTPUT FORMAT:\n\ - - Raw SQL only. No explanations, no markdown code fences (```), no comments.\n\ - - The output must be directly executable.\n\ - \n\ - DIAGNOSTIC CHECKLIST:\n\ - - Column/table does not exist → check the schema for correct names and spelling.\n\ - - Column is ambiguous → qualify with table name or alias.\n\ - - Must appear in GROUP BY → add missing non-aggregated columns to GROUP BY.\n\ - - Type mismatch → add appropriate casts (::text, ::integer, etc.).\n\ - - Permission denied → wrap in a read-only transaction if needed.\n\ - - Syntax error → correct PostgreSQL syntax (check commas, parentheses, keywords).\n\ - - Subquery returns more than one row → use IN, ANY, or add LIMIT 1.\n\ - - Division by zero → wrap divisor with NULLIF(x, 0).\n\ - \n\ - ONLY use tables and columns from the schema below. Never invent names.\n\ - Preserve the original intent of the query; change only what is necessary to fix the error.\n\ - \n\ - {}\n", - schema_text - ); - - let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message); - - let raw = call_chat_simple(&app, &state, system_prompt, user_content).await?; - Ok(clean_sql_response(&raw)) -} - // --------------------------------------------------------------------------- // Lite overview builder (chat v2) // --------------------------------------------------------------------------- @@ -735,231 +608,6 @@ async fn build_overview_clickhouse(state: &AppState, connection_id: &str) -> Tus Ok(out.join("\n")) } -// --------------------------------------------------------------------------- -// Full schema context builder (legacy — used by generate_sql/explain_sql/fix_sql_error) -// --------------------------------------------------------------------------- - -pub(crate) async fn build_schema_context( - state: &AppState, - connection_id: &str, -) -> TuskResult { - // Check cache first - if let Some(cached) = state.get_schema_cache(connection_id).await { - return Ok(cached); - } - - let flavor = state.get_flavor(connection_id).await; - if matches!(flavor, DbFlavor::ClickHouse) { - return build_clickhouse_schema_context(state, connection_id).await; - } - - let is_greenplum = matches!(flavor, DbFlavor::Greenplum); - let gp_major = state.get_gp_major(connection_id).await.unwrap_or(7); - - let pool = state.get_pool(connection_id).await?; - - // Run all metadata queries in parallel for speed - let ( - version_res, - current_db_res, - all_dbs_res, - col_res, - fk_res, - enum_res, - tbl_comment_res, - col_comment_res, - unique_res, - varchar_res, - jsonb_res, - ) = tokio::join!( - sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool), - sqlx::query_scalar::<_, String>("SELECT current_database()").fetch_one(&pool), - sqlx::query_scalar::<_, String>( - "SELECT datname FROM pg_database \ - WHERE datistemplate = false AND datallowconn = true \ - ORDER BY datname" - ) - .fetch_all(&pool), - fetch_columns(&pool), - fetch_foreign_keys_raw(&pool), - fetch_enum_types(&pool), - fetch_table_comments(&pool), - fetch_column_comments(&pool), - fetch_unique_constraints(&pool), - fetch_varchar_values(&pool), - fetch_jsonb_keys(&pool), - ); - - let version = version_res.map_err(TuskError::Database)?; - let current_db = current_db_res.unwrap_or_default(); - let all_dbs = all_dbs_res.unwrap_or_default(); - let col_rows = col_res?; - let fk_rows = fk_res?; - let enum_map = enum_res?; - let tbl_comments = tbl_comment_res?; - let col_comments = col_comment_res?; - let unique_constraints = unique_res?; - let varchar_values = varchar_res.unwrap_or_default(); - let jsonb_keys = jsonb_res.unwrap_or_default(); - let gp_extras = if is_greenplum { - Some(fetch_gp_table_extras(&pool, gp_major).await) - } else { - None - }; - - // -- Build FK inline lookup: (schema, table, column) -> "ref_schema.ref_table(ref_col)" -- - let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new(); - let mut fk_lines: Vec = Vec::new(); - for fk in &fk_rows { - let line = format!( - "FK: {}.{}({}) -> {}.{}({})", - fk.schema, - fk.table, - fk.columns.join(", "), - fk.ref_schema, - fk.ref_table, - fk.ref_columns.join(", ") - ); - fk_lines.push(line); - - // For single-column FKs, enable inline annotation on column - if fk.columns.len() == 1 && fk.ref_columns.len() == 1 { - fk_inline.insert( - (fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()), - format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]), - ); - } - } - - // -- Build unique constraint lookup: (schema, table) -> Vec -- - let mut unique_map: HashMap<(String, String), Vec> = HashMap::new(); - for (schema, table, cols) in &unique_constraints { - unique_map - .entry((schema.clone(), table.clone())) - .or_default() - .push(cols.join(", ")); - } - - // -- Format output -- - let mut output: Vec = Vec::new(); - - // 1. PostgreSQL version (short form) - let short_version = version - .split_whitespace() - .take(2) - .collect::>() - .join(" "); - output.push(format!("DATABASE SCHEMA ({})", short_version)); - if !current_db.is_empty() { - output.push(format!("ACTIVE DATABASE: {}", current_db)); - } - output.push(String::new()); - - // 2. Cluster topology — other databases on this server. - // Each PG database is isolated; cross-DB queries are not possible from a single connection. - if all_dbs.len() > 1 { - output.push("DATABASES ON THIS SERVER:".to_string()); - for db in &all_dbs { - if db == ¤t_db { - output.push(format!(" * {} (active)", db)); - } else { - output.push(format!(" {}", db)); - } - } - output.push(String::new()); - output.push( - "NOTE: Tables in other databases are NOT queryable from this session. \ - If the user's question concerns data likely stored in a different database \ - (e.g. an identity service in a separate DB), respond with a `final` message \ - asking them to switch the active database via the connection selector." - .to_string(), - ); - output.push(String::new()); - } - - // 3. Quick table+column index for fast existence checks before writing SQL. - // Each line lists `schema.table(col1, col2, ...)` so the model can grep both - // table names and column names without scrolling through the full TABLES section. - { - let mut by_table: BTreeMap<(String, String), Vec> = BTreeMap::new(); - for c in &col_rows { - by_table - .entry((c.schema.clone(), c.table.clone())) - .or_default() - .push(c.column.clone()); - } - if !by_table.is_empty() { - output.push(format!( - "TABLE INDEX (database `{}`, {} tables — table_name(column_list)):", - current_db, - by_table.len() - )); - for ((schema, table), cols) in &by_table { - output.push(format!(" {}.{}({})", schema, table, cols.join(", "))); - } - output.push(String::new()); - } - } - - // 4. Enum types - if !enum_map.is_empty() { - output.push("ENUM TYPES:".to_string()); - for (type_name, values) in &enum_map { - let vals_str = values - .iter() - .map(|v| format!("'{}'", v)) - .collect::>() - .join(", "); - output.push(format!(" {} = [{}]", type_name, vals_str)); - } - output.push(String::new()); - } - - // 5. Tables with columns - output.push("TABLES:".to_string()); - - // Group columns by schema.table preserving order - let mut tables: BTreeMap> = BTreeMap::new(); - for ci in &col_rows { - let key = format!("{}.{}", ci.schema, ci.table); - tables.entry(key).or_default().push(ci.clone()); - } - - for (full_name, columns) in &tables { - format_table_block( - full_name, - columns, - &tbl_comments, - &col_comments, - &fk_inline, - &enum_map, - &unique_map, - &varchar_values, - &jsonb_keys, - gp_extras.as_ref(), - &mut output, - ); - } - - // 6. Foreign keys summary - if !fk_lines.is_empty() { - output.push(String::new()); - output.push("FOREIGN KEYS:".to_string()); - for fk in &fk_lines { - output.push(format!(" {}", fk)); - } - } - - let result = output.join("\n"); - - // Cache the result - state - .set_schema_cache(connection_id.to_string(), result.clone()) - .await; - - Ok(result) -} - // --------------------------------------------------------------------------- // Schema query helpers // --------------------------------------------------------------------------- @@ -976,8 +624,7 @@ pub(crate) struct ColumnInfo { } /// Render a single table's column block in the human/LLM-readable schema format. -/// Reused by both `build_schema_context` (full DDL for legacy AI commands) and -/// the new `get_columns` chat tool. +/// Used by the chat agent's `get_columns` tool. #[allow(clippy::too_many_arguments)] pub(crate) fn format_table_block( full_name: &str, @@ -1102,106 +749,6 @@ pub(crate) fn format_table_block( } } -async fn build_clickhouse_schema_context( - state: &AppState, - connection_id: &str, -) -> TuskResult { - let client = state.get_ch_client(connection_id).await?; - let db = client.database.clone(); - - // ClickHouse exposes ALL databases via system.* — pull cross-DB schema in one shot. - let columns_sql = "SELECT database, table, name, type, is_in_primary_key \ - FROM system.columns \ - WHERE database NOT IN ('system','INFORMATION_SCHEMA','information_schema') \ - ORDER BY database, table, position"; - let dbs_sql = "SELECT name FROM system.databases \ - WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \ - ORDER BY name"; - - let rows = client.fetch_objects(columns_sql).await?; - let db_rows = client.fetch_objects(dbs_sql).await.unwrap_or_default(); - - let version = client.ping().await.unwrap_or_default(); - let mut out = String::new(); - out.push_str("DATABASE: ClickHouse\n"); - if !version.is_empty() { - out.push_str(&format!("VERSION: {}\n", version.trim())); - } - out.push_str(&format!("ACTIVE_DATABASE: {}\n\n", db)); - - // Cluster overview - if !db_rows.is_empty() { - out.push_str("DATABASES ON THIS SERVER:\n"); - for row in &db_rows { - let name = row.get("name").and_then(|v| v.as_str()).unwrap_or(""); - if name == db { - out.push_str(&format!(" * {} (active)\n", name)); - } else { - out.push_str(&format!(" {}\n", name)); - } - } - out.push('\n'); - } - - // Table+column index — same shape as the PG path so model has a uniform reference. - { - let mut by_table: BTreeMap<(String, String), Vec> = BTreeMap::new(); - for r in &rows { - let dbn = r.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let tbl = r.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let col = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); - if !dbn.is_empty() && !tbl.is_empty() && !col.is_empty() { - by_table.entry((dbn, tbl)).or_default().push(col); - } - } - if !by_table.is_empty() { - out.push_str(&format!( - "TABLE INDEX ({} tables across all databases — db.table(column_list)):\n", - by_table.len() - )); - for ((dbn, tbl), cols) in &by_table { - out.push_str(&format!(" {}.{}({})\n", dbn, tbl, cols.join(", "))); - } - out.push('\n'); - } - } - - out.push_str("TABLES:\n"); - let mut current_key: Option = None; - for row in &rows { - let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let table = row.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string(); - let column = row.get("name").and_then(|v| v.as_str()).unwrap_or(""); - let dtype = row.get("type").and_then(|v| v.as_str()).unwrap_or(""); - let is_pk = matches!(row.get("is_in_primary_key"), Some(serde_json::Value::Number(n)) if n.as_i64() == Some(1)) - || matches!(row.get("is_in_primary_key"), Some(serde_json::Value::String(s)) if s == "1"); - let key = format!("{}.{}", dbn, table); - if Some(&key) != current_key.as_ref() { - out.push_str(&format!("\nTABLE {}.{}\n", dbn, table)); - current_key = Some(key); - } - out.push_str(&format!( - " {} {}{}\n", - column, - dtype, - if is_pk { " [PK]" } else { "" } - )); - } - - out.push_str( - "\nNOTES:\n\ - - Use ClickHouse SQL dialect. Functions differ from PostgreSQL (e.g. count(), arrayJoin, toDate, formatDateTime).\n\ - - ClickHouse allows fully-qualified `database.table` in queries — you CAN cross-reference databases on this server.\n\ - - Read-only mode is enforced for the agent: only SELECT/WITH/EXPLAIN/SHOW/DESCRIBE allowed.\n\ - - Always include LIMIT for ad-hoc SELECTs.\n", - ); - - state - .set_schema_cache(connection_id.to_string(), out.clone()) - .await; - Ok(out) -} - pub(crate) async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult> { let rows = sqlx::query( "SELECT \ @@ -1399,6 +946,7 @@ pub(crate) async fn fetch_unique_constraints( /// Returns HashMap<(schema, table, column), Vec> for varchar columns /// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery. /// Returns None if pg_stats is not accessible (graceful degradation). +#[allow(dead_code)] // re-exposed by profile_table tool (PR2) async fn fetch_varchar_values( pool: &sqlx::PgPool, ) -> Option>> { @@ -1440,104 +988,6 @@ async fn fetch_varchar_values( Some(map) } -/// Discovers top-level keys in JSONB columns by sampling actual data. -/// Runs two sequential queries internally: first discovers JSONB columns, -/// then samples keys from each via a single UNION ALL query. -/// Returns None on error (graceful degradation). -async fn fetch_jsonb_keys( - pool: &sqlx::PgPool, -) -> Option>> { - // Step 1: Find all JSONB columns - let col_rows = match sqlx::query( - "SELECT table_schema, table_name, column_name \ - FROM information_schema.columns \ - WHERE data_type = 'jsonb' \ - AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ - ORDER BY table_schema, table_name, column_name", - ) - .fetch_all(pool) - .await - { - Ok(r) => r, - Err(e) => { - log::warn!("Failed to fetch JSONB columns: {}", e); - return None; - } - }; - - if col_rows.is_empty() { - return Some(HashMap::new()); - } - - // Cap at 50 JSONB columns to prevent unbounded UNION ALL queries on large schemas - let columns: Vec<(String, String, String)> = col_rows - .iter() - .take(50) - .map(|r| { - ( - r.get::(0), - r.get::(1), - r.get::(2), - ) - }) - .collect(); - - // Step 2: Build a single UNION ALL query to sample keys from all JSONB columns - let parts: Vec = columns - .iter() - .enumerate() - .map(|(i, (schema, table, col))| { - let qs = schema.replace('"', "\"\""); - let qt = table.replace('"', "\"\""); - let qc = col.replace('"', "\"\""); - format!( - "(SELECT '{}.{}.{}' AS col_ref, key FROM (\ - SELECT DISTINCT jsonb_object_keys(\"{}\") AS key \ - FROM \"{}\".\"{}\" \ - WHERE \"{}\" IS NOT NULL AND jsonb_typeof(\"{}\") = 'object' \ - LIMIT 50\ - ) sub{})", - schema, table, col, qc, qs, qt, qc, qc, i - ) - }) - .collect(); - - let query = parts.join(" UNION ALL "); - - let rows = match sqlx::query(&query).fetch_all(pool).await { - Ok(r) => r, - Err(e) => { - log::warn!("Failed to fetch JSONB keys: {}", e); - return None; - } - }; - - let mut map: HashMap<(String, String, String), Vec> = HashMap::new(); - for r in &rows { - let col_ref: String = r.get(0); - let key: String = r.get(1); - let ref_parts: Vec<&str> = col_ref.splitn(3, '.').collect(); - if ref_parts.len() == 3 { - let entry = map - .entry(( - ref_parts[0].to_string(), - ref_parts[1].to_string(), - ref_parts[2].to_string(), - )) - .or_default(); - if !entry.contains(&key) { - entry.push(key); - } - } - } - - for vals in map.values_mut() { - vals.sort(); - } - - Some(map) -} - // --------------------------------------------------------------------------- // Greenplum-specific table attributes // --------------------------------------------------------------------------- @@ -1656,6 +1106,7 @@ pub(crate) async fn fetch_gp_table_extras( // --------------------------------------------------------------------------- /// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"} +#[allow(dead_code)] // helper for fetch_varchar_values; re-exposed by profile_table tool (PR2) fn parse_pg_array_text(s: &str) -> Vec { let s = s.trim(); let s = s.strip_prefix('{').unwrap_or(s); @@ -1716,65 +1167,10 @@ fn simplify_default(raw: &str) -> String { s.to_string() } -fn clean_sql_response(raw: &str) -> String { - let trimmed = raw.trim(); - // Remove markdown code fences - let without_fences = if trimmed.starts_with("```") { - let inner = trimmed - .strip_prefix("```sql") - .or_else(|| trimmed.strip_prefix("```SQL")) - .or_else(|| trimmed.strip_prefix("```postgresql")) - .or_else(|| trimmed.strip_prefix("```")) - .unwrap_or(trimmed); - inner.strip_suffix("```").unwrap_or(inner) - } else { - trimmed - }; - without_fences.trim().to_string() -} - #[cfg(test)] mod tests { use super::*; - // ── clean_sql_response ──────────────────────────────────── - - #[test] - fn clean_sql_plain() { - assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1"); - } - - #[test] - fn clean_sql_with_fences() { - assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1"); - } - - #[test] - fn clean_sql_with_generic_fences() { - assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1"); - } - - #[test] - fn clean_sql_with_postgresql_fences() { - assert_eq!( - clean_sql_response("```postgresql\nSELECT 1\n```"), - "SELECT 1" - ); - } - - #[test] - fn clean_sql_with_whitespace() { - assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1"); - } - - #[test] - fn clean_sql_no_fences_multiline() { - assert_eq!( - clean_sql_response("SELECT\n *\nFROM users"), - "SELECT\n *\nFROM users" - ); - } - // ── Fireworks provider ─────────────────────────────────── #[test] @@ -1784,12 +1180,14 @@ mod tests { } #[test] - fn deserializes_legacy_settings_without_fireworks_key() { - // Old config files won't have `fireworks_api_key` — must still parse. + fn deserializes_legacy_settings_with_dropped_provider_keys() { + // Old config files may include `openai_api_key`/`anthropic_api_key` and a + // legacy `"provider": "openai"` value — both must be tolerated, with the + // unknown provider coerced to Ollama. let legacy = r#"{ - "provider": "ollama", + "provider": "openai", "ollama_url": "http://localhost:11434", - "openai_api_key": null, + "openai_api_key": "sk-deprecated", "anthropic_api_key": null, "model": "qwen2.5-coder:7b" }"#; @@ -1813,6 +1211,28 @@ mod tests { assert_eq!(parsed.choices[0].message.content, "hi"); } + // ── OpenRouter provider ────────────────────────────────── + + #[test] + fn serializes_openrouter_provider() { + let json = serde_json::to_string(&AiProvider::OpenRouter).unwrap(); + assert_eq!(json, "\"openrouter\""); + } + + #[test] + fn deserializes_openrouter_settings() { + let cfg = r#"{ + "provider": "openrouter", + "ollama_url": "http://localhost:11434", + "openrouter_api_key": "sk-or-v1-abc", + "model": "anthropic/claude-3.5-sonnet" + }"#; + let parsed: AiSettings = serde_json::from_str(cfg).unwrap(); + assert_eq!(parsed.provider, AiProvider::OpenRouter); + assert_eq!(parsed.openrouter_api_key.as_deref(), Some("sk-or-v1-abc")); + assert_eq!(parsed.model, "anthropic/claude-3.5-sonnet"); + } + #[test] fn parses_fireworks_models_list() { let body = r#"{ diff --git a/src-tauri/src/commands/chat.rs b/src-tauri/src/commands/chat.rs index 4b65bbe..785db32 100644 --- a/src-tauri/src/commands/chat.rs +++ b/src-tauri/src/commands/chat.rs @@ -1,13 +1,14 @@ use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings}; use crate::commands::chat_tools::{ - find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool, + build_sample_sql, detect_skew_tool, explain_query_tool, find_queries_tool, get_columns_tool, + list_databases_tool, list_tables_tool, profile_table_tool, save_query_tool, switch_database_tool, }; use crate::commands::memory::{append_memory_core, read_memory_core}; use crate::commands::queries::execute_query_core; use crate::error::{TuskError, TuskResult}; use crate::models::ai::OllamaChatMessage; -use crate::models::chat::{ChartConfig, ChatMessage, ChatTurnResult, ContextUsage}; +use crate::models::chat::{ChatMessage, ChatTurnResult, ContextUsage}; use crate::models::query_result::QueryResult; use crate::state::AppState; use chrono::Utc; @@ -30,11 +31,11 @@ const TEXT_TOOL_CHAR_CAP: usize = 10_000; /// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192). /// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content. const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000; -/// Conservative default for managed providers (Fireworks). Most chat-capable -/// Fireworks models ship with 32K–256K context windows; 384K chars (~128K tok) -/// is a safe floor that won't trigger false /compact nags on normal sessions -/// while still flagging genuinely runaway threads. -const CONTEXT_BUDGET_CHARS_FIREWORKS: u64 = 384_000; +/// Conservative default for managed providers (Fireworks, OpenRouter). Most +/// chat-capable hosted models ship with 32K–256K context windows; 384K chars +/// (~128K tok) is a safe floor that won't trigger false /compact nags on normal +/// sessions while still flagging genuinely runaway threads. +const CONTEXT_BUDGET_CHARS_MANAGED: u64 = 384_000; /// Stop the loop when the model fails the same SQL hurdle this many times in a /// row. Beyond this, additional hops almost always burn the rest of the budget /// on identical retries; a definitive `final` with the error is more useful. @@ -55,9 +56,15 @@ enum AgentAction { Remember { note: String }, SaveQuery { name: String, sql: String }, FindQueries { text: String }, - MakeChart { config: ChartConfig }, + ProfileTable { table: String }, + SampleData { table: String, limit: u32 }, + ExplainQuery { sql: String }, + DetectSkew { table: String }, } +const SAMPLE_DATA_DEFAULT_LIMIT: u32 = 50; +const SAMPLE_DATA_MAX_LIMIT: u32 = 200; + /// Parse the model's JSON response. Accepts both shapes the model tends to emit: /// {"action":"X","field":"..."} — flat (matches our prompt) /// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention) @@ -157,60 +164,55 @@ fn parse_agent_action(raw: &str) -> Result { } Ok(AgentAction::FindQueries { text }) } - "make_chart" => { - let chart_type = lookup("chart_type") - .or_else(|| lookup("type")) + "profile_table" => { + let table = lookup("table") .and_then(|v| v.as_str()) - .ok_or_else(|| "make_chart missing `chart_type`".to_string())? - .trim() - .to_lowercase(); - if !["bar", "line", "area", "pie"].contains(&chart_type.as_str()) { - return Err(format!( - "make_chart `chart_type` must be one of: bar, line, area, pie. Got: {}", - chart_type - )); - } - let x = lookup("x") - .and_then(|v| v.as_str()) - .ok_or_else(|| "make_chart missing `x` column".to_string())? + .ok_or_else(|| "profile_table missing `table`".to_string())? .trim() .to_string(); - let y = lookup("y") - .and_then(|v| v.as_str()) - .ok_or_else(|| "make_chart missing `y` column".to_string())? - .trim() - .to_string(); - if x.is_empty() || y.is_empty() { - return Err("make_chart `x` and `y` must not be empty".into()); + if table.is_empty() { + return Err("profile_table `table` must not be empty".into()); } - let group = lookup("group") - .and_then(|v| v.as_str()) - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()); - let title = lookup("title") - .and_then(|v| v.as_str()) - .map(|s| s.trim().to_string()) - .filter(|s| !s.is_empty()); - let orientation = lookup("orientation") - .and_then(|v| v.as_str()) - .map(|s| s.trim().to_lowercase()) - .filter(|s| !s.is_empty()); - Ok(AgentAction::MakeChart { - config: ChartConfig { - chart_type, - x, - y, - group, - title, - orientation, - }, - }) + Ok(AgentAction::ProfileTable { table }) + } + "sample_data" => { + let table = lookup("table") + .and_then(|v| v.as_str()) + .ok_or_else(|| "sample_data missing `table`".to_string())? + .trim() + .to_string(); + if table.is_empty() { + return Err("sample_data `table` must not be empty".into()); + } + let limit = lookup("limit") + .and_then(|v| v.as_u64()) + .map(|n| n as u32) + .unwrap_or(SAMPLE_DATA_DEFAULT_LIMIT) + .clamp(1, SAMPLE_DATA_MAX_LIMIT); + Ok(AgentAction::SampleData { table, limit }) + } + "explain_query" => { + let sql = lookup("sql") + .and_then(|v| v.as_str()) + .ok_or_else(|| "explain_query missing `sql`".to_string())? + .trim() + .to_string(); + if sql.is_empty() { + return Err("explain_query `sql` must not be empty".into()); + } + Ok(AgentAction::ExplainQuery { sql }) + } + "detect_skew" => { + let table = lookup("table") + .and_then(|v| v.as_str()) + .ok_or_else(|| "detect_skew missing `table`".to_string())? + .trim() + .to_string(); + if table.is_empty() { + return Err("detect_skew `table` must not be empty".into()); + } + Ok(AgentAction::DetectSkew { table }) } - // Legacy from earlier iterations — silently ignored at parse time so the - // model can recover with a different action. - "get_schema" => Err( - "get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(), - ), other => Err(format!("unknown action `{}`", other)), } } @@ -285,8 +287,17 @@ You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On {{"action":"save_query","name":"","sql":""}} Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved. - {{"action":"make_chart","chart_type":"bar","x":"","y":"","title":""}} - Visualise the LAST successful run_query result as a chart inline. `chart_type` is one of: bar, line, area, pie. `x` and `y` MUST be column names from the previous result. Optional: `group` (column for series), `orientation` ("vertical"/"horizontal", bar only). Use after run_query when the data is aggregated and would be clearer as a chart (top-N comparisons → bar; time series → line/area; proportions → pie). Skip for tiny results (≤2 rows) and giant ones (>500 rows). + {{"action":"profile_table","table":"schema.table"}} + Per-column profile: NULL fraction, distinct cardinality, min/max range, top-K values. PG/GP reads pg_stats (zero-cost; ensure ANALYZE has run). ClickHouse fires one summary query (cheap on MergeTree). Use BEFORE writing aggregations to spot pseudo-enums, NULL-heavy columns, or skewed distributions. + + {{"action":"sample_data","table":"schema.table","limit":50}} + Random row sample (default 50, max 200). PG/GP uses TABLESAMPLE BERNOULLI when reltuples > 0, else ORDER BY random(). CH uses SAMPLE 0.01 on MergeTree with a sampling key, else ORDER BY rand(). Use to eyeball value shape BEFORE writing filters; cheaper than `SELECT * LIMIT N` on huge tables. + + {{"action":"explain_query","sql":"SELECT ..."}} + Run EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) on PG/GP, EXPLAIN PLAN on CH. Reports root node, planning + execution time, seq-scanned tables, spilled sorts, est-vs-actual row skew, Greenplum Motions. Use AFTER a slow run_query. + + {{"action":"detect_skew","table":"schema.table"}} + Greenplum-only: counts rows per gp_segment_id and reports max/min/avg + skew ratio. Ratio > 1.5 ⇒ uneven distribution; suggests revisiting DISTRIBUTED BY. Soft-errors on PG/CH. {{"action":"final","text":"..."}} End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling). @@ -296,6 +307,8 @@ WORKFLOW 2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question. 3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs. 4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns. + 4a. If the user asks about value shape (cardinality, NULL rates, top values), prefer `profile_table` over a hand-written run_query. To eyeball actual rows, prefer `sample_data` over `LIMIT 100`. + 4b. If the user reports a slow query or asks why something takes long, run `explain_query` on it. On Greenplum, if a single table appears unbalanced, check `detect_skew`. 5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first. 6. Execute run_query. 7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here. @@ -415,9 +428,6 @@ fn build_history( content: serde_json::json!({ "action": "final", "text": text }).to_string(), }), ChatMessage::ToolCall { tool, input_json, .. } => { - if tool == "get_schema" { - continue; // legacy - } let mut envelope = serde_json::Map::new(); envelope.insert("action".to_string(), Value::String(tool.clone())); if let Ok(Value::Object(input)) = serde_json::from_str::(input_json) { @@ -437,9 +447,6 @@ fn build_history( result, .. } => { - if tool == "get_schema" { - continue; // legacy - } let payload = match tool.as_str() { "run_query" => { if *is_error { @@ -521,7 +528,7 @@ async fn provider_budget_chars(state: &AppState, app: &AppHandle) -> u64 { use crate::models::ai::AiProvider; match load_ai_settings(app, state).await { Ok(s) => match s.provider { - AiProvider::Fireworks => CONTEXT_BUDGET_CHARS_FIREWORKS, + AiProvider::Fireworks | AiProvider::OpenRouter => CONTEXT_BUDGET_CHARS_MANAGED, _ => CONTEXT_BUDGET_CHARS_OLLAMA, }, Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA, @@ -597,7 +604,10 @@ pub async fn chat_send( } }; - let is_run_query = matches!(&action, AgentAction::RunQuery { .. }); + let is_run_query = matches!( + &action, + AgentAction::RunQuery { .. } | AgentAction::SampleData { .. } + ); match action { AgentAction::Final { text } => { @@ -742,91 +752,90 @@ pub async fn chat_send( ); push_tool_result(&mut new_messages, &mut working, result); } - AgentAction::MakeChart { config } => { - let config_json = serde_json::to_string(&config).unwrap_or_else(|_| "{}".into()); + AgentAction::ProfileTable { table } => { push_tool_call( &mut new_messages, &mut working, - "make_chart", - config_json.clone(), + "profile_table", + serde_json::json!({ "table": &table }).to_string(), ); - - let result_msg = match last_successful_query_result(&working) { - None => ChatMessage::ToolResult { - id: new_id("res"), - tool: "make_chart".to_string(), - is_error: true, - text: Some( - "make_chart needs a successful run_query result above it. Run a SELECT first, then call make_chart." - .to_string(), - ), - result: None, - created_at: now_ms(), - }, - Some(qr) => { - if !qr.columns.iter().any(|c| c == &config.x) { + let result = run_text_tool( + profile_table_tool(&state, &connection_id, &table).await, + "profile_table", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::ExplainQuery { sql } => { + push_tool_call( + &mut new_messages, + &mut working, + "explain_query", + serde_json::json!({ "sql": &sql }).to_string(), + ); + let result = run_text_tool( + explain_query_tool(&state, &connection_id, &sql).await, + "explain_query", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::DetectSkew { table } => { + push_tool_call( + &mut new_messages, + &mut working, + "detect_skew", + serde_json::json!({ "table": &table }).to_string(), + ); + let result = run_text_tool( + detect_skew_tool(&state, &connection_id, &table).await, + "detect_skew", + ); + push_tool_result(&mut new_messages, &mut working, result); + } + AgentAction::SampleData { table, limit } => { + push_tool_call( + &mut new_messages, + &mut working, + "sample_data", + serde_json::json!({ "table": &table, "limit": limit }).to_string(), + ); + let outcome = match build_sample_sql(&state, &connection_id, &table, limit).await { + Ok(sql) => match execute_query_core(&state, &connection_id, &sql).await { + Ok(qr) => { + consecutive_query_errors = 0; ChatMessage::ToolResult { id: new_id("res"), - tool: "make_chart".to_string(), - is_error: true, - text: Some(format!( - "x column `{}` is not in the last result. Available: {}.", - config.x, - qr.columns.join(", ") - )), - result: None, - created_at: now_ms(), - } - } else if !qr.columns.iter().any(|c| c == &config.y) { - ChatMessage::ToolResult { - id: new_id("res"), - tool: "make_chart".to_string(), - is_error: true, - text: Some(format!( - "y column `{}` is not in the last result. Available: {}.", - config.y, - qr.columns.join(", ") - )), - result: None, - created_at: now_ms(), - } - } else if let Some(group) = &config.group { - if !qr.columns.iter().any(|c| c == group) { - ChatMessage::ToolResult { - id: new_id("res"), - tool: "make_chart".to_string(), - is_error: true, - text: Some(format!( - "group column `{}` is not in the last result. Available: {}.", - group, - qr.columns.join(", ") - )), - result: None, - created_at: now_ms(), - } - } else { - ChatMessage::ToolResult { - id: new_id("res"), - tool: "make_chart".to_string(), - is_error: false, - text: Some(config_json.clone()), - result: Some(qr), - created_at: now_ms(), - } - } - } else { - ChatMessage::ToolResult { - id: new_id("res"), - tool: "make_chart".to_string(), + tool: "sample_data".to_string(), is_error: false, - text: Some(config_json.clone()), + text: None, result: Some(qr), created_at: now_ms(), } } + Err(e) => { + consecutive_query_errors += 1; + ChatMessage::ToolResult { + id: new_id("res"), + tool: "sample_data".to_string(), + is_error: true, + text: Some(format_db_error(&e)), + result: None, + created_at: now_ms(), + } + } + }, + Err(e) => { + consecutive_query_errors += 1; + ChatMessage::ToolResult { + id: new_id("res"), + tool: "sample_data".to_string(), + is_error: true, + text: Some(format_db_error(&e)), + result: None, + created_at: now_ms(), + } } }; - push_tool_result(&mut new_messages, &mut working, result_msg); + push_tool_result(&mut new_messages, &mut working, outcome); } } @@ -982,26 +991,6 @@ fn format_db_error(e: &TuskError) -> String { e.to_string() } -/// Locate the most recent SUCCESSFUL run_query in the working thread and -/// return its full QueryResult. Used by make_chart to attach data to a chart -/// directive without relying on the model to re-send it. -fn last_successful_query_result(messages: &[ChatMessage]) -> Option { - for m in messages.iter().rev() { - if let ChatMessage::ToolResult { - tool, - is_error: false, - result: Some(qr), - .. - } = m - { - if tool == "run_query" { - return Some(qr.clone()); - } - } - } - None -} - /// Pull the most recent run_query error text from the working thread, so the /// post-loop "I gave up" summary can quote concrete errors back to the user. fn last_run_query_error(messages: &[ChatMessage]) -> Option { @@ -1484,119 +1473,6 @@ mod tests { assert!(last_run_query_error(&msgs).is_none()); } - #[test] - fn parses_make_chart_minimal() { - let a = parse_agent_action( - r#"{"action":"make_chart","chart_type":"bar","x":"carrier","y":"trips"}"#, - ) - .unwrap(); - match a { - AgentAction::MakeChart { config } => { - assert_eq!(config.chart_type, "bar"); - assert_eq!(config.x, "carrier"); - assert_eq!(config.y, "trips"); - assert!(config.group.is_none()); - assert!(config.title.is_none()); - } - _ => panic!("wrong variant"), - } - } - - #[test] - fn parses_make_chart_with_group_and_title() { - let a = parse_agent_action( - r#"{"action":"make_chart","chart_type":"line","x":"month","y":"revenue","group":"region","title":"Revenue"}"#, - ) - .unwrap(); - match a { - AgentAction::MakeChart { config } => { - assert_eq!(config.group.as_deref(), Some("region")); - assert_eq!(config.title.as_deref(), Some("Revenue")); - } - _ => panic!("wrong variant"), - } - } - - #[test] - fn make_chart_accepts_alternative_field_name_type() { - // Some models emit `type` instead of `chart_type`. - let a = parse_agent_action( - r#"{"action":"make_chart","type":"pie","x":"label","y":"value"}"#, - ) - .unwrap(); - match a { - AgentAction::MakeChart { config } => assert_eq!(config.chart_type, "pie"), - _ => panic!("wrong variant"), - } - } - - #[test] - fn rejects_make_chart_with_unknown_chart_type() { - let r = parse_agent_action( - r#"{"action":"make_chart","chart_type":"radar","x":"a","y":"b"}"#, - ); - assert!(r.is_err()); - } - - #[test] - fn rejects_make_chart_missing_x_or_y() { - assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","y":"a"}"#).is_err()); - assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","x":"a"}"#).is_err()); - } - - #[test] - fn last_successful_query_result_finds_recent() { - use crate::models::query_result::QueryResult; - let qr = QueryResult { - columns: vec!["a".into()], - types: vec!["INT4".into()], - rows: vec![vec![Value::Number(1.into())]], - row_count: 1, - execution_time_ms: 1, - }; - let msgs = vec![ - ChatMessage::ToolResult { - id: "r1".into(), - tool: "run_query".into(), - is_error: false, - text: None, - result: Some(qr.clone()), - created_at: 1, - }, - ChatMessage::ToolResult { - id: "r2".into(), - tool: "run_query".into(), - is_error: true, - text: Some("oops".into()), - result: None, - created_at: 2, - }, - ]; - let found = last_successful_query_result(&msgs).expect("ok"); - assert_eq!(found.columns, vec!["a".to_string()]); - } - - #[test] - fn last_successful_query_result_skips_non_run_query() { - use crate::models::query_result::QueryResult; - let qr = QueryResult { - columns: vec!["a".into()], - types: vec!["INT4".into()], - rows: vec![], - row_count: 0, - execution_time_ms: 0, - }; - let msgs = vec![ChatMessage::ToolResult { - id: "r1".into(), - tool: "list_tables".into(), - is_error: false, - text: Some("public.x".into()), - result: Some(qr), - created_at: 1, - }]; - assert!(last_successful_query_result(&msgs).is_none()); - } - #[test] fn render_thread_for_summary_includes_roles_and_skips_rows() { let msgs = vec![ @@ -1625,11 +1501,6 @@ mod tests { assert!(!rendered.contains("alice")); } - #[test] - fn rejects_legacy_get_schema() { - assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err()); - } - #[test] fn truncates_long_cell() { let long = "a".repeat(CELL_CHAR_CAP + 50); diff --git a/src-tauri/src/commands/chat_tools.rs b/src-tauri/src/commands/chat_tools.rs index 3608c74..ad4ff15 100644 --- a/src-tauri/src/commands/chat_tools.rs +++ b/src-tauri/src/commands/chat_tools.rs @@ -10,11 +10,14 @@ use crate::commands::ai::{ ColumnInfo, }; use crate::commands::connections::{load_connection_config, switch_database_core}; +use crate::commands::queries::execute_query_core; use crate::commands::saved_queries::{list_saved_queries_core, save_query_core}; use crate::commands::schema::{list_databases_core, list_tables_core}; +use crate::db::sql_guard::ensure_readonly_sql; use crate::error::{TuskError, TuskResult}; use crate::models::saved_queries::SavedQuery; use crate::state::{AppState, CachedVec, DbFlavor}; +use crate::utils::escape_ident; use sqlx::{PgPool, Row}; use std::collections::{BTreeMap, HashMap}; use std::time::{Duration, Instant}; @@ -565,3 +568,690 @@ pub async fn find_queries_tool( Ok(out) } +// --------------------------------------------------------------------------- +// profile_table (PR2 — data-engineering tool) +// --------------------------------------------------------------------------- + +const PROFILE_TABLE_MAX_COLUMNS: usize = 30; +const PROFILE_TABLE_TOPK: usize = 5; + +pub async fn profile_table_tool( + state: &AppState, + connection_id: &str, + table: &str, +) -> TuskResult { + let active_db = active_db_name(state, connection_id).await.unwrap_or_default(); + let (schema, tbl, _raw) = normalise_table_ref(table, &active_db); + let flavor = state.get_flavor(connection_id).await; + match flavor { + DbFlavor::PostgreSQL | DbFlavor::Greenplum => { + profile_table_postgres(state, connection_id, &schema, &tbl).await + } + DbFlavor::ClickHouse => profile_table_clickhouse(state, connection_id, &schema, &tbl).await, + } +} + +async fn profile_table_postgres( + state: &AppState, + connection_id: &str, + schema: &str, + table: &str, +) -> TuskResult { + let pool = state.get_pool(connection_id).await?; + + let exists = sqlx::query_scalar::<_, i64>( + "SELECT 1 FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \ + WHERE n.nspname = $1 AND c.relname = $2 LIMIT 1", + ) + .bind(schema) + .bind(table) + .fetch_optional(&pool) + .await + .map_err(TuskError::Database)?; + if exists.is_none() { + return Err(TuskError::Custom(format!( + "Table '{}.{}' does not exist (or no privileges).", + schema, table + ))); + } + + let last_analyze: Option> = sqlx::query_scalar( + "SELECT GREATEST(last_analyze, last_autoanalyze) FROM pg_stat_user_tables \ + WHERE schemaname = $1 AND relname = $2", + ) + .bind(schema) + .bind(table) + .fetch_optional(&pool) + .await + .ok() + .flatten(); + + let stat_rows = sqlx::query( + "SELECT attname, null_frac, n_distinct, \ + most_common_vals::text, most_common_freqs, histogram_bounds::text \ + FROM pg_stats \ + WHERE schemaname = $1 AND tablename = $2 \ + ORDER BY attname", + ) + .bind(schema) + .bind(table) + .fetch_all(&pool) + .await + .map_err(TuskError::Database)?; + + let mut out = format!("PROFILE {}.{}\n", schema, table); + match last_analyze { + Some(ts) => out.push_str(&format!("Last ANALYZE: {}\n", ts.to_rfc3339())), + None => out.push_str("Last ANALYZE: never\n"), + } + + if stat_rows.is_empty() { + out.push_str(&format!( + "\nNo statistics in pg_stats. Run: ANALYZE {}.{};\n", + escape_ident(schema), + escape_ident(table) + )); + return Ok(out); + } + + let total = stat_rows.len(); + let take = total.min(PROFILE_TABLE_MAX_COLUMNS); + out.push_str(&format!("\n{} columns with stats\n", total)); + + for r in stat_rows.iter().take(take) { + let attname: String = r.get(0); + let null_frac: f32 = r.try_get(1).unwrap_or(0.0); + let n_distinct: f32 = r.try_get(2).unwrap_or(0.0); + let mcv_text: Option = r.try_get(3).ok(); + let mcf_arr: Option> = r.try_get(4).ok(); + let hist_text: Option = r.try_get(5).ok(); + + out.push_str(&format!("\n {}:\n", attname)); + out.push_str(&format!(" null_frac: {:.4}\n", null_frac)); + if n_distinct < 0.0 { + out.push_str(&format!( + " n_distinct: {:.3} (ratio of total rows)\n", + -n_distinct + )); + } else { + out.push_str(&format!(" n_distinct: {}\n", n_distinct as i64)); + } + + if let Some(text) = hist_text.as_deref() { + let bounds = parse_pg_array_text_local(text); + if let (Some(min), Some(max)) = (bounds.first(), bounds.last()) { + out.push_str(&format!(" range: {} … {}\n", min, max)); + } + } + + if let Some(text) = mcv_text.as_deref() { + let vals = parse_pg_array_text_local(text); + if !vals.is_empty() { + let freqs = mcf_arr.unwrap_or_default(); + let pairs: Vec = vals + .iter() + .take(PROFILE_TABLE_TOPK) + .enumerate() + .map(|(i, v)| match freqs.get(i) { + Some(f) => format!("{}({:.3})", v, f), + None => v.clone(), + }) + .collect(); + out.push_str(&format!(" top: {}\n", pairs.join(", "))); + } + } + } + + if total > take { + out.push_str(&format!("\n…and {} more columns\n", total - take)); + } + Ok(out) +} + +/// Local pg-array parser used by profile_table; mirrors `parse_pg_array_text` in ai.rs +/// but kept local to avoid importing a private helper. +fn parse_pg_array_text_local(s: &str) -> Vec { + let s = s.trim(); + let s = s.strip_prefix('{').unwrap_or(s); + let s = s.strip_suffix('}').unwrap_or(s); + if s.is_empty() { + return Vec::new(); + } + let mut out = Vec::new(); + let mut cur = String::new(); + let mut in_quotes = false; + let mut chars = s.chars().peekable(); + while let Some(c) = chars.next() { + match c { + '"' if !in_quotes => in_quotes = true, + '"' if in_quotes => { + if chars.peek() == Some(&'"') { + cur.push('"'); + chars.next(); + } else { + in_quotes = false; + } + } + ',' if !in_quotes => { + out.push(std::mem::take(&mut cur)); + } + '\\' if in_quotes => { + if let Some(next) = chars.next() { + cur.push(next); + } + } + other => cur.push(other), + } + } + if !cur.is_empty() || s.ends_with(',') { + out.push(cur); + } + out +} + +async fn profile_table_clickhouse( + state: &AppState, + connection_id: &str, + schema: &str, + table: &str, +) -> TuskResult { + let client = state.get_ch_client(connection_id).await?; + let active_db = client.database.clone(); + let dbn = if schema == "public" || schema.is_empty() { + active_db + } else { + schema.to_string() + }; + + let cols_sql = format!( + "SELECT name, type FROM system.columns \ + WHERE database = '{}' AND table = '{}' \ + ORDER BY position LIMIT {}", + dbn.replace('\'', "\\'"), + table.replace('\'', "\\'"), + PROFILE_TABLE_MAX_COLUMNS + ); + let col_rows = client.fetch_objects(&cols_sql).await?; + if col_rows.is_empty() { + return Err(TuskError::Custom(format!( + "Table '{}.{}' does not exist (or no privileges).", + dbn, table + ))); + } + + let mut select_parts: Vec = vec!["count() AS rows_total".to_string()]; + let mut col_names: Vec = Vec::new(); + let mut col_types: Vec = Vec::new(); + for r in &col_rows { + let name = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(); + let dtype = r.get("type").and_then(|v| v.as_str()).unwrap_or("").to_string(); + if name.is_empty() { + continue; + } + col_names.push(name.clone()); + col_types.push(dtype); + let q = name.replace('`', "``"); + select_parts.push(format!("countIf(`{}` IS NULL) AS null_{}", q, col_names.len())); + select_parts.push(format!("uniqHLL12(`{}`) AS dist_{}", q, col_names.len())); + select_parts.push(format!("toString(min(`{}`)) AS min_{}", q, col_names.len())); + select_parts.push(format!("toString(max(`{}`)) AS max_{}", q, col_names.len())); + select_parts.push(format!( + "arrayStringConcat(arrayMap(x -> toString(x), topK({})(`{}`)), '|') AS top_{}", + PROFILE_TABLE_TOPK, + q, + col_names.len() + )); + } + + let agg_sql = format!( + "SELECT {} FROM `{}`.`{}`", + select_parts.join(", "), + dbn.replace('`', "``"), + table.replace('`', "``") + ); + let agg_rows = client.fetch_objects(&agg_sql).await?; + let row = agg_rows + .first() + .ok_or_else(|| TuskError::Custom("ClickHouse returned no row for profile aggregate".into()))?; + + let rows_total = row + .get("rows_total") + .and_then(|v| v.as_str().and_then(|s| s.parse::().ok()).or_else(|| v.as_i64())) + .unwrap_or(0); + + let mut out = format!( + "PROFILE {}.{}\nRows: {}\n{} columns profiled\n", + dbn, + table, + rows_total, + col_names.len() + ); + + for (i, name) in col_names.iter().enumerate() { + let n = i + 1; + let nulls = row + .get(&format!("null_{}", n)) + .and_then(|v| v.as_str().and_then(|s| s.parse::().ok()).or_else(|| v.as_i64())) + .unwrap_or(0); + let dist = row + .get(&format!("dist_{}", n)) + .and_then(|v| v.as_str().and_then(|s| s.parse::().ok()).or_else(|| v.as_i64())) + .unwrap_or(0); + let min = row.get(&format!("min_{}", n)).and_then(|v| v.as_str()).unwrap_or(""); + let max = row.get(&format!("max_{}", n)).and_then(|v| v.as_str()).unwrap_or(""); + let top_raw = row.get(&format!("top_{}", n)).and_then(|v| v.as_str()).unwrap_or(""); + + out.push_str(&format!("\n {} ({}):\n", name, col_types[i])); + let null_frac = if rows_total > 0 { + nulls as f64 / rows_total as f64 + } else { + 0.0 + }; + out.push_str(&format!(" null_frac: {:.4}\n", null_frac)); + out.push_str(&format!(" distinct (HLL): {}\n", dist)); + if !min.is_empty() || !max.is_empty() { + out.push_str(&format!(" range: {} … {}\n", min, max)); + } + if !top_raw.is_empty() { + let top_vals: Vec<&str> = top_raw.split('|').take(PROFILE_TABLE_TOPK).collect(); + out.push_str(&format!(" top: {}\n", top_vals.join(", "))); + } + } + + if col_rows.len() == PROFILE_TABLE_MAX_COLUMNS { + out.push_str(&format!( + "\n…showing first {} columns\n", + PROFILE_TABLE_MAX_COLUMNS + )); + } + Ok(out) +} + +// --------------------------------------------------------------------------- +// sample_data (PR2 — returns SQL string; dispatch site runs it through +// execute_query_core so the QueryResult feeds the standard renderer) +// --------------------------------------------------------------------------- + +pub async fn build_sample_sql( + state: &AppState, + connection_id: &str, + table: &str, + limit: u32, +) -> TuskResult { + let active_db = active_db_name(state, connection_id).await.unwrap_or_default(); + let (schema, tbl, _raw) = normalise_table_ref(table, &active_db); + let flavor = state.get_flavor(connection_id).await; + match flavor { + DbFlavor::PostgreSQL | DbFlavor::Greenplum => { + build_sample_sql_postgres(state, connection_id, &schema, &tbl, limit).await + } + DbFlavor::ClickHouse => { + build_sample_sql_clickhouse(state, connection_id, &schema, &tbl, limit).await + } + } +} + +async fn build_sample_sql_postgres( + state: &AppState, + connection_id: &str, + schema: &str, + table: &str, + limit: u32, +) -> TuskResult { + let pool = state.get_pool(connection_id).await?; + let reltuples: f64 = sqlx::query_scalar( + "SELECT c.reltuples FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \ + WHERE n.nspname = $1 AND c.relname = $2", + ) + .bind(schema) + .bind(table) + .fetch_optional(&pool) + .await + .map_err(TuskError::Database)? + .unwrap_or(0.0); + + let qualified = format!("{}.{}", escape_ident(schema), escape_ident(table)); + if reltuples > 0.0 { + let target = limit as f64 * 100.0 / reltuples; + let percent = target.clamp(0.01, 100.0); + Ok(format!( + "SELECT * FROM {} TABLESAMPLE BERNOULLI({:.4}) LIMIT {}", + qualified, percent, limit + )) + } else { + Ok(format!( + "SELECT * FROM {} ORDER BY random() LIMIT {}", + qualified, limit + )) + } +} + +async fn build_sample_sql_clickhouse( + state: &AppState, + connection_id: &str, + schema: &str, + table: &str, + limit: u32, +) -> TuskResult { + let client = state.get_ch_client(connection_id).await?; + let active_db = client.database.clone(); + let dbn = if schema == "public" || schema.is_empty() { + active_db + } else { + schema.to_string() + }; + + let info_sql = format!( + "SELECT engine, sampling_key FROM system.tables \ + WHERE database = '{}' AND name = '{}' LIMIT 1", + dbn.replace('\'', "\\'"), + table.replace('\'', "\\'") + ); + let rows = client.fetch_objects(&info_sql).await.unwrap_or_default(); + let (engine, sampling_key) = match rows.first() { + Some(r) => ( + r.get("engine").and_then(|v| v.as_str()).unwrap_or("").to_string(), + r.get("sampling_key").and_then(|v| v.as_str()).unwrap_or("").to_string(), + ), + None => (String::new(), String::new()), + }; + + let qualified = format!( + "`{}`.`{}`", + dbn.replace('`', "``"), + table.replace('`', "``") + ); + if engine.starts_with("Merge") && !sampling_key.trim().is_empty() { + Ok(format!( + "SELECT * FROM {} SAMPLE 0.01 LIMIT {}", + qualified, limit + )) + } else { + Ok(format!( + "SELECT * FROM {} ORDER BY rand() LIMIT {}", + qualified, limit + )) + } +} + +// --------------------------------------------------------------------------- +// explain_query (PR2) +// --------------------------------------------------------------------------- + +pub async fn explain_query_tool( + state: &AppState, + connection_id: &str, + sql: &str, +) -> TuskResult { + let trimmed = sql.trim(); + if trimmed.is_empty() { + return Err(TuskError::Custom("explain_query: sql must not be empty".into())); + } + // Validate the user's statement BEFORE prefixing EXPLAIN so the error message + // references their SQL, not the wrapper. ensure_readonly_sql also rejects any + // forbidden keywords (INSERT/UPDATE/DELETE/...) even nested under EXPLAIN. + ensure_readonly_sql(trimmed).map_err(|e| TuskError::Custom(e.to_string()))?; + + let flavor = state.get_flavor(connection_id).await; + match flavor { + DbFlavor::PostgreSQL | DbFlavor::Greenplum => { + explain_query_postgres(state, connection_id, trimmed).await + } + DbFlavor::ClickHouse => explain_query_clickhouse(state, connection_id, trimmed).await, + } +} + +async fn explain_query_postgres( + state: &AppState, + connection_id: &str, + sql: &str, +) -> TuskResult { + let pool = state.get_pool(connection_id).await?; + let plan_sql = format!("EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) {}", sql); + let mut tx = pool.begin().await.map_err(TuskError::Database)?; + sqlx::query("SET TRANSACTION READ ONLY") + .execute(&mut *tx) + .await + .map_err(TuskError::Database)?; + let row = sqlx::query(&plan_sql) + .fetch_one(&mut *tx) + .await + .map_err(TuskError::Database)?; + let _ = tx.rollback().await; + + let raw_json: serde_json::Value = match row.try_get::(0) { + Ok(v) => v, + Err(_) => { + let s: String = row.try_get(0).map_err(TuskError::Database)?; + serde_json::from_str(&s) + .map_err(|e| TuskError::Custom(format!("EXPLAIN JSON parse failed: {}", e)))? + } + }; + + let plans = raw_json + .as_array() + .ok_or_else(|| TuskError::Custom("EXPLAIN JSON: expected array".into()))?; + let plan = plans.first().and_then(|p| p.get("Plan")).ok_or_else(|| { + TuskError::Custom("EXPLAIN JSON: missing top-level Plan node".into()) + })?; + + let root_node = plan.get("Node Type").and_then(|v| v.as_str()).unwrap_or("?"); + let total_cost = plan.get("Total Cost").and_then(|v| v.as_f64()).unwrap_or(0.0); + let planning = plans + .first() + .and_then(|p| p.get("Planning Time").and_then(|v| v.as_f64())) + .unwrap_or(0.0); + let execution = plans + .first() + .and_then(|p| p.get("Execution Time").and_then(|v| v.as_f64())) + .unwrap_or(0.0); + + let mut seq_scans: Vec = Vec::new(); + let mut spilled: Vec = Vec::new(); + let mut motions: Vec = Vec::new(); + let mut max_skew: Option<(f64, String)> = None; + walk_pg_plan(plan, &mut seq_scans, &mut spilled, &mut motions, &mut max_skew); + + let mut out = format!( + "PLAN root: {}, total cost {:.1}\nPlanning: {:.2} ms Execution: {:.2} ms\n", + root_node, total_cost, planning, execution + ); + if !seq_scans.is_empty() { + out.push_str(&format!("Seq scans on: {}\n", seq_scans.join(", "))); + } + if !spilled.is_empty() { + out.push_str(&format!("Spilled to disk: {}\n", spilled.join(", "))); + } + if !motions.is_empty() { + out.push_str(&format!("Motions (Greenplum): {}\n", motions.join(", "))); + } + if let Some((ratio, node)) = max_skew { + if ratio >= 5.0 { + out.push_str(&format!( + "Estimate skew: max plan/actual ratio = {:.1} on {}\n", + ratio, node + )); + } + } + if seq_scans.is_empty() && spilled.is_empty() && motions.is_empty() { + out.push_str("No obvious red flags.\n"); + } + Ok(out) +} + +fn walk_pg_plan( + node: &serde_json::Value, + seq_scans: &mut Vec, + spilled: &mut Vec, + motions: &mut Vec, + max_skew: &mut Option<(f64, String)>, +) { + let node_type = node.get("Node Type").and_then(|v| v.as_str()).unwrap_or(""); + if node_type == "Seq Scan" { + let rel = node + .get("Relation Name") + .and_then(|v| v.as_str()) + .unwrap_or("?"); + let schema = node + .get("Schema") + .and_then(|v| v.as_str()) + .map(|s| format!("{}.", s)) + .unwrap_or_default(); + seq_scans.push(format!("{}{}", schema, rel)); + } + if let Some(method) = node.get("Sort Method").and_then(|v| v.as_str()) { + if method.contains("disk") || method.contains("external") { + spilled.push(format!("Sort ({})", method)); + } + } + if node_type.contains("Motion") { + motions.push(node_type.to_string()); + } + let plan_rows = node.get("Plan Rows").and_then(|v| v.as_f64()).unwrap_or(0.0); + let actual_rows = node.get("Actual Rows").and_then(|v| v.as_f64()).unwrap_or(0.0); + if actual_rows > 0.0 && plan_rows > 0.0 { + let ratio = (plan_rows / actual_rows).max(actual_rows / plan_rows); + if max_skew.as_ref().map(|(r, _)| ratio > *r).unwrap_or(true) { + *max_skew = Some((ratio, node_type.to_string())); + } + } + if let Some(children) = node.get("Plans").and_then(|v| v.as_array()) { + for child in children { + walk_pg_plan(child, seq_scans, spilled, motions, max_skew); + } + } +} + +async fn explain_query_clickhouse( + state: &AppState, + connection_id: &str, + sql: &str, +) -> TuskResult { + let client = state.get_ch_client(connection_id).await?; + let plan_sql = format!("EXPLAIN PLAN {}", sql); + let qr = client.execute_query(&plan_sql, true).await?; + if qr.rows.is_empty() { + return Ok("(empty plan)".to_string()); + } + let mut out = String::from("ClickHouse plan:\n"); + for row in &qr.rows { + if let Some(cell) = row.first() { + if let Some(s) = cell.as_str() { + out.push_str(s); + out.push('\n'); + } + } + } + Ok(out) +} + +// --------------------------------------------------------------------------- +// detect_skew (PR2 — Greenplum-only) +// --------------------------------------------------------------------------- + +pub async fn detect_skew_tool( + state: &AppState, + connection_id: &str, + table: &str, +) -> TuskResult { + let flavor = state.get_flavor(connection_id).await; + if !matches!(flavor, DbFlavor::Greenplum) { + return Ok("detect_skew is only available on Greenplum connections.".to_string()); + } + let active_db = active_db_name(state, connection_id).await.unwrap_or_default(); + let (schema, tbl, _raw) = normalise_table_ref(table, &active_db); + + let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&tbl)); + let sql = format!( + "SELECT gp_segment_id, COUNT(*) AS n FROM {} GROUP BY 1 ORDER BY 1", + qualified + ); + let qr = execute_query_core(state, connection_id, &sql).await?; + + let mut counts: Vec<(i64, i64)> = Vec::new(); + for row in &qr.rows { + let seg = row + .get(0) + .and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok()))) + .unwrap_or(0); + let n = row + .get(1) + .and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok()))) + .unwrap_or(0); + counts.push((seg, n)); + } + + if counts.is_empty() { + return Ok(format!("Table {}.{} is empty.", schema, tbl)); + } + + let total: i64 = counts.iter().map(|(_, n)| *n).sum(); + let max = counts.iter().map(|(_, n)| *n).max().unwrap_or(0); + let min = counts.iter().map(|(_, n)| *n).min().unwrap_or(0); + let avg = total as f64 / counts.len() as f64; + let ratio = if avg > 0.0 { max as f64 / avg } else { 0.0 }; + + let mut out = format!( + "Per-segment row distribution for {}.{}\nsegments: {} total rows: {}\nmin: {} max: {} avg: {:.0}\nskew ratio (max/avg): {:.2}", + schema, + tbl, + counts.len(), + total, + min, + max, + avg, + ratio + ); + if ratio > 1.5 { + out.push_str(" ⚠ uneven distribution\n"); + } else { + out.push_str(" OK — within 1.5x of average\n"); + } + + let pool = state.get_pool(connection_id).await?; + if let Some(policy) = fetch_gp_distribution_for(&pool, &schema, &tbl).await { + out.push_str(&format!("\nCurrent policy: {}\n", policy)); + if ratio > 1.5 { + out.push_str( + "Hint: pick a higher-cardinality column. Run profile_table to compare n_distinct.\n", + ); + } + } + Ok(out) +} + +/// Fetch the Greenplum DISTRIBUTED BY policy for a single table. Returns None if +/// the catalog query fails (non-GP connection, missing privileges, etc.). +async fn fetch_gp_distribution_for( + pool: &PgPool, + schema: &str, + table: &str, +) -> Option { + let row = sqlx::query( + "SELECT COALESCE(\ + (SELECT array_agg(a.attname ORDER BY ord.idx) \ + FROM regexp_split_to_table(NULLIF(trim(p.distkey::text), ''), ' ') \ + WITH ORDINALITY AS ord(attnum_str, idx) \ + JOIN pg_attribute a \ + ON a.attrelid = c.oid \ + AND a.attnum::int = ord.attnum_str::int), \ + ARRAY[]::text[] \ + ) AS dist_columns \ + FROM gp_distribution_policy p \ + JOIN pg_class c ON p.localoid = c.oid \ + JOIN pg_namespace n ON c.relnamespace = n.oid \ + WHERE n.nspname = $1 AND c.relname = $2", + ) + .bind(schema) + .bind(table) + .fetch_optional(pool) + .await + .ok() + .flatten()?; + let cols: Vec = row.try_get(0).ok()?; + Some(if cols.is_empty() { + "DISTRIBUTED RANDOMLY".to_string() + } else { + format!("DISTRIBUTED BY ({})", cols.join(", ")) + }) +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index ed653c6..5dd6f48 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -111,9 +111,7 @@ pub fn run() { commands::ai::save_ai_settings, commands::ai::list_ollama_models, commands::ai::list_fireworks_models, - commands::ai::generate_sql, - commands::ai::explain_sql, - commands::ai::fix_sql_error, + commands::ai::list_openrouter_models, // chat commands::chat::chat_send, commands::chat::chat_compact, diff --git a/src-tauri/src/models/ai.rs b/src-tauri/src/models/ai.rs index 586b031..4fc4575 100644 --- a/src-tauri/src/models/ai.rs +++ b/src-tauri/src/models/ai.rs @@ -1,13 +1,26 @@ -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)] #[serde(rename_all = "lowercase")] pub enum AiProvider { #[default] Ollama, - OpenAi, - Anthropic, Fireworks, + OpenRouter, +} + +/// Deserialize a provider string, coercing legacy `openai`/`anthropic` and any +/// unknown value to `Ollama`. Keeps existing config files loadable after the +/// stub providers were removed. +impl<'de> Deserialize<'de> for AiProvider { + fn deserialize>(d: D) -> Result { + let s = String::deserialize(d)?; + Ok(match s.as_str() { + "fireworks" => AiProvider::Fireworks, + "openrouter" => AiProvider::OpenRouter, + _ => AiProvider::Ollama, + }) + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -15,11 +28,9 @@ pub struct AiSettings { pub provider: AiProvider, pub ollama_url: String, #[serde(default)] - pub openai_api_key: Option, - #[serde(default)] - pub anthropic_api_key: Option, - #[serde(default)] pub fireworks_api_key: Option, + #[serde(default)] + pub openrouter_api_key: Option, pub model: String, } @@ -28,9 +39,8 @@ impl Default for AiSettings { Self { provider: AiProvider::Ollama, ollama_url: "http://localhost:11434".to_string(), - openai_api_key: None, - anthropic_api_key: None, fireworks_api_key: None, + openrouter_api_key: None, model: String::new(), } } @@ -71,7 +81,9 @@ pub struct OllamaModel { } // --------------------------------------------------------------------------- -// Fireworks (OpenAI-compatible chat-completions) +// OpenAI-compatible chat-completions (Fireworks, OpenRouter) +// These request/response shapes are shared by every OpenAI-compatible provider; +// the `Fireworks*` names are retained for historical reasons. // --------------------------------------------------------------------------- #[derive(Debug, Clone, Serialize)] diff --git a/src-tauri/src/models/chat.rs b/src-tauri/src/models/chat.rs index 2893b50..69f9f6f 100644 --- a/src-tauri/src/models/chat.rs +++ b/src-tauri/src/models/chat.rs @@ -31,18 +31,3 @@ pub struct ChatTurnResult { pub messages: Vec, pub usage: ContextUsage, } - -/// Chart configuration produced by the agent's `make_chart` tool. -/// Embedded as JSON in `ToolResult.text` for tool == "make_chart" while the -/// underlying data lives in `ToolResult.result`. The frontend reads both to -/// render the chart inline. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ChartConfig { - pub chart_type: String, // "bar" | "line" | "area" | "pie" - pub x: String, // column name for X axis / category - pub y: String, // column name for Y axis / numeric value - pub group: Option, // optional column for series grouping - pub title: Option, - pub orientation: Option, // "vertical" | "horizontal" — bar only -} - diff --git a/src-tauri/src/state.rs b/src-tauri/src/state.rs index 3336c65..2ec21cd 100644 --- a/src-tauri/src/state.rs +++ b/src-tauri/src/state.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Instant; use tokio::sync::{watch, RwLock}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -17,12 +17,6 @@ pub enum DbFlavor { } -#[derive(Clone)] -pub struct SchemaCacheEntry { - pub schema_text: String, - pub cached_at: Instant, -} - #[derive(Clone)] pub struct CachedString { pub value: String, @@ -43,23 +37,16 @@ pub struct AppState { /// Greenplum major version (6 or 7), tracked separately because GP6 and GP7 /// expose very different system catalogs (GP6 = PG9.4 base, GP7 = PG14 base). pub gp_majors: RwLock>, - /// Legacy cache used by generate_sql/explain_sql/fix_sql_error — full DDL. - pub schema_cache: RwLock>, - /// Chat v2 caches: lite overview per connection. + /// Chat agent caches: lite overview per connection. pub overview_cache: RwLock>, - /// Chat v2 caches: list of tables per (connection_id, db_name) — used for + /// Chat agent caches: list of tables per (connection_id, db_name) — used for /// list_tables on a non-active PG database via temporary pool. pub tables_by_db_cache: RwLock>>, - /// Chat v2 caches: column block per (connection_id, db_name, "schema.table"). - pub columns_cache: RwLock>, pub mcp_shutdown_tx: watch::Sender, pub mcp_running: RwLock, pub ai_settings: RwLock>, } -const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes -const SCHEMA_CACHE_MAX_SIZE: usize = 100; - impl AppState { pub fn new() -> Self { let (mcp_shutdown_tx, _) = watch::channel(false); @@ -69,10 +56,8 @@ impl AppState { read_only: RwLock::new(HashMap::new()), db_flavors: RwLock::new(HashMap::new()), gp_majors: RwLock::new(HashMap::new()), - schema_cache: RwLock::new(HashMap::new()), overview_cache: RwLock::new(HashMap::new()), tables_by_db_cache: RwLock::new(HashMap::new()), - columns_cache: RwLock::new(HashMap::new()), mcp_shutdown_tx, mcp_running: RwLock::new(false), ai_settings: RwLock::new(None), @@ -82,16 +67,11 @@ impl AppState { /// Drop every chat-agent cache entry tied to this connection. /// Called by switch_database_core, disconnect, and on connection delete. pub async fn invalidate_chat_caches_for(&self, connection_id: &str) { - self.schema_cache.write().await.remove(connection_id); self.overview_cache.write().await.remove(connection_id); self.tables_by_db_cache .write() .await .retain(|(cid, _), _| cid != connection_id); - self.columns_cache - .write() - .await - .retain(|(cid, _, _), _| cid != connection_id); } pub async fn get_pool(&self, connection_id: &str) -> TuskResult { @@ -125,39 +105,4 @@ impl AppState { pub async fn get_gp_major(&self, id: &str) -> Option { self.gp_majors.read().await.get(id).copied() } - - pub async fn get_schema_cache(&self, connection_id: &str) -> Option { - let cache = self.schema_cache.read().await; - cache.get(connection_id).and_then(|entry| { - if entry.cached_at.elapsed() < SCHEMA_CACHE_TTL { - Some(entry.schema_text.clone()) - } else { - None - } - }) - } - - pub async fn set_schema_cache(&self, connection_id: String, schema_text: String) { - let mut cache = self.schema_cache.write().await; - // Evict stale entries to prevent unbounded memory growth - cache.retain(|_, entry| entry.cached_at.elapsed() < SCHEMA_CACHE_TTL); - // If still at capacity, remove the oldest entry - if cache.len() >= SCHEMA_CACHE_MAX_SIZE { - if let Some(oldest_key) = cache - .iter() - .min_by_key(|(_, e)| e.cached_at) - .map(|(k, _)| k.clone()) - { - cache.remove(&oldest_key); - } - } - cache.insert( - connection_id, - SchemaCacheEntry { - schema_text, - cached_at: Instant::now(), - }, - ); - } - } diff --git a/src/components/ai/AiBar.tsx b/src/components/ai/AiBar.tsx deleted file mode 100644 index 5adf802..0000000 --- a/src/components/ai/AiBar.tsx +++ /dev/null @@ -1,103 +0,0 @@ -import { useState } from "react"; -import { Button } from "@/components/ui/button"; -import { Input } from "@/components/ui/input"; -import { AiSettingsPopover } from "./AiSettingsPopover"; -import { useGenerateSql } from "@/hooks/use-ai"; -import { Sparkles, Loader2, X, Eraser } from "lucide-react"; -import { toast } from "sonner"; - -interface Props { - connectionId: string; - onSqlGenerated: (sql: string) => void; - onClose: () => void; - onExecute?: () => void; -} - -export function AiBar({ connectionId, onSqlGenerated, onClose, onExecute }: Props) { - const [prompt, setPrompt] = useState(""); - const generateMutation = useGenerateSql(); - - const handleGenerate = () => { - if (!prompt.trim() || generateMutation.isPending) return; - generateMutation.mutate( - { connectionId, prompt }, - { - onSuccess: (sql) => { - onSqlGenerated(sql); - }, - onError: (err) => { - toast.error("AI generation failed", { description: String(err) }); - }, - } - ); - }; - - const handleKeyDown = (e: React.KeyboardEvent) => { - if (e.key === "Enter" && (e.ctrlKey || e.metaKey)) { - e.preventDefault(); - e.stopPropagation(); - onExecute?.(); - return; - } - if (e.key === "Enter" && !e.shiftKey) { - e.preventDefault(); - e.stopPropagation(); - handleGenerate(); - return; - } - if (e.key === "Escape") { - e.stopPropagation(); - onClose(); - } - }; - - return ( -
- - setPrompt(e.target.value)} - onKeyDown={handleKeyDown} - placeholder="Describe the query you want..." - className="h-7 min-w-0 flex-1 border-tusk-purple/20 bg-tusk-purple/5 text-xs placeholder:text-muted-foreground/40 focus:border-tusk-purple/40 focus:ring-tusk-purple/20" - autoFocus - disabled={generateMutation.isPending} - /> - - {prompt.trim() && ( - - )} - - -
- ); -} diff --git a/src/components/ai/AiSettingsFields.tsx b/src/components/ai/AiSettingsFields.tsx index 6c2047e..6aa7d2e 100644 --- a/src/components/ai/AiSettingsFields.tsx +++ b/src/components/ai/AiSettingsFields.tsx @@ -7,7 +7,11 @@ import { SelectTrigger, SelectValue, } from "@/components/ui/select"; -import { useFireworksModels, useOllamaModels } from "@/hooks/use-ai"; +import { + useFireworksModels, + useOllamaModels, + useOpenRouterModels, +} from "@/hooks/use-ai"; import { RefreshCw, Loader2 } from "lucide-react"; import type { AiProvider, OllamaModel } from "@/types"; @@ -17,6 +21,8 @@ interface Props { onOllamaUrlChange: (url: string) => void; fireworksApiKey: string; onFireworksApiKeyChange: (key: string) => void; + openrouterApiKey: string; + onOpenRouterApiKeyChange: (key: string) => void; model: string; onModelChange: (model: string) => void; } @@ -27,6 +33,8 @@ export function AiSettingsFields({ onOllamaUrlChange, fireworksApiKey, onFireworksApiKeyChange, + openrouterApiKey, + onOpenRouterApiKeyChange, model, onModelChange, }: Props) { @@ -41,6 +49,17 @@ export function AiSettingsFields({ ); } + if (provider === "openrouter") { + return ( + + ); + } + return ( void; + model: string; + onModelChange: (model: string) => void; +}) { + const { + data: models, + isLoading, + isError, + refetch, + } = useOpenRouterModels(apiKey); + + return ( + <> +
+ + onApiKeyChange(e.target.value)} + placeholder="sk-or-..." + className="h-8 text-xs" + autoComplete="off" + /> +

+ Stored locally; sent only to openrouter.ai. +

+
+ + refetch()} + model={model} + onModelChange={onModelChange} + emptyHint={apiKey.trim() ? "Click ↻ to load models" : "Enter API key first"} + /> + + ); +} + function ModelDropdown({ models, loading, diff --git a/src/components/ai/AiSettingsPopover.tsx b/src/components/ai/AiSettingsPopover.tsx index 4cd7344..3282fed 100644 --- a/src/components/ai/AiSettingsPopover.tsx +++ b/src/components/ai/AiSettingsPopover.tsx @@ -21,6 +21,7 @@ import type { AiProvider } from "@/types"; const SUPPORTED_PROVIDERS: { value: AiProvider; label: string }[] = [ { value: "ollama", label: "Ollama (local)" }, { value: "fireworks", label: "Fireworks AI" }, + { value: "openrouter", label: "OpenRouter" }, ]; export function AiSettingsPopover() { @@ -30,22 +31,16 @@ export function AiSettingsPopover() { const [provider, setProvider] = useState(null); const [url, setUrl] = useState(null); const [fireworksKey, setFireworksKey] = useState(null); + const [openrouterKey, setOpenrouterKey] = useState(null); const [model, setModel] = useState(null); - const settingsProvider = settings?.provider; - // Hide unsupported legacy values (openai/anthropic) from the selector. - const normalizedSettingsProvider: AiProvider | undefined = - settingsProvider === "ollama" || settingsProvider === "fireworks" - ? settingsProvider - : settingsProvider - ? "ollama" - : undefined; - const currentProvider: AiProvider = - provider ?? normalizedSettingsProvider ?? "ollama"; + provider ?? settings?.provider ?? "ollama"; const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434"; const currentFireworksKey = fireworksKey ?? settings?.fireworks_api_key ?? ""; + const currentOpenrouterKey = + openrouterKey ?? settings?.openrouter_api_key ?? ""; const currentModel = model ?? settings?.model ?? ""; const handleProviderChange = (next: AiProvider) => { @@ -64,6 +59,10 @@ export function AiSettingsPopover() { currentProvider === "fireworks" ? currentFireworksKey.trim() || undefined : settings?.fireworks_api_key, + openrouter_api_key: + currentProvider === "openrouter" + ? currentOpenrouterKey.trim() || undefined + : settings?.openrouter_api_key, model: currentModel, }, { @@ -117,6 +116,8 @@ export function AiSettingsPopover() { onOllamaUrlChange={setUrl} fireworksApiKey={currentFireworksKey} onFireworksApiKeyChange={setFireworksKey} + openrouterApiKey={currentOpenrouterKey} + onOpenRouterApiKeyChange={setOpenrouterKey} model={currentModel} onModelChange={setModel} /> diff --git a/src/components/chat/ChartPreview.tsx b/src/components/chat/ChartPreview.tsx deleted file mode 100644 index f079d48..0000000 --- a/src/components/chat/ChartPreview.tsx +++ /dev/null @@ -1,327 +0,0 @@ -import { useMemo } from "react"; -import { - Area, - AreaChart, - Bar, - BarChart, - CartesianGrid, - Cell, - Legend, - Line, - LineChart, - Pie, - PieChart, - ResponsiveContainer, - Tooltip, - XAxis, - YAxis, -} from "recharts"; -import type { ChartConfig } from "@/types"; - -interface Props { - config: ChartConfig; - columns: string[]; - rows: unknown[][]; - height?: number; -} - -const PALETTE = [ - "#60a5fa", // blue-400 - "#34d399", // emerald-400 - "#fbbf24", // amber-400 - "#f87171", // red-400 - "#a78bfa", // violet-400 - "#22d3ee", // cyan-400 - "#fb923c", // orange-400 - "#f472b6", // pink-400 -]; - -const MAX_POINTS = 500; - -export function ChartPreview({ config, columns, rows, height = 280 }: Props) { - const xIdx = columns.indexOf(config.x); - const yIdx = columns.indexOf(config.y); - const groupIdx = config.group ? columns.indexOf(config.group) : -1; - - const limited = useMemo(() => rows.slice(0, MAX_POINTS), [rows]); - - if (xIdx < 0 || yIdx < 0) { - return ( - - ); - } - - // Coerce y values to numbers; chart libs need numeric Y. - const numericY = (v: unknown): number => { - if (typeof v === "number") return v; - if (typeof v === "string") { - const n = parseFloat(v); - return Number.isFinite(n) ? n : 0; - } - return 0; - }; - - const labelX = (v: unknown): string => { - if (v == null) return "—"; - if (typeof v === "string") return v; - if (typeof v === "number" || typeof v === "boolean") return String(v); - return JSON.stringify(v); - }; - - const isGrouped = groupIdx >= 0; - - // ──────────── grouped data shape ──────────── - // For multi-series: pivot to { x: , : yVal, : yVal, … } - // Used by line, area, and grouped-bar. - const pivoted = useMemo(() => { - if (!isGrouped) return null; - const map = new Map>(); - const groupSet = new Set(); - for (const row of limited) { - const xv = labelX(row[xIdx]); - const gv = labelX(row[groupIdx!]); - const yv = numericY(row[yIdx]); - groupSet.add(gv); - const acc = map.get(xv) ?? { _x: xv }; - acc[gv] = ((acc[gv] as number) ?? 0) + yv; - map.set(xv, acc); - } - return { - data: Array.from(map.values()), - groups: Array.from(groupSet), - }; - }, [isGrouped, limited, xIdx, yIdx, groupIdx]); - - // Single series shape: [{ _x, _y }] - const flat = useMemo(() => { - return limited.map((row) => ({ - _x: labelX(row[xIdx]), - _y: numericY(row[yIdx]), - })); - }, [limited, xIdx, yIdx]); - - const tickStyle = { - fill: "var(--muted-foreground)", - fontSize: 10, - } as const; - - const axisLine = { - stroke: "rgba(255, 255, 255, 0.08)", - } as const; - - const tooltipStyle = { - backgroundColor: "var(--popover)", - border: "1px solid var(--border)", - borderRadius: 6, - fontSize: 11, - } as const; - - if (config.chart_type === "pie") { - // Pie: aggregate y by x label (sum), no group support. - const agg = new Map(); - for (const row of limited) { - const xv = labelX(row[xIdx]); - agg.set(xv, (agg.get(xv) ?? 0) + numericY(row[yIdx])); - } - const data = Array.from(agg.entries()).map(([name, value]) => ({ name, value })); - return ( - - - - - typeof entry.name === "string" && entry.name.length < 20 ? entry.name : "" - } - > - {data.map((_, i) => ( - - ))} - - - - - - - ); - } - - if (config.chart_type === "line") { - return ( - - - - - - - - {isGrouped ? ( - <> - - {pivoted!.groups.map((g, i) => ( - - ))} - - ) : ( - - )} - - - - ); - } - - if (config.chart_type === "area") { - return ( - - - - - - - - {isGrouped ? ( - <> - - {pivoted!.groups.map((g, i) => ( - - ))} - - ) : ( - - )} - - - - ); - } - - // bar (default) - const horizontal = config.orientation === "horizontal"; - return ( - - - - - {horizontal ? ( - <> - - - - ) : ( - <> - - - - )} - - {isGrouped ? ( - <> - - {pivoted!.groups.map((g, i) => ( - - ))} - - ) : ( - - )} - - - - ); -} - -function ChartFrame({ - config, - height, - count, - totalRows, - children, -}: { - config: ChartConfig; - height: number; - count: number; - totalRows: number; - children: React.ReactNode; -}) { - return ( -
-
- - {config.title ?? `${capitalize(config.chart_type)} chart`} - - - {count} point{count === 1 ? "" : "s"} - {totalRows > MAX_POINTS && ` (of ${totalRows}, capped at ${MAX_POINTS})`} - -
-
- {children} -
-
- ); -} - -function ChartFallback({ config, message }: { config: ChartConfig; message: string }) { - return ( -
-
- Chart {config.chart_type} failed -
-
{message}
-
- ); -} - -function capitalize(s: string) { - return s.charAt(0).toUpperCase() + s.slice(1); -} diff --git a/src/components/chat/ChatMessageView.tsx b/src/components/chat/ChatMessageView.tsx index 8db7693..682c052 100644 --- a/src/components/chat/ChatMessageView.tsx +++ b/src/components/chat/ChatMessageView.tsx @@ -1,7 +1,6 @@ import { useState } from "react"; import { ResultsTable } from "@/components/results/ResultsTable"; import { ExportDialog } from "@/components/export/ExportDialog"; -import { ChartPreview } from "./ChartPreview"; import { Dialog, DialogContent, @@ -15,19 +14,12 @@ import { AlertCircle, Sparkles, User, - Wrench, Database, - Columns, - Layers, - RefreshCw, - StickyNote, - Bookmark, - BookmarkPlus, Maximize2, Download, - BarChart3, } from "lucide-react"; -import type { ChartConfig, ChatMessage } from "@/types"; +import type { ChatMessage } from "@/types"; +import { getToolMeta, isQueryResultTool } from "./tool-registry"; interface Props { message: ChatMessage; @@ -79,8 +71,10 @@ function AssistantBubble({ text }: { text: string }) { function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) { const [expanded, setExpanded] = useState(false); - const preview = extractToolPreview(tool, inputJson); - const Icon = iconForTool(tool); + const meta = getToolMeta(tool); + const preview = previewFromJson(tool, inputJson); + const Icon = meta.icon; + const showSqlPreview = (tool === "run_query" || tool === "explain_query") && preview; return (
@@ -91,17 +85,14 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) > {expanded ? : } - {labelForTool(tool)} + {meta.label} {preview && ( - - {preview.slice(0, 80)} - {preview.length > 80 ? "…" : ""} - + {preview} )} {expanded && (
- {tool === "run_query" && preview ? ( + {showSqlPreview ? (
               {preview}
             
@@ -116,6 +107,15 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) ); } +function previewFromJson(tool: string, inputJson: string): string | null { + try { + const parsed = JSON.parse(inputJson) as Record; + return getToolMeta(tool).preview(parsed); + } catch { + return null; + } +} + function ToolResultBlock({ tool, isError, @@ -132,87 +132,20 @@ function ToolResultBlock({
-
{labelForTool(tool)} failed
+
{getToolMeta(tool).label} failed
{text &&
{text}
}
); } - // Legacy schema tool — keep a one-line indicator for old threads. - if (tool === "get_schema") { - return ( -
- - Loaded schema context ({text?.length ?? 0} chars) -
- ); - } - - // Text-only tools (chat v2/v3): list_databases, list_tables, get_columns, switch_database, - // remember, save_query, find_queries. - if ( - tool === "list_databases" || - tool === "list_tables" || - tool === "get_columns" || - tool === "switch_database" || - tool === "remember" || - tool === "save_query" || - tool === "find_queries" - ) { - return ; - } - - // make_chart — render chart inline using config from text + data from result. - if (tool === "make_chart") { - return ; - } - - // run_query — full results table with Open-full / Export actions. - if (result) { + // Tools that produce a QueryResult (rendered as a table): run_query, sample_data. + if (isQueryResultTool(tool) && result) { return ; } - return null; -} - -function ChartToolResult({ - text, - result, -}: { - text: string | null; - result: { columns: string[]; types: string[]; rows: unknown[][]; row_count: number; execution_time_ms: number } | null; -}) { - let config: ChartConfig | null = null; - try { - if (text) { - config = JSON.parse(text) as ChartConfig; - } - } catch { - config = null; - } - if (!config || !result) { - return ( -
- -
-
Chart unavailable
-
- The agent referenced a chart but the previous query result is not attached. -
-
-
- ); - } - return ( -
- -
- ); + // Everything else falls back to a collapsible text block. + return ; } function RunQueryResultBlock({ @@ -315,8 +248,10 @@ function RunQueryResultBlock({ } function TextToolResult({ tool, text }: { tool: string; text: string | null }) { + // Lazy preview: switch_database is short; everything else collapses by default. const [expanded, setExpanded] = useState(tool === "switch_database"); - const Icon = iconForTool(tool); + const meta = getToolMeta(tool); + const Icon = meta.icon; const lineCount = text ? text.split("\n").length : 0; return ( @@ -328,7 +263,7 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) { > {expanded ? : } - {labelForTool(tool)} + {meta.label} {text && ( {lineCount} line{lineCount === 1 ? "" : "s"} @@ -346,93 +281,6 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) { ); } -function labelForTool(tool: string): string { - switch (tool) { - case "run_query": - return "Run SQL"; - case "list_databases": - return "List databases"; - case "list_tables": - return "List tables"; - case "get_columns": - return "Inspect columns"; - case "switch_database": - return "Switch database"; - case "remember": - return "Remember"; - case "save_query": - return "Save query"; - case "find_queries": - return "Find saved queries"; - case "make_chart": - return "Make chart"; - case "get_schema": - return "Load schema"; - default: - return tool; - } -} - -function iconForTool(tool: string) { - switch (tool) { - case "run_query": - return Wrench; - case "list_databases": - return Database; - case "list_tables": - return Layers; - case "get_columns": - return Columns; - case "switch_database": - return RefreshCw; - case "remember": - return StickyNote; - case "save_query": - return BookmarkPlus; - case "find_queries": - return Bookmark; - case "make_chart": - return BarChart3; - case "get_schema": - return Database; - default: - return Wrench; - } -} - -function extractToolPreview(tool: string, inputJson: string): string | null { - try { - const parsed = JSON.parse(inputJson) as Record; - switch (tool) { - case "run_query": - return typeof parsed.sql === "string" ? parsed.sql : null; - case "list_tables": - return typeof parsed.database === "string" ? parsed.database : null; - case "switch_database": - return typeof parsed.database === "string" ? parsed.database : null; - case "get_columns": - return Array.isArray(parsed.tables) ? parsed.tables.join(", ") : null; - case "remember": - return typeof parsed.note === "string" ? parsed.note : null; - case "save_query": - return typeof parsed.name === "string" ? parsed.name : null; - case "find_queries": - return typeof parsed.text === "string" ? parsed.text : null; - case "make_chart": { - const t = typeof parsed.chart_type === "string" ? parsed.chart_type : null; - const x = typeof parsed.x === "string" ? parsed.x : null; - const y = typeof parsed.y === "string" ? parsed.y : null; - if (t && x && y) return `${t}: ${x} → ${y}`; - return null; - } - default: - return null; - } - } catch { - return null; - } -} - function prettyJson(s: string): string { try { return JSON.stringify(JSON.parse(s), null, 2); diff --git a/src/components/chat/tool-registry.ts b/src/components/chat/tool-registry.ts new file mode 100644 index 0000000..27efaec --- /dev/null +++ b/src/components/chat/tool-registry.ts @@ -0,0 +1,107 @@ +import { + Database, + Layers, + Columns, + RefreshCw, + Wrench, + StickyNote, + Bookmark, + BookmarkPlus, + Activity, + Shuffle, + GitBranch, + AlertTriangle, +} from "lucide-react"; +import type { LucideIcon } from "lucide-react"; + +export type ToolMeta = { + icon: LucideIcon; + label: string; + preview: (parsed: Record) => string | null; +}; + +const truncate = (s: unknown, n = 80): string | null => { + if (typeof s !== "string") return null; + return s.length > n ? `${s.slice(0, n)}…` : s; +}; + +export const TOOLS: Record = { + list_databases: { + icon: Database, + label: "List databases", + preview: () => null, + }, + list_tables: { + icon: Layers, + label: "List tables", + preview: (p) => (typeof p.database === "string" ? p.database : null), + }, + get_columns: { + icon: Columns, + label: "Inspect columns", + preview: (p) => (Array.isArray(p.tables) ? (p.tables as string[]).join(", ") : null), + }, + switch_database: { + icon: RefreshCw, + label: "Switch database", + preview: (p) => (typeof p.database === "string" ? p.database : null), + }, + run_query: { + icon: Wrench, + label: "Run SQL", + preview: (p) => truncate(p.sql), + }, + remember: { + icon: StickyNote, + label: "Remember", + preview: (p) => (typeof p.note === "string" ? p.note : null), + }, + save_query: { + icon: BookmarkPlus, + label: "Save query", + preview: (p) => (typeof p.name === "string" ? p.name : null), + }, + find_queries: { + icon: Bookmark, + label: "Find saved queries", + preview: (p) => (typeof p.text === "string" ? p.text : null), + }, + profile_table: { + icon: Activity, + label: "Profile table", + preview: (p) => (typeof p.table === "string" ? p.table : null), + }, + sample_data: { + icon: Shuffle, + label: "Sample rows", + preview: (p) => { + const t = typeof p.table === "string" ? p.table : ""; + const limit = typeof p.limit === "number" ? p.limit : 50; + return t ? `${t} (${limit})` : null; + }, + }, + explain_query: { + icon: GitBranch, + label: "Explain query", + preview: (p) => truncate(p.sql), + }, + detect_skew: { + icon: AlertTriangle, + label: "Detect skew", + preview: (p) => (typeof p.table === "string" ? p.table : null), + }, +}; + +export function getToolMeta(tool: string): ToolMeta { + return ( + TOOLS[tool] ?? { + icon: Wrench, + label: tool, + preview: () => null, + } + ); +} + +export function isQueryResultTool(tool: string): boolean { + return tool === "run_query" || tool === "sample_data"; +} diff --git a/src/components/results/ResultsPanel.tsx b/src/components/results/ResultsPanel.tsx index 9067499..2604e92 100644 --- a/src/components/results/ResultsPanel.tsx +++ b/src/components/results/ResultsPanel.tsx @@ -1,8 +1,7 @@ import { ResultsTable } from "./ResultsTable"; import { ResultsJsonView } from "./ResultsJsonView"; import type { QueryResult } from "@/types"; -import { Loader2, AlertCircle, Sparkles, Wand2 } from "lucide-react"; -import { Button } from "@/components/ui/button"; +import { Loader2, AlertCircle } from "lucide-react"; interface Props { result?: QueryResult | null; @@ -15,10 +14,6 @@ interface Props { value: unknown ) => void; highlightedCells?: Set; - aiExplanation?: string | null; - isAiLoading?: boolean; - onExplainError?: () => void; - onFixError?: () => void; } export function ResultsPanel({ @@ -28,10 +23,6 @@ export function ResultsPanel({ viewMode = "table", onCellDoubleClick, highlightedCells, - aiExplanation, - isAiLoading, - onExplainError, - onFixError, }: Props) { if (isLoading) { return ( @@ -42,22 +33,6 @@ export function ResultsPanel({ ); } - if (aiExplanation) { - return ( -
-
-
- - AI Explanation -
-
-            {aiExplanation}
-          
-
-
- ); - } - if (error) { return (
@@ -65,42 +40,6 @@ export function ResultsPanel({
{error}
- {(onExplainError || onFixError) && ( -
- {onExplainError && ( - - )} - {onFixError && ( - - )} -
- )}
); } diff --git a/src/components/settings/AppSettingsSheet.tsx b/src/components/settings/AppSettingsSheet.tsx index c1956a4..b1f67b0 100644 --- a/src/components/settings/AppSettingsSheet.tsx +++ b/src/components/settings/AppSettingsSheet.tsx @@ -27,6 +27,7 @@ import type { AiProvider, AppSettings } from "@/types"; const SUPPORTED_AI_PROVIDERS: { value: AiProvider; label: string }[] = [ { value: "ollama", label: "Ollama (local)" }, { value: "fireworks", label: "Fireworks AI" }, + { value: "openrouter", label: "OpenRouter" }, ]; interface Props { @@ -50,6 +51,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) { const [aiProvider, setAiProvider] = useState("ollama"); const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434"); const [fireworksApiKey, setFireworksApiKey] = useState(""); + const [openrouterApiKey, setOpenrouterApiKey] = useState(""); const [aiModel, setAiModel] = useState(""); const [copied, setCopied] = useState(false); @@ -70,10 +72,14 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) { if (aiSettings) { // Legacy openai/anthropic values aren't user-selectable here — fall back to ollama. setAiProvider( - aiSettings.provider === "fireworks" ? "fireworks" : "ollama" + aiSettings.provider === "fireworks" || + aiSettings.provider === "openrouter" + ? aiSettings.provider + : "ollama" ); setOllamaUrl(aiSettings.ollama_url); setFireworksApiKey(aiSettings.fireworks_api_key ?? ""); + setOpenrouterApiKey(aiSettings.openrouter_api_key ?? ""); setAiModel(aiSettings.model); } } @@ -115,6 +121,10 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) { aiProvider === "fireworks" ? fireworksApiKey.trim() || undefined : aiSettings?.fireworks_api_key, + openrouter_api_key: + aiProvider === "openrouter" + ? openrouterApiKey.trim() || undefined + : aiSettings?.openrouter_api_key, model: aiModel, }, { @@ -167,7 +177,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) { @@ -189,7 +199,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) { title="Copy endpoint URL" > {copied ? ( - + ) : ( )} @@ -229,6 +239,8 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) { onOllamaUrlChange={setOllamaUrl} fireworksApiKey={fireworksApiKey} onFireworksApiKeyChange={setFireworksApiKey} + openrouterApiKey={openrouterApiKey} + onOpenRouterApiKeyChange={setOpenrouterApiKey} model={aiModel} onModelChange={setAiModel} /> diff --git a/src/components/workspace/WorkspacePanel.tsx b/src/components/workspace/WorkspacePanel.tsx index c0fcf98..9e65a7b 100644 --- a/src/components/workspace/WorkspacePanel.tsx +++ b/src/components/workspace/WorkspacePanel.tsx @@ -13,7 +13,7 @@ import { useCompletionSchema } from "@/hooks/use-completion-schema"; import { useConnections } from "@/hooks/use-connections"; import { useAppStore } from "@/stores/app-store"; import { Button } from "@/components/ui/button"; -import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles, BrainCircuit } from "lucide-react"; +import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces } from "lucide-react"; import { format as formatSql } from "sql-formatter"; import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog"; import { @@ -25,8 +25,6 @@ import { import { exportCsv, exportJson } from "@/lib/tauri"; import { save } from "@tauri-apps/plugin-dialog"; import { toast } from "sonner"; -import { AiBar } from "@/components/ai/AiBar"; -import { useExplainSql, useFixSqlError } from "@/hooks/use-ai"; import type { QueryResult, ExplainResult } from "@/types"; interface Props { @@ -53,12 +51,8 @@ export function WorkspacePanel({ const [resultView, setResultView] = useState<"results" | "explain">("results"); const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table"); const [saveDialogOpen, setSaveDialogOpen] = useState(false); - const [aiBarOpen, setAiBarOpen] = useState(false); - const [aiExplanation, setAiExplanation] = useState(null); const queryMutation = useQueryExecution(); - const explainMutation = useExplainSql(); - const fixMutation = useFixSqlError(); const addHistoryMutation = useAddHistory(); const { data: connections } = useConnections(); const { data: completionSchema } = useCompletionSchema(connectionId); @@ -102,7 +96,6 @@ export function WorkspacePanel({ if (!sqlValue.trim() || !connectionId) return; setError(null); setExplainData(null); - setAiExplanation(null); setResultView("results"); queryMutation.mutate( { connectionId, sql: sqlValue }, @@ -196,60 +189,6 @@ export function WorkspacePanel({ [result] ); - const isAiLoading = explainMutation.isPending || fixMutation.isPending; - - const handleAiExplain = useCallback(() => { - if (!sqlValue.trim() || !connectionId) return; - setAiExplanation(null); - setResultView("results"); - explainMutation.mutate( - { connectionId, sql: sqlValue }, - { - onSuccess: (explanation) => { - setAiExplanation(explanation); - }, - onError: (err) => { - toast.error("AI Explain failed", { description: String(err) }); - }, - } - ); - }, [connectionId, sqlValue, explainMutation]); - - const handleExplainError = useCallback(() => { - if (!sqlValue.trim() || !connectionId || !error) return; - setAiExplanation(null); - explainMutation.mutate( - { connectionId, sql: `${sqlValue}\n\n-- Error: ${error}` }, - { - onSuccess: (explanation) => { - setAiExplanation(explanation); - }, - onError: (err) => { - toast.error("AI Explain failed", { description: String(err) }); - }, - } - ); - }, [connectionId, sqlValue, error, explainMutation]); - - const handleFixError = useCallback(() => { - if (!sqlValue.trim() || !connectionId || !error) return; - fixMutation.mutate( - { connectionId, sql: sqlValue, errorMessage: error }, - { - onSuccess: (fixedSql) => { - setSqlValue(fixedSql); - onSqlChange?.(fixedSql); - setError(null); - setAiExplanation(null); - toast.success("SQL replaced by AI suggestion"); - }, - onError: (err) => { - toast.error("AI Fix failed", { description: String(err) }); - }, - } - ); - }, [connectionId, sqlValue, error, fixMutation, onSqlChange]); - return ( <> @@ -308,35 +247,6 @@ export function WorkspacePanel({ Save -
- - {/* AI actions group — purple-branded */} - - - {result && result.columns.length > 0 && ( <>
@@ -369,23 +279,12 @@ export function WorkspacePanel({ {"\u2318"}Enter {isReadOnly && ( - + READ )}
- {aiBarOpen && ( - { - setSqlValue(sql); - onSqlChange?.(sql); - }} - onClose={() => setAiBarOpen(false)} - onExecute={handleExecute} - /> - )}
- {(explainData || result || error || aiExplanation) && ( + {(explainData || result || error) && (
diff --git a/src/hooks/use-ai.ts b/src/hooks/use-ai.ts index 322e7f9..6a7b855 100644 --- a/src/hooks/use-ai.ts +++ b/src/hooks/use-ai.ts @@ -4,9 +4,7 @@ import { saveAiSettings, listOllamaModels, listFireworksModels, - generateSql, - explainSql, - fixSqlError, + listOpenRouterModels, } from "@/lib/tauri"; import type { AiSettings } from "@/types"; @@ -47,40 +45,12 @@ export function useFireworksModels(apiKey: string | undefined) { }); } -export function useGenerateSql() { - return useMutation({ - mutationFn: ({ - connectionId, - prompt, - }: { - connectionId: string; - prompt: string; - }) => generateSql(connectionId, prompt), - }); -} - -export function useExplainSql() { - return useMutation({ - mutationFn: ({ - connectionId, - sql, - }: { - connectionId: string; - sql: string; - }) => explainSql(connectionId, sql), - }); -} - -export function useFixSqlError() { - return useMutation({ - mutationFn: ({ - connectionId, - sql, - errorMessage, - }: { - connectionId: string; - sql: string; - errorMessage: string; - }) => fixSqlError(connectionId, sql, errorMessage), +export function useOpenRouterModels(apiKey: string | undefined) { + return useQuery({ + queryKey: ["openrouter-models", apiKey], + queryFn: () => listOpenRouterModels(apiKey!), + enabled: !!apiKey && apiKey.trim().length > 0, + retry: false, + staleTime: 60_000, }); } diff --git a/src/lib/tauri.ts b/src/lib/tauri.ts index 39f6f1d..f9c5770 100644 --- a/src/lib/tauri.ts +++ b/src/lib/tauri.ts @@ -214,14 +214,8 @@ export const listOllamaModels = (ollamaUrl: string) => export const listFireworksModels = (apiKey: string) => invoke("list_fireworks_models", { apiKey }); -export const generateSql = (connectionId: string, prompt: string) => - invoke("generate_sql", { connectionId, prompt }); - -export const explainSql = (connectionId: string, sql: string) => - invoke("explain_sql", { connectionId, sql }); - -export const fixSqlError = (connectionId: string, sql: string, errorMessage: string) => - invoke("fix_sql_error", { connectionId, sql, errorMessage }); +export const listOpenRouterModels = (apiKey: string) => + invoke("list_openrouter_models", { apiKey }); export const chatSend = (connectionId: string, messages: ChatMessage[]) => invoke("chat_send", { connectionId, messages }); diff --git a/src/types/index.ts b/src/types/index.ts index da75819..0b7a15d 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -134,14 +134,13 @@ export interface SavedQuery { created_at: string; } -export type AiProvider = "ollama" | "openai" | "anthropic" | "fireworks"; +export type AiProvider = "ollama" | "fireworks" | "openrouter"; export interface AiSettings { provider: AiProvider; ollama_url: string; - openai_api_key?: string; - anthropic_api_key?: string; fireworks_api_key?: string; + openrouter_api_key?: string; model: string; } @@ -216,14 +215,3 @@ export interface ChatTurnResult { messages: ChatMessage[]; usage: ContextUsage; } - -export type ChartType = "bar" | "line" | "area" | "pie"; - -export interface ChartConfig { - chart_type: ChartType; - x: string; - y: string; - group?: string | null; - title?: string | null; - orientation?: string | null; -}