Files
CatLink/Relay/FutariRelay.cs
2026-01-18 17:59:01 +08:00

306 lines
9.7 KiB
C#

using System.Net;
using System.Net.Sockets;
using System.Security.Cryptography;
using System.Text;
using CatLink.Models;
using Microsoft.Extensions.Logging;
namespace CatLink.Relay
{
public class FutariRelay
{
private readonly ILogger<FutariRelay> _logger;
private readonly IServiceProvider _serviceProvider;
private readonly int _port;
private readonly CancellationToken _cancellationToken;
private readonly Dictionary<uint, ActiveClient> _clients = new();
private readonly object _clientsLock = new();
public FutariRelay(int port, CancellationToken cancellationToken, ILogger<FutariRelay> logger, IServiceProvider serviceProvider)
{
_port = port;
_cancellationToken = cancellationToken;
_logger = logger;
_serviceProvider = serviceProvider;
}
public async Task StartAsync()
{
var listener = new TcpListener(IPAddress.Any, _port);
listener.Start();
_logger.LogInformation("TCP Relay server started on port {Port}", _port);
try
{
while (!_cancellationToken.IsCancellationRequested)
{
var socket = await listener.AcceptSocketAsync(_cancellationToken);
_ = Task.Run(() => HandleClientAsync(socket), _cancellationToken);
}
}
finally
{
listener.Stop();
}
}
private async Task HandleClientAsync(Socket socket)
{
socket.ReceiveTimeout = 20000;
socket.NoDelay = true;
try
{
using var stream = new NetworkStream(socket);
using var reader = new StreamReader(stream);
using var writer = new StreamWriter(stream) { AutoFlush = true };
var firstLine = await reader.ReadLineAsync();
if (string.IsNullOrEmpty(firstLine) || string.IsNullOrWhiteSpace(firstLine))
{
_logger.LogWarning("Received empty line from client");
socket.Close();
return;
}
var msg = Msg.FromString(firstLine);
if (msg.Cmd != Command.CTL_START)
{
_logger.LogWarning("First message was not CTL_START");
socket.Close();
return;
}
var clientKey = msg.Data;
if (string.IsNullOrEmpty(clientKey))
{
_logger.LogWarning("Client key is empty");
socket.Close();
return;
}
var stubIp = KeychipToStubIp(clientKey);
var clientLogger = _serviceProvider.GetRequiredService<ILogger<ActiveClient>>();
var client = new ActiveClient(
clientKey,
stubIp,
socket,
HandleMessage,
HandleDisconnect,
clientLogger
);
lock (_clientsLock)
{
if (_clients.ContainsKey(stubIp))
{
_logger.LogWarning("Client with stub IP {StubIp} already exists", stubIp);
client.Disconnect();
return;
}
_clients[stubIp] = client;
}
_logger.LogInformation("Client registered: {ClientKey} -> {StubIp}", clientKey, stubIp);
// Send version confirmation
client.Send(new Msg
{
Cmd = Command.CTL_START,
Data = "version=1"
});
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling client connection");
socket.Close();
}
}
private void HandleMessage(ActiveClient client, Msg msg)
{
switch (msg.Cmd)
{
case Command.CTL_HEARTBEAT:
client.LastHeartbeat = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
client.Send(new Msg { Cmd = Command.CTL_HEARTBEAT });
break;
case Command.CTL_TCP_CONNECT:
HandleTcpConnect(client, msg);
break;
case Command.CTL_TCP_ACCEPT:
HandleTcpAccept(client, msg);
break;
case Command.DATA_SEND:
HandleDataSend(client, msg);
break;
case Command.DATA_BROADCAST:
HandleDataBroadcast(client, msg);
break;
case Command.CTL_TCP_CLOSE:
HandleTcpClose(client, msg);
break;
}
}
private void HandleTcpConnect(ActiveClient client, Msg msg)
{
ActiveClient? target;
lock (_clientsLock)
{
_clients.TryGetValue(msg.Dst, out target);
}
if (target == null)
{
_logger.LogWarning("TCP connect: target not found {Dst}", msg.Dst);
return;
}
if (client.TcpStreams.ContainsKey(msg.Sid) || client.PendingStreams.Contains(msg.Sid))
{
_logger.LogWarning("TCP connect: stream ID already in use {Sid}", msg.Sid);
return;
}
client.PendingStreams.Add(msg.Sid);
var forwardMsg = msg.CloneWithNewData(msg.Data);
forwardMsg.Src = client.StubIp;
target.Send(forwardMsg);
_logger.LogDebug("TCP connect: {Src} -> {Dst}, stream {Sid}", client.StubIp, msg.Dst, msg.Sid);
}
private void HandleTcpAccept(ActiveClient client, Msg msg)
{
ActiveClient? target;
lock (_clientsLock)
{
_clients.TryGetValue(msg.Dst, out target);
}
if (target == null)
{
_logger.LogWarning("TCP accept: target not found {Dst}", msg.Dst);
return;
}
if (!target.PendingStreams.Contains(msg.Sid))
{
_logger.LogWarning("TCP accept: stream not in pending {Sid}", msg.Sid);
return;
}
target.PendingStreams.Remove(msg.Sid);
// Establish bidirectional stream mapping
target.TcpStreams[msg.Sid] = client.StubIp;
client.TcpStreams[msg.Sid] = target.StubIp;
var forwardMsg = msg.CloneWithNewData(msg.Data);
forwardMsg.Src = client.StubIp;
target.Send(forwardMsg);
_logger.LogDebug("TCP accept: {Src} <-> {Dst}, stream {Sid}", client.StubIp, msg.Dst, msg.Sid);
}
private void HandleDataSend(ActiveClient client, Msg msg)
{
if (!client.TcpStreams.TryGetValue(msg.Sid, out var targetStubIp))
{
_logger.LogWarning("Data send: stream not found {Sid}", msg.Sid);
return;
}
ActiveClient? target;
lock (_clientsLock)
{
_clients.TryGetValue(targetStubIp, out target);
}
if (target == null)
{
_logger.LogWarning("Data send: target not found {TargetStubIp}", targetStubIp);
return;
}
var forwardMsg = msg.CloneWithNewData(msg.Data);
forwardMsg.Src = client.StubIp;
forwardMsg.Dst = target.StubIp;
target.Send(forwardMsg);
}
private void HandleDataBroadcast(ActiveClient client, Msg msg)
{
List<ActiveClient> clientsCopy;
lock (_clientsLock)
{
clientsCopy = _clients.Values.ToList();
}
var forwardMsg = msg.CloneWithNewData(msg.Data);
forwardMsg.Src = client.StubIp;
foreach (var c in clientsCopy)
{
if (c.StubIp != client.StubIp)
{
c.Send(forwardMsg);
}
}
_logger.LogDebug("Data broadcast from {Src}", client.StubIp);
}
private void HandleTcpClose(ActiveClient client, Msg msg)
{
if (client.TcpStreams.TryGetValue(msg.Sid, out var targetStubIp))
{
client.TcpStreams.Remove(msg.Sid);
ActiveClient? target;
lock (_clientsLock)
{
_clients.TryGetValue(targetStubIp, out target);
}
if (target != null)
{
target.TcpStreams.Remove(msg.Sid);
var forwardMsg = msg.CloneWithNewData(msg.Data);
forwardMsg.Src = client.StubIp;
forwardMsg.Dst = target.StubIp;
target.Send(forwardMsg);
}
}
_logger.LogDebug("TCP close: stream {Sid}", msg.Sid);
}
private void HandleDisconnect(ActiveClient client)
{
lock (_clientsLock)
{
_clients.Remove(client.StubIp);
}
_logger.LogInformation("Client disconnected: {ClientKey}", client.ClientKey);
}
private static uint KeychipToStubIp(string keychip)
{
var hash = MD5.HashData(Encoding.UTF8.GetBytes(keychip));
return (uint)((hash[0] << 24) | (hash[1] << 16) | (hash[2] << 8) | hash[3]);
}
}
}