@ -0,0 +1,83 @@ | |||
import sys | |||
sys.path.insert(0, './tmsp') | |||
from wire import * | |||
from server import * | |||
# tmsp application interface | |||
class CounterApplication(): | |||
def __init__(self): | |||
self.hashCount = 0 | |||
self.txCount = 0 | |||
self.commitCount = 0 | |||
def open(self): | |||
return CounterAppContext(self) | |||
class CounterAppContext(): | |||
def __init__(self, app): | |||
self.app = app | |||
self.hashCount = app.hashCount | |||
self.txCount = app.txCount | |||
self.commitCount = app.commitCount | |||
self.serial = False | |||
def echo(self, msg): | |||
return msg, 0 | |||
def info(self): | |||
return ["hash, tx, commit counts:%d, %d, %d"%(self.hashCount, self.txCount, self.commitCount)], 0 | |||
def set_option(self, key, value): | |||
if key == "serial" and value == "on": | |||
self.serial = True | |||
return 0 | |||
def append_tx(self, txBytes): | |||
if self.serial: | |||
txValue = decode_big_endian(BytesReader(txBytes), len(txBytes)) | |||
if txValue != self.txCount: | |||
return [], 1 | |||
self.txCount += 1 | |||
return None, 0 | |||
def get_hash(self): | |||
self.hashCount += 1 | |||
if self.txCount == 0: | |||
return "", 0 | |||
return str(encode_big_endian(self.txCount, 8)), 0 | |||
def commit(self): | |||
return 0 | |||
def rollback(self): | |||
return 0 | |||
def add_listener(self): | |||
return 0 | |||
def rm_listener(self): | |||
return 0 | |||
def event(self): | |||
return | |||
if __name__ == '__main__': | |||
l = len(sys.argv) | |||
if l == 1: | |||
port = 46658 | |||
elif l == 2: | |||
port = int(sys.argv[1]) | |||
else: | |||
print "too many arguments" | |||
quit() | |||
print 'TMSP Demo APP (Python)' | |||
app = CounterApplication() | |||
server = TMSPServer(app, port) | |||
server.main_loop() |
@ -0,0 +1,54 @@ | |||
from wire import * | |||
# map type_byte to message name | |||
message_types = { | |||
0x01 : "echo", | |||
0x02 : "flush", | |||
0x03 : "info", | |||
0x04 : "set_option", | |||
0x21 : "append_tx", | |||
0x22 : "get_hash", | |||
0x23 : "commit", | |||
0x24 : "rollback", | |||
0x25 : "add_listener", | |||
0x26 : "rm_listener", | |||
} | |||
# return the decoded arguments of tmsp messages | |||
class RequestDecoder(): | |||
def __init__(self, reader): | |||
self.reader = reader | |||
def echo(self): | |||
return decode_string(self.reader) | |||
def flush(self): | |||
return | |||
def info(self): | |||
return | |||
def set_option(self): | |||
return decode_string(self.reader), decode_string(self.reader) | |||
def append_tx(self): | |||
return decode_string(self.reader) | |||
def get_hash(self): | |||
return | |||
def commit(self): | |||
return | |||
def rollback(self): | |||
return | |||
def add_listener(self): | |||
# TODO | |||
return | |||
def rm_listener(self): | |||
# TODO | |||
return | |||
@ -0,0 +1,31 @@ | |||
# Simple read() method around a bytearray | |||
class BytesReader(): | |||
def __init__(self, b): | |||
self.buf = b | |||
def read(self, n): | |||
if len(self.buf) < n: | |||
# TODO: exception | |||
return | |||
r = self.buf[:n] | |||
self.buf = self.buf[n:] | |||
return r | |||
# Buffer bytes off a tcp connection and read them off in chunks | |||
class ConnReader(): | |||
def __init__(self, conn): | |||
self.conn = conn | |||
self.buf = bytearray() | |||
# blocking | |||
def read(self, n): | |||
while n > len(self.buf): | |||
moreBuf = self.conn.recv(1024) | |||
if not moreBuf: | |||
raise IOError("dead connection") | |||
self.buf = self.buf + bytearray(moreBuf) | |||
r = self.buf[:n] | |||
self.buf = self.buf[n:] | |||
return r |
@ -0,0 +1,134 @@ | |||
import socket | |||
import select | |||
import sys | |||
import os | |||
from wire import * | |||
from reader import * | |||
from msg import * | |||
# TMSP server responds to messges by calling methods on the app | |||
class TMSPServer(): | |||
def __init__(self, app, port=5410): | |||
self.app = app | |||
self.appMap = {} # map conn file descriptors to (appContext, msgDecoder) | |||
self.port = port | |||
self.listen_backlog = 10 | |||
self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |||
self.listener.setblocking(0) | |||
self.listener.bind(('', port)) | |||
self.listener.listen(self.listen_backlog) | |||
self.shutdown = False | |||
self.read_list = [self.listener] | |||
self.write_list = [] | |||
def handle_new_connection(self, r): | |||
new_fd, new_addr = r.accept() | |||
self.read_list.append(new_fd) | |||
self.write_list.append(new_fd) | |||
print 'new connection to', new_addr | |||
appContext = self.app.open() | |||
self.appMap[new_fd] = (appContext, RequestDecoder(ConnReader(new_fd))) | |||
def handle_conn_closed(self, r): | |||
self.read_list.remove(r) | |||
self.write_list.remove(r) | |||
r.close() | |||
print "connection closed" | |||
def handle_recv(self, r): | |||
appCtx, conn = self.appMap[r] | |||
response = bytearray() | |||
while True: | |||
try: | |||
# first read the request type and get the msg decoder | |||
typeByte = conn.reader.read(1) | |||
typeByte = int(typeByte[0]) | |||
resTypeByte = typeByte+0x10 | |||
req_type = message_types[typeByte] | |||
if req_type == "flush": | |||
response += bytearray([resTypeByte]) | |||
sent = r.send(str(response)) | |||
return | |||
decoder = getattr(conn, req_type) | |||
req_args = decoder() | |||
req_f = getattr(appCtx, req_type) | |||
if req_args == None: | |||
res = req_f() | |||
elif isinstance(req_args, tuple): | |||
res = req_f(*req_args) | |||
else: | |||
res = req_f(req_args) | |||
if isinstance(res, tuple): | |||
res, ret_code = res | |||
else: | |||
ret_code = res | |||
res = None | |||
if ret_code != 0: | |||
print "non-zero retcode:", ret_code | |||
return | |||
if req_type in ("echo", "info"): # these dont return a ret code | |||
response += bytearray([resTypeByte]) + encode(res) | |||
else: | |||
response += bytearray([resTypeByte]) + encode(ret_code) + encode(res) | |||
except TypeError as e: | |||
print "TypeError on reading from connection:", e | |||
self.handle_conn_closed(r) | |||
return | |||
except ValueError as e: | |||
print "ValueError on reading from connection:", e | |||
self.handle_conn_closed(r) | |||
return | |||
except IOError as e: | |||
print "IOError on reading from connection:", e | |||
self.handle_conn_closed(r) | |||
return | |||
except: | |||
print "error reading from connection", sys.exc_info()[0] # TODO better | |||
self.handle_conn_closed(r) | |||
return | |||
def main_loop(self): | |||
while not self.shutdown: | |||
r_list, w_list, _ = select.select(self.read_list, self.write_list, [], 2.5) | |||
for r in r_list: | |||
if (r == self.listener): | |||
try: | |||
self.handle_new_connection(r) | |||
# undo adding to read list ... | |||
except NameError as e: | |||
print "Could not connect due to NameError:", e | |||
except TypeError as e: | |||
print "Could not connect due to TypeError:", e | |||
except: | |||
print "Could not connect due to unexpected error:", sys.exc_info()[0] | |||
else: | |||
self.handle_recv(r) | |||
def handle_shutdown(self): | |||
for r in self.read_list: | |||
r.close() | |||
for w in self.write_list: | |||
try: | |||
w.close() | |||
except: pass | |||
self.shutdown = True | |||
@ -0,0 +1,99 @@ | |||
# the decoder works off a reader | |||
# the encoder returns bytearray | |||
def bytes2hex(b): | |||
if type(b) in (str, unicode): | |||
return "".join([hex(ord(c))[2:].zfill(2) for c in b]) | |||
else: | |||
return bytes2hex(b.decode()) | |||
# expects uvarint64 (no crazy big nums!) | |||
def uvarint_size(i): | |||
if i == 0: | |||
return 0 | |||
for j in xrange(1, 8): | |||
if i < 1<<j*8: | |||
return j | |||
return 8 | |||
# expects i < 2**size | |||
def encode_big_endian(i, size): | |||
if size == 0: | |||
return bytearray() | |||
return encode_big_endian(i/256, size-1) + bytearray([i%256]) | |||
def decode_big_endian(reader, size): | |||
if size == 0: | |||
return 0 | |||
firstByte = reader.read(1)[0] | |||
return firstByte*(256**(size-1)) + decode_big_endian(reader, size-1) | |||
# ints are max 16 bytes long | |||
def encode_varint(i): | |||
negate = False | |||
if i < 0: | |||
negate = True | |||
i = -i | |||
size = uvarint_size(i) | |||
if size == 0: | |||
return bytearray([0]) | |||
big_end = encode_big_endian(i, size) | |||
if negate: | |||
size += 0xF0 | |||
return bytearray([size]) + big_end | |||
# returns the int and whats left of the byte array | |||
def decode_varint(reader): | |||
size = reader.read(1)[0] | |||
if size == 0: | |||
return 0 | |||
negate = True if size > int(0xF0) else False | |||
if negate: size = size -0xF0 | |||
i = decode_big_endian(reader, size) | |||
if negate: i = i*(-1) | |||
return i | |||
def encode_string(s): | |||
size = encode_varint(len(s)) | |||
return size + bytearray(s) | |||
def decode_string(reader): | |||
length = decode_varint(reader) | |||
return str(reader.read(length)) | |||
def encode_list(s): | |||
b = bytearray() | |||
map(b.extend, map(encode, s)) | |||
return encode_varint(len(s)) + b | |||
def encode(s): | |||
if s == None: | |||
return bytearray() | |||
if isinstance(s, int): | |||
return encode_varint(s) | |||
elif isinstance(s, str): | |||
return encode_string(s) | |||
elif isinstance(s, list): | |||
return encode_list(s) | |||
else: | |||
print "UNSUPPORTED TYPE!", type(s), s | |||
import binascii | |||
if __name__ == '__main__': | |||
ns = [100,100,1000,256] | |||
ss = [2,5,5,2] | |||
bs = map(encode_big_endian, ns,ss) | |||
ds = map(decode_big_endian, bs,ss) | |||
print ns | |||
print [i[0] for i in ds] | |||
ss = ["abc", "hi there jim", "ok now what"] | |||
e = map(encode_string, ss) | |||
d = map(decode_string, e) | |||
print ss | |||
print [i[0] for i in d] |