feat: add Fireworks AI provider for chat agent

Routes chat-completions through a managed OpenAI-compatible inference
endpoint as an alternative to local Ollama, useful when the agent needs
fast multi-hop reasoning that local hardware can't sustain.

- backend: rename `call_ollama_chat_messages` → `call_chat_messages`,
  dispatch by provider; add `call_fireworks` branch (Bearer auth,
  `response_format: json_object` mapped from internal `format="json"`)
  and `list_fireworks_models` Tauri command
- settings: extend `AiProvider` enum + `AiSettings.fireworks_api_key`
  (serde-default for legacy config compat); Fireworks base URL hardcoded
- UI: provider selector in both popover and AppSettingsSheet (only
  ollama+fireworks shown; legacy openai/anthropic kept for serde-compat
  but normalized to ollama in UI); password input + dynamic model list
  for Fireworks; switching provider clears stale model selection
- 4 unit tests: serde round-trip, legacy settings deserialization,
  Fireworks chat-completions parsing, models-list parsing
This commit is contained in:
2026-05-06 23:04:10 +03:00
parent 532ebf3b44
commit 96a54edcd0
10 changed files with 524 additions and 65 deletions

View File

@@ -1,6 +1,7 @@
use crate::error::{TuskError, TuskResult}; use crate::error::{TuskError, TuskResult};
use crate::models::ai::{ use crate::models::ai::{
AiProvider, AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel, AiProvider, AiSettings, FireworksChatRequest, FireworksChatResponse, FireworksModelsResponse,
FireworksResponseFormat, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel,
OllamaTagsResponse, OllamaTagsResponse,
}; };
use crate::state::{AppState, DbFlavor}; use crate::state::{AppState, DbFlavor};
@@ -13,6 +14,7 @@ use tauri::{AppHandle, Manager, State};
const MAX_RETRIES: u32 = 2; const MAX_RETRIES: u32 = 2;
const RETRY_DELAY_MS: u64 = 1000; const RETRY_DELAY_MS: u64 = 1000;
const FIREWORKS_BASE_URL: &str = "https://api.fireworks.ai/inference/v1";
fn http_client() -> &'static reqwest::Client { fn http_client() -> &'static reqwest::Client {
use std::sync::LazyLock; use std::sync::LazyLock;
@@ -85,6 +87,42 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaMode
Ok(tags.models) Ok(tags.models)
} }
#[tauri::command]
pub async fn list_fireworks_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
let key = api_key.trim();
if key.is_empty() {
return Err(TuskError::Ai("Fireworks API key required".to_string()));
}
let url = format!("{}/models", FIREWORKS_BASE_URL);
let resp = http_client()
.get(&url)
.bearer_auth(key)
.send()
.await
.map_err(|e| TuskError::Ai(format!("Cannot reach Fireworks: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}",
status, body
)));
}
let parsed: FireworksModelsResponse = resp
.json()
.await
.map_err(|e| TuskError::Ai(format!("Failed to parse Fireworks models list: {}", e)))?;
Ok(parsed
.data
.into_iter()
.map(|m| OllamaModel { name: m.id })
.collect())
}
async fn call_ai_with_retry<F, Fut, T>( async fn call_ai_with_retry<F, Fut, T>(
_settings: &AiSettings, _settings: &AiSettings,
operation: &str, operation: &str,
@@ -142,13 +180,13 @@ pub(crate) async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskR
Ok(settings) Ok(settings)
} }
async fn call_ollama_chat( async fn call_chat_simple(
app: &AppHandle, app: &AppHandle,
state: &AppState, state: &AppState,
system_prompt: String, system_prompt: String,
user_content: String, user_content: String,
) -> TuskResult<String> { ) -> TuskResult<String> {
call_ollama_chat_messages( call_chat_messages(
app, app,
state, state,
vec![ vec![
@@ -166,7 +204,10 @@ async fn call_ollama_chat(
.await .await
} }
pub(crate) async fn call_ollama_chat_messages( /// Provider-agnostic chat-completions dispatcher used by every LLM-driven feature
/// (chat agent, generate_sql, explain_sql, fix_sql_error). Returns the model's
/// raw text content.
pub(crate) async fn call_chat_messages(
app: &AppHandle, app: &AppHandle,
state: &AppState, state: &AppState,
messages: Vec<OllamaChatMessage>, messages: Vec<OllamaChatMessage>,
@@ -180,24 +221,30 @@ pub(crate) async fn call_ollama_chat_messages(
)); ));
} }
if settings.provider != AiProvider::Ollama { match settings.provider {
return Err(TuskError::Ai(format!( AiProvider::Ollama => call_ollama(&settings, messages, format).await,
AiProvider::Fireworks => call_fireworks(&settings, messages, format).await,
AiProvider::OpenAi | AiProvider::Anthropic => Err(TuskError::Ai(format!(
"Provider {:?} not implemented yet", "Provider {:?} not implemented yet",
settings.provider settings.provider
))); ))),
} }
}
let model = settings.model.clone(); async fn call_ollama(
settings: &AiSettings,
messages: Vec<OllamaChatMessage>,
format: Option<String>,
) -> TuskResult<String> {
let url = format!("{}/api/chat", settings.ollama_url.trim_end_matches('/')); let url = format!("{}/api/chat", settings.ollama_url.trim_end_matches('/'));
let request = OllamaChatRequest { let request = OllamaChatRequest {
model: model.clone(), model: settings.model.clone(),
messages, messages,
stream: false, stream: false,
format, format,
}; };
call_ai_with_retry(&settings, "Ollama request", || { call_ai_with_retry(settings, "Ollama request", || {
let url = url.clone(); let url = url.clone();
let request = request.clone(); let request = request.clone();
async move { async move {
@@ -230,6 +277,75 @@ pub(crate) async fn call_ollama_chat_messages(
.await .await
} }
async fn call_fireworks(
settings: &AiSettings,
messages: Vec<OllamaChatMessage>,
format: Option<String>,
) -> TuskResult<String> {
let api_key = settings
.fireworks_api_key
.clone()
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.ok_or_else(|| {
TuskError::Ai("Fireworks API key not set. Open AI settings to add it.".to_string())
})?;
let url = format!("{}/chat/completions", FIREWORKS_BASE_URL);
let response_format = format.as_deref().map(|f| FireworksResponseFormat {
kind: if f == "json" {
"json_object".to_string()
} else {
f.to_string()
},
});
let request = FireworksChatRequest {
model: settings.model.clone(),
messages,
temperature: 0.0,
response_format,
};
call_ai_with_retry(settings, "Fireworks request", || {
let url = url.clone();
let request = request.clone();
let api_key = api_key.clone();
async move {
let resp = http_client()
.post(&url)
.bearer_auth(&api_key)
.json(&request)
.send()
.await
.map_err(|e| {
TuskError::Ai(format!("Cannot reach Fireworks at {}: {}", url, e))
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}",
status, body
)));
}
let parsed: FireworksChatResponse = resp.json().await.map_err(|e| {
TuskError::Ai(format!("Failed to parse Fireworks response: {}", e))
})?;
parsed
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| TuskError::Ai("Fireworks returned no choices".to_string()))
}
})
.await
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// SQL generation // SQL generation
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -310,7 +426,7 @@ pub async fn generate_sql(
schema_text schema_text
); );
let raw = call_ollama_chat(&app, &state, system_prompt, prompt).await?; let raw = call_chat_simple(&app, &state, system_prompt, prompt).await?;
Ok(clean_sql_response(&raw)) Ok(clean_sql_response(&raw))
} }
@@ -347,7 +463,7 @@ pub async fn explain_sql(
schema_text schema_text
); );
call_ollama_chat(&app, &state, system_prompt, sql).await call_chat_simple(&app, &state, system_prompt, sql).await
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -391,7 +507,7 @@ pub async fn fix_sql_error(
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message); let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
let raw = call_ollama_chat(&app, &state, system_prompt, user_content).await?; let raw = call_chat_simple(&app, &state, system_prompt, user_content).await?;
Ok(clean_sql_response(&raw)) Ok(clean_sql_response(&raw))
} }
@@ -1453,4 +1569,60 @@ mod tests {
"SELECT\n *\nFROM users" "SELECT\n *\nFROM users"
); );
} }
// ── Fireworks provider ───────────────────────────────────
#[test]
fn serializes_fireworks_provider() {
let json = serde_json::to_string(&AiProvider::Fireworks).unwrap();
assert_eq!(json, "\"fireworks\"");
}
#[test]
fn deserializes_legacy_settings_without_fireworks_key() {
// Old config files won't have `fireworks_api_key` — must still parse.
let legacy = r#"{
"provider": "ollama",
"ollama_url": "http://localhost:11434",
"openai_api_key": null,
"anthropic_api_key": null,
"model": "qwen2.5-coder:7b"
}"#;
let parsed: AiSettings = serde_json::from_str(legacy).unwrap();
assert_eq!(parsed.provider, AiProvider::Ollama);
assert_eq!(parsed.ollama_url, "http://localhost:11434");
assert!(parsed.fireworks_api_key.is_none());
assert_eq!(parsed.model, "qwen2.5-coder:7b");
}
#[test]
fn parses_fireworks_chat_response() {
let body = r#"{
"choices": [
{"message": {"role": "assistant", "content": "hi"}}
]
}"#;
let parsed: FireworksChatResponse = serde_json::from_str(body).unwrap();
assert_eq!(parsed.choices.len(), 1);
assert_eq!(parsed.choices[0].message.role, "assistant");
assert_eq!(parsed.choices[0].message.content, "hi");
}
#[test]
fn parses_fireworks_models_list() {
let body = r#"{
"data": [
{"id": "accounts/fireworks/models/qwen2p5-coder-32b-instruct"},
{"id": "accounts/fireworks/models/deepseek-v3"}
]
}"#;
let parsed: FireworksModelsResponse = serde_json::from_str(body).unwrap();
let names: Vec<String> = parsed.data.into_iter().map(|m| m.id).collect();
assert_eq!(names.len(), 2);
assert_eq!(
names[0],
"accounts/fireworks/models/qwen2p5-coder-32b-instruct"
);
assert_eq!(names[1], "accounts/fireworks/models/deepseek-v3");
}
} }

