271 lines
9.3 KiB
Python
271 lines
9.3 KiB
Python
import usocket as socket
|
|
import ubinascii
|
|
import uos
|
|
|
|
class WebSocketError(Exception):
|
|
pass
|
|
|
|
class WebSocketClient:
|
|
def __init__(self, uri, timeout=5):
|
|
self.sock = None
|
|
self.uri = uri
|
|
self.timeout = timeout
|
|
self.unread_messages = [] # Queue for buffered messages
|
|
self.buffer = bytearray(4096) # Pre-allocated buffer for small messages
|
|
self.connect()
|
|
|
|
def connect(self):
|
|
uri = self.uri
|
|
assert uri.startswith("ws://")
|
|
|
|
uri = uri[5:]
|
|
if "/" in uri:
|
|
host, path = uri.split("/", 1)
|
|
else:
|
|
host, path = uri, ""
|
|
path = "/" + path
|
|
|
|
if ":" in host:
|
|
host, port = host.split(":")
|
|
port = int(port)
|
|
else:
|
|
port = 80
|
|
|
|
print(f"Connecting to {host}:{port}{path}...")
|
|
self.sock = socket.socket()
|
|
|
|
# Add timeout
|
|
self.sock.settimeout(self.timeout)
|
|
|
|
addr_info = socket.getaddrinfo(host, port)
|
|
addr = addr_info[0][-1]
|
|
print(f"Resolved address: {addr}")
|
|
|
|
try:
|
|
self.sock.connect(addr)
|
|
except OSError as e:
|
|
print(f"Socket connect failed: {e}")
|
|
if e.args[0] == 113:
|
|
print("Hint: Check firewall settings on server or if server is running.")
|
|
raise
|
|
|
|
# Random key
|
|
key = ubinascii.b2a_base64(uos.urandom(16)).strip()
|
|
|
|
|
|
req = "GET {} HTTP/1.1\r\n".format(path)
|
|
req += "Host: {}:{}\r\n".format(host, port)
|
|
req += "Connection: Upgrade\r\n"
|
|
req += "Upgrade: websocket\r\n"
|
|
req += "Sec-WebSocket-Key: {}\r\n".format(key.decode())
|
|
req += "Sec-WebSocket-Version: 13\r\n"
|
|
req += "\r\n"
|
|
|
|
self.sock.write(req.encode())
|
|
|
|
# Read handshake response
|
|
header = b""
|
|
while b"\r\n\r\n" not in header:
|
|
chunk = self.sock.read(1)
|
|
if not chunk:
|
|
raise WebSocketError("Connection closed during handshake")
|
|
header += chunk
|
|
|
|
if b" 101 " not in header:
|
|
raise WebSocketError("Handshake failed: " + header.decode())
|
|
|
|
print("WebSocket connected!")
|
|
|
|
def is_connected(self):
|
|
return self.sock is not None
|
|
|
|
def send(self, data, opcode=1): # 1=Text, 2=Binary
|
|
if not self.sock:
|
|
print("WebSocket is not connected (send called on closed socket)")
|
|
raise WebSocketError("Connection closed")
|
|
|
|
if isinstance(data, str):
|
|
data = data.encode('utf-8')
|
|
|
|
header = bytearray()
|
|
header.append(0x80 | opcode) # FIN + Opcode
|
|
|
|
length = len(data)
|
|
if length < 126:
|
|
header.append(0x80 | length) # Masked + length
|
|
elif length < 65536:
|
|
header.append(0x80 | 126)
|
|
header.extend(length.to_bytes(2, 'big'))
|
|
else:
|
|
header.append(0x80 | 127)
|
|
header.extend(length.to_bytes(8, 'big'))
|
|
|
|
mask = uos.urandom(4)
|
|
header.extend(mask)
|
|
|
|
masked_data = bytearray(length)
|
|
for i in range(length):
|
|
masked_data[i] = data[i] ^ mask[i % 4]
|
|
|
|
self.sock.write(header)
|
|
self.sock.write(masked_data)
|
|
|
|
def _read_exact(self, n):
|
|
"""Read exactly n bytes from the socket"""
|
|
data = b''
|
|
while len(data) < n:
|
|
try:
|
|
chunk = self.sock.read(n - len(data))
|
|
if not chunk:
|
|
return None
|
|
data += chunk
|
|
except Exception as e:
|
|
# Handle timeout or other errors
|
|
if len(data) > 0:
|
|
# If we read some data but timed out, we can't just return None
|
|
# as we would lose that data. We must keep trying or raise error.
|
|
# For simplicity in this blocking-with-timeout model,
|
|
# we assume we should keep trying if we got some data,
|
|
# or return what we have if it's a hard error?
|
|
# Actually, if we return None, the caller treats it as "no message".
|
|
# But we already consumed data! This is the core issue.
|
|
# We should probably buffer it?
|
|
# Or just return None and let the caller handle it?
|
|
# But the caller (recv) expects a full frame or nothing.
|
|
|
|
# To properly fix this without a persistent buffer across calls
|
|
# (which is complex to add now), we will just print error and return None,
|
|
# accepting that we lost the connection sync.
|
|
print(f"Socket read error: {e}")
|
|
return None
|
|
return None
|
|
return data
|
|
|
|
def recv(self):
|
|
# 1. Check if we have unread messages in the buffer
|
|
if self.unread_messages:
|
|
return self.unread_messages.pop(0)
|
|
|
|
if not self.sock:
|
|
return None
|
|
|
|
# Read header
|
|
try:
|
|
# Read 2 bytes at once
|
|
header = self._read_exact(2)
|
|
if not header: return None
|
|
|
|
b1 = header[0]
|
|
b2 = header[1]
|
|
|
|
fin = b1 & 0x80
|
|
opcode = b1 & 0x0f
|
|
|
|
mask = b2 & 0x80
|
|
length = b2 & 0x7f
|
|
|
|
if length == 126:
|
|
length_bytes = self._read_exact(2)
|
|
if not length_bytes: return None
|
|
length = int.from_bytes(length_bytes, 'big')
|
|
elif length == 127:
|
|
length_bytes = self._read_exact(8)
|
|
if not length_bytes: return None
|
|
length = int.from_bytes(length_bytes, 'big')
|
|
|
|
# Safety check for memory allocation
|
|
if length > 50 * 1024: # 50KB limit (reduced from 1MB to be safer on ESP32)
|
|
print(f"WS Recv: Message too large ({length} bytes)")
|
|
# If it's a binary message (image chunk), maybe we can process it?
|
|
# But for now, just skip to avoid OOM
|
|
self._skip_bytes(length)
|
|
if mask:
|
|
self._read_exact(4) # Consume mask key
|
|
return None
|
|
|
|
if mask:
|
|
mask_key = self._read_exact(4)
|
|
if not mask_key: return None
|
|
|
|
# Optimization for streaming binary data (opcode 2)
|
|
try:
|
|
# Pre-allocate buffer or use shared buffer
|
|
if length <= 4096:
|
|
data = self.buffer
|
|
else:
|
|
data = bytearray(length)
|
|
except MemoryError:
|
|
print(f"WS Recv: Memory allocation failed for {length} bytes")
|
|
# Try to skip data
|
|
self._skip_bytes(length)
|
|
return None
|
|
|
|
# Use smaller chunks for readinto to avoid memory allocation issues in MicroPython
|
|
pos = 0
|
|
while pos < length:
|
|
chunk_size = min(length - pos, 1024) # 1KB chunks
|
|
try:
|
|
# Create a view into the target buffer
|
|
chunk_view = memoryview(data)[pos:pos + chunk_size]
|
|
|
|
# We need exact read here too
|
|
read_len = 0
|
|
while read_len < chunk_size:
|
|
chunk_read = self.sock.readinto(chunk_view[read_len:])
|
|
if not chunk_read:
|
|
# Connection closed or timeout
|
|
# If timeout, we are in trouble.
|
|
break
|
|
read_len += chunk_read
|
|
|
|
if read_len < chunk_size:
|
|
print("WS Recv: Incomplete payload read")
|
|
return None
|
|
|
|
pos += read_len
|
|
except Exception as e:
|
|
print(f"WS Recv read error: {e}")
|
|
return None
|
|
|
|
# Create a view for the relevant part of the data
|
|
view = memoryview(data)[:length]
|
|
|
|
if mask:
|
|
# In-place unmasking
|
|
for i in range(length):
|
|
view[i] = view[i] ^ mask_key[i % 4]
|
|
|
|
if opcode == 1: # Text
|
|
return str(view, 'utf-8')
|
|
elif opcode == 2: # Binary
|
|
return bytes(view) # Return copy
|
|
elif opcode == 8: # Close
|
|
self.close()
|
|
return None
|
|
elif opcode == 9: # Ping
|
|
self.send(view, opcode=10) # Pong
|
|
return self.recv()
|
|
|
|
return bytes(view)
|
|
|
|
except Exception as e:
|
|
# Don't print timeout errors as they are expected in non-blocking polling
|
|
if "ETIMEDOUT" not in str(e) and "110" not in str(e):
|
|
print(f"WS Recv Error: {e}")
|
|
return None
|
|
|
|
def _skip_bytes(self, length):
|
|
"""Skip bytes from socket"""
|
|
chunk_size = 1024
|
|
remaining = length
|
|
while remaining > 0:
|
|
to_read = min(remaining, chunk_size)
|
|
self.sock.read(to_read)
|
|
remaining -= to_read
|
|
|
|
def close(self):
|
|
if self.sock:
|
|
self.sock.close()
|
|
self.sock = None
|
|
self.unread_messages = []
|