From 3ad0ee5cc3aeb8a6a4ad3631de411bdd9ee0133f Mon Sep 17 00:00:00 2001 From: "A.Shakhmatov" Date: Fri, 13 Feb 2026 18:48:39 +0300 Subject: [PATCH] feat: add AI Explain Query and Fix Error via Ollama Extract shared call_ollama_chat helper from generate_sql to reuse settings loading and Ollama API call logic. Add two new AI commands: - explain_sql: explains what a SQL query does in plain language - fix_sql_error: suggests corrected SQL based on the error and schema UI additions: "AI Explain" toolbar button, "Explain" and "Fix with AI" action buttons on query errors, inline explanation display in results. Co-Authored-By: Claude Opus 4.6 --- src-tauri/src/commands/ai.rs | 118 +++++++++++++++----- src-tauri/src/lib.rs | 2 + src/components/results/ResultsPanel.tsx | 65 ++++++++++- src/components/workspace/WorkspacePanel.tsx | 82 +++++++++++++- src/hooks/use-ai.ts | 28 +++++ src/lib/tauri.ts | 6 + 6 files changed, 268 insertions(+), 33 deletions(-) diff --git a/src-tauri/src/commands/ai.rs b/src-tauri/src/commands/ai.rs index 132fd03..ea9ebaf 100644 --- a/src-tauri/src/commands/ai.rs +++ b/src-tauri/src/commands/ai.rs @@ -73,16 +73,13 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult>, - connection_id: String, - prompt: String, +async fn call_ollama_chat( + app: &AppHandle, + system_prompt: String, + user_content: String, ) -> TuskResult { - // Load AI settings let settings = { - let path = get_ai_settings_path(&app)?; + let path = get_ai_settings_path(app)?; if !path.exists() { return Err(TuskError::Ai( "No AI model selected. Open AI settings to choose a model.".to_string(), @@ -98,24 +95,6 @@ pub async fn generate_sql( )); } - // Build schema context - let schema_text = build_schema_context(&state, &connection_id).await?; - - let system_prompt = format!( - "You are a PostgreSQL SQL generator. Given the database schema below and a natural language request, \ - output ONLY a valid PostgreSQL SQL query. Do not include any explanation, markdown formatting, \ - or code fences. Output raw SQL only.\n\n\ - RULES:\n\ - - Use FK relationships for correct JOIN conditions.\n\ - - timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ - - interval cannot be cast to numeric directly.\n\ - - When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\ - - Use COALESCE for nullable columns in aggregations when appropriate.\n\ - - Prefer LEFT JOIN when the related row may not exist.\n\n\ - DATABASE SCHEMA:\n{}", - schema_text - ); - let request = OllamaChatRequest { model: settings.model, messages: vec![ @@ -125,7 +104,7 @@ pub async fn generate_sql( }, OllamaChatMessage { role: "user".to_string(), - content: prompt, + content: user_content, }, ], stream: false, @@ -162,8 +141,89 @@ pub async fn generate_sql( .await .map_err(|e| TuskError::Ai(format!("Failed to parse Ollama response: {}", e)))?; - let sql = clean_sql_response(&chat_resp.message.content); - Ok(sql) + Ok(chat_resp.message.content) +} + +#[tauri::command] +pub async fn generate_sql( + app: AppHandle, + state: State<'_, Arc>, + connection_id: String, + prompt: String, +) -> TuskResult { + let schema_text = build_schema_context(&state, &connection_id).await?; + + let system_prompt = format!( + "You are a PostgreSQL SQL generator. Given the database schema below and a natural language request, \ + output ONLY a valid PostgreSQL SQL query. Do not include any explanation, markdown formatting, \ + or code fences. Output raw SQL only.\n\n\ + RULES:\n\ + - Use FK relationships for correct JOIN conditions.\n\ + - timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ + - interval cannot be cast to numeric directly.\n\ + - When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\ + - Use COALESCE for nullable columns in aggregations when appropriate.\n\ + - Prefer LEFT JOIN when the related row may not exist.\n\n\ + DATABASE SCHEMA:\n{}", + schema_text + ); + + let raw = call_ollama_chat(&app, system_prompt, prompt).await?; + Ok(clean_sql_response(&raw)) +} + +#[tauri::command] +pub async fn explain_sql( + app: AppHandle, + state: State<'_, Arc>, + connection_id: String, + sql: String, +) -> TuskResult { + let schema_text = build_schema_context(&state, &connection_id).await?; + + let system_prompt = format!( + "You are a PostgreSQL expert. Explain what this SQL query does in clear, concise language. \ + Focus on the business logic, mention the tables, joins, and filters used. \ + Use short paragraphs or bullet points.\n\n\ + DATABASE SCHEMA:\n{}", + schema_text + ); + + call_ollama_chat(&app, system_prompt, sql).await +} + +#[tauri::command] +pub async fn fix_sql_error( + app: AppHandle, + state: State<'_, Arc>, + connection_id: String, + sql: String, + error_message: String, +) -> TuskResult { + let schema_text = build_schema_context(&state, &connection_id).await?; + + let system_prompt = format!( + "You are a PostgreSQL expert. Fix the SQL query based on the error message. \ + Output ONLY the corrected valid PostgreSQL SQL. Do not include any explanation, \ + markdown formatting, or code fences. Output raw SQL only.\n\n\ + RULES:\n\ + - Use FK relationships for correct JOIN conditions.\n\ + - timestamp - timestamp = interval. To get a number use EXTRACT(EPOCH FROM (ts1 - ts2)).\n\ + - interval cannot be cast to numeric directly.\n\ + - When using UNION/UNION ALL, ensure matching column types; cast enums to text if they differ.\n\ + - Use COALESCE for nullable columns in aggregations when appropriate.\n\ + - Prefer LEFT JOIN when the related row may not exist.\n\n\ + DATABASE SCHEMA:\n{}", + schema_text + ); + + let user_content = format!( + "Original SQL:\n{}\n\nError:\n{}", + sql, error_message + ); + + let raw = call_ollama_chat(&app, system_prompt, user_content).await?; + Ok(clean_sql_response(&raw)) } async fn build_schema_context( diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 1239bf3..4bd2458 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -94,6 +94,8 @@ pub fn run() { commands::ai::save_ai_settings, commands::ai::list_ollama_models, commands::ai::generate_sql, + commands::ai::explain_sql, + commands::ai::fix_sql_error, // lookup commands::lookup::entity_lookup, ]) diff --git a/src/components/results/ResultsPanel.tsx b/src/components/results/ResultsPanel.tsx index bed3c93..c57cf7b 100644 --- a/src/components/results/ResultsPanel.tsx +++ b/src/components/results/ResultsPanel.tsx @@ -1,7 +1,8 @@ import { ResultsTable } from "./ResultsTable"; import { ResultsJsonView } from "./ResultsJsonView"; import type { QueryResult } from "@/types"; -import { Loader2, AlertCircle } from "lucide-react"; +import { Loader2, AlertCircle, Sparkles, Wand2 } from "lucide-react"; +import { Button } from "@/components/ui/button"; interface Props { result?: QueryResult | null; @@ -14,6 +15,10 @@ interface Props { value: unknown ) => void; highlightedCells?: Set; + aiExplanation?: string | null; + isAiLoading?: boolean; + onExplainError?: () => void; + onFixError?: () => void; } export function ResultsPanel({ @@ -23,6 +28,10 @@ export function ResultsPanel({ viewMode = "table", onCellDoubleClick, highlightedCells, + aiExplanation, + isAiLoading, + onExplainError, + onFixError, }: Props) { if (isLoading) { return ( @@ -33,13 +42,65 @@ export function ResultsPanel({ ); } + if (aiExplanation) { + return ( +
+
+
+ + AI Explanation +
+
+            {aiExplanation}
+          
+
+
+ ); + } + if (error) { return ( -
+
{error}
+ {(onExplainError || onFixError) && ( +
+ {onExplainError && ( + + )} + {onFixError && ( + + )} +
+ )}
); } diff --git a/src/components/workspace/WorkspacePanel.tsx b/src/components/workspace/WorkspacePanel.tsx index 71ffd0d..0c954de 100644 --- a/src/components/workspace/WorkspacePanel.tsx +++ b/src/components/workspace/WorkspacePanel.tsx @@ -13,7 +13,7 @@ import { useCompletionSchema } from "@/hooks/use-completion-schema"; import { useConnections } from "@/hooks/use-connections"; import { useAppStore } from "@/stores/app-store"; import { Button } from "@/components/ui/button"; -import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles } from "lucide-react"; +import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles, BrainCircuit } from "lucide-react"; import { format as formatSql } from "sql-formatter"; import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog"; import { @@ -26,6 +26,7 @@ import { exportCsv, exportJson } from "@/lib/tauri"; import { save } from "@tauri-apps/plugin-dialog"; import { toast } from "sonner"; import { AiBar } from "@/components/ai/AiBar"; +import { useExplainSql, useFixSqlError } from "@/hooks/use-ai"; import type { QueryResult, ExplainResult } from "@/types"; interface Props { @@ -53,8 +54,11 @@ export function WorkspacePanel({ const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table"); const [saveDialogOpen, setSaveDialogOpen] = useState(false); const [aiBarOpen, setAiBarOpen] = useState(false); + const [aiExplanation, setAiExplanation] = useState(null); const queryMutation = useQueryExecution(); + const explainMutation = useExplainSql(); + const fixMutation = useFixSqlError(); const addHistoryMutation = useAddHistory(); const { data: connections } = useConnections(); const { data: completionSchema } = useCompletionSchema(connectionId); @@ -98,6 +102,7 @@ export function WorkspacePanel({ if (!sqlValue.trim() || !connectionId) return; setError(null); setExplainData(null); + setAiExplanation(null); setResultView("results"); queryMutation.mutate( { connectionId, sql: sqlValue }, @@ -191,6 +196,60 @@ export function WorkspacePanel({ [result] ); + const isAiLoading = explainMutation.isPending || fixMutation.isPending; + + const handleAiExplain = useCallback(() => { + if (!sqlValue.trim() || !connectionId) return; + setAiExplanation(null); + setResultView("results"); + explainMutation.mutate( + { connectionId, sql: sqlValue }, + { + onSuccess: (explanation) => { + setAiExplanation(explanation); + }, + onError: (err) => { + toast.error("AI Explain failed", { description: String(err) }); + }, + } + ); + }, [connectionId, sqlValue, explainMutation]); + + const handleExplainError = useCallback(() => { + if (!sqlValue.trim() || !connectionId || !error) return; + setAiExplanation(null); + explainMutation.mutate( + { connectionId, sql: `${sqlValue}\n\n-- Error: ${error}` }, + { + onSuccess: (explanation) => { + setAiExplanation(explanation); + }, + onError: (err) => { + toast.error("AI Explain failed", { description: String(err) }); + }, + } + ); + }, [connectionId, sqlValue, error, explainMutation]); + + const handleFixError = useCallback(() => { + if (!sqlValue.trim() || !connectionId || !error) return; + fixMutation.mutate( + { connectionId, sql: sqlValue, errorMessage: error }, + { + onSuccess: (fixedSql) => { + setSqlValue(fixedSql); + onSqlChange?.(fixedSql); + setError(null); + setAiExplanation(null); + toast.success("SQL replaced by AI suggestion"); + }, + onError: (err) => { + toast.error("AI Fix failed", { description: String(err) }); + }, + } + ); + }, [connectionId, sqlValue, error, fixMutation, onSqlChange]); + return ( <> @@ -257,6 +316,21 @@ export function WorkspacePanel({ AI + {result && result.columns.length > 0 && ( @@ -314,7 +388,7 @@ export function WorkspacePanel({
- {(explainData || result || error) && ( + {(explainData || result || error || aiExplanation) && (
diff --git a/src/hooks/use-ai.ts b/src/hooks/use-ai.ts index 5070dbc..d8db13b 100644 --- a/src/hooks/use-ai.ts +++ b/src/hooks/use-ai.ts @@ -4,6 +4,8 @@ import { saveAiSettings, listOllamaModels, generateSql, + explainSql, + fixSqlError, } from "@/lib/tauri"; import type { AiSettings } from "@/types"; @@ -45,3 +47,29 @@ export function useGenerateSql() { }) => generateSql(connectionId, prompt), }); } + +export function useExplainSql() { + return useMutation({ + mutationFn: ({ + connectionId, + sql, + }: { + connectionId: string; + sql: string; + }) => explainSql(connectionId, sql), + }); +} + +export function useFixSqlError() { + return useMutation({ + mutationFn: ({ + connectionId, + sql, + errorMessage, + }: { + connectionId: string; + sql: string; + errorMessage: string; + }) => fixSqlError(connectionId, sql, errorMessage), + }); +} diff --git a/src/lib/tauri.ts b/src/lib/tauri.ts index 7c7d295..8af6968 100644 --- a/src/lib/tauri.ts +++ b/src/lib/tauri.ts @@ -254,6 +254,12 @@ export const listOllamaModels = (ollamaUrl: string) => export const generateSql = (connectionId: string, prompt: string) => invoke("generate_sql", { connectionId, prompt }); +export const explainSql = (connectionId: string, sql: string) => + invoke("explain_sql", { connectionId, sql }); + +export const fixSqlError = (connectionId: string, sql: string, errorMessage: string) => + invoke("fix_sql_error", { connectionId, sql, errorMessage }); + // Entity Lookup export const entityLookup = ( config: ConnectionConfig,