The chat usage badge was hardcoded to ~8K-token Ollama defaults (`CONTEXT_BUDGET_CHARS = 24_000`), which made every Fireworks session look 150%+ full after a few hops even though models like Kimi-K2 carry 256K context windows. Now the budget is selected per-provider: - Ollama → 24K chars (~8K tok), unchanged - Fireworks → 384K chars (~128K tok), a safe floor for the smallest Fireworks chat models (qwen2.5-coder 32K) while not stuffing the bar for the larger ones Auto-compact thresholds and the % badge both read this back from the backend, so they now scale correctly when the user switches providers.
1662 lines
65 KiB
Rust
1662 lines
65 KiB
Rust
use crate::commands::ai::{build_overview_context, call_chat_messages, load_ai_settings};
|
||
use crate::commands::chat_tools::{
|
||
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
|
||
switch_database_tool,
|
||
};
|
||
use crate::commands::memory::{append_memory_core, read_memory_core};
|
||
use crate::commands::queries::execute_query_core;
|
||
use crate::error::{TuskError, TuskResult};
|
||
use crate::models::ai::OllamaChatMessage;
|
||
use crate::models::chat::{ChartConfig, ChatMessage, ChatTurnResult, ContextUsage};
|
||
use crate::models::query_result::QueryResult;
|
||
use crate::state::AppState;
|
||
use chrono::Utc;
|
||
use serde_json::Value;
|
||
use std::sync::Arc;
|
||
use tauri::{AppHandle, State};
|
||
|
||
const MAX_HOPS: usize = 10;
|
||
/// Number of MOST RECENT run_query tool_results that get full sample-rows in
|
||
/// LLM history. Older ones are reduced to a marker so very long threads stay
|
||
/// within model context budget.
|
||
const RECENT_TOOL_RESULTS_FULL: usize = 4;
|
||
/// Sample-row cap for compressed run_query results in LLM history.
|
||
const RUN_QUERY_SAMPLE_ROWS: usize = 10;
|
||
/// Per-cell character cap when stringifying sample rows.
|
||
const CELL_CHAR_CAP: usize = 200;
|
||
/// Per text-tool-result character cap (list_tables, get_columns, etc).
|
||
const TEXT_TOOL_CHAR_CAP: usize = 10_000;
|
||
/// Soft cap on serialized history+system prompt characters before the user
|
||
/// is nudged to /compact. Tuned for Ollama defaults (~8K tokens at num_ctx=8192).
|
||
/// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content.
|
||
const CONTEXT_BUDGET_CHARS_OLLAMA: u64 = 24_000;
|
||
/// Conservative default for managed providers (Fireworks). Most chat-capable
|
||
/// Fireworks models ship with 32K–256K context windows; 384K chars (~128K tok)
|
||
/// is a safe floor that won't trigger false /compact nags on normal sessions
|
||
/// while still flagging genuinely runaway threads.
|
||
const CONTEXT_BUDGET_CHARS_FIREWORKS: u64 = 384_000;
|
||
/// Stop the loop when the model fails the same SQL hurdle this many times in a
|
||
/// row. Beyond this, additional hops almost always burn the rest of the budget
|
||
/// on identical retries; a definitive `final` with the error is more useful.
|
||
const MAX_CONSECUTIVE_QUERY_ERRORS: usize = 2;
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Action protocol
|
||
// ---------------------------------------------------------------------------
|
||
|
||
#[derive(Debug)]
|
||
enum AgentAction {
|
||
Final { text: String },
|
||
RunQuery { sql: String },
|
||
ListDatabases,
|
||
ListTables { database: Option<String> },
|
||
GetColumns { tables: Vec<String> },
|
||
SwitchDatabase { database: String },
|
||
Remember { note: String },
|
||
SaveQuery { name: String, sql: String },
|
||
FindQueries { text: String },
|
||
MakeChart { config: ChartConfig },
|
||
}
|
||
|
||
/// Parse the model's JSON response. Accepts both shapes the model tends to emit:
|
||
/// {"action":"X","field":"..."} — flat (matches our prompt)
|
||
/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention)
|
||
fn parse_agent_action(raw: &str) -> Result<AgentAction, String> {
|
||
let v: Value = serde_json::from_str(raw).map_err(|e| e.to_string())?;
|
||
let obj = v.as_object().ok_or_else(|| "expected JSON object".to_string())?;
|
||
let action = obj
|
||
.get("action")
|
||
.and_then(|a| a.as_str())
|
||
.ok_or_else(|| "missing field `action`".to_string())?;
|
||
|
||
let lookup = |key: &str| -> Option<&Value> {
|
||
obj.get(key)
|
||
.or_else(|| obj.get("input").and_then(|i| i.as_object()).and_then(|i| i.get(key)))
|
||
};
|
||
|
||
match action {
|
||
"final" => {
|
||
let text = lookup("text")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "final action missing `text`".to_string())?
|
||
.to_string();
|
||
Ok(AgentAction::Final { text })
|
||
}
|
||
"run_query" => {
|
||
let sql = lookup("sql")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "run_query action missing `sql`".to_string())?
|
||
.to_string();
|
||
Ok(AgentAction::RunQuery { sql })
|
||
}
|
||
"list_databases" => Ok(AgentAction::ListDatabases),
|
||
"list_tables" => {
|
||
let database = lookup("database")
|
||
.and_then(|v| v.as_str())
|
||
.map(|s| s.to_string());
|
||
Ok(AgentAction::ListTables { database })
|
||
}
|
||
"get_columns" => {
|
||
let arr = lookup("tables")
|
||
.and_then(|v| v.as_array())
|
||
.ok_or_else(|| "get_columns action missing `tables`: [...]".to_string())?;
|
||
let tables: Vec<String> = arr
|
||
.iter()
|
||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||
.collect();
|
||
if tables.is_empty() {
|
||
return Err("get_columns `tables` array must not be empty".into());
|
||
}
|
||
Ok(AgentAction::GetColumns { tables })
|
||
}
|
||
"switch_database" => {
|
||
let database = lookup("database")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "switch_database missing `database`".to_string())?
|
||
.to_string();
|
||
Ok(AgentAction::SwitchDatabase { database })
|
||
}
|
||
"remember" => {
|
||
let note = lookup("note")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "remember action missing `note`".to_string())?
|
||
.trim()
|
||
.to_string();
|
||
if note.is_empty() {
|
||
return Err("remember `note` must not be empty".into());
|
||
}
|
||
Ok(AgentAction::Remember { note })
|
||
}
|
||
"save_query" => {
|
||
let name = lookup("name")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "save_query missing `name`".to_string())?
|
||
.trim()
|
||
.to_string();
|
||
let sql = lookup("sql")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "save_query missing `sql`".to_string())?
|
||
.trim()
|
||
.to_string();
|
||
if name.is_empty() {
|
||
return Err("save_query `name` must not be empty".into());
|
||
}
|
||
if sql.is_empty() {
|
||
return Err("save_query `sql` must not be empty".into());
|
||
}
|
||
Ok(AgentAction::SaveQuery { name, sql })
|
||
}
|
||
"find_queries" => {
|
||
let text = lookup("text")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "find_queries missing `text`".to_string())?
|
||
.trim()
|
||
.to_string();
|
||
if text.is_empty() {
|
||
return Err("find_queries `text` must not be empty".into());
|
||
}
|
||
Ok(AgentAction::FindQueries { text })
|
||
}
|
||
"make_chart" => {
|
||
let chart_type = lookup("chart_type")
|
||
.or_else(|| lookup("type"))
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "make_chart missing `chart_type`".to_string())?
|
||
.trim()
|
||
.to_lowercase();
|
||
if !["bar", "line", "area", "pie"].contains(&chart_type.as_str()) {
|
||
return Err(format!(
|
||
"make_chart `chart_type` must be one of: bar, line, area, pie. Got: {}",
|
||
chart_type
|
||
));
|
||
}
|
||
let x = lookup("x")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "make_chart missing `x` column".to_string())?
|
||
.trim()
|
||
.to_string();
|
||
let y = lookup("y")
|
||
.and_then(|v| v.as_str())
|
||
.ok_or_else(|| "make_chart missing `y` column".to_string())?
|
||
.trim()
|
||
.to_string();
|
||
if x.is_empty() || y.is_empty() {
|
||
return Err("make_chart `x` and `y` must not be empty".into());
|
||
}
|
||
let group = lookup("group")
|
||
.and_then(|v| v.as_str())
|
||
.map(|s| s.trim().to_string())
|
||
.filter(|s| !s.is_empty());
|
||
let title = lookup("title")
|
||
.and_then(|v| v.as_str())
|
||
.map(|s| s.trim().to_string())
|
||
.filter(|s| !s.is_empty());
|
||
let orientation = lookup("orientation")
|
||
.and_then(|v| v.as_str())
|
||
.map(|s| s.trim().to_lowercase())
|
||
.filter(|s| !s.is_empty());
|
||
Ok(AgentAction::MakeChart {
|
||
config: ChartConfig {
|
||
chart_type,
|
||
x,
|
||
y,
|
||
group,
|
||
title,
|
||
orientation,
|
||
},
|
||
})
|
||
}
|
||
// Legacy from earlier iterations — silently ignored at parse time so the
|
||
// model can recover with a different action.
|
||
"get_schema" => Err(
|
||
"get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(),
|
||
),
|
||
other => Err(format!("unknown action `{}`", other)),
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// id / time helpers
|
||
// ---------------------------------------------------------------------------
|
||
|
||
fn now_ms() -> i64 {
|
||
Utc::now().timestamp_millis()
|
||
}
|
||
|
||
fn new_id(prefix: &str) -> String {
|
||
format!("{}-{}-{}", prefix, now_ms(), rand_suffix())
|
||
}
|
||
|
||
fn rand_suffix() -> String {
|
||
use std::time::{SystemTime, UNIX_EPOCH};
|
||
let nanos = SystemTime::now()
|
||
.duration_since(UNIX_EPOCH)
|
||
.map(|d| d.subsec_nanos())
|
||
.unwrap_or(0);
|
||
format!("{:x}", nanos)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// System prompt
|
||
// ---------------------------------------------------------------------------
|
||
|
||
fn system_prompt(overview: &str, memory: &str) -> String {
|
||
let overview_block = if overview.is_empty() {
|
||
"(overview unavailable; respond with `final` asking the user to reconnect.)".to_string()
|
||
} else {
|
||
overview.to_string()
|
||
};
|
||
|
||
let memory_block = if memory.trim().is_empty() {
|
||
"(empty — call remember() when you discover non-obvious facts about this database)".to_string()
|
||
} else {
|
||
memory.to_string()
|
||
};
|
||
|
||
format!(
|
||
r#"ROLE: Tusk's data assistant. Reply in the user's language.
|
||
|
||
You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On every turn output STRICT JSON — exactly one of these shapes, with all fields at the root (no `input` wrapper, no markdown fences):
|
||
|
||
{{"action":"list_databases"}}
|
||
Refresh the database list when the OVERVIEW seems stale.
|
||
|
||
{{"action":"list_tables"}}
|
||
List tables in the active database.
|
||
|
||
{{"action":"list_tables","database":"<name>"}}
|
||
List tables in a specific database (PostgreSQL: requires switch_database before run_query).
|
||
|
||
{{"action":"get_columns","tables":["schema.table","schema.table2"]}}
|
||
Load full column info (types, PK, FK, comments, enums) for the listed tables. Use this BEFORE writing SQL when you don't already know the columns.
|
||
|
||
{{"action":"switch_database","database":"<name>"}}
|
||
Change the active database. Required for PostgreSQL when the user's question concerns data in another database. ClickHouse rarely needs this — `db.table` qualifiers are allowed without switching.
|
||
|
||
{{"action":"run_query","sql":"SELECT ..."}}
|
||
Execute read-only SQL (SELECT / WITH ... SELECT / EXPLAIN / SHOW / DESCRIBE). Mutating SQL is rejected by the read-only guard.
|
||
|
||
{{"action":"remember","note":"<short observation>"}}
|
||
Persist a non-obvious fact about THIS database for future sessions: column semantics, naming conventions, business-rule encodings, gotchas. Keep notes < 200 chars. The user sees and can edit your notes in the Memory sidebar tab.
|
||
|
||
{{"action":"find_queries","text":"<keywords>"}}
|
||
Search saved queries (your past work + user-saved). Use BEFORE writing complex SQL — a usable variant may already exist. Top 10 matches with SQL preview are returned.
|
||
|
||
{{"action":"save_query","name":"<short label>","sql":"<the SQL>"}}
|
||
Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved.
|
||
|
||
{{"action":"make_chart","chart_type":"bar","x":"<col>","y":"<col>","title":"<short title>"}}
|
||
Visualise the LAST successful run_query result as a chart inline. `chart_type` is one of: bar, line, area, pie. `x` and `y` MUST be column names from the previous result. Optional: `group` (column for series), `orientation` ("vertical"/"horizontal", bar only). Use after run_query when the data is aggregated and would be clearer as a chart (top-N comparisons → bar; time series → line/area; proportions → pie). Skip for tiny results (≤2 rows) and giant ones (>500 rows).
|
||
|
||
{{"action":"final","text":"..."}}
|
||
End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling).
|
||
|
||
WORKFLOW
|
||
1. Read LEARNED NOTES below first — the user (or your past self) may have already documented relevant facts.
|
||
2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question.
|
||
3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs.
|
||
4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns.
|
||
5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first.
|
||
6. Execute run_query.
|
||
7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here.
|
||
8. If the query is likely to be re-run later (a real report-style request, not a one-off lookup), call `save_query` with a concise `name`.
|
||
9. Answer with `final`.
|
||
|
||
RULES
|
||
- Use ONLY identifiers visible to you (overview / list_tables / get_columns output). Don't pluralize, translate, or guess.
|
||
- **After get_columns, your next run_query must use ONLY column names that appear verbatim in that output.** Do not assume conventions like `name`, `id`, `title`. If get_columns shows a table has `company_name` and `legal_name` but no `name`, then `name` does NOT exist on that table.
|
||
- LIMIT on ad-hoc SELECTs unless aggregating.
|
||
- When run_query fails, READ the error carefully — especially any `HINT:` line, which often spells out the fix. Common PostgreSQL fixes:
|
||
* `operator does not exist: X = Y` (e.g. `character varying = uuid`) → cast one side, e.g. `a.id::uuid = b.id` or `a.id = b.id::text`. If unsure of types, call get_columns on both tables.
|
||
* `column "X" does not exist` → call get_columns on the table you're querying; the column is named differently. The error message lists which alias the column was attached to (e.g. `column le.name` means it's missing on the table aliased as `le`).
|
||
* `relation "X" does not exist` → check the OVERVIEW table list; the table may be in a different schema or database.
|
||
- On SQL error retry at most ONCE with a corrected query. On the second consecutive failure, STOP and respond with `final` explaining what's missing — do not loop. The harness will force-stop after 2 consecutive errors regardless.
|
||
- You have a hop budget of 10 tool calls per user turn. Spend them deliberately: don't burn hops re-running the same query — investigate (get_columns) when in doubt.
|
||
- `remember` is for durable facts, not transient observations. Don't memorise query results — only insights about the schema/data model that aren't already in the OVERVIEW.
|
||
|
||
═══════════════════════════════════════════════════════════════
|
||
LEARNED NOTES (per-connection memory; user can edit in sidebar → Memory)
|
||
═══════════════════════════════════════════════════════════════
|
||
{memory}
|
||
═══════════════════════════════════════════════════════════════
|
||
|
||
═══════════════════════════════════════════════════════════════
|
||
OVERVIEW (refreshed every turn)
|
||
═══════════════════════════════════════════════════════════════
|
||
{overview}
|
||
═══════════════════════════════════════════════════════════════
|
||
"#,
|
||
hops = MAX_HOPS,
|
||
memory = memory_block,
|
||
overview = overview_block,
|
||
)
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Compressed history projection
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Compact view of a QueryResult for re-injection into the LLM history.
|
||
/// Keeps just enough for the model to reason about the next step (column
|
||
/// names, types, total row count, first N rows) without the full payload.
|
||
fn compact_query_result(result: &QueryResult) -> Value {
|
||
let total = result.rows.len();
|
||
let sample: Vec<Vec<Value>> = result
|
||
.rows
|
||
.iter()
|
||
.take(RUN_QUERY_SAMPLE_ROWS)
|
||
.map(|row| row.iter().map(truncate_cell).collect())
|
||
.collect();
|
||
serde_json::json!({
|
||
"columns": result.columns,
|
||
"types": result.types,
|
||
"row_count": total,
|
||
"execution_time_ms": result.execution_time_ms,
|
||
"sample_rows": sample,
|
||
"truncated": total > RUN_QUERY_SAMPLE_ROWS,
|
||
})
|
||
}
|
||
|
||
fn truncate_cell(v: &Value) -> Value {
|
||
match v {
|
||
Value::String(s) if s.chars().count() > CELL_CHAR_CAP => {
|
||
let truncated: String = s.chars().take(CELL_CHAR_CAP).collect();
|
||
Value::String(format!("{}…", truncated))
|
||
}
|
||
other => other.clone(),
|
||
}
|
||
}
|
||
|
||
fn truncate_text(text: &str) -> String {
|
||
if text.len() <= TEXT_TOOL_CHAR_CAP {
|
||
text.to_string()
|
||
} else {
|
||
let mut out = text[..TEXT_TOOL_CHAR_CAP].to_string();
|
||
out.push_str("\n…(truncated)");
|
||
out
|
||
}
|
||
}
|
||
|
||
fn build_history(
|
||
messages: &[ChatMessage],
|
||
overview_text: &str,
|
||
memory_text: &str,
|
||
) -> Vec<OllamaChatMessage> {
|
||
// Index of run_query tool_results in `messages`. Used to mark which ones
|
||
// get full sample rows vs the "(rows omitted)" placeholder.
|
||
let run_query_indices: Vec<usize> = messages
|
||
.iter()
|
||
.enumerate()
|
||
.filter_map(|(i, m)| match m {
|
||
ChatMessage::ToolResult { tool, .. } if tool == "run_query" => Some(i),
|
||
_ => None,
|
||
})
|
||
.collect();
|
||
let keep_full_after_index: usize = if run_query_indices.len() <= RECENT_TOOL_RESULTS_FULL {
|
||
0
|
||
} else {
|
||
run_query_indices[run_query_indices.len() - RECENT_TOOL_RESULTS_FULL]
|
||
};
|
||
|
||
let mut out = Vec::with_capacity(messages.len() + 1);
|
||
out.push(OllamaChatMessage {
|
||
role: "system".to_string(),
|
||
content: system_prompt(overview_text, memory_text),
|
||
});
|
||
|
||
for (idx, m) in messages.iter().enumerate() {
|
||
match m {
|
||
ChatMessage::User { text, .. } => out.push(OllamaChatMessage {
|
||
role: "user".to_string(),
|
||
content: text.clone(),
|
||
}),
|
||
ChatMessage::Assistant { text, .. } => out.push(OllamaChatMessage {
|
||
role: "assistant".to_string(),
|
||
content: serde_json::json!({ "action": "final", "text": text }).to_string(),
|
||
}),
|
||
ChatMessage::ToolCall { tool, input_json, .. } => {
|
||
if tool == "get_schema" {
|
||
continue; // legacy
|
||
}
|
||
let mut envelope = serde_json::Map::new();
|
||
envelope.insert("action".to_string(), Value::String(tool.clone()));
|
||
if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) {
|
||
for (k, v) in input {
|
||
envelope.insert(k, v);
|
||
}
|
||
}
|
||
out.push(OllamaChatMessage {
|
||
role: "assistant".to_string(),
|
||
content: Value::Object(envelope).to_string(),
|
||
});
|
||
}
|
||
ChatMessage::ToolResult {
|
||
tool,
|
||
is_error,
|
||
text,
|
||
result,
|
||
..
|
||
} => {
|
||
if tool == "get_schema" {
|
||
continue; // legacy
|
||
}
|
||
let payload = match tool.as_str() {
|
||
"run_query" => {
|
||
if *is_error {
|
||
serde_json::json!({
|
||
"tool": "run_query",
|
||
"error": true,
|
||
"text": text.clone().unwrap_or_default(),
|
||
})
|
||
} else if idx < keep_full_after_index {
|
||
serde_json::json!({
|
||
"tool": "run_query",
|
||
"error": false,
|
||
"note": "rows omitted (older result; user has it in the UI above)",
|
||
})
|
||
} else if let Some(qr) = result {
|
||
serde_json::json!({
|
||
"tool": "run_query",
|
||
"error": false,
|
||
"result": compact_query_result(qr),
|
||
})
|
||
} else {
|
||
serde_json::json!({
|
||
"tool": "run_query",
|
||
"error": false,
|
||
"result": null,
|
||
})
|
||
}
|
||
}
|
||
// Text-only tools — pass through with cap.
|
||
_ => serde_json::json!({
|
||
"tool": tool,
|
||
"error": *is_error,
|
||
"text": text.as_deref().map(truncate_text),
|
||
}),
|
||
};
|
||
|
||
out.push(OllamaChatMessage {
|
||
role: "user".to_string(),
|
||
content: format!("TOOL_RESULT {}", payload),
|
||
});
|
||
}
|
||
}
|
||
}
|
||
out
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// chat_send
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Estimate how many characters the next LLM call will serialize to history
|
||
/// (system prompt + conversation, after compression). This is the same data
|
||
/// path as the actual call, so the count is exact for the chosen budget unit.
|
||
async fn compute_usage(
|
||
state: &AppState,
|
||
app: &AppHandle,
|
||
connection_id: &str,
|
||
working: &[ChatMessage],
|
||
) -> ContextUsage {
|
||
let overview = build_overview_context(state, connection_id)
|
||
.await
|
||
.unwrap_or_default();
|
||
let memory = read_memory_core(app, connection_id).unwrap_or_default();
|
||
let history = build_history(working, &overview, &memory);
|
||
// role string ("system"/"user"/"assistant") ≤ 9 chars + content + JSON envelope overhead
|
||
let used: u64 = history
|
||
.iter()
|
||
.map(|m| (m.role.len() + m.content.len() + 16) as u64)
|
||
.sum();
|
||
ContextUsage {
|
||
used_chars: used,
|
||
budget_chars: provider_budget_chars(state, app).await,
|
||
}
|
||
}
|
||
|
||
/// Returns the soft context budget appropriate for the currently-configured
|
||
/// LLM provider. Falls back to the Ollama default if settings can't be loaded.
|
||
async fn provider_budget_chars(state: &AppState, app: &AppHandle) -> u64 {
|
||
use crate::models::ai::AiProvider;
|
||
match load_ai_settings(app, state).await {
|
||
Ok(s) => match s.provider {
|
||
AiProvider::Fireworks => CONTEXT_BUDGET_CHARS_FIREWORKS,
|
||
_ => CONTEXT_BUDGET_CHARS_OLLAMA,
|
||
},
|
||
Err(_) => CONTEXT_BUDGET_CHARS_OLLAMA,
|
||
}
|
||
}
|
||
|
||
#[tauri::command]
|
||
pub async fn chat_send(
|
||
app: AppHandle,
|
||
state: State<'_, Arc<AppState>>,
|
||
connection_id: String,
|
||
messages: Vec<ChatMessage>,
|
||
) -> TuskResult<ChatTurnResult> {
|
||
let mut new_messages: Vec<ChatMessage> = Vec::new();
|
||
let mut working: Vec<ChatMessage> = messages;
|
||
let mut consecutive_query_errors: usize = 0;
|
||
|
||
for _hop in 0..MAX_HOPS {
|
||
// Hard guard: stop once the model has hit the same SQL hurdle multiple
|
||
// times in a row. Beyond this point further hops virtually always
|
||
// re-run the same broken query against the model's better judgment.
|
||
if consecutive_query_errors >= MAX_CONSECUTIVE_QUERY_ERRORS {
|
||
let last_err = last_run_query_error(&working).unwrap_or_default();
|
||
let msg = ChatMessage::Assistant {
|
||
id: new_id("asst"),
|
||
text: format!(
|
||
"I tried {} times and kept hitting the same SQL error:\n\n{}\n\nThis usually means a schema mismatch — wrong column type, missing cast (often `::uuid` or `::text`), or a join key that doesn't exist as I assumed. Could you double-check the question, or open the table in the sidebar to verify column types? You can also write the SQL manually in Advanced mode.",
|
||
consecutive_query_errors, last_err
|
||
),
|
||
created_at: now_ms(),
|
||
};
|
||
new_messages.push(msg.clone());
|
||
working.push(msg);
|
||
let usage = compute_usage(&state, &app, &connection_id, &working).await;
|
||
return Ok(ChatTurnResult {
|
||
messages: new_messages,
|
||
usage,
|
||
});
|
||
}
|
||
|
||
// Overview is rebuilt per turn — cheap (cached) and reflects the active DB
|
||
// even if the user (or the agent) just switched it.
|
||
let overview_text = build_overview_context(&state, &connection_id)
|
||
.await
|
||
.unwrap_or_default();
|
||
// Memory is read fresh each turn so user-side edits in the Memory tab
|
||
// are visible to the agent immediately.
|
||
let memory_text = read_memory_core(&app, &connection_id).unwrap_or_default();
|
||
|
||
let history = build_history(&working, &overview_text, &memory_text);
|
||
let raw =
|
||
call_chat_messages(&app, &state, history, Some("json".to_string())).await?;
|
||
let trimmed = raw.trim();
|
||
|
||
let action = match parse_agent_action(trimmed) {
|
||
Ok(a) => a,
|
||
Err(parse_err) => {
|
||
let msg = ChatMessage::Assistant {
|
||
id: new_id("asst"),
|
||
text: format!(
|
||
"{}\n\n_(Note: model returned non-protocol output: {})_",
|
||
trimmed, parse_err
|
||
),
|
||
created_at: now_ms(),
|
||
};
|
||
new_messages.push(msg.clone());
|
||
working.push(msg);
|
||
let usage = compute_usage(&state, &app, &connection_id, &working).await;
|
||
return Ok(ChatTurnResult {
|
||
messages: new_messages,
|
||
usage,
|
||
});
|
||
}
|
||
};
|
||
|
||
let is_run_query = matches!(&action, AgentAction::RunQuery { .. });
|
||
|
||
match action {
|
||
AgentAction::Final { text } => {
|
||
let msg = ChatMessage::Assistant {
|
||
id: new_id("asst"),
|
||
text,
|
||
created_at: now_ms(),
|
||
};
|
||
new_messages.push(msg.clone());
|
||
working.push(msg);
|
||
let usage = compute_usage(&state, &app, &connection_id, &working).await;
|
||
return Ok(ChatTurnResult {
|
||
messages: new_messages,
|
||
usage,
|
||
});
|
||
}
|
||
AgentAction::RunQuery { sql } => {
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"run_query",
|
||
serde_json::json!({ "sql": sql }).to_string(),
|
||
);
|
||
let result = match execute_query_core(&state, &connection_id, &sql).await {
|
||
Ok(qr) => {
|
||
consecutive_query_errors = 0;
|
||
ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "run_query".to_string(),
|
||
is_error: false,
|
||
text: None,
|
||
result: Some(qr),
|
||
created_at: now_ms(),
|
||
}
|
||
}
|
||
Err(e) => {
|
||
consecutive_query_errors += 1;
|
||
let suffix = match &e {
|
||
TuskError::ReadOnly => {
|
||
"\n\nRead-only mode is on. Toggle it off in the toolbar to allow writes."
|
||
}
|
||
_ => "",
|
||
};
|
||
ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "run_query".to_string(),
|
||
is_error: true,
|
||
text: Some(format!("{}{}", format_db_error(&e), suffix)),
|
||
result: None,
|
||
created_at: now_ms(),
|
||
}
|
||
}
|
||
};
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::ListDatabases => {
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"list_databases",
|
||
"{}".to_string(),
|
||
);
|
||
let result = run_text_tool(
|
||
list_databases_tool(&state, &connection_id).await,
|
||
"list_databases",
|
||
);
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::ListTables { database } => {
|
||
let input_json = match &database {
|
||
Some(db) => serde_json::json!({ "database": db }).to_string(),
|
||
None => "{}".to_string(),
|
||
};
|
||
push_tool_call(&mut new_messages, &mut working, "list_tables", input_json);
|
||
let result = run_text_tool(
|
||
list_tables_tool(&app, &state, &connection_id, database.as_deref()).await,
|
||
"list_tables",
|
||
);
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::GetColumns { tables } => {
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"get_columns",
|
||
serde_json::json!({ "tables": tables }).to_string(),
|
||
);
|
||
let result = run_text_tool(
|
||
get_columns_tool(&state, &connection_id, &tables).await,
|
||
"get_columns",
|
||
);
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::SwitchDatabase { database } => {
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"switch_database",
|
||
serde_json::json!({ "database": &database }).to_string(),
|
||
);
|
||
let result = run_text_tool(
|
||
switch_database_tool(&app, &state, &connection_id, &database).await,
|
||
"switch_database",
|
||
);
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::Remember { note } => {
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"remember",
|
||
serde_json::json!({ "note": ¬e }).to_string(),
|
||
);
|
||
let outcome = append_memory_core(&app, &connection_id, ¬e)
|
||
.map(|_| format!("Saved note ({} chars).", note.len()));
|
||
let result = run_text_tool(outcome, "remember");
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::SaveQuery { name, sql } => {
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"save_query",
|
||
serde_json::json!({ "name": &name, "sql": &sql }).to_string(),
|
||
);
|
||
let result = run_text_tool(
|
||
save_query_tool(&app, &connection_id, &name, &sql).await,
|
||
"save_query",
|
||
);
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::FindQueries { text } => {
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"find_queries",
|
||
serde_json::json!({ "text": &text }).to_string(),
|
||
);
|
||
let result = run_text_tool(
|
||
find_queries_tool(&app, &connection_id, &text).await,
|
||
"find_queries",
|
||
);
|
||
push_tool_result(&mut new_messages, &mut working, result);
|
||
}
|
||
AgentAction::MakeChart { config } => {
|
||
let config_json = serde_json::to_string(&config).unwrap_or_else(|_| "{}".into());
|
||
push_tool_call(
|
||
&mut new_messages,
|
||
&mut working,
|
||
"make_chart",
|
||
config_json.clone(),
|
||
);
|
||
|
||
let result_msg = match last_successful_query_result(&working) {
|
||
None => ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "make_chart".to_string(),
|
||
is_error: true,
|
||
text: Some(
|
||
"make_chart needs a successful run_query result above it. Run a SELECT first, then call make_chart."
|
||
.to_string(),
|
||
),
|
||
result: None,
|
||
created_at: now_ms(),
|
||
},
|
||
Some(qr) => {
|
||
if !qr.columns.iter().any(|c| c == &config.x) {
|
||
ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "make_chart".to_string(),
|
||
is_error: true,
|
||
text: Some(format!(
|
||
"x column `{}` is not in the last result. Available: {}.",
|
||
config.x,
|
||
qr.columns.join(", ")
|
||
)),
|
||
result: None,
|
||
created_at: now_ms(),
|
||
}
|
||
} else if !qr.columns.iter().any(|c| c == &config.y) {
|
||
ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "make_chart".to_string(),
|
||
is_error: true,
|
||
text: Some(format!(
|
||
"y column `{}` is not in the last result. Available: {}.",
|
||
config.y,
|
||
qr.columns.join(", ")
|
||
)),
|
||
result: None,
|
||
created_at: now_ms(),
|
||
}
|
||
} else if let Some(group) = &config.group {
|
||
if !qr.columns.iter().any(|c| c == group) {
|
||
ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "make_chart".to_string(),
|
||
is_error: true,
|
||
text: Some(format!(
|
||
"group column `{}` is not in the last result. Available: {}.",
|
||
group,
|
||
qr.columns.join(", ")
|
||
)),
|
||
result: None,
|
||
created_at: now_ms(),
|
||
}
|
||
} else {
|
||
ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "make_chart".to_string(),
|
||
is_error: false,
|
||
text: Some(config_json.clone()),
|
||
result: Some(qr),
|
||
created_at: now_ms(),
|
||
}
|
||
}
|
||
} else {
|
||
ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: "make_chart".to_string(),
|
||
is_error: false,
|
||
text: Some(config_json.clone()),
|
||
result: Some(qr),
|
||
created_at: now_ms(),
|
||
}
|
||
}
|
||
}
|
||
};
|
||
push_tool_result(&mut new_messages, &mut working, result_msg);
|
||
}
|
||
}
|
||
|
||
// Any non-RunQuery, non-Final action means the model is investigating
|
||
// (e.g. get_columns to verify a type). Give it a fresh error budget.
|
||
if !is_run_query {
|
||
consecutive_query_errors = 0;
|
||
}
|
||
}
|
||
|
||
// Last-chance synthesis: model is out of tool calls but may have collected
|
||
// enough data on the last hop to answer. One extra LLM call, no JSON
|
||
// protocol, just plain text.
|
||
let synthesis = force_final_synthesis(&app, &state, &working).await;
|
||
let text = match synthesis {
|
||
Some(t) => format!("{}\n\n_(Tool-call limit reached; answer synthesised from collected results.)_", t),
|
||
None => format!(
|
||
"Stopped after {} tool calls without a final answer. Try rephrasing the question, splitting it into smaller parts, or running the SQL manually in Advanced mode.",
|
||
MAX_HOPS
|
||
),
|
||
};
|
||
let msg = ChatMessage::Assistant {
|
||
id: new_id("asst"),
|
||
text,
|
||
created_at: now_ms(),
|
||
};
|
||
new_messages.push(msg.clone());
|
||
working.push(msg);
|
||
let usage = compute_usage(&state, &app, &connection_id, &working).await;
|
||
Ok(ChatTurnResult {
|
||
messages: new_messages,
|
||
usage,
|
||
})
|
||
}
|
||
|
||
fn push_tool_call(
|
||
new_messages: &mut Vec<ChatMessage>,
|
||
working: &mut Vec<ChatMessage>,
|
||
tool: &str,
|
||
input_json: String,
|
||
) {
|
||
let call = ChatMessage::ToolCall {
|
||
id: new_id("call"),
|
||
tool: tool.to_string(),
|
||
input_json,
|
||
created_at: now_ms(),
|
||
};
|
||
new_messages.push(call.clone());
|
||
working.push(call);
|
||
}
|
||
|
||
fn push_tool_result(
|
||
new_messages: &mut Vec<ChatMessage>,
|
||
working: &mut Vec<ChatMessage>,
|
||
result: ChatMessage,
|
||
) {
|
||
new_messages.push(result.clone());
|
||
working.push(result);
|
||
}
|
||
|
||
fn run_text_tool(outcome: TuskResult<String>, tool: &str) -> ChatMessage {
|
||
match outcome {
|
||
Ok(text) => ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: tool.to_string(),
|
||
is_error: false,
|
||
text: Some(text),
|
||
result: None,
|
||
created_at: now_ms(),
|
||
},
|
||
Err(e) => ChatMessage::ToolResult {
|
||
id: new_id("res"),
|
||
tool: tool.to_string(),
|
||
is_error: true,
|
||
text: Some(e.to_string()),
|
||
result: None,
|
||
created_at: now_ms(),
|
||
},
|
||
}
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// chat_compact
|
||
// ---------------------------------------------------------------------------
|
||
|
||
/// Render the older-history portion of the thread as a compact text block
|
||
/// for LLM-driven summarization. Skips QueryResult.rows (huge), keeps only
|
||
/// columns + row_count + sample.
|
||
fn render_thread_for_summary(messages: &[ChatMessage]) -> String {
|
||
let mut out = String::new();
|
||
for m in messages {
|
||
match m {
|
||
ChatMessage::User { text, .. } => {
|
||
out.push_str(&format!("USER: {}\n\n", text));
|
||
}
|
||
ChatMessage::Assistant { text, .. } => {
|
||
out.push_str(&format!("ASSISTANT: {}\n\n", text));
|
||
}
|
||
ChatMessage::ToolCall { tool, input_json, .. } => {
|
||
out.push_str(&format!("TOOL_CALL [{}]: {}\n\n", tool, input_json));
|
||
}
|
||
ChatMessage::ToolResult {
|
||
tool,
|
||
is_error,
|
||
text,
|
||
result,
|
||
..
|
||
} => {
|
||
if *is_error {
|
||
out.push_str(&format!(
|
||
"TOOL_ERROR [{}]: {}\n\n",
|
||
tool,
|
||
text.as_deref().unwrap_or("")
|
||
));
|
||
continue;
|
||
}
|
||
if let Some(qr) = result {
|
||
out.push_str(&format!(
|
||
"TOOL_RESULT [{}]: {} rows; columns={}\n\n",
|
||
tool,
|
||
qr.row_count,
|
||
qr.columns.join(", ")
|
||
));
|
||
} else if let Some(t) = text {
|
||
let snippet: String = t.chars().take(800).collect();
|
||
out.push_str(&format!("TOOL_RESULT [{}]: {}\n\n", tool, snippet));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
out
|
||
}
|
||
|
||
/// Format a database error with PostgreSQL HINT/DETAIL extracted when present.
|
||
/// Plain `e.to_string()` shows just the bare error message, which makes
|
||
/// errors like `operator does not exist: character varying = uuid` impossible
|
||
/// to act on. PG's own HINT (`You might need to add explicit type casts`) is
|
||
/// often the difference between the agent finding the fix on retry and
|
||
/// looping forever.
|
||
fn format_db_error(e: &TuskError) -> String {
|
||
if let TuskError::Database(sqlx::Error::Database(db_err)) = e {
|
||
let mut out = format!("Database error: {}", db_err.message());
|
||
if let Some(pg) = db_err.try_downcast_ref::<sqlx::postgres::PgDatabaseError>() {
|
||
if let Some(detail) = pg.detail() {
|
||
out.push_str(&format!("\nDETAIL: {}", detail));
|
||
}
|
||
if let Some(hint) = pg.hint() {
|
||
out.push_str(&format!("\nHINT: {}", hint));
|
||
}
|
||
}
|
||
return out;
|
||
}
|
||
e.to_string()
|
||
}
|
||
|
||
/// Locate the most recent SUCCESSFUL run_query in the working thread and
|
||
/// return its full QueryResult. Used by make_chart to attach data to a chart
|
||
/// directive without relying on the model to re-send it.
|
||
fn last_successful_query_result(messages: &[ChatMessage]) -> Option<QueryResult> {
|
||
for m in messages.iter().rev() {
|
||
if let ChatMessage::ToolResult {
|
||
tool,
|
||
is_error: false,
|
||
result: Some(qr),
|
||
..
|
||
} = m
|
||
{
|
||
if tool == "run_query" {
|
||
return Some(qr.clone());
|
||
}
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
/// Pull the most recent run_query error text from the working thread, so the
|
||
/// post-loop "I gave up" summary can quote concrete errors back to the user.
|
||
fn last_run_query_error(messages: &[ChatMessage]) -> Option<String> {
|
||
for m in messages.iter().rev() {
|
||
if let ChatMessage::ToolResult {
|
||
tool,
|
||
is_error: true,
|
||
text,
|
||
..
|
||
} = m
|
||
{
|
||
if tool == "run_query" {
|
||
return text.clone();
|
||
}
|
||
}
|
||
}
|
||
None
|
||
}
|
||
|
||
/// Last-chance LLM call after MAX_HOPS is exhausted: nudge the model to
|
||
/// produce a `final` answer based on whatever data and tool results the
|
||
/// thread already contains. Without this, a thread that succeeded on its
|
||
/// last run_query would end with "Stopped after N tool calls" and waste
|
||
/// the result. Returns None if the LLM call fails.
|
||
async fn force_final_synthesis(
|
||
app: &AppHandle,
|
||
state: &AppState,
|
||
working: &[ChatMessage],
|
||
) -> Option<String> {
|
||
if working.is_empty() {
|
||
return None;
|
||
}
|
||
let convo = render_thread_for_summary(working);
|
||
let system = "The agent loop has reached the tool-call limit. The user is waiting for an answer right now. \
|
||
Based ONLY on the conversation below, write a SHORT plain-text answer for the user. \
|
||
Reply in the SAME language the user used. \
|
||
If a query produced results, summarise what those results show. \
|
||
If queries kept failing, explain what went wrong and what the user could do (provide the missing piece, switch to Advanced mode, etc.). \
|
||
Be concrete with numbers and identifiers from the results.\n\
|
||
\n\
|
||
OUTPUT FORMAT: PLAIN TEXT ONLY. \
|
||
DO NOT output JSON, markdown fences, or field names. \
|
||
DO NOT call any tools. DO NOT use the action protocol. \
|
||
Just the answer text.";
|
||
|
||
let llm_messages = vec![
|
||
OllamaChatMessage {
|
||
role: "system".to_string(),
|
||
content: system.to_string(),
|
||
},
|
||
OllamaChatMessage {
|
||
role: "user".to_string(),
|
||
content: convo,
|
||
},
|
||
];
|
||
match call_chat_messages(app, state, llm_messages, None).await {
|
||
Ok(s) => {
|
||
let cleaned = clean_summary(&s);
|
||
if cleaned.trim().is_empty() {
|
||
None
|
||
} else {
|
||
Some(cleaned)
|
||
}
|
||
}
|
||
Err(_) => None,
|
||
}
|
||
}
|
||
|
||
/// Strip JSON envelopes, markdown fences, and known field-extraction patterns
|
||
/// that the agent-trained model tends to emit even for non-agent prompts.
|
||
/// Returns the underlying summary text.
|
||
fn clean_summary(raw: &str) -> String {
|
||
let trimmed = raw.trim();
|
||
if trimmed.is_empty() {
|
||
return trimmed.to_string();
|
||
}
|
||
|
||
// Strip ```...``` fences (with or without lang).
|
||
let unfenced = if trimmed.starts_with("```") {
|
||
let body = trimmed.trim_start_matches('`').trim_start_matches('`').trim_start_matches('`');
|
||
// Drop optional language identifier on the first line.
|
||
let after_lang = body.split_once('\n').map(|(_, rest)| rest).unwrap_or(body);
|
||
let trimmed_end = after_lang.trim_end_matches('`').trim_end_matches('`').trim_end_matches('`');
|
||
trimmed_end.trim().to_string()
|
||
} else {
|
||
trimmed.to_string()
|
||
};
|
||
|
||
// If the model returned a JSON envelope, extract a known string field.
|
||
if unfenced.starts_with('{') {
|
||
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&unfenced) {
|
||
for key in ["text", "summary", "content", "answer", "output"] {
|
||
if let Some(s) = v.get(key).and_then(|x| x.as_str()) {
|
||
return s.trim().to_string();
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
unfenced
|
||
}
|
||
|
||
/// Find the index of the last User message; returns messages.len() if no user message.
|
||
fn last_user_turn_index(messages: &[ChatMessage]) -> usize {
|
||
for (i, m) in messages.iter().enumerate().rev() {
|
||
if matches!(m, ChatMessage::User { .. }) {
|
||
return i;
|
||
}
|
||
}
|
||
messages.len()
|
||
}
|
||
|
||
/// LLM-summarise the older portion of a chat thread.
|
||
/// Returns thread = [ Assistant("📋 Compacted: …") , <last_user_turn_if_any> ].
|
||
/// If the thread has nothing to compact, returns it unchanged.
|
||
#[tauri::command]
|
||
pub async fn chat_compact(
|
||
app: AppHandle,
|
||
state: State<'_, Arc<AppState>>,
|
||
connection_id: String,
|
||
messages: Vec<ChatMessage>,
|
||
) -> TuskResult<ChatTurnResult> {
|
||
if messages.is_empty() {
|
||
let usage = compute_usage(&state, &app, &connection_id, &messages).await;
|
||
return Ok(ChatTurnResult { messages, usage });
|
||
}
|
||
|
||
// Preserve the user's most recent question (if any) untouched so the
|
||
// model can continue from it after compaction. Everything before goes
|
||
// into the summary.
|
||
let split_at = last_user_turn_index(&messages);
|
||
let (older, recent): (&[ChatMessage], &[ChatMessage]) = if split_at == messages.len() {
|
||
(&messages[..], &[])
|
||
} else {
|
||
(&messages[..split_at], &messages[split_at..])
|
||
};
|
||
|
||
if older.is_empty() {
|
||
let usage = compute_usage(&state, &app, &connection_id, &messages).await;
|
||
return Ok(ChatTurnResult { messages, usage });
|
||
}
|
||
|
||
let convo = render_thread_for_summary(older);
|
||
let system = "You are a precise summarizer of a database analysis dialogue. \
|
||
Produce a SHORT summary in the SAME language the user spoke. \
|
||
Use 3-6 bullet points covering: the user's goal, key tables/columns/queries used, \
|
||
numerical findings, conclusions reached, any open questions. \
|
||
Be concrete with numbers and identifiers. Total length < 800 chars.\n\
|
||
\n\
|
||
OUTPUT FORMAT: PLAIN TEXT. Start each bullet with `- `. \
|
||
DO NOT output JSON. DO NOT wrap output in `{` or `}`. \
|
||
DO NOT use markdown fences. DO NOT include field names like `action`, `text`, `summary`. \
|
||
DO NOT add a preamble. The first character of your reply must be `-`.";
|
||
|
||
let llm_messages = vec![
|
||
OllamaChatMessage {
|
||
role: "system".to_string(),
|
||
content: system.to_string(),
|
||
},
|
||
OllamaChatMessage {
|
||
role: "user".to_string(),
|
||
content: convo,
|
||
},
|
||
];
|
||
let summary = call_chat_messages(&app, &state, llm_messages, None)
|
||
.await
|
||
.map_err(|e| TuskError::Ai(format!("Compact failed: {}", e)))?;
|
||
|
||
let cleaned = clean_summary(&summary);
|
||
let compacted_msg = ChatMessage::Assistant {
|
||
id: new_id("asst"),
|
||
text: format!(
|
||
"📋 Compacted {} earlier message{}:\n\n{}",
|
||
older.len(),
|
||
if older.len() == 1 { "" } else { "s" },
|
||
cleaned
|
||
),
|
||
created_at: now_ms(),
|
||
};
|
||
|
||
let mut out: Vec<ChatMessage> = Vec::with_capacity(1 + recent.len());
|
||
out.push(compacted_msg);
|
||
out.extend(recent.iter().cloned());
|
||
|
||
let usage = compute_usage(&state, &app, &connection_id, &out).await;
|
||
Ok(ChatTurnResult {
|
||
messages: out,
|
||
usage,
|
||
})
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// tests
|
||
// ---------------------------------------------------------------------------
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn parses_flat_run_query() {
|
||
let a = parse_agent_action(r#"{"action":"run_query","sql":"SELECT 1"}"#).unwrap();
|
||
match a {
|
||
AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 1"),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parses_nested_run_query() {
|
||
let a =
|
||
parse_agent_action(r#"{"action":"run_query","input":{"sql":"SELECT 2"}}"#).unwrap();
|
||
match a {
|
||
AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 2"),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parses_get_columns() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"get_columns","tables":["public.users","public.orders"]}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::GetColumns { tables } => {
|
||
assert_eq!(tables, vec!["public.users", "public.orders"]);
|
||
}
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parses_get_columns_nested() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"get_columns","input":{"tables":["public.t"]}}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::GetColumns { tables } => assert_eq!(tables, vec!["public.t"]),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_get_columns_empty_tables() {
|
||
assert!(parse_agent_action(r#"{"action":"get_columns","tables":[]}"#).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn parses_switch_database() {
|
||
let a = parse_agent_action(r#"{"action":"switch_database","database":"orders_db"}"#)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::SwitchDatabase { database } => assert_eq!(database, "orders_db"),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parses_list_tables_optional_db() {
|
||
let a1 = parse_agent_action(r#"{"action":"list_tables"}"#).unwrap();
|
||
match a1 {
|
||
AgentAction::ListTables { database } => assert!(database.is_none()),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
let a2 = parse_agent_action(r#"{"action":"list_tables","database":"x"}"#).unwrap();
|
||
match a2 {
|
||
AgentAction::ListTables { database } => assert_eq!(database.as_deref(), Some("x")),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_unknown_action() {
|
||
assert!(parse_agent_action(r#"{"action":"nuke","yes":true}"#).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn parses_remember_flat() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"remember","note":"trips.started_at is NULL for cancelled"}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::Remember { note } => {
|
||
assert_eq!(note, "trips.started_at is NULL for cancelled");
|
||
}
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parses_remember_nested() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"remember","input":{"note":" surrounded by spaces "}}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::Remember { note } => {
|
||
// trim happens in parser
|
||
assert_eq!(note, "surrounded by spaces");
|
||
}
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_remember_without_note() {
|
||
assert!(parse_agent_action(r#"{"action":"remember"}"#).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_remember_empty_note() {
|
||
assert!(parse_agent_action(r#"{"action":"remember","note":" "}"#).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn parses_save_query_flat() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"save_query","name":"GMV last 30d","sql":"SELECT 1"}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::SaveQuery { name, sql } => {
|
||
assert_eq!(name, "GMV last 30d");
|
||
assert_eq!(sql, "SELECT 1");
|
||
}
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parses_save_query_nested() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"save_query","input":{"name":"x","sql":"SELECT 2"}}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::SaveQuery { name, sql } => {
|
||
assert_eq!(name, "x");
|
||
assert_eq!(sql, "SELECT 2");
|
||
}
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_save_query_missing_fields() {
|
||
assert!(parse_agent_action(r#"{"action":"save_query","name":"x"}"#).is_err());
|
||
assert!(parse_agent_action(r#"{"action":"save_query","sql":"SELECT 1"}"#).is_err());
|
||
assert!(
|
||
parse_agent_action(r#"{"action":"save_query","name":" ","sql":"SELECT 1"}"#).is_err()
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn parses_find_queries() {
|
||
let a = parse_agent_action(r#"{"action":"find_queries","text":"gmv"}"#).unwrap();
|
||
match a {
|
||
AgentAction::FindQueries { text } => assert_eq!(text, "gmv"),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_find_queries_empty_text() {
|
||
assert!(parse_agent_action(r#"{"action":"find_queries","text":""}"#).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn last_user_turn_index_finds_last_user() {
|
||
let msgs = vec![
|
||
ChatMessage::User { id: "u1".into(), text: "first".into(), created_at: 1 },
|
||
ChatMessage::Assistant { id: "a1".into(), text: "ans".into(), created_at: 2 },
|
||
ChatMessage::User { id: "u2".into(), text: "second".into(), created_at: 3 },
|
||
ChatMessage::Assistant { id: "a2".into(), text: "ans2".into(), created_at: 4 },
|
||
];
|
||
assert_eq!(last_user_turn_index(&msgs), 2);
|
||
}
|
||
|
||
#[test]
|
||
fn last_user_turn_index_returns_len_when_no_user() {
|
||
let msgs = vec![ChatMessage::Assistant {
|
||
id: "a1".into(),
|
||
text: "alone".into(),
|
||
created_at: 1,
|
||
}];
|
||
assert_eq!(last_user_turn_index(&msgs), msgs.len());
|
||
}
|
||
|
||
#[test]
|
||
fn clean_summary_passes_plain_text() {
|
||
let s = "- bullet one\n- bullet two";
|
||
assert_eq!(clean_summary(s), s);
|
||
}
|
||
|
||
#[test]
|
||
fn clean_summary_strips_markdown_fences() {
|
||
let s = "```\n- bullet\n```";
|
||
assert_eq!(clean_summary(s), "- bullet");
|
||
}
|
||
|
||
#[test]
|
||
fn clean_summary_strips_lang_fence() {
|
||
let s = "```text\n- bullet one\n- bullet two\n```";
|
||
assert_eq!(clean_summary(s), "- bullet one\n- bullet two");
|
||
}
|
||
|
||
#[test]
|
||
fn clean_summary_extracts_text_field_from_json_envelope() {
|
||
let s = r#"{"action":"final","text":"- bullet one\n- bullet two"}"#;
|
||
assert_eq!(clean_summary(s), "- bullet one\n- bullet two");
|
||
}
|
||
|
||
#[test]
|
||
fn clean_summary_extracts_summary_field() {
|
||
let s = r#"{"summary":"- a\n- b"}"#;
|
||
assert_eq!(clean_summary(s), "- a\n- b");
|
||
}
|
||
|
||
#[test]
|
||
fn clean_summary_returns_unchanged_for_unrecognised_json() {
|
||
let s = r#"{"weird":42}"#;
|
||
assert_eq!(clean_summary(s), s);
|
||
}
|
||
|
||
#[test]
|
||
fn format_db_error_falls_back_to_to_string_for_non_db_errors() {
|
||
let e = TuskError::Custom("oops".into());
|
||
assert_eq!(format_db_error(&e), e.to_string());
|
||
}
|
||
|
||
#[test]
|
||
fn last_run_query_error_finds_most_recent() {
|
||
let msgs = vec![
|
||
ChatMessage::ToolResult {
|
||
id: "r1".into(),
|
||
tool: "run_query".into(),
|
||
is_error: true,
|
||
text: Some("first error".into()),
|
||
result: None,
|
||
created_at: 1,
|
||
},
|
||
ChatMessage::ToolResult {
|
||
id: "r2".into(),
|
||
tool: "get_columns".into(),
|
||
is_error: true,
|
||
text: Some("not run_query".into()),
|
||
result: None,
|
||
created_at: 2,
|
||
},
|
||
ChatMessage::ToolResult {
|
||
id: "r3".into(),
|
||
tool: "run_query".into(),
|
||
is_error: true,
|
||
text: Some("second error".into()),
|
||
result: None,
|
||
created_at: 3,
|
||
},
|
||
];
|
||
assert_eq!(
|
||
last_run_query_error(&msgs).as_deref(),
|
||
Some("second error")
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn last_run_query_error_none_for_empty_thread() {
|
||
assert!(last_run_query_error(&[]).is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn last_run_query_error_none_when_no_errors() {
|
||
let msgs = vec![ChatMessage::User {
|
||
id: "u1".into(),
|
||
text: "hi".into(),
|
||
created_at: 1,
|
||
}];
|
||
assert!(last_run_query_error(&msgs).is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn parses_make_chart_minimal() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"make_chart","chart_type":"bar","x":"carrier","y":"trips"}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::MakeChart { config } => {
|
||
assert_eq!(config.chart_type, "bar");
|
||
assert_eq!(config.x, "carrier");
|
||
assert_eq!(config.y, "trips");
|
||
assert!(config.group.is_none());
|
||
assert!(config.title.is_none());
|
||
}
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parses_make_chart_with_group_and_title() {
|
||
let a = parse_agent_action(
|
||
r#"{"action":"make_chart","chart_type":"line","x":"month","y":"revenue","group":"region","title":"Revenue"}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::MakeChart { config } => {
|
||
assert_eq!(config.group.as_deref(), Some("region"));
|
||
assert_eq!(config.title.as_deref(), Some("Revenue"));
|
||
}
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn make_chart_accepts_alternative_field_name_type() {
|
||
// Some models emit `type` instead of `chart_type`.
|
||
let a = parse_agent_action(
|
||
r#"{"action":"make_chart","type":"pie","x":"label","y":"value"}"#,
|
||
)
|
||
.unwrap();
|
||
match a {
|
||
AgentAction::MakeChart { config } => assert_eq!(config.chart_type, "pie"),
|
||
_ => panic!("wrong variant"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_make_chart_with_unknown_chart_type() {
|
||
let r = parse_agent_action(
|
||
r#"{"action":"make_chart","chart_type":"radar","x":"a","y":"b"}"#,
|
||
);
|
||
assert!(r.is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_make_chart_missing_x_or_y() {
|
||
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","y":"a"}"#).is_err());
|
||
assert!(parse_agent_action(r#"{"action":"make_chart","chart_type":"bar","x":"a"}"#).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn last_successful_query_result_finds_recent() {
|
||
use crate::models::query_result::QueryResult;
|
||
let qr = QueryResult {
|
||
columns: vec!["a".into()],
|
||
types: vec!["INT4".into()],
|
||
rows: vec![vec![Value::Number(1.into())]],
|
||
row_count: 1,
|
||
execution_time_ms: 1,
|
||
};
|
||
let msgs = vec![
|
||
ChatMessage::ToolResult {
|
||
id: "r1".into(),
|
||
tool: "run_query".into(),
|
||
is_error: false,
|
||
text: None,
|
||
result: Some(qr.clone()),
|
||
created_at: 1,
|
||
},
|
||
ChatMessage::ToolResult {
|
||
id: "r2".into(),
|
||
tool: "run_query".into(),
|
||
is_error: true,
|
||
text: Some("oops".into()),
|
||
result: None,
|
||
created_at: 2,
|
||
},
|
||
];
|
||
let found = last_successful_query_result(&msgs).expect("ok");
|
||
assert_eq!(found.columns, vec!["a".to_string()]);
|
||
}
|
||
|
||
#[test]
|
||
fn last_successful_query_result_skips_non_run_query() {
|
||
use crate::models::query_result::QueryResult;
|
||
let qr = QueryResult {
|
||
columns: vec!["a".into()],
|
||
types: vec!["INT4".into()],
|
||
rows: vec![],
|
||
row_count: 0,
|
||
execution_time_ms: 0,
|
||
};
|
||
let msgs = vec![ChatMessage::ToolResult {
|
||
id: "r1".into(),
|
||
tool: "list_tables".into(),
|
||
is_error: false,
|
||
text: Some("public.x".into()),
|
||
result: Some(qr),
|
||
created_at: 1,
|
||
}];
|
||
assert!(last_successful_query_result(&msgs).is_none());
|
||
}
|
||
|
||
#[test]
|
||
fn render_thread_for_summary_includes_roles_and_skips_rows() {
|
||
let msgs = vec![
|
||
ChatMessage::User { id: "u1".into(), text: "find users".into(), created_at: 1 },
|
||
ChatMessage::ToolCall { id: "c1".into(), tool: "run_query".into(), input_json: r#"{"sql":"SELECT 1"}"#.into(), created_at: 2 },
|
||
ChatMessage::ToolResult {
|
||
id: "r1".into(),
|
||
tool: "run_query".into(),
|
||
is_error: false,
|
||
text: None,
|
||
result: Some(QueryResult {
|
||
columns: vec!["id".into(), "name".into()],
|
||
types: vec!["INT4".into(), "TEXT".into()],
|
||
rows: vec![vec![Value::Number(1.into()), Value::String("alice".into())]; 1000],
|
||
row_count: 1000,
|
||
execution_time_ms: 12,
|
||
}),
|
||
created_at: 3,
|
||
},
|
||
];
|
||
let rendered = render_thread_for_summary(&msgs);
|
||
assert!(rendered.contains("USER: find users"));
|
||
assert!(rendered.contains("TOOL_CALL [run_query]"));
|
||
assert!(rendered.contains("1000 rows"));
|
||
// Must NOT include the actual rows
|
||
assert!(!rendered.contains("alice"));
|
||
}
|
||
|
||
#[test]
|
||
fn rejects_legacy_get_schema() {
|
||
assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err());
|
||
}
|
||
|
||
#[test]
|
||
fn truncates_long_cell() {
|
||
let long = "a".repeat(CELL_CHAR_CAP + 50);
|
||
let v = truncate_cell(&Value::String(long));
|
||
let s = v.as_str().unwrap();
|
||
assert!(s.ends_with('…'));
|
||
assert!(s.chars().count() <= CELL_CHAR_CAP + 1);
|
||
}
|
||
|
||
#[test]
|
||
fn compact_drops_rows_beyond_sample() {
|
||
let mut rows = Vec::new();
|
||
for i in 0..50 {
|
||
rows.push(vec![Value::Number(i.into())]);
|
||
}
|
||
let qr = QueryResult {
|
||
columns: vec!["id".into()],
|
||
types: vec!["INT4".into()],
|
||
rows,
|
||
row_count: 50,
|
||
execution_time_ms: 1,
|
||
};
|
||
let v = compact_query_result(&qr);
|
||
let sample = v.get("sample_rows").unwrap().as_array().unwrap();
|
||
assert_eq!(sample.len(), RUN_QUERY_SAMPLE_ROWS);
|
||
assert_eq!(v.get("truncated").unwrap(), &Value::Bool(true));
|
||
assert_eq!(v.get("row_count").unwrap().as_u64().unwrap(), 50);
|
||
}
|
||
}
|