View File

@@ -1,4 +1,4 @@
use crate::commands::ai::{build_overview_context, call_ollama_chat_messages}; use crate::commands::ai::{build_overview_context, call_chat_messages};
use crate::commands::chat_tools::{ use crate::commands::chat_tools::{
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool, find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
switch_database_tool, switch_database_tool,
@@ -555,7 +555,7 @@ pub async fn chat_send(
let history = build_history(&working, &overview_text, &memory_text); let history = build_history(&working, &overview_text, &memory_text);
let raw = let raw =
call_ollama_chat_messages(&app, &state, history, Some("json".to_string())).await?; call_chat_messages(&app, &state, history, Some("json".to_string())).await?;
let trimmed = raw.trim(); let trimmed = raw.trim();
let action = match parse_agent_action(trimmed) { let action = match parse_agent_action(trimmed) {
@@ -1039,7 +1039,7 @@ async fn force_final_synthesis(
content: convo, content: convo,
}, },
]; ];
match call_ollama_chat_messages(app, state, llm_messages, None).await { match call_chat_messages(app, state, llm_messages, None).await {
Ok(s) => { Ok(s) => {
let cleaned = clean_summary(&s); let cleaned = clean_summary(&s);
if cleaned.trim().is_empty() { if cleaned.trim().is_empty() {
@@ -1148,7 +1148,7 @@ pub async fn chat_compact(
content: convo, content: convo,
}, },
]; ];
let summary = call_ollama_chat_messages(&app, &state, llm_messages, None) let summary = call_chat_messages(&app, &state, llm_messages, None)
.await .await
.map_err(|e| TuskError::Ai(format!("Compact failed: {}", e)))?; .map_err(|e| TuskError::Ai(format!("Compact failed: {}", e)))?;

View File

@@ -110,6 +110,7 @@ pub fn run() {
commands::ai::get_ai_settings, commands::ai::get_ai_settings,
commands::ai::save_ai_settings, commands::ai::save_ai_settings,
commands::ai::list_ollama_models, commands::ai::list_ollama_models,
commands::ai::list_fireworks_models,
commands::ai::generate_sql, commands::ai::generate_sql,
commands::ai::explain_sql, commands::ai::explain_sql,
commands::ai::fix_sql_error, commands::ai::fix_sql_error,

View File

@@ -7,14 +7,19 @@ pub enum AiProvider {
Ollama, Ollama,
OpenAi, OpenAi,
Anthropic, Anthropic,
Fireworks,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiSettings { pub struct AiSettings {
pub provider: AiProvider, pub provider: AiProvider,
pub ollama_url: String, pub ollama_url: String,
#[serde(default)]
pub openai_api_key: Option<String>, pub openai_api_key: Option<String>,
#[serde(default)]
pub anthropic_api_key: Option<String>, pub anthropic_api_key: Option<String>,
#[serde(default)]
pub fireworks_api_key: Option<String>,
pub model: String, pub model: String,
} }
@@ -25,11 +30,14 @@ impl Default for AiSettings {
ollama_url: "http://localhost:11434".to_string(), ollama_url: "http://localhost:11434".to_string(),
openai_api_key: None, openai_api_key: None,
anthropic_api_key: None, anthropic_api_key: None,
fireworks_api_key: None,
model: String::new(), model: String::new(),
} }
} }
} }
/// Generic chat message used by all chat providers (Ollama, Fireworks, OpenAI-compatible).
/// `{role, content}` shape is identical across these APIs.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaChatMessage { pub struct OllamaChatMessage {
pub role: String, pub role: String,
@@ -55,7 +63,48 @@ pub struct OllamaTagsResponse {
pub models: Vec<OllamaModel>, pub models: Vec<OllamaModel>,
} }
/// Generic chat-model descriptor exposed to the UI dropdown.
/// Reused as the return shape for both Ollama and Fireworks model listings.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaModel { pub struct OllamaModel {
pub name: String, pub name: String,
} }
// ---------------------------------------------------------------------------
// Fireworks (OpenAI-compatible chat-completions)
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize)]
pub struct FireworksChatRequest {
pub model: String,
pub messages: Vec<OllamaChatMessage>,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<FireworksResponseFormat>,
}
#[derive(Debug, Clone, Serialize)]
pub struct FireworksResponseFormat {
#[serde(rename = "type")]
pub kind: String,
}
#[derive(Debug, Deserialize)]
pub struct FireworksChatResponse {
pub choices: Vec<FireworksChoice>,
}
#[derive(Debug, Deserialize)]
pub struct FireworksChoice {
pub message: OllamaChatMessage,
}
#[derive(Debug, Deserialize)]
pub struct FireworksModelsResponse {
pub data: Vec<FireworksModelEntry>,
}
#[derive(Debug, Deserialize)]
pub struct FireworksModelEntry {
pub id: String,
}

