use crate::commands::ai::{build_overview_context, call_ollama_chat_messages}; use crate::commands::chat_tools::{ find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool, switch_database_tool, }; use crate::commands::memory::{append_memory_core, read_memory_core}; use crate::commands::queries::execute_query_core; use crate::error::{TuskError, TuskResult}; use crate::models::ai::OllamaChatMessage; use crate::models::chat::{ChatMessage, ChatTurnResult, ContextUsage}; use crate::models::query_result::QueryResult; use crate::state::AppState; use chrono::Utc; use serde_json::Value; use std::sync::Arc; use tauri::{AppHandle, State}; const MAX_HOPS: usize = 8; /// Number of MOST RECENT run_query tool_results that get full sample-rows in /// LLM history. Older ones are reduced to a marker so very long threads stay /// within model context budget. const RECENT_TOOL_RESULTS_FULL: usize = 4; /// Sample-row cap for compressed run_query results in LLM history. const RUN_QUERY_SAMPLE_ROWS: usize = 10; /// Per-cell character cap when stringifying sample rows. const CELL_CHAR_CAP: usize = 200; /// Per text-tool-result character cap (list_tables, get_columns, etc). const TEXT_TOOL_CHAR_CAP: usize = 10_000; /// Soft cap on serialized history+system prompt characters before the user /// is nudged to /compact. Tuned for Ollama defaults (~4-8K tokens). /// Token estimate ≈ chars / 3 for mixed Cyrillic/ASCII content. const CONTEXT_BUDGET_CHARS: u64 = 24_000; // --------------------------------------------------------------------------- // Action protocol // --------------------------------------------------------------------------- #[derive(Debug)] enum AgentAction { Final { text: String }, RunQuery { sql: String }, ListDatabases, ListTables { database: Option }, GetColumns { tables: Vec }, SwitchDatabase { database: String }, Remember { note: String }, SaveQuery { name: String, sql: String }, FindQueries { text: String }, } /// Parse the model's JSON response. Accepts both shapes the model tends to emit: /// {"action":"X","field":"..."} — flat (matches our prompt) /// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention) fn parse_agent_action(raw: &str) -> Result { let v: Value = serde_json::from_str(raw).map_err(|e| e.to_string())?; let obj = v.as_object().ok_or_else(|| "expected JSON object".to_string())?; let action = obj .get("action") .and_then(|a| a.as_str()) .ok_or_else(|| "missing field `action`".to_string())?; let lookup = |key: &str| -> Option<&Value> { obj.get(key) .or_else(|| obj.get("input").and_then(|i| i.as_object()).and_then(|i| i.get(key))) }; match action { "final" => { let text = lookup("text") .and_then(|v| v.as_str()) .ok_or_else(|| "final action missing `text`".to_string())? .to_string(); Ok(AgentAction::Final { text }) } "run_query" => { let sql = lookup("sql") .and_then(|v| v.as_str()) .ok_or_else(|| "run_query action missing `sql`".to_string())? .to_string(); Ok(AgentAction::RunQuery { sql }) } "list_databases" => Ok(AgentAction::ListDatabases), "list_tables" => { let database = lookup("database") .and_then(|v| v.as_str()) .map(|s| s.to_string()); Ok(AgentAction::ListTables { database }) } "get_columns" => { let arr = lookup("tables") .and_then(|v| v.as_array()) .ok_or_else(|| "get_columns action missing `tables`: [...]".to_string())?; let tables: Vec = arr .iter() .filter_map(|v| v.as_str().map(|s| s.to_string())) .collect(); if tables.is_empty() { return Err("get_columns `tables` array must not be empty".into()); } Ok(AgentAction::GetColumns { tables }) } "switch_database" => { let database = lookup("database") .and_then(|v| v.as_str()) .ok_or_else(|| "switch_database missing `database`".to_string())? .to_string(); Ok(AgentAction::SwitchDatabase { database }) } "remember" => { let note = lookup("note") .and_then(|v| v.as_str()) .ok_or_else(|| "remember action missing `note`".to_string())? .trim() .to_string(); if note.is_empty() { return Err("remember `note` must not be empty".into()); } Ok(AgentAction::Remember { note }) } "save_query" => { let name = lookup("name") .and_then(|v| v.as_str()) .ok_or_else(|| "save_query missing `name`".to_string())? .trim() .to_string(); let sql = lookup("sql") .and_then(|v| v.as_str()) .ok_or_else(|| "save_query missing `sql`".to_string())? .trim() .to_string(); if name.is_empty() { return Err("save_query `name` must not be empty".into()); } if sql.is_empty() { return Err("save_query `sql` must not be empty".into()); } Ok(AgentAction::SaveQuery { name, sql }) } "find_queries" => { let text = lookup("text") .and_then(|v| v.as_str()) .ok_or_else(|| "find_queries missing `text`".to_string())? .trim() .to_string(); if text.is_empty() { return Err("find_queries `text` must not be empty".into()); } Ok(AgentAction::FindQueries { text }) } // Legacy from earlier iterations — silently ignored at parse time so the // model can recover with a different action. "get_schema" => Err( "get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(), ), other => Err(format!("unknown action `{}`", other)), } } // --------------------------------------------------------------------------- // id / time helpers // --------------------------------------------------------------------------- fn now_ms() -> i64 { Utc::now().timestamp_millis() } fn new_id(prefix: &str) -> String { format!("{}-{}-{}", prefix, now_ms(), rand_suffix()) } fn rand_suffix() -> String { use std::time::{SystemTime, UNIX_EPOCH}; let nanos = SystemTime::now() .duration_since(UNIX_EPOCH) .map(|d| d.subsec_nanos()) .unwrap_or(0); format!("{:x}", nanos) } // --------------------------------------------------------------------------- // System prompt // --------------------------------------------------------------------------- fn system_prompt(overview: &str, memory: &str) -> String { let overview_block = if overview.is_empty() { "(overview unavailable; respond with `final` asking the user to reconnect.)".to_string() } else { overview.to_string() }; let memory_block = if memory.trim().is_empty() { "(empty — call remember() when you discover non-obvious facts about this database)".to_string() } else { memory.to_string() }; format!( r#"ROLE: Tusk's data assistant. Reply in the user's language. You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On every turn output STRICT JSON — exactly one of these shapes, with all fields at the root (no `input` wrapper, no markdown fences): {{"action":"list_databases"}} Refresh the database list when the OVERVIEW seems stale. {{"action":"list_tables"}} List tables in the active database. {{"action":"list_tables","database":""}} List tables in a specific database (PostgreSQL: requires switch_database before run_query). {{"action":"get_columns","tables":["schema.table","schema.table2"]}} Load full column info (types, PK, FK, comments, enums) for the listed tables. Use this BEFORE writing SQL when you don't already know the columns. {{"action":"switch_database","database":""}} Change the active database. Required for PostgreSQL when the user's question concerns data in another database. ClickHouse rarely needs this — `db.table` qualifiers are allowed without switching. {{"action":"run_query","sql":"SELECT ..."}} Execute read-only SQL (SELECT / WITH ... SELECT / EXPLAIN / SHOW / DESCRIBE). Mutating SQL is rejected by the read-only guard. {{"action":"remember","note":""}} Persist a non-obvious fact about THIS database for future sessions: column semantics, naming conventions, business-rule encodings, gotchas. Keep notes < 200 chars. The user sees and can edit your notes in the Memory sidebar tab. {{"action":"find_queries","text":""}} Search saved queries (your past work + user-saved). Use BEFORE writing complex SQL — a usable variant may already exist. Top 10 matches with SQL preview are returned. {{"action":"save_query","name":"","sql":""}} Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved. {{"action":"final","text":"..."}} End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling). WORKFLOW 1. Read LEARNED NOTES below first — the user (or your past self) may have already documented relevant facts. 2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question. 3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs. 4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns. 5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first. 6. Execute run_query. 7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here. 8. If the query is likely to be re-run later (a real report-style request, not a one-off lookup), call `save_query` with a concise `name`. 9. Answer with `final`. RULES - Use ONLY identifiers visible to you (overview / list_tables / get_columns output). Don't pluralize, translate, or guess. - LIMIT on ad-hoc SELECTs unless aggregating. - On SQL error retry once with a fix; on the second failure respond with `final` explaining what's missing. - `remember` is for durable facts, not transient observations. Don't memorise query results — only insights about the schema/data model that aren't already in the OVERVIEW. ═══════════════════════════════════════════════════════════════ LEARNED NOTES (per-connection memory; user can edit in sidebar → Memory) ═══════════════════════════════════════════════════════════════ {memory} ═══════════════════════════════════════════════════════════════ ═══════════════════════════════════════════════════════════════ OVERVIEW (refreshed every turn) ═══════════════════════════════════════════════════════════════ {overview} ═══════════════════════════════════════════════════════════════ "#, hops = MAX_HOPS, memory = memory_block, overview = overview_block, ) } // --------------------------------------------------------------------------- // Compressed history projection // --------------------------------------------------------------------------- /// Compact view of a QueryResult for re-injection into the LLM history. /// Keeps just enough for the model to reason about the next step (column /// names, types, total row count, first N rows) without the full payload. fn compact_query_result(result: &QueryResult) -> Value { let total = result.rows.len(); let sample: Vec> = result .rows .iter() .take(RUN_QUERY_SAMPLE_ROWS) .map(|row| row.iter().map(truncate_cell).collect()) .collect(); serde_json::json!({ "columns": result.columns, "types": result.types, "row_count": total, "execution_time_ms": result.execution_time_ms, "sample_rows": sample, "truncated": total > RUN_QUERY_SAMPLE_ROWS, }) } fn truncate_cell(v: &Value) -> Value { match v { Value::String(s) if s.chars().count() > CELL_CHAR_CAP => { let truncated: String = s.chars().take(CELL_CHAR_CAP).collect(); Value::String(format!("{}…", truncated)) } other => other.clone(), } } fn truncate_text(text: &str) -> String { if text.len() <= TEXT_TOOL_CHAR_CAP { text.to_string() } else { let mut out = text[..TEXT_TOOL_CHAR_CAP].to_string(); out.push_str("\n…(truncated)"); out } } fn build_history( messages: &[ChatMessage], overview_text: &str, memory_text: &str, ) -> Vec { // Index of run_query tool_results in `messages`. Used to mark which ones // get full sample rows vs the "(rows omitted)" placeholder. let run_query_indices: Vec = messages .iter() .enumerate() .filter_map(|(i, m)| match m { ChatMessage::ToolResult { tool, .. } if tool == "run_query" => Some(i), _ => None, }) .collect(); let keep_full_after_index: usize = if run_query_indices.len() <= RECENT_TOOL_RESULTS_FULL { 0 } else { run_query_indices[run_query_indices.len() - RECENT_TOOL_RESULTS_FULL] }; let mut out = Vec::with_capacity(messages.len() + 1); out.push(OllamaChatMessage { role: "system".to_string(), content: system_prompt(overview_text, memory_text), }); for (idx, m) in messages.iter().enumerate() { match m { ChatMessage::User { text, .. } => out.push(OllamaChatMessage { role: "user".to_string(), content: text.clone(), }), ChatMessage::Assistant { text, .. } => out.push(OllamaChatMessage { role: "assistant".to_string(), content: serde_json::json!({ "action": "final", "text": text }).to_string(), }), ChatMessage::ToolCall { tool, input_json, .. } => { if tool == "get_schema" { continue; // legacy } let mut envelope = serde_json::Map::new(); envelope.insert("action".to_string(), Value::String(tool.clone())); if let Ok(Value::Object(input)) = serde_json::from_str::(input_json) { for (k, v) in input { envelope.insert(k, v); } } out.push(OllamaChatMessage { role: "assistant".to_string(), content: Value::Object(envelope).to_string(), }); } ChatMessage::ToolResult { tool, is_error, text, result, .. } => { if tool == "get_schema" { continue; // legacy } let payload = match tool.as_str() { "run_query" => { if *is_error { serde_json::json!({ "tool": "run_query", "error": true, "text": text.clone().unwrap_or_default(), }) } else if idx < keep_full_after_index { serde_json::json!({ "tool": "run_query", "error": false, "note": "rows omitted (older result; user has it in the UI above)", }) } else if let Some(qr) = result { serde_json::json!({ "tool": "run_query", "error": false, "result": compact_query_result(qr), }) } else { serde_json::json!({ "tool": "run_query", "error": false, "result": null, }) } } // Text-only tools — pass through with cap. _ => serde_json::json!({ "tool": tool, "error": *is_error, "text": text.as_deref().map(truncate_text), }), }; out.push(OllamaChatMessage { role: "user".to_string(), content: format!("TOOL_RESULT {}", payload), }); } } } out } // --------------------------------------------------------------------------- // chat_send // --------------------------------------------------------------------------- /// Estimate how many characters the next LLM call will serialize to history /// (system prompt + conversation, after compression). This is the same data /// path as the actual call, so the count is exact for the chosen budget unit. async fn compute_usage( state: &AppState, app: &AppHandle, connection_id: &str, working: &[ChatMessage], ) -> ContextUsage { let overview = build_overview_context(state, connection_id) .await .unwrap_or_default(); let memory = read_memory_core(app, connection_id).unwrap_or_default(); let history = build_history(working, &overview, &memory); // role string ("system"/"user"/"assistant") ≤ 9 chars + content + JSON envelope overhead let used: u64 = history .iter() .map(|m| (m.role.len() + m.content.len() + 16) as u64) .sum(); ContextUsage { used_chars: used, budget_chars: CONTEXT_BUDGET_CHARS, } } #[tauri::command] pub async fn chat_send( app: AppHandle, state: State<'_, Arc>, connection_id: String, messages: Vec, ) -> TuskResult { let mut new_messages: Vec = Vec::new(); let mut working: Vec = messages; for _hop in 0..MAX_HOPS { // Overview is rebuilt per turn — cheap (cached) and reflects the active DB // even if the user (or the agent) just switched it. let overview_text = build_overview_context(&state, &connection_id) .await .unwrap_or_default(); // Memory is read fresh each turn so user-side edits in the Memory tab // are visible to the agent immediately. let memory_text = read_memory_core(&app, &connection_id).unwrap_or_default(); let history = build_history(&working, &overview_text, &memory_text); let raw = call_ollama_chat_messages(&app, &state, history, Some("json".to_string())).await?; let trimmed = raw.trim(); let action = match parse_agent_action(trimmed) { Ok(a) => a, Err(parse_err) => { let msg = ChatMessage::Assistant { id: new_id("asst"), text: format!( "{}\n\n_(Note: model returned non-protocol output: {})_", trimmed, parse_err ), created_at: now_ms(), }; new_messages.push(msg.clone()); working.push(msg); let usage = compute_usage(&state, &app, &connection_id, &working).await; return Ok(ChatTurnResult { messages: new_messages, usage, }); } }; match action { AgentAction::Final { text } => { let msg = ChatMessage::Assistant { id: new_id("asst"), text, created_at: now_ms(), }; new_messages.push(msg.clone()); working.push(msg); let usage = compute_usage(&state, &app, &connection_id, &working).await; return Ok(ChatTurnResult { messages: new_messages, usage, }); } AgentAction::RunQuery { sql } => { push_tool_call( &mut new_messages, &mut working, "run_query", serde_json::json!({ "sql": sql }).to_string(), ); let result = match execute_query_core(&state, &connection_id, &sql).await { Ok(qr) => ChatMessage::ToolResult { id: new_id("res"), tool: "run_query".to_string(), is_error: false, text: None, result: Some(qr), created_at: now_ms(), }, Err(e) => { let hint = match e { TuskError::ReadOnly => "\n\nRead-only mode is on. Toggle it off in the toolbar to allow writes.", _ => "", }; ChatMessage::ToolResult { id: new_id("res"), tool: "run_query".to_string(), is_error: true, text: Some(format!("{}{}", e, hint)), result: None, created_at: now_ms(), } } }; push_tool_result(&mut new_messages, &mut working, result); } AgentAction::ListDatabases => { push_tool_call( &mut new_messages, &mut working, "list_databases", "{}".to_string(), ); let result = run_text_tool( list_databases_tool(&state, &connection_id).await, "list_databases", ); push_tool_result(&mut new_messages, &mut working, result); } AgentAction::ListTables { database } => { let input_json = match &database { Some(db) => serde_json::json!({ "database": db }).to_string(), None => "{}".to_string(), }; push_tool_call(&mut new_messages, &mut working, "list_tables", input_json); let result = run_text_tool( list_tables_tool(&app, &state, &connection_id, database.as_deref()).await, "list_tables", ); push_tool_result(&mut new_messages, &mut working, result); } AgentAction::GetColumns { tables } => { push_tool_call( &mut new_messages, &mut working, "get_columns", serde_json::json!({ "tables": tables }).to_string(), ); let result = run_text_tool( get_columns_tool(&state, &connection_id, &tables).await, "get_columns", ); push_tool_result(&mut new_messages, &mut working, result); } AgentAction::SwitchDatabase { database } => { push_tool_call( &mut new_messages, &mut working, "switch_database", serde_json::json!({ "database": &database }).to_string(), ); let result = run_text_tool( switch_database_tool(&app, &state, &connection_id, &database).await, "switch_database", ); push_tool_result(&mut new_messages, &mut working, result); } AgentAction::Remember { note } => { push_tool_call( &mut new_messages, &mut working, "remember", serde_json::json!({ "note": ¬e }).to_string(), ); let outcome = append_memory_core(&app, &connection_id, ¬e) .map(|_| format!("Saved note ({} chars).", note.len())); let result = run_text_tool(outcome, "remember"); push_tool_result(&mut new_messages, &mut working, result); } AgentAction::SaveQuery { name, sql } => { push_tool_call( &mut new_messages, &mut working, "save_query", serde_json::json!({ "name": &name, "sql": &sql }).to_string(), ); let result = run_text_tool( save_query_tool(&app, &connection_id, &name, &sql).await, "save_query", ); push_tool_result(&mut new_messages, &mut working, result); } AgentAction::FindQueries { text } => { push_tool_call( &mut new_messages, &mut working, "find_queries", serde_json::json!({ "text": &text }).to_string(), ); let result = run_text_tool( find_queries_tool(&app, &connection_id, &text).await, "find_queries", ); push_tool_result(&mut new_messages, &mut working, result); } } } let msg = ChatMessage::Assistant { id: new_id("asst"), text: format!( "Stopped after {} tool calls without a final answer. Try rephrasing or simplifying the question.", MAX_HOPS ), created_at: now_ms(), }; new_messages.push(msg.clone()); working.push(msg); let usage = compute_usage(&state, &app, &connection_id, &working).await; Ok(ChatTurnResult { messages: new_messages, usage, }) } fn push_tool_call( new_messages: &mut Vec, working: &mut Vec, tool: &str, input_json: String, ) { let call = ChatMessage::ToolCall { id: new_id("call"), tool: tool.to_string(), input_json, created_at: now_ms(), }; new_messages.push(call.clone()); working.push(call); } fn push_tool_result( new_messages: &mut Vec, working: &mut Vec, result: ChatMessage, ) { new_messages.push(result.clone()); working.push(result); } fn run_text_tool(outcome: TuskResult, tool: &str) -> ChatMessage { match outcome { Ok(text) => ChatMessage::ToolResult { id: new_id("res"), tool: tool.to_string(), is_error: false, text: Some(text), result: None, created_at: now_ms(), }, Err(e) => ChatMessage::ToolResult { id: new_id("res"), tool: tool.to_string(), is_error: true, text: Some(e.to_string()), result: None, created_at: now_ms(), }, } } // --------------------------------------------------------------------------- // chat_compact // --------------------------------------------------------------------------- /// Render the older-history portion of the thread as a compact text block /// for LLM-driven summarization. Skips QueryResult.rows (huge), keeps only /// columns + row_count + sample. fn render_thread_for_summary(messages: &[ChatMessage]) -> String { let mut out = String::new(); for m in messages { match m { ChatMessage::User { text, .. } => { out.push_str(&format!("USER: {}\n\n", text)); } ChatMessage::Assistant { text, .. } => { out.push_str(&format!("ASSISTANT: {}\n\n", text)); } ChatMessage::ToolCall { tool, input_json, .. } => { out.push_str(&format!("TOOL_CALL [{}]: {}\n\n", tool, input_json)); } ChatMessage::ToolResult { tool, is_error, text, result, .. } => { if *is_error { out.push_str(&format!( "TOOL_ERROR [{}]: {}\n\n", tool, text.as_deref().unwrap_or("") )); continue; } if let Some(qr) = result { out.push_str(&format!( "TOOL_RESULT [{}]: {} rows; columns={}\n\n", tool, qr.row_count, qr.columns.join(", ") )); } else if let Some(t) = text { let snippet: String = t.chars().take(800).collect(); out.push_str(&format!("TOOL_RESULT [{}]: {}\n\n", tool, snippet)); } } } } out } /// Find the index of the last User message; returns messages.len() if no user message. fn last_user_turn_index(messages: &[ChatMessage]) -> usize { for (i, m) in messages.iter().enumerate().rev() { if matches!(m, ChatMessage::User { .. }) { return i; } } messages.len() } /// LLM-summarise the older portion of a chat thread. /// Returns thread = [ Assistant("📋 Compacted: …") , ]. /// If the thread has nothing to compact, returns it unchanged. #[tauri::command] pub async fn chat_compact( app: AppHandle, state: State<'_, Arc>, connection_id: String, messages: Vec, ) -> TuskResult { if messages.is_empty() { let usage = compute_usage(&state, &app, &connection_id, &messages).await; return Ok(ChatTurnResult { messages, usage }); } // Preserve the user's most recent question (if any) untouched so the // model can continue from it after compaction. Everything before goes // into the summary. let split_at = last_user_turn_index(&messages); let (older, recent): (&[ChatMessage], &[ChatMessage]) = if split_at == messages.len() { (&messages[..], &[]) } else { (&messages[..split_at], &messages[split_at..]) }; if older.is_empty() { let usage = compute_usage(&state, &app, &connection_id, &messages).await; return Ok(ChatTurnResult { messages, usage }); } let convo = render_thread_for_summary(older); let system = "You are a precise summarizer of a database analysis dialogue. \ Produce a SHORT summary in the SAME language the user spoke. \ Use 3-6 bullet points covering: the user's goal, key tables/columns/queries used, \ numerical findings, conclusions reached, any open questions. \ Be concrete with numbers and identifiers. Total length < 800 chars. \ Output the bullets directly with no preamble, no JSON, no markdown fences."; let llm_messages = vec![ OllamaChatMessage { role: "system".to_string(), content: system.to_string(), }, OllamaChatMessage { role: "user".to_string(), content: convo, }, ]; let summary = call_ollama_chat_messages(&app, &state, llm_messages, None) .await .map_err(|e| TuskError::Ai(format!("Compact failed: {}", e)))?; let cleaned = summary.trim(); let compacted_msg = ChatMessage::Assistant { id: new_id("asst"), text: format!( "📋 Compacted {} earlier message{}:\n\n{}", older.len(), if older.len() == 1 { "" } else { "s" }, cleaned ), created_at: now_ms(), }; let mut out: Vec = Vec::with_capacity(1 + recent.len()); out.push(compacted_msg); out.extend(recent.iter().cloned()); let usage = compute_usage(&state, &app, &connection_id, &out).await; Ok(ChatTurnResult { messages: out, usage, }) } // --------------------------------------------------------------------------- // tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; #[test] fn parses_flat_run_query() { let a = parse_agent_action(r#"{"action":"run_query","sql":"SELECT 1"}"#).unwrap(); match a { AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 1"), _ => panic!("wrong variant"), } } #[test] fn parses_nested_run_query() { let a = parse_agent_action(r#"{"action":"run_query","input":{"sql":"SELECT 2"}}"#).unwrap(); match a { AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 2"), _ => panic!("wrong variant"), } } #[test] fn parses_get_columns() { let a = parse_agent_action( r#"{"action":"get_columns","tables":["public.users","public.orders"]}"#, ) .unwrap(); match a { AgentAction::GetColumns { tables } => { assert_eq!(tables, vec!["public.users", "public.orders"]); } _ => panic!("wrong variant"), } } #[test] fn parses_get_columns_nested() { let a = parse_agent_action( r#"{"action":"get_columns","input":{"tables":["public.t"]}}"#, ) .unwrap(); match a { AgentAction::GetColumns { tables } => assert_eq!(tables, vec!["public.t"]), _ => panic!("wrong variant"), } } #[test] fn rejects_get_columns_empty_tables() { assert!(parse_agent_action(r#"{"action":"get_columns","tables":[]}"#).is_err()); } #[test] fn parses_switch_database() { let a = parse_agent_action(r#"{"action":"switch_database","database":"orders_db"}"#) .unwrap(); match a { AgentAction::SwitchDatabase { database } => assert_eq!(database, "orders_db"), _ => panic!("wrong variant"), } } #[test] fn parses_list_tables_optional_db() { let a1 = parse_agent_action(r#"{"action":"list_tables"}"#).unwrap(); match a1 { AgentAction::ListTables { database } => assert!(database.is_none()), _ => panic!("wrong variant"), } let a2 = parse_agent_action(r#"{"action":"list_tables","database":"x"}"#).unwrap(); match a2 { AgentAction::ListTables { database } => assert_eq!(database.as_deref(), Some("x")), _ => panic!("wrong variant"), } } #[test] fn rejects_unknown_action() { assert!(parse_agent_action(r#"{"action":"nuke","yes":true}"#).is_err()); } #[test] fn parses_remember_flat() { let a = parse_agent_action( r#"{"action":"remember","note":"trips.started_at is NULL for cancelled"}"#, ) .unwrap(); match a { AgentAction::Remember { note } => { assert_eq!(note, "trips.started_at is NULL for cancelled"); } _ => panic!("wrong variant"), } } #[test] fn parses_remember_nested() { let a = parse_agent_action( r#"{"action":"remember","input":{"note":" surrounded by spaces "}}"#, ) .unwrap(); match a { AgentAction::Remember { note } => { // trim happens in parser assert_eq!(note, "surrounded by spaces"); } _ => panic!("wrong variant"), } } #[test] fn rejects_remember_without_note() { assert!(parse_agent_action(r#"{"action":"remember"}"#).is_err()); } #[test] fn rejects_remember_empty_note() { assert!(parse_agent_action(r#"{"action":"remember","note":" "}"#).is_err()); } #[test] fn parses_save_query_flat() { let a = parse_agent_action( r#"{"action":"save_query","name":"GMV last 30d","sql":"SELECT 1"}"#, ) .unwrap(); match a { AgentAction::SaveQuery { name, sql } => { assert_eq!(name, "GMV last 30d"); assert_eq!(sql, "SELECT 1"); } _ => panic!("wrong variant"), } } #[test] fn parses_save_query_nested() { let a = parse_agent_action( r#"{"action":"save_query","input":{"name":"x","sql":"SELECT 2"}}"#, ) .unwrap(); match a { AgentAction::SaveQuery { name, sql } => { assert_eq!(name, "x"); assert_eq!(sql, "SELECT 2"); } _ => panic!("wrong variant"), } } #[test] fn rejects_save_query_missing_fields() { assert!(parse_agent_action(r#"{"action":"save_query","name":"x"}"#).is_err()); assert!(parse_agent_action(r#"{"action":"save_query","sql":"SELECT 1"}"#).is_err()); assert!( parse_agent_action(r#"{"action":"save_query","name":" ","sql":"SELECT 1"}"#).is_err() ); } #[test] fn parses_find_queries() { let a = parse_agent_action(r#"{"action":"find_queries","text":"gmv"}"#).unwrap(); match a { AgentAction::FindQueries { text } => assert_eq!(text, "gmv"), _ => panic!("wrong variant"), } } #[test] fn rejects_find_queries_empty_text() { assert!(parse_agent_action(r#"{"action":"find_queries","text":""}"#).is_err()); } #[test] fn last_user_turn_index_finds_last_user() { let msgs = vec![ ChatMessage::User { id: "u1".into(), text: "first".into(), created_at: 1 }, ChatMessage::Assistant { id: "a1".into(), text: "ans".into(), created_at: 2 }, ChatMessage::User { id: "u2".into(), text: "second".into(), created_at: 3 }, ChatMessage::Assistant { id: "a2".into(), text: "ans2".into(), created_at: 4 }, ]; assert_eq!(last_user_turn_index(&msgs), 2); } #[test] fn last_user_turn_index_returns_len_when_no_user() { let msgs = vec![ChatMessage::Assistant { id: "a1".into(), text: "alone".into(), created_at: 1, }]; assert_eq!(last_user_turn_index(&msgs), msgs.len()); } #[test] fn render_thread_for_summary_includes_roles_and_skips_rows() { let msgs = vec![ ChatMessage::User { id: "u1".into(), text: "find users".into(), created_at: 1 }, ChatMessage::ToolCall { id: "c1".into(), tool: "run_query".into(), input_json: r#"{"sql":"SELECT 1"}"#.into(), created_at: 2 }, ChatMessage::ToolResult { id: "r1".into(), tool: "run_query".into(), is_error: false, text: None, result: Some(QueryResult { columns: vec!["id".into(), "name".into()], types: vec!["INT4".into(), "TEXT".into()], rows: vec![vec![Value::Number(1.into()), Value::String("alice".into())]; 1000], row_count: 1000, execution_time_ms: 12, }), created_at: 3, }, ]; let rendered = render_thread_for_summary(&msgs); assert!(rendered.contains("USER: find users")); assert!(rendered.contains("TOOL_CALL [run_query]")); assert!(rendered.contains("1000 rows")); // Must NOT include the actual rows assert!(!rendered.contains("alice")); } #[test] fn rejects_legacy_get_schema() { assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err()); } #[test] fn truncates_long_cell() { let long = "a".repeat(CELL_CHAR_CAP + 50); let v = truncate_cell(&Value::String(long)); let s = v.as_str().unwrap(); assert!(s.ends_with('…')); assert!(s.chars().count() <= CELL_CHAR_CAP + 1); } #[test] fn compact_drops_rows_beyond_sample() { let mut rows = Vec::new(); for i in 0..50 { rows.push(vec![Value::Number(i.into())]); } let qr = QueryResult { columns: vec!["id".into()], types: vec!["INT4".into()], rows, row_count: 50, execution_time_ms: 1, }; let v = compact_query_result(&qr); let sample = v.get("sample_rows").unwrap().as_array().unwrap(); assert_eq!(sample.len(), RUN_QUERY_SAMPLE_ROWS); assert_eq!(v.get("truncated").unwrap(), &Value::Bool(true)); assert_eq!(v.get("row_count").unwrap().as_u64().unwrap(), 50); } }