feat: add AI data validation, test data generator, index advisor, and snapshots

Four new killer features leveraging AI (Ollama) and PostgreSQL internals:

- Data Validation: describe quality rules in natural language, AI generates
  SQL to find violations, run with pass/fail results and sample violations
- Test Data Generator: right-click table to generate realistic FK-aware test
  data with AI, preview before inserting in a transaction
- Index Advisor: analyze pg_stat tables + AI recommendations for CREATE/DROP
  INDEX with one-click apply
- Data Snapshots: export selected tables to JSON (FK-ordered), restore from
  file with optional truncate in a transaction

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-21 13:27:41 +03:00
parent d507162377
commit a3b05b0328
26 changed files with 3438 additions and 17 deletions

View File

@@ -1,15 +1,21 @@
use crate::commands::data::bind_json_value;
use crate::commands::queries::pg_value_to_json;
use crate::error::{TuskError, TuskResult};
use crate::models::ai::{
AiProvider, AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse,
OllamaModel, OllamaTagsResponse,
AiProvider, AiSettings, GenerateDataParams, GeneratedDataPreview, GeneratedTableData,
IndexAdvisorReport, IndexRecommendation, IndexStats,
OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse,
SlowQuery, TableStats, ValidationRule, ValidationStatus, DataGenProgress,
};
use crate::state::AppState;
use sqlx::Row;
use crate::utils::{escape_ident, topological_sort_tables};
use serde_json::Value;
use sqlx::{Column, Row};
use std::collections::{BTreeMap, HashMap};
use std::fs;
use std::sync::Arc;
use std::time::Duration;
use tauri::{AppHandle, Manager, State};
use std::time::{Duration, Instant};
use tauri::{AppHandle, Emitter, Manager, State};
const MAX_RETRIES: u32 = 2;
const RETRY_DELAY_MS: u64 = 1000;
@@ -386,7 +392,7 @@ pub async fn fix_sql_error(
// Schema context builder
// ---------------------------------------------------------------------------
async fn build_schema_context(
pub(crate) async fn build_schema_context(
state: &AppState,
connection_id: &str,
) -> TuskResult<String> {
@@ -665,16 +671,16 @@ async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult<Vec<ColumnInfo>> {
.collect())
}
struct ForeignKeyInfo {
schema: String,
table: String,
columns: Vec<String>,
ref_schema: String,
ref_table: String,
ref_columns: Vec<String>,
pub(crate) struct ForeignKeyInfo {
pub(crate) schema: String,
pub(crate) table: String,
pub(crate) columns: Vec<String>,
pub(crate) ref_schema: String,
pub(crate) ref_table: String,
pub(crate) ref_columns: Vec<String>,
}
async fn fetch_foreign_keys_raw(pool: &sqlx::PgPool) -> TuskResult<Vec<ForeignKeyInfo>> {
pub(crate) async fn fetch_foreign_keys_raw(pool: &sqlx::PgPool) -> TuskResult<Vec<ForeignKeyInfo>> {
let rows = sqlx::query(
"SELECT \
cn.nspname AS schema_name, cl.relname AS table_name, \
@@ -1043,3 +1049,609 @@ fn clean_sql_response(raw: &str) -> String {
};
without_fences.trim().to_string()
}
// ---------------------------------------------------------------------------
// Wave 1: AI Data Assertions (Validation)
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn generate_validation_sql(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
rule_description: String,
) -> TuskResult<String> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are an expert PostgreSQL data quality validator. Given a database schema and a natural \
language data quality rule, generate a SELECT query that finds ALL rows violating the rule.\n\
\n\
OUTPUT FORMAT:\n\
- Raw SQL only. No explanations, no markdown code fences, no comments.\n\
- The query MUST be a SELECT statement.\n\
- Return violating rows with enough context columns to identify them.\n\
\n\
VALIDATION PATTERNS:\n\
- NULL checks: SELECT * FROM table WHERE required_column IS NULL\n\
- Format checks: WHERE column !~ 'pattern'\n\
- Range checks: WHERE column < min OR column > max\n\
- FK integrity: LEFT JOIN parent ON ... WHERE parent.id IS NULL\n\
- Uniqueness: GROUP BY ... HAVING COUNT(*) > 1\n\
- Date consistency: WHERE start_date > end_date\n\
- Enum validity: WHERE column NOT IN ('val1', 'val2', ...)\n\
\n\
ONLY reference tables and columns that exist in the schema.\n\
\n\
{}\n",
schema_text
);
let raw = call_ollama_chat(&app, &state, system_prompt, rule_description).await?;
Ok(clean_sql_response(&raw))
}
#[tauri::command]
pub async fn run_validation_rule(
state: State<'_, Arc<AppState>>,
connection_id: String,
sql: String,
sample_limit: Option<u32>,
) -> TuskResult<ValidationRule> {
let sql_upper = sql.trim().to_uppercase();
if !sql_upper.starts_with("SELECT") {
return Err(TuskError::Custom(
"Validation query must be a SELECT statement".to_string(),
));
}
let pool = state.get_pool(&connection_id).await?;
let limit = sample_limit.unwrap_or(10);
let _start = Instant::now();
let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
sqlx::query("SET TRANSACTION READ ONLY")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
sqlx::query("SET statement_timeout = '30s'")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
// Count violations
let count_sql = format!("SELECT COUNT(*) FROM ({}) AS _v", sql);
let count_row = sqlx::query(&count_sql)
.fetch_one(&mut *tx)
.await
.map_err(TuskError::Database)?;
let violation_count: i64 = count_row.get(0);
// Sample violations
let sample_sql = format!("SELECT * FROM ({}) AS _v LIMIT {}", sql, limit);
let sample_rows = sqlx::query(&sample_sql)
.fetch_all(&mut *tx)
.await
.map_err(TuskError::Database)?;
tx.rollback().await.map_err(TuskError::Database)?;
let mut violation_columns = Vec::new();
let mut sample_violations = Vec::new();
if let Some(first) = sample_rows.first() {
for col in first.columns() {
violation_columns.push(col.name().to_string());
}
}
for row in &sample_rows {
let vals: Vec<Value> = (0..violation_columns.len())
.map(|i| pg_value_to_json(row, i))
.collect();
sample_violations.push(vals);
}
let status = if violation_count > 0 {
ValidationStatus::Failed
} else {
ValidationStatus::Passed
};
Ok(ValidationRule {
id: String::new(),
description: String::new(),
generated_sql: sql,
status,
violation_count: violation_count as u64,
sample_violations,
violation_columns,
error: None,
})
}
#[tauri::command]
pub async fn suggest_validation_rules(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
) -> TuskResult<Vec<String>> {
let schema_text = build_schema_context(&state, &connection_id).await?;
let system_prompt = format!(
"You are a data quality expert. Given a database schema, suggest 5-10 data quality \
validation rules as natural language descriptions.\n\
\n\
OUTPUT FORMAT:\n\
- Return ONLY a JSON array of strings, each string being a validation rule.\n\
- No markdown, no explanations, no code fences.\n\
- Example: [\"All users must have a non-empty email address\", \"Order total must be positive\"]\n\
\n\
RULE CATEGORIES TO COVER:\n\
- NOT NULL checks for critical columns\n\
- Business logic (positive amounts, valid ranges, consistent dates)\n\
- Referential integrity (orphaned foreign keys)\n\
- Format validation (emails, phone numbers, codes)\n\
- Enum/status field validity\n\
- Date consistency (start before end, not in future where inappropriate)\n\
\n\
{}\n",
schema_text
);
let raw = call_ollama_chat(&app, &state, system_prompt, "Suggest validation rules".to_string()).await?;
let cleaned = raw.trim();
let json_start = cleaned.find('[').unwrap_or(0);
let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len());
let json_str = &cleaned[json_start..json_end];
let rules: Vec<String> = serde_json::from_str(json_str).map_err(|e| {
TuskError::Ai(format!("Failed to parse AI response as JSON array: {}. Response: {}", e, cleaned))
})?;
Ok(rules)
}
// ---------------------------------------------------------------------------
// Wave 2: AI Data Generator
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn generate_test_data_preview(
app: AppHandle,
state: State<'_, Arc<AppState>>,
params: GenerateDataParams,
gen_id: String,
) -> TuskResult<GeneratedDataPreview> {
let pool = state.get_pool(&params.connection_id).await?;
let _ = app.emit("datagen-progress", DataGenProgress {
gen_id: gen_id.clone(),
stage: "schema".to_string(),
percent: 10,
message: "Building schema context...".to_string(),
detail: None,
});
let schema_text = build_schema_context(&state, &params.connection_id).await?;
// Get FK info for topological sort
let fk_rows = fetch_foreign_keys_raw(&pool).await?;
let mut target_tables = vec![(params.schema.clone(), params.table.clone())];
if params.include_related {
// Add parent tables (tables referenced by FKs from target)
for fk in &fk_rows {
if fk.schema == params.schema && fk.table == params.table {
let parent = (fk.ref_schema.clone(), fk.ref_table.clone());
if !target_tables.contains(&parent) {
target_tables.push(parent);
}
}
}
}
let fk_edges: Vec<(String, String, String, String)> = fk_rows
.iter()
.map(|fk| (fk.schema.clone(), fk.table.clone(), fk.ref_schema.clone(), fk.ref_table.clone()))
.collect();
let sorted_tables = topological_sort_tables(&fk_edges, &target_tables);
let insert_order: Vec<String> = sorted_tables
.iter()
.map(|(s, t)| format!("{}.{}", s, t))
.collect();
let row_count = params.row_count.min(1000);
let _ = app.emit("datagen-progress", DataGenProgress {
gen_id: gen_id.clone(),
stage: "generating".to_string(),
percent: 30,
message: "AI is generating test data...".to_string(),
detail: None,
});
let tables_desc: Vec<String> = sorted_tables
.iter()
.map(|(s, t)| {
let count = if s == &params.schema && t == &params.table {
row_count
} else {
(row_count / 3).max(1)
};
format!("{}.{}: {} rows", s, t, count)
})
.collect();
let custom = params
.custom_instructions
.as_deref()
.unwrap_or("Generate realistic sample data");
let system_prompt = format!(
"You are a PostgreSQL test data generator. Generate realistic test data as JSON.\n\
\n\
OUTPUT FORMAT:\n\
- Return ONLY a JSON object where keys are \"schema.table\" and values are arrays of row objects.\n\
- Each row object has column names as keys and values matching the column types.\n\
- No markdown, no explanations, no code fences.\n\
- Example: {{\"public.users\": [{{\"name\": \"Alice\", \"email\": \"alice@example.com\"}}]}}\n\
\n\
RULES:\n\
1. Respect column types exactly (text, integer, boolean, timestamp, uuid, etc.)\n\
2. Use valid foreign key values - parent tables are generated first, reference their IDs\n\
3. Respect enum types - use only valid enum values\n\
4. Omit auto-increment/serial/identity columns (they have DEFAULT auto-increment)\n\
5. Generate realistic data: real names, valid emails, plausible dates, etc.\n\
6. Respect NOT NULL constraints\n\
7. For UUID columns, generate valid UUIDs\n\
8. For timestamp columns, use ISO 8601 format\n\
\n\
Tables to generate (in this exact order):\n\
{}\n\
\n\
Custom instructions: {}\n\
\n\
{}\n",
tables_desc.join("\n"),
custom,
schema_text
);
let raw = call_ollama_chat(
&app,
&state,
system_prompt,
format!("Generate test data for {} tables", sorted_tables.len()),
)
.await?;
let _ = app.emit("datagen-progress", DataGenProgress {
gen_id: gen_id.clone(),
stage: "parsing".to_string(),
percent: 80,
message: "Parsing generated data...".to_string(),
detail: None,
});
// Parse JSON response
let cleaned = raw.trim();
let json_start = cleaned.find('{').unwrap_or(0);
let json_end = cleaned.rfind('}').map(|i| i + 1).unwrap_or(cleaned.len());
let json_str = &cleaned[json_start..json_end];
let data_map: HashMap<String, Vec<HashMap<String, Value>>> =
serde_json::from_str(json_str).map_err(|e| {
TuskError::Ai(format!("Failed to parse generated data: {}. Response: {}", e, &cleaned[..cleaned.len().min(500)]))
})?;
let mut tables = Vec::new();
let mut total_rows: u32 = 0;
for (schema, table) in &sorted_tables {
let key = format!("{}.{}", schema, table);
if let Some(rows_data) = data_map.get(&key) {
let columns: Vec<String> = if let Some(first) = rows_data.first() {
first.keys().cloned().collect()
} else {
Vec::new()
};
let rows: Vec<Vec<Value>> = rows_data
.iter()
.map(|row_map| columns.iter().map(|col| row_map.get(col).cloned().unwrap_or(Value::Null)).collect())
.collect();
let count = rows.len() as u32;
total_rows += count;
tables.push(GeneratedTableData {
schema: schema.clone(),
table: table.clone(),
columns,
rows,
row_count: count,
});
}
}
let _ = app.emit("datagen-progress", DataGenProgress {
gen_id: gen_id.clone(),
stage: "done".to_string(),
percent: 100,
message: "Data generation complete".to_string(),
detail: Some(format!("{} rows across {} tables", total_rows, tables.len())),
});
Ok(GeneratedDataPreview {
tables,
insert_order,
total_rows,
})
}
#[tauri::command]
pub async fn insert_generated_data(
state: State<'_, Arc<AppState>>,
connection_id: String,
preview: GeneratedDataPreview,
) -> TuskResult<u64> {
if state.is_read_only(&connection_id).await {
return Err(TuskError::ReadOnly);
}
let pool = state.get_pool(&connection_id).await?;
let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
// Defer constraints for circular FKs
sqlx::query("SET CONSTRAINTS ALL DEFERRED")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
let mut total_inserted: u64 = 0;
for table_data in &preview.tables {
if table_data.columns.is_empty() || table_data.rows.is_empty() {
continue;
}
let qualified = format!(
"{}.{}",
escape_ident(&table_data.schema),
escape_ident(&table_data.table)
);
let col_list: Vec<String> = table_data.columns.iter().map(|c| escape_ident(c)).collect();
let placeholders: Vec<String> = (1..=table_data.columns.len())
.map(|i| format!("${}", i))
.collect();
let sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
qualified,
col_list.join(", "),
placeholders.join(", ")
);
for row in &table_data.rows {
let mut query = sqlx::query(&sql);
for val in row {
query = bind_json_value(query, val);
}
query.execute(&mut *tx).await.map_err(TuskError::Database)?;
total_inserted += 1;
}
}
tx.commit().await.map_err(TuskError::Database)?;
// Invalidate schema cache since data changed
state.invalidate_schema_cache(&connection_id).await;
Ok(total_inserted)
}
// ---------------------------------------------------------------------------
// Wave 3A: Smart Index Advisor
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn get_index_advisor_report(
app: AppHandle,
state: State<'_, Arc<AppState>>,
connection_id: String,
) -> TuskResult<IndexAdvisorReport> {
let pool = state.get_pool(&connection_id).await?;
// Fetch table stats
let table_stats_rows = sqlx::query(
"SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \
pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \
pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \
FROM pg_stat_user_tables \
ORDER BY seq_scan DESC \
LIMIT 50"
)
.fetch_all(&pool)
.await
.map_err(TuskError::Database)?;
let table_stats: Vec<TableStats> = table_stats_rows
.iter()
.map(|r| TableStats {
schema: r.get(0),
table: r.get(1),
seq_scan: r.get(2),
idx_scan: r.get(3),
n_live_tup: r.get(4),
table_size: r.get(5),
index_size: r.get(6),
})
.collect();
// Fetch index stats
let index_stats_rows = sqlx::query(
"SELECT schemaname, relname, indexrelname, idx_scan, \
pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \
pg_get_indexdef(indexrelid) AS definition \
FROM pg_stat_user_indexes \
ORDER BY idx_scan ASC \
LIMIT 50"
)
.fetch_all(&pool)
.await
.map_err(TuskError::Database)?;
let index_stats: Vec<IndexStats> = index_stats_rows
.iter()
.map(|r| IndexStats {
schema: r.get(0),
table: r.get(1),
index_name: r.get(2),
idx_scan: r.get(3),
index_size: r.get(4),
definition: r.get(5),
})
.collect();
// Fetch slow queries (graceful if pg_stat_statements not available)
let (slow_queries, has_pg_stat_statements) = match sqlx::query(
"SELECT query, calls, total_exec_time, mean_exec_time, rows \
FROM pg_stat_statements \
WHERE calls > 0 \
ORDER BY mean_exec_time DESC \
LIMIT 20"
)
.fetch_all(&pool)
.await
{
Ok(rows) => {
let queries: Vec<SlowQuery> = rows
.iter()
.map(|r| SlowQuery {
query: r.get(0),
calls: r.get(1),
total_time_ms: r.get(2),
mean_time_ms: r.get(3),
rows: r.get(4),
})
.collect();
(queries, true)
}
Err(_) => (Vec::new(), false),
};
// Build AI prompt for recommendations
let schema_text = build_schema_context(&state, &connection_id).await?;
let mut stats_text = String::from("TABLE STATISTICS:\n");
for ts in &table_stats {
stats_text.push_str(&format!(
" {}.{}: seq_scan={}, idx_scan={}, rows={}, size={}, idx_size={}\n",
ts.schema, ts.table, ts.seq_scan, ts.idx_scan, ts.n_live_tup, ts.table_size, ts.index_size
));
}
stats_text.push_str("\nINDEX STATISTICS:\n");
for is in &index_stats {
stats_text.push_str(&format!(
" {}.{}.{}: scans={}, size={}, def={}\n",
is.schema, is.table, is.index_name, is.idx_scan, is.index_size, is.definition
));
}
if !slow_queries.is_empty() {
stats_text.push_str("\nSLOW QUERIES:\n");
for sq in &slow_queries {
stats_text.push_str(&format!(
" calls={}, mean={:.1}ms, total={:.1}ms, rows={}: {}\n",
sq.calls, sq.mean_time_ms, sq.total_time_ms, sq.rows,
sq.query.chars().take(200).collect::<String>()
));
}
}
let system_prompt = format!(
"You are a PostgreSQL performance expert. Analyze the database statistics and recommend index changes.\n\
\n\
OUTPUT FORMAT:\n\
- Return ONLY a JSON array of recommendation objects.\n\
- No markdown, no explanations, no code fences.\n\
- Each object: {{\"recommendation_type\": \"create_index\"|\"drop_index\"|\"replace_index\", \
\"table_schema\": \"...\", \"table_name\": \"...\", \"index_name\": \"...\"|null, \
\"ddl\": \"CREATE INDEX CONCURRENTLY ...\", \"rationale\": \"...\", \
\"estimated_impact\": \"high\"|\"medium\"|\"low\", \"priority\": \"high\"|\"medium\"|\"low\"}}\n\
\n\
RULES:\n\
1. Prefer CREATE INDEX CONCURRENTLY to avoid locking\n\
2. Never suggest dropping PRIMARY KEY or UNIQUE indexes\n\
3. Suggest dropping indexes with 0 scans on tables with many rows\n\
4. Suggest composite indexes for commonly co-filtered columns\n\
5. Suggest partial indexes for low-cardinality boolean columns\n\
6. Consider covering indexes for frequently selected columns\n\
7. High seq_scan + high row count = strong candidate for new index\n\
\n\
{}\n\
\n\
{}\n",
stats_text, schema_text
);
let raw = call_ollama_chat(
&app,
&state,
system_prompt,
"Analyze indexes and provide recommendations".to_string(),
)
.await?;
let cleaned = raw.trim();
let json_start = cleaned.find('[').unwrap_or(0);
let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len());
let json_str = &cleaned[json_start..json_end];
let recommendations: Vec<IndexRecommendation> =
serde_json::from_str(json_str).unwrap_or_default();
Ok(IndexAdvisorReport {
table_stats,
index_stats,
slow_queries,
recommendations,
has_pg_stat_statements,
})
}
#[tauri::command]
pub async fn apply_index_recommendation(
state: State<'_, Arc<AppState>>,
connection_id: String,
ddl: String,
) -> TuskResult<()> {
if state.is_read_only(&connection_id).await {
return Err(TuskError::ReadOnly);
}
let ddl_upper = ddl.trim().to_uppercase();
if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") {
return Err(TuskError::Custom(
"Only CREATE INDEX and DROP INDEX statements are allowed".to_string(),
));
}
let pool = state.get_pool(&connection_id).await?;
// CONCURRENTLY cannot run inside a transaction, execute directly
sqlx::query(&ddl)
.execute(&pool)
.await
.map_err(TuskError::Database)?;
// Invalidate schema cache
state.invalidate_schema_cache(&connection_id).await;
Ok(())
}

