Add standard middlewares: logging, headers, auth, and panic recovery
- Logging: structured slog output with method, URL, status, duration - DefaultHeaders/UserAgent: inject headers without overwriting existing - BearerAuth/BasicAuth: per-request token resolution and static credentials - Recovery: catches panics in the RoundTripper chain
This commit is contained in:
89
middleware/auth_test.go
Normal file
89
middleware/auth_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/middleware"
|
||||
)
|
||||
|
||||
func TestBearerAuth(t *testing.T) {
|
||||
t.Run("sets authorization header", func(t *testing.T) {
|
||||
var captured http.Header
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
captured = req.Header.Clone()
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
tokenFunc := func(_ context.Context) (string, error) {
|
||||
return "my-secret-token", nil
|
||||
}
|
||||
|
||||
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want := "Bearer my-secret-token"
|
||||
if got := captured.Get("Authorization"); got != want {
|
||||
t.Errorf("Authorization = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when tokenFunc fails", func(t *testing.T) {
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
t.Fatal("base transport should not be called")
|
||||
return nil, nil
|
||||
})
|
||||
|
||||
tokenErr := errors.New("token expired")
|
||||
tokenFunc := func(_ context.Context) (string, error) {
|
||||
return "", tokenErr
|
||||
}
|
||||
|
||||
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, tokenErr) {
|
||||
t.Errorf("got error %v, want %v", err, tokenErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
t.Run("sets basic auth header", func(t *testing.T) {
|
||||
var capturedReq *http.Request
|
||||
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||
capturedReq = req
|
||||
return okResponse(), nil
|
||||
})
|
||||
|
||||
transport := middleware.BasicAuth("user", "pass")(base)
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
|
||||
_, err := transport.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
username, password, ok := capturedReq.BasicAuth()
|
||||
if !ok {
|
||||
t.Fatal("BasicAuth() returned ok=false")
|
||||
}
|
||||
if username != "user" {
|
||||
t.Errorf("username = %q, want %q", username, "user")
|
||||
}
|
||||
if password != "pass" {
|
||||
t.Errorf("password = %q, want %q", password, "pass")
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user