Supports AllowOrigins, AllowMethods, AllowHeaders, ExposeHeaders, AllowCredentials, and MaxAge options. Handles preflight OPTIONS requests correctly, including Vary header and credential-aware origin echoing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
144 lines
4.6 KiB
Go
144 lines
4.6 KiB
Go
package server_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"git.codelab.vc/pkg/httpx/server"
|
|
)
|
|
|
|
func TestCORS(t *testing.T) {
|
|
okHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
t.Run("no Origin header passes through", func(t *testing.T) {
|
|
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Header().Get("Access-Control-Allow-Origin") != "" {
|
|
t.Fatal("expected no CORS headers without Origin")
|
|
}
|
|
})
|
|
|
|
t.Run("wildcard origin", func(t *testing.T) {
|
|
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "*" {
|
|
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "*")
|
|
}
|
|
})
|
|
|
|
t.Run("specific origin allowed", func(t *testing.T) {
|
|
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
|
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
|
}
|
|
})
|
|
|
|
t.Run("disallowed origin gets no CORS headers", func(t *testing.T) {
|
|
mw := server.CORS(server.AllowOrigins("http://example.com"))(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "http://evil.com")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
|
t.Fatalf("expected no Access-Control-Allow-Origin for disallowed origin, got %q", got)
|
|
}
|
|
})
|
|
|
|
t.Run("preflight OPTIONS", func(t *testing.T) {
|
|
mw := server.CORS(
|
|
server.AllowOrigins("http://example.com"),
|
|
server.AllowMethods("GET", "POST", "PUT"),
|
|
server.AllowHeaders("Authorization", "Content-Type"),
|
|
server.MaxAge(3600),
|
|
)(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodOptions, "/api/data", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
req.Header.Set("Access-Control-Request-Method", "POST")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if w.Code != http.StatusNoContent {
|
|
t.Fatalf("got status %d, want %d", w.Code, http.StatusNoContent)
|
|
}
|
|
if got := w.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, PUT" {
|
|
t.Fatalf("Access-Control-Allow-Methods = %q, want %q", got, "GET, POST, PUT")
|
|
}
|
|
if got := w.Header().Get("Access-Control-Allow-Headers"); got != "Authorization, Content-Type" {
|
|
t.Fatalf("Access-Control-Allow-Headers = %q, want %q", got, "Authorization, Content-Type")
|
|
}
|
|
if got := w.Header().Get("Access-Control-Max-Age"); got != "3600" {
|
|
t.Fatalf("Access-Control-Max-Age = %q, want %q", got, "3600")
|
|
}
|
|
})
|
|
|
|
t.Run("credentials with specific origin", func(t *testing.T) {
|
|
mw := server.CORS(
|
|
server.AllowOrigins("*"),
|
|
server.AllowCredentials(true),
|
|
)(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
// With credentials, must echo specific origin even with wildcard config.
|
|
if got := w.Header().Get("Access-Control-Allow-Origin"); got != "http://example.com" {
|
|
t.Fatalf("Access-Control-Allow-Origin = %q, want %q", got, "http://example.com")
|
|
}
|
|
if got := w.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
|
t.Fatalf("Access-Control-Allow-Credentials = %q, want %q", got, "true")
|
|
}
|
|
})
|
|
|
|
t.Run("expose headers", func(t *testing.T) {
|
|
mw := server.CORS(
|
|
server.AllowOrigins("*"),
|
|
server.ExposeHeaders("X-Custom", "X-Request-Id"),
|
|
)(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if got := w.Header().Get("Access-Control-Expose-Headers"); got != "X-Custom, X-Request-Id" {
|
|
t.Fatalf("Access-Control-Expose-Headers = %q, want %q", got, "X-Custom, X-Request-Id")
|
|
}
|
|
})
|
|
|
|
t.Run("Vary header is set", func(t *testing.T) {
|
|
mw := server.CORS(server.AllowOrigins("*"))(okHandler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Origin", "http://example.com")
|
|
mw.ServeHTTP(w, req)
|
|
|
|
if got := w.Header().Get("Vary"); got != "Origin" {
|
|
t.Fatalf("Vary = %q, want %q", got, "Origin")
|
|
}
|
|
})
|
|
}
|