refactor(ai): consolidate AI around chat tool-calling; add OpenRouter

- rework chat backend (chat.rs, chat_tools.rs, ai.rs, models, state) around tool calls
- add OpenRouter provider alongside Ollama/Fireworks in settings
- drop inline AiBar, ResultsPanel explain/fix UI and ChartPreview in favour of the chat panel
- add frontend chat tool-registry
This commit is contained in:
2026-05-23 15:01:52 +03:00
parent a485cf7ee3
commit 0cba457fb7
19 changed files with 1244 additions and 1931 deletions

View File

@@ -15,6 +15,13 @@ use tauri::{AppHandle, Manager, State};
const MAX_RETRIES: u32 = 2; const MAX_RETRIES: u32 = 2;
const RETRY_DELAY_MS: u64 = 1000; const RETRY_DELAY_MS: u64 = 1000;
const FIREWORKS_BASE_URL: &str = "https://api.fireworks.ai/inference/v1"; const FIREWORKS_BASE_URL: &str = "https://api.fireworks.ai/inference/v1";
const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
/// Optional attribution headers OpenRouter uses for its public app leaderboard.
/// Harmless to send on every request; ignored by other OpenAI-compatible APIs.
const OPENROUTER_HEADERS: &[(&str, &str)] = &[
("HTTP-Referer", "https://github.com/codelab/tusk"),
("X-Title", "Tusk"),
];
fn http_client() -> &'static reqwest::Client { fn http_client() -> &'static reqwest::Client {
use std::sync::LazyLock; use std::sync::LazyLock;
@@ -89,32 +96,51 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaMode
#[tauri::command] #[tauri::command]
pub async fn list_fireworks_models(api_key: String) -> TuskResult<Vec<OllamaModel>> { pub async fn list_fireworks_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
list_openai_compatible_models(FIREWORKS_BASE_URL, &api_key, "Fireworks", &[]).await
}
#[tauri::command]
pub async fn list_openrouter_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
list_openai_compatible_models(OPENROUTER_BASE_URL, &api_key, "OpenRouter", OPENROUTER_HEADERS)
.await
}
/// List the available chat models for any OpenAI-compatible provider via its
/// `GET {base_url}/models` endpoint. `extra_headers` carries provider-specific
/// attribution headers (OpenRouter recommends `HTTP-Referer`/`X-Title`).
async fn list_openai_compatible_models(
base_url: &str,
api_key: &str,
provider_label: &str,
extra_headers: &[(&str, &str)],
) -> TuskResult<Vec<OllamaModel>> {
let key = api_key.trim(); let key = api_key.trim();
if key.is_empty() { if key.is_empty() {
return Err(TuskError::Ai("Fireworks API key required".to_string())); return Err(TuskError::Ai(format!("{} API key required", provider_label)));
} }
let url = format!("{}/models", FIREWORKS_BASE_URL); let url = format!("{}/models", base_url);
let resp = http_client() let mut req = http_client().get(&url).bearer_auth(key);
.get(&url) for (name, value) in extra_headers {
.bearer_auth(key) req = req.header(*name, *value);
}
let resp = req
.send() .send()
.await .await
.map_err(|e| TuskError::Ai(format!("Cannot reach Fireworks: {}", e)))?; .map_err(|e| TuskError::Ai(format!("Cannot reach {}: {}", provider_label, e)))?;
if !resp.status().is_success() { if !resp.status().is_success() {
let status = resp.status(); let status = resp.status();
let body = resp.text().await.unwrap_or_default(); let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!( return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}", "{} error ({}): {}",
status, body provider_label, status, body
))); )));
} }
let parsed: FireworksModelsResponse = resp let parsed: FireworksModelsResponse = resp.json().await.map_err(|e| {
.json() TuskError::Ai(format!("Failed to parse {} models list: {}", provider_label, e))
.await })?;
.map_err(|e| TuskError::Ai(format!("Failed to parse Fireworks models list: {}", e)))?;
Ok(parsed Ok(parsed
.data .data
@@ -180,33 +206,8 @@ pub(crate) async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskR
Ok(settings) Ok(settings)
} }
async fn call_chat_simple( /// Provider-agnostic chat-completions dispatcher used by the chat agent.
app: &AppHandle, /// Returns the model's raw text content.
state: &AppState,
system_prompt: String,
user_content: String,
) -> TuskResult<String> {
call_chat_messages(
app,
state,
vec![
OllamaChatMessage {
role: "system".to_string(),
content: system_prompt,
},
OllamaChatMessage {
role: "user".to_string(),
content: user_content,
},
],
None,
)
.await
}
/// Provider-agnostic chat-completions dispatcher used by every LLM-driven feature
/// (chat agent, generate_sql, explain_sql, fix_sql_error). Returns the model's
/// raw text content.
pub(crate) async fn call_chat_messages( pub(crate) async fn call_chat_messages(
app: &AppHandle, app: &AppHandle,
state: &AppState, state: &AppState,
@@ -223,14 +224,49 @@ pub(crate) async fn call_chat_messages(
match settings.provider { match settings.provider {
AiProvider::Ollama => call_ollama(&settings, messages, format).await, AiProvider::Ollama => call_ollama(&settings, messages, format).await,
AiProvider::Fireworks => call_fireworks(&settings, messages, format).await, AiProvider::Fireworks => {
AiProvider::OpenAi | AiProvider::Anthropic => Err(TuskError::Ai(format!( let api_key = require_api_key(
"Provider {:?} not implemented yet", settings.fireworks_api_key.as_deref(),
settings.provider "Fireworks API key not set. Open AI settings to add it.",
))), )?;
call_openai_compatible(
&settings,
FIREWORKS_BASE_URL,
&api_key,
"Fireworks",
&[],
messages,
format,
)
.await
}
AiProvider::OpenRouter => {
let api_key = require_api_key(
settings.openrouter_api_key.as_deref(),
"OpenRouter API key not set. Open AI settings to add it.",
)?;
call_openai_compatible(
&settings,
OPENROUTER_BASE_URL,
&api_key,
"OpenRouter",
OPENROUTER_HEADERS,
messages,
format,
)
.await
}
} }
} }
/// Trim and validate an optional API key, returning a user-facing error when
/// it's missing or blank.
fn require_api_key(key: Option<&str>, missing_msg: &str) -> TuskResult<String> {
key.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.ok_or_else(|| TuskError::Ai(missing_msg.to_string()))
}
async fn call_ollama( async fn call_ollama(
settings: &AiSettings, settings: &AiSettings,
messages: Vec<OllamaChatMessage>, messages: Vec<OllamaChatMessage>,
@@ -277,21 +313,18 @@ async fn call_ollama(
.await .await
} }
async fn call_fireworks( /// Chat-completions call for any OpenAI-compatible provider (Fireworks,
/// OpenRouter). `extra_headers` carries provider-specific attribution headers.
async fn call_openai_compatible(
settings: &AiSettings, settings: &AiSettings,
base_url: &str,
api_key: &str,
provider_label: &str,
extra_headers: &[(&str, &str)],
messages: Vec<OllamaChatMessage>, messages: Vec<OllamaChatMessage>,
format: Option<String>, format: Option<String>,
) -> TuskResult<String> { ) -> TuskResult<String> {
let api_key = settings let url = format!("{}/chat/completions", base_url);
.fireworks_api_key
.clone()
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.ok_or_else(|| {
TuskError::Ai("Fireworks API key not set. Open AI settings to add it.".to_string())
})?;
let url = format!("{}/chat/completions", FIREWORKS_BASE_URL);
let response_format = format.as_deref().map(|f| FireworksResponseFormat { let response_format = format.as_deref().map(|f| FireworksResponseFormat {
kind: if f == "json" { kind: if f == "json" {
"json_object".to_string() "json_object".to_string()
@@ -307,19 +340,22 @@ async fn call_fireworks(
response_format, response_format,
}; };
call_ai_with_retry(settings, "Fireworks request", || { let operation = format!("{} request", provider_label);
call_ai_with_retry(settings, &operation, || {
let url = url.clone(); let url = url.clone();
let request = request.clone(); let request = request.clone();
let api_key = api_key.clone(); let api_key = api_key.to_string();
async move { async move {
let resp = http_client() let mut req = http_client().post(&url).bearer_auth(&api_key);
.post(&url) for (name, value) in extra_headers {
.bearer_auth(&api_key) req = req.header(*name, *value);
}
let resp = req
.json(&request) .json(&request)
.send() .send()
.await .await
.map_err(|e| { .map_err(|e| {
TuskError::Ai(format!("Cannot reach Fireworks at {}: {}", url, e)) TuskError::Ai(format!("Cannot reach {} at {}: {}", provider_label, url, e))
})?; })?;
if !resp.status().is_success() { if !resp.status().is_success() {
@@ -339,13 +375,13 @@ async fn call_fireworks(
)); ));
} }
return Err(TuskError::Ai(format!( return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}", "{} error ({}): {}",
status, body provider_label, status, body
))); )));
} }
let parsed: FireworksChatResponse = resp.json().await.map_err(|e| { let parsed: FireworksChatResponse = resp.json().await.map_err(|e| {
TuskError::Ai(format!("Failed to parse Fireworks response: {}", e)) TuskError::Ai(format!("Failed to parse {} response: {}", provider_label, e))
})?; })?;
parsed parsed
@@ -353,177 +389,14 @@ async fn call_fireworks(
.into_iter() .into_iter()
.next() .next()
.map(|c| c.message.content) .map(|c| c.message.content)
.ok_or_else(|| TuskError::Ai("Fireworks returned no choices".to_string())) .ok_or_else(|| {
TuskError::Ai(format!("{} returned no choices", provider_label))
})
} }
}) })
.await .await
} }
// ---------------------------------------------------------------------------
// SQL generation
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn generate_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
prompt: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are an expert PostgreSQL query generator. You receive a database schema and a natural \
language request. Output ONLY a valid, executable PostgreSQL SQL query.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences (```), no comments, no preamble.\n\
- The output must be directly executable in psql.\n\
- For complex queries use readable formatting with line breaks and indentation.\n\
\n\
CRITICAL RULES:\n\
1. ONLY reference tables and columns that exist in the schema. Never invent names.\n\
2. Use the FOREIGN KEY information to determine correct JOIN conditions.\n\
3. Use LEFT JOIN when the FK column is nullable or the relationship is optional; \
INNER JOIN when both sides must exist.\n\
4. Every non-aggregated column in SELECT must appear in GROUP BY.\n\
5. Use COALESCE for nullable columns in aggregations: COALESCE(SUM(x), 0).\n\
6. For enum columns, use ONLY the values listed in the ENUM TYPES section.\n\
7. Use IS NULL / IS NOT NULL for null checks — never = NULL or != NULL.\n\
8. Add LIMIT when the user asks for \"top N\", \"first N\", \"latest N\", etc.\n\
9. Qualify column names with table alias when the query involves multiple tables.\n\
\n\
SEMANTIC RULES (very important):\n\
- When a table has both actual_* and planned_* columns (e.g. actual_start vs planned_start), \
they represent DIFFERENT concepts: planned = future estimate, actual = what really happened. \
NEVER mix them with COALESCE unless the user explicitly requests a fallback.\n\
- For time-based calculations involving real events (\"how long did X take\", \"average time between\"), \
use ONLY actual/factual timestamps (actual_*, started_at, completed_at, ended_at). \
Filter out NULL values with WHERE instead of falling back to planned timestamps.\n\
- Planned timestamps (planned_*, scheduled_*, estimated_*) should ONLY be used when the user \
asks about plans, schedules, SLA, or compares plan vs fact.\n\
- When computing durations or averages, always filter out rows where any involved timestamp \
is NULL rather than substituting with unrelated defaults.\n\
- Pay attention to column descriptions/comments in the schema — they reveal business semantics \
that are critical for correct queries.\n\
\n\
TYPE RULES:\n\
- timestamp - timestamp = interval. For seconds: EXTRACT(EPOCH FROM (ts1 - ts2)).\n\
- interval cannot be cast to numeric directly; use EXTRACT(EPOCH FROM interval).\n\
- UNION/UNION ALL requires matching column count and compatible types; cast enums to text.\n\
- Use ::type for PostgreSQL-style casts.\n\
- For array columns use ANY, ALL, @>, <@ operators.\n\
- For JSONB columns use ->, ->>, #>, jsonb_extract_path.\n\
\n\
COMMON PATTERNS:\n\
- FIRST/LAST per group: to find MIN(started_at) per trip, use \
\"SELECT trip_id, MIN(started_at) FROM t GROUP BY trip_id\". \
NEVER put the aggregated column (started_at) into GROUP BY — that defeats the aggregation \
and returns every row separately instead of one per group.\n\
- TOP-1 per group with extra columns: use DISTINCT ON (group_col) ... ORDER BY group_col, sort_col \
or a subquery with ROW_NUMBER() OVER (PARTITION BY group_col ORDER BY sort_col) = 1.\n\
- For \"time from A to B\" calculations, ensure both timestamps are NOT NULL with WHERE filters; \
never use COALESCE to mix planned and actual timestamps.\n\
\n\
BEST PRACTICES:\n\
- Use ILIKE for case-insensitive text search, LIKE for case-sensitive.\n\
- Use EXISTS instead of IN for subquery existence checks.\n\
- Use CTE (WITH ... AS) for complex multi-step logic.\n\
- Use window functions (ROW_NUMBER, RANK, LAG, LEAD, SUM OVER) for ranking and running totals.\n\
- Use date_trunc('period', column) for time-based grouping.\n\
- Use generate_series() for creating ranges.\n\
- Use string_agg(col, ', ') for concatenating grouped values.\n\
- Use FILTER (WHERE ...) for conditional aggregation instead of CASE inside aggregate.\n\
\n\
{}\n",
schema_text
);
let raw = call_chat_simple(&app, &state, system_prompt, prompt).await?;
Ok(clean_sql_response(&raw))
}
// ---------------------------------------------------------------------------
// SQL explanation
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn explain_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a PostgreSQL expert. Explain the given SQL query clearly and concisely.\n\
\n\
Structure your explanation as:\n\
1. **Summary** — one sentence describing what the query returns in business terms.\n\
2. **Step-by-step breakdown** — explain tables accessed, joins, filters, aggregations, \
subqueries, and sorting. Use bullet points.\n\
3. **Notes** — mention edge cases, potential issues, or performance concerns if any.\n\
\n\
Use the database schema below to understand table relationships and column meanings.\n\
Keep the explanation short; avoid restating the SQL verbatim.\n\
\n\
IMPORTANT: If you notice semantic issues (e.g. mixing planned_* and actual_* timestamps \
with COALESCE, comparing unrelated columns, missing NULL filters on nullable timestamps), \
mention them in the Notes section as potential problems.\n\
\n\
{}\n",
schema_text
);
call_chat_simple(&app, &state, system_prompt, sql).await
}
// ---------------------------------------------------------------------------
// SQL error fixing
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn fix_sql_error(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
error_message: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a PostgreSQL expert debugger. You receive a SQL query and the error it produced. \
Fix the query so it executes correctly.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences (```), no comments.\n\
- The output must be directly executable.\n\
\n\
DIAGNOSTIC CHECKLIST:\n\
- Column/table does not exist → check the schema for correct names and spelling.\n\
- Column is ambiguous → qualify with table name or alias.\n\
- Must appear in GROUP BY → add missing non-aggregated columns to GROUP BY.\n\
- Type mismatch → add appropriate casts (::text, ::integer, etc.).\n\
- Permission denied → wrap in a read-only transaction if needed.\n\
- Syntax error → correct PostgreSQL syntax (check commas, parentheses, keywords).\n\
- Subquery returns more than one row → use IN, ANY, or add LIMIT 1.\n\
- Division by zero → wrap divisor with NULLIF(x, 0).\n\
\n\
ONLY use tables and columns from the schema below. Never invent names.\n\
Preserve the original intent of the query; change only what is necessary to fix the error.\n\
\n\
{}\n",
schema_text
);
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
let raw = call_chat_simple(&app, &state, system_prompt, user_content).await?;
Ok(clean_sql_response(&raw))
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Lite overview builder (chat v2) // Lite overview builder (chat v2)
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -735,231 +608,6 @@ async fn build_overview_clickhouse(state: &AppState, connection_id: &str) -> Tus
Ok(out.join("\n")) Ok(out.join("\n"))
} }
// ---------------------------------------------------------------------------
// Full schema context builder (legacy — used by generate_sql/explain_sql/fix_sql_error)
// ---------------------------------------------------------------------------
pub(crate) async fn build_schema_context(
state: &AppState,
connection_id: &str,
) -> TuskResult<String> {
// Check cache first
if let Some(cached) = state.get_schema_cache(connection_id).await {
return Ok(cached);
}
let flavor = state.get_flavor(connection_id).await;
if matches!(flavor, DbFlavor::ClickHouse) {
return build_clickhouse_schema_context(state, connection_id).await;
}
let is_greenplum = matches!(flavor, DbFlavor::Greenplum);
let gp_major = state.get_gp_major(connection_id).await.unwrap_or(7);
let pool = state.get_pool(connection_id).await?;
// Run all metadata queries in parallel for speed
let (
version_res,
current_db_res,
all_dbs_res,
col_res,
fk_res,
enum_res,
tbl_comment_res,
col_comment_res,
unique_res,
varchar_res,
jsonb_res,
) = tokio::join!(
sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool),
sqlx::query_scalar::<_, String>("SELECT current_database()").fetch_one(&pool),
sqlx::query_scalar::<_, String>(
"SELECT datname FROM pg_database \
WHERE datistemplate = false AND datallowconn = true \
ORDER BY datname"
)
.fetch_all(&pool),
fetch_columns(&pool),
fetch_foreign_keys_raw(&pool),
fetch_enum_types(&pool),
fetch_table_comments(&pool),
fetch_column_comments(&pool),
fetch_unique_constraints(&pool),
fetch_varchar_values(&pool),
fetch_jsonb_keys(&pool),
);
let version = version_res.map_err(TuskError::Database)?;
let current_db = current_db_res.unwrap_or_default();
let all_dbs = all_dbs_res.unwrap_or_default();
let col_rows = col_res?;
let fk_rows = fk_res?;
let enum_map = enum_res?;
let tbl_comments = tbl_comment_res?;
let col_comments = col_comment_res?;
let unique_constraints = unique_res?;
let varchar_values = varchar_res.unwrap_or_default();
let jsonb_keys = jsonb_res.unwrap_or_default();
let gp_extras = if is_greenplum {
Some(fetch_gp_table_extras(&pool, gp_major).await)
} else {
None
};
// -- Build FK inline lookup: (schema, table, column) -> "ref_schema.ref_table(ref_col)" --
let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new();
let mut fk_lines: Vec<String> = Vec::new();
for fk in &fk_rows {
let line = format!(
"FK: {}.{}({}) -> {}.{}({})",
fk.schema,
fk.table,
fk.columns.join(", "),
fk.ref_schema,
fk.ref_table,
fk.ref_columns.join(", ")
);
fk_lines.push(line);
// For single-column FKs, enable inline annotation on column
if fk.columns.len() == 1 && fk.ref_columns.len() == 1 {
fk_inline.insert(
(fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()),
format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]),
);
}
}
// -- Build unique constraint lookup: (schema, table) -> Vec<column_list_string> --
let mut unique_map: HashMap<(String, String), Vec<String>> = HashMap::new();
for (schema, table, cols) in &unique_constraints {
unique_map
.entry((schema.clone(), table.clone()))
.or_default()
.push(cols.join(", "));
}
// -- Format output --
let mut output: Vec<String> = Vec::new();
// 1. PostgreSQL version (short form)
let short_version = version
.split_whitespace()
.take(2)
.collect::<Vec<_>>()
.join(" ");
output.push(format!("DATABASE SCHEMA ({})", short_version));
if !current_db.is_empty() {
output.push(format!("ACTIVE DATABASE: {}", current_db));
}
output.push(String::new());
// 2. Cluster topology — other databases on this server.
// Each PG database is isolated; cross-DB queries are not possible from a single connection.
if all_dbs.len() > 1 {
output.push("DATABASES ON THIS SERVER:".to_string());
for db in &all_dbs {
if db == &current_db {
output.push(format!(" * {} (active)", db));
} else {
output.push(format!(" {}", db));
}
}
output.push(String::new());
output.push(
"NOTE: Tables in other databases are NOT queryable from this session. \
If the user's question concerns data likely stored in a different database \
(e.g. an identity service in a separate DB), respond with a `final` message \
asking them to switch the active database via the connection selector."
.to_string(),
);
output.push(String::new());
}
// 3. Quick table+column index for fast existence checks before writing SQL.
// Each line lists `schema.table(col1, col2, ...)` so the model can grep both
// table names and column names without scrolling through the full TABLES section.
{
let mut by_table: BTreeMap<(String, String), Vec<String>> = BTreeMap::new();
for c in &col_rows {
by_table
.entry((c.schema.clone(), c.table.clone()))
.or_default()
.push(c.column.clone());
}
if !by_table.is_empty() {
output.push(format!(
"TABLE INDEX (database `{}`, {} tables — table_name(column_list)):",
current_db,
by_table.len()
));
for ((schema, table), cols) in &by_table {
output.push(format!(" {}.{}({})", schema, table, cols.join(", ")));
}
output.push(String::new());
}
}
// 4. Enum types
if !enum_map.is_empty() {
output.push("ENUM TYPES:".to_string());
for (type_name, values) in &enum_map {
let vals_str = values
.iter()
.map(|v| format!("'{}'", v))
.collect::<Vec<_>>()
.join(", ");
output.push(format!(" {} = [{}]", type_name, vals_str));
}
output.push(String::new());
}
// 5. Tables with columns
output.push("TABLES:".to_string());
// Group columns by schema.table preserving order
let mut tables: BTreeMap<String, Vec<ColumnInfo>> = BTreeMap::new();
for ci in &col_rows {
let key = format!("{}.{}", ci.schema, ci.table);
tables.entry(key).or_default().push(ci.clone());
}
for (full_name, columns) in &tables {
format_table_block(
full_name,
columns,
&tbl_comments,
&col_comments,
&fk_inline,
&enum_map,
&unique_map,
&varchar_values,
&jsonb_keys,
gp_extras.as_ref(),
&mut output,
);
}
// 6. Foreign keys summary
if !fk_lines.is_empty() {
output.push(String::new());
output.push("FOREIGN KEYS:".to_string());
for fk in &fk_lines {
output.push(format!(" {}", fk));
}
}
let result = output.join("\n");
// Cache the result
state
.set_schema_cache(connection_id.to_string(), result.clone())
.await;
Ok(result)
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Schema query helpers // Schema query helpers
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -976,8 +624,7 @@ pub(crate) struct ColumnInfo {
} }
/// Render a single table's column block in the human/LLM-readable schema format. /// Render a single table's column block in the human/LLM-readable schema format.
/// Reused by both `build_schema_context` (full DDL for legacy AI commands) and /// Used by the chat agent's `get_columns` tool.
/// the new `get_columns` chat tool.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) fn format_table_block( pub(crate) fn format_table_block(
full_name: &str, full_name: &str,
@@ -1102,106 +749,6 @@ pub(crate) fn format_table_block(
} }
} }
async fn build_clickhouse_schema_context(
state: &AppState,
connection_id: &str,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let db = client.database.clone();
// ClickHouse exposes ALL databases via system.* — pull cross-DB schema in one shot.
let columns_sql = "SELECT database, table, name, type, is_in_primary_key \
FROM system.columns \
WHERE database NOT IN ('system','INFORMATION_SCHEMA','information_schema') \
ORDER BY database, table, position";
let dbs_sql = "SELECT name FROM system.databases \
WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \
ORDER BY name";
let rows = client.fetch_objects(columns_sql).await?;
let db_rows = client.fetch_objects(dbs_sql).await.unwrap_or_default();
let version = client.ping().await.unwrap_or_default();
let mut out = String::new();
out.push_str("DATABASE: ClickHouse\n");
if !version.is_empty() {
out.push_str(&format!("VERSION: {}\n", version.trim()));
}
out.push_str(&format!("ACTIVE_DATABASE: {}\n\n", db));
// Cluster overview
if !db_rows.is_empty() {
out.push_str("DATABASES ON THIS SERVER:\n");
for row in &db_rows {
let name = row.get("name").and_then(|v| v.as_str()).unwrap_or("");
if name == db {
out.push_str(&format!(" * {} (active)\n", name));
} else {
out.push_str(&format!(" {}\n", name));
}
}
out.push('\n');
}
// Table+column index — same shape as the PG path so model has a uniform reference.
{
let mut by_table: BTreeMap<(String, String), Vec<String>> = BTreeMap::new();
for r in &rows {
let dbn = r.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string();
let tbl = r.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string();
let col = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
if !dbn.is_empty() && !tbl.is_empty() && !col.is_empty() {
by_table.entry((dbn, tbl)).or_default().push(col);
}
}
if !by_table.is_empty() {
out.push_str(&format!(
"TABLE INDEX ({} tables across all databases — db.table(column_list)):\n",
by_table.len()
));
for ((dbn, tbl), cols) in &by_table {
out.push_str(&format!(" {}.{}({})\n", dbn, tbl, cols.join(", ")));
}
out.push('\n');
}
}
out.push_str("TABLES:\n");
let mut current_key: Option<String> = None;
for row in &rows {
let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string();
let table = row.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string();
let column = row.get("name").and_then(|v| v.as_str()).unwrap_or("");
let dtype = row.get("type").and_then(|v| v.as_str()).unwrap_or("");
let is_pk = matches!(row.get("is_in_primary_key"), Some(serde_json::Value::Number(n)) if n.as_i64() == Some(1))
|| matches!(row.get("is_in_primary_key"), Some(serde_json::Value::String(s)) if s == "1");
let key = format!("{}.{}", dbn, table);
if Some(&key) != current_key.as_ref() {
out.push_str(&format!("\nTABLE {}.{}\n", dbn, table));
current_key = Some(key);
}
out.push_str(&format!(
" {} {}{}\n",
column,
dtype,
if is_pk { " [PK]" } else { "" }
));
}
out.push_str(
"\nNOTES:\n\
- Use ClickHouse SQL dialect. Functions differ from PostgreSQL (e.g. count(), arrayJoin, toDate, formatDateTime).\n\
- ClickHouse allows fully-qualified `database.table` in queries — you CAN cross-reference databases on this server.\n\
- Read-only mode is enforced for the agent: only SELECT/WITH/EXPLAIN/SHOW/DESCRIBE allowed.\n\
- Always include LIMIT for ad-hoc SELECTs.\n",
);
state
.set_schema_cache(connection_id.to_string(), out.clone())
.await;
Ok(out)
}
pub(crate) async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult<Vec<ColumnInfo>> { pub(crate) async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult<Vec<ColumnInfo>> {
let rows = sqlx::query( let rows = sqlx::query(
"SELECT \ "SELECT \
@@ -1399,6 +946,7 @@ pub(crate) async fn fetch_unique_constraints(
/// Returns HashMap<(schema, table, column), Vec<distinct_values>> for varchar columns /// Returns HashMap<(schema, table, column), Vec<distinct_values>> for varchar columns
/// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery. /// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery.
/// Returns None if pg_stats is not accessible (graceful degradation). /// Returns None if pg_stats is not accessible (graceful degradation).
#[allow(dead_code)] // re-exposed by profile_table tool (PR2)
async fn fetch_varchar_values( async fn fetch_varchar_values(
pool: &sqlx::PgPool, pool: &sqlx::PgPool,
) -> Option<HashMap<(String, String, String), Vec<String>>> { ) -> Option<HashMap<(String, String, String), Vec<String>>> {
@@ -1440,104 +988,6 @@ async fn fetch_varchar_values(
Some(map) Some(map)
} }
/// Discovers top-level keys in JSONB columns by sampling actual data.
/// Runs two sequential queries internally: first discovers JSONB columns,
/// then samples keys from each via a single UNION ALL query.
/// Returns None on error (graceful degradation).
async fn fetch_jsonb_keys(
pool: &sqlx::PgPool,
) -> Option<HashMap<(String, String, String), Vec<String>>> {
// Step 1: Find all JSONB columns
let col_rows = match sqlx::query(
"SELECT table_schema, table_name, column_name \
FROM information_schema.columns \
WHERE data_type = 'jsonb' \
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
ORDER BY table_schema, table_name, column_name",
)
.fetch_all(pool)
.await
{
Ok(r) => r,
Err(e) => {
log::warn!("Failed to fetch JSONB columns: {}", e);
return None;
}
};
if col_rows.is_empty() {
return Some(HashMap::new());
}
// Cap at 50 JSONB columns to prevent unbounded UNION ALL queries on large schemas
let columns: Vec<(String, String, String)> = col_rows
.iter()
.take(50)
.map(|r| {
(
r.get::<String, _>(0),
r.get::<String, _>(1),
r.get::<String, _>(2),
)
})
.collect();
// Step 2: Build a single UNION ALL query to sample keys from all JSONB columns
let parts: Vec<String> = columns
.iter()
.enumerate()
.map(|(i, (schema, table, col))| {
let qs = schema.replace('"', "\"\"");
let qt = table.replace('"', "\"\"");
let qc = col.replace('"', "\"\"");
format!(
"(SELECT '{}.{}.{}' AS col_ref, key FROM (\
SELECT DISTINCT jsonb_object_keys(\"{}\") AS key \
FROM \"{}\".\"{}\" \
WHERE \"{}\" IS NOT NULL AND jsonb_typeof(\"{}\") = 'object' \
LIMIT 50\
) sub{})",
schema, table, col, qc, qs, qt, qc, qc, i
)
})
.collect();
let query = parts.join(" UNION ALL ");
let rows = match sqlx::query(&query).fetch_all(pool).await {
Ok(r) => r,
Err(e) => {
log::warn!("Failed to fetch JSONB keys: {}", e);
return None;
}
};
let mut map: HashMap<(String, String, String), Vec<String>> = HashMap::new();
for r in &rows {
let col_ref: String = r.get(0);
let key: String = r.get(1);
let ref_parts: Vec<&str> = col_ref.splitn(3, '.').collect();
if ref_parts.len() == 3 {
let entry = map
.entry((
ref_parts[0].to_string(),
ref_parts[1].to_string(),
ref_parts[2].to_string(),
))
.or_default();
if !entry.contains(&key) {
entry.push(key);
}
}
}
for vals in map.values_mut() {
vals.sort();
}
Some(map)
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Greenplum-specific table attributes // Greenplum-specific table attributes
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -1656,6 +1106,7 @@ pub(crate) async fn fetch_gp_table_extras(
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
/// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"} /// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"}
#[allow(dead_code)] // helper for fetch_varchar_values; re-exposed by profile_table tool (PR2)
fn parse_pg_array_text(s: &str) -> Vec<String> { fn parse_pg_array_text(s: &str) -> Vec<String> {
let s = s.trim(); let s = s.trim();
let s = s.strip_prefix('{').unwrap_or(s); let s = s.strip_prefix('{').unwrap_or(s);
@@ -1716,65 +1167,10 @@ fn simplify_default(raw: &str) -> String {
s.to_string() s.to_string()
} }
fn clean_sql_response(raw: &str) -> String {
let trimmed = raw.trim();
// Remove markdown code fences
let without_fences = if trimmed.starts_with("```") {
let inner = trimmed
.strip_prefix("```sql")
.or_else(|| trimmed.strip_prefix("```SQL"))
.or_else(|| trimmed.strip_prefix("```postgresql"))
.or_else(|| trimmed.strip_prefix("```"))
.unwrap_or(trimmed);
inner.strip_suffix("```").unwrap_or(inner)
} else {
trimmed
};
without_fences.trim().to_string()
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
// ── clean_sql_response ────────────────────────────────────
#[test]
fn clean_sql_plain() {
assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1");
}
#[test]
fn clean_sql_with_fences() {
assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1");
}
#[test]
fn clean_sql_with_generic_fences() {
assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1");
}
#[test]
fn clean_sql_with_postgresql_fences() {
assert_eq!(
clean_sql_response("```postgresql\nSELECT 1\n```"),
"SELECT 1"
);
}
#[test]
fn clean_sql_with_whitespace() {
assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1");
}
#[test]
fn clean_sql_no_fences_multiline() {
assert_eq!(
clean_sql_response("SELECT\n *\nFROM users"),
"SELECT\n *\nFROM users"
);
}
// ── Fireworks provider ─────────────────────────────────── // ── Fireworks provider ───────────────────────────────────
#[test] #[test]
@@ -1784,12 +1180,14 @@ mod tests {
} }
#[test] #[test]
fn deserializes_legacy_settings_without_fireworks_key() { fn deserializes_legacy_settings_with_dropped_provider_keys() {
// Old config files won't have `fireworks_api_key` — must still parse. // Old config files may include `openai_api_key`/`anthropic_api_key` and a
// legacy `"provider": "openai"` value — both must be tolerated, with the
// unknown provider coerced to Ollama.
let legacy = r#"{ let legacy = r#"{
"provider": "ollama", "provider": "openai",
"ollama_url": "http://localhost:11434", "ollama_url": "http://localhost:11434",
"openai_api_key": null, "openai_api_key": "sk-deprecated",
"anthropic_api_key": null, "anthropic_api_key": null,
"model": "qwen2.5-coder:7b" "model": "qwen2.5-coder:7b"
}"#; }"#;
@@ -1813,6 +1211,28 @@ mod tests {
assert_eq!(parsed.choices[0].message.content, "hi"); assert_eq!(parsed.choices[0].message.content, "hi");
} }
// ── OpenRouter provider ──────────────────────────────────
#[test]
fn serializes_openrouter_provider() {
let json = serde_json::to_string(&AiProvider::OpenRouter).unwrap();
assert_eq!(json, "\"openrouter\"");
}
#[test]
fn deserializes_openrouter_settings() {
let cfg = r#"{
"provider": "openrouter",
"ollama_url": "http://localhost:11434",
"openrouter_api_key": "sk-or-v1-abc",
"model": "anthropic/claude-3.5-sonnet"
}"#;
let parsed: AiSettings = serde_json::from_str(cfg).unwrap();
assert_eq!(parsed.provider, AiProvider::OpenRouter);
assert_eq!(parsed.openrouter_api_key.as_deref(), Some("sk-or-v1-abc"));
assert_eq!(parsed.model, "anthropic/claude-3.5-sonnet");
}
#[test] #[test]
fn parses_fireworks_models_list() { fn parses_fireworks_models_list() {
let body = r#"{ let body = r#"{

View File

@@ -1,13 +1,14 @@
use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings}; use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings};
use crate::commands::chat_tools::{ use crate::commands::chat_tools::{
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool, build_sample_sql, detect_skew_tool, explain_query_tool, find_queries_tool, get_columns_tool,
list_databases_tool, list_tables_tool, profile_table_tool, save_query_tool,
switch_database_tool, switch_database_tool,
}; };
use crate::commands::memory::{append_memory_core, read_memory_core}; use crate::commands::memory::{append_memory_core, read_memory_core};
use crate::commands::queries::execute_query_core; use crate::commands::queries::execute_query_core;
use crate::error::{TuskError, TuskResult}; use crate::error::{TuskError, TuskResult};
use crate::models::ai::OllamaChatMessage; use crate::models::ai::OllamaChatMessage;
use crate::models::chat::{ChartConfig, ChatMessage, ChatTurnResult, ContextUsage}; use crate::models::chat::{ChatMessage, ChatTurnResult, ContextUsage};
use crate::models::query_result::QueryResult; use crate::models::query_result::QueryResult;
use crate::state::AppState; use crate::state::AppState;
use chrono::Utc; use chrono::Utc;
@@ -30,11 +31,11 @@ const TEXT_TOOL_CHAR_CAP: usize = 10_000;
/// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192). /// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192).
/// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content. /// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content.
const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000; const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000;
/// Conservative default for managed providers (Fireworks). Most chat-capable /// Conservative default for managed providers (Fireworks, OpenRouter). Most
/// Fireworks models ship with 32K256K context windows; 384K chars (~128K tok) /// chat-capable hosted models ship with 32K256K context windows; 384K chars
/// is a safe floor that won't trigger false /compact nags on normal sessions /// (~128K tok) is a safe floor that won't trigger false /compact nags on normal
/// while still flagging genuinely runaway threads. /// sessions while still flagging genuinely runaway threads.
const CONTEXT_BUDGET_CHARS_FIREWORKS: u64 = 384_000; const CONTEXT_BUDGET_CHARS_MANAGED: u64 = 384_000;
/// Stop the loop when the model fails the same SQL hurdle this many times in a /// Stop the loop when the model fails the same SQL hurdle this many times in a
/// row. Beyond this, additional hops almost always burn the rest of the budget /// row. Beyond this, additional hops almost always burn the rest of the budget
/// on identical retries; a definitive `final` with the error is more useful. /// on identical retries; a definitive `final` with the error is more useful.
@@ -55,9 +56,15 @@ enum AgentAction {
Remember { note: String }, Remember { note: String },
SaveQuery { name: String, sql: String }, SaveQuery { name: String, sql: String },
FindQueries { text: String }, FindQueries { text: String },
MakeChart { config: ChartConfig }, ProfileTable { table: String },
SampleData { table: String, limit: u32 },
ExplainQuery { sql: String },
DetectSkew { table: String },
} }
const SAMPLE_DATA_DEFAULT_LIMIT: u32 = 50;
const SAMPLE_DATA_MAX_LIMIT: u32 = 200;
/// Parse the model's JSON response. Accepts both shapes the model tends to emit: /// Parse the model's JSON response. Accepts both shapes the model tends to emit:
/// {"action":"X","field":"..."} — flat (matches our prompt) /// {"action":"X","field":"..."} — flat (matches our prompt)
/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention) /// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention)
@@ -157,60 +164,55 @@ fn parse_agent_action(raw: &str) -> Result<AgentAction, String> {
} }
Ok(AgentAction::FindQueries { text }) Ok(AgentAction::FindQueries { text })
} }
"make_chart" => { "profile_table" => {
let chart_type = lookup("chart_type") let table = lookup("table")
.or_else(|| lookup("type"))
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `chart_type`".to_string())? .ok_or_else(|| "profile_table missing `table`".to_string())?
.trim()
.to_lowercase();
if !["bar", "line", "area", "pie"].contains(&chart_type.as_str()) {
return Err(format!(
"make_chart `chart_type` must be one of: bar, line, area, pie. Got: {}",
chart_type
));
}
let x = lookup("x")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `x` column".to_string())?
.trim() .trim()
.to_string(); .to_string();
let y = lookup("y") if table.is_empty() {
.and_then(|v| v.as_str()) return Err("profile_table `table` must not be empty".into());
.ok_or_else(|| "make_chart missing `y` column".to_string())?
.trim()
.to_string();
if x.is_empty() || y.is_empty() {
return Err("make_chart `x` and `y` must not be empty".into());
} }
let group = lookup("group") Ok(AgentAction::ProfileTable { table })
.and_then(|v| v.as_str()) }
.map(|s| s.trim().to_string()) "sample_data" => {
.filter(|s| !s.is_empty()); let table = lookup("table")
let title = lookup("title") .and_then(|v| v.as_str())
.and_then(|v| v.as_str()) .ok_or_else(|| "sample_data missing `table`".to_string())?
.map(|s| s.trim().to_string()) .trim()
.filter(|s| !s.is_empty()); .to_string();
let orientation = lookup("orientation") if table.is_empty() {
.and_then(|v| v.as_str()) return Err("sample_data `table` must not be empty".into());
.map(|s| s.trim().to_lowercase()) }
.filter(|s| !s.is_empty()); let limit = lookup("limit")
Ok(AgentAction::MakeChart { .and_then(|v| v.as_u64())
config: ChartConfig { .map(|n| n as u32)
chart_type, .unwrap_or(SAMPLE_DATA_DEFAULT_LIMIT)
x, .clamp(1, SAMPLE_DATA_MAX_LIMIT);
y, Ok(AgentAction::SampleData { table, limit })
group, }
title, "explain_query" => {
orientation, let sql = lookup("sql")
}, .and_then(|v| v.as_str())
}) .ok_or_else(|| "explain_query missing `sql`".to_string())?
.trim()
.to_string();
if sql.is_empty() {
return Err("explain_query `sql` must not be empty".into());
}
Ok(AgentAction::ExplainQuery { sql })
}
"detect_skew" => {
let table = lookup("table")
.and_then(|v| v.as_str())
.ok_or_else(|| "detect_skew missing `table`".to_string())?
.trim()
.to_string();
if table.is_empty() {
return Err("detect_skew `table` must not be empty".into());
}
Ok(AgentAction::DetectSkew { table })
} }
// Legacy from earlier iterations — silently ignored at parse time so the
// model can recover with a different action.
"get_schema" => Err(
"get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(),
),
other => Err(format!("unknown action `{}`", other)), other => Err(format!("unknown action `{}`", other)),
} }
} }
@@ -285,8 +287,17 @@ You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On
{{"action":"save_query","name":"<short label>","sql":"<the SQL>"}} {{"action":"save_query","name":"<short label>","sql":"<the SQL>"}}
Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved. Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved.
{{"action":"make_chart","chart_type":"bar","x":"<col>","y":"<col>","title":"<short title>"}} {{"action":"profile_table","table":"schema.table"}}
Visualise the LAST successful run_query result as a chart inline. `chart_type` is one of: bar, line, area, pie. `x` and `y` MUST be column names from the previous result. Optional: `group` (column for series), `orientation` ("vertical"/"horizontal", bar only). Use after run_query when the data is aggregated and would be clearer as a chart (top-N comparisons → bar; time series → line/area; proportions → pie). Skip for tiny results (≤2 rows) and giant ones (>500 rows). Per-column profile: NULL fraction, distinct cardinality, min/max range, top-K values. PG/GP reads pg_stats (zero-cost; ensure ANALYZE has run). ClickHouse fires one summary query (cheap on MergeTree). Use BEFORE writing aggregations to spot pseudo-enums, NULL-heavy columns, or skewed distributions.
{{"action":"sample_data","table":"schema.table","limit":50}}
Random row sample (default 50, max 200). PG/GP uses TABLESAMPLE BERNOULLI when reltuples > 0, else ORDER BY random(). CH uses SAMPLE 0.01 on MergeTree with a sampling key, else ORDER BY rand(). Use to eyeball value shape BEFORE writing filters; cheaper than `SELECT * LIMIT N` on huge tables.
{{"action":"explain_query","sql":"SELECT ..."}}
Run EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) on PG/GP, EXPLAIN PLAN on CH. Reports root node, planning + execution time, seq-scanned tables, spilled sorts, est-vs-actual row skew, Greenplum Motions. Use AFTER a slow run_query.
{{"action":"detect_skew","table":"schema.table"}}
Greenplum-only: counts rows per gp_segment_id and reports max/min/avg + skew ratio. Ratio > 1.5 ⇒ uneven distribution; suggests revisiting DISTRIBUTED BY. Soft-errors on PG/CH.
{{"action":"final","text":"..."}} {{"action":"final","text":"..."}}
End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling). End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling).
@@ -296,6 +307,8 @@ WORKFLOW
2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question. 2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question.
3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs. 3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs.
4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns. 4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns.
4a. If the user asks about value shape (cardinality, NULL rates, top values), prefer `profile_table` over a hand-written run_query. To eyeball actual rows, prefer `sample_data` over `LIMIT 100`.
4b. If the user reports a slow query or asks why something takes long, run `explain_query` on it. On Greenplum, if a single table appears unbalanced, check `detect_skew`.
5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first. 5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first.
6. Execute run_query. 6. Execute run_query.
7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here. 7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here.
@@ -415,9 +428,6 @@ fn build_history(
content: serde_json::json!({ "action": "final", "text": text }).to_string(), content: serde_json::json!({ "action": "final", "text": text }).to_string(),
}), }),
ChatMessage::ToolCall { tool, input_json, .. } => { ChatMessage::ToolCall { tool, input_json, .. } => {
if tool == "get_schema" {
continue; // legacy
}
let mut envelope = serde_json::Map::new(); let mut envelope = serde_json::Map::new();
envelope.insert("action".to_string(), Value::String(tool.clone())); envelope.insert("action".to_string(), Value::String(tool.clone()));
if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) { if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) {
@@ -437,9 +447,6 @@ fn build_history(
result, result,
.. ..
} => { } => {
if tool == "get_schema" {
continue; // legacy
}
let payload = match tool.as_str() { let payload = match tool.as_str() {
"run_query" => { "run_query" => {
if *is_error { if *is_error {
@@ -521,7 +528,7 @@ async fn provider_budget_chars(state: &AppState, app: &AppHandle) -> u64 {
use crate::models::ai::AiProvider; use crate::models::ai::AiProvider;
match load_ai_settings(app, state).await { match load_ai_settings(app, state).await {
Ok(s) => match s.provider { Ok(s) => match s.provider {
AiProvider::Fireworks => CONTEXT_BUDGET_CHARS_FIREWORKS, AiProvider::Fireworks | AiProvider::OpenRouter => CONTEXT_BUDGET_CHARS_MANAGED,
_ => CONTEXT_BUDGET_CHARS_OLLAMA, _ => CONTEXT_BUDGET_CHARS_OLLAMA,
}, },
Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA, Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA,
@@ -597,7 +604,10 @@ pub async fn chat_send(
} }
}; };
let is_run_query = matches!(&action, AgentAction::RunQuery { .. }); let is_run_query = matches!(
&action,
AgentAction::RunQuery { .. } | AgentAction::SampleData { .. }
);
match action { match action {
AgentAction::Final { text } => { AgentAction::Final { text } => {
@@ -742,91 +752,90 @@ pub async fn chat_send(
); );
push_tool_result(&mut new_messages, &mut working, result); push_tool_result(&mut new_messages, &mut working, result);
} }
AgentAction::MakeChart { config } => { AgentAction::ProfileTable { table } => {
let config_json = serde_json::to_string(&config).unwrap_or_else(|_| "{}".into());
push_tool_call( push_tool_call(
&mut new_messages, &mut new_messages,
&mut working, &mut working,
"make_chart", "profile_table",
config_json.clone(), serde_json::json!({ "table": &table }).to_string(),
); );
let result = run_text_tool(
let result_msg = match last_successful_query_result(&working) { profile_table_tool(&state, &connection_id, &table).await,
None => ChatMessage::ToolResult { "profile_table",
id: new_id("res"), );
tool: "make_chart".to_string(), push_tool_result(&mut new_messages, &mut working, result);
is_error: true, }
text: Some( AgentAction::ExplainQuery { sql } => {
"make_chart needs a successful run_query result above it. Run a SELECT first, then call make_chart." push_tool_call(
.to_string(), &mut new_messages,
), &mut working,
result: None, "explain_query",
created_at: now_ms(), serde_json::json!({ "sql": &sql }).to_string(),
}, );
Some(qr) => { let result = run_text_tool(
if !qr.columns.iter().any(|c| c == &config.x) { explain_query_tool(&state, &connection_id, &sql).await,
"explain_query",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::DetectSkew { table } => {
push_tool_call(
&mut new_messages,
&mut working,
"detect_skew",
serde_json::json!({ "table": &table }).to_string(),
);
let result = run_text_tool(
detect_skew_tool(&state, &connection_id, &table).await,
"detect_skew",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::SampleData { table, limit } => {
push_tool_call(
&mut new_messages,
&mut working,
"sample_data",
serde_json::json!({ "table": &table, "limit": limit }).to_string(),
);
let outcome = match build_sample_sql(&state, &connection_id, &table, limit).await {
Ok(sql) => match execute_query_core(&state, &connection_id, &sql).await {
Ok(qr) => {
consecutive_query_errors = 0;
ChatMessage::ToolResult { ChatMessage::ToolResult {
id: new_id("res"), id: new_id("res"),
tool: "make_chart".to_string(), tool: "sample_data".to_string(),
is_error: true,
text: Some(format!(
"x column `{}` is not in the last result. Available: {}.",
config.x,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if !qr.columns.iter().any(|c| c == &config.y) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"y column `{}` is not in the last result. Available: {}.",
config.y,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if let Some(group) = &config.group {
if !qr.columns.iter().any(|c| c == group) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"group column `{}` is not in the last result. Available: {}.",
group,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: false,
text: Some(config_json.clone()),
result: Some(qr),
created_at: now_ms(),
}
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: false, is_error: false,
text: Some(config_json.clone()), text: None,
result: Some(qr), result: Some(qr),
created_at: now_ms(), created_at: now_ms(),
} }
} }
Err(e) => {
consecutive_query_errors += 1;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "sample_data".to_string(),
is_error: true,
text: Some(format_db_error(&e)),
result: None,
created_at: now_ms(),
}
}
},
Err(e) => {
consecutive_query_errors += 1;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "sample_data".to_string(),
is_error: true,
text: Some(format_db_error(&e)),
result: None,
created_at: now_ms(),
}
} }
}; };
push_tool_result(&mut new_messages, &mut working, result_msg); push_tool_result(&mut new_messages, &mut working, outcome);
} }
} }
@@ -982,26 +991,6 @@ fn format_db_error(e: &TuskError) -> String {
e.to_string() e.to_string()
} }
/// Locate the most recent SUCCESSFUL run_query in the working thread and
/// return its full QueryResult. Used by make_chart to attach data to a chart
/// directive without relying on the model to re-send it.
fn last_successful_query_result(messages: &[ChatMessage]) -> Option<QueryResult> {
for m in messages.iter().rev() {
if let ChatMessage::ToolResult {
tool,
is_error: false,
result: Some(qr),
..
} = m
{
if tool == "run_query" {
return Some(qr.clone());
}
}
}
None
}
/// Pull the most recent run_query error text from the working thread, so the /// Pull the most recent run_query error text from the working thread, so the
/// post-loop "I gave up" summary can quote concrete errors back to the user. /// post-loop "I gave up" summary can quote concrete errors back to the user.
fn last_run_query_error(messages: &[ChatMessage]) -> Option<String> { fn last_run_query_error(messages: &[ChatMessage]) -> Option<String> {
@@ -1484,119 +1473,6 @@ mod tests {
assert!(last_run_query_error(&msgs).is_none()); assert!(last_run_query_error(&msgs).is_none());
} }
#[test]
fn parses_make_chart_minimal() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"bar","x":"carrier","y":"trips"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.chart_type, "bar");
assert_eq!(config.x, "carrier");
assert_eq!(config.y, "trips");
assert!(config.group.is_none());
assert!(config.title.is_none());
}
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_make_chart_with_group_and_title() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"line","x":"month","y":"revenue","group":"region","title":"Revenue"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.group.as_deref(), Some("region"));
assert_eq!(config.title.as_deref(), Some("Revenue"));
}
_ => panic!("wrong variant"),
}
}
#[test]
fn make_chart_accepts_alternative_field_name_type() {
// Some models emit `type` instead of `chart_type`.
let a = parse_agent_action(
r#"{"action":"make_chart","type":"pie","x":"label","y":"value"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => assert_eq!(config.chart_type, "pie"),
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_make_chart_with_unknown_chart_type() {
let r = parse_agent_action(
r#"{"action":"make_chart","chart_type":"radar","x":"a","y":"b"}"#,
);
assert!(r.is_err());
}
#[test]
fn rejects_make_chart_missing_x_or_y() {
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","y":"a"}"#).is_err());
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","x":"a"}"#).is_err());
}
#[test]
fn last_successful_query_result_finds_recent() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![vec![Value::Number(1.into())]],
row_count: 1,
execution_time_ms: 1,
};
let msgs = vec![
ChatMessage::ToolResult {
id: "r1".into(),
tool: "run_query".into(),
is_error: false,
text: None,
result: Some(qr.clone()),
created_at: 1,
},
ChatMessage::ToolResult {
id: "r2".into(),
tool: "run_query".into(),
is_error: true,
text: Some("oops".into()),
result: None,
created_at: 2,
},
];
let found = last_successful_query_result(&msgs).expect("ok");
assert_eq!(found.columns, vec!["a".to_string()]);
}
#[test]
fn last_successful_query_result_skips_non_run_query() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![],
row_count: 0,
execution_time_ms: 0,
};
let msgs = vec![ChatMessage::ToolResult {
id: "r1".into(),
tool: "list_tables".into(),
is_error: false,
text: Some("public.x".into()),
result: Some(qr),
created_at: 1,
}];
assert!(last_successful_query_result(&msgs).is_none());
}
#[test] #[test]
fn render_thread_for_summary_includes_roles_and_skips_rows() { fn render_thread_for_summary_includes_roles_and_skips_rows() {
let msgs = vec![ let msgs = vec![
@@ -1625,11 +1501,6 @@ mod tests {
assert!(!rendered.contains("alice")); assert!(!rendered.contains("alice"));
} }
#[test]
fn rejects_legacy_get_schema() {
assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err());
}
#[test] #[test]
fn truncates_long_cell() { fn truncates_long_cell() {
let long = "a".repeat(CELL_CHAR_CAP + 50); let long = "a".repeat(CELL_CHAR_CAP + 50);

View File

@@ -10,11 +10,14 @@ use crate::commands::ai::{
ColumnInfo, ColumnInfo,
}; };
use crate::commands::connections::{load_connection_config, switch_database_core}; use crate::commands::connections::{load_connection_config, switch_database_core};
use crate::commands::queries::execute_query_core;
use crate::commands::saved_queries::{list_saved_queries_core, save_query_core}; use crate::commands::saved_queries::{list_saved_queries_core, save_query_core};
use crate::commands::schema::{list_databases_core, list_tables_core}; use crate::commands::schema::{list_databases_core, list_tables_core};
use crate::db::sql_guard::ensure_readonly_sql;
use crate::error::{TuskError, TuskResult}; use crate::error::{TuskError, TuskResult};
use crate::models::saved_queries::SavedQuery; use crate::models::saved_queries::SavedQuery;
use crate::state::{AppState, CachedVec, DbFlavor}; use crate::state::{AppState, CachedVec, DbFlavor};
use crate::utils::escape_ident;
use sqlx::{PgPool, Row}; use sqlx::{PgPool, Row};
use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeMap, HashMap};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@@ -565,3 +568,690 @@ pub async fn find_queries_tool(
Ok(out) Ok(out)
} }
// ---------------------------------------------------------------------------
// profile_table (PR2 — data-engineering tool)
// ---------------------------------------------------------------------------
const PROFILE_TABLE_MAX_COLUMNS: usize = 30;
const PROFILE_TABLE_TOPK: usize = 5;
pub async fn profile_table_tool(
state: &AppState,
connection_id: &str,
table: &str,
) -> TuskResult<String> {
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
let flavor = state.get_flavor(connection_id).await;
match flavor {
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
profile_table_postgres(state, connection_id, &schema, &tbl).await
}
DbFlavor::ClickHouse => profile_table_clickhouse(state, connection_id, &schema, &tbl).await,
}
}
async fn profile_table_postgres(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
) -> TuskResult<String> {
let pool = state.get_pool(connection_id).await?;
let exists = sqlx::query_scalar::<_, i64>(
"SELECT 1 FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE n.nspname = $1 AND c.relname = $2 LIMIT 1",
)
.bind(schema)
.bind(table)
.fetch_optional(&pool)
.await
.map_err(TuskError::Database)?;
if exists.is_none() {
return Err(TuskError::Custom(format!(
"Table '{}.{}' does not exist (or no privileges).",
schema, table
)));
}
let last_analyze: Option<chrono::DateTime<chrono::Utc>> = sqlx::query_scalar(
"SELECT GREATEST(last_analyze, last_autoanalyze) FROM pg_stat_user_tables \
WHERE schemaname = $1 AND relname = $2",
)
.bind(schema)
.bind(table)
.fetch_optional(&pool)
.await
.ok()
.flatten();
let stat_rows = sqlx::query(
"SELECT attname, null_frac, n_distinct, \
most_common_vals::text, most_common_freqs, histogram_bounds::text \
FROM pg_stats \
WHERE schemaname = $1 AND tablename = $2 \
ORDER BY attname",
)
.bind(schema)
.bind(table)
.fetch_all(&pool)
.await
.map_err(TuskError::Database)?;
let mut out = format!("PROFILE {}.{}\n", schema, table);
match last_analyze {
Some(ts) => out.push_str(&format!("Last ANALYZE: {}\n", ts.to_rfc3339())),
None => out.push_str("Last ANALYZE: never\n"),
}
if stat_rows.is_empty() {
out.push_str(&format!(
"\nNo statistics in pg_stats. Run: ANALYZE {}.{};\n",
escape_ident(schema),
escape_ident(table)
));
return Ok(out);
}
let total = stat_rows.len();
let take = total.min(PROFILE_TABLE_MAX_COLUMNS);
out.push_str(&format!("\n{} columns with stats\n", total));
for r in stat_rows.iter().take(take) {
let attname: String = r.get(0);
let null_frac: f32 = r.try_get(1).unwrap_or(0.0);
let n_distinct: f32 = r.try_get(2).unwrap_or(0.0);
let mcv_text: Option<String> = r.try_get(3).ok();
let mcf_arr: Option<Vec<f32>> = r.try_get(4).ok();
let hist_text: Option<String> = r.try_get(5).ok();
out.push_str(&format!("\n {}:\n", attname));
out.push_str(&format!(" null_frac: {:.4}\n", null_frac));
if n_distinct < 0.0 {
out.push_str(&format!(
" n_distinct: {:.3} (ratio of total rows)\n",
-n_distinct
));
} else {
out.push_str(&format!(" n_distinct: {}\n", n_distinct as i64));
}
if let Some(text) = hist_text.as_deref() {
let bounds = parse_pg_array_text_local(text);
if let (Some(min), Some(max)) = (bounds.first(), bounds.last()) {
out.push_str(&format!(" range: {}{}\n", min, max));
}
}
if let Some(text) = mcv_text.as_deref() {
let vals = parse_pg_array_text_local(text);
if !vals.is_empty() {
let freqs = mcf_arr.unwrap_or_default();
let pairs: Vec<String> = vals
.iter()
.take(PROFILE_TABLE_TOPK)
.enumerate()
.map(|(i, v)| match freqs.get(i) {
Some(f) => format!("{}({:.3})", v, f),
None => v.clone(),
})
.collect();
out.push_str(&format!(" top: {}\n", pairs.join(", ")));
}
}
}
if total > take {
out.push_str(&format!("\n…and {} more columns\n", total - take));
}
Ok(out)
}
/// Local pg-array parser used by profile_table; mirrors `parse_pg_array_text` in ai.rs
/// but kept local to avoid importing a private helper.
fn parse_pg_array_text_local(s: &str) -> Vec<String> {
let s = s.trim();
let s = s.strip_prefix('{').unwrap_or(s);
let s = s.strip_suffix('}').unwrap_or(s);
if s.is_empty() {
return Vec::new();
}
let mut out = Vec::new();
let mut cur = String::new();
let mut in_quotes = false;
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
match c {
'"' if !in_quotes => in_quotes = true,
'"' if in_quotes => {
if chars.peek() == Some(&'"') {
cur.push('"');
chars.next();
} else {
in_quotes = false;
}
}
',' if !in_quotes => {
out.push(std::mem::take(&mut cur));
}
'\\' if in_quotes => {
if let Some(next) = chars.next() {
cur.push(next);
}
}
other => cur.push(other),
}
}
if !cur.is_empty() || s.ends_with(',') {
out.push(cur);
}
out
}
async fn profile_table_clickhouse(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let active_db = client.database.clone();
let dbn = if schema == "public" || schema.is_empty() {
active_db
} else {
schema.to_string()
};
let cols_sql = format!(
"SELECT name, type FROM system.columns \
WHERE database = '{}' AND table = '{}' \
ORDER BY position LIMIT {}",
dbn.replace('\'', "\\'"),
table.replace('\'', "\\'"),
PROFILE_TABLE_MAX_COLUMNS
);
let col_rows = client.fetch_objects(&cols_sql).await?;
if col_rows.is_empty() {
return Err(TuskError::Custom(format!(
"Table '{}.{}' does not exist (or no privileges).",
dbn, table
)));
}
let mut select_parts: Vec<String> = vec!["count() AS rows_total".to_string()];
let mut col_names: Vec<String> = Vec::new();
let mut col_types: Vec<String> = Vec::new();
for r in &col_rows {
let name = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
let dtype = r.get("type").and_then(|v| v.as_str()).unwrap_or("").to_string();
if name.is_empty() {
continue;
}
col_names.push(name.clone());
col_types.push(dtype);
let q = name.replace('`', "``");
select_parts.push(format!("countIf(`{}` IS NULL) AS null_{}", q, col_names.len()));
select_parts.push(format!("uniqHLL12(`{}`) AS dist_{}", q, col_names.len()));
select_parts.push(format!("toString(min(`{}`)) AS min_{}", q, col_names.len()));
select_parts.push(format!("toString(max(`{}`)) AS max_{}", q, col_names.len()));
select_parts.push(format!(
"arrayStringConcat(arrayMap(x -> toString(x), topK({})(`{}`)), '|') AS top_{}",
PROFILE_TABLE_TOPK,
q,
col_names.len()
));
}
let agg_sql = format!(
"SELECT {} FROM `{}`.`{}`",
select_parts.join(", "),
dbn.replace('`', "``"),
table.replace('`', "``")
);
let agg_rows = client.fetch_objects(&agg_sql).await?;
let row = agg_rows
.first()
.ok_or_else(|| TuskError::Custom("ClickHouse returned no row for profile aggregate".into()))?;
let rows_total = row
.get("rows_total")
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
.unwrap_or(0);
let mut out = format!(
"PROFILE {}.{}\nRows: {}\n{} columns profiled\n",
dbn,
table,
rows_total,
col_names.len()
);
for (i, name) in col_names.iter().enumerate() {
let n = i + 1;
let nulls = row
.get(&format!("null_{}", n))
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
.unwrap_or(0);
let dist = row
.get(&format!("dist_{}", n))
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
.unwrap_or(0);
let min = row.get(&format!("min_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
let max = row.get(&format!("max_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
let top_raw = row.get(&format!("top_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
out.push_str(&format!("\n {} ({}):\n", name, col_types[i]));
let null_frac = if rows_total > 0 {
nulls as f64 / rows_total as f64
} else {
0.0
};
out.push_str(&format!(" null_frac: {:.4}\n", null_frac));
out.push_str(&format!(" distinct (HLL): {}\n", dist));
if !min.is_empty() || !max.is_empty() {
out.push_str(&format!(" range: {}{}\n", min, max));
}
if !top_raw.is_empty() {
let top_vals: Vec<&str> = top_raw.split('|').take(PROFILE_TABLE_TOPK).collect();
out.push_str(&format!(" top: {}\n", top_vals.join(", ")));
}
}
if col_rows.len() == PROFILE_TABLE_MAX_COLUMNS {
out.push_str(&format!(
"\n…showing first {} columns\n",
PROFILE_TABLE_MAX_COLUMNS
));
}
Ok(out)
}
// ---------------------------------------------------------------------------
// sample_data (PR2 — returns SQL string; dispatch site runs it through
// execute_query_core so the QueryResult feeds the standard renderer)
// ---------------------------------------------------------------------------
pub async fn build_sample_sql(
state: &AppState,
connection_id: &str,
table: &str,
limit: u32,
) -> TuskResult<String> {
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
let flavor = state.get_flavor(connection_id).await;
match flavor {
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
build_sample_sql_postgres(state, connection_id, &schema, &tbl, limit).await
}
DbFlavor::ClickHouse => {
build_sample_sql_clickhouse(state, connection_id, &schema, &tbl, limit).await
}
}
}
async fn build_sample_sql_postgres(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
limit: u32,
) -> TuskResult<String> {
let pool = state.get_pool(connection_id).await?;
let reltuples: f64 = sqlx::query_scalar(
"SELECT c.reltuples FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE n.nspname = $1 AND c.relname = $2",
)
.bind(schema)
.bind(table)
.fetch_optional(&pool)
.await
.map_err(TuskError::Database)?
.unwrap_or(0.0);
let qualified = format!("{}.{}", escape_ident(schema), escape_ident(table));
if reltuples > 0.0 {
let target = limit as f64 * 100.0 / reltuples;
let percent = target.clamp(0.01, 100.0);
Ok(format!(
"SELECT * FROM {} TABLESAMPLE BERNOULLI({:.4}) LIMIT {}",
qualified, percent, limit
))
} else {
Ok(format!(
"SELECT * FROM {} ORDER BY random() LIMIT {}",
qualified, limit
))
}
}
async fn build_sample_sql_clickhouse(
state: &AppState,
connection_id: &str,
schema: &str,
table: &str,
limit: u32,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let active_db = client.database.clone();
let dbn = if schema == "public" || schema.is_empty() {
active_db
} else {
schema.to_string()
};
let info_sql = format!(
"SELECT engine, sampling_key FROM system.tables \
WHERE database = '{}' AND name = '{}' LIMIT 1",
dbn.replace('\'', "\\'"),
table.replace('\'', "\\'")
);
let rows = client.fetch_objects(&info_sql).await.unwrap_or_default();
let (engine, sampling_key) = match rows.first() {
Some(r) => (
r.get("engine").and_then(|v| v.as_str()).unwrap_or("").to_string(),
r.get("sampling_key").and_then(|v| v.as_str()).unwrap_or("").to_string(),
),
None => (String::new(), String::new()),
};
let qualified = format!(
"`{}`.`{}`",
dbn.replace('`', "``"),
table.replace('`', "``")
);
if engine.starts_with("Merge") && !sampling_key.trim().is_empty() {
Ok(format!(
"SELECT * FROM {} SAMPLE 0.01 LIMIT {}",
qualified, limit
))
} else {
Ok(format!(
"SELECT * FROM {} ORDER BY rand() LIMIT {}",
qualified, limit
))
}
}
// ---------------------------------------------------------------------------
// explain_query (PR2)
// ---------------------------------------------------------------------------
pub async fn explain_query_tool(
state: &AppState,
connection_id: &str,
sql: &str,
) -> TuskResult<String> {
let trimmed = sql.trim();
if trimmed.is_empty() {
return Err(TuskError::Custom("explain_query: sql must not be empty".into()));
}
// Validate the user's statement BEFORE prefixing EXPLAIN so the error message
// references their SQL, not the wrapper. ensure_readonly_sql also rejects any
// forbidden keywords (INSERT/UPDATE/DELETE/...) even nested under EXPLAIN.
ensure_readonly_sql(trimmed).map_err(|e| TuskError::Custom(e.to_string()))?;
let flavor = state.get_flavor(connection_id).await;
match flavor {
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
explain_query_postgres(state, connection_id, trimmed).await
}
DbFlavor::ClickHouse => explain_query_clickhouse(state, connection_id, trimmed).await,
}
}
async fn explain_query_postgres(
state: &AppState,
connection_id: &str,
sql: &str,
) -> TuskResult<String> {
let pool = state.get_pool(connection_id).await?;
let plan_sql = format!("EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) {}", sql);
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
sqlx::query("SET TRANSACTION READ ONLY")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
let row = sqlx::query(&plan_sql)
.fetch_one(&mut *tx)
.await
.map_err(TuskError::Database)?;
let _ = tx.rollback().await;
let raw_json: serde_json::Value = match row.try_get::<serde_json::Value, _>(0) {
Ok(v) => v,
Err(_) => {
let s: String = row.try_get(0).map_err(TuskError::Database)?;
serde_json::from_str(&s)
.map_err(|e| TuskError::Custom(format!("EXPLAIN JSON parse failed: {}", e)))?
}
};
let plans = raw_json
.as_array()
.ok_or_else(|| TuskError::Custom("EXPLAIN JSON: expected array".into()))?;
let plan = plans.first().and_then(|p| p.get("Plan")).ok_or_else(|| {
TuskError::Custom("EXPLAIN JSON: missing top-level Plan node".into())
})?;
let root_node = plan.get("Node Type").and_then(|v| v.as_str()).unwrap_or("?");
let total_cost = plan.get("Total Cost").and_then(|v| v.as_f64()).unwrap_or(0.0);
let planning = plans
.first()
.and_then(|p| p.get("Planning Time").and_then(|v| v.as_f64()))
.unwrap_or(0.0);
let execution = plans
.first()
.and_then(|p| p.get("Execution Time").and_then(|v| v.as_f64()))
.unwrap_or(0.0);
let mut seq_scans: Vec<String> = Vec::new();
let mut spilled: Vec<String> = Vec::new();
let mut motions: Vec<String> = Vec::new();
let mut max_skew: Option<(f64, String)> = None;
walk_pg_plan(plan, &mut seq_scans, &mut spilled, &mut motions, &mut max_skew);
let mut out = format!(
"PLAN root: {}, total cost {:.1}\nPlanning: {:.2} ms Execution: {:.2} ms\n",
root_node, total_cost, planning, execution
);
if !seq_scans.is_empty() {
out.push_str(&format!("Seq scans on: {}\n", seq_scans.join(", ")));
}
if !spilled.is_empty() {
out.push_str(&format!("Spilled to disk: {}\n", spilled.join(", ")));
}
if !motions.is_empty() {
out.push_str(&format!("Motions (Greenplum): {}\n", motions.join(", ")));
}
if let Some((ratio, node)) = max_skew {
if ratio >= 5.0 {
out.push_str(&format!(
"Estimate skew: max plan/actual ratio = {:.1} on {}\n",
ratio, node
));
}
}
if seq_scans.is_empty() && spilled.is_empty() && motions.is_empty() {
out.push_str("No obvious red flags.\n");
}
Ok(out)
}
fn walk_pg_plan(
node: &serde_json::Value,
seq_scans: &mut Vec<String>,
spilled: &mut Vec<String>,
motions: &mut Vec<String>,
max_skew: &mut Option<(f64, String)>,
) {
let node_type = node.get("Node Type").and_then(|v| v.as_str()).unwrap_or("");
if node_type == "Seq Scan" {
let rel = node
.get("Relation Name")
.and_then(|v| v.as_str())
.unwrap_or("?");
let schema = node
.get("Schema")
.and_then(|v| v.as_str())
.map(|s| format!("{}.", s))
.unwrap_or_default();
seq_scans.push(format!("{}{}", schema, rel));
}
if let Some(method) = node.get("Sort Method").and_then(|v| v.as_str()) {
if method.contains("disk") || method.contains("external") {
spilled.push(format!("Sort ({})", method));
}
}
if node_type.contains("Motion") {
motions.push(node_type.to_string());
}
let plan_rows = node.get("Plan Rows").and_then(|v| v.as_f64()).unwrap_or(0.0);
let actual_rows = node.get("Actual Rows").and_then(|v| v.as_f64()).unwrap_or(0.0);
if actual_rows > 0.0 && plan_rows > 0.0 {
let ratio = (plan_rows / actual_rows).max(actual_rows / plan_rows);
if max_skew.as_ref().map(|(r, _)| ratio > *r).unwrap_or(true) {
*max_skew = Some((ratio, node_type.to_string()));
}
}
if let Some(children) = node.get("Plans").and_then(|v| v.as_array()) {
for child in children {
walk_pg_plan(child, seq_scans, spilled, motions, max_skew);
}
}
}
async fn explain_query_clickhouse(
state: &AppState,
connection_id: &str,
sql: &str,
) -> TuskResult<String> {
let client = state.get_ch_client(connection_id).await?;
let plan_sql = format!("EXPLAIN PLAN {}", sql);
let qr = client.execute_query(&plan_sql, true).await?;
if qr.rows.is_empty() {
return Ok("(empty plan)".to_string());
}
let mut out = String::from("ClickHouse plan:\n");
for row in &qr.rows {
if let Some(cell) = row.first() {
if let Some(s) = cell.as_str() {
out.push_str(s);
out.push('\n');
}
}
}
Ok(out)
}
// ---------------------------------------------------------------------------
// detect_skew (PR2 — Greenplum-only)
// ---------------------------------------------------------------------------
pub async fn detect_skew_tool(
state: &AppState,
connection_id: &str,
table: &str,
) -> TuskResult<String> {
let flavor = state.get_flavor(connection_id).await;
if !matches!(flavor, DbFlavor::Greenplum) {
return Ok("detect_skew is only available on Greenplum connections.".to_string());
}
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&tbl));
let sql = format!(
"SELECT gp_segment_id, COUNT(*) AS n FROM {} GROUP BY 1 ORDER BY 1",
qualified
);
let qr = execute_query_core(state, connection_id, &sql).await?;
let mut counts: Vec<(i64, i64)> = Vec::new();
for row in &qr.rows {
let seg = row
.get(0)
.and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok())))
.unwrap_or(0);
let n = row
.get(1)
.and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok())))
.unwrap_or(0);
counts.push((seg, n));
}
if counts.is_empty() {
return Ok(format!("Table {}.{} is empty.", schema, tbl));
}
let total: i64 = counts.iter().map(|(_, n)| *n).sum();
let max = counts.iter().map(|(_, n)| *n).max().unwrap_or(0);
let min = counts.iter().map(|(_, n)| *n).min().unwrap_or(0);
let avg = total as f64 / counts.len() as f64;
let ratio = if avg > 0.0 { max as f64 / avg } else { 0.0 };
let mut out = format!(
"Per-segment row distribution for {}.{}\nsegments: {} total rows: {}\nmin: {} max: {} avg: {:.0}\nskew ratio (max/avg): {:.2}",
schema,
tbl,
counts.len(),
total,
min,
max,
avg,
ratio
);
if ratio > 1.5 {
out.push_str(" ⚠ uneven distribution\n");
} else {
out.push_str(" OK — within 1.5x of average\n");
}
let pool = state.get_pool(connection_id).await?;
if let Some(policy) = fetch_gp_distribution_for(&pool, &schema, &tbl).await {
out.push_str(&format!("\nCurrent policy: {}\n", policy));
if ratio > 1.5 {
out.push_str(
"Hint: pick a higher-cardinality column. Run profile_table to compare n_distinct.\n",
);
}
}
Ok(out)
}
/// Fetch the Greenplum DISTRIBUTED BY policy for a single table. Returns None if
/// the catalog query fails (non-GP connection, missing privileges, etc.).
async fn fetch_gp_distribution_for(
pool: &PgPool,
schema: &str,
table: &str,
) -> Option<String> {
let row = sqlx::query(
"SELECT COALESCE(\
(SELECT array_agg(a.attname ORDER BY ord.idx) \
FROM regexp_split_to_table(NULLIF(trim(p.distkey::text), ''), ' ') \
WITH ORDINALITY AS ord(attnum_str, idx) \
JOIN pg_attribute a \
ON a.attrelid = c.oid \
AND a.attnum::int = ord.attnum_str::int), \
ARRAY[]::text[] \
) AS dist_columns \
FROM gp_distribution_policy p \
JOIN pg_class c ON p.localoid = c.oid \
JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE n.nspname = $1 AND c.relname = $2",
)
.bind(schema)
.bind(table)
.fetch_optional(pool)
.await
.ok()
.flatten()?;
let cols: Vec<String> = row.try_get(0).ok()?;
Some(if cols.is_empty() {
"DISTRIBUTED RANDOMLY".to_string()
} else {
format!("DISTRIBUTED BY ({})", cols.join(", "))
})
}

