fix: harden security, reduce duplication, and improve robustness

- Fix SQL injection in data.rs by wrapping get_table_data in READ ONLY transaction
- Fix SQL injection in docker.rs CREATE DATABASE via escape_ident
- Fix command injection in docker.rs by validating pg_version/container_name
  and escaping shell-interpolated values
- Fix UTF-8 panic on stderr truncation with char_indices
- Wrap delete_rows in a transaction for atomicity
- Replace .expect() with proper error propagation in lib.rs
- Cache AI settings in AppState to avoid repeated disk reads
- Cap JSONB column discovery at 50 to prevent unbounded queries
- Fix ERD colorMode to respect system theme via useTheme()
- Extract AppState::get_pool() replacing ~19 inline pool patterns
- Extract shared AiSettingsFields component (DRY popover + sheet)
- Make get_connections_path pub(crate) and reuse from docker.rs
- Deduplicate check_docker by delegating to check_docker_internal

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-21 11:41:14 +03:00
parent baa794b66a
commit d507162377
15 changed files with 1196 additions and 667 deletions

345
src-tauri/Cargo.lock generated
View File

@@ -383,6 +383,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]] [[package]]
name = "chrono" name = "chrono"
version = "0.4.43" version = "0.4.43"
@@ -438,16 +444,6 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "core-foundation" name = "core-foundation"
version = "0.10.1" version = "0.10.1"
@@ -471,9 +467,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1" checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
dependencies = [ dependencies = [
"bitflags 2.10.0", "bitflags 2.10.0",
"core-foundation 0.10.1", "core-foundation",
"core-graphics-types", "core-graphics-types",
"foreign-types 0.5.0", "foreign-types",
"libc", "libc",
] ]
@@ -484,7 +480,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
dependencies = [ dependencies = [
"bitflags 2.10.0", "bitflags 2.10.0",
"core-foundation 0.10.1", "core-foundation",
"libc", "libc",
] ]
@@ -930,12 +926,6 @@ dependencies = [
"pin-project-lite", "pin-project-lite",
] ]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]] [[package]]
name = "fdeflate" name = "fdeflate"
version = "0.3.7" version = "0.3.7"
@@ -994,15 +984,6 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared 0.1.1",
]
[[package]] [[package]]
name = "foreign-types" name = "foreign-types"
version = "0.5.0" version = "0.5.0"
@@ -1010,7 +991,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [ dependencies = [
"foreign-types-macros", "foreign-types-macros",
"foreign-types-shared 0.3.1", "foreign-types-shared",
] ]
[[package]] [[package]]
@@ -1024,12 +1005,6 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]] [[package]]
name = "foreign-types-shared" name = "foreign-types-shared"
version = "0.3.1" version = "0.3.1"
@@ -1291,8 +1266,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"wasi 0.11.1+wasi-snapshot-preview1", "wasi 0.11.1+wasi-snapshot-preview1",
"wasm-bindgen",
] ]
[[package]] [[package]]
@@ -1302,9 +1279,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"js-sys",
"libc", "libc",
"r-efi", "r-efi",
"wasip2", "wasip2",
"wasm-bindgen",
] ]
[[package]] [[package]]
@@ -1455,25 +1434,6 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "h2"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http",
"indexmap 2.13.0",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.12.3" version = "0.12.3"
@@ -1618,7 +1578,6 @@ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"h2",
"http", "http",
"http-body", "http-body",
"httparse", "httparse",
@@ -1645,22 +1604,7 @@ dependencies = [
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tower-service", "tower-service",
] "webpki-roots 1.0.6",
[[package]]
name = "hyper-tls"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0"
dependencies = [
"bytes",
"http-body-util",
"hyper",
"hyper-util",
"native-tls",
"tokio",
"tokio-native-tls",
"tower-service",
] ]
[[package]] [[package]]
@@ -1681,11 +1625,9 @@ dependencies = [
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"socket2", "socket2",
"system-configuration",
"tokio", "tokio",
"tower-service", "tower-service",
"tracing", "tracing",
"windows-registry",
] ]
[[package]] [[package]]
@@ -2079,12 +2021,6 @@ dependencies = [
"vcpkg", "vcpkg",
] ]
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]] [[package]]
name = "litemap" name = "litemap"
version = "0.8.1" version = "0.8.1"
@@ -2106,6 +2042,12 @@ version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "lru-slab"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154"
[[package]] [[package]]
name = "mac" name = "mac"
version = "0.1.1" version = "0.1.1"
@@ -2222,23 +2164,6 @@ dependencies = [
"windows-sys 0.60.2", "windows-sys 0.60.2",
] ]
[[package]]
name = "native-tls"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6cdede44f9a69cab2899a2049e2c3bd49bf911a157f6a3353d4a91c61abbce44"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]] [[package]]
name = "ndk" name = "ndk"
version = "0.9.0" version = "0.9.0"
@@ -2595,50 +2520,6 @@ dependencies = [
"pathdiff", "pathdiff",
] ]
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "option-ext" name = "option-ext"
version = "0.2.0" version = "0.2.0"
@@ -3042,6 +2923,61 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "quinn"
version = "0.11.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20"
dependencies = [
"bytes",
"cfg_aliases",
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"socket2",
"thiserror 2.0.18",
"tokio",
"tracing",
"web-time",
]
[[package]]
name = "quinn-proto"
version = "0.11.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31"
dependencies = [
"bytes",
"getrandom 0.3.4",
"lru-slab",
"rand 0.9.2",
"ring",
"rustc-hash",
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.18",
"tinyvec",
"tracing",
"web-time",
]
[[package]]
name = "quinn-udp"
version = "0.5.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd"
dependencies = [
"cfg_aliases",
"libc",
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.44" version = "1.0.44"
@@ -3259,29 +3195,26 @@ checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
"encoding_rs",
"futures-core", "futures-core",
"h2",
"http", "http",
"http-body", "http-body",
"http-body-util", "http-body-util",
"hyper", "hyper",
"hyper-rustls", "hyper-rustls",
"hyper-tls",
"hyper-util", "hyper-util",
"js-sys", "js-sys",
"log", "log",
"mime",
"native-tls",
"percent-encoding", "percent-encoding",
"pin-project-lite", "pin-project-lite",
"quinn",
"rustls",
"rustls-pki-types", "rustls-pki-types",
"serde", "serde",
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-native-tls", "tokio-rustls",
"tower", "tower",
"tower-http", "tower-http",
"tower-service", "tower-service",
@@ -3289,6 +3222,7 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"web-sys", "web-sys",
"webpki-roots 1.0.6",
] ]
[[package]] [[package]]
@@ -3428,6 +3362,12 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "rustc-hash"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]] [[package]]
name = "rustc_version" name = "rustc_version"
version = "0.4.1" version = "0.4.1"
@@ -3437,19 +3377,6 @@ dependencies = [
"semver", "semver",
] ]
[[package]]
name = "rustix"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.23.36" version = "0.23.36"
@@ -3470,6 +3397,7 @@ version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [ dependencies = [
"web-time",
"zeroize", "zeroize",
] ]
@@ -3505,15 +3433,6 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "schannel"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "schemars" name = "schemars"
version = "0.8.22" version = "0.8.22"
@@ -3585,29 +3504,6 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "selectors" name = "selectors"
version = "0.24.0" version = "0.24.0"
@@ -4329,27 +4225,6 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "system-configuration"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "system-deps" name = "system-deps"
version = "6.2.2" version = "6.2.2"
@@ -4371,7 +4246,7 @@ checksum = "f3a753bdc39c07b192151523a3f77cd0394aa75413802c883a0f6f6a0e5ee2e7"
dependencies = [ dependencies = [
"bitflags 2.10.0", "bitflags 2.10.0",
"block2", "block2",
"core-foundation 0.10.1", "core-foundation",
"core-graphics", "core-graphics",
"crossbeam-channel", "crossbeam-channel",
"dispatch", "dispatch",
@@ -4713,19 +4588,6 @@ dependencies = [
"toml 0.9.12+spec-1.1.0", "toml 0.9.12+spec-1.1.0",
] ]
[[package]]
name = "tempfile"
version = "3.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "tendril" name = "tendril"
version = "0.4.3" version = "0.4.3"
@@ -4861,16 +4723,6 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.26.4" version = "0.26.4"
@@ -5440,6 +5292,16 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "web-time"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]] [[package]]
name = "webkit2gtk" name = "webkit2gtk"
version = "2.0.2" version = "2.0.2"
@@ -5697,17 +5559,6 @@ dependencies = [
"windows-link 0.1.3", "windows-link 0.1.3",
] ]
[[package]]
name = "windows-registry"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720"
dependencies = [
"windows-link 0.2.1",
"windows-result 0.4.1",
"windows-strings 0.5.1",
]
[[package]] [[package]]
name = "windows-result" name = "windows-result"
version = "0.3.4" version = "0.3.4"

