feat: rescope to AI-first DB harness with multi-DB chat agent
Removes enterprise/DBA features and replaces the marginal AI bar with a central chat agent that has progressive-discovery tools, cross-session memory, saved-query reuse, and inline result actions. Adds ClickHouse support alongside PostgreSQL/Greenplum. Cleanup - Drop ~10k LOC of advanced features: Docker, Snapshots, Validation, Index Advisor, Role/User Management, Data Generator, ERD, Lookup. - Trim deps: drop @xyflow/react, dagre, @types/dagre; cut tokio features to rt-multi-thread/sync/time/net/macros. - Remove unused TuskError variants and dead helpers (topological_sort, invalidate_schema_cache). Multi-DB (PostgreSQL + ClickHouse) - New src-tauri/src/db/ module: ChClient (HTTP-based, reuses reqwest), sql_guard (cross-flavor read-only whitelist with 8 tests). - ConnectionConfig gains db_flavor and secure fields with serde defaults for backwards-compatible connections.json. - All connection/query/schema/data commands dispatch by flavor; CH covers connect, execute_query, list_databases/schemas/tables/views/ columns/completion_schema, paginated table fetch. - Frontend: dbCapabilities matrix, ConnectionDialog engine selector with port auto-swap and HTTPS toggle, SqlEditor switches to StandardSQL dialect for CH, TableDataView surfaces CH connections as read-only. AI-first chat agent - New src/components/chat/ panel with composer, message rendering, collapsible tool-call/result blocks, top-level ErrorBoundary. - Backend agent loop in commands/chat.rs with strict-JSON tool protocol. Nine tools: list_databases, list_tables, get_columns, switch_database, run_query, remember, save_query, find_queries, final. Forgiving parser accepts both flat and nested-input shapes. - Compressed history: only the last 4 run_query results carry sample rows (≤10, cells truncated to 200 chars) into LLM context; older results marked omitted. - System prompt uses lite OVERVIEW (DB list + active-DB tables only) instead of full DDL — schema details are loaded on demand via get_columns. CH OVERVIEW shows cross-DB tables since CH allows db.table queries. Cross-session memory (F1) - Per-connection markdown file at app_data_dir/memory/<connection_id>.md, 16KB cap with oldest-block eviction. Agent appends via remember() tool; the file is injected into LEARNED NOTES section of every system prompt. - New Memory sidebar tab with editable textarea, badge for note count, empty-state with template. Edits picked up on the next agent turn. Saved-query reuse (F2) - Tools save_query and find_queries scoped to current connection. save_query attaches a UUID + timestamp; find_queries returns top 10 matches with SQL preview ≤500 chars. - Storage shared with the sidebar Saved panel. Inline result actions (F3) - run_query result block in chat gets Open-full (90vw × 80vh modal with full ResultsTable, no row cap) and Export (reuses ExportDialog for CSV/JSON via existing exportCsv/exportJson commands). Verification - cargo check clean, zero warnings. - cargo test --lib: 50 pass (20 chat parser + 4 memory + 8 sql_guard + 6 clean_sql + 12 escape_ident). - npx tsc --noEmit clean. - npx vitest run: 20 pass.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
869
src-tauri/src/commands/chat.rs
Normal file
869
src-tauri/src/commands/chat.rs
Normal file
@@ -0,0 +1,869 @@
|
||||
use crate::commands::ai::{build_overview_context, call_ollama_chat_messages};
|
||||
use crate::commands::chat_tools::{
|
||||
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
|
||||
switch_database_tool,
|
||||
};
|
||||
use crate::commands::memory::{append_memory_core, read_memory_core};
|
||||
use crate::commands::queries::execute_query_core;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::ai::OllamaChatMessage;
|
||||
use crate::models::chat::ChatMessage;
|
||||
use crate::models::query_result::QueryResult;
|
||||
use crate::state::AppState;
|
||||
use chrono::Utc;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tauri::{AppHandle, State};
|
||||
|
||||
const MAX_HOPS: usize = 8;
|
||||
/// Number of MOST RECENT run_query tool_results that get full sample-rows in
|
||||
/// LLM history. Older ones are reduced to a marker so very long threads stay
|
||||
/// within model context budget.
|
||||
const RECENT_TOOL_RESULTS_FULL: usize = 4;
|
||||
/// Sample-row cap for compressed run_query results in LLM history.
|
||||
const RUN_QUERY_SAMPLE_ROWS: usize = 10;
|
||||
/// Per-cell character cap when stringifying sample rows.
|
||||
const CELL_CHAR_CAP: usize = 200;
|
||||
/// Per text-tool-result character cap (list_tables, get_columns, etc).
|
||||
const TEXT_TOOL_CHAR_CAP: usize = 10_000;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Action protocol
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug)]
|
||||
enum AgentAction {
|
||||
Final { text: String },
|
||||
RunQuery { sql: String },
|
||||
ListDatabases,
|
||||
ListTables { database: Option<String> },
|
||||
GetColumns { tables: Vec<String> },
|
||||
SwitchDatabase { database: String },
|
||||
Remember { note: String },
|
||||
SaveQuery { name: String, sql: String },
|
||||
FindQueries { text: String },
|
||||
}
|
||||
|
||||
/// Parse the model's JSON response. Accepts both shapes the model tends to emit:
|
||||
/// {"action":"X","field":"..."} — flat (matches our prompt)
|
||||
/// {"action":"X","input":{"field":"..."}} — nested (common tool-use convention)
|
||||
fn parse_agent_action(raw: &str) -> Result<AgentAction, String> {
|
||||
let v: Value = serde_json::from_str(raw).map_err(|e| e.to_string())?;
|
||||
let obj = v.as_object().ok_or_else(|| "expected JSON object".to_string())?;
|
||||
let action = obj
|
||||
.get("action")
|
||||
.and_then(|a| a.as_str())
|
||||
.ok_or_else(|| "missing field `action`".to_string())?;
|
||||
|
||||
let lookup = |key: &str| -> Option<&Value> {
|
||||
obj.get(key)
|
||||
.or_else(|| obj.get("input").and_then(|i| i.as_object()).and_then(|i| i.get(key)))
|
||||
};
|
||||
|
||||
match action {
|
||||
"final" => {
|
||||
let text = lookup("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "final action missing `text`".to_string())?
|
||||
.to_string();
|
||||
Ok(AgentAction::Final { text })
|
||||
}
|
||||
"run_query" => {
|
||||
let sql = lookup("sql")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "run_query action missing `sql`".to_string())?
|
||||
.to_string();
|
||||
Ok(AgentAction::RunQuery { sql })
|
||||
}
|
||||
"list_databases" => Ok(AgentAction::ListDatabases),
|
||||
"list_tables" => {
|
||||
let database = lookup("database")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
Ok(AgentAction::ListTables { database })
|
||||
}
|
||||
"get_columns" => {
|
||||
let arr = lookup("tables")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| "get_columns action missing `tables`: [...]".to_string())?;
|
||||
let tables: Vec<String> = arr
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str().map(|s| s.to_string()))
|
||||
.collect();
|
||||
if tables.is_empty() {
|
||||
return Err("get_columns `tables` array must not be empty".into());
|
||||
}
|
||||
Ok(AgentAction::GetColumns { tables })
|
||||
}
|
||||
"switch_database" => {
|
||||
let database = lookup("database")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "switch_database missing `database`".to_string())?
|
||||
.to_string();
|
||||
Ok(AgentAction::SwitchDatabase { database })
|
||||
}
|
||||
"remember" => {
|
||||
let note = lookup("note")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "remember action missing `note`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
if note.is_empty() {
|
||||
return Err("remember `note` must not be empty".into());
|
||||
}
|
||||
Ok(AgentAction::Remember { note })
|
||||
}
|
||||
"save_query" => {
|
||||
let name = lookup("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "save_query missing `name`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
let sql = lookup("sql")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "save_query missing `sql`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
if name.is_empty() {
|
||||
return Err("save_query `name` must not be empty".into());
|
||||
}
|
||||
if sql.is_empty() {
|
||||
return Err("save_query `sql` must not be empty".into());
|
||||
}
|
||||
Ok(AgentAction::SaveQuery { name, sql })
|
||||
}
|
||||
"find_queries" => {
|
||||
let text = lookup("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| "find_queries missing `text`".to_string())?
|
||||
.trim()
|
||||
.to_string();
|
||||
if text.is_empty() {
|
||||
return Err("find_queries `text` must not be empty".into());
|
||||
}
|
||||
Ok(AgentAction::FindQueries { text })
|
||||
}
|
||||
// Legacy from earlier iterations — silently ignored at parse time so the
|
||||
// model can recover with a different action.
|
||||
"get_schema" => Err(
|
||||
"get_schema is deprecated; use get_columns({\"tables\":[...]}) instead.".to_string(),
|
||||
),
|
||||
other => Err(format!("unknown action `{}`", other)),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// id / time helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn now_ms() -> i64 {
|
||||
Utc::now().timestamp_millis()
|
||||
}
|
||||
|
||||
fn new_id(prefix: &str) -> String {
|
||||
format!("{}-{}-{}", prefix, now_ms(), rand_suffix())
|
||||
}
|
||||
|
||||
fn rand_suffix() -> String {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let nanos = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.subsec_nanos())
|
||||
.unwrap_or(0);
|
||||
format!("{:x}", nanos)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// System prompt
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn system_prompt(overview: &str, memory: &str) -> String {
|
||||
let overview_block = if overview.is_empty() {
|
||||
"(overview unavailable; respond with `final` asking the user to reconnect.)".to_string()
|
||||
} else {
|
||||
overview.to_string()
|
||||
};
|
||||
|
||||
let memory_block = if memory.trim().is_empty() {
|
||||
"(empty — call remember() when you discover non-obvious facts about this database)".to_string()
|
||||
} else {
|
||||
memory.to_string()
|
||||
};
|
||||
|
||||
format!(
|
||||
r#"ROLE: Tusk's data assistant. Reply in the user's language.
|
||||
|
||||
You operate as an agent in a single-tool-per-turn loop with hop limit {hops}. On every turn output STRICT JSON — exactly one of these shapes, with all fields at the root (no `input` wrapper, no markdown fences):
|
||||
|
||||
{{"action":"list_databases"}}
|
||||
Refresh the database list when the OVERVIEW seems stale.
|
||||
|
||||
{{"action":"list_tables"}}
|
||||
List tables in the active database.
|
||||
|
||||
{{"action":"list_tables","database":"<name>"}}
|
||||
List tables in a specific database (PostgreSQL: requires switch_database before run_query).
|
||||
|
||||
{{"action":"get_columns","tables":["schema.table","schema.table2"]}}
|
||||
Load full column info (types, PK, FK, comments, enums) for the listed tables. Use this BEFORE writing SQL when you don't already know the columns.
|
||||
|
||||
{{"action":"switch_database","database":"<name>"}}
|
||||
Change the active database. Required for PostgreSQL when the user's question concerns data in another database. ClickHouse rarely needs this — `db.table` qualifiers are allowed without switching.
|
||||
|
||||
{{"action":"run_query","sql":"SELECT ..."}}
|
||||
Execute read-only SQL (SELECT / WITH ... SELECT / EXPLAIN / SHOW / DESCRIBE). Mutating SQL is rejected by the read-only guard.
|
||||
|
||||
{{"action":"remember","note":"<short observation>"}}
|
||||
Persist a non-obvious fact about THIS database for future sessions: column semantics, naming conventions, business-rule encodings, gotchas. Keep notes < 200 chars. The user sees and can edit your notes in the Memory sidebar tab.
|
||||
|
||||
{{"action":"find_queries","text":"<keywords>"}}
|
||||
Search saved queries (your past work + user-saved). Use BEFORE writing complex SQL — a usable variant may already exist. Top 10 matches with SQL preview are returned.
|
||||
|
||||
{{"action":"save_query","name":"<short label>","sql":"<the SQL>"}}
|
||||
Persist a non-trivial working SELECT for reuse later. Use AFTER a successful run_query when the query is likely to be re-run. Keep `name` short and descriptive (e.g. "GMV by carrier — last 30d"). The user sees these in sidebar → Saved.
|
||||
|
||||
{{"action":"final","text":"..."}}
|
||||
End the turn with a plain-language answer for the user. Do NOT repeat the result table — the UI shows it. Mention caveats (LIMIT, NULL filters, sampling).
|
||||
|
||||
WORKFLOW
|
||||
1. Read LEARNED NOTES below first — the user (or your past self) may have already documented relevant facts.
|
||||
2. For non-trivial requests, run `find_queries({{text}})` once to check if a saved query already answers the question.
|
||||
3. Pick candidate tables from the OVERVIEW (active DB) or call list_tables if you need other DBs.
|
||||
4. If a candidate's columns are unknown, call get_columns FIRST. NEVER invent columns.
|
||||
5. If the user's data lives in a different DB and engine is PostgreSQL, switch_database first.
|
||||
6. Execute run_query.
|
||||
7. If you discovered something non-obvious (semantics, gotcha, business rule that isn't visible from the schema alone), call `remember` BEFORE `final`. Future sessions will see your notes here.
|
||||
8. If the query is likely to be re-run later (a real report-style request, not a one-off lookup), call `save_query` with a concise `name`.
|
||||
9. Answer with `final`.
|
||||
|
||||
RULES
|
||||
- Use ONLY identifiers visible to you (overview / list_tables / get_columns output). Don't pluralize, translate, or guess.
|
||||
- LIMIT on ad-hoc SELECTs unless aggregating.
|
||||
- On SQL error retry once with a fix; on the second failure respond with `final` explaining what's missing.
|
||||
- `remember` is for durable facts, not transient observations. Don't memorise query results — only insights about the schema/data model that aren't already in the OVERVIEW.
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
LEARNED NOTES (per-connection memory; user can edit in sidebar → Memory)
|
||||
═══════════════════════════════════════════════════════════════
|
||||
{memory}
|
||||
═══════════════════════════════════════════════════════════════
|
||||
|
||||
═══════════════════════════════════════════════════════════════
|
||||
OVERVIEW (refreshed every turn)
|
||||
═══════════════════════════════════════════════════════════════
|
||||
{overview}
|
||||
═══════════════════════════════════════════════════════════════
|
||||
"#,
|
||||
hops = MAX_HOPS,
|
||||
memory = memory_block,
|
||||
overview = overview_block,
|
||||
)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compressed history projection
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compact view of a QueryResult for re-injection into the LLM history.
|
||||
/// Keeps just enough for the model to reason about the next step (column
|
||||
/// names, types, total row count, first N rows) without the full payload.
|
||||
fn compact_query_result(result: &QueryResult) -> Value {
|
||||
let total = result.rows.len();
|
||||
let sample: Vec<Vec<Value>> = result
|
||||
.rows
|
||||
.iter()
|
||||
.take(RUN_QUERY_SAMPLE_ROWS)
|
||||
.map(|row| row.iter().map(truncate_cell).collect())
|
||||
.collect();
|
||||
serde_json::json!({
|
||||
"columns": result.columns,
|
||||
"types": result.types,
|
||||
"row_count": total,
|
||||
"execution_time_ms": result.execution_time_ms,
|
||||
"sample_rows": sample,
|
||||
"truncated": total > RUN_QUERY_SAMPLE_ROWS,
|
||||
})
|
||||
}
|
||||
|
||||
fn truncate_cell(v: &Value) -> Value {
|
||||
match v {
|
||||
Value::String(s) if s.chars().count() > CELL_CHAR_CAP => {
|
||||
let truncated: String = s.chars().take(CELL_CHAR_CAP).collect();
|
||||
Value::String(format!("{}…", truncated))
|
||||
}
|
||||
other => other.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate_text(text: &str) -> String {
|
||||
if text.len() <= TEXT_TOOL_CHAR_CAP {
|
||||
text.to_string()
|
||||
} else {
|
||||
let mut out = text[..TEXT_TOOL_CHAR_CAP].to_string();
|
||||
out.push_str("\n…(truncated)");
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
fn build_history(
|
||||
messages: &[ChatMessage],
|
||||
overview_text: &str,
|
||||
memory_text: &str,
|
||||
) -> Vec<OllamaChatMessage> {
|
||||
// Index of run_query tool_results in `messages`. Used to mark which ones
|
||||
// get full sample rows vs the "(rows omitted)" placeholder.
|
||||
let run_query_indices: Vec<usize> = messages
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, m)| match m {
|
||||
ChatMessage::ToolResult { tool, .. } if tool == "run_query" => Some(i),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
let keep_full_after_index: usize = if run_query_indices.len() <= RECENT_TOOL_RESULTS_FULL {
|
||||
0
|
||||
} else {
|
||||
run_query_indices[run_query_indices.len() - RECENT_TOOL_RESULTS_FULL]
|
||||
};
|
||||
|
||||
let mut out = Vec::with_capacity(messages.len() + 1);
|
||||
out.push(OllamaChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: system_prompt(overview_text, memory_text),
|
||||
});
|
||||
|
||||
for (idx, m) in messages.iter().enumerate() {
|
||||
match m {
|
||||
ChatMessage::User { text, .. } => out.push(OllamaChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: text.clone(),
|
||||
}),
|
||||
ChatMessage::Assistant { text, .. } => out.push(OllamaChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: serde_json::json!({ "action": "final", "text": text }).to_string(),
|
||||
}),
|
||||
ChatMessage::ToolCall { tool, input_json, .. } => {
|
||||
if tool == "get_schema" {
|
||||
continue; // legacy
|
||||
}
|
||||
let mut envelope = serde_json::Map::new();
|
||||
envelope.insert("action".to_string(), Value::String(tool.clone()));
|
||||
if let Ok(Value::Object(input)) = serde_json::from_str::<Value>(input_json) {
|
||||
for (k, v) in input {
|
||||
envelope.insert(k, v);
|
||||
}
|
||||
}
|
||||
out.push(OllamaChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: Value::Object(envelope).to_string(),
|
||||
});
|
||||
}
|
||||
ChatMessage::ToolResult {
|
||||
tool,
|
||||
is_error,
|
||||
text,
|
||||
result,
|
||||
..
|
||||
} => {
|
||||
if tool == "get_schema" {
|
||||
continue; // legacy
|
||||
}
|
||||
let payload = match tool.as_str() {
|
||||
"run_query" => {
|
||||
if *is_error {
|
||||
serde_json::json!({
|
||||
"tool": "run_query",
|
||||
"error": true,
|
||||
"text": text.clone().unwrap_or_default(),
|
||||
})
|
||||
} else if idx < keep_full_after_index {
|
||||
serde_json::json!({
|
||||
"tool": "run_query",
|
||||
"error": false,
|
||||
"note": "rows omitted (older result; user has it in the UI above)",
|
||||
})
|
||||
} else if let Some(qr) = result {
|
||||
serde_json::json!({
|
||||
"tool": "run_query",
|
||||
"error": false,
|
||||
"result": compact_query_result(qr),
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"tool": "run_query",
|
||||
"error": false,
|
||||
"result": null,
|
||||
})
|
||||
}
|
||||
}
|
||||
// Text-only tools — pass through with cap.
|
||||
_ => serde_json::json!({
|
||||
"tool": tool,
|
||||
"error": *is_error,
|
||||
"text": text.as_deref().map(truncate_text),
|
||||
}),
|
||||
};
|
||||
|
||||
out.push(OllamaChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: format!("TOOL_RESULT {}", payload),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// chat_send
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn chat_send(
|
||||
app: AppHandle,
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
messages: Vec<ChatMessage>,
|
||||
) -> TuskResult<Vec<ChatMessage>> {
|
||||
let mut new_messages: Vec<ChatMessage> = Vec::new();
|
||||
let mut working: Vec<ChatMessage> = messages;
|
||||
|
||||
for _hop in 0..MAX_HOPS {
|
||||
// Overview is rebuilt per turn — cheap (cached) and reflects the active DB
|
||||
// even if the user (or the agent) just switched it.
|
||||
let overview_text = build_overview_context(&state, &connection_id)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
// Memory is read fresh each turn so user-side edits in the Memory tab
|
||||
// are visible to the agent immediately.
|
||||
let memory_text = read_memory_core(&app, &connection_id).unwrap_or_default();
|
||||
|
||||
let history = build_history(&working, &overview_text, &memory_text);
|
||||
let raw =
|
||||
call_ollama_chat_messages(&app, &state, history, Some("json".to_string())).await?;
|
||||
let trimmed = raw.trim();
|
||||
|
||||
let action = match parse_agent_action(trimmed) {
|
||||
Ok(a) => a,
|
||||
Err(parse_err) => {
|
||||
let msg = ChatMessage::Assistant {
|
||||
id: new_id("asst"),
|
||||
text: format!(
|
||||
"{}\n\n_(Note: model returned non-protocol output: {})_",
|
||||
trimmed, parse_err
|
||||
),
|
||||
created_at: now_ms(),
|
||||
};
|
||||
new_messages.push(msg.clone());
|
||||
working.push(msg);
|
||||
return Ok(new_messages);
|
||||
}
|
||||
};
|
||||
|
||||
match action {
|
||||
AgentAction::Final { text } => {
|
||||
let msg = ChatMessage::Assistant {
|
||||
id: new_id("asst"),
|
||||
text,
|
||||
created_at: now_ms(),
|
||||
};
|
||||
new_messages.push(msg.clone());
|
||||
working.push(msg);
|
||||
return Ok(new_messages);
|
||||
}
|
||||
AgentAction::RunQuery { sql } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"run_query",
|
||||
serde_json::json!({ "sql": sql }).to_string(),
|
||||
);
|
||||
let result = match execute_query_core(&state, &connection_id, &sql).await {
|
||||
Ok(qr) => ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "run_query".to_string(),
|
||||
is_error: false,
|
||||
text: None,
|
||||
result: Some(qr),
|
||||
created_at: now_ms(),
|
||||
},
|
||||
Err(e) => {
|
||||
let hint = match e {
|
||||
TuskError::ReadOnly => "\n\nRead-only mode is on. Toggle it off in the toolbar to allow writes.",
|
||||
_ => "",
|
||||
};
|
||||
ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: "run_query".to_string(),
|
||||
is_error: true,
|
||||
text: Some(format!("{}{}", e, hint)),
|
||||
result: None,
|
||||
created_at: now_ms(),
|
||||
}
|
||||
}
|
||||
};
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::ListDatabases => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"list_databases",
|
||||
"{}".to_string(),
|
||||
);
|
||||
let result = run_text_tool(
|
||||
list_databases_tool(&state, &connection_id).await,
|
||||
"list_databases",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::ListTables { database } => {
|
||||
let input_json = match &database {
|
||||
Some(db) => serde_json::json!({ "database": db }).to_string(),
|
||||
None => "{}".to_string(),
|
||||
};
|
||||
push_tool_call(&mut new_messages, &mut working, "list_tables", input_json);
|
||||
let result = run_text_tool(
|
||||
list_tables_tool(&app, &state, &connection_id, database.as_deref()).await,
|
||||
"list_tables",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::GetColumns { tables } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"get_columns",
|
||||
serde_json::json!({ "tables": tables }).to_string(),
|
||||
);
|
||||
let result = run_text_tool(
|
||||
get_columns_tool(&state, &connection_id, &tables).await,
|
||||
"get_columns",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::SwitchDatabase { database } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"switch_database",
|
||||
serde_json::json!({ "database": &database }).to_string(),
|
||||
);
|
||||
let result = run_text_tool(
|
||||
switch_database_tool(&app, &state, &connection_id, &database).await,
|
||||
"switch_database",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::Remember { note } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"remember",
|
||||
serde_json::json!({ "note": ¬e }).to_string(),
|
||||
);
|
||||
let outcome = append_memory_core(&app, &connection_id, ¬e)
|
||||
.map(|_| format!("Saved note ({} chars).", note.len()));
|
||||
let result = run_text_tool(outcome, "remember");
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::SaveQuery { name, sql } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"save_query",
|
||||
serde_json::json!({ "name": &name, "sql": &sql }).to_string(),
|
||||
);
|
||||
let result = run_text_tool(
|
||||
save_query_tool(&app, &connection_id, &name, &sql).await,
|
||||
"save_query",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
AgentAction::FindQueries { text } => {
|
||||
push_tool_call(
|
||||
&mut new_messages,
|
||||
&mut working,
|
||||
"find_queries",
|
||||
serde_json::json!({ "text": &text }).to_string(),
|
||||
);
|
||||
let result = run_text_tool(
|
||||
find_queries_tool(&app, &connection_id, &text).await,
|
||||
"find_queries",
|
||||
);
|
||||
push_tool_result(&mut new_messages, &mut working, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let msg = ChatMessage::Assistant {
|
||||
id: new_id("asst"),
|
||||
text: format!(
|
||||
"Stopped after {} tool calls without a final answer. Try rephrasing or simplifying the question.",
|
||||
MAX_HOPS
|
||||
),
|
||||
created_at: now_ms(),
|
||||
};
|
||||
new_messages.push(msg);
|
||||
Ok(new_messages)
|
||||
}
|
||||
|
||||
fn push_tool_call(
|
||||
new_messages: &mut Vec<ChatMessage>,
|
||||
working: &mut Vec<ChatMessage>,
|
||||
tool: &str,
|
||||
input_json: String,
|
||||
) {
|
||||
let call = ChatMessage::ToolCall {
|
||||
id: new_id("call"),
|
||||
tool: tool.to_string(),
|
||||
input_json,
|
||||
created_at: now_ms(),
|
||||
};
|
||||
new_messages.push(call.clone());
|
||||
working.push(call);
|
||||
}
|
||||
|
||||
fn push_tool_result(
|
||||
new_messages: &mut Vec<ChatMessage>,
|
||||
working: &mut Vec<ChatMessage>,
|
||||
result: ChatMessage,
|
||||
) {
|
||||
new_messages.push(result.clone());
|
||||
working.push(result);
|
||||
}
|
||||
|
||||
fn run_text_tool(outcome: TuskResult<String>, tool: &str) -> ChatMessage {
|
||||
match outcome {
|
||||
Ok(text) => ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: tool.to_string(),
|
||||
is_error: false,
|
||||
text: Some(text),
|
||||
result: None,
|
||||
created_at: now_ms(),
|
||||
},
|
||||
Err(e) => ChatMessage::ToolResult {
|
||||
id: new_id("res"),
|
||||
tool: tool.to_string(),
|
||||
is_error: true,
|
||||
text: Some(e.to_string()),
|
||||
result: None,
|
||||
created_at: now_ms(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parses_flat_run_query() {
|
||||
let a = parse_agent_action(r#"{"action":"run_query","sql":"SELECT 1"}"#).unwrap();
|
||||
match a {
|
||||
AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 1"),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_nested_run_query() {
|
||||
let a =
|
||||
parse_agent_action(r#"{"action":"run_query","input":{"sql":"SELECT 2"}}"#).unwrap();
|
||||
match a {
|
||||
AgentAction::RunQuery { sql } => assert_eq!(sql, "SELECT 2"),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_get_columns() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"get_columns","tables":["public.users","public.orders"]}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::GetColumns { tables } => {
|
||||
assert_eq!(tables, vec!["public.users", "public.orders"]);
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_get_columns_nested() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"get_columns","input":{"tables":["public.t"]}}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::GetColumns { tables } => assert_eq!(tables, vec!["public.t"]),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_get_columns_empty_tables() {
|
||||
assert!(parse_agent_action(r#"{"action":"get_columns","tables":[]}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_switch_database() {
|
||||
let a = parse_agent_action(r#"{"action":"switch_database","database":"orders_db"}"#)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::SwitchDatabase { database } => assert_eq!(database, "orders_db"),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_list_tables_optional_db() {
|
||||
let a1 = parse_agent_action(r#"{"action":"list_tables"}"#).unwrap();
|
||||
match a1 {
|
||||
AgentAction::ListTables { database } => assert!(database.is_none()),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
let a2 = parse_agent_action(r#"{"action":"list_tables","database":"x"}"#).unwrap();
|
||||
match a2 {
|
||||
AgentAction::ListTables { database } => assert_eq!(database.as_deref(), Some("x")),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_unknown_action() {
|
||||
assert!(parse_agent_action(r#"{"action":"nuke","yes":true}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_remember_flat() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"remember","note":"trips.started_at is NULL for cancelled"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::Remember { note } => {
|
||||
assert_eq!(note, "trips.started_at is NULL for cancelled");
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_remember_nested() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"remember","input":{"note":" surrounded by spaces "}}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::Remember { note } => {
|
||||
// trim happens in parser
|
||||
assert_eq!(note, "surrounded by spaces");
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_remember_without_note() {
|
||||
assert!(parse_agent_action(r#"{"action":"remember"}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_remember_empty_note() {
|
||||
assert!(parse_agent_action(r#"{"action":"remember","note":" "}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_save_query_flat() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"save_query","name":"GMV last 30d","sql":"SELECT 1"}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::SaveQuery { name, sql } => {
|
||||
assert_eq!(name, "GMV last 30d");
|
||||
assert_eq!(sql, "SELECT 1");
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_save_query_nested() {
|
||||
let a = parse_agent_action(
|
||||
r#"{"action":"save_query","input":{"name":"x","sql":"SELECT 2"}}"#,
|
||||
)
|
||||
.unwrap();
|
||||
match a {
|
||||
AgentAction::SaveQuery { name, sql } => {
|
||||
assert_eq!(name, "x");
|
||||
assert_eq!(sql, "SELECT 2");
|
||||
}
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_save_query_missing_fields() {
|
||||
assert!(parse_agent_action(r#"{"action":"save_query","name":"x"}"#).is_err());
|
||||
assert!(parse_agent_action(r#"{"action":"save_query","sql":"SELECT 1"}"#).is_err());
|
||||
assert!(
|
||||
parse_agent_action(r#"{"action":"save_query","name":" ","sql":"SELECT 1"}"#).is_err()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_find_queries() {
|
||||
let a = parse_agent_action(r#"{"action":"find_queries","text":"gmv"}"#).unwrap();
|
||||
match a {
|
||||
AgentAction::FindQueries { text } => assert_eq!(text, "gmv"),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_find_queries_empty_text() {
|
||||
assert!(parse_agent_action(r#"{"action":"find_queries","text":""}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_legacy_get_schema() {
|
||||
assert!(parse_agent_action(r#"{"action":"get_schema"}"#).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_long_cell() {
|
||||
let long = "a".repeat(CELL_CHAR_CAP + 50);
|
||||
let v = truncate_cell(&Value::String(long));
|
||||
let s = v.as_str().unwrap();
|
||||
assert!(s.ends_with('…'));
|
||||
assert!(s.chars().count() <= CELL_CHAR_CAP + 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compact_drops_rows_beyond_sample() {
|
||||
let mut rows = Vec::new();
|
||||
for i in 0..50 {
|
||||
rows.push(vec![Value::Number(i.into())]);
|
||||
}
|
||||
let qr = QueryResult {
|
||||
columns: vec!["id".into()],
|
||||
types: vec!["INT4".into()],
|
||||
rows,
|
||||
row_count: 50,
|
||||
execution_time_ms: 1,
|
||||
};
|
||||
let v = compact_query_result(&qr);
|
||||
let sample = v.get("sample_rows").unwrap().as_array().unwrap();
|
||||
assert_eq!(sample.len(), RUN_QUERY_SAMPLE_ROWS);
|
||||
assert_eq!(v.get("truncated").unwrap(), &Value::Bool(true));
|
||||
assert_eq!(v.get("row_count").unwrap().as_u64().unwrap(), 50);
|
||||
}
|
||||
}
|
||||
558
src-tauri/src/commands/chat_tools.rs
Normal file
558
src-tauri/src/commands/chat_tools.rs
Normal file
@@ -0,0 +1,558 @@
|
||||
//! Chat agent tool handlers (chat v2).
|
||||
//!
|
||||
//! Each `*_tool` function returns a plain string formatted for direct injection
|
||||
//! into the LLM tool-result history. They reuse the schema helpers in
|
||||
//! `commands::ai` and `commands::schema` rather than re-implementing SQL.
|
||||
|
||||
use crate::commands::ai::{
|
||||
fetch_column_comments, fetch_columns, fetch_enum_types, fetch_foreign_keys_raw,
|
||||
fetch_table_comments, fetch_unique_constraints, format_table_block, ColumnInfo,
|
||||
};
|
||||
use crate::commands::connections::{load_connection_config, switch_database_core};
|
||||
use crate::commands::saved_queries::{list_saved_queries_core, save_query_core};
|
||||
use crate::commands::schema::{list_databases_core, list_tables_core};
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::saved_queries::SavedQuery;
|
||||
use crate::state::{AppState, CachedVec, DbFlavor};
|
||||
use sqlx::{PgPool, Row};
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::time::{Duration, Instant};
|
||||
use tauri::AppHandle;
|
||||
|
||||
const TOOL_CACHE_TTL: Duration = Duration::from_secs(300);
|
||||
const MAX_TABLES_PER_GET_COLUMNS: usize = 20;
|
||||
const COLUMNS_TOOL_OUTPUT_CAP: usize = 15_000;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// list_databases
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_databases_tool(state: &AppState, connection_id: &str) -> TuskResult<String> {
|
||||
let dbs = list_databases_core(state, connection_id).await?;
|
||||
let active = active_db_name(state, connection_id).await;
|
||||
|
||||
let mut out = format!("DATABASES ({}):", dbs.len());
|
||||
for db in &dbs {
|
||||
if Some(db) == active.as_ref() {
|
||||
out.push_str(&format!("\n * {} (active)", db));
|
||||
} else {
|
||||
out.push_str(&format!("\n {}", db));
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// list_tables
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn list_tables_tool(
|
||||
app: &AppHandle,
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
db: Option<&str>,
|
||||
) -> TuskResult<String> {
|
||||
let active = active_db_name(state, connection_id).await;
|
||||
let target = db.map(|s| s.to_string()).or_else(|| active.clone());
|
||||
|
||||
let target_name = match target.as_deref() {
|
||||
Some(n) => n.to_string(),
|
||||
None => return Err(TuskError::Custom("No active database selected.".into())),
|
||||
};
|
||||
|
||||
let same_as_active = active.as_deref() == Some(target_name.as_str());
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
|
||||
let table_names = match (flavor, same_as_active) {
|
||||
(DbFlavor::ClickHouse, _) => list_tables_clickhouse(state, connection_id, &target_name).await?,
|
||||
(_, true) => list_tables_active_pg(state, connection_id).await?,
|
||||
(_, false) => list_tables_other_pg(app, state, connection_id, &target_name).await?,
|
||||
};
|
||||
|
||||
let header = if same_as_active {
|
||||
format!("TABLES IN ACTIVE DATABASE `{}` ({}):", target_name, table_names.len())
|
||||
} else {
|
||||
format!("TABLES IN DATABASE `{}` ({}):", target_name, table_names.len())
|
||||
};
|
||||
let body: Vec<String> = table_names.iter().map(|t| format!(" {}", t)).collect();
|
||||
Ok(format!("{}\n{}", header, body.join("\n")))
|
||||
}
|
||||
|
||||
async fn list_tables_active_pg(state: &AppState, connection_id: &str) -> TuskResult<Vec<String>> {
|
||||
let schemas = crate::commands::schema::list_schemas_core(state, connection_id).await?;
|
||||
let mut all: Vec<String> = Vec::new();
|
||||
for schema in &schemas {
|
||||
let tables = list_tables_core(state, connection_id, schema).await?;
|
||||
for t in tables {
|
||||
all.push(format!("{}.{}", schema, t.name));
|
||||
}
|
||||
}
|
||||
Ok(all)
|
||||
}
|
||||
|
||||
async fn list_tables_other_pg(
|
||||
app: &AppHandle,
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
target_db: &str,
|
||||
) -> TuskResult<Vec<String>> {
|
||||
let cache_key = (connection_id.to_string(), target_db.to_string());
|
||||
if let Some(hit) = state.tables_by_db_cache.read().await.get(&cache_key).cloned() {
|
||||
if hit.cached_at.elapsed() < TOOL_CACHE_TTL {
|
||||
return Ok(hit.value);
|
||||
}
|
||||
}
|
||||
|
||||
let config = load_connection_config(app, connection_id)?;
|
||||
let url = config.connection_url_for_db(target_db);
|
||||
let pool = PgPool::connect(&url).await.map_err(|e| {
|
||||
TuskError::Custom(format!(
|
||||
"Could not connect to database '{}' on this server: {}",
|
||||
target_db, e
|
||||
))
|
||||
})?;
|
||||
let rows = sqlx::query(
|
||||
"SELECT table_schema, table_name FROM information_schema.tables \
|
||||
WHERE table_schema NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \
|
||||
AND table_type = 'BASE TABLE' \
|
||||
ORDER BY table_schema, table_name",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
pool.close().await;
|
||||
|
||||
let names: Vec<String> = rows
|
||||
.iter()
|
||||
.map(|r| format!("{}.{}", r.get::<String, _>(0), r.get::<String, _>(1)))
|
||||
.collect();
|
||||
|
||||
state.tables_by_db_cache.write().await.insert(
|
||||
cache_key,
|
||||
CachedVec {
|
||||
value: names.clone(),
|
||||
cached_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
Ok(names)
|
||||
}
|
||||
|
||||
async fn list_tables_clickhouse(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
target_db: &str,
|
||||
) -> TuskResult<Vec<String>> {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let escaped = target_db.replace('\\', "\\\\").replace('\'', "\\'");
|
||||
let sql = format!(
|
||||
"SELECT name FROM system.tables WHERE database = '{}' ORDER BY name",
|
||||
escaped
|
||||
);
|
||||
let rows = client.fetch_objects(&sql).await?;
|
||||
Ok(rows
|
||||
.iter()
|
||||
.filter_map(|r| r.get("name").and_then(|v| v.as_str()).map(String::from))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// get_columns
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn get_columns_tool(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
tables: &[String],
|
||||
) -> TuskResult<String> {
|
||||
if tables.is_empty() {
|
||||
return Err(TuskError::Custom("get_columns requires at least one table.".into()));
|
||||
}
|
||||
if tables.len() > MAX_TABLES_PER_GET_COLUMNS {
|
||||
return Err(TuskError::Custom(format!(
|
||||
"Too many tables ({}); split into batches of ≤{}.",
|
||||
tables.len(),
|
||||
MAX_TABLES_PER_GET_COLUMNS
|
||||
)));
|
||||
}
|
||||
|
||||
let active_db = active_db_name(state, connection_id).await.unwrap_or_default();
|
||||
|
||||
// Normalise: accept "schema.table", "db.schema.table" (drop db if == active),
|
||||
// and "table" (assume schema "public" for PG, or active DB for CH).
|
||||
let parsed: Vec<(String, String, String)> = tables
|
||||
.iter()
|
||||
.map(|raw| normalise_table_ref(raw, &active_db))
|
||||
.collect();
|
||||
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return get_columns_clickhouse(state, connection_id, &parsed).await;
|
||||
}
|
||||
get_columns_postgres(state, connection_id, &parsed).await
|
||||
}
|
||||
|
||||
fn normalise_table_ref(raw: &str, active_db: &str) -> (String, String, String) {
|
||||
// Returns (schema, table, original_input_for_diagnostics)
|
||||
let trimmed = raw.trim().trim_matches('"').trim_matches('`');
|
||||
let parts: Vec<&str> = trimmed.split('.').collect();
|
||||
match parts.len() {
|
||||
1 => ("public".to_string(), parts[0].to_string(), raw.to_string()),
|
||||
2 => (parts[0].to_string(), parts[1].to_string(), raw.to_string()),
|
||||
3 => {
|
||||
// "db.schema.table" — drop db prefix when it matches active
|
||||
let (db, schema, table) = (parts[0], parts[1], parts[2]);
|
||||
if db == active_db {
|
||||
(schema.to_string(), table.to_string(), raw.to_string())
|
||||
} else {
|
||||
// Different DB requested — let the caller surface a not-found warning.
|
||||
// We still parse it as schema.table here.
|
||||
(schema.to_string(), table.to_string(), raw.to_string())
|
||||
}
|
||||
}
|
||||
_ => ("public".to_string(), trimmed.to_string(), raw.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_columns_postgres(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
requested: &[(String, String, String)],
|
||||
) -> TuskResult<String> {
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
|
||||
let (col_res, fk_res, enum_res, tbl_comm_res, col_comm_res, unique_res) = tokio::join!(
|
||||
fetch_columns(&pool),
|
||||
fetch_foreign_keys_raw(&pool),
|
||||
fetch_enum_types(&pool),
|
||||
fetch_table_comments(&pool),
|
||||
fetch_column_comments(&pool),
|
||||
fetch_unique_constraints(&pool),
|
||||
);
|
||||
let all_cols = col_res?;
|
||||
let fk_rows = fk_res?;
|
||||
let enum_map = enum_res.unwrap_or_default();
|
||||
let tbl_comments = tbl_comm_res.unwrap_or_default();
|
||||
let col_comments = col_comm_res.unwrap_or_default();
|
||||
let uniques = unique_res.unwrap_or_default();
|
||||
|
||||
// Build (schema, table) → Vec<ColumnInfo>
|
||||
let mut by_table: BTreeMap<(String, String), Vec<ColumnInfo>> = BTreeMap::new();
|
||||
for ci in &all_cols {
|
||||
by_table
|
||||
.entry((ci.schema.clone(), ci.table.clone()))
|
||||
.or_default()
|
||||
.push(ci.clone());
|
||||
}
|
||||
|
||||
let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new();
|
||||
for fk in &fk_rows {
|
||||
if fk.columns.len() == 1 && fk.ref_columns.len() == 1 {
|
||||
fk_inline.insert(
|
||||
(fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()),
|
||||
format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let mut unique_map: HashMap<(String, String), Vec<String>> = HashMap::new();
|
||||
for (schema, table, cols) in &uniques {
|
||||
unique_map
|
||||
.entry((schema.clone(), table.clone()))
|
||||
.or_default()
|
||||
.push(cols.join(", "));
|
||||
}
|
||||
|
||||
let varchar_values: HashMap<(String, String, String), Vec<String>> = HashMap::new();
|
||||
let jsonb_keys: HashMap<(String, String, String), Vec<String>> = HashMap::new();
|
||||
|
||||
let mut output: Vec<String> = Vec::new();
|
||||
let mut not_found: Vec<String> = Vec::new();
|
||||
|
||||
for (schema, table, raw) in requested {
|
||||
match by_table.get(&(schema.clone(), table.clone())) {
|
||||
Some(cols) => {
|
||||
let full_name = format!("{}.{}", schema, table);
|
||||
format_table_block(
|
||||
&full_name,
|
||||
cols,
|
||||
&tbl_comments,
|
||||
&col_comments,
|
||||
&fk_inline,
|
||||
&enum_map,
|
||||
&unique_map,
|
||||
&varchar_values,
|
||||
&jsonb_keys,
|
||||
&mut output,
|
||||
);
|
||||
}
|
||||
None => not_found.push(raw.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
if !not_found.is_empty() {
|
||||
let nearest = nearest_table_matches(&by_table, ¬_found);
|
||||
let header = format!(
|
||||
"WARNING: tables not found: {}.{}",
|
||||
not_found.join(", "),
|
||||
if nearest.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Nearest matches: {}.", nearest.join(", "))
|
||||
}
|
||||
);
|
||||
output.insert(0, header);
|
||||
output.insert(1, String::new());
|
||||
}
|
||||
|
||||
let mut text = output.join("\n");
|
||||
if text.len() > COLUMNS_TOOL_OUTPUT_CAP {
|
||||
text.truncate(COLUMNS_TOOL_OUTPUT_CAP);
|
||||
text.push_str("\n... (output truncated)");
|
||||
}
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
async fn get_columns_clickhouse(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
requested: &[(String, String, String)],
|
||||
) -> TuskResult<String> {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let active_db = client.database.clone();
|
||||
|
||||
let where_terms: Vec<String> = requested
|
||||
.iter()
|
||||
.map(|(schema, table, _)| {
|
||||
// For CH, treat the parsed "schema" as the database name; if it equals
|
||||
// a PG-conventional default ("public"), substitute with active CH database.
|
||||
let dbn = if schema == "public" { active_db.clone() } else { schema.clone() };
|
||||
format!(
|
||||
"(database = '{}' AND name = '{}')",
|
||||
dbn.replace('\'', "\\'"),
|
||||
table.replace('\'', "\\'")
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let where_clause = where_terms.join(" OR ");
|
||||
|
||||
let sql = format!(
|
||||
"SELECT database, table, name, type, default_expression, is_in_primary_key, comment, position \
|
||||
FROM system.columns WHERE {} ORDER BY database, table, position",
|
||||
where_clause
|
||||
);
|
||||
let rows = client.fetch_objects(&sql).await?;
|
||||
|
||||
// Group by (database, table)
|
||||
let mut grouped: BTreeMap<(String, String), Vec<&serde_json::Map<String, serde_json::Value>>> =
|
||||
BTreeMap::new();
|
||||
for row in &rows {
|
||||
let dbn = row.get("database").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
let tbl = row.get("table").and_then(|v| v.as_str()).unwrap_or("").to_string();
|
||||
grouped.entry((dbn, tbl)).or_default().push(row);
|
||||
}
|
||||
|
||||
// Track which requested tables were found
|
||||
let mut output = String::new();
|
||||
let mut not_found: Vec<String> = Vec::new();
|
||||
for (schema, table, raw) in requested {
|
||||
let dbn = if schema == "public" { active_db.clone() } else { schema.clone() };
|
||||
match grouped.get(&(dbn.clone(), table.clone())) {
|
||||
Some(cols) => {
|
||||
output.push_str(&format!("\nTABLE {}.{}\n", dbn, table));
|
||||
for col in cols {
|
||||
let name = col.get("name").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let dtype = col.get("type").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let is_pk = matches!(
|
||||
col.get("is_in_primary_key"),
|
||||
Some(serde_json::Value::Number(n)) if n.as_i64() == Some(1)
|
||||
) || matches!(
|
||||
col.get("is_in_primary_key"),
|
||||
Some(serde_json::Value::String(s)) if s == "1"
|
||||
);
|
||||
let default = col.get("default_expression").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let comment = col.get("comment").and_then(|v| v.as_str()).unwrap_or("");
|
||||
let mut line = format!(" {} {}", name, dtype);
|
||||
if is_pk {
|
||||
line.push_str(" [PK]");
|
||||
}
|
||||
if !default.is_empty() {
|
||||
line.push_str(&format!(" DEFAULT {}", default));
|
||||
}
|
||||
if !comment.is_empty() {
|
||||
line.push_str(&format!(" -- {}", comment));
|
||||
}
|
||||
output.push_str(&line);
|
||||
output.push('\n');
|
||||
}
|
||||
}
|
||||
None => not_found.push(raw.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
let mut header = String::new();
|
||||
if !not_found.is_empty() {
|
||||
header.push_str(&format!(
|
||||
"WARNING: tables not found: {}\n\n",
|
||||
not_found.join(", ")
|
||||
));
|
||||
}
|
||||
let mut combined = format!("{}{}", header, output.trim_start());
|
||||
if combined.len() > COLUMNS_TOOL_OUTPUT_CAP {
|
||||
combined.truncate(COLUMNS_TOOL_OUTPUT_CAP);
|
||||
combined.push_str("\n... (output truncated)");
|
||||
}
|
||||
Ok(combined)
|
||||
}
|
||||
|
||||
fn nearest_table_matches(
|
||||
by_table: &BTreeMap<(String, String), Vec<ColumnInfo>>,
|
||||
missing: &[String],
|
||||
) -> Vec<String> {
|
||||
let all: Vec<String> = by_table
|
||||
.keys()
|
||||
.map(|(s, t)| format!("{}.{}", s, t))
|
||||
.collect();
|
||||
let mut hints: Vec<String> = Vec::new();
|
||||
for m in missing {
|
||||
let needle = m.to_lowercase();
|
||||
let mut candidates: Vec<&String> = all
|
||||
.iter()
|
||||
.filter(|n| {
|
||||
let lower = n.to_lowercase();
|
||||
lower.contains(&needle) || needle.contains(lower.split('.').last().unwrap_or(""))
|
||||
})
|
||||
.take(3)
|
||||
.collect();
|
||||
candidates.dedup();
|
||||
for c in candidates {
|
||||
if !hints.contains(c) {
|
||||
hints.push(c.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
hints
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// switch_database
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn switch_database_tool(
|
||||
app: &AppHandle,
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
target_db: &str,
|
||||
) -> TuskResult<String> {
|
||||
let config = load_connection_config(app, connection_id)?;
|
||||
|
||||
// Verify target exists in cluster
|
||||
let dbs = list_databases_core(state, connection_id).await?;
|
||||
if !dbs.iter().any(|d| d == target_db) {
|
||||
return Err(TuskError::Custom(format!(
|
||||
"Database '{}' does not exist on this server. Available: {}",
|
||||
target_db,
|
||||
dbs.join(", ")
|
||||
)));
|
||||
}
|
||||
|
||||
switch_database_core(state, &config, target_db).await?;
|
||||
Ok(format!("Switched active database to '{}'.", target_db))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn active_db_name(state: &AppState, connection_id: &str) -> Option<String> {
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return state
|
||||
.get_ch_client(connection_id)
|
||||
.await
|
||||
.ok()
|
||||
.map(|c| c.database.clone());
|
||||
}
|
||||
let pool = state.get_pool(connection_id).await.ok()?;
|
||||
sqlx::query_scalar::<_, String>("SELECT current_database()")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.ok()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// save_query / find_queries (chat v3 — F2)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const FIND_QUERIES_LIMIT: usize = 10;
|
||||
const FIND_QUERIES_SQL_PREVIEW_CHARS: usize = 500;
|
||||
|
||||
pub async fn save_query_tool(
|
||||
app: &AppHandle,
|
||||
connection_id: &str,
|
||||
name: &str,
|
||||
sql: &str,
|
||||
) -> TuskResult<String> {
|
||||
let trimmed_name = name.trim();
|
||||
let trimmed_sql = sql.trim();
|
||||
if trimmed_name.is_empty() {
|
||||
return Err(TuskError::Custom("save_query: name must not be empty".into()));
|
||||
}
|
||||
if trimmed_sql.is_empty() {
|
||||
return Err(TuskError::Custom("save_query: sql must not be empty".into()));
|
||||
}
|
||||
|
||||
let entry = SavedQuery {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
name: trimmed_name.to_string(),
|
||||
sql: trimmed_sql.to_string(),
|
||||
connection_id: Some(connection_id.to_string()),
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
};
|
||||
save_query_core(app, entry).await?;
|
||||
Ok(format!("Saved query \"{}\" — visible in sidebar → Saved.", trimmed_name))
|
||||
}
|
||||
|
||||
pub async fn find_queries_tool(
|
||||
app: &AppHandle,
|
||||
connection_id: &str,
|
||||
text: &str,
|
||||
) -> TuskResult<String> {
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return Err(TuskError::Custom("find_queries: text must not be empty".into()));
|
||||
}
|
||||
|
||||
let all = list_saved_queries_core(app, Some(trimmed)).await?;
|
||||
let matches: Vec<SavedQuery> = all
|
||||
.into_iter()
|
||||
.filter(|q| q.connection_id.as_deref() == Some(connection_id))
|
||||
.take(FIND_QUERIES_LIMIT)
|
||||
.collect();
|
||||
|
||||
if matches.is_empty() {
|
||||
return Ok(format!(
|
||||
"No saved queries match \"{}\" for this connection.",
|
||||
trimmed
|
||||
));
|
||||
}
|
||||
|
||||
let mut out = format!(
|
||||
"Saved queries matching \"{}\" ({}):",
|
||||
trimmed,
|
||||
matches.len()
|
||||
);
|
||||
for q in &matches {
|
||||
let sql_preview: String = if q.sql.chars().count() > FIND_QUERIES_SQL_PREVIEW_CHARS {
|
||||
let truncated: String = q.sql.chars().take(FIND_QUERIES_SQL_PREVIEW_CHARS).collect();
|
||||
format!("{}…", truncated)
|
||||
} else {
|
||||
q.sql.clone()
|
||||
};
|
||||
out.push_str(&format!(
|
||||
"\n\n[{}] {}\n{}",
|
||||
q.created_at, q.name, sql_preview
|
||||
));
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::db::clickhouse::ChClient;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::connection::ConnectionConfig;
|
||||
use crate::state::{AppState, DbFlavor};
|
||||
@@ -23,6 +24,34 @@ pub(crate) fn get_connections_path(app: &AppHandle) -> TuskResult<std::path::Pat
|
||||
Ok(dir.join("connections.json"))
|
||||
}
|
||||
|
||||
/// Read all saved connection configs from disk.
|
||||
pub(crate) fn load_all_connections(app: &AppHandle) -> TuskResult<Vec<ConnectionConfig>> {
|
||||
let path = get_connections_path(app)?;
|
||||
if !path.exists() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let data = fs::read_to_string(&path)?;
|
||||
let connections: Vec<ConnectionConfig> = serde_json::from_str(&data)?;
|
||||
Ok(connections)
|
||||
}
|
||||
|
||||
/// Look up a single saved connection by id. Used by tools that need credentials
|
||||
/// (e.g. switch_database from inside the chat agent loop) but only have the id in scope.
|
||||
pub(crate) fn load_connection_config(
|
||||
app: &AppHandle,
|
||||
connection_id: &str,
|
||||
) -> TuskResult<ConnectionConfig> {
|
||||
load_all_connections(app)?
|
||||
.into_iter()
|
||||
.find(|c| c.id == connection_id)
|
||||
.ok_or_else(|| {
|
||||
TuskError::Custom(format!(
|
||||
"Connection '{}' not found in connections.json",
|
||||
connection_id
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_connections(app: AppHandle) -> TuskResult<Vec<ConnectionConfig>> {
|
||||
let path = get_connections_path(&app)?;
|
||||
@@ -55,6 +84,24 @@ pub async fn save_connection(app: AppHandle, config: ConnectionConfig) -> TuskRe
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn close_connection(state: &AppState, id: &str) {
|
||||
let mut pools = state.pools.write().await;
|
||||
if let Some(pool) = pools.remove(id) {
|
||||
pool.close().await;
|
||||
}
|
||||
drop(pools);
|
||||
let mut clients = state.ch_clients.write().await;
|
||||
clients.remove(id);
|
||||
drop(clients);
|
||||
let mut ro = state.read_only.write().await;
|
||||
ro.remove(id);
|
||||
drop(ro);
|
||||
let mut flavors = state.db_flavors.write().await;
|
||||
flavors.remove(id);
|
||||
drop(flavors);
|
||||
state.invalidate_chat_caches_for(id).await;
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn delete_connection(
|
||||
app: AppHandle,
|
||||
@@ -69,36 +116,37 @@ pub async fn delete_connection(
|
||||
let data = serde_json::to_string_pretty(&connections)?;
|
||||
fs::write(&path, data)?;
|
||||
}
|
||||
|
||||
// Close pool if connected
|
||||
let mut pools = state.pools.write().await;
|
||||
if let Some(pool) = pools.remove(&id) {
|
||||
pool.close().await;
|
||||
}
|
||||
|
||||
let mut ro = state.read_only.write().await;
|
||||
ro.remove(&id);
|
||||
|
||||
let mut flavors = state.db_flavors.write().await;
|
||||
flavors.remove(&id);
|
||||
|
||||
close_connection(&state, &id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn test_connection(config: ConnectionConfig) -> TuskResult<String> {
|
||||
let pool = PgPool::connect(&config.connection_url())
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let row = sqlx::query("SELECT version()")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let version: String = row.get(0);
|
||||
pool.close().await;
|
||||
Ok(version)
|
||||
match config.db_flavor {
|
||||
DbFlavor::ClickHouse => {
|
||||
let client = ChClient::new(
|
||||
&config.host,
|
||||
config.port,
|
||||
config.secure,
|
||||
&config.user,
|
||||
&config.password,
|
||||
&config.database,
|
||||
);
|
||||
client.ping().await
|
||||
}
|
||||
_ => {
|
||||
let pool = PgPool::connect(&config.connection_url())
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let row = sqlx::query("SELECT version()")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let version: String = row.get(0);
|
||||
pool.close().await;
|
||||
Ok(version)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
@@ -106,39 +154,110 @@ pub async fn connect(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
config: ConnectionConfig,
|
||||
) -> TuskResult<ConnectResult> {
|
||||
let pool = PgPool::connect(&config.connection_url())
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
match config.db_flavor {
|
||||
DbFlavor::ClickHouse => {
|
||||
let client = ChClient::new(
|
||||
&config.host,
|
||||
config.port,
|
||||
config.secure,
|
||||
&config.user,
|
||||
&config.password,
|
||||
&config.database,
|
||||
);
|
||||
let version = client.ping().await?;
|
||||
let arc = Arc::new(client);
|
||||
state.ch_clients.write().await.insert(config.id.clone(), arc);
|
||||
state.read_only.write().await.insert(config.id.clone(), true);
|
||||
state
|
||||
.db_flavors
|
||||
.write()
|
||||
.await
|
||||
.insert(config.id.clone(), DbFlavor::ClickHouse);
|
||||
Ok(ConnectResult {
|
||||
version,
|
||||
flavor: DbFlavor::ClickHouse,
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
let pool = PgPool::connect(&config.connection_url())
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
sqlx::query("SELECT 1")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let row = sqlx::query("SELECT version()")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let version: String = row.get(0);
|
||||
let flavor = if version.to_lowercase().contains("greenplum") {
|
||||
DbFlavor::Greenplum
|
||||
} else {
|
||||
DbFlavor::PostgreSQL
|
||||
};
|
||||
state.pools.write().await.insert(config.id.clone(), pool);
|
||||
state.read_only.write().await.insert(config.id.clone(), true);
|
||||
state
|
||||
.db_flavors
|
||||
.write()
|
||||
.await
|
||||
.insert(config.id.clone(), flavor);
|
||||
Ok(ConnectResult { version, flavor })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify connection
|
||||
sqlx::query("SELECT 1")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
/// Core implementation of switching the active database for a connection.
|
||||
/// Reusable from both the Tauri command (frontend-driven) and the chat agent
|
||||
/// loop (model-driven via the switch_database tool).
|
||||
pub(crate) async fn switch_database_core(
|
||||
state: &AppState,
|
||||
config: &ConnectionConfig,
|
||||
database: &str,
|
||||
) -> TuskResult<()> {
|
||||
let mut switched = config.clone();
|
||||
switched.database = database.to_string();
|
||||
|
||||
// Detect database flavor via version()
|
||||
let row = sqlx::query("SELECT version()")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let version: String = row.get(0);
|
||||
|
||||
let flavor = if version.to_lowercase().contains("greenplum") {
|
||||
DbFlavor::Greenplum
|
||||
} else {
|
||||
DbFlavor::PostgreSQL
|
||||
let result: TuskResult<()> = match config.db_flavor {
|
||||
DbFlavor::ClickHouse => {
|
||||
let client = ChClient::new(
|
||||
&switched.host,
|
||||
switched.port,
|
||||
switched.secure,
|
||||
&switched.user,
|
||||
&switched.password,
|
||||
&switched.database,
|
||||
);
|
||||
client.ping().await?;
|
||||
state
|
||||
.ch_clients
|
||||
.write()
|
||||
.await
|
||||
.insert(config.id.clone(), Arc::new(client));
|
||||
Ok(())
|
||||
}
|
||||
_ => {
|
||||
let pool = PgPool::connect(&switched.connection_url())
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
sqlx::query("SELECT 1")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
let mut pools = state.pools.write().await;
|
||||
if let Some(old_pool) = pools.remove(&config.id) {
|
||||
old_pool.close().await;
|
||||
}
|
||||
pools.insert(config.id.clone(), pool);
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
|
||||
let mut pools = state.pools.write().await;
|
||||
pools.insert(config.id.clone(), pool);
|
||||
// Drop every cache that's bound to this connection's previous database.
|
||||
state.invalidate_chat_caches_for(&config.id).await;
|
||||
|
||||
let mut ro = state.read_only.write().await;
|
||||
ro.insert(config.id.clone(), true);
|
||||
|
||||
let mut flavors = state.db_flavors.write().await;
|
||||
flavors.insert(config.id.clone(), flavor);
|
||||
|
||||
Ok(ConnectResult { version, flavor })
|
||||
result
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
@@ -147,40 +266,12 @@ pub async fn switch_database(
|
||||
config: ConnectionConfig,
|
||||
database: String,
|
||||
) -> TuskResult<()> {
|
||||
let mut switched = config.clone();
|
||||
switched.database = database;
|
||||
|
||||
let pool = PgPool::connect(&switched.connection_url())
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
sqlx::query("SELECT 1")
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let mut pools = state.pools.write().await;
|
||||
if let Some(old_pool) = pools.remove(&config.id) {
|
||||
old_pool.close().await;
|
||||
}
|
||||
pools.insert(config.id.clone(), pool);
|
||||
|
||||
Ok(())
|
||||
switch_database_core(&state, &config, &database).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn disconnect(state: State<'_, Arc<AppState>>, id: String) -> TuskResult<()> {
|
||||
let mut pools = state.pools.write().await;
|
||||
if let Some(pool) = pools.remove(&id) {
|
||||
pool.close().await;
|
||||
}
|
||||
|
||||
let mut ro = state.read_only.write().await;
|
||||
ro.remove(&id);
|
||||
|
||||
let mut flavors = state.db_flavors.write().await;
|
||||
flavors.remove(&id);
|
||||
|
||||
close_connection(&state, &id).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::commands::queries::pg_value_to_json;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::query_result::PaginatedQueryResult;
|
||||
use crate::state::AppState;
|
||||
use crate::state::{AppState, DbFlavor};
|
||||
use crate::utils::escape_ident;
|
||||
use serde_json::Value;
|
||||
use sqlx::{Column, Row, TypeInfo};
|
||||
@@ -9,6 +9,80 @@ use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tauri::State;
|
||||
|
||||
async fn ch_get_table_data(
|
||||
state: &AppState,
|
||||
connection_id: &str,
|
||||
schema: &str,
|
||||
table: &str,
|
||||
page: u32,
|
||||
page_size: u32,
|
||||
sort_column: Option<&str>,
|
||||
sort_direction: Option<&str>,
|
||||
filter: Option<&str>,
|
||||
) -> TuskResult<PaginatedQueryResult> {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let qualified = format!(
|
||||
"{}.{}",
|
||||
ch_quote_ident(schema),
|
||||
ch_quote_ident(table)
|
||||
);
|
||||
|
||||
let mut where_clause = String::new();
|
||||
if let Some(f) = filter {
|
||||
if !f.trim().is_empty() {
|
||||
crate::db::sql_guard::ensure_readonly_sql(&format!("SELECT 1 FROM x WHERE {}", f))?;
|
||||
where_clause = format!(" WHERE {}", f);
|
||||
}
|
||||
}
|
||||
|
||||
let mut order_clause = String::new();
|
||||
if let Some(col) = sort_column {
|
||||
if !col.trim().is_empty() {
|
||||
let dir = match sort_direction {
|
||||
Some("DESC") | Some("desc") => "DESC",
|
||||
_ => "ASC",
|
||||
};
|
||||
order_clause = format!(" ORDER BY {} {}", ch_quote_ident(col), dir);
|
||||
}
|
||||
}
|
||||
|
||||
let offset = (page.saturating_sub(1)) as i64 * page_size as i64;
|
||||
let data_sql = format!(
|
||||
"SELECT * FROM {}{}{} LIMIT {} OFFSET {}",
|
||||
qualified, where_clause, order_clause, page_size, offset
|
||||
);
|
||||
let count_sql = format!("SELECT count() AS c FROM {}{}", qualified, where_clause);
|
||||
|
||||
let result = client.execute_query(&data_sql, true).await?;
|
||||
let count_rows = client.fetch_objects(&count_sql).await?;
|
||||
let total_rows = count_rows
|
||||
.first()
|
||||
.and_then(|o| o.get("c"))
|
||||
.and_then(|v| match v {
|
||||
Value::Number(n) => n.as_i64(),
|
||||
Value::String(s) => s.parse::<i64>().ok(),
|
||||
_ => None,
|
||||
})
|
||||
.unwrap_or(0);
|
||||
|
||||
Ok(PaginatedQueryResult {
|
||||
columns: result.columns,
|
||||
types: result.types,
|
||||
rows: result.rows,
|
||||
row_count: result.row_count,
|
||||
execution_time_ms: result.execution_time_ms,
|
||||
total_rows,
|
||||
page,
|
||||
page_size,
|
||||
ctids: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn ch_quote_ident(s: &str) -> String {
|
||||
let escaped = s.replace('`', "``");
|
||||
format!("`{}`", escaped)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn get_table_data(
|
||||
@@ -22,6 +96,21 @@ pub async fn get_table_data(
|
||||
sort_direction: Option<String>,
|
||||
filter: Option<String>,
|
||||
) -> TuskResult<PaginatedQueryResult> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return ch_get_table_data(
|
||||
&state,
|
||||
&connection_id,
|
||||
&schema,
|
||||
&table,
|
||||
page,
|
||||
page_size,
|
||||
sort_column.as_deref(),
|
||||
sort_direction.as_deref(),
|
||||
filter.as_deref(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table));
|
||||
@@ -74,7 +163,7 @@ pub async fn get_table_data(
|
||||
|
||||
tx.rollback().await.map_err(TuskError::Database)?;
|
||||
|
||||
let execution_time_ms = start.elapsed().as_millis();
|
||||
let execution_time_ms = start.elapsed().as_millis() as u64;
|
||||
let total_rows: i64 = count_row.get(0);
|
||||
|
||||
let mut all_columns = Vec::new();
|
||||
@@ -146,6 +235,11 @@ pub async fn update_row(
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
if matches!(state.get_flavor(&connection_id).await, DbFlavor::ClickHouse) {
|
||||
return Err(TuskError::Custom(
|
||||
"Inline row edit is not supported for ClickHouse — use SQL ALTER … UPDATE.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
@@ -202,6 +296,11 @@ pub async fn insert_row(
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
if matches!(state.get_flavor(&connection_id).await, DbFlavor::ClickHouse) {
|
||||
return Err(TuskError::Custom(
|
||||
"Inline row insert is not supported for ClickHouse — use SQL INSERT.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
@@ -240,6 +339,11 @@ pub async fn delete_rows(
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
if matches!(state.get_flavor(&connection_id).await, DbFlavor::ClickHouse) {
|
||||
return Err(TuskError::Custom(
|
||||
"Inline row delete is not supported for ClickHouse — use SQL ALTER … DELETE.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,367 +0,0 @@
|
||||
use crate::commands::queries::pg_value_to_json;
|
||||
use crate::error::TuskResult;
|
||||
use crate::models::connection::ConnectionConfig;
|
||||
use crate::models::lookup::{
|
||||
EntityLookupResult, LookupDatabaseResult, LookupProgress, LookupTableMatch,
|
||||
};
|
||||
use crate::utils::escape_ident;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use sqlx::{Column, Row, TypeInfo};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tauri::{AppHandle, Emitter};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
struct TableCandidate {
|
||||
schema: String,
|
||||
table: String,
|
||||
data_type: String,
|
||||
}
|
||||
|
||||
async fn search_database(
|
||||
config: &ConnectionConfig,
|
||||
database: &str,
|
||||
column_name: &str,
|
||||
value: &str,
|
||||
) -> LookupDatabaseResult {
|
||||
let start = Instant::now();
|
||||
|
||||
let mut db_config = config.clone();
|
||||
db_config.database = database.to_string();
|
||||
let url = db_config.connection_url();
|
||||
|
||||
let pool = match PgPoolOptions::new()
|
||||
.max_connections(2)
|
||||
.acquire_timeout(std::time::Duration::from_secs(5))
|
||||
.connect(&url)
|
||||
.await
|
||||
{
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return LookupDatabaseResult {
|
||||
database: database.to_string(),
|
||||
tables: vec![],
|
||||
error: Some(format!("Connection failed: {}", e)),
|
||||
search_time_ms: start.elapsed().as_millis(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(120),
|
||||
search_database_inner(&pool, database, column_name, value),
|
||||
)
|
||||
.await;
|
||||
|
||||
pool.close().await;
|
||||
|
||||
match result {
|
||||
Ok(db_result) => {
|
||||
let mut db_result = db_result;
|
||||
db_result.search_time_ms = start.elapsed().as_millis();
|
||||
db_result
|
||||
}
|
||||
Err(_) => LookupDatabaseResult {
|
||||
database: database.to_string(),
|
||||
tables: vec![],
|
||||
error: Some("Timeout (120s)".to_string()),
|
||||
search_time_ms: start.elapsed().as_millis(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn search_database_inner(
|
||||
pool: &sqlx::PgPool,
|
||||
database: &str,
|
||||
column_name: &str,
|
||||
value: &str,
|
||||
) -> LookupDatabaseResult {
|
||||
// Find tables that have this column
|
||||
let candidates = match sqlx::query_as::<_, (String, String, String)>(
|
||||
"SELECT table_schema, table_name, data_type \
|
||||
FROM information_schema.columns \
|
||||
WHERE column_name = $1 \
|
||||
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit')",
|
||||
)
|
||||
.bind(column_name)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
{
|
||||
Ok(rows) => rows
|
||||
.into_iter()
|
||||
.map(|(schema, table, data_type)| TableCandidate {
|
||||
schema,
|
||||
table,
|
||||
data_type,
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
Err(e) => {
|
||||
return LookupDatabaseResult {
|
||||
database: database.to_string(),
|
||||
tables: vec![],
|
||||
error: Some(format!("Schema query failed: {}", e)),
|
||||
search_time_ms: 0,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let mut tables = Vec::new();
|
||||
|
||||
for candidate in &candidates {
|
||||
let qualified = format!(
|
||||
"{}.{}",
|
||||
escape_ident(&candidate.schema),
|
||||
escape_ident(&candidate.table)
|
||||
);
|
||||
let col_ident = escape_ident(column_name);
|
||||
|
||||
// Read-only transaction: SELECT rows + COUNT
|
||||
let select_sql = format!(
|
||||
"SELECT * FROM {} WHERE {}::text = $1 LIMIT 50",
|
||||
qualified, col_ident
|
||||
);
|
||||
let count_sql = format!(
|
||||
"SELECT COUNT(*) FROM {} WHERE {}::text = $1",
|
||||
qualified, col_ident
|
||||
);
|
||||
|
||||
let mut tx = match pool.begin().await {
|
||||
Ok(tx) => tx,
|
||||
Err(e) => {
|
||||
tables.push(LookupTableMatch {
|
||||
schema: candidate.schema.clone(),
|
||||
table: candidate.table.clone(),
|
||||
column_type: candidate.data_type.clone(),
|
||||
columns: vec![],
|
||||
types: vec![],
|
||||
rows: vec![],
|
||||
row_count: 0,
|
||||
total_count: 0,
|
||||
});
|
||||
log::warn!(
|
||||
"Failed to begin tx for {}.{}: {}",
|
||||
candidate.schema,
|
||||
candidate.table,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = sqlx::query("SET TRANSACTION READ ONLY")
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
{
|
||||
let _ = tx.rollback().await;
|
||||
log::warn!("Failed SET TRANSACTION READ ONLY: {}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
let rows_result = sqlx::query(&select_sql)
|
||||
.bind(value)
|
||||
.fetch_all(&mut *tx)
|
||||
.await;
|
||||
|
||||
let count_result: Result<i64, _> = sqlx::query_scalar(&count_sql)
|
||||
.bind(value)
|
||||
.fetch_one(&mut *tx)
|
||||
.await;
|
||||
|
||||
let _ = tx.rollback().await;
|
||||
|
||||
match rows_result {
|
||||
Ok(rows) if !rows.is_empty() => {
|
||||
let mut col_names = Vec::new();
|
||||
let mut col_types = Vec::new();
|
||||
if let Some(first) = rows.first() {
|
||||
for col in first.columns() {
|
||||
col_names.push(col.name().to_string());
|
||||
col_types.push(col.type_info().name().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let result_rows: Vec<Vec<serde_json::Value>> = rows
|
||||
.iter()
|
||||
.map(|row| {
|
||||
(0..col_names.len())
|
||||
.map(|i| pg_value_to_json(row, i))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let row_count = result_rows.len();
|
||||
let total_count = count_result.unwrap_or(row_count as i64);
|
||||
|
||||
tables.push(LookupTableMatch {
|
||||
schema: candidate.schema.clone(),
|
||||
table: candidate.table.clone(),
|
||||
column_type: candidate.data_type.clone(),
|
||||
columns: col_names,
|
||||
types: col_types,
|
||||
rows: result_rows,
|
||||
row_count,
|
||||
total_count,
|
||||
});
|
||||
}
|
||||
Ok(_) => {
|
||||
// No rows matched — skip
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!(
|
||||
"Query failed for {}.{}: {}",
|
||||
candidate.schema,
|
||||
candidate.table,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
LookupDatabaseResult {
|
||||
database: database.to_string(),
|
||||
tables,
|
||||
error: None,
|
||||
search_time_ms: 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn entity_lookup(
|
||||
app: AppHandle,
|
||||
config: ConnectionConfig,
|
||||
column_name: String,
|
||||
value: String,
|
||||
databases: Option<Vec<String>>,
|
||||
lookup_id: String,
|
||||
) -> TuskResult<EntityLookupResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
// 1. Get list of databases
|
||||
let url = config.connection_url();
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.acquire_timeout(std::time::Duration::from_secs(5))
|
||||
.connect(&url)
|
||||
.await
|
||||
.map_err(crate::error::TuskError::Database)?;
|
||||
|
||||
let db_names: Vec<String> = sqlx::query_scalar(
|
||||
"SELECT datname FROM pg_database WHERE datistemplate = false ORDER BY datname",
|
||||
)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(crate::error::TuskError::Database)?;
|
||||
|
||||
pool.close().await;
|
||||
|
||||
// Filter if specific databases requested
|
||||
let db_names: Vec<String> = if let Some(ref filter) = databases {
|
||||
db_names
|
||||
.into_iter()
|
||||
.filter(|d| filter.contains(d))
|
||||
.collect()
|
||||
} else {
|
||||
db_names
|
||||
};
|
||||
|
||||
let total = db_names.len();
|
||||
let completed = Arc::new(AtomicUsize::new(0));
|
||||
let semaphore = Arc::new(Semaphore::new(5));
|
||||
|
||||
// 2. Parallel search across databases
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for db_name in db_names {
|
||||
let config = config.clone();
|
||||
let column_name = column_name.clone();
|
||||
let value = value.clone();
|
||||
let lookup_id = lookup_id.clone();
|
||||
let app = app.clone();
|
||||
let semaphore = semaphore.clone();
|
||||
let completed = completed.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = semaphore.acquire().await.unwrap();
|
||||
|
||||
// Emit "searching" progress
|
||||
let _ = app.emit(
|
||||
"lookup-progress",
|
||||
LookupProgress {
|
||||
lookup_id: lookup_id.clone(),
|
||||
database: db_name.clone(),
|
||||
status: "searching".to_string(),
|
||||
tables_found: 0,
|
||||
rows_found: 0,
|
||||
error: None,
|
||||
completed: completed.load(Ordering::Relaxed),
|
||||
total,
|
||||
},
|
||||
);
|
||||
|
||||
let result = search_database(&config, &db_name, &column_name, &value).await;
|
||||
|
||||
let done = completed.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
let status = if result.error.is_some() {
|
||||
"error"
|
||||
} else {
|
||||
"done"
|
||||
};
|
||||
|
||||
let _ = app.emit(
|
||||
"lookup-progress",
|
||||
LookupProgress {
|
||||
lookup_id: lookup_id.clone(),
|
||||
database: db_name.clone(),
|
||||
status: status.to_string(),
|
||||
tables_found: result.tables.len(),
|
||||
rows_found: result.tables.iter().map(|t| t.row_count).sum(),
|
||||
error: result.error.clone(),
|
||||
completed: done,
|
||||
total,
|
||||
},
|
||||
);
|
||||
|
||||
result
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// 3. Collect results
|
||||
let mut all_results = Vec::new();
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(result) => all_results.push(result),
|
||||
Err(e) => {
|
||||
log::error!("Join error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort: databases with matches first, then by name
|
||||
all_results.sort_by(|a, b| {
|
||||
let a_has = !a.tables.is_empty();
|
||||
let b_has = !b.tables.is_empty();
|
||||
b_has.cmp(&a_has).then(a.database.cmp(&b.database))
|
||||
});
|
||||
|
||||
let total_databases_searched = all_results.len();
|
||||
let total_tables_matched: usize = all_results.iter().map(|d| d.tables.len()).sum();
|
||||
let total_rows_found: usize = all_results
|
||||
.iter()
|
||||
.flat_map(|d| d.tables.iter())
|
||||
.map(|t| t.row_count)
|
||||
.sum();
|
||||
|
||||
Ok(EntityLookupResult {
|
||||
column_name,
|
||||
value,
|
||||
databases: all_results,
|
||||
total_databases_searched,
|
||||
total_tables_matched,
|
||||
total_rows_found,
|
||||
total_time_ms: start.elapsed().as_millis(),
|
||||
})
|
||||
}
|
||||
@@ -1,594 +0,0 @@
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::management::*;
|
||||
use crate::state::{AppState, DbFlavor};
|
||||
use crate::utils::escape_ident;
|
||||
use sqlx::Row;
|
||||
use std::sync::Arc;
|
||||
use tauri::State;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_database_info(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
) -> TuskResult<Vec<DatabaseInfo>> {
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT d.datname, \
|
||||
pg_catalog.pg_get_userbyid(d.datdba) AS owner, \
|
||||
pg_catalog.pg_encoding_to_char(d.encoding) AS encoding, \
|
||||
d.datcollate, \
|
||||
d.datctype, \
|
||||
COALESCE(t.spcname, 'pg_default') AS tablespace, \
|
||||
d.datconnlimit, \
|
||||
pg_catalog.pg_size_pretty(pg_catalog.pg_database_size(d.datname)) AS size, \
|
||||
pg_catalog.shobj_description(d.oid, 'pg_database') AS description \
|
||||
FROM pg_catalog.pg_database d \
|
||||
LEFT JOIN pg_catalog.pg_tablespace t ON d.dattablespace = t.oid \
|
||||
WHERE NOT d.datistemplate \
|
||||
ORDER BY d.datname",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let databases = rows
|
||||
.iter()
|
||||
.map(|row| DatabaseInfo {
|
||||
name: row.get("datname"),
|
||||
owner: row.get("owner"),
|
||||
encoding: row.get("encoding"),
|
||||
collation: row.get("datcollate"),
|
||||
ctype: row.get("datctype"),
|
||||
tablespace: row.get("tablespace"),
|
||||
connection_limit: row.get("datconnlimit"),
|
||||
size: row.get("size"),
|
||||
description: row.get("description"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(databases)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn create_database(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
params: CreateDatabaseParams,
|
||||
) -> TuskResult<()> {
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let mut sql = format!("CREATE DATABASE {}", escape_ident(¶ms.name));
|
||||
|
||||
if let Some(ref owner) = params.owner {
|
||||
sql.push_str(&format!(" OWNER {}", escape_ident(owner)));
|
||||
}
|
||||
if let Some(ref template) = params.template {
|
||||
sql.push_str(&format!(" TEMPLATE {}", escape_ident(template)));
|
||||
}
|
||||
if let Some(ref encoding) = params.encoding {
|
||||
sql.push_str(&format!(" ENCODING '{}'", encoding.replace('\'', "''")));
|
||||
}
|
||||
if let Some(ref tablespace) = params.tablespace {
|
||||
sql.push_str(&format!(" TABLESPACE {}", escape_ident(tablespace)));
|
||||
}
|
||||
if let Some(limit) = params.connection_limit {
|
||||
sql.push_str(&format!(" CONNECTION LIMIT {}", limit));
|
||||
}
|
||||
|
||||
sqlx::query(&sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn drop_database(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
name: String,
|
||||
) -> TuskResult<()> {
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
// Terminate active connections to the target database
|
||||
sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1::name AND pid <> pg_backend_pid()")
|
||||
.bind(&name)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let drop_sql = format!("DROP DATABASE {}", escape_ident(&name));
|
||||
sqlx::query(&drop_sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_roles(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
) -> TuskResult<Vec<RoleInfo>> {
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT r.rolname, \
|
||||
r.rolsuper, \
|
||||
r.rolcanlogin, \
|
||||
r.rolcreatedb, \
|
||||
r.rolcreaterole, \
|
||||
r.rolinherit, \
|
||||
r.rolreplication, \
|
||||
r.rolconnlimit, \
|
||||
r.rolpassword IS NOT NULL AS password_set, \
|
||||
r.rolvaliduntil::text, \
|
||||
COALESCE(( \
|
||||
SELECT array_agg(g.rolname ORDER BY g.rolname) \
|
||||
FROM pg_catalog.pg_auth_members m \
|
||||
JOIN pg_catalog.pg_roles g ON m.roleid = g.oid \
|
||||
WHERE m.member = r.oid \
|
||||
), ARRAY[]::text[]) AS member_of, \
|
||||
COALESCE(( \
|
||||
SELECT array_agg(m2.rolname ORDER BY m2.rolname) \
|
||||
FROM pg_catalog.pg_auth_members am \
|
||||
JOIN pg_catalog.pg_roles m2 ON am.member = m2.oid \
|
||||
WHERE am.roleid = r.oid \
|
||||
), ARRAY[]::text[]) AS members, \
|
||||
pg_catalog.shobj_description(r.oid, 'pg_authid') AS description \
|
||||
FROM pg_catalog.pg_roles r \
|
||||
WHERE r.rolname !~ '^pg_' \
|
||||
ORDER BY r.rolname",
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let roles = rows
|
||||
.iter()
|
||||
.map(|row| RoleInfo {
|
||||
name: row.get("rolname"),
|
||||
is_superuser: row.get("rolsuper"),
|
||||
can_login: row.get("rolcanlogin"),
|
||||
can_create_db: row.get("rolcreatedb"),
|
||||
can_create_role: row.get("rolcreaterole"),
|
||||
inherit: row.get("rolinherit"),
|
||||
is_replication: row.get("rolreplication"),
|
||||
connection_limit: row.get("rolconnlimit"),
|
||||
password_set: row.get("password_set"),
|
||||
valid_until: row.get("rolvaliduntil"),
|
||||
member_of: row.get("member_of"),
|
||||
members: row.get("members"),
|
||||
description: row.get("description"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(roles)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn create_role(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
params: CreateRoleParams,
|
||||
) -> TuskResult<()> {
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let mut sql = format!("CREATE ROLE {}", escape_ident(¶ms.name));
|
||||
|
||||
let mut options = Vec::new();
|
||||
options.push(if params.login { "LOGIN" } else { "NOLOGIN" });
|
||||
options.push(if params.superuser {
|
||||
"SUPERUSER"
|
||||
} else {
|
||||
"NOSUPERUSER"
|
||||
});
|
||||
options.push(if params.createdb {
|
||||
"CREATEDB"
|
||||
} else {
|
||||
"NOCREATEDB"
|
||||
});
|
||||
options.push(if params.createrole {
|
||||
"CREATEROLE"
|
||||
} else {
|
||||
"NOCREATEROLE"
|
||||
});
|
||||
options.push(if params.inherit {
|
||||
"INHERIT"
|
||||
} else {
|
||||
"NOINHERIT"
|
||||
});
|
||||
options.push(if params.replication {
|
||||
"REPLICATION"
|
||||
} else {
|
||||
"NOREPLICATION"
|
||||
});
|
||||
|
||||
if let Some(ref password) = params.password {
|
||||
options.push("PASSWORD");
|
||||
// Will be appended separately
|
||||
sql.push_str(&format!(" {}", options.join(" ")));
|
||||
sql.push_str(&format!(" '{}'", password.replace('\'', "''")));
|
||||
} else {
|
||||
sql.push_str(&format!(" {}", options.join(" ")));
|
||||
}
|
||||
|
||||
if let Some(limit) = params.connection_limit {
|
||||
sql.push_str(&format!(" CONNECTION LIMIT {}", limit));
|
||||
}
|
||||
|
||||
if let Some(ref valid_until) = params.valid_until {
|
||||
sql.push_str(&format!(
|
||||
" VALID UNTIL '{}'",
|
||||
valid_until.replace('\'', "''")
|
||||
));
|
||||
}
|
||||
|
||||
if !params.in_roles.is_empty() {
|
||||
let roles: Vec<String> = params.in_roles.iter().map(|r| escape_ident(r)).collect();
|
||||
sql.push_str(&format!(" IN ROLE {}", roles.join(", ")));
|
||||
}
|
||||
|
||||
sqlx::query(&sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn alter_role(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
params: AlterRoleParams,
|
||||
) -> TuskResult<()> {
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let mut options = Vec::new();
|
||||
|
||||
if let Some(login) = params.login {
|
||||
options.push(if login {
|
||||
"LOGIN".to_string()
|
||||
} else {
|
||||
"NOLOGIN".to_string()
|
||||
});
|
||||
}
|
||||
if let Some(superuser) = params.superuser {
|
||||
options.push(if superuser {
|
||||
"SUPERUSER".to_string()
|
||||
} else {
|
||||
"NOSUPERUSER".to_string()
|
||||
});
|
||||
}
|
||||
if let Some(createdb) = params.createdb {
|
||||
options.push(if createdb {
|
||||
"CREATEDB".to_string()
|
||||
} else {
|
||||
"NOCREATEDB".to_string()
|
||||
});
|
||||
}
|
||||
if let Some(createrole) = params.createrole {
|
||||
options.push(if createrole {
|
||||
"CREATEROLE".to_string()
|
||||
} else {
|
||||
"NOCREATEROLE".to_string()
|
||||
});
|
||||
}
|
||||
if let Some(inherit) = params.inherit {
|
||||
options.push(if inherit {
|
||||
"INHERIT".to_string()
|
||||
} else {
|
||||
"NOINHERIT".to_string()
|
||||
});
|
||||
}
|
||||
if let Some(replication) = params.replication {
|
||||
options.push(if replication {
|
||||
"REPLICATION".to_string()
|
||||
} else {
|
||||
"NOREPLICATION".to_string()
|
||||
});
|
||||
}
|
||||
if let Some(ref password) = params.password {
|
||||
options.push(format!("PASSWORD '{}'", password.replace('\'', "''")));
|
||||
}
|
||||
if let Some(limit) = params.connection_limit {
|
||||
options.push(format!("CONNECTION LIMIT {}", limit));
|
||||
}
|
||||
if let Some(ref valid_until) = params.valid_until {
|
||||
options.push(format!("VALID UNTIL '{}'", valid_until.replace('\'', "''")));
|
||||
}
|
||||
|
||||
if !options.is_empty() {
|
||||
let sql = format!(
|
||||
"ALTER ROLE {} {}",
|
||||
escape_ident(¶ms.name),
|
||||
options.join(" ")
|
||||
);
|
||||
sqlx::query(&sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
}
|
||||
|
||||
if let Some(ref new_name) = params.rename_to {
|
||||
let sql = format!(
|
||||
"ALTER ROLE {} RENAME TO {}",
|
||||
escape_ident(¶ms.name),
|
||||
escape_ident(new_name)
|
||||
);
|
||||
sqlx::query(&sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn drop_role(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
name: String,
|
||||
) -> TuskResult<()> {
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let sql = format!("DROP ROLE {}", escape_ident(&name));
|
||||
sqlx::query(&sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_table_privileges(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
schema: String,
|
||||
table: String,
|
||||
) -> TuskResult<Vec<TablePrivilege>> {
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT grantee, table_schema, table_name, privilege_type, \
|
||||
is_grantable = 'YES' AS is_grantable \
|
||||
FROM information_schema.role_table_grants \
|
||||
WHERE table_schema = $1 AND table_name = $2 \
|
||||
ORDER BY grantee, privilege_type",
|
||||
)
|
||||
.bind(&schema)
|
||||
.bind(&table)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let privileges = rows
|
||||
.iter()
|
||||
.map(|row| TablePrivilege {
|
||||
grantee: row.get("grantee"),
|
||||
table_schema: row.get("table_schema"),
|
||||
table_name: row.get("table_name"),
|
||||
privilege_type: row.get("privilege_type"),
|
||||
is_grantable: row.get("is_grantable"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(privileges)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn grant_revoke(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
params: GrantRevokeParams,
|
||||
) -> TuskResult<()> {
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let privs = params.privileges.join(", ");
|
||||
let object_type = params.object_type.to_uppercase();
|
||||
let object_ref = escape_ident(¶ms.object_name);
|
||||
let role_ref = escape_ident(¶ms.role_name);
|
||||
|
||||
let sql = if params.action.to_uppercase() == "GRANT" {
|
||||
let grant_option = if params.with_grant_option {
|
||||
" WITH GRANT OPTION"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!(
|
||||
"GRANT {} ON {} {} TO {}{}",
|
||||
privs, object_type, object_ref, role_ref, grant_option
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"REVOKE {} ON {} {} FROM {}",
|
||||
privs, object_type, object_ref, role_ref
|
||||
)
|
||||
};
|
||||
|
||||
sqlx::query(&sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn manage_role_membership(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
params: RoleMembershipParams,
|
||||
) -> TuskResult<()> {
|
||||
if state.is_read_only(&connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let role_ref = escape_ident(¶ms.role_name);
|
||||
let member_ref = escape_ident(¶ms.member_name);
|
||||
|
||||
let sql = if params.action.to_uppercase() == "GRANT" {
|
||||
format!("GRANT {} TO {}", role_ref, member_ref)
|
||||
} else {
|
||||
format!("REVOKE {} FROM {}", role_ref, member_ref)
|
||||
};
|
||||
|
||||
sqlx::query(&sql)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_sessions(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
) -> TuskResult<Vec<SessionInfo>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let sql = if flavor == DbFlavor::Greenplum {
|
||||
"SELECT pid, usename, datname, state, query, \
|
||||
query_start::text, NULL::text as wait_event_type, NULL::text as wait_event, \
|
||||
client_addr::text \
|
||||
FROM pg_stat_activity \
|
||||
WHERE datname IS NOT NULL \
|
||||
ORDER BY query_start DESC NULLS LAST"
|
||||
} else {
|
||||
"SELECT pid, usename, datname, state, query, \
|
||||
query_start::text, wait_event_type, wait_event, \
|
||||
client_addr::text \
|
||||
FROM pg_stat_activity \
|
||||
WHERE datname IS NOT NULL \
|
||||
ORDER BY query_start DESC NULLS LAST"
|
||||
};
|
||||
|
||||
let rows = sqlx::query(sql)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let sessions = rows
|
||||
.iter()
|
||||
.map(|row| SessionInfo {
|
||||
pid: row.get("pid"),
|
||||
usename: row.get("usename"),
|
||||
datname: row.get("datname"),
|
||||
state: row.get("state"),
|
||||
query: row.get("query"),
|
||||
query_start: row.get("query_start"),
|
||||
wait_event_type: row.get("wait_event_type"),
|
||||
wait_event: row.get("wait_event"),
|
||||
client_addr: row.get("client_addr"),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn cancel_query(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
pid: i32,
|
||||
) -> TuskResult<bool> {
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let row = sqlx::query("SELECT pg_cancel_backend($1)")
|
||||
.bind(pid)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(row.get::<bool, _>(0))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn terminate_backend(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
pid: i32,
|
||||
) -> TuskResult<bool> {
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
.get(&connection_id)
|
||||
.ok_or(TuskError::NotConnected(connection_id))?;
|
||||
|
||||
let row = sqlx::query("SELECT pg_terminate_backend($1)")
|
||||
.bind(pid)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
Ok(row.get::<bool, _>(0))
|
||||
}
|
||||
214
src-tauri/src/commands/memory.rs
Normal file
214
src-tauri/src/commands/memory.rs
Normal file
@@ -0,0 +1,214 @@
|
||||
//! Per-connection long-term memory for the chat agent (F1).
|
||||
//!
|
||||
//! Stored as a markdown file at `<app_data_dir>/memory/<connection_id>.md`.
|
||||
//! The agent appends notes via the `remember` tool; the user can view and edit
|
||||
//! the file in the Memory sidebar tab. The same content is injected into the
|
||||
//! LEARNED NOTES section of the system prompt every turn.
|
||||
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use chrono::Utc;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tauri::{AppHandle, Manager};
|
||||
|
||||
/// Soft cap on memory file size. Overflow drops oldest `## ts` blocks until fits.
|
||||
pub const MEMORY_BYTE_CAP: usize = 16 * 1024;
|
||||
|
||||
pub(crate) fn get_memory_path(
|
||||
app: &AppHandle,
|
||||
connection_id: &str,
|
||||
) -> TuskResult<PathBuf> {
|
||||
let dir = app
|
||||
.path()
|
||||
.app_data_dir()
|
||||
.map_err(|e| TuskError::Config(e.to_string()))?
|
||||
.join("memory");
|
||||
fs::create_dir_all(&dir)?;
|
||||
let safe = sanitize_connection_id(connection_id);
|
||||
Ok(dir.join(format!("{}.md", safe)))
|
||||
}
|
||||
|
||||
fn sanitize_connection_id(id: &str) -> String {
|
||||
id.chars()
|
||||
.map(|c| match c {
|
||||
'a'..='z' | 'A'..='Z' | '0'..='9' | '-' | '_' => c,
|
||||
_ => '_',
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn read_memory_core(
|
||||
app: &AppHandle,
|
||||
connection_id: &str,
|
||||
) -> TuskResult<String> {
|
||||
let path = get_memory_path(app, connection_id)?;
|
||||
if !path.exists() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
Ok(fs::read_to_string(&path)?)
|
||||
}
|
||||
|
||||
pub(crate) fn write_memory_core(
|
||||
app: &AppHandle,
|
||||
connection_id: &str,
|
||||
content: &str,
|
||||
) -> TuskResult<()> {
|
||||
let path = get_memory_path(app, connection_id)?;
|
||||
let trimmed = enforce_size_cap(content, MEMORY_BYTE_CAP);
|
||||
fs::write(&path, trimmed)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn append_memory_core(
|
||||
app: &AppHandle,
|
||||
connection_id: &str,
|
||||
note: &str,
|
||||
) -> TuskResult<()> {
|
||||
let trimmed_note = note.trim();
|
||||
if trimmed_note.is_empty() {
|
||||
return Err(TuskError::Custom("remember: note must not be empty".into()));
|
||||
}
|
||||
|
||||
let existing = read_memory_core(app, connection_id)?;
|
||||
let mut buf = String::new();
|
||||
if existing.is_empty() {
|
||||
buf.push_str("# Memory\n\n");
|
||||
} else {
|
||||
buf.push_str(&existing);
|
||||
if !buf.ends_with('\n') {
|
||||
buf.push('\n');
|
||||
}
|
||||
if !buf.ends_with("\n\n") {
|
||||
buf.push('\n');
|
||||
}
|
||||
}
|
||||
let ts = Utc::now().format("%Y-%m-%dT%H:%M:%SZ");
|
||||
buf.push_str(&format!("## {}\n{}\n", ts, trimmed_note));
|
||||
|
||||
let final_content = enforce_size_cap(&buf, MEMORY_BYTE_CAP);
|
||||
let path = get_memory_path(app, connection_id)?;
|
||||
fs::write(&path, final_content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Trim the file from the *oldest* note (top) until it fits within `cap` bytes.
|
||||
/// Always preserves the trailing notes (the most recent observations). Keeps
|
||||
/// the leading `# Memory\n\n` header if present.
|
||||
pub(crate) fn enforce_size_cap(content: &str, cap: usize) -> String {
|
||||
if content.len() <= cap {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
let header = if content.starts_with("# Memory") {
|
||||
match content.find("\n## ") {
|
||||
Some(pos) => &content[..pos + 1],
|
||||
None => "# Memory\n\n",
|
||||
}
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
// Split into note blocks by "\n## " marker.
|
||||
// First block (after header) might lack the leading "## " — handle uniformly.
|
||||
let body_start = header.len();
|
||||
let body = &content[body_start..];
|
||||
|
||||
let mut blocks: Vec<&str> = Vec::new();
|
||||
let mut idx = 0;
|
||||
while idx < body.len() {
|
||||
// Find the next "\n## " starting at idx; if not found, the rest is one block.
|
||||
let rel = body[idx..].find("\n## ");
|
||||
match rel {
|
||||
Some(r) => {
|
||||
blocks.push(&body[idx..idx + r + 1]); // include trailing newline before next block
|
||||
idx = idx + r + 1; // start of "## "
|
||||
}
|
||||
None => {
|
||||
blocks.push(&body[idx..]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Drop blocks from the front until total fits.
|
||||
let mut current_size: usize = header.len() + blocks.iter().map(|b| b.len()).sum::<usize>();
|
||||
let mut start = 0usize;
|
||||
while current_size > cap && start < blocks.len() {
|
||||
current_size -= blocks[start].len();
|
||||
start += 1;
|
||||
}
|
||||
|
||||
let mut out = String::with_capacity(current_size);
|
||||
out.push_str(header);
|
||||
for b in &blocks[start..] {
|
||||
out.push_str(b);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_memory(app: AppHandle, connection_id: String) -> TuskResult<String> {
|
||||
read_memory_core(&app, &connection_id)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn save_memory(
|
||||
app: AppHandle,
|
||||
connection_id: String,
|
||||
content: String,
|
||||
) -> TuskResult<()> {
|
||||
write_memory_core(&app, &connection_id, &content)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn cap_passthrough_under_limit() {
|
||||
let small = "# Memory\n\n## 2026-01-01T00:00:00Z\nshort note\n";
|
||||
assert_eq!(enforce_size_cap(small, MEMORY_BYTE_CAP), small);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_drops_oldest_blocks() {
|
||||
// 3 blocks of ~6KB each -> 18KB total > 16KB cap
|
||||
let block_body = "x".repeat(6000);
|
||||
let content = format!(
|
||||
"# Memory\n\n## 2026-01-01T00:00:00Z\n{body}\n## 2026-02-01T00:00:00Z\n{body}\n## 2026-03-01T00:00:00Z\n{body}\n",
|
||||
body = block_body
|
||||
);
|
||||
assert!(content.len() > MEMORY_BYTE_CAP);
|
||||
let trimmed = enforce_size_cap(&content, MEMORY_BYTE_CAP);
|
||||
assert!(trimmed.len() <= MEMORY_BYTE_CAP);
|
||||
// Most recent block must survive.
|
||||
assert!(trimmed.contains("2026-03-01T00:00:00Z"));
|
||||
// Oldest must be dropped.
|
||||
assert!(!trimmed.contains("2026-01-01T00:00:00Z"));
|
||||
// Header preserved.
|
||||
assert!(trimmed.starts_with("# Memory"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cap_keeps_only_latest_when_single_block_huge() {
|
||||
let block_body = "y".repeat(20_000);
|
||||
let content = format!(
|
||||
"# Memory\n\n## 2026-01-01T00:00:00Z\n{}\n",
|
||||
block_body
|
||||
);
|
||||
let trimmed = enforce_size_cap(&content, MEMORY_BYTE_CAP);
|
||||
// Even after dropping that single block we keep at least the header,
|
||||
// so the result is just the header (or close to it).
|
||||
assert!(trimmed.starts_with("# Memory"));
|
||||
assert!(trimmed.len() <= MEMORY_BYTE_CAP);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_strips_path_chars() {
|
||||
assert_eq!(sanitize_connection_id("abc/../etc"), "abc____etc");
|
||||
assert_eq!(
|
||||
sanitize_connection_id("cf9feefd-59ab-4a7c"),
|
||||
"cf9feefd-59ab-4a7c"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
pub mod ai;
|
||||
pub mod chat;
|
||||
pub mod chat_tools;
|
||||
pub mod connections;
|
||||
pub mod data;
|
||||
pub mod docker;
|
||||
pub mod export;
|
||||
pub mod history;
|
||||
pub mod lookup;
|
||||
pub mod management;
|
||||
pub mod memory;
|
||||
pub mod queries;
|
||||
pub mod saved_queries;
|
||||
pub mod schema;
|
||||
pub mod settings;
|
||||
pub mod snapshot;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::db::sql_guard::ensure_readonly_sql;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::query_result::QueryResult;
|
||||
use crate::state::AppState;
|
||||
use crate::state::{AppState, DbFlavor};
|
||||
use serde_json::Value;
|
||||
use sqlx::postgres::PgRow;
|
||||
use sqlx::{Column, Row, TypeInfo};
|
||||
@@ -81,6 +82,16 @@ pub async fn execute_query_core(
|
||||
sql: &str,
|
||||
) -> TuskResult<QueryResult> {
|
||||
let read_only = state.is_read_only(connection_id).await;
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
|
||||
if read_only {
|
||||
ensure_readonly_sql(sql)?;
|
||||
}
|
||||
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
return client.execute_query(sql, read_only).await;
|
||||
}
|
||||
|
||||
let pools = state.pools.read().await;
|
||||
let pool = pools
|
||||
@@ -106,7 +117,7 @@ pub async fn execute_query_core(
|
||||
.await
|
||||
.map_err(TuskError::Database)?
|
||||
};
|
||||
let execution_time_ms = start.elapsed().as_millis();
|
||||
let execution_time_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
let mut columns = Vec::new();
|
||||
let mut types = Vec::new();
|
||||
|
||||
@@ -12,12 +12,11 @@ fn get_saved_queries_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
|
||||
Ok(dir.join("saved_queries.json"))
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_saved_queries(
|
||||
app: AppHandle,
|
||||
search: Option<String>,
|
||||
pub(crate) async fn list_saved_queries_core(
|
||||
app: &AppHandle,
|
||||
search: Option<&str>,
|
||||
) -> TuskResult<Vec<SavedQuery>> {
|
||||
let path = get_saved_queries_path(&app)?;
|
||||
let path = get_saved_queries_path(app)?;
|
||||
if !path.exists() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
@@ -27,7 +26,7 @@ pub async fn list_saved_queries(
|
||||
let filtered: Vec<SavedQuery> = entries
|
||||
.into_iter()
|
||||
.filter(|e| {
|
||||
if let Some(ref s) = search {
|
||||
if let Some(s) = search {
|
||||
let lower = s.to_lowercase();
|
||||
e.name.to_lowercase().contains(&lower) || e.sql.to_lowercase().contains(&lower)
|
||||
} else {
|
||||
@@ -39,9 +38,8 @@ pub async fn list_saved_queries(
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn save_query(app: AppHandle, query: SavedQuery) -> TuskResult<()> {
|
||||
let path = get_saved_queries_path(&app)?;
|
||||
pub(crate) async fn save_query_core(app: &AppHandle, query: SavedQuery) -> TuskResult<()> {
|
||||
let path = get_saved_queries_path(app)?;
|
||||
let mut entries = if path.exists() {
|
||||
let data = fs::read_to_string(&path)?;
|
||||
serde_json::from_str::<Vec<SavedQuery>>(&data).unwrap_or_default()
|
||||
@@ -56,6 +54,19 @@ pub async fn save_query(app: AppHandle, query: SavedQuery) -> TuskResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_saved_queries(
|
||||
app: AppHandle,
|
||||
search: Option<String>,
|
||||
) -> TuskResult<Vec<SavedQuery>> {
|
||||
list_saved_queries_core(&app, search.as_deref()).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn save_query(app: AppHandle, query: SavedQuery) -> TuskResult<()> {
|
||||
save_query_core(&app, query).await
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn delete_saved_query(app: AppHandle, id: String) -> TuskResult<()> {
|
||||
let path = get_saved_queries_path(&app)?;
|
||||
|
||||
@@ -1,20 +1,53 @@
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::schema::{
|
||||
ColumnDetail, ColumnInfo, ConstraintInfo, ErdColumn, ErdData, ErdRelationship, ErdTable,
|
||||
IndexInfo, SchemaObject, TriggerInfo,
|
||||
ColumnDetail, ColumnInfo, ConstraintInfo, IndexInfo, SchemaObject, TriggerInfo,
|
||||
};
|
||||
use crate::state::{AppState, DbFlavor};
|
||||
use serde_json::Value;
|
||||
use sqlx::Row;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tauri::State;
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_databases(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
) -> TuskResult<Vec<String>> {
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
fn ch_string_literal(s: &str) -> String {
|
||||
let escaped = s.replace('\\', "\\\\").replace('\'', "\\'");
|
||||
format!("'{}'", escaped)
|
||||
}
|
||||
|
||||
fn ch_obj_string(obj: &serde_json::Map<String, Value>, key: &str) -> Option<String> {
|
||||
obj.get(key).and_then(|v| match v {
|
||||
Value::String(s) => Some(s.clone()),
|
||||
Value::Number(n) => Some(n.to_string()),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn ch_obj_i64(obj: &serde_json::Map<String, Value>, key: &str) -> Option<i64> {
|
||||
obj.get(key).and_then(|v| match v {
|
||||
Value::Number(n) => n.as_i64(),
|
||||
Value::String(s) => s.parse::<i64>().ok(),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn list_databases_core(state: &AppState, connection_id: &str) -> TuskResult<Vec<String>> {
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let rows = client
|
||||
.fetch_objects(
|
||||
"SELECT name FROM system.databases \
|
||||
WHERE name NOT IN ('system','INFORMATION_SCHEMA','information_schema') \
|
||||
ORDER BY name",
|
||||
)
|
||||
.await?;
|
||||
return Ok(rows
|
||||
.iter()
|
||||
.filter_map(|o| ch_obj_string(o, "name"))
|
||||
.collect());
|
||||
}
|
||||
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
"SELECT datname FROM pg_database \
|
||||
@@ -28,10 +61,24 @@ pub async fn list_databases(
|
||||
Ok(rows.iter().map(|r| r.get::<String, _>(0)).collect())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_databases(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
) -> TuskResult<Vec<String>> {
|
||||
list_databases_core(&state, &connection_id).await
|
||||
}
|
||||
|
||||
pub async fn list_schemas_core(state: &AppState, connection_id: &str) -> TuskResult<Vec<String>> {
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
// ClickHouse has no schema layer — surface the active database as a virtual schema.
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
return Ok(vec![client.database.clone()]);
|
||||
}
|
||||
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
let sql = if flavor == DbFlavor::Greenplum {
|
||||
"SELECT schema_name FROM information_schema.schemata \
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
|
||||
@@ -63,6 +110,29 @@ pub async fn list_tables_core(
|
||||
connection_id: &str,
|
||||
schema: &str,
|
||||
) -> TuskResult<Vec<SchemaObject>> {
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let escaped = ch_string_literal(schema);
|
||||
let sql = format!(
|
||||
"SELECT name, total_rows, total_bytes FROM system.tables \
|
||||
WHERE database = {} AND engine NOT LIKE '%View' \
|
||||
ORDER BY name",
|
||||
escaped
|
||||
);
|
||||
let rows = client.fetch_objects(&sql).await?;
|
||||
return Ok(rows
|
||||
.iter()
|
||||
.map(|o| SchemaObject {
|
||||
name: ch_obj_string(o, "name").unwrap_or_default(),
|
||||
object_type: "table".to_string(),
|
||||
schema: schema.to_string(),
|
||||
row_count: ch_obj_i64(o, "total_rows"),
|
||||
size_bytes: ch_obj_i64(o, "total_bytes"),
|
||||
})
|
||||
.collect());
|
||||
}
|
||||
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -107,6 +177,28 @@ pub async fn list_views(
|
||||
connection_id: String,
|
||||
schema: String,
|
||||
) -> TuskResult<Vec<SchemaObject>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
let client = state.get_ch_client(&connection_id).await?;
|
||||
let sql = format!(
|
||||
"SELECT name FROM system.tables \
|
||||
WHERE database = {} AND engine LIKE '%View' \
|
||||
ORDER BY name",
|
||||
ch_string_literal(&schema)
|
||||
);
|
||||
let rows = client.fetch_objects(&sql).await?;
|
||||
return Ok(rows
|
||||
.iter()
|
||||
.map(|o| SchemaObject {
|
||||
name: ch_obj_string(o, "name").unwrap_or_default(),
|
||||
object_type: "view".to_string(),
|
||||
schema: schema.clone(),
|
||||
row_count: None,
|
||||
size_bytes: None,
|
||||
})
|
||||
.collect());
|
||||
}
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -137,6 +229,11 @@ pub async fn list_functions(
|
||||
connection_id: String,
|
||||
schema: String,
|
||||
) -> TuskResult<Vec<SchemaObject>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
// ClickHouse functions are global, not schema-scoped — surface empty here.
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -167,6 +264,10 @@ pub async fn list_indexes(
|
||||
connection_id: String,
|
||||
schema: String,
|
||||
) -> TuskResult<Vec<SchemaObject>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -197,6 +298,10 @@ pub async fn list_sequences(
|
||||
connection_id: String,
|
||||
schema: String,
|
||||
) -> TuskResult<Vec<SchemaObject>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -227,6 +332,36 @@ pub async fn get_table_columns_core(
|
||||
schema: &str,
|
||||
table: &str,
|
||||
) -> TuskResult<Vec<ColumnInfo>> {
|
||||
let flavor = state.get_flavor(connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
let client = state.get_ch_client(connection_id).await?;
|
||||
let sql = format!(
|
||||
"SELECT name, type, default_expression, is_in_primary_key, comment, position \
|
||||
FROM system.columns WHERE database = {} AND table = {} \
|
||||
ORDER BY position",
|
||||
ch_string_literal(schema),
|
||||
ch_string_literal(table)
|
||||
);
|
||||
let rows = client.fetch_objects(&sql).await?;
|
||||
return Ok(rows
|
||||
.iter()
|
||||
.map(|o| {
|
||||
let type_str = ch_obj_string(o, "type").unwrap_or_default();
|
||||
let is_nullable = type_str.starts_with("Nullable(");
|
||||
ColumnInfo {
|
||||
name: ch_obj_string(o, "name").unwrap_or_default(),
|
||||
data_type: type_str,
|
||||
is_nullable,
|
||||
column_default: ch_obj_string(o, "default_expression"),
|
||||
ordinal_position: ch_obj_i64(o, "position").unwrap_or(0) as i32,
|
||||
character_maximum_length: None,
|
||||
is_primary_key: ch_obj_i64(o, "is_in_primary_key").unwrap_or(0) != 0,
|
||||
comment: ch_obj_string(o, "comment"),
|
||||
}
|
||||
})
|
||||
.collect());
|
||||
}
|
||||
|
||||
let pool = state.get_pool(connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -296,6 +431,10 @@ pub async fn get_table_constraints(
|
||||
schema: String,
|
||||
table: String,
|
||||
) -> TuskResult<Vec<ConstraintInfo>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -372,6 +511,10 @@ pub async fn get_table_indexes(
|
||||
schema: String,
|
||||
table: String,
|
||||
) -> TuskResult<Vec<IndexInfo>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -410,6 +553,25 @@ pub async fn get_completion_schema(
|
||||
connection_id: String,
|
||||
) -> TuskResult<HashMap<String, HashMap<String, Vec<String>>>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
let client = state.get_ch_client(&connection_id).await?;
|
||||
let sql = format!(
|
||||
"SELECT database, table, name FROM system.columns \
|
||||
WHERE database = {} \
|
||||
ORDER BY database, table, position",
|
||||
ch_string_literal(&client.database)
|
||||
);
|
||||
let rows = client.fetch_objects(&sql).await?;
|
||||
let mut result: HashMap<String, HashMap<String, Vec<String>>> = HashMap::new();
|
||||
for row in rows {
|
||||
let db = ch_obj_string(&row, "database").unwrap_or_default();
|
||||
let table = ch_obj_string(&row, "table").unwrap_or_default();
|
||||
let column = ch_obj_string(&row, "name").unwrap_or_default();
|
||||
result.entry(db).or_default().entry(table).or_default().push(column);
|
||||
}
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let sql = if flavor == DbFlavor::Greenplum {
|
||||
@@ -454,6 +616,19 @@ pub async fn get_column_details(
|
||||
table: String,
|
||||
) -> TuskResult<Vec<ColumnDetail>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
let columns = get_table_columns_core(&state, &connection_id, &schema, &table).await?;
|
||||
return Ok(columns
|
||||
.into_iter()
|
||||
.map(|c| ColumnDetail {
|
||||
column_name: c.name,
|
||||
data_type: c.data_type,
|
||||
is_nullable: c.is_nullable,
|
||||
column_default: c.column_default,
|
||||
is_identity: false,
|
||||
})
|
||||
.collect());
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let sql = if flavor == DbFlavor::Greenplum {
|
||||
@@ -500,6 +675,10 @@ pub async fn get_table_triggers(
|
||||
schema: String,
|
||||
table: String,
|
||||
) -> TuskResult<Vec<TriggerInfo>> {
|
||||
let flavor = state.get_flavor(&connection_id).await;
|
||||
if matches!(flavor, DbFlavor::ClickHouse) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
let rows = sqlx::query(
|
||||
@@ -547,127 +726,3 @@ pub async fn get_table_triggers(
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn get_schema_erd(
|
||||
state: State<'_, Arc<AppState>>,
|
||||
connection_id: String,
|
||||
schema: String,
|
||||
) -> TuskResult<ErdData> {
|
||||
let pool = state.get_pool(&connection_id).await?;
|
||||
|
||||
// Get all tables with columns
|
||||
let col_rows = sqlx::query(
|
||||
"SELECT \
|
||||
c.table_name, \
|
||||
c.column_name, \
|
||||
c.data_type, \
|
||||
c.is_nullable = 'YES' AS is_nullable, \
|
||||
COALESCE(( \
|
||||
SELECT true FROM pg_constraint con \
|
||||
JOIN pg_class cl ON cl.oid = con.conrelid \
|
||||
JOIN pg_namespace ns ON ns.oid = cl.relnamespace \
|
||||
WHERE con.contype = 'p' \
|
||||
AND ns.nspname = $1 AND cl.relname = c.table_name \
|
||||
AND EXISTS ( \
|
||||
SELECT 1 FROM unnest(con.conkey) k \
|
||||
JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = k \
|
||||
WHERE a.attname = c.column_name \
|
||||
) \
|
||||
LIMIT 1 \
|
||||
), false) AS is_pk \
|
||||
FROM information_schema.columns c \
|
||||
JOIN information_schema.tables t \
|
||||
ON t.table_schema = c.table_schema AND t.table_name = c.table_name \
|
||||
WHERE c.table_schema = $1 AND t.table_type = 'BASE TABLE' \
|
||||
ORDER BY c.table_name, c.ordinal_position",
|
||||
)
|
||||
.bind(&schema)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
// Build tables map
|
||||
let mut tables_map: HashMap<String, ErdTable> = HashMap::new();
|
||||
for row in &col_rows {
|
||||
let table_name: String = row.get(0);
|
||||
let entry = tables_map
|
||||
.entry(table_name.clone())
|
||||
.or_insert_with(|| ErdTable {
|
||||
schema: schema.clone(),
|
||||
name: table_name,
|
||||
columns: Vec::new(),
|
||||
});
|
||||
entry.columns.push(ErdColumn {
|
||||
name: row.get(1),
|
||||
data_type: row.get(2),
|
||||
is_nullable: row.get(3),
|
||||
is_primary_key: row.get(4),
|
||||
});
|
||||
}
|
||||
let tables: Vec<ErdTable> = tables_map.into_values().collect();
|
||||
|
||||
// Get all FK relationships
|
||||
let fk_rows = sqlx::query(
|
||||
"SELECT \
|
||||
c.conname AS constraint_name, \
|
||||
src_ns.nspname AS source_schema, \
|
||||
src_cl.relname AS source_table, \
|
||||
ARRAY( \
|
||||
SELECT a.attname FROM unnest(c.conkey) WITH ORDINALITY AS k(attnum, ord) \
|
||||
JOIN pg_attribute a ON a.attrelid = c.conrelid AND a.attnum = k.attnum \
|
||||
ORDER BY k.ord \
|
||||
)::text[] AS source_columns, \
|
||||
ref_ns.nspname AS target_schema, \
|
||||
ref_cl.relname AS target_table, \
|
||||
ARRAY( \
|
||||
SELECT a.attname FROM unnest(c.confkey) WITH ORDINALITY AS k(attnum, ord) \
|
||||
JOIN pg_attribute a ON a.attrelid = c.confrelid AND a.attnum = k.attnum \
|
||||
ORDER BY k.ord \
|
||||
)::text[] AS target_columns, \
|
||||
CASE c.confupdtype \
|
||||
WHEN 'a' THEN 'NO ACTION' \
|
||||
WHEN 'r' THEN 'RESTRICT' \
|
||||
WHEN 'c' THEN 'CASCADE' \
|
||||
WHEN 'n' THEN 'SET NULL' \
|
||||
WHEN 'd' THEN 'SET DEFAULT' \
|
||||
END AS update_rule, \
|
||||
CASE c.confdeltype \
|
||||
WHEN 'a' THEN 'NO ACTION' \
|
||||
WHEN 'r' THEN 'RESTRICT' \
|
||||
WHEN 'c' THEN 'CASCADE' \
|
||||
WHEN 'n' THEN 'SET NULL' \
|
||||
WHEN 'd' THEN 'SET DEFAULT' \
|
||||
END AS delete_rule \
|
||||
FROM pg_constraint c \
|
||||
JOIN pg_class src_cl ON src_cl.oid = c.conrelid \
|
||||
JOIN pg_namespace src_ns ON src_ns.oid = src_cl.relnamespace \
|
||||
JOIN pg_class ref_cl ON ref_cl.oid = c.confrelid \
|
||||
JOIN pg_namespace ref_ns ON ref_ns.oid = ref_cl.relnamespace \
|
||||
WHERE c.contype = 'f' AND src_ns.nspname = $1 \
|
||||
ORDER BY c.conname",
|
||||
)
|
||||
.bind(&schema)
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let relationships: Vec<ErdRelationship> = fk_rows
|
||||
.iter()
|
||||
.map(|r| ErdRelationship {
|
||||
constraint_name: r.get(0),
|
||||
source_schema: r.get(1),
|
||||
source_table: r.get(2),
|
||||
source_columns: r.get(3),
|
||||
target_schema: r.get(4),
|
||||
target_table: r.get(5),
|
||||
target_columns: r.get(6),
|
||||
update_rule: r.get(7),
|
||||
delete_rule: r.get(8),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ErdData {
|
||||
tables,
|
||||
relationships,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::mcp;
|
||||
use crate::models::settings::{AppSettings, DockerHost, McpStatus};
|
||||
use crate::models::settings::{AppSettings, McpStatus};
|
||||
use crate::state::AppState;
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
@@ -36,15 +36,6 @@ pub async fn save_app_settings(
|
||||
let data = serde_json::to_string_pretty(&settings)?;
|
||||
fs::write(&path, data)?;
|
||||
|
||||
// Apply docker host setting
|
||||
{
|
||||
let mut docker_host = state.docker_host.write().await;
|
||||
*docker_host = match settings.docker.host {
|
||||
DockerHost::Remote => settings.docker.remote_url.clone(),
|
||||
DockerHost::Local => None,
|
||||
};
|
||||
}
|
||||
|
||||
// Apply MCP setting: restart or stop
|
||||
let is_running = *state.mcp_running.read().await;
|
||||
|
||||
|
||||
@@ -1,362 +0,0 @@
|
||||
use crate::commands::ai::fetch_foreign_keys_raw;
|
||||
use crate::commands::data::bind_json_value;
|
||||
use crate::commands::queries::pg_value_to_json;
|
||||
use crate::error::{TuskError, TuskResult};
|
||||
use crate::models::snapshot::{
|
||||
CreateSnapshotParams, RestoreSnapshotParams, Snapshot, SnapshotMetadata, SnapshotProgress,
|
||||
SnapshotTableData, SnapshotTableMeta,
|
||||
};
|
||||
use crate::state::AppState;
|
||||
use crate::utils::{escape_ident, topological_sort_tables};
|
||||
use serde_json::Value;
|
||||
use sqlx::{Column, Row, TypeInfo};
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
use tauri::{AppHandle, Emitter, Manager, State};
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn create_snapshot(
|
||||
app: AppHandle,
|
||||
state: State<'_, Arc<AppState>>,
|
||||
params: CreateSnapshotParams,
|
||||
snapshot_id: String,
|
||||
file_path: String,
|
||||
) -> TuskResult<SnapshotMetadata> {
|
||||
let pool = state.get_pool(¶ms.connection_id).await?;
|
||||
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "preparing".to_string(),
|
||||
percent: 5,
|
||||
message: "Preparing snapshot...".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
let mut target_tables: Vec<(String, String)> = params
|
||||
.tables
|
||||
.iter()
|
||||
.map(|t| (t.schema.clone(), t.table.clone()))
|
||||
.collect();
|
||||
|
||||
// Fetch FK info once — used for both dependency expansion and topological sort
|
||||
let fk_rows = fetch_foreign_keys_raw(&pool).await?;
|
||||
|
||||
if params.include_dependencies {
|
||||
for fk in &fk_rows {
|
||||
if target_tables
|
||||
.iter()
|
||||
.any(|(s, t)| s == &fk.schema && t == &fk.table)
|
||||
{
|
||||
let parent = (fk.ref_schema.clone(), fk.ref_table.clone());
|
||||
if !target_tables.contains(&parent) {
|
||||
target_tables.push(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FK-based topological sort
|
||||
let fk_edges: Vec<(String, String, String, String)> = fk_rows
|
||||
.iter()
|
||||
.map(|fk| {
|
||||
(
|
||||
fk.schema.clone(),
|
||||
fk.table.clone(),
|
||||
fk.ref_schema.clone(),
|
||||
fk.ref_table.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let sorted_tables = topological_sort_tables(&fk_edges, &target_tables);
|
||||
|
||||
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
|
||||
sqlx::query("SET TRANSACTION READ ONLY")
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let total_tables = sorted_tables.len();
|
||||
let mut snapshot_tables: Vec<SnapshotTableData> = Vec::new();
|
||||
let mut table_metas: Vec<SnapshotTableMeta> = Vec::new();
|
||||
let mut total_rows: u64 = 0;
|
||||
|
||||
for (i, (schema, table)) in sorted_tables.iter().enumerate() {
|
||||
let percent = (10 + (i * 80 / total_tables.max(1))).min(90) as u8;
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "exporting".to_string(),
|
||||
percent,
|
||||
message: format!("Exporting {}.{}...", schema, table),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
let qualified = format!("{}.{}", escape_ident(schema), escape_ident(table));
|
||||
let sql = format!("SELECT * FROM {}", qualified);
|
||||
let rows = sqlx::query(&sql)
|
||||
.fetch_all(&mut *tx)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
let mut columns = Vec::new();
|
||||
let mut column_types = Vec::new();
|
||||
|
||||
if let Some(first) = rows.first() {
|
||||
for col in first.columns() {
|
||||
columns.push(col.name().to_string());
|
||||
column_types.push(col.type_info().name().to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let data_rows: Vec<Vec<Value>> = rows
|
||||
.iter()
|
||||
.map(|row| {
|
||||
(0..columns.len())
|
||||
.map(|i| pg_value_to_json(row, i))
|
||||
.collect()
|
||||
})
|
||||
.collect();
|
||||
|
||||
let row_count = data_rows.len() as u64;
|
||||
total_rows += row_count;
|
||||
|
||||
table_metas.push(SnapshotTableMeta {
|
||||
schema: schema.clone(),
|
||||
table: table.clone(),
|
||||
row_count,
|
||||
columns: columns.clone(),
|
||||
column_types: column_types.clone(),
|
||||
});
|
||||
|
||||
snapshot_tables.push(SnapshotTableData {
|
||||
schema: schema.clone(),
|
||||
table: table.clone(),
|
||||
columns,
|
||||
column_types,
|
||||
rows: data_rows,
|
||||
});
|
||||
}
|
||||
|
||||
tx.rollback().await.map_err(TuskError::Database)?;
|
||||
|
||||
let metadata = SnapshotMetadata {
|
||||
id: snapshot_id.clone(),
|
||||
name: params.name.clone(),
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
connection_name: String::new(),
|
||||
database: String::new(),
|
||||
tables: table_metas,
|
||||
total_rows,
|
||||
file_size_bytes: 0,
|
||||
version: 1,
|
||||
};
|
||||
|
||||
let snapshot = Snapshot {
|
||||
metadata: metadata.clone(),
|
||||
tables: snapshot_tables,
|
||||
};
|
||||
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "saving".to_string(),
|
||||
percent: 95,
|
||||
message: "Saving snapshot file...".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
let json = serde_json::to_string_pretty(&snapshot)?;
|
||||
let file_size = json.len() as u64;
|
||||
fs::write(&file_path, json)?;
|
||||
|
||||
let mut final_metadata = metadata;
|
||||
final_metadata.file_size_bytes = file_size;
|
||||
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "done".to_string(),
|
||||
percent: 100,
|
||||
message: "Snapshot created successfully".to_string(),
|
||||
detail: Some(format!("{} rows, {} tables", total_rows, total_tables)),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(final_metadata)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn restore_snapshot(
|
||||
app: AppHandle,
|
||||
state: State<'_, Arc<AppState>>,
|
||||
params: RestoreSnapshotParams,
|
||||
snapshot_id: String,
|
||||
) -> TuskResult<u64> {
|
||||
if state.is_read_only(¶ms.connection_id).await {
|
||||
return Err(TuskError::ReadOnly);
|
||||
}
|
||||
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "reading".to_string(),
|
||||
percent: 5,
|
||||
message: "Reading snapshot file...".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
let data = fs::read_to_string(¶ms.file_path)?;
|
||||
let snapshot: Snapshot = serde_json::from_str(&data)?;
|
||||
|
||||
let pool = state.get_pool(¶ms.connection_id).await?;
|
||||
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
|
||||
|
||||
sqlx::query("SET CONSTRAINTS ALL DEFERRED")
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
|
||||
// TRUNCATE in reverse order (children first)
|
||||
if params.truncate_before_restore {
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "truncating".to_string(),
|
||||
percent: 15,
|
||||
message: "Truncating existing data...".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
);
|
||||
|
||||
for table_data in snapshot.tables.iter().rev() {
|
||||
let qualified = format!(
|
||||
"{}.{}",
|
||||
escape_ident(&table_data.schema),
|
||||
escape_ident(&table_data.table)
|
||||
);
|
||||
let truncate_sql = format!("TRUNCATE {} CASCADE", qualified);
|
||||
sqlx::query(&truncate_sql)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(TuskError::Database)?;
|
||||
}
|
||||
}
|
||||
|
||||
// INSERT in forward order (parents first)
|
||||
let total_tables = snapshot.tables.len();
|
||||
let mut total_inserted: u64 = 0;
|
||||
|
||||
for (i, table_data) in snapshot.tables.iter().enumerate() {
|
||||
if table_data.columns.is_empty() || table_data.rows.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let percent = (20 + (i * 75 / total_tables.max(1))).min(95) as u8;
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "inserting".to_string(),
|
||||
percent,
|
||||
message: format!("Restoring {}.{}...", table_data.schema, table_data.table),
|
||||
detail: Some(format!("{} rows", table_data.rows.len())),
|
||||
},
|
||||
);
|
||||
|
||||
let qualified = format!(
|
||||
"{}.{}",
|
||||
escape_ident(&table_data.schema),
|
||||
escape_ident(&table_data.table)
|
||||
);
|
||||
let col_list: Vec<String> = table_data.columns.iter().map(|c| escape_ident(c)).collect();
|
||||
let placeholders: Vec<String> = (1..=table_data.columns.len())
|
||||
.map(|i| format!("${}", i))
|
||||
.collect();
|
||||
|
||||
let sql = format!(
|
||||
"INSERT INTO {} ({}) VALUES ({})",
|
||||
qualified,
|
||||
col_list.join(", "),
|
||||
placeholders.join(", ")
|
||||
);
|
||||
|
||||
// Chunked insert
|
||||
for row in &table_data.rows {
|
||||
let mut query = sqlx::query(&sql);
|
||||
for val in row {
|
||||
query = bind_json_value(query, val);
|
||||
}
|
||||
query.execute(&mut *tx).await.map_err(TuskError::Database)?;
|
||||
total_inserted += 1;
|
||||
}
|
||||
}
|
||||
|
||||
tx.commit().await.map_err(TuskError::Database)?;
|
||||
|
||||
let _ = app.emit(
|
||||
"snapshot-progress",
|
||||
SnapshotProgress {
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
stage: "done".to_string(),
|
||||
percent: 100,
|
||||
message: "Restore completed successfully".to_string(),
|
||||
detail: Some(format!("{} rows restored", total_inserted)),
|
||||
},
|
||||
);
|
||||
|
||||
state.invalidate_schema_cache(¶ms.connection_id).await;
|
||||
|
||||
Ok(total_inserted)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn list_snapshots(app: AppHandle) -> TuskResult<Vec<SnapshotMetadata>> {
|
||||
let dir = app
|
||||
.path()
|
||||
.app_data_dir()
|
||||
.map_err(|e| TuskError::Config(e.to_string()))?
|
||||
.join("snapshots");
|
||||
|
||||
if !dir.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut snapshots = Vec::new();
|
||||
|
||||
for entry in fs::read_dir(&dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().map(|e| e == "json").unwrap_or(false) {
|
||||
if let Ok(data) = fs::read_to_string(&path) {
|
||||
if let Ok(snapshot) = serde_json::from_str::<Snapshot>(&data) {
|
||||
let mut meta = snapshot.metadata;
|
||||
meta.file_size_bytes = entry.metadata().map(|m| m.len()).unwrap_or(0);
|
||||
snapshots.push(meta);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
|
||||
Ok(snapshots)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn read_snapshot_metadata(file_path: String) -> TuskResult<SnapshotMetadata> {
|
||||
let data = fs::read_to_string(&file_path)?;
|
||||
let snapshot: Snapshot = serde_json::from_str(&data)?;
|
||||
let mut meta = snapshot.metadata;
|
||||
meta.file_size_bytes = fs::metadata(&file_path).map(|m| m.len()).unwrap_or(0);
|
||||
Ok(meta)
|
||||
}
|
||||
Reference in New Issue
Block a user