- Make reqwest::Client a LazyLock singleton instead of per-call allocation - Parallelize 3 independent DB queries in get_index_advisor_report with tokio::join! - Eliminate per-iteration Vec allocation in snapshot FK dependency loop - Hoist try_local_pg_dump() call in SampleData clone mode to avoid double execution - Evict stale schema cache entries on write to prevent unbounded memory growth - Remove unused ValidationReport struct and config_path field - Rename IndexRecommendationType variants to remove redundant suffix
1863 lines
62 KiB
Rust
1863 lines
62 KiB
Rust
use crate::commands::data::bind_json_value;
|
|
use crate::commands::queries::pg_value_to_json;
|
|
use crate::error::{TuskError, TuskResult};
|
|
use crate::models::ai::{
|
|
AiProvider, AiSettings, DataGenProgress, GenerateDataParams, GeneratedDataPreview,
|
|
GeneratedTableData, IndexAdvisorReport, IndexRecommendation, IndexStats, OllamaChatMessage,
|
|
OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse, SlowQuery, TableStats,
|
|
ValidationRule, ValidationStatus,
|
|
};
|
|
use crate::state::AppState;
|
|
use crate::utils::{escape_ident, topological_sort_tables};
|
|
use serde_json::Value;
|
|
use sqlx::{Column, Row};
|
|
use std::collections::{BTreeMap, HashMap};
|
|
use std::fs;
|
|
use std::sync::Arc;
|
|
use std::time::{Duration, Instant};
|
|
use tauri::{AppHandle, Emitter, Manager, State};
|
|
|
|
const MAX_RETRIES: u32 = 2;
|
|
const RETRY_DELAY_MS: u64 = 1000;
|
|
|
|
fn http_client() -> &'static reqwest::Client {
|
|
use std::sync::LazyLock;
|
|
static CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
|
|
reqwest::Client::builder()
|
|
.connect_timeout(Duration::from_secs(5))
|
|
.timeout(Duration::from_secs(300))
|
|
.build()
|
|
.unwrap_or_default()
|
|
});
|
|
&CLIENT
|
|
}
|
|
|
|
fn get_ai_settings_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
|
|
let dir = app
|
|
.path()
|
|
.app_data_dir()
|
|
.map_err(|e| TuskError::Custom(e.to_string()))?;
|
|
fs::create_dir_all(&dir)?;
|
|
Ok(dir.join("ai_settings.json"))
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub async fn get_ai_settings(app: AppHandle) -> TuskResult<AiSettings> {
|
|
let path = get_ai_settings_path(&app)?;
|
|
if !path.exists() {
|
|
return Ok(AiSettings::default());
|
|
}
|
|
let data = fs::read_to_string(&path)?;
|
|
let settings: AiSettings = serde_json::from_str(&data)?;
|
|
Ok(settings)
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub async fn save_ai_settings(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
settings: AiSettings,
|
|
) -> TuskResult<()> {
|
|
let path = get_ai_settings_path(&app)?;
|
|
let data = serde_json::to_string_pretty(&settings)?;
|
|
fs::write(&path, data)?;
|
|
// Update in-memory cache
|
|
*state.ai_settings.write().await = Some(settings);
|
|
Ok(())
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaModel>> {
|
|
let url = format!("{}/api/tags", ollama_url.trim_end_matches('/'));
|
|
let resp =
|
|
http_client().get(&url).send().await.map_err(|e| {
|
|
TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e))
|
|
})?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let body = resp.text().await.unwrap_or_default();
|
|
return Err(TuskError::Ai(format!(
|
|
"Ollama error ({}): {}",
|
|
status, body
|
|
)));
|
|
}
|
|
|
|
let tags: OllamaTagsResponse = resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?;
|
|
|
|
Ok(tags.models)
|
|
}
|
|
|
|
async fn call_ai_with_retry<F, Fut, T>(
|
|
_settings: &AiSettings,
|
|
operation: &str,
|
|
f: F,
|
|
) -> TuskResult<T>
|
|
where
|
|
F: Fn() -> Fut,
|
|
Fut: std::future::Future<Output = TuskResult<T>>,
|
|
{
|
|
let mut last_error = None;
|
|
|
|
for attempt in 0..MAX_RETRIES {
|
|
match f().await {
|
|
Ok(result) => return Ok(result),
|
|
Err(e) => {
|
|
last_error = Some(e);
|
|
if attempt < MAX_RETRIES - 1 {
|
|
log::warn!(
|
|
"{} failed (attempt {}/{}), retrying in {}ms...",
|
|
operation,
|
|
attempt + 1,
|
|
MAX_RETRIES,
|
|
RETRY_DELAY_MS
|
|
);
|
|
tokio::time::sleep(Duration::from_millis(RETRY_DELAY_MS)).await;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(last_error.unwrap_or_else(|| {
|
|
TuskError::Ai(format!(
|
|
"{} failed after {} attempts",
|
|
operation, MAX_RETRIES
|
|
))
|
|
}))
|
|
}
|
|
|
|
async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskResult<AiSettings> {
|
|
// Try in-memory cache first
|
|
if let Some(cached) = state.ai_settings.read().await.clone() {
|
|
return Ok(cached);
|
|
}
|
|
// Fallback to disk
|
|
let path = get_ai_settings_path(app)?;
|
|
if !path.exists() {
|
|
return Err(TuskError::Ai(
|
|
"No AI model selected. Open AI settings to choose a model.".to_string(),
|
|
));
|
|
}
|
|
let data = fs::read_to_string(&path)?;
|
|
let settings: AiSettings = serde_json::from_str(&data)?;
|
|
// Populate cache for future calls
|
|
*state.ai_settings.write().await = Some(settings.clone());
|
|
Ok(settings)
|
|
}
|
|
|
|
async fn call_ollama_chat(
|
|
app: &AppHandle,
|
|
state: &AppState,
|
|
system_prompt: String,
|
|
user_content: String,
|
|
) -> TuskResult<String> {
|
|
let settings = load_ai_settings(app, state).await?;
|
|
|
|
if settings.model.is_empty() {
|
|
return Err(TuskError::Ai(
|
|
"No AI model selected. Open AI settings to choose a model.".to_string(),
|
|
));
|
|
}
|
|
|
|
if settings.provider != AiProvider::Ollama {
|
|
return Err(TuskError::Ai(format!(
|
|
"Provider {:?} not implemented yet",
|
|
settings.provider
|
|
)));
|
|
}
|
|
|
|
let model = settings.model.clone();
|
|
let url = format!("{}/api/chat", settings.ollama_url.trim_end_matches('/'));
|
|
|
|
let request = OllamaChatRequest {
|
|
model: model.clone(),
|
|
messages: vec![
|
|
OllamaChatMessage {
|
|
role: "system".to_string(),
|
|
content: system_prompt,
|
|
},
|
|
OllamaChatMessage {
|
|
role: "user".to_string(),
|
|
content: user_content,
|
|
},
|
|
],
|
|
stream: false,
|
|
};
|
|
|
|
call_ai_with_retry(&settings, "Ollama request", || {
|
|
let url = url.clone();
|
|
let request = request.clone();
|
|
async move {
|
|
let resp = http_client()
|
|
.post(&url)
|
|
.json(&request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| {
|
|
TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", url, e))
|
|
})?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let body = resp.text().await.unwrap_or_default();
|
|
return Err(TuskError::Ai(format!(
|
|
"Ollama error ({}): {}",
|
|
status, body
|
|
)));
|
|
}
|
|
|
|
let chat_resp: OllamaChatResponse = resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?;
|
|
|
|
Ok(chat_resp.message.content)
|
|
}
|
|
})
|
|
.await
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// SQL generation
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[tauri::command]
|
|
pub async fn generate_sql(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
prompt: String,
|
|
) -> TuskResult<String> {
|
|
let schema_text = build_schema_context(&state, &connection_id).await?;
|
|
|
|
let system_prompt = format!(
|
|
"You are an expert PostgreSQL query generator. You receive a database schema and a natural \
|
|
language request. Output ONLY a valid, executable PostgreSQL SQL query.\n\
|
|
\n\
|
|
OUTPUT FORMAT:\n\
|
|
- Raw SQL only. No explanations, no markdown code fences (```), no comments, no preamble.\n\
|
|
- The output must be directly executable in psql.\n\
|
|
- For complex queries use readable formatting with line breaks and indentation.\n\
|
|
\n\
|
|
CRITICAL RULES:\n\
|
|
1. ONLY reference tables and columns that exist in the schema. Never invent names.\n\
|
|
2. Use the FOREIGN KEY information to determine correct JOIN conditions.\n\
|
|
3. Use LEFT JOIN when the FK column is nullable or the relationship is optional; \
|
|
INNER JOIN when both sides must exist.\n\
|
|
4. Every non-aggregated column in SELECT must appear in GROUP BY.\n\
|
|
5. Use COALESCE for nullable columns in aggregations: COALESCE(SUM(x), 0).\n\
|
|
6. For enum columns, use ONLY the values listed in the ENUM TYPES section.\n\
|
|
7. Use IS NULL / IS NOT NULL for null checks — never = NULL or != NULL.\n\
|
|
8. Add LIMIT when the user asks for \"top N\", \"first N\", \"latest N\", etc.\n\
|
|
9. Qualify column names with table alias when the query involves multiple tables.\n\
|
|
\n\
|
|
SEMANTIC RULES (very important):\n\
|
|
- When a table has both actual_* and planned_* columns (e.g. actual_start vs planned_start), \
|
|
they represent DIFFERENT concepts: planned = future estimate, actual = what really happened. \
|
|
NEVER mix them with COALESCE unless the user explicitly requests a fallback.\n\
|
|
- For time-based calculations involving real events (\"how long did X take\", \"average time between\"), \
|
|
use ONLY actual/factual timestamps (actual_*, started_at, completed_at, ended_at). \
|
|
Filter out NULL values with WHERE instead of falling back to planned timestamps.\n\
|
|
- Planned timestamps (planned_*, scheduled_*, estimated_*) should ONLY be used when the user \
|
|
asks about plans, schedules, SLA, or compares plan vs fact.\n\
|
|
- When computing durations or averages, always filter out rows where any involved timestamp \
|
|
is NULL rather than substituting with unrelated defaults.\n\
|
|
- Pay attention to column descriptions/comments in the schema — they reveal business semantics \
|
|
that are critical for correct queries.\n\
|
|
\n\
|
|
TYPE RULES:\n\
|
|
- timestamp - timestamp = interval. For seconds: EXTRACT(EPOCH FROM (ts1 - ts2)).\n\
|
|
- interval cannot be cast to numeric directly; use EXTRACT(EPOCH FROM interval).\n\
|
|
- UNION/UNION ALL requires matching column count and compatible types; cast enums to text.\n\
|
|
- Use ::type for PostgreSQL-style casts.\n\
|
|
- For array columns use ANY, ALL, @>, <@ operators.\n\
|
|
- For JSONB columns use ->, ->>, #>, jsonb_extract_path.\n\
|
|
\n\
|
|
COMMON PATTERNS:\n\
|
|
- FIRST/LAST per group: to find MIN(started_at) per trip, use \
|
|
\"SELECT trip_id, MIN(started_at) FROM t GROUP BY trip_id\". \
|
|
NEVER put the aggregated column (started_at) into GROUP BY — that defeats the aggregation \
|
|
and returns every row separately instead of one per group.\n\
|
|
- TOP-1 per group with extra columns: use DISTINCT ON (group_col) ... ORDER BY group_col, sort_col \
|
|
or a subquery with ROW_NUMBER() OVER (PARTITION BY group_col ORDER BY sort_col) = 1.\n\
|
|
- For \"time from A to B\" calculations, ensure both timestamps are NOT NULL with WHERE filters; \
|
|
never use COALESCE to mix planned and actual timestamps.\n\
|
|
\n\
|
|
BEST PRACTICES:\n\
|
|
- Use ILIKE for case-insensitive text search, LIKE for case-sensitive.\n\
|
|
- Use EXISTS instead of IN for subquery existence checks.\n\
|
|
- Use CTE (WITH ... AS) for complex multi-step logic.\n\
|
|
- Use window functions (ROW_NUMBER, RANK, LAG, LEAD, SUM OVER) for ranking and running totals.\n\
|
|
- Use date_trunc('period', column) for time-based grouping.\n\
|
|
- Use generate_series() for creating ranges.\n\
|
|
- Use string_agg(col, ', ') for concatenating grouped values.\n\
|
|
- Use FILTER (WHERE ...) for conditional aggregation instead of CASE inside aggregate.\n\
|
|
\n\
|
|
{}\n",
|
|
schema_text
|
|
);
|
|
|
|
let raw = call_ollama_chat(&app, &state, system_prompt, prompt).await?;
|
|
Ok(clean_sql_response(&raw))
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// SQL explanation
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[tauri::command]
|
|
pub async fn explain_sql(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
sql: String,
|
|
) -> TuskResult<String> {
|
|
let schema_text = build_schema_context(&state, &connection_id).await?;
|
|
|
|
let system_prompt = format!(
|
|
"You are a PostgreSQL expert. Explain the given SQL query clearly and concisely.\n\
|
|
\n\
|
|
Structure your explanation as:\n\
|
|
1. **Summary** — one sentence describing what the query returns in business terms.\n\
|
|
2. **Step-by-step breakdown** — explain tables accessed, joins, filters, aggregations, \
|
|
subqueries, and sorting. Use bullet points.\n\
|
|
3. **Notes** — mention edge cases, potential issues, or performance concerns if any.\n\
|
|
\n\
|
|
Use the database schema below to understand table relationships and column meanings.\n\
|
|
Keep the explanation short; avoid restating the SQL verbatim.\n\
|
|
\n\
|
|
IMPORTANT: If you notice semantic issues (e.g. mixing planned_* and actual_* timestamps \
|
|
with COALESCE, comparing unrelated columns, missing NULL filters on nullable timestamps), \
|
|
mention them in the Notes section as potential problems.\n\
|
|
\n\
|
|
{}\n",
|
|
schema_text
|
|
);
|
|
|
|
call_ollama_chat(&app, &state, system_prompt, sql).await
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// SQL error fixing
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[tauri::command]
|
|
pub async fn fix_sql_error(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
sql: String,
|
|
error_message: String,
|
|
) -> TuskResult<String> {
|
|
let schema_text = build_schema_context(&state, &connection_id).await?;
|
|
|
|
let system_prompt = format!(
|
|
"You are a PostgreSQL expert debugger. You receive a SQL query and the error it produced. \
|
|
Fix the query so it executes correctly.\n\
|
|
\n\
|
|
OUTPUT FORMAT:\n\
|
|
- Raw SQL only. No explanations, no markdown code fences (```), no comments.\n\
|
|
- The output must be directly executable.\n\
|
|
\n\
|
|
DIAGNOSTIC CHECKLIST:\n\
|
|
- Column/table does not exist → check the schema for correct names and spelling.\n\
|
|
- Column is ambiguous → qualify with table name or alias.\n\
|
|
- Must appear in GROUP BY → add missing non-aggregated columns to GROUP BY.\n\
|
|
- Type mismatch → add appropriate casts (::text, ::integer, etc.).\n\
|
|
- Permission denied → wrap in a read-only transaction if needed.\n\
|
|
- Syntax error → correct PostgreSQL syntax (check commas, parentheses, keywords).\n\
|
|
- Subquery returns more than one row → use IN, ANY, or add LIMIT 1.\n\
|
|
- Division by zero → wrap divisor with NULLIF(x, 0).\n\
|
|
\n\
|
|
ONLY use tables and columns from the schema below. Never invent names.\n\
|
|
Preserve the original intent of the query; change only what is necessary to fix the error.\n\
|
|
\n\
|
|
{}\n",
|
|
schema_text
|
|
);
|
|
|
|
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
|
|
|
|
let raw = call_ollama_chat(&app, &state, system_prompt, user_content).await?;
|
|
Ok(clean_sql_response(&raw))
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Schema context builder
|
|
// ---------------------------------------------------------------------------
|
|
|
|
pub(crate) async fn build_schema_context(
|
|
state: &AppState,
|
|
connection_id: &str,
|
|
) -> TuskResult<String> {
|
|
// Check cache first
|
|
if let Some(cached) = state.get_schema_cache(connection_id).await {
|
|
return Ok(cached);
|
|
}
|
|
|
|
let pool = state.get_pool(connection_id).await?;
|
|
|
|
// Run all metadata queries in parallel for speed
|
|
let (
|
|
version_res,
|
|
col_res,
|
|
fk_res,
|
|
enum_res,
|
|
tbl_comment_res,
|
|
col_comment_res,
|
|
unique_res,
|
|
varchar_res,
|
|
jsonb_res,
|
|
) = tokio::join!(
|
|
sqlx::query_scalar::<_, String>("SELECT version()").fetch_one(&pool),
|
|
fetch_columns(&pool),
|
|
fetch_foreign_keys_raw(&pool),
|
|
fetch_enum_types(&pool),
|
|
fetch_table_comments(&pool),
|
|
fetch_column_comments(&pool),
|
|
fetch_unique_constraints(&pool),
|
|
fetch_varchar_values(&pool),
|
|
fetch_jsonb_keys(&pool),
|
|
);
|
|
|
|
let version = version_res.map_err(TuskError::Database)?;
|
|
let col_rows = col_res?;
|
|
let fk_rows = fk_res?;
|
|
let enum_map = enum_res?;
|
|
let tbl_comments = tbl_comment_res?;
|
|
let col_comments = col_comment_res?;
|
|
let unique_constraints = unique_res?;
|
|
let varchar_values = varchar_res.unwrap_or_default();
|
|
let jsonb_keys = jsonb_res.unwrap_or_default();
|
|
|
|
// -- Build FK inline lookup: (schema, table, column) -> "ref_schema.ref_table(ref_col)" --
|
|
let mut fk_inline: HashMap<(String, String, String), String> = HashMap::new();
|
|
let mut fk_lines: Vec<String> = Vec::new();
|
|
for fk in &fk_rows {
|
|
let line = format!(
|
|
"FK: {}.{}({}) -> {}.{}({})",
|
|
fk.schema,
|
|
fk.table,
|
|
fk.columns.join(", "),
|
|
fk.ref_schema,
|
|
fk.ref_table,
|
|
fk.ref_columns.join(", ")
|
|
);
|
|
fk_lines.push(line);
|
|
|
|
// For single-column FKs, enable inline annotation on column
|
|
if fk.columns.len() == 1 && fk.ref_columns.len() == 1 {
|
|
fk_inline.insert(
|
|
(fk.schema.clone(), fk.table.clone(), fk.columns[0].clone()),
|
|
format!("{}.{}({})", fk.ref_schema, fk.ref_table, fk.ref_columns[0]),
|
|
);
|
|
}
|
|
}
|
|
|
|
// -- Build unique constraint lookup: (schema, table) -> Vec<column_list_string> --
|
|
let mut unique_map: HashMap<(String, String), Vec<String>> = HashMap::new();
|
|
for (schema, table, cols) in &unique_constraints {
|
|
unique_map
|
|
.entry((schema.clone(), table.clone()))
|
|
.or_default()
|
|
.push(cols.join(", "));
|
|
}
|
|
|
|
// -- Format output --
|
|
let mut output: Vec<String> = Vec::new();
|
|
|
|
// 1. PostgreSQL version (short form)
|
|
let short_version = version
|
|
.split_whitespace()
|
|
.take(2)
|
|
.collect::<Vec<_>>()
|
|
.join(" ");
|
|
output.push(format!("DATABASE SCHEMA ({})", short_version));
|
|
output.push(String::new());
|
|
|
|
// 2. Enum types
|
|
if !enum_map.is_empty() {
|
|
output.push("ENUM TYPES:".to_string());
|
|
for (type_name, values) in &enum_map {
|
|
let vals_str = values
|
|
.iter()
|
|
.map(|v| format!("'{}'", v))
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
output.push(format!(" {} = [{}]", type_name, vals_str));
|
|
}
|
|
output.push(String::new());
|
|
}
|
|
|
|
// 3. Tables with columns
|
|
output.push("TABLES:".to_string());
|
|
|
|
// Group columns by schema.table preserving order
|
|
let mut tables: BTreeMap<String, Vec<ColumnInfo>> = BTreeMap::new();
|
|
for ci in &col_rows {
|
|
let key = format!("{}.{}", ci.schema, ci.table);
|
|
tables.entry(key).or_default().push(ci.clone());
|
|
}
|
|
|
|
for (full_name, columns) in &tables {
|
|
// Table header with optional comment
|
|
let tbl_comment = tbl_comments.get(full_name).map(|c| c.as_str());
|
|
match tbl_comment {
|
|
Some(comment) => output.push(format!("\nTABLE {} -- {}", full_name, comment)),
|
|
None => output.push(format!("\nTABLE {}", full_name)),
|
|
}
|
|
|
|
// Columns
|
|
for ci in columns {
|
|
let mut parts: Vec<String> = vec![ci.column.clone(), ci.data_type.clone()];
|
|
|
|
if ci.is_pk {
|
|
parts.push("PK".to_string());
|
|
}
|
|
if ci.not_null && !ci.is_pk {
|
|
parts.push("NOT NULL".to_string());
|
|
}
|
|
|
|
// Inline FK reference
|
|
if let Some(ref_target) =
|
|
fk_inline.get(&(ci.schema.clone(), ci.table.clone(), ci.column.clone()))
|
|
{
|
|
parts.push(format!("FK->{}", ref_target));
|
|
}
|
|
|
|
// Default value (simplified)
|
|
if let Some(ref def) = ci.column_default {
|
|
let simplified = simplify_default(def);
|
|
if !simplified.is_empty() {
|
|
parts.push(format!("DEFAULT {}", simplified));
|
|
}
|
|
}
|
|
|
|
// Column comment
|
|
let col_key = (ci.schema.clone(), ci.table.clone(), ci.column.clone());
|
|
let col_comment = col_comments.get(&col_key);
|
|
|
|
// Inline enum values for enum-typed columns
|
|
let enum_annotation = enum_map.get(&ci.data_type);
|
|
|
|
let mut suffix_parts: Vec<String> = Vec::new();
|
|
if let Some(vals) = enum_annotation {
|
|
let vals_str = vals
|
|
.iter()
|
|
.map(|v| format!("'{}'", v))
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
suffix_parts.push(format!("enum: {}", vals_str));
|
|
}
|
|
|
|
// Inline varchar distinct values (pseudo-enums from pg_stats)
|
|
if enum_annotation.is_none() {
|
|
if let Some(vals) = varchar_values.get(&col_key) {
|
|
let vals_str = vals
|
|
.iter()
|
|
.take(15)
|
|
.map(|v| format!("'{}'", v))
|
|
.collect::<Vec<_>>()
|
|
.join(", ");
|
|
suffix_parts.push(format!("values: {}", vals_str));
|
|
}
|
|
}
|
|
|
|
// Inline JSONB key structure
|
|
if let Some(keys) = jsonb_keys.get(&col_key) {
|
|
suffix_parts.push(format!("json keys: {}", keys.join(", ")));
|
|
}
|
|
|
|
if let Some(comment) = col_comment {
|
|
suffix_parts.push(comment.clone());
|
|
}
|
|
|
|
if suffix_parts.is_empty() {
|
|
output.push(format!(" {}", parts.join(" ")));
|
|
} else {
|
|
output.push(format!(
|
|
" {} -- {}",
|
|
parts.join(" "),
|
|
suffix_parts.join("; ")
|
|
));
|
|
}
|
|
}
|
|
|
|
// Unique constraints for this table
|
|
let schema_table: Vec<&str> = full_name.splitn(2, '.').collect();
|
|
if schema_table.len() == 2 {
|
|
if let Some(uqs) =
|
|
unique_map.get(&(schema_table[0].to_string(), schema_table[1].to_string()))
|
|
{
|
|
for uq_cols in uqs {
|
|
output.push(format!(" UNIQUE({})", uq_cols));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 4. Foreign keys summary
|
|
if !fk_lines.is_empty() {
|
|
output.push(String::new());
|
|
output.push("FOREIGN KEYS:".to_string());
|
|
for fk in &fk_lines {
|
|
output.push(format!(" {}", fk));
|
|
}
|
|
}
|
|
|
|
let result = output.join("\n");
|
|
|
|
// Cache the result
|
|
state
|
|
.set_schema_cache(connection_id.to_string(), result.clone())
|
|
.await;
|
|
|
|
Ok(result)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Schema query helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[derive(Clone)]
|
|
struct ColumnInfo {
|
|
schema: String,
|
|
table: String,
|
|
column: String,
|
|
data_type: String,
|
|
not_null: bool,
|
|
is_pk: bool,
|
|
column_default: Option<String>,
|
|
}
|
|
|
|
async fn fetch_columns(pool: &sqlx::PgPool) -> TuskResult<Vec<ColumnInfo>> {
|
|
let rows = sqlx::query(
|
|
"SELECT \
|
|
c.table_schema, c.table_name, c.column_name, \
|
|
CASE \
|
|
WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name \
|
|
WHEN c.data_type = 'ARRAY' THEN c.udt_name || '[]' \
|
|
ELSE c.data_type \
|
|
END AS data_type, \
|
|
c.is_nullable = 'NO' AS not_null, \
|
|
c.column_default, \
|
|
EXISTS( \
|
|
SELECT 1 FROM information_schema.table_constraints tc \
|
|
JOIN information_schema.key_column_usage kcu \
|
|
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema \
|
|
WHERE tc.constraint_type = 'PRIMARY KEY' \
|
|
AND tc.table_schema = c.table_schema \
|
|
AND tc.table_name = c.table_name \
|
|
AND kcu.column_name = c.column_name \
|
|
) AS is_pk \
|
|
FROM information_schema.columns c \
|
|
WHERE c.table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
|
|
ORDER BY c.table_schema, c.table_name, c.ordinal_position",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
Ok(rows
|
|
.iter()
|
|
.map(|r| ColumnInfo {
|
|
schema: r.get(0),
|
|
table: r.get(1),
|
|
column: r.get(2),
|
|
data_type: r.get(3),
|
|
not_null: r.get(4),
|
|
column_default: r.get(5),
|
|
is_pk: r.get(6),
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
pub(crate) struct ForeignKeyInfo {
|
|
pub(crate) schema: String,
|
|
pub(crate) table: String,
|
|
pub(crate) columns: Vec<String>,
|
|
pub(crate) ref_schema: String,
|
|
pub(crate) ref_table: String,
|
|
pub(crate) ref_columns: Vec<String>,
|
|
}
|
|
|
|
pub(crate) async fn fetch_foreign_keys_raw(pool: &sqlx::PgPool) -> TuskResult<Vec<ForeignKeyInfo>> {
|
|
let rows = sqlx::query(
|
|
"SELECT \
|
|
cn.nspname AS schema_name, cl.relname AS table_name, \
|
|
array_agg(DISTINCT a.attname ORDER BY a.attname) AS columns, \
|
|
cnf.nspname AS ref_schema, clf.relname AS ref_table, \
|
|
array_agg(DISTINCT af.attname ORDER BY af.attname) AS ref_columns \
|
|
FROM pg_constraint con \
|
|
JOIN pg_class cl ON con.conrelid = cl.oid \
|
|
JOIN pg_namespace cn ON cl.relnamespace = cn.oid \
|
|
JOIN pg_class clf ON con.confrelid = clf.oid \
|
|
JOIN pg_namespace cnf ON clf.relnamespace = cnf.oid \
|
|
JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \
|
|
JOIN pg_attribute af ON af.attrelid = con.confrelid AND af.attnum = ANY(con.confkey) \
|
|
WHERE con.contype = 'f' \
|
|
AND cn.nspname NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \
|
|
GROUP BY cn.nspname, cl.relname, cnf.nspname, clf.relname, con.oid",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
Ok(rows
|
|
.iter()
|
|
.map(|r| ForeignKeyInfo {
|
|
schema: r.get(0),
|
|
table: r.get(1),
|
|
columns: r.get(2),
|
|
ref_schema: r.get(3),
|
|
ref_table: r.get(4),
|
|
ref_columns: r.get(5),
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
/// Returns BTreeMap<enum_type_name, Vec<enum_values>> ordered by type name
|
|
async fn fetch_enum_types(pool: &sqlx::PgPool) -> TuskResult<BTreeMap<String, Vec<String>>> {
|
|
let rows = sqlx::query(
|
|
"SELECT t.typname, \
|
|
array_agg(e.enumlabel ORDER BY e.enumsortorder) AS vals \
|
|
FROM pg_enum e \
|
|
JOIN pg_type t ON e.enumtypid = t.oid \
|
|
JOIN pg_namespace n ON t.typnamespace = n.oid \
|
|
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema') \
|
|
GROUP BY t.typname \
|
|
ORDER BY t.typname",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
let mut map = BTreeMap::new();
|
|
for r in &rows {
|
|
let name: String = r.get(0);
|
|
let vals: Vec<String> = r.get(1);
|
|
map.insert(name, vals);
|
|
}
|
|
Ok(map)
|
|
}
|
|
|
|
/// Returns HashMap<"schema.table", comment>
|
|
async fn fetch_table_comments(pool: &sqlx::PgPool) -> TuskResult<HashMap<String, String>> {
|
|
let rows = sqlx::query(
|
|
"SELECT n.nspname, c.relname, obj_description(c.oid, 'pg_class') \
|
|
FROM pg_class c \
|
|
JOIN pg_namespace n ON c.relnamespace = n.oid \
|
|
WHERE c.relkind IN ('r', 'v', 'p', 'm') \
|
|
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
|
|
AND obj_description(c.oid, 'pg_class') IS NOT NULL",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
let mut map = HashMap::new();
|
|
for r in &rows {
|
|
let schema: String = r.get(0);
|
|
let table: String = r.get(1);
|
|
let comment: String = r.get(2);
|
|
map.insert(format!("{}.{}", schema, table), comment);
|
|
}
|
|
Ok(map)
|
|
}
|
|
|
|
/// Returns HashMap<(schema, table, column), comment>
|
|
async fn fetch_column_comments(
|
|
pool: &sqlx::PgPool,
|
|
) -> TuskResult<HashMap<(String, String, String), String>> {
|
|
let rows = sqlx::query(
|
|
"SELECT n.nspname, c.relname, a.attname, d.description \
|
|
FROM pg_description d \
|
|
JOIN pg_class c ON d.objoid = c.oid \
|
|
JOIN pg_namespace n ON c.relnamespace = n.oid \
|
|
JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = d.objsubid \
|
|
WHERE d.objsubid > 0 \
|
|
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit')",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
let mut map = HashMap::new();
|
|
for r in &rows {
|
|
let schema: String = r.get(0);
|
|
let table: String = r.get(1);
|
|
let column: String = r.get(2);
|
|
let comment: String = r.get(3);
|
|
map.insert((schema, table, column), comment);
|
|
}
|
|
Ok(map)
|
|
}
|
|
|
|
/// Returns Vec<(schema, table, Vec<column_names>)> for UNIQUE constraints
|
|
async fn fetch_unique_constraints(
|
|
pool: &sqlx::PgPool,
|
|
) -> TuskResult<Vec<(String, String, Vec<String>)>> {
|
|
let rows = sqlx::query(
|
|
"SELECT n.nspname, cl.relname, \
|
|
array_agg(a.attname ORDER BY a.attnum) AS cols \
|
|
FROM pg_constraint con \
|
|
JOIN pg_class cl ON con.conrelid = cl.oid \
|
|
JOIN pg_namespace n ON cl.relnamespace = n.oid \
|
|
JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \
|
|
WHERE con.contype = 'u' \
|
|
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
|
|
GROUP BY n.nspname, cl.relname, con.oid \
|
|
ORDER BY n.nspname, cl.relname",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
Ok(rows
|
|
.iter()
|
|
.map(|r| {
|
|
let schema: String = r.get(0);
|
|
let table: String = r.get(1);
|
|
let cols: Vec<String> = r.get(2);
|
|
(schema, table, cols)
|
|
})
|
|
.collect())
|
|
}
|
|
|
|
/// Returns HashMap<(schema, table, column), Vec<distinct_values>> for varchar columns
|
|
/// with few distinct values (pseudo-enums), using pg_stats for zero-cost discovery.
|
|
/// Returns None if pg_stats is not accessible (graceful degradation).
|
|
async fn fetch_varchar_values(
|
|
pool: &sqlx::PgPool,
|
|
) -> Option<HashMap<(String, String, String), Vec<String>>> {
|
|
let rows = match sqlx::query(
|
|
"SELECT s.schemaname, s.tablename, s.attname, \
|
|
s.most_common_vals::text AS vals \
|
|
FROM pg_stats s \
|
|
JOIN information_schema.columns c \
|
|
ON c.table_schema = s.schemaname \
|
|
AND c.table_name = s.tablename \
|
|
AND c.column_name = s.attname \
|
|
WHERE s.schemaname NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
|
|
AND c.data_type = 'character varying' \
|
|
AND s.n_distinct > 0 AND s.n_distinct <= 20 \
|
|
AND s.most_common_vals IS NOT NULL \
|
|
ORDER BY s.schemaname, s.tablename, s.attname",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
log::warn!("Failed to fetch varchar values from pg_stats: {}", e);
|
|
return None;
|
|
}
|
|
};
|
|
|
|
let mut map = HashMap::new();
|
|
for r in &rows {
|
|
let schema: String = r.get(0);
|
|
let table: String = r.get(1);
|
|
let column: String = r.get(2);
|
|
let vals_text: String = r.get(3);
|
|
let vals = parse_pg_array_text(&vals_text);
|
|
if !vals.is_empty() {
|
|
map.insert((schema, table, column), vals);
|
|
}
|
|
}
|
|
Some(map)
|
|
}
|
|
|
|
/// Discovers top-level keys in JSONB columns by sampling actual data.
|
|
/// Runs two sequential queries internally: first discovers JSONB columns,
|
|
/// then samples keys from each via a single UNION ALL query.
|
|
/// Returns None on error (graceful degradation).
|
|
async fn fetch_jsonb_keys(
|
|
pool: &sqlx::PgPool,
|
|
) -> Option<HashMap<(String, String, String), Vec<String>>> {
|
|
// Step 1: Find all JSONB columns
|
|
let col_rows = match sqlx::query(
|
|
"SELECT table_schema, table_name, column_name \
|
|
FROM information_schema.columns \
|
|
WHERE data_type = 'jsonb' \
|
|
AND table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \
|
|
ORDER BY table_schema, table_name, column_name",
|
|
)
|
|
.fetch_all(pool)
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
log::warn!("Failed to fetch JSONB columns: {}", e);
|
|
return None;
|
|
}
|
|
};
|
|
|
|
if col_rows.is_empty() {
|
|
return Some(HashMap::new());
|
|
}
|
|
|
|
// Cap at 50 JSONB columns to prevent unbounded UNION ALL queries on large schemas
|
|
let columns: Vec<(String, String, String)> = col_rows
|
|
.iter()
|
|
.take(50)
|
|
.map(|r| {
|
|
(
|
|
r.get::<String, _>(0),
|
|
r.get::<String, _>(1),
|
|
r.get::<String, _>(2),
|
|
)
|
|
})
|
|
.collect();
|
|
|
|
// Step 2: Build a single UNION ALL query to sample keys from all JSONB columns
|
|
let parts: Vec<String> = columns
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, (schema, table, col))| {
|
|
let qs = schema.replace('"', "\"\"");
|
|
let qt = table.replace('"', "\"\"");
|
|
let qc = col.replace('"', "\"\"");
|
|
format!(
|
|
"(SELECT '{}.{}.{}' AS col_ref, key FROM (\
|
|
SELECT DISTINCT jsonb_object_keys(\"{}\") AS key \
|
|
FROM \"{}\".\"{}\" \
|
|
WHERE \"{}\" IS NOT NULL AND jsonb_typeof(\"{}\") = 'object' \
|
|
LIMIT 50\
|
|
) sub{})",
|
|
schema, table, col, qc, qs, qt, qc, qc, i
|
|
)
|
|
})
|
|
.collect();
|
|
|
|
let query = parts.join(" UNION ALL ");
|
|
|
|
let rows = match sqlx::query(&query).fetch_all(pool).await {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
log::warn!("Failed to fetch JSONB keys: {}", e);
|
|
return None;
|
|
}
|
|
};
|
|
|
|
let mut map: HashMap<(String, String, String), Vec<String>> = HashMap::new();
|
|
for r in &rows {
|
|
let col_ref: String = r.get(0);
|
|
let key: String = r.get(1);
|
|
let ref_parts: Vec<&str> = col_ref.splitn(3, '.').collect();
|
|
if ref_parts.len() == 3 {
|
|
let entry = map
|
|
.entry((
|
|
ref_parts[0].to_string(),
|
|
ref_parts[1].to_string(),
|
|
ref_parts[2].to_string(),
|
|
))
|
|
.or_default();
|
|
if !entry.contains(&key) {
|
|
entry.push(key);
|
|
}
|
|
}
|
|
}
|
|
|
|
for vals in map.values_mut() {
|
|
vals.sort();
|
|
}
|
|
|
|
Some(map)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Helpers
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Parses PostgreSQL text representation of arrays: {val1,val2,"val with comma"}
|
|
fn parse_pg_array_text(s: &str) -> Vec<String> {
|
|
let s = s.trim();
|
|
let s = s.strip_prefix('{').unwrap_or(s);
|
|
let s = s.strip_suffix('}').unwrap_or(s);
|
|
if s.is_empty() {
|
|
return Vec::new();
|
|
}
|
|
|
|
let mut result = Vec::new();
|
|
let mut current = String::new();
|
|
let mut in_quotes = false;
|
|
let mut chars = s.chars().peekable();
|
|
|
|
while let Some(ch) = chars.next() {
|
|
match ch {
|
|
'"' if !in_quotes => {
|
|
in_quotes = true;
|
|
}
|
|
'"' if in_quotes => {
|
|
if chars.peek() == Some(&'"') {
|
|
current.push('"');
|
|
chars.next();
|
|
} else {
|
|
in_quotes = false;
|
|
}
|
|
}
|
|
',' if !in_quotes => {
|
|
result.push(current.trim().to_string());
|
|
current = String::new();
|
|
}
|
|
_ => {
|
|
current.push(ch);
|
|
}
|
|
}
|
|
}
|
|
if !current.is_empty() || !result.is_empty() {
|
|
result.push(current.trim().to_string());
|
|
}
|
|
result
|
|
}
|
|
|
|
fn simplify_default(raw: &str) -> String {
|
|
let s = raw.trim();
|
|
if s.contains("nextval(") {
|
|
return "auto-increment".to_string();
|
|
}
|
|
// Shorten common defaults
|
|
if s == "now()" || s == "CURRENT_TIMESTAMP" || s == "current_timestamp" {
|
|
return "now()".to_string();
|
|
}
|
|
if s == "true" || s == "false" {
|
|
return s.to_string();
|
|
}
|
|
// Numeric/string literals — keep short ones, skip very long generated defaults
|
|
if s.len() > 50 {
|
|
return String::new();
|
|
}
|
|
s.to_string()
|
|
}
|
|
|
|
fn validate_select_statement(sql: &str) -> TuskResult<()> {
|
|
let sql_upper = sql.trim().to_uppercase();
|
|
if !sql_upper.starts_with("SELECT") {
|
|
return Err(TuskError::Custom(
|
|
"Validation query must be a SELECT statement".to_string(),
|
|
));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn validate_index_ddl(ddl: &str) -> TuskResult<()> {
|
|
let ddl_upper = ddl.trim().to_uppercase();
|
|
if !ddl_upper.starts_with("CREATE INDEX") && !ddl_upper.starts_with("DROP INDEX") {
|
|
return Err(TuskError::Custom(
|
|
"Only CREATE INDEX and DROP INDEX statements are allowed".to_string(),
|
|
));
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn clean_sql_response(raw: &str) -> String {
|
|
let trimmed = raw.trim();
|
|
// Remove markdown code fences
|
|
let without_fences = if trimmed.starts_with("```") {
|
|
let inner = trimmed
|
|
.strip_prefix("```sql")
|
|
.or_else(|| trimmed.strip_prefix("```SQL"))
|
|
.or_else(|| trimmed.strip_prefix("```postgresql"))
|
|
.or_else(|| trimmed.strip_prefix("```"))
|
|
.unwrap_or(trimmed);
|
|
inner.strip_suffix("```").unwrap_or(inner)
|
|
} else {
|
|
trimmed
|
|
};
|
|
without_fences.trim().to_string()
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Wave 1: AI Data Assertions (Validation)
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[tauri::command]
|
|
pub async fn generate_validation_sql(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
rule_description: String,
|
|
) -> TuskResult<String> {
|
|
let schema_text = build_schema_context(&state, &connection_id).await?;
|
|
|
|
let system_prompt = format!(
|
|
"You are an expert PostgreSQL data quality validator. Given a database schema and a natural \
|
|
language data quality rule, generate a SELECT query that finds ALL rows violating the rule.\n\
|
|
\n\
|
|
OUTPUT FORMAT:\n\
|
|
- Raw SQL only. No explanations, no markdown code fences, no comments.\n\
|
|
- The query MUST be a SELECT statement.\n\
|
|
- Return violating rows with enough context columns to identify them.\n\
|
|
\n\
|
|
VALIDATION PATTERNS:\n\
|
|
- NULL checks: SELECT * FROM table WHERE required_column IS NULL\n\
|
|
- Format checks: WHERE column !~ 'pattern'\n\
|
|
- Range checks: WHERE column < min OR column > max\n\
|
|
- FK integrity: LEFT JOIN parent ON ... WHERE parent.id IS NULL\n\
|
|
- Uniqueness: GROUP BY ... HAVING COUNT(*) > 1\n\
|
|
- Date consistency: WHERE start_date > end_date\n\
|
|
- Enum validity: WHERE column NOT IN ('val1', 'val2', ...)\n\
|
|
\n\
|
|
ONLY reference tables and columns that exist in the schema.\n\
|
|
\n\
|
|
{}\n",
|
|
schema_text
|
|
);
|
|
|
|
let raw = call_ollama_chat(&app, &state, system_prompt, rule_description).await?;
|
|
Ok(clean_sql_response(&raw))
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub async fn run_validation_rule(
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
sql: String,
|
|
sample_limit: Option<u32>,
|
|
) -> TuskResult<ValidationRule> {
|
|
validate_select_statement(&sql)?;
|
|
|
|
let pool = state.get_pool(&connection_id).await?;
|
|
let limit = sample_limit.unwrap_or(10);
|
|
let _start = Instant::now();
|
|
|
|
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
|
|
sqlx::query("SET TRANSACTION READ ONLY")
|
|
.execute(&mut *tx)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
sqlx::query("SET statement_timeout = '30s'")
|
|
.execute(&mut *tx)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
// Count violations
|
|
let count_sql = format!("SELECT COUNT(*) FROM ({}) AS _v", sql);
|
|
let count_row = sqlx::query(&count_sql)
|
|
.fetch_one(&mut *tx)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
let violation_count: i64 = count_row.get(0);
|
|
|
|
// Sample violations
|
|
let sample_sql = format!("SELECT * FROM ({}) AS _v LIMIT {}", sql, limit);
|
|
let sample_rows = sqlx::query(&sample_sql)
|
|
.fetch_all(&mut *tx)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
tx.rollback().await.map_err(TuskError::Database)?;
|
|
|
|
let mut violation_columns = Vec::new();
|
|
let mut sample_violations = Vec::new();
|
|
|
|
if let Some(first) = sample_rows.first() {
|
|
for col in first.columns() {
|
|
violation_columns.push(col.name().to_string());
|
|
}
|
|
}
|
|
|
|
for row in &sample_rows {
|
|
let vals: Vec<Value> = (0..violation_columns.len())
|
|
.map(|i| pg_value_to_json(row, i))
|
|
.collect();
|
|
sample_violations.push(vals);
|
|
}
|
|
|
|
let status = if violation_count > 0 {
|
|
ValidationStatus::Failed
|
|
} else {
|
|
ValidationStatus::Passed
|
|
};
|
|
|
|
Ok(ValidationRule {
|
|
id: String::new(),
|
|
description: String::new(),
|
|
generated_sql: sql,
|
|
status,
|
|
violation_count: violation_count as u64,
|
|
sample_violations,
|
|
violation_columns,
|
|
error: None,
|
|
})
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub async fn suggest_validation_rules(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
) -> TuskResult<Vec<String>> {
|
|
let schema_text = build_schema_context(&state, &connection_id).await?;
|
|
|
|
let system_prompt = format!(
|
|
"You are a data quality expert. Given a database schema, suggest 5-10 data quality \
|
|
validation rules as natural language descriptions.\n\
|
|
\n\
|
|
OUTPUT FORMAT:\n\
|
|
- Return ONLY a JSON array of strings, each string being a validation rule.\n\
|
|
- No markdown, no explanations, no code fences.\n\
|
|
- Example: [\"All users must have a non-empty email address\", \"Order total must be positive\"]\n\
|
|
\n\
|
|
RULE CATEGORIES TO COVER:\n\
|
|
- NOT NULL checks for critical columns\n\
|
|
- Business logic (positive amounts, valid ranges, consistent dates)\n\
|
|
- Referential integrity (orphaned foreign keys)\n\
|
|
- Format validation (emails, phone numbers, codes)\n\
|
|
- Enum/status field validity\n\
|
|
- Date consistency (start before end, not in future where inappropriate)\n\
|
|
\n\
|
|
{}\n",
|
|
schema_text
|
|
);
|
|
|
|
let raw = call_ollama_chat(
|
|
&app,
|
|
&state,
|
|
system_prompt,
|
|
"Suggest validation rules".to_string(),
|
|
)
|
|
.await?;
|
|
|
|
let cleaned = raw.trim();
|
|
let json_start = cleaned.find('[').unwrap_or(0);
|
|
let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len());
|
|
let json_str = &cleaned[json_start..json_end];
|
|
|
|
let rules: Vec<String> = serde_json::from_str(json_str).map_err(|e| {
|
|
TuskError::Ai(format!(
|
|
"Failed to parse AI response as JSON array: {}. Response: {}",
|
|
e, cleaned
|
|
))
|
|
})?;
|
|
|
|
Ok(rules)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Wave 2: AI Data Generator
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[tauri::command]
|
|
pub async fn generate_test_data_preview(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
params: GenerateDataParams,
|
|
gen_id: String,
|
|
) -> TuskResult<GeneratedDataPreview> {
|
|
let pool = state.get_pool(¶ms.connection_id).await?;
|
|
|
|
let _ = app.emit(
|
|
"datagen-progress",
|
|
DataGenProgress {
|
|
gen_id: gen_id.clone(),
|
|
stage: "schema".to_string(),
|
|
percent: 10,
|
|
message: "Building schema context...".to_string(),
|
|
detail: None,
|
|
},
|
|
);
|
|
|
|
let schema_text = build_schema_context(&state, ¶ms.connection_id).await?;
|
|
|
|
// Get FK info for topological sort
|
|
let fk_rows = fetch_foreign_keys_raw(&pool).await?;
|
|
|
|
let mut target_tables = vec![(params.schema.clone(), params.table.clone())];
|
|
|
|
if params.include_related {
|
|
// Add parent tables (tables referenced by FKs from target)
|
|
for fk in &fk_rows {
|
|
if fk.schema == params.schema && fk.table == params.table {
|
|
let parent = (fk.ref_schema.clone(), fk.ref_table.clone());
|
|
if !target_tables.contains(&parent) {
|
|
target_tables.push(parent);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
let fk_edges: Vec<(String, String, String, String)> = fk_rows
|
|
.iter()
|
|
.map(|fk| {
|
|
(
|
|
fk.schema.clone(),
|
|
fk.table.clone(),
|
|
fk.ref_schema.clone(),
|
|
fk.ref_table.clone(),
|
|
)
|
|
})
|
|
.collect();
|
|
|
|
let sorted_tables = topological_sort_tables(&fk_edges, &target_tables);
|
|
let insert_order: Vec<String> = sorted_tables
|
|
.iter()
|
|
.map(|(s, t)| format!("{}.{}", s, t))
|
|
.collect();
|
|
|
|
let row_count = params.row_count.min(1000);
|
|
|
|
let _ = app.emit(
|
|
"datagen-progress",
|
|
DataGenProgress {
|
|
gen_id: gen_id.clone(),
|
|
stage: "generating".to_string(),
|
|
percent: 30,
|
|
message: "AI is generating test data...".to_string(),
|
|
detail: None,
|
|
},
|
|
);
|
|
|
|
let tables_desc: Vec<String> = sorted_tables
|
|
.iter()
|
|
.map(|(s, t)| {
|
|
let count = if s == ¶ms.schema && t == ¶ms.table {
|
|
row_count
|
|
} else {
|
|
(row_count / 3).max(1)
|
|
};
|
|
format!("{}.{}: {} rows", s, t, count)
|
|
})
|
|
.collect();
|
|
|
|
let custom = params
|
|
.custom_instructions
|
|
.as_deref()
|
|
.unwrap_or("Generate realistic sample data");
|
|
|
|
let system_prompt = format!(
|
|
"You are a PostgreSQL test data generator. Generate realistic test data as JSON.\n\
|
|
\n\
|
|
OUTPUT FORMAT:\n\
|
|
- Return ONLY a JSON object where keys are \"schema.table\" and values are arrays of row objects.\n\
|
|
- Each row object has column names as keys and values matching the column types.\n\
|
|
- No markdown, no explanations, no code fences.\n\
|
|
- Example: {{\"public.users\": [{{\"name\": \"Alice\", \"email\": \"alice@example.com\"}}]}}\n\
|
|
\n\
|
|
RULES:\n\
|
|
1. Respect column types exactly (text, integer, boolean, timestamp, uuid, etc.)\n\
|
|
2. Use valid foreign key values - parent tables are generated first, reference their IDs\n\
|
|
3. Respect enum types - use only valid enum values\n\
|
|
4. Omit auto-increment/serial/identity columns (they have DEFAULT auto-increment)\n\
|
|
5. Generate realistic data: real names, valid emails, plausible dates, etc.\n\
|
|
6. Respect NOT NULL constraints\n\
|
|
7. For UUID columns, generate valid UUIDs\n\
|
|
8. For timestamp columns, use ISO 8601 format\n\
|
|
\n\
|
|
Tables to generate (in this exact order):\n\
|
|
{}\n\
|
|
\n\
|
|
Custom instructions: {}\n\
|
|
\n\
|
|
{}\n",
|
|
tables_desc.join("\n"),
|
|
custom,
|
|
schema_text
|
|
);
|
|
|
|
let raw = call_ollama_chat(
|
|
&app,
|
|
&state,
|
|
system_prompt,
|
|
format!("Generate test data for {} tables", sorted_tables.len()),
|
|
)
|
|
.await?;
|
|
|
|
let _ = app.emit(
|
|
"datagen-progress",
|
|
DataGenProgress {
|
|
gen_id: gen_id.clone(),
|
|
stage: "parsing".to_string(),
|
|
percent: 80,
|
|
message: "Parsing generated data...".to_string(),
|
|
detail: None,
|
|
},
|
|
);
|
|
|
|
// Parse JSON response
|
|
let cleaned = raw.trim();
|
|
let json_start = cleaned.find('{').unwrap_or(0);
|
|
let json_end = cleaned.rfind('}').map(|i| i + 1).unwrap_or(cleaned.len());
|
|
let json_str = &cleaned[json_start..json_end];
|
|
|
|
let data_map: HashMap<String, Vec<HashMap<String, Value>>> = serde_json::from_str(json_str)
|
|
.map_err(|e| {
|
|
TuskError::Ai(format!(
|
|
"Failed to parse generated data: {}. Response: {}",
|
|
e,
|
|
&cleaned[..cleaned.len().min(500)]
|
|
))
|
|
})?;
|
|
|
|
let mut tables = Vec::new();
|
|
let mut total_rows: u32 = 0;
|
|
|
|
for (schema, table) in &sorted_tables {
|
|
let key = format!("{}.{}", schema, table);
|
|
if let Some(rows_data) = data_map.get(&key) {
|
|
let columns: Vec<String> = if let Some(first) = rows_data.first() {
|
|
first.keys().cloned().collect()
|
|
} else {
|
|
Vec::new()
|
|
};
|
|
|
|
let rows: Vec<Vec<Value>> = rows_data
|
|
.iter()
|
|
.map(|row_map| {
|
|
columns
|
|
.iter()
|
|
.map(|col| row_map.get(col).cloned().unwrap_or(Value::Null))
|
|
.collect()
|
|
})
|
|
.collect();
|
|
|
|
let count = rows.len() as u32;
|
|
total_rows += count;
|
|
|
|
tables.push(GeneratedTableData {
|
|
schema: schema.clone(),
|
|
table: table.clone(),
|
|
columns,
|
|
rows,
|
|
row_count: count,
|
|
});
|
|
}
|
|
}
|
|
|
|
let _ = app.emit(
|
|
"datagen-progress",
|
|
DataGenProgress {
|
|
gen_id: gen_id.clone(),
|
|
stage: "done".to_string(),
|
|
percent: 100,
|
|
message: "Data generation complete".to_string(),
|
|
detail: Some(format!(
|
|
"{} rows across {} tables",
|
|
total_rows,
|
|
tables.len()
|
|
)),
|
|
},
|
|
);
|
|
|
|
Ok(GeneratedDataPreview {
|
|
tables,
|
|
insert_order,
|
|
total_rows,
|
|
})
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub async fn insert_generated_data(
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
preview: GeneratedDataPreview,
|
|
) -> TuskResult<u64> {
|
|
if state.is_read_only(&connection_id).await {
|
|
return Err(TuskError::ReadOnly);
|
|
}
|
|
|
|
let pool = state.get_pool(&connection_id).await?;
|
|
let mut tx = pool.begin().await.map_err(TuskError::Database)?;
|
|
|
|
// Defer constraints for circular FKs
|
|
sqlx::query("SET CONSTRAINTS ALL DEFERRED")
|
|
.execute(&mut *tx)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
let mut total_inserted: u64 = 0;
|
|
|
|
for table_data in &preview.tables {
|
|
if table_data.columns.is_empty() || table_data.rows.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
let qualified = format!(
|
|
"{}.{}",
|
|
escape_ident(&table_data.schema),
|
|
escape_ident(&table_data.table)
|
|
);
|
|
let col_list: Vec<String> = table_data.columns.iter().map(|c| escape_ident(c)).collect();
|
|
let placeholders: Vec<String> = (1..=table_data.columns.len())
|
|
.map(|i| format!("${}", i))
|
|
.collect();
|
|
|
|
let sql = format!(
|
|
"INSERT INTO {} ({}) VALUES ({})",
|
|
qualified,
|
|
col_list.join(", "),
|
|
placeholders.join(", ")
|
|
);
|
|
|
|
for row in &table_data.rows {
|
|
let mut query = sqlx::query(&sql);
|
|
for val in row {
|
|
query = bind_json_value(query, val);
|
|
}
|
|
query.execute(&mut *tx).await.map_err(TuskError::Database)?;
|
|
total_inserted += 1;
|
|
}
|
|
}
|
|
|
|
tx.commit().await.map_err(TuskError::Database)?;
|
|
|
|
// Invalidate schema cache since data changed
|
|
state.invalidate_schema_cache(&connection_id).await;
|
|
|
|
Ok(total_inserted)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Wave 3A: Smart Index Advisor
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[tauri::command]
|
|
pub async fn get_index_advisor_report(
|
|
app: AppHandle,
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
) -> TuskResult<IndexAdvisorReport> {
|
|
let pool = state.get_pool(&connection_id).await?;
|
|
|
|
// Fetch table stats, index stats, and slow queries concurrently
|
|
let (table_stats_res, index_stats_res, slow_queries_res) = tokio::join!(
|
|
sqlx::query(
|
|
"SELECT schemaname, relname, seq_scan, idx_scan, n_live_tup, \
|
|
pg_size_pretty(pg_total_relation_size(schemaname || '.' || relname)) AS table_size, \
|
|
pg_size_pretty(pg_indexes_size(quote_ident(schemaname) || '.' || quote_ident(relname))) AS index_size \
|
|
FROM pg_stat_user_tables \
|
|
ORDER BY seq_scan DESC \
|
|
LIMIT 50"
|
|
)
|
|
.fetch_all(&pool),
|
|
sqlx::query(
|
|
"SELECT schemaname, relname, indexrelname, idx_scan, \
|
|
pg_size_pretty(pg_relation_size(indexrelid)) AS index_size, \
|
|
pg_get_indexdef(indexrelid) AS definition \
|
|
FROM pg_stat_user_indexes \
|
|
ORDER BY idx_scan ASC \
|
|
LIMIT 50",
|
|
)
|
|
.fetch_all(&pool),
|
|
sqlx::query(
|
|
"SELECT query, calls, total_exec_time, mean_exec_time, rows \
|
|
FROM pg_stat_statements \
|
|
WHERE calls > 0 \
|
|
ORDER BY mean_exec_time DESC \
|
|
LIMIT 20",
|
|
)
|
|
.fetch_all(&pool),
|
|
);
|
|
|
|
let table_stats: Vec<TableStats> = table_stats_res
|
|
.map_err(TuskError::Database)?
|
|
.iter()
|
|
.map(|r| TableStats {
|
|
schema: r.get(0),
|
|
table: r.get(1),
|
|
seq_scan: r.get(2),
|
|
idx_scan: r.get(3),
|
|
n_live_tup: r.get(4),
|
|
table_size: r.get(5),
|
|
index_size: r.get(6),
|
|
})
|
|
.collect();
|
|
|
|
let index_stats: Vec<IndexStats> = index_stats_res
|
|
.map_err(TuskError::Database)?
|
|
.iter()
|
|
.map(|r| IndexStats {
|
|
schema: r.get(0),
|
|
table: r.get(1),
|
|
index_name: r.get(2),
|
|
idx_scan: r.get(3),
|
|
index_size: r.get(4),
|
|
definition: r.get(5),
|
|
})
|
|
.collect();
|
|
|
|
let (slow_queries, has_pg_stat_statements) = match slow_queries_res {
|
|
Ok(rows) => {
|
|
let queries: Vec<SlowQuery> = rows
|
|
.iter()
|
|
.map(|r| SlowQuery {
|
|
query: r.get(0),
|
|
calls: r.get(1),
|
|
total_time_ms: r.get(2),
|
|
mean_time_ms: r.get(3),
|
|
rows: r.get(4),
|
|
})
|
|
.collect();
|
|
(queries, true)
|
|
}
|
|
Err(_) => (Vec::new(), false),
|
|
};
|
|
|
|
// Build AI prompt for recommendations
|
|
let schema_text = build_schema_context(&state, &connection_id).await?;
|
|
|
|
let mut stats_text = String::from("TABLE STATISTICS:\n");
|
|
for ts in &table_stats {
|
|
stats_text.push_str(&format!(
|
|
" {}.{}: seq_scan={}, idx_scan={}, rows={}, size={}, idx_size={}\n",
|
|
ts.schema,
|
|
ts.table,
|
|
ts.seq_scan,
|
|
ts.idx_scan,
|
|
ts.n_live_tup,
|
|
ts.table_size,
|
|
ts.index_size
|
|
));
|
|
}
|
|
|
|
stats_text.push_str("\nINDEX STATISTICS:\n");
|
|
for is in &index_stats {
|
|
stats_text.push_str(&format!(
|
|
" {}.{}.{}: scans={}, size={}, def={}\n",
|
|
is.schema, is.table, is.index_name, is.idx_scan, is.index_size, is.definition
|
|
));
|
|
}
|
|
|
|
if !slow_queries.is_empty() {
|
|
stats_text.push_str("\nSLOW QUERIES:\n");
|
|
for sq in &slow_queries {
|
|
stats_text.push_str(&format!(
|
|
" calls={}, mean={:.1}ms, total={:.1}ms, rows={}: {}\n",
|
|
sq.calls,
|
|
sq.mean_time_ms,
|
|
sq.total_time_ms,
|
|
sq.rows,
|
|
sq.query.chars().take(200).collect::<String>()
|
|
));
|
|
}
|
|
}
|
|
|
|
let system_prompt = format!(
|
|
"You are a PostgreSQL performance expert. Analyze the database statistics and recommend index changes.\n\
|
|
\n\
|
|
OUTPUT FORMAT:\n\
|
|
- Return ONLY a JSON array of recommendation objects.\n\
|
|
- No markdown, no explanations, no code fences.\n\
|
|
- Each object: {{\"recommendation_type\": \"create_index\"|\"drop_index\"|\"replace_index\", \
|
|
\"table_schema\": \"...\", \"table_name\": \"...\", \"index_name\": \"...\"|null, \
|
|
\"ddl\": \"CREATE INDEX CONCURRENTLY ...\", \"rationale\": \"...\", \
|
|
\"estimated_impact\": \"high\"|\"medium\"|\"low\", \"priority\": \"high\"|\"medium\"|\"low\"}}\n\
|
|
\n\
|
|
RULES:\n\
|
|
1. Prefer CREATE INDEX CONCURRENTLY to avoid locking\n\
|
|
2. Never suggest dropping PRIMARY KEY or UNIQUE indexes\n\
|
|
3. Suggest dropping indexes with 0 scans on tables with many rows\n\
|
|
4. Suggest composite indexes for commonly co-filtered columns\n\
|
|
5. Suggest partial indexes for low-cardinality boolean columns\n\
|
|
6. Consider covering indexes for frequently selected columns\n\
|
|
7. High seq_scan + high row count = strong candidate for new index\n\
|
|
\n\
|
|
{}\n\
|
|
\n\
|
|
{}\n",
|
|
stats_text, schema_text
|
|
);
|
|
|
|
let raw = call_ollama_chat(
|
|
&app,
|
|
&state,
|
|
system_prompt,
|
|
"Analyze indexes and provide recommendations".to_string(),
|
|
)
|
|
.await?;
|
|
|
|
let cleaned = raw.trim();
|
|
let json_start = cleaned.find('[').unwrap_or(0);
|
|
let json_end = cleaned.rfind(']').map(|i| i + 1).unwrap_or(cleaned.len());
|
|
let json_str = &cleaned[json_start..json_end];
|
|
|
|
let recommendations: Vec<IndexRecommendation> =
|
|
serde_json::from_str(json_str).unwrap_or_default();
|
|
|
|
Ok(IndexAdvisorReport {
|
|
table_stats,
|
|
index_stats,
|
|
slow_queries,
|
|
recommendations,
|
|
has_pg_stat_statements,
|
|
})
|
|
}
|
|
|
|
#[tauri::command]
|
|
pub async fn apply_index_recommendation(
|
|
state: State<'_, Arc<AppState>>,
|
|
connection_id: String,
|
|
ddl: String,
|
|
) -> TuskResult<()> {
|
|
if state.is_read_only(&connection_id).await {
|
|
return Err(TuskError::ReadOnly);
|
|
}
|
|
|
|
validate_index_ddl(&ddl)?;
|
|
|
|
let pool = state.get_pool(&connection_id).await?;
|
|
|
|
// CONCURRENTLY cannot run inside a transaction, execute directly
|
|
sqlx::query(&ddl)
|
|
.execute(&pool)
|
|
.await
|
|
.map_err(TuskError::Database)?;
|
|
|
|
// Invalidate schema cache
|
|
state.invalidate_schema_cache(&connection_id).await;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
// ── validate_select_statement ─────────────────────────────
|
|
|
|
#[test]
|
|
fn select_valid_simple() {
|
|
assert!(validate_select_statement("SELECT 1").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn select_valid_with_leading_whitespace() {
|
|
assert!(validate_select_statement(" SELECT * FROM users").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn select_valid_lowercase() {
|
|
assert!(validate_select_statement("select * from users").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn select_rejects_insert() {
|
|
assert!(validate_select_statement("INSERT INTO users VALUES (1)").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn select_rejects_delete() {
|
|
assert!(validate_select_statement("DELETE FROM users").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn select_rejects_drop() {
|
|
assert!(validate_select_statement("DROP TABLE users").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn select_rejects_empty() {
|
|
assert!(validate_select_statement("").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn select_rejects_whitespace_only() {
|
|
assert!(validate_select_statement(" ").is_err());
|
|
}
|
|
|
|
// NOTE: This test documents a known weakness — SELECT prefix allows injection
|
|
#[test]
|
|
fn select_allows_semicolon_after_select() {
|
|
// "SELECT 1; DROP TABLE users" starts with SELECT — passes validation
|
|
// This is a known limitation documented in the review
|
|
assert!(validate_select_statement("SELECT 1; DROP TABLE users").is_ok());
|
|
}
|
|
|
|
// ── validate_index_ddl ────────────────────────────────────
|
|
|
|
#[test]
|
|
fn ddl_valid_create_index() {
|
|
assert!(validate_index_ddl("CREATE INDEX idx_name ON users(email)").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_valid_create_index_concurrently() {
|
|
assert!(validate_index_ddl("CREATE INDEX CONCURRENTLY idx ON t(c)").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_valid_drop_index() {
|
|
assert!(validate_index_ddl("DROP INDEX idx_name").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_valid_with_leading_whitespace() {
|
|
assert!(validate_index_ddl(" CREATE INDEX idx ON t(c)").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_valid_lowercase() {
|
|
assert!(validate_index_ddl("create index idx on t(c)").is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_rejects_create_table() {
|
|
assert!(validate_index_ddl("CREATE TABLE evil(id int)").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_rejects_drop_table() {
|
|
assert!(validate_index_ddl("DROP TABLE users").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_rejects_alter_table() {
|
|
assert!(validate_index_ddl("ALTER TABLE users ADD COLUMN x int").is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn ddl_rejects_empty() {
|
|
assert!(validate_index_ddl("").is_err());
|
|
}
|
|
|
|
// NOTE: Documents bypass weakness — semicolon after valid prefix
|
|
#[test]
|
|
fn ddl_allows_semicolon_injection() {
|
|
// "CREATE INDEX x ON t(c); DROP TABLE users" — passes validation
|
|
// Mitigated by sqlx single-statement execution
|
|
assert!(validate_index_ddl("CREATE INDEX x ON t(c); DROP TABLE users").is_ok());
|
|
}
|
|
|
|
// ── clean_sql_response ────────────────────────────────────
|
|
|
|
#[test]
|
|
fn clean_sql_plain() {
|
|
assert_eq!(clean_sql_response("SELECT 1"), "SELECT 1");
|
|
}
|
|
|
|
#[test]
|
|
fn clean_sql_with_fences() {
|
|
assert_eq!(clean_sql_response("```sql\nSELECT 1\n```"), "SELECT 1");
|
|
}
|
|
|
|
#[test]
|
|
fn clean_sql_with_generic_fences() {
|
|
assert_eq!(clean_sql_response("```\nSELECT 1\n```"), "SELECT 1");
|
|
}
|
|
|
|
#[test]
|
|
fn clean_sql_with_postgresql_fences() {
|
|
assert_eq!(
|
|
clean_sql_response("```postgresql\nSELECT 1\n```"),
|
|
"SELECT 1"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn clean_sql_with_whitespace() {
|
|
assert_eq!(clean_sql_response(" SELECT 1 "), "SELECT 1");
|
|
}
|
|
|
|
#[test]
|
|
fn clean_sql_no_fences_multiline() {
|
|
assert_eq!(
|
|
clean_sql_response("SELECT\n *\nFROM users"),
|
|
"SELECT\n *\nFROM users"
|
|
);
|
|
}
|
|
}
|