Supports AllowOrigins, AllowMethods, AllowHeaders, ExposeHeaders, AllowCredentials, and MaxAge options. Handles preflight OPTIONS requests correctly, including Vary header and credential-aware origin echoing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
129 lines
3.5 KiB
Go
129 lines
3.5 KiB
Go
package server
|
|
|
|
import (
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type corsOptions struct {
|
|
allowOrigins []string
|
|
allowMethods []string
|
|
allowHeaders []string
|
|
exposeHeaders []string
|
|
allowCredentials bool
|
|
maxAge int
|
|
}
|
|
|
|
// CORSOption configures the CORS middleware.
|
|
type CORSOption func(*corsOptions)
|
|
|
|
// AllowOrigins sets the allowed origins. Use "*" to allow any origin.
|
|
// Default is no origins (CORS disabled).
|
|
func AllowOrigins(origins ...string) CORSOption {
|
|
return func(o *corsOptions) { o.allowOrigins = origins }
|
|
}
|
|
|
|
// AllowMethods sets the allowed HTTP methods for preflight requests.
|
|
// Default is GET, POST, HEAD.
|
|
func AllowMethods(methods ...string) CORSOption {
|
|
return func(o *corsOptions) { o.allowMethods = methods }
|
|
}
|
|
|
|
// AllowHeaders sets the allowed request headers for preflight requests.
|
|
func AllowHeaders(headers ...string) CORSOption {
|
|
return func(o *corsOptions) { o.allowHeaders = headers }
|
|
}
|
|
|
|
// ExposeHeaders sets headers that browsers are allowed to access.
|
|
func ExposeHeaders(headers ...string) CORSOption {
|
|
return func(o *corsOptions) { o.exposeHeaders = headers }
|
|
}
|
|
|
|
// AllowCredentials indicates whether the response to the request can be
|
|
// exposed when the credentials flag is true.
|
|
func AllowCredentials(allow bool) CORSOption {
|
|
return func(o *corsOptions) { o.allowCredentials = allow }
|
|
}
|
|
|
|
// MaxAge sets the maximum time (in seconds) a preflight result can be cached.
|
|
func MaxAge(seconds int) CORSOption {
|
|
return func(o *corsOptions) { o.maxAge = seconds }
|
|
}
|
|
|
|
// CORS returns a middleware that handles Cross-Origin Resource Sharing.
|
|
// It processes preflight OPTIONS requests and sets the appropriate
|
|
// Access-Control-* response headers.
|
|
func CORS(opts ...CORSOption) Middleware {
|
|
o := &corsOptions{
|
|
allowMethods: []string{"GET", "POST", "HEAD"},
|
|
}
|
|
for _, opt := range opts {
|
|
opt(o)
|
|
}
|
|
|
|
allowedOrigins := make(map[string]struct{}, len(o.allowOrigins))
|
|
allowAll := false
|
|
for _, origin := range o.allowOrigins {
|
|
if origin == "*" {
|
|
allowAll = true
|
|
}
|
|
allowedOrigins[origin] = struct{}{}
|
|
}
|
|
|
|
methods := strings.Join(o.allowMethods, ", ")
|
|
headers := strings.Join(o.allowHeaders, ", ")
|
|
expose := strings.Join(o.exposeHeaders, ", ")
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
if origin == "" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
allowed := allowAll
|
|
if !allowed {
|
|
_, allowed = allowedOrigins[origin]
|
|
}
|
|
if !allowed {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
// Set the allowed origin. When credentials are enabled,
|
|
// we must echo the specific origin, not "*".
|
|
if allowAll && !o.allowCredentials {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
} else {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
}
|
|
|
|
if o.allowCredentials {
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
}
|
|
if expose != "" {
|
|
w.Header().Set("Access-Control-Expose-Headers", expose)
|
|
}
|
|
|
|
w.Header().Add("Vary", "Origin")
|
|
|
|
// Handle preflight.
|
|
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
|
w.Header().Set("Access-Control-Allow-Methods", methods)
|
|
if headers != "" {
|
|
w.Header().Set("Access-Control-Allow-Headers", headers)
|
|
}
|
|
if o.maxAge > 0 {
|
|
w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge))
|
|
}
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|