refactor(ai): consolidate AI around chat tool-calling; add OpenRouter

- rework chat backend (chat.rs, chat_tools.rs, ai.rs, models, state) around tool calls
- add OpenRouter provider alongside Ollama/Fireworks in settings
- drop inline AiBar, ResultsPanel explain/fix UI and ChartPreview in favour of the chat panel
- add frontend chat tool-registry
This commit is contained in:
2026-05-23 15:01:52 +03:00
parent a485cf7ee3
commit 0cba457fb7
19 changed files with 1244 additions and 1931 deletions

View File

@@ -1,13 +1,14 @@
use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings};
use crate::commands::chat_tools::{
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
build_sample_sql, detect_skew_tool, explain_query_tool, find_queries_tool, get_columns_tool,
list_databases_tool, list_tables_tool, profile_table_tool, save_query_tool,
switch_database_tool,
};
use crate::commands::memory::{append_memory_core, read_memory_core};
use crate::commands::queries::execute_query_core;
use crate::error::{TuskError, TuskResult};
use crate::models::ai::OllamaChatMessage;
use crate::models::chat::{ChartConfig, ChatMessage, ChatTurnResult, ContextUsage};
use crate::models::chat::{ChatMessage, ChatTurnResult, ContextUsage};
use crate::models::query_result::QueryResult;
use crate::state::AppState;
use chrono::Utc;
@@ -30,11 +31,11 @@ const TEXT_TOOL_CHAR_CAP: usize = 10_000;
/// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192).
/// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content.
const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000;
/// Conservative default for managed providers (Fireworks). Most chat-capable
/// Fireworks models ship with 32K256K context windows; 384K chars (~128K tok)
/// is a safe floor that won't trigger false /compact nags on normal sessions
/// while still flagging genuinely runaway threads.
const CONTEXT_BUDGET_CHARS_FIREWORKS: u64 = 384_000;
/// Conservative default for managed providers (Fireworks, OpenRouter). Most
/// chat-capable hosted models ship with 32K256K context windows; 384K chars
/// (~128K tok) is a safe floor that won't trigger false /compact nags on normal
/// sessions while still flagging genuinely runaway threads.
const CONTEXT_BUDGET_CHARS_MANAGED: u64 = 384_000;
/// Stop the loop when the model fails the same SQL hurdle this many times in a
/// row. Beyond this, additional hops almost always burn the rest of the budget
/// on identical retries; a definitive `final` with the error is more useful.
@@ -55,9 +56,15 @@ enum AgentAction {
Remember { note: String },
SaveQuery { name: String, sql: String },
FindQueries { text: String },
MakeChart { config: ChartConfig },
ProfileTable { table: String },
SampleData { table: String, limit: u32 },
ExplainQuery { sql: String },
DetectSkew { table: String },
}
const SAMPLE_DATA_DEFAULT_LIMIT: u32 = 50;
const SAMPLE_DATA_MAX_LIMIT: u32 = 200;
/// Parse the model's JSON response. Accepts both shapes the model tends to emit:
/// {"action":"X","field":"..."} — flat (matches our prompt)
/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention)
@@ -157,60 +164,55 @@ fn parse_agent_action(raw: &str) -> Result<AgentAction, String> {
}
Ok(AgentAction::FindQueries { text })
}
"make_chart" => {
let chart_type = lookup("chart_type")
.or_else(|| lookup("type"))
"profile_table" => {
let table = lookup("table")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `chart_type`".to_string())?
.trim()
.to_lowercase();
if !["bar", "line", "area", "pie"].contains(&chart_type.as_str()) {
return Err(format!(
"make_chart `chart_type` must be one of: bar, line, area, pie. Got: {}",
chart_type
));
}
let x = lookup("x")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `x` column".to_string())?
.ok_or_else(|| "profile_table missing `table`".to_string())?
.trim()
.to_string();
let y = lookup("y")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `y` column".to_string())?
.trim()
.to_string();
if x.is_empty() || y.is_empty() {
return Err("make_chart `x` and `y` must not be empty".into());
if table.is_empty() {
return Err("profile_table `table` must not be empty".into());
}
let group = lookup("group")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let title = lookup("title")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let orientation = lookup("orientation")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty());
Ok(AgentAction::MakeChart {
config: ChartConfig {
chart_type,
x,
y,
group,
title,
orientation,
},
})
Ok(AgentAction::ProfileTable { table })
}
"sample_data" => {
let table = lookup("table")
.and_then(|v| v.as_str())
.ok_or_else(|| "sample_data missing `table`".to_string())?
.trim()
.to_string();
if table.is_empty() {
return Err("sample_data `table` must not be empty".into());
}
let limit = lookup("limit")
.and_then(|v| v.as_u64())
.map(|n| n as u32)
.unwrap_or(SAMPLE_DATA_DEFAULT_LIMIT)
.clamp(1, SAMPLE_DATA_MAX_LIMIT);
Ok(AgentAction::SampleData { table, limit })
}
"explain_query" => {
let sql = lookup("sql")
.and_then(|v| v.as_str())
.ok_or_else(|| "explain_query missing `sql`".to_string())?
.trim()
.to_string();
if sql.is_empty() {
return Err("explain_query `sql` must not be empty".into());
}
Ok(AgentAction::ExplainQuery { sql })
}
"detect_skew" => {
let table = lookup("table")
.and_then(|v| v.as_str())
.ok_or_else(|| "detect_skew missing `table`".to_string())?
.trim()
.to_string();
if table.is_empty() {
return Err("detect_skew `table` must not be empty".into());
}
Ok(AgentAction::DetectSkew { table })
}
// Legacy from earlier iterations — silently ignored at parse time so the
// model can recover with a different action.
"get_schema" => Err(
"get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(),
),
other => Err(format!("unknown action `{}`", other)),
}
}
@@ -285,8 +287,17 @@ You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On
{{"action":"save_query","name":"<short label>","sql":"<the SQL>"}}
Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved.
{{"action":"make_chart","chart_type":"bar","x":"<col>","y":"<col>","title":"<short title>"}}
Visualise the LAST successful run_query result as a chart inline. `chart_type` is one of: bar, line, area, pie. `x` and `y` MUST be column names from the previous result. Optional: `group` (column for series), `orientation` ("vertical"/"horizontal", bar only). Use after run_query when the data is aggregated and would be clearer as a chart (top-N comparisons → bar; time series → line/area; proportions → pie). Skip for tiny results (≤2 rows) and giant ones (>500 rows).
{{"action":"profile_table","table":"schema.table"}}
Per-column profile: NULL fraction, distinct cardinality, min/max range, top-K values. PG/GP reads pg_stats (zero-cost; ensure ANALYZE has run). ClickHouse fires one summary query (cheap on MergeTree). Use BEFORE writing aggregations to spot pseudo-enums, NULL-heavy columns, or skewed distributions.
{{"action":"sample_data","table":"schema.table","limit":50}}
Random row sample (default 50, max 200). PG/GP uses TABLESAMPLE BERNOULLI when reltuples > 0, else ORDER BY random(). CH uses SAMPLE 0.01 on MergeTree with a sampling key, else ORDER BY rand(). Use to eyeball value shape BEFORE writing filters; cheaper than `SELECT * LIMIT N` on huge tables.
{{"action":"explain_query","sql":"SELECT ..."}}
Run EXPLAIN (FORMAT JSON, ANALYZE, BUFFERS) on PG/GP, EXPLAIN PLAN on CH. Reports root node, planning + execution time, seq-scanned tables, spilled sorts, est-vs-actual row skew, Greenplum Motions. Use AFTER a slow run_query.
{{"action":"detect_skew","table":"schema.table"}}
Greenplum-only: counts rows per gp_segment_id and reports max/min/avg + skew ratio. Ratio > 1.5 ⇒ uneven distribution; suggests revisiting DISTRIBUTED BY. Soft-errors on PG/CH.
{{"action":"final","text":"..."}}
End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling).
@@ -296,6 +307,8 @@ WORKFLOW
2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question.
3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs.
4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns.
4a. If the user asks about value shape (cardinality, NULL rates, top values), prefer `profile_table` over a hand-written run_query. To eyeball actual rows, prefer `sample_data` over `LIMIT 100`.
4b. If the user reports a slow query or asks why something takes long, run `explain_query` on it. On Greenplum, if a single table appears unbalanced, check `detect_skew`.
5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first.
6. Execute run_query.
7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here.
@@ -415,9 +428,6 @@ fn build_history(
content: serde_json::json!({ "action": "final", "text": text }).to_string(),
}),
ChatMessage::ToolCall { tool, input_json, .. } => {
if tool == "get_schema" {
continue; // legacy
}
let mut envelope = serde_json::Map::new();
envelope.insert("action".to_string(), Value::String(tool.clone()));
if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) {
@@ -437,9 +447,6 @@ fn build_history(
result,
..
} => {
if tool == "get_schema" {
continue; // legacy
}
let payload = match tool.as_str() {
"run_query" => {
if *is_error {
@@ -521,7 +528,7 @@ async fn provider_budget_chars(state: &AppState, app: &AppHandle) -> u64 {
use crate::models::ai::AiProvider;
match load_ai_settings(app, state).await {
Ok(s) => match s.provider {
AiProvider::Fireworks => CONTEXT_BUDGET_CHARS_FIREWORKS,
AiProvider::Fireworks | AiProvider::OpenRouter => CONTEXT_BUDGET_CHARS_MANAGED,
_ => CONTEXT_BUDGET_CHARS_OLLAMA,
},
Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA,
@@ -597,7 +604,10 @@ pub async fn chat_send(
}
};
let is_run_query = matches!(&action, AgentAction::RunQuery { .. });
let is_run_query = matches!(
&action,
AgentAction::RunQuery { .. } | AgentAction::SampleData { .. }
);
match action {
AgentAction::Final { text } => {
@@ -742,91 +752,90 @@ pub async fn chat_send(
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::MakeChart { config } => {
let config_json = serde_json::to_string(&config).unwrap_or_else(|_| "{}".into());
AgentAction::ProfileTable { table } => {
push_tool_call(
&mut new_messages,
&mut working,
"make_chart",
config_json.clone(),
"profile_table",
serde_json::json!({ "table": &table }).to_string(),
);
let result_msg = match last_successful_query_result(&working) {
None => ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(
"make_chart needs a successful run_query result above it. Run a SELECT first, then call make_chart."
.to_string(),
),
result: None,
created_at: now_ms(),
},
Some(qr) => {
if !qr.columns.iter().any(|c| c == &config.x) {
let result = run_text_tool(
profile_table_tool(&state, &connection_id, &table).await,
"profile_table",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::ExplainQuery { sql } => {
push_tool_call(
&mut new_messages,
&mut working,
"explain_query",
serde_json::json!({ "sql": &sql }).to_string(),
);
let result = run_text_tool(
explain_query_tool(&state, &connection_id, &sql).await,
"explain_query",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::DetectSkew { table } => {
push_tool_call(
&mut new_messages,
&mut working,
"detect_skew",
serde_json::json!({ "table": &table }).to_string(),
);
let result = run_text_tool(
detect_skew_tool(&state, &connection_id, &table).await,
"detect_skew",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::SampleData { table, limit } => {
push_tool_call(
&mut new_messages,
&mut working,
"sample_data",
serde_json::json!({ "table": &table, "limit": limit }).to_string(),
);
let outcome = match build_sample_sql(&state, &connection_id, &table, limit).await {
Ok(sql) => match execute_query_core(&state, &connection_id, &sql).await {
Ok(qr) => {
consecutive_query_errors = 0;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"x column `{}` is not in the last result. Available: {}.",
config.x,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if !qr.columns.iter().any(|c| c == &config.y) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"y column `{}` is not in the last result. Available: {}.",
config.y,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if let Some(group) = &config.group {
if !qr.columns.iter().any(|c| c == group) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"group column `{}` is not in the last result. Available: {}.",
group,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: false,
text: Some(config_json.clone()),
result: Some(qr),
created_at: now_ms(),
}
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
tool: "sample_data".to_string(),
is_error: false,
text: Some(config_json.clone()),
text: None,
result: Some(qr),
created_at: now_ms(),
}
}
Err(e) => {
consecutive_query_errors += 1;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "sample_data".to_string(),
is_error: true,
text: Some(format_db_error(&e)),
result: None,
created_at: now_ms(),
}
}
},
Err(e) => {
consecutive_query_errors += 1;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "sample_data".to_string(),
is_error: true,
text: Some(format_db_error(&e)),
result: None,
created_at: now_ms(),
}
}
};
push_tool_result(&mut new_messages, &mut working, result_msg);
push_tool_result(&mut new_messages, &mut working, outcome);
}
}
@@ -982,26 +991,6 @@ fn format_db_error(e: &TuskError) -> String {
e.to_string()
}
/// Locate the most recent SUCCESSFUL run_query in the working thread and
/// return its full QueryResult. Used by make_chart to attach data to a chart
/// directive without relying on the model to re-send it.
fn last_successful_query_result(messages: &[ChatMessage]) -> Option<QueryResult> {
for m in messages.iter().rev() {
if let ChatMessage::ToolResult {
tool,
is_error: false,
result: Some(qr),
..
} = m
{
if tool == "run_query" {
return Some(qr.clone());
}
}
}
None
}
/// Pull the most recent run_query error text from the working thread, so the
/// post-loop "I gave up" summary can quote concrete errors back to the user.
fn last_run_query_error(messages: &[ChatMessage]) -> Option<String> {
@@ -1484,119 +1473,6 @@ mod tests {
assert!(last_run_query_error(&msgs).is_none());
}
#[test]
fn parses_make_chart_minimal() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"bar","x":"carrier","y":"trips"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.chart_type, "bar");
assert_eq!(config.x, "carrier");
assert_eq!(config.y, "trips");
assert!(config.group.is_none());
assert!(config.title.is_none());
}
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_make_chart_with_group_and_title() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"line","x":"month","y":"revenue","group":"region","title":"Revenue"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.group.as_deref(), Some("region"));
assert_eq!(config.title.as_deref(), Some("Revenue"));
}
_ => panic!("wrong variant"),
}
}
#[test]
fn make_chart_accepts_alternative_field_name_type() {
// Some models emit `type` instead of `chart_type`.
let a = parse_agent_action(
r#"{"action":"make_chart","type":"pie","x":"label","y":"value"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => assert_eq!(config.chart_type, "pie"),
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_make_chart_with_unknown_chart_type() {
let r = parse_agent_action(
r#"{"action":"make_chart","chart_type":"radar","x":"a","y":"b"}"#,
);
assert!(r.is_err());
}
#[test]
fn rejects_make_chart_missing_x_or_y() {
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","y":"a"}"#).is_err());
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","x":"a"}"#).is_err());
}
#[test]
fn last_successful_query_result_finds_recent() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![vec![Value::Number(1.into())]],
row_count: 1,
execution_time_ms: 1,
};
let msgs = vec![
ChatMessage::ToolResult {
id: "r1".into(),
tool: "run_query".into(),
is_error: false,
text: None,
result: Some(qr.clone()),
created_at: 1,
},
ChatMessage::ToolResult {
id: "r2".into(),
tool: "run_query".into(),
is_error: true,
text: Some("oops".into()),
result: None,
created_at: 2,
},
];
let found = last_successful_query_result(&msgs).expect("ok");
assert_eq!(found.columns, vec!["a".to_string()]);
}
#[test]
fn last_successful_query_result_skips_non_run_query() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![],
row_count: 0,
execution_time_ms: 0,
};
let msgs = vec![ChatMessage::ToolResult {
id: "r1".into(),
tool: "list_tables".into(),
is_error: false,
text: Some("public.x".into()),
result: Some(qr),
created_at: 1,
}];
assert!(last_successful_query_result(&msgs).is_none());
}
#[test]
fn render_thread_for_summary_includes_roles_and_skips_rows() {
let msgs = vec![
@@ -1625,11 +1501,6 @@ mod tests {
assert!(!rendered.contains("alice"));
}
#[test]
fn rejects_legacy_get_schema() {
assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err());
}
#[test]
fn truncates_long_cell() {
let long = "a".repeat(CELL_CHAR_CAP + 50);