|
|
- import socket
- import select
- import sys
- import logging
-
- from .wire import decode_varint, encode
- from .reader import BytesBuffer
- from .msg import RequestDecoder, message_types
-
- # hold the asyncronous state of a connection
- # ie. we may not get enough bytes on one read to decode the message
-
- logger = logging.getLogger(__name__)
-
- class Connection():
-
- def __init__(self, fd, app):
- self.fd = fd
- self.app = app
- self.recBuf = BytesBuffer(bytearray())
- self.resBuf = BytesBuffer(bytearray())
- self.msgLength = 0
- self.decoder = RequestDecoder(self.recBuf)
- self.inProgress = False # are we in the middle of a message
-
- def recv(this):
- data = this.fd.recv(1024)
- if not data: # what about len(data) == 0
- raise IOError("dead connection")
- this.recBuf.write(data)
-
- # TMSP server responds to messges by calling methods on the app
-
- class TMSPServer():
-
- def __init__(self, app, port=5410):
- self.app = app
- # map conn file descriptors to (app, reqBuf, resBuf, msgDecoder)
- self.appMap = {}
-
- self.port = port
- self.listen_backlog = 10
-
- self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.listener.setblocking(0)
- self.listener.bind(('', port))
-
- self.listener.listen(self.listen_backlog)
-
- self.shutdown = False
-
- self.read_list = [self.listener]
- self.write_list = []
-
- def handle_new_connection(self, r):
- new_fd, new_addr = r.accept()
- new_fd.setblocking(0) # non-blocking
- self.read_list.append(new_fd)
- self.write_list.append(new_fd)
- print('new connection to', new_addr)
-
- self.appMap[new_fd] = Connection(new_fd, self.app)
-
- def handle_conn_closed(self, r):
- self.read_list.remove(r)
- self.write_list.remove(r)
- r.close()
- print("connection closed")
-
- def handle_recv(self, r):
- # app, recBuf, resBuf, conn
- conn = self.appMap[r]
- while True:
- try:
- print("recv loop")
- # check if we need more data first
- if conn.inProgress:
- if (conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength):
- conn.recv()
- else:
- if conn.recBuf.size() == 0:
- conn.recv()
-
- conn.inProgress = True
-
- # see if we have enough to get the message length
- if conn.msgLength == 0:
- ll = conn.recBuf.peek()
- if conn.recBuf.size() < 1 + ll:
- # we don't have enough bytes to read the length yet
- return
- print("decoding msg length")
- conn.msgLength = decode_varint(conn.recBuf)
-
- # see if we have enough to decode the message
- if conn.recBuf.size() < conn.msgLength:
- return
-
- # now we can decode the message
-
- # first read the request type and get the particular msg
- # decoder
- typeByte = conn.recBuf.read(1)
- typeByte = int(typeByte[0])
- resTypeByte = typeByte + 0x10
- req_type = message_types[typeByte]
-
- if req_type == "flush":
- # messages are length prefixed
- conn.resBuf.write(encode(1))
- conn.resBuf.write([resTypeByte])
- conn.fd.send(conn.resBuf.buf)
- conn.msgLength = 0
- conn.inProgress = False
- conn.resBuf = BytesBuffer(bytearray())
- return
-
- decoder = getattr(conn.decoder, req_type)
-
- print("decoding args")
- req_args = decoder()
- print("got args", req_args)
-
- # done decoding message
- conn.msgLength = 0
- conn.inProgress = False
-
- req_f = getattr(conn.app, req_type)
- if req_args is None:
- res = req_f()
- elif isinstance(req_args, tuple):
- res = req_f(*req_args)
- else:
- res = req_f(req_args)
-
- if isinstance(res, tuple):
- res, ret_code = res
- else:
- ret_code = res
- res = None
-
- print("called", req_type, "ret code:", ret_code, 'res:', res)
- if ret_code != 0:
- print("non-zero retcode:", ret_code)
-
- if req_type in ("echo", "info"): # these dont return a ret code
- enc = encode(res)
- # messages are length prefixed
- conn.resBuf.write(encode(len(enc) + 1))
- conn.resBuf.write([resTypeByte])
- conn.resBuf.write(enc)
- else:
- enc, encRet = encode(res), encode(ret_code)
- # messages are length prefixed
- conn.resBuf.write(encode(len(enc) + len(encRet) + 1))
- conn.resBuf.write([resTypeByte])
- conn.resBuf.write(encRet)
- conn.resBuf.write(enc)
- except IOError as e:
- print("IOError on reading from connection:", e)
- self.handle_conn_closed(r)
- return
- except Exception as e:
- logger.exception("error reading from connection")
- self.handle_conn_closed(r)
- return
-
- def main_loop(self):
- while not self.shutdown:
- r_list, w_list, _ = select.select(
- self.read_list, self.write_list, [], 2.5)
-
- for r in r_list:
- if (r == self.listener):
- try:
- self.handle_new_connection(r)
- # undo adding to read list ...
- except NameError as e:
- print("Could not connect due to NameError:", e)
- except TypeError as e:
- print("Could not connect due to TypeError:", e)
- except:
- print("Could not connect due to unexpected error:", sys.exc_info()[0])
- else:
- self.handle_recv(r)
-
- def handle_shutdown(self):
- for r in self.read_list:
- r.close()
- for w in self.write_list:
- try:
- w.close()
- except Exception as e:
- print(e) # TODO: add logging
- self.shutdown = True
|