- Logging: structured slog output with method, URL, status, duration - DefaultHeaders/UserAgent: inject headers without overwriting existing - BearerAuth/BasicAuth: per-request token resolution and static credentials - Recovery: catches panics in the RoundTripper chain
90 lines
2.2 KiB
Go
90 lines
2.2 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"git.codelab.vc/pkg/httpx/middleware"
|
|
)
|
|
|
|
func TestBearerAuth(t *testing.T) {
|
|
t.Run("sets authorization header", func(t *testing.T) {
|
|
var captured http.Header
|
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
captured = req.Header.Clone()
|
|
return okResponse(), nil
|
|
})
|
|
|
|
tokenFunc := func(_ context.Context) (string, error) {
|
|
return "my-secret-token", nil
|
|
}
|
|
|
|
transport := middleware.BearerAuth(tokenFunc)(base)
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
|
|
_, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
want := "Bearer my-secret-token"
|
|
if got := captured.Get("Authorization"); got != want {
|
|
t.Errorf("Authorization = %q, want %q", got, want)
|
|
}
|
|
})
|
|
|
|
t.Run("returns error when tokenFunc fails", func(t *testing.T) {
|
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
t.Fatal("base transport should not be called")
|
|
return nil, nil
|
|
})
|
|
|
|
tokenErr := errors.New("token expired")
|
|
tokenFunc := func(_ context.Context) (string, error) {
|
|
return "", tokenErr
|
|
}
|
|
|
|
transport := middleware.BearerAuth(tokenFunc)(base)
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
|
|
_, err := transport.RoundTrip(req)
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
if !errors.Is(err, tokenErr) {
|
|
t.Errorf("got error %v, want %v", err, tokenErr)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestBasicAuth(t *testing.T) {
|
|
t.Run("sets basic auth header", func(t *testing.T) {
|
|
var capturedReq *http.Request
|
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
|
capturedReq = req
|
|
return okResponse(), nil
|
|
})
|
|
|
|
transport := middleware.BasicAuth("user", "pass")(base)
|
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
|
|
|
_, err := transport.RoundTrip(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
username, password, ok := capturedReq.BasicAuth()
|
|
if !ok {
|
|
t.Fatal("BasicAuth() returned ok=false")
|
|
}
|
|
if username != "user" {
|
|
t.Errorf("username = %q, want %q", username, "user")
|
|
}
|
|
if password != "pass" {
|
|
t.Errorf("password = %q, want %q", password, "pass")
|
|
}
|
|
})
|
|
}
|