Files
httpx/middleware/auth_test.go
Aleksey Shakhmatov a90c4cd7fa Add standard middlewares: logging, headers, auth, and panic recovery
- Logging: structured slog output with method, URL, status, duration
- DefaultHeaders/UserAgent: inject headers without overwriting existing
- BearerAuth/BasicAuth: per-request token resolution and static credentials
- Recovery: catches panics in the RoundTripper chain
2026-03-20 14:22:14 +03:00

90 lines
2.2 KiB
Go

package middleware_test
import (
"context"
"errors"
"net/http"
"testing"
"git.codelab.vc/pkg/httpx/middleware"
)
func TestBearerAuth(t *testing.T) {
t.Run("sets authorization header", func(t *testing.T) {
var captured http.Header
base := mockTransport(func(req *http.Request) (*http.Response, error) {
captured = req.Header.Clone()
return okResponse(), nil
})
tokenFunc := func(_ context.Context) (string, error) {
return "my-secret-token", nil
}
transport := middleware.BearerAuth(tokenFunc)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
want := "Bearer my-secret-token"
if got := captured.Get("Authorization"); got != want {
t.Errorf("Authorization = %q, want %q", got, want)
}
})
t.Run("returns error when tokenFunc fails", func(t *testing.T) {
base := mockTransport(func(req *http.Request) (*http.Response, error) {
t.Fatal("base transport should not be called")
return nil, nil
})
tokenErr := errors.New("token expired")
tokenFunc := func(_ context.Context) (string, error) {
return "", tokenErr
}
transport := middleware.BearerAuth(tokenFunc)(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, tokenErr) {
t.Errorf("got error %v, want %v", err, tokenErr)
}
})
}
func TestBasicAuth(t *testing.T) {
t.Run("sets basic auth header", func(t *testing.T) {
var capturedReq *http.Request
base := mockTransport(func(req *http.Request) (*http.Response, error) {
capturedReq = req
return okResponse(), nil
})
transport := middleware.BasicAuth("user", "pass")(base)
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
_, err := transport.RoundTrip(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
username, password, ok := capturedReq.BasicAuth()
if !ok {
t.Fatal("BasicAuth() returned ok=false")
}
if username != "user" {
t.Errorf("username = %q, want %q", username, "user")
}
if password != "pass" {
t.Errorf("password = %q, want %q", password, "pass")
}
})
}