View File

@@ -111,9 +111,7 @@ pub fn run() {
commands::ai::save_ai_settings, commands::ai::save_ai_settings,
commands::ai::list_ollama_models, commands::ai::list_ollama_models,
commands::ai::list_fireworks_models, commands::ai::list_fireworks_models,
commands::ai::generate_sql, commands::ai::list_openrouter_models,
commands::ai::explain_sql,
commands::ai::fix_sql_error,
// chat // chat
commands::chat::chat_send, commands::chat::chat_send,
commands::chat::chat_compact, commands::chat::chat_compact,

View File

@@ -1,13 +1,26 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum AiProvider { pub enum AiProvider {
#[default] #[default]
Ollama, Ollama,
OpenAi,
Anthropic,
Fireworks, Fireworks,
OpenRouter,
}
/// Deserialize a provider string, coercing legacy `openai`/`anthropic` and any
/// unknown value to `Ollama`. Keeps existing config files loadable after the
/// stub providers were removed.
impl<'de> Deserialize<'de> for AiProvider {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
Ok(match s.as_str() {
"fireworks" => AiProvider::Fireworks,
"openrouter" => AiProvider::OpenRouter,
_ => AiProvider::Ollama,
})
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -15,11 +28,9 @@ pub struct AiSettings {
pub provider: AiProvider, pub provider: AiProvider,
pub ollama_url: String, pub ollama_url: String,
#[serde(default)] #[serde(default)]
pub openai_api_key: Option<String>,
#[serde(default)]
pub anthropic_api_key: Option<String>,
#[serde(default)]
pub fireworks_api_key: Option<String>, pub fireworks_api_key: Option<String>,
#[serde(default)]
pub openrouter_api_key: Option<String>,
pub model: String, pub model: String,
} }
@@ -28,9 +39,8 @@ impl Default for AiSettings {
Self { Self {
provider: AiProvider::Ollama, provider: AiProvider::Ollama,
ollama_url: "http://localhost:11434".to_string(), ollama_url: "http://localhost:11434".to_string(),
openai_api_key: None,
anthropic_api_key: None,
fireworks_api_key: None, fireworks_api_key: None,
openrouter_api_key: None,
model: String::new(), model: String::new(),
} }
} }
@@ -71,7 +81,9 @@ pub struct OllamaModel {
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Fireworks (OpenAI-compatible chat-completions) // OpenAI-compatible chat-completions (Fireworks, OpenRouter)
// These request/response shapes are shared by every OpenAI-compatible provider;
// the `Fireworks*` names are retained for historical reasons.
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]

View File

@@ -31,18 +31,3 @@ pub struct ChatTurnResult {
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
pub usage: ContextUsage, pub usage: ContextUsage,
} }
/// Chart configuration produced by the agent's `make_chart` tool.
/// Embedded as JSON in `ToolResult.text` for tool == "make_chart" while the
/// underlying data lives in `ToolResult.result`. The frontend reads both to
/// render the chart inline.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChartConfig {
pub chart_type: String, // "bar" | "line" | "area" | "pie"
pub x: String, // column name for X axis / category
pub y: String, // column name for Y axis / numeric value
pub group: Option<String>, // optional column for series grouping
pub title: Option<String>,
pub orientation: Option<String>, // "vertical" | "horizontal" — bar only
}

