use crate::error::{TuskError, TuskResult}; use crate::models::ai::{ AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, OllamaTagsResponse, }; use crate::state::AppState; use sqlx::Row; use std::collections::BTreeMap; use std::fs; use std::sync::Arc; use std::time::Duration; use tauri::{AppHandle, Manager, State}; fn http_client() -> reqwest::Client { reqwest::Client::builder() .connect_timeout(Duration::from_secs(5)) .timeout(Duration::from_secs(300)) .build() .unwrap_or_default() } fn get_ai_settings_path(app: &AppHandle) -> TuskResult { let dir = app .path() .app_data_dir() .map_err(|e| TuskError::Custom(e.to_string()))?; fs::create_dir_all(&dir)?; Ok(dir.join("ai_settings.json")) } #[tauri::command] pub async fn get_ai_settings(app: AppHandle) -> TuskResult { let path = get_ai_settings_path(&app)?; if !path.exists() { return Ok(AiSettings::default()); } let data = fs::read_to_string(&path)?; let settings: AiSettings = serde_json::from_str(&data)?; Ok(settings) } #[tauri::command] pub async fn save_ai_settings(app: AppHandle, settings: AiSettings) -> TuskResult<()> { let path = get_ai_settings_path(&app)?; let data = serde_json::to_string_pretty(&settings)?; fs::write(&path, data)?; Ok(()) } #[tauri::command] pub async fn list_ollama_models(ollama_url: String) -> TuskResult> { let url = format!("{}/api/tags", ollama_url.trim_end_matches('/')); let resp = http_client() .get(&url) .send() .await .map_err(|e| TuskError::Ai(format!("Cannot connect to Ollama at {}: {}", ollama_url, e)))?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); return Err(TuskError::Ai(format!( "Ollama error ({}): {}", status, body ))); } let tags: OllamaTagsResponse = resp .json() .await .map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?; Ok(tags.models) } #[tauri::command] pub async fn generate_sql( app: AppHandle, state: State<'_, Arc>, connection_id: String, prompt: String, ) -> TuskResult { // Load AI settings let settings = { let path = get_ai_settings_path(&app)?; if !path.exists() { return Err(TuskError::Ai( "No AI model selected. Open AI settings to choose a model.".to_string(), )); } let data = fs::read_to_string(&path)?; serde_json::from_str::(&data)? }; if settings.model.is_empty() { return Err(TuskError::Ai( "No AI model selected. Open AI settings to choose a model.".to_string(), )); } // Build schema context let schema_text = build_schema_context(&state, &connection_id).await?; let system_prompt = format!( "You are a PostgreSQL SQL generator. Given the database schema below and a natural language request, \ output ONLY a valid PostgreSQL SQL query. Do not include any explanation, markdown formatting, \ or code fences. Output raw SQL only.\n\n\ RULES:\n\ - Use FK relationships for correct JOIN conditions.\n\ - timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ - interval cannot be cast to numeric directly.\n\ - When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\ - Use COALESCE for nullable columns in aggregations when appropriate.\n\ - Prefer LEFT JOIN when the related row may not exist.\n\n\ DATABASE SCHEMA:\n{}", schema_text ); let request = OllamaChatRequest { model: settings.model, messages: vec![ OllamaChatMessage { role: "system".to_string(), content: system_prompt, }, OllamaChatMessage { role: "user".to_string(), content: prompt, }, ], stream: false, }; let url = format!( "{}/api/chat", settings.ollama_url.trim_end_matches('/') ); let resp = http_client() .post(&url) .json(&request) .send() .await .map_err(|e| { TuskError::Ai(format!( "Cannot connect to Ollama at {}: {}", settings.ollama_url, e )) })?; if !resp.status().is_success() { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); return Err(TuskError::Ai(format!( "Ollama error ({}): {}", status, body ))); } let chat_resp: OllamaChatResponse = resp .json() .await .map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?; let sql = clean_sql_response(&chat_resp.message.content); Ok(sql) } async fn build_schema_context( state: &AppState, connection_id: &str, ) -> TuskResult { let pools = state.pools.read().await; let pool = pools .get(connection_id) .ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?; // Single query: all columns with real type names (enum types show actual name, not USER-DEFINED) let col_rows = sqlx::query( "SELECT \ c.table_schema, c.table_name, c.column_name, \ CASE WHEN c.data_type = 'USER-DEFINED' THEN c.udt_name ELSE c.data_type END AS data_type, \ c.is_nullable = 'NO' AS not_null, \ EXISTS( \ SELECT 1 FROM information_schema.table_constraints tc \ JOIN information_schema.key_column_usage kcu \ ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema \ WHERE tc.constraint_type = 'PRIMARY KEY' \ AND tc.table_schema = c.table_schema \ AND tc.table_name = c.table_name \ AND kcu.column_name = c.column_name \ ) AS is_pk \ FROM information_schema.columns c \ WHERE c.table_schema NOT IN ('pg_catalog', 'information_schema', 'pg_toast', 'gp_toolkit') \ ORDER BY c.table_schema, c.table_name, c.ordinal_position", ) .fetch_all(pool) .await .map_err(TuskError::Database)?; // Group columns by schema.table let mut tables: BTreeMap> = BTreeMap::new(); for row in &col_rows { let schema: String = row.get(0); let table: String = row.get(1); let col_name: String = row.get(2); let data_type: String = row.get(3); let not_null: bool = row.get(4); let is_pk: bool = row.get(5); let mut parts = vec![col_name, data_type]; if is_pk { parts.push("PK".to_string()); } if not_null { parts.push("NOT NULL".to_string()); } let key = format!("{}.{}", schema, table); tables.entry(key).or_default().push(parts.join(" ")); } let mut lines: Vec = tables .into_iter() .map(|(key, cols)| format!("{}({})", key, cols.join(", "))) .collect(); // Fetch FK relationships let fks = fetch_foreign_keys_from_pool(pool).await?; for fk in &fks { lines.push(fk.clone()); } Ok(lines.join("\n")) } async fn fetch_foreign_keys_from_pool( pool: &sqlx::PgPool, ) -> TuskResult> { let rows = sqlx::query( "SELECT \ cn.nspname AS schema_name, cl.relname AS table_name, \ array_agg(DISTINCT a.attname ORDER BY a.attname) AS columns, \ cnf.nspname AS ref_schema, clf.relname AS ref_table, \ array_agg(DISTINCT af.attname ORDER BY af.attname) AS ref_columns \ FROM pg_constraint con \ JOIN pg_class cl ON con.conrelid = cl.oid \ JOIN pg_namespace cn ON cl.relnamespace = cn.oid \ JOIN pg_class clf ON con.confrelid = clf.oid \ JOIN pg_namespace cnf ON clf.relnamespace = cnf.oid \ JOIN pg_attribute a ON a.attrelid = con.conrelid AND a.attnum = ANY(con.conkey) \ JOIN pg_attribute af ON af.attrelid = con.confrelid AND af.attnum = ANY(con.confkey) \ WHERE con.contype = 'f' \ AND cn.nspname NOT IN ('pg_catalog','information_schema','pg_toast','gp_toolkit') \ GROUP BY cn.nspname, cl.relname, cnf.nspname, clf.relname, con.oid", ) .fetch_all(pool) .await .map_err(TuskError::Database)?; let fks: Vec = rows .iter() .map(|r| { let schema: String = r.get(0); let table: String = r.get(1); let cols: Vec = r.get(2); let ref_schema: String = r.get(3); let ref_table: String = r.get(4); let ref_cols: Vec = r.get(5); format!( "FK: {}.{}({}) -> {}.{}({})", schema, table, cols.join(", "), ref_schema, ref_table, ref_cols.join(", ") ) }) .collect(); Ok(fks) } fn clean_sql_response(raw: &str) -> String { let trimmed = raw.trim(); // Remove markdown code fences let without_fences = if trimmed.starts_with("```") { let inner = trimmed .strip_prefix("```sql") .or_else(|| trimmed.strip_prefix("```SQL")) .or_else(|| trimmed.strip_prefix("```")) .unwrap_or(trimmed); inner.strip_suffix("```").unwrap_or(inner) } else { trimmed }; without_fences.trim().to_string() }