@ -1,7 +1,7 @@ | |||
package blockchain | |||
import ( | |||
"github.com/tendermint/go-logger" | |||
"github.com/tendermint/tmlibs/logger" | |||
) | |||
var log = logger.New("module", "blockchain") |
@ -1,56 +1,57 @@ | |||
package: github.com/tendermint/tendermint | |||
import: | |||
- package: github.com/tendermint/go-autofile | |||
version: develop | |||
- package: github.com/tendermint/go-clist | |||
version: develop | |||
- package: github.com/tendermint/go-common | |||
- package: github.com/ebuchman/fail-test | |||
- package: github.com/gogo/protobuf | |||
subpackages: | |||
- proto | |||
- package: github.com/golang/protobuf | |||
subpackages: | |||
- proto | |||
- package: github.com/gorilla/websocket | |||
- package: github.com/pkg/errors | |||
- package: github.com/spf13/cobra | |||
- package: github.com/stretchr/testify | |||
subpackages: | |||
- require | |||
- package: github.com/tendermint/abci | |||
version: develop | |||
subpackages: | |||
- client | |||
- example/dummy | |||
- types | |||
- package: github.com/tendermint/go-config | |||
version: develop | |||
- package: github.com/tendermint/go-crypto | |||
version: develop | |||
- package: github.com/tendermint/go-data | |||
version: develop | |||
- package: github.com/tendermint/go-db | |||
version: develop | |||
- package: github.com/tendermint/go-events | |||
version: develop | |||
- package: github.com/tendermint/go-logger | |||
version: develop | |||
- package: github.com/tendermint/go-merkle | |||
version: develop | |||
- package: github.com/tendermint/go-p2p | |||
version: unstable | |||
- package: github.com/tendermint/go-rpc | |||
version: develop | |||
- package: github.com/tendermint/go-wire | |||
version: develop | |||
- package: github.com/tendermint/abci | |||
version: develop | |||
- package: github.com/tendermint/go-flowrate | |||
subpackages: | |||
- data | |||
- package: github.com/tendermint/log15 | |||
- package: github.com/tendermint/ed25519 | |||
- package: github.com/tendermint/merkleeyes | |||
- package: github.com/tendermint/tmlibs | |||
version: develop | |||
subpackages: | |||
- app | |||
- package: github.com/gogo/protobuf | |||
version: ^0.3 | |||
subpackages: | |||
- proto | |||
- package: github.com/gorilla/websocket | |||
version: ^1.1.0 | |||
- package: github.com/spf13/cobra | |||
- package: github.com/spf13/pflag | |||
- package: github.com/pkg/errors | |||
version: ^0.8.0 | |||
- autofile | |||
- clist | |||
- common | |||
- db | |||
- events | |||
- flowrate | |||
- logger | |||
- merkle | |||
- package: golang.org/x/crypto | |||
subpackages: | |||
- nacl/box | |||
- nacl/secretbox | |||
- ripemd160 | |||
- package: golang.org/x/net | |||
subpackages: | |||
- context | |||
- package: google.golang.org/grpc | |||
testImport: | |||
- package: github.com/stretchr/testify | |||
version: ^1.1.4 | |||
- package: github.com/tendermint/merkleeyes | |||
version: develop | |||
subpackages: | |||
- assert | |||
- require | |||
- app | |||
- iavl | |||
- testutil |
@ -1,7 +1,7 @@ | |||
package node | |||
import ( | |||
"github.com/tendermint/go-logger" | |||
"github.com/tendermint/tmlibs/logger" | |||
) | |||
var log = logger.New("module", "node") |
@ -0,0 +1,78 @@ | |||
# Changelog | |||
## 0.5.0 (April 21, 2017) | |||
BREAKING CHANGES: | |||
- Remove or unexport methods from FuzzedConnection: Active, Mode, ProbDropRW, ProbDropConn, ProbSleep, MaxDelayMilliseconds, Fuzz | |||
- switch.AddPeerWithConnection is unexported and replaced by switch.AddPeer | |||
- switch.DialPeerWithAddress takes a bool, setting the peer as persistent or not | |||
FEATURES: | |||
- Persistent peers: any peer considered a "seed" will be reconnected to when the connection is dropped | |||
IMPROVEMENTS: | |||
- Many more tests and comments | |||
- Refactor configurations for less dependence on go-config. Introduces new structs PeerConfig, MConnConfig, FuzzConnConfig | |||
- New methods on peer: CloseConn, HandshakeTimeout, IsPersistent, Addr, PubKey | |||
- NewNetAddress supports a testing mode where the address defaults to 0.0.0.0:0 | |||
## 0.4.0 (March 6, 2017) | |||
BREAKING CHANGES: | |||
- DialSeeds now takes an AddrBook and returns an error: `DialSeeds(*AddrBook, []string) error` | |||
- NewNetAddressString now returns an error: `NewNetAddressString(string) (*NetAddress, error)` | |||
FEATURES: | |||
- `NewNetAddressStrings([]string) ([]*NetAddress, error)` | |||
- `AddrBook.Save()` | |||
IMPROVEMENTS: | |||
- PexReactor responsible for starting and stopping the AddrBook | |||
BUG FIXES: | |||
- DialSeeds returns an error instead of panicking on bad addresses | |||
## 0.3.5 (January 12, 2017) | |||
FEATURES | |||
- Toggle strict routability in the AddrBook | |||
BUG FIXES | |||
- Close filtered out connections | |||
- Fixes for MakeConnectedSwitches and Connect2Switches | |||
## 0.3.4 (August 10, 2016) | |||
FEATURES: | |||
- Optionally filter connections by address or public key | |||
## 0.3.3 (May 12, 2016) | |||
FEATURES: | |||
- FuzzConn | |||
## 0.3.2 (March 12, 2016) | |||
IMPROVEMENTS: | |||
- Memory optimizations | |||
## 0.3.1 () | |||
FEATURES: | |||
- Configurable parameters | |||
@ -0,0 +1,13 @@ | |||
FROM golang:latest | |||
RUN curl https://glide.sh/get | sh | |||
RUN mkdir -p /go/src/github.com/tendermint/tendermint/p2p | |||
WORKDIR /go/src/github.com/tendermint/tendermint/p2p | |||
COPY glide.yaml /go/src/github.com/tendermint/tendermint/p2p/ | |||
COPY glide.lock /go/src/github.com/tendermint/tendermint/p2p/ | |||
RUN glide install | |||
COPY . /go/src/github.com/tendermint/tendermint/p2p |
@ -0,0 +1,79 @@ | |||
# `tendermint/tendermint/p2p` | |||
[![CircleCI](https://circleci.com/gh/tendermint/tendermint/p2p.svg?style=svg)](https://circleci.com/gh/tendermint/tendermint/p2p) | |||
`tendermint/tendermint/p2p` provides an abstraction around peer-to-peer communication.<br/> | |||
## Peer/MConnection/Channel | |||
Each peer has one `MConnection` (multiplex connection) instance. | |||
__multiplex__ *noun* a system or signal involving simultaneous transmission of | |||
several messages along a single channel of communication. | |||
Each `MConnection` handles message transmission on multiple abstract communication | |||
`Channel`s. Each channel has a globally unique byte id. | |||
The byte id and the relative priorities of each `Channel` are configured upon | |||
initialization of the connection. | |||
There are two methods for sending messages: | |||
```go | |||
func (m MConnection) Send(chID byte, msg interface{}) bool {} | |||
func (m MConnection) TrySend(chID byte, msg interface{}) bool {} | |||
``` | |||
`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued | |||
for the channel with the given id byte `chID`. The message `msg` is serialized | |||
using the `tendermint/wire` submodule's `WriteBinary()` reflection routine. | |||
`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's | |||
queue is full. | |||
`Send()` and `TrySend()` are also exposed for each `Peer`. | |||
## Switch/Reactor | |||
The `Switch` handles peer connections and exposes an API to receive incoming messages | |||
on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one | |||
or more `Channels`. So while sending outgoing messages is typically performed on the peer, | |||
incoming messages are received on the reactor. | |||
```go | |||
// Declare a MyReactor reactor that handles messages on MyChannelID. | |||
type MyReactor struct{} | |||
func (reactor MyReactor) GetChannels() []*ChannelDescriptor { | |||
return []*ChannelDescriptor{ChannelDescriptor{ID:MyChannelID, Priority: 1}} | |||
} | |||
func (reactor MyReactor) Receive(chID byte, peer *Peer, msgBytes []byte) { | |||
r, n, err := bytes.NewBuffer(msgBytes), new(int64), new(error) | |||
msgString := ReadString(r, n, err) | |||
fmt.Println(msgString) | |||
} | |||
// Other Reactor methods omitted for brevity | |||
... | |||
switch := NewSwitch([]Reactor{MyReactor{}}) | |||
... | |||
// Send a random message to all outbound connections | |||
for _, peer := range switch.Peers().List() { | |||
if peer.IsOutbound() { | |||
peer.Send(MyChannelID, "Here's a random message") | |||
} | |||
} | |||
``` | |||
### PexReactor/AddrBook | |||
A `PEXReactor` reactor implementation is provided to automate peer discovery. | |||
```go | |||
book := p2p.NewAddrBook(addrBookFilePath) | |||
pexReactor := p2p.NewPEXReactor(book) | |||
... | |||
switch := NewSwitch([]Reactor{pexReactor, myReactor, ...}) | |||
``` |
@ -0,0 +1,839 @@ | |||
// Modified for Tendermint | |||
// Originally Copyright (c) 2013-2014 Conformal Systems LLC. | |||
// https://github.com/conformal/btcd/blob/master/LICENSE | |||
package p2p | |||
import ( | |||
"encoding/binary" | |||
"encoding/json" | |||
"math" | |||
"math/rand" | |||
"net" | |||
"os" | |||
"sync" | |||
"time" | |||
. "github.com/tendermint/tmlibs/common" | |||
crypto "github.com/tendermint/go-crypto" | |||
) | |||
const ( | |||
// addresses under which the address manager will claim to need more addresses. | |||
needAddressThreshold = 1000 | |||
// interval used to dump the address cache to disk for future use. | |||
dumpAddressInterval = time.Minute * 2 | |||
// max addresses in each old address bucket. | |||
oldBucketSize = 64 | |||
// buckets we split old addresses over. | |||
oldBucketCount = 64 | |||
// max addresses in each new address bucket. | |||
newBucketSize = 64 | |||
// buckets that we spread new addresses over. | |||
newBucketCount = 256 | |||
// old buckets over which an address group will be spread. | |||
oldBucketsPerGroup = 4 | |||
// new buckets over which an source address group will be spread. | |||
newBucketsPerGroup = 32 | |||
// buckets a frequently seen new address may end up in. | |||
maxNewBucketsPerAddress = 4 | |||
// days before which we assume an address has vanished | |||
// if we have not seen it announced in that long. | |||
numMissingDays = 30 | |||
// tries without a single success before we assume an address is bad. | |||
numRetries = 3 | |||
// max failures we will accept without a success before considering an address bad. | |||
maxFailures = 10 | |||
// days since the last success before we will consider evicting an address. | |||
minBadDays = 7 | |||
// % of total addresses known returned by GetSelection. | |||
getSelectionPercent = 23 | |||
// min addresses that must be returned by GetSelection. Useful for bootstrapping. | |||
minGetSelection = 32 | |||
// max addresses returned by GetSelection | |||
// NOTE: this must match "maxPexMessageSize" | |||
maxGetSelection = 250 | |||
// current version of the on-disk format. | |||
serializationVersion = 1 | |||
) | |||
const ( | |||
bucketTypeNew = 0x01 | |||
bucketTypeOld = 0x02 | |||
) | |||
// AddrBook - concurrency safe peer address manager. | |||
type AddrBook struct { | |||
BaseService | |||
mtx sync.Mutex | |||
filePath string | |||
routabilityStrict bool | |||
rand *rand.Rand | |||
key string | |||
ourAddrs map[string]*NetAddress | |||
addrLookup map[string]*knownAddress // new & old | |||
addrNew []map[string]*knownAddress | |||
addrOld []map[string]*knownAddress | |||
wg sync.WaitGroup | |||
nOld int | |||
nNew int | |||
} | |||
// NewAddrBook creates a new address book. | |||
// Use Start to begin processing asynchronous address updates. | |||
func NewAddrBook(filePath string, routabilityStrict bool) *AddrBook { | |||
am := &AddrBook{ | |||
rand: rand.New(rand.NewSource(time.Now().UnixNano())), | |||
ourAddrs: make(map[string]*NetAddress), | |||
addrLookup: make(map[string]*knownAddress), | |||
filePath: filePath, | |||
routabilityStrict: routabilityStrict, | |||
} | |||
am.init() | |||
am.BaseService = *NewBaseService(log, "AddrBook", am) | |||
return am | |||
} | |||
// When modifying this, don't forget to update loadFromFile() | |||
func (a *AddrBook) init() { | |||
a.key = crypto.CRandHex(24) // 24/2 * 8 = 96 bits | |||
// New addr buckets | |||
a.addrNew = make([]map[string]*knownAddress, newBucketCount) | |||
for i := range a.addrNew { | |||
a.addrNew[i] = make(map[string]*knownAddress) | |||
} | |||
// Old addr buckets | |||
a.addrOld = make([]map[string]*knownAddress, oldBucketCount) | |||
for i := range a.addrOld { | |||
a.addrOld[i] = make(map[string]*knownAddress) | |||
} | |||
} | |||
// OnStart implements Service. | |||
func (a *AddrBook) OnStart() error { | |||
a.BaseService.OnStart() | |||
a.loadFromFile(a.filePath) | |||
a.wg.Add(1) | |||
go a.saveRoutine() | |||
return nil | |||
} | |||
// OnStop implements Service. | |||
func (a *AddrBook) OnStop() { | |||
a.BaseService.OnStop() | |||
} | |||
func (a *AddrBook) Wait() { | |||
a.wg.Wait() | |||
} | |||
func (a *AddrBook) AddOurAddress(addr *NetAddress) { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
log.Info("Add our address to book", "addr", addr) | |||
a.ourAddrs[addr.String()] = addr | |||
} | |||
func (a *AddrBook) OurAddresses() []*NetAddress { | |||
addrs := []*NetAddress{} | |||
for _, addr := range a.ourAddrs { | |||
addrs = append(addrs, addr) | |||
} | |||
return addrs | |||
} | |||
// NOTE: addr must not be nil | |||
func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
log.Info("Add address to book", "addr", addr, "src", src) | |||
a.addAddress(addr, src) | |||
} | |||
func (a *AddrBook) NeedMoreAddrs() bool { | |||
return a.Size() < needAddressThreshold | |||
} | |||
func (a *AddrBook) Size() int { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
return a.size() | |||
} | |||
func (a *AddrBook) size() int { | |||
return a.nNew + a.nOld | |||
} | |||
// Pick an address to connect to with new/old bias. | |||
func (a *AddrBook) PickAddress(newBias int) *NetAddress { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
if a.size() == 0 { | |||
return nil | |||
} | |||
if newBias > 100 { | |||
newBias = 100 | |||
} | |||
if newBias < 0 { | |||
newBias = 0 | |||
} | |||
// Bias between new and old addresses. | |||
oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias)) | |||
newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias) | |||
if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation { | |||
// pick random Old bucket. | |||
var bucket map[string]*knownAddress = nil | |||
for len(bucket) == 0 { | |||
bucket = a.addrOld[a.rand.Intn(len(a.addrOld))] | |||
} | |||
// pick a random ka from bucket. | |||
randIndex := a.rand.Intn(len(bucket)) | |||
for _, ka := range bucket { | |||
if randIndex == 0 { | |||
return ka.Addr | |||
} | |||
randIndex-- | |||
} | |||
PanicSanity("Should not happen") | |||
} else { | |||
// pick random New bucket. | |||
var bucket map[string]*knownAddress = nil | |||
for len(bucket) == 0 { | |||
bucket = a.addrNew[a.rand.Intn(len(a.addrNew))] | |||
} | |||
// pick a random ka from bucket. | |||
randIndex := a.rand.Intn(len(bucket)) | |||
for _, ka := range bucket { | |||
if randIndex == 0 { | |||
return ka.Addr | |||
} | |||
randIndex-- | |||
} | |||
PanicSanity("Should not happen") | |||
} | |||
return nil | |||
} | |||
func (a *AddrBook) MarkGood(addr *NetAddress) { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
ka := a.addrLookup[addr.String()] | |||
if ka == nil { | |||
return | |||
} | |||
ka.markGood() | |||
if ka.isNew() { | |||
a.moveToOld(ka) | |||
} | |||
} | |||
func (a *AddrBook) MarkAttempt(addr *NetAddress) { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
ka := a.addrLookup[addr.String()] | |||
if ka == nil { | |||
return | |||
} | |||
ka.markAttempt() | |||
} | |||
// MarkBad currently just ejects the address. In the future, consider | |||
// blacklisting. | |||
func (a *AddrBook) MarkBad(addr *NetAddress) { | |||
a.RemoveAddress(addr) | |||
} | |||
// RemoveAddress removes the address from the book. | |||
func (a *AddrBook) RemoveAddress(addr *NetAddress) { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
ka := a.addrLookup[addr.String()] | |||
if ka == nil { | |||
return | |||
} | |||
log.Info("Remove address from book", "addr", addr) | |||
a.removeFromAllBuckets(ka) | |||
} | |||
/* Peer exchange */ | |||
// GetSelection randomly selects some addresses (old & new). Suitable for peer-exchange protocols. | |||
func (a *AddrBook) GetSelection() []*NetAddress { | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
if a.size() == 0 { | |||
return nil | |||
} | |||
allAddr := make([]*NetAddress, a.size()) | |||
i := 0 | |||
for _, v := range a.addrLookup { | |||
allAddr[i] = v.Addr | |||
i++ | |||
} | |||
numAddresses := MaxInt( | |||
MinInt(minGetSelection, len(allAddr)), | |||
len(allAddr)*getSelectionPercent/100) | |||
numAddresses = MinInt(maxGetSelection, numAddresses) | |||
// Fisher-Yates shuffle the array. We only need to do the first | |||
// `numAddresses' since we are throwing the rest. | |||
for i := 0; i < numAddresses; i++ { | |||
// pick a number between current index and the end | |||
j := rand.Intn(len(allAddr)-i) + i | |||
allAddr[i], allAddr[j] = allAddr[j], allAddr[i] | |||
} | |||
// slice off the limit we are willing to share. | |||
return allAddr[:numAddresses] | |||
} | |||
/* Loading & Saving */ | |||
type addrBookJSON struct { | |||
Key string | |||
Addrs []*knownAddress | |||
} | |||
func (a *AddrBook) saveToFile(filePath string) { | |||
log.Info("Saving AddrBook to file", "size", a.Size()) | |||
a.mtx.Lock() | |||
defer a.mtx.Unlock() | |||
// Compile Addrs | |||
addrs := []*knownAddress{} | |||
for _, ka := range a.addrLookup { | |||
addrs = append(addrs, ka) | |||
} | |||
aJSON := &addrBookJSON{ | |||
Key: a.key, | |||
Addrs: addrs, | |||
} | |||
jsonBytes, err := json.MarshalIndent(aJSON, "", "\t") | |||
if err != nil { | |||
log.Error("Failed to save AddrBook to file", "err", err) | |||
return | |||
} | |||
err = WriteFileAtomic(filePath, jsonBytes, 0644) | |||
if err != nil { | |||
log.Error("Failed to save AddrBook to file", "file", filePath, "error", err) | |||
} | |||
} | |||
// Returns false if file does not exist. | |||
// Panics if file is corrupt. | |||
func (a *AddrBook) loadFromFile(filePath string) bool { | |||
// If doesn't exist, do nothing. | |||
_, err := os.Stat(filePath) | |||
if os.IsNotExist(err) { | |||
return false | |||
} | |||
// Load addrBookJSON{} | |||
r, err := os.Open(filePath) | |||
if err != nil { | |||
PanicCrisis(Fmt("Error opening file %s: %v", filePath, err)) | |||
} | |||
defer r.Close() | |||
aJSON := &addrBookJSON{} | |||
dec := json.NewDecoder(r) | |||
err = dec.Decode(aJSON) | |||
if err != nil { | |||
PanicCrisis(Fmt("Error reading file %s: %v", filePath, err)) | |||
} | |||
// Restore all the fields... | |||
// Restore the key | |||
a.key = aJSON.Key | |||
// Restore .addrNew & .addrOld | |||
for _, ka := range aJSON.Addrs { | |||
for _, bucketIndex := range ka.Buckets { | |||
bucket := a.getBucket(ka.BucketType, bucketIndex) | |||
bucket[ka.Addr.String()] = ka | |||
} | |||
a.addrLookup[ka.Addr.String()] = ka | |||
if ka.BucketType == bucketTypeNew { | |||
a.nNew++ | |||
} else { | |||
a.nOld++ | |||
} | |||
} | |||
return true | |||
} | |||
// Save saves the book. | |||
func (a *AddrBook) Save() { | |||
log.Info("Saving AddrBook to file", "size", a.Size()) | |||
a.saveToFile(a.filePath) | |||
} | |||
/* Private methods */ | |||
func (a *AddrBook) saveRoutine() { | |||
dumpAddressTicker := time.NewTicker(dumpAddressInterval) | |||
out: | |||
for { | |||
select { | |||
case <-dumpAddressTicker.C: | |||
a.saveToFile(a.filePath) | |||
case <-a.Quit: | |||
break out | |||
} | |||
} | |||
dumpAddressTicker.Stop() | |||
a.saveToFile(a.filePath) | |||
a.wg.Done() | |||
log.Notice("Address handler done") | |||
} | |||
func (a *AddrBook) getBucket(bucketType byte, bucketIdx int) map[string]*knownAddress { | |||
switch bucketType { | |||
case bucketTypeNew: | |||
return a.addrNew[bucketIdx] | |||
case bucketTypeOld: | |||
return a.addrOld[bucketIdx] | |||
default: | |||
PanicSanity("Should not happen") | |||
return nil | |||
} | |||
} | |||
// Adds ka to new bucket. Returns false if it couldn't do it cuz buckets full. | |||
// NOTE: currently it always returns true. | |||
func (a *AddrBook) addToNewBucket(ka *knownAddress, bucketIdx int) bool { | |||
// Sanity check | |||
if ka.isOld() { | |||
log.Warn(Fmt("Cannot add address already in old bucket to a new bucket: %v", ka)) | |||
return false | |||
} | |||
addrStr := ka.Addr.String() | |||
bucket := a.getBucket(bucketTypeNew, bucketIdx) | |||
// Already exists? | |||
if _, ok := bucket[addrStr]; ok { | |||
return true | |||
} | |||
// Enforce max addresses. | |||
if len(bucket) > newBucketSize { | |||
log.Notice("new bucket is full, expiring old ") | |||
a.expireNew(bucketIdx) | |||
} | |||
// Add to bucket. | |||
bucket[addrStr] = ka | |||
if ka.addBucketRef(bucketIdx) == 1 { | |||
a.nNew++ | |||
} | |||
// Ensure in addrLookup | |||
a.addrLookup[addrStr] = ka | |||
return true | |||
} | |||
// Adds ka to old bucket. Returns false if it couldn't do it cuz buckets full. | |||
func (a *AddrBook) addToOldBucket(ka *knownAddress, bucketIdx int) bool { | |||
// Sanity check | |||
if ka.isNew() { | |||
log.Warn(Fmt("Cannot add new address to old bucket: %v", ka)) | |||
return false | |||
} | |||
if len(ka.Buckets) != 0 { | |||
log.Warn(Fmt("Cannot add already old address to another old bucket: %v", ka)) | |||
return false | |||
} | |||
addrStr := ka.Addr.String() | |||
bucket := a.getBucket(bucketTypeNew, bucketIdx) | |||
// Already exists? | |||
if _, ok := bucket[addrStr]; ok { | |||
return true | |||
} | |||
// Enforce max addresses. | |||
if len(bucket) > oldBucketSize { | |||
return false | |||
} | |||
// Add to bucket. | |||
bucket[addrStr] = ka | |||
if ka.addBucketRef(bucketIdx) == 1 { | |||
a.nOld++ | |||
} | |||
// Ensure in addrLookup | |||
a.addrLookup[addrStr] = ka | |||
return true | |||
} | |||
func (a *AddrBook) removeFromBucket(ka *knownAddress, bucketType byte, bucketIdx int) { | |||
if ka.BucketType != bucketType { | |||
log.Warn(Fmt("Bucket type mismatch: %v", ka)) | |||
return | |||
} | |||
bucket := a.getBucket(bucketType, bucketIdx) | |||
delete(bucket, ka.Addr.String()) | |||
if ka.removeBucketRef(bucketIdx) == 0 { | |||
if bucketType == bucketTypeNew { | |||
a.nNew-- | |||
} else { | |||
a.nOld-- | |||
} | |||
delete(a.addrLookup, ka.Addr.String()) | |||
} | |||
} | |||
func (a *AddrBook) removeFromAllBuckets(ka *knownAddress) { | |||
for _, bucketIdx := range ka.Buckets { | |||
bucket := a.getBucket(ka.BucketType, bucketIdx) | |||
delete(bucket, ka.Addr.String()) | |||
} | |||
ka.Buckets = nil | |||
if ka.BucketType == bucketTypeNew { | |||
a.nNew-- | |||
} else { | |||
a.nOld-- | |||
} | |||
delete(a.addrLookup, ka.Addr.String()) | |||
} | |||
func (a *AddrBook) pickOldest(bucketType byte, bucketIdx int) *knownAddress { | |||
bucket := a.getBucket(bucketType, bucketIdx) | |||
var oldest *knownAddress | |||
for _, ka := range bucket { | |||
if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt) { | |||
oldest = ka | |||
} | |||
} | |||
return oldest | |||
} | |||
func (a *AddrBook) addAddress(addr, src *NetAddress) { | |||
if a.routabilityStrict && !addr.Routable() { | |||
log.Warn(Fmt("Cannot add non-routable address %v", addr)) | |||
return | |||
} | |||
if _, ok := a.ourAddrs[addr.String()]; ok { | |||
// Ignore our own listener address. | |||
return | |||
} | |||
ka := a.addrLookup[addr.String()] | |||
if ka != nil { | |||
// Already old. | |||
if ka.isOld() { | |||
return | |||
} | |||
// Already in max new buckets. | |||
if len(ka.Buckets) == maxNewBucketsPerAddress { | |||
return | |||
} | |||
// The more entries we have, the less likely we are to add more. | |||
factor := int32(2 * len(ka.Buckets)) | |||
if a.rand.Int31n(factor) != 0 { | |||
return | |||
} | |||
} else { | |||
ka = newKnownAddress(addr, src) | |||
} | |||
bucket := a.calcNewBucket(addr, src) | |||
a.addToNewBucket(ka, bucket) | |||
log.Notice("Added new address", "address", addr, "total", a.size()) | |||
} | |||
// Make space in the new buckets by expiring the really bad entries. | |||
// If no bad entries are available we remove the oldest. | |||
func (a *AddrBook) expireNew(bucketIdx int) { | |||
for addrStr, ka := range a.addrNew[bucketIdx] { | |||
// If an entry is bad, throw it away | |||
if ka.isBad() { | |||
log.Notice(Fmt("expiring bad address %v", addrStr)) | |||
a.removeFromBucket(ka, bucketTypeNew, bucketIdx) | |||
return | |||
} | |||
} | |||
// If we haven't thrown out a bad entry, throw out the oldest entry | |||
oldest := a.pickOldest(bucketTypeNew, bucketIdx) | |||
a.removeFromBucket(oldest, bucketTypeNew, bucketIdx) | |||
} | |||
// Promotes an address from new to old. | |||
// TODO: Move to old probabilistically. | |||
// The better a node is, the less likely it should be evicted from an old bucket. | |||
func (a *AddrBook) moveToOld(ka *knownAddress) { | |||
// Sanity check | |||
if ka.isOld() { | |||
log.Warn(Fmt("Cannot promote address that is already old %v", ka)) | |||
return | |||
} | |||
if len(ka.Buckets) == 0 { | |||
log.Warn(Fmt("Cannot promote address that isn't in any new buckets %v", ka)) | |||
return | |||
} | |||
// Remember one of the buckets in which ka is in. | |||
freedBucket := ka.Buckets[0] | |||
// Remove from all (new) buckets. | |||
a.removeFromAllBuckets(ka) | |||
// It's officially old now. | |||
ka.BucketType = bucketTypeOld | |||
// Try to add it to its oldBucket destination. | |||
oldBucketIdx := a.calcOldBucket(ka.Addr) | |||
added := a.addToOldBucket(ka, oldBucketIdx) | |||
if !added { | |||
// No room, must evict something | |||
oldest := a.pickOldest(bucketTypeOld, oldBucketIdx) | |||
a.removeFromBucket(oldest, bucketTypeOld, oldBucketIdx) | |||
// Find new bucket to put oldest in | |||
newBucketIdx := a.calcNewBucket(oldest.Addr, oldest.Src) | |||
added := a.addToNewBucket(oldest, newBucketIdx) | |||
// No space in newBucket either, just put it in freedBucket from above. | |||
if !added { | |||
added := a.addToNewBucket(oldest, freedBucket) | |||
if !added { | |||
log.Warn(Fmt("Could not migrate oldest %v to freedBucket %v", oldest, freedBucket)) | |||
} | |||
} | |||
// Finally, add to bucket again. | |||
added = a.addToOldBucket(ka, oldBucketIdx) | |||
if !added { | |||
log.Warn(Fmt("Could not re-add ka %v to oldBucketIdx %v", ka, oldBucketIdx)) | |||
} | |||
} | |||
} | |||
// doublesha256( key + sourcegroup + | |||
// int64(doublesha256(key + group + sourcegroup))%bucket_per_group ) % num_new_buckets | |||
func (a *AddrBook) calcNewBucket(addr, src *NetAddress) int { | |||
data1 := []byte{} | |||
data1 = append(data1, []byte(a.key)...) | |||
data1 = append(data1, []byte(a.groupKey(addr))...) | |||
data1 = append(data1, []byte(a.groupKey(src))...) | |||
hash1 := doubleSha256(data1) | |||
hash64 := binary.BigEndian.Uint64(hash1) | |||
hash64 %= newBucketsPerGroup | |||
var hashbuf [8]byte | |||
binary.BigEndian.PutUint64(hashbuf[:], hash64) | |||
data2 := []byte{} | |||
data2 = append(data2, []byte(a.key)...) | |||
data2 = append(data2, a.groupKey(src)...) | |||
data2 = append(data2, hashbuf[:]...) | |||
hash2 := doubleSha256(data2) | |||
return int(binary.BigEndian.Uint64(hash2) % newBucketCount) | |||
} | |||
// doublesha256( key + group + | |||
// int64(doublesha256(key + addr))%buckets_per_group ) % num_old_buckets | |||
func (a *AddrBook) calcOldBucket(addr *NetAddress) int { | |||
data1 := []byte{} | |||
data1 = append(data1, []byte(a.key)...) | |||
data1 = append(data1, []byte(addr.String())...) | |||
hash1 := doubleSha256(data1) | |||
hash64 := binary.BigEndian.Uint64(hash1) | |||
hash64 %= oldBucketsPerGroup | |||
var hashbuf [8]byte | |||
binary.BigEndian.PutUint64(hashbuf[:], hash64) | |||
data2 := []byte{} | |||
data2 = append(data2, []byte(a.key)...) | |||
data2 = append(data2, a.groupKey(addr)...) | |||
data2 = append(data2, hashbuf[:]...) | |||
hash2 := doubleSha256(data2) | |||
return int(binary.BigEndian.Uint64(hash2) % oldBucketCount) | |||
} | |||
// Return a string representing the network group of this address. | |||
// This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string | |||
// "local" for a local address and the string "unroutable for an unroutable | |||
// address. | |||
func (a *AddrBook) groupKey(na *NetAddress) string { | |||
if a.routabilityStrict && na.Local() { | |||
return "local" | |||
} | |||
if a.routabilityStrict && !na.Routable() { | |||
return "unroutable" | |||
} | |||
if ipv4 := na.IP.To4(); ipv4 != nil { | |||
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String() | |||
} | |||
if na.RFC6145() || na.RFC6052() { | |||
// last four bytes are the ip address | |||
ip := net.IP(na.IP[12:16]) | |||
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() | |||
} | |||
if na.RFC3964() { | |||
ip := net.IP(na.IP[2:7]) | |||
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() | |||
} | |||
if na.RFC4380() { | |||
// teredo tunnels have the last 4 bytes as the v4 address XOR | |||
// 0xff. | |||
ip := net.IP(make([]byte, 4)) | |||
for i, byte := range na.IP[12:16] { | |||
ip[i] = byte ^ 0xff | |||
} | |||
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() | |||
} | |||
// OK, so now we know ourselves to be a IPv6 address. | |||
// bitcoind uses /32 for everything, except for Hurricane Electric's | |||
// (he.net) IP range, which it uses /36 for. | |||
bits := 32 | |||
heNet := &net.IPNet{IP: net.ParseIP("2001:470::"), | |||
Mask: net.CIDRMask(32, 128)} | |||
if heNet.Contains(na.IP) { | |||
bits = 36 | |||
} | |||
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String() | |||
} | |||
//----------------------------------------------------------------------------- | |||
/* | |||
knownAddress | |||
tracks information about a known network address that is used | |||
to determine how viable an address is. | |||
*/ | |||
type knownAddress struct { | |||
Addr *NetAddress | |||
Src *NetAddress | |||
Attempts int32 | |||
LastAttempt time.Time | |||
LastSuccess time.Time | |||
BucketType byte | |||
Buckets []int | |||
} | |||
func newKnownAddress(addr *NetAddress, src *NetAddress) *knownAddress { | |||
return &knownAddress{ | |||
Addr: addr, | |||
Src: src, | |||
Attempts: 0, | |||
LastAttempt: time.Now(), | |||
BucketType: bucketTypeNew, | |||
Buckets: nil, | |||
} | |||
} | |||
func (ka *knownAddress) isOld() bool { | |||
return ka.BucketType == bucketTypeOld | |||
} | |||
func (ka *knownAddress) isNew() bool { | |||
return ka.BucketType == bucketTypeNew | |||
} | |||
func (ka *knownAddress) markAttempt() { | |||
now := time.Now() | |||
ka.LastAttempt = now | |||
ka.Attempts += 1 | |||
} | |||
func (ka *knownAddress) markGood() { | |||
now := time.Now() | |||
ka.LastAttempt = now | |||
ka.Attempts = 0 | |||
ka.LastSuccess = now | |||
} | |||
func (ka *knownAddress) addBucketRef(bucketIdx int) int { | |||
for _, bucket := range ka.Buckets { | |||
if bucket == bucketIdx { | |||
log.Warn(Fmt("Bucket already exists in ka.Buckets: %v", ka)) | |||
return -1 | |||
} | |||
} | |||
ka.Buckets = append(ka.Buckets, bucketIdx) | |||
return len(ka.Buckets) | |||
} | |||
func (ka *knownAddress) removeBucketRef(bucketIdx int) int { | |||
buckets := []int{} | |||
for _, bucket := range ka.Buckets { | |||
if bucket != bucketIdx { | |||
buckets = append(buckets, bucket) | |||
} | |||
} | |||
if len(buckets) != len(ka.Buckets)-1 { | |||
log.Warn(Fmt("bucketIdx not found in ka.Buckets: %v", ka)) | |||
return -1 | |||
} | |||
ka.Buckets = buckets | |||
return len(ka.Buckets) | |||
} | |||
/* | |||
An address is bad if the address in question has not been tried in the last | |||
minute and meets one of the following criteria: | |||
1) It claims to be from the future | |||
2) It hasn't been seen in over a month | |||
3) It has failed at least three times and never succeeded | |||
4) It has failed ten times in the last week | |||
All addresses that meet these criteria are assumed to be worthless and not | |||
worth keeping hold of. | |||
*/ | |||
func (ka *knownAddress) isBad() bool { | |||
// Has been attempted in the last minute --> good | |||
if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) { | |||
return false | |||
} | |||
// Over a month old? | |||
if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) { | |||
return true | |||
} | |||
// Never succeeded? | |||
if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries { | |||
return true | |||
} | |||
// Hasn't succeeded in too long? | |||
if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) && | |||
ka.Attempts >= maxFailures { | |||
return true | |||
} | |||
return false | |||
} |
@ -0,0 +1,166 @@ | |||
package p2p | |||
import ( | |||
"fmt" | |||
"io/ioutil" | |||
"math/rand" | |||
"testing" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func createTempFileName(prefix string) string { | |||
f, err := ioutil.TempFile("", prefix) | |||
if err != nil { | |||
panic(err) | |||
} | |||
fname := f.Name() | |||
err = f.Close() | |||
if err != nil { | |||
panic(err) | |||
} | |||
return fname | |||
} | |||
func TestAddrBookSaveLoad(t *testing.T) { | |||
fname := createTempFileName("addrbook_test") | |||
// 0 addresses | |||
book := NewAddrBook(fname, true) | |||
book.saveToFile(fname) | |||
book = NewAddrBook(fname, true) | |||
book.loadFromFile(fname) | |||
assert.Zero(t, book.Size()) | |||
// 100 addresses | |||
randAddrs := randNetAddressPairs(t, 100) | |||
for _, addrSrc := range randAddrs { | |||
book.AddAddress(addrSrc.addr, addrSrc.src) | |||
} | |||
assert.Equal(t, 100, book.Size()) | |||
book.saveToFile(fname) | |||
book = NewAddrBook(fname, true) | |||
book.loadFromFile(fname) | |||
assert.Equal(t, 100, book.Size()) | |||
} | |||
func TestAddrBookLookup(t *testing.T) { | |||
fname := createTempFileName("addrbook_test") | |||
randAddrs := randNetAddressPairs(t, 100) | |||
book := NewAddrBook(fname, true) | |||
for _, addrSrc := range randAddrs { | |||
addr := addrSrc.addr | |||
src := addrSrc.src | |||
book.AddAddress(addr, src) | |||
ka := book.addrLookup[addr.String()] | |||
assert.NotNil(t, ka, "Expected to find KnownAddress %v but wasn't there.", addr) | |||
if !(ka.Addr.Equals(addr) && ka.Src.Equals(src)) { | |||
t.Fatalf("KnownAddress doesn't match addr & src") | |||
} | |||
} | |||
} | |||
func TestAddrBookPromoteToOld(t *testing.T) { | |||
fname := createTempFileName("addrbook_test") | |||
randAddrs := randNetAddressPairs(t, 100) | |||
book := NewAddrBook(fname, true) | |||
for _, addrSrc := range randAddrs { | |||
book.AddAddress(addrSrc.addr, addrSrc.src) | |||
} | |||
// Attempt all addresses. | |||
for _, addrSrc := range randAddrs { | |||
book.MarkAttempt(addrSrc.addr) | |||
} | |||
// Promote half of them | |||
for i, addrSrc := range randAddrs { | |||
if i%2 == 0 { | |||
book.MarkGood(addrSrc.addr) | |||
} | |||
} | |||
// TODO: do more testing :) | |||
selection := book.GetSelection() | |||
t.Logf("selection: %v", selection) | |||
if len(selection) > book.Size() { | |||
t.Errorf("selection could not be bigger than the book") | |||
} | |||
} | |||
func TestAddrBookHandlesDuplicates(t *testing.T) { | |||
fname := createTempFileName("addrbook_test") | |||
book := NewAddrBook(fname, true) | |||
randAddrs := randNetAddressPairs(t, 100) | |||
differentSrc := randIPv4Address(t) | |||
for _, addrSrc := range randAddrs { | |||
book.AddAddress(addrSrc.addr, addrSrc.src) | |||
book.AddAddress(addrSrc.addr, addrSrc.src) // duplicate | |||
book.AddAddress(addrSrc.addr, differentSrc) // different src | |||
} | |||
assert.Equal(t, 100, book.Size()) | |||
} | |||
type netAddressPair struct { | |||
addr *NetAddress | |||
src *NetAddress | |||
} | |||
func randNetAddressPairs(t *testing.T, n int) []netAddressPair { | |||
randAddrs := make([]netAddressPair, n) | |||
for i := 0; i < n; i++ { | |||
randAddrs[i] = netAddressPair{addr: randIPv4Address(t), src: randIPv4Address(t)} | |||
} | |||
return randAddrs | |||
} | |||
func randIPv4Address(t *testing.T) *NetAddress { | |||
for { | |||
ip := fmt.Sprintf("%v.%v.%v.%v", | |||
rand.Intn(254)+1, | |||
rand.Intn(255), | |||
rand.Intn(255), | |||
rand.Intn(255), | |||
) | |||
port := rand.Intn(65535-1) + 1 | |||
addr, err := NewNetAddressString(fmt.Sprintf("%v:%v", ip, port)) | |||
assert.Nil(t, err, "error generating rand network address") | |||
if addr.Routable() { | |||
return addr | |||
} | |||
} | |||
} | |||
func TestAddrBookRemoveAddress(t *testing.T) { | |||
fname := createTempFileName("addrbook_test") | |||
book := NewAddrBook(fname, true) | |||
addr := randIPv4Address(t) | |||
book.AddAddress(addr, addr) | |||
assert.Equal(t, 1, book.Size()) | |||
book.RemoveAddress(addr) | |||
assert.Equal(t, 0, book.Size()) | |||
nonExistingAddr := randIPv4Address(t) | |||
book.RemoveAddress(nonExistingAddr) | |||
assert.Equal(t, 0, book.Size()) | |||
} |
@ -0,0 +1,45 @@ | |||
package p2p | |||
import ( | |||
cfg "github.com/tendermint/go-config" | |||
) | |||
const ( | |||
// Switch config keys | |||
configKeyDialTimeoutSeconds = "dial_timeout_seconds" | |||
configKeyHandshakeTimeoutSeconds = "handshake_timeout_seconds" | |||
configKeyMaxNumPeers = "max_num_peers" | |||
configKeyAuthEnc = "authenticated_encryption" | |||
// MConnection config keys | |||
configKeySendRate = "send_rate" | |||
configKeyRecvRate = "recv_rate" | |||
// Fuzz params | |||
configFuzzEnable = "fuzz_enable" // use the fuzz wrapped conn | |||
configFuzzMode = "fuzz_mode" // eg. drop, delay | |||
configFuzzMaxDelayMilliseconds = "fuzz_max_delay_milliseconds" | |||
configFuzzProbDropRW = "fuzz_prob_drop_rw" | |||
configFuzzProbDropConn = "fuzz_prob_drop_conn" | |||
configFuzzProbSleep = "fuzz_prob_sleep" | |||
) | |||
func setConfigDefaults(config cfg.Config) { | |||
// Switch default config | |||
config.SetDefault(configKeyDialTimeoutSeconds, 3) | |||
config.SetDefault(configKeyHandshakeTimeoutSeconds, 20) | |||
config.SetDefault(configKeyMaxNumPeers, 50) | |||
config.SetDefault(configKeyAuthEnc, true) | |||
// MConnection default config | |||
config.SetDefault(configKeySendRate, 512000) // 500KB/s | |||
config.SetDefault(configKeyRecvRate, 512000) // 500KB/s | |||
// Fuzz defaults | |||
config.SetDefault(configFuzzEnable, false) | |||
config.SetDefault(configFuzzMode, FuzzModeDrop) | |||
config.SetDefault(configFuzzMaxDelayMilliseconds, 3000) | |||
config.SetDefault(configFuzzProbDropRW, 0.2) | |||
config.SetDefault(configFuzzProbDropConn, 0.00) | |||
config.SetDefault(configFuzzProbSleep, 0.00) | |||
} |
@ -0,0 +1,686 @@ | |||
package p2p | |||
import ( | |||
"bufio" | |||
"fmt" | |||
"io" | |||
"math" | |||
"net" | |||
"runtime/debug" | |||
"sync/atomic" | |||
"time" | |||
wire "github.com/tendermint/go-wire" | |||
cmn "github.com/tendermint/tmlibs/common" | |||
flow "github.com/tendermint/tmlibs/flowrate" | |||
) | |||
const ( | |||
numBatchMsgPackets = 10 | |||
minReadBufferSize = 1024 | |||
minWriteBufferSize = 65536 | |||
updateState = 2 * time.Second | |||
pingTimeout = 40 * time.Second | |||
flushThrottle = 100 * time.Millisecond | |||
defaultSendQueueCapacity = 1 | |||
defaultSendRate = int64(512000) // 500KB/s | |||
defaultRecvBufferCapacity = 4096 | |||
defaultRecvMessageCapacity = 22020096 // 21MB | |||
defaultRecvRate = int64(512000) // 500KB/s | |||
defaultSendTimeout = 10 * time.Second | |||
) | |||
type receiveCbFunc func(chID byte, msgBytes []byte) | |||
type errorCbFunc func(interface{}) | |||
/* | |||
Each peer has one `MConnection` (multiplex connection) instance. | |||
__multiplex__ *noun* a system or signal involving simultaneous transmission of | |||
several messages along a single channel of communication. | |||
Each `MConnection` handles message transmission on multiple abstract communication | |||
`Channel`s. Each channel has a globally unique byte id. | |||
The byte id and the relative priorities of each `Channel` are configured upon | |||
initialization of the connection. | |||
There are two methods for sending messages: | |||
func (m MConnection) Send(chID byte, msg interface{}) bool {} | |||
func (m MConnection) TrySend(chID byte, msg interface{}) bool {} | |||
`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued | |||
for the channel with the given id byte `chID`, or until the request times out. | |||
The message `msg` is serialized using the `tendermint/wire` submodule's | |||
`WriteBinary()` reflection routine. | |||
`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's | |||
queue is full. | |||
Inbound message bytes are handled with an onReceive callback function. | |||
*/ | |||
type MConnection struct { | |||
cmn.BaseService | |||
conn net.Conn | |||
bufReader *bufio.Reader | |||
bufWriter *bufio.Writer | |||
sendMonitor *flow.Monitor | |||
recvMonitor *flow.Monitor | |||
send chan struct{} | |||
pong chan struct{} | |||
channels []*Channel | |||
channelsIdx map[byte]*Channel | |||
onReceive receiveCbFunc | |||
onError errorCbFunc | |||
errored uint32 | |||
config *MConnConfig | |||
quit chan struct{} | |||
flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. | |||
pingTimer *cmn.RepeatTimer // send pings periodically | |||
chStatsTimer *cmn.RepeatTimer // update channel stats periodically | |||
LocalAddress *NetAddress | |||
RemoteAddress *NetAddress | |||
} | |||
// MConnConfig is a MConnection configuration. | |||
type MConnConfig struct { | |||
SendRate int64 | |||
RecvRate int64 | |||
} | |||
// DefaultMConnConfig returns the default config. | |||
func DefaultMConnConfig() *MConnConfig { | |||
return &MConnConfig{ | |||
SendRate: defaultSendRate, | |||
RecvRate: defaultRecvRate, | |||
} | |||
} | |||
// NewMConnection wraps net.Conn and creates multiplex connection | |||
func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection { | |||
return NewMConnectionWithConfig( | |||
conn, | |||
chDescs, | |||
onReceive, | |||
onError, | |||
DefaultMConnConfig()) | |||
} | |||
// NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config | |||
func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection { | |||
mconn := &MConnection{ | |||
conn: conn, | |||
bufReader: bufio.NewReaderSize(conn, minReadBufferSize), | |||
bufWriter: bufio.NewWriterSize(conn, minWriteBufferSize), | |||
sendMonitor: flow.New(0, 0), | |||
recvMonitor: flow.New(0, 0), | |||
send: make(chan struct{}, 1), | |||
pong: make(chan struct{}), | |||
onReceive: onReceive, | |||
onError: onError, | |||
config: config, | |||
LocalAddress: NewNetAddress(conn.LocalAddr()), | |||
RemoteAddress: NewNetAddress(conn.RemoteAddr()), | |||
} | |||
// Create channels | |||
var channelsIdx = map[byte]*Channel{} | |||
var channels = []*Channel{} | |||
for _, desc := range chDescs { | |||
descCopy := *desc // copy the desc else unsafe access across connections | |||
channel := newChannel(mconn, &descCopy) | |||
channelsIdx[channel.id] = channel | |||
channels = append(channels, channel) | |||
} | |||
mconn.channels = channels | |||
mconn.channelsIdx = channelsIdx | |||
mconn.BaseService = *cmn.NewBaseService(log, "MConnection", mconn) | |||
return mconn | |||
} | |||
func (c *MConnection) OnStart() error { | |||
c.BaseService.OnStart() | |||
c.quit = make(chan struct{}) | |||
c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle) | |||
c.pingTimer = cmn.NewRepeatTimer("ping", pingTimeout) | |||
c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateState) | |||
go c.sendRoutine() | |||
go c.recvRoutine() | |||
return nil | |||
} | |||
func (c *MConnection) OnStop() { | |||
c.BaseService.OnStop() | |||
c.flushTimer.Stop() | |||
c.pingTimer.Stop() | |||
c.chStatsTimer.Stop() | |||
if c.quit != nil { | |||
close(c.quit) | |||
} | |||
c.conn.Close() | |||
// We can't close pong safely here because | |||
// recvRoutine may write to it after we've stopped. | |||
// Though it doesn't need to get closed at all, | |||
// we close it @ recvRoutine. | |||
// close(c.pong) | |||
} | |||
func (c *MConnection) String() string { | |||
return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr()) | |||
} | |||
func (c *MConnection) flush() { | |||
log.Debug("Flush", "conn", c) | |||
err := c.bufWriter.Flush() | |||
if err != nil { | |||
log.Warn("MConnection flush failed", "error", err) | |||
} | |||
} | |||
// Catch panics, usually caused by remote disconnects. | |||
func (c *MConnection) _recover() { | |||
if r := recover(); r != nil { | |||
stack := debug.Stack() | |||
err := cmn.StackError{r, stack} | |||
c.stopForError(err) | |||
} | |||
} | |||
func (c *MConnection) stopForError(r interface{}) { | |||
c.Stop() | |||
if atomic.CompareAndSwapUint32(&c.errored, 0, 1) { | |||
if c.onError != nil { | |||
c.onError(r) | |||
} | |||
} | |||
} | |||
// Queues a message to be sent to channel. | |||
func (c *MConnection) Send(chID byte, msg interface{}) bool { | |||
if !c.IsRunning() { | |||
return false | |||
} | |||
log.Debug("Send", "channel", chID, "conn", c, "msg", msg) //, "bytes", wire.BinaryBytes(msg)) | |||
// Send message to channel. | |||
channel, ok := c.channelsIdx[chID] | |||
if !ok { | |||
log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID)) | |||
return false | |||
} | |||
success := channel.sendBytes(wire.BinaryBytes(msg)) | |||
if success { | |||
// Wake up sendRoutine if necessary | |||
select { | |||
case c.send <- struct{}{}: | |||
default: | |||
} | |||
} else { | |||
log.Warn("Send failed", "channel", chID, "conn", c, "msg", msg) | |||
} | |||
return success | |||
} | |||
// Queues a message to be sent to channel. | |||
// Nonblocking, returns true if successful. | |||
func (c *MConnection) TrySend(chID byte, msg interface{}) bool { | |||
if !c.IsRunning() { | |||
return false | |||
} | |||
log.Debug("TrySend", "channel", chID, "conn", c, "msg", msg) | |||
// Send message to channel. | |||
channel, ok := c.channelsIdx[chID] | |||
if !ok { | |||
log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID)) | |||
return false | |||
} | |||
ok = channel.trySendBytes(wire.BinaryBytes(msg)) | |||
if ok { | |||
// Wake up sendRoutine if necessary | |||
select { | |||
case c.send <- struct{}{}: | |||
default: | |||
} | |||
} | |||
return ok | |||
} | |||
// CanSend returns true if you can send more data onto the chID, false | |||
// otherwise. Use only as a heuristic. | |||
func (c *MConnection) CanSend(chID byte) bool { | |||
if !c.IsRunning() { | |||
return false | |||
} | |||
channel, ok := c.channelsIdx[chID] | |||
if !ok { | |||
log.Error(cmn.Fmt("Unknown channel %X", chID)) | |||
return false | |||
} | |||
return channel.canSend() | |||
} | |||
// sendRoutine polls for packets to send from channels. | |||
func (c *MConnection) sendRoutine() { | |||
defer c._recover() | |||
FOR_LOOP: | |||
for { | |||
var n int | |||
var err error | |||
select { | |||
case <-c.flushTimer.Ch: | |||
// NOTE: flushTimer.Set() must be called every time | |||
// something is written to .bufWriter. | |||
c.flush() | |||
case <-c.chStatsTimer.Ch: | |||
for _, channel := range c.channels { | |||
channel.updateStats() | |||
} | |||
case <-c.pingTimer.Ch: | |||
log.Debug("Send Ping") | |||
wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) | |||
c.sendMonitor.Update(int(n)) | |||
c.flush() | |||
case <-c.pong: | |||
log.Debug("Send Pong") | |||
wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) | |||
c.sendMonitor.Update(int(n)) | |||
c.flush() | |||
case <-c.quit: | |||
break FOR_LOOP | |||
case <-c.send: | |||
// Send some msgPackets | |||
eof := c.sendSomeMsgPackets() | |||
if !eof { | |||
// Keep sendRoutine awake. | |||
select { | |||
case c.send <- struct{}{}: | |||
default: | |||
} | |||
} | |||
} | |||
if !c.IsRunning() { | |||
break FOR_LOOP | |||
} | |||
if err != nil { | |||
log.Warn("Connection failed @ sendRoutine", "conn", c, "error", err) | |||
c.stopForError(err) | |||
break FOR_LOOP | |||
} | |||
} | |||
// Cleanup | |||
} | |||
// Returns true if messages from channels were exhausted. | |||
// Blocks in accordance to .sendMonitor throttling. | |||
func (c *MConnection) sendSomeMsgPackets() bool { | |||
// Block until .sendMonitor says we can write. | |||
// Once we're ready we send more than we asked for, | |||
// but amortized it should even out. | |||
c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true) | |||
// Now send some msgPackets. | |||
for i := 0; i < numBatchMsgPackets; i++ { | |||
if c.sendMsgPacket() { | |||
return true | |||
} | |||
} | |||
return false | |||
} | |||
// Returns true if messages from channels were exhausted. | |||
func (c *MConnection) sendMsgPacket() bool { | |||
// Choose a channel to create a msgPacket from. | |||
// The chosen channel will be the one whose recentlySent/priority is the least. | |||
var leastRatio float32 = math.MaxFloat32 | |||
var leastChannel *Channel | |||
for _, channel := range c.channels { | |||
// If nothing to send, skip this channel | |||
if !channel.isSendPending() { | |||
continue | |||
} | |||
// Get ratio, and keep track of lowest ratio. | |||
ratio := float32(channel.recentlySent) / float32(channel.priority) | |||
if ratio < leastRatio { | |||
leastRatio = ratio | |||
leastChannel = channel | |||
} | |||
} | |||
// Nothing to send? | |||
if leastChannel == nil { | |||
return true | |||
} else { | |||
// log.Info("Found a msgPacket to send") | |||
} | |||
// Make & send a msgPacket from this channel | |||
n, err := leastChannel.writeMsgPacketTo(c.bufWriter) | |||
if err != nil { | |||
log.Warn("Failed to write msgPacket", "error", err) | |||
c.stopForError(err) | |||
return true | |||
} | |||
c.sendMonitor.Update(int(n)) | |||
c.flushTimer.Set() | |||
return false | |||
} | |||
// recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer. | |||
// After a whole message has been assembled, it's pushed to onReceive(). | |||
// Blocks depending on how the connection is throttled. | |||
func (c *MConnection) recvRoutine() { | |||
defer c._recover() | |||
FOR_LOOP: | |||
for { | |||
// Block until .recvMonitor says we can read. | |||
c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true) | |||
/* | |||
// Peek into bufReader for debugging | |||
if numBytes := c.bufReader.Buffered(); numBytes > 0 { | |||
log.Info("Peek connection buffer", "numBytes", numBytes, "bytes", log15.Lazy{func() []byte { | |||
bytes, err := c.bufReader.Peek(MinInt(numBytes, 100)) | |||
if err == nil { | |||
return bytes | |||
} else { | |||
log.Warn("Error peeking connection buffer", "error", err) | |||
return nil | |||
} | |||
}}) | |||
} | |||
*/ | |||
// Read packet type | |||
var n int | |||
var err error | |||
pktType := wire.ReadByte(c.bufReader, &n, &err) | |||
c.recvMonitor.Update(int(n)) | |||
if err != nil { | |||
if c.IsRunning() { | |||
log.Warn("Connection failed @ recvRoutine (reading byte)", "conn", c, "error", err) | |||
c.stopForError(err) | |||
} | |||
break FOR_LOOP | |||
} | |||
// Read more depending on packet type. | |||
switch pktType { | |||
case packetTypePing: | |||
// TODO: prevent abuse, as they cause flush()'s. | |||
log.Debug("Receive Ping") | |||
c.pong <- struct{}{} | |||
case packetTypePong: | |||
// do nothing | |||
log.Debug("Receive Pong") | |||
case packetTypeMsg: | |||
pkt, n, err := msgPacket{}, int(0), error(nil) | |||
wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err) | |||
c.recvMonitor.Update(int(n)) | |||
if err != nil { | |||
if c.IsRunning() { | |||
log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err) | |||
c.stopForError(err) | |||
} | |||
break FOR_LOOP | |||
} | |||
channel, ok := c.channelsIdx[pkt.ChannelID] | |||
if !ok || channel == nil { | |||
cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID)) | |||
} | |||
msgBytes, err := channel.recvMsgPacket(pkt) | |||
if err != nil { | |||
if c.IsRunning() { | |||
log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err) | |||
c.stopForError(err) | |||
} | |||
break FOR_LOOP | |||
} | |||
if msgBytes != nil { | |||
log.Debug("Received bytes", "chID", pkt.ChannelID, "msgBytes", msgBytes) | |||
c.onReceive(pkt.ChannelID, msgBytes) | |||
} | |||
default: | |||
cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType)) | |||
} | |||
// TODO: shouldn't this go in the sendRoutine? | |||
// Better to send a ping packet when *we* haven't sent anything for a while. | |||
c.pingTimer.Reset() | |||
} | |||
// Cleanup | |||
close(c.pong) | |||
for _ = range c.pong { | |||
// Drain | |||
} | |||
} | |||
type ConnectionStatus struct { | |||
SendMonitor flow.Status | |||
RecvMonitor flow.Status | |||
Channels []ChannelStatus | |||
} | |||
type ChannelStatus struct { | |||
ID byte | |||
SendQueueCapacity int | |||
SendQueueSize int | |||
Priority int | |||
RecentlySent int64 | |||
} | |||
func (c *MConnection) Status() ConnectionStatus { | |||
var status ConnectionStatus | |||
status.SendMonitor = c.sendMonitor.Status() | |||
status.RecvMonitor = c.recvMonitor.Status() | |||
status.Channels = make([]ChannelStatus, len(c.channels)) | |||
for i, channel := range c.channels { | |||
status.Channels[i] = ChannelStatus{ | |||
ID: channel.id, | |||
SendQueueCapacity: cap(channel.sendQueue), | |||
SendQueueSize: int(channel.sendQueueSize), // TODO use atomic | |||
Priority: channel.priority, | |||
RecentlySent: channel.recentlySent, | |||
} | |||
} | |||
return status | |||
} | |||
//----------------------------------------------------------------------------- | |||
type ChannelDescriptor struct { | |||
ID byte | |||
Priority int | |||
SendQueueCapacity int | |||
RecvBufferCapacity int | |||
RecvMessageCapacity int | |||
} | |||
func (chDesc *ChannelDescriptor) FillDefaults() { | |||
if chDesc.SendQueueCapacity == 0 { | |||
chDesc.SendQueueCapacity = defaultSendQueueCapacity | |||
} | |||
if chDesc.RecvBufferCapacity == 0 { | |||
chDesc.RecvBufferCapacity = defaultRecvBufferCapacity | |||
} | |||
if chDesc.RecvMessageCapacity == 0 { | |||
chDesc.RecvMessageCapacity = defaultRecvMessageCapacity | |||
} | |||
} | |||
// TODO: lowercase. | |||
// NOTE: not goroutine-safe. | |||
type Channel struct { | |||
conn *MConnection | |||
desc *ChannelDescriptor | |||
id byte | |||
sendQueue chan []byte | |||
sendQueueSize int32 // atomic. | |||
recving []byte | |||
sending []byte | |||
priority int | |||
recentlySent int64 // exponential moving average | |||
} | |||
func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel { | |||
desc.FillDefaults() | |||
if desc.Priority <= 0 { | |||
cmn.PanicSanity("Channel default priority must be a postive integer") | |||
} | |||
return &Channel{ | |||
conn: conn, | |||
desc: desc, | |||
id: desc.ID, | |||
sendQueue: make(chan []byte, desc.SendQueueCapacity), | |||
recving: make([]byte, 0, desc.RecvBufferCapacity), | |||
priority: desc.Priority, | |||
} | |||
} | |||
// Queues message to send to this channel. | |||
// Goroutine-safe | |||
// Times out (and returns false) after defaultSendTimeout | |||
func (ch *Channel) sendBytes(bytes []byte) bool { | |||
select { | |||
case ch.sendQueue <- bytes: | |||
atomic.AddInt32(&ch.sendQueueSize, 1) | |||
return true | |||
case <-time.After(defaultSendTimeout): | |||
return false | |||
} | |||
} | |||
// Queues message to send to this channel. | |||
// Nonblocking, returns true if successful. | |||
// Goroutine-safe | |||
func (ch *Channel) trySendBytes(bytes []byte) bool { | |||
select { | |||
case ch.sendQueue <- bytes: | |||
atomic.AddInt32(&ch.sendQueueSize, 1) | |||
return true | |||
default: | |||
return false | |||
} | |||
} | |||
// Goroutine-safe | |||
func (ch *Channel) loadSendQueueSize() (size int) { | |||
return int(atomic.LoadInt32(&ch.sendQueueSize)) | |||
} | |||
// Goroutine-safe | |||
// Use only as a heuristic. | |||
func (ch *Channel) canSend() bool { | |||
return ch.loadSendQueueSize() < defaultSendQueueCapacity | |||
} | |||
// Returns true if any msgPackets are pending to be sent. | |||
// Call before calling nextMsgPacket() | |||
// Goroutine-safe | |||
func (ch *Channel) isSendPending() bool { | |||
if len(ch.sending) == 0 { | |||
if len(ch.sendQueue) == 0 { | |||
return false | |||
} | |||
ch.sending = <-ch.sendQueue | |||
} | |||
return true | |||
} | |||
// Creates a new msgPacket to send. | |||
// Not goroutine-safe | |||
func (ch *Channel) nextMsgPacket() msgPacket { | |||
packet := msgPacket{} | |||
packet.ChannelID = byte(ch.id) | |||
packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))] | |||
if len(ch.sending) <= maxMsgPacketPayloadSize { | |||
packet.EOF = byte(0x01) | |||
ch.sending = nil | |||
atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize | |||
} else { | |||
packet.EOF = byte(0x00) | |||
ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):] | |||
} | |||
return packet | |||
} | |||
// Writes next msgPacket to w. | |||
// Not goroutine-safe | |||
func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) { | |||
packet := ch.nextMsgPacket() | |||
log.Debug("Write Msg Packet", "conn", ch.conn, "packet", packet) | |||
wire.WriteByte(packetTypeMsg, w, &n, &err) | |||
wire.WriteBinary(packet, w, &n, &err) | |||
if err == nil { | |||
ch.recentlySent += int64(n) | |||
} | |||
return | |||
} | |||
// Handles incoming msgPackets. Returns a msg bytes if msg is complete. | |||
// Not goroutine-safe | |||
func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) { | |||
// log.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet) | |||
if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) { | |||
return nil, wire.ErrBinaryReadOverflow | |||
} | |||
ch.recving = append(ch.recving, packet.Bytes...) | |||
if packet.EOF == byte(0x01) { | |||
msgBytes := ch.recving | |||
// clear the slice without re-allocating. | |||
// http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go | |||
// suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes, | |||
// at which point the recving slice stops being used and should be garbage collected | |||
ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity) | |||
return msgBytes, nil | |||
} | |||
return nil, nil | |||
} | |||
// Call this periodically to update stats for throttling purposes. | |||
// Not goroutine-safe | |||
func (ch *Channel) updateStats() { | |||
// Exponential decay of stats. | |||
// TODO: optimize. | |||
ch.recentlySent = int64(float64(ch.recentlySent) * 0.8) | |||
} | |||
//----------------------------------------------------------------------------- | |||
const ( | |||
maxMsgPacketPayloadSize = 1024 | |||
maxMsgPacketOverheadSize = 10 // It's actually lower but good enough | |||
maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize | |||
packetTypePing = byte(0x01) | |||
packetTypePong = byte(0x02) | |||
packetTypeMsg = byte(0x03) | |||
) | |||
// Messages in channels are chopped into smaller msgPackets for multiplexing. | |||
type msgPacket struct { | |||
ChannelID byte | |||
EOF byte // 1 means message ends here. | |||
Bytes []byte | |||
} | |||
func (p msgPacket) String() string { | |||
return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) | |||
} |
@ -0,0 +1,139 @@ | |||
package p2p_test | |||
import ( | |||
"net" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
p2p "github.com/tendermint/tendermint/p2p" | |||
) | |||
func createMConnection(conn net.Conn) *p2p.MConnection { | |||
onReceive := func(chID byte, msgBytes []byte) { | |||
} | |||
onError := func(r interface{}) { | |||
} | |||
return createMConnectionWithCallbacks(conn, onReceive, onError) | |||
} | |||
func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *p2p.MConnection { | |||
chDescs := []*p2p.ChannelDescriptor{&p2p.ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} | |||
return p2p.NewMConnection(conn, chDescs, onReceive, onError) | |||
} | |||
func TestMConnectionSend(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
server, client := net.Pipe() | |||
defer server.Close() | |||
defer client.Close() | |||
mconn := createMConnection(client) | |||
_, err := mconn.Start() | |||
require.Nil(err) | |||
defer mconn.Stop() | |||
msg := "Ant-Man" | |||
assert.True(mconn.Send(0x01, msg)) | |||
// Note: subsequent Send/TrySend calls could pass because we are reading from | |||
// the send queue in a separate goroutine. | |||
server.Read(make([]byte, len(msg))) | |||
assert.True(mconn.CanSend(0x01)) | |||
msg = "Spider-Man" | |||
assert.True(mconn.TrySend(0x01, msg)) | |||
server.Read(make([]byte, len(msg))) | |||
assert.False(mconn.CanSend(0x05), "CanSend should return false because channel is unknown") | |||
assert.False(mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown") | |||
} | |||
func TestMConnectionReceive(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
server, client := net.Pipe() | |||
defer server.Close() | |||
defer client.Close() | |||
receivedCh := make(chan []byte) | |||
errorsCh := make(chan interface{}) | |||
onReceive := func(chID byte, msgBytes []byte) { | |||
receivedCh <- msgBytes | |||
} | |||
onError := func(r interface{}) { | |||
errorsCh <- r | |||
} | |||
mconn1 := createMConnectionWithCallbacks(client, onReceive, onError) | |||
_, err := mconn1.Start() | |||
require.Nil(err) | |||
defer mconn1.Stop() | |||
mconn2 := createMConnection(server) | |||
_, err = mconn2.Start() | |||
require.Nil(err) | |||
defer mconn2.Stop() | |||
msg := "Cyclops" | |||
assert.True(mconn2.Send(0x01, msg)) | |||
select { | |||
case receivedBytes := <-receivedCh: | |||
assert.Equal([]byte(msg), receivedBytes[2:]) // first 3 bytes are internal | |||
case err := <-errorsCh: | |||
t.Fatalf("Expected %s, got %+v", msg, err) | |||
case <-time.After(500 * time.Millisecond): | |||
t.Fatalf("Did not receive %s message in 500ms", msg) | |||
} | |||
} | |||
func TestMConnectionStatus(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
server, client := net.Pipe() | |||
defer server.Close() | |||
defer client.Close() | |||
mconn := createMConnection(client) | |||
_, err := mconn.Start() | |||
require.Nil(err) | |||
defer mconn.Stop() | |||
status := mconn.Status() | |||
assert.NotNil(status) | |||
assert.Zero(status.Channels[0].SendQueueSize) | |||
} | |||
func TestMConnectionStopsAndReturnsError(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
server, client := net.Pipe() | |||
defer server.Close() | |||
defer client.Close() | |||
receivedCh := make(chan []byte) | |||
errorsCh := make(chan interface{}) | |||
onReceive := func(chID byte, msgBytes []byte) { | |||
receivedCh <- msgBytes | |||
} | |||
onError := func(r interface{}) { | |||
errorsCh <- r | |||
} | |||
mconn := createMConnectionWithCallbacks(client, onReceive, onError) | |||
_, err := mconn.Start() | |||
require.Nil(err) | |||
defer mconn.Stop() | |||
client.Close() | |||
select { | |||
case receivedBytes := <-receivedCh: | |||
t.Fatalf("Expected error, got %v", receivedBytes) | |||
case err := <-errorsCh: | |||
assert.NotNil(err) | |||
assert.False(mconn.IsRunning()) | |||
case <-time.After(500 * time.Millisecond): | |||
t.Fatal("Did not receive error in 500ms") | |||
} | |||
} |
@ -0,0 +1,173 @@ | |||
package p2p | |||
import ( | |||
"math/rand" | |||
"net" | |||
"sync" | |||
"time" | |||
) | |||
const ( | |||
// FuzzModeDrop is a mode in which we randomly drop reads/writes, connections or sleep | |||
FuzzModeDrop = iota | |||
// FuzzModeDelay is a mode in which we randomly sleep | |||
FuzzModeDelay | |||
) | |||
// FuzzedConnection wraps any net.Conn and depending on the mode either delays | |||
// reads/writes or randomly drops reads/writes/connections. | |||
type FuzzedConnection struct { | |||
conn net.Conn | |||
mtx sync.Mutex | |||
start <-chan time.Time | |||
active bool | |||
config *FuzzConnConfig | |||
} | |||
// FuzzConnConfig is a FuzzedConnection configuration. | |||
type FuzzConnConfig struct { | |||
Mode int | |||
MaxDelay time.Duration | |||
ProbDropRW float64 | |||
ProbDropConn float64 | |||
ProbSleep float64 | |||
} | |||
// DefaultFuzzConnConfig returns the default config. | |||
func DefaultFuzzConnConfig() *FuzzConnConfig { | |||
return &FuzzConnConfig{ | |||
Mode: FuzzModeDrop, | |||
MaxDelay: 3 * time.Second, | |||
ProbDropRW: 0.2, | |||
ProbDropConn: 0.00, | |||
ProbSleep: 0.00, | |||
} | |||
} | |||
// FuzzConn creates a new FuzzedConnection. Fuzzing starts immediately. | |||
func FuzzConn(conn net.Conn) net.Conn { | |||
return FuzzConnFromConfig(conn, DefaultFuzzConnConfig()) | |||
} | |||
// FuzzConnFromConfig creates a new FuzzedConnection from a config. Fuzzing | |||
// starts immediately. | |||
func FuzzConnFromConfig(conn net.Conn, config *FuzzConnConfig) net.Conn { | |||
return &FuzzedConnection{ | |||
conn: conn, | |||
start: make(<-chan time.Time), | |||
active: true, | |||
config: config, | |||
} | |||
} | |||
// FuzzConnAfter creates a new FuzzedConnection. Fuzzing starts when the | |||
// duration elapses. | |||
func FuzzConnAfter(conn net.Conn, d time.Duration) net.Conn { | |||
return FuzzConnAfterFromConfig(conn, d, DefaultFuzzConnConfig()) | |||
} | |||
// FuzzConnAfterFromConfig creates a new FuzzedConnection from a config. | |||
// Fuzzing starts when the duration elapses. | |||
func FuzzConnAfterFromConfig(conn net.Conn, d time.Duration, config *FuzzConnConfig) net.Conn { | |||
return &FuzzedConnection{ | |||
conn: conn, | |||
start: time.After(d), | |||
active: false, | |||
config: config, | |||
} | |||
} | |||
// Config returns the connection's config. | |||
func (fc *FuzzedConnection) Config() *FuzzConnConfig { | |||
return fc.config | |||
} | |||
// Read implements net.Conn. | |||
func (fc *FuzzedConnection) Read(data []byte) (n int, err error) { | |||
if fc.fuzz() { | |||
return 0, nil | |||
} | |||
return fc.conn.Read(data) | |||
} | |||
// Write implements net.Conn. | |||
func (fc *FuzzedConnection) Write(data []byte) (n int, err error) { | |||
if fc.fuzz() { | |||
return 0, nil | |||
} | |||
return fc.conn.Write(data) | |||
} | |||
// Close implements net.Conn. | |||
func (fc *FuzzedConnection) Close() error { return fc.conn.Close() } | |||
// LocalAddr implements net.Conn. | |||
func (fc *FuzzedConnection) LocalAddr() net.Addr { return fc.conn.LocalAddr() } | |||
// RemoteAddr implements net.Conn. | |||
func (fc *FuzzedConnection) RemoteAddr() net.Addr { return fc.conn.RemoteAddr() } | |||
// SetDeadline implements net.Conn. | |||
func (fc *FuzzedConnection) SetDeadline(t time.Time) error { return fc.conn.SetDeadline(t) } | |||
// SetReadDeadline implements net.Conn. | |||
func (fc *FuzzedConnection) SetReadDeadline(t time.Time) error { | |||
return fc.conn.SetReadDeadline(t) | |||
} | |||
// SetWriteDeadline implements net.Conn. | |||
func (fc *FuzzedConnection) SetWriteDeadline(t time.Time) error { | |||
return fc.conn.SetWriteDeadline(t) | |||
} | |||
func (fc *FuzzedConnection) randomDuration() time.Duration { | |||
maxDelayMillis := int(fc.config.MaxDelay.Nanoseconds() / 1000) | |||
return time.Millisecond * time.Duration(rand.Int()%maxDelayMillis) | |||
} | |||
// implements the fuzz (delay, kill conn) | |||
// and returns whether or not the read/write should be ignored | |||
func (fc *FuzzedConnection) fuzz() bool { | |||
if !fc.shouldFuzz() { | |||
return false | |||
} | |||
switch fc.config.Mode { | |||
case FuzzModeDrop: | |||
// randomly drop the r/w, drop the conn, or sleep | |||
r := rand.Float64() | |||
if r <= fc.config.ProbDropRW { | |||
return true | |||
} else if r < fc.config.ProbDropRW+fc.config.ProbDropConn { | |||
// XXX: can't this fail because machine precision? | |||
// XXX: do we need an error? | |||
fc.Close() | |||
return true | |||
} else if r < fc.config.ProbDropRW+fc.config.ProbDropConn+fc.config.ProbSleep { | |||
time.Sleep(fc.randomDuration()) | |||
} | |||
case FuzzModeDelay: | |||
// sleep a bit | |||
time.Sleep(fc.randomDuration()) | |||
} | |||
return false | |||
} | |||
func (fc *FuzzedConnection) shouldFuzz() bool { | |||
if fc.active { | |||
return true | |||
} | |||
fc.mtx.Lock() | |||
defer fc.mtx.Unlock() | |||
select { | |||
case <-fc.start: | |||
fc.active = true | |||
return true | |||
default: | |||
return false | |||
} | |||
} |
@ -0,0 +1,29 @@ | |||
package p2p | |||
import ( | |||
"strings" | |||
) | |||
// TODO Test | |||
func AddToIPRangeCounts(counts map[string]int, ip string) map[string]int { | |||
changes := make(map[string]int) | |||
ipParts := strings.Split(ip, ":") | |||
for i := 1; i < len(ipParts); i++ { | |||
prefix := strings.Join(ipParts[:i], ":") | |||
counts[prefix] += 1 | |||
changes[prefix] = counts[prefix] | |||
} | |||
return changes | |||
} | |||
// TODO Test | |||
func CheckIPRangeCounts(counts map[string]int, limits []int) bool { | |||
for prefix, count := range counts { | |||
ipParts := strings.Split(prefix, ":") | |||
numParts := len(ipParts) | |||
if limits[numParts] < count { | |||
return false | |||
} | |||
} | |||
return true | |||
} |
@ -0,0 +1,217 @@ | |||
package p2p | |||
import ( | |||
"fmt" | |||
"net" | |||
"strconv" | |||
"time" | |||
. "github.com/tendermint/tmlibs/common" | |||
"github.com/tendermint/tendermint/p2p/upnp" | |||
) | |||
type Listener interface { | |||
Connections() <-chan net.Conn | |||
InternalAddress() *NetAddress | |||
ExternalAddress() *NetAddress | |||
String() string | |||
Stop() bool | |||
} | |||
// Implements Listener | |||
type DefaultListener struct { | |||
BaseService | |||
listener net.Listener | |||
intAddr *NetAddress | |||
extAddr *NetAddress | |||
connections chan net.Conn | |||
} | |||
const ( | |||
numBufferedConnections = 10 | |||
defaultExternalPort = 8770 | |||
tryListenSeconds = 5 | |||
) | |||
func splitHostPort(addr string) (host string, port int) { | |||
host, portStr, err := net.SplitHostPort(addr) | |||
if err != nil { | |||
PanicSanity(err) | |||
} | |||
port, err = strconv.Atoi(portStr) | |||
if err != nil { | |||
PanicSanity(err) | |||
} | |||
return host, port | |||
} | |||
// skipUPNP: If true, does not try getUPNPExternalAddress() | |||
func NewDefaultListener(protocol string, lAddr string, skipUPNP bool) Listener { | |||
// Local listen IP & port | |||
lAddrIP, lAddrPort := splitHostPort(lAddr) | |||
// Create listener | |||
var listener net.Listener | |||
var err error | |||
for i := 0; i < tryListenSeconds; i++ { | |||
listener, err = net.Listen(protocol, lAddr) | |||
if err == nil { | |||
break | |||
} else if i < tryListenSeconds-1 { | |||
time.Sleep(time.Second * 1) | |||
} | |||
} | |||
if err != nil { | |||
PanicCrisis(err) | |||
} | |||
// Actual listener local IP & port | |||
listenerIP, listenerPort := splitHostPort(listener.Addr().String()) | |||
log.Info("Local listener", "ip", listenerIP, "port", listenerPort) | |||
// Determine internal address... | |||
var intAddr *NetAddress | |||
intAddr, err = NewNetAddressString(lAddr) | |||
if err != nil { | |||
PanicCrisis(err) | |||
} | |||
// Determine external address... | |||
var extAddr *NetAddress | |||
if !skipUPNP { | |||
// If the lAddrIP is INADDR_ANY, try UPnP | |||
if lAddrIP == "" || lAddrIP == "0.0.0.0" { | |||
extAddr = getUPNPExternalAddress(lAddrPort, listenerPort) | |||
} | |||
} | |||
// Otherwise just use the local address... | |||
if extAddr == nil { | |||
extAddr = getNaiveExternalAddress(listenerPort) | |||
} | |||
if extAddr == nil { | |||
PanicCrisis("Could not determine external address!") | |||
} | |||
dl := &DefaultListener{ | |||
listener: listener, | |||
intAddr: intAddr, | |||
extAddr: extAddr, | |||
connections: make(chan net.Conn, numBufferedConnections), | |||
} | |||
dl.BaseService = *NewBaseService(log, "DefaultListener", dl) | |||
dl.Start() // Started upon construction | |||
return dl | |||
} | |||
func (l *DefaultListener) OnStart() error { | |||
l.BaseService.OnStart() | |||
go l.listenRoutine() | |||
return nil | |||
} | |||
func (l *DefaultListener) OnStop() { | |||
l.BaseService.OnStop() | |||
l.listener.Close() | |||
} | |||
// Accept connections and pass on the channel | |||
func (l *DefaultListener) listenRoutine() { | |||
for { | |||
conn, err := l.listener.Accept() | |||
if !l.IsRunning() { | |||
break // Go to cleanup | |||
} | |||
// listener wasn't stopped, | |||
// yet we encountered an error. | |||
if err != nil { | |||
PanicCrisis(err) | |||
} | |||
l.connections <- conn | |||
} | |||
// Cleanup | |||
close(l.connections) | |||
for _ = range l.connections { | |||
// Drain | |||
} | |||
} | |||
// A channel of inbound connections. | |||
// It gets closed when the listener closes. | |||
func (l *DefaultListener) Connections() <-chan net.Conn { | |||
return l.connections | |||
} | |||
func (l *DefaultListener) InternalAddress() *NetAddress { | |||
return l.intAddr | |||
} | |||
func (l *DefaultListener) ExternalAddress() *NetAddress { | |||
return l.extAddr | |||
} | |||
// NOTE: The returned listener is already Accept()'ing. | |||
// So it's not suitable to pass into http.Serve(). | |||
func (l *DefaultListener) NetListener() net.Listener { | |||
return l.listener | |||
} | |||
func (l *DefaultListener) String() string { | |||
return fmt.Sprintf("Listener(@%v)", l.extAddr) | |||
} | |||
/* external address helpers */ | |||
// UPNP external address discovery & port mapping | |||
func getUPNPExternalAddress(externalPort, internalPort int) *NetAddress { | |||
log.Info("Getting UPNP external address") | |||
nat, err := upnp.Discover() | |||
if err != nil { | |||
log.Info("Could not perform UPNP discover", "error", err) | |||
return nil | |||
} | |||
ext, err := nat.GetExternalAddress() | |||
if err != nil { | |||
log.Info("Could not get UPNP external address", "error", err) | |||
return nil | |||
} | |||
// UPnP can't seem to get the external port, so let's just be explicit. | |||
if externalPort == 0 { | |||
externalPort = defaultExternalPort | |||
} | |||
externalPort, err = nat.AddPortMapping("tcp", externalPort, internalPort, "tendermint", 0) | |||
if err != nil { | |||
log.Info("Could not add UPNP port mapping", "error", err) | |||
return nil | |||
} | |||
log.Info("Got UPNP external address", "address", ext) | |||
return NewNetAddressIPPort(ext, uint16(externalPort)) | |||
} | |||
// TODO: use syscalls: http://pastebin.com/9exZG4rh | |||
func getNaiveExternalAddress(port int) *NetAddress { | |||
addrs, err := net.InterfaceAddrs() | |||
if err != nil { | |||
PanicCrisis(Fmt("Could not fetch interface addresses: %v", err)) | |||
} | |||
for _, a := range addrs { | |||
ipnet, ok := a.(*net.IPNet) | |||
if !ok { | |||
continue | |||
} | |||
v4 := ipnet.IP.To4() | |||
if v4 == nil || v4[0] == 127 { | |||
continue | |||
} // loopback | |||
return NewNetAddressIPPort(ipnet.IP, uint16(port)) | |||
} | |||
return nil | |||
} |
@ -0,0 +1,40 @@ | |||
package p2p | |||
import ( | |||
"bytes" | |||
"testing" | |||
) | |||
func TestListener(t *testing.T) { | |||
// Create a listener | |||
l := NewDefaultListener("tcp", ":8001", true) | |||
// Dial the listener | |||
lAddr := l.ExternalAddress() | |||
connOut, err := lAddr.Dial() | |||
if err != nil { | |||
t.Fatalf("Could not connect to listener address %v", lAddr) | |||
} else { | |||
t.Logf("Created a connection to listener address %v", lAddr) | |||
} | |||
connIn, ok := <-l.Connections() | |||
if !ok { | |||
t.Fatalf("Could not get inbound connection from listener") | |||
} | |||
msg := []byte("hi!") | |||
go connIn.Write(msg) | |||
b := make([]byte, 32) | |||
n, err := connOut.Read(b) | |||
if err != nil { | |||
t.Fatalf("Error reading off connection: %v", err) | |||
} | |||
b = b[:n] | |||
if !bytes.Equal(msg, b) { | |||
t.Fatalf("Got %s, expected %s", b, msg) | |||
} | |||
// Close the server, no longer needed. | |||
l.Stop() | |||
} |
@ -0,0 +1,7 @@ | |||
package p2p | |||
import ( | |||
"github.com/tendermint/tmlibs/logger" | |||
) | |||
var log = logger.New("module", "p2p") |
@ -0,0 +1,253 @@ | |||
// Modified for Tendermint | |||
// Originally Copyright (c) 2013-2014 Conformal Systems LLC. | |||
// https://github.com/conformal/btcd/blob/master/LICENSE | |||
package p2p | |||
import ( | |||
"errors" | |||
"flag" | |||
"net" | |||
"strconv" | |||
"time" | |||
cmn "github.com/tendermint/tmlibs/common" | |||
) | |||
// NetAddress defines information about a peer on the network | |||
// including its IP address, and port. | |||
type NetAddress struct { | |||
IP net.IP | |||
Port uint16 | |||
str string | |||
} | |||
// NewNetAddress returns a new NetAddress using the provided TCP | |||
// address. When testing, other net.Addr (except TCP) will result in | |||
// using 0.0.0.0:0. When normal run, other net.Addr (except TCP) will | |||
// panic. | |||
// TODO: socks proxies? | |||
func NewNetAddress(addr net.Addr) *NetAddress { | |||
tcpAddr, ok := addr.(*net.TCPAddr) | |||
if !ok { | |||
if flag.Lookup("test.v") == nil { // normal run | |||
cmn.PanicSanity(cmn.Fmt("Only TCPAddrs are supported. Got: %v", addr)) | |||
} else { // in testing | |||
return NewNetAddressIPPort(net.IP("0.0.0.0"), 0) | |||
} | |||
} | |||
ip := tcpAddr.IP | |||
port := uint16(tcpAddr.Port) | |||
return NewNetAddressIPPort(ip, port) | |||
} | |||
// NewNetAddressString returns a new NetAddress using the provided | |||
// address in the form of "IP:Port". Also resolves the host if host | |||
// is not an IP. | |||
func NewNetAddressString(addr string) (*NetAddress, error) { | |||
host, portStr, err := net.SplitHostPort(addr) | |||
if err != nil { | |||
return nil, err | |||
} | |||
ip := net.ParseIP(host) | |||
if ip == nil { | |||
if len(host) > 0 { | |||
ips, err := net.LookupIP(host) | |||
if err != nil { | |||
return nil, err | |||
} | |||
ip = ips[0] | |||
} | |||
} | |||
port, err := strconv.ParseUint(portStr, 10, 16) | |||
if err != nil { | |||
return nil, err | |||
} | |||
na := NewNetAddressIPPort(ip, uint16(port)) | |||
return na, nil | |||
} | |||
// NewNetAddressStrings returns an array of NetAddress'es build using | |||
// the provided strings. | |||
func NewNetAddressStrings(addrs []string) ([]*NetAddress, error) { | |||
netAddrs := make([]*NetAddress, len(addrs)) | |||
for i, addr := range addrs { | |||
netAddr, err := NewNetAddressString(addr) | |||
if err != nil { | |||
return nil, errors.New(cmn.Fmt("Error in address %s: %v", addr, err)) | |||
} | |||
netAddrs[i] = netAddr | |||
} | |||
return netAddrs, nil | |||
} | |||
// NewNetAddressIPPort returns a new NetAddress using the provided IP | |||
// and port number. | |||
func NewNetAddressIPPort(ip net.IP, port uint16) *NetAddress { | |||
na := &NetAddress{ | |||
IP: ip, | |||
Port: port, | |||
str: net.JoinHostPort( | |||
ip.String(), | |||
strconv.FormatUint(uint64(port), 10), | |||
), | |||
} | |||
return na | |||
} | |||
// Equals reports whether na and other are the same addresses. | |||
func (na *NetAddress) Equals(other interface{}) bool { | |||
if o, ok := other.(*NetAddress); ok { | |||
return na.String() == o.String() | |||
} | |||
return false | |||
} | |||
func (na *NetAddress) Less(other interface{}) bool { | |||
if o, ok := other.(*NetAddress); ok { | |||
return na.String() < o.String() | |||
} | |||
cmn.PanicSanity("Cannot compare unequal types") | |||
return false | |||
} | |||
// String representation. | |||
func (na *NetAddress) String() string { | |||
if na.str == "" { | |||
na.str = net.JoinHostPort( | |||
na.IP.String(), | |||
strconv.FormatUint(uint64(na.Port), 10), | |||
) | |||
} | |||
return na.str | |||
} | |||
// Dial calls net.Dial on the address. | |||
func (na *NetAddress) Dial() (net.Conn, error) { | |||
conn, err := net.Dial("tcp", na.String()) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return conn, nil | |||
} | |||
// DialTimeout calls net.DialTimeout on the address. | |||
func (na *NetAddress) DialTimeout(timeout time.Duration) (net.Conn, error) { | |||
conn, err := net.DialTimeout("tcp", na.String(), timeout) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return conn, nil | |||
} | |||
// Routable returns true if the address is routable. | |||
func (na *NetAddress) Routable() bool { | |||
// TODO(oga) bitcoind doesn't include RFC3849 here, but should we? | |||
return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() || | |||
na.RFC4193() || na.RFC4843() || na.Local()) | |||
} | |||
// For IPv4 these are either a 0 or all bits set address. For IPv6 a zero | |||
// address or one that matches the RFC3849 documentation address format. | |||
func (na *NetAddress) Valid() bool { | |||
return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() || | |||
na.IP.Equal(net.IPv4bcast)) | |||
} | |||
// Local returns true if it is a local address. | |||
func (na *NetAddress) Local() bool { | |||
return na.IP.IsLoopback() || zero4.Contains(na.IP) | |||
} | |||
// ReachabilityTo checks whenever o can be reached from na. | |||
func (na *NetAddress) ReachabilityTo(o *NetAddress) int { | |||
const ( | |||
Unreachable = 0 | |||
Default = iota | |||
Teredo | |||
Ipv6_weak | |||
Ipv4 | |||
Ipv6_strong | |||
Private | |||
) | |||
if !na.Routable() { | |||
return Unreachable | |||
} else if na.RFC4380() { | |||
if !o.Routable() { | |||
return Default | |||
} else if o.RFC4380() { | |||
return Teredo | |||
} else if o.IP.To4() != nil { | |||
return Ipv4 | |||
} else { // ipv6 | |||
return Ipv6_weak | |||
} | |||
} else if na.IP.To4() != nil { | |||
if o.Routable() && o.IP.To4() != nil { | |||
return Ipv4 | |||
} | |||
return Default | |||
} else /* ipv6 */ { | |||
var tunnelled bool | |||
// Is our v6 is tunnelled? | |||
if o.RFC3964() || o.RFC6052() || o.RFC6145() { | |||
tunnelled = true | |||
} | |||
if !o.Routable() { | |||
return Default | |||
} else if o.RFC4380() { | |||
return Teredo | |||
} else if o.IP.To4() != nil { | |||
return Ipv4 | |||
} else if tunnelled { | |||
// only prioritise ipv6 if we aren't tunnelling it. | |||
return Ipv6_weak | |||
} | |||
return Ipv6_strong | |||
} | |||
} | |||
// RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12) | |||
// RFC3849: IPv6 Documentation address (2001:0DB8::/32) | |||
// RFC3927: IPv4 Autoconfig (169.254.0.0/16) | |||
// RFC3964: IPv6 6to4 (2002::/16) | |||
// RFC4193: IPv6 unique local (FC00::/7) | |||
// RFC4380: IPv6 Teredo tunneling (2001::/32) | |||
// RFC4843: IPv6 ORCHID: (2001:10::/28) | |||
// RFC4862: IPv6 Autoconfig (FE80::/64) | |||
// RFC6052: IPv6 well known prefix (64:FF9B::/96) | |||
// RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96 | |||
var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)} | |||
var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)} | |||
var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)} | |||
var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)} | |||
var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)} | |||
var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)} | |||
var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)} | |||
var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)} | |||
var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)} | |||
var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)} | |||
var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)} | |||
var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)} | |||
var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)} | |||
func (na *NetAddress) RFC1918() bool { | |||
return rfc1918_10.Contains(na.IP) || | |||
rfc1918_192.Contains(na.IP) || | |||
rfc1918_172.Contains(na.IP) | |||
} | |||
func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) } | |||
func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) } | |||
func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) } | |||
func (na *NetAddress) RFC4193() bool { return rfc4193.Contains(na.IP) } | |||
func (na *NetAddress) RFC4380() bool { return rfc4380.Contains(na.IP) } | |||
func (na *NetAddress) RFC4843() bool { return rfc4843.Contains(na.IP) } | |||
func (na *NetAddress) RFC4862() bool { return rfc4862.Contains(na.IP) } | |||
func (na *NetAddress) RFC6052() bool { return rfc6052.Contains(na.IP) } | |||
func (na *NetAddress) RFC6145() bool { return rfc6145.Contains(na.IP) } |
@ -0,0 +1,113 @@ | |||
package p2p | |||
import ( | |||
"net" | |||
"testing" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
) | |||
func TestNewNetAddress(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") | |||
require.Nil(err) | |||
addr := NewNetAddress(tcpAddr) | |||
assert.Equal("127.0.0.1:8080", addr.String()) | |||
assert.NotPanics(func() { | |||
NewNetAddress(&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8000}) | |||
}, "Calling NewNetAddress with UDPAddr should not panic in testing") | |||
} | |||
func TestNewNetAddressString(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
tests := []struct { | |||
addr string | |||
correct bool | |||
}{ | |||
{"127.0.0.1:8080", true}, | |||
{"127.0.0:8080", false}, | |||
{"a", false}, | |||
{"127.0.0.1:a", false}, | |||
{"a:8080", false}, | |||
{"8082", false}, | |||
{"127.0.0:8080000", false}, | |||
} | |||
for _, t := range tests { | |||
addr, err := NewNetAddressString(t.addr) | |||
if t.correct { | |||
require.Nil(err) | |||
assert.Equal(t.addr, addr.String()) | |||
} else { | |||
require.NotNil(err) | |||
} | |||
} | |||
} | |||
func TestNewNetAddressStrings(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
addrs, err := NewNetAddressStrings([]string{"127.0.0.1:8080", "127.0.0.2:8080"}) | |||
require.Nil(err) | |||
assert.Equal(2, len(addrs)) | |||
} | |||
func TestNewNetAddressIPPort(t *testing.T) { | |||
assert := assert.New(t) | |||
addr := NewNetAddressIPPort(net.ParseIP("127.0.0.1"), 8080) | |||
assert.Equal("127.0.0.1:8080", addr.String()) | |||
} | |||
func TestNetAddressProperties(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
// TODO add more test cases | |||
tests := []struct { | |||
addr string | |||
valid bool | |||
local bool | |||
routable bool | |||
}{ | |||
{"127.0.0.1:8080", true, true, false}, | |||
{"ya.ru:80", true, false, true}, | |||
} | |||
for _, t := range tests { | |||
addr, err := NewNetAddressString(t.addr) | |||
require.Nil(err) | |||
assert.Equal(t.valid, addr.Valid()) | |||
assert.Equal(t.local, addr.Local()) | |||
assert.Equal(t.routable, addr.Routable()) | |||
} | |||
} | |||
func TestNetAddressReachabilityTo(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
// TODO add more test cases | |||
tests := []struct { | |||
addr string | |||
other string | |||
reachability int | |||
}{ | |||
{"127.0.0.1:8080", "127.0.0.1:8081", 0}, | |||
{"ya.ru:80", "127.0.0.1:8080", 1}, | |||
} | |||
for _, t := range tests { | |||
addr, err := NewNetAddressString(t.addr) | |||
require.Nil(err) | |||
other, err := NewNetAddressString(t.other) | |||
require.Nil(err) | |||
assert.Equal(t.reachability, addr.ReachabilityTo(other)) | |||
} | |||
} |
@ -0,0 +1,304 @@ | |||
package p2p | |||
import ( | |||
"fmt" | |||
"io" | |||
"net" | |||
"time" | |||
"github.com/pkg/errors" | |||
crypto "github.com/tendermint/go-crypto" | |||
wire "github.com/tendermint/go-wire" | |||
cmn "github.com/tendermint/tmlibs/common" | |||
) | |||
// Peer could be marked as persistent, in which case you can use | |||
// Redial function to reconnect. Note that inbound peers can't be | |||
// made persistent. They should be made persistent on the other end. | |||
// | |||
// Before using a peer, you will need to perform a handshake on connection. | |||
type Peer struct { | |||
cmn.BaseService | |||
outbound bool | |||
conn net.Conn // source connection | |||
mconn *MConnection // multiplex connection | |||
persistent bool | |||
config *PeerConfig | |||
*NodeInfo | |||
Key string | |||
Data *cmn.CMap // User data. | |||
} | |||
// PeerConfig is a Peer configuration. | |||
type PeerConfig struct { | |||
AuthEnc bool // authenticated encryption | |||
HandshakeTimeout time.Duration | |||
DialTimeout time.Duration | |||
MConfig *MConnConfig | |||
Fuzz bool // fuzz connection (for testing) | |||
FuzzConfig *FuzzConnConfig | |||
} | |||
// DefaultPeerConfig returns the default config. | |||
func DefaultPeerConfig() *PeerConfig { | |||
return &PeerConfig{ | |||
AuthEnc: true, | |||
HandshakeTimeout: 2 * time.Second, | |||
DialTimeout: 3 * time.Second, | |||
MConfig: DefaultMConnConfig(), | |||
Fuzz: false, | |||
FuzzConfig: DefaultFuzzConnConfig(), | |||
} | |||
} | |||
func newOutboundPeer(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) { | |||
return newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig()) | |||
} | |||
func newOutboundPeerWithConfig(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { | |||
conn, err := dial(addr, config) | |||
if err != nil { | |||
return nil, errors.Wrap(err, "Error creating peer") | |||
} | |||
peer, err := newPeerFromConnAndConfig(conn, true, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config) | |||
if err != nil { | |||
conn.Close() | |||
return nil, err | |||
} | |||
return peer, nil | |||
} | |||
func newInboundPeer(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) { | |||
return newInboundPeerWithConfig(conn, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig()) | |||
} | |||
func newInboundPeerWithConfig(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { | |||
return newPeerFromConnAndConfig(conn, false, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config) | |||
} | |||
func newPeerFromConnAndConfig(rawConn net.Conn, outbound bool, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { | |||
conn := rawConn | |||
// Fuzz connection | |||
if config.Fuzz { | |||
// so we have time to do peer handshakes and get set up | |||
conn = FuzzConnAfterFromConfig(conn, 10*time.Second, config.FuzzConfig) | |||
} | |||
// Encrypt connection | |||
if config.AuthEnc { | |||
conn.SetDeadline(time.Now().Add(config.HandshakeTimeout)) | |||
var err error | |||
conn, err = MakeSecretConnection(conn, ourNodePrivKey) | |||
if err != nil { | |||
return nil, errors.Wrap(err, "Error creating peer") | |||
} | |||
} | |||
// Key and NodeInfo are set after Handshake | |||
p := &Peer{ | |||
outbound: outbound, | |||
conn: conn, | |||
config: config, | |||
Data: cmn.NewCMap(), | |||
} | |||
p.mconn = createMConnection(conn, p, reactorsByCh, chDescs, onPeerError, config.MConfig) | |||
p.BaseService = *cmn.NewBaseService(log, "Peer", p) | |||
return p, nil | |||
} | |||
// CloseConn should be used when the peer was created, but never started. | |||
func (p *Peer) CloseConn() { | |||
p.conn.Close() | |||
} | |||
// makePersistent marks the peer as persistent. | |||
func (p *Peer) makePersistent() { | |||
if !p.outbound { | |||
panic("inbound peers can't be made persistent") | |||
} | |||
p.persistent = true | |||
} | |||
// IsPersistent returns true if the peer is persitent, false otherwise. | |||
func (p *Peer) IsPersistent() bool { | |||
return p.persistent | |||
} | |||
// HandshakeTimeout performs a handshake between a given node and the peer. | |||
// NOTE: blocking | |||
func (p *Peer) HandshakeTimeout(ourNodeInfo *NodeInfo, timeout time.Duration) error { | |||
// Set deadline for handshake so we don't block forever on conn.ReadFull | |||
p.conn.SetDeadline(time.Now().Add(timeout)) | |||
var peerNodeInfo = new(NodeInfo) | |||
var err1 error | |||
var err2 error | |||
cmn.Parallel( | |||
func() { | |||
var n int | |||
wire.WriteBinary(ourNodeInfo, p.conn, &n, &err1) | |||
}, | |||
func() { | |||
var n int | |||
wire.ReadBinary(peerNodeInfo, p.conn, maxNodeInfoSize, &n, &err2) | |||
log.Notice("Peer handshake", "peerNodeInfo", peerNodeInfo) | |||
}) | |||
if err1 != nil { | |||
return errors.Wrap(err1, "Error during handshake/write") | |||
} | |||
if err2 != nil { | |||
return errors.Wrap(err2, "Error during handshake/read") | |||
} | |||
if p.config.AuthEnc { | |||
// Check that the professed PubKey matches the sconn's. | |||
if !peerNodeInfo.PubKey.Equals(p.PubKey().Wrap()) { | |||
return fmt.Errorf("Ignoring connection with unmatching pubkey: %v vs %v", | |||
peerNodeInfo.PubKey, p.PubKey()) | |||
} | |||
} | |||
// Remove deadline | |||
p.conn.SetDeadline(time.Time{}) | |||
peerNodeInfo.RemoteAddr = p.Addr().String() | |||
p.NodeInfo = peerNodeInfo | |||
p.Key = peerNodeInfo.PubKey.KeyString() | |||
return nil | |||
} | |||
// Addr returns peer's network address. | |||
func (p *Peer) Addr() net.Addr { | |||
return p.conn.RemoteAddr() | |||
} | |||
// PubKey returns peer's public key. | |||
func (p *Peer) PubKey() crypto.PubKeyEd25519 { | |||
if p.config.AuthEnc { | |||
return p.conn.(*SecretConnection).RemotePubKey() | |||
} | |||
if p.NodeInfo == nil { | |||
panic("Attempt to get peer's PubKey before calling Handshake") | |||
} | |||
return p.PubKey() | |||
} | |||
// OnStart implements BaseService. | |||
func (p *Peer) OnStart() error { | |||
p.BaseService.OnStart() | |||
_, err := p.mconn.Start() | |||
return err | |||
} | |||
// OnStop implements BaseService. | |||
func (p *Peer) OnStop() { | |||
p.BaseService.OnStop() | |||
p.mconn.Stop() | |||
} | |||
// Connection returns underlying MConnection. | |||
func (p *Peer) Connection() *MConnection { | |||
return p.mconn | |||
} | |||
// IsOutbound returns true if the connection is outbound, false otherwise. | |||
func (p *Peer) IsOutbound() bool { | |||
return p.outbound | |||
} | |||
// Send msg to the channel identified by chID byte. Returns false if the send | |||
// queue is full after timeout, specified by MConnection. | |||
func (p *Peer) Send(chID byte, msg interface{}) bool { | |||
if !p.IsRunning() { | |||
// see Switch#Broadcast, where we fetch the list of peers and loop over | |||
// them - while we're looping, one peer may be removed and stopped. | |||
return false | |||
} | |||
return p.mconn.Send(chID, msg) | |||
} | |||
// TrySend msg to the channel identified by chID byte. Immediately returns | |||
// false if the send queue is full. | |||
func (p *Peer) TrySend(chID byte, msg interface{}) bool { | |||
if !p.IsRunning() { | |||
return false | |||
} | |||
return p.mconn.TrySend(chID, msg) | |||
} | |||
// CanSend returns true if the send queue is not full, false otherwise. | |||
func (p *Peer) CanSend(chID byte) bool { | |||
if !p.IsRunning() { | |||
return false | |||
} | |||
return p.mconn.CanSend(chID) | |||
} | |||
// WriteTo writes the peer's public key to w. | |||
func (p *Peer) WriteTo(w io.Writer) (n int64, err error) { | |||
var n_ int | |||
wire.WriteString(p.Key, w, &n_, &err) | |||
n += int64(n_) | |||
return | |||
} | |||
// String representation. | |||
func (p *Peer) String() string { | |||
if p.outbound { | |||
return fmt.Sprintf("Peer{%v %v out}", p.mconn, p.Key[:12]) | |||
} | |||
return fmt.Sprintf("Peer{%v %v in}", p.mconn, p.Key[:12]) | |||
} | |||
// Equals reports whenever 2 peers are actually represent the same node. | |||
func (p *Peer) Equals(other *Peer) bool { | |||
return p.Key == other.Key | |||
} | |||
// Get the data for a given key. | |||
func (p *Peer) Get(key string) interface{} { | |||
return p.Data.Get(key) | |||
} | |||
func dial(addr *NetAddress, config *PeerConfig) (net.Conn, error) { | |||
log.Info("Dialing address", "address", addr) | |||
conn, err := addr.DialTimeout(config.DialTimeout) | |||
if err != nil { | |||
log.Info("Failed dialing address", "address", addr, "error", err) | |||
return nil, err | |||
} | |||
return conn, nil | |||
} | |||
func createMConnection(conn net.Conn, p *Peer, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), config *MConnConfig) *MConnection { | |||
onReceive := func(chID byte, msgBytes []byte) { | |||
reactor := reactorsByCh[chID] | |||
if reactor == nil { | |||
cmn.PanicSanity(cmn.Fmt("Unknown channel %X", chID)) | |||
} | |||
reactor.Receive(chID, p, msgBytes) | |||
} | |||
onError := func(r interface{}) { | |||
onPeerError(p, r) | |||
} | |||
return NewMConnectionWithConfig(conn, chDescs, onReceive, onError, config) | |||
} |
@ -0,0 +1,115 @@ | |||
package p2p | |||
import ( | |||
"sync" | |||
) | |||
// IPeerSet has a (immutable) subset of the methods of PeerSet. | |||
type IPeerSet interface { | |||
Has(key string) bool | |||
Get(key string) *Peer | |||
List() []*Peer | |||
Size() int | |||
} | |||
//----------------------------------------------------------------------------- | |||
// PeerSet is a special structure for keeping a table of peers. | |||
// Iteration over the peers is super fast and thread-safe. | |||
// We also track how many peers per IP range and avoid too many | |||
type PeerSet struct { | |||
mtx sync.Mutex | |||
lookup map[string]*peerSetItem | |||
list []*Peer | |||
} | |||
type peerSetItem struct { | |||
peer *Peer | |||
index int | |||
} | |||
func NewPeerSet() *PeerSet { | |||
return &PeerSet{ | |||
lookup: make(map[string]*peerSetItem), | |||
list: make([]*Peer, 0, 256), | |||
} | |||
} | |||
// Returns false if peer with key (PubKeyEd25519) is already in set | |||
// or if we have too many peers from the peer's IP range | |||
func (ps *PeerSet) Add(peer *Peer) error { | |||
ps.mtx.Lock() | |||
defer ps.mtx.Unlock() | |||
if ps.lookup[peer.Key] != nil { | |||
return ErrSwitchDuplicatePeer | |||
} | |||
index := len(ps.list) | |||
// Appending is safe even with other goroutines | |||
// iterating over the ps.list slice. | |||
ps.list = append(ps.list, peer) | |||
ps.lookup[peer.Key] = &peerSetItem{peer, index} | |||
return nil | |||
} | |||
func (ps *PeerSet) Has(peerKey string) bool { | |||
ps.mtx.Lock() | |||
defer ps.mtx.Unlock() | |||
_, ok := ps.lookup[peerKey] | |||
return ok | |||
} | |||
func (ps *PeerSet) Get(peerKey string) *Peer { | |||
ps.mtx.Lock() | |||
defer ps.mtx.Unlock() | |||
item, ok := ps.lookup[peerKey] | |||
if ok { | |||
return item.peer | |||
} else { | |||
return nil | |||
} | |||
} | |||
func (ps *PeerSet) Remove(peer *Peer) { | |||
ps.mtx.Lock() | |||
defer ps.mtx.Unlock() | |||
item := ps.lookup[peer.Key] | |||
if item == nil { | |||
return | |||
} | |||
index := item.index | |||
// Copy the list but without the last element. | |||
// (we must copy because we're mutating the list) | |||
newList := make([]*Peer, len(ps.list)-1) | |||
copy(newList, ps.list) | |||
// If it's the last peer, that's an easy special case. | |||
if index == len(ps.list)-1 { | |||
ps.list = newList | |||
delete(ps.lookup, peer.Key) | |||
return | |||
} | |||
// Move the last item from ps.list to "index" in list. | |||
lastPeer := ps.list[len(ps.list)-1] | |||
lastPeerKey := lastPeer.Key | |||
lastPeerItem := ps.lookup[lastPeerKey] | |||
newList[index] = lastPeer | |||
lastPeerItem.index = index | |||
ps.list = newList | |||
delete(ps.lookup, peer.Key) | |||
} | |||
func (ps *PeerSet) Size() int { | |||
ps.mtx.Lock() | |||
defer ps.mtx.Unlock() | |||
return len(ps.list) | |||
} | |||
// threadsafe list of peers. | |||
func (ps *PeerSet) List() []*Peer { | |||
ps.mtx.Lock() | |||
defer ps.mtx.Unlock() | |||
return ps.list | |||
} |
@ -0,0 +1,67 @@ | |||
package p2p | |||
import ( | |||
"math/rand" | |||
"testing" | |||
. "github.com/tendermint/tmlibs/common" | |||
) | |||
// Returns an empty dummy peer | |||
func randPeer() *Peer { | |||
return &Peer{ | |||
Key: RandStr(12), | |||
NodeInfo: &NodeInfo{ | |||
RemoteAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256), | |||
ListenAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256), | |||
}, | |||
} | |||
} | |||
func TestAddRemoveOne(t *testing.T) { | |||
peerSet := NewPeerSet() | |||
peer := randPeer() | |||
err := peerSet.Add(peer) | |||
if err != nil { | |||
t.Errorf("Failed to add new peer") | |||
} | |||
if peerSet.Size() != 1 { | |||
t.Errorf("Failed to add new peer and increment size") | |||
} | |||
peerSet.Remove(peer) | |||
if peerSet.Has(peer.Key) { | |||
t.Errorf("Failed to remove peer") | |||
} | |||
if peerSet.Size() != 0 { | |||
t.Errorf("Failed to remove peer and decrement size") | |||
} | |||
} | |||
func TestAddRemoveMany(t *testing.T) { | |||
peerSet := NewPeerSet() | |||
peers := []*Peer{} | |||
N := 100 | |||
for i := 0; i < N; i++ { | |||
peer := randPeer() | |||
if err := peerSet.Add(peer); err != nil { | |||
t.Errorf("Failed to add new peer") | |||
} | |||
if peerSet.Size() != i+1 { | |||
t.Errorf("Failed to add new peer and increment size") | |||
} | |||
peers = append(peers, peer) | |||
} | |||
for i, peer := range peers { | |||
peerSet.Remove(peer) | |||
if peerSet.Has(peer.Key) { | |||
t.Errorf("Failed to remove peer") | |||
} | |||
if peerSet.Size() != len(peers)-i-1 { | |||
t.Errorf("Failed to remove peer and decrement size") | |||
} | |||
} | |||
} |
@ -0,0 +1,156 @@ | |||
package p2p | |||
import ( | |||
golog "log" | |||
"net" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
crypto "github.com/tendermint/go-crypto" | |||
) | |||
func TestPeerBasic(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
// simulate remote peer | |||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} | |||
rp.Start() | |||
defer rp.Stop() | |||
p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), DefaultPeerConfig()) | |||
require.Nil(err) | |||
p.Start() | |||
defer p.Stop() | |||
assert.True(p.IsRunning()) | |||
assert.True(p.IsOutbound()) | |||
assert.False(p.IsPersistent()) | |||
p.makePersistent() | |||
assert.True(p.IsPersistent()) | |||
assert.Equal(rp.Addr().String(), p.Addr().String()) | |||
assert.Equal(rp.PubKey(), p.PubKey()) | |||
} | |||
func TestPeerWithoutAuthEnc(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
config := DefaultPeerConfig() | |||
config.AuthEnc = false | |||
// simulate remote peer | |||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config} | |||
rp.Start() | |||
defer rp.Stop() | |||
p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config) | |||
require.Nil(err) | |||
p.Start() | |||
defer p.Stop() | |||
assert.True(p.IsRunning()) | |||
} | |||
func TestPeerSend(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
config := DefaultPeerConfig() | |||
config.AuthEnc = false | |||
// simulate remote peer | |||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config} | |||
rp.Start() | |||
defer rp.Stop() | |||
p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config) | |||
require.Nil(err) | |||
p.Start() | |||
defer p.Stop() | |||
assert.True(p.CanSend(0x01)) | |||
assert.True(p.Send(0x01, "Asylum")) | |||
} | |||
func createOutboundPeerAndPerformHandshake(addr *NetAddress, config *PeerConfig) (*Peer, error) { | |||
chDescs := []*ChannelDescriptor{ | |||
&ChannelDescriptor{ID: 0x01, Priority: 1}, | |||
} | |||
reactorsByCh := map[byte]Reactor{0x01: NewTestReactor(chDescs, true)} | |||
pk := crypto.GenPrivKeyEd25519() | |||
p, err := newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, func(p *Peer, r interface{}) {}, pk, config) | |||
if err != nil { | |||
return nil, err | |||
} | |||
err = p.HandshakeTimeout(&NodeInfo{ | |||
PubKey: pk.PubKey().Unwrap().(crypto.PubKeyEd25519), | |||
Moniker: "host_peer", | |||
Network: "testing", | |||
Version: "123.123.123", | |||
}, 1*time.Second) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return p, nil | |||
} | |||
type remotePeer struct { | |||
PrivKey crypto.PrivKeyEd25519 | |||
Config *PeerConfig | |||
addr *NetAddress | |||
quit chan struct{} | |||
} | |||
func (p *remotePeer) Addr() *NetAddress { | |||
return p.addr | |||
} | |||
func (p *remotePeer) PubKey() crypto.PubKeyEd25519 { | |||
return p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||
} | |||
func (p *remotePeer) Start() { | |||
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address | |||
if e != nil { | |||
golog.Fatalf("net.Listen tcp :0: %+v", e) | |||
} | |||
p.addr = NewNetAddress(l.Addr()) | |||
p.quit = make(chan struct{}) | |||
go p.accept(l) | |||
} | |||
func (p *remotePeer) Stop() { | |||
close(p.quit) | |||
} | |||
func (p *remotePeer) accept(l net.Listener) { | |||
for { | |||
conn, err := l.Accept() | |||
if err != nil { | |||
golog.Fatalf("Failed to accept conn: %+v", err) | |||
} | |||
peer, err := newInboundPeerWithConfig(conn, make(map[byte]Reactor), make([]*ChannelDescriptor, 0), func(p *Peer, r interface{}) {}, p.PrivKey, p.Config) | |||
if err != nil { | |||
golog.Fatalf("Failed to create a peer: %+v", err) | |||
} | |||
err = peer.HandshakeTimeout(&NodeInfo{ | |||
PubKey: p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519), | |||
Moniker: "remote_peer", | |||
Network: "testing", | |||
Version: "123.123.123", | |||
}, 1*time.Second) | |||
if err != nil { | |||
golog.Fatalf("Failed to perform handshake: %+v", err) | |||
} | |||
select { | |||
case <-p.quit: | |||
conn.Close() | |||
return | |||
default: | |||
} | |||
} | |||
} |
@ -0,0 +1,358 @@ | |||
package p2p | |||
import ( | |||
"bytes" | |||
"fmt" | |||
"math/rand" | |||
"reflect" | |||
"time" | |||
cmn "github.com/tendermint/tmlibs/common" | |||
wire "github.com/tendermint/go-wire" | |||
) | |||
const ( | |||
// PexChannel is a channel for PEX messages | |||
PexChannel = byte(0x00) | |||
// period to ensure peers connected | |||
defaultEnsurePeersPeriod = 30 * time.Second | |||
minNumOutboundPeers = 10 | |||
maxPexMessageSize = 1048576 // 1MB | |||
// maximum messages one peer can send to us during `msgCountByPeerFlushInterval` | |||
defaultMaxMsgCountByPeer = 1000 | |||
msgCountByPeerFlushInterval = 1 * time.Hour | |||
) | |||
// PEXReactor handles PEX (peer exchange) and ensures that an | |||
// adequate number of peers are connected to the switch. | |||
// | |||
// It uses `AddrBook` (address book) to store `NetAddress`es of the peers. | |||
// | |||
// ## Preventing abuse | |||
// | |||
// For now, it just limits the number of messages from one peer to | |||
// `defaultMaxMsgCountByPeer` messages per `msgCountByPeerFlushInterval` (1000 | |||
// msg/hour). | |||
// | |||
// NOTE [2017-01-17]: | |||
// Limiting is fine for now. Maybe down the road we want to keep track of the | |||
// quality of peer messages so if peerA keeps telling us about peers we can't | |||
// connect to then maybe we should care less about peerA. But I don't think | |||
// that kind of complexity is priority right now. | |||
type PEXReactor struct { | |||
BaseReactor | |||
sw *Switch | |||
book *AddrBook | |||
ensurePeersPeriod time.Duration | |||
// tracks message count by peer, so we can prevent abuse | |||
msgCountByPeer *cmn.CMap | |||
maxMsgCountByPeer uint16 | |||
} | |||
// NewPEXReactor creates new PEX reactor. | |||
func NewPEXReactor(b *AddrBook) *PEXReactor { | |||
r := &PEXReactor{ | |||
book: b, | |||
ensurePeersPeriod: defaultEnsurePeersPeriod, | |||
msgCountByPeer: cmn.NewCMap(), | |||
maxMsgCountByPeer: defaultMaxMsgCountByPeer, | |||
} | |||
r.BaseReactor = *NewBaseReactor(log, "PEXReactor", r) | |||
return r | |||
} | |||
// OnStart implements BaseService | |||
func (r *PEXReactor) OnStart() error { | |||
r.BaseReactor.OnStart() | |||
r.book.Start() | |||
go r.ensurePeersRoutine() | |||
go r.flushMsgCountByPeer() | |||
return nil | |||
} | |||
// OnStop implements BaseService | |||
func (r *PEXReactor) OnStop() { | |||
r.BaseReactor.OnStop() | |||
r.book.Stop() | |||
} | |||
// GetChannels implements Reactor | |||
func (r *PEXReactor) GetChannels() []*ChannelDescriptor { | |||
return []*ChannelDescriptor{ | |||
&ChannelDescriptor{ | |||
ID: PexChannel, | |||
Priority: 1, | |||
SendQueueCapacity: 10, | |||
}, | |||
} | |||
} | |||
// AddPeer implements Reactor by adding peer to the address book (if inbound) | |||
// or by requesting more addresses (if outbound). | |||
func (r *PEXReactor) AddPeer(p *Peer) { | |||
if p.IsOutbound() { | |||
// For outbound peers, the address is already in the books. | |||
// Either it was added in DialSeeds or when we | |||
// received the peer's address in r.Receive | |||
if r.book.NeedMoreAddrs() { | |||
r.RequestPEX(p) | |||
} | |||
} else { // For inbound connections, the peer is its own source | |||
addr, err := NewNetAddressString(p.ListenAddr) | |||
if err != nil { | |||
// this should never happen | |||
log.Error("Error in AddPeer: invalid peer address", "addr", p.ListenAddr, "error", err) | |||
return | |||
} | |||
r.book.AddAddress(addr, addr) | |||
} | |||
} | |||
// RemovePeer implements Reactor. | |||
func (r *PEXReactor) RemovePeer(p *Peer, reason interface{}) { | |||
// If we aren't keeping track of local temp data for each peer here, then we | |||
// don't have to do anything. | |||
} | |||
// Receive implements Reactor by handling incoming PEX messages. | |||
func (r *PEXReactor) Receive(chID byte, src *Peer, msgBytes []byte) { | |||
srcAddr := src.Connection().RemoteAddress | |||
srcAddrStr := srcAddr.String() | |||
r.IncrementMsgCountForPeer(srcAddrStr) | |||
if r.ReachedMaxMsgCountForPeer(srcAddrStr) { | |||
log.Warn("Maximum number of messages reached for peer", "peer", srcAddrStr) | |||
// TODO remove src from peers? | |||
return | |||
} | |||
_, msg, err := DecodeMessage(msgBytes) | |||
if err != nil { | |||
log.Warn("Error decoding message", "error", err) | |||
return | |||
} | |||
log.Notice("Received message", "msg", msg) | |||
switch msg := msg.(type) { | |||
case *pexRequestMessage: | |||
// src requested some peers. | |||
r.SendAddrs(src, r.book.GetSelection()) | |||
case *pexAddrsMessage: | |||
// We received some peer addresses from src. | |||
// (We don't want to get spammed with bad peers) | |||
for _, addr := range msg.Addrs { | |||
if addr != nil { | |||
r.book.AddAddress(addr, srcAddr) | |||
} | |||
} | |||
default: | |||
log.Warn(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg))) | |||
} | |||
} | |||
// RequestPEX asks peer for more addresses. | |||
func (r *PEXReactor) RequestPEX(p *Peer) { | |||
p.Send(PexChannel, struct{ PexMessage }{&pexRequestMessage{}}) | |||
} | |||
// SendAddrs sends addrs to the peer. | |||
func (r *PEXReactor) SendAddrs(p *Peer, addrs []*NetAddress) { | |||
p.Send(PexChannel, struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) | |||
} | |||
// SetEnsurePeersPeriod sets period to ensure peers connected. | |||
func (r *PEXReactor) SetEnsurePeersPeriod(d time.Duration) { | |||
r.ensurePeersPeriod = d | |||
} | |||
// SetMaxMsgCountByPeer sets maximum messages one peer can send to us during 'msgCountByPeerFlushInterval'. | |||
func (r *PEXReactor) SetMaxMsgCountByPeer(v uint16) { | |||
r.maxMsgCountByPeer = v | |||
} | |||
// ReachedMaxMsgCountForPeer returns true if we received too many | |||
// messages from peer with address `addr`. | |||
// NOTE: assumes the value in the CMap is non-nil | |||
func (r *PEXReactor) ReachedMaxMsgCountForPeer(addr string) bool { | |||
return r.msgCountByPeer.Get(addr).(uint16) >= r.maxMsgCountByPeer | |||
} | |||
// Increment or initialize the msg count for the peer in the CMap | |||
func (r *PEXReactor) IncrementMsgCountForPeer(addr string) { | |||
var count uint16 | |||
countI := r.msgCountByPeer.Get(addr) | |||
if countI != nil { | |||
count = countI.(uint16) | |||
} | |||
count++ | |||
r.msgCountByPeer.Set(addr, count) | |||
} | |||
// Ensures that sufficient peers are connected. (continuous) | |||
func (r *PEXReactor) ensurePeersRoutine() { | |||
// Randomize when routine starts | |||
ensurePeersPeriodMs := r.ensurePeersPeriod.Nanoseconds() / 1e6 | |||
time.Sleep(time.Duration(rand.Int63n(ensurePeersPeriodMs)) * time.Millisecond) | |||
// fire once immediately. | |||
r.ensurePeers() | |||
// fire periodically | |||
ticker := time.NewTicker(r.ensurePeersPeriod) | |||
for { | |||
select { | |||
case <-ticker.C: | |||
r.ensurePeers() | |||
case <-r.Quit: | |||
ticker.Stop() | |||
return | |||
} | |||
} | |||
} | |||
// ensurePeers ensures that sufficient peers are connected. (once) | |||
// | |||
// Old bucket / New bucket are arbitrary categories to denote whether an | |||
// address is vetted or not, and this needs to be determined over time via a | |||
// heuristic that we haven't perfected yet, or, perhaps is manually edited by | |||
// the node operator. It should not be used to compute what addresses are | |||
// already connected or not. | |||
// | |||
// TODO Basically, we need to work harder on our good-peer/bad-peer marking. | |||
// What we're currently doing in terms of marking good/bad peers is just a | |||
// placeholder. It should not be the case that an address becomes old/vetted | |||
// upon a single successful connection. | |||
func (r *PEXReactor) ensurePeers() { | |||
numOutPeers, _, numDialing := r.Switch.NumPeers() | |||
numToDial := minNumOutboundPeers - (numOutPeers + numDialing) | |||
log.Info("Ensure peers", "numOutPeers", numOutPeers, "numDialing", numDialing, "numToDial", numToDial) | |||
if numToDial <= 0 { | |||
return | |||
} | |||
toDial := make(map[string]*NetAddress) | |||
// Try to pick numToDial addresses to dial. | |||
for i := 0; i < numToDial; i++ { | |||
// The purpose of newBias is to first prioritize old (more vetted) peers | |||
// when we have few connections, but to allow for new (less vetted) peers | |||
// if we already have many connections. This algorithm isn't perfect, but | |||
// it somewhat ensures that we prioritize connecting to more-vetted | |||
// peers. | |||
newBias := cmn.MinInt(numOutPeers, 8)*10 + 10 | |||
var picked *NetAddress | |||
// Try to fetch a new peer 3 times. | |||
// This caps the maximum number of tries to 3 * numToDial. | |||
for j := 0; j < 3; j++ { | |||
try := r.book.PickAddress(newBias) | |||
if try == nil { | |||
break | |||
} | |||
_, alreadySelected := toDial[try.IP.String()] | |||
alreadyDialing := r.Switch.IsDialing(try) | |||
alreadyConnected := r.Switch.Peers().Has(try.IP.String()) | |||
if alreadySelected || alreadyDialing || alreadyConnected { | |||
// log.Info("Cannot dial address", "addr", try, | |||
// "alreadySelected", alreadySelected, | |||
// "alreadyDialing", alreadyDialing, | |||
// "alreadyConnected", alreadyConnected) | |||
continue | |||
} else { | |||
log.Info("Will dial address", "addr", try) | |||
picked = try | |||
break | |||
} | |||
} | |||
if picked == nil { | |||
continue | |||
} | |||
toDial[picked.IP.String()] = picked | |||
} | |||
// Dial picked addresses | |||
for _, item := range toDial { | |||
go func(picked *NetAddress) { | |||
_, err := r.Switch.DialPeerWithAddress(picked, false) | |||
if err != nil { | |||
r.book.MarkAttempt(picked) | |||
} | |||
}(item) | |||
} | |||
// If we need more addresses, pick a random peer and ask for more. | |||
if r.book.NeedMoreAddrs() { | |||
if peers := r.Switch.Peers().List(); len(peers) > 0 { | |||
i := rand.Int() % len(peers) | |||
peer := peers[i] | |||
log.Info("No addresses to dial. Sending pexRequest to random peer", "peer", peer) | |||
r.RequestPEX(peer) | |||
} | |||
} | |||
} | |||
func (r *PEXReactor) flushMsgCountByPeer() { | |||
ticker := time.NewTicker(msgCountByPeerFlushInterval) | |||
for { | |||
select { | |||
case <-ticker.C: | |||
r.msgCountByPeer.Clear() | |||
case <-r.Quit: | |||
ticker.Stop() | |||
return | |||
} | |||
} | |||
} | |||
//----------------------------------------------------------------------------- | |||
// Messages | |||
const ( | |||
msgTypeRequest = byte(0x01) | |||
msgTypeAddrs = byte(0x02) | |||
) | |||
// PexMessage is a primary type for PEX messages. Underneath, it could contain | |||
// either pexRequestMessage, or pexAddrsMessage messages. | |||
type PexMessage interface{} | |||
var _ = wire.RegisterInterface( | |||
struct{ PexMessage }{}, | |||
wire.ConcreteType{&pexRequestMessage{}, msgTypeRequest}, | |||
wire.ConcreteType{&pexAddrsMessage{}, msgTypeAddrs}, | |||
) | |||
// DecodeMessage implements interface registered above. | |||
func DecodeMessage(bz []byte) (msgType byte, msg PexMessage, err error) { | |||
msgType = bz[0] | |||
n := new(int) | |||
r := bytes.NewReader(bz) | |||
msg = wire.ReadBinary(struct{ PexMessage }{}, r, maxPexMessageSize, n, &err).(struct{ PexMessage }).PexMessage | |||
return | |||
} | |||
/* | |||
A pexRequestMessage requests additional peer addresses. | |||
*/ | |||
type pexRequestMessage struct { | |||
} | |||
func (m *pexRequestMessage) String() string { | |||
return "[pexRequest]" | |||
} | |||
/* | |||
A message with announced peer addresses. | |||
*/ | |||
type pexAddrsMessage struct { | |||
Addrs []*NetAddress | |||
} | |||
func (m *pexAddrsMessage) String() string { | |||
return fmt.Sprintf("[pexAddrs %v]", m.Addrs) | |||
} |
@ -0,0 +1,163 @@ | |||
package p2p | |||
import ( | |||
"io/ioutil" | |||
"math/rand" | |||
"os" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
cmn "github.com/tendermint/tmlibs/common" | |||
wire "github.com/tendermint/go-wire" | |||
) | |||
func TestPEXReactorBasic(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
dir, err := ioutil.TempDir("", "pex_reactor") | |||
require.Nil(err) | |||
defer os.RemoveAll(dir) | |||
book := NewAddrBook(dir+"addrbook.json", true) | |||
r := NewPEXReactor(book) | |||
assert.NotNil(r) | |||
assert.NotEmpty(r.GetChannels()) | |||
} | |||
func TestPEXReactorAddRemovePeer(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
dir, err := ioutil.TempDir("", "pex_reactor") | |||
require.Nil(err) | |||
defer os.RemoveAll(dir) | |||
book := NewAddrBook(dir+"addrbook.json", true) | |||
r := NewPEXReactor(book) | |||
size := book.Size() | |||
peer := createRandomPeer(false) | |||
r.AddPeer(peer) | |||
assert.Equal(size+1, book.Size()) | |||
r.RemovePeer(peer, "peer not available") | |||
assert.Equal(size+1, book.Size()) | |||
outboundPeer := createRandomPeer(true) | |||
r.AddPeer(outboundPeer) | |||
assert.Equal(size+1, book.Size(), "outbound peers should not be added to the address book") | |||
r.RemovePeer(outboundPeer, "peer not available") | |||
assert.Equal(size+1, book.Size()) | |||
} | |||
func TestPEXReactorRunning(t *testing.T) { | |||
require := require.New(t) | |||
N := 3 | |||
switches := make([]*Switch, N) | |||
dir, err := ioutil.TempDir("", "pex_reactor") | |||
require.Nil(err) | |||
defer os.RemoveAll(dir) | |||
book := NewAddrBook(dir+"addrbook.json", false) | |||
// create switches | |||
for i := 0; i < N; i++ { | |||
switches[i] = makeSwitch(i, "127.0.0.1", "123.123.123", func(i int, sw *Switch) *Switch { | |||
r := NewPEXReactor(book) | |||
r.SetEnsurePeersPeriod(250 * time.Millisecond) | |||
sw.AddReactor("pex", r) | |||
return sw | |||
}) | |||
} | |||
// fill the address book and add listeners | |||
for _, s := range switches { | |||
addr, _ := NewNetAddressString(s.NodeInfo().ListenAddr) | |||
book.AddAddress(addr, addr) | |||
s.AddListener(NewDefaultListener("tcp", s.NodeInfo().ListenAddr, true)) | |||
} | |||
// start switches | |||
for _, s := range switches { | |||
_, err := s.Start() // start switch and reactors | |||
require.Nil(err) | |||
} | |||
time.Sleep(1 * time.Second) | |||
// check peers are connected after some time | |||
for _, s := range switches { | |||
outbound, inbound, _ := s.NumPeers() | |||
if outbound+inbound == 0 { | |||
t.Errorf("%v expected to be connected to at least one peer", s.NodeInfo().ListenAddr) | |||
} | |||
} | |||
// stop them | |||
for _, s := range switches { | |||
s.Stop() | |||
} | |||
} | |||
func TestPEXReactorReceive(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
dir, err := ioutil.TempDir("", "pex_reactor") | |||
require.Nil(err) | |||
defer os.RemoveAll(dir) | |||
book := NewAddrBook(dir+"addrbook.json", true) | |||
r := NewPEXReactor(book) | |||
peer := createRandomPeer(false) | |||
size := book.Size() | |||
netAddr, _ := NewNetAddressString(peer.ListenAddr) | |||
addrs := []*NetAddress{netAddr} | |||
msg := wire.BinaryBytes(struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) | |||
r.Receive(PexChannel, peer, msg) | |||
assert.Equal(size+1, book.Size()) | |||
msg = wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}}) | |||
r.Receive(PexChannel, peer, msg) | |||
} | |||
func TestPEXReactorAbuseFromPeer(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
dir, err := ioutil.TempDir("", "pex_reactor") | |||
require.Nil(err) | |||
defer os.RemoveAll(dir) | |||
book := NewAddrBook(dir+"addrbook.json", true) | |||
r := NewPEXReactor(book) | |||
r.SetMaxMsgCountByPeer(5) | |||
peer := createRandomPeer(false) | |||
msg := wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}}) | |||
for i := 0; i < 10; i++ { | |||
r.Receive(PexChannel, peer, msg) | |||
} | |||
assert.True(r.ReachedMaxMsgCountForPeer(peer.ListenAddr)) | |||
} | |||
func createRandomPeer(outbound bool) *Peer { | |||
addr := cmn.Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256) | |||
netAddr, _ := NewNetAddressString(addr) | |||
return &Peer{ | |||
Key: cmn.RandStr(12), | |||
NodeInfo: &NodeInfo{ | |||
ListenAddr: addr, | |||
}, | |||
outbound: outbound, | |||
mconn: &MConnection{RemoteAddress: netAddr}, | |||
} | |||
} |
@ -0,0 +1,346 @@ | |||
// Uses nacl's secret_box to encrypt a net.Conn. | |||
// It is (meant to be) an implementation of the STS protocol. | |||
// Note we do not (yet) assume that a remote peer's pubkey | |||
// is known ahead of time, and thus we are technically | |||
// still vulnerable to MITM. (TODO!) | |||
// See docs/sts-final.pdf for more info | |||
package p2p | |||
import ( | |||
"bytes" | |||
crand "crypto/rand" | |||
"crypto/sha256" | |||
"encoding/binary" | |||
"errors" | |||
"io" | |||
"net" | |||
"time" | |||
"golang.org/x/crypto/nacl/box" | |||
"golang.org/x/crypto/nacl/secretbox" | |||
"golang.org/x/crypto/ripemd160" | |||
"github.com/tendermint/go-crypto" | |||
"github.com/tendermint/go-wire" | |||
. "github.com/tendermint/tmlibs/common" | |||
) | |||
// 2 + 1024 == 1026 total frame size | |||
const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize | |||
const dataMaxSize = 1024 | |||
const totalFrameSize = dataMaxSize + dataLenSize | |||
const sealedFrameSize = totalFrameSize + secretbox.Overhead | |||
const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays | |||
// Implements net.Conn | |||
type SecretConnection struct { | |||
conn io.ReadWriteCloser | |||
recvBuffer []byte | |||
recvNonce *[24]byte | |||
sendNonce *[24]byte | |||
remPubKey crypto.PubKeyEd25519 | |||
shrSecret *[32]byte // shared secret | |||
} | |||
// Performs handshake and returns a new authenticated SecretConnection. | |||
// Returns nil if error in handshake. | |||
// Caller should call conn.Close() | |||
// See docs/sts-final.pdf for more information. | |||
func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25519) (*SecretConnection, error) { | |||
locPubKey := locPrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||
// Generate ephemeral keys for perfect forward secrecy. | |||
locEphPub, locEphPriv := genEphKeys() | |||
// Write local ephemeral pubkey and receive one too. | |||
// NOTE: every 32-byte string is accepted as a Curve25519 public key | |||
// (see DJB's Curve25519 paper: http://cr.yp.to/ecdh/curve25519-20060209.pdf) | |||
remEphPub, err := shareEphPubKey(conn, locEphPub) | |||
if err != nil { | |||
return nil, err | |||
} | |||
// Compute common shared secret. | |||
shrSecret := computeSharedSecret(remEphPub, locEphPriv) | |||
// Sort by lexical order. | |||
loEphPub, hiEphPub := sort32(locEphPub, remEphPub) | |||
// Generate nonces to use for secretbox. | |||
recvNonce, sendNonce := genNonces(loEphPub, hiEphPub, locEphPub == loEphPub) | |||
// Generate common challenge to sign. | |||
challenge := genChallenge(loEphPub, hiEphPub) | |||
// Construct SecretConnection. | |||
sc := &SecretConnection{ | |||
conn: conn, | |||
recvBuffer: nil, | |||
recvNonce: recvNonce, | |||
sendNonce: sendNonce, | |||
shrSecret: shrSecret, | |||
} | |||
// Sign the challenge bytes for authentication. | |||
locSignature := signChallenge(challenge, locPrivKey) | |||
// Share (in secret) each other's pubkey & challenge signature | |||
authSigMsg, err := shareAuthSignature(sc, locPubKey, locSignature) | |||
if err != nil { | |||
return nil, err | |||
} | |||
remPubKey, remSignature := authSigMsg.Key, authSigMsg.Sig | |||
if !remPubKey.VerifyBytes(challenge[:], remSignature) { | |||
return nil, errors.New("Challenge verification failed") | |||
} | |||
// We've authorized. | |||
sc.remPubKey = remPubKey.Unwrap().(crypto.PubKeyEd25519) | |||
return sc, nil | |||
} | |||
// Returns authenticated remote pubkey | |||
func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 { | |||
return sc.remPubKey | |||
} | |||
// Writes encrypted frames of `sealedFrameSize` | |||
// CONTRACT: data smaller than dataMaxSize is read atomically. | |||
func (sc *SecretConnection) Write(data []byte) (n int, err error) { | |||
for 0 < len(data) { | |||
var frame []byte = make([]byte, totalFrameSize) | |||
var chunk []byte | |||
if dataMaxSize < len(data) { | |||
chunk = data[:dataMaxSize] | |||
data = data[dataMaxSize:] | |||
} else { | |||
chunk = data | |||
data = nil | |||
} | |||
chunkLength := len(chunk) | |||
binary.BigEndian.PutUint16(frame, uint16(chunkLength)) | |||
copy(frame[dataLenSize:], chunk) | |||
// encrypt the frame | |||
var sealedFrame = make([]byte, sealedFrameSize) | |||
secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret) | |||
// fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret) | |||
incr2Nonce(sc.sendNonce) | |||
// end encryption | |||
_, err := sc.conn.Write(sealedFrame) | |||
if err != nil { | |||
return n, err | |||
} else { | |||
n += len(chunk) | |||
} | |||
} | |||
return | |||
} | |||
// CONTRACT: data smaller than dataMaxSize is read atomically. | |||
func (sc *SecretConnection) Read(data []byte) (n int, err error) { | |||
if 0 < len(sc.recvBuffer) { | |||
n_ := copy(data, sc.recvBuffer) | |||
sc.recvBuffer = sc.recvBuffer[n_:] | |||
return | |||
} | |||
sealedFrame := make([]byte, sealedFrameSize) | |||
_, err = io.ReadFull(sc.conn, sealedFrame) | |||
if err != nil { | |||
return | |||
} | |||
// decrypt the frame | |||
var frame = make([]byte, totalFrameSize) | |||
// fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret) | |||
_, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret) | |||
if !ok { | |||
return n, errors.New("Failed to decrypt SecretConnection") | |||
} | |||
incr2Nonce(sc.recvNonce) | |||
// end decryption | |||
var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes | |||
if chunkLength > dataMaxSize { | |||
return 0, errors.New("chunkLength is greater than dataMaxSize") | |||
} | |||
var chunk = frame[dataLenSize : dataLenSize+chunkLength] | |||
n = copy(data, chunk) | |||
sc.recvBuffer = chunk[n:] | |||
return | |||
} | |||
// Implements net.Conn | |||
func (sc *SecretConnection) Close() error { return sc.conn.Close() } | |||
func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() } | |||
func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() } | |||
func (sc *SecretConnection) SetDeadline(t time.Time) error { return sc.conn.(net.Conn).SetDeadline(t) } | |||
func (sc *SecretConnection) SetReadDeadline(t time.Time) error { | |||
return sc.conn.(net.Conn).SetReadDeadline(t) | |||
} | |||
func (sc *SecretConnection) SetWriteDeadline(t time.Time) error { | |||
return sc.conn.(net.Conn).SetWriteDeadline(t) | |||
} | |||
func genEphKeys() (ephPub, ephPriv *[32]byte) { | |||
var err error | |||
ephPub, ephPriv, err = box.GenerateKey(crand.Reader) | |||
if err != nil { | |||
PanicCrisis("Could not generate ephemeral keypairs") | |||
} | |||
return | |||
} | |||
func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { | |||
var err1, err2 error | |||
Parallel( | |||
func() { | |||
_, err1 = conn.Write(locEphPub[:]) | |||
}, | |||
func() { | |||
remEphPub = new([32]byte) | |||
_, err2 = io.ReadFull(conn, remEphPub[:]) | |||
}, | |||
) | |||
if err1 != nil { | |||
return nil, err1 | |||
} | |||
if err2 != nil { | |||
return nil, err2 | |||
} | |||
return remEphPub, nil | |||
} | |||
func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) { | |||
shrSecret = new([32]byte) | |||
box.Precompute(shrSecret, remPubKey, locPrivKey) | |||
return | |||
} | |||
func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) { | |||
if bytes.Compare(foo[:], bar[:]) < 0 { | |||
lo = foo | |||
hi = bar | |||
} else { | |||
lo = bar | |||
hi = foo | |||
} | |||
return | |||
} | |||
func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce *[24]byte) { | |||
nonce1 := hash24(append(loPubKey[:], hiPubKey[:]...)) | |||
nonce2 := new([24]byte) | |||
copy(nonce2[:], nonce1[:]) | |||
nonce2[len(nonce2)-1] ^= 0x01 | |||
if locIsLo { | |||
recvNonce = nonce1 | |||
sendNonce = nonce2 | |||
} else { | |||
recvNonce = nonce2 | |||
sendNonce = nonce1 | |||
} | |||
return | |||
} | |||
func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) { | |||
return hash32(append(loPubKey[:], hiPubKey[:]...)) | |||
} | |||
func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signature crypto.SignatureEd25519) { | |||
signature = locPrivKey.Sign(challenge[:]).Unwrap().(crypto.SignatureEd25519) | |||
return | |||
} | |||
type authSigMessage struct { | |||
Key crypto.PubKey | |||
Sig crypto.Signature | |||
} | |||
func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signature crypto.SignatureEd25519) (*authSigMessage, error) { | |||
var recvMsg authSigMessage | |||
var err1, err2 error | |||
Parallel( | |||
func() { | |||
msgBytes := wire.BinaryBytes(authSigMessage{pubKey.Wrap(), signature.Wrap()}) | |||
_, err1 = sc.Write(msgBytes) | |||
}, | |||
func() { | |||
readBuffer := make([]byte, authSigMsgSize) | |||
_, err2 = io.ReadFull(sc, readBuffer) | |||
if err2 != nil { | |||
return | |||
} | |||
n := int(0) // not used. | |||
recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage) | |||
}) | |||
if err1 != nil { | |||
return nil, err1 | |||
} | |||
if err2 != nil { | |||
return nil, err2 | |||
} | |||
return &recvMsg, nil | |||
} | |||
func verifyChallengeSignature(challenge *[32]byte, remPubKey crypto.PubKeyEd25519, remSignature crypto.SignatureEd25519) bool { | |||
return remPubKey.VerifyBytes(challenge[:], remSignature.Wrap()) | |||
} | |||
//-------------------------------------------------------------------------------- | |||
// sha256 | |||
func hash32(input []byte) (res *[32]byte) { | |||
hasher := sha256.New() | |||
hasher.Write(input) // does not error | |||
resSlice := hasher.Sum(nil) | |||
res = new([32]byte) | |||
copy(res[:], resSlice) | |||
return | |||
} | |||
// We only fill in the first 20 bytes with ripemd160 | |||
func hash24(input []byte) (res *[24]byte) { | |||
hasher := ripemd160.New() | |||
hasher.Write(input) // does not error | |||
resSlice := hasher.Sum(nil) | |||
res = new([24]byte) | |||
copy(res[:], resSlice) | |||
return | |||
} | |||
// ripemd160 | |||
func hash20(input []byte) (res *[20]byte) { | |||
hasher := ripemd160.New() | |||
hasher.Write(input) // does not error | |||
resSlice := hasher.Sum(nil) | |||
res = new([20]byte) | |||
copy(res[:], resSlice) | |||
return | |||
} | |||
// increment nonce big-endian by 2 with wraparound. | |||
func incr2Nonce(nonce *[24]byte) { | |||
incrNonce(nonce) | |||
incrNonce(nonce) | |||
} | |||
// increment nonce big-endian by 1 with wraparound. | |||
func incrNonce(nonce *[24]byte) { | |||
for i := 23; 0 <= i; i-- { | |||
nonce[i] += 1 | |||
if nonce[i] != 0 { | |||
return | |||
} | |||
} | |||
} |
@ -0,0 +1,202 @@ | |||
package p2p | |||
import ( | |||
"bytes" | |||
"io" | |||
"testing" | |||
"github.com/tendermint/go-crypto" | |||
. "github.com/tendermint/tmlibs/common" | |||
) | |||
type dummyConn struct { | |||
*io.PipeReader | |||
*io.PipeWriter | |||
} | |||
func (drw dummyConn) Close() (err error) { | |||
err2 := drw.PipeWriter.CloseWithError(io.EOF) | |||
err1 := drw.PipeReader.Close() | |||
if err2 != nil { | |||
return err | |||
} | |||
return err1 | |||
} | |||
// Each returned ReadWriteCloser is akin to a net.Connection | |||
func makeDummyConnPair() (fooConn, barConn dummyConn) { | |||
barReader, fooWriter := io.Pipe() | |||
fooReader, barWriter := io.Pipe() | |||
return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter} | |||
} | |||
func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { | |||
fooConn, barConn := makeDummyConnPair() | |||
fooPrvKey := crypto.GenPrivKeyEd25519() | |||
fooPubKey := fooPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||
barPrvKey := crypto.GenPrivKeyEd25519() | |||
barPubKey := barPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||
Parallel( | |||
func() { | |||
var err error | |||
fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) | |||
if err != nil { | |||
tb.Errorf("Failed to establish SecretConnection for foo: %v", err) | |||
return | |||
} | |||
remotePubBytes := fooSecConn.RemotePubKey() | |||
if !bytes.Equal(remotePubBytes[:], barPubKey[:]) { | |||
tb.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v", | |||
barPubKey, fooSecConn.RemotePubKey()) | |||
} | |||
}, | |||
func() { | |||
var err error | |||
barSecConn, err = MakeSecretConnection(barConn, barPrvKey) | |||
if barSecConn == nil { | |||
tb.Errorf("Failed to establish SecretConnection for bar: %v", err) | |||
return | |||
} | |||
remotePubBytes := barSecConn.RemotePubKey() | |||
if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) { | |||
tb.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v", | |||
fooPubKey, barSecConn.RemotePubKey()) | |||
} | |||
}) | |||
return | |||
} | |||
func TestSecretConnectionHandshake(t *testing.T) { | |||
fooSecConn, barSecConn := makeSecretConnPair(t) | |||
fooSecConn.Close() | |||
barSecConn.Close() | |||
} | |||
func TestSecretConnectionReadWrite(t *testing.T) { | |||
fooConn, barConn := makeDummyConnPair() | |||
fooWrites, barWrites := []string{}, []string{} | |||
fooReads, barReads := []string{}, []string{} | |||
// Pre-generate the things to write (for foo & bar) | |||
for i := 0; i < 100; i++ { | |||
fooWrites = append(fooWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) | |||
barWrites = append(barWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) | |||
} | |||
// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa | |||
genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() { | |||
return func() { | |||
// Node handskae | |||
nodePrvKey := crypto.GenPrivKeyEd25519() | |||
nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) | |||
if err != nil { | |||
t.Errorf("Failed to establish SecretConnection for node: %v", err) | |||
return | |||
} | |||
// In parallel, handle reads and writes | |||
Parallel( | |||
func() { | |||
// Node writes | |||
for _, nodeWrite := range nodeWrites { | |||
n, err := nodeSecretConn.Write([]byte(nodeWrite)) | |||
if err != nil { | |||
t.Errorf("Failed to write to nodeSecretConn: %v", err) | |||
return | |||
} | |||
if n != len(nodeWrite) { | |||
t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) | |||
return | |||
} | |||
} | |||
nodeConn.PipeWriter.Close() | |||
}, | |||
func() { | |||
// Node reads | |||
readBuffer := make([]byte, dataMaxSize) | |||
for { | |||
n, err := nodeSecretConn.Read(readBuffer) | |||
if err == io.EOF { | |||
return | |||
} else if err != nil { | |||
t.Errorf("Failed to read from nodeSecretConn: %v", err) | |||
return | |||
} | |||
*nodeReads = append(*nodeReads, string(readBuffer[:n])) | |||
} | |||
nodeConn.PipeReader.Close() | |||
}) | |||
} | |||
} | |||
// Run foo & bar in parallel | |||
Parallel( | |||
genNodeRunner(fooConn, fooWrites, &fooReads), | |||
genNodeRunner(barConn, barWrites, &barReads), | |||
) | |||
// A helper to ensure that the writes and reads match. | |||
// Additionally, small writes (<= dataMaxSize) must be atomically read. | |||
compareWritesReads := func(writes []string, reads []string) { | |||
for { | |||
// Pop next write & corresponding reads | |||
var read, write string = "", writes[0] | |||
var readCount = 0 | |||
for _, readChunk := range reads { | |||
read += readChunk | |||
readCount += 1 | |||
if len(write) <= len(read) { | |||
break | |||
} | |||
if len(write) <= dataMaxSize { | |||
break // atomicity of small writes | |||
} | |||
} | |||
// Compare | |||
if write != read { | |||
t.Errorf("Expected to read %X, got %X", write, read) | |||
} | |||
// Iterate | |||
writes = writes[1:] | |||
reads = reads[readCount:] | |||
if len(writes) == 0 { | |||
break | |||
} | |||
} | |||
} | |||
compareWritesReads(fooWrites, barReads) | |||
compareWritesReads(barWrites, fooReads) | |||
} | |||
func BenchmarkSecretConnection(b *testing.B) { | |||
b.StopTimer() | |||
fooSecConn, barSecConn := makeSecretConnPair(b) | |||
fooWriteText := RandStr(dataMaxSize) | |||
// Consume reads from bar's reader | |||
go func() { | |||
readBuffer := make([]byte, dataMaxSize) | |||
for { | |||
_, err := barSecConn.Read(readBuffer) | |||
if err == io.EOF { | |||
return | |||
} else if err != nil { | |||
b.Fatalf("Failed to read from barSecConn: %v", err) | |||
} | |||
} | |||
}() | |||
b.StartTimer() | |||
for i := 0; i < b.N; i++ { | |||
_, err := fooSecConn.Write([]byte(fooWriteText)) | |||
if err != nil { | |||
b.Fatalf("Failed to write to fooSecConn: %v", err) | |||
} | |||
} | |||
b.StopTimer() | |||
fooSecConn.Close() | |||
//barSecConn.Close() race condition | |||
} |
@ -0,0 +1,593 @@ | |||
package p2p | |||
import ( | |||
"errors" | |||
"fmt" | |||
"math/rand" | |||
"net" | |||
"time" | |||
cfg "github.com/tendermint/go-config" | |||
crypto "github.com/tendermint/go-crypto" | |||
"github.com/tendermint/log15" | |||
. "github.com/tendermint/tmlibs/common" | |||
) | |||
const ( | |||
reconnectAttempts = 30 | |||
reconnectInterval = 3 * time.Second | |||
) | |||
type Reactor interface { | |||
Service // Start, Stop | |||
SetSwitch(*Switch) | |||
GetChannels() []*ChannelDescriptor | |||
AddPeer(peer *Peer) | |||
RemovePeer(peer *Peer, reason interface{}) | |||
Receive(chID byte, peer *Peer, msgBytes []byte) | |||
} | |||
//-------------------------------------- | |||
type BaseReactor struct { | |||
BaseService // Provides Start, Stop, .Quit | |||
Switch *Switch | |||
} | |||
func NewBaseReactor(log log15.Logger, name string, impl Reactor) *BaseReactor { | |||
return &BaseReactor{ | |||
BaseService: *NewBaseService(log, name, impl), | |||
Switch: nil, | |||
} | |||
} | |||
func (br *BaseReactor) SetSwitch(sw *Switch) { | |||
br.Switch = sw | |||
} | |||
func (_ *BaseReactor) GetChannels() []*ChannelDescriptor { return nil } | |||
func (_ *BaseReactor) AddPeer(peer *Peer) {} | |||
func (_ *BaseReactor) RemovePeer(peer *Peer, reason interface{}) {} | |||
func (_ *BaseReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {} | |||
//----------------------------------------------------------------------------- | |||
/* | |||
The `Switch` handles peer connections and exposes an API to receive incoming messages | |||
on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one | |||
or more `Channels`. So while sending outgoing messages is typically performed on the peer, | |||
incoming messages are received on the reactor. | |||
*/ | |||
type Switch struct { | |||
BaseService | |||
config cfg.Config | |||
listeners []Listener | |||
reactors map[string]Reactor | |||
chDescs []*ChannelDescriptor | |||
reactorsByCh map[byte]Reactor | |||
peers *PeerSet | |||
dialing *CMap | |||
nodeInfo *NodeInfo // our node info | |||
nodePrivKey crypto.PrivKeyEd25519 // our node privkey | |||
filterConnByAddr func(net.Addr) error | |||
filterConnByPubKey func(crypto.PubKeyEd25519) error | |||
} | |||
var ( | |||
ErrSwitchDuplicatePeer = errors.New("Duplicate peer") | |||
ErrSwitchMaxPeersPerIPRange = errors.New("IP range has too many peers") | |||
) | |||
func NewSwitch(config cfg.Config) *Switch { | |||
setConfigDefaults(config) | |||
sw := &Switch{ | |||
config: config, | |||
reactors: make(map[string]Reactor), | |||
chDescs: make([]*ChannelDescriptor, 0), | |||
reactorsByCh: make(map[byte]Reactor), | |||
peers: NewPeerSet(), | |||
dialing: NewCMap(), | |||
nodeInfo: nil, | |||
} | |||
sw.BaseService = *NewBaseService(log, "P2P Switch", sw) | |||
return sw | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor { | |||
// Validate the reactor. | |||
// No two reactors can share the same channel. | |||
reactorChannels := reactor.GetChannels() | |||
for _, chDesc := range reactorChannels { | |||
chID := chDesc.ID | |||
if sw.reactorsByCh[chID] != nil { | |||
PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor)) | |||
} | |||
sw.chDescs = append(sw.chDescs, chDesc) | |||
sw.reactorsByCh[chID] = reactor | |||
} | |||
sw.reactors[name] = reactor | |||
reactor.SetSwitch(sw) | |||
return reactor | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) Reactors() map[string]Reactor { | |||
return sw.reactors | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) Reactor(name string) Reactor { | |||
return sw.reactors[name] | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) AddListener(l Listener) { | |||
sw.listeners = append(sw.listeners, l) | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) Listeners() []Listener { | |||
return sw.listeners | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) IsListening() bool { | |||
return len(sw.listeners) > 0 | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) { | |||
sw.nodeInfo = nodeInfo | |||
} | |||
// Not goroutine safe. | |||
func (sw *Switch) NodeInfo() *NodeInfo { | |||
return sw.nodeInfo | |||
} | |||
// Not goroutine safe. | |||
// NOTE: Overwrites sw.nodeInfo.PubKey | |||
func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) { | |||
sw.nodePrivKey = nodePrivKey | |||
if sw.nodeInfo != nil { | |||
sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||
} | |||
} | |||
// Switch.Start() starts all the reactors, peers, and listeners. | |||
func (sw *Switch) OnStart() error { | |||
sw.BaseService.OnStart() | |||
// Start reactors | |||
for _, reactor := range sw.reactors { | |||
_, err := reactor.Start() | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
// Start peers | |||
for _, peer := range sw.peers.List() { | |||
sw.startInitPeer(peer) | |||
} | |||
// Start listeners | |||
for _, listener := range sw.listeners { | |||
go sw.listenerRoutine(listener) | |||
} | |||
return nil | |||
} | |||
func (sw *Switch) OnStop() { | |||
sw.BaseService.OnStop() | |||
// Stop listeners | |||
for _, listener := range sw.listeners { | |||
listener.Stop() | |||
} | |||
sw.listeners = nil | |||
// Stop peers | |||
for _, peer := range sw.peers.List() { | |||
peer.Stop() | |||
sw.peers.Remove(peer) | |||
} | |||
// Stop reactors | |||
for _, reactor := range sw.reactors { | |||
reactor.Stop() | |||
} | |||
} | |||
// NOTE: This performs a blocking handshake before the peer is added. | |||
// CONTRACT: If error is returned, peer is nil, and conn is immediately closed. | |||
func (sw *Switch) AddPeer(peer *Peer) error { | |||
if err := sw.FilterConnByAddr(peer.Addr()); err != nil { | |||
return err | |||
} | |||
if err := sw.FilterConnByPubKey(peer.PubKey()); err != nil { | |||
return err | |||
} | |||
if err := peer.HandshakeTimeout(sw.nodeInfo, time.Duration(sw.config.GetInt(configKeyHandshakeTimeoutSeconds))*time.Second); err != nil { | |||
return err | |||
} | |||
// Avoid self | |||
if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) { | |||
return errors.New("Ignoring connection from self") | |||
} | |||
// Check version, chain id | |||
if err := sw.nodeInfo.CompatibleWith(peer.NodeInfo); err != nil { | |||
return err | |||
} | |||
// Add the peer to .peers | |||
// ignore if duplicate or if we already have too many for that IP range | |||
if err := sw.peers.Add(peer); err != nil { | |||
log.Notice("Ignoring peer", "error", err, "peer", peer) | |||
peer.Stop() | |||
return err | |||
} | |||
// Start peer | |||
if sw.IsRunning() { | |||
sw.startInitPeer(peer) | |||
} | |||
log.Notice("Added peer", "peer", peer) | |||
return nil | |||
} | |||
func (sw *Switch) FilterConnByAddr(addr net.Addr) error { | |||
if sw.filterConnByAddr != nil { | |||
return sw.filterConnByAddr(addr) | |||
} | |||
return nil | |||
} | |||
func (sw *Switch) FilterConnByPubKey(pubkey crypto.PubKeyEd25519) error { | |||
if sw.filterConnByPubKey != nil { | |||
return sw.filterConnByPubKey(pubkey) | |||
} | |||
return nil | |||
} | |||
func (sw *Switch) SetAddrFilter(f func(net.Addr) error) { | |||
sw.filterConnByAddr = f | |||
} | |||
func (sw *Switch) SetPubKeyFilter(f func(crypto.PubKeyEd25519) error) { | |||
sw.filterConnByPubKey = f | |||
} | |||
func (sw *Switch) startInitPeer(peer *Peer) { | |||
peer.Start() // spawn send/recv routines | |||
for _, reactor := range sw.reactors { | |||
reactor.AddPeer(peer) | |||
} | |||
} | |||
// Dial a list of seeds asynchronously in random order | |||
func (sw *Switch) DialSeeds(addrBook *AddrBook, seeds []string) error { | |||
netAddrs, err := NewNetAddressStrings(seeds) | |||
if err != nil { | |||
return err | |||
} | |||
if addrBook != nil { | |||
// add seeds to `addrBook` | |||
ourAddrS := sw.nodeInfo.ListenAddr | |||
ourAddr, _ := NewNetAddressString(ourAddrS) | |||
for _, netAddr := range netAddrs { | |||
// do not add ourselves | |||
if netAddr.Equals(ourAddr) { | |||
continue | |||
} | |||
addrBook.AddAddress(netAddr, ourAddr) | |||
} | |||
addrBook.Save() | |||
} | |||
// permute the list, dial them in random order. | |||
perm := rand.Perm(len(netAddrs)) | |||
for i := 0; i < len(perm); i++ { | |||
go func(i int) { | |||
time.Sleep(time.Duration(rand.Int63n(3000)) * time.Millisecond) | |||
j := perm[i] | |||
sw.dialSeed(netAddrs[j]) | |||
}(i) | |||
} | |||
return nil | |||
} | |||
func (sw *Switch) dialSeed(addr *NetAddress) { | |||
peer, err := sw.DialPeerWithAddress(addr, true) | |||
if err != nil { | |||
log.Error("Error dialing seed", "error", err) | |||
return | |||
} else { | |||
log.Notice("Connected to seed", "peer", peer) | |||
} | |||
} | |||
func (sw *Switch) DialPeerWithAddress(addr *NetAddress, persistent bool) (*Peer, error) { | |||
sw.dialing.Set(addr.IP.String(), addr) | |||
defer sw.dialing.Delete(addr.IP.String()) | |||
peer, err := newOutboundPeerWithConfig(addr, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, peerConfigFromGoConfig(sw.config)) | |||
if err != nil { | |||
log.Info("Failed dialing peer", "address", addr, "error", err) | |||
return nil, err | |||
} | |||
if persistent { | |||
peer.makePersistent() | |||
} | |||
err = sw.AddPeer(peer) | |||
if err != nil { | |||
log.Info("Failed adding peer", "address", addr, "error", err) | |||
peer.CloseConn() | |||
return nil, err | |||
} | |||
log.Notice("Dialed and added peer", "address", addr, "peer", peer) | |||
return peer, nil | |||
} | |||
func (sw *Switch) IsDialing(addr *NetAddress) bool { | |||
return sw.dialing.Has(addr.IP.String()) | |||
} | |||
// Broadcast runs a go routine for each attempted send, which will block | |||
// trying to send for defaultSendTimeoutSeconds. Returns a channel | |||
// which receives success values for each attempted send (false if times out) | |||
// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. | |||
func (sw *Switch) Broadcast(chID byte, msg interface{}) chan bool { | |||
successChan := make(chan bool, len(sw.peers.List())) | |||
log.Debug("Broadcast", "channel", chID, "msg", msg) | |||
for _, peer := range sw.peers.List() { | |||
go func(peer *Peer) { | |||
success := peer.Send(chID, msg) | |||
successChan <- success | |||
}(peer) | |||
} | |||
return successChan | |||
} | |||
// Returns the count of outbound/inbound and outbound-dialing peers. | |||
func (sw *Switch) NumPeers() (outbound, inbound, dialing int) { | |||
peers := sw.peers.List() | |||
for _, peer := range peers { | |||
if peer.outbound { | |||
outbound++ | |||
} else { | |||
inbound++ | |||
} | |||
} | |||
dialing = sw.dialing.Size() | |||
return | |||
} | |||
func (sw *Switch) Peers() IPeerSet { | |||
return sw.peers | |||
} | |||
// Disconnect from a peer due to external error, retry if it is a persistent peer. | |||
// TODO: make record depending on reason. | |||
func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) { | |||
addr := NewNetAddress(peer.Addr()) | |||
log.Notice("Stopping peer for error", "peer", peer, "error", reason) | |||
sw.stopAndRemovePeer(peer, reason) | |||
if peer.IsPersistent() { | |||
go func() { | |||
log.Notice("Reconnecting to peer", "peer", peer) | |||
for i := 1; i < reconnectAttempts; i++ { | |||
if !sw.IsRunning() { | |||
return | |||
} | |||
peer, err := sw.DialPeerWithAddress(addr, true) | |||
if err != nil { | |||
if i == reconnectAttempts { | |||
log.Notice("Error reconnecting to peer. Giving up", "tries", i, "error", err) | |||
return | |||
} | |||
log.Notice("Error reconnecting to peer. Trying again", "tries", i, "error", err) | |||
time.Sleep(reconnectInterval) | |||
continue | |||
} | |||
log.Notice("Reconnected to peer", "peer", peer) | |||
return | |||
} | |||
}() | |||
} | |||
} | |||
// Disconnect from a peer gracefully. | |||
// TODO: handle graceful disconnects. | |||
func (sw *Switch) StopPeerGracefully(peer *Peer) { | |||
log.Notice("Stopping peer gracefully") | |||
sw.stopAndRemovePeer(peer, nil) | |||
} | |||
func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) { | |||
sw.peers.Remove(peer) | |||
peer.Stop() | |||
for _, reactor := range sw.reactors { | |||
reactor.RemovePeer(peer, reason) | |||
} | |||
} | |||
func (sw *Switch) listenerRoutine(l Listener) { | |||
for { | |||
inConn, ok := <-l.Connections() | |||
if !ok { | |||
break | |||
} | |||
// ignore connection if we already have enough | |||
maxPeers := sw.config.GetInt(configKeyMaxNumPeers) | |||
if maxPeers <= sw.peers.Size() { | |||
log.Info("Ignoring inbound connection: already have enough peers", "address", inConn.RemoteAddr().String(), "numPeers", sw.peers.Size(), "max", maxPeers) | |||
continue | |||
} | |||
// New inbound connection! | |||
err := sw.addPeerWithConnectionAndConfig(inConn, peerConfigFromGoConfig(sw.config)) | |||
if err != nil { | |||
log.Notice("Ignoring inbound connection: error while adding peer", "address", inConn.RemoteAddr().String(), "error", err) | |||
continue | |||
} | |||
// NOTE: We don't yet have the listening port of the | |||
// remote (if they have a listener at all). | |||
// The peerHandshake will handle that | |||
} | |||
// cleanup | |||
} | |||
//----------------------------------------------------------------------------- | |||
type SwitchEventNewPeer struct { | |||
Peer *Peer | |||
} | |||
type SwitchEventDonePeer struct { | |||
Peer *Peer | |||
Error interface{} | |||
} | |||
//------------------------------------------------------------------ | |||
// Switches connected via arbitrary net.Conn; useful for testing | |||
// Returns n switches, connected according to the connect func. | |||
// If connect==Connect2Switches, the switches will be fully connected. | |||
// initSwitch defines how the ith switch should be initialized (ie. with what reactors). | |||
// NOTE: panics if any switch fails to start. | |||
func MakeConnectedSwitches(n int, initSwitch func(int, *Switch) *Switch, connect func([]*Switch, int, int)) []*Switch { | |||
switches := make([]*Switch, n) | |||
for i := 0; i < n; i++ { | |||
switches[i] = makeSwitch(i, "testing", "123.123.123", initSwitch) | |||
} | |||
if err := StartSwitches(switches); err != nil { | |||
panic(err) | |||
} | |||
for i := 0; i < n; i++ { | |||
for j := i; j < n; j++ { | |||
connect(switches, i, j) | |||
} | |||
} | |||
return switches | |||
} | |||
var PanicOnAddPeerErr = false | |||
// Will connect switches i and j via net.Pipe() | |||
// Blocks until a conection is established. | |||
// NOTE: caller ensures i and j are within bounds | |||
func Connect2Switches(switches []*Switch, i, j int) { | |||
switchI := switches[i] | |||
switchJ := switches[j] | |||
c1, c2 := net.Pipe() | |||
doneCh := make(chan struct{}) | |||
go func() { | |||
err := switchI.addPeerWithConnection(c1) | |||
if PanicOnAddPeerErr && err != nil { | |||
panic(err) | |||
} | |||
doneCh <- struct{}{} | |||
}() | |||
go func() { | |||
err := switchJ.addPeerWithConnection(c2) | |||
if PanicOnAddPeerErr && err != nil { | |||
panic(err) | |||
} | |||
doneCh <- struct{}{} | |||
}() | |||
<-doneCh | |||
<-doneCh | |||
} | |||
func StartSwitches(switches []*Switch) error { | |||
for _, s := range switches { | |||
_, err := s.Start() // start switch and reactors | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func makeSwitch(i int, network, version string, initSwitch func(int, *Switch) *Switch) *Switch { | |||
privKey := crypto.GenPrivKeyEd25519() | |||
// new switch, add reactors | |||
// TODO: let the config be passed in? | |||
s := initSwitch(i, NewSwitch(cfg.NewMapConfig(nil))) | |||
s.SetNodeInfo(&NodeInfo{ | |||
PubKey: privKey.PubKey().Unwrap().(crypto.PubKeyEd25519), | |||
Moniker: Fmt("switch%d", i), | |||
Network: network, | |||
Version: version, | |||
RemoteAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023), | |||
ListenAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023), | |||
}) | |||
s.SetNodePrivKey(privKey) | |||
return s | |||
} | |||
func (sw *Switch) addPeerWithConnection(conn net.Conn) error { | |||
peer, err := newInboundPeer(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) | |||
if err != nil { | |||
conn.Close() | |||
return err | |||
} | |||
if err = sw.AddPeer(peer); err != nil { | |||
conn.Close() | |||
return err | |||
} | |||
return nil | |||
} | |||
func (sw *Switch) addPeerWithConnectionAndConfig(conn net.Conn, config *PeerConfig) error { | |||
peer, err := newInboundPeerWithConfig(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config) | |||
if err != nil { | |||
conn.Close() | |||
return err | |||
} | |||
if err = sw.AddPeer(peer); err != nil { | |||
conn.Close() | |||
return err | |||
} | |||
return nil | |||
} | |||
func peerConfigFromGoConfig(config cfg.Config) *PeerConfig { | |||
return &PeerConfig{ | |||
AuthEnc: config.GetBool(configKeyAuthEnc), | |||
Fuzz: config.GetBool(configFuzzEnable), | |||
HandshakeTimeout: time.Duration(config.GetInt(configKeyHandshakeTimeoutSeconds)) * time.Second, | |||
DialTimeout: time.Duration(config.GetInt(configKeyDialTimeoutSeconds)) * time.Second, | |||
MConfig: &MConnConfig{ | |||
SendRate: int64(config.GetInt(configKeySendRate)), | |||
RecvRate: int64(config.GetInt(configKeyRecvRate)), | |||
}, | |||
FuzzConfig: &FuzzConnConfig{ | |||
Mode: config.GetInt(configFuzzMode), | |||
MaxDelay: time.Duration(config.GetInt(configFuzzMaxDelayMilliseconds)) * time.Millisecond, | |||
ProbDropRW: config.GetFloat64(configFuzzProbDropRW), | |||
ProbDropConn: config.GetFloat64(configFuzzProbDropConn), | |||
ProbSleep: config.GetFloat64(configFuzzProbSleep), | |||
}, | |||
} | |||
} |
@ -0,0 +1,330 @@ | |||
package p2p | |||
import ( | |||
"bytes" | |||
"fmt" | |||
"net" | |||
"sync" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
. "github.com/tendermint/tmlibs/common" | |||
cfg "github.com/tendermint/go-config" | |||
crypto "github.com/tendermint/go-crypto" | |||
wire "github.com/tendermint/go-wire" | |||
) | |||
var ( | |||
config cfg.Config | |||
) | |||
func init() { | |||
config = cfg.NewMapConfig(nil) | |||
setConfigDefaults(config) | |||
} | |||
type PeerMessage struct { | |||
PeerKey string | |||
Bytes []byte | |||
Counter int | |||
} | |||
type TestReactor struct { | |||
BaseReactor | |||
mtx sync.Mutex | |||
channels []*ChannelDescriptor | |||
peersAdded []*Peer | |||
peersRemoved []*Peer | |||
logMessages bool | |||
msgsCounter int | |||
msgsReceived map[byte][]PeerMessage | |||
} | |||
func NewTestReactor(channels []*ChannelDescriptor, logMessages bool) *TestReactor { | |||
tr := &TestReactor{ | |||
channels: channels, | |||
logMessages: logMessages, | |||
msgsReceived: make(map[byte][]PeerMessage), | |||
} | |||
tr.BaseReactor = *NewBaseReactor(log, "TestReactor", tr) | |||
return tr | |||
} | |||
func (tr *TestReactor) GetChannels() []*ChannelDescriptor { | |||
return tr.channels | |||
} | |||
func (tr *TestReactor) AddPeer(peer *Peer) { | |||
tr.mtx.Lock() | |||
defer tr.mtx.Unlock() | |||
tr.peersAdded = append(tr.peersAdded, peer) | |||
} | |||
func (tr *TestReactor) RemovePeer(peer *Peer, reason interface{}) { | |||
tr.mtx.Lock() | |||
defer tr.mtx.Unlock() | |||
tr.peersRemoved = append(tr.peersRemoved, peer) | |||
} | |||
func (tr *TestReactor) Receive(chID byte, peer *Peer, msgBytes []byte) { | |||
if tr.logMessages { | |||
tr.mtx.Lock() | |||
defer tr.mtx.Unlock() | |||
//fmt.Printf("Received: %X, %X\n", chID, msgBytes) | |||
tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.Key, msgBytes, tr.msgsCounter}) | |||
tr.msgsCounter++ | |||
} | |||
} | |||
func (tr *TestReactor) getMsgs(chID byte) []PeerMessage { | |||
tr.mtx.Lock() | |||
defer tr.mtx.Unlock() | |||
return tr.msgsReceived[chID] | |||
} | |||
//----------------------------------------------------------------------------- | |||
// convenience method for creating two switches connected to each other. | |||
// XXX: note this uses net.Pipe and not a proper TCP conn | |||
func makeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) { | |||
// Create two switches that will be interconnected. | |||
switches := MakeConnectedSwitches(2, initSwitch, Connect2Switches) | |||
return switches[0], switches[1] | |||
} | |||
func initSwitchFunc(i int, sw *Switch) *Switch { | |||
// Make two reactors of two channels each | |||
sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ | |||
&ChannelDescriptor{ID: byte(0x00), Priority: 10}, | |||
&ChannelDescriptor{ID: byte(0x01), Priority: 10}, | |||
}, true)) | |||
sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ | |||
&ChannelDescriptor{ID: byte(0x02), Priority: 10}, | |||
&ChannelDescriptor{ID: byte(0x03), Priority: 10}, | |||
}, true)) | |||
return sw | |||
} | |||
func TestSwitches(t *testing.T) { | |||
s1, s2 := makeSwitchPair(t, initSwitchFunc) | |||
defer s1.Stop() | |||
defer s2.Stop() | |||
if s1.Peers().Size() != 1 { | |||
t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size()) | |||
} | |||
if s2.Peers().Size() != 1 { | |||
t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size()) | |||
} | |||
// Lets send some messages | |||
ch0Msg := "channel zero" | |||
ch1Msg := "channel foo" | |||
ch2Msg := "channel bar" | |||
s1.Broadcast(byte(0x00), ch0Msg) | |||
s1.Broadcast(byte(0x01), ch1Msg) | |||
s1.Broadcast(byte(0x02), ch2Msg) | |||
// Wait for things to settle... | |||
time.Sleep(5000 * time.Millisecond) | |||
// Check message on ch0 | |||
ch0Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x00)) | |||
if len(ch0Msgs) != 1 { | |||
t.Errorf("Expected to have received 1 message in ch0") | |||
} | |||
if !bytes.Equal(ch0Msgs[0].Bytes, wire.BinaryBytes(ch0Msg)) { | |||
t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch0Msg), ch0Msgs[0].Bytes) | |||
} | |||
// Check message on ch1 | |||
ch1Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x01)) | |||
if len(ch1Msgs) != 1 { | |||
t.Errorf("Expected to have received 1 message in ch1") | |||
} | |||
if !bytes.Equal(ch1Msgs[0].Bytes, wire.BinaryBytes(ch1Msg)) { | |||
t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch1Msg), ch1Msgs[0].Bytes) | |||
} | |||
// Check message on ch2 | |||
ch2Msgs := s2.Reactor("bar").(*TestReactor).getMsgs(byte(0x02)) | |||
if len(ch2Msgs) != 1 { | |||
t.Errorf("Expected to have received 1 message in ch2") | |||
} | |||
if !bytes.Equal(ch2Msgs[0].Bytes, wire.BinaryBytes(ch2Msg)) { | |||
t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch2Msg), ch2Msgs[0].Bytes) | |||
} | |||
} | |||
func TestConnAddrFilter(t *testing.T) { | |||
s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||
s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||
c1, c2 := net.Pipe() | |||
s1.SetAddrFilter(func(addr net.Addr) error { | |||
if addr.String() == c1.RemoteAddr().String() { | |||
return fmt.Errorf("Error: pipe is blacklisted") | |||
} | |||
return nil | |||
}) | |||
// connect to good peer | |||
go func() { | |||
s1.addPeerWithConnection(c1) | |||
}() | |||
go func() { | |||
s2.addPeerWithConnection(c2) | |||
}() | |||
// Wait for things to happen, peers to get added... | |||
time.Sleep(100 * time.Millisecond * time.Duration(4)) | |||
defer s1.Stop() | |||
defer s2.Stop() | |||
if s1.Peers().Size() != 0 { | |||
t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) | |||
} | |||
if s2.Peers().Size() != 0 { | |||
t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) | |||
} | |||
} | |||
func TestConnPubKeyFilter(t *testing.T) { | |||
s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||
s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||
c1, c2 := net.Pipe() | |||
// set pubkey filter | |||
s1.SetPubKeyFilter(func(pubkey crypto.PubKeyEd25519) error { | |||
if bytes.Equal(pubkey.Bytes(), s2.nodeInfo.PubKey.Bytes()) { | |||
return fmt.Errorf("Error: pipe is blacklisted") | |||
} | |||
return nil | |||
}) | |||
// connect to good peer | |||
go func() { | |||
s1.addPeerWithConnection(c1) | |||
}() | |||
go func() { | |||
s2.addPeerWithConnection(c2) | |||
}() | |||
// Wait for things to happen, peers to get added... | |||
time.Sleep(100 * time.Millisecond * time.Duration(4)) | |||
defer s1.Stop() | |||
defer s2.Stop() | |||
if s1.Peers().Size() != 0 { | |||
t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) | |||
} | |||
if s2.Peers().Size() != 0 { | |||
t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) | |||
} | |||
} | |||
func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||
sw.Start() | |||
defer sw.Stop() | |||
// simulate remote peer | |||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} | |||
rp.Start() | |||
defer rp.Stop() | |||
peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) | |||
require.Nil(err) | |||
err = sw.AddPeer(peer) | |||
require.Nil(err) | |||
// simulate failure by closing connection | |||
peer.CloseConn() | |||
time.Sleep(100 * time.Millisecond) | |||
assert.Zero(sw.Peers().Size()) | |||
assert.False(peer.IsRunning()) | |||
} | |||
func TestSwitchReconnectsToPersistentPeer(t *testing.T) { | |||
assert, require := assert.New(t), require.New(t) | |||
sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||
sw.Start() | |||
defer sw.Stop() | |||
// simulate remote peer | |||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} | |||
rp.Start() | |||
defer rp.Stop() | |||
peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) | |||
peer.makePersistent() | |||
require.Nil(err) | |||
err = sw.AddPeer(peer) | |||
require.Nil(err) | |||
// simulate failure by closing connection | |||
peer.CloseConn() | |||
time.Sleep(100 * time.Millisecond) | |||
assert.NotZero(sw.Peers().Size()) | |||
assert.False(peer.IsRunning()) | |||
} | |||
func BenchmarkSwitches(b *testing.B) { | |||
b.StopTimer() | |||
s1, s2 := makeSwitchPair(b, func(i int, sw *Switch) *Switch { | |||
// Make bar reactors of bar channels each | |||
sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ | |||
&ChannelDescriptor{ID: byte(0x00), Priority: 10}, | |||
&ChannelDescriptor{ID: byte(0x01), Priority: 10}, | |||
}, false)) | |||
sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ | |||
&ChannelDescriptor{ID: byte(0x02), Priority: 10}, | |||
&ChannelDescriptor{ID: byte(0x03), Priority: 10}, | |||
}, false)) | |||
return sw | |||
}) | |||
defer s1.Stop() | |||
defer s2.Stop() | |||
// Allow time for goroutines to boot up | |||
time.Sleep(1000 * time.Millisecond) | |||
b.StartTimer() | |||
numSuccess, numFailure := 0, 0 | |||
// Send random message from foo channel to another | |||
for i := 0; i < b.N; i++ { | |||
chID := byte(i % 4) | |||
successChan := s1.Broadcast(chID, "test data") | |||
for s := range successChan { | |||
if s { | |||
numSuccess++ | |||
} else { | |||
numFailure++ | |||
} | |||
} | |||
} | |||
log.Warn(Fmt("success: %v, failure: %v", numSuccess, numFailure)) | |||
// Allow everything to flush before stopping switches & closing connections. | |||
b.StopTimer() | |||
time.Sleep(1000 * time.Millisecond) | |||
} |
@ -0,0 +1,77 @@ | |||
package p2p | |||
import ( | |||
"fmt" | |||
"net" | |||
"strconv" | |||
"strings" | |||
"github.com/tendermint/go-crypto" | |||
) | |||
const maxNodeInfoSize = 10240 // 10Kb | |||
type NodeInfo struct { | |||
PubKey crypto.PubKeyEd25519 `json:"pub_key"` | |||
Moniker string `json:"moniker"` | |||
Network string `json:"network"` | |||
RemoteAddr string `json:"remote_addr"` | |||
ListenAddr string `json:"listen_addr"` | |||
Version string `json:"version"` // major.minor.revision | |||
Other []string `json:"other"` // other application specific data | |||
} | |||
// CONTRACT: two nodes are compatible if the major/minor versions match and network match | |||
func (info *NodeInfo) CompatibleWith(other *NodeInfo) error { | |||
iMajor, iMinor, _, iErr := splitVersion(info.Version) | |||
oMajor, oMinor, _, oErr := splitVersion(other.Version) | |||
// if our own version number is not formatted right, we messed up | |||
if iErr != nil { | |||
return iErr | |||
} | |||
// version number must be formatted correctly ("x.x.x") | |||
if oErr != nil { | |||
return oErr | |||
} | |||
// major version must match | |||
if iMajor != oMajor { | |||
return fmt.Errorf("Peer is on a different major version. Got %v, expected %v", oMajor, iMajor) | |||
} | |||
// minor version must match | |||
if iMinor != oMinor { | |||
return fmt.Errorf("Peer is on a different minor version. Got %v, expected %v", oMinor, iMinor) | |||
} | |||
// nodes must be on the same network | |||
if info.Network != other.Network { | |||
return fmt.Errorf("Peer is on a different network. Got %v, expected %v", other.Network, info.Network) | |||
} | |||
return nil | |||
} | |||
func (info *NodeInfo) ListenHost() string { | |||
host, _, _ := net.SplitHostPort(info.ListenAddr) | |||
return host | |||
} | |||
func (info *NodeInfo) ListenPort() int { | |||
_, port, _ := net.SplitHostPort(info.ListenAddr) | |||
port_i, err := strconv.Atoi(port) | |||
if err != nil { | |||
return -1 | |||
} | |||
return port_i | |||
} | |||
func splitVersion(version string) (string, string, string, error) { | |||
spl := strings.Split(version, ".") | |||
if len(spl) != 3 { | |||
return "", "", "", fmt.Errorf("Invalid version format %v", version) | |||
} | |||
return spl[0], spl[1], spl[2], nil | |||
} |
@ -0,0 +1,5 @@ | |||
# `tendermint/p2p/upnp` | |||
## Resources | |||
* http://www.upnp-hacks.org/upnp.html |
@ -0,0 +1,7 @@ | |||
package upnp | |||
import ( | |||
"github.com/tendermint/tmlibs/logger" | |||
) | |||
var log = logger.New("module", "upnp") |
@ -0,0 +1,111 @@ | |||
package upnp | |||
import ( | |||
"errors" | |||
"fmt" | |||
"net" | |||
"time" | |||
. "github.com/tendermint/tmlibs/common" | |||
) | |||
type UPNPCapabilities struct { | |||
PortMapping bool | |||
Hairpin bool | |||
} | |||
func makeUPNPListener(intPort int, extPort int) (NAT, net.Listener, net.IP, error) { | |||
nat, err := Discover() | |||
if err != nil { | |||
return nil, nil, nil, errors.New(fmt.Sprintf("NAT upnp could not be discovered: %v", err)) | |||
} | |||
log.Info(Fmt("ourIP: %v", nat.(*upnpNAT).ourIP)) | |||
ext, err := nat.GetExternalAddress() | |||
if err != nil { | |||
return nat, nil, nil, errors.New(fmt.Sprintf("External address error: %v", err)) | |||
} | |||
log.Info(Fmt("External address: %v", ext)) | |||
port, err := nat.AddPortMapping("tcp", extPort, intPort, "Tendermint UPnP Probe", 0) | |||
if err != nil { | |||
return nat, nil, ext, errors.New(fmt.Sprintf("Port mapping error: %v", err)) | |||
} | |||
log.Info(Fmt("Port mapping mapped: %v", port)) | |||
// also run the listener, open for all remote addresses. | |||
listener, err := net.Listen("tcp", fmt.Sprintf(":%v", intPort)) | |||
if err != nil { | |||
return nat, nil, ext, errors.New(fmt.Sprintf("Error establishing listener: %v", err)) | |||
} | |||
return nat, listener, ext, nil | |||
} | |||
func testHairpin(listener net.Listener, extAddr string) (supportsHairpin bool) { | |||
// Listener | |||
go func() { | |||
inConn, err := listener.Accept() | |||
if err != nil { | |||
log.Notice(Fmt("Listener.Accept() error: %v", err)) | |||
return | |||
} | |||
log.Info(Fmt("Accepted incoming connection: %v -> %v", inConn.LocalAddr(), inConn.RemoteAddr())) | |||
buf := make([]byte, 1024) | |||
n, err := inConn.Read(buf) | |||
if err != nil { | |||
log.Notice(Fmt("Incoming connection read error: %v", err)) | |||
return | |||
} | |||
log.Info(Fmt("Incoming connection read %v bytes: %X", n, buf)) | |||
if string(buf) == "test data" { | |||
supportsHairpin = true | |||
return | |||
} | |||
}() | |||
// Establish outgoing | |||
outConn, err := net.Dial("tcp", extAddr) | |||
if err != nil { | |||
log.Notice(Fmt("Outgoing connection dial error: %v", err)) | |||
return | |||
} | |||
n, err := outConn.Write([]byte("test data")) | |||
if err != nil { | |||
log.Notice(Fmt("Outgoing connection write error: %v", err)) | |||
return | |||
} | |||
log.Info(Fmt("Outgoing connection wrote %v bytes", n)) | |||
// Wait for data receipt | |||
time.Sleep(1 * time.Second) | |||
return | |||
} | |||
func Probe() (caps UPNPCapabilities, err error) { | |||
log.Info("Probing for UPnP!") | |||
intPort, extPort := 8001, 8001 | |||
nat, listener, ext, err := makeUPNPListener(intPort, extPort) | |||
if err != nil { | |||
return | |||
} | |||
caps.PortMapping = true | |||
// Deferred cleanup | |||
defer func() { | |||
err = nat.DeletePortMapping("tcp", intPort, extPort) | |||
if err != nil { | |||
log.Warn(Fmt("Port mapping delete error: %v", err)) | |||
} | |||
listener.Close() | |||
}() | |||
supportsHairpin := testHairpin(listener, fmt.Sprintf("%v:%v", ext, extPort)) | |||
if supportsHairpin { | |||
caps.Hairpin = true | |||
} | |||
return | |||
} |
@ -0,0 +1,380 @@ | |||
/* | |||
Taken from taipei-torrent | |||
Just enough UPnP to be able to forward ports | |||
*/ | |||
package upnp | |||
// BUG(jae): TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh | |||
import ( | |||
"bytes" | |||
"encoding/xml" | |||
"errors" | |||
"io/ioutil" | |||
"net" | |||
"net/http" | |||
"strconv" | |||
"strings" | |||
"time" | |||
) | |||
type upnpNAT struct { | |||
serviceURL string | |||
ourIP string | |||
urnDomain string | |||
} | |||
// protocol is either "udp" or "tcp" | |||
type NAT interface { | |||
GetExternalAddress() (addr net.IP, err error) | |||
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) | |||
DeletePortMapping(protocol string, externalPort, internalPort int) (err error) | |||
} | |||
func Discover() (nat NAT, err error) { | |||
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") | |||
if err != nil { | |||
return | |||
} | |||
conn, err := net.ListenPacket("udp4", ":0") | |||
if err != nil { | |||
return | |||
} | |||
socket := conn.(*net.UDPConn) | |||
defer socket.Close() | |||
err = socket.SetDeadline(time.Now().Add(3 * time.Second)) | |||
if err != nil { | |||
return | |||
} | |||
st := "InternetGatewayDevice:1" | |||
buf := bytes.NewBufferString( | |||
"M-SEARCH * HTTP/1.1\r\n" + | |||
"HOST: 239.255.255.250:1900\r\n" + | |||
"ST: ssdp:all\r\n" + | |||
"MAN: \"ssdp:discover\"\r\n" + | |||
"MX: 2\r\n\r\n") | |||
message := buf.Bytes() | |||
answerBytes := make([]byte, 1024) | |||
for i := 0; i < 3; i++ { | |||
_, err = socket.WriteToUDP(message, ssdp) | |||
if err != nil { | |||
return | |||
} | |||
var n int | |||
n, _, err = socket.ReadFromUDP(answerBytes) | |||
for { | |||
n, _, err = socket.ReadFromUDP(answerBytes) | |||
if err != nil { | |||
break | |||
} | |||
answer := string(answerBytes[0:n]) | |||
if strings.Index(answer, st) < 0 { | |||
continue | |||
} | |||
// HTTP header field names are case-insensitive. | |||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 | |||
locString := "\r\nlocation:" | |||
answer = strings.ToLower(answer) | |||
locIndex := strings.Index(answer, locString) | |||
if locIndex < 0 { | |||
continue | |||
} | |||
loc := answer[locIndex+len(locString):] | |||
endIndex := strings.Index(loc, "\r\n") | |||
if endIndex < 0 { | |||
continue | |||
} | |||
locURL := strings.TrimSpace(loc[0:endIndex]) | |||
var serviceURL, urnDomain string | |||
serviceURL, urnDomain, err = getServiceURL(locURL) | |||
if err != nil { | |||
return | |||
} | |||
var ourIP net.IP | |||
ourIP, err = localIPv4() | |||
if err != nil { | |||
return | |||
} | |||
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} | |||
return | |||
} | |||
} | |||
err = errors.New("UPnP port discovery failed.") | |||
return | |||
} | |||
type Envelope struct { | |||
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` | |||
Soap *SoapBody | |||
} | |||
type SoapBody struct { | |||
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` | |||
ExternalIP *ExternalIPAddressResponse | |||
} | |||
type ExternalIPAddressResponse struct { | |||
XMLName xml.Name `xml:"GetExternalIPAddressResponse"` | |||
IPAddress string `xml:"NewExternalIPAddress"` | |||
} | |||
type ExternalIPAddress struct { | |||
XMLName xml.Name `xml:"NewExternalIPAddress"` | |||
IP string | |||
} | |||
type UPNPService struct { | |||
ServiceType string `xml:"serviceType"` | |||
ControlURL string `xml:"controlURL"` | |||
} | |||
type DeviceList struct { | |||
Device []Device `xml:"device"` | |||
} | |||
type ServiceList struct { | |||
Service []UPNPService `xml:"service"` | |||
} | |||
type Device struct { | |||
XMLName xml.Name `xml:"device"` | |||
DeviceType string `xml:"deviceType"` | |||
DeviceList DeviceList `xml:"deviceList"` | |||
ServiceList ServiceList `xml:"serviceList"` | |||
} | |||
type Root struct { | |||
Device Device | |||
} | |||
func getChildDevice(d *Device, deviceType string) *Device { | |||
dl := d.DeviceList.Device | |||
for i := 0; i < len(dl); i++ { | |||
if strings.Index(dl[i].DeviceType, deviceType) >= 0 { | |||
return &dl[i] | |||
} | |||
} | |||
return nil | |||
} | |||
func getChildService(d *Device, serviceType string) *UPNPService { | |||
sl := d.ServiceList.Service | |||
for i := 0; i < len(sl); i++ { | |||
if strings.Index(sl[i].ServiceType, serviceType) >= 0 { | |||
return &sl[i] | |||
} | |||
} | |||
return nil | |||
} | |||
func localIPv4() (net.IP, error) { | |||
tt, err := net.Interfaces() | |||
if err != nil { | |||
return nil, err | |||
} | |||
for _, t := range tt { | |||
aa, err := t.Addrs() | |||
if err != nil { | |||
return nil, err | |||
} | |||
for _, a := range aa { | |||
ipnet, ok := a.(*net.IPNet) | |||
if !ok { | |||
continue | |||
} | |||
v4 := ipnet.IP.To4() | |||
if v4 == nil || v4[0] == 127 { // loopback address | |||
continue | |||
} | |||
return v4, nil | |||
} | |||
} | |||
return nil, errors.New("cannot find local IP address") | |||
} | |||
func getServiceURL(rootURL string) (url, urnDomain string, err error) { | |||
r, err := http.Get(rootURL) | |||
if err != nil { | |||
return | |||
} | |||
defer r.Body.Close() | |||
if r.StatusCode >= 400 { | |||
err = errors.New(string(r.StatusCode)) | |||
return | |||
} | |||
var root Root | |||
err = xml.NewDecoder(r.Body).Decode(&root) | |||
if err != nil { | |||
return | |||
} | |||
a := &root.Device | |||
if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 { | |||
err = errors.New("No InternetGatewayDevice") | |||
return | |||
} | |||
b := getChildDevice(a, "WANDevice:1") | |||
if b == nil { | |||
err = errors.New("No WANDevice") | |||
return | |||
} | |||
c := getChildDevice(b, "WANConnectionDevice:1") | |||
if c == nil { | |||
err = errors.New("No WANConnectionDevice") | |||
return | |||
} | |||
d := getChildService(c, "WANIPConnection:1") | |||
if d == nil { | |||
// Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, | |||
// instead of under WanConnectionDevice | |||
d = getChildService(b, "WANIPConnection:1") | |||
if d == nil { | |||
err = errors.New("No WANIPConnection") | |||
return | |||
} | |||
} | |||
// Extract the domain name, which isn't always 'schemas-upnp-org' | |||
urnDomain = strings.Split(d.ServiceType, ":")[1] | |||
url = combineURL(rootURL, d.ControlURL) | |||
return | |||
} | |||
func combineURL(rootURL, subURL string) string { | |||
protocolEnd := "://" | |||
protoEndIndex := strings.Index(rootURL, protocolEnd) | |||
a := rootURL[protoEndIndex+len(protocolEnd):] | |||
rootIndex := strings.Index(a, "/") | |||
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL | |||
} | |||
func soapRequest(url, function, message, domain string) (r *http.Response, err error) { | |||
fullMessage := "<?xml version=\"1.0\" ?>" + | |||
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" + | |||
"<s:Body>" + message + "</s:Body></s:Envelope>" | |||
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) | |||
if err != nil { | |||
return nil, err | |||
} | |||
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") | |||
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") | |||
//req.Header.Set("Transfer-Encoding", "chunked") | |||
req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") | |||
req.Header.Set("Connection", "Close") | |||
req.Header.Set("Cache-Control", "no-cache") | |||
req.Header.Set("Pragma", "no-cache") | |||
// log.Stderr("soapRequest ", req) | |||
r, err = http.DefaultClient.Do(req) | |||
if err != nil { | |||
return nil, err | |||
} | |||
/*if r.Body != nil { | |||
defer r.Body.Close() | |||
}*/ | |||
if r.StatusCode >= 400 { | |||
// log.Stderr(function, r.StatusCode) | |||
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) | |||
r = nil | |||
return | |||
} | |||
return | |||
} | |||
type statusInfo struct { | |||
externalIpAddress string | |||
} | |||
func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) { | |||
message := "<u:GetExternalIPAddress xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + | |||
"</u:GetExternalIPAddress>" | |||
var response *http.Response | |||
response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) | |||
if response != nil { | |||
defer response.Body.Close() | |||
} | |||
if err != nil { | |||
return | |||
} | |||
var envelope Envelope | |||
data, err := ioutil.ReadAll(response.Body) | |||
reader := bytes.NewReader(data) | |||
xml.NewDecoder(reader).Decode(&envelope) | |||
info = statusInfo{envelope.Soap.ExternalIP.IPAddress} | |||
if err != nil { | |||
return | |||
} | |||
return | |||
} | |||
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { | |||
info, err := n.getExternalIPAddress() | |||
if err != nil { | |||
return | |||
} | |||
addr = net.ParseIP(info.externalIpAddress) | |||
return | |||
} | |||
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { | |||
// A single concatenation would break ARM compilation. | |||
message := "<u:AddPortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + | |||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) | |||
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" | |||
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + | |||
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + | |||
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" | |||
message += description + | |||
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + | |||
"</NewLeaseDuration></u:AddPortMapping>" | |||
var response *http.Response | |||
response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) | |||
if response != nil { | |||
defer response.Body.Close() | |||
} | |||
if err != nil { | |||
return | |||
} | |||
// TODO: check response to see if the port was forwarded | |||
// log.Println(message, response) | |||
// JAE: | |||
// body, err := ioutil.ReadAll(response.Body) | |||
// fmt.Println(string(body), err) | |||
mappedExternalPort = externalPort | |||
_ = response | |||
return | |||
} | |||
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { | |||
message := "<u:DeletePortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + | |||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + | |||
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + | |||
"</u:DeletePortMapping>" | |||
var response *http.Response | |||
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) | |||
if response != nil { | |||
defer response.Body.Close() | |||
} | |||
if err != nil { | |||
return | |||
} | |||
// TODO: check response to see if the port was deleted | |||
// log.Println(message, response) | |||
_ = response | |||
return | |||
} |
@ -0,0 +1,15 @@ | |||
package p2p | |||
import ( | |||
"crypto/sha256" | |||
) | |||
// doubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes. | |||
func doubleSha256(b []byte) []byte { | |||
hasher := sha256.New() | |||
hasher.Write(b) | |||
sum := hasher.Sum(nil) | |||
hasher.Reset() | |||
hasher.Write(sum) | |||
return hasher.Sum(nil) | |||
} |
@ -0,0 +1,3 @@ | |||
package p2p | |||
const Version = "0.5.0" |
@ -1,7 +1,7 @@ | |||
package proxy | |||
import ( | |||
"github.com/tendermint/go-logger" | |||
"github.com/tendermint/tmlibs/logger" | |||
) | |||
var log = logger.New("module", "proxy") |
@ -0,0 +1,12 @@ | |||
FROM golang:latest | |||
RUN mkdir -p /go/src/github.com/tendermint/tendermint/rpc | |||
WORKDIR /go/src/github.com/tendermint/tendermint/rpc | |||
COPY Makefile /go/src/github.com/tendermint/tendermint/rpc/ | |||
# COPY glide.yaml /go/src/github.com/tendermint/tendermint/rpc/ | |||
# COPY glide.lock /go/src/github.com/tendermint/tendermint/rpc/ | |||
COPY . /go/src/github.com/tendermint/tendermint/rpc | |||
RUN make get_deps |
@ -0,0 +1,18 @@ | |||
PACKAGES=$(shell go list ./... | grep -v "test") | |||
all: get_deps test | |||
test: | |||
@echo "--> Running go test --race" | |||
@go test --race $(PACKAGES) | |||
@echo "--> Running integration tests" | |||
@bash ./test/integration_test.sh | |||
get_deps: | |||
@echo "--> Running go get" | |||
@go get -v -d $(PACKAGES) | |||
@go list -f '{{join .TestImports "\n"}}' ./... | \ | |||
grep -v /vendor/ | sort | uniq | \ | |||
xargs go get -v -d | |||
.PHONY: all test get_deps |
@ -0,0 +1,128 @@ | |||
# tendermint/rpc | |||
[![CircleCI](https://circleci.com/gh/tendermint/tendermint/rpc.svg?style=svg)](https://circleci.com/gh/tendermint/tendermint/rpc) | |||
HTTP RPC server supporting calls via uri params, jsonrpc, and jsonrpc over websockets | |||
# Client Requests | |||
Suppose we want to expose the rpc function `HelloWorld(name string, num int)`. | |||
## GET (URI) | |||
As a GET request, it would have URI encoded parameters, and look like: | |||
``` | |||
curl 'http://localhost:8008/hello_world?name="my_world"&num=5' | |||
``` | |||
Note the `'` around the url, which is just so bash doesn't ignore the quotes in `"my_world"`. | |||
This should also work: | |||
``` | |||
curl http://localhost:8008/hello_world?name=\"my_world\"&num=5 | |||
``` | |||
A GET request to `/` returns a list of available endpoints. | |||
For those which take arguments, the arguments will be listed in order, with `_` where the actual value should be. | |||
## POST (JSONRPC) | |||
As a POST request, we use JSONRPC. For instance, the same request would have this as the body: | |||
``` | |||
{ | |||
"jsonrpc": "2.0", | |||
"id": "anything", | |||
"method": "hello_world", | |||
"params": { | |||
"name": "my_world", | |||
"num": 5 | |||
} | |||
} | |||
``` | |||
With the above saved in file `data.json`, we can make the request with | |||
``` | |||
curl --data @data.json http://localhost:8008 | |||
``` | |||
## WebSocket (JSONRPC) | |||
All requests are exposed over websocket in the same form as the POST JSONRPC. | |||
Websocket connections are available at their own endpoint, typically `/websocket`, | |||
though this is configurable when starting the server. | |||
# Server Definition | |||
Define some types and routes: | |||
``` | |||
// Define a type for results and register concrete versions with go-wire | |||
type Result interface{} | |||
type ResultStatus struct { | |||
Value string | |||
} | |||
var _ = wire.RegisterInterface( | |||
struct{ Result }{}, | |||
wire.ConcreteType{&ResultStatus{}, 0x1}, | |||
) | |||
// Define some routes | |||
var Routes = map[string]*rpcserver.RPCFunc{ | |||
"status": rpcserver.NewRPCFunc(StatusResult, "arg"), | |||
} | |||
// an rpc function | |||
func StatusResult(v string) (Result, error) { | |||
return &ResultStatus{v}, nil | |||
} | |||
``` | |||
Now start the server: | |||
``` | |||
mux := http.NewServeMux() | |||
rpcserver.RegisterRPCFuncs(mux, Routes) | |||
wm := rpcserver.NewWebsocketManager(Routes, nil) | |||
mux.HandleFunc("/websocket", wm.WebsocketHandler) | |||
go func() { | |||
_, err := rpcserver.StartHTTPServer("0.0.0.0:8008", mux) | |||
if err != nil { | |||
panic(err) | |||
} | |||
}() | |||
``` | |||
Note that unix sockets are supported as well (eg. `/path/to/socket` instead of `0.0.0.0:8008`) | |||
Now see all available endpoints by sending a GET request to `0.0.0.0:8008`. | |||
Each route is available as a GET request, as a JSONRPCv2 POST request, and via JSONRPCv2 over websockets. | |||
# Examples | |||
* [Tendermint](https://github.com/tendermint/tendermint/blob/master/rpc/core/routes.go) | |||
* [tm-monitor](https://github.com/tendermint/tools/blob/master/tm-monitor/rpc.go) | |||
## CHANGELOG | |||
### 0.7.0 | |||
BREAKING CHANGES: | |||
- removed `Client` empty interface | |||
- `ClientJSONRPC#Call` `params` argument became a map | |||
- rename `ClientURI` -> `URIClient`, `ClientJSONRPC` -> `JSONRPCClient` | |||
IMPROVEMENTS: | |||
- added `HTTPClient` interface, which can be used for both `ClientURI` | |||
and `ClientJSONRPC` | |||
- all params are now optional (Golang's default will be used if some param is missing) | |||
- added `Call` method to `WSClient` (see method's doc for details) |
@ -0,0 +1,21 @@ | |||
machine: | |||
environment: | |||
GOPATH: /home/ubuntu/.go_workspace | |||
REPO: $GOPATH/src/github.com/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME | |||
hosts: | |||
circlehost: 127.0.0.1 | |||
localhost: 127.0.0.1 | |||
checkout: | |||
post: | |||
- rm -rf $REPO | |||
- mkdir -p $HOME/.go_workspace/src/github.com/$CIRCLE_PROJECT_USERNAME | |||
- mv $HOME/$CIRCLE_PROJECT_REPONAME $REPO | |||
dependencies: | |||
override: | |||
- "cd $REPO && make get_deps" | |||
test: | |||
override: | |||
- "cd $REPO && make test" |
@ -0,0 +1,39 @@ | |||
package rpcclient | |||
import ( | |||
"testing" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
) | |||
type Tx []byte | |||
type Foo struct { | |||
Bar int | |||
Baz string | |||
} | |||
func TestArgToJSON(t *testing.T) { | |||
assert := assert.New(t) | |||
require := require.New(t) | |||
cases := []struct { | |||
input interface{} | |||
expected string | |||
}{ | |||
{[]byte("1234"), "0x31323334"}, | |||
{Tx("654"), "0x363534"}, | |||
{Foo{7, "hello"}, `{"Bar":7,"Baz":"hello"}`}, | |||
} | |||
for i, tc := range cases { | |||
args := map[string]interface{}{"data": tc.input} | |||
err := argsToJson(args) | |||
require.Nil(err, "%d: %+v", i, err) | |||
require.Equal(1, len(args), "%d", i) | |||
data, ok := args["data"].(string) | |||
require.True(ok, "%d: %#v", i, args["data"]) | |||
assert.Equal(tc.expected, data, "%d", i) | |||
} | |||
} |
@ -0,0 +1,199 @@ | |||
package rpcclient | |||
import ( | |||
"bytes" | |||
"encoding/json" | |||
"fmt" | |||
"io/ioutil" | |||
"net" | |||
"net/http" | |||
"net/url" | |||
"reflect" | |||
"strings" | |||
"github.com/pkg/errors" | |||
wire "github.com/tendermint/go-wire" | |||
types "github.com/tendermint/tendermint/rpc/types" | |||
) | |||
// HTTPClient is a common interface for JSONRPCClient and URIClient. | |||
type HTTPClient interface { | |||
Call(method string, params map[string]interface{}, result interface{}) (interface{}, error) | |||
} | |||
// TODO: Deprecate support for IP:PORT or /path/to/socket | |||
func makeHTTPDialer(remoteAddr string) (string, func(string, string) (net.Conn, error)) { | |||
parts := strings.SplitN(remoteAddr, "://", 2) | |||
var protocol, address string | |||
if len(parts) != 2 { | |||
log.Warn("WARNING (tendermint/rpc): Please use fully formed listening addresses, including the tcp:// or unix:// prefix") | |||
protocol = types.SocketType(remoteAddr) | |||
address = remoteAddr | |||
} else { | |||
protocol, address = parts[0], parts[1] | |||
} | |||
trimmedAddress := strings.Replace(address, "/", ".", -1) // replace / with . for http requests (dummy domain) | |||
return trimmedAddress, func(proto, addr string) (net.Conn, error) { | |||
return net.Dial(protocol, address) | |||
} | |||
} | |||
// We overwrite the http.Client.Dial so we can do http over tcp or unix. | |||
// remoteAddr should be fully featured (eg. with tcp:// or unix://) | |||
func makeHTTPClient(remoteAddr string) (string, *http.Client) { | |||
address, dialer := makeHTTPDialer(remoteAddr) | |||
return "http://" + address, &http.Client{ | |||
Transport: &http.Transport{ | |||
Dial: dialer, | |||
}, | |||
} | |||
} | |||
//------------------------------------------------------------------------------------ | |||
// JSON rpc takes params as a slice | |||
type JSONRPCClient struct { | |||
address string | |||
client *http.Client | |||
} | |||
func NewJSONRPCClient(remote string) *JSONRPCClient { | |||
address, client := makeHTTPClient(remote) | |||
return &JSONRPCClient{ | |||
address: address, | |||
client: client, | |||
} | |||
} | |||
func (c *JSONRPCClient) Call(method string, params map[string]interface{}, result interface{}) (interface{}, error) { | |||
// we need this step because we attempt to decode values using `go-wire` | |||
// (handlers.go:176) on the server side | |||
encodedParams := make(map[string]interface{}) | |||
for k, v := range params { | |||
bytes := json.RawMessage(wire.JSONBytes(v)) | |||
encodedParams[k] = &bytes | |||
} | |||
request := types.RPCRequest{ | |||
JSONRPC: "2.0", | |||
Method: method, | |||
Params: encodedParams, | |||
ID: "", | |||
} | |||
requestBytes, err := json.Marshal(request) | |||
if err != nil { | |||
return nil, err | |||
} | |||
// log.Info(string(requestBytes)) | |||
requestBuf := bytes.NewBuffer(requestBytes) | |||
// log.Info(Fmt("RPC request to %v (%v): %v", c.remote, method, string(requestBytes))) | |||
httpResponse, err := c.client.Post(c.address, "text/json", requestBuf) | |||
if err != nil { | |||
return nil, err | |||
} | |||
defer httpResponse.Body.Close() | |||
responseBytes, err := ioutil.ReadAll(httpResponse.Body) | |||
if err != nil { | |||
return nil, err | |||
} | |||
// log.Info(Fmt("RPC response: %v", string(responseBytes))) | |||
return unmarshalResponseBytes(responseBytes, result) | |||
} | |||
//------------------------------------------------------------- | |||
// URI takes params as a map | |||
type URIClient struct { | |||
address string | |||
client *http.Client | |||
} | |||
func NewURIClient(remote string) *URIClient { | |||
address, client := makeHTTPClient(remote) | |||
return &URIClient{ | |||
address: address, | |||
client: client, | |||
} | |||
} | |||
func (c *URIClient) Call(method string, params map[string]interface{}, result interface{}) (interface{}, error) { | |||
values, err := argsToURLValues(params) | |||
if err != nil { | |||
return nil, err | |||
} | |||
// log.Info(Fmt("URI request to %v (%v): %v", c.address, method, values)) | |||
resp, err := c.client.PostForm(c.address+"/"+method, values) | |||
if err != nil { | |||
return nil, err | |||
} | |||
defer resp.Body.Close() | |||
responseBytes, err := ioutil.ReadAll(resp.Body) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return unmarshalResponseBytes(responseBytes, result) | |||
} | |||
//------------------------------------------------ | |||
func unmarshalResponseBytes(responseBytes []byte, result interface{}) (interface{}, error) { | |||
// read response | |||
// if rpc/core/types is imported, the result will unmarshal | |||
// into the correct type | |||
// log.Notice("response", "response", string(responseBytes)) | |||
var err error | |||
response := &types.RPCResponse{} | |||
err = json.Unmarshal(responseBytes, response) | |||
if err != nil { | |||
return nil, errors.Errorf("Error unmarshalling rpc response: %v", err) | |||
} | |||
errorStr := response.Error | |||
if errorStr != "" { | |||
return nil, errors.Errorf("Response error: %v", errorStr) | |||
} | |||
// unmarshal the RawMessage into the result | |||
result = wire.ReadJSONPtr(result, *response.Result, &err) | |||
if err != nil { | |||
return nil, errors.Errorf("Error unmarshalling rpc response result: %v", err) | |||
} | |||
return result, nil | |||
} | |||
func argsToURLValues(args map[string]interface{}) (url.Values, error) { | |||
values := make(url.Values) | |||
if len(args) == 0 { | |||
return values, nil | |||
} | |||
err := argsToJson(args) | |||
if err != nil { | |||
return nil, err | |||
} | |||
for key, val := range args { | |||
values.Set(key, val.(string)) | |||
} | |||
return values, nil | |||
} | |||
func argsToJson(args map[string]interface{}) error { | |||
var n int | |||
var err error | |||
for k, v := range args { | |||
rt := reflect.TypeOf(v) | |||
isByteSlice := rt.Kind() == reflect.Slice && rt.Elem().Kind() == reflect.Uint8 | |||
if isByteSlice { | |||
bytes := reflect.ValueOf(v).Bytes() | |||
args[k] = fmt.Sprintf("0x%X", bytes) | |||
continue | |||
} | |||
// Pass everything else to go-wire | |||
buf := new(bytes.Buffer) | |||
wire.WriteJSON(v, buf, &n, &err) | |||
if err != nil { | |||
return err | |||
} | |||
args[k] = buf.String() | |||
} | |||
return nil | |||
} |
@ -0,0 +1,7 @@ | |||
package rpcclient | |||
import ( | |||
"github.com/tendermint/log15" | |||
) | |||
var log = log15.New("module", "rpcclient") |
@ -0,0 +1,172 @@ | |||
package rpcclient | |||
import ( | |||
"encoding/json" | |||
"net" | |||
"net/http" | |||
"time" | |||
"github.com/gorilla/websocket" | |||
"github.com/pkg/errors" | |||
cmn "github.com/tendermint/tmlibs/common" | |||
types "github.com/tendermint/tendermint/rpc/types" | |||
wire "github.com/tendermint/go-wire" | |||
) | |||
const ( | |||
wsResultsChannelCapacity = 10 | |||
wsErrorsChannelCapacity = 1 | |||
wsWriteTimeoutSeconds = 10 | |||
) | |||
type WSClient struct { | |||
cmn.BaseService | |||
Address string // IP:PORT or /path/to/socket | |||
Endpoint string // /websocket/url/endpoint | |||
Dialer func(string, string) (net.Conn, error) | |||
*websocket.Conn | |||
ResultsCh chan json.RawMessage // closes upon WSClient.Stop() | |||
ErrorsCh chan error // closes upon WSClient.Stop() | |||
} | |||
// create a new connection | |||
func NewWSClient(remoteAddr, endpoint string) *WSClient { | |||
addr, dialer := makeHTTPDialer(remoteAddr) | |||
wsClient := &WSClient{ | |||
Address: addr, | |||
Dialer: dialer, | |||
Endpoint: endpoint, | |||
Conn: nil, | |||
} | |||
wsClient.BaseService = *cmn.NewBaseService(log, "WSClient", wsClient) | |||
return wsClient | |||
} | |||
func (wsc *WSClient) String() string { | |||
return wsc.Address + ", " + wsc.Endpoint | |||
} | |||
// OnStart implements cmn.BaseService interface | |||
func (wsc *WSClient) OnStart() error { | |||
wsc.BaseService.OnStart() | |||
err := wsc.dial() | |||
if err != nil { | |||
return err | |||
} | |||
wsc.ResultsCh = make(chan json.RawMessage, wsResultsChannelCapacity) | |||
wsc.ErrorsCh = make(chan error, wsErrorsChannelCapacity) | |||
go wsc.receiveEventsRoutine() | |||
return nil | |||
} | |||
// OnReset implements cmn.BaseService interface | |||
func (wsc *WSClient) OnReset() error { | |||
return nil | |||
} | |||
func (wsc *WSClient) dial() error { | |||
// Dial | |||
dialer := &websocket.Dialer{ | |||
NetDial: wsc.Dialer, | |||
Proxy: http.ProxyFromEnvironment, | |||
} | |||
rHeader := http.Header{} | |||
con, _, err := dialer.Dial("ws://"+wsc.Address+wsc.Endpoint, rHeader) | |||
if err != nil { | |||
return err | |||
} | |||
// Set the ping/pong handlers | |||
con.SetPingHandler(func(m string) error { | |||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||
go con.WriteControl(websocket.PongMessage, []byte(m), time.Now().Add(time.Second*wsWriteTimeoutSeconds)) | |||
return nil | |||
}) | |||
con.SetPongHandler(func(m string) error { | |||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||
return nil | |||
}) | |||
wsc.Conn = con | |||
return nil | |||
} | |||
// OnStop implements cmn.BaseService interface | |||
func (wsc *WSClient) OnStop() { | |||
wsc.BaseService.OnStop() | |||
wsc.Conn.Close() | |||
// ResultsCh/ErrorsCh is closed in receiveEventsRoutine. | |||
} | |||
func (wsc *WSClient) receiveEventsRoutine() { | |||
for { | |||
_, data, err := wsc.ReadMessage() | |||
if err != nil { | |||
log.Info("WSClient failed to read message", "error", err, "data", string(data)) | |||
wsc.Stop() | |||
break | |||
} else { | |||
var response types.RPCResponse | |||
err := json.Unmarshal(data, &response) | |||
if err != nil { | |||
log.Info("WSClient failed to parse message", "error", err, "data", string(data)) | |||
wsc.ErrorsCh <- err | |||
continue | |||
} | |||
if response.Error != "" { | |||
wsc.ErrorsCh <- errors.Errorf(response.Error) | |||
continue | |||
} | |||
wsc.ResultsCh <- *response.Result | |||
} | |||
} | |||
// this must be modified in the same go-routine that reads from the | |||
// connection to avoid race conditions | |||
wsc.Conn = nil | |||
// Cleanup | |||
close(wsc.ResultsCh) | |||
close(wsc.ErrorsCh) | |||
} | |||
// Subscribe to an event. Note the server must have a "subscribe" route | |||
// defined. | |||
func (wsc *WSClient) Subscribe(eventid string) error { | |||
err := wsc.WriteJSON(types.RPCRequest{ | |||
JSONRPC: "2.0", | |||
ID: "", | |||
Method: "subscribe", | |||
Params: map[string]interface{}{"event": eventid}, | |||
}) | |||
return err | |||
} | |||
// Unsubscribe from an event. Note the server must have a "unsubscribe" route | |||
// defined. | |||
func (wsc *WSClient) Unsubscribe(eventid string) error { | |||
err := wsc.WriteJSON(types.RPCRequest{ | |||
JSONRPC: "2.0", | |||
ID: "", | |||
Method: "unsubscribe", | |||
Params: map[string]interface{}{"event": eventid}, | |||
}) | |||
return err | |||
} | |||
// Call asynchronously calls a given method by sending an RPCRequest to the | |||
// server. Results will be available on ResultsCh, errors, if any, on ErrorsCh. | |||
func (wsc *WSClient) Call(method string, params map[string]interface{}) error { | |||
// we need this step because we attempt to decode values using `go-wire` | |||
// (handlers.go:470) on the server side | |||
encodedParams := make(map[string]interface{}) | |||
for k, v := range params { | |||
bytes := json.RawMessage(wire.JSONBytes(v)) | |||
encodedParams[k] = &bytes | |||
} | |||
err := wsc.WriteJSON(types.RPCRequest{ | |||
JSONRPC: "2.0", | |||
Method: method, | |||
Params: encodedParams, | |||
ID: "", | |||
}) | |||
return err | |||
} |
@ -0,0 +1,298 @@ | |||
package rpc | |||
import ( | |||
"bytes" | |||
crand "crypto/rand" | |||
"fmt" | |||
"math/rand" | |||
"net/http" | |||
"os/exec" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
"github.com/stretchr/testify/require" | |||
wire "github.com/tendermint/go-wire" | |||
client "github.com/tendermint/tendermint/rpc/client" | |||
server "github.com/tendermint/tendermint/rpc/server" | |||
types "github.com/tendermint/tendermint/rpc/types" | |||
) | |||
// Client and Server should work over tcp or unix sockets | |||
const ( | |||
tcpAddr = "tcp://0.0.0.0:46657" | |||
unixSocket = "/tmp/rpc.sock" | |||
unixAddr = "unix:///tmp/rpc.sock" | |||
websocketEndpoint = "/websocket/endpoint" | |||
) | |||
// Define a type for results and register concrete versions | |||
type Result interface{} | |||
type ResultEcho struct { | |||
Value string | |||
} | |||
type ResultEchoBytes struct { | |||
Value []byte | |||
} | |||
var _ = wire.RegisterInterface( | |||
struct{ Result }{}, | |||
wire.ConcreteType{&ResultEcho{}, 0x1}, | |||
wire.ConcreteType{&ResultEchoBytes{}, 0x2}, | |||
) | |||
// Define some routes | |||
var Routes = map[string]*server.RPCFunc{ | |||
"echo": server.NewRPCFunc(EchoResult, "arg"), | |||
"echo_ws": server.NewWSRPCFunc(EchoWSResult, "arg"), | |||
"echo_bytes": server.NewRPCFunc(EchoBytesResult, "arg"), | |||
} | |||
func EchoResult(v string) (Result, error) { | |||
return &ResultEcho{v}, nil | |||
} | |||
func EchoWSResult(wsCtx types.WSRPCContext, v string) (Result, error) { | |||
return &ResultEcho{v}, nil | |||
} | |||
func EchoBytesResult(v []byte) (Result, error) { | |||
return &ResultEchoBytes{v}, nil | |||
} | |||
// launch unix and tcp servers | |||
func init() { | |||
cmd := exec.Command("rm", "-f", unixSocket) | |||
err := cmd.Start() | |||
if err != nil { | |||
panic(err) | |||
} | |||
if err = cmd.Wait(); err != nil { | |||
panic(err) | |||
} | |||
mux := http.NewServeMux() | |||
server.RegisterRPCFuncs(mux, Routes) | |||
wm := server.NewWebsocketManager(Routes, nil) | |||
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler) | |||
go func() { | |||
_, err := server.StartHTTPServer(tcpAddr, mux) | |||
if err != nil { | |||
panic(err) | |||
} | |||
}() | |||
mux2 := http.NewServeMux() | |||
server.RegisterRPCFuncs(mux2, Routes) | |||
wm = server.NewWebsocketManager(Routes, nil) | |||
mux2.HandleFunc(websocketEndpoint, wm.WebsocketHandler) | |||
go func() { | |||
_, err := server.StartHTTPServer(unixAddr, mux2) | |||
if err != nil { | |||
panic(err) | |||
} | |||
}() | |||
// wait for servers to start | |||
time.Sleep(time.Second * 2) | |||
} | |||
func echoViaHTTP(cl client.HTTPClient, val string) (string, error) { | |||
params := map[string]interface{}{ | |||
"arg": val, | |||
} | |||
var result Result | |||
if _, err := cl.Call("echo", params, &result); err != nil { | |||
return "", err | |||
} | |||
return result.(*ResultEcho).Value, nil | |||
} | |||
func echoBytesViaHTTP(cl client.HTTPClient, bytes []byte) ([]byte, error) { | |||
params := map[string]interface{}{ | |||
"arg": bytes, | |||
} | |||
var result Result | |||
if _, err := cl.Call("echo_bytes", params, &result); err != nil { | |||
return []byte{}, err | |||
} | |||
return result.(*ResultEchoBytes).Value, nil | |||
} | |||
func testWithHTTPClient(t *testing.T, cl client.HTTPClient) { | |||
val := "acbd" | |||
got, err := echoViaHTTP(cl, val) | |||
require.Nil(t, err) | |||
assert.Equal(t, got, val) | |||
val2 := randBytes(t) | |||
got2, err := echoBytesViaHTTP(cl, val2) | |||
require.Nil(t, err) | |||
assert.Equal(t, got2, val2) | |||
} | |||
func echoViaWS(cl *client.WSClient, val string) (string, error) { | |||
params := map[string]interface{}{ | |||
"arg": val, | |||
} | |||
err := cl.Call("echo", params) | |||
if err != nil { | |||
return "", err | |||
} | |||
select { | |||
case msg := <-cl.ResultsCh: | |||
result := new(Result) | |||
wire.ReadJSONPtr(result, msg, &err) | |||
if err != nil { | |||
return "", nil | |||
} | |||
return (*result).(*ResultEcho).Value, nil | |||
case err := <-cl.ErrorsCh: | |||
return "", err | |||
} | |||
} | |||
func echoBytesViaWS(cl *client.WSClient, bytes []byte) ([]byte, error) { | |||
params := map[string]interface{}{ | |||
"arg": bytes, | |||
} | |||
err := cl.Call("echo_bytes", params) | |||
if err != nil { | |||
return []byte{}, err | |||
} | |||
select { | |||
case msg := <-cl.ResultsCh: | |||
result := new(Result) | |||
wire.ReadJSONPtr(result, msg, &err) | |||
if err != nil { | |||
return []byte{}, nil | |||
} | |||
return (*result).(*ResultEchoBytes).Value, nil | |||
case err := <-cl.ErrorsCh: | |||
return []byte{}, err | |||
} | |||
} | |||
func testWithWSClient(t *testing.T, cl *client.WSClient) { | |||
val := "acbd" | |||
got, err := echoViaWS(cl, val) | |||
require.Nil(t, err) | |||
assert.Equal(t, got, val) | |||
val2 := randBytes(t) | |||
got2, err := echoBytesViaWS(cl, val2) | |||
require.Nil(t, err) | |||
assert.Equal(t, got2, val2) | |||
} | |||
//------------- | |||
func TestServersAndClientsBasic(t *testing.T) { | |||
serverAddrs := [...]string{tcpAddr, unixAddr} | |||
for _, addr := range serverAddrs { | |||
cl1 := client.NewURIClient(addr) | |||
fmt.Printf("=== testing server on %s using %v client", addr, cl1) | |||
testWithHTTPClient(t, cl1) | |||
cl2 := client.NewJSONRPCClient(tcpAddr) | |||
fmt.Printf("=== testing server on %s using %v client", addr, cl2) | |||
testWithHTTPClient(t, cl2) | |||
cl3 := client.NewWSClient(tcpAddr, websocketEndpoint) | |||
_, err := cl3.Start() | |||
require.Nil(t, err) | |||
fmt.Printf("=== testing server on %s using %v client", addr, cl3) | |||
testWithWSClient(t, cl3) | |||
cl3.Stop() | |||
} | |||
} | |||
func TestHexStringArg(t *testing.T) { | |||
cl := client.NewURIClient(tcpAddr) | |||
// should NOT be handled as hex | |||
val := "0xabc" | |||
got, err := echoViaHTTP(cl, val) | |||
require.Nil(t, err) | |||
assert.Equal(t, got, val) | |||
} | |||
func TestQuotedStringArg(t *testing.T) { | |||
cl := client.NewURIClient(tcpAddr) | |||
// should NOT be unquoted | |||
val := "\"abc\"" | |||
got, err := echoViaHTTP(cl, val) | |||
require.Nil(t, err) | |||
assert.Equal(t, got, val) | |||
} | |||
func TestWSNewWSRPCFunc(t *testing.T) { | |||
cl := client.NewWSClient(tcpAddr, websocketEndpoint) | |||
_, err := cl.Start() | |||
require.Nil(t, err) | |||
defer cl.Stop() | |||
val := "acbd" | |||
params := map[string]interface{}{ | |||
"arg": val, | |||
} | |||
err = cl.WriteJSON(types.RPCRequest{ | |||
JSONRPC: "2.0", | |||
ID: "", | |||
Method: "echo_ws", | |||
Params: params, | |||
}) | |||
require.Nil(t, err) | |||
select { | |||
case msg := <-cl.ResultsCh: | |||
result := new(Result) | |||
wire.ReadJSONPtr(result, msg, &err) | |||
require.Nil(t, err) | |||
got := (*result).(*ResultEcho).Value | |||
assert.Equal(t, got, val) | |||
case err := <-cl.ErrorsCh: | |||
t.Fatal(err) | |||
} | |||
} | |||
func TestWSHandlesArrayParams(t *testing.T) { | |||
cl := client.NewWSClient(tcpAddr, websocketEndpoint) | |||
_, err := cl.Start() | |||
require.Nil(t, err) | |||
defer cl.Stop() | |||
val := "acbd" | |||
params := []interface{}{val} | |||
err = cl.WriteJSON(types.RPCRequest{ | |||
JSONRPC: "2.0", | |||
ID: "", | |||
Method: "echo_ws", | |||
Params: params, | |||
}) | |||
require.Nil(t, err) | |||
select { | |||
case msg := <-cl.ResultsCh: | |||
result := new(Result) | |||
wire.ReadJSONPtr(result, msg, &err) | |||
require.Nil(t, err) | |||
got := (*result).(*ResultEcho).Value | |||
assert.Equal(t, got, val) | |||
case err := <-cl.ErrorsCh: | |||
t.Fatalf("%+v", err) | |||
} | |||
} | |||
func randBytes(t *testing.T) []byte { | |||
n := rand.Intn(10) + 2 | |||
buf := make([]byte, n) | |||
_, err := crand.Read(buf) | |||
require.Nil(t, err) | |||
return bytes.Replace(buf, []byte("="), []byte{100}, -1) | |||
} |
@ -0,0 +1,649 @@ | |||
package rpcserver | |||
import ( | |||
"bytes" | |||
"encoding/hex" | |||
"encoding/json" | |||
"fmt" | |||
"io/ioutil" | |||
"net/http" | |||
"reflect" | |||
"sort" | |||
"strings" | |||
"time" | |||
"github.com/gorilla/websocket" | |||
"github.com/pkg/errors" | |||
types "github.com/tendermint/tendermint/rpc/types" | |||
wire "github.com/tendermint/go-wire" | |||
cmn "github.com/tendermint/tmlibs/common" | |||
events "github.com/tendermint/tmlibs/events" | |||
) | |||
// Adds a route for each function in the funcMap, as well as general jsonrpc and websocket handlers for all functions. | |||
// "result" is the interface on which the result objects are registered, and is popualted with every RPCResponse | |||
func RegisterRPCFuncs(mux *http.ServeMux, funcMap map[string]*RPCFunc) { | |||
// HTTP endpoints | |||
for funcName, rpcFunc := range funcMap { | |||
mux.HandleFunc("/"+funcName, makeHTTPHandler(rpcFunc)) | |||
} | |||
// JSONRPC endpoints | |||
mux.HandleFunc("/", makeJSONRPCHandler(funcMap)) | |||
} | |||
//------------------------------------- | |||
// function introspection | |||
// holds all type information for each function | |||
type RPCFunc struct { | |||
f reflect.Value // underlying rpc function | |||
args []reflect.Type // type of each function arg | |||
returns []reflect.Type // type of each return arg | |||
argNames []string // name of each argument | |||
ws bool // websocket only | |||
} | |||
// wraps a function for quicker introspection | |||
// f is the function, args are comma separated argument names | |||
func NewRPCFunc(f interface{}, args string) *RPCFunc { | |||
return newRPCFunc(f, args, false) | |||
} | |||
func NewWSRPCFunc(f interface{}, args string) *RPCFunc { | |||
return newRPCFunc(f, args, true) | |||
} | |||
func newRPCFunc(f interface{}, args string, ws bool) *RPCFunc { | |||
var argNames []string | |||
if args != "" { | |||
argNames = strings.Split(args, ",") | |||
} | |||
return &RPCFunc{ | |||
f: reflect.ValueOf(f), | |||
args: funcArgTypes(f), | |||
returns: funcReturnTypes(f), | |||
argNames: argNames, | |||
ws: ws, | |||
} | |||
} | |||
// return a function's argument types | |||
func funcArgTypes(f interface{}) []reflect.Type { | |||
t := reflect.TypeOf(f) | |||
n := t.NumIn() | |||
typez := make([]reflect.Type, n) | |||
for i := 0; i < n; i++ { | |||
typez[i] = t.In(i) | |||
} | |||
return typez | |||
} | |||
// return a function's return types | |||
func funcReturnTypes(f interface{}) []reflect.Type { | |||
t := reflect.TypeOf(f) | |||
n := t.NumOut() | |||
typez := make([]reflect.Type, n) | |||
for i := 0; i < n; i++ { | |||
typez[i] = t.Out(i) | |||
} | |||
return typez | |||
} | |||
// function introspection | |||
//----------------------------------------------------------------------------- | |||
// rpc.json | |||
// jsonrpc calls grab the given method's function info and runs reflect.Call | |||
func makeJSONRPCHandler(funcMap map[string]*RPCFunc) http.HandlerFunc { | |||
return func(w http.ResponseWriter, r *http.Request) { | |||
b, _ := ioutil.ReadAll(r.Body) | |||
// if its an empty request (like from a browser), | |||
// just display a list of functions | |||
if len(b) == 0 { | |||
writeListOfEndpoints(w, r, funcMap) | |||
return | |||
} | |||
var request types.RPCRequest | |||
err := json.Unmarshal(b, &request) | |||
if err != nil { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, fmt.Sprintf("Error unmarshalling request: %v", err.Error()))) | |||
return | |||
} | |||
if len(r.URL.Path) > 1 { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, fmt.Sprintf("Invalid JSONRPC endpoint %s", r.URL.Path))) | |||
return | |||
} | |||
rpcFunc := funcMap[request.Method] | |||
if rpcFunc == nil { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, "RPC method unknown: "+request.Method)) | |||
return | |||
} | |||
if rpcFunc.ws { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, "RPC method is only for websockets: "+request.Method)) | |||
return | |||
} | |||
args, err := jsonParamsToArgsRPC(rpcFunc, request.Params) | |||
if err != nil { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, fmt.Sprintf("Error converting json params to arguments: %v", err.Error()))) | |||
return | |||
} | |||
returns := rpcFunc.f.Call(args) | |||
log.Info("HTTPJSONRPC", "method", request.Method, "args", args, "returns", returns) | |||
result, err := unreflectResult(returns) | |||
if err != nil { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, result, err.Error())) | |||
return | |||
} | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, result, "")) | |||
} | |||
} | |||
// Convert a []interface{} OR a map[string]interface{} to properly typed values | |||
// | |||
// argsOffset should be 0 for RPC calls, and 1 for WS requests, where len(rpcFunc.args) != len(rpcFunc.argNames). | |||
// Example: | |||
// rpcFunc.args = [rpctypes.WSRPCContext string] | |||
// rpcFunc.argNames = ["arg"] | |||
func jsonParamsToArgs(rpcFunc *RPCFunc, paramsI interface{}, argsOffset int) ([]reflect.Value, error) { | |||
values := make([]reflect.Value, len(rpcFunc.argNames)) | |||
switch params := paramsI.(type) { | |||
case map[string]interface{}: | |||
for i, argName := range rpcFunc.argNames { | |||
argType := rpcFunc.args[i+argsOffset] | |||
// decode param if provided | |||
if param, ok := params[argName]; ok && "" != param { | |||
v, err := _jsonObjectToArg(argType, param) | |||
if err != nil { | |||
return nil, err | |||
} | |||
values[i] = v | |||
} else { // use default for that type | |||
values[i] = reflect.Zero(argType) | |||
} | |||
} | |||
case []interface{}: | |||
if len(rpcFunc.argNames) != len(params) { | |||
return nil, errors.New(fmt.Sprintf("Expected %v parameters (%v), got %v (%v)", | |||
len(rpcFunc.argNames), rpcFunc.argNames, len(params), params)) | |||
} | |||
values := make([]reflect.Value, len(params)) | |||
for i, p := range params { | |||
ty := rpcFunc.args[i+argsOffset] | |||
v, err := _jsonObjectToArg(ty, p) | |||
if err != nil { | |||
return nil, err | |||
} | |||
values[i] = v | |||
} | |||
return values, nil | |||
default: | |||
return nil, fmt.Errorf("Unknown type for JSON params %v. Expected map[string]interface{} or []interface{}", reflect.TypeOf(paramsI)) | |||
} | |||
return values, nil | |||
} | |||
// Convert a []interface{} OR a map[string]interface{} to properly typed values | |||
func jsonParamsToArgsRPC(rpcFunc *RPCFunc, paramsI interface{}) ([]reflect.Value, error) { | |||
return jsonParamsToArgs(rpcFunc, paramsI, 0) | |||
} | |||
// Same as above, but with the first param the websocket connection | |||
func jsonParamsToArgsWS(rpcFunc *RPCFunc, paramsI interface{}, wsCtx types.WSRPCContext) ([]reflect.Value, error) { | |||
values, err := jsonParamsToArgs(rpcFunc, paramsI, 1) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return append([]reflect.Value{reflect.ValueOf(wsCtx)}, values...), nil | |||
} | |||
func _jsonObjectToArg(ty reflect.Type, object interface{}) (reflect.Value, error) { | |||
var err error | |||
v := reflect.New(ty) | |||
wire.ReadJSONObjectPtr(v.Interface(), object, &err) | |||
if err != nil { | |||
return v, err | |||
} | |||
v = v.Elem() | |||
return v, nil | |||
} | |||
// rpc.json | |||
//----------------------------------------------------------------------------- | |||
// rpc.http | |||
// convert from a function name to the http handler | |||
func makeHTTPHandler(rpcFunc *RPCFunc) func(http.ResponseWriter, *http.Request) { | |||
// Exception for websocket endpoints | |||
if rpcFunc.ws { | |||
return func(w http.ResponseWriter, r *http.Request) { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, "This RPC method is only for websockets")) | |||
} | |||
} | |||
// All other endpoints | |||
return func(w http.ResponseWriter, r *http.Request) { | |||
log.Debug("HTTP HANDLER", "req", r) | |||
args, err := httpParamsToArgs(rpcFunc, r) | |||
if err != nil { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, fmt.Sprintf("Error converting http params to args: %v", err.Error()))) | |||
return | |||
} | |||
returns := rpcFunc.f.Call(args) | |||
log.Info("HTTPRestRPC", "method", r.URL.Path, "args", args, "returns", returns) | |||
result, err := unreflectResult(returns) | |||
if err != nil { | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, err.Error())) | |||
return | |||
} | |||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", result, "")) | |||
} | |||
} | |||
// Covert an http query to a list of properly typed values. | |||
// To be properly decoded the arg must be a concrete type from tendermint (if its an interface). | |||
func httpParamsToArgs(rpcFunc *RPCFunc, r *http.Request) ([]reflect.Value, error) { | |||
values := make([]reflect.Value, len(rpcFunc.args)) | |||
for i, name := range rpcFunc.argNames { | |||
argType := rpcFunc.args[i] | |||
values[i] = reflect.Zero(argType) // set default for that type | |||
arg := GetParam(r, name) | |||
// log.Notice("param to arg", "argType", argType, "name", name, "arg", arg) | |||
if "" == arg { | |||
continue | |||
} | |||
v, err, ok := nonJsonToArg(argType, arg) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if ok { | |||
values[i] = v | |||
continue | |||
} | |||
// Pass values to go-wire | |||
values[i], err = _jsonStringToArg(argType, arg) | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
return values, nil | |||
} | |||
func _jsonStringToArg(ty reflect.Type, arg string) (reflect.Value, error) { | |||
var err error | |||
v := reflect.New(ty) | |||
wire.ReadJSONPtr(v.Interface(), []byte(arg), &err) | |||
if err != nil { | |||
return v, err | |||
} | |||
v = v.Elem() | |||
return v, nil | |||
} | |||
func nonJsonToArg(ty reflect.Type, arg string) (reflect.Value, error, bool) { | |||
isQuotedString := strings.HasPrefix(arg, `"`) && strings.HasSuffix(arg, `"`) | |||
isHexString := strings.HasPrefix(strings.ToLower(arg), "0x") | |||
expectingString := ty.Kind() == reflect.String | |||
expectingByteSlice := ty.Kind() == reflect.Slice && ty.Elem().Kind() == reflect.Uint8 | |||
if isHexString { | |||
if !expectingString && !expectingByteSlice { | |||
err := errors.Errorf("Got a hex string arg, but expected '%s'", | |||
ty.Kind().String()) | |||
return reflect.ValueOf(nil), err, false | |||
} | |||
var value []byte | |||
value, err := hex.DecodeString(arg[2:]) | |||
if err != nil { | |||
return reflect.ValueOf(nil), err, false | |||
} | |||
if ty.Kind() == reflect.String { | |||
return reflect.ValueOf(string(value)), nil, true | |||
} | |||
return reflect.ValueOf([]byte(value)), nil, true | |||
} | |||
if isQuotedString && expectingByteSlice { | |||
var err error | |||
v := reflect.New(reflect.TypeOf("")) | |||
wire.ReadJSONPtr(v.Interface(), []byte(arg), &err) | |||
if err != nil { | |||
return reflect.ValueOf(nil), err, false | |||
} | |||
v = v.Elem() | |||
return reflect.ValueOf([]byte(v.String())), nil, true | |||
} | |||
return reflect.ValueOf(nil), nil, false | |||
} | |||
// rpc.http | |||
//----------------------------------------------------------------------------- | |||
// rpc.websocket | |||
const ( | |||
writeChanCapacity = 1000 | |||
wsWriteTimeoutSeconds = 30 // each write times out after this | |||
wsReadTimeoutSeconds = 30 // connection times out if we haven't received *anything* in this long, not even pings. | |||
wsPingTickerSeconds = 10 // send a ping every PingTickerSeconds. | |||
) | |||
// a single websocket connection | |||
// contains listener id, underlying ws connection, | |||
// and the event switch for subscribing to events | |||
type wsConnection struct { | |||
cmn.BaseService | |||
remoteAddr string | |||
baseConn *websocket.Conn | |||
writeChan chan types.RPCResponse | |||
readTimeout *time.Timer | |||
pingTicker *time.Ticker | |||
funcMap map[string]*RPCFunc | |||
evsw events.EventSwitch | |||
} | |||
// new websocket connection wrapper | |||
func NewWSConnection(baseConn *websocket.Conn, funcMap map[string]*RPCFunc, evsw events.EventSwitch) *wsConnection { | |||
wsc := &wsConnection{ | |||
remoteAddr: baseConn.RemoteAddr().String(), | |||
baseConn: baseConn, | |||
writeChan: make(chan types.RPCResponse, writeChanCapacity), // error when full. | |||
funcMap: funcMap, | |||
evsw: evsw, | |||
} | |||
wsc.BaseService = *cmn.NewBaseService(log, "wsConnection", wsc) | |||
return wsc | |||
} | |||
// wsc.Start() blocks until the connection closes. | |||
func (wsc *wsConnection) OnStart() error { | |||
wsc.BaseService.OnStart() | |||
// these must be set before the readRoutine is created, as it may | |||
// call wsc.Stop(), which accesses these timers | |||
wsc.readTimeout = time.NewTimer(time.Second * wsReadTimeoutSeconds) | |||
wsc.pingTicker = time.NewTicker(time.Second * wsPingTickerSeconds) | |||
// Read subscriptions/unsubscriptions to events | |||
go wsc.readRoutine() | |||
// Custom Ping handler to touch readTimeout | |||
wsc.baseConn.SetPingHandler(func(m string) error { | |||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||
go wsc.baseConn.WriteControl(websocket.PongMessage, []byte(m), time.Now().Add(time.Second*wsWriteTimeoutSeconds)) | |||
wsc.readTimeout.Reset(time.Second * wsReadTimeoutSeconds) | |||
return nil | |||
}) | |||
wsc.baseConn.SetPongHandler(func(m string) error { | |||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||
wsc.readTimeout.Reset(time.Second * wsReadTimeoutSeconds) | |||
return nil | |||
}) | |||
go wsc.readTimeoutRoutine() | |||
// Write responses, BLOCKING. | |||
wsc.writeRoutine() | |||
return nil | |||
} | |||
func (wsc *wsConnection) OnStop() { | |||
wsc.BaseService.OnStop() | |||
if wsc.evsw != nil { | |||
wsc.evsw.RemoveListener(wsc.remoteAddr) | |||
} | |||
wsc.readTimeout.Stop() | |||
wsc.pingTicker.Stop() | |||
// The write loop closes the websocket connection | |||
// when it exits its loop, and the read loop | |||
// closes the writeChan | |||
} | |||
func (wsc *wsConnection) readTimeoutRoutine() { | |||
select { | |||
case <-wsc.readTimeout.C: | |||
log.Notice("Stopping connection due to read timeout") | |||
wsc.Stop() | |||
case <-wsc.Quit: | |||
return | |||
} | |||
} | |||
// Implements WSRPCConnection | |||
func (wsc *wsConnection) GetRemoteAddr() string { | |||
return wsc.remoteAddr | |||
} | |||
// Implements WSRPCConnection | |||
func (wsc *wsConnection) GetEventSwitch() events.EventSwitch { | |||
return wsc.evsw | |||
} | |||
// Implements WSRPCConnection | |||
// Blocking write to writeChan until service stops. | |||
// Goroutine-safe | |||
func (wsc *wsConnection) WriteRPCResponse(resp types.RPCResponse) { | |||
select { | |||
case <-wsc.Quit: | |||
return | |||
case wsc.writeChan <- resp: | |||
} | |||
} | |||
// Implements WSRPCConnection | |||
// Nonblocking write. | |||
// Goroutine-safe | |||
func (wsc *wsConnection) TryWriteRPCResponse(resp types.RPCResponse) bool { | |||
select { | |||
case <-wsc.Quit: | |||
return false | |||
case wsc.writeChan <- resp: | |||
return true | |||
default: | |||
return false | |||
} | |||
} | |||
// Read from the socket and subscribe to or unsubscribe from events | |||
func (wsc *wsConnection) readRoutine() { | |||
// Do not close writeChan, to allow WriteRPCResponse() to fail. | |||
// defer close(wsc.writeChan) | |||
for { | |||
select { | |||
case <-wsc.Quit: | |||
return | |||
default: | |||
var in []byte | |||
// Do not set a deadline here like below: | |||
// wsc.baseConn.SetReadDeadline(time.Now().Add(time.Second * wsReadTimeoutSeconds)) | |||
// The client may not send anything for a while. | |||
// We use `readTimeout` to handle read timeouts. | |||
_, in, err := wsc.baseConn.ReadMessage() | |||
if err != nil { | |||
log.Notice("Failed to read from connection", "remote", wsc.remoteAddr, "err", err.Error()) | |||
// an error reading the connection, | |||
// kill the connection | |||
wsc.Stop() | |||
return | |||
} | |||
var request types.RPCRequest | |||
err = json.Unmarshal(in, &request) | |||
if err != nil { | |||
errStr := fmt.Sprintf("Error unmarshaling data: %s", err.Error()) | |||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, errStr)) | |||
continue | |||
} | |||
// Now, fetch the RPCFunc and execute it. | |||
rpcFunc := wsc.funcMap[request.Method] | |||
if rpcFunc == nil { | |||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, "RPC method unknown: "+request.Method)) | |||
continue | |||
} | |||
var args []reflect.Value | |||
if rpcFunc.ws { | |||
wsCtx := types.WSRPCContext{Request: request, WSRPCConnection: wsc} | |||
args, err = jsonParamsToArgsWS(rpcFunc, request.Params, wsCtx) | |||
} else { | |||
args, err = jsonParamsToArgsRPC(rpcFunc, request.Params) | |||
} | |||
if err != nil { | |||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, err.Error())) | |||
continue | |||
} | |||
returns := rpcFunc.f.Call(args) | |||
log.Info("WSJSONRPC", "method", request.Method, "args", args, "returns", returns) | |||
result, err := unreflectResult(returns) | |||
if err != nil { | |||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, err.Error())) | |||
continue | |||
} else { | |||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, result, "")) | |||
continue | |||
} | |||
} | |||
} | |||
} | |||
// receives on a write channel and writes out on the socket | |||
func (wsc *wsConnection) writeRoutine() { | |||
defer wsc.baseConn.Close() | |||
for { | |||
select { | |||
case <-wsc.Quit: | |||
return | |||
case <-wsc.pingTicker.C: | |||
err := wsc.baseConn.WriteMessage(websocket.PingMessage, []byte{}) | |||
if err != nil { | |||
log.Error("Failed to write ping message on websocket", "error", err) | |||
wsc.Stop() | |||
return | |||
} | |||
case msg := <-wsc.writeChan: | |||
jsonBytes, err := json.Marshal(msg) | |||
if err != nil { | |||
log.Error("Failed to marshal RPCResponse to JSON", "error", err) | |||
} else { | |||
wsc.baseConn.SetWriteDeadline(time.Now().Add(time.Second * wsWriteTimeoutSeconds)) | |||
if err = wsc.baseConn.WriteMessage(websocket.TextMessage, jsonBytes); err != nil { | |||
log.Warn("Failed to write response on websocket", "error", err) | |||
wsc.Stop() | |||
return | |||
} | |||
} | |||
} | |||
} | |||
} | |||
//---------------------------------------- | |||
// Main manager for all websocket connections | |||
// Holds the event switch | |||
// NOTE: The websocket path is defined externally, e.g. in node/node.go | |||
type WebsocketManager struct { | |||
websocket.Upgrader | |||
funcMap map[string]*RPCFunc | |||
evsw events.EventSwitch | |||
} | |||
func NewWebsocketManager(funcMap map[string]*RPCFunc, evsw events.EventSwitch) *WebsocketManager { | |||
return &WebsocketManager{ | |||
funcMap: funcMap, | |||
evsw: evsw, | |||
Upgrader: websocket.Upgrader{ | |||
ReadBufferSize: 1024, | |||
WriteBufferSize: 1024, | |||
CheckOrigin: func(r *http.Request) bool { | |||
// TODO | |||
return true | |||
}, | |||
}, | |||
} | |||
} | |||
// Upgrade the request/response (via http.Hijack) and starts the wsConnection. | |||
func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Request) { | |||
wsConn, err := wm.Upgrade(w, r, nil) | |||
if err != nil { | |||
// TODO - return http error | |||
log.Error("Failed to upgrade to websocket connection", "error", err) | |||
return | |||
} | |||
// register connection | |||
con := NewWSConnection(wsConn, wm.funcMap, wm.evsw) | |||
log.Notice("New websocket connection", "remote", con.remoteAddr) | |||
con.Start() // Blocking | |||
} | |||
// rpc.websocket | |||
//----------------------------------------------------------------------------- | |||
// NOTE: assume returns is result struct and error. If error is not nil, return it | |||
func unreflectResult(returns []reflect.Value) (interface{}, error) { | |||
errV := returns[1] | |||
if errV.Interface() != nil { | |||
return nil, errors.Errorf("%v", errV.Interface()) | |||
} | |||
rv := returns[0] | |||
// the result is a registered interface, | |||
// we need a pointer to it so we can marshal with type byte | |||
rvp := reflect.New(rv.Type()) | |||
rvp.Elem().Set(rv) | |||
return rvp.Interface(), nil | |||
} | |||
// writes a list of available rpc endpoints as an html page | |||
func writeListOfEndpoints(w http.ResponseWriter, r *http.Request, funcMap map[string]*RPCFunc) { | |||
noArgNames := []string{} | |||
argNames := []string{} | |||
for name, funcData := range funcMap { | |||
if len(funcData.args) == 0 { | |||
noArgNames = append(noArgNames, name) | |||
} else { | |||
argNames = append(argNames, name) | |||
} | |||
} | |||
sort.Strings(noArgNames) | |||
sort.Strings(argNames) | |||
buf := new(bytes.Buffer) | |||
buf.WriteString("<html><body>") | |||
buf.WriteString("<br>Available endpoints:<br>") | |||
for _, name := range noArgNames { | |||
link := fmt.Sprintf("http://%s/%s", r.Host, name) | |||
buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link)) | |||
} | |||
buf.WriteString("<br>Endpoints that require arguments:<br>") | |||
for _, name := range argNames { | |||
link := fmt.Sprintf("http://%s/%s?", r.Host, name) | |||
funcData := funcMap[name] | |||
for i, argName := range funcData.argNames { | |||
link += argName + "=_" | |||
if i < len(funcData.argNames)-1 { | |||
link += "&" | |||
} | |||
} | |||
buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link)) | |||
} | |||
buf.WriteString("</body></html>") | |||
w.Header().Set("Content-Type", "text/html") | |||
w.WriteHeader(200) | |||
w.Write(buf.Bytes()) | |||
} |
@ -0,0 +1,90 @@ | |||
package rpcserver | |||
import ( | |||
"encoding/hex" | |||
"net/http" | |||
"regexp" | |||
"strconv" | |||
"github.com/pkg/errors" | |||
) | |||
var ( | |||
// Parts of regular expressions | |||
atom = "[A-Z0-9!#$%&'*+\\-/=?^_`{|}~]+" | |||
dotAtom = atom + `(?:\.` + atom + `)*` | |||
domain = `[A-Z0-9.-]+\.[A-Z]{2,4}` | |||
RE_HEX = regexp.MustCompile(`^(?i)[a-f0-9]+$`) | |||
RE_EMAIL = regexp.MustCompile(`^(?i)(` + dotAtom + `)@(` + dotAtom + `)$`) | |||
RE_ADDRESS = regexp.MustCompile(`^(?i)[a-z0-9]{25,34}$`) | |||
RE_HOST = regexp.MustCompile(`^(?i)(` + domain + `)$`) | |||
//RE_ID12 = regexp.MustCompile(`^[a-zA-Z0-9]{12}$`) | |||
) | |||
func GetParam(r *http.Request, param string) string { | |||
s := r.URL.Query().Get(param) | |||
if s == "" { | |||
s = r.FormValue(param) | |||
} | |||
return s | |||
} | |||
func GetParamByteSlice(r *http.Request, param string) ([]byte, error) { | |||
s := GetParam(r, param) | |||
return hex.DecodeString(s) | |||
} | |||
func GetParamInt64(r *http.Request, param string) (int64, error) { | |||
s := GetParam(r, param) | |||
i, err := strconv.ParseInt(s, 10, 64) | |||
if err != nil { | |||
return 0, errors.Errorf(param, err.Error()) | |||
} | |||
return i, nil | |||
} | |||
func GetParamInt32(r *http.Request, param string) (int32, error) { | |||
s := GetParam(r, param) | |||
i, err := strconv.ParseInt(s, 10, 32) | |||
if err != nil { | |||
return 0, errors.Errorf(param, err.Error()) | |||
} | |||
return int32(i), nil | |||
} | |||
func GetParamUint64(r *http.Request, param string) (uint64, error) { | |||
s := GetParam(r, param) | |||
i, err := strconv.ParseUint(s, 10, 64) | |||
if err != nil { | |||
return 0, errors.Errorf(param, err.Error()) | |||
} | |||
return i, nil | |||
} | |||
func GetParamUint(r *http.Request, param string) (uint, error) { | |||
s := GetParam(r, param) | |||
i, err := strconv.ParseUint(s, 10, 64) | |||
if err != nil { | |||
return 0, errors.Errorf(param, err.Error()) | |||
} | |||
return uint(i), nil | |||
} | |||
func GetParamRegexp(r *http.Request, param string, re *regexp.Regexp) (string, error) { | |||
s := GetParam(r, param) | |||
if !re.MatchString(s) { | |||
return "", errors.Errorf(param, "Did not match regular expression %v", re.String()) | |||
} | |||
return s, nil | |||
} | |||
func GetParamFloat64(r *http.Request, param string) (float64, error) { | |||
s := GetParam(r, param) | |||
f, err := strconv.ParseFloat(s, 64) | |||
if err != nil { | |||
return 0, errors.Errorf(param, err.Error()) | |||
} | |||
return f, nil | |||
} |
@ -0,0 +1,125 @@ | |||
// Commons for HTTP handling | |||
package rpcserver | |||
import ( | |||
"bufio" | |||
"encoding/json" | |||
"fmt" | |||
"net" | |||
"net/http" | |||
"runtime/debug" | |||
"strings" | |||
"time" | |||
"github.com/pkg/errors" | |||
types "github.com/tendermint/tendermint/rpc/types" | |||
) | |||
func StartHTTPServer(listenAddr string, handler http.Handler) (listener net.Listener, err error) { | |||
// listenAddr should be fully formed including tcp:// or unix:// prefix | |||
var proto, addr string | |||
parts := strings.SplitN(listenAddr, "://", 2) | |||
if len(parts) != 2 { | |||
log.Warn("WARNING (tendermint/rpc): Please use fully formed listening addresses, including the tcp:// or unix:// prefix") | |||
// we used to allow addrs without tcp/unix prefix by checking for a colon | |||
// TODO: Deprecate | |||
proto = types.SocketType(listenAddr) | |||
addr = listenAddr | |||
// return nil, errors.Errorf("Invalid listener address %s", lisenAddr) | |||
} else { | |||
proto, addr = parts[0], parts[1] | |||
} | |||
log.Notice(fmt.Sprintf("Starting RPC HTTP server on %s socket %v", proto, addr)) | |||
listener, err = net.Listen(proto, addr) | |||
if err != nil { | |||
return nil, errors.Errorf("Failed to listen to %v: %v", listenAddr, err) | |||
} | |||
go func() { | |||
res := http.Serve( | |||
listener, | |||
RecoverAndLogHandler(handler), | |||
) | |||
log.Crit("RPC HTTP server stopped", "result", res) | |||
}() | |||
return listener, nil | |||
} | |||
func WriteRPCResponseHTTP(w http.ResponseWriter, res types.RPCResponse) { | |||
// jsonBytes := wire.JSONBytesPretty(res) | |||
jsonBytes, err := json.Marshal(res) | |||
if err != nil { | |||
panic(err) | |||
} | |||
w.Header().Set("Content-Type", "application/json") | |||
w.WriteHeader(200) | |||
w.Write(jsonBytes) | |||
} | |||
//----------------------------------------------------------------------------- | |||
// Wraps an HTTP handler, adding error logging. | |||
// If the inner function panics, the outer function recovers, logs, sends an | |||
// HTTP 500 error response. | |||
func RecoverAndLogHandler(handler http.Handler) http.Handler { | |||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||
// Wrap the ResponseWriter to remember the status | |||
rww := &ResponseWriterWrapper{-1, w} | |||
begin := time.Now() | |||
// Common headers | |||
origin := r.Header.Get("Origin") | |||
rww.Header().Set("Access-Control-Allow-Origin", origin) | |||
rww.Header().Set("Access-Control-Allow-Credentials", "true") | |||
rww.Header().Set("Access-Control-Expose-Headers", "X-Server-Time") | |||
rww.Header().Set("X-Server-Time", fmt.Sprintf("%v", begin.Unix())) | |||
defer func() { | |||
// Send a 500 error if a panic happens during a handler. | |||
// Without this, Chrome & Firefox were retrying aborted ajax requests, | |||
// at least to my localhost. | |||
if e := recover(); e != nil { | |||
// If RPCResponse | |||
if res, ok := e.(types.RPCResponse); ok { | |||
WriteRPCResponseHTTP(rww, res) | |||
} else { | |||
// For the rest, | |||
log.Error("Panic in RPC HTTP handler", "error", e, "stack", string(debug.Stack())) | |||
rww.WriteHeader(http.StatusInternalServerError) | |||
WriteRPCResponseHTTP(rww, types.NewRPCResponse("", nil, fmt.Sprintf("Internal Server Error: %v", e))) | |||
} | |||
} | |||
// Finally, log. | |||
durationMS := time.Since(begin).Nanoseconds() / 1000000 | |||
if rww.Status == -1 { | |||
rww.Status = 200 | |||
} | |||
log.Info("Served RPC HTTP response", | |||
"method", r.Method, "url", r.URL, | |||
"status", rww.Status, "duration", durationMS, | |||
"remoteAddr", r.RemoteAddr, | |||
) | |||
}() | |||
handler.ServeHTTP(rww, r) | |||
}) | |||
} | |||
// Remember the status for logging | |||
type ResponseWriterWrapper struct { | |||
Status int | |||
http.ResponseWriter | |||
} | |||
func (w *ResponseWriterWrapper) WriteHeader(status int) { | |||
w.Status = status | |||
w.ResponseWriter.WriteHeader(status) | |||
} | |||
// implements http.Hijacker | |||
func (w *ResponseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { | |||
return w.ResponseWriter.(http.Hijacker).Hijack() | |||
} |
@ -0,0 +1,7 @@ | |||
package rpcserver | |||
import ( | |||
"github.com/tendermint/log15" | |||
) | |||
var log = log15.New("module", "rpcserver") |