You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

538 lines
12 KiB

  1. package privval
  2. import (
  3. "errors"
  4. "fmt"
  5. "io"
  6. "net"
  7. "time"
  8. "github.com/tendermint/go-amino"
  9. "github.com/tendermint/tendermint/crypto"
  10. cmn "github.com/tendermint/tendermint/libs/common"
  11. "github.com/tendermint/tendermint/libs/log"
  12. p2pconn "github.com/tendermint/tendermint/p2p/conn"
  13. "github.com/tendermint/tendermint/types"
  14. )
  15. const (
  16. defaultAcceptDeadlineSeconds = 3
  17. defaultConnDeadlineSeconds = 3
  18. defaultConnHeartBeatSeconds = 30
  19. defaultConnWaitSeconds = 60
  20. defaultDialRetries = 10
  21. )
  22. // Socket errors.
  23. var (
  24. ErrDialRetryMax = errors.New("dialed maximum retries")
  25. ErrConnWaitTimeout = errors.New("waited for remote signer for too long")
  26. ErrConnTimeout = errors.New("remote signer timed out")
  27. )
  28. var (
  29. acceptDeadline = time.Second + defaultAcceptDeadlineSeconds
  30. connDeadline = time.Second * defaultConnDeadlineSeconds
  31. connHeartbeat = time.Second * defaultConnHeartBeatSeconds
  32. )
  33. // SocketPVOption sets an optional parameter on the SocketPV.
  34. type SocketPVOption func(*SocketPV)
  35. // SocketPVAcceptDeadline sets the deadline for the SocketPV listener.
  36. // A zero time value disables the deadline.
  37. func SocketPVAcceptDeadline(deadline time.Duration) SocketPVOption {
  38. return func(sc *SocketPV) { sc.acceptDeadline = deadline }
  39. }
  40. // SocketPVConnDeadline sets the read and write deadline for connections
  41. // from external signing processes.
  42. func SocketPVConnDeadline(deadline time.Duration) SocketPVOption {
  43. return func(sc *SocketPV) { sc.connDeadline = deadline }
  44. }
  45. // SocketPVHeartbeat sets the period on which to check the liveness of the
  46. // connected Signer connections.
  47. func SocketPVHeartbeat(period time.Duration) SocketPVOption {
  48. return func(sc *SocketPV) { sc.connHeartbeat = period }
  49. }
  50. // SocketPVConnWait sets the timeout duration before connection of external
  51. // signing processes are considered to be unsuccessful.
  52. func SocketPVConnWait(timeout time.Duration) SocketPVOption {
  53. return func(sc *SocketPV) { sc.connWaitTimeout = timeout }
  54. }
  55. // SocketPV implements PrivValidator, it uses a socket to request signatures
  56. // from an external process.
  57. type SocketPV struct {
  58. cmn.BaseService
  59. addr string
  60. acceptDeadline time.Duration
  61. connDeadline time.Duration
  62. connHeartbeat time.Duration
  63. connWaitTimeout time.Duration
  64. privKey crypto.PrivKeyEd25519
  65. conn net.Conn
  66. listener net.Listener
  67. }
  68. // Check that SocketPV implements PrivValidator.
  69. var _ types.PrivValidator = (*SocketPV)(nil)
  70. // NewSocketPV returns an instance of SocketPV.
  71. func NewSocketPV(
  72. logger log.Logger,
  73. socketAddr string,
  74. privKey crypto.PrivKeyEd25519,
  75. ) *SocketPV {
  76. sc := &SocketPV{
  77. addr: socketAddr,
  78. acceptDeadline: acceptDeadline,
  79. connDeadline: connDeadline,
  80. connHeartbeat: connHeartbeat,
  81. connWaitTimeout: time.Second * defaultConnWaitSeconds,
  82. privKey: privKey,
  83. }
  84. sc.BaseService = *cmn.NewBaseService(logger, "SocketPV", sc)
  85. return sc
  86. }
  87. // GetAddress implements PrivValidator.
  88. func (sc *SocketPV) GetAddress() types.Address {
  89. addr, err := sc.getAddress()
  90. if err != nil {
  91. panic(err)
  92. }
  93. return addr
  94. }
  95. // Address is an alias for PubKey().Address().
  96. func (sc *SocketPV) getAddress() (cmn.HexBytes, error) {
  97. p, err := sc.getPubKey()
  98. if err != nil {
  99. return nil, err
  100. }
  101. return p.Address(), nil
  102. }
  103. // GetPubKey implements PrivValidator.
  104. func (sc *SocketPV) GetPubKey() crypto.PubKey {
  105. pubKey, err := sc.getPubKey()
  106. if err != nil {
  107. panic(err)
  108. }
  109. return pubKey
  110. }
  111. func (sc *SocketPV) getPubKey() (crypto.PubKey, error) {
  112. err := writeMsg(sc.conn, &PubKeyMsg{})
  113. if err != nil {
  114. return nil, err
  115. }
  116. res, err := readMsg(sc.conn)
  117. if err != nil {
  118. return nil, err
  119. }
  120. return res.(*PubKeyMsg).PubKey, nil
  121. }
  122. // SignVote implements PrivValidator.
  123. func (sc *SocketPV) SignVote(chainID string, vote *types.Vote) error {
  124. err := writeMsg(sc.conn, &SignVoteMsg{Vote: vote})
  125. if err != nil {
  126. return err
  127. }
  128. res, err := readMsg(sc.conn)
  129. if err != nil {
  130. return err
  131. }
  132. *vote = *res.(*SignVoteMsg).Vote
  133. return nil
  134. }
  135. // SignProposal implements PrivValidator.
  136. func (sc *SocketPV) SignProposal(
  137. chainID string,
  138. proposal *types.Proposal,
  139. ) error {
  140. err := writeMsg(sc.conn, &SignProposalMsg{Proposal: proposal})
  141. if err != nil {
  142. return err
  143. }
  144. res, err := readMsg(sc.conn)
  145. if err != nil {
  146. return err
  147. }
  148. *proposal = *res.(*SignProposalMsg).Proposal
  149. return nil
  150. }
  151. // SignHeartbeat implements PrivValidator.
  152. func (sc *SocketPV) SignHeartbeat(
  153. chainID string,
  154. heartbeat *types.Heartbeat,
  155. ) error {
  156. err := writeMsg(sc.conn, &SignHeartbeatMsg{Heartbeat: heartbeat})
  157. if err != nil {
  158. return err
  159. }
  160. res, err := readMsg(sc.conn)
  161. if err != nil {
  162. return err
  163. }
  164. *heartbeat = *res.(*SignHeartbeatMsg).Heartbeat
  165. return nil
  166. }
  167. // OnStart implements cmn.Service.
  168. func (sc *SocketPV) OnStart() error {
  169. if err := sc.listen(); err != nil {
  170. err = cmn.ErrorWrap(err, "failed to listen")
  171. sc.Logger.Error(
  172. "OnStart",
  173. "err", err,
  174. )
  175. return err
  176. }
  177. conn, err := sc.waitConnection()
  178. if err != nil {
  179. err = cmn.ErrorWrap(err, "failed to accept connection")
  180. sc.Logger.Error(
  181. "OnStart",
  182. "err", err,
  183. )
  184. return err
  185. }
  186. sc.conn = conn
  187. return nil
  188. }
  189. // OnStop implements cmn.Service.
  190. func (sc *SocketPV) OnStop() {
  191. if sc.conn != nil {
  192. if err := sc.conn.Close(); err != nil {
  193. err = cmn.ErrorWrap(err, "failed to close connection")
  194. sc.Logger.Error(
  195. "OnStop",
  196. "err", err,
  197. )
  198. }
  199. }
  200. if sc.listener != nil {
  201. if err := sc.listener.Close(); err != nil {
  202. err = cmn.ErrorWrap(err, "failed to close listener")
  203. sc.Logger.Error(
  204. "OnStop",
  205. "err", err,
  206. )
  207. }
  208. }
  209. }
  210. func (sc *SocketPV) acceptConnection() (net.Conn, error) {
  211. conn, err := sc.listener.Accept()
  212. if err != nil {
  213. if !sc.IsRunning() {
  214. return nil, nil // Ignore error from listener closing.
  215. }
  216. return nil, err
  217. }
  218. conn, err = p2pconn.MakeSecretConnection(conn, sc.privKey)
  219. if err != nil {
  220. return nil, err
  221. }
  222. return conn, nil
  223. }
  224. func (sc *SocketPV) listen() error {
  225. ln, err := net.Listen(cmn.ProtocolAndAddress(sc.addr))
  226. if err != nil {
  227. return err
  228. }
  229. sc.listener = newTCPTimeoutListener(
  230. ln,
  231. sc.acceptDeadline,
  232. sc.connDeadline,
  233. sc.connHeartbeat,
  234. )
  235. return nil
  236. }
  237. // waitConnection uses the configured wait timeout to error if no external
  238. // process connects in the time period.
  239. func (sc *SocketPV) waitConnection() (net.Conn, error) {
  240. var (
  241. connc = make(chan net.Conn, 1)
  242. errc = make(chan error, 1)
  243. )
  244. go func(connc chan<- net.Conn, errc chan<- error) {
  245. conn, err := sc.acceptConnection()
  246. if err != nil {
  247. errc <- err
  248. return
  249. }
  250. connc <- conn
  251. }(connc, errc)
  252. select {
  253. case conn := <-connc:
  254. return conn, nil
  255. case err := <-errc:
  256. if _, ok := err.(timeoutError); ok {
  257. return nil, cmn.ErrorWrap(ErrConnWaitTimeout, err.Error())
  258. }
  259. return nil, err
  260. case <-time.After(sc.connWaitTimeout):
  261. return nil, ErrConnWaitTimeout
  262. }
  263. }
  264. //---------------------------------------------------------
  265. // RemoteSignerOption sets an optional parameter on the RemoteSigner.
  266. type RemoteSignerOption func(*RemoteSigner)
  267. // RemoteSignerConnDeadline sets the read and write deadline for connections
  268. // from external signing processes.
  269. func RemoteSignerConnDeadline(deadline time.Duration) RemoteSignerOption {
  270. return func(ss *RemoteSigner) { ss.connDeadline = deadline }
  271. }
  272. // RemoteSignerConnRetries sets the amount of attempted retries to connect.
  273. func RemoteSignerConnRetries(retries int) RemoteSignerOption {
  274. return func(ss *RemoteSigner) { ss.connRetries = retries }
  275. }
  276. // RemoteSigner implements PrivValidator by dialing to a socket.
  277. type RemoteSigner struct {
  278. cmn.BaseService
  279. addr string
  280. chainID string
  281. connDeadline time.Duration
  282. connRetries int
  283. privKey crypto.PrivKeyEd25519
  284. privVal types.PrivValidator
  285. conn net.Conn
  286. }
  287. // NewRemoteSigner returns an instance of RemoteSigner.
  288. func NewRemoteSigner(
  289. logger log.Logger,
  290. chainID, socketAddr string,
  291. privVal types.PrivValidator,
  292. privKey crypto.PrivKeyEd25519,
  293. ) *RemoteSigner {
  294. rs := &RemoteSigner{
  295. addr: socketAddr,
  296. chainID: chainID,
  297. connDeadline: time.Second * defaultConnDeadlineSeconds,
  298. connRetries: defaultDialRetries,
  299. privKey: privKey,
  300. privVal: privVal,
  301. }
  302. rs.BaseService = *cmn.NewBaseService(logger, "RemoteSigner", rs)
  303. return rs
  304. }
  305. // OnStart implements cmn.Service.
  306. func (rs *RemoteSigner) OnStart() error {
  307. conn, err := rs.connect()
  308. if err != nil {
  309. err = cmn.ErrorWrap(err, "connect")
  310. rs.Logger.Error("OnStart", "err", err)
  311. return err
  312. }
  313. go rs.handleConnection(conn)
  314. return nil
  315. }
  316. // OnStop implements cmn.Service.
  317. func (rs *RemoteSigner) OnStop() {
  318. if rs.conn == nil {
  319. return
  320. }
  321. if err := rs.conn.Close(); err != nil {
  322. rs.Logger.Error("OnStop", "err", cmn.ErrorWrap(err, "closing listener failed"))
  323. }
  324. }
  325. func (rs *RemoteSigner) connect() (net.Conn, error) {
  326. for retries := rs.connRetries; retries > 0; retries-- {
  327. // Don't sleep if it is the first retry.
  328. if retries != rs.connRetries {
  329. time.Sleep(rs.connDeadline)
  330. }
  331. conn, err := cmn.Connect(rs.addr)
  332. if err != nil {
  333. err = cmn.ErrorWrap(err, "connection failed")
  334. rs.Logger.Error(
  335. "connect",
  336. "addr", rs.addr,
  337. "err", err,
  338. )
  339. continue
  340. }
  341. if err := conn.SetDeadline(time.Now().Add(connDeadline)); err != nil {
  342. err = cmn.ErrorWrap(err, "setting connection timeout failed")
  343. rs.Logger.Error(
  344. "connect",
  345. "err", err,
  346. )
  347. continue
  348. }
  349. conn, err = p2pconn.MakeSecretConnection(conn, rs.privKey)
  350. if err != nil {
  351. err = cmn.ErrorWrap(err, "encrypting connection failed")
  352. rs.Logger.Error(
  353. "connect",
  354. "err", err,
  355. )
  356. continue
  357. }
  358. return conn, nil
  359. }
  360. return nil, ErrDialRetryMax
  361. }
  362. func (rs *RemoteSigner) handleConnection(conn net.Conn) {
  363. for {
  364. if !rs.IsRunning() {
  365. return // Ignore error from listener closing.
  366. }
  367. req, err := readMsg(conn)
  368. if err != nil {
  369. if err != io.EOF {
  370. rs.Logger.Error("handleConnection", "err", err)
  371. }
  372. return
  373. }
  374. var res SocketPVMsg
  375. switch r := req.(type) {
  376. case *PubKeyMsg:
  377. var p crypto.PubKey
  378. p = rs.privVal.GetPubKey()
  379. res = &PubKeyMsg{p}
  380. case *SignVoteMsg:
  381. err = rs.privVal.SignVote(rs.chainID, r.Vote)
  382. res = &SignVoteMsg{r.Vote}
  383. case *SignProposalMsg:
  384. err = rs.privVal.SignProposal(rs.chainID, r.Proposal)
  385. res = &SignProposalMsg{r.Proposal}
  386. case *SignHeartbeatMsg:
  387. err = rs.privVal.SignHeartbeat(rs.chainID, r.Heartbeat)
  388. res = &SignHeartbeatMsg{r.Heartbeat}
  389. default:
  390. err = fmt.Errorf("unknown msg: %v", r)
  391. }
  392. if err != nil {
  393. rs.Logger.Error("handleConnection", "err", err)
  394. return
  395. }
  396. err = writeMsg(conn, res)
  397. if err != nil {
  398. rs.Logger.Error("handleConnection", "err", err)
  399. return
  400. }
  401. }
  402. }
  403. //---------------------------------------------------------
  404. // SocketPVMsg is sent between RemoteSigner and SocketPV.
  405. type SocketPVMsg interface{}
  406. func RegisterSocketPVMsg(cdc *amino.Codec) {
  407. cdc.RegisterInterface((*SocketPVMsg)(nil), nil)
  408. cdc.RegisterConcrete(&PubKeyMsg{}, "tendermint/socketpv/PubKeyMsg", nil)
  409. cdc.RegisterConcrete(&SignVoteMsg{}, "tendermint/socketpv/SignVoteMsg", nil)
  410. cdc.RegisterConcrete(&SignProposalMsg{}, "tendermint/socketpv/SignProposalMsg", nil)
  411. cdc.RegisterConcrete(&SignHeartbeatMsg{}, "tendermint/socketpv/SignHeartbeatMsg", nil)
  412. }
  413. // PubKeyMsg is a PrivValidatorSocket message containing the public key.
  414. type PubKeyMsg struct {
  415. PubKey crypto.PubKey
  416. }
  417. // SignVoteMsg is a PrivValidatorSocket message containing a vote.
  418. type SignVoteMsg struct {
  419. Vote *types.Vote
  420. }
  421. // SignProposalMsg is a PrivValidatorSocket message containing a Proposal.
  422. type SignProposalMsg struct {
  423. Proposal *types.Proposal
  424. }
  425. // SignHeartbeatMsg is a PrivValidatorSocket message containing a Heartbeat.
  426. type SignHeartbeatMsg struct {
  427. Heartbeat *types.Heartbeat
  428. }
  429. func readMsg(r io.Reader) (msg SocketPVMsg, err error) {
  430. const maxSocketPVMsgSize = 1024 * 10
  431. _, err = cdc.UnmarshalBinaryReader(r, &msg, maxSocketPVMsgSize)
  432. if _, ok := err.(timeoutError); ok {
  433. err = cmn.ErrorWrap(ErrConnTimeout, err.Error())
  434. }
  435. return
  436. }
  437. func writeMsg(w io.Writer, msg interface{}) (err error) {
  438. _, err = cdc.MarshalBinaryWriter(w, msg)
  439. if _, ok := err.(timeoutError); ok {
  440. err = cmn.ErrorWrap(ErrConnTimeout, err.Error())
  441. }
  442. return
  443. }