Browse Source

libs/os: add test case for TrapSignal (#5646)

pull/5650/head
Alessio Treglia 4 years ago
committed by GitHub
parent
commit
eb0d353767
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 10 deletions
  1. +8
    -8
      libs/os/os.go
  2. +54
    -2
      libs/os/os_test.go

+ 8
- 8
libs/os/os.go View File

@ -13,19 +13,19 @@ type logger interface {
Info(msg string, keyvals ...interface{})
}
// TrapSignal catches the SIGTERM/SIGINT and executes cb function. After that it exits
// with code 0.
// TrapSignal catches SIGTERM and SIGINT, executes the cleanup function,
// and exits with code 0.
func TrapSignal(logger logger, cb func()) {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
for sig := range c {
logger.Info(fmt.Sprintf("captured %v, exiting...", sig))
if cb != nil {
cb()
}
os.Exit(0)
sig := <-c
logger.Info(fmt.Sprintf("captured %v, exiting...", sig))
if cb != nil {
cb()
}
os.Exit(0)
}()
}


+ 54
- 2
libs/os/os_test.go View File

@ -1,11 +1,15 @@
package os
package os_test
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"os/exec"
"testing"
"time"
tmos "github.com/tendermint/tendermint/libs/os"
)
func TestCopyFile(t *testing.T) {
@ -20,7 +24,7 @@ func TestCopyFile(t *testing.T) {
}
copyfile := fmt.Sprintf("%s.copy", tmpfile.Name())
if err := CopyFile(tmpfile.Name(), copyfile); err != nil {
if err := tmos.CopyFile(tmpfile.Name(), copyfile); err != nil {
t.Fatal(err)
}
if _, err := os.Stat(copyfile); os.IsNotExist(err) {
@ -35,3 +39,51 @@ func TestCopyFile(t *testing.T) {
}
os.Remove(copyfile)
}
func TestTrapSignal(t *testing.T) {
if os.Getenv("TM_TRAP_SIGNAL_TEST") == "1" {
t.Log("inside test process")
killer()
return
}
cmd := exec.Command(os.Args[0], "-test.run="+t.Name())
mockStderr := bytes.NewBufferString("")
cmd.Env = append(os.Environ(), "TM_TRAP_SIGNAL_TEST=1")
cmd.Stderr = mockStderr
err := cmd.Run()
if err == nil {
wantStderr := "exiting"
if mockStderr.String() != wantStderr {
t.Fatalf("stderr: want %q, got %q", wantStderr, mockStderr.String())
}
return
}
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
t.Fatalf("wrong exit code, want 0, got %d", e.ExitCode())
}
t.Fatal("this error should not be triggered")
}
type mockLogger struct{}
func (ml mockLogger) Info(msg string, keyvals ...interface{}) {}
func killer() {
logger := mockLogger{}
tmos.TrapSignal(logger, func() { _, _ = fmt.Fprintf(os.Stderr, "exiting") })
time.Sleep(1 * time.Second)
// use Kill() to test SIGTERM
if err := tmos.Kill(); err != nil {
panic(err)
}
time.Sleep(1 * time.Second)
}

Loading…
Cancel
Save