[Refactor] Link Play TCP data transmission

- Code refactor of Link Play TCP data transmission for better security and scalability
This commit is contained in:
Lost-MSth
2023-12-03 00:38:43 +08:00
parent 3e93082a3c
commit 150686d9f8
7 changed files with 222 additions and 75 deletions

View File

@@ -23,6 +23,7 @@ class Config:
LINKPLAY_TCP_PORT = 10901 LINKPLAY_TCP_PORT = 10901
LINKPLAY_AUTHENTICATION = 'my_link_play_server' LINKPLAY_AUTHENTICATION = 'my_link_play_server'
LINKPLAY_DISPLAY_HOST = '' LINKPLAY_DISPLAY_HOST = ''
LINKPLAY_TCP_SECRET_KEY = '1145141919810'
SSL_CERT = '' SSL_CERT = ''
SSL_KEY = '' SSL_KEY = ''

View File

@@ -56,6 +56,8 @@ class Constant:
LINKPLAY_TCP_PORT = Config.LINKPLAY_TCP_PORT LINKPLAY_TCP_PORT = Config.LINKPLAY_TCP_PORT
LINKPLAY_UDP_PORT = Config.LINKPLAY_UDP_PORT LINKPLAY_UDP_PORT = Config.LINKPLAY_UDP_PORT
LINKPLAY_AUTHENTICATION = Config.LINKPLAY_AUTHENTICATION LINKPLAY_AUTHENTICATION = Config.LINKPLAY_AUTHENTICATION
LINKPLAY_TCP_SECRET_KEY = Config.LINKPLAY_TCP_SECRET_KEY
LINKPLAY_TCP_MAX_LENGTH = 0x0FFFFFFF
# Well, I can't say a word when I see this. # Well, I can't say a word when I see this.
FINALE_SWITCH = [ FINALE_SWITCH = [

View File

@@ -1,10 +1,12 @@
import socket import socket
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from json import dumps, loads
from core.error import ArcError, Timeout from core.error import ArcError, Timeout
from .constant import Constant from .constant import Constant
from .user import UserInfo from .user import UserInfo
from .util import aes_gcm_128_decrypt, aes_gcm_128_encrypt
socket.setdefaulttimeout(Constant.LINKPLAY_TIMEOUT) socket.setdefaulttimeout(Constant.LINKPLAY_TIMEOUT)
@@ -86,53 +88,106 @@ class Room:
class RemoteMultiPlayer: class RemoteMultiPlayer:
TCP_AES_KEY = Constant.LINKPLAY_TCP_SECRET_KEY.encode(
'utf-8').ljust(16, b'\x00')[:16]
def __init__(self) -> None: def __init__(self) -> None:
self.user: 'Player' = None self.user: 'Player' = None
self.room: 'Room' = None self.room: 'Room' = None
self.data_recv: tuple = None self.data_recv: 'dict | list' = None
def to_dict(self) -> dict: def to_dict(self) -> dict:
return dict(self.room.to_dict(), **self.user.to_dict()) return dict(self.room.to_dict(), **self.user.to_dict())
@staticmethod @staticmethod
def tcp(data: str) -> str: def tcp(data: bytes) -> bytes:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((Constant.LINKPLAY_HOST, sock.connect((Constant.LINKPLAY_HOST,
Constant.LINKPLAY_TCP_PORT)) Constant.LINKPLAY_TCP_PORT))
sock.sendall(bytes(data + "\n", "utf-8"))
sock.sendall(data)
try: try:
received = str(sock.recv(1024), "utf-8").strip() cipher_len = int.from_bytes(sock.recv(8), byteorder='little')
if cipher_len > Constant.LINKPLAY_TCP_MAX_LENGTH:
raise ArcError(
'Too long body from link play server', status=400)
iv = sock.recv(12)
tag = sock.recv(12)
ciphertext = sock.recv(cipher_len)
received = aes_gcm_128_decrypt(
RemoteMultiPlayer.TCP_AES_KEY, b'', iv, ciphertext, tag)
except socket.timeout as e: except socket.timeout as e:
raise Timeout( raise Timeout(
'Timeout when waiting for data from link play server.', status=400) from e 'Timeout when waiting for data from link play server.', status=400) from e
# print(received) # print(received)
return received return received
def data_swap(self, data: tuple) -> tuple: def data_swap(self, data: dict) -> dict:
received = self.tcp(Constant.LINKPLAY_AUTHENTICATION + iv, ciphertext, tag = aes_gcm_128_encrypt(
'|' + '|'.join([str(x) for x in data])) self.TCP_AES_KEY, dumps(data).encode('utf-8'), b'')
send_data = Constant.LINKPLAY_AUTHENTICATION.encode(
'utf-8') + len(ciphertext).to_bytes(8, byteorder='little') + iv + tag[:12] + ciphertext
recv_data = self.tcp(send_data)
self.data_recv = loads(recv_data)
self.data_recv = received.split('|') code = self.data_recv['code']
if self.data_recv[0] != '0': if code != 0:
code = int(self.data_recv[0])
raise ArcError(f'Link Play error code: {code}', code, status=400) raise ArcError(f'Link Play error code: {code}', code, status=400)
return self.data_recv
# if self.data_recv[0] != '0':
# code = int(self.data_recv[0])
# raise ArcError(f'Link Play error code: {code}', code, status=400)
# @staticmethod
# def tcp(data: str) -> str:
# with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
# sock.connect((Constant.LINKPLAY_HOST,
# Constant.LINKPLAY_TCP_PORT))
# send_data =
# sock.sendall(bytes(data + "\n", "utf-8"))
# try:
# received = str(sock.recv(1024), "utf-8").strip()
# except socket.timeout as e:
# raise Timeout(
# 'Timeout when waiting for data from link play server.', status=400) from e
# # print(received)
# return received
# def data_swap(self, data: tuple) -> tuple:
# received = self.tcp(Constant.LINKPLAY_AUTHENTICATION +
# '|' + '|'.join([str(x) for x in data]))
# self.data_recv = received.split('|')
# if self.data_recv[0] != '0':
# code = int(self.data_recv[0])
# raise ArcError(f'Link Play error code: {code}', code, status=400)
def create_room(self, user: 'Player' = None) -> None: def create_room(self, user: 'Player' = None) -> None:
'''创建房间''' '''创建房间'''
if user is not None: if user is not None:
self.user = user self.user = user
user.select_user_one_column('name') user.select_user_one_column('name')
self.data_swap((1, self.user.name, b64encode( self.data_swap({
self.user.song_unlock).decode('utf-8'))) 'endpoint': 'create_room',
'data': {
'name': self.user.name,
'song_unlock': b64encode(self.user.song_unlock).decode('utf-8')
}
})
self.room = Room() self.room = Room()
self.room.room_code = self.data_recv[1] x = self.data_recv['data']
self.room.room_id = int(self.data_recv[2]) self.room.room_code = x['room_code']
self.room.room_id = int(x['room_id'])
self.room.song_unlock = self.user.song_unlock self.room.song_unlock = self.user.song_unlock
self.user.token = int(self.data_recv[3]) self.user.token = int(x['token'])
self.user.key = b64decode(self.data_recv[4]) self.user.key = b64decode(x['key'])
self.user.player_id = int(self.data_recv[5]) self.user.player_id = int(x['player_id'])
def join_room(self, room: 'Room' = None, user: 'Player' = None) -> None: def join_room(self, room: 'Room' = None, user: 'Player' = None) -> None:
'''加入房间''' '''加入房间'''
@@ -142,23 +197,37 @@ class RemoteMultiPlayer:
self.room = room self.room = room
self.user.select_user_one_column('name') self.user.select_user_one_column('name')
self.data_swap( self.data_swap({
(2, self.user.name, b64encode(self.user.song_unlock).decode('utf-8'), room.room_code)) 'endpoint': 'join_room',
self.room.room_code = self.data_recv[1] 'data': {
self.room.room_id = int(self.data_recv[2]) 'name': self.user.name,
self.room.song_unlock = b64decode(self.data_recv[6]) 'song_unlock': b64encode(self.user.song_unlock).decode('utf-8'),
self.user.token = int(self.data_recv[3]) 'room_code': self.room.room_code
self.user.key = b64decode(self.data_recv[4]) }
self.user.player_id = int(self.data_recv[5]) })
x = self.data_recv['data']
self.room.room_code = x['room_code']
self.room.room_id = int(x['room_id'])
self.room.song_unlock = b64decode(x['song_unlock'])
self.user.token = int(x['token'])
self.user.key = b64decode(x['key'])
self.user.player_id = int(x['player_id'])
def update_room(self, user: 'Player' = None) -> None: def update_room(self, user: 'Player' = None) -> None:
'''更新房间''' '''更新房间'''
if user is not None: if user is not None:
self.user = user self.user = user
self.data_swap((3, self.user.token)) self.data_swap({
'endpoint': 'update_room',
'data': {
'token': self.user.token
}
})
self.room = Room() self.room = Room()
self.room.room_code = self.data_recv[1] x = self.data_recv['data']
self.room.room_id = int(self.data_recv[2]) self.room.room_code = x['room_code']
self.room.song_unlock = b64decode(self.data_recv[5]) self.room.room_id = int(x['room_id'])
self.user.key = b64decode(self.data_recv[3]) self.room.song_unlock = b64decode(x['song_unlock'])
self.user.player_id = int(self.data_recv[4]) self.user.key = b64decode(x['key'])
self.user.player_id = int(x['player_id'])

View File

@@ -1,9 +1,30 @@
import hashlib import hashlib
import os import os
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from datetime import date from datetime import date
from time import mktime from time import mktime
def aes_gcm_128_encrypt(key, plaintext, associated_data):
iv = os.urandom(12)
encryptor = Cipher(
algorithms.AES(key),
modes.GCM(iv, min_tag_length=12),
).encryptor()
encryptor.authenticate_additional_data(associated_data)
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return (iv, ciphertext, encryptor.tag)
def aes_gcm_128_decrypt(key, associated_data, iv, ciphertext, tag):
decryptor = Cipher(
algorithms.AES(key),
modes.GCM(iv, tag, min_tag_length=12),
).decryptor()
decryptor.authenticate_additional_data(associated_data)
return decryptor.update(ciphertext) + decryptor.finalize()
def md5(code: str) -> str: def md5(code: str) -> str:
# md5加密算法 # md5加密算法
code = code.encode() code = code.encode()

View File

@@ -4,19 +4,22 @@ class Config:
''' '''
''' '''
服务器地址、端口号、校验码 服务器地址、端口号、校验码、传输加密密钥
Server address, port and verification code Server address, port, verification code, and encryption key
''' '''
HOST = '0.0.0.0' HOST = '0.0.0.0'
UDP_PORT = 10900 UDP_PORT = 10900
TCP_PORT = 10901 TCP_PORT = 10901
AUTHENTICATION = 'my_link_play_server' AUTHENTICATION = 'my_link_play_server'
TCP_SECRET_KEY = '1145141919810'
''' '''
-------------------------------------------------- --------------------------------------------------
''' '''
DEBUG = False DEBUG = False
TCP_MAX_LENGTH = 0x0FFFFFFF
TIME_LIMIT = 3600000 TIME_LIMIT = 3600000
COMMAND_INTERVAL = 1000000 COMMAND_INTERVAL = 1000000

View File

@@ -2,6 +2,7 @@
import logging import logging
import socketserver import socketserver
import threading import threading
from json import dumps, loads
from .aes import decrypt, encrypt from .aes import decrypt, encrypt
from .config import Config from .config import Config
@@ -55,27 +56,52 @@ class UDP_handler(socketserver.BaseRequestHandler):
ciphertext, self.client_address) ciphertext, self.client_address)
AUTH_LEN = len(Config.AUTHENTICATION)
TCP_AES_KEY = Config.TCP_SECRET_KEY.encode('utf-8').ljust(16, b'\x00')[:16]
class TCP_handler(socketserver.StreamRequestHandler): class TCP_handler(socketserver.StreamRequestHandler):
def handle(self): def handle(self):
try: try:
self.data = self.rfile.readline().strip() if self.rfile.read(AUTH_LEN).decode('utf-8') != Config.AUTHENTICATION:
self.wfile.write(b'No authentication')
logging.warning(
f'TCP-{self.client_address[0]}-No authentication')
return None
cipher_len = int.from_bytes(self.rfile.read(8), byteorder='little')
if cipher_len > Config.TCP_MAX_LENGTH:
self.wfile.write(b'Body too long')
logging.warning(f'TCP-{self.client_address[0]}-Body too long')
return None
iv = self.rfile.read(12)
tag = self.rfile.read(12)
ciphertext = self.rfile.read(cipher_len)
self.data = decrypt(TCP_AES_KEY, b'', iv, ciphertext, tag)
message = self.data.decode('utf-8') message = self.data.decode('utf-8')
data = loads(message)
except Exception as e: except Exception as e:
logging.error(e) logging.error(e)
return None return None
if Config.DEBUG: if Config.DEBUG:
logging.info(f'TCP-From-{self.client_address[0]}-{message}') logging.info(f'TCP-From-{self.client_address[0]}-{message}')
data = message.split('|')
if data[0] != Config.AUTHENTICATION:
self.wfile.write(b'No authentication')
logging.warning(f'TCP-{self.client_address[0]}-No authentication')
return None
r = TCPRouter(data[1:]).handle() r = TCPRouter(data).handle()
if Config.DEBUG: try:
logging.info(f'TCP-To-{self.client_address[0]}-{r}') r = dumps(r)
self.wfile.write(r.encode('utf-8')) if Config.DEBUG:
logging.info(f'TCP-To-{self.client_address[0]}-{r}')
iv, ciphertext, tag = encrypt(TCP_AES_KEY, r.encode('utf-8'), b'')
r = len(ciphertext).to_bytes(8, byteorder='little') + \
iv + tag[:12] + ciphertext
except Exception as e:
logging.error(e)
return None
self.wfile.write(r)
def link_play(ip: str = Config.HOST, udp_port: int = Config.UDP_PORT, tcp_port: int = Config.TCP_PORT): def link_play(ip: str = Config.HOST, udp_port: int = Config.UDP_PORT, tcp_port: int = Config.TCP_PORT):

