Skip to content

Commit e6e4582

Browse files
committed
feat(testutil): add lazy timeout context with location-based reset
It's common to create a context early in a test body, then do setup work unrelated to that context. By the time the context is actually used, it may have already timed out. This was detected as test failures in #21091. The new Context() function returns a context that resets its timeout when accessed from new lines in the test file. The timeout does not begin until the context is first used (lazy initialization). This is useful for integration tests that pass contexts through many subsystems, where each subsystem should get a fresh timeout window. Key behaviors: - Timer starts on first Done(), Deadline(), or Err() call - Value() does not trigger initialization (used for tracing/logging) - Each unique line in a _test.go file gets a fresh timeout window - Same-line access (e.g., in loops) does not reset - Expired contexts cannot be resurrected Limitations: - Wrapping with a child context (e.g., context.WithCancel) prevents resets since the child's methods don't call through to the parent - Storing the Done() channel prevents resets on subsequent accesses The original fixed-timeout behavior is available via ContextFixed().
1 parent 770fdb3 commit e6e4582

File tree

4 files changed

+422
-2
lines changed

4 files changed

+422
-2
lines changed

cli/ssh_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2052,7 +2052,6 @@ func TestSSH_Container(t *testing.T) {
20522052
t.Parallel()
20532053

20542054
client, workspace, agentToken := setupWorkspaceForAgent(t)
2055-
ctx := testutil.Context(t, testutil.WaitLong)
20562055
pool, err := dockertest.NewPool("")
20572056
require.NoError(t, err, "Could not connect to docker")
20582057
ct, err := pool.RunWithOptions(&dockertest.RunOptions{
@@ -2083,6 +2082,7 @@ func TestSSH_Container(t *testing.T) {
20832082
})
20842083
_ = coderdtest.NewWorkspaceAgentWaiter(t, client, workspace.ID).Wait()
20852084

2085+
ctx := testutil.Context(t, testutil.WaitLong)
20862086
inv, root := clitest.New(t, "ssh", workspace.Name, "-c", ct.Container.ID)
20872087
clitest.SetupConfig(t, client, root)
20882088
ptty := ptytest.New(t).Attach(inv)

testutil/ctx.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,33 @@ import (
66
"time"
77
)
88

9-
func Context(t testing.TB, dur time.Duration) context.Context {
9+
// Context returns a context that resets its timeout when accessed from new
10+
// locations in the test file. The timeout does not begin until the context is
11+
// first used.
12+
//
13+
// This is useful for integration tests that pass contexts through many
14+
// subsystems, where each subsystem should get a fresh timeout window.
15+
//
16+
// Note: Each call to Done(), Deadline(), or Err() from a new line in the test
17+
// file resets the timeout. If you need to prevent resets (e.g., to test actual
18+
// timeout behavior), store the channel:
19+
//
20+
// done := ctx.Done() // Timeout starts, channel stored
21+
// // ... do work ...
22+
// select {
23+
// case <-done: // No reset, using stored channel
24+
// // handle timeout
25+
// }
26+
//
27+
// Wrapping with a child context (e.g., context.WithCancel) will also prevent
28+
// resets since the child's methods don't call through to the parent.
29+
func Context(t testing.TB, timeout time.Duration) context.Context {
30+
return newLazyTimeoutContext(t, timeout)
31+
}
32+
33+
// ContextFixed returns a context with a fixed timeout that starts immediately.
34+
// Use Context() instead for contexts that should reset on new package access.
35+
func ContextFixed(t testing.TB, dur time.Duration) context.Context {
1036
ctx, cancel := context.WithTimeout(context.Background(), dur)
1137
t.Cleanup(cancel)
1238
return ctx

testutil/lazy_ctx.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package testutil
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"runtime"
7+
"strings"
8+
"sync"
9+
"testing"
10+
"time"
11+
)
12+
13+
var _ context.Context = (*lazyTimeoutContext)(nil)
14+
15+
// lazyTimeoutContext is a context.Context that resets its timeout when accessed
16+
// from new locations in the test file. The timeout does not begin until the
17+
// context is first used.
18+
type lazyTimeoutContext struct {
19+
t testing.TB
20+
timeout time.Duration
21+
22+
mu sync.Mutex
23+
started bool
24+
deadline time.Time
25+
timer *time.Timer
26+
done chan struct{}
27+
err error
28+
seenLocations map[string]struct{}
29+
}
30+
31+
func newLazyTimeoutContext(t testing.TB, timeout time.Duration) context.Context {
32+
ctx := &lazyTimeoutContext{
33+
t: t,
34+
timeout: timeout,
35+
done: make(chan struct{}),
36+
seenLocations: make(map[string]struct{}),
37+
}
38+
t.Cleanup(ctx.cancel)
39+
return ctx
40+
}
41+
42+
// Deadline returns the current deadline, if any. The deadline is set lazily
43+
// on first access and may be extended when accessed from new locations.
44+
func (c *lazyTimeoutContext) Deadline() (deadline time.Time, ok bool) {
45+
c.maybeResetForLocation()
46+
47+
c.mu.Lock()
48+
defer c.mu.Unlock()
49+
if !c.started {
50+
return time.Time{}, false
51+
}
52+
return c.deadline, true
53+
}
54+
55+
// Done returns a channel that's closed when the context is canceled.
56+
func (c *lazyTimeoutContext) Done() <-chan struct{} {
57+
c.maybeResetForLocation()
58+
return c.done
59+
}
60+
61+
// Err returns the error indicating why this context was canceled.
62+
func (c *lazyTimeoutContext) Err() error {
63+
c.maybeResetForLocation()
64+
65+
c.mu.Lock()
66+
defer c.mu.Unlock()
67+
return c.err
68+
}
69+
70+
// Value returns nil; this context carries no values.
71+
// Note: Value() does NOT trigger lazy initialization or timeout reset.
72+
func (*lazyTimeoutContext) Value(any) any {
73+
return nil
74+
}
75+
76+
// maybeResetForLocation starts the timer on first access and resets it when
77+
// accessed from a new location in the test file.
78+
func (c *lazyTimeoutContext) maybeResetForLocation() {
79+
loc := callerLocation()
80+
81+
c.mu.Lock()
82+
defer c.mu.Unlock()
83+
84+
// Don't reset if already canceled.
85+
if c.err != nil {
86+
return
87+
}
88+
89+
// Always start the timer on first access, regardless of location.
90+
if !c.started {
91+
c.startLocked()
92+
if loc != "" {
93+
c.seenLocations[loc] = struct{}{}
94+
}
95+
if testing.Verbose() {
96+
c.t.Logf("lazyTimeoutContext: started timeout for location: %s", loc)
97+
}
98+
return
99+
}
100+
101+
// Only reset for known test file locations.
102+
if loc == "" {
103+
return
104+
}
105+
106+
if _, seen := c.seenLocations[loc]; seen {
107+
return
108+
}
109+
c.seenLocations[loc] = struct{}{}
110+
111+
// Reset deadline.
112+
c.deadline = time.Now().Add(c.timeout)
113+
if c.timer != nil && c.timer.Stop() {
114+
c.timer.Reset(c.timeout)
115+
}
116+
117+
if testing.Verbose() {
118+
c.t.Logf("lazyTimeoutContext: reset timeout for new location: %s", loc)
119+
}
120+
}
121+
122+
// startLocked initializes the timer. Must be called with mu held.
123+
func (c *lazyTimeoutContext) startLocked() {
124+
c.started = true
125+
c.deadline = time.Now().Add(c.timeout)
126+
c.timer = time.AfterFunc(c.timeout, func() {
127+
c.mu.Lock()
128+
defer c.mu.Unlock()
129+
if c.err == nil {
130+
c.err = context.DeadlineExceeded
131+
close(c.done)
132+
}
133+
})
134+
}
135+
136+
// cancel stops the timer and marks the context as canceled.
137+
func (c *lazyTimeoutContext) cancel() {
138+
c.mu.Lock()
139+
defer c.mu.Unlock()
140+
if c.timer != nil {
141+
c.timer.Stop()
142+
}
143+
if c.err == nil {
144+
c.err = context.Canceled
145+
close(c.done)
146+
}
147+
}
148+
149+
// callerLocation walks the stack to find the line in a test file that
150+
// initiated the call. Returns empty string if not called from a test file.
151+
func callerLocation() string {
152+
// Skip: runtime.Callers, callerLocation, maybeResetForLocation,
153+
// Done/Deadline/Err, and we want to find the caller of those.
154+
pc := make([]uintptr, 50)
155+
n := runtime.Callers(4, pc)
156+
if n == 0 {
157+
return ""
158+
}
159+
160+
frames := runtime.CallersFrames(pc[:n])
161+
for {
162+
frame, more := frames.Next()
163+
164+
// Look for frames in _test.go files.
165+
if strings.HasSuffix(frame.File, "_test.go") {
166+
return fmt.Sprintf("%s:%d", frame.File, frame.Line)
167+
}
168+
169+
if !more {
170+
break
171+
}
172+
}
173+
174+
return ""
175+
}

0 commit comments

Comments
 (0)