package middleware_test import ( "io" "net/http" "strings" "testing" "git.codelab.vc/pkg/httpx/middleware" ) func mockTransport(fn func(*http.Request) (*http.Response, error)) http.RoundTripper { return middleware.RoundTripperFunc(fn) } func okResponse() *http.Response { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("ok")), Header: make(http.Header), } } func TestChain(t *testing.T) { t.Run("applies middlewares in correct order", func(t *testing.T) { var order []string mwA := func(next http.RoundTripper) http.RoundTripper { return mockTransport(func(req *http.Request) (*http.Response, error) { order = append(order, "A-before") resp, err := next.RoundTrip(req) order = append(order, "A-after") return resp, err }) } mwB := func(next http.RoundTripper) http.RoundTripper { return mockTransport(func(req *http.Request) (*http.Response, error) { order = append(order, "B-before") resp, err := next.RoundTrip(req) order = append(order, "B-after") return resp, err }) } mwC := func(next http.RoundTripper) http.RoundTripper { return mockTransport(func(req *http.Request) (*http.Response, error) { order = append(order, "C-before") resp, err := next.RoundTrip(req) order = append(order, "C-after") return resp, err }) } base := mockTransport(func(req *http.Request) (*http.Response, error) { order = append(order, "base") return okResponse(), nil }) chained := middleware.Chain(mwA, mwB, mwC)(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) _, err := chained.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } expected := []string{"A-before", "B-before", "C-before", "base", "C-after", "B-after", "A-after"} if len(order) != len(expected) { t.Fatalf("got %v, want %v", order, expected) } for i, v := range expected { if order[i] != v { t.Fatalf("order[%d] = %q, want %q", i, order[i], v) } } }) t.Run("empty chain returns base transport", func(t *testing.T) { called := false base := mockTransport(func(req *http.Request) (*http.Response, error) { called = true return okResponse(), nil }) chained := middleware.Chain()(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) _, err := chained.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if !called { t.Fatal("base transport was not called") } }) } func TestRoundTripperFunc(t *testing.T) { t.Run("implements RoundTripper", func(t *testing.T) { fn := middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { return okResponse(), nil }) var rt http.RoundTripper = fn req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) resp, err := rt.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK) } }) }