diff --git a/client.go b/client.go index 25f752b..ca02535 100644 --- a/client.go +++ b/client.go @@ -102,9 +102,12 @@ func (c *Client) Do(ctx context.Context, req *http.Request) (*Response, error) { } if c.maxResponseBody > 0 { + // Read one byte past the limit so we can distinguish "exactly at the + // limit" (allowed) from "exceeds the limit" (ErrResponseTooLarge). resp.Body = &limitedReadCloser{ - R: io.LimitedReader{R: resp.Body, N: c.maxResponseBody}, - C: resp.Body, + r: io.LimitReader(resp.Body, c.maxResponseBody+1), + c: resp.Body, + limit: c.maxResponseBody, } } diff --git a/client_options.go b/client_options.go index b38f1a5..a986ef3 100644 --- a/client_options.go +++ b/client_options.go @@ -89,8 +89,8 @@ func WithBalancer(opts ...balancer.Option) Option { // WithMaxResponseBody limits the number of bytes read from response bodies // by Response.Bytes (and by extension String, JSON, XML). If the response -// body exceeds n bytes, reading stops and returns an error. -// A value of 0 means no limit (the default). +// body exceeds n bytes, reading returns ErrResponseTooLarge instead of +// silently truncating. A value of 0 means no limit (the default). func WithMaxResponseBody(n int64) Option { return func(o *clientOptions) { o.maxResponseBody = n } } diff --git a/error.go b/error.go index fc47c54..ec67f90 100644 --- a/error.go +++ b/error.go @@ -1,6 +1,7 @@ package httpx import ( + "errors" "fmt" "net/http" @@ -18,6 +19,11 @@ var ( ErrNoHealthy = balancer.ErrNoHealthy ) +// ErrResponseTooLarge is returned when reading a response body that exceeds +// the limit configured via WithMaxResponseBody. Any bytes read up to the +// limit are returned alongside the error. +var ErrResponseTooLarge = errors.New("httpx: response body exceeds configured limit") + // Error provides structured error information for failed HTTP operations. type Error struct { // Op is the operation that failed (e.g. "Get", "Do"). diff --git a/response.go b/response.go index af37eeb..ec8e7c5 100644 --- a/response.go +++ b/response.go @@ -98,17 +98,26 @@ func (r *Response) BodyReader() io.Reader { return r.Body } -// limitedReadCloser wraps an io.LimitedReader with a separate Closer -// so the original body can be closed. +// limitedReadCloser enforces a maximum number of bytes that may be read from +// a response body. Reading more than limit bytes returns ErrResponseTooLarge +// rather than silently truncating the body. The original body is closed via +// the separate Closer. type limitedReadCloser struct { - R io.LimitedReader - C io.Closer + r io.Reader // an io.LimitReader over the original body (limit+1 bytes) + c io.Closer // the original body, for Close + limit int64 + read int64 } func (l *limitedReadCloser) Read(p []byte) (int, error) { - return l.R.Read(p) + n, err := l.r.Read(p) + l.read += int64(n) + if l.read > l.limit { + return n, ErrResponseTooLarge + } + return n, err } func (l *limitedReadCloser) Close() error { - return l.C.Close() + return l.c.Close() } diff --git a/response_limit_test.go b/response_limit_test.go index caca2aa..4b6bc46 100644 --- a/response_limit_test.go +++ b/response_limit_test.go @@ -2,6 +2,7 @@ package httpx_test import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -32,13 +33,30 @@ func TestClient_MaxResponseBody(t *testing.T) { } }) - t.Run("truncates response exceeding limit", func(t *testing.T) { + t.Run("returns ErrResponseTooLarge when exceeding limit", func(t *testing.T) { largeBody := strings.Repeat("x", 1000) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { fmt.Fprint(w, largeBody) })) defer srv.Close() + client := httpx.New(httpx.WithMaxResponseBody(100)) + resp, err := client.Get(context.Background(), srv.URL+"/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if _, err := resp.Bytes(); !errors.Is(err, httpx.ErrResponseTooLarge) { + t.Fatalf("err = %v, want ErrResponseTooLarge", err) + } + }) + + t.Run("allows body exactly at limit", func(t *testing.T) { + exact := strings.Repeat("x", 100) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprint(w, exact) + })) + defer srv.Close() + client := httpx.New(httpx.WithMaxResponseBody(100)) resp, err := client.Get(context.Background(), srv.URL+"/") if err != nil { @@ -46,7 +64,7 @@ func TestClient_MaxResponseBody(t *testing.T) { } b, err := resp.Bytes() if err != nil { - t.Fatalf("reading body: %v", err) + t.Fatalf("reading body at exact limit: %v", err) } if len(b) != 100 { t.Fatalf("body length = %d, want %d", len(b), 100)