@ -0,0 +1,92 @@ | |||||
import sys | |||||
from tmsp.wire import hex2bytes, decode_big_endian, encode_big_endian | |||||
from tmsp.server import TMSPServer | |||||
from tmsp.reader import BytesBuffer | |||||
class CounterApplication(): | |||||
def __init__(self): | |||||
self.hashCount = 0 | |||||
self.txCount = 0 | |||||
self.commitCount = 0 | |||||
def open(self): | |||||
return CounterAppContext(self) | |||||
class CounterAppContext(): | |||||
def __init__(self, app): | |||||
self.app = app | |||||
self.hashCount = app.hashCount | |||||
self.txCount = app.txCount | |||||
self.commitCount = app.commitCount | |||||
self.serial = False | |||||
def echo(self, msg): | |||||
return msg, 0 | |||||
def info(self): | |||||
return ["hash, tx, commit counts:%d, %d, %d" % (self.hashCount, | |||||
self.txCount, | |||||
self.commitCount)], 0 | |||||
def set_option(self, key, value): | |||||
if key == "serial" and value == "on": | |||||
self.serial = True | |||||
return 0 | |||||
def append_tx(self, txBytes): | |||||
if self.serial: | |||||
txByteArray = bytearray(txBytes) | |||||
if len(txBytes) >= 2 and txBytes[:2] == "0x": | |||||
txByteArray = hex2bytes(txBytes[2:]) | |||||
txValue = decode_big_endian( | |||||
BytesBuffer(txByteArray), len(txBytes)) | |||||
if txValue != self.txCount: | |||||
return None, 1 | |||||
self.txCount += 1 | |||||
return None, 0 | |||||
def get_hash(self): | |||||
self.hashCount += 1 | |||||
if self.txCount == 0: | |||||
return "", 0 | |||||
h = encode_big_endian(self.txCount, 8) | |||||
h.reverse() | |||||
return str(h), 0 | |||||
def commit(self): | |||||
self.commitCount += 1 | |||||
return 0 | |||||
def rollback(self): | |||||
return 0 | |||||
def add_listener(self): | |||||
return 0 | |||||
def rm_listener(self): | |||||
return 0 | |||||
def event(self): | |||||
return | |||||
if __name__ == '__main__': | |||||
l = len(sys.argv) | |||||
if l == 1: | |||||
port = 46658 | |||||
elif l == 2: | |||||
port = int(sys.argv[1]) | |||||
else: | |||||
print("too many arguments") | |||||
quit() | |||||
print('TMSP Demo APP (Python)') | |||||
app = CounterApplication() | |||||
server = TMSPServer(app, port) | |||||
server.main_loop() |
@ -0,0 +1,55 @@ | |||||
from .wire import decode_string | |||||
# map type_byte to message name | |||||
message_types = { | |||||
0x01: "echo", | |||||
0x02: "flush", | |||||
0x03: "info", | |||||
0x04: "set_option", | |||||
0x21: "append_tx", | |||||
0x22: "get_hash", | |||||
0x23: "commit", | |||||
0x24: "rollback", | |||||
0x25: "add_listener", | |||||
0x26: "rm_listener", | |||||
} | |||||
# return the decoded arguments of tmsp messages | |||||
class RequestDecoder(): | |||||
def __init__(self, reader): | |||||
self.reader = reader | |||||
def echo(self): | |||||
return decode_string(self.reader) | |||||
def flush(self): | |||||
return | |||||
def info(self): | |||||
return | |||||
def set_option(self): | |||||
return decode_string(self.reader), decode_string(self.reader) | |||||
def append_tx(self): | |||||
return decode_string(self.reader) | |||||
def get_hash(self): | |||||
return | |||||
def commit(self): | |||||
return | |||||
def rollback(self): | |||||
return | |||||
def add_listener(self): | |||||
# TODO | |||||
return | |||||
def rm_listener(self): | |||||
# TODO | |||||
return |
@ -0,0 +1,56 @@ | |||||
# Simple read() method around a bytearray | |||||
class BytesBuffer(): | |||||
def __init__(self, b): | |||||
self.buf = b | |||||
self.readCount = 0 | |||||
def count(self): | |||||
return self.readCount | |||||
def reset_count(self): | |||||
self.readCount = 0 | |||||
def size(self): | |||||
return len(self.buf) | |||||
def peek(self): | |||||
return self.buf[0] | |||||
def write(self, b): | |||||
# b should be castable to byte array | |||||
self.buf += bytearray(b) | |||||
def read(self, n): | |||||
if len(self.buf) < n: | |||||
print("reader err: buf less than n") | |||||
# TODO: exception | |||||
return | |||||
self.readCount += n | |||||
r = self.buf[:n] | |||||
self.buf = self.buf[n:] | |||||
return r | |||||
# Buffer bytes off a tcp connection and read them off in chunks | |||||
class ConnReader(): | |||||
def __init__(self, conn): | |||||
self.conn = conn | |||||
self.buf = bytearray() | |||||
# blocking | |||||
def read(self, n): | |||||
while n > len(self.buf): | |||||
moreBuf = self.conn.recv(1024) | |||||
if not moreBuf: | |||||
raise IOError("dead connection") | |||||
self.buf = self.buf + bytearray(moreBuf) | |||||
r = self.buf[:n] | |||||
self.buf = self.buf[n:] | |||||
return r |
@ -0,0 +1,201 @@ | |||||
import socket | |||||
import select | |||||
import sys | |||||
import logging | |||||
from .wire import decode_varint, encode | |||||
from .reader import BytesBuffer | |||||
from .msg import RequestDecoder, message_types | |||||
# hold the asyncronous state of a connection | |||||
# ie. we may not get enough bytes on one read to decode the message | |||||
logger = logging.getLogger(__name__) | |||||
class Connection(): | |||||
def __init__(self, fd, appCtx): | |||||
self.fd = fd | |||||
self.appCtx = appCtx | |||||
self.recBuf = BytesBuffer(bytearray()) | |||||
self.resBuf = BytesBuffer(bytearray()) | |||||
self.msgLength = 0 | |||||
self.decoder = RequestDecoder(self.recBuf) | |||||
self.inProgress = False # are we in the middle of a message | |||||
def recv(this): | |||||
data = this.fd.recv(1024) | |||||
if not data: # what about len(data) == 0 | |||||
raise IOError("dead connection") | |||||
this.recBuf.write(data) | |||||
# TMSP server responds to messges by calling methods on the app | |||||
class TMSPServer(): | |||||
def __init__(self, app, port=5410): | |||||
self.app = app | |||||
# map conn file descriptors to (appContext, reqBuf, resBuf, msgDecoder) | |||||
self.appMap = {} | |||||
self.port = port | |||||
self.listen_backlog = 10 | |||||
self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||||
self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |||||
self.listener.setblocking(0) | |||||
self.listener.bind(('', port)) | |||||
self.listener.listen(self.listen_backlog) | |||||
self.shutdown = False | |||||
self.read_list = [self.listener] | |||||
self.write_list = [] | |||||
def handle_new_connection(self, r): | |||||
new_fd, new_addr = r.accept() | |||||
new_fd.setblocking(0) # non-blocking | |||||
self.read_list.append(new_fd) | |||||
self.write_list.append(new_fd) | |||||
print('new connection to', new_addr) | |||||
appContext = self.app.open() | |||||
self.appMap[new_fd] = Connection(new_fd, appContext) | |||||
def handle_conn_closed(self, r): | |||||
self.read_list.remove(r) | |||||
self.write_list.remove(r) | |||||
r.close() | |||||
print("connection closed") | |||||
def handle_recv(self, r): | |||||
# appCtx, recBuf, resBuf, conn | |||||
conn = self.appMap[r] | |||||
while True: | |||||
try: | |||||
print("recv loop") | |||||
# check if we need more data first | |||||
if conn.inProgress: | |||||
if (conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength): | |||||
conn.recv() | |||||
else: | |||||
if conn.recBuf.size() == 0: | |||||
conn.recv() | |||||
conn.inProgress = True | |||||
# see if we have enough to get the message length | |||||
if conn.msgLength == 0: | |||||
ll = conn.recBuf.peek() | |||||
if conn.recBuf.size() < 1 + ll: | |||||
# we don't have enough bytes to read the length yet | |||||
return | |||||
print("decoding msg length") | |||||
conn.msgLength = decode_varint(conn.recBuf) | |||||
# see if we have enough to decode the message | |||||
if conn.recBuf.size() < conn.msgLength: | |||||
return | |||||
# now we can decode the message | |||||
# first read the request type and get the particular msg | |||||
# decoder | |||||
typeByte = conn.recBuf.read(1) | |||||
typeByte = int(typeByte[0]) | |||||
resTypeByte = typeByte + 0x10 | |||||
req_type = message_types[typeByte] | |||||
if req_type == "flush": | |||||
# messages are length prefixed | |||||
conn.resBuf.write(encode(1)) | |||||
conn.resBuf.write([resTypeByte]) | |||||
conn.fd.send(conn.resBuf.buf) | |||||
conn.msgLength = 0 | |||||
conn.inProgress = False | |||||
conn.resBuf = BytesBuffer(bytearray()) | |||||
return | |||||
decoder = getattr(conn.decoder, req_type) | |||||
print("decoding args") | |||||
req_args = decoder() | |||||
print("got args", req_args) | |||||
# done decoding message | |||||
conn.msgLength = 0 | |||||
conn.inProgress = False | |||||
req_f = getattr(conn.appCtx, req_type) | |||||
if req_args is None: | |||||
res = req_f() | |||||
elif isinstance(req_args, tuple): | |||||
res = req_f(*req_args) | |||||
else: | |||||
res = req_f(req_args) | |||||
if isinstance(res, tuple): | |||||
res, ret_code = res | |||||
else: | |||||
ret_code = res | |||||
res = None | |||||
print("called", req_type, "ret code:", ret_code, 'res:', res) | |||||
if ret_code != 0: | |||||
print("non-zero retcode:", ret_code) | |||||
if req_type in ("echo", "info"): # these dont return a ret code | |||||
enc = encode(res) | |||||
# messages are length prefixed | |||||
conn.resBuf.write(encode(len(enc) + 1)) | |||||
conn.resBuf.write([resTypeByte]) | |||||
conn.resBuf.write(enc) | |||||
else: | |||||
enc, encRet = encode(res), encode(ret_code) | |||||
# messages are length prefixed | |||||
conn.resBuf.write(encode(len(enc) + len(encRet) + 1)) | |||||
conn.resBuf.write([resTypeByte]) | |||||
conn.resBuf.write(encRet) | |||||
conn.resBuf.write(enc) | |||||
except IOError as e: | |||||
print("IOError on reading from connection:", e) | |||||
self.handle_conn_closed(r) | |||||
return | |||||
except Exception as e: | |||||
import sys | |||||
print(sys.exc_info()[0]) | |||||
print("error reading from connection", str(e)) | |||||
self.handle_conn_closed(r) | |||||
return | |||||
def main_loop(self): | |||||
while not self.shutdown: | |||||
r_list, w_list, _ = select.select( | |||||
self.read_list, self.write_list, [], 2.5) | |||||
for r in r_list: | |||||
if (r == self.listener): | |||||
try: | |||||
self.handle_new_connection(r) | |||||
# undo adding to read list ... | |||||
except NameError as e: | |||||
print("Could not connect due to NameError:", e) | |||||
except TypeError as e: | |||||
print("Could not connect due to TypeError:", e) | |||||
except: | |||||
print("Could not connect due to unexpected error:", sys.exc_info()[0]) | |||||
else: | |||||
self.handle_recv(r) | |||||
def handle_shutdown(self): | |||||
for r in self.read_list: | |||||
r.close() | |||||
for w in self.write_list: | |||||
try: | |||||
w.close() | |||||
except Exception as e: | |||||
print(e) # TODO: add logging | |||||
self.shutdown = True |
@ -0,0 +1,119 @@ | |||||
# the decoder works off a reader | |||||
# the encoder returns bytearray | |||||
def hex2bytes(h): | |||||
return bytearray(h.decode('hex')) | |||||
def bytes2hex(b): | |||||
if type(b) in (str, str): | |||||
return "".join([hex(ord(c))[2:].zfill(2) for c in b]) | |||||
else: | |||||
return bytes2hex(b.decode()) | |||||
# expects uvarint64 (no crazy big nums!) | |||||
def uvarint_size(i): | |||||
if i == 0: | |||||
return 0 | |||||
for j in range(1, 8): | |||||
if i < 1 << j * 8: | |||||
return j | |||||
return 8 | |||||
# expects i < 2**size | |||||
def encode_big_endian(i, size): | |||||
if size == 0: | |||||
return bytearray() | |||||
return encode_big_endian(i / 256, size - 1) + bytearray([i % 256]) | |||||
def decode_big_endian(reader, size): | |||||
if size == 0: | |||||
return 0 | |||||
firstByte = reader.read(1)[0] | |||||
return firstByte * (256 ** (size - 1)) + decode_big_endian(reader, size - 1) | |||||
# ints are max 16 bytes long | |||||
def encode_varint(i): | |||||
negate = False | |||||
if i < 0: | |||||
negate = True | |||||
i = -i | |||||
size = uvarint_size(i) | |||||
if size == 0: | |||||
return bytearray([0]) | |||||
big_end = encode_big_endian(i, size) | |||||
if negate: | |||||
size += 0xF0 | |||||
return bytearray([size]) + big_end | |||||
# returns the int and whats left of the byte array | |||||
def decode_varint(reader): | |||||
size = reader.read(1)[0] | |||||
if size == 0: | |||||
return 0 | |||||
negate = True if size > int(0xF0) else False | |||||
if negate: | |||||
size = size - 0xF0 | |||||
i = decode_big_endian(reader, size) | |||||
if negate: | |||||
i = i * (-1) | |||||
return i | |||||
def encode_string(s): | |||||
size = encode_varint(len(s)) | |||||
return size + bytearray(s, 'utf8') | |||||
def decode_string(reader): | |||||
length = decode_varint(reader) | |||||
raw_data = reader.read(length) | |||||
return raw_data.decode() | |||||
def encode_list(s): | |||||
b = bytearray() | |||||
list(map(b.extend, list(map(encode, s)))) | |||||
return encode_varint(len(s)) + b | |||||
def encode(s): | |||||
print('encoding', repr(s)) | |||||
if s is None: | |||||
return bytearray() | |||||
if isinstance(s, int): | |||||
return encode_varint(s) | |||||
elif isinstance(s, str): | |||||
return encode_string(s) | |||||
elif isinstance(s, list): | |||||
return encode_list(s) | |||||
elif isinstance(s, bytearray): | |||||
return encode_string(s) | |||||
else: | |||||
print("UNSUPPORTED TYPE!", type(s), s) | |||||
if __name__ == '__main__': | |||||
ns = [100, 100, 1000, 256] | |||||
ss = [2, 5, 5, 2] | |||||
bs = list(map(encode_big_endian, ns, ss)) | |||||
ds = list(map(decode_big_endian, bs, ss)) | |||||
print(ns) | |||||
print([i[0] for i in ds]) | |||||
ss = ["abc", "hi there jim", "ok now what"] | |||||
e = list(map(encode_string, ss)) | |||||
d = list(map(decode_string, e)) | |||||
print(ss) | |||||
print([i[0] for i in d]) |