@@ -169,6 +169,7 @@ type FakeIDP struct {
169169 // clientID to be used by coderd
170170 clientID string
171171 clientSecret string
172+ pkce bool // TODO: Implement for refresh token flow as well
172173 // externalProviderID is optional to match the provider in coderd for
173174 // redirectURLs.
174175 externalProviderID string
@@ -181,6 +182,8 @@ type FakeIDP struct {
181182 // These maps are used to control the state of the IDP.
182183 // That is the various access tokens, refresh tokens, states, etc.
183184 codeToStateMap * syncmap.Map [string , string ]
185+ // Code -> PKCE Challenge
186+ codeToChallengeMap * syncmap.Map [string , string ]
184187 // Token -> Email
185188 accessTokens * syncmap.Map [string , token ]
186189 // Refresh Token -> Email
@@ -239,6 +242,12 @@ func (s statusHookError) Error() string {
239242
240243type FakeIDPOpt func (idp * FakeIDP )
241244
245+ func WithPKCE () func (* FakeIDP ) {
246+ return func (f * FakeIDP ) {
247+ f .pkce = true
248+ }
249+ }
250+
242251func WithAuthorizedRedirectURL (hook func (redirectURL string ) error ) func (* FakeIDP ) {
243252 return func (f * FakeIDP ) {
244253 f .hookValidRedirectURL = hook
@@ -450,6 +459,7 @@ func NewFakeIDP(t testing.TB, opts ...FakeIDPOpt) *FakeIDP {
450459 clientSecret : uuid .NewString (),
451460 logger : slog .Make (),
452461 codeToStateMap : syncmap .New [string , string ](),
462+ codeToChallengeMap : syncmap .New [string , string ](),
453463 accessTokens : syncmap .New [string , token ](),
454464 refreshTokens : syncmap .New [string , string ](),
455465 refreshTokensUsed : syncmap .New [string , bool ](),
@@ -557,8 +567,16 @@ func (f *FakeIDP) realServer(t testing.TB) *httptest.Server {
557567func (f * FakeIDP ) GenerateAuthenticatedToken (claims jwt.MapClaims ) (* oauth2.Token , error ) {
558568 state := uuid .NewString ()
559569 f .stateToIDTokenClaims .Store (state , claims )
560- code := f .newCode (state )
561- return f .locked .Config ().Exchange (oidc .ClientContext (context .Background (), f .HTTPClient (nil )), code )
570+
571+ exchangeOpts := []oauth2.AuthCodeOption {}
572+ verifier := ""
573+ if f .pkce {
574+ verifier = oauth2 .GenerateVerifier ()
575+ exchangeOpts = append (exchangeOpts , oauth2 .VerifierOption (verifier ))
576+ }
577+ code := f .newCode (state , oauth2 .S256ChallengeFromVerifier (verifier ))
578+
579+ return f .locked .Config ().Exchange (oidc .ClientContext (context .Background (), f .HTTPClient (nil )), code , exchangeOpts ... )
562580}
563581
564582// Login does the full OIDC flow starting at the "LoginButton".
@@ -790,9 +808,10 @@ type ProviderJSON struct {
790808
791809// newCode enforces the code exchanged is actually a valid code
792810// created by the IDP.
793- func (f * FakeIDP ) newCode (state string ) string {
811+ func (f * FakeIDP ) newCode (state string , challenge string ) string {
794812 code := uuid .NewString ()
795813 f .codeToStateMap .Store (code , state )
814+ f .codeToChallengeMap .Store (code , challenge )
796815 return code
797816}
798817
@@ -918,6 +937,22 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
918937 mux .Handle (authorizePath , http .HandlerFunc (func (rw http.ResponseWriter , r * http.Request ) {
919938 f .logger .Info (r .Context (), "http call authorize" , slogRequestFields (r )... )
920939
940+ challenge := ""
941+ if f .pkce {
942+ method := r .URL .Query ().Get ("code_challenge_method" )
943+ challenge = r .URL .Query ().Get ("code_challenge" )
944+
945+ if method == "" {
946+ httpError (rw , http .StatusBadRequest , xerrors .New ("missing code_challenge_method" ))
947+ return
948+ }
949+
950+ if challenge == "" {
951+ httpError (rw , http .StatusBadRequest , xerrors .New ("missing code_challenge" ))
952+ return
953+ }
954+ }
955+
921956 clientID := r .URL .Query ().Get ("client_id" )
922957 if ! assert .Equal (t , f .clientID , clientID , "unexpected client_id" ) {
923958 httpError (rw , http .StatusBadRequest , xerrors .New ("invalid client_id" ))
@@ -959,7 +994,7 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
959994
960995 q := ru .Query ()
961996 q .Set ("state" , state )
962- q .Set ("code" , f .newCode (state ))
997+ q .Set ("code" , f .newCode (state , challenge ))
963998 ru .RawQuery = q .Encode ()
964999
9651000 http .Redirect (rw , r , ru .String (), http .StatusTemporaryRedirect )
@@ -1009,6 +1044,21 @@ func (f *FakeIDP) httpHandler(t testing.TB) http.Handler {
10091044 http .Error (rw , "invalid code" , http .StatusBadRequest )
10101045 return
10111046 }
1047+
1048+ if f .pkce {
1049+ challenge , ok := f .codeToChallengeMap .Load (code )
1050+ if ! ok {
1051+ httpError (rw , http .StatusBadRequest , xerrors .New ("code challenge not found for code" ))
1052+ return
1053+ }
1054+ codeVerifier := values .Get ("code_verifier" )
1055+ expecter := oauth2 .S256ChallengeFromVerifier (codeVerifier )
1056+ if challenge != expecter {
1057+ httpError (rw , http .StatusBadRequest , xerrors .New ("invalid code verifier" ))
1058+ return
1059+ }
1060+ }
1061+
10121062 // Always invalidate the code after it is used.
10131063 f .codeToStateMap .Delete (code )
10141064
0 commit comments