close websocket connections in an improved way
This commit is contained in:
@ -1,16 +1,24 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import select
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IncompleteRead(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class WebSocketConnection(object):
|
||||
def __init__(self, handler, messageHandler):
|
||||
self.handler = handler
|
||||
self.messageHandler = messageHandler
|
||||
(self.interruptPipeRecv, self.interruptPipeSend) = os.pipe()
|
||||
self.open = True
|
||||
my_headers = self.handler.headers.items()
|
||||
my_header_keys = list(map(lambda x: x[0], my_headers))
|
||||
h_key_exists = lambda x: my_header_keys.count(x)
|
||||
@ -78,33 +86,49 @@ class WebSocketConnection(object):
|
||||
else:
|
||||
self.handler.wfile.flush()
|
||||
|
||||
def read_loop(self):
|
||||
open = True
|
||||
while open:
|
||||
header = self.handler.rfile.read(2)
|
||||
opcode = header[0] & 0x0F
|
||||
length = header[1] & 0x7F
|
||||
mask = (header[1] & 0x80) >> 7
|
||||
if length == 126:
|
||||
header = self.handler.rfile.read(2)
|
||||
length = (header[0] << 8) + header[1]
|
||||
if mask:
|
||||
masking_key = self.handler.rfile.read(4)
|
||||
data = self.handler.rfile.read(length)
|
||||
if mask:
|
||||
data = bytes([b ^ masking_key[index % 4] for (index, b) in enumerate(data)])
|
||||
if opcode == 1:
|
||||
message = data.decode("utf-8")
|
||||
self.messageHandler.handleTextMessage(self, message)
|
||||
elif opcode == 2:
|
||||
self.messageHandler.handleBinaryMessage(self, data)
|
||||
elif opcode == 8:
|
||||
open = False
|
||||
self.messageHandler.handleClose(self)
|
||||
else:
|
||||
logger.warning("unsupported opcode: {0}".format(opcode))
|
||||
def protected_read(self, num):
|
||||
data = self.handler.rfile.read(num)
|
||||
if len(data) != num:
|
||||
raise IncompleteRead()
|
||||
return data
|
||||
|
||||
def interrupt(self):
|
||||
os.write(self.interruptPipeSend, bytes(0x00))
|
||||
|
||||
def read_loop(self):
|
||||
self.open = True
|
||||
while self.open:
|
||||
try:
|
||||
(read, _, _) = select.select([self.interruptPipeRecv, self.handler.rfile], [], [])
|
||||
if read[0] == self.handler.rfile:
|
||||
header = self.protected_read(2)
|
||||
opcode = header[0] & 0x0F
|
||||
length = header[1] & 0x7F
|
||||
mask = (header[1] & 0x80) >> 7
|
||||
if length == 126:
|
||||
header = self.protected_read(2)
|
||||
length = (header[0] << 8) + header[1]
|
||||
if mask:
|
||||
masking_key = self.protected_read(4)
|
||||
data = self.protected_read(length)
|
||||
if mask:
|
||||
data = bytes([b ^ masking_key[index % 4] for (index, b) in enumerate(data)])
|
||||
if opcode == 1:
|
||||
message = data.decode("utf-8")
|
||||
self.messageHandler.handleTextMessage(self, message)
|
||||
elif opcode == 2:
|
||||
self.messageHandler.handleBinaryMessage(self, data)
|
||||
elif opcode == 8:
|
||||
logger.debug("websocket close frame received; closing connection")
|
||||
self.open = False
|
||||
else:
|
||||
logger.warning("unsupported opcode: {0}".format(opcode))
|
||||
except IncompleteRead:
|
||||
logger.warning("incomplete websocket read; closing socket")
|
||||
self.open = False
|
||||
|
||||
logger.debug("websocket loop ended; sending close frame")
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
header = self.get_header(0, 8)
|
||||
self.handler.wfile.write(header)
|
||||
@ -114,11 +138,9 @@ class WebSocketConnection(object):
|
||||
except OSError:
|
||||
logger.exception("OSError while writing close frame:")
|
||||
|
||||
try:
|
||||
self.handler.finish()
|
||||
self.handler.connection.close()
|
||||
except Exception:
|
||||
logger.exception("while closing connection:")
|
||||
def close(self):
|
||||
self.open = False
|
||||
self.interrupt()
|
||||
|
||||
|
||||
class WebSocketException(Exception):
|
||||
|
Reference in New Issue
Block a user