import asyncio import os import ssl import websockets from channels.generic.websocket import AsyncWebsocketConsumer from django.conf import settings from urllib.parse import quote class GatewayConnection: _ssl_context: ssl.SSLContext = None def __init__(self, host: str, port: int): if settings.CONNECTION_GATEWAY_AUTH_KEY and not GatewayConnection._ssl_context: ctx = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) ctx.load_cert_chain( os.path.realpath(settings.CONNECTION_GATEWAY_AUTH_CERTIFICATE), os.path.realpath(settings.CONNECTION_GATEWAY_AUTH_KEY), ) if settings.CONNECTION_GATEWAY_AUTH_CA: ctx.load_verify_locations( cafile=os.path.realpath(settings.CONNECTION_GATEWAY_AUTH_CA), ) ctx.verify_mode = ssl.CERT_REQUIRED GatewayConnection._ssl_context = ctx proto = 'wss' if GatewayConnection._ssl_context else 'ws' self.url = f'{proto}://localhost:9000/connect/{quote(host)}:{quote(str(port))}' async def connect(self): self.context = websockets.connect(self.url, ssl=GatewayConnection._ssl_context) self.socket = await self.context.__aenter__() async def send(self, data): await self.socket.send(data) def recv(self, timeout=None): return asyncio.wait_for(self.socket.recv(), timeout) async def close(self): await self.socket.close() await self.context.__aexit__(None, None, None) class TCPConsumer(AsyncWebsocketConsumer): async def connect(self): self.closed = False self.conn = GatewayConnection( self.scope['url_route']['kwargs']['host'], int(self.scope['url_route']['kwargs']['port']), ) await self.conn.connect() await self.accept() self.reader = asyncio.get_event_loop().create_task(self.socket_reader()) async def disconnect(self, close_code): self.closed = True await self.conn.close() async def receive(self, bytes_data): await self.conn.send(bytes_data) async def socket_reader(self): while True: if self.closed: return try: data = await self.conn.recv(timeout=10) except asyncio.TimeoutError: continue except websockets.exceptions.ConnectionClosed: await self.close() return await self.send(bytes_data=data)