Honor RoundTripper contract in middleware; validate incoming X-Request-Id
BearerAuth, BasicAuth and DefaultHeaders mutated the caller's request, which violates the RoundTripper contract and risks races on shared/retried requests; clone before writing headers (matching RequestID). Validate the incoming X-Request-Id (length and character set) before propagating it to logs and the response header, preventing log forging and header splitting from a client-controlled value.
This commit is contained in:
@@ -15,6 +15,9 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
// RoundTrippers must not mutate the caller's request; clone before
|
||||||
|
// setting headers (req.Clone is shallow + a header copy).
|
||||||
|
req = req.Clone(req.Context())
|
||||||
req.Header.Set("Authorization", "Bearer "+token)
|
req.Header.Set("Authorization", "Bearer "+token)
|
||||||
return next.RoundTrip(req)
|
return next.RoundTrip(req)
|
||||||
})
|
})
|
||||||
@@ -26,6 +29,7 @@ func BearerAuth(tokenFunc func(ctx context.Context) (string, error)) Middleware
|
|||||||
func BasicAuth(username, password string) Middleware {
|
func BasicAuth(username, password string) Middleware {
|
||||||
return func(next http.RoundTripper) http.RoundTripper {
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
req.SetBasicAuth(username, password)
|
req.SetBasicAuth(username, password)
|
||||||
return next.RoundTrip(req)
|
return next.RoundTrip(req)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -7,10 +7,17 @@ import "net/http"
|
|||||||
func DefaultHeaders(headers http.Header) Middleware {
|
func DefaultHeaders(headers http.Header) Middleware {
|
||||||
return func(next http.RoundTripper) http.RoundTripper {
|
return func(next http.RoundTripper) http.RoundTripper {
|
||||||
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
return RoundTripperFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
// Clone lazily on the first header we actually add, so that
|
||||||
|
// RoundTrippers never mutate the caller's request.
|
||||||
|
cloned := false
|
||||||
for key, values := range headers {
|
for key, values := range headers {
|
||||||
if req.Header.Get(key) != "" {
|
if req.Header.Get(key) != "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !cloned {
|
||||||
|
req = req.Clone(req.Context())
|
||||||
|
cloned = true
|
||||||
|
}
|
||||||
for _, v := range values {
|
for _, v := range values {
|
||||||
req.Header.Add(key, v)
|
req.Header.Add(key, v)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,9 +9,14 @@ import (
|
|||||||
"git.codelab.vc/pkg/httpx/internal/requestid"
|
"git.codelab.vc/pkg/httpx/internal/requestid"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// maxRequestIDLen bounds the length of a client-supplied request ID that we
|
||||||
|
// are willing to propagate.
|
||||||
|
const maxRequestIDLen = 128
|
||||||
|
|
||||||
// RequestID returns a middleware that assigns a unique request ID to each
|
// RequestID returns a middleware that assigns a unique request ID to each
|
||||||
// request. If the incoming request already has an X-Request-Id header, that
|
// request. If the incoming request carries a valid X-Request-Id header, that
|
||||||
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
|
// value is reused; otherwise (or if the supplied value is empty, too long, or
|
||||||
|
// contains unsafe characters) a new UUID v4 is generated via crypto/rand.
|
||||||
//
|
//
|
||||||
// The request ID is stored in the request context (retrieve with
|
// The request ID is stored in the request context (retrieve with
|
||||||
// RequestIDFromContext) and set on the response X-Request-Id header.
|
// RequestIDFromContext) and set on the response X-Request-Id header.
|
||||||
@@ -19,7 +24,7 @@ func RequestID() Middleware {
|
|||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
id := r.Header.Get("X-Request-Id")
|
id := r.Header.Get("X-Request-Id")
|
||||||
if id == "" {
|
if !validRequestID(id) {
|
||||||
id = newUUID()
|
id = newUUID()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,6 +35,27 @@ func RequestID() Middleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validRequestID reports whether a client-supplied request ID is safe to
|
||||||
|
// propagate: non-empty, within a sane length, and restricted to characters
|
||||||
|
// that cannot forge log lines or split response headers.
|
||||||
|
func validRequestID(id string) bool {
|
||||||
|
if id == "" || len(id) > maxRequestIDLen {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for i := 0; i < len(id); i++ {
|
||||||
|
c := id[i]
|
||||||
|
switch {
|
||||||
|
case c >= 'a' && c <= 'z',
|
||||||
|
c >= 'A' && c <= 'Z',
|
||||||
|
c >= '0' && c <= '9',
|
||||||
|
c == '-', c == '_', c == '.':
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// RequestIDFromContext returns the request ID from the context, or an empty
|
// RequestIDFromContext returns the request ID from the context, or an empty
|
||||||
// string if none is set.
|
// string if none is set.
|
||||||
func RequestIDFromContext(ctx context.Context) string {
|
func RequestIDFromContext(ctx context.Context) string {
|
||||||
|
|||||||
@@ -214,6 +214,36 @@ func TestRequestID(t *testing.T) {
|
|||||||
t.Fatalf("expected empty, got %q", id)
|
t.Fatalf("expected empty, got %q", id)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("rejects unsafe incoming ID", func(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"header injection": "abc\r\nX-Injected: 1",
|
||||||
|
"contains space": "has space",
|
||||||
|
"too long": strings.Repeat("a", 200),
|
||||||
|
}
|
||||||
|
for name, badID := range cases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
var gotID string
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotID = server.RequestIDFromContext(r.Context())
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
mw := server.RequestID()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Request-Id", badID)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if gotID == badID {
|
||||||
|
t.Fatalf("unsafe incoming ID was propagated verbatim: %q", gotID)
|
||||||
|
}
|
||||||
|
if len(gotID) != 36 {
|
||||||
|
t.Fatalf("expected a freshly generated UUID, got %q (len %d)", gotID, len(gotID))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRequestID_UUIDFormat(t *testing.T) {
|
func TestRequestID_UUIDFormat(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user