Files
tusk/src-tauri/src/commands/ai.rs
2026-04-08 10:50:40 +03:00

1863 lines
62 KiB
Rust

use crate::commands::data::bind_json_value;
use crate::commands::queries::pg_value_to_json;
use crate::error::{TuskError, TuskResult};
use crate::models::ai::{
AiProvider, AiSettings, DataGenProgress, GenerateDataParams, GeneratedDataPreview,
GeneratedTableData, IndexAdvisorReport, IndexRecommendation, IndexStats, OllamaChatMessage,
OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse, SlowQuery, TableStats,
ValidationRule, ValidationStatus,
};
use crate::state::AppState;
use crate::utils::{escape_ident, topological_sort_tables};
use serde_json::Value;
use sqlx::{Column, Row};
use std::collections::{BTreeMap, HashMap};
use std::fs;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tauri::{AppHandle, Emitter, Manager, State};
const MAX_RETRIES: u32 = 2;
const RETRY_DELAY_MS: u64 = 1000;
fn http_client() -> &'static reqwest::Client {
use std::sync::LazyLock;
static CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.connect_timeout(Duration::from_secs(5))
.timeout(Duration::from_secs(300))
.build()
.unwrap_or_default()
});
&CLIENT
}
fn get_ai_settings_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
let dir = app
.path()
.app_data_dir()
.map_err(|e| TuskError::Config(e.to_string()))?;
fs::create_dir_all(&dir)?;
Ok(dir.join("ai_settings.json"))
}
#[tauri::command]
pub async fn get_ai_settings(app: AppHandle) -> TuskResult<AiSettings> {
let path = get_ai_settings_path(&app)?;
if !path.exists() {
return Ok(AiSettings::default());
}
let data = fs::read_to_string(&path)?;
let settings: AiSettings = serde_json::from_str(&data)?;
Ok(settings)
}
#[tauri::command]
pub async fn save_ai_settings(
app: AppHandle,
state: State<'_, Arc<AppState>>,
settings: AiSettings,
) -> TuskResult<()> {
let path = get_ai_settings_path(&app)?;
let data = serde_json::to_string_pretty(&settings)?;
fs::write(&path, data)?;
// Update in-memory cache
*state.ai_settings.write().await = Some(settings);
Ok(())
}
#[tauri::command]
pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaModel>> {
let url = format!("{}/api/tags", ollama_url.trim_end_matches('/'));
let resp =
http_client().get(&url).send().await.map_err(|e| {
TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e))
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Ollama error ({}): {}",
status, body
)));
}
let tags: OllamaTagsResponse = resp
.json()
.await
.map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?;
Ok(tags.models)
}
async fn call_ai_with_retry<F, Fut, T>(
_settings: &AiSettings,
operation: &str,
f: F,
) -> TuskResult<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = TuskResult<T>>,
{
let mut last_error = None;
for attempt in 0..MAX_RETRIES {
match f().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt < MAX_RETRIES - 1 {
log::warn!(
"{} failed (attempt {}/{}), retrying in {}ms...",
operation,
attempt + 1,
MAX_RETRIES,
RETRY_DELAY_MS
);
tokio::time::sleep(Duration::from_millis(RETRY_DELAY_MS)).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| {
TuskError::Ai(format!(
"{} failed after {} attempts",
operation, MAX_RETRIES
))
}))
}
async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskResult<AiSettings> {
// Try in-memory cache first
if let Some(cached) = state.ai_settings.read().await.clone() {
return Ok(cached);
}
// Fallback to disk
let path = get_ai_settings_path(app)?;
if !path.exists() {
return Err(TuskError::Ai(
"No AI model selected. Open AI settings to choose a model.".to_string(),
));
}
let data = fs::read_to_string(&path)?;
let settings: AiSettings = serde_json::from_str(&data)?;
// Populate cache for future calls
*state.ai_settings.write().await = Some(settings.clone());
Ok(settings)
}
async fn call_ollama_chat(
app: &AppHandle,
state: &AppState,
system_prompt: String,
user_content: String,
) -> TuskResult<String> {
let settings = load_ai_settings(app, state).await?;
if settings.model.is_empty() {
return Err(TuskError::Ai(
"No AI model selected. Open AI settings to choose a model.".to_string(),
));
}
if settings.provider != AiProvider::Ollama {
return Err(TuskError::Ai(format!(
"Provider {:?} not implemented yet",
settings.provider
)));
}
let model = settings.model.clone();
let url = format!("{}/api/chat", settings.ollama_url.trim_end_matches('/'));
let request = OllamaChatRequest {
model: model.clone(),
messages: vec![
OllamaChatMessage {
role: "system".to_string(),
content: system_prompt,
},
OllamaChatMessage {
role: "user".to_string(),
content: user_content,
},
],
stream: false,
};
call_ai_with_retry(&settings, "Ollama request", || {
let url = url.clone();
let request = request.clone();
async move {
let resp = http_client()
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| {
TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", url, e))
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Ollama error ({}): {}",
status, body
)));
}
let chat_resp: OllamaChatResponse = resp
.json()
.await
.map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?;
Ok(chat_resp.message.content)
}
})
.await
}
// ---------------------------------------------------------------------------
// SQL generation
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn generate_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
prompt: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are an expert PostgreSQL query generator. You receive a database schema and a natural \
language request. Output ONLY a valid, executable PostgreSQL SQL query.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences (```), no comments, no preamble.\n\
- The output must be directly executable in psql.\n\
- For complex queries use readable formatting with line breaks and indentation.\n\
\n\
CRITICAL RULES:\n\
1. ONLY reference tables and columns that exist in the schema. Never invent names.\n\
2. Use the FOREIGN KEY information to determine correct JOIN conditions.\n\
3. Use LEFT JOIN when the FK column is nullable or the relationship is optional; \
INNER JOIN when both sides must exist.\n\
4. Every non-aggregated column in SELECT must appear in GROUP BY.\n\
5. Use COALESCE for nullable columns in aggregations: COALESCE(SUM(x), 0).\n\
6. For enum columns, use ONLY the values listed in the ENUM TYPES section.\n\
7. Use IS NULL / IS NOT NULL for null checks — never = NULL or != NULL.\n\
8. Add LIMIT when the user asks for \"top N\", \"first N\", \"latest N\", etc.\n\
9. Qualify column names with table alias when the query involves multiple tables.\n\
\n\
SEMANTIC RULES (very important):\n\
- When a table has both actual_* and planned_* columns (e.g. actual_start vs planned_start), \
they represent DIFFERENT concepts: planned = future estimate, actual = what really happened. \
NEVER mix them with COALESCE unless the user explicitly requests a fallback.\n\
- For time-based calculations involving real events (\"how long did X take\", \"average time between\"), \
use ONLY actual/factual timestamps (actual_*, started_at, completed_at, ended_at). \
Filter out NULL values with WHERE instead of falling back to planned timestamps.\n\
- Planned timestamps (planned_*, scheduled_*, estimated_*) should ONLY be used when the user \
asks about plans, schedules, SLA, or compares plan vs fact.\n\
- When computing durations or averages, always filter out rows where any involved timestamp \
is NULL rather than substituting with unrelated defaults.\n\
- Pay attention to column descriptions/comments in the schema — they reveal business semantics \
that are critical for correct queries.\n\
\n\
TYPE RULES:\n\
- timestamp - timestamp = interval. For seconds: EXTRACT(EPOCH FROM (ts1 - ts2)).\n\
- interval cannot be cast to numeric directly; use EXTRACT(EPOCH FROM interval).\n\
- UNION/UNION ALL requires matching column count and compatible types; cast enums to text.\n\
- Use ::type for PostgreSQL-style casts.\n\
- For array columns use ANY, ALL, @>, <@ operators.\n\
- For JSONB columns use ->, ->>, #>, jsonb_extract_path.\n\
\n\
COMMON PATTERNS:\n\
- FIRST/LAST per group: to find MIN(started_at) per trip, use \
\"SELECT trip_id, MIN(started_at) FROM t GROUP BY trip_id\". \
NEVER put the aggregated column (started_at) into GROUP BY — that defeats the aggregation \
and returns every row separately instead of one per group.\n\
- TOP-1 per group with extra columns: use DISTINCT ON (group_col) ... ORDER BY group_col, sort_col \
or a subquery with ROW_NUMBER() OVER (PARTITION BY group_col ORDER BY sort_col) = 1.\n\
- For \"time from A to B\" calculations, ensure both timestamps are NOT NULL with WHERE filters; \
never use COALESCE to mix planned and actual timestamps.\n\
\n\
BEST PRACTICES:\n\
- Use ILIKE for case-insensitive text search, LIKE for case-sensitive.\n\
- Use EXISTS instead of IN for subquery existence checks.\n\
- Use CTE (WITH ... AS) for complex multi-step logic.\n\
- Use window functions (ROW_NUMBER, RANK, LAG, LEAD, SUM OVER) for ranking and running totals.\n\
- Use date_trunc('period', column) for time-based grouping.\n\
- Use generate_series() for creating ranges.\n\
- Use string_agg(col, ', ') for concatenating grouped values.\n\
- Use FILTER (WHERE ...) for conditional aggregation instead of CASE inside aggregate.\n\
\n\
{}\n",
schema_text
);
let raw = call_ollama_chat(&app, &state, system_prompt, prompt).await?;
Ok(clean_sql_response(&raw))
}
// ---------------------------------------------------------------------------
// SQL explanation
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn explain_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a PostgreSQL expert. Explain the given SQL query clearly and concisely.\n\
\n\
Structure your explanation as:\n\
1. **Summary** — one sentence describing what the query returns in business terms.\n\
2. **Step-by-step breakdown** — explain tables accessed, joins, filters, aggregations, \
subqueries, and sorting. Use bullet points.\n\
3. **Notes** — mention edge cases, potential issues, or performance concerns if any.\n\
\n\
Use the database schema below to understand table relationships and column meanings.\n\
Keep the explanation short; avoid restating the SQL verbatim.\n\
\n\
IMPORTANT: If you notice semantic issues (e.g. mixing planned_* and actual_* timestamps \
with COALESCE, comparing unrelated columns, missing NULL filters on nullable timestamps), \
mention them in the Notes section as potential problems.\n\
\n\
{}\n",
schema_text
);
call_ollama_chat(&app, &state, system_prompt, sql).await
}
// ---------------------------------------------------------------------------
// SQL error fixing
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn fix_sql_error(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
error_message: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a PostgreSQL expert debugger. You receive a SQL query and the error it produced. \
Fix the query so it executes correctly.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences (```), no comments.\n\
- The output must be directly executable.\n\
\n\
DIAGNOSTIC CHECKLIST:\n\
- Column/table does not exist → check the schema for correct names and spelling.\n\
- Column is ambiguous → qualify with table name or alias.\n\
- Must appear in GROUP BY → add missing non-aggregated columns to GROUP BY.\n\
- Type mismatch → add appropriate casts (::text, ::integer, etc.).\n\
- Permission denied → wrap in a read-only transaction if needed.\n\
- Syntax error → correct PostgreSQL syntax (check commas, parentheses, keywords).\n\
- Subquery returns more than one row → use IN, ANY, or add LIMIT 1.\n\
- Division by zero → wrap divisor with NULLIF(x, 0).\n\
\n\
ONLY use tables and columns from the schema below. Never invent names.\n\
Preserve the original intent of the query; change only what is necessary to fix the error.\n\
\n\
{}\n",
schema_text
);
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
let raw = call_ollama_chat(&app, &state, system_prompt, user_content).await?;
Ok(clean_sql_response(&raw))
}
// ---------------------------------------------------------------------------
// Schema context builder
// ---------------------------------------------------------------------------
pub(crate) async fn build_schema_context(
state: &AppState,
connection_id: &str,
) -> TuskResult<String> {
// Check cache first
if let Some(cached) = state.get_schema_cache(connection_id).await {
return Ok(cached);
}
let pool = state.get_pool(connection_id).await?;
// Run all metadata queries in parallel for speed
let (
version_res,
col_res,
fk_res,
enum_res,
tbl_comment_res,
col_comment_res,
unique_res,
varchar_res,
jsonb_res,
) = tokio::join!(
sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool),
fetch_columns(&pool),
fetch_foreign_keys_raw(&pool),
fetch_enum_types(&pool),
fetch_table_comments(&pool),
fetch_column_comments(&pool),
fetch_unique_constraints(&pool),
fetch_varchar_values(&pool),
fetch_jsonb_keys(&pool),
);
let version = version_res.map_err(TuskError::Database)?;
let col_rows = col_res?;
let fk_rows = fk_res?;
let enum_map = enum_res?;
let tbl_comments = tbl_comment_res?;
let col_comments = col_comment_res?;
let unique_constraints = unique_res?;
let varchar_values = varchar_res.unwrap_or_default();
let jsonb_keys = jsonb_res.unwrap_or_default();
// -- Build FK inline lookup: (schema, table, column) -> "ref_schema.ref_table(ref_col)" --
let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new();
let mut fk_lines: Vec<String> = Vec::new();
for fk in &fk_rows {
let line = format!(
"FK: {}.{}({}) -> {}.{}({})",
fk.schema,
fk.table,
fk.columns.join(", "),
fk.ref_schema,
fk.ref_table,
fk.ref_columns.join(", ")
);
fk_lines.push(line);
// For single-column FKs, enable inline annotation on column
if fk.columns.len() == 1 && fk.ref_columns.len() == 1 {
fk_inline.insert(
(fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()),
format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]),
);
}
}
// -- Build unique constraint lookup: (schema, table) -> Vec<column_list_string> --
let mut unique_map: HashMap<(String, String), Vec<String>> = HashMap::new();
for (schema, table, cols) in &unique_constraints {
unique_map
.entry((schema.clone(), table.clone()))
.or_default()
.push(cols.join(", "));
}
// -- Format output --
let mut output: Vec<String> = Vec::new();
// 1. PostgreSQL version (short form)
let short_version = version
.split_whitespace()
.take(2)
.collect::<Vec<_>>()
.join(" ");
output.push(format!("DATABASE SCHEMA ({})", short_version));
output.push(String::new());
// 2. Enum types
if !enum_map.is_empty() {
output.push("ENUM TYPES:".to_string());
for (type_name, values) in &enum_map {
let vals_str = values
.iter()
.map(|v| format!("'{}'", v))
.collect::<Vec<_>>()
.join(", ");
output.push(format!(" {} = [{}]", type_name, vals_str));
}
output.push(String::new());
}
// 3. Tables with columns
output.push("TABLES:".to_string());
// Group columns by schema.table preserving order
let mut tables: BTreeMap<String, Vec<ColumnInfo>> = BTreeMap::new();
for ci in &col_rows {
let key = format!("{}.{}", ci.schema, ci.table);
tables.entry(key).or_default().push(ci.clone());
}
for (full_name, columns) in &tables {
// Table header with optional comment
let tbl_comment = tbl_comments.get(full_name).map(|c| c.as_str());
match tbl_comment {
Some(comment) => output.push(format!("\nTABLE {} -- {}", full_name, comment)),
None => output.push(format!("\nTABLE {}", full_name)),
}
// Columns
for ci in columns {
let mut parts: Vec<String> = vec![ci.column.clone(), ci.data_type.clone()];
if ci.is_pk {
parts.push("PK".to_string());
}
if ci.not_null && !ci.is_pk {
parts.push("NOT NULL".to_string());
}
// Inline FK reference
if let Some(ref_target) =
fk_inline.get(&(ci.schema.clone(), ci.table.clone(), ci.column.clone()))
{
parts.push(format!("FK->{}", ref_target));
}
// Default value (simplified)
if let Some(ref def) = ci.column_default {
let simplified = simplify_default(def);
if !simplified.is_empty() {
parts.push(format!("DEFAULT {}", simplified));
}
}
// Column comment
let col_key = (ci.schema.clone(), ci.table.clone(), ci.column.clone());
let col_comment = col_comments.get(&col_key);
// Inline enum values for enum-typed columns
let enum_annotation = enum_map.get(&ci.data_type);
let mut suffix_parts: Vec<String> = Vec::new();
if let Some(vals) = enum_annotation {
let vals_str = vals
.iter()
.map(|v| format!("'{}'", v))
.collect::<Vec<_>>()
.join(", ");
suffix_parts.push(format!("enum: {}", vals_str));
}
// Inline varchar distinct values (pseudo-enums from pg_stats)
if enum_annotation.is_none() {
if let Some(vals) = varchar_values.get(&col_key) {
let vals_str = vals
.iter()
.take(15)
.map(|v| format!("'{}'", v))
.collect::<Vec<_>>()
.join(", ");
suffix_parts.push(format!("values: {}", vals_str));
}
}
// Inline JSONB key structure
if let Some(keys) = jsonb_keys.get(&col_key) {
suffix_parts.push(format!("json keys: {}", keys.join(", ")));
}
if let Some(comment) = col_comment {
suffix_parts.push(comment.clone());
}
if suffix_parts.is_empty() {
output.push(format!(" {}", parts.join(" ")));
} else {
output.push(format!(
" {} -- {}",
parts.join(" "),
suffix_parts.join("; ")
));
}
}
// Unique constraints for this table
let schema_table: Vec<&str> = full_name.splitn(2, '.').collect();
if schema_table.len() == 2 {
if let Some(uqs) =
unique_map.get(&(schema_table[0].to_string(), schema_table[1].to_string()))
{
for uq_cols in uqs {
output.push(format!(" UNIQUE({})", uq_cols));
}
}
}
}
// 4. Foreign keys summary
if !fk_lines.is_empty() {
output.push(String::new());
output.push("FOREIGN KEYS:".to_string());
for fk in &fk_lines {
output.push(format!(" {}", fk));
}
}
let result = output.join("\n");
// Cache the result
state
.set_schema_cache(connection_id.to_string(), result.clone())
.await;
Ok(result)
}
// ---------------------------------------------------------------------------
// Schema query helpers
// ---------------------------------------------------------------------------
#[derive(Clone)]
struct ColumnInfo {
schema: String,
table: String,
column: String,
data_type: String,
not_null: bool,
is_pk: bool,
column_default: Option<String>,
}
async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult<Vec<ColumnInfo>> {
let rows = sqlx::query(
"SELECT \
c.table_schema, c.table_name, c.column_name, \
CASE \
WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name \
WHEN c.data_type = 'ARRAY' THEN c.udt_name || '[]' \
ELSE c.data_type \
END AS data_type, \
c.is_nullable = 'NO' AS not_null, \
c.column_default, \
EXISTS( \
SELECT 1 FROM information_schema.table_constraints tc \
JOIN information_schema.key_column_usage kcu \
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema \
WHERE tc.constraint_type = 'PRIMARY KEY' \
AND tc.table_schema = c.table_schema \
AND tc.table_name = c.table_name \
AND kcu.column_name = c.column_name \
) AS is_pk \
FROM information_schema.columns c \
WHERE c.table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
ORDER BY c.table_schema, c.table_name, c.ordinal_position",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
Ok(rows
.iter()
.map(|r| ColumnInfo {
schema: r.get(0),
table: r.get(1),
column: r.get(2),
data_type: r.get(3),
not_null: r.get(4),
column_default: r.get(5),
is_pk: r.get(6),
})
.collect())
}
pub(crate) struct ForeignKeyInfo {
pub(crate) schema: String,
pub(crate) table: String,
pub(crate) columns: Vec<String>,
pub(crate) ref_schema: String,
pub(crate) ref_table: String,
pub(crate) ref_columns: Vec<String>,
}
pub(crate) async fn fetch_foreign_keys_raw(pool: &sqlx::PgPool) -> TuskResult<Vec<ForeignKeyInfo>> {
let rows = sqlx::query(
"SELECT \
cn.nspname AS schema_name, cl.relname AS table_name, \
array_agg(DISTINCT a.attname ORDER BY a.attname) AS columns, \
cnf.nspname AS ref_schema, clf.relname AS ref_table, \
array_agg(DISTINCT af.attname ORDER BY af.attname) AS ref_columns \
FROM pg_constraint con \
JOIN pg_class cl ON con.conrelid = cl.oid \
JOIN pg_namespace cn ON cl.relnamespace = cn.oid \
JOIN pg_class clf ON con.confrelid = clf.oid \
JOIN pg_namespace cnf ON clf.relnamespace = cnf.oid \
JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \
JOIN pg_attribute af ON af.attrelid = con.confrelid AND af.attnum = ANY(con.confkey) \
WHERE con.contype = 'f' \
AND cn.nspname NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \
GROUP BY cn.nspname, cl.relname, cnf.nspname, clf.relname, con.oid",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
Ok(rows
.iter()
.map(|r| ForeignKeyInfo {
schema: r.get(0),
table: r.get(1),
columns: r.get(2),
ref_schema: r.get(3),
ref_table: r.get(4),
ref_columns: r.get(5),
})
.collect())
}
/// Returns BTreeMap<enum_type_name, Vec<enum_values>> ordered by type name
async fn fetch_enum_types(pool: &sqlx::PgPool) -> TuskResult<BTreeMap<String, Vec<String>>> {
let rows = sqlx::query(
"SELECT t.typname, \
array_agg(e.enumlabel ORDER BY e.enumsortorder) AS vals \
FROM pg_enum e \
JOIN pg_type t ON e.enumtypid = t.oid \
JOIN pg_namespace n ON t.typnamespace = n.oid \
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') \
GROUP BY t.typname \
ORDER BY t.typname",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
let mut map = BTreeMap::new();
for r in &rows {
let name: String = r.get(0);
let vals: Vec<String> = r.get(1);
map.insert(name, vals);
}
Ok(map)
}
/// Returns HashMap<"schema.table", comment>
async fn fetch_table_comments(pool: &sqlx::PgPool) -> TuskResult<HashMap<String, String>> {
let rows = sqlx::query(
"SELECT n.nspname, c.relname, obj_description(c.oid, 'pg_class') \
FROM pg_class c \
JOIN pg_namespace n ON c.relnamespace = n.oid \
WHERE c.relkind IN ('r', 'v', 'p', 'm') \
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
AND obj_description(c.oid, 'pg_class') IS NOT NULL",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
let mut map = HashMap::new();
for r in &rows {
let schema: String = r.get(0);
let table: String = r.get(1);
let comment: String = r.get(2);
map.insert(format!("{}.{}", schema, table), comment);
}
Ok(map)
}
/// Returns HashMap<(schema, table, column), comment>
async fn fetch_column_comments(
pool: &sqlx::PgPool,
) -> TuskResult<HashMap<(String, String, String), String>> {
let rows = sqlx::query(
"SELECT n.nspname, c.relname, a.attname, d.description \
FROM pg_description d \
JOIN pg_class c ON d.objoid = c.oid \
JOIN pg_namespace n ON c.relnamespace = n.oid \
JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = d.objsubid \
WHERE d.objsubid > 0 \
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit')",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
let mut map = HashMap::new();
for r in &rows {
let schema: String = r.get(0);
let table: String = r.get(1);
let column: String = r.get(2);
let comment: String = r.get(3);
map.insert((schema, table, column), comment);
}
Ok(map)
}
/// Returns Vec<(schema, table, Vec<column_names>)> for UNIQUE constraints
async fn fetch_unique_constraints(
pool: &sqlx::PgPool,
) -> TuskResult<Vec<(String, String, Vec<String>)>> {
let rows = sqlx::query(
"SELECT n.nspname, cl.relname, \
array_agg(a.attname ORDER BY a.attnum) AS cols \
FROM pg_constraint con \
JOIN pg_class cl ON con.conrelid = cl.oid \
JOIN pg_namespace n ON cl.relnamespace = n.oid \
JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \
WHERE con.contype = 'u' \
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
GROUP BY n.nspname, cl.relname, con.oid \
ORDER BY n.nspname, cl.relname",
)
.fetch_all(pool)
.await
.map_err(TuskError::Database)?;
Ok(rows
.iter()
.map(|r| {
let schema: String = r.get(0);
let table: String = r.get(1);
let cols: Vec<String> = r.get(2);
(schema, table, cols)
})
.collect())
}
/// Returns HashMap<(schema, table, column), Vec<distinct_values>> for varchar columns
/// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery.
/// Returns None if pg_stats is not accessible (graceful degradation).
async fn fetch_varchar_values(
pool: &sqlx::PgPool,
) -> Option<HashMap<(String, String, String), Vec<String>>> {
let rows = match sqlx::query(
"SELECT s.schemaname, s.tablename, s.attname, \
s.most_common_vals::text AS vals \
FROM pg_stats s \
JOIN information_schema.columns c \
ON c.table_schema = s.schemaname \
AND c.table_name = s.tablename \
AND c.column_name = s.attname \
WHERE s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
AND c.data_type = 'character varying' \
AND s.n_distinct > 0 AND s.n_distinct <= 20 \
AND s.most_common_vals IS NOT NULL \
ORDER BY s.schemaname, s.tablename, s.attname",
)
.fetch_all(pool)
.await
{
Ok(r) => r,
Err(e) => {
log::warn!("Failed to fetch varchar values from pg_stats: {}", e);
return None;
}
};
let mut map = HashMap::new();
for r in &rows {
let schema: String = r.get(0);
let table: String = r.get(1);
let column: String = r.get(2);
let vals_text: String = r.get(3);
let vals = parse_pg_array_text(&vals_text);
if !vals.is_empty() {
map.insert((schema, table, column), vals);
}
}
Some(map)
}
/// Discovers top-level keys in JSONB columns by sampling actual data.
/// Runs two sequential queries internally: first discovers JSONB columns,
/// then samples keys from each via a single UNION ALL query.
/// Returns None on error (graceful degradation).
async fn fetch_jsonb_keys(
pool: &sqlx::PgPool,
) -> Option<HashMap<(String, String, String), Vec<String>>> {
// Step 1: Find all JSONB columns
let col_rows = match sqlx::query(
"SELECT table_schema, table_name, column_name \
FROM information_schema.columns \
WHERE data_type = 'jsonb' \
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
ORDER BY table_schema, table_name, column_name",
)
.fetch_all(pool)
.await
{
Ok(r) => r,
Err(e) => {
log::warn!("Failed to fetch JSONB columns: {}", e);
return None;
}
};
if col_rows.is_empty() {
return Some(HashMap::new());
}
// Cap at 50 JSONB columns to prevent unbounded UNION ALL queries on large schemas
let columns: Vec<(String, String, String)> = col_rows
.iter()
.take(50)
.map(|r| {
(
r.get::<String, _>(0),
r.get::<String, _>(1),
r.get::<String, _>(2),
)
})
.collect();
// Step 2: Build a single UNION ALL query to sample keys from all JSONB columns
let parts: Vec<String> = columns
.iter()
.enumerate()
.map(|(i, (schema, table, col))| {
let qs = schema.replace('"', "\"\"");
let qt = table.replace('"', "\"\"");
let qc = col.replace('"', "\"\"");
format!(
"(SELECT '{}.{}.{}' AS col_ref, key FROM (\
SELECT DISTINCT jsonb_object_keys(\"{}\") AS key \
FROM \"{}\".\"{}\" \
WHERE \"{}\" IS NOT NULL AND jsonb_typeof(\"{}\") = 'object' \
LIMIT 50\
) sub{})",
schema, table, col, qc, qs, qt, qc, qc, i
)
})
.collect();
let query = parts.join(" UNION ALL ");
let rows = match sqlx::query(&query).fetch_all(pool).await {
Ok(r) => r,
Err(e) => {
log::warn!("Failed to fetch JSONB keys: {}", e);
return None;
}
};
let mut map: HashMap<(String, String, String), Vec<String>> = HashMap::new();
for r in &rows {
let col_ref: String = r.get(0);
let key: String = r.get(1);
let ref_parts: Vec<&str> = col_ref.splitn(3, '.').collect();
if ref_parts.len() == 3 {
let entry = map
.entry((
ref_parts[0].to_string(),
ref_parts[1].to_string(),
ref_parts[2].to_string(),
))
.or_default();
if !entry.contains(&key) {
entry.push(key);
}
}
}
for vals in map.values_mut() {
vals.sort();
}
Some(map)
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"}
fn parse_pg_array_text(s: &str) -> Vec<String> {
let s = s.trim();
let s = s.strip_prefix('{').unwrap_or(s);
let s = s.strip_suffix('}').unwrap_or(s);
if s.is_empty() {
return Vec::new();
}
let mut result = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
let mut chars = s.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'"' if !in_quotes => {
in_quotes = true;
}
'"' if in_quotes => {
if chars.peek() == Some(&'"') {
current.push('"');
chars.next();
} else {
in_quotes = false;
}
}
',' if !in_quotes => {
result.push(current.trim().to_string());
current = String::new();
}
_ => {
current.push(ch);
}
}
}
if !current.is_empty() || !result.is_empty() {
result.push(current.trim().to_string());
}
result
}
fn simplify_default(raw: &str) -> String {
let s = raw.trim();
if s.contains("nextval(") {
return "auto-increment".to_string();
}
// Shorten common defaults
if s == "now()" || s == "CURRENT_TIMESTAMP" || s == "current_timestamp" {
return "now()".to_string();
}
if s == "true" || s == "false" {
return s.to_string();
}
// Numeric/string literals — keep short ones, skip very long generated defaults
if s.len() > 50 {
return String::new();
}
s.to_string()
}
fn validate_select_statement(sql: &str) -> TuskResult<()> {
let sql_upper = sql.trim().to_uppercase();
if !sql_upper.starts_with("SELECT") {
return Err(TuskError::Validation(
"Validation query must be a SELECT statement".to_string(),
));
}
Ok(())
}
fn validate_index_ddl(ddl: &str) -> TuskResult<()> {
let ddl_upper = ddl.trim().to_uppercase();
if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") {
return Err(TuskError::Validation(
"Only CREATE INDEX and DROP INDEX statements are allowed".to_string(),
));
}
Ok(())
}
fn clean_sql_response(raw: &str) -> String {
let trimmed = raw.trim();
// Remove markdown code fences
let without_fences = if trimmed.starts_with("```") {
let inner = trimmed
.strip_prefix("```sql")
.or_else(|| trimmed.strip_prefix("```SQL"))
.or_else(|| trimmed.strip_prefix("```postgresql"))
.or_else(|| trimmed.strip_prefix("```"))
.unwrap_or(trimmed);
inner.strip_suffix("```").unwrap_or(inner)
} else {
trimmed
};
without_fences.trim().to_string()
}
// ---------------------------------------------------------------------------
// Wave 1: AI Data Assertions (Validation)
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn generate_validation_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
rule_description: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are an expert PostgreSQL data quality validator. Given a database schema and a natural \
language data quality rule, generate a SELECT query that finds ALL rows violating the rule.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences, no comments.\n\
- The query MUST be a SELECT statement.\n\
- Return violating rows with enough context columns to identify them.\n\
\n\
VALIDATION PATTERNS:\n\
- NULL checks: SELECT * FROM table WHERE required_column IS NULL\n\
- Format checks: WHERE column !~ 'pattern'\n\
- Range checks: WHERE column < min OR column > max\n\
- FK integrity: LEFT JOIN parent ON ... WHERE parent.id IS NULL\n\
- Uniqueness: GROUP BY ... HAVING COUNT(*) > 1\n\
- Date consistency: WHERE start_date > end_date\n\
- Enum validity: WHERE column NOT IN ('val1', 'val2', ...)\n\
\n\
ONLY reference tables and columns that exist in the schema.\n\
\n\
{}\n",
schema_text
);
let raw = call_ollama_chat(&app, &state, system_prompt, rule_description).await?;
Ok(clean_sql_response(&raw))
}
#[tauri::command]
pub async fn run_validation_rule(
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
sample_limit: Option<u32>,
) -> TuskResult<ValidationRule> {
validate_select_statement(&sql)?;
let pool = state.get_pool(&connection_id).await?;
let limit = sample_limit.unwrap_or(10);
let _start = Instant::now();
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
sqlx::query("SET TRANSACTION READ ONLY")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
sqlx::query("SET statement_timeout = '30s'")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
// Count violations
let count_sql = format!("SELECT COUNT(*) FROM ({}) AS _v", sql);
let count_row = sqlx::query(&count_sql)
.fetch_one(&mut *tx)
.await
.map_err(TuskError::Database)?;
let violation_count: i64 = count_row.get(0);
// Sample violations
let sample_sql = format!("SELECT * FROM ({}) AS _v LIMIT {}", sql, limit);
let sample_rows = sqlx::query(&sample_sql)
.fetch_all(&mut *tx)
.await
.map_err(TuskError::Database)?;
tx.rollback().await.map_err(TuskError::Database)?;
let mut violation_columns = Vec::new();
let mut sample_violations = Vec::new();
if let Some(first) = sample_rows.first() {
for col in first.columns() {
violation_columns.push(col.name().to_string());
}
}
for row in &sample_rows {
let vals: Vec<Value> = (0..violation_columns.len())
.map(|i| pg_value_to_json(row, i))
.collect();
sample_violations.push(vals);
}
let status = if violation_count > 0 {
ValidationStatus::Failed
} else {
ValidationStatus::Passed
};
Ok(ValidationRule {
id: String::new(),
description: String::new(),
generated_sql: sql,
status,
violation_count: violation_count as u64,
sample_violations,
violation_columns,
error: None,
})
}
#[tauri::command]
pub async fn suggest_validation_rules(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
) -> TuskResult<Vec<String>> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a data quality expert. Given a database schema, suggest 5-10 data quality \
validation rules as natural language descriptions.\n\
\n\
OUTPUT FORMAT:\n\
- Return ONLY a JSON array of strings, each string being a validation rule.\n\
- No markdown, no explanations, no code fences.\n\
- Example: [\"All users must have a non-empty email address\", \"Order total must be positive\"]\n\
\n\
RULE CATEGORIES TO COVER:\n\
- NOT NULL checks for critical columns\n\
- Business logic (positive amounts, valid ranges, consistent dates)\n\
- Referential integrity (orphaned foreign keys)\n\
- Format validation (emails, phone numbers, codes)\n\
- Enum/status field validity\n\
- Date consistency (start before end, not in future where inappropriate)\n\
\n\
{}\n",
schema_text
);
let raw = call_ollama_chat(
&app,
&state,
system_prompt,
"Suggest validation rules".to_string(),
)
.await?;
let cleaned = raw.trim();
let json_start = cleaned.find('[').unwrap_or(0);
let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len());
let json_str = &cleaned[json_start..json_end];
let rules: Vec<String> = serde_json::from_str(json_str).map_err(|e| {
TuskError::Ai(format!(
"Failed to parse AI response as JSON array: {}. Response: {}",
e, cleaned
))
})?;
Ok(rules)
}
// ---------------------------------------------------------------------------
// Wave 2: AI Data Generator
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn generate_test_data_preview(
app: AppHandle,
state: State<'_, Arc<AppState>>,
params: GenerateDataParams,
gen_id: String,
) -> TuskResult<GeneratedDataPreview> {
let pool = state.get_pool(&params.connection_id).await?;
let _ = app.emit(
"datagen-progress",
DataGenProgress {
gen_id: gen_id.clone(),
stage: "schema".to_string(),
percent: 10,
message: "Building schema context...".to_string(),
detail: None,
},
);
let schema_text = build_schema_context(&state, &params.connection_id).await?;
// Get FK info for topological sort
let fk_rows = fetch_foreign_keys_raw(&pool).await?;
let mut target_tables = vec![(params.schema.clone(), params.table.clone())];
if params.include_related {
// Add parent tables (tables referenced by FKs from target)
for fk in &fk_rows {
if fk.schema == params.schema && fk.table == params.table {
let parent = (fk.ref_schema.clone(), fk.ref_table.clone());
if !target_tables.contains(&parent) {
target_tables.push(parent);
}
}
}
}
let fk_edges: Vec<(String, String, String, String)> = fk_rows
.iter()
.map(|fk| {
(
fk.schema.clone(),
fk.table.clone(),
fk.ref_schema.clone(),
fk.ref_table.clone(),
)
})
.collect();
let sorted_tables = topological_sort_tables(&fk_edges, &target_tables);
let insert_order: Vec<String> = sorted_tables
.iter()
.map(|(s, t)| format!("{}.{}", s, t))
.collect();
let row_count = params.row_count.min(1000);
let _ = app.emit(
"datagen-progress",
DataGenProgress {
gen_id: gen_id.clone(),
stage: "generating".to_string(),
percent: 30,
message: "AI is generating test data...".to_string(),
detail: None,
},
);
let tables_desc: Vec<String> = sorted_tables
.iter()
.map(|(s, t)| {
let count = if s == &params.schema && t == &params.table {
row_count
} else {
(row_count / 3).max(1)
};
format!("{}.{}: {} rows", s, t, count)
})
.collect();
let custom = params
.custom_instructions
.as_deref()
.unwrap_or("Generate realistic sample data");
let system_prompt = format!(
"You are a PostgreSQL test data generator. Generate realistic test data as JSON.\n\
\n\
OUTPUT FORMAT:\n\
- Return ONLY a JSON object where keys are \"schema.table\" and values are arrays of row objects.\n\
- Each row object has column names as keys and values matching the column types.\n\
- No markdown, no explanations, no code fences.\n\
- Example: {{\"public.users\": [{{\"name\": \"Alice\", \"email\": \"alice@example.com\"}}]}}\n\
\n\
RULES:\n\
1. Respect column types exactly (text, integer, boolean, timestamp, uuid, etc.)\n\
2. Use valid foreign key values - parent tables are generated first, reference their IDs\n\
3. Respect enum types - use only valid enum values\n\
4. Omit auto-increment/serial/identity columns (they have DEFAULT auto-increment)\n\
5. Generate realistic data: real names, valid emails, plausible dates, etc.\n\
6. Respect NOT NULL constraints\n\
7. For UUID columns, generate valid UUIDs\n\
8. For timestamp columns, use ISO 8601 format\n\
\n\
Tables to generate (in this exact order):\n\
{}\n\
\n\
Custom instructions: {}\n\
\n\
{}\n",
tables_desc.join("\n"),
custom,
schema_text
);
let raw = call_ollama_chat(
&app,
&state,
system_prompt,
format!("Generate test data for {} tables", sorted_tables.len()),
)
.await?;
let _ = app.emit(
"datagen-progress",
DataGenProgress {
gen_id: gen_id.clone(),
stage: "parsing".to_string(),
percent: 80,
message: "Parsing generated data...".to_string(),
detail: None,
},
);
// Parse JSON response
let cleaned = raw.trim();
let json_start = cleaned.find('{').unwrap_or(0);
let json_end = cleaned.rfind('}').map(|i| i + 1).unwrap_or(cleaned.len());
let json_str = &cleaned[json_start..json_end];
let data_map: HashMap<String, Vec<HashMap<String, Value>>> = serde_json::from_str(json_str)
.map_err(|e| {
TuskError::Ai(format!(
"Failed to parse generated data: {}. Response: {}",
e,
&cleaned[..cleaned.len().min(500)]
))
})?;
let mut tables = Vec::new();
let mut total_rows: u32 = 0;
for (schema, table) in &sorted_tables {
let key = format!("{}.{}", schema, table);
if let Some(rows_data) = data_map.get(&key) {
let columns: Vec<String> = if let Some(first) = rows_data.first() {
first.keys().cloned().collect()
} else {
Vec::new()
};
let rows: Vec<Vec<Value>> = rows_data
.iter()
.map(|row_map| {
columns
.iter()
.map(|col| row_map.get(col).cloned().unwrap_or(Value::Null))
.collect()
})
.collect();
let count = rows.len() as u32;
total_rows += count;
tables.push(GeneratedTableData {
schema: schema.clone(),
table: table.clone(),
columns,
rows,
row_count: count,
});
}
}
let _ = app.emit(
"datagen-progress",
DataGenProgress {
gen_id: gen_id.clone(),
stage: "done".to_string(),
percent: 100,
message: "Data generation complete".to_string(),
detail: Some(format!(
"{} rows across {} tables",
total_rows,
tables.len()
)),
},
);
Ok(GeneratedDataPreview {
tables,
insert_order,
total_rows,
})
}
#[tauri::command]
pub async fn insert_generated_data(
state: State<'_, Arc<AppState>>,
connection_id: String,
preview: GeneratedDataPreview,
) -> TuskResult<u64> {
if state.is_read_only(&connection_id).await {
return Err(TuskError::ReadOnly);
}
let pool = state.get_pool(&connection_id).await?;
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
// Defer constraints for circular FKs
sqlx::query("SET CONSTRAINTS ALL DEFERRED")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
let mut total_inserted: u64 = 0;
for table_data in &preview.tables {
if table_data.columns.is_empty() || table_data.rows.is_empty() {
continue;
}
let qualified = format!(
"{}.{}",
escape_ident(&table_data.schema),
escape_ident(&table_data.table)
);
let col_list: Vec<String> = table_data.columns.iter().map(|c| escape_ident(c)).collect();
let placeholders: Vec<String> = (1..=table_data.columns.len())
.map(|i| format!("${}", i))
.collect();
let sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
qualified,
col_list.join(", "),
placeholders.join(", ")
);
for row in &table_data.rows {
let mut query = sqlx::query(&sql);
for val in row {
query = bind_json_value(query, val);
}
query.execute(&mut *tx).await.map_err(TuskError::Database)?;
total_inserted += 1;
}
}
tx.commit().await.map_err(TuskError::Database)?;
// Invalidate schema cache since data changed
state.invalidate_schema_cache(&connection_id).await;
Ok(total_inserted)
}
// ---------------------------------------------------------------------------
// Wave 3A: Smart Index Advisor
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn get_index_advisor_report(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
) -> TuskResult<IndexAdvisorReport> {
let pool = state.get_pool(&connection_id).await?;
// Fetch table stats, index stats, and slow queries concurrently
let (table_stats_res, index_stats_res, slow_queries_res) = tokio::join!(
sqlx::query(
"SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \
pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \
pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \
FROM pg_stat_user_tables \
ORDER BY seq_scan DESC \
LIMIT 50"
)
.fetch_all(&pool),
sqlx::query(
"SELECT schemaname, relname, indexrelname, idx_scan, \
pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \
pg_get_indexdef(indexrelid) AS definition \
FROM pg_stat_user_indexes \
ORDER BY idx_scan ASC \
LIMIT 50",
)
.fetch_all(&pool),
sqlx::query(
"SELECT query, calls, total_exec_time, mean_exec_time, rows \
FROM pg_stat_statements \
WHERE calls > 0 \
ORDER BY mean_exec_time DESC \
LIMIT 20",
)
.fetch_all(&pool),
);
let table_stats: Vec<TableStats> = table_stats_res
.map_err(TuskError::Database)?
.iter()
.map(|r| TableStats {
schema: r.get(0),
table: r.get(1),
seq_scan: r.get(2),
idx_scan: r.get(3),
n_live_tup: r.get(4),
table_size: r.get(5),
index_size: r.get(6),
})
.collect();
let index_stats: Vec<IndexStats> = index_stats_res
.map_err(TuskError::Database)?
.iter()
.map(|r| IndexStats {
schema: r.get(0),
table: r.get(1),
index_name: r.get(2),
idx_scan: r.get(3),
index_size: r.get(4),
definition: r.get(5),
})
.collect();
let (slow_queries, has_pg_stat_statements) = match slow_queries_res {
Ok(rows) => {
let queries: Vec<SlowQuery> = rows
.iter()
.map(|r| SlowQuery {
query: r.get(0),
calls: r.get(1),
total_time_ms: r.get(2),
mean_time_ms: r.get(3),
rows: r.get(4),
})
.collect();
(queries, true)
}
Err(_) => (Vec::new(), false),
};
// Build AI prompt for recommendations
let schema_text = build_schema_context(&state, &connection_id).await?;
let mut stats_text = String::from("TABLE STATISTICS:\n");
for ts in &table_stats {
stats_text.push_str(&format!(
" {}.{}: seq_scan={}, idx_scan={}, rows={}, size={}, idx_size={}\n",
ts.schema,
ts.table,
ts.seq_scan,
ts.idx_scan,
ts.n_live_tup,
ts.table_size,
ts.index_size
));
}
stats_text.push_str("\nINDEX STATISTICS:\n");
for is in &index_stats {
stats_text.push_str(&format!(
" {}.{}.{}: scans={}, size={}, def={}\n",
is.schema, is.table, is.index_name, is.idx_scan, is.index_size, is.definition
));
}
if !slow_queries.is_empty() {
stats_text.push_str("\nSLOW QUERIES:\n");
for sq in &slow_queries {
stats_text.push_str(&format!(
" calls={}, mean={:.1}ms, total={:.1}ms, rows={}: {}\n",
sq.calls,
sq.mean_time_ms,
sq.total_time_ms,
sq.rows,
sq.query.chars().take(200).collect::<String>()
));
}
}
let system_prompt = format!(
"You are a PostgreSQL performance expert. Analyze the database statistics and recommend index changes.\n\
\n\
OUTPUT FORMAT:\n\
- Return ONLY a JSON array of recommendation objects.\n\
- No markdown, no explanations, no code fences.\n\
- Each object: {{\"recommendation_type\": \"create_index\"|\"drop_index\"|\"replace_index\", \
\"table_schema\": \"...\", \"table_name\": \"...\", \"index_name\": \"...\"|null, \
\"ddl\": \"CREATE INDEX CONCURRENTLY ...\", \"rationale\": \"...\", \
\"estimated_impact\": \"high\"|\"medium\"|\"low\", \"priority\": \"high\"|\"medium\"|\"low\"}}\n\
\n\
RULES:\n\
1. Prefer CREATE INDEX CONCURRENTLY to avoid locking\n\
2. Never suggest dropping PRIMARY KEY or UNIQUE indexes\n\
3. Suggest dropping indexes with 0 scans on tables with many rows\n\
4. Suggest composite indexes for commonly co-filtered columns\n\
5. Suggest partial indexes for low-cardinality boolean columns\n\
6. Consider covering indexes for frequently selected columns\n\
7. High seq_scan + high row count = strong candidate for new index\n\
\n\
{}\n\
\n\
{}\n",
stats_text, schema_text
);
let raw = call_ollama_chat(
&app,
&state,
system_prompt,
"Analyze indexes and provide recommendations".to_string(),
)
.await?;
let cleaned = raw.trim();
let json_start = cleaned.find('[').unwrap_or(0);
let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len());
let json_str = &cleaned[json_start..json_end];
let recommendations: Vec<IndexRecommendation> =
serde_json::from_str(json_str).unwrap_or_default();
Ok(IndexAdvisorReport {
table_stats,
index_stats,
slow_queries,
recommendations,
has_pg_stat_statements,
})
}
#[tauri::command]
pub async fn apply_index_recommendation(
state: State<'_, Arc<AppState>>,
connection_id: String,
ddl: String,
) -> TuskResult<()> {
if state.is_read_only(&connection_id).await {
return Err(TuskError::ReadOnly);
}
validate_index_ddl(&ddl)?;
let pool = state.get_pool(&connection_id).await?;
// CONCURRENTLY cannot run inside a transaction, execute directly
sqlx::query(&ddl)
.execute(&pool)
.await
.map_err(TuskError::Database)?;
// Invalidate schema cache
state.invalidate_schema_cache(&connection_id).await;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
// ── validate_select_statement ─────────────────────────────
#[test]
fn select_valid_simple() {
assert!(validate_select_statement("SELECT 1").is_ok());
}
#[test]
fn select_valid_with_leading_whitespace() {
assert!(validate_select_statement(" SELECT * FROM users").is_ok());
}
#[test]
fn select_valid_lowercase() {
assert!(validate_select_statement("select * from users").is_ok());
}
#[test]
fn select_rejects_insert() {
assert!(validate_select_statement("INSERT INTO users VALUES (1)").is_err());
}
#[test]
fn select_rejects_delete() {
assert!(validate_select_statement("DELETE FROM users").is_err());
}
#[test]
fn select_rejects_drop() {
assert!(validate_select_statement("DROP TABLE users").is_err());
}
#[test]
fn select_rejects_empty() {
assert!(validate_select_statement("").is_err());
}
#[test]
fn select_rejects_whitespace_only() {
assert!(validate_select_statement(" ").is_err());
}
// NOTE: This test documents a known weakness — SELECT prefix allows injection
#[test]
fn select_allows_semicolon_after_select() {
// "SELECT 1; DROP TABLE users" starts with SELECT — passes validation
// This is a known limitation documented in the review
assert!(validate_select_statement("SELECT 1; DROP TABLE users").is_ok());
}
// ── validate_index_ddl ────────────────────────────────────
#[test]
fn ddl_valid_create_index() {
assert!(validate_index_ddl("CREATE INDEX idx_name ON users(email)").is_ok());
}
#[test]
fn ddl_valid_create_index_concurrently() {
assert!(validate_index_ddl("CREATE INDEX CONCURRENTLY idx ON t(c)").is_ok());
}
#[test]
fn ddl_valid_drop_index() {
assert!(validate_index_ddl("DROP INDEX idx_name").is_ok());
}
#[test]
fn ddl_valid_with_leading_whitespace() {
assert!(validate_index_ddl(" CREATE INDEX idx ON t(c)").is_ok());
}
#[test]
fn ddl_valid_lowercase() {
assert!(validate_index_ddl("create index idx on t(c)").is_ok());
}
#[test]
fn ddl_rejects_create_table() {
assert!(validate_index_ddl("CREATE TABLE evil(id int)").is_err());
}
#[test]
fn ddl_rejects_drop_table() {
assert!(validate_index_ddl("DROP TABLE users").is_err());
}
#[test]
fn ddl_rejects_alter_table() {
assert!(validate_index_ddl("ALTER TABLE users ADD COLUMN x int").is_err());
}
#[test]
fn ddl_rejects_empty() {
assert!(validate_index_ddl("").is_err());
}
// NOTE: Documents bypass weakness — semicolon after valid prefix
#[test]
fn ddl_allows_semicolon_injection() {
// "CREATE INDEX x ON t(c); DROP TABLE users" — passes validation
// Mitigated by sqlx single-statement execution
assert!(validate_index_ddl("CREATE INDEX x ON t(c); DROP TABLE users").is_ok());
}
// ── clean_sql_response ────────────────────────────────────
#[test]
fn clean_sql_plain() {
assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1");
}
#[test]
fn clean_sql_with_fences() {
assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1");
}
#[test]
fn clean_sql_with_generic_fences() {
assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1");
}
#[test]
fn clean_sql_with_postgresql_fences() {
assert_eq!(
clean_sql_response("```postgresql\nSELECT 1\n```"),
"SELECT 1"
);
}
#[test]
fn clean_sql_with_whitespace() {
assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1");
}
#[test]
fn clean_sql_no_fences_multiline() {
assert_eq!(
clean_sql_response("SELECT\n *\nFROM users"),
"SELECT\n *\nFROM users"
);
}
}