diff --git a/agent/agent_test.go b/agent/agent_test.go index d4c8b568319c3..7c309a59fa882 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -994,42 +994,77 @@ func TestAgent_UnixRemoteForwarding(t *testing.T) { func TestAgent_SFTP(t *testing.T) { t.Parallel() - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - u, err := user.Current() - require.NoError(t, err, "get current user") - home := u.HomeDir - if runtime.GOOS == "windows" { - home = "/" + strings.ReplaceAll(home, "\\", "/") - } - //nolint:dogsled - conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) - sshClient, err := conn.SSHClient(ctx) - require.NoError(t, err) - defer sshClient.Close() - client, err := sftp.NewClient(sshClient) - require.NoError(t, err) - defer client.Close() - wd, err := client.Getwd() - require.NoError(t, err, "get working directory") - require.Equal(t, home, wd, "working directory should be home user home") - tempFile := filepath.Join(t.TempDir(), "sftp") - // SFTP only accepts unix-y paths. - remoteFile := filepath.ToSlash(tempFile) - if !path.IsAbs(remoteFile) { - // On Windows, e.g. "/C:/Users/...". - remoteFile = path.Join("/", remoteFile) - } - file, err := client.Create(remoteFile) - require.NoError(t, err) - err = file.Close() - require.NoError(t, err) - _, err = os.Stat(tempFile) - require.NoError(t, err) - // Close the client to trigger disconnect event. - _ = client.Close() - assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "") + t.Run("DefaultWorkingDirectory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + u, err := user.Current() + require.NoError(t, err, "get current user") + home := u.HomeDir + if runtime.GOOS == "windows" { + home = "/" + strings.ReplaceAll(home, "\\", "/") + } + //nolint:dogsled + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + defer sshClient.Close() + client, err := sftp.NewClient(sshClient) + require.NoError(t, err) + defer client.Close() + wd, err := client.Getwd() + require.NoError(t, err, "get working directory") + require.Equal(t, home, wd, "working directory should be user home") + tempFile := filepath.Join(t.TempDir(), "sftp") + // SFTP only accepts unix-y paths. + remoteFile := filepath.ToSlash(tempFile) + if !path.IsAbs(remoteFile) { + // On Windows, e.g. "/C:/Users/...". + remoteFile = path.Join("/", remoteFile) + } + file, err := client.Create(remoteFile) + require.NoError(t, err) + err = file.Close() + require.NoError(t, err) + _, err = os.Stat(tempFile) + require.NoError(t, err) + + // Close the client to trigger disconnect event. + _ = client.Close() + assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "") + }) + + t.Run("CustomWorkingDirectory", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + // Create a custom directory for the agent to use. + customDir := t.TempDir() + expectedDir := customDir + if runtime.GOOS == "windows" { + expectedDir = "/" + strings.ReplaceAll(customDir, "\\", "/") + } + + //nolint:dogsled + conn, agentClient, _, _, _ := setupAgent(t, agentsdk.Manifest{ + Directory: customDir, + }, 0) + sshClient, err := conn.SSHClient(ctx) + require.NoError(t, err) + defer sshClient.Close() + client, err := sftp.NewClient(sshClient) + require.NoError(t, err) + defer client.Close() + wd, err := client.Getwd() + require.NoError(t, err, "get working directory") + require.Equal(t, expectedDir, wd, "working directory should be custom directory") + + // Close the client to trigger disconnect event. + _ = client.Close() + assertConnectionReport(t, agentClient, proto.Connection_SSH, 0, "") + }) } func TestAgent_SCP(t *testing.T) { diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index 625c5e67205c4..c769e5f07f56f 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -829,13 +829,19 @@ func (s *Server) sftpHandler(logger slog.Logger, session ssh.Session) error { session.DisablePTYEmulation() var opts []sftp.ServerOption - // Change current working directory to the users home - // directory so that SFTP connections land there. - homedir, err := userHomeDir() - if err != nil { - logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) - } else { - opts = append(opts, sftp.WithServerWorkingDirectory(homedir)) + // Change current working directory to the configured + // directory (or home directory if not set) so that SFTP + // connections land there. + dir := s.config.WorkingDirectory() + if dir == "" { + var err error + dir, err = userHomeDir() + if err != nil { + logger.Warn(ctx, "get sftp working directory failed, unable to get home dir", slog.Error(err)) + } + } + if dir != "" { + opts = append(opts, sftp.WithServerWorkingDirectory(dir)) } server, err := sftp.NewServer(session, opts...)