refactor(ai): consolidate AI around chat tool-calling; add OpenRouter
- rework chat backend (chat.rs, chat_tools.rs, ai.rs, models, state) around tool calls - add OpenRouter provider alongside Ollama/Fireworks in settings - drop inline AiBar, ResultsPanel explain/fix UI and ChartPreview in favour of the chat panel - add frontend chat tool-registry
This commit is contained in:
@@ -15,6 +15,13 @@ use tauri::{AppHandle, Manager, State};
|
||||
const MAX_RETRIES: u32 = 2;
|
||||
const RETRY_DELAY_MS: u64 = 1000;
|
||||
const FIREWORKS_BASE_URL: &str = "https://api.fireworks.ai/inference/v1";
|
||||
const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
|
||||
/// Optional attribution headers OpenRouter uses for its public app leaderboard.
|
||||
/// Harmless to send on every request; ignored by other OpenAI-compatible APIs.
|
||||
const OPENROUTER_HEADERS: &[(&str, &str)] = &[
|
||||
("HTTP-Referer", "https://github.com/codelab/tusk"),
|
||||
("X-Title", "Tusk"),
|
||||
];
|
||||
|
||||
fn http_client() -> &'static reqwest::Client {
|
||||
use std::sync::LazyLock;
|
||||
@@ -89,32 +96,51 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaMode
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_fireworks_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
|
||||
let key = api_key.trim();
|
||||
if key.is_empty() {
|
||||
return Err(TuskError::Ai("Fireworks API key required".to_string()));
|
||||
list_openai_compatible_models(FIREWORKS_BASE_URL, &api_key, "Fireworks", &[]).await
|
||||
}
|
||||
|
||||
let url = format!("{}/models", FIREWORKS_BASE_URL);
|
||||
let resp = http_client()
|
||||
.get(&url)
|
||||
.bearer_auth(key)
|
||||
#[tauri::command]
|
||||
pub async fn list_openrouter_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
|
||||
list_openai_compatible_models(OPENROUTER_BASE_URL, &api_key, "OpenRouter", OPENROUTER_HEADERS)
|
||||
.await
|
||||
}
|
||||
|
||||
/// List the available chat models for any OpenAI-compatible provider via its
|
||||
/// `GET {base_url}/models` endpoint. `extra_headers` carries provider-specific
|
||||
/// attribution headers (OpenRouter recommends `HTTP-Referer`/`X-Title`).
|
||||
async fn list_openai_compatible_models(
|
||||
base_url: &str,
|
||||
api_key: &str,
|
||||
provider_label: &str,
|
||||
extra_headers: &[(&str, &str)],
|
||||
) -> TuskResult<Vec<OllamaModel>> {
|
||||
let key = api_key.trim();
|
||||
if key.is_empty() {
|
||||
return Err(TuskError::Ai(format!("{} API key required", provider_label)));
|
||||
}
|
||||
|
||||
let url = format!("{}/models", base_url);
|
||||
let mut req = http_client().get(&url).bearer_auth(key);
|
||||
for (name, value) in extra_headers {
|
||||
req = req.header(*name, *value);
|
||||
}
|
||||
let resp = req
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| TuskError::Ai(format!("Cannot reach Fireworks: {}", e)))?;
|
||||
.map_err(|e| TuskError::Ai(format!("Cannot reach {}: {}", provider_label, e)))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(TuskError::Ai(format!(
|
||||
"Fireworks error ({}): {}",
|
||||
status, body
|
||||
"{} error ({}): {}",
|
||||
provider_label, status, body
|
||||
)));
|
||||
}
|
||||
|
||||
let parsed: FireworksModelsResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| TuskError::Ai(format!("Failed to parse Fireworks models list: {}", e)))?;
|
||||
let parsed: FireworksModelsResponse = resp.json().await.map_err(|e| {
|
||||
TuskError::Ai(format!("Failed to parse {} models list: {}", provider_label, e))
|
||||
})?;
|
||||
|
||||
Ok(parsed
|
||||
.data
|
||||
@@ -180,33 +206,8 @@ pub(crate) async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskR
|
||||
Ok(settings)
|
||||
}
|
||||
|
||||
async fn call_chat_simple(
|
||||
app: &AppHandle,
|
||||
state: &AppState,
|
||||
system_prompt: String,
|
||||
user_content: String,
|
||||
) -> TuskResult<String> {
|
||||
call_chat_messages(
|
||||
app,
|
||||
state,
|
||||
vec![
|
||||
OllamaChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: system_prompt,
|
||||
},
|
||||
OllamaChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: user_content,
|
||||
},
|
||||
],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Provider-agnostic chat-completions dispatcher used by every LLM-driven feature
|
||||
/// (chat agent, generate_sql, explain_sql, fix_sql_error). Returns the model's
|
||||
/// raw text content.
|
||||
/// Provider-agnostic chat-completions dispatcher used by the chat agent.
|
||||
/// Returns the model's raw text content.
|
||||
pub(crate) async fn call_chat_messages(
|
||||
app: &AppHandle,
|
||||
state: &AppState,
|
||||
@@ -223,12 +224,47 @@ pub(crate) async fn call_chat_messages(
|
||||
|
||||
match settings.provider {
|
||||
AiProvider::Ollama => call_ollama(&settings, messages, format).await,
|
||||
AiProvider::Fireworks => call_fireworks(&settings, messages, format).await,
|
||||
AiProvider::OpenAi | AiProvider::Anthropic => Err(TuskError::Ai(format!(
|
||||
"Provider {:?} not implemented yet",
|
||||
settings.provider
|
||||
))),
|
||||
AiProvider::Fireworks => {
|
||||
let api_key = require_api_key(
|
||||
settings.fireworks_api_key.as_deref(),
|
||||
"Fireworks API key not set. Open AI settings to add it.",
|
||||
)?;
|
||||
call_openai_compatible(
|
||||
&settings,
|
||||
FIREWORKS_BASE_URL,
|
||||
&api_key,
|
||||
"Fireworks",
|
||||
&[],
|
||||
messages,
|
||||
format,
|
||||
)
|
||||
.await
|
||||
}
|
||||
AiProvider::OpenRouter => {
|
||||
let api_key = require_api_key(
|
||||
settings.openrouter_api_key.as_deref(),
|
||||
"OpenRouter API key not set. Open AI settings to add it.",
|
||||
)?;
|
||||
call_openai_compatible(
|
||||
&settings,
|
||||
OPENROUTER_BASE_URL,
|
||||
&api_key,
|
||||
"OpenRouter",
|
||||
OPENROUTER_HEADERS,
|
||||
messages,
|
||||
format,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trim and validate an optional API key, returning a user-facing error when
|
||||
/// it's missing or blank.
|
||||
fn require_api_key(key: Option<&str>, missing_msg: &str) -> TuskResult<String> {
|
||||
key.map(|k| k.trim().to_string())
|
||||
.filter(|k| !k.is_empty())
|
||||
.ok_or_else(|| TuskError::Ai(missing_msg.to_string()))
|
||||
}
|
||||
|
||||
async fn call_ollama(
|
||||
@@ -277,21 +313,18 @@ async fn call_ollama(
|
||||
.await
|
||||
}
|
||||
|
||||
async fn call_fireworks(
|
||||
/// Chat-completions call for any OpenAI-compatible provider (Fireworks,
|
||||
/// OpenRouter). `extra_headers` carries provider-specific attribution headers.
|
||||
async fn call_openai_compatible(
|
||||
settings: &AiSettings,
|
||||
base_url: &str,
|
||||
api_key: &str,
|
||||
provider_label: &str,
|
||||
extra_headers: &[(&str, &str)],
|
||||
messages: Vec<OllamaChatMessage>,
|
||||
format: Option<String>,
|
||||
) -> TuskResult<String> {
|
||||
let api_key = settings
|
||||
.fireworks_api_key
|
||||
.clone()
|
||||
.map(|k| k.trim().to_string())
|
||||
.filter(|k| !k.is_empty())
|
||||
.ok_or_else(|| {
|
||||
TuskError::Ai("Fireworks API key not set. Open AI settings to add it.".to_string())
|
||||
})?;
|
||||
|
||||
let url = format!("{}/chat/completions", FIREWORKS_BASE_URL);
|
||||
let url = format!("{}/chat/completions", base_url);
|
||||
let response_format = format.as_deref().map(|f| FireworksResponseFormat {
|
||||
kind: if f == "json" {
|
||||
"json_object".to_string()
|
||||
@@ -307,19 +340,22 @@ async fn call_fireworks(
|
||||
response_format,
|
||||
};
|
||||
|
||||
call_ai_with_retry(settings, "Fireworks request", || {
|
||||
let operation = format!("{} request", provider_label);
|
||||
call_ai_with_retry(settings, &operation, || {
|
||||
let url = url.clone();
|
||||
let request = request.clone();
|
||||
let api_key = api_key.clone();
|
||||
let api_key = api_key.to_string();
|
||||
async move {
|
||||
let resp = http_client()
|
||||
.post(&url)
|
||||
.bearer_auth(&api_key)
|
||||
let mut req = http_client().post(&url).bearer_auth(&api_key);
|
||||
for (name, value) in extra_headers {
|
||||
req = req.header(*name, *value);
|
||||
}
|
||||
let resp = req
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
TuskError::Ai(format!("Cannot reach Fireworks at {}: {}", url, e))
|
||||
TuskError::Ai(format!("Cannot reach {} at {}: {}", provider_label, url, e))
|
||||
})?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
@@ -339,13 +375,13 @@ async fn call_fireworks(
|
||||
));
|
||||
}
|
||||
return Err(TuskError::Ai(format!(
|
||||
"Fireworks error ({}): {}",
|
||||
status, body
|
||||
"{} error ({}): {}",
|
||||
provider_label, status, body
|
||||
)));
|
||||
}
|
||||
|
||||
let parsed: FireworksChatResponse = resp.json().await.map_err(|e| {
|
||||
TuskError::Ai(format!("Failed to parse Fireworks response: {}", e))
|
||||
TuskError::Ai(format!("Failed to parse {} response: {}", provider_label, e))
|
||||
})?;
|
||||
|
||||
parsed
|
||||
@@ -353,177 +389,14 @@ async fn call_fireworks(
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| TuskError::Ai("Fireworks returned no choices".to_string()))
|
||||
.ok_or_else(|| {
|
||||
TuskError::Ai(format!("{} returned no choices", provider_label))
|
||||
})
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SQL generation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn generate_sql(
|
||||
app: AppHandle,
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
prompt: String,
|
||||
) -> TuskResult<String> {
|
||||
let schema_text = build_schema_context(&state, &connection_id).await?;
|
||||
|
||||
let system_prompt = format!(
|
||||
"You are an expert PostgreSQL query generator. You receive a database schema and a natural \
|
||||
language request. Output ONLY a valid, executable PostgreSQL SQL query.\n\
|
||||
\n\
|
||||
OUTPUT FORMAT:\n\
|
||||
- Raw SQL only. No explanations, no markdown code fences (```), no comments, no preamble.\n\
|
||||
- The output must be directly executable in psql.\n\
|
||||
- For complex queries use readable formatting with line breaks and indentation.\n\
|
||||
\n\
|
||||
CRITICAL RULES:\n\
|
||||
1. ONLY reference tables and columns that exist in the schema. Never invent names.\n\
|
||||
2. Use the FOREIGN KEY information to determine correct JOIN conditions.\n\
|
||||
3. Use LEFT JOIN when the FK column is nullable or the relationship is optional; \
|
||||
INNER JOIN when both sides must exist.\n\
|
||||
4. Every non-aggregated column in SELECT must appear in GROUP BY.\n\
|
||||
5. Use COALESCE for nullable columns in aggregations: COALESCE(SUM(x), 0).\n\
|
||||
6. For enum columns, use ONLY the values listed in the ENUM TYPES section.\n\
|
||||
7. Use IS NULL / IS NOT NULL for null checks — never = NULL or != NULL.\n\
|
||||
8. Add LIMIT when the user asks for \"top N\", \"first N\", \"latest N\", etc.\n\
|
||||
9. Qualify column names with table alias when the query involves multiple tables.\n\
|
||||
\n\
|
||||
SEMANTIC RULES (very important):\n\
|
||||
- When a table has both actual_* and planned_* columns (e.g. actual_start vs planned_start), \
|
||||
they represent DIFFERENT concepts: planned = future estimate, actual = what really happened. \
|
||||
NEVER mix them with COALESCE unless the user explicitly requests a fallback.\n\
|
||||
- For time-based calculations involving real events (\"how long did X take\", \"average time between\"), \
|
||||
use ONLY actual/factual timestamps (actual_*, started_at, completed_at, ended_at). \
|
||||
Filter out NULL values with WHERE instead of falling back to planned timestamps.\n\
|
||||
- Planned timestamps (planned_*, scheduled_*, estimated_*) should ONLY be used when the user \
|
||||
asks about plans, schedules, SLA, or compares plan vs fact.\n\
|
||||
- When computing durations or averages, always filter out rows where any involved timestamp \
|
||||
is NULL rather than substituting with unrelated defaults.\n\
|
||||
- Pay attention to column descriptions/comments in the schema — they reveal business semantics \
|
||||
that are critical for correct queries.\n\
|
||||
\n\
|
||||
TYPE RULES:\n\
|
||||
- timestamp - timestamp = interval. For seconds: EXTRACT(EPOCH FROM (ts1 - ts2)).\n\
|
||||
- interval cannot be cast to numeric directly; use EXTRACT(EPOCH FROM interval).\n\
|
||||
- UNION/UNION ALL requires matching column count and compatible types; cast enums to text.\n\
|
||||
- Use ::type for PostgreSQL-style casts.\n\
|
||||
- For array columns use ANY, ALL, @>, <@ operators.\n\
|
||||
- For JSONB columns use ->, ->>, #>, jsonb_extract_path.\n\
|
||||
\n\
|
||||
COMMON PATTERNS:\n\
|
||||
- FIRST/LAST per group: to find MIN(started_at) per trip, use \
|
||||
\"SELECT trip_id, MIN(started_at) FROM t GROUP BY trip_id\". \
|
||||
NEVER put the aggregated column (started_at) into GROUP BY — that defeats the aggregation \
|
||||
and returns every row separately instead of one per group.\n\
|
||||
- TOP-1 per group with extra columns: use DISTINCT ON (group_col) ... ORDER BY group_col, sort_col \
|
||||
or a subquery with ROW_NUMBER() OVER (PARTITION BY group_col ORDER BY sort_col) = 1.\n\
|
||||
- For \"time from A to B\" calculations, ensure both timestamps are NOT NULL with WHERE filters; \
|
||||
never use COALESCE to mix planned and actual timestamps.\n\
|
||||
\n\
|
||||
BEST PRACTICES:\n\
|
||||
- Use ILIKE for case-insensitive text search, LIKE for case-sensitive.\n\
|
||||
- Use EXISTS instead of IN for subquery existence checks.\n\
|
||||
- Use CTE (WITH ... AS) for complex multi-step logic.\n\
|
||||
- Use window functions (ROW_NUMBER, RANK, LAG, LEAD, SUM OVER) for ranking and running totals.\n\
|
||||
- Use date_trunc('period', column) for time-based grouping.\n\
|
||||
- Use generate_series() for creating ranges.\n\
|
||||
- Use string_agg(col, ', ') for concatenating grouped values.\n\
|
||||
- Use FILTER (WHERE ...) for conditional aggregation instead of CASE inside aggregate.\n\
|
||||
\n\
|
||||
{}\n",
|
||||
schema_text
|
||||
);
|
||||
|
||||
let raw = call_chat_simple(&app, &state, system_prompt, prompt).await?;
|
||||
Ok(clean_sql_response(&raw))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SQL explanation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn explain_sql(
|
||||
app: AppHandle,
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
sql: String,
|
||||
) -> TuskResult<String> {
|
||||
let schema_text = build_schema_context(&state, &connection_id).await?;
|
||||
|
||||
let system_prompt = format!(
|
||||
"You are a PostgreSQL expert. Explain the given SQL query clearly and concisely.\n\
|
||||
\n\
|
||||
Structure your explanation as:\n\
|
||||
1. **Summary** — one sentence describing what the query returns in business terms.\n\
|
||||
2. **Step-by-step breakdown** — explain tables accessed, joins, filters, aggregations, \
|
||||
subqueries, and sorting. Use bullet points.\n\
|
||||
3. **Notes** — mention edge cases, potential issues, or performance concerns if any.\n\
|
||||
\n\
|
||||
Use the database schema below to understand table relationships and column meanings.\n\
|
||||
Keep the explanation short; avoid restating the SQL verbatim.\n\
|
||||
\n\
|
||||
IMPORTANT: If you notice semantic issues (e.g. mixing planned_* and actual_* timestamps \
|
||||
with COALESCE, comparing unrelated columns, missing NULL filters on nullable timestamps), \
|
||||
mention them in the Notes section as potential problems.\n\
|
||||
\n\
|
||||
{}\n",
|
||||
schema_text
|
||||
);
|
||||
|
||||
call_chat_simple(&app, &state, system_prompt, sql).await
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SQL error fixing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn fix_sql_error(
|
||||
app: AppHandle,
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
sql: String,
|
||||
error_message: String,
|
||||
) -> TuskResult<String> {
|
||||
let schema_text = build_schema_context(&state, &connection_id).await?;
|
||||
|
||||
let system_prompt = format!(
|
||||
"You are a PostgreSQL expert debugger. You receive a SQL query and the error it produced. \
|
||||
Fix the query so it executes correctly.\n\
|
||||
\n\
|
||||
OUTPUT FORMAT:\n\
|
||||
- Raw SQL only. No explanations, no markdown code fences (```), no comments.\n\
|
||||
- The output must be directly executable.\n\
|
||||
\n\
|
||||
DIAGNOSTIC CHECKLIST:\n\
|
||||
- Column/table does not exist → check the schema for correct names and spelling.\n\
|
||||
- Column is ambiguous → qualify with table name or alias.\n\
|
||||
- Must appear in GROUP BY → add missing non-aggregated columns to GROUP BY.\n\
|
||||
- Type mismatch → add appropriate casts (::text, ::integer, etc.).\n\
|
||||
- Permission denied → wrap in a read-only transaction if needed.\n\
|
||||
- Syntax error → correct PostgreSQL syntax (check commas, parentheses, keywords).\n\
|
||||
- Subquery returns more than one row → use IN, ANY, or add LIMIT 1.\n\
|
||||
- Division by zero → wrap divisor with NULLIF(x, 0).\n\
|
||||
\n\
|
||||
ONLY use tables and columns from the schema below. Never invent names.\n\
|
||||
Preserve the original intent of the query; change only what is necessary to fix the error.\n\
|
||||
\n\
|
||||
{}\n",
|
||||
schema_text
|
||||
);
|
||||
|
||||
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
|
||||
|
||||
let raw = call_chat_simple(&app, &state, system_prompt, user_content).await?;
|
||||
Ok(clean_sql_response(&raw))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Lite overview builder (chat v2)
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -735,231 +608,6 @@ async fn build_overview_clickhouse(state: &AppState, connection_id: &str) -> Tus
|
||||
Ok(out.join("\n"))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Full schema context builder (legacy — used by generate_sql/explain_sql/fix_sql_error)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub(crate) async fn build_schema_context(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
) -> TuskResult<String> {
|
||||
// Check cache first
|
||||
if let Some(cached) = state.get_schema_cache(connection_id).await {
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return build_clickhouse_schema_context(state, connection_id).await;
|
||||
}
|
||||
|
||||
let is_greenplum = matches!(flavor, DbFlavor::Greenplum);
|
||||
let gp_major = state.get_gp_major(connection_id).await.unwrap_or(7);
|
||||
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
|
||||
// Run all metadata queries in parallel for speed
|
||||
let (
|
||||
version_res,
|
||||
current_db_res,
|
||||
all_dbs_res,
|
||||
col_res,
|
||||
fk_res,
|
||||
enum_res,
|
||||
tbl_comment_res,
|
||||
col_comment_res,
|
||||
unique_res,
|
||||
varchar_res,
|
||||
jsonb_res,
|
||||
) = tokio::join!(
|
||||
sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool),
|
||||
sqlx::query_scalar::<_, String>("SELECT current_database()").fetch_one(&pool),
|
||||
sqlx::query_scalar::<_, String>(
|
||||
"SELECT datname FROM pg_database \
|
||||
WHERE datistemplate = false AND datallowconn = true \
|
||||
ORDER BY datname"
|
||||
)
|
||||
.fetch_all(&pool),
|
||||
fetch_columns(&pool),
|
||||
fetch_foreign_keys_raw(&pool),
|
||||
fetch_enum_types(&pool),
|
||||
fetch_table_comments(&pool),
|
||||
fetch_column_comments(&pool),
|
||||
fetch_unique_constraints(&pool),
|
||||
fetch_varchar_values(&pool),
|
||||
fetch_jsonb_keys(&pool),
|
||||
);
|
||||
|
||||
let version = version_res.map_err(TuskError::Database)?;
|
||||
let current_db = current_db_res.unwrap_or_default();
|
||||
let all_dbs = all_dbs_res.unwrap_or_default();
|
||||
let col_rows = col_res?;
|
||||
let fk_rows = fk_res?;
|
||||
let enum_map = enum_res?;
|
||||
let tbl_comments = tbl_comment_res?;
|
||||
let col_comments = col_comment_res?;
|
||||
let unique_constraints = unique_res?;
|
||||
let varchar_values = varchar_res.unwrap_or_default();
|
||||
let jsonb_keys = jsonb_res.unwrap_or_default();
|
||||
let gp_extras = if is_greenplum {
|
||||
Some(fetch_gp_table_extras(&pool, gp_major).await)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// -- Build FK inline lookup: (schema, table, column) -> "ref_schema.ref_table(ref_col)" --
|
||||
let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new();
|
||||
let mut fk_lines: Vec<String> = Vec::new();
|
||||
for fk in &fk_rows {
|
||||
let line = format!(
|
||||
"FK: {}.{}({}) -> {}.{}({})",
|
||||
fk.schema,
|
||||
fk.table,
|
||||
fk.columns.join(", "),
|
||||
fk.ref_schema,
|
||||
fk.ref_table,
|
||||
fk.ref_columns.join(", ")
|
||||
);
|
||||
fk_lines.push(line);
|
||||
|
||||
// For single-column FKs, enable inline annotation on column
|
||||
if fk.columns.len() == 1 && fk.ref_columns.len() == 1 {
|
||||
fk_inline.insert(
|
||||
(fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()),
|
||||
format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// -- Build unique constraint lookup: (schema, table) -> Vec<column_list_string> --
|
||||
let mut unique_map: HashMap<(String, String), Vec<String>> = HashMap::new();
|
||||
for (schema, table, cols) in &unique_constraints {
|
||||
unique_map
|
||||
.entry((schema.clone(), table.clone()))
|
||||
.or_default()
|
||||
.push(cols.join(", "));
|
||||
}
|
||||
|
||||
// -- Format output --
|
||||
let mut output: Vec<String> = Vec::new();
|
||||
|
||||
// 1. PostgreSQL version (short form)
|
||||
let short_version = version
|
||||
.split_whitespace()
|
||||
.take(2)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
output.push(format!("DATABASE SCHEMA ({})", short_version));
|
||||
if !current_db.is_empty() {
|
||||
output.push(format!("ACTIVE DATABASE: {}", current_db));
|
||||
}
|
||||
output.push(String::new());
|
||||
|
||||
// 2. Cluster topology — other databases on this server.
|
||||
// Each PG database is isolated; cross-DB queries are not possible from a single connection.
|
||||
if all_dbs.len() > 1 {
|
||||
output.push("DATABASES ON THIS SERVER:".to_string());
|
||||
for db in &all_dbs {
|
||||
if db == ¤t_db {
|
||||
output.push(format!(" * {} (active)", db));
|
||||
} else {
|
||||
output.push(format!(" {}", db));
|
||||
}
|
||||
}
|
||||
output.push(String::new());
|
||||
output.push(
|
||||
"NOTE: Tables in other databases are NOT queryable from this session. \
|
||||
If the user's question concerns data likely stored in a different database \
|
||||
(e.g. an identity service in a separate DB), respond with a `final` message \
|
||||
asking them to switch the active database via the connection selector."
|
||||
.to_string(),
|
||||
);
|
||||
output.push(String::new());
|
||||
}
|
||||
|
||||
// 3. Quick table+column index for fast existence checks before writing SQL.
|
||||
// Each line lists `schema.table(col1, col2, ...)` so the model can grep both
|
||||
// table names and column names without scrolling through the full TABLES section.
|
||||
{
|
||||
let mut by_table: BTreeMap<(String, String), Vec<String>> = BTreeMap::new();
|
||||
for c in &col_rows {
|
||||
by_table
|
||||
.entry((c.schema.clone(), c.table.clone()))
|
||||
.or_default()
|
||||
.push(c.column.clone());
|
||||
}
|
||||
if !by_table.is_empty() {
|
||||
output.push(format!(
|
||||
"TABLE INDEX (database `{}`, {} tables — table_name(column_list)):",
|
||||
current_db,
|
||||
by_table.len()
|
||||
));
|
||||
for ((schema, table), cols) in &by_table {
|
||||
output.push(format!(" {}.{}({})", schema, table, cols.join(", ")));
|
||||
}
|
||||
output.push(String::new());
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Enum types
|
||||
if !enum_map.is_empty() {
|
||||
output.push("ENUM TYPES:".to_string());
|
||||
for (type_name, values) in &enum_map {
|
||||
let vals_str = values
|
||||
.iter()
|
||||
.map(|v| format!("'{}'", v))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
output.push(format!(" {} = [{}]", type_name, vals_str));
|
||||
}
|
||||
output.push(String::new());
|
||||
}
|
||||
|
||||
// 5. Tables with columns
|
||||
output.push("TABLES:".to_string());
|
||||
|
||||
// Group columns by schema.table preserving order
|
||||
let mut tables: BTreeMap<String, Vec<ColumnInfo>> = BTreeMap::new();
|
||||
for ci in &col_rows {
|
||||
let key = format!("{}.{}", ci.schema, ci.table);
|
||||
tables.entry(key).or_default().push(ci.clone());
|
||||
}
|
||||
|
||||
for (full_name, columns) in &tables {
|
||||
format_table_block(
|
||||
full_name,
|
||||
columns,
|
||||
&tbl_comments,
|
||||
&col_comments,
|
||||
&fk_inline,
|
||||
&enum_map,
|
||||
&unique_map,
|
||||
&varchar_values,
|
||||
&jsonb_keys,
|
||||
gp_extras.as_ref(),
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
|
||||
// 6. Foreign keys summary
|
||||
if !fk_lines.is_empty() {
|
||||
output.push(String::new());
|
||||
output.push("FOREIGN KEYS:".to_string());
|
||||
for fk in &fk_lines {
|
||||
output.push(format!(" {}", fk));
|
||||
}
|
||||
}
|
||||
|
||||
let result = output.join("\n");
|
||||
|
||||
// Cache the result
|
||||
state
|
||||
.set_schema_cache(connection_id.to_string(), result.clone())
|
||||
.await;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Schema query helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -976,8 +624,7 @@ pub(crate) struct ColumnInfo {
|
||||
}
|
||||
|
||||
/// Render a single table's column block in the human/LLM-readable schema format.
|
||||
/// Reused by both `build_schema_context` (full DDL for legacy AI commands) and
|
||||
/// the new `get_columns` chat tool.
|
||||
/// Used by the chat agent's `get_columns` tool.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn format_table_block(
|
||||
full_name: &str,
|
||||
@@ -1102,106 +749,6 @@ pub(crate) fn format_table_block(
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_clickhouse_schema_context(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
) -> TuskResult<String> {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let db = client.database.clone();
|
||||
|
||||
// ClickHouse exposes ALL databases via system.* — pull cross-DB schema in one shot.
|
||||
let columns_sql = "SELECT database, table, name, type, is_in_primary_key \
|
||||
FROM system.columns \
|
||||
WHERE database NOT IN ('system','INFORMATION_SCHEMA','information_schema') \
|
||||
ORDER BY database, table, position";
|
||||
let dbs_sql = "SELECT name FROM system.databases \
|
||||
WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \
|
||||
ORDER BY name";
|
||||
|
||||
let rows = client.fetch_objects(columns_sql).await?;
|
||||
let db_rows = client.fetch_objects(dbs_sql).await.unwrap_or_default();
|
||||
|
||||
let version = client.ping().await.unwrap_or_default();
|
||||
let mut out = String::new();
|
||||
out.push_str("DATABASE: ClickHouse\n");
|
||||
if !version.is_empty() {
|
||||
out.push_str(&format!("VERSION: {}\n", version.trim()));
|
||||
}
|
||||
out.push_str(&format!("ACTIVE_DATABASE: {}\n\n", db));
|
||||
|
||||
// Cluster overview
|
||||
if !db_rows.is_empty() {
|
||||
out.push_str("DATABASES ON THIS SERVER:\n");
|
||||
for row in &db_rows {
|
||||
let name = row.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if name == db {
|
||||
out.push_str(&format!(" * {} (active)\n", name));
|
||||
} else {
|
||||
out.push_str(&format!(" {}\n", name));
|
||||
}
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
// Table+column index — same shape as the PG path so model has a uniform reference.
|
||||
{
|
||||
let mut by_table: BTreeMap<(String, String), Vec<String>> = BTreeMap::new();
|
||||
for r in &rows {
|
||||
let dbn = r.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let tbl = r.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let col = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
if !dbn.is_empty() && !tbl.is_empty() && !col.is_empty() {
|
||||
by_table.entry((dbn, tbl)).or_default().push(col);
|
||||
}
|
||||
}
|
||||
if !by_table.is_empty() {
|
||||
out.push_str(&format!(
|
||||
"TABLE INDEX ({} tables across all databases — db.table(column_list)):\n",
|
||||
by_table.len()
|
||||
));
|
||||
for ((dbn, tbl), cols) in &by_table {
|
||||
out.push_str(&format!(" {}.{}({})\n", dbn, tbl, cols.join(", ")));
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
out.push_str("TABLES:\n");
|
||||
let mut current_key: Option<String> = None;
|
||||
for row in &rows {
|
||||
let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let table = row.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let column = row.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let dtype = row.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let is_pk = matches!(row.get("is_in_primary_key"), Some(serde_json::Value::Number(n)) if n.as_i64() == Some(1))
|
||||
|| matches!(row.get("is_in_primary_key"), Some(serde_json::Value::String(s)) if s == "1");
|
||||
let key = format!("{}.{}", dbn, table);
|
||||
if Some(&key) != current_key.as_ref() {
|
||||
out.push_str(&format!("\nTABLE {}.{}\n", dbn, table));
|
||||
current_key = Some(key);
|
||||
}
|
||||
out.push_str(&format!(
|
||||
" {} {}{}\n",
|
||||
column,
|
||||
dtype,
|
||||
if is_pk { " [PK]" } else { "" }
|
||||
));
|
||||
}
|
||||
|
||||
out.push_str(
|
||||
"\nNOTES:\n\
|
||||
- Use ClickHouse SQL dialect. Functions differ from PostgreSQL (e.g. count(), arrayJoin, toDate, formatDateTime).\n\
|
||||
- ClickHouse allows fully-qualified `database.table` in queries — you CAN cross-reference databases on this server.\n\
|
||||
- Read-only mode is enforced for the agent: only SELECT/WITH/EXPLAIN/SHOW/DESCRIBE allowed.\n\
|
||||
- Always include LIMIT for ad-hoc SELECTs.\n",
|
||||
);
|
||||
|
||||
state
|
||||
.set_schema_cache(connection_id.to_string(), out.clone())
|
||||
.await;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
pub(crate) async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult<Vec<ColumnInfo>> {
|
||||
let rows = sqlx::query(
|
||||
"SELECT \
|
||||
@@ -1399,6 +946,7 @@ pub(crate) async fn fetch_unique_constraints(
|
||||
/// Returns HashMap<(schema, table, column), Vec<distinct_values>> for varchar columns
|
||||
/// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery.
|
||||
/// Returns None if pg_stats is not accessible (graceful degradation).
|
||||
#[allow(dead_code)] // re-exposed by profile_table tool (PR2)
|
||||
async fn fetch_varchar_values(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Option<HashMap<(String, String, String), Vec<String>>> {
|
||||
@@ -1440,104 +988,6 @@ async fn fetch_varchar_values(
|
||||
Some(map)
|
||||
}
|
||||
|
||||
/// Discovers top-level keys in JSONB columns by sampling actual data.
|
||||
/// Runs two sequential queries internally: first discovers JSONB columns,
|
||||
/// then samples keys from each via a single UNION ALL query.
|
||||
/// Returns None on error (graceful degradation).
|
||||
async fn fetch_jsonb_keys(
|
||||
pool: &sqlx::PgPool,
|
||||
) -> Option<HashMap<(String, String, String), Vec<String>>> {
|
||||
// Step 1: Find all JSONB columns
|
||||
let col_rows = match sqlx::query(
|
||||
"SELECT table_schema, table_name, column_name \
|
||||
FROM information_schema.columns \
|
||||
WHERE data_type = 'jsonb' \
|
||||
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
|
||||
ORDER BY table_schema, table_name, column_name",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
log::warn!("Failed to fetch JSONB columns: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if col_rows.is_empty() {
|
||||
return Some(HashMap::new());
|
||||
}
|
||||
|
||||
// Cap at 50 JSONB columns to prevent unbounded UNION ALL queries on large schemas
|
||||
let columns: Vec<(String, String, String)> = col_rows
|
||||
.iter()
|
||||
.take(50)
|
||||
.map(|r| {
|
||||
(
|
||||
r.get::<String, _>(0),
|
||||
r.get::<String, _>(1),
|
||||
r.get::<String, _>(2),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Step 2: Build a single UNION ALL query to sample keys from all JSONB columns
|
||||
let parts: Vec<String> = columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (schema, table, col))| {
|
||||
let qs = schema.replace('"', "\"\"");
|
||||
let qt = table.replace('"', "\"\"");
|
||||
let qc = col.replace('"', "\"\"");
|
||||
format!(
|
||||
"(SELECT '{}.{}.{}' AS col_ref, key FROM (\
|
||||
SELECT DISTINCT jsonb_object_keys(\"{}\") AS key \
|
||||
FROM \"{}\".\"{}\" \
|
||||
WHERE \"{}\" IS NOT NULL AND jsonb_typeof(\"{}\") = 'object' \
|
||||
LIMIT 50\
|
||||
) sub{})",
|
||||
schema, table, col, qc, qs, qt, qc, qc, i
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let query = parts.join(" UNION ALL ");
|
||||
|
||||
let rows = match sqlx::query(&query).fetch_all(pool).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
log::warn!("Failed to fetch JSONB keys: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let mut map: HashMap<(String, String, String), Vec<String>> = HashMap::new();
|
||||
for r in &rows {
|
||||
let col_ref: String = r.get(0);
|
||||
let key: String = r.get(1);
|
||||
let ref_parts: Vec<&str> = col_ref.splitn(3, '.').collect();
|
||||
if ref_parts.len() == 3 {
|
||||
let entry = map
|
||||
.entry((
|
||||
ref_parts[0].to_string(),
|
||||
ref_parts[1].to_string(),
|
||||
ref_parts[2].to_string(),
|
||||
))
|
||||
.or_default();
|
||||
if !entry.contains(&key) {
|
||||
entry.push(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for vals in map.values_mut() {
|
||||
vals.sort();
|
||||
}
|
||||
|
||||
Some(map)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Greenplum-specific table attributes
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -1656,6 +1106,7 @@ pub(crate) async fn fetch_gp_table_extras(
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"}
|
||||
#[allow(dead_code)] // helper for fetch_varchar_values; re-exposed by profile_table tool (PR2)
|
||||
fn parse_pg_array_text(s: &str) -> Vec<String> {
|
||||
let s = s.trim();
|
||||
let s = s.strip_prefix('{').unwrap_or(s);
|
||||
@@ -1716,65 +1167,10 @@ fn simplify_default(raw: &str) -> String {
|
||||
s.to_string()
|
||||
}
|
||||
|
||||
fn clean_sql_response(raw: &str) -> String {
|
||||
let trimmed = raw.trim();
|
||||
// Remove markdown code fences
|
||||
let without_fences = if trimmed.starts_with("```") {
|
||||
let inner = trimmed
|
||||
.strip_prefix("```sql")
|
||||
.or_else(|| trimmed.strip_prefix("```SQL"))
|
||||
.or_else(|| trimmed.strip_prefix("```postgresql"))
|
||||
.or_else(|| trimmed.strip_prefix("```"))
|
||||
.unwrap_or(trimmed);
|
||||
inner.strip_suffix("```").unwrap_or(inner)
|
||||
} else {
|
||||
trimmed
|
||||
};
|
||||
without_fences.trim().to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// ── clean_sql_response ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn clean_sql_plain() {
|
||||
assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_fences() {
|
||||
assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_generic_fences() {
|
||||
assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_postgresql_fences() {
|
||||
assert_eq!(
|
||||
clean_sql_response("```postgresql\nSELECT 1\n```"),
|
||||
"SELECT 1"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_with_whitespace() {
|
||||
assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clean_sql_no_fences_multiline() {
|
||||
assert_eq!(
|
||||
clean_sql_response("SELECT\n *\nFROM users"),
|
||||
"SELECT\n *\nFROM users"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Fireworks provider ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
@@ -1784,12 +1180,14 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserializes_legacy_settings_without_fireworks_key() {
|
||||
// Old config files won't have `fireworks_api_key` — must still parse.
|
||||
fn deserializes_legacy_settings_with_dropped_provider_keys() {
|
||||
// Old config files may include `openai_api_key`/`anthropic_api_key` and a
|
||||
// legacy `"provider": "openai"` value — both must be tolerated, with the
|
||||
// unknown provider coerced to Ollama.
|
||||
let legacy = r#"{
|
||||
"provider": "ollama",
|
||||
"provider": "openai",
|
||||
"ollama_url": "http://localhost:11434",
|
||||
"openai_api_key": null,
|
||||
"openai_api_key": "sk-deprecated",
|
||||
"anthropic_api_key": null,
|
||||
"model": "qwen2.5-coder:7b"
|
||||
}"#;
|
||||
@@ -1813,6 +1211,28 @@ mod tests {
|
||||
assert_eq!(parsed.choices[0].message.content, "hi");
|
||||
}
|
||||
|
||||
// ── OpenRouter provider ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn serializes_openrouter_provider() {
|
||||
let json = serde_json::to_string(&AiProvider::OpenRouter).unwrap();
|
||||
assert_eq!(json, "\"openrouter\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserializes_openrouter_settings() {
|
||||
let cfg = r#"{
|
||||
"provider": "openrouter",
|
||||
"ollama_url": "http://localhost:11434",
|
||||
"openrouter_api_key": "sk-or-v1-abc",
|
||||
"model": "anthropic/claude-3.5-sonnet"
|
||||
}"#;
|
||||
let parsed: AiSettings = serde_json::from_str(cfg).unwrap();
|
||||
assert_eq!(parsed.provider, AiProvider::OpenRouter);
|
||||
assert_eq!(parsed.openrouter_api_key.as_deref(), Some("sk-or-v1-abc"));
|
||||
assert_eq!(parsed.model, "anthropic/claude-3.5-sonnet");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_fireworks_models_list() {
|
||||
let body = r#"{
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings};
|
||||
use crate::commands::chat_tools::{
|
||||
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
|
||||
build_sample_sql, detect_skew_tool, explain_query_tool, find_queries_tool, get_columns_tool,
|
||||
list_databases_tool, list_tables_tool, profile_table_tool, save_query_tool,
|
||||
switch_database_tool,
|
||||
};
|
||||
use crate::commands::memory::{append_memory_core, read_memory_core};
|
||||
use crate::commands::queries::execute_query_core;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::ai::OllamaChatMessage;
|
||||
use crate::models::chat::{ChartConfig, ChatMessage, ChatTurnResult, ContextUsage};
|
||||
use crate::models::chat::{ChatMessage, ChatTurnResult, ContextUsage};
|
||||
use crate::models::query_result::QueryResult;
|
||||
use crate::state::AppState;
|
||||
use chrono::Utc;
|
||||
@@ -30,11 +31,11 @@ const TEXT_TOOL_CHAR_CAP: usize = 10_000;
|
||||
/// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192).
|
||||
/// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content.
|
||||
const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000;
|
||||
/// Conservative default for managed providers (Fireworks). Most chat-capable
|
||||
/// Fireworks models ship with 32K–256K context windows; 384K chars (~128K tok)
|
||||
/// is a safe floor that won't trigger false /compact nags on normal sessions
|
||||
/// while still flagging genuinely runaway threads.
|
||||
const CONTEXT_BUDGET_CHARS_FIREWORKS: u64 = 384_000;
|
||||
/// Conservative default for managed providers (Fireworks, OpenRouter). Most
|
||||
/// chat-capable hosted models ship with 32K–256K context windows; 384K chars
|
||||
/// (~128K tok) is a safe floor that won't trigger false /compact nags on normal
|
||||
/// sessions while still flagging genuinely runaway threads.
|
||||
const CONTEXT_BUDGET_CHARS_MANAGED: u64 = 384_000;
|
||||
/// Stop the loop when the model fails the same SQL hurdle this many times in a
|
||||
/// row. Beyond this, additional hops almost always burn the rest of the budget
|
||||
/// on identical retries; a definitive `final` with the error is more useful.
|
||||
@@ -55,9 +56,15 @@ enum AgentAction {
|
||||
Remember { note: String },
|
||||
SaveQuery { name: String, sql: String },
|
||||
FindQueries { text: String },
|
||||
MakeChart { config: ChartConfig },
|
||||
ProfileTable { table: String },
|
||||
SampleData { table: String, limit: u32 },
|
||||
ExplainQuery { sql: String },
|
||||
DetectSkew { table: String },
|
||||
}
|
||||
|
||||
const SAMPLE_DATA_DEFAULT_LIMIT: u32 = 50;
|
||||
const SAMPLE_DATA_MAX_LIMIT: u32 = 200;
|
||||
|
||||
/// Parse the model's JSON response. Accepts both shapes the model tends to emit:
|
||||
/// {"action":"X","field":"..."} — flat (matches our prompt)
|
||||
/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention)
|
||||
@@ -157,60 +164,55 @@ fn parse_agent_action(raw: &str) -> Result<AgentAction, String> {
|
||||
}
|
||||
Ok(AgentAction::FindQueries { text })
|
||||
}
|
||||
"make_chart" => {
|
||||
let chart_type = lookup("chart_type")
|
||||
.or_else(|| lookup("type"))
|
||||
"profile_table" => {
|
||||
let table = lookup("table")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "make_chart missing `chart_type`".to_string())?
|
||||
.trim()
|
||||
.to_lowercase();
|
||||
if !["bar", "line", "area", "pie"].contains(&chart_type.as_str()) {
|
||||
return Err(format!(
|
||||
"make_chart `chart_type` must be one of: bar, line, area, pie. Got: {}",
|
||||
chart_type
|
||||
));
|
||||
}
|
||||
let x = lookup("x")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "make_chart missing `x` column".to_string())?
|
||||
.ok_or_else(|| "profile_table missing `table`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
let y = lookup("y")
|
||||
if table.is_empty() {
|
||||
return Err("profile_table `table` must not be empty".into());
|
||||
}
|
||||
Ok(AgentAction::ProfileTable { table })
|
||||
}
|
||||
"sample_data" => {
|
||||
let table = lookup("table")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "make_chart missing `y` column".to_string())?
|
||||
.ok_or_else(|| "sample_data missing `table`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
if x.is_empty() || y.is_empty() {
|
||||
return Err("make_chart `x` and `y` must not be empty".into());
|
||||
if table.is_empty() {
|
||||
return Err("sample_data `table` must not be empty".into());
|
||||
}
|
||||
let group = lookup("group")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty());
|
||||
let title = lookup("title")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty());
|
||||
let orientation = lookup("orientation")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.trim().to_lowercase())
|
||||
.filter(|s| !s.is_empty());
|
||||
Ok(AgentAction::MakeChart {
|
||||
config: ChartConfig {
|
||||
chart_type,
|
||||
x,
|
||||
y,
|
||||
group,
|
||||
title,
|
||||
orientation,
|
||||
},
|
||||
})
|
||||
let limit = lookup("limit")
|
||||
.and_then(|v| v.as_u64())
|
||||
.map(|n| n as u32)
|
||||
.unwrap_or(SAMPLE_DATA_DEFAULT_LIMIT)
|
||||
.clamp(1, SAMPLE_DATA_MAX_LIMIT);
|
||||
Ok(AgentAction::SampleData { table, limit })
|
||||
}
|
||||
"explain_query" => {
|
||||
let sql = lookup("sql")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "explain_query missing `sql`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
if sql.is_empty() {
|
||||
return Err("explain_query `sql` must not be empty".into());
|
||||
}
|
||||
Ok(AgentAction::ExplainQuery { sql })
|
||||
}
|
||||
"detect_skew" => {
|
||||
let table = lookup("table")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "detect_skew missing `table`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
if table.is_empty() {
|
||||
return Err("detect_skew `table` must not be empty".into());
|
||||
}
|
||||
Ok(AgentAction::DetectSkew { table })
|
||||
}
|
||||
// Legacy from earlier iterations — silently ignored at parse time so the
|
||||
// model can recover with a different action.
|
||||
"get_schema" => Err(
|
||||
"get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(),
|
||||
),
|
||||
other => Err(format!("unknown action `{}`", other)),
|
||||
}
|
||||
}
|
||||
@@ -285,8 +287,17 @@ You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On
|
||||
{{"action":"save_query","name":"<short label>","sql":"<the SQL>"}}
|
||||
Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved.
|
||||
|
||||
{{"action":"make_chart","chart_type":"bar","x":"<col>","y":"<col>","title":"<short title>"}}
|
||||
Visualise the LAST successful run_query result as a chart inline. `chart_type` is one of: bar, line, area, pie. `x` and `y` MUST be column names from the previous result. Optional: `group` (column for series), `orientation` ("vertical"/"horizontal", bar only). Use after run_query when the data is aggregated and would be clearer as a chart (top-N comparisons → bar; time series → line/area; proportions → pie). Skip for tiny results (≤2 rows) and giant ones (>500 rows).
|
||||
{{"action":"profile_table","table":"schema.table"}}
|
||||
Per-column profile: NULL fraction, distinct cardinality, min/max range, top-K values. PG/GP reads pg_stats (zero-cost; ensure ANALYZE has run). ClickHouse fires one summary query (cheap on MergeTree). Use BEFORE writing aggregations to spot pseudo-enums, NULL-heavy columns, or skewed distributions.
|
||||
|
||||
{{"action":"sample_data","table":"schema.table","limit":50}}
|
||||
Random row sample (default 50, max 200). PG/GP uses TABLESAMPLE BERNOULLI when reltuples > 0, else ORDER BY random(). CH uses SAMPLE 0.01 on MergeTree with a sampling key, else ORDER BY rand(). Use to eyeball value shape BEFORE writing filters; cheaper than `SELECT * LIMIT N` on huge tables.
|
||||
|
||||
{{"action":"explain_query","sql":"SELECT ..."}}
|
||||
Run EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) on PG/GP, EXPLAIN PLAN on CH. Reports root node, planning + execution time, seq-scanned tables, spilled sorts, est-vs-actual row skew, Greenplum Motions. Use AFTER a slow run_query.
|
||||
|
||||
{{"action":"detect_skew","table":"schema.table"}}
|
||||
Greenplum-only: counts rows per gp_segment_id and reports max/min/avg + skew ratio. Ratio > 1.5 ⇒ uneven distribution; suggests revisiting DISTRIBUTED BY. Soft-errors on PG/CH.
|
||||
|
||||
{{"action":"final","text":"..."}}
|
||||
End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling).
|
||||
@@ -296,6 +307,8 @@ WORKFLOW
|
||||
2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question.
|
||||
3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs.
|
||||
4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns.
|
||||
4a. If the user asks about value shape (cardinality, NULL rates, top values), prefer `profile_table` over a hand-written run_query. To eyeball actual rows, prefer `sample_data` over `LIMIT 100`.
|
||||
4b. If the user reports a slow query or asks why something takes long, run `explain_query` on it. On Greenplum, if a single table appears unbalanced, check `detect_skew`.
|
||||
5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first.
|
||||
6. Execute run_query.
|
||||
7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here.
|
||||
@@ -415,9 +428,6 @@ fn build_history(
|
||||
content: serde_json::json!({ "action": "final", "text": text }).to_string(),
|
||||
}),
|
||||
ChatMessage::ToolCall { tool, input_json, .. } => {
|
||||
if tool == "get_schema" {
|
||||
continue; // legacy
|
||||
}
|
||||
let mut envelope = serde_json::Map::new();
|
||||
envelope.insert("action".to_string(), Value::String(tool.clone()));
|
||||
if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) {
|
||||
@@ -437,9 +447,6 @@ fn build_history(
|
||||
result,
|
||||
..
|
||||
} => {
|
||||
if tool == "get_schema" {
|
||||
continue; // legacy
|
||||
}
|
||||
let payload = match tool.as_str() {
|
||||
"run_query" => {
|
||||
if *is_error {
|
||||
@@ -521,7 +528,7 @@ async fn provider_budget_chars(state: &AppState, app: &AppHandle) -> u64 {
|
||||
use crate::models::ai::AiProvider;
|
||||
match load_ai_settings(app, state).await {
|
||||
Ok(s) => match s.provider {
|
||||
AiProvider::Fireworks => CONTEXT_BUDGET_CHARS_FIREWORKS,
|
||||
AiProvider::Fireworks | AiProvider::OpenRouter => CONTEXT_BUDGET_CHARS_MANAGED,
|
||||
_ => CONTEXT_BUDGET_CHARS_OLLAMA,
|
||||
},
|
||||
Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA,
|
||||
@@ -597,7 +604,10 @@ pub async fn chat_send(
|
||||
}
|
||||
};
|
||||
|
||||
let is_run_query = matches!(&action, AgentAction::RunQuery { .. });
|
||||
let is_run_query = matches!(
|
||||
&action,
|
||||
AgentAction::RunQuery { .. } | AgentAction::SampleData { .. }
|
||||
);
|
||||
|
||||
match action {
|
||||
AgentAction::Final { text } => {
|
||||
@@ -742,91 +752,90 @@ pub async fn chat_send(
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::MakeChart { config } => {
|
||||
let config_json = serde_json::to_string(&config).unwrap_or_else(|_| "{}".into());
|
||||
AgentAction::ProfileTable { table } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"make_chart",
|
||||
config_json.clone(),
|
||||
"profile_table",
|
||||
serde_json::json!({ "table": &table }).to_string(),
|
||||
);
|
||||
|
||||
let result_msg = match last_successful_query_result(&working) {
|
||||
None => ChatMessage::ToolResult {
|
||||
let result = run_text_tool(
|
||||
profile_table_tool(&state, &connection_id, &table).await,
|
||||
"profile_table",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::ExplainQuery { sql } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"explain_query",
|
||||
serde_json::json!({ "sql": &sql }).to_string(),
|
||||
);
|
||||
let result = run_text_tool(
|
||||
explain_query_tool(&state, &connection_id, &sql).await,
|
||||
"explain_query",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::DetectSkew { table } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"detect_skew",
|
||||
serde_json::json!({ "table": &table }).to_string(),
|
||||
);
|
||||
let result = run_text_tool(
|
||||
detect_skew_tool(&state, &connection_id, &table).await,
|
||||
"detect_skew",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::SampleData { table, limit } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"sample_data",
|
||||
serde_json::json!({ "table": &table, "limit": limit }).to_string(),
|
||||
);
|
||||
let outcome = match build_sample_sql(&state, &connection_id, &table, limit).await {
|
||||
Ok(sql) => match execute_query_core(&state, &connection_id, &sql).await {
|
||||
Ok(qr) => {
|
||||
consecutive_query_errors = 0;
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "make_chart".to_string(),
|
||||
tool: "sample_data".to_string(),
|
||||
is_error: false,
|
||||
text: None,
|
||||
result: Some(qr),
|
||||
created_at: now_ms(),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
consecutive_query_errors += 1;
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "sample_data".to_string(),
|
||||
is_error: true,
|
||||
text: Some(
|
||||
"make_chart needs a successful run_query result above it. Run a SELECT first, then call make_chart."
|
||||
.to_string(),
|
||||
),
|
||||
text: Some(format_db_error(&e)),
|
||||
result: None,
|
||||
created_at: now_ms(),
|
||||
}
|
||||
}
|
||||
},
|
||||
Some(qr) => {
|
||||
if !qr.columns.iter().any(|c| c == &config.x) {
|
||||
Err(e) => {
|
||||
consecutive_query_errors += 1;
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "make_chart".to_string(),
|
||||
tool: "sample_data".to_string(),
|
||||
is_error: true,
|
||||
text: Some(format!(
|
||||
"x column `{}` is not in the last result. Available: {}.",
|
||||
config.x,
|
||||
qr.columns.join(", ")
|
||||
)),
|
||||
text: Some(format_db_error(&e)),
|
||||
result: None,
|
||||
created_at: now_ms(),
|
||||
}
|
||||
} else if !qr.columns.iter().any(|c| c == &config.y) {
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "make_chart".to_string(),
|
||||
is_error: true,
|
||||
text: Some(format!(
|
||||
"y column `{}` is not in the last result. Available: {}.",
|
||||
config.y,
|
||||
qr.columns.join(", ")
|
||||
)),
|
||||
result: None,
|
||||
created_at: now_ms(),
|
||||
}
|
||||
} else if let Some(group) = &config.group {
|
||||
if !qr.columns.iter().any(|c| c == group) {
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "make_chart".to_string(),
|
||||
is_error: true,
|
||||
text: Some(format!(
|
||||
"group column `{}` is not in the last result. Available: {}.",
|
||||
group,
|
||||
qr.columns.join(", ")
|
||||
)),
|
||||
result: None,
|
||||
created_at: now_ms(),
|
||||
}
|
||||
} else {
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "make_chart".to_string(),
|
||||
is_error: false,
|
||||
text: Some(config_json.clone()),
|
||||
result: Some(qr),
|
||||
created_at: now_ms(),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "make_chart".to_string(),
|
||||
is_error: false,
|
||||
text: Some(config_json.clone()),
|
||||
result: Some(qr),
|
||||
created_at: now_ms(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
push_tool_result(&mut new_messages, &mut working, result_msg);
|
||||
push_tool_result(&mut new_messages, &mut working, outcome);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -982,26 +991,6 @@ fn format_db_error(e: &TuskError) -> String {
|
||||
e.to_string()
|
||||
}
|
||||
|
||||
/// Locate the most recent SUCCESSFUL run_query in the working thread and
|
||||
/// return its full QueryResult. Used by make_chart to attach data to a chart
|
||||
/// directive without relying on the model to re-send it.
|
||||
fn last_successful_query_result(messages: &[ChatMessage]) -> Option<QueryResult> {
|
||||
for m in messages.iter().rev() {
|
||||
if let ChatMessage::ToolResult {
|
||||
tool,
|
||||
is_error: false,
|
||||
result: Some(qr),
|
||||
..
|
||||
} = m
|
||||
{
|
||||
if tool == "run_query" {
|
||||
return Some(qr.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Pull the most recent run_query error text from the working thread, so the
|
||||
/// post-loop "I gave up" summary can quote concrete errors back to the user.
|
||||
fn last_run_query_error(messages: &[ChatMessage]) -> Option<String> {
|
||||
@@ -1484,119 +1473,6 @@ mod tests {
|
||||
assert!(last_run_query_error(&msgs).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_make_chart_minimal() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"make_chart","chart_type":"bar","x":"carrier","y":"trips"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::MakeChart { config } => {
|
||||
assert_eq!(config.chart_type, "bar");
|
||||
assert_eq!(config.x, "carrier");
|
||||
assert_eq!(config.y, "trips");
|
||||
assert!(config.group.is_none());
|
||||
assert!(config.title.is_none());
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_make_chart_with_group_and_title() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"make_chart","chart_type":"line","x":"month","y":"revenue","group":"region","title":"Revenue"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::MakeChart { config } => {
|
||||
assert_eq!(config.group.as_deref(), Some("region"));
|
||||
assert_eq!(config.title.as_deref(), Some("Revenue"));
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn make_chart_accepts_alternative_field_name_type() {
|
||||
// Some models emit `type` instead of `chart_type`.
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"make_chart","type":"pie","x":"label","y":"value"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::MakeChart { config } => assert_eq!(config.chart_type, "pie"),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_make_chart_with_unknown_chart_type() {
|
||||
let r = parse_agent_action(
|
||||
r#"{"action":"make_chart","chart_type":"radar","x":"a","y":"b"}"#,
|
||||
);
|
||||
assert!(r.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_make_chart_missing_x_or_y() {
|
||||
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","y":"a"}"#).is_err());
|
||||
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","x":"a"}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last_successful_query_result_finds_recent() {
|
||||
use crate::models::query_result::QueryResult;
|
||||
let qr = QueryResult {
|
||||
columns: vec!["a".into()],
|
||||
types: vec!["INT4".into()],
|
||||
rows: vec![vec![Value::Number(1.into())]],
|
||||
row_count: 1,
|
||||
execution_time_ms: 1,
|
||||
};
|
||||
let msgs = vec![
|
||||
ChatMessage::ToolResult {
|
||||
id: "r1".into(),
|
||||
tool: "run_query".into(),
|
||||
is_error: false,
|
||||
text: None,
|
||||
result: Some(qr.clone()),
|
||||
created_at: 1,
|
||||
},
|
||||
ChatMessage::ToolResult {
|
||||
id: "r2".into(),
|
||||
tool: "run_query".into(),
|
||||
is_error: true,
|
||||
text: Some("oops".into()),
|
||||
result: None,
|
||||
created_at: 2,
|
||||
},
|
||||
];
|
||||
let found = last_successful_query_result(&msgs).expect("ok");
|
||||
assert_eq!(found.columns, vec!["a".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last_successful_query_result_skips_non_run_query() {
|
||||
use crate::models::query_result::QueryResult;
|
||||
let qr = QueryResult {
|
||||
columns: vec!["a".into()],
|
||||
types: vec!["INT4".into()],
|
||||
rows: vec![],
|
||||
row_count: 0,
|
||||
execution_time_ms: 0,
|
||||
};
|
||||
let msgs = vec![ChatMessage::ToolResult {
|
||||
id: "r1".into(),
|
||||
tool: "list_tables".into(),
|
||||
is_error: false,
|
||||
text: Some("public.x".into()),
|
||||
result: Some(qr),
|
||||
created_at: 1,
|
||||
}];
|
||||
assert!(last_successful_query_result(&msgs).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_thread_for_summary_includes_roles_and_skips_rows() {
|
||||
let msgs = vec![
|
||||
@@ -1625,11 +1501,6 @@ mod tests {
|
||||
assert!(!rendered.contains("alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_legacy_get_schema() {
|
||||
assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_long_cell() {
|
||||
let long = "a".repeat(CELL_CHAR_CAP + 50);
|
||||
|
||||
@@ -10,11 +10,14 @@ use crate::commands::ai::{
|
||||
ColumnInfo,
|
||||
};
|
||||
use crate::commands::connections::{load_connection_config, switch_database_core};
|
||||
use crate::commands::queries::execute_query_core;
|
||||
use crate::commands::saved_queries::{list_saved_queries_core, save_query_core};
|
||||
use crate::commands::schema::{list_databases_core, list_tables_core};
|
||||
use crate::db::sql_guard::ensure_readonly_sql;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::saved_queries::SavedQuery;
|
||||
use crate::state::{AppState, CachedVec, DbFlavor};
|
||||
use crate::utils::escape_ident;
|
||||
use sqlx::{PgPool, Row};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::time::{Duration, Instant};
|
||||
@@ -565,3 +568,690 @@ pub async fn find_queries_tool(
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// profile_table (PR2 — data-engineering tool)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const PROFILE_TABLE_MAX_COLUMNS: usize = 30;
|
||||
const PROFILE_TABLE_TOPK: usize = 5;
|
||||
|
||||
pub async fn profile_table_tool(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
table: &str,
|
||||
) -> TuskResult<String> {
|
||||
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
|
||||
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
match flavor {
|
||||
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
|
||||
profile_table_postgres(state, connection_id, &schema, &tbl).await
|
||||
}
|
||||
DbFlavor::ClickHouse => profile_table_clickhouse(state, connection_id, &schema, &tbl).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn profile_table_postgres(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
schema: &str,
|
||||
table: &str,
|
||||
) -> TuskResult<String> {
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
|
||||
let exists = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT 1 FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \
|
||||
WHERE n.nspname = $1 AND c.relname = $2 LIMIT 1",
|
||||
)
|
||||
.bind(schema)
|
||||
.bind(table)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
if exists.is_none() {
|
||||
return Err(TuskError::Custom(format!(
|
||||
"Table '{}.{}' does not exist (or no privileges).",
|
||||
schema, table
|
||||
)));
|
||||
}
|
||||
|
||||
let last_analyze: Option<chrono::DateTime<chrono::Utc>> = sqlx::query_scalar(
|
||||
"SELECT GREATEST(last_analyze, last_autoanalyze) FROM pg_stat_user_tables \
|
||||
WHERE schemaname = $1 AND relname = $2",
|
||||
)
|
||||
.bind(schema)
|
||||
.bind(table)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let stat_rows = sqlx::query(
|
||||
"SELECT attname, null_frac, n_distinct, \
|
||||
most_common_vals::text, most_common_freqs, histogram_bounds::text \
|
||||
FROM pg_stats \
|
||||
WHERE schemaname = $1 AND tablename = $2 \
|
||||
ORDER BY attname",
|
||||
)
|
||||
.bind(schema)
|
||||
.bind(table)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let mut out = format!("PROFILE {}.{}\n", schema, table);
|
||||
match last_analyze {
|
||||
Some(ts) => out.push_str(&format!("Last ANALYZE: {}\n", ts.to_rfc3339())),
|
||||
None => out.push_str("Last ANALYZE: never\n"),
|
||||
}
|
||||
|
||||
if stat_rows.is_empty() {
|
||||
out.push_str(&format!(
|
||||
"\nNo statistics in pg_stats. Run: ANALYZE {}.{};\n",
|
||||
escape_ident(schema),
|
||||
escape_ident(table)
|
||||
));
|
||||
return Ok(out);
|
||||
}
|
||||
|
||||
let total = stat_rows.len();
|
||||
let take = total.min(PROFILE_TABLE_MAX_COLUMNS);
|
||||
out.push_str(&format!("\n{} columns with stats\n", total));
|
||||
|
||||
for r in stat_rows.iter().take(take) {
|
||||
let attname: String = r.get(0);
|
||||
let null_frac: f32 = r.try_get(1).unwrap_or(0.0);
|
||||
let n_distinct: f32 = r.try_get(2).unwrap_or(0.0);
|
||||
let mcv_text: Option<String> = r.try_get(3).ok();
|
||||
let mcf_arr: Option<Vec<f32>> = r.try_get(4).ok();
|
||||
let hist_text: Option<String> = r.try_get(5).ok();
|
||||
|
||||
out.push_str(&format!("\n {}:\n", attname));
|
||||
out.push_str(&format!(" null_frac: {:.4}\n", null_frac));
|
||||
if n_distinct < 0.0 {
|
||||
out.push_str(&format!(
|
||||
" n_distinct: {:.3} (ratio of total rows)\n",
|
||||
-n_distinct
|
||||
));
|
||||
} else {
|
||||
out.push_str(&format!(" n_distinct: {}\n", n_distinct as i64));
|
||||
}
|
||||
|
||||
if let Some(text) = hist_text.as_deref() {
|
||||
let bounds = parse_pg_array_text_local(text);
|
||||
if let (Some(min), Some(max)) = (bounds.first(), bounds.last()) {
|
||||
out.push_str(&format!(" range: {} … {}\n", min, max));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(text) = mcv_text.as_deref() {
|
||||
let vals = parse_pg_array_text_local(text);
|
||||
if !vals.is_empty() {
|
||||
let freqs = mcf_arr.unwrap_or_default();
|
||||
let pairs: Vec<String> = vals
|
||||
.iter()
|
||||
.take(PROFILE_TABLE_TOPK)
|
||||
.enumerate()
|
||||
.map(|(i, v)| match freqs.get(i) {
|
||||
Some(f) => format!("{}({:.3})", v, f),
|
||||
None => v.clone(),
|
||||
})
|
||||
.collect();
|
||||
out.push_str(&format!(" top: {}\n", pairs.join(", ")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if total > take {
|
||||
out.push_str(&format!("\n…and {} more columns\n", total - take));
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Local pg-array parser used by profile_table; mirrors `parse_pg_array_text` in ai.rs
|
||||
/// but kept local to avoid importing a private helper.
|
||||
fn parse_pg_array_text_local(s: &str) -> Vec<String> {
|
||||
let s = s.trim();
|
||||
let s = s.strip_prefix('{').unwrap_or(s);
|
||||
let s = s.strip_suffix('}').unwrap_or(s);
|
||||
if s.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
let mut out = Vec::new();
|
||||
let mut cur = String::new();
|
||||
let mut in_quotes = false;
|
||||
let mut chars = s.chars().peekable();
|
||||
while let Some(c) = chars.next() {
|
||||
match c {
|
||||
'"' if !in_quotes => in_quotes = true,
|
||||
'"' if in_quotes => {
|
||||
if chars.peek() == Some(&'"') {
|
||||
cur.push('"');
|
||||
chars.next();
|
||||
} else {
|
||||
in_quotes = false;
|
||||
}
|
||||
}
|
||||
',' if !in_quotes => {
|
||||
out.push(std::mem::take(&mut cur));
|
||||
}
|
||||
'\\' if in_quotes => {
|
||||
if let Some(next) = chars.next() {
|
||||
cur.push(next);
|
||||
}
|
||||
}
|
||||
other => cur.push(other),
|
||||
}
|
||||
}
|
||||
if !cur.is_empty() || s.ends_with(',') {
|
||||
out.push(cur);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
async fn profile_table_clickhouse(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
schema: &str,
|
||||
table: &str,
|
||||
) -> TuskResult<String> {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let active_db = client.database.clone();
|
||||
let dbn = if schema == "public" || schema.is_empty() {
|
||||
active_db
|
||||
} else {
|
||||
schema.to_string()
|
||||
};
|
||||
|
||||
let cols_sql = format!(
|
||||
"SELECT name, type FROM system.columns \
|
||||
WHERE database = '{}' AND table = '{}' \
|
||||
ORDER BY position LIMIT {}",
|
||||
dbn.replace('\'', "\\'"),
|
||||
table.replace('\'', "\\'"),
|
||||
PROFILE_TABLE_MAX_COLUMNS
|
||||
);
|
||||
let col_rows = client.fetch_objects(&cols_sql).await?;
|
||||
if col_rows.is_empty() {
|
||||
return Err(TuskError::Custom(format!(
|
||||
"Table '{}.{}' does not exist (or no privileges).",
|
||||
dbn, table
|
||||
)));
|
||||
}
|
||||
|
||||
let mut select_parts: Vec<String> = vec!["count() AS rows_total".to_string()];
|
||||
let mut col_names: Vec<String> = Vec::new();
|
||||
let mut col_types: Vec<String> = Vec::new();
|
||||
for r in &col_rows {
|
||||
let name = r.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let dtype = r.get("type").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
if name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
col_names.push(name.clone());
|
||||
col_types.push(dtype);
|
||||
let q = name.replace('`', "``");
|
||||
select_parts.push(format!("countIf(`{}` IS NULL) AS null_{}", q, col_names.len()));
|
||||
select_parts.push(format!("uniqHLL12(`{}`) AS dist_{}", q, col_names.len()));
|
||||
select_parts.push(format!("toString(min(`{}`)) AS min_{}", q, col_names.len()));
|
||||
select_parts.push(format!("toString(max(`{}`)) AS max_{}", q, col_names.len()));
|
||||
select_parts.push(format!(
|
||||
"arrayStringConcat(arrayMap(x -> toString(x), topK({})(`{}`)), '|') AS top_{}",
|
||||
PROFILE_TABLE_TOPK,
|
||||
q,
|
||||
col_names.len()
|
||||
));
|
||||
}
|
||||
|
||||
let agg_sql = format!(
|
||||
"SELECT {} FROM `{}`.`{}`",
|
||||
select_parts.join(", "),
|
||||
dbn.replace('`', "``"),
|
||||
table.replace('`', "``")
|
||||
);
|
||||
let agg_rows = client.fetch_objects(&agg_sql).await?;
|
||||
let row = agg_rows
|
||||
.first()
|
||||
.ok_or_else(|| TuskError::Custom("ClickHouse returned no row for profile aggregate".into()))?;
|
||||
|
||||
let rows_total = row
|
||||
.get("rows_total")
|
||||
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
|
||||
.unwrap_or(0);
|
||||
|
||||
let mut out = format!(
|
||||
"PROFILE {}.{}\nRows: {}\n{} columns profiled\n",
|
||||
dbn,
|
||||
table,
|
||||
rows_total,
|
||||
col_names.len()
|
||||
);
|
||||
|
||||
for (i, name) in col_names.iter().enumerate() {
|
||||
let n = i + 1;
|
||||
let nulls = row
|
||||
.get(&format!("null_{}", n))
|
||||
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
|
||||
.unwrap_or(0);
|
||||
let dist = row
|
||||
.get(&format!("dist_{}", n))
|
||||
.and_then(|v| v.as_str().and_then(|s| s.parse::<i64>().ok()).or_else(|| v.as_i64()))
|
||||
.unwrap_or(0);
|
||||
let min = row.get(&format!("min_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
|
||||
let max = row.get(&format!("max_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
|
||||
let top_raw = row.get(&format!("top_{}", n)).and_then(|v| v.as_str()).unwrap_or("");
|
||||
|
||||
out.push_str(&format!("\n {} ({}):\n", name, col_types[i]));
|
||||
let null_frac = if rows_total > 0 {
|
||||
nulls as f64 / rows_total as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
out.push_str(&format!(" null_frac: {:.4}\n", null_frac));
|
||||
out.push_str(&format!(" distinct (HLL): {}\n", dist));
|
||||
if !min.is_empty() || !max.is_empty() {
|
||||
out.push_str(&format!(" range: {} … {}\n", min, max));
|
||||
}
|
||||
if !top_raw.is_empty() {
|
||||
let top_vals: Vec<&str> = top_raw.split('|').take(PROFILE_TABLE_TOPK).collect();
|
||||
out.push_str(&format!(" top: {}\n", top_vals.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
if col_rows.len() == PROFILE_TABLE_MAX_COLUMNS {
|
||||
out.push_str(&format!(
|
||||
"\n…showing first {} columns\n",
|
||||
PROFILE_TABLE_MAX_COLUMNS
|
||||
));
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// sample_data (PR2 — returns SQL string; dispatch site runs it through
|
||||
// execute_query_core so the QueryResult feeds the standard renderer)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn build_sample_sql(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
table: &str,
|
||||
limit: u32,
|
||||
) -> TuskResult<String> {
|
||||
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
|
||||
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
match flavor {
|
||||
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
|
||||
build_sample_sql_postgres(state, connection_id, &schema, &tbl, limit).await
|
||||
}
|
||||
DbFlavor::ClickHouse => {
|
||||
build_sample_sql_clickhouse(state, connection_id, &schema, &tbl, limit).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_sample_sql_postgres(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
schema: &str,
|
||||
table: &str,
|
||||
limit: u32,
|
||||
) -> TuskResult<String> {
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
let reltuples: f64 = sqlx::query_scalar(
|
||||
"SELECT c.reltuples FROM pg_class c JOIN pg_namespace n ON c.relnamespace = n.oid \
|
||||
WHERE n.nspname = $1 AND c.relname = $2",
|
||||
)
|
||||
.bind(schema)
|
||||
.bind(table)
|
||||
.fetch_optional(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let qualified = format!("{}.{}", escape_ident(schema), escape_ident(table));
|
||||
if reltuples > 0.0 {
|
||||
let target = limit as f64 * 100.0 / reltuples;
|
||||
let percent = target.clamp(0.01, 100.0);
|
||||
Ok(format!(
|
||||
"SELECT * FROM {} TABLESAMPLE BERNOULLI({:.4}) LIMIT {}",
|
||||
qualified, percent, limit
|
||||
))
|
||||
} else {
|
||||
Ok(format!(
|
||||
"SELECT * FROM {} ORDER BY random() LIMIT {}",
|
||||
qualified, limit
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_sample_sql_clickhouse(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
schema: &str,
|
||||
table: &str,
|
||||
limit: u32,
|
||||
) -> TuskResult<String> {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let active_db = client.database.clone();
|
||||
let dbn = if schema == "public" || schema.is_empty() {
|
||||
active_db
|
||||
} else {
|
||||
schema.to_string()
|
||||
};
|
||||
|
||||
let info_sql = format!(
|
||||
"SELECT engine, sampling_key FROM system.tables \
|
||||
WHERE database = '{}' AND name = '{}' LIMIT 1",
|
||||
dbn.replace('\'', "\\'"),
|
||||
table.replace('\'', "\\'")
|
||||
);
|
||||
let rows = client.fetch_objects(&info_sql).await.unwrap_or_default();
|
||||
let (engine, sampling_key) = match rows.first() {
|
||||
Some(r) => (
|
||||
r.get("engine").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
r.get("sampling_key").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
),
|
||||
None => (String::new(), String::new()),
|
||||
};
|
||||
|
||||
let qualified = format!(
|
||||
"`{}`.`{}`",
|
||||
dbn.replace('`', "``"),
|
||||
table.replace('`', "``")
|
||||
);
|
||||
if engine.starts_with("Merge") && !sampling_key.trim().is_empty() {
|
||||
Ok(format!(
|
||||
"SELECT * FROM {} SAMPLE 0.01 LIMIT {}",
|
||||
qualified, limit
|
||||
))
|
||||
} else {
|
||||
Ok(format!(
|
||||
"SELECT * FROM {} ORDER BY rand() LIMIT {}",
|
||||
qualified, limit
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// explain_query (PR2)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn explain_query_tool(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
sql: &str,
|
||||
) -> TuskResult<String> {
|
||||
let trimmed = sql.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(TuskError::Custom("explain_query: sql must not be empty".into()));
|
||||
}
|
||||
// Validate the user's statement BEFORE prefixing EXPLAIN so the error message
|
||||
// references their SQL, not the wrapper. ensure_readonly_sql also rejects any
|
||||
// forbidden keywords (INSERT/UPDATE/DELETE/...) even nested under EXPLAIN.
|
||||
ensure_readonly_sql(trimmed).map_err(|e| TuskError::Custom(e.to_string()))?;
|
||||
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
match flavor {
|
||||
DbFlavor::PostgreSQL | DbFlavor::Greenplum => {
|
||||
explain_query_postgres(state, connection_id, trimmed).await
|
||||
}
|
||||
DbFlavor::ClickHouse => explain_query_clickhouse(state, connection_id, trimmed).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn explain_query_postgres(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
sql: &str,
|
||||
) -> TuskResult<String> {
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
let plan_sql = format!("EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) {}", sql);
|
||||
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
|
||||
sqlx::query("SET TRANSACTION READ ONLY")
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let row = sqlx::query(&plan_sql)
|
||||
.fetch_one(&mut *tx)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let _ = tx.rollback().await;
|
||||
|
||||
let raw_json: serde_json::Value = match row.try_get::<serde_json::Value, _>(0) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
let s: String = row.try_get(0).map_err(TuskError::Database)?;
|
||||
serde_json::from_str(&s)
|
||||
.map_err(|e| TuskError::Custom(format!("EXPLAIN JSON parse failed: {}", e)))?
|
||||
}
|
||||
};
|
||||
|
||||
let plans = raw_json
|
||||
.as_array()
|
||||
.ok_or_else(|| TuskError::Custom("EXPLAIN JSON: expected array".into()))?;
|
||||
let plan = plans.first().and_then(|p| p.get("Plan")).ok_or_else(|| {
|
||||
TuskError::Custom("EXPLAIN JSON: missing top-level Plan node".into())
|
||||
})?;
|
||||
|
||||
let root_node = plan.get("Node Type").and_then(|v| v.as_str()).unwrap_or("?");
|
||||
let total_cost = plan.get("Total Cost").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
let planning = plans
|
||||
.first()
|
||||
.and_then(|p| p.get("Planning Time").and_then(|v| v.as_f64()))
|
||||
.unwrap_or(0.0);
|
||||
let execution = plans
|
||||
.first()
|
||||
.and_then(|p| p.get("Execution Time").and_then(|v| v.as_f64()))
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let mut seq_scans: Vec<String> = Vec::new();
|
||||
let mut spilled: Vec<String> = Vec::new();
|
||||
let mut motions: Vec<String> = Vec::new();
|
||||
let mut max_skew: Option<(f64, String)> = None;
|
||||
walk_pg_plan(plan, &mut seq_scans, &mut spilled, &mut motions, &mut max_skew);
|
||||
|
||||
let mut out = format!(
|
||||
"PLAN root: {}, total cost {:.1}\nPlanning: {:.2} ms Execution: {:.2} ms\n",
|
||||
root_node, total_cost, planning, execution
|
||||
);
|
||||
if !seq_scans.is_empty() {
|
||||
out.push_str(&format!("Seq scans on: {}\n", seq_scans.join(", ")));
|
||||
}
|
||||
if !spilled.is_empty() {
|
||||
out.push_str(&format!("Spilled to disk: {}\n", spilled.join(", ")));
|
||||
}
|
||||
if !motions.is_empty() {
|
||||
out.push_str(&format!("Motions (Greenplum): {}\n", motions.join(", ")));
|
||||
}
|
||||
if let Some((ratio, node)) = max_skew {
|
||||
if ratio >= 5.0 {
|
||||
out.push_str(&format!(
|
||||
"Estimate skew: max plan/actual ratio = {:.1} on {}\n",
|
||||
ratio, node
|
||||
));
|
||||
}
|
||||
}
|
||||
if seq_scans.is_empty() && spilled.is_empty() && motions.is_empty() {
|
||||
out.push_str("No obvious red flags.\n");
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn walk_pg_plan(
|
||||
node: &serde_json::Value,
|
||||
seq_scans: &mut Vec<String>,
|
||||
spilled: &mut Vec<String>,
|
||||
motions: &mut Vec<String>,
|
||||
max_skew: &mut Option<(f64, String)>,
|
||||
) {
|
||||
let node_type = node.get("Node Type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
if node_type == "Seq Scan" {
|
||||
let rel = node
|
||||
.get("Relation Name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("?");
|
||||
let schema = node
|
||||
.get("Schema")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| format!("{}.", s))
|
||||
.unwrap_or_default();
|
||||
seq_scans.push(format!("{}{}", schema, rel));
|
||||
}
|
||||
if let Some(method) = node.get("Sort Method").and_then(|v| v.as_str()) {
|
||||
if method.contains("disk") || method.contains("external") {
|
||||
spilled.push(format!("Sort ({})", method));
|
||||
}
|
||||
}
|
||||
if node_type.contains("Motion") {
|
||||
motions.push(node_type.to_string());
|
||||
}
|
||||
let plan_rows = node.get("Plan Rows").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
let actual_rows = node.get("Actual Rows").and_then(|v| v.as_f64()).unwrap_or(0.0);
|
||||
if actual_rows > 0.0 && plan_rows > 0.0 {
|
||||
let ratio = (plan_rows / actual_rows).max(actual_rows / plan_rows);
|
||||
if max_skew.as_ref().map(|(r, _)| ratio > *r).unwrap_or(true) {
|
||||
*max_skew = Some((ratio, node_type.to_string()));
|
||||
}
|
||||
}
|
||||
if let Some(children) = node.get("Plans").and_then(|v| v.as_array()) {
|
||||
for child in children {
|
||||
walk_pg_plan(child, seq_scans, spilled, motions, max_skew);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn explain_query_clickhouse(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
sql: &str,
|
||||
) -> TuskResult<String> {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let plan_sql = format!("EXPLAIN PLAN {}", sql);
|
||||
let qr = client.execute_query(&plan_sql, true).await?;
|
||||
if qr.rows.is_empty() {
|
||||
return Ok("(empty plan)".to_string());
|
||||
}
|
||||
let mut out = String::from("ClickHouse plan:\n");
|
||||
for row in &qr.rows {
|
||||
if let Some(cell) = row.first() {
|
||||
if let Some(s) = cell.as_str() {
|
||||
out.push_str(s);
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// detect_skew (PR2 — Greenplum-only)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn detect_skew_tool(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
table: &str,
|
||||
) -> TuskResult<String> {
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if !matches!(flavor, DbFlavor::Greenplum) {
|
||||
return Ok("detect_skew is only available on Greenplum connections.".to_string());
|
||||
}
|
||||
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
|
||||
let (schema, tbl, _raw) = normalise_table_ref(table, &active_db);
|
||||
|
||||
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&tbl));
|
||||
let sql = format!(
|
||||
"SELECT gp_segment_id, COUNT(*) AS n FROM {} GROUP BY 1 ORDER BY 1",
|
||||
qualified
|
||||
);
|
||||
let qr = execute_query_core(state, connection_id, &sql).await?;
|
||||
|
||||
let mut counts: Vec<(i64, i64)> = Vec::new();
|
||||
for row in &qr.rows {
|
||||
let seg = row
|
||||
.get(0)
|
||||
.and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok())))
|
||||
.unwrap_or(0);
|
||||
let n = row
|
||||
.get(1)
|
||||
.and_then(|v| v.as_i64().or_else(|| v.as_str().and_then(|s| s.parse().ok())))
|
||||
.unwrap_or(0);
|
||||
counts.push((seg, n));
|
||||
}
|
||||
|
||||
if counts.is_empty() {
|
||||
return Ok(format!("Table {}.{} is empty.", schema, tbl));
|
||||
}
|
||||
|
||||
let total: i64 = counts.iter().map(|(_, n)| *n).sum();
|
||||
let max = counts.iter().map(|(_, n)| *n).max().unwrap_or(0);
|
||||
let min = counts.iter().map(|(_, n)| *n).min().unwrap_or(0);
|
||||
let avg = total as f64 / counts.len() as f64;
|
||||
let ratio = if avg > 0.0 { max as f64 / avg } else { 0.0 };
|
||||
|
||||
let mut out = format!(
|
||||
"Per-segment row distribution for {}.{}\nsegments: {} total rows: {}\nmin: {} max: {} avg: {:.0}\nskew ratio (max/avg): {:.2}",
|
||||
schema,
|
||||
tbl,
|
||||
counts.len(),
|
||||
total,
|
||||
min,
|
||||
max,
|
||||
avg,
|
||||
ratio
|
||||
);
|
||||
if ratio > 1.5 {
|
||||
out.push_str(" ⚠ uneven distribution\n");
|
||||
} else {
|
||||
out.push_str(" OK — within 1.5x of average\n");
|
||||
}
|
||||
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
if let Some(policy) = fetch_gp_distribution_for(&pool, &schema, &tbl).await {
|
||||
out.push_str(&format!("\nCurrent policy: {}\n", policy));
|
||||
if ratio > 1.5 {
|
||||
out.push_str(
|
||||
"Hint: pick a higher-cardinality column. Run profile_table to compare n_distinct.\n",
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
/// Fetch the Greenplum DISTRIBUTED BY policy for a single table. Returns None if
|
||||
/// the catalog query fails (non-GP connection, missing privileges, etc.).
|
||||
async fn fetch_gp_distribution_for(
|
||||
pool: &PgPool,
|
||||
schema: &str,
|
||||
table: &str,
|
||||
) -> Option<String> {
|
||||
let row = sqlx::query(
|
||||
"SELECT COALESCE(\
|
||||
(SELECT array_agg(a.attname ORDER BY ord.idx) \
|
||||
FROM regexp_split_to_table(NULLIF(trim(p.distkey::text), ''), ' ') \
|
||||
WITH ORDINALITY AS ord(attnum_str, idx) \
|
||||
JOIN pg_attribute a \
|
||||
ON a.attrelid = c.oid \
|
||||
AND a.attnum::int = ord.attnum_str::int), \
|
||||
ARRAY[]::text[] \
|
||||
) AS dist_columns \
|
||||
FROM gp_distribution_policy p \
|
||||
JOIN pg_class c ON p.localoid = c.oid \
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid \
|
||||
WHERE n.nspname = $1 AND c.relname = $2",
|
||||
)
|
||||
.bind(schema)
|
||||
.bind(table)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()?;
|
||||
let cols: Vec<String> = row.try_get(0).ok()?;
|
||||
Some(if cols.is_empty() {
|
||||
"DISTRIBUTED RANDOMLY".to_string()
|
||||
} else {
|
||||
format!("DISTRIBUTED BY ({})", cols.join(", "))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -111,9 +111,7 @@ pub fn run() {
|
||||
commands::ai::save_ai_settings,
|
||||
commands::ai::list_ollama_models,
|
||||
commands::ai::list_fireworks_models,
|
||||
commands::ai::generate_sql,
|
||||
commands::ai::explain_sql,
|
||||
commands::ai::fix_sql_error,
|
||||
commands::ai::list_openrouter_models,
|
||||
// chat
|
||||
commands::chat::chat_send,
|
||||
commands::chat::chat_compact,
|
||||
|
||||
@@ -1,13 +1,26 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AiProvider {
|
||||
#[default]
|
||||
Ollama,
|
||||
OpenAi,
|
||||
Anthropic,
|
||||
Fireworks,
|
||||
OpenRouter,
|
||||
}
|
||||
|
||||
/// Deserialize a provider string, coercing legacy `openai`/`anthropic` and any
|
||||
/// unknown value to `Ollama`. Keeps existing config files loadable after the
|
||||
/// stub providers were removed.
|
||||
impl<'de> Deserialize<'de> for AiProvider {
|
||||
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
|
||||
let s = String::deserialize(d)?;
|
||||
Ok(match s.as_str() {
|
||||
"fireworks" => AiProvider::Fireworks,
|
||||
"openrouter" => AiProvider::OpenRouter,
|
||||
_ => AiProvider::Ollama,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -15,11 +28,9 @@ pub struct AiSettings {
|
||||
pub provider: AiProvider,
|
||||
pub ollama_url: String,
|
||||
#[serde(default)]
|
||||
pub openai_api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub anthropic_api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub fireworks_api_key: Option<String>,
|
||||
#[serde(default)]
|
||||
pub openrouter_api_key: Option<String>,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
@@ -28,9 +39,8 @@ impl Default for AiSettings {
|
||||
Self {
|
||||
provider: AiProvider::Ollama,
|
||||
ollama_url: "http://localhost:11434".to_string(),
|
||||
openai_api_key: None,
|
||||
anthropic_api_key: None,
|
||||
fireworks_api_key: None,
|
||||
openrouter_api_key: None,
|
||||
model: String::new(),
|
||||
}
|
||||
}
|
||||
@@ -71,7 +81,9 @@ pub struct OllamaModel {
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Fireworks (OpenAI-compatible chat-completions)
|
||||
// OpenAI-compatible chat-completions (Fireworks, OpenRouter)
|
||||
// These request/response shapes are shared by every OpenAI-compatible provider;
|
||||
// the `Fireworks*` names are retained for historical reasons.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
|
||||
@@ -31,18 +31,3 @@ pub struct ChatTurnResult {
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub usage: ContextUsage,
|
||||
}
|
||||
|
||||
/// Chart configuration produced by the agent's `make_chart` tool.
|
||||
/// Embedded as JSON in `ToolResult.text` for tool == "make_chart" while the
|
||||
/// underlying data lives in `ToolResult.result`. The frontend reads both to
|
||||
/// render the chart inline.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChartConfig {
|
||||
pub chart_type: String, // "bar" | "line" | "area" | "pie"
|
||||
pub x: String, // column name for X axis / category
|
||||
pub y: String, // column name for Y axis / numeric value
|
||||
pub group: Option<String>, // optional column for series grouping
|
||||
pub title: Option<String>,
|
||||
pub orientation: Option<String>, // "vertical" | "horizontal" — bar only
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{watch, RwLock};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
@@ -17,12 +17,6 @@ pub enum DbFlavor {
|
||||
}
|
||||
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SchemaCacheEntry {
|
||||
pub schema_text: String,
|
||||
pub cached_at: Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct CachedString {
|
||||
pub value: String,
|
||||
@@ -43,23 +37,16 @@ pub struct AppState {
|
||||
/// Greenplum major version (6 or 7), tracked separately because GP6 and GP7
|
||||
/// expose very different system catalogs (GP6 = PG9.4 base, GP7 = PG14 base).
|
||||
pub gp_majors: RwLock<HashMap<String, u8>>,
|
||||
/// Legacy cache used by generate_sql/explain_sql/fix_sql_error — full DDL.
|
||||
pub schema_cache: RwLock<HashMap<String, SchemaCacheEntry>>,
|
||||
/// Chat v2 caches: lite overview per connection.
|
||||
/// Chat agent caches: lite overview per connection.
|
||||
pub overview_cache: RwLock<HashMap<String, CachedString>>,
|
||||
/// Chat v2 caches: list of tables per (connection_id, db_name) — used for
|
||||
/// Chat agent caches: list of tables per (connection_id, db_name) — used for
|
||||
/// list_tables on a non-active PG database via temporary pool.
|
||||
pub tables_by_db_cache: RwLock<HashMap<(String, String), CachedVec<String>>>,
|
||||
/// Chat v2 caches: column block per (connection_id, db_name, "schema.table").
|
||||
pub columns_cache: RwLock<HashMap<(String, String, String), CachedString>>,
|
||||
pub mcp_shutdown_tx: watch::Sender<bool>,
|
||||
pub mcp_running: RwLock<bool>,
|
||||
pub ai_settings: RwLock<Option<AiSettings>>,
|
||||
}
|
||||
|
||||
const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
|
||||
const SCHEMA_CACHE_MAX_SIZE: usize = 100;
|
||||
|
||||
impl AppState {
|
||||
pub fn new() -> Self {
|
||||
let (mcp_shutdown_tx, _) = watch::channel(false);
|
||||
@@ -69,10 +56,8 @@ impl AppState {
|
||||
read_only: RwLock::new(HashMap::new()),
|
||||
db_flavors: RwLock::new(HashMap::new()),
|
||||
gp_majors: RwLock::new(HashMap::new()),
|
||||
schema_cache: RwLock::new(HashMap::new()),
|
||||
overview_cache: RwLock::new(HashMap::new()),
|
||||
tables_by_db_cache: RwLock::new(HashMap::new()),
|
||||
columns_cache: RwLock::new(HashMap::new()),
|
||||
mcp_shutdown_tx,
|
||||
mcp_running: RwLock::new(false),
|
||||
ai_settings: RwLock::new(None),
|
||||
@@ -82,16 +67,11 @@ impl AppState {
|
||||
/// Drop every chat-agent cache entry tied to this connection.
|
||||
/// Called by switch_database_core, disconnect, and on connection delete.
|
||||
pub async fn invalidate_chat_caches_for(&self, connection_id: &str) {
|
||||
self.schema_cache.write().await.remove(connection_id);
|
||||
self.overview_cache.write().await.remove(connection_id);
|
||||
self.tables_by_db_cache
|
||||
.write()
|
||||
.await
|
||||
.retain(|(cid, _), _| cid != connection_id);
|
||||
self.columns_cache
|
||||
.write()
|
||||
.await
|
||||
.retain(|(cid, _, _), _| cid != connection_id);
|
||||
}
|
||||
|
||||
pub async fn get_pool(&self, connection_id: &str) -> TuskResult<PgPool> {
|
||||
@@ -125,39 +105,4 @@ impl AppState {
|
||||
pub async fn get_gp_major(&self, id: &str) -> Option<u8> {
|
||||
self.gp_majors.read().await.get(id).copied()
|
||||
}
|
||||
|
||||
pub async fn get_schema_cache(&self, connection_id: &str) -> Option<String> {
|
||||
let cache = self.schema_cache.read().await;
|
||||
cache.get(connection_id).and_then(|entry| {
|
||||
if entry.cached_at.elapsed() < SCHEMA_CACHE_TTL {
|
||||
Some(entry.schema_text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn set_schema_cache(&self, connection_id: String, schema_text: String) {
|
||||
let mut cache = self.schema_cache.write().await;
|
||||
// Evict stale entries to prevent unbounded memory growth
|
||||
cache.retain(|_, entry| entry.cached_at.elapsed() < SCHEMA_CACHE_TTL);
|
||||
// If still at capacity, remove the oldest entry
|
||||
if cache.len() >= SCHEMA_CACHE_MAX_SIZE {
|
||||
if let Some(oldest_key) = cache
|
||||
.iter()
|
||||
.min_by_key(|(_, e)| e.cached_at)
|
||||
.map(|(k, _)| k.clone())
|
||||
{
|
||||
cache.remove(&oldest_key);
|
||||
}
|
||||
}
|
||||
cache.insert(
|
||||
connection_id,
|
||||
SchemaCacheEntry {
|
||||
schema_text,
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
import { useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { AiSettingsPopover } from "./AiSettingsPopover";
|
||||
import { useGenerateSql } from "@/hooks/use-ai";
|
||||
import { Sparkles, Loader2, X, Eraser } from "lucide-react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface Props {
|
||||
connectionId: string;
|
||||
onSqlGenerated: (sql: string) => void;
|
||||
onClose: () => void;
|
||||
onExecute?: () => void;
|
||||
}
|
||||
|
||||
export function AiBar({ connectionId, onSqlGenerated, onClose, onExecute }: Props) {
|
||||
const [prompt, setPrompt] = useState("");
|
||||
const generateMutation = useGenerateSql();
|
||||
|
||||
const handleGenerate = () => {
|
||||
if (!prompt.trim() || generateMutation.isPending) return;
|
||||
generateMutation.mutate(
|
||||
{ connectionId, prompt },
|
||||
{
|
||||
onSuccess: (sql) => {
|
||||
onSqlGenerated(sql);
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI generation failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
if (e.key === "Enter" && (e.ctrlKey || e.metaKey)) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
onExecute?.();
|
||||
return;
|
||||
}
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
handleGenerate();
|
||||
return;
|
||||
}
|
||||
if (e.key === "Escape") {
|
||||
e.stopPropagation();
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="tusk-ai-bar flex items-center gap-2 px-2 py-1.5 tusk-fade-in">
|
||||
<Sparkles className="h-3.5 w-3.5 shrink-0 tusk-ai-icon" />
|
||||
<Input
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Describe the query you want..."
|
||||
className="h-7 min-w-0 flex-1 border-tusk-purple/20 bg-tusk-purple/5 text-xs placeholder:text-muted-foreground/40 focus:border-tusk-purple/40 focus:ring-tusk-purple/20"
|
||||
autoFocus
|
||||
disabled={generateMutation.isPending}
|
||||
/>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
className="gap-1 text-[11px] text-tusk-purple hover:bg-tusk-purple/10 hover:text-tusk-purple"
|
||||
onClick={handleGenerate}
|
||||
disabled={generateMutation.isPending || !prompt.trim()}
|
||||
>
|
||||
{generateMutation.isPending ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
"Generate"
|
||||
)}
|
||||
</Button>
|
||||
{prompt.trim() && (
|
||||
<Button
|
||||
size="icon-xs"
|
||||
variant="ghost"
|
||||
onClick={() => setPrompt("")}
|
||||
title="Clear prompt"
|
||||
disabled={generateMutation.isPending}
|
||||
className="text-muted-foreground"
|
||||
>
|
||||
<Eraser className="h-3 w-3" />
|
||||
</Button>
|
||||
)}
|
||||
<AiSettingsPopover />
|
||||
<Button
|
||||
size="icon-xs"
|
||||
variant="ghost"
|
||||
onClick={onClose}
|
||||
title="Close AI bar"
|
||||
className="text-muted-foreground"
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -7,7 +7,11 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useFireworksModels, useOllamaModels } from "@/hooks/use-ai";
|
||||
import {
|
||||
useFireworksModels,
|
||||
useOllamaModels,
|
||||
useOpenRouterModels,
|
||||
} from "@/hooks/use-ai";
|
||||
import { RefreshCw, Loader2 } from "lucide-react";
|
||||
import type { AiProvider, OllamaModel } from "@/types";
|
||||
|
||||
@@ -17,6 +21,8 @@ interface Props {
|
||||
onOllamaUrlChange: (url: string) => void;
|
||||
fireworksApiKey: string;
|
||||
onFireworksApiKeyChange: (key: string) => void;
|
||||
openrouterApiKey: string;
|
||||
onOpenRouterApiKeyChange: (key: string) => void;
|
||||
model: string;
|
||||
onModelChange: (model: string) => void;
|
||||
}
|
||||
@@ -27,6 +33,8 @@ export function AiSettingsFields({
|
||||
onOllamaUrlChange,
|
||||
fireworksApiKey,
|
||||
onFireworksApiKeyChange,
|
||||
openrouterApiKey,
|
||||
onOpenRouterApiKeyChange,
|
||||
model,
|
||||
onModelChange,
|
||||
}: Props) {
|
||||
@@ -41,6 +49,17 @@ export function AiSettingsFields({
|
||||
);
|
||||
}
|
||||
|
||||
if (provider === "openrouter") {
|
||||
return (
|
||||
<OpenRouterFields
|
||||
apiKey={openrouterApiKey}
|
||||
onApiKeyChange={onOpenRouterApiKeyChange}
|
||||
model={model}
|
||||
onModelChange={onModelChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<OllamaFields
|
||||
ollamaUrl={ollamaUrl}
|
||||
@@ -143,6 +162,55 @@ function FireworksFields({
|
||||
);
|
||||
}
|
||||
|
||||
function OpenRouterFields({
|
||||
apiKey,
|
||||
onApiKeyChange,
|
||||
model,
|
||||
onModelChange,
|
||||
}: {
|
||||
apiKey: string;
|
||||
onApiKeyChange: (key: string) => void;
|
||||
model: string;
|
||||
onModelChange: (model: string) => void;
|
||||
}) {
|
||||
const {
|
||||
data: models,
|
||||
isLoading,
|
||||
isError,
|
||||
refetch,
|
||||
} = useOpenRouterModels(apiKey);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<label className="text-xs text-muted-foreground">OpenRouter API key</label>
|
||||
<Input
|
||||
type="password"
|
||||
value={apiKey}
|
||||
onChange={(e) => onApiKeyChange(e.target.value)}
|
||||
placeholder="sk-or-..."
|
||||
className="h-8 text-xs"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<p className="text-[10px] text-muted-foreground/70">
|
||||
Stored locally; sent only to openrouter.ai.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<ModelDropdown
|
||||
models={models}
|
||||
loading={isLoading}
|
||||
errored={isError}
|
||||
errorText="Cannot reach OpenRouter (check API key)"
|
||||
onRefresh={() => refetch()}
|
||||
model={model}
|
||||
onModelChange={onModelChange}
|
||||
emptyHint={apiKey.trim() ? "Click ↻ to load models" : "Enter API key first"}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function ModelDropdown({
|
||||
models,
|
||||
loading,
|
||||
|
||||
@@ -21,6 +21,7 @@ import type { AiProvider } from "@/types";
|
||||
const SUPPORTED_PROVIDERS: { value: AiProvider; label: string }[] = [
|
||||
{ value: "ollama", label: "Ollama (local)" },
|
||||
{ value: "fireworks", label: "Fireworks AI" },
|
||||
{ value: "openrouter", label: "OpenRouter" },
|
||||
];
|
||||
|
||||
export function AiSettingsPopover() {
|
||||
@@ -30,22 +31,16 @@ export function AiSettingsPopover() {
|
||||
const [provider, setProvider] = useState<AiProvider | null>(null);
|
||||
const [url, setUrl] = useState<string | null>(null);
|
||||
const [fireworksKey, setFireworksKey] = useState<string | null>(null);
|
||||
const [openrouterKey, setOpenrouterKey] = useState<string | null>(null);
|
||||
const [model, setModel] = useState<string | null>(null);
|
||||
|
||||
const settingsProvider = settings?.provider;
|
||||
// Hide unsupported legacy values (openai/anthropic) from the selector.
|
||||
const normalizedSettingsProvider: AiProvider | undefined =
|
||||
settingsProvider === "ollama" || settingsProvider === "fireworks"
|
||||
? settingsProvider
|
||||
: settingsProvider
|
||||
? "ollama"
|
||||
: undefined;
|
||||
|
||||
const currentProvider: AiProvider =
|
||||
provider ?? normalizedSettingsProvider ?? "ollama";
|
||||
provider ?? settings?.provider ?? "ollama";
|
||||
const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434";
|
||||
const currentFireworksKey =
|
||||
fireworksKey ?? settings?.fireworks_api_key ?? "";
|
||||
const currentOpenrouterKey =
|
||||
openrouterKey ?? settings?.openrouter_api_key ?? "";
|
||||
const currentModel = model ?? settings?.model ?? "";
|
||||
|
||||
const handleProviderChange = (next: AiProvider) => {
|
||||
@@ -64,6 +59,10 @@ export function AiSettingsPopover() {
|
||||
currentProvider === "fireworks"
|
||||
? currentFireworksKey.trim() || undefined
|
||||
: settings?.fireworks_api_key,
|
||||
openrouter_api_key:
|
||||
currentProvider === "openrouter"
|
||||
? currentOpenrouterKey.trim() || undefined
|
||||
: settings?.openrouter_api_key,
|
||||
model: currentModel,
|
||||
},
|
||||
{
|
||||
@@ -117,6 +116,8 @@ export function AiSettingsPopover() {
|
||||
onOllamaUrlChange={setUrl}
|
||||
fireworksApiKey={currentFireworksKey}
|
||||
onFireworksApiKeyChange={setFireworksKey}
|
||||
openrouterApiKey={currentOpenrouterKey}
|
||||
onOpenRouterApiKeyChange={setOpenrouterKey}
|
||||
model={currentModel}
|
||||
onModelChange={setModel}
|
||||
/>
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
import { useMemo } from "react";
|
||||
import {
|
||||
Area,
|
||||
AreaChart,
|
||||
Bar,
|
||||
BarChart,
|
||||
CartesianGrid,
|
||||
Cell,
|
||||
Legend,
|
||||
Line,
|
||||
LineChart,
|
||||
Pie,
|
||||
PieChart,
|
||||
ResponsiveContainer,
|
||||
Tooltip,
|
||||
XAxis,
|
||||
YAxis,
|
||||
} from "recharts";
|
||||
import type { ChartConfig } from "@/types";
|
||||
|
||||
interface Props {
|
||||
config: ChartConfig;
|
||||
columns: string[];
|
||||
rows: unknown[][];
|
||||
height?: number;
|
||||
}
|
||||
|
||||
const PALETTE = [
|
||||
"#60a5fa", // blue-400
|
||||
"#34d399", // emerald-400
|
||||
"#fbbf24", // amber-400
|
||||
"#f87171", // red-400
|
||||
"#a78bfa", // violet-400
|
||||
"#22d3ee", // cyan-400
|
||||
"#fb923c", // orange-400
|
||||
"#f472b6", // pink-400
|
||||
];
|
||||
|
||||
const MAX_POINTS = 500;
|
||||
|
||||
export function ChartPreview({ config, columns, rows, height = 280 }: Props) {
|
||||
const xIdx = columns.indexOf(config.x);
|
||||
const yIdx = columns.indexOf(config.y);
|
||||
const groupIdx = config.group ? columns.indexOf(config.group) : -1;
|
||||
|
||||
const limited = useMemo(() => rows.slice(0, MAX_POINTS), [rows]);
|
||||
|
||||
if (xIdx < 0 || yIdx < 0) {
|
||||
return (
|
||||
<ChartFallback
|
||||
config={config}
|
||||
message={`Column not found: ${xIdx < 0 ? config.x : config.y}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Coerce y values to numbers; chart libs need numeric Y.
|
||||
const numericY = (v: unknown): number => {
|
||||
if (typeof v === "number") return v;
|
||||
if (typeof v === "string") {
|
||||
const n = parseFloat(v);
|
||||
return Number.isFinite(n) ? n : 0;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
const labelX = (v: unknown): string => {
|
||||
if (v == null) return "—";
|
||||
if (typeof v === "string") return v;
|
||||
if (typeof v === "number" || typeof v === "boolean") return String(v);
|
||||
return JSON.stringify(v);
|
||||
};
|
||||
|
||||
const isGrouped = groupIdx >= 0;
|
||||
|
||||
// ──────────── grouped data shape ────────────
|
||||
// For multi-series: pivot to { x: <xValue>, <group1>: yVal, <group2>: yVal, … }
|
||||
// Used by line, area, and grouped-bar.
|
||||
const pivoted = useMemo(() => {
|
||||
if (!isGrouped) return null;
|
||||
const map = new Map<string, Record<string, unknown>>();
|
||||
const groupSet = new Set<string>();
|
||||
for (const row of limited) {
|
||||
const xv = labelX(row[xIdx]);
|
||||
const gv = labelX(row[groupIdx!]);
|
||||
const yv = numericY(row[yIdx]);
|
||||
groupSet.add(gv);
|
||||
const acc = map.get(xv) ?? { _x: xv };
|
||||
acc[gv] = ((acc[gv] as number) ?? 0) + yv;
|
||||
map.set(xv, acc);
|
||||
}
|
||||
return {
|
||||
data: Array.from(map.values()),
|
||||
groups: Array.from(groupSet),
|
||||
};
|
||||
}, [isGrouped, limited, xIdx, yIdx, groupIdx]);
|
||||
|
||||
// Single series shape: [{ _x, _y }]
|
||||
const flat = useMemo(() => {
|
||||
return limited.map((row) => ({
|
||||
_x: labelX(row[xIdx]),
|
||||
_y: numericY(row[yIdx]),
|
||||
}));
|
||||
}, [limited, xIdx, yIdx]);
|
||||
|
||||
const tickStyle = {
|
||||
fill: "var(--muted-foreground)",
|
||||
fontSize: 10,
|
||||
} as const;
|
||||
|
||||
const axisLine = {
|
||||
stroke: "rgba(255, 255, 255, 0.08)",
|
||||
} as const;
|
||||
|
||||
const tooltipStyle = {
|
||||
backgroundColor: "var(--popover)",
|
||||
border: "1px solid var(--border)",
|
||||
borderRadius: 6,
|
||||
fontSize: 11,
|
||||
} as const;
|
||||
|
||||
if (config.chart_type === "pie") {
|
||||
// Pie: aggregate y by x label (sum), no group support.
|
||||
const agg = new Map<string, number>();
|
||||
for (const row of limited) {
|
||||
const xv = labelX(row[xIdx]);
|
||||
agg.set(xv, (agg.get(xv) ?? 0) + numericY(row[yIdx]));
|
||||
}
|
||||
const data = Array.from(agg.entries()).map(([name, value]) => ({ name, value }));
|
||||
return (
|
||||
<ChartFrame config={config} height={height} count={data.length} totalRows={rows.length}>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<PieChart>
|
||||
<Pie
|
||||
data={data}
|
||||
dataKey="value"
|
||||
nameKey="name"
|
||||
outerRadius={Math.min(height / 2.5, 110)}
|
||||
label={(entry) =>
|
||||
typeof entry.name === "string" && entry.name.length < 20 ? entry.name : ""
|
||||
}
|
||||
>
|
||||
{data.map((_, i) => (
|
||||
<Cell key={i} fill={PALETTE[i % PALETTE.length]} />
|
||||
))}
|
||||
</Pie>
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
<Legend
|
||||
wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }}
|
||||
verticalAlign="bottom"
|
||||
/>
|
||||
</PieChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
if (config.chart_type === "line") {
|
||||
return (
|
||||
<ChartFrame
|
||||
config={config}
|
||||
height={height}
|
||||
count={isGrouped ? pivoted!.data.length : flat.length}
|
||||
totalRows={rows.length}
|
||||
>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<LineChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
|
||||
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
|
||||
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
{isGrouped ? (
|
||||
<>
|
||||
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
|
||||
{pivoted!.groups.map((g, i) => (
|
||||
<Line
|
||||
key={g}
|
||||
type="monotone"
|
||||
dataKey={g}
|
||||
stroke={PALETTE[i % PALETTE.length]}
|
||||
strokeWidth={2}
|
||||
dot={false}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
<Line type="monotone" dataKey="_y" stroke={PALETTE[0]} strokeWidth={2} dot={false} />
|
||||
)}
|
||||
</LineChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
if (config.chart_type === "area") {
|
||||
return (
|
||||
<ChartFrame
|
||||
config={config}
|
||||
height={height}
|
||||
count={isGrouped ? pivoted!.data.length : flat.length}
|
||||
totalRows={rows.length}
|
||||
>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<AreaChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
|
||||
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
|
||||
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
{isGrouped ? (
|
||||
<>
|
||||
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
|
||||
{pivoted!.groups.map((g, i) => (
|
||||
<Area
|
||||
key={g}
|
||||
type="monotone"
|
||||
dataKey={g}
|
||||
stackId="1"
|
||||
stroke={PALETTE[i % PALETTE.length]}
|
||||
fill={PALETTE[i % PALETTE.length]}
|
||||
fillOpacity={0.35}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
<Area
|
||||
type="monotone"
|
||||
dataKey="_y"
|
||||
stroke={PALETTE[0]}
|
||||
fill={PALETTE[0]}
|
||||
fillOpacity={0.35}
|
||||
/>
|
||||
)}
|
||||
</AreaChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
// bar (default)
|
||||
const horizontal = config.orientation === "horizontal";
|
||||
return (
|
||||
<ChartFrame
|
||||
config={config}
|
||||
height={height}
|
||||
count={isGrouped ? pivoted!.data.length : flat.length}
|
||||
totalRows={rows.length}
|
||||
>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<BarChart
|
||||
layout={horizontal ? "vertical" : "horizontal"}
|
||||
data={isGrouped ? pivoted!.data : flat}
|
||||
margin={{ top: 8, right: 12, left: horizontal ? 24 : 0, bottom: 4 }}
|
||||
>
|
||||
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={horizontal} horizontal={!horizontal} />
|
||||
{horizontal ? (
|
||||
<>
|
||||
<XAxis type="number" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis dataKey="_x" type="category" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} width={100} />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
</>
|
||||
)}
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
{isGrouped ? (
|
||||
<>
|
||||
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
|
||||
{pivoted!.groups.map((g, i) => (
|
||||
<Bar key={g} dataKey={g} fill={PALETTE[i % PALETTE.length]} radius={[3, 3, 0, 0]} />
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
<Bar dataKey="_y" fill={PALETTE[0]} radius={[3, 3, 0, 0]} />
|
||||
)}
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
function ChartFrame({
|
||||
config,
|
||||
height,
|
||||
count,
|
||||
totalRows,
|
||||
children,
|
||||
}: {
|
||||
config: ChartConfig;
|
||||
height: number;
|
||||
count: number;
|
||||
totalRows: number;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div className="rounded-md border border-border/40 bg-background">
|
||||
<div className="flex items-center gap-2 border-b border-border/30 px-2 py-1 text-[11px] text-muted-foreground">
|
||||
<span className="font-medium text-foreground/80">
|
||||
{config.title ?? `${capitalize(config.chart_type)} chart`}
|
||||
</span>
|
||||
<span className="ml-auto text-muted-foreground/60">
|
||||
{count} point{count === 1 ? "" : "s"}
|
||||
{totalRows > MAX_POINTS && ` (of ${totalRows}, capped at ${MAX_POINTS})`}
|
||||
</span>
|
||||
</div>
|
||||
<div className="p-2" style={{ minHeight: height }}>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ChartFallback({ config, message }: { config: ChartConfig; message: string }) {
|
||||
return (
|
||||
<div className="rounded-md border border-destructive/40 bg-destructive/5 p-3 text-xs">
|
||||
<div className="font-medium text-destructive">
|
||||
Chart {config.chart_type} failed
|
||||
</div>
|
||||
<div className="mt-1 text-muted-foreground">{message}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function capitalize(s: string) {
|
||||
return s.charAt(0).toUpperCase() + s.slice(1);
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useState } from "react";
|
||||
import { ResultsTable } from "@/components/results/ResultsTable";
|
||||
import { ExportDialog } from "@/components/export/ExportDialog";
|
||||
import { ChartPreview } from "./ChartPreview";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
@@ -15,19 +14,12 @@ import {
|
||||
AlertCircle,
|
||||
Sparkles,
|
||||
User,
|
||||
Wrench,
|
||||
Database,
|
||||
Columns,
|
||||
Layers,
|
||||
RefreshCw,
|
||||
StickyNote,
|
||||
Bookmark,
|
||||
BookmarkPlus,
|
||||
Maximize2,
|
||||
Download,
|
||||
BarChart3,
|
||||
} from "lucide-react";
|
||||
import type { ChartConfig, ChatMessage } from "@/types";
|
||||
import type { ChatMessage } from "@/types";
|
||||
import { getToolMeta, isQueryResultTool } from "./tool-registry";
|
||||
|
||||
interface Props {
|
||||
message: ChatMessage;
|
||||
@@ -79,8 +71,10 @@ function AssistantBubble({ text }: { text: string }) {
|
||||
|
||||
function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const preview = extractToolPreview(tool, inputJson);
|
||||
const Icon = iconForTool(tool);
|
||||
const meta = getToolMeta(tool);
|
||||
const preview = previewFromJson(tool, inputJson);
|
||||
const Icon = meta.icon;
|
||||
const showSqlPreview = (tool === "run_query" || tool === "explain_query") && preview;
|
||||
|
||||
return (
|
||||
<div className="ml-8 rounded-md border border-border/40 bg-muted/20">
|
||||
@@ -91,17 +85,14 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
|
||||
>
|
||||
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
|
||||
<Icon className="h-3 w-3" />
|
||||
<span className="font-medium">{labelForTool(tool)}</span>
|
||||
<span className="font-medium">{meta.label}</span>
|
||||
{preview && (
|
||||
<span className="ml-1 truncate text-muted-foreground/70">
|
||||
{preview.slice(0, 80)}
|
||||
{preview.length > 80 ? "…" : ""}
|
||||
</span>
|
||||
<span className="ml-1 truncate text-muted-foreground/70">{preview}</span>
|
||||
)}
|
||||
</button>
|
||||
{expanded && (
|
||||
<div className="border-t border-border/30 p-2">
|
||||
{tool === "run_query" && preview ? (
|
||||
{showSqlPreview ? (
|
||||
<pre className="overflow-x-auto whitespace-pre-wrap rounded bg-background/60 p-2 font-mono text-[11px]">
|
||||
{preview}
|
||||
</pre>
|
||||
@@ -116,6 +107,15 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
|
||||
);
|
||||
}
|
||||
|
||||
function previewFromJson(tool: string, inputJson: string): string | null {
|
||||
try {
|
||||
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
|
||||
return getToolMeta(tool).preview(parsed);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function ToolResultBlock({
|
||||
tool,
|
||||
isError,
|
||||
@@ -132,87 +132,20 @@ function ToolResultBlock({
|
||||
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
|
||||
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
|
||||
<div>
|
||||
<div className="font-medium text-destructive">{labelForTool(tool)} failed</div>
|
||||
<div className="font-medium text-destructive">{getToolMeta(tool).label} failed</div>
|
||||
{text && <div className="mt-1 whitespace-pre-wrap text-muted-foreground">{text}</div>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Legacy schema tool — keep a one-line indicator for old threads.
|
||||
if (tool === "get_schema") {
|
||||
return (
|
||||
<div className="ml-8 flex items-center gap-2 rounded-md border border-border/40 bg-muted/20 px-2 py-1.5 text-xs text-muted-foreground">
|
||||
<Database className="h-3 w-3" />
|
||||
<span>Loaded schema context ({text?.length ?? 0} chars)</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Text-only tools (chat v2/v3): list_databases, list_tables, get_columns, switch_database,
|
||||
// remember, save_query, find_queries.
|
||||
if (
|
||||
tool === "list_databases" ||
|
||||
tool === "list_tables" ||
|
||||
tool === "get_columns" ||
|
||||
tool === "switch_database" ||
|
||||
tool === "remember" ||
|
||||
tool === "save_query" ||
|
||||
tool === "find_queries"
|
||||
) {
|
||||
return <TextToolResult tool={tool} text={text} />;
|
||||
}
|
||||
|
||||
// make_chart — render chart inline using config from text + data from result.
|
||||
if (tool === "make_chart") {
|
||||
return <ChartToolResult text={text} result={result} />;
|
||||
}
|
||||
|
||||
// run_query — full results table with Open-full / Export actions.
|
||||
if (result) {
|
||||
// Tools that produce a QueryResult (rendered as a table): run_query, sample_data.
|
||||
if (isQueryResultTool(tool) && result) {
|
||||
return <RunQueryResultBlock result={result} />;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function ChartToolResult({
|
||||
text,
|
||||
result,
|
||||
}: {
|
||||
text: string | null;
|
||||
result: { columns: string[]; types: string[]; rows: unknown[][]; row_count: number; execution_time_ms: number } | null;
|
||||
}) {
|
||||
let config: ChartConfig | null = null;
|
||||
try {
|
||||
if (text) {
|
||||
config = JSON.parse(text) as ChartConfig;
|
||||
}
|
||||
} catch {
|
||||
config = null;
|
||||
}
|
||||
if (!config || !result) {
|
||||
return (
|
||||
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
|
||||
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
|
||||
<div>
|
||||
<div className="font-medium text-destructive">Chart unavailable</div>
|
||||
<div className="mt-1 text-muted-foreground">
|
||||
The agent referenced a chart but the previous query result is not attached.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="ml-8">
|
||||
<ChartPreview
|
||||
config={config}
|
||||
columns={result.columns}
|
||||
rows={result.rows}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
// Everything else falls back to a collapsible text block.
|
||||
return <TextToolResult tool={tool} text={text} />;
|
||||
}
|
||||
|
||||
function RunQueryResultBlock({
|
||||
@@ -315,8 +248,10 @@ function RunQueryResultBlock({
|
||||
}
|
||||
|
||||
function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
|
||||
// Lazy preview: switch_database is short; everything else collapses by default.
|
||||
const [expanded, setExpanded] = useState(tool === "switch_database");
|
||||
const Icon = iconForTool(tool);
|
||||
const meta = getToolMeta(tool);
|
||||
const Icon = meta.icon;
|
||||
const lineCount = text ? text.split("\n").length : 0;
|
||||
|
||||
return (
|
||||
@@ -328,7 +263,7 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
|
||||
>
|
||||
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
|
||||
<Icon className="h-3 w-3" />
|
||||
<span className="font-medium">{labelForTool(tool)}</span>
|
||||
<span className="font-medium">{meta.label}</span>
|
||||
{text && (
|
||||
<span className="ml-1 text-muted-foreground/60">
|
||||
{lineCount} line{lineCount === 1 ? "" : "s"}
|
||||
@@ -346,93 +281,6 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
|
||||
);
|
||||
}
|
||||
|
||||
function labelForTool(tool: string): string {
|
||||
switch (tool) {
|
||||
case "run_query":
|
||||
return "Run SQL";
|
||||
case "list_databases":
|
||||
return "List databases";
|
||||
case "list_tables":
|
||||
return "List tables";
|
||||
case "get_columns":
|
||||
return "Inspect columns";
|
||||
case "switch_database":
|
||||
return "Switch database";
|
||||
case "remember":
|
||||
return "Remember";
|
||||
case "save_query":
|
||||
return "Save query";
|
||||
case "find_queries":
|
||||
return "Find saved queries";
|
||||
case "make_chart":
|
||||
return "Make chart";
|
||||
case "get_schema":
|
||||
return "Load schema";
|
||||
default:
|
||||
return tool;
|
||||
}
|
||||
}
|
||||
|
||||
function iconForTool(tool: string) {
|
||||
switch (tool) {
|
||||
case "run_query":
|
||||
return Wrench;
|
||||
case "list_databases":
|
||||
return Database;
|
||||
case "list_tables":
|
||||
return Layers;
|
||||
case "get_columns":
|
||||
return Columns;
|
||||
case "switch_database":
|
||||
return RefreshCw;
|
||||
case "remember":
|
||||
return StickyNote;
|
||||
case "save_query":
|
||||
return BookmarkPlus;
|
||||
case "find_queries":
|
||||
return Bookmark;
|
||||
case "make_chart":
|
||||
return BarChart3;
|
||||
case "get_schema":
|
||||
return Database;
|
||||
default:
|
||||
return Wrench;
|
||||
}
|
||||
}
|
||||
|
||||
function extractToolPreview(tool: string, inputJson: string): string | null {
|
||||
try {
|
||||
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
|
||||
switch (tool) {
|
||||
case "run_query":
|
||||
return typeof parsed.sql === "string" ? parsed.sql : null;
|
||||
case "list_tables":
|
||||
return typeof parsed.database === "string" ? parsed.database : null;
|
||||
case "switch_database":
|
||||
return typeof parsed.database === "string" ? parsed.database : null;
|
||||
case "get_columns":
|
||||
return Array.isArray(parsed.tables) ? parsed.tables.join(", ") : null;
|
||||
case "remember":
|
||||
return typeof parsed.note === "string" ? parsed.note : null;
|
||||
case "save_query":
|
||||
return typeof parsed.name === "string" ? parsed.name : null;
|
||||
case "find_queries":
|
||||
return typeof parsed.text === "string" ? parsed.text : null;
|
||||
case "make_chart": {
|
||||
const t = typeof parsed.chart_type === "string" ? parsed.chart_type : null;
|
||||
const x = typeof parsed.x === "string" ? parsed.x : null;
|
||||
const y = typeof parsed.y === "string" ? parsed.y : null;
|
||||
if (t && x && y) return `${t}: ${x} → ${y}`;
|
||||
return null;
|
||||
}
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function prettyJson(s: string): string {
|
||||
try {
|
||||
return JSON.stringify(JSON.parse(s), null, 2);
|
||||
|
||||
107
src/components/chat/tool-registry.ts
Normal file
107
src/components/chat/tool-registry.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
import {
|
||||
Database,
|
||||
Layers,
|
||||
Columns,
|
||||
RefreshCw,
|
||||
Wrench,
|
||||
StickyNote,
|
||||
Bookmark,
|
||||
BookmarkPlus,
|
||||
Activity,
|
||||
Shuffle,
|
||||
GitBranch,
|
||||
AlertTriangle,
|
||||
} from "lucide-react";
|
||||
import type { LucideIcon } from "lucide-react";
|
||||
|
||||
export type ToolMeta = {
|
||||
icon: LucideIcon;
|
||||
label: string;
|
||||
preview: (parsed: Record<string, unknown>) => string | null;
|
||||
};
|
||||
|
||||
const truncate = (s: unknown, n = 80): string | null => {
|
||||
if (typeof s !== "string") return null;
|
||||
return s.length > n ? `${s.slice(0, n)}…` : s;
|
||||
};
|
||||
|
||||
export const TOOLS: Record<string, ToolMeta> = {
|
||||
list_databases: {
|
||||
icon: Database,
|
||||
label: "List databases",
|
||||
preview: () => null,
|
||||
},
|
||||
list_tables: {
|
||||
icon: Layers,
|
||||
label: "List tables",
|
||||
preview: (p) => (typeof p.database === "string" ? p.database : null),
|
||||
},
|
||||
get_columns: {
|
||||
icon: Columns,
|
||||
label: "Inspect columns",
|
||||
preview: (p) => (Array.isArray(p.tables) ? (p.tables as string[]).join(", ") : null),
|
||||
},
|
||||
switch_database: {
|
||||
icon: RefreshCw,
|
||||
label: "Switch database",
|
||||
preview: (p) => (typeof p.database === "string" ? p.database : null),
|
||||
},
|
||||
run_query: {
|
||||
icon: Wrench,
|
||||
label: "Run SQL",
|
||||
preview: (p) => truncate(p.sql),
|
||||
},
|
||||
remember: {
|
||||
icon: StickyNote,
|
||||
label: "Remember",
|
||||
preview: (p) => (typeof p.note === "string" ? p.note : null),
|
||||
},
|
||||
save_query: {
|
||||
icon: BookmarkPlus,
|
||||
label: "Save query",
|
||||
preview: (p) => (typeof p.name === "string" ? p.name : null),
|
||||
},
|
||||
find_queries: {
|
||||
icon: Bookmark,
|
||||
label: "Find saved queries",
|
||||
preview: (p) => (typeof p.text === "string" ? p.text : null),
|
||||
},
|
||||
profile_table: {
|
||||
icon: Activity,
|
||||
label: "Profile table",
|
||||
preview: (p) => (typeof p.table === "string" ? p.table : null),
|
||||
},
|
||||
sample_data: {
|
||||
icon: Shuffle,
|
||||
label: "Sample rows",
|
||||
preview: (p) => {
|
||||
const t = typeof p.table === "string" ? p.table : "";
|
||||
const limit = typeof p.limit === "number" ? p.limit : 50;
|
||||
return t ? `${t} (${limit})` : null;
|
||||
},
|
||||
},
|
||||
explain_query: {
|
||||
icon: GitBranch,
|
||||
label: "Explain query",
|
||||
preview: (p) => truncate(p.sql),
|
||||
},
|
||||
detect_skew: {
|
||||
icon: AlertTriangle,
|
||||
label: "Detect skew",
|
||||
preview: (p) => (typeof p.table === "string" ? p.table : null),
|
||||
},
|
||||
};
|
||||
|
||||
export function getToolMeta(tool: string): ToolMeta {
|
||||
return (
|
||||
TOOLS[tool] ?? {
|
||||
icon: Wrench,
|
||||
label: tool,
|
||||
preview: () => null,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function isQueryResultTool(tool: string): boolean {
|
||||
return tool === "run_query" || tool === "sample_data";
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
import { ResultsTable } from "./ResultsTable";
|
||||
import { ResultsJsonView } from "./ResultsJsonView";
|
||||
import type { QueryResult } from "@/types";
|
||||
import { Loader2, AlertCircle, Sparkles, Wand2 } from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Loader2, AlertCircle } from "lucide-react";
|
||||
|
||||
interface Props {
|
||||
result?: QueryResult | null;
|
||||
@@ -15,10 +14,6 @@ interface Props {
|
||||
value: unknown
|
||||
) => void;
|
||||
highlightedCells?: Set<string>;
|
||||
aiExplanation?: string | null;
|
||||
isAiLoading?: boolean;
|
||||
onExplainError?: () => void;
|
||||
onFixError?: () => void;
|
||||
}
|
||||
|
||||
export function ResultsPanel({
|
||||
@@ -28,10 +23,6 @@ export function ResultsPanel({
|
||||
viewMode = "table",
|
||||
onCellDoubleClick,
|
||||
highlightedCells,
|
||||
aiExplanation,
|
||||
isAiLoading,
|
||||
onExplainError,
|
||||
onFixError,
|
||||
}: Props) {
|
||||
if (isLoading) {
|
||||
return (
|
||||
@@ -42,22 +33,6 @@ export function ResultsPanel({
|
||||
);
|
||||
}
|
||||
|
||||
if (aiExplanation) {
|
||||
return (
|
||||
<div className="h-full select-text overflow-auto p-4">
|
||||
<div className="rounded-md border bg-muted/30 p-4">
|
||||
<div className="mb-2 flex items-center gap-2 text-xs font-medium text-muted-foreground">
|
||||
<Sparkles className="h-3.5 w-3.5" />
|
||||
AI Explanation
|
||||
</div>
|
||||
<pre className="whitespace-pre-wrap font-sans text-sm leading-relaxed text-foreground">
|
||||
{aiExplanation}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex h-full select-text flex-col items-center justify-center gap-3 p-4">
|
||||
@@ -65,42 +40,6 @@ export function ResultsPanel({
|
||||
<AlertCircle className="mt-0.5 h-4 w-4 shrink-0" />
|
||||
<pre className="whitespace-pre-wrap font-mono text-xs">{error}</pre>
|
||||
</div>
|
||||
{(onExplainError || onFixError) && (
|
||||
<div className="flex items-center gap-2">
|
||||
{onExplainError && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
className="h-7 gap-1.5 text-xs"
|
||||
onClick={onExplainError}
|
||||
disabled={isAiLoading}
|
||||
>
|
||||
{isAiLoading ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
<Sparkles className="h-3 w-3" />
|
||||
)}
|
||||
Explain
|
||||
</Button>
|
||||
)}
|
||||
{onFixError && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
className="h-7 gap-1.5 text-xs"
|
||||
onClick={onFixError}
|
||||
disabled={isAiLoading}
|
||||
>
|
||||
{isAiLoading ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
<Wand2 className="h-3 w-3" />
|
||||
)}
|
||||
Fix with AI
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import type { AiProvider, AppSettings } from "@/types";
|
||||
const SUPPORTED_AI_PROVIDERS: { value: AiProvider; label: string }[] = [
|
||||
{ value: "ollama", label: "Ollama (local)" },
|
||||
{ value: "fireworks", label: "Fireworks AI" },
|
||||
{ value: "openrouter", label: "OpenRouter" },
|
||||
];
|
||||
|
||||
interface Props {
|
||||
@@ -50,6 +51,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
const [aiProvider, setAiProvider] = useState<AiProvider>("ollama");
|
||||
const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434");
|
||||
const [fireworksApiKey, setFireworksApiKey] = useState("");
|
||||
const [openrouterApiKey, setOpenrouterApiKey] = useState("");
|
||||
const [aiModel, setAiModel] = useState("");
|
||||
|
||||
const [copied, setCopied] = useState(false);
|
||||
@@ -70,10 +72,14 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
if (aiSettings) {
|
||||
// Legacy openai/anthropic values aren't user-selectable here — fall back to ollama.
|
||||
setAiProvider(
|
||||
aiSettings.provider === "fireworks" ? "fireworks" : "ollama"
|
||||
aiSettings.provider === "fireworks" ||
|
||||
aiSettings.provider === "openrouter"
|
||||
? aiSettings.provider
|
||||
: "ollama"
|
||||
);
|
||||
setOllamaUrl(aiSettings.ollama_url);
|
||||
setFireworksApiKey(aiSettings.fireworks_api_key ?? "");
|
||||
setOpenrouterApiKey(aiSettings.openrouter_api_key ?? "");
|
||||
setAiModel(aiSettings.model);
|
||||
}
|
||||
}
|
||||
@@ -115,6 +121,10 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
aiProvider === "fireworks"
|
||||
? fireworksApiKey.trim() || undefined
|
||||
: aiSettings?.fireworks_api_key,
|
||||
openrouter_api_key:
|
||||
aiProvider === "openrouter"
|
||||
? openrouterApiKey.trim() || undefined
|
||||
: aiSettings?.openrouter_api_key,
|
||||
model: aiModel,
|
||||
},
|
||||
{
|
||||
@@ -167,7 +177,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
<span
|
||||
className={`inline-block h-2 w-2 rounded-full ${
|
||||
mcpStatus?.running
|
||||
? "bg-green-500"
|
||||
? "bg-success ring-2 ring-success/25"
|
||||
: "bg-muted-foreground/30"
|
||||
}`}
|
||||
/>
|
||||
@@ -189,7 +199,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
title="Copy endpoint URL"
|
||||
>
|
||||
{copied ? (
|
||||
<Check className="h-3 w-3 text-green-500" />
|
||||
<Check className="h-3 w-3 text-success" />
|
||||
) : (
|
||||
<Copy className="h-3 w-3" />
|
||||
)}
|
||||
@@ -229,6 +239,8 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
onOllamaUrlChange={setOllamaUrl}
|
||||
fireworksApiKey={fireworksApiKey}
|
||||
onFireworksApiKeyChange={setFireworksApiKey}
|
||||
openrouterApiKey={openrouterApiKey}
|
||||
onOpenRouterApiKeyChange={setOpenrouterApiKey}
|
||||
model={aiModel}
|
||||
onModelChange={setAiModel}
|
||||
/>
|
||||
|
||||
@@ -13,7 +13,7 @@ import { useCompletionSchema } from "@/hooks/use-completion-schema";
|
||||
import { useConnections } from "@/hooks/use-connections";
|
||||
import { useAppStore } from "@/stores/app-store";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles, BrainCircuit } from "lucide-react";
|
||||
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces } from "lucide-react";
|
||||
import { format as formatSql } from "sql-formatter";
|
||||
import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog";
|
||||
import {
|
||||
@@ -25,8 +25,6 @@ import {
|
||||
import { exportCsv, exportJson } from "@/lib/tauri";
|
||||
import { save } from "@tauri-apps/plugin-dialog";
|
||||
import { toast } from "sonner";
|
||||
import { AiBar } from "@/components/ai/AiBar";
|
||||
import { useExplainSql, useFixSqlError } from "@/hooks/use-ai";
|
||||
import type { QueryResult, ExplainResult } from "@/types";
|
||||
|
||||
interface Props {
|
||||
@@ -53,12 +51,8 @@ export function WorkspacePanel({
|
||||
const [resultView, setResultView] = useState<"results" | "explain">("results");
|
||||
const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table");
|
||||
const [saveDialogOpen, setSaveDialogOpen] = useState(false);
|
||||
const [aiBarOpen, setAiBarOpen] = useState(false);
|
||||
const [aiExplanation, setAiExplanation] = useState<string | null>(null);
|
||||
|
||||
const queryMutation = useQueryExecution();
|
||||
const explainMutation = useExplainSql();
|
||||
const fixMutation = useFixSqlError();
|
||||
const addHistoryMutation = useAddHistory();
|
||||
const { data: connections } = useConnections();
|
||||
const { data: completionSchema } = useCompletionSchema(connectionId);
|
||||
@@ -102,7 +96,6 @@ export function WorkspacePanel({
|
||||
if (!sqlValue.trim() || !connectionId) return;
|
||||
setError(null);
|
||||
setExplainData(null);
|
||||
setAiExplanation(null);
|
||||
setResultView("results");
|
||||
queryMutation.mutate(
|
||||
{ connectionId, sql: sqlValue },
|
||||
@@ -196,60 +189,6 @@ export function WorkspacePanel({
|
||||
[result]
|
||||
);
|
||||
|
||||
const isAiLoading = explainMutation.isPending || fixMutation.isPending;
|
||||
|
||||
const handleAiExplain = useCallback(() => {
|
||||
if (!sqlValue.trim() || !connectionId) return;
|
||||
setAiExplanation(null);
|
||||
setResultView("results");
|
||||
explainMutation.mutate(
|
||||
{ connectionId, sql: sqlValue },
|
||||
{
|
||||
onSuccess: (explanation) => {
|
||||
setAiExplanation(explanation);
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI Explain failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
}, [connectionId, sqlValue, explainMutation]);
|
||||
|
||||
const handleExplainError = useCallback(() => {
|
||||
if (!sqlValue.trim() || !connectionId || !error) return;
|
||||
setAiExplanation(null);
|
||||
explainMutation.mutate(
|
||||
{ connectionId, sql: `${sqlValue}\n\n-- Error: ${error}` },
|
||||
{
|
||||
onSuccess: (explanation) => {
|
||||
setAiExplanation(explanation);
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI Explain failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
}, [connectionId, sqlValue, error, explainMutation]);
|
||||
|
||||
const handleFixError = useCallback(() => {
|
||||
if (!sqlValue.trim() || !connectionId || !error) return;
|
||||
fixMutation.mutate(
|
||||
{ connectionId, sql: sqlValue, errorMessage: error },
|
||||
{
|
||||
onSuccess: (fixedSql) => {
|
||||
setSqlValue(fixedSql);
|
||||
onSqlChange?.(fixedSql);
|
||||
setError(null);
|
||||
setAiExplanation(null);
|
||||
toast.success("SQL replaced by AI suggestion");
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI Fix failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
}, [connectionId, sqlValue, error, fixMutation, onSqlChange]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ResizablePanelGroup orientation="vertical">
|
||||
@@ -308,35 +247,6 @@ export function WorkspacePanel({
|
||||
Save
|
||||
</Button>
|
||||
|
||||
<div className="mx-1 h-3.5 w-px bg-border/40" />
|
||||
|
||||
{/* AI actions group — purple-branded */}
|
||||
<Button
|
||||
size="xs"
|
||||
variant={aiBarOpen ? "secondary" : "ghost"}
|
||||
className={`gap-1 text-[11px] ${aiBarOpen ? "text-tusk-purple" : ""}`}
|
||||
onClick={() => setAiBarOpen(!aiBarOpen)}
|
||||
title="AI SQL Generator"
|
||||
>
|
||||
<Sparkles className={`h-3 w-3 ${aiBarOpen ? "tusk-ai-icon" : ""}`} />
|
||||
AI
|
||||
</Button>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
className="gap-1 text-[11px]"
|
||||
onClick={handleAiExplain}
|
||||
disabled={isAiLoading || !sqlValue.trim()}
|
||||
title="Explain query with AI"
|
||||
>
|
||||
{isAiLoading ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
<BrainCircuit className="h-3 w-3" />
|
||||
)}
|
||||
AI Explain
|
||||
</Button>
|
||||
|
||||
{result && result.columns.length > 0 && (
|
||||
<>
|
||||
<div className="mx-1 h-3.5 w-px bg-border/40" />
|
||||
@@ -369,23 +279,12 @@ export function WorkspacePanel({
|
||||
{"\u2318"}Enter
|
||||
</span>
|
||||
{isReadOnly && (
|
||||
<span className="ml-2 flex items-center gap-1 rounded-sm bg-amber-500/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-amber-500">
|
||||
<span className="ml-2 flex items-center gap-1 rounded-sm bg-warning/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-warning">
|
||||
<Lock className="h-2.5 w-2.5" />
|
||||
READ
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{aiBarOpen && (
|
||||
<AiBar
|
||||
connectionId={connectionId}
|
||||
onSqlGenerated={(sql) => {
|
||||
setSqlValue(sql);
|
||||
onSqlChange?.(sql);
|
||||
}}
|
||||
onClose={() => setAiBarOpen(false)}
|
||||
onExecute={handleExecute}
|
||||
/>
|
||||
)}
|
||||
<div className="min-h-0 flex-1">
|
||||
<SqlEditor
|
||||
value={sqlValue}
|
||||
@@ -400,7 +299,7 @@ export function WorkspacePanel({
|
||||
<ResizableHandle withHandle />
|
||||
<ResizablePanel id="results" defaultSize="60%" minSize="15%">
|
||||
<div className="flex h-full flex-col overflow-hidden">
|
||||
{(explainData || result || error || aiExplanation) && (
|
||||
{(explainData || result || error) && (
|
||||
<div className="flex shrink-0 items-center border-b border-border/40 text-xs">
|
||||
<button
|
||||
className={`relative px-3 py-1.5 font-medium transition-colors ${
|
||||
@@ -469,10 +368,6 @@ export function WorkspacePanel({
|
||||
error={error}
|
||||
isLoading={queryMutation.isPending && resultView === "results"}
|
||||
viewMode={resultViewMode}
|
||||
aiExplanation={aiExplanation}
|
||||
isAiLoading={isAiLoading}
|
||||
onExplainError={error ? handleExplainError : undefined}
|
||||
onFixError={error ? handleFixError : undefined}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -4,9 +4,7 @@ import {
|
||||
saveAiSettings,
|
||||
listOllamaModels,
|
||||
listFireworksModels,
|
||||
generateSql,
|
||||
explainSql,
|
||||
fixSqlError,
|
||||
listOpenRouterModels,
|
||||
} from "@/lib/tauri";
|
||||
import type { AiSettings } from "@/types";
|
||||
|
||||
@@ -47,40 +45,12 @@ export function useFireworksModels(apiKey: string | undefined) {
|
||||
});
|
||||
}
|
||||
|
||||
export function useGenerateSql() {
|
||||
return useMutation({
|
||||
mutationFn: ({
|
||||
connectionId,
|
||||
prompt,
|
||||
}: {
|
||||
connectionId: string;
|
||||
prompt: string;
|
||||
}) => generateSql(connectionId, prompt),
|
||||
});
|
||||
}
|
||||
|
||||
export function useExplainSql() {
|
||||
return useMutation({
|
||||
mutationFn: ({
|
||||
connectionId,
|
||||
sql,
|
||||
}: {
|
||||
connectionId: string;
|
||||
sql: string;
|
||||
}) => explainSql(connectionId, sql),
|
||||
});
|
||||
}
|
||||
|
||||
export function useFixSqlError() {
|
||||
return useMutation({
|
||||
mutationFn: ({
|
||||
connectionId,
|
||||
sql,
|
||||
errorMessage,
|
||||
}: {
|
||||
connectionId: string;
|
||||
sql: string;
|
||||
errorMessage: string;
|
||||
}) => fixSqlError(connectionId, sql, errorMessage),
|
||||
export function useOpenRouterModels(apiKey: string | undefined) {
|
||||
return useQuery({
|
||||
queryKey: ["openrouter-models", apiKey],
|
||||
queryFn: () => listOpenRouterModels(apiKey!),
|
||||
enabled: !!apiKey && apiKey.trim().length > 0,
|
||||
retry: false,
|
||||
staleTime: 60_000,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -214,14 +214,8 @@ export const listOllamaModels = (ollamaUrl: string) =>
|
||||
export const listFireworksModels = (apiKey: string) =>
|
||||
invoke<OllamaModel[]>("list_fireworks_models", { apiKey });
|
||||
|
||||
export const generateSql = (connectionId: string, prompt: string) =>
|
||||
invoke<string>("generate_sql", { connectionId, prompt });
|
||||
|
||||
export const explainSql = (connectionId: string, sql: string) =>
|
||||
invoke<string>("explain_sql", { connectionId, sql });
|
||||
|
||||
export const fixSqlError = (connectionId: string, sql: string, errorMessage: string) =>
|
||||
invoke<string>("fix_sql_error", { connectionId, sql, errorMessage });
|
||||
export const listOpenRouterModels = (apiKey: string) =>
|
||||
invoke<OllamaModel[]>("list_openrouter_models", { apiKey });
|
||||
|
||||
export const chatSend = (connectionId: string, messages: ChatMessage[]) =>
|
||||
invoke<ChatTurnResult>("chat_send", { connectionId, messages });
|
||||
|
||||
@@ -134,14 +134,13 @@ export interface SavedQuery {
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export type AiProvider = "ollama" | "openai" | "anthropic" | "fireworks";
|
||||
export type AiProvider = "ollama" | "fireworks" | "openrouter";
|
||||
|
||||
export interface AiSettings {
|
||||
provider: AiProvider;
|
||||
ollama_url: string;
|
||||
openai_api_key?: string;
|
||||
anthropic_api_key?: string;
|
||||
fireworks_api_key?: string;
|
||||
openrouter_api_key?: string;
|
||||
model: string;
|
||||
}
|
||||
|
||||
@@ -216,14 +215,3 @@ export interface ChatTurnResult {
|
||||
messages: ChatMessage[];
|
||||
usage: ContextUsage;
|
||||
}
|
||||
|
||||
export type ChartType = "bar" | "line" | "area" | "pie";
|
||||
|
||||
export interface ChartConfig {
|
||||
chart_type: ChartType;
|
||||
x: string;
|
||||
y: string;
|
||||
group?: string | null;
|
||||
title?: string | null;
|
||||
orientation?: string | null;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user