diff --git a/internal/requestid/requestid.go b/internal/requestid/requestid.go new file mode 100644 index 0000000..a6bc8ef --- /dev/null +++ b/internal/requestid/requestid.go @@ -0,0 +1,19 @@ +// Package requestid provides a shared context key for request IDs, +// allowing both client and server packages to access request IDs +// without circular imports. +package requestid + +import "context" + +type key struct{} + +// NewContext returns a context with the given request ID. +func NewContext(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, key{}, id) +} + +// FromContext returns the request ID from ctx, or empty string if not set. +func FromContext(ctx context.Context) string { + id, _ := ctx.Value(key{}).(string) + return id +} diff --git a/middleware/requestid.go b/middleware/requestid.go new file mode 100644 index 0000000..67d41be --- /dev/null +++ b/middleware/requestid.go @@ -0,0 +1,23 @@ +package middleware + +import ( + "net/http" + + "git.codelab.vc/pkg/httpx/internal/requestid" +) + +// RequestID returns a middleware that propagates the request ID from the +// request context to the outgoing X-Request-Id header. This pairs with +// the server.RequestID middleware: the server stores the ID in the context, +// and the client middleware forwards it to downstream services. +func RequestID() Middleware { + return func(next http.RoundTripper) http.RoundTripper { + return RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if id := requestid.FromContext(req.Context()); id != "" { + req = req.Clone(req.Context()) + req.Header.Set("X-Request-Id", id) + } + return next.RoundTrip(req) + }) + } +} diff --git a/middleware/requestid_test.go b/middleware/requestid_test.go new file mode 100644 index 0000000..4229592 --- /dev/null +++ b/middleware/requestid_test.go @@ -0,0 +1,69 @@ +package middleware_test + +import ( + "context" + "net/http" + "testing" + + "git.codelab.vc/pkg/httpx/internal/requestid" + "git.codelab.vc/pkg/httpx/middleware" +) + +func TestRequestID(t *testing.T) { + t.Run("propagates ID from context", func(t *testing.T) { + var gotHeader string + base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + gotHeader = req.Header.Get("X-Request-Id") + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + mw := middleware.RequestID()(base) + + ctx := requestid.NewContext(context.Background(), "test-id-123") + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) + _, err := mw.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotHeader != "test-id-123" { + t.Fatalf("X-Request-Id = %q, want %q", gotHeader, "test-id-123") + } + }) + + t.Run("no ID in context skips header", func(t *testing.T) { + var gotHeader string + base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + gotHeader = req.Header.Get("X-Request-Id") + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + mw := middleware.RequestID()(base) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://example.com", nil) + _, err := mw.RoundTrip(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotHeader != "" { + t.Fatalf("expected no X-Request-Id header, got %q", gotHeader) + } + }) + + t.Run("does not mutate original request", func(t *testing.T) { + base := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil + }) + + mw := middleware.RequestID()(base) + + ctx := requestid.NewContext(context.Background(), "test-id") + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com", nil) + _, _ = mw.RoundTrip(req) + + if req.Header.Get("X-Request-Id") != "" { + t.Fatal("original request was mutated") + } + }) +} diff --git a/server/middleware_requestid.go b/server/middleware_requestid.go index 081a5a7..5cfdb26 100644 --- a/server/middleware_requestid.go +++ b/server/middleware_requestid.go @@ -5,9 +5,9 @@ import ( "crypto/rand" "fmt" "net/http" -) -type requestIDKey struct{} + "git.codelab.vc/pkg/httpx/internal/requestid" +) // RequestID returns a middleware that assigns a unique request ID to each // request. If the incoming request already has an X-Request-Id header, that @@ -23,7 +23,7 @@ func RequestID() Middleware { id = newUUID() } - ctx := context.WithValue(r.Context(), requestIDKey{}, id) + ctx := requestid.NewContext(r.Context(), id) w.Header().Set("X-Request-Id", id) next.ServeHTTP(w, r.WithContext(ctx)) }) @@ -33,8 +33,7 @@ func RequestID() Middleware { // RequestIDFromContext returns the request ID from the context, or an empty // string if none is set. func RequestIDFromContext(ctx context.Context) string { - id, _ := ctx.Value(requestIDKey{}).(string) - return id + return requestid.FromContext(ctx) } // newUUID generates a UUID v4 string using crypto/rand.