You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
6.7 KiB

9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
9 years ago
  1. import socket
  2. import select
  3. import sys
  4. import logging
  5. from .wire import decode_varint, encode
  6. from .reader import BytesBuffer
  7. from .msg import RequestDecoder, message_types
  8. # hold the asyncronous state of a connection
  9. # ie. we may not get enough bytes on one read to decode the message
  10. logger = logging.getLogger(__name__)
  11. class Connection():
  12. def __init__(self, fd, app):
  13. self.fd = fd
  14. self.app = app
  15. self.recBuf = BytesBuffer(bytearray())
  16. self.resBuf = BytesBuffer(bytearray())
  17. self.msgLength = 0
  18. self.decoder = RequestDecoder(self.recBuf)
  19. self.inProgress = False # are we in the middle of a message
  20. def recv(this):
  21. data = this.fd.recv(1024)
  22. if not data: # what about len(data) == 0
  23. raise IOError("dead connection")
  24. this.recBuf.write(data)
  25. # TMSP server responds to messges by calling methods on the app
  26. class TMSPServer():
  27. def __init__(self, app, port=5410):
  28. self.app = app
  29. # map conn file descriptors to (app, reqBuf, resBuf, msgDecoder)
  30. self.appMap = {}
  31. self.port = port
  32. self.listen_backlog = 10
  33. self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  34. self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  35. self.listener.setblocking(0)
  36. self.listener.bind(('', port))
  37. self.listener.listen(self.listen_backlog)
  38. self.shutdown = False
  39. self.read_list = [self.listener]
  40. self.write_list = []
  41. def handle_new_connection(self, r):
  42. new_fd, new_addr = r.accept()
  43. new_fd.setblocking(0) # non-blocking
  44. self.read_list.append(new_fd)
  45. self.write_list.append(new_fd)
  46. print('new connection to', new_addr)
  47. self.appMap[new_fd] = Connection(new_fd, self.app)
  48. def handle_conn_closed(self, r):
  49. self.read_list.remove(r)
  50. self.write_list.remove(r)
  51. r.close()
  52. print("connection closed")
  53. def handle_recv(self, r):
  54. # app, recBuf, resBuf, conn
  55. conn = self.appMap[r]
  56. while True:
  57. try:
  58. print("recv loop")
  59. # check if we need more data first
  60. if conn.inProgress:
  61. if (conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength):
  62. conn.recv()
  63. else:
  64. if conn.recBuf.size() == 0:
  65. conn.recv()
  66. conn.inProgress = True
  67. # see if we have enough to get the message length
  68. if conn.msgLength == 0:
  69. ll = conn.recBuf.peek()
  70. if conn.recBuf.size() < 1 + ll:
  71. # we don't have enough bytes to read the length yet
  72. return
  73. print("decoding msg length")
  74. conn.msgLength = decode_varint(conn.recBuf)
  75. # see if we have enough to decode the message
  76. if conn.recBuf.size() < conn.msgLength:
  77. return
  78. # now we can decode the message
  79. # first read the request type and get the particular msg
  80. # decoder
  81. typeByte = conn.recBuf.read(1)
  82. typeByte = int(typeByte[0])
  83. resTypeByte = typeByte + 0x10
  84. req_type = message_types[typeByte]
  85. if req_type == "flush":
  86. # messages are length prefixed
  87. conn.resBuf.write(encode(1))
  88. conn.resBuf.write([resTypeByte])
  89. conn.fd.send(conn.resBuf.buf)
  90. conn.msgLength = 0
  91. conn.inProgress = False
  92. conn.resBuf = BytesBuffer(bytearray())
  93. return
  94. decoder = getattr(conn.decoder, req_type)
  95. print("decoding args")
  96. req_args = decoder()
  97. print("got args", req_args)
  98. # done decoding message
  99. conn.msgLength = 0
  100. conn.inProgress = False
  101. req_f = getattr(conn.app, req_type)
  102. if req_args is None:
  103. res = req_f()
  104. elif isinstance(req_args, tuple):
  105. res = req_f(*req_args)
  106. else:
  107. res = req_f(req_args)
  108. if isinstance(res, tuple):
  109. res, ret_code = res
  110. else:
  111. ret_code = res
  112. res = None
  113. print("called", req_type, "ret code:", ret_code, 'res:', res)
  114. if ret_code != 0:
  115. print("non-zero retcode:", ret_code)
  116. if req_type in ("echo", "info"): # these dont return a ret code
  117. enc = encode(res)
  118. # messages are length prefixed
  119. conn.resBuf.write(encode(len(enc) + 1))
  120. conn.resBuf.write([resTypeByte])
  121. conn.resBuf.write(enc)
  122. else:
  123. enc, encRet = encode(res), encode(ret_code)
  124. # messages are length prefixed
  125. conn.resBuf.write(encode(len(enc) + len(encRet) + 1))
  126. conn.resBuf.write([resTypeByte])
  127. conn.resBuf.write(encRet)
  128. conn.resBuf.write(enc)
  129. except IOError as e:
  130. print("IOError on reading from connection:", e)
  131. self.handle_conn_closed(r)
  132. return
  133. except Exception as e:
  134. logger.exception("error reading from connection")
  135. self.handle_conn_closed(r)
  136. return
  137. def main_loop(self):
  138. while not self.shutdown:
  139. r_list, w_list, _ = select.select(
  140. self.read_list, self.write_list, [], 2.5)
  141. for r in r_list:
  142. if (r == self.listener):
  143. try:
  144. self.handle_new_connection(r)
  145. # undo adding to read list ...
  146. except NameError as e:
  147. print("Could not connect due to NameError:", e)
  148. except TypeError as e:
  149. print("Could not connect due to TypeError:", e)
  150. except:
  151. print("Could not connect due to unexpected error:", sys.exc_info()[0])
  152. else:
  153. self.handle_recv(r)
  154. def handle_shutdown(self):
  155. for r in self.read_list:
  156. r.close()
  157. for w in self.write_list:
  158. try:
  159. w.close()
  160. except Exception as e:
  161. print(e) # TODO: add logging
  162. self.shutdown = True