You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

459 lines
10 KiB

  1. package privval
  2. import (
  3. "fmt"
  4. "net"
  5. "testing"
  6. "time"
  7. "github.com/stretchr/testify/assert"
  8. "github.com/stretchr/testify/require"
  9. "github.com/tendermint/tendermint/crypto/ed25519"
  10. cmn "github.com/tendermint/tendermint/libs/common"
  11. "github.com/tendermint/tendermint/libs/log"
  12. p2pconn "github.com/tendermint/tendermint/p2p/conn"
  13. "github.com/tendermint/tendermint/types"
  14. )
  15. func TestSocketPVAddress(t *testing.T) {
  16. var (
  17. chainID = cmn.RandStr(12)
  18. sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV())
  19. )
  20. defer sc.Stop()
  21. defer rs.Stop()
  22. serverAddr := rs.privVal.GetAddress()
  23. clientAddr := sc.GetAddress()
  24. assert.Equal(t, serverAddr, clientAddr)
  25. // TODO(xla): Remove when PrivValidator2 replaced PrivValidator.
  26. assert.Equal(t, serverAddr, sc.GetAddress())
  27. }
  28. func TestSocketPVPubKey(t *testing.T) {
  29. var (
  30. chainID = cmn.RandStr(12)
  31. sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV())
  32. )
  33. defer sc.Stop()
  34. defer rs.Stop()
  35. clientKey, err := sc.getPubKey()
  36. require.NoError(t, err)
  37. privKey := rs.privVal.GetPubKey()
  38. assert.Equal(t, privKey, clientKey)
  39. // TODO(xla): Remove when PrivValidator2 replaced PrivValidator.
  40. assert.Equal(t, privKey, sc.GetPubKey())
  41. }
  42. func TestSocketPVProposal(t *testing.T) {
  43. var (
  44. chainID = cmn.RandStr(12)
  45. sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV())
  46. ts = time.Now()
  47. privProposal = &types.Proposal{Timestamp: ts}
  48. clientProposal = &types.Proposal{Timestamp: ts}
  49. )
  50. defer sc.Stop()
  51. defer rs.Stop()
  52. require.NoError(t, rs.privVal.SignProposal(chainID, privProposal))
  53. require.NoError(t, sc.SignProposal(chainID, clientProposal))
  54. assert.Equal(t, privProposal.Signature, clientProposal.Signature)
  55. }
  56. func TestSocketPVVote(t *testing.T) {
  57. var (
  58. chainID = cmn.RandStr(12)
  59. sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV())
  60. ts = time.Now()
  61. vType = types.PrecommitType
  62. want = &types.Vote{Timestamp: ts, Type: vType}
  63. have = &types.Vote{Timestamp: ts, Type: vType}
  64. )
  65. defer sc.Stop()
  66. defer rs.Stop()
  67. require.NoError(t, rs.privVal.SignVote(chainID, want))
  68. require.NoError(t, sc.SignVote(chainID, have))
  69. assert.Equal(t, want.Signature, have.Signature)
  70. }
  71. func TestSocketPVVoteResetDeadline(t *testing.T) {
  72. var (
  73. chainID = cmn.RandStr(12)
  74. sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV())
  75. ts = time.Now()
  76. vType = types.PrecommitType
  77. want = &types.Vote{Timestamp: ts, Type: vType}
  78. have = &types.Vote{Timestamp: ts, Type: vType}
  79. )
  80. defer sc.Stop()
  81. defer rs.Stop()
  82. time.Sleep(3 * time.Millisecond)
  83. require.NoError(t, rs.privVal.SignVote(chainID, want))
  84. require.NoError(t, sc.SignVote(chainID, have))
  85. assert.Equal(t, want.Signature, have.Signature)
  86. // This would exceed the deadline if it was not extended by the previous message
  87. time.Sleep(3 * time.Millisecond)
  88. require.NoError(t, rs.privVal.SignVote(chainID, want))
  89. require.NoError(t, sc.SignVote(chainID, have))
  90. assert.Equal(t, want.Signature, have.Signature)
  91. }
  92. func TestSocketPVVoteKeepalive(t *testing.T) {
  93. var (
  94. chainID = cmn.RandStr(12)
  95. sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV())
  96. ts = time.Now()
  97. vType = types.PrecommitType
  98. want = &types.Vote{Timestamp: ts, Type: vType}
  99. have = &types.Vote{Timestamp: ts, Type: vType}
  100. )
  101. defer sc.Stop()
  102. defer rs.Stop()
  103. time.Sleep(10 * time.Millisecond)
  104. require.NoError(t, rs.privVal.SignVote(chainID, want))
  105. require.NoError(t, sc.SignVote(chainID, have))
  106. assert.Equal(t, want.Signature, have.Signature)
  107. }
  108. func TestSocketPVHeartbeat(t *testing.T) {
  109. var (
  110. chainID = cmn.RandStr(12)
  111. sc, rs = testSetupSocketPair(t, chainID, types.NewMockPV())
  112. want = &types.Heartbeat{}
  113. have = &types.Heartbeat{}
  114. )
  115. defer sc.Stop()
  116. defer rs.Stop()
  117. require.NoError(t, rs.privVal.SignHeartbeat(chainID, want))
  118. require.NoError(t, sc.SignHeartbeat(chainID, have))
  119. assert.Equal(t, want.Signature, have.Signature)
  120. }
  121. func TestSocketPVDeadline(t *testing.T) {
  122. var (
  123. addr = testFreeAddr(t)
  124. listenc = make(chan struct{})
  125. sc = NewTCPVal(
  126. log.TestingLogger(),
  127. addr,
  128. ed25519.GenPrivKey(),
  129. )
  130. )
  131. TCPValConnTimeout(100 * time.Millisecond)(sc)
  132. go func(sc *TCPVal) {
  133. defer close(listenc)
  134. require.NoError(t, sc.Start())
  135. assert.True(t, sc.IsRunning())
  136. }(sc)
  137. for {
  138. conn, err := cmn.Connect(addr)
  139. if err != nil {
  140. continue
  141. }
  142. _, err = p2pconn.MakeSecretConnection(
  143. conn,
  144. ed25519.GenPrivKey(),
  145. )
  146. if err == nil {
  147. break
  148. }
  149. }
  150. <-listenc
  151. _, err := sc.getPubKey()
  152. assert.Equal(t, err.(cmn.Error).Data(), ErrConnTimeout)
  153. }
  154. func TestRemoteSignerRetry(t *testing.T) {
  155. var (
  156. attemptc = make(chan int)
  157. retries = 2
  158. )
  159. ln, err := net.Listen("tcp", "127.0.0.1:0")
  160. require.NoError(t, err)
  161. go func(ln net.Listener, attemptc chan<- int) {
  162. attempts := 0
  163. for {
  164. conn, err := ln.Accept()
  165. require.NoError(t, err)
  166. err = conn.Close()
  167. require.NoError(t, err)
  168. attempts++
  169. if attempts == retries {
  170. attemptc <- attempts
  171. break
  172. }
  173. }
  174. }(ln, attemptc)
  175. rs := NewRemoteSigner(
  176. log.TestingLogger(),
  177. cmn.RandStr(12),
  178. ln.Addr().String(),
  179. types.NewMockPV(),
  180. ed25519.GenPrivKey(),
  181. )
  182. defer rs.Stop()
  183. RemoteSignerConnDeadline(time.Millisecond)(rs)
  184. RemoteSignerConnRetries(retries)(rs)
  185. assert.Equal(t, rs.Start(), ErrDialRetryMax)
  186. select {
  187. case attempts := <-attemptc:
  188. assert.Equal(t, retries, attempts)
  189. case <-time.After(100 * time.Millisecond):
  190. t.Error("expected remote to observe connection attempts")
  191. }
  192. }
  193. func TestRemoteSignVoteErrors(t *testing.T) {
  194. var (
  195. chainID = cmn.RandStr(12)
  196. sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV())
  197. ts = time.Now()
  198. vType = types.PrecommitType
  199. vote = &types.Vote{Timestamp: ts, Type: vType}
  200. )
  201. defer sc.Stop()
  202. defer rs.Stop()
  203. err := writeMsg(sc.conn, &SignVoteRequest{Vote: vote})
  204. require.NoError(t, err)
  205. res, err := readMsg(sc.conn)
  206. require.NoError(t, err)
  207. resp := *res.(*SignedVoteResponse)
  208. require.NotNil(t, resp.Error)
  209. require.Equal(t, resp.Error.Description, types.ErroringMockPVErr.Error())
  210. err = rs.privVal.SignVote(chainID, vote)
  211. require.Error(t, err)
  212. err = sc.SignVote(chainID, vote)
  213. require.Error(t, err)
  214. }
  215. func TestRemoteSignProposalErrors(t *testing.T) {
  216. var (
  217. chainID = cmn.RandStr(12)
  218. sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV())
  219. ts = time.Now()
  220. proposal = &types.Proposal{Timestamp: ts}
  221. )
  222. defer sc.Stop()
  223. defer rs.Stop()
  224. err := writeMsg(sc.conn, &SignProposalRequest{Proposal: proposal})
  225. require.NoError(t, err)
  226. res, err := readMsg(sc.conn)
  227. require.NoError(t, err)
  228. resp := *res.(*SignedProposalResponse)
  229. require.NotNil(t, resp.Error)
  230. require.Equal(t, resp.Error.Description, types.ErroringMockPVErr.Error())
  231. err = rs.privVal.SignProposal(chainID, proposal)
  232. require.Error(t, err)
  233. err = sc.SignProposal(chainID, proposal)
  234. require.Error(t, err)
  235. }
  236. func TestRemoteSignHeartbeatErrors(t *testing.T) {
  237. var (
  238. chainID = cmn.RandStr(12)
  239. sc, rs = testSetupSocketPair(t, chainID, types.NewErroringMockPV())
  240. hb = &types.Heartbeat{}
  241. )
  242. defer sc.Stop()
  243. defer rs.Stop()
  244. err := writeMsg(sc.conn, &SignHeartbeatRequest{Heartbeat: hb})
  245. require.NoError(t, err)
  246. res, err := readMsg(sc.conn)
  247. require.NoError(t, err)
  248. resp := *res.(*SignedHeartbeatResponse)
  249. require.NotNil(t, resp.Error)
  250. require.Equal(t, resp.Error.Description, types.ErroringMockPVErr.Error())
  251. err = rs.privVal.SignHeartbeat(chainID, hb)
  252. require.Error(t, err)
  253. err = sc.SignHeartbeat(chainID, hb)
  254. require.Error(t, err)
  255. }
  256. func TestErrUnexpectedResponse(t *testing.T) {
  257. var (
  258. addr = testFreeAddr(t)
  259. logger = log.TestingLogger()
  260. chainID = cmn.RandStr(12)
  261. readyc = make(chan struct{})
  262. errc = make(chan error, 1)
  263. rs = NewRemoteSigner(
  264. logger,
  265. chainID,
  266. addr,
  267. types.NewMockPV(),
  268. ed25519.GenPrivKey(),
  269. )
  270. sc = NewTCPVal(
  271. logger,
  272. addr,
  273. ed25519.GenPrivKey(),
  274. )
  275. )
  276. testStartSocketPV(t, readyc, sc)
  277. defer sc.Stop()
  278. RemoteSignerConnDeadline(time.Millisecond)(rs)
  279. RemoteSignerConnRetries(1e6)(rs)
  280. // we do not want to Start() the remote signer here and instead use the connection to
  281. // reply with intentionally wrong replies below:
  282. rsConn, err := rs.connect()
  283. defer rsConn.Close()
  284. require.NoError(t, err)
  285. require.NotNil(t, rsConn)
  286. <-readyc
  287. // Heartbeat:
  288. go func(errc chan error) {
  289. errc <- sc.SignHeartbeat(chainID, &types.Heartbeat{})
  290. }(errc)
  291. // read request and write wrong response:
  292. go testReadWriteResponse(t, &SignedVoteResponse{}, rsConn)
  293. err = <-errc
  294. require.Error(t, err)
  295. require.Equal(t, err, ErrUnexpectedResponse)
  296. // Proposal:
  297. go func(errc chan error) {
  298. errc <- sc.SignProposal(chainID, &types.Proposal{})
  299. }(errc)
  300. // read request and write wrong response:
  301. go testReadWriteResponse(t, &SignedHeartbeatResponse{}, rsConn)
  302. err = <-errc
  303. require.Error(t, err)
  304. require.Equal(t, err, ErrUnexpectedResponse)
  305. // Vote:
  306. go func(errc chan error) {
  307. errc <- sc.SignVote(chainID, &types.Vote{})
  308. }(errc)
  309. // read request and write wrong response:
  310. go testReadWriteResponse(t, &SignedHeartbeatResponse{}, rsConn)
  311. err = <-errc
  312. require.Error(t, err)
  313. require.Equal(t, err, ErrUnexpectedResponse)
  314. }
  315. func testSetupSocketPair(
  316. t *testing.T,
  317. chainID string,
  318. privValidator types.PrivValidator,
  319. ) (*TCPVal, *RemoteSigner) {
  320. var (
  321. addr = testFreeAddr(t)
  322. logger = log.TestingLogger()
  323. privVal = privValidator
  324. readyc = make(chan struct{})
  325. rs = NewRemoteSigner(
  326. logger,
  327. chainID,
  328. addr,
  329. privVal,
  330. ed25519.GenPrivKey(),
  331. )
  332. sc = NewTCPVal(
  333. logger,
  334. addr,
  335. ed25519.GenPrivKey(),
  336. )
  337. )
  338. TCPValConnTimeout(5 * time.Millisecond)(sc)
  339. TCPValHeartbeat(2 * time.Millisecond)(sc)
  340. RemoteSignerConnDeadline(5 * time.Millisecond)(rs)
  341. RemoteSignerConnRetries(1e6)(rs)
  342. testStartSocketPV(t, readyc, sc)
  343. require.NoError(t, rs.Start())
  344. assert.True(t, rs.IsRunning())
  345. <-readyc
  346. return sc, rs
  347. }
  348. func testReadWriteResponse(t *testing.T, resp RemoteSignerMsg, rsConn net.Conn) {
  349. _, err := readMsg(rsConn)
  350. require.NoError(t, err)
  351. err = writeMsg(rsConn, resp)
  352. require.NoError(t, err)
  353. }
  354. func testStartSocketPV(t *testing.T, readyc chan struct{}, sc *TCPVal) {
  355. go func(sc *TCPVal) {
  356. require.NoError(t, sc.Start())
  357. assert.True(t, sc.IsRunning())
  358. readyc <- struct{}{}
  359. }(sc)
  360. }
  361. // testFreeAddr claims a free port so we don't block on listener being ready.
  362. func testFreeAddr(t *testing.T) string {
  363. ln, err := net.Listen("tcp", "127.0.0.1:0")
  364. require.NoError(t, err)
  365. defer ln.Close()
  366. return fmt.Sprintf("127.0.0.1:%d", ln.Addr().(*net.TCPAddr).Port)
  367. }