@ -0,0 +1,92 @@ | |||
import sys | |||
from tmsp.wire import hex2bytes, decode_big_endian, encode_big_endian | |||
from tmsp.server import TMSPServer | |||
from tmsp.reader import BytesBuffer | |||
class CounterApplication(): | |||
def __init__(self): | |||
self.hashCount = 0 | |||
self.txCount = 0 | |||
self.commitCount = 0 | |||
def open(self): | |||
return CounterAppContext(self) | |||
class CounterAppContext(): | |||
def __init__(self, app): | |||
self.app = app | |||
self.hashCount = app.hashCount | |||
self.txCount = app.txCount | |||
self.commitCount = app.commitCount | |||
self.serial = False | |||
def echo(self, msg): | |||
return msg, 0 | |||
def info(self): | |||
return ["hash, tx, commit counts:%d, %d, %d" % (self.hashCount, | |||
self.txCount, | |||
self.commitCount)], 0 | |||
def set_option(self, key, value): | |||
if key == "serial" and value == "on": | |||
self.serial = True | |||
return 0 | |||
def append_tx(self, txBytes): | |||
if self.serial: | |||
txByteArray = bytearray(txBytes) | |||
if len(txBytes) >= 2 and txBytes[:2] == "0x": | |||
txByteArray = hex2bytes(txBytes[2:]) | |||
txValue = decode_big_endian( | |||
BytesBuffer(txByteArray), len(txBytes)) | |||
if txValue != self.txCount: | |||
return None, 1 | |||
self.txCount += 1 | |||
return None, 0 | |||
def get_hash(self): | |||
self.hashCount += 1 | |||
if self.txCount == 0: | |||
return "", 0 | |||
h = encode_big_endian(self.txCount, 8) | |||
h.reverse() | |||
return h.decode(), 0 | |||
def commit(self): | |||
self.commitCount += 1 | |||
return 0 | |||
def rollback(self): | |||
return 0 | |||
def add_listener(self): | |||
return 0 | |||
def rm_listener(self): | |||
return 0 | |||
def event(self): | |||
return | |||
if __name__ == '__main__': | |||
l = len(sys.argv) | |||
if l == 1: | |||
port = 46658 | |||
elif l == 2: | |||
port = int(sys.argv[1]) | |||
else: | |||
print("too many arguments") | |||
quit() | |||
print('TMSP Demo APP (Python)') | |||
app = CounterApplication() | |||
server = TMSPServer(app, port) | |||
server.main_loop() |
@ -0,0 +1,55 @@ | |||
from .wire import decode_string | |||
# map type_byte to message name | |||
message_types = { | |||
0x01: "echo", | |||
0x02: "flush", | |||
0x03: "info", | |||
0x04: "set_option", | |||
0x21: "append_tx", | |||
0x22: "get_hash", | |||
0x23: "commit", | |||
0x24: "rollback", | |||
0x25: "add_listener", | |||
0x26: "rm_listener", | |||
} | |||
# return the decoded arguments of tmsp messages | |||
class RequestDecoder(): | |||
def __init__(self, reader): | |||
self.reader = reader | |||
def echo(self): | |||
return decode_string(self.reader) | |||
def flush(self): | |||
return | |||
def info(self): | |||
return | |||
def set_option(self): | |||
return decode_string(self.reader), decode_string(self.reader) | |||
def append_tx(self): | |||
return decode_string(self.reader) | |||
def get_hash(self): | |||
return | |||
def commit(self): | |||
return | |||
def rollback(self): | |||
return | |||
def add_listener(self): | |||
# TODO | |||
return | |||
def rm_listener(self): | |||
# TODO | |||
return |
@ -0,0 +1,56 @@ | |||
# Simple read() method around a bytearray | |||
class BytesBuffer(): | |||
def __init__(self, b): | |||
self.buf = b | |||
self.readCount = 0 | |||
def count(self): | |||
return self.readCount | |||
def reset_count(self): | |||
self.readCount = 0 | |||
def size(self): | |||
return len(self.buf) | |||
def peek(self): | |||
return self.buf[0] | |||
def write(self, b): | |||
# b should be castable to byte array | |||
self.buf += bytearray(b) | |||
def read(self, n): | |||
if len(self.buf) < n: | |||
print("reader err: buf less than n") | |||
# TODO: exception | |||
return | |||
self.readCount += n | |||
r = self.buf[:n] | |||
self.buf = self.buf[n:] | |||
return r | |||
# Buffer bytes off a tcp connection and read them off in chunks | |||
class ConnReader(): | |||
def __init__(self, conn): | |||
self.conn = conn | |||
self.buf = bytearray() | |||
# blocking | |||
def read(self, n): | |||
while n > len(self.buf): | |||
moreBuf = self.conn.recv(1024) | |||
if not moreBuf: | |||
raise IOError("dead connection") | |||
self.buf = self.buf + bytearray(moreBuf) | |||
r = self.buf[:n] | |||
self.buf = self.buf[n:] | |||
return r |
@ -0,0 +1,199 @@ | |||
import socket | |||
import select | |||
import sys | |||
import logging | |||
from .wire import decode_varint, encode | |||
from .reader import BytesBuffer | |||
from .msg import RequestDecoder, message_types | |||
# hold the asyncronous state of a connection | |||
# ie. we may not get enough bytes on one read to decode the message | |||
logger = logging.getLogger(__name__) | |||
class Connection(): | |||
def __init__(self, fd, appCtx): | |||
self.fd = fd | |||
self.appCtx = appCtx | |||
self.recBuf = BytesBuffer(bytearray()) | |||
self.resBuf = BytesBuffer(bytearray()) | |||
self.msgLength = 0 | |||
self.decoder = RequestDecoder(self.recBuf) | |||
self.inProgress = False # are we in the middle of a message | |||
def recv(this): | |||
data = this.fd.recv(1024) | |||
if not data: # what about len(data) == 0 | |||
raise IOError("dead connection") | |||
this.recBuf.write(data) | |||
# TMSP server responds to messges by calling methods on the app | |||
class TMSPServer(): | |||
def __init__(self, app, port=5410): | |||
self.app = app | |||
# map conn file descriptors to (appContext, reqBuf, resBuf, msgDecoder) | |||
self.appMap = {} | |||
self.port = port | |||
self.listen_backlog = 10 | |||
self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |||
self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |||
self.listener.setblocking(0) | |||
self.listener.bind(('', port)) | |||
self.listener.listen(self.listen_backlog) | |||
self.shutdown = False | |||
self.read_list = [self.listener] | |||
self.write_list = [] | |||
def handle_new_connection(self, r): | |||
new_fd, new_addr = r.accept() | |||
new_fd.setblocking(0) # non-blocking | |||
self.read_list.append(new_fd) | |||
self.write_list.append(new_fd) | |||
print('new connection to', new_addr) | |||
appContext = self.app.open() | |||
self.appMap[new_fd] = Connection(new_fd, appContext) | |||
def handle_conn_closed(self, r): | |||
self.read_list.remove(r) | |||
self.write_list.remove(r) | |||
r.close() | |||
print("connection closed") | |||
def handle_recv(self, r): | |||
# appCtx, recBuf, resBuf, conn | |||
conn = self.appMap[r] | |||
while True: | |||
try: | |||
print("recv loop") | |||
# check if we need more data first | |||
if conn.inProgress: | |||
if (conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength): | |||
conn.recv() | |||
else: | |||
if conn.recBuf.size() == 0: | |||
conn.recv() | |||
conn.inProgress = True | |||
# see if we have enough to get the message length | |||
if conn.msgLength == 0: | |||
ll = conn.recBuf.peek() | |||
if conn.recBuf.size() < 1 + ll: | |||
# we don't have enough bytes to read the length yet | |||
return | |||
print("decoding msg length") | |||
conn.msgLength = decode_varint(conn.recBuf) | |||
# see if we have enough to decode the message | |||
if conn.recBuf.size() < conn.msgLength: | |||
return | |||
# now we can decode the message | |||
# first read the request type and get the particular msg | |||
# decoder | |||
typeByte = conn.recBuf.read(1) | |||
typeByte = int(typeByte[0]) | |||
resTypeByte = typeByte + 0x10 | |||
req_type = message_types[typeByte] | |||
if req_type == "flush": | |||
# messages are length prefixed | |||
conn.resBuf.write(encode(1)) | |||
conn.resBuf.write([resTypeByte]) | |||
conn.fd.send(conn.resBuf.buf) | |||
conn.msgLength = 0 | |||
conn.inProgress = False | |||
conn.resBuf = BytesBuffer(bytearray()) | |||
return | |||
decoder = getattr(conn.decoder, req_type) | |||
print("decoding args") | |||
req_args = decoder() | |||
print("got args", req_args) | |||
# done decoding message | |||
conn.msgLength = 0 | |||
conn.inProgress = False | |||
req_f = getattr(conn.appCtx, req_type) | |||
if req_args is None: | |||
res = req_f() | |||
elif isinstance(req_args, tuple): | |||
res = req_f(*req_args) | |||
else: | |||
res = req_f(req_args) | |||
if isinstance(res, tuple): | |||
res, ret_code = res | |||
else: | |||
ret_code = res | |||
res = None | |||
print("called", req_type, "ret code:", ret_code, 'res:', res) | |||
if ret_code != 0: | |||
print("non-zero retcode:", ret_code) | |||
if req_type in ("echo", "info"): # these dont return a ret code | |||
enc = encode(res) | |||
# messages are length prefixed | |||
conn.resBuf.write(encode(len(enc) + 1)) | |||
conn.resBuf.write([resTypeByte]) | |||
conn.resBuf.write(enc) | |||
else: | |||
enc, encRet = encode(res), encode(ret_code) | |||
# messages are length prefixed | |||
conn.resBuf.write(encode(len(enc) + len(encRet) + 1)) | |||
conn.resBuf.write([resTypeByte]) | |||
conn.resBuf.write(encRet) | |||
conn.resBuf.write(enc) | |||
except IOError as e: | |||
print("IOError on reading from connection:", e) | |||
self.handle_conn_closed(r) | |||
return | |||
except Exception as e: | |||
logger.exception("error reading from connection") | |||
self.handle_conn_closed(r) | |||
return | |||
def main_loop(self): | |||
while not self.shutdown: | |||
r_list, w_list, _ = select.select( | |||
self.read_list, self.write_list, [], 2.5) | |||
for r in r_list: | |||
if (r == self.listener): | |||
try: | |||
self.handle_new_connection(r) | |||
# undo adding to read list ... | |||
except NameError as e: | |||
print("Could not connect due to NameError:", e) | |||
except TypeError as e: | |||
print("Could not connect due to TypeError:", e) | |||
except: | |||
print("Could not connect due to unexpected error:", sys.exc_info()[0]) | |||
else: | |||
self.handle_recv(r) | |||
def handle_shutdown(self): | |||
for r in self.read_list: | |||
r.close() | |||
for w in self.write_list: | |||
try: | |||
w.close() | |||
except Exception as e: | |||
print(e) # TODO: add logging | |||
self.shutdown = True |
@ -0,0 +1,119 @@ | |||
# the decoder works off a reader | |||
# the encoder returns bytearray | |||
def hex2bytes(h): | |||
return bytearray(h.decode('hex')) | |||
def bytes2hex(b): | |||
if type(b) in (str, str): | |||
return "".join([hex(ord(c))[2:].zfill(2) for c in b]) | |||
else: | |||
return bytes2hex(b.decode()) | |||
# expects uvarint64 (no crazy big nums!) | |||
def uvarint_size(i): | |||
if i == 0: | |||
return 0 | |||
for j in range(1, 8): | |||
if i < 1 << j * 8: | |||
return j | |||
return 8 | |||
# expects i < 2**size | |||
def encode_big_endian(i, size): | |||
if size == 0: | |||
return bytearray() | |||
return encode_big_endian(i // 256, size - 1) + bytearray([i % 256]) | |||
def decode_big_endian(reader, size): | |||
if size == 0: | |||
return 0 | |||
firstByte = reader.read(1)[0] | |||
return firstByte * (256 ** (size - 1)) + decode_big_endian(reader, size - 1) | |||
# ints are max 16 bytes long | |||
def encode_varint(i): | |||
negate = False | |||
if i < 0: | |||
negate = True | |||
i = -i | |||
size = uvarint_size(i) | |||
if size == 0: | |||
return bytearray([0]) | |||
big_end = encode_big_endian(i, size) | |||
if negate: | |||
size += 0xF0 | |||
return bytearray([size]) + big_end | |||
# returns the int and whats left of the byte array | |||
def decode_varint(reader): | |||
size = reader.read(1)[0] | |||
if size == 0: | |||
return 0 | |||
negate = True if size > int(0xF0) else False | |||
if negate: | |||
size = size - 0xF0 | |||
i = decode_big_endian(reader, size) | |||
if negate: | |||
i = i * (-1) | |||
return i | |||
def encode_string(s): | |||
size = encode_varint(len(s)) | |||
return size + bytearray(s, 'utf8') | |||
def decode_string(reader): | |||
length = decode_varint(reader) | |||
raw_data = reader.read(length) | |||
return raw_data.decode() | |||
def encode_list(s): | |||
b = bytearray() | |||
list(map(b.extend, list(map(encode, s)))) | |||
return encode_varint(len(s)) + b | |||
def encode(s): | |||
print('encoding', repr(s)) | |||
if s is None: | |||
return bytearray() | |||
if isinstance(s, int): | |||
return encode_varint(s) | |||
elif isinstance(s, str): | |||
return encode_string(s) | |||
elif isinstance(s, list): | |||
return encode_list(s) | |||
elif isinstance(s, bytearray): | |||
return encode_string(s) | |||
else: | |||
print("UNSUPPORTED TYPE!", type(s), s) | |||
if __name__ == '__main__': | |||
ns = [100, 100, 1000, 256] | |||
ss = [2, 5, 5, 2] | |||
bs = list(map(encode_big_endian, ns, ss)) | |||
ds = list(map(decode_big_endian, bs, ss)) | |||
print(ns) | |||
print([i[0] for i in ds]) | |||
ss = ["abc", "hi there jim", "ok now what"] | |||
e = list(map(encode_string, ss)) | |||
d = list(map(decode_string, e)) | |||
print(ss) | |||
print([i[0] for i in d]) |