Added TCP wrappers to TCP proxy

This commit is contained in:
2025-01-28 14:00:14 +01:00
parent 4ef7500c3b
commit 56664cdac5
2 changed files with 65 additions and 96 deletions

View File

@@ -10,6 +10,7 @@ using System.Threading.Tasks;
using AMWD.Protocols.Modbus.Common; using AMWD.Protocols.Modbus.Common;
using AMWD.Protocols.Modbus.Common.Contracts; using AMWD.Protocols.Modbus.Common.Contracts;
using AMWD.Protocols.Modbus.Common.Protocols; using AMWD.Protocols.Modbus.Common.Protocols;
using AMWD.Protocols.Modbus.Tcp.Utils;
namespace AMWD.Protocols.Modbus.Tcp namespace AMWD.Protocols.Modbus.Tcp
{ {
@@ -24,12 +25,12 @@ namespace AMWD.Protocols.Modbus.Tcp
private TimeSpan _readWriteTimeout = TimeSpan.FromSeconds(100); private TimeSpan _readWriteTimeout = TimeSpan.FromSeconds(100);
private TcpListener _listener; private TcpListenerWrapper _tcpListener;
private CancellationTokenSource _stopCts; private CancellationTokenSource _stopCts;
private Task _clientConnectTask = Task.CompletedTask; private Task _clientConnectTask = Task.CompletedTask;
private readonly SemaphoreSlim _clientListLock = new(1, 1); private readonly SemaphoreSlim _clientListLock = new(1, 1);
private readonly List<TcpClient> _clients = []; private readonly List<TcpClientWrapper> _clients = [];
private readonly List<Task> _clientTasks = []; private readonly List<Task> _clientTasks = [];
#endregion Fields #endregion Fields
@@ -41,31 +42,11 @@ namespace AMWD.Protocols.Modbus.Tcp
/// </summary> /// </summary>
/// <param name="client">The <see cref="ModbusClientBase"/> used to request the remote device, that should be proxied.</param> /// <param name="client">The <see cref="ModbusClientBase"/> used to request the remote device, that should be proxied.</param>
/// <param name="listenAddress">An <see cref="IPAddress"/> to listen on (Default: <see cref="IPAddress.Loopback"/>).</param> /// <param name="listenAddress">An <see cref="IPAddress"/> to listen on (Default: <see cref="IPAddress.Loopback"/>).</param>
/// <param name="listenPort">A port to listen on (Default: 502).</param> public ModbusTcpProxy(ModbusClientBase client, IPAddress listenAddress)
public ModbusTcpProxy(ModbusClientBase client, IPAddress listenAddress = null, int listenPort = 502)
{ {
Client = client ?? throw new ArgumentNullException(nameof(client)); Client = client ?? throw new ArgumentNullException(nameof(client));
ListenAddress = listenAddress ?? IPAddress.Loopback; _tcpListener = new TcpListenerWrapper(listenAddress, 502);
if (listenPort < ushort.MinValue || ushort.MaxValue < listenPort)
throw new ArgumentOutOfRangeException(nameof(listenPort));
try
{
#if NET8_0_OR_GREATER
using var testListener = new TcpListener(ListenAddress, listenPort);
#else
var testListener = new TcpListener(ListenAddress, listenPort);
#endif
testListener.Start(1);
ListenPort = (testListener.LocalEndpoint as IPEndPoint).Port;
testListener.Stop();
}
catch (Exception ex)
{
throw new ArgumentException($"{nameof(ListenPort)} ({listenPort}) is already in use.", ex);
}
} }
#endregion Constructors #endregion Constructors
@@ -80,17 +61,25 @@ namespace AMWD.Protocols.Modbus.Tcp
/// <summary> /// <summary>
/// Gets the <see cref="IPAddress"/> to listen on. /// Gets the <see cref="IPAddress"/> to listen on.
/// </summary> /// </summary>
public IPAddress ListenAddress { get; } public IPAddress ListenAddress
{
get => _tcpListener.LocalIPEndPoint.Address;
set => _tcpListener.LocalIPEndPoint.Address = value;
}
/// <summary> /// <summary>
/// Get the port to listen on. /// Get the port to listen on.
/// </summary> /// </summary>
public int ListenPort { get; } public int ListenPort
{
get => _tcpListener.LocalIPEndPoint.Port;
set => _tcpListener.LocalIPEndPoint.Port = value;
}
/// <summary> /// <summary>
/// Gets a value indicating whether the server is running. /// Gets a value indicating whether the server is running.
/// </summary> /// </summary>
public bool IsRunning => _listener?.Server.IsBound ?? false; public bool IsRunning => _tcpListener.Socket.IsBound;
/// <summary> /// <summary>
/// Gets or sets the read/write timeout for the incoming connections (not the <see cref="Client"/>!). /// Gets or sets the read/write timeout for the incoming connections (not the <see cref="Client"/>!).
@@ -121,20 +110,14 @@ namespace AMWD.Protocols.Modbus.Tcp
Assertions(); Assertions();
_stopCts?.Cancel(); _stopCts?.Cancel();
_tcpListener.Stop();
_listener?.Stop();
#if NET8_0_OR_GREATER
_listener?.Dispose();
#endif
_stopCts?.Dispose(); _stopCts?.Dispose();
_stopCts = new CancellationTokenSource(); _stopCts = new CancellationTokenSource();
_listener = new TcpListener(ListenAddress, ListenPort); _tcpListener.Socket.DualMode = ListenAddress.AddressFamily == AddressFamily.InterNetworkV6;
if (ListenAddress.AddressFamily == AddressFamily.InterNetworkV6)
_listener.Server.DualMode = true;
_listener.Start(); _tcpListener.Start();
_clientConnectTask = WaitForClientAsync(_stopCts.Token); _clientConnectTask = WaitForClientAsync(_stopCts.Token);
return Task.CompletedTask; return Task.CompletedTask;
@@ -152,12 +135,9 @@ namespace AMWD.Protocols.Modbus.Tcp
private async Task StopAsyncInternal(CancellationToken cancellationToken = default) private async Task StopAsyncInternal(CancellationToken cancellationToken = default)
{ {
_stopCts.Cancel(); _stopCts?.Cancel();
_tcpListener.Stop();
_listener.Stop();
#if NET8_0_OR_GREATER
_listener.Dispose();
#endif
try try
{ {
await Task.WhenAny(_clientConnectTask, Task.Delay(Timeout.Infinite, cancellationToken)); await Task.WhenAny(_clientConnectTask, Task.Delay(Timeout.Infinite, cancellationToken));
@@ -191,6 +171,7 @@ namespace AMWD.Protocols.Modbus.Tcp
_clientListLock.Dispose(); _clientListLock.Dispose();
_clients.Clear(); _clients.Clear();
_tcpListener.Dispose();
_stopCts?.Dispose(); _stopCts?.Dispose();
} }
@@ -215,11 +196,7 @@ namespace AMWD.Protocols.Modbus.Tcp
{ {
try try
{ {
#if NET8_0_OR_GREATER var client = await _tcpListener.AcceptTcpClientAsync(cancellationToken);
var client = await _listener.AcceptTcpClientAsync(cancellationToken);
#else
var client = await _listener.AcceptTcpClientAsync();
#endif
await _clientListLock.WaitAsync(cancellationToken); await _clientListLock.WaitAsync(cancellationToken);
try try
{ {
@@ -238,7 +215,7 @@ namespace AMWD.Protocols.Modbus.Tcp
} }
} }
private async Task HandleClientAsync(TcpClient client, CancellationToken cancellationToken) private async Task HandleClientAsync(TcpClientWrapper client, CancellationToken cancellationToken)
{ {
try try
{ {
@@ -253,11 +230,11 @@ namespace AMWD.Protocols.Modbus.Tcp
byte[] headerBytes = await stream.ReadExpectedBytesAsync(6, cts.Token); byte[] headerBytes = await stream.ReadExpectedBytesAsync(6, cts.Token);
requestBytes.AddRange(headerBytes); requestBytes.AddRange(headerBytes);
byte[] followingCountBytes = headerBytes.Skip(4).Take(2).ToArray(); ushort length = headerBytes
followingCountBytes.SwapBigEndian(); .Skip(4).Take(2).ToArray()
int followingCount = BitConverter.ToUInt16(followingCountBytes, 0); .GetBigEndianUInt16();
byte[] bodyBytes = await stream.ReadExpectedBytesAsync(followingCount, cts.Token); byte[] bodyBytes = await stream.ReadExpectedBytesAsync(length, cts.Token);
requestBytes.AddRange(bodyBytes); requestBytes.AddRange(bodyBytes);
} }
@@ -322,14 +299,14 @@ namespace AMWD.Protocols.Modbus.Tcp
default: // unknown function default: // unknown function
{ {
byte[] responseBytes = new byte[9]; var responseBytes = new List<byte>();
Array.Copy(requestBytes, 0, responseBytes, 0, 8); responseBytes.AddRange(requestBytes.Take(8));
responseBytes.Add((byte)ModbusErrorCode.IllegalFunction);
// Mark as error // Mark as error
responseBytes[7] |= 0x80; responseBytes[7] |= 0x80;
responseBytes[8] = (byte)ModbusErrorCode.IllegalFunction; return Task.FromResult(ReturnResponse(responseBytes));
return Task.FromResult(responseBytes);
} }
} }
} }
@@ -662,6 +639,7 @@ namespace AMWD.Protocols.Modbus.Tcp
HighByte = requestBytes[baseOffset + i * 2], HighByte = requestBytes[baseOffset + i * 2],
LowByte = requestBytes[baseOffset + i * 2 + 1] LowByte = requestBytes[baseOffset + i * 2 + 1]
}); });
}
bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[6], list, cancellationToken); bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[6], list, cancellationToken);
if (isSuccess) if (isSuccess)
@@ -675,7 +653,6 @@ namespace AMWD.Protocols.Modbus.Tcp
responseBytes.Add((byte)ModbusErrorCode.SlaveDeviceFailure); responseBytes.Add((byte)ModbusErrorCode.SlaveDeviceFailure);
} }
} }
}
catch catch
{ {
responseBytes[7] |= 0x80; responseBytes[7] |= 0x80;
@@ -687,6 +664,9 @@ namespace AMWD.Protocols.Modbus.Tcp
private async Task<byte[]> HandleEncapsulatedInterfaceAsync(byte[] requestBytes, CancellationToken cancellationToken) private async Task<byte[]> HandleEncapsulatedInterfaceAsync(byte[] requestBytes, CancellationToken cancellationToken)
{ {
if (requestBytes.Length < 11)
return null;
var responseBytes = new List<byte>(); var responseBytes = new List<byte>();
responseBytes.AddRange(requestBytes.Take(8)); responseBytes.AddRange(requestBytes.Take(8));
@@ -718,7 +698,7 @@ namespace AMWD.Protocols.Modbus.Tcp
try try
{ {
var res = await Client.ReadDeviceIdentificationAsync(requestBytes[6], category, firstObject, cancellationToken); var deviceInfo = await Client.ReadDeviceIdentificationAsync(requestBytes[6], category, firstObject, cancellationToken);
var bodyBytes = new List<byte>(); var bodyBytes = new List<byte>();
@@ -727,31 +707,20 @@ namespace AMWD.Protocols.Modbus.Tcp
// Conformity // Conformity
bodyBytes.Add((byte)category); bodyBytes.Add((byte)category);
if (res.IsIndividualAccessAllowed) if (deviceInfo.IsIndividualAccessAllowed)
bodyBytes[2] |= 0x80; bodyBytes[2] |= 0x80;
// More, NextId, NumberOfObjects // More, NextId, NumberOfObjects
bodyBytes.AddRange(new byte[3]); bodyBytes.AddRange(new byte[3]);
int maxObjectId; int maxObjectId = category switch
switch (category)
{ {
case ModbusDeviceIdentificationCategory.Basic: ModbusDeviceIdentificationCategory.Basic => 0x02,
maxObjectId = 0x02; ModbusDeviceIdentificationCategory.Regular => 0x06,
break; ModbusDeviceIdentificationCategory.Extended => 0xFF,
// Individual
case ModbusDeviceIdentificationCategory.Regular: _ => requestBytes[10],
maxObjectId = 0x06; };
break;
case ModbusDeviceIdentificationCategory.Extended:
maxObjectId = 0xFF;
break;
default: // Individual
maxObjectId = requestBytes[10];
break;
}
byte numberOfObjects = 0; byte numberOfObjects = 0;
for (int i = requestBytes[10]; i <= maxObjectId; i++) for (int i = requestBytes[10]; i <= maxObjectId; i++)
@@ -760,7 +729,7 @@ namespace AMWD.Protocols.Modbus.Tcp
if (0x07 <= i && i <= 0x7F) if (0x07 <= i && i <= 0x7F)
continue; continue;
byte[] objBytes = GetDeviceObject((byte)i, res); byte[] objBytes = GetDeviceObject((byte)i, deviceInfo);
// We need to split the response if it would exceed the max ADU size // We need to split the response if it would exceed the max ADU size
if (responseBytes.Count + bodyBytes.Count + objBytes.Length > TcpProtocol.MAX_ADU_LENGTH) if (responseBytes.Count + bodyBytes.Count + objBytes.Length > TcpProtocol.MAX_ADU_LENGTH)
@@ -799,7 +768,7 @@ namespace AMWD.Protocols.Modbus.Tcp
{ {
case ModbusDeviceIdentificationObject.VendorName: case ModbusDeviceIdentificationObject.VendorName:
{ {
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorName); byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorName ?? "");
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }
@@ -807,7 +776,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.ProductCode: case ModbusDeviceIdentificationObject.ProductCode:
{ {
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductCode); byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductCode ?? "");
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }
@@ -815,7 +784,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.MajorMinorRevision: case ModbusDeviceIdentificationObject.MajorMinorRevision:
{ {
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.MajorMinorRevision); byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.MajorMinorRevision ?? "");
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }
@@ -823,7 +792,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.VendorUrl: case ModbusDeviceIdentificationObject.VendorUrl:
{ {
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorUrl); byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.VendorUrl ?? "");
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }
@@ -831,7 +800,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.ProductName: case ModbusDeviceIdentificationObject.ProductName:
{ {
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductName); byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ProductName ?? "");
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }
@@ -839,7 +808,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.ModelName: case ModbusDeviceIdentificationObject.ModelName:
{ {
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ModelName); byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.ModelName ?? "");
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }
@@ -847,7 +816,7 @@ namespace AMWD.Protocols.Modbus.Tcp
case ModbusDeviceIdentificationObject.UserApplicationName: case ModbusDeviceIdentificationObject.UserApplicationName:
{ {
byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.UserApplicationName); byte[] bytes = Encoding.UTF8.GetBytes(deviceIdentification.UserApplicationName ?? "");
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }
@@ -855,9 +824,8 @@ namespace AMWD.Protocols.Modbus.Tcp
default: default:
{ {
if (deviceIdentification.ExtendedObjects.ContainsKey(objectId)) if (deviceIdentification.ExtendedObjects.TryGetValue(objectId, out byte[] bytes))
{ {
byte[] bytes = deviceIdentification.ExtendedObjects[objectId];
result.Add((byte)bytes.Length); result.Add((byte)bytes.Length);
result.AddRange(bytes); result.AddRange(bytes);
} }

View File

@@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- Wrong _following bytes_ calculation in `ModbusTcpProxy`. - Wrong _following bytes_ calculation in `ModbusTcpProxy`.
- Wrong processing of `WriteMultipleHoldingRegisters` for proxies.
## [v0.3.2] (2024-09-04) ## [v0.3.2] (2024-09-04)