add ping / pong to keep the websockets running

This commit is contained in:
Jakob Ketterl 2019-09-26 22:57:10 +02:00
parent 2c4add6aad
commit 76fe11741a

View File

@ -9,6 +9,12 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OPCODE_TEXT_MESSAGE = 0x01
OPCODE_BINARY_MESSAGE = 0x02
OPCODE_CLOSE = 0x08
OPCODE_PING = 0x09
OPCODE_PONG = 0x0A
class IncompleteRead(Exception): class IncompleteRead(Exception):
pass pass
@ -55,6 +61,8 @@ class WebSocketConnection(object):
ws_key_toreturn.decode() ws_key_toreturn.decode()
).encode() ).encode()
) )
self.pingTimer = None
self.resetPing()
def get_header(self, size, opcode): def get_header(self, size, opcode):
ws_first_byte = 0b10000000 | (opcode & 0x0F) ws_first_byte = 0b10000000 | (opcode & 0x0F)
@ -90,13 +98,17 @@ class WebSocketConnection(object):
# string-type messages are sent as text frames # string-type messages are sent as text frames
if type(data) == str: if type(data) == str:
header = self.get_header(len(data), 1) header = self.get_header(len(data), OPCODE_TEXT_MESSAGE)
data_to_send = header + data.encode("utf-8") data_to_send = header + data.encode("utf-8")
# anything else as binary # anything else as binary
else: else:
header = self.get_header(len(data), 2) header = self.get_header(len(data), OPCODE_BINARY_MESSAGE)
data_to_send = header + data data_to_send = header + data
self._sendBytes(data_to_send)
def _sendBytes(self, data_to_send):
def chunks(l, n): def chunks(l, n):
"""Yield successive n-sized chunks from l.""" """Yield successive n-sized chunks from l."""
for i in range(0, len(l), n): for i in range(0, len(l), n):
@ -122,7 +134,11 @@ class WebSocketConnection(object):
logger.exception("ValueError while writing data") logger.exception("ValueError while writing data")
self.close() self.close()
def protected_read(self, num): def interrupt(self):
self.interruptPipeSend.send(bytes(0x00))
def read_loop(self):
def protected_read(num):
data = self.handler.rfile.read(num) data = self.handler.rfile.read(num)
if data is None: if data is None:
raise Drained() raise Drained()
@ -130,36 +146,38 @@ class WebSocketConnection(object):
raise IncompleteRead() raise IncompleteRead()
return data return data
def interrupt(self):
self.interruptPipeSend.send(bytes(0x00))
def read_loop(self):
WebSocketConnection.connections.append(self) WebSocketConnection.connections.append(self)
self.open = True self.open = True
while self.open: while self.open:
(read, _, _) = select.select([self.interruptPipeRecv, self.handler.rfile], [], []) (read, _, _) = select.select([self.interruptPipeRecv, self.handler.rfile], [], [])
if self.handler.rfile in read: if self.handler.rfile in read:
available = True available = True
self.resetPing()
while self.open and available: while self.open and available:
try: try:
header = self.protected_read(2) header = protected_read(2)
opcode = header[0] & 0x0F opcode = header[0] & 0x0F
length = header[1] & 0x7F length = header[1] & 0x7F
mask = (header[1] & 0x80) >> 7 mask = (header[1] & 0x80) >> 7
if length == 126: if length == 126:
header = self.protected_read(2) header = protected_read(2)
length = (header[0] << 8) + header[1] length = (header[0] << 8) + header[1]
if mask: if mask:
masking_key = self.protected_read(4) masking_key = protected_read(4)
data = self.protected_read(length) data = protected_read(length)
if mask: if mask:
data = bytes([b ^ masking_key[index % 4] for (index, b) in enumerate(data)]) data = bytes([b ^ masking_key[index % 4] for (index, b) in enumerate(data)])
if opcode == 1: if opcode == OPCODE_TEXT_MESSAGE:
message = data.decode("utf-8") message = data.decode("utf-8")
self.messageHandler.handleTextMessage(self, message) self.messageHandler.handleTextMessage(self, message)
elif opcode == 2: elif opcode == OPCODE_BINARY_MESSAGE:
self.messageHandler.handleBinaryMessage(self, data) self.messageHandler.handleBinaryMessage(self, data)
elif opcode == 8: elif opcode == OPCODE_PING:
self.sendPong()
elif opcode == OPCODE_PONG:
# since every read resets the ping timer, there's nothing to do here.
pass
elif opcode == OPCODE_CLOSE:
logger.debug("websocket close frame received; closing connection") logger.debug("websocket close frame received; closing connection")
self.open = False self.open = False
else: else:
@ -176,17 +194,12 @@ class WebSocketConnection(object):
logger.debug("websocket loop ended; shutting down") logger.debug("websocket loop ended; shutting down")
self.messageHandler.handleClose() self.messageHandler.handleClose()
self.cancelPing()
logger.debug("websocket loop ended; sending close frame") logger.debug("websocket loop ended; sending close frame")
try: header = self.get_header(0, OPCODE_CLOSE)
header = self.get_header(0, 8) self._sendBytes(header)
self.handler.wfile.write(header)
self.handler.wfile.flush()
except ValueError:
logger.exception("ValueError while writing close frame:")
except OSError:
logger.exception("OSError while writing close frame:")
try: try:
WebSocketConnection.connections.remove(self) WebSocketConnection.connections.remove(self)
@ -197,6 +210,24 @@ class WebSocketConnection(object):
self.open = False self.open = False
self.interrupt() self.interrupt()
def cancelPing(self):
if self.pingTimer:
self.pingTimer.cancel()
def resetPing(self):
self.cancelPing()
self.pingTimer = threading.Timer(30, self.sendPing)
self.pingTimer.start()
def sendPing(self):
header = self.get_header(0, OPCODE_PING)
self._sendBytes(header)
self.resetPing()
def sendPong(self):
header = self.get_header(0, OPCODE_PONG)
self._sendBytes(header)
class WebSocketException(Exception): class WebSocketException(Exception):
pass pass