@@ -125,6 +125,7 @@ type Server struct {
125
125
listeners map [net.Listener ]struct {}
126
126
conns map [net.Conn ]struct {}
127
127
sessions map [ssh.Session ]struct {}
128
+ processes map [* os.Process ]struct {}
128
129
closing chan struct {}
129
130
// Wait for goroutines to exit, waited without
130
131
// a lock on mu but protected by closing.
@@ -183,6 +184,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
183
184
fs : fs ,
184
185
conns : make (map [net.Conn ]struct {}),
185
186
sessions : make (map [ssh.Session ]struct {}),
187
+ processes : make (map [* os.Process ]struct {}),
186
188
logger : logger ,
187
189
188
190
config : config ,
@@ -587,7 +589,10 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
587
589
// otherwise context cancellation will not propagate properly
588
590
// and SSH server close may be delayed.
589
591
cmd .SysProcAttr = cmdSysProcAttr ()
590
- cmd .Cancel = cmdCancel (session .Context (), logger , cmd )
592
+
593
+ // to match OpenSSH, we don't actually tear a non-TTY command down, even if the session ends.
594
+ // c.f. https://github.com/coder/coder/issues/18519#issuecomment-3019118271
595
+ cmd .Cancel = nil
591
596
592
597
cmd .Stdout = session
593
598
cmd .Stderr = session .Stderr ()
@@ -610,6 +615,16 @@ func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, mag
610
615
s .metrics .sessionErrors .WithLabelValues (magicTypeLabel , "no" , "start_command" ).Add (1 )
611
616
return xerrors .Errorf ("start: %w" , err )
612
617
}
618
+
619
+ // Since we don't cancel the process when the session stops, we still need to tear it down if we are closing. So
620
+ // track it here.
621
+ if ! s .trackProcess (cmd .Process , true ) {
622
+ // must be closing
623
+ err = cmdCancel (logger , cmd .Process )
624
+ return xerrors .Errorf ("failed to track process: %w" , err )
625
+ }
626
+ defer s .trackProcess (cmd .Process , false )
627
+
613
628
sigs := make (chan ssh.Signal , 1 )
614
629
session .Signals (sigs )
615
630
defer func () {
@@ -1070,6 +1085,27 @@ func (s *Server) trackSession(ss ssh.Session, add bool) (ok bool) {
1070
1085
return true
1071
1086
}
1072
1087
1088
+ // trackCommand registers the process with the server. If the server is
1089
+ // closing, the process is not registered and should be closed.
1090
+ //
1091
+ //nolint:revive
1092
+ func (s * Server ) trackProcess (p * os.Process , add bool ) (ok bool ) {
1093
+ s .mu .Lock ()
1094
+ defer s .mu .Unlock ()
1095
+ if add {
1096
+ if s .closing != nil {
1097
+ // Server closed.
1098
+ return false
1099
+ }
1100
+ s .wg .Add (1 )
1101
+ s .processes [p ] = struct {}{}
1102
+ return true
1103
+ }
1104
+ s .wg .Done ()
1105
+ delete (s .processes , p )
1106
+ return true
1107
+ }
1108
+
1073
1109
// Close the server and all active connections. Server can be re-used
1074
1110
// after Close is done.
1075
1111
func (s * Server ) Close () error {
@@ -1109,6 +1145,10 @@ func (s *Server) Close() error {
1109
1145
_ = c .Close ()
1110
1146
}
1111
1147
1148
+ for p := range s .processes {
1149
+ _ = cmdCancel (s .logger , p )
1150
+ }
1151
+
1112
1152
s .logger .Debug (ctx , "closing SSH server" )
1113
1153
err := s .srv .Close ()
1114
1154
0 commit comments