Add production-ready HTTP server package with routing, health checks, and middleware

Introduces server/ sub-package as the server-side companion to the existing Client.
Includes Router (over http.ServeMux with groups and mounting), graceful shutdown with
signal handling, health endpoints (/healthz, /readyz), and built-in middlewares
(RequestID, Recovery, Logging). Zero external dependencies.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-21 13:41:54 +03:00
parent 6b901c931e
commit cea75d198b
12 changed files with 1215 additions and 0 deletions

55
server/health.go Normal file
View File

@@ -0,0 +1,55 @@
package server
import (
"encoding/json"
"net/http"
)
// ReadinessChecker is a function that reports whether a dependency is ready.
// Return nil if healthy, or an error describing the problem.
type ReadinessChecker func() error
// HealthHandler returns an http.Handler that exposes liveness and readiness
// endpoints:
//
// - GET /healthz — liveness check, always returns 200 OK
// - GET /readyz — readiness check, returns 200 if all checkers pass, 503 otherwise
func HealthHandler(checkers ...ReadinessChecker) http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
})
mux.HandleFunc("GET /readyz", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
var errs []string
for _, check := range checkers {
if err := check(); err != nil {
errs = append(errs, err.Error())
}
}
if len(errs) > 0 {
w.WriteHeader(http.StatusServiceUnavailable)
_ = json.NewEncoder(w).Encode(healthResponse{
Status: "unavailable",
Errors: errs,
})
return
}
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
})
return mux
}
type healthResponse struct {
Status string `json:"status"`
Errors []string `json:"errors,omitempty"`
}

90
server/health_test.go Normal file
View File

@@ -0,0 +1,90 @@
package server_test
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestHealthHandler(t *testing.T) {
t.Run("liveness always returns 200", func(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
if resp["status"] != "ok" {
t.Fatalf("got status %q, want %q", resp["status"], "ok")
}
})
t.Run("readiness returns 200 when all checks pass", func(t *testing.T) {
h := server.HealthHandler(
func() error { return nil },
func() error { return nil },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
t.Run("readiness returns 503 when a check fails", func(t *testing.T) {
h := server.HealthHandler(
func() error { return nil },
func() error { return errors.New("db down") },
)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
}
var resp map[string]any
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatalf("decode failed: %v", err)
}
if resp["status"] != "unavailable" {
t.Fatalf("got status %q, want %q", resp["status"], "unavailable")
}
errs, ok := resp["errors"].([]any)
if !ok || len(errs) != 1 {
t.Fatalf("expected 1 error, got %v", resp["errors"])
}
if errs[0] != "db down" {
t.Fatalf("got error %q, want %q", errs[0], "db down")
}
})
t.Run("readiness returns 200 with no checkers", func(t *testing.T) {
h := server.HealthHandler()
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
h.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
}

56
server/middleware.go Normal file
View File

@@ -0,0 +1,56 @@
package server
import "net/http"
// Middleware wraps an http.Handler to add behavior.
// This is the server-side counterpart of the client middleware type
// func(http.RoundTripper) http.RoundTripper.
type Middleware func(http.Handler) http.Handler
// Chain composes middlewares so that Chain(A, B, C)(handler) == A(B(C(handler))).
// Middlewares are applied from right to left: C wraps handler first, then B wraps
// the result, then A wraps last. This means A is the outermost layer and sees
// every request first.
func Chain(mws ...Middleware) Middleware {
return func(h http.Handler) http.Handler {
for i := len(mws) - 1; i >= 0; i-- {
h = mws[i](h)
}
return h
}
}
// statusWriter wraps http.ResponseWriter to capture the response status code.
// It implements Unwrap() so that http.ResponseController can access the
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
type statusWriter struct {
http.ResponseWriter
status int
written bool
}
// WriteHeader captures the status code and delegates to the underlying writer.
func (w *statusWriter) WriteHeader(code int) {
if !w.written {
w.status = code
w.written = true
}
w.ResponseWriter.WriteHeader(code)
}
// Write delegates to the underlying writer, defaulting status to 200 if
// WriteHeader was not called explicitly.
func (w *statusWriter) Write(b []byte) (int, error) {
if !w.written {
w.status = http.StatusOK
w.written = true
}
return w.ResponseWriter.Write(b)
}
// Unwrap returns the underlying ResponseWriter. This is required for
// http.ResponseController to detect optional interfaces like http.Flusher
// and http.Hijacker on the original writer.
func (w *statusWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}

View File

@@ -0,0 +1,39 @@
package server
import (
"log/slog"
"net/http"
"time"
)
// Logging returns a middleware that logs each request's method, path,
// status code, duration, and request ID using the provided structured logger.
func Logging(logger *slog.Logger) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(sw, r)
duration := time.Since(start)
attrs := []slog.Attr{
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int("status", sw.status),
slog.Duration("duration", duration),
}
if id := RequestIDFromContext(r.Context()); id != "" {
attrs = append(attrs, slog.String("request_id", id))
}
level := slog.LevelInfo
if sw.status >= http.StatusInternalServerError {
level = slog.LevelError
}
logger.LogAttrs(r.Context(), level, "request completed", attrs...)
})
}
}