View File

@@ -7,27 +7,66 @@ import {
SelectTrigger, SelectTrigger,
SelectValue, SelectValue,
} from "@/components/ui/select"; } from "@/components/ui/select";
import { useOllamaModels } from "@/hooks/use-ai"; import { useFireworksModels, useOllamaModels } from "@/hooks/use-ai";
import { RefreshCw, Loader2 } from "lucide-react"; import { RefreshCw, Loader2 } from "lucide-react";
import type { AiProvider, OllamaModel } from "@/types";
interface Props { interface Props {
provider: AiProvider;
ollamaUrl: string; ollamaUrl: string;
onOllamaUrlChange: (url: string) => void; onOllamaUrlChange: (url: string) => void;
fireworksApiKey: string;
onFireworksApiKeyChange: (key: string) => void;
model: string; model: string;
onModelChange: (model: string) => void; onModelChange: (model: string) => void;
} }
export function AiSettingsFields({ export function AiSettingsFields({
provider,
ollamaUrl,
onOllamaUrlChange,
fireworksApiKey,
onFireworksApiKeyChange,
model,
onModelChange,
}: Props) {
if (provider === "fireworks") {
return (
<FireworksFields
apiKey={fireworksApiKey}
onApiKeyChange={onFireworksApiKeyChange}
model={model}
onModelChange={onModelChange}
/>
);
}
return (
<OllamaFields
ollamaUrl={ollamaUrl}
onOllamaUrlChange={onOllamaUrlChange}
model={model}
onModelChange={onModelChange}
/>
);
}
function OllamaFields({
ollamaUrl, ollamaUrl,
onOllamaUrlChange, onOllamaUrlChange,
model, model,
onModelChange, onModelChange,
}: Props) { }: {
ollamaUrl: string;
onOllamaUrlChange: (url: string) => void;
model: string;
onModelChange: (model: string) => void;
}) {
const { const {
data: models, data: models,
isLoading: modelsLoading, isLoading,
isError: modelsError, isError,
refetch: refetchModels, refetch,
} = useOllamaModels(ollamaUrl); } = useOllamaModels(ollamaUrl);
return ( return (
@@ -42,41 +81,122 @@ export function AiSettingsFields({
/> />
</div> </div>
<div className="flex flex-col gap-1.5"> <ModelDropdown
<div className="flex items-center justify-between"> models={models}
<label className="text-xs text-muted-foreground">Model</label> loading={isLoading}
<Button errored={isError}
size="sm" errorText="Cannot connect to Ollama"
variant="ghost" onRefresh={() => refetch()}
className="h-5 w-5 p-0" model={model}
onClick={() => refetchModels()} onModelChange={onModelChange}
disabled={modelsLoading} />
title="Refresh models"
>
{modelsLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<RefreshCw className="h-3 w-3" />
)}
</Button>
</div>
{modelsError ? (
<p className="text-xs text-destructive">Cannot connect to Ollama</p>
) : (
<Select value={model} onValueChange={onModelChange}>
<SelectTrigger className="h-8 w-full text-xs">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{models?.map((m) => (
<SelectItem key={m.name} value={m.name}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
</> </>
); );
} }
function FireworksFields({
apiKey,
onApiKeyChange,
model,
onModelChange,
}: {
apiKey: string;
onApiKeyChange: (key: string) => void;
model: string;
onModelChange: (model: string) => void;
}) {
const {
data: models,
isLoading,
isError,
refetch,
} = useFireworksModels(apiKey);
return (
<>
<div className="flex flex-col gap-1.5">
<label className="text-xs text-muted-foreground">Fireworks API key</label>
<Input
type="password"
value={apiKey}
onChange={(e) => onApiKeyChange(e.target.value)}
placeholder="fw_..."
className="h-8 text-xs"
autoComplete="off"
/>
<p className="text-[10px] text-muted-foreground/70">
Stored locally; sent only to api.fireworks.ai.
</p>
</div>
<ModelDropdown
models={models}
loading={isLoading}
errored={isError}
errorText="Cannot reach Fireworks (check API key)"
onRefresh={() => refetch()}
model={model}
onModelChange={onModelChange}
emptyHint={apiKey.trim() ? "Click ↻ to load models" : "Enter API key first"}
/>
</>
);
}
function ModelDropdown({
models,
loading,
errored,
errorText,
onRefresh,
model,
onModelChange,
emptyHint,
}: {
models: OllamaModel[] | undefined;
loading: boolean;
errored: boolean;
errorText: string;
onRefresh: () => void;
model: string;
onModelChange: (model: string) => void;
emptyHint?: string;
}) {
return (
<div className="flex flex-col gap-1.5">
<div className="flex items-center justify-between">
<label className="text-xs text-muted-foreground">Model</label>
<Button
size="sm"
variant="ghost"
className="h-5 w-5 p-0"
onClick={onRefresh}
disabled={loading}
title="Refresh models"
>
{loading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<RefreshCw className="h-3 w-3" />
)}
</Button>
</div>
{errored ? (
<p className="text-xs text-destructive">{errorText}</p>
) : (
<Select value={model} onValueChange={onModelChange}>
<SelectTrigger className="h-8 w-full text-xs">
<SelectValue placeholder={emptyHint ?? "Select a model"} />
</SelectTrigger>
<SelectContent>
{models?.map((m) => (
<SelectItem key={m.name} value={m.name}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
);
}

View File

@@ -4,25 +4,68 @@ import {
PopoverContent, PopoverContent,
PopoverTrigger, PopoverTrigger,
} from "@/components/ui/popover"; } from "@/components/ui/popover";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { useAiSettings, useSaveAiSettings } from "@/hooks/use-ai"; import { useAiSettings, useSaveAiSettings } from "@/hooks/use-ai";
import { Settings } from "lucide-react"; import { Settings } from "lucide-react";
import { toast } from "sonner"; import { toast } from "sonner";
import { AiSettingsFields } from "./AiSettingsFields"; import { AiSettingsFields } from "./AiSettingsFields";
import type { AiProvider } from "@/types";
const SUPPORTED_PROVIDERS: { value: AiProvider; label: string }[] = [
{ value: "ollama", label: "Ollama (local)" },
{ value: "fireworks", label: "Fireworks AI" },
];
export function AiSettingsPopover() { export function AiSettingsPopover() {
const { data: settings } = useAiSettings(); const { data: settings } = useAiSettings();
const saveMutation = useSaveAiSettings(); const saveMutation = useSaveAiSettings();
const [provider, setProvider] = useState<AiProvider | null>(null);
const [url, setUrl] = useState<string | null>(null); const [url, setUrl] = useState<string | null>(null);
const [fireworksKey, setFireworksKey] = useState<string | null>(null);
const [model, setModel] = useState<string | null>(null); const [model, setModel] = useState<string | null>(null);
const settingsProvider = settings?.provider;
// Hide unsupported legacy values (openai/anthropic) from the selector.
const normalizedSettingsProvider: AiProvider | undefined =
settingsProvider === "ollama" || settingsProvider === "fireworks"
? settingsProvider
: settingsProvider
? "ollama"
: undefined;
const currentProvider: AiProvider =
provider ?? normalizedSettingsProvider ?? "ollama";
const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434"; const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434";
const currentFireworksKey =
fireworksKey ?? settings?.fireworks_api_key ?? "";
const currentModel = model ?? settings?.model ?? ""; const currentModel = model ?? settings?.model ?? "";
const handleProviderChange = (next: AiProvider) => {
if (next === currentProvider) return;
setProvider(next);
// Model lists differ between providers — drop the previous selection.
setModel("");
};
const handleSave = () => { const handleSave = () => {
saveMutation.mutate( saveMutation.mutate(
{ provider: "ollama", ollama_url: currentUrl, model: currentModel }, {
provider: currentProvider,
ollama_url: currentUrl,
fireworks_api_key:
currentProvider === "fireworks"
? currentFireworksKey.trim() || undefined
: settings?.fireworks_api_key,
model: currentModel,
},
{ {
onSuccess: () => toast.success("AI settings saved"), onSuccess: () => toast.success("AI settings saved"),
onError: (err) => onError: (err) =>
@@ -47,11 +90,33 @@ export function AiSettingsPopover() {
</PopoverTrigger> </PopoverTrigger>
<PopoverContent className="w-80" align="end"> <PopoverContent className="w-80" align="end">
<div className="flex flex-col gap-3"> <div className="flex flex-col gap-3">
<h4 className="text-sm font-medium">Ollama Settings</h4> <h4 className="text-sm font-medium">AI Settings</h4>
<div className="flex flex-col gap-1.5">
<label className="text-xs text-muted-foreground">Provider</label>
<Select
value={currentProvider}
onValueChange={(v) => handleProviderChange(v as AiProvider)}
>
<SelectTrigger className="h-8 w-full text-xs">
<SelectValue />
</SelectTrigger>
<SelectContent>
{SUPPORTED_PROVIDERS.map((p) => (
<SelectItem key={p.value} value={p.value}>
{p.label}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<AiSettingsFields <AiSettingsFields
provider={currentProvider}
ollamaUrl={currentUrl} ollamaUrl={currentUrl}
onOllamaUrlChange={setUrl} onOllamaUrlChange={setUrl}
fireworksApiKey={currentFireworksKey}
onFireworksApiKeyChange={setFireworksKey}
model={currentModel} model={currentModel}
onModelChange={setModel} onModelChange={setModel}
/> />

View File

@@ -22,7 +22,12 @@ import { useAiSettings, useSaveAiSettings } from "@/hooks/use-ai";
import { AiSettingsFields } from "@/components/ai/AiSettingsFields"; import { AiSettingsFields } from "@/components/ai/AiSettingsFields";
import { Loader2, Copy, Check } from "lucide-react"; import { Loader2, Copy, Check } from "lucide-react";
import { toast } from "sonner"; import { toast } from "sonner";
import type { AppSettings } from "@/types"; import type { AiProvider, AppSettings } from "@/types";
const SUPPORTED_AI_PROVIDERS: { value: AiProvider; label: string }[] = [
{ value: "ollama", label: "Ollama (local)" },
{ value: "fireworks", label: "Fireworks AI" },
];
interface Props { interface Props {
open: boolean; open: boolean;
@@ -42,7 +47,9 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
const [mcpPort, setMcpPort] = useState(9427); const [mcpPort, setMcpPort] = useState(9427);
// AI state // AI state
const [aiProvider, setAiProvider] = useState<AiProvider>("ollama");
const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434"); const [ollamaUrl, setOllamaUrl] = useState("http://localhost:11434");
const [fireworksApiKey, setFireworksApiKey] = useState("");
const [aiModel, setAiModel] = useState(""); const [aiModel, setAiModel] = useState("");
const [copied, setCopied] = useState(false); const [copied, setCopied] = useState(false);
@@ -61,11 +68,23 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
if (aiSettings !== prevAiSettings) { if (aiSettings !== prevAiSettings) {
setPrevAiSettings(aiSettings); setPrevAiSettings(aiSettings);
if (aiSettings) { if (aiSettings) {
// Legacy openai/anthropic values aren't user-selectable here — fall back to ollama.
setAiProvider(
aiSettings.provider === "fireworks" ? "fireworks" : "ollama"
);
setOllamaUrl(aiSettings.ollama_url); setOllamaUrl(aiSettings.ollama_url);
setFireworksApiKey(aiSettings.fireworks_api_key ?? "");
setAiModel(aiSettings.model); setAiModel(aiSettings.model);
} }
} }
const handleAiProviderChange = (next: AiProvider) => {
if (next === aiProvider) return;
setAiProvider(next);
// Model lists differ per provider — clear stale selection.
setAiModel("");
};
const mcpEndpoint = `http://127.0.0.1:${mcpPort}/mcp`; const mcpEndpoint = `http://127.0.0.1:${mcpPort}/mcp`;
const handleCopy = async () => { const handleCopy = async () => {
@@ -89,7 +108,15 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
// Save AI settings separately // Save AI settings separately
saveAiMutation.mutate( saveAiMutation.mutate(
{ provider: "ollama", ollama_url: ollamaUrl, model: aiModel }, {
provider: aiProvider,
ollama_url: ollamaUrl,
fireworks_api_key:
aiProvider === "fireworks"
? fireworksApiKey.trim() || undefined
: aiSettings?.fireworks_api_key,
model: aiModel,
},
{ {
onError: (err) => onError: (err) =>
toast.error("Failed to save AI settings", { description: String(err) }), toast.error("Failed to save AI settings", { description: String(err) }),
@@ -179,19 +206,29 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
<div className="flex flex-col gap-1.5"> <div className="flex flex-col gap-1.5">
<label className="text-xs text-muted-foreground">Provider</label> <label className="text-xs text-muted-foreground">Provider</label>
<Select value="ollama" disabled> <Select
value={aiProvider}
onValueChange={(v) => handleAiProviderChange(v as AiProvider)}
>
<SelectTrigger className="h-8 text-xs"> <SelectTrigger className="h-8 text-xs">
<SelectValue /> <SelectValue />
</SelectTrigger> </SelectTrigger>
<SelectContent> <SelectContent>
<SelectItem value="ollama">Ollama</SelectItem> {SUPPORTED_AI_PROVIDERS.map((p) => (
<SelectItem key={p.value} value={p.value}>
{p.label}
</SelectItem>
))}
</SelectContent> </SelectContent>
</Select> </Select>
</div> </div>
<AiSettingsFields <AiSettingsFields
provider={aiProvider}
ollamaUrl={ollamaUrl} ollamaUrl={ollamaUrl}
onOllamaUrlChange={setOllamaUrl} onOllamaUrlChange={setOllamaUrl}
fireworksApiKey={fireworksApiKey}
onFireworksApiKeyChange={setFireworksApiKey}
model={aiModel} model={aiModel}
onModelChange={setAiModel} onModelChange={setAiModel}
/> />

View File

@@ -3,6 +3,7 @@ import {
getAiSettings, getAiSettings,
saveAiSettings, saveAiSettings,
listOllamaModels, listOllamaModels,
listFireworksModels,
generateSql, generateSql,
explainSql, explainSql,
fixSqlError, fixSqlError,
@@ -36,6 +37,16 @@ export function useOllamaModels(ollamaUrl: string | undefined) {
}); });
} }
export function useFireworksModels(apiKey: string | undefined) {
return useQuery({
queryKey: ["fireworks-models", apiKey],
queryFn: () => listFireworksModels(apiKey!),
enabled: !!apiKey && apiKey.trim().length > 0,
retry: false,
staleTime: 60_000,
});
}
export function useGenerateSql() { export function useGenerateSql() {
return useMutation({ return useMutation({
mutationFn: ({ mutationFn: ({

View File

@@ -211,6 +211,9 @@ export const saveAiSettings = (settings: AiSettings) =>
export const listOllamaModels = (ollamaUrl: string) => export const listOllamaModels = (ollamaUrl: string) =>
invoke<OllamaModel[]>("list_ollama_models", { ollamaUrl }); invoke<OllamaModel[]>("list_ollama_models", { ollamaUrl });
export const listFireworksModels = (apiKey: string) =>
invoke<OllamaModel[]>("list_fireworks_models", { apiKey });
export const generateSql = (connectionId: string, prompt: string) => export const generateSql = (connectionId: string, prompt: string) =>
invoke<string>("generate_sql", { connectionId, prompt }); invoke<string>("generate_sql", { connectionId, prompt });

View File

@@ -134,13 +134,14 @@ export interface SavedQuery {
created_at: string; created_at: string;
} }
export type AiProvider = "ollama" | "openai" | "anthropic"; export type AiProvider = "ollama" | "openai" | "anthropic" | "fireworks";
export interface AiSettings { export interface AiSettings {
provider: AiProvider; provider: AiProvider;
ollama_url: string; ollama_url: string;
openai_api_key?: string; openai_api_key?: string;
anthropic_api_key?: string; anthropic_api_key?: string;
fireworks_api_key?: string;
model: string; model: string;
} }