Add production-ready HTTP server package with routing, health checks, and middleware
Introduces server/ sub-package as the server-side companion to the existing Client. Includes Router (over http.ServeMux with groups and mounting), graceful shutdown with signal handling, health endpoints (/healthz, /readyz), and built-in middlewares (RequestID, Recovery, Logging). Zero external dependencies. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
55
server/health.go
Normal file
55
server/health.go
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReadinessChecker is a function that reports whether a dependency is ready.
|
||||||
|
// Return nil if healthy, or an error describing the problem.
|
||||||
|
type ReadinessChecker func() error
|
||||||
|
|
||||||
|
// HealthHandler returns an http.Handler that exposes liveness and readiness
|
||||||
|
// endpoints:
|
||||||
|
//
|
||||||
|
// - GET /healthz — liveness check, always returns 200 OK
|
||||||
|
// - GET /readyz — readiness check, returns 200 if all checkers pass, 503 otherwise
|
||||||
|
func HealthHandler(checkers ...ReadinessChecker) http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
mux.HandleFunc("GET /healthz", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("GET /readyz", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
var errs []string
|
||||||
|
for _, check := range checkers {
|
||||||
|
if err := check(); err != nil {
|
||||||
|
errs = append(errs, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
_ = json.NewEncoder(w).Encode(healthResponse{
|
||||||
|
Status: "unavailable",
|
||||||
|
Errors: errs,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(healthResponse{Status: "ok"})
|
||||||
|
})
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
type healthResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Errors []string `json:"errors,omitempty"`
|
||||||
|
}
|
||||||
90
server/health_test.go
Normal file
90
server/health_test.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthHandler(t *testing.T) {
|
||||||
|
t.Run("liveness always returns 200", func(t *testing.T) {
|
||||||
|
h := server.HealthHandler()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("decode failed: %v", err)
|
||||||
|
}
|
||||||
|
if resp["status"] != "ok" {
|
||||||
|
t.Fatalf("got status %q, want %q", resp["status"], "ok")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("readiness returns 200 when all checks pass", func(t *testing.T) {
|
||||||
|
h := server.HealthHandler(
|
||||||
|
func() error { return nil },
|
||||||
|
func() error { return nil },
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("readiness returns 503 when a check fails", func(t *testing.T) {
|
||||||
|
h := server.HealthHandler(
|
||||||
|
func() error { return nil },
|
||||||
|
func() error { return errors.New("db down") },
|
||||||
|
)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]any
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("decode failed: %v", err)
|
||||||
|
}
|
||||||
|
if resp["status"] != "unavailable" {
|
||||||
|
t.Fatalf("got status %q, want %q", resp["status"], "unavailable")
|
||||||
|
}
|
||||||
|
errs, ok := resp["errors"].([]any)
|
||||||
|
if !ok || len(errs) != 1 {
|
||||||
|
t.Fatalf("expected 1 error, got %v", resp["errors"])
|
||||||
|
}
|
||||||
|
if errs[0] != "db down" {
|
||||||
|
t.Fatalf("got error %q, want %q", errs[0], "db down")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("readiness returns 200 with no checkers", func(t *testing.T) {
|
||||||
|
h := server.HealthHandler()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/readyz", nil)
|
||||||
|
h.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
56
server/middleware.go
Normal file
56
server/middleware.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// Middleware wraps an http.Handler to add behavior.
|
||||||
|
// This is the server-side counterpart of the client middleware type
|
||||||
|
// func(http.RoundTripper) http.RoundTripper.
|
||||||
|
type Middleware func(http.Handler) http.Handler
|
||||||
|
|
||||||
|
// Chain composes middlewares so that Chain(A, B, C)(handler) == A(B(C(handler))).
|
||||||
|
// Middlewares are applied from right to left: C wraps handler first, then B wraps
|
||||||
|
// the result, then A wraps last. This means A is the outermost layer and sees
|
||||||
|
// every request first.
|
||||||
|
func Chain(mws ...Middleware) Middleware {
|
||||||
|
return func(h http.Handler) http.Handler {
|
||||||
|
for i := len(mws) - 1; i >= 0; i-- {
|
||||||
|
h = mws[i](h)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusWriter wraps http.ResponseWriter to capture the response status code.
|
||||||
|
// It implements Unwrap() so that http.ResponseController can access the
|
||||||
|
// underlying ResponseWriter's optional interfaces (Flusher, Hijacker, etc.).
|
||||||
|
type statusWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
status int
|
||||||
|
written bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||||
|
func (w *statusWriter) WriteHeader(code int) {
|
||||||
|
if !w.written {
|
||||||
|
w.status = code
|
||||||
|
w.written = true
|
||||||
|
}
|
||||||
|
w.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write delegates to the underlying writer, defaulting status to 200 if
|
||||||
|
// WriteHeader was not called explicitly.
|
||||||
|
func (w *statusWriter) Write(b []byte) (int, error) {
|
||||||
|
if !w.written {
|
||||||
|
w.status = http.StatusOK
|
||||||
|
w.written = true
|
||||||
|
}
|
||||||
|
return w.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying ResponseWriter. This is required for
|
||||||
|
// http.ResponseController to detect optional interfaces like http.Flusher
|
||||||
|
// and http.Hijacker on the original writer.
|
||||||
|
func (w *statusWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return w.ResponseWriter
|
||||||
|
}
|
||||||
39
server/middleware_logging.go
Normal file
39
server/middleware_logging.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Logging returns a middleware that logs each request's method, path,
|
||||||
|
// status code, duration, and request ID using the provided structured logger.
|
||||||
|
func Logging(logger *slog.Logger) Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||||
|
next.ServeHTTP(sw, r)
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
attrs := []slog.Attr{
|
||||||
|
slog.String("method", r.Method),
|
||||||
|
slog.String("path", r.URL.Path),
|
||||||
|
slog.Int("status", sw.status),
|
||||||
|
slog.Duration("duration", duration),
|
||||||
|
}
|
||||||
|
|
||||||
|
if id := RequestIDFromContext(r.Context()); id != "" {
|
||||||
|
attrs = append(attrs, slog.String("request_id", id))
|
||||||
|
}
|
||||||
|
|
||||||
|
level := slog.LevelInfo
|
||||||
|
if sw.status >= http.StatusInternalServerError {
|
||||||
|
level = slog.LevelError
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogAttrs(r.Context(), level, "request completed", attrs...)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
50
server/middleware_recovery.go
Normal file
50
server/middleware_recovery.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RecoveryOption configures the Recovery middleware.
|
||||||
|
type RecoveryOption func(*recoveryOptions)
|
||||||
|
|
||||||
|
type recoveryOptions struct {
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRecoveryLogger sets the logger for the Recovery middleware.
|
||||||
|
// If not set, panics are recovered silently (500 is still returned).
|
||||||
|
func WithRecoveryLogger(l *slog.Logger) RecoveryOption {
|
||||||
|
return func(o *recoveryOptions) { o.logger = l }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recovery returns a middleware that recovers from panics in downstream
|
||||||
|
// handlers. A recovered panic results in a 500 Internal Server Error
|
||||||
|
// response and is logged (if a logger is configured) with the stack trace.
|
||||||
|
func Recovery(opts ...RecoveryOption) Middleware {
|
||||||
|
o := &recoveryOptions{}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer func() {
|
||||||
|
if v := recover(); v != nil {
|
||||||
|
if o.logger != nil {
|
||||||
|
o.logger.LogAttrs(r.Context(), slog.LevelError, "panic recovered",
|
||||||
|
slog.Any("panic", v),
|
||||||
|
slog.String("stack", string(debug.Stack())),
|
||||||
|
slog.String("method", r.Method),
|
||||||
|
slog.String("path", r.URL.Path),
|
||||||
|
slog.String("request_id", RequestIDFromContext(r.Context())),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
52
server/middleware_requestid.go
Normal file
52
server/middleware_requestid.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type requestIDKey struct{}
|
||||||
|
|
||||||
|
// RequestID returns a middleware that assigns a unique request ID to each
|
||||||
|
// request. If the incoming request already has an X-Request-Id header, that
|
||||||
|
// value is used. Otherwise a new UUID v4 is generated via crypto/rand.
|
||||||
|
//
|
||||||
|
// The request ID is stored in the request context (retrieve with
|
||||||
|
// RequestIDFromContext) and set on the response X-Request-Id header.
|
||||||
|
func RequestID() Middleware {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id := r.Header.Get("X-Request-Id")
|
||||||
|
if id == "" {
|
||||||
|
id = newUUID()
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), requestIDKey{}, id)
|
||||||
|
w.Header().Set("X-Request-Id", id)
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestIDFromContext returns the request ID from the context, or an empty
|
||||||
|
// string if none is set.
|
||||||
|
func RequestIDFromContext(ctx context.Context) string {
|
||||||
|
id, _ := ctx.Value(requestIDKey{}).(string)
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// newUUID generates a UUID v4 string using crypto/rand.
|
||||||
|
func newUUID() string {
|
||||||
|
var uuid [16]byte
|
||||||
|
_, _ = rand.Read(uuid[:])
|
||||||
|
|
||||||
|
// Set version 4 (bits 12-15 of time_hi_and_version).
|
||||||
|
uuid[6] = (uuid[6] & 0x0f) | 0x40
|
||||||
|
// Set variant bits (10xx).
|
||||||
|
uuid[8] = (uuid[8] & 0x3f) | 0x80
|
||||||
|
|
||||||
|
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||||
|
uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:16])
|
||||||
|
}
|
||||||
234
server/middleware_test.go
Normal file
234
server/middleware_test.go
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestChain(t *testing.T) {
|
||||||
|
t.Run("applies middlewares in correct order", func(t *testing.T) {
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
mwA := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
order = append(order, "A-before")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
order = append(order, "A-after")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
mwB := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
order = append(order, "B-before")
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
order = append(order, "B-after")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
order = append(order, "handler")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
chained := server.Chain(mwA, mwB)(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
chained.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
expected := []string{"A-before", "B-before", "handler", "B-after", "A-after"}
|
||||||
|
if len(order) != len(expected) {
|
||||||
|
t.Fatalf("got %v, want %v", order, expected)
|
||||||
|
}
|
||||||
|
for i, v := range expected {
|
||||||
|
if order[i] != v {
|
||||||
|
t.Fatalf("order[%d] = %q, want %q", i, order[i], v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty chain returns handler unchanged", func(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
chained := server.Chain()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
chained.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("handler was not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestID(t *testing.T) {
|
||||||
|
t.Run("generates ID when not present", func(t *testing.T) {
|
||||||
|
var gotID string
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotID = server.RequestIDFromContext(r.Context())
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := server.RequestID()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if gotID == "" {
|
||||||
|
t.Fatal("expected request ID in context, got empty")
|
||||||
|
}
|
||||||
|
if w.Header().Get("X-Request-Id") != gotID {
|
||||||
|
t.Fatalf("response header %q != context ID %q", w.Header().Get("X-Request-Id"), gotID)
|
||||||
|
}
|
||||||
|
// UUID v4 format: 8-4-4-4-12 hex chars.
|
||||||
|
if len(gotID) != 36 {
|
||||||
|
t.Fatalf("expected UUID length 36, got %d: %q", len(gotID), gotID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing ID", func(t *testing.T) {
|
||||||
|
var gotID string
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
gotID = server.RequestIDFromContext(r.Context())
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := server.RequestID()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Request-Id", "custom-123")
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if gotID != "custom-123" {
|
||||||
|
t.Fatalf("got ID %q, want %q", gotID, "custom-123")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("context without ID returns empty", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if id := server.RequestIDFromContext(req.Context()); id != "" {
|
||||||
|
t.Fatalf("expected empty, got %q", id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecovery(t *testing.T) {
|
||||||
|
t.Run("recovers from panic and returns 500", func(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
panic("something went wrong")
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := server.Recovery()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("logs panic with logger", func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
panic("boom")
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := server.Recovery(server.WithRecoveryLogger(logger))(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if !strings.Contains(buf.String(), "panic recovered") {
|
||||||
|
t.Fatalf("expected log to contain 'panic recovered', got %q", buf.String())
|
||||||
|
}
|
||||||
|
if !strings.Contains(buf.String(), "boom") {
|
||||||
|
t.Fatalf("expected log to contain 'boom', got %q", buf.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("passes through without panic", func(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := server.Recovery()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogging(t *testing.T) {
|
||||||
|
t.Run("logs request details", func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := server.Logging(logger)(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/users", nil)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
logOutput := buf.String()
|
||||||
|
if !strings.Contains(logOutput, "request completed") {
|
||||||
|
t.Fatalf("expected 'request completed' in log, got %q", logOutput)
|
||||||
|
}
|
||||||
|
if !strings.Contains(logOutput, "POST") {
|
||||||
|
t.Fatalf("expected method in log, got %q", logOutput)
|
||||||
|
}
|
||||||
|
if !strings.Contains(logOutput, "/api/users") {
|
||||||
|
t.Fatalf("expected path in log, got %q", logOutput)
|
||||||
|
}
|
||||||
|
if !strings.Contains(logOutput, "status=201") {
|
||||||
|
t.Fatalf("expected status=201 in log, got %q", logOutput)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("logs error level for 5xx", func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
})
|
||||||
|
|
||||||
|
mw := server.Logging(logger)(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
mw.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
logOutput := buf.String()
|
||||||
|
if !strings.Contains(logOutput, "level=ERROR") {
|
||||||
|
t.Fatalf("expected ERROR level in log, got %q", logOutput)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
89
server/options.go
Normal file
89
server/options.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serverOptions struct {
|
||||||
|
addr string
|
||||||
|
readTimeout time.Duration
|
||||||
|
readHeaderTimeout time.Duration
|
||||||
|
writeTimeout time.Duration
|
||||||
|
idleTimeout time.Duration
|
||||||
|
shutdownTimeout time.Duration
|
||||||
|
logger *slog.Logger
|
||||||
|
middlewares []Middleware
|
||||||
|
onShutdown []func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option configures a Server.
|
||||||
|
type Option func(*serverOptions)
|
||||||
|
|
||||||
|
// WithAddr sets the listen address. Default is ":8080".
|
||||||
|
func WithAddr(addr string) Option {
|
||||||
|
return func(o *serverOptions) { o.addr = addr }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithReadTimeout sets the maximum duration for reading the entire request.
|
||||||
|
func WithReadTimeout(d time.Duration) Option {
|
||||||
|
return func(o *serverOptions) { o.readTimeout = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithReadHeaderTimeout sets the maximum duration for reading request headers.
|
||||||
|
func WithReadHeaderTimeout(d time.Duration) Option {
|
||||||
|
return func(o *serverOptions) { o.readHeaderTimeout = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithWriteTimeout sets the maximum duration before timing out writes of the response.
|
||||||
|
func WithWriteTimeout(d time.Duration) Option {
|
||||||
|
return func(o *serverOptions) { o.writeTimeout = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIdleTimeout sets the maximum amount of time to wait for the next request
|
||||||
|
// when keep-alives are enabled.
|
||||||
|
func WithIdleTimeout(d time.Duration) Option {
|
||||||
|
return func(o *serverOptions) { o.idleTimeout = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithShutdownTimeout sets the maximum duration to wait for active connections
|
||||||
|
// to close during graceful shutdown. Default is 15 seconds.
|
||||||
|
func WithShutdownTimeout(d time.Duration) Option {
|
||||||
|
return func(o *serverOptions) { o.shutdownTimeout = d }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithLogger sets the structured logger used by the server for lifecycle events.
|
||||||
|
func WithLogger(l *slog.Logger) Option {
|
||||||
|
return func(o *serverOptions) { o.logger = l }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMiddleware appends server middlewares to the chain.
|
||||||
|
// These are applied to the handler in the order given.
|
||||||
|
func WithMiddleware(mws ...Middleware) Option {
|
||||||
|
return func(o *serverOptions) { o.middlewares = append(o.middlewares, mws...) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithOnShutdown registers a function to be called during graceful shutdown,
|
||||||
|
// before the HTTP server begins draining connections.
|
||||||
|
func WithOnShutdown(fn func()) Option {
|
||||||
|
return func(o *serverOptions) { o.onShutdown = append(o.onShutdown, fn) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Defaults returns a production-ready set of options including standard
|
||||||
|
// middleware (RequestID, Recovery, Logging), sensible timeouts, and the
|
||||||
|
// provided logger.
|
||||||
|
//
|
||||||
|
// Middleware order: RequestID → Recovery → Logging → user handler.
|
||||||
|
func Defaults(logger *slog.Logger) []Option {
|
||||||
|
return []Option{
|
||||||
|
WithReadHeaderTimeout(10 * time.Second),
|
||||||
|
WithIdleTimeout(120 * time.Second),
|
||||||
|
WithShutdownTimeout(15 * time.Second),
|
||||||
|
WithLogger(logger),
|
||||||
|
WithMiddleware(
|
||||||
|
RequestID(),
|
||||||
|
Recovery(WithRecoveryLogger(logger)),
|
||||||
|
Logging(logger),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
103
server/route.go
Normal file
103
server/route.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Router is a lightweight wrapper around http.ServeMux that adds middleware
|
||||||
|
// groups and sub-router mounting. It leverages Go 1.22+ enhanced patterns
|
||||||
|
// like "GET /users/{id}".
|
||||||
|
type Router struct {
|
||||||
|
mux *http.ServeMux
|
||||||
|
prefix string
|
||||||
|
middlewares []Middleware
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter creates a new Router backed by a fresh http.ServeMux.
|
||||||
|
func NewRouter() *Router {
|
||||||
|
return &Router{
|
||||||
|
mux: http.NewServeMux(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle registers a handler for the given pattern. The pattern follows
|
||||||
|
// http.ServeMux conventions, including method-based patterns like "GET /users".
|
||||||
|
func (r *Router) Handle(pattern string, handler http.Handler) {
|
||||||
|
if len(r.middlewares) > 0 {
|
||||||
|
handler = Chain(r.middlewares...)(handler)
|
||||||
|
}
|
||||||
|
r.mux.Handle(r.prefixedPattern(pattern), handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFunc registers a handler function for the given pattern.
|
||||||
|
func (r *Router) HandleFunc(pattern string, fn http.HandlerFunc) {
|
||||||
|
r.Handle(pattern, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group creates a sub-router with a shared prefix and optional middleware.
|
||||||
|
// Patterns registered on the group are prefixed automatically. The group
|
||||||
|
// shares the underlying ServeMux with the parent router.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// api := router.Group("/api/v1", authMiddleware)
|
||||||
|
// api.HandleFunc("GET /users", listUsers) // registers "GET /api/v1/users"
|
||||||
|
func (r *Router) Group(prefix string, mws ...Middleware) *Router {
|
||||||
|
return &Router{
|
||||||
|
mux: r.mux,
|
||||||
|
prefix: r.prefix + prefix,
|
||||||
|
middlewares: append(r.middlewaresSnapshot(), mws...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mount attaches an http.Handler under the given prefix. All requests
|
||||||
|
// starting with prefix are forwarded to the handler with the prefix stripped.
|
||||||
|
func (r *Router) Mount(prefix string, handler http.Handler) {
|
||||||
|
full := r.prefix + prefix
|
||||||
|
if !strings.HasSuffix(full, "/") {
|
||||||
|
full += "/"
|
||||||
|
}
|
||||||
|
r.mux.Handle(full, http.StripPrefix(strings.TrimSuffix(full, "/"), handler))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements http.Handler, making Router usable as a handler.
|
||||||
|
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
r.mux.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixedPattern inserts the router prefix into a pattern. It is aware of
|
||||||
|
// method prefixes: "GET /users" with prefix "/api" becomes "GET /api/users".
|
||||||
|
func (r *Router) prefixedPattern(pattern string) string {
|
||||||
|
if r.prefix == "" {
|
||||||
|
return pattern
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split method prefix if present: "GET /users" → method="GET ", path="/users"
|
||||||
|
method, path, hasMethod := splitMethodPattern(pattern)
|
||||||
|
|
||||||
|
path = r.prefix + path
|
||||||
|
|
||||||
|
if hasMethod {
|
||||||
|
return method + path
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitMethodPattern splits "GET /path" into ("GET ", "/path", true).
|
||||||
|
// If there is no method prefix, returns ("", pattern, false).
|
||||||
|
func splitMethodPattern(pattern string) (method, path string, hasMethod bool) {
|
||||||
|
if idx := strings.IndexByte(pattern, ' '); idx >= 0 {
|
||||||
|
return pattern[:idx+1], pattern[idx+1:], true
|
||||||
|
}
|
||||||
|
return "", pattern, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) middlewaresSnapshot() []Middleware {
|
||||||
|
if len(r.middlewares) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cp := make([]Middleware, len(r.middlewares))
|
||||||
|
copy(cp, r.middlewares)
|
||||||
|
return cp
|
||||||
|
}
|
||||||
143
server/route_test.go
Normal file
143
server/route_test.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRouter(t *testing.T) {
|
||||||
|
t.Run("basic route", func(t *testing.T) {
|
||||||
|
r := server.NewRouter()
|
||||||
|
r.HandleFunc("GET /hello", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("world"))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if body := w.Body.String(); body != "world" {
|
||||||
|
t.Fatalf("got body %q, want %q", body, "world")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Handle with http.Handler", func(t *testing.T) {
|
||||||
|
r := server.NewRouter()
|
||||||
|
r.Handle("GET /ping", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("pong"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if body := w.Body.String(); body != "pong" {
|
||||||
|
t.Fatalf("got body %q, want %q", body, "pong")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("path parameter", func(t *testing.T) {
|
||||||
|
r := server.NewRouter()
|
||||||
|
r.HandleFunc("GET /users/{id}", func(w http.ResponseWriter, req *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("user:" + req.PathValue("id")))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/users/42", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if body := w.Body.String(); body != "user:42" {
|
||||||
|
t.Fatalf("got body %q, want %q", body, "user:42")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterGroup(t *testing.T) {
|
||||||
|
t.Run("prefix is applied", func(t *testing.T) {
|
||||||
|
r := server.NewRouter()
|
||||||
|
api := r.Group("/api/v1")
|
||||||
|
api.HandleFunc("GET /users", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("users"))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if body := w.Body.String(); body != "users" {
|
||||||
|
t.Fatalf("got body %q, want %q", body, "users")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nested groups", func(t *testing.T) {
|
||||||
|
r := server.NewRouter()
|
||||||
|
api := r.Group("/api")
|
||||||
|
v1 := api.Group("/v1")
|
||||||
|
v1.HandleFunc("GET /items", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("items"))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/v1/items", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if body := w.Body.String(); body != "items" {
|
||||||
|
t.Fatalf("got body %q, want %q", body, "items")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("group middleware", func(t *testing.T) {
|
||||||
|
var mwCalled bool
|
||||||
|
mw := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
mwCalled = true
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
r := server.NewRouter()
|
||||||
|
g := r.Group("/admin", mw)
|
||||||
|
g.HandleFunc("GET /dashboard", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("ok"))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if !mwCalled {
|
||||||
|
t.Fatal("group middleware was not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouterMount(t *testing.T) {
|
||||||
|
t.Run("mounts sub-handler with prefix stripping", func(t *testing.T) {
|
||||||
|
sub := http.NewServeMux()
|
||||||
|
sub.HandleFunc("GET /info", func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = w.Write([]byte("info"))
|
||||||
|
})
|
||||||
|
|
||||||
|
r := server.NewRouter()
|
||||||
|
r.Mount("/sub", sub)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/sub/info", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(w.Body)
|
||||||
|
if string(body) != "info" {
|
||||||
|
t.Fatalf("got body %q, want %q", body, "info")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
173
server/server.go
Normal file
173
server/server.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os/signal"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server is a production-ready HTTP server with graceful shutdown,
|
||||||
|
// middleware support, and signal handling.
|
||||||
|
type Server struct {
|
||||||
|
httpServer *http.Server
|
||||||
|
listener net.Listener
|
||||||
|
addr atomic.Value
|
||||||
|
logger *slog.Logger
|
||||||
|
shutdownTimeout time.Duration
|
||||||
|
onShutdown []func()
|
||||||
|
listenAddr string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Server that will serve the given handler with the
|
||||||
|
// provided options. Middleware from options is applied to the handler.
|
||||||
|
func New(handler http.Handler, opts ...Option) *Server {
|
||||||
|
o := &serverOptions{
|
||||||
|
addr: ":8080",
|
||||||
|
shutdownTimeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply middleware chain to the handler.
|
||||||
|
if len(o.middlewares) > 0 {
|
||||||
|
handler = Chain(o.middlewares...)(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := &Server{
|
||||||
|
httpServer: &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadTimeout: o.readTimeout,
|
||||||
|
ReadHeaderTimeout: o.readHeaderTimeout,
|
||||||
|
WriteTimeout: o.writeTimeout,
|
||||||
|
IdleTimeout: o.idleTimeout,
|
||||||
|
},
|
||||||
|
logger: o.logger,
|
||||||
|
shutdownTimeout: o.shutdownTimeout,
|
||||||
|
onShutdown: o.onShutdown,
|
||||||
|
listenAddr: o.addr,
|
||||||
|
}
|
||||||
|
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServe starts the server and blocks until a SIGINT or SIGTERM
|
||||||
|
// signal is received. It then performs a graceful shutdown within the
|
||||||
|
// configured shutdown timeout.
|
||||||
|
//
|
||||||
|
// Returns nil on clean shutdown or an error if listen/shutdown fails.
|
||||||
|
func (s *Server) ListenAndServe() error {
|
||||||
|
ln, err := net.Listen("tcp", s.listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listener = ln
|
||||||
|
s.addr.Store(ln.Addr().String())
|
||||||
|
|
||||||
|
s.log("server started", slog.String("addr", ln.Addr().String()))
|
||||||
|
|
||||||
|
// Wait for signal in context.
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := s.httpServer.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
close(errCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
stop()
|
||||||
|
return s.shutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServeTLS starts the server with TLS and blocks until a signal
|
||||||
|
// is received.
|
||||||
|
func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
|
||||||
|
ln, err := net.Listen("tcp", s.listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.listener = ln
|
||||||
|
s.addr.Store(ln.Addr().String())
|
||||||
|
|
||||||
|
s.log("server started (TLS)", slog.String("addr", ln.Addr().String()))
|
||||||
|
|
||||||
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := s.httpServer.ServeTLS(ln, certFile, keyFile); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
close(errCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
stop()
|
||||||
|
return s.shutdown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown gracefully shuts down the server. It calls any registered
|
||||||
|
// onShutdown hooks, then waits for active connections to drain within
|
||||||
|
// the shutdown timeout.
|
||||||
|
func (s *Server) Shutdown(ctx context.Context) error {
|
||||||
|
s.runOnShutdown()
|
||||||
|
return s.httpServer.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns the listener address after the server has started.
|
||||||
|
// Returns an empty string if the server has not started yet.
|
||||||
|
func (s *Server) Addr() string {
|
||||||
|
v := s.addr.Load()
|
||||||
|
if v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) shutdown() error {
|
||||||
|
s.log("shutting down")
|
||||||
|
|
||||||
|
s.runOnShutdown()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), s.shutdownTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||||
|
s.log("shutdown error", slog.String("error", err.Error()))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log("server stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) runOnShutdown() {
|
||||||
|
for _, fn := range s.onShutdown {
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) log(msg string, attrs ...slog.Attr) {
|
||||||
|
if s.logger != nil {
|
||||||
|
s.logger.LogAttrs(context.Background(), slog.LevelInfo, msg, attrs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
131
server/server_test.go
Normal file
131
server/server_test.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.codelab.vc/pkg/httpx/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServerLifecycle(t *testing.T) {
|
||||||
|
t.Run("starts and serves requests", func(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("hello"))
|
||||||
|
})
|
||||||
|
|
||||||
|
srv := server.New(handler, server.WithAddr(":0"))
|
||||||
|
|
||||||
|
// Start in background and wait for addr.
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() { errCh <- srv.ListenAndServe() }()
|
||||||
|
|
||||||
|
waitForAddr(t, srv)
|
||||||
|
|
||||||
|
resp, err := http.Get("http://" + srv.Addr())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if string(body) != "hello" {
|
||||||
|
t.Fatalf("got body %q, want %q", body, "hello")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown.
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
t.Fatalf("shutdown failed: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("addr returns empty before start", func(t *testing.T) {
|
||||||
|
srv := server.New(http.NotFoundHandler())
|
||||||
|
if addr := srv.Addr(); addr != "" {
|
||||||
|
t.Fatalf("got addr %q before start, want empty", addr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGracefulShutdown(t *testing.T) {
|
||||||
|
t.Run("calls onShutdown hooks", func(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
srv := server.New(handler,
|
||||||
|
server.WithAddr(":0"),
|
||||||
|
server.WithOnShutdown(func() { called = true }),
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() { _ = srv.ListenAndServe() }()
|
||||||
|
waitForAddr(t, srv)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
t.Fatalf("shutdown failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("onShutdown hook was not called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerWithMiddleware(t *testing.T) {
|
||||||
|
t.Run("applies middleware from options", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
mw := func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
srv := server.New(handler,
|
||||||
|
server.WithAddr(":0"),
|
||||||
|
server.WithMiddleware(mw),
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() { _ = srv.ListenAndServe() }()
|
||||||
|
waitForAddr(t, srv)
|
||||||
|
|
||||||
|
resp, err := http.Get("http://" + srv.Addr())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET failed: %v", err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("middleware was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = srv.Shutdown(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForAddr polls until the server's Addr() is non-empty.
|
||||||
|
func waitForAddr(t *testing.T, srv *server.Server) {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(2 * time.Second)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
if srv.Addr() != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatal("server did not start in time")
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user