Skip to content

Commit 0be0bb2

Browse files
committed
make some tests use pkce
1 parent ead92c8 commit 0be0bb2

File tree

3 files changed

+81
-11
lines changed

3 files changed

+81
-11
lines changed

coderd/coderdtest/oidctest/idp.go

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ type FakeIDP struct {
169169
// clientID to be used by coderd
170170
clientID string
171171
clientSecret string
172+
pkce bool // TODO: Implement for refresh token flow as well
172173
// externalProviderID is optional to match the provider in coderd for
173174
// redirectURLs.
174175
externalProviderID string
@@ -181,6 +182,8 @@ type FakeIDP struct {
181182
// These maps are used to control the state of the IDP.
182183
// That is the various access tokens, refresh tokens, states, etc.
183184
codeToStateMap *syncmap.Map[string, string]
185+
// Code -> PKCE Challenge
186+
codeToChallengeMap *syncmap.Map[string, string]
184187
// Token -> Email
185188
accessTokens *syncmap.Map[string, token]
186189
// Refresh Token -> Email
@@ -239,6 +242,12 @@ func (s statusHookError) Error() string {
239242

240243
type FakeIDPOpt func(idp *FakeIDP)
241244

245+
func WithPKCE() func(*FakeIDP) {
246+
return func(f *FakeIDP) {
247+
f.pkce = true
248+
}
249+
}
250+
242251
func WithAuthorizedRedirectURL(hook func(redirectURL string) error) func(*FakeIDP) {
243252
return func(f *FakeIDP) {
244253
f.hookValidRedirectURL = hook
@@ -450,6 +459,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
450459
clientSecret: uuid.NewString(),
451460
logger: slog.Make(),
452461
codeToStateMap: syncmap.New[string, string](),
462+
codeToChallengeMap: syncmap.New[string, string](),
453463
accessTokens: syncmap.New[string, token](),
454464
refreshTokens: syncmap.New[string, string](),
455465
refreshTokensUsed: syncmap.New[string, bool](),
@@ -557,8 +567,16 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
557567
func (f *FakeIDP) GenerateAuthenticatedToken(claims jwt.MapClaims) (*oauth2.Token, error) {
558568
state := uuid.NewString()
559569
f.stateToIDTokenClaims.Store(state, claims)
560-
code := f.newCode(state)
561-
return f.locked.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code)
570+
571+
exchangeOpts := []oauth2.AuthCodeOption{}
572+
verifier := ""
573+
if f.pkce {
574+
verifier = oauth2.GenerateVerifier()
575+
exchangeOpts = append(exchangeOpts, oauth2.VerifierOption(verifier))
576+
}
577+
code := f.newCode(state, oauth2.S256ChallengeFromVerifier(verifier))
578+
579+
return f.locked.Config().Exchange(oidc.ClientContext(context.Background(), f.HTTPClient(nil)), code, exchangeOpts...)
562580
}
563581

564582
// Login does the full OIDC flow starting at the "LoginButton".
@@ -790,9 +808,10 @@ type ProviderJSON struct {
790808

791809
// newCode enforces the code exchanged is actually a valid code
792810
// created by the IDP.
793-
func (f *FakeIDP) newCode(state string) string {
811+
func (f *FakeIDP) newCode(state string, challenge string) string {
794812
code := uuid.NewString()
795813
f.codeToStateMap.Store(code, state)
814+
f.codeToChallengeMap.Store(code, challenge)
796815
return code
797816
}
798817

@@ -918,6 +937,22 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
918937
mux.Handle(authorizePath, http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
919938
f.logger.Info(r.Context(), "http call authorize", slogRequestFields(r)...)
920939

940+
challenge := ""
941+
if f.pkce {
942+
method := r.URL.Query().Get("code_challenge_method")
943+
challenge = r.URL.Query().Get("code_challenge")
944+
945+
if method == "" {
946+
httpError(rw, http.StatusBadRequest, xerrors.New("missing code_challenge_method"))
947+
return
948+
}
949+
950+
if challenge == "" {
951+
httpError(rw, http.StatusBadRequest, xerrors.New("missing code_challenge"))
952+
return
953+
}
954+
}
955+
921956
clientID := r.URL.Query().Get("client_id")
922957
if !assert.Equal(t, f.clientID, clientID, "unexpected client_id") {
923958
httpError(rw, http.StatusBadRequest, xerrors.New("invalid client_id"))
@@ -959,7 +994,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
959994

960995
q := ru.Query()
961996
q.Set("state", state)
962-
q.Set("code", f.newCode(state))
997+
q.Set("code", f.newCode(state, challenge))
963998
ru.RawQuery = q.Encode()
964999

9651000
http.Redirect(rw, r, ru.String(), http.StatusTemporaryRedirect)
@@ -1009,6 +1044,21 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10091044
http.Error(rw, "invalid code", http.StatusBadRequest)
10101045
return
10111046
}
1047+
1048+
if f.pkce {
1049+
challenge, ok := f.codeToChallengeMap.Load(code)
1050+
if !ok {
1051+
httpError(rw, http.StatusBadRequest, xerrors.New("code challenge not found for code"))
1052+
return
1053+
}
1054+
codeVerifier := values.Get("code_verifier")
1055+
expecter := oauth2.S256ChallengeFromVerifier(codeVerifier)
1056+
if challenge != expecter {
1057+
httpError(rw, http.StatusBadRequest, xerrors.New("invalid code verifier"))
1058+
return
1059+
}
1060+
}
1061+
10121062
// Always invalidate the code after it is used.
10131063
f.codeToStateMap.Delete(code)
10141064

coderd/externalauth/externalauth_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -792,20 +792,21 @@ func setupOauth2Test(t *testing.T, settings testConfig) (*oidctest.FakeIDP, *ext
792792

793793
const providerID = "test-idp"
794794
fake := oidctest.NewFakeIDP(t,
795-
append([]oidctest.FakeIDPOpt{}, settings.FakeIDPOpts...)...,
795+
append([]oidctest.FakeIDPOpt{oidctest.WithPKCE()}, settings.FakeIDPOpts...)...,
796796
)
797797

798798
f := promoauth.NewFactory(prometheus.NewRegistry())
799799
cid, cs := fake.AppCredentials()
800800
config := &externalauth.Config{
801801
InstrumentedOAuth2Config: f.New("test-oauth2",
802802
fake.OIDCConfig(t, nil, settings.CoderOIDCConfigOpts...)),
803-
ID: providerID,
804-
ClientID: cid,
805-
ClientSecret: cs,
806-
ValidateURL: fake.WellknownConfig().UserInfoURL,
807-
RevokeURL: fake.WellknownConfig().RevokeURL,
808-
RevokeTimeout: 1 * time.Second,
803+
ID: providerID,
804+
ClientID: cid,
805+
ClientSecret: cs,
806+
ValidateURL: fake.WellknownConfig().UserInfoURL,
807+
RevokeURL: fake.WellknownConfig().RevokeURL,
808+
RevokeTimeout: 1 * time.Second,
809+
CodeChallengeMethodsSupported: []promoauth.Oauth2PKCEChallengeMethod{promoauth.PKCEChallengeMethodSha256},
809810
}
810811
settings.ExternalAuthOpt(config)
811812

coderd/externalauth_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/coder/coder/v2/coderd/database/dbtime"
2727
"github.com/coder/coder/v2/coderd/externalauth"
2828
"github.com/coder/coder/v2/coderd/httpapi"
29+
"github.com/coder/coder/v2/coderd/promoauth"
2930
"github.com/coder/coder/v2/codersdk"
3031
"github.com/coder/coder/v2/codersdk/agentsdk"
3132
"github.com/coder/coder/v2/provisioner/echo"
@@ -34,6 +35,24 @@ import (
3435

3536
func TestExternalAuthByID(t *testing.T) {
3637
t.Parallel()
38+
t.Run("PKCEMissing", func(t *testing.T) {
39+
t.Parallel()
40+
const providerID = "fake-github"
41+
fake := oidctest.NewFakeIDP(t, oidctest.WithServing())
42+
43+
client := coderdtest.New(t, &coderdtest.Options{
44+
ExternalAuthConfigs: []*externalauth.Config{
45+
fake.ExternalAuthConfig(t, providerID, nil, func(cfg *externalauth.Config) {
46+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
47+
cfg.CodeChallengeMethodsSupported = []promoauth.Oauth2PKCEChallengeMethod{}
48+
}),
49+
},
50+
})
51+
coderdtest.CreateFirstUser(t, client)
52+
auth, err := client.ExternalAuthByID(context.Background(), providerID)
53+
require.NoError(t, err)
54+
require.False(t, auth.Authenticated)
55+
})
3756
t.Run("Unauthenticated", func(t *testing.T) {
3857
t.Parallel()
3958
const providerID = "fake-github"

0 commit comments

Comments
 (0)