View File

@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::PgPool;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::Instant;
use tokio::sync::{watch, RwLock}; use tokio::sync::{watch, RwLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@@ -17,12 +17,6 @@ pub enum DbFlavor {
} }
#[derive(Clone)]
pub struct SchemaCacheEntry {
pub schema_text: String,
pub cached_at: Instant,
}
#[derive(Clone)] #[derive(Clone)]
pub struct CachedString { pub struct CachedString {
pub value: String, pub value: String,
@@ -43,23 +37,16 @@ pub struct AppState {
/// Greenplum major version (6 or 7), tracked separately because GP6 and GP7 /// Greenplum major version (6 or 7), tracked separately because GP6 and GP7
/// expose very different system catalogs (GP6 = PG9.4 base, GP7 = PG14 base). /// expose very different system catalogs (GP6 = PG9.4 base, GP7 = PG14 base).
pub gp_majors: RwLock<HashMap<String, u8>>, pub gp_majors: RwLock<HashMap<String, u8>>,
/// Legacy cache used by generate_sql/explain_sql/fix_sql_error — full DDL. /// Chat agent caches: lite overview per connection.
pub schema_cache: RwLock<HashMap<String, SchemaCacheEntry>>,
/// Chat v2 caches: lite overview per connection.
pub overview_cache: RwLock<HashMap<String, CachedString>>, pub overview_cache: RwLock<HashMap<String, CachedString>>,
/// Chat v2 caches: list of tables per (connection_id, db_name) — used for /// Chat agent caches: list of tables per (connection_id, db_name) — used for
/// list_tables on a non-active PG database via temporary pool. /// list_tables on a non-active PG database via temporary pool.
pub tables_by_db_cache: RwLock<HashMap<(String, String), CachedVec<String>>>, pub tables_by_db_cache: RwLock<HashMap<(String, String), CachedVec<String>>>,
/// Chat v2 caches: column block per (connection_id, db_name, "schema.table").
pub columns_cache: RwLock<HashMap<(String, String, String), CachedString>>,
pub mcp_shutdown_tx: watch::Sender<bool>, pub mcp_shutdown_tx: watch::Sender<bool>,
pub mcp_running: RwLock<bool>, pub mcp_running: RwLock<bool>,
pub ai_settings: RwLock<Option<AiSettings>>, pub ai_settings: RwLock<Option<AiSettings>>,
} }
const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
const SCHEMA_CACHE_MAX_SIZE: usize = 100;
impl AppState { impl AppState {
pub fn new() -> Self { pub fn new() -> Self {
let (mcp_shutdown_tx, _) = watch::channel(false); let (mcp_shutdown_tx, _) = watch::channel(false);
@@ -69,10 +56,8 @@ impl AppState {
read_only: RwLock::new(HashMap::new()), read_only: RwLock::new(HashMap::new()),
db_flavors: RwLock::new(HashMap::new()), db_flavors: RwLock::new(HashMap::new()),
gp_majors: RwLock::new(HashMap::new()), gp_majors: RwLock::new(HashMap::new()),
schema_cache: RwLock::new(HashMap::new()),
overview_cache: RwLock::new(HashMap::new()), overview_cache: RwLock::new(HashMap::new()),
tables_by_db_cache: RwLock::new(HashMap::new()), tables_by_db_cache: RwLock::new(HashMap::new()),
columns_cache: RwLock::new(HashMap::new()),
mcp_shutdown_tx, mcp_shutdown_tx,
mcp_running: RwLock::new(false), mcp_running: RwLock::new(false),
ai_settings: RwLock::new(None), ai_settings: RwLock::new(None),
@@ -82,16 +67,11 @@ impl AppState {
/// Drop every chat-agent cache entry tied to this connection. /// Drop every chat-agent cache entry tied to this connection.
/// Called by switch_database_core, disconnect, and on connection delete. /// Called by switch_database_core, disconnect, and on connection delete.
pub async fn invalidate_chat_caches_for(&self, connection_id: &str) { pub async fn invalidate_chat_caches_for(&self, connection_id: &str) {
self.schema_cache.write().await.remove(connection_id);
self.overview_cache.write().await.remove(connection_id); self.overview_cache.write().await.remove(connection_id);
self.tables_by_db_cache self.tables_by_db_cache
.write() .write()
.await .await
.retain(|(cid, _), _| cid != connection_id); .retain(|(cid, _), _| cid != connection_id);
self.columns_cache
.write()
.await
.retain(|(cid, _, _), _| cid != connection_id);
} }
pub async fn get_pool(&self, connection_id: &str) -> TuskResult<PgPool> { pub async fn get_pool(&self, connection_id: &str) -> TuskResult<PgPool> {
@@ -125,39 +105,4 @@ impl AppState {
pub async fn get_gp_major(&self, id: &str) -> Option<u8> { pub async fn get_gp_major(&self, id: &str) -> Option<u8> {
self.gp_majors.read().await.get(id).copied() self.gp_majors.read().await.get(id).copied()
} }
pub async fn get_schema_cache(&self, connection_id: &str) -> Option<String> {
let cache = self.schema_cache.read().await;
cache.get(connection_id).and_then(|entry| {
if entry.cached_at.elapsed() < SCHEMA_CACHE_TTL {
Some(entry.schema_text.clone())
} else {
None
}
})
}
pub async fn set_schema_cache(&self, connection_id: String, schema_text: String) {
let mut cache = self.schema_cache.write().await;
// Evict stale entries to prevent unbounded memory growth
cache.retain(|_, entry| entry.cached_at.elapsed() < SCHEMA_CACHE_TTL);
// If still at capacity, remove the oldest entry
if cache.len() >= SCHEMA_CACHE_MAX_SIZE {
if let Some(oldest_key) = cache
.iter()
.min_by_key(|(_, e)| e.cached_at)
.map(|(k, _)| k.clone())
{
cache.remove(&oldest_key);
}
}
cache.insert(
connection_id,
SchemaCacheEntry {
schema_text,
cached_at: Instant::now(),
},
);
}
} }

View File

@@ -1,103 +0,0 @@
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { AiSettingsPopover } from "./AiSettingsPopover";
import { useGenerateSql } from "@/hooks/use-ai";
import { Sparkles, Loader2, X, Eraser } from "lucide-react";
import { toast } from "sonner";
interface Props {
connectionId: string;
onSqlGenerated: (sql: string) => void;
onClose: () => void;
onExecute?: () => void;
}
export function AiBar({ connectionId, onSqlGenerated, onClose, onExecute }: Props) {
const [prompt, setPrompt] = useState("");
const generateMutation = useGenerateSql();
const handleGenerate = () => {
if (!prompt.trim() || generateMutation.isPending) return;
generateMutation.mutate(
{ connectionId, prompt },
{
onSuccess: (sql) => {
onSqlGenerated(sql);
},
onError: (err) => {
toast.error("AI generation failed", { description: String(err) });
},
}
);
};
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === "Enter" && (e.ctrlKey || e.metaKey)) {
e.preventDefault();
e.stopPropagation();
onExecute?.();
return;
}
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
e.stopPropagation();
handleGenerate();
return;
}
if (e.key === "Escape") {
e.stopPropagation();
onClose();
}
};
return (
<div className="tusk-ai-bar flex items-center gap-2 px-2 py-1.5 tusk-fade-in">
<Sparkles className="h-3.5 w-3.5 shrink-0 tusk-ai-icon" />
<Input
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
onKeyDown={handleKeyDown}
placeholder="Describe the query you want..."
className="h-7 min-w-0 flex-1 border-tusk-purple/20 bg-tusk-purple/5 text-xs placeholder:text-muted-foreground/40 focus:border-tusk-purple/40 focus:ring-tusk-purple/20"
autoFocus
disabled={generateMutation.isPending}
/>
<Button
size="xs"
variant="ghost"
className="gap-1 text-[11px] text-tusk-purple hover:bg-tusk-purple/10 hover:text-tusk-purple"
onClick={handleGenerate}
disabled={generateMutation.isPending || !prompt.trim()}
>
{generateMutation.isPending ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
"Generate"
)}
</Button>
{prompt.trim() && (
<Button
size="icon-xs"
variant="ghost"
onClick={() => setPrompt("")}
title="Clear prompt"
disabled={generateMutation.isPending}
className="text-muted-foreground"
>
<Eraser className="h-3 w-3" />
</Button>
)}
<AiSettingsPopover />
<Button
size="icon-xs"
variant="ghost"
onClick={onClose}
title="Close AI bar"
className="text-muted-foreground"
>
<X className="h-3 w-3" />
</Button>
</div>
);
}

View File

@@ -7,7 +7,11 @@ import {
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
} from "@/components/ui/select"; } from "@/components/ui/select";
import { useFireworksModels, useOllamaModels } from "@/hooks/use-ai"; import {
useFireworksModels,
useOllamaModels,
useOpenRouterModels,
} from "@/hooks/use-ai";
import { RefreshCw, Loader2 } from "lucide-react"; import { RefreshCw, Loader2 } from "lucide-react";
import type { AiProvider, OllamaModel } from "@/types"; import type { AiProvider, OllamaModel } from "@/types";
@@ -17,6 +21,8 @@ interface Props {
onOllamaUrlChange: (url: string) => void; onOllamaUrlChange: (url: string) => void;
fireworksApiKey: string; fireworksApiKey: string;
onFireworksApiKeyChange: (key: string) => void; onFireworksApiKeyChange: (key: string) => void;
openrouterApiKey: string;
onOpenRouterApiKeyChange: (key: string) => void;
model: string; model: string;
onModelChange: (model: string) => void; onModelChange: (model: string) => void;
} }
@@ -27,6 +33,8 @@ export function AiSettingsFields({
onOllamaUrlChange, onOllamaUrlChange,
fireworksApiKey, fireworksApiKey,
onFireworksApiKeyChange, onFireworksApiKeyChange,
openrouterApiKey,
onOpenRouterApiKeyChange,
model, model,
onModelChange, onModelChange,
}: Props) { }: Props) {
@@ -41,6 +49,17 @@ export function AiSettingsFields({
); );
} }
if (provider === "openrouter") {
return (
<OpenRouterFields
apiKey={openrouterApiKey}
onApiKeyChange={onOpenRouterApiKeyChange}
model={model}
onModelChange={onModelChange}
/>
);
}
return ( return (
<OllamaFields <OllamaFields
ollamaUrl={ollamaUrl} ollamaUrl={ollamaUrl}
@@ -143,6 +162,55 @@ function FireworksFields({
); );
} }
function OpenRouterFields({
apiKey,
onApiKeyChange,
model,
onModelChange,
}: {
apiKey: string;
onApiKeyChange: (key: string) => void;
model: string;
onModelChange: (model: string) => void;
}) {
const {
data: models,
isLoading,
isError,
refetch,
} = useOpenRouterModels(apiKey);
return (
<>
<div className="flex flex-col gap-1.5">
<label className="text-xs text-muted-foreground">OpenRouter API key</label>
<Input
type="password"
value={apiKey}
onChange={(e) => onApiKeyChange(e.target.value)}
placeholder="sk-or-..."
className="h-8 text-xs"
autoComplete="off"
/>
<p className="text-[10px] text-muted-foreground/70">
Stored locally; sent only to openrouter.ai.
</p>
</div>
<ModelDropdown
models={models}
loading={isLoading}
errored={isError}
errorText="Cannot reach OpenRouter (check API key)"
onRefresh={() => refetch()}
model={model}
onModelChange={onModelChange}
emptyHint={apiKey.trim() ? "Click ↻ to load models" : "Enter API key first"}
/>
</>
);
}
function ModelDropdown({ function ModelDropdown({
models, models,
loading, loading,

View File

@@ -21,6 +21,7 @@ import type { AiProvider } from "@/types";
const SUPPORTED_PROVIDERS: { value: AiProvider; label: string }[] = [ const SUPPORTED_PROVIDERS: { value: AiProvider; label: string }[] = [
{ value: "ollama", label: "Ollama (local)" }, { value: "ollama", label: "Ollama (local)" },
{ value: "fireworks", label: "Fireworks AI" }, { value: "fireworks", label: "Fireworks AI" },
{ value: "openrouter", label: "OpenRouter" },
]; ];
export function AiSettingsPopover() { export function AiSettingsPopover() {
@@ -30,22 +31,16 @@ export function AiSettingsPopover() {
const [provider, setProvider] = useState<AiProvider | null>(null); const [provider, setProvider] = useState<AiProvider | null>(null);
const [url, setUrl] = useState<string | null>(null); const [url, setUrl] = useState<string | null>(null);
const [fireworksKey, setFireworksKey] = useState<string | null>(null); const [fireworksKey, setFireworksKey] = useState<string | null>(null);
const [openrouterKey, setOpenrouterKey] = useState<string | null>(null);
const [model, setModel] = useState<string | null>(null); const [model, setModel] = useState<string | null>(null);
const settingsProvider = settings?.provider;
// Hide unsupported legacy values (openai/anthropic) from the selector.
const normalizedSettingsProvider: AiProvider | undefined =
settingsProvider === "ollama" || settingsProvider === "fireworks"
? settingsProvider
: settingsProvider
? "ollama"
: undefined;
const currentProvider: AiProvider = const currentProvider: AiProvider =
provider ?? normalizedSettingsProvider ?? "ollama"; provider ?? settings?.provider ?? "ollama";
const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434"; const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434";
const currentFireworksKey = const currentFireworksKey =
fireworksKey ?? settings?.fireworks_api_key ?? ""; fireworksKey ?? settings?.fireworks_api_key ?? "";
const currentOpenrouterKey =
openrouterKey ?? settings?.openrouter_api_key ?? "";
const currentModel = model ?? settings?.model ?? ""; const currentModel = model ?? settings?.model ?? "";
const handleProviderChange = (next: AiProvider) => { const handleProviderChange = (next: AiProvider) => {
@@ -64,6 +59,10 @@ export function AiSettingsPopover() {
currentProvider === "fireworks" currentProvider === "fireworks"
? currentFireworksKey.trim() || undefined ? currentFireworksKey.trim() || undefined
: settings?.fireworks_api_key, : settings?.fireworks_api_key,
openrouter_api_key:
currentProvider === "openrouter"
? currentOpenrouterKey.trim() || undefined
: settings?.openrouter_api_key,
model: currentModel, model: currentModel,
}, },
{ {
@@ -117,6 +116,8 @@ export function AiSettingsPopover() {
onOllamaUrlChange={setUrl} onOllamaUrlChange={setUrl}
fireworksApiKey={currentFireworksKey} fireworksApiKey={currentFireworksKey}
onFireworksApiKeyChange={setFireworksKey} onFireworksApiKeyChange={setFireworksKey}
openrouterApiKey={currentOpenrouterKey}
onOpenRouterApiKeyChange={setOpenrouterKey}
model={currentModel} model={currentModel}
onModelChange={setModel} onModelChange={setModel}
/> />

View File

@@ -1,327 +0,0 @@
import { useMemo } from "react";
import {
Area,
AreaChart,
Bar,
BarChart,
CartesianGrid,
Cell,
Legend,
Line,
LineChart,
Pie,
PieChart,
ResponsiveContainer,
Tooltip,
XAxis,
YAxis,
} from "recharts";
import type { ChartConfig } from "@/types";
interface Props {
config: ChartConfig;
columns: string[];
rows: unknown[][];
height?: number;
}
const PALETTE = [
"#60a5fa", // blue-400
"#34d399", // emerald-400
"#fbbf24", // amber-400
"#f87171", // red-400
"#a78bfa", // violet-400
"#22d3ee", // cyan-400
"#fb923c", // orange-400
"#f472b6", // pink-400
];
const MAX_POINTS = 500;
export function ChartPreview({ config, columns, rows, height = 280 }: Props) {
const xIdx = columns.indexOf(config.x);
const yIdx = columns.indexOf(config.y);
const groupIdx = config.group ? columns.indexOf(config.group) : -1;
const limited = useMemo(() => rows.slice(0, MAX_POINTS), [rows]);
if (xIdx < 0 || yIdx < 0) {
return (
<ChartFallback
config={config}
message={`Column not found: ${xIdx < 0 ? config.x : config.y}`}
/>
);
}
// Coerce y values to numbers; chart libs need numeric Y.
const numericY = (v: unknown): number => {
if (typeof v === "number") return v;
if (typeof v === "string") {
const n = parseFloat(v);
return Number.isFinite(n) ? n : 0;
}
return 0;
};
const labelX = (v: unknown): string => {
if (v == null) return "—";
if (typeof v === "string") return v;
if (typeof v === "number" || typeof v === "boolean") return String(v);
return JSON.stringify(v);
};
const isGrouped = groupIdx >= 0;
// ──────────── grouped data shape ────────────
// For multi-series: pivot to { x: <xValue>, <group1>: yVal, <group2>: yVal, … }
// Used by line, area, and grouped-bar.
const pivoted = useMemo(() => {
if (!isGrouped) return null;
const map = new Map<string, Record<string, unknown>>();
const groupSet = new Set<string>();
for (const row of limited) {
const xv = labelX(row[xIdx]);
const gv = labelX(row[groupIdx!]);
const yv = numericY(row[yIdx]);
groupSet.add(gv);
const acc = map.get(xv) ?? { _x: xv };
acc[gv] = ((acc[gv] as number) ?? 0) + yv;
map.set(xv, acc);
}
return {
data: Array.from(map.values()),
groups: Array.from(groupSet),
};
}, [isGrouped, limited, xIdx, yIdx, groupIdx]);
// Single series shape: [{ _x, _y }]
const flat = useMemo(() => {
return limited.map((row) => ({
_x: labelX(row[xIdx]),
_y: numericY(row[yIdx]),
}));
}, [limited, xIdx, yIdx]);
const tickStyle = {
fill: "var(--muted-foreground)",
fontSize: 10,
} as const;
const axisLine = {
stroke: "rgba(255, 255, 255, 0.08)",
} as const;
const tooltipStyle = {
backgroundColor: "var(--popover)",
border: "1px solid var(--border)",
borderRadius: 6,
fontSize: 11,
} as const;
if (config.chart_type === "pie") {
// Pie: aggregate y by x label (sum), no group support.
const agg = new Map<string, number>();
for (const row of limited) {
const xv = labelX(row[xIdx]);
agg.set(xv, (agg.get(xv) ?? 0) + numericY(row[yIdx]));
}
const data = Array.from(agg.entries()).map(([name, value]) => ({ name, value }));
return (
<ChartFrame config={config} height={height} count={data.length} totalRows={rows.length}>
<ResponsiveContainer width="100%" height={height}>
<PieChart>
<Pie
data={data}
dataKey="value"
nameKey="name"
outerRadius={Math.min(height / 2.5, 110)}
label={(entry) =>
typeof entry.name === "string" && entry.name.length < 20 ? entry.name : ""
}
>
{data.map((_, i) => (
<Cell key={i} fill={PALETTE[i % PALETTE.length]} />
))}
</Pie>
<Tooltip contentStyle={tooltipStyle} />
<Legend
wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }}
verticalAlign="bottom"
/>
</PieChart>
</ResponsiveContainer>
</ChartFrame>
);
}
if (config.chart_type === "line") {
return (
<ChartFrame
config={config}
height={height}
count={isGrouped ? pivoted!.data.length : flat.length}
totalRows={rows.length}
>
<ResponsiveContainer width="100%" height={height}>
<LineChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<Tooltip contentStyle={tooltipStyle} />
{isGrouped ? (
<>
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
{pivoted!.groups.map((g, i) => (
<Line
key={g}
type="monotone"
dataKey={g}
stroke={PALETTE[i % PALETTE.length]}
strokeWidth={2}
dot={false}
/>
))}
</>
) : (
<Line type="monotone" dataKey="_y" stroke={PALETTE[0]} strokeWidth={2} dot={false} />
)}
</LineChart>
</ResponsiveContainer>
</ChartFrame>
);
}
if (config.chart_type === "area") {
return (
<ChartFrame
config={config}
height={height}
count={isGrouped ? pivoted!.data.length : flat.length}
totalRows={rows.length}
>
<ResponsiveContainer width="100%" height={height}>
<AreaChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<Tooltip contentStyle={tooltipStyle} />
{isGrouped ? (
<>
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
{pivoted!.groups.map((g, i) => (
<Area
key={g}
type="monotone"
dataKey={g}
stackId="1"
stroke={PALETTE[i % PALETTE.length]}
fill={PALETTE[i % PALETTE.length]}
fillOpacity={0.35}
/>
))}
</>
) : (
<Area
type="monotone"
dataKey="_y"
stroke={PALETTE[0]}
fill={PALETTE[0]}
fillOpacity={0.35}
/>
)}
</AreaChart>
</ResponsiveContainer>
</ChartFrame>
);
}
// bar (default)
const horizontal = config.orientation === "horizontal";
return (
<ChartFrame
config={config}
height={height}
count={isGrouped ? pivoted!.data.length : flat.length}
totalRows={rows.length}
>
<ResponsiveContainer width="100%" height={height}>
<BarChart
layout={horizontal ? "vertical" : "horizontal"}
data={isGrouped ? pivoted!.data : flat}
margin={{ top: 8, right: 12, left: horizontal ? 24 : 0, bottom: 4 }}
>
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={horizontal} horizontal={!horizontal} />
{horizontal ? (
<>
<XAxis type="number" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis dataKey="_x" type="category" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} width={100} />
</>
) : (
<>
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
</>
)}
<Tooltip contentStyle={tooltipStyle} />
{isGrouped ? (
<>
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
{pivoted!.groups.map((g, i) => (
<Bar key={g} dataKey={g} fill={PALETTE[i % PALETTE.length]} radius={[3, 3, 0, 0]} />
))}
</>
) : (
<Bar dataKey="_y" fill={PALETTE[0]} radius={[3, 3, 0, 0]} />
)}
</BarChart>
</ResponsiveContainer>
</ChartFrame>
);
}
function ChartFrame({
config,
height,
count,
totalRows,
children,
}: {
config: ChartConfig;
height: number;
count: number;
totalRows: number;
children: React.ReactNode;
}) {
return (
<div className="rounded-md border border-border/40 bg-background">
<div className="flex items-center gap-2 border-b border-border/30 px-2 py-1 text-[11px] text-muted-foreground">
<span className="font-medium text-foreground/80">
{config.title ?? `${capitalize(config.chart_type)} chart`}
</span>
<span className="ml-auto text-muted-foreground/60">
{count} point{count === 1 ? "" : "s"}
{totalRows > MAX_POINTS && ` (of ${totalRows}, capped at ${MAX_POINTS})`}
</span>
</div>
<div className="p-2" style={{ minHeight: height }}>
{children}
</div>
</div>
);
}
function ChartFallback({ config, message }: { config: ChartConfig; message: string }) {
return (
<div className="rounded-md border border-destructive/40 bg-destructive/5 p-3 text-xs">
<div className="font-medium text-destructive">
Chart {config.chart_type} failed
</div>
<div className="mt-1 text-muted-foreground">{message}</div>
</div>
);
}
function capitalize(s: string) {
return s.charAt(0).toUpperCase() + s.slice(1);
}

View File

@@ -1,7 +1,6 @@
import { useState } from "react"; import { useState } from "react";
import { ResultsTable } from "@/components/results/ResultsTable"; import { ResultsTable } from "@/components/results/ResultsTable";
import { ExportDialog } from "@/components/export/ExportDialog"; import { ExportDialog } from "@/components/export/ExportDialog";
import { ChartPreview } from "./ChartPreview";
import { import {
Dialog, Dialog,
DialogContent, DialogContent,
@@ -15,19 +14,12 @@ import {
AlertCircle, AlertCircle,
Sparkles, Sparkles,
User, User,
Wrench,
Database, Database,
Columns,
Layers,
RefreshCw,
StickyNote,
Bookmark,
BookmarkPlus,
Maximize2, Maximize2,
Download, Download,
BarChart3,
} from "lucide-react"; } from "lucide-react";
import type { ChartConfig, ChatMessage } from "@/types"; import type { ChatMessage } from "@/types";
import { getToolMeta, isQueryResultTool } from "./tool-registry";
interface Props { interface Props {
message: ChatMessage; message: ChatMessage;
@@ -79,8 +71,10 @@ function AssistantBubble({ text }: { text: string }) {
function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) { function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) {
const [expanded, setExpanded] = useState(false); const [expanded, setExpanded] = useState(false);
const preview = extractToolPreview(tool, inputJson); const meta = getToolMeta(tool);
const Icon = iconForTool(tool); const preview = previewFromJson(tool, inputJson);
const Icon = meta.icon;
const showSqlPreview = (tool === "run_query" || tool === "explain_query") && preview;
return ( return (
<div className="ml-8 rounded-md border border-border/40 bg-muted/20"> <div className="ml-8 rounded-md border border-border/40 bg-muted/20">
@@ -91,17 +85,14 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
> >
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />} {expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
<Icon className="h-3 w-3" /> <Icon className="h-3 w-3" />
<span className="font-medium">{labelForTool(tool)}</span> <span className="font-medium">{meta.label}</span>
{preview && ( {preview && (
<span className="ml-1 truncate text-muted-foreground/70"> <span className="ml-1 truncate text-muted-foreground/70">{preview}</span>
{preview.slice(0, 80)}
{preview.length > 80 ? "…" : ""}
</span>
)} )}
</button> </button>
{expanded && ( {expanded && (
<div className="border-t border-border/30 p-2"> <div className="border-t border-border/30 p-2">
{tool === "run_query" && preview ? ( {showSqlPreview ? (
<pre className="overflow-x-auto whitespace-pre-wrap rounded bg-background/60 p-2 font-mono text-[11px]"> <pre className="overflow-x-auto whitespace-pre-wrap rounded bg-background/60 p-2 font-mono text-[11px]">
{preview} {preview}
</pre> </pre>
@@ -116,6 +107,15 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
); );
} }
function previewFromJson(tool: string, inputJson: string): string | null {
try {
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
return getToolMeta(tool).preview(parsed);
} catch {
return null;
}
}
function ToolResultBlock({ function ToolResultBlock({
tool, tool,
isError, isError,
@@ -132,87 +132,20 @@ function ToolResultBlock({
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs"> <div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" /> <AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
<div> <div>
<div className="font-medium text-destructive">{labelForTool(tool)} failed</div> <div className="font-medium text-destructive">{getToolMeta(tool).label} failed</div>
{text && <div className="mt-1 whitespace-pre-wrap text-muted-foreground">{text}</div>} {text && <div className="mt-1 whitespace-pre-wrap text-muted-foreground">{text}</div>}
</div> </div>
</div> </div>
); );
} }
// Legacy schema tool — keep a one-line indicator for old threads. // Tools that produce a QueryResult (rendered as a table): run_query, sample_data.
if (tool === "get_schema") { if (isQueryResultTool(tool) && result) {
return (
<div className="ml-8 flex items-center gap-2 rounded-md border border-border/40 bg-muted/20 px-2 py-1.5 text-xs text-muted-foreground">
<Database className="h-3 w-3" />
<span>Loaded schema context ({text?.length ?? 0} chars)</span>
</div>
);
}
// Text-only tools (chat v2/v3): list_databases, list_tables, get_columns, switch_database,
// remember, save_query, find_queries.
if (
tool === "list_databases" ||
tool === "list_tables" ||
tool === "get_columns" ||
tool === "switch_database" ||
tool === "remember" ||
tool === "save_query" ||
tool === "find_queries"
) {
return <TextToolResult tool={tool} text={text} />;
}
// make_chart — render chart inline using config from text + data from result.
if (tool === "make_chart") {
return <ChartToolResult text={text} result={result} />;
}
// run_query — full results table with Open-full / Export actions.
if (result) {
return <RunQueryResultBlock result={result} />; return <RunQueryResultBlock result={result} />;
} }
return null; // Everything else falls back to a collapsible text block.
} return <TextToolResult tool={tool} text={text} />;
function ChartToolResult({
text,
result,
}: {
text: string | null;
result: { columns: string[]; types: string[]; rows: unknown[][]; row_count: number; execution_time_ms: number } | null;
}) {
let config: ChartConfig | null = null;
try {
if (text) {
config = JSON.parse(text) as ChartConfig;
}
} catch {
config = null;
}
if (!config || !result) {
return (
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
<div>
<div className="font-medium text-destructive">Chart unavailable</div>
<div className="mt-1 text-muted-foreground">
The agent referenced a chart but the previous query result is not attached.
</div>
</div>
</div>
);
}
return (
<div className="ml-8">
<ChartPreview
config={config}
columns={result.columns}
rows={result.rows}
/>
</div>
);
} }
function RunQueryResultBlock({ function RunQueryResultBlock({
@@ -315,8 +248,10 @@ function RunQueryResultBlock({
} }
function TextToolResult({ tool, text }: { tool: string; text: string | null }) { function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
// Lazy preview: switch_database is short; everything else collapses by default.
const [expanded, setExpanded] = useState(tool === "switch_database"); const [expanded, setExpanded] = useState(tool === "switch_database");
const Icon = iconForTool(tool); const meta = getToolMeta(tool);
const Icon = meta.icon;
const lineCount = text ? text.split("\n").length : 0; const lineCount = text ? text.split("\n").length : 0;
return ( return (
@@ -328,7 +263,7 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
> >
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />} {expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
<Icon className="h-3 w-3" /> <Icon className="h-3 w-3" />
<span className="font-medium">{labelForTool(tool)}</span> <span className="font-medium">{meta.label}</span>
{text && ( {text && (
<span className="ml-1 text-muted-foreground/60"> <span className="ml-1 text-muted-foreground/60">
{lineCount} line{lineCount === 1 ? "" : "s"} {lineCount} line{lineCount === 1 ? "" : "s"}
@@ -346,93 +281,6 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
); );
} }
function labelForTool(tool: string): string {
switch (tool) {
case "run_query":
return "Run SQL";
case "list_databases":
return "List databases";
case "list_tables":
return "List tables";
case "get_columns":
return "Inspect columns";
case "switch_database":
return "Switch database";
case "remember":
return "Remember";
case "save_query":
return "Save query";
case "find_queries":
return "Find saved queries";
case "make_chart":
return "Make chart";
case "get_schema":
return "Load schema";
default:
return tool;
}
}
function iconForTool(tool: string) {
switch (tool) {
case "run_query":
return Wrench;
case "list_databases":
return Database;
case "list_tables":
return Layers;
case "get_columns":
return Columns;
case "switch_database":
return RefreshCw;
case "remember":
return StickyNote;
case "save_query":
return BookmarkPlus;
case "find_queries":
return Bookmark;
case "make_chart":
return BarChart3;
case "get_schema":
return Database;
default:
return Wrench;
}
}
function extractToolPreview(tool: string, inputJson: string): string | null {
try {
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
switch (tool) {
case "run_query":
return typeof parsed.sql === "string" ? parsed.sql : null;
case "list_tables":
return typeof parsed.database === "string" ? parsed.database : null;
case "switch_database":
return typeof parsed.database === "string" ? parsed.database : null;
case "get_columns":
return Array.isArray(parsed.tables) ? parsed.tables.join(", ") : null;
case "remember":
return typeof parsed.note === "string" ? parsed.note : null;
case "save_query":
return typeof parsed.name === "string" ? parsed.name : null;
case "find_queries":
return typeof parsed.text === "string" ? parsed.text : null;
case "make_chart": {
const t = typeof parsed.chart_type === "string" ? parsed.chart_type : null;
const x = typeof parsed.x === "string" ? parsed.x : null;
const y = typeof parsed.y === "string" ? parsed.y : null;
if (t && x && y) return `${t}: ${x}${y}`;
return null;
}
default:
return null;
}
} catch {
return null;
}
}
function prettyJson(s: string): string { function prettyJson(s: string): string {
try { try {
return JSON.stringify(JSON.parse(s), null, 2); return JSON.stringify(JSON.parse(s), null, 2);

View File

@@ -0,0 +1,107 @@
import {
Database,
Layers,
Columns,
RefreshCw,
Wrench,
StickyNote,
Bookmark,
BookmarkPlus,
Activity,
Shuffle,
GitBranch,
AlertTriangle,
} from "lucide-react";
import type { LucideIcon } from "lucide-react";
export type ToolMeta = {
icon: LucideIcon;
label: string;
preview: (parsed: Record<string, unknown>) => string | null;
};
const truncate = (s: unknown, n = 80): string | null => {
if (typeof s !== "string") return null;
return s.length > n ? `${s.slice(0, n)}` : s;
};
export const TOOLS: Record<string, ToolMeta> = {
list_databases: {
icon: Database,
label: "List databases",
preview: () => null,
},
list_tables: {
icon: Layers,
label: "List tables",
preview: (p) => (typeof p.database === "string" ? p.database : null),
},
get_columns: {
icon: Columns,
label: "Inspect columns",
preview: (p) => (Array.isArray(p.tables) ? (p.tables as string[]).join(", ") : null),
},
switch_database: {
icon: RefreshCw,
label: "Switch database",
preview: (p) => (typeof p.database === "string" ? p.database : null),
},
run_query: {
icon: Wrench,
label: "Run SQL",
preview: (p) => truncate(p.sql),
},
remember: {
icon: StickyNote,
label: "Remember",
preview: (p) => (typeof p.note === "string" ? p.note : null),
},
save_query: {
icon: BookmarkPlus,
label: "Save query",
preview: (p) => (typeof p.name === "string" ? p.name : null),
},
find_queries: {
icon: Bookmark,
label: "Find saved queries",
preview: (p) => (typeof p.text === "string" ? p.text : null),
},
profile_table: {
icon: Activity,
label: "Profile table",
preview: (p) => (typeof p.table === "string" ? p.table : null),
},
sample_data: {
icon: Shuffle,
label: "Sample rows",
preview: (p) => {
const t = typeof p.table === "string" ? p.table : "";
const limit = typeof p.limit === "number" ? p.limit : 50;
return t ? `${t} (${limit})` : null;
},
},
explain_query: {
icon: GitBranch,
label: "Explain query",
preview: (p) => truncate(p.sql),
},
detect_skew: {
icon: AlertTriangle,
label: "Detect skew",
preview: (p) => (typeof p.table === "string" ? p.table : null),
},
};
export function getToolMeta(tool: string): ToolMeta {
return (
TOOLS[tool] ?? {
icon: Wrench,
label: tool,
preview: () => null,
}
);
}
export function isQueryResultTool(tool: string): boolean {
return tool === "run_query" || tool === "sample_data";
}

View File

@@ -1,8 +1,7 @@
import { ResultsTable } from "./ResultsTable"; import { ResultsTable } from "./ResultsTable";
import { ResultsJsonView } from "./ResultsJsonView"; import { ResultsJsonView } from "./ResultsJsonView";
import type { QueryResult } from "@/types"; import type { QueryResult } from "@/types";
import { Loader2, AlertCircle, Sparkles, Wand2 } from "lucide-react"; import { Loader2, AlertCircle } from "lucide-react";
import { Button } from "@/components/ui/button";
interface Props { interface Props {
result?: QueryResult | null; result?: QueryResult | null;
@@ -15,10 +14,6 @@ interface Props {
value: unknown value: unknown
) => void; ) => void;
highlightedCells?: Set<string>; highlightedCells?: Set<string>;
aiExplanation?: string | null;
isAiLoading?: boolean;
onExplainError?: () => void;
onFixError?: () => void;
} }
export function ResultsPanel({ export function ResultsPanel({
@@ -28,10 +23,6 @@ export function ResultsPanel({
viewMode = "table", viewMode = "table",
onCellDoubleClick, onCellDoubleClick,
highlightedCells, highlightedCells,
aiExplanation,
isAiLoading,
onExplainError,
onFixError,
}: Props) { }: Props) {
if (isLoading) { if (isLoading) {
return ( return (
@@ -42,22 +33,6 @@ export function ResultsPanel({
); );
} }
if (aiExplanation) {
return (
<div className="h-full select-text overflow-auto p-4">
<div className="rounded-md border bg-muted/30 p-4">
<div className="mb-2 flex items-center gap-2 text-xs font-medium text-muted-foreground">
<Sparkles className="h-3.5 w-3.5" />
AI Explanation
</div>
<pre className="whitespace-pre-wrap font-sans text-sm leading-relaxed text-foreground">
{aiExplanation}
</pre>
</div>
</div>
);
}
if (error) { if (error) {
return ( return (
<div className="flex h-full select-text flex-col items-center justify-center gap-3 p-4"> <div className="flex h-full select-text flex-col items-center justify-center gap-3 p-4">
@@ -65,42 +40,6 @@ export function ResultsPanel({
<AlertCircle className="mt-0.5 h-4 w-4 shrink-0" /> <AlertCircle className="mt-0.5 h-4 w-4 shrink-0" />
<pre className="whitespace-pre-wrap font-mono text-xs">{error}</pre> <pre className="whitespace-pre-wrap font-mono text-xs">{error}</pre>
</div> </div>
{(onExplainError || onFixError) && (
<div className="flex items-center gap-2">
{onExplainError && (
<Button
size="sm"
variant="outline"
className="h-7 gap-1.5 text-xs"
onClick={onExplainError}
disabled={isAiLoading}
>
{isAiLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<Sparkles className="h-3 w-3" />
)}
Explain
</Button>
)}
{onFixError && (
<Button
size="sm"
variant="outline"
className="h-7 gap-1.5 text-xs"
onClick={onFixError}
disabled={isAiLoading}
>
{isAiLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<Wand2 className="h-3 w-3" />
)}
Fix with AI
</Button>
)}
</div>
)}
</div> </div>
); );
} }

View File

@@ -27,6 +27,7 @@ import type { AiProvider, AppSettings } from "@/types";
const SUPPORTED_AI_PROVIDERS: { value: AiProvider; label: string }[] = [ const SUPPORTED_AI_PROVIDERS: { value: AiProvider; label: string }[] = [
{ value: "ollama", label: "Ollama (local)" }, { value: "ollama", label: "Ollama (local)" },
{ value: "fireworks", label: "Fireworks AI" }, { value: "fireworks", label: "Fireworks AI" },
{ value: "openrouter", label: "OpenRouter" },
]; ];
interface Props { interface Props {
@@ -50,6 +51,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
const [aiProvider, setAiProvider] = useState<AiProvider>("ollama"); const [aiProvider, setAiProvider] = useState<AiProvider>("ollama");
const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434"); const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434");
const [fireworksApiKey, setFireworksApiKey] = useState(""); const [fireworksApiKey, setFireworksApiKey] = useState("");
const [openrouterApiKey, setOpenrouterApiKey] = useState("");
const [aiModel, setAiModel] = useState(""); const [aiModel, setAiModel] = useState("");
const [copied, setCopied] = useState(false); const [copied, setCopied] = useState(false);
@@ -70,10 +72,14 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
if (aiSettings) { if (aiSettings) {
// Legacy openai/anthropic values aren't user-selectable here — fall back to ollama. // Legacy openai/anthropic values aren't user-selectable here — fall back to ollama.
setAiProvider( setAiProvider(
aiSettings.provider === "fireworks" ? "fireworks" : "ollama" aiSettings.provider === "fireworks" ||
aiSettings.provider === "openrouter"
? aiSettings.provider
: "ollama"
); );
setOllamaUrl(aiSettings.ollama_url); setOllamaUrl(aiSettings.ollama_url);
setFireworksApiKey(aiSettings.fireworks_api_key ?? ""); setFireworksApiKey(aiSettings.fireworks_api_key ?? "");
setOpenrouterApiKey(aiSettings.openrouter_api_key ?? "");
setAiModel(aiSettings.model); setAiModel(aiSettings.model);
} }
} }
@@ -115,6 +121,10 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
aiProvider === "fireworks" aiProvider === "fireworks"
? fireworksApiKey.trim() || undefined ? fireworksApiKey.trim() || undefined
: aiSettings?.fireworks_api_key, : aiSettings?.fireworks_api_key,
openrouter_api_key:
aiProvider === "openrouter"
? openrouterApiKey.trim() || undefined
: aiSettings?.openrouter_api_key,
model: aiModel, model: aiModel,
}, },
{ {
@@ -167,7 +177,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
<span <span
className={`inline-block h-2 w-2 rounded-full ${ className={`inline-block h-2 w-2 rounded-full ${
mcpStatus?.running mcpStatus?.running
? "bg-green-500" ? "bg-success ring-2 ring-success/25"
: "bg-muted-foreground/30" : "bg-muted-foreground/30"
}`} }`}
/> />
@@ -189,7 +199,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
title="Copy endpoint URL" title="Copy endpoint URL"
> >
{copied ? ( {copied ? (
<Check className="h-3 w-3 text-green-500" /> <Check className="h-3 w-3 text-success" />
) : ( ) : (
<Copy className="h-3 w-3" /> <Copy className="h-3 w-3" />
)} )}
@@ -229,6 +239,8 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
onOllamaUrlChange={setOllamaUrl} onOllamaUrlChange={setOllamaUrl}
fireworksApiKey={fireworksApiKey} fireworksApiKey={fireworksApiKey}
onFireworksApiKeyChange={setFireworksApiKey} onFireworksApiKeyChange={setFireworksApiKey}
openrouterApiKey={openrouterApiKey}
onOpenRouterApiKeyChange={setOpenrouterApiKey}
model={aiModel} model={aiModel}
onModelChange={setAiModel} onModelChange={setAiModel}
/> />

View File

@@ -13,7 +13,7 @@ import { useCompletionSchema } from "@/hooks/use-completion-schema";
import { useConnections } from "@/hooks/use-connections"; import { useConnections } from "@/hooks/use-connections";
import { useAppStore } from "@/stores/app-store"; import { useAppStore } from "@/stores/app-store";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles, BrainCircuit } from "lucide-react"; import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces } from "lucide-react";
import { format as formatSql } from "sql-formatter"; import { format as formatSql } from "sql-formatter";
import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog"; import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog";
import { import {
@@ -25,8 +25,6 @@ import {
import { exportCsv, exportJson } from "@/lib/tauri"; import { exportCsv, exportJson } from "@/lib/tauri";
import { save } from "@tauri-apps/plugin-dialog"; import { save } from "@tauri-apps/plugin-dialog";
import { toast } from "sonner"; import { toast } from "sonner";
import { AiBar } from "@/components/ai/AiBar";
import { useExplainSql, useFixSqlError } from "@/hooks/use-ai";
import type { QueryResult, ExplainResult } from "@/types"; import type { QueryResult, ExplainResult } from "@/types";
interface Props { interface Props {
@@ -53,12 +51,8 @@ export function WorkspacePanel({
const [resultView, setResultView] = useState<"results" | "explain">("results"); const [resultView, setResultView] = useState<"results" | "explain">("results");
const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table"); const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table");
const [saveDialogOpen, setSaveDialogOpen] = useState(false); const [saveDialogOpen, setSaveDialogOpen] = useState(false);
const [aiBarOpen, setAiBarOpen] = useState(false);
const [aiExplanation, setAiExplanation] = useState<string | null>(null);
const queryMutation = useQueryExecution(); const queryMutation = useQueryExecution();
const explainMutation = useExplainSql();
const fixMutation = useFixSqlError();
const addHistoryMutation = useAddHistory(); const addHistoryMutation = useAddHistory();
const { data: connections } = useConnections(); const { data: connections } = useConnections();
const { data: completionSchema } = useCompletionSchema(connectionId); const { data: completionSchema } = useCompletionSchema(connectionId);
@@ -102,7 +96,6 @@ export function WorkspacePanel({
if (!sqlValue.trim() || !connectionId) return; if (!sqlValue.trim() || !connectionId) return;
setError(null); setError(null);
setExplainData(null); setExplainData(null);
setAiExplanation(null);
setResultView("results"); setResultView("results");
queryMutation.mutate( queryMutation.mutate(
{ connectionId, sql: sqlValue }, { connectionId, sql: sqlValue },
@@ -196,60 +189,6 @@ export function WorkspacePanel({
[result] [result]
); );
const isAiLoading = explainMutation.isPending || fixMutation.isPending;
const handleAiExplain = useCallback(() => {
if (!sqlValue.trim() || !connectionId) return;
setAiExplanation(null);
setResultView("results");
explainMutation.mutate(
{ connectionId, sql: sqlValue },
{
onSuccess: (explanation) => {
setAiExplanation(explanation);
},
onError: (err) => {
toast.error("AI Explain failed", { description: String(err) });
},
}
);
}, [connectionId, sqlValue, explainMutation]);
const handleExplainError = useCallback(() => {
if (!sqlValue.trim() || !connectionId || !error) return;
setAiExplanation(null);
explainMutation.mutate(
{ connectionId, sql: `${sqlValue}\n\n-- Error: ${error}` },
{
onSuccess: (explanation) => {
setAiExplanation(explanation);
},
onError: (err) => {
toast.error("AI Explain failed", { description: String(err) });
},
}
);
}, [connectionId, sqlValue, error, explainMutation]);
const handleFixError = useCallback(() => {
if (!sqlValue.trim() || !connectionId || !error) return;
fixMutation.mutate(
{ connectionId, sql: sqlValue, errorMessage: error },
{
onSuccess: (fixedSql) => {
setSqlValue(fixedSql);
onSqlChange?.(fixedSql);
setError(null);
setAiExplanation(null);
toast.success("SQL replaced by AI suggestion");
},
onError: (err) => {
toast.error("AI Fix failed", { description: String(err) });
},
}
);
}, [connectionId, sqlValue, error, fixMutation, onSqlChange]);
return ( return (
<> <>
<ResizablePanelGroup orientation="vertical"> <ResizablePanelGroup orientation="vertical">
@@ -308,35 +247,6 @@ export function WorkspacePanel({
Save Save
</Button> </Button>
<div className="mx-1 h-3.5 w-px bg-border/40" />
{/* AI actions group — purple-branded */}
<Button
size="xs"
variant={aiBarOpen ? "secondary" : "ghost"}
className={`gap-1 text-[11px] ${aiBarOpen ? "text-tusk-purple" : ""}`}
onClick={() => setAiBarOpen(!aiBarOpen)}
title="AI SQL Generator"
>
<Sparkles className={`h-3 w-3 ${aiBarOpen ? "tusk-ai-icon" : ""}`} />
AI
</Button>
<Button
size="xs"
variant="ghost"
className="gap-1 text-[11px]"
onClick={handleAiExplain}
disabled={isAiLoading || !sqlValue.trim()}
title="Explain query with AI"
>
{isAiLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<BrainCircuit className="h-3 w-3" />
)}
AI Explain
</Button>
{result && result.columns.length > 0 && ( {result && result.columns.length > 0 && (
<> <>
<div className="mx-1 h-3.5 w-px bg-border/40" /> <div className="mx-1 h-3.5 w-px bg-border/40" />
@@ -369,23 +279,12 @@ export function WorkspacePanel({
{"\u2318"}Enter {"\u2318"}Enter
</span> </span>
{isReadOnly && ( {isReadOnly && (
<span className="ml-2 flex items-center gap-1 rounded-sm bg-amber-500/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-amber-500"> <span className="ml-2 flex items-center gap-1 rounded-sm bg-warning/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-warning">
<Lock className="h-2.5 w-2.5" /> <Lock className="h-2.5 w-2.5" />
READ READ
</span> </span>
)} )}
</div> </div>
{aiBarOpen && (
<AiBar
connectionId={connectionId}
onSqlGenerated={(sql) => {
setSqlValue(sql);
onSqlChange?.(sql);
}}
onClose={() => setAiBarOpen(false)}
onExecute={handleExecute}
/>
)}
<div className="min-h-0 flex-1"> <div className="min-h-0 flex-1">
<SqlEditor <SqlEditor
value={sqlValue} value={sqlValue}
@@ -400,7 +299,7 @@ export function WorkspacePanel({
<ResizableHandle withHandle /> <ResizableHandle withHandle />
<ResizablePanel id="results" defaultSize="60%" minSize="15%"> <ResizablePanel id="results" defaultSize="60%" minSize="15%">
<div className="flex h-full flex-col overflow-hidden"> <div className="flex h-full flex-col overflow-hidden">
{(explainData || result || error || aiExplanation) && ( {(explainData || result || error) && (
<div className="flex shrink-0 items-center border-b border-border/40 text-xs"> <div className="flex shrink-0 items-center border-b border-border/40 text-xs">
<button <button
className={`relative px-3 py-1.5 font-medium transition-colors ${ className={`relative px-3 py-1.5 font-medium transition-colors ${
@@ -469,10 +368,6 @@ export function WorkspacePanel({
error={error} error={error}
isLoading={queryMutation.isPending && resultView === "results"} isLoading={queryMutation.isPending && resultView === "results"}
viewMode={resultViewMode} viewMode={resultViewMode}
aiExplanation={aiExplanation}
isAiLoading={isAiLoading}
onExplainError={error ? handleExplainError : undefined}
onFixError={error ? handleFixError : undefined}
/> />
)} )}
</div> </div>

View File

@@ -4,9 +4,7 @@ import {
saveAiSettings, saveAiSettings,
listOllamaModels, listOllamaModels,
listFireworksModels, listFireworksModels,
generateSql, listOpenRouterModels,
explainSql,
fixSqlError,
} from "@/lib/tauri"; } from "@/lib/tauri";
import type { AiSettings } from "@/types"; import type { AiSettings } from "@/types";
@@ -47,40 +45,12 @@ export function useFireworksModels(apiKey: string | undefined) {
}); });
} }
export function useGenerateSql() { export function useOpenRouterModels(apiKey: string | undefined) {
return useMutation({ return useQuery({
mutationFn: ({ queryKey: ["openrouter-models", apiKey],
connectionId, queryFn: () => listOpenRouterModels(apiKey!),
prompt, enabled: !!apiKey && apiKey.trim().length > 0,
}: { retry: false,
connectionId: string; staleTime: 60_000,
prompt: string;
}) => generateSql(connectionId, prompt),
});
}
export function useExplainSql() {
return useMutation({
mutationFn: ({
connectionId,
sql,
}: {
connectionId: string;
sql: string;
}) => explainSql(connectionId, sql),
});
}
export function useFixSqlError() {
return useMutation({
mutationFn: ({
connectionId,
sql,
errorMessage,
}: {
connectionId: string;
sql: string;
errorMessage: string;
}) => fixSqlError(connectionId, sql, errorMessage),
}); });
} }

View File

@@ -214,14 +214,8 @@ export const listOllamaModels = (ollamaUrl: string) =>
export const listFireworksModels = (apiKey: string) => export const listFireworksModels = (apiKey: string) =>
invoke<OllamaModel[]>("list_fireworks_models", { apiKey }); invoke<OllamaModel[]>("list_fireworks_models", { apiKey });
export const generateSql = (connectionId: string, prompt: string) => export const listOpenRouterModels = (apiKey: string) =>
invoke<string>("generate_sql", { connectionId, prompt }); invoke<OllamaModel[]>("list_openrouter_models", { apiKey });
export const explainSql = (connectionId: string, sql: string) =>
invoke<string>("explain_sql", { connectionId, sql });
export const fixSqlError = (connectionId: string, sql: string, errorMessage: string) =>
invoke<string>("fix_sql_error", { connectionId, sql, errorMessage });
export const chatSend = (connectionId: string, messages: ChatMessage[]) => export const chatSend = (connectionId: string, messages: ChatMessage[]) =>
invoke<ChatTurnResult>("chat_send", { connectionId, messages }); invoke<ChatTurnResult>("chat_send", { connectionId, messages });

View File

@@ -134,14 +134,13 @@ export interface SavedQuery {
created_at: string; created_at: string;
} }
export type AiProvider = "ollama" | "openai" | "anthropic" | "fireworks"; export type AiProvider = "ollama" | "fireworks" | "openrouter";
export interface AiSettings { export interface AiSettings {
provider: AiProvider; provider: AiProvider;
ollama_url: string; ollama_url: string;
openai_api_key?: string;
anthropic_api_key?: string;
fireworks_api_key?: string; fireworks_api_key?: string;
openrouter_api_key?: string;
model: string; model: string;
} }
@@ -216,14 +215,3 @@ export interface ChatTurnResult {
messages: ChatMessage[]; messages: ChatMessage[];
usage: ContextUsage; usage: ContextUsage;
} }
export type ChartType = "bar" | "line" | "area" | "pie";
export interface ChartConfig {
chart_type: ChartType;
x: string;
y: string;
group?: string | null;
title?: string | null;
orientation?: string | null;
}