BearerAuth, BasicAuth and DefaultHeaders mutated the caller's request, which violates the RoundTripper contract and risks races on shared/retried requests; clone before writing headers (matching RequestID). Validate the incoming X-Request-Id (length and character set) before propagating it to logs and the response header, preventing log forging and header splitting from a client-controlled value.
78 lines
2.1 KiB
Go
78 lines
2.1 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"git.codelab.vc/pkg/httpx/internal/requestid"
|
|
)
|
|
|
|
// maxRequestIDLen bounds the length of a client-supplied request ID that we
|
|
// are willing to propagate.
|
|
const maxRequestIDLen = 128
|
|
|
|
// RequestID returns a middleware that assigns a unique request ID to each
|
|
// request. If the incoming request carries a valid X-Request-Id header, that
|
|
// value is reused; otherwise (or if the supplied value is empty, too long, or
|
|
// contains unsafe characters) a new UUID v4 is generated via crypto/rand.
|
|
//
|
|
// The request ID is stored in the request context (retrieve with
|
|
// RequestIDFromContext) and set on the response X-Request-Id header.
|
|
func RequestID() Middleware {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
id := r.Header.Get("X-Request-Id")
|
|
if !validRequestID(id) {
|
|
id = newUUID()
|
|
}
|
|
|
|
ctx := requestid.NewContext(r.Context(), id)
|
|
w.Header().Set("X-Request-Id", id)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// validRequestID reports whether a client-supplied request ID is safe to
|
|
// propagate: non-empty, within a sane length, and restricted to characters
|
|
// that cannot forge log lines or split response headers.
|
|
func validRequestID(id string) bool {
|
|
if id == "" || len(id) > maxRequestIDLen {
|
|
return false
|
|
}
|
|
for i := 0; i < len(id); i++ {
|
|
c := id[i]
|
|
switch {
|
|
case c >= 'a' && c <= 'z',
|
|
c >= 'A' && c <= 'Z',
|
|
c >= '0' && c <= '9',
|
|
c == '-', c == '_', c == '.':
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// RequestIDFromContext returns the request ID from the context, or an empty
|
|
// string if none is set.
|
|
func RequestIDFromContext(ctx context.Context) string {
|
|
return requestid.FromContext(ctx)
|
|
}
|
|
|
|
// newUUID generates a UUID v4 string using crypto/rand.
|
|
func newUUID() string {
|
|
var uuid [16]byte
|
|
_, _ = rand.Read(uuid[:])
|
|
|
|
// Set version 4 (bits 12-15 of time_hi_and_version).
|
|
uuid[6] = (uuid[6] & 0x0f) | 0x40
|
|
// Set variant bits (10xx).
|
|
uuid[8] = (uuid[8] & 0x3f) | 0x80
|
|
|
|
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
|
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
|
|
}
|