Add standard middlewares: logging, headers, auth, and panic recovery
- Logging: structured slog output with method, URL, status, duration - DefaultHeaders/UserAgent: inject headers without overwriting existing - BearerAuth/BasicAuth: per-request token resolution and static credentials - Recovery: catches panics in the RoundTripper chain
This commit is contained in:
33
middleware/auth.go
Normal file
33
middleware/auth.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BearerAuth returns a middleware that sets the Authorization header to a
|
||||||
|
// Bearer token obtained by calling tokenFunc on each request. If tokenFunc
|
||||||
|
// returns an error, the request is not sent and the error is returned.
|
||||||
|
func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
token, err := tokenFunc(req.Context())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BasicAuth returns a middleware that sets HTTP Basic Authentication
|
||||||
|
// credentials on every outgoing request.
|
||||||
|
func BasicAuth(username, password string) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req.SetBasicAuth(username, password)
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
89
middleware/auth_test.go
Normal file
89
middleware/auth_test.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBearerAuth(t *testing.T) {
|
||||||
|
t.Run("sets authorization header", func(t *testing.T) {
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
tokenFunc := func(_ context.Context) (string, error) {
|
||||||
|
return "my-secret-token", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := "Bearer my-secret-token"
|
||||||
|
if got := captured.Get("Authorization"); got != want {
|
||||||
|
t.Errorf("Authorization = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error when tokenFunc fails", func(t *testing.T) {
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
t.Fatal("base transport should not be called")
|
||||||
|
return nil, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
tokenErr := errors.New("token expired")
|
||||||
|
tokenFunc := func(_ context.Context) (string, error) {
|
||||||
|
return "", tokenErr
|
||||||
|
}
|
||||||
|
|
||||||
|
transport := middleware.BearerAuth(tokenFunc)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, tokenErr) {
|
||||||
|
t.Errorf("got error %v, want %v", err, tokenErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBasicAuth(t *testing.T) {
|
||||||
|
t.Run("sets basic auth header", func(t *testing.T) {
|
||||||
|
var capturedReq *http.Request
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
capturedReq = req
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.BasicAuth("user", "pass")(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
username, password, ok := capturedReq.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("BasicAuth() returned ok=false")
|
||||||
|
}
|
||||||
|
if username != "user" {
|
||||||
|
t.Errorf("username = %q, want %q", username, "user")
|
||||||
|
}
|
||||||
|
if password != "pass" {
|
||||||
|
t.Errorf("password = %q, want %q", password, "pass")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
29
middleware/headers.go
Normal file
29
middleware/headers.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// DefaultHeaders returns a middleware that adds the given headers to every
|
||||||
|
// outgoing request. Existing headers on the request are not overwritten.
|
||||||
|
func DefaultHeaders(headers http.Header) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
for key, values := range headers {
|
||||||
|
if req.Header.Get(key) != "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, v := range values {
|
||||||
|
req.Header.Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserAgent returns a middleware that sets the User-Agent header on every
|
||||||
|
// outgoing request, unless one is already present.
|
||||||
|
func UserAgent(ua string) Middleware {
|
||||||
|
return DefaultHeaders(http.Header{
|
||||||
|
"User-Agent": {ua},
|
||||||
|
})
|
||||||
|
}
|
||||||
107
middleware/headers_test.go
Normal file
107
middleware/headers_test.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultHeaders(t *testing.T) {
|
||||||
|
t.Run("adds headers without overwriting existing", func(t *testing.T) {
|
||||||
|
defaults := http.Header{
|
||||||
|
"X-Custom": {"default-value"},
|
||||||
|
"X-Untouched": {"from-middleware"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.DefaultHeaders(defaults)(base)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
req.Header.Set("X-Custom", "request-value")
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("X-Custom"); got != "request-value" {
|
||||||
|
t.Errorf("X-Custom = %q, want %q (should not overwrite)", got, "request-value")
|
||||||
|
}
|
||||||
|
if got := captured.Get("X-Untouched"); got != "from-middleware" {
|
||||||
|
t.Errorf("X-Untouched = %q, want %q", got, "from-middleware")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("adds headers when absent", func(t *testing.T) {
|
||||||
|
defaults := http.Header{
|
||||||
|
"Accept": {"application/json"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.DefaultHeaders(defaults)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("Accept"); got != "application/json" {
|
||||||
|
t.Errorf("Accept = %q, want %q", got, "application/json")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserAgent(t *testing.T) {
|
||||||
|
t.Run("sets user agent header", func(t *testing.T) {
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.UserAgent("httpx/1.0")(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("User-Agent"); got != "httpx/1.0" {
|
||||||
|
t.Errorf("User-Agent = %q, want %q", got, "httpx/1.0")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("does not overwrite existing user agent", func(t *testing.T) {
|
||||||
|
var captured http.Header
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
captured = req.Header.Clone()
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.UserAgent("httpx/1.0")(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
req.Header.Set("User-Agent", "custom-agent")
|
||||||
|
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := captured.Get("User-Agent"); got != "custom-agent" {
|
||||||
|
t.Errorf("User-Agent = %q, want %q", got, "custom-agent")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
38
middleware/logging.go
Normal file
38
middleware/logging.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logging returns a middleware that logs each request's method, URL, status
|
||||||
|
// code, duration, and error (if any) using the provided structured logger.
|
||||||
|
// Successful responses are logged at Info level; errors at Error level.
|
||||||
|
func Logging(logger *slog.Logger) Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
attrs := []slog.Attr{
|
||||||
|
slog.String("method", req.Method),
|
||||||
|
slog.String("url", req.URL.String()),
|
||||||
|
slog.Duration("duration", duration),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
attrs = append(attrs, slog.String("error", err.Error()))
|
||||||
|
logger.LogAttrs(req.Context(), slog.LevelError, "request failed", attrs...)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs = append(attrs, slog.Int("status", resp.StatusCode))
|
||||||
|
logger.LogAttrs(req.Context(), slog.LevelInfo, "request completed", attrs...)
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
124
middleware/logging_test.go
Normal file
124
middleware/logging_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// captureHandler is a slog.Handler that captures log records for inspection.
|
||||||
|
type captureHandler struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
records []slog.Record
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *captureHandler) Enabled(_ context.Context, _ slog.Level) bool { return true }
|
||||||
|
|
||||||
|
func (h *captureHandler) Handle(_ context.Context, r slog.Record) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.records = append(h.records, r)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *captureHandler) WithAttrs(_ []slog.Attr) slog.Handler { return h }
|
||||||
|
func (h *captureHandler) WithGroup(_ string) slog.Handler { return h }
|
||||||
|
|
||||||
|
func (h *captureHandler) lastRecord() slog.Record {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
return h.records[len(h.records)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogging(t *testing.T) {
|
||||||
|
t.Run("logs method url status duration on success", func(t *testing.T) {
|
||||||
|
handler := &captureHandler{}
|
||||||
|
logger := slog.New(handler)
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Logging(logger)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://example.com/api", nil)
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := handler.lastRecord()
|
||||||
|
if rec.Level != slog.LevelInfo {
|
||||||
|
t.Errorf("got level %v, want %v", rec.Level, slog.LevelInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := map[string]string{}
|
||||||
|
rec.Attrs(func(a slog.Attr) bool {
|
||||||
|
attrs[a.Key] = a.Value.String()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if attrs["method"] != "POST" {
|
||||||
|
t.Errorf("method = %q, want %q", attrs["method"], "POST")
|
||||||
|
}
|
||||||
|
if attrs["url"] != "http://example.com/api" {
|
||||||
|
t.Errorf("url = %q, want %q", attrs["url"], "http://example.com/api")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["status"]; !ok {
|
||||||
|
t.Error("missing status attribute")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["duration"]; !ok {
|
||||||
|
t.Error("missing duration attribute")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("logs error on failure", func(t *testing.T) {
|
||||||
|
handler := &captureHandler{}
|
||||||
|
logger := slog.New(handler)
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, errors.New("connection refused")
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Logging(logger)(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com/fail", nil)
|
||||||
|
_, err := transport.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
rec := handler.lastRecord()
|
||||||
|
if rec.Level != slog.LevelError {
|
||||||
|
t.Errorf("got level %v, want %v", rec.Level, slog.LevelError)
|
||||||
|
}
|
||||||
|
|
||||||
|
attrs := map[string]string{}
|
||||||
|
rec.Attrs(func(a slog.Attr) bool {
|
||||||
|
attrs[a.Key] = a.Value.String()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if attrs["error"] != "connection refused" {
|
||||||
|
t.Errorf("error = %q, want %q", attrs["error"], "connection refused")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["method"]; !ok {
|
||||||
|
t.Error("missing method attribute")
|
||||||
|
}
|
||||||
|
if _, ok := attrs["url"]; !ok {
|
||||||
|
t.Error("missing url attribute")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
115
middleware/middleware_test.go
Normal file
115
middleware/middleware_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper {
|
||||||
|
return middleware.RoundTripperFunc(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func okResponse() *http.Response {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Body: io.NopCloser(strings.NewReader("ok")),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChain(t *testing.T) {
|
||||||
|
t.Run("applies middlewares in correct order", func(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
mwA := func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "A-before")
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
order = append(order, "A-after")
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
mwB := func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "B-before")
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
order = append(order, "B-after")
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
mwC := func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "C-before")
|
||||||
|
resp, err := next.RoundTrip(req)
|
||||||
|
order = append(order, "C-after")
|
||||||
|
return resp, err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
order = append(order, "base")
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
chained := middleware.Chain(mwA, mwB, mwC)(base)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
_, err := chained.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"A-before", "B-before", "C-before", "base", "C-after", "B-after", "A-after"}
|
||||||
|
if len(order) != len(expected) {
|
||||||
|
t.Fatalf("got %v, want %v", order, expected)
|
||||||
|
}
|
||||||
|
for i, v := range expected {
|
||||||
|
if order[i] != v {
|
||||||
|
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty chain returns base transport", func(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
called = true
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
chained := middleware.Chain()(base)
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
_, err := chained.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Fatal("base transport was not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripperFunc(t *testing.T) {
|
||||||
|
t.Run("implements RoundTripper", func(t *testing.T) {
|
||||||
|
fn := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
var rt http.RoundTripper = fn
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
resp, err := rt.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
22
middleware/recovery.go
Normal file
22
middleware/recovery.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Recovery returns a middleware that recovers from panics in the inner
|
||||||
|
// RoundTripper chain. A recovered panic is converted to an error wrapping
|
||||||
|
// the panic value.
|
||||||
|
func Recovery() Middleware {
|
||||||
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
|
return RoundTripperFunc(func(req *http.Request) (resp *http.Response, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic recovered in round trip: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return next.RoundTrip(req)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
51
middleware/recovery_test.go
Normal file
51
middleware/recovery_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecovery(t *testing.T) {
|
||||||
|
t.Run("recovers from panic and returns error", func(t *testing.T) {
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
panic("something went wrong")
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Recovery()(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
t.Errorf("expected nil response, got %v", resp)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "panic recovered") {
|
||||||
|
t.Errorf("error = %q, want it to contain %q", err.Error(), "panic recovered")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "something went wrong") {
|
||||||
|
t.Errorf("error = %q, want it to contain %q", err.Error(), "something went wrong")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passes through normal responses", func(t *testing.T) {
|
||||||
|
base := mockTransport(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return okResponse(), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
transport := middleware.Recovery()(base)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||||
|
|
||||||
|
resp, err := transport.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("got status %d, want %d", resp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user