refactor(ai): consolidate AI around chat tool-calling; add OpenRouter

- rework chat backend (chat.rs, chat_tools.rs, ai.rs, models, state) around tool calls
- add OpenRouter provider alongside Ollama/Fireworks in settings
- drop inline AiBar, ResultsPanel explain/fix UI and ChartPreview in favour of the chat panel
- add frontend chat tool-registry
This commit is contained in:
2026-05-23 15:01:52 +03:00
parent a485cf7ee3
commit 0cba457fb7
19 changed files with 1244 additions and 1931 deletions

View File

@@ -1,103 +0,0 @@
import { useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { AiSettingsPopover } from "./AiSettingsPopover";
import { useGenerateSql } from "@/hooks/use-ai";
import { Sparkles, Loader2, X, Eraser } from "lucide-react";
import { toast } from "sonner";
interface Props {
connectionId: string;
onSqlGenerated: (sql: string) => void;
onClose: () => void;
onExecute?: () => void;
}
export function AiBar({ connectionId, onSqlGenerated, onClose, onExecute }: Props) {
const [prompt, setPrompt] = useState("");
const generateMutation = useGenerateSql();
const handleGenerate = () => {
if (!prompt.trim() || generateMutation.isPending) return;
generateMutation.mutate(
{ connectionId, prompt },
{
onSuccess: (sql) => {
onSqlGenerated(sql);
},
onError: (err) => {
toast.error("AI generation failed", { description: String(err) });
},
}
);
};
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === "Enter" && (e.ctrlKey || e.metaKey)) {
e.preventDefault();
e.stopPropagation();
onExecute?.();
return;
}
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
e.stopPropagation();
handleGenerate();
return;
}
if (e.key === "Escape") {
e.stopPropagation();
onClose();
}
};
return (
<div className="tusk-ai-bar flex items-center gap-2 px-2 py-1.5 tusk-fade-in">
<Sparkles className="h-3.5 w-3.5 shrink-0 tusk-ai-icon" />
<Input
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
onKeyDown={handleKeyDown}
placeholder="Describe the query you want..."
className="h-7 min-w-0 flex-1 border-tusk-purple/20 bg-tusk-purple/5 text-xs placeholder:text-muted-foreground/40 focus:border-tusk-purple/40 focus:ring-tusk-purple/20"
autoFocus
disabled={generateMutation.isPending}
/>
<Button
size="xs"
variant="ghost"
className="gap-1 text-[11px] text-tusk-purple hover:bg-tusk-purple/10 hover:text-tusk-purple"
onClick={handleGenerate}
disabled={generateMutation.isPending || !prompt.trim()}
>
{generateMutation.isPending ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
"Generate"
)}
</Button>
{prompt.trim() && (
<Button
size="icon-xs"
variant="ghost"
onClick={() => setPrompt("")}
title="Clear prompt"
disabled={generateMutation.isPending}
className="text-muted-foreground"
>
<Eraser className="h-3 w-3" />
</Button>
)}
<AiSettingsPopover />
<Button
size="icon-xs"
variant="ghost"
onClick={onClose}
title="Close AI bar"
className="text-muted-foreground"
>
<X className="h-3 w-3" />
</Button>
</div>
);
}

View File

@@ -7,7 +7,11 @@ import {
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useFireworksModels, useOllamaModels } from "@/hooks/use-ai";
import {
useFireworksModels,
useOllamaModels,
useOpenRouterModels,
} from "@/hooks/use-ai";
import { RefreshCw, Loader2 } from "lucide-react";
import type { AiProvider, OllamaModel } from "@/types";
@@ -17,6 +21,8 @@ interface Props {
onOllamaUrlChange: (url: string) => void;
fireworksApiKey: string;
onFireworksApiKeyChange: (key: string) => void;
openrouterApiKey: string;
onOpenRouterApiKeyChange: (key: string) => void;
model: string;
onModelChange: (model: string) => void;
}
@@ -27,6 +33,8 @@ export function AiSettingsFields({
onOllamaUrlChange,
fireworksApiKey,
onFireworksApiKeyChange,
openrouterApiKey,
onOpenRouterApiKeyChange,
model,
onModelChange,
}: Props) {
@@ -41,6 +49,17 @@ export function AiSettingsFields({
);
}
if (provider === "openrouter") {
return (
<OpenRouterFields
apiKey={openrouterApiKey}
onApiKeyChange={onOpenRouterApiKeyChange}
model={model}
onModelChange={onModelChange}
/>
);
}
return (
<OllamaFields
ollamaUrl={ollamaUrl}
@@ -143,6 +162,55 @@ function FireworksFields({
);
}
function OpenRouterFields({
apiKey,
onApiKeyChange,
model,
onModelChange,
}: {
apiKey: string;
onApiKeyChange: (key: string) => void;
model: string;
onModelChange: (model: string) => void;
}) {
const {
data: models,
isLoading,
isError,
refetch,
} = useOpenRouterModels(apiKey);
return (
<>
<div className="flex flex-col gap-1.5">
<label className="text-xs text-muted-foreground">OpenRouter API key</label>
<Input
type="password"
value={apiKey}
onChange={(e) => onApiKeyChange(e.target.value)}
placeholder="sk-or-..."
className="h-8 text-xs"
autoComplete="off"
/>
<p className="text-[10px] text-muted-foreground/70">
Stored locally; sent only to openrouter.ai.
</p>
</div>
<ModelDropdown
models={models}
loading={isLoading}
errored={isError}
errorText="Cannot reach OpenRouter (check API key)"
onRefresh={() => refetch()}
model={model}
onModelChange={onModelChange}
emptyHint={apiKey.trim() ? "Click ↻ to load models" : "Enter API key first"}
/>
</>
);
}
function ModelDropdown({
models,
loading,

View File

@@ -21,6 +21,7 @@ import type { AiProvider } from "@/types";
const SUPPORTED_PROVIDERS: { value: AiProvider; label: string }[] = [
{ value: "ollama", label: "Ollama (local)" },
{ value: "fireworks", label: "Fireworks AI" },
{ value: "openrouter", label: "OpenRouter" },
];
export function AiSettingsPopover() {
@@ -30,22 +31,16 @@ export function AiSettingsPopover() {
const [provider, setProvider] = useState<AiProvider | null>(null);
const [url, setUrl] = useState<string | null>(null);
const [fireworksKey, setFireworksKey] = useState<string | null>(null);
const [openrouterKey, setOpenrouterKey] = useState<string | null>(null);
const [model, setModel] = useState<string | null>(null);
const settingsProvider = settings?.provider;
// Hide unsupported legacy values (openai/anthropic) from the selector.
const normalizedSettingsProvider: AiProvider | undefined =
settingsProvider === "ollama" || settingsProvider === "fireworks"
? settingsProvider
: settingsProvider
? "ollama"
: undefined;
const currentProvider: AiProvider =
provider ?? normalizedSettingsProvider ?? "ollama";
provider ?? settings?.provider ?? "ollama";
const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434";
const currentFireworksKey =
fireworksKey ?? settings?.fireworks_api_key ?? "";
const currentOpenrouterKey =
openrouterKey ?? settings?.openrouter_api_key ?? "";
const currentModel = model ?? settings?.model ?? "";
const handleProviderChange = (next: AiProvider) => {
@@ -64,6 +59,10 @@ export function AiSettingsPopover() {
currentProvider === "fireworks"
? currentFireworksKey.trim() || undefined
: settings?.fireworks_api_key,
openrouter_api_key:
currentProvider === "openrouter"
? currentOpenrouterKey.trim() || undefined
: settings?.openrouter_api_key,
model: currentModel,
},
{
@@ -117,6 +116,8 @@ export function AiSettingsPopover() {
onOllamaUrlChange={setUrl}
fireworksApiKey={currentFireworksKey}
onFireworksApiKeyChange={setFireworksKey}
openrouterApiKey={currentOpenrouterKey}
onOpenRouterApiKeyChange={setOpenrouterKey}
model={currentModel}
onModelChange={setModel}
/>

View File

@@ -1,327 +0,0 @@
import { useMemo } from "react";
import {
Area,
AreaChart,
Bar,
BarChart,
CartesianGrid,
Cell,
Legend,
Line,
LineChart,
Pie,
PieChart,
ResponsiveContainer,
Tooltip,
XAxis,
YAxis,
} from "recharts";
import type { ChartConfig } from "@/types";
interface Props {
config: ChartConfig;
columns: string[];
rows: unknown[][];
height?: number;
}
const PALETTE = [
"#60a5fa", // blue-400
"#34d399", // emerald-400
"#fbbf24", // amber-400
"#f87171", // red-400
"#a78bfa", // violet-400
"#22d3ee", // cyan-400
"#fb923c", // orange-400
"#f472b6", // pink-400
];
const MAX_POINTS = 500;
export function ChartPreview({ config, columns, rows, height = 280 }: Props) {
const xIdx = columns.indexOf(config.x);
const yIdx = columns.indexOf(config.y);
const groupIdx = config.group ? columns.indexOf(config.group) : -1;
const limited = useMemo(() => rows.slice(0, MAX_POINTS), [rows]);
if (xIdx < 0 || yIdx < 0) {
return (
<ChartFallback
config={config}
message={`Column not found: ${xIdx < 0 ? config.x : config.y}`}
/>
);
}
// Coerce y values to numbers; chart libs need numeric Y.
const numericY = (v: unknown): number => {
if (typeof v === "number") return v;
if (typeof v === "string") {
const n = parseFloat(v);
return Number.isFinite(n) ? n : 0;
}
return 0;
};
const labelX = (v: unknown): string => {
if (v == null) return "—";
if (typeof v === "string") return v;
if (typeof v === "number" || typeof v === "boolean") return String(v);
return JSON.stringify(v);
};
const isGrouped = groupIdx >= 0;
// ──────────── grouped data shape ────────────
// For multi-series: pivot to { x: <xValue>, <group1>: yVal, <group2>: yVal, … }
// Used by line, area, and grouped-bar.
const pivoted = useMemo(() => {
if (!isGrouped) return null;
const map = new Map<string, Record<string, unknown>>();
const groupSet = new Set<string>();
for (const row of limited) {
const xv = labelX(row[xIdx]);
const gv = labelX(row[groupIdx!]);
const yv = numericY(row[yIdx]);
groupSet.add(gv);
const acc = map.get(xv) ?? { _x: xv };
acc[gv] = ((acc[gv] as number) ?? 0) + yv;
map.set(xv, acc);
}
return {
data: Array.from(map.values()),
groups: Array.from(groupSet),
};
}, [isGrouped, limited, xIdx, yIdx, groupIdx]);
// Single series shape: [{ _x, _y }]
const flat = useMemo(() => {
return limited.map((row) => ({
_x: labelX(row[xIdx]),
_y: numericY(row[yIdx]),
}));
}, [limited, xIdx, yIdx]);
const tickStyle = {
fill: "var(--muted-foreground)",
fontSize: 10,
} as const;
const axisLine = {
stroke: "rgba(255, 255, 255, 0.08)",
} as const;
const tooltipStyle = {
backgroundColor: "var(--popover)",
border: "1px solid var(--border)",
borderRadius: 6,
fontSize: 11,
} as const;
if (config.chart_type === "pie") {
// Pie: aggregate y by x label (sum), no group support.
const agg = new Map<string, number>();
for (const row of limited) {
const xv = labelX(row[xIdx]);
agg.set(xv, (agg.get(xv) ?? 0) + numericY(row[yIdx]));
}
const data = Array.from(agg.entries()).map(([name, value]) => ({ name, value }));
return (
<ChartFrame config={config} height={height} count={data.length} totalRows={rows.length}>
<ResponsiveContainer width="100%" height={height}>
<PieChart>
<Pie
data={data}
dataKey="value"
nameKey="name"
outerRadius={Math.min(height / 2.5, 110)}
label={(entry) =>
typeof entry.name === "string" && entry.name.length < 20 ? entry.name : ""
}
>
{data.map((_, i) => (
<Cell key={i} fill={PALETTE[i % PALETTE.length]} />
))}
</Pie>
<Tooltip contentStyle={tooltipStyle} />
<Legend
wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }}
verticalAlign="bottom"
/>
</PieChart>
</ResponsiveContainer>
</ChartFrame>
);
}
if (config.chart_type === "line") {
return (
<ChartFrame
config={config}
height={height}
count={isGrouped ? pivoted!.data.length : flat.length}
totalRows={rows.length}
>
<ResponsiveContainer width="100%" height={height}>
<LineChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<Tooltip contentStyle={tooltipStyle} />
{isGrouped ? (
<>
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
{pivoted!.groups.map((g, i) => (
<Line
key={g}
type="monotone"
dataKey={g}
stroke={PALETTE[i % PALETTE.length]}
strokeWidth={2}
dot={false}
/>
))}
</>
) : (
<Line type="monotone" dataKey="_y" stroke={PALETTE[0]} strokeWidth={2} dot={false} />
)}
</LineChart>
</ResponsiveContainer>
</ChartFrame>
);
}
if (config.chart_type === "area") {
return (
<ChartFrame
config={config}
height={height}
count={isGrouped ? pivoted!.data.length : flat.length}
totalRows={rows.length}
>
<ResponsiveContainer width="100%" height={height}>
<AreaChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<Tooltip contentStyle={tooltipStyle} />
{isGrouped ? (
<>
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
{pivoted!.groups.map((g, i) => (
<Area
key={g}
type="monotone"
dataKey={g}
stackId="1"
stroke={PALETTE[i % PALETTE.length]}
fill={PALETTE[i % PALETTE.length]}
fillOpacity={0.35}
/>
))}
</>
) : (
<Area
type="monotone"
dataKey="_y"
stroke={PALETTE[0]}
fill={PALETTE[0]}
fillOpacity={0.35}
/>
)}
</AreaChart>
</ResponsiveContainer>
</ChartFrame>
);
}
// bar (default)
const horizontal = config.orientation === "horizontal";
return (
<ChartFrame
config={config}
height={height}
count={isGrouped ? pivoted!.data.length : flat.length}
totalRows={rows.length}
>
<ResponsiveContainer width="100%" height={height}>
<BarChart
layout={horizontal ? "vertical" : "horizontal"}
data={isGrouped ? pivoted!.data : flat}
margin={{ top: 8, right: 12, left: horizontal ? 24 : 0, bottom: 4 }}
>
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={horizontal} horizontal={!horizontal} />
{horizontal ? (
<>
<XAxis type="number" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis dataKey="_x" type="category" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} width={100} />
</>
) : (
<>
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
</>
)}
<Tooltip contentStyle={tooltipStyle} />
{isGrouped ? (
<>
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
{pivoted!.groups.map((g, i) => (
<Bar key={g} dataKey={g} fill={PALETTE[i % PALETTE.length]} radius={[3, 3, 0, 0]} />
))}
</>
) : (
<Bar dataKey="_y" fill={PALETTE[0]} radius={[3, 3, 0, 0]} />
)}
</BarChart>
</ResponsiveContainer>
</ChartFrame>
);
}
function ChartFrame({
config,
height,
count,
totalRows,
children,
}: {
config: ChartConfig;
height: number;
count: number;
totalRows: number;
children: React.ReactNode;
}) {
return (
<div className="rounded-md border border-border/40 bg-background">
<div className="flex items-center gap-2 border-b border-border/30 px-2 py-1 text-[11px] text-muted-foreground">
<span className="font-medium text-foreground/80">
{config.title ?? `${capitalize(config.chart_type)} chart`}
</span>
<span className="ml-auto text-muted-foreground/60">
{count} point{count === 1 ? "" : "s"}
{totalRows > MAX_POINTS && ` (of ${totalRows}, capped at ${MAX_POINTS})`}
</span>
</div>
<div className="p-2" style={{ minHeight: height }}>
{children}
</div>
</div>
);
}
function ChartFallback({ config, message }: { config: ChartConfig; message: string }) {
return (
<div className="rounded-md border border-destructive/40 bg-destructive/5 p-3 text-xs">
<div className="font-medium text-destructive">
Chart {config.chart_type} failed
</div>
<div className="mt-1 text-muted-foreground">{message}</div>
</div>
);
}
function capitalize(s: string) {
return s.charAt(0).toUpperCase() + s.slice(1);
}

View File

@@ -1,7 +1,6 @@
import { useState } from "react";
import { ResultsTable } from "@/components/results/ResultsTable";
import { ExportDialog } from "@/components/export/ExportDialog";
import { ChartPreview } from "./ChartPreview";
import {
Dialog,
DialogContent,
@@ -15,19 +14,12 @@ import {
AlertCircle,
Sparkles,
User,
Wrench,
Database,
Columns,
Layers,
RefreshCw,
StickyNote,
Bookmark,
BookmarkPlus,
Maximize2,
Download,
BarChart3,
} from "lucide-react";
import type { ChartConfig, ChatMessage } from "@/types";
import type { ChatMessage } from "@/types";
import { getToolMeta, isQueryResultTool } from "./tool-registry";
interface Props {
message: ChatMessage;
@@ -79,8 +71,10 @@ function AssistantBubble({ text }: { text: string }) {
function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) {
const [expanded, setExpanded] = useState(false);
const preview = extractToolPreview(tool, inputJson);
const Icon = iconForTool(tool);
const meta = getToolMeta(tool);
const preview = previewFromJson(tool, inputJson);
const Icon = meta.icon;
const showSqlPreview = (tool === "run_query" || tool === "explain_query") && preview;
return (
<div className="ml-8 rounded-md border border-border/40 bg-muted/20">
@@ -91,17 +85,14 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
>
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
<Icon className="h-3 w-3" />
<span className="font-medium">{labelForTool(tool)}</span>
<span className="font-medium">{meta.label}</span>
{preview && (
<span className="ml-1 truncate text-muted-foreground/70">
{preview.slice(0, 80)}
{preview.length > 80 ? "…" : ""}
</span>
<span className="ml-1 truncate text-muted-foreground/70">{preview}</span>
)}
</button>
{expanded && (
<div className="border-t border-border/30 p-2">
{tool === "run_query" && preview ? (
{showSqlPreview ? (
<pre className="overflow-x-auto whitespace-pre-wrap rounded bg-background/60 p-2 font-mono text-[11px]">
{preview}
</pre>
@@ -116,6 +107,15 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
);
}
function previewFromJson(tool: string, inputJson: string): string | null {
try {
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
return getToolMeta(tool).preview(parsed);
} catch {
return null;
}
}
function ToolResultBlock({
tool,
isError,
@@ -132,87 +132,20 @@ function ToolResultBlock({
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
<div>
<div className="font-medium text-destructive">{labelForTool(tool)} failed</div>
<div className="font-medium text-destructive">{getToolMeta(tool).label} failed</div>
{text && <div className="mt-1 whitespace-pre-wrap text-muted-foreground">{text}</div>}
</div>
</div>
);
}
// Legacy schema tool — keep a one-line indicator for old threads.
if (tool === "get_schema") {
return (
<div className="ml-8 flex items-center gap-2 rounded-md border border-border/40 bg-muted/20 px-2 py-1.5 text-xs text-muted-foreground">
<Database className="h-3 w-3" />
<span>Loaded schema context ({text?.length ?? 0} chars)</span>
</div>
);
}
// Text-only tools (chat v2/v3): list_databases, list_tables, get_columns, switch_database,
// remember, save_query, find_queries.
if (
tool === "list_databases" ||
tool === "list_tables" ||
tool === "get_columns" ||
tool === "switch_database" ||
tool === "remember" ||
tool === "save_query" ||
tool === "find_queries"
) {
return <TextToolResult tool={tool} text={text} />;
}
// make_chart — render chart inline using config from text + data from result.
if (tool === "make_chart") {
return <ChartToolResult text={text} result={result} />;
}
// run_query — full results table with Open-full / Export actions.
if (result) {
// Tools that produce a QueryResult (rendered as a table): run_query, sample_data.
if (isQueryResultTool(tool) && result) {
return <RunQueryResultBlock result={result} />;
}
return null;
}
function ChartToolResult({
text,
result,
}: {
text: string | null;
result: { columns: string[]; types: string[]; rows: unknown[][]; row_count: number; execution_time_ms: number } | null;
}) {
let config: ChartConfig | null = null;
try {
if (text) {
config = JSON.parse(text) as ChartConfig;
}
} catch {
config = null;
}
if (!config || !result) {
return (
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
<div>
<div className="font-medium text-destructive">Chart unavailable</div>
<div className="mt-1 text-muted-foreground">
The agent referenced a chart but the previous query result is not attached.
</div>
</div>
</div>
);
}
return (
<div className="ml-8">
<ChartPreview
config={config}
columns={result.columns}
rows={result.rows}
/>
</div>
);
// Everything else falls back to a collapsible text block.
return <TextToolResult tool={tool} text={text} />;
}
function RunQueryResultBlock({
@@ -315,8 +248,10 @@ function RunQueryResultBlock({
}
function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
// Lazy preview: switch_database is short; everything else collapses by default.
const [expanded, setExpanded] = useState(tool === "switch_database");
const Icon = iconForTool(tool);
const meta = getToolMeta(tool);
const Icon = meta.icon;
const lineCount = text ? text.split("\n").length : 0;
return (
@@ -328,7 +263,7 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
>
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
<Icon className="h-3 w-3" />
<span className="font-medium">{labelForTool(tool)}</span>
<span className="font-medium">{meta.label}</span>
{text && (
<span className="ml-1 text-muted-foreground/60">
{lineCount} line{lineCount === 1 ? "" : "s"}
@@ -346,93 +281,6 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
);
}
function labelForTool(tool: string): string {
switch (tool) {
case "run_query":
return "Run SQL";
case "list_databases":
return "List databases";
case "list_tables":
return "List tables";
case "get_columns":
return "Inspect columns";
case "switch_database":
return "Switch database";
case "remember":
return "Remember";
case "save_query":
return "Save query";
case "find_queries":
return "Find saved queries";
case "make_chart":
return "Make chart";
case "get_schema":
return "Load schema";
default:
return tool;
}
}
function iconForTool(tool: string) {
switch (tool) {
case "run_query":
return Wrench;
case "list_databases":
return Database;
case "list_tables":
return Layers;
case "get_columns":
return Columns;
case "switch_database":
return RefreshCw;
case "remember":
return StickyNote;
case "save_query":
return BookmarkPlus;
case "find_queries":
return Bookmark;
case "make_chart":
return BarChart3;
case "get_schema":
return Database;
default:
return Wrench;
}
}
function extractToolPreview(tool: string, inputJson: string): string | null {
try {
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
switch (tool) {
case "run_query":
return typeof parsed.sql === "string" ? parsed.sql : null;
case "list_tables":
return typeof parsed.database === "string" ? parsed.database : null;
case "switch_database":
return typeof parsed.database === "string" ? parsed.database : null;
case "get_columns":
return Array.isArray(parsed.tables) ? parsed.tables.join(", ") : null;
case "remember":
return typeof parsed.note === "string" ? parsed.note : null;
case "save_query":
return typeof parsed.name === "string" ? parsed.name : null;
case "find_queries":
return typeof parsed.text === "string" ? parsed.text : null;
case "make_chart": {
const t = typeof parsed.chart_type === "string" ? parsed.chart_type : null;
const x = typeof parsed.x === "string" ? parsed.x : null;
const y = typeof parsed.y === "string" ? parsed.y : null;
if (t && x && y) return `${t}: ${x}${y}`;
return null;
}
default:
return null;
}
} catch {
return null;
}
}
function prettyJson(s: string): string {
try {
return JSON.stringify(JSON.parse(s), null, 2);

View File

@@ -0,0 +1,107 @@
import {
Database,
Layers,
Columns,
RefreshCw,
Wrench,
StickyNote,
Bookmark,
BookmarkPlus,
Activity,
Shuffle,
GitBranch,
AlertTriangle,
} from "lucide-react";
import type { LucideIcon } from "lucide-react";
export type ToolMeta = {
icon: LucideIcon;
label: string;
preview: (parsed: Record<string, unknown>) => string | null;
};
const truncate = (s: unknown, n = 80): string | null => {
if (typeof s !== "string") return null;
return s.length > n ? `${s.slice(0, n)}` : s;
};
export const TOOLS: Record<string, ToolMeta> = {
list_databases: {
icon: Database,
label: "List databases",
preview: () => null,
},
list_tables: {
icon: Layers,
label: "List tables",
preview: (p) => (typeof p.database === "string" ? p.database : null),
},
get_columns: {
icon: Columns,
label: "Inspect columns",
preview: (p) => (Array.isArray(p.tables) ? (p.tables as string[]).join(", ") : null),
},
switch_database: {
icon: RefreshCw,
label: "Switch database",
preview: (p) => (typeof p.database === "string" ? p.database : null),
},
run_query: {
icon: Wrench,
label: "Run SQL",
preview: (p) => truncate(p.sql),
},
remember: {
icon: StickyNote,
label: "Remember",
preview: (p) => (typeof p.note === "string" ? p.note : null),
},
save_query: {
icon: BookmarkPlus,
label: "Save query",
preview: (p) => (typeof p.name === "string" ? p.name : null),
},
find_queries: {
icon: Bookmark,
label: "Find saved queries",
preview: (p) => (typeof p.text === "string" ? p.text : null),
},
profile_table: {
icon: Activity,
label: "Profile table",
preview: (p) => (typeof p.table === "string" ? p.table : null),
},
sample_data: {
icon: Shuffle,
label: "Sample rows",
preview: (p) => {
const t = typeof p.table === "string" ? p.table : "";
const limit = typeof p.limit === "number" ? p.limit : 50;
return t ? `${t} (${limit})` : null;
},
},
explain_query: {
icon: GitBranch,
label: "Explain query",
preview: (p) => truncate(p.sql),
},
detect_skew: {
icon: AlertTriangle,
label: "Detect skew",
preview: (p) => (typeof p.table === "string" ? p.table : null),
},
};
export function getToolMeta(tool: string): ToolMeta {
return (
TOOLS[tool] ?? {
icon: Wrench,
label: tool,
preview: () => null,
}
);
}
export function isQueryResultTool(tool: string): boolean {
return tool === "run_query" || tool === "sample_data";
}

View File

@@ -1,8 +1,7 @@
import { ResultsTable } from "./ResultsTable";
import { ResultsJsonView } from "./ResultsJsonView";
import type { QueryResult } from "@/types";
import { Loader2, AlertCircle, Sparkles, Wand2 } from "lucide-react";
import { Button } from "@/components/ui/button";
import { Loader2, AlertCircle } from "lucide-react";
interface Props {
result?: QueryResult | null;
@@ -15,10 +14,6 @@ interface Props {
value: unknown
) => void;
highlightedCells?: Set<string>;
aiExplanation?: string | null;
isAiLoading?: boolean;
onExplainError?: () => void;
onFixError?: () => void;
}
export function ResultsPanel({
@@ -28,10 +23,6 @@ export function ResultsPanel({
viewMode = "table",
onCellDoubleClick,
highlightedCells,
aiExplanation,
isAiLoading,
onExplainError,
onFixError,
}: Props) {
if (isLoading) {
return (
@@ -42,22 +33,6 @@ export function ResultsPanel({
);
}
if (aiExplanation) {
return (
<div className="h-full select-text overflow-auto p-4">
<div className="rounded-md border bg-muted/30 p-4">
<div className="mb-2 flex items-center gap-2 text-xs font-medium text-muted-foreground">
<Sparkles className="h-3.5 w-3.5" />
AI Explanation
</div>
<pre className="whitespace-pre-wrap font-sans text-sm leading-relaxed text-foreground">
{aiExplanation}
</pre>
</div>
</div>
);
}
if (error) {
return (
<div className="flex h-full select-text flex-col items-center justify-center gap-3 p-4">
@@ -65,42 +40,6 @@ export function ResultsPanel({
<AlertCircle className="mt-0.5 h-4 w-4 shrink-0" />
<pre className="whitespace-pre-wrap font-mono text-xs">{error}</pre>
</div>
{(onExplainError || onFixError) && (
<div className="flex items-center gap-2">
{onExplainError && (
<Button
size="sm"
variant="outline"
className="h-7 gap-1.5 text-xs"
onClick={onExplainError}
disabled={isAiLoading}
>
{isAiLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<Sparkles className="h-3 w-3" />
)}
Explain
</Button>
)}
{onFixError && (
<Button
size="sm"
variant="outline"
className="h-7 gap-1.5 text-xs"
onClick={onFixError}
disabled={isAiLoading}
>
{isAiLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<Wand2 className="h-3 w-3" />
)}
Fix with AI
</Button>
)}
</div>
)}
</div>
);
}

View File

@@ -27,6 +27,7 @@ import type { AiProvider, AppSettings } from "@/types";
const SUPPORTED_AI_PROVIDERS: { value: AiProvider; label: string }[] = [
{ value: "ollama", label: "Ollama (local)" },
{ value: "fireworks", label: "Fireworks AI" },
{ value: "openrouter", label: "OpenRouter" },
];
interface Props {
@@ -50,6 +51,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
const [aiProvider, setAiProvider] = useState<AiProvider>("ollama");
const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434");
const [fireworksApiKey, setFireworksApiKey] = useState("");
const [openrouterApiKey, setOpenrouterApiKey] = useState("");
const [aiModel, setAiModel] = useState("");
const [copied, setCopied] = useState(false);
@@ -70,10 +72,14 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
if (aiSettings) {
// Legacy openai/anthropic values aren't user-selectable here — fall back to ollama.
setAiProvider(
aiSettings.provider === "fireworks" ? "fireworks" : "ollama"
aiSettings.provider === "fireworks" ||
aiSettings.provider === "openrouter"
? aiSettings.provider
: "ollama"
);
setOllamaUrl(aiSettings.ollama_url);
setFireworksApiKey(aiSettings.fireworks_api_key ?? "");
setOpenrouterApiKey(aiSettings.openrouter_api_key ?? "");
setAiModel(aiSettings.model);
}
}
@@ -115,6 +121,10 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
aiProvider === "fireworks"
? fireworksApiKey.trim() || undefined
: aiSettings?.fireworks_api_key,
openrouter_api_key:
aiProvider === "openrouter"
? openrouterApiKey.trim() || undefined
: aiSettings?.openrouter_api_key,
model: aiModel,
},
{
@@ -167,7 +177,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
<span
className={`inline-block h-2 w-2 rounded-full ${
mcpStatus?.running
? "bg-green-500"
? "bg-success ring-2 ring-success/25"
: "bg-muted-foreground/30"
}`}
/>
@@ -189,7 +199,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
title="Copy endpoint URL"
>
{copied ? (
<Check className="h-3 w-3 text-green-500" />
<Check className="h-3 w-3 text-success" />
) : (
<Copy className="h-3 w-3" />
)}
@@ -229,6 +239,8 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
onOllamaUrlChange={setOllamaUrl}
fireworksApiKey={fireworksApiKey}
onFireworksApiKeyChange={setFireworksApiKey}
openrouterApiKey={openrouterApiKey}
onOpenRouterApiKeyChange={setOpenrouterApiKey}
model={aiModel}
onModelChange={setAiModel}
/>

View File

@@ -13,7 +13,7 @@ import { useCompletionSchema } from "@/hooks/use-completion-schema";
import { useConnections } from "@/hooks/use-connections";
import { useAppStore } from "@/stores/app-store";
import { Button } from "@/components/ui/button";
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles, BrainCircuit } from "lucide-react";
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces } from "lucide-react";
import { format as formatSql } from "sql-formatter";
import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog";
import {
@@ -25,8 +25,6 @@ import {
import { exportCsv, exportJson } from "@/lib/tauri";
import { save } from "@tauri-apps/plugin-dialog";
import { toast } from "sonner";
import { AiBar } from "@/components/ai/AiBar";
import { useExplainSql, useFixSqlError } from "@/hooks/use-ai";
import type { QueryResult, ExplainResult } from "@/types";
interface Props {
@@ -53,12 +51,8 @@ export function WorkspacePanel({
const [resultView, setResultView] = useState<"results" | "explain">("results");
const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table");
const [saveDialogOpen, setSaveDialogOpen] = useState(false);
const [aiBarOpen, setAiBarOpen] = useState(false);
const [aiExplanation, setAiExplanation] = useState<string | null>(null);
const queryMutation = useQueryExecution();
const explainMutation = useExplainSql();
const fixMutation = useFixSqlError();
const addHistoryMutation = useAddHistory();
const { data: connections } = useConnections();
const { data: completionSchema } = useCompletionSchema(connectionId);
@@ -102,7 +96,6 @@ export function WorkspacePanel({
if (!sqlValue.trim() || !connectionId) return;
setError(null);
setExplainData(null);
setAiExplanation(null);
setResultView("results");
queryMutation.mutate(
{ connectionId, sql: sqlValue },
@@ -196,60 +189,6 @@ export function WorkspacePanel({
[result]
);
const isAiLoading = explainMutation.isPending || fixMutation.isPending;
const handleAiExplain = useCallback(() => {
if (!sqlValue.trim() || !connectionId) return;
setAiExplanation(null);
setResultView("results");
explainMutation.mutate(
{ connectionId, sql: sqlValue },
{
onSuccess: (explanation) => {
setAiExplanation(explanation);
},
onError: (err) => {
toast.error("AI Explain failed", { description: String(err) });
},
}
);
}, [connectionId, sqlValue, explainMutation]);
const handleExplainError = useCallback(() => {
if (!sqlValue.trim() || !connectionId || !error) return;
setAiExplanation(null);
explainMutation.mutate(
{ connectionId, sql: `${sqlValue}\n\n-- Error: ${error}` },
{
onSuccess: (explanation) => {
setAiExplanation(explanation);
},
onError: (err) => {
toast.error("AI Explain failed", { description: String(err) });
},
}
);
}, [connectionId, sqlValue, error, explainMutation]);
const handleFixError = useCallback(() => {
if (!sqlValue.trim() || !connectionId || !error) return;
fixMutation.mutate(
{ connectionId, sql: sqlValue, errorMessage: error },
{
onSuccess: (fixedSql) => {
setSqlValue(fixedSql);
onSqlChange?.(fixedSql);
setError(null);
setAiExplanation(null);
toast.success("SQL replaced by AI suggestion");
},
onError: (err) => {
toast.error("AI Fix failed", { description: String(err) });
},
}
);
}, [connectionId, sqlValue, error, fixMutation, onSqlChange]);
return (
<>
<ResizablePanelGroup orientation="vertical">
@@ -308,35 +247,6 @@ export function WorkspacePanel({
Save
</Button>
<div className="mx-1 h-3.5 w-px bg-border/40" />
{/* AI actions group — purple-branded */}
<Button
size="xs"
variant={aiBarOpen ? "secondary" : "ghost"}
className={`gap-1 text-[11px] ${aiBarOpen ? "text-tusk-purple" : ""}`}
onClick={() => setAiBarOpen(!aiBarOpen)}
title="AI SQL Generator"
>
<Sparkles className={`h-3 w-3 ${aiBarOpen ? "tusk-ai-icon" : ""}`} />
AI
</Button>
<Button
size="xs"
variant="ghost"
className="gap-1 text-[11px]"
onClick={handleAiExplain}
disabled={isAiLoading || !sqlValue.trim()}
title="Explain query with AI"
>
{isAiLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<BrainCircuit className="h-3 w-3" />
)}
AI Explain
</Button>
{result && result.columns.length > 0 && (
<>
<div className="mx-1 h-3.5 w-px bg-border/40" />
@@ -369,23 +279,12 @@ export function WorkspacePanel({
{"\u2318"}Enter
</span>
{isReadOnly && (
<span className="ml-2 flex items-center gap-1 rounded-sm bg-amber-500/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-amber-500">
<span className="ml-2 flex items-center gap-1 rounded-sm bg-warning/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-warning">
<Lock className="h-2.5 w-2.5" />
READ
</span>
)}
</div>
{aiBarOpen && (
<AiBar
connectionId={connectionId}
onSqlGenerated={(sql) => {
setSqlValue(sql);
onSqlChange?.(sql);
}}
onClose={() => setAiBarOpen(false)}
onExecute={handleExecute}
/>
)}
<div className="min-h-0 flex-1">
<SqlEditor
value={sqlValue}
@@ -400,7 +299,7 @@ export function WorkspacePanel({
<ResizableHandle withHandle />
<ResizablePanel id="results" defaultSize="60%" minSize="15%">
<div className="flex h-full flex-col overflow-hidden">
{(explainData || result || error || aiExplanation) && (
{(explainData || result || error) && (
<div className="flex shrink-0 items-center border-b border-border/40 text-xs">
<button
className={`relative px-3 py-1.5 font-medium transition-colors ${
@@ -469,10 +368,6 @@ export function WorkspacePanel({
error={error}
isLoading={queryMutation.isPending && resultView === "results"}
viewMode={resultViewMode}
aiExplanation={aiExplanation}
isAiLoading={isAiLoading}
onExplainError={error ? handleExplainError : undefined}
onFixError={error ? handleFixError : undefined}
/>
)}
</div>

View File

@@ -4,9 +4,7 @@ import {
saveAiSettings,
listOllamaModels,
listFireworksModels,
generateSql,
explainSql,
fixSqlError,
listOpenRouterModels,
} from "@/lib/tauri";
import type { AiSettings } from "@/types";
@@ -47,40 +45,12 @@ export function useFireworksModels(apiKey: string | undefined) {
});
}
export function useGenerateSql() {
return useMutation({
mutationFn: ({
connectionId,
prompt,
}: {
connectionId: string;
prompt: string;
}) => generateSql(connectionId, prompt),
});
}
export function useExplainSql() {
return useMutation({
mutationFn: ({
connectionId,
sql,
}: {
connectionId: string;
sql: string;
}) => explainSql(connectionId, sql),
});
}
export function useFixSqlError() {
return useMutation({
mutationFn: ({
connectionId,
sql,
errorMessage,
}: {
connectionId: string;
sql: string;
errorMessage: string;
}) => fixSqlError(connectionId, sql, errorMessage),
export function useOpenRouterModels(apiKey: string | undefined) {
return useQuery({
queryKey: ["openrouter-models", apiKey],
queryFn: () => listOpenRouterModels(apiKey!),
enabled: !!apiKey && apiKey.trim().length > 0,
retry: false,
staleTime: 60_000,
});
}

View File

@@ -214,14 +214,8 @@ export const listOllamaModels = (ollamaUrl: string) =>
export const listFireworksModels = (apiKey: string) =>
invoke<OllamaModel[]>("list_fireworks_models", { apiKey });
export const generateSql = (connectionId: string, prompt: string) =>
invoke<string>("generate_sql", { connectionId, prompt });
export const explainSql = (connectionId: string, sql: string) =>
invoke<string>("explain_sql", { connectionId, sql });
export const fixSqlError = (connectionId: string, sql: string, errorMessage: string) =>
invoke<string>("fix_sql_error", { connectionId, sql, errorMessage });
export const listOpenRouterModels = (apiKey: string) =>
invoke<OllamaModel[]>("list_openrouter_models", { apiKey });
export const chatSend = (connectionId: string, messages: ChatMessage[]) =>
invoke<ChatTurnResult>("chat_send", { connectionId, messages });

View File

@@ -134,14 +134,13 @@ export interface SavedQuery {
created_at: string;
}
export type AiProvider = "ollama" | "openai" | "anthropic" | "fireworks";
export type AiProvider = "ollama" | "fireworks" | "openrouter";
export interface AiSettings {
provider: AiProvider;
ollama_url: string;
openai_api_key?: string;
anthropic_api_key?: string;
fireworks_api_key?: string;
openrouter_api_key?: string;
model: string;
}
@@ -216,14 +215,3 @@ export interface ChatTurnResult {
messages: ChatMessage[];
usage: ContextUsage;
}
export type ChartType = "bar" | "line" | "area" | "pie";
export interface ChartConfig {
chart_type: ChartType;
x: string;
y: string;
group?: string | null;
title?: string | null;
orientation?: string | null;
}