From cea75d198b69aeac04677ba5f52df882bd3749ae Mon Sep 17 00:00:00 2001 From: Aleksey Shakhmatov Date: Sat, 21 Mar 2026 13:41:54 +0300 Subject: [PATCH] Add production-ready HTTP server package with routing, health checks, and middleware Introduces server/ sub-package as the server-side companion to the existing Client. Includes Router (over http.ServeMux with groups and mounting), graceful shutdown with signal handling, health endpoints (/healthz, /readyz), and built-in middlewares (RequestID, Recovery, Logging). Zero external dependencies. Co-Authored-By: Claude Opus 4.6 (1M context) --- server/health.go | 55 ++++++++ server/health_test.go | 90 +++++++++++++ server/middleware.go | 56 ++++++++ server/middleware_logging.go | 39 ++++++ server/middleware_recovery.go | 50 +++++++ server/middleware_requestid.go | 52 ++++++++ server/middleware_test.go | 234 +++++++++++++++++++++++++++++++++ server/options.go | 89 +++++++++++++ server/route.go | 103 +++++++++++++++ server/route_test.go | 143 ++++++++++++++++++++ server/server.go | 173 ++++++++++++++++++++++++ server/server_test.go | 131 ++++++++++++++++++ 12 files changed, 1215 insertions(+) create mode 100644 server/health.go create mode 100644 server/health_test.go create mode 100644 server/middleware.go create mode 100644 server/middleware_logging.go create mode 100644 server/middleware_recovery.go create mode 100644 server/middleware_requestid.go create mode 100644 server/middleware_test.go create mode 100644 server/options.go create mode 100644 server/route.go create mode 100644 server/route_test.go create mode 100644 server/server.go create mode 100644 server/server_test.go diff --git a/server/health.go b/server/health.go new file mode 100644 index 0000000..0a2e6a8 --- /dev/null +++ b/server/health.go @@ -0,0 +1,55 @@ +package server + +import ( + "encoding/json" + "net/http" +) + +// ReadinessChecker is a function that reports whether a dependency is ready. +// Return nil if healthy, or an error describing the problem. +type ReadinessChecker func() error + +// HealthHandler returns an http.Handler that exposes liveness and readiness +// endpoints: +// +// - GET /healthz — liveness check, always returns 200 OK +// - GET /readyz — readiness check, returns 200 if all checkers pass, 503 otherwise +func HealthHandler(checkers ...ReadinessChecker) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"}) + }) + + mux.HandleFunc("GET /readyz", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var errs []string + for _, check := range checkers { + if err := check(); err != nil { + errs = append(errs, err.Error()) + } + } + + if len(errs) > 0 { + w.WriteHeader(http.StatusServiceUnavailable) + _ = json.NewEncoder(w).Encode(healthResponse{ + Status: "unavailable", + Errors: errs, + }) + return + } + + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"}) + }) + + return mux +} + +type healthResponse struct { + Status string `json:"status"` + Errors []string `json:"errors,omitempty"` +} diff --git a/server/health_test.go b/server/health_test.go new file mode 100644 index 0000000..d543fe3 --- /dev/null +++ b/server/health_test.go @@ -0,0 +1,90 @@ +package server_test + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "git.codelab.vc/pkg/httpx/server" +) + +func TestHealthHandler(t *testing.T) { + t.Run("liveness always returns 200", func(t *testing.T) { + h := server.HealthHandler() + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/healthz", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + + var resp map[string]any + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode failed: %v", err) + } + if resp["status"] != "ok" { + t.Fatalf("got status %q, want %q", resp["status"], "ok") + } + }) + + t.Run("readiness returns 200 when all checks pass", func(t *testing.T) { + h := server.HealthHandler( + func() error { return nil }, + func() error { return nil }, + ) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/readyz", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("readiness returns 503 when a check fails", func(t *testing.T) { + h := server.HealthHandler( + func() error { return nil }, + func() error { return errors.New("db down") }, + ) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/readyz", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable) + } + + var resp map[string]any + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("decode failed: %v", err) + } + if resp["status"] != "unavailable" { + t.Fatalf("got status %q, want %q", resp["status"], "unavailable") + } + errs, ok := resp["errors"].([]any) + if !ok || len(errs) != 1 { + t.Fatalf("expected 1 error, got %v", resp["errors"]) + } + if errs[0] != "db down" { + t.Fatalf("got error %q, want %q", errs[0], "db down") + } + }) + + t.Run("readiness returns 200 with no checkers", func(t *testing.T) { + h := server.HealthHandler() + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/readyz", nil) + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + }) +} diff --git a/server/middleware.go b/server/middleware.go new file mode 100644 index 0000000..9b7b79f --- /dev/null +++ b/server/middleware.go @@ -0,0 +1,56 @@ +package server + +import "net/http" + +// Middleware wraps an http.Handler to add behavior. +// This is the server-side counterpart of the client middleware type +// func(http.RoundTripper) http.RoundTripper. +type Middleware func(http.Handler) http.Handler + +// Chain composes middlewares so that Chain(A, B, C)(handler) == A(B(C(handler))). +// Middlewares are applied from right to left: C wraps handler first, then B wraps +// the result, then A wraps last. This means A is the outermost layer and sees +// every request first. +func Chain(mws ...Middleware) Middleware { + return func(h http.Handler) http.Handler { + for i := len(mws) - 1; i >= 0; i-- { + h = mws[i](h) + } + return h + } +} + +// statusWriter wraps http.ResponseWriter to capture the response status code. +// It implements Unwrap() so that http.ResponseController can access the +// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.). +type statusWriter struct { + http.ResponseWriter + status int + written bool +} + +// WriteHeader captures the status code and delegates to the underlying writer. +func (w *statusWriter) WriteHeader(code int) { + if !w.written { + w.status = code + w.written = true + } + w.ResponseWriter.WriteHeader(code) +} + +// Write delegates to the underlying writer, defaulting status to 200 if +// WriteHeader was not called explicitly. +func (w *statusWriter) Write(b []byte) (int, error) { + if !w.written { + w.status = http.StatusOK + w.written = true + } + return w.ResponseWriter.Write(b) +} + +// Unwrap returns the underlying ResponseWriter. This is required for +// http.ResponseController to detect optional interfaces like http.Flusher +// and http.Hijacker on the original writer. +func (w *statusWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} diff --git a/server/middleware_logging.go b/server/middleware_logging.go new file mode 100644 index 0000000..cf86829 --- /dev/null +++ b/server/middleware_logging.go @@ -0,0 +1,39 @@ +package server + +import ( + "log/slog" + "net/http" + "time" +) + +// Logging returns a middleware that logs each request's method, path, +// status code, duration, and request ID using the provided structured logger. +func Logging(logger *slog.Logger) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + + sw := &statusWriter{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(sw, r) + + duration := time.Since(start) + attrs := []slog.Attr{ + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", sw.status), + slog.Duration("duration", duration), + } + + if id := RequestIDFromContext(r.Context()); id != "" { + attrs = append(attrs, slog.String("request_id", id)) + } + + level := slog.LevelInfo + if sw.status >= http.StatusInternalServerError { + level = slog.LevelError + } + + logger.LogAttrs(r.Context(), level, "request completed", attrs...) + }) + } +} diff --git a/server/middleware_recovery.go b/server/middleware_recovery.go new file mode 100644 index 0000000..692ca50 --- /dev/null +++ b/server/middleware_recovery.go @@ -0,0 +1,50 @@ +package server + +import ( + "log/slog" + "net/http" + "runtime/debug" +) + +// RecoveryOption configures the Recovery middleware. +type RecoveryOption func(*recoveryOptions) + +type recoveryOptions struct { + logger *slog.Logger +} + +// WithRecoveryLogger sets the logger for the Recovery middleware. +// If not set, panics are recovered silently (500 is still returned). +func WithRecoveryLogger(l *slog.Logger) RecoveryOption { + return func(o *recoveryOptions) { o.logger = l } +} + +// Recovery returns a middleware that recovers from panics in downstream +// handlers. A recovered panic results in a 500 Internal Server Error +// response and is logged (if a logger is configured) with the stack trace. +func Recovery(opts ...RecoveryOption) Middleware { + o := &recoveryOptions{} + for _, opt := range opts { + opt(o) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if v := recover(); v != nil { + if o.logger != nil { + o.logger.LogAttrs(r.Context(), slog.LevelError, "panic recovered", + slog.Any("panic", v), + slog.String("stack", string(debug.Stack())), + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.String("request_id", RequestIDFromContext(r.Context())), + ) + } + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) + } +} diff --git a/server/middleware_requestid.go b/server/middleware_requestid.go new file mode 100644 index 0000000..081a5a7 --- /dev/null +++ b/server/middleware_requestid.go @@ -0,0 +1,52 @@ +package server + +import ( + "context" + "crypto/rand" + "fmt" + "net/http" +) + +type requestIDKey struct{} + +// RequestID returns a middleware that assigns a unique request ID to each +// request. If the incoming request already has an X-Request-Id header, that +// value is used. Otherwise a new UUID v4 is generated via crypto/rand. +// +// The request ID is stored in the request context (retrieve with +// RequestIDFromContext) and set on the response X-Request-Id header. +func RequestID() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + id := r.Header.Get("X-Request-Id") + if id == "" { + id = newUUID() + } + + ctx := context.WithValue(r.Context(), requestIDKey{}, id) + w.Header().Set("X-Request-Id", id) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// RequestIDFromContext returns the request ID from the context, or an empty +// string if none is set. +func RequestIDFromContext(ctx context.Context) string { + id, _ := ctx.Value(requestIDKey{}).(string) + return id +} + +// newUUID generates a UUID v4 string using crypto/rand. +func newUUID() string { + var uuid [16]byte + _, _ = rand.Read(uuid[:]) + + // Set version 4 (bits 12-15 of time_hi_and_version). + uuid[6] = (uuid[6] & 0x0f) | 0x40 + // Set variant bits (10xx). + uuid[8] = (uuid[8] & 0x3f) | 0x80 + + return fmt.Sprintf("%x-%x-%x-%x-%x", + uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16]) +} diff --git a/server/middleware_test.go b/server/middleware_test.go new file mode 100644 index 0000000..0be9f6c --- /dev/null +++ b/server/middleware_test.go @@ -0,0 +1,234 @@ +package server_test + +import ( + "bytes" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.codelab.vc/pkg/httpx/server" +) + +func TestChain(t *testing.T) { + t.Run("applies middlewares in correct order", func(t *testing.T) { + var order []string + + mwA := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + order = append(order, "A-before") + next.ServeHTTP(w, r) + order = append(order, "A-after") + }) + } + + mwB := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + order = append(order, "B-before") + next.ServeHTTP(w, r) + order = append(order, "B-after") + }) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + order = append(order, "handler") + w.WriteHeader(http.StatusOK) + }) + + chained := server.Chain(mwA, mwB)(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + chained.ServeHTTP(w, req) + + expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"} + if len(order) != len(expected) { + t.Fatalf("got %v, want %v", order, expected) + } + for i, v := range expected { + if order[i] != v { + t.Fatalf("order[%d] = %q, want %q", i, order[i], v) + } + } + }) + + t.Run("empty chain returns handler unchanged", func(t *testing.T) { + called := false + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + chained := server.Chain()(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + chained.ServeHTTP(w, req) + + if !called { + t.Fatal("handler was not called") + } + }) +} + +func TestRequestID(t *testing.T) { + t.Run("generates ID when not present", func(t *testing.T) { + var gotID string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotID = server.RequestIDFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + mw := server.RequestID()(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + mw.ServeHTTP(w, req) + + if gotID == "" { + t.Fatal("expected request ID in context, got empty") + } + if w.Header().Get("X-Request-Id") != gotID { + t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID) + } + // UUID v4 format: 8-4-4-4-12 hex chars. + if len(gotID) != 36 { + t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID) + } + }) + + t.Run("preserves existing ID", func(t *testing.T) { + var gotID string + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotID = server.RequestIDFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + mw := server.RequestID()(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-Request-Id", "custom-123") + mw.ServeHTTP(w, req) + + if gotID != "custom-123" { + t.Fatalf("got ID %q, want %q", gotID, "custom-123") + } + }) + + t.Run("context without ID returns empty", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if id := server.RequestIDFromContext(req.Context()); id != "" { + t.Fatalf("expected empty, got %q", id) + } + }) +} + +func TestRecovery(t *testing.T) { + t.Run("recovers from panic and returns 500", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + panic("something went wrong") + }) + + mw := server.Recovery()(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + mw.ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError) + } + }) + + t.Run("logs panic with logger", func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + panic("boom") + }) + + mw := server.Recovery(server.WithRecoveryLogger(logger))(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + mw.ServeHTTP(w, req) + + if !strings.Contains(buf.String(), "panic recovered") { + t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String()) + } + if !strings.Contains(buf.String(), "boom") { + t.Fatalf("expected log to contain 'boom', got %q", buf.String()) + } + }) + + t.Run("passes through without panic", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + mw := server.Recovery()(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + mw.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + }) +} + +func TestLogging(t *testing.T) { + t.Run("logs request details", func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusCreated) + }) + + mw := server.Logging(logger)(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/users", nil) + mw.ServeHTTP(w, req) + + logOutput := buf.String() + if !strings.Contains(logOutput, "request completed") { + t.Fatalf("expected 'request completed' in log, got %q", logOutput) + } + if !strings.Contains(logOutput, "POST") { + t.Fatalf("expected method in log, got %q", logOutput) + } + if !strings.Contains(logOutput, "/api/users") { + t.Fatalf("expected path in log, got %q", logOutput) + } + if !strings.Contains(logOutput, "status=201") { + t.Fatalf("expected status=201 in log, got %q", logOutput) + } + }) + + t.Run("logs error level for 5xx", func(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + }) + + mw := server.Logging(logger)(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + mw.ServeHTTP(w, req) + + logOutput := buf.String() + if !strings.Contains(logOutput, "level=ERROR") { + t.Fatalf("expected ERROR level in log, got %q", logOutput) + } + }) +} diff --git a/server/options.go b/server/options.go new file mode 100644 index 0000000..644a2f3 --- /dev/null +++ b/server/options.go @@ -0,0 +1,89 @@ +package server + +import ( + "log/slog" + "time" +) + +type serverOptions struct { + addr string + readTimeout time.Duration + readHeaderTimeout time.Duration + writeTimeout time.Duration + idleTimeout time.Duration + shutdownTimeout time.Duration + logger *slog.Logger + middlewares []Middleware + onShutdown []func() +} + +// Option configures a Server. +type Option func(*serverOptions) + +// WithAddr sets the listen address. Default is ":8080". +func WithAddr(addr string) Option { + return func(o *serverOptions) { o.addr = addr } +} + +// WithReadTimeout sets the maximum duration for reading the entire request. +func WithReadTimeout(d time.Duration) Option { + return func(o *serverOptions) { o.readTimeout = d } +} + +// WithReadHeaderTimeout sets the maximum duration for reading request headers. +func WithReadHeaderTimeout(d time.Duration) Option { + return func(o *serverOptions) { o.readHeaderTimeout = d } +} + +// WithWriteTimeout sets the maximum duration before timing out writes of the response. +func WithWriteTimeout(d time.Duration) Option { + return func(o *serverOptions) { o.writeTimeout = d } +} + +// WithIdleTimeout sets the maximum amount of time to wait for the next request +// when keep-alives are enabled. +func WithIdleTimeout(d time.Duration) Option { + return func(o *serverOptions) { o.idleTimeout = d } +} + +// WithShutdownTimeout sets the maximum duration to wait for active connections +// to close during graceful shutdown. Default is 15 seconds. +func WithShutdownTimeout(d time.Duration) Option { + return func(o *serverOptions) { o.shutdownTimeout = d } +} + +// WithLogger sets the structured logger used by the server for lifecycle events. +func WithLogger(l *slog.Logger) Option { + return func(o *serverOptions) { o.logger = l } +} + +// WithMiddleware appends server middlewares to the chain. +// These are applied to the handler in the order given. +func WithMiddleware(mws ...Middleware) Option { + return func(o *serverOptions) { o.middlewares = append(o.middlewares, mws...) } +} + +// WithOnShutdown registers a function to be called during graceful shutdown, +// before the HTTP server begins draining connections. +func WithOnShutdown(fn func()) Option { + return func(o *serverOptions) { o.onShutdown = append(o.onShutdown, fn) } +} + +// Defaults returns a production-ready set of options including standard +// middleware (RequestID, Recovery, Logging), sensible timeouts, and the +// provided logger. +// +// Middleware order: RequestID → Recovery → Logging → user handler. +func Defaults(logger *slog.Logger) []Option { + return []Option{ + WithReadHeaderTimeout(10 * time.Second), + WithIdleTimeout(120 * time.Second), + WithShutdownTimeout(15 * time.Second), + WithLogger(logger), + WithMiddleware( + RequestID(), + Recovery(WithRecoveryLogger(logger)), + Logging(logger), + ), + } +} diff --git a/server/route.go b/server/route.go new file mode 100644 index 0000000..27599ca --- /dev/null +++ b/server/route.go @@ -0,0 +1,103 @@ +package server + +import ( + "net/http" + "strings" +) + +// Router is a lightweight wrapper around http.ServeMux that adds middleware +// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns +// like "GET /users/{id}". +type Router struct { + mux *http.ServeMux + prefix string + middlewares []Middleware +} + +// NewRouter creates a new Router backed by a fresh http.ServeMux. +func NewRouter() *Router { + return &Router{ + mux: http.NewServeMux(), + } +} + +// Handle registers a handler for the given pattern. The pattern follows +// http.ServeMux conventions, including method-based patterns like "GET /users". +func (r *Router) Handle(pattern string, handler http.Handler) { + if len(r.middlewares) > 0 { + handler = Chain(r.middlewares...)(handler) + } + r.mux.Handle(r.prefixedPattern(pattern), handler) +} + +// HandleFunc registers a handler function for the given pattern. +func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) { + r.Handle(pattern, fn) +} + +// Group creates a sub-router with a shared prefix and optional middleware. +// Patterns registered on the group are prefixed automatically. The group +// shares the underlying ServeMux with the parent router. +// +// Example: +// +// api := router.Group("/api/v1", authMiddleware) +// api.HandleFunc("GET /users", listUsers) // registers "GET /api/v1/users" +func (r *Router) Group(prefix string, mws ...Middleware) *Router { + return &Router{ + mux: r.mux, + prefix: r.prefix + prefix, + middlewares: append(r.middlewaresSnapshot(), mws...), + } +} + +// Mount attaches an http.Handler under the given prefix. All requests +// starting with prefix are forwarded to the handler with the prefix stripped. +func (r *Router) Mount(prefix string, handler http.Handler) { + full := r.prefix + prefix + if !strings.HasSuffix(full, "/") { + full += "/" + } + r.mux.Handle(full, http.StripPrefix(strings.TrimSuffix(full, "/"), handler)) +} + +// ServeHTTP implements http.Handler, making Router usable as a handler. +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + r.mux.ServeHTTP(w, req) +} + +// prefixedPattern inserts the router prefix into a pattern. It is aware of +// method prefixes: "GET /users" with prefix "/api" becomes "GET /api/users". +func (r *Router) prefixedPattern(pattern string) string { + if r.prefix == "" { + return pattern + } + + // Split method prefix if present: "GET /users" → method="GET ", path="/users" + method, path, hasMethod := splitMethodPattern(pattern) + + path = r.prefix + path + + if hasMethod { + return method + path + } + return path +} + +// splitMethodPattern splits "GET /path" into ("GET ", "/path", true). +// If there is no method prefix, returns ("", pattern, false). +func splitMethodPattern(pattern string) (method, path string, hasMethod bool) { + if idx := strings.IndexByte(pattern, ' '); idx >= 0 { + return pattern[:idx+1], pattern[idx+1:], true + } + return "", pattern, false +} + +func (r *Router) middlewaresSnapshot() []Middleware { + if len(r.middlewares) == 0 { + return nil + } + cp := make([]Middleware, len(r.middlewares)) + copy(cp, r.middlewares) + return cp +} diff --git a/server/route_test.go b/server/route_test.go new file mode 100644 index 0000000..aba99e7 --- /dev/null +++ b/server/route_test.go @@ -0,0 +1,143 @@ +package server_test + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.codelab.vc/pkg/httpx/server" +) + +func TestRouter(t *testing.T) { + t.Run("basic route", func(t *testing.T) { + r := server.NewRouter() + r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("world")) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/hello", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + if body := w.Body.String(); body != "world" { + t.Fatalf("got body %q, want %q", body, "world") + } + }) + + t.Run("Handle with http.Handler", func(t *testing.T) { + r := server.NewRouter() + r.Handle("GET /ping", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("pong")) + })) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/ping", nil) + r.ServeHTTP(w, req) + + if body := w.Body.String(); body != "pong" { + t.Fatalf("got body %q, want %q", body, "pong") + } + }) + + t.Run("path parameter", func(t *testing.T) { + r := server.NewRouter() + r.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) { + _, _ = w.Write([]byte("user:" + req.PathValue("id"))) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/users/42", nil) + r.ServeHTTP(w, req) + + if body := w.Body.String(); body != "user:42" { + t.Fatalf("got body %q, want %q", body, "user:42") + } + }) +} + +func TestRouterGroup(t *testing.T) { + t.Run("prefix is applied", func(t *testing.T) { + r := server.NewRouter() + api := r.Group("/api/v1") + api.HandleFunc("GET /users", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("users")) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("got status %d, want %d", w.Code, http.StatusOK) + } + if body := w.Body.String(); body != "users" { + t.Fatalf("got body %q, want %q", body, "users") + } + }) + + t.Run("nested groups", func(t *testing.T) { + r := server.NewRouter() + api := r.Group("/api") + v1 := api.Group("/v1") + v1.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("items")) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil) + r.ServeHTTP(w, req) + + if body := w.Body.String(); body != "items" { + t.Fatalf("got body %q, want %q", body, "items") + } + }) + + t.Run("group middleware", func(t *testing.T) { + var mwCalled bool + mw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mwCalled = true + next.ServeHTTP(w, r) + }) + } + + r := server.NewRouter() + g := r.Group("/admin", mw) + g.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("ok")) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil) + r.ServeHTTP(w, req) + + if !mwCalled { + t.Fatal("group middleware was not called") + } + }) +} + +func TestRouterMount(t *testing.T) { + t.Run("mounts sub-handler with prefix stripping", func(t *testing.T) { + sub := http.NewServeMux() + sub.HandleFunc("GET /info", func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("info")) + }) + + r := server.NewRouter() + r.Mount("/sub", sub) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/sub/info", nil) + r.ServeHTTP(w, req) + + body, _ := io.ReadAll(w.Body) + if string(body) != "info" { + t.Fatalf("got body %q, want %q", body, "info") + } + }) +} diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..40b2b72 --- /dev/null +++ b/server/server.go @@ -0,0 +1,173 @@ +package server + +import ( + "context" + "errors" + "log/slog" + "net" + "net/http" + "os/signal" + "sync/atomic" + "syscall" + "time" +) + +// Server is a production-ready HTTP server with graceful shutdown, +// middleware support, and signal handling. +type Server struct { + httpServer *http.Server + listener net.Listener + addr atomic.Value + logger *slog.Logger + shutdownTimeout time.Duration + onShutdown []func() + listenAddr string +} + +// New creates a new Server that will serve the given handler with the +// provided options. Middleware from options is applied to the handler. +func New(handler http.Handler, opts ...Option) *Server { + o := &serverOptions{ + addr: ":8080", + shutdownTimeout: 15 * time.Second, + } + for _, opt := range opts { + opt(o) + } + + // Apply middleware chain to the handler. + if len(o.middlewares) > 0 { + handler = Chain(o.middlewares...)(handler) + } + + srv := &Server{ + httpServer: &http.Server{ + Handler: handler, + ReadTimeout: o.readTimeout, + ReadHeaderTimeout: o.readHeaderTimeout, + WriteTimeout: o.writeTimeout, + IdleTimeout: o.idleTimeout, + }, + logger: o.logger, + shutdownTimeout: o.shutdownTimeout, + onShutdown: o.onShutdown, + listenAddr: o.addr, + } + + return srv +} + +// ListenAndServe starts the server and blocks until a SIGINT or SIGTERM +// signal is received. It then performs a graceful shutdown within the +// configured shutdown timeout. +// +// Returns nil on clean shutdown or an error if listen/shutdown fails. +func (s *Server) ListenAndServe() error { + ln, err := net.Listen("tcp", s.listenAddr) + if err != nil { + return err + } + s.listener = ln + s.addr.Store(ln.Addr().String()) + + s.log("server started", slog.String("addr", ln.Addr().String())) + + // Wait for signal in context. + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 1) + go func() { + if err := s.httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + close(errCh) + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + stop() + return s.shutdown() + } +} + +// ListenAndServeTLS starts the server with TLS and blocks until a signal +// is received. +func (s *Server) ListenAndServeTLS(certFile, keyFile string) error { + ln, err := net.Listen("tcp", s.listenAddr) + if err != nil { + return err + } + s.listener = ln + s.addr.Store(ln.Addr().String()) + + s.log("server started (TLS)", slog.String("addr", ln.Addr().String())) + + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 1) + go func() { + if err := s.httpServer.ServeTLS(ln, certFile, keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) { + errCh <- err + } + close(errCh) + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + stop() + return s.shutdown() + } +} + +// Shutdown gracefully shuts down the server. It calls any registered +// onShutdown hooks, then waits for active connections to drain within +// the shutdown timeout. +func (s *Server) Shutdown(ctx context.Context) error { + s.runOnShutdown() + return s.httpServer.Shutdown(ctx) +} + +// Addr returns the listener address after the server has started. +// Returns an empty string if the server has not started yet. +func (s *Server) Addr() string { + v := s.addr.Load() + if v == nil { + return "" + } + return v.(string) +} + +func (s *Server) shutdown() error { + s.log("shutting down") + + s.runOnShutdown() + + ctx, cancel := context.WithTimeout(context.Background(), s.shutdownTimeout) + defer cancel() + + if err := s.httpServer.Shutdown(ctx); err != nil { + s.log("shutdown error", slog.String("error", err.Error())) + return err + } + + s.log("server stopped") + return nil +} + +func (s *Server) runOnShutdown() { + for _, fn := range s.onShutdown { + fn() + } +} + +func (s *Server) log(msg string, attrs ...slog.Attr) { + if s.logger != nil { + s.logger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrs...) + } +} diff --git a/server/server_test.go b/server/server_test.go new file mode 100644 index 0000000..775ce04 --- /dev/null +++ b/server/server_test.go @@ -0,0 +1,131 @@ +package server_test + +import ( + "context" + "io" + "net/http" + "testing" + "time" + + "git.codelab.vc/pkg/httpx/server" +) + +func TestServerLifecycle(t *testing.T) { + t.Run("starts and serves requests", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("hello")) + }) + + srv := server.New(handler, server.WithAddr(":0")) + + // Start in background and wait for addr. + errCh := make(chan error, 1) + go func() { errCh <- srv.ListenAndServe() }() + + waitForAddr(t, srv) + + resp, err := http.Get("http://" + srv.Addr()) + if err != nil { + t.Fatalf("GET failed: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if string(body) != "hello" { + t.Fatalf("got body %q, want %q", body, "hello") + } + + // Shutdown. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + t.Fatalf("shutdown failed: %v", err) + } + }) + + t.Run("addr returns empty before start", func(t *testing.T) { + srv := server.New(http.NotFoundHandler()) + if addr := srv.Addr(); addr != "" { + t.Fatalf("got addr %q before start, want empty", addr) + } + }) +} + +func TestGracefulShutdown(t *testing.T) { + t.Run("calls onShutdown hooks", func(t *testing.T) { + called := false + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + srv := server.New(handler, + server.WithAddr(":0"), + server.WithOnShutdown(func() { called = true }), + ) + + go func() { _ = srv.ListenAndServe() }() + waitForAddr(t, srv) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + t.Fatalf("shutdown failed: %v", err) + } + + if !called { + t.Fatal("onShutdown hook was not called") + } + }) +} + +func TestServerWithMiddleware(t *testing.T) { + t.Run("applies middleware from options", func(t *testing.T) { + var called bool + mw := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + next.ServeHTTP(w, r) + }) + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + srv := server.New(handler, + server.WithAddr(":0"), + server.WithMiddleware(mw), + ) + + go func() { _ = srv.ListenAndServe() }() + waitForAddr(t, srv) + + resp, err := http.Get("http://" + srv.Addr()) + if err != nil { + t.Fatalf("GET failed: %v", err) + } + resp.Body.Close() + + if !called { + t.Fatal("middleware was not called") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = srv.Shutdown(ctx) + }) +} + +// waitForAddr polls until the server's Addr() is non-empty. +func waitForAddr(t *testing.T, srv *server.Server) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if srv.Addr() != "" { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatal("server did not start in time") +}