diff --git a/cli/agent.go b/cli/agent.go index 5465aeedd9302..073581bd950cb 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -50,6 +50,8 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { slogJSONPath string slogStackdriverPath string blockFileTransfer bool + agentHeaderCommand string + agentHeader []string ) cmd := &serpent.Command{ Use: "agent", @@ -176,6 +178,14 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { // with large payloads can take a bit. e.g. startup scripts // may take a while to insert. client.SDK.HTTPClient.Timeout = 30 * time.Second + // Attach header transport so we process --agent-header and + // --agent-header-command flags + headerTransport, err := headerTransport(ctx, r.agentURL, agentHeader, agentHeaderCommand) + if err != nil { + return xerrors.Errorf("configure header transport: %w", err) + } + headerTransport.Transport = client.SDK.HTTPClient.Transport + client.SDK.HTTPClient.Transport = headerTransport // Enable pprof handler // This prevents the pprof import from being accidentally deleted. @@ -361,6 +371,18 @@ func (r *RootCmd) workspaceAgent() *serpent.Command { Value: serpent.StringOf(&pprofAddress), Description: "The address to serve pprof.", }, + { + Flag: "agent-header-command", + Env: "CODER_AGENT_HEADER_COMMAND", + Value: serpent.StringOf(&agentHeaderCommand), + Description: "An external command that outputs additional HTTP headers added to all requests. The command must output each header as `key=value` on its own line.", + }, + { + Flag: "agent-header", + Env: "CODER_AGENT_HEADER", + Value: serpent.StringArrayOf(&agentHeader), + Description: "Additional HTTP headers added to all requests. Provide as " + `key=value` + ". Can be specified multiple times.", + }, { Flag: "no-reap", diff --git a/cli/agent_test.go b/cli/agent_test.go index 9571bf03e1a09..f30d12b012d88 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -3,10 +3,13 @@ package cli_test import ( "context" "fmt" + "net/http" + "net/http/httptest" "os" "path/filepath" "runtime" "strings" + "sync/atomic" "testing" "github.com/google/uuid" @@ -229,6 +232,43 @@ func TestWorkspaceAgent(t *testing.T) { require.Equal(t, codersdk.AgentSubsystemEnvbox, resources[0].Agents[0].Subsystems[0]) require.Equal(t, codersdk.AgentSubsystemExectrace, resources[0].Agents[0].Subsystems[1]) }) + t.Run("Header", func(t *testing.T) { + t.Parallel() + + var url string + var called int64 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "wow", r.Header.Get("X-Testing")) + assert.Equal(t, "Ethan was Here!", r.Header.Get("Cool-Header")) + assert.Equal(t, "very-wow-"+url, r.Header.Get("X-Process-Testing")) + assert.Equal(t, "more-wow", r.Header.Get("X-Process-Testing2")) + atomic.AddInt64(&called, 1) + w.WriteHeader(http.StatusGone) + })) + defer srv.Close() + url = srv.URL + coderURLEnv := "$CODER_URL" + if runtime.GOOS == "windows" { + coderURLEnv = "%CODER_URL%" + } + + logDir := t.TempDir() + inv, _ := clitest.New(t, + "agent", + "--auth", "token", + "--agent-token", "fake-token", + "--agent-url", srv.URL, + "--log-dir", logDir, + "--agent-header", "X-Testing=wow", + "--agent-header", "Cool-Header=Ethan was Here!", + "--agent-header-command", "printf X-Process-Testing=very-wow-"+coderURLEnv+"'\\r\\n'X-Process-Testing2=more-wow", + ) + + clitest.Start(t, inv) + require.Eventually(t, func() bool { + return atomic.LoadInt64(&called) > 0 + }, testutil.WaitShort, testutil.IntervalFast) + }) } func matchAgentWithVersion(rs []codersdk.WorkspaceResource) bool { diff --git a/cli/root.go b/cli/root.go index fdebdc74bedde..da7e48f2feae4 100644 --- a/cli/root.go +++ b/cli/root.go @@ -550,44 +550,7 @@ func (r *RootCmd) InitClient(client *codersdk.Client) serpent.MiddlewareFunc { // HeaderTransport creates a new transport that executes `--header-command` // if it is set to add headers for all outbound requests. func (r *RootCmd) HeaderTransport(ctx context.Context, serverURL *url.URL) (*codersdk.HeaderTransport, error) { - transport := &codersdk.HeaderTransport{ - Transport: http.DefaultTransport, - Header: http.Header{}, - } - headers := r.header - if r.headerCommand != "" { - shell := "sh" - caller := "-c" - if runtime.GOOS == "windows" { - shell = "cmd.exe" - caller = "/c" - } - var outBuf bytes.Buffer - // #nosec - cmd := exec.CommandContext(ctx, shell, caller, r.headerCommand) - cmd.Env = append(os.Environ(), "CODER_URL="+serverURL.String()) - cmd.Stdout = &outBuf - cmd.Stderr = io.Discard - err := cmd.Run() - if err != nil { - return nil, xerrors.Errorf("failed to run %v: %w", cmd.Args, err) - } - scanner := bufio.NewScanner(&outBuf) - for scanner.Scan() { - headers = append(headers, scanner.Text()) - } - if err := scanner.Err(); err != nil { - return nil, xerrors.Errorf("scan %v: %w", cmd.Args, err) - } - } - for _, header := range headers { - parts := strings.SplitN(header, "=", 2) - if len(parts) < 2 { - return nil, xerrors.Errorf("split header %q had less than two parts", header) - } - transport.Header.Add(parts[0], parts[1]) - } - return transport, nil + return headerTransport(ctx, serverURL, r.header, r.headerCommand) } func (r *RootCmd) configureClient(ctx context.Context, client *codersdk.Client, serverURL *url.URL, inv *serpent.Invocation) error { @@ -1273,3 +1236,46 @@ type roundTripper func(req *http.Request) (*http.Response, error) func (r roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return r(req) } + +// HeaderTransport creates a new transport that executes `--header-command` +// if it is set to add headers for all outbound requests. +func headerTransport(ctx context.Context, serverURL *url.URL, header []string, headerCommand string) (*codersdk.HeaderTransport, error) { + transport := &codersdk.HeaderTransport{ + Transport: http.DefaultTransport, + Header: http.Header{}, + } + headers := header + if headerCommand != "" { + shell := "sh" + caller := "-c" + if runtime.GOOS == "windows" { + shell = "cmd.exe" + caller = "/c" + } + var outBuf bytes.Buffer + // #nosec + cmd := exec.CommandContext(ctx, shell, caller, headerCommand) + cmd.Env = append(os.Environ(), "CODER_URL="+serverURL.String()) + cmd.Stdout = &outBuf + cmd.Stderr = io.Discard + err := cmd.Run() + if err != nil { + return nil, xerrors.Errorf("failed to run %v: %w", cmd.Args, err) + } + scanner := bufio.NewScanner(&outBuf) + for scanner.Scan() { + headers = append(headers, scanner.Text()) + } + if err := scanner.Err(); err != nil { + return nil, xerrors.Errorf("scan %v: %w", cmd.Args, err) + } + } + for _, header := range headers { + parts := strings.SplitN(header, "=", 2) + if len(parts) < 2 { + return nil, xerrors.Errorf("split header %q had less than two parts", header) + } + transport.Header.Add(parts[0], parts[1]) + } + return transport, nil +} diff --git a/cli/testdata/coder_agent_--help.golden b/cli/testdata/coder_agent_--help.golden index d6982fda18e7c..3394b43a9e900 100644 --- a/cli/testdata/coder_agent_--help.golden +++ b/cli/testdata/coder_agent_--help.golden @@ -15,6 +15,15 @@ OPTIONS: --log-stackdriver string, $CODER_AGENT_LOGGING_STACKDRIVER Output Stackdriver compatible logs to a given file. + --agent-header string-array, $CODER_AGENT_HEADER + Additional HTTP headers added to all requests. Provide as key=value. + Can be specified multiple times. + + --agent-header-command string, $CODER_AGENT_HEADER_COMMAND + An external command that outputs additional HTTP headers added to all + requests. The command must output each header as `key=value` on its + own line. + --auth string, $CODER_AGENT_AUTH (default: token) Specify the authentication type to use for the agent.