diff --git a/client_options.go b/client_options.go index 084a72b..b38f1a5 100644 --- a/client_options.go +++ b/client_options.go @@ -12,18 +12,19 @@ import ( ) type clientOptions struct { - baseURL string - timeout time.Duration - transport http.RoundTripper - logger *slog.Logger - errorMapper ErrorMapper - middlewares []middleware.Middleware - retryOpts []retry.Option - enableRetry bool - cbOpts []circuitbreaker.Option - enableCB bool - endpoints []balancer.Endpoint - balancerOpts []balancer.Option + baseURL string + timeout time.Duration + transport http.RoundTripper + logger *slog.Logger + errorMapper ErrorMapper + middlewares []middleware.Middleware + retryOpts []retry.Option + enableRetry bool + cbOpts []circuitbreaker.Option + enableCB bool + endpoints []balancer.Endpoint + balancerOpts []balancer.Option + maxResponseBody int64 } // Option configures a Client. @@ -85,3 +86,11 @@ func WithEndpoints(eps ...balancer.Endpoint) Option { func WithBalancer(opts ...balancer.Option) Option { return func(o *clientOptions) { o.balancerOpts = opts } } + +// WithMaxResponseBody limits the number of bytes read from response bodies +// by Response.Bytes (and by extension String, JSON, XML). If the response +// body exceeds n bytes, reading stops and returns an error. +// A value of 0 means no limit (the default). +func WithMaxResponseBody(n int64) Option { + return func(o *clientOptions) { o.maxResponseBody = n } +} diff --git a/response.go b/response.go index 3068cfe..af37eeb 100644 --- a/response.go +++ b/response.go @@ -97,3 +97,18 @@ func (r *Response) BodyReader() io.Reader { } return r.Body } + +// limitedReadCloser wraps an io.LimitedReader with a separate Closer +// so the original body can be closed. +type limitedReadCloser struct { + R io.LimitedReader + C io.Closer +} + +func (l *limitedReadCloser) Read(p []byte) (int, error) { + return l.R.Read(p) +} + +func (l *limitedReadCloser) Close() error { + return l.C.Close() +} diff --git a/response_limit_test.go b/response_limit_test.go new file mode 100644 index 0000000..caca2aa --- /dev/null +++ b/response_limit_test.go @@ -0,0 +1,76 @@ +package httpx_test + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.codelab.vc/pkg/httpx" +) + +func TestClient_MaxResponseBody(t *testing.T) { + t.Run("allows response within limit", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprint(w, "hello") + })) + defer srv.Close() + + client := httpx.New(httpx.WithMaxResponseBody(1024)) + resp, err := client.Get(context.Background(), srv.URL+"/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + body, err := resp.String() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if body != "hello" { + t.Fatalf("body = %q, want %q", body, "hello") + } + }) + + t.Run("truncates response exceeding limit", func(t *testing.T) { + largeBody := strings.Repeat("x", 1000) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprint(w, largeBody) + })) + defer srv.Close() + + client := httpx.New(httpx.WithMaxResponseBody(100)) + resp, err := client.Get(context.Background(), srv.URL+"/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + b, err := resp.Bytes() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if len(b) != 100 { + t.Fatalf("body length = %d, want %d", len(b), 100) + } + }) + + t.Run("no limit when zero", func(t *testing.T) { + largeBody := strings.Repeat("x", 10000) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprint(w, largeBody) + })) + defer srv.Close() + + client := httpx.New() + resp, err := client.Get(context.Background(), srv.URL+"/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + b, err := resp.Bytes() + if err != nil { + t.Fatalf("reading body: %v", err) + } + if len(b) != 10000 { + t.Fatalf("body length = %d, want %d", len(b), 10000) + } + }) +}