diff --git a/server/middleware_timeout.go b/server/middleware_timeout.go new file mode 100644 index 0000000..a447660 --- /dev/null +++ b/server/middleware_timeout.go @@ -0,0 +1,15 @@ +package server + +import ( + "net/http" + "time" +) + +// Timeout returns a middleware that limits request processing time. +// If the handler does not complete within d, the client receives a +// 503 Service Unavailable response. It wraps http.TimeoutHandler. +func Timeout(d time.Duration) Middleware { + return func(next http.Handler) http.Handler { + return http.TimeoutHandler(next, d, "Service Unavailable\n") + } +} diff --git a/server/middleware_timeout_test.go b/server/middleware_timeout_test.go new file mode 100644 index 0000000..74145d6 --- /dev/null +++ b/server/middleware_timeout_test.go @@ -0,0 +1,49 @@ +package server_test + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "git.codelab.vc/pkg/httpx/server" +) + +func TestTimeout(t *testing.T) { + t.Run("handler completes within timeout", func(t *testing.T) { + handler := server.Timeout(1 * time.Second)( + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }), + ) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("handler exceeds timeout returns 503", func(t *testing.T) { + handler := server.Timeout(10 * time.Millisecond)( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-time.After(1 * time.Second): + case <-r.Context().Done(): + } + w.WriteHeader(http.StatusOK) + }), + ) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + handler.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable) + } + }) +}