View File

@@ -87,19 +87,21 @@ def memory_clean(now):
class TCPRouter: class TCPRouter:
clean_timer = 0 clean_timer = 0
router = { router = {
'0': 'debug', 'debug',
'1': 'create_room', 'create_room',
'2': 'join_room', 'join_room',
'3': 'update_room', 'update_room',
} }
def __init__(self, data: list): def __init__(self, raw_data: 'dict | list'):
self.data = data # data: list[str] = [command, ...] self.raw_data = raw_data # data: dict {endpoint: str, data: dict}
self.data = raw_data['data']
self.endpoint = raw_data['endpoint']
def debug(self): def debug(self) -> dict:
if Config.DEBUG: if Config.DEBUG:
return eval(self.data[1]) return {'result': eval(self.data['code'])}
return 'ok' return {'hello_world': 'ok'}
@staticmethod @staticmethod
def clean_check(): def clean_check():
@@ -109,14 +111,17 @@ class TCPRouter:
TCPRouter.clean_timer = now TCPRouter.clean_timer = now
memory_clean(now) memory_clean(now)
def handle(self) -> str: def handle(self) -> dict:
self.clean_check() self.clean_check()
if self.data[0] not in self.router: if self.endpoint not in self.router:
return None return None
r = getattr(self, self.router[self.data[0]])() r = getattr(self, self.endpoint)()
if isinstance(r, tuple): if isinstance(r, int):
return '|'.join(map(str, r)) return {'code': r}
return str(r) return {
'code': 0,
'data': r
}
@staticmethod @staticmethod
def generate_player(name: str) -> Player: def generate_player(name: str) -> Player:
@@ -144,12 +149,12 @@ class TCPRouter:
return room return room
def create_room(self) -> tuple: def create_room(self) -> dict:
# 开房 # 开房
# data = ['1', name, song_unlock, ] # data = ['1', name, song_unlock, ]
# song_unlock: base64 str # song_unlock: base64 str
name = self.data[1] name = self.data['name']
song_unlock = b64decode(self.data[2]) song_unlock = b64decode(self.data['song_unlock'])
key = urandom(16) key = urandom(16)
with Store.lock: with Store.lock:
@@ -172,33 +177,39 @@ class TCPRouter:
} }
logging.info(f'TCP-Create room `{room.room_code}` by player `{name}`') logging.info(f'TCP-Create room `{room.room_code}` by player `{name}`')
return (0, room.room_code, room.room_id, token, b64encode(key).decode('utf-8'), player.player_id) return {
'room_code': room.room_code,
'room_id': room.room_id,
'token': token,
'key': b64encode(key).decode('utf-8'),
'player_id': player.player_id
}
def join_room(self) -> tuple: def join_room(self) -> 'dict | int':
# 入房 # 入房
# data = ['2', name, song_unlock, room_code] # data = ['2', name, song_unlock, room_code]
# song_unlock: base64 str # song_unlock: base64 str
room_code = self.data[3].upper() room_code = self.data['room_code'].upper()
key = urandom(16) key = urandom(16)
name = self.data[1] name = self.data['name']
song_unlock = b64decode(self.data[2]) song_unlock = b64decode(self.data['song_unlock'])
with Store.lock: with Store.lock:
if room_code not in Store.room_code_dict: if room_code not in Store.room_code_dict:
# 房间号错误 / 房间不存在 # 房间号错误 / 房间不存在
return '1202' return 1202
room: Room = Store.room_code_dict[room_code] room: Room = Store.room_code_dict[room_code]
player_num = room.player_num player_num = room.player_num
if player_num == 4: if player_num == 4:
# 满人 # 满人
return '1201' return 1201
if player_num == 0: if player_num == 0:
# 房间不存在 # 房间不存在
return '1202' return 1202
if room.state != 2: if room.state != 2:
# 无法加入 # 无法加入
return '1205' return 1205
token = unique_random(Store.link_play_data) token = unique_random(Store.link_play_data)
@@ -219,16 +230,30 @@ class TCPRouter:
} }
logging.info(f'TCP-Player `{name}` joins room `{room_code}`') logging.info(f'TCP-Player `{name}` joins room `{room_code}`')
return (0, room_code, room.room_id, token, b64encode(key).decode('utf-8'), player.player_id, b64encode(room.song_unlock).decode('utf-8')) return {
'room_code': room_code,
'room_id': room.room_id,
'token': token,
'key': b64encode(key).decode('utf-8'),
'player_id': player.player_id,
'song_unlock': b64encode(room.song_unlock).decode('utf-8')
}
def update_room(self) -> tuple: def update_room(self) -> dict:
# 房间信息更新 # 房间信息更新
# data = ['3', token] # data = ['3', token]
token = int(self.data[1]) token = int(self.data['token'])
with Store.lock: with Store.lock:
if token not in Store.link_play_data: if token not in Store.link_play_data:
return '108' return 108
r = Store.link_play_data[token] r = Store.link_play_data[token]
room = r['room'] room = r['room']
logging.info(f'TCP-Room `{room.room_code}` info update') logging.info(f'TCP-Room `{room.room_code}` info update')
return (0, room.room_code, room.room_id, b64encode(r['key']).decode('utf-8'), room.players[r['player_index']].player_id, b64encode(room.song_unlock).decode('utf-8')) return {
'room_code': room.room_code,
'room_id': room.room_id,
'key': b64encode(r['key']).decode('utf-8'),
# changed from room.players[r['player_index']].player_id,
'player_id': r['player_id'],
'song_unlock': b64encode(room.song_unlock).decode('utf-8')
}