package middleware_test import ( "net/http" "testing" "git.codelab.vc/pkg/httpx/middleware" ) func TestDefaultHeaders(t *testing.T) { t.Run("adds headers without overwriting existing", func(t *testing.T) { defaults := http.Header{ "X-Custom": {"default-value"}, "X-Untouched": {"from-middleware"}, } var captured http.Header base := mockTransport(func(req *http.Request) (*http.Response, error) { captured = req.Header.Clone() return okResponse(), nil }) transport := middleware.DefaultHeaders(defaults)(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.Header.Set("X-Custom", "request-value") _, err := transport.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if got := captured.Get("X-Custom"); got != "request-value" { t.Errorf("X-Custom = %q, want %q (should not overwrite)", got, "request-value") } if got := captured.Get("X-Untouched"); got != "from-middleware" { t.Errorf("X-Untouched = %q, want %q", got, "from-middleware") } }) t.Run("adds headers when absent", func(t *testing.T) { defaults := http.Header{ "Accept": {"application/json"}, } var captured http.Header base := mockTransport(func(req *http.Request) (*http.Response, error) { captured = req.Header.Clone() return okResponse(), nil }) transport := middleware.DefaultHeaders(defaults)(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) _, err := transport.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if got := captured.Get("Accept"); got != "application/json" { t.Errorf("Accept = %q, want %q", got, "application/json") } }) } func TestUserAgent(t *testing.T) { t.Run("sets user agent header", func(t *testing.T) { var captured http.Header base := mockTransport(func(req *http.Request) (*http.Response, error) { captured = req.Header.Clone() return okResponse(), nil }) transport := middleware.UserAgent("httpx/1.0")(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) _, err := transport.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if got := captured.Get("User-Agent"); got != "httpx/1.0" { t.Errorf("User-Agent = %q, want %q", got, "httpx/1.0") } }) t.Run("does not overwrite existing user agent", func(t *testing.T) { var captured http.Header base := mockTransport(func(req *http.Request) (*http.Response, error) { captured = req.Header.Clone() return okResponse(), nil }) transport := middleware.UserAgent("httpx/1.0")(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) req.Header.Set("User-Agent", "custom-agent") _, err := transport.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if got := captured.Get("User-Agent"); got != "custom-agent" { t.Errorf("User-Agent = %q, want %q", got, "custom-agent") } }) }