Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion testutil/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,21 @@ import (
"time"
)

func Context(t testing.TB, dur time.Duration) context.Context {
// Context returns a context with a timeout that starts on first use and resets
// when accessed from new lines in test files. Each call to Done, Deadline, or
// Err from a new line resets the deadline.
//
// To prevent resets, store the Done channel or wrap with a child context:
//
// done := ctx.Done()
// <-done // Uses stored channel, no reset.
func Context(t testing.TB, timeout time.Duration) context.Context {
return newLazyTimeoutContext(t, timeout)
}

// ContextFixed returns a context with a timeout that starts immediately and
// does not reset.
func ContextFixed(t testing.TB, dur time.Duration) context.Context {
ctx, cancel := context.WithTimeout(context.Background(), dur)
t.Cleanup(cancel)
return ctx
Expand Down
170 changes: 170 additions & 0 deletions testutil/lazy_ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package testutil

import (
"context"
"fmt"
"runtime"
"strings"
"sync"
"testing"
"time"
)

var _ context.Context = (*lazyTimeoutContext)(nil)

// lazyTimeoutContext implements context.Context with a timeout that starts on
// first use and resets when accessed from new locations in test files.
type lazyTimeoutContext struct {
t testing.TB
timeout time.Duration

mu sync.Mutex // Protects following fields.
testDone bool // True after cancel, prevents post-test logging.
deadline time.Time
timer *time.Timer
done chan struct{}
err error
seenLocations map[string]struct{}
}

func newLazyTimeoutContext(t testing.TB, timeout time.Duration) context.Context {
ctx := &lazyTimeoutContext{
t: t,
timeout: timeout,
done: make(chan struct{}),
seenLocations: make(map[string]struct{}),
}
t.Cleanup(ctx.cancel)
return ctx
}

// Deadline returns the current deadline. The deadline is set on first access
// and may be extended when accessed from new locations in test files.
func (c *lazyTimeoutContext) Deadline() (deadline time.Time, ok bool) {
c.maybeResetForLocation()

c.mu.Lock()
defer c.mu.Unlock()
return c.deadline, true
}

// Done returns a channel that's closed when the context is canceled.
func (c *lazyTimeoutContext) Done() <-chan struct{} {
c.maybeResetForLocation()
return c.done
}

// Err returns the error indicating why this context was canceled.
func (c *lazyTimeoutContext) Err() error {
c.maybeResetForLocation()

c.mu.Lock()
defer c.mu.Unlock()
return c.err
}

// Value returns nil. It does not trigger initialization or reset.
func (*lazyTimeoutContext) Value(any) any {
return nil
}

// maybeResetForLocation starts the timer on first access and resets the
// deadline when called from a previously unseen location in a test file.
func (c *lazyTimeoutContext) maybeResetForLocation() {
loc := callerLocation()

c.mu.Lock()
defer c.mu.Unlock()

// Already canceled.
if c.err != nil {
return
}

// First access, start timer.
if c.timer == nil {
c.startLocked()
if loc != "" {
c.seenLocations[loc] = struct{}{}
}
if testing.Verbose() && !c.testDone {
c.t.Logf("lazyTimeoutContext: started timeout for location: %s", loc)
}
return
}

// Non-test location, ignore.
if loc == "" {
return
}

if _, seen := c.seenLocations[loc]; seen {
return
}
c.seenLocations[loc] = struct{}{}

// New location, reset deadline.
c.deadline = time.Now().Add(c.timeout)
if c.timer.Stop() {
c.timer.Reset(c.timeout)
}

if testing.Verbose() && !c.testDone {
c.t.Logf("lazyTimeoutContext: reset timeout for new location: %s", loc)
}
}

// startLocked initializes the deadline and timer. It must be called with mu held.
func (c *lazyTimeoutContext) startLocked() {
c.deadline = time.Now().Add(c.timeout)
c.timer = time.AfterFunc(c.timeout, func() {
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.err = context.DeadlineExceeded
close(c.done)
}
})
}

// cancel stops the timer and marks the context as canceled. It is called by
// t.Cleanup when the test ends.
func (c *lazyTimeoutContext) cancel() {
c.mu.Lock()
defer c.mu.Unlock()
c.testDone = true
if c.timer != nil {
c.timer.Stop()
}
if c.err == nil {
c.err = context.Canceled
close(c.done)
}
}

// callerLocation returns the file:line of the first caller in a _test.go file,
// or the empty string if none is found.
func callerLocation() string {
// Skip runtime.Callers, callerLocation, maybeResetForLocation, and the
// context method (Done/Deadline/Err).
pc := make([]uintptr, 50)
n := runtime.Callers(4, pc)
if n == 0 {
return ""
}

frames := runtime.CallersFrames(pc[:n])
for {
frame, more := frames.Next()

if strings.HasSuffix(frame.File, "_test.go") {
return fmt.Sprintf("%s:%d", frame.File, frame.Line)
}

if !more {
break
}
}

return ""
}
200 changes: 200 additions & 0 deletions testutil/lazy_ctx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package testutil_test

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/testutil"
)

