add ping / pong to keep the websockets running

This commit is contained in:
Jakob Ketterl 2019-09-26 22:57:10 +02:00
parent 2c4add6aad
commit 76fe11741a
1 changed files with 56 additions and 25 deletions

View File

@ -9,6 +9,12 @@ import logging
logger = logging.getLogger(__name__)
OPCODE_TEXT_MESSAGE = 0x01
OPCODE_BINARY_MESSAGE = 0x02
OPCODE_CLOSE = 0x08
OPCODE_PING = 0x09
OPCODE_PONG = 0x0A
class IncompleteRead(Exception):
pass
@ -55,6 +61,8 @@ class WebSocketConnection(object):
ws_key_toreturn.decode()
).encode()
)
self.pingTimer = None
self.resetPing()
def get_header(self, size, opcode):
ws_first_byte = 0b10000000 | (opcode & 0x0F)
@ -90,13 +98,17 @@ class WebSocketConnection(object):
# string-type messages are sent as text frames
if type(data) == str:
header = self.get_header(len(data), 1)
header = self.get_header(len(data), OPCODE_TEXT_MESSAGE)
data_to_send = header + data.encode("utf-8")
# anything else as binary
else:
header = self.get_header(len(data), 2)
header = self.get_header(len(data), OPCODE_BINARY_MESSAGE)
data_to_send = header + data
self._sendBytes(data_to_send)
def _sendBytes(self, data_to_send):
def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
@ -122,44 +134,50 @@ class WebSocketConnection(object):
logger.exception("ValueError while writing data")
self.close()
def protected_read(self, num):
data = self.handler.rfile.read(num)
if data is None:
raise Drained()
if len(data) != num:
raise IncompleteRead()
return data
def interrupt(self):
self.interruptPipeSend.send(bytes(0x00))
def read_loop(self):
def protected_read(num):
data = self.handler.rfile.read(num)
if data is None:
raise Drained()
if len(data) != num:
raise IncompleteRead()
return data
WebSocketConnection.connections.append(self)
self.open = True
while self.open:
(read, _, _) = select.select([self.interruptPipeRecv, self.handler.rfile], [], [])
if self.handler.rfile in read:
available = True
self.resetPing()
while self.open and available:
try:
header = self.protected_read(2)
header = protected_read(2)
opcode = header[0] & 0x0F
length = header[1] & 0x7F
mask = (header[1] & 0x80) >> 7
if length == 126:
header = self.protected_read(2)
header = protected_read(2)
length = (header[0] << 8) + header[1]
if mask:
masking_key = self.protected_read(4)
data = self.protected_read(length)
masking_key = protected_read(4)
data = protected_read(length)
if mask:
data = bytes([b ^ masking_key[index % 4] for (index, b) in enumerate(data)])
if opcode == 1:
if opcode == OPCODE_TEXT_MESSAGE:
message = data.decode("utf-8")
self.messageHandler.handleTextMessage(self, message)
elif opcode == 2:
elif opcode == OPCODE_BINARY_MESSAGE:
self.messageHandler.handleBinaryMessage(self, data)
elif opcode == 8:
elif opcode == OPCODE_PING:
self.sendPong()
elif opcode == OPCODE_PONG:
# since every read resets the ping timer, there's nothing to do here.
pass
elif opcode == OPCODE_CLOSE:
logger.debug("websocket close frame received; closing connection")
self.open = False
else:
@ -176,17 +194,12 @@ class WebSocketConnection(object):
logger.debug("websocket loop ended; shutting down")
self.messageHandler.handleClose()
self.cancelPing()
logger.debug("websocket loop ended; sending close frame")
try:
header = self.get_header(0, 8)
self.handler.wfile.write(header)
self.handler.wfile.flush()
except ValueError:
logger.exception("ValueError while writing close frame:")
except OSError:
logger.exception("OSError while writing close frame:")
header = self.get_header(0, OPCODE_CLOSE)
self._sendBytes(header)
try:
WebSocketConnection.connections.remove(self)
@ -197,6 +210,24 @@ class WebSocketConnection(object):
self.open = False
self.interrupt()
def cancelPing(self):
if self.pingTimer:
self.pingTimer.cancel()
def resetPing(self):
self.cancelPing()
self.pingTimer = threading.Timer(30, self.sendPing)
self.pingTimer.start()
def sendPing(self):
header = self.get_header(0, OPCODE_PING)
self._sendBytes(header)
self.resetPing()
def sendPong(self):
header = self.get_header(0, OPCODE_PONG)
self._sendBytes(header)
class WebSocketException(Exception):
pass