package server_test import ( "net/http" "net/http/httptest" "testing" "git.codelab.vc/pkg/httpx/server" ) func TestCORS(t *testing.T) { okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) t.Run("no Origin header passes through", func(t *testing.T) { mw := server.CORS(server.AllowOrigins("*"))(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) mw.ServeHTTP(w, req) if w.Header().Get("Access-Control-Allow-Origin") != "" { t.Fatal("expected no CORS headers without Origin") } }) t.Run("wildcard origin", func(t *testing.T) { mw := server.CORS(server.AllowOrigins("*"))(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "http://example.com") mw.ServeHTTP(w, req) if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" { t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*") } }) t.Run("specific origin allowed", func(t *testing.T) { mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "http://example.com") mw.ServeHTTP(w, req) if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" { t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com") } }) t.Run("disallowed origin gets no CORS headers", func(t *testing.T) { mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "http://evil.com") mw.ServeHTTP(w, req) if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" { t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got) } }) t.Run("preflight OPTIONS", func(t *testing.T) { mw := server.CORS( server.AllowOrigins("http://example.com"), server.AllowMethods("GET", "POST", "PUT"), server.AllowHeaders("Authorization", "Content-Type"), server.MaxAge(3600), )(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodOptions, "/api/data", nil) req.Header.Set("Origin", "http://example.com") req.Header.Set("Access-Control-Request-Method", "POST") mw.ServeHTTP(w, req) if w.Code != http.StatusNoContent { t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent) } if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" { t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT") } if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" { t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type") } if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" { t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600") } }) t.Run("credentials with specific origin", func(t *testing.T) { mw := server.CORS( server.AllowOrigins("*"), server.AllowCredentials(true), )(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "http://example.com") mw.ServeHTTP(w, req) // With credentials, must echo specific origin even with wildcard config. if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" { t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com") } if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" { t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true") } }) t.Run("expose headers", func(t *testing.T) { mw := server.CORS( server.AllowOrigins("*"), server.ExposeHeaders("X-Custom", "X-Request-Id"), )(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "http://example.com") mw.ServeHTTP(w, req) if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" { t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id") } }) t.Run("Vary header is set", func(t *testing.T) { mw := server.CORS(server.AllowOrigins("*"))(okHandler) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Origin", "http://example.com") mw.ServeHTTP(w, req) if got := w.Header().Get("Vary"); got != "Origin" { t.Fatalf("Vary = %q, want %q", got, "Origin") } }) }