View File

@@ -0,0 +1,50 @@
package server
import (
"log/slog"
"net/http"
"runtime/debug"
)
// RecoveryOption configures the Recovery middleware.
type RecoveryOption func(*recoveryOptions)
type recoveryOptions struct {
logger *slog.Logger
}
// WithRecoveryLogger sets the logger for the Recovery middleware.
// If not set, panics are recovered silently (500 is still returned).
func WithRecoveryLogger(l *slog.Logger) RecoveryOption {
return func(o *recoveryOptions) { o.logger = l }
}
// Recovery returns a middleware that recovers from panics in downstream
// handlers. A recovered panic results in a 500 Internal Server Error
// response and is logged (if a logger is configured) with the stack trace.
func Recovery(opts ...RecoveryOption) Middleware {
o := &recoveryOptions{}
for _, opt := range opts {
opt(o)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if v := recover(); v != nil {
if o.logger != nil {
o.logger.LogAttrs(r.Context(), slog.LevelError, "panic recovered",
slog.Any("panic", v),
slog.String("stack", string(debug.Stack())),
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.String("request_id", RequestIDFromContext(r.Context())),
)
}
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}()
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,52 @@
package server
import (
"context"
"crypto/rand"
"fmt"
"net/http"
)
type requestIDKey struct{}
// RequestID returns a middleware that assigns a unique request ID to each
// request. If the incoming request already has an X-Request-Id header, that
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
//
// The request ID is stored in the request context (retrieve with
// RequestIDFromContext) and set on the response X-Request-Id header.
func RequestID() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-Id")
if id == "" {
id = newUUID()
}
ctx := context.WithValue(r.Context(), requestIDKey{}, id)
w.Header().Set("X-Request-Id", id)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RequestIDFromContext returns the request ID from the context, or an empty
// string if none is set.
func RequestIDFromContext(ctx context.Context) string {
id, _ := ctx.Value(requestIDKey{}).(string)
return id
}
// newUUID generates a UUID v4 string using crypto/rand.
func newUUID() string {
var uuid [16]byte
_, _ = rand.Read(uuid[:])
// Set version 4 (bits 12-15 of time_hi_and_version).
uuid[6] = (uuid[6] & 0x0f) | 0x40
// Set variant bits (10xx).
uuid[8] = (uuid[8] & 0x3f) | 0x80
return fmt.Sprintf("%x-%x-%x-%x-%x",
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
}

234
server/middleware_test.go Normal file
View File

@@ -0,0 +1,234 @@
package server_test
import (
"bytes"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestChain(t *testing.T) {
t.Run("applies middlewares in correct order", func(t *testing.T) {
var order []string
mwA := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "A-before")
next.ServeHTTP(w, r)
order = append(order, "A-after")
})
}
mwB := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
order = append(order, "B-before")
next.ServeHTTP(w, r)
order = append(order, "B-after")
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
order = append(order, "handler")
w.WriteHeader(http.StatusOK)
})
chained := server.Chain(mwA, mwB)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"}
if len(order) != len(expected) {
t.Fatalf("got %v, want %v", order, expected)
}
for i, v := range expected {
if order[i] != v {
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
}
}
})
t.Run("empty chain returns handler unchanged", func(t *testing.T) {
called := false
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
chained := server.Chain()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
chained.ServeHTTP(w, req)
if !called {
t.Fatal("handler was not called")
}
})
}
func TestRequestID(t *testing.T) {
t.Run("generates ID when not present", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if gotID == "" {
t.Fatal("expected request ID in context, got empty")
}
if w.Header().Get("X-Request-Id") != gotID {
t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID)
}
// UUID v4 format: 8-4-4-4-12 hex chars.
if len(gotID) != 36 {
t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID)
}
})
t.Run("preserves existing ID", func(t *testing.T) {
var gotID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotID = server.RequestIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
mw := server.RequestID()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("X-Request-Id", "custom-123")
mw.ServeHTTP(w, req)
if gotID != "custom-123" {
t.Fatalf("got ID %q, want %q", gotID, "custom-123")
}
})
t.Run("context without ID returns empty", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if id := server.RequestIDFromContext(req.Context()); id != "" {
t.Fatalf("expected empty, got %q", id)
}
})
}
func TestRecovery(t *testing.T) {
t.Run("recovers from panic and returns 500", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("something went wrong")
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
}
})
t.Run("logs panic with logger", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
panic("boom")
})
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if !strings.Contains(buf.String(), "panic recovered") {
t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String())
}
if !strings.Contains(buf.String(), "boom") {
t.Fatalf("expected log to contain 'boom', got %q", buf.String())
}
})
t.Run("passes through without panic", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
mw := server.Recovery()(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestLogging(t *testing.T) {
t.Run("logs request details", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusCreated)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "request completed") {
t.Fatalf("expected 'request completed' in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "POST") {
t.Fatalf("expected method in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "/api/users") {
t.Fatalf("expected path in log, got %q", logOutput)
}
if !strings.Contains(logOutput, "status=201") {
t.Fatalf("expected status=201 in log, got %q", logOutput)
}
})
t.Run("logs error level for 5xx", func(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusBadGateway)
})
mw := server.Logging(logger)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
mw.ServeHTTP(w, req)
logOutput := buf.String()
if !strings.Contains(logOutput, "level=ERROR") {
t.Fatalf("expected ERROR level in log, got %q", logOutput)
}
})
}

89
server/options.go Normal file
View File

@@ -0,0 +1,89 @@
package server
import (
"log/slog"
"time"
)
type serverOptions struct {
addr string
readTimeout time.Duration
readHeaderTimeout time.Duration
writeTimeout time.Duration
idleTimeout time.Duration
shutdownTimeout time.Duration
logger *slog.Logger
middlewares []Middleware
onShutdown []func()
}
// Option configures a Server.
type Option func(*serverOptions)
// WithAddr sets the listen address. Default is ":8080".
func WithAddr(addr string) Option {
return func(o *serverOptions) { o.addr = addr }
}
// WithReadTimeout sets the maximum duration for reading the entire request.
func WithReadTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.readTimeout = d }
}
// WithReadHeaderTimeout sets the maximum duration for reading request headers.
func WithReadHeaderTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.readHeaderTimeout = d }
}
// WithWriteTimeout sets the maximum duration before timing out writes of the response.
func WithWriteTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.writeTimeout = d }
}
// WithIdleTimeout sets the maximum amount of time to wait for the next request
// when keep-alives are enabled.
func WithIdleTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.idleTimeout = d }
}
// WithShutdownTimeout sets the maximum duration to wait for active connections
// to close during graceful shutdown. Default is 15 seconds.
func WithShutdownTimeout(d time.Duration) Option {
return func(o *serverOptions) { o.shutdownTimeout = d }
}
// WithLogger sets the structured logger used by the server for lifecycle events.
func WithLogger(l *slog.Logger) Option {
return func(o *serverOptions) { o.logger = l }
}
// WithMiddleware appends server middlewares to the chain.
// These are applied to the handler in the order given.
func WithMiddleware(mws ...Middleware) Option {
return func(o *serverOptions) { o.middlewares = append(o.middlewares, mws...) }
}
// WithOnShutdown registers a function to be called during graceful shutdown,
// before the HTTP server begins draining connections.
func WithOnShutdown(fn func()) Option {
return func(o *serverOptions) { o.onShutdown = append(o.onShutdown, fn) }
}
// Defaults returns a production-ready set of options including standard
// middleware (RequestID, Recovery, Logging), sensible timeouts, and the
// provided logger.
//
// Middleware order: RequestID → Recovery → Logging → user handler.
func Defaults(logger *slog.Logger) []Option {
return []Option{
WithReadHeaderTimeout(10 * time.Second),
WithIdleTimeout(120 * time.Second),
WithShutdownTimeout(15 * time.Second),
WithLogger(logger),
WithMiddleware(
RequestID(),
Recovery(WithRecoveryLogger(logger)),
Logging(logger),
),
}
}

