From 71d6ab6827744bcb556a761da236f4a07579427e Mon Sep 17 00:00:00 2001 From: Colin Cross Date: Tue, 4 May 2021 09:11:41 -0700 Subject: [PATCH] Fix concurrency issues in Test_runWithTimeout Use a concurrency-safe writer in runWithTimeout to avoid data races on the bytes.Buffer passed in during tests. Bug: 181095653 Fixes: 187149270 Test: Test_runWithTimeout Test: go test -race ./cmd/run_with_timeout Change-Id: I57a889765cb9ee7b42983f0906313e0c2d1e414e --- cmd/run_with_timeout/run_with_timeout.go | 37 ++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/cmd/run_with_timeout/run_with_timeout.go b/cmd/run_with_timeout/run_with_timeout.go index e2258726c..f2caaabd9 100644 --- a/cmd/run_with_timeout/run_with_timeout.go +++ b/cmd/run_with_timeout/run_with_timeout.go @@ -23,6 +23,7 @@ import ( "io" "os" "os/exec" + "sync" "syscall" "time" ) @@ -62,10 +63,42 @@ func main() { } } +// concurrentWriter wraps a writer to make it thread-safe to call Write. +type concurrentWriter struct { + w io.Writer + sync.Mutex +} + +// Write writes the data to the wrapped writer with a lock to allow for concurrent calls. +func (c *concurrentWriter) Write(data []byte) (n int, err error) { + c.Lock() + defer c.Unlock() + if c.w == nil { + return 0, nil + } + return c.w.Write(data) +} + +// Close ends the concurrentWriter, causing future calls to Write to be no-ops. It does not close +// the underlying writer. +func (c *concurrentWriter) Close() { + c.Lock() + defer c.Unlock() + c.w = nil +} + func runWithTimeout(command string, args []string, timeout time.Duration, onTimeoutCmdStr string, stdin io.Reader, stdout, stderr io.Writer) error { cmd := exec.Command(command, args...) - cmd.Stdin, cmd.Stdout, cmd.Stderr = stdin, stdout, stderr + + // Wrap the writers in a locking writer so that cmd and onTimeoutCmd don't try to write to + // stdout or stderr concurrently. + concurrentStdout := &concurrentWriter{w: stdout} + concurrentStderr := &concurrentWriter{w: stderr} + defer concurrentStdout.Close() + defer concurrentStderr.Close() + + cmd.Stdin, cmd.Stdout, cmd.Stderr = stdin, concurrentStdout, concurrentStderr err := cmd.Start() if err != nil { return err @@ -98,7 +131,7 @@ func runWithTimeout(command string, args []string, timeout time.Duration, onTime if onTimeoutCmdStr != "" { onTimeoutCmd := exec.Command("sh", "-c", onTimeoutCmdStr) - onTimeoutCmd.Stdin, onTimeoutCmd.Stdout, onTimeoutCmd.Stderr = stdin, stdout, stderr + onTimeoutCmd.Stdin, onTimeoutCmd.Stdout, onTimeoutCmd.Stderr = stdin, concurrentStdout, concurrentStderr onTimeoutCmd.Env = append(os.Environ(), fmt.Sprintf("PID=%d", cmd.Process.Pid)) err := onTimeoutCmd.Run() if err != nil {