View File

@@ -30,7 +30,7 @@ csv = "1"
log = "0.4" log = "0.4"
hex = "0.4" hex = "0.4"
bigdecimal = { version = "0.4", features = ["serde"] } bigdecimal = { version = "0.4", features = ["serde"] }
reqwest = { version = "0.12", features = ["json"] } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
rmcp = { version = "0.15", features = ["server", "macros", "transport-streamable-http-server"] } rmcp = { version = "0.15", features = ["server", "macros", "transport-streamable-http-server"] }
axum = "0.8" axum = "0.8"
schemars = "1" schemars = "1"

File diff suppressed because it is too large Load Diff

View File

@@ -14,7 +14,7 @@ pub struct ConnectResult {
pub flavor: DbFlavor, pub flavor: DbFlavor,
} }
fn get_connections_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> { pub(crate) fn get_connections_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
let dir = app let dir = app
.path() .path()
.app_data_dir() .app_data_dir()

View File

@@ -21,10 +21,7 @@ pub async fn get_table_data(
sort_direction: Option<String>, sort_direction: Option<String>,
filter: Option<String>, filter: Option<String>,
) -> TuskResult<PaginatedQueryResult> { ) -> TuskResult<PaginatedQueryResult> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table));
@@ -56,11 +53,24 @@ pub async fn get_table_data(
let start = Instant::now(); let start = Instant::now();
let (rows, count_row) = tokio::try_join!( // Always run table data queries in a read-only transaction to prevent
sqlx::query(&data_sql).fetch_all(pool), // writable CTEs or other mutation via the raw filter parameter.
sqlx::query(&count_sql).fetch_one(pool), let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
) sqlx::query("SET TRANSACTION READ ONLY")
.map_err(TuskError::Database)?; .execute(&mut *tx)
.await
.map_err(TuskError::Database)?;
let rows = sqlx::query(&data_sql)
.fetch_all(&mut *tx)
.await
.map_err(TuskError::Database)?;
let count_row = sqlx::query(&count_sql)
.fetch_one(&mut *tx)
.await
.map_err(TuskError::Database)?;
tx.rollback().await.map_err(TuskError::Database)?;
let execution_time_ms = start.elapsed().as_millis(); let execution_time_ms = start.elapsed().as_millis();
let total_rows: i64 = count_row.get(0); let total_rows: i64 = count_row.get(0);
@@ -134,10 +144,7 @@ pub async fn update_row(
return Err(TuskError::ReadOnly); return Err(TuskError::ReadOnly);
} }
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table));
@@ -155,7 +162,7 @@ pub async fn update_row(
let mut query = sqlx::query(&sql); let mut query = sqlx::query(&sql);
query = bind_json_value(query, &value); query = bind_json_value(query, &value);
query = query.bind(ctid_val); query = query.bind(ctid_val);
query.execute(pool).await.map_err(TuskError::Database)?; query.execute(&pool).await.map_err(TuskError::Database)?;
} else { } else {
let where_parts: Vec<String> = pk_columns let where_parts: Vec<String> = pk_columns
.iter() .iter()
@@ -174,7 +181,7 @@ pub async fn update_row(
for pk_val in &pk_values { for pk_val in &pk_values {
query = bind_json_value(query, pk_val); query = bind_json_value(query, pk_val);
} }
query.execute(pool).await.map_err(TuskError::Database)?; query.execute(&pool).await.map_err(TuskError::Database)?;
} }
Ok(()) Ok(())
@@ -193,10 +200,7 @@ pub async fn insert_row(
return Err(TuskError::ReadOnly); return Err(TuskError::ReadOnly);
} }
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table));
@@ -215,7 +219,7 @@ pub async fn insert_row(
query = bind_json_value(query, val); query = bind_json_value(query, val);
} }
query.execute(pool).await.map_err(TuskError::Database)?; query.execute(&pool).await.map_err(TuskError::Database)?;
Ok(()) Ok(())
} }
@@ -234,14 +238,14 @@ pub async fn delete_rows(
return Err(TuskError::ReadOnly); return Err(TuskError::ReadOnly);
} }
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table)); let qualified = format!("{}.{}", escape_ident(&schema), escape_ident(&table));
let mut total_affected: u64 = 0; let mut total_affected: u64 = 0;
// Wrap all deletes in a transaction for atomicity
let mut tx = (&pool).begin().await.map_err(TuskError::Database)?;
if pk_columns.is_empty() { if pk_columns.is_empty() {
// Fallback: use ctids for row identification // Fallback: use ctids for row identification
let ctid_list = ctids.ok_or_else(|| { let ctid_list = ctids.ok_or_else(|| {
@@ -250,7 +254,7 @@ pub async fn delete_rows(
for ctid_val in &ctid_list { for ctid_val in &ctid_list {
let sql = format!("DELETE FROM {} WHERE ctid = $1::tid", qualified); let sql = format!("DELETE FROM {} WHERE ctid = $1::tid", qualified);
let query = sqlx::query(&sql).bind(ctid_val); let query = sqlx::query(&sql).bind(ctid_val);
let result = query.execute(pool).await.map_err(TuskError::Database)?; let result = query.execute(&mut *tx).await.map_err(TuskError::Database)?;
total_affected += result.rows_affected(); total_affected += result.rows_affected();
} }
} else { } else {
@@ -269,11 +273,13 @@ pub async fn delete_rows(
query = bind_json_value(query, val); query = bind_json_value(query, val);
} }
let result = query.execute(pool).await.map_err(TuskError::Database)?; let result = query.execute(&mut *tx).await.map_err(TuskError::Database)?;
total_affected += result.rows_affected(); total_affected += result.rows_affected();
} }
} }
tx.commit().await.map_err(TuskError::Database)?;
Ok(total_affected) Ok(total_affected)
} }

