Async optimization

This commit is contained in:
2025-02-03 22:28:31 +01:00
parent 9283b04971
commit 241a9d114c
8 changed files with 133 additions and 104 deletions

View File

@@ -10,19 +10,27 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
/// <summary> /// <summary>
/// Base implementation of a Modbus client. /// Base implementation of a Modbus client.
/// </summary> /// </summary>
public abstract class ModbusClientBase : IDisposable /// <remarks>
/// Initializes a new instance of the <see cref="ModbusClientBase"/> class with a specific <see cref="IModbusConnection"/>.
/// </remarks>
/// <param name="connection">The <see cref="IModbusConnection"/> responsible for invoking the requests.</param>
/// <param name="disposeConnection">
/// <see langword="true"/> if the connection should be disposed of by Dispose(),
/// <see langword="false"/> otherwise if you inted to reuse the connection.
/// </param>
public abstract class ModbusClientBase(IModbusConnection connection, bool disposeConnection) : IDisposable
{ {
private bool _isDisposed; private bool _isDisposed;
/// <summary> /// <summary>
/// Gets or sets a value indicating whether the connection should be disposed of by <see cref="Dispose()"/>. /// Gets or sets a value indicating whether the connection should be disposed of by <see cref="Dispose()"/>.
/// </summary> /// </summary>
protected readonly bool disposeConnection; protected readonly bool disposeConnection = disposeConnection;
/// <summary> /// <summary>
/// Gets or sets the <see cref="IModbusConnection"/> responsible for invoking the requests. /// Gets or sets the <see cref="IModbusConnection"/> responsible for invoking the requests.
/// </summary> /// </summary>
protected readonly IModbusConnection connection; protected readonly IModbusConnection connection = connection ?? throw new ArgumentNullException(nameof(connection));
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="ModbusClientBase"/> class with a specific <see cref="IModbusConnection"/>. /// Initializes a new instance of the <see cref="ModbusClientBase"/> class with a specific <see cref="IModbusConnection"/>.
@@ -32,20 +40,6 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
: this(connection, true) : this(connection, true)
{ } { }
/// <summary>
/// Initializes a new instance of the <see cref="ModbusClientBase"/> class with a specific <see cref="IModbusConnection"/>.
/// </summary>
/// <param name="connection">The <see cref="IModbusConnection"/> responsible for invoking the requests.</param>
/// <param name="disposeConnection">
/// <see langword="true"/> if the connection should be disposed of by Dispose(),
/// <see langword="false"/> otherwise if you inted to reuse the connection.
/// </param>
public ModbusClientBase(IModbusConnection connection, bool disposeConnection)
{
this.connection = connection ?? throw new ArgumentNullException(nameof(connection));
this.disposeConnection = disposeConnection;
}
/// <summary> /// <summary>
/// Gets or sets the protocol type to use. /// Gets or sets the protocol type to use.
/// </summary> /// </summary>
@@ -67,7 +61,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeReadCoils(unitId, startAddress, count); var request = Protocol.SerializeReadCoils(unitId, startAddress, count);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
// The protocol processes complete bytes from the response. // The protocol processes complete bytes from the response.
@@ -92,7 +86,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeReadDiscreteInputs(unitId, startAddress, count); var request = Protocol.SerializeReadDiscreteInputs(unitId, startAddress, count);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
// The protocol processes complete bytes from the response. // The protocol processes complete bytes from the response.
@@ -117,7 +111,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeReadHoldingRegisters(unitId, startAddress, count); var request = Protocol.SerializeReadHoldingRegisters(unitId, startAddress, count);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
var holdingRegisters = Protocol.DeserializeReadHoldingRegisters(response).ToList(); var holdingRegisters = Protocol.DeserializeReadHoldingRegisters(response).ToList();
@@ -140,7 +134,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeReadInputRegisters(unitId, startAddress, count); var request = Protocol.SerializeReadInputRegisters(unitId, startAddress, count);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
var inputRegisters = Protocol.DeserializeReadInputRegisters(response).ToList(); var inputRegisters = Protocol.DeserializeReadInputRegisters(response).ToList();
@@ -184,7 +178,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
do do
{ {
var request = Protocol.SerializeReadDeviceIdentification(unitId, category, requestObjectId); var request = Protocol.SerializeReadDeviceIdentification(unitId, category, requestObjectId);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
result = Protocol.DeserializeReadDeviceIdentification(response); result = Protocol.DeserializeReadDeviceIdentification(response);
@@ -247,7 +241,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeWriteSingleCoil(unitId, coil); var request = Protocol.SerializeWriteSingleCoil(unitId, coil);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
var result = Protocol.DeserializeWriteSingleCoil(response); var result = Protocol.DeserializeWriteSingleCoil(response);
@@ -268,7 +262,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeWriteSingleHoldingRegister(unitId, register); var request = Protocol.SerializeWriteSingleHoldingRegister(unitId, register);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
var result = Protocol.DeserializeWriteSingleHoldingRegister(response); var result = Protocol.DeserializeWriteSingleHoldingRegister(response);
@@ -289,7 +283,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeWriteMultipleCoils(unitId, coils); var request = Protocol.SerializeWriteMultipleCoils(unitId, coils);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
var (firstAddress, count) = Protocol.DeserializeWriteMultipleCoils(response); var (firstAddress, count) = Protocol.DeserializeWriteMultipleCoils(response);
@@ -309,7 +303,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
Assertions(); Assertions();
var request = Protocol.SerializeWriteMultipleHoldingRegisters(unitId, registers); var request = Protocol.SerializeWriteMultipleHoldingRegisters(unitId, registers);
var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken); var response = await connection.InvokeAsync(request, Protocol.CheckResponseComplete, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
Protocol.ValidateResponse(request, response); Protocol.ValidateResponse(request, response);
var (firstAddress, count) = Protocol.DeserializeWriteMultipleHoldingRegisters(response); var (firstAddress, count) = Protocol.DeserializeWriteMultipleHoldingRegisters(response);

View File

@@ -192,17 +192,16 @@ namespace AMWD.Protocols.Modbus.Serial
public Task StopAsync(CancellationToken cancellationToken = default) public Task StopAsync(CancellationToken cancellationToken = default)
{ {
Assertions(); Assertions();
return StopAsyncInternal(cancellationToken); StopAsyncInternal();
return Task.CompletedTask;
} }
private Task StopAsyncInternal(CancellationToken cancellationToken) private void StopAsyncInternal()
{ {
_stopCts?.Cancel(); _stopCts?.Cancel();
_serialPort.Close(); _serialPort.Close();
_serialPort.DataReceived -= OnDataReceived; _serialPort.DataReceived -= OnDataReceived;
return Task.CompletedTask;
} }
/// <summary> /// <summary>
@@ -215,7 +214,7 @@ namespace AMWD.Protocols.Modbus.Serial
_isDisposed = true; _isDisposed = true;
StopAsyncInternal(CancellationToken.None).Wait(); StopAsyncInternal();
_serialPort.Dispose(); _serialPort.Dispose();
_stopCts?.Dispose(); _stopCts?.Dispose();
@@ -332,7 +331,7 @@ namespace AMWD.Protocols.Modbus.Serial
responseBytes.AddRange(requestBytes.Take(2)); responseBytes.AddRange(requestBytes.Take(2));
try try
{ {
var coils = await Client.ReadCoilsAsync(unitId, firstAddress, count, cancellationToken); var coils = await Client.ReadCoilsAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[(int)Math.Ceiling(coils.Count / 8.0)]; byte[] values = new byte[(int)Math.Ceiling(coils.Count / 8.0)];
for (int i = 0; i < coils.Count; i++) for (int i = 0; i < coils.Count; i++)
@@ -371,7 +370,7 @@ namespace AMWD.Protocols.Modbus.Serial
responseBytes.AddRange(requestBytes.Take(2)); responseBytes.AddRange(requestBytes.Take(2));
try try
{ {
var discreteInputs = await Client.ReadDiscreteInputsAsync(unitId, firstAddress, count, cancellationToken); var discreteInputs = await Client.ReadDiscreteInputsAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[(int)Math.Ceiling(discreteInputs.Count / 8.0)]; byte[] values = new byte[(int)Math.Ceiling(discreteInputs.Count / 8.0)];
for (int i = 0; i < discreteInputs.Count; i++) for (int i = 0; i < discreteInputs.Count; i++)
@@ -410,7 +409,7 @@ namespace AMWD.Protocols.Modbus.Serial
responseBytes.AddRange(requestBytes.Take(2)); responseBytes.AddRange(requestBytes.Take(2));
try try
{ {
var holdingRegisters = await Client.ReadHoldingRegistersAsync(unitId, firstAddress, count, cancellationToken); var holdingRegisters = await Client.ReadHoldingRegistersAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[holdingRegisters.Count * 2]; byte[] values = new byte[holdingRegisters.Count * 2];
for (int i = 0; i < holdingRegisters.Count; i++) for (int i = 0; i < holdingRegisters.Count; i++)
@@ -444,7 +443,7 @@ namespace AMWD.Protocols.Modbus.Serial
responseBytes.AddRange(requestBytes.Take(2)); responseBytes.AddRange(requestBytes.Take(2));
try try
{ {
var inputRegisters = await Client.ReadInputRegistersAsync(unitId, firstAddress, count, cancellationToken); var inputRegisters = await Client.ReadInputRegistersAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[count * 2]; byte[] values = new byte[count * 2];
for (int i = 0; i < count; i++) for (int i = 0; i < count; i++)
@@ -492,7 +491,7 @@ namespace AMWD.Protocols.Modbus.Serial
LowByte = requestBytes[5], LowByte = requestBytes[5],
}; };
bool isSuccess = await Client.WriteSingleCoilAsync(requestBytes[0], coil, cancellationToken); bool isSuccess = await Client.WriteSingleCoilAsync(requestBytes[0], coil, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -531,7 +530,7 @@ namespace AMWD.Protocols.Modbus.Serial
LowByte = requestBytes[5] LowByte = requestBytes[5]
}; };
bool isSuccess = await Client.WriteSingleHoldingRegisterAsync(requestBytes[0], register, cancellationToken); bool isSuccess = await Client.WriteSingleHoldingRegisterAsync(requestBytes[0], register, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -591,7 +590,7 @@ namespace AMWD.Protocols.Modbus.Serial
}); });
} }
bool isSuccess = await Client.WriteMultipleCoilsAsync(requestBytes[0], coils, cancellationToken); bool isSuccess = await Client.WriteMultipleCoilsAsync(requestBytes[0], coils, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -648,7 +647,7 @@ namespace AMWD.Protocols.Modbus.Serial
}); });
} }
bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[0], list, cancellationToken); bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[0], list, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -705,7 +704,7 @@ namespace AMWD.Protocols.Modbus.Serial
try try
{ {
var deviceInfo = await Client.ReadDeviceIdentificationAsync(requestBytes[0], category, firstObject, cancellationToken); var deviceInfo = await Client.ReadDeviceIdentificationAsync(requestBytes[0], category, firstObject, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
var bodyBytes = new List<byte>(); var bodyBytes = new List<byte>();
@@ -855,5 +854,21 @@ namespace AMWD.Protocols.Modbus.Serial
} }
#endregion Request Handling #endregion Request Handling
/// <inheritdoc/>
public override string ToString()
{
var sb = new StringBuilder();
sb.AppendLine($"RTU Proxy");
sb.AppendLine($" {nameof(PortName)}: {PortName}");
sb.AppendLine($" {nameof(BaudRate)}: {(int)BaudRate}");
sb.AppendLine($" {nameof(DataBits)}: {DataBits}");
sb.AppendLine($" {nameof(StopBits)}: {StopBits}");
sb.AppendLine($" {nameof(Parity)}: {Parity}");
sb.AppendLine($" {nameof(Client)}: {Client.GetType().Name}");
return sb.ToString();
}
} }
} }

