You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
7.0 KiB

9 years ago
  1. import socket
  2. import select
  3. import sys
  4. from wire import decode_varint, encode
  5. from reader import BytesBuffer
  6. from msg import RequestDecoder, message_types
  7. # hold the asyncronous state of a connection
  8. # ie. we may not get enough bytes on one read to decode the message
  9. class Connection():
  10. def __init__(self, fd, app):
  11. self.fd = fd
  12. self.app = app
  13. self.recBuf = BytesBuffer(bytearray())
  14. self.resBuf = BytesBuffer(bytearray())
  15. self.msgLength = 0
  16. self.decoder = RequestDecoder(self.recBuf)
  17. self.inProgress = False # are we in the middle of a message
  18. def recv(this):
  19. data = this.fd.recv(1024)
  20. if not data: # what about len(data) == 0
  21. raise IOError("dead connection")
  22. this.recBuf.write(data)
  23. # TMSP server responds to messges by calling methods on the app
  24. class TMSPServer():
  25. def __init__(self, app, port=5410):
  26. self.app = app
  27. # map conn file descriptors to (app, reqBuf, resBuf, msgDecoder)
  28. self.appMap = {}
  29. self.port = port
  30. self.listen_backlog = 10
  31. self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  32. self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  33. self.listener.setblocking(0)
  34. self.listener.bind(('', port))
  35. self.listener.listen(self.listen_backlog)
  36. self.shutdown = False
  37. self.read_list = [self.listener]
  38. self.write_list = []
  39. def handle_new_connection(self, r):
  40. new_fd, new_addr = r.accept()
  41. new_fd.setblocking(0) # non-blocking
  42. self.read_list.append(new_fd)
  43. self.write_list.append(new_fd)
  44. print 'new connection to', new_addr
  45. self.appMap[new_fd] = Connection(new_fd, self.app)
  46. def handle_conn_closed(self, r):
  47. self.read_list.remove(r)
  48. self.write_list.remove(r)
  49. r.close()
  50. print "connection closed"
  51. def handle_recv(self, r):
  52. # app, recBuf, resBuf, conn
  53. conn = self.appMap[r]
  54. while True:
  55. try:
  56. print "recv loop"
  57. # check if we need more data first
  58. if conn.inProgress:
  59. if (conn.msgLength == 0 or conn.recBuf.size() < conn.msgLength):
  60. conn.recv()
  61. else:
  62. if conn.recBuf.size() == 0:
  63. conn.recv()
  64. conn.inProgress = True
  65. # see if we have enough to get the message length
  66. if conn.msgLength == 0:
  67. ll = conn.recBuf.peek()
  68. if conn.recBuf.size() < 1 + ll:
  69. # we don't have enough bytes to read the length yet
  70. return
  71. print "decoding msg length"
  72. conn.msgLength = decode_varint(conn.recBuf)
  73. # see if we have enough to decode the message
  74. if conn.recBuf.size() < conn.msgLength:
  75. return
  76. # now we can decode the message
  77. # first read the request type and get the particular msg
  78. # decoder
  79. typeByte = conn.recBuf.read(1)
  80. typeByte = int(typeByte[0])
  81. resTypeByte = typeByte + 0x10
  82. req_type = message_types[typeByte]
  83. if req_type == "flush":
  84. # messages are length prefixed
  85. conn.resBuf.write(encode(1))
  86. conn.resBuf.write([resTypeByte])
  87. conn.fd.send(str(conn.resBuf.buf))
  88. conn.msgLength = 0
  89. conn.inProgress = False
  90. conn.resBuf = BytesBuffer(bytearray())
  91. return
  92. decoder = getattr(conn.decoder, req_type)
  93. print "decoding args"
  94. req_args = decoder()
  95. print "got args", req_args
  96. # done decoding message
  97. conn.msgLength = 0
  98. conn.inProgress = False
  99. req_f = getattr(conn.app, req_type)
  100. if req_args is None:
  101. res = req_f()
  102. elif isinstance(req_args, tuple):
  103. res = req_f(*req_args)
  104. else:
  105. res = req_f(req_args)
  106. if isinstance(res, tuple):
  107. res, ret_code = res
  108. else:
  109. ret_code = res
  110. res = None
  111. print "called", req_type, "ret code:", ret_code
  112. if ret_code != 0:
  113. print "non-zero retcode:", ret_code
  114. if req_type in ("echo", "info"): # these dont return a ret code
  115. enc = encode(res)
  116. # messages are length prefixed
  117. conn.resBuf.write(encode(len(enc) + 1))
  118. conn.resBuf.write([resTypeByte])
  119. conn.resBuf.write(enc)
  120. else:
  121. enc, encRet = encode(res), encode(ret_code)
  122. # messages are length prefixed
  123. conn.resBuf.write(encode(len(enc) + len(encRet) + 1))
  124. conn.resBuf.write([resTypeByte])
  125. conn.resBuf.write(encRet)
  126. conn.resBuf.write(enc)
  127. except TypeError as e:
  128. print "TypeError on reading from connection:", e
  129. self.handle_conn_closed(r)
  130. return
  131. except ValueError as e:
  132. print "ValueError on reading from connection:", e
  133. self.handle_conn_closed(r)
  134. return
  135. except IOError as e:
  136. print "IOError on reading from connection:", e
  137. self.handle_conn_closed(r)
  138. return
  139. except Exception as e:
  140. # sys.exc_info()[0] # TODO better
  141. print "error reading from connection", str(e)
  142. self.handle_conn_closed(r)
  143. return
  144. def main_loop(self):
  145. while not self.shutdown:
  146. r_list, w_list, _ = select.select(
  147. self.read_list, self.write_list, [], 2.5)
  148. for r in r_list:
  149. if (r == self.listener):
  150. try:
  151. self.handle_new_connection(r)
  152. # undo adding to read list ...
  153. except NameError as e:
  154. print "Could not connect due to NameError:", e
  155. except TypeError as e:
  156. print "Could not connect due to TypeError:", e
  157. except:
  158. print "Could not connect due to unexpected error:", sys.exc_info()[0]
  159. else:
  160. self.handle_recv(r)
  161. def handle_shutdown(self):
  162. for r in self.read_list:
  163. r.close()
  164. for w in self.write_list:
  165. try:
  166. w.close()
  167. except Exception as e:
  168. print(e) # TODO: add logging
  169. self.shutdown = True