package server import ( "net/http" "strconv" "strings" ) type corsOptions struct { allowOrigins []string allowMethods []string allowHeaders []string exposeHeaders []string allowCredentials bool maxAge int } // CORSOption configures the CORS middleware. type CORSOption func(*corsOptions) // AllowOrigins sets the allowed origins. Use "*" to allow any origin. // Default is no origins (CORS disabled). func AllowOrigins(origins ...string) CORSOption { return func(o *corsOptions) { o.allowOrigins = origins } } // AllowMethods sets the allowed HTTP methods for preflight requests. // Default is GET, POST, HEAD. func AllowMethods(methods ...string) CORSOption { return func(o *corsOptions) { o.allowMethods = methods } } // AllowHeaders sets the allowed request headers for preflight requests. func AllowHeaders(headers ...string) CORSOption { return func(o *corsOptions) { o.allowHeaders = headers } } // ExposeHeaders sets headers that browsers are allowed to access. func ExposeHeaders(headers ...string) CORSOption { return func(o *corsOptions) { o.exposeHeaders = headers } } // AllowCredentials indicates whether the response to the request can be // exposed when the credentials flag is true. func AllowCredentials(allow bool) CORSOption { return func(o *corsOptions) { o.allowCredentials = allow } } // MaxAge sets the maximum time (in seconds) a preflight result can be cached. func MaxAge(seconds int) CORSOption { return func(o *corsOptions) { o.maxAge = seconds } } // CORS returns a middleware that handles Cross-Origin Resource Sharing. // It processes preflight OPTIONS requests and sets the appropriate // Access-Control-* response headers. func CORS(opts ...CORSOption) Middleware { o := &corsOptions{ allowMethods: []string{"GET", "POST", "HEAD"}, } for _, opt := range opts { opt(o) } allowedOrigins := make(map[string]struct{}, len(o.allowOrigins)) allowAll := false for _, origin := range o.allowOrigins { if origin == "*" { allowAll = true } allowedOrigins[origin] = struct{}{} } methods := strings.Join(o.allowMethods, ", ") headers := strings.Join(o.allowHeaders, ", ") expose := strings.Join(o.exposeHeaders, ", ") return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin == "" { next.ServeHTTP(w, r) return } allowed := allowAll if !allowed { _, allowed = allowedOrigins[origin] } if !allowed { next.ServeHTTP(w, r) return } // Set the allowed origin. When credentials are enabled, // we must echo the specific origin, not "*". if allowAll && !o.allowCredentials { w.Header().Set("Access-Control-Allow-Origin", "*") } else { w.Header().Set("Access-Control-Allow-Origin", origin) } if o.allowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") } if expose != "" { w.Header().Set("Access-Control-Expose-Headers", expose) } w.Header().Add("Vary", "Origin") // Handle preflight. if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { w.Header().Set("Access-Control-Allow-Methods", methods) if headers != "" { w.Header().Set("Access-Control-Allow-Headers", headers) } if o.maxAge > 0 { w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge)) } w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } }