refactor(ai): consolidate AI around chat tool-calling; add OpenRouter

- rework chat backend (chat.rs, chat_tools.rs, ai.rs, models, state) around tool calls
- add OpenRouter provider alongside Ollama/Fireworks in settings
- drop inline AiBar, ResultsPanel explain/fix UI and ChartPreview in favour of the chat panel
- add frontend chat tool-registry
This commit is contained in:
2026-05-23 15:01:52 +03:00
parent a485cf7ee3
commit 0cba457fb7
19 changed files with 1244 additions and 1931 deletions

View File

@@ -15,6 +15,13 @@ use tauri::{AppHandle, Manager, State};
const MAX_RETRIES: u32 = 2;
const RETRY_DELAY_MS: u64 = 1000;
const FIREWORKS_BASE_URL: &str = "https://api.fireworks.ai/inference/v1";
const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
/// Optional attribution headers OpenRouter uses for its public app leaderboard.
/// Harmless to send on every request; ignored by other OpenAI-compatible APIs.
const OPENROUTER_HEADERS: &[(&str, &str)] = &[
("HTTP-Referer", "https://github.com/codelab/tusk"),
("X-Title", "Tusk"),
];
fn http_client() -> &'static reqwest::Client {
use std::sync::LazyLock;
@@ -89,32 +96,51 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaMode
#[tauri::command]
pub async fn list_fireworks_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
list_openai_compatible_models(FIREWORKS_BASE_URL, &api_key, "Fireworks", &[]).await
}
#[tauri::command]
pub async fn list_openrouter_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
list_openai_compatible_models(OPENROUTER_BASE_URL, &api_key, "OpenRouter", OPENROUTER_HEADERS)
.await
}
/// List the available chat models for any OpenAI-compatible provider via its
/// `GET {base_url}/models` endpoint. `extra_headers` carries provider-specific
/// attribution headers (OpenRouter recommends `HTTP-Referer`/`X-Title`).
async fn list_openai_compatible_models(
base_url: &str,
api_key: &str,
provider_label: &str,
extra_headers: &[(&str, &str)],
) -> TuskResult<Vec<OllamaModel>> {
let key = api_key.trim();
if key.is_empty() {
return Err(TuskError::Ai("Fireworks API key required".to_string()));
return Err(TuskError::Ai(format!("{} API key required", provider_label)));
}
let url = format!("{}/models", FIREWORKS_BASE_URL);
let resp = http_client()
.get(&url)
.bearer_auth(key)
let url = format!("{}/models", base_url);
let mut req = http_client().get(&url).bearer_auth(key);
for (name, value) in extra_headers {
req = req.header(*name, *value);
}
let resp = req
.send()
.await
.map_err(|e| TuskError::Ai(format!("Cannot reach Fireworks: {}", e)))?;
.map_err(|e| TuskError::Ai(format!("Cannot reach {}: {}", provider_label, e)))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}",
status, body
"{} error ({}): {}",
provider_label, status, body
)));
}
let parsed: FireworksModelsResponse = resp
.json()
.await
.map_err(|e| TuskError::Ai(format!("Failed to parse Fireworks models list: {}", e)))?;
let parsed: FireworksModelsResponse = resp.json().await.map_err(|e| {
TuskError::Ai(format!("Failed to parse {} models list: {}", provider_label, e))
})?;
Ok(parsed
.data
@@ -180,33 +206,8 @@ pub(crate) async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskR
Ok(settings)
}
async fn call_chat_simple(
app: &AppHandle,
state: &AppState,
system_prompt: String,
user_content: String,
) -> TuskResult<String> {
call_chat_messages(
app,
state,
vec![
OllamaChatMessage {
role: "system".to_string(),
content: system_prompt,
},
OllamaChatMessage {
role: "user".to_string(),
content: user_content,
},
],
None,
)
.await
}
/// Provider-agnostic chat-completions dispatcher used by every LLM-driven feature
/// (chat agent, generate_sql, explain_sql, fix_sql_error). Returns the model's
/// raw text content.
/// Provider-agnostic chat-completions dispatcher used by the chat agent.
/// Returns the model's raw text content.
pub(crate) async fn call_chat_messages(
app: &AppHandle,
state: &AppState,
@@ -223,14 +224,49 @@ pub(crate) async fn call_chat_messages(
match settings.provider {
AiProvider::Ollama => call_ollama(&settings, messages, format).await,
AiProvider::Fireworks => call_fireworks(&settings, messages, format).await,
AiProvider::OpenAi | AiProvider::Anthropic => Err(TuskError::Ai(format!(
"Provider {:?} not implemented yet",
settings.provider
))),
AiProvider::Fireworks => {
let api_key = require_api_key(
settings.fireworks_api_key.as_deref(),
"Fireworks API key not set. Open AI settings to add it.",
)?;
call_openai_compatible(
&settings,
FIREWORKS_BASE_URL,
&api_key,
"Fireworks",
&[],
messages,
format,
)
.await
}
AiProvider::OpenRouter => {
let api_key = require_api_key(
settings.openrouter_api_key.as_deref(),
"OpenRouter API key not set. Open AI settings to add it.",
)?;
call_openai_compatible(
&settings,
OPENROUTER_BASE_URL,
&api_key,
"OpenRouter",
OPENROUTER_HEADERS,
messages,
format,
)
.await
}
}
}
/// Trim and validate an optional API key, returning a user-facing error when
/// it's missing or blank.
fn require_api_key(key: Option<&str>, missing_msg: &str) -> TuskResult<String> {
key.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.ok_or_else(|| TuskError::Ai(missing_msg.to_string()))
}
async fn call_ollama(
settings: &AiSettings,
messages: Vec<OllamaChatMessage>,
@@ -277,21 +313,18 @@ async fn call_ollama(
.await
}
async fn call_fireworks(
/// Chat-completions call for any OpenAI-compatible provider (Fireworks,
/// OpenRouter). `extra_headers` carries provider-specific attribution headers.
async fn call_openai_compatible(
settings: &AiSettings,
base_url: &str,
api_key: &str,
provider_label: &str,
extra_headers: &[(&str, &str)],
messages: Vec<OllamaChatMessage>,
format: Option<String>,
) -> TuskResult<String> {
let api_key = settings
.fireworks_api_key
.clone()
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.ok_or_else(|| {
TuskError::Ai("Fireworks API key not set. Open AI settings to add it.".to_string())
})?;
let url = format!("{}/chat/completions", FIREWORKS_BASE_URL);
let url = format!("{}/chat/completions", base_url);
let response_format = format.as_deref().map(|f| FireworksResponseFormat {
kind: if f == "json" {
"json_object".to_string()
@@ -307,19 +340,22 @@ async fn call_fireworks(
response_format,
};
call_ai_with_retry(settings, "Fireworks request", || {
let operation = format!("{} request", provider_label);
call_ai_with_retry(settings, &operation, || {
let url = url.clone();
let request = request.clone();
let api_key = api_key.clone();
let api_key = api_key.to_string();
async move {
let resp = http_client()
.post(&url)
.bearer_auth(&api_key)
let mut req = http_client().post(&url).bearer_auth(&api_key);
for (name, value) in extra_headers {
req = req.header(*name, *value);
}
let resp = req
.json(&request)
.send()
.await
.map_err(|e| {
TuskError::Ai(format!("Cannot reach Fireworks at {}: {}", url, e))
TuskError::Ai(format!("Cannot reach {} at {}: {}", provider_label, url, e))
})?;
if !resp.status().is_success() {
@@ -339,13 +375,13 @@ async fn call_fireworks(
));
}
return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}",
status, body
"{} error ({}): {}",
provider_label, status, body
)));
}
let parsed: FireworksChatResponse = resp.json().await.map_err(|e| {
TuskError::Ai(format!("Failed to parse Fireworks response: {}", e))
TuskError::Ai(format!("Failed to parse {} response: {}", provider_label, e))
})?;
parsed
@@ -353,177 +389,14 @@ async fn call_fireworks(
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| TuskError::Ai("Fireworks returned no choices".to_string()))
.ok_or_else(|| {
TuskError::Ai(format!("{} returned no choices", provider_label))
})
}
})
.await
}
// ---------------------------------------------------------------------------
// SQL generation
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn generate_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
prompt: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are an expert PostgreSQL query generator. You receive a database schema and a natural \
language request. Output ONLY a valid, executable PostgreSQL SQL query.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences (```), no comments, no preamble.\n\
- The output must be directly executable in psql.\n\
- For complex queries use readable formatting with line breaks and indentation.\n\
\n\
CRITICAL RULES:\n\
1. ONLY reference tables and columns that exist in the schema. Never invent names.\n\
2. Use the FOREIGN KEY information to determine correct JOIN conditions.\n\
3. Use LEFT JOIN when the FK column is nullable or the relationship is optional; \
INNER JOIN when both sides must exist.\n\
4. Every non-aggregated column in SELECT must appear in GROUP BY.\n\
5. Use COALESCE for nullable columns in aggregations: COALESCE(SUM(x), 0).\n\
6. For enum columns, use ONLY the values listed in the ENUM TYPES section.\n\
7. Use IS NULL / IS NOT NULL for null checks — never = NULL or != NULL.\n\
8. Add LIMIT when the user asks for \"top N\", \"first N\", \"latest N\", etc.\n\
9. Qualify column names with table alias when the query involves multiple tables.\n\
\n\
SEMANTIC RULES (very important):\n\
- When a table has both actual_* and planned_* columns (e.g. actual_start vs planned_start), \
they represent DIFFERENT concepts: planned = future estimate, actual = what really happened. \
NEVER mix them with COALESCE unless the user explicitly requests a fallback.\n\
- For time-based calculations involving real events (\"how long did X take\", \"average time between\"), \
use ONLY actual/factual timestamps (actual_*, started_at, completed_at, ended_at). \
Filter out NULL values with WHERE instead of falling back to planned timestamps.\n\
- Planned timestamps (planned_*, scheduled_*, estimated_*) should ONLY be used when the user \
asks about plans, schedules, SLA, or compares plan vs fact.\n\
- When computing durations or averages, always filter out rows where any involved timestamp \
is NULL rather than substituting with unrelated defaults.\n\
- Pay attention to column descriptions/comments in the schema — they reveal business semantics \
that are critical for correct queries.\n\
\n\
TYPE RULES:\n\
- timestamp - timestamp = interval. For seconds: EXTRACT(EPOCH FROM (ts1 - ts2)).\n\
- interval cannot be cast to numeric directly; use EXTRACT(EPOCH FROM interval).\n\
- UNION/UNION ALL requires matching column count and compatible types; cast enums to text.\n\
- Use ::type for PostgreSQL-style casts.\n\
- For array columns use ANY, ALL, @>, <@ operators.\n\
- For JSONB columns use ->, ->>, #>, jsonb_extract_path.\n\
\n\
COMMON PATTERNS:\n\
- FIRST/LAST per group: to find MIN(started_at) per trip, use \
\"SELECT trip_id, MIN(started_at) FROM t GROUP BY trip_id\". \
NEVER put the aggregated column (started_at) into GROUP BY — that defeats the aggregation \
and returns every row separately instead of one per group.\n\
- TOP-1 per group with extra columns: use DISTINCT ON (group_col) ... ORDER BY group_col, sort_col \
or a subquery with ROW_NUMBER() OVER (PARTITION BY group_col ORDER BY sort_col) = 1.\n\
- For \"time from A to B\" calculations, ensure both timestamps are NOT NULL with WHERE filters; \
never use COALESCE to mix planned and actual timestamps.\n\
\n\
BEST PRACTICES:\n\
- Use ILIKE for case-insensitive text search, LIKE for case-sensitive.\n\
- Use EXISTS instead of IN for subquery existence checks.\n\
- Use CTE (WITH ... AS) for complex multi-step logic.\n\
- Use window functions (ROW_NUMBER, RANK, LAG, LEAD, SUM OVER) for ranking and running totals.\n\
- Use date_trunc('period', column) for time-based grouping.\n\
- Use generate_series() for creating ranges.\n\
- Use string_agg(col, ', ') for concatenating grouped values.\n\
- Use FILTER (WHERE ...) for conditional aggregation instead of CASE inside aggregate.\n\
\n\
{}\n",
schema_text
);
let raw = call_chat_simple(&app, &state, system_prompt, prompt).await?;
Ok(clean_sql_response(&raw))
}
// ---------------------------------------------------------------------------
// SQL explanation
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn explain_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a PostgreSQL expert. Explain the given SQL query clearly and concisely.\n\
\n\
Structure your explanation as:\n\
1. **Summary** — one sentence describing what the query returns in business terms.\n\
2. **Step-by-step breakdown** — explain tables accessed, joins, filters, aggregations, \
subqueries, and sorting. Use bullet points.\n\
3. **Notes** — mention edge cases, potential issues, or performance concerns if any.\n\
\n\
Use the database schema below to understand table relationships and column meanings.\n\
Keep the explanation short; avoid restating the SQL verbatim.\n\
\n\
IMPORTANT: If you notice semantic issues (e.g. mixing planned_* and actual_* timestamps \
with COALESCE, comparing unrelated columns, missing NULL filters on nullable timestamps), \
mention them in the Notes section as potential problems.\n\
\n\
{}\n",
schema_text
);
call_chat_simple(&app, &state, system_prompt, sql).await
}
// ---------------------------------------------------------------------------
// SQL error fixing
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn fix_sql_error(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
error_message: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a PostgreSQL expert debugger. You receive a SQL query and the error it produced. \
Fix the query so it executes correctly.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences (```), no comments.\n\
- The output must be directly executable.\n\
\n\
DIAGNOSTIC CHECKLIST:\n\
- Column/table does not exist → check the schema for correct names and spelling.\n\
- Column is ambiguous → qualify with table name or alias.\n\
- Must appear in GROUP BY → add missing non-aggregated columns to GROUP BY.\n\
- Type mismatch → add appropriate casts (::text, ::integer, etc.).\n\
- Permission denied → wrap in a read-only transaction if needed.\n\
- Syntax error → correct PostgreSQL syntax (check commas, parentheses, keywords).\n\
- Subquery returns more than one row → use IN, ANY, or add LIMIT 1.\n\
- Division by zero → wrap divisor with NULLIF(x, 0).\n\
\n\
ONLY use tables and columns from the schema below. Never invent names.\n\
Preserve the original intent of the query; change only what is necessary to fix the error.\n\
\n\
{}\n",
schema_text
);
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
let raw = call_chat_simple(&app, &state, system_prompt, user_content).await?;
Ok(clean_sql_response(&raw))
}
// ---------------------------------------------------------------------------
// Lite overview builder (chat v2)
// ---------------------------------------------------------------------------
@@ -735,231 +608,6 @@ async fn build_overview_clickhouse(state: &AppState, connection_id: &str) -> Tus
Ok(out.join("\n"))
}
// ---------------------------------------------------------------------------
// Full schema context builder (legacy — used by generate_sql/explain_sql/fix_sql_error)
// ---------------------------------------------------------------------------
pub(crate) async fn build_schema_context(
state: &AppState,
connection_id: &str,
) -> TuskResult<String> {
// Check cache first
if let Some(cached) = state.get_schema_cache(connection_id).await {
return Ok(cached);
}
let flavor = state.get_flavor(connection_id).await;
if matches!(flavor, DbFlavor::ClickHouse) {
return build_clickhouse_schema_context(state, connection_id).await;
}
let is_greenplum = matches!(flavor, DbFlavor::Greenplum);
let gp_major = state.get_gp_major(connection_id).await.unwrap_or(7);
let pool = state.get_pool(connection_id).await?;
// Run all metadata queries in parallel for speed
let (
version_res,
current_db_res,
all_dbs_res,
col_res,
fk_res,
enum_res,
tbl_comment_res,
col_comment_res,
unique_res,
varchar_res,
jsonb_res,
) = tokio::join!(
sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool),
sqlx::query_scalar::<_, String>("SELECT current_database()").fetch_one(&pool),
sqlx::query_scalar::<_, String>(
"SELECT datname FROM pg_database \
WHERE datistemplate = false AND datallowconn = true \
ORDER BY datname"
)
.fetch_all(&pool),
fetch_columns(&pool),
fetch_foreign_keys_raw(&pool),
fetch_enum_types(&pool),
fetch_table_comments(&pool),
fetch_column_comments(&pool),
fetch_unique_constraints(&pool),
fetch_varchar_values(&pool),
fetch_jsonb_keys(&pool),
);
let version = version_res.map_err(TuskError::Database)?;
let current_db = current_db_res.unwrap_or_default();
let all_dbs = all_dbs_res.unwrap_or_default();
let col_rows = col_res?;
let fk_rows = fk_res?;
let enum_map = enum_res?;
let tbl_comments = tbl_comment_res?;
let col_comments = col_comment_res?;
let unique_constraints = unique_res?;
let varchar_values = varchar_res.unwrap_or_default();
let jsonb_keys = jsonb_res.unwrap_or_default();
let gp_extras = if is_greenplum {
Some(fetch_gp_table_extras(&pool, gp_major).await)
} else {
None
};
// -- Build FK inline lookup: (schema, table, column) -> "ref_schema.ref_table(ref_col)" --
let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new();
let mut fk_lines: Vec<String> = Vec::new();
for fk in &fk_rows {
let line = format!(
"FK: {}.{}({}) -> {}.{}({})",
fk.schema,
fk.table,
fk.columns.join(", "),
fk.ref_schema,
fk.ref_table,
fk.ref_columns.join(", ")
);
fk_lines.push(line);
// For single-column FKs, enable inline annotation on column
if fk.columns.len() == 1 && fk.ref_columns.len() == 1 {
fk_inline.insert(
(fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()),
format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]),
);
}
}
// -- Build unique constraint lookup: (schema, table) -> Vec<column_list_string> --
let mut unique_map: HashMap<(String, String), Vec<String>> = HashMap::new();
for (schema, table, cols) in &unique_constraints {
unique_map
.entry((schema.clone(), table.clone()))
.or_default()
.push(cols.join(", "));
}
// -- Format output --
let mut output: Vec<String> = Vec::new();
// 1. PostgreSQL version (short form)
let short_version = version
.split_whitespace()
.take(2)
.collect::<Vec<_>>()
.join(" ");
output.push(format!("DATABASE SCHEMA ({})", short_version));
if !current_db.is_empty() {
output.push(format!("ACTIVE DATABASE: {}", current_db));
}
output.push(String::new());
// 2. Cluster topology — other databases on this server.
// Each PG database is isolated; cross-DB queries are not possible from a single connection.
if all_dbs.len() > 1 {
output.push("DATABASES ON THIS SERVER:".to_string());
for db in &all_dbs {
if db == &current_db {
output.push(format!(" * {} (active)", db));
} else {
output.push(format!(" {}", db));
}
}
output.push(String::new());
output.push(
"NOTE: Tables in other databases are NOT queryable from this session. \
If the user's question concerns data likely stored in a different database \
(e.g. an identity service in a separate DB), respond with a `final` message \
asking them to switch the active database via the connection selector."
.to_string(),
);
output.push(String::new());
}
// 3. Quick table+column index for fast existence checks before writing SQL.
// Each line lists `schema.table(col1, col2, ...)` so the model can grep both
// table names and column names without scrolling through the full TABLES section.
{
let mut by_table: BTreeMap<(String, String), Vec<String>> = BTreeMap::new();
for c in &col_rows {
by_table
.entry((c.schema.clone(), c.table.clone()))
.or_default()
.push(c.column.clone());
}
if !by_table.is_empty() {
output.push(format!(
"TABLE INDEX (database `{}`, {} tables — table_name(column_list)):",
current_db,
by_table.len()
));
for ((schema, table), cols) in &by_table {
output.push(format!(" {}.{}({})", schema, table, cols.join(", ")));
}
output.push(String::new());
}
}
// 4. Enum types
if !enum_map.is_empty() {
output.push("ENUM TYPES:".to_string());
for (type_name, values) in &enum_map {
let vals_str = values
.iter()
.map(|v| format!("'{}'", v))
.collect::<Vec<_>>()
.join(", ");
output.push(format!(" {} = [{}]", type_name, vals_str));
}
output.push(String::new());
}
// 5. Tables with columns
output.push("TABLES:".to_string());
// Group columns by schema.table preserving order
let mut tables: BTreeMap<String, Vec<ColumnInfo>> = BTreeMap::new();
for ci in &col_rows {
let key = format!("{}.{}", ci.schema, ci.table);
tables.entry(key).or_default().push(ci.clone());
}
for (full_name, columns) in &tables {
format_table_block(
full_name,
columns,
&tbl_comments,
&col_comments,
&fk_inline,
&enum_map,
&unique_map,
&varchar_values,
&jsonb_keys,
gp_extras.as_ref(),
&mut output,
);
}
// 6. Foreign keys summary
if !fk_lines.is_empty() {
output.push(String::new());
output.push("FOREIGN KEYS:".to_string());
for fk in &fk_lines {
output.push(format!(" {}", fk));
}
}
let result = output.join("\n");
// Cache the result
state
.set_schema_cache(connection_id.to_string(), result.clone())
.await;
Ok(result)
}
// ---------------------------------------------------------------------------
// Schema query helpers
// ---------------------------------------------------------------------------
@@ -976,8 +624,7 @@ pub(crate) struct ColumnInfo {
}
/// Render a single table's column block in the human/LLM-readable schema format.
/// Reused by both `build_schema_context` (full DDL for legacy AI commands) and
/// the new `get_columns` chat tool.
/// Used by the chat agent's `get_columns` tool.
#[allow(clippy::too_many_arguments)]
pub(crate) fn format_table_block(
full_name: &str,
@@ -1102,106 +749,6 @@ pub(crate) fn format_table_block(
}
}
async fn build_clickhouse_schema_context(
state: &AppState,
connection_id: &str,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let db = client.database.clone();
// ClickHouse exposes ALL databases via system.* — pull cross-DB schema in one shot.
let columns_sql = "SELECT database, table, name, type, is_in_primary_key \
FROM system.columns \
WHERE database NOT IN ('system','INFORMATION_SCHEMA','information_schema') \
ORDER BY database, table, position";
let dbs_sql = "SELECT name FROM system.databases \
WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \
ORDER BY name";
let rows = client.fetch_objects(columns_sql).await?;
let db_rows = client.fetch_objects(dbs_sql).await.unwrap_or_default();
let version = client.ping().await.unwrap_or_default();
let mut out = String::new();
out.push_str("DATABASE: ClickHouse\n");
if !version.is_empty() {
out.push_str(&format!("VERSION: {}\n", version.trim()));
}
out.push_str(&format!("ACTIVE_DATABASE: {}\n\n", db));
// Cluster overview
if !db_rows.is_empty() {
out.push_str("DATABASES ON THIS SERVER:\n");
for row in &db_rows {
let name = row.get("name").and_then(|v| v.as_str()).unwrap_or("");
if name == db {
out.push_str(&format!(" * {} (active)\n", name));
} else {
out.push_str(&format!(" {}\n", name));
}
}
out.push('\n');
}
// Table+column index — same shape as the PG path so model has a uniform reference.
{
let mut by_table: BTreeMap<(String, String), Vec<String>> = BTreeMap::new();
for r in &rows {
let dbn = r.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string();
let tbl = r.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string();
let col = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
if !dbn.is_empty() && !tbl.is_empty() && !col.is_empty() {
by_table.entry((dbn, tbl)).or_default().push(col);
}
}
if !by_table.is_empty() {
out.push_str(&format!(
"TABLE INDEX ({} tables across all databases — db.table(column_list)):\n",
by_table.len()
));
for ((dbn, tbl), cols) in &by_table {
out.push_str(&format!(" {}.{}({})\n", dbn, tbl, cols.join(", ")));
}
out.push('\n');
}
}
out.push_str("TABLES:\n");
let mut current_key: Option<String> = None;
for row in &rows {
let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string();
let table = row.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string();
let column = row.get("name").and_then(|v| v.as_str()).unwrap_or("");
let dtype = row.get("type").and_then(|v| v.as_str()).unwrap_or("");
let is_pk = matches!(row.get("is_in_primary_key"), Some(serde_json::Value::Number(n)) if n.as_i64() == Some(1))
|| matches!(row.get("is_in_primary_key"), Some(serde_json::Value::String(s)) if s == "1");
let key = format!("{}.{}", dbn, table);
if Some(&key) != current_key.as_ref() {
out.push_str(&format!("\nTABLE {}.{}\n", dbn, table));
current_key = Some(key);
}
out.push_str(&format!(
" {} {}{}\n",
column,
dtype,
if is_pk { " [PK]" } else { "" }
));
}
out.push_str(
"\nNOTES:\n\
- Use ClickHouse SQL dialect. Functions differ from PostgreSQL (e.g. count(), arrayJoin, toDate, formatDateTime).\n\
- ClickHouse allows fully-qualified `database.table` in queries — you CAN cross-reference databases on this server.\n\
- Read-only mode is enforced for the agent: only SELECT/WITH/EXPLAIN/SHOW/DESCRIBE allowed.\n\
- Always include LIMIT for ad-hoc SELECTs.\n",
);
state
.set_schema_cache(connection_id.to_string(), out.clone())
.await;
Ok(out)
}
pub(crate) async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult<Vec<ColumnInfo>> {
let rows = sqlx::query(
"SELECT \
@@ -1399,6 +946,7 @@ pub(crate) async fn fetch_unique_constraints(
/// Returns HashMap<(schema, table, column), Vec<distinct_values>> for varchar columns
/// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery.
/// Returns None if pg_stats is not accessible (graceful degradation).
#[allow(dead_code)] // re-exposed by profile_table tool (PR2)
async fn fetch_varchar_values(
pool: &sqlx::PgPool,
) -> Option<HashMap<(String, String, String), Vec<String>>> {
@@ -1440,104 +988,6 @@ async fn fetch_varchar_values(
Some(map)
}
/// Discovers top-level keys in JSONB columns by sampling actual data.
/// Runs two sequential queries internally: first discovers JSONB columns,
/// then samples keys from each via a single UNION ALL query.
/// Returns None on error (graceful degradation).
async fn fetch_jsonb_keys(
pool: &sqlx::PgPool,
) -> Option<HashMap<(String, String, String), Vec<String>>> {
// Step 1: Find all JSONB columns
let col_rows = match sqlx::query(
"SELECT table_schema, table_name, column_name \
FROM information_schema.columns \
WHERE data_type = 'jsonb' \
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
ORDER BY table_schema, table_name, column_name",
)
.fetch_all(pool)
.await
{
Ok(r) => r,
Err(e) => {
log::warn!("Failed to fetch JSONB columns: {}", e);
return None;
}
};
if col_rows.is_empty() {
return Some(HashMap::new());
}
// Cap at 50 JSONB columns to prevent unbounded UNION ALL queries on large schemas
let columns: Vec<(String, String, String)> = col_rows
.iter()
.take(50)
.map(|r| {
(
r.get::<String, _>(0),
r.get::<String, _>(1),
r.get::<String, _>(2),
)
})
.collect();
// Step 2: Build a single UNION ALL query to sample keys from all JSONB columns
let parts: Vec<String> = columns
.iter()
.enumerate()
.map(|(i, (schema, table, col))| {
let qs = schema.replace('"', "\"\"");
let qt = table.replace('"', "\"\"");
let qc = col.replace('"', "\"\"");
format!(
"(SELECT '{}.{}.{}' AS col_ref, key FROM (\
SELECT DISTINCT jsonb_object_keys(\"{}\") AS key \
FROM \"{}\".\"{}\" \
WHERE \"{}\" IS NOT NULL AND jsonb_typeof(\"{}\") = 'object' \
LIMIT 50\
) sub{})",
schema, table, col, qc, qs, qt, qc, qc, i
)
})
.collect();
let query = parts.join(" UNION ALL ");
let rows = match sqlx::query(&query).fetch_all(pool).await {
Ok(r) => r,
Err(e) => {
log::warn!("Failed to fetch JSONB keys: {}", e);
return None;
}
};
let mut map: HashMap<(String, String, String), Vec<String>> = HashMap::new();
for r in &rows {
let col_ref: String = r.get(0);
let key: String = r.get(1);
let ref_parts: Vec<&str> = col_ref.splitn(3, '.').collect();
if ref_parts.len() == 3 {
let entry = map
.entry((
ref_parts[0].to_string(),
ref_parts[1].to_string(),
ref_parts[2].to_string(),
))
.or_default();
if !entry.contains(&key) {
entry.push(key);
}
}
}
for vals in map.values_mut() {
vals.sort();
}
Some(map)
}
// ---------------------------------------------------------------------------
// Greenplum-specific table attributes
// ---------------------------------------------------------------------------
@@ -1656,6 +1106,7 @@ pub(crate) async fn fetch_gp_table_extras(
// ---------------------------------------------------------------------------
/// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"}
#[allow(dead_code)] // helper for fetch_varchar_values; re-exposed by profile_table tool (PR2)
fn parse_pg_array_text(s: &str) -> Vec<String> {
let s = s.trim();
let s = s.strip_prefix('{').unwrap_or(s);
@@ -1716,65 +1167,10 @@ fn simplify_default(raw: &str) -> String {
s.to_string()
}
fn clean_sql_response(raw: &str) -> String {
let trimmed = raw.trim();
// Remove markdown code fences
let without_fences = if trimmed.starts_with("```") {
let inner = trimmed
.strip_prefix("```sql")
.or_else(|| trimmed.strip_prefix("```SQL"))
.or_else(|| trimmed.strip_prefix("```postgresql"))
.or_else(|| trimmed.strip_prefix("```"))
.unwrap_or(trimmed);
inner.strip_suffix("```").unwrap_or(inner)
} else {
trimmed
};
without_fences.trim().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
// ── clean_sql_response ────────────────────────────────────
#[test]
fn clean_sql_plain() {
assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1");
}
#[test]
fn clean_sql_with_fences() {
assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1");
}
#[test]
fn clean_sql_with_generic_fences() {
assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1");
}
#[test]
fn clean_sql_with_postgresql_fences() {
assert_eq!(
clean_sql_response("```postgresql\nSELECT 1\n```"),
"SELECT 1"
);
}
#[test]
fn clean_sql_with_whitespace() {
assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1");
}
#[test]
fn clean_sql_no_fences_multiline() {
assert_eq!(
clean_sql_response("SELECT\n *\nFROM users"),
"SELECT\n *\nFROM users"
);
}
// ── Fireworks provider ───────────────────────────────────
#[test]
@@ -1784,12 +1180,14 @@ mod tests {
}
#[test]
fn deserializes_legacy_settings_without_fireworks_key() {
// Old config files won't have `fireworks_api_key` — must still parse.
fn deserializes_legacy_settings_with_dropped_provider_keys() {
// Old config files may include `openai_api_key`/`anthropic_api_key` and a
// legacy `"provider": "openai"` value — both must be tolerated, with the
// unknown provider coerced to Ollama.
let legacy = r#"{
"provider": "ollama",
"provider": "openai",
"ollama_url": "http://localhost:11434",
"openai_api_key": null,
"openai_api_key": "sk-deprecated",
"anthropic_api_key": null,
"model": "qwen2.5-coder:7b"
}"#;
@@ -1813,6 +1211,28 @@ mod tests {
assert_eq!(parsed.choices[0].message.content, "hi");
}
// ── OpenRouter provider ──────────────────────────────────
#[test]
fn serializes_openrouter_provider() {
let json = serde_json::to_string(&AiProvider::OpenRouter).unwrap();
assert_eq!(json, "\"openrouter\"");
}
#[test]
fn deserializes_openrouter_settings() {
let cfg = r#"{
"provider": "openrouter",
"ollama_url": "http://localhost:11434",
"openrouter_api_key": "sk-or-v1-abc",
"model": "anthropic/claude-3.5-sonnet"
}"#;
let parsed: AiSettings = serde_json::from_str(cfg).unwrap();
assert_eq!(parsed.provider, AiProvider::OpenRouter);
assert_eq!(parsed.openrouter_api_key.as_deref(), Some("sk-or-v1-abc"));
assert_eq!(parsed.model, "anthropic/claude-3.5-sonnet");
}
#[test]
fn parses_fireworks_models_list() {
let body = r#"{

View File

@@ -1,13 +1,14 @@
use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings};
use crate::commands::chat_tools::{
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
build_sample_sql, detect_skew_tool, explain_query_tool, find_queries_tool, get_columns_tool,
list_databases_tool, list_tables_tool, profile_table_tool, save_query_tool,
switch_database_tool,
};
use crate::commands::memory::{append_memory_core, read_memory_core};
use crate::commands::queries::execute_query_core;
use crate::error::{TuskError, TuskResult};
use crate::models::ai::OllamaChatMessage;
use crate::models::chat::{ChartConfig, ChatMessage, ChatTurnResult, ContextUsage};
use crate::models::chat::{ChatMessage, ChatTurnResult, ContextUsage};
use crate::models::query_result::QueryResult;
use crate::state::AppState;
use chrono::Utc;
@@ -30,11 +31,11 @@ const TEXT_TOOL_CHAR_CAP: usize = 10_000;
/// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192).
/// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content.
const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000;
/// Conservative default for managed providers (Fireworks). Most chat-capable
/// Fireworks models ship with 32K256K context windows; 384K chars (~128K tok)
/// is a safe floor that won't trigger false /compact nags on normal sessions
/// while still flagging genuinely runaway threads.
const CONTEXT_BUDGET_CHARS_FIREWORKS: u64 = 384_000;
/// Conservative default for managed providers (Fireworks, OpenRouter). Most
/// chat-capable hosted models ship with 32K256K context windows; 384K chars
/// (~128K tok) is a safe floor that won't trigger false /compact nags on normal
/// sessions while still flagging genuinely runaway threads.
const CONTEXT_BUDGET_CHARS_MANAGED: u64 = 384_000;
/// Stop the loop when the model fails the same SQL hurdle this many times in a
/// row. Beyond this, additional hops almost always burn the rest of the budget
/// on identical retries; a definitive `final` with the error is more useful.
@@ -55,9 +56,15 @@ enum AgentAction {
Remember { note: String },
SaveQuery { name: String, sql: String },
FindQueries { text: String },
MakeChart { config: ChartConfig },
ProfileTable { table: String },
SampleData { table: String, limit: u32 },
ExplainQuery { sql: String },
DetectSkew { table: String },
}
const SAMPLE_DATA_DEFAULT_LIMIT: u32 = 50;
const SAMPLE_DATA_MAX_LIMIT: u32 = 200;
/// Parse the model's JSON response. Accepts both shapes the model tends to emit:
/// {"action":"X","field":"..."} — flat (matches our prompt)
/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention)
@@ -157,60 +164,55 @@ fn parse_agent_action(raw: &str) -> Result<AgentAction, String> {
}
Ok(AgentAction::FindQueries { text })
}
"make_chart" => {
let chart_type = lookup("chart_type")
.or_else(|| lookup("type"))
"profile_table" => {
let table = lookup("table")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `chart_type`".to_string())?
.trim()
.to_lowercase();
if !["bar", "line", "area", "pie"].contains(&chart_type.as_str()) {
return Err(format!(
"make_chart `chart_type` must be one of: bar, line, area, pie. Got: {}",
chart_type
));
}
let x = lookup("x")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `x` column".to_string())?
.ok_or_else(|| "profile_table missing `table`".to_string())?
.trim()
.to_string();
let y = lookup("y")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `y` column".to_string())?
.trim()
.to_string();
if x.is_empty() || y.is_empty() {
return Err("make_chart `x` and `y` must not be empty".into());
if table.is_empty() {
return Err("profile_table `table` must not be empty".into());
}
let group = lookup("group")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let title = lookup("title")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let orientation = lookup("orientation")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty());
Ok(AgentAction::MakeChart {
config: ChartConfig {
chart_type,
x,
y,
group,
title,
orientation,
},
})
Ok(AgentAction::ProfileTable { table })
}
"sample_data" => {
let table = lookup("table")
.and_then(|v| v.as_str())
.ok_or_else(|| "sample_data missing `table`".to_string())?
.trim()
.to_string();
if table.is_empty() {
return Err("sample_data `table` must not be empty".into());
}
let limit = lookup("limit")
.and_then(|v| v.as_u64())
.map(|n| n as u32)
.unwrap_or(SAMPLE_DATA_DEFAULT_LIMIT)
.clamp(1, SAMPLE_DATA_MAX_LIMIT);
Ok(AgentAction::SampleData { table, limit })
}
"explain_query" => {
let sql = lookup("sql")
.and_then(|v| v.as_str())
.ok_or_else(|| "explain_query missing `sql`".to_string())?
.trim()
.to_string();
if sql.is_empty() {
return Err("explain_query `sql` must not be empty".into());
}
Ok(AgentAction::ExplainQuery { sql })
}
"detect_skew" => {
let table = lookup("table")
.and_then(|v| v.as_str())
.ok_or_else(|| "detect_skew missing `table`".to_string())?
.trim()
.to_string();
if table.is_empty() {
return Err("detect_skew `table` must not be empty".into());
}
Ok(AgentAction::DetectSkew { table })
}
// Legacy from earlier iterations — silently ignored at parse time so the
// model can recover with a different action.
"get_schema" => Err(
"get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(),
),
other => Err(format!("unknown action `{}`", other)),
}
}
@@ -285,8 +287,17 @@ You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On
{{"action":"save_query","name":"<short label>","sql":"<the SQL>"}}
Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved.
{{"action":"make_chart","chart_type":"bar","x":"<col>","y":"<col>","title":"<short title>"}}
Visualise the LAST successful run_query result as a chart inline. `chart_type` is one of: bar, line, area, pie. `x` and `y` MUST be column names from the previous result. Optional: `group` (column for series), `orientation` ("vertical"/"horizontal", bar only). Use after run_query when the data is aggregated and would be clearer as a chart (top-N comparisons → bar; time series → line/area; proportions → pie). Skip for tiny results (≤2 rows) and giant ones (>500 rows).
{{"action":"profile_table","table":"schema.table"}}
Per-column profile: NULL fraction, distinct cardinality, min/max range, top-K values. PG/GP reads pg_stats (zero-cost; ensure ANALYZE has run). ClickHouse fires one summary query (cheap on MergeTree). Use BEFORE writing aggregations to spot pseudo-enums, NULL-heavy columns, or skewed distributions.
{{"action":"sample_data","table":"schema.table","limit":50}}
Random row sample (default 50, max 200). PG/GP uses TABLESAMPLE BERNOULLI when reltuples > 0, else ORDER BY random(). CH uses SAMPLE 0.01 on MergeTree with a sampling key, else ORDER BY rand(). Use to eyeball value shape BEFORE writing filters; cheaper than `SELECT * LIMIT N` on huge tables.
{{"action":"explain_query","sql":"SELECT ..."}}
Run EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) on PG/GP, EXPLAIN PLAN on CH. Reports root node, planning + execution time, seq-scanned tables, spilled sorts, est-vs-actual row skew, Greenplum Motions. Use AFTER a slow run_query.
{{"action":"detect_skew","table":"schema.table"}}
Greenplum-only: counts rows per gp_segment_id and reports max/min/avg + skew ratio. Ratio > 1.5 ⇒ uneven distribution; suggests revisiting DISTRIBUTED BY. Soft-errors on PG/CH.
{{"action":"final","text":"..."}}
End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling).
@@ -296,6 +307,8 @@ WORKFLOW
2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question.
3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs.
4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns.
4a. If the user asks about value shape (cardinality, NULL rates, top values), prefer `profile_table` over a hand-written run_query. To eyeball actual rows, prefer `sample_data` over `LIMIT 100`.
4b. If the user reports a slow query or asks why something takes long, run `explain_query` on it. On Greenplum, if a single table appears unbalanced, check `detect_skew`.
5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first.
6. Execute run_query.
7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here.
@@ -415,9 +428,6 @@ fn build_history(
content: serde_json::json!({ "action": "final", "text": text }).to_string(),
}),
ChatMessage::ToolCall { tool, input_json, .. } => {
if tool == "get_schema" {
continue; // legacy
}
let mut envelope = serde_json::Map::new();
envelope.insert("action".to_string(), Value::String(tool.clone()));
if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) {
@@ -437,9 +447,6 @@ fn build_history(
result,
..
} => {
if tool == "get_schema" {
continue; // legacy
}
let payload = match tool.as_str() {
"run_query" => {
if *is_error {
@@ -521,7 +528,7 @@ async fn provider_budget_chars(state: &AppState, app: &AppHandle) -> u64 {
use crate::models::ai::AiProvider;
match load_ai_settings(app, state).await {
Ok(s) => match s.provider {
AiProvider::Fireworks => CONTEXT_BUDGET_CHARS_FIREWORKS,
AiProvider::Fireworks | AiProvider::OpenRouter => CONTEXT_BUDGET_CHARS_MANAGED,
_ => CONTEXT_BUDGET_CHARS_OLLAMA,
},
Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA,
@@ -597,7 +604,10 @@ pub async fn chat_send(
}
};
let is_run_query = matches!(&action, AgentAction::RunQuery { .. });
let is_run_query = matches!(
&action,
AgentAction::RunQuery { .. } | AgentAction::SampleData { .. }
);
match action {
AgentAction::Final { text } => {
@@ -742,91 +752,90 @@ pub async fn chat_send(
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::MakeChart { config } => {
let config_json = serde_json::to_string(&config).unwrap_or_else(|_| "{}".into());
AgentAction::ProfileTable { table } => {
push_tool_call(
&mut new_messages,
&mut working,
"make_chart",
config_json.clone(),
"profile_table",
serde_json::json!({ "table": &table }).to_string(),
);
let result_msg = match last_successful_query_result(&working) {
None => ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(
"make_chart needs a successful run_query result above it. Run a SELECT first, then call make_chart."
.to_string(),
),
result: None,
created_at: now_ms(),
},
Some(qr) => {
if !qr.columns.iter().any(|c| c == &config.x) {
let result = run_text_tool(
profile_table_tool(&state, &connection_id, &table).await,
"profile_table",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::ExplainQuery { sql } => {
push_tool_call(
&mut new_messages,
&mut working,
"explain_query",
serde_json::json!({ "sql": &sql }).to_string(),
);
let result = run_text_tool(
explain_query_tool(&state, &connection_id, &sql).await,
"explain_query",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::DetectSkew { table } => {
push_tool_call(
&mut new_messages,
&mut working,
"detect_skew",
serde_json::json!({ "table": &table }).to_string(),
);
let result = run_text_tool(
detect_skew_tool(&state, &connection_id, &table).await,
"detect_skew",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::SampleData { table, limit } => {
push_tool_call(
&mut new_messages,
&mut working,
"sample_data",
serde_json::json!({ "table": &table, "limit": limit }).to_string(),
);
let outcome = match build_sample_sql(&state, &connection_id, &table, limit).await {
Ok(sql) => match execute_query_core(&state, &connection_id, &sql).await {
Ok(qr) => {
consecutive_query_errors = 0;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"x column `{}` is not in the last result. Available: {}.",
config.x,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if !qr.columns.iter().any(|c| c == &config.y) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"y column `{}` is not in the last result. Available: {}.",
config.y,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if let Some(group) = &config.group {
if !qr.columns.iter().any(|c| c == group) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"group column `{}` is not in the last result. Available: {}.",
group,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: false,
text: Some(config_json.clone()),
result: Some(qr),
created_at: now_ms(),
}
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
tool: "sample_data".to_string(),
is_error: false,
text: Some(config_json.clone()),
text: None,
result: Some(qr),
created_at: now_ms(),
}
}
Err(e) => {
consecutive_query_errors += 1;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "sample_data".to_string(),
is_error: true,
text: Some(format_db_error(&e)),
result: None,
created_at: now_ms(),
}
}
},
Err(e) => {
consecutive_query_errors += 1;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "sample_data".to_string(),
is_error: true,
text: Some(format_db_error(&e)),
result: None,
created_at: now_ms(),
}
}
};
push_tool_result(&mut new_messages, &mut working, result_msg);
push_tool_result(&mut new_messages, &mut working, outcome);
}
}
@@ -982,26 +991,6 @@ fn format_db_error(e: &TuskError) -> String {
e.to_string()
}
/// Locate the most recent SUCCESSFUL run_query in the working thread and
/// return its full QueryResult. Used by make_chart to attach data to a chart
/// directive without relying on the model to re-send it.
fn last_successful_query_result(messages: &[ChatMessage]) -> Option<QueryResult> {
for m in messages.iter().rev() {
if let ChatMessage::ToolResult {
tool,
is_error: false,
result: Some(qr),
..
} = m
{
if tool == "run_query" {
return Some(qr.clone());
}
}
}
None
}
/// Pull the most recent run_query error text from the working thread, so the
/// post-loop "I gave up" summary can quote concrete errors back to the user.
fn last_run_query_error(messages: &[ChatMessage]) -> Option<String> {
@@ -1484,119 +1473,6 @@ mod tests {
assert!(last_run_query_error(&msgs).is_none());
}
#[test]
fn parses_make_chart_minimal() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"bar","x":"carrier","y":"trips"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.chart_type, "bar");
assert_eq!(config.x, "carrier");
assert_eq!(config.y, "trips");
assert!(config.group.is_none());
assert!(config.title.is_none());
}
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_make_chart_with_group_and_title() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"line","x":"month","y":"revenue","group":"region","title":"Revenue"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.group.as_deref(), Some("region"));
assert_eq!(config.title.as_deref(), Some("Revenue"));
}
_ => panic!("wrong variant"),
}
}
#[test]
fn make_chart_accepts_alternative_field_name_type() {
// Some models emit `type` instead of `chart_type`.
let a = parse_agent_action(
r#"{"action":"make_chart","type":"pie","x":"label","y":"value"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => assert_eq!(config.chart_type, "pie"),
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_make_chart_with_unknown_chart_type() {
let r = parse_agent_action(
r#"{"action":"make_chart","chart_type":"radar","x":"a","y":"b"}"#,
);
assert!(r.is_err());
}
#[test]
fn rejects_make_chart_missing_x_or_y() {
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","y":"a"}"#).is_err());
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","x":"a"}"#).is_err());
}
#[test]
fn last_successful_query_result_finds_recent() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![vec![Value::Number(1.into())]],
row_count: 1,
execution_time_ms: 1,
};
let msgs = vec![
ChatMessage::ToolResult {
id: "r1".into(),
tool: "run_query".into(),
is_error: false,
text: None,
result: Some(qr.clone()),
created_at: 1,
},
ChatMessage::ToolResult {
id: "r2".into(),
tool: "run_query".into(),
is_error: true,
text: Some("oops".into()),
result: None,
created_at: 2,
},
];
let found = last_successful_query_result(&msgs).expect("ok");
assert_eq!(found.columns, vec!["a".to_string()]);
}
#[test]
fn last_successful_query_result_skips_non_run_query() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![],
row_count: 0,
execution_time_ms: 0,
};
let msgs = vec![ChatMessage::ToolResult {
id: "r1".into(),
tool: "list_tables".into(),
is_error: false,
text: Some("public.x".into()),
result: Some(qr),
created_at: 1,
}];
assert!(last_successful_query_result(&msgs).is_none());
}
#[test]
fn render_thread_for_summary_includes_roles_and_skips_rows() {
let msgs = vec![
@@ -1625,11 +1501,6 @@ mod tests {
assert!(!rendered.contains("alice"));
}
#[test]
fn rejects_legacy_get_schema() {
assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err());
}
#[test]
fn truncates_long_cell() {
let long = "a".repeat(CELL_CHAR_CAP + 50);

View File

@@ -10,11 +10,14 @@ use crate::commands::ai::{
ColumnInfo,
};
use crate::commands::connections::{load_connection_config, switch_database_core};
use crate::commands::queries::execute_query_core;
use crate::commands::saved_queries::{list_saved_queries_core, save_query_core};
use crate::commands::schema::{list_databases_core, list_tables_core};
use crate::db::sql_guard::ensure_readonly_sql;
use crate::error::{TuskError, TuskResult};
use crate::models::saved_queries::SavedQuery;
use crate::state::{AppState, CachedVec, DbFlavor};
use crate::utils::escape_ident;
use sqlx::{PgPool, Row};
use std::collections::{BTreeMap, HashMap};
use std::time::{Duration, Instant};
@@ -565,3 +568,690 @@ pub async fn find_queries_tool(
Ok(out)
}
// ---------------------------------------------------------------------------
// profile_table (PR2 — data-engineering tool)
// ---------------------------------------------------------------------------
const PROFILE_TABLE_MAX_COLUMNS: usize = 30;
const PROFILE_TABLE_TOPK: usize = 5;
pub async fn profile_table_tool(
state: &AppState,
connection_id: &str,
table: &str,
) -> TuskResult<String> {
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
let flavor = state.get_flavor(connection_id).await;
match flavor {
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
profile_table_postgres(state, connection_id, &schema, &tbl).await
}
DbFlavor::ClickHouse => profile_table_clickhouse(state, connection_id, &schema, &tbl).await,
}
}
async fn profile_table_postgres(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
) -> TuskResult<String> {
let pool = state.get_pool(connection_id).await?;
let exists = sqlx::query_scalar::<_, i64>(
"SELECT 1 FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE n.nspname = $1 AND c.relname = $2 LIMIT 1",
)
.bind(schema)
.bind(table)
.fetch_optional(&pool)
.await
.map_err(TuskError::Database)?;
if exists.is_none() {
return Err(TuskError::Custom(format!(
"Table '{}.{}' does not exist (or no privileges).",
schema, table
)));
}
let last_analyze: Option<chrono::DateTime<chrono::Utc>> = sqlx::query_scalar(
"SELECT GREATEST(last_analyze, last_autoanalyze) FROM pg_stat_user_tables \
WHERE schemaname = $1 AND relname = $2",
)
.bind(schema)
.bind(table)
.fetch_optional(&pool)
.await
.ok()
.flatten();
let stat_rows = sqlx::query(
"SELECT attname, null_frac, n_distinct, \
most_common_vals::text, most_common_freqs, histogram_bounds::text \
FROM pg_stats \
WHERE schemaname = $1 AND tablename = $2 \
ORDER BY attname",
)
.bind(schema)
.bind(table)
.fetch_all(&pool)
.await
.map_err(TuskError::Database)?;
let mut out = format!("PROFILE {}.{}\n", schema, table);
match last_analyze {
Some(ts) => out.push_str(&format!("Last ANALYZE: {}\n", ts.to_rfc3339())),
None => out.push_str("Last ANALYZE: never\n"),
}
if stat_rows.is_empty() {
out.push_str(&format!(
"\nNo statistics in pg_stats. Run: ANALYZE {}.{};\n",
escape_ident(schema),
escape_ident(table)
));
return Ok(out);
}
let total = stat_rows.len();
let take = total.min(PROFILE_TABLE_MAX_COLUMNS);
out.push_str(&format!("\n{} columns with stats\n", total));
for r in stat_rows.iter().take(take) {
let attname: String = r.get(0);
let null_frac: f32 = r.try_get(1).unwrap_or(0.0);
let n_distinct: f32 = r.try_get(2).unwrap_or(0.0);
let mcv_text: Option<String> = r.try_get(3).ok();
let mcf_arr: Option<Vec<f32>> = r.try_get(4).ok();
let hist_text: Option<String> = r.try_get(5).ok();
out.push_str(&format!("\n {}:\n", attname));
out.push_str(&format!(" null_frac: {:.4}\n", null_frac));
if n_distinct < 0.0 {
out.push_str(&format!(
" n_distinct: {:.3} (ratio of total rows)\n",
-n_distinct
));
} else {
out.push_str(&format!(" n_distinct: {}\n", n_distinct as i64));
}
if let Some(text) = hist_text.as_deref() {
let bounds = parse_pg_array_text_local(text);
if let (Some(min), Some(max)) = (bounds.first(), bounds.last()) {
out.push_str(&format!(" range: {}{}\n", min, max));
}
}
if let Some(text) = mcv_text.as_deref() {
let vals = parse_pg_array_text_local(text);
if !vals.is_empty() {
let freqs = mcf_arr.unwrap_or_default();
let pairs: Vec<String> = vals
.iter()
.take(PROFILE_TABLE_TOPK)
.enumerate()
.map(|(i, v)| match freqs.get(i) {
Some(f) => format!("{}({:.3})", v, f),
None => v.clone(),
})
.collect();
out.push_str(&format!(" top: {}\n", pairs.join(", ")));
}
}
}
if total > take {
out.push_str(&format!("\n…and {} more columns\n", total - take));
}
Ok(out)
}
/// Local pg-array parser used by profile_table; mirrors `parse_pg_array_text` in ai.rs
/// but kept local to avoid importing a private helper.
fn parse_pg_array_text_local(s: &str) -> Vec<String> {
let s = s.trim();
let s = s.strip_prefix('{').unwrap_or(s);
let s = s.strip_suffix('}').unwrap_or(s);
if s.is_empty() {
return Vec::new();
}
let mut out = Vec::new();
let mut cur = String::new();
let mut in_quotes = false;
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
match c {
'"' if !in_quotes => in_quotes = true,
'"' if in_quotes => {
if chars.peek() == Some(&'"') {
cur.push('"');
chars.next();
} else {
in_quotes = false;
}
}
',' if !in_quotes => {
out.push(std::mem::take(&mut cur));
}
'\\' if in_quotes => {
if let Some(next) = chars.next() {
cur.push(next);
}
}
other => cur.push(other),
}
}
if !cur.is_empty() || s.ends_with(',') {
out.push(cur);
}
out
}
async fn profile_table_clickhouse(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let active_db = client.database.clone();
let dbn = if schema == "public" || schema.is_empty() {
active_db
} else {
schema.to_string()
};
let cols_sql = format!(
"SELECT name, type FROM system.columns \
WHERE database = '{}' AND table = '{}' \
ORDER BY position LIMIT {}",
dbn.replace('\'', "\\'"),
table.replace('\'', "\\'"),
PROFILE_TABLE_MAX_COLUMNS
);
let col_rows = client.fetch_objects(&cols_sql).await?;
if col_rows.is_empty() {
return Err(TuskError::Custom(format!(
"Table '{}.{}' does not exist (or no privileges).",
dbn, table
)));
}
let mut select_parts: Vec<String> = vec!["count() AS rows_total".to_string()];
let mut col_names: Vec<String> = Vec::new();
let mut col_types: Vec<String> = Vec::new();
for r in &col_rows {
let name = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
let dtype = r.get("type").and_then(|v| v.as_str()).unwrap_or("").to_string();
if name.is_empty() {
continue;
}
col_names.push(name.clone());
col_types.push(dtype);
let q = name.replace('`', "``");
select_parts.push(format!("countIf(`{}` IS NULL) AS null_{}", q, col_names.len()));
select_parts.push(format!("uniqHLL12(`{}`) AS dist_{}", q, col_names.len()));
select_parts.push(format!("toString(min(`{}`)) AS min_{}", q, col_names.len()));
select_parts.push(format!("toString(max(`{}`)) AS max_{}", q, col_names.len()));
select_parts.push(format!(
"arrayStringConcat(arrayMap(x -> toString(x), topK({})(`{}`)), '|') AS top_{}",
PROFILE_TABLE_TOPK,
q,
col_names.len()
));
}
let agg_sql = format!(
"SELECT {} FROM `{}`.`{}`",
select_parts.join(", "),
dbn.replace('`', "``"),
table.replace('`', "``")
);
let agg_rows = client.fetch_objects(&agg_sql).await?;
let row = agg_rows
.first()
.ok_or_else(|| TuskError::Custom("ClickHouse returned no row for profile aggregate".into()))?;
let rows_total = row
.get("rows_total")
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
.unwrap_or(0);
let mut out = format!(
"PROFILE {}.{}\nRows: {}\n{} columns profiled\n",
dbn,
table,
rows_total,
col_names.len()
);
for (i, name) in col_names.iter().enumerate() {
let n = i + 1;
let nulls = row
.get(&format!("null_{}", n))
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
.unwrap_or(0);
let dist = row
.get(&format!("dist_{}", n))
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
.unwrap_or(0);
let min = row.get(&format!("min_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
let max = row.get(&format!("max_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
let top_raw = row.get(&format!("top_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
out.push_str(&format!("\n {} ({}):\n", name, col_types[i]));
let null_frac = if rows_total > 0 {
nulls as f64 / rows_total as f64
} else {
0.0
};
out.push_str(&format!(" null_frac: {:.4}\n", null_frac));
out.push_str(&format!(" distinct (HLL): {}\n", dist));
if !min.is_empty() || !max.is_empty() {
out.push_str(&format!(" range: {}{}\n", min, max));
}
if !top_raw.is_empty() {
let top_vals: Vec<&str> = top_raw.split('|').take(PROFILE_TABLE_TOPK).collect();
out.push_str(&format!(" top: {}\n", top_vals.join(", ")));
}
}
if col_rows.len() == PROFILE_TABLE_MAX_COLUMNS {
out.push_str(&format!(
"\n…showing first {} columns\n",
PROFILE_TABLE_MAX_COLUMNS
));
}
Ok(out)
}
// ---------------------------------------------------------------------------
// sample_data (PR2 — returns SQL string; dispatch site runs it through
// execute_query_core so the QueryResult feeds the standard renderer)
// ---------------------------------------------------------------------------
pub async fn build_sample_sql(
state: &AppState,
connection_id: &str,
table: &str,
limit: u32,
) -> TuskResult<String> {
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
let flavor = state.get_flavor(connection_id).await;
match flavor {
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
build_sample_sql_postgres(state, connection_id, &schema, &tbl, limit).await
}
DbFlavor::ClickHouse => {
build_sample_sql_clickhouse(state, connection_id, &schema, &tbl, limit).await
}
}
}
async fn build_sample_sql_postgres(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
limit: u32,
) -> TuskResult<String> {
let pool = state.get_pool(connection_id).await?;
let reltuples: f64 = sqlx::query_scalar(
"SELECT c.reltuples FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE n.nspname = $1 AND c.relname = $2",
)
.bind(schema)
.bind(table)
.fetch_optional(&pool)
.await
.map_err(TuskError::Database)?
.unwrap_or(0.0);
let qualified = format!("{}.{}", escape_ident(schema), escape_ident(table));
if reltuples > 0.0 {
let target = limit as f64 * 100.0 / reltuples;
let percent = target.clamp(0.01, 100.0);
Ok(format!(
"SELECT * FROM {} TABLESAMPLE BERNOULLI({:.4}) LIMIT {}",
qualified, percent, limit
))
} else {
Ok(format!(
"SELECT * FROM {} ORDER BY random() LIMIT {}",
qualified, limit
))
}
}
async fn build_sample_sql_clickhouse(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
limit: u32,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let active_db = client.database.clone();
let dbn = if schema == "public" || schema.is_empty() {
active_db
} else {
schema.to_string()
};
let info_sql = format!(
"SELECT engine, sampling_key FROM system.tables \
WHERE database = '{}' AND name = '{}' LIMIT 1",
dbn.replace('\'', "\\'"),
table.replace('\'', "\\'")
);
let rows = client.fetch_objects(&info_sql).await.unwrap_or_default();
let (engine, sampling_key) = match rows.first() {
Some(r) => (
r.get("engine").and_then(|v| v.as_str()).unwrap_or("").to_string(),
r.get("sampling_key").and_then(|v| v.as_str()).unwrap_or("").to_string(),
),
None => (String::new(), String::new()),
};
let qualified = format!(
"`{}`.`{}`",
dbn.replace('`', "``"),
table.replace('`', "``")
);
if engine.starts_with("Merge") && !sampling_key.trim().is_empty() {
Ok(format!(
"SELECT * FROM {} SAMPLE 0.01 LIMIT {}",
qualified, limit
))
} else {
Ok(format!(
"SELECT * FROM {} ORDER BY rand() LIMIT {}",
qualified, limit
))
}
}
// ---------------------------------------------------------------------------
// explain_query (PR2)
// ---------------------------------------------------------------------------
pub async fn explain_query_tool(
state: &AppState,
connection_id: &str,
sql: &str,
) -> TuskResult<String> {
let trimmed = sql.trim();
if trimmed.is_empty() {
return Err(TuskError::Custom("explain_query: sql must not be empty".into()));
}
// Validate the user's statement BEFORE prefixing EXPLAIN so the error message
// references their SQL, not the wrapper. ensure_readonly_sql also rejects any
// forbidden keywords (INSERT/UPDATE/DELETE/...) even nested under EXPLAIN.
ensure_readonly_sql(trimmed).map_err(|e| TuskError::Custom(e.to_string()))?;
let flavor = state.get_flavor(connection_id).await;
match flavor {
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
explain_query_postgres(state, connection_id, trimmed).await
}
DbFlavor::ClickHouse => explain_query_clickhouse(state, connection_id, trimmed).await,
}
}
async fn explain_query_postgres(
state: &AppState,
connection_id: &str,
sql: &str,
) -> TuskResult<String> {
let pool = state.get_pool(connection_id).await?;
let plan_sql = format!("EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) {}", sql);
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
sqlx::query("SET TRANSACTION READ ONLY")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
let row = sqlx::query(&plan_sql)
.fetch_one(&mut *tx)
.await
.map_err(TuskError::Database)?;
let _ = tx.rollback().await;
let raw_json: serde_json::Value = match row.try_get::<serde_json::Value, _>(0) {
Ok(v) => v,
Err(_) => {
let s: String = row.try_get(0).map_err(TuskError::Database)?;
serde_json::from_str(&s)
.map_err(|e| TuskError::Custom(format!("EXPLAIN JSON parse failed: {}", e)))?
}
};
let plans = raw_json
.as_array()
.ok_or_else(|| TuskError::Custom("EXPLAIN JSON: expected array".into()))?;
let plan = plans.first().and_then(|p| p.get("Plan")).ok_or_else(|| {
TuskError::Custom("EXPLAIN JSON: missing top-level Plan node".into())
})?;
let root_node = plan.get("Node Type").and_then(|v| v.as_str()).unwrap_or("?");
let total_cost = plan.get("Total Cost").and_then(|v| v.as_f64()).unwrap_or(0.0);
let planning = plans
.first()
.and_then(|p| p.get("Planning Time").and_then(|v| v.as_f64()))
.unwrap_or(0.0);
let execution = plans
.first()
.and_then(|p| p.get("Execution Time").and_then(|v| v.as_f64()))
.unwrap_or(0.0);
let mut seq_scans: Vec<String> = Vec::new();
let mut spilled: Vec<String> = Vec::new();
let mut motions: Vec<String> = Vec::new();
let mut max_skew: Option<(f64, String)> = None;
walk_pg_plan(plan, &mut seq_scans, &mut spilled, &mut motions, &mut max_skew);
let mut out = format!(
"PLAN root: {}, total cost {:.1}\nPlanning: {:.2} ms Execution: {:.2} ms\n",
root_node, total_cost, planning, execution
);
if !seq_scans.is_empty() {
out.push_str(&format!("Seq scans on: {}\n", seq_scans.join(", ")));
}
if !spilled.is_empty() {
out.push_str(&format!("Spilled to disk: {}\n", spilled.join(", ")));
}
if !motions.is_empty() {
out.push_str(&format!("Motions (Greenplum): {}\n", motions.join(", ")));
}
if let Some((ratio, node)) = max_skew {
if ratio >= 5.0 {
out.push_str(&format!(
"Estimate skew: max plan/actual ratio = {:.1} on {}\n",
ratio, node
));
}
}
if seq_scans.is_empty() && spilled.is_empty() && motions.is_empty() {
out.push_str("No obvious red flags.\n");
}
Ok(out)
}
fn walk_pg_plan(
node: &serde_json::Value,
seq_scans: &mut Vec<String>,
spilled: &mut Vec<String>,
motions: &mut Vec<String>,
max_skew: &mut Option<(f64, String)>,
) {
let node_type = node.get("Node Type").and_then(|v| v.as_str()).unwrap_or("");
if node_type == "Seq Scan" {
let rel = node
.get("Relation Name")
.and_then(|v| v.as_str())
.unwrap_or("?");
let schema = node
.get("Schema")
.and_then(|v| v.as_str())
.map(|s| format!("{}.", s))
.unwrap_or_default();
seq_scans.push(format!("{}{}", schema, rel));
}
if let Some(method) = node.get("Sort Method").and_then(|v| v.as_str()) {
if method.contains("disk") || method.contains("external") {
spilled.push(format!("Sort ({})", method));
}
}
if node_type.contains("Motion") {
motions.push(node_type.to_string());
}
let plan_rows = node.get("Plan Rows").and_then(|v| v.as_f64()).unwrap_or(0.0);
let actual_rows = node.get("Actual Rows").and_then(|v| v.as_f64()).unwrap_or(0.0);
if actual_rows > 0.0 && plan_rows > 0.0 {
let ratio = (plan_rows / actual_rows).max(actual_rows / plan_rows);
if max_skew.as_ref().map(|(r, _)| ratio > *r).unwrap_or(true) {
*max_skew = Some((ratio, node_type.to_string()));
}
}
if let Some(children) = node.get("Plans").and_then(|v| v.as_array()) {
for child in children {
walk_pg_plan(child, seq_scans, spilled, motions, max_skew);
}
}
}
async fn explain_query_clickhouse(
state: &AppState,
connection_id: &str,
sql: &str,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let plan_sql = format!("EXPLAIN PLAN {}", sql);
let qr = client.execute_query(&plan_sql, true).await?;
if qr.rows.is_empty() {
return Ok("(empty plan)".to_string());
}
let mut out = String::from("ClickHouse plan:\n");
for row in &qr.rows {
if let Some(cell) = row.first() {
if let Some(s) = cell.as_str() {
out.push_str(s);
out.push('\n');
}
}
}
Ok(out)
}
// ---------------------------------------------------------------------------
// detect_skew (PR2 — Greenplum-only)
// ---------------------------------------------------------------------------
pub async fn detect_skew_tool(
state: &AppState,
connection_id: &str,
table: &str,
) -> TuskResult<String> {
let flavor = state.get_flavor(connection_id).await;
if !matches!(flavor, DbFlavor::Greenplum) {
return Ok("detect_skew is only available on Greenplum connections.".to_string());
}
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&tbl));
let sql = format!(
"SELECT gp_segment_id, COUNT(*) AS n FROM {} GROUP BY 1 ORDER BY 1",
qualified
);
let qr = execute_query_core(state, connection_id, &sql).await?;
let mut counts: Vec<(i64, i64)> = Vec::new();
for row in &qr.rows {
let seg = row
.get(0)
.and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok())))
.unwrap_or(0);
let n = row
.get(1)
.and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok())))
.unwrap_or(0);
counts.push((seg, n));
}
if counts.is_empty() {
return Ok(format!("Table {}.{} is empty.", schema, tbl));
}
let total: i64 = counts.iter().map(|(_, n)| *n).sum();
let max = counts.iter().map(|(_, n)| *n).max().unwrap_or(0);
let min = counts.iter().map(|(_, n)| *n).min().unwrap_or(0);
let avg = total as f64 / counts.len() as f64;
let ratio = if avg > 0.0 { max as f64 / avg } else { 0.0 };
let mut out = format!(
"Per-segment row distribution for {}.{}\nsegments: {} total rows: {}\nmin: {} max: {} avg: {:.0}\nskew ratio (max/avg): {:.2}",
schema,
tbl,
counts.len(),
total,
min,
max,
avg,
ratio
);
if ratio > 1.5 {
out.push_str(" ⚠ uneven distribution\n");
} else {
out.push_str(" OK — within 1.5x of average\n");
}
let pool = state.get_pool(connection_id).await?;
if let Some(policy) = fetch_gp_distribution_for(&pool, &schema, &tbl).await {
out.push_str(&format!("\nCurrent policy: {}\n", policy));
if ratio > 1.5 {
out.push_str(
"Hint: pick a higher-cardinality column. Run profile_table to compare n_distinct.\n",
);
}
}
Ok(out)
}
/// Fetch the Greenplum DISTRIBUTED BY policy for a single table. Returns None if
/// the catalog query fails (non-GP connection, missing privileges, etc.).
async fn fetch_gp_distribution_for(
pool: &PgPool,
schema: &str,
table: &str,
) -> Option<String> {
let row = sqlx::query(
"SELECT COALESCE(\
(SELECT array_agg(a.attname ORDER BY ord.idx) \
FROM regexp_split_to_table(NULLIF(trim(p.distkey::text), ''), ' ') \
WITH ORDINALITY AS ord(attnum_str, idx) \
JOIN pg_attribute a \
ON a.attrelid = c.oid \
AND a.attnum::int = ord.attnum_str::int), \
ARRAY[]::text[] \
) AS dist_columns \
FROM gp_distribution_policy p \
JOIN pg_class c ON p.localoid = c.oid \
JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE n.nspname = $1 AND c.relname = $2",
)
.bind(schema)
.bind(table)
.fetch_optional(pool)
.await
.ok()
.flatten()?;
let cols: Vec<String> = row.try_get(0).ok()?;
Some(if cols.is_empty() {
"DISTRIBUTED RANDOMLY".to_string()
} else {
format!("DISTRIBUTED BY ({})", cols.join(", "))
})
}

View File

@@ -111,9 +111,7 @@ pub fn run() {
commands::ai::save_ai_settings,
commands::ai::list_ollama_models,
commands::ai::list_fireworks_models,
commands::ai::generate_sql,
commands::ai::explain_sql,
commands::ai::fix_sql_error,
commands::ai::list_openrouter_models,
// chat
commands::chat::chat_send,
commands::chat::chat_compact,

View File

@@ -1,13 +1,26 @@
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AiProvider {
#[default]
Ollama,
OpenAi,
Anthropic,
Fireworks,
OpenRouter,
}
/// Deserialize a provider string, coercing legacy `openai`/`anthropic` and any
/// unknown value to `Ollama`. Keeps existing config files loadable after the
/// stub providers were removed.
impl<'de> Deserialize<'de> for AiProvider {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
Ok(match s.as_str() {
"fireworks" => AiProvider::Fireworks,
"openrouter" => AiProvider::OpenRouter,
_ => AiProvider::Ollama,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -15,11 +28,9 @@ pub struct AiSettings {
pub provider: AiProvider,
pub ollama_url: String,
#[serde(default)]
pub openai_api_key: Option<String>,
#[serde(default)]
pub anthropic_api_key: Option<String>,
#[serde(default)]
pub fireworks_api_key: Option<String>,
#[serde(default)]
pub openrouter_api_key: Option<String>,
pub model: String,
}
@@ -28,9 +39,8 @@ impl Default for AiSettings {
Self {
provider: AiProvider::Ollama,
ollama_url: "http://localhost:11434".to_string(),
openai_api_key: None,
anthropic_api_key: None,
fireworks_api_key: None,
openrouter_api_key: None,
model: String::new(),
}
}
@@ -71,7 +81,9 @@ pub struct OllamaModel {
}
// ---------------------------------------------------------------------------
// Fireworks (OpenAI-compatible chat-completions)
// OpenAI-compatible chat-completions (Fireworks, OpenRouter)
// These request/response shapes are shared by every OpenAI-compatible provider;
// the `Fireworks*` names are retained for historical reasons.
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize)]

View File

@@ -31,18 +31,3 @@ pub struct ChatTurnResult {
pub messages: Vec<ChatMessage>,
pub usage: ContextUsage,
}
/// Chart configuration produced by the agent's `make_chart` tool.
/// Embedded as JSON in `ToolResult.text` for tool == "make_chart" while the
/// underlying data lives in `ToolResult.result`. The frontend reads both to
/// render the chart inline.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChartConfig {
pub chart_type: String, // "bar" | "line" | "area" | "pie"
pub x: String, // column name for X axis / category
pub y: String, // column name for Y axis / numeric value
pub group: Option<String>, // optional column for series grouping
pub title: Option<String>,
pub orientation: Option<String>, // "vertical" | "horizontal" — bar only
}

View File

@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use std::time::Instant;
use tokio::sync::{watch, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@@ -17,12 +17,6 @@ pub enum DbFlavor {
}
#[derive(Clone)]
pub struct SchemaCacheEntry {
pub schema_text: String,
pub cached_at: Instant,
}
#[derive(Clone)]
pub struct CachedString {
pub value: String,
@@ -43,23 +37,16 @@ pub struct AppState {
/// Greenplum major version (6 or 7), tracked separately because GP6 and GP7
/// expose very different system catalogs (GP6 = PG9.4 base, GP7 = PG14 base).
pub gp_majors: RwLock<HashMap<String, u8>>,
/// Legacy cache used by generate_sql/explain_sql/fix_sql_error — full DDL.
pub schema_cache: RwLock<HashMap<String, SchemaCacheEntry>>,
/// Chat v2 caches: lite overview per connection.
/// Chat agent caches: lite overview per connection.
pub overview_cache: RwLock<HashMap<String, CachedString>>,
/// Chat v2 caches: list of tables per (connection_id, db_name) — used for
/// Chat agent caches: list of tables per (connection_id, db_name) — used for
/// list_tables on a non-active PG database via temporary pool.
pub tables_by_db_cache: RwLock<HashMap<(String, String), CachedVec<String>>>,
/// Chat v2 caches: column block per (connection_id, db_name, "schema.table").
pub columns_cache: RwLock<HashMap<(String, String, String), CachedString>>,
pub mcp_shutdown_tx: watch::Sender<bool>,
pub mcp_running: RwLock<bool>,
pub ai_settings: RwLock<Option<AiSettings>>,
}
const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
const SCHEMA_CACHE_MAX_SIZE: usize = 100;
impl AppState {
pub fn new() -> Self {
let (mcp_shutdown_tx, _) = watch::channel(false);
@@ -69,10 +56,8 @@ impl AppState {
read_only: RwLock::new(HashMap::new()),
db_flavors: RwLock::new(HashMap::new()),
gp_majors: RwLock::new(HashMap::new()),
schema_cache: RwLock::new(HashMap::new()),
overview_cache: RwLock::new(HashMap::new()),
tables_by_db_cache: RwLock::new(HashMap::new()),
columns_cache: RwLock::new(HashMap::new()),
mcp_shutdown_tx,
mcp_running: RwLock::new(false),
ai_settings: RwLock::new(None),
@@ -82,16 +67,11 @@ impl AppState {
/// Drop every chat-agent cache entry tied to this connection.
/// Called by switch_database_core, disconnect, and on connection delete.
pub async fn invalidate_chat_caches_for(&self, connection_id: &str) {
self.schema_cache.write().await.remove(connection_id);
self.overview_cache.write().await.remove(connection_id);
self.tables_by_db_cache
.write()
.await
.retain(|(cid, _), _| cid != connection_id);
self.columns_cache
.write()
.await
.retain(|(cid, _, _), _| cid != connection_id);
}
pub async fn get_pool(&self, connection_id: &str) -> TuskResult<PgPool> {
@@ -125,39 +105,4 @@ impl AppState {
pub async fn get_gp_major(&self, id: &str) -> Option<u8> {
self.gp_majors.read().await.get(id).copied()
}
pub async fn get_schema_cache(&self, connection_id: &str) -> Option<String> {
let cache = self.schema_cache.read().await;
cache.get(connection_id).and_then(|entry| {
if entry.cached_at.elapsed() < SCHEMA_CACHE_TTL {
Some(entry.schema_text.clone())
} else {
None
}
})
}
pub async fn set_schema_cache(&self, connection_id: String, schema_text: String) {
let mut cache = self.schema_cache.write().await;
// Evict stale entries to prevent unbounded memory growth
cache.retain(|_, entry| entry.cached_at.elapsed() < SCHEMA_CACHE_TTL);
// If still at capacity, remove the oldest entry
if cache.len() >= SCHEMA_CACHE_MAX_SIZE {
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, e)| e.cached_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(
connection_id,
SchemaCacheEntry {
schema_text,
cached_at: Instant::now(),
},
);
}
}