Honor RoundTripper contract in middleware; validate incoming X-Request-Id

BearerAuth, BasicAuth and DefaultHeaders mutated the caller's request, which
violates the RoundTripper contract and risks races on shared/retried requests;
clone before writing headers (matching RequestID). Validate the incoming
X-Request-Id (length and character set) before propagating it to logs and the
response header, preventing log forging and header splitting from a
client-controlled value.
This commit is contained in:
2026-05-23 13:47:38 +03:00
parent 01478be0dc
commit b5259af73e
5 changed files with 72 additions and 5 deletions

View File

@@ -15,6 +15,9 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
if err != nil { if err != nil {
return nil, err return nil, err
} }
// RoundTrippers must not mutate the caller's request; clone before
// setting headers (req.Clone is shallow + a header copy).
req = req.Clone(req.Context())
req.Header.Set("Authorization", "Bearer "+token) req.Header.Set("Authorization", "Bearer "+token)
return next.RoundTrip(req) return next.RoundTrip(req)
}) })
@@ -26,6 +29,7 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
func BasicAuth(username, password string) Middleware { func BasicAuth(username, password string) Middleware {
return func(next http.RoundTripper) http.RoundTripper { return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.SetBasicAuth(username, password) req.SetBasicAuth(username, password)
return next.RoundTrip(req) return next.RoundTrip(req)
}) })

View File

@@ -7,10 +7,17 @@ import "net/http"
func DefaultHeaders(headers http.Header) Middleware { func DefaultHeaders(headers http.Header) Middleware {
return func(next http.RoundTripper) http.RoundTripper { return func(next http.RoundTripper) http.RoundTripper {
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
// Clone lazily on the first header we actually add, so that
// RoundTrippers never mutate the caller's request.
cloned := false
for key, values := range headers { for key, values := range headers {
if req.Header.Get(key) != "" { if req.Header.Get(key) != "" {
continue continue
} }
if !cloned {
req = req.Clone(req.Context())
cloned = true
}
for _, v := range values { for _, v := range values {
req.Header.Add(key, v) req.Header.Add(key, v)
} }

View File

@@ -25,8 +25,8 @@ func Chain(mws ...Middleware) Middleware {
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.). // underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
type statusWriter struct { type statusWriter struct {
http.ResponseWriter http.ResponseWriter
status int status int
written bool written bool
} }
// WriteHeader captures the status code and delegates to the underlying writer. // WriteHeader captures the status code and delegates to the underlying writer.

View File

@@ -9,9 +9,14 @@ import (
"git.codelab.vc/pkg/httpx/internal/requestid" "git.codelab.vc/pkg/httpx/internal/requestid"
) )
// maxRequestIDLen bounds the length of a client-supplied request ID that we
// are willing to propagate.
const maxRequestIDLen = 128
// RequestID returns a middleware that assigns a unique request ID to each // RequestID returns a middleware that assigns a unique request ID to each
// request. If the incoming request already has an X-Request-Id header, that // request. If the incoming request carries a valid X-Request-Id header, that
// value is used. Otherwise a new UUID v4 is generated via crypto/rand. // value is reused; otherwise (or if the supplied value is empty, too long, or
// contains unsafe characters) a new UUID v4 is generated via crypto/rand.
// //
// The request ID is stored in the request context (retrieve with // The request ID is stored in the request context (retrieve with
// RequestIDFromContext) and set on the response X-Request-Id header. // RequestIDFromContext) and set on the response X-Request-Id header.
@@ -19,7 +24,7 @@ func RequestID() Middleware {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-Id") id := r.Header.Get("X-Request-Id")
if id == "" { if !validRequestID(id) {
id = newUUID() id = newUUID()
} }
@@ -30,6 +35,27 @@ func RequestID() Middleware {
} }
} }
// validRequestID reports whether a client-supplied request ID is safe to
// propagate: non-empty, within a sane length, and restricted to characters
// that cannot forge log lines or split response headers.
func validRequestID(id string) bool {
if id == "" || len(id) > maxRequestIDLen {
return false
}
for i := 0; i < len(id); i++ {
c := id[i]
switch {
case c >= 'a' && c <= 'z',
c >= 'A' && c <= 'Z',
c >= '0' && c <= '9',
c == '-', c == '_', c == '.':
default:
return false
}
}
return true
}
// RequestIDFromContext returns the request ID from the context, or an empty // RequestIDFromContext returns the request ID from the context, or an empty
// string if none is set. // string if none is set.
func RequestIDFromContext(ctx context.Context) string { func RequestIDFromContext(ctx context.Context) string {

View File

@@ -214,6 +214,36 @@ func TestRequestID(t *testing.T) {
t.Fatalf("expected empty, got %q", id) t.Fatalf("expected empty, got %q", id)
} }
}) })
t.Run("rejects unsafe incoming ID", func(t *testing.T) {
cases := map[string]string{
"header injection": "abc\r\nX-Injected: 1",
"contains space": "has space",
"too long": strings.Repeat("a", 200),
}
for name, badID := range cases {
t.Run(name, func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Request-Id", badID)
mw.ServeHTTP(w, req)
if gotID == badID {
t.Fatalf("unsafe incoming ID was propagated verbatim: %q", gotID)
}
if len(gotID) != 36 {
t.Fatalf("expected a freshly generated UUID, got %q (len %d)", gotID, len(gotID))
}
})
}
})
} }
func TestRequestID_UUIDFormat(t *testing.T) { func TestRequestID_UUIDFormat(t *testing.T) {