diff --git a/middleware/auth.go b/middleware/auth.go index 0564a40..55ad6b7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -15,6 +15,9 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware if err != nil { return nil, err } + // RoundTrippers must not mutate the caller's request; clone before + // setting headers (req.Clone is shallow + a header copy). + req = req.Clone(req.Context()) req.Header.Set("Authorization", "Bearer "+token) return next.RoundTrip(req) }) @@ -26,6 +29,7 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware func BasicAuth(username, password string) Middleware { return func(next http.RoundTripper) http.RoundTripper { return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) req.SetBasicAuth(username, password) return next.RoundTrip(req) }) diff --git a/middleware/headers.go b/middleware/headers.go index 895f890..00b6567 100644 --- a/middleware/headers.go +++ b/middleware/headers.go @@ -7,10 +7,17 @@ import "net/http" func DefaultHeaders(headers http.Header) Middleware { return func(next http.RoundTripper) http.RoundTripper { return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + // Clone lazily on the first header we actually add, so that + // RoundTrippers never mutate the caller's request. + cloned := false for key, values := range headers { if req.Header.Get(key) != "" { continue } + if !cloned { + req = req.Clone(req.Context()) + cloned = true + } for _, v := range values { req.Header.Add(key, v) } diff --git a/server/middleware.go b/server/middleware.go index 9b7b79f..78300aa 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -25,8 +25,8 @@ func Chain(mws ...Middleware) Middleware { // underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.). type statusWriter struct { http.ResponseWriter - status int - written bool + status int + written bool } // WriteHeader captures the status code and delegates to the underlying writer. diff --git a/server/middleware_requestid.go b/server/middleware_requestid.go index 5cfdb26..4a005b0 100644 --- a/server/middleware_requestid.go +++ b/server/middleware_requestid.go @@ -9,9 +9,14 @@ import ( "git.codelab.vc/pkg/httpx/internal/requestid" ) +// maxRequestIDLen bounds the length of a client-supplied request ID that we +// are willing to propagate. +const maxRequestIDLen = 128 + // RequestID returns a middleware that assigns a unique request ID to each -// request. If the incoming request already has an X-Request-Id header, that -// value is used. Otherwise a new UUID v4 is generated via crypto/rand. +// request. If the incoming request carries a valid X-Request-Id header, that +// value is reused; otherwise (or if the supplied value is empty, too long, or +// contains unsafe characters) a new UUID v4 is generated via crypto/rand. // // The request ID is stored in the request context (retrieve with // RequestIDFromContext) and set on the response X-Request-Id header. @@ -19,7 +24,7 @@ func RequestID() Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { id := r.Header.Get("X-Request-Id") - if id == "" { + if !validRequestID(id) { id = newUUID() } @@ -30,6 +35,27 @@ func RequestID() Middleware { } } +// validRequestID reports whether a client-supplied request ID is safe to +// propagate: non-empty, within a sane length, and restricted to characters +// that cannot forge log lines or split response headers. +func validRequestID(id string) bool { + if id == "" || len(id) > maxRequestIDLen { + return false + } + for i := 0; i < len(id); i++ { + c := id[i] + switch { + case c >= 'a' && c <= 'z', + c >= 'A' && c <= 'Z', + c >= '0' && c <= '9', + c == '-', c == '_', c == '.': + default: + return false + } + } + return true +} + // RequestIDFromContext returns the request ID from the context, or an empty // string if none is set. func RequestIDFromContext(ctx context.Context) string { diff --git a/server/middleware_test.go b/server/middleware_test.go index 0419885..c9d61bd 100644 --- a/server/middleware_test.go +++ b/server/middleware_test.go @@ -214,6 +214,36 @@ func TestRequestID(t *testing.T) { t.Fatalf("expected empty, got %q", id) } }) + + t.Run("rejects unsafe incoming ID", func(t *testing.T) { + cases := map[string]string{ + "header injection": "abc\r\nX-Injected: 1", + "contains space": "has space", + "too long": strings.Repeat("a", 200), + } + for name, badID := range cases { + t.Run(name, func(t *testing.T) { + var gotID string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotID = server.RequestIDFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + }) + mw := server.RequestID()(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-Id", badID) + mw.ServeHTTP(w, req) + + if gotID == badID { + t.Fatalf("unsafe incoming ID was propagated verbatim: %q", gotID) + } + if len(gotID) != 36 { + t.Fatalf("expected a freshly generated UUID, got %q (len %d)", gotID, len(gotID)) + } + }) + } + }) } func TestRequestID_UUIDFormat(t *testing.T) {