View File

@@ -283,7 +283,7 @@ pub async fn delete_rows(
Ok(total_affected)
}
fn bind_json_value<'q>(
pub(crate) fn bind_json_value<'q>(
query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
value: &'q Value,
) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {

View File

@@ -9,4 +9,5 @@ pub mod management;
pub mod queries;
pub mod saved_queries;
pub mod schema;
pub mod snapshot;
pub mod settings;

View File

@@ -0,0 +1,349 @@
use crate::commands::ai::fetch_foreign_keys_raw;
use crate::commands::data::bind_json_value;
use crate::commands::queries::pg_value_to_json;
use crate::error::{TuskError, TuskResult};
use crate::models::snapshot::{
CreateSnapshotParams, RestoreSnapshotParams, Snapshot, SnapshotMetadata, SnapshotProgress,
SnapshotTableData, SnapshotTableMeta,
};
use crate::state::AppState;
use crate::utils::{escape_ident, topological_sort_tables};
use serde_json::Value;
use sqlx::{Column, Row, TypeInfo};
use std::fs;
use std::sync::Arc;
use tauri::{AppHandle, Emitter, Manager, State};
#[tauri::command]
pub async fn create_snapshot(
app: AppHandle,
state: State<'_, Arc<AppState>>,
params: CreateSnapshotParams,
snapshot_id: String,
file_path: String,
) -> TuskResult<SnapshotMetadata> {
let pool = state.get_pool(&params.connection_id).await?;
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "preparing".to_string(),
percent: 5,
message: "Preparing snapshot...".to_string(),
detail: None,
},
);
let mut target_tables: Vec<(String, String)> = params
.tables
.iter()
.map(|t| (t.schema.clone(), t.table.clone()))
.collect();
if params.include_dependencies {
let fk_rows = fetch_foreign_keys_raw(&pool).await?;
for fk in &fk_rows {
for (schema, table) in &params.tables.iter().map(|t| (t.schema.clone(), t.table.clone())).collect::<Vec<_>>() {
if &fk.schema == schema && &fk.table == table {
let parent = (fk.ref_schema.clone(), fk.ref_table.clone());
if !target_tables.contains(&parent) {
target_tables.push(parent);
}
}
}
}
}
// FK-based topological sort
let fk_rows = fetch_foreign_keys_raw(&pool).await?;
let fk_edges: Vec<(String, String, String, String)> = fk_rows
.iter()
.map(|fk| (fk.schema.clone(), fk.table.clone(), fk.ref_schema.clone(), fk.ref_table.clone()))
.collect();
let sorted_tables = topological_sort_tables(&fk_edges, &target_tables);
let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
sqlx::query("SET TRANSACTION READ ONLY")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
let total_tables = sorted_tables.len();
let mut snapshot_tables: Vec<SnapshotTableData> = Vec::new();
let mut table_metas: Vec<SnapshotTableMeta> = Vec::new();
let mut total_rows: u64 = 0;
for (i, (schema, table)) in sorted_tables.iter().enumerate() {
let percent = 10 + ((i as u8) * 80 / total_tables.max(1) as u8);
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "exporting".to_string(),
percent,
message: format!("Exporting {}.{}...", schema, table),
detail: None,
},
);
let qualified = format!("{}.{}", escape_ident(schema), escape_ident(table));
let sql = format!("SELECT * FROM {}", qualified);
let rows = sqlx::query(&sql)
.fetch_all(&mut *tx)
.await
.map_err(TuskError::Database)?;
let mut columns = Vec::new();
let mut column_types = Vec::new();
if let Some(first) = rows.first() {
for col in first.columns() {
columns.push(col.name().to_string());
column_types.push(col.type_info().name().to_string());
}
}
let data_rows: Vec<Vec<Value>> = rows
.iter()
.map(|row| (0..columns.len()).map(|i| pg_value_to_json(row, i)).collect())
.collect();
let row_count = data_rows.len() as u64;
total_rows += row_count;
table_metas.push(SnapshotTableMeta {
schema: schema.clone(),
table: table.clone(),
row_count,
columns: columns.clone(),
column_types: column_types.clone(),
});
snapshot_tables.push(SnapshotTableData {
schema: schema.clone(),
table: table.clone(),
columns,
column_types,
rows: data_rows,
});
}
tx.rollback().await.map_err(TuskError::Database)?;
let metadata = SnapshotMetadata {
id: snapshot_id.clone(),
name: params.name.clone(),
created_at: chrono::Utc::now().to_rfc3339(),
connection_name: String::new(),
database: String::new(),
tables: table_metas,
total_rows,
file_size_bytes: 0,
version: 1,
};
let snapshot = Snapshot {
metadata: metadata.clone(),
tables: snapshot_tables,
};
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "saving".to_string(),
percent: 95,
message: "Saving snapshot file...".to_string(),
detail: None,
},
);
let json = serde_json::to_string_pretty(&snapshot)?;
let file_size = json.len() as u64;
fs::write(&file_path, json)?;
let mut final_metadata = metadata;
final_metadata.file_size_bytes = file_size;
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "done".to_string(),
percent: 100,
message: "Snapshot created successfully".to_string(),
detail: Some(format!("{} rows, {} tables", total_rows, total_tables)),
},
);
Ok(final_metadata)
}
#[tauri::command]
pub async fn restore_snapshot(
app: AppHandle,
state: State<'_, Arc<AppState>>,
params: RestoreSnapshotParams,
snapshot_id: String,
) -> TuskResult<u64> {
if state.is_read_only(&params.connection_id).await {
return Err(TuskError::ReadOnly);
}
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "reading".to_string(),
percent: 5,
message: "Reading snapshot file...".to_string(),
detail: None,
},
);
let data = fs::read_to_string(&params.file_path)?;
let snapshot: Snapshot = serde_json::from_str(&data)?;
let pool = state.get_pool(&params.connection_id).await?;
let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
sqlx::query("SET CONSTRAINTS ALL DEFERRED")
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
// TRUNCATE in reverse order (children first)
if params.truncate_before_restore {
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "truncating".to_string(),
percent: 15,
message: "Truncating existing data...".to_string(),
detail: None,
},
);
for table_data in snapshot.tables.iter().rev() {
let qualified = format!(
"{}.{}",
escape_ident(&table_data.schema),
escape_ident(&table_data.table)
);
let truncate_sql = format!("TRUNCATE {} CASCADE", qualified);
sqlx::query(&truncate_sql)
.execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
}
}
// INSERT in forward order (parents first)
let total_tables = snapshot.tables.len();
let mut total_inserted: u64 = 0;
for (i, table_data) in snapshot.tables.iter().enumerate() {
if table_data.columns.is_empty() || table_data.rows.is_empty() {
continue;
}
let percent = 20 + ((i as u8) * 75 / total_tables.max(1) as u8);
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "inserting".to_string(),
percent,
message: format!("Restoring {}.{}...", table_data.schema, table_data.table),
detail: Some(format!("{} rows", table_data.rows.len())),
},
);
let qualified = format!(
"{}.{}",
escape_ident(&table_data.schema),
escape_ident(&table_data.table)
);
let col_list: Vec<String> = table_data.columns.iter().map(|c| escape_ident(c)).collect();
let placeholders: Vec<String> = (1..=table_data.columns.len())
.map(|i| format!("${}", i))
.collect();
let sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
qualified,
col_list.join(", "),
placeholders.join(", ")
);
// Chunked insert
for row in &table_data.rows {
let mut query = sqlx::query(&sql);
for val in row {
query = bind_json_value(query, val);
}
query.execute(&mut *tx).await.map_err(TuskError::Database)?;
total_inserted += 1;
}
}
tx.commit().await.map_err(TuskError::Database)?;
let _ = app.emit(
"snapshot-progress",
SnapshotProgress {
snapshot_id: snapshot_id.clone(),
stage: "done".to_string(),
percent: 100,
message: "Restore completed successfully".to_string(),
detail: Some(format!("{} rows restored", total_inserted)),
},
);
state.invalidate_schema_cache(&params.connection_id).await;
Ok(total_inserted)
}
#[tauri::command]
pub async fn list_snapshots(app: AppHandle) -> TuskResult<Vec<SnapshotMetadata>> {
let dir = app
.path()
.app_data_dir()
.map_err(|e| TuskError::Custom(e.to_string()))?
.join("snapshots");
if !dir.exists() {
return Ok(Vec::new());
}
let mut snapshots = Vec::new();
for entry in fs::read_dir(&dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().map(|e| e == "json").unwrap_or(false) {
if let Ok(data) = fs::read_to_string(&path) {
if let Ok(snapshot) = serde_json::from_str::<Snapshot>(&data) {
let mut meta = snapshot.metadata;
meta.file_size_bytes = entry.metadata().map(|m| m.len()).unwrap_or(0);
snapshots.push(meta);
}
}
}
}
snapshots.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(snapshots)
}
#[tauri::command]
pub async fn read_snapshot_metadata(file_path: String) -> TuskResult<SnapshotMetadata> {
let data = fs::read_to_string(&file_path)?;
let snapshot: Snapshot = serde_json::from_str(&data)?;
let mut meta = snapshot.metadata;
meta.file_size_bytes = fs::metadata(&file_path).map(|m| m.len()).unwrap_or(0);
Ok(meta)
}

