diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 000000000..cae2f4c9f
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,78 @@
+# Changelog
+
+## 0.5.0 (April 21, 2017)
+
+BREAKING CHANGES:
+
+- Remove or unexport methods from FuzzedConnection: Active, Mode, ProbDropRW, ProbDropConn, ProbSleep, MaxDelayMilliseconds, Fuzz
+- switch.AddPeerWithConnection is unexported and replaced by switch.AddPeer
+- switch.DialPeerWithAddress takes a bool, setting the peer as persistent or not
+
+FEATURES:
+
+- Persistent peers: any peer considered a "seed" will be reconnected to when the connection is dropped
+
+
+IMPROVEMENTS:
+
+- Many more tests and comments
+- Refactor configurations for less dependence on go-config. Introduces new structs PeerConfig, MConnConfig, FuzzConnConfig
+- New methods on peer: CloseConn, HandshakeTimeout, IsPersistent, Addr, PubKey
+- NewNetAddress supports a testing mode where the address defaults to 0.0.0.0:0
+
+
+## 0.4.0 (March 6, 2017)
+
+BREAKING CHANGES:
+
+- DialSeeds now takes an AddrBook and returns an error: `DialSeeds(*AddrBook, []string) error`
+- NewNetAddressString now returns an error: `NewNetAddressString(string) (*NetAddress, error)`
+
+FEATURES:
+
+- `NewNetAddressStrings([]string) ([]*NetAddress, error)`
+- `AddrBook.Save()`
+
+IMPROVEMENTS:
+
+- PexReactor responsible for starting and stopping the AddrBook
+
+BUG FIXES:
+
+- DialSeeds returns an error instead of panicking on bad addresses
+
+## 0.3.5 (January 12, 2017)
+
+FEATURES
+
+- Toggle strict routability in the AddrBook
+
+BUG FIXES
+
+- Close filtered out connections
+- Fixes for MakeConnectedSwitches and Connect2Switches
+
+## 0.3.4 (August 10, 2016)
+
+FEATURES:
+
+- Optionally filter connections by address or public key
+
+## 0.3.3 (May 12, 2016)
+
+FEATURES:
+
+- FuzzConn
+
+## 0.3.2 (March 12, 2016)
+
+IMPROVEMENTS:
+
+- Memory optimizations
+
+## 0.3.1 ()
+
+FEATURES:
+
+- Configurable parameters
+
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 000000000..3716185f2
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,13 @@
+FROM golang:latest
+
+RUN curl https://glide.sh/get | sh
+
+RUN mkdir -p /go/src/github.com/tendermint/go-p2p
+WORKDIR /go/src/github.com/tendermint/go-p2p
+
+COPY glide.yaml /go/src/github.com/tendermint/go-p2p/
+COPY glide.lock /go/src/github.com/tendermint/go-p2p/
+
+RUN glide install
+
+COPY . /go/src/github.com/tendermint/go-p2p
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 000000000..e908e0f95
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,193 @@
+Tendermint Go-P2P
+Copyright (C) 2015 Tendermint
+
+
+
+ Apache License
+ Version 2.0, January 2004
+ https://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100644
index 000000000..236efd367
--- /dev/null
+++ b/README.md
@@ -0,0 +1,79 @@
+# `tendermint/go-p2p`
+
+[![CircleCI](https://circleci.com/gh/tendermint/go-p2p.svg?style=svg)](https://circleci.com/gh/tendermint/go-p2p)
+
+`tendermint/go-p2p` provides an abstraction around peer-to-peer communication.
+
+## Peer/MConnection/Channel
+
+Each peer has one `MConnection` (multiplex connection) instance.
+
+__multiplex__ *noun* a system or signal involving simultaneous transmission of
+several messages along a single channel of communication.
+
+Each `MConnection` handles message transmission on multiple abstract communication
+`Channel`s. Each channel has a globally unique byte id.
+The byte id and the relative priorities of each `Channel` are configured upon
+initialization of the connection.
+
+There are two methods for sending messages:
+```go
+func (m MConnection) Send(chID byte, msg interface{}) bool {}
+func (m MConnection) TrySend(chID byte, msg interface{}) bool {}
+```
+
+`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued
+for the channel with the given id byte `chID`. The message `msg` is serialized
+using the `tendermint/wire` submodule's `WriteBinary()` reflection routine.
+
+`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's
+queue is full.
+
+`Send()` and `TrySend()` are also exposed for each `Peer`.
+
+## Switch/Reactor
+
+The `Switch` handles peer connections and exposes an API to receive incoming messages
+on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one
+or more `Channels`. So while sending outgoing messages is typically performed on the peer,
+incoming messages are received on the reactor.
+
+```go
+// Declare a MyReactor reactor that handles messages on MyChannelID.
+type MyReactor struct{}
+
+func (reactor MyReactor) GetChannels() []*ChannelDescriptor {
+ return []*ChannelDescriptor{ChannelDescriptor{ID:MyChannelID, Priority: 1}}
+}
+
+func (reactor MyReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {
+ r, n, err := bytes.NewBuffer(msgBytes), new(int64), new(error)
+ msgString := ReadString(r, n, err)
+ fmt.Println(msgString)
+}
+
+// Other Reactor methods omitted for brevity
+...
+
+switch := NewSwitch([]Reactor{MyReactor{}})
+
+...
+
+// Send a random message to all outbound connections
+for _, peer := range switch.Peers().List() {
+ if peer.IsOutbound() {
+ peer.Send(MyChannelID, "Here's a random message")
+ }
+}
+```
+
+### PexReactor/AddrBook
+
+A `PEXReactor` reactor implementation is provided to automate peer discovery.
+
+```go
+book := p2p.NewAddrBook(addrBookFilePath)
+pexReactor := p2p.NewPEXReactor(book)
+...
+switch := NewSwitch([]Reactor{pexReactor, myReactor, ...})
+```
diff --git a/addrbook.go b/addrbook.go
new file mode 100644
index 000000000..e68cc7b3a
--- /dev/null
+++ b/addrbook.go
@@ -0,0 +1,839 @@
+// Modified for Tendermint
+// Originally Copyright (c) 2013-2014 Conformal Systems LLC.
+// https://github.com/conformal/btcd/blob/master/LICENSE
+
+package p2p
+
+import (
+ "encoding/binary"
+ "encoding/json"
+ "math"
+ "math/rand"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ . "github.com/tendermint/tmlibs/common"
+ crypto "github.com/tendermint/go-crypto"
+)
+
+const (
+ // addresses under which the address manager will claim to need more addresses.
+ needAddressThreshold = 1000
+
+ // interval used to dump the address cache to disk for future use.
+ dumpAddressInterval = time.Minute * 2
+
+ // max addresses in each old address bucket.
+ oldBucketSize = 64
+
+ // buckets we split old addresses over.
+ oldBucketCount = 64
+
+ // max addresses in each new address bucket.
+ newBucketSize = 64
+
+ // buckets that we spread new addresses over.
+ newBucketCount = 256
+
+ // old buckets over which an address group will be spread.
+ oldBucketsPerGroup = 4
+
+ // new buckets over which an source address group will be spread.
+ newBucketsPerGroup = 32
+
+ // buckets a frequently seen new address may end up in.
+ maxNewBucketsPerAddress = 4
+
+ // days before which we assume an address has vanished
+ // if we have not seen it announced in that long.
+ numMissingDays = 30
+
+ // tries without a single success before we assume an address is bad.
+ numRetries = 3
+
+ // max failures we will accept without a success before considering an address bad.
+ maxFailures = 10
+
+ // days since the last success before we will consider evicting an address.
+ minBadDays = 7
+
+ // % of total addresses known returned by GetSelection.
+ getSelectionPercent = 23
+
+ // min addresses that must be returned by GetSelection. Useful for bootstrapping.
+ minGetSelection = 32
+
+ // max addresses returned by GetSelection
+ // NOTE: this must match "maxPexMessageSize"
+ maxGetSelection = 250
+
+ // current version of the on-disk format.
+ serializationVersion = 1
+)
+
+const (
+ bucketTypeNew = 0x01
+ bucketTypeOld = 0x02
+)
+
+// AddrBook - concurrency safe peer address manager.
+type AddrBook struct {
+ BaseService
+
+ mtx sync.Mutex
+ filePath string
+ routabilityStrict bool
+ rand *rand.Rand
+ key string
+ ourAddrs map[string]*NetAddress
+ addrLookup map[string]*knownAddress // new & old
+ addrNew []map[string]*knownAddress
+ addrOld []map[string]*knownAddress
+ wg sync.WaitGroup
+ nOld int
+ nNew int
+}
+
+// NewAddrBook creates a new address book.
+// Use Start to begin processing asynchronous address updates.
+func NewAddrBook(filePath string, routabilityStrict bool) *AddrBook {
+ am := &AddrBook{
+ rand: rand.New(rand.NewSource(time.Now().UnixNano())),
+ ourAddrs: make(map[string]*NetAddress),
+ addrLookup: make(map[string]*knownAddress),
+ filePath: filePath,
+ routabilityStrict: routabilityStrict,
+ }
+ am.init()
+ am.BaseService = *NewBaseService(log, "AddrBook", am)
+ return am
+}
+
+// When modifying this, don't forget to update loadFromFile()
+func (a *AddrBook) init() {
+ a.key = crypto.CRandHex(24) // 24/2 * 8 = 96 bits
+ // New addr buckets
+ a.addrNew = make([]map[string]*knownAddress, newBucketCount)
+ for i := range a.addrNew {
+ a.addrNew[i] = make(map[string]*knownAddress)
+ }
+ // Old addr buckets
+ a.addrOld = make([]map[string]*knownAddress, oldBucketCount)
+ for i := range a.addrOld {
+ a.addrOld[i] = make(map[string]*knownAddress)
+ }
+}
+
+// OnStart implements Service.
+func (a *AddrBook) OnStart() error {
+ a.BaseService.OnStart()
+ a.loadFromFile(a.filePath)
+ a.wg.Add(1)
+ go a.saveRoutine()
+ return nil
+}
+
+// OnStop implements Service.
+func (a *AddrBook) OnStop() {
+ a.BaseService.OnStop()
+}
+
+func (a *AddrBook) Wait() {
+ a.wg.Wait()
+}
+
+func (a *AddrBook) AddOurAddress(addr *NetAddress) {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+ log.Info("Add our address to book", "addr", addr)
+ a.ourAddrs[addr.String()] = addr
+}
+
+func (a *AddrBook) OurAddresses() []*NetAddress {
+ addrs := []*NetAddress{}
+ for _, addr := range a.ourAddrs {
+ addrs = append(addrs, addr)
+ }
+ return addrs
+}
+
+// NOTE: addr must not be nil
+func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+ log.Info("Add address to book", "addr", addr, "src", src)
+ a.addAddress(addr, src)
+}
+
+func (a *AddrBook) NeedMoreAddrs() bool {
+ return a.Size() < needAddressThreshold
+}
+
+func (a *AddrBook) Size() int {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+ return a.size()
+}
+
+func (a *AddrBook) size() int {
+ return a.nNew + a.nOld
+}
+
+// Pick an address to connect to with new/old bias.
+func (a *AddrBook) PickAddress(newBias int) *NetAddress {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+
+ if a.size() == 0 {
+ return nil
+ }
+ if newBias > 100 {
+ newBias = 100
+ }
+ if newBias < 0 {
+ newBias = 0
+ }
+
+ // Bias between new and old addresses.
+ oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias))
+ newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias)
+
+ if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation {
+ // pick random Old bucket.
+ var bucket map[string]*knownAddress = nil
+ for len(bucket) == 0 {
+ bucket = a.addrOld[a.rand.Intn(len(a.addrOld))]
+ }
+ // pick a random ka from bucket.
+ randIndex := a.rand.Intn(len(bucket))
+ for _, ka := range bucket {
+ if randIndex == 0 {
+ return ka.Addr
+ }
+ randIndex--
+ }
+ PanicSanity("Should not happen")
+ } else {
+ // pick random New bucket.
+ var bucket map[string]*knownAddress = nil
+ for len(bucket) == 0 {
+ bucket = a.addrNew[a.rand.Intn(len(a.addrNew))]
+ }
+ // pick a random ka from bucket.
+ randIndex := a.rand.Intn(len(bucket))
+ for _, ka := range bucket {
+ if randIndex == 0 {
+ return ka.Addr
+ }
+ randIndex--
+ }
+ PanicSanity("Should not happen")
+ }
+ return nil
+}
+
+func (a *AddrBook) MarkGood(addr *NetAddress) {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+ ka := a.addrLookup[addr.String()]
+ if ka == nil {
+ return
+ }
+ ka.markGood()
+ if ka.isNew() {
+ a.moveToOld(ka)
+ }
+}
+
+func (a *AddrBook) MarkAttempt(addr *NetAddress) {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+ ka := a.addrLookup[addr.String()]
+ if ka == nil {
+ return
+ }
+ ka.markAttempt()
+}
+
+// MarkBad currently just ejects the address. In the future, consider
+// blacklisting.
+func (a *AddrBook) MarkBad(addr *NetAddress) {
+ a.RemoveAddress(addr)
+}
+
+// RemoveAddress removes the address from the book.
+func (a *AddrBook) RemoveAddress(addr *NetAddress) {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+ ka := a.addrLookup[addr.String()]
+ if ka == nil {
+ return
+ }
+ log.Info("Remove address from book", "addr", addr)
+ a.removeFromAllBuckets(ka)
+}
+
+/* Peer exchange */
+
+// GetSelection randomly selects some addresses (old & new). Suitable for peer-exchange protocols.
+func (a *AddrBook) GetSelection() []*NetAddress {
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+
+ if a.size() == 0 {
+ return nil
+ }
+
+ allAddr := make([]*NetAddress, a.size())
+ i := 0
+ for _, v := range a.addrLookup {
+ allAddr[i] = v.Addr
+ i++
+ }
+
+ numAddresses := MaxInt(
+ MinInt(minGetSelection, len(allAddr)),
+ len(allAddr)*getSelectionPercent/100)
+ numAddresses = MinInt(maxGetSelection, numAddresses)
+
+ // Fisher-Yates shuffle the array. We only need to do the first
+ // `numAddresses' since we are throwing the rest.
+ for i := 0; i < numAddresses; i++ {
+ // pick a number between current index and the end
+ j := rand.Intn(len(allAddr)-i) + i
+ allAddr[i], allAddr[j] = allAddr[j], allAddr[i]
+ }
+
+ // slice off the limit we are willing to share.
+ return allAddr[:numAddresses]
+}
+
+/* Loading & Saving */
+
+type addrBookJSON struct {
+ Key string
+ Addrs []*knownAddress
+}
+
+func (a *AddrBook) saveToFile(filePath string) {
+ log.Info("Saving AddrBook to file", "size", a.Size())
+
+ a.mtx.Lock()
+ defer a.mtx.Unlock()
+ // Compile Addrs
+ addrs := []*knownAddress{}
+ for _, ka := range a.addrLookup {
+ addrs = append(addrs, ka)
+ }
+
+ aJSON := &addrBookJSON{
+ Key: a.key,
+ Addrs: addrs,
+ }
+
+ jsonBytes, err := json.MarshalIndent(aJSON, "", "\t")
+ if err != nil {
+ log.Error("Failed to save AddrBook to file", "err", err)
+ return
+ }
+ err = WriteFileAtomic(filePath, jsonBytes, 0644)
+ if err != nil {
+ log.Error("Failed to save AddrBook to file", "file", filePath, "error", err)
+ }
+}
+
+// Returns false if file does not exist.
+// Panics if file is corrupt.
+func (a *AddrBook) loadFromFile(filePath string) bool {
+ // If doesn't exist, do nothing.
+ _, err := os.Stat(filePath)
+ if os.IsNotExist(err) {
+ return false
+ }
+
+ // Load addrBookJSON{}
+ r, err := os.Open(filePath)
+ if err != nil {
+ PanicCrisis(Fmt("Error opening file %s: %v", filePath, err))
+ }
+ defer r.Close()
+ aJSON := &addrBookJSON{}
+ dec := json.NewDecoder(r)
+ err = dec.Decode(aJSON)
+ if err != nil {
+ PanicCrisis(Fmt("Error reading file %s: %v", filePath, err))
+ }
+
+ // Restore all the fields...
+ // Restore the key
+ a.key = aJSON.Key
+ // Restore .addrNew & .addrOld
+ for _, ka := range aJSON.Addrs {
+ for _, bucketIndex := range ka.Buckets {
+ bucket := a.getBucket(ka.BucketType, bucketIndex)
+ bucket[ka.Addr.String()] = ka
+ }
+ a.addrLookup[ka.Addr.String()] = ka
+ if ka.BucketType == bucketTypeNew {
+ a.nNew++
+ } else {
+ a.nOld++
+ }
+ }
+ return true
+}
+
+// Save saves the book.
+func (a *AddrBook) Save() {
+ log.Info("Saving AddrBook to file", "size", a.Size())
+ a.saveToFile(a.filePath)
+}
+
+/* Private methods */
+
+func (a *AddrBook) saveRoutine() {
+ dumpAddressTicker := time.NewTicker(dumpAddressInterval)
+out:
+ for {
+ select {
+ case <-dumpAddressTicker.C:
+ a.saveToFile(a.filePath)
+ case <-a.Quit:
+ break out
+ }
+ }
+ dumpAddressTicker.Stop()
+ a.saveToFile(a.filePath)
+ a.wg.Done()
+ log.Notice("Address handler done")
+}
+
+func (a *AddrBook) getBucket(bucketType byte, bucketIdx int) map[string]*knownAddress {
+ switch bucketType {
+ case bucketTypeNew:
+ return a.addrNew[bucketIdx]
+ case bucketTypeOld:
+ return a.addrOld[bucketIdx]
+ default:
+ PanicSanity("Should not happen")
+ return nil
+ }
+}
+
+// Adds ka to new bucket. Returns false if it couldn't do it cuz buckets full.
+// NOTE: currently it always returns true.
+func (a *AddrBook) addToNewBucket(ka *knownAddress, bucketIdx int) bool {
+ // Sanity check
+ if ka.isOld() {
+ log.Warn(Fmt("Cannot add address already in old bucket to a new bucket: %v", ka))
+ return false
+ }
+
+ addrStr := ka.Addr.String()
+ bucket := a.getBucket(bucketTypeNew, bucketIdx)
+
+ // Already exists?
+ if _, ok := bucket[addrStr]; ok {
+ return true
+ }
+
+ // Enforce max addresses.
+ if len(bucket) > newBucketSize {
+ log.Notice("new bucket is full, expiring old ")
+ a.expireNew(bucketIdx)
+ }
+
+ // Add to bucket.
+ bucket[addrStr] = ka
+ if ka.addBucketRef(bucketIdx) == 1 {
+ a.nNew++
+ }
+
+ // Ensure in addrLookup
+ a.addrLookup[addrStr] = ka
+
+ return true
+}
+
+// Adds ka to old bucket. Returns false if it couldn't do it cuz buckets full.
+func (a *AddrBook) addToOldBucket(ka *knownAddress, bucketIdx int) bool {
+ // Sanity check
+ if ka.isNew() {
+ log.Warn(Fmt("Cannot add new address to old bucket: %v", ka))
+ return false
+ }
+ if len(ka.Buckets) != 0 {
+ log.Warn(Fmt("Cannot add already old address to another old bucket: %v", ka))
+ return false
+ }
+
+ addrStr := ka.Addr.String()
+ bucket := a.getBucket(bucketTypeNew, bucketIdx)
+
+ // Already exists?
+ if _, ok := bucket[addrStr]; ok {
+ return true
+ }
+
+ // Enforce max addresses.
+ if len(bucket) > oldBucketSize {
+ return false
+ }
+
+ // Add to bucket.
+ bucket[addrStr] = ka
+ if ka.addBucketRef(bucketIdx) == 1 {
+ a.nOld++
+ }
+
+ // Ensure in addrLookup
+ a.addrLookup[addrStr] = ka
+
+ return true
+}
+
+func (a *AddrBook) removeFromBucket(ka *knownAddress, bucketType byte, bucketIdx int) {
+ if ka.BucketType != bucketType {
+ log.Warn(Fmt("Bucket type mismatch: %v", ka))
+ return
+ }
+ bucket := a.getBucket(bucketType, bucketIdx)
+ delete(bucket, ka.Addr.String())
+ if ka.removeBucketRef(bucketIdx) == 0 {
+ if bucketType == bucketTypeNew {
+ a.nNew--
+ } else {
+ a.nOld--
+ }
+ delete(a.addrLookup, ka.Addr.String())
+ }
+}
+
+func (a *AddrBook) removeFromAllBuckets(ka *knownAddress) {
+ for _, bucketIdx := range ka.Buckets {
+ bucket := a.getBucket(ka.BucketType, bucketIdx)
+ delete(bucket, ka.Addr.String())
+ }
+ ka.Buckets = nil
+ if ka.BucketType == bucketTypeNew {
+ a.nNew--
+ } else {
+ a.nOld--
+ }
+ delete(a.addrLookup, ka.Addr.String())
+}
+
+func (a *AddrBook) pickOldest(bucketType byte, bucketIdx int) *knownAddress {
+ bucket := a.getBucket(bucketType, bucketIdx)
+ var oldest *knownAddress
+ for _, ka := range bucket {
+ if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt) {
+ oldest = ka
+ }
+ }
+ return oldest
+}
+
+func (a *AddrBook) addAddress(addr, src *NetAddress) {
+ if a.routabilityStrict && !addr.Routable() {
+ log.Warn(Fmt("Cannot add non-routable address %v", addr))
+ return
+ }
+ if _, ok := a.ourAddrs[addr.String()]; ok {
+ // Ignore our own listener address.
+ return
+ }
+
+ ka := a.addrLookup[addr.String()]
+
+ if ka != nil {
+ // Already old.
+ if ka.isOld() {
+ return
+ }
+ // Already in max new buckets.
+ if len(ka.Buckets) == maxNewBucketsPerAddress {
+ return
+ }
+ // The more entries we have, the less likely we are to add more.
+ factor := int32(2 * len(ka.Buckets))
+ if a.rand.Int31n(factor) != 0 {
+ return
+ }
+ } else {
+ ka = newKnownAddress(addr, src)
+ }
+
+ bucket := a.calcNewBucket(addr, src)
+ a.addToNewBucket(ka, bucket)
+
+ log.Notice("Added new address", "address", addr, "total", a.size())
+}
+
+// Make space in the new buckets by expiring the really bad entries.
+// If no bad entries are available we remove the oldest.
+func (a *AddrBook) expireNew(bucketIdx int) {
+ for addrStr, ka := range a.addrNew[bucketIdx] {
+ // If an entry is bad, throw it away
+ if ka.isBad() {
+ log.Notice(Fmt("expiring bad address %v", addrStr))
+ a.removeFromBucket(ka, bucketTypeNew, bucketIdx)
+ return
+ }
+ }
+
+ // If we haven't thrown out a bad entry, throw out the oldest entry
+ oldest := a.pickOldest(bucketTypeNew, bucketIdx)
+ a.removeFromBucket(oldest, bucketTypeNew, bucketIdx)
+}
+
+// Promotes an address from new to old.
+// TODO: Move to old probabilistically.
+// The better a node is, the less likely it should be evicted from an old bucket.
+func (a *AddrBook) moveToOld(ka *knownAddress) {
+ // Sanity check
+ if ka.isOld() {
+ log.Warn(Fmt("Cannot promote address that is already old %v", ka))
+ return
+ }
+ if len(ka.Buckets) == 0 {
+ log.Warn(Fmt("Cannot promote address that isn't in any new buckets %v", ka))
+ return
+ }
+
+ // Remember one of the buckets in which ka is in.
+ freedBucket := ka.Buckets[0]
+ // Remove from all (new) buckets.
+ a.removeFromAllBuckets(ka)
+ // It's officially old now.
+ ka.BucketType = bucketTypeOld
+
+ // Try to add it to its oldBucket destination.
+ oldBucketIdx := a.calcOldBucket(ka.Addr)
+ added := a.addToOldBucket(ka, oldBucketIdx)
+ if !added {
+ // No room, must evict something
+ oldest := a.pickOldest(bucketTypeOld, oldBucketIdx)
+ a.removeFromBucket(oldest, bucketTypeOld, oldBucketIdx)
+ // Find new bucket to put oldest in
+ newBucketIdx := a.calcNewBucket(oldest.Addr, oldest.Src)
+ added := a.addToNewBucket(oldest, newBucketIdx)
+ // No space in newBucket either, just put it in freedBucket from above.
+ if !added {
+ added := a.addToNewBucket(oldest, freedBucket)
+ if !added {
+ log.Warn(Fmt("Could not migrate oldest %v to freedBucket %v", oldest, freedBucket))
+ }
+ }
+ // Finally, add to bucket again.
+ added = a.addToOldBucket(ka, oldBucketIdx)
+ if !added {
+ log.Warn(Fmt("Could not re-add ka %v to oldBucketIdx %v", ka, oldBucketIdx))
+ }
+ }
+}
+
+// doublesha256( key + sourcegroup +
+// int64(doublesha256(key + group + sourcegroup))%bucket_per_group ) % num_new_buckets
+func (a *AddrBook) calcNewBucket(addr, src *NetAddress) int {
+ data1 := []byte{}
+ data1 = append(data1, []byte(a.key)...)
+ data1 = append(data1, []byte(a.groupKey(addr))...)
+ data1 = append(data1, []byte(a.groupKey(src))...)
+ hash1 := doubleSha256(data1)
+ hash64 := binary.BigEndian.Uint64(hash1)
+ hash64 %= newBucketsPerGroup
+ var hashbuf [8]byte
+ binary.BigEndian.PutUint64(hashbuf[:], hash64)
+ data2 := []byte{}
+ data2 = append(data2, []byte(a.key)...)
+ data2 = append(data2, a.groupKey(src)...)
+ data2 = append(data2, hashbuf[:]...)
+
+ hash2 := doubleSha256(data2)
+ return int(binary.BigEndian.Uint64(hash2) % newBucketCount)
+}
+
+// doublesha256( key + group +
+// int64(doublesha256(key + addr))%buckets_per_group ) % num_old_buckets
+func (a *AddrBook) calcOldBucket(addr *NetAddress) int {
+ data1 := []byte{}
+ data1 = append(data1, []byte(a.key)...)
+ data1 = append(data1, []byte(addr.String())...)
+ hash1 := doubleSha256(data1)
+ hash64 := binary.BigEndian.Uint64(hash1)
+ hash64 %= oldBucketsPerGroup
+ var hashbuf [8]byte
+ binary.BigEndian.PutUint64(hashbuf[:], hash64)
+ data2 := []byte{}
+ data2 = append(data2, []byte(a.key)...)
+ data2 = append(data2, a.groupKey(addr)...)
+ data2 = append(data2, hashbuf[:]...)
+
+ hash2 := doubleSha256(data2)
+ return int(binary.BigEndian.Uint64(hash2) % oldBucketCount)
+}
+
+// Return a string representing the network group of this address.
+// This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string
+// "local" for a local address and the string "unroutable for an unroutable
+// address.
+func (a *AddrBook) groupKey(na *NetAddress) string {
+ if a.routabilityStrict && na.Local() {
+ return "local"
+ }
+ if a.routabilityStrict && !na.Routable() {
+ return "unroutable"
+ }
+
+ if ipv4 := na.IP.To4(); ipv4 != nil {
+ return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String()
+ }
+ if na.RFC6145() || na.RFC6052() {
+ // last four bytes are the ip address
+ ip := net.IP(na.IP[12:16])
+ return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
+ }
+
+ if na.RFC3964() {
+ ip := net.IP(na.IP[2:7])
+ return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
+
+ }
+ if na.RFC4380() {
+ // teredo tunnels have the last 4 bytes as the v4 address XOR
+ // 0xff.
+ ip := net.IP(make([]byte, 4))
+ for i, byte := range na.IP[12:16] {
+ ip[i] = byte ^ 0xff
+ }
+ return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String()
+ }
+
+ // OK, so now we know ourselves to be a IPv6 address.
+ // bitcoind uses /32 for everything, except for Hurricane Electric's
+ // (he.net) IP range, which it uses /36 for.
+ bits := 32
+ heNet := &net.IPNet{IP: net.ParseIP("2001:470::"),
+ Mask: net.CIDRMask(32, 128)}
+ if heNet.Contains(na.IP) {
+ bits = 36
+ }
+
+ return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String()
+}
+
+//-----------------------------------------------------------------------------
+
+/*
+ knownAddress
+
+ tracks information about a known network address that is used
+ to determine how viable an address is.
+*/
+type knownAddress struct {
+ Addr *NetAddress
+ Src *NetAddress
+ Attempts int32
+ LastAttempt time.Time
+ LastSuccess time.Time
+ BucketType byte
+ Buckets []int
+}
+
+func newKnownAddress(addr *NetAddress, src *NetAddress) *knownAddress {
+ return &knownAddress{
+ Addr: addr,
+ Src: src,
+ Attempts: 0,
+ LastAttempt: time.Now(),
+ BucketType: bucketTypeNew,
+ Buckets: nil,
+ }
+}
+
+func (ka *knownAddress) isOld() bool {
+ return ka.BucketType == bucketTypeOld
+}
+
+func (ka *knownAddress) isNew() bool {
+ return ka.BucketType == bucketTypeNew
+}
+
+func (ka *knownAddress) markAttempt() {
+ now := time.Now()
+ ka.LastAttempt = now
+ ka.Attempts += 1
+}
+
+func (ka *knownAddress) markGood() {
+ now := time.Now()
+ ka.LastAttempt = now
+ ka.Attempts = 0
+ ka.LastSuccess = now
+}
+
+func (ka *knownAddress) addBucketRef(bucketIdx int) int {
+ for _, bucket := range ka.Buckets {
+ if bucket == bucketIdx {
+ log.Warn(Fmt("Bucket already exists in ka.Buckets: %v", ka))
+ return -1
+ }
+ }
+ ka.Buckets = append(ka.Buckets, bucketIdx)
+ return len(ka.Buckets)
+}
+
+func (ka *knownAddress) removeBucketRef(bucketIdx int) int {
+ buckets := []int{}
+ for _, bucket := range ka.Buckets {
+ if bucket != bucketIdx {
+ buckets = append(buckets, bucket)
+ }
+ }
+ if len(buckets) != len(ka.Buckets)-1 {
+ log.Warn(Fmt("bucketIdx not found in ka.Buckets: %v", ka))
+ return -1
+ }
+ ka.Buckets = buckets
+ return len(ka.Buckets)
+}
+
+/*
+ An address is bad if the address in question has not been tried in the last
+ minute and meets one of the following criteria:
+
+ 1) It claims to be from the future
+ 2) It hasn't been seen in over a month
+ 3) It has failed at least three times and never succeeded
+ 4) It has failed ten times in the last week
+
+ All addresses that meet these criteria are assumed to be worthless and not
+ worth keeping hold of.
+*/
+func (ka *knownAddress) isBad() bool {
+ // Has been attempted in the last minute --> good
+ if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) {
+ return false
+ }
+
+ // Over a month old?
+ if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) {
+ return true
+ }
+
+ // Never succeeded?
+ if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries {
+ return true
+ }
+
+ // Hasn't succeeded in too long?
+ if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) &&
+ ka.Attempts >= maxFailures {
+ return true
+ }
+
+ return false
+}
diff --git a/addrbook_test.go b/addrbook_test.go
new file mode 100644
index 000000000..16aea8ef9
--- /dev/null
+++ b/addrbook_test.go
@@ -0,0 +1,166 @@
+package p2p
+
+import (
+ "fmt"
+ "io/ioutil"
+ "math/rand"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func createTempFileName(prefix string) string {
+ f, err := ioutil.TempFile("", prefix)
+ if err != nil {
+ panic(err)
+ }
+ fname := f.Name()
+ err = f.Close()
+ if err != nil {
+ panic(err)
+ }
+ return fname
+}
+
+func TestAddrBookSaveLoad(t *testing.T) {
+ fname := createTempFileName("addrbook_test")
+
+ // 0 addresses
+ book := NewAddrBook(fname, true)
+ book.saveToFile(fname)
+
+ book = NewAddrBook(fname, true)
+ book.loadFromFile(fname)
+
+ assert.Zero(t, book.Size())
+
+ // 100 addresses
+ randAddrs := randNetAddressPairs(t, 100)
+
+ for _, addrSrc := range randAddrs {
+ book.AddAddress(addrSrc.addr, addrSrc.src)
+ }
+
+ assert.Equal(t, 100, book.Size())
+ book.saveToFile(fname)
+
+ book = NewAddrBook(fname, true)
+ book.loadFromFile(fname)
+
+ assert.Equal(t, 100, book.Size())
+}
+
+func TestAddrBookLookup(t *testing.T) {
+ fname := createTempFileName("addrbook_test")
+
+ randAddrs := randNetAddressPairs(t, 100)
+
+ book := NewAddrBook(fname, true)
+ for _, addrSrc := range randAddrs {
+ addr := addrSrc.addr
+ src := addrSrc.src
+ book.AddAddress(addr, src)
+
+ ka := book.addrLookup[addr.String()]
+ assert.NotNil(t, ka, "Expected to find KnownAddress %v but wasn't there.", addr)
+
+ if !(ka.Addr.Equals(addr) && ka.Src.Equals(src)) {
+ t.Fatalf("KnownAddress doesn't match addr & src")
+ }
+ }
+}
+
+func TestAddrBookPromoteToOld(t *testing.T) {
+ fname := createTempFileName("addrbook_test")
+
+ randAddrs := randNetAddressPairs(t, 100)
+
+ book := NewAddrBook(fname, true)
+ for _, addrSrc := range randAddrs {
+ book.AddAddress(addrSrc.addr, addrSrc.src)
+ }
+
+ // Attempt all addresses.
+ for _, addrSrc := range randAddrs {
+ book.MarkAttempt(addrSrc.addr)
+ }
+
+ // Promote half of them
+ for i, addrSrc := range randAddrs {
+ if i%2 == 0 {
+ book.MarkGood(addrSrc.addr)
+ }
+ }
+
+ // TODO: do more testing :)
+
+ selection := book.GetSelection()
+ t.Logf("selection: %v", selection)
+
+ if len(selection) > book.Size() {
+ t.Errorf("selection could not be bigger than the book")
+ }
+}
+
+func TestAddrBookHandlesDuplicates(t *testing.T) {
+ fname := createTempFileName("addrbook_test")
+
+ book := NewAddrBook(fname, true)
+
+ randAddrs := randNetAddressPairs(t, 100)
+
+ differentSrc := randIPv4Address(t)
+ for _, addrSrc := range randAddrs {
+ book.AddAddress(addrSrc.addr, addrSrc.src)
+ book.AddAddress(addrSrc.addr, addrSrc.src) // duplicate
+ book.AddAddress(addrSrc.addr, differentSrc) // different src
+ }
+
+ assert.Equal(t, 100, book.Size())
+}
+
+type netAddressPair struct {
+ addr *NetAddress
+ src *NetAddress
+}
+
+func randNetAddressPairs(t *testing.T, n int) []netAddressPair {
+ randAddrs := make([]netAddressPair, n)
+ for i := 0; i < n; i++ {
+ randAddrs[i] = netAddressPair{addr: randIPv4Address(t), src: randIPv4Address(t)}
+ }
+ return randAddrs
+}
+
+func randIPv4Address(t *testing.T) *NetAddress {
+ for {
+ ip := fmt.Sprintf("%v.%v.%v.%v",
+ rand.Intn(254)+1,
+ rand.Intn(255),
+ rand.Intn(255),
+ rand.Intn(255),
+ )
+ port := rand.Intn(65535-1) + 1
+ addr, err := NewNetAddressString(fmt.Sprintf("%v:%v", ip, port))
+ assert.Nil(t, err, "error generating rand network address")
+ if addr.Routable() {
+ return addr
+ }
+ }
+}
+
+func TestAddrBookRemoveAddress(t *testing.T) {
+ fname := createTempFileName("addrbook_test")
+ book := NewAddrBook(fname, true)
+
+ addr := randIPv4Address(t)
+ book.AddAddress(addr, addr)
+ assert.Equal(t, 1, book.Size())
+
+ book.RemoveAddress(addr)
+ assert.Equal(t, 0, book.Size())
+
+ nonExistingAddr := randIPv4Address(t)
+ book.RemoveAddress(nonExistingAddr)
+ assert.Equal(t, 0, book.Size())
+}
diff --git a/config.go b/config.go
new file mode 100644
index 000000000..a8b7e343b
--- /dev/null
+++ b/config.go
@@ -0,0 +1,45 @@
+package p2p
+
+import (
+ cfg "github.com/tendermint/go-config"
+)
+
+const (
+ // Switch config keys
+ configKeyDialTimeoutSeconds = "dial_timeout_seconds"
+ configKeyHandshakeTimeoutSeconds = "handshake_timeout_seconds"
+ configKeyMaxNumPeers = "max_num_peers"
+ configKeyAuthEnc = "authenticated_encryption"
+
+ // MConnection config keys
+ configKeySendRate = "send_rate"
+ configKeyRecvRate = "recv_rate"
+
+ // Fuzz params
+ configFuzzEnable = "fuzz_enable" // use the fuzz wrapped conn
+ configFuzzMode = "fuzz_mode" // eg. drop, delay
+ configFuzzMaxDelayMilliseconds = "fuzz_max_delay_milliseconds"
+ configFuzzProbDropRW = "fuzz_prob_drop_rw"
+ configFuzzProbDropConn = "fuzz_prob_drop_conn"
+ configFuzzProbSleep = "fuzz_prob_sleep"
+)
+
+func setConfigDefaults(config cfg.Config) {
+ // Switch default config
+ config.SetDefault(configKeyDialTimeoutSeconds, 3)
+ config.SetDefault(configKeyHandshakeTimeoutSeconds, 20)
+ config.SetDefault(configKeyMaxNumPeers, 50)
+ config.SetDefault(configKeyAuthEnc, true)
+
+ // MConnection default config
+ config.SetDefault(configKeySendRate, 512000) // 500KB/s
+ config.SetDefault(configKeyRecvRate, 512000) // 500KB/s
+
+ // Fuzz defaults
+ config.SetDefault(configFuzzEnable, false)
+ config.SetDefault(configFuzzMode, FuzzModeDrop)
+ config.SetDefault(configFuzzMaxDelayMilliseconds, 3000)
+ config.SetDefault(configFuzzProbDropRW, 0.2)
+ config.SetDefault(configFuzzProbDropConn, 0.00)
+ config.SetDefault(configFuzzProbSleep, 0.00)
+}
diff --git a/connection.go b/connection.go
new file mode 100644
index 000000000..629ab7b0d
--- /dev/null
+++ b/connection.go
@@ -0,0 +1,686 @@
+package p2p
+
+import (
+ "bufio"
+ "fmt"
+ "io"
+ "math"
+ "net"
+ "runtime/debug"
+ "sync/atomic"
+ "time"
+
+ wire "github.com/tendermint/go-wire"
+ cmn "github.com/tendermint/tmlibs/common"
+ flow "github.com/tendermint/tmlibs/flowrate"
+)
+
+const (
+ numBatchMsgPackets = 10
+ minReadBufferSize = 1024
+ minWriteBufferSize = 65536
+ updateState = 2 * time.Second
+ pingTimeout = 40 * time.Second
+ flushThrottle = 100 * time.Millisecond
+
+ defaultSendQueueCapacity = 1
+ defaultSendRate = int64(512000) // 500KB/s
+ defaultRecvBufferCapacity = 4096
+ defaultRecvMessageCapacity = 22020096 // 21MB
+ defaultRecvRate = int64(512000) // 500KB/s
+ defaultSendTimeout = 10 * time.Second
+)
+
+type receiveCbFunc func(chID byte, msgBytes []byte)
+type errorCbFunc func(interface{})
+
+/*
+Each peer has one `MConnection` (multiplex connection) instance.
+
+__multiplex__ *noun* a system or signal involving simultaneous transmission of
+several messages along a single channel of communication.
+
+Each `MConnection` handles message transmission on multiple abstract communication
+`Channel`s. Each channel has a globally unique byte id.
+The byte id and the relative priorities of each `Channel` are configured upon
+initialization of the connection.
+
+There are two methods for sending messages:
+ func (m MConnection) Send(chID byte, msg interface{}) bool {}
+ func (m MConnection) TrySend(chID byte, msg interface{}) bool {}
+
+`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued
+for the channel with the given id byte `chID`, or until the request times out.
+The message `msg` is serialized using the `tendermint/wire` submodule's
+`WriteBinary()` reflection routine.
+
+`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's
+queue is full.
+
+Inbound message bytes are handled with an onReceive callback function.
+*/
+type MConnection struct {
+ cmn.BaseService
+
+ conn net.Conn
+ bufReader *bufio.Reader
+ bufWriter *bufio.Writer
+ sendMonitor *flow.Monitor
+ recvMonitor *flow.Monitor
+ send chan struct{}
+ pong chan struct{}
+ channels []*Channel
+ channelsIdx map[byte]*Channel
+ onReceive receiveCbFunc
+ onError errorCbFunc
+ errored uint32
+ config *MConnConfig
+
+ quit chan struct{}
+ flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled.
+ pingTimer *cmn.RepeatTimer // send pings periodically
+ chStatsTimer *cmn.RepeatTimer // update channel stats periodically
+
+ LocalAddress *NetAddress
+ RemoteAddress *NetAddress
+}
+
+// MConnConfig is a MConnection configuration.
+type MConnConfig struct {
+ SendRate int64
+ RecvRate int64
+}
+
+// DefaultMConnConfig returns the default config.
+func DefaultMConnConfig() *MConnConfig {
+ return &MConnConfig{
+ SendRate: defaultSendRate,
+ RecvRate: defaultRecvRate,
+ }
+}
+
+// NewMConnection wraps net.Conn and creates multiplex connection
+func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection {
+ return NewMConnectionWithConfig(
+ conn,
+ chDescs,
+ onReceive,
+ onError,
+ DefaultMConnConfig())
+}
+
+// NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config
+func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection {
+ mconn := &MConnection{
+ conn: conn,
+ bufReader: bufio.NewReaderSize(conn, minReadBufferSize),
+ bufWriter: bufio.NewWriterSize(conn, minWriteBufferSize),
+ sendMonitor: flow.New(0, 0),
+ recvMonitor: flow.New(0, 0),
+ send: make(chan struct{}, 1),
+ pong: make(chan struct{}),
+ onReceive: onReceive,
+ onError: onError,
+ config: config,
+
+ LocalAddress: NewNetAddress(conn.LocalAddr()),
+ RemoteAddress: NewNetAddress(conn.RemoteAddr()),
+ }
+
+ // Create channels
+ var channelsIdx = map[byte]*Channel{}
+ var channels = []*Channel{}
+
+ for _, desc := range chDescs {
+ descCopy := *desc // copy the desc else unsafe access across connections
+ channel := newChannel(mconn, &descCopy)
+ channelsIdx[channel.id] = channel
+ channels = append(channels, channel)
+ }
+ mconn.channels = channels
+ mconn.channelsIdx = channelsIdx
+
+ mconn.BaseService = *cmn.NewBaseService(log, "MConnection", mconn)
+
+ return mconn
+}
+
+func (c *MConnection) OnStart() error {
+ c.BaseService.OnStart()
+ c.quit = make(chan struct{})
+ c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle)
+ c.pingTimer = cmn.NewRepeatTimer("ping", pingTimeout)
+ c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateState)
+ go c.sendRoutine()
+ go c.recvRoutine()
+ return nil
+}
+
+func (c *MConnection) OnStop() {
+ c.BaseService.OnStop()
+ c.flushTimer.Stop()
+ c.pingTimer.Stop()
+ c.chStatsTimer.Stop()
+ if c.quit != nil {
+ close(c.quit)
+ }
+ c.conn.Close()
+ // We can't close pong safely here because
+ // recvRoutine may write to it after we've stopped.
+ // Though it doesn't need to get closed at all,
+ // we close it @ recvRoutine.
+ // close(c.pong)
+}
+
+func (c *MConnection) String() string {
+ return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr())
+}
+
+func (c *MConnection) flush() {
+ log.Debug("Flush", "conn", c)
+ err := c.bufWriter.Flush()
+ if err != nil {
+ log.Warn("MConnection flush failed", "error", err)
+ }
+}
+
+// Catch panics, usually caused by remote disconnects.
+func (c *MConnection) _recover() {
+ if r := recover(); r != nil {
+ stack := debug.Stack()
+ err := cmn.StackError{r, stack}
+ c.stopForError(err)
+ }
+}
+
+func (c *MConnection) stopForError(r interface{}) {
+ c.Stop()
+ if atomic.CompareAndSwapUint32(&c.errored, 0, 1) {
+ if c.onError != nil {
+ c.onError(r)
+ }
+ }
+}
+
+// Queues a message to be sent to channel.
+func (c *MConnection) Send(chID byte, msg interface{}) bool {
+ if !c.IsRunning() {
+ return false
+ }
+
+ log.Debug("Send", "channel", chID, "conn", c, "msg", msg) //, "bytes", wire.BinaryBytes(msg))
+
+ // Send message to channel.
+ channel, ok := c.channelsIdx[chID]
+ if !ok {
+ log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID))
+ return false
+ }
+
+ success := channel.sendBytes(wire.BinaryBytes(msg))
+ if success {
+ // Wake up sendRoutine if necessary
+ select {
+ case c.send <- struct{}{}:
+ default:
+ }
+ } else {
+ log.Warn("Send failed", "channel", chID, "conn", c, "msg", msg)
+ }
+ return success
+}
+
+// Queues a message to be sent to channel.
+// Nonblocking, returns true if successful.
+func (c *MConnection) TrySend(chID byte, msg interface{}) bool {
+ if !c.IsRunning() {
+ return false
+ }
+
+ log.Debug("TrySend", "channel", chID, "conn", c, "msg", msg)
+
+ // Send message to channel.
+ channel, ok := c.channelsIdx[chID]
+ if !ok {
+ log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID))
+ return false
+ }
+
+ ok = channel.trySendBytes(wire.BinaryBytes(msg))
+ if ok {
+ // Wake up sendRoutine if necessary
+ select {
+ case c.send <- struct{}{}:
+ default:
+ }
+ }
+
+ return ok
+}
+
+// CanSend returns true if you can send more data onto the chID, false
+// otherwise. Use only as a heuristic.
+func (c *MConnection) CanSend(chID byte) bool {
+ if !c.IsRunning() {
+ return false
+ }
+
+ channel, ok := c.channelsIdx[chID]
+ if !ok {
+ log.Error(cmn.Fmt("Unknown channel %X", chID))
+ return false
+ }
+ return channel.canSend()
+}
+
+// sendRoutine polls for packets to send from channels.
+func (c *MConnection) sendRoutine() {
+ defer c._recover()
+
+FOR_LOOP:
+ for {
+ var n int
+ var err error
+ select {
+ case <-c.flushTimer.Ch:
+ // NOTE: flushTimer.Set() must be called every time
+ // something is written to .bufWriter.
+ c.flush()
+ case <-c.chStatsTimer.Ch:
+ for _, channel := range c.channels {
+ channel.updateStats()
+ }
+ case <-c.pingTimer.Ch:
+ log.Debug("Send Ping")
+ wire.WriteByte(packetTypePing, c.bufWriter, &n, &err)
+ c.sendMonitor.Update(int(n))
+ c.flush()
+ case <-c.pong:
+ log.Debug("Send Pong")
+ wire.WriteByte(packetTypePong, c.bufWriter, &n, &err)
+ c.sendMonitor.Update(int(n))
+ c.flush()
+ case <-c.quit:
+ break FOR_LOOP
+ case <-c.send:
+ // Send some msgPackets
+ eof := c.sendSomeMsgPackets()
+ if !eof {
+ // Keep sendRoutine awake.
+ select {
+ case c.send <- struct{}{}:
+ default:
+ }
+ }
+ }
+
+ if !c.IsRunning() {
+ break FOR_LOOP
+ }
+ if err != nil {
+ log.Warn("Connection failed @ sendRoutine", "conn", c, "error", err)
+ c.stopForError(err)
+ break FOR_LOOP
+ }
+ }
+
+ // Cleanup
+}
+
+// Returns true if messages from channels were exhausted.
+// Blocks in accordance to .sendMonitor throttling.
+func (c *MConnection) sendSomeMsgPackets() bool {
+ // Block until .sendMonitor says we can write.
+ // Once we're ready we send more than we asked for,
+ // but amortized it should even out.
+ c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true)
+
+ // Now send some msgPackets.
+ for i := 0; i < numBatchMsgPackets; i++ {
+ if c.sendMsgPacket() {
+ return true
+ }
+ }
+ return false
+}
+
+// Returns true if messages from channels were exhausted.
+func (c *MConnection) sendMsgPacket() bool {
+ // Choose a channel to create a msgPacket from.
+ // The chosen channel will be the one whose recentlySent/priority is the least.
+ var leastRatio float32 = math.MaxFloat32
+ var leastChannel *Channel
+ for _, channel := range c.channels {
+ // If nothing to send, skip this channel
+ if !channel.isSendPending() {
+ continue
+ }
+ // Get ratio, and keep track of lowest ratio.
+ ratio := float32(channel.recentlySent) / float32(channel.priority)
+ if ratio < leastRatio {
+ leastRatio = ratio
+ leastChannel = channel
+ }
+ }
+
+ // Nothing to send?
+ if leastChannel == nil {
+ return true
+ } else {
+ // log.Info("Found a msgPacket to send")
+ }
+
+ // Make & send a msgPacket from this channel
+ n, err := leastChannel.writeMsgPacketTo(c.bufWriter)
+ if err != nil {
+ log.Warn("Failed to write msgPacket", "error", err)
+ c.stopForError(err)
+ return true
+ }
+ c.sendMonitor.Update(int(n))
+ c.flushTimer.Set()
+ return false
+}
+
+// recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer.
+// After a whole message has been assembled, it's pushed to onReceive().
+// Blocks depending on how the connection is throttled.
+func (c *MConnection) recvRoutine() {
+ defer c._recover()
+
+FOR_LOOP:
+ for {
+ // Block until .recvMonitor says we can read.
+ c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true)
+
+ /*
+ // Peek into bufReader for debugging
+ if numBytes := c.bufReader.Buffered(); numBytes > 0 {
+ log.Info("Peek connection buffer", "numBytes", numBytes, "bytes", log15.Lazy{func() []byte {
+ bytes, err := c.bufReader.Peek(MinInt(numBytes, 100))
+ if err == nil {
+ return bytes
+ } else {
+ log.Warn("Error peeking connection buffer", "error", err)
+ return nil
+ }
+ }})
+ }
+ */
+
+ // Read packet type
+ var n int
+ var err error
+ pktType := wire.ReadByte(c.bufReader, &n, &err)
+ c.recvMonitor.Update(int(n))
+ if err != nil {
+ if c.IsRunning() {
+ log.Warn("Connection failed @ recvRoutine (reading byte)", "conn", c, "error", err)
+ c.stopForError(err)
+ }
+ break FOR_LOOP
+ }
+
+ // Read more depending on packet type.
+ switch pktType {
+ case packetTypePing:
+ // TODO: prevent abuse, as they cause flush()'s.
+ log.Debug("Receive Ping")
+ c.pong <- struct{}{}
+ case packetTypePong:
+ // do nothing
+ log.Debug("Receive Pong")
+ case packetTypeMsg:
+ pkt, n, err := msgPacket{}, int(0), error(nil)
+ wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err)
+ c.recvMonitor.Update(int(n))
+ if err != nil {
+ if c.IsRunning() {
+ log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err)
+ c.stopForError(err)
+ }
+ break FOR_LOOP
+ }
+ channel, ok := c.channelsIdx[pkt.ChannelID]
+ if !ok || channel == nil {
+ cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID))
+ }
+ msgBytes, err := channel.recvMsgPacket(pkt)
+ if err != nil {
+ if c.IsRunning() {
+ log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err)
+ c.stopForError(err)
+ }
+ break FOR_LOOP
+ }
+ if msgBytes != nil {
+ log.Debug("Received bytes", "chID", pkt.ChannelID, "msgBytes", msgBytes)
+ c.onReceive(pkt.ChannelID, msgBytes)
+ }
+ default:
+ cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType))
+ }
+
+ // TODO: shouldn't this go in the sendRoutine?
+ // Better to send a ping packet when *we* haven't sent anything for a while.
+ c.pingTimer.Reset()
+ }
+
+ // Cleanup
+ close(c.pong)
+ for _ = range c.pong {
+ // Drain
+ }
+}
+
+type ConnectionStatus struct {
+ SendMonitor flow.Status
+ RecvMonitor flow.Status
+ Channels []ChannelStatus
+}
+
+type ChannelStatus struct {
+ ID byte
+ SendQueueCapacity int
+ SendQueueSize int
+ Priority int
+ RecentlySent int64
+}
+
+func (c *MConnection) Status() ConnectionStatus {
+ var status ConnectionStatus
+ status.SendMonitor = c.sendMonitor.Status()
+ status.RecvMonitor = c.recvMonitor.Status()
+ status.Channels = make([]ChannelStatus, len(c.channels))
+ for i, channel := range c.channels {
+ status.Channels[i] = ChannelStatus{
+ ID: channel.id,
+ SendQueueCapacity: cap(channel.sendQueue),
+ SendQueueSize: int(channel.sendQueueSize), // TODO use atomic
+ Priority: channel.priority,
+ RecentlySent: channel.recentlySent,
+ }
+ }
+ return status
+}
+
+//-----------------------------------------------------------------------------
+
+type ChannelDescriptor struct {
+ ID byte
+ Priority int
+ SendQueueCapacity int
+ RecvBufferCapacity int
+ RecvMessageCapacity int
+}
+
+func (chDesc *ChannelDescriptor) FillDefaults() {
+ if chDesc.SendQueueCapacity == 0 {
+ chDesc.SendQueueCapacity = defaultSendQueueCapacity
+ }
+ if chDesc.RecvBufferCapacity == 0 {
+ chDesc.RecvBufferCapacity = defaultRecvBufferCapacity
+ }
+ if chDesc.RecvMessageCapacity == 0 {
+ chDesc.RecvMessageCapacity = defaultRecvMessageCapacity
+ }
+}
+
+// TODO: lowercase.
+// NOTE: not goroutine-safe.
+type Channel struct {
+ conn *MConnection
+ desc *ChannelDescriptor
+ id byte
+ sendQueue chan []byte
+ sendQueueSize int32 // atomic.
+ recving []byte
+ sending []byte
+ priority int
+ recentlySent int64 // exponential moving average
+}
+
+func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel {
+ desc.FillDefaults()
+ if desc.Priority <= 0 {
+ cmn.PanicSanity("Channel default priority must be a postive integer")
+ }
+ return &Channel{
+ conn: conn,
+ desc: desc,
+ id: desc.ID,
+ sendQueue: make(chan []byte, desc.SendQueueCapacity),
+ recving: make([]byte, 0, desc.RecvBufferCapacity),
+ priority: desc.Priority,
+ }
+}
+
+// Queues message to send to this channel.
+// Goroutine-safe
+// Times out (and returns false) after defaultSendTimeout
+func (ch *Channel) sendBytes(bytes []byte) bool {
+ select {
+ case ch.sendQueue <- bytes:
+ atomic.AddInt32(&ch.sendQueueSize, 1)
+ return true
+ case <-time.After(defaultSendTimeout):
+ return false
+ }
+}
+
+// Queues message to send to this channel.
+// Nonblocking, returns true if successful.
+// Goroutine-safe
+func (ch *Channel) trySendBytes(bytes []byte) bool {
+ select {
+ case ch.sendQueue <- bytes:
+ atomic.AddInt32(&ch.sendQueueSize, 1)
+ return true
+ default:
+ return false
+ }
+}
+
+// Goroutine-safe
+func (ch *Channel) loadSendQueueSize() (size int) {
+ return int(atomic.LoadInt32(&ch.sendQueueSize))
+}
+
+// Goroutine-safe
+// Use only as a heuristic.
+func (ch *Channel) canSend() bool {
+ return ch.loadSendQueueSize() < defaultSendQueueCapacity
+}
+
+// Returns true if any msgPackets are pending to be sent.
+// Call before calling nextMsgPacket()
+// Goroutine-safe
+func (ch *Channel) isSendPending() bool {
+ if len(ch.sending) == 0 {
+ if len(ch.sendQueue) == 0 {
+ return false
+ }
+ ch.sending = <-ch.sendQueue
+ }
+ return true
+}
+
+// Creates a new msgPacket to send.
+// Not goroutine-safe
+func (ch *Channel) nextMsgPacket() msgPacket {
+ packet := msgPacket{}
+ packet.ChannelID = byte(ch.id)
+ packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))]
+ if len(ch.sending) <= maxMsgPacketPayloadSize {
+ packet.EOF = byte(0x01)
+ ch.sending = nil
+ atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize
+ } else {
+ packet.EOF = byte(0x00)
+ ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):]
+ }
+ return packet
+}
+
+// Writes next msgPacket to w.
+// Not goroutine-safe
+func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) {
+ packet := ch.nextMsgPacket()
+ log.Debug("Write Msg Packet", "conn", ch.conn, "packet", packet)
+ wire.WriteByte(packetTypeMsg, w, &n, &err)
+ wire.WriteBinary(packet, w, &n, &err)
+ if err == nil {
+ ch.recentlySent += int64(n)
+ }
+ return
+}
+
+// Handles incoming msgPackets. Returns a msg bytes if msg is complete.
+// Not goroutine-safe
+func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) {
+ // log.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet)
+ if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) {
+ return nil, wire.ErrBinaryReadOverflow
+ }
+ ch.recving = append(ch.recving, packet.Bytes...)
+ if packet.EOF == byte(0x01) {
+ msgBytes := ch.recving
+ // clear the slice without re-allocating.
+ // http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go
+ // suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes,
+ // at which point the recving slice stops being used and should be garbage collected
+ ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity)
+ return msgBytes, nil
+ }
+ return nil, nil
+}
+
+// Call this periodically to update stats for throttling purposes.
+// Not goroutine-safe
+func (ch *Channel) updateStats() {
+ // Exponential decay of stats.
+ // TODO: optimize.
+ ch.recentlySent = int64(float64(ch.recentlySent) * 0.8)
+}
+
+//-----------------------------------------------------------------------------
+
+const (
+ maxMsgPacketPayloadSize = 1024
+ maxMsgPacketOverheadSize = 10 // It's actually lower but good enough
+ maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize
+ packetTypePing = byte(0x01)
+ packetTypePong = byte(0x02)
+ packetTypeMsg = byte(0x03)
+)
+
+// Messages in channels are chopped into smaller msgPackets for multiplexing.
+type msgPacket struct {
+ ChannelID byte
+ EOF byte // 1 means message ends here.
+ Bytes []byte
+}
+
+func (p msgPacket) String() string {
+ return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF)
+}
diff --git a/connection_test.go b/connection_test.go
new file mode 100644
index 000000000..33d8adfd1
--- /dev/null
+++ b/connection_test.go
@@ -0,0 +1,139 @@
+package p2p_test
+
+import (
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ p2p "github.com/tendermint/go-p2p"
+)
+
+func createMConnection(conn net.Conn) *p2p.MConnection {
+ onReceive := func(chID byte, msgBytes []byte) {
+ }
+ onError := func(r interface{}) {
+ }
+ return createMConnectionWithCallbacks(conn, onReceive, onError)
+}
+
+func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *p2p.MConnection {
+ chDescs := []*p2p.ChannelDescriptor{&p2p.ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
+ return p2p.NewMConnection(conn, chDescs, onReceive, onError)
+}
+
+func TestMConnectionSend(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ server, client := net.Pipe()
+ defer server.Close()
+ defer client.Close()
+
+ mconn := createMConnection(client)
+ _, err := mconn.Start()
+ require.Nil(err)
+ defer mconn.Stop()
+
+ msg := "Ant-Man"
+ assert.True(mconn.Send(0x01, msg))
+ // Note: subsequent Send/TrySend calls could pass because we are reading from
+ // the send queue in a separate goroutine.
+ server.Read(make([]byte, len(msg)))
+ assert.True(mconn.CanSend(0x01))
+
+ msg = "Spider-Man"
+ assert.True(mconn.TrySend(0x01, msg))
+ server.Read(make([]byte, len(msg)))
+
+ assert.False(mconn.CanSend(0x05), "CanSend should return false because channel is unknown")
+ assert.False(mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown")
+}
+
+func TestMConnectionReceive(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ server, client := net.Pipe()
+ defer server.Close()
+ defer client.Close()
+
+ receivedCh := make(chan []byte)
+ errorsCh := make(chan interface{})
+ onReceive := func(chID byte, msgBytes []byte) {
+ receivedCh <- msgBytes
+ }
+ onError := func(r interface{}) {
+ errorsCh <- r
+ }
+ mconn1 := createMConnectionWithCallbacks(client, onReceive, onError)
+ _, err := mconn1.Start()
+ require.Nil(err)
+ defer mconn1.Stop()
+
+ mconn2 := createMConnection(server)
+ _, err = mconn2.Start()
+ require.Nil(err)
+ defer mconn2.Stop()
+
+ msg := "Cyclops"
+ assert.True(mconn2.Send(0x01, msg))
+
+ select {
+ case receivedBytes := <-receivedCh:
+ assert.Equal([]byte(msg), receivedBytes[2:]) // first 3 bytes are internal
+ case err := <-errorsCh:
+ t.Fatalf("Expected %s, got %+v", msg, err)
+ case <-time.After(500 * time.Millisecond):
+ t.Fatalf("Did not receive %s message in 500ms", msg)
+ }
+}
+
+func TestMConnectionStatus(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ server, client := net.Pipe()
+ defer server.Close()
+ defer client.Close()
+
+ mconn := createMConnection(client)
+ _, err := mconn.Start()
+ require.Nil(err)
+ defer mconn.Stop()
+
+ status := mconn.Status()
+ assert.NotNil(status)
+ assert.Zero(status.Channels[0].SendQueueSize)
+}
+
+func TestMConnectionStopsAndReturnsError(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ server, client := net.Pipe()
+ defer server.Close()
+ defer client.Close()
+
+ receivedCh := make(chan []byte)
+ errorsCh := make(chan interface{})
+ onReceive := func(chID byte, msgBytes []byte) {
+ receivedCh <- msgBytes
+ }
+ onError := func(r interface{}) {
+ errorsCh <- r
+ }
+ mconn := createMConnectionWithCallbacks(client, onReceive, onError)
+ _, err := mconn.Start()
+ require.Nil(err)
+ defer mconn.Stop()
+
+ client.Close()
+
+ select {
+ case receivedBytes := <-receivedCh:
+ t.Fatalf("Expected error, got %v", receivedBytes)
+ case err := <-errorsCh:
+ assert.NotNil(err)
+ assert.False(mconn.IsRunning())
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("Did not receive error in 500ms")
+ }
+}
diff --git a/fuzz.go b/fuzz.go
new file mode 100644
index 000000000..aefac986a
--- /dev/null
+++ b/fuzz.go
@@ -0,0 +1,173 @@
+package p2p
+
+import (
+ "math/rand"
+ "net"
+ "sync"
+ "time"
+)
+
+const (
+ // FuzzModeDrop is a mode in which we randomly drop reads/writes, connections or sleep
+ FuzzModeDrop = iota
+ // FuzzModeDelay is a mode in which we randomly sleep
+ FuzzModeDelay
+)
+
+// FuzzedConnection wraps any net.Conn and depending on the mode either delays
+// reads/writes or randomly drops reads/writes/connections.
+type FuzzedConnection struct {
+ conn net.Conn
+
+ mtx sync.Mutex
+ start <-chan time.Time
+ active bool
+
+ config *FuzzConnConfig
+}
+
+// FuzzConnConfig is a FuzzedConnection configuration.
+type FuzzConnConfig struct {
+ Mode int
+ MaxDelay time.Duration
+ ProbDropRW float64
+ ProbDropConn float64
+ ProbSleep float64
+}
+
+// DefaultFuzzConnConfig returns the default config.
+func DefaultFuzzConnConfig() *FuzzConnConfig {
+ return &FuzzConnConfig{
+ Mode: FuzzModeDrop,
+ MaxDelay: 3 * time.Second,
+ ProbDropRW: 0.2,
+ ProbDropConn: 0.00,
+ ProbSleep: 0.00,
+ }
+}
+
+// FuzzConn creates a new FuzzedConnection. Fuzzing starts immediately.
+func FuzzConn(conn net.Conn) net.Conn {
+ return FuzzConnFromConfig(conn, DefaultFuzzConnConfig())
+}
+
+// FuzzConnFromConfig creates a new FuzzedConnection from a config. Fuzzing
+// starts immediately.
+func FuzzConnFromConfig(conn net.Conn, config *FuzzConnConfig) net.Conn {
+ return &FuzzedConnection{
+ conn: conn,
+ start: make(<-chan time.Time),
+ active: true,
+ config: config,
+ }
+}
+
+// FuzzConnAfter creates a new FuzzedConnection. Fuzzing starts when the
+// duration elapses.
+func FuzzConnAfter(conn net.Conn, d time.Duration) net.Conn {
+ return FuzzConnAfterFromConfig(conn, d, DefaultFuzzConnConfig())
+}
+
+// FuzzConnAfterFromConfig creates a new FuzzedConnection from a config.
+// Fuzzing starts when the duration elapses.
+func FuzzConnAfterFromConfig(conn net.Conn, d time.Duration, config *FuzzConnConfig) net.Conn {
+ return &FuzzedConnection{
+ conn: conn,
+ start: time.After(d),
+ active: false,
+ config: config,
+ }
+}
+
+// Config returns the connection's config.
+func (fc *FuzzedConnection) Config() *FuzzConnConfig {
+ return fc.config
+}
+
+// Read implements net.Conn.
+func (fc *FuzzedConnection) Read(data []byte) (n int, err error) {
+ if fc.fuzz() {
+ return 0, nil
+ }
+ return fc.conn.Read(data)
+}
+
+// Write implements net.Conn.
+func (fc *FuzzedConnection) Write(data []byte) (n int, err error) {
+ if fc.fuzz() {
+ return 0, nil
+ }
+ return fc.conn.Write(data)
+}
+
+// Close implements net.Conn.
+func (fc *FuzzedConnection) Close() error { return fc.conn.Close() }
+
+// LocalAddr implements net.Conn.
+func (fc *FuzzedConnection) LocalAddr() net.Addr { return fc.conn.LocalAddr() }
+
+// RemoteAddr implements net.Conn.
+func (fc *FuzzedConnection) RemoteAddr() net.Addr { return fc.conn.RemoteAddr() }
+
+// SetDeadline implements net.Conn.
+func (fc *FuzzedConnection) SetDeadline(t time.Time) error { return fc.conn.SetDeadline(t) }
+
+// SetReadDeadline implements net.Conn.
+func (fc *FuzzedConnection) SetReadDeadline(t time.Time) error {
+ return fc.conn.SetReadDeadline(t)
+}
+
+// SetWriteDeadline implements net.Conn.
+func (fc *FuzzedConnection) SetWriteDeadline(t time.Time) error {
+ return fc.conn.SetWriteDeadline(t)
+}
+
+func (fc *FuzzedConnection) randomDuration() time.Duration {
+ maxDelayMillis := int(fc.config.MaxDelay.Nanoseconds() / 1000)
+ return time.Millisecond * time.Duration(rand.Int()%maxDelayMillis)
+}
+
+// implements the fuzz (delay, kill conn)
+// and returns whether or not the read/write should be ignored
+func (fc *FuzzedConnection) fuzz() bool {
+ if !fc.shouldFuzz() {
+ return false
+ }
+
+ switch fc.config.Mode {
+ case FuzzModeDrop:
+ // randomly drop the r/w, drop the conn, or sleep
+ r := rand.Float64()
+ if r <= fc.config.ProbDropRW {
+ return true
+ } else if r < fc.config.ProbDropRW+fc.config.ProbDropConn {
+ // XXX: can't this fail because machine precision?
+ // XXX: do we need an error?
+ fc.Close()
+ return true
+ } else if r < fc.config.ProbDropRW+fc.config.ProbDropConn+fc.config.ProbSleep {
+ time.Sleep(fc.randomDuration())
+ }
+ case FuzzModeDelay:
+ // sleep a bit
+ time.Sleep(fc.randomDuration())
+ }
+ return false
+}
+
+func (fc *FuzzedConnection) shouldFuzz() bool {
+ if fc.active {
+ return true
+ }
+
+ fc.mtx.Lock()
+ defer fc.mtx.Unlock()
+
+ select {
+ case <-fc.start:
+ fc.active = true
+ return true
+ default:
+ return false
+ }
+}
diff --git a/glide.lock b/glide.lock
new file mode 100644
index 000000000..5da4e6230
--- /dev/null
+++ b/glide.lock
@@ -0,0 +1,73 @@
+hash: f3d76bef9548cc37ad6038cb55f0812bac7e64735a99995c9da85010eef27f50
+updated: 2017-04-19T00:00:50.949249104-04:00
+imports:
+- name: github.com/btcsuite/btcd
+ version: b8df516b4b267acf2de46be593a9d948d1d2c420
+ subpackages:
+ - btcec
+- name: github.com/btcsuite/fastsha256
+ version: 637e656429416087660c84436a2a035d69d54e2e
+- name: github.com/BurntSushi/toml
+ version: 99064174e013895bbd9b025c31100bd1d9b590ca
+- name: github.com/go-stack/stack
+ version: 100eb0c0a9c5b306ca2fb4f165df21d80ada4b82
+- name: github.com/mattn/go-colorable
+ version: 9fdad7c47650b7d2e1da50644c1f4ba7f172f252
+- name: github.com/mattn/go-isatty
+ version: 56b76bdf51f7708750eac80fa38b952bb9f32639
+- name: github.com/pkg/errors
+ version: 645ef00459ed84a119197bfb8d8205042c6df63d
+- name: github.com/tendermint/ed25519
+ version: 1f52c6f8b8a5c7908aff4497c186af344b428925
+ subpackages:
+ - edwards25519
+ - extra25519
+- name: github.com/tendermint/tmlibs/common
+ version: f9e3db037330c8a8d61d3966de8473eaf01154fa
+- name: github.com/tendermint/go-config
+ version: 620dcbbd7d587cf3599dedbf329b64311b0c307a
+- name: github.com/tendermint/go-crypto
+ version: 0ca2c6fdb0706001ca4c4b9b80c9f428e8cf39da
+- name: github.com/tendermint/go-wire/data
+ version: e7fcc6d081ec8518912fcdc103188275f83a3ee5
+- name: github.com/tendermint/tmlibs/flowrate
+ version: a20c98e61957faa93b4014fbd902f20ab9317a6a
+ subpackages:
+ - flowrate
+- name: github.com/tendermint/tmlibs/logger
+ version: cefb3a45c0bf3c493a04e9bcd9b1540528be59f2
+- name: github.com/tendermint/go-wire
+ version: c1c9a57ab8038448ddea1714c0698f8051e5748c
+- name: github.com/tendermint/log15
+ version: ae0f3d6450da9eac7074b439c8e1c3cabf0d5ce6
+ subpackages:
+ - term
+- name: golang.org/x/crypto
+ version: 1f22c0103821b9390939b6776727195525381532
+ subpackages:
+ - curve25519
+ - nacl/box
+ - nacl/secretbox
+ - openpgp/armor
+ - openpgp/errors
+ - poly1305
+ - ripemd160
+ - salsa20/salsa
+- name: golang.org/x/sys
+ version: 50c6bc5e4292a1d4e65c6e9be5f53be28bcbe28e
+ subpackages:
+ - unix
+testImports:
+- name: github.com/davecgh/go-spew
+ version: 6d212800a42e8ab5c146b8ace3490ee17e5225f9
+ subpackages:
+ - spew
+- name: github.com/pmezard/go-difflib
+ version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
+ subpackages:
+ - difflib
+- name: github.com/stretchr/testify
+ version: 69483b4bd14f5845b5a1e55bca19e954e827f1d0
+ subpackages:
+ - assert
+ - require
diff --git a/glide.yaml b/glide.yaml
new file mode 100644
index 000000000..5bf7a015e
--- /dev/null
+++ b/glide.yaml
@@ -0,0 +1,29 @@
+package: github.com/tendermint/go-p2p
+import:
+- package: github.com/tendermint/tmlibs/common
+ version: develop
+- package: github.com/tendermint/go-config
+ version: develop
+- package: github.com/tendermint/go-crypto
+ version: develop
+- package: github.com/tendermint/go-wire/data
+ version: develop
+- package: github.com/tendermint/tmlibs/flowrate
+ subpackages:
+ - flowrate
+- package: github.com/tendermint/tmlibs/logger
+ version: develop
+- package: github.com/tendermint/go-wire
+ version: develop
+- package: github.com/tendermint/log15
+- package: golang.org/x/crypto
+ subpackages:
+ - nacl/box
+ - nacl/secretbox
+ - ripemd160
+- package: github.com/pkg/errors
+testImport:
+- package: github.com/stretchr/testify
+ subpackages:
+ - assert
+ - require
diff --git a/ip_range_counter.go b/ip_range_counter.go
new file mode 100644
index 000000000..85d9d407a
--- /dev/null
+++ b/ip_range_counter.go
@@ -0,0 +1,29 @@
+package p2p
+
+import (
+ "strings"
+)
+
+// TODO Test
+func AddToIPRangeCounts(counts map[string]int, ip string) map[string]int {
+ changes := make(map[string]int)
+ ipParts := strings.Split(ip, ":")
+ for i := 1; i < len(ipParts); i++ {
+ prefix := strings.Join(ipParts[:i], ":")
+ counts[prefix] += 1
+ changes[prefix] = counts[prefix]
+ }
+ return changes
+}
+
+// TODO Test
+func CheckIPRangeCounts(counts map[string]int, limits []int) bool {
+ for prefix, count := range counts {
+ ipParts := strings.Split(prefix, ":")
+ numParts := len(ipParts)
+ if limits[numParts] < count {
+ return false
+ }
+ }
+ return true
+}
diff --git a/listener.go b/listener.go
new file mode 100644
index 000000000..51beb5e27
--- /dev/null
+++ b/listener.go
@@ -0,0 +1,217 @@
+package p2p
+
+import (
+ "fmt"
+ "net"
+ "strconv"
+ "time"
+
+ . "github.com/tendermint/tmlibs/common"
+ "github.com/tendermint/go-p2p/upnp"
+)
+
+type Listener interface {
+ Connections() <-chan net.Conn
+ InternalAddress() *NetAddress
+ ExternalAddress() *NetAddress
+ String() string
+ Stop() bool
+}
+
+// Implements Listener
+type DefaultListener struct {
+ BaseService
+
+ listener net.Listener
+ intAddr *NetAddress
+ extAddr *NetAddress
+ connections chan net.Conn
+}
+
+const (
+ numBufferedConnections = 10
+ defaultExternalPort = 8770
+ tryListenSeconds = 5
+)
+
+func splitHostPort(addr string) (host string, port int) {
+ host, portStr, err := net.SplitHostPort(addr)
+ if err != nil {
+ PanicSanity(err)
+ }
+ port, err = strconv.Atoi(portStr)
+ if err != nil {
+ PanicSanity(err)
+ }
+ return host, port
+}
+
+// skipUPNP: If true, does not try getUPNPExternalAddress()
+func NewDefaultListener(protocol string, lAddr string, skipUPNP bool) Listener {
+ // Local listen IP & port
+ lAddrIP, lAddrPort := splitHostPort(lAddr)
+
+ // Create listener
+ var listener net.Listener
+ var err error
+ for i := 0; i < tryListenSeconds; i++ {
+ listener, err = net.Listen(protocol, lAddr)
+ if err == nil {
+ break
+ } else if i < tryListenSeconds-1 {
+ time.Sleep(time.Second * 1)
+ }
+ }
+ if err != nil {
+ PanicCrisis(err)
+ }
+ // Actual listener local IP & port
+ listenerIP, listenerPort := splitHostPort(listener.Addr().String())
+ log.Info("Local listener", "ip", listenerIP, "port", listenerPort)
+
+ // Determine internal address...
+ var intAddr *NetAddress
+ intAddr, err = NewNetAddressString(lAddr)
+ if err != nil {
+ PanicCrisis(err)
+ }
+
+ // Determine external address...
+ var extAddr *NetAddress
+ if !skipUPNP {
+ // If the lAddrIP is INADDR_ANY, try UPnP
+ if lAddrIP == "" || lAddrIP == "0.0.0.0" {
+ extAddr = getUPNPExternalAddress(lAddrPort, listenerPort)
+ }
+ }
+ // Otherwise just use the local address...
+ if extAddr == nil {
+ extAddr = getNaiveExternalAddress(listenerPort)
+ }
+ if extAddr == nil {
+ PanicCrisis("Could not determine external address!")
+ }
+
+ dl := &DefaultListener{
+ listener: listener,
+ intAddr: intAddr,
+ extAddr: extAddr,
+ connections: make(chan net.Conn, numBufferedConnections),
+ }
+ dl.BaseService = *NewBaseService(log, "DefaultListener", dl)
+ dl.Start() // Started upon construction
+ return dl
+}
+
+func (l *DefaultListener) OnStart() error {
+ l.BaseService.OnStart()
+ go l.listenRoutine()
+ return nil
+}
+
+func (l *DefaultListener) OnStop() {
+ l.BaseService.OnStop()
+ l.listener.Close()
+}
+
+// Accept connections and pass on the channel
+func (l *DefaultListener) listenRoutine() {
+ for {
+ conn, err := l.listener.Accept()
+
+ if !l.IsRunning() {
+ break // Go to cleanup
+ }
+
+ // listener wasn't stopped,
+ // yet we encountered an error.
+ if err != nil {
+ PanicCrisis(err)
+ }
+
+ l.connections <- conn
+ }
+
+ // Cleanup
+ close(l.connections)
+ for _ = range l.connections {
+ // Drain
+ }
+}
+
+// A channel of inbound connections.
+// It gets closed when the listener closes.
+func (l *DefaultListener) Connections() <-chan net.Conn {
+ return l.connections
+}
+
+func (l *DefaultListener) InternalAddress() *NetAddress {
+ return l.intAddr
+}
+
+func (l *DefaultListener) ExternalAddress() *NetAddress {
+ return l.extAddr
+}
+
+// NOTE: The returned listener is already Accept()'ing.
+// So it's not suitable to pass into http.Serve().
+func (l *DefaultListener) NetListener() net.Listener {
+ return l.listener
+}
+
+func (l *DefaultListener) String() string {
+ return fmt.Sprintf("Listener(@%v)", l.extAddr)
+}
+
+/* external address helpers */
+
+// UPNP external address discovery & port mapping
+func getUPNPExternalAddress(externalPort, internalPort int) *NetAddress {
+ log.Info("Getting UPNP external address")
+ nat, err := upnp.Discover()
+ if err != nil {
+ log.Info("Could not perform UPNP discover", "error", err)
+ return nil
+ }
+
+ ext, err := nat.GetExternalAddress()
+ if err != nil {
+ log.Info("Could not get UPNP external address", "error", err)
+ return nil
+ }
+
+ // UPnP can't seem to get the external port, so let's just be explicit.
+ if externalPort == 0 {
+ externalPort = defaultExternalPort
+ }
+
+ externalPort, err = nat.AddPortMapping("tcp", externalPort, internalPort, "tendermint", 0)
+ if err != nil {
+ log.Info("Could not add UPNP port mapping", "error", err)
+ return nil
+ }
+
+ log.Info("Got UPNP external address", "address", ext)
+ return NewNetAddressIPPort(ext, uint16(externalPort))
+}
+
+// TODO: use syscalls: http://pastebin.com/9exZG4rh
+func getNaiveExternalAddress(port int) *NetAddress {
+ addrs, err := net.InterfaceAddrs()
+ if err != nil {
+ PanicCrisis(Fmt("Could not fetch interface addresses: %v", err))
+ }
+
+ for _, a := range addrs {
+ ipnet, ok := a.(*net.IPNet)
+ if !ok {
+ continue
+ }
+ v4 := ipnet.IP.To4()
+ if v4 == nil || v4[0] == 127 {
+ continue
+ } // loopback
+ return NewNetAddressIPPort(ipnet.IP, uint16(port))
+ }
+ return nil
+}
diff --git a/listener_test.go b/listener_test.go
new file mode 100644
index 000000000..0f8a54946
--- /dev/null
+++ b/listener_test.go
@@ -0,0 +1,40 @@
+package p2p
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestListener(t *testing.T) {
+ // Create a listener
+ l := NewDefaultListener("tcp", ":8001", true)
+
+ // Dial the listener
+ lAddr := l.ExternalAddress()
+ connOut, err := lAddr.Dial()
+ if err != nil {
+ t.Fatalf("Could not connect to listener address %v", lAddr)
+ } else {
+ t.Logf("Created a connection to listener address %v", lAddr)
+ }
+ connIn, ok := <-l.Connections()
+ if !ok {
+ t.Fatalf("Could not get inbound connection from listener")
+ }
+
+ msg := []byte("hi!")
+ go connIn.Write(msg)
+ b := make([]byte, 32)
+ n, err := connOut.Read(b)
+ if err != nil {
+ t.Fatalf("Error reading off connection: %v", err)
+ }
+
+ b = b[:n]
+ if !bytes.Equal(msg, b) {
+ t.Fatalf("Got %s, expected %s", b, msg)
+ }
+
+ // Close the server, no longer needed.
+ l.Stop()
+}
diff --git a/log.go b/log.go
new file mode 100644
index 000000000..af3203409
--- /dev/null
+++ b/log.go
@@ -0,0 +1,7 @@
+package p2p
+
+import (
+ "github.com/tendermint/tmlibs/logger"
+)
+
+var log = logger.New("module", "p2p")
diff --git a/netaddress.go b/netaddress.go
new file mode 100644
index 000000000..09787481c
--- /dev/null
+++ b/netaddress.go
@@ -0,0 +1,253 @@
+// Modified for Tendermint
+// Originally Copyright (c) 2013-2014 Conformal Systems LLC.
+// https://github.com/conformal/btcd/blob/master/LICENSE
+
+package p2p
+
+import (
+ "errors"
+ "flag"
+ "net"
+ "strconv"
+ "time"
+
+ cmn "github.com/tendermint/tmlibs/common"
+)
+
+// NetAddress defines information about a peer on the network
+// including its IP address, and port.
+type NetAddress struct {
+ IP net.IP
+ Port uint16
+ str string
+}
+
+// NewNetAddress returns a new NetAddress using the provided TCP
+// address. When testing, other net.Addr (except TCP) will result in
+// using 0.0.0.0:0. When normal run, other net.Addr (except TCP) will
+// panic.
+// TODO: socks proxies?
+func NewNetAddress(addr net.Addr) *NetAddress {
+ tcpAddr, ok := addr.(*net.TCPAddr)
+ if !ok {
+ if flag.Lookup("test.v") == nil { // normal run
+ cmn.PanicSanity(cmn.Fmt("Only TCPAddrs are supported. Got: %v", addr))
+ } else { // in testing
+ return NewNetAddressIPPort(net.IP("0.0.0.0"), 0)
+ }
+ }
+ ip := tcpAddr.IP
+ port := uint16(tcpAddr.Port)
+ return NewNetAddressIPPort(ip, port)
+}
+
+// NewNetAddressString returns a new NetAddress using the provided
+// address in the form of "IP:Port". Also resolves the host if host
+// is not an IP.
+func NewNetAddressString(addr string) (*NetAddress, error) {
+
+ host, portStr, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+
+ ip := net.ParseIP(host)
+ if ip == nil {
+ if len(host) > 0 {
+ ips, err := net.LookupIP(host)
+ if err != nil {
+ return nil, err
+ }
+ ip = ips[0]
+ }
+ }
+
+ port, err := strconv.ParseUint(portStr, 10, 16)
+ if err != nil {
+ return nil, err
+ }
+
+ na := NewNetAddressIPPort(ip, uint16(port))
+ return na, nil
+}
+
+// NewNetAddressStrings returns an array of NetAddress'es build using
+// the provided strings.
+func NewNetAddressStrings(addrs []string) ([]*NetAddress, error) {
+ netAddrs := make([]*NetAddress, len(addrs))
+ for i, addr := range addrs {
+ netAddr, err := NewNetAddressString(addr)
+ if err != nil {
+ return nil, errors.New(cmn.Fmt("Error in address %s: %v", addr, err))
+ }
+ netAddrs[i] = netAddr
+ }
+ return netAddrs, nil
+}
+
+// NewNetAddressIPPort returns a new NetAddress using the provided IP
+// and port number.
+func NewNetAddressIPPort(ip net.IP, port uint16) *NetAddress {
+ na := &NetAddress{
+ IP: ip,
+ Port: port,
+ str: net.JoinHostPort(
+ ip.String(),
+ strconv.FormatUint(uint64(port), 10),
+ ),
+ }
+ return na
+}
+
+// Equals reports whether na and other are the same addresses.
+func (na *NetAddress) Equals(other interface{}) bool {
+ if o, ok := other.(*NetAddress); ok {
+ return na.String() == o.String()
+ }
+
+ return false
+}
+
+func (na *NetAddress) Less(other interface{}) bool {
+ if o, ok := other.(*NetAddress); ok {
+ return na.String() < o.String()
+ }
+
+ cmn.PanicSanity("Cannot compare unequal types")
+ return false
+}
+
+// String representation.
+func (na *NetAddress) String() string {
+ if na.str == "" {
+ na.str = net.JoinHostPort(
+ na.IP.String(),
+ strconv.FormatUint(uint64(na.Port), 10),
+ )
+ }
+ return na.str
+}
+
+// Dial calls net.Dial on the address.
+func (na *NetAddress) Dial() (net.Conn, error) {
+ conn, err := net.Dial("tcp", na.String())
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+// DialTimeout calls net.DialTimeout on the address.
+func (na *NetAddress) DialTimeout(timeout time.Duration) (net.Conn, error) {
+ conn, err := net.DialTimeout("tcp", na.String(), timeout)
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+// Routable returns true if the address is routable.
+func (na *NetAddress) Routable() bool {
+ // TODO(oga) bitcoind doesn't include RFC3849 here, but should we?
+ return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() ||
+ na.RFC4193() || na.RFC4843() || na.Local())
+}
+
+// For IPv4 these are either a 0 or all bits set address. For IPv6 a zero
+// address or one that matches the RFC3849 documentation address format.
+func (na *NetAddress) Valid() bool {
+ return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() ||
+ na.IP.Equal(net.IPv4bcast))
+}
+
+// Local returns true if it is a local address.
+func (na *NetAddress) Local() bool {
+ return na.IP.IsLoopback() || zero4.Contains(na.IP)
+}
+
+// ReachabilityTo checks whenever o can be reached from na.
+func (na *NetAddress) ReachabilityTo(o *NetAddress) int {
+ const (
+ Unreachable = 0
+ Default = iota
+ Teredo
+ Ipv6_weak
+ Ipv4
+ Ipv6_strong
+ Private
+ )
+ if !na.Routable() {
+ return Unreachable
+ } else if na.RFC4380() {
+ if !o.Routable() {
+ return Default
+ } else if o.RFC4380() {
+ return Teredo
+ } else if o.IP.To4() != nil {
+ return Ipv4
+ } else { // ipv6
+ return Ipv6_weak
+ }
+ } else if na.IP.To4() != nil {
+ if o.Routable() && o.IP.To4() != nil {
+ return Ipv4
+ }
+ return Default
+ } else /* ipv6 */ {
+ var tunnelled bool
+ // Is our v6 is tunnelled?
+ if o.RFC3964() || o.RFC6052() || o.RFC6145() {
+ tunnelled = true
+ }
+ if !o.Routable() {
+ return Default
+ } else if o.RFC4380() {
+ return Teredo
+ } else if o.IP.To4() != nil {
+ return Ipv4
+ } else if tunnelled {
+ // only prioritise ipv6 if we aren't tunnelling it.
+ return Ipv6_weak
+ }
+ return Ipv6_strong
+ }
+}
+
+// RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12)
+// RFC3849: IPv6 Documentation address (2001:0DB8::/32)
+// RFC3927: IPv4 Autoconfig (169.254.0.0/16)
+// RFC3964: IPv6 6to4 (2002::/16)
+// RFC4193: IPv6 unique local (FC00::/7)
+// RFC4380: IPv6 Teredo tunneling (2001::/32)
+// RFC4843: IPv6 ORCHID: (2001:10::/28)
+// RFC4862: IPv6 Autoconfig (FE80::/64)
+// RFC6052: IPv6 well known prefix (64:FF9B::/96)
+// RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96
+var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}
+var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}
+var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}
+var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)}
+var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}
+var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)}
+var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)}
+var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)}
+var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)}
+var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)}
+var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)}
+var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)}
+var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)}
+
+func (na *NetAddress) RFC1918() bool {
+ return rfc1918_10.Contains(na.IP) ||
+ rfc1918_192.Contains(na.IP) ||
+ rfc1918_172.Contains(na.IP)
+}
+func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) }
+func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) }
+func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) }
+func (na *NetAddress) RFC4193() bool { return rfc4193.Contains(na.IP) }
+func (na *NetAddress) RFC4380() bool { return rfc4380.Contains(na.IP) }
+func (na *NetAddress) RFC4843() bool { return rfc4843.Contains(na.IP) }
+func (na *NetAddress) RFC4862() bool { return rfc4862.Contains(na.IP) }
+func (na *NetAddress) RFC6052() bool { return rfc6052.Contains(na.IP) }
+func (na *NetAddress) RFC6145() bool { return rfc6145.Contains(na.IP) }
diff --git a/netaddress_test.go b/netaddress_test.go
new file mode 100644
index 000000000..db871fdec
--- /dev/null
+++ b/netaddress_test.go
@@ -0,0 +1,113 @@
+package p2p
+
+import (
+ "net"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestNewNetAddress(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
+ require.Nil(err)
+ addr := NewNetAddress(tcpAddr)
+
+ assert.Equal("127.0.0.1:8080", addr.String())
+
+ assert.NotPanics(func() {
+ NewNetAddress(&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8000})
+ }, "Calling NewNetAddress with UDPAddr should not panic in testing")
+}
+
+func TestNewNetAddressString(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ tests := []struct {
+ addr string
+ correct bool
+ }{
+ {"127.0.0.1:8080", true},
+ {"127.0.0:8080", false},
+ {"a", false},
+ {"127.0.0.1:a", false},
+ {"a:8080", false},
+ {"8082", false},
+ {"127.0.0:8080000", false},
+ }
+
+ for _, t := range tests {
+ addr, err := NewNetAddressString(t.addr)
+ if t.correct {
+ require.Nil(err)
+ assert.Equal(t.addr, addr.String())
+ } else {
+ require.NotNil(err)
+ }
+ }
+}
+
+func TestNewNetAddressStrings(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+ addrs, err := NewNetAddressStrings([]string{"127.0.0.1:8080", "127.0.0.2:8080"})
+ require.Nil(err)
+
+ assert.Equal(2, len(addrs))
+}
+
+func TestNewNetAddressIPPort(t *testing.T) {
+ assert := assert.New(t)
+ addr := NewNetAddressIPPort(net.ParseIP("127.0.0.1"), 8080)
+
+ assert.Equal("127.0.0.1:8080", addr.String())
+}
+
+func TestNetAddressProperties(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ // TODO add more test cases
+ tests := []struct {
+ addr string
+ valid bool
+ local bool
+ routable bool
+ }{
+ {"127.0.0.1:8080", true, true, false},
+ {"ya.ru:80", true, false, true},
+ }
+
+ for _, t := range tests {
+ addr, err := NewNetAddressString(t.addr)
+ require.Nil(err)
+
+ assert.Equal(t.valid, addr.Valid())
+ assert.Equal(t.local, addr.Local())
+ assert.Equal(t.routable, addr.Routable())
+ }
+}
+
+func TestNetAddressReachabilityTo(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ // TODO add more test cases
+ tests := []struct {
+ addr string
+ other string
+ reachability int
+ }{
+ {"127.0.0.1:8080", "127.0.0.1:8081", 0},
+ {"ya.ru:80", "127.0.0.1:8080", 1},
+ }
+
+ for _, t := range tests {
+ addr, err := NewNetAddressString(t.addr)
+ require.Nil(err)
+
+ other, err := NewNetAddressString(t.other)
+ require.Nil(err)
+
+ assert.Equal(t.reachability, addr.ReachabilityTo(other))
+ }
+}
diff --git a/peer.go b/peer.go
new file mode 100644
index 000000000..bf819087d
--- /dev/null
+++ b/peer.go
@@ -0,0 +1,304 @@
+package p2p
+
+import (
+ "fmt"
+ "io"
+ "net"
+ "time"
+
+ "github.com/pkg/errors"
+ crypto "github.com/tendermint/go-crypto"
+ wire "github.com/tendermint/go-wire"
+ cmn "github.com/tendermint/tmlibs/common"
+)
+
+// Peer could be marked as persistent, in which case you can use
+// Redial function to reconnect. Note that inbound peers can't be
+// made persistent. They should be made persistent on the other end.
+//
+// Before using a peer, you will need to perform a handshake on connection.
+type Peer struct {
+ cmn.BaseService
+
+ outbound bool
+
+ conn net.Conn // source connection
+ mconn *MConnection // multiplex connection
+
+ persistent bool
+ config *PeerConfig
+
+ *NodeInfo
+ Key string
+ Data *cmn.CMap // User data.
+}
+
+// PeerConfig is a Peer configuration.
+type PeerConfig struct {
+ AuthEnc bool // authenticated encryption
+
+ HandshakeTimeout time.Duration
+ DialTimeout time.Duration
+
+ MConfig *MConnConfig
+
+ Fuzz bool // fuzz connection (for testing)
+ FuzzConfig *FuzzConnConfig
+}
+
+// DefaultPeerConfig returns the default config.
+func DefaultPeerConfig() *PeerConfig {
+ return &PeerConfig{
+ AuthEnc: true,
+ HandshakeTimeout: 2 * time.Second,
+ DialTimeout: 3 * time.Second,
+ MConfig: DefaultMConnConfig(),
+ Fuzz: false,
+ FuzzConfig: DefaultFuzzConnConfig(),
+ }
+}
+
+func newOutboundPeer(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) {
+ return newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig())
+}
+
+func newOutboundPeerWithConfig(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) {
+ conn, err := dial(addr, config)
+ if err != nil {
+ return nil, errors.Wrap(err, "Error creating peer")
+ }
+
+ peer, err := newPeerFromConnAndConfig(conn, true, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config)
+ if err != nil {
+ conn.Close()
+ return nil, err
+ }
+ return peer, nil
+}
+
+func newInboundPeer(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) {
+ return newInboundPeerWithConfig(conn, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig())
+}
+
+func newInboundPeerWithConfig(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) {
+ return newPeerFromConnAndConfig(conn, false, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config)
+}
+
+func newPeerFromConnAndConfig(rawConn net.Conn, outbound bool, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) {
+ conn := rawConn
+
+ // Fuzz connection
+ if config.Fuzz {
+ // so we have time to do peer handshakes and get set up
+ conn = FuzzConnAfterFromConfig(conn, 10*time.Second, config.FuzzConfig)
+ }
+
+ // Encrypt connection
+ if config.AuthEnc {
+ conn.SetDeadline(time.Now().Add(config.HandshakeTimeout))
+
+ var err error
+ conn, err = MakeSecretConnection(conn, ourNodePrivKey)
+ if err != nil {
+ return nil, errors.Wrap(err, "Error creating peer")
+ }
+ }
+
+ // Key and NodeInfo are set after Handshake
+ p := &Peer{
+ outbound: outbound,
+ conn: conn,
+ config: config,
+ Data: cmn.NewCMap(),
+ }
+
+ p.mconn = createMConnection(conn, p, reactorsByCh, chDescs, onPeerError, config.MConfig)
+
+ p.BaseService = *cmn.NewBaseService(log, "Peer", p)
+
+ return p, nil
+}
+
+// CloseConn should be used when the peer was created, but never started.
+func (p *Peer) CloseConn() {
+ p.conn.Close()
+}
+
+// makePersistent marks the peer as persistent.
+func (p *Peer) makePersistent() {
+ if !p.outbound {
+ panic("inbound peers can't be made persistent")
+ }
+
+ p.persistent = true
+}
+
+// IsPersistent returns true if the peer is persitent, false otherwise.
+func (p *Peer) IsPersistent() bool {
+ return p.persistent
+}
+
+// HandshakeTimeout performs a handshake between a given node and the peer.
+// NOTE: blocking
+func (p *Peer) HandshakeTimeout(ourNodeInfo *NodeInfo, timeout time.Duration) error {
+ // Set deadline for handshake so we don't block forever on conn.ReadFull
+ p.conn.SetDeadline(time.Now().Add(timeout))
+
+ var peerNodeInfo = new(NodeInfo)
+ var err1 error
+ var err2 error
+ cmn.Parallel(
+ func() {
+ var n int
+ wire.WriteBinary(ourNodeInfo, p.conn, &n, &err1)
+ },
+ func() {
+ var n int
+ wire.ReadBinary(peerNodeInfo, p.conn, maxNodeInfoSize, &n, &err2)
+ log.Notice("Peer handshake", "peerNodeInfo", peerNodeInfo)
+ })
+ if err1 != nil {
+ return errors.Wrap(err1, "Error during handshake/write")
+ }
+ if err2 != nil {
+ return errors.Wrap(err2, "Error during handshake/read")
+ }
+
+ if p.config.AuthEnc {
+ // Check that the professed PubKey matches the sconn's.
+ if !peerNodeInfo.PubKey.Equals(p.PubKey().Wrap()) {
+ return fmt.Errorf("Ignoring connection with unmatching pubkey: %v vs %v",
+ peerNodeInfo.PubKey, p.PubKey())
+ }
+ }
+
+ // Remove deadline
+ p.conn.SetDeadline(time.Time{})
+
+ peerNodeInfo.RemoteAddr = p.Addr().String()
+
+ p.NodeInfo = peerNodeInfo
+ p.Key = peerNodeInfo.PubKey.KeyString()
+
+ return nil
+}
+
+// Addr returns peer's network address.
+func (p *Peer) Addr() net.Addr {
+ return p.conn.RemoteAddr()
+}
+
+// PubKey returns peer's public key.
+func (p *Peer) PubKey() crypto.PubKeyEd25519 {
+ if p.config.AuthEnc {
+ return p.conn.(*SecretConnection).RemotePubKey()
+ }
+ if p.NodeInfo == nil {
+ panic("Attempt to get peer's PubKey before calling Handshake")
+ }
+ return p.PubKey()
+}
+
+// OnStart implements BaseService.
+func (p *Peer) OnStart() error {
+ p.BaseService.OnStart()
+ _, err := p.mconn.Start()
+ return err
+}
+
+// OnStop implements BaseService.
+func (p *Peer) OnStop() {
+ p.BaseService.OnStop()
+ p.mconn.Stop()
+}
+
+// Connection returns underlying MConnection.
+func (p *Peer) Connection() *MConnection {
+ return p.mconn
+}
+
+// IsOutbound returns true if the connection is outbound, false otherwise.
+func (p *Peer) IsOutbound() bool {
+ return p.outbound
+}
+
+// Send msg to the channel identified by chID byte. Returns false if the send
+// queue is full after timeout, specified by MConnection.
+func (p *Peer) Send(chID byte, msg interface{}) bool {
+ if !p.IsRunning() {
+ // see Switch#Broadcast, where we fetch the list of peers and loop over
+ // them - while we're looping, one peer may be removed and stopped.
+ return false
+ }
+ return p.mconn.Send(chID, msg)
+}
+
+// TrySend msg to the channel identified by chID byte. Immediately returns
+// false if the send queue is full.
+func (p *Peer) TrySend(chID byte, msg interface{}) bool {
+ if !p.IsRunning() {
+ return false
+ }
+ return p.mconn.TrySend(chID, msg)
+}
+
+// CanSend returns true if the send queue is not full, false otherwise.
+func (p *Peer) CanSend(chID byte) bool {
+ if !p.IsRunning() {
+ return false
+ }
+ return p.mconn.CanSend(chID)
+}
+
+// WriteTo writes the peer's public key to w.
+func (p *Peer) WriteTo(w io.Writer) (n int64, err error) {
+ var n_ int
+ wire.WriteString(p.Key, w, &n_, &err)
+ n += int64(n_)
+ return
+}
+
+// String representation.
+func (p *Peer) String() string {
+ if p.outbound {
+ return fmt.Sprintf("Peer{%v %v out}", p.mconn, p.Key[:12])
+ }
+
+ return fmt.Sprintf("Peer{%v %v in}", p.mconn, p.Key[:12])
+}
+
+// Equals reports whenever 2 peers are actually represent the same node.
+func (p *Peer) Equals(other *Peer) bool {
+ return p.Key == other.Key
+}
+
+// Get the data for a given key.
+func (p *Peer) Get(key string) interface{} {
+ return p.Data.Get(key)
+}
+
+func dial(addr *NetAddress, config *PeerConfig) (net.Conn, error) {
+ log.Info("Dialing address", "address", addr)
+ conn, err := addr.DialTimeout(config.DialTimeout)
+ if err != nil {
+ log.Info("Failed dialing address", "address", addr, "error", err)
+ return nil, err
+ }
+ return conn, nil
+}
+
+func createMConnection(conn net.Conn, p *Peer, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), config *MConnConfig) *MConnection {
+ onReceive := func(chID byte, msgBytes []byte) {
+ reactor := reactorsByCh[chID]
+ if reactor == nil {
+ cmn.PanicSanity(cmn.Fmt("Unknown channel %X", chID))
+ }
+ reactor.Receive(chID, p, msgBytes)
+ }
+
+ onError := func(r interface{}) {
+ onPeerError(p, r)
+ }
+
+ return NewMConnectionWithConfig(conn, chDescs, onReceive, onError, config)
+}
diff --git a/peer_set.go b/peer_set.go
new file mode 100644
index 000000000..f3bc1edaf
--- /dev/null
+++ b/peer_set.go
@@ -0,0 +1,115 @@
+package p2p
+
+import (
+ "sync"
+)
+
+// IPeerSet has a (immutable) subset of the methods of PeerSet.
+type IPeerSet interface {
+ Has(key string) bool
+ Get(key string) *Peer
+ List() []*Peer
+ Size() int
+}
+
+//-----------------------------------------------------------------------------
+
+// PeerSet is a special structure for keeping a table of peers.
+// Iteration over the peers is super fast and thread-safe.
+// We also track how many peers per IP range and avoid too many
+type PeerSet struct {
+ mtx sync.Mutex
+ lookup map[string]*peerSetItem
+ list []*Peer
+}
+
+type peerSetItem struct {
+ peer *Peer
+ index int
+}
+
+func NewPeerSet() *PeerSet {
+ return &PeerSet{
+ lookup: make(map[string]*peerSetItem),
+ list: make([]*Peer, 0, 256),
+ }
+}
+
+// Returns false if peer with key (PubKeyEd25519) is already in set
+// or if we have too many peers from the peer's IP range
+func (ps *PeerSet) Add(peer *Peer) error {
+ ps.mtx.Lock()
+ defer ps.mtx.Unlock()
+ if ps.lookup[peer.Key] != nil {
+ return ErrSwitchDuplicatePeer
+ }
+
+ index := len(ps.list)
+ // Appending is safe even with other goroutines
+ // iterating over the ps.list slice.
+ ps.list = append(ps.list, peer)
+ ps.lookup[peer.Key] = &peerSetItem{peer, index}
+ return nil
+}
+
+func (ps *PeerSet) Has(peerKey string) bool {
+ ps.mtx.Lock()
+ defer ps.mtx.Unlock()
+ _, ok := ps.lookup[peerKey]
+ return ok
+}
+
+func (ps *PeerSet) Get(peerKey string) *Peer {
+ ps.mtx.Lock()
+ defer ps.mtx.Unlock()
+ item, ok := ps.lookup[peerKey]
+ if ok {
+ return item.peer
+ } else {
+ return nil
+ }
+}
+
+func (ps *PeerSet) Remove(peer *Peer) {
+ ps.mtx.Lock()
+ defer ps.mtx.Unlock()
+ item := ps.lookup[peer.Key]
+ if item == nil {
+ return
+ }
+
+ index := item.index
+ // Copy the list but without the last element.
+ // (we must copy because we're mutating the list)
+ newList := make([]*Peer, len(ps.list)-1)
+ copy(newList, ps.list)
+ // If it's the last peer, that's an easy special case.
+ if index == len(ps.list)-1 {
+ ps.list = newList
+ delete(ps.lookup, peer.Key)
+ return
+ }
+
+ // Move the last item from ps.list to "index" in list.
+ lastPeer := ps.list[len(ps.list)-1]
+ lastPeerKey := lastPeer.Key
+ lastPeerItem := ps.lookup[lastPeerKey]
+ newList[index] = lastPeer
+ lastPeerItem.index = index
+ ps.list = newList
+ delete(ps.lookup, peer.Key)
+
+}
+
+func (ps *PeerSet) Size() int {
+ ps.mtx.Lock()
+ defer ps.mtx.Unlock()
+ return len(ps.list)
+}
+
+// threadsafe list of peers.
+func (ps *PeerSet) List() []*Peer {
+ ps.mtx.Lock()
+ defer ps.mtx.Unlock()
+ return ps.list
+}
diff --git a/peer_set_test.go b/peer_set_test.go
new file mode 100644
index 000000000..a17f9d658
--- /dev/null
+++ b/peer_set_test.go
@@ -0,0 +1,67 @@
+package p2p
+
+import (
+ "math/rand"
+ "testing"
+
+ . "github.com/tendermint/tmlibs/common"
+)
+
+// Returns an empty dummy peer
+func randPeer() *Peer {
+ return &Peer{
+ Key: RandStr(12),
+ NodeInfo: &NodeInfo{
+ RemoteAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256),
+ ListenAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256),
+ },
+ }
+}
+
+func TestAddRemoveOne(t *testing.T) {
+ peerSet := NewPeerSet()
+
+ peer := randPeer()
+ err := peerSet.Add(peer)
+ if err != nil {
+ t.Errorf("Failed to add new peer")
+ }
+ if peerSet.Size() != 1 {
+ t.Errorf("Failed to add new peer and increment size")
+ }
+
+ peerSet.Remove(peer)
+ if peerSet.Has(peer.Key) {
+ t.Errorf("Failed to remove peer")
+ }
+ if peerSet.Size() != 0 {
+ t.Errorf("Failed to remove peer and decrement size")
+ }
+}
+
+func TestAddRemoveMany(t *testing.T) {
+ peerSet := NewPeerSet()
+
+ peers := []*Peer{}
+ N := 100
+ for i := 0; i < N; i++ {
+ peer := randPeer()
+ if err := peerSet.Add(peer); err != nil {
+ t.Errorf("Failed to add new peer")
+ }
+ if peerSet.Size() != i+1 {
+ t.Errorf("Failed to add new peer and increment size")
+ }
+ peers = append(peers, peer)
+ }
+
+ for i, peer := range peers {
+ peerSet.Remove(peer)
+ if peerSet.Has(peer.Key) {
+ t.Errorf("Failed to remove peer")
+ }
+ if peerSet.Size() != len(peers)-i-1 {
+ t.Errorf("Failed to remove peer and decrement size")
+ }
+ }
+}
diff --git a/peer_test.go b/peer_test.go
new file mode 100644
index 000000000..0ac776347
--- /dev/null
+++ b/peer_test.go
@@ -0,0 +1,156 @@
+package p2p
+
+import (
+ golog "log"
+ "net"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ crypto "github.com/tendermint/go-crypto"
+)
+
+func TestPeerBasic(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ // simulate remote peer
+ rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()}
+ rp.Start()
+ defer rp.Stop()
+
+ p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), DefaultPeerConfig())
+ require.Nil(err)
+
+ p.Start()
+ defer p.Stop()
+
+ assert.True(p.IsRunning())
+ assert.True(p.IsOutbound())
+ assert.False(p.IsPersistent())
+ p.makePersistent()
+ assert.True(p.IsPersistent())
+ assert.Equal(rp.Addr().String(), p.Addr().String())
+ assert.Equal(rp.PubKey(), p.PubKey())
+}
+
+func TestPeerWithoutAuthEnc(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ config := DefaultPeerConfig()
+ config.AuthEnc = false
+
+ // simulate remote peer
+ rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config}
+ rp.Start()
+ defer rp.Stop()
+
+ p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config)
+ require.Nil(err)
+
+ p.Start()
+ defer p.Stop()
+
+ assert.True(p.IsRunning())
+}
+
+func TestPeerSend(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ config := DefaultPeerConfig()
+ config.AuthEnc = false
+
+ // simulate remote peer
+ rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config}
+ rp.Start()
+ defer rp.Stop()
+
+ p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config)
+ require.Nil(err)
+
+ p.Start()
+ defer p.Stop()
+
+ assert.True(p.CanSend(0x01))
+ assert.True(p.Send(0x01, "Asylum"))
+}
+
+func createOutboundPeerAndPerformHandshake(addr *NetAddress, config *PeerConfig) (*Peer, error) {
+ chDescs := []*ChannelDescriptor{
+ &ChannelDescriptor{ID: 0x01, Priority: 1},
+ }
+ reactorsByCh := map[byte]Reactor{0x01: NewTestReactor(chDescs, true)}
+ pk := crypto.GenPrivKeyEd25519()
+ p, err := newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, func(p *Peer, r interface{}) {}, pk, config)
+ if err != nil {
+ return nil, err
+ }
+ err = p.HandshakeTimeout(&NodeInfo{
+ PubKey: pk.PubKey().Unwrap().(crypto.PubKeyEd25519),
+ Moniker: "host_peer",
+ Network: "testing",
+ Version: "123.123.123",
+ }, 1*time.Second)
+ if err != nil {
+ return nil, err
+ }
+ return p, nil
+}
+
+type remotePeer struct {
+ PrivKey crypto.PrivKeyEd25519
+ Config *PeerConfig
+ addr *NetAddress
+ quit chan struct{}
+}
+
+func (p *remotePeer) Addr() *NetAddress {
+ return p.addr
+}
+
+func (p *remotePeer) PubKey() crypto.PubKeyEd25519 {
+ return p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
+}
+
+func (p *remotePeer) Start() {
+ l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
+ if e != nil {
+ golog.Fatalf("net.Listen tcp :0: %+v", e)
+ }
+ p.addr = NewNetAddress(l.Addr())
+ p.quit = make(chan struct{})
+ go p.accept(l)
+}
+
+func (p *remotePeer) Stop() {
+ close(p.quit)
+}
+
+func (p *remotePeer) accept(l net.Listener) {
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ golog.Fatalf("Failed to accept conn: %+v", err)
+ }
+ peer, err := newInboundPeerWithConfig(conn, make(map[byte]Reactor), make([]*ChannelDescriptor, 0), func(p *Peer, r interface{}) {}, p.PrivKey, p.Config)
+ if err != nil {
+ golog.Fatalf("Failed to create a peer: %+v", err)
+ }
+ err = peer.HandshakeTimeout(&NodeInfo{
+ PubKey: p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519),
+ Moniker: "remote_peer",
+ Network: "testing",
+ Version: "123.123.123",
+ }, 1*time.Second)
+ if err != nil {
+ golog.Fatalf("Failed to perform handshake: %+v", err)
+ }
+ select {
+ case <-p.quit:
+ conn.Close()
+ return
+ default:
+ }
+ }
+}
diff --git a/pex_reactor.go b/pex_reactor.go
new file mode 100644
index 000000000..03a383c85
--- /dev/null
+++ b/pex_reactor.go
@@ -0,0 +1,358 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "math/rand"
+ "reflect"
+ "time"
+
+ cmn "github.com/tendermint/tmlibs/common"
+ wire "github.com/tendermint/go-wire"
+)
+
+const (
+ // PexChannel is a channel for PEX messages
+ PexChannel = byte(0x00)
+
+ // period to ensure peers connected
+ defaultEnsurePeersPeriod = 30 * time.Second
+ minNumOutboundPeers = 10
+ maxPexMessageSize = 1048576 // 1MB
+
+ // maximum messages one peer can send to us during `msgCountByPeerFlushInterval`
+ defaultMaxMsgCountByPeer = 1000
+ msgCountByPeerFlushInterval = 1 * time.Hour
+)
+
+// PEXReactor handles PEX (peer exchange) and ensures that an
+// adequate number of peers are connected to the switch.
+//
+// It uses `AddrBook` (address book) to store `NetAddress`es of the peers.
+//
+// ## Preventing abuse
+//
+// For now, it just limits the number of messages from one peer to
+// `defaultMaxMsgCountByPeer` messages per `msgCountByPeerFlushInterval` (1000
+// msg/hour).
+//
+// NOTE [2017-01-17]:
+// Limiting is fine for now. Maybe down the road we want to keep track of the
+// quality of peer messages so if peerA keeps telling us about peers we can't
+// connect to then maybe we should care less about peerA. But I don't think
+// that kind of complexity is priority right now.
+type PEXReactor struct {
+ BaseReactor
+
+ sw *Switch
+ book *AddrBook
+ ensurePeersPeriod time.Duration
+
+ // tracks message count by peer, so we can prevent abuse
+ msgCountByPeer *cmn.CMap
+ maxMsgCountByPeer uint16
+}
+
+// NewPEXReactor creates new PEX reactor.
+func NewPEXReactor(b *AddrBook) *PEXReactor {
+ r := &PEXReactor{
+ book: b,
+ ensurePeersPeriod: defaultEnsurePeersPeriod,
+ msgCountByPeer: cmn.NewCMap(),
+ maxMsgCountByPeer: defaultMaxMsgCountByPeer,
+ }
+ r.BaseReactor = *NewBaseReactor(log, "PEXReactor", r)
+ return r
+}
+
+// OnStart implements BaseService
+func (r *PEXReactor) OnStart() error {
+ r.BaseReactor.OnStart()
+ r.book.Start()
+ go r.ensurePeersRoutine()
+ go r.flushMsgCountByPeer()
+ return nil
+}
+
+// OnStop implements BaseService
+func (r *PEXReactor) OnStop() {
+ r.BaseReactor.OnStop()
+ r.book.Stop()
+}
+
+// GetChannels implements Reactor
+func (r *PEXReactor) GetChannels() []*ChannelDescriptor {
+ return []*ChannelDescriptor{
+ &ChannelDescriptor{
+ ID: PexChannel,
+ Priority: 1,
+ SendQueueCapacity: 10,
+ },
+ }
+}
+
+// AddPeer implements Reactor by adding peer to the address book (if inbound)
+// or by requesting more addresses (if outbound).
+func (r *PEXReactor) AddPeer(p *Peer) {
+ if p.IsOutbound() {
+ // For outbound peers, the address is already in the books.
+ // Either it was added in DialSeeds or when we
+ // received the peer's address in r.Receive
+ if r.book.NeedMoreAddrs() {
+ r.RequestPEX(p)
+ }
+ } else { // For inbound connections, the peer is its own source
+ addr, err := NewNetAddressString(p.ListenAddr)
+ if err != nil {
+ // this should never happen
+ log.Error("Error in AddPeer: invalid peer address", "addr", p.ListenAddr, "error", err)
+ return
+ }
+ r.book.AddAddress(addr, addr)
+ }
+}
+
+// RemovePeer implements Reactor.
+func (r *PEXReactor) RemovePeer(p *Peer, reason interface{}) {
+ // If we aren't keeping track of local temp data for each peer here, then we
+ // don't have to do anything.
+}
+
+// Receive implements Reactor by handling incoming PEX messages.
+func (r *PEXReactor) Receive(chID byte, src *Peer, msgBytes []byte) {
+ srcAddr := src.Connection().RemoteAddress
+ srcAddrStr := srcAddr.String()
+
+ r.IncrementMsgCountForPeer(srcAddrStr)
+ if r.ReachedMaxMsgCountForPeer(srcAddrStr) {
+ log.Warn("Maximum number of messages reached for peer", "peer", srcAddrStr)
+ // TODO remove src from peers?
+ return
+ }
+
+ _, msg, err := DecodeMessage(msgBytes)
+ if err != nil {
+ log.Warn("Error decoding message", "error", err)
+ return
+ }
+ log.Notice("Received message", "msg", msg)
+
+ switch msg := msg.(type) {
+ case *pexRequestMessage:
+ // src requested some peers.
+ r.SendAddrs(src, r.book.GetSelection())
+ case *pexAddrsMessage:
+ // We received some peer addresses from src.
+ // (We don't want to get spammed with bad peers)
+ for _, addr := range msg.Addrs {
+ if addr != nil {
+ r.book.AddAddress(addr, srcAddr)
+ }
+ }
+ default:
+ log.Warn(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg)))
+ }
+}
+
+// RequestPEX asks peer for more addresses.
+func (r *PEXReactor) RequestPEX(p *Peer) {
+ p.Send(PexChannel, struct{ PexMessage }{&pexRequestMessage{}})
+}
+
+// SendAddrs sends addrs to the peer.
+func (r *PEXReactor) SendAddrs(p *Peer, addrs []*NetAddress) {
+ p.Send(PexChannel, struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}})
+}
+
+// SetEnsurePeersPeriod sets period to ensure peers connected.
+func (r *PEXReactor) SetEnsurePeersPeriod(d time.Duration) {
+ r.ensurePeersPeriod = d
+}
+
+// SetMaxMsgCountByPeer sets maximum messages one peer can send to us during 'msgCountByPeerFlushInterval'.
+func (r *PEXReactor) SetMaxMsgCountByPeer(v uint16) {
+ r.maxMsgCountByPeer = v
+}
+
+// ReachedMaxMsgCountForPeer returns true if we received too many
+// messages from peer with address `addr`.
+// NOTE: assumes the value in the CMap is non-nil
+func (r *PEXReactor) ReachedMaxMsgCountForPeer(addr string) bool {
+ return r.msgCountByPeer.Get(addr).(uint16) >= r.maxMsgCountByPeer
+}
+
+// Increment or initialize the msg count for the peer in the CMap
+func (r *PEXReactor) IncrementMsgCountForPeer(addr string) {
+ var count uint16
+ countI := r.msgCountByPeer.Get(addr)
+ if countI != nil {
+ count = countI.(uint16)
+ }
+ count++
+ r.msgCountByPeer.Set(addr, count)
+}
+
+// Ensures that sufficient peers are connected. (continuous)
+func (r *PEXReactor) ensurePeersRoutine() {
+ // Randomize when routine starts
+ ensurePeersPeriodMs := r.ensurePeersPeriod.Nanoseconds() / 1e6
+ time.Sleep(time.Duration(rand.Int63n(ensurePeersPeriodMs)) * time.Millisecond)
+
+ // fire once immediately.
+ r.ensurePeers()
+
+ // fire periodically
+ ticker := time.NewTicker(r.ensurePeersPeriod)
+
+ for {
+ select {
+ case <-ticker.C:
+ r.ensurePeers()
+ case <-r.Quit:
+ ticker.Stop()
+ return
+ }
+ }
+}
+
+// ensurePeers ensures that sufficient peers are connected. (once)
+//
+// Old bucket / New bucket are arbitrary categories to denote whether an
+// address is vetted or not, and this needs to be determined over time via a
+// heuristic that we haven't perfected yet, or, perhaps is manually edited by
+// the node operator. It should not be used to compute what addresses are
+// already connected or not.
+//
+// TODO Basically, we need to work harder on our good-peer/bad-peer marking.
+// What we're currently doing in terms of marking good/bad peers is just a
+// placeholder. It should not be the case that an address becomes old/vetted
+// upon a single successful connection.
+func (r *PEXReactor) ensurePeers() {
+ numOutPeers, _, numDialing := r.Switch.NumPeers()
+ numToDial := minNumOutboundPeers - (numOutPeers + numDialing)
+ log.Info("Ensure peers", "numOutPeers", numOutPeers, "numDialing", numDialing, "numToDial", numToDial)
+ if numToDial <= 0 {
+ return
+ }
+
+ toDial := make(map[string]*NetAddress)
+
+ // Try to pick numToDial addresses to dial.
+ for i := 0; i < numToDial; i++ {
+ // The purpose of newBias is to first prioritize old (more vetted) peers
+ // when we have few connections, but to allow for new (less vetted) peers
+ // if we already have many connections. This algorithm isn't perfect, but
+ // it somewhat ensures that we prioritize connecting to more-vetted
+ // peers.
+ newBias := cmn.MinInt(numOutPeers, 8)*10 + 10
+ var picked *NetAddress
+ // Try to fetch a new peer 3 times.
+ // This caps the maximum number of tries to 3 * numToDial.
+ for j := 0; j < 3; j++ {
+ try := r.book.PickAddress(newBias)
+ if try == nil {
+ break
+ }
+ _, alreadySelected := toDial[try.IP.String()]
+ alreadyDialing := r.Switch.IsDialing(try)
+ alreadyConnected := r.Switch.Peers().Has(try.IP.String())
+ if alreadySelected || alreadyDialing || alreadyConnected {
+ // log.Info("Cannot dial address", "addr", try,
+ // "alreadySelected", alreadySelected,
+ // "alreadyDialing", alreadyDialing,
+ // "alreadyConnected", alreadyConnected)
+ continue
+ } else {
+ log.Info("Will dial address", "addr", try)
+ picked = try
+ break
+ }
+ }
+ if picked == nil {
+ continue
+ }
+ toDial[picked.IP.String()] = picked
+ }
+
+ // Dial picked addresses
+ for _, item := range toDial {
+ go func(picked *NetAddress) {
+ _, err := r.Switch.DialPeerWithAddress(picked, false)
+ if err != nil {
+ r.book.MarkAttempt(picked)
+ }
+ }(item)
+ }
+
+ // If we need more addresses, pick a random peer and ask for more.
+ if r.book.NeedMoreAddrs() {
+ if peers := r.Switch.Peers().List(); len(peers) > 0 {
+ i := rand.Int() % len(peers)
+ peer := peers[i]
+ log.Info("No addresses to dial. Sending pexRequest to random peer", "peer", peer)
+ r.RequestPEX(peer)
+ }
+ }
+}
+
+func (r *PEXReactor) flushMsgCountByPeer() {
+ ticker := time.NewTicker(msgCountByPeerFlushInterval)
+
+ for {
+ select {
+ case <-ticker.C:
+ r.msgCountByPeer.Clear()
+ case <-r.Quit:
+ ticker.Stop()
+ return
+ }
+ }
+}
+
+//-----------------------------------------------------------------------------
+// Messages
+
+const (
+ msgTypeRequest = byte(0x01)
+ msgTypeAddrs = byte(0x02)
+)
+
+// PexMessage is a primary type for PEX messages. Underneath, it could contain
+// either pexRequestMessage, or pexAddrsMessage messages.
+type PexMessage interface{}
+
+var _ = wire.RegisterInterface(
+ struct{ PexMessage }{},
+ wire.ConcreteType{&pexRequestMessage{}, msgTypeRequest},
+ wire.ConcreteType{&pexAddrsMessage{}, msgTypeAddrs},
+)
+
+// DecodeMessage implements interface registered above.
+func DecodeMessage(bz []byte) (msgType byte, msg PexMessage, err error) {
+ msgType = bz[0]
+ n := new(int)
+ r := bytes.NewReader(bz)
+ msg = wire.ReadBinary(struct{ PexMessage }{}, r, maxPexMessageSize, n, &err).(struct{ PexMessage }).PexMessage
+ return
+}
+
+/*
+A pexRequestMessage requests additional peer addresses.
+*/
+type pexRequestMessage struct {
+}
+
+func (m *pexRequestMessage) String() string {
+ return "[pexRequest]"
+}
+
+/*
+A message with announced peer addresses.
+*/
+type pexAddrsMessage struct {
+ Addrs []*NetAddress
+}
+
+func (m *pexAddrsMessage) String() string {
+ return fmt.Sprintf("[pexAddrs %v]", m.Addrs)
+}
diff --git a/pex_reactor_test.go b/pex_reactor_test.go
new file mode 100644
index 000000000..aed6c758d
--- /dev/null
+++ b/pex_reactor_test.go
@@ -0,0 +1,163 @@
+package p2p
+
+import (
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ cmn "github.com/tendermint/tmlibs/common"
+ wire "github.com/tendermint/go-wire"
+)
+
+func TestPEXReactorBasic(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ dir, err := ioutil.TempDir("", "pex_reactor")
+ require.Nil(err)
+ defer os.RemoveAll(dir)
+ book := NewAddrBook(dir+"addrbook.json", true)
+
+ r := NewPEXReactor(book)
+
+ assert.NotNil(r)
+ assert.NotEmpty(r.GetChannels())
+}
+
+func TestPEXReactorAddRemovePeer(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ dir, err := ioutil.TempDir("", "pex_reactor")
+ require.Nil(err)
+ defer os.RemoveAll(dir)
+ book := NewAddrBook(dir+"addrbook.json", true)
+
+ r := NewPEXReactor(book)
+
+ size := book.Size()
+ peer := createRandomPeer(false)
+
+ r.AddPeer(peer)
+ assert.Equal(size+1, book.Size())
+
+ r.RemovePeer(peer, "peer not available")
+ assert.Equal(size+1, book.Size())
+
+ outboundPeer := createRandomPeer(true)
+
+ r.AddPeer(outboundPeer)
+ assert.Equal(size+1, book.Size(), "outbound peers should not be added to the address book")
+
+ r.RemovePeer(outboundPeer, "peer not available")
+ assert.Equal(size+1, book.Size())
+}
+
+func TestPEXReactorRunning(t *testing.T) {
+ require := require.New(t)
+
+ N := 3
+ switches := make([]*Switch, N)
+
+ dir, err := ioutil.TempDir("", "pex_reactor")
+ require.Nil(err)
+ defer os.RemoveAll(dir)
+ book := NewAddrBook(dir+"addrbook.json", false)
+
+ // create switches
+ for i := 0; i < N; i++ {
+ switches[i] = makeSwitch(i, "127.0.0.1", "123.123.123", func(i int, sw *Switch) *Switch {
+ r := NewPEXReactor(book)
+ r.SetEnsurePeersPeriod(250 * time.Millisecond)
+ sw.AddReactor("pex", r)
+ return sw
+ })
+ }
+
+ // fill the address book and add listeners
+ for _, s := range switches {
+ addr, _ := NewNetAddressString(s.NodeInfo().ListenAddr)
+ book.AddAddress(addr, addr)
+ s.AddListener(NewDefaultListener("tcp", s.NodeInfo().ListenAddr, true))
+ }
+
+ // start switches
+ for _, s := range switches {
+ _, err := s.Start() // start switch and reactors
+ require.Nil(err)
+ }
+
+ time.Sleep(1 * time.Second)
+
+ // check peers are connected after some time
+ for _, s := range switches {
+ outbound, inbound, _ := s.NumPeers()
+ if outbound+inbound == 0 {
+ t.Errorf("%v expected to be connected to at least one peer", s.NodeInfo().ListenAddr)
+ }
+ }
+
+ // stop them
+ for _, s := range switches {
+ s.Stop()
+ }
+}
+
+func TestPEXReactorReceive(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ dir, err := ioutil.TempDir("", "pex_reactor")
+ require.Nil(err)
+ defer os.RemoveAll(dir)
+ book := NewAddrBook(dir+"addrbook.json", true)
+
+ r := NewPEXReactor(book)
+
+ peer := createRandomPeer(false)
+
+ size := book.Size()
+ netAddr, _ := NewNetAddressString(peer.ListenAddr)
+ addrs := []*NetAddress{netAddr}
+ msg := wire.BinaryBytes(struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}})
+ r.Receive(PexChannel, peer, msg)
+ assert.Equal(size+1, book.Size())
+
+ msg = wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}})
+ r.Receive(PexChannel, peer, msg)
+}
+
+func TestPEXReactorAbuseFromPeer(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ dir, err := ioutil.TempDir("", "pex_reactor")
+ require.Nil(err)
+ defer os.RemoveAll(dir)
+ book := NewAddrBook(dir+"addrbook.json", true)
+
+ r := NewPEXReactor(book)
+ r.SetMaxMsgCountByPeer(5)
+
+ peer := createRandomPeer(false)
+
+ msg := wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}})
+ for i := 0; i < 10; i++ {
+ r.Receive(PexChannel, peer, msg)
+ }
+
+ assert.True(r.ReachedMaxMsgCountForPeer(peer.ListenAddr))
+}
+
+func createRandomPeer(outbound bool) *Peer {
+ addr := cmn.Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256)
+ netAddr, _ := NewNetAddressString(addr)
+ return &Peer{
+ Key: cmn.RandStr(12),
+ NodeInfo: &NodeInfo{
+ ListenAddr: addr,
+ },
+ outbound: outbound,
+ mconn: &MConnection{RemoteAddress: netAddr},
+ }
+}
diff --git a/secret_connection.go b/secret_connection.go
new file mode 100644
index 000000000..446c4f185
--- /dev/null
+++ b/secret_connection.go
@@ -0,0 +1,346 @@
+// Uses nacl's secret_box to encrypt a net.Conn.
+// It is (meant to be) an implementation of the STS protocol.
+// Note we do not (yet) assume that a remote peer's pubkey
+// is known ahead of time, and thus we are technically
+// still vulnerable to MITM. (TODO!)
+// See docs/sts-final.pdf for more info
+package p2p
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "crypto/sha256"
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "time"
+
+ "golang.org/x/crypto/nacl/box"
+ "golang.org/x/crypto/nacl/secretbox"
+ "golang.org/x/crypto/ripemd160"
+
+ "github.com/tendermint/go-crypto"
+ "github.com/tendermint/go-wire"
+ . "github.com/tendermint/tmlibs/common"
+)
+
+// 2 + 1024 == 1026 total frame size
+const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize
+const dataMaxSize = 1024
+const totalFrameSize = dataMaxSize + dataLenSize
+const sealedFrameSize = totalFrameSize + secretbox.Overhead
+const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays
+
+// Implements net.Conn
+type SecretConnection struct {
+ conn io.ReadWriteCloser
+ recvBuffer []byte
+ recvNonce *[24]byte
+ sendNonce *[24]byte
+ remPubKey crypto.PubKeyEd25519
+ shrSecret *[32]byte // shared secret
+}
+
+// Performs handshake and returns a new authenticated SecretConnection.
+// Returns nil if error in handshake.
+// Caller should call conn.Close()
+// See docs/sts-final.pdf for more information.
+func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25519) (*SecretConnection, error) {
+
+ locPubKey := locPrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
+
+ // Generate ephemeral keys for perfect forward secrecy.
+ locEphPub, locEphPriv := genEphKeys()
+
+ // Write local ephemeral pubkey and receive one too.
+ // NOTE: every 32-byte string is accepted as a Curve25519 public key
+ // (see DJB's Curve25519 paper: http://cr.yp.to/ecdh/curve25519-20060209.pdf)
+ remEphPub, err := shareEphPubKey(conn, locEphPub)
+ if err != nil {
+ return nil, err
+ }
+
+ // Compute common shared secret.
+ shrSecret := computeSharedSecret(remEphPub, locEphPriv)
+
+ // Sort by lexical order.
+ loEphPub, hiEphPub := sort32(locEphPub, remEphPub)
+
+ // Generate nonces to use for secretbox.
+ recvNonce, sendNonce := genNonces(loEphPub, hiEphPub, locEphPub == loEphPub)
+
+ // Generate common challenge to sign.
+ challenge := genChallenge(loEphPub, hiEphPub)
+
+ // Construct SecretConnection.
+ sc := &SecretConnection{
+ conn: conn,
+ recvBuffer: nil,
+ recvNonce: recvNonce,
+ sendNonce: sendNonce,
+ shrSecret: shrSecret,
+ }
+
+ // Sign the challenge bytes for authentication.
+ locSignature := signChallenge(challenge, locPrivKey)
+
+ // Share (in secret) each other's pubkey & challenge signature
+ authSigMsg, err := shareAuthSignature(sc, locPubKey, locSignature)
+ if err != nil {
+ return nil, err
+ }
+ remPubKey, remSignature := authSigMsg.Key, authSigMsg.Sig
+ if !remPubKey.VerifyBytes(challenge[:], remSignature) {
+ return nil, errors.New("Challenge verification failed")
+ }
+
+ // We've authorized.
+ sc.remPubKey = remPubKey.Unwrap().(crypto.PubKeyEd25519)
+ return sc, nil
+}
+
+// Returns authenticated remote pubkey
+func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 {
+ return sc.remPubKey
+}
+
+// Writes encrypted frames of `sealedFrameSize`
+// CONTRACT: data smaller than dataMaxSize is read atomically.
+func (sc *SecretConnection) Write(data []byte) (n int, err error) {
+ for 0 < len(data) {
+ var frame []byte = make([]byte, totalFrameSize)
+ var chunk []byte
+ if dataMaxSize < len(data) {
+ chunk = data[:dataMaxSize]
+ data = data[dataMaxSize:]
+ } else {
+ chunk = data
+ data = nil
+ }
+ chunkLength := len(chunk)
+ binary.BigEndian.PutUint16(frame, uint16(chunkLength))
+ copy(frame[dataLenSize:], chunk)
+
+ // encrypt the frame
+ var sealedFrame = make([]byte, sealedFrameSize)
+ secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret)
+ // fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret)
+ incr2Nonce(sc.sendNonce)
+ // end encryption
+
+ _, err := sc.conn.Write(sealedFrame)
+ if err != nil {
+ return n, err
+ } else {
+ n += len(chunk)
+ }
+ }
+ return
+}
+
+// CONTRACT: data smaller than dataMaxSize is read atomically.
+func (sc *SecretConnection) Read(data []byte) (n int, err error) {
+ if 0 < len(sc.recvBuffer) {
+ n_ := copy(data, sc.recvBuffer)
+ sc.recvBuffer = sc.recvBuffer[n_:]
+ return
+ }
+
+ sealedFrame := make([]byte, sealedFrameSize)
+ _, err = io.ReadFull(sc.conn, sealedFrame)
+ if err != nil {
+ return
+ }
+
+ // decrypt the frame
+ var frame = make([]byte, totalFrameSize)
+ // fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret)
+ _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret)
+ if !ok {
+ return n, errors.New("Failed to decrypt SecretConnection")
+ }
+ incr2Nonce(sc.recvNonce)
+ // end decryption
+
+ var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes
+ if chunkLength > dataMaxSize {
+ return 0, errors.New("chunkLength is greater than dataMaxSize")
+ }
+ var chunk = frame[dataLenSize : dataLenSize+chunkLength]
+
+ n = copy(data, chunk)
+ sc.recvBuffer = chunk[n:]
+ return
+}
+
+// Implements net.Conn
+func (sc *SecretConnection) Close() error { return sc.conn.Close() }
+func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() }
+func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() }
+func (sc *SecretConnection) SetDeadline(t time.Time) error { return sc.conn.(net.Conn).SetDeadline(t) }
+func (sc *SecretConnection) SetReadDeadline(t time.Time) error {
+ return sc.conn.(net.Conn).SetReadDeadline(t)
+}
+func (sc *SecretConnection) SetWriteDeadline(t time.Time) error {
+ return sc.conn.(net.Conn).SetWriteDeadline(t)
+}
+
+func genEphKeys() (ephPub, ephPriv *[32]byte) {
+ var err error
+ ephPub, ephPriv, err = box.GenerateKey(crand.Reader)
+ if err != nil {
+ PanicCrisis("Could not generate ephemeral keypairs")
+ }
+ return
+}
+
+func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) {
+ var err1, err2 error
+
+ Parallel(
+ func() {
+ _, err1 = conn.Write(locEphPub[:])
+ },
+ func() {
+ remEphPub = new([32]byte)
+ _, err2 = io.ReadFull(conn, remEphPub[:])
+ },
+ )
+
+ if err1 != nil {
+ return nil, err1
+ }
+ if err2 != nil {
+ return nil, err2
+ }
+
+ return remEphPub, nil
+}
+
+func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) {
+ shrSecret = new([32]byte)
+ box.Precompute(shrSecret, remPubKey, locPrivKey)
+ return
+}
+
+func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) {
+ if bytes.Compare(foo[:], bar[:]) < 0 {
+ lo = foo
+ hi = bar
+ } else {
+ lo = bar
+ hi = foo
+ }
+ return
+}
+
+func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce *[24]byte) {
+ nonce1 := hash24(append(loPubKey[:], hiPubKey[:]...))
+ nonce2 := new([24]byte)
+ copy(nonce2[:], nonce1[:])
+ nonce2[len(nonce2)-1] ^= 0x01
+ if locIsLo {
+ recvNonce = nonce1
+ sendNonce = nonce2
+ } else {
+ recvNonce = nonce2
+ sendNonce = nonce1
+ }
+ return
+}
+
+func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) {
+ return hash32(append(loPubKey[:], hiPubKey[:]...))
+}
+
+func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signature crypto.SignatureEd25519) {
+ signature = locPrivKey.Sign(challenge[:]).Unwrap().(crypto.SignatureEd25519)
+ return
+}
+
+type authSigMessage struct {
+ Key crypto.PubKey
+ Sig crypto.Signature
+}
+
+func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signature crypto.SignatureEd25519) (*authSigMessage, error) {
+ var recvMsg authSigMessage
+ var err1, err2 error
+
+ Parallel(
+ func() {
+ msgBytes := wire.BinaryBytes(authSigMessage{pubKey.Wrap(), signature.Wrap()})
+ _, err1 = sc.Write(msgBytes)
+ },
+ func() {
+ readBuffer := make([]byte, authSigMsgSize)
+ _, err2 = io.ReadFull(sc, readBuffer)
+ if err2 != nil {
+ return
+ }
+ n := int(0) // not used.
+ recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage)
+ })
+
+ if err1 != nil {
+ return nil, err1
+ }
+ if err2 != nil {
+ return nil, err2
+ }
+
+ return &recvMsg, nil
+}
+
+func verifyChallengeSignature(challenge *[32]byte, remPubKey crypto.PubKeyEd25519, remSignature crypto.SignatureEd25519) bool {
+ return remPubKey.VerifyBytes(challenge[:], remSignature.Wrap())
+}
+
+//--------------------------------------------------------------------------------
+
+// sha256
+func hash32(input []byte) (res *[32]byte) {
+ hasher := sha256.New()
+ hasher.Write(input) // does not error
+ resSlice := hasher.Sum(nil)
+ res = new([32]byte)
+ copy(res[:], resSlice)
+ return
+}
+
+// We only fill in the first 20 bytes with ripemd160
+func hash24(input []byte) (res *[24]byte) {
+ hasher := ripemd160.New()
+ hasher.Write(input) // does not error
+ resSlice := hasher.Sum(nil)
+ res = new([24]byte)
+ copy(res[:], resSlice)
+ return
+}
+
+// ripemd160
+func hash20(input []byte) (res *[20]byte) {
+ hasher := ripemd160.New()
+ hasher.Write(input) // does not error
+ resSlice := hasher.Sum(nil)
+ res = new([20]byte)
+ copy(res[:], resSlice)
+ return
+}
+
+// increment nonce big-endian by 2 with wraparound.
+func incr2Nonce(nonce *[24]byte) {
+ incrNonce(nonce)
+ incrNonce(nonce)
+}
+
+// increment nonce big-endian by 1 with wraparound.
+func incrNonce(nonce *[24]byte) {
+ for i := 23; 0 <= i; i-- {
+ nonce[i] += 1
+ if nonce[i] != 0 {
+ return
+ }
+ }
+}
diff --git a/secret_connection_test.go b/secret_connection_test.go
new file mode 100644
index 000000000..3dd962f88
--- /dev/null
+++ b/secret_connection_test.go
@@ -0,0 +1,202 @@
+package p2p
+
+import (
+ "bytes"
+ "io"
+ "testing"
+
+ "github.com/tendermint/go-crypto"
+ . "github.com/tendermint/tmlibs/common"
+)
+
+type dummyConn struct {
+ *io.PipeReader
+ *io.PipeWriter
+}
+
+func (drw dummyConn) Close() (err error) {
+ err2 := drw.PipeWriter.CloseWithError(io.EOF)
+ err1 := drw.PipeReader.Close()
+ if err2 != nil {
+ return err
+ }
+ return err1
+}
+
+// Each returned ReadWriteCloser is akin to a net.Connection
+func makeDummyConnPair() (fooConn, barConn dummyConn) {
+ barReader, fooWriter := io.Pipe()
+ fooReader, barWriter := io.Pipe()
+ return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter}
+}
+
+func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
+ fooConn, barConn := makeDummyConnPair()
+ fooPrvKey := crypto.GenPrivKeyEd25519()
+ fooPubKey := fooPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
+ barPrvKey := crypto.GenPrivKeyEd25519()
+ barPubKey := barPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
+
+ Parallel(
+ func() {
+ var err error
+ fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
+ if err != nil {
+ tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
+ return
+ }
+ remotePubBytes := fooSecConn.RemotePubKey()
+ if !bytes.Equal(remotePubBytes[:], barPubKey[:]) {
+ tb.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
+ barPubKey, fooSecConn.RemotePubKey())
+ }
+ },
+ func() {
+ var err error
+ barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
+ if barSecConn == nil {
+ tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
+ return
+ }
+ remotePubBytes := barSecConn.RemotePubKey()
+ if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) {
+ tb.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
+ fooPubKey, barSecConn.RemotePubKey())
+ }
+ })
+
+ return
+}
+
+func TestSecretConnectionHandshake(t *testing.T) {
+ fooSecConn, barSecConn := makeSecretConnPair(t)
+ fooSecConn.Close()
+ barSecConn.Close()
+}
+
+func TestSecretConnectionReadWrite(t *testing.T) {
+ fooConn, barConn := makeDummyConnPair()
+ fooWrites, barWrites := []string{}, []string{}
+ fooReads, barReads := []string{}, []string{}
+
+ // Pre-generate the things to write (for foo & bar)
+ for i := 0; i < 100; i++ {
+ fooWrites = append(fooWrites, RandStr((RandInt()%(dataMaxSize*5))+1))
+ barWrites = append(barWrites, RandStr((RandInt()%(dataMaxSize*5))+1))
+ }
+
+ // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
+ genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() {
+ return func() {
+ // Node handskae
+ nodePrvKey := crypto.GenPrivKeyEd25519()
+ nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
+ if err != nil {
+ t.Errorf("Failed to establish SecretConnection for node: %v", err)
+ return
+ }
+ // In parallel, handle reads and writes
+ Parallel(
+ func() {
+ // Node writes
+ for _, nodeWrite := range nodeWrites {
+ n, err := nodeSecretConn.Write([]byte(nodeWrite))
+ if err != nil {
+ t.Errorf("Failed to write to nodeSecretConn: %v", err)
+ return
+ }
+ if n != len(nodeWrite) {
+ t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
+ return
+ }
+ }
+ nodeConn.PipeWriter.Close()
+ },
+ func() {
+ // Node reads
+ readBuffer := make([]byte, dataMaxSize)
+ for {
+ n, err := nodeSecretConn.Read(readBuffer)
+ if err == io.EOF {
+ return
+ } else if err != nil {
+ t.Errorf("Failed to read from nodeSecretConn: %v", err)
+ return
+ }
+ *nodeReads = append(*nodeReads, string(readBuffer[:n]))
+ }
+ nodeConn.PipeReader.Close()
+ })
+ }
+ }
+
+ // Run foo & bar in parallel
+ Parallel(
+ genNodeRunner(fooConn, fooWrites, &fooReads),
+ genNodeRunner(barConn, barWrites, &barReads),
+ )
+
+ // A helper to ensure that the writes and reads match.
+ // Additionally, small writes (<= dataMaxSize) must be atomically read.
+ compareWritesReads := func(writes []string, reads []string) {
+ for {
+ // Pop next write & corresponding reads
+ var read, write string = "", writes[0]
+ var readCount = 0
+ for _, readChunk := range reads {
+ read += readChunk
+ readCount += 1
+ if len(write) <= len(read) {
+ break
+ }
+ if len(write) <= dataMaxSize {
+ break // atomicity of small writes
+ }
+ }
+ // Compare
+ if write != read {
+ t.Errorf("Expected to read %X, got %X", write, read)
+ }
+ // Iterate
+ writes = writes[1:]
+ reads = reads[readCount:]
+ if len(writes) == 0 {
+ break
+ }
+ }
+ }
+
+ compareWritesReads(fooWrites, barReads)
+ compareWritesReads(barWrites, fooReads)
+
+}
+
+func BenchmarkSecretConnection(b *testing.B) {
+ b.StopTimer()
+ fooSecConn, barSecConn := makeSecretConnPair(b)
+ fooWriteText := RandStr(dataMaxSize)
+ // Consume reads from bar's reader
+ go func() {
+ readBuffer := make([]byte, dataMaxSize)
+ for {
+ _, err := barSecConn.Read(readBuffer)
+ if err == io.EOF {
+ return
+ } else if err != nil {
+ b.Fatalf("Failed to read from barSecConn: %v", err)
+ }
+ }
+ }()
+
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ _, err := fooSecConn.Write([]byte(fooWriteText))
+ if err != nil {
+ b.Fatalf("Failed to write to fooSecConn: %v", err)
+ }
+ }
+ b.StopTimer()
+
+ fooSecConn.Close()
+ //barSecConn.Close() race condition
+}
diff --git a/switch.go b/switch.go
new file mode 100644
index 000000000..8c771df1c
--- /dev/null
+++ b/switch.go
@@ -0,0 +1,593 @@
+package p2p
+
+import (
+ "errors"
+ "fmt"
+ "math/rand"
+ "net"
+ "time"
+
+ cfg "github.com/tendermint/go-config"
+ crypto "github.com/tendermint/go-crypto"
+ "github.com/tendermint/log15"
+ . "github.com/tendermint/tmlibs/common"
+)
+
+const (
+ reconnectAttempts = 30
+ reconnectInterval = 3 * time.Second
+)
+
+type Reactor interface {
+ Service // Start, Stop
+
+ SetSwitch(*Switch)
+ GetChannels() []*ChannelDescriptor
+ AddPeer(peer *Peer)
+ RemovePeer(peer *Peer, reason interface{})
+ Receive(chID byte, peer *Peer, msgBytes []byte)
+}
+
+//--------------------------------------
+
+type BaseReactor struct {
+ BaseService // Provides Start, Stop, .Quit
+ Switch *Switch
+}
+
+func NewBaseReactor(log log15.Logger, name string, impl Reactor) *BaseReactor {
+ return &BaseReactor{
+ BaseService: *NewBaseService(log, name, impl),
+ Switch: nil,
+ }
+}
+
+func (br *BaseReactor) SetSwitch(sw *Switch) {
+ br.Switch = sw
+}
+func (_ *BaseReactor) GetChannels() []*ChannelDescriptor { return nil }
+func (_ *BaseReactor) AddPeer(peer *Peer) {}
+func (_ *BaseReactor) RemovePeer(peer *Peer, reason interface{}) {}
+func (_ *BaseReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {}
+
+//-----------------------------------------------------------------------------
+
+/*
+The `Switch` handles peer connections and exposes an API to receive incoming messages
+on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one
+or more `Channels`. So while sending outgoing messages is typically performed on the peer,
+incoming messages are received on the reactor.
+*/
+type Switch struct {
+ BaseService
+
+ config cfg.Config
+ listeners []Listener
+ reactors map[string]Reactor
+ chDescs []*ChannelDescriptor
+ reactorsByCh map[byte]Reactor
+ peers *PeerSet
+ dialing *CMap
+ nodeInfo *NodeInfo // our node info
+ nodePrivKey crypto.PrivKeyEd25519 // our node privkey
+
+ filterConnByAddr func(net.Addr) error
+ filterConnByPubKey func(crypto.PubKeyEd25519) error
+}
+
+var (
+ ErrSwitchDuplicatePeer = errors.New("Duplicate peer")
+ ErrSwitchMaxPeersPerIPRange = errors.New("IP range has too many peers")
+)
+
+func NewSwitch(config cfg.Config) *Switch {
+ setConfigDefaults(config)
+
+ sw := &Switch{
+ config: config,
+ reactors: make(map[string]Reactor),
+ chDescs: make([]*ChannelDescriptor, 0),
+ reactorsByCh: make(map[byte]Reactor),
+ peers: NewPeerSet(),
+ dialing: NewCMap(),
+ nodeInfo: nil,
+ }
+ sw.BaseService = *NewBaseService(log, "P2P Switch", sw)
+ return sw
+}
+
+// Not goroutine safe.
+func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor {
+ // Validate the reactor.
+ // No two reactors can share the same channel.
+ reactorChannels := reactor.GetChannels()
+ for _, chDesc := range reactorChannels {
+ chID := chDesc.ID
+ if sw.reactorsByCh[chID] != nil {
+ PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor))
+ }
+ sw.chDescs = append(sw.chDescs, chDesc)
+ sw.reactorsByCh[chID] = reactor
+ }
+ sw.reactors[name] = reactor
+ reactor.SetSwitch(sw)
+ return reactor
+}
+
+// Not goroutine safe.
+func (sw *Switch) Reactors() map[string]Reactor {
+ return sw.reactors
+}
+
+// Not goroutine safe.
+func (sw *Switch) Reactor(name string) Reactor {
+ return sw.reactors[name]
+}
+
+// Not goroutine safe.
+func (sw *Switch) AddListener(l Listener) {
+ sw.listeners = append(sw.listeners, l)
+}
+
+// Not goroutine safe.
+func (sw *Switch) Listeners() []Listener {
+ return sw.listeners
+}
+
+// Not goroutine safe.
+func (sw *Switch) IsListening() bool {
+ return len(sw.listeners) > 0
+}
+
+// Not goroutine safe.
+func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) {
+ sw.nodeInfo = nodeInfo
+}
+
+// Not goroutine safe.
+func (sw *Switch) NodeInfo() *NodeInfo {
+ return sw.nodeInfo
+}
+
+// Not goroutine safe.
+// NOTE: Overwrites sw.nodeInfo.PubKey
+func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) {
+ sw.nodePrivKey = nodePrivKey
+ if sw.nodeInfo != nil {
+ sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
+ }
+}
+
+// Switch.Start() starts all the reactors, peers, and listeners.
+func (sw *Switch) OnStart() error {
+ sw.BaseService.OnStart()
+ // Start reactors
+ for _, reactor := range sw.reactors {
+ _, err := reactor.Start()
+ if err != nil {
+ return err
+ }
+ }
+ // Start peers
+ for _, peer := range sw.peers.List() {
+ sw.startInitPeer(peer)
+ }
+ // Start listeners
+ for _, listener := range sw.listeners {
+ go sw.listenerRoutine(listener)
+ }
+ return nil
+}
+
+func (sw *Switch) OnStop() {
+ sw.BaseService.OnStop()
+ // Stop listeners
+ for _, listener := range sw.listeners {
+ listener.Stop()
+ }
+ sw.listeners = nil
+ // Stop peers
+ for _, peer := range sw.peers.List() {
+ peer.Stop()
+ sw.peers.Remove(peer)
+ }
+ // Stop reactors
+ for _, reactor := range sw.reactors {
+ reactor.Stop()
+ }
+}
+
+// NOTE: This performs a blocking handshake before the peer is added.
+// CONTRACT: If error is returned, peer is nil, and conn is immediately closed.
+func (sw *Switch) AddPeer(peer *Peer) error {
+ if err := sw.FilterConnByAddr(peer.Addr()); err != nil {
+ return err
+ }
+
+ if err := sw.FilterConnByPubKey(peer.PubKey()); err != nil {
+ return err
+ }
+
+ if err := peer.HandshakeTimeout(sw.nodeInfo, time.Duration(sw.config.GetInt(configKeyHandshakeTimeoutSeconds))*time.Second); err != nil {
+ return err
+ }
+
+ // Avoid self
+ if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) {
+ return errors.New("Ignoring connection from self")
+ }
+
+ // Check version, chain id
+ if err := sw.nodeInfo.CompatibleWith(peer.NodeInfo); err != nil {
+ return err
+ }
+
+ // Add the peer to .peers
+ // ignore if duplicate or if we already have too many for that IP range
+ if err := sw.peers.Add(peer); err != nil {
+ log.Notice("Ignoring peer", "error", err, "peer", peer)
+ peer.Stop()
+ return err
+ }
+
+ // Start peer
+ if sw.IsRunning() {
+ sw.startInitPeer(peer)
+ }
+
+ log.Notice("Added peer", "peer", peer)
+ return nil
+}
+
+func (sw *Switch) FilterConnByAddr(addr net.Addr) error {
+ if sw.filterConnByAddr != nil {
+ return sw.filterConnByAddr(addr)
+ }
+ return nil
+}
+
+func (sw *Switch) FilterConnByPubKey(pubkey crypto.PubKeyEd25519) error {
+ if sw.filterConnByPubKey != nil {
+ return sw.filterConnByPubKey(pubkey)
+ }
+ return nil
+
+}
+
+func (sw *Switch) SetAddrFilter(f func(net.Addr) error) {
+ sw.filterConnByAddr = f
+}
+
+func (sw *Switch) SetPubKeyFilter(f func(crypto.PubKeyEd25519) error) {
+ sw.filterConnByPubKey = f
+}
+
+func (sw *Switch) startInitPeer(peer *Peer) {
+ peer.Start() // spawn send/recv routines
+ for _, reactor := range sw.reactors {
+ reactor.AddPeer(peer)
+ }
+}
+
+// Dial a list of seeds asynchronously in random order
+func (sw *Switch) DialSeeds(addrBook *AddrBook, seeds []string) error {
+
+ netAddrs, err := NewNetAddressStrings(seeds)
+ if err != nil {
+ return err
+ }
+
+ if addrBook != nil {
+ // add seeds to `addrBook`
+ ourAddrS := sw.nodeInfo.ListenAddr
+ ourAddr, _ := NewNetAddressString(ourAddrS)
+ for _, netAddr := range netAddrs {
+ // do not add ourselves
+ if netAddr.Equals(ourAddr) {
+ continue
+ }
+ addrBook.AddAddress(netAddr, ourAddr)
+ }
+ addrBook.Save()
+ }
+
+ // permute the list, dial them in random order.
+ perm := rand.Perm(len(netAddrs))
+ for i := 0; i < len(perm); i++ {
+ go func(i int) {
+ time.Sleep(time.Duration(rand.Int63n(3000)) * time.Millisecond)
+ j := perm[i]
+ sw.dialSeed(netAddrs[j])
+ }(i)
+ }
+ return nil
+}
+
+func (sw *Switch) dialSeed(addr *NetAddress) {
+ peer, err := sw.DialPeerWithAddress(addr, true)
+ if err != nil {
+ log.Error("Error dialing seed", "error", err)
+ return
+ } else {
+ log.Notice("Connected to seed", "peer", peer)
+ }
+}
+
+func (sw *Switch) DialPeerWithAddress(addr *NetAddress, persistent bool) (*Peer, error) {
+ sw.dialing.Set(addr.IP.String(), addr)
+ defer sw.dialing.Delete(addr.IP.String())
+
+ peer, err := newOutboundPeerWithConfig(addr, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, peerConfigFromGoConfig(sw.config))
+ if err != nil {
+ log.Info("Failed dialing peer", "address", addr, "error", err)
+ return nil, err
+ }
+ if persistent {
+ peer.makePersistent()
+ }
+ err = sw.AddPeer(peer)
+ if err != nil {
+ log.Info("Failed adding peer", "address", addr, "error", err)
+ peer.CloseConn()
+ return nil, err
+ }
+ log.Notice("Dialed and added peer", "address", addr, "peer", peer)
+ return peer, nil
+}
+
+func (sw *Switch) IsDialing(addr *NetAddress) bool {
+ return sw.dialing.Has(addr.IP.String())
+}
+
+// Broadcast runs a go routine for each attempted send, which will block
+// trying to send for defaultSendTimeoutSeconds. Returns a channel
+// which receives success values for each attempted send (false if times out)
+// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved.
+func (sw *Switch) Broadcast(chID byte, msg interface{}) chan bool {
+ successChan := make(chan bool, len(sw.peers.List()))
+ log.Debug("Broadcast", "channel", chID, "msg", msg)
+ for _, peer := range sw.peers.List() {
+ go func(peer *Peer) {
+ success := peer.Send(chID, msg)
+ successChan <- success
+ }(peer)
+ }
+ return successChan
+}
+
+// Returns the count of outbound/inbound and outbound-dialing peers.
+func (sw *Switch) NumPeers() (outbound, inbound, dialing int) {
+ peers := sw.peers.List()
+ for _, peer := range peers {
+ if peer.outbound {
+ outbound++
+ } else {
+ inbound++
+ }
+ }
+ dialing = sw.dialing.Size()
+ return
+}
+
+func (sw *Switch) Peers() IPeerSet {
+ return sw.peers
+}
+
+// Disconnect from a peer due to external error, retry if it is a persistent peer.
+// TODO: make record depending on reason.
+func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
+ addr := NewNetAddress(peer.Addr())
+ log.Notice("Stopping peer for error", "peer", peer, "error", reason)
+ sw.stopAndRemovePeer(peer, reason)
+
+ if peer.IsPersistent() {
+ go func() {
+ log.Notice("Reconnecting to peer", "peer", peer)
+ for i := 1; i < reconnectAttempts; i++ {
+ if !sw.IsRunning() {
+ return
+ }
+
+ peer, err := sw.DialPeerWithAddress(addr, true)
+ if err != nil {
+ if i == reconnectAttempts {
+ log.Notice("Error reconnecting to peer. Giving up", "tries", i, "error", err)
+ return
+ }
+ log.Notice("Error reconnecting to peer. Trying again", "tries", i, "error", err)
+ time.Sleep(reconnectInterval)
+ continue
+ }
+
+ log.Notice("Reconnected to peer", "peer", peer)
+ return
+ }
+ }()
+ }
+}
+
+// Disconnect from a peer gracefully.
+// TODO: handle graceful disconnects.
+func (sw *Switch) StopPeerGracefully(peer *Peer) {
+ log.Notice("Stopping peer gracefully")
+ sw.stopAndRemovePeer(peer, nil)
+}
+
+func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
+ sw.peers.Remove(peer)
+ peer.Stop()
+ for _, reactor := range sw.reactors {
+ reactor.RemovePeer(peer, reason)
+ }
+}
+
+func (sw *Switch) listenerRoutine(l Listener) {
+ for {
+ inConn, ok := <-l.Connections()
+ if !ok {
+ break
+ }
+
+ // ignore connection if we already have enough
+ maxPeers := sw.config.GetInt(configKeyMaxNumPeers)
+ if maxPeers <= sw.peers.Size() {
+ log.Info("Ignoring inbound connection: already have enough peers", "address", inConn.RemoteAddr().String(), "numPeers", sw.peers.Size(), "max", maxPeers)
+ continue
+ }
+
+ // New inbound connection!
+ err := sw.addPeerWithConnectionAndConfig(inConn, peerConfigFromGoConfig(sw.config))
+ if err != nil {
+ log.Notice("Ignoring inbound connection: error while adding peer", "address", inConn.RemoteAddr().String(), "error", err)
+ continue
+ }
+
+ // NOTE: We don't yet have the listening port of the
+ // remote (if they have a listener at all).
+ // The peerHandshake will handle that
+ }
+
+ // cleanup
+}
+
+//-----------------------------------------------------------------------------
+
+type SwitchEventNewPeer struct {
+ Peer *Peer
+}
+
+type SwitchEventDonePeer struct {
+ Peer *Peer
+ Error interface{}
+}
+
+//------------------------------------------------------------------
+// Switches connected via arbitrary net.Conn; useful for testing
+
+// Returns n switches, connected according to the connect func.
+// If connect==Connect2Switches, the switches will be fully connected.
+// initSwitch defines how the ith switch should be initialized (ie. with what reactors).
+// NOTE: panics if any switch fails to start.
+func MakeConnectedSwitches(n int, initSwitch func(int, *Switch) *Switch, connect func([]*Switch, int, int)) []*Switch {
+ switches := make([]*Switch, n)
+ for i := 0; i < n; i++ {
+ switches[i] = makeSwitch(i, "testing", "123.123.123", initSwitch)
+ }
+
+ if err := StartSwitches(switches); err != nil {
+ panic(err)
+ }
+
+ for i := 0; i < n; i++ {
+ for j := i; j < n; j++ {
+ connect(switches, i, j)
+ }
+ }
+
+ return switches
+}
+
+var PanicOnAddPeerErr = false
+
+// Will connect switches i and j via net.Pipe()
+// Blocks until a conection is established.
+// NOTE: caller ensures i and j are within bounds
+func Connect2Switches(switches []*Switch, i, j int) {
+ switchI := switches[i]
+ switchJ := switches[j]
+ c1, c2 := net.Pipe()
+ doneCh := make(chan struct{})
+ go func() {
+ err := switchI.addPeerWithConnection(c1)
+ if PanicOnAddPeerErr && err != nil {
+ panic(err)
+ }
+ doneCh <- struct{}{}
+ }()
+ go func() {
+ err := switchJ.addPeerWithConnection(c2)
+ if PanicOnAddPeerErr && err != nil {
+ panic(err)
+ }
+ doneCh <- struct{}{}
+ }()
+ <-doneCh
+ <-doneCh
+}
+
+func StartSwitches(switches []*Switch) error {
+ for _, s := range switches {
+ _, err := s.Start() // start switch and reactors
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func makeSwitch(i int, network, version string, initSwitch func(int, *Switch) *Switch) *Switch {
+ privKey := crypto.GenPrivKeyEd25519()
+ // new switch, add reactors
+ // TODO: let the config be passed in?
+ s := initSwitch(i, NewSwitch(cfg.NewMapConfig(nil)))
+ s.SetNodeInfo(&NodeInfo{
+ PubKey: privKey.PubKey().Unwrap().(crypto.PubKeyEd25519),
+ Moniker: Fmt("switch%d", i),
+ Network: network,
+ Version: version,
+ RemoteAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023),
+ ListenAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023),
+ })
+ s.SetNodePrivKey(privKey)
+ return s
+}
+
+func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
+ peer, err := newInboundPeer(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey)
+ if err != nil {
+ conn.Close()
+ return err
+ }
+
+ if err = sw.AddPeer(peer); err != nil {
+ conn.Close()
+ return err
+ }
+
+ return nil
+}
+
+func (sw *Switch) addPeerWithConnectionAndConfig(conn net.Conn, config *PeerConfig) error {
+ peer, err := newInboundPeerWithConfig(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config)
+ if err != nil {
+ conn.Close()
+ return err
+ }
+
+ if err = sw.AddPeer(peer); err != nil {
+ conn.Close()
+ return err
+ }
+
+ return nil
+}
+
+func peerConfigFromGoConfig(config cfg.Config) *PeerConfig {
+ return &PeerConfig{
+ AuthEnc: config.GetBool(configKeyAuthEnc),
+ Fuzz: config.GetBool(configFuzzEnable),
+ HandshakeTimeout: time.Duration(config.GetInt(configKeyHandshakeTimeoutSeconds)) * time.Second,
+ DialTimeout: time.Duration(config.GetInt(configKeyDialTimeoutSeconds)) * time.Second,
+ MConfig: &MConnConfig{
+ SendRate: int64(config.GetInt(configKeySendRate)),
+ RecvRate: int64(config.GetInt(configKeyRecvRate)),
+ },
+ FuzzConfig: &FuzzConnConfig{
+ Mode: config.GetInt(configFuzzMode),
+ MaxDelay: time.Duration(config.GetInt(configFuzzMaxDelayMilliseconds)) * time.Millisecond,
+ ProbDropRW: config.GetFloat64(configFuzzProbDropRW),
+ ProbDropConn: config.GetFloat64(configFuzzProbDropConn),
+ ProbSleep: config.GetFloat64(configFuzzProbSleep),
+ },
+ }
+}
diff --git a/switch_test.go b/switch_test.go
new file mode 100644
index 000000000..1f1fe69f0
--- /dev/null
+++ b/switch_test.go
@@ -0,0 +1,330 @@
+package p2p
+
+import (
+ "bytes"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ . "github.com/tendermint/tmlibs/common"
+ cfg "github.com/tendermint/go-config"
+ crypto "github.com/tendermint/go-crypto"
+ wire "github.com/tendermint/go-wire"
+)
+
+var (
+ config cfg.Config
+)
+
+func init() {
+ config = cfg.NewMapConfig(nil)
+ setConfigDefaults(config)
+}
+
+type PeerMessage struct {
+ PeerKey string
+ Bytes []byte
+ Counter int
+}
+
+type TestReactor struct {
+ BaseReactor
+
+ mtx sync.Mutex
+ channels []*ChannelDescriptor
+ peersAdded []*Peer
+ peersRemoved []*Peer
+ logMessages bool
+ msgsCounter int
+ msgsReceived map[byte][]PeerMessage
+}
+
+func NewTestReactor(channels []*ChannelDescriptor, logMessages bool) *TestReactor {
+ tr := &TestReactor{
+ channels: channels,
+ logMessages: logMessages,
+ msgsReceived: make(map[byte][]PeerMessage),
+ }
+ tr.BaseReactor = *NewBaseReactor(log, "TestReactor", tr)
+ return tr
+}
+
+func (tr *TestReactor) GetChannels() []*ChannelDescriptor {
+ return tr.channels
+}
+
+func (tr *TestReactor) AddPeer(peer *Peer) {
+ tr.mtx.Lock()
+ defer tr.mtx.Unlock()
+ tr.peersAdded = append(tr.peersAdded, peer)
+}
+
+func (tr *TestReactor) RemovePeer(peer *Peer, reason interface{}) {
+ tr.mtx.Lock()
+ defer tr.mtx.Unlock()
+ tr.peersRemoved = append(tr.peersRemoved, peer)
+}
+
+func (tr *TestReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {
+ if tr.logMessages {
+ tr.mtx.Lock()
+ defer tr.mtx.Unlock()
+ //fmt.Printf("Received: %X, %X\n", chID, msgBytes)
+ tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.Key, msgBytes, tr.msgsCounter})
+ tr.msgsCounter++
+ }
+}
+
+func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
+ tr.mtx.Lock()
+ defer tr.mtx.Unlock()
+ return tr.msgsReceived[chID]
+}
+
+//-----------------------------------------------------------------------------
+
+// convenience method for creating two switches connected to each other.
+// XXX: note this uses net.Pipe and not a proper TCP conn
+func makeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
+ // Create two switches that will be interconnected.
+ switches := MakeConnectedSwitches(2, initSwitch, Connect2Switches)
+ return switches[0], switches[1]
+}
+
+func initSwitchFunc(i int, sw *Switch) *Switch {
+ // Make two reactors of two channels each
+ sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{
+ &ChannelDescriptor{ID: byte(0x00), Priority: 10},
+ &ChannelDescriptor{ID: byte(0x01), Priority: 10},
+ }, true))
+ sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{
+ &ChannelDescriptor{ID: byte(0x02), Priority: 10},
+ &ChannelDescriptor{ID: byte(0x03), Priority: 10},
+ }, true))
+ return sw
+}
+
+func TestSwitches(t *testing.T) {
+ s1, s2 := makeSwitchPair(t, initSwitchFunc)
+ defer s1.Stop()
+ defer s2.Stop()
+
+ if s1.Peers().Size() != 1 {
+ t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size())
+ }
+ if s2.Peers().Size() != 1 {
+ t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size())
+ }
+
+ // Lets send some messages
+ ch0Msg := "channel zero"
+ ch1Msg := "channel foo"
+ ch2Msg := "channel bar"
+
+ s1.Broadcast(byte(0x00), ch0Msg)
+ s1.Broadcast(byte(0x01), ch1Msg)
+ s1.Broadcast(byte(0x02), ch2Msg)
+
+ // Wait for things to settle...
+ time.Sleep(5000 * time.Millisecond)
+
+ // Check message on ch0
+ ch0Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x00))
+ if len(ch0Msgs) != 1 {
+ t.Errorf("Expected to have received 1 message in ch0")
+ }
+ if !bytes.Equal(ch0Msgs[0].Bytes, wire.BinaryBytes(ch0Msg)) {
+ t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch0Msg), ch0Msgs[0].Bytes)
+ }
+
+ // Check message on ch1
+ ch1Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x01))
+ if len(ch1Msgs) != 1 {
+ t.Errorf("Expected to have received 1 message in ch1")
+ }
+ if !bytes.Equal(ch1Msgs[0].Bytes, wire.BinaryBytes(ch1Msg)) {
+ t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch1Msg), ch1Msgs[0].Bytes)
+ }
+
+ // Check message on ch2
+ ch2Msgs := s2.Reactor("bar").(*TestReactor).getMsgs(byte(0x02))
+ if len(ch2Msgs) != 1 {
+ t.Errorf("Expected to have received 1 message in ch2")
+ }
+ if !bytes.Equal(ch2Msgs[0].Bytes, wire.BinaryBytes(ch2Msg)) {
+ t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch2Msg), ch2Msgs[0].Bytes)
+ }
+
+}
+
+func TestConnAddrFilter(t *testing.T) {
+ s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc)
+ s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc)
+
+ c1, c2 := net.Pipe()
+
+ s1.SetAddrFilter(func(addr net.Addr) error {
+ if addr.String() == c1.RemoteAddr().String() {
+ return fmt.Errorf("Error: pipe is blacklisted")
+ }
+ return nil
+ })
+
+ // connect to good peer
+ go func() {
+ s1.addPeerWithConnection(c1)
+ }()
+ go func() {
+ s2.addPeerWithConnection(c2)
+ }()
+
+ // Wait for things to happen, peers to get added...
+ time.Sleep(100 * time.Millisecond * time.Duration(4))
+
+ defer s1.Stop()
+ defer s2.Stop()
+ if s1.Peers().Size() != 0 {
+ t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size())
+ }
+ if s2.Peers().Size() != 0 {
+ t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size())
+ }
+}
+
+func TestConnPubKeyFilter(t *testing.T) {
+ s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc)
+ s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc)
+
+ c1, c2 := net.Pipe()
+
+ // set pubkey filter
+ s1.SetPubKeyFilter(func(pubkey crypto.PubKeyEd25519) error {
+ if bytes.Equal(pubkey.Bytes(), s2.nodeInfo.PubKey.Bytes()) {
+ return fmt.Errorf("Error: pipe is blacklisted")
+ }
+ return nil
+ })
+
+ // connect to good peer
+ go func() {
+ s1.addPeerWithConnection(c1)
+ }()
+ go func() {
+ s2.addPeerWithConnection(c2)
+ }()
+
+ // Wait for things to happen, peers to get added...
+ time.Sleep(100 * time.Millisecond * time.Duration(4))
+
+ defer s1.Stop()
+ defer s2.Stop()
+ if s1.Peers().Size() != 0 {
+ t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size())
+ }
+ if s2.Peers().Size() != 0 {
+ t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size())
+ }
+}
+
+func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc)
+ sw.Start()
+ defer sw.Stop()
+
+ // simulate remote peer
+ rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()}
+ rp.Start()
+ defer rp.Stop()
+
+ peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey)
+ require.Nil(err)
+ err = sw.AddPeer(peer)
+ require.Nil(err)
+
+ // simulate failure by closing connection
+ peer.CloseConn()
+
+ time.Sleep(100 * time.Millisecond)
+
+ assert.Zero(sw.Peers().Size())
+ assert.False(peer.IsRunning())
+}
+
+func TestSwitchReconnectsToPersistentPeer(t *testing.T) {
+ assert, require := assert.New(t), require.New(t)
+
+ sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc)
+ sw.Start()
+ defer sw.Stop()
+
+ // simulate remote peer
+ rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()}
+ rp.Start()
+ defer rp.Stop()
+
+ peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey)
+ peer.makePersistent()
+ require.Nil(err)
+ err = sw.AddPeer(peer)
+ require.Nil(err)
+
+ // simulate failure by closing connection
+ peer.CloseConn()
+
+ time.Sleep(100 * time.Millisecond)
+
+ assert.NotZero(sw.Peers().Size())
+ assert.False(peer.IsRunning())
+}
+
+func BenchmarkSwitches(b *testing.B) {
+
+ b.StopTimer()
+
+ s1, s2 := makeSwitchPair(b, func(i int, sw *Switch) *Switch {
+ // Make bar reactors of bar channels each
+ sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{
+ &ChannelDescriptor{ID: byte(0x00), Priority: 10},
+ &ChannelDescriptor{ID: byte(0x01), Priority: 10},
+ }, false))
+ sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{
+ &ChannelDescriptor{ID: byte(0x02), Priority: 10},
+ &ChannelDescriptor{ID: byte(0x03), Priority: 10},
+ }, false))
+ return sw
+ })
+ defer s1.Stop()
+ defer s2.Stop()
+
+ // Allow time for goroutines to boot up
+ time.Sleep(1000 * time.Millisecond)
+ b.StartTimer()
+
+ numSuccess, numFailure := 0, 0
+
+ // Send random message from foo channel to another
+ for i := 0; i < b.N; i++ {
+ chID := byte(i % 4)
+ successChan := s1.Broadcast(chID, "test data")
+ for s := range successChan {
+ if s {
+ numSuccess++
+ } else {
+ numFailure++
+ }
+ }
+ }
+
+ log.Warn(Fmt("success: %v, failure: %v", numSuccess, numFailure))
+
+ // Allow everything to flush before stopping switches & closing connections.
+ b.StopTimer()
+ time.Sleep(1000 * time.Millisecond)
+
+}
diff --git a/types.go b/types.go
new file mode 100644
index 000000000..4f3e4c1d8
--- /dev/null
+++ b/types.go
@@ -0,0 +1,77 @@
+package p2p
+
+import (
+ "fmt"
+ "net"
+ "strconv"
+ "strings"
+
+ "github.com/tendermint/go-crypto"
+)
+
+const maxNodeInfoSize = 10240 // 10Kb
+
+type NodeInfo struct {
+ PubKey crypto.PubKeyEd25519 `json:"pub_key"`
+ Moniker string `json:"moniker"`
+ Network string `json:"network"`
+ RemoteAddr string `json:"remote_addr"`
+ ListenAddr string `json:"listen_addr"`
+ Version string `json:"version"` // major.minor.revision
+ Other []string `json:"other"` // other application specific data
+}
+
+// CONTRACT: two nodes are compatible if the major/minor versions match and network match
+func (info *NodeInfo) CompatibleWith(other *NodeInfo) error {
+ iMajor, iMinor, _, iErr := splitVersion(info.Version)
+ oMajor, oMinor, _, oErr := splitVersion(other.Version)
+
+ // if our own version number is not formatted right, we messed up
+ if iErr != nil {
+ return iErr
+ }
+
+ // version number must be formatted correctly ("x.x.x")
+ if oErr != nil {
+ return oErr
+ }
+
+ // major version must match
+ if iMajor != oMajor {
+ return fmt.Errorf("Peer is on a different major version. Got %v, expected %v", oMajor, iMajor)
+ }
+
+ // minor version must match
+ if iMinor != oMinor {
+ return fmt.Errorf("Peer is on a different minor version. Got %v, expected %v", oMinor, iMinor)
+ }
+
+ // nodes must be on the same network
+ if info.Network != other.Network {
+ return fmt.Errorf("Peer is on a different network. Got %v, expected %v", other.Network, info.Network)
+ }
+
+ return nil
+}
+
+func (info *NodeInfo) ListenHost() string {
+ host, _, _ := net.SplitHostPort(info.ListenAddr)
+ return host
+}
+
+func (info *NodeInfo) ListenPort() int {
+ _, port, _ := net.SplitHostPort(info.ListenAddr)
+ port_i, err := strconv.Atoi(port)
+ if err != nil {
+ return -1
+ }
+ return port_i
+}
+
+func splitVersion(version string) (string, string, string, error) {
+ spl := strings.Split(version, ".")
+ if len(spl) != 3 {
+ return "", "", "", fmt.Errorf("Invalid version format %v", version)
+ }
+ return spl[0], spl[1], spl[2], nil
+}
diff --git a/upnp/README.md b/upnp/README.md
new file mode 100644
index 000000000..557d05bdc
--- /dev/null
+++ b/upnp/README.md
@@ -0,0 +1,5 @@
+# `tendermint/p2p/upnp`
+
+## Resources
+
+* http://www.upnp-hacks.org/upnp.html
diff --git a/upnp/log.go b/upnp/log.go
new file mode 100644
index 000000000..45e44439c
--- /dev/null
+++ b/upnp/log.go
@@ -0,0 +1,7 @@
+package upnp
+
+import (
+ "github.com/tendermint/tmlibs/logger"
+)
+
+var log = logger.New("module", "upnp")
diff --git a/upnp/probe.go b/upnp/probe.go
new file mode 100644
index 000000000..5488de587
--- /dev/null
+++ b/upnp/probe.go
@@ -0,0 +1,111 @@
+package upnp
+
+import (
+ "errors"
+ "fmt"
+ "net"
+ "time"
+
+ . "github.com/tendermint/tmlibs/common"
+)
+
+type UPNPCapabilities struct {
+ PortMapping bool
+ Hairpin bool
+}
+
+func makeUPNPListener(intPort int, extPort int) (NAT, net.Listener, net.IP, error) {
+ nat, err := Discover()
+ if err != nil {
+ return nil, nil, nil, errors.New(fmt.Sprintf("NAT upnp could not be discovered: %v", err))
+ }
+ log.Info(Fmt("ourIP: %v", nat.(*upnpNAT).ourIP))
+
+ ext, err := nat.GetExternalAddress()
+ if err != nil {
+ return nat, nil, nil, errors.New(fmt.Sprintf("External address error: %v", err))
+ }
+ log.Info(Fmt("External address: %v", ext))
+
+ port, err := nat.AddPortMapping("tcp", extPort, intPort, "Tendermint UPnP Probe", 0)
+ if err != nil {
+ return nat, nil, ext, errors.New(fmt.Sprintf("Port mapping error: %v", err))
+ }
+ log.Info(Fmt("Port mapping mapped: %v", port))
+
+ // also run the listener, open for all remote addresses.
+ listener, err := net.Listen("tcp", fmt.Sprintf(":%v", intPort))
+ if err != nil {
+ return nat, nil, ext, errors.New(fmt.Sprintf("Error establishing listener: %v", err))
+ }
+ return nat, listener, ext, nil
+}
+
+func testHairpin(listener net.Listener, extAddr string) (supportsHairpin bool) {
+ // Listener
+ go func() {
+ inConn, err := listener.Accept()
+ if err != nil {
+ log.Notice(Fmt("Listener.Accept() error: %v", err))
+ return
+ }
+ log.Info(Fmt("Accepted incoming connection: %v -> %v", inConn.LocalAddr(), inConn.RemoteAddr()))
+ buf := make([]byte, 1024)
+ n, err := inConn.Read(buf)
+ if err != nil {
+ log.Notice(Fmt("Incoming connection read error: %v", err))
+ return
+ }
+ log.Info(Fmt("Incoming connection read %v bytes: %X", n, buf))
+ if string(buf) == "test data" {
+ supportsHairpin = true
+ return
+ }
+ }()
+
+ // Establish outgoing
+ outConn, err := net.Dial("tcp", extAddr)
+ if err != nil {
+ log.Notice(Fmt("Outgoing connection dial error: %v", err))
+ return
+ }
+
+ n, err := outConn.Write([]byte("test data"))
+ if err != nil {
+ log.Notice(Fmt("Outgoing connection write error: %v", err))
+ return
+ }
+ log.Info(Fmt("Outgoing connection wrote %v bytes", n))
+
+ // Wait for data receipt
+ time.Sleep(1 * time.Second)
+ return
+}
+
+func Probe() (caps UPNPCapabilities, err error) {
+ log.Info("Probing for UPnP!")
+
+ intPort, extPort := 8001, 8001
+
+ nat, listener, ext, err := makeUPNPListener(intPort, extPort)
+ if err != nil {
+ return
+ }
+ caps.PortMapping = true
+
+ // Deferred cleanup
+ defer func() {
+ err = nat.DeletePortMapping("tcp", intPort, extPort)
+ if err != nil {
+ log.Warn(Fmt("Port mapping delete error: %v", err))
+ }
+ listener.Close()
+ }()
+
+ supportsHairpin := testHairpin(listener, fmt.Sprintf("%v:%v", ext, extPort))
+ if supportsHairpin {
+ caps.Hairpin = true
+ }
+
+ return
+}
diff --git a/upnp/upnp.go b/upnp/upnp.go
new file mode 100644
index 000000000..3d6c55035
--- /dev/null
+++ b/upnp/upnp.go
@@ -0,0 +1,380 @@
+/*
+Taken from taipei-torrent
+
+Just enough UPnP to be able to forward ports
+*/
+package upnp
+
+// BUG(jae): TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh
+
+import (
+ "bytes"
+ "encoding/xml"
+ "errors"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "strconv"
+ "strings"
+ "time"
+)
+
+type upnpNAT struct {
+ serviceURL string
+ ourIP string
+ urnDomain string
+}
+
+// protocol is either "udp" or "tcp"
+type NAT interface {
+ GetExternalAddress() (addr net.IP, err error)
+ AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error)
+ DeletePortMapping(protocol string, externalPort, internalPort int) (err error)
+}
+
+func Discover() (nat NAT, err error) {
+ ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900")
+ if err != nil {
+ return
+ }
+ conn, err := net.ListenPacket("udp4", ":0")
+ if err != nil {
+ return
+ }
+ socket := conn.(*net.UDPConn)
+ defer socket.Close()
+
+ err = socket.SetDeadline(time.Now().Add(3 * time.Second))
+ if err != nil {
+ return
+ }
+
+ st := "InternetGatewayDevice:1"
+
+ buf := bytes.NewBufferString(
+ "M-SEARCH * HTTP/1.1\r\n" +
+ "HOST: 239.255.255.250:1900\r\n" +
+ "ST: ssdp:all\r\n" +
+ "MAN: \"ssdp:discover\"\r\n" +
+ "MX: 2\r\n\r\n")
+ message := buf.Bytes()
+ answerBytes := make([]byte, 1024)
+ for i := 0; i < 3; i++ {
+ _, err = socket.WriteToUDP(message, ssdp)
+ if err != nil {
+ return
+ }
+ var n int
+ n, _, err = socket.ReadFromUDP(answerBytes)
+ for {
+ n, _, err = socket.ReadFromUDP(answerBytes)
+ if err != nil {
+ break
+ }
+ answer := string(answerBytes[0:n])
+ if strings.Index(answer, st) < 0 {
+ continue
+ }
+ // HTTP header field names are case-insensitive.
+ // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
+ locString := "\r\nlocation:"
+ answer = strings.ToLower(answer)
+ locIndex := strings.Index(answer, locString)
+ if locIndex < 0 {
+ continue
+ }
+ loc := answer[locIndex+len(locString):]
+ endIndex := strings.Index(loc, "\r\n")
+ if endIndex < 0 {
+ continue
+ }
+ locURL := strings.TrimSpace(loc[0:endIndex])
+ var serviceURL, urnDomain string
+ serviceURL, urnDomain, err = getServiceURL(locURL)
+ if err != nil {
+ return
+ }
+ var ourIP net.IP
+ ourIP, err = localIPv4()
+ if err != nil {
+ return
+ }
+ nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain}
+ return
+ }
+ }
+ err = errors.New("UPnP port discovery failed.")
+ return
+}
+
+type Envelope struct {
+ XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"`
+ Soap *SoapBody
+}
+type SoapBody struct {
+ XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"`
+ ExternalIP *ExternalIPAddressResponse
+}
+
+type ExternalIPAddressResponse struct {
+ XMLName xml.Name `xml:"GetExternalIPAddressResponse"`
+ IPAddress string `xml:"NewExternalIPAddress"`
+}
+
+type ExternalIPAddress struct {
+ XMLName xml.Name `xml:"NewExternalIPAddress"`
+ IP string
+}
+
+type UPNPService struct {
+ ServiceType string `xml:"serviceType"`
+ ControlURL string `xml:"controlURL"`
+}
+
+type DeviceList struct {
+ Device []Device `xml:"device"`
+}
+
+type ServiceList struct {
+ Service []UPNPService `xml:"service"`
+}
+
+type Device struct {
+ XMLName xml.Name `xml:"device"`
+ DeviceType string `xml:"deviceType"`
+ DeviceList DeviceList `xml:"deviceList"`
+ ServiceList ServiceList `xml:"serviceList"`
+}
+
+type Root struct {
+ Device Device
+}
+
+func getChildDevice(d *Device, deviceType string) *Device {
+ dl := d.DeviceList.Device
+ for i := 0; i < len(dl); i++ {
+ if strings.Index(dl[i].DeviceType, deviceType) >= 0 {
+ return &dl[i]
+ }
+ }
+ return nil
+}
+
+func getChildService(d *Device, serviceType string) *UPNPService {
+ sl := d.ServiceList.Service
+ for i := 0; i < len(sl); i++ {
+ if strings.Index(sl[i].ServiceType, serviceType) >= 0 {
+ return &sl[i]
+ }
+ }
+ return nil
+}
+
+func localIPv4() (net.IP, error) {
+ tt, err := net.Interfaces()
+ if err != nil {
+ return nil, err
+ }
+ for _, t := range tt {
+ aa, err := t.Addrs()
+ if err != nil {
+ return nil, err
+ }
+ for _, a := range aa {
+ ipnet, ok := a.(*net.IPNet)
+ if !ok {
+ continue
+ }
+ v4 := ipnet.IP.To4()
+ if v4 == nil || v4[0] == 127 { // loopback address
+ continue
+ }
+ return v4, nil
+ }
+ }
+ return nil, errors.New("cannot find local IP address")
+}
+
+func getServiceURL(rootURL string) (url, urnDomain string, err error) {
+ r, err := http.Get(rootURL)
+ if err != nil {
+ return
+ }
+ defer r.Body.Close()
+ if r.StatusCode >= 400 {
+ err = errors.New(string(r.StatusCode))
+ return
+ }
+ var root Root
+ err = xml.NewDecoder(r.Body).Decode(&root)
+ if err != nil {
+ return
+ }
+ a := &root.Device
+ if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 {
+ err = errors.New("No InternetGatewayDevice")
+ return
+ }
+ b := getChildDevice(a, "WANDevice:1")
+ if b == nil {
+ err = errors.New("No WANDevice")
+ return
+ }
+ c := getChildDevice(b, "WANConnectionDevice:1")
+ if c == nil {
+ err = errors.New("No WANConnectionDevice")
+ return
+ }
+ d := getChildService(c, "WANIPConnection:1")
+ if d == nil {
+ // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice,
+ // instead of under WanConnectionDevice
+ d = getChildService(b, "WANIPConnection:1")
+
+ if d == nil {
+ err = errors.New("No WANIPConnection")
+ return
+ }
+ }
+ // Extract the domain name, which isn't always 'schemas-upnp-org'
+ urnDomain = strings.Split(d.ServiceType, ":")[1]
+ url = combineURL(rootURL, d.ControlURL)
+ return
+}
+
+func combineURL(rootURL, subURL string) string {
+ protocolEnd := "://"
+ protoEndIndex := strings.Index(rootURL, protocolEnd)
+ a := rootURL[protoEndIndex+len(protocolEnd):]
+ rootIndex := strings.Index(a, "/")
+ return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL
+}
+
+func soapRequest(url, function, message, domain string) (r *http.Response, err error) {
+ fullMessage := "" +
+ "\r\n" +
+ "" + message + ""
+
+ req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"")
+ req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3")
+ //req.Header.Set("Transfer-Encoding", "chunked")
+ req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"")
+ req.Header.Set("Connection", "Close")
+ req.Header.Set("Cache-Control", "no-cache")
+ req.Header.Set("Pragma", "no-cache")
+
+ // log.Stderr("soapRequest ", req)
+
+ r, err = http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ /*if r.Body != nil {
+ defer r.Body.Close()
+ }*/
+
+ if r.StatusCode >= 400 {
+ // log.Stderr(function, r.StatusCode)
+ err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function)
+ r = nil
+ return
+ }
+ return
+}
+
+type statusInfo struct {
+ externalIpAddress string
+}
+
+func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) {
+
+ message := "\r\n" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+ var envelope Envelope
+ data, err := ioutil.ReadAll(response.Body)
+ reader := bytes.NewReader(data)
+ xml.NewDecoder(reader).Decode(&envelope)
+
+ info = statusInfo{envelope.Soap.ExternalIP.IPAddress}
+
+ if err != nil {
+ return
+ }
+
+ return
+}
+
+func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) {
+ info, err := n.getExternalIPAddress()
+ if err != nil {
+ return
+ }
+ addr = net.ParseIP(info.externalIpAddress)
+ return
+}
+
+func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) {
+ // A single concatenation would break ARM compilation.
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort)
+ message += "" + protocol + ""
+ message += "" + strconv.Itoa(internalPort) + "" +
+ "" + n.ourIP + "" +
+ "1"
+ message += description +
+ "" + strconv.Itoa(timeout) +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was forwarded
+ // log.Println(message, response)
+ // JAE:
+ // body, err := ioutil.ReadAll(response.Body)
+ // fmt.Println(string(body), err)
+ mappedExternalPort = externalPort
+ _ = response
+ return
+}
+
+func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) {
+
+ message := "\r\n" +
+ "" + strconv.Itoa(externalPort) +
+ "" + protocol + "" +
+ ""
+
+ var response *http.Response
+ response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain)
+ if response != nil {
+ defer response.Body.Close()
+ }
+ if err != nil {
+ return
+ }
+
+ // TODO: check response to see if the port was deleted
+ // log.Println(message, response)
+ _ = response
+ return
+}
diff --git a/util.go b/util.go
new file mode 100644
index 000000000..2be320263
--- /dev/null
+++ b/util.go
@@ -0,0 +1,15 @@
+package p2p
+
+import (
+ "crypto/sha256"
+)
+
+// doubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes.
+func doubleSha256(b []byte) []byte {
+ hasher := sha256.New()
+ hasher.Write(b)
+ sum := hasher.Sum(nil)
+ hasher.Reset()
+ hasher.Write(sum)
+ return hasher.Sum(nil)
+}
diff --git a/version.go b/version.go
new file mode 100644
index 000000000..9a4c7bbaf
--- /dev/null
+++ b/version.go
@@ -0,0 +1,3 @@
+package p2p
+
+const Version = "0.5.0"