t
This commit is contained in:
178
websocket_client.py
Normal file
178
websocket_client.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import usocket as socket
|
||||
import ubinascii
|
||||
import uos
|
||||
|
||||
class WebSocketError(Exception):
|
||||
pass
|
||||
|
||||
class WebSocketClient:
|
||||
def __init__(self, uri, timeout=5):
|
||||
self.sock = None
|
||||
self.uri = uri
|
||||
self.timeout = timeout
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
uri = self.uri
|
||||
assert uri.startswith("ws://")
|
||||
|
||||
uri = uri[5:]
|
||||
if "/" in uri:
|
||||
host, path = uri.split("/", 1)
|
||||
else:
|
||||
host, path = uri, ""
|
||||
path = "/" + path
|
||||
|
||||
if ":" in host:
|
||||
host, port = host.split(":")
|
||||
port = int(port)
|
||||
else:
|
||||
port = 80
|
||||
|
||||
print(f"Connecting to {host}:{port}{path}...")
|
||||
self.sock = socket.socket()
|
||||
|
||||
# Add timeout
|
||||
self.sock.settimeout(self.timeout)
|
||||
|
||||
addr_info = socket.getaddrinfo(host, port)
|
||||
addr = addr_info[0][-1]
|
||||
print(f"Resolved address: {addr}")
|
||||
|
||||
try:
|
||||
self.sock.connect(addr)
|
||||
except OSError as e:
|
||||
print(f"Socket connect failed: {e}")
|
||||
if e.args[0] == 113:
|
||||
print("Hint: Check firewall settings on server or if server is running.")
|
||||
raise
|
||||
|
||||
# Random key
|
||||
key = ubinascii.b2a_base64(uos.urandom(16)).strip()
|
||||
|
||||
|
||||
req = "GET {} HTTP/1.1\r\n".format(path)
|
||||
req += "Host: {}:{}\r\n".format(host, port)
|
||||
req += "Connection: Upgrade\r\n"
|
||||
req += "Upgrade: websocket\r\n"
|
||||
req += "Sec-WebSocket-Key: {}\r\n".format(key.decode())
|
||||
req += "Sec-WebSocket-Version: 13\r\n"
|
||||
req += "\r\n"
|
||||
|
||||
self.sock.write(req.encode())
|
||||
|
||||
# Read handshake response
|
||||
header = b""
|
||||
while b"\r\n\r\n" not in header:
|
||||
chunk = self.sock.read(1)
|
||||
if not chunk:
|
||||
raise WebSocketError("Connection closed during handshake")
|
||||
header += chunk
|
||||
|
||||
if b" 101 " not in header:
|
||||
raise WebSocketError("Handshake failed: " + header.decode())
|
||||
|
||||
print("WebSocket connected!")
|
||||
|
||||
def is_connected(self):
|
||||
return self.sock is not None
|
||||
|
||||
def send(self, data, opcode=1): # 1=Text, 2=Binary
|
||||
if not self.sock:
|
||||
print("WebSocket is not connected (send called on closed socket)")
|
||||
raise WebSocketError("Connection closed")
|
||||
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
header = bytearray()
|
||||
header.append(0x80 | opcode) # FIN + Opcode
|
||||
|
||||
length = len(data)
|
||||
if length < 126:
|
||||
header.append(0x80 | length) # Masked + length
|
||||
elif length < 65536:
|
||||
header.append(0x80 | 126)
|
||||
header.extend(length.to_bytes(2, 'big'))
|
||||
else:
|
||||
header.append(0x80 | 127)
|
||||
header.extend(length.to_bytes(8, 'big'))
|
||||
|
||||
mask = uos.urandom(4)
|
||||
header.extend(mask)
|
||||
|
||||
masked_data = bytearray(length)
|
||||
for i in range(length):
|
||||
masked_data[i] = data[i] ^ mask[i % 4]
|
||||
|
||||
self.sock.write(header)
|
||||
self.sock.write(masked_data)
|
||||
|
||||
def recv(self):
|
||||
# Read header
|
||||
try:
|
||||
# Read 2 bytes at once
|
||||
header = self.sock.read(2)
|
||||
if not header or len(header) < 2: return None
|
||||
|
||||
b1 = header[0]
|
||||
b2 = header[1]
|
||||
|
||||
fin = b1 & 0x80
|
||||
opcode = b1 & 0x0f
|
||||
|
||||
mask = b2 & 0x80
|
||||
length = b2 & 0x7f
|
||||
|
||||
if length == 126:
|
||||
length_bytes = self.sock.read(2)
|
||||
if not length_bytes: return None
|
||||
length = int.from_bytes(length_bytes, 'big')
|
||||
elif length == 127:
|
||||
length_bytes = self.sock.read(8)
|
||||
if not length_bytes: return None
|
||||
length = int.from_bytes(length_bytes, 'big')
|
||||
|
||||
if mask:
|
||||
mask_key = self.sock.read(4)
|
||||
if not mask_key: return None
|
||||
|
||||
# Read payload
|
||||
data = bytearray(length)
|
||||
view = memoryview(data)
|
||||
pos = 0
|
||||
while pos < length:
|
||||
read_len = self.sock.readinto(view[pos:])
|
||||
if read_len == 0:
|
||||
return None
|
||||
pos += read_len
|
||||
|
||||
if mask:
|
||||
unmasked = bytearray(length)
|
||||
for i in range(length):
|
||||
unmasked[i] = data[i] ^ mask_key[i % 4]
|
||||
data = unmasked
|
||||
|
||||
if opcode == 1: # Text
|
||||
return data.decode('utf-8')
|
||||
elif opcode == 2: # Binary
|
||||
return data
|
||||
elif opcode == 8: # Close
|
||||
self.close()
|
||||
return None
|
||||
elif opcode == 9: # Ping
|
||||
self.send(data, opcode=10) # Pong
|
||||
return self.recv()
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
# Don't print timeout errors as they are expected in non-blocking polling
|
||||
if "ETIMEDOUT" not in str(e) and "110" not in str(e):
|
||||
print(f"WS Recv Error: {e}")
|
||||
return None
|
||||
|
||||
def close(self):
|
||||
if self.sock:
|
||||
self.sock.close()
|
||||
self.sock = None
|
||||
Reference in New Issue
Block a user