View File

@@ -135,6 +135,18 @@ pub fn run() {
commands::ai::generate_sql,
commands::ai::explain_sql,
commands::ai::fix_sql_error,
commands::ai::generate_validation_sql,
commands::ai::run_validation_rule,
commands::ai::suggest_validation_rules,
commands::ai::generate_test_data_preview,
commands::ai::insert_generated_data,
commands::ai::get_index_advisor_report,
commands::ai::apply_index_recommendation,
// snapshot
commands::snapshot::create_snapshot,
commands::snapshot::restore_snapshot,
commands::snapshot::list_snapshots,
commands::snapshot::read_snapshot_metadata,
// lookup
commands::lookup::entity_lookup,
// docker

View File

@@ -57,3 +57,137 @@ pub struct OllamaTagsResponse {
pub struct OllamaModel {
pub name: String,
}
// --- Wave 1: Validation ---
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ValidationStatus {
Pending,
Generating,
Running,
Passed,
Failed,
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationRule {
pub id: String,
pub description: String,
pub generated_sql: String,
pub status: ValidationStatus,
pub violation_count: u64,
pub sample_violations: Vec<Vec<serde_json::Value>>,
pub violation_columns: Vec<String>,
pub error: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationReport {
pub rules: Vec<ValidationRule>,
pub total_rules: usize,
pub passed: usize,
pub failed: usize,
pub errors: usize,
pub execution_time_ms: u128,
}
// --- Wave 2: Data Generator ---
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerateDataParams {
pub connection_id: String,
pub schema: String,
pub table: String,
pub row_count: u32,
pub include_related: bool,
pub custom_instructions: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedDataPreview {
pub tables: Vec<GeneratedTableData>,
pub insert_order: Vec<String>,
pub total_rows: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedTableData {
pub schema: String,
pub table: String,
pub columns: Vec<String>,
pub rows: Vec<Vec<serde_json::Value>>,
pub row_count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataGenProgress {
pub gen_id: String,
pub stage: String,
pub percent: u8,
pub message: String,
pub detail: Option<String>,
}
// --- Wave 3A: Index Advisor ---
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableStats {
pub schema: String,
pub table: String,
pub seq_scan: i64,
pub idx_scan: i64,
pub n_live_tup: i64,
pub table_size: String,
pub index_size: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub schema: String,
pub table: String,
pub index_name: String,
pub idx_scan: i64,
pub index_size: String,
pub definition: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SlowQuery {
pub query: String,
pub calls: i64,
pub total_time_ms: f64,
pub mean_time_ms: f64,
pub rows: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum IndexRecommendationType {
CreateIndex,
DropIndex,
ReplaceIndex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexRecommendation {
pub id: String,
pub recommendation_type: IndexRecommendationType,
pub table_schema: String,
pub table_name: String,
pub index_name: Option<String>,
pub ddl: String,
pub rationale: String,
pub estimated_impact: String,
pub priority: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexAdvisorReport {
pub table_stats: Vec<TableStats>,
pub index_stats: Vec<IndexStats>,
pub slow_queries: Vec<SlowQuery>,
pub recommendations: Vec<IndexRecommendation>,
pub has_pg_stat_statements: bool,
}

View File

@@ -8,3 +8,4 @@ pub mod query_result;
pub mod saved_queries;
pub mod schema;
pub mod settings;
pub mod snapshot;

View File

@@ -0,0 +1,68 @@
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotMetadata {
pub id: String,
pub name: String,
pub created_at: String,
pub connection_name: String,
pub database: String,
pub tables: Vec<SnapshotTableMeta>,
pub total_rows: u64,
pub file_size_bytes: u64,
pub version: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotTableMeta {
pub schema: String,
pub table: String,
pub row_count: u64,
pub columns: Vec<String>,
pub column_types: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Snapshot {
pub metadata: SnapshotMetadata,
pub tables: Vec<SnapshotTableData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotTableData {
pub schema: String,
pub table: String,
pub columns: Vec<String>,
pub column_types: Vec<String>,
pub rows: Vec<Vec<serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SnapshotProgress {
pub snapshot_id: String,
pub stage: String,
pub percent: u8,
pub message: String,
pub detail: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateSnapshotParams {
pub connection_id: String,
pub tables: Vec<TableRef>,
pub name: String,
pub include_dependencies: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TableRef {
pub schema: String,
pub table: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RestoreSnapshotParams {
pub connection_id: String,
pub file_path: String,
pub truncate_before_restore: bool,
}

View File

@@ -1,3 +1,75 @@
use std::collections::{HashMap, HashSet};
pub fn escape_ident(name: &str) -> String {
format!("\"{}\"", name.replace('"', "\"\""))
}
/// Topological sort of tables based on foreign key dependencies.
/// Returns tables in insertion order: parents before children.
pub fn topological_sort_tables(
fk_edges: &[(String, String, String, String)], // (schema, table, ref_schema, ref_table)
target_tables: &[(String, String)],
) -> Vec<(String, String)> {
let mut graph: HashMap<(String, String), HashSet<(String, String)>> = HashMap::new();
let mut in_degree: HashMap<(String, String), usize> = HashMap::new();
// Initialize all target tables
for t in target_tables {
graph.entry(t.clone()).or_default();
in_degree.entry(t.clone()).or_insert(0);
}
let target_set: HashSet<(String, String)> = target_tables.iter().cloned().collect();
// Build edges: parent -> child (child depends on parent)
for (schema, table, ref_schema, ref_table) in fk_edges {
let child = (schema.clone(), table.clone());
let parent = (ref_schema.clone(), ref_table.clone());
if child == parent {
continue; // self-referencing
}
if !target_set.contains(&child) || !target_set.contains(&parent) {
continue;
}
if graph.entry(parent.clone()).or_default().insert(child.clone()) {
*in_degree.entry(child).or_insert(0) += 1;
}
}
// Kahn's algorithm
let mut queue: Vec<(String, String)> = in_degree
.iter()
.filter(|(_, &deg)| deg == 0)
.map(|(k, _)| k.clone())
.collect();
queue.sort(); // deterministic order
let mut result = Vec::new();
while let Some(node) = queue.pop() {
result.push(node.clone());
if let Some(neighbors) = graph.get(&node) {
for neighbor in neighbors {
if let Some(deg) = in_degree.get_mut(neighbor) {
*deg -= 1;
if *deg == 0 {
queue.push(neighbor.clone());
queue.sort();
}
}
}
}
}
// Add any remaining tables (cycles) at the end
for t in target_tables {
if !result.contains(t) {
result.push(t.clone());
}
}
result
}