Honor RoundTripper contract in middleware; validate incoming X-Request-Id
BearerAuth, BasicAuth and DefaultHeaders mutated the caller's request, which violates the RoundTripper contract and risks races on shared/retried requests; clone before writing headers (matching RequestID). Validate the incoming X-Request-Id (length and character set) before propagating it to logs and the response header, preventing log forging and header splitting from a client-controlled value.
This commit is contained in:
@@ -15,6 +15,9 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// RoundTrippers must not mutate the caller's request; clone before
|
||||
// setting headers (req.Clone is shallow + a header copy).
|
||||
req = req.Clone(req.Context())
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
@@ -26,6 +29,7 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
||||
func BasicAuth(username, password string) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
req = req.Clone(req.Context())
|
||||
req.SetBasicAuth(username, password)
|
||||
return next.RoundTrip(req)
|
||||
})
|
||||
|
||||
@@ -7,10 +7,17 @@ import "net/http"
|
||||
func DefaultHeaders(headers http.Header) Middleware {
|
||||
return func(next http.RoundTripper) http.RoundTripper {
|
||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||
// Clone lazily on the first header we actually add, so that
|
||||
// RoundTrippers never mutate the caller's request.
|
||||
cloned := false
|
||||
for key, values := range headers {
|
||||
if req.Header.Get(key) != "" {
|
||||
continue
|
||||
}
|
||||
if !cloned {
|
||||
req = req.Clone(req.Context())
|
||||
cloned = true
|
||||
}
|
||||
for _, v := range values {
|
||||
req.Header.Add(key, v)
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ func Chain(mws ...Middleware) Middleware {
|
||||
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
written bool
|
||||
status int
|
||||
written bool
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||
|
||||
@@ -9,9 +9,14 @@ import (
|
||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||
)
|
||||
|
||||
// maxRequestIDLen bounds the length of a client-supplied request ID that we
|
||||
// are willing to propagate.
|
||||
const maxRequestIDLen = 128
|
||||
|
||||
// RequestID returns a middleware that assigns a unique request ID to each
|
||||
// request. If the incoming request already has an X-Request-Id header, that
|
||||
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
|
||||
// request. If the incoming request carries a valid X-Request-Id header, that
|
||||
// value is reused; otherwise (or if the supplied value is empty, too long, or
|
||||
// contains unsafe characters) a new UUID v4 is generated via crypto/rand.
|
||||
//
|
||||
// The request ID is stored in the request context (retrieve with
|
||||
// RequestIDFromContext) and set on the response X-Request-Id header.
|
||||
@@ -19,7 +24,7 @@ func RequestID() Middleware {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.Header.Get("X-Request-Id")
|
||||
if id == "" {
|
||||
if !validRequestID(id) {
|
||||
id = newUUID()
|
||||
}
|
||||
|
||||
@@ -30,6 +35,27 @@ func RequestID() Middleware {
|
||||
}
|
||||
}
|
||||
|
||||
// validRequestID reports whether a client-supplied request ID is safe to
|
||||
// propagate: non-empty, within a sane length, and restricted to characters
|
||||
// that cannot forge log lines or split response headers.
|
||||
func validRequestID(id string) bool {
|
||||
if id == "" || len(id) > maxRequestIDLen {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < len(id); i++ {
|
||||
c := id[i]
|
||||
switch {
|
||||
case c >= 'a' && c <= 'z',
|
||||
c >= 'A' && c <= 'Z',
|
||||
c >= '0' && c <= '9',
|
||||
c == '-', c == '_', c == '.':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestIDFromContext returns the request ID from the context, or an empty
|
||||
// string if none is set.
|
||||
func RequestIDFromContext(ctx context.Context) string {
|
||||
|
||||
@@ -214,6 +214,36 @@ func TestRequestID(t *testing.T) {
|
||||
t.Fatalf("expected empty, got %q", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects unsafe incoming ID", func(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"header injection": "abc\r\nX-Injected: 1",
|
||||
"contains space": "has space",
|
||||
"too long": strings.Repeat("a", 200),
|
||||
}
|
||||
for name, badID := range cases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
var gotID string
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
gotID = server.RequestIDFromContext(r.Context())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
mw := server.RequestID()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("X-Request-Id", badID)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if gotID == badID {
|
||||
t.Fatalf("unsafe incoming ID was propagated verbatim: %q", gotID)
|
||||
}
|
||||
if len(gotID) != 36 {
|
||||
t.Fatalf("expected a freshly generated UUID, got %q (len %d)", gotID, len(gotID))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestID_UUIDFormat(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user