un-couple messaging between connections; use non-blocking io

This commit is contained in:
Jakob Ketterl
2019-09-21 22:10:16 +02:00
parent 6ec85aa349
commit 1ed69de5b0
2 changed files with 92 additions and 47 deletions

View File

@ -3,6 +3,7 @@ import hashlib
import json
import os
import select
import threading
import logging
@ -16,9 +17,11 @@ class IncompleteRead(Exception):
class WebSocketConnection(object):
def __init__(self, handler, messageHandler):
self.handler = handler
self.handler.connection.setblocking(0)
self.messageHandler = messageHandler
(self.interruptPipeRecv, self.interruptPipeSend) = os.pipe()
self.open = True
self.sendLock = threading.Lock()
my_headers = self.handler.headers.items()
my_header_keys = list(map(lambda x: x[0], my_headers))
h_key_exists = lambda x: my_header_keys.count(x)
@ -79,53 +82,80 @@ class WebSocketConnection(object):
else:
header = self.get_header(len(data), 2)
data_to_send = header + data
written = self.handler.wfile.write(data_to_send)
if written != len(data_to_send):
logger.error("incomplete write! closing socket!")
self.close()
else:
self.handler.wfile.flush()
def chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i : i + n]
with self.sendLock:
for chunk in chunks(data_to_send, 1024):
(_, write, _) = select.select([], [self.handler.wfile], [], 10)
if self.handler.wfile in write:
written = self.handler.wfile.write(chunk)
if written != len(chunk):
logger.error("incomplete write! closing socket!")
self.close()
else:
logger.debug("socket not returned from select; closing")
self.close()
def protected_read(self, num):
data = self.handler.rfile.read(num)
if len(data) != num:
if data is None or len(data) != num:
raise IncompleteRead()
return data
def protected_send(self, data):
try:
self.send(data)
# these exception happen when the socket is closed
except OSError:
logger.exception("OSError while writing data")
self.close()
except ValueError:
logger.exception("ValueError while writing data")
self.close()
def interrupt(self):
os.write(self.interruptPipeSend, bytes(0x00))
def read_loop(self):
self.open = True
while self.open:
try:
(read, _, _) = select.select([self.interruptPipeRecv, self.handler.rfile], [], [])
if read[0] == self.handler.rfile:
header = self.protected_read(2)
opcode = header[0] & 0x0F
length = header[1] & 0x7F
mask = (header[1] & 0x80) >> 7
if length == 126:
(read, _, _) = select.select([self.interruptPipeRecv, self.handler.rfile], [], [])
if self.handler.rfile in read:
available = True
while available:
try:
header = self.protected_read(2)
length = (header[0] << 8) + header[1]
if mask:
masking_key = self.protected_read(4)
data = self.protected_read(length)
if mask:
data = bytes([b ^ masking_key[index % 4] for (index, b) in enumerate(data)])
if opcode == 1:
message = data.decode("utf-8")
self.messageHandler.handleTextMessage(self, message)
elif opcode == 2:
self.messageHandler.handleBinaryMessage(self, data)
elif opcode == 8:
logger.debug("websocket close frame received; closing connection")
self.open = False
else:
logger.warning("unsupported opcode: {0}".format(opcode))
except IncompleteRead:
logger.warning("incomplete websocket read; closing socket")
self.open = False
opcode = header[0] & 0x0F
length = header[1] & 0x7F
mask = (header[1] & 0x80) >> 7
if length == 126:
header = self.protected_read(2)
length = (header[0] << 8) + header[1]
if mask:
masking_key = self.protected_read(4)
data = self.protected_read(length)
if mask:
data = bytes([b ^ masking_key[index % 4] for (index, b) in enumerate(data)])
if opcode == 1:
message = data.decode("utf-8")
self.messageHandler.handleTextMessage(self, message)
elif opcode == 2:
self.messageHandler.handleBinaryMessage(self, data)
elif opcode == 8:
logger.debug("websocket close frame received; closing connection")
self.open = False
else:
logger.warning("unsupported opcode: {0}".format(opcode))
except IncompleteRead:
available = False
logger.debug("websocket loop ended; shutting down")
self.messageHandler.handleClose()
logger.debug("websocket loop ended; sending close frame")