103
server/route.go Normal file
View File

@@ -0,0 +1,103 @@
package server
import (
"net/http"
"strings"
)
// Router is a lightweight wrapper around http.ServeMux that adds middleware
// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns
// like "GET /users/{id}".
type Router struct {
mux *http.ServeMux
prefix string
middlewares []Middleware
}
// NewRouter creates a new Router backed by a fresh http.ServeMux.
func NewRouter() *Router {
return &Router{
mux: http.NewServeMux(),
}
}
// Handle registers a handler for the given pattern. The pattern follows
// http.ServeMux conventions, including method-based patterns like "GET /users".
func (r *Router) Handle(pattern string, handler http.Handler) {
if len(r.middlewares) > 0 {
handler = Chain(r.middlewares...)(handler)
}
r.mux.Handle(r.prefixedPattern(pattern), handler)
}
// HandleFunc registers a handler function for the given pattern.
func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) {
r.Handle(pattern, fn)
}
// Group creates a sub-router with a shared prefix and optional middleware.
// Patterns registered on the group are prefixed automatically. The group
// shares the underlying ServeMux with the parent router.
//
// Example:
//
// api := router.Group("/api/v1", authMiddleware)
// api.HandleFunc("GET /users", listUsers) // registers "GET /api/v1/users"
func (r *Router) Group(prefix string, mws ...Middleware) *Router {
return &Router{
mux: r.mux,
prefix: r.prefix + prefix,
middlewares: append(r.middlewaresSnapshot(), mws...),
}
}
// Mount attaches an http.Handler under the given prefix. All requests
// starting with prefix are forwarded to the handler with the prefix stripped.
func (r *Router) Mount(prefix string, handler http.Handler) {
full := r.prefix + prefix
if !strings.HasSuffix(full, "/") {
full += "/"
}
r.mux.Handle(full, http.StripPrefix(strings.TrimSuffix(full, "/"), handler))
}
// ServeHTTP implements http.Handler, making Router usable as a handler.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.mux.ServeHTTP(w, req)
}
// prefixedPattern inserts the router prefix into a pattern. It is aware of
// method prefixes: "GET /users" with prefix "/api" becomes "GET /api/users".
func (r *Router) prefixedPattern(pattern string) string {
if r.prefix == "" {
return pattern
}
// Split method prefix if present: "GET /users" → method="GET ", path="/users"
method, path, hasMethod := splitMethodPattern(pattern)
path = r.prefix + path
if hasMethod {
return method + path
}
return path
}
// splitMethodPattern splits "GET /path" into ("GET ", "/path", true).
// If there is no method prefix, returns ("", pattern, false).
func splitMethodPattern(pattern string) (method, path string, hasMethod bool) {
if idx := strings.IndexByte(pattern, ' '); idx >= 0 {
return pattern[:idx+1], pattern[idx+1:], true
}
return "", pattern, false
}
func (r *Router) middlewaresSnapshot() []Middleware {
if len(r.middlewares) == 0 {
return nil
}
cp := make([]Middleware, len(r.middlewares))
copy(cp, r.middlewares)
return cp
}

