You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
4.8 KiB

  1. package cli
  2. import (
  3. "context"
  4. "fmt"
  5. "os"
  6. "path/filepath"
  7. "runtime"
  8. "strings"
  9. "github.com/spf13/cobra"
  10. "github.com/spf13/viper"
  11. )
  12. const (
  13. HomeFlag = "home"
  14. TraceFlag = "trace"
  15. OutputFlag = "output"
  16. EncodingFlag = "encoding"
  17. )
  18. // Executable is the minimal interface to *corba.Command, so we can
  19. // wrap if desired before the test
  20. type Executable interface {
  21. Execute() error
  22. Context() context.Context
  23. }
  24. // PrepareBaseCmd is meant for tendermint and other servers
  25. func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor {
  26. cobra.OnInitialize(func() { InitEnv(envPrefix) })
  27. cmd.PersistentFlags().StringP(HomeFlag, "", defaultHome, "directory for config and data")
  28. cmd.PersistentFlags().Bool(TraceFlag, false, "print out full stack trace on errors")
  29. cmd.PersistentPreRunE = concatCobraCmdFuncs(BindFlagsLoadViper, cmd.PersistentPreRunE)
  30. return Executor{cmd, os.Exit}
  31. }
  32. // PrepareMainCmd is meant for client side libs that want some more flags
  33. //
  34. // This adds --encoding (hex, btc, base64) and --output (text, json) to
  35. // the command. These only really make sense in interactive commands.
  36. func PrepareMainCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor {
  37. cmd.PersistentFlags().StringP(EncodingFlag, "e", "hex", "Binary encoding (hex|b64|btc)")
  38. cmd.PersistentFlags().StringP(OutputFlag, "o", "text", "Output format (text|json)")
  39. cmd.PersistentPreRunE = concatCobraCmdFuncs(validateOutput, cmd.PersistentPreRunE)
  40. return PrepareBaseCmd(cmd, envPrefix, defaultHome)
  41. }
  42. // InitEnv sets to use ENV variables if set.
  43. func InitEnv(prefix string) {
  44. copyEnvVars(prefix)
  45. // env variables with TM prefix (eg. TM_ROOT)
  46. viper.SetEnvPrefix(prefix)
  47. viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_", "-", "_"))
  48. viper.AutomaticEnv()
  49. }
  50. // This copies all variables like TMROOT to TM_ROOT,
  51. // so we can support both formats for the user
  52. func copyEnvVars(prefix string) {
  53. prefix = strings.ToUpper(prefix)
  54. ps := prefix + "_"
  55. for _, e := range os.Environ() {
  56. kv := strings.SplitN(e, "=", 2)
  57. if len(kv) == 2 {
  58. k, v := kv[0], kv[1]
  59. if strings.HasPrefix(k, prefix) && !strings.HasPrefix(k, ps) {
  60. k2 := strings.Replace(k, prefix, ps, 1)
  61. os.Setenv(k2, v)
  62. }
  63. }
  64. }
  65. }
  66. // Executor wraps the cobra Command with a nicer Execute method
  67. type Executor struct {
  68. *cobra.Command
  69. Exit func(int) // this is os.Exit by default, override in tests
  70. }
  71. type ExitCoder interface {
  72. ExitCode() int
  73. }
  74. // execute adds all child commands to the root command sets flags appropriately.
  75. // This is called by main.main(). It only needs to happen once to the rootCmd.
  76. func (e Executor) Execute() error {
  77. e.SilenceUsage = true
  78. e.SilenceErrors = true
  79. err := e.Command.Execute()
  80. if err != nil {
  81. if viper.GetBool(TraceFlag) {
  82. const size = 64 << 10
  83. buf := make([]byte, size)
  84. buf = buf[:runtime.Stack(buf, false)]
  85. fmt.Fprintf(os.Stderr, "ERROR: %v\n%s\n", err, buf)
  86. } else {
  87. fmt.Fprintf(os.Stderr, "ERROR: %v\n", err)
  88. }
  89. // return error code 1 by default, can override it with a special error type
  90. exitCode := 1
  91. if ec, ok := err.(ExitCoder); ok {
  92. exitCode = ec.ExitCode()
  93. }
  94. e.Exit(exitCode)
  95. }
  96. return err
  97. }
  98. type cobraCmdFunc func(cmd *cobra.Command, args []string) error
  99. // Returns a single function that calls each argument function in sequence
  100. // RunE, PreRunE, PersistentPreRunE, etc. all have this same signature
  101. func concatCobraCmdFuncs(fs ...cobraCmdFunc) cobraCmdFunc {
  102. return func(cmd *cobra.Command, args []string) error {
  103. for _, f := range fs {
  104. if f != nil {
  105. if err := f(cmd, args); err != nil {
  106. return err
  107. }
  108. }
  109. }
  110. return nil
  111. }
  112. }
  113. // Bind all flags and read the config into viper
  114. func BindFlagsLoadViper(cmd *cobra.Command, args []string) error {
  115. // cmd.Flags() includes flags from this command and all persistent flags from the parent
  116. if err := viper.BindPFlags(cmd.Flags()); err != nil {
  117. return err
  118. }
  119. homeDir := viper.GetString(HomeFlag)
  120. viper.Set(HomeFlag, homeDir)
  121. viper.SetConfigName("config") // name of config file (without extension)
  122. viper.AddConfigPath(homeDir) // search root directory
  123. viper.AddConfigPath(filepath.Join(homeDir, "config")) // search root directory /config
  124. // If a config file is found, read it in.
  125. if err := viper.ReadInConfig(); err == nil {
  126. // stderr, so if we redirect output to json file, this doesn't appear
  127. // fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed())
  128. } else if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
  129. // ignore not found error, return other errors
  130. return err
  131. }
  132. return nil
  133. }
  134. func validateOutput(cmd *cobra.Command, args []string) error {
  135. // validate output format
  136. output := viper.GetString(OutputFlag)
  137. switch output {
  138. case "text", "json":
  139. default:
  140. return fmt.Errorf("unsupported output format: %s", output)
  141. }
  142. return nil
  143. }