From f9a05f5c57c5c94e093cb52c97b7efd32084015b Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Fri, 20 Mar 2026 14:22:22 +0300 Subject: [PATCH] Add Client with response wrapper, request helpers, and full middleware assembly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the top-level httpx.Client that composes the full chain: Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Transport - Response wrapper with JSON/XML/Bytes decoding and body caching - NewJSONRequest helper with Content-Type and GetBody support - Functional options: WithBaseURL, WithTimeout, WithRetry, WithEndpoints, etc. - Integration tests covering retry, balancing, error mapping, and JSON round-trips --- client.go | 164 ++++++++++++++++++++++++ client_options.go | 87 +++++++++++++ client_test.go | 311 ++++++++++++++++++++++++++++++++++++++++++++++ error_test.go | 90 ++++++++++++++ request.go | 34 +++++ request_test.go | 78 ++++++++++++ response.go | 99 +++++++++++++++ response_test.go | 96 ++++++++++++++ 8 files changed, 959 insertions(+) create mode 100644 client.go create mode 100644 client_options.go create mode 100644 client_test.go create mode 100644 error_test.go create mode 100644 request.go create mode 100644 request_test.go create mode 100644 response.go create mode 100644 response_test.go diff --git a/client.go b/client.go new file mode 100644 index 0000000..eb25b28 --- /dev/null +++ b/client.go @@ -0,0 +1,164 @@ +package httpx + +import ( + "context" + "io" + "net/http" + "strings" + + "git.codelab.vc/pkg/httpx/balancer" + "git.codelab.vc/pkg/httpx/circuitbreaker" + "git.codelab.vc/pkg/httpx/middleware" + "git.codelab.vc/pkg/httpx/retry" +) + +// Client is a high-level HTTP client that composes middleware for retry, +// circuit breaking, load balancing, logging, and more. +type Client struct { + httpClient *http.Client + baseURL string + errorMapper ErrorMapper +} + +// New creates a new Client with the given options. +// +// The middleware chain is assembled as (outermost → innermost): +// +// Logging → User Middlewares → Retry → Circuit Breaker → Balancer → Base Transport +func New(opts ...Option) *Client { + o := &clientOptions{ + transport: http.DefaultTransport, + } + for _, opt := range opts { + opt(o) + } + + // Build the middleware chain from inside out. + var chain []middleware.Middleware + + // Balancer (innermost, wraps base transport). + if len(o.endpoints) > 0 { + chain = append(chain, balancer.Transport(o.endpoints, o.balancerOpts...)) + } + + // Circuit breaker wraps balancer. + if o.enableCB { + chain = append(chain, circuitbreaker.Transport(o.cbOpts...)) + } + + // Retry wraps circuit breaker + balancer. + if o.enableRetry { + chain = append(chain, retry.Transport(o.retryOpts...)) + } + + // User middlewares. + for i := len(o.middlewares) - 1; i >= 0; i-- { + chain = append(chain, o.middlewares[i]) + } + + // Logging (outermost). + if o.logger != nil { + chain = append(chain, middleware.Logging(o.logger)) + } + + // Assemble: chain[last] is outermost. + rt := o.transport + for _, mw := range chain { + rt = mw(rt) + } + + return &Client{ + httpClient: &http.Client{ + Transport: rt, + Timeout: o.timeout, + }, + baseURL: o.baseURL, + errorMapper: o.errorMapper, + } +} + +// Do executes an HTTP request. +func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) { + req = req.WithContext(ctx) + c.resolveURL(req) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, &Error{ + Op: req.Method, + URL: req.URL.String(), + Err: err, + } + } + + r := newResponse(resp) + + if c.errorMapper != nil { + if mapErr := c.errorMapper(resp); mapErr != nil { + return r, &Error{ + Op: req.Method, + URL: req.URL.String(), + StatusCode: resp.StatusCode, + Err: mapErr, + } + } + } + + return r, nil +} + +// Get performs a GET request to the given URL. +func (c *Client) Get(ctx context.Context, url string) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + return c.Do(ctx, req) +} + +// Post performs a POST request to the given URL with the given body. +func (c *Client) Post(ctx context.Context, url string, body io.Reader) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body) + if err != nil { + return nil, err + } + return c.Do(ctx, req) +} + +// Put performs a PUT request to the given URL with the given body. +func (c *Client) Put(ctx context.Context, url string, body io.Reader) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body) + if err != nil { + return nil, err + } + return c.Do(ctx, req) +} + +// Delete performs a DELETE request to the given URL. +func (c *Client) Delete(ctx context.Context, url string) (*Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) + if err != nil { + return nil, err + } + return c.Do(ctx, req) +} + +// HTTPClient returns the underlying *http.Client for advanced use cases. +func (c *Client) HTTPClient() *http.Client { + return c.httpClient +} + +func (c *Client) resolveURL(req *http.Request) { + if c.baseURL == "" { + return + } + // Only resolve relative URLs (no scheme). + if req.URL.Scheme == "" && req.URL.Host == "" { + path := req.URL.String() + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + base := strings.TrimRight(c.baseURL, "/") + req.URL, _ = req.URL.Parse(base + path) + } +} diff --git a/client_options.go b/client_options.go new file mode 100644 index 0000000..084a72b --- /dev/null +++ b/client_options.go @@ -0,0 +1,87 @@ +package httpx + +import ( + "log/slog" + "net/http" + "time" + + "git.codelab.vc/pkg/httpx/balancer" + "git.codelab.vc/pkg/httpx/circuitbreaker" + "git.codelab.vc/pkg/httpx/middleware" + "git.codelab.vc/pkg/httpx/retry" +) + +type clientOptions struct { + baseURL string + timeout time.Duration + transport http.RoundTripper + logger *slog.Logger + errorMapper ErrorMapper + middlewares []middleware.Middleware + retryOpts []retry.Option + enableRetry bool + cbOpts []circuitbreaker.Option + enableCB bool + endpoints []balancer.Endpoint + balancerOpts []balancer.Option +} + +// Option configures a Client. +type Option func(*clientOptions) + +// WithBaseURL sets the base URL prepended to all relative request paths. +func WithBaseURL(url string) Option { + return func(o *clientOptions) { o.baseURL = url } +} + +// WithTimeout sets the overall request timeout. +func WithTimeout(d time.Duration) Option { + return func(o *clientOptions) { o.timeout = d } +} + +// WithTransport sets the base http.RoundTripper. Defaults to http.DefaultTransport. +func WithTransport(rt http.RoundTripper) Option { + return func(o *clientOptions) { o.transport = rt } +} + +// WithLogger enables structured logging of requests and responses. +func WithLogger(l *slog.Logger) Option { + return func(o *clientOptions) { o.logger = l } +} + +// WithErrorMapper sets a function that maps HTTP responses to errors. +func WithErrorMapper(m ErrorMapper) Option { + return func(o *clientOptions) { o.errorMapper = m } +} + +// WithMiddleware appends user middlewares to the chain. +// These run between logging and retry in the middleware stack. +func WithMiddleware(mws ...middleware.Middleware) Option { + return func(o *clientOptions) { o.middlewares = append(o.middlewares, mws...) } +} + +// WithRetry enables retry with the given options. +func WithRetry(opts ...retry.Option) Option { + return func(o *clientOptions) { + o.enableRetry = true + o.retryOpts = opts + } +} + +// WithCircuitBreaker enables per-host circuit breaking. +func WithCircuitBreaker(opts ...circuitbreaker.Option) Option { + return func(o *clientOptions) { + o.enableCB = true + o.cbOpts = opts + } +} + +// WithEndpoints sets the endpoints for load balancing. +func WithEndpoints(eps ...balancer.Endpoint) Option { + return func(o *clientOptions) { o.endpoints = eps } +} + +// WithBalancer configures the load balancer strategy and options. +func WithBalancer(opts ...balancer.Option) Option { + return func(o *clientOptions) { o.balancerOpts = opts } +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..518f220 --- /dev/null +++ b/client_test.go @@ -0,0 +1,311 @@ +package httpx_test + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "git.codelab.vc/pkg/httpx" + "git.codelab.vc/pkg/httpx/balancer" + "git.codelab.vc/pkg/httpx/middleware" + "git.codelab.vc/pkg/httpx/retry" +) + +func TestClient_Get(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "hello") + })) + defer srv.Close() + + client := httpx.New() + resp, err := client.Get(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + body, err := resp.String() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if body != "hello" { + t.Errorf("expected body %q, got %q", "hello", body) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestClient_Post(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + b, _ := io.ReadAll(r.Body) + if string(b) != "request-body" { + t.Errorf("expected body %q, got %q", "request-body", string(b)) + } + w.WriteHeader(http.StatusCreated) + fmt.Fprint(w, "created") + })) + defer srv.Close() + + client := httpx.New() + resp, err := client.Post(context.Background(), srv.URL+"/items", strings.NewReader("request-body")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resp.StatusCode != http.StatusCreated { + t.Errorf("expected status 201, got %d", resp.StatusCode) + } + body, err := resp.String() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if body != "created" { + t.Errorf("expected body %q, got %q", "created", body) + } +} + +func TestClient_BaseURL(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/users" { + t.Errorf("expected path /api/v1/users, got %s", r.URL.Path) + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + client := httpx.New(httpx.WithBaseURL(srv.URL + "/api/v1")) + + // Use a relative path (no scheme/host). + resp, err := client.Get(context.Background(), "/users") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestClient_WithMiddleware(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + val := r.Header.Get("X-Custom-Header") + if val != "test-value" { + t.Errorf("expected header X-Custom-Header=%q, got %q", "test-value", val) + } + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + addHeader := func(next http.RoundTripper) http.RoundTripper { + return middleware.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + req.Header.Set("X-Custom-Header", "test-value") + return next.RoundTrip(req) + }) + } + + client := httpx.New(httpx.WithMiddleware(addHeader)) + + resp, err := client.Get(context.Background(), srv.URL+"/test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +func TestClient_RetryIntegration(t *testing.T) { + var calls atomic.Int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := calls.Add(1) + if n <= 2 { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, "success") + })) + defer srv.Close() + + client := httpx.New( + httpx.WithRetry( + retry.WithMaxAttempts(3), + retry.WithBackoff(retry.ConstantBackoff(1*time.Millisecond)), + ), + ) + + resp, err := client.Get(context.Background(), srv.URL+"/flaky") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } + + body, err := resp.String() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if body != "success" { + t.Errorf("expected body %q, got %q", "success", body) + } + + if got := calls.Load(); got != 3 { + t.Errorf("expected 3 total requests, got %d", got) + } +} + +func TestClient_BalancerIntegration(t *testing.T) { + var hits1, hits2 atomic.Int32 + + srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits1.Add(1) + fmt.Fprint(w, "server1") + })) + defer srv1.Close() + + srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hits2.Add(1) + fmt.Fprint(w, "server2") + })) + defer srv2.Close() + + client := httpx.New( + httpx.WithEndpoints( + balancer.Endpoint{URL: srv1.URL}, + balancer.Endpoint{URL: srv2.URL}, + ), + ) + + const totalRequests = 6 + for i := range totalRequests { + resp, err := client.Get(context.Background(), fmt.Sprintf("/item/%d", i)) + if err != nil { + t.Fatalf("request %d: unexpected error: %v", i, err) + } + resp.Close() + } + + h1 := hits1.Load() + h2 := hits2.Load() + + if h1+h2 != totalRequests { + t.Errorf("expected %d total hits, got %d", totalRequests, h1+h2) + } + if h1 == 0 || h2 == 0 { + t.Errorf("expected requests distributed across both servers, got server1=%d server2=%d", h1, h2) + } +} + +func TestClient_ErrorMapper(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, "not found") + })) + defer srv.Close() + + mapper := func(resp *http.Response) error { + if resp.StatusCode >= 400 { + return fmt.Errorf("HTTP %d", resp.StatusCode) + } + return nil + } + + client := httpx.New(httpx.WithErrorMapper(mapper)) + + resp, err := client.Get(context.Background(), srv.URL+"/missing") + if err == nil { + t.Fatal("expected error, got nil") + } + + // The response should still be returned alongside the error. + if resp == nil { + t.Fatal("expected non-nil response even on mapped error") + } + if resp.StatusCode != http.StatusNotFound { + t.Errorf("expected status 404, got %d", resp.StatusCode) + } + + // Verify the error message contains the status code. + if !strings.Contains(err.Error(), "404") { + t.Errorf("expected error to contain 404, got: %v", err) + } +} + +func TestClient_JSON(t *testing.T) { + type reqPayload struct { + Name string `json:"name"` + Age int `json:"age"` + } + type respPayload struct { + ID int `json:"id"` + Name string `json:"name"` + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected Content-Type application/json, got %q", ct) + } + + var p reqPayload + if err := json.NewDecoder(r.Body).Decode(&p); err != nil { + t.Errorf("decoding request body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + if p.Name != "Alice" || p.Age != 30 { + t.Errorf("unexpected payload: %+v", p) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(respPayload{ID: 1, Name: p.Name}) + })) + defer srv.Close() + + client := httpx.New() + + req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, srv.URL+"/users", reqPayload{ + Name: "Alice", + Age: 30, + }) + if err != nil { + t.Fatalf("creating JSON request: %v", err) + } + + resp, err := client.Do(context.Background(), req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + var result respPayload + if err := resp.JSON(&result); err != nil { + t.Fatalf("decoding JSON response: %v", err) + } + if result.ID != 1 { + t.Errorf("expected ID 1, got %d", result.ID) + } + if result.Name != "Alice" { + t.Errorf("expected Name %q, got %q", "Alice", result.Name) + } +} + +// Ensure slog import is used (referenced in imports for completeness with the spec). +var _ = slog.Default diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..7645680 --- /dev/null +++ b/error_test.go @@ -0,0 +1,90 @@ +package httpx_test + +import ( + "errors" + "testing" + + "git.codelab.vc/pkg/httpx" +) + +func TestError(t *testing.T) { + t.Run("formats without endpoint", func(t *testing.T) { + inner := errors.New("connection refused") + e := &httpx.Error{ + Op: "Get", + URL: "http://example.com/api", + Err: inner, + } + + want := "httpx: Get http://example.com/api: connection refused" + if got := e.Error(); got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("formats with endpoint different from url", func(t *testing.T) { + inner := errors.New("timeout") + e := &httpx.Error{ + Op: "Do", + URL: "http://example.com/api", + Endpoint: "http://node1.example.com/api", + Err: inner, + } + + want := "httpx: Do http://example.com/api (endpoint http://node1.example.com/api): timeout" + if got := e.Error(); got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("formats with endpoint same as url", func(t *testing.T) { + inner := errors.New("not found") + e := &httpx.Error{ + Op: "Get", + URL: "http://example.com/api", + Endpoint: "http://example.com/api", + Err: inner, + } + + want := "httpx: Get http://example.com/api: not found" + if got := e.Error(); got != want { + t.Errorf("got %q, want %q", got, want) + } + }) + + t.Run("unwrap returns inner error", func(t *testing.T) { + inner := errors.New("underlying") + e := &httpx.Error{Op: "Get", URL: "http://example.com", Err: inner} + + if got := e.Unwrap(); got != inner { + t.Errorf("Unwrap() = %v, want %v", got, inner) + } + + if !errors.Is(e, inner) { + t.Error("errors.Is should find the inner error") + } + }) +} + +func TestSentinelErrors(t *testing.T) { + t.Run("ErrRetryExhausted", func(t *testing.T) { + if httpx.ErrRetryExhausted == nil { + t.Fatal("ErrRetryExhausted is nil") + } + if httpx.ErrRetryExhausted.Error() == "" { + t.Fatal("ErrRetryExhausted has empty message") + } + }) + + t.Run("ErrCircuitOpen", func(t *testing.T) { + if httpx.ErrCircuitOpen == nil { + t.Fatal("ErrCircuitOpen is nil") + } + }) + + t.Run("ErrNoHealthy", func(t *testing.T) { + if httpx.ErrNoHealthy == nil { + t.Fatal("ErrNoHealthy is nil") + } + }) +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..c7babbe --- /dev/null +++ b/request.go @@ -0,0 +1,34 @@ +package httpx + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// NewRequest creates an http.Request with context. It is a convenience +// wrapper around http.NewRequestWithContext. +func NewRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + return http.NewRequestWithContext(ctx, method, url, body) +} + +// NewJSONRequest creates an http.Request with a JSON-encoded body and +// sets Content-Type to application/json. +func NewJSONRequest(ctx context.Context, method, url string, body any) (*http.Request, error) { + b, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("httpx: encoding JSON body: %w", err) + } + req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(b)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(b)), nil + } + return req, nil +} diff --git a/request_test.go b/request_test.go new file mode 100644 index 0000000..a67f6f8 --- /dev/null +++ b/request_test.go @@ -0,0 +1,78 @@ +package httpx_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + + "git.codelab.vc/pkg/httpx" +) + +func TestNewJSONRequest(t *testing.T) { + t.Run("body is JSON encoded", func(t *testing.T) { + payload := map[string]string{"key": "value"} + req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + body, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + + var decoded map[string]string + if err := json.Unmarshal(body, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if decoded["key"] != "value" { + t.Errorf("decoded[key] = %q, want %q", decoded["key"], "value") + } + }) + + t.Run("content type is set", func(t *testing.T) { + req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", "test") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ct := req.Header.Get("Content-Type") + if ct != "application/json" { + t.Errorf("Content-Type = %q, want %q", ct, "application/json") + } + }) + + t.Run("GetBody works", func(t *testing.T) { + payload := map[string]int{"num": 123} + req, err := httpx.NewJSONRequest(context.Background(), http.MethodPost, "http://example.com", payload) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if req.GetBody == nil { + t.Fatal("GetBody is nil") + } + + // Read body first time + b1, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + + // Get a fresh body + body2, err := req.GetBody() + if err != nil { + t.Fatalf("GetBody(): %v", err) + } + b2, err := io.ReadAll(body2) + if err != nil { + t.Fatalf("reading body2: %v", err) + } + + if string(b1) != string(b2) { + t.Errorf("GetBody returned different data: %q vs %q", b1, b2) + } + }) +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..3068cfe --- /dev/null +++ b/response.go @@ -0,0 +1,99 @@ +package httpx + +import ( + "bytes" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" +) + +// Response wraps http.Response with convenience methods. +type Response struct { + *http.Response + body []byte + read bool +} + +func newResponse(resp *http.Response) *Response { + return &Response{Response: resp} +} + +// Bytes reads and returns the entire response body. +// The body is cached so subsequent calls return the same data. +func (r *Response) Bytes() ([]byte, error) { + if r.read { + return r.body, nil + } + defer r.Body.Close() + b, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + r.body = b + r.read = true + return b, nil +} + +// String reads the response body and returns it as a string. +func (r *Response) String() (string, error) { + b, err := r.Bytes() + if err != nil { + return "", err + } + return string(b), nil +} + +// JSON decodes the response body as JSON into v. +func (r *Response) JSON(v any) error { + b, err := r.Bytes() + if err != nil { + return fmt.Errorf("httpx: reading response body: %w", err) + } + if err := json.Unmarshal(b, v); err != nil { + return fmt.Errorf("httpx: decoding JSON: %w", err) + } + return nil +} + +// XML decodes the response body as XML into v. +func (r *Response) XML(v any) error { + b, err := r.Bytes() + if err != nil { + return fmt.Errorf("httpx: reading response body: %w", err) + } + if err := xml.Unmarshal(b, v); err != nil { + return fmt.Errorf("httpx: decoding XML: %w", err) + } + return nil +} + +// IsSuccess returns true if the status code is in the 2xx range. +func (r *Response) IsSuccess() bool { + return r.StatusCode >= 200 && r.StatusCode < 300 +} + +// IsError returns true if the status code is 4xx or 5xx. +func (r *Response) IsError() bool { + return r.StatusCode >= 400 +} + +// Close drains and closes the response body. +func (r *Response) Close() error { + if r.read { + return nil + } + _, _ = io.Copy(io.Discard, r.Body) + return r.Body.Close() +} + +// BodyReader returns a reader for the response body. +// If the body has already been read via Bytes/String/JSON/XML, +// returns a reader over the cached bytes. +func (r *Response) BodyReader() io.Reader { + if r.read { + return bytes.NewReader(r.body) + } + return r.Body +} diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..8a45fc4 --- /dev/null +++ b/response_test.go @@ -0,0 +1,96 @@ +package httpx + +import ( + "io" + "net/http" + "strings" + "testing" +) + +func makeTestResponse(statusCode int, body string) *Response { + return newResponse(&http.Response{ + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }) +} + +func TestResponse(t *testing.T) { + t.Run("Bytes returns body", func(t *testing.T) { + r := makeTestResponse(200, "hello world") + b, err := r.Bytes() + if err != nil { + t.Fatalf("Bytes() error: %v", err) + } + if string(b) != "hello world" { + t.Errorf("Bytes() = %q, want %q", string(b), "hello world") + } + }) + + t.Run("body caching returns same data", func(t *testing.T) { + r := makeTestResponse(200, "cached body") + b1, err := r.Bytes() + if err != nil { + t.Fatalf("first Bytes() error: %v", err) + } + b2, err := r.Bytes() + if err != nil { + t.Fatalf("second Bytes() error: %v", err) + } + if string(b1) != string(b2) { + t.Errorf("Bytes() returned different data: %q vs %q", b1, b2) + } + }) + + t.Run("String returns body as string", func(t *testing.T) { + r := makeTestResponse(200, "string body") + s, err := r.String() + if err != nil { + t.Fatalf("String() error: %v", err) + } + if s != "string body" { + t.Errorf("String() = %q, want %q", s, "string body") + } + }) + + t.Run("JSON decodes body", func(t *testing.T) { + r := makeTestResponse(200, `{"name":"test","value":42}`) + var result struct { + Name string `json:"name"` + Value int `json:"value"` + } + if err := r.JSON(&result); err != nil { + t.Fatalf("JSON() error: %v", err) + } + if result.Name != "test" { + t.Errorf("Name = %q, want %q", result.Name, "test") + } + if result.Value != 42 { + t.Errorf("Value = %d, want %d", result.Value, 42) + } + }) + + t.Run("IsSuccess for 2xx", func(t *testing.T) { + for _, code := range []int{200, 201, 204, 299} { + r := makeTestResponse(code, "") + if !r.IsSuccess() { + t.Errorf("IsSuccess() = false for status %d", code) + } + if r.IsError() { + t.Errorf("IsError() = true for status %d", code) + } + } + }) + + t.Run("IsError for 4xx and 5xx", func(t *testing.T) { + for _, code := range []int{400, 404, 500, 503} { + r := makeTestResponse(code, "") + if !r.IsError() { + t.Errorf("IsError() = false for status %d", code) + } + if r.IsSuccess() { + t.Errorf("IsSuccess() = true for status %d", code) + } + } + }) +}