Added TCP wrappers to TCP proxy

This commit is contained in:
2025-01-28 14:00:14 +01:00
parent 4ef7500c3b
commit 56664cdac5
2 changed files with 65 additions and 96 deletions

View File

@@ -10,6 +10,7 @@ using System.Threading.Tasks;
using AMWD.Protocols.Modbus.Common;
using AMWD.Protocols.Modbus.Common.Contracts;
using AMWD.Protocols.Modbus.Common.Protocols;
using AMWD.Protocols.Modbus.Tcp.Utils;
namespace AMWD.Protocols.Modbus.Tcp
{
@@ -24,12 +25,12 @@ namespace AMWD.Protocols.Modbus.Tcp
private TimeSpan _readWriteTimeout = TimeSpan.FromSeconds(100);
private TcpListener _listener;
private TcpListenerWrapper _tcpListener;
private CancellationTokenSource _stopCts;
private Task _clientConnectTask = Task.CompletedTask;
private readonly SemaphoreSlim _clientListLock = new(1, 1);
private readonly List<TcpClient> _clients = [];
private readonly List<TcpClientWrapper> _clients = [];
private readonly List<Task> _clientTasks = [];
#endregion Fields
@@ -41,31 +42,11 @@ namespace AMWD.Protocols.Modbus.Tcp
/// </summary>
/// <param name="client">The <see cref="ModbusClientBase"/> used to request the remote device, that should be proxied.</param>
/// <param name="listenAddress">An <see cref="IPAddress"/> to listen on (Default: <see cref="IPAddress.Loopback"/>).</param>
/// <param name="listenPort">A port to listen on (Default: 502).</param>
public ModbusTcpProxy(ModbusClientBase client, IPAddress listenAddress = null, int listenPort = 502)
public ModbusTcpProxy(ModbusClientBase client, IPAddress listenAddress)
{
Client = client ?? throw new ArgumentNullException(nameof(client));
ListenAddress = listenAddress ?? IPAddress.Loopback;
if (listenPort < ushort.MinValue || ushort.MaxValue < listenPort)
throw new ArgumentOutOfRangeException(nameof(listenPort));
try
{
#if NET8_0_OR_GREATER
using var testListener = new TcpListener(ListenAddress, listenPort);
#else
var testListener = new TcpListener(ListenAddress, listenPort);
#endif
testListener.Start(1);
ListenPort = (testListener.LocalEndpoint as IPEndPoint).Port;
testListener.Stop();
}
catch (Exception ex)
{
throw new ArgumentException($"{nameof(ListenPort)} ({listenPort}) is already in use.", ex);
}
_tcpListener = new TcpListenerWrapper(listenAddress, 502);
}
#endregion Constructors
@@ -80,17 +61,25 @@ namespace AMWD.Protocols.Modbus.Tcp
/// <summary>
/// Gets the <see cref="IPAddress"/> to listen on.
/// </summary>
public IPAddress ListenAddress { get; }
public IPAddress ListenAddress
{
get => _tcpListener.LocalIPEndPoint.Address;
set => _tcpListener.LocalIPEndPoint.Address = value;
}
/// <summary>
/// Get the port to listen on.
/// </summary>
public int ListenPort { get; }
public int ListenPort
{
get => _tcpListener.LocalIPEndPoint.Port;
set => _tcpListener.LocalIPEndPoint.Port = value;
}
/// <summary>
/// Gets a value indicating whether the server is running.
/// </summary>
public bool IsRunning => _listener?.Server.IsBound ?? false;
public bool IsRunning => _tcpListener.Socket.IsBound;
/// <summary>
/// Gets or sets the read/write timeout for the incoming connections (not the <see cref="Client"/>!).
@@ -121,20 +110,14 @@ namespace AMWD.Protocols.Modbus.Tcp
Assertions();
_stopCts?.Cancel();
_listener?.Stop();
#if NET8_0_OR_GREATER
_listener?.Dispose();
#endif
_tcpListener.Stop();
_stopCts?.Dispose();
_stopCts = new CancellationTokenSource();
_listener = new TcpListener(ListenAddress, ListenPort);
if (ListenAddress.AddressFamily == AddressFamily.InterNetworkV6)
_listener.Server.DualMode = true;
_tcpListener.Socket.DualMode = ListenAddress.AddressFamily == AddressFamily.InterNetworkV6;
_listener.Start();
_tcpListener.Start();
_clientConnectTask = WaitForClientAsync(_stopCts.Token);
return Task.CompletedTask;
@@ -152,12 +135,9 @@ namespace AMWD.Protocols.Modbus.Tcp
private async Task StopAsyncInternal(CancellationToken cancellationToken = default)
{
_stopCts.Cancel();
_stopCts?.Cancel();
_tcpListener.Stop();
_listener.Stop();
#if NET8_0_OR_GREATER
_listener.Dispose();
#endif
try
{
await Task.WhenAny(_clientConnectTask, Task.Delay(Timeout.Infinite, cancellationToken));
@@ -191,6 +171,7 @@ namespace AMWD.Protocols.Modbus.Tcp
_clientListLock.Dispose();
_clients.Clear();
_tcpListener.Dispose();
_stopCts?.Dispose();
}
@@ -215,11 +196,7 @@ namespace AMWD.Protocols.Modbus.Tcp
{
try
{
#if NET8_0_OR_GREATER
var client = await _listener.AcceptTcpClientAsync(cancellationToken);
#else
var client = await _listener.AcceptTcpClientAsync();
#endif
var client = await _tcpListener.AcceptTcpClientAsync(cancellationToken);
await _clientListLock.WaitAsync(cancellationToken);
try
{
@@ -238,7 +215,7 @@ namespace AMWD.Protocols.Modbus.Tcp
}
}
private async Task HandleClientAsync(TcpClient client, CancellationToken cancellationToken)
private async Task HandleClientAsync(TcpClientWrapper client, CancellationToken cancellationToken)
{
try
{
@@ -253,11 +230,11 @@ namespace AMWD.Protocols.Modbus.Tcp
byte[] headerBytes = await stream.ReadExpectedBytesAsync(6, cts.Token);
requestBytes.AddRange(headerBytes);
byte[] followingCountBytes = headerBytes.Skip(4).Take(2).ToArray();
followingCountBytes.SwapBigEndian();
int followingCount = BitConverter.ToUInt16(followingCountBytes, 0);
ushort length = headerBytes
.Skip(4).Take(2).ToArray()
.GetBigEndianUInt16();
byte[] bodyBytes = await stream.ReadExpectedBytesAsync(followingCount, cts.Token);
byte[] bodyBytes = await stream.ReadExpectedBytesAsync(length, cts.Token);
requestBytes.AddRange(bodyBytes);
}
@@ -322,14 +299,14 @@ namespace AMWD.Protocols.Modbus.Tcp
default: // unknown function
{
byte[] responseBytes = new byte[9];
Array.Copy(requestBytes, 0, responseBytes, 0, 8);
var responseBytes = new List<byte>();
responseBytes.AddRange(requestBytes.Take(8));
responseBytes.Add((byte)ModbusErrorCode.IllegalFunction);
// Mark as error
responseBytes[7] |= 0x80;
responseBytes[8] = (byte)ModbusErrorCode.IllegalFunction;
return Task.FromResult(responseBytes);
return Task.FromResult(ReturnResponse(responseBytes));
}
}
}
@@ -662,18 +639,18 @@ namespace AMWD.Protocols.Modbus.Tcp
HighByte = requestBytes[baseOffset + i * 2],
LowByte = requestBytes[baseOffset + i * 2 + 1]
});
}
bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[6], list, cancellationToken);
if (isSuccess)
{
// Response is an echo of the request
responseBytes.AddRange(requestBytes.Skip(8).Take(4));
}
else
{
responseBytes[7] |= 0x80;
responseBytes.Add((byte)ModbusErrorCode.SlaveDeviceFailure);
}
bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[6], list, cancellationToken);
if (isSuccess)
{
// Response is an echo of the request
responseBytes.AddRange(requestBytes.Skip(8).Take(4));
}
else
{
responseBytes[7] |= 0x80;
responseBytes.Add((byte)ModbusErrorCode.SlaveDeviceFailure);
}
}
catch
@@ -687,6 +664,9 @@ namespace AMWD.Protocols.Modbus.Tcp
private async Task<byte[]> HandleEncapsulatedInterfaceAsync(byte[] requestBytes, CancellationToken cancellationToken)
{
if (requestBytes.Length < 11)
return null;
var responseBytes = new List<byte>();
responseBytes.AddRange(requestBytes.Take(8));
@@ -718,7 +698,7 @@ namespace AMWD.Protocols.Modbus.Tcp
try
{
var res = await Client.ReadDeviceIdentificationAsync(requestBytes[6], category, firstObject, cancellationToken);
var deviceInfo = await Client.ReadDeviceIdentificationAsync(requestBytes[6], category, firstObject, cancellationToken);
var bodyBytes = new List<byte>();
@@ -727,31 +707,20 @@ namespace AMWD.Protocols.Modbus.Tcp
// Conformity
bodyBytes.Add((byte)category);
if (res.IsIndividualAccessAllowed)
if (deviceInfo.IsIndividualAccessAllowed)
bodyBytes[2] |= 0x80;
// More, NextId, NumberOfObjects
bodyBytes.AddRange(new byte[3]);
int maxObjectId;
switch (category)
int maxObjectId = category switch
{
case ModbusDeviceIdentificationCategory.Basic:
maxObjectId = 0x02;
break;
case ModbusDeviceIdentificationCategory.Regular:
maxObjectId = 0x06;
break;
case ModbusDeviceIdentificationCategory.Extended:
maxObjectId = 0xFF;
break;
default: // Individual
maxObjectId = requestBytes[10];
break;
}
ModbusDeviceIdentificationCategory.Basic => 0x02,
ModbusDeviceIdentificationCategory.Regular => 0x06,
ModbusDeviceIdentificationCategory.Extended => 0xFF,
// Individual
_ => requestBytes[10],
};
byte numberOfObjects = 0;
for (int i = requestBytes[10]; i <= maxObjectId; i++)
@@ -760,7 +729,7 @@ namespace AMWD.Protocols.Modbus.Tcp
if (0x07 <= i && i <= 0x7F)
continue;
byte[] objBytes = GetDeviceObject((byte)i, res);
byte[] objBytes = GetDeviceObject((byte)i, deviceInfo);
// We need to split the response if it would exceed the max ADU size
if (responseBytes.Count + bodyBytes.Count + objBytes.Length > TcpProtocol.MAX_ADU_LENGTH)
@@ -799,7 +768,7 @@ namespace AMWD.Protocols.Modbus.Tcp
{
case ModbusDeviceIdentificationObject.VendorName:
{
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorName);
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorName ?? "");
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}
@@ -807,7 +776,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.ProductCode:
{
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductCode);
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductCode ?? "");
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}
@@ -815,7 +784,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.MajorMinorRevision:
{
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.MajorMinorRevision);
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.MajorMinorRevision ?? "");
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}
@@ -823,7 +792,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.VendorUrl:
{
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorUrl);
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorUrl ?? "");
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}
@@ -831,7 +800,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.ProductName:
{
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductName);
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductName ?? "");
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}
@@ -839,7 +808,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.ModelName:
{
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ModelName);
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ModelName ?? "");
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}
@@ -847,7 +816,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.UserApplicationName:
{
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.UserApplicationName);
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.UserApplicationName ?? "");
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}
@@ -855,9 +824,8 @@ namespace AMWD.Protocols.Modbus.Tcp
default:
{
if (deviceIdentification.ExtendedObjects.ContainsKey(objectId))
if (deviceIdentification.ExtendedObjects.TryGetValue(objectId, out byte[] bytes))
{
byte[] bytes = deviceIdentification.ExtendedObjects[objectId];
result.Add((byte)bytes.Length);
result.AddRange(bytes);
}