package middleware_test import ( "context" "errors" "net/http" "testing" "git.codelab.vc/pkg/httpx/middleware" ) func TestBearerAuth(t *testing.T) { t.Run("sets authorization header", func(t *testing.T) { var captured http.Header base := mockTransport(func(req *http.Request) (*http.Response, error) { captured = req.Header.Clone() return okResponse(), nil }) tokenFunc := func(_ context.Context) (string, error) { return "my-secret-token", nil } transport := middleware.BearerAuth(tokenFunc)(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) _, err := transport.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } want := "Bearer my-secret-token" if got := captured.Get("Authorization"); got != want { t.Errorf("Authorization = %q, want %q", got, want) } }) t.Run("returns error when tokenFunc fails", func(t *testing.T) { base := mockTransport(func(req *http.Request) (*http.Response, error) { t.Fatal("base transport should not be called") return nil, nil }) tokenErr := errors.New("token expired") tokenFunc := func(_ context.Context) (string, error) { return "", tokenErr } transport := middleware.BearerAuth(tokenFunc)(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) _, err := transport.RoundTrip(req) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, tokenErr) { t.Errorf("got error %v, want %v", err, tokenErr) } }) } func TestBasicAuth(t *testing.T) { t.Run("sets basic auth header", func(t *testing.T) { var capturedReq *http.Request base := mockTransport(func(req *http.Request) (*http.Response, error) { capturedReq = req return okResponse(), nil }) transport := middleware.BasicAuth("user", "pass")(base) req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) _, err := transport.RoundTrip(req) if err != nil { t.Fatalf("unexpected error: %v", err) } username, password, ok := capturedReq.BasicAuth() if !ok { t.Fatal("BasicAuth() returned ok=false") } if username != "user" { t.Errorf("username = %q, want %q", username, "user") } if password != "pass" { t.Errorf("password = %q, want %q", password, "pass") } }) }