View File

@@ -4,9 +4,10 @@ use crate::models::docker::{
CloneMode, CloneProgress, CloneResult, CloneToDockerParams, DockerStatus, TuskContainer, CloneMode, CloneProgress, CloneResult, CloneToDockerParams, DockerStatus, TuskContainer,
}; };
use crate::state::AppState; use crate::state::AppState;
use crate::utils::escape_ident;
use std::fs; use std::fs;
use std::sync::Arc; use std::sync::Arc;
use tauri::{AppHandle, Emitter, Manager, State}; use tauri::{AppHandle, Emitter, State};
use tokio::process::Command; use tokio::process::Command;
async fn docker_cmd(state: &AppState) -> Command { async fn docker_cmd(state: &AppState) -> Command {
@@ -42,17 +43,8 @@ fn emit_progress(
); );
} }
fn get_connections_path(app: &AppHandle) -> TuskResult<std::path::PathBuf> {
let dir = app
.path()
.app_data_dir()
.map_err(|e| TuskError::Custom(e.to_string()))?;
fs::create_dir_all(&dir)?;
Ok(dir.join("connections.json"))
}
fn load_connection_config(app: &AppHandle, connection_id: &str) -> TuskResult<ConnectionConfig> { fn load_connection_config(app: &AppHandle, connection_id: &str) -> TuskResult<ConnectionConfig> {
let path = get_connections_path(app)?; let path = super::connections::get_connections_path(app)?;
if !path.exists() { if !path.exists() {
return Err(TuskError::ConnectionNotFound(connection_id.to_string())); return Err(TuskError::ConnectionNotFound(connection_id.to_string()));
} }
@@ -69,43 +61,58 @@ fn shell_escape(s: &str) -> String {
s.replace('\'', "'\\''") s.replace('\'', "'\\''")
} }
/// Validate pg_version matches a safe pattern (e.g. "16", "16.2", "17.1")
fn validate_pg_version(version: &str) -> TuskResult<()> {
let is_valid = !version.is_empty()
&& version
.chars()
.all(|c| c.is_ascii_digit() || c == '.');
if !is_valid {
return Err(docker_err(format!(
"Invalid pg_version '{}': must contain only digits and dots (e.g. '16', '16.2')",
version
)));
}
Ok(())
}
/// Validate container name matches Docker naming rules: [a-zA-Z0-9][a-zA-Z0-9_.-]*
fn validate_container_name(name: &str) -> TuskResult<()> {
if name.is_empty() {
return Err(docker_err("Container name cannot be empty"));
}
let first = name.chars().next().unwrap();
if !first.is_ascii_alphanumeric() {
return Err(docker_err(format!(
"Invalid container name '{}': must start with an alphanumeric character",
name
)));
}
let is_valid = name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '-');
if !is_valid {
return Err(docker_err(format!(
"Invalid container name '{}': only [a-zA-Z0-9_.-] characters are allowed",
name
)));
}
Ok(())
}
/// Shell-escape a string for use inside double-quoted shell contexts
fn shell_escape_double(s: &str) -> String {
s.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('$', "\\$")
.replace('`', "\\`")
.replace('!', "\\!")
}
#[tauri::command] #[tauri::command]
pub async fn check_docker(state: State<'_, Arc<AppState>>) -> TuskResult<DockerStatus> { pub async fn check_docker(state: State<'_, Arc<AppState>>) -> TuskResult<DockerStatus> {
let output = docker_cmd(&state) let docker_host = state.docker_host.read().await.clone();
.await check_docker_internal(&docker_host).await
.args(["version", "--format", "{{.Server.Version}}"])
.output()
.await;
match output {
Ok(out) => {
if out.status.success() {
let version = String::from_utf8_lossy(&out.stdout).trim().to_string();
Ok(DockerStatus {
installed: true,
daemon_running: true,
version: Some(version),
error: None,
})
} else {
let stderr = String::from_utf8_lossy(&out.stderr).trim().to_string();
let daemon_running = !stderr.contains("Cannot connect")
&& !stderr.contains("connection refused");
Ok(DockerStatus {
installed: true,
daemon_running,
version: None,
error: Some(stderr),
})
}
}
Err(_) => Ok(DockerStatus {
installed: false,
daemon_running: false,
version: None,
error: Some("Docker CLI not found. Please install Docker.".to_string()),
}),
}
} }
#[tauri::command] #[tauri::command]
@@ -252,6 +259,10 @@ async fn do_clone(
params: &CloneToDockerParams, params: &CloneToDockerParams,
clone_id: &str, clone_id: &str,
) -> TuskResult<CloneResult> { ) -> TuskResult<CloneResult> {
// Validate user inputs before any operations
validate_pg_version(&params.pg_version)?;
validate_container_name(&params.container_name)?;
let docker_host = state.docker_host.read().await.clone(); let docker_host = state.docker_host.read().await.clone();
// Step 1: Check Docker // Step 1: Check Docker
@@ -313,7 +324,7 @@ async fn do_clone(
.args([ .args([
"exec", &params.container_name, "exec", &params.container_name,
"psql", "-U", "postgres", "-c", "psql", "-U", "postgres", "-c",
&format!("CREATE DATABASE \"{}\"", params.source_database), &format!("CREATE DATABASE {}", escape_ident(&params.source_database)),
]) ])
.output() .output()
.await .await
@@ -492,7 +503,11 @@ async fn run_pipe_cmd(
if !stderr.is_empty() { if !stderr.is_empty() {
// Truncate for progress display (full log can be long) // Truncate for progress display (full log can be long)
let short = if stderr.len() > 500 { let short = if stderr.len() > 500 {
format!("{}...", &stderr[..500]) let truncated = stderr.char_indices()
.nth(500)
.map(|(i, _)| &stderr[..i])
.unwrap_or(&stderr);
format!("{}...", truncated)
} else { } else {
stderr.clone() stderr.clone()
}; };
@@ -633,13 +648,16 @@ async fn transfer_sample_data(
let table = parts[1]; let table = parts[1];
// Use COPY (SELECT ... LIMIT N) TO STDOUT piped to COPY ... FROM STDIN // Use COPY (SELECT ... LIMIT N) TO STDOUT piped to COPY ... FROM STDIN
// Escape schema/table for use inside double-quoted shell strings
let escaped_schema = shell_escape_double(schema);
let escaped_table = shell_escape_double(table);
let copy_out_sql = format!( let copy_out_sql = format!(
"\\copy (SELECT * FROM \\\"{}\\\".\\\"{}\\\" LIMIT {}) TO STDOUT", "\\copy (SELECT * FROM \\\"{}\\\".\\\"{}\\\" LIMIT {}) TO STDOUT",
schema, table, sample_rows escaped_schema, escaped_table, sample_rows
); );
let copy_in_sql = format!( let copy_in_sql = format!(
"\\copy \\\"{}\\\".\\\"{}\\\" FROM STDIN", "\\copy \\\"{}\\\".\\\"{}\\\" FROM STDIN",
schema, table escaped_schema, escaped_table
); );
let escaped_url = shell_escape(source_url); let escaped_url = shell_escape(source_url);
@@ -693,7 +711,7 @@ async fn transfer_sample_data(
} }
fn save_connection_config(app: &AppHandle, config: &ConnectionConfig) -> TuskResult<()> { fn save_connection_config(app: &AppHandle, config: &ConnectionConfig) -> TuskResult<()> {
let path = get_connections_path(app)?; let path = super::connections::get_connections_path(app)?;
let mut connections = if path.exists() { let mut connections = if path.exists() {
let data = fs::read_to_string(&path)?; let data = fs::read_to_string(&path)?;
serde_json::from_str::<Vec<ConnectionConfig>>(&data)? serde_json::from_str::<Vec<ConnectionConfig>>(&data)?
@@ -701,7 +719,12 @@ fn save_connection_config(app: &AppHandle, config: &ConnectionConfig) -> TuskRes
vec![] vec![]
}; };
connections.push(config.clone()); // Upsert by ID to avoid duplicate entries on retry
if let Some(pos) = connections.iter().position(|c| c.id == config.id) {
connections[pos] = config.clone();
} else {
connections.push(config.clone());
}
let data = serde_json::to_string_pretty(&connections)?; let data = serde_json::to_string_pretty(&connections)?;
fs::write(&path, data)?; fs::write(&path, data)?;

View File

@@ -14,17 +14,14 @@ pub async fn list_databases(
state: State<'_, Arc<AppState>>, state: State<'_, Arc<AppState>>,
connection_id: String, connection_id: String,
) -> TuskResult<Vec<String>> { ) -> TuskResult<Vec<String>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT datname FROM pg_database \ "SELECT datname FROM pg_database \
WHERE datistemplate = false \ WHERE datistemplate = false \
ORDER BY datname", ORDER BY datname",
) )
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -35,10 +32,7 @@ pub async fn list_schemas_core(
state: &AppState, state: &AppState,
connection_id: &str, connection_id: &str,
) -> TuskResult<Vec<String>> { ) -> TuskResult<Vec<String>> {
let pools = state.pools.read().await; let pool = state.get_pool(connection_id).await?;
let pool = pools
.get(connection_id)
.ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?;
let flavor = state.get_flavor(connection_id).await; let flavor = state.get_flavor(connection_id).await;
let sql = if flavor == DbFlavor::Greenplum { let sql = if flavor == DbFlavor::Greenplum {
@@ -52,7 +46,7 @@ pub async fn list_schemas_core(
}; };
let rows = sqlx::query(sql) let rows = sqlx::query(sql)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -72,10 +66,7 @@ pub async fn list_tables_core(
connection_id: &str, connection_id: &str,
schema: &str, schema: &str,
) -> TuskResult<Vec<SchemaObject>> { ) -> TuskResult<Vec<SchemaObject>> {
let pools = state.pools.read().await; let pool = state.get_pool(connection_id).await?;
let pool = pools
.get(connection_id)
.ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT t.table_name, \ "SELECT t.table_name, \
@@ -88,7 +79,7 @@ pub async fn list_tables_core(
ORDER BY t.table_name", ORDER BY t.table_name",
) )
.bind(schema) .bind(schema)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -119,10 +110,7 @@ pub async fn list_views(
connection_id: String, connection_id: String,
schema: String, schema: String,
) -> TuskResult<Vec<SchemaObject>> { ) -> TuskResult<Vec<SchemaObject>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT table_name FROM information_schema.views \ "SELECT table_name FROM information_schema.views \
@@ -130,7 +118,7 @@ pub async fn list_views(
ORDER BY table_name", ORDER BY table_name",
) )
.bind(&schema) .bind(&schema)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -152,10 +140,7 @@ pub async fn list_functions(
connection_id: String, connection_id: String,
schema: String, schema: String,
) -> TuskResult<Vec<SchemaObject>> { ) -> TuskResult<Vec<SchemaObject>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT routine_name FROM information_schema.routines \ "SELECT routine_name FROM information_schema.routines \
@@ -163,7 +148,7 @@ pub async fn list_functions(
ORDER BY routine_name", ORDER BY routine_name",
) )
.bind(&schema) .bind(&schema)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -185,10 +170,7 @@ pub async fn list_indexes(
connection_id: String, connection_id: String,
schema: String, schema: String,
) -> TuskResult<Vec<SchemaObject>> { ) -> TuskResult<Vec<SchemaObject>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT indexname FROM pg_indexes \ "SELECT indexname FROM pg_indexes \
@@ -196,7 +178,7 @@ pub async fn list_indexes(
ORDER BY indexname", ORDER BY indexname",
) )
.bind(&schema) .bind(&schema)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -218,10 +200,7 @@ pub async fn list_sequences(
connection_id: String, connection_id: String,
schema: String, schema: String,
) -> TuskResult<Vec<SchemaObject>> { ) -> TuskResult<Vec<SchemaObject>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT sequence_name FROM information_schema.sequences \ "SELECT sequence_name FROM information_schema.sequences \
@@ -229,7 +208,7 @@ pub async fn list_sequences(
ORDER BY sequence_name", ORDER BY sequence_name",
) )
.bind(&schema) .bind(&schema)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -251,10 +230,7 @@ pub async fn get_table_columns_core(
schema: &str, schema: &str,
table: &str, table: &str,
) -> TuskResult<Vec<ColumnInfo>> { ) -> TuskResult<Vec<ColumnInfo>> {
let pools = state.pools.read().await; let pool = state.get_pool(connection_id).await?;
let pool = pools
.get(connection_id)
.ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT \ "SELECT \
@@ -287,7 +263,7 @@ pub async fn get_table_columns_core(
) )
.bind(schema) .bind(schema)
.bind(table) .bind(table)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -323,10 +299,7 @@ pub async fn get_table_constraints(
schema: String, schema: String,
table: String, table: String,
) -> TuskResult<Vec<ConstraintInfo>> { ) -> TuskResult<Vec<ConstraintInfo>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT \ "SELECT \
@@ -376,7 +349,7 @@ pub async fn get_table_constraints(
) )
.bind(&schema) .bind(&schema)
.bind(&table) .bind(&table)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -402,10 +375,7 @@ pub async fn get_table_indexes(
schema: String, schema: String,
table: String, table: String,
) -> TuskResult<Vec<IndexInfo>> { ) -> TuskResult<Vec<IndexInfo>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT \ "SELECT \
@@ -422,7 +392,7 @@ pub async fn get_table_indexes(
) )
.bind(&schema) .bind(&schema)
.bind(&table) .bind(&table)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -443,10 +413,7 @@ pub async fn get_completion_schema(
connection_id: String, connection_id: String,
) -> TuskResult<HashMap<String, HashMap<String, Vec<String>>>> { ) -> TuskResult<HashMap<String, HashMap<String, Vec<String>>>> {
let flavor = state.get_flavor(&connection_id).await; let flavor = state.get_flavor(&connection_id).await;
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let sql = if flavor == DbFlavor::Greenplum { let sql = if flavor == DbFlavor::Greenplum {
"SELECT table_schema, table_name, column_name \ "SELECT table_schema, table_name, column_name \
@@ -461,7 +428,7 @@ pub async fn get_completion_schema(
}; };
let rows = sqlx::query(sql) let rows = sqlx::query(sql)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -490,10 +457,7 @@ pub async fn get_column_details(
table: String, table: String,
) -> TuskResult<Vec<ColumnDetail>> { ) -> TuskResult<Vec<ColumnDetail>> {
let flavor = state.get_flavor(&connection_id).await; let flavor = state.get_flavor(&connection_id).await;
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let sql = if flavor == DbFlavor::Greenplum { let sql = if flavor == DbFlavor::Greenplum {
"SELECT c.column_name, c.data_type, \ "SELECT c.column_name, c.data_type, \
@@ -516,7 +480,7 @@ pub async fn get_column_details(
let rows = sqlx::query(sql) let rows = sqlx::query(sql)
.bind(&schema) .bind(&schema)
.bind(&table) .bind(&table)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -539,10 +503,7 @@ pub async fn get_table_triggers(
schema: String, schema: String,
table: String, table: String,
) -> TuskResult<Vec<TriggerInfo>> { ) -> TuskResult<Vec<TriggerInfo>> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
let rows = sqlx::query( let rows = sqlx::query(
"SELECT \ "SELECT \
@@ -571,7 +532,7 @@ pub async fn get_table_triggers(
) )
.bind(&schema) .bind(&schema)
.bind(&table) .bind(&table)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -595,10 +556,7 @@ pub async fn get_schema_erd(
connection_id: String, connection_id: String,
schema: String, schema: String,
) -> TuskResult<ErdData> { ) -> TuskResult<ErdData> {
let pools = state.pools.read().await; let pool = state.get_pool(&connection_id).await?;
let pool = pools
.get(&connection_id)
.ok_or(TuskError::NotConnected(connection_id))?;
// Get all tables with columns // Get all tables with columns
let col_rows = sqlx::query( let col_rows = sqlx::query(
@@ -627,7 +585,7 @@ pub async fn get_schema_erd(
ORDER BY c.table_name, c.ordinal_position", ORDER BY c.table_name, c.ordinal_position",
) )
.bind(&schema) .bind(&schema)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;
@@ -690,7 +648,7 @@ pub async fn get_schema_erd(
ORDER BY c.conname", ORDER BY c.conname",
) )
.bind(&schema) .bind(&schema)
.fetch_all(pool) .fetch_all(&pool)
.await .await
.map_err(TuskError::Database)?; .map_err(TuskError::Database)?;

View File

@@ -13,24 +13,20 @@ use tauri::Manager;
pub fn run() { pub fn run() {
let shared_state = Arc::new(AppState::new()); let shared_state = Arc::new(AppState::new());
tauri::Builder::default() let _ = tauri::Builder::default()
.plugin(tauri_plugin_shell::init()) .plugin(tauri_plugin_shell::init())
.plugin(tauri_plugin_dialog::init()) .plugin(tauri_plugin_dialog::init())
.manage(shared_state) .manage(shared_state)
.setup(|app| { .setup(|app| {
let state = app.state::<Arc<AppState>>().inner().clone(); let state = app.state::<Arc<AppState>>().inner().clone();
let connections_path = app let data_dir = app
.path() .path()
.app_data_dir() .app_data_dir()
.expect("failed to resolve app data dir") .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
.join("connections.json"); let connections_path = data_dir.join("connections.json");
// Read app settings // Read app settings
let settings_path = app let settings_path = data_dir.join("app_settings.json");
.path()
.app_data_dir()
.expect("failed to resolve app data dir")
.join("app_settings.json");
let settings = if settings_path.exists() { let settings = if settings_path.exists() {
std::fs::read_to_string(&settings_path) std::fs::read_to_string(&settings_path)
@@ -154,5 +150,7 @@ pub fn run() {
commands::settings::get_mcp_status, commands::settings::get_mcp_status,
]) ])
.run(tauri::generate_context!()) .run(tauri::generate_context!())
.expect("error while running tauri application"); .inspect_err(|e| {
log::error!("Tauri application error: {}", e);
});
} }

View File

@@ -1,27 +1,42 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum AiProvider {
#[default]
Ollama,
OpenAi,
Anthropic,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AiSettings { pub struct AiSettings {
pub provider: AiProvider,
pub ollama_url: String, pub ollama_url: String,
pub openai_api_key: Option<String>,
pub anthropic_api_key: Option<String>,
pub model: String, pub model: String,
} }
impl Default for AiSettings { impl Default for AiSettings {
fn default() -> Self { fn default() -> Self {
Self { Self {
provider: AiProvider::Ollama,
ollama_url: "http://localhost:11434".to_string(), ollama_url: "http://localhost:11434".to_string(),
openai_api_key: None,
anthropic_api_key: None,
model: String::new(), model: String::new(),
} }
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaChatMessage { pub struct OllamaChatMessage {
pub role: String, pub role: String,
pub content: String, pub content: String,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct OllamaChatRequest { pub struct OllamaChatRequest {
pub model: String, pub model: String,
pub messages: Vec<OllamaChatMessage>, pub messages: Vec<OllamaChatMessage>,

View File

@@ -1,3 +1,5 @@
use crate::error::{TuskError, TuskResult};
use crate::models::ai::AiSettings;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::PgPool; use sqlx::PgPool;
use std::collections::HashMap; use std::collections::HashMap;
@@ -27,6 +29,7 @@ pub struct AppState {
pub mcp_shutdown_tx: watch::Sender<bool>, pub mcp_shutdown_tx: watch::Sender<bool>,
pub mcp_running: RwLock<bool>, pub mcp_running: RwLock<bool>,
pub docker_host: RwLock<Option<String>>, pub docker_host: RwLock<Option<String>>,
pub ai_settings: RwLock<Option<AiSettings>>,
} }
const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes const SCHEMA_CACHE_TTL: Duration = Duration::from_secs(300); // 5 minutes
@@ -43,9 +46,18 @@ impl AppState {
mcp_shutdown_tx, mcp_shutdown_tx,
mcp_running: RwLock::new(false), mcp_running: RwLock::new(false),
docker_host: RwLock::new(None), docker_host: RwLock::new(None),
ai_settings: RwLock::new(None),
} }
} }
pub async fn get_pool(&self, connection_id: &str) -> TuskResult<PgPool> {
let pools = self.pools.read().await;
pools
.get(connection_id)
.cloned()
.ok_or_else(|| TuskError::NotConnected(connection_id.to_string()))
}
pub async fn is_read_only(&self, id: &str) -> bool { pub async fn is_read_only(&self, id: &str) -> bool {
let map = self.read_only.read().await; let map = self.read_only.read().await;
map.get(id).copied().unwrap_or(true) map.get(id).copied().unwrap_or(true)

View File

@@ -2,7 +2,7 @@
"$schema": "https://schema.tauri.app/config/2", "$schema": "https://schema.tauri.app/config/2",
"productName": "Tusk", "productName": "Tusk",
"version": "0.1.0", "version": "0.1.0",
"identifier": "com.tusk.app", "identifier": "com.tusk.dbm",
"build": { "build": {
"frontendDist": "../dist", "frontendDist": "../dist",
"devUrl": "http://localhost:5173", "devUrl": "http://localhost:5173",
@@ -27,7 +27,7 @@
}, },
"bundle": { "bundle": {
"active": true, "active": true,
"targets": "all", "targets": ["deb", "rpm", "dmg", "nsis"],
"icon": [ "icon": [
"icons/32x32.png", "icons/32x32.png",
"icons/128x128.png", "icons/128x128.png",

View File

@@ -0,0 +1,82 @@
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useOllamaModels } from "@/hooks/use-ai";
import { RefreshCw, Loader2 } from "lucide-react";
interface Props {
ollamaUrl: string;
onOllamaUrlChange: (url: string) => void;
model: string;
onModelChange: (model: string) => void;
}
export function AiSettingsFields({
ollamaUrl,
onOllamaUrlChange,
model,
onModelChange,
}: Props) {
const {
data: models,
isLoading: modelsLoading,
isError: modelsError,
refetch: refetchModels,
} = useOllamaModels(ollamaUrl);
return (
<>
<div className="flex flex-col gap-1.5">
<label className="text-xs text-muted-foreground">Ollama URL</label>
<Input
value={ollamaUrl}
onChange={(e) => onOllamaUrlChange(e.target.value)}
placeholder="http://localhost:11434"
className="h-8 text-xs"
/>
</div>
<div className="flex flex-col gap-1.5">
<div className="flex items-center justify-between">
<label className="text-xs text-muted-foreground">Model</label>
<Button
size="sm"
variant="ghost"
className="h-5 w-5 p-0"
onClick={() => refetchModels()}
disabled={modelsLoading}
title="Refresh models"
>
{modelsLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<RefreshCw className="h-3 w-3" />
)}
</Button>
</div>
{modelsError ? (
<p className="text-xs text-destructive">Cannot connect to Ollama</p>
) : (
<Select value={model} onValueChange={onModelChange}>
<SelectTrigger className="h-8 w-full text-xs">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{models?.map((m) => (
<SelectItem key={m.name} value={m.name}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
</>
);
}

View File

@@ -5,17 +5,10 @@ import {
PopoverTrigger, PopoverTrigger,
} from "@/components/ui/popover"; } from "@/components/ui/popover";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input"; import { useAiSettings, useSaveAiSettings } from "@/hooks/use-ai";
import { import { Settings } from "lucide-react";
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import { useAiSettings, useSaveAiSettings, useOllamaModels } from "@/hooks/use-ai";
import { Settings, RefreshCw, Loader2 } from "lucide-react";
import { toast } from "sonner"; import { toast } from "sonner";
import { AiSettingsFields } from "./AiSettingsFields";
export function AiSettingsPopover() { export function AiSettingsPopover() {
const { data: settings } = useAiSettings(); const { data: settings } = useAiSettings();
@@ -27,16 +20,9 @@ export function AiSettingsPopover() {
const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434"; const currentUrl = url ?? settings?.ollama_url ?? "http://localhost:11434";
const currentModel = model ?? settings?.model ?? ""; const currentModel = model ?? settings?.model ?? "";
const {
data: models,
isLoading: modelsLoading,
isError: modelsError,
refetch: refetchModels,
} = useOllamaModels(currentUrl);
const handleSave = () => { const handleSave = () => {
saveMutation.mutate( saveMutation.mutate(
{ ollama_url: currentUrl, model: currentModel }, { provider: "ollama", ollama_url: currentUrl, model: currentModel },
{ {
onSuccess: () => toast.success("AI settings saved"), onSuccess: () => toast.success("AI settings saved"),
onError: (err) => onError: (err) =>
@@ -63,53 +49,12 @@ export function AiSettingsPopover() {
<div className="flex flex-col gap-3"> <div className="flex flex-col gap-3">
<h4 className="text-sm font-medium">Ollama Settings</h4> <h4 className="text-sm font-medium">Ollama Settings</h4>
<div className="flex flex-col gap-1.5"> <AiSettingsFields
<label className="text-xs text-muted-foreground">Ollama URL</label> ollamaUrl={currentUrl}
<Input onOllamaUrlChange={setUrl}
value={currentUrl} model={currentModel}
onChange={(e) => setUrl(e.target.value)} onModelChange={setModel}
placeholder="http://localhost:11434" />
className="h-8 text-xs"
/>
</div>
<div className="flex flex-col gap-1.5">
<div className="flex items-center justify-between">
<label className="text-xs text-muted-foreground">Model</label>
<Button
size="sm"
variant="ghost"
className="h-5 w-5 p-0"
onClick={() => refetchModels()}
disabled={modelsLoading}
title="Refresh models"
>
{modelsLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<RefreshCw className="h-3 w-3" />
)}
</Button>
</div>
{modelsError ? (
<p className="text-xs text-destructive">
Cannot connect to Ollama
</p>
) : (
<Select value={currentModel} onValueChange={setModel}>
<SelectTrigger className="h-8 w-full text-xs">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{models?.map((m) => (
<SelectItem key={m.name} value={m.name}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
<Button size="sm" className="h-7 text-xs" onClick={handleSave}> <Button size="sm" className="h-7 text-xs" onClick={handleSave}>
Save Save

View File

@@ -1,4 +1,5 @@
import { useMemo, useCallback, useEffect, useState } from "react"; import { useMemo, useCallback, useEffect, useState } from "react";
import { useTheme } from "next-themes";
import { import {
ReactFlow, ReactFlow,
Background, Background,
@@ -100,6 +101,7 @@ interface Props {
export function ErdDiagram({ connectionId, schema }: Props) { export function ErdDiagram({ connectionId, schema }: Props) {
const { data: erdData, isLoading, error } = useSchemaErd(connectionId, schema); const { data: erdData, isLoading, error } = useSchemaErd(connectionId, schema);
const { resolvedTheme } = useTheme();
const layout = useMemo(() => { const layout = useMemo(() => {
if (!erdData) return null; if (!erdData) return null;
@@ -126,9 +128,6 @@ export function ErdDiagram({ connectionId, schema }: Props) {
[], [],
); );
const onInit = useCallback((instance: { fitView: () => void }) => {
setTimeout(() => instance.fitView(), 50);
}, []);
if (isLoading) { if (isLoading) {
return ( return (
@@ -162,9 +161,8 @@ export function ErdDiagram({ connectionId, schema }: Props) {
onNodesChange={onNodesChange} onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange} onEdgesChange={onEdgesChange}
nodeTypes={nodeTypes} nodeTypes={nodeTypes}
onInit={onInit}
fitView fitView
colorMode="dark" colorMode={resolvedTheme === "dark" ? "dark" : "light"}
minZoom={0.05} minZoom={0.05}
maxZoom={3} maxZoom={3}
zoomOnScroll zoomOnScroll

View File

@@ -18,8 +18,9 @@ import {
} from "@/components/ui/select"; } from "@/components/ui/select";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { useAppSettings, useSaveAppSettings, useMcpStatus } from "@/hooks/use-settings"; import { useAppSettings, useSaveAppSettings, useMcpStatus } from "@/hooks/use-settings";
import { useAiSettings, useSaveAiSettings, useOllamaModels } from "@/hooks/use-ai"; import { useAiSettings, useSaveAiSettings } from "@/hooks/use-ai";
import { RefreshCw, Loader2, Copy, Check } from "lucide-react"; import { AiSettingsFields } from "@/components/ai/AiSettingsFields";
import { Loader2, Copy, Check } from "lucide-react";
import { toast } from "sonner"; import { toast } from "sonner";
import type { AppSettings, DockerHost } from "@/types"; import type { AppSettings, DockerHost } from "@/types";
@@ -67,13 +68,6 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
} }
}, [aiSettings]); }, [aiSettings]);
const {
data: models,
isLoading: modelsLoading,
isError: modelsError,
refetch: refetchModels,
} = useOllamaModels(ollamaUrl);
const mcpEndpoint = `http://127.0.0.1:${mcpPort}/mcp`; const mcpEndpoint = `http://127.0.0.1:${mcpPort}/mcp`;
const handleCopy = async () => { const handleCopy = async () => {
@@ -233,51 +227,12 @@ export function AppSettingsSheet({ open, onOpenChange }: Props) {
</Select> </Select>
</div> </div>
<div className="flex flex-col gap-1.5"> <AiSettingsFields
<label className="text-xs text-muted-foreground">Ollama URL</label> ollamaUrl={ollamaUrl}
<Input onOllamaUrlChange={setOllamaUrl}
value={ollamaUrl} model={aiModel}
onChange={(e) => setOllamaUrl(e.target.value)} onModelChange={setAiModel}
placeholder="http://localhost:11434" />
className="h-8 text-xs"
/>
</div>
<div className="flex flex-col gap-1.5">
<div className="flex items-center justify-between">
<label className="text-xs text-muted-foreground">Model</label>
<Button
size="sm"
variant="ghost"
className="h-5 w-5 p-0"
onClick={() => refetchModels()}
disabled={modelsLoading}
title="Refresh models"
>
{modelsLoading ? (
<Loader2 className="h-3 w-3 animate-spin" />
) : (
<RefreshCw className="h-3 w-3" />
)}
</Button>
</div>
{modelsError ? (
<p className="text-xs text-destructive">Cannot connect to Ollama</p>
) : (
<Select value={aiModel} onValueChange={setAiModel}>
<SelectTrigger className="h-8 w-full text-xs">
<SelectValue placeholder="Select a model" />
</SelectTrigger>
<SelectContent>
{models?.map((m) => (
<SelectItem key={m.name} value={m.name}>
{m.name}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
</section> </section>
</div> </div>