diff --git a/server/middleware_cors.go b/server/middleware_cors.go new file mode 100644 index 0000000..e9e2db9 --- /dev/null +++ b/server/middleware_cors.go @@ -0,0 +1,128 @@ +package server + +import ( + "net/http" + "strconv" + "strings" +) + +type corsOptions struct { + allowOrigins []string + allowMethods []string + allowHeaders []string + exposeHeaders []string + allowCredentials bool + maxAge int +} + +// CORSOption configures the CORS middleware. +type CORSOption func(*corsOptions) + +// AllowOrigins sets the allowed origins. Use "*" to allow any origin. +// Default is no origins (CORS disabled). +func AllowOrigins(origins ...string) CORSOption { + return func(o *corsOptions) { o.allowOrigins = origins } +} + +// AllowMethods sets the allowed HTTP methods for preflight requests. +// Default is GET, POST, HEAD. +func AllowMethods(methods ...string) CORSOption { + return func(o *corsOptions) { o.allowMethods = methods } +} + +// AllowHeaders sets the allowed request headers for preflight requests. +func AllowHeaders(headers ...string) CORSOption { + return func(o *corsOptions) { o.allowHeaders = headers } +} + +// ExposeHeaders sets headers that browsers are allowed to access. +func ExposeHeaders(headers ...string) CORSOption { + return func(o *corsOptions) { o.exposeHeaders = headers } +} + +// AllowCredentials indicates whether the response to the request can be +// exposed when the credentials flag is true. +func AllowCredentials(allow bool) CORSOption { + return func(o *corsOptions) { o.allowCredentials = allow } +} + +// MaxAge sets the maximum time (in seconds) a preflight result can be cached. +func MaxAge(seconds int) CORSOption { + return func(o *corsOptions) { o.maxAge = seconds } +} + +// CORS returns a middleware that handles Cross-Origin Resource Sharing. +// It processes preflight OPTIONS requests and sets the appropriate +// Access-Control-* response headers. +func CORS(opts ...CORSOption) Middleware { + o := &corsOptions{ + allowMethods: []string{"GET", "POST", "HEAD"}, + } + for _, opt := range opts { + opt(o) + } + + allowedOrigins := make(map[string]struct{}, len(o.allowOrigins)) + allowAll := false + for _, origin := range o.allowOrigins { + if origin == "*" { + allowAll = true + } + allowedOrigins[origin] = struct{}{} + } + + methods := strings.Join(o.allowMethods, ", ") + headers := strings.Join(o.allowHeaders, ", ") + expose := strings.Join(o.exposeHeaders, ", ") + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if origin == "" { + next.ServeHTTP(w, r) + return + } + + allowed := allowAll + if !allowed { + _, allowed = allowedOrigins[origin] + } + if !allowed { + next.ServeHTTP(w, r) + return + } + + // Set the allowed origin. When credentials are enabled, + // we must echo the specific origin, not "*". + if allowAll && !o.allowCredentials { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + w.Header().Set("Access-Control-Allow-Origin", origin) + } + + if o.allowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + if expose != "" { + w.Header().Set("Access-Control-Expose-Headers", expose) + } + + w.Header().Add("Vary", "Origin") + + // Handle preflight. + if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" { + w.Header().Set("Access-Control-Allow-Methods", methods) + if headers != "" { + w.Header().Set("Access-Control-Allow-Headers", headers) + } + if o.maxAge > 0 { + w.Header().Set("Access-Control-Max-Age", strconv.Itoa(o.maxAge)) + } + w.WriteHeader(http.StatusNoContent) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/server/middleware_cors_test.go b/server/middleware_cors_test.go new file mode 100644 index 0000000..4ef21da --- /dev/null +++ b/server/middleware_cors_test.go @@ -0,0 +1,143 @@ +package server_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "git.codelab.vc/pkg/httpx/server" +) + +func TestCORS(t *testing.T) { + okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + t.Run("no Origin header passes through", func(t *testing.T) { + mw := server.CORS(server.AllowOrigins("*"))(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + mw.ServeHTTP(w, req) + + if w.Header().Get("Access-Control-Allow-Origin") != "" { + t.Fatal("expected no CORS headers without Origin") + } + }) + + t.Run("wildcard origin", func(t *testing.T) { + mw := server.CORS(server.AllowOrigins("*"))(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + mw.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { + t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*") + } + }) + + t.Run("specific origin allowed", func(t *testing.T) { + mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + mw.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" { + t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com") + } + }) + + t.Run("disallowed origin gets no CORS headers", func(t *testing.T) { + mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://evil.com") + mw.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { + t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got) + } + }) + + t.Run("preflight OPTIONS", func(t *testing.T) { + mw := server.CORS( + server.AllowOrigins("http://example.com"), + server.AllowMethods("GET", "POST", "PUT"), + server.AllowHeaders("Authorization", "Content-Type"), + server.MaxAge(3600), + )(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodOptions, "/api/data", nil) + req.Header.Set("Origin", "http://example.com") + req.Header.Set("Access-Control-Request-Method", "POST") + mw.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent) + } + if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" { + t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT") + } + if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" { + t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type") + } + if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" { + t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600") + } + }) + + t.Run("credentials with specific origin", func(t *testing.T) { + mw := server.CORS( + server.AllowOrigins("*"), + server.AllowCredentials(true), + )(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + mw.ServeHTTP(w, req) + + // With credentials, must echo specific origin even with wildcard config. + if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" { + t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com") + } + if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" { + t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true") + } + }) + + t.Run("expose headers", func(t *testing.T) { + mw := server.CORS( + server.AllowOrigins("*"), + server.ExposeHeaders("X-Custom", "X-Request-Id"), + )(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + mw.ServeHTTP(w, req) + + if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" { + t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id") + } + }) + + t.Run("Vary header is set", func(t *testing.T) { + mw := server.CORS(server.AllowOrigins("*"))(okHandler) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Origin", "http://example.com") + mw.ServeHTTP(w, req) + + if got := w.Header().Get("Vary"); got != "Origin" { + t.Fatalf("Vary = %q, want %q", got, "Origin") + } + }) +}