143
server/route_test.go Normal file
View File

@@ -0,0 +1,143 @@
package server_test
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"git.codelab.vc/pkg/httpx/server"
)
func TestRouter(t *testing.T) {
t.Run("basic route", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("world"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "world" {
t.Fatalf("got body %q, want %q", body, "world")
}
})
t.Run("Handle with http.Handler", func(t *testing.T) {
r := server.NewRouter()
r.Handle("GET /ping", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("pong"))
}))
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "pong" {
t.Fatalf("got body %q, want %q", body, "pong")
}
})
t.Run("path parameter", func(t *testing.T) {
r := server.NewRouter()
r.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
_, _ = w.Write([]byte("user:" + req.PathValue("id")))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/users/42", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "user:42" {
t.Fatalf("got body %q, want %q", body, "user:42")
}
})
}
func TestRouterGroup(t *testing.T) {
t.Run("prefix is applied", func(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api/v1")
api.HandleFunc("GET /users", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("users"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
}
if body := w.Body.String(); body != "users" {
t.Fatalf("got body %q, want %q", body, "users")
}
})
t.Run("nested groups", func(t *testing.T) {
r := server.NewRouter()
api := r.Group("/api")
v1 := api.Group("/v1")
v1.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("items"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
r.ServeHTTP(w, req)
if body := w.Body.String(); body != "items" {
t.Fatalf("got body %q, want %q", body, "items")
}
})
t.Run("group middleware", func(t *testing.T) {
var mwCalled bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mwCalled = true
next.ServeHTTP(w, r)
})
}
r := server.NewRouter()
g := r.Group("/admin", mw)
g.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("ok"))
})
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil)
r.ServeHTTP(w, req)
if !mwCalled {
t.Fatal("group middleware was not called")
}
})
}
func TestRouterMount(t *testing.T) {
t.Run("mounts sub-handler with prefix stripping", func(t *testing.T) {
sub := http.NewServeMux()
sub.HandleFunc("GET /info", func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("info"))
})
r := server.NewRouter()
r.Mount("/sub", sub)
w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/sub/info", nil)
r.ServeHTTP(w, req)
body, _ := io.ReadAll(w.Body)
if string(body) != "info" {
t.Fatalf("got body %q, want %q", body, "info")
}
})
}

173
server/server.go Normal file
View File

@@ -0,0 +1,173 @@
package server
import (
"context"
"errors"
"log/slog"
"net"
"net/http"
"os/signal"
"sync/atomic"
"syscall"
"time"
)
// Server is a production-ready HTTP server with graceful shutdown,
// middleware support, and signal handling.
type Server struct {
httpServer *http.Server
listener net.Listener
addr atomic.Value
logger *slog.Logger
shutdownTimeout time.Duration
onShutdown []func()
listenAddr string
}
// New creates a new Server that will serve the given handler with the
// provided options. Middleware from options is applied to the handler.
func New(handler http.Handler, opts ...Option) *Server {
o := &serverOptions{
addr: ":8080",
shutdownTimeout: 15 * time.Second,
}
for _, opt := range opts {
opt(o)
}
// Apply middleware chain to the handler.
if len(o.middlewares) > 0 {
handler = Chain(o.middlewares...)(handler)
}
srv := &Server{
httpServer: &http.Server{
Handler: handler,
ReadTimeout: o.readTimeout,
ReadHeaderTimeout: o.readHeaderTimeout,
WriteTimeout: o.writeTimeout,
IdleTimeout: o.idleTimeout,
},
logger: o.logger,
shutdownTimeout: o.shutdownTimeout,
onShutdown: o.onShutdown,
listenAddr: o.addr,
}
return srv
}
// ListenAndServe starts the server and blocks until a SIGINT or SIGTERM
// signal is received. It then performs a graceful shutdown within the
// configured shutdown timeout.
//
// Returns nil on clean shutdown or an error if listen/shutdown fails.
func (s *Server) ListenAndServe() error {
ln, err := net.Listen("tcp", s.listenAddr)
if err != nil {
return err
}
s.listener = ln
s.addr.Store(ln.Addr().String())
s.log("server started", slog.String("addr", ln.Addr().String()))
// Wait for signal in context.
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 1)
go func() {
if err := s.httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
close(errCh)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
stop()
return s.shutdown()
}
}
// ListenAndServeTLS starts the server with TLS and blocks until a signal
// is received.
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
ln, err := net.Listen("tcp", s.listenAddr)
if err != nil {
return err
}
s.listener = ln
s.addr.Store(ln.Addr().String())
s.log("server started (TLS)", slog.String("addr", ln.Addr().String()))
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 1)
go func() {
if err := s.httpServer.ServeTLS(ln, certFile, keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
errCh <- err
}
close(errCh)
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
stop()
return s.shutdown()
}
}
// Shutdown gracefully shuts down the server. It calls any registered
// onShutdown hooks, then waits for active connections to drain within
// the shutdown timeout.
func (s *Server) Shutdown(ctx context.Context) error {
s.runOnShutdown()
return s.httpServer.Shutdown(ctx)
}
// Addr returns the listener address after the server has started.
// Returns an empty string if the server has not started yet.
func (s *Server) Addr() string {
v := s.addr.Load()
if v == nil {
return ""
}
return v.(string)
}
func (s *Server) shutdown() error {
s.log("shutting down")
s.runOnShutdown()
ctx, cancel := context.WithTimeout(context.Background(), s.shutdownTimeout)
defer cancel()
if err := s.httpServer.Shutdown(ctx); err != nil {
s.log("shutdown error", slog.String("error", err.Error()))
return err
}
s.log("server stopped")
return nil
}
func (s *Server) runOnShutdown() {
for _, fn := range s.onShutdown {
fn()
}
}
func (s *Server) log(msg string, attrs ...slog.Attr) {
if s.logger != nil {
s.logger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrs...)
}
}

131
server/server_test.go Normal file
View File

@@ -0,0 +1,131 @@
package server_test
import (
"context"
"io"
"net/http"
"testing"
"time"
"git.codelab.vc/pkg/httpx/server"
)
func TestServerLifecycle(t *testing.T) {
t.Run("starts and serves requests", func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello"))
})
srv := server.New(handler, server.WithAddr(":0"))
// Start in background and wait for addr.
errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if string(body) != "hello" {
t.Fatalf("got body %q, want %q", body, "hello")
}
// Shutdown.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
})
t.Run("addr returns empty before start", func(t *testing.T) {
srv := server.New(http.NotFoundHandler())
if addr := srv.Addr(); addr != "" {
t.Fatalf("got addr %q before start, want empty", addr)
}
})
}
func TestGracefulShutdown(t *testing.T) {
t.Run("calls onShutdown hooks", func(t *testing.T) {
called := false
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithOnShutdown(func() { called = true }),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatalf("shutdown failed: %v", err)
}
if !called {
t.Fatal("onShutdown hook was not called")
}
})
}
func TestServerWithMiddleware(t *testing.T) {
t.Run("applies middleware from options", func(t *testing.T) {
var called bool
mw := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
next.ServeHTTP(w, r)
})
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
srv := server.New(handler,
server.WithAddr(":0"),
server.WithMiddleware(mw),
)
go func() { _ = srv.ListenAndServe() }()
waitForAddr(t, srv)
resp, err := http.Get("http://" + srv.Addr())
if err != nil {
t.Fatalf("GET failed: %v", err)
}
resp.Body.Close()
if !called {
t.Fatal("middleware was not called")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = srv.Shutdown(ctx)
})
}
// waitForAddr polls until the server's Addr() is non-empty.
func waitForAddr(t *testing.T, srv *server.Server) {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if srv.Addr() != "" {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatal("server did not start in time")
}