View File

@@ -31,8 +31,7 @@ namespace AMWD.Protocols.Modbus.Serial
private readonly Task _processingTask; private readonly Task _processingTask;
private readonly AsyncQueue<RequestQueueItem> _requestQueue = new(); private readonly AsyncQueue<RequestQueueItem> _requestQueue = new();
// Only required to cover all logic branches on unit tests. private readonly bool _isLinux;
private bool _isUnitTest = false;
#endregion Fields #endregion Fields
@@ -41,6 +40,8 @@ namespace AMWD.Protocols.Modbus.Serial
/// </summary> /// </summary>
public ModbusSerialConnection(string portName) public ModbusSerialConnection(string portName)
{ {
_isLinux = RuntimeInformation.IsOSPlatform(OSPlatform.Linux);
if (string.IsNullOrWhiteSpace(portName)) if (string.IsNullOrWhiteSpace(portName))
throw new ArgumentNullException(nameof(portName)); throw new ArgumentNullException(nameof(portName));
@@ -268,7 +269,7 @@ namespace AMWD.Protocols.Modbus.Serial
try try
{ {
// Get next request to process // Get next request to process
var item = await _requestQueue.DequeueAsync(cancellationToken); var item = await _requestQueue.DequeueAsync(cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
// Remove registration => already removed from queue // Remove registration => already removed from queue
item.CancellationTokenRegistration.Dispose(); item.CancellationTokenRegistration.Dispose();
@@ -276,13 +277,13 @@ namespace AMWD.Protocols.Modbus.Serial
// Build combined cancellation token // Build combined cancellation token
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, item.CancellationTokenSource.Token); using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, item.CancellationTokenSource.Token);
// Wait for exclusive access // Wait for exclusive access
await _portLock.WaitAsync(linkedCts.Token); await _portLock.WaitAsync(linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
try try
{ {
// Ensure connection is up // Ensure connection is up
await AssertConnection(linkedCts.Token); await AssertConnection(linkedCts.Token);
await _serialPort.WriteAsync(item.Request, linkedCts.Token); await _serialPort.WriteAsync(item.Request, linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
linkedCts.Token.ThrowIfCancellationRequested(); linkedCts.Token.ThrowIfCancellationRequested();
@@ -291,7 +292,7 @@ namespace AMWD.Protocols.Modbus.Serial
do do
{ {
int readCount = await _serialPort.ReadAsync(buffer, 0, buffer.Length, linkedCts.Token); int readCount = await _serialPort.ReadAsync(buffer, 0, buffer.Length, linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
if (readCount < 1) if (readCount < 1)
throw new EndOfStreamException(); throw new EndOfStreamException();
@@ -322,7 +323,7 @@ namespace AMWD.Protocols.Modbus.Serial
_portLock.Release(); _portLock.Release();
_idleTimer.Change(IdleTimeout, Timeout.InfiniteTimeSpan); _idleTimer.Change(IdleTimeout, Timeout.InfiniteTimeSpan);
await Task.Delay(InterRequestDelay, cancellationToken); await Task.Delay(InterRequestDelay, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
} }
} }
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
@@ -353,7 +354,7 @@ namespace AMWD.Protocols.Modbus.Serial
_serialPort.Close(); _serialPort.Close();
_serialPort.ResetRS485DriverStateFlags(); _serialPort.ResetRS485DriverStateFlags();
if (DriverEnabledRS485 && (RuntimeInformation.IsOSPlatform(OSPlatform.Linux) || _isUnitTest)) if (DriverEnabledRS485 && _isLinux)
{ {
var flags = _serialPort.GetRS485DriverStateFlags(); var flags = _serialPort.GetRS485DriverStateFlags();
flags |= RS485Flags.Enabled; flags |= RS485Flags.Enabled;
@@ -361,7 +362,7 @@ namespace AMWD.Protocols.Modbus.Serial
_serialPort.ChangeRS485DriverStateFlags(flags); _serialPort.ChangeRS485DriverStateFlags(flags);
} }
using var connectTask = Task.Run(_serialPort.Open); using var connectTask = Task.Run(_serialPort.Open, cancellationToken);
if (await Task.WhenAny(connectTask, Task.Delay(ReadTimeout, cancellationToken)) == connectTask) if (await Task.WhenAny(connectTask, Task.Delay(ReadTimeout, cancellationToken)) == connectTask)
{ {
await connectTask; await connectTask;
@@ -379,7 +380,7 @@ namespace AMWD.Protocols.Modbus.Serial
try try
{ {
await Task.Delay(TimeSpan.FromSeconds(delay), cancellationToken); await Task.Delay(TimeSpan.FromSeconds(delay), cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
} }
catch catch
{ /* keep it quiet */ } { /* keep it quiet */ }

View File

@@ -12,7 +12,7 @@ namespace System.IO
int offset = 0; int offset = 0;
do do
{ {
int count = await stream.ReadAsync(buffer, offset, expectedBytes - offset, cancellationToken); int count = await stream.ReadAsync(buffer, offset, expectedBytes - offset, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (count < 1) if (count < 1)
throw new EndOfStreamException(); throw new EndOfStreamException();
@@ -30,7 +30,7 @@ namespace System.IO
int offset = 0; int offset = 0;
do do
{ {
int count = await stream.ReadAsync(buffer, offset, expectedBytes - offset, cancellationToken); int count = await stream.ReadAsync(buffer, offset, expectedBytes - offset, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (count < 1) if (count < 1)
throw new EndOfStreamException(); throw new EndOfStreamException();

View File

@@ -0,0 +1,17 @@
using System.Threading.Tasks;
namespace AMWD.Protocols.Modbus.Tcp.Extensions
{
internal static class TaskExtensions
{
public static async void Forget(this Task task)
{
try
{
await task;
}
catch
{ /* keep it quiet */ }
}
}
}

View File

@@ -208,7 +208,7 @@ namespace AMWD.Protocols.Modbus.Tcp
try try
{ {
// Get next request to process // Get next request to process
var item = await _requestQueue.DequeueAsync(cancellationToken); var item = await _requestQueue.DequeueAsync(cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
// Remove registration => already removed from queue // Remove registration => already removed from queue
item.CancellationTokenRegistration.Dispose(); item.CancellationTokenRegistration.Dispose();
@@ -216,19 +216,19 @@ namespace AMWD.Protocols.Modbus.Tcp
// Build combined cancellation token // Build combined cancellation token
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, item.CancellationTokenSource.Token); using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, item.CancellationTokenSource.Token);
// Wait for exclusive access // Wait for exclusive access
await _clientLock.WaitAsync(linkedCts.Token); await _clientLock.WaitAsync(linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
try try
{ {
// Ensure connection is up // Ensure connection is up
await AssertConnection(linkedCts.Token); await AssertConnection(linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
var stream = _tcpClient.GetStream(); var stream = _tcpClient.GetStream();
await stream.FlushAsync(linkedCts.Token); await stream.FlushAsync(linkedCts.Token);
#if NET6_0_OR_GREATER #if NET6_0_OR_GREATER
await stream.WriteAsync(item.Request, linkedCts.Token); await stream.WriteAsync(item.Request, linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
#else #else
await stream.WriteAsync(item.Request, 0, item.Request.Length, linkedCts.Token); await stream.WriteAsync(item.Request, 0, item.Request.Length, linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
#endif #endif
linkedCts.Token.ThrowIfCancellationRequested(); linkedCts.Token.ThrowIfCancellationRequested();
@@ -239,9 +239,9 @@ namespace AMWD.Protocols.Modbus.Tcp
do do
{ {
#if NET6_0_OR_GREATER #if NET6_0_OR_GREATER
int readCount = await stream.ReadAsync(buffer, linkedCts.Token); int readCount = await stream.ReadAsync(buffer, linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
#else #else
int readCount = await stream.ReadAsync(buffer, 0, buffer.Length, linkedCts.Token); int readCount = await stream.ReadAsync(buffer, 0, buffer.Length, linkedCts.Token).ConfigureAwait(continueOnCapturedContext: false);
#endif #endif
if (readCount < 1) if (readCount < 1)
throw new EndOfStreamException(); throw new EndOfStreamException();
@@ -332,7 +332,7 @@ namespace AMWD.Protocols.Modbus.Tcp
try try
{ {
await Task.Delay(TimeSpan.FromSeconds(delay), cancellationToken); await Task.Delay(TimeSpan.FromSeconds(delay), cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
} }
catch catch
{ /* keep it quiet */ } { /* keep it quiet */ }

View File

@@ -10,6 +10,7 @@ using System.Threading.Tasks;
using AMWD.Protocols.Modbus.Common; using AMWD.Protocols.Modbus.Common;
using AMWD.Protocols.Modbus.Common.Contracts; using AMWD.Protocols.Modbus.Common.Contracts;
using AMWD.Protocols.Modbus.Common.Protocols; using AMWD.Protocols.Modbus.Common.Protocols;
using AMWD.Protocols.Modbus.Tcp.Extensions;
using AMWD.Protocols.Modbus.Tcp.Utils; using AMWD.Protocols.Modbus.Tcp.Utils;
namespace AMWD.Protocols.Modbus.Tcp namespace AMWD.Protocols.Modbus.Tcp
@@ -17,7 +18,12 @@ namespace AMWD.Protocols.Modbus.Tcp
/// <summary> /// <summary>
/// Implements a Modbus TCP server proxying all requests to a Modbus client of choice. /// Implements a Modbus TCP server proxying all requests to a Modbus client of choice.
/// </summary> /// </summary>
public class ModbusTcpProxy : IModbusProxy /// <remarks>
/// Initializes a new instance of the <see cref="ModbusTcpProxy"/> class.
/// </remarks>
/// <param name="client">The <see cref="ModbusClientBase"/> used to request the remote device, that should be proxied.</param>
/// <param name="listenAddress">An <see cref="IPAddress"/> to listen on.</param>
public class ModbusTcpProxy(ModbusClientBase client, IPAddress listenAddress) : IModbusProxy
{ {
#region Fields #region Fields
@@ -25,30 +31,17 @@ namespace AMWD.Protocols.Modbus.Tcp
private TimeSpan _readWriteTimeout = TimeSpan.FromSeconds(100); private TimeSpan _readWriteTimeout = TimeSpan.FromSeconds(100);
private TcpListenerWrapper _tcpListener; private readonly TcpListenerWrapper _tcpListener = new(listenAddress, 502);
private CancellationTokenSource _stopCts; private CancellationTokenSource _stopCts;
private Task _clientConnectTask = Task.CompletedTask; private Task _clientConnectTask = Task.CompletedTask;
private readonly SemaphoreSlim _clientListLock = new(1, 1); private readonly SemaphoreSlim _clientListLock = new(1, 1);
private readonly List<TcpClientWrapper> _clients = []; private readonly List<TcpClientWrapper> _clients = [];
private readonly List<Task> _clientTasks = [];
#endregion Fields #endregion Fields
#region Constructors #region Constructors
/// <summary>
/// Initializes a new instance of the <see cref="ModbusTcpProxy"/> class.
/// </summary>
/// <param name="client">The <see cref="ModbusClientBase"/> used to request the remote device, that should be proxied.</param>
/// <param name="listenAddress">An <see cref="IPAddress"/> to listen on.</param>
public ModbusTcpProxy(ModbusClientBase client, IPAddress listenAddress)
{
Client = client ?? throw new ArgumentNullException(nameof(client));
_tcpListener = new TcpListenerWrapper(listenAddress, 502);
}
#endregion Constructors #endregion Constructors
#region Properties #region Properties
@@ -56,7 +49,7 @@ namespace AMWD.Protocols.Modbus.Tcp
/// <summary> /// <summary>
/// Gets the Modbus client used to request the remote device, that should be proxied. /// Gets the Modbus client used to request the remote device, that should be proxied.
/// </summary> /// </summary>
public ModbusClientBase Client { get; } public ModbusClientBase Client { get; } = client ?? throw new ArgumentNullException(nameof(client));
/// <summary> /// <summary>
/// Gets the <see cref="IPAddress"/> to listen on. /// Gets the <see cref="IPAddress"/> to listen on.
@@ -140,16 +133,7 @@ namespace AMWD.Protocols.Modbus.Tcp
try try
{ {
await Task.WhenAny(_clientConnectTask, Task.Delay(Timeout.Infinite, cancellationToken)); await Task.WhenAny(_clientConnectTask, Task.Delay(Timeout.Infinite, cancellationToken)).ConfigureAwait(continueOnCapturedContext: false);
}
catch (OperationCanceledException)
{
// Terminated
}
try
{
await Task.WhenAny(Task.WhenAll(_clientTasks), Task.Delay(Timeout.Infinite, cancellationToken));
} }
catch (OperationCanceledException) catch (OperationCanceledException)
{ {
@@ -174,6 +158,7 @@ namespace AMWD.Protocols.Modbus.Tcp
_tcpListener.Dispose(); _tcpListener.Dispose();
_stopCts?.Dispose(); _stopCts?.Dispose();
GC.SuppressFinalize(this);
} }
private void Assertions() private void Assertions()
@@ -196,12 +181,13 @@ namespace AMWD.Protocols.Modbus.Tcp
{ {
try try
{ {
var client = await _tcpListener.AcceptTcpClientAsync(cancellationToken); var client = await _tcpListener.AcceptTcpClientAsync(cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
await _clientListLock.WaitAsync(cancellationToken); await _clientListLock.WaitAsync(cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
try try
{ {
_clients.Add(client); _clients.Add(client);
_clientTasks.Add(HandleClientAsync(client, cancellationToken)); // Can be ignored as it will terminate by itself on cancellation
HandleClientAsync(client, cancellationToken).Forget();
} }
finally finally
{ {
@@ -227,20 +213,20 @@ namespace AMWD.Protocols.Modbus.Tcp
using (var cts = new CancellationTokenSource(ReadWriteTimeout)) using (var cts = new CancellationTokenSource(ReadWriteTimeout))
using (cancellationToken.Register(cts.Cancel)) using (cancellationToken.Register(cts.Cancel))
{ {
byte[] headerBytes = await stream.ReadExpectedBytesAsync(6, cts.Token); byte[] headerBytes = await stream.ReadExpectedBytesAsync(6, cts.Token).ConfigureAwait(continueOnCapturedContext: false);
requestBytes.AddRange(headerBytes); requestBytes.AddRange(headerBytes);
ushort length = headerBytes ushort length = headerBytes
.Skip(4).Take(2).ToArray() .Skip(4).Take(2).ToArray()
.GetBigEndianUInt16(); .GetBigEndianUInt16();
byte[] bodyBytes = await stream.ReadExpectedBytesAsync(length, cts.Token); byte[] bodyBytes = await stream.ReadExpectedBytesAsync(length, cts.Token).ConfigureAwait(continueOnCapturedContext: false);
requestBytes.AddRange(bodyBytes); requestBytes.AddRange(bodyBytes);
} }
byte[] responseBytes = await HandleRequestAsync([.. requestBytes], cancellationToken); byte[] responseBytes = await HandleRequestAsync([.. requestBytes], cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (responseBytes != null) if (responseBytes != null)
await stream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken); await stream.WriteAsync(responseBytes, 0, responseBytes.Length, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
} }
} }
catch catch
@@ -249,7 +235,7 @@ namespace AMWD.Protocols.Modbus.Tcp
} }
finally finally
{ {
await _clientListLock.WaitAsync(cancellationToken); await _clientListLock.WaitAsync(cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
try try
{ {
_clients.Remove(client); _clients.Remove(client);
@@ -324,7 +310,7 @@ namespace AMWD.Protocols.Modbus.Tcp
responseBytes.AddRange(requestBytes.Take(8)); responseBytes.AddRange(requestBytes.Take(8));
try try
{ {
var coils = await Client.ReadCoilsAsync(unitId, firstAddress, count, cancellationToken); var coils = await Client.ReadCoilsAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[(int)Math.Ceiling(coils.Count / 8.0)]; byte[] values = new byte[(int)Math.Ceiling(coils.Count / 8.0)];
for (int i = 0; i < coils.Count; i++) for (int i = 0; i < coils.Count; i++)
@@ -363,7 +349,7 @@ namespace AMWD.Protocols.Modbus.Tcp
responseBytes.AddRange(requestBytes.Take(8)); responseBytes.AddRange(requestBytes.Take(8));
try try
{ {
var discreteInputs = await Client.ReadDiscreteInputsAsync(unitId, firstAddress, count, cancellationToken); var discreteInputs = await Client.ReadDiscreteInputsAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[(int)Math.Ceiling(discreteInputs.Count / 8.0)]; byte[] values = new byte[(int)Math.Ceiling(discreteInputs.Count / 8.0)];
for (int i = 0; i < discreteInputs.Count; i++) for (int i = 0; i < discreteInputs.Count; i++)
@@ -402,7 +388,7 @@ namespace AMWD.Protocols.Modbus.Tcp
responseBytes.AddRange(requestBytes.Take(8)); responseBytes.AddRange(requestBytes.Take(8));
try try
{ {
var holdingRegisters = await Client.ReadHoldingRegistersAsync(unitId, firstAddress, count, cancellationToken); var holdingRegisters = await Client.ReadHoldingRegistersAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[holdingRegisters.Count * 2]; byte[] values = new byte[holdingRegisters.Count * 2];
for (int i = 0; i < holdingRegisters.Count; i++) for (int i = 0; i < holdingRegisters.Count; i++)
@@ -436,7 +422,7 @@ namespace AMWD.Protocols.Modbus.Tcp
responseBytes.AddRange(requestBytes.Take(8)); responseBytes.AddRange(requestBytes.Take(8));
try try
{ {
var inputRegisters = await Client.ReadInputRegistersAsync(unitId, firstAddress, count, cancellationToken); var inputRegisters = await Client.ReadInputRegistersAsync(unitId, firstAddress, count, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
byte[] values = new byte[count * 2]; byte[] values = new byte[count * 2];
for (int i = 0; i < count; i++) for (int i = 0; i < count; i++)
@@ -484,7 +470,7 @@ namespace AMWD.Protocols.Modbus.Tcp
LowByte = requestBytes[11], LowByte = requestBytes[11],
}; };
bool isSuccess = await Client.WriteSingleCoilAsync(requestBytes[6], coil, cancellationToken); bool isSuccess = await Client.WriteSingleCoilAsync(requestBytes[6], coil, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -524,7 +510,7 @@ namespace AMWD.Protocols.Modbus.Tcp
LowByte = requestBytes[11] LowByte = requestBytes[11]
}; };
bool isSuccess = await Client.WriteSingleHoldingRegisterAsync(requestBytes[6], register, cancellationToken); bool isSuccess = await Client.WriteSingleHoldingRegisterAsync(requestBytes[6], register, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -584,7 +570,7 @@ namespace AMWD.Protocols.Modbus.Tcp
}); });
} }
bool isSuccess = await Client.WriteMultipleCoilsAsync(requestBytes[6], coils, cancellationToken); bool isSuccess = await Client.WriteMultipleCoilsAsync(requestBytes[6], coils, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -641,7 +627,7 @@ namespace AMWD.Protocols.Modbus.Tcp
}); });
} }
bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[6], list, cancellationToken); bool isSuccess = await Client.WriteMultipleHoldingRegistersAsync(requestBytes[6], list, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
if (isSuccess) if (isSuccess)
{ {
// Response is an echo of the request // Response is an echo of the request
@@ -698,7 +684,7 @@ namespace AMWD.Protocols.Modbus.Tcp
try try
{ {
var deviceInfo = await Client.ReadDeviceIdentificationAsync(requestBytes[6], category, firstObject, cancellationToken); var deviceInfo = await Client.ReadDeviceIdentificationAsync(requestBytes[6], category, firstObject, cancellationToken).ConfigureAwait(continueOnCapturedContext: false);
var bodyBytes = new List<byte>(); var bodyBytes = new List<byte>();
@@ -761,7 +747,7 @@ namespace AMWD.Protocols.Modbus.Tcp
} }
} }
private byte[] GetDeviceObject(byte objectId, DeviceIdentification deviceIdentification) private static byte[] GetDeviceObject(byte objectId, DeviceIdentification deviceIdentification)
{ {
var result = new List<byte> { objectId }; var result = new List<byte> { objectId };
switch ((ModbusDeviceIdentificationObject)objectId) switch ((ModbusDeviceIdentificationObject)objectId)
@@ -851,5 +837,18 @@ namespace AMWD.Protocols.Modbus.Tcp
} }
#endregion Request Handling #endregion Request Handling
/// <inheritdoc/>
public override string ToString()
{
var sb = new StringBuilder();
sb.AppendLine($"TCP Proxy");
sb.AppendLine($" {nameof(ListenAddress)}: {ListenAddress}");
sb.AppendLine($" {nameof(ListenPort)}: {ListenPort}");
sb.AppendLine($" {nameof(Client)}: {Client.GetType().Name}");
return sb.ToString();
}
} }
} }

View File

@@ -3,6 +3,9 @@ using System.Net.Sockets;
namespace AMWD.Protocols.Modbus.Tcp.Utils namespace AMWD.Protocols.Modbus.Tcp.Utils
{ {
/// <summary>
/// Factory for creating <see cref="TcpClientWrapper"/> instances.
/// </summary>
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
internal class TcpClientWrapperFactory internal class TcpClientWrapperFactory
{ {