@ -1,7 +1,7 @@ | |||||
package blockchain | package blockchain | ||||
import ( | import ( | ||||
"github.com/tendermint/go-logger" | |||||
"github.com/tendermint/tmlibs/logger" | |||||
) | ) | ||||
var log = logger.New("module", "blockchain") | var log = logger.New("module", "blockchain") |
@ -1,56 +1,57 @@ | |||||
package: github.com/tendermint/tendermint | package: github.com/tendermint/tendermint | ||||
import: | import: | ||||
- package: github.com/tendermint/go-autofile | |||||
version: develop | |||||
- package: github.com/tendermint/go-clist | |||||
version: develop | |||||
- package: github.com/tendermint/go-common | |||||
- package: github.com/ebuchman/fail-test | |||||
- package: github.com/gogo/protobuf | |||||
subpackages: | |||||
- proto | |||||
- package: github.com/golang/protobuf | |||||
subpackages: | |||||
- proto | |||||
- package: github.com/gorilla/websocket | |||||
- package: github.com/pkg/errors | |||||
- package: github.com/spf13/cobra | |||||
- package: github.com/stretchr/testify | |||||
subpackages: | |||||
- require | |||||
- package: github.com/tendermint/abci | |||||
version: develop | version: develop | ||||
subpackages: | |||||
- client | |||||
- example/dummy | |||||
- types | |||||
- package: github.com/tendermint/go-config | - package: github.com/tendermint/go-config | ||||
version: develop | version: develop | ||||
- package: github.com/tendermint/go-crypto | - package: github.com/tendermint/go-crypto | ||||
version: develop | version: develop | ||||
- package: github.com/tendermint/go-data | |||||
version: develop | |||||
- package: github.com/tendermint/go-db | |||||
version: develop | |||||
- package: github.com/tendermint/go-events | |||||
version: develop | |||||
- package: github.com/tendermint/go-logger | |||||
version: develop | |||||
- package: github.com/tendermint/go-merkle | |||||
version: develop | |||||
- package: github.com/tendermint/go-p2p | |||||
version: unstable | |||||
- package: github.com/tendermint/go-rpc | |||||
version: develop | |||||
- package: github.com/tendermint/go-wire | - package: github.com/tendermint/go-wire | ||||
version: develop | version: develop | ||||
- package: github.com/tendermint/abci | |||||
version: develop | |||||
- package: github.com/tendermint/go-flowrate | |||||
subpackages: | |||||
- data | |||||
- package: github.com/tendermint/log15 | - package: github.com/tendermint/log15 | ||||
- package: github.com/tendermint/ed25519 | |||||
- package: github.com/tendermint/merkleeyes | |||||
- package: github.com/tendermint/tmlibs | |||||
version: develop | version: develop | ||||
subpackages: | subpackages: | ||||
- app | |||||
- package: github.com/gogo/protobuf | |||||
version: ^0.3 | |||||
subpackages: | |||||
- proto | |||||
- package: github.com/gorilla/websocket | |||||
version: ^1.1.0 | |||||
- package: github.com/spf13/cobra | |||||
- package: github.com/spf13/pflag | |||||
- package: github.com/pkg/errors | |||||
version: ^0.8.0 | |||||
- autofile | |||||
- clist | |||||
- common | |||||
- db | |||||
- events | |||||
- flowrate | |||||
- logger | |||||
- merkle | |||||
- package: golang.org/x/crypto | - package: golang.org/x/crypto | ||||
subpackages: | subpackages: | ||||
- nacl/box | |||||
- nacl/secretbox | |||||
- ripemd160 | - ripemd160 | ||||
- package: golang.org/x/net | |||||
subpackages: | |||||
- context | |||||
- package: google.golang.org/grpc | |||||
testImport: | testImport: | ||||
- package: github.com/stretchr/testify | |||||
version: ^1.1.4 | |||||
- package: github.com/tendermint/merkleeyes | |||||
version: develop | |||||
subpackages: | subpackages: | ||||
- assert | |||||
- require | |||||
- app | |||||
- iavl | |||||
- testutil |
@ -1,7 +1,7 @@ | |||||
package node | package node | ||||
import ( | import ( | ||||
"github.com/tendermint/go-logger" | |||||
"github.com/tendermint/tmlibs/logger" | |||||
) | ) | ||||
var log = logger.New("module", "node") | var log = logger.New("module", "node") |
@ -0,0 +1,78 @@ | |||||
# Changelog | |||||
## 0.5.0 (April 21, 2017) | |||||
BREAKING CHANGES: | |||||
- Remove or unexport methods from FuzzedConnection: Active, Mode, ProbDropRW, ProbDropConn, ProbSleep, MaxDelayMilliseconds, Fuzz | |||||
- switch.AddPeerWithConnection is unexported and replaced by switch.AddPeer | |||||
- switch.DialPeerWithAddress takes a bool, setting the peer as persistent or not | |||||
FEATURES: | |||||
- Persistent peers: any peer considered a "seed" will be reconnected to when the connection is dropped | |||||
IMPROVEMENTS: | |||||
- Many more tests and comments | |||||
- Refactor configurations for less dependence on go-config. Introduces new structs PeerConfig, MConnConfig, FuzzConnConfig | |||||
- New methods on peer: CloseConn, HandshakeTimeout, IsPersistent, Addr, PubKey | |||||
- NewNetAddress supports a testing mode where the address defaults to 0.0.0.0:0 | |||||
## 0.4.0 (March 6, 2017) | |||||
BREAKING CHANGES: | |||||
- DialSeeds now takes an AddrBook and returns an error: `DialSeeds(*AddrBook, []string) error` | |||||
- NewNetAddressString now returns an error: `NewNetAddressString(string) (*NetAddress, error)` | |||||
FEATURES: | |||||
- `NewNetAddressStrings([]string) ([]*NetAddress, error)` | |||||
- `AddrBook.Save()` | |||||
IMPROVEMENTS: | |||||
- PexReactor responsible for starting and stopping the AddrBook | |||||
BUG FIXES: | |||||
- DialSeeds returns an error instead of panicking on bad addresses | |||||
## 0.3.5 (January 12, 2017) | |||||
FEATURES | |||||
- Toggle strict routability in the AddrBook | |||||
BUG FIXES | |||||
- Close filtered out connections | |||||
- Fixes for MakeConnectedSwitches and Connect2Switches | |||||
## 0.3.4 (August 10, 2016) | |||||
FEATURES: | |||||
- Optionally filter connections by address or public key | |||||
## 0.3.3 (May 12, 2016) | |||||
FEATURES: | |||||
- FuzzConn | |||||
## 0.3.2 (March 12, 2016) | |||||
IMPROVEMENTS: | |||||
- Memory optimizations | |||||
## 0.3.1 () | |||||
FEATURES: | |||||
- Configurable parameters | |||||
@ -0,0 +1,13 @@ | |||||
FROM golang:latest | |||||
RUN curl https://glide.sh/get | sh | |||||
RUN mkdir -p /go/src/github.com/tendermint/tendermint/p2p | |||||
WORKDIR /go/src/github.com/tendermint/tendermint/p2p | |||||
COPY glide.yaml /go/src/github.com/tendermint/tendermint/p2p/ | |||||
COPY glide.lock /go/src/github.com/tendermint/tendermint/p2p/ | |||||
RUN glide install | |||||
COPY . /go/src/github.com/tendermint/tendermint/p2p |
@ -0,0 +1,79 @@ | |||||
# `tendermint/tendermint/p2p` | |||||
[![CircleCI](https://circleci.com/gh/tendermint/tendermint/p2p.svg?style=svg)](https://circleci.com/gh/tendermint/tendermint/p2p) | |||||
`tendermint/tendermint/p2p` provides an abstraction around peer-to-peer communication.<br/> | |||||
## Peer/MConnection/Channel | |||||
Each peer has one `MConnection` (multiplex connection) instance. | |||||
__multiplex__ *noun* a system or signal involving simultaneous transmission of | |||||
several messages along a single channel of communication. | |||||
Each `MConnection` handles message transmission on multiple abstract communication | |||||
`Channel`s. Each channel has a globally unique byte id. | |||||
The byte id and the relative priorities of each `Channel` are configured upon | |||||
initialization of the connection. | |||||
There are two methods for sending messages: | |||||
```go | |||||
func (m MConnection) Send(chID byte, msg interface{}) bool {} | |||||
func (m MConnection) TrySend(chID byte, msg interface{}) bool {} | |||||
``` | |||||
`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued | |||||
for the channel with the given id byte `chID`. The message `msg` is serialized | |||||
using the `tendermint/wire` submodule's `WriteBinary()` reflection routine. | |||||
`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's | |||||
queue is full. | |||||
`Send()` and `TrySend()` are also exposed for each `Peer`. | |||||
## Switch/Reactor | |||||
The `Switch` handles peer connections and exposes an API to receive incoming messages | |||||
on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one | |||||
or more `Channels`. So while sending outgoing messages is typically performed on the peer, | |||||
incoming messages are received on the reactor. | |||||
```go | |||||
// Declare a MyReactor reactor that handles messages on MyChannelID. | |||||
type MyReactor struct{} | |||||
func (reactor MyReactor) GetChannels() []*ChannelDescriptor { | |||||
return []*ChannelDescriptor{ChannelDescriptor{ID:MyChannelID, Priority: 1}} | |||||
} | |||||
func (reactor MyReactor) Receive(chID byte, peer *Peer, msgBytes []byte) { | |||||
r, n, err := bytes.NewBuffer(msgBytes), new(int64), new(error) | |||||
msgString := ReadString(r, n, err) | |||||
fmt.Println(msgString) | |||||
} | |||||
// Other Reactor methods omitted for brevity | |||||
... | |||||
switch := NewSwitch([]Reactor{MyReactor{}}) | |||||
... | |||||
// Send a random message to all outbound connections | |||||
for _, peer := range switch.Peers().List() { | |||||
if peer.IsOutbound() { | |||||
peer.Send(MyChannelID, "Here's a random message") | |||||
} | |||||
} | |||||
``` | |||||
### PexReactor/AddrBook | |||||
A `PEXReactor` reactor implementation is provided to automate peer discovery. | |||||
```go | |||||
book := p2p.NewAddrBook(addrBookFilePath) | |||||
pexReactor := p2p.NewPEXReactor(book) | |||||
... | |||||
switch := NewSwitch([]Reactor{pexReactor, myReactor, ...}) | |||||
``` |
@ -0,0 +1,839 @@ | |||||
// Modified for Tendermint | |||||
// Originally Copyright (c) 2013-2014 Conformal Systems LLC. | |||||
// https://github.com/conformal/btcd/blob/master/LICENSE | |||||
package p2p | |||||
import ( | |||||
"encoding/binary" | |||||
"encoding/json" | |||||
"math" | |||||
"math/rand" | |||||
"net" | |||||
"os" | |||||
"sync" | |||||
"time" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
crypto "github.com/tendermint/go-crypto" | |||||
) | |||||
const ( | |||||
// addresses under which the address manager will claim to need more addresses. | |||||
needAddressThreshold = 1000 | |||||
// interval used to dump the address cache to disk for future use. | |||||
dumpAddressInterval = time.Minute * 2 | |||||
// max addresses in each old address bucket. | |||||
oldBucketSize = 64 | |||||
// buckets we split old addresses over. | |||||
oldBucketCount = 64 | |||||
// max addresses in each new address bucket. | |||||
newBucketSize = 64 | |||||
// buckets that we spread new addresses over. | |||||
newBucketCount = 256 | |||||
// old buckets over which an address group will be spread. | |||||
oldBucketsPerGroup = 4 | |||||
// new buckets over which an source address group will be spread. | |||||
newBucketsPerGroup = 32 | |||||
// buckets a frequently seen new address may end up in. | |||||
maxNewBucketsPerAddress = 4 | |||||
// days before which we assume an address has vanished | |||||
// if we have not seen it announced in that long. | |||||
numMissingDays = 30 | |||||
// tries without a single success before we assume an address is bad. | |||||
numRetries = 3 | |||||
// max failures we will accept without a success before considering an address bad. | |||||
maxFailures = 10 | |||||
// days since the last success before we will consider evicting an address. | |||||
minBadDays = 7 | |||||
// % of total addresses known returned by GetSelection. | |||||
getSelectionPercent = 23 | |||||
// min addresses that must be returned by GetSelection. Useful for bootstrapping. | |||||
minGetSelection = 32 | |||||
// max addresses returned by GetSelection | |||||
// NOTE: this must match "maxPexMessageSize" | |||||
maxGetSelection = 250 | |||||
// current version of the on-disk format. | |||||
serializationVersion = 1 | |||||
) | |||||
const ( | |||||
bucketTypeNew = 0x01 | |||||
bucketTypeOld = 0x02 | |||||
) | |||||
// AddrBook - concurrency safe peer address manager. | |||||
type AddrBook struct { | |||||
BaseService | |||||
mtx sync.Mutex | |||||
filePath string | |||||
routabilityStrict bool | |||||
rand *rand.Rand | |||||
key string | |||||
ourAddrs map[string]*NetAddress | |||||
addrLookup map[string]*knownAddress // new & old | |||||
addrNew []map[string]*knownAddress | |||||
addrOld []map[string]*knownAddress | |||||
wg sync.WaitGroup | |||||
nOld int | |||||
nNew int | |||||
} | |||||
// NewAddrBook creates a new address book. | |||||
// Use Start to begin processing asynchronous address updates. | |||||
func NewAddrBook(filePath string, routabilityStrict bool) *AddrBook { | |||||
am := &AddrBook{ | |||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())), | |||||
ourAddrs: make(map[string]*NetAddress), | |||||
addrLookup: make(map[string]*knownAddress), | |||||
filePath: filePath, | |||||
routabilityStrict: routabilityStrict, | |||||
} | |||||
am.init() | |||||
am.BaseService = *NewBaseService(log, "AddrBook", am) | |||||
return am | |||||
} | |||||
// When modifying this, don't forget to update loadFromFile() | |||||
func (a *AddrBook) init() { | |||||
a.key = crypto.CRandHex(24) // 24/2 * 8 = 96 bits | |||||
// New addr buckets | |||||
a.addrNew = make([]map[string]*knownAddress, newBucketCount) | |||||
for i := range a.addrNew { | |||||
a.addrNew[i] = make(map[string]*knownAddress) | |||||
} | |||||
// Old addr buckets | |||||
a.addrOld = make([]map[string]*knownAddress, oldBucketCount) | |||||
for i := range a.addrOld { | |||||
a.addrOld[i] = make(map[string]*knownAddress) | |||||
} | |||||
} | |||||
// OnStart implements Service. | |||||
func (a *AddrBook) OnStart() error { | |||||
a.BaseService.OnStart() | |||||
a.loadFromFile(a.filePath) | |||||
a.wg.Add(1) | |||||
go a.saveRoutine() | |||||
return nil | |||||
} | |||||
// OnStop implements Service. | |||||
func (a *AddrBook) OnStop() { | |||||
a.BaseService.OnStop() | |||||
} | |||||
func (a *AddrBook) Wait() { | |||||
a.wg.Wait() | |||||
} | |||||
func (a *AddrBook) AddOurAddress(addr *NetAddress) { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
log.Info("Add our address to book", "addr", addr) | |||||
a.ourAddrs[addr.String()] = addr | |||||
} | |||||
func (a *AddrBook) OurAddresses() []*NetAddress { | |||||
addrs := []*NetAddress{} | |||||
for _, addr := range a.ourAddrs { | |||||
addrs = append(addrs, addr) | |||||
} | |||||
return addrs | |||||
} | |||||
// NOTE: addr must not be nil | |||||
func (a *AddrBook) AddAddress(addr *NetAddress, src *NetAddress) { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
log.Info("Add address to book", "addr", addr, "src", src) | |||||
a.addAddress(addr, src) | |||||
} | |||||
func (a *AddrBook) NeedMoreAddrs() bool { | |||||
return a.Size() < needAddressThreshold | |||||
} | |||||
func (a *AddrBook) Size() int { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
return a.size() | |||||
} | |||||
func (a *AddrBook) size() int { | |||||
return a.nNew + a.nOld | |||||
} | |||||
// Pick an address to connect to with new/old bias. | |||||
func (a *AddrBook) PickAddress(newBias int) *NetAddress { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
if a.size() == 0 { | |||||
return nil | |||||
} | |||||
if newBias > 100 { | |||||
newBias = 100 | |||||
} | |||||
if newBias < 0 { | |||||
newBias = 0 | |||||
} | |||||
// Bias between new and old addresses. | |||||
oldCorrelation := math.Sqrt(float64(a.nOld)) * (100.0 - float64(newBias)) | |||||
newCorrelation := math.Sqrt(float64(a.nNew)) * float64(newBias) | |||||
if (newCorrelation+oldCorrelation)*a.rand.Float64() < oldCorrelation { | |||||
// pick random Old bucket. | |||||
var bucket map[string]*knownAddress = nil | |||||
for len(bucket) == 0 { | |||||
bucket = a.addrOld[a.rand.Intn(len(a.addrOld))] | |||||
} | |||||
// pick a random ka from bucket. | |||||
randIndex := a.rand.Intn(len(bucket)) | |||||
for _, ka := range bucket { | |||||
if randIndex == 0 { | |||||
return ka.Addr | |||||
} | |||||
randIndex-- | |||||
} | |||||
PanicSanity("Should not happen") | |||||
} else { | |||||
// pick random New bucket. | |||||
var bucket map[string]*knownAddress = nil | |||||
for len(bucket) == 0 { | |||||
bucket = a.addrNew[a.rand.Intn(len(a.addrNew))] | |||||
} | |||||
// pick a random ka from bucket. | |||||
randIndex := a.rand.Intn(len(bucket)) | |||||
for _, ka := range bucket { | |||||
if randIndex == 0 { | |||||
return ka.Addr | |||||
} | |||||
randIndex-- | |||||
} | |||||
PanicSanity("Should not happen") | |||||
} | |||||
return nil | |||||
} | |||||
func (a *AddrBook) MarkGood(addr *NetAddress) { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
ka := a.addrLookup[addr.String()] | |||||
if ka == nil { | |||||
return | |||||
} | |||||
ka.markGood() | |||||
if ka.isNew() { | |||||
a.moveToOld(ka) | |||||
} | |||||
} | |||||
func (a *AddrBook) MarkAttempt(addr *NetAddress) { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
ka := a.addrLookup[addr.String()] | |||||
if ka == nil { | |||||
return | |||||
} | |||||
ka.markAttempt() | |||||
} | |||||
// MarkBad currently just ejects the address. In the future, consider | |||||
// blacklisting. | |||||
func (a *AddrBook) MarkBad(addr *NetAddress) { | |||||
a.RemoveAddress(addr) | |||||
} | |||||
// RemoveAddress removes the address from the book. | |||||
func (a *AddrBook) RemoveAddress(addr *NetAddress) { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
ka := a.addrLookup[addr.String()] | |||||
if ka == nil { | |||||
return | |||||
} | |||||
log.Info("Remove address from book", "addr", addr) | |||||
a.removeFromAllBuckets(ka) | |||||
} | |||||
/* Peer exchange */ | |||||
// GetSelection randomly selects some addresses (old & new). Suitable for peer-exchange protocols. | |||||
func (a *AddrBook) GetSelection() []*NetAddress { | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
if a.size() == 0 { | |||||
return nil | |||||
} | |||||
allAddr := make([]*NetAddress, a.size()) | |||||
i := 0 | |||||
for _, v := range a.addrLookup { | |||||
allAddr[i] = v.Addr | |||||
i++ | |||||
} | |||||
numAddresses := MaxInt( | |||||
MinInt(minGetSelection, len(allAddr)), | |||||
len(allAddr)*getSelectionPercent/100) | |||||
numAddresses = MinInt(maxGetSelection, numAddresses) | |||||
// Fisher-Yates shuffle the array. We only need to do the first | |||||
// `numAddresses' since we are throwing the rest. | |||||
for i := 0; i < numAddresses; i++ { | |||||
// pick a number between current index and the end | |||||
j := rand.Intn(len(allAddr)-i) + i | |||||
allAddr[i], allAddr[j] = allAddr[j], allAddr[i] | |||||
} | |||||
// slice off the limit we are willing to share. | |||||
return allAddr[:numAddresses] | |||||
} | |||||
/* Loading & Saving */ | |||||
type addrBookJSON struct { | |||||
Key string | |||||
Addrs []*knownAddress | |||||
} | |||||
func (a *AddrBook) saveToFile(filePath string) { | |||||
log.Info("Saving AddrBook to file", "size", a.Size()) | |||||
a.mtx.Lock() | |||||
defer a.mtx.Unlock() | |||||
// Compile Addrs | |||||
addrs := []*knownAddress{} | |||||
for _, ka := range a.addrLookup { | |||||
addrs = append(addrs, ka) | |||||
} | |||||
aJSON := &addrBookJSON{ | |||||
Key: a.key, | |||||
Addrs: addrs, | |||||
} | |||||
jsonBytes, err := json.MarshalIndent(aJSON, "", "\t") | |||||
if err != nil { | |||||
log.Error("Failed to save AddrBook to file", "err", err) | |||||
return | |||||
} | |||||
err = WriteFileAtomic(filePath, jsonBytes, 0644) | |||||
if err != nil { | |||||
log.Error("Failed to save AddrBook to file", "file", filePath, "error", err) | |||||
} | |||||
} | |||||
// Returns false if file does not exist. | |||||
// Panics if file is corrupt. | |||||
func (a *AddrBook) loadFromFile(filePath string) bool { | |||||
// If doesn't exist, do nothing. | |||||
_, err := os.Stat(filePath) | |||||
if os.IsNotExist(err) { | |||||
return false | |||||
} | |||||
// Load addrBookJSON{} | |||||
r, err := os.Open(filePath) | |||||
if err != nil { | |||||
PanicCrisis(Fmt("Error opening file %s: %v", filePath, err)) | |||||
} | |||||
defer r.Close() | |||||
aJSON := &addrBookJSON{} | |||||
dec := json.NewDecoder(r) | |||||
err = dec.Decode(aJSON) | |||||
if err != nil { | |||||
PanicCrisis(Fmt("Error reading file %s: %v", filePath, err)) | |||||
} | |||||
// Restore all the fields... | |||||
// Restore the key | |||||
a.key = aJSON.Key | |||||
// Restore .addrNew & .addrOld | |||||
for _, ka := range aJSON.Addrs { | |||||
for _, bucketIndex := range ka.Buckets { | |||||
bucket := a.getBucket(ka.BucketType, bucketIndex) | |||||
bucket[ka.Addr.String()] = ka | |||||
} | |||||
a.addrLookup[ka.Addr.String()] = ka | |||||
if ka.BucketType == bucketTypeNew { | |||||
a.nNew++ | |||||
} else { | |||||
a.nOld++ | |||||
} | |||||
} | |||||
return true | |||||
} | |||||
// Save saves the book. | |||||
func (a *AddrBook) Save() { | |||||
log.Info("Saving AddrBook to file", "size", a.Size()) | |||||
a.saveToFile(a.filePath) | |||||
} | |||||
/* Private methods */ | |||||
func (a *AddrBook) saveRoutine() { | |||||
dumpAddressTicker := time.NewTicker(dumpAddressInterval) | |||||
out: | |||||
for { | |||||
select { | |||||
case <-dumpAddressTicker.C: | |||||
a.saveToFile(a.filePath) | |||||
case <-a.Quit: | |||||
break out | |||||
} | |||||
} | |||||
dumpAddressTicker.Stop() | |||||
a.saveToFile(a.filePath) | |||||
a.wg.Done() | |||||
log.Notice("Address handler done") | |||||
} | |||||
func (a *AddrBook) getBucket(bucketType byte, bucketIdx int) map[string]*knownAddress { | |||||
switch bucketType { | |||||
case bucketTypeNew: | |||||
return a.addrNew[bucketIdx] | |||||
case bucketTypeOld: | |||||
return a.addrOld[bucketIdx] | |||||
default: | |||||
PanicSanity("Should not happen") | |||||
return nil | |||||
} | |||||
} | |||||
// Adds ka to new bucket. Returns false if it couldn't do it cuz buckets full. | |||||
// NOTE: currently it always returns true. | |||||
func (a *AddrBook) addToNewBucket(ka *knownAddress, bucketIdx int) bool { | |||||
// Sanity check | |||||
if ka.isOld() { | |||||
log.Warn(Fmt("Cannot add address already in old bucket to a new bucket: %v", ka)) | |||||
return false | |||||
} | |||||
addrStr := ka.Addr.String() | |||||
bucket := a.getBucket(bucketTypeNew, bucketIdx) | |||||
// Already exists? | |||||
if _, ok := bucket[addrStr]; ok { | |||||
return true | |||||
} | |||||
// Enforce max addresses. | |||||
if len(bucket) > newBucketSize { | |||||
log.Notice("new bucket is full, expiring old ") | |||||
a.expireNew(bucketIdx) | |||||
} | |||||
// Add to bucket. | |||||
bucket[addrStr] = ka | |||||
if ka.addBucketRef(bucketIdx) == 1 { | |||||
a.nNew++ | |||||
} | |||||
// Ensure in addrLookup | |||||
a.addrLookup[addrStr] = ka | |||||
return true | |||||
} | |||||
// Adds ka to old bucket. Returns false if it couldn't do it cuz buckets full. | |||||
func (a *AddrBook) addToOldBucket(ka *knownAddress, bucketIdx int) bool { | |||||
// Sanity check | |||||
if ka.isNew() { | |||||
log.Warn(Fmt("Cannot add new address to old bucket: %v", ka)) | |||||
return false | |||||
} | |||||
if len(ka.Buckets) != 0 { | |||||
log.Warn(Fmt("Cannot add already old address to another old bucket: %v", ka)) | |||||
return false | |||||
} | |||||
addrStr := ka.Addr.String() | |||||
bucket := a.getBucket(bucketTypeNew, bucketIdx) | |||||
// Already exists? | |||||
if _, ok := bucket[addrStr]; ok { | |||||
return true | |||||
} | |||||
// Enforce max addresses. | |||||
if len(bucket) > oldBucketSize { | |||||
return false | |||||
} | |||||
// Add to bucket. | |||||
bucket[addrStr] = ka | |||||
if ka.addBucketRef(bucketIdx) == 1 { | |||||
a.nOld++ | |||||
} | |||||
// Ensure in addrLookup | |||||
a.addrLookup[addrStr] = ka | |||||
return true | |||||
} | |||||
func (a *AddrBook) removeFromBucket(ka *knownAddress, bucketType byte, bucketIdx int) { | |||||
if ka.BucketType != bucketType { | |||||
log.Warn(Fmt("Bucket type mismatch: %v", ka)) | |||||
return | |||||
} | |||||
bucket := a.getBucket(bucketType, bucketIdx) | |||||
delete(bucket, ka.Addr.String()) | |||||
if ka.removeBucketRef(bucketIdx) == 0 { | |||||
if bucketType == bucketTypeNew { | |||||
a.nNew-- | |||||
} else { | |||||
a.nOld-- | |||||
} | |||||
delete(a.addrLookup, ka.Addr.String()) | |||||
} | |||||
} | |||||
func (a *AddrBook) removeFromAllBuckets(ka *knownAddress) { | |||||
for _, bucketIdx := range ka.Buckets { | |||||
bucket := a.getBucket(ka.BucketType, bucketIdx) | |||||
delete(bucket, ka.Addr.String()) | |||||
} | |||||
ka.Buckets = nil | |||||
if ka.BucketType == bucketTypeNew { | |||||
a.nNew-- | |||||
} else { | |||||
a.nOld-- | |||||
} | |||||
delete(a.addrLookup, ka.Addr.String()) | |||||
} | |||||
func (a *AddrBook) pickOldest(bucketType byte, bucketIdx int) *knownAddress { | |||||
bucket := a.getBucket(bucketType, bucketIdx) | |||||
var oldest *knownAddress | |||||
for _, ka := range bucket { | |||||
if oldest == nil || ka.LastAttempt.Before(oldest.LastAttempt) { | |||||
oldest = ka | |||||
} | |||||
} | |||||
return oldest | |||||
} | |||||
func (a *AddrBook) addAddress(addr, src *NetAddress) { | |||||
if a.routabilityStrict && !addr.Routable() { | |||||
log.Warn(Fmt("Cannot add non-routable address %v", addr)) | |||||
return | |||||
} | |||||
if _, ok := a.ourAddrs[addr.String()]; ok { | |||||
// Ignore our own listener address. | |||||
return | |||||
} | |||||
ka := a.addrLookup[addr.String()] | |||||
if ka != nil { | |||||
// Already old. | |||||
if ka.isOld() { | |||||
return | |||||
} | |||||
// Already in max new buckets. | |||||
if len(ka.Buckets) == maxNewBucketsPerAddress { | |||||
return | |||||
} | |||||
// The more entries we have, the less likely we are to add more. | |||||
factor := int32(2 * len(ka.Buckets)) | |||||
if a.rand.Int31n(factor) != 0 { | |||||
return | |||||
} | |||||
} else { | |||||
ka = newKnownAddress(addr, src) | |||||
} | |||||
bucket := a.calcNewBucket(addr, src) | |||||
a.addToNewBucket(ka, bucket) | |||||
log.Notice("Added new address", "address", addr, "total", a.size()) | |||||
} | |||||
// Make space in the new buckets by expiring the really bad entries. | |||||
// If no bad entries are available we remove the oldest. | |||||
func (a *AddrBook) expireNew(bucketIdx int) { | |||||
for addrStr, ka := range a.addrNew[bucketIdx] { | |||||
// If an entry is bad, throw it away | |||||
if ka.isBad() { | |||||
log.Notice(Fmt("expiring bad address %v", addrStr)) | |||||
a.removeFromBucket(ka, bucketTypeNew, bucketIdx) | |||||
return | |||||
} | |||||
} | |||||
// If we haven't thrown out a bad entry, throw out the oldest entry | |||||
oldest := a.pickOldest(bucketTypeNew, bucketIdx) | |||||
a.removeFromBucket(oldest, bucketTypeNew, bucketIdx) | |||||
} | |||||
// Promotes an address from new to old. | |||||
// TODO: Move to old probabilistically. | |||||
// The better a node is, the less likely it should be evicted from an old bucket. | |||||
func (a *AddrBook) moveToOld(ka *knownAddress) { | |||||
// Sanity check | |||||
if ka.isOld() { | |||||
log.Warn(Fmt("Cannot promote address that is already old %v", ka)) | |||||
return | |||||
} | |||||
if len(ka.Buckets) == 0 { | |||||
log.Warn(Fmt("Cannot promote address that isn't in any new buckets %v", ka)) | |||||
return | |||||
} | |||||
// Remember one of the buckets in which ka is in. | |||||
freedBucket := ka.Buckets[0] | |||||
// Remove from all (new) buckets. | |||||
a.removeFromAllBuckets(ka) | |||||
// It's officially old now. | |||||
ka.BucketType = bucketTypeOld | |||||
// Try to add it to its oldBucket destination. | |||||
oldBucketIdx := a.calcOldBucket(ka.Addr) | |||||
added := a.addToOldBucket(ka, oldBucketIdx) | |||||
if !added { | |||||
// No room, must evict something | |||||
oldest := a.pickOldest(bucketTypeOld, oldBucketIdx) | |||||
a.removeFromBucket(oldest, bucketTypeOld, oldBucketIdx) | |||||
// Find new bucket to put oldest in | |||||
newBucketIdx := a.calcNewBucket(oldest.Addr, oldest.Src) | |||||
added := a.addToNewBucket(oldest, newBucketIdx) | |||||
// No space in newBucket either, just put it in freedBucket from above. | |||||
if !added { | |||||
added := a.addToNewBucket(oldest, freedBucket) | |||||
if !added { | |||||
log.Warn(Fmt("Could not migrate oldest %v to freedBucket %v", oldest, freedBucket)) | |||||
} | |||||
} | |||||
// Finally, add to bucket again. | |||||
added = a.addToOldBucket(ka, oldBucketIdx) | |||||
if !added { | |||||
log.Warn(Fmt("Could not re-add ka %v to oldBucketIdx %v", ka, oldBucketIdx)) | |||||
} | |||||
} | |||||
} | |||||
// doublesha256( key + sourcegroup + | |||||
// int64(doublesha256(key + group + sourcegroup))%bucket_per_group ) % num_new_buckets | |||||
func (a *AddrBook) calcNewBucket(addr, src *NetAddress) int { | |||||
data1 := []byte{} | |||||
data1 = append(data1, []byte(a.key)...) | |||||
data1 = append(data1, []byte(a.groupKey(addr))...) | |||||
data1 = append(data1, []byte(a.groupKey(src))...) | |||||
hash1 := doubleSha256(data1) | |||||
hash64 := binary.BigEndian.Uint64(hash1) | |||||
hash64 %= newBucketsPerGroup | |||||
var hashbuf [8]byte | |||||
binary.BigEndian.PutUint64(hashbuf[:], hash64) | |||||
data2 := []byte{} | |||||
data2 = append(data2, []byte(a.key)...) | |||||
data2 = append(data2, a.groupKey(src)...) | |||||
data2 = append(data2, hashbuf[:]...) | |||||
hash2 := doubleSha256(data2) | |||||
return int(binary.BigEndian.Uint64(hash2) % newBucketCount) | |||||
} | |||||
// doublesha256( key + group + | |||||
// int64(doublesha256(key + addr))%buckets_per_group ) % num_old_buckets | |||||
func (a *AddrBook) calcOldBucket(addr *NetAddress) int { | |||||
data1 := []byte{} | |||||
data1 = append(data1, []byte(a.key)...) | |||||
data1 = append(data1, []byte(addr.String())...) | |||||
hash1 := doubleSha256(data1) | |||||
hash64 := binary.BigEndian.Uint64(hash1) | |||||
hash64 %= oldBucketsPerGroup | |||||
var hashbuf [8]byte | |||||
binary.BigEndian.PutUint64(hashbuf[:], hash64) | |||||
data2 := []byte{} | |||||
data2 = append(data2, []byte(a.key)...) | |||||
data2 = append(data2, a.groupKey(addr)...) | |||||
data2 = append(data2, hashbuf[:]...) | |||||
hash2 := doubleSha256(data2) | |||||
return int(binary.BigEndian.Uint64(hash2) % oldBucketCount) | |||||
} | |||||
// Return a string representing the network group of this address. | |||||
// This is the /16 for IPv6, the /32 (/36 for he.net) for IPv6, the string | |||||
// "local" for a local address and the string "unroutable for an unroutable | |||||
// address. | |||||
func (a *AddrBook) groupKey(na *NetAddress) string { | |||||
if a.routabilityStrict && na.Local() { | |||||
return "local" | |||||
} | |||||
if a.routabilityStrict && !na.Routable() { | |||||
return "unroutable" | |||||
} | |||||
if ipv4 := na.IP.To4(); ipv4 != nil { | |||||
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(16, 32)}).String() | |||||
} | |||||
if na.RFC6145() || na.RFC6052() { | |||||
// last four bytes are the ip address | |||||
ip := net.IP(na.IP[12:16]) | |||||
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() | |||||
} | |||||
if na.RFC3964() { | |||||
ip := net.IP(na.IP[2:7]) | |||||
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() | |||||
} | |||||
if na.RFC4380() { | |||||
// teredo tunnels have the last 4 bytes as the v4 address XOR | |||||
// 0xff. | |||||
ip := net.IP(make([]byte, 4)) | |||||
for i, byte := range na.IP[12:16] { | |||||
ip[i] = byte ^ 0xff | |||||
} | |||||
return (&net.IPNet{IP: ip, Mask: net.CIDRMask(16, 32)}).String() | |||||
} | |||||
// OK, so now we know ourselves to be a IPv6 address. | |||||
// bitcoind uses /32 for everything, except for Hurricane Electric's | |||||
// (he.net) IP range, which it uses /36 for. | |||||
bits := 32 | |||||
heNet := &net.IPNet{IP: net.ParseIP("2001:470::"), | |||||
Mask: net.CIDRMask(32, 128)} | |||||
if heNet.Contains(na.IP) { | |||||
bits = 36 | |||||
} | |||||
return (&net.IPNet{IP: na.IP, Mask: net.CIDRMask(bits, 128)}).String() | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
/* | |||||
knownAddress | |||||
tracks information about a known network address that is used | |||||
to determine how viable an address is. | |||||
*/ | |||||
type knownAddress struct { | |||||
Addr *NetAddress | |||||
Src *NetAddress | |||||
Attempts int32 | |||||
LastAttempt time.Time | |||||
LastSuccess time.Time | |||||
BucketType byte | |||||
Buckets []int | |||||
} | |||||
func newKnownAddress(addr *NetAddress, src *NetAddress) *knownAddress { | |||||
return &knownAddress{ | |||||
Addr: addr, | |||||
Src: src, | |||||
Attempts: 0, | |||||
LastAttempt: time.Now(), | |||||
BucketType: bucketTypeNew, | |||||
Buckets: nil, | |||||
} | |||||
} | |||||
func (ka *knownAddress) isOld() bool { | |||||
return ka.BucketType == bucketTypeOld | |||||
} | |||||
func (ka *knownAddress) isNew() bool { | |||||
return ka.BucketType == bucketTypeNew | |||||
} | |||||
func (ka *knownAddress) markAttempt() { | |||||
now := time.Now() | |||||
ka.LastAttempt = now | |||||
ka.Attempts += 1 | |||||
} | |||||
func (ka *knownAddress) markGood() { | |||||
now := time.Now() | |||||
ka.LastAttempt = now | |||||
ka.Attempts = 0 | |||||
ka.LastSuccess = now | |||||
} | |||||
func (ka *knownAddress) addBucketRef(bucketIdx int) int { | |||||
for _, bucket := range ka.Buckets { | |||||
if bucket == bucketIdx { | |||||
log.Warn(Fmt("Bucket already exists in ka.Buckets: %v", ka)) | |||||
return -1 | |||||
} | |||||
} | |||||
ka.Buckets = append(ka.Buckets, bucketIdx) | |||||
return len(ka.Buckets) | |||||
} | |||||
func (ka *knownAddress) removeBucketRef(bucketIdx int) int { | |||||
buckets := []int{} | |||||
for _, bucket := range ka.Buckets { | |||||
if bucket != bucketIdx { | |||||
buckets = append(buckets, bucket) | |||||
} | |||||
} | |||||
if len(buckets) != len(ka.Buckets)-1 { | |||||
log.Warn(Fmt("bucketIdx not found in ka.Buckets: %v", ka)) | |||||
return -1 | |||||
} | |||||
ka.Buckets = buckets | |||||
return len(ka.Buckets) | |||||
} | |||||
/* | |||||
An address is bad if the address in question has not been tried in the last | |||||
minute and meets one of the following criteria: | |||||
1) It claims to be from the future | |||||
2) It hasn't been seen in over a month | |||||
3) It has failed at least three times and never succeeded | |||||
4) It has failed ten times in the last week | |||||
All addresses that meet these criteria are assumed to be worthless and not | |||||
worth keeping hold of. | |||||
*/ | |||||
func (ka *knownAddress) isBad() bool { | |||||
// Has been attempted in the last minute --> good | |||||
if ka.LastAttempt.Before(time.Now().Add(-1 * time.Minute)) { | |||||
return false | |||||
} | |||||
// Over a month old? | |||||
if ka.LastAttempt.After(time.Now().Add(-1 * numMissingDays * time.Hour * 24)) { | |||||
return true | |||||
} | |||||
// Never succeeded? | |||||
if ka.LastSuccess.IsZero() && ka.Attempts >= numRetries { | |||||
return true | |||||
} | |||||
// Hasn't succeeded in too long? | |||||
if ka.LastSuccess.Before(time.Now().Add(-1*minBadDays*time.Hour*24)) && | |||||
ka.Attempts >= maxFailures { | |||||
return true | |||||
} | |||||
return false | |||||
} |
@ -0,0 +1,166 @@ | |||||
package p2p | |||||
import ( | |||||
"fmt" | |||||
"io/ioutil" | |||||
"math/rand" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
) | |||||
func createTempFileName(prefix string) string { | |||||
f, err := ioutil.TempFile("", prefix) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
fname := f.Name() | |||||
err = f.Close() | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
return fname | |||||
} | |||||
func TestAddrBookSaveLoad(t *testing.T) { | |||||
fname := createTempFileName("addrbook_test") | |||||
// 0 addresses | |||||
book := NewAddrBook(fname, true) | |||||
book.saveToFile(fname) | |||||
book = NewAddrBook(fname, true) | |||||
book.loadFromFile(fname) | |||||
assert.Zero(t, book.Size()) | |||||
// 100 addresses | |||||
randAddrs := randNetAddressPairs(t, 100) | |||||
for _, addrSrc := range randAddrs { | |||||
book.AddAddress(addrSrc.addr, addrSrc.src) | |||||
} | |||||
assert.Equal(t, 100, book.Size()) | |||||
book.saveToFile(fname) | |||||
book = NewAddrBook(fname, true) | |||||
book.loadFromFile(fname) | |||||
assert.Equal(t, 100, book.Size()) | |||||
} | |||||
func TestAddrBookLookup(t *testing.T) { | |||||
fname := createTempFileName("addrbook_test") | |||||
randAddrs := randNetAddressPairs(t, 100) | |||||
book := NewAddrBook(fname, true) | |||||
for _, addrSrc := range randAddrs { | |||||
addr := addrSrc.addr | |||||
src := addrSrc.src | |||||
book.AddAddress(addr, src) | |||||
ka := book.addrLookup[addr.String()] | |||||
assert.NotNil(t, ka, "Expected to find KnownAddress %v but wasn't there.", addr) | |||||
if !(ka.Addr.Equals(addr) && ka.Src.Equals(src)) { | |||||
t.Fatalf("KnownAddress doesn't match addr & src") | |||||
} | |||||
} | |||||
} | |||||
func TestAddrBookPromoteToOld(t *testing.T) { | |||||
fname := createTempFileName("addrbook_test") | |||||
randAddrs := randNetAddressPairs(t, 100) | |||||
book := NewAddrBook(fname, true) | |||||
for _, addrSrc := range randAddrs { | |||||
book.AddAddress(addrSrc.addr, addrSrc.src) | |||||
} | |||||
// Attempt all addresses. | |||||
for _, addrSrc := range randAddrs { | |||||
book.MarkAttempt(addrSrc.addr) | |||||
} | |||||
// Promote half of them | |||||
for i, addrSrc := range randAddrs { | |||||
if i%2 == 0 { | |||||
book.MarkGood(addrSrc.addr) | |||||
} | |||||
} | |||||
// TODO: do more testing :) | |||||
selection := book.GetSelection() | |||||
t.Logf("selection: %v", selection) | |||||
if len(selection) > book.Size() { | |||||
t.Errorf("selection could not be bigger than the book") | |||||
} | |||||
} | |||||
func TestAddrBookHandlesDuplicates(t *testing.T) { | |||||
fname := createTempFileName("addrbook_test") | |||||
book := NewAddrBook(fname, true) | |||||
randAddrs := randNetAddressPairs(t, 100) | |||||
differentSrc := randIPv4Address(t) | |||||
for _, addrSrc := range randAddrs { | |||||
book.AddAddress(addrSrc.addr, addrSrc.src) | |||||
book.AddAddress(addrSrc.addr, addrSrc.src) // duplicate | |||||
book.AddAddress(addrSrc.addr, differentSrc) // different src | |||||
} | |||||
assert.Equal(t, 100, book.Size()) | |||||
} | |||||
type netAddressPair struct { | |||||
addr *NetAddress | |||||
src *NetAddress | |||||
} | |||||
func randNetAddressPairs(t *testing.T, n int) []netAddressPair { | |||||
randAddrs := make([]netAddressPair, n) | |||||
for i := 0; i < n; i++ { | |||||
randAddrs[i] = netAddressPair{addr: randIPv4Address(t), src: randIPv4Address(t)} | |||||
} | |||||
return randAddrs | |||||
} | |||||
func randIPv4Address(t *testing.T) *NetAddress { | |||||
for { | |||||
ip := fmt.Sprintf("%v.%v.%v.%v", | |||||
rand.Intn(254)+1, | |||||
rand.Intn(255), | |||||
rand.Intn(255), | |||||
rand.Intn(255), | |||||
) | |||||
port := rand.Intn(65535-1) + 1 | |||||
addr, err := NewNetAddressString(fmt.Sprintf("%v:%v", ip, port)) | |||||
assert.Nil(t, err, "error generating rand network address") | |||||
if addr.Routable() { | |||||
return addr | |||||
} | |||||
} | |||||
} | |||||
func TestAddrBookRemoveAddress(t *testing.T) { | |||||
fname := createTempFileName("addrbook_test") | |||||
book := NewAddrBook(fname, true) | |||||
addr := randIPv4Address(t) | |||||
book.AddAddress(addr, addr) | |||||
assert.Equal(t, 1, book.Size()) | |||||
book.RemoveAddress(addr) | |||||
assert.Equal(t, 0, book.Size()) | |||||
nonExistingAddr := randIPv4Address(t) | |||||
book.RemoveAddress(nonExistingAddr) | |||||
assert.Equal(t, 0, book.Size()) | |||||
} |
@ -0,0 +1,45 @@ | |||||
package p2p | |||||
import ( | |||||
cfg "github.com/tendermint/go-config" | |||||
) | |||||
const ( | |||||
// Switch config keys | |||||
configKeyDialTimeoutSeconds = "dial_timeout_seconds" | |||||
configKeyHandshakeTimeoutSeconds = "handshake_timeout_seconds" | |||||
configKeyMaxNumPeers = "max_num_peers" | |||||
configKeyAuthEnc = "authenticated_encryption" | |||||
// MConnection config keys | |||||
configKeySendRate = "send_rate" | |||||
configKeyRecvRate = "recv_rate" | |||||
// Fuzz params | |||||
configFuzzEnable = "fuzz_enable" // use the fuzz wrapped conn | |||||
configFuzzMode = "fuzz_mode" // eg. drop, delay | |||||
configFuzzMaxDelayMilliseconds = "fuzz_max_delay_milliseconds" | |||||
configFuzzProbDropRW = "fuzz_prob_drop_rw" | |||||
configFuzzProbDropConn = "fuzz_prob_drop_conn" | |||||
configFuzzProbSleep = "fuzz_prob_sleep" | |||||
) | |||||
func setConfigDefaults(config cfg.Config) { | |||||
// Switch default config | |||||
config.SetDefault(configKeyDialTimeoutSeconds, 3) | |||||
config.SetDefault(configKeyHandshakeTimeoutSeconds, 20) | |||||
config.SetDefault(configKeyMaxNumPeers, 50) | |||||
config.SetDefault(configKeyAuthEnc, true) | |||||
// MConnection default config | |||||
config.SetDefault(configKeySendRate, 512000) // 500KB/s | |||||
config.SetDefault(configKeyRecvRate, 512000) // 500KB/s | |||||
// Fuzz defaults | |||||
config.SetDefault(configFuzzEnable, false) | |||||
config.SetDefault(configFuzzMode, FuzzModeDrop) | |||||
config.SetDefault(configFuzzMaxDelayMilliseconds, 3000) | |||||
config.SetDefault(configFuzzProbDropRW, 0.2) | |||||
config.SetDefault(configFuzzProbDropConn, 0.00) | |||||
config.SetDefault(configFuzzProbSleep, 0.00) | |||||
} |
@ -0,0 +1,686 @@ | |||||
package p2p | |||||
import ( | |||||
"bufio" | |||||
"fmt" | |||||
"io" | |||||
"math" | |||||
"net" | |||||
"runtime/debug" | |||||
"sync/atomic" | |||||
"time" | |||||
wire "github.com/tendermint/go-wire" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
flow "github.com/tendermint/tmlibs/flowrate" | |||||
) | |||||
const ( | |||||
numBatchMsgPackets = 10 | |||||
minReadBufferSize = 1024 | |||||
minWriteBufferSize = 65536 | |||||
updateState = 2 * time.Second | |||||
pingTimeout = 40 * time.Second | |||||
flushThrottle = 100 * time.Millisecond | |||||
defaultSendQueueCapacity = 1 | |||||
defaultSendRate = int64(512000) // 500KB/s | |||||
defaultRecvBufferCapacity = 4096 | |||||
defaultRecvMessageCapacity = 22020096 // 21MB | |||||
defaultRecvRate = int64(512000) // 500KB/s | |||||
defaultSendTimeout = 10 * time.Second | |||||
) | |||||
type receiveCbFunc func(chID byte, msgBytes []byte) | |||||
type errorCbFunc func(interface{}) | |||||
/* | |||||
Each peer has one `MConnection` (multiplex connection) instance. | |||||
__multiplex__ *noun* a system or signal involving simultaneous transmission of | |||||
several messages along a single channel of communication. | |||||
Each `MConnection` handles message transmission on multiple abstract communication | |||||
`Channel`s. Each channel has a globally unique byte id. | |||||
The byte id and the relative priorities of each `Channel` are configured upon | |||||
initialization of the connection. | |||||
There are two methods for sending messages: | |||||
func (m MConnection) Send(chID byte, msg interface{}) bool {} | |||||
func (m MConnection) TrySend(chID byte, msg interface{}) bool {} | |||||
`Send(chID, msg)` is a blocking call that waits until `msg` is successfully queued | |||||
for the channel with the given id byte `chID`, or until the request times out. | |||||
The message `msg` is serialized using the `tendermint/wire` submodule's | |||||
`WriteBinary()` reflection routine. | |||||
`TrySend(chID, msg)` is a nonblocking call that returns false if the channel's | |||||
queue is full. | |||||
Inbound message bytes are handled with an onReceive callback function. | |||||
*/ | |||||
type MConnection struct { | |||||
cmn.BaseService | |||||
conn net.Conn | |||||
bufReader *bufio.Reader | |||||
bufWriter *bufio.Writer | |||||
sendMonitor *flow.Monitor | |||||
recvMonitor *flow.Monitor | |||||
send chan struct{} | |||||
pong chan struct{} | |||||
channels []*Channel | |||||
channelsIdx map[byte]*Channel | |||||
onReceive receiveCbFunc | |||||
onError errorCbFunc | |||||
errored uint32 | |||||
config *MConnConfig | |||||
quit chan struct{} | |||||
flushTimer *cmn.ThrottleTimer // flush writes as necessary but throttled. | |||||
pingTimer *cmn.RepeatTimer // send pings periodically | |||||
chStatsTimer *cmn.RepeatTimer // update channel stats periodically | |||||
LocalAddress *NetAddress | |||||
RemoteAddress *NetAddress | |||||
} | |||||
// MConnConfig is a MConnection configuration. | |||||
type MConnConfig struct { | |||||
SendRate int64 | |||||
RecvRate int64 | |||||
} | |||||
// DefaultMConnConfig returns the default config. | |||||
func DefaultMConnConfig() *MConnConfig { | |||||
return &MConnConfig{ | |||||
SendRate: defaultSendRate, | |||||
RecvRate: defaultRecvRate, | |||||
} | |||||
} | |||||
// NewMConnection wraps net.Conn and creates multiplex connection | |||||
func NewMConnection(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc) *MConnection { | |||||
return NewMConnectionWithConfig( | |||||
conn, | |||||
chDescs, | |||||
onReceive, | |||||
onError, | |||||
DefaultMConnConfig()) | |||||
} | |||||
// NewMConnectionWithConfig wraps net.Conn and creates multiplex connection with a config | |||||
func NewMConnectionWithConfig(conn net.Conn, chDescs []*ChannelDescriptor, onReceive receiveCbFunc, onError errorCbFunc, config *MConnConfig) *MConnection { | |||||
mconn := &MConnection{ | |||||
conn: conn, | |||||
bufReader: bufio.NewReaderSize(conn, minReadBufferSize), | |||||
bufWriter: bufio.NewWriterSize(conn, minWriteBufferSize), | |||||
sendMonitor: flow.New(0, 0), | |||||
recvMonitor: flow.New(0, 0), | |||||
send: make(chan struct{}, 1), | |||||
pong: make(chan struct{}), | |||||
onReceive: onReceive, | |||||
onError: onError, | |||||
config: config, | |||||
LocalAddress: NewNetAddress(conn.LocalAddr()), | |||||
RemoteAddress: NewNetAddress(conn.RemoteAddr()), | |||||
} | |||||
// Create channels | |||||
var channelsIdx = map[byte]*Channel{} | |||||
var channels = []*Channel{} | |||||
for _, desc := range chDescs { | |||||
descCopy := *desc // copy the desc else unsafe access across connections | |||||
channel := newChannel(mconn, &descCopy) | |||||
channelsIdx[channel.id] = channel | |||||
channels = append(channels, channel) | |||||
} | |||||
mconn.channels = channels | |||||
mconn.channelsIdx = channelsIdx | |||||
mconn.BaseService = *cmn.NewBaseService(log, "MConnection", mconn) | |||||
return mconn | |||||
} | |||||
func (c *MConnection) OnStart() error { | |||||
c.BaseService.OnStart() | |||||
c.quit = make(chan struct{}) | |||||
c.flushTimer = cmn.NewThrottleTimer("flush", flushThrottle) | |||||
c.pingTimer = cmn.NewRepeatTimer("ping", pingTimeout) | |||||
c.chStatsTimer = cmn.NewRepeatTimer("chStats", updateState) | |||||
go c.sendRoutine() | |||||
go c.recvRoutine() | |||||
return nil | |||||
} | |||||
func (c *MConnection) OnStop() { | |||||
c.BaseService.OnStop() | |||||
c.flushTimer.Stop() | |||||
c.pingTimer.Stop() | |||||
c.chStatsTimer.Stop() | |||||
if c.quit != nil { | |||||
close(c.quit) | |||||
} | |||||
c.conn.Close() | |||||
// We can't close pong safely here because | |||||
// recvRoutine may write to it after we've stopped. | |||||
// Though it doesn't need to get closed at all, | |||||
// we close it @ recvRoutine. | |||||
// close(c.pong) | |||||
} | |||||
func (c *MConnection) String() string { | |||||
return fmt.Sprintf("MConn{%v}", c.conn.RemoteAddr()) | |||||
} | |||||
func (c *MConnection) flush() { | |||||
log.Debug("Flush", "conn", c) | |||||
err := c.bufWriter.Flush() | |||||
if err != nil { | |||||
log.Warn("MConnection flush failed", "error", err) | |||||
} | |||||
} | |||||
// Catch panics, usually caused by remote disconnects. | |||||
func (c *MConnection) _recover() { | |||||
if r := recover(); r != nil { | |||||
stack := debug.Stack() | |||||
err := cmn.StackError{r, stack} | |||||
c.stopForError(err) | |||||
} | |||||
} | |||||
func (c *MConnection) stopForError(r interface{}) { | |||||
c.Stop() | |||||
if atomic.CompareAndSwapUint32(&c.errored, 0, 1) { | |||||
if c.onError != nil { | |||||
c.onError(r) | |||||
} | |||||
} | |||||
} | |||||
// Queues a message to be sent to channel. | |||||
func (c *MConnection) Send(chID byte, msg interface{}) bool { | |||||
if !c.IsRunning() { | |||||
return false | |||||
} | |||||
log.Debug("Send", "channel", chID, "conn", c, "msg", msg) //, "bytes", wire.BinaryBytes(msg)) | |||||
// Send message to channel. | |||||
channel, ok := c.channelsIdx[chID] | |||||
if !ok { | |||||
log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID)) | |||||
return false | |||||
} | |||||
success := channel.sendBytes(wire.BinaryBytes(msg)) | |||||
if success { | |||||
// Wake up sendRoutine if necessary | |||||
select { | |||||
case c.send <- struct{}{}: | |||||
default: | |||||
} | |||||
} else { | |||||
log.Warn("Send failed", "channel", chID, "conn", c, "msg", msg) | |||||
} | |||||
return success | |||||
} | |||||
// Queues a message to be sent to channel. | |||||
// Nonblocking, returns true if successful. | |||||
func (c *MConnection) TrySend(chID byte, msg interface{}) bool { | |||||
if !c.IsRunning() { | |||||
return false | |||||
} | |||||
log.Debug("TrySend", "channel", chID, "conn", c, "msg", msg) | |||||
// Send message to channel. | |||||
channel, ok := c.channelsIdx[chID] | |||||
if !ok { | |||||
log.Error(cmn.Fmt("Cannot send bytes, unknown channel %X", chID)) | |||||
return false | |||||
} | |||||
ok = channel.trySendBytes(wire.BinaryBytes(msg)) | |||||
if ok { | |||||
// Wake up sendRoutine if necessary | |||||
select { | |||||
case c.send <- struct{}{}: | |||||
default: | |||||
} | |||||
} | |||||
return ok | |||||
} | |||||
// CanSend returns true if you can send more data onto the chID, false | |||||
// otherwise. Use only as a heuristic. | |||||
func (c *MConnection) CanSend(chID byte) bool { | |||||
if !c.IsRunning() { | |||||
return false | |||||
} | |||||
channel, ok := c.channelsIdx[chID] | |||||
if !ok { | |||||
log.Error(cmn.Fmt("Unknown channel %X", chID)) | |||||
return false | |||||
} | |||||
return channel.canSend() | |||||
} | |||||
// sendRoutine polls for packets to send from channels. | |||||
func (c *MConnection) sendRoutine() { | |||||
defer c._recover() | |||||
FOR_LOOP: | |||||
for { | |||||
var n int | |||||
var err error | |||||
select { | |||||
case <-c.flushTimer.Ch: | |||||
// NOTE: flushTimer.Set() must be called every time | |||||
// something is written to .bufWriter. | |||||
c.flush() | |||||
case <-c.chStatsTimer.Ch: | |||||
for _, channel := range c.channels { | |||||
channel.updateStats() | |||||
} | |||||
case <-c.pingTimer.Ch: | |||||
log.Debug("Send Ping") | |||||
wire.WriteByte(packetTypePing, c.bufWriter, &n, &err) | |||||
c.sendMonitor.Update(int(n)) | |||||
c.flush() | |||||
case <-c.pong: | |||||
log.Debug("Send Pong") | |||||
wire.WriteByte(packetTypePong, c.bufWriter, &n, &err) | |||||
c.sendMonitor.Update(int(n)) | |||||
c.flush() | |||||
case <-c.quit: | |||||
break FOR_LOOP | |||||
case <-c.send: | |||||
// Send some msgPackets | |||||
eof := c.sendSomeMsgPackets() | |||||
if !eof { | |||||
// Keep sendRoutine awake. | |||||
select { | |||||
case c.send <- struct{}{}: | |||||
default: | |||||
} | |||||
} | |||||
} | |||||
if !c.IsRunning() { | |||||
break FOR_LOOP | |||||
} | |||||
if err != nil { | |||||
log.Warn("Connection failed @ sendRoutine", "conn", c, "error", err) | |||||
c.stopForError(err) | |||||
break FOR_LOOP | |||||
} | |||||
} | |||||
// Cleanup | |||||
} | |||||
// Returns true if messages from channels were exhausted. | |||||
// Blocks in accordance to .sendMonitor throttling. | |||||
func (c *MConnection) sendSomeMsgPackets() bool { | |||||
// Block until .sendMonitor says we can write. | |||||
// Once we're ready we send more than we asked for, | |||||
// but amortized it should even out. | |||||
c.sendMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.SendRate), true) | |||||
// Now send some msgPackets. | |||||
for i := 0; i < numBatchMsgPackets; i++ { | |||||
if c.sendMsgPacket() { | |||||
return true | |||||
} | |||||
} | |||||
return false | |||||
} | |||||
// Returns true if messages from channels were exhausted. | |||||
func (c *MConnection) sendMsgPacket() bool { | |||||
// Choose a channel to create a msgPacket from. | |||||
// The chosen channel will be the one whose recentlySent/priority is the least. | |||||
var leastRatio float32 = math.MaxFloat32 | |||||
var leastChannel *Channel | |||||
for _, channel := range c.channels { | |||||
// If nothing to send, skip this channel | |||||
if !channel.isSendPending() { | |||||
continue | |||||
} | |||||
// Get ratio, and keep track of lowest ratio. | |||||
ratio := float32(channel.recentlySent) / float32(channel.priority) | |||||
if ratio < leastRatio { | |||||
leastRatio = ratio | |||||
leastChannel = channel | |||||
} | |||||
} | |||||
// Nothing to send? | |||||
if leastChannel == nil { | |||||
return true | |||||
} else { | |||||
// log.Info("Found a msgPacket to send") | |||||
} | |||||
// Make & send a msgPacket from this channel | |||||
n, err := leastChannel.writeMsgPacketTo(c.bufWriter) | |||||
if err != nil { | |||||
log.Warn("Failed to write msgPacket", "error", err) | |||||
c.stopForError(err) | |||||
return true | |||||
} | |||||
c.sendMonitor.Update(int(n)) | |||||
c.flushTimer.Set() | |||||
return false | |||||
} | |||||
// recvRoutine reads msgPackets and reconstructs the message using the channels' "recving" buffer. | |||||
// After a whole message has been assembled, it's pushed to onReceive(). | |||||
// Blocks depending on how the connection is throttled. | |||||
func (c *MConnection) recvRoutine() { | |||||
defer c._recover() | |||||
FOR_LOOP: | |||||
for { | |||||
// Block until .recvMonitor says we can read. | |||||
c.recvMonitor.Limit(maxMsgPacketTotalSize, atomic.LoadInt64(&c.config.RecvRate), true) | |||||
/* | |||||
// Peek into bufReader for debugging | |||||
if numBytes := c.bufReader.Buffered(); numBytes > 0 { | |||||
log.Info("Peek connection buffer", "numBytes", numBytes, "bytes", log15.Lazy{func() []byte { | |||||
bytes, err := c.bufReader.Peek(MinInt(numBytes, 100)) | |||||
if err == nil { | |||||
return bytes | |||||
} else { | |||||
log.Warn("Error peeking connection buffer", "error", err) | |||||
return nil | |||||
} | |||||
}}) | |||||
} | |||||
*/ | |||||
// Read packet type | |||||
var n int | |||||
var err error | |||||
pktType := wire.ReadByte(c.bufReader, &n, &err) | |||||
c.recvMonitor.Update(int(n)) | |||||
if err != nil { | |||||
if c.IsRunning() { | |||||
log.Warn("Connection failed @ recvRoutine (reading byte)", "conn", c, "error", err) | |||||
c.stopForError(err) | |||||
} | |||||
break FOR_LOOP | |||||
} | |||||
// Read more depending on packet type. | |||||
switch pktType { | |||||
case packetTypePing: | |||||
// TODO: prevent abuse, as they cause flush()'s. | |||||
log.Debug("Receive Ping") | |||||
c.pong <- struct{}{} | |||||
case packetTypePong: | |||||
// do nothing | |||||
log.Debug("Receive Pong") | |||||
case packetTypeMsg: | |||||
pkt, n, err := msgPacket{}, int(0), error(nil) | |||||
wire.ReadBinaryPtr(&pkt, c.bufReader, maxMsgPacketTotalSize, &n, &err) | |||||
c.recvMonitor.Update(int(n)) | |||||
if err != nil { | |||||
if c.IsRunning() { | |||||
log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err) | |||||
c.stopForError(err) | |||||
} | |||||
break FOR_LOOP | |||||
} | |||||
channel, ok := c.channelsIdx[pkt.ChannelID] | |||||
if !ok || channel == nil { | |||||
cmn.PanicQ(cmn.Fmt("Unknown channel %X", pkt.ChannelID)) | |||||
} | |||||
msgBytes, err := channel.recvMsgPacket(pkt) | |||||
if err != nil { | |||||
if c.IsRunning() { | |||||
log.Warn("Connection failed @ recvRoutine", "conn", c, "error", err) | |||||
c.stopForError(err) | |||||
} | |||||
break FOR_LOOP | |||||
} | |||||
if msgBytes != nil { | |||||
log.Debug("Received bytes", "chID", pkt.ChannelID, "msgBytes", msgBytes) | |||||
c.onReceive(pkt.ChannelID, msgBytes) | |||||
} | |||||
default: | |||||
cmn.PanicSanity(cmn.Fmt("Unknown message type %X", pktType)) | |||||
} | |||||
// TODO: shouldn't this go in the sendRoutine? | |||||
// Better to send a ping packet when *we* haven't sent anything for a while. | |||||
c.pingTimer.Reset() | |||||
} | |||||
// Cleanup | |||||
close(c.pong) | |||||
for _ = range c.pong { | |||||
// Drain | |||||
} | |||||
} | |||||
type ConnectionStatus struct { | |||||
SendMonitor flow.Status | |||||
RecvMonitor flow.Status | |||||
Channels []ChannelStatus | |||||
} | |||||
type ChannelStatus struct { | |||||
ID byte | |||||
SendQueueCapacity int | |||||
SendQueueSize int | |||||
Priority int | |||||
RecentlySent int64 | |||||
} | |||||
func (c *MConnection) Status() ConnectionStatus { | |||||
var status ConnectionStatus | |||||
status.SendMonitor = c.sendMonitor.Status() | |||||
status.RecvMonitor = c.recvMonitor.Status() | |||||
status.Channels = make([]ChannelStatus, len(c.channels)) | |||||
for i, channel := range c.channels { | |||||
status.Channels[i] = ChannelStatus{ | |||||
ID: channel.id, | |||||
SendQueueCapacity: cap(channel.sendQueue), | |||||
SendQueueSize: int(channel.sendQueueSize), // TODO use atomic | |||||
Priority: channel.priority, | |||||
RecentlySent: channel.recentlySent, | |||||
} | |||||
} | |||||
return status | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
type ChannelDescriptor struct { | |||||
ID byte | |||||
Priority int | |||||
SendQueueCapacity int | |||||
RecvBufferCapacity int | |||||
RecvMessageCapacity int | |||||
} | |||||
func (chDesc *ChannelDescriptor) FillDefaults() { | |||||
if chDesc.SendQueueCapacity == 0 { | |||||
chDesc.SendQueueCapacity = defaultSendQueueCapacity | |||||
} | |||||
if chDesc.RecvBufferCapacity == 0 { | |||||
chDesc.RecvBufferCapacity = defaultRecvBufferCapacity | |||||
} | |||||
if chDesc.RecvMessageCapacity == 0 { | |||||
chDesc.RecvMessageCapacity = defaultRecvMessageCapacity | |||||
} | |||||
} | |||||
// TODO: lowercase. | |||||
// NOTE: not goroutine-safe. | |||||
type Channel struct { | |||||
conn *MConnection | |||||
desc *ChannelDescriptor | |||||
id byte | |||||
sendQueue chan []byte | |||||
sendQueueSize int32 // atomic. | |||||
recving []byte | |||||
sending []byte | |||||
priority int | |||||
recentlySent int64 // exponential moving average | |||||
} | |||||
func newChannel(conn *MConnection, desc *ChannelDescriptor) *Channel { | |||||
desc.FillDefaults() | |||||
if desc.Priority <= 0 { | |||||
cmn.PanicSanity("Channel default priority must be a postive integer") | |||||
} | |||||
return &Channel{ | |||||
conn: conn, | |||||
desc: desc, | |||||
id: desc.ID, | |||||
sendQueue: make(chan []byte, desc.SendQueueCapacity), | |||||
recving: make([]byte, 0, desc.RecvBufferCapacity), | |||||
priority: desc.Priority, | |||||
} | |||||
} | |||||
// Queues message to send to this channel. | |||||
// Goroutine-safe | |||||
// Times out (and returns false) after defaultSendTimeout | |||||
func (ch *Channel) sendBytes(bytes []byte) bool { | |||||
select { | |||||
case ch.sendQueue <- bytes: | |||||
atomic.AddInt32(&ch.sendQueueSize, 1) | |||||
return true | |||||
case <-time.After(defaultSendTimeout): | |||||
return false | |||||
} | |||||
} | |||||
// Queues message to send to this channel. | |||||
// Nonblocking, returns true if successful. | |||||
// Goroutine-safe | |||||
func (ch *Channel) trySendBytes(bytes []byte) bool { | |||||
select { | |||||
case ch.sendQueue <- bytes: | |||||
atomic.AddInt32(&ch.sendQueueSize, 1) | |||||
return true | |||||
default: | |||||
return false | |||||
} | |||||
} | |||||
// Goroutine-safe | |||||
func (ch *Channel) loadSendQueueSize() (size int) { | |||||
return int(atomic.LoadInt32(&ch.sendQueueSize)) | |||||
} | |||||
// Goroutine-safe | |||||
// Use only as a heuristic. | |||||
func (ch *Channel) canSend() bool { | |||||
return ch.loadSendQueueSize() < defaultSendQueueCapacity | |||||
} | |||||
// Returns true if any msgPackets are pending to be sent. | |||||
// Call before calling nextMsgPacket() | |||||
// Goroutine-safe | |||||
func (ch *Channel) isSendPending() bool { | |||||
if len(ch.sending) == 0 { | |||||
if len(ch.sendQueue) == 0 { | |||||
return false | |||||
} | |||||
ch.sending = <-ch.sendQueue | |||||
} | |||||
return true | |||||
} | |||||
// Creates a new msgPacket to send. | |||||
// Not goroutine-safe | |||||
func (ch *Channel) nextMsgPacket() msgPacket { | |||||
packet := msgPacket{} | |||||
packet.ChannelID = byte(ch.id) | |||||
packet.Bytes = ch.sending[:cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending))] | |||||
if len(ch.sending) <= maxMsgPacketPayloadSize { | |||||
packet.EOF = byte(0x01) | |||||
ch.sending = nil | |||||
atomic.AddInt32(&ch.sendQueueSize, -1) // decrement sendQueueSize | |||||
} else { | |||||
packet.EOF = byte(0x00) | |||||
ch.sending = ch.sending[cmn.MinInt(maxMsgPacketPayloadSize, len(ch.sending)):] | |||||
} | |||||
return packet | |||||
} | |||||
// Writes next msgPacket to w. | |||||
// Not goroutine-safe | |||||
func (ch *Channel) writeMsgPacketTo(w io.Writer) (n int, err error) { | |||||
packet := ch.nextMsgPacket() | |||||
log.Debug("Write Msg Packet", "conn", ch.conn, "packet", packet) | |||||
wire.WriteByte(packetTypeMsg, w, &n, &err) | |||||
wire.WriteBinary(packet, w, &n, &err) | |||||
if err == nil { | |||||
ch.recentlySent += int64(n) | |||||
} | |||||
return | |||||
} | |||||
// Handles incoming msgPackets. Returns a msg bytes if msg is complete. | |||||
// Not goroutine-safe | |||||
func (ch *Channel) recvMsgPacket(packet msgPacket) ([]byte, error) { | |||||
// log.Debug("Read Msg Packet", "conn", ch.conn, "packet", packet) | |||||
if ch.desc.RecvMessageCapacity < len(ch.recving)+len(packet.Bytes) { | |||||
return nil, wire.ErrBinaryReadOverflow | |||||
} | |||||
ch.recving = append(ch.recving, packet.Bytes...) | |||||
if packet.EOF == byte(0x01) { | |||||
msgBytes := ch.recving | |||||
// clear the slice without re-allocating. | |||||
// http://stackoverflow.com/questions/16971741/how-do-you-clear-a-slice-in-go | |||||
// suggests this could be a memory leak, but we might as well keep the memory for the channel until it closes, | |||||
// at which point the recving slice stops being used and should be garbage collected | |||||
ch.recving = ch.recving[:0] // make([]byte, 0, ch.desc.RecvBufferCapacity) | |||||
return msgBytes, nil | |||||
} | |||||
return nil, nil | |||||
} | |||||
// Call this periodically to update stats for throttling purposes. | |||||
// Not goroutine-safe | |||||
func (ch *Channel) updateStats() { | |||||
// Exponential decay of stats. | |||||
// TODO: optimize. | |||||
ch.recentlySent = int64(float64(ch.recentlySent) * 0.8) | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
const ( | |||||
maxMsgPacketPayloadSize = 1024 | |||||
maxMsgPacketOverheadSize = 10 // It's actually lower but good enough | |||||
maxMsgPacketTotalSize = maxMsgPacketPayloadSize + maxMsgPacketOverheadSize | |||||
packetTypePing = byte(0x01) | |||||
packetTypePong = byte(0x02) | |||||
packetTypeMsg = byte(0x03) | |||||
) | |||||
// Messages in channels are chopped into smaller msgPackets for multiplexing. | |||||
type msgPacket struct { | |||||
ChannelID byte | |||||
EOF byte // 1 means message ends here. | |||||
Bytes []byte | |||||
} | |||||
func (p msgPacket) String() string { | |||||
return fmt.Sprintf("MsgPacket{%X:%X T:%X}", p.ChannelID, p.Bytes, p.EOF) | |||||
} |
@ -0,0 +1,139 @@ | |||||
package p2p_test | |||||
import ( | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
p2p "github.com/tendermint/tendermint/p2p" | |||||
) | |||||
func createMConnection(conn net.Conn) *p2p.MConnection { | |||||
onReceive := func(chID byte, msgBytes []byte) { | |||||
} | |||||
onError := func(r interface{}) { | |||||
} | |||||
return createMConnectionWithCallbacks(conn, onReceive, onError) | |||||
} | |||||
func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *p2p.MConnection { | |||||
chDescs := []*p2p.ChannelDescriptor{&p2p.ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}} | |||||
return p2p.NewMConnection(conn, chDescs, onReceive, onError) | |||||
} | |||||
func TestMConnectionSend(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
server, client := net.Pipe() | |||||
defer server.Close() | |||||
defer client.Close() | |||||
mconn := createMConnection(client) | |||||
_, err := mconn.Start() | |||||
require.Nil(err) | |||||
defer mconn.Stop() | |||||
msg := "Ant-Man" | |||||
assert.True(mconn.Send(0x01, msg)) | |||||
// Note: subsequent Send/TrySend calls could pass because we are reading from | |||||
// the send queue in a separate goroutine. | |||||
server.Read(make([]byte, len(msg))) | |||||
assert.True(mconn.CanSend(0x01)) | |||||
msg = "Spider-Man" | |||||
assert.True(mconn.TrySend(0x01, msg)) | |||||
server.Read(make([]byte, len(msg))) | |||||
assert.False(mconn.CanSend(0x05), "CanSend should return false because channel is unknown") | |||||
assert.False(mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown") | |||||
} | |||||
func TestMConnectionReceive(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
server, client := net.Pipe() | |||||
defer server.Close() | |||||
defer client.Close() | |||||
receivedCh := make(chan []byte) | |||||
errorsCh := make(chan interface{}) | |||||
onReceive := func(chID byte, msgBytes []byte) { | |||||
receivedCh <- msgBytes | |||||
} | |||||
onError := func(r interface{}) { | |||||
errorsCh <- r | |||||
} | |||||
mconn1 := createMConnectionWithCallbacks(client, onReceive, onError) | |||||
_, err := mconn1.Start() | |||||
require.Nil(err) | |||||
defer mconn1.Stop() | |||||
mconn2 := createMConnection(server) | |||||
_, err = mconn2.Start() | |||||
require.Nil(err) | |||||
defer mconn2.Stop() | |||||
msg := "Cyclops" | |||||
assert.True(mconn2.Send(0x01, msg)) | |||||
select { | |||||
case receivedBytes := <-receivedCh: | |||||
assert.Equal([]byte(msg), receivedBytes[2:]) // first 3 bytes are internal | |||||
case err := <-errorsCh: | |||||
t.Fatalf("Expected %s, got %+v", msg, err) | |||||
case <-time.After(500 * time.Millisecond): | |||||
t.Fatalf("Did not receive %s message in 500ms", msg) | |||||
} | |||||
} | |||||
func TestMConnectionStatus(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
server, client := net.Pipe() | |||||
defer server.Close() | |||||
defer client.Close() | |||||
mconn := createMConnection(client) | |||||
_, err := mconn.Start() | |||||
require.Nil(err) | |||||
defer mconn.Stop() | |||||
status := mconn.Status() | |||||
assert.NotNil(status) | |||||
assert.Zero(status.Channels[0].SendQueueSize) | |||||
} | |||||
func TestMConnectionStopsAndReturnsError(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
server, client := net.Pipe() | |||||
defer server.Close() | |||||
defer client.Close() | |||||
receivedCh := make(chan []byte) | |||||
errorsCh := make(chan interface{}) | |||||
onReceive := func(chID byte, msgBytes []byte) { | |||||
receivedCh <- msgBytes | |||||
} | |||||
onError := func(r interface{}) { | |||||
errorsCh <- r | |||||
} | |||||
mconn := createMConnectionWithCallbacks(client, onReceive, onError) | |||||
_, err := mconn.Start() | |||||
require.Nil(err) | |||||
defer mconn.Stop() | |||||
client.Close() | |||||
select { | |||||
case receivedBytes := <-receivedCh: | |||||
t.Fatalf("Expected error, got %v", receivedBytes) | |||||
case err := <-errorsCh: | |||||
assert.NotNil(err) | |||||
assert.False(mconn.IsRunning()) | |||||
case <-time.After(500 * time.Millisecond): | |||||
t.Fatal("Did not receive error in 500ms") | |||||
} | |||||
} |
@ -0,0 +1,173 @@ | |||||
package p2p | |||||
import ( | |||||
"math/rand" | |||||
"net" | |||||
"sync" | |||||
"time" | |||||
) | |||||
const ( | |||||
// FuzzModeDrop is a mode in which we randomly drop reads/writes, connections or sleep | |||||
FuzzModeDrop = iota | |||||
// FuzzModeDelay is a mode in which we randomly sleep | |||||
FuzzModeDelay | |||||
) | |||||
// FuzzedConnection wraps any net.Conn and depending on the mode either delays | |||||
// reads/writes or randomly drops reads/writes/connections. | |||||
type FuzzedConnection struct { | |||||
conn net.Conn | |||||
mtx sync.Mutex | |||||
start <-chan time.Time | |||||
active bool | |||||
config *FuzzConnConfig | |||||
} | |||||
// FuzzConnConfig is a FuzzedConnection configuration. | |||||
type FuzzConnConfig struct { | |||||
Mode int | |||||
MaxDelay time.Duration | |||||
ProbDropRW float64 | |||||
ProbDropConn float64 | |||||
ProbSleep float64 | |||||
} | |||||
// DefaultFuzzConnConfig returns the default config. | |||||
func DefaultFuzzConnConfig() *FuzzConnConfig { | |||||
return &FuzzConnConfig{ | |||||
Mode: FuzzModeDrop, | |||||
MaxDelay: 3 * time.Second, | |||||
ProbDropRW: 0.2, | |||||
ProbDropConn: 0.00, | |||||
ProbSleep: 0.00, | |||||
} | |||||
} | |||||
// FuzzConn creates a new FuzzedConnection. Fuzzing starts immediately. | |||||
func FuzzConn(conn net.Conn) net.Conn { | |||||
return FuzzConnFromConfig(conn, DefaultFuzzConnConfig()) | |||||
} | |||||
// FuzzConnFromConfig creates a new FuzzedConnection from a config. Fuzzing | |||||
// starts immediately. | |||||
func FuzzConnFromConfig(conn net.Conn, config *FuzzConnConfig) net.Conn { | |||||
return &FuzzedConnection{ | |||||
conn: conn, | |||||
start: make(<-chan time.Time), | |||||
active: true, | |||||
config: config, | |||||
} | |||||
} | |||||
// FuzzConnAfter creates a new FuzzedConnection. Fuzzing starts when the | |||||
// duration elapses. | |||||
func FuzzConnAfter(conn net.Conn, d time.Duration) net.Conn { | |||||
return FuzzConnAfterFromConfig(conn, d, DefaultFuzzConnConfig()) | |||||
} | |||||
// FuzzConnAfterFromConfig creates a new FuzzedConnection from a config. | |||||
// Fuzzing starts when the duration elapses. | |||||
func FuzzConnAfterFromConfig(conn net.Conn, d time.Duration, config *FuzzConnConfig) net.Conn { | |||||
return &FuzzedConnection{ | |||||
conn: conn, | |||||
start: time.After(d), | |||||
active: false, | |||||
config: config, | |||||
} | |||||
} | |||||
// Config returns the connection's config. | |||||
func (fc *FuzzedConnection) Config() *FuzzConnConfig { | |||||
return fc.config | |||||
} | |||||
// Read implements net.Conn. | |||||
func (fc *FuzzedConnection) Read(data []byte) (n int, err error) { | |||||
if fc.fuzz() { | |||||
return 0, nil | |||||
} | |||||
return fc.conn.Read(data) | |||||
} | |||||
// Write implements net.Conn. | |||||
func (fc *FuzzedConnection) Write(data []byte) (n int, err error) { | |||||
if fc.fuzz() { | |||||
return 0, nil | |||||
} | |||||
return fc.conn.Write(data) | |||||
} | |||||
// Close implements net.Conn. | |||||
func (fc *FuzzedConnection) Close() error { return fc.conn.Close() } | |||||
// LocalAddr implements net.Conn. | |||||
func (fc *FuzzedConnection) LocalAddr() net.Addr { return fc.conn.LocalAddr() } | |||||
// RemoteAddr implements net.Conn. | |||||
func (fc *FuzzedConnection) RemoteAddr() net.Addr { return fc.conn.RemoteAddr() } | |||||
// SetDeadline implements net.Conn. | |||||
func (fc *FuzzedConnection) SetDeadline(t time.Time) error { return fc.conn.SetDeadline(t) } | |||||
// SetReadDeadline implements net.Conn. | |||||
func (fc *FuzzedConnection) SetReadDeadline(t time.Time) error { | |||||
return fc.conn.SetReadDeadline(t) | |||||
} | |||||
// SetWriteDeadline implements net.Conn. | |||||
func (fc *FuzzedConnection) SetWriteDeadline(t time.Time) error { | |||||
return fc.conn.SetWriteDeadline(t) | |||||
} | |||||
func (fc *FuzzedConnection) randomDuration() time.Duration { | |||||
maxDelayMillis := int(fc.config.MaxDelay.Nanoseconds() / 1000) | |||||
return time.Millisecond * time.Duration(rand.Int()%maxDelayMillis) | |||||
} | |||||
// implements the fuzz (delay, kill conn) | |||||
// and returns whether or not the read/write should be ignored | |||||
func (fc *FuzzedConnection) fuzz() bool { | |||||
if !fc.shouldFuzz() { | |||||
return false | |||||
} | |||||
switch fc.config.Mode { | |||||
case FuzzModeDrop: | |||||
// randomly drop the r/w, drop the conn, or sleep | |||||
r := rand.Float64() | |||||
if r <= fc.config.ProbDropRW { | |||||
return true | |||||
} else if r < fc.config.ProbDropRW+fc.config.ProbDropConn { | |||||
// XXX: can't this fail because machine precision? | |||||
// XXX: do we need an error? | |||||
fc.Close() | |||||
return true | |||||
} else if r < fc.config.ProbDropRW+fc.config.ProbDropConn+fc.config.ProbSleep { | |||||
time.Sleep(fc.randomDuration()) | |||||
} | |||||
case FuzzModeDelay: | |||||
// sleep a bit | |||||
time.Sleep(fc.randomDuration()) | |||||
} | |||||
return false | |||||
} | |||||
func (fc *FuzzedConnection) shouldFuzz() bool { | |||||
if fc.active { | |||||
return true | |||||
} | |||||
fc.mtx.Lock() | |||||
defer fc.mtx.Unlock() | |||||
select { | |||||
case <-fc.start: | |||||
fc.active = true | |||||
return true | |||||
default: | |||||
return false | |||||
} | |||||
} |
@ -0,0 +1,29 @@ | |||||
package p2p | |||||
import ( | |||||
"strings" | |||||
) | |||||
// TODO Test | |||||
func AddToIPRangeCounts(counts map[string]int, ip string) map[string]int { | |||||
changes := make(map[string]int) | |||||
ipParts := strings.Split(ip, ":") | |||||
for i := 1; i < len(ipParts); i++ { | |||||
prefix := strings.Join(ipParts[:i], ":") | |||||
counts[prefix] += 1 | |||||
changes[prefix] = counts[prefix] | |||||
} | |||||
return changes | |||||
} | |||||
// TODO Test | |||||
func CheckIPRangeCounts(counts map[string]int, limits []int) bool { | |||||
for prefix, count := range counts { | |||||
ipParts := strings.Split(prefix, ":") | |||||
numParts := len(ipParts) | |||||
if limits[numParts] < count { | |||||
return false | |||||
} | |||||
} | |||||
return true | |||||
} |
@ -0,0 +1,217 @@ | |||||
package p2p | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"strconv" | |||||
"time" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
"github.com/tendermint/tendermint/p2p/upnp" | |||||
) | |||||
type Listener interface { | |||||
Connections() <-chan net.Conn | |||||
InternalAddress() *NetAddress | |||||
ExternalAddress() *NetAddress | |||||
String() string | |||||
Stop() bool | |||||
} | |||||
// Implements Listener | |||||
type DefaultListener struct { | |||||
BaseService | |||||
listener net.Listener | |||||
intAddr *NetAddress | |||||
extAddr *NetAddress | |||||
connections chan net.Conn | |||||
} | |||||
const ( | |||||
numBufferedConnections = 10 | |||||
defaultExternalPort = 8770 | |||||
tryListenSeconds = 5 | |||||
) | |||||
func splitHostPort(addr string) (host string, port int) { | |||||
host, portStr, err := net.SplitHostPort(addr) | |||||
if err != nil { | |||||
PanicSanity(err) | |||||
} | |||||
port, err = strconv.Atoi(portStr) | |||||
if err != nil { | |||||
PanicSanity(err) | |||||
} | |||||
return host, port | |||||
} | |||||
// skipUPNP: If true, does not try getUPNPExternalAddress() | |||||
func NewDefaultListener(protocol string, lAddr string, skipUPNP bool) Listener { | |||||
// Local listen IP & port | |||||
lAddrIP, lAddrPort := splitHostPort(lAddr) | |||||
// Create listener | |||||
var listener net.Listener | |||||
var err error | |||||
for i := 0; i < tryListenSeconds; i++ { | |||||
listener, err = net.Listen(protocol, lAddr) | |||||
if err == nil { | |||||
break | |||||
} else if i < tryListenSeconds-1 { | |||||
time.Sleep(time.Second * 1) | |||||
} | |||||
} | |||||
if err != nil { | |||||
PanicCrisis(err) | |||||
} | |||||
// Actual listener local IP & port | |||||
listenerIP, listenerPort := splitHostPort(listener.Addr().String()) | |||||
log.Info("Local listener", "ip", listenerIP, "port", listenerPort) | |||||
// Determine internal address... | |||||
var intAddr *NetAddress | |||||
intAddr, err = NewNetAddressString(lAddr) | |||||
if err != nil { | |||||
PanicCrisis(err) | |||||
} | |||||
// Determine external address... | |||||
var extAddr *NetAddress | |||||
if !skipUPNP { | |||||
// If the lAddrIP is INADDR_ANY, try UPnP | |||||
if lAddrIP == "" || lAddrIP == "0.0.0.0" { | |||||
extAddr = getUPNPExternalAddress(lAddrPort, listenerPort) | |||||
} | |||||
} | |||||
// Otherwise just use the local address... | |||||
if extAddr == nil { | |||||
extAddr = getNaiveExternalAddress(listenerPort) | |||||
} | |||||
if extAddr == nil { | |||||
PanicCrisis("Could not determine external address!") | |||||
} | |||||
dl := &DefaultListener{ | |||||
listener: listener, | |||||
intAddr: intAddr, | |||||
extAddr: extAddr, | |||||
connections: make(chan net.Conn, numBufferedConnections), | |||||
} | |||||
dl.BaseService = *NewBaseService(log, "DefaultListener", dl) | |||||
dl.Start() // Started upon construction | |||||
return dl | |||||
} | |||||
func (l *DefaultListener) OnStart() error { | |||||
l.BaseService.OnStart() | |||||
go l.listenRoutine() | |||||
return nil | |||||
} | |||||
func (l *DefaultListener) OnStop() { | |||||
l.BaseService.OnStop() | |||||
l.listener.Close() | |||||
} | |||||
// Accept connections and pass on the channel | |||||
func (l *DefaultListener) listenRoutine() { | |||||
for { | |||||
conn, err := l.listener.Accept() | |||||
if !l.IsRunning() { | |||||
break // Go to cleanup | |||||
} | |||||
// listener wasn't stopped, | |||||
// yet we encountered an error. | |||||
if err != nil { | |||||
PanicCrisis(err) | |||||
} | |||||
l.connections <- conn | |||||
} | |||||
// Cleanup | |||||
close(l.connections) | |||||
for _ = range l.connections { | |||||
// Drain | |||||
} | |||||
} | |||||
// A channel of inbound connections. | |||||
// It gets closed when the listener closes. | |||||
func (l *DefaultListener) Connections() <-chan net.Conn { | |||||
return l.connections | |||||
} | |||||
func (l *DefaultListener) InternalAddress() *NetAddress { | |||||
return l.intAddr | |||||
} | |||||
func (l *DefaultListener) ExternalAddress() *NetAddress { | |||||
return l.extAddr | |||||
} | |||||
// NOTE: The returned listener is already Accept()'ing. | |||||
// So it's not suitable to pass into http.Serve(). | |||||
func (l *DefaultListener) NetListener() net.Listener { | |||||
return l.listener | |||||
} | |||||
func (l *DefaultListener) String() string { | |||||
return fmt.Sprintf("Listener(@%v)", l.extAddr) | |||||
} | |||||
/* external address helpers */ | |||||
// UPNP external address discovery & port mapping | |||||
func getUPNPExternalAddress(externalPort, internalPort int) *NetAddress { | |||||
log.Info("Getting UPNP external address") | |||||
nat, err := upnp.Discover() | |||||
if err != nil { | |||||
log.Info("Could not perform UPNP discover", "error", err) | |||||
return nil | |||||
} | |||||
ext, err := nat.GetExternalAddress() | |||||
if err != nil { | |||||
log.Info("Could not get UPNP external address", "error", err) | |||||
return nil | |||||
} | |||||
// UPnP can't seem to get the external port, so let's just be explicit. | |||||
if externalPort == 0 { | |||||
externalPort = defaultExternalPort | |||||
} | |||||
externalPort, err = nat.AddPortMapping("tcp", externalPort, internalPort, "tendermint", 0) | |||||
if err != nil { | |||||
log.Info("Could not add UPNP port mapping", "error", err) | |||||
return nil | |||||
} | |||||
log.Info("Got UPNP external address", "address", ext) | |||||
return NewNetAddressIPPort(ext, uint16(externalPort)) | |||||
} | |||||
// TODO: use syscalls: http://pastebin.com/9exZG4rh | |||||
func getNaiveExternalAddress(port int) *NetAddress { | |||||
addrs, err := net.InterfaceAddrs() | |||||
if err != nil { | |||||
PanicCrisis(Fmt("Could not fetch interface addresses: %v", err)) | |||||
} | |||||
for _, a := range addrs { | |||||
ipnet, ok := a.(*net.IPNet) | |||||
if !ok { | |||||
continue | |||||
} | |||||
v4 := ipnet.IP.To4() | |||||
if v4 == nil || v4[0] == 127 { | |||||
continue | |||||
} // loopback | |||||
return NewNetAddressIPPort(ipnet.IP, uint16(port)) | |||||
} | |||||
return nil | |||||
} |
@ -0,0 +1,40 @@ | |||||
package p2p | |||||
import ( | |||||
"bytes" | |||||
"testing" | |||||
) | |||||
func TestListener(t *testing.T) { | |||||
// Create a listener | |||||
l := NewDefaultListener("tcp", ":8001", true) | |||||
// Dial the listener | |||||
lAddr := l.ExternalAddress() | |||||
connOut, err := lAddr.Dial() | |||||
if err != nil { | |||||
t.Fatalf("Could not connect to listener address %v", lAddr) | |||||
} else { | |||||
t.Logf("Created a connection to listener address %v", lAddr) | |||||
} | |||||
connIn, ok := <-l.Connections() | |||||
if !ok { | |||||
t.Fatalf("Could not get inbound connection from listener") | |||||
} | |||||
msg := []byte("hi!") | |||||
go connIn.Write(msg) | |||||
b := make([]byte, 32) | |||||
n, err := connOut.Read(b) | |||||
if err != nil { | |||||
t.Fatalf("Error reading off connection: %v", err) | |||||
} | |||||
b = b[:n] | |||||
if !bytes.Equal(msg, b) { | |||||
t.Fatalf("Got %s, expected %s", b, msg) | |||||
} | |||||
// Close the server, no longer needed. | |||||
l.Stop() | |||||
} |
@ -0,0 +1,7 @@ | |||||
package p2p | |||||
import ( | |||||
"github.com/tendermint/tmlibs/logger" | |||||
) | |||||
var log = logger.New("module", "p2p") |
@ -0,0 +1,253 @@ | |||||
// Modified for Tendermint | |||||
// Originally Copyright (c) 2013-2014 Conformal Systems LLC. | |||||
// https://github.com/conformal/btcd/blob/master/LICENSE | |||||
package p2p | |||||
import ( | |||||
"errors" | |||||
"flag" | |||||
"net" | |||||
"strconv" | |||||
"time" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
) | |||||
// NetAddress defines information about a peer on the network | |||||
// including its IP address, and port. | |||||
type NetAddress struct { | |||||
IP net.IP | |||||
Port uint16 | |||||
str string | |||||
} | |||||
// NewNetAddress returns a new NetAddress using the provided TCP | |||||
// address. When testing, other net.Addr (except TCP) will result in | |||||
// using 0.0.0.0:0. When normal run, other net.Addr (except TCP) will | |||||
// panic. | |||||
// TODO: socks proxies? | |||||
func NewNetAddress(addr net.Addr) *NetAddress { | |||||
tcpAddr, ok := addr.(*net.TCPAddr) | |||||
if !ok { | |||||
if flag.Lookup("test.v") == nil { // normal run | |||||
cmn.PanicSanity(cmn.Fmt("Only TCPAddrs are supported. Got: %v", addr)) | |||||
} else { // in testing | |||||
return NewNetAddressIPPort(net.IP("0.0.0.0"), 0) | |||||
} | |||||
} | |||||
ip := tcpAddr.IP | |||||
port := uint16(tcpAddr.Port) | |||||
return NewNetAddressIPPort(ip, port) | |||||
} | |||||
// NewNetAddressString returns a new NetAddress using the provided | |||||
// address in the form of "IP:Port". Also resolves the host if host | |||||
// is not an IP. | |||||
func NewNetAddressString(addr string) (*NetAddress, error) { | |||||
host, portStr, err := net.SplitHostPort(addr) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
ip := net.ParseIP(host) | |||||
if ip == nil { | |||||
if len(host) > 0 { | |||||
ips, err := net.LookupIP(host) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
ip = ips[0] | |||||
} | |||||
} | |||||
port, err := strconv.ParseUint(portStr, 10, 16) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
na := NewNetAddressIPPort(ip, uint16(port)) | |||||
return na, nil | |||||
} | |||||
// NewNetAddressStrings returns an array of NetAddress'es build using | |||||
// the provided strings. | |||||
func NewNetAddressStrings(addrs []string) ([]*NetAddress, error) { | |||||
netAddrs := make([]*NetAddress, len(addrs)) | |||||
for i, addr := range addrs { | |||||
netAddr, err := NewNetAddressString(addr) | |||||
if err != nil { | |||||
return nil, errors.New(cmn.Fmt("Error in address %s: %v", addr, err)) | |||||
} | |||||
netAddrs[i] = netAddr | |||||
} | |||||
return netAddrs, nil | |||||
} | |||||
// NewNetAddressIPPort returns a new NetAddress using the provided IP | |||||
// and port number. | |||||
func NewNetAddressIPPort(ip net.IP, port uint16) *NetAddress { | |||||
na := &NetAddress{ | |||||
IP: ip, | |||||
Port: port, | |||||
str: net.JoinHostPort( | |||||
ip.String(), | |||||
strconv.FormatUint(uint64(port), 10), | |||||
), | |||||
} | |||||
return na | |||||
} | |||||
// Equals reports whether na and other are the same addresses. | |||||
func (na *NetAddress) Equals(other interface{}) bool { | |||||
if o, ok := other.(*NetAddress); ok { | |||||
return na.String() == o.String() | |||||
} | |||||
return false | |||||
} | |||||
func (na *NetAddress) Less(other interface{}) bool { | |||||
if o, ok := other.(*NetAddress); ok { | |||||
return na.String() < o.String() | |||||
} | |||||
cmn.PanicSanity("Cannot compare unequal types") | |||||
return false | |||||
} | |||||
// String representation. | |||||
func (na *NetAddress) String() string { | |||||
if na.str == "" { | |||||
na.str = net.JoinHostPort( | |||||
na.IP.String(), | |||||
strconv.FormatUint(uint64(na.Port), 10), | |||||
) | |||||
} | |||||
return na.str | |||||
} | |||||
// Dial calls net.Dial on the address. | |||||
func (na *NetAddress) Dial() (net.Conn, error) { | |||||
conn, err := net.Dial("tcp", na.String()) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return conn, nil | |||||
} | |||||
// DialTimeout calls net.DialTimeout on the address. | |||||
func (na *NetAddress) DialTimeout(timeout time.Duration) (net.Conn, error) { | |||||
conn, err := net.DialTimeout("tcp", na.String(), timeout) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return conn, nil | |||||
} | |||||
// Routable returns true if the address is routable. | |||||
func (na *NetAddress) Routable() bool { | |||||
// TODO(oga) bitcoind doesn't include RFC3849 here, but should we? | |||||
return na.Valid() && !(na.RFC1918() || na.RFC3927() || na.RFC4862() || | |||||
na.RFC4193() || na.RFC4843() || na.Local()) | |||||
} | |||||
// For IPv4 these are either a 0 or all bits set address. For IPv6 a zero | |||||
// address or one that matches the RFC3849 documentation address format. | |||||
func (na *NetAddress) Valid() bool { | |||||
return na.IP != nil && !(na.IP.IsUnspecified() || na.RFC3849() || | |||||
na.IP.Equal(net.IPv4bcast)) | |||||
} | |||||
// Local returns true if it is a local address. | |||||
func (na *NetAddress) Local() bool { | |||||
return na.IP.IsLoopback() || zero4.Contains(na.IP) | |||||
} | |||||
// ReachabilityTo checks whenever o can be reached from na. | |||||
func (na *NetAddress) ReachabilityTo(o *NetAddress) int { | |||||
const ( | |||||
Unreachable = 0 | |||||
Default = iota | |||||
Teredo | |||||
Ipv6_weak | |||||
Ipv4 | |||||
Ipv6_strong | |||||
Private | |||||
) | |||||
if !na.Routable() { | |||||
return Unreachable | |||||
} else if na.RFC4380() { | |||||
if !o.Routable() { | |||||
return Default | |||||
} else if o.RFC4380() { | |||||
return Teredo | |||||
} else if o.IP.To4() != nil { | |||||
return Ipv4 | |||||
} else { // ipv6 | |||||
return Ipv6_weak | |||||
} | |||||
} else if na.IP.To4() != nil { | |||||
if o.Routable() && o.IP.To4() != nil { | |||||
return Ipv4 | |||||
} | |||||
return Default | |||||
} else /* ipv6 */ { | |||||
var tunnelled bool | |||||
// Is our v6 is tunnelled? | |||||
if o.RFC3964() || o.RFC6052() || o.RFC6145() { | |||||
tunnelled = true | |||||
} | |||||
if !o.Routable() { | |||||
return Default | |||||
} else if o.RFC4380() { | |||||
return Teredo | |||||
} else if o.IP.To4() != nil { | |||||
return Ipv4 | |||||
} else if tunnelled { | |||||
// only prioritise ipv6 if we aren't tunnelling it. | |||||
return Ipv6_weak | |||||
} | |||||
return Ipv6_strong | |||||
} | |||||
} | |||||
// RFC1918: IPv4 Private networks (10.0.0.0/8, 192.168.0.0/16, 172.16.0.0/12) | |||||
// RFC3849: IPv6 Documentation address (2001:0DB8::/32) | |||||
// RFC3927: IPv4 Autoconfig (169.254.0.0/16) | |||||
// RFC3964: IPv6 6to4 (2002::/16) | |||||
// RFC4193: IPv6 unique local (FC00::/7) | |||||
// RFC4380: IPv6 Teredo tunneling (2001::/32) | |||||
// RFC4843: IPv6 ORCHID: (2001:10::/28) | |||||
// RFC4862: IPv6 Autoconfig (FE80::/64) | |||||
// RFC6052: IPv6 well known prefix (64:FF9B::/96) | |||||
// RFC6145: IPv6 IPv4 translated address ::FFFF:0:0:0/96 | |||||
var rfc1918_10 = net.IPNet{IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)} | |||||
var rfc1918_192 = net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)} | |||||
var rfc1918_172 = net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)} | |||||
var rfc3849 = net.IPNet{IP: net.ParseIP("2001:0DB8::"), Mask: net.CIDRMask(32, 128)} | |||||
var rfc3927 = net.IPNet{IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)} | |||||
var rfc3964 = net.IPNet{IP: net.ParseIP("2002::"), Mask: net.CIDRMask(16, 128)} | |||||
var rfc4193 = net.IPNet{IP: net.ParseIP("FC00::"), Mask: net.CIDRMask(7, 128)} | |||||
var rfc4380 = net.IPNet{IP: net.ParseIP("2001::"), Mask: net.CIDRMask(32, 128)} | |||||
var rfc4843 = net.IPNet{IP: net.ParseIP("2001:10::"), Mask: net.CIDRMask(28, 128)} | |||||
var rfc4862 = net.IPNet{IP: net.ParseIP("FE80::"), Mask: net.CIDRMask(64, 128)} | |||||
var rfc6052 = net.IPNet{IP: net.ParseIP("64:FF9B::"), Mask: net.CIDRMask(96, 128)} | |||||
var rfc6145 = net.IPNet{IP: net.ParseIP("::FFFF:0:0:0"), Mask: net.CIDRMask(96, 128)} | |||||
var zero4 = net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(8, 32)} | |||||
func (na *NetAddress) RFC1918() bool { | |||||
return rfc1918_10.Contains(na.IP) || | |||||
rfc1918_192.Contains(na.IP) || | |||||
rfc1918_172.Contains(na.IP) | |||||
} | |||||
func (na *NetAddress) RFC3849() bool { return rfc3849.Contains(na.IP) } | |||||
func (na *NetAddress) RFC3927() bool { return rfc3927.Contains(na.IP) } | |||||
func (na *NetAddress) RFC3964() bool { return rfc3964.Contains(na.IP) } | |||||
func (na *NetAddress) RFC4193() bool { return rfc4193.Contains(na.IP) } | |||||
func (na *NetAddress) RFC4380() bool { return rfc4380.Contains(na.IP) } | |||||
func (na *NetAddress) RFC4843() bool { return rfc4843.Contains(na.IP) } | |||||
func (na *NetAddress) RFC4862() bool { return rfc4862.Contains(na.IP) } | |||||
func (na *NetAddress) RFC6052() bool { return rfc6052.Contains(na.IP) } | |||||
func (na *NetAddress) RFC6145() bool { return rfc6145.Contains(na.IP) } |
@ -0,0 +1,113 @@ | |||||
package p2p | |||||
import ( | |||||
"net" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
) | |||||
func TestNewNetAddress(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
tcpAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") | |||||
require.Nil(err) | |||||
addr := NewNetAddress(tcpAddr) | |||||
assert.Equal("127.0.0.1:8080", addr.String()) | |||||
assert.NotPanics(func() { | |||||
NewNetAddress(&net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8000}) | |||||
}, "Calling NewNetAddress with UDPAddr should not panic in testing") | |||||
} | |||||
func TestNewNetAddressString(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
tests := []struct { | |||||
addr string | |||||
correct bool | |||||
}{ | |||||
{"127.0.0.1:8080", true}, | |||||
{"127.0.0:8080", false}, | |||||
{"a", false}, | |||||
{"127.0.0.1:a", false}, | |||||
{"a:8080", false}, | |||||
{"8082", false}, | |||||
{"127.0.0:8080000", false}, | |||||
} | |||||
for _, t := range tests { | |||||
addr, err := NewNetAddressString(t.addr) | |||||
if t.correct { | |||||
require.Nil(err) | |||||
assert.Equal(t.addr, addr.String()) | |||||
} else { | |||||
require.NotNil(err) | |||||
} | |||||
} | |||||
} | |||||
func TestNewNetAddressStrings(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
addrs, err := NewNetAddressStrings([]string{"127.0.0.1:8080", "127.0.0.2:8080"}) | |||||
require.Nil(err) | |||||
assert.Equal(2, len(addrs)) | |||||
} | |||||
func TestNewNetAddressIPPort(t *testing.T) { | |||||
assert := assert.New(t) | |||||
addr := NewNetAddressIPPort(net.ParseIP("127.0.0.1"), 8080) | |||||
assert.Equal("127.0.0.1:8080", addr.String()) | |||||
} | |||||
func TestNetAddressProperties(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
// TODO add more test cases | |||||
tests := []struct { | |||||
addr string | |||||
valid bool | |||||
local bool | |||||
routable bool | |||||
}{ | |||||
{"127.0.0.1:8080", true, true, false}, | |||||
{"ya.ru:80", true, false, true}, | |||||
} | |||||
for _, t := range tests { | |||||
addr, err := NewNetAddressString(t.addr) | |||||
require.Nil(err) | |||||
assert.Equal(t.valid, addr.Valid()) | |||||
assert.Equal(t.local, addr.Local()) | |||||
assert.Equal(t.routable, addr.Routable()) | |||||
} | |||||
} | |||||
func TestNetAddressReachabilityTo(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
// TODO add more test cases | |||||
tests := []struct { | |||||
addr string | |||||
other string | |||||
reachability int | |||||
}{ | |||||
{"127.0.0.1:8080", "127.0.0.1:8081", 0}, | |||||
{"ya.ru:80", "127.0.0.1:8080", 1}, | |||||
} | |||||
for _, t := range tests { | |||||
addr, err := NewNetAddressString(t.addr) | |||||
require.Nil(err) | |||||
other, err := NewNetAddressString(t.other) | |||||
require.Nil(err) | |||||
assert.Equal(t.reachability, addr.ReachabilityTo(other)) | |||||
} | |||||
} |
@ -0,0 +1,304 @@ | |||||
package p2p | |||||
import ( | |||||
"fmt" | |||||
"io" | |||||
"net" | |||||
"time" | |||||
"github.com/pkg/errors" | |||||
crypto "github.com/tendermint/go-crypto" | |||||
wire "github.com/tendermint/go-wire" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
) | |||||
// Peer could be marked as persistent, in which case you can use | |||||
// Redial function to reconnect. Note that inbound peers can't be | |||||
// made persistent. They should be made persistent on the other end. | |||||
// | |||||
// Before using a peer, you will need to perform a handshake on connection. | |||||
type Peer struct { | |||||
cmn.BaseService | |||||
outbound bool | |||||
conn net.Conn // source connection | |||||
mconn *MConnection // multiplex connection | |||||
persistent bool | |||||
config *PeerConfig | |||||
*NodeInfo | |||||
Key string | |||||
Data *cmn.CMap // User data. | |||||
} | |||||
// PeerConfig is a Peer configuration. | |||||
type PeerConfig struct { | |||||
AuthEnc bool // authenticated encryption | |||||
HandshakeTimeout time.Duration | |||||
DialTimeout time.Duration | |||||
MConfig *MConnConfig | |||||
Fuzz bool // fuzz connection (for testing) | |||||
FuzzConfig *FuzzConnConfig | |||||
} | |||||
// DefaultPeerConfig returns the default config. | |||||
func DefaultPeerConfig() *PeerConfig { | |||||
return &PeerConfig{ | |||||
AuthEnc: true, | |||||
HandshakeTimeout: 2 * time.Second, | |||||
DialTimeout: 3 * time.Second, | |||||
MConfig: DefaultMConnConfig(), | |||||
Fuzz: false, | |||||
FuzzConfig: DefaultFuzzConnConfig(), | |||||
} | |||||
} | |||||
func newOutboundPeer(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) { | |||||
return newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig()) | |||||
} | |||||
func newOutboundPeerWithConfig(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { | |||||
conn, err := dial(addr, config) | |||||
if err != nil { | |||||
return nil, errors.Wrap(err, "Error creating peer") | |||||
} | |||||
peer, err := newPeerFromConnAndConfig(conn, true, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config) | |||||
if err != nil { | |||||
conn.Close() | |||||
return nil, err | |||||
} | |||||
return peer, nil | |||||
} | |||||
func newInboundPeer(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519) (*Peer, error) { | |||||
return newInboundPeerWithConfig(conn, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig()) | |||||
} | |||||
func newInboundPeerWithConfig(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { | |||||
return newPeerFromConnAndConfig(conn, false, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config) | |||||
} | |||||
func newPeerFromConnAndConfig(rawConn net.Conn, outbound bool, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*Peer, error) { | |||||
conn := rawConn | |||||
// Fuzz connection | |||||
if config.Fuzz { | |||||
// so we have time to do peer handshakes and get set up | |||||
conn = FuzzConnAfterFromConfig(conn, 10*time.Second, config.FuzzConfig) | |||||
} | |||||
// Encrypt connection | |||||
if config.AuthEnc { | |||||
conn.SetDeadline(time.Now().Add(config.HandshakeTimeout)) | |||||
var err error | |||||
conn, err = MakeSecretConnection(conn, ourNodePrivKey) | |||||
if err != nil { | |||||
return nil, errors.Wrap(err, "Error creating peer") | |||||
} | |||||
} | |||||
// Key and NodeInfo are set after Handshake | |||||
p := &Peer{ | |||||
outbound: outbound, | |||||
conn: conn, | |||||
config: config, | |||||
Data: cmn.NewCMap(), | |||||
} | |||||
p.mconn = createMConnection(conn, p, reactorsByCh, chDescs, onPeerError, config.MConfig) | |||||
p.BaseService = *cmn.NewBaseService(log, "Peer", p) | |||||
return p, nil | |||||
} | |||||
// CloseConn should be used when the peer was created, but never started. | |||||
func (p *Peer) CloseConn() { | |||||
p.conn.Close() | |||||
} | |||||
// makePersistent marks the peer as persistent. | |||||
func (p *Peer) makePersistent() { | |||||
if !p.outbound { | |||||
panic("inbound peers can't be made persistent") | |||||
} | |||||
p.persistent = true | |||||
} | |||||
// IsPersistent returns true if the peer is persitent, false otherwise. | |||||
func (p *Peer) IsPersistent() bool { | |||||
return p.persistent | |||||
} | |||||
// HandshakeTimeout performs a handshake between a given node and the peer. | |||||
// NOTE: blocking | |||||
func (p *Peer) HandshakeTimeout(ourNodeInfo *NodeInfo, timeout time.Duration) error { | |||||
// Set deadline for handshake so we don't block forever on conn.ReadFull | |||||
p.conn.SetDeadline(time.Now().Add(timeout)) | |||||
var peerNodeInfo = new(NodeInfo) | |||||
var err1 error | |||||
var err2 error | |||||
cmn.Parallel( | |||||
func() { | |||||
var n int | |||||
wire.WriteBinary(ourNodeInfo, p.conn, &n, &err1) | |||||
}, | |||||
func() { | |||||
var n int | |||||
wire.ReadBinary(peerNodeInfo, p.conn, maxNodeInfoSize, &n, &err2) | |||||
log.Notice("Peer handshake", "peerNodeInfo", peerNodeInfo) | |||||
}) | |||||
if err1 != nil { | |||||
return errors.Wrap(err1, "Error during handshake/write") | |||||
} | |||||
if err2 != nil { | |||||
return errors.Wrap(err2, "Error during handshake/read") | |||||
} | |||||
if p.config.AuthEnc { | |||||
// Check that the professed PubKey matches the sconn's. | |||||
if !peerNodeInfo.PubKey.Equals(p.PubKey().Wrap()) { | |||||
return fmt.Errorf("Ignoring connection with unmatching pubkey: %v vs %v", | |||||
peerNodeInfo.PubKey, p.PubKey()) | |||||
} | |||||
} | |||||
// Remove deadline | |||||
p.conn.SetDeadline(time.Time{}) | |||||
peerNodeInfo.RemoteAddr = p.Addr().String() | |||||
p.NodeInfo = peerNodeInfo | |||||
p.Key = peerNodeInfo.PubKey.KeyString() | |||||
return nil | |||||
} | |||||
// Addr returns peer's network address. | |||||
func (p *Peer) Addr() net.Addr { | |||||
return p.conn.RemoteAddr() | |||||
} | |||||
// PubKey returns peer's public key. | |||||
func (p *Peer) PubKey() crypto.PubKeyEd25519 { | |||||
if p.config.AuthEnc { | |||||
return p.conn.(*SecretConnection).RemotePubKey() | |||||
} | |||||
if p.NodeInfo == nil { | |||||
panic("Attempt to get peer's PubKey before calling Handshake") | |||||
} | |||||
return p.PubKey() | |||||
} | |||||
// OnStart implements BaseService. | |||||
func (p *Peer) OnStart() error { | |||||
p.BaseService.OnStart() | |||||
_, err := p.mconn.Start() | |||||
return err | |||||
} | |||||
// OnStop implements BaseService. | |||||
func (p *Peer) OnStop() { | |||||
p.BaseService.OnStop() | |||||
p.mconn.Stop() | |||||
} | |||||
// Connection returns underlying MConnection. | |||||
func (p *Peer) Connection() *MConnection { | |||||
return p.mconn | |||||
} | |||||
// IsOutbound returns true if the connection is outbound, false otherwise. | |||||
func (p *Peer) IsOutbound() bool { | |||||
return p.outbound | |||||
} | |||||
// Send msg to the channel identified by chID byte. Returns false if the send | |||||
// queue is full after timeout, specified by MConnection. | |||||
func (p *Peer) Send(chID byte, msg interface{}) bool { | |||||
if !p.IsRunning() { | |||||
// see Switch#Broadcast, where we fetch the list of peers and loop over | |||||
// them - while we're looping, one peer may be removed and stopped. | |||||
return false | |||||
} | |||||
return p.mconn.Send(chID, msg) | |||||
} | |||||
// TrySend msg to the channel identified by chID byte. Immediately returns | |||||
// false if the send queue is full. | |||||
func (p *Peer) TrySend(chID byte, msg interface{}) bool { | |||||
if !p.IsRunning() { | |||||
return false | |||||
} | |||||
return p.mconn.TrySend(chID, msg) | |||||
} | |||||
// CanSend returns true if the send queue is not full, false otherwise. | |||||
func (p *Peer) CanSend(chID byte) bool { | |||||
if !p.IsRunning() { | |||||
return false | |||||
} | |||||
return p.mconn.CanSend(chID) | |||||
} | |||||
// WriteTo writes the peer's public key to w. | |||||
func (p *Peer) WriteTo(w io.Writer) (n int64, err error) { | |||||
var n_ int | |||||
wire.WriteString(p.Key, w, &n_, &err) | |||||
n += int64(n_) | |||||
return | |||||
} | |||||
// String representation. | |||||
func (p *Peer) String() string { | |||||
if p.outbound { | |||||
return fmt.Sprintf("Peer{%v %v out}", p.mconn, p.Key[:12]) | |||||
} | |||||
return fmt.Sprintf("Peer{%v %v in}", p.mconn, p.Key[:12]) | |||||
} | |||||
// Equals reports whenever 2 peers are actually represent the same node. | |||||
func (p *Peer) Equals(other *Peer) bool { | |||||
return p.Key == other.Key | |||||
} | |||||
// Get the data for a given key. | |||||
func (p *Peer) Get(key string) interface{} { | |||||
return p.Data.Get(key) | |||||
} | |||||
func dial(addr *NetAddress, config *PeerConfig) (net.Conn, error) { | |||||
log.Info("Dialing address", "address", addr) | |||||
conn, err := addr.DialTimeout(config.DialTimeout) | |||||
if err != nil { | |||||
log.Info("Failed dialing address", "address", addr, "error", err) | |||||
return nil, err | |||||
} | |||||
return conn, nil | |||||
} | |||||
func createMConnection(conn net.Conn, p *Peer, reactorsByCh map[byte]Reactor, chDescs []*ChannelDescriptor, onPeerError func(*Peer, interface{}), config *MConnConfig) *MConnection { | |||||
onReceive := func(chID byte, msgBytes []byte) { | |||||
reactor := reactorsByCh[chID] | |||||
if reactor == nil { | |||||
cmn.PanicSanity(cmn.Fmt("Unknown channel %X", chID)) | |||||
} | |||||
reactor.Receive(chID, p, msgBytes) | |||||
} | |||||
onError := func(r interface{}) { | |||||
onPeerError(p, r) | |||||
} | |||||
return NewMConnectionWithConfig(conn, chDescs, onReceive, onError, config) | |||||
} |
@ -0,0 +1,115 @@ | |||||
package p2p | |||||
import ( | |||||
"sync" | |||||
) | |||||
// IPeerSet has a (immutable) subset of the methods of PeerSet. | |||||
type IPeerSet interface { | |||||
Has(key string) bool | |||||
Get(key string) *Peer | |||||
List() []*Peer | |||||
Size() int | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
// PeerSet is a special structure for keeping a table of peers. | |||||
// Iteration over the peers is super fast and thread-safe. | |||||
// We also track how many peers per IP range and avoid too many | |||||
type PeerSet struct { | |||||
mtx sync.Mutex | |||||
lookup map[string]*peerSetItem | |||||
list []*Peer | |||||
} | |||||
type peerSetItem struct { | |||||
peer *Peer | |||||
index int | |||||
} | |||||
func NewPeerSet() *PeerSet { | |||||
return &PeerSet{ | |||||
lookup: make(map[string]*peerSetItem), | |||||
list: make([]*Peer, 0, 256), | |||||
} | |||||
} | |||||
// Returns false if peer with key (PubKeyEd25519) is already in set | |||||
// or if we have too many peers from the peer's IP range | |||||
func (ps *PeerSet) Add(peer *Peer) error { | |||||
ps.mtx.Lock() | |||||
defer ps.mtx.Unlock() | |||||
if ps.lookup[peer.Key] != nil { | |||||
return ErrSwitchDuplicatePeer | |||||
} | |||||
index := len(ps.list) | |||||
// Appending is safe even with other goroutines | |||||
// iterating over the ps.list slice. | |||||
ps.list = append(ps.list, peer) | |||||
ps.lookup[peer.Key] = &peerSetItem{peer, index} | |||||
return nil | |||||
} | |||||
func (ps *PeerSet) Has(peerKey string) bool { | |||||
ps.mtx.Lock() | |||||
defer ps.mtx.Unlock() | |||||
_, ok := ps.lookup[peerKey] | |||||
return ok | |||||
} | |||||
func (ps *PeerSet) Get(peerKey string) *Peer { | |||||
ps.mtx.Lock() | |||||
defer ps.mtx.Unlock() | |||||
item, ok := ps.lookup[peerKey] | |||||
if ok { | |||||
return item.peer | |||||
} else { | |||||
return nil | |||||
} | |||||
} | |||||
func (ps *PeerSet) Remove(peer *Peer) { | |||||
ps.mtx.Lock() | |||||
defer ps.mtx.Unlock() | |||||
item := ps.lookup[peer.Key] | |||||
if item == nil { | |||||
return | |||||
} | |||||
index := item.index | |||||
// Copy the list but without the last element. | |||||
// (we must copy because we're mutating the list) | |||||
newList := make([]*Peer, len(ps.list)-1) | |||||
copy(newList, ps.list) | |||||
// If it's the last peer, that's an easy special case. | |||||
if index == len(ps.list)-1 { | |||||
ps.list = newList | |||||
delete(ps.lookup, peer.Key) | |||||
return | |||||
} | |||||
// Move the last item from ps.list to "index" in list. | |||||
lastPeer := ps.list[len(ps.list)-1] | |||||
lastPeerKey := lastPeer.Key | |||||
lastPeerItem := ps.lookup[lastPeerKey] | |||||
newList[index] = lastPeer | |||||
lastPeerItem.index = index | |||||
ps.list = newList | |||||
delete(ps.lookup, peer.Key) | |||||
} | |||||
func (ps *PeerSet) Size() int { | |||||
ps.mtx.Lock() | |||||
defer ps.mtx.Unlock() | |||||
return len(ps.list) | |||||
} | |||||
// threadsafe list of peers. | |||||
func (ps *PeerSet) List() []*Peer { | |||||
ps.mtx.Lock() | |||||
defer ps.mtx.Unlock() | |||||
return ps.list | |||||
} |
@ -0,0 +1,67 @@ | |||||
package p2p | |||||
import ( | |||||
"math/rand" | |||||
"testing" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
) | |||||
// Returns an empty dummy peer | |||||
func randPeer() *Peer { | |||||
return &Peer{ | |||||
Key: RandStr(12), | |||||
NodeInfo: &NodeInfo{ | |||||
RemoteAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256), | |||||
ListenAddr: Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256), | |||||
}, | |||||
} | |||||
} | |||||
func TestAddRemoveOne(t *testing.T) { | |||||
peerSet := NewPeerSet() | |||||
peer := randPeer() | |||||
err := peerSet.Add(peer) | |||||
if err != nil { | |||||
t.Errorf("Failed to add new peer") | |||||
} | |||||
if peerSet.Size() != 1 { | |||||
t.Errorf("Failed to add new peer and increment size") | |||||
} | |||||
peerSet.Remove(peer) | |||||
if peerSet.Has(peer.Key) { | |||||
t.Errorf("Failed to remove peer") | |||||
} | |||||
if peerSet.Size() != 0 { | |||||
t.Errorf("Failed to remove peer and decrement size") | |||||
} | |||||
} | |||||
func TestAddRemoveMany(t *testing.T) { | |||||
peerSet := NewPeerSet() | |||||
peers := []*Peer{} | |||||
N := 100 | |||||
for i := 0; i < N; i++ { | |||||
peer := randPeer() | |||||
if err := peerSet.Add(peer); err != nil { | |||||
t.Errorf("Failed to add new peer") | |||||
} | |||||
if peerSet.Size() != i+1 { | |||||
t.Errorf("Failed to add new peer and increment size") | |||||
} | |||||
peers = append(peers, peer) | |||||
} | |||||
for i, peer := range peers { | |||||
peerSet.Remove(peer) | |||||
if peerSet.Has(peer.Key) { | |||||
t.Errorf("Failed to remove peer") | |||||
} | |||||
if peerSet.Size() != len(peers)-i-1 { | |||||
t.Errorf("Failed to remove peer and decrement size") | |||||
} | |||||
} | |||||
} |
@ -0,0 +1,156 @@ | |||||
package p2p | |||||
import ( | |||||
golog "log" | |||||
"net" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
crypto "github.com/tendermint/go-crypto" | |||||
) | |||||
func TestPeerBasic(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
// simulate remote peer | |||||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} | |||||
rp.Start() | |||||
defer rp.Stop() | |||||
p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), DefaultPeerConfig()) | |||||
require.Nil(err) | |||||
p.Start() | |||||
defer p.Stop() | |||||
assert.True(p.IsRunning()) | |||||
assert.True(p.IsOutbound()) | |||||
assert.False(p.IsPersistent()) | |||||
p.makePersistent() | |||||
assert.True(p.IsPersistent()) | |||||
assert.Equal(rp.Addr().String(), p.Addr().String()) | |||||
assert.Equal(rp.PubKey(), p.PubKey()) | |||||
} | |||||
func TestPeerWithoutAuthEnc(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
config := DefaultPeerConfig() | |||||
config.AuthEnc = false | |||||
// simulate remote peer | |||||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config} | |||||
rp.Start() | |||||
defer rp.Stop() | |||||
p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config) | |||||
require.Nil(err) | |||||
p.Start() | |||||
defer p.Stop() | |||||
assert.True(p.IsRunning()) | |||||
} | |||||
func TestPeerSend(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
config := DefaultPeerConfig() | |||||
config.AuthEnc = false | |||||
// simulate remote peer | |||||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config} | |||||
rp.Start() | |||||
defer rp.Stop() | |||||
p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config) | |||||
require.Nil(err) | |||||
p.Start() | |||||
defer p.Stop() | |||||
assert.True(p.CanSend(0x01)) | |||||
assert.True(p.Send(0x01, "Asylum")) | |||||
} | |||||
func createOutboundPeerAndPerformHandshake(addr *NetAddress, config *PeerConfig) (*Peer, error) { | |||||
chDescs := []*ChannelDescriptor{ | |||||
&ChannelDescriptor{ID: 0x01, Priority: 1}, | |||||
} | |||||
reactorsByCh := map[byte]Reactor{0x01: NewTestReactor(chDescs, true)} | |||||
pk := crypto.GenPrivKeyEd25519() | |||||
p, err := newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, func(p *Peer, r interface{}) {}, pk, config) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
err = p.HandshakeTimeout(&NodeInfo{ | |||||
PubKey: pk.PubKey().Unwrap().(crypto.PubKeyEd25519), | |||||
Moniker: "host_peer", | |||||
Network: "testing", | |||||
Version: "123.123.123", | |||||
}, 1*time.Second) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return p, nil | |||||
} | |||||
type remotePeer struct { | |||||
PrivKey crypto.PrivKeyEd25519 | |||||
Config *PeerConfig | |||||
addr *NetAddress | |||||
quit chan struct{} | |||||
} | |||||
func (p *remotePeer) Addr() *NetAddress { | |||||
return p.addr | |||||
} | |||||
func (p *remotePeer) PubKey() crypto.PubKeyEd25519 { | |||||
return p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||||
} | |||||
func (p *remotePeer) Start() { | |||||
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address | |||||
if e != nil { | |||||
golog.Fatalf("net.Listen tcp :0: %+v", e) | |||||
} | |||||
p.addr = NewNetAddress(l.Addr()) | |||||
p.quit = make(chan struct{}) | |||||
go p.accept(l) | |||||
} | |||||
func (p *remotePeer) Stop() { | |||||
close(p.quit) | |||||
} | |||||
func (p *remotePeer) accept(l net.Listener) { | |||||
for { | |||||
conn, err := l.Accept() | |||||
if err != nil { | |||||
golog.Fatalf("Failed to accept conn: %+v", err) | |||||
} | |||||
peer, err := newInboundPeerWithConfig(conn, make(map[byte]Reactor), make([]*ChannelDescriptor, 0), func(p *Peer, r interface{}) {}, p.PrivKey, p.Config) | |||||
if err != nil { | |||||
golog.Fatalf("Failed to create a peer: %+v", err) | |||||
} | |||||
err = peer.HandshakeTimeout(&NodeInfo{ | |||||
PubKey: p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519), | |||||
Moniker: "remote_peer", | |||||
Network: "testing", | |||||
Version: "123.123.123", | |||||
}, 1*time.Second) | |||||
if err != nil { | |||||
golog.Fatalf("Failed to perform handshake: %+v", err) | |||||
} | |||||
select { | |||||
case <-p.quit: | |||||
conn.Close() | |||||
return | |||||
default: | |||||
} | |||||
} | |||||
} |
@ -0,0 +1,358 @@ | |||||
package p2p | |||||
import ( | |||||
"bytes" | |||||
"fmt" | |||||
"math/rand" | |||||
"reflect" | |||||
"time" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
wire "github.com/tendermint/go-wire" | |||||
) | |||||
const ( | |||||
// PexChannel is a channel for PEX messages | |||||
PexChannel = byte(0x00) | |||||
// period to ensure peers connected | |||||
defaultEnsurePeersPeriod = 30 * time.Second | |||||
minNumOutboundPeers = 10 | |||||
maxPexMessageSize = 1048576 // 1MB | |||||
// maximum messages one peer can send to us during `msgCountByPeerFlushInterval` | |||||
defaultMaxMsgCountByPeer = 1000 | |||||
msgCountByPeerFlushInterval = 1 * time.Hour | |||||
) | |||||
// PEXReactor handles PEX (peer exchange) and ensures that an | |||||
// adequate number of peers are connected to the switch. | |||||
// | |||||
// It uses `AddrBook` (address book) to store `NetAddress`es of the peers. | |||||
// | |||||
// ## Preventing abuse | |||||
// | |||||
// For now, it just limits the number of messages from one peer to | |||||
// `defaultMaxMsgCountByPeer` messages per `msgCountByPeerFlushInterval` (1000 | |||||
// msg/hour). | |||||
// | |||||
// NOTE [2017-01-17]: | |||||
// Limiting is fine for now. Maybe down the road we want to keep track of the | |||||
// quality of peer messages so if peerA keeps telling us about peers we can't | |||||
// connect to then maybe we should care less about peerA. But I don't think | |||||
// that kind of complexity is priority right now. | |||||
type PEXReactor struct { | |||||
BaseReactor | |||||
sw *Switch | |||||
book *AddrBook | |||||
ensurePeersPeriod time.Duration | |||||
// tracks message count by peer, so we can prevent abuse | |||||
msgCountByPeer *cmn.CMap | |||||
maxMsgCountByPeer uint16 | |||||
} | |||||
// NewPEXReactor creates new PEX reactor. | |||||
func NewPEXReactor(b *AddrBook) *PEXReactor { | |||||
r := &PEXReactor{ | |||||
book: b, | |||||
ensurePeersPeriod: defaultEnsurePeersPeriod, | |||||
msgCountByPeer: cmn.NewCMap(), | |||||
maxMsgCountByPeer: defaultMaxMsgCountByPeer, | |||||
} | |||||
r.BaseReactor = *NewBaseReactor(log, "PEXReactor", r) | |||||
return r | |||||
} | |||||
// OnStart implements BaseService | |||||
func (r *PEXReactor) OnStart() error { | |||||
r.BaseReactor.OnStart() | |||||
r.book.Start() | |||||
go r.ensurePeersRoutine() | |||||
go r.flushMsgCountByPeer() | |||||
return nil | |||||
} | |||||
// OnStop implements BaseService | |||||
func (r *PEXReactor) OnStop() { | |||||
r.BaseReactor.OnStop() | |||||
r.book.Stop() | |||||
} | |||||
// GetChannels implements Reactor | |||||
func (r *PEXReactor) GetChannels() []*ChannelDescriptor { | |||||
return []*ChannelDescriptor{ | |||||
&ChannelDescriptor{ | |||||
ID: PexChannel, | |||||
Priority: 1, | |||||
SendQueueCapacity: 10, | |||||
}, | |||||
} | |||||
} | |||||
// AddPeer implements Reactor by adding peer to the address book (if inbound) | |||||
// or by requesting more addresses (if outbound). | |||||
func (r *PEXReactor) AddPeer(p *Peer) { | |||||
if p.IsOutbound() { | |||||
// For outbound peers, the address is already in the books. | |||||
// Either it was added in DialSeeds or when we | |||||
// received the peer's address in r.Receive | |||||
if r.book.NeedMoreAddrs() { | |||||
r.RequestPEX(p) | |||||
} | |||||
} else { // For inbound connections, the peer is its own source | |||||
addr, err := NewNetAddressString(p.ListenAddr) | |||||
if err != nil { | |||||
// this should never happen | |||||
log.Error("Error in AddPeer: invalid peer address", "addr", p.ListenAddr, "error", err) | |||||
return | |||||
} | |||||
r.book.AddAddress(addr, addr) | |||||
} | |||||
} | |||||
// RemovePeer implements Reactor. | |||||
func (r *PEXReactor) RemovePeer(p *Peer, reason interface{}) { | |||||
// If we aren't keeping track of local temp data for each peer here, then we | |||||
// don't have to do anything. | |||||
} | |||||
// Receive implements Reactor by handling incoming PEX messages. | |||||
func (r *PEXReactor) Receive(chID byte, src *Peer, msgBytes []byte) { | |||||
srcAddr := src.Connection().RemoteAddress | |||||
srcAddrStr := srcAddr.String() | |||||
r.IncrementMsgCountForPeer(srcAddrStr) | |||||
if r.ReachedMaxMsgCountForPeer(srcAddrStr) { | |||||
log.Warn("Maximum number of messages reached for peer", "peer", srcAddrStr) | |||||
// TODO remove src from peers? | |||||
return | |||||
} | |||||
_, msg, err := DecodeMessage(msgBytes) | |||||
if err != nil { | |||||
log.Warn("Error decoding message", "error", err) | |||||
return | |||||
} | |||||
log.Notice("Received message", "msg", msg) | |||||
switch msg := msg.(type) { | |||||
case *pexRequestMessage: | |||||
// src requested some peers. | |||||
r.SendAddrs(src, r.book.GetSelection()) | |||||
case *pexAddrsMessage: | |||||
// We received some peer addresses from src. | |||||
// (We don't want to get spammed with bad peers) | |||||
for _, addr := range msg.Addrs { | |||||
if addr != nil { | |||||
r.book.AddAddress(addr, srcAddr) | |||||
} | |||||
} | |||||
default: | |||||
log.Warn(fmt.Sprintf("Unknown message type %v", reflect.TypeOf(msg))) | |||||
} | |||||
} | |||||
// RequestPEX asks peer for more addresses. | |||||
func (r *PEXReactor) RequestPEX(p *Peer) { | |||||
p.Send(PexChannel, struct{ PexMessage }{&pexRequestMessage{}}) | |||||
} | |||||
// SendAddrs sends addrs to the peer. | |||||
func (r *PEXReactor) SendAddrs(p *Peer, addrs []*NetAddress) { | |||||
p.Send(PexChannel, struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) | |||||
} | |||||
// SetEnsurePeersPeriod sets period to ensure peers connected. | |||||
func (r *PEXReactor) SetEnsurePeersPeriod(d time.Duration) { | |||||
r.ensurePeersPeriod = d | |||||
} | |||||
// SetMaxMsgCountByPeer sets maximum messages one peer can send to us during 'msgCountByPeerFlushInterval'. | |||||
func (r *PEXReactor) SetMaxMsgCountByPeer(v uint16) { | |||||
r.maxMsgCountByPeer = v | |||||
} | |||||
// ReachedMaxMsgCountForPeer returns true if we received too many | |||||
// messages from peer with address `addr`. | |||||
// NOTE: assumes the value in the CMap is non-nil | |||||
func (r *PEXReactor) ReachedMaxMsgCountForPeer(addr string) bool { | |||||
return r.msgCountByPeer.Get(addr).(uint16) >= r.maxMsgCountByPeer | |||||
} | |||||
// Increment or initialize the msg count for the peer in the CMap | |||||
func (r *PEXReactor) IncrementMsgCountForPeer(addr string) { | |||||
var count uint16 | |||||
countI := r.msgCountByPeer.Get(addr) | |||||
if countI != nil { | |||||
count = countI.(uint16) | |||||
} | |||||
count++ | |||||
r.msgCountByPeer.Set(addr, count) | |||||
} | |||||
// Ensures that sufficient peers are connected. (continuous) | |||||
func (r *PEXReactor) ensurePeersRoutine() { | |||||
// Randomize when routine starts | |||||
ensurePeersPeriodMs := r.ensurePeersPeriod.Nanoseconds() / 1e6 | |||||
time.Sleep(time.Duration(rand.Int63n(ensurePeersPeriodMs)) * time.Millisecond) | |||||
// fire once immediately. | |||||
r.ensurePeers() | |||||
// fire periodically | |||||
ticker := time.NewTicker(r.ensurePeersPeriod) | |||||
for { | |||||
select { | |||||
case <-ticker.C: | |||||
r.ensurePeers() | |||||
case <-r.Quit: | |||||
ticker.Stop() | |||||
return | |||||
} | |||||
} | |||||
} | |||||
// ensurePeers ensures that sufficient peers are connected. (once) | |||||
// | |||||
// Old bucket / New bucket are arbitrary categories to denote whether an | |||||
// address is vetted or not, and this needs to be determined over time via a | |||||
// heuristic that we haven't perfected yet, or, perhaps is manually edited by | |||||
// the node operator. It should not be used to compute what addresses are | |||||
// already connected or not. | |||||
// | |||||
// TODO Basically, we need to work harder on our good-peer/bad-peer marking. | |||||
// What we're currently doing in terms of marking good/bad peers is just a | |||||
// placeholder. It should not be the case that an address becomes old/vetted | |||||
// upon a single successful connection. | |||||
func (r *PEXReactor) ensurePeers() { | |||||
numOutPeers, _, numDialing := r.Switch.NumPeers() | |||||
numToDial := minNumOutboundPeers - (numOutPeers + numDialing) | |||||
log.Info("Ensure peers", "numOutPeers", numOutPeers, "numDialing", numDialing, "numToDial", numToDial) | |||||
if numToDial <= 0 { | |||||
return | |||||
} | |||||
toDial := make(map[string]*NetAddress) | |||||
// Try to pick numToDial addresses to dial. | |||||
for i := 0; i < numToDial; i++ { | |||||
// The purpose of newBias is to first prioritize old (more vetted) peers | |||||
// when we have few connections, but to allow for new (less vetted) peers | |||||
// if we already have many connections. This algorithm isn't perfect, but | |||||
// it somewhat ensures that we prioritize connecting to more-vetted | |||||
// peers. | |||||
newBias := cmn.MinInt(numOutPeers, 8)*10 + 10 | |||||
var picked *NetAddress | |||||
// Try to fetch a new peer 3 times. | |||||
// This caps the maximum number of tries to 3 * numToDial. | |||||
for j := 0; j < 3; j++ { | |||||
try := r.book.PickAddress(newBias) | |||||
if try == nil { | |||||
break | |||||
} | |||||
_, alreadySelected := toDial[try.IP.String()] | |||||
alreadyDialing := r.Switch.IsDialing(try) | |||||
alreadyConnected := r.Switch.Peers().Has(try.IP.String()) | |||||
if alreadySelected || alreadyDialing || alreadyConnected { | |||||
// log.Info("Cannot dial address", "addr", try, | |||||
// "alreadySelected", alreadySelected, | |||||
// "alreadyDialing", alreadyDialing, | |||||
// "alreadyConnected", alreadyConnected) | |||||
continue | |||||
} else { | |||||
log.Info("Will dial address", "addr", try) | |||||
picked = try | |||||
break | |||||
} | |||||
} | |||||
if picked == nil { | |||||
continue | |||||
} | |||||
toDial[picked.IP.String()] = picked | |||||
} | |||||
// Dial picked addresses | |||||
for _, item := range toDial { | |||||
go func(picked *NetAddress) { | |||||
_, err := r.Switch.DialPeerWithAddress(picked, false) | |||||
if err != nil { | |||||
r.book.MarkAttempt(picked) | |||||
} | |||||
}(item) | |||||
} | |||||
// If we need more addresses, pick a random peer and ask for more. | |||||
if r.book.NeedMoreAddrs() { | |||||
if peers := r.Switch.Peers().List(); len(peers) > 0 { | |||||
i := rand.Int() % len(peers) | |||||
peer := peers[i] | |||||
log.Info("No addresses to dial. Sending pexRequest to random peer", "peer", peer) | |||||
r.RequestPEX(peer) | |||||
} | |||||
} | |||||
} | |||||
func (r *PEXReactor) flushMsgCountByPeer() { | |||||
ticker := time.NewTicker(msgCountByPeerFlushInterval) | |||||
for { | |||||
select { | |||||
case <-ticker.C: | |||||
r.msgCountByPeer.Clear() | |||||
case <-r.Quit: | |||||
ticker.Stop() | |||||
return | |||||
} | |||||
} | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
// Messages | |||||
const ( | |||||
msgTypeRequest = byte(0x01) | |||||
msgTypeAddrs = byte(0x02) | |||||
) | |||||
// PexMessage is a primary type for PEX messages. Underneath, it could contain | |||||
// either pexRequestMessage, or pexAddrsMessage messages. | |||||
type PexMessage interface{} | |||||
var _ = wire.RegisterInterface( | |||||
struct{ PexMessage }{}, | |||||
wire.ConcreteType{&pexRequestMessage{}, msgTypeRequest}, | |||||
wire.ConcreteType{&pexAddrsMessage{}, msgTypeAddrs}, | |||||
) | |||||
// DecodeMessage implements interface registered above. | |||||
func DecodeMessage(bz []byte) (msgType byte, msg PexMessage, err error) { | |||||
msgType = bz[0] | |||||
n := new(int) | |||||
r := bytes.NewReader(bz) | |||||
msg = wire.ReadBinary(struct{ PexMessage }{}, r, maxPexMessageSize, n, &err).(struct{ PexMessage }).PexMessage | |||||
return | |||||
} | |||||
/* | |||||
A pexRequestMessage requests additional peer addresses. | |||||
*/ | |||||
type pexRequestMessage struct { | |||||
} | |||||
func (m *pexRequestMessage) String() string { | |||||
return "[pexRequest]" | |||||
} | |||||
/* | |||||
A message with announced peer addresses. | |||||
*/ | |||||
type pexAddrsMessage struct { | |||||
Addrs []*NetAddress | |||||
} | |||||
func (m *pexAddrsMessage) String() string { | |||||
return fmt.Sprintf("[pexAddrs %v]", m.Addrs) | |||||
} |
@ -0,0 +1,163 @@ | |||||
package p2p | |||||
import ( | |||||
"io/ioutil" | |||||
"math/rand" | |||||
"os" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
wire "github.com/tendermint/go-wire" | |||||
) | |||||
func TestPEXReactorBasic(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
dir, err := ioutil.TempDir("", "pex_reactor") | |||||
require.Nil(err) | |||||
defer os.RemoveAll(dir) | |||||
book := NewAddrBook(dir+"addrbook.json", true) | |||||
r := NewPEXReactor(book) | |||||
assert.NotNil(r) | |||||
assert.NotEmpty(r.GetChannels()) | |||||
} | |||||
func TestPEXReactorAddRemovePeer(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
dir, err := ioutil.TempDir("", "pex_reactor") | |||||
require.Nil(err) | |||||
defer os.RemoveAll(dir) | |||||
book := NewAddrBook(dir+"addrbook.json", true) | |||||
r := NewPEXReactor(book) | |||||
size := book.Size() | |||||
peer := createRandomPeer(false) | |||||
r.AddPeer(peer) | |||||
assert.Equal(size+1, book.Size()) | |||||
r.RemovePeer(peer, "peer not available") | |||||
assert.Equal(size+1, book.Size()) | |||||
outboundPeer := createRandomPeer(true) | |||||
r.AddPeer(outboundPeer) | |||||
assert.Equal(size+1, book.Size(), "outbound peers should not be added to the address book") | |||||
r.RemovePeer(outboundPeer, "peer not available") | |||||
assert.Equal(size+1, book.Size()) | |||||
} | |||||
func TestPEXReactorRunning(t *testing.T) { | |||||
require := require.New(t) | |||||
N := 3 | |||||
switches := make([]*Switch, N) | |||||
dir, err := ioutil.TempDir("", "pex_reactor") | |||||
require.Nil(err) | |||||
defer os.RemoveAll(dir) | |||||
book := NewAddrBook(dir+"addrbook.json", false) | |||||
// create switches | |||||
for i := 0; i < N; i++ { | |||||
switches[i] = makeSwitch(i, "127.0.0.1", "123.123.123", func(i int, sw *Switch) *Switch { | |||||
r := NewPEXReactor(book) | |||||
r.SetEnsurePeersPeriod(250 * time.Millisecond) | |||||
sw.AddReactor("pex", r) | |||||
return sw | |||||
}) | |||||
} | |||||
// fill the address book and add listeners | |||||
for _, s := range switches { | |||||
addr, _ := NewNetAddressString(s.NodeInfo().ListenAddr) | |||||
book.AddAddress(addr, addr) | |||||
s.AddListener(NewDefaultListener("tcp", s.NodeInfo().ListenAddr, true)) | |||||
} | |||||
// start switches | |||||
for _, s := range switches { | |||||
_, err := s.Start() // start switch and reactors | |||||
require.Nil(err) | |||||
} | |||||
time.Sleep(1 * time.Second) | |||||
// check peers are connected after some time | |||||
for _, s := range switches { | |||||
outbound, inbound, _ := s.NumPeers() | |||||
if outbound+inbound == 0 { | |||||
t.Errorf("%v expected to be connected to at least one peer", s.NodeInfo().ListenAddr) | |||||
} | |||||
} | |||||
// stop them | |||||
for _, s := range switches { | |||||
s.Stop() | |||||
} | |||||
} | |||||
func TestPEXReactorReceive(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
dir, err := ioutil.TempDir("", "pex_reactor") | |||||
require.Nil(err) | |||||
defer os.RemoveAll(dir) | |||||
book := NewAddrBook(dir+"addrbook.json", true) | |||||
r := NewPEXReactor(book) | |||||
peer := createRandomPeer(false) | |||||
size := book.Size() | |||||
netAddr, _ := NewNetAddressString(peer.ListenAddr) | |||||
addrs := []*NetAddress{netAddr} | |||||
msg := wire.BinaryBytes(struct{ PexMessage }{&pexAddrsMessage{Addrs: addrs}}) | |||||
r.Receive(PexChannel, peer, msg) | |||||
assert.Equal(size+1, book.Size()) | |||||
msg = wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}}) | |||||
r.Receive(PexChannel, peer, msg) | |||||
} | |||||
func TestPEXReactorAbuseFromPeer(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
dir, err := ioutil.TempDir("", "pex_reactor") | |||||
require.Nil(err) | |||||
defer os.RemoveAll(dir) | |||||
book := NewAddrBook(dir+"addrbook.json", true) | |||||
r := NewPEXReactor(book) | |||||
r.SetMaxMsgCountByPeer(5) | |||||
peer := createRandomPeer(false) | |||||
msg := wire.BinaryBytes(struct{ PexMessage }{&pexRequestMessage{}}) | |||||
for i := 0; i < 10; i++ { | |||||
r.Receive(PexChannel, peer, msg) | |||||
} | |||||
assert.True(r.ReachedMaxMsgCountForPeer(peer.ListenAddr)) | |||||
} | |||||
func createRandomPeer(outbound bool) *Peer { | |||||
addr := cmn.Fmt("%v.%v.%v.%v:46656", rand.Int()%256, rand.Int()%256, rand.Int()%256, rand.Int()%256) | |||||
netAddr, _ := NewNetAddressString(addr) | |||||
return &Peer{ | |||||
Key: cmn.RandStr(12), | |||||
NodeInfo: &NodeInfo{ | |||||
ListenAddr: addr, | |||||
}, | |||||
outbound: outbound, | |||||
mconn: &MConnection{RemoteAddress: netAddr}, | |||||
} | |||||
} |
@ -0,0 +1,346 @@ | |||||
// Uses nacl's secret_box to encrypt a net.Conn. | |||||
// It is (meant to be) an implementation of the STS protocol. | |||||
// Note we do not (yet) assume that a remote peer's pubkey | |||||
// is known ahead of time, and thus we are technically | |||||
// still vulnerable to MITM. (TODO!) | |||||
// See docs/sts-final.pdf for more info | |||||
package p2p | |||||
import ( | |||||
"bytes" | |||||
crand "crypto/rand" | |||||
"crypto/sha256" | |||||
"encoding/binary" | |||||
"errors" | |||||
"io" | |||||
"net" | |||||
"time" | |||||
"golang.org/x/crypto/nacl/box" | |||||
"golang.org/x/crypto/nacl/secretbox" | |||||
"golang.org/x/crypto/ripemd160" | |||||
"github.com/tendermint/go-crypto" | |||||
"github.com/tendermint/go-wire" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
) | |||||
// 2 + 1024 == 1026 total frame size | |||||
const dataLenSize = 2 // uint16 to describe the length, is <= dataMaxSize | |||||
const dataMaxSize = 1024 | |||||
const totalFrameSize = dataMaxSize + dataLenSize | |||||
const sealedFrameSize = totalFrameSize + secretbox.Overhead | |||||
const authSigMsgSize = (32 + 1) + (64 + 1) // fixed size (length prefixed) byte arrays | |||||
// Implements net.Conn | |||||
type SecretConnection struct { | |||||
conn io.ReadWriteCloser | |||||
recvBuffer []byte | |||||
recvNonce *[24]byte | |||||
sendNonce *[24]byte | |||||
remPubKey crypto.PubKeyEd25519 | |||||
shrSecret *[32]byte // shared secret | |||||
} | |||||
// Performs handshake and returns a new authenticated SecretConnection. | |||||
// Returns nil if error in handshake. | |||||
// Caller should call conn.Close() | |||||
// See docs/sts-final.pdf for more information. | |||||
func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey crypto.PrivKeyEd25519) (*SecretConnection, error) { | |||||
locPubKey := locPrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||||
// Generate ephemeral keys for perfect forward secrecy. | |||||
locEphPub, locEphPriv := genEphKeys() | |||||
// Write local ephemeral pubkey and receive one too. | |||||
// NOTE: every 32-byte string is accepted as a Curve25519 public key | |||||
// (see DJB's Curve25519 paper: http://cr.yp.to/ecdh/curve25519-20060209.pdf) | |||||
remEphPub, err := shareEphPubKey(conn, locEphPub) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
// Compute common shared secret. | |||||
shrSecret := computeSharedSecret(remEphPub, locEphPriv) | |||||
// Sort by lexical order. | |||||
loEphPub, hiEphPub := sort32(locEphPub, remEphPub) | |||||
// Generate nonces to use for secretbox. | |||||
recvNonce, sendNonce := genNonces(loEphPub, hiEphPub, locEphPub == loEphPub) | |||||
// Generate common challenge to sign. | |||||
challenge := genChallenge(loEphPub, hiEphPub) | |||||
// Construct SecretConnection. | |||||
sc := &SecretConnection{ | |||||
conn: conn, | |||||
recvBuffer: nil, | |||||
recvNonce: recvNonce, | |||||
sendNonce: sendNonce, | |||||
shrSecret: shrSecret, | |||||
} | |||||
// Sign the challenge bytes for authentication. | |||||
locSignature := signChallenge(challenge, locPrivKey) | |||||
// Share (in secret) each other's pubkey & challenge signature | |||||
authSigMsg, err := shareAuthSignature(sc, locPubKey, locSignature) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
remPubKey, remSignature := authSigMsg.Key, authSigMsg.Sig | |||||
if !remPubKey.VerifyBytes(challenge[:], remSignature) { | |||||
return nil, errors.New("Challenge verification failed") | |||||
} | |||||
// We've authorized. | |||||
sc.remPubKey = remPubKey.Unwrap().(crypto.PubKeyEd25519) | |||||
return sc, nil | |||||
} | |||||
// Returns authenticated remote pubkey | |||||
func (sc *SecretConnection) RemotePubKey() crypto.PubKeyEd25519 { | |||||
return sc.remPubKey | |||||
} | |||||
// Writes encrypted frames of `sealedFrameSize` | |||||
// CONTRACT: data smaller than dataMaxSize is read atomically. | |||||
func (sc *SecretConnection) Write(data []byte) (n int, err error) { | |||||
for 0 < len(data) { | |||||
var frame []byte = make([]byte, totalFrameSize) | |||||
var chunk []byte | |||||
if dataMaxSize < len(data) { | |||||
chunk = data[:dataMaxSize] | |||||
data = data[dataMaxSize:] | |||||
} else { | |||||
chunk = data | |||||
data = nil | |||||
} | |||||
chunkLength := len(chunk) | |||||
binary.BigEndian.PutUint16(frame, uint16(chunkLength)) | |||||
copy(frame[dataLenSize:], chunk) | |||||
// encrypt the frame | |||||
var sealedFrame = make([]byte, sealedFrameSize) | |||||
secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret) | |||||
// fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret) | |||||
incr2Nonce(sc.sendNonce) | |||||
// end encryption | |||||
_, err := sc.conn.Write(sealedFrame) | |||||
if err != nil { | |||||
return n, err | |||||
} else { | |||||
n += len(chunk) | |||||
} | |||||
} | |||||
return | |||||
} | |||||
// CONTRACT: data smaller than dataMaxSize is read atomically. | |||||
func (sc *SecretConnection) Read(data []byte) (n int, err error) { | |||||
if 0 < len(sc.recvBuffer) { | |||||
n_ := copy(data, sc.recvBuffer) | |||||
sc.recvBuffer = sc.recvBuffer[n_:] | |||||
return | |||||
} | |||||
sealedFrame := make([]byte, sealedFrameSize) | |||||
_, err = io.ReadFull(sc.conn, sealedFrame) | |||||
if err != nil { | |||||
return | |||||
} | |||||
// decrypt the frame | |||||
var frame = make([]byte, totalFrameSize) | |||||
// fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret) | |||||
_, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret) | |||||
if !ok { | |||||
return n, errors.New("Failed to decrypt SecretConnection") | |||||
} | |||||
incr2Nonce(sc.recvNonce) | |||||
// end decryption | |||||
var chunkLength = binary.BigEndian.Uint16(frame) // read the first two bytes | |||||
if chunkLength > dataMaxSize { | |||||
return 0, errors.New("chunkLength is greater than dataMaxSize") | |||||
} | |||||
var chunk = frame[dataLenSize : dataLenSize+chunkLength] | |||||
n = copy(data, chunk) | |||||
sc.recvBuffer = chunk[n:] | |||||
return | |||||
} | |||||
// Implements net.Conn | |||||
func (sc *SecretConnection) Close() error { return sc.conn.Close() } | |||||
func (sc *SecretConnection) LocalAddr() net.Addr { return sc.conn.(net.Conn).LocalAddr() } | |||||
func (sc *SecretConnection) RemoteAddr() net.Addr { return sc.conn.(net.Conn).RemoteAddr() } | |||||
func (sc *SecretConnection) SetDeadline(t time.Time) error { return sc.conn.(net.Conn).SetDeadline(t) } | |||||
func (sc *SecretConnection) SetReadDeadline(t time.Time) error { | |||||
return sc.conn.(net.Conn).SetReadDeadline(t) | |||||
} | |||||
func (sc *SecretConnection) SetWriteDeadline(t time.Time) error { | |||||
return sc.conn.(net.Conn).SetWriteDeadline(t) | |||||
} | |||||
func genEphKeys() (ephPub, ephPriv *[32]byte) { | |||||
var err error | |||||
ephPub, ephPriv, err = box.GenerateKey(crand.Reader) | |||||
if err != nil { | |||||
PanicCrisis("Could not generate ephemeral keypairs") | |||||
} | |||||
return | |||||
} | |||||
func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { | |||||
var err1, err2 error | |||||
Parallel( | |||||
func() { | |||||
_, err1 = conn.Write(locEphPub[:]) | |||||
}, | |||||
func() { | |||||
remEphPub = new([32]byte) | |||||
_, err2 = io.ReadFull(conn, remEphPub[:]) | |||||
}, | |||||
) | |||||
if err1 != nil { | |||||
return nil, err1 | |||||
} | |||||
if err2 != nil { | |||||
return nil, err2 | |||||
} | |||||
return remEphPub, nil | |||||
} | |||||
func computeSharedSecret(remPubKey, locPrivKey *[32]byte) (shrSecret *[32]byte) { | |||||
shrSecret = new([32]byte) | |||||
box.Precompute(shrSecret, remPubKey, locPrivKey) | |||||
return | |||||
} | |||||
func sort32(foo, bar *[32]byte) (lo, hi *[32]byte) { | |||||
if bytes.Compare(foo[:], bar[:]) < 0 { | |||||
lo = foo | |||||
hi = bar | |||||
} else { | |||||
lo = bar | |||||
hi = foo | |||||
} | |||||
return | |||||
} | |||||
func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce *[24]byte) { | |||||
nonce1 := hash24(append(loPubKey[:], hiPubKey[:]...)) | |||||
nonce2 := new([24]byte) | |||||
copy(nonce2[:], nonce1[:]) | |||||
nonce2[len(nonce2)-1] ^= 0x01 | |||||
if locIsLo { | |||||
recvNonce = nonce1 | |||||
sendNonce = nonce2 | |||||
} else { | |||||
recvNonce = nonce2 | |||||
sendNonce = nonce1 | |||||
} | |||||
return | |||||
} | |||||
func genChallenge(loPubKey, hiPubKey *[32]byte) (challenge *[32]byte) { | |||||
return hash32(append(loPubKey[:], hiPubKey[:]...)) | |||||
} | |||||
func signChallenge(challenge *[32]byte, locPrivKey crypto.PrivKeyEd25519) (signature crypto.SignatureEd25519) { | |||||
signature = locPrivKey.Sign(challenge[:]).Unwrap().(crypto.SignatureEd25519) | |||||
return | |||||
} | |||||
type authSigMessage struct { | |||||
Key crypto.PubKey | |||||
Sig crypto.Signature | |||||
} | |||||
func shareAuthSignature(sc *SecretConnection, pubKey crypto.PubKeyEd25519, signature crypto.SignatureEd25519) (*authSigMessage, error) { | |||||
var recvMsg authSigMessage | |||||
var err1, err2 error | |||||
Parallel( | |||||
func() { | |||||
msgBytes := wire.BinaryBytes(authSigMessage{pubKey.Wrap(), signature.Wrap()}) | |||||
_, err1 = sc.Write(msgBytes) | |||||
}, | |||||
func() { | |||||
readBuffer := make([]byte, authSigMsgSize) | |||||
_, err2 = io.ReadFull(sc, readBuffer) | |||||
if err2 != nil { | |||||
return | |||||
} | |||||
n := int(0) // not used. | |||||
recvMsg = wire.ReadBinary(authSigMessage{}, bytes.NewBuffer(readBuffer), authSigMsgSize, &n, &err2).(authSigMessage) | |||||
}) | |||||
if err1 != nil { | |||||
return nil, err1 | |||||
} | |||||
if err2 != nil { | |||||
return nil, err2 | |||||
} | |||||
return &recvMsg, nil | |||||
} | |||||
func verifyChallengeSignature(challenge *[32]byte, remPubKey crypto.PubKeyEd25519, remSignature crypto.SignatureEd25519) bool { | |||||
return remPubKey.VerifyBytes(challenge[:], remSignature.Wrap()) | |||||
} | |||||
//-------------------------------------------------------------------------------- | |||||
// sha256 | |||||
func hash32(input []byte) (res *[32]byte) { | |||||
hasher := sha256.New() | |||||
hasher.Write(input) // does not error | |||||
resSlice := hasher.Sum(nil) | |||||
res = new([32]byte) | |||||
copy(res[:], resSlice) | |||||
return | |||||
} | |||||
// We only fill in the first 20 bytes with ripemd160 | |||||
func hash24(input []byte) (res *[24]byte) { | |||||
hasher := ripemd160.New() | |||||
hasher.Write(input) // does not error | |||||
resSlice := hasher.Sum(nil) | |||||
res = new([24]byte) | |||||
copy(res[:], resSlice) | |||||
return | |||||
} | |||||
// ripemd160 | |||||
func hash20(input []byte) (res *[20]byte) { | |||||
hasher := ripemd160.New() | |||||
hasher.Write(input) // does not error | |||||
resSlice := hasher.Sum(nil) | |||||
res = new([20]byte) | |||||
copy(res[:], resSlice) | |||||
return | |||||
} | |||||
// increment nonce big-endian by 2 with wraparound. | |||||
func incr2Nonce(nonce *[24]byte) { | |||||
incrNonce(nonce) | |||||
incrNonce(nonce) | |||||
} | |||||
// increment nonce big-endian by 1 with wraparound. | |||||
func incrNonce(nonce *[24]byte) { | |||||
for i := 23; 0 <= i; i-- { | |||||
nonce[i] += 1 | |||||
if nonce[i] != 0 { | |||||
return | |||||
} | |||||
} | |||||
} |
@ -0,0 +1,202 @@ | |||||
package p2p | |||||
import ( | |||||
"bytes" | |||||
"io" | |||||
"testing" | |||||
"github.com/tendermint/go-crypto" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
) | |||||
type dummyConn struct { | |||||
*io.PipeReader | |||||
*io.PipeWriter | |||||
} | |||||
func (drw dummyConn) Close() (err error) { | |||||
err2 := drw.PipeWriter.CloseWithError(io.EOF) | |||||
err1 := drw.PipeReader.Close() | |||||
if err2 != nil { | |||||
return err | |||||
} | |||||
return err1 | |||||
} | |||||
// Each returned ReadWriteCloser is akin to a net.Connection | |||||
func makeDummyConnPair() (fooConn, barConn dummyConn) { | |||||
barReader, fooWriter := io.Pipe() | |||||
fooReader, barWriter := io.Pipe() | |||||
return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter} | |||||
} | |||||
func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { | |||||
fooConn, barConn := makeDummyConnPair() | |||||
fooPrvKey := crypto.GenPrivKeyEd25519() | |||||
fooPubKey := fooPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||||
barPrvKey := crypto.GenPrivKeyEd25519() | |||||
barPubKey := barPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||||
Parallel( | |||||
func() { | |||||
var err error | |||||
fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) | |||||
if err != nil { | |||||
tb.Errorf("Failed to establish SecretConnection for foo: %v", err) | |||||
return | |||||
} | |||||
remotePubBytes := fooSecConn.RemotePubKey() | |||||
if !bytes.Equal(remotePubBytes[:], barPubKey[:]) { | |||||
tb.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v", | |||||
barPubKey, fooSecConn.RemotePubKey()) | |||||
} | |||||
}, | |||||
func() { | |||||
var err error | |||||
barSecConn, err = MakeSecretConnection(barConn, barPrvKey) | |||||
if barSecConn == nil { | |||||
tb.Errorf("Failed to establish SecretConnection for bar: %v", err) | |||||
return | |||||
} | |||||
remotePubBytes := barSecConn.RemotePubKey() | |||||
if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) { | |||||
tb.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v", | |||||
fooPubKey, barSecConn.RemotePubKey()) | |||||
} | |||||
}) | |||||
return | |||||
} | |||||
func TestSecretConnectionHandshake(t *testing.T) { | |||||
fooSecConn, barSecConn := makeSecretConnPair(t) | |||||
fooSecConn.Close() | |||||
barSecConn.Close() | |||||
} | |||||
func TestSecretConnectionReadWrite(t *testing.T) { | |||||
fooConn, barConn := makeDummyConnPair() | |||||
fooWrites, barWrites := []string{}, []string{} | |||||
fooReads, barReads := []string{}, []string{} | |||||
// Pre-generate the things to write (for foo & bar) | |||||
for i := 0; i < 100; i++ { | |||||
fooWrites = append(fooWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) | |||||
barWrites = append(barWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) | |||||
} | |||||
// A helper that will run with (fooConn, fooWrites, fooReads) and vice versa | |||||
genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() { | |||||
return func() { | |||||
// Node handskae | |||||
nodePrvKey := crypto.GenPrivKeyEd25519() | |||||
nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) | |||||
if err != nil { | |||||
t.Errorf("Failed to establish SecretConnection for node: %v", err) | |||||
return | |||||
} | |||||
// In parallel, handle reads and writes | |||||
Parallel( | |||||
func() { | |||||
// Node writes | |||||
for _, nodeWrite := range nodeWrites { | |||||
n, err := nodeSecretConn.Write([]byte(nodeWrite)) | |||||
if err != nil { | |||||
t.Errorf("Failed to write to nodeSecretConn: %v", err) | |||||
return | |||||
} | |||||
if n != len(nodeWrite) { | |||||
t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) | |||||
return | |||||
} | |||||
} | |||||
nodeConn.PipeWriter.Close() | |||||
}, | |||||
func() { | |||||
// Node reads | |||||
readBuffer := make([]byte, dataMaxSize) | |||||
for { | |||||
n, err := nodeSecretConn.Read(readBuffer) | |||||
if err == io.EOF { | |||||
return | |||||
} else if err != nil { | |||||
t.Errorf("Failed to read from nodeSecretConn: %v", err) | |||||
return | |||||
} | |||||
*nodeReads = append(*nodeReads, string(readBuffer[:n])) | |||||
} | |||||
nodeConn.PipeReader.Close() | |||||
}) | |||||
} | |||||
} | |||||
// Run foo & bar in parallel | |||||
Parallel( | |||||
genNodeRunner(fooConn, fooWrites, &fooReads), | |||||
genNodeRunner(barConn, barWrites, &barReads), | |||||
) | |||||
// A helper to ensure that the writes and reads match. | |||||
// Additionally, small writes (<= dataMaxSize) must be atomically read. | |||||
compareWritesReads := func(writes []string, reads []string) { | |||||
for { | |||||
// Pop next write & corresponding reads | |||||
var read, write string = "", writes[0] | |||||
var readCount = 0 | |||||
for _, readChunk := range reads { | |||||
read += readChunk | |||||
readCount += 1 | |||||
if len(write) <= len(read) { | |||||
break | |||||
} | |||||
if len(write) <= dataMaxSize { | |||||
break // atomicity of small writes | |||||
} | |||||
} | |||||
// Compare | |||||
if write != read { | |||||
t.Errorf("Expected to read %X, got %X", write, read) | |||||
} | |||||
// Iterate | |||||
writes = writes[1:] | |||||
reads = reads[readCount:] | |||||
if len(writes) == 0 { | |||||
break | |||||
} | |||||
} | |||||
} | |||||
compareWritesReads(fooWrites, barReads) | |||||
compareWritesReads(barWrites, fooReads) | |||||
} | |||||
func BenchmarkSecretConnection(b *testing.B) { | |||||
b.StopTimer() | |||||
fooSecConn, barSecConn := makeSecretConnPair(b) | |||||
fooWriteText := RandStr(dataMaxSize) | |||||
// Consume reads from bar's reader | |||||
go func() { | |||||
readBuffer := make([]byte, dataMaxSize) | |||||
for { | |||||
_, err := barSecConn.Read(readBuffer) | |||||
if err == io.EOF { | |||||
return | |||||
} else if err != nil { | |||||
b.Fatalf("Failed to read from barSecConn: %v", err) | |||||
} | |||||
} | |||||
}() | |||||
b.StartTimer() | |||||
for i := 0; i < b.N; i++ { | |||||
_, err := fooSecConn.Write([]byte(fooWriteText)) | |||||
if err != nil { | |||||
b.Fatalf("Failed to write to fooSecConn: %v", err) | |||||
} | |||||
} | |||||
b.StopTimer() | |||||
fooSecConn.Close() | |||||
//barSecConn.Close() race condition | |||||
} |
@ -0,0 +1,593 @@ | |||||
package p2p | |||||
import ( | |||||
"errors" | |||||
"fmt" | |||||
"math/rand" | |||||
"net" | |||||
"time" | |||||
cfg "github.com/tendermint/go-config" | |||||
crypto "github.com/tendermint/go-crypto" | |||||
"github.com/tendermint/log15" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
) | |||||
const ( | |||||
reconnectAttempts = 30 | |||||
reconnectInterval = 3 * time.Second | |||||
) | |||||
type Reactor interface { | |||||
Service // Start, Stop | |||||
SetSwitch(*Switch) | |||||
GetChannels() []*ChannelDescriptor | |||||
AddPeer(peer *Peer) | |||||
RemovePeer(peer *Peer, reason interface{}) | |||||
Receive(chID byte, peer *Peer, msgBytes []byte) | |||||
} | |||||
//-------------------------------------- | |||||
type BaseReactor struct { | |||||
BaseService // Provides Start, Stop, .Quit | |||||
Switch *Switch | |||||
} | |||||
func NewBaseReactor(log log15.Logger, name string, impl Reactor) *BaseReactor { | |||||
return &BaseReactor{ | |||||
BaseService: *NewBaseService(log, name, impl), | |||||
Switch: nil, | |||||
} | |||||
} | |||||
func (br *BaseReactor) SetSwitch(sw *Switch) { | |||||
br.Switch = sw | |||||
} | |||||
func (_ *BaseReactor) GetChannels() []*ChannelDescriptor { return nil } | |||||
func (_ *BaseReactor) AddPeer(peer *Peer) {} | |||||
func (_ *BaseReactor) RemovePeer(peer *Peer, reason interface{}) {} | |||||
func (_ *BaseReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {} | |||||
//----------------------------------------------------------------------------- | |||||
/* | |||||
The `Switch` handles peer connections and exposes an API to receive incoming messages | |||||
on `Reactors`. Each `Reactor` is responsible for handling incoming messages of one | |||||
or more `Channels`. So while sending outgoing messages is typically performed on the peer, | |||||
incoming messages are received on the reactor. | |||||
*/ | |||||
type Switch struct { | |||||
BaseService | |||||
config cfg.Config | |||||
listeners []Listener | |||||
reactors map[string]Reactor | |||||
chDescs []*ChannelDescriptor | |||||
reactorsByCh map[byte]Reactor | |||||
peers *PeerSet | |||||
dialing *CMap | |||||
nodeInfo *NodeInfo // our node info | |||||
nodePrivKey crypto.PrivKeyEd25519 // our node privkey | |||||
filterConnByAddr func(net.Addr) error | |||||
filterConnByPubKey func(crypto.PubKeyEd25519) error | |||||
} | |||||
var ( | |||||
ErrSwitchDuplicatePeer = errors.New("Duplicate peer") | |||||
ErrSwitchMaxPeersPerIPRange = errors.New("IP range has too many peers") | |||||
) | |||||
func NewSwitch(config cfg.Config) *Switch { | |||||
setConfigDefaults(config) | |||||
sw := &Switch{ | |||||
config: config, | |||||
reactors: make(map[string]Reactor), | |||||
chDescs: make([]*ChannelDescriptor, 0), | |||||
reactorsByCh: make(map[byte]Reactor), | |||||
peers: NewPeerSet(), | |||||
dialing: NewCMap(), | |||||
nodeInfo: nil, | |||||
} | |||||
sw.BaseService = *NewBaseService(log, "P2P Switch", sw) | |||||
return sw | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor { | |||||
// Validate the reactor. | |||||
// No two reactors can share the same channel. | |||||
reactorChannels := reactor.GetChannels() | |||||
for _, chDesc := range reactorChannels { | |||||
chID := chDesc.ID | |||||
if sw.reactorsByCh[chID] != nil { | |||||
PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor)) | |||||
} | |||||
sw.chDescs = append(sw.chDescs, chDesc) | |||||
sw.reactorsByCh[chID] = reactor | |||||
} | |||||
sw.reactors[name] = reactor | |||||
reactor.SetSwitch(sw) | |||||
return reactor | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) Reactors() map[string]Reactor { | |||||
return sw.reactors | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) Reactor(name string) Reactor { | |||||
return sw.reactors[name] | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) AddListener(l Listener) { | |||||
sw.listeners = append(sw.listeners, l) | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) Listeners() []Listener { | |||||
return sw.listeners | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) IsListening() bool { | |||||
return len(sw.listeners) > 0 | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) { | |||||
sw.nodeInfo = nodeInfo | |||||
} | |||||
// Not goroutine safe. | |||||
func (sw *Switch) NodeInfo() *NodeInfo { | |||||
return sw.nodeInfo | |||||
} | |||||
// Not goroutine safe. | |||||
// NOTE: Overwrites sw.nodeInfo.PubKey | |||||
func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) { | |||||
sw.nodePrivKey = nodePrivKey | |||||
if sw.nodeInfo != nil { | |||||
sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519) | |||||
} | |||||
} | |||||
// Switch.Start() starts all the reactors, peers, and listeners. | |||||
func (sw *Switch) OnStart() error { | |||||
sw.BaseService.OnStart() | |||||
// Start reactors | |||||
for _, reactor := range sw.reactors { | |||||
_, err := reactor.Start() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
} | |||||
// Start peers | |||||
for _, peer := range sw.peers.List() { | |||||
sw.startInitPeer(peer) | |||||
} | |||||
// Start listeners | |||||
for _, listener := range sw.listeners { | |||||
go sw.listenerRoutine(listener) | |||||
} | |||||
return nil | |||||
} | |||||
func (sw *Switch) OnStop() { | |||||
sw.BaseService.OnStop() | |||||
// Stop listeners | |||||
for _, listener := range sw.listeners { | |||||
listener.Stop() | |||||
} | |||||
sw.listeners = nil | |||||
// Stop peers | |||||
for _, peer := range sw.peers.List() { | |||||
peer.Stop() | |||||
sw.peers.Remove(peer) | |||||
} | |||||
// Stop reactors | |||||
for _, reactor := range sw.reactors { | |||||
reactor.Stop() | |||||
} | |||||
} | |||||
// NOTE: This performs a blocking handshake before the peer is added. | |||||
// CONTRACT: If error is returned, peer is nil, and conn is immediately closed. | |||||
func (sw *Switch) AddPeer(peer *Peer) error { | |||||
if err := sw.FilterConnByAddr(peer.Addr()); err != nil { | |||||
return err | |||||
} | |||||
if err := sw.FilterConnByPubKey(peer.PubKey()); err != nil { | |||||
return err | |||||
} | |||||
if err := peer.HandshakeTimeout(sw.nodeInfo, time.Duration(sw.config.GetInt(configKeyHandshakeTimeoutSeconds))*time.Second); err != nil { | |||||
return err | |||||
} | |||||
// Avoid self | |||||
if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) { | |||||
return errors.New("Ignoring connection from self") | |||||
} | |||||
// Check version, chain id | |||||
if err := sw.nodeInfo.CompatibleWith(peer.NodeInfo); err != nil { | |||||
return err | |||||
} | |||||
// Add the peer to .peers | |||||
// ignore if duplicate or if we already have too many for that IP range | |||||
if err := sw.peers.Add(peer); err != nil { | |||||
log.Notice("Ignoring peer", "error", err, "peer", peer) | |||||
peer.Stop() | |||||
return err | |||||
} | |||||
// Start peer | |||||
if sw.IsRunning() { | |||||
sw.startInitPeer(peer) | |||||
} | |||||
log.Notice("Added peer", "peer", peer) | |||||
return nil | |||||
} | |||||
func (sw *Switch) FilterConnByAddr(addr net.Addr) error { | |||||
if sw.filterConnByAddr != nil { | |||||
return sw.filterConnByAddr(addr) | |||||
} | |||||
return nil | |||||
} | |||||
func (sw *Switch) FilterConnByPubKey(pubkey crypto.PubKeyEd25519) error { | |||||
if sw.filterConnByPubKey != nil { | |||||
return sw.filterConnByPubKey(pubkey) | |||||
} | |||||
return nil | |||||
} | |||||
func (sw *Switch) SetAddrFilter(f func(net.Addr) error) { | |||||
sw.filterConnByAddr = f | |||||
} | |||||
func (sw *Switch) SetPubKeyFilter(f func(crypto.PubKeyEd25519) error) { | |||||
sw.filterConnByPubKey = f | |||||
} | |||||
func (sw *Switch) startInitPeer(peer *Peer) { | |||||
peer.Start() // spawn send/recv routines | |||||
for _, reactor := range sw.reactors { | |||||
reactor.AddPeer(peer) | |||||
} | |||||
} | |||||
// Dial a list of seeds asynchronously in random order | |||||
func (sw *Switch) DialSeeds(addrBook *AddrBook, seeds []string) error { | |||||
netAddrs, err := NewNetAddressStrings(seeds) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
if addrBook != nil { | |||||
// add seeds to `addrBook` | |||||
ourAddrS := sw.nodeInfo.ListenAddr | |||||
ourAddr, _ := NewNetAddressString(ourAddrS) | |||||
for _, netAddr := range netAddrs { | |||||
// do not add ourselves | |||||
if netAddr.Equals(ourAddr) { | |||||
continue | |||||
} | |||||
addrBook.AddAddress(netAddr, ourAddr) | |||||
} | |||||
addrBook.Save() | |||||
} | |||||
// permute the list, dial them in random order. | |||||
perm := rand.Perm(len(netAddrs)) | |||||
for i := 0; i < len(perm); i++ { | |||||
go func(i int) { | |||||
time.Sleep(time.Duration(rand.Int63n(3000)) * time.Millisecond) | |||||
j := perm[i] | |||||
sw.dialSeed(netAddrs[j]) | |||||
}(i) | |||||
} | |||||
return nil | |||||
} | |||||
func (sw *Switch) dialSeed(addr *NetAddress) { | |||||
peer, err := sw.DialPeerWithAddress(addr, true) | |||||
if err != nil { | |||||
log.Error("Error dialing seed", "error", err) | |||||
return | |||||
} else { | |||||
log.Notice("Connected to seed", "peer", peer) | |||||
} | |||||
} | |||||
func (sw *Switch) DialPeerWithAddress(addr *NetAddress, persistent bool) (*Peer, error) { | |||||
sw.dialing.Set(addr.IP.String(), addr) | |||||
defer sw.dialing.Delete(addr.IP.String()) | |||||
peer, err := newOutboundPeerWithConfig(addr, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, peerConfigFromGoConfig(sw.config)) | |||||
if err != nil { | |||||
log.Info("Failed dialing peer", "address", addr, "error", err) | |||||
return nil, err | |||||
} | |||||
if persistent { | |||||
peer.makePersistent() | |||||
} | |||||
err = sw.AddPeer(peer) | |||||
if err != nil { | |||||
log.Info("Failed adding peer", "address", addr, "error", err) | |||||
peer.CloseConn() | |||||
return nil, err | |||||
} | |||||
log.Notice("Dialed and added peer", "address", addr, "peer", peer) | |||||
return peer, nil | |||||
} | |||||
func (sw *Switch) IsDialing(addr *NetAddress) bool { | |||||
return sw.dialing.Has(addr.IP.String()) | |||||
} | |||||
// Broadcast runs a go routine for each attempted send, which will block | |||||
// trying to send for defaultSendTimeoutSeconds. Returns a channel | |||||
// which receives success values for each attempted send (false if times out) | |||||
// NOTE: Broadcast uses goroutines, so order of broadcast may not be preserved. | |||||
func (sw *Switch) Broadcast(chID byte, msg interface{}) chan bool { | |||||
successChan := make(chan bool, len(sw.peers.List())) | |||||
log.Debug("Broadcast", "channel", chID, "msg", msg) | |||||
for _, peer := range sw.peers.List() { | |||||
go func(peer *Peer) { | |||||
success := peer.Send(chID, msg) | |||||
successChan <- success | |||||
}(peer) | |||||
} | |||||
return successChan | |||||
} | |||||
// Returns the count of outbound/inbound and outbound-dialing peers. | |||||
func (sw *Switch) NumPeers() (outbound, inbound, dialing int) { | |||||
peers := sw.peers.List() | |||||
for _, peer := range peers { | |||||
if peer.outbound { | |||||
outbound++ | |||||
} else { | |||||
inbound++ | |||||
} | |||||
} | |||||
dialing = sw.dialing.Size() | |||||
return | |||||
} | |||||
func (sw *Switch) Peers() IPeerSet { | |||||
return sw.peers | |||||
} | |||||
// Disconnect from a peer due to external error, retry if it is a persistent peer. | |||||
// TODO: make record depending on reason. | |||||
func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) { | |||||
addr := NewNetAddress(peer.Addr()) | |||||
log.Notice("Stopping peer for error", "peer", peer, "error", reason) | |||||
sw.stopAndRemovePeer(peer, reason) | |||||
if peer.IsPersistent() { | |||||
go func() { | |||||
log.Notice("Reconnecting to peer", "peer", peer) | |||||
for i := 1; i < reconnectAttempts; i++ { | |||||
if !sw.IsRunning() { | |||||
return | |||||
} | |||||
peer, err := sw.DialPeerWithAddress(addr, true) | |||||
if err != nil { | |||||
if i == reconnectAttempts { | |||||
log.Notice("Error reconnecting to peer. Giving up", "tries", i, "error", err) | |||||
return | |||||
} | |||||
log.Notice("Error reconnecting to peer. Trying again", "tries", i, "error", err) | |||||
time.Sleep(reconnectInterval) | |||||
continue | |||||
} | |||||
log.Notice("Reconnected to peer", "peer", peer) | |||||
return | |||||
} | |||||
}() | |||||
} | |||||
} | |||||
// Disconnect from a peer gracefully. | |||||
// TODO: handle graceful disconnects. | |||||
func (sw *Switch) StopPeerGracefully(peer *Peer) { | |||||
log.Notice("Stopping peer gracefully") | |||||
sw.stopAndRemovePeer(peer, nil) | |||||
} | |||||
func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) { | |||||
sw.peers.Remove(peer) | |||||
peer.Stop() | |||||
for _, reactor := range sw.reactors { | |||||
reactor.RemovePeer(peer, reason) | |||||
} | |||||
} | |||||
func (sw *Switch) listenerRoutine(l Listener) { | |||||
for { | |||||
inConn, ok := <-l.Connections() | |||||
if !ok { | |||||
break | |||||
} | |||||
// ignore connection if we already have enough | |||||
maxPeers := sw.config.GetInt(configKeyMaxNumPeers) | |||||
if maxPeers <= sw.peers.Size() { | |||||
log.Info("Ignoring inbound connection: already have enough peers", "address", inConn.RemoteAddr().String(), "numPeers", sw.peers.Size(), "max", maxPeers) | |||||
continue | |||||
} | |||||
// New inbound connection! | |||||
err := sw.addPeerWithConnectionAndConfig(inConn, peerConfigFromGoConfig(sw.config)) | |||||
if err != nil { | |||||
log.Notice("Ignoring inbound connection: error while adding peer", "address", inConn.RemoteAddr().String(), "error", err) | |||||
continue | |||||
} | |||||
// NOTE: We don't yet have the listening port of the | |||||
// remote (if they have a listener at all). | |||||
// The peerHandshake will handle that | |||||
} | |||||
// cleanup | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
type SwitchEventNewPeer struct { | |||||
Peer *Peer | |||||
} | |||||
type SwitchEventDonePeer struct { | |||||
Peer *Peer | |||||
Error interface{} | |||||
} | |||||
//------------------------------------------------------------------ | |||||
// Switches connected via arbitrary net.Conn; useful for testing | |||||
// Returns n switches, connected according to the connect func. | |||||
// If connect==Connect2Switches, the switches will be fully connected. | |||||
// initSwitch defines how the ith switch should be initialized (ie. with what reactors). | |||||
// NOTE: panics if any switch fails to start. | |||||
func MakeConnectedSwitches(n int, initSwitch func(int, *Switch) *Switch, connect func([]*Switch, int, int)) []*Switch { | |||||
switches := make([]*Switch, n) | |||||
for i := 0; i < n; i++ { | |||||
switches[i] = makeSwitch(i, "testing", "123.123.123", initSwitch) | |||||
} | |||||
if err := StartSwitches(switches); err != nil { | |||||
panic(err) | |||||
} | |||||
for i := 0; i < n; i++ { | |||||
for j := i; j < n; j++ { | |||||
connect(switches, i, j) | |||||
} | |||||
} | |||||
return switches | |||||
} | |||||
var PanicOnAddPeerErr = false | |||||
// Will connect switches i and j via net.Pipe() | |||||
// Blocks until a conection is established. | |||||
// NOTE: caller ensures i and j are within bounds | |||||
func Connect2Switches(switches []*Switch, i, j int) { | |||||
switchI := switches[i] | |||||
switchJ := switches[j] | |||||
c1, c2 := net.Pipe() | |||||
doneCh := make(chan struct{}) | |||||
go func() { | |||||
err := switchI.addPeerWithConnection(c1) | |||||
if PanicOnAddPeerErr && err != nil { | |||||
panic(err) | |||||
} | |||||
doneCh <- struct{}{} | |||||
}() | |||||
go func() { | |||||
err := switchJ.addPeerWithConnection(c2) | |||||
if PanicOnAddPeerErr && err != nil { | |||||
panic(err) | |||||
} | |||||
doneCh <- struct{}{} | |||||
}() | |||||
<-doneCh | |||||
<-doneCh | |||||
} | |||||
func StartSwitches(switches []*Switch) error { | |||||
for _, s := range switches { | |||||
_, err := s.Start() // start switch and reactors | |||||
if err != nil { | |||||
return err | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func makeSwitch(i int, network, version string, initSwitch func(int, *Switch) *Switch) *Switch { | |||||
privKey := crypto.GenPrivKeyEd25519() | |||||
// new switch, add reactors | |||||
// TODO: let the config be passed in? | |||||
s := initSwitch(i, NewSwitch(cfg.NewMapConfig(nil))) | |||||
s.SetNodeInfo(&NodeInfo{ | |||||
PubKey: privKey.PubKey().Unwrap().(crypto.PubKeyEd25519), | |||||
Moniker: Fmt("switch%d", i), | |||||
Network: network, | |||||
Version: version, | |||||
RemoteAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023), | |||||
ListenAddr: Fmt("%v:%v", network, rand.Intn(64512)+1023), | |||||
}) | |||||
s.SetNodePrivKey(privKey) | |||||
return s | |||||
} | |||||
func (sw *Switch) addPeerWithConnection(conn net.Conn) error { | |||||
peer, err := newInboundPeer(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) | |||||
if err != nil { | |||||
conn.Close() | |||||
return err | |||||
} | |||||
if err = sw.AddPeer(peer); err != nil { | |||||
conn.Close() | |||||
return err | |||||
} | |||||
return nil | |||||
} | |||||
func (sw *Switch) addPeerWithConnectionAndConfig(conn net.Conn, config *PeerConfig) error { | |||||
peer, err := newInboundPeerWithConfig(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config) | |||||
if err != nil { | |||||
conn.Close() | |||||
return err | |||||
} | |||||
if err = sw.AddPeer(peer); err != nil { | |||||
conn.Close() | |||||
return err | |||||
} | |||||
return nil | |||||
} | |||||
func peerConfigFromGoConfig(config cfg.Config) *PeerConfig { | |||||
return &PeerConfig{ | |||||
AuthEnc: config.GetBool(configKeyAuthEnc), | |||||
Fuzz: config.GetBool(configFuzzEnable), | |||||
HandshakeTimeout: time.Duration(config.GetInt(configKeyHandshakeTimeoutSeconds)) * time.Second, | |||||
DialTimeout: time.Duration(config.GetInt(configKeyDialTimeoutSeconds)) * time.Second, | |||||
MConfig: &MConnConfig{ | |||||
SendRate: int64(config.GetInt(configKeySendRate)), | |||||
RecvRate: int64(config.GetInt(configKeyRecvRate)), | |||||
}, | |||||
FuzzConfig: &FuzzConnConfig{ | |||||
Mode: config.GetInt(configFuzzMode), | |||||
MaxDelay: time.Duration(config.GetInt(configFuzzMaxDelayMilliseconds)) * time.Millisecond, | |||||
ProbDropRW: config.GetFloat64(configFuzzProbDropRW), | |||||
ProbDropConn: config.GetFloat64(configFuzzProbDropConn), | |||||
ProbSleep: config.GetFloat64(configFuzzProbSleep), | |||||
}, | |||||
} | |||||
} |
@ -0,0 +1,330 @@ | |||||
package p2p | |||||
import ( | |||||
"bytes" | |||||
"fmt" | |||||
"net" | |||||
"sync" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
cfg "github.com/tendermint/go-config" | |||||
crypto "github.com/tendermint/go-crypto" | |||||
wire "github.com/tendermint/go-wire" | |||||
) | |||||
var ( | |||||
config cfg.Config | |||||
) | |||||
func init() { | |||||
config = cfg.NewMapConfig(nil) | |||||
setConfigDefaults(config) | |||||
} | |||||
type PeerMessage struct { | |||||
PeerKey string | |||||
Bytes []byte | |||||
Counter int | |||||
} | |||||
type TestReactor struct { | |||||
BaseReactor | |||||
mtx sync.Mutex | |||||
channels []*ChannelDescriptor | |||||
peersAdded []*Peer | |||||
peersRemoved []*Peer | |||||
logMessages bool | |||||
msgsCounter int | |||||
msgsReceived map[byte][]PeerMessage | |||||
} | |||||
func NewTestReactor(channels []*ChannelDescriptor, logMessages bool) *TestReactor { | |||||
tr := &TestReactor{ | |||||
channels: channels, | |||||
logMessages: logMessages, | |||||
msgsReceived: make(map[byte][]PeerMessage), | |||||
} | |||||
tr.BaseReactor = *NewBaseReactor(log, "TestReactor", tr) | |||||
return tr | |||||
} | |||||
func (tr *TestReactor) GetChannels() []*ChannelDescriptor { | |||||
return tr.channels | |||||
} | |||||
func (tr *TestReactor) AddPeer(peer *Peer) { | |||||
tr.mtx.Lock() | |||||
defer tr.mtx.Unlock() | |||||
tr.peersAdded = append(tr.peersAdded, peer) | |||||
} | |||||
func (tr *TestReactor) RemovePeer(peer *Peer, reason interface{}) { | |||||
tr.mtx.Lock() | |||||
defer tr.mtx.Unlock() | |||||
tr.peersRemoved = append(tr.peersRemoved, peer) | |||||
} | |||||
func (tr *TestReactor) Receive(chID byte, peer *Peer, msgBytes []byte) { | |||||
if tr.logMessages { | |||||
tr.mtx.Lock() | |||||
defer tr.mtx.Unlock() | |||||
//fmt.Printf("Received: %X, %X\n", chID, msgBytes) | |||||
tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.Key, msgBytes, tr.msgsCounter}) | |||||
tr.msgsCounter++ | |||||
} | |||||
} | |||||
func (tr *TestReactor) getMsgs(chID byte) []PeerMessage { | |||||
tr.mtx.Lock() | |||||
defer tr.mtx.Unlock() | |||||
return tr.msgsReceived[chID] | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
// convenience method for creating two switches connected to each other. | |||||
// XXX: note this uses net.Pipe and not a proper TCP conn | |||||
func makeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) { | |||||
// Create two switches that will be interconnected. | |||||
switches := MakeConnectedSwitches(2, initSwitch, Connect2Switches) | |||||
return switches[0], switches[1] | |||||
} | |||||
func initSwitchFunc(i int, sw *Switch) *Switch { | |||||
// Make two reactors of two channels each | |||||
sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ | |||||
&ChannelDescriptor{ID: byte(0x00), Priority: 10}, | |||||
&ChannelDescriptor{ID: byte(0x01), Priority: 10}, | |||||
}, true)) | |||||
sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ | |||||
&ChannelDescriptor{ID: byte(0x02), Priority: 10}, | |||||
&ChannelDescriptor{ID: byte(0x03), Priority: 10}, | |||||
}, true)) | |||||
return sw | |||||
} | |||||
func TestSwitches(t *testing.T) { | |||||
s1, s2 := makeSwitchPair(t, initSwitchFunc) | |||||
defer s1.Stop() | |||||
defer s2.Stop() | |||||
if s1.Peers().Size() != 1 { | |||||
t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size()) | |||||
} | |||||
if s2.Peers().Size() != 1 { | |||||
t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size()) | |||||
} | |||||
// Lets send some messages | |||||
ch0Msg := "channel zero" | |||||
ch1Msg := "channel foo" | |||||
ch2Msg := "channel bar" | |||||
s1.Broadcast(byte(0x00), ch0Msg) | |||||
s1.Broadcast(byte(0x01), ch1Msg) | |||||
s1.Broadcast(byte(0x02), ch2Msg) | |||||
// Wait for things to settle... | |||||
time.Sleep(5000 * time.Millisecond) | |||||
// Check message on ch0 | |||||
ch0Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x00)) | |||||
if len(ch0Msgs) != 1 { | |||||
t.Errorf("Expected to have received 1 message in ch0") | |||||
} | |||||
if !bytes.Equal(ch0Msgs[0].Bytes, wire.BinaryBytes(ch0Msg)) { | |||||
t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch0Msg), ch0Msgs[0].Bytes) | |||||
} | |||||
// Check message on ch1 | |||||
ch1Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x01)) | |||||
if len(ch1Msgs) != 1 { | |||||
t.Errorf("Expected to have received 1 message in ch1") | |||||
} | |||||
if !bytes.Equal(ch1Msgs[0].Bytes, wire.BinaryBytes(ch1Msg)) { | |||||
t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch1Msg), ch1Msgs[0].Bytes) | |||||
} | |||||
// Check message on ch2 | |||||
ch2Msgs := s2.Reactor("bar").(*TestReactor).getMsgs(byte(0x02)) | |||||
if len(ch2Msgs) != 1 { | |||||
t.Errorf("Expected to have received 1 message in ch2") | |||||
} | |||||
if !bytes.Equal(ch2Msgs[0].Bytes, wire.BinaryBytes(ch2Msg)) { | |||||
t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch2Msg), ch2Msgs[0].Bytes) | |||||
} | |||||
} | |||||
func TestConnAddrFilter(t *testing.T) { | |||||
s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||||
s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||||
c1, c2 := net.Pipe() | |||||
s1.SetAddrFilter(func(addr net.Addr) error { | |||||
if addr.String() == c1.RemoteAddr().String() { | |||||
return fmt.Errorf("Error: pipe is blacklisted") | |||||
} | |||||
return nil | |||||
}) | |||||
// connect to good peer | |||||
go func() { | |||||
s1.addPeerWithConnection(c1) | |||||
}() | |||||
go func() { | |||||
s2.addPeerWithConnection(c2) | |||||
}() | |||||
// Wait for things to happen, peers to get added... | |||||
time.Sleep(100 * time.Millisecond * time.Duration(4)) | |||||
defer s1.Stop() | |||||
defer s2.Stop() | |||||
if s1.Peers().Size() != 0 { | |||||
t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) | |||||
} | |||||
if s2.Peers().Size() != 0 { | |||||
t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) | |||||
} | |||||
} | |||||
func TestConnPubKeyFilter(t *testing.T) { | |||||
s1 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||||
s2 := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||||
c1, c2 := net.Pipe() | |||||
// set pubkey filter | |||||
s1.SetPubKeyFilter(func(pubkey crypto.PubKeyEd25519) error { | |||||
if bytes.Equal(pubkey.Bytes(), s2.nodeInfo.PubKey.Bytes()) { | |||||
return fmt.Errorf("Error: pipe is blacklisted") | |||||
} | |||||
return nil | |||||
}) | |||||
// connect to good peer | |||||
go func() { | |||||
s1.addPeerWithConnection(c1) | |||||
}() | |||||
go func() { | |||||
s2.addPeerWithConnection(c2) | |||||
}() | |||||
// Wait for things to happen, peers to get added... | |||||
time.Sleep(100 * time.Millisecond * time.Duration(4)) | |||||
defer s1.Stop() | |||||
defer s2.Stop() | |||||
if s1.Peers().Size() != 0 { | |||||
t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size()) | |||||
} | |||||
if s2.Peers().Size() != 0 { | |||||
t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size()) | |||||
} | |||||
} | |||||
func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||||
sw.Start() | |||||
defer sw.Stop() | |||||
// simulate remote peer | |||||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} | |||||
rp.Start() | |||||
defer rp.Stop() | |||||
peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) | |||||
require.Nil(err) | |||||
err = sw.AddPeer(peer) | |||||
require.Nil(err) | |||||
// simulate failure by closing connection | |||||
peer.CloseConn() | |||||
time.Sleep(100 * time.Millisecond) | |||||
assert.Zero(sw.Peers().Size()) | |||||
assert.False(peer.IsRunning()) | |||||
} | |||||
func TestSwitchReconnectsToPersistentPeer(t *testing.T) { | |||||
assert, require := assert.New(t), require.New(t) | |||||
sw := makeSwitch(1, "testing", "123.123.123", initSwitchFunc) | |||||
sw.Start() | |||||
defer sw.Stop() | |||||
// simulate remote peer | |||||
rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig()} | |||||
rp.Start() | |||||
defer rp.Stop() | |||||
peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey) | |||||
peer.makePersistent() | |||||
require.Nil(err) | |||||
err = sw.AddPeer(peer) | |||||
require.Nil(err) | |||||
// simulate failure by closing connection | |||||
peer.CloseConn() | |||||
time.Sleep(100 * time.Millisecond) | |||||
assert.NotZero(sw.Peers().Size()) | |||||
assert.False(peer.IsRunning()) | |||||
} | |||||
func BenchmarkSwitches(b *testing.B) { | |||||
b.StopTimer() | |||||
s1, s2 := makeSwitchPair(b, func(i int, sw *Switch) *Switch { | |||||
// Make bar reactors of bar channels each | |||||
sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{ | |||||
&ChannelDescriptor{ID: byte(0x00), Priority: 10}, | |||||
&ChannelDescriptor{ID: byte(0x01), Priority: 10}, | |||||
}, false)) | |||||
sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{ | |||||
&ChannelDescriptor{ID: byte(0x02), Priority: 10}, | |||||
&ChannelDescriptor{ID: byte(0x03), Priority: 10}, | |||||
}, false)) | |||||
return sw | |||||
}) | |||||
defer s1.Stop() | |||||
defer s2.Stop() | |||||
// Allow time for goroutines to boot up | |||||
time.Sleep(1000 * time.Millisecond) | |||||
b.StartTimer() | |||||
numSuccess, numFailure := 0, 0 | |||||
// Send random message from foo channel to another | |||||
for i := 0; i < b.N; i++ { | |||||
chID := byte(i % 4) | |||||
successChan := s1.Broadcast(chID, "test data") | |||||
for s := range successChan { | |||||
if s { | |||||
numSuccess++ | |||||
} else { | |||||
numFailure++ | |||||
} | |||||
} | |||||
} | |||||
log.Warn(Fmt("success: %v, failure: %v", numSuccess, numFailure)) | |||||
// Allow everything to flush before stopping switches & closing connections. | |||||
b.StopTimer() | |||||
time.Sleep(1000 * time.Millisecond) | |||||
} |
@ -0,0 +1,77 @@ | |||||
package p2p | |||||
import ( | |||||
"fmt" | |||||
"net" | |||||
"strconv" | |||||
"strings" | |||||
"github.com/tendermint/go-crypto" | |||||
) | |||||
const maxNodeInfoSize = 10240 // 10Kb | |||||
type NodeInfo struct { | |||||
PubKey crypto.PubKeyEd25519 `json:"pub_key"` | |||||
Moniker string `json:"moniker"` | |||||
Network string `json:"network"` | |||||
RemoteAddr string `json:"remote_addr"` | |||||
ListenAddr string `json:"listen_addr"` | |||||
Version string `json:"version"` // major.minor.revision | |||||
Other []string `json:"other"` // other application specific data | |||||
} | |||||
// CONTRACT: two nodes are compatible if the major/minor versions match and network match | |||||
func (info *NodeInfo) CompatibleWith(other *NodeInfo) error { | |||||
iMajor, iMinor, _, iErr := splitVersion(info.Version) | |||||
oMajor, oMinor, _, oErr := splitVersion(other.Version) | |||||
// if our own version number is not formatted right, we messed up | |||||
if iErr != nil { | |||||
return iErr | |||||
} | |||||
// version number must be formatted correctly ("x.x.x") | |||||
if oErr != nil { | |||||
return oErr | |||||
} | |||||
// major version must match | |||||
if iMajor != oMajor { | |||||
return fmt.Errorf("Peer is on a different major version. Got %v, expected %v", oMajor, iMajor) | |||||
} | |||||
// minor version must match | |||||
if iMinor != oMinor { | |||||
return fmt.Errorf("Peer is on a different minor version. Got %v, expected %v", oMinor, iMinor) | |||||
} | |||||
// nodes must be on the same network | |||||
if info.Network != other.Network { | |||||
return fmt.Errorf("Peer is on a different network. Got %v, expected %v", other.Network, info.Network) | |||||
} | |||||
return nil | |||||
} | |||||
func (info *NodeInfo) ListenHost() string { | |||||
host, _, _ := net.SplitHostPort(info.ListenAddr) | |||||
return host | |||||
} | |||||
func (info *NodeInfo) ListenPort() int { | |||||
_, port, _ := net.SplitHostPort(info.ListenAddr) | |||||
port_i, err := strconv.Atoi(port) | |||||
if err != nil { | |||||
return -1 | |||||
} | |||||
return port_i | |||||
} | |||||
func splitVersion(version string) (string, string, string, error) { | |||||
spl := strings.Split(version, ".") | |||||
if len(spl) != 3 { | |||||
return "", "", "", fmt.Errorf("Invalid version format %v", version) | |||||
} | |||||
return spl[0], spl[1], spl[2], nil | |||||
} |
@ -0,0 +1,5 @@ | |||||
# `tendermint/p2p/upnp` | |||||
## Resources | |||||
* http://www.upnp-hacks.org/upnp.html |
@ -0,0 +1,7 @@ | |||||
package upnp | |||||
import ( | |||||
"github.com/tendermint/tmlibs/logger" | |||||
) | |||||
var log = logger.New("module", "upnp") |
@ -0,0 +1,111 @@ | |||||
package upnp | |||||
import ( | |||||
"errors" | |||||
"fmt" | |||||
"net" | |||||
"time" | |||||
. "github.com/tendermint/tmlibs/common" | |||||
) | |||||
type UPNPCapabilities struct { | |||||
PortMapping bool | |||||
Hairpin bool | |||||
} | |||||
func makeUPNPListener(intPort int, extPort int) (NAT, net.Listener, net.IP, error) { | |||||
nat, err := Discover() | |||||
if err != nil { | |||||
return nil, nil, nil, errors.New(fmt.Sprintf("NAT upnp could not be discovered: %v", err)) | |||||
} | |||||
log.Info(Fmt("ourIP: %v", nat.(*upnpNAT).ourIP)) | |||||
ext, err := nat.GetExternalAddress() | |||||
if err != nil { | |||||
return nat, nil, nil, errors.New(fmt.Sprintf("External address error: %v", err)) | |||||
} | |||||
log.Info(Fmt("External address: %v", ext)) | |||||
port, err := nat.AddPortMapping("tcp", extPort, intPort, "Tendermint UPnP Probe", 0) | |||||
if err != nil { | |||||
return nat, nil, ext, errors.New(fmt.Sprintf("Port mapping error: %v", err)) | |||||
} | |||||
log.Info(Fmt("Port mapping mapped: %v", port)) | |||||
// also run the listener, open for all remote addresses. | |||||
listener, err := net.Listen("tcp", fmt.Sprintf(":%v", intPort)) | |||||
if err != nil { | |||||
return nat, nil, ext, errors.New(fmt.Sprintf("Error establishing listener: %v", err)) | |||||
} | |||||
return nat, listener, ext, nil | |||||
} | |||||
func testHairpin(listener net.Listener, extAddr string) (supportsHairpin bool) { | |||||
// Listener | |||||
go func() { | |||||
inConn, err := listener.Accept() | |||||
if err != nil { | |||||
log.Notice(Fmt("Listener.Accept() error: %v", err)) | |||||
return | |||||
} | |||||
log.Info(Fmt("Accepted incoming connection: %v -> %v", inConn.LocalAddr(), inConn.RemoteAddr())) | |||||
buf := make([]byte, 1024) | |||||
n, err := inConn.Read(buf) | |||||
if err != nil { | |||||
log.Notice(Fmt("Incoming connection read error: %v", err)) | |||||
return | |||||
} | |||||
log.Info(Fmt("Incoming connection read %v bytes: %X", n, buf)) | |||||
if string(buf) == "test data" { | |||||
supportsHairpin = true | |||||
return | |||||
} | |||||
}() | |||||
// Establish outgoing | |||||
outConn, err := net.Dial("tcp", extAddr) | |||||
if err != nil { | |||||
log.Notice(Fmt("Outgoing connection dial error: %v", err)) | |||||
return | |||||
} | |||||
n, err := outConn.Write([]byte("test data")) | |||||
if err != nil { | |||||
log.Notice(Fmt("Outgoing connection write error: %v", err)) | |||||
return | |||||
} | |||||
log.Info(Fmt("Outgoing connection wrote %v bytes", n)) | |||||
// Wait for data receipt | |||||
time.Sleep(1 * time.Second) | |||||
return | |||||
} | |||||
func Probe() (caps UPNPCapabilities, err error) { | |||||
log.Info("Probing for UPnP!") | |||||
intPort, extPort := 8001, 8001 | |||||
nat, listener, ext, err := makeUPNPListener(intPort, extPort) | |||||
if err != nil { | |||||
return | |||||
} | |||||
caps.PortMapping = true | |||||
// Deferred cleanup | |||||
defer func() { | |||||
err = nat.DeletePortMapping("tcp", intPort, extPort) | |||||
if err != nil { | |||||
log.Warn(Fmt("Port mapping delete error: %v", err)) | |||||
} | |||||
listener.Close() | |||||
}() | |||||
supportsHairpin := testHairpin(listener, fmt.Sprintf("%v:%v", ext, extPort)) | |||||
if supportsHairpin { | |||||
caps.Hairpin = true | |||||
} | |||||
return | |||||
} |
@ -0,0 +1,380 @@ | |||||
/* | |||||
Taken from taipei-torrent | |||||
Just enough UPnP to be able to forward ports | |||||
*/ | |||||
package upnp | |||||
// BUG(jae): TODO: use syscalls to get actual ourIP. http://pastebin.com/9exZG4rh | |||||
import ( | |||||
"bytes" | |||||
"encoding/xml" | |||||
"errors" | |||||
"io/ioutil" | |||||
"net" | |||||
"net/http" | |||||
"strconv" | |||||
"strings" | |||||
"time" | |||||
) | |||||
type upnpNAT struct { | |||||
serviceURL string | |||||
ourIP string | |||||
urnDomain string | |||||
} | |||||
// protocol is either "udp" or "tcp" | |||||
type NAT interface { | |||||
GetExternalAddress() (addr net.IP, err error) | |||||
AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) | |||||
DeletePortMapping(protocol string, externalPort, internalPort int) (err error) | |||||
} | |||||
func Discover() (nat NAT, err error) { | |||||
ssdp, err := net.ResolveUDPAddr("udp4", "239.255.255.250:1900") | |||||
if err != nil { | |||||
return | |||||
} | |||||
conn, err := net.ListenPacket("udp4", ":0") | |||||
if err != nil { | |||||
return | |||||
} | |||||
socket := conn.(*net.UDPConn) | |||||
defer socket.Close() | |||||
err = socket.SetDeadline(time.Now().Add(3 * time.Second)) | |||||
if err != nil { | |||||
return | |||||
} | |||||
st := "InternetGatewayDevice:1" | |||||
buf := bytes.NewBufferString( | |||||
"M-SEARCH * HTTP/1.1\r\n" + | |||||
"HOST: 239.255.255.250:1900\r\n" + | |||||
"ST: ssdp:all\r\n" + | |||||
"MAN: \"ssdp:discover\"\r\n" + | |||||
"MX: 2\r\n\r\n") | |||||
message := buf.Bytes() | |||||
answerBytes := make([]byte, 1024) | |||||
for i := 0; i < 3; i++ { | |||||
_, err = socket.WriteToUDP(message, ssdp) | |||||
if err != nil { | |||||
return | |||||
} | |||||
var n int | |||||
n, _, err = socket.ReadFromUDP(answerBytes) | |||||
for { | |||||
n, _, err = socket.ReadFromUDP(answerBytes) | |||||
if err != nil { | |||||
break | |||||
} | |||||
answer := string(answerBytes[0:n]) | |||||
if strings.Index(answer, st) < 0 { | |||||
continue | |||||
} | |||||
// HTTP header field names are case-insensitive. | |||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 | |||||
locString := "\r\nlocation:" | |||||
answer = strings.ToLower(answer) | |||||
locIndex := strings.Index(answer, locString) | |||||
if locIndex < 0 { | |||||
continue | |||||
} | |||||
loc := answer[locIndex+len(locString):] | |||||
endIndex := strings.Index(loc, "\r\n") | |||||
if endIndex < 0 { | |||||
continue | |||||
} | |||||
locURL := strings.TrimSpace(loc[0:endIndex]) | |||||
var serviceURL, urnDomain string | |||||
serviceURL, urnDomain, err = getServiceURL(locURL) | |||||
if err != nil { | |||||
return | |||||
} | |||||
var ourIP net.IP | |||||
ourIP, err = localIPv4() | |||||
if err != nil { | |||||
return | |||||
} | |||||
nat = &upnpNAT{serviceURL: serviceURL, ourIP: ourIP.String(), urnDomain: urnDomain} | |||||
return | |||||
} | |||||
} | |||||
err = errors.New("UPnP port discovery failed.") | |||||
return | |||||
} | |||||
type Envelope struct { | |||||
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Envelope"` | |||||
Soap *SoapBody | |||||
} | |||||
type SoapBody struct { | |||||
XMLName xml.Name `xml:"http://schemas.xmlsoap.org/soap/envelope/ Body"` | |||||
ExternalIP *ExternalIPAddressResponse | |||||
} | |||||
type ExternalIPAddressResponse struct { | |||||
XMLName xml.Name `xml:"GetExternalIPAddressResponse"` | |||||
IPAddress string `xml:"NewExternalIPAddress"` | |||||
} | |||||
type ExternalIPAddress struct { | |||||
XMLName xml.Name `xml:"NewExternalIPAddress"` | |||||
IP string | |||||
} | |||||
type UPNPService struct { | |||||
ServiceType string `xml:"serviceType"` | |||||
ControlURL string `xml:"controlURL"` | |||||
} | |||||
type DeviceList struct { | |||||
Device []Device `xml:"device"` | |||||
} | |||||
type ServiceList struct { | |||||
Service []UPNPService `xml:"service"` | |||||
} | |||||
type Device struct { | |||||
XMLName xml.Name `xml:"device"` | |||||
DeviceType string `xml:"deviceType"` | |||||
DeviceList DeviceList `xml:"deviceList"` | |||||
ServiceList ServiceList `xml:"serviceList"` | |||||
} | |||||
type Root struct { | |||||
Device Device | |||||
} | |||||
func getChildDevice(d *Device, deviceType string) *Device { | |||||
dl := d.DeviceList.Device | |||||
for i := 0; i < len(dl); i++ { | |||||
if strings.Index(dl[i].DeviceType, deviceType) >= 0 { | |||||
return &dl[i] | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func getChildService(d *Device, serviceType string) *UPNPService { | |||||
sl := d.ServiceList.Service | |||||
for i := 0; i < len(sl); i++ { | |||||
if strings.Index(sl[i].ServiceType, serviceType) >= 0 { | |||||
return &sl[i] | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
func localIPv4() (net.IP, error) { | |||||
tt, err := net.Interfaces() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
for _, t := range tt { | |||||
aa, err := t.Addrs() | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
for _, a := range aa { | |||||
ipnet, ok := a.(*net.IPNet) | |||||
if !ok { | |||||
continue | |||||
} | |||||
v4 := ipnet.IP.To4() | |||||
if v4 == nil || v4[0] == 127 { // loopback address | |||||
continue | |||||
} | |||||
return v4, nil | |||||
} | |||||
} | |||||
return nil, errors.New("cannot find local IP address") | |||||
} | |||||
func getServiceURL(rootURL string) (url, urnDomain string, err error) { | |||||
r, err := http.Get(rootURL) | |||||
if err != nil { | |||||
return | |||||
} | |||||
defer r.Body.Close() | |||||
if r.StatusCode >= 400 { | |||||
err = errors.New(string(r.StatusCode)) | |||||
return | |||||
} | |||||
var root Root | |||||
err = xml.NewDecoder(r.Body).Decode(&root) | |||||
if err != nil { | |||||
return | |||||
} | |||||
a := &root.Device | |||||
if strings.Index(a.DeviceType, "InternetGatewayDevice:1") < 0 { | |||||
err = errors.New("No InternetGatewayDevice") | |||||
return | |||||
} | |||||
b := getChildDevice(a, "WANDevice:1") | |||||
if b == nil { | |||||
err = errors.New("No WANDevice") | |||||
return | |||||
} | |||||
c := getChildDevice(b, "WANConnectionDevice:1") | |||||
if c == nil { | |||||
err = errors.New("No WANConnectionDevice") | |||||
return | |||||
} | |||||
d := getChildService(c, "WANIPConnection:1") | |||||
if d == nil { | |||||
// Some routers don't follow the UPnP spec, and put WanIPConnection under WanDevice, | |||||
// instead of under WanConnectionDevice | |||||
d = getChildService(b, "WANIPConnection:1") | |||||
if d == nil { | |||||
err = errors.New("No WANIPConnection") | |||||
return | |||||
} | |||||
} | |||||
// Extract the domain name, which isn't always 'schemas-upnp-org' | |||||
urnDomain = strings.Split(d.ServiceType, ":")[1] | |||||
url = combineURL(rootURL, d.ControlURL) | |||||
return | |||||
} | |||||
func combineURL(rootURL, subURL string) string { | |||||
protocolEnd := "://" | |||||
protoEndIndex := strings.Index(rootURL, protocolEnd) | |||||
a := rootURL[protoEndIndex+len(protocolEnd):] | |||||
rootIndex := strings.Index(a, "/") | |||||
return rootURL[0:protoEndIndex+len(protocolEnd)+rootIndex] + subURL | |||||
} | |||||
func soapRequest(url, function, message, domain string) (r *http.Response, err error) { | |||||
fullMessage := "<?xml version=\"1.0\" ?>" + | |||||
"<s:Envelope xmlns:s=\"http://schemas.xmlsoap.org/soap/envelope/\" s:encodingStyle=\"http://schemas.xmlsoap.org/soap/encoding/\">\r\n" + | |||||
"<s:Body>" + message + "</s:Body></s:Envelope>" | |||||
req, err := http.NewRequest("POST", url, strings.NewReader(fullMessage)) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
req.Header.Set("Content-Type", "text/xml ; charset=\"utf-8\"") | |||||
req.Header.Set("User-Agent", "Darwin/10.0.0, UPnP/1.0, MiniUPnPc/1.3") | |||||
//req.Header.Set("Transfer-Encoding", "chunked") | |||||
req.Header.Set("SOAPAction", "\"urn:"+domain+":service:WANIPConnection:1#"+function+"\"") | |||||
req.Header.Set("Connection", "Close") | |||||
req.Header.Set("Cache-Control", "no-cache") | |||||
req.Header.Set("Pragma", "no-cache") | |||||
// log.Stderr("soapRequest ", req) | |||||
r, err = http.DefaultClient.Do(req) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
/*if r.Body != nil { | |||||
defer r.Body.Close() | |||||
}*/ | |||||
if r.StatusCode >= 400 { | |||||
// log.Stderr(function, r.StatusCode) | |||||
err = errors.New("Error " + strconv.Itoa(r.StatusCode) + " for " + function) | |||||
r = nil | |||||
return | |||||
} | |||||
return | |||||
} | |||||
type statusInfo struct { | |||||
externalIpAddress string | |||||
} | |||||
func (n *upnpNAT) getExternalIPAddress() (info statusInfo, err error) { | |||||
message := "<u:GetExternalIPAddress xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + | |||||
"</u:GetExternalIPAddress>" | |||||
var response *http.Response | |||||
response, err = soapRequest(n.serviceURL, "GetExternalIPAddress", message, n.urnDomain) | |||||
if response != nil { | |||||
defer response.Body.Close() | |||||
} | |||||
if err != nil { | |||||
return | |||||
} | |||||
var envelope Envelope | |||||
data, err := ioutil.ReadAll(response.Body) | |||||
reader := bytes.NewReader(data) | |||||
xml.NewDecoder(reader).Decode(&envelope) | |||||
info = statusInfo{envelope.Soap.ExternalIP.IPAddress} | |||||
if err != nil { | |||||
return | |||||
} | |||||
return | |||||
} | |||||
func (n *upnpNAT) GetExternalAddress() (addr net.IP, err error) { | |||||
info, err := n.getExternalIPAddress() | |||||
if err != nil { | |||||
return | |||||
} | |||||
addr = net.ParseIP(info.externalIpAddress) | |||||
return | |||||
} | |||||
func (n *upnpNAT) AddPortMapping(protocol string, externalPort, internalPort int, description string, timeout int) (mappedExternalPort int, err error) { | |||||
// A single concatenation would break ARM compilation. | |||||
message := "<u:AddPortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + | |||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) | |||||
message += "</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" | |||||
message += "<NewInternalPort>" + strconv.Itoa(internalPort) + "</NewInternalPort>" + | |||||
"<NewInternalClient>" + n.ourIP + "</NewInternalClient>" + | |||||
"<NewEnabled>1</NewEnabled><NewPortMappingDescription>" | |||||
message += description + | |||||
"</NewPortMappingDescription><NewLeaseDuration>" + strconv.Itoa(timeout) + | |||||
"</NewLeaseDuration></u:AddPortMapping>" | |||||
var response *http.Response | |||||
response, err = soapRequest(n.serviceURL, "AddPortMapping", message, n.urnDomain) | |||||
if response != nil { | |||||
defer response.Body.Close() | |||||
} | |||||
if err != nil { | |||||
return | |||||
} | |||||
// TODO: check response to see if the port was forwarded | |||||
// log.Println(message, response) | |||||
// JAE: | |||||
// body, err := ioutil.ReadAll(response.Body) | |||||
// fmt.Println(string(body), err) | |||||
mappedExternalPort = externalPort | |||||
_ = response | |||||
return | |||||
} | |||||
func (n *upnpNAT) DeletePortMapping(protocol string, externalPort, internalPort int) (err error) { | |||||
message := "<u:DeletePortMapping xmlns:u=\"urn:" + n.urnDomain + ":service:WANIPConnection:1\">\r\n" + | |||||
"<NewRemoteHost></NewRemoteHost><NewExternalPort>" + strconv.Itoa(externalPort) + | |||||
"</NewExternalPort><NewProtocol>" + protocol + "</NewProtocol>" + | |||||
"</u:DeletePortMapping>" | |||||
var response *http.Response | |||||
response, err = soapRequest(n.serviceURL, "DeletePortMapping", message, n.urnDomain) | |||||
if response != nil { | |||||
defer response.Body.Close() | |||||
} | |||||
if err != nil { | |||||
return | |||||
} | |||||
// TODO: check response to see if the port was deleted | |||||
// log.Println(message, response) | |||||
_ = response | |||||
return | |||||
} |
@ -0,0 +1,15 @@ | |||||
package p2p | |||||
import ( | |||||
"crypto/sha256" | |||||
) | |||||
// doubleSha256 calculates sha256(sha256(b)) and returns the resulting bytes. | |||||
func doubleSha256(b []byte) []byte { | |||||
hasher := sha256.New() | |||||
hasher.Write(b) | |||||
sum := hasher.Sum(nil) | |||||
hasher.Reset() | |||||
hasher.Write(sum) | |||||
return hasher.Sum(nil) | |||||
} |
@ -0,0 +1,3 @@ | |||||
package p2p | |||||
const Version = "0.5.0" |
@ -1,7 +1,7 @@ | |||||
package proxy | package proxy | ||||
import ( | import ( | ||||
"github.com/tendermint/go-logger" | |||||
"github.com/tendermint/tmlibs/logger" | |||||
) | ) | ||||
var log = logger.New("module", "proxy") | var log = logger.New("module", "proxy") |
@ -0,0 +1,12 @@ | |||||
FROM golang:latest | |||||
RUN mkdir -p /go/src/github.com/tendermint/tendermint/rpc | |||||
WORKDIR /go/src/github.com/tendermint/tendermint/rpc | |||||
COPY Makefile /go/src/github.com/tendermint/tendermint/rpc/ | |||||
# COPY glide.yaml /go/src/github.com/tendermint/tendermint/rpc/ | |||||
# COPY glide.lock /go/src/github.com/tendermint/tendermint/rpc/ | |||||
COPY . /go/src/github.com/tendermint/tendermint/rpc | |||||
RUN make get_deps |
@ -0,0 +1,18 @@ | |||||
PACKAGES=$(shell go list ./... | grep -v "test") | |||||
all: get_deps test | |||||
test: | |||||
@echo "--> Running go test --race" | |||||
@go test --race $(PACKAGES) | |||||
@echo "--> Running integration tests" | |||||
@bash ./test/integration_test.sh | |||||
get_deps: | |||||
@echo "--> Running go get" | |||||
@go get -v -d $(PACKAGES) | |||||
@go list -f '{{join .TestImports "\n"}}' ./... | \ | |||||
grep -v /vendor/ | sort | uniq | \ | |||||
xargs go get -v -d | |||||
.PHONY: all test get_deps |
@ -0,0 +1,128 @@ | |||||
# tendermint/rpc | |||||
[![CircleCI](https://circleci.com/gh/tendermint/tendermint/rpc.svg?style=svg)](https://circleci.com/gh/tendermint/tendermint/rpc) | |||||
HTTP RPC server supporting calls via uri params, jsonrpc, and jsonrpc over websockets | |||||
# Client Requests | |||||
Suppose we want to expose the rpc function `HelloWorld(name string, num int)`. | |||||
## GET (URI) | |||||
As a GET request, it would have URI encoded parameters, and look like: | |||||
``` | |||||
curl 'http://localhost:8008/hello_world?name="my_world"&num=5' | |||||
``` | |||||
Note the `'` around the url, which is just so bash doesn't ignore the quotes in `"my_world"`. | |||||
This should also work: | |||||
``` | |||||
curl http://localhost:8008/hello_world?name=\"my_world\"&num=5 | |||||
``` | |||||
A GET request to `/` returns a list of available endpoints. | |||||
For those which take arguments, the arguments will be listed in order, with `_` where the actual value should be. | |||||
## POST (JSONRPC) | |||||
As a POST request, we use JSONRPC. For instance, the same request would have this as the body: | |||||
``` | |||||
{ | |||||
"jsonrpc": "2.0", | |||||
"id": "anything", | |||||
"method": "hello_world", | |||||
"params": { | |||||
"name": "my_world", | |||||
"num": 5 | |||||
} | |||||
} | |||||
``` | |||||
With the above saved in file `data.json`, we can make the request with | |||||
``` | |||||
curl --data @data.json http://localhost:8008 | |||||
``` | |||||
## WebSocket (JSONRPC) | |||||
All requests are exposed over websocket in the same form as the POST JSONRPC. | |||||
Websocket connections are available at their own endpoint, typically `/websocket`, | |||||
though this is configurable when starting the server. | |||||
# Server Definition | |||||
Define some types and routes: | |||||
``` | |||||
// Define a type for results and register concrete versions with go-wire | |||||
type Result interface{} | |||||
type ResultStatus struct { | |||||
Value string | |||||
} | |||||
var _ = wire.RegisterInterface( | |||||
struct{ Result }{}, | |||||
wire.ConcreteType{&ResultStatus{}, 0x1}, | |||||
) | |||||
// Define some routes | |||||
var Routes = map[string]*rpcserver.RPCFunc{ | |||||
"status": rpcserver.NewRPCFunc(StatusResult, "arg"), | |||||
} | |||||
// an rpc function | |||||
func StatusResult(v string) (Result, error) { | |||||
return &ResultStatus{v}, nil | |||||
} | |||||
``` | |||||
Now start the server: | |||||
``` | |||||
mux := http.NewServeMux() | |||||
rpcserver.RegisterRPCFuncs(mux, Routes) | |||||
wm := rpcserver.NewWebsocketManager(Routes, nil) | |||||
mux.HandleFunc("/websocket", wm.WebsocketHandler) | |||||
go func() { | |||||
_, err := rpcserver.StartHTTPServer("0.0.0.0:8008", mux) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
}() | |||||
``` | |||||
Note that unix sockets are supported as well (eg. `/path/to/socket` instead of `0.0.0.0:8008`) | |||||
Now see all available endpoints by sending a GET request to `0.0.0.0:8008`. | |||||
Each route is available as a GET request, as a JSONRPCv2 POST request, and via JSONRPCv2 over websockets. | |||||
# Examples | |||||
* [Tendermint](https://github.com/tendermint/tendermint/blob/master/rpc/core/routes.go) | |||||
* [tm-monitor](https://github.com/tendermint/tools/blob/master/tm-monitor/rpc.go) | |||||
## CHANGELOG | |||||
### 0.7.0 | |||||
BREAKING CHANGES: | |||||
- removed `Client` empty interface | |||||
- `ClientJSONRPC#Call` `params` argument became a map | |||||
- rename `ClientURI` -> `URIClient`, `ClientJSONRPC` -> `JSONRPCClient` | |||||
IMPROVEMENTS: | |||||
- added `HTTPClient` interface, which can be used for both `ClientURI` | |||||
and `ClientJSONRPC` | |||||
- all params are now optional (Golang's default will be used if some param is missing) | |||||
- added `Call` method to `WSClient` (see method's doc for details) |
@ -0,0 +1,21 @@ | |||||
machine: | |||||
environment: | |||||
GOPATH: /home/ubuntu/.go_workspace | |||||
REPO: $GOPATH/src/github.com/$CIRCLE_PROJECT_USERNAME/$CIRCLE_PROJECT_REPONAME | |||||
hosts: | |||||
circlehost: 127.0.0.1 | |||||
localhost: 127.0.0.1 | |||||
checkout: | |||||
post: | |||||
- rm -rf $REPO | |||||
- mkdir -p $HOME/.go_workspace/src/github.com/$CIRCLE_PROJECT_USERNAME | |||||
- mv $HOME/$CIRCLE_PROJECT_REPONAME $REPO | |||||
dependencies: | |||||
override: | |||||
- "cd $REPO && make get_deps" | |||||
test: | |||||
override: | |||||
- "cd $REPO && make test" |
@ -0,0 +1,39 @@ | |||||
package rpcclient | |||||
import ( | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
) | |||||
type Tx []byte | |||||
type Foo struct { | |||||
Bar int | |||||
Baz string | |||||
} | |||||
func TestArgToJSON(t *testing.T) { | |||||
assert := assert.New(t) | |||||
require := require.New(t) | |||||
cases := []struct { | |||||
input interface{} | |||||
expected string | |||||
}{ | |||||
{[]byte("1234"), "0x31323334"}, | |||||
{Tx("654"), "0x363534"}, | |||||
{Foo{7, "hello"}, `{"Bar":7,"Baz":"hello"}`}, | |||||
} | |||||
for i, tc := range cases { | |||||
args := map[string]interface{}{"data": tc.input} | |||||
err := argsToJson(args) | |||||
require.Nil(err, "%d: %+v", i, err) | |||||
require.Equal(1, len(args), "%d", i) | |||||
data, ok := args["data"].(string) | |||||
require.True(ok, "%d: %#v", i, args["data"]) | |||||
assert.Equal(tc.expected, data, "%d", i) | |||||
} | |||||
} |
@ -0,0 +1,199 @@ | |||||
package rpcclient | |||||
import ( | |||||
"bytes" | |||||
"encoding/json" | |||||
"fmt" | |||||
"io/ioutil" | |||||
"net" | |||||
"net/http" | |||||
"net/url" | |||||
"reflect" | |||||
"strings" | |||||
"github.com/pkg/errors" | |||||
wire "github.com/tendermint/go-wire" | |||||
types "github.com/tendermint/tendermint/rpc/types" | |||||
) | |||||
// HTTPClient is a common interface for JSONRPCClient and URIClient. | |||||
type HTTPClient interface { | |||||
Call(method string, params map[string]interface{}, result interface{}) (interface{}, error) | |||||
} | |||||
// TODO: Deprecate support for IP:PORT or /path/to/socket | |||||
func makeHTTPDialer(remoteAddr string) (string, func(string, string) (net.Conn, error)) { | |||||
parts := strings.SplitN(remoteAddr, "://", 2) | |||||
var protocol, address string | |||||
if len(parts) != 2 { | |||||
log.Warn("WARNING (tendermint/rpc): Please use fully formed listening addresses, including the tcp:// or unix:// prefix") | |||||
protocol = types.SocketType(remoteAddr) | |||||
address = remoteAddr | |||||
} else { | |||||
protocol, address = parts[0], parts[1] | |||||
} | |||||
trimmedAddress := strings.Replace(address, "/", ".", -1) // replace / with . for http requests (dummy domain) | |||||
return trimmedAddress, func(proto, addr string) (net.Conn, error) { | |||||
return net.Dial(protocol, address) | |||||
} | |||||
} | |||||
// We overwrite the http.Client.Dial so we can do http over tcp or unix. | |||||
// remoteAddr should be fully featured (eg. with tcp:// or unix://) | |||||
func makeHTTPClient(remoteAddr string) (string, *http.Client) { | |||||
address, dialer := makeHTTPDialer(remoteAddr) | |||||
return "http://" + address, &http.Client{ | |||||
Transport: &http.Transport{ | |||||
Dial: dialer, | |||||
}, | |||||
} | |||||
} | |||||
//------------------------------------------------------------------------------------ | |||||
// JSON rpc takes params as a slice | |||||
type JSONRPCClient struct { | |||||
address string | |||||
client *http.Client | |||||
} | |||||
func NewJSONRPCClient(remote string) *JSONRPCClient { | |||||
address, client := makeHTTPClient(remote) | |||||
return &JSONRPCClient{ | |||||
address: address, | |||||
client: client, | |||||
} | |||||
} | |||||
func (c *JSONRPCClient) Call(method string, params map[string]interface{}, result interface{}) (interface{}, error) { | |||||
// we need this step because we attempt to decode values using `go-wire` | |||||
// (handlers.go:176) on the server side | |||||
encodedParams := make(map[string]interface{}) | |||||
for k, v := range params { | |||||
bytes := json.RawMessage(wire.JSONBytes(v)) | |||||
encodedParams[k] = &bytes | |||||
} | |||||
request := types.RPCRequest{ | |||||
JSONRPC: "2.0", | |||||
Method: method, | |||||
Params: encodedParams, | |||||
ID: "", | |||||
} | |||||
requestBytes, err := json.Marshal(request) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
// log.Info(string(requestBytes)) | |||||
requestBuf := bytes.NewBuffer(requestBytes) | |||||
// log.Info(Fmt("RPC request to %v (%v): %v", c.remote, method, string(requestBytes))) | |||||
httpResponse, err := c.client.Post(c.address, "text/json", requestBuf) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
defer httpResponse.Body.Close() | |||||
responseBytes, err := ioutil.ReadAll(httpResponse.Body) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
// log.Info(Fmt("RPC response: %v", string(responseBytes))) | |||||
return unmarshalResponseBytes(responseBytes, result) | |||||
} | |||||
//------------------------------------------------------------- | |||||
// URI takes params as a map | |||||
type URIClient struct { | |||||
address string | |||||
client *http.Client | |||||
} | |||||
func NewURIClient(remote string) *URIClient { | |||||
address, client := makeHTTPClient(remote) | |||||
return &URIClient{ | |||||
address: address, | |||||
client: client, | |||||
} | |||||
} | |||||
func (c *URIClient) Call(method string, params map[string]interface{}, result interface{}) (interface{}, error) { | |||||
values, err := argsToURLValues(params) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
// log.Info(Fmt("URI request to %v (%v): %v", c.address, method, values)) | |||||
resp, err := c.client.PostForm(c.address+"/"+method, values) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
defer resp.Body.Close() | |||||
responseBytes, err := ioutil.ReadAll(resp.Body) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return unmarshalResponseBytes(responseBytes, result) | |||||
} | |||||
//------------------------------------------------ | |||||
func unmarshalResponseBytes(responseBytes []byte, result interface{}) (interface{}, error) { | |||||
// read response | |||||
// if rpc/core/types is imported, the result will unmarshal | |||||
// into the correct type | |||||
// log.Notice("response", "response", string(responseBytes)) | |||||
var err error | |||||
response := &types.RPCResponse{} | |||||
err = json.Unmarshal(responseBytes, response) | |||||
if err != nil { | |||||
return nil, errors.Errorf("Error unmarshalling rpc response: %v", err) | |||||
} | |||||
errorStr := response.Error | |||||
if errorStr != "" { | |||||
return nil, errors.Errorf("Response error: %v", errorStr) | |||||
} | |||||
// unmarshal the RawMessage into the result | |||||
result = wire.ReadJSONPtr(result, *response.Result, &err) | |||||
if err != nil { | |||||
return nil, errors.Errorf("Error unmarshalling rpc response result: %v", err) | |||||
} | |||||
return result, nil | |||||
} | |||||
func argsToURLValues(args map[string]interface{}) (url.Values, error) { | |||||
values := make(url.Values) | |||||
if len(args) == 0 { | |||||
return values, nil | |||||
} | |||||
err := argsToJson(args) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
for key, val := range args { | |||||
values.Set(key, val.(string)) | |||||
} | |||||
return values, nil | |||||
} | |||||
func argsToJson(args map[string]interface{}) error { | |||||
var n int | |||||
var err error | |||||
for k, v := range args { | |||||
rt := reflect.TypeOf(v) | |||||
isByteSlice := rt.Kind() == reflect.Slice && rt.Elem().Kind() == reflect.Uint8 | |||||
if isByteSlice { | |||||
bytes := reflect.ValueOf(v).Bytes() | |||||
args[k] = fmt.Sprintf("0x%X", bytes) | |||||
continue | |||||
} | |||||
// Pass everything else to go-wire | |||||
buf := new(bytes.Buffer) | |||||
wire.WriteJSON(v, buf, &n, &err) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
args[k] = buf.String() | |||||
} | |||||
return nil | |||||
} |
@ -0,0 +1,7 @@ | |||||
package rpcclient | |||||
import ( | |||||
"github.com/tendermint/log15" | |||||
) | |||||
var log = log15.New("module", "rpcclient") |
@ -0,0 +1,172 @@ | |||||
package rpcclient | |||||
import ( | |||||
"encoding/json" | |||||
"net" | |||||
"net/http" | |||||
"time" | |||||
"github.com/gorilla/websocket" | |||||
"github.com/pkg/errors" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
types "github.com/tendermint/tendermint/rpc/types" | |||||
wire "github.com/tendermint/go-wire" | |||||
) | |||||
const ( | |||||
wsResultsChannelCapacity = 10 | |||||
wsErrorsChannelCapacity = 1 | |||||
wsWriteTimeoutSeconds = 10 | |||||
) | |||||
type WSClient struct { | |||||
cmn.BaseService | |||||
Address string // IP:PORT or /path/to/socket | |||||
Endpoint string // /websocket/url/endpoint | |||||
Dialer func(string, string) (net.Conn, error) | |||||
*websocket.Conn | |||||
ResultsCh chan json.RawMessage // closes upon WSClient.Stop() | |||||
ErrorsCh chan error // closes upon WSClient.Stop() | |||||
} | |||||
// create a new connection | |||||
func NewWSClient(remoteAddr, endpoint string) *WSClient { | |||||
addr, dialer := makeHTTPDialer(remoteAddr) | |||||
wsClient := &WSClient{ | |||||
Address: addr, | |||||
Dialer: dialer, | |||||
Endpoint: endpoint, | |||||
Conn: nil, | |||||
} | |||||
wsClient.BaseService = *cmn.NewBaseService(log, "WSClient", wsClient) | |||||
return wsClient | |||||
} | |||||
func (wsc *WSClient) String() string { | |||||
return wsc.Address + ", " + wsc.Endpoint | |||||
} | |||||
// OnStart implements cmn.BaseService interface | |||||
func (wsc *WSClient) OnStart() error { | |||||
wsc.BaseService.OnStart() | |||||
err := wsc.dial() | |||||
if err != nil { | |||||
return err | |||||
} | |||||
wsc.ResultsCh = make(chan json.RawMessage, wsResultsChannelCapacity) | |||||
wsc.ErrorsCh = make(chan error, wsErrorsChannelCapacity) | |||||
go wsc.receiveEventsRoutine() | |||||
return nil | |||||
} | |||||
// OnReset implements cmn.BaseService interface | |||||
func (wsc *WSClient) OnReset() error { | |||||
return nil | |||||
} | |||||
func (wsc *WSClient) dial() error { | |||||
// Dial | |||||
dialer := &websocket.Dialer{ | |||||
NetDial: wsc.Dialer, | |||||
Proxy: http.ProxyFromEnvironment, | |||||
} | |||||
rHeader := http.Header{} | |||||
con, _, err := dialer.Dial("ws://"+wsc.Address+wsc.Endpoint, rHeader) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
// Set the ping/pong handlers | |||||
con.SetPingHandler(func(m string) error { | |||||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||||
go con.WriteControl(websocket.PongMessage, []byte(m), time.Now().Add(time.Second*wsWriteTimeoutSeconds)) | |||||
return nil | |||||
}) | |||||
con.SetPongHandler(func(m string) error { | |||||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||||
return nil | |||||
}) | |||||
wsc.Conn = con | |||||
return nil | |||||
} | |||||
// OnStop implements cmn.BaseService interface | |||||
func (wsc *WSClient) OnStop() { | |||||
wsc.BaseService.OnStop() | |||||
wsc.Conn.Close() | |||||
// ResultsCh/ErrorsCh is closed in receiveEventsRoutine. | |||||
} | |||||
func (wsc *WSClient) receiveEventsRoutine() { | |||||
for { | |||||
_, data, err := wsc.ReadMessage() | |||||
if err != nil { | |||||
log.Info("WSClient failed to read message", "error", err, "data", string(data)) | |||||
wsc.Stop() | |||||
break | |||||
} else { | |||||
var response types.RPCResponse | |||||
err := json.Unmarshal(data, &response) | |||||
if err != nil { | |||||
log.Info("WSClient failed to parse message", "error", err, "data", string(data)) | |||||
wsc.ErrorsCh <- err | |||||
continue | |||||
} | |||||
if response.Error != "" { | |||||
wsc.ErrorsCh <- errors.Errorf(response.Error) | |||||
continue | |||||
} | |||||
wsc.ResultsCh <- *response.Result | |||||
} | |||||
} | |||||
// this must be modified in the same go-routine that reads from the | |||||
// connection to avoid race conditions | |||||
wsc.Conn = nil | |||||
// Cleanup | |||||
close(wsc.ResultsCh) | |||||
close(wsc.ErrorsCh) | |||||
} | |||||
// Subscribe to an event. Note the server must have a "subscribe" route | |||||
// defined. | |||||
func (wsc *WSClient) Subscribe(eventid string) error { | |||||
err := wsc.WriteJSON(types.RPCRequest{ | |||||
JSONRPC: "2.0", | |||||
ID: "", | |||||
Method: "subscribe", | |||||
Params: map[string]interface{}{"event": eventid}, | |||||
}) | |||||
return err | |||||
} | |||||
// Unsubscribe from an event. Note the server must have a "unsubscribe" route | |||||
// defined. | |||||
func (wsc *WSClient) Unsubscribe(eventid string) error { | |||||
err := wsc.WriteJSON(types.RPCRequest{ | |||||
JSONRPC: "2.0", | |||||
ID: "", | |||||
Method: "unsubscribe", | |||||
Params: map[string]interface{}{"event": eventid}, | |||||
}) | |||||
return err | |||||
} | |||||
// Call asynchronously calls a given method by sending an RPCRequest to the | |||||
// server. Results will be available on ResultsCh, errors, if any, on ErrorsCh. | |||||
func (wsc *WSClient) Call(method string, params map[string]interface{}) error { | |||||
// we need this step because we attempt to decode values using `go-wire` | |||||
// (handlers.go:470) on the server side | |||||
encodedParams := make(map[string]interface{}) | |||||
for k, v := range params { | |||||
bytes := json.RawMessage(wire.JSONBytes(v)) | |||||
encodedParams[k] = &bytes | |||||
} | |||||
err := wsc.WriteJSON(types.RPCRequest{ | |||||
JSONRPC: "2.0", | |||||
Method: method, | |||||
Params: encodedParams, | |||||
ID: "", | |||||
}) | |||||
return err | |||||
} |
@ -0,0 +1,298 @@ | |||||
package rpc | |||||
import ( | |||||
"bytes" | |||||
crand "crypto/rand" | |||||
"fmt" | |||||
"math/rand" | |||||
"net/http" | |||||
"os/exec" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
wire "github.com/tendermint/go-wire" | |||||
client "github.com/tendermint/tendermint/rpc/client" | |||||
server "github.com/tendermint/tendermint/rpc/server" | |||||
types "github.com/tendermint/tendermint/rpc/types" | |||||
) | |||||
// Client and Server should work over tcp or unix sockets | |||||
const ( | |||||
tcpAddr = "tcp://0.0.0.0:46657" | |||||
unixSocket = "/tmp/rpc.sock" | |||||
unixAddr = "unix:///tmp/rpc.sock" | |||||
websocketEndpoint = "/websocket/endpoint" | |||||
) | |||||
// Define a type for results and register concrete versions | |||||
type Result interface{} | |||||
type ResultEcho struct { | |||||
Value string | |||||
} | |||||
type ResultEchoBytes struct { | |||||
Value []byte | |||||
} | |||||
var _ = wire.RegisterInterface( | |||||
struct{ Result }{}, | |||||
wire.ConcreteType{&ResultEcho{}, 0x1}, | |||||
wire.ConcreteType{&ResultEchoBytes{}, 0x2}, | |||||
) | |||||
// Define some routes | |||||
var Routes = map[string]*server.RPCFunc{ | |||||
"echo": server.NewRPCFunc(EchoResult, "arg"), | |||||
"echo_ws": server.NewWSRPCFunc(EchoWSResult, "arg"), | |||||
"echo_bytes": server.NewRPCFunc(EchoBytesResult, "arg"), | |||||
} | |||||
func EchoResult(v string) (Result, error) { | |||||
return &ResultEcho{v}, nil | |||||
} | |||||
func EchoWSResult(wsCtx types.WSRPCContext, v string) (Result, error) { | |||||
return &ResultEcho{v}, nil | |||||
} | |||||
func EchoBytesResult(v []byte) (Result, error) { | |||||
return &ResultEchoBytes{v}, nil | |||||
} | |||||
// launch unix and tcp servers | |||||
func init() { | |||||
cmd := exec.Command("rm", "-f", unixSocket) | |||||
err := cmd.Start() | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
if err = cmd.Wait(); err != nil { | |||||
panic(err) | |||||
} | |||||
mux := http.NewServeMux() | |||||
server.RegisterRPCFuncs(mux, Routes) | |||||
wm := server.NewWebsocketManager(Routes, nil) | |||||
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler) | |||||
go func() { | |||||
_, err := server.StartHTTPServer(tcpAddr, mux) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
}() | |||||
mux2 := http.NewServeMux() | |||||
server.RegisterRPCFuncs(mux2, Routes) | |||||
wm = server.NewWebsocketManager(Routes, nil) | |||||
mux2.HandleFunc(websocketEndpoint, wm.WebsocketHandler) | |||||
go func() { | |||||
_, err := server.StartHTTPServer(unixAddr, mux2) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
}() | |||||
// wait for servers to start | |||||
time.Sleep(time.Second * 2) | |||||
} | |||||
func echoViaHTTP(cl client.HTTPClient, val string) (string, error) { | |||||
params := map[string]interface{}{ | |||||
"arg": val, | |||||
} | |||||
var result Result | |||||
if _, err := cl.Call("echo", params, &result); err != nil { | |||||
return "", err | |||||
} | |||||
return result.(*ResultEcho).Value, nil | |||||
} | |||||
func echoBytesViaHTTP(cl client.HTTPClient, bytes []byte) ([]byte, error) { | |||||
params := map[string]interface{}{ | |||||
"arg": bytes, | |||||
} | |||||
var result Result | |||||
if _, err := cl.Call("echo_bytes", params, &result); err != nil { | |||||
return []byte{}, err | |||||
} | |||||
return result.(*ResultEchoBytes).Value, nil | |||||
} | |||||
func testWithHTTPClient(t *testing.T, cl client.HTTPClient) { | |||||
val := "acbd" | |||||
got, err := echoViaHTTP(cl, val) | |||||
require.Nil(t, err) | |||||
assert.Equal(t, got, val) | |||||
val2 := randBytes(t) | |||||
got2, err := echoBytesViaHTTP(cl, val2) | |||||
require.Nil(t, err) | |||||
assert.Equal(t, got2, val2) | |||||
} | |||||
func echoViaWS(cl *client.WSClient, val string) (string, error) { | |||||
params := map[string]interface{}{ | |||||
"arg": val, | |||||
} | |||||
err := cl.Call("echo", params) | |||||
if err != nil { | |||||
return "", err | |||||
} | |||||
select { | |||||
case msg := <-cl.ResultsCh: | |||||
result := new(Result) | |||||
wire.ReadJSONPtr(result, msg, &err) | |||||
if err != nil { | |||||
return "", nil | |||||
} | |||||
return (*result).(*ResultEcho).Value, nil | |||||
case err := <-cl.ErrorsCh: | |||||
return "", err | |||||
} | |||||
} | |||||
func echoBytesViaWS(cl *client.WSClient, bytes []byte) ([]byte, error) { | |||||
params := map[string]interface{}{ | |||||
"arg": bytes, | |||||
} | |||||
err := cl.Call("echo_bytes", params) | |||||
if err != nil { | |||||
return []byte{}, err | |||||
} | |||||
select { | |||||
case msg := <-cl.ResultsCh: | |||||
result := new(Result) | |||||
wire.ReadJSONPtr(result, msg, &err) | |||||
if err != nil { | |||||
return []byte{}, nil | |||||
} | |||||
return (*result).(*ResultEchoBytes).Value, nil | |||||
case err := <-cl.ErrorsCh: | |||||
return []byte{}, err | |||||
} | |||||
} | |||||
func testWithWSClient(t *testing.T, cl *client.WSClient) { | |||||
val := "acbd" | |||||
got, err := echoViaWS(cl, val) | |||||
require.Nil(t, err) | |||||
assert.Equal(t, got, val) | |||||
val2 := randBytes(t) | |||||
got2, err := echoBytesViaWS(cl, val2) | |||||
require.Nil(t, err) | |||||
assert.Equal(t, got2, val2) | |||||
} | |||||
//------------- | |||||
func TestServersAndClientsBasic(t *testing.T) { | |||||
serverAddrs := [...]string{tcpAddr, unixAddr} | |||||
for _, addr := range serverAddrs { | |||||
cl1 := client.NewURIClient(addr) | |||||
fmt.Printf("=== testing server on %s using %v client", addr, cl1) | |||||
testWithHTTPClient(t, cl1) | |||||
cl2 := client.NewJSONRPCClient(tcpAddr) | |||||
fmt.Printf("=== testing server on %s using %v client", addr, cl2) | |||||
testWithHTTPClient(t, cl2) | |||||
cl3 := client.NewWSClient(tcpAddr, websocketEndpoint) | |||||
_, err := cl3.Start() | |||||
require.Nil(t, err) | |||||
fmt.Printf("=== testing server on %s using %v client", addr, cl3) | |||||
testWithWSClient(t, cl3) | |||||
cl3.Stop() | |||||
} | |||||
} | |||||
func TestHexStringArg(t *testing.T) { | |||||
cl := client.NewURIClient(tcpAddr) | |||||
// should NOT be handled as hex | |||||
val := "0xabc" | |||||
got, err := echoViaHTTP(cl, val) | |||||
require.Nil(t, err) | |||||
assert.Equal(t, got, val) | |||||
} | |||||
func TestQuotedStringArg(t *testing.T) { | |||||
cl := client.NewURIClient(tcpAddr) | |||||
// should NOT be unquoted | |||||
val := "\"abc\"" | |||||
got, err := echoViaHTTP(cl, val) | |||||
require.Nil(t, err) | |||||
assert.Equal(t, got, val) | |||||
} | |||||
func TestWSNewWSRPCFunc(t *testing.T) { | |||||
cl := client.NewWSClient(tcpAddr, websocketEndpoint) | |||||
_, err := cl.Start() | |||||
require.Nil(t, err) | |||||
defer cl.Stop() | |||||
val := "acbd" | |||||
params := map[string]interface{}{ | |||||
"arg": val, | |||||
} | |||||
err = cl.WriteJSON(types.RPCRequest{ | |||||
JSONRPC: "2.0", | |||||
ID: "", | |||||
Method: "echo_ws", | |||||
Params: params, | |||||
}) | |||||
require.Nil(t, err) | |||||
select { | |||||
case msg := <-cl.ResultsCh: | |||||
result := new(Result) | |||||
wire.ReadJSONPtr(result, msg, &err) | |||||
require.Nil(t, err) | |||||
got := (*result).(*ResultEcho).Value | |||||
assert.Equal(t, got, val) | |||||
case err := <-cl.ErrorsCh: | |||||
t.Fatal(err) | |||||
} | |||||
} | |||||
func TestWSHandlesArrayParams(t *testing.T) { | |||||
cl := client.NewWSClient(tcpAddr, websocketEndpoint) | |||||
_, err := cl.Start() | |||||
require.Nil(t, err) | |||||
defer cl.Stop() | |||||
val := "acbd" | |||||
params := []interface{}{val} | |||||
err = cl.WriteJSON(types.RPCRequest{ | |||||
JSONRPC: "2.0", | |||||
ID: "", | |||||
Method: "echo_ws", | |||||
Params: params, | |||||
}) | |||||
require.Nil(t, err) | |||||
select { | |||||
case msg := <-cl.ResultsCh: | |||||
result := new(Result) | |||||
wire.ReadJSONPtr(result, msg, &err) | |||||
require.Nil(t, err) | |||||
got := (*result).(*ResultEcho).Value | |||||
assert.Equal(t, got, val) | |||||
case err := <-cl.ErrorsCh: | |||||
t.Fatalf("%+v", err) | |||||
} | |||||
} | |||||
func randBytes(t *testing.T) []byte { | |||||
n := rand.Intn(10) + 2 | |||||
buf := make([]byte, n) | |||||
_, err := crand.Read(buf) | |||||
require.Nil(t, err) | |||||
return bytes.Replace(buf, []byte("="), []byte{100}, -1) | |||||
} |
@ -0,0 +1,649 @@ | |||||
package rpcserver | |||||
import ( | |||||
"bytes" | |||||
"encoding/hex" | |||||
"encoding/json" | |||||
"fmt" | |||||
"io/ioutil" | |||||
"net/http" | |||||
"reflect" | |||||
"sort" | |||||
"strings" | |||||
"time" | |||||
"github.com/gorilla/websocket" | |||||
"github.com/pkg/errors" | |||||
types "github.com/tendermint/tendermint/rpc/types" | |||||
wire "github.com/tendermint/go-wire" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
events "github.com/tendermint/tmlibs/events" | |||||
) | |||||
// Adds a route for each function in the funcMap, as well as general jsonrpc and websocket handlers for all functions. | |||||
// "result" is the interface on which the result objects are registered, and is popualted with every RPCResponse | |||||
func RegisterRPCFuncs(mux *http.ServeMux, funcMap map[string]*RPCFunc) { | |||||
// HTTP endpoints | |||||
for funcName, rpcFunc := range funcMap { | |||||
mux.HandleFunc("/"+funcName, makeHTTPHandler(rpcFunc)) | |||||
} | |||||
// JSONRPC endpoints | |||||
mux.HandleFunc("/", makeJSONRPCHandler(funcMap)) | |||||
} | |||||
//------------------------------------- | |||||
// function introspection | |||||
// holds all type information for each function | |||||
type RPCFunc struct { | |||||
f reflect.Value // underlying rpc function | |||||
args []reflect.Type // type of each function arg | |||||
returns []reflect.Type // type of each return arg | |||||
argNames []string // name of each argument | |||||
ws bool // websocket only | |||||
} | |||||
// wraps a function for quicker introspection | |||||
// f is the function, args are comma separated argument names | |||||
func NewRPCFunc(f interface{}, args string) *RPCFunc { | |||||
return newRPCFunc(f, args, false) | |||||
} | |||||
func NewWSRPCFunc(f interface{}, args string) *RPCFunc { | |||||
return newRPCFunc(f, args, true) | |||||
} | |||||
func newRPCFunc(f interface{}, args string, ws bool) *RPCFunc { | |||||
var argNames []string | |||||
if args != "" { | |||||
argNames = strings.Split(args, ",") | |||||
} | |||||
return &RPCFunc{ | |||||
f: reflect.ValueOf(f), | |||||
args: funcArgTypes(f), | |||||
returns: funcReturnTypes(f), | |||||
argNames: argNames, | |||||
ws: ws, | |||||
} | |||||
} | |||||
// return a function's argument types | |||||
func funcArgTypes(f interface{}) []reflect.Type { | |||||
t := reflect.TypeOf(f) | |||||
n := t.NumIn() | |||||
typez := make([]reflect.Type, n) | |||||
for i := 0; i < n; i++ { | |||||
typez[i] = t.In(i) | |||||
} | |||||
return typez | |||||
} | |||||
// return a function's return types | |||||
func funcReturnTypes(f interface{}) []reflect.Type { | |||||
t := reflect.TypeOf(f) | |||||
n := t.NumOut() | |||||
typez := make([]reflect.Type, n) | |||||
for i := 0; i < n; i++ { | |||||
typez[i] = t.Out(i) | |||||
} | |||||
return typez | |||||
} | |||||
// function introspection | |||||
//----------------------------------------------------------------------------- | |||||
// rpc.json | |||||
// jsonrpc calls grab the given method's function info and runs reflect.Call | |||||
func makeJSONRPCHandler(funcMap map[string]*RPCFunc) http.HandlerFunc { | |||||
return func(w http.ResponseWriter, r *http.Request) { | |||||
b, _ := ioutil.ReadAll(r.Body) | |||||
// if its an empty request (like from a browser), | |||||
// just display a list of functions | |||||
if len(b) == 0 { | |||||
writeListOfEndpoints(w, r, funcMap) | |||||
return | |||||
} | |||||
var request types.RPCRequest | |||||
err := json.Unmarshal(b, &request) | |||||
if err != nil { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, fmt.Sprintf("Error unmarshalling request: %v", err.Error()))) | |||||
return | |||||
} | |||||
if len(r.URL.Path) > 1 { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, fmt.Sprintf("Invalid JSONRPC endpoint %s", r.URL.Path))) | |||||
return | |||||
} | |||||
rpcFunc := funcMap[request.Method] | |||||
if rpcFunc == nil { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, "RPC method unknown: "+request.Method)) | |||||
return | |||||
} | |||||
if rpcFunc.ws { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, "RPC method is only for websockets: "+request.Method)) | |||||
return | |||||
} | |||||
args, err := jsonParamsToArgsRPC(rpcFunc, request.Params) | |||||
if err != nil { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, nil, fmt.Sprintf("Error converting json params to arguments: %v", err.Error()))) | |||||
return | |||||
} | |||||
returns := rpcFunc.f.Call(args) | |||||
log.Info("HTTPJSONRPC", "method", request.Method, "args", args, "returns", returns) | |||||
result, err := unreflectResult(returns) | |||||
if err != nil { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, result, err.Error())) | |||||
return | |||||
} | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse(request.ID, result, "")) | |||||
} | |||||
} | |||||
// Convert a []interface{} OR a map[string]interface{} to properly typed values | |||||
// | |||||
// argsOffset should be 0 for RPC calls, and 1 for WS requests, where len(rpcFunc.args) != len(rpcFunc.argNames). | |||||
// Example: | |||||
// rpcFunc.args = [rpctypes.WSRPCContext string] | |||||
// rpcFunc.argNames = ["arg"] | |||||
func jsonParamsToArgs(rpcFunc *RPCFunc, paramsI interface{}, argsOffset int) ([]reflect.Value, error) { | |||||
values := make([]reflect.Value, len(rpcFunc.argNames)) | |||||
switch params := paramsI.(type) { | |||||
case map[string]interface{}: | |||||
for i, argName := range rpcFunc.argNames { | |||||
argType := rpcFunc.args[i+argsOffset] | |||||
// decode param if provided | |||||
if param, ok := params[argName]; ok && "" != param { | |||||
v, err := _jsonObjectToArg(argType, param) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
values[i] = v | |||||
} else { // use default for that type | |||||
values[i] = reflect.Zero(argType) | |||||
} | |||||
} | |||||
case []interface{}: | |||||
if len(rpcFunc.argNames) != len(params) { | |||||
return nil, errors.New(fmt.Sprintf("Expected %v parameters (%v), got %v (%v)", | |||||
len(rpcFunc.argNames), rpcFunc.argNames, len(params), params)) | |||||
} | |||||
values := make([]reflect.Value, len(params)) | |||||
for i, p := range params { | |||||
ty := rpcFunc.args[i+argsOffset] | |||||
v, err := _jsonObjectToArg(ty, p) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
values[i] = v | |||||
} | |||||
return values, nil | |||||
default: | |||||
return nil, fmt.Errorf("Unknown type for JSON params %v. Expected map[string]interface{} or []interface{}", reflect.TypeOf(paramsI)) | |||||
} | |||||
return values, nil | |||||
} | |||||
// Convert a []interface{} OR a map[string]interface{} to properly typed values | |||||
func jsonParamsToArgsRPC(rpcFunc *RPCFunc, paramsI interface{}) ([]reflect.Value, error) { | |||||
return jsonParamsToArgs(rpcFunc, paramsI, 0) | |||||
} | |||||
// Same as above, but with the first param the websocket connection | |||||
func jsonParamsToArgsWS(rpcFunc *RPCFunc, paramsI interface{}, wsCtx types.WSRPCContext) ([]reflect.Value, error) { | |||||
values, err := jsonParamsToArgs(rpcFunc, paramsI, 1) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return append([]reflect.Value{reflect.ValueOf(wsCtx)}, values...), nil | |||||
} | |||||
func _jsonObjectToArg(ty reflect.Type, object interface{}) (reflect.Value, error) { | |||||
var err error | |||||
v := reflect.New(ty) | |||||
wire.ReadJSONObjectPtr(v.Interface(), object, &err) | |||||
if err != nil { | |||||
return v, err | |||||
} | |||||
v = v.Elem() | |||||
return v, nil | |||||
} | |||||
// rpc.json | |||||
//----------------------------------------------------------------------------- | |||||
// rpc.http | |||||
// convert from a function name to the http handler | |||||
func makeHTTPHandler(rpcFunc *RPCFunc) func(http.ResponseWriter, *http.Request) { | |||||
// Exception for websocket endpoints | |||||
if rpcFunc.ws { | |||||
return func(w http.ResponseWriter, r *http.Request) { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, "This RPC method is only for websockets")) | |||||
} | |||||
} | |||||
// All other endpoints | |||||
return func(w http.ResponseWriter, r *http.Request) { | |||||
log.Debug("HTTP HANDLER", "req", r) | |||||
args, err := httpParamsToArgs(rpcFunc, r) | |||||
if err != nil { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, fmt.Sprintf("Error converting http params to args: %v", err.Error()))) | |||||
return | |||||
} | |||||
returns := rpcFunc.f.Call(args) | |||||
log.Info("HTTPRestRPC", "method", r.URL.Path, "args", args, "returns", returns) | |||||
result, err := unreflectResult(returns) | |||||
if err != nil { | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", nil, err.Error())) | |||||
return | |||||
} | |||||
WriteRPCResponseHTTP(w, types.NewRPCResponse("", result, "")) | |||||
} | |||||
} | |||||
// Covert an http query to a list of properly typed values. | |||||
// To be properly decoded the arg must be a concrete type from tendermint (if its an interface). | |||||
func httpParamsToArgs(rpcFunc *RPCFunc, r *http.Request) ([]reflect.Value, error) { | |||||
values := make([]reflect.Value, len(rpcFunc.args)) | |||||
for i, name := range rpcFunc.argNames { | |||||
argType := rpcFunc.args[i] | |||||
values[i] = reflect.Zero(argType) // set default for that type | |||||
arg := GetParam(r, name) | |||||
// log.Notice("param to arg", "argType", argType, "name", name, "arg", arg) | |||||
if "" == arg { | |||||
continue | |||||
} | |||||
v, err, ok := nonJsonToArg(argType, arg) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
if ok { | |||||
values[i] = v | |||||
continue | |||||
} | |||||
// Pass values to go-wire | |||||
values[i], err = _jsonStringToArg(argType, arg) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
} | |||||
return values, nil | |||||
} | |||||
func _jsonStringToArg(ty reflect.Type, arg string) (reflect.Value, error) { | |||||
var err error | |||||
v := reflect.New(ty) | |||||
wire.ReadJSONPtr(v.Interface(), []byte(arg), &err) | |||||
if err != nil { | |||||
return v, err | |||||
} | |||||
v = v.Elem() | |||||
return v, nil | |||||
} | |||||
func nonJsonToArg(ty reflect.Type, arg string) (reflect.Value, error, bool) { | |||||
isQuotedString := strings.HasPrefix(arg, `"`) && strings.HasSuffix(arg, `"`) | |||||
isHexString := strings.HasPrefix(strings.ToLower(arg), "0x") | |||||
expectingString := ty.Kind() == reflect.String | |||||
expectingByteSlice := ty.Kind() == reflect.Slice && ty.Elem().Kind() == reflect.Uint8 | |||||
if isHexString { | |||||
if !expectingString && !expectingByteSlice { | |||||
err := errors.Errorf("Got a hex string arg, but expected '%s'", | |||||
ty.Kind().String()) | |||||
return reflect.ValueOf(nil), err, false | |||||
} | |||||
var value []byte | |||||
value, err := hex.DecodeString(arg[2:]) | |||||
if err != nil { | |||||
return reflect.ValueOf(nil), err, false | |||||
} | |||||
if ty.Kind() == reflect.String { | |||||
return reflect.ValueOf(string(value)), nil, true | |||||
} | |||||
return reflect.ValueOf([]byte(value)), nil, true | |||||
} | |||||
if isQuotedString && expectingByteSlice { | |||||
var err error | |||||
v := reflect.New(reflect.TypeOf("")) | |||||
wire.ReadJSONPtr(v.Interface(), []byte(arg), &err) | |||||
if err != nil { | |||||
return reflect.ValueOf(nil), err, false | |||||
} | |||||
v = v.Elem() | |||||
return reflect.ValueOf([]byte(v.String())), nil, true | |||||
} | |||||
return reflect.ValueOf(nil), nil, false | |||||
} | |||||
// rpc.http | |||||
//----------------------------------------------------------------------------- | |||||
// rpc.websocket | |||||
const ( | |||||
writeChanCapacity = 1000 | |||||
wsWriteTimeoutSeconds = 30 // each write times out after this | |||||
wsReadTimeoutSeconds = 30 // connection times out if we haven't received *anything* in this long, not even pings. | |||||
wsPingTickerSeconds = 10 // send a ping every PingTickerSeconds. | |||||
) | |||||
// a single websocket connection | |||||
// contains listener id, underlying ws connection, | |||||
// and the event switch for subscribing to events | |||||
type wsConnection struct { | |||||
cmn.BaseService | |||||
remoteAddr string | |||||
baseConn *websocket.Conn | |||||
writeChan chan types.RPCResponse | |||||
readTimeout *time.Timer | |||||
pingTicker *time.Ticker | |||||
funcMap map[string]*RPCFunc | |||||
evsw events.EventSwitch | |||||
} | |||||
// new websocket connection wrapper | |||||
func NewWSConnection(baseConn *websocket.Conn, funcMap map[string]*RPCFunc, evsw events.EventSwitch) *wsConnection { | |||||
wsc := &wsConnection{ | |||||
remoteAddr: baseConn.RemoteAddr().String(), | |||||
baseConn: baseConn, | |||||
writeChan: make(chan types.RPCResponse, writeChanCapacity), // error when full. | |||||
funcMap: funcMap, | |||||
evsw: evsw, | |||||
} | |||||
wsc.BaseService = *cmn.NewBaseService(log, "wsConnection", wsc) | |||||
return wsc | |||||
} | |||||
// wsc.Start() blocks until the connection closes. | |||||
func (wsc *wsConnection) OnStart() error { | |||||
wsc.BaseService.OnStart() | |||||
// these must be set before the readRoutine is created, as it may | |||||
// call wsc.Stop(), which accesses these timers | |||||
wsc.readTimeout = time.NewTimer(time.Second * wsReadTimeoutSeconds) | |||||
wsc.pingTicker = time.NewTicker(time.Second * wsPingTickerSeconds) | |||||
// Read subscriptions/unsubscriptions to events | |||||
go wsc.readRoutine() | |||||
// Custom Ping handler to touch readTimeout | |||||
wsc.baseConn.SetPingHandler(func(m string) error { | |||||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||||
go wsc.baseConn.WriteControl(websocket.PongMessage, []byte(m), time.Now().Add(time.Second*wsWriteTimeoutSeconds)) | |||||
wsc.readTimeout.Reset(time.Second * wsReadTimeoutSeconds) | |||||
return nil | |||||
}) | |||||
wsc.baseConn.SetPongHandler(func(m string) error { | |||||
// NOTE: https://github.com/gorilla/websocket/issues/97 | |||||
wsc.readTimeout.Reset(time.Second * wsReadTimeoutSeconds) | |||||
return nil | |||||
}) | |||||
go wsc.readTimeoutRoutine() | |||||
// Write responses, BLOCKING. | |||||
wsc.writeRoutine() | |||||
return nil | |||||
} | |||||
func (wsc *wsConnection) OnStop() { | |||||
wsc.BaseService.OnStop() | |||||
if wsc.evsw != nil { | |||||
wsc.evsw.RemoveListener(wsc.remoteAddr) | |||||
} | |||||
wsc.readTimeout.Stop() | |||||
wsc.pingTicker.Stop() | |||||
// The write loop closes the websocket connection | |||||
// when it exits its loop, and the read loop | |||||
// closes the writeChan | |||||
} | |||||
func (wsc *wsConnection) readTimeoutRoutine() { | |||||
select { | |||||
case <-wsc.readTimeout.C: | |||||
log.Notice("Stopping connection due to read timeout") | |||||
wsc.Stop() | |||||
case <-wsc.Quit: | |||||
return | |||||
} | |||||
} | |||||
// Implements WSRPCConnection | |||||
func (wsc *wsConnection) GetRemoteAddr() string { | |||||
return wsc.remoteAddr | |||||
} | |||||
// Implements WSRPCConnection | |||||
func (wsc *wsConnection) GetEventSwitch() events.EventSwitch { | |||||
return wsc.evsw | |||||
} | |||||
// Implements WSRPCConnection | |||||
// Blocking write to writeChan until service stops. | |||||
// Goroutine-safe | |||||
func (wsc *wsConnection) WriteRPCResponse(resp types.RPCResponse) { | |||||
select { | |||||
case <-wsc.Quit: | |||||
return | |||||
case wsc.writeChan <- resp: | |||||
} | |||||
} | |||||
// Implements WSRPCConnection | |||||
// Nonblocking write. | |||||
// Goroutine-safe | |||||
func (wsc *wsConnection) TryWriteRPCResponse(resp types.RPCResponse) bool { | |||||
select { | |||||
case <-wsc.Quit: | |||||
return false | |||||
case wsc.writeChan <- resp: | |||||
return true | |||||
default: | |||||
return false | |||||
} | |||||
} | |||||
// Read from the socket and subscribe to or unsubscribe from events | |||||
func (wsc *wsConnection) readRoutine() { | |||||
// Do not close writeChan, to allow WriteRPCResponse() to fail. | |||||
// defer close(wsc.writeChan) | |||||
for { | |||||
select { | |||||
case <-wsc.Quit: | |||||
return | |||||
default: | |||||
var in []byte | |||||
// Do not set a deadline here like below: | |||||
// wsc.baseConn.SetReadDeadline(time.Now().Add(time.Second * wsReadTimeoutSeconds)) | |||||
// The client may not send anything for a while. | |||||
// We use `readTimeout` to handle read timeouts. | |||||
_, in, err := wsc.baseConn.ReadMessage() | |||||
if err != nil { | |||||
log.Notice("Failed to read from connection", "remote", wsc.remoteAddr, "err", err.Error()) | |||||
// an error reading the connection, | |||||
// kill the connection | |||||
wsc.Stop() | |||||
return | |||||
} | |||||
var request types.RPCRequest | |||||
err = json.Unmarshal(in, &request) | |||||
if err != nil { | |||||
errStr := fmt.Sprintf("Error unmarshaling data: %s", err.Error()) | |||||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, errStr)) | |||||
continue | |||||
} | |||||
// Now, fetch the RPCFunc and execute it. | |||||
rpcFunc := wsc.funcMap[request.Method] | |||||
if rpcFunc == nil { | |||||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, "RPC method unknown: "+request.Method)) | |||||
continue | |||||
} | |||||
var args []reflect.Value | |||||
if rpcFunc.ws { | |||||
wsCtx := types.WSRPCContext{Request: request, WSRPCConnection: wsc} | |||||
args, err = jsonParamsToArgsWS(rpcFunc, request.Params, wsCtx) | |||||
} else { | |||||
args, err = jsonParamsToArgsRPC(rpcFunc, request.Params) | |||||
} | |||||
if err != nil { | |||||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, err.Error())) | |||||
continue | |||||
} | |||||
returns := rpcFunc.f.Call(args) | |||||
log.Info("WSJSONRPC", "method", request.Method, "args", args, "returns", returns) | |||||
result, err := unreflectResult(returns) | |||||
if err != nil { | |||||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, nil, err.Error())) | |||||
continue | |||||
} else { | |||||
wsc.WriteRPCResponse(types.NewRPCResponse(request.ID, result, "")) | |||||
continue | |||||
} | |||||
} | |||||
} | |||||
} | |||||
// receives on a write channel and writes out on the socket | |||||
func (wsc *wsConnection) writeRoutine() { | |||||
defer wsc.baseConn.Close() | |||||
for { | |||||
select { | |||||
case <-wsc.Quit: | |||||
return | |||||
case <-wsc.pingTicker.C: | |||||
err := wsc.baseConn.WriteMessage(websocket.PingMessage, []byte{}) | |||||
if err != nil { | |||||
log.Error("Failed to write ping message on websocket", "error", err) | |||||
wsc.Stop() | |||||
return | |||||
} | |||||
case msg := <-wsc.writeChan: | |||||
jsonBytes, err := json.Marshal(msg) | |||||
if err != nil { | |||||
log.Error("Failed to marshal RPCResponse to JSON", "error", err) | |||||
} else { | |||||
wsc.baseConn.SetWriteDeadline(time.Now().Add(time.Second * wsWriteTimeoutSeconds)) | |||||
if err = wsc.baseConn.WriteMessage(websocket.TextMessage, jsonBytes); err != nil { | |||||
log.Warn("Failed to write response on websocket", "error", err) | |||||
wsc.Stop() | |||||
return | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
//---------------------------------------- | |||||
// Main manager for all websocket connections | |||||
// Holds the event switch | |||||
// NOTE: The websocket path is defined externally, e.g. in node/node.go | |||||
type WebsocketManager struct { | |||||
websocket.Upgrader | |||||
funcMap map[string]*RPCFunc | |||||
evsw events.EventSwitch | |||||
} | |||||
func NewWebsocketManager(funcMap map[string]*RPCFunc, evsw events.EventSwitch) *WebsocketManager { | |||||
return &WebsocketManager{ | |||||
funcMap: funcMap, | |||||
evsw: evsw, | |||||
Upgrader: websocket.Upgrader{ | |||||
ReadBufferSize: 1024, | |||||
WriteBufferSize: 1024, | |||||
CheckOrigin: func(r *http.Request) bool { | |||||
// TODO | |||||
return true | |||||
}, | |||||
}, | |||||
} | |||||
} | |||||
// Upgrade the request/response (via http.Hijack) and starts the wsConnection. | |||||
func (wm *WebsocketManager) WebsocketHandler(w http.ResponseWriter, r *http.Request) { | |||||
wsConn, err := wm.Upgrade(w, r, nil) | |||||
if err != nil { | |||||
// TODO - return http error | |||||
log.Error("Failed to upgrade to websocket connection", "error", err) | |||||
return | |||||
} | |||||
// register connection | |||||
con := NewWSConnection(wsConn, wm.funcMap, wm.evsw) | |||||
log.Notice("New websocket connection", "remote", con.remoteAddr) | |||||
con.Start() // Blocking | |||||
} | |||||
// rpc.websocket | |||||
//----------------------------------------------------------------------------- | |||||
// NOTE: assume returns is result struct and error. If error is not nil, return it | |||||
func unreflectResult(returns []reflect.Value) (interface{}, error) { | |||||
errV := returns[1] | |||||
if errV.Interface() != nil { | |||||
return nil, errors.Errorf("%v", errV.Interface()) | |||||
} | |||||
rv := returns[0] | |||||
// the result is a registered interface, | |||||
// we need a pointer to it so we can marshal with type byte | |||||
rvp := reflect.New(rv.Type()) | |||||
rvp.Elem().Set(rv) | |||||
return rvp.Interface(), nil | |||||
} | |||||
// writes a list of available rpc endpoints as an html page | |||||
func writeListOfEndpoints(w http.ResponseWriter, r *http.Request, funcMap map[string]*RPCFunc) { | |||||
noArgNames := []string{} | |||||
argNames := []string{} | |||||
for name, funcData := range funcMap { | |||||
if len(funcData.args) == 0 { | |||||
noArgNames = append(noArgNames, name) | |||||
} else { | |||||
argNames = append(argNames, name) | |||||
} | |||||
} | |||||
sort.Strings(noArgNames) | |||||
sort.Strings(argNames) | |||||
buf := new(bytes.Buffer) | |||||
buf.WriteString("<html><body>") | |||||
buf.WriteString("<br>Available endpoints:<br>") | |||||
for _, name := range noArgNames { | |||||
link := fmt.Sprintf("http://%s/%s", r.Host, name) | |||||
buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link)) | |||||
} | |||||
buf.WriteString("<br>Endpoints that require arguments:<br>") | |||||
for _, name := range argNames { | |||||
link := fmt.Sprintf("http://%s/%s?", r.Host, name) | |||||
funcData := funcMap[name] | |||||
for i, argName := range funcData.argNames { | |||||
link += argName + "=_" | |||||
if i < len(funcData.argNames)-1 { | |||||
link += "&" | |||||
} | |||||
} | |||||
buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link)) | |||||
} | |||||
buf.WriteString("</body></html>") | |||||
w.Header().Set("Content-Type", "text/html") | |||||
w.WriteHeader(200) | |||||
w.Write(buf.Bytes()) | |||||
} |
@ -0,0 +1,90 @@ | |||||
package rpcserver | |||||
import ( | |||||
"encoding/hex" | |||||
"net/http" | |||||
"regexp" | |||||
"strconv" | |||||
"github.com/pkg/errors" | |||||
) | |||||
var ( | |||||
// Parts of regular expressions | |||||
atom = "[A-Z0-9!#$%&'*+\\-/=?^_`{|}~]+" | |||||
dotAtom = atom + `(?:\.` + atom + `)*` | |||||
domain = `[A-Z0-9.-]+\.[A-Z]{2,4}` | |||||
RE_HEX = regexp.MustCompile(`^(?i)[a-f0-9]+$`) | |||||
RE_EMAIL = regexp.MustCompile(`^(?i)(` + dotAtom + `)@(` + dotAtom + `)$`) | |||||
RE_ADDRESS = regexp.MustCompile(`^(?i)[a-z0-9]{25,34}$`) | |||||
RE_HOST = regexp.MustCompile(`^(?i)(` + domain + `)$`) | |||||
//RE_ID12 = regexp.MustCompile(`^[a-zA-Z0-9]{12}$`) | |||||
) | |||||
func GetParam(r *http.Request, param string) string { | |||||
s := r.URL.Query().Get(param) | |||||
if s == "" { | |||||
s = r.FormValue(param) | |||||
} | |||||
return s | |||||
} | |||||
func GetParamByteSlice(r *http.Request, param string) ([]byte, error) { | |||||
s := GetParam(r, param) | |||||
return hex.DecodeString(s) | |||||
} | |||||
func GetParamInt64(r *http.Request, param string) (int64, error) { | |||||
s := GetParam(r, param) | |||||
i, err := strconv.ParseInt(s, 10, 64) | |||||
if err != nil { | |||||
return 0, errors.Errorf(param, err.Error()) | |||||
} | |||||
return i, nil | |||||
} | |||||
func GetParamInt32(r *http.Request, param string) (int32, error) { | |||||
s := GetParam(r, param) | |||||
i, err := strconv.ParseInt(s, 10, 32) | |||||
if err != nil { | |||||
return 0, errors.Errorf(param, err.Error()) | |||||
} | |||||
return int32(i), nil | |||||
} | |||||
func GetParamUint64(r *http.Request, param string) (uint64, error) { | |||||
s := GetParam(r, param) | |||||
i, err := strconv.ParseUint(s, 10, 64) | |||||
if err != nil { | |||||
return 0, errors.Errorf(param, err.Error()) | |||||
} | |||||
return i, nil | |||||
} | |||||
func GetParamUint(r *http.Request, param string) (uint, error) { | |||||
s := GetParam(r, param) | |||||
i, err := strconv.ParseUint(s, 10, 64) | |||||
if err != nil { | |||||
return 0, errors.Errorf(param, err.Error()) | |||||
} | |||||
return uint(i), nil | |||||
} | |||||
func GetParamRegexp(r *http.Request, param string, re *regexp.Regexp) (string, error) { | |||||
s := GetParam(r, param) | |||||
if !re.MatchString(s) { | |||||
return "", errors.Errorf(param, "Did not match regular expression %v", re.String()) | |||||
} | |||||
return s, nil | |||||
} | |||||
func GetParamFloat64(r *http.Request, param string) (float64, error) { | |||||
s := GetParam(r, param) | |||||
f, err := strconv.ParseFloat(s, 64) | |||||
if err != nil { | |||||
return 0, errors.Errorf(param, err.Error()) | |||||
} | |||||
return f, nil | |||||
} |
@ -0,0 +1,125 @@ | |||||
// Commons for HTTP handling | |||||
package rpcserver | |||||
import ( | |||||
"bufio" | |||||
"encoding/json" | |||||
"fmt" | |||||
"net" | |||||
"net/http" | |||||
"runtime/debug" | |||||
"strings" | |||||
"time" | |||||
"github.com/pkg/errors" | |||||
types "github.com/tendermint/tendermint/rpc/types" | |||||
) | |||||
func StartHTTPServer(listenAddr string, handler http.Handler) (listener net.Listener, err error) { | |||||
// listenAddr should be fully formed including tcp:// or unix:// prefix | |||||
var proto, addr string | |||||
parts := strings.SplitN(listenAddr, "://", 2) | |||||
if len(parts) != 2 { | |||||
log.Warn("WARNING (tendermint/rpc): Please use fully formed listening addresses, including the tcp:// or unix:// prefix") | |||||
// we used to allow addrs without tcp/unix prefix by checking for a colon | |||||
// TODO: Deprecate | |||||
proto = types.SocketType(listenAddr) | |||||
addr = listenAddr | |||||
// return nil, errors.Errorf("Invalid listener address %s", lisenAddr) | |||||
} else { | |||||
proto, addr = parts[0], parts[1] | |||||
} | |||||
log.Notice(fmt.Sprintf("Starting RPC HTTP server on %s socket %v", proto, addr)) | |||||
listener, err = net.Listen(proto, addr) | |||||
if err != nil { | |||||
return nil, errors.Errorf("Failed to listen to %v: %v", listenAddr, err) | |||||
} | |||||
go func() { | |||||
res := http.Serve( | |||||
listener, | |||||
RecoverAndLogHandler(handler), | |||||
) | |||||
log.Crit("RPC HTTP server stopped", "result", res) | |||||
}() | |||||
return listener, nil | |||||
} | |||||
func WriteRPCResponseHTTP(w http.ResponseWriter, res types.RPCResponse) { | |||||
// jsonBytes := wire.JSONBytesPretty(res) | |||||
jsonBytes, err := json.Marshal(res) | |||||
if err != nil { | |||||
panic(err) | |||||
} | |||||
w.Header().Set("Content-Type", "application/json") | |||||
w.WriteHeader(200) | |||||
w.Write(jsonBytes) | |||||
} | |||||
//----------------------------------------------------------------------------- | |||||
// Wraps an HTTP handler, adding error logging. | |||||
// If the inner function panics, the outer function recovers, logs, sends an | |||||
// HTTP 500 error response. | |||||
func RecoverAndLogHandler(handler http.Handler) http.Handler { | |||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |||||
// Wrap the ResponseWriter to remember the status | |||||
rww := &ResponseWriterWrapper{-1, w} | |||||
begin := time.Now() | |||||
// Common headers | |||||
origin := r.Header.Get("Origin") | |||||
rww.Header().Set("Access-Control-Allow-Origin", origin) | |||||
rww.Header().Set("Access-Control-Allow-Credentials", "true") | |||||
rww.Header().Set("Access-Control-Expose-Headers", "X-Server-Time") | |||||
rww.Header().Set("X-Server-Time", fmt.Sprintf("%v", begin.Unix())) | |||||
defer func() { | |||||
// Send a 500 error if a panic happens during a handler. | |||||
// Without this, Chrome & Firefox were retrying aborted ajax requests, | |||||
// at least to my localhost. | |||||
if e := recover(); e != nil { | |||||
// If RPCResponse | |||||
if res, ok := e.(types.RPCResponse); ok { | |||||
WriteRPCResponseHTTP(rww, res) | |||||
} else { | |||||
// For the rest, | |||||
log.Error("Panic in RPC HTTP handler", "error", e, "stack", string(debug.Stack())) | |||||
rww.WriteHeader(http.StatusInternalServerError) | |||||
WriteRPCResponseHTTP(rww, types.NewRPCResponse("", nil, fmt.Sprintf("Internal Server Error: %v", e))) | |||||
} | |||||
} | |||||
// Finally, log. | |||||
durationMS := time.Since(begin).Nanoseconds() / 1000000 | |||||
if rww.Status == -1 { | |||||
rww.Status = 200 | |||||
} | |||||
log.Info("Served RPC HTTP response", | |||||
"method", r.Method, "url", r.URL, | |||||
"status", rww.Status, "duration", durationMS, | |||||
"remoteAddr", r.RemoteAddr, | |||||
) | |||||
}() | |||||
handler.ServeHTTP(rww, r) | |||||
}) | |||||
} | |||||
// Remember the status for logging | |||||
type ResponseWriterWrapper struct { | |||||
Status int | |||||
http.ResponseWriter | |||||
} | |||||
func (w *ResponseWriterWrapper) WriteHeader(status int) { | |||||
w.Status = status | |||||
w.ResponseWriter.WriteHeader(status) | |||||
} | |||||
// implements http.Hijacker | |||||
func (w *ResponseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { | |||||
return w.ResponseWriter.(http.Hijacker).Hijack() | |||||
} |
@ -0,0 +1,7 @@ | |||||
package rpcserver | |||||
import ( | |||||
"github.com/tendermint/log15" | |||||
) | |||||
var log = log15.New("module", "rpcserver") |