feat: add Fireworks AI provider for chat agent

Routes chat-completions through a managed OpenAI-compatible inference
endpoint as an alternative to local Ollama, useful when the agent needs
fast multi-hop reasoning that local hardware can't sustain.

- backend: rename `call_ollama_chat_messages` → `call_chat_messages`,
  dispatch by provider; add `call_fireworks` branch (Bearer auth,
  `response_format: json_object` mapped from internal `format="json"`)
  and `list_fireworks_models` Tauri command
- settings: extend `AiProvider` enum + `AiSettings.fireworks_api_key`
  (serde-default for legacy config compat); Fireworks base URL hardcoded
- UI: provider selector in both popover and AppSettingsSheet (only
  ollama+fireworks shown; legacy openai/anthropic kept for serde-compat
  but normalized to ollama in UI); password input + dynamic model list
  for Fireworks; switching provider clears stale model selection
- 4 unit tests: serde round-trip, legacy settings deserialization,
  Fireworks chat-completions parsing, models-list parsing
This commit is contained in:
2026-05-06 23:04:10 +03:00
parent 532ebf3b44
commit 96a54edcd0
10 changed files with 524 additions and 65 deletions

View File

@@ -1,6 +1,7 @@
use crate::error::{TuskError, TuskResult};
use crate::models::ai::{
AiProvider, AiSettings, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel,
AiProvider, AiSettings, FireworksChatRequest, FireworksChatResponse, FireworksModelsResponse,
FireworksResponseFormat, OllamaChatMessage, OllamaChatRequest, OllamaChatResponse, OllamaModel,
OllamaTagsResponse,
};
use crate::state::{AppState, DbFlavor};
@@ -13,6 +14,7 @@ use tauri::{AppHandle, Manager, State};
const MAX_RETRIES: u32 = 2;
const RETRY_DELAY_MS: u64 = 1000;
const FIREWORKS_BASE_URL: &str = "https://api.fireworks.ai/inference/v1";
fn http_client() -> &'static reqwest::Client {
use std::sync::LazyLock;
@@ -85,6 +87,42 @@ pub async fn list_ollama_models(ollama_url: String) -> TuskResult<Vec<OllamaMode
Ok(tags.models)
}
#[tauri::command]
pub async fn list_fireworks_models(api_key: String) -> TuskResult<Vec<OllamaModel>> {
let key = api_key.trim();
if key.is_empty() {
return Err(TuskError::Ai("Fireworks API key required".to_string()));
}
let url = format!("{}/models", FIREWORKS_BASE_URL);
let resp = http_client()
.get(&url)
.bearer_auth(key)
.send()
.await
.map_err(|e| TuskError::Ai(format!("Cannot reach Fireworks: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}",
status, body
)));
}
let parsed: FireworksModelsResponse = resp
.json()
.await
.map_err(|e| TuskError::Ai(format!("Failed to parse Fireworks models list: {}", e)))?;
Ok(parsed
.data
.into_iter()
.map(|m| OllamaModel { name: m.id })
.collect())
}
async fn call_ai_with_retry<F, Fut, T>(
_settings: &AiSettings,
operation: &str,
@@ -142,13 +180,13 @@ pub(crate) async fn load_ai_settings(app: &AppHandle, state: &AppState) -> TuskR
Ok(settings)
}
async fn call_ollama_chat(
async fn call_chat_simple(
app: &AppHandle,
state: &AppState,
system_prompt: String,
user_content: String,
) -> TuskResult<String> {
call_ollama_chat_messages(
call_chat_messages(
app,
state,
vec![
@@ -166,7 +204,10 @@ async fn call_ollama_chat(
.await
}
pub(crate) async fn call_ollama_chat_messages(
/// Provider-agnostic chat-completions dispatcher used by every LLM-driven feature
/// (chat agent, generate_sql, explain_sql, fix_sql_error). Returns the model's
/// raw text content.
pub(crate) async fn call_chat_messages(
app: &AppHandle,
state: &AppState,
messages: Vec<OllamaChatMessage>,
@@ -180,24 +221,30 @@ pub(crate) async fn call_ollama_chat_messages(
));
}
if settings.provider != AiProvider::Ollama {
return Err(TuskError::Ai(format!(
match settings.provider {
AiProvider::Ollama => call_ollama(&settings, messages, format).await,
AiProvider::Fireworks => call_fireworks(&settings, messages, format).await,
AiProvider::OpenAi | AiProvider::Anthropic => Err(TuskError::Ai(format!(
"Provider {:?} not implemented yet",
settings.provider
)));
))),
}
}
let model = settings.model.clone();
async fn call_ollama(
settings: &AiSettings,
messages: Vec<OllamaChatMessage>,
format: Option<String>,
) -> TuskResult<String> {
let url = format!("{}/api/chat", settings.ollama_url.trim_end_matches('/'));
let request = OllamaChatRequest {
model: model.clone(),
model: settings.model.clone(),
messages,
stream: false,
format,
};
call_ai_with_retry(&settings, "Ollama request", || {
call_ai_with_retry(settings, "Ollama request", || {
let url = url.clone();
let request = request.clone();
async move {
@@ -230,6 +277,75 @@ pub(crate) async fn call_ollama_chat_messages(
.await
}
async fn call_fireworks(
settings: &AiSettings,
messages: Vec<OllamaChatMessage>,
format: Option<String>,
) -> TuskResult<String> {
let api_key = settings
.fireworks_api_key
.clone()
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.ok_or_else(|| {
TuskError::Ai("Fireworks API key not set. Open AI settings to add it.".to_string())
})?;
let url = format!("{}/chat/completions", FIREWORKS_BASE_URL);
let response_format = format.as_deref().map(|f| FireworksResponseFormat {
kind: if f == "json" {
"json_object".to_string()
} else {
f.to_string()
},
});
let request = FireworksChatRequest {
model: settings.model.clone(),
messages,
temperature: 0.0,
response_format,
};
call_ai_with_retry(settings, "Fireworks request", || {
let url = url.clone();
let request = request.clone();
let api_key = api_key.clone();
async move {
let resp = http_client()
.post(&url)
.bearer_auth(&api_key)
.json(&request)
.send()
.await
.map_err(|e| {
TuskError::Ai(format!("Cannot reach Fireworks at {}: {}", url, e))
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(TuskError::Ai(format!(
"Fireworks error ({}): {}",
status, body
)));
}
let parsed: FireworksChatResponse = resp.json().await.map_err(|e| {
TuskError::Ai(format!("Failed to parse Fireworks response: {}", e))
})?;
parsed
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| TuskError::Ai("Fireworks returned no choices".to_string()))
}
})
.await
}
// ---------------------------------------------------------------------------
// SQL generation
// ---------------------------------------------------------------------------
@@ -310,7 +426,7 @@ pub async fn generate_sql(
schema_text
);
let raw = call_ollama_chat(&app, &state, system_prompt, prompt).await?;
let raw = call_chat_simple(&app, &state, system_prompt, prompt).await?;
Ok(clean_sql_response(&raw))
}
@@ -347,7 +463,7 @@ pub async fn explain_sql(
schema_text
);
call_ollama_chat(&app, &state, system_prompt, sql).await
call_chat_simple(&app, &state, system_prompt, sql).await
}
// ---------------------------------------------------------------------------
@@ -391,7 +507,7 @@ pub async fn fix_sql_error(
let user_content = format!("SQL query:\n{}\n\nError message:\n{}", sql, error_message);
let raw = call_ollama_chat(&app, &state, system_prompt, user_content).await?;
let raw = call_chat_simple(&app, &state, system_prompt, user_content).await?;
Ok(clean_sql_response(&raw))
}
@@ -1453,4 +1569,60 @@ mod tests {
"SELECT\n *\nFROM users"
);
}
// ── Fireworks provider ───────────────────────────────────
#[test]
fn serializes_fireworks_provider() {
let json = serde_json::to_string(&AiProvider::Fireworks).unwrap();
assert_eq!(json, "\"fireworks\"");
}
#[test]
fn deserializes_legacy_settings_without_fireworks_key() {
// Old config files won't have `fireworks_api_key` — must still parse.
let legacy = r#"{
"provider": "ollama",
"ollama_url": "http://localhost:11434",
"openai_api_key": null,
"anthropic_api_key": null,
"model": "qwen2.5-coder:7b"
}"#;
let parsed: AiSettings = serde_json::from_str(legacy).unwrap();
assert_eq!(parsed.provider, AiProvider::Ollama);
assert_eq!(parsed.ollama_url, "http://localhost:11434");
assert!(parsed.fireworks_api_key.is_none());
assert_eq!(parsed.model, "qwen2.5-coder:7b");
}
#[test]
fn parses_fireworks_chat_response() {
let body = r#"{
"choices": [
{"message": {"role": "assistant", "content": "hi"}}
]
}"#;
let parsed: FireworksChatResponse = serde_json::from_str(body).unwrap();
assert_eq!(parsed.choices.len(), 1);
assert_eq!(parsed.choices[0].message.role, "assistant");
assert_eq!(parsed.choices[0].message.content, "hi");
}
#[test]
fn parses_fireworks_models_list() {
let body = r#"{
"data": [
{"id": "accounts/fireworks/models/qwen2p5-coder-32b-instruct"},
{"id": "accounts/fireworks/models/deepseek-v3"}
]
}"#;
let parsed: FireworksModelsResponse = serde_json::from_str(body).unwrap();
let names: Vec<String> = parsed.data.into_iter().map(|m| m.id).collect();
assert_eq!(names.len(), 2);
assert_eq!(
names[0],
"accounts/fireworks/models/qwen2p5-coder-32b-instruct"
);
assert_eq!(names[1], "accounts/fireworks/models/deepseek-v3");
}
}

View File

@@ -1,4 +1,4 @@
use crate::commands::ai::{build_overview_context, call_ollama_chat_messages};
use crate::commands::ai::{build_overview_context, call_chat_messages};
use crate::commands::chat_tools::{
find_queries_tool, get_columns_tool, list_databases_tool, list_tables_tool, save_query_tool,
switch_database_tool,
@@ -555,7 +555,7 @@ pub async fn chat_send(
let history = build_history(&working, &overview_text, &memory_text);
let raw =
call_ollama_chat_messages(&app, &state, history, Some("json".to_string())).await?;
call_chat_messages(&app, &state, history, Some("json".to_string())).await?;
let trimmed = raw.trim();
let action = match parse_agent_action(trimmed) {
@@ -1039,7 +1039,7 @@ async fn force_final_synthesis(
content: convo,
},
];
match call_ollama_chat_messages(app, state, llm_messages, None).await {
match call_chat_messages(app, state, llm_messages, None).await {
Ok(s) => {
let cleaned = clean_summary(&s);
if cleaned.trim().is_empty() {
@@ -1148,7 +1148,7 @@ pub async fn chat_compact(
content: convo,
},
];
let summary = call_ollama_chat_messages(&app, &state, llm_messages, None)
let summary = call_chat_messages(&app, &state, llm_messages, None)
.await
.map_err(|e| TuskError::Ai(format!("Compact failed: {}", e)))?;

View File

@@ -110,6 +110,7 @@ pub fn run() {
commands::ai::get_ai_settings,
commands::ai::save_ai_settings,
commands::ai::list_ollama_models,
commands::ai::list_fireworks_models,
commands::ai::generate_sql,
commands::ai::explain_sql,
commands::ai::fix_sql_error,

View File

@@ -7,14 +7,19 @@ pub enum AiProvider {
Ollama,
OpenAi,
Anthropic,
Fireworks,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiSettings {
pub provider: AiProvider,
pub ollama_url: String,
#[serde(default)]
pub openai_api_key: Option<String>,
#[serde(default)]
pub anthropic_api_key: Option<String>,
#[serde(default)]
pub fireworks_api_key: Option<String>,
pub model: String,
}
@@ -25,11 +30,14 @@ impl Default for AiSettings {
ollama_url: "http://localhost:11434".to_string(),
openai_api_key: None,
anthropic_api_key: None,
fireworks_api_key: None,
model: String::new(),
}
}
}
/// Generic chat message used by all chat providers (Ollama, Fireworks, OpenAI-compatible).
/// `{role, content}` shape is identical across these APIs.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaChatMessage {
pub role: String,
@@ -55,7 +63,48 @@ pub struct OllamaTagsResponse {
pub models: Vec<OllamaModel>,
}
/// Generic chat-model descriptor exposed to the UI dropdown.
/// Reused as the return shape for both Ollama and Fireworks model listings.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaModel {
pub name: String,
}
// ---------------------------------------------------------------------------
// Fireworks (OpenAI-compatible chat-completions)
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize)]
pub struct FireworksChatRequest {
pub model: String,
pub messages: Vec<OllamaChatMessage>,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<FireworksResponseFormat>,
}
#[derive(Debug, Clone, Serialize)]
pub struct FireworksResponseFormat {
#[serde(rename = "type")]
pub kind: String,
}
#[derive(Debug, Deserialize)]
pub struct FireworksChatResponse {
pub choices: Vec<FireworksChoice>,
}
#[derive(Debug, Deserialize)]
pub struct FireworksChoice {
pub message: OllamaChatMessage,
}
#[derive(Debug, Deserialize)]
pub struct FireworksModelsResponse {
pub data: Vec<FireworksModelEntry>,
}
#[derive(Debug, Deserialize)]
pub struct FireworksModelEntry {
pub id: String,
}