diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..cae2f4c9f --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,78 @@ +# Changelog + +## 0.5.0 (April 21, 2017) + +BREAKING CHANGES: + +- Remove or unexport methods from FuzzedConnection: Active, Mode, ProbDropRW, ProbDropConn, ProbSleep, MaxDelayMilliseconds, Fuzz +- switch.AddPeerWithConnection is unexported and replaced by switch.AddPeer +- switch.DialPeerWithAddress takes a bool, setting the peer as persistent or not + +FEATURES: + +- Persistent peers: any peer considered a "seed" will be reconnected to when the connection is dropped + + +IMPROVEMENTS: + +- Many more tests and comments +- Refactor configurations for less dependence on go-config. Introduces new structs PeerConfig, MConnConfig, FuzzConnConfig +- New methods on peer: CloseConn, HandshakeTimeout, IsPersistent, Addr, PubKey +- NewNetAddress supports a testing mode where the address defaults to 0.0.0.0:0 + + +## 0.4.0 (March 6, 2017) + +BREAKING CHANGES: + +- DialSeeds now takes an AddrBook and returns an error: `DialSeeds(*AddrBook, []string) error` +- NewNetAddressString now returns an error: `NewNetAddressString(string) (*NetAddress, error)` + +FEATURES: + +- `NewNetAddressStrings([]string) ([]*NetAddress, error)` +- `AddrBook.Save()` + +IMPROVEMENTS: + +- PexReactor responsible for starting and stopping the AddrBook + +BUG FIXES: + +- DialSeeds returns an error instead of panicking on bad addresses + +## 0.3.5 (January 12, 2017) + +FEATURES + +- Toggle strict routability in the AddrBook + +BUG FIXES + +- Close filtered out connections +- Fixes for MakeConnectedSwitches and Connect2Switches + +## 0.3.4 (August 10, 2016) + +FEATURES: + +- Optionally filter connections by address or public key + +## 0.3.3 (May 12, 2016) + +FEATURES: + +- FuzzConn + +## 0.3.2 (March 12, 2016) + +IMPROVEMENTS: + +- Memory optimizations + +## 0.3.1 () + +FEATURES: + +- Configurable parameters + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..3716185f2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM golang:latest + +RUN curl https://glide.sh/get | sh + +RUN mkdir -p /go/src/github.com/tendermint/go-p2p +WORKDIR /go/src/github.com/tendermint/go-p2p + +COPY glide.yaml /go/src/github.com/tendermint/go-p2p/ +COPY glide.lock /go/src/github.com/tendermint/go-p2p/ + +RUN glide install + +COPY . /go/src/github.com/tendermint/go-p2p diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..e908e0f95 --- /dev/null +++ b/LICENSE @@ -0,0 +1,193 @@ +Tendermint Go-P2P +Copyright (C) 2015 Tendermint + + + + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..236efd367 --- /dev/null +++ b/README.md @@ -0,0 +1,79 @@ +# `tendermint/go-p2p` + +[![CircleCI](https://circleci.com/gh/tendermint/go-p2p.svg?style=svg)](https://circleci.com/gh/tendermint/go-p2p) + +`tendermint/go-p2p` provides an abstraction around peer-to-peer communication.
+ +## Peer/MConnection/Channel + +Each peer has one `MConnection` (multiplex connection) instance. + +__multiplex__ *noun* a system or signal involving simultaneous transmission of +several messages along a single channel of communication. + +Each `MConnection` handles message transmission on multiple abstract communication +`Channel`s. Each channel has a globally unique byte id. +The byte id and the relative priorities of each `Channel` are configured upon +initialization of the connection. + +There are two methods for sending messages: +```go +func (m MConnection) Send(chID byte, msg interface{}) bool {} +func (m MConnection) TrySend(chID byte, msg interface{}) bool {} +``` + +`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued +for the channel with the given id byte `chID`. The message `msg` is serialized +using the `tendermint/wire` submodule's `WriteBinary()` reflection routine. + +`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's +queue is full. + +`Send()` and `TrySend()` are also exposed for each `Peer`. + +## Switch/Reactor + +The `Switch` handles peer connections and exposes an API to receive incoming messages +on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one +or more `Channels`. So while sending outgoing messages is typically performed on the peer, +incoming messages are received on the reactor. + +```go +// Declare a MyReactor reactor that handles messages on MyChannelID. +type MyReactor struct{} + +func (reactor MyReactor) GetChannels() []*ChannelDescriptor { + return []*ChannelDescriptor{ChannelDescriptor{ID:MyChannelID, Priority: 1}} +} + +func (reactor MyReactor) Receive(chID byte, peer *Peer, msgBytes []byte) { + r, n, err := bytes.NewBuffer(msgBytes), new(int64), new(error) + msgString := ReadString(r, n, err) + fmt.Println(msgString) +} + +// Other Reactor methods omitted for brevity +... + +switch := NewSwitch([]Reactor{MyReactor{}}) + +... + +// Send a random message to all outbound connections +for _, peer := range switch.Peers().List() { + if peer.IsOutbound() { + peer.Send(MyChannelID, "Here's a random message") + } +} +``` + +### PexReactor/AddrBook + +A `PEXReactor` reactor implementation is provided to automate peer discovery. + +```go +book := p2p.NewAddrBook(addrBookFilePath) +pexReactor := p2p.NewPEXReactor(book) +... +switch := NewSwitch([]Reactor{pexReactor, myReactor, ...}) +``` diff --git a/addrbook.go b/addrbook.go new file mode 100644 index 000000000..e68cc7b3a --- /dev/null +++ b/addrbook.go @@ -0,0 +1,839 @@ +// Modified for Tendermint +// Originally Copyright (c) 2013-2014 Conformal Systems LLC. +// https://github.com/conformal/btcd/blob/master/LICENSE + +package p2p + +import ( + "encoding/binary" + "encoding/json" + "math" + "math/rand" + "net" + "os" + "sync" + "time" + + . "github.com/tendermint/tmlibs/common" + crypto "github.com/tendermint/go-crypto" +) + +const ( + // addresses under which the address manager will claim to need more addresses. + needAddressThreshold = 1000 + + // interval used to dump the address cache to disk for future use. + dumpAddressInterval = time.Minute * 2 + + // max addresses in each old address bucket. + oldBucketSize = 64 + + // buckets we split old addresses over. + oldBucketCount = 64 + + // max addresses in each new address bucket. + newBucketSize = 64 + + // buckets that we spread new addresses over. + newBucketCount = 256 + + // old buckets over which an address group will be spread. + oldBucketsPerGroup = 4 + + // new buckets over which an source address group will be spread. + newBucketsPerGroup = 32 + + // buckets a frequently seen new address may end up in. + maxNewBucketsPerAddress = 4 + + // days before which we assume an address has vanished + // if we have not seen it announced in that long. + numMissingDays = 30 + + // tries without a single success before we assume an address is bad. + numRetries = 3 + + // max failures we will accept without a success before considering an address bad. + maxFailures = 10 + + // days since the last success before we will consider evicting an address. + minBadDays = 7 + + // % of total addresses known returned by GetSelection. + getSelectionPercent = 23 + + // min addresses that must be returned by GetSelection. Useful for bootstrapping. + minGetSelection = 32 + + // max addresses returned by GetSelection + // NOTE: this must match "maxPexMessageSize" + maxGetSelection = 250 + + // current version of the on-disk format. + serializationVersion = 1 +) + +const ( + bucketTypeNew = 0x01 + bucketTypeOld = 0x02 +) + +// AddrBook - concurrency safe peer address manager. +type AddrBook struct { + BaseService + + mtx sync.Mutex + filePath string + routabilityStrict bool + rand *rand.Rand + key string + ourAddrs map[string]*NetAddress + addrLookup map[string]*knownAddress // new & old + addrNew []map[string]*knownAddress + addrOld []map[string]*knownAddress + wg sync.WaitGroup + nOld int + nNew int +} + +// NewAddrBook creates a new address book. +// Use Start to begin processing asynchronous address updates. +func NewAddrBook(filePath string, routabilityStrict bool) *AddrBook { + am := &AddrBook{ + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + ourAddrs: make(map[string]*NetAddress), + addrLookup: make(map[string]*knownAddress), + filePath: filePath, + routabilityStrict: routabilityStrict, + } + am.init() + am.BaseService = *NewBaseService(log, "AddrBook", am) + return am +} + +// When modifying this, don't forget to update loadFromFile() +func (a *AddrBook) init() { + a.key = crypto.CRandHex(24) // 24/2 * 8 = 96 bits + // New addr buckets + a.addrNew = make([]map[string]*knownAddress, newBucketCount) + for i := range a.addrNew { + a.addrNew[i] = make(map[string]*knownAddress) + } + // Old addr buckets + a.addrOld = make([]map[string]*knownAddress, oldBucketCount) + for i := range a.addrOld { + a.addrOld[i] = make(map[string]*knownAddress) + } +} + +// OnStart implements Service. +func (a *AddrBook) OnStart() error { + a.BaseService.OnStart() + a.loadFromFile(a.filePath) + a.wg.Add(1) + go a.saveRoutine() + return nil +} + +// OnStop implements Service. +func (a *AddrBook) OnStop() { + a.BaseService.OnStop() +} + +func (a *AddrBook) Wait() { + a.wg.Wait() +} + +func (a *AddrBook) AddOurAddress(addr *NetAddress) { + a.mtx.Lock() + defer a.mtx.Unlock() + log.Info("Add our address to book", "addr", addr) + a.ourAddrs[addr.String()] = addr +} + +func (a *AddrBook) OurAddresses() []*NetAddress { + addrs := []*NetAddress{} + for _, addr := range a.ourAddrs { + addrs = append(addrs, addr) + } + return addrs +} + +// NOTE: addr must not be nil +func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) { + a.mtx.Lock() + defer a.mtx.Unlock() + log.Info("Add address to book", "addr", addr, "src", src) + a.addAddress(addr, src) +} + +func (a *AddrBook) NeedMoreAddrs() bool { + return a.Size() < needAddressThreshold +} + +func (a *AddrBook) Size() int { + a.mtx.Lock() + defer a.mtx.Unlock() + return a.size() +} + +func (a *AddrBook) size() int { + return a.nNew + a.nOld +} + +// Pick an address to connect to with new/old bias. +func (a *AddrBook) PickAddress(newBias int) *NetAddress { + a.mtx.Lock() + defer a.mtx.Unlock() + + if a.size() == 0 { + return nil + } + if newBias > 100 { + newBias = 100 + } + if newBias < 0 { + newBias = 0 + } + + // Bias between new and old addresses. + oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias)) + newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias) + + if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation { + // pick random Old bucket. + var bucket map[string]*knownAddress = nil + for len(bucket) == 0 { + bucket = a.addrOld[a.rand.Intn(len(a.addrOld))] + } + // pick a random ka from bucket. + randIndex := a.rand.Intn(len(bucket)) + for _, ka := range bucket { + if randIndex == 0 { + return ka.Addr + } + randIndex-- + } + PanicSanity("Should not happen") + } else { + // pick random New bucket. + var bucket map[string]*knownAddress = nil + for len(bucket) == 0 { + bucket = a.addrNew[a.rand.Intn(len(a.addrNew))] + } + // pick a random ka from bucket. + randIndex := a.rand.Intn(len(bucket)) + for _, ka := range bucket { + if randIndex == 0 { + return ka.Addr + } + randIndex-- + } + PanicSanity("Should not happen") + } + return nil +} + +func (a *AddrBook) MarkGood(addr *NetAddress) { + a.mtx.Lock() + defer a.mtx.Unlock() + ka := a.addrLookup[addr.String()] + if ka == nil { + return + } + ka.markGood() + if ka.isNew() { + a.moveToOld(ka) + } +} + +func (a *AddrBook) MarkAttempt(addr *NetAddress) { + a.mtx.Lock() + defer a.mtx.Unlock() + ka := a.addrLookup[addr.String()] + if ka == nil { + return + } + ka.markAttempt() +} + +// MarkBad currently just ejects the address. In the future, consider +// blacklisting. +func (a *AddrBook) MarkBad(addr *NetAddress) { + a.RemoveAddress(addr) +} + +// RemoveAddress removes the address from the book. +func (a *AddrBook) RemoveAddress(addr *NetAddress) { + a.mtx.Lock() + defer a.mtx.Unlock() + ka := a.addrLookup[addr.String()] + if ka == nil { + return + } + log.Info("Remove address from book", "addr", addr) + a.removeFromAllBuckets(ka) +} + +/* Peer exchange */ + +// GetSelection randomly selects some addresses (old & new). Suitable for peer-exchange protocols. +func (a *AddrBook) GetSelection() []*NetAddress { + a.mtx.Lock() + defer a.mtx.Unlock() + + if a.size() == 0 { + return nil + } + + allAddr := make([]*NetAddress, a.size()) + i := 0 + for _, v := range a.addrLookup { + allAddr[i] = v.Addr + i++ + } + + numAddresses := MaxInt( + MinInt(minGetSelection, len(allAddr)), + len(allAddr)*getSelectionPercent/100) + numAddresses = MinInt(maxGetSelection, numAddresses) + + // Fisher-Yates shuffle the array. We only need to do the first + // `numAddresses' since we are throwing the rest. + for i := 0; i < numAddresses; i++ { + // pick a number between current index and the end + j := rand.Intn(len(allAddr)-i) + i + allAddr[i], allAddr[j] = allAddr[j], allAddr[i] + } + + // slice off the limit we are willing to share. + return allAddr[:numAddresses] +} + +/* Loading & Saving */ + +type addrBookJSON struct { + Key string + Addrs []*knownAddress +} + +func (a *AddrBook) saveToFile(filePath string) { + log.Info("Saving AddrBook to file", "size", a.Size()) + + a.mtx.Lock() + defer a.mtx.Unlock() + // Compile Addrs + addrs := []*knownAddress{} + for _, ka := range a.addrLookup { + addrs = append(addrs, ka) + } + + aJSON := &addrBookJSON{ + Key: a.key, + Addrs: addrs, + } + + jsonBytes, err := json.MarshalIndent(aJSON, "", "\t") + if err != nil { + log.Error("Failed to save AddrBook to file", "err", err) + return + } + err = WriteFileAtomic(filePath, jsonBytes, 0644) + if err != nil { + log.Error("Failed to save AddrBook to file", "file", filePath, "error", err) + } +} + +// Returns false if file does not exist. +// Panics if file is corrupt. +func (a *AddrBook) loadFromFile(filePath string) bool { + // If doesn't exist, do nothing. + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return false + } + + // Load addrBookJSON{} + r, err := os.Open(filePath) + if err != nil { + PanicCrisis(Fmt("Error opening file %s: %v", filePath, err)) + } + defer r.Close() + aJSON := &addrBookJSON{} + dec := json.NewDecoder(r) + err = dec.Decode(aJSON) + if err != nil { + PanicCrisis(Fmt("Error reading file %s: %v", filePath, err)) + } + + // Restore all the fields... + // Restore the key + a.key = aJSON.Key + // Restore .addrNew & .addrOld + for _, ka := range aJSON.Addrs { + for _, bucketIndex := range ka.Buckets { + bucket := a.getBucket(ka.BucketType, bucketIndex) + bucket[ka.Addr.String()] = ka + } + a.addrLookup[ka.Addr.String()] = ka + if ka.BucketType == bucketTypeNew { + a.nNew++ + } else { + a.nOld++ + } + } + return true +} + +// Save saves the book. +func (a *AddrBook) Save() { + log.Info("Saving AddrBook to file", "size", a.Size()) + a.saveToFile(a.filePath) +} + +/* Private methods */ + +func (a *AddrBook) saveRoutine() { + dumpAddressTicker := time.NewTicker(dumpAddressInterval) +out: + for { + select { + case <-dumpAddressTicker.C: + a.saveToFile(a.filePath) + case <-a.Quit: + break out + } + } + dumpAddressTicker.Stop() + a.saveToFile(a.filePath) + a.wg.Done() + log.Notice("Address handler done") +} + +func (a *AddrBook) getBucket(bucketType byte, bucketIdx int) map[string]*knownAddress { + switch bucketType { + case bucketTypeNew: + return a.addrNew[bucketIdx] + case bucketTypeOld: + return a.addrOld[bucketIdx] + default: + PanicSanity("Should not happen") + return nil + } +} + +// Adds ka to new bucket. Returns false if it couldn't do it cuz buckets full. +// NOTE: currently it always returns true. +func (a *AddrBook) addToNewBucket(ka *knownAddress, bucketIdx int) bool { + // Sanity check + if ka.isOld() { + log.Warn(Fmt("Cannot add address already in old bucket to a new bucket: %v", ka)) + return false + } + + addrStr := ka.Addr.String() + bucket := a.getBucket(bucketTypeNew, bucketIdx) + + // Already exists? + if _, ok := bucket[addrStr]; ok { + return true + } + + // Enforce max addresses. + if len(bucket) > newBucketSize { + log.Notice("new bucket is full, expiring old ") + a.expireNew(bucketIdx) + } + + // Add to bucket. + bucket[addrStr] = ka + if ka.addBucketRef(bucketIdx) == 1 { + a.nNew++ + } + + // Ensure in addrLookup + a.addrLookup[addrStr] = ka + + return true +} + +// Adds ka to old bucket. Returns false if it couldn't do it cuz buckets full. +func (a *AddrBook) addToOldBucket(ka *knownAddress, bucketIdx int) bool { + // Sanity check + if ka.isNew() { + log.Warn(Fmt("Cannot add new address to old bucket: %v", ka)) + return false + } + if len(ka.Buckets) != 0 { + log.Warn(Fmt("Cannot add already old address to another old bucket: %v", ka)) + return false + } + + addrStr := ka.Addr.String() + bucket := a.getBucket(bucketTypeNew, bucketIdx) + + // Already exists? + if _, ok := bucket[addrStr]; ok { + return true + } + + // Enforce max addresses. + if len(bucket) > oldBucketSize { + return false + } + + // Add to bucket. + bucket[addrStr] = ka + if ka.addBucketRef(bucketIdx) == 1 { + a.nOld++ + } + + // Ensure in addrLookup + a.addrLookup[addrStr] = ka + + return true +} + +func (a *AddrBook) removeFromBucket(ka *knownAddress, bucketType byte, bucketIdx int) { + if ka.BucketType != bucketType { + log.Warn(Fmt("Bucket type mismatch: %v", ka)) + return + } + bucket := a.getBucket(bucketType, bucketIdx) + delete(bucket, ka.Addr.String()) + if ka.removeBucketRef(bucketIdx) == 0 { + if bucketType == bucketTypeNew { + a.nNew-- + } else { + a.nOld-- + } + delete(a.addrLookup, ka.Addr.String()) + } +} + +func (a *AddrBook) removeFromAllBuckets(ka *knownAddress) { + for _, bucketIdx := range ka.Buckets { + bucket := a.getBucket(ka.BucketType, bucketIdx) + delete(bucket, ka.Addr.String()) + } + ka.Buckets = nil + if ka.BucketType == bucketTypeNew { + a.nNew-- + } else { + a.nOld-- + } + delete(a.addrLookup, ka.Addr.String()) +} + +func (a *AddrBook) pickOldest(bucketType byte, bucketIdx int) *knownAddress { + bucket := a.getBucket(bucketType, bucketIdx) + var oldest *knownAddress + for _, ka := range bucket { + if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt) { + oldest = ka + } + } + return oldest +} + +func (a *AddrBook) addAddress(addr, src *NetAddress) { + if a.routabilityStrict && !addr.Routable() { + log.Warn(Fmt("Cannot add non-routable address %v", addr)) + return + } + if _, ok := a.ourAddrs[addr.String()]; ok { + // Ignore our own listener address. + return + } + + ka := a.addrLookup[addr.String()] + + if ka != nil { + // Already old. + if ka.isOld() { + return + } + // Already in max new buckets. + if len(ka.Buckets) == maxNewBucketsPerAddress { + return + } + // The more entries we have, the less likely we are to add more. + factor := int32(2 * len(ka.Buckets)) + if a.rand.Int31n(factor) != 0 { + return + } + } else { + ka = newKnownAddress(addr, src) + } + + bucket := a.calcNewBucket(addr, src) + a.addToNewBucket(ka, bucket) + + log.Notice("Added new address", "address", addr, "total", a.size()) +} + +// Make space in the new buckets by expiring the really bad entries. +// If no bad entries are available we remove the oldest. +func (a *AddrBook) expireNew(bucketIdx int) { + for addrStr, ka := range a.addrNew[bucketIdx] { + // If an entry is bad, throw it away + if ka.isBad() { + log.Notice(Fmt("expiring bad address %v", addrStr)) + a.removeFromBucket(ka, bucketTypeNew, bucketIdx) + return + } + } + + // If we haven't thrown out a bad entry, throw out the oldest entry + oldest := a.pickOldest(bucketTypeNew, bucketIdx) + a.removeFromBucket(oldest, bucketTypeNew, bucketIdx) +} + +// Promotes an address from new to old. +// TODO: Move to old probabilistically. +// The better a node is, the less likely it should be evicted from an old bucket. +func (a *AddrBook) moveToOld(ka *knownAddress) { + // Sanity check + if ka.isOld() { + log.Warn(Fmt("Cannot promote address that is already old %v", ka)) + return + } + if len(ka.Buckets) == 0 { + log.Warn(Fmt("Cannot promote address that isn't in any new buckets %v", ka)) + return + } + + // Remember one of the buckets in which ka is in. + freedBucket := ka.Buckets[0] + // Remove from all (new) buckets. + a.removeFromAllBuckets(ka) + // It's officially old now. + ka.BucketType = bucketTypeOld + + // Try to add it to its oldBucket destination. + oldBucketIdx := a.calcOldBucket(ka.Addr) + added := a.addToOldBucket(ka, oldBucketIdx) + if !added { + // No room, must evict something + oldest := a.pickOldest(bucketTypeOld, oldBucketIdx) + a.removeFromBucket(oldest, bucketTypeOld, oldBucketIdx) + // Find new bucket to put oldest in + newBucketIdx := a.calcNewBucket(oldest.Addr, oldest.Src) + added := a.addToNewBucket(oldest, newBucketIdx) + // No space in newBucket either, just put it in freedBucket from above. + if !added { + added := a.addToNewBucket(oldest, freedBucket) + if !added { + log.Warn(Fmt("Could not migrate oldest %v to freedBucket %v", oldest, freedBucket)) + } + } + // Finally, add to bucket again. + added = a.addToOldBucket(ka, oldBucketIdx) + if !added { + log.Warn(Fmt("Could not re-add ka %v to oldBucketIdx %v", ka, oldBucketIdx)) + } + } +} + +// doublesha256( key + sourcegroup + +// int64(doublesha256(key + group + sourcegroup))%bucket_per_group ) % num_new_buckets +func (a *AddrBook) calcNewBucket(addr, src *NetAddress) int { + data1 := []byte{} + data1 = append(data1, []byte(a.key)...) + data1 = append(data1, []byte(a.groupKey(addr))...) + data1 = append(data1, []byte(a.groupKey(src))...) + hash1 := doubleSha256(data1) + hash64 := binary.BigEndian.Uint64(hash1) + hash64 %= newBucketsPerGroup + var hashbuf [8]byte + binary.BigEndian.PutUint64(hashbuf[:], hash64) + data2 := []byte{} + data2 = append(data2, []byte(a.key)...) + data2 = append(data2, a.groupKey(src)...) + data2 = append(data2, hashbuf[:]...) + + hash2 := doubleSha256(data2) + return int(binary.BigEndian.Uint64(hash2) % newBucketCount) +} + +// doublesha256( key + group + +// int64(doublesha256(key + addr))%buckets_per_group ) % num_old_buckets +func (a *AddrBook) calcOldBucket(addr *NetAddress) int { + data1 := []byte{} + data1 = append(data1, []byte(a.key)...) + data1 = append(data1, []byte(addr.String())...) + hash1 := doubleSha256(data1) + hash64 := binary.BigEndian.Uint64(hash1) + hash64 %= oldBucketsPerGroup + var hashbuf [8]byte + binary.BigEndian.PutUint64(hashbuf[:], hash64) + data2 := []byte{} + data2 = append(data2, []byte(a.key)...) + data2 = append(data2, a.groupKey(addr)...) + data2 = append(data2, hashbuf[:]...) + + hash2 := doubleSha256(data2) + return int(binary.BigEndian.Uint64(hash2) % oldBucketCount) +} + +// Return a string representing the network group of this address. +// This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string +// "local" for a local address and the string "unroutable for an unroutable +// address. +func (a *AddrBook) groupKey(na *NetAddress) string { + if a.routabilityStrict && na.Local() { + return "local" + } + if a.routabilityStrict && !na.Routable() { + return "unroutable" + } + + if ipv4 := na.IP.To4(); ipv4 != nil { + return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String() + } + if na.RFC6145() || na.RFC6052() { + // last four bytes are the ip address + ip := net.IP(na.IP[12:16]) + return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() + } + + if na.RFC3964() { + ip := net.IP(na.IP[2:7]) + return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() + + } + if na.RFC4380() { + // teredo tunnels have the last 4 bytes as the v4 address XOR + // 0xff. + ip := net.IP(make([]byte, 4)) + for i, byte := range na.IP[12:16] { + ip[i] = byte ^ 0xff + } + return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() + } + + // OK, so now we know ourselves to be a IPv6 address. + // bitcoind uses /32 for everything, except for Hurricane Electric's + // (he.net) IP range, which it uses /36 for. + bits := 32 + heNet := &net.IPNet{IP: net.ParseIP("2001:470::"), + Mask: net.CIDRMask(32, 128)} + if heNet.Contains(na.IP) { + bits = 36 + } + + return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String() +} + +//----------------------------------------------------------------------------- + +/* + knownAddress + + tracks information about a known network address that is used + to determine how viable an address is. +*/ +type knownAddress struct { + Addr *NetAddress + Src *NetAddress + Attempts int32 + LastAttempt time.Time + LastSuccess time.Time + BucketType byte + Buckets []int +} + +func newKnownAddress(addr *NetAddress, src *NetAddress) *knownAddress { + return &knownAddress{ + Addr: addr, + Src: src, + Attempts: 0, + LastAttempt: time.Now(), + BucketType: bucketTypeNew, + Buckets: nil, + } +} + +func (ka *knownAddress) isOld() bool { + return ka.BucketType == bucketTypeOld +} + +func (ka *knownAddress) isNew() bool { + return ka.BucketType == bucketTypeNew +} + +func (ka *knownAddress) markAttempt() { + now := time.Now() + ka.LastAttempt = now + ka.Attempts += 1 +} + +func (ka *knownAddress) markGood() { + now := time.Now() + ka.LastAttempt = now + ka.Attempts = 0 + ka.LastSuccess = now +} + +func (ka *knownAddress) addBucketRef(bucketIdx int) int { + for _, bucket := range ka.Buckets { + if bucket == bucketIdx { + log.Warn(Fmt("Bucket already exists in ka.Buckets: %v", ka)) + return -1 + } + } + ka.Buckets = append(ka.Buckets, bucketIdx) + return len(ka.Buckets) +} + +func (ka *knownAddress) removeBucketRef(bucketIdx int) int { + buckets := []int{} + for _, bucket := range ka.Buckets { + if bucket != bucketIdx { + buckets = append(buckets, bucket) + } + } + if len(buckets) != len(ka.Buckets)-1 { + log.Warn(Fmt("bucketIdx not found in ka.Buckets: %v", ka)) + return -1 + } + ka.Buckets = buckets + return len(ka.Buckets) +} + +/* + An address is bad if the address in question has not been tried in the last + minute and meets one of the following criteria: + + 1) It claims to be from the future + 2) It hasn't been seen in over a month + 3) It has failed at least three times and never succeeded + 4) It has failed ten times in the last week + + All addresses that meet these criteria are assumed to be worthless and not + worth keeping hold of. +*/ +func (ka *knownAddress) isBad() bool { + // Has been attempted in the last minute --> good + if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) { + return false + } + + // Over a month old? + if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) { + return true + } + + // Never succeeded? + if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries { + return true + } + + // Hasn't succeeded in too long? + if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) && + ka.Attempts >= maxFailures { + return true + } + + return false +} diff --git a/addrbook_test.go b/addrbook_test.go new file mode 100644 index 000000000..16aea8ef9 --- /dev/null +++ b/addrbook_test.go @@ -0,0 +1,166 @@ +package p2p + +import ( + "fmt" + "io/ioutil" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" +) + +func createTempFileName(prefix string) string { + f, err := ioutil.TempFile("", prefix) + if err != nil { + panic(err) + } + fname := f.Name() + err = f.Close() + if err != nil { + panic(err) + } + return fname +} + +func TestAddrBookSaveLoad(t *testing.T) { + fname := createTempFileName("addrbook_test") + + // 0 addresses + book := NewAddrBook(fname, true) + book.saveToFile(fname) + + book = NewAddrBook(fname, true) + book.loadFromFile(fname) + + assert.Zero(t, book.Size()) + + // 100 addresses + randAddrs := randNetAddressPairs(t, 100) + + for _, addrSrc := range randAddrs { + book.AddAddress(addrSrc.addr, addrSrc.src) + } + + assert.Equal(t, 100, book.Size()) + book.saveToFile(fname) + + book = NewAddrBook(fname, true) + book.loadFromFile(fname) + + assert.Equal(t, 100, book.Size()) +} + +func TestAddrBookLookup(t *testing.T) { + fname := createTempFileName("addrbook_test") + + randAddrs := randNetAddressPairs(t, 100) + + book := NewAddrBook(fname, true) + for _, addrSrc := range randAddrs { + addr := addrSrc.addr + src := addrSrc.src + book.AddAddress(addr, src) + + ka := book.addrLookup[addr.String()] + assert.NotNil(t, ka, "Expected to find KnownAddress %v but wasn't there.", addr) + + if !(ka.Addr.Equals(addr) && ka.Src.Equals(src)) { + t.Fatalf("KnownAddress doesn't match addr & src") + } + } +} + +func TestAddrBookPromoteToOld(t *testing.T) { + fname := createTempFileName("addrbook_test") + + randAddrs := randNetAddressPairs(t, 100) + + book := NewAddrBook(fname, true) + for _, addrSrc := range randAddrs { + book.AddAddress(addrSrc.addr, addrSrc.src) + } + + // Attempt all addresses. + for _, addrSrc := range randAddrs { + book.MarkAttempt(addrSrc.addr) + } + + // Promote half of them + for i, addrSrc := range randAddrs { + if i%2 == 0 { + book.MarkGood(addrSrc.addr) + } + } + + // TODO: do more testing :) + + selection := book.GetSelection() + t.Logf("selection: %v", selection) + + if len(selection) > book.Size() { + t.Errorf("selection could not be bigger than the book") + } +} + +func TestAddrBookHandlesDuplicates(t *testing.T) { + fname := createTempFileName("addrbook_test") + + book := NewAddrBook(fname, true) + + randAddrs := randNetAddressPairs(t, 100) + + differentSrc := randIPv4Address(t) + for _, addrSrc := range randAddrs { + book.AddAddress(addrSrc.addr, addrSrc.src) + book.AddAddress(addrSrc.addr, addrSrc.src) // duplicate + book.AddAddress(addrSrc.addr, differentSrc) // different src + } + + assert.Equal(t, 100, book.Size()) +} + +type netAddressPair struct { + addr *NetAddress + src *NetAddress +} + +func randNetAddressPairs(t *testing.T, n int) []netAddressPair { + randAddrs := make([]netAddressPair, n) + for i := 0; i < n; i++ { + randAddrs[i] = netAddressPair{addr: randIPv4Address(t), src: randIPv4Address(t)} + } + return randAddrs +} + +func randIPv4Address(t *testing.T) *NetAddress { + for { + ip := fmt.Sprintf("%v.%v.%v.%v", + rand.Intn(254)+1, + rand.Intn(255), + rand.Intn(255), + rand.Intn(255), + ) + port := rand.Intn(65535-1) + 1 + addr, err := NewNetAddressString(fmt.Sprintf("%v:%v", ip, port)) + assert.Nil(t, err, "error generating rand network address") + if addr.Routable() { + return addr + } + } +} + +func TestAddrBookRemoveAddress(t *testing.T) { + fname := createTempFileName("addrbook_test") + book := NewAddrBook(fname, true) + + addr := randIPv4Address(t) + book.AddAddress(addr, addr) + assert.Equal(t, 1, book.Size()) + + book.RemoveAddress(addr) + assert.Equal(t, 0, book.Size()) + + nonExistingAddr := randIPv4Address(t) + book.RemoveAddress(nonExistingAddr) + assert.Equal(t, 0, book.Size()) +} diff --git a/config.go b/config.go new file mode 100644 index 000000000..a8b7e343b --- /dev/null +++ b/config.go @@ -0,0 +1,45 @@ +package p2p + +import ( + cfg "github.com/tendermint/go-config" +) + +const ( + // Switch config keys + configKeyDialTimeoutSeconds = "dial_timeout_seconds" + configKeyHandshakeTimeoutSeconds = "handshake_timeout_seconds" + configKeyMaxNumPeers = "max_num_peers" + configKeyAuthEnc = "authenticated_encryption" + + // MConnection config keys + configKeySendRate = "send_rate" + configKeyRecvRate = "recv_rate" + + // Fuzz params + configFuzzEnable = "fuzz_enable" // use the fuzz wrapped conn + configFuzzMode = "fuzz_mode" // eg. drop, delay + configFuzzMaxDelayMilliseconds = "fuzz_max_delay_milliseconds" + configFuzzProbDropRW = "fuzz_prob_drop_rw" + configFuzzProbDropConn = "fuzz_prob_drop_conn" + configFuzzProbSleep = "fuzz_prob_sleep" +) + +func setConfigDefaults(config cfg.Config) { + // Switch default config + config.SetDefault(configKeyDialTimeoutSeconds, 3) + config.SetDefault(configKeyHandshakeTimeoutSeconds, 20) + config.SetDefault(configKeyMaxNumPeers, 50) + config.SetDefault(configKeyAuthEnc, true) + + // MConnection default config + config.SetDefault(configKeySendRate, 512000) // 500KB/s + config.SetDefault(configKeyRecvRate, 512000) // 500KB/s + + // Fuzz defaults + config.SetDefault(configFuzzEnable, false) + config.SetDefault(configFuzzMode, FuzzModeDrop) + config.SetDefault(configFuzzMaxDelayMilliseconds, 3000) + config.SetDefault(configFuzzProbDropRW, 0.2) + config.SetDefault(configFuzzProbDropConn, 0.00) + config.SetDefault(configFuzzProbSleep, 0.00) +} diff --git a/connection.go b/connection.go new file mode 100644 index 000000000..629ab7b0d --- /dev/null +++ b/connection.go @@ -0,0 +1,686 @@ +package p2p + +import ( + "bufio" + "fmt" + "io" + "math" + "net" + "runtime/debug" + "sync/atomic" + "time" + + wire "github.com/tendermint/go-wire" + cmn "github.com/tendermint/tmlibs/common" + flow "github.com/tendermint/tmlibs/flowrate" +) + +const ( + numBatchMsgPackets = 10 + minReadBufferSize = 1024 + minWriteBufferSize = 65536 + updateState = 2 * time.Second + pingTimeout = 40 * time.Second + flushThrottle = 100 * time.Millisecond + + defaultSendQueueCapacity = 1 + defaultSendRate = int64(512000) // 500KB/s + defaultRecvBufferCapacity = 4096 + defaultRecvMessageCapacity = 22020096 // 21MB + defaultRecvRate = int64(512000) // 500KB/s + defaultSendTimeout = 10 * time.Second +) + +type receiveCbFunc func(chID byte, msgBytes []byte) +type errorCbFunc func(interface{}) + +/* +Each peer has one `MConnection` (multiplex connection) instance. + +__multiplex__ *noun* a system or signal involving simultaneous transmission of +several messages along a single channel of communication. + +Each `MConnection` handles message transmission on multiple abstract communication +`Channel`s. Each channel has a globally unique byte id. +The byte id and the relative priorities of each `Channel` are configured upon +initialization of the connection. + +There are two methods for sending messages: + func (m MConnection) Send(chID byte, msg interface{}) bool {} + func (m MConnection) TrySend(chID byte, msg interface{}) bool {} + +`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued +for the channel with the given id byte `chID`, or until the request times out. +The message `msg` is serialized using the `tendermint/wire` submodule's +`WriteBinary()` reflection routine. + +`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's +queue is full. + +Inbound message bytes are handled with an onReceive callback function. +*/ +type MConnection struct { + cmn.BaseService + + conn net.Conn + bufReader *bufio.Reader + bufWriter *bufio.Writer + sendMonitor *flow.Monitor + recvMonitor *flow.Monitor + send chan struct{} + pong chan struct{} + channels []*Channel + channelsIdx map[byte]*Channel + onReceive receiveCbFunc + onError errorCbFunc + errored uint32 + config *MConnConfig + + quit chan struct{} + flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. + pingTimer *cmn.RepeatTimer // send pings periodically + chStatsTimer *cmn.RepeatTimer // update channel stats periodically + + LocalAddress *NetAddress + RemoteAddress *NetAddress +} + +// MConnConfig is a MConnection configuration. +type MConnConfig struct { + SendRate int64 + RecvRate int64 +} + +// DefaultMConnConfig returns the default config. +func DefaultMConnConfig() *MConnConfig { + return &MConnConfig{ + SendRate: defaultSendRate, + RecvRate: defaultRecvRate, + } +} + +// NewMConnection wraps net.Conn and creates multiplex connection +func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection { + return NewMConnectionWithConfig( + conn, + chDescs, + onReceive, + onError, + DefaultMConnConfig()) +} + +// NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config +func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection { + mconn := &MConnection{ + conn: conn, + bufReader: bufio.NewReaderSize(conn, minReadBufferSize), + bufWriter: bufio.NewWriterSize(conn, minWriteBufferSize), + sendMonitor: flow.New(0, 0), + recvMonitor: flow.New(0, 0), + send: make(chan struct{}, 1), + pong: make(chan struct{}), + onReceive: onReceive, + onError: onError, + config: config, + + LocalAddress: NewNetAddress(conn.LocalAddr()), + RemoteAddress: NewNetAddress(conn.RemoteAddr()), + } + + // Create channels + var channelsIdx = map[byte]*Channel{} + var channels = []*Channel{} + + for _, desc := range chDescs { + descCopy := *desc // copy the desc else unsafe access across connections + channel := newChannel(mconn, &descCopy) + channelsIdx[channel.id] = channel + channels = append(channels, channel) + } + mconn.channels = channels + mconn.channelsIdx = channelsIdx + + mconn.BaseService = *cmn.NewBaseService(log, "MConnection", mconn) + + return mconn +} + +func (c *MConnection) OnStart() error { + c.BaseService.OnStart() + c.quit = make(chan struct{}) + c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle) + c.pingTimer = cmn.NewRepeatTimer("ping", pingTimeout) + c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateState) + go c.sendRoutine() + go c.recvRoutine() + return nil +} + +func (c *MConnection) OnStop() { + c.BaseService.OnStop() + c.flushTimer.Stop() + c.pingTimer.Stop() + c.chStatsTimer.Stop() + if c.quit != nil { + close(c.quit) + } + c.conn.Close() + // We can't close pong safely here because + // recvRoutine may write to it after we've stopped. + // Though it doesn't need to get closed at all, + // we close it @ recvRoutine. + // close(c.pong) +} + +func (c *MConnection) String() string { + return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr()) +} + +func (c *MConnection) flush() { + log.Debug("Flush", "conn", c) + err := c.bufWriter.Flush() + if err != nil { + log.Warn("MConnection flush failed", "error", err) + } +} + +// Catch panics, usually caused by remote disconnects. +func (c *MConnection) _recover() { + if r := recover(); r != nil { + stack := debug.Stack() + err := cmn.StackError{r, stack} + c.stopForError(err) + } +} + +func (c *MConnection) stopForError(r interface{}) { + c.Stop() + if atomic.CompareAndSwapUint32(&c.errored, 0, 1) { + if c.onError != nil { + c.onError(r) + } + } +} + +// Queues a message to be sent to channel. +func (c *MConnection) Send(chID byte, msg interface{}) bool { + if !c.IsRunning() { + return false + } + + log.Debug("Send", "channel", chID, "conn", c, "msg", msg) //, "bytes", wire.BinaryBytes(msg)) + + // Send message to channel. + channel, ok := c.channelsIdx[chID] + if !ok { + log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID)) + return false + } + + success := channel.sendBytes(wire.BinaryBytes(msg)) + if success { + // Wake up sendRoutine if necessary + select { + case c.send <- struct{}{}: + default: + } + } else { + log.Warn("Send failed", "channel", chID, "conn", c, "msg", msg) + } + return success +} + +// Queues a message to be sent to channel. +// Nonblocking, returns true if successful. +func (c *MConnection) TrySend(chID byte, msg interface{}) bool { + if !c.IsRunning() { + return false + } + + log.Debug("TrySend", "channel", chID, "conn", c, "msg", msg) + + // Send message to channel. + channel, ok := c.channelsIdx[chID] + if !ok { + log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID)) + return false + } + + ok = channel.trySendBytes(wire.BinaryBytes(msg)) + if ok { + // Wake up sendRoutine if necessary + select { + case c.send <- struct{}{}: + default: + } + } + + return ok +} + +// CanSend returns true if you can send more data onto the chID, false +// otherwise. Use only as a heuristic. +func (c *MConnection) CanSend(chID byte) bool { + if !c.IsRunning() { + return false + } + + channel, ok := c.channelsIdx[chID] + if !ok { + log.Error(cmn.Fmt("Unknown channel %X", chID)) + return false + } + return channel.canSend() +} + +// sendRoutine polls for packets to send from channels. +func (c *MConnection) sendRoutine() { + defer c._recover() + +FOR_LOOP: + for { + var n int + var err error + select { + case <-c.flushTimer.Ch: + // NOTE: flushTimer.Set() must be called every time + // something is written to .bufWriter. + c.flush() + case <-c.chStatsTimer.Ch: + for _, channel := range c.channels { + channel.updateStats() + } + case <-c.pingTimer.Ch: + log.Debug("Send Ping") + wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) + c.sendMonitor.Update(int(n)) + c.flush() + case <-c.pong: + log.Debug("Send Pong") + wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) + c.sendMonitor.Update(int(n)) + c.flush() + case <-c.quit: + break FOR_LOOP + case <-c.send: + // Send some msgPackets + eof := c.sendSomeMsgPackets() + if !eof { + // Keep sendRoutine awake. + select { + case c.send <- struct{}{}: + default: + } + } + } + + if !c.IsRunning() { + break FOR_LOOP + } + if err != nil { + log.Warn("Connection failed @ sendRoutine", "conn", c, "error", err) + c.stopForError(err) + break FOR_LOOP + } + } + + // Cleanup +} + +// Returns true if messages from channels were exhausted. +// Blocks in accordance to .sendMonitor throttling. +func (c *MConnection) sendSomeMsgPackets() bool { + // Block until .sendMonitor says we can write. + // Once we're ready we send more than we asked for, + // but amortized it should even out. + c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true) + + // Now send some msgPackets. + for i := 0; i < numBatchMsgPackets; i++ { + if c.sendMsgPacket() { + return true + } + } + return false +} + +// Returns true if messages from channels were exhausted. +func (c *MConnection) sendMsgPacket() bool { + // Choose a channel to create a msgPacket from. + // The chosen channel will be the one whose recentlySent/priority is the least. + var leastRatio float32 = math.MaxFloat32 + var leastChannel *Channel + for _, channel := range c.channels { + // If nothing to send, skip this channel + if !channel.isSendPending() { + continue + } + // Get ratio, and keep track of lowest ratio. + ratio := float32(channel.recentlySent) / float32(channel.priority) + if ratio < leastRatio { + leastRatio = ratio + leastChannel = channel + } + } + + // Nothing to send? + if leastChannel == nil { + return true + } else { + // log.Info("Found a msgPacket to send") + } + + // Make & send a msgPacket from this channel + n, err := leastChannel.writeMsgPacketTo(c.bufWriter) + if err != nil { + log.Warn("Failed to write msgPacket", "error", err) + c.stopForError(err) + return true + } + c.sendMonitor.Update(int(n)) + c.flushTimer.Set() + return false +} + +// recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer. +// After a whole message has been assembled, it's pushed to onReceive(). +// Blocks depending on how the connection is throttled. +func (c *MConnection) recvRoutine() { + defer c._recover() + +FOR_LOOP: + for { + // Block until .recvMonitor says we can read. + c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true) + + /* + // Peek into bufReader for debugging + if numBytes := c.bufReader.Buffered(); numBytes > 0 { + log.Info("Peek connection buffer", "numBytes", numBytes, "bytes", log15.Lazy{func() []byte { + bytes, err := c.bufReader.Peek(MinInt(numBytes, 100)) + if err == nil { + return bytes + } else { + log.Warn("Error peeking connection buffer", "error", err) + return nil + } + }}) + } + */ + + // Read packet type + var n int + var err error + pktType := wire.ReadByte(c.bufReader, &n, &err) + c.recvMonitor.Update(int(n)) + if err != nil { + if c.IsRunning() { + log.Warn("Connection failed @ recvRoutine (reading byte)", "conn", c, "error", err) + c.stopForError(err) + } + break FOR_LOOP + } + + // Read more depending on packet type. + switch pktType { + case packetTypePing: + // TODO: prevent abuse, as they cause flush()'s. + log.Debug("Receive Ping") + c.pong <- struct{}{} + case packetTypePong: + // do nothing + log.Debug("Receive Pong") + case packetTypeMsg: + pkt, n, err := msgPacket{}, int(0), error(nil) + wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err) + c.recvMonitor.Update(int(n)) + if err != nil { + if c.IsRunning() { + log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err) + c.stopForError(err) + } + break FOR_LOOP + } + channel, ok := c.channelsIdx[pkt.ChannelID] + if !ok || channel == nil { + cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID)) + } + msgBytes, err := channel.recvMsgPacket(pkt) + if err != nil { + if c.IsRunning() { + log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err) + c.stopForError(err) + } + break FOR_LOOP + } + if msgBytes != nil { + log.Debug("Received bytes", "chID", pkt.ChannelID, "msgBytes", msgBytes) + c.onReceive(pkt.ChannelID, msgBytes) + } + default: + cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType)) + } + + // TODO: shouldn't this go in the sendRoutine? + // Better to send a ping packet when *we* haven't sent anything for a while. + c.pingTimer.Reset() + } + + // Cleanup + close(c.pong) + for _ = range c.pong { + // Drain + } +} + +type ConnectionStatus struct { + SendMonitor flow.Status + RecvMonitor flow.Status + Channels []ChannelStatus +} + +type ChannelStatus struct { + ID byte + SendQueueCapacity int + SendQueueSize int + Priority int + RecentlySent int64 +} + +func (c *MConnection) Status() ConnectionStatus { + var status ConnectionStatus + status.SendMonitor = c.sendMonitor.Status() + status.RecvMonitor = c.recvMonitor.Status() + status.Channels = make([]ChannelStatus, len(c.channels)) + for i, channel := range c.channels { + status.Channels[i] = ChannelStatus{ + ID: channel.id, + SendQueueCapacity: cap(channel.sendQueue), + SendQueueSize: int(channel.sendQueueSize), // TODO use atomic + Priority: channel.priority, + RecentlySent: channel.recentlySent, + } + } + return status +} + +//----------------------------------------------------------------------------- + +type ChannelDescriptor struct { + ID byte + Priority int + SendQueueCapacity int + RecvBufferCapacity int + RecvMessageCapacity int +} + +func (chDesc *ChannelDescriptor) FillDefaults() { + if chDesc.SendQueueCapacity == 0 { + chDesc.SendQueueCapacity = defaultSendQueueCapacity + } + if chDesc.RecvBufferCapacity == 0 { + chDesc.RecvBufferCapacity = defaultRecvBufferCapacity + } + if chDesc.RecvMessageCapacity == 0 { + chDesc.RecvMessageCapacity = defaultRecvMessageCapacity + } +} + +// TODO: lowercase. +// NOTE: not goroutine-safe. +type Channel struct { + conn *MConnection + desc *ChannelDescriptor + id byte + sendQueue chan []byte + sendQueueSize int32 // atomic. + recving []byte + sending []byte + priority int + recentlySent int64 // exponential moving average +} + +func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel { + desc.FillDefaults() + if desc.Priority <= 0 { + cmn.PanicSanity("Channel default priority must be a postive integer") + } + return &Channel{ + conn: conn, + desc: desc, + id: desc.ID, + sendQueue: make(chan []byte, desc.SendQueueCapacity), + recving: make([]byte, 0, desc.RecvBufferCapacity), + priority: desc.Priority, + } +} + +// Queues message to send to this channel. +// Goroutine-safe +// Times out (and returns false) after defaultSendTimeout +func (ch *Channel) sendBytes(bytes []byte) bool { + select { + case ch.sendQueue <- bytes: + atomic.AddInt32(&ch.sendQueueSize, 1) + return true + case <-time.After(defaultSendTimeout): + return false + } +} + +// Queues message to send to this channel. +// Nonblocking, returns true if successful. +// Goroutine-safe +func (ch *Channel) trySendBytes(bytes []byte) bool { + select { + case ch.sendQueue <- bytes: + atomic.AddInt32(&ch.sendQueueSize, 1) + return true + default: + return false + } +} + +// Goroutine-safe +func (ch *Channel) loadSendQueueSize() (size int) { + return int(atomic.LoadInt32(&ch.sendQueueSize)) +} + +// Goroutine-safe +// Use only as a heuristic. +func (ch *Channel) canSend() bool { + return ch.loadSendQueueSize() < defaultSendQueueCapacity +} + +// Returns true if any msgPackets are pending to be sent. +// Call before calling nextMsgPacket() +// Goroutine-safe +func (ch *Channel) isSendPending() bool { + if len(ch.sending) == 0 { + if len(ch.sendQueue) == 0 { + return false + } + ch.sending = <-ch.sendQueue + } + return true +} + +// Creates a new msgPacket to send. +// Not goroutine-safe +func (ch *Channel) nextMsgPacket() msgPacket { + packet := msgPacket{} + packet.ChannelID = byte(ch.id) + packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))] + if len(ch.sending) <= maxMsgPacketPayloadSize { + packet.EOF = byte(0x01) + ch.sending = nil + atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize + } else { + packet.EOF = byte(0x00) + ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):] + } + return packet +} + +// Writes next msgPacket to w. +// Not goroutine-safe +func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) { + packet := ch.nextMsgPacket() + log.Debug("Write Msg Packet", "conn", ch.conn, "packet", packet) + wire.WriteByte(packetTypeMsg, w, &n, &err) + wire.WriteBinary(packet, w, &n, &err) + if err == nil { + ch.recentlySent += int64(n) + } + return +} + +// Handles incoming msgPackets. Returns a msg bytes if msg is complete. +// Not goroutine-safe +func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) { + // log.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet) + if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) { + return nil, wire.ErrBinaryReadOverflow + } + ch.recving = append(ch.recving, packet.Bytes...) + if packet.EOF == byte(0x01) { + msgBytes := ch.recving + // clear the slice without re-allocating. + // http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go + // suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes, + // at which point the recving slice stops being used and should be garbage collected + ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity) + return msgBytes, nil + } + return nil, nil +} + +// Call this periodically to update stats for throttling purposes. +// Not goroutine-safe +func (ch *Channel) updateStats() { + // Exponential decay of stats. + // TODO: optimize. + ch.recentlySent = int64(float64(ch.recentlySent) * 0.8) +} + +//----------------------------------------------------------------------------- + +const ( + maxMsgPacketPayloadSize = 1024 + maxMsgPacketOverheadSize = 10 // It's actually lower but good enough + maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize + packetTypePing = byte(0x01) + packetTypePong = byte(0x02) + packetTypeMsg = byte(0x03) +) + +// Messages in channels are chopped into smaller msgPackets for multiplexing. +type msgPacket struct { + ChannelID byte + EOF byte // 1 means message ends here. + Bytes []byte +} + +func (p msgPacket) String() string { + return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) +} diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 000000000..33d8adfd1 --- /dev/null +++ b/connection_test.go @@ -0,0 +1,139 @@ +package p2p_test + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + p2p "github.com/tendermint/go-p2p" +) + +func createMConnection(conn net.Conn) *p2p.MConnection { + onReceive := func(chID byte, msgBytes []byte) { + } + onError := func(r interface{}) { + } + return createMConnectionWithCallbacks(conn, onReceive, onError) +} + +func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *p2p.MConnection { + chDescs := []*p2p.ChannelDescriptor{&p2p.ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} + return p2p.NewMConnection(conn, chDescs, onReceive, onError) +} + +func TestMConnectionSend(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + mconn := createMConnection(client) + _, err := mconn.Start() + require.Nil(err) + defer mconn.Stop() + + msg := "Ant-Man" + assert.True(mconn.Send(0x01, msg)) + // Note: subsequent Send/TrySend calls could pass because we are reading from + // the send queue in a separate goroutine. + server.Read(make([]byte, len(msg))) + assert.True(mconn.CanSend(0x01)) + + msg = "Spider-Man" + assert.True(mconn.TrySend(0x01, msg)) + server.Read(make([]byte, len(msg))) + + assert.False(mconn.CanSend(0x05), "CanSend should return false because channel is unknown") + assert.False(mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown") +} + +func TestMConnectionReceive(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn1 := createMConnectionWithCallbacks(client, onReceive, onError) + _, err := mconn1.Start() + require.Nil(err) + defer mconn1.Stop() + + mconn2 := createMConnection(server) + _, err = mconn2.Start() + require.Nil(err) + defer mconn2.Stop() + + msg := "Cyclops" + assert.True(mconn2.Send(0x01, msg)) + + select { + case receivedBytes := <-receivedCh: + assert.Equal([]byte(msg), receivedBytes[2:]) // first 3 bytes are internal + case err := <-errorsCh: + t.Fatalf("Expected %s, got %+v", msg, err) + case <-time.After(500 * time.Millisecond): + t.Fatalf("Did not receive %s message in 500ms", msg) + } +} + +func TestMConnectionStatus(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + mconn := createMConnection(client) + _, err := mconn.Start() + require.Nil(err) + defer mconn.Stop() + + status := mconn.Status() + assert.NotNil(status) + assert.Zero(status.Channels[0].SendQueueSize) +} + +func TestMConnectionStopsAndReturnsError(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + _, err := mconn.Start() + require.Nil(err) + defer mconn.Stop() + + client.Close() + + select { + case receivedBytes := <-receivedCh: + t.Fatalf("Expected error, got %v", receivedBytes) + case err := <-errorsCh: + assert.NotNil(err) + assert.False(mconn.IsRunning()) + case <-time.After(500 * time.Millisecond): + t.Fatal("Did not receive error in 500ms") + } +} diff --git a/fuzz.go b/fuzz.go new file mode 100644 index 000000000..aefac986a --- /dev/null +++ b/fuzz.go @@ -0,0 +1,173 @@ +package p2p + +import ( + "math/rand" + "net" + "sync" + "time" +) + +const ( + // FuzzModeDrop is a mode in which we randomly drop reads/writes, connections or sleep + FuzzModeDrop = iota + // FuzzModeDelay is a mode in which we randomly sleep + FuzzModeDelay +) + +// FuzzedConnection wraps any net.Conn and depending on the mode either delays +// reads/writes or randomly drops reads/writes/connections. +type FuzzedConnection struct { + conn net.Conn + + mtx sync.Mutex + start <-chan time.Time + active bool + + config *FuzzConnConfig +} + +// FuzzConnConfig is a FuzzedConnection configuration. +type FuzzConnConfig struct { + Mode int + MaxDelay time.Duration + ProbDropRW float64 + ProbDropConn float64 + ProbSleep float64 +} + +// DefaultFuzzConnConfig returns the default config. +func DefaultFuzzConnConfig() *FuzzConnConfig { + return &FuzzConnConfig{ + Mode: FuzzModeDrop, + MaxDelay: 3 * time.Second, + ProbDropRW: 0.2, + ProbDropConn: 0.00, + ProbSleep: 0.00, + } +} + +// FuzzConn creates a new FuzzedConnection. Fuzzing starts immediately. +func FuzzConn(conn net.Conn) net.Conn { + return FuzzConnFromConfig(conn, DefaultFuzzConnConfig()) +} + +// FuzzConnFromConfig creates a new FuzzedConnection from a config. Fuzzing +// starts immediately. +func FuzzConnFromConfig(conn net.Conn, config *FuzzConnConfig) net.Conn { + return &FuzzedConnection{ + conn: conn, + start: make(<-chan time.Time), + active: true, + config: config, + } +} + +// FuzzConnAfter creates a new FuzzedConnection. Fuzzing starts when the +// duration elapses. +func FuzzConnAfter(conn net.Conn, d time.Duration) net.Conn { + return FuzzConnAfterFromConfig(conn, d, DefaultFuzzConnConfig()) +} + +// FuzzConnAfterFromConfig creates a new FuzzedConnection from a config. +// Fuzzing starts when the duration elapses. +func FuzzConnAfterFromConfig(conn net.Conn, d time.Duration, config *FuzzConnConfig) net.Conn { + return &FuzzedConnection{ + conn: conn, + start: time.After(d), + active: false, + config: config, + } +} + +// Config returns the connection's config. +func (fc *FuzzedConnection) Config() *FuzzConnConfig { + return fc.config +} + +// Read implements net.Conn. +func (fc *FuzzedConnection) Read(data []byte) (n int, err error) { + if fc.fuzz() { + return 0, nil + } + return fc.conn.Read(data) +} + +// Write implements net.Conn. +func (fc *FuzzedConnection) Write(data []byte) (n int, err error) { + if fc.fuzz() { + return 0, nil + } + return fc.conn.Write(data) +} + +// Close implements net.Conn. +func (fc *FuzzedConnection) Close() error { return fc.conn.Close() } + +// LocalAddr implements net.Conn. +func (fc *FuzzedConnection) LocalAddr() net.Addr { return fc.conn.LocalAddr() } + +// RemoteAddr implements net.Conn. +func (fc *FuzzedConnection) RemoteAddr() net.Addr { return fc.conn.RemoteAddr() } + +// SetDeadline implements net.Conn. +func (fc *FuzzedConnection) SetDeadline(t time.Time) error { return fc.conn.SetDeadline(t) } + +// SetReadDeadline implements net.Conn. +func (fc *FuzzedConnection) SetReadDeadline(t time.Time) error { + return fc.conn.SetReadDeadline(t) +} + +// SetWriteDeadline implements net.Conn. +func (fc *FuzzedConnection) SetWriteDeadline(t time.Time) error { + return fc.conn.SetWriteDeadline(t) +} + +func (fc *FuzzedConnection) randomDuration() time.Duration { + maxDelayMillis := int(fc.config.MaxDelay.Nanoseconds() / 1000) + return time.Millisecond * time.Duration(rand.Int()%maxDelayMillis) +} + +// implements the fuzz (delay, kill conn) +// and returns whether or not the read/write should be ignored +func (fc *FuzzedConnection) fuzz() bool { + if !fc.shouldFuzz() { + return false + } + + switch fc.config.Mode { + case FuzzModeDrop: + // randomly drop the r/w, drop the conn, or sleep + r := rand.Float64() + if r <= fc.config.ProbDropRW { + return true + } else if r < fc.config.ProbDropRW+fc.config.ProbDropConn { + // XXX: can't this fail because machine precision? + // XXX: do we need an error? + fc.Close() + return true + } else if r < fc.config.ProbDropRW+fc.config.ProbDropConn+fc.config.ProbSleep { + time.Sleep(fc.randomDuration()) + } + case FuzzModeDelay: + // sleep a bit + time.Sleep(fc.randomDuration()) + } + return false +} + +func (fc *FuzzedConnection) shouldFuzz() bool { + if fc.active { + return true + } + + fc.mtx.Lock() + defer fc.mtx.Unlock() + + select { + case <-fc.start: + fc.active = true + return true + default: + return false + } +} diff --git a/glide.lock b/glide.lock new file mode 100644 index 000000000..5da4e6230 --- /dev/null +++ b/glide.lock @@ -0,0 +1,73 @@ +hash: f3d76bef9548cc37ad6038cb55f0812bac7e64735a99995c9da85010eef27f50 +updated: 2017-04-19T00:00:50.949249104-04:00 +imports: +- name: github.com/btcsuite/btcd + version: b8df516b4b267acf2de46be593a9d948d1d2c420 + subpackages: + - btcec +- name: github.com/btcsuite/fastsha256 + version: 637e656429416087660c84436a2a035d69d54e2e +- name: github.com/BurntSushi/toml + version: 99064174e013895bbd9b025c31100bd1d9b590ca +- name: github.com/go-stack/stack + version: 100eb0c0a9c5b306ca2fb4f165df21d80ada4b82 +- name: github.com/mattn/go-colorable + version: 9fdad7c47650b7d2e1da50644c1f4ba7f172f252 +- name: github.com/mattn/go-isatty + version: 56b76bdf51f7708750eac80fa38b952bb9f32639 +- name: github.com/pkg/errors + version: 645ef00459ed84a119197bfb8d8205042c6df63d +- name: github.com/tendermint/ed25519 + version: 1f52c6f8b8a5c7908aff4497c186af344b428925 + subpackages: + - edwards25519 + - extra25519 +- name: github.com/tendermint/tmlibs/common + version: f9e3db037330c8a8d61d3966de8473eaf01154fa +- name: github.com/tendermint/go-config + version: 620dcbbd7d587cf3599dedbf329b64311b0c307a +- name: github.com/tendermint/go-crypto + version: 0ca2c6fdb0706001ca4c4b9b80c9f428e8cf39da +- name: github.com/tendermint/go-wire/data + version: e7fcc6d081ec8518912fcdc103188275f83a3ee5 +- name: github.com/tendermint/tmlibs/flowrate + version: a20c98e61957faa93b4014fbd902f20ab9317a6a + subpackages: + - flowrate +- name: github.com/tendermint/tmlibs/logger + version: cefb3a45c0bf3c493a04e9bcd9b1540528be59f2 +- name: github.com/tendermint/go-wire + version: c1c9a57ab8038448ddea1714c0698f8051e5748c +- name: github.com/tendermint/log15 + version: ae0f3d6450da9eac7074b439c8e1c3cabf0d5ce6 + subpackages: + - term +- name: golang.org/x/crypto + version: 1f22c0103821b9390939b6776727195525381532 + subpackages: + - curve25519 + - nacl/box + - nacl/secretbox + - openpgp/armor + - openpgp/errors + - poly1305 + - ripemd160 + - salsa20/salsa +- name: golang.org/x/sys + version: 50c6bc5e4292a1d4e65c6e9be5f53be28bcbe28e + subpackages: + - unix +testImports: +- name: github.com/davecgh/go-spew + version: 6d212800a42e8ab5c146b8ace3490ee17e5225f9 + subpackages: + - spew +- name: github.com/pmezard/go-difflib + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d + subpackages: + - difflib +- name: github.com/stretchr/testify + version: 69483b4bd14f5845b5a1e55bca19e954e827f1d0 + subpackages: + - assert + - require diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 000000000..5bf7a015e --- /dev/null +++ b/glide.yaml @@ -0,0 +1,29 @@ +package: github.com/tendermint/go-p2p +import: +- package: github.com/tendermint/tmlibs/common + version: develop +- package: github.com/tendermint/go-config + version: develop +- package: github.com/tendermint/go-crypto + version: develop +- package: github.com/tendermint/go-wire/data + version: develop +- package: github.com/tendermint/tmlibs/flowrate + subpackages: + - flowrate +- package: github.com/tendermint/tmlibs/logger + version: develop +- package: github.com/tendermint/go-wire + version: develop +- package: github.com/tendermint/log15 +- package: golang.org/x/crypto + subpackages: + - nacl/box + - nacl/secretbox + - ripemd160 +- package: github.com/pkg/errors +testImport: +- package: github.com/stretchr/testify + subpackages: + - assert + - require diff --git a/ip_range_counter.go b/ip_range_counter.go new file mode 100644 index 000000000..85d9d407a --- /dev/null +++ b/ip_range_counter.go @@ -0,0 +1,29 @@ +package p2p + +import ( + "strings" +) + +// TODO Test +func AddToIPRangeCounts(counts map[string]int, ip string) map[string]int { + changes := make(map[string]int) + ipParts := strings.Split(ip, ":") + for i := 1; i < len(ipParts); i++ { + prefix := strings.Join(ipParts[:i], ":") + counts[prefix] += 1 + changes[prefix] = counts[prefix] + } + return changes +} + +// TODO Test +func CheckIPRangeCounts(counts map[string]int, limits []int) bool { + for prefix, count := range counts { + ipParts := strings.Split(prefix, ":") + numParts := len(ipParts) + if limits[numParts] < count { + return false + } + } + return true +} diff --git a/listener.go b/listener.go new file mode 100644 index 000000000..51beb5e27 --- /dev/null +++ b/listener.go @@ -0,0 +1,217 @@ +package p2p + +import ( + "fmt" + "net" + "strconv" + "time" + + . "github.com/tendermint/tmlibs/common" + "github.com/tendermint/go-p2p/upnp" +) + +type Listener interface { + Connections() <-chan net.Conn + InternalAddress() *NetAddress + ExternalAddress() *NetAddress + String() string + Stop() bool +} + +// Implements Listener +type DefaultListener struct { + BaseService + + listener net.Listener + intAddr *NetAddress + extAddr *NetAddress + connections chan net.Conn +} + +const ( + numBufferedConnections = 10 + defaultExternalPort = 8770 + tryListenSeconds = 5 +) + +func splitHostPort(addr string) (host string, port int) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + PanicSanity(err) + } + port, err = strconv.Atoi(portStr) + if err != nil { + PanicSanity(err) + } + return host, port +} + +// skipUPNP: If true, does not try getUPNPExternalAddress() +func NewDefaultListener(protocol string, lAddr string, skipUPNP bool) Listener { + // Local listen IP & port + lAddrIP, lAddrPort := splitHostPort(lAddr) + + // Create listener + var listener net.Listener + var err error + for i := 0; i < tryListenSeconds; i++ { + listener, err = net.Listen(protocol, lAddr) + if err == nil { + break + } else if i < tryListenSeconds-1 { + time.Sleep(time.Second * 1) + } + } + if err != nil { + PanicCrisis(err) + } + // Actual listener local IP & port + listenerIP, listenerPort := splitHostPort(listener.Addr().String()) + log.Info("Local listener", "ip", listenerIP, "port", listenerPort) + + // Determine internal address... + var intAddr *NetAddress + intAddr, err = NewNetAddressString(lAddr) + if err != nil { + PanicCrisis(err) + } + + // Determine external address... + var extAddr *NetAddress + if !skipUPNP { + // If the lAddrIP is INADDR_ANY, try UPnP + if lAddrIP == "" || lAddrIP == "0.0.0.0" { + extAddr = getUPNPExternalAddress(lAddrPort, listenerPort) + } + } + // Otherwise just use the local address... + if extAddr == nil { + extAddr = getNaiveExternalAddress(listenerPort) + } + if extAddr == nil { + PanicCrisis("Could not determine external address!") + } + + dl := &DefaultListener{ + listener: listener, + intAddr: intAddr, + extAddr: extAddr, + connections: make(chan net.Conn, numBufferedConnections), + } + dl.BaseService = *NewBaseService(log, "DefaultListener", dl) + dl.Start() // Started upon construction + return dl +} + +func (l *DefaultListener) OnStart() error { + l.BaseService.OnStart() + go l.listenRoutine() + return nil +} + +func (l *DefaultListener) OnStop() { + l.BaseService.OnStop() + l.listener.Close() +} + +// Accept connections and pass on the channel +func (l *DefaultListener) listenRoutine() { + for { + conn, err := l.listener.Accept() + + if !l.IsRunning() { + break // Go to cleanup + } + + // listener wasn't stopped, + // yet we encountered an error. + if err != nil { + PanicCrisis(err) + } + + l.connections <- conn + } + + // Cleanup + close(l.connections) + for _ = range l.connections { + // Drain + } +} + +// A channel of inbound connections. +// It gets closed when the listener closes. +func (l *DefaultListener) Connections() <-chan net.Conn { + return l.connections +} + +func (l *DefaultListener) InternalAddress() *NetAddress { + return l.intAddr +} + +func (l *DefaultListener) ExternalAddress() *NetAddress { + return l.extAddr +} + +// NOTE: The returned listener is already Accept()'ing. +// So it's not suitable to pass into http.Serve(). +func (l *DefaultListener) NetListener() net.Listener { + return l.listener +} + +func (l *DefaultListener) String() string { + return fmt.Sprintf("Listener(@%v)", l.extAddr) +} + +/* external address helpers */ + +// UPNP external address discovery & port mapping +func getUPNPExternalAddress(externalPort, internalPort int) *NetAddress { + log.Info("Getting UPNP external address") + nat, err := upnp.Discover() + if err != nil { + log.Info("Could not perform UPNP discover", "error", err) + return nil + } + + ext, err := nat.GetExternalAddress() + if err != nil { + log.Info("Could not get UPNP external address", "error", err) + return nil + } + + // UPnP can't seem to get the external port, so let's just be explicit. + if externalPort == 0 { + externalPort = defaultExternalPort + } + + externalPort, err = nat.AddPortMapping("tcp", externalPort, internalPort, "tendermint", 0) + if err != nil { + log.Info("Could not add UPNP port mapping", "error", err) + return nil + } + + log.Info("Got UPNP external address", "address", ext) + return NewNetAddressIPPort(ext, uint16(externalPort)) +} + +// TODO: use syscalls: http://pastebin.com/9exZG4rh +func getNaiveExternalAddress(port int) *NetAddress { + addrs, err := net.InterfaceAddrs() + if err != nil { + PanicCrisis(Fmt("Could not fetch interface addresses: %v", err)) + } + + for _, a := range addrs { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + v4 := ipnet.IP.To4() + if v4 == nil || v4[0] == 127 { + continue + } // loopback + return NewNetAddressIPPort(ipnet.IP, uint16(port)) + } + return nil +} diff --git a/listener_test.go b/listener_test.go new file mode 100644 index 000000000..0f8a54946 --- /dev/null +++ b/listener_test.go @@ -0,0 +1,40 @@ +package p2p + +import ( + "bytes" + "testing" +) + +func TestListener(t *testing.T) { + // Create a listener + l := NewDefaultListener("tcp", ":8001", true) + + // Dial the listener + lAddr := l.ExternalAddress() + connOut, err := lAddr.Dial() + if err != nil { + t.Fatalf("Could not connect to listener address %v", lAddr) + } else { + t.Logf("Created a connection to listener address %v", lAddr) + } + connIn, ok := <-l.Connections() + if !ok { + t.Fatalf("Could not get inbound connection from listener") + } + + msg := []byte("hi!") + go connIn.Write(msg) + b := make([]byte, 32) + n, err := connOut.Read(b) + if err != nil { + t.Fatalf("Error reading off connection: %v", err) + } + + b = b[:n] + if !bytes.Equal(msg, b) { + t.Fatalf("Got %s, expected %s", b, msg) + } + + // Close the server, no longer needed. + l.Stop() +} diff --git a/log.go b/log.go new file mode 100644 index 000000000..af3203409 --- /dev/null +++ b/log.go @@ -0,0 +1,7 @@ +package p2p + +import ( + "github.com/tendermint/tmlibs/logger" +) + +var log = logger.New("module", "p2p") diff --git a/netaddress.go b/netaddress.go new file mode 100644 index 000000000..09787481c --- /dev/null +++ b/netaddress.go @@ -0,0 +1,253 @@ +// Modified for Tendermint +// Originally Copyright (c) 2013-2014 Conformal Systems LLC. +// https://github.com/conformal/btcd/blob/master/LICENSE + +package p2p + +import ( + "errors" + "flag" + "net" + "strconv" + "time" + + cmn "github.com/tendermint/tmlibs/common" +) + +// NetAddress defines information about a peer on the network +// including its IP address, and port. +type NetAddress struct { + IP net.IP + Port uint16 + str string +} + +// NewNetAddress returns a new NetAddress using the provided TCP +// address. When testing, other net.Addr (except TCP) will result in +// using 0.0.0.0:0. When normal run, other net.Addr (except TCP) will +// panic. +// TODO: socks proxies? +func NewNetAddress(addr net.Addr) *NetAddress { + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + if flag.Lookup("test.v") == nil { // normal run + cmn.PanicSanity(cmn.Fmt("Only TCPAddrs are supported. Got: %v", addr)) + } else { // in testing + return NewNetAddressIPPort(net.IP("0.0.0.0"), 0) + } + } + ip := tcpAddr.IP + port := uint16(tcpAddr.Port) + return NewNetAddressIPPort(ip, port) +} + +// NewNetAddressString returns a new NetAddress using the provided +// address in the form of "IP:Port". Also resolves the host if host +// is not an IP. +func NewNetAddressString(addr string) (*NetAddress, error) { + + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + ip := net.ParseIP(host) + if ip == nil { + if len(host) > 0 { + ips, err := net.LookupIP(host) + if err != nil { + return nil, err + } + ip = ips[0] + } + } + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, err + } + + na := NewNetAddressIPPort(ip, uint16(port)) + return na, nil +} + +// NewNetAddressStrings returns an array of NetAddress'es build using +// the provided strings. +func NewNetAddressStrings(addrs []string) ([]*NetAddress, error) { + netAddrs := make([]*NetAddress, len(addrs)) + for i, addr := range addrs { + netAddr, err := NewNetAddressString(addr) + if err != nil { + return nil, errors.New(cmn.Fmt("Error in address %s: %v", addr, err)) + } + netAddrs[i] = netAddr + } + return netAddrs, nil +} + +// NewNetAddressIPPort returns a new NetAddress using the provided IP +// and port number. +func NewNetAddressIPPort(ip net.IP, port uint16) *NetAddress { + na := &NetAddress{ + IP: ip, + Port: port, + str: net.JoinHostPort( + ip.String(), + strconv.FormatUint(uint64(port), 10), + ), + } + return na +} + +// Equals reports whether na and other are the same addresses. +func (na *NetAddress) Equals(other interface{}) bool { + if o, ok := other.(*NetAddress); ok { + return na.String() == o.String() + } + + return false +} + +func (na *NetAddress) Less(other interface{}) bool { + if o, ok := other.(*NetAddress); ok { + return na.String() < o.String() + } + + cmn.PanicSanity("Cannot compare unequal types") + return false +} + +// String representation. +func (na *NetAddress) String() string { + if na.str == "" { + na.str = net.JoinHostPort( + na.IP.String(), + strconv.FormatUint(uint64(na.Port), 10), + ) + } + return na.str +} + +// Dial calls net.Dial on the address. +func (na *NetAddress) Dial() (net.Conn, error) { + conn, err := net.Dial("tcp", na.String()) + if err != nil { + return nil, err + } + return conn, nil +} + +// DialTimeout calls net.DialTimeout on the address. +func (na *NetAddress) DialTimeout(timeout time.Duration) (net.Conn, error) { + conn, err := net.DialTimeout("tcp", na.String(), timeout) + if err != nil { + return nil, err + } + return conn, nil +} + +// Routable returns true if the address is routable. +func (na *NetAddress) Routable() bool { + // TODO(oga) bitcoind doesn't include RFC3849 here, but should we? + return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() || + na.RFC4193() || na.RFC4843() || na.Local()) +} + +// For IPv4 these are either a 0 or all bits set address. For IPv6 a zero +// address or one that matches the RFC3849 documentation address format. +func (na *NetAddress) Valid() bool { + return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() || + na.IP.Equal(net.IPv4bcast)) +} + +// Local returns true if it is a local address. +func (na *NetAddress) Local() bool { + return na.IP.IsLoopback() || zero4.Contains(na.IP) +} + +// ReachabilityTo checks whenever o can be reached from na. +func (na *NetAddress) ReachabilityTo(o *NetAddress) int { + const ( + Unreachable = 0 + Default = iota + Teredo + Ipv6_weak + Ipv4 + Ipv6_strong + Private + ) + if !na.Routable() { + return Unreachable + } else if na.RFC4380() { + if !o.Routable() { + return Default + } else if o.RFC4380() { + return Teredo + } else if o.IP.To4() != nil { + return Ipv4 + } else { // ipv6 + return Ipv6_weak + } + } else if na.IP.To4() != nil { + if o.Routable() && o.IP.To4() != nil { + return Ipv4 + } + return Default + } else /* ipv6 */ { + var tunnelled bool + // Is our v6 is tunnelled? + if o.RFC3964() || o.RFC6052() || o.RFC6145() { + tunnelled = true + } + if !o.Routable() { + return Default + } else if o.RFC4380() { + return Teredo + } else if o.IP.To4() != nil { + return Ipv4 + } else if tunnelled { + // only prioritise ipv6 if we aren't tunnelling it. + return Ipv6_weak + } + return Ipv6_strong + } +} + +// RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12) +// RFC3849: IPv6 Documentation address (2001:0DB8::/32) +// RFC3927: IPv4 Autoconfig (169.254.0.0/16) +// RFC3964: IPv6 6to4 (2002::/16) +// RFC4193: IPv6 unique local (FC00::/7) +// RFC4380: IPv6 Teredo tunneling (2001::/32) +// RFC4843: IPv6 ORCHID: (2001:10::/28) +// RFC4862: IPv6 Autoconfig (FE80::/64) +// RFC6052: IPv6 well known prefix (64:FF9B::/96) +// RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96 +var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)} +var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)} +var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)} +var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)} +var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)} +var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)} +var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)} +var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)} +var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)} +var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)} +var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)} +var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)} +var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)} + +func (na *NetAddress) RFC1918() bool { + return rfc1918_10.Contains(na.IP) || + rfc1918_192.Contains(na.IP) || + rfc1918_172.Contains(na.IP) +} +func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) } +func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) } +func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) } +func (na *NetAddress) RFC4193() bool { return rfc4193.Contains(na.IP) } +func (na *NetAddress) RFC4380() bool { return rfc4380.Contains(na.IP) } +func (na *NetAddress) RFC4843() bool { return rfc4843.Contains(na.IP) } +func (na *NetAddress) RFC4862() bool { return rfc4862.Contains(na.IP) } +func (na *NetAddress) RFC6052() bool { return rfc6052.Contains(na.IP) } +func (na *NetAddress) RFC6145() bool { return rfc6145.Contains(na.IP) } diff --git a/netaddress_test.go b/netaddress_test.go new file mode 100644 index 000000000..db871fdec --- /dev/null +++ b/netaddress_test.go @@ -0,0 +1,113 @@ +package p2p + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewNetAddress(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") + require.Nil(err) + addr := NewNetAddress(tcpAddr) + + assert.Equal("127.0.0.1:8080", addr.String()) + + assert.NotPanics(func() { + NewNetAddress(&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8000}) + }, "Calling NewNetAddress with UDPAddr should not panic in testing") +} + +func TestNewNetAddressString(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + tests := []struct { + addr string + correct bool + }{ + {"127.0.0.1:8080", true}, + {"127.0.0:8080", false}, + {"a", false}, + {"127.0.0.1:a", false}, + {"a:8080", false}, + {"8082", false}, + {"127.0.0:8080000", false}, + } + + for _, t := range tests { + addr, err := NewNetAddressString(t.addr) + if t.correct { + require.Nil(err) + assert.Equal(t.addr, addr.String()) + } else { + require.NotNil(err) + } + } +} + +func TestNewNetAddressStrings(t *testing.T) { + assert, require := assert.New(t), require.New(t) + addrs, err := NewNetAddressStrings([]string{"127.0.0.1:8080", "127.0.0.2:8080"}) + require.Nil(err) + + assert.Equal(2, len(addrs)) +} + +func TestNewNetAddressIPPort(t *testing.T) { + assert := assert.New(t) + addr := NewNetAddressIPPort(net.ParseIP("127.0.0.1"), 8080) + + assert.Equal("127.0.0.1:8080", addr.String()) +} + +func TestNetAddressProperties(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + // TODO add more test cases + tests := []struct { + addr string + valid bool + local bool + routable bool + }{ + {"127.0.0.1:8080", true, true, false}, + {"ya.ru:80", true, false, true}, + } + + for _, t := range tests { + addr, err := NewNetAddressString(t.addr) + require.Nil(err) + + assert.Equal(t.valid, addr.Valid()) + assert.Equal(t.local, addr.Local()) + assert.Equal(t.routable, addr.Routable()) + } +} + +func TestNetAddressReachabilityTo(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + // TODO add more test cases + tests := []struct { + addr string + other string + reachability int + }{ + {"127.0.0.1:8080", "127.0.0.1:8081", 0}, + {"ya.ru:80", "127.0.0.1:8080", 1}, + } + + for _, t := range tests { + addr, err := NewNetAddressString(t.addr) + require.Nil(err) + + other, err := NewNetAddressString(t.other) + require.Nil(err) + + assert.Equal(t.reachability, addr.ReachabilityTo(other)) + } +} diff --git a/peer.go b/peer.go new file mode 100644 index 000000000..bf819087d --- /dev/null +++ b/peer.go @@ -0,0 +1,304 @@ +package p2p + +import ( + "fmt" + "io" + "net" + "time" + + "github.com/pkg/errors" + crypto "github.com/tendermint/go-crypto" + wire "github.com/tendermint/go-wire" + cmn "github.com/tendermint/tmlibs/common" +) + +// Peer could be marked as persistent, in which case you can use +// Redial function to reconnect. Note that inbound peers can't be +// made persistent. They should be made persistent on the other end. +// +// Before using a peer, you will need to perform a handshake on connection. +type Peer struct { + cmn.BaseService + + outbound bool + + conn net.Conn // source connection + mconn *MConnection // multiplex connection + + persistent bool + config *PeerConfig + + *NodeInfo + Key string + Data *cmn.CMap // User data. +} + +// PeerConfig is a Peer configuration. +type PeerConfig struct { + AuthEnc bool // authenticated encryption + + HandshakeTimeout time.Duration + DialTimeout time.Duration + + MConfig *MConnConfig + + Fuzz bool // fuzz connection (for testing) + FuzzConfig *FuzzConnConfig +} + +// DefaultPeerConfig returns the default config. +func DefaultPeerConfig() *PeerConfig { + return &PeerConfig{ + AuthEnc: true, + HandshakeTimeout: 2 * time.Second, + DialTimeout: 3 * time.Second, + MConfig: DefaultMConnConfig(), + Fuzz: false, + FuzzConfig: DefaultFuzzConnConfig(), + } +} + +func newOutboundPeer(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) { + return newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig()) +} + +func newOutboundPeerWithConfig(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { + conn, err := dial(addr, config) + if err != nil { + return nil, errors.Wrap(err, "Error creating peer") + } + + peer, err := newPeerFromConnAndConfig(conn, true, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config) + if err != nil { + conn.Close() + return nil, err + } + return peer, nil +} + +func newInboundPeer(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) { + return newInboundPeerWithConfig(conn, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig()) +} + +func newInboundPeerWithConfig(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { + return newPeerFromConnAndConfig(conn, false, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config) +} + +func newPeerFromConnAndConfig(rawConn net.Conn, outbound bool, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { + conn := rawConn + + // Fuzz connection + if config.Fuzz { + // so we have time to do peer handshakes and get set up + conn = FuzzConnAfterFromConfig(conn, 10*time.Second, config.FuzzConfig) + } + + // Encrypt connection + if config.AuthEnc { + conn.SetDeadline(time.Now().Add(config.HandshakeTimeout)) + + var err error + conn, err = MakeSecretConnection(conn, ourNodePrivKey) + if err != nil { + return nil, errors.Wrap(err, "Error creating peer") + } + } + + // Key and NodeInfo are set after Handshake + p := &Peer{ + outbound: outbound, + conn: conn, + config: config, + Data: cmn.NewCMap(), + } + + p.mconn = createMConnection(conn, p, reactorsByCh, chDescs, onPeerError, config.MConfig) + + p.BaseService = *cmn.NewBaseService(log, "Peer", p) + + return p, nil +} + +// CloseConn should be used when the peer was created, but never started. +func (p *Peer) CloseConn() { + p.conn.Close() +} + +// makePersistent marks the peer as persistent. +func (p *Peer) makePersistent() { + if !p.outbound { + panic("inbound peers can't be made persistent") + } + + p.persistent = true +} + +// IsPersistent returns true if the peer is persitent, false otherwise. +func (p *Peer) IsPersistent() bool { + return p.persistent +} + +// HandshakeTimeout performs a handshake between a given node and the peer. +// NOTE: blocking +func (p *Peer) HandshakeTimeout(ourNodeInfo *NodeInfo, timeout time.Duration) error { + // Set deadline for handshake so we don't block forever on conn.ReadFull + p.conn.SetDeadline(time.Now().Add(timeout)) + + var peerNodeInfo = new(NodeInfo) + var err1 error + var err2 error + cmn.Parallel( + func() { + var n int + wire.WriteBinary(ourNodeInfo, p.conn, &n, &err1) + }, + func() { + var n int + wire.ReadBinary(peerNodeInfo, p.conn, maxNodeInfoSize, &n, &err2) + log.Notice("Peer handshake", "peerNodeInfo", peerNodeInfo) + }) + if err1 != nil { + return errors.Wrap(err1, "Error during handshake/write") + } + if err2 != nil { + return errors.Wrap(err2, "Error during handshake/read") + } + + if p.config.AuthEnc { + // Check that the professed PubKey matches the sconn's. + if !peerNodeInfo.PubKey.Equals(p.PubKey().Wrap()) { + return fmt.Errorf("Ignoring connection with unmatching pubkey: %v vs %v", + peerNodeInfo.PubKey, p.PubKey()) + } + } + + // Remove deadline + p.conn.SetDeadline(time.Time{}) + + peerNodeInfo.RemoteAddr = p.Addr().String() + + p.NodeInfo = peerNodeInfo + p.Key = peerNodeInfo.PubKey.KeyString() + + return nil +} + +// Addr returns peer's network address. +func (p *Peer) Addr() net.Addr { + return p.conn.RemoteAddr() +} + +// PubKey returns peer's public key. +func (p *Peer) PubKey() crypto.PubKeyEd25519 { + if p.config.AuthEnc { + return p.conn.(*SecretConnection).RemotePubKey() + } + if p.NodeInfo == nil { + panic("Attempt to get peer's PubKey before calling Handshake") + } + return p.PubKey() +} + +// OnStart implements BaseService. +func (p *Peer) OnStart() error { + p.BaseService.OnStart() + _, err := p.mconn.Start() + return err +} + +// OnStop implements BaseService. +func (p *Peer) OnStop() { + p.BaseService.OnStop() + p.mconn.Stop() +} + +// Connection returns underlying MConnection. +func (p *Peer) Connection() *MConnection { + return p.mconn +} + +// IsOutbound returns true if the connection is outbound, false otherwise. +func (p *Peer) IsOutbound() bool { + return p.outbound +} + +// Send msg to the channel identified by chID byte. Returns false if the send +// queue is full after timeout, specified by MConnection. +func (p *Peer) Send(chID byte, msg interface{}) bool { + if !p.IsRunning() { + // see Switch#Broadcast, where we fetch the list of peers and loop over + // them - while we're looping, one peer may be removed and stopped. + return false + } + return p.mconn.Send(chID, msg) +} + +// TrySend msg to the channel identified by chID byte. Immediately returns +// false if the send queue is full. +func (p *Peer) TrySend(chID byte, msg interface{}) bool { + if !p.IsRunning() { + return false + } + return p.mconn.TrySend(chID, msg) +} + +// CanSend returns true if the send queue is not full, false otherwise. +func (p *Peer) CanSend(chID byte) bool { + if !p.IsRunning() { + return false + } + return p.mconn.CanSend(chID) +} + +// WriteTo writes the peer's public key to w. +func (p *Peer) WriteTo(w io.Writer) (n int64, err error) { + var n_ int + wire.WriteString(p.Key, w, &n_, &err) + n += int64(n_) + return +} + +// String representation. +func (p *Peer) String() string { + if p.outbound { + return fmt.Sprintf("Peer{%v %v out}", p.mconn, p.Key[:12]) + } + + return fmt.Sprintf("Peer{%v %v in}", p.mconn, p.Key[:12]) +} + +// Equals reports whenever 2 peers are actually represent the same node. +func (p *Peer) Equals(other *Peer) bool { + return p.Key == other.Key +} + +// Get the data for a given key. +func (p *Peer) Get(key string) interface{} { + return p.Data.Get(key) +} + +func dial(addr *NetAddress, config *PeerConfig) (net.Conn, error) { + log.Info("Dialing address", "address", addr) + conn, err := addr.DialTimeout(config.DialTimeout) + if err != nil { + log.Info("Failed dialing address", "address", addr, "error", err) + return nil, err + } + return conn, nil +} + +func createMConnection(conn net.Conn, p *Peer, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), config *MConnConfig) *MConnection { + onReceive := func(chID byte, msgBytes []byte) { + reactor := reactorsByCh[chID] + if reactor == nil { + cmn.PanicSanity(cmn.Fmt("Unknown channel %X", chID)) + } + reactor.Receive(chID, p, msgBytes) + } + + onError := func(r interface{}) { + onPeerError(p, r) + } + + return NewMConnectionWithConfig(conn, chDescs, onReceive, onError, config) +} diff --git a/peer_set.go b/peer_set.go new file mode 100644 index 000000000..f3bc1edaf --- /dev/null +++ b/peer_set.go @@ -0,0 +1,115 @@ +package p2p + +import ( + "sync" +) + +// IPeerSet has a (immutable) subset of the methods of PeerSet. +type IPeerSet interface { + Has(key string) bool + Get(key string) *Peer + List() []*Peer + Size() int +} + +//----------------------------------------------------------------------------- + +// PeerSet is a special structure for keeping a table of peers. +// Iteration over the peers is super fast and thread-safe. +// We also track how many peers per IP range and avoid too many +type PeerSet struct { + mtx sync.Mutex + lookup map[string]*peerSetItem + list []*Peer +} + +type peerSetItem struct { + peer *Peer + index int +} + +func NewPeerSet() *PeerSet { + return &PeerSet{ + lookup: make(map[string]*peerSetItem), + list: make([]*Peer, 0, 256), + } +} + +// Returns false if peer with key (PubKeyEd25519) is already in set +// or if we have too many peers from the peer's IP range +func (ps *PeerSet) Add(peer *Peer) error { + ps.mtx.Lock() + defer ps.mtx.Unlock() + if ps.lookup[peer.Key] != nil { + return ErrSwitchDuplicatePeer + } + + index := len(ps.list) + // Appending is safe even with other goroutines + // iterating over the ps.list slice. + ps.list = append(ps.list, peer) + ps.lookup[peer.Key] = &peerSetItem{peer, index} + return nil +} + +func (ps *PeerSet) Has(peerKey string) bool { + ps.mtx.Lock() + defer ps.mtx.Unlock() + _, ok := ps.lookup[peerKey] + return ok +} + +func (ps *PeerSet) Get(peerKey string) *Peer { + ps.mtx.Lock() + defer ps.mtx.Unlock() + item, ok := ps.lookup[peerKey] + if ok { + return item.peer + } else { + return nil + } +} + +func (ps *PeerSet) Remove(peer *Peer) { + ps.mtx.Lock() + defer ps.mtx.Unlock() + item := ps.lookup[peer.Key] + if item == nil { + return + } + + index := item.index + // Copy the list but without the last element. + // (we must copy because we're mutating the list) + newList := make([]*Peer, len(ps.list)-1) + copy(newList, ps.list) + // If it's the last peer, that's an easy special case. + if index == len(ps.list)-1 { + ps.list = newList + delete(ps.lookup, peer.Key) + return + } + + // Move the last item from ps.list to "index" in list. + lastPeer := ps.list[len(ps.list)-1] + lastPeerKey := lastPeer.Key + lastPeerItem := ps.lookup[lastPeerKey] + newList[index] = lastPeer + lastPeerItem.index = index + ps.list = newList + delete(ps.lookup, peer.Key) + +} + +func (ps *PeerSet) Size() int { + ps.mtx.Lock() + defer ps.mtx.Unlock() + return len(ps.list) +} + +// threadsafe list of peers. +func (ps *PeerSet) List() []*Peer { + ps.mtx.Lock() + defer ps.mtx.Unlock() + return ps.list +} diff --git a/peer_set_test.go b/peer_set_test.go new file mode 100644 index 000000000..a17f9d658 --- /dev/null +++ b/peer_set_test.go @@ -0,0 +1,67 @@ +package p2p + +import ( + "math/rand" + "testing" + + . "github.com/tendermint/tmlibs/common" +) + +// Returns an empty dummy peer +func randPeer() *Peer { + return &Peer{ + Key: RandStr(12), + NodeInfo: &NodeInfo{ + RemoteAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256), + ListenAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256), + }, + } +} + +func TestAddRemoveOne(t *testing.T) { + peerSet := NewPeerSet() + + peer := randPeer() + err := peerSet.Add(peer) + if err != nil { + t.Errorf("Failed to add new peer") + } + if peerSet.Size() != 1 { + t.Errorf("Failed to add new peer and increment size") + } + + peerSet.Remove(peer) + if peerSet.Has(peer.Key) { + t.Errorf("Failed to remove peer") + } + if peerSet.Size() != 0 { + t.Errorf("Failed to remove peer and decrement size") + } +} + +func TestAddRemoveMany(t *testing.T) { + peerSet := NewPeerSet() + + peers := []*Peer{} + N := 100 + for i := 0; i < N; i++ { + peer := randPeer() + if err := peerSet.Add(peer); err != nil { + t.Errorf("Failed to add new peer") + } + if peerSet.Size() != i+1 { + t.Errorf("Failed to add new peer and increment size") + } + peers = append(peers, peer) + } + + for i, peer := range peers { + peerSet.Remove(peer) + if peerSet.Has(peer.Key) { + t.Errorf("Failed to remove peer") + } + if peerSet.Size() != len(peers)-i-1 { + t.Errorf("Failed to remove peer and decrement size") + } + } +} diff --git a/peer_test.go b/peer_test.go new file mode 100644 index 000000000..0ac776347 --- /dev/null +++ b/peer_test.go @@ -0,0 +1,156 @@ +package p2p + +import ( + golog "log" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + crypto "github.com/tendermint/go-crypto" +) + +func TestPeerBasic(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + // simulate remote peer + rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} + rp.Start() + defer rp.Stop() + + p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), DefaultPeerConfig()) + require.Nil(err) + + p.Start() + defer p.Stop() + + assert.True(p.IsRunning()) + assert.True(p.IsOutbound()) + assert.False(p.IsPersistent()) + p.makePersistent() + assert.True(p.IsPersistent()) + assert.Equal(rp.Addr().String(), p.Addr().String()) + assert.Equal(rp.PubKey(), p.PubKey()) +} + +func TestPeerWithoutAuthEnc(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + config := DefaultPeerConfig() + config.AuthEnc = false + + // simulate remote peer + rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config} + rp.Start() + defer rp.Stop() + + p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config) + require.Nil(err) + + p.Start() + defer p.Stop() + + assert.True(p.IsRunning()) +} + +func TestPeerSend(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + config := DefaultPeerConfig() + config.AuthEnc = false + + // simulate remote peer + rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config} + rp.Start() + defer rp.Stop() + + p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config) + require.Nil(err) + + p.Start() + defer p.Stop() + + assert.True(p.CanSend(0x01)) + assert.True(p.Send(0x01, "Asylum")) +} + +func createOutboundPeerAndPerformHandshake(addr *NetAddress, config *PeerConfig) (*Peer, error) { + chDescs := []*ChannelDescriptor{ + &ChannelDescriptor{ID: 0x01, Priority: 1}, + } + reactorsByCh := map[byte]Reactor{0x01: NewTestReactor(chDescs, true)} + pk := crypto.GenPrivKeyEd25519() + p, err := newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, func(p *Peer, r interface{}) {}, pk, config) + if err != nil { + return nil, err + } + err = p.HandshakeTimeout(&NodeInfo{ + PubKey: pk.PubKey().Unwrap().(crypto.PubKeyEd25519), + Moniker: "host_peer", + Network: "testing", + Version: "123.123.123", + }, 1*time.Second) + if err != nil { + return nil, err + } + return p, nil +} + +type remotePeer struct { + PrivKey crypto.PrivKeyEd25519 + Config *PeerConfig + addr *NetAddress + quit chan struct{} +} + +func (p *remotePeer) Addr() *NetAddress { + return p.addr +} + +func (p *remotePeer) PubKey() crypto.PubKeyEd25519 { + return p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) +} + +func (p *remotePeer) Start() { + l, e := net.Listen("tcp", "127.0.0.1:0") // any available address + if e != nil { + golog.Fatalf("net.Listen tcp :0: %+v", e) + } + p.addr = NewNetAddress(l.Addr()) + p.quit = make(chan struct{}) + go p.accept(l) +} + +func (p *remotePeer) Stop() { + close(p.quit) +} + +func (p *remotePeer) accept(l net.Listener) { + for { + conn, err := l.Accept() + if err != nil { + golog.Fatalf("Failed to accept conn: %+v", err) + } + peer, err := newInboundPeerWithConfig(conn, make(map[byte]Reactor), make([]*ChannelDescriptor, 0), func(p *Peer, r interface{}) {}, p.PrivKey, p.Config) + if err != nil { + golog.Fatalf("Failed to create a peer: %+v", err) + } + err = peer.HandshakeTimeout(&NodeInfo{ + PubKey: p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519), + Moniker: "remote_peer", + Network: "testing", + Version: "123.123.123", + }, 1*time.Second) + if err != nil { + golog.Fatalf("Failed to perform handshake: %+v", err) + } + select { + case <-p.quit: + conn.Close() + return + default: + } + } +} diff --git a/pex_reactor.go b/pex_reactor.go new file mode 100644 index 000000000..03a383c85 --- /dev/null +++ b/pex_reactor.go @@ -0,0 +1,358 @@ +package p2p + +import ( + "bytes" + "fmt" + "math/rand" + "reflect" + "time" + + cmn "github.com/tendermint/tmlibs/common" + wire "github.com/tendermint/go-wire" +) + +const ( + // PexChannel is a channel for PEX messages + PexChannel = byte(0x00) + + // period to ensure peers connected + defaultEnsurePeersPeriod = 30 * time.Second + minNumOutboundPeers = 10 + maxPexMessageSize = 1048576 // 1MB + + // maximum messages one peer can send to us during `msgCountByPeerFlushInterval` + defaultMaxMsgCountByPeer = 1000 + msgCountByPeerFlushInterval = 1 * time.Hour +) + +// PEXReactor handles PEX (peer exchange) and ensures that an +// adequate number of peers are connected to the switch. +// +// It uses `AddrBook` (address book) to store `NetAddress`es of the peers. +// +// ## Preventing abuse +// +// For now, it just limits the number of messages from one peer to +// `defaultMaxMsgCountByPeer` messages per `msgCountByPeerFlushInterval` (1000 +// msg/hour). +// +// NOTE [2017-01-17]: +// Limiting is fine for now. Maybe down the road we want to keep track of the +// quality of peer messages so if peerA keeps telling us about peers we can't +// connect to then maybe we should care less about peerA. But I don't think +// that kind of complexity is priority right now. +type PEXReactor struct { + BaseReactor + + sw *Switch + book *AddrBook + ensurePeersPeriod time.Duration + + // tracks message count by peer, so we can prevent abuse + msgCountByPeer *cmn.CMap + maxMsgCountByPeer uint16 +} + +// NewPEXReactor creates new PEX reactor. +func NewPEXReactor(b *AddrBook) *PEXReactor { + r := &PEXReactor{ + book: b, + ensurePeersPeriod: defaultEnsurePeersPeriod, + msgCountByPeer: cmn.NewCMap(), + maxMsgCountByPeer: defaultMaxMsgCountByPeer, + } + r.BaseReactor = *NewBaseReactor(log, "PEXReactor", r) + return r +} + +// OnStart implements BaseService +func (r *PEXReactor) OnStart() error { + r.BaseReactor.OnStart() + r.book.Start() + go r.ensurePeersRoutine() + go r.flushMsgCountByPeer() + return nil +} + +// OnStop implements BaseService +func (r *PEXReactor) OnStop() { + r.BaseReactor.OnStop() + r.book.Stop() +} + +// GetChannels implements Reactor +func (r *PEXReactor) GetChannels() []*ChannelDescriptor { + return []*ChannelDescriptor{ + &ChannelDescriptor{ + ID: PexChannel, + Priority: 1, + SendQueueCapacity: 10, + }, + } +} + +// AddPeer implements Reactor by adding peer to the address book (if inbound) +// or by requesting more addresses (if outbound). +func (r *PEXReactor) AddPeer(p *Peer) { + if p.IsOutbound() { + // For outbound peers, the address is already in the books. + // Either it was added in DialSeeds or when we + // received the peer's address in r.Receive + if r.book.NeedMoreAddrs() { + r.RequestPEX(p) + } + } else { // For inbound connections, the peer is its own source + addr, err := NewNetAddressString(p.ListenAddr) + if err != nil { + // this should never happen + log.Error("Error in AddPeer: invalid peer address", "addr", p.ListenAddr, "error", err) + return + } + r.book.AddAddress(addr, addr) + } +} + +// RemovePeer implements Reactor. +func (r *PEXReactor) RemovePeer(p *Peer, reason interface{}) { + // If we aren't keeping track of local temp data for each peer here, then we + // don't have to do anything. +} + +// Receive implements Reactor by handling incoming PEX messages. +func (r *PEXReactor) Receive(chID byte, src *Peer, msgBytes []byte) { + srcAddr := src.Connection().RemoteAddress + srcAddrStr := srcAddr.String() + + r.IncrementMsgCountForPeer(srcAddrStr) + if r.ReachedMaxMsgCountForPeer(srcAddrStr) { + log.Warn("Maximum number of messages reached for peer", "peer", srcAddrStr) + // TODO remove src from peers? + return + } + + _, msg, err := DecodeMessage(msgBytes) + if err != nil { + log.Warn("Error decoding message", "error", err) + return + } + log.Notice("Received message", "msg", msg) + + switch msg := msg.(type) { + case *pexRequestMessage: + // src requested some peers. + r.SendAddrs(src, r.book.GetSelection()) + case *pexAddrsMessage: + // We received some peer addresses from src. + // (We don't want to get spammed with bad peers) + for _, addr := range msg.Addrs { + if addr != nil { + r.book.AddAddress(addr, srcAddr) + } + } + default: + log.Warn(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg))) + } +} + +// RequestPEX asks peer for more addresses. +func (r *PEXReactor) RequestPEX(p *Peer) { + p.Send(PexChannel, struct{ PexMessage }{&pexRequestMessage{}}) +} + +// SendAddrs sends addrs to the peer. +func (r *PEXReactor) SendAddrs(p *Peer, addrs []*NetAddress) { + p.Send(PexChannel, struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) +} + +// SetEnsurePeersPeriod sets period to ensure peers connected. +func (r *PEXReactor) SetEnsurePeersPeriod(d time.Duration) { + r.ensurePeersPeriod = d +} + +// SetMaxMsgCountByPeer sets maximum messages one peer can send to us during 'msgCountByPeerFlushInterval'. +func (r *PEXReactor) SetMaxMsgCountByPeer(v uint16) { + r.maxMsgCountByPeer = v +} + +// ReachedMaxMsgCountForPeer returns true if we received too many +// messages from peer with address `addr`. +// NOTE: assumes the value in the CMap is non-nil +func (r *PEXReactor) ReachedMaxMsgCountForPeer(addr string) bool { + return r.msgCountByPeer.Get(addr).(uint16) >= r.maxMsgCountByPeer +} + +// Increment or initialize the msg count for the peer in the CMap +func (r *PEXReactor) IncrementMsgCountForPeer(addr string) { + var count uint16 + countI := r.msgCountByPeer.Get(addr) + if countI != nil { + count = countI.(uint16) + } + count++ + r.msgCountByPeer.Set(addr, count) +} + +// Ensures that sufficient peers are connected. (continuous) +func (r *PEXReactor) ensurePeersRoutine() { + // Randomize when routine starts + ensurePeersPeriodMs := r.ensurePeersPeriod.Nanoseconds() / 1e6 + time.Sleep(time.Duration(rand.Int63n(ensurePeersPeriodMs)) * time.Millisecond) + + // fire once immediately. + r.ensurePeers() + + // fire periodically + ticker := time.NewTicker(r.ensurePeersPeriod) + + for { + select { + case <-ticker.C: + r.ensurePeers() + case <-r.Quit: + ticker.Stop() + return + } + } +} + +// ensurePeers ensures that sufficient peers are connected. (once) +// +// Old bucket / New bucket are arbitrary categories to denote whether an +// address is vetted or not, and this needs to be determined over time via a +// heuristic that we haven't perfected yet, or, perhaps is manually edited by +// the node operator. It should not be used to compute what addresses are +// already connected or not. +// +// TODO Basically, we need to work harder on our good-peer/bad-peer marking. +// What we're currently doing in terms of marking good/bad peers is just a +// placeholder. It should not be the case that an address becomes old/vetted +// upon a single successful connection. +func (r *PEXReactor) ensurePeers() { + numOutPeers, _, numDialing := r.Switch.NumPeers() + numToDial := minNumOutboundPeers - (numOutPeers + numDialing) + log.Info("Ensure peers", "numOutPeers", numOutPeers, "numDialing", numDialing, "numToDial", numToDial) + if numToDial <= 0 { + return + } + + toDial := make(map[string]*NetAddress) + + // Try to pick numToDial addresses to dial. + for i := 0; i < numToDial; i++ { + // The purpose of newBias is to first prioritize old (more vetted) peers + // when we have few connections, but to allow for new (less vetted) peers + // if we already have many connections. This algorithm isn't perfect, but + // it somewhat ensures that we prioritize connecting to more-vetted + // peers. + newBias := cmn.MinInt(numOutPeers, 8)*10 + 10 + var picked *NetAddress + // Try to fetch a new peer 3 times. + // This caps the maximum number of tries to 3 * numToDial. + for j := 0; j < 3; j++ { + try := r.book.PickAddress(newBias) + if try == nil { + break + } + _, alreadySelected := toDial[try.IP.String()] + alreadyDialing := r.Switch.IsDialing(try) + alreadyConnected := r.Switch.Peers().Has(try.IP.String()) + if alreadySelected || alreadyDialing || alreadyConnected { + // log.Info("Cannot dial address", "addr", try, + // "alreadySelected", alreadySelected, + // "alreadyDialing", alreadyDialing, + // "alreadyConnected", alreadyConnected) + continue + } else { + log.Info("Will dial address", "addr", try) + picked = try + break + } + } + if picked == nil { + continue + } + toDial[picked.IP.String()] = picked + } + + // Dial picked addresses + for _, item := range toDial { + go func(picked *NetAddress) { + _, err := r.Switch.DialPeerWithAddress(picked, false) + if err != nil { + r.book.MarkAttempt(picked) + } + }(item) + } + + // If we need more addresses, pick a random peer and ask for more. + if r.book.NeedMoreAddrs() { + if peers := r.Switch.Peers().List(); len(peers) > 0 { + i := rand.Int() % len(peers) + peer := peers[i] + log.Info("No addresses to dial. Sending pexRequest to random peer", "peer", peer) + r.RequestPEX(peer) + } + } +} + +func (r *PEXReactor) flushMsgCountByPeer() { + ticker := time.NewTicker(msgCountByPeerFlushInterval) + + for { + select { + case <-ticker.C: + r.msgCountByPeer.Clear() + case <-r.Quit: + ticker.Stop() + return + } + } +} + +//----------------------------------------------------------------------------- +// Messages + +const ( + msgTypeRequest = byte(0x01) + msgTypeAddrs = byte(0x02) +) + +// PexMessage is a primary type for PEX messages. Underneath, it could contain +// either pexRequestMessage, or pexAddrsMessage messages. +type PexMessage interface{} + +var _ = wire.RegisterInterface( + struct{ PexMessage }{}, + wire.ConcreteType{&pexRequestMessage{}, msgTypeRequest}, + wire.ConcreteType{&pexAddrsMessage{}, msgTypeAddrs}, +) + +// DecodeMessage implements interface registered above. +func DecodeMessage(bz []byte) (msgType byte, msg PexMessage, err error) { + msgType = bz[0] + n := new(int) + r := bytes.NewReader(bz) + msg = wire.ReadBinary(struct{ PexMessage }{}, r, maxPexMessageSize, n, &err).(struct{ PexMessage }).PexMessage + return +} + +/* +A pexRequestMessage requests additional peer addresses. +*/ +type pexRequestMessage struct { +} + +func (m *pexRequestMessage) String() string { + return "[pexRequest]" +} + +/* +A message with announced peer addresses. +*/ +type pexAddrsMessage struct { + Addrs []*NetAddress +} + +func (m *pexAddrsMessage) String() string { + return fmt.Sprintf("[pexAddrs %v]", m.Addrs) +} diff --git a/pex_reactor_test.go b/pex_reactor_test.go new file mode 100644 index 000000000..aed6c758d --- /dev/null +++ b/pex_reactor_test.go @@ -0,0 +1,163 @@ +package p2p + +import ( + "io/ioutil" + "math/rand" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tmlibs/common" + wire "github.com/tendermint/go-wire" +) + +func TestPEXReactorBasic(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + dir, err := ioutil.TempDir("", "pex_reactor") + require.Nil(err) + defer os.RemoveAll(dir) + book := NewAddrBook(dir+"addrbook.json", true) + + r := NewPEXReactor(book) + + assert.NotNil(r) + assert.NotEmpty(r.GetChannels()) +} + +func TestPEXReactorAddRemovePeer(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + dir, err := ioutil.TempDir("", "pex_reactor") + require.Nil(err) + defer os.RemoveAll(dir) + book := NewAddrBook(dir+"addrbook.json", true) + + r := NewPEXReactor(book) + + size := book.Size() + peer := createRandomPeer(false) + + r.AddPeer(peer) + assert.Equal(size+1, book.Size()) + + r.RemovePeer(peer, "peer not available") + assert.Equal(size+1, book.Size()) + + outboundPeer := createRandomPeer(true) + + r.AddPeer(outboundPeer) + assert.Equal(size+1, book.Size(), "outbound peers should not be added to the address book") + + r.RemovePeer(outboundPeer, "peer not available") + assert.Equal(size+1, book.Size()) +} + +func TestPEXReactorRunning(t *testing.T) { + require := require.New(t) + + N := 3 + switches := make([]*Switch, N) + + dir, err := ioutil.TempDir("", "pex_reactor") + require.Nil(err) + defer os.RemoveAll(dir) + book := NewAddrBook(dir+"addrbook.json", false) + + // create switches + for i := 0; i < N; i++ { + switches[i] = makeSwitch(i, "127.0.0.1", "123.123.123", func(i int, sw *Switch) *Switch { + r := NewPEXReactor(book) + r.SetEnsurePeersPeriod(250 * time.Millisecond) + sw.AddReactor("pex", r) + return sw + }) + } + + // fill the address book and add listeners + for _, s := range switches { + addr, _ := NewNetAddressString(s.NodeInfo().ListenAddr) + book.AddAddress(addr, addr) + s.AddListener(NewDefaultListener("tcp", s.NodeInfo().ListenAddr, true)) + } + + // start switches + for _, s := range switches { + _, err := s.Start() // start switch and reactors + require.Nil(err) + } + + time.Sleep(1 * time.Second) + + // check peers are connected after some time + for _, s := range switches { + outbound, inbound, _ := s.NumPeers() + if outbound+inbound == 0 { + t.Errorf("%v expected to be connected to at least one peer", s.NodeInfo().ListenAddr) + } + } + + // stop them + for _, s := range switches { + s.Stop() + } +} + +func TestPEXReactorReceive(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + dir, err := ioutil.TempDir("", "pex_reactor") + require.Nil(err) + defer os.RemoveAll(dir) + book := NewAddrBook(dir+"addrbook.json", true) + + r := NewPEXReactor(book) + + peer := createRandomPeer(false) + + size := book.Size() + netAddr, _ := NewNetAddressString(peer.ListenAddr) + addrs := []*NetAddress{netAddr} + msg := wire.BinaryBytes(struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) + r.Receive(PexChannel, peer, msg) + assert.Equal(size+1, book.Size()) + + msg = wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}}) + r.Receive(PexChannel, peer, msg) +} + +func TestPEXReactorAbuseFromPeer(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + dir, err := ioutil.TempDir("", "pex_reactor") + require.Nil(err) + defer os.RemoveAll(dir) + book := NewAddrBook(dir+"addrbook.json", true) + + r := NewPEXReactor(book) + r.SetMaxMsgCountByPeer(5) + + peer := createRandomPeer(false) + + msg := wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}}) + for i := 0; i < 10; i++ { + r.Receive(PexChannel, peer, msg) + } + + assert.True(r.ReachedMaxMsgCountForPeer(peer.ListenAddr)) +} + +func createRandomPeer(outbound bool) *Peer { + addr := cmn.Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256) + netAddr, _ := NewNetAddressString(addr) + return &Peer{ + Key: cmn.RandStr(12), + NodeInfo: &NodeInfo{ + ListenAddr: addr, + }, + outbound: outbound, + mconn: &MConnection{RemoteAddress: netAddr}, + } +} diff --git a/secret_connection.go b/secret_connection.go new file mode 100644 index 000000000..446c4f185 --- /dev/null +++ b/secret_connection.go @@ -0,0 +1,346 @@ +// Uses nacl's secret_box to encrypt a net.Conn. +// It is (meant to be) an implementation of the STS protocol. +// Note we do not (yet) assume that a remote peer's pubkey +// is known ahead of time, and thus we are technically +// still vulnerable to MITM. (TODO!) +// See docs/sts-final.pdf for more info +package p2p + +import ( + "bytes" + crand "crypto/rand" + "crypto/sha256" + "encoding/binary" + "errors" + "io" + "net" + "time" + + "golang.org/x/crypto/nacl/box" + "golang.org/x/crypto/nacl/secretbox" + "golang.org/x/crypto/ripemd160" + + "github.com/tendermint/go-crypto" + "github.com/tendermint/go-wire" + . "github.com/tendermint/tmlibs/common" +) + +// 2 + 1024 == 1026 total frame size +const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize +const dataMaxSize = 1024 +const totalFrameSize = dataMaxSize + dataLenSize +const sealedFrameSize = totalFrameSize + secretbox.Overhead +const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays + +// Implements net.Conn +type SecretConnection struct { + conn io.ReadWriteCloser + recvBuffer []byte + recvNonce *[24]byte + sendNonce *[24]byte + remPubKey crypto.PubKeyEd25519 + shrSecret *[32]byte // shared secret +} + +// Performs handshake and returns a new authenticated SecretConnection. +// Returns nil if error in handshake. +// Caller should call conn.Close() +// See docs/sts-final.pdf for more information. +func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25519) (*SecretConnection, error) { + + locPubKey := locPrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) + + // Generate ephemeral keys for perfect forward secrecy. + locEphPub, locEphPriv := genEphKeys() + + // Write local ephemeral pubkey and receive one too. + // NOTE: every 32-byte string is accepted as a Curve25519 public key + // (see DJB's Curve25519 paper: http://cr.yp.to/ecdh/curve25519-20060209.pdf) + remEphPub, err := shareEphPubKey(conn, locEphPub) + if err != nil { + return nil, err + } + + // Compute common shared secret. + shrSecret := computeSharedSecret(remEphPub, locEphPriv) + + // Sort by lexical order. + loEphPub, hiEphPub := sort32(locEphPub, remEphPub) + + // Generate nonces to use for secretbox. + recvNonce, sendNonce := genNonces(loEphPub, hiEphPub, locEphPub == loEphPub) + + // Generate common challenge to sign. + challenge := genChallenge(loEphPub, hiEphPub) + + // Construct SecretConnection. + sc := &SecretConnection{ + conn: conn, + recvBuffer: nil, + recvNonce: recvNonce, + sendNonce: sendNonce, + shrSecret: shrSecret, + } + + // Sign the challenge bytes for authentication. + locSignature := signChallenge(challenge, locPrivKey) + + // Share (in secret) each other's pubkey & challenge signature + authSigMsg, err := shareAuthSignature(sc, locPubKey, locSignature) + if err != nil { + return nil, err + } + remPubKey, remSignature := authSigMsg.Key, authSigMsg.Sig + if !remPubKey.VerifyBytes(challenge[:], remSignature) { + return nil, errors.New("Challenge verification failed") + } + + // We've authorized. + sc.remPubKey = remPubKey.Unwrap().(crypto.PubKeyEd25519) + return sc, nil +} + +// Returns authenticated remote pubkey +func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 { + return sc.remPubKey +} + +// Writes encrypted frames of `sealedFrameSize` +// CONTRACT: data smaller than dataMaxSize is read atomically. +func (sc *SecretConnection) Write(data []byte) (n int, err error) { + for 0 < len(data) { + var frame []byte = make([]byte, totalFrameSize) + var chunk []byte + if dataMaxSize < len(data) { + chunk = data[:dataMaxSize] + data = data[dataMaxSize:] + } else { + chunk = data + data = nil + } + chunkLength := len(chunk) + binary.BigEndian.PutUint16(frame, uint16(chunkLength)) + copy(frame[dataLenSize:], chunk) + + // encrypt the frame + var sealedFrame = make([]byte, sealedFrameSize) + secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret) + // fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret) + incr2Nonce(sc.sendNonce) + // end encryption + + _, err := sc.conn.Write(sealedFrame) + if err != nil { + return n, err + } else { + n += len(chunk) + } + } + return +} + +// CONTRACT: data smaller than dataMaxSize is read atomically. +func (sc *SecretConnection) Read(data []byte) (n int, err error) { + if 0 < len(sc.recvBuffer) { + n_ := copy(data, sc.recvBuffer) + sc.recvBuffer = sc.recvBuffer[n_:] + return + } + + sealedFrame := make([]byte, sealedFrameSize) + _, err = io.ReadFull(sc.conn, sealedFrame) + if err != nil { + return + } + + // decrypt the frame + var frame = make([]byte, totalFrameSize) + // fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret) + _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret) + if !ok { + return n, errors.New("Failed to decrypt SecretConnection") + } + incr2Nonce(sc.recvNonce) + // end decryption + + var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes + if chunkLength > dataMaxSize { + return 0, errors.New("chunkLength is greater than dataMaxSize") + } + var chunk = frame[dataLenSize : dataLenSize+chunkLength] + + n = copy(data, chunk) + sc.recvBuffer = chunk[n:] + return +} + +// Implements net.Conn +func (sc *SecretConnection) Close() error { return sc.conn.Close() } +func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() } +func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() } +func (sc *SecretConnection) SetDeadline(t time.Time) error { return sc.conn.(net.Conn).SetDeadline(t) } +func (sc *SecretConnection) SetReadDeadline(t time.Time) error { + return sc.conn.(net.Conn).SetReadDeadline(t) +} +func (sc *SecretConnection) SetWriteDeadline(t time.Time) error { + return sc.conn.(net.Conn).SetWriteDeadline(t) +} + +func genEphKeys() (ephPub, ephPriv *[32]byte) { + var err error + ephPub, ephPriv, err = box.GenerateKey(crand.Reader) + if err != nil { + PanicCrisis("Could not generate ephemeral keypairs") + } + return +} + +func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { + var err1, err2 error + + Parallel( + func() { + _, err1 = conn.Write(locEphPub[:]) + }, + func() { + remEphPub = new([32]byte) + _, err2 = io.ReadFull(conn, remEphPub[:]) + }, + ) + + if err1 != nil { + return nil, err1 + } + if err2 != nil { + return nil, err2 + } + + return remEphPub, nil +} + +func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) { + shrSecret = new([32]byte) + box.Precompute(shrSecret, remPubKey, locPrivKey) + return +} + +func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) { + if bytes.Compare(foo[:], bar[:]) < 0 { + lo = foo + hi = bar + } else { + lo = bar + hi = foo + } + return +} + +func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce *[24]byte) { + nonce1 := hash24(append(loPubKey[:], hiPubKey[:]...)) + nonce2 := new([24]byte) + copy(nonce2[:], nonce1[:]) + nonce2[len(nonce2)-1] ^= 0x01 + if locIsLo { + recvNonce = nonce1 + sendNonce = nonce2 + } else { + recvNonce = nonce2 + sendNonce = nonce1 + } + return +} + +func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) { + return hash32(append(loPubKey[:], hiPubKey[:]...)) +} + +func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signature crypto.SignatureEd25519) { + signature = locPrivKey.Sign(challenge[:]).Unwrap().(crypto.SignatureEd25519) + return +} + +type authSigMessage struct { + Key crypto.PubKey + Sig crypto.Signature +} + +func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signature crypto.SignatureEd25519) (*authSigMessage, error) { + var recvMsg authSigMessage + var err1, err2 error + + Parallel( + func() { + msgBytes := wire.BinaryBytes(authSigMessage{pubKey.Wrap(), signature.Wrap()}) + _, err1 = sc.Write(msgBytes) + }, + func() { + readBuffer := make([]byte, authSigMsgSize) + _, err2 = io.ReadFull(sc, readBuffer) + if err2 != nil { + return + } + n := int(0) // not used. + recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage) + }) + + if err1 != nil { + return nil, err1 + } + if err2 != nil { + return nil, err2 + } + + return &recvMsg, nil +} + +func verifyChallengeSignature(challenge *[32]byte, remPubKey crypto.PubKeyEd25519, remSignature crypto.SignatureEd25519) bool { + return remPubKey.VerifyBytes(challenge[:], remSignature.Wrap()) +} + +//-------------------------------------------------------------------------------- + +// sha256 +func hash32(input []byte) (res *[32]byte) { + hasher := sha256.New() + hasher.Write(input) // does not error + resSlice := hasher.Sum(nil) + res = new([32]byte) + copy(res[:], resSlice) + return +} + +// We only fill in the first 20 bytes with ripemd160 +func hash24(input []byte) (res *[24]byte) { + hasher := ripemd160.New() + hasher.Write(input) // does not error + resSlice := hasher.Sum(nil) + res = new([24]byte) + copy(res[:], resSlice) + return +} + +// ripemd160 +func hash20(input []byte) (res *[20]byte) { + hasher := ripemd160.New() + hasher.Write(input) // does not error + resSlice := hasher.Sum(nil) + res = new([20]byte) + copy(res[:], resSlice) + return +} + +// increment nonce big-endian by 2 with wraparound. +func incr2Nonce(nonce *[24]byte) { + incrNonce(nonce) + incrNonce(nonce) +} + +// increment nonce big-endian by 1 with wraparound. +func incrNonce(nonce *[24]byte) { + for i := 23; 0 <= i; i-- { + nonce[i] += 1 + if nonce[i] != 0 { + return + } + } +} diff --git a/secret_connection_test.go b/secret_connection_test.go new file mode 100644 index 000000000..3dd962f88 --- /dev/null +++ b/secret_connection_test.go @@ -0,0 +1,202 @@ +package p2p + +import ( + "bytes" + "io" + "testing" + + "github.com/tendermint/go-crypto" + . "github.com/tendermint/tmlibs/common" +) + +type dummyConn struct { + *io.PipeReader + *io.PipeWriter +} + +func (drw dummyConn) Close() (err error) { + err2 := drw.PipeWriter.CloseWithError(io.EOF) + err1 := drw.PipeReader.Close() + if err2 != nil { + return err + } + return err1 +} + +// Each returned ReadWriteCloser is akin to a net.Connection +func makeDummyConnPair() (fooConn, barConn dummyConn) { + barReader, fooWriter := io.Pipe() + fooReader, barWriter := io.Pipe() + return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter} +} + +func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { + fooConn, barConn := makeDummyConnPair() + fooPrvKey := crypto.GenPrivKeyEd25519() + fooPubKey := fooPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519) + barPrvKey := crypto.GenPrivKeyEd25519() + barPubKey := barPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519) + + Parallel( + func() { + var err error + fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) + if err != nil { + tb.Errorf("Failed to establish SecretConnection for foo: %v", err) + return + } + remotePubBytes := fooSecConn.RemotePubKey() + if !bytes.Equal(remotePubBytes[:], barPubKey[:]) { + tb.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v", + barPubKey, fooSecConn.RemotePubKey()) + } + }, + func() { + var err error + barSecConn, err = MakeSecretConnection(barConn, barPrvKey) + if barSecConn == nil { + tb.Errorf("Failed to establish SecretConnection for bar: %v", err) + return + } + remotePubBytes := barSecConn.RemotePubKey() + if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) { + tb.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v", + fooPubKey, barSecConn.RemotePubKey()) + } + }) + + return +} + +func TestSecretConnectionHandshake(t *testing.T) { + fooSecConn, barSecConn := makeSecretConnPair(t) + fooSecConn.Close() + barSecConn.Close() +} + +func TestSecretConnectionReadWrite(t *testing.T) { + fooConn, barConn := makeDummyConnPair() + fooWrites, barWrites := []string{}, []string{} + fooReads, barReads := []string{}, []string{} + + // Pre-generate the things to write (for foo & bar) + for i := 0; i < 100; i++ { + fooWrites = append(fooWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) + barWrites = append(barWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) + } + + // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa + genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() { + return func() { + // Node handskae + nodePrvKey := crypto.GenPrivKeyEd25519() + nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) + if err != nil { + t.Errorf("Failed to establish SecretConnection for node: %v", err) + return + } + // In parallel, handle reads and writes + Parallel( + func() { + // Node writes + for _, nodeWrite := range nodeWrites { + n, err := nodeSecretConn.Write([]byte(nodeWrite)) + if err != nil { + t.Errorf("Failed to write to nodeSecretConn: %v", err) + return + } + if n != len(nodeWrite) { + t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) + return + } + } + nodeConn.PipeWriter.Close() + }, + func() { + // Node reads + readBuffer := make([]byte, dataMaxSize) + for { + n, err := nodeSecretConn.Read(readBuffer) + if err == io.EOF { + return + } else if err != nil { + t.Errorf("Failed to read from nodeSecretConn: %v", err) + return + } + *nodeReads = append(*nodeReads, string(readBuffer[:n])) + } + nodeConn.PipeReader.Close() + }) + } + } + + // Run foo & bar in parallel + Parallel( + genNodeRunner(fooConn, fooWrites, &fooReads), + genNodeRunner(barConn, barWrites, &barReads), + ) + + // A helper to ensure that the writes and reads match. + // Additionally, small writes (<= dataMaxSize) must be atomically read. + compareWritesReads := func(writes []string, reads []string) { + for { + // Pop next write & corresponding reads + var read, write string = "", writes[0] + var readCount = 0 + for _, readChunk := range reads { + read += readChunk + readCount += 1 + if len(write) <= len(read) { + break + } + if len(write) <= dataMaxSize { + break // atomicity of small writes + } + } + // Compare + if write != read { + t.Errorf("Expected to read %X, got %X", write, read) + } + // Iterate + writes = writes[1:] + reads = reads[readCount:] + if len(writes) == 0 { + break + } + } + } + + compareWritesReads(fooWrites, barReads) + compareWritesReads(barWrites, fooReads) + +} + +func BenchmarkSecretConnection(b *testing.B) { + b.StopTimer() + fooSecConn, barSecConn := makeSecretConnPair(b) + fooWriteText := RandStr(dataMaxSize) + // Consume reads from bar's reader + go func() { + readBuffer := make([]byte, dataMaxSize) + for { + _, err := barSecConn.Read(readBuffer) + if err == io.EOF { + return + } else if err != nil { + b.Fatalf("Failed to read from barSecConn: %v", err) + } + } + }() + + b.StartTimer() + for i := 0; i < b.N; i++ { + _, err := fooSecConn.Write([]byte(fooWriteText)) + if err != nil { + b.Fatalf("Failed to write to fooSecConn: %v", err) + } + } + b.StopTimer() + + fooSecConn.Close() + //barSecConn.Close() race condition +} diff --git a/switch.go b/switch.go new file mode 100644 index 000000000..8c771df1c --- /dev/null +++ b/switch.go @@ -0,0 +1,593 @@ +package p2p + +import ( + "errors" + "fmt" + "math/rand" + "net" + "time" + + cfg "github.com/tendermint/go-config" + crypto "github.com/tendermint/go-crypto" + "github.com/tendermint/log15" + . "github.com/tendermint/tmlibs/common" +) + +const ( + reconnectAttempts = 30 + reconnectInterval = 3 * time.Second +) + +type Reactor interface { + Service // Start, Stop + + SetSwitch(*Switch) + GetChannels() []*ChannelDescriptor + AddPeer(peer *Peer) + RemovePeer(peer *Peer, reason interface{}) + Receive(chID byte, peer *Peer, msgBytes []byte) +} + +//-------------------------------------- + +type BaseReactor struct { + BaseService // Provides Start, Stop, .Quit + Switch *Switch +} + +func NewBaseReactor(log log15.Logger, name string, impl Reactor) *BaseReactor { + return &BaseReactor{ + BaseService: *NewBaseService(log, name, impl), + Switch: nil, + } +} + +func (br *BaseReactor) SetSwitch(sw *Switch) { + br.Switch = sw +} +func (_ *BaseReactor) GetChannels() []*ChannelDescriptor { return nil } +func (_ *BaseReactor) AddPeer(peer *Peer) {} +func (_ *BaseReactor) RemovePeer(peer *Peer, reason interface{}) {} +func (_ *BaseReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {} + +//----------------------------------------------------------------------------- + +/* +The `Switch` handles peer connections and exposes an API to receive incoming messages +on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one +or more `Channels`. So while sending outgoing messages is typically performed on the peer, +incoming messages are received on the reactor. +*/ +type Switch struct { + BaseService + + config cfg.Config + listeners []Listener + reactors map[string]Reactor + chDescs []*ChannelDescriptor + reactorsByCh map[byte]Reactor + peers *PeerSet + dialing *CMap + nodeInfo *NodeInfo // our node info + nodePrivKey crypto.PrivKeyEd25519 // our node privkey + + filterConnByAddr func(net.Addr) error + filterConnByPubKey func(crypto.PubKeyEd25519) error +} + +var ( + ErrSwitchDuplicatePeer = errors.New("Duplicate peer") + ErrSwitchMaxPeersPerIPRange = errors.New("IP range has too many peers") +) + +func NewSwitch(config cfg.Config) *Switch { + setConfigDefaults(config) + + sw := &Switch{ + config: config, + reactors: make(map[string]Reactor), + chDescs: make([]*ChannelDescriptor, 0), + reactorsByCh: make(map[byte]Reactor), + peers: NewPeerSet(), + dialing: NewCMap(), + nodeInfo: nil, + } + sw.BaseService = *NewBaseService(log, "P2P Switch", sw) + return sw +} + +// Not goroutine safe. +func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor { + // Validate the reactor. + // No two reactors can share the same channel. + reactorChannels := reactor.GetChannels() + for _, chDesc := range reactorChannels { + chID := chDesc.ID + if sw.reactorsByCh[chID] != nil { + PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor)) + } + sw.chDescs = append(sw.chDescs, chDesc) + sw.reactorsByCh[chID] = reactor + } + sw.reactors[name] = reactor + reactor.SetSwitch(sw) + return reactor +} + +// Not goroutine safe. +func (sw *Switch) Reactors() map[string]Reactor { + return sw.reactors +} + +// Not goroutine safe. +func (sw *Switch) Reactor(name string) Reactor { + return sw.reactors[name] +} + +// Not goroutine safe. +func (sw *Switch) AddListener(l Listener) { + sw.listeners = append(sw.listeners, l) +} + +// Not goroutine safe. +func (sw *Switch) Listeners() []Listener { + return sw.listeners +} + +// Not goroutine safe. +func (sw *Switch) IsListening() bool { + return len(sw.listeners) > 0 +} + +// Not goroutine safe. +func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) { + sw.nodeInfo = nodeInfo +} + +// Not goroutine safe. +func (sw *Switch) NodeInfo() *NodeInfo { + return sw.nodeInfo +} + +// Not goroutine safe. +// NOTE: Overwrites sw.nodeInfo.PubKey +func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) { + sw.nodePrivKey = nodePrivKey + if sw.nodeInfo != nil { + sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) + } +} + +// Switch.Start() starts all the reactors, peers, and listeners. +func (sw *Switch) OnStart() error { + sw.BaseService.OnStart() + // Start reactors + for _, reactor := range sw.reactors { + _, err := reactor.Start() + if err != nil { + return err + } + } + // Start peers + for _, peer := range sw.peers.List() { + sw.startInitPeer(peer) + } + // Start listeners + for _, listener := range sw.listeners { + go sw.listenerRoutine(listener) + } + return nil +} + +func (sw *Switch) OnStop() { + sw.BaseService.OnStop() + // Stop listeners + for _, listener := range sw.listeners { + listener.Stop() + } + sw.listeners = nil + // Stop peers + for _, peer := range sw.peers.List() { + peer.Stop() + sw.peers.Remove(peer) + } + // Stop reactors + for _, reactor := range sw.reactors { + reactor.Stop() + } +} + +// NOTE: This performs a blocking handshake before the peer is added. +// CONTRACT: If error is returned, peer is nil, and conn is immediately closed. +func (sw *Switch) AddPeer(peer *Peer) error { + if err := sw.FilterConnByAddr(peer.Addr()); err != nil { + return err + } + + if err := sw.FilterConnByPubKey(peer.PubKey()); err != nil { + return err + } + + if err := peer.HandshakeTimeout(sw.nodeInfo, time.Duration(sw.config.GetInt(configKeyHandshakeTimeoutSeconds))*time.Second); err != nil { + return err + } + + // Avoid self + if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) { + return errors.New("Ignoring connection from self") + } + + // Check version, chain id + if err := sw.nodeInfo.CompatibleWith(peer.NodeInfo); err != nil { + return err + } + + // Add the peer to .peers + // ignore if duplicate or if we already have too many for that IP range + if err := sw.peers.Add(peer); err != nil { + log.Notice("Ignoring peer", "error", err, "peer", peer) + peer.Stop() + return err + } + + // Start peer + if sw.IsRunning() { + sw.startInitPeer(peer) + } + + log.Notice("Added peer", "peer", peer) + return nil +} + +func (sw *Switch) FilterConnByAddr(addr net.Addr) error { + if sw.filterConnByAddr != nil { + return sw.filterConnByAddr(addr) + } + return nil +} + +func (sw *Switch) FilterConnByPubKey(pubkey crypto.PubKeyEd25519) error { + if sw.filterConnByPubKey != nil { + return sw.filterConnByPubKey(pubkey) + } + return nil + +} + +func (sw *Switch) SetAddrFilter(f func(net.Addr) error) { + sw.filterConnByAddr = f +} + +func (sw *Switch) SetPubKeyFilter(f func(crypto.PubKeyEd25519) error) { + sw.filterConnByPubKey = f +} + +func (sw *Switch) startInitPeer(peer *Peer) { + peer.Start() // spawn send/recv routines + for _, reactor := range sw.reactors { + reactor.AddPeer(peer) + } +} + +// Dial a list of seeds asynchronously in random order +func (sw *Switch) DialSeeds(addrBook *AddrBook, seeds []string) error { + + netAddrs, err := NewNetAddressStrings(seeds) + if err != nil { + return err + } + + if addrBook != nil { + // add seeds to `addrBook` + ourAddrS := sw.nodeInfo.ListenAddr + ourAddr, _ := NewNetAddressString(ourAddrS) + for _, netAddr := range netAddrs { + // do not add ourselves + if netAddr.Equals(ourAddr) { + continue + } + addrBook.AddAddress(netAddr, ourAddr) + } + addrBook.Save() + } + + // permute the list, dial them in random order. + perm := rand.Perm(len(netAddrs)) + for i := 0; i < len(perm); i++ { + go func(i int) { + time.Sleep(time.Duration(rand.Int63n(3000)) * time.Millisecond) + j := perm[i] + sw.dialSeed(netAddrs[j]) + }(i) + } + return nil +} + +func (sw *Switch) dialSeed(addr *NetAddress) { + peer, err := sw.DialPeerWithAddress(addr, true) + if err != nil { + log.Error("Error dialing seed", "error", err) + return + } else { + log.Notice("Connected to seed", "peer", peer) + } +} + +func (sw *Switch) DialPeerWithAddress(addr *NetAddress, persistent bool) (*Peer, error) { + sw.dialing.Set(addr.IP.String(), addr) + defer sw.dialing.Delete(addr.IP.String()) + + peer, err := newOutboundPeerWithConfig(addr, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, peerConfigFromGoConfig(sw.config)) + if err != nil { + log.Info("Failed dialing peer", "address", addr, "error", err) + return nil, err + } + if persistent { + peer.makePersistent() + } + err = sw.AddPeer(peer) + if err != nil { + log.Info("Failed adding peer", "address", addr, "error", err) + peer.CloseConn() + return nil, err + } + log.Notice("Dialed and added peer", "address", addr, "peer", peer) + return peer, nil +} + +func (sw *Switch) IsDialing(addr *NetAddress) bool { + return sw.dialing.Has(addr.IP.String()) +} + +// Broadcast runs a go routine for each attempted send, which will block +// trying to send for defaultSendTimeoutSeconds. Returns a channel +// which receives success values for each attempted send (false if times out) +// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. +func (sw *Switch) Broadcast(chID byte, msg interface{}) chan bool { + successChan := make(chan bool, len(sw.peers.List())) + log.Debug("Broadcast", "channel", chID, "msg", msg) + for _, peer := range sw.peers.List() { + go func(peer *Peer) { + success := peer.Send(chID, msg) + successChan <- success + }(peer) + } + return successChan +} + +// Returns the count of outbound/inbound and outbound-dialing peers. +func (sw *Switch) NumPeers() (outbound, inbound, dialing int) { + peers := sw.peers.List() + for _, peer := range peers { + if peer.outbound { + outbound++ + } else { + inbound++ + } + } + dialing = sw.dialing.Size() + return +} + +func (sw *Switch) Peers() IPeerSet { + return sw.peers +} + +// Disconnect from a peer due to external error, retry if it is a persistent peer. +// TODO: make record depending on reason. +func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) { + addr := NewNetAddress(peer.Addr()) + log.Notice("Stopping peer for error", "peer", peer, "error", reason) + sw.stopAndRemovePeer(peer, reason) + + if peer.IsPersistent() { + go func() { + log.Notice("Reconnecting to peer", "peer", peer) + for i := 1; i < reconnectAttempts; i++ { + if !sw.IsRunning() { + return + } + + peer, err := sw.DialPeerWithAddress(addr, true) + if err != nil { + if i == reconnectAttempts { + log.Notice("Error reconnecting to peer. Giving up", "tries", i, "error", err) + return + } + log.Notice("Error reconnecting to peer. Trying again", "tries", i, "error", err) + time.Sleep(reconnectInterval) + continue + } + + log.Notice("Reconnected to peer", "peer", peer) + return + } + }() + } +} + +// Disconnect from a peer gracefully. +// TODO: handle graceful disconnects. +func (sw *Switch) StopPeerGracefully(peer *Peer) { + log.Notice("Stopping peer gracefully") + sw.stopAndRemovePeer(peer, nil) +} + +func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) { + sw.peers.Remove(peer) + peer.Stop() + for _, reactor := range sw.reactors { + reactor.RemovePeer(peer, reason) + } +} + +func (sw *Switch) listenerRoutine(l Listener) { + for { + inConn, ok := <-l.Connections() + if !ok { + break + } + + // ignore connection if we already have enough + maxPeers := sw.config.GetInt(configKeyMaxNumPeers) + if maxPeers <= sw.peers.Size() { + log.Info("Ignoring inbound connection: already have enough peers", "address", inConn.RemoteAddr().String(), "numPeers", sw.peers.Size(), "max", maxPeers) + continue + } + + // New inbound connection! + err := sw.addPeerWithConnectionAndConfig(inConn, peerConfigFromGoConfig(sw.config)) + if err != nil { + log.Notice("Ignoring inbound connection: error while adding peer", "address", inConn.RemoteAddr().String(), "error", err) + continue + } + + // NOTE: We don't yet have the listening port of the + // remote (if they have a listener at all). + // The peerHandshake will handle that + } + + // cleanup +} + +//----------------------------------------------------------------------------- + +type SwitchEventNewPeer struct { + Peer *Peer +} + +type SwitchEventDonePeer struct { + Peer *Peer + Error interface{} +} + +//------------------------------------------------------------------ +// Switches connected via arbitrary net.Conn; useful for testing + +// Returns n switches, connected according to the connect func. +// If connect==Connect2Switches, the switches will be fully connected. +// initSwitch defines how the ith switch should be initialized (ie. with what reactors). +// NOTE: panics if any switch fails to start. +func MakeConnectedSwitches(n int, initSwitch func(int, *Switch) *Switch, connect func([]*Switch, int, int)) []*Switch { + switches := make([]*Switch, n) + for i := 0; i < n; i++ { + switches[i] = makeSwitch(i, "testing", "123.123.123", initSwitch) + } + + if err := StartSwitches(switches); err != nil { + panic(err) + } + + for i := 0; i < n; i++ { + for j := i; j < n; j++ { + connect(switches, i, j) + } + } + + return switches +} + +var PanicOnAddPeerErr = false + +// Will connect switches i and j via net.Pipe() +// Blocks until a conection is established. +// NOTE: caller ensures i and j are within bounds +func Connect2Switches(switches []*Switch, i, j int) { + switchI := switches[i] + switchJ := switches[j] + c1, c2 := net.Pipe() + doneCh := make(chan struct{}) + go func() { + err := switchI.addPeerWithConnection(c1) + if PanicOnAddPeerErr && err != nil { + panic(err) + } + doneCh <- struct{}{} + }() + go func() { + err := switchJ.addPeerWithConnection(c2) + if PanicOnAddPeerErr && err != nil { + panic(err) + } + doneCh <- struct{}{} + }() + <-doneCh + <-doneCh +} + +func StartSwitches(switches []*Switch) error { + for _, s := range switches { + _, err := s.Start() // start switch and reactors + if err != nil { + return err + } + } + return nil +} + +func makeSwitch(i int, network, version string, initSwitch func(int, *Switch) *Switch) *Switch { + privKey := crypto.GenPrivKeyEd25519() + // new switch, add reactors + // TODO: let the config be passed in? + s := initSwitch(i, NewSwitch(cfg.NewMapConfig(nil))) + s.SetNodeInfo(&NodeInfo{ + PubKey: privKey.PubKey().Unwrap().(crypto.PubKeyEd25519), + Moniker: Fmt("switch%d", i), + Network: network, + Version: version, + RemoteAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023), + ListenAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023), + }) + s.SetNodePrivKey(privKey) + return s +} + +func (sw *Switch) addPeerWithConnection(conn net.Conn) error { + peer, err := newInboundPeer(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) + if err != nil { + conn.Close() + return err + } + + if err = sw.AddPeer(peer); err != nil { + conn.Close() + return err + } + + return nil +} + +func (sw *Switch) addPeerWithConnectionAndConfig(conn net.Conn, config *PeerConfig) error { + peer, err := newInboundPeerWithConfig(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config) + if err != nil { + conn.Close() + return err + } + + if err = sw.AddPeer(peer); err != nil { + conn.Close() + return err + } + + return nil +} + +func peerConfigFromGoConfig(config cfg.Config) *PeerConfig { + return &PeerConfig{ + AuthEnc: config.GetBool(configKeyAuthEnc), + Fuzz: config.GetBool(configFuzzEnable), + HandshakeTimeout: time.Duration(config.GetInt(configKeyHandshakeTimeoutSeconds)) * time.Second, + DialTimeout: time.Duration(config.GetInt(configKeyDialTimeoutSeconds)) * time.Second, + MConfig: &MConnConfig{ + SendRate: int64(config.GetInt(configKeySendRate)), + RecvRate: int64(config.GetInt(configKeyRecvRate)), + }, + FuzzConfig: &FuzzConnConfig{ + Mode: config.GetInt(configFuzzMode), + MaxDelay: time.Duration(config.GetInt(configFuzzMaxDelayMilliseconds)) * time.Millisecond, + ProbDropRW: config.GetFloat64(configFuzzProbDropRW), + ProbDropConn: config.GetFloat64(configFuzzProbDropConn), + ProbSleep: config.GetFloat64(configFuzzProbSleep), + }, + } +} diff --git a/switch_test.go b/switch_test.go new file mode 100644 index 000000000..1f1fe69f0 --- /dev/null +++ b/switch_test.go @@ -0,0 +1,330 @@ +package p2p + +import ( + "bytes" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + . "github.com/tendermint/tmlibs/common" + cfg "github.com/tendermint/go-config" + crypto "github.com/tendermint/go-crypto" + wire "github.com/tendermint/go-wire" +) + +var ( + config cfg.Config +) + +func init() { + config = cfg.NewMapConfig(nil) + setConfigDefaults(config) +} + +type PeerMessage struct { + PeerKey string + Bytes []byte + Counter int +} + +type TestReactor struct { + BaseReactor + + mtx sync.Mutex + channels []*ChannelDescriptor + peersAdded []*Peer + peersRemoved []*Peer + logMessages bool + msgsCounter int + msgsReceived map[byte][]PeerMessage +} + +func NewTestReactor(channels []*ChannelDescriptor, logMessages bool) *TestReactor { + tr := &TestReactor{ + channels: channels, + logMessages: logMessages, + msgsReceived: make(map[byte][]PeerMessage), + } + tr.BaseReactor = *NewBaseReactor(log, "TestReactor", tr) + return tr +} + +func (tr *TestReactor) GetChannels() []*ChannelDescriptor { + return tr.channels +} + +func (tr *TestReactor) AddPeer(peer *Peer) { + tr.mtx.Lock() + defer tr.mtx.Unlock() + tr.peersAdded = append(tr.peersAdded, peer) +} + +func (tr *TestReactor) RemovePeer(peer *Peer, reason interface{}) { + tr.mtx.Lock() + defer tr.mtx.Unlock() + tr.peersRemoved = append(tr.peersRemoved, peer) +} + +func (tr *TestReactor) Receive(chID byte, peer *Peer, msgBytes []byte) { + if tr.logMessages { + tr.mtx.Lock() + defer tr.mtx.Unlock() + //fmt.Printf("Received: %X, %X\n", chID, msgBytes) + tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.Key, msgBytes, tr.msgsCounter}) + tr.msgsCounter++ + } +} + +func (tr *TestReactor) getMsgs(chID byte) []PeerMessage { + tr.mtx.Lock() + defer tr.mtx.Unlock() + return tr.msgsReceived[chID] +} + +//----------------------------------------------------------------------------- + +// convenience method for creating two switches connected to each other. +// XXX: note this uses net.Pipe and not a proper TCP conn +func makeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) { + // Create two switches that will be interconnected. + switches := MakeConnectedSwitches(2, initSwitch, Connect2Switches) + return switches[0], switches[1] +} + +func initSwitchFunc(i int, sw *Switch) *Switch { + // Make two reactors of two channels each + sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ + &ChannelDescriptor{ID: byte(0x00), Priority: 10}, + &ChannelDescriptor{ID: byte(0x01), Priority: 10}, + }, true)) + sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ + &ChannelDescriptor{ID: byte(0x02), Priority: 10}, + &ChannelDescriptor{ID: byte(0x03), Priority: 10}, + }, true)) + return sw +} + +func TestSwitches(t *testing.T) { + s1, s2 := makeSwitchPair(t, initSwitchFunc) + defer s1.Stop() + defer s2.Stop() + + if s1.Peers().Size() != 1 { + t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size()) + } + if s2.Peers().Size() != 1 { + t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size()) + } + + // Lets send some messages + ch0Msg := "channel zero" + ch1Msg := "channel foo" + ch2Msg := "channel bar" + + s1.Broadcast(byte(0x00), ch0Msg) + s1.Broadcast(byte(0x01), ch1Msg) + s1.Broadcast(byte(0x02), ch2Msg) + + // Wait for things to settle... + time.Sleep(5000 * time.Millisecond) + + // Check message on ch0 + ch0Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x00)) + if len(ch0Msgs) != 1 { + t.Errorf("Expected to have received 1 message in ch0") + } + if !bytes.Equal(ch0Msgs[0].Bytes, wire.BinaryBytes(ch0Msg)) { + t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch0Msg), ch0Msgs[0].Bytes) + } + + // Check message on ch1 + ch1Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x01)) + if len(ch1Msgs) != 1 { + t.Errorf("Expected to have received 1 message in ch1") + } + if !bytes.Equal(ch1Msgs[0].Bytes, wire.BinaryBytes(ch1Msg)) { + t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch1Msg), ch1Msgs[0].Bytes) + } + + // Check message on ch2 + ch2Msgs := s2.Reactor("bar").(*TestReactor).getMsgs(byte(0x02)) + if len(ch2Msgs) != 1 { + t.Errorf("Expected to have received 1 message in ch2") + } + if !bytes.Equal(ch2Msgs[0].Bytes, wire.BinaryBytes(ch2Msg)) { + t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch2Msg), ch2Msgs[0].Bytes) + } + +} + +func TestConnAddrFilter(t *testing.T) { + s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + + c1, c2 := net.Pipe() + + s1.SetAddrFilter(func(addr net.Addr) error { + if addr.String() == c1.RemoteAddr().String() { + return fmt.Errorf("Error: pipe is blacklisted") + } + return nil + }) + + // connect to good peer + go func() { + s1.addPeerWithConnection(c1) + }() + go func() { + s2.addPeerWithConnection(c2) + }() + + // Wait for things to happen, peers to get added... + time.Sleep(100 * time.Millisecond * time.Duration(4)) + + defer s1.Stop() + defer s2.Stop() + if s1.Peers().Size() != 0 { + t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) + } + if s2.Peers().Size() != 0 { + t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) + } +} + +func TestConnPubKeyFilter(t *testing.T) { + s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + + c1, c2 := net.Pipe() + + // set pubkey filter + s1.SetPubKeyFilter(func(pubkey crypto.PubKeyEd25519) error { + if bytes.Equal(pubkey.Bytes(), s2.nodeInfo.PubKey.Bytes()) { + return fmt.Errorf("Error: pipe is blacklisted") + } + return nil + }) + + // connect to good peer + go func() { + s1.addPeerWithConnection(c1) + }() + go func() { + s2.addPeerWithConnection(c2) + }() + + // Wait for things to happen, peers to get added... + time.Sleep(100 * time.Millisecond * time.Duration(4)) + + defer s1.Stop() + defer s2.Stop() + if s1.Peers().Size() != 0 { + t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) + } + if s2.Peers().Size() != 0 { + t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) + } +} + +func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + sw.Start() + defer sw.Stop() + + // simulate remote peer + rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} + rp.Start() + defer rp.Stop() + + peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) + require.Nil(err) + err = sw.AddPeer(peer) + require.Nil(err) + + // simulate failure by closing connection + peer.CloseConn() + + time.Sleep(100 * time.Millisecond) + + assert.Zero(sw.Peers().Size()) + assert.False(peer.IsRunning()) +} + +func TestSwitchReconnectsToPersistentPeer(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) + sw.Start() + defer sw.Stop() + + // simulate remote peer + rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} + rp.Start() + defer rp.Stop() + + peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) + peer.makePersistent() + require.Nil(err) + err = sw.AddPeer(peer) + require.Nil(err) + + // simulate failure by closing connection + peer.CloseConn() + + time.Sleep(100 * time.Millisecond) + + assert.NotZero(sw.Peers().Size()) + assert.False(peer.IsRunning()) +} + +func BenchmarkSwitches(b *testing.B) { + + b.StopTimer() + + s1, s2 := makeSwitchPair(b, func(i int, sw *Switch) *Switch { + // Make bar reactors of bar channels each + sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ + &ChannelDescriptor{ID: byte(0x00), Priority: 10}, + &ChannelDescriptor{ID: byte(0x01), Priority: 10}, + }, false)) + sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ + &ChannelDescriptor{ID: byte(0x02), Priority: 10}, + &ChannelDescriptor{ID: byte(0x03), Priority: 10}, + }, false)) + return sw + }) + defer s1.Stop() + defer s2.Stop() + + // Allow time for goroutines to boot up + time.Sleep(1000 * time.Millisecond) + b.StartTimer() + + numSuccess, numFailure := 0, 0 + + // Send random message from foo channel to another + for i := 0; i < b.N; i++ { + chID := byte(i % 4) + successChan := s1.Broadcast(chID, "test data") + for s := range successChan { + if s { + numSuccess++ + } else { + numFailure++ + } + } + } + + log.Warn(Fmt("success: %v, failure: %v", numSuccess, numFailure)) + + // Allow everything to flush before stopping switches & closing connections. + b.StopTimer() + time.Sleep(1000 * time.Millisecond) + +} diff --git a/types.go b/types.go new file mode 100644 index 000000000..4f3e4c1d8 --- /dev/null +++ b/types.go @@ -0,0 +1,77 @@ +package p2p + +import ( + "fmt" + "net" + "strconv" + "strings" + + "github.com/tendermint/go-crypto" +) + +const maxNodeInfoSize = 10240 // 10Kb + +type NodeInfo struct { + PubKey crypto.PubKeyEd25519 `json:"pub_key"` + Moniker string `json:"moniker"` + Network string `json:"network"` + RemoteAddr string `json:"remote_addr"` + ListenAddr string `json:"listen_addr"` + Version string `json:"version"` // major.minor.revision + Other []string `json:"other"` // other application specific data +} + +// CONTRACT: two nodes are compatible if the major/minor versions match and network match +func (info *NodeInfo) CompatibleWith(other *NodeInfo) error { + iMajor, iMinor, _, iErr := splitVersion(info.Version) + oMajor, oMinor, _, oErr := splitVersion(other.Version) + + // if our own version number is not formatted right, we messed up + if iErr != nil { + return iErr + } + + // version number must be formatted correctly ("x.x.x") + if oErr != nil { + return oErr + } + + // major version must match + if iMajor != oMajor { + return fmt.Errorf("Peer is on a different major version. Got %v, expected %v", oMajor, iMajor) + } + + // minor version must match + if iMinor != oMinor { + return fmt.Errorf("Peer is on a different minor version. Got %v, expected %v", oMinor, iMinor) + } + + // nodes must be on the same network + if info.Network != other.Network { + return fmt.Errorf("Peer is on a different network. Got %v, expected %v", other.Network, info.Network) + } + + return nil +} + +func (info *NodeInfo) ListenHost() string { + host, _, _ := net.SplitHostPort(info.ListenAddr) + return host +} + +func (info *NodeInfo) ListenPort() int { + _, port, _ := net.SplitHostPort(info.ListenAddr) + port_i, err := strconv.Atoi(port) + if err != nil { + return -1 + } + return port_i +} + +func splitVersion(version string) (string, string, string, error) { + spl := strings.Split(version, ".") + if len(spl) != 3 { + return "", "", "", fmt.Errorf("Invalid version format %v", version) + } + return spl[0], spl[1], spl[2], nil +} diff --git a/upnp/README.md b/upnp/README.md new file mode 100644 index 000000000..557d05bdc --- /dev/null +++ b/upnp/README.md @@ -0,0 +1,5 @@ +# `tendermint/p2p/upnp` + +## Resources + +* http://www.upnp-hacks.org/upnp.html diff --git a/upnp/log.go b/upnp/log.go new file mode 100644 index 000000000..45e44439c --- /dev/null +++ b/upnp/log.go @@ -0,0 +1,7 @@ +package upnp + +import ( + "github.com/tendermint/tmlibs/logger" +) + +var log = logger.New("module", "upnp") diff --git a/upnp/probe.go b/upnp/probe.go new file mode 100644 index 000000000..5488de587 --- /dev/null +++ b/upnp/probe.go @@ -0,0 +1,111 @@ +package upnp + +import ( + "errors" + "fmt" + "net" + "time" + + . "github.com/tendermint/tmlibs/common" +) + +type UPNPCapabilities struct { + PortMapping bool + Hairpin bool +} + +func makeUPNPListener(intPort int, extPort int) (NAT, net.Listener, net.IP, error) { + nat, err := Discover() + if err != nil { + return nil, nil, nil, errors.New(fmt.Sprintf("NAT upnp could not be discovered: %v", err)) + } + log.Info(Fmt("ourIP: %v", nat.(*upnpNAT).ourIP)) + + ext, err := nat.GetExternalAddress() + if err != nil { + return nat, nil, nil, errors.New(fmt.Sprintf("External address error: %v", err)) + } + log.Info(Fmt("External address: %v", ext)) + + port, err := nat.AddPortMapping("tcp", extPort, intPort, "Tendermint UPnP Probe", 0) + if err != nil { + return nat, nil, ext, errors.New(fmt.Sprintf("Port mapping error: %v", err)) + } + log.Info(Fmt("Port mapping mapped: %v", port)) + + // also run the listener, open for all remote addresses. + listener, err := net.Listen("tcp", fmt.Sprintf(":%v", intPort)) + if err != nil { + return nat, nil, ext, errors.New(fmt.Sprintf("Error establishing listener: %v", err)) + } + return nat, listener, ext, nil +} + +func testHairpin(listener net.Listener, extAddr string) (supportsHairpin bool) { + // Listener + go func() { + inConn, err := listener.Accept() + if err != nil { + log.Notice(Fmt("Listener.Accept() error: %v", err)) + return + } + log.Info(Fmt("Accepted incoming connection: %v -> %v", inConn.LocalAddr(), inConn.RemoteAddr())) + buf := make([]byte, 1024) + n, err := inConn.Read(buf) + if err != nil { + log.Notice(Fmt("Incoming connection read error: %v", err)) + return + } + log.Info(Fmt("Incoming connection read %v bytes: %X", n, buf)) + if string(buf) == "test data" { + supportsHairpin = true + return + } + }() + + // Establish outgoing + outConn, err := net.Dial("tcp", extAddr) + if err != nil { + log.Notice(Fmt("Outgoing connection dial error: %v", err)) + return + } + + n, err := outConn.Write([]byte("test data")) + if err != nil { + log.Notice(Fmt("Outgoing connection write error: %v", err)) + return + } + log.Info(Fmt("Outgoing connection wrote %v bytes", n)) + + // Wait for data receipt + time.Sleep(1 * time.Second) + return +} + +func Probe() (caps UPNPCapabilities, err error) { + log.Info("Probing for UPnP!") + + intPort, extPort := 8001, 8001 + + nat, listener, ext, err := makeUPNPListener(intPort, extPort) + if err != nil { + return + } + caps.PortMapping = true + + // Deferred cleanup + defer func() { + err = nat.DeletePortMapping("tcp", intPort, extPort) + if err != nil { + log.Warn(Fmt("Port mapping delete error: %v", err)) + } + listener.Close() + }() + + supportsHairpin := testHairpin(listener, fmt.Sprintf("%v:%v", ext, extPort)) + if supportsHairpin { + caps.Hairpin = true + } + + return +} diff --git a/upnp/upnp.go b/upnp/upnp.go new file mode 100644 index 000000000..3d6c55035 --- /dev/null +++ b/upnp/upnp.go @@ -0,0 +1,380 @@ +/* +Taken from taipei-torrent + +Just enough UPnP to be able to forward ports +*/ +package upnp + +// BUG(jae): TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh + +import ( + "bytes" + "encoding/xml" + "errors" + "io/ioutil" + "net" + "net/http" + "strconv" + "strings" + "time" +) + +type upnpNAT struct { + serviceURL string + ourIP string + urnDomain string +} + +// protocol is either "udp" or "tcp" +type NAT interface { + GetExternalAddress() (addr net.IP, err error) + AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) + DeletePortMapping(protocol string, externalPort, internalPort int) (err error) +} + +func Discover() (nat NAT, err error) { + ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") + if err != nil { + return + } + conn, err := net.ListenPacket("udp4", ":0") + if err != nil { + return + } + socket := conn.(*net.UDPConn) + defer socket.Close() + + err = socket.SetDeadline(time.Now().Add(3 * time.Second)) + if err != nil { + return + } + + st := "InternetGatewayDevice:1" + + buf := bytes.NewBufferString( + "M-SEARCH * HTTP/1.1\r\n" + + "HOST: 239.255.255.250:1900\r\n" + + "ST: ssdp:all\r\n" + + "MAN: \"ssdp:discover\"\r\n" + + "MX: 2\r\n\r\n") + message := buf.Bytes() + answerBytes := make([]byte, 1024) + for i := 0; i < 3; i++ { + _, err = socket.WriteToUDP(message, ssdp) + if err != nil { + return + } + var n int + n, _, err = socket.ReadFromUDP(answerBytes) + for { + n, _, err = socket.ReadFromUDP(answerBytes) + if err != nil { + break + } + answer := string(answerBytes[0:n]) + if strings.Index(answer, st) < 0 { + continue + } + // HTTP header field names are case-insensitive. + // http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + locString := "\r\nlocation:" + answer = strings.ToLower(answer) + locIndex := strings.Index(answer, locString) + if locIndex < 0 { + continue + } + loc := answer[locIndex+len(locString):] + endIndex := strings.Index(loc, "\r\n") + if endIndex < 0 { + continue + } + locURL := strings.TrimSpace(loc[0:endIndex]) + var serviceURL, urnDomain string + serviceURL, urnDomain, err = getServiceURL(locURL) + if err != nil { + return + } + var ourIP net.IP + ourIP, err = localIPv4() + if err != nil { + return + } + nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} + return + } + } + err = errors.New("UPnP port discovery failed.") + return +} + +type Envelope struct { + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` + Soap *SoapBody +} +type SoapBody struct { + XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` + ExternalIP *ExternalIPAddressResponse +} + +type ExternalIPAddressResponse struct { + XMLName xml.Name `xml:"GetExternalIPAddressResponse"` + IPAddress string `xml:"NewExternalIPAddress"` +} + +type ExternalIPAddress struct { + XMLName xml.Name `xml:"NewExternalIPAddress"` + IP string +} + +type UPNPService struct { + ServiceType string `xml:"serviceType"` + ControlURL string `xml:"controlURL"` +} + +type DeviceList struct { + Device []Device `xml:"device"` +} + +type ServiceList struct { + Service []UPNPService `xml:"service"` +} + +type Device struct { + XMLName xml.Name `xml:"device"` + DeviceType string `xml:"deviceType"` + DeviceList DeviceList `xml:"deviceList"` + ServiceList ServiceList `xml:"serviceList"` +} + +type Root struct { + Device Device +} + +func getChildDevice(d *Device, deviceType string) *Device { + dl := d.DeviceList.Device + for i := 0; i < len(dl); i++ { + if strings.Index(dl[i].DeviceType, deviceType) >= 0 { + return &dl[i] + } + } + return nil +} + +func getChildService(d *Device, serviceType string) *UPNPService { + sl := d.ServiceList.Service + for i := 0; i < len(sl); i++ { + if strings.Index(sl[i].ServiceType, serviceType) >= 0 { + return &sl[i] + } + } + return nil +} + +func localIPv4() (net.IP, error) { + tt, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, t := range tt { + aa, err := t.Addrs() + if err != nil { + return nil, err + } + for _, a := range aa { + ipnet, ok := a.(*net.IPNet) + if !ok { + continue + } + v4 := ipnet.IP.To4() + if v4 == nil || v4[0] == 127 { // loopback address + continue + } + return v4, nil + } + } + return nil, errors.New("cannot find local IP address") +} + +func getServiceURL(rootURL string) (url, urnDomain string, err error) { + r, err := http.Get(rootURL) + if err != nil { + return + } + defer r.Body.Close() + if r.StatusCode >= 400 { + err = errors.New(string(r.StatusCode)) + return + } + var root Root + err = xml.NewDecoder(r.Body).Decode(&root) + if err != nil { + return + } + a := &root.Device + if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 { + err = errors.New("No InternetGatewayDevice") + return + } + b := getChildDevice(a, "WANDevice:1") + if b == nil { + err = errors.New("No WANDevice") + return + } + c := getChildDevice(b, "WANConnectionDevice:1") + if c == nil { + err = errors.New("No WANConnectionDevice") + return + } + d := getChildService(c, "WANIPConnection:1") + if d == nil { + // Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, + // instead of under WanConnectionDevice + d = getChildService(b, "WANIPConnection:1") + + if d == nil { + err = errors.New("No WANIPConnection") + return + } + } + // Extract the domain name, which isn't always 'schemas-upnp-org' + urnDomain = strings.Split(d.ServiceType, ":")[1] + url = combineURL(rootURL, d.ControlURL) + return +} + +func combineURL(rootURL, subURL string) string { + protocolEnd := "://" + protoEndIndex := strings.Index(rootURL, protocolEnd) + a := rootURL[protoEndIndex+len(protocolEnd):] + rootIndex := strings.Index(a, "/") + return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL +} + +func soapRequest(url, function, message, domain string) (r *http.Response, err error) { + fullMessage := "" + + "\r\n" + + "" + message + "" + + req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") + req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") + //req.Header.Set("Transfer-Encoding", "chunked") + req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") + req.Header.Set("Connection", "Close") + req.Header.Set("Cache-Control", "no-cache") + req.Header.Set("Pragma", "no-cache") + + // log.Stderr("soapRequest ", req) + + r, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + /*if r.Body != nil { + defer r.Body.Close() + }*/ + + if r.StatusCode >= 400 { + // log.Stderr(function, r.StatusCode) + err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) + r = nil + return + } + return +} + +type statusInfo struct { + externalIpAddress string +} + +func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) { + + message := "\r\n" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + var envelope Envelope + data, err := ioutil.ReadAll(response.Body) + reader := bytes.NewReader(data) + xml.NewDecoder(reader).Decode(&envelope) + + info = statusInfo{envelope.Soap.ExternalIP.IPAddress} + + if err != nil { + return + } + + return +} + +func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { + info, err := n.getExternalIPAddress() + if err != nil { + return + } + addr = net.ParseIP(info.externalIpAddress) + return +} + +func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { + // A single concatenation would break ARM compilation. + message := "\r\n" + + "" + strconv.Itoa(externalPort) + message += "" + protocol + "" + message += "" + strconv.Itoa(internalPort) + "" + + "" + n.ourIP + "" + + "1" + message += description + + "" + strconv.Itoa(timeout) + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + + // TODO: check response to see if the port was forwarded + // log.Println(message, response) + // JAE: + // body, err := ioutil.ReadAll(response.Body) + // fmt.Println(string(body), err) + mappedExternalPort = externalPort + _ = response + return +} + +func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { + + message := "\r\n" + + "" + strconv.Itoa(externalPort) + + "" + protocol + "" + + "" + + var response *http.Response + response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) + if response != nil { + defer response.Body.Close() + } + if err != nil { + return + } + + // TODO: check response to see if the port was deleted + // log.Println(message, response) + _ = response + return +} diff --git a/util.go b/util.go new file mode 100644 index 000000000..2be320263 --- /dev/null +++ b/util.go @@ -0,0 +1,15 @@ +package p2p + +import ( + "crypto/sha256" +) + +// doubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes. +func doubleSha256(b []byte) []byte { + hasher := sha256.New() + hasher.Write(b) + sum := hasher.Sum(nil) + hasher.Reset() + hasher.Write(sum) + return hasher.Sum(nil) +} diff --git a/version.go b/version.go new file mode 100644 index 000000000..9a4c7bbaf --- /dev/null +++ b/version.go @@ -0,0 +1,3 @@ +package p2p + +const Version = "0.5.0"