diff --git a/example/python/app.py b/example/python/app.py index a439aa748..caed17830 100644 --- a/example/python/app.py +++ b/example/python/app.py @@ -1,82 +1,85 @@ - import sys -sys.path.insert(0, './tmsp') - -from wire import * -from server import * +from wire import hex2bytes, decode_big_endian, encode_big_endian +from server import TMSPServer +from reader import BytesBuffer -# tmsp application interface class CounterApplication(): + def __init__(self): self.hashCount = 0 - self.txCount = 0 - self.commitCount = 0 + self.txCount = 0 + self.commitCount = 0 def open(self): - return CounterAppContext(self) + return CounterAppContext(self) + class CounterAppContext(): - def __init__(self, app): - self.app = app - self.hashCount = app.hashCount - self.txCount = app.txCount - self.commitCount = app.commitCount - self.serial = False - - def echo(self, msg): - return msg, 0 - - def info(self): - return ["hash, tx, commit counts:%d, %d, %d"%(self.hashCount, self.txCount, self.commitCount)], 0 - - def set_option(self, key, value): - if key == "serial" and value == "on": - self.serial = True - return 0 - - def append_tx(self, txBytes): - if self.serial: - txByteArray = bytearray(txBytes) - if len(txBytes) >= 2 and txBytes[:2] == "0x": - txByteArray = hex2bytes(txBytes[2:]) - txValue = decode_big_endian(BytesBuffer(txByteArray), len(txBytes)) - if txValue != self.txCount: - return None, 1 - self.txCount += 1 - return None, 0 - - def get_hash(self): - self.hashCount += 1 - if self.txCount == 0: - return "", 0 - h = encode_big_endian(self.txCount, 8) - h.reverse() - return str(h), 0 - - def commit(self): - self.commitCount += 1 - return 0 - - def rollback(self): - return 0 - - def add_listener(self): - return 0 - - def rm_listener(self): - return 0 - - def event(self): - return - - + + def __init__(self, app): + self.app = app + self.hashCount = app.hashCount + self.txCount = app.txCount + self.commitCount = app.commitCount + self.serial = False + + def echo(self, msg): + return msg, 0 + + def info(self): + return ["hash, tx, commit counts:%d, %d, %d" % (self.hashCount, + self.txCount, + self.commitCount)], 0 + + def set_option(self, key, value): + if key == "serial" and value == "on": + self.serial = True + return 0 + + def append_tx(self, txBytes): + if self.serial: + txByteArray = bytearray(txBytes) + if len(txBytes) >= 2 and txBytes[:2] == "0x": + txByteArray = hex2bytes(txBytes[2:]) + txValue = decode_big_endian( + BytesBuffer(txByteArray), len(txBytes)) + if txValue != self.txCount: + return None, 1 + self.txCount += 1 + return None, 0 + + def get_hash(self): + self.hashCount += 1 + if self.txCount == 0: + return "", 0 + h = encode_big_endian(self.txCount, 8) + h.reverse() + return str(h), 0 + + def commit(self): + self.commitCount += 1 + return 0 + + def rollback(self): + return 0 + + def add_listener(self): + return 0 + + def rm_listener(self): + return 0 + + def event(self): + return + + if __name__ == '__main__': l = len(sys.argv) if l == 1: - port = 46658 - elif l == 2: + port = 46658 + elif l == 2: port = int(sys.argv[1]) else: print "too many arguments" @@ -84,6 +87,6 @@ if __name__ == '__main__': print 'TMSP Demo APP (Python)' - app = CounterApplication() + app = CounterApplication() server = TMSPServer(app, port) server.main_loop() diff --git a/example/python/tmsp/__init__.py b/example/python/tmsp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/example/python/tmsp/msg.py b/example/python/tmsp/msg.py index a99386a3b..f9339fe9b 100644 --- a/example/python/tmsp/msg.py +++ b/example/python/tmsp/msg.py @@ -1,54 +1,55 @@ -from wire import * +from wire import decode_string # map type_byte to message name message_types = { -0x01 : "echo", -0x02 : "flush", -0x03 : "info", -0x04 : "set_option", -0x21 : "append_tx", -0x22 : "get_hash", -0x23 : "commit", -0x24 : "rollback", -0x25 : "add_listener", -0x26 : "rm_listener", + 0x01: "echo", + 0x02: "flush", + 0x03: "info", + 0x04: "set_option", + 0x21: "append_tx", + 0x22: "get_hash", + 0x23: "commit", + 0x24: "rollback", + 0x25: "add_listener", + 0x26: "rm_listener", } # return the decoded arguments of tmsp messages -class RequestDecoder(): - def __init__(self, reader): - self.reader = reader - def echo(self): - return decode_string(self.reader) - def flush(self): - return +class RequestDecoder(): + + def __init__(self, reader): + self.reader = reader - def info(self): - return + def echo(self): + return decode_string(self.reader) - def set_option(self): - return decode_string(self.reader), decode_string(self.reader) + def flush(self): + return - def append_tx(self): - return decode_string(self.reader) + def info(self): + return - def get_hash(self): - return + def set_option(self): + return decode_string(self.reader), decode_string(self.reader) - def commit(self): - return + def append_tx(self): + return decode_string(self.reader) - def rollback(self): - return + def get_hash(self): + return - def add_listener(self): - # TODO - return + def commit(self): + return - def rm_listener(self): - # TODO - return + def rollback(self): + return + def add_listener(self): + # TODO + return + def rm_listener(self): + # TODO + return diff --git a/example/python/tmsp/reader.py b/example/python/tmsp/reader.py index 3b1f87fcb..6c0dad94e 100644 --- a/example/python/tmsp/reader.py +++ b/example/python/tmsp/reader.py @@ -1,50 +1,56 @@ # Simple read() method around a bytearray + + class BytesBuffer(): - def __init__(self, b): - self.buf = b - self.readCount = 0 - - def count(self): - return self.readCount - - def reset_count(self): - self.readCount = 0 - - def size(self): - return len(self.buf) - - def peek(self): - return self.buf[0] - - def write(self, b): - # b should be castable to byte array - self.buf += bytearray(b) - - def read(self, n): - if len(self.buf) < n: - print "reader err: buf less than n" - # TODO: exception - return - self.readCount += n - r = self.buf[:n] - self.buf = self.buf[n:] - return r + + def __init__(self, b): + self.buf = b + self.readCount = 0 + + def count(self): + return self.readCount + + def reset_count(self): + self.readCount = 0 + + def size(self): + return len(self.buf) + + def peek(self): + return self.buf[0] + + def write(self, b): + # b should be castable to byte array + self.buf += bytearray(b) + + def read(self, n): + if len(self.buf) < n: + print "reader err: buf less than n" + # TODO: exception + return + self.readCount += n + r = self.buf[:n] + self.buf = self.buf[n:] + return r # Buffer bytes off a tcp connection and read them off in chunks + + class ConnReader(): - def __init__(self, conn): - self.conn = conn - self.buf = bytearray() - - # blocking - def read(self, n): - while n > len(self.buf): - moreBuf = self.conn.recv(1024) - if not moreBuf: - raise IOError("dead connection") - self.buf = self.buf + bytearray(moreBuf) - - r = self.buf[:n] - self.buf = self.buf[n:] - return r + + def __init__(self, conn): + self.conn = conn + self.buf = bytearray() + + # blocking + def read(self, n): + while n > len(self.buf): + moreBuf = self.conn.recv(1024) + if not moreBuf: + raise IOError("dead connection") + self.buf = self.buf + bytearray(moreBuf) + + r = self.buf[:n] + self.buf = self.buf[n:] + return r diff --git a/example/python/tmsp/server.py b/example/python/tmsp/server.py index 0beb59d13..eeb974b8a 100644 --- a/example/python/tmsp/server.py +++ b/example/python/tmsp/server.py @@ -1,38 +1,44 @@ import socket import select import sys -import os -from wire import * -from reader import * -from msg import * +from wire import decode_varint, encode +from reader import BytesBuffer +from msg import RequestDecoder, message_types # hold the asyncronous state of a connection # ie. we may not get enough bytes on one read to decode the message + + class Connection(): - def __init__(self, fd, appCtx): - self.fd = fd - self.appCtx = appCtx - self.recBuf = BytesBuffer(bytearray()) - self.resBuf = BytesBuffer(bytearray()) - self.msgLength = 0 - self.decoder = RequestDecoder(self.recBuf) - self.inProgress = False # are we in the middle of a message - - def recv(this): - data = this.fd.recv(1024) - if not data: # what about len(data) == 0 - raise IOError("dead connection") - this.recBuf.write(data) + + def __init__(self, fd, appCtx): + self.fd = fd + self.appCtx = appCtx + self.recBuf = BytesBuffer(bytearray()) + self.resBuf = BytesBuffer(bytearray()) + self.msgLength = 0 + self.decoder = RequestDecoder(self.recBuf) + self.inProgress = False # are we in the middle of a message + + def recv(this): + data = this.fd.recv(1024) + if not data: # what about len(data) == 0 + raise IOError("dead connection") + this.recBuf.write(data) # TMSP server responds to messges by calling methods on the app + + class TMSPServer(): + def __init__(self, app, port=5410): - self.app = app - self.appMap = {} # map conn file descriptors to (appContext, reqBuf, resBuf, msgDecoder) + self.app = app + # map conn file descriptors to (appContext, reqBuf, resBuf, msgDecoder) + self.appMap = {} - self.port = port + self.port = port self.listen_backlog = 10 self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -49,13 +55,13 @@ class TMSPServer(): def handle_new_connection(self, r): new_fd, new_addr = r.accept() - new_fd.setblocking(0) # non-blocking + new_fd.setblocking(0) # non-blocking self.read_list.append(new_fd) self.write_list.append(new_fd) print 'new connection to', new_addr - appContext = self.app.open() - self.appMap[new_fd] = Connection(new_fd, appContext) + appContext = self.app.open() + self.appMap[new_fd] = Connection(new_fd, appContext) def handle_conn_closed(self, r): self.read_list.remove(r) @@ -64,137 +70,137 @@ class TMSPServer(): print "connection closed" def handle_recv(self, r): -# appCtx, recBuf, resBuf, conn - conn = self.appMap[r] - while True: - try: - print "recv loop" - # check if we need more data first - if conn.inProgress: - if conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength: - conn.recv() - else: - if conn.recBuf.size() == 0: - conn.recv() - - conn.inProgress = True - - # see if we have enough to get the message length - if conn.msgLength == 0: - ll = conn.recBuf.peek() - if conn.recBuf.size() < 1 + ll: - # we don't have enough bytes to read the length yet - return - print "decoding msg length" - conn.msgLength = decode_varint(conn.recBuf) - - # see if we have enough to decode the message - if conn.recBuf.size() < conn.msgLength: - return - - # now we can decode the message - - # first read the request type and get the particular msg decoder - typeByte = conn.recBuf.read(1) - typeByte = int(typeByte[0]) - resTypeByte = typeByte+0x10 - req_type = message_types[typeByte] - - if req_type == "flush": - # messages are length prefixed - conn.resBuf.write(encode(1)) - conn.resBuf.write([resTypeByte]) - sent = conn.fd.send(str(conn.resBuf.buf)) - conn.msgLength = 0 - conn.inProgress = False - conn.resBuf = BytesBuffer(bytearray()) - return - - decoder = getattr(conn.decoder, req_type) - - print "decoding args" - req_args = decoder() - print "got args", req_args - - # done decoding message - conn.msgLength = 0 - conn.inProgress = False - - req_f = getattr(conn.appCtx, req_type) - if req_args == None: - res = req_f() - elif isinstance(req_args, tuple): - res = req_f(*req_args) - else: - res = req_f(req_args) - - if isinstance(res, tuple): - res, ret_code = res - else: - ret_code = res - res = None - - print "called", req_type, "ret code:", ret_code - if ret_code != 0: - print "non-zero retcode:", ret_code - - if req_type in ("echo", "info"): # these dont return a ret code - enc = encode(res) - # messages are length prefixed - conn.resBuf.write(encode(len(enc) + 1)) - conn.resBuf.write([resTypeByte]) - conn.resBuf.write(enc) - else: - enc, encRet = encode(res), encode(ret_code) - # messages are length prefixed - conn.resBuf.write(encode(len(enc)+len(encRet)+1)) - conn.resBuf.write([resTypeByte]) - conn.resBuf.write(encRet) - conn.resBuf.write(enc) - except TypeError as e: - print "TypeError on reading from connection:", e - self.handle_conn_closed(r) - return - except ValueError as e: - print "ValueError on reading from connection:", e - self.handle_conn_closed(r) - return - except IOError as e: - print "IOError on reading from connection:", e - self.handle_conn_closed(r) - return - except Exception as e: - print "error reading from connection", str(e) # sys.exc_info()[0] # TODO better - self.handle_conn_closed(r) - return + # appCtx, recBuf, resBuf, conn + conn = self.appMap[r] + while True: + try: + print "recv loop" + # check if we need more data first + if conn.inProgress: + if (conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength): + conn.recv() + else: + if conn.recBuf.size() == 0: + conn.recv() + + conn.inProgress = True + + # see if we have enough to get the message length + if conn.msgLength == 0: + ll = conn.recBuf.peek() + if conn.recBuf.size() < 1 + ll: + # we don't have enough bytes to read the length yet + return + print "decoding msg length" + conn.msgLength = decode_varint(conn.recBuf) + + # see if we have enough to decode the message + if conn.recBuf.size() < conn.msgLength: + return + + # now we can decode the message + + # first read the request type and get the particular msg + # decoder + typeByte = conn.recBuf.read(1) + typeByte = int(typeByte[0]) + resTypeByte = typeByte + 0x10 + req_type = message_types[typeByte] + + if req_type == "flush": + # messages are length prefixed + conn.resBuf.write(encode(1)) + conn.resBuf.write([resTypeByte]) + conn.msgLength = 0 + conn.inProgress = False + conn.resBuf = BytesBuffer(bytearray()) + return + + decoder = getattr(conn.decoder, req_type) + + print "decoding args" + req_args = decoder() + print "got args", req_args + + # done decoding message + conn.msgLength = 0 + conn.inProgress = False + + req_f = getattr(conn.appCtx, req_type) + if req_args is None: + res = req_f() + elif isinstance(req_args, tuple): + res = req_f(*req_args) + else: + res = req_f(req_args) + + if isinstance(res, tuple): + res, ret_code = res + else: + ret_code = res + res = None + + print "called", req_type, "ret code:", ret_code + if ret_code != 0: + print "non-zero retcode:", ret_code + + if req_type in ("echo", "info"): # these dont return a ret code + enc = encode(res) + # messages are length prefixed + conn.resBuf.write(encode(len(enc) + 1)) + conn.resBuf.write([resTypeByte]) + conn.resBuf.write(enc) + else: + enc, encRet = encode(res), encode(ret_code) + # messages are length prefixed + conn.resBuf.write(encode(len(enc) + len(encRet) + 1)) + conn.resBuf.write([resTypeByte]) + conn.resBuf.write(encRet) + conn.resBuf.write(enc) + except TypeError as e: + print "TypeError on reading from connection:", e + self.handle_conn_closed(r) + return + except ValueError as e: + print "ValueError on reading from connection:", e + self.handle_conn_closed(r) + return + except IOError as e: + print "IOError on reading from connection:", e + self.handle_conn_closed(r) + return + except Exception as e: + # sys.exc_info()[0] # TODO better + print "error reading from connection", str(e) + self.handle_conn_closed(r) + return def main_loop(self): while not self.shutdown: - r_list, w_list, _ = select.select(self.read_list, self.write_list, [], 2.5) + r_list, w_list, _ = select.select( + self.read_list, self.write_list, [], 2.5) for r in r_list: - if (r == self.listener): + if (r == self.listener): try: self.handle_new_connection(r) - # undo adding to read list ... - except rameError as e: - print "Could not connect due to NameError:", e - except TypeError as e: - print "Could not connect due to TypeError:", e - except: - print "Could not connect due to unexpected error:", sys.exc_info()[0] - else: + # undo adding to read list ... + except NameError as e: + print "Could not connect due to NameError:", e + except TypeError as e: + print "Could not connect due to TypeError:", e + except: + print "Could not connect due to unexpected error:", sys.exc_info()[0] + else: self.handle_recv(r) - - def handle_shutdown(self): for r in self.read_list: r.close() for w in self.write_list: try: w.close() - except: pass + except: + pass self.shutdown = True - diff --git a/example/python/tmsp/wire.py b/example/python/tmsp/wire.py index 1f16855a7..1a07e89f1 100644 --- a/example/python/tmsp/wire.py +++ b/example/python/tmsp/wire.py @@ -2,101 +2,114 @@ # the decoder works off a reader # the encoder returns bytearray + def hex2bytes(h): - return bytearray(h.decode('hex')) + return bytearray(h.decode('hex')) + def bytes2hex(b): - if type(b) in (str, unicode): - return "".join([hex(ord(c))[2:].zfill(2) for c in b]) - else: - return bytes2hex(b.decode()) + if type(b) in (str, unicode): + return "".join([hex(ord(c))[2:].zfill(2) for c in b]) + else: + return bytes2hex(b.decode()) # expects uvarint64 (no crazy big nums!) def uvarint_size(i): - if i == 0: - return 0 - for j in xrange(1, 8): - if i < 1< int(0xF0) else False - if negate: size = size -0xF0 - i = decode_big_endian(reader, size) - if negate: i = i*(-1) - return i - + size = reader.read(1)[0] + if size == 0: + return 0 + + negate = True if size > int(0xF0) else False + if negate: + size = size - 0xF0 + i = decode_big_endian(reader, size) + if negate: + i = i * (-1) + return i + + def encode_string(s): - size = encode_varint(len(s)) - return size + bytearray(s) + size = encode_varint(len(s)) + return size + bytearray(s) + def decode_string(reader): - length = decode_varint(reader) - return str(reader.read(length)) + length = decode_varint(reader) + return str(reader.read(length)) + def encode_list(s): - b = bytearray() - map(b.extend, map(encode, s)) - return encode_varint(len(s)) + b + b = bytearray() + map(b.extend, map(encode, s)) + return encode_varint(len(s)) + b + def encode(s): - if s == None: - return bytearray() - if isinstance(s, int): - return encode_varint(s) - elif isinstance(s, str): - return encode_string(s) - elif isinstance(s, list): - return encode_list(s) - else: - print "UNSUPPORTED TYPE!", type(s), s - - -import binascii - + if s is None: + return bytearray() + if isinstance(s, int): + return encode_varint(s) + elif isinstance(s, str): + return encode_string(s) + elif isinstance(s, list): + return encode_list(s) + else: + print "UNSUPPORTED TYPE!", type(s), s + + if __name__ == '__main__': - ns = [100,100,1000,256] - ss = [2,5,5,2] - bs = map(encode_big_endian, ns,ss) - ds = map(decode_big_endian, bs,ss) - print ns - print [i[0] for i in ds] - - ss = ["abc", "hi there jim", "ok now what"] - e = map(encode_string, ss) - d = map(decode_string, e) - print ss - print [i[0] for i in d] + ns = [100, 100, 1000, 256] + ss = [2, 5, 5, 2] + bs = map(encode_big_endian, ns, ss) + ds = map(decode_big_endian, bs, ss) + print ns + print [i[0] for i in ds] + + ss = ["abc", "hi there jim", "ok now what"] + e = map(encode_string, ss) + d = map(decode_string, e) + print ss + print [i[0] for i in d]