func TestLazyTimeoutContext_LazyStart(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 10*time.Millisecond)

time.Sleep(50 * time.Millisecond) // Longer than timeout.

// Timer hasn't started, context should be valid.
select {
case <-ctx.Done():
t.Fatal("context should not be done yet - timer should not have started")
default:
}

// First select started the timer, wait for expiration.
select {
case <-ctx.Done():
case <-time.After(50 * time.Millisecond):
t.Fatal("context should have expired")
}
}

func TestLazyTimeoutContext_ValueDoesNotTriggerStart(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 10*time.Millisecond)

_ = ctx.Value("key") // Must not start timer.

time.Sleep(50 * time.Millisecond)

select {
case <-ctx.Done():
t.Fatal("Value() should not start timer")
default:
}
}

func TestLazyTimeoutContext_Expiration(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 5*time.Millisecond)

done := ctx.Done() // Store to avoid reset in select.

select {
case <-done:
case <-time.After(50 * time.Millisecond):
t.Fatal("context should have expired")
}

require.ErrorIs(t, ctx.Err(), context.DeadlineExceeded)
}

func TestLazyTimeoutContext_ResetOnNewLocation(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 50*time.Millisecond)

done := ctx.Done() // Store to check expiration.
time.Sleep(30 * time.Millisecond) // 60% of timeout.
_ = ctx.Done() // New line, resets timeout.
time.Sleep(30 * time.Millisecond) // 60% again, would be 120% without reset.

select {
case <-done:
t.Fatal("timeout should have been reset")
default:
}

select {
case <-done:
case <-time.After(50 * time.Millisecond):
t.Fatal("context should have expired")
}
}

func TestLazyTimeoutContext_NoResetOnSameLocation(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 50*time.Millisecond)

var done <-chan struct{}
// Same line, no reset. 5*15ms = 75ms > 50ms timeout.
for i := 0; i < 5; i++ {
done = ctx.Done()
time.Sleep(15 * time.Millisecond)
}

select {
case <-done:
default:
t.Fatal("context should be done - same location should not reset")
}
}

func TestLazyTimeoutContext_AlreadyExpiredNoResurrection(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 5*time.Millisecond)

<-ctx.Done()
require.ErrorIs(t, ctx.Err(), context.DeadlineExceeded)

_ = ctx.Err() // New location, must not resurrect.

select {
case <-ctx.Done():
default:
t.Fatal("expired context should not be resurrected")
}

require.ErrorIs(t, ctx.Err(), context.DeadlineExceeded)
}

func TestLazyTimeoutContext_ThreadSafety(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 100*time.Millisecond)

var wg sync.WaitGroup
const numGoroutines = 10
// Relies on -race to detect issues.
for i := range numGoroutines {

Check failure on line 138 in testutil/lazy_ctx_test.go

View workflow job for this annotation

GitHub Actions / test-go-pg-17

declared and not used: i

Check failure on line 138 in testutil/lazy_ctx_test.go

View workflow job for this annotation

GitHub Actions / test-go-pg (ubuntu-latest)

declared and not used: i

Check failure on line 138 in testutil/lazy_ctx_test.go

View workflow job for this annotation

GitHub Actions / test-go-pg (ubuntu-latest)

declared and not used: i

Check failure on line 138 in testutil/lazy_ctx_test.go

View workflow job for this annotation

GitHub Actions / lint

declared and not used: i (typecheck)

Check failure on line 138 in testutil/lazy_ctx_test.go

View workflow job for this annotation

GitHub Actions / test-go-race-pg

declared and not used: i
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
_ = ctx.Done()
_, _ = ctx.Deadline()
_ = ctx.Err()
_ = ctx.Value("key")
}
}()
}

wg.Wait()
}

func TestLazyTimeoutContext_WithChildContext(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 50*time.Millisecond)

childCtx, cancel := context.WithCancel(ctx)
defer cancel()

select {
case <-childCtx.Done():
t.Fatal("child context should not be done yet")
default:
}

cancel()

select {
case <-childCtx.Done():
case <-time.After(50 * time.Millisecond):
t.Fatal("child context should be done after cancel")
}
}

func TestLazyTimeoutContext_ErrBeforeExpiration(t *testing.T) {
t.Parallel()

ctx := testutil.Context(t, 50*time.Millisecond)

err := ctx.Err()
assert.NoError(t, err, "Err() should return nil before expiration")
}

func TestLazyTimeoutContext_DeadlineReturnsCorrectValue(t *testing.T) {
t.Parallel()

timeout := 50 * time.Millisecond
before := time.Now()
ctx := testutil.Context(t, timeout)

deadline, ok := ctx.Deadline()
after := time.Now()

require.True(t, ok, "deadline should be set after Deadline() call")
require.False(t, deadline.IsZero(), "deadline should not be zero")
require.True(t, deadline.After(before.Add(timeout-time.Millisecond)))
require.True(t, deadline.Before(after.Add(timeout+10*time.Millisecond)))
}
Loading