Files
tusk/src-tauri/src/commands/chat.rs
Aleksey Shakhmatov 9a424dcd34 fix: use provider-aware context budget so Fireworks doesn't show 150% on small threads
The chat usage badge was hardcoded to ~8K-token Ollama defaults
(`CONTEXT_BUDGET_CHARS = 24_000`), which made every Fireworks session
look 150%+ full after a few hops even though models like Kimi-K2 carry
256K context windows. Now the budget is selected per-provider:

- Ollama → 24K chars (~8K tok), unchanged
- Fireworks → 384K chars (~128K tok), a safe floor for the smallest
  Fireworks chat models (qwen2.5-coder 32K) while not stuffing the bar
  for the larger ones

Auto-compact thresholds and the % badge both read this back from the
backend, so they now scale correctly when the user switches providers.
2026-05-06 23:11:56 +03:00

1662 lines
65 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings};
use crate::commands::chat_tools::{
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
switch_database_tool,
};
use crate::commands::memory::{append_memory_core, read_memory_core};
use crate::commands::queries::execute_query_core;
use crate::error::{TuskError, TuskResult};
use crate::models::ai::OllamaChatMessage;
use crate::models::chat::{ChartConfig, ChatMessage, ChatTurnResult, ContextUsage};
use crate::models::query_result::QueryResult;
use crate::state::AppState;
use chrono::Utc;
use serde_json::Value;
use std::sync::Arc;
use tauri::{AppHandle, State};
const MAX_HOPS: usize = 10;
/// Number of MOST RECENT run_query tool_results that get full sample-rows in
/// LLM history. Older ones are reduced to a marker so very long threads stay
/// within model context budget.
const RECENT_TOOL_RESULTS_FULL: usize = 4;
/// Sample-row cap for compressed run_query results in LLM history.
const RUN_QUERY_SAMPLE_ROWS: usize = 10;
/// Per-cell character cap when stringifying sample rows.
const CELL_CHAR_CAP: usize = 200;
/// Per text-tool-result character cap (list_tables, get_columns, etc).
const TEXT_TOOL_CHAR_CAP: usize = 10_000;
/// Soft cap on serialized history+system prompt characters before the user
/// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192).
/// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content.
const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000;
/// Conservative default for managed providers (Fireworks). Most chat-capable
/// Fireworks models ship with 32K256K context windows; 384K chars (~128K tok)
/// is a safe floor that won't trigger false /compact nags on normal sessions
/// while still flagging genuinely runaway threads.
const CONTEXT_BUDGET_CHARS_FIREWORKS: u64 = 384_000;
/// Stop the loop when the model fails the same SQL hurdle this many times in a
/// row. Beyond this, additional hops almost always burn the rest of the budget
/// on identical retries; a definitive `final` with the error is more useful.
const MAX_CONSECUTIVE_QUERY_ERRORS: usize = 2;
// ---------------------------------------------------------------------------
// Action protocol
// ---------------------------------------------------------------------------
#[derive(Debug)]
enum AgentAction {
Final { text: String },
RunQuery { sql: String },
ListDatabases,
ListTables { database: Option<String> },
GetColumns { tables: Vec<String> },
SwitchDatabase { database: String },
Remember { note: String },
SaveQuery { name: String, sql: String },
FindQueries { text: String },
MakeChart { config: ChartConfig },
}
/// Parse the model's JSON response. Accepts both shapes the model tends to emit:
/// {"action":"X","field":"..."} — flat (matches our prompt)
/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention)
fn parse_agent_action(raw: &str) -> Result<AgentAction, String> {
let v: Value = serde_json::from_str(raw).map_err(|e| e.to_string())?;
let obj = v.as_object().ok_or_else(|| "expected JSON object".to_string())?;
let action = obj
.get("action")
.and_then(|a| a.as_str())
.ok_or_else(|| "missing field `action`".to_string())?;
let lookup = |key: &str| -> Option<&Value> {
obj.get(key)
.or_else(|| obj.get("input").and_then(|i| i.as_object()).and_then(|i| i.get(key)))
};
match action {
"final" => {
let text = lookup("text")
.and_then(|v| v.as_str())
.ok_or_else(|| "final action missing `text`".to_string())?
.to_string();
Ok(AgentAction::Final { text })
}
"run_query" => {
let sql = lookup("sql")
.and_then(|v| v.as_str())
.ok_or_else(|| "run_query action missing `sql`".to_string())?
.to_string();
Ok(AgentAction::RunQuery { sql })
}
"list_databases" => Ok(AgentAction::ListDatabases),
"list_tables" => {
let database = lookup("database")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Ok(AgentAction::ListTables { database })
}
"get_columns" => {
let arr = lookup("tables")
.and_then(|v| v.as_array())
.ok_or_else(|| "get_columns action missing `tables`: [...]".to_string())?;
let tables: Vec<String> = arr
.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect();
if tables.is_empty() {
return Err("get_columns `tables` array must not be empty".into());
}
Ok(AgentAction::GetColumns { tables })
}
"switch_database" => {
let database = lookup("database")
.and_then(|v| v.as_str())
.ok_or_else(|| "switch_database missing `database`".to_string())?
.to_string();
Ok(AgentAction::SwitchDatabase { database })
}
"remember" => {
let note = lookup("note")
.and_then(|v| v.as_str())
.ok_or_else(|| "remember action missing `note`".to_string())?
.trim()
.to_string();
if note.is_empty() {
return Err("remember `note` must not be empty".into());
}
Ok(AgentAction::Remember { note })
}
"save_query" => {
let name = lookup("name")
.and_then(|v| v.as_str())
.ok_or_else(|| "save_query missing `name`".to_string())?
.trim()
.to_string();
let sql = lookup("sql")
.and_then(|v| v.as_str())
.ok_or_else(|| "save_query missing `sql`".to_string())?
.trim()
.to_string();
if name.is_empty() {
return Err("save_query `name` must not be empty".into());
}
if sql.is_empty() {
return Err("save_query `sql` must not be empty".into());
}
Ok(AgentAction::SaveQuery { name, sql })
}
"find_queries" => {
let text = lookup("text")
.and_then(|v| v.as_str())
.ok_or_else(|| "find_queries missing `text`".to_string())?
.trim()
.to_string();
if text.is_empty() {
return Err("find_queries `text` must not be empty".into());
}
Ok(AgentAction::FindQueries { text })
}
"make_chart" => {
let chart_type = lookup("chart_type")
.or_else(|| lookup("type"))
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `chart_type`".to_string())?
.trim()
.to_lowercase();
if !["bar", "line", "area", "pie"].contains(&chart_type.as_str()) {
return Err(format!(
"make_chart `chart_type` must be one of: bar, line, area, pie. Got: {}",
chart_type
));
}
let x = lookup("x")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `x` column".to_string())?
.trim()
.to_string();
let y = lookup("y")
.and_then(|v| v.as_str())
.ok_or_else(|| "make_chart missing `y` column".to_string())?
.trim()
.to_string();
if x.is_empty() || y.is_empty() {
return Err("make_chart `x` and `y` must not be empty".into());
}
let group = lookup("group")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let title = lookup("title")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty());
let orientation = lookup("orientation")
.and_then(|v| v.as_str())
.map(|s| s.trim().to_lowercase())
.filter(|s| !s.is_empty());
Ok(AgentAction::MakeChart {
config: ChartConfig {
chart_type,
x,
y,
group,
title,
orientation,
},
})
}
// Legacy from earlier iterations — silently ignored at parse time so the
// model can recover with a different action.
"get_schema" => Err(
"get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(),
),
other => Err(format!("unknown action `{}`", other)),
}
}
// ---------------------------------------------------------------------------
// id / time helpers
// ---------------------------------------------------------------------------
fn now_ms() -> i64 {
Utc::now().timestamp_millis()
}
fn new_id(prefix: &str) -> String {
format!("{}-{}-{}", prefix, now_ms(), rand_suffix())
}
fn rand_suffix() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
format!("{:x}", nanos)
}
// ---------------------------------------------------------------------------
// System prompt
// ---------------------------------------------------------------------------
fn system_prompt(overview: &str, memory: &str) -> String {
let overview_block = if overview.is_empty() {
"(overview unavailable; respond with `final` asking the user to reconnect.)".to_string()
} else {
overview.to_string()
};
let memory_block = if memory.trim().is_empty() {
"(empty — call remember() when you discover non-obvious facts about this database)".to_string()
} else {
memory.to_string()
};
format!(
r#"ROLE: Tusk's data assistant. Reply in the user's language.
You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On every turn output STRICT JSON — exactly one of these shapes, with all fields at the root (no `input` wrapper, no markdown fences):
{{"action":"list_databases"}}
Refresh the database list when the OVERVIEW seems stale.
{{"action":"list_tables"}}
List tables in the active database.
{{"action":"list_tables","database":"<name>"}}
List tables in a specific database (PostgreSQL: requires switch_database before run_query).
{{"action":"get_columns","tables":["schema.table","schema.table2"]}}
Load full column info (types, PK, FK, comments, enums) for the listed tables. Use this BEFORE writing SQL when you don't already know the columns.
{{"action":"switch_database","database":"<name>"}}
Change the active database. Required for PostgreSQL when the user's question concerns data in another database. ClickHouse rarely needs this — `db.table` qualifiers are allowed without switching.
{{"action":"run_query","sql":"SELECT ..."}}
Execute read-only SQL (SELECT / WITH ... SELECT / EXPLAIN / SHOW / DESCRIBE). Mutating SQL is rejected by the read-only guard.
{{"action":"remember","note":"<short observation>"}}
Persist a non-obvious fact about THIS database for future sessions: column semantics, naming conventions, business-rule encodings, gotchas. Keep notes < 200 chars. The user sees and can edit your notes in the Memory sidebar tab.
{{"action":"find_queries","text":"<keywords>"}}
Search saved queries (your past work + user-saved). Use BEFORE writing complex SQL — a usable variant may already exist. Top 10 matches with SQL preview are returned.
{{"action":"save_query","name":"<short label>","sql":"<the SQL>"}}
Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved.
{{"action":"make_chart","chart_type":"bar","x":"<col>","y":"<col>","title":"<short title>"}}
Visualise the LAST successful run_query result as a chart inline. `chart_type` is one of: bar, line, area, pie. `x` and `y` MUST be column names from the previous result. Optional: `group` (column for series), `orientation` ("vertical"/"horizontal", bar only). Use after run_query when the data is aggregated and would be clearer as a chart (top-N comparisons → bar; time series → line/area; proportions → pie). Skip for tiny results (≤2 rows) and giant ones (>500 rows).
{{"action":"final","text":"..."}}
End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling).
WORKFLOW
1. Read LEARNED NOTES below first — the user (or your past self) may have already documented relevant facts.
2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question.
3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs.
4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns.
5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first.
6. Execute run_query.
7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here.
8. If the query is likely to be re-run later (a real report-style request, not a one-off lookup), call `save_query` with a concise `name`.
9. Answer with `final`.
RULES
- Use ONLY identifiers visible to you (overview / list_tables / get_columns output). Don't pluralize, translate, or guess.
- **After get_columns, your next run_query must use ONLY column names that appear verbatim in that output.** Do not assume conventions like `name`, `id`, `title`. If get_columns shows a table has `company_name` and `legal_name` but no `name`, then `name` does NOT exist on that table.
- LIMIT on ad-hoc SELECTs unless aggregating.
- When run_query fails, READ the error carefully — especially any `HINT:` line, which often spells out the fix. Common PostgreSQL fixes:
* `operator does not exist: X = Y` (e.g. `character varying = uuid`) → cast one side, e.g. `a.id::uuid = b.id` or `a.id = b.id::text`. If unsure of types, call get_columns on both tables.
* `column "X" does not exist` → call get_columns on the table you're querying; the column is named differently. The error message lists which alias the column was attached to (e.g. `column le.name` means it's missing on the table aliased as `le`).
* `relation "X" does not exist` → check the OVERVIEW table list; the table may be in a different schema or database.
- On SQL error retry at most ONCE with a corrected query. On the second consecutive failure, STOP and respond with `final` explaining what's missing — do not loop. The harness will force-stop after 2 consecutive errors regardless.
- You have a hop budget of 10 tool calls per user turn. Spend them deliberately: don't burn hops re-running the same query — investigate (get_columns) when in doubt.
- `remember` is for durable facts, not transient observations. Don't memorise query results — only insights about the schema/data model that aren't already in the OVERVIEW.
═══════════════════════════════════════════════════════════════
LEARNED NOTES (per-connection memory; user can edit in sidebar → Memory)
═══════════════════════════════════════════════════════════════
{memory}
═══════════════════════════════════════════════════════════════
═══════════════════════════════════════════════════════════════
OVERVIEW (refreshed every turn)
═══════════════════════════════════════════════════════════════
{overview}
═══════════════════════════════════════════════════════════════
"#,
hops = MAX_HOPS,
memory = memory_block,
overview = overview_block,
)
}
// ---------------------------------------------------------------------------
// Compressed history projection
// ---------------------------------------------------------------------------
/// Compact view of a QueryResult for re-injection into the LLM history.
/// Keeps just enough for the model to reason about the next step (column
/// names, types, total row count, first N rows) without the full payload.
fn compact_query_result(result: &QueryResult) -> Value {
let total = result.rows.len();
let sample: Vec<Vec<Value>> = result
.rows
.iter()
.take(RUN_QUERY_SAMPLE_ROWS)
.map(|row| row.iter().map(truncate_cell).collect())
.collect();
serde_json::json!({
"columns": result.columns,
"types": result.types,
"row_count": total,
"execution_time_ms": result.execution_time_ms,
"sample_rows": sample,
"truncated": total > RUN_QUERY_SAMPLE_ROWS,
})
}
fn truncate_cell(v: &Value) -> Value {
match v {
Value::String(s) if s.chars().count() > CELL_CHAR_CAP => {
let truncated: String = s.chars().take(CELL_CHAR_CAP).collect();
Value::String(format!("{}", truncated))
}
other => other.clone(),
}
}
fn truncate_text(text: &str) -> String {
if text.len() <= TEXT_TOOL_CHAR_CAP {
text.to_string()
} else {
let mut out = text[..TEXT_TOOL_CHAR_CAP].to_string();
out.push_str("\n…(truncated)");
out
}
}
fn build_history(
messages: &[ChatMessage],
overview_text: &str,
memory_text: &str,
) -> Vec<OllamaChatMessage> {
// Index of run_query tool_results in `messages`. Used to mark which ones
// get full sample rows vs the "(rows omitted)" placeholder.
let run_query_indices: Vec<usize> = messages
.iter()
.enumerate()
.filter_map(|(i, m)| match m {
ChatMessage::ToolResult { tool, .. } if tool == "run_query" => Some(i),
_ => None,
})
.collect();
let keep_full_after_index: usize = if run_query_indices.len() <= RECENT_TOOL_RESULTS_FULL {
0
} else {
run_query_indices[run_query_indices.len() - RECENT_TOOL_RESULTS_FULL]
};
let mut out = Vec::with_capacity(messages.len() + 1);
out.push(OllamaChatMessage {
role: "system".to_string(),
content: system_prompt(overview_text, memory_text),
});
for (idx, m) in messages.iter().enumerate() {
match m {
ChatMessage::User { text, .. } => out.push(OllamaChatMessage {
role: "user".to_string(),
content: text.clone(),
}),
ChatMessage::Assistant { text, .. } => out.push(OllamaChatMessage {
role: "assistant".to_string(),
content: serde_json::json!({ "action": "final", "text": text }).to_string(),
}),
ChatMessage::ToolCall { tool, input_json, .. } => {
if tool == "get_schema" {
continue; // legacy
}
let mut envelope = serde_json::Map::new();
envelope.insert("action".to_string(), Value::String(tool.clone()));
if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) {
for (k, v) in input {
envelope.insert(k, v);
}
}
out.push(OllamaChatMessage {
role: "assistant".to_string(),
content: Value::Object(envelope).to_string(),
});
}
ChatMessage::ToolResult {
tool,
is_error,
text,
result,
..
} => {
if tool == "get_schema" {
continue; // legacy
}
let payload = match tool.as_str() {
"run_query" => {
if *is_error {
serde_json::json!({
"tool": "run_query",
"error": true,
"text": text.clone().unwrap_or_default(),
})
} else if idx < keep_full_after_index {
serde_json::json!({
"tool": "run_query",
"error": false,
"note": "rows omitted (older result; user has it in the UI above)",
})
} else if let Some(qr) = result {
serde_json::json!({
"tool": "run_query",
"error": false,
"result": compact_query_result(qr),
})
} else {
serde_json::json!({
"tool": "run_query",
"error": false,
"result": null,
})
}
}
// Text-only tools — pass through with cap.
_ => serde_json::json!({
"tool": tool,
"error": *is_error,
"text": text.as_deref().map(truncate_text),
}),
};
out.push(OllamaChatMessage {
role: "user".to_string(),
content: format!("TOOL_RESULT {}", payload),
});
}
}
}
out
}
// ---------------------------------------------------------------------------
// chat_send
// ---------------------------------------------------------------------------
/// Estimate how many characters the next LLM call will serialize to history
/// (system prompt + conversation, after compression). This is the same data
/// path as the actual call, so the count is exact for the chosen budget unit.
async fn compute_usage(
state: &AppState,
app: &AppHandle,
connection_id: &str,
working: &[ChatMessage],
) -> ContextUsage {
let overview = build_overview_context(state, connection_id)
.await
.unwrap_or_default();
let memory = read_memory_core(app, connection_id).unwrap_or_default();
let history = build_history(working, &overview, &memory);
// role string ("system"/"user"/"assistant") ≤ 9 chars + content + JSON envelope overhead
let used: u64 = history
.iter()
.map(|m| (m.role.len() + m.content.len() + 16) as u64)
.sum();
ContextUsage {
used_chars: used,
budget_chars: provider_budget_chars(state, app).await,
}
}
/// Returns the soft context budget appropriate for the currently-configured
/// LLM provider. Falls back to the Ollama default if settings can't be loaded.
async fn provider_budget_chars(state: &AppState, app: &AppHandle) -> u64 {
use crate::models::ai::AiProvider;
match load_ai_settings(app, state).await {
Ok(s) => match s.provider {
AiProvider::Fireworks => CONTEXT_BUDGET_CHARS_FIREWORKS,
_ => CONTEXT_BUDGET_CHARS_OLLAMA,
},
Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA,
}
}
#[tauri::command]
pub async fn chat_send(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
messages: Vec<ChatMessage>,
) -> TuskResult<ChatTurnResult> {
let mut new_messages: Vec<ChatMessage> = Vec::new();
let mut working: Vec<ChatMessage> = messages;
let mut consecutive_query_errors: usize = 0;
for _hop in 0..MAX_HOPS {
// Hard guard: stop once the model has hit the same SQL hurdle multiple
// times in a row. Beyond this point further hops virtually always
// re-run the same broken query against the model's better judgment.
if consecutive_query_errors >= MAX_CONSECUTIVE_QUERY_ERRORS {
let last_err = last_run_query_error(&working).unwrap_or_default();
let msg = ChatMessage::Assistant {
id: new_id("asst"),
text: format!(
"I tried {} times and kept hitting the same SQL error:\n\n{}\n\nThis usually means a schema mismatch — wrong column type, missing cast (often `::uuid` or `::text`), or a join key that doesn't exist as I assumed. Could you double-check the question, or open the table in the sidebar to verify column types? You can also write the SQL manually in Advanced mode.",
consecutive_query_errors, last_err
),
created_at: now_ms(),
};
new_messages.push(msg.clone());
working.push(msg);
let usage = compute_usage(&state, &app, &connection_id, &working).await;
return Ok(ChatTurnResult {
messages: new_messages,
usage,
});
}
// Overview is rebuilt per turn — cheap (cached) and reflects the active DB
// even if the user (or the agent) just switched it.
let overview_text = build_overview_context(&state, &connection_id)
.await
.unwrap_or_default();
// Memory is read fresh each turn so user-side edits in the Memory tab
// are visible to the agent immediately.
let memory_text = read_memory_core(&app, &connection_id).unwrap_or_default();
let history = build_history(&working, &overview_text, &memory_text);
let raw =
call_chat_messages(&app, &state, history, Some("json".to_string())).await?;
let trimmed = raw.trim();
let action = match parse_agent_action(trimmed) {
Ok(a) => a,
Err(parse_err) => {
let msg = ChatMessage::Assistant {
id: new_id("asst"),
text: format!(
"{}\n\n_(Note: model returned non-protocol output: {})_",
trimmed, parse_err
),
created_at: now_ms(),
};
new_messages.push(msg.clone());
working.push(msg);
let usage = compute_usage(&state, &app, &connection_id, &working).await;
return Ok(ChatTurnResult {
messages: new_messages,
usage,
});
}
};
let is_run_query = matches!(&action, AgentAction::RunQuery { .. });
match action {
AgentAction::Final { text } => {
let msg = ChatMessage::Assistant {
id: new_id("asst"),
text,
created_at: now_ms(),
};
new_messages.push(msg.clone());
working.push(msg);
let usage = compute_usage(&state, &app, &connection_id, &working).await;
return Ok(ChatTurnResult {
messages: new_messages,
usage,
});
}
AgentAction::RunQuery { sql } => {
push_tool_call(
&mut new_messages,
&mut working,
"run_query",
serde_json::json!({ "sql": sql }).to_string(),
);
let result = match execute_query_core(&state, &connection_id, &sql).await {
Ok(qr) => {
consecutive_query_errors = 0;
ChatMessage::ToolResult {
id: new_id("res"),
tool: "run_query".to_string(),
is_error: false,
text: None,
result: Some(qr),
created_at: now_ms(),
}
}
Err(e) => {
consecutive_query_errors += 1;
let suffix = match &e {
TuskError::ReadOnly => {
"\n\nRead-only mode is on. Toggle it off in the toolbar to allow writes."
}
_ => "",
};
ChatMessage::ToolResult {
id: new_id("res"),
tool: "run_query".to_string(),
is_error: true,
text: Some(format!("{}{}", format_db_error(&e), suffix)),
result: None,
created_at: now_ms(),
}
}
};
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::ListDatabases => {
push_tool_call(
&mut new_messages,
&mut working,
"list_databases",
"{}".to_string(),
);
let result = run_text_tool(
list_databases_tool(&state, &connection_id).await,
"list_databases",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::ListTables { database } => {
let input_json = match &database {
Some(db) => serde_json::json!({ "database": db }).to_string(),
None => "{}".to_string(),
};
push_tool_call(&mut new_messages, &mut working, "list_tables", input_json);
let result = run_text_tool(
list_tables_tool(&app, &state, &connection_id, database.as_deref()).await,
"list_tables",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::GetColumns { tables } => {
push_tool_call(
&mut new_messages,
&mut working,
"get_columns",
serde_json::json!({ "tables": tables }).to_string(),
);
let result = run_text_tool(
get_columns_tool(&state, &connection_id, &tables).await,
"get_columns",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::SwitchDatabase { database } => {
push_tool_call(
&mut new_messages,
&mut working,
"switch_database",
serde_json::json!({ "database": &database }).to_string(),
);
let result = run_text_tool(
switch_database_tool(&app, &state, &connection_id, &database).await,
"switch_database",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::Remember { note } => {
push_tool_call(
&mut new_messages,
&mut working,
"remember",
serde_json::json!({ "note": &note }).to_string(),
);
let outcome = append_memory_core(&app, &connection_id, &note)
.map(|_| format!("Saved note ({} chars).", note.len()));
let result = run_text_tool(outcome, "remember");
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::SaveQuery { name, sql } => {
push_tool_call(
&mut new_messages,
&mut working,
"save_query",
serde_json::json!({ "name": &name, "sql": &sql }).to_string(),
);
let result = run_text_tool(
save_query_tool(&app, &connection_id, &name, &sql).await,
"save_query",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::FindQueries { text } => {
push_tool_call(
&mut new_messages,
&mut working,
"find_queries",
serde_json::json!({ "text": &text }).to_string(),
);
let result = run_text_tool(
find_queries_tool(&app, &connection_id, &text).await,
"find_queries",
);
push_tool_result(&mut new_messages, &mut working, result);
}
AgentAction::MakeChart { config } => {
let config_json = serde_json::to_string(&config).unwrap_or_else(|_| "{}".into());
push_tool_call(
&mut new_messages,
&mut working,
"make_chart",
config_json.clone(),
);
let result_msg = match last_successful_query_result(&working) {
None => ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(
"make_chart needs a successful run_query result above it. Run a SELECT first, then call make_chart."
.to_string(),
),
result: None,
created_at: now_ms(),
},
Some(qr) => {
if !qr.columns.iter().any(|c| c == &config.x) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"x column `{}` is not in the last result. Available: {}.",
config.x,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if !qr.columns.iter().any(|c| c == &config.y) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"y column `{}` is not in the last result. Available: {}.",
config.y,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else if let Some(group) = &config.group {
if !qr.columns.iter().any(|c| c == group) {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: true,
text: Some(format!(
"group column `{}` is not in the last result. Available: {}.",
group,
qr.columns.join(", ")
)),
result: None,
created_at: now_ms(),
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: false,
text: Some(config_json.clone()),
result: Some(qr),
created_at: now_ms(),
}
}
} else {
ChatMessage::ToolResult {
id: new_id("res"),
tool: "make_chart".to_string(),
is_error: false,
text: Some(config_json.clone()),
result: Some(qr),
created_at: now_ms(),
}
}
}
};
push_tool_result(&mut new_messages, &mut working, result_msg);
}
}
// Any non-RunQuery, non-Final action means the model is investigating
// (e.g. get_columns to verify a type). Give it a fresh error budget.
if !is_run_query {
consecutive_query_errors = 0;
}
}
// Last-chance synthesis: model is out of tool calls but may have collected
// enough data on the last hop to answer. One extra LLM call, no JSON
// protocol, just plain text.
let synthesis = force_final_synthesis(&app, &state, &working).await;
let text = match synthesis {
Some(t) => format!("{}\n\n_(Tool-call limit reached; answer synthesised from collected results.)_", t),
None => format!(
"Stopped after {} tool calls without a final answer. Try rephrasing the question, splitting it into smaller parts, or running the SQL manually in Advanced mode.",
MAX_HOPS
),
};
let msg = ChatMessage::Assistant {
id: new_id("asst"),
text,
created_at: now_ms(),
};
new_messages.push(msg.clone());
working.push(msg);
let usage = compute_usage(&state, &app, &connection_id, &working).await;
Ok(ChatTurnResult {
messages: new_messages,
usage,
})
}
fn push_tool_call(
new_messages: &mut Vec<ChatMessage>,
working: &mut Vec<ChatMessage>,
tool: &str,
input_json: String,
) {
let call = ChatMessage::ToolCall {
id: new_id("call"),
tool: tool.to_string(),
input_json,
created_at: now_ms(),
};
new_messages.push(call.clone());
working.push(call);
}
fn push_tool_result(
new_messages: &mut Vec<ChatMessage>,
working: &mut Vec<ChatMessage>,
result: ChatMessage,
) {
new_messages.push(result.clone());
working.push(result);
}
fn run_text_tool(outcome: TuskResult<String>, tool: &str) -> ChatMessage {
match outcome {
Ok(text) => ChatMessage::ToolResult {
id: new_id("res"),
tool: tool.to_string(),
is_error: false,
text: Some(text),
result: None,
created_at: now_ms(),
},
Err(e) => ChatMessage::ToolResult {
id: new_id("res"),
tool: tool.to_string(),
is_error: true,
text: Some(e.to_string()),
result: None,
created_at: now_ms(),
},
}
}
// ---------------------------------------------------------------------------
// chat_compact
// ---------------------------------------------------------------------------
/// Render the older-history portion of the thread as a compact text block
/// for LLM-driven summarization. Skips QueryResult.rows (huge), keeps only
/// columns + row_count + sample.
fn render_thread_for_summary(messages: &[ChatMessage]) -> String {
let mut out = String::new();
for m in messages {
match m {
ChatMessage::User { text, .. } => {
out.push_str(&format!("USER: {}\n\n", text));
}
ChatMessage::Assistant { text, .. } => {
out.push_str(&format!("ASSISTANT: {}\n\n", text));
}
ChatMessage::ToolCall { tool, input_json, .. } => {
out.push_str(&format!("TOOL_CALL [{}]: {}\n\n", tool, input_json));
}
ChatMessage::ToolResult {
tool,
is_error,
text,
result,
..
} => {
if *is_error {
out.push_str(&format!(
"TOOL_ERROR [{}]: {}\n\n",
tool,
text.as_deref().unwrap_or("")
));
continue;
}
if let Some(qr) = result {
out.push_str(&format!(
"TOOL_RESULT [{}]: {} rows; columns={}\n\n",
tool,
qr.row_count,
qr.columns.join(", ")
));
} else if let Some(t) = text {
let snippet: String = t.chars().take(800).collect();
out.push_str(&format!("TOOL_RESULT [{}]: {}\n\n", tool, snippet));
}
}
}
}
out
}
/// Format a database error with PostgreSQL HINT/DETAIL extracted when present.
/// Plain `e.to_string()` shows just the bare error message, which makes
/// errors like `operator does not exist: character varying = uuid` impossible
/// to act on. PG's own HINT (`You might need to add explicit type casts`) is
/// often the difference between the agent finding the fix on retry and
/// looping forever.
fn format_db_error(e: &TuskError) -> String {
if let TuskError::Database(sqlx::Error::Database(db_err)) = e {
let mut out = format!("Database error: {}", db_err.message());
if let Some(pg) = db_err.try_downcast_ref::<sqlx::postgres::PgDatabaseError>() {
if let Some(detail) = pg.detail() {
out.push_str(&format!("\nDETAIL: {}", detail));
}
if let Some(hint) = pg.hint() {
out.push_str(&format!("\nHINT: {}", hint));
}
}
return out;
}
e.to_string()
}
/// Locate the most recent SUCCESSFUL run_query in the working thread and
/// return its full QueryResult. Used by make_chart to attach data to a chart
/// directive without relying on the model to re-send it.
fn last_successful_query_result(messages: &[ChatMessage]) -> Option<QueryResult> {
for m in messages.iter().rev() {
if let ChatMessage::ToolResult {
tool,
is_error: false,
result: Some(qr),
..
} = m
{
if tool == "run_query" {
return Some(qr.clone());
}
}
}
None
}
/// Pull the most recent run_query error text from the working thread, so the
/// post-loop "I gave up" summary can quote concrete errors back to the user.
fn last_run_query_error(messages: &[ChatMessage]) -> Option<String> {
for m in messages.iter().rev() {
if let ChatMessage::ToolResult {
tool,
is_error: true,
text,
..
} = m
{
if tool == "run_query" {
return text.clone();
}
}
}
None
}
/// Last-chance LLM call after MAX_HOPS is exhausted: nudge the model to
/// produce a `final` answer based on whatever data and tool results the
/// thread already contains. Without this, a thread that succeeded on its
/// last run_query would end with "Stopped after N tool calls" and waste
/// the result. Returns None if the LLM call fails.
async fn force_final_synthesis(
app: &AppHandle,
state: &AppState,
working: &[ChatMessage],
) -> Option<String> {
if working.is_empty() {
return None;
}
let convo = render_thread_for_summary(working);
let system = "The agent loop has reached the tool-call limit. The user is waiting for an answer right now. \
Based ONLY on the conversation below, write a SHORT plain-text answer for the user. \
Reply in the SAME language the user used. \
If a query produced results, summarise what those results show. \
If queries kept failing, explain what went wrong and what the user could do (provide the missing piece, switch to Advanced mode, etc.). \
Be concrete with numbers and identifiers from the results.\n\
\n\
OUTPUT FORMAT: PLAIN TEXT ONLY. \
DO NOT output JSON, markdown fences, or field names. \
DO NOT call any tools. DO NOT use the action protocol. \
Just the answer text.";
let llm_messages = vec![
OllamaChatMessage {
role: "system".to_string(),
content: system.to_string(),
},
OllamaChatMessage {
role: "user".to_string(),
content: convo,
},
];
match call_chat_messages(app, state, llm_messages, None).await {
Ok(s) => {
let cleaned = clean_summary(&s);
if cleaned.trim().is_empty() {
None
} else {
Some(cleaned)
}
}
Err(_) => None,
}
}
/// Strip JSON envelopes, markdown fences, and known field-extraction patterns
/// that the agent-trained model tends to emit even for non-agent prompts.
/// Returns the underlying summary text.
fn clean_summary(raw: &str) -> String {
let trimmed = raw.trim();
if trimmed.is_empty() {
return trimmed.to_string();
}
// Strip ```...``` fences (with or without lang).
let unfenced = if trimmed.starts_with("```") {
let body = trimmed.trim_start_matches('`').trim_start_matches('`').trim_start_matches('`');
// Drop optional language identifier on the first line.
let after_lang = body.split_once('\n').map(|(_, rest)| rest).unwrap_or(body);
let trimmed_end = after_lang.trim_end_matches('`').trim_end_matches('`').trim_end_matches('`');
trimmed_end.trim().to_string()
} else {
trimmed.to_string()
};
// If the model returned a JSON envelope, extract a known string field.
if unfenced.starts_with('{') {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&unfenced) {
for key in ["text", "summary", "content", "answer", "output"] {
if let Some(s) = v.get(key).and_then(|x| x.as_str()) {
return s.trim().to_string();
}
}
}
}
unfenced
}
/// Find the index of the last User message; returns messages.len() if no user message.
fn last_user_turn_index(messages: &[ChatMessage]) -> usize {
for (i, m) in messages.iter().enumerate().rev() {
if matches!(m, ChatMessage::User { .. }) {
return i;
}
}
messages.len()
}
/// LLM-summarise the older portion of a chat thread.
/// Returns thread = [ Assistant("📋 Compacted: …") , <last_user_turn_if_any> ].
/// If the thread has nothing to compact, returns it unchanged.
#[tauri::command]
pub async fn chat_compact(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
messages: Vec<ChatMessage>,
) -> TuskResult<ChatTurnResult> {
if messages.is_empty() {
let usage = compute_usage(&state, &app, &connection_id, &messages).await;
return Ok(ChatTurnResult { messages, usage });
}
// Preserve the user's most recent question (if any) untouched so the
// model can continue from it after compaction. Everything before goes
// into the summary.
let split_at = last_user_turn_index(&messages);
let (older, recent): (&[ChatMessage], &[ChatMessage]) = if split_at == messages.len() {
(&messages[..], &[])
} else {
(&messages[..split_at], &messages[split_at..])
};
if older.is_empty() {
let usage = compute_usage(&state, &app, &connection_id, &messages).await;
return Ok(ChatTurnResult { messages, usage });
}
let convo = render_thread_for_summary(older);
let system = "You are a precise summarizer of a database analysis dialogue. \
Produce a SHORT summary in the SAME language the user spoke. \
Use 3-6 bullet points covering: the user's goal, key tables/columns/queries used, \
numerical findings, conclusions reached, any open questions. \
Be concrete with numbers and identifiers. Total length < 800 chars.\n\
\n\
OUTPUT FORMAT: PLAIN TEXT. Start each bullet with `- `. \
DO NOT output JSON. DO NOT wrap output in `{` or `}`. \
DO NOT use markdown fences. DO NOT include field names like `action`, `text`, `summary`. \
DO NOT add a preamble. The first character of your reply must be `-`.";
let llm_messages = vec![
OllamaChatMessage {
role: "system".to_string(),
content: system.to_string(),
},
OllamaChatMessage {
role: "user".to_string(),
content: convo,
},
];
let summary = call_chat_messages(&app, &state, llm_messages, None)
.await
.map_err(|e| TuskError::Ai(format!("Compact failed: {}", e)))?;
let cleaned = clean_summary(&summary);
let compacted_msg = ChatMessage::Assistant {
id: new_id("asst"),
text: format!(
"📋 Compacted {} earlier message{}:\n\n{}",
older.len(),
if older.len() == 1 { "" } else { "s" },
cleaned
),
created_at: now_ms(),
};
let mut out: Vec<ChatMessage> = Vec::with_capacity(1 + recent.len());
out.push(compacted_msg);
out.extend(recent.iter().cloned());
let usage = compute_usage(&state, &app, &connection_id, &out).await;
Ok(ChatTurnResult {
messages: out,
usage,
})
}
// ---------------------------------------------------------------------------
// tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_flat_run_query() {
let a = parse_agent_action(r#"{"action":"run_query","sql":"SELECT 1"}"#).unwrap();
match a {
AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 1"),
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_nested_run_query() {
let a =
parse_agent_action(r#"{"action":"run_query","input":{"sql":"SELECT 2"}}"#).unwrap();
match a {
AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 2"),
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_get_columns() {
let a = parse_agent_action(
r#"{"action":"get_columns","tables":["public.users","public.orders"]}"#,
)
.unwrap();
match a {
AgentAction::GetColumns { tables } => {
assert_eq!(tables, vec!["public.users", "public.orders"]);
}
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_get_columns_nested() {
let a = parse_agent_action(
r#"{"action":"get_columns","input":{"tables":["public.t"]}}"#,
)
.unwrap();
match a {
AgentAction::GetColumns { tables } => assert_eq!(tables, vec!["public.t"]),
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_get_columns_empty_tables() {
assert!(parse_agent_action(r#"{"action":"get_columns","tables":[]}"#).is_err());
}
#[test]
fn parses_switch_database() {
let a = parse_agent_action(r#"{"action":"switch_database","database":"orders_db"}"#)
.unwrap();
match a {
AgentAction::SwitchDatabase { database } => assert_eq!(database, "orders_db"),
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_list_tables_optional_db() {
let a1 = parse_agent_action(r#"{"action":"list_tables"}"#).unwrap();
match a1 {
AgentAction::ListTables { database } => assert!(database.is_none()),
_ => panic!("wrong variant"),
}
let a2 = parse_agent_action(r#"{"action":"list_tables","database":"x"}"#).unwrap();
match a2 {
AgentAction::ListTables { database } => assert_eq!(database.as_deref(), Some("x")),
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_unknown_action() {
assert!(parse_agent_action(r#"{"action":"nuke","yes":true}"#).is_err());
}
#[test]
fn parses_remember_flat() {
let a = parse_agent_action(
r#"{"action":"remember","note":"trips.started_at is NULL for cancelled"}"#,
)
.unwrap();
match a {
AgentAction::Remember { note } => {
assert_eq!(note, "trips.started_at is NULL for cancelled");
}
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_remember_nested() {
let a = parse_agent_action(
r#"{"action":"remember","input":{"note":" surrounded by spaces "}}"#,
)
.unwrap();
match a {
AgentAction::Remember { note } => {
// trim happens in parser
assert_eq!(note, "surrounded by spaces");
}
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_remember_without_note() {
assert!(parse_agent_action(r#"{"action":"remember"}"#).is_err());
}
#[test]
fn rejects_remember_empty_note() {
assert!(parse_agent_action(r#"{"action":"remember","note":" "}"#).is_err());
}
#[test]
fn parses_save_query_flat() {
let a = parse_agent_action(
r#"{"action":"save_query","name":"GMV last 30d","sql":"SELECT 1"}"#,
)
.unwrap();
match a {
AgentAction::SaveQuery { name, sql } => {
assert_eq!(name, "GMV last 30d");
assert_eq!(sql, "SELECT 1");
}
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_save_query_nested() {
let a = parse_agent_action(
r#"{"action":"save_query","input":{"name":"x","sql":"SELECT 2"}}"#,
)
.unwrap();
match a {
AgentAction::SaveQuery { name, sql } => {
assert_eq!(name, "x");
assert_eq!(sql, "SELECT 2");
}
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_save_query_missing_fields() {
assert!(parse_agent_action(r#"{"action":"save_query","name":"x"}"#).is_err());
assert!(parse_agent_action(r#"{"action":"save_query","sql":"SELECT 1"}"#).is_err());
assert!(
parse_agent_action(r#"{"action":"save_query","name":" ","sql":"SELECT 1"}"#).is_err()
);
}
#[test]
fn parses_find_queries() {
let a = parse_agent_action(r#"{"action":"find_queries","text":"gmv"}"#).unwrap();
match a {
AgentAction::FindQueries { text } => assert_eq!(text, "gmv"),
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_find_queries_empty_text() {
assert!(parse_agent_action(r#"{"action":"find_queries","text":""}"#).is_err());
}
#[test]
fn last_user_turn_index_finds_last_user() {
let msgs = vec![
ChatMessage::User { id: "u1".into(), text: "first".into(), created_at: 1 },
ChatMessage::Assistant { id: "a1".into(), text: "ans".into(), created_at: 2 },
ChatMessage::User { id: "u2".into(), text: "second".into(), created_at: 3 },
ChatMessage::Assistant { id: "a2".into(), text: "ans2".into(), created_at: 4 },
];
assert_eq!(last_user_turn_index(&msgs), 2);
}
#[test]
fn last_user_turn_index_returns_len_when_no_user() {
let msgs = vec![ChatMessage::Assistant {
id: "a1".into(),
text: "alone".into(),
created_at: 1,
}];
assert_eq!(last_user_turn_index(&msgs), msgs.len());
}
#[test]
fn clean_summary_passes_plain_text() {
let s = "- bullet one\n- bullet two";
assert_eq!(clean_summary(s), s);
}
#[test]
fn clean_summary_strips_markdown_fences() {
let s = "```\n- bullet\n```";
assert_eq!(clean_summary(s), "- bullet");
}
#[test]
fn clean_summary_strips_lang_fence() {
let s = "```text\n- bullet one\n- bullet two\n```";
assert_eq!(clean_summary(s), "- bullet one\n- bullet two");
}
#[test]
fn clean_summary_extracts_text_field_from_json_envelope() {
let s = r#"{"action":"final","text":"- bullet one\n- bullet two"}"#;
assert_eq!(clean_summary(s), "- bullet one\n- bullet two");
}
#[test]
fn clean_summary_extracts_summary_field() {
let s = r#"{"summary":"- a\n- b"}"#;
assert_eq!(clean_summary(s), "- a\n- b");
}
#[test]
fn clean_summary_returns_unchanged_for_unrecognised_json() {
let s = r#"{"weird":42}"#;
assert_eq!(clean_summary(s), s);
}
#[test]
fn format_db_error_falls_back_to_to_string_for_non_db_errors() {
let e = TuskError::Custom("oops".into());
assert_eq!(format_db_error(&e), e.to_string());
}
#[test]
fn last_run_query_error_finds_most_recent() {
let msgs = vec![
ChatMessage::ToolResult {
id: "r1".into(),
tool: "run_query".into(),
is_error: true,
text: Some("first error".into()),
result: None,
created_at: 1,
},
ChatMessage::ToolResult {
id: "r2".into(),
tool: "get_columns".into(),
is_error: true,
text: Some("not run_query".into()),
result: None,
created_at: 2,
},
ChatMessage::ToolResult {
id: "r3".into(),
tool: "run_query".into(),
is_error: true,
text: Some("second error".into()),
result: None,
created_at: 3,
},
];
assert_eq!(
last_run_query_error(&msgs).as_deref(),
Some("second error")
);
}
#[test]
fn last_run_query_error_none_for_empty_thread() {
assert!(last_run_query_error(&[]).is_none());
}
#[test]
fn last_run_query_error_none_when_no_errors() {
let msgs = vec![ChatMessage::User {
id: "u1".into(),
text: "hi".into(),
created_at: 1,
}];
assert!(last_run_query_error(&msgs).is_none());
}
#[test]
fn parses_make_chart_minimal() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"bar","x":"carrier","y":"trips"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.chart_type, "bar");
assert_eq!(config.x, "carrier");
assert_eq!(config.y, "trips");
assert!(config.group.is_none());
assert!(config.title.is_none());
}
_ => panic!("wrong variant"),
}
}
#[test]
fn parses_make_chart_with_group_and_title() {
let a = parse_agent_action(
r#"{"action":"make_chart","chart_type":"line","x":"month","y":"revenue","group":"region","title":"Revenue"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => {
assert_eq!(config.group.as_deref(), Some("region"));
assert_eq!(config.title.as_deref(), Some("Revenue"));
}
_ => panic!("wrong variant"),
}
}
#[test]
fn make_chart_accepts_alternative_field_name_type() {
// Some models emit `type` instead of `chart_type`.
let a = parse_agent_action(
r#"{"action":"make_chart","type":"pie","x":"label","y":"value"}"#,
)
.unwrap();
match a {
AgentAction::MakeChart { config } => assert_eq!(config.chart_type, "pie"),
_ => panic!("wrong variant"),
}
}
#[test]
fn rejects_make_chart_with_unknown_chart_type() {
let r = parse_agent_action(
r#"{"action":"make_chart","chart_type":"radar","x":"a","y":"b"}"#,
);
assert!(r.is_err());
}
#[test]
fn rejects_make_chart_missing_x_or_y() {
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","y":"a"}"#).is_err());
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","x":"a"}"#).is_err());
}
#[test]
fn last_successful_query_result_finds_recent() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![vec![Value::Number(1.into())]],
row_count: 1,
execution_time_ms: 1,
};
let msgs = vec![
ChatMessage::ToolResult {
id: "r1".into(),
tool: "run_query".into(),
is_error: false,
text: None,
result: Some(qr.clone()),
created_at: 1,
},
ChatMessage::ToolResult {
id: "r2".into(),
tool: "run_query".into(),
is_error: true,
text: Some("oops".into()),
result: None,
created_at: 2,
},
];
let found = last_successful_query_result(&msgs).expect("ok");
assert_eq!(found.columns, vec!["a".to_string()]);
}
#[test]
fn last_successful_query_result_skips_non_run_query() {
use crate::models::query_result::QueryResult;
let qr = QueryResult {
columns: vec!["a".into()],
types: vec!["INT4".into()],
rows: vec![],
row_count: 0,
execution_time_ms: 0,
};
let msgs = vec![ChatMessage::ToolResult {
id: "r1".into(),
tool: "list_tables".into(),
is_error: false,
text: Some("public.x".into()),
result: Some(qr),
created_at: 1,
}];
assert!(last_successful_query_result(&msgs).is_none());
}
#[test]
fn render_thread_for_summary_includes_roles_and_skips_rows() {
let msgs = vec![
ChatMessage::User { id: "u1".into(), text: "find users".into(), created_at: 1 },
ChatMessage::ToolCall { id: "c1".into(), tool: "run_query".into(), input_json: r#"{"sql":"SELECT 1"}"#.into(), created_at: 2 },
ChatMessage::ToolResult {
id: "r1".into(),
tool: "run_query".into(),
is_error: false,
text: None,
result: Some(QueryResult {
columns: vec!["id".into(), "name".into()],
types: vec!["INT4".into(), "TEXT".into()],
rows: vec![vec![Value::Number(1.into()), Value::String("alice".into())]; 1000],
row_count: 1000,
execution_time_ms: 12,
}),
created_at: 3,
},
];
let rendered = render_thread_for_summary(&msgs);
assert!(rendered.contains("USER: find users"));
assert!(rendered.contains("TOOL_CALL [run_query]"));
assert!(rendered.contains("1000 rows"));
// Must NOT include the actual rows
assert!(!rendered.contains("alice"));
}
#[test]
fn rejects_legacy_get_schema() {
assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err());
}
#[test]
fn truncates_long_cell() {
let long = "a".repeat(CELL_CHAR_CAP + 50);
let v = truncate_cell(&Value::String(long));
let s = v.as_str().unwrap();
assert!(s.ends_with('…'));
assert!(s.chars().count() <= CELL_CHAR_CAP + 1);
}
#[test]
fn compact_drops_rows_beyond_sample() {
let mut rows = Vec::new();
for i in 0..50 {
rows.push(vec![Value::Number(i.into())]);
}
let qr = QueryResult {
columns: vec!["id".into()],
types: vec!["INT4".into()],
rows,
row_count: 50,
execution_time_ms: 1,
};
let v = compact_query_result(&qr);
let sample = v.get("sample_rows").unwrap().as_array().unwrap();
assert_eq!(sample.len(), RUN_QUERY_SAMPLE_ROWS);
assert_eq!(v.get("truncated").unwrap(), &Value::Bool(true));
assert_eq!(v.get("row_count").unwrap().as_u64().unwrap(), 50);
}
}