Add server CORS middleware with preflight handling
Supports AllowOrigins, AllowMethods, AllowHeaders, ExposeHeaders, AllowCredentials, and MaxAge options. Handles preflight OPTIONS requests correctly, including Vary header and credential-aware origin echoing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
128
server/middleware_cors.go
Normal file
128
server/middleware_cors.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type corsOptions struct {
|
||||
allowOrigins []string
|
||||
allowMethods []string
|
||||
allowHeaders []string
|
||||
exposeHeaders []string
|
||||
allowCredentials bool
|
||||
maxAge int
|
||||
}
|
||||
|
||||
// CORSOption configures the CORS middleware.
|
||||
type CORSOption func(*corsOptions)
|
||||
|
||||
// AllowOrigins sets the allowed origins. Use "*" to allow any origin.
|
||||
// Default is no origins (CORS disabled).
|
||||
func AllowOrigins(origins ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowOrigins = origins }
|
||||
}
|
||||
|
||||
// AllowMethods sets the allowed HTTP methods for preflight requests.
|
||||
// Default is GET, POST, HEAD.
|
||||
func AllowMethods(methods ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowMethods = methods }
|
||||
}
|
||||
|
||||
// AllowHeaders sets the allowed request headers for preflight requests.
|
||||
func AllowHeaders(headers ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.allowHeaders = headers }
|
||||
}
|
||||
|
||||
// ExposeHeaders sets headers that browsers are allowed to access.
|
||||
func ExposeHeaders(headers ...string) CORSOption {
|
||||
return func(o *corsOptions) { o.exposeHeaders = headers }
|
||||
}
|
||||
|
||||
// AllowCredentials indicates whether the response to the request can be
|
||||
// exposed when the credentials flag is true.
|
||||
func AllowCredentials(allow bool) CORSOption {
|
||||
return func(o *corsOptions) { o.allowCredentials = allow }
|
||||
}
|
||||
|
||||
// MaxAge sets the maximum time (in seconds) a preflight result can be cached.
|
||||
func MaxAge(seconds int) CORSOption {
|
||||
return func(o *corsOptions) { o.maxAge = seconds }
|
||||
}
|
||||
|
||||
// CORS returns a middleware that handles Cross-Origin Resource Sharing.
|
||||
// It processes preflight OPTIONS requests and sets the appropriate
|
||||
// Access-Control-* response headers.
|
||||
func CORS(opts ...CORSOption) Middleware {
|
||||
o := &corsOptions{
|
||||
allowMethods: []string{"GET", "POST", "HEAD"},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
allowedOrigins := make(map[string]struct{}, len(o.allowOrigins))
|
||||
allowAll := false
|
||||
for _, origin := range o.allowOrigins {
|
||||
if origin == "*" {
|
||||
allowAll = true
|
||||
}
|
||||
allowedOrigins[origin] = struct{}{}
|
||||
}
|
||||
|
||||
methods := strings.Join(o.allowMethods, ", ")
|
||||
headers := strings.Join(o.allowHeaders, ", ")
|
||||
expose := strings.Join(o.exposeHeaders, ", ")
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
if origin == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
allowed := allowAll
|
||||
if !allowed {
|
||||
_, allowed = allowedOrigins[origin]
|
||||
}
|
||||
if !allowed {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Set the allowed origin. When credentials are enabled,
|
||||
// we must echo the specific origin, not "*".
|
||||
if allowAll && !o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
w.Header().Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if o.allowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if expose != "" {
|
||||
w.Header().Set("Access-Control-Expose-Headers", expose)
|
||||
}
|
||||
|
||||
w.Header().Add("Vary", "Origin")
|
||||
|
||||
// Handle preflight.
|
||||
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
||||
w.Header().Set("Access-Control-Allow-Methods", methods)
|
||||
if headers != "" {
|
||||
w.Header().Set("Access-Control-Allow-Headers", headers)
|
||||
}
|
||||
if o.maxAge > 0 {
|
||||
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge))
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
143
server/middleware_cors_test.go
Normal file
143
server/middleware_cors_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package server_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"git.codelab.vc/pkg/httpx/server"
|
||||
)
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
t.Run("no Origin header passes through", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
||||
t.Fatal("expected no CORS headers without Origin")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wildcard origin", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("specific origin allowed", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("disallowed origin gets no CORS headers", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://evil.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
||||
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preflight OPTIONS", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("http://example.com"),
|
||||
server.AllowMethods("GET", "POST", "PUT"),
|
||||
server.AllowHeaders("Authorization", "Content-Type"),
|
||||
server.MaxAge(3600),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodOptions, "/api/data", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
req.Header.Set("Access-Control-Request-Method", "POST")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusNoContent {
|
||||
t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent)
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" {
|
||||
t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" {
|
||||
t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" {
|
||||
t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("credentials with specific origin", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("*"),
|
||||
server.AllowCredentials(true),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
// With credentials, must echo specific origin even with wildcard config.
|
||||
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
||||
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
||||
}
|
||||
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
||||
t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expose headers", func(t *testing.T) {
|
||||
mw := server.CORS(
|
||||
server.AllowOrigins("*"),
|
||||
server.ExposeHeaders("X-Custom", "X-Request-Id"),
|
||||
)(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" {
|
||||
t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Vary header is set", func(t *testing.T) {
|
||||
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set("Origin", "http://example.com")
|
||||
mw.ServeHTTP(w, req)
|
||||
|
||||
if got := w.Header().Get("Vary"); got != "Origin" {
|
||||
t.Fatalf("Vary = %q, want %q", got, "Origin")
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user