You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

539 lines
12 KiB

  1. package privval
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net"
  7. "time"
  8. amino "github.com/tendermint/go-amino"
  9. "github.com/tendermint/tendermint/crypto"
  10. "github.com/tendermint/tendermint/crypto/ed25519"
  11. cmn "github.com/tendermint/tendermint/libs/common"
  12. "github.com/tendermint/tendermint/libs/log"
  13. p2pconn "github.com/tendermint/tendermint/p2p/conn"
  14. "github.com/tendermint/tendermint/types"
  15. )
  16. const (
  17. defaultAcceptDeadlineSeconds = 3
  18. defaultConnDeadlineSeconds = 3
  19. defaultConnHeartBeatSeconds = 30
  20. defaultConnWaitSeconds = 60
  21. defaultDialRetries = 10
  22. )
  23. // Socket errors.
  24. var (
  25. ErrDialRetryMax = errors.New("dialed maximum retries")
  26. ErrConnWaitTimeout = errors.New("waited for remote signer for too long")
  27. ErrConnTimeout = errors.New("remote signer timed out")
  28. )
  29. var (
  30. acceptDeadline = time.Second * defaultAcceptDeadlineSeconds
  31. connDeadline = time.Second * defaultConnDeadlineSeconds
  32. connHeartbeat = time.Second * defaultConnHeartBeatSeconds
  33. )
  34. // SocketPVOption sets an optional parameter on the SocketPV.
  35. type SocketPVOption func(*SocketPV)
  36. // SocketPVAcceptDeadline sets the deadline for the SocketPV listener.
  37. // A zero time value disables the deadline.
  38. func SocketPVAcceptDeadline(deadline time.Duration) SocketPVOption {
  39. return func(sc *SocketPV) { sc.acceptDeadline = deadline }
  40. }
  41. // SocketPVConnDeadline sets the read and write deadline for connections
  42. // from external signing processes.
  43. func SocketPVConnDeadline(deadline time.Duration) SocketPVOption {
  44. return func(sc *SocketPV) { sc.connDeadline = deadline }
  45. }
  46. // SocketPVHeartbeat sets the period on which to check the liveness of the
  47. // connected Signer connections.
  48. func SocketPVHeartbeat(period time.Duration) SocketPVOption {
  49. return func(sc *SocketPV) { sc.connHeartbeat = period }
  50. }
  51. // SocketPVConnWait sets the timeout duration before connection of external
  52. // signing processes are considered to be unsuccessful.
  53. func SocketPVConnWait(timeout time.Duration) SocketPVOption {
  54. return func(sc *SocketPV) { sc.connWaitTimeout = timeout }
  55. }
  56. // SocketPV implements PrivValidator, it uses a socket to request signatures
  57. // from an external process.
  58. type SocketPV struct {
  59. cmn.BaseService
  60. addr string
  61. acceptDeadline time.Duration
  62. connDeadline time.Duration
  63. connHeartbeat time.Duration
  64. connWaitTimeout time.Duration
  65. privKey ed25519.PrivKeyEd25519
  66. conn net.Conn
  67. listener net.Listener
  68. }
  69. // Check that SocketPV implements PrivValidator.
  70. var _ types.PrivValidator = (*SocketPV)(nil)
  71. // NewSocketPV returns an instance of SocketPV.
  72. func NewSocketPV(
  73. logger log.Logger,
  74. socketAddr string,
  75. privKey ed25519.PrivKeyEd25519,
  76. ) *SocketPV {
  77. sc := &SocketPV{
  78. addr: socketAddr,
  79. acceptDeadline: acceptDeadline,
  80. connDeadline: connDeadline,
  81. connHeartbeat: connHeartbeat,
  82. connWaitTimeout: time.Second * defaultConnWaitSeconds,
  83. privKey: privKey,
  84. }
  85. sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc)
  86. return sc
  87. }
  88. // GetAddress implements PrivValidator.
  89. func (sc *SocketPV) GetAddress() types.Address {
  90. addr, err := sc.getAddress()
  91. if err != nil {
  92. panic(err)
  93. }
  94. return addr
  95. }
  96. // Address is an alias for PubKey().Address().
  97. func (sc *SocketPV) getAddress() (cmn.HexBytes, error) {
  98. p, err := sc.getPubKey()
  99. if err != nil {
  100. return nil, err
  101. }
  102. return p.Address(), nil
  103. }
  104. // GetPubKey implements PrivValidator.
  105. func (sc *SocketPV) GetPubKey() crypto.PubKey {
  106. pubKey, err := sc.getPubKey()
  107. if err != nil {
  108. panic(err)
  109. }
  110. return pubKey
  111. }
  112. func (sc *SocketPV) getPubKey() (crypto.PubKey, error) {
  113. err := writeMsg(sc.conn, &PubKeyMsg{})
  114. if err != nil {
  115. return nil, err
  116. }
  117. res, err := readMsg(sc.conn)
  118. if err != nil {
  119. return nil, err
  120. }
  121. return res.(*PubKeyMsg).PubKey, nil
  122. }
  123. // SignVote implements PrivValidator.
  124. func (sc *SocketPV) SignVote(chainID string, vote *types.Vote) error {
  125. err := writeMsg(sc.conn, &SignVoteMsg{Vote: vote})
  126. if err != nil {
  127. return err
  128. }
  129. res, err := readMsg(sc.conn)
  130. if err != nil {
  131. return err
  132. }
  133. *vote = *res.(*SignVoteMsg).Vote
  134. return nil
  135. }
  136. // SignProposal implements PrivValidator.
  137. func (sc *SocketPV) SignProposal(
  138. chainID string,
  139. proposal *types.Proposal,
  140. ) error {
  141. err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal})
  142. if err != nil {
  143. return err
  144. }
  145. res, err := readMsg(sc.conn)
  146. if err != nil {
  147. return err
  148. }
  149. *proposal = *res.(*SignProposalMsg).Proposal
  150. return nil
  151. }
  152. // SignHeartbeat implements PrivValidator.
  153. func (sc *SocketPV) SignHeartbeat(
  154. chainID string,
  155. heartbeat *types.Heartbeat,
  156. ) error {
  157. err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
  158. if err != nil {
  159. return err
  160. }
  161. res, err := readMsg(sc.conn)
  162. if err != nil {
  163. return err
  164. }
  165. *heartbeat = *res.(*SignHeartbeatMsg).Heartbeat
  166. return nil
  167. }
  168. // OnStart implements cmn.Service.
  169. func (sc *SocketPV) OnStart() error {
  170. if err := sc.listen(); err != nil {
  171. err = cmn.ErrorWrap(err, "failed to listen")
  172. sc.Logger.Error(
  173. "OnStart",
  174. "err", err,
  175. )
  176. return err
  177. }
  178. conn, err := sc.waitConnection()
  179. if err != nil {
  180. err = cmn.ErrorWrap(err, "failed to accept connection")
  181. sc.Logger.Error(
  182. "OnStart",
  183. "err", err,
  184. )
  185. return err
  186. }
  187. sc.conn = conn
  188. return nil
  189. }
  190. // OnStop implements cmn.Service.
  191. func (sc *SocketPV) OnStop() {
  192. if sc.conn != nil {
  193. if err := sc.conn.Close(); err != nil {
  194. err = cmn.ErrorWrap(err, "failed to close connection")
  195. sc.Logger.Error(
  196. "OnStop",
  197. "err", err,
  198. )
  199. }
  200. }
  201. if sc.listener != nil {
  202. if err := sc.listener.Close(); err != nil {
  203. err = cmn.ErrorWrap(err, "failed to close listener")
  204. sc.Logger.Error(
  205. "OnStop",
  206. "err", err,
  207. )
  208. }
  209. }
  210. }
  211. func (sc *SocketPV) acceptConnection() (net.Conn, error) {
  212. conn, err := sc.listener.Accept()
  213. if err != nil {
  214. if !sc.IsRunning() {
  215. return nil, nil // Ignore error from listener closing.
  216. }
  217. return nil, err
  218. }
  219. conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey)
  220. if err != nil {
  221. return nil, err
  222. }
  223. return conn, nil
  224. }
  225. func (sc *SocketPV) listen() error {
  226. ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr))
  227. if err != nil {
  228. return err
  229. }
  230. sc.listener = newTCPTimeoutListener(
  231. ln,
  232. sc.acceptDeadline,
  233. sc.connDeadline,
  234. sc.connHeartbeat,
  235. )
  236. return nil
  237. }
  238. // waitConnection uses the configured wait timeout to error if no external
  239. // process connects in the time period.
  240. func (sc *SocketPV) waitConnection() (net.Conn, error) {
  241. var (
  242. connc = make(chan net.Conn, 1)
  243. errc = make(chan error, 1)
  244. )
  245. go func(connc chan<- net.Conn, errc chan<- error) {
  246. conn, err := sc.acceptConnection()
  247. if err != nil {
  248. errc <- err
  249. return
  250. }
  251. connc <- conn
  252. }(connc, errc)
  253. select {
  254. case conn := <-connc:
  255. return conn, nil
  256. case err := <-errc:
  257. if _, ok := err.(timeoutError); ok {
  258. return nil, cmn.ErrorWrap(ErrConnWaitTimeout, err.Error())
  259. }
  260. return nil, err
  261. case <-time.After(sc.connWaitTimeout):
  262. return nil, ErrConnWaitTimeout
  263. }
  264. }
  265. //---------------------------------------------------------
  266. // RemoteSignerOption sets an optional parameter on the RemoteSigner.
  267. type RemoteSignerOption func(*RemoteSigner)
  268. // RemoteSignerConnDeadline sets the read and write deadline for connections
  269. // from external signing processes.
  270. func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption {
  271. return func(ss *RemoteSigner) { ss.connDeadline = deadline }
  272. }
  273. // RemoteSignerConnRetries sets the amount of attempted retries to connect.
  274. func RemoteSignerConnRetries(retries int) RemoteSignerOption {
  275. return func(ss *RemoteSigner) { ss.connRetries = retries }
  276. }
  277. // RemoteSigner implements PrivValidator by dialing to a socket.
  278. type RemoteSigner struct {
  279. cmn.BaseService
  280. addr string
  281. chainID string
  282. connDeadline time.Duration
  283. connRetries int
  284. privKey ed25519.PrivKeyEd25519
  285. privVal types.PrivValidator
  286. conn net.Conn
  287. }
  288. // NewRemoteSigner returns an instance of RemoteSigner.
  289. func NewRemoteSigner(
  290. logger log.Logger,
  291. chainID, socketAddr string,
  292. privVal types.PrivValidator,
  293. privKey ed25519.PrivKeyEd25519,
  294. ) *RemoteSigner {
  295. rs := &RemoteSigner{
  296. addr: socketAddr,
  297. chainID: chainID,
  298. connDeadline: time.Second * defaultConnDeadlineSeconds,
  299. connRetries: defaultDialRetries,
  300. privKey: privKey,
  301. privVal: privVal,
  302. }
  303. rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs)
  304. return rs
  305. }
  306. // OnStart implements cmn.Service.
  307. func (rs *RemoteSigner) OnStart() error {
  308. conn, err := rs.connect()
  309. if err != nil {
  310. err = cmn.ErrorWrap(err, "connect")
  311. rs.Logger.Error("OnStart", "err", err)
  312. return err
  313. }
  314. go rs.handleConnection(conn)
  315. return nil
  316. }
  317. // OnStop implements cmn.Service.
  318. func (rs *RemoteSigner) OnStop() {
  319. if rs.conn == nil {
  320. return
  321. }
  322. if err := rs.conn.Close(); err != nil {
  323. rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed"))
  324. }
  325. }
  326. func (rs *RemoteSigner) connect() (net.Conn, error) {
  327. for retries := rs.connRetries; retries > 0; retries-- {
  328. // Don't sleep if it is the first retry.
  329. if retries != rs.connRetries {
  330. time.Sleep(rs.connDeadline)
  331. }
  332. conn, err := cmn.Connect(rs.addr)
  333. if err != nil {
  334. err = cmn.ErrorWrap(err, "connection failed")
  335. rs.Logger.Error(
  336. "connect",
  337. "addr", rs.addr,
  338. "err", err,
  339. )
  340. continue
  341. }
  342. if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
  343. err = cmn.ErrorWrap(err, "setting connection timeout failed")
  344. rs.Logger.Error(
  345. "connect",
  346. "err", err,
  347. )
  348. continue
  349. }
  350. conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey)
  351. if err != nil {
  352. err = cmn.ErrorWrap(err, "encrypting connection failed")
  353. rs.Logger.Error(
  354. "connect",
  355. "err", err,
  356. )
  357. continue
  358. }
  359. return conn, nil
  360. }
  361. return nil, ErrDialRetryMax
  362. }
  363. func (rs *RemoteSigner) handleConnection(conn net.Conn) {
  364. for {
  365. if !rs.IsRunning() {
  366. return // Ignore error from listener closing.
  367. }
  368. req, err := readMsg(conn)
  369. if err != nil {
  370. if err != io.EOF {
  371. rs.Logger.Error("handleConnection", "err", err)
  372. }
  373. return
  374. }
  375. var res SocketPVMsg
  376. switch r := req.(type) {
  377. case *PubKeyMsg:
  378. var p crypto.PubKey
  379. p = rs.privVal.GetPubKey()
  380. res = &PubKeyMsg{p}
  381. case *SignVoteMsg:
  382. err = rs.privVal.SignVote(rs.chainID, r.Vote)
  383. res = &SignVoteMsg{r.Vote}
  384. case *SignProposalMsg:
  385. err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
  386. res = &SignProposalMsg{r.Proposal}
  387. case *SignHeartbeatMsg:
  388. err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
  389. res = &SignHeartbeatMsg{r.Heartbeat}
  390. default:
  391. err = fmt.Errorf("unknown msg: %v", r)
  392. }
  393. if err != nil {
  394. rs.Logger.Error("handleConnection", "err", err)
  395. return
  396. }
  397. err = writeMsg(conn, res)
  398. if err != nil {
  399. rs.Logger.Error("handleConnection", "err", err)
  400. return
  401. }
  402. }
  403. }
  404. //---------------------------------------------------------
  405. // SocketPVMsg is sent between RemoteSigner and SocketPV.
  406. type SocketPVMsg interface{}
  407. func RegisterSocketPVMsg(cdc *amino.Codec) {
  408. cdc.RegisterInterface((*SocketPVMsg)(nil), nil)
  409. cdc.RegisterConcrete(&PubKeyMsg{}, "tendermint/socketpv/PubKeyMsg", nil)
  410. cdc.RegisterConcrete(&SignVoteMsg{}, "tendermint/socketpv/SignVoteMsg", nil)
  411. cdc.RegisterConcrete(&SignProposalMsg{}, "tendermint/socketpv/SignProposalMsg", nil)
  412. cdc.RegisterConcrete(&SignHeartbeatMsg{}, "tendermint/socketpv/SignHeartbeatMsg", nil)
  413. }
  414. // PubKeyMsg is a PrivValidatorSocket message containing the public key.
  415. type PubKeyMsg struct {
  416. PubKey crypto.PubKey
  417. }
  418. // SignVoteMsg is a PrivValidatorSocket message containing a vote.
  419. type SignVoteMsg struct {
  420. Vote *types.Vote
  421. }
  422. // SignProposalMsg is a PrivValidatorSocket message containing a Proposal.
  423. type SignProposalMsg struct {
  424. Proposal *types.Proposal
  425. }
  426. // SignHeartbeatMsg is a PrivValidatorSocket message containing a Heartbeat.
  427. type SignHeartbeatMsg struct {
  428. Heartbeat *types.Heartbeat
  429. }
  430. func readMsg(r io.Reader) (msg SocketPVMsg, err error) {
  431. const maxSocketPVMsgSize = 1024 * 10
  432. _, err = cdc.UnmarshalBinaryReader(r, &msg, maxSocketPVMsgSize)
  433. if _, ok := err.(timeoutError); ok {
  434. err = cmn.ErrorWrap(ErrConnTimeout, err.Error())
  435. }
  436. return
  437. }
  438. func writeMsg(w io.Writer, msg interface{}) (err error) {
  439. _, err = cdc.MarshalBinaryWriter(w, msg)
  440. if _, ok := err.(timeoutError); ok {
  441. err = cmn.ErrorWrap(ErrConnTimeout, err.Error())
  442. }
  443. return
  444. }