package httpx_test import ( "context" "encoding/json" "fmt" "io" "log/slog" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" "git.codelab.vc/pkg/httpx" "git.codelab.vc/pkg/httpx/balancer" "git.codelab.vc/pkg/httpx/middleware" "git.codelab.vc/pkg/httpx/retry" ) func TestClient_Get(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { t.Errorf("expected GET, got %s", r.Method) } w.WriteHeader(http.StatusOK) fmt.Fprint(w, "hello") })) defer srv.Close() client := httpx.New() resp, err := client.Get(context.Background(), srv.URL+"/test") if err != nil { t.Fatalf("unexpected error: %v", err) } body, err := resp.String() if err != nil { t.Fatalf("reading body: %v", err) } if body != "hello" { t.Errorf("expected body %q, got %q", "hello", body) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } } func TestClient_Post(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { t.Errorf("expected POST, got %s", r.Method) } b, _ := io.ReadAll(r.Body) if string(b) != "request-body" { t.Errorf("expected body %q, got %q", "request-body", string(b)) } w.WriteHeader(http.StatusCreated) fmt.Fprint(w, "created") })) defer srv.Close() client := httpx.New() resp, err := client.Post(context.Background(), srv.URL+"/items", strings.NewReader("request-body")) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusCreated { t.Errorf("expected status 201, got %d", resp.StatusCode) } body, err := resp.String() if err != nil { t.Fatalf("reading body: %v", err) } if body != "created" { t.Errorf("expected body %q, got %q", "created", body) } } func TestClient_BaseURL(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/api/v1/users" { t.Errorf("expected path /api/v1/users, got %s", r.URL.Path) } w.WriteHeader(http.StatusOK) })) defer srv.Close() client := httpx.New(httpx.WithBaseURL(srv.URL + "/api/v1")) // Use a relative path (no scheme/host). resp, err := client.Get(context.Background(), "/users") if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } } func TestClient_WithMiddleware(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { val := r.Header.Get("X-Custom-Header") if val != "test-value" { t.Errorf("expected header X-Custom-Header=%q, got %q", "test-value", val) } w.WriteHeader(http.StatusOK) })) defer srv.Close() addHeader := func(next http.RoundTripper) http.RoundTripper { return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) req.Header.Set("X-Custom-Header", "test-value") return next.RoundTrip(req) }) } client := httpx.New(httpx.WithMiddleware(addHeader)) resp, err := client.Get(context.Background(), srv.URL+"/test") if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } } func TestClient_RetryIntegration(t *testing.T) { var calls atomic.Int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := calls.Add(1) if n <= 2 { w.WriteHeader(http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) fmt.Fprint(w, "success") })) defer srv.Close() client := httpx.New( httpx.WithRetry( retry.WithMaxAttempts(3), retry.WithBackoff(retry.ConstantBackoff(1*time.Millisecond)), ), ) resp, err := client.Get(context.Background(), srv.URL+"/flaky") if err != nil { t.Fatalf("unexpected error: %v", err) } if resp.StatusCode != http.StatusOK { t.Errorf("expected status 200, got %d", resp.StatusCode) } body, err := resp.String() if err != nil { t.Fatalf("reading body: %v", err) } if body != "success" { t.Errorf("expected body %q, got %q", "success", body) } if got := calls.Load(); got != 3 { t.Errorf("expected 3 total requests, got %d", got) } } func TestClient_BalancerIntegration(t *testing.T) { var hits1, hits2 atomic.Int32 srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hits1.Add(1) fmt.Fprint(w, "server1") })) defer srv1.Close() srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { hits2.Add(1) fmt.Fprint(w, "server2") })) defer srv2.Close() client := httpx.New( httpx.WithEndpoints( balancer.Endpoint{URL: srv1.URL}, balancer.Endpoint{URL: srv2.URL}, ), ) const totalRequests = 6 for i := range totalRequests { resp, err := client.Get(context.Background(), fmt.Sprintf("/item/%d", i)) if err != nil { t.Fatalf("request %d: unexpected error: %v", i, err) } resp.Close() } h1 := hits1.Load() h2 := hits2.Load() if h1+h2 != totalRequests { t.Errorf("expected %d total hits, got %d", totalRequests, h1+h2) } if h1 == 0 || h2 == 0 { t.Errorf("expected requests distributed across both servers, got server1=%d server2=%d", h1, h2) } } func TestClient_ErrorMapper(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) fmt.Fprint(w, "not found") })) defer srv.Close() mapper := func(resp *http.Response) error { if resp.StatusCode >= 400 { return fmt.Errorf("HTTP %d", resp.StatusCode) } return nil } client := httpx.New(httpx.WithErrorMapper(mapper)) resp, err := client.Get(context.Background(), srv.URL+"/missing") if err == nil { t.Fatal("expected error, got nil") } // The response should still be returned alongside the error. if resp == nil { t.Fatal("expected non-nil response even on mapped error") } if resp.StatusCode != http.StatusNotFound { t.Errorf("expected status 404, got %d", resp.StatusCode) } // Verify the error message contains the status code. if !strings.Contains(err.Error(), "404") { t.Errorf("expected error to contain 404, got: %v", err) } } func TestClient_JSON(t *testing.T) { type reqPayload struct { Name string `json:"name"` Age int `json:"age"` } type respPayload struct { ID int `json:"id"` Name string `json:"name"` } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if ct := r.Header.Get("Content-Type"); ct != "application/json" { t.Errorf("expected Content-Type application/json, got %q", ct) } var p reqPayload if err := json.NewDecoder(r.Body).Decode(&p); err != nil { t.Errorf("decoding request body: %v", err) w.WriteHeader(http.StatusBadRequest) return } if p.Name != "Alice" || p.Age != 30 { t.Errorf("unexpected payload: %+v", p) } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(respPayload{ID: 1, Name: p.Name}) })) defer srv.Close() client := httpx.New() req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, srv.URL+"/users", reqPayload{ Name: "Alice", Age: 30, }) if err != nil { t.Fatalf("creating JSON request: %v", err) } resp, err := client.Do(context.Background(), req) if err != nil { t.Fatalf("unexpected error: %v", err) } var result respPayload if err := resp.JSON(&result); err != nil { t.Fatalf("decoding JSON response: %v", err) } if result.ID != 1 { t.Errorf("expected ID 1, got %d", result.ID) } if result.Name != "Alice" { t.Errorf("expected Name %q, got %q", "Alice", result.Name) } } // Ensure slog import is used (referenced in imports for completeness with the spec). var _ = slog.Default