refactor(ai): consolidate AI around chat tool-calling; add OpenRouter
- rework chat backend (chat.rs, chat_tools.rs, ai.rs, models, state) around tool calls - add OpenRouter provider alongside Ollama/Fireworks in settings - drop inline AiBar, ResultsPanel explain/fix UI and ChartPreview in favour of the chat panel - add frontend chat tool-registry
This commit is contained in:
@@ -1,103 +0,0 @@
|
||||
import { useState } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { AiSettingsPopover } from "./AiSettingsPopover";
|
||||
import { useGenerateSql } from "@/hooks/use-ai";
|
||||
import { Sparkles, Loader2, X, Eraser } from "lucide-react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
interface Props {
|
||||
connectionId: string;
|
||||
onSqlGenerated: (sql: string) => void;
|
||||
onClose: () => void;
|
||||
onExecute?: () => void;
|
||||
}
|
||||
|
||||
export function AiBar({ connectionId, onSqlGenerated, onClose, onExecute }: Props) {
|
||||
const [prompt, setPrompt] = useState("");
|
||||
const generateMutation = useGenerateSql();
|
||||
|
||||
const handleGenerate = () => {
|
||||
if (!prompt.trim() || generateMutation.isPending) return;
|
||||
generateMutation.mutate(
|
||||
{ connectionId, prompt },
|
||||
{
|
||||
onSuccess: (sql) => {
|
||||
onSqlGenerated(sql);
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI generation failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
if (e.key === "Enter" && (e.ctrlKey || e.metaKey)) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
onExecute?.();
|
||||
return;
|
||||
}
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
handleGenerate();
|
||||
return;
|
||||
}
|
||||
if (e.key === "Escape") {
|
||||
e.stopPropagation();
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="tusk-ai-bar flex items-center gap-2 px-2 py-1.5 tusk-fade-in">
|
||||
<Sparkles className="h-3.5 w-3.5 shrink-0 tusk-ai-icon" />
|
||||
<Input
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Describe the query you want..."
|
||||
className="h-7 min-w-0 flex-1 border-tusk-purple/20 bg-tusk-purple/5 text-xs placeholder:text-muted-foreground/40 focus:border-tusk-purple/40 focus:ring-tusk-purple/20"
|
||||
autoFocus
|
||||
disabled={generateMutation.isPending}
|
||||
/>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
className="gap-1 text-[11px] text-tusk-purple hover:bg-tusk-purple/10 hover:text-tusk-purple"
|
||||
onClick={handleGenerate}
|
||||
disabled={generateMutation.isPending || !prompt.trim()}
|
||||
>
|
||||
{generateMutation.isPending ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
"Generate"
|
||||
)}
|
||||
</Button>
|
||||
{prompt.trim() && (
|
||||
<Button
|
||||
size="icon-xs"
|
||||
variant="ghost"
|
||||
onClick={() => setPrompt("")}
|
||||
title="Clear prompt"
|
||||
disabled={generateMutation.isPending}
|
||||
className="text-muted-foreground"
|
||||
>
|
||||
<Eraser className="h-3 w-3" />
|
||||
</Button>
|
||||
)}
|
||||
<AiSettingsPopover />
|
||||
<Button
|
||||
size="icon-xs"
|
||||
variant="ghost"
|
||||
onClick={onClose}
|
||||
title="Close AI bar"
|
||||
className="text-muted-foreground"
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -7,7 +7,11 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useFireworksModels, useOllamaModels } from "@/hooks/use-ai";
|
||||
import {
|
||||
useFireworksModels,
|
||||
useOllamaModels,
|
||||
useOpenRouterModels,
|
||||
} from "@/hooks/use-ai";
|
||||
import { RefreshCw, Loader2 } from "lucide-react";
|
||||
import type { AiProvider, OllamaModel } from "@/types";
|
||||
|
||||
@@ -17,6 +21,8 @@ interface Props {
|
||||
onOllamaUrlChange: (url: string) => void;
|
||||
fireworksApiKey: string;
|
||||
onFireworksApiKeyChange: (key: string) => void;
|
||||
openrouterApiKey: string;
|
||||
onOpenRouterApiKeyChange: (key: string) => void;
|
||||
model: string;
|
||||
onModelChange: (model: string) => void;
|
||||
}
|
||||
@@ -27,6 +33,8 @@ export function AiSettingsFields({
|
||||
onOllamaUrlChange,
|
||||
fireworksApiKey,
|
||||
onFireworksApiKeyChange,
|
||||
openrouterApiKey,
|
||||
onOpenRouterApiKeyChange,
|
||||
model,
|
||||
onModelChange,
|
||||
}: Props) {
|
||||
@@ -41,6 +49,17 @@ export function AiSettingsFields({
|
||||
);
|
||||
}
|
||||
|
||||
if (provider === "openrouter") {
|
||||
return (
|
||||
<OpenRouterFields
|
||||
apiKey={openrouterApiKey}
|
||||
onApiKeyChange={onOpenRouterApiKeyChange}
|
||||
model={model}
|
||||
onModelChange={onModelChange}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<OllamaFields
|
||||
ollamaUrl={ollamaUrl}
|
||||
@@ -143,6 +162,55 @@ function FireworksFields({
|
||||
);
|
||||
}
|
||||
|
||||
function OpenRouterFields({
|
||||
apiKey,
|
||||
onApiKeyChange,
|
||||
model,
|
||||
onModelChange,
|
||||
}: {
|
||||
apiKey: string;
|
||||
onApiKeyChange: (key: string) => void;
|
||||
model: string;
|
||||
onModelChange: (model: string) => void;
|
||||
}) {
|
||||
const {
|
||||
data: models,
|
||||
isLoading,
|
||||
isError,
|
||||
refetch,
|
||||
} = useOpenRouterModels(apiKey);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-1.5">
|
||||
<label className="text-xs text-muted-foreground">OpenRouter API key</label>
|
||||
<Input
|
||||
type="password"
|
||||
value={apiKey}
|
||||
onChange={(e) => onApiKeyChange(e.target.value)}
|
||||
placeholder="sk-or-..."
|
||||
className="h-8 text-xs"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<p className="text-[10px] text-muted-foreground/70">
|
||||
Stored locally; sent only to openrouter.ai.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<ModelDropdown
|
||||
models={models}
|
||||
loading={isLoading}
|
||||
errored={isError}
|
||||
errorText="Cannot reach OpenRouter (check API key)"
|
||||
onRefresh={() => refetch()}
|
||||
model={model}
|
||||
onModelChange={onModelChange}
|
||||
emptyHint={apiKey.trim() ? "Click ↻ to load models" : "Enter API key first"}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
function ModelDropdown({
|
||||
models,
|
||||
loading,
|
||||
|
||||
@@ -21,6 +21,7 @@ import type { AiProvider } from "@/types";
|
||||
const SUPPORTED_PROVIDERS: { value: AiProvider; label: string }[] = [
|
||||
{ value: "ollama", label: "Ollama (local)" },
|
||||
{ value: "fireworks", label: "Fireworks AI" },
|
||||
{ value: "openrouter", label: "OpenRouter" },
|
||||
];
|
||||
|
||||
export function AiSettingsPopover() {
|
||||
@@ -30,22 +31,16 @@ export function AiSettingsPopover() {
|
||||
const [provider, setProvider] = useState<AiProvider | null>(null);
|
||||
const [url, setUrl] = useState<string | null>(null);
|
||||
const [fireworksKey, setFireworksKey] = useState<string | null>(null);
|
||||
const [openrouterKey, setOpenrouterKey] = useState<string | null>(null);
|
||||
const [model, setModel] = useState<string | null>(null);
|
||||
|
||||
const settingsProvider = settings?.provider;
|
||||
// Hide unsupported legacy values (openai/anthropic) from the selector.
|
||||
const normalizedSettingsProvider: AiProvider | undefined =
|
||||
settingsProvider === "ollama" || settingsProvider === "fireworks"
|
||||
? settingsProvider
|
||||
: settingsProvider
|
||||
? "ollama"
|
||||
: undefined;
|
||||
|
||||
const currentProvider: AiProvider =
|
||||
provider ?? normalizedSettingsProvider ?? "ollama";
|
||||
provider ?? settings?.provider ?? "ollama";
|
||||
const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434";
|
||||
const currentFireworksKey =
|
||||
fireworksKey ?? settings?.fireworks_api_key ?? "";
|
||||
const currentOpenrouterKey =
|
||||
openrouterKey ?? settings?.openrouter_api_key ?? "";
|
||||
const currentModel = model ?? settings?.model ?? "";
|
||||
|
||||
const handleProviderChange = (next: AiProvider) => {
|
||||
@@ -64,6 +59,10 @@ export function AiSettingsPopover() {
|
||||
currentProvider === "fireworks"
|
||||
? currentFireworksKey.trim() || undefined
|
||||
: settings?.fireworks_api_key,
|
||||
openrouter_api_key:
|
||||
currentProvider === "openrouter"
|
||||
? currentOpenrouterKey.trim() || undefined
|
||||
: settings?.openrouter_api_key,
|
||||
model: currentModel,
|
||||
},
|
||||
{
|
||||
@@ -117,6 +116,8 @@ export function AiSettingsPopover() {
|
||||
onOllamaUrlChange={setUrl}
|
||||
fireworksApiKey={currentFireworksKey}
|
||||
onFireworksApiKeyChange={setFireworksKey}
|
||||
openrouterApiKey={currentOpenrouterKey}
|
||||
onOpenRouterApiKeyChange={setOpenrouterKey}
|
||||
model={currentModel}
|
||||
onModelChange={setModel}
|
||||
/>
|
||||
|
||||
@@ -1,327 +0,0 @@
|
||||
import { useMemo } from "react";
|
||||
import {
|
||||
Area,
|
||||
AreaChart,
|
||||
Bar,
|
||||
BarChart,
|
||||
CartesianGrid,
|
||||
Cell,
|
||||
Legend,
|
||||
Line,
|
||||
LineChart,
|
||||
Pie,
|
||||
PieChart,
|
||||
ResponsiveContainer,
|
||||
Tooltip,
|
||||
XAxis,
|
||||
YAxis,
|
||||
} from "recharts";
|
||||
import type { ChartConfig } from "@/types";
|
||||
|
||||
interface Props {
|
||||
config: ChartConfig;
|
||||
columns: string[];
|
||||
rows: unknown[][];
|
||||
height?: number;
|
||||
}
|
||||
|
||||
const PALETTE = [
|
||||
"#60a5fa", // blue-400
|
||||
"#34d399", // emerald-400
|
||||
"#fbbf24", // amber-400
|
||||
"#f87171", // red-400
|
||||
"#a78bfa", // violet-400
|
||||
"#22d3ee", // cyan-400
|
||||
"#fb923c", // orange-400
|
||||
"#f472b6", // pink-400
|
||||
];
|
||||
|
||||
const MAX_POINTS = 500;
|
||||
|
||||
export function ChartPreview({ config, columns, rows, height = 280 }: Props) {
|
||||
const xIdx = columns.indexOf(config.x);
|
||||
const yIdx = columns.indexOf(config.y);
|
||||
const groupIdx = config.group ? columns.indexOf(config.group) : -1;
|
||||
|
||||
const limited = useMemo(() => rows.slice(0, MAX_POINTS), [rows]);
|
||||
|
||||
if (xIdx < 0 || yIdx < 0) {
|
||||
return (
|
||||
<ChartFallback
|
||||
config={config}
|
||||
message={`Column not found: ${xIdx < 0 ? config.x : config.y}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Coerce y values to numbers; chart libs need numeric Y.
|
||||
const numericY = (v: unknown): number => {
|
||||
if (typeof v === "number") return v;
|
||||
if (typeof v === "string") {
|
||||
const n = parseFloat(v);
|
||||
return Number.isFinite(n) ? n : 0;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
const labelX = (v: unknown): string => {
|
||||
if (v == null) return "—";
|
||||
if (typeof v === "string") return v;
|
||||
if (typeof v === "number" || typeof v === "boolean") return String(v);
|
||||
return JSON.stringify(v);
|
||||
};
|
||||
|
||||
const isGrouped = groupIdx >= 0;
|
||||
|
||||
// ──────────── grouped data shape ────────────
|
||||
// For multi-series: pivot to { x: <xValue>, <group1>: yVal, <group2>: yVal, … }
|
||||
// Used by line, area, and grouped-bar.
|
||||
const pivoted = useMemo(() => {
|
||||
if (!isGrouped) return null;
|
||||
const map = new Map<string, Record<string, unknown>>();
|
||||
const groupSet = new Set<string>();
|
||||
for (const row of limited) {
|
||||
const xv = labelX(row[xIdx]);
|
||||
const gv = labelX(row[groupIdx!]);
|
||||
const yv = numericY(row[yIdx]);
|
||||
groupSet.add(gv);
|
||||
const acc = map.get(xv) ?? { _x: xv };
|
||||
acc[gv] = ((acc[gv] as number) ?? 0) + yv;
|
||||
map.set(xv, acc);
|
||||
}
|
||||
return {
|
||||
data: Array.from(map.values()),
|
||||
groups: Array.from(groupSet),
|
||||
};
|
||||
}, [isGrouped, limited, xIdx, yIdx, groupIdx]);
|
||||
|
||||
// Single series shape: [{ _x, _y }]
|
||||
const flat = useMemo(() => {
|
||||
return limited.map((row) => ({
|
||||
_x: labelX(row[xIdx]),
|
||||
_y: numericY(row[yIdx]),
|
||||
}));
|
||||
}, [limited, xIdx, yIdx]);
|
||||
|
||||
const tickStyle = {
|
||||
fill: "var(--muted-foreground)",
|
||||
fontSize: 10,
|
||||
} as const;
|
||||
|
||||
const axisLine = {
|
||||
stroke: "rgba(255, 255, 255, 0.08)",
|
||||
} as const;
|
||||
|
||||
const tooltipStyle = {
|
||||
backgroundColor: "var(--popover)",
|
||||
border: "1px solid var(--border)",
|
||||
borderRadius: 6,
|
||||
fontSize: 11,
|
||||
} as const;
|
||||
|
||||
if (config.chart_type === "pie") {
|
||||
// Pie: aggregate y by x label (sum), no group support.
|
||||
const agg = new Map<string, number>();
|
||||
for (const row of limited) {
|
||||
const xv = labelX(row[xIdx]);
|
||||
agg.set(xv, (agg.get(xv) ?? 0) + numericY(row[yIdx]));
|
||||
}
|
||||
const data = Array.from(agg.entries()).map(([name, value]) => ({ name, value }));
|
||||
return (
|
||||
<ChartFrame config={config} height={height} count={data.length} totalRows={rows.length}>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<PieChart>
|
||||
<Pie
|
||||
data={data}
|
||||
dataKey="value"
|
||||
nameKey="name"
|
||||
outerRadius={Math.min(height / 2.5, 110)}
|
||||
label={(entry) =>
|
||||
typeof entry.name === "string" && entry.name.length < 20 ? entry.name : ""
|
||||
}
|
||||
>
|
||||
{data.map((_, i) => (
|
||||
<Cell key={i} fill={PALETTE[i % PALETTE.length]} />
|
||||
))}
|
||||
</Pie>
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
<Legend
|
||||
wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }}
|
||||
verticalAlign="bottom"
|
||||
/>
|
||||
</PieChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
if (config.chart_type === "line") {
|
||||
return (
|
||||
<ChartFrame
|
||||
config={config}
|
||||
height={height}
|
||||
count={isGrouped ? pivoted!.data.length : flat.length}
|
||||
totalRows={rows.length}
|
||||
>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<LineChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
|
||||
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
|
||||
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
{isGrouped ? (
|
||||
<>
|
||||
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
|
||||
{pivoted!.groups.map((g, i) => (
|
||||
<Line
|
||||
key={g}
|
||||
type="monotone"
|
||||
dataKey={g}
|
||||
stroke={PALETTE[i % PALETTE.length]}
|
||||
strokeWidth={2}
|
||||
dot={false}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
<Line type="monotone" dataKey="_y" stroke={PALETTE[0]} strokeWidth={2} dot={false} />
|
||||
)}
|
||||
</LineChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
if (config.chart_type === "area") {
|
||||
return (
|
||||
<ChartFrame
|
||||
config={config}
|
||||
height={height}
|
||||
count={isGrouped ? pivoted!.data.length : flat.length}
|
||||
totalRows={rows.length}
|
||||
>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<AreaChart data={isGrouped ? pivoted!.data : flat} margin={{ top: 8, right: 12, left: 0, bottom: 4 }}>
|
||||
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={false} />
|
||||
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
{isGrouped ? (
|
||||
<>
|
||||
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
|
||||
{pivoted!.groups.map((g, i) => (
|
||||
<Area
|
||||
key={g}
|
||||
type="monotone"
|
||||
dataKey={g}
|
||||
stackId="1"
|
||||
stroke={PALETTE[i % PALETTE.length]}
|
||||
fill={PALETTE[i % PALETTE.length]}
|
||||
fillOpacity={0.35}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
<Area
|
||||
type="monotone"
|
||||
dataKey="_y"
|
||||
stroke={PALETTE[0]}
|
||||
fill={PALETTE[0]}
|
||||
fillOpacity={0.35}
|
||||
/>
|
||||
)}
|
||||
</AreaChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
// bar (default)
|
||||
const horizontal = config.orientation === "horizontal";
|
||||
return (
|
||||
<ChartFrame
|
||||
config={config}
|
||||
height={height}
|
||||
count={isGrouped ? pivoted!.data.length : flat.length}
|
||||
totalRows={rows.length}
|
||||
>
|
||||
<ResponsiveContainer width="100%" height={height}>
|
||||
<BarChart
|
||||
layout={horizontal ? "vertical" : "horizontal"}
|
||||
data={isGrouped ? pivoted!.data : flat}
|
||||
margin={{ top: 8, right: 12, left: horizontal ? 24 : 0, bottom: 4 }}
|
||||
>
|
||||
<CartesianGrid stroke="rgba(255,255,255,0.05)" vertical={horizontal} horizontal={!horizontal} />
|
||||
{horizontal ? (
|
||||
<>
|
||||
<XAxis type="number" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis dataKey="_x" type="category" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} width={100} />
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<XAxis dataKey="_x" tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
<YAxis tick={tickStyle} axisLine={axisLine} tickLine={axisLine} />
|
||||
</>
|
||||
)}
|
||||
<Tooltip contentStyle={tooltipStyle} />
|
||||
{isGrouped ? (
|
||||
<>
|
||||
<Legend wrapperStyle={{ fontSize: 11, color: "var(--muted-foreground)" }} />
|
||||
{pivoted!.groups.map((g, i) => (
|
||||
<Bar key={g} dataKey={g} fill={PALETTE[i % PALETTE.length]} radius={[3, 3, 0, 0]} />
|
||||
))}
|
||||
</>
|
||||
) : (
|
||||
<Bar dataKey="_y" fill={PALETTE[0]} radius={[3, 3, 0, 0]} />
|
||||
)}
|
||||
</BarChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartFrame>
|
||||
);
|
||||
}
|
||||
|
||||
function ChartFrame({
|
||||
config,
|
||||
height,
|
||||
count,
|
||||
totalRows,
|
||||
children,
|
||||
}: {
|
||||
config: ChartConfig;
|
||||
height: number;
|
||||
count: number;
|
||||
totalRows: number;
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div className="rounded-md border border-border/40 bg-background">
|
||||
<div className="flex items-center gap-2 border-b border-border/30 px-2 py-1 text-[11px] text-muted-foreground">
|
||||
<span className="font-medium text-foreground/80">
|
||||
{config.title ?? `${capitalize(config.chart_type)} chart`}
|
||||
</span>
|
||||
<span className="ml-auto text-muted-foreground/60">
|
||||
{count} point{count === 1 ? "" : "s"}
|
||||
{totalRows > MAX_POINTS && ` (of ${totalRows}, capped at ${MAX_POINTS})`}
|
||||
</span>
|
||||
</div>
|
||||
<div className="p-2" style={{ minHeight: height }}>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ChartFallback({ config, message }: { config: ChartConfig; message: string }) {
|
||||
return (
|
||||
<div className="rounded-md border border-destructive/40 bg-destructive/5 p-3 text-xs">
|
||||
<div className="font-medium text-destructive">
|
||||
Chart {config.chart_type} failed
|
||||
</div>
|
||||
<div className="mt-1 text-muted-foreground">{message}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function capitalize(s: string) {
|
||||
return s.charAt(0).toUpperCase() + s.slice(1);
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useState } from "react";
|
||||
import { ResultsTable } from "@/components/results/ResultsTable";
|
||||
import { ExportDialog } from "@/components/export/ExportDialog";
|
||||
import { ChartPreview } from "./ChartPreview";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
@@ -15,19 +14,12 @@ import {
|
||||
AlertCircle,
|
||||
Sparkles,
|
||||
User,
|
||||
Wrench,
|
||||
Database,
|
||||
Columns,
|
||||
Layers,
|
||||
RefreshCw,
|
||||
StickyNote,
|
||||
Bookmark,
|
||||
BookmarkPlus,
|
||||
Maximize2,
|
||||
Download,
|
||||
BarChart3,
|
||||
} from "lucide-react";
|
||||
import type { ChartConfig, ChatMessage } from "@/types";
|
||||
import type { ChatMessage } from "@/types";
|
||||
import { getToolMeta, isQueryResultTool } from "./tool-registry";
|
||||
|
||||
interface Props {
|
||||
message: ChatMessage;
|
||||
@@ -79,8 +71,10 @@ function AssistantBubble({ text }: { text: string }) {
|
||||
|
||||
function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string }) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const preview = extractToolPreview(tool, inputJson);
|
||||
const Icon = iconForTool(tool);
|
||||
const meta = getToolMeta(tool);
|
||||
const preview = previewFromJson(tool, inputJson);
|
||||
const Icon = meta.icon;
|
||||
const showSqlPreview = (tool === "run_query" || tool === "explain_query") && preview;
|
||||
|
||||
return (
|
||||
<div className="ml-8 rounded-md border border-border/40 bg-muted/20">
|
||||
@@ -91,17 +85,14 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
|
||||
>
|
||||
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
|
||||
<Icon className="h-3 w-3" />
|
||||
<span className="font-medium">{labelForTool(tool)}</span>
|
||||
<span className="font-medium">{meta.label}</span>
|
||||
{preview && (
|
||||
<span className="ml-1 truncate text-muted-foreground/70">
|
||||
{preview.slice(0, 80)}
|
||||
{preview.length > 80 ? "…" : ""}
|
||||
</span>
|
||||
<span className="ml-1 truncate text-muted-foreground/70">{preview}</span>
|
||||
)}
|
||||
</button>
|
||||
{expanded && (
|
||||
<div className="border-t border-border/30 p-2">
|
||||
{tool === "run_query" && preview ? (
|
||||
{showSqlPreview ? (
|
||||
<pre className="overflow-x-auto whitespace-pre-wrap rounded bg-background/60 p-2 font-mono text-[11px]">
|
||||
{preview}
|
||||
</pre>
|
||||
@@ -116,6 +107,15 @@ function ToolCallBlock({ tool, inputJson }: { tool: string; inputJson: string })
|
||||
);
|
||||
}
|
||||
|
||||
function previewFromJson(tool: string, inputJson: string): string | null {
|
||||
try {
|
||||
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
|
||||
return getToolMeta(tool).preview(parsed);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function ToolResultBlock({
|
||||
tool,
|
||||
isError,
|
||||
@@ -132,87 +132,20 @@ function ToolResultBlock({
|
||||
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
|
||||
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
|
||||
<div>
|
||||
<div className="font-medium text-destructive">{labelForTool(tool)} failed</div>
|
||||
<div className="font-medium text-destructive">{getToolMeta(tool).label} failed</div>
|
||||
{text && <div className="mt-1 whitespace-pre-wrap text-muted-foreground">{text}</div>}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Legacy schema tool — keep a one-line indicator for old threads.
|
||||
if (tool === "get_schema") {
|
||||
return (
|
||||
<div className="ml-8 flex items-center gap-2 rounded-md border border-border/40 bg-muted/20 px-2 py-1.5 text-xs text-muted-foreground">
|
||||
<Database className="h-3 w-3" />
|
||||
<span>Loaded schema context ({text?.length ?? 0} chars)</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Text-only tools (chat v2/v3): list_databases, list_tables, get_columns, switch_database,
|
||||
// remember, save_query, find_queries.
|
||||
if (
|
||||
tool === "list_databases" ||
|
||||
tool === "list_tables" ||
|
||||
tool === "get_columns" ||
|
||||
tool === "switch_database" ||
|
||||
tool === "remember" ||
|
||||
tool === "save_query" ||
|
||||
tool === "find_queries"
|
||||
) {
|
||||
return <TextToolResult tool={tool} text={text} />;
|
||||
}
|
||||
|
||||
// make_chart — render chart inline using config from text + data from result.
|
||||
if (tool === "make_chart") {
|
||||
return <ChartToolResult text={text} result={result} />;
|
||||
}
|
||||
|
||||
// run_query — full results table with Open-full / Export actions.
|
||||
if (result) {
|
||||
// Tools that produce a QueryResult (rendered as a table): run_query, sample_data.
|
||||
if (isQueryResultTool(tool) && result) {
|
||||
return <RunQueryResultBlock result={result} />;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function ChartToolResult({
|
||||
text,
|
||||
result,
|
||||
}: {
|
||||
text: string | null;
|
||||
result: { columns: string[]; types: string[]; rows: unknown[][]; row_count: number; execution_time_ms: number } | null;
|
||||
}) {
|
||||
let config: ChartConfig | null = null;
|
||||
try {
|
||||
if (text) {
|
||||
config = JSON.parse(text) as ChartConfig;
|
||||
}
|
||||
} catch {
|
||||
config = null;
|
||||
}
|
||||
if (!config || !result) {
|
||||
return (
|
||||
<div className="ml-8 flex items-start gap-2 rounded-md border border-destructive/40 bg-destructive/5 px-3 py-2 text-xs">
|
||||
<AlertCircle className="mt-0.5 h-3.5 w-3.5 shrink-0 text-destructive" />
|
||||
<div>
|
||||
<div className="font-medium text-destructive">Chart unavailable</div>
|
||||
<div className="mt-1 text-muted-foreground">
|
||||
The agent referenced a chart but the previous query result is not attached.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="ml-8">
|
||||
<ChartPreview
|
||||
config={config}
|
||||
columns={result.columns}
|
||||
rows={result.rows}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
// Everything else falls back to a collapsible text block.
|
||||
return <TextToolResult tool={tool} text={text} />;
|
||||
}
|
||||
|
||||
function RunQueryResultBlock({
|
||||
@@ -315,8 +248,10 @@ function RunQueryResultBlock({
|
||||
}
|
||||
|
||||
function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
|
||||
// Lazy preview: switch_database is short; everything else collapses by default.
|
||||
const [expanded, setExpanded] = useState(tool === "switch_database");
|
||||
const Icon = iconForTool(tool);
|
||||
const meta = getToolMeta(tool);
|
||||
const Icon = meta.icon;
|
||||
const lineCount = text ? text.split("\n").length : 0;
|
||||
|
||||
return (
|
||||
@@ -328,7 +263,7 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
|
||||
>
|
||||
{expanded ? <ChevronDown className="h-3 w-3" /> : <ChevronRight className="h-3 w-3" />}
|
||||
<Icon className="h-3 w-3" />
|
||||
<span className="font-medium">{labelForTool(tool)}</span>
|
||||
<span className="font-medium">{meta.label}</span>
|
||||
{text && (
|
||||
<span className="ml-1 text-muted-foreground/60">
|
||||
{lineCount} line{lineCount === 1 ? "" : "s"}
|
||||
@@ -346,93 +281,6 @@ function TextToolResult({ tool, text }: { tool: string; text: string | null }) {
|
||||
);
|
||||
}
|
||||
|
||||
function labelForTool(tool: string): string {
|
||||
switch (tool) {
|
||||
case "run_query":
|
||||
return "Run SQL";
|
||||
case "list_databases":
|
||||
return "List databases";
|
||||
case "list_tables":
|
||||
return "List tables";
|
||||
case "get_columns":
|
||||
return "Inspect columns";
|
||||
case "switch_database":
|
||||
return "Switch database";
|
||||
case "remember":
|
||||
return "Remember";
|
||||
case "save_query":
|
||||
return "Save query";
|
||||
case "find_queries":
|
||||
return "Find saved queries";
|
||||
case "make_chart":
|
||||
return "Make chart";
|
||||
case "get_schema":
|
||||
return "Load schema";
|
||||
default:
|
||||
return tool;
|
||||
}
|
||||
}
|
||||
|
||||
function iconForTool(tool: string) {
|
||||
switch (tool) {
|
||||
case "run_query":
|
||||
return Wrench;
|
||||
case "list_databases":
|
||||
return Database;
|
||||
case "list_tables":
|
||||
return Layers;
|
||||
case "get_columns":
|
||||
return Columns;
|
||||
case "switch_database":
|
||||
return RefreshCw;
|
||||
case "remember":
|
||||
return StickyNote;
|
||||
case "save_query":
|
||||
return BookmarkPlus;
|
||||
case "find_queries":
|
||||
return Bookmark;
|
||||
case "make_chart":
|
||||
return BarChart3;
|
||||
case "get_schema":
|
||||
return Database;
|
||||
default:
|
||||
return Wrench;
|
||||
}
|
||||
}
|
||||
|
||||
function extractToolPreview(tool: string, inputJson: string): string | null {
|
||||
try {
|
||||
const parsed = JSON.parse(inputJson) as Record<string, unknown>;
|
||||
switch (tool) {
|
||||
case "run_query":
|
||||
return typeof parsed.sql === "string" ? parsed.sql : null;
|
||||
case "list_tables":
|
||||
return typeof parsed.database === "string" ? parsed.database : null;
|
||||
case "switch_database":
|
||||
return typeof parsed.database === "string" ? parsed.database : null;
|
||||
case "get_columns":
|
||||
return Array.isArray(parsed.tables) ? parsed.tables.join(", ") : null;
|
||||
case "remember":
|
||||
return typeof parsed.note === "string" ? parsed.note : null;
|
||||
case "save_query":
|
||||
return typeof parsed.name === "string" ? parsed.name : null;
|
||||
case "find_queries":
|
||||
return typeof parsed.text === "string" ? parsed.text : null;
|
||||
case "make_chart": {
|
||||
const t = typeof parsed.chart_type === "string" ? parsed.chart_type : null;
|
||||
const x = typeof parsed.x === "string" ? parsed.x : null;
|
||||
const y = typeof parsed.y === "string" ? parsed.y : null;
|
||||
if (t && x && y) return `${t}: ${x} → ${y}`;
|
||||
return null;
|
||||
}
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function prettyJson(s: string): string {
|
||||
try {
|
||||
return JSON.stringify(JSON.parse(s), null, 2);
|
||||
|
||||
107
src/components/chat/tool-registry.ts
Normal file
107
src/components/chat/tool-registry.ts
Normal file
@@ -0,0 +1,107 @@
|
||||
import {
|
||||
Database,
|
||||
Layers,
|
||||
Columns,
|
||||
RefreshCw,
|
||||
Wrench,
|
||||
StickyNote,
|
||||
Bookmark,
|
||||
BookmarkPlus,
|
||||
Activity,
|
||||
Shuffle,
|
||||
GitBranch,
|
||||
AlertTriangle,
|
||||
} from "lucide-react";
|
||||
import type { LucideIcon } from "lucide-react";
|
||||
|
||||
export type ToolMeta = {
|
||||
icon: LucideIcon;
|
||||
label: string;
|
||||
preview: (parsed: Record<string, unknown>) => string | null;
|
||||
};
|
||||
|
||||
const truncate = (s: unknown, n = 80): string | null => {
|
||||
if (typeof s !== "string") return null;
|
||||
return s.length > n ? `${s.slice(0, n)}…` : s;
|
||||
};
|
||||
|
||||
export const TOOLS: Record<string, ToolMeta> = {
|
||||
list_databases: {
|
||||
icon: Database,
|
||||
label: "List databases",
|
||||
preview: () => null,
|
||||
},
|
||||
list_tables: {
|
||||
icon: Layers,
|
||||
label: "List tables",
|
||||
preview: (p) => (typeof p.database === "string" ? p.database : null),
|
||||
},
|
||||
get_columns: {
|
||||
icon: Columns,
|
||||
label: "Inspect columns",
|
||||
preview: (p) => (Array.isArray(p.tables) ? (p.tables as string[]).join(", ") : null),
|
||||
},
|
||||
switch_database: {
|
||||
icon: RefreshCw,
|
||||
label: "Switch database",
|
||||
preview: (p) => (typeof p.database === "string" ? p.database : null),
|
||||
},
|
||||
run_query: {
|
||||
icon: Wrench,
|
||||
label: "Run SQL",
|
||||
preview: (p) => truncate(p.sql),
|
||||
},
|
||||
remember: {
|
||||
icon: StickyNote,
|
||||
label: "Remember",
|
||||
preview: (p) => (typeof p.note === "string" ? p.note : null),
|
||||
},
|
||||
save_query: {
|
||||
icon: BookmarkPlus,
|
||||
label: "Save query",
|
||||
preview: (p) => (typeof p.name === "string" ? p.name : null),
|
||||
},
|
||||
find_queries: {
|
||||
icon: Bookmark,
|
||||
label: "Find saved queries",
|
||||
preview: (p) => (typeof p.text === "string" ? p.text : null),
|
||||
},
|
||||
profile_table: {
|
||||
icon: Activity,
|
||||
label: "Profile table",
|
||||
preview: (p) => (typeof p.table === "string" ? p.table : null),
|
||||
},
|
||||
sample_data: {
|
||||
icon: Shuffle,
|
||||
label: "Sample rows",
|
||||
preview: (p) => {
|
||||
const t = typeof p.table === "string" ? p.table : "";
|
||||
const limit = typeof p.limit === "number" ? p.limit : 50;
|
||||
return t ? `${t} (${limit})` : null;
|
||||
},
|
||||
},
|
||||
explain_query: {
|
||||
icon: GitBranch,
|
||||
label: "Explain query",
|
||||
preview: (p) => truncate(p.sql),
|
||||
},
|
||||
detect_skew: {
|
||||
icon: AlertTriangle,
|
||||
label: "Detect skew",
|
||||
preview: (p) => (typeof p.table === "string" ? p.table : null),
|
||||
},
|
||||
};
|
||||
|
||||
export function getToolMeta(tool: string): ToolMeta {
|
||||
return (
|
||||
TOOLS[tool] ?? {
|
||||
icon: Wrench,
|
||||
label: tool,
|
||||
preview: () => null,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export function isQueryResultTool(tool: string): boolean {
|
||||
return tool === "run_query" || tool === "sample_data";
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
import { ResultsTable } from "./ResultsTable";
|
||||
import { ResultsJsonView } from "./ResultsJsonView";
|
||||
import type { QueryResult } from "@/types";
|
||||
import { Loader2, AlertCircle, Sparkles, Wand2 } from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Loader2, AlertCircle } from "lucide-react";
|
||||
|
||||
interface Props {
|
||||
result?: QueryResult | null;
|
||||
@@ -15,10 +14,6 @@ interface Props {
|
||||
value: unknown
|
||||
) => void;
|
||||
highlightedCells?: Set<string>;
|
||||
aiExplanation?: string | null;
|
||||
isAiLoading?: boolean;
|
||||
onExplainError?: () => void;
|
||||
onFixError?: () => void;
|
||||
}
|
||||
|
||||
export function ResultsPanel({
|
||||
@@ -28,10 +23,6 @@ export function ResultsPanel({
|
||||
viewMode = "table",
|
||||
onCellDoubleClick,
|
||||
highlightedCells,
|
||||
aiExplanation,
|
||||
isAiLoading,
|
||||
onExplainError,
|
||||
onFixError,
|
||||
}: Props) {
|
||||
if (isLoading) {
|
||||
return (
|
||||
@@ -42,22 +33,6 @@ export function ResultsPanel({
|
||||
);
|
||||
}
|
||||
|
||||
if (aiExplanation) {
|
||||
return (
|
||||
<div className="h-full select-text overflow-auto p-4">
|
||||
<div className="rounded-md border bg-muted/30 p-4">
|
||||
<div className="mb-2 flex items-center gap-2 text-xs font-medium text-muted-foreground">
|
||||
<Sparkles className="h-3.5 w-3.5" />
|
||||
AI Explanation
|
||||
</div>
|
||||
<pre className="whitespace-pre-wrap font-sans text-sm leading-relaxed text-foreground">
|
||||
{aiExplanation}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="flex h-full select-text flex-col items-center justify-center gap-3 p-4">
|
||||
@@ -65,42 +40,6 @@ export function ResultsPanel({
|
||||
<AlertCircle className="mt-0.5 h-4 w-4 shrink-0" />
|
||||
<pre className="whitespace-pre-wrap font-mono text-xs">{error}</pre>
|
||||
</div>
|
||||
{(onExplainError || onFixError) && (
|
||||
<div className="flex items-center gap-2">
|
||||
{onExplainError && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
className="h-7 gap-1.5 text-xs"
|
||||
onClick={onExplainError}
|
||||
disabled={isAiLoading}
|
||||
>
|
||||
{isAiLoading ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
<Sparkles className="h-3 w-3" />
|
||||
)}
|
||||
Explain
|
||||
</Button>
|
||||
)}
|
||||
{onFixError && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
className="h-7 gap-1.5 text-xs"
|
||||
onClick={onFixError}
|
||||
disabled={isAiLoading}
|
||||
>
|
||||
{isAiLoading ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
<Wand2 className="h-3 w-3" />
|
||||
)}
|
||||
Fix with AI
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import type { AiProvider, AppSettings } from "@/types";
|
||||
const SUPPORTED_AI_PROVIDERS: { value: AiProvider; label: string }[] = [
|
||||
{ value: "ollama", label: "Ollama (local)" },
|
||||
{ value: "fireworks", label: "Fireworks AI" },
|
||||
{ value: "openrouter", label: "OpenRouter" },
|
||||
];
|
||||
|
||||
interface Props {
|
||||
@@ -50,6 +51,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
const [aiProvider, setAiProvider] = useState<AiProvider>("ollama");
|
||||
const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434");
|
||||
const [fireworksApiKey, setFireworksApiKey] = useState("");
|
||||
const [openrouterApiKey, setOpenrouterApiKey] = useState("");
|
||||
const [aiModel, setAiModel] = useState("");
|
||||
|
||||
const [copied, setCopied] = useState(false);
|
||||
@@ -70,10 +72,14 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
if (aiSettings) {
|
||||
// Legacy openai/anthropic values aren't user-selectable here — fall back to ollama.
|
||||
setAiProvider(
|
||||
aiSettings.provider === "fireworks" ? "fireworks" : "ollama"
|
||||
aiSettings.provider === "fireworks" ||
|
||||
aiSettings.provider === "openrouter"
|
||||
? aiSettings.provider
|
||||
: "ollama"
|
||||
);
|
||||
setOllamaUrl(aiSettings.ollama_url);
|
||||
setFireworksApiKey(aiSettings.fireworks_api_key ?? "");
|
||||
setOpenrouterApiKey(aiSettings.openrouter_api_key ?? "");
|
||||
setAiModel(aiSettings.model);
|
||||
}
|
||||
}
|
||||
@@ -115,6 +121,10 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
aiProvider === "fireworks"
|
||||
? fireworksApiKey.trim() || undefined
|
||||
: aiSettings?.fireworks_api_key,
|
||||
openrouter_api_key:
|
||||
aiProvider === "openrouter"
|
||||
? openrouterApiKey.trim() || undefined
|
||||
: aiSettings?.openrouter_api_key,
|
||||
model: aiModel,
|
||||
},
|
||||
{
|
||||
@@ -167,7 +177,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
<span
|
||||
className={`inline-block h-2 w-2 rounded-full ${
|
||||
mcpStatus?.running
|
||||
? "bg-green-500"
|
||||
? "bg-success ring-2 ring-success/25"
|
||||
: "bg-muted-foreground/30"
|
||||
}`}
|
||||
/>
|
||||
@@ -189,7 +199,7 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
title="Copy endpoint URL"
|
||||
>
|
||||
{copied ? (
|
||||
<Check className="h-3 w-3 text-green-500" />
|
||||
<Check className="h-3 w-3 text-success" />
|
||||
) : (
|
||||
<Copy className="h-3 w-3" />
|
||||
)}
|
||||
@@ -229,6 +239,8 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
|
||||
onOllamaUrlChange={setOllamaUrl}
|
||||
fireworksApiKey={fireworksApiKey}
|
||||
onFireworksApiKeyChange={setFireworksApiKey}
|
||||
openrouterApiKey={openrouterApiKey}
|
||||
onOpenRouterApiKeyChange={setOpenrouterApiKey}
|
||||
model={aiModel}
|
||||
onModelChange={setAiModel}
|
||||
/>
|
||||
|
||||
@@ -13,7 +13,7 @@ import { useCompletionSchema } from "@/hooks/use-completion-schema";
|
||||
import { useConnections } from "@/hooks/use-connections";
|
||||
import { useAppStore } from "@/stores/app-store";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces, Sparkles, BrainCircuit } from "lucide-react";
|
||||
import { Play, Loader2, Lock, BarChart3, Download, AlignLeft, Bookmark, Table2, Braces } from "lucide-react";
|
||||
import { format as formatSql } from "sql-formatter";
|
||||
import { SaveQueryDialog } from "@/components/saved-queries/SaveQueryDialog";
|
||||
import {
|
||||
@@ -25,8 +25,6 @@ import {
|
||||
import { exportCsv, exportJson } from "@/lib/tauri";
|
||||
import { save } from "@tauri-apps/plugin-dialog";
|
||||
import { toast } from "sonner";
|
||||
import { AiBar } from "@/components/ai/AiBar";
|
||||
import { useExplainSql, useFixSqlError } from "@/hooks/use-ai";
|
||||
import type { QueryResult, ExplainResult } from "@/types";
|
||||
|
||||
interface Props {
|
||||
@@ -53,12 +51,8 @@ export function WorkspacePanel({
|
||||
const [resultView, setResultView] = useState<"results" | "explain">("results");
|
||||
const [resultViewMode, setResultViewMode] = useState<"table" | "json">("table");
|
||||
const [saveDialogOpen, setSaveDialogOpen] = useState(false);
|
||||
const [aiBarOpen, setAiBarOpen] = useState(false);
|
||||
const [aiExplanation, setAiExplanation] = useState<string | null>(null);
|
||||
|
||||
const queryMutation = useQueryExecution();
|
||||
const explainMutation = useExplainSql();
|
||||
const fixMutation = useFixSqlError();
|
||||
const addHistoryMutation = useAddHistory();
|
||||
const { data: connections } = useConnections();
|
||||
const { data: completionSchema } = useCompletionSchema(connectionId);
|
||||
@@ -102,7 +96,6 @@ export function WorkspacePanel({
|
||||
if (!sqlValue.trim() || !connectionId) return;
|
||||
setError(null);
|
||||
setExplainData(null);
|
||||
setAiExplanation(null);
|
||||
setResultView("results");
|
||||
queryMutation.mutate(
|
||||
{ connectionId, sql: sqlValue },
|
||||
@@ -196,60 +189,6 @@ export function WorkspacePanel({
|
||||
[result]
|
||||
);
|
||||
|
||||
const isAiLoading = explainMutation.isPending || fixMutation.isPending;
|
||||
|
||||
const handleAiExplain = useCallback(() => {
|
||||
if (!sqlValue.trim() || !connectionId) return;
|
||||
setAiExplanation(null);
|
||||
setResultView("results");
|
||||
explainMutation.mutate(
|
||||
{ connectionId, sql: sqlValue },
|
||||
{
|
||||
onSuccess: (explanation) => {
|
||||
setAiExplanation(explanation);
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI Explain failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
}, [connectionId, sqlValue, explainMutation]);
|
||||
|
||||
const handleExplainError = useCallback(() => {
|
||||
if (!sqlValue.trim() || !connectionId || !error) return;
|
||||
setAiExplanation(null);
|
||||
explainMutation.mutate(
|
||||
{ connectionId, sql: `${sqlValue}\n\n-- Error: ${error}` },
|
||||
{
|
||||
onSuccess: (explanation) => {
|
||||
setAiExplanation(explanation);
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI Explain failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
}, [connectionId, sqlValue, error, explainMutation]);
|
||||
|
||||
const handleFixError = useCallback(() => {
|
||||
if (!sqlValue.trim() || !connectionId || !error) return;
|
||||
fixMutation.mutate(
|
||||
{ connectionId, sql: sqlValue, errorMessage: error },
|
||||
{
|
||||
onSuccess: (fixedSql) => {
|
||||
setSqlValue(fixedSql);
|
||||
onSqlChange?.(fixedSql);
|
||||
setError(null);
|
||||
setAiExplanation(null);
|
||||
toast.success("SQL replaced by AI suggestion");
|
||||
},
|
||||
onError: (err) => {
|
||||
toast.error("AI Fix failed", { description: String(err) });
|
||||
},
|
||||
}
|
||||
);
|
||||
}, [connectionId, sqlValue, error, fixMutation, onSqlChange]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ResizablePanelGroup orientation="vertical">
|
||||
@@ -308,35 +247,6 @@ export function WorkspacePanel({
|
||||
Save
|
||||
</Button>
|
||||
|
||||
<div className="mx-1 h-3.5 w-px bg-border/40" />
|
||||
|
||||
{/* AI actions group — purple-branded */}
|
||||
<Button
|
||||
size="xs"
|
||||
variant={aiBarOpen ? "secondary" : "ghost"}
|
||||
className={`gap-1 text-[11px] ${aiBarOpen ? "text-tusk-purple" : ""}`}
|
||||
onClick={() => setAiBarOpen(!aiBarOpen)}
|
||||
title="AI SQL Generator"
|
||||
>
|
||||
<Sparkles className={`h-3 w-3 ${aiBarOpen ? "tusk-ai-icon" : ""}`} />
|
||||
AI
|
||||
</Button>
|
||||
<Button
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
className="gap-1 text-[11px]"
|
||||
onClick={handleAiExplain}
|
||||
disabled={isAiLoading || !sqlValue.trim()}
|
||||
title="Explain query with AI"
|
||||
>
|
||||
{isAiLoading ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin" />
|
||||
) : (
|
||||
<BrainCircuit className="h-3 w-3" />
|
||||
)}
|
||||
AI Explain
|
||||
</Button>
|
||||
|
||||
{result && result.columns.length > 0 && (
|
||||
<>
|
||||
<div className="mx-1 h-3.5 w-px bg-border/40" />
|
||||
@@ -369,23 +279,12 @@ export function WorkspacePanel({
|
||||
{"\u2318"}Enter
|
||||
</span>
|
||||
{isReadOnly && (
|
||||
<span className="ml-2 flex items-center gap-1 rounded-sm bg-amber-500/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-amber-500">
|
||||
<span className="ml-2 flex items-center gap-1 rounded-sm bg-warning/10 px-1.5 py-0.5 text-[10px] font-semibold tracking-wide text-warning">
|
||||
<Lock className="h-2.5 w-2.5" />
|
||||
READ
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{aiBarOpen && (
|
||||
<AiBar
|
||||
connectionId={connectionId}
|
||||
onSqlGenerated={(sql) => {
|
||||
setSqlValue(sql);
|
||||
onSqlChange?.(sql);
|
||||
}}
|
||||
onClose={() => setAiBarOpen(false)}
|
||||
onExecute={handleExecute}
|
||||
/>
|
||||
)}
|
||||
<div className="min-h-0 flex-1">
|
||||
<SqlEditor
|
||||
value={sqlValue}
|
||||
@@ -400,7 +299,7 @@ export function WorkspacePanel({
|
||||
<ResizableHandle withHandle />
|
||||
<ResizablePanel id="results" defaultSize="60%" minSize="15%">
|
||||
<div className="flex h-full flex-col overflow-hidden">
|
||||
{(explainData || result || error || aiExplanation) && (
|
||||
{(explainData || result || error) && (
|
||||
<div className="flex shrink-0 items-center border-b border-border/40 text-xs">
|
||||
<button
|
||||
className={`relative px-3 py-1.5 font-medium transition-colors ${
|
||||
@@ -469,10 +368,6 @@ export function WorkspacePanel({
|
||||
error={error}
|
||||
isLoading={queryMutation.isPending && resultView === "results"}
|
||||
viewMode={resultViewMode}
|
||||
aiExplanation={aiExplanation}
|
||||
isAiLoading={isAiLoading}
|
||||
onExplainError={error ? handleExplainError : undefined}
|
||||
onFixError={error ? handleFixError : undefined}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -4,9 +4,7 @@ import {
|
||||
saveAiSettings,
|
||||
listOllamaModels,
|
||||
listFireworksModels,
|
||||
generateSql,
|
||||
explainSql,
|
||||
fixSqlError,
|
||||
listOpenRouterModels,
|
||||
} from "@/lib/tauri";
|
||||
import type { AiSettings } from "@/types";
|
||||
|
||||
@@ -47,40 +45,12 @@ export function useFireworksModels(apiKey: string | undefined) {
|
||||
});
|
||||
}
|
||||
|
||||
export function useGenerateSql() {
|
||||
return useMutation({
|
||||
mutationFn: ({
|
||||
connectionId,
|
||||
prompt,
|
||||
}: {
|
||||
connectionId: string;
|
||||
prompt: string;
|
||||
}) => generateSql(connectionId, prompt),
|
||||
});
|
||||
}
|
||||
|
||||
export function useExplainSql() {
|
||||
return useMutation({
|
||||
mutationFn: ({
|
||||
connectionId,
|
||||
sql,
|
||||
}: {
|
||||
connectionId: string;
|
||||
sql: string;
|
||||
}) => explainSql(connectionId, sql),
|
||||
});
|
||||
}
|
||||
|
||||
export function useFixSqlError() {
|
||||
return useMutation({
|
||||
mutationFn: ({
|
||||
connectionId,
|
||||
sql,
|
||||
errorMessage,
|
||||
}: {
|
||||
connectionId: string;
|
||||
sql: string;
|
||||
errorMessage: string;
|
||||
}) => fixSqlError(connectionId, sql, errorMessage),
|
||||
export function useOpenRouterModels(apiKey: string | undefined) {
|
||||
return useQuery({
|
||||
queryKey: ["openrouter-models", apiKey],
|
||||
queryFn: () => listOpenRouterModels(apiKey!),
|
||||
enabled: !!apiKey && apiKey.trim().length > 0,
|
||||
retry: false,
|
||||
staleTime: 60_000,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -214,14 +214,8 @@ export const listOllamaModels = (ollamaUrl: string) =>
|
||||
export const listFireworksModels = (apiKey: string) =>
|
||||
invoke<OllamaModel[]>("list_fireworks_models", { apiKey });
|
||||
|
||||
export const generateSql = (connectionId: string, prompt: string) =>
|
||||
invoke<string>("generate_sql", { connectionId, prompt });
|
||||
|
||||
export const explainSql = (connectionId: string, sql: string) =>
|
||||
invoke<string>("explain_sql", { connectionId, sql });
|
||||
|
||||
export const fixSqlError = (connectionId: string, sql: string, errorMessage: string) =>
|
||||
invoke<string>("fix_sql_error", { connectionId, sql, errorMessage });
|
||||
export const listOpenRouterModels = (apiKey: string) =>
|
||||
invoke<OllamaModel[]>("list_openrouter_models", { apiKey });
|
||||
|
||||
export const chatSend = (connectionId: string, messages: ChatMessage[]) =>
|
||||
invoke<ChatTurnResult>("chat_send", { connectionId, messages });
|
||||
|
||||
@@ -134,14 +134,13 @@ export interface SavedQuery {
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export type AiProvider = "ollama" | "openai" | "anthropic" | "fireworks";
|
||||
export type AiProvider = "ollama" | "fireworks" | "openrouter";
|
||||
|
||||
export interface AiSettings {
|
||||
provider: AiProvider;
|
||||
ollama_url: string;
|
||||
openai_api_key?: string;
|
||||
anthropic_api_key?: string;
|
||||
fireworks_api_key?: string;
|
||||
openrouter_api_key?: string;
|
||||
model: string;
|
||||
}
|
||||
|
||||
@@ -216,14 +215,3 @@ export interface ChatTurnResult {
|
||||
messages: ChatMessage[];
|
||||
usage: ContextUsage;
|
||||
}
|
||||
|
||||
export type ChartType = "bar" | "line" | "area" | "pie";
|
||||
|
||||
export interface ChartConfig {
|
||||
chart_type: ChartType;
|
||||
x: string;
|
||||
y: string;
|
||||
group?: string | null;
|
||||
title?: string | null;
|
||||
orientation?: string | null;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user