Refactoring connection to use an idle timeout and automatically close the underlying data channel

This commit is contained in:
2024-03-31 22:29:07 +02:00
parent 967d80ff3f
commit a58af4d75f
16 changed files with 812 additions and 1198 deletions

View File

@@ -16,23 +16,14 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
string Name { get; } string Name { get; }
/// <summary> /// <summary>
/// Gets a value indicating whether the connection is open. /// Gets or sets the idle time after that the connection is closed.
/// </summary> /// </summary>
bool IsConnected { get; } /// <remarks>
/// Set to <see cref="Timeout.InfiniteTimeSpan"/> to disable idle closing the connection.
/// <summary> /// <br/>
/// Opens the connection to the remote device. /// Set to <see cref="TimeSpan.Zero"/> to close the connection immediately after each request.
/// </summary> /// </remarks>
/// <param name="cancellationToken">A cancellation token used to propagate notification that this operation should be canceled.</param> TimeSpan IdleTimeout { get; set; }
/// <returns>An awaitable <see cref="Task"/>.</returns>
Task ConnectAsync(CancellationToken cancellationToken = default);
/// <summary>
/// Closes the connection to the remote device.
/// </summary>
/// <param name="cancellationToken">A cancellation token used to propagate notification that this operation should be canceled.</param>
/// <returns>An awaitable <see cref="Task"/>.</returns>
Task DisconnectAsync(CancellationToken cancellationToken = default);
/// <summary> /// <summary>
/// Invokes a Modbus request. /// Invokes a Modbus request.

View File

@@ -46,11 +46,6 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
this.disposeConnection = disposeConnection; this.disposeConnection = disposeConnection;
} }
/// <summary>
/// Gets a value indicating whether the client is connected.
/// </summary>
public bool IsConnected => connection.IsConnected;
/// <summary> /// <summary>
/// Gets or sets the protocol type to use. /// Gets or sets the protocol type to use.
/// </summary> /// </summary>
@@ -59,28 +54,6 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
/// </remarks> /// </remarks>
public abstract IModbusProtocol Protocol { get; set; } public abstract IModbusProtocol Protocol { get; set; }
/// <summary>
/// Starts the connection to the remote endpoint.
/// </summary>
/// <param name="cancellationToken">A cancellation token used to propagate notification that this operation should be canceled.</param>
/// <returns>An awaitable <see cref="Task"/>.</returns>
public virtual Task ConnectAsync(CancellationToken cancellationToken = default)
{
Assertions(false);
return connection.ConnectAsync(cancellationToken);
}
/// <summary>
/// Stops the connection to the remote endpoint.
/// </summary>
/// <param name="cancellationToken">A cancellation token used to propagate notification that this operation should be canceled.</param>
/// <returns>An awaitable <see cref="Task"/>.</returns>
public virtual Task DisconnectAsync(CancellationToken cancellationToken = default)
{
Assertions(false);
return connection.DisconnectAsync(cancellationToken);
}
/// <summary> /// <summary>
/// Reads multiple <see cref="Coil"/>s. /// Reads multiple <see cref="Coil"/>s.
/// </summary> /// </summary>
@@ -222,31 +195,31 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
switch ((ModbusDeviceIdentificationObject)item.Key) switch ((ModbusDeviceIdentificationObject)item.Key)
{ {
case ModbusDeviceIdentificationObject.VendorName: case ModbusDeviceIdentificationObject.VendorName:
devIdent.VendorName = Encoding.ASCII.GetString(item.Value); devIdent.VendorName = Encoding.UTF8.GetString(item.Value);
break; break;
case ModbusDeviceIdentificationObject.ProductCode: case ModbusDeviceIdentificationObject.ProductCode:
devIdent.ProductCode = Encoding.ASCII.GetString(item.Value); devIdent.ProductCode = Encoding.UTF8.GetString(item.Value);
break; break;
case ModbusDeviceIdentificationObject.MajorMinorRevision: case ModbusDeviceIdentificationObject.MajorMinorRevision:
devIdent.MajorMinorRevision = Encoding.ASCII.GetString(item.Value); devIdent.MajorMinorRevision = Encoding.UTF8.GetString(item.Value);
break; break;
case ModbusDeviceIdentificationObject.VendorUrl: case ModbusDeviceIdentificationObject.VendorUrl:
devIdent.VendorUrl = Encoding.ASCII.GetString(item.Value); devIdent.VendorUrl = Encoding.UTF8.GetString(item.Value);
break; break;
case ModbusDeviceIdentificationObject.ProductName: case ModbusDeviceIdentificationObject.ProductName:
devIdent.ProductName = Encoding.ASCII.GetString(item.Value); devIdent.ProductName = Encoding.UTF8.GetString(item.Value);
break; break;
case ModbusDeviceIdentificationObject.ModelName: case ModbusDeviceIdentificationObject.ModelName:
devIdent.ModelName = Encoding.ASCII.GetString(item.Value); devIdent.ModelName = Encoding.UTF8.GetString(item.Value);
break; break;
case ModbusDeviceIdentificationObject.UserApplicationName: case ModbusDeviceIdentificationObject.UserApplicationName:
devIdent.UserApplicationName = Encoding.ASCII.GetString(item.Value); devIdent.UserApplicationName = Encoding.UTF8.GetString(item.Value);
break; break;
default: default:
@@ -375,7 +348,7 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
/// <summary> /// <summary>
/// Performs basic assertions. /// Performs basic assertions.
/// </summary> /// </summary>
protected virtual void Assertions(bool checkConnected = true) protected virtual void Assertions()
{ {
#if NET8_0_OR_GREATER #if NET8_0_OR_GREATER
ObjectDisposedException.ThrowIf(_isDisposed, this); ObjectDisposedException.ThrowIf(_isDisposed, this);
@@ -390,12 +363,6 @@ namespace AMWD.Protocols.Modbus.Common.Contracts
if (Protocol == null) if (Protocol == null)
throw new ArgumentNullException(nameof(Protocol)); throw new ArgumentNullException(nameof(Protocol));
#endif #endif
if (!checkConnected)
return;
if (!IsConnected)
throw new ApplicationException($"Connection is not open");
} }
} }
} }

View File

@@ -7,6 +7,7 @@ This package contains all basic tools to build your own clients.
**IModbusConnection** **IModbusConnection**
This is the interface used on the base client to communicate with the remote device. This is the interface used on the base client to communicate with the remote device.
If you want to use a custom connection type, you should implement this interface yourself. If you want to use a custom connection type, you should implement this interface yourself.
The `IModbusConnection` is responsible to open and close the data channel in the background.
**IModbusProtocol** **IModbusProtocol**
If you want to speak a custom type of protocol with the clients, you can implement this interface. If you want to speak a custom type of protocol with the clients, you can implement this interface.

View File

@@ -17,6 +17,7 @@
<Compile Include="../AMWD.Protocols.Modbus.Common/InternalsVisibleTo.cs" Link="InternalsVisibleTo.cs" /> <Compile Include="../AMWD.Protocols.Modbus.Common/InternalsVisibleTo.cs" Link="InternalsVisibleTo.cs" />
<Compile Include="../AMWD.Protocols.Modbus.Common/Extensions/ArrayExtensions.cs" Link="Extensions/ArrayExtensions.cs" /> <Compile Include="../AMWD.Protocols.Modbus.Common/Extensions/ArrayExtensions.cs" Link="Extensions/ArrayExtensions.cs" />
<Compile Include="../AMWD.Protocols.Modbus.Common/Extensions/ReaderWriterLockSlimExtensions.cs" Link="Extensions/ReaderWriterLockSlimExtensions.cs" /> <Compile Include="../AMWD.Protocols.Modbus.Common/Extensions/ReaderWriterLockSlimExtensions.cs" Link="Extensions/ReaderWriterLockSlimExtensions.cs" />
<Compile Include="../AMWD.Protocols.Modbus.Common/Utils/AsyncQueue.cs" Link="Utils/AsyncQueue.cs" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@@ -111,37 +111,37 @@ namespace AMWD.Protocols.Modbus.Tcp
} }
} }
/// <inheritdoc cref="ModbusTcpConnection.ReconnectTimeout"/> /// <inheritdoc cref="ModbusTcpConnection.ConnectTimeout"/>
public TimeSpan ReconnectTimeout public TimeSpan ReconnectTimeout
{ {
get get
{ {
if (connection is ModbusTcpConnection tcpConnection) if (connection is ModbusTcpConnection tcpConnection)
return tcpConnection.ReconnectTimeout; return tcpConnection.ConnectTimeout;
return default; return default;
} }
set set
{ {
if (connection is ModbusTcpConnection tcpConnection) if (connection is ModbusTcpConnection tcpConnection)
tcpConnection.ReconnectTimeout = value; tcpConnection.ConnectTimeout = value;
} }
} }
/// <inheritdoc cref="ModbusTcpConnection.KeepAliveInterval"/> /// <inheritdoc cref="ModbusTcpConnection.IdleTimeout"/>
public TimeSpan KeepAliveInterval public TimeSpan IdleTimeout
{ {
get get
{ {
if (connection is ModbusTcpConnection tcpConnection) if (connection is ModbusTcpConnection tcpConnection)
return tcpConnection.KeepAliveInterval; return tcpConnection.IdleTimeout;
return default; return default;
} }
set set
{ {
if (connection is ModbusTcpConnection tcpConnection) if (connection is ModbusTcpConnection tcpConnection)
tcpConnection.KeepAliveInterval = value; tcpConnection.IdleTimeout = value;
} }
} }
} }

View File

@@ -4,10 +4,10 @@ using System.IO;
using System.Linq; using System.Linq;
using System.Net; using System.Net;
using System.Net.Sockets; using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using AMWD.Protocols.Modbus.Common.Contracts; using AMWD.Protocols.Modbus.Common.Contracts;
using AMWD.Protocols.Modbus.Common.Protocols;
using AMWD.Protocols.Modbus.Tcp.Utils; using AMWD.Protocols.Modbus.Tcp.Utils;
namespace AMWD.Protocols.Modbus.Tcp namespace AMWD.Protocols.Modbus.Tcp
@@ -23,26 +23,33 @@ namespace AMWD.Protocols.Modbus.Tcp
private int _port; private int _port;
private bool _isDisposed; private bool _isDisposed;
private bool _isConnected; private readonly CancellationTokenSource _disposeCts = new();
private readonly SemaphoreSlim _clientLock = new(1, 1);
private readonly TcpClientWrapper _client = new(); private readonly TcpClientWrapper _client = new();
private readonly Timer _idleTimer;
private CancellationTokenSource _disconnectCts; private readonly Task _processingTask;
private Task _reconnectTask = Task.CompletedTask;
private readonly SemaphoreSlim _reconnectLock = new(1, 1);
private CancellationTokenSource _processingCts;
private Task _processingTask = Task.CompletedTask;
private readonly AsyncQueue<RequestQueueItem> _requestQueue = new(); private readonly AsyncQueue<RequestQueueItem> _requestQueue = new();
#endregion Fields #endregion Fields
/// <summary>
/// Initializes a new instance of the <see cref="ModbusTcpConnection"/> class.
/// </summary>
public ModbusTcpConnection()
{
_idleTimer = new Timer(OnIdleTimer);
_processingTask = ProcessAsync(_disposeCts.Token);
}
#region Properties #region Properties
/// <inheritdoc/> /// <inheritdoc/>
public string Name => "TCP"; public string Name => "TCP";
/// <inheritdoc/> /// <inheritdoc/>
public bool IsConnected => _isConnected && _client.Connected; public virtual TimeSpan IdleTimeout { get; set; } = TimeSpan.FromSeconds(6);
/// <summary> /// <summary>
/// The DNS name of the remote host to which the connection is intended to. /// The DNS name of the remote host to which the connection is intended to.
@@ -93,55 +100,12 @@ namespace AMWD.Protocols.Modbus.Tcp
} }
/// <summary> /// <summary>
/// Gets or sets the maximum time until the reconnect is given up. /// Gets or sets the maximum time until the connect attempt is given up.
/// </summary> /// </summary>
public virtual TimeSpan ReconnectTimeout { get; set; } = TimeSpan.MaxValue; public virtual TimeSpan ConnectTimeout { get; set; } = TimeSpan.MaxValue;
/// <summary>
/// Gets or sets the interval in which a keep alive package should be sent.
/// </summary>
public virtual TimeSpan KeepAliveInterval { get; set; } = TimeSpan.Zero;
#endregion Properties #endregion Properties
/// <inheritdoc/>
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
#if NET8_0_OR_GREATER
ObjectDisposedException.ThrowIf(_isDisposed, this);
#else
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
#endif
if (_disconnectCts != null)
{
await _reconnectTask;
return;
}
_disconnectCts = new CancellationTokenSource();
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_disconnectCts.Token, cancellationToken);
_reconnectTask = ReconnectInternalAsync(linkedCts.Token);
await _reconnectTask.ConfigureAwait(false);
}
/// <inheritdoc/>
public Task DisconnectAsync(CancellationToken cancellationToken = default)
{
#if NET8_0_OR_GREATER
ObjectDisposedException.ThrowIf(_isDisposed, this);
#else
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
#endif
if (_disconnectCts == null)
return Task.CompletedTask;
return DisconnectInternalAsync(cancellationToken);
}
/// <inheritdoc/> /// <inheritdoc/>
public void Dispose() public void Dispose()
{ {
@@ -149,13 +113,36 @@ namespace AMWD.Protocols.Modbus.Tcp
return; return;
_isDisposed = true; _isDisposed = true;
DisconnectInternalAsync(CancellationToken.None).Wait(); _disposeCts.Cancel();
_idleTimer.Dispose();
try
{
_processingTask.Wait();
_processingTask.Dispose();
}
catch
{ /* keep it quiet */ }
OnIdleTimer(null);
_client.Dispose(); _client.Dispose();
_clientLock.Dispose();
while (_requestQueue.TryDequeue(out var item))
{
item.CancellationTokenRegistration.Dispose();
item.CancellationTokenSource.Dispose();
item.TaskCompletionSource.TrySetException(new ObjectDisposedException(GetType().FullName));
}
_disposeCts.Dispose();
GC.SuppressFinalize(this); GC.SuppressFinalize(this);
} }
#region Request processing
/// <inheritdoc/> /// <inheritdoc/>
public Task<IReadOnlyList<byte>> InvokeAsync(IReadOnlyList<byte> request, Func<IReadOnlyList<byte>, bool> validateResponseComplete, CancellationToken cancellationToken = default) public Task<IReadOnlyList<byte>> InvokeAsync(IReadOnlyList<byte> request, Func<IReadOnlyList<byte>, bool> validateResponseComplete, CancellationToken cancellationToken = default)
{ {
@@ -166,10 +153,7 @@ namespace AMWD.Protocols.Modbus.Tcp
throw new ObjectDisposedException(GetType().FullName); throw new ObjectDisposedException(GetType().FullName);
#endif #endif
if (!IsConnected) if (request == null || request.Count < 1)
throw new ApplicationException($"Connection is not open");
if (request?.Count < 1)
throw new ArgumentNullException(nameof(request)); throw new ArgumentNullException(nameof(request));
#if NET8_0_OR_GREATER #if NET8_0_OR_GREATER
@@ -184,7 +168,7 @@ namespace AMWD.Protocols.Modbus.Tcp
Request = [.. request], Request = [.. request],
ValidateResponseComplete = validateResponseComplete, ValidateResponseComplete = validateResponseComplete,
TaskCompletionSource = new(), TaskCompletionSource = new(),
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken), CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken)
}; };
item.CancellationTokenRegistration = item.CancellationTokenSource.Token.Register(() => item.CancellationTokenRegistration = item.CancellationTokenSource.Token.Register(() =>
@@ -199,231 +183,183 @@ namespace AMWD.Protocols.Modbus.Tcp
return item.TaskCompletionSource.Task; return item.TaskCompletionSource.Task;
} }
private async Task ReconnectInternalAsync(CancellationToken cancellationToken)
{
if (!_reconnectLock.Wait(0, cancellationToken))
return;
try
{
_isConnected = false;
_processingCts?.Cancel();
await _processingTask.ConfigureAwait(false);
int delay = 1;
int maxDelay = 60;
var ipAddresses = Resolve(Hostname);
if (ipAddresses.Count == 0)
throw new ApplicationException($"Could not resolve hostname '{Hostname}'");
var startTime = DateTime.UtcNow;
while (!cancellationToken.IsCancellationRequested)
{
try
{
foreach (var ipAddress in ipAddresses)
{
_client.Close();
#if NET6_0_OR_GREATER
using var connectTask = _client.ConnectAsync(ipAddress, Port, cancellationToken);
#else
using var connectTask = _client.ConnectAsync(ipAddress, Port);
#endif
if (await Task.WhenAny(connectTask, Task.Delay(ReadTimeout, cancellationToken)) == connectTask)
{
await connectTask;
if (_client.Connected)
{
_isConnected = true;
_processingCts?.Dispose();
_processingCts = new();
_processingTask = ProcessAsync(_processingCts.Token);
SetKeepAlive();
return;
}
}
}
throw new SocketException((int)SocketError.TimedOut);
}
catch (SocketException) when (ReconnectTimeout == TimeSpan.MaxValue || DateTime.UtcNow.Subtract(startTime) < ReconnectTimeout)
{
delay *= 2;
if (delay > maxDelay)
delay = maxDelay;
try
{
await Task.Delay(TimeSpan.FromSeconds(delay), cancellationToken).ConfigureAwait(false);
}
catch
{ /* keep it quiet */ }
}
}
}
finally
{
_reconnectLock.Release();
}
}
private async Task DisconnectInternalAsync(CancellationToken cancellationToken)
{
_disconnectCts?.Cancel();
_processingCts?.Cancel();
try
{
await _reconnectTask.ConfigureAwait(false);
await _processingTask.ConfigureAwait(false);
}
catch
{ /* keep it quiet */ }
// Ensure that the client is closed
await _reconnectLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
_isConnected = false;
_client.Close();
}
finally
{
_reconnectLock.Release();
}
_disconnectCts?.Dispose();
_disconnectCts = null;
_processingCts?.Dispose();
_processingCts = null;
while (_requestQueue.TryDequeue(out var item))
{
item.CancellationTokenRegistration.Dispose();
item.CancellationTokenSource.Dispose();
item.TaskCompletionSource.TrySetCanceled(CancellationToken.None);
}
}
#region Processing
private async Task ProcessAsync(CancellationToken cancellationToken) private async Task ProcessAsync(CancellationToken cancellationToken)
{ {
while (!cancellationToken.IsCancellationRequested) while (!cancellationToken.IsCancellationRequested)
{ {
var item = await _requestQueue.DequeueAsync(cancellationToken).ConfigureAwait(false);
item.CancellationTokenRegistration.Dispose();
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, item.CancellationTokenSource.Token);
try try
{ {
var stream = _client.GetStream(); // Get next request to process
await stream.FlushAsync(linkedCts.Token).ConfigureAwait(false); var item = await _requestQueue.DequeueAsync(cancellationToken).ConfigureAwait(false);
#if NET6_0_OR_GREATER // Remove registration => already removed from queue
await stream.WriteAsync(item.Request, linkedCts.Token).ConfigureAwait(false); item.CancellationTokenRegistration.Dispose();
#else
await stream.WriteAsync(item.Request, 0, item.Request.Length, linkedCts.Token).ConfigureAwait(false);
#endif
linkedCts.Token.ThrowIfCancellationRequested(); // Build combined cancellation token
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, item.CancellationTokenSource.Token);
var bytes = new List<byte>(); // Wait for exclusive access
byte[] buffer = new byte[260]; await _clientLock.WaitAsync(linkedCts.Token).ConfigureAwait(false);
try
do
{ {
#if NET6_0_OR_GREATER // Ensure connection is up
int readCount = await stream.ReadAsync(buffer, linkedCts.Token).ConfigureAwait(false); await AssertConnection(linkedCts.Token).ConfigureAwait(false);
#else
int readCount = await stream.ReadAsync(buffer, 0, buffer.Length, linkedCts.Token).ConfigureAwait(false);
#endif
if (readCount < 1)
throw new EndOfStreamException();
bytes.AddRange(buffer.Take(readCount)); var stream = _client.GetStream();
await stream.FlushAsync(linkedCts.Token).ConfigureAwait(false);
#if NET6_0_OR_GREATER
await stream.WriteAsync(item.Request, linkedCts.Token).ConfigureAwait(false);
#else
await stream.WriteAsync(item.Request, 0, item.Request.Length, linkedCts.Token).ConfigureAwait(false);
#endif
linkedCts.Token.ThrowIfCancellationRequested(); linkedCts.Token.ThrowIfCancellationRequested();
}
while (!item.ValidateResponseComplete(bytes));
item.TaskCompletionSource.TrySetResult(bytes); var bytes = new List<byte>();
byte[] buffer = new byte[TcpProtocol.MAX_ADU_LENGTH];
do
{
#if NET6_0_OR_GREATER
int readCount = await stream.ReadAsync(buffer, linkedCts.Token).ConfigureAwait(false);
#else
int readCount = await stream.ReadAsync(buffer, 0, buffer.Length, linkedCts.Token).ConfigureAwait(false);
#endif
if (readCount < 1)
throw new EndOfStreamException();
bytes.AddRange(buffer.Take(readCount));
linkedCts.Token.ThrowIfCancellationRequested();
}
while (!item.ValidateResponseComplete(bytes));
item.TaskCompletionSource.TrySetResult(bytes);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// Dispose() called
item.TaskCompletionSource.TrySetCanceled(cancellationToken);
}
catch (OperationCanceledException) when (item.CancellationTokenSource.IsCancellationRequested)
{
// Cancellation requested by user
item.TaskCompletionSource.TrySetCanceled(item.CancellationTokenSource.Token);
}
catch (Exception ex)
{
item.TaskCompletionSource.TrySetException(ex);
}
finally
{
_clientLock.Release();
_idleTimer.Change(IdleTimeout, Timeout.InfiniteTimeSpan);
}
} }
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{ {
// DisconnectAsync() called // Dispose() called while waiting for request item
item.TaskCompletionSource.TrySetCanceled(cancellationToken);
return;
}
catch (OperationCanceledException) when (item.CancellationTokenSource.IsCancellationRequested)
{
item.TaskCompletionSource.TrySetCanceled(item.CancellationTokenSource.Token);
continue;
}
catch (IOException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (SocketException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (TimeoutException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (InvalidOperationException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (Exception ex)
{
item.TaskCompletionSource.TrySetException(ex);
} }
} }
} }
internal class RequestQueueItem #endregion Request processing
#region Connection handling
// Has to be called within _clientLock!
private async Task AssertConnection(CancellationToken cancellationToken)
{ {
public byte[] Request { get; set; } if (_client.Connected)
return;
public Func<IReadOnlyList<byte>, bool> ValidateResponseComplete { get; set; } int delay = 1;
int maxDelay = 60;
public TaskCompletionSource<IReadOnlyList<byte>> TaskCompletionSource { get; set; } var ipAddresses = Resolve(Hostname);
if (ipAddresses.Length == 0)
throw new ApplicationException($"Could not resolve hostname '{Hostname}'");
public CancellationTokenSource CancellationTokenSource { get; set; } var startTime = DateTime.UtcNow;
while (!cancellationToken.IsCancellationRequested)
{
try
{
foreach (var ipAddress in ipAddresses)
{
_client.Close();
public CancellationTokenRegistration CancellationTokenRegistration { get; set; } #if NET6_0_OR_GREATER
using var connectTask = _client.ConnectAsync(ipAddress, Port, cancellationToken);
#else
using var connectTask = _client.ConnectAsync(ipAddress, Port);
#endif
if (await Task.WhenAny(connectTask, Task.Delay(ReadTimeout, cancellationToken)) == connectTask)
{
await connectTask;
if (_client.Connected)
return;
}
}
throw new SocketException((int)SocketError.TimedOut);
}
catch (SocketException) when (ConnectTimeout == TimeSpan.MaxValue || DateTime.UtcNow.Subtract(startTime) < ConnectTimeout)
{
delay *= 2;
if (delay > maxDelay)
delay = maxDelay;
try
{
await Task.Delay(TimeSpan.FromSeconds(delay), cancellationToken).ConfigureAwait(false);
}
catch
{ /* keep it quiet */ }
}
}
} }
#endregion Processing private void OnIdleTimer(object _)
{
try
{
_clientLock.Wait(_disposeCts.Token);
try
{
if (!_client.Connected)
return;
_client.Close();
}
finally
{
_clientLock.Release();
}
}
catch
{ /* keep it quiet */ }
}
#endregion Connection handling
#region Helpers #region Helpers
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
private static List<IPAddress> Resolve(string hostname) private static IPAddress[] Resolve(string hostname)
{ {
if (string.IsNullOrWhiteSpace(hostname)) if (string.IsNullOrWhiteSpace(hostname))
return []; return [];
if (IPAddress.TryParse(hostname, out var ipAddress)) if (IPAddress.TryParse(hostname, out var address))
return [ipAddress]; return [address];
try try
{ {
return Dns.GetHostAddresses(hostname) return Dns.GetHostAddresses(hostname)
.Where(a => a.AddressFamily == AddressFamily.InterNetwork || a.AddressFamily == AddressFamily.InterNetworkV6) .Where(a => a.AddressFamily == AddressFamily.InterNetwork || a.AddressFamily == AddressFamily.InterNetworkV6)
.OrderBy(a => a.AddressFamily) // Prefer IPv4 .OrderBy(a => a.AddressFamily) // prefer IPv4
.ToList(); .ToArray();
} }
catch catch
{ {
@@ -431,32 +367,6 @@ namespace AMWD.Protocols.Modbus.Tcp
} }
} }
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
private void SetKeepAlive()
{
#if NET6_0_OR_GREATER
_client.Client?.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, KeepAliveInterval.TotalMilliseconds > 0);
_client.Client?.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, (int)KeepAliveInterval.TotalSeconds);
_client.Client?.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, (int)KeepAliveInterval.TotalSeconds);
#else
// See: https://github.com/dotnet/runtime/issues/25555
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return;
bool isEnabled = KeepAliveInterval.TotalMilliseconds > 0;
uint interval = KeepAliveInterval.TotalMilliseconds > uint.MaxValue
? uint.MaxValue
: (uint)KeepAliveInterval.TotalMilliseconds;
int uIntSize = sizeof(uint);
byte[] config = new byte[uIntSize * 3];
Array.Copy(BitConverter.GetBytes(isEnabled ? 1U : 0U), 0, config, uIntSize * 0, uIntSize);
Array.Copy(BitConverter.GetBytes(interval), 0, config, uIntSize * 1, uIntSize);
Array.Copy(BitConverter.GetBytes(interval), 0, config, uIntSize * 2, uIntSize);
_client.Client?.IOControl(IOControlCode.KeepAliveValues, config, null);
#endif
}
#endregion Helpers #endregion Helpers
} }
} }

View File

@@ -21,6 +21,7 @@ namespace AMWD.Protocols.Modbus.Tcp.Utils
_stream = stream; _stream = stream;
} }
/// <inheritdoc cref="NetworkStream.Dispose" />
public virtual void Dispose() public virtual void Dispose()
=> _stream.Dispose(); => _stream.Dispose();

View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace AMWD.Protocols.Modbus.Tcp.Utils
{
internal class RequestQueueItem
{
public byte[] Request { get; set; }
public Func<IReadOnlyList<byte>, bool> ValidateResponseComplete { get; set; }
public TaskCompletionSource<IReadOnlyList<byte>> TaskCompletionSource { get; set; }
public CancellationTokenSource CancellationTokenSource { get; set; }
public CancellationTokenRegistration CancellationTokenRegistration { get; set; }
}
}

View File

@@ -1,39 +0,0 @@
using System;
using System.Net.Sockets;
namespace AMWD.Protocols.Modbus.Tcp.Utils
{
/// <inheritdoc cref="Socket"/>
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
internal class SocketWrapper : IDisposable
{
[Obsolete("Constructor only for mocking on UnitTests!", error: true)]
public SocketWrapper()
{ }
public SocketWrapper(Socket socket)
{
Client = socket;
}
public virtual Socket Client { get; }
/// <inheritdoc cref="Socket.Dispose()"/>
public virtual void Dispose()
=> Client.Dispose();
/// <inheritdoc cref="Socket.IOControl(IOControlCode, byte[], byte[])"/>
public virtual int IOControl(IOControlCode ioControlCode, byte[] optionInValue, byte[] optionOutValue)
=> Client.IOControl(ioControlCode, optionInValue, optionOutValue);
#if NET6_0_OR_GREATER
/// <inheritdoc cref="Socket.SetSocketOption(SocketOptionLevel, SocketOptionName, bool)"/>
public virtual void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName optionName, bool optionValue)
=> Client.SetSocketOption(optionLevel, optionName, optionValue);
/// <inheritdoc cref="Socket.SetSocketOption(SocketOptionLevel, SocketOptionName, int)"/>
public virtual void SetSocketOption(SocketOptionLevel optionLevel, SocketOptionName optionName, int optionValue)
=> Client.SetSocketOption(optionLevel, optionName, optionValue);
#endif
}
}

View File

@@ -10,8 +10,17 @@ namespace AMWD.Protocols.Modbus.Tcp.Utils
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
internal class TcpClientWrapper : IDisposable internal class TcpClientWrapper : IDisposable
{ {
#region Fields
private readonly TcpClient _client = new(); private readonly TcpClient _client = new();
#endregion Fields
#region Properties
/// <inheritdoc cref="TcpClient.Connected" />
public virtual bool Connected => _client.Connected;
/// <inheritdoc cref="TcpClient.ReceiveTimeout" /> /// <inheritdoc cref="TcpClient.ReceiveTimeout" />
public virtual int ReceiveTimeout public virtual int ReceiveTimeout
{ {
@@ -26,15 +35,9 @@ namespace AMWD.Protocols.Modbus.Tcp.Utils
set => _client.SendTimeout = value; set => _client.SendTimeout = value;
} }
/// <inheritdoc cref="TcpClient.Connected" /> #endregion Properties
public virtual bool Connected => _client.Connected;
/// <inheritdoc cref="TcpClient.Client" /> #region Methods
public virtual SocketWrapper Client
{
get => new(_client.Client);
set => _client.Client = value.Client;
}
/// <inheritdoc cref="TcpClient.Close" /> /// <inheritdoc cref="TcpClient.Close" />
public virtual void Close() public virtual void Close()
@@ -52,12 +55,18 @@ namespace AMWD.Protocols.Modbus.Tcp.Utils
#endif #endif
/// <inheritdoc cref="TcpClient.GetStream" />
public virtual NetworkStreamWrapper GetStream()
=> new(_client.GetStream());
#endregion Methods
#region IDisposable
/// <inheritdoc cref="TcpClient.Dispose()" /> /// <inheritdoc cref="TcpClient.Dispose()" />
public virtual void Dispose() public virtual void Dispose()
=> _client.Dispose(); => _client.Dispose();
/// <inheritdoc cref="TcpClient.GetStream" /> #endregion IDisposable
public virtual NetworkStreamWrapper GetStream()
=> new(_client.GetStream());
} }
} }

View File

@@ -20,7 +20,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
private Mock<IModbusProtocol> _protocol; private Mock<IModbusProtocol> _protocol;
// Responses // Responses
private bool _connectionIsConnectecd;
private List<Coil> _readCoilsResponse; private List<Coil> _readCoilsResponse;
private List<DiscreteInput> _readDiscreteInputsResponse; private List<DiscreteInput> _readDiscreteInputsResponse;
private List<HoldingRegister> _readHoldingRegistersResponse; private List<HoldingRegister> _readHoldingRegistersResponse;
@@ -35,8 +34,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
[TestInitialize] [TestInitialize]
public void Initialize() public void Initialize()
{ {
_connectionIsConnectecd = true;
_readCoilsResponse = []; _readCoilsResponse = [];
_readDiscreteInputsResponse = []; _readDiscreteInputsResponse = [];
_readHoldingRegistersResponse = []; _readHoldingRegistersResponse = [];
@@ -75,13 +72,13 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
MoreRequestsNeeded = false, MoreRequestsNeeded = false,
NextObjectIdToRequest = 0x00, NextObjectIdToRequest = 0x00,
}; };
_firstDeviceIdentificationResponse.Objects.Add(0x00, Encoding.ASCII.GetBytes("AM.WD")); _firstDeviceIdentificationResponse.Objects.Add(0x00, Encoding.UTF8.GetBytes("AM.WD"));
_firstDeviceIdentificationResponse.Objects.Add(0x01, Encoding.ASCII.GetBytes("AMWD-MB")); _firstDeviceIdentificationResponse.Objects.Add(0x01, Encoding.UTF8.GetBytes("AMWD-MB"));
_firstDeviceIdentificationResponse.Objects.Add(0x02, Encoding.ASCII.GetBytes("1.2.3")); _firstDeviceIdentificationResponse.Objects.Add(0x02, Encoding.UTF8.GetBytes("1.2.3"));
_firstDeviceIdentificationResponse.Objects.Add(0x03, Encoding.ASCII.GetBytes("https://github.com/AM-WD/AMWD.Protocols.Modbus")); _firstDeviceIdentificationResponse.Objects.Add(0x03, Encoding.UTF8.GetBytes("https://github.com/AM-WD/AMWD.Protocols.Modbus"));
_firstDeviceIdentificationResponse.Objects.Add(0x04, Encoding.ASCII.GetBytes("AM.WD Modbus Library")); _firstDeviceIdentificationResponse.Objects.Add(0x04, Encoding.UTF8.GetBytes("AM.WD Modbus Library"));
_firstDeviceIdentificationResponse.Objects.Add(0x05, Encoding.ASCII.GetBytes("UnitTests")); _firstDeviceIdentificationResponse.Objects.Add(0x05, Encoding.UTF8.GetBytes("UnitTests"));
_firstDeviceIdentificationResponse.Objects.Add(0x06, Encoding.ASCII.GetBytes("Modbus Client Base Unit Test")); _firstDeviceIdentificationResponse.Objects.Add(0x06, Encoding.UTF8.GetBytes("Modbus Client Base Unit Test"));
_deviceIdentificationResponseQueue = new Queue<DeviceIdentificationRaw>(); _deviceIdentificationResponseQueue = new Queue<DeviceIdentificationRaw>();
_deviceIdentificationResponseQueue.Enqueue(_firstDeviceIdentificationResponse); _deviceIdentificationResponseQueue.Enqueue(_firstDeviceIdentificationResponse);
@@ -121,38 +118,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert - ArgumentNullException // Assert - ArgumentNullException
} }
[TestMethod]
public async Task ShouldConnectSuccessfully()
{
// Arrange
var client = GetClient();
// Act
await client.ConnectAsync();
// Assert
_connection.Verify(c => c.ConnectAsync(It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls();
_protocol.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldDisconnectSuccessfully()
{
// Arrange
var client = GetClient();
// Act
await client.DisconnectAsync();
// Assert
_connection.Verify(c => c.DisconnectAsync(It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls();
_protocol.VerifyNoOtherCalls();
}
[DataTestMethod] [DataTestMethod]
[DataRow(true)] [DataRow(true)]
[DataRow(false)] [DataRow(false)]
@@ -218,20 +183,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert - ArgumentNullException // Assert - ArgumentNullException
} }
[TestMethod]
[ExpectedException(typeof(ApplicationException))]
public async Task ShouldAssertConnected()
{
// Arrange
_connectionIsConnectecd = false;
var client = GetClient();
// Act
await client.ReadCoilsAsync(UNIT_ID, START_ADDRESS, READ_COUNT);
// Assert - ApplicationException
}
#endregion Common/Connection/Assertions #endregion Common/Connection/Assertions
#region Read #region Read
@@ -256,7 +207,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
Assert.AreEqual(i % 2 == 0, result[i].Value); Assert.AreEqual(i % 2 == 0, result[i].Value);
} }
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -286,7 +236,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
Assert.AreEqual(i % 2 == 1, result[i].Value); Assert.AreEqual(i % 2 == 1, result[i].Value);
} }
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -315,7 +264,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
Assert.AreEqual(i + 10, result[i].Value); Assert.AreEqual(i + 10, result[i].Value);
} }
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -344,7 +292,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
Assert.AreEqual(i + 15, result[i].Value); Assert.AreEqual(i + 15, result[i].Value);
} }
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -376,7 +323,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
Assert.AreEqual(0, result.ExtendedObjects.Count); Assert.AreEqual(0, result.ExtendedObjects.Count);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -422,7 +368,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
Assert.AreEqual(0x07, result.ExtendedObjects.First().Key); Assert.AreEqual(0x07, result.ExtendedObjects.First().Key);
CollectionAssert.AreEqual(new byte[] { 0x01, 0x02, 0x03 }, result.ExtendedObjects.First().Value); CollectionAssert.AreEqual(new byte[] { 0x01, 0x02, 0x03 }, result.ExtendedObjects.First().Value);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Exactly(2)); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Exactly(2));
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -454,7 +399,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsTrue(result); Assert.IsTrue(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -481,7 +425,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -508,7 +451,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -535,7 +477,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsTrue(result); Assert.IsTrue(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -562,7 +503,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -589,7 +529,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -620,7 +559,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsTrue(result); Assert.IsTrue(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -652,7 +590,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -684,7 +621,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -715,7 +651,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsTrue(result); Assert.IsTrue(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -747,7 +682,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -779,7 +713,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
// Assert // Assert
Assert.IsFalse(result); Assert.IsFalse(result);
_connection.VerifyGet(c => c.IsConnected, Times.Once);
_connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once); _connection.Verify(c => c.InvokeAsync(It.IsAny<IReadOnlyList<byte>>(), It.IsAny<Func<IReadOnlyList<byte>, bool>>(), It.IsAny<CancellationToken>()), Times.Once);
_connection.VerifyNoOtherCalls(); _connection.VerifyNoOtherCalls();
@@ -797,9 +730,6 @@ namespace AMWD.Protocols.Modbus.Tests.Common.Contracts
_connection _connection
.SetupGet(c => c.Name) .SetupGet(c => c.Name)
.Returns("Mock"); .Returns("Mock");
_connection
.SetupGet(c => c.IsConnected)
.Returns(() => _connectionIsConnectecd);
_protocol = new Mock<IModbusProtocol>(); _protocol = new Mock<IModbusProtocol>();
_protocol _protocol

View File

@@ -19,8 +19,8 @@ namespace AMWD.Protocols.Modbus.Tests.Tcp
_tcpConnectionMock.Setup(c => c.Port).Returns(502); _tcpConnectionMock.Setup(c => c.Port).Returns(502);
_tcpConnectionMock.Setup(c => c.ReadTimeout).Returns(TimeSpan.FromSeconds(10)); _tcpConnectionMock.Setup(c => c.ReadTimeout).Returns(TimeSpan.FromSeconds(10));
_tcpConnectionMock.Setup(c => c.WriteTimeout).Returns(TimeSpan.FromSeconds(20)); _tcpConnectionMock.Setup(c => c.WriteTimeout).Returns(TimeSpan.FromSeconds(20));
_tcpConnectionMock.Setup(c => c.ReconnectTimeout).Returns(TimeSpan.FromSeconds(30)); _tcpConnectionMock.Setup(c => c.ConnectTimeout).Returns(TimeSpan.FromSeconds(30));
_tcpConnectionMock.Setup(c => c.KeepAliveInterval).Returns(TimeSpan.FromSeconds(40)); _tcpConnectionMock.Setup(c => c.IdleTimeout).Returns(TimeSpan.FromSeconds(40));
} }
[TestMethod] [TestMethod]
@@ -35,7 +35,7 @@ namespace AMWD.Protocols.Modbus.Tests.Tcp
TimeSpan readTimeout = client.ReadTimeout; TimeSpan readTimeout = client.ReadTimeout;
TimeSpan writeTimeout = client.WriteTimeout; TimeSpan writeTimeout = client.WriteTimeout;
TimeSpan reconnectTimeout = client.ReconnectTimeout; TimeSpan reconnectTimeout = client.ReconnectTimeout;
TimeSpan keepAliveInterval = client.KeepAliveInterval; TimeSpan idleTimeout = client.IdleTimeout;
// Assert // Assert
Assert.IsNull(hostname); Assert.IsNull(hostname);
@@ -43,7 +43,7 @@ namespace AMWD.Protocols.Modbus.Tests.Tcp
Assert.AreEqual(TimeSpan.Zero, readTimeout); Assert.AreEqual(TimeSpan.Zero, readTimeout);
Assert.AreEqual(TimeSpan.Zero, writeTimeout); Assert.AreEqual(TimeSpan.Zero, writeTimeout);
Assert.AreEqual(TimeSpan.Zero, reconnectTimeout); Assert.AreEqual(TimeSpan.Zero, reconnectTimeout);
Assert.AreEqual(TimeSpan.Zero, keepAliveInterval); Assert.AreEqual(TimeSpan.Zero, idleTimeout);
_genericConnectionMock.VerifyNoOtherCalls(); _genericConnectionMock.VerifyNoOtherCalls();
} }
@@ -60,7 +60,7 @@ namespace AMWD.Protocols.Modbus.Tests.Tcp
client.ReadTimeout = TimeSpan.FromSeconds(123); client.ReadTimeout = TimeSpan.FromSeconds(123);
client.WriteTimeout = TimeSpan.FromSeconds(456); client.WriteTimeout = TimeSpan.FromSeconds(456);
client.ReconnectTimeout = TimeSpan.FromSeconds(789); client.ReconnectTimeout = TimeSpan.FromSeconds(789);
client.KeepAliveInterval = TimeSpan.FromSeconds(321); client.IdleTimeout = TimeSpan.FromSeconds(321);
// Assert // Assert
_genericConnectionMock.VerifyNoOtherCalls(); _genericConnectionMock.VerifyNoOtherCalls();
@@ -78,7 +78,7 @@ namespace AMWD.Protocols.Modbus.Tests.Tcp
TimeSpan readTimeout = client.ReadTimeout; TimeSpan readTimeout = client.ReadTimeout;
TimeSpan writeTimeout = client.WriteTimeout; TimeSpan writeTimeout = client.WriteTimeout;
TimeSpan reconnectTimeout = client.ReconnectTimeout; TimeSpan reconnectTimeout = client.ReconnectTimeout;
TimeSpan keepAliveInterval = client.KeepAliveInterval; TimeSpan keepAliveInterval = client.IdleTimeout;
// Assert // Assert
Assert.AreEqual("127.0.0.1", hostname); Assert.AreEqual("127.0.0.1", hostname);
@@ -92,8 +92,8 @@ namespace AMWD.Protocols.Modbus.Tests.Tcp
_tcpConnectionMock.VerifyGet(c => c.Port, Times.Once); _tcpConnectionMock.VerifyGet(c => c.Port, Times.Once);
_tcpConnectionMock.VerifyGet(c => c.ReadTimeout, Times.Once); _tcpConnectionMock.VerifyGet(c => c.ReadTimeout, Times.Once);
_tcpConnectionMock.VerifyGet(c => c.WriteTimeout, Times.Once); _tcpConnectionMock.VerifyGet(c => c.WriteTimeout, Times.Once);
_tcpConnectionMock.VerifyGet(c => c.ReconnectTimeout, Times.Once); _tcpConnectionMock.VerifyGet(c => c.ConnectTimeout, Times.Once);
_tcpConnectionMock.VerifyGet(c => c.KeepAliveInterval, Times.Once); _tcpConnectionMock.VerifyGet(c => c.IdleTimeout, Times.Once);
_tcpConnectionMock.VerifyNoOtherCalls(); _tcpConnectionMock.VerifyNoOtherCalls();
} }
@@ -109,15 +109,15 @@ namespace AMWD.Protocols.Modbus.Tests.Tcp
client.ReadTimeout = TimeSpan.FromSeconds(123); client.ReadTimeout = TimeSpan.FromSeconds(123);
client.WriteTimeout = TimeSpan.FromSeconds(456); client.WriteTimeout = TimeSpan.FromSeconds(456);
client.ReconnectTimeout = TimeSpan.FromSeconds(789); client.ReconnectTimeout = TimeSpan.FromSeconds(789);
client.KeepAliveInterval = TimeSpan.FromSeconds(321); client.IdleTimeout = TimeSpan.FromSeconds(321);
// Assert // Assert
_tcpConnectionMock.VerifySet(c => c.Hostname = "localhost", Times.Once); _tcpConnectionMock.VerifySet(c => c.Hostname = "localhost", Times.Once);
_tcpConnectionMock.VerifySet(c => c.Port = 205, Times.Once); _tcpConnectionMock.VerifySet(c => c.Port = 205, Times.Once);
_tcpConnectionMock.VerifySet(c => c.ReadTimeout = TimeSpan.FromSeconds(123), Times.Once); _tcpConnectionMock.VerifySet(c => c.ReadTimeout = TimeSpan.FromSeconds(123), Times.Once);
_tcpConnectionMock.VerifySet(c => c.WriteTimeout = TimeSpan.FromSeconds(456), Times.Once); _tcpConnectionMock.VerifySet(c => c.WriteTimeout = TimeSpan.FromSeconds(456), Times.Once);
_tcpConnectionMock.VerifySet(c => c.ReconnectTimeout = TimeSpan.FromSeconds(789), Times.Once); _tcpConnectionMock.VerifySet(c => c.ConnectTimeout = TimeSpan.FromSeconds(789), Times.Once);
_tcpConnectionMock.VerifySet(c => c.KeepAliveInterval = TimeSpan.FromSeconds(321), Times.Once); _tcpConnectionMock.VerifySet(c => c.IdleTimeout = TimeSpan.FromSeconds(321), Times.Once);
_tcpConnectionMock.VerifyNoOtherCalls(); _tcpConnectionMock.VerifyNoOtherCalls();
} }
} }

View File

@@ -0,0 +1,545 @@
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using AMWD.Protocols.Modbus.Common.Contracts;
using AMWD.Protocols.Modbus.Tcp;
using AMWD.Protocols.Modbus.Tcp.Utils;
using Moq;
namespace AMWD.Protocols.Modbus.Tests.Tcp
{
[TestClass]
public class ModbusTcpConnectionTest
{
private readonly string _hostname = "127.0.0.1";
private Mock<TcpClientWrapper> _tcpClientMock;
private Mock<NetworkStreamWrapper> _networkStreamMock;
private bool _alwaysConnected;
private Queue<bool> _connectedQueue;
private readonly int _clientReceiveTimeout = 1000;
private readonly int _clientSendTimeout = 1000;
private readonly Task _clientConnectTask = Task.CompletedTask;
private List<byte[]> _networkRequestCallbacks;
private Queue<byte[]> _networkResponseQueue;
[TestInitialize]
public void Initialize()
{
_alwaysConnected = true;
_connectedQueue = new Queue<bool>();
_networkRequestCallbacks = [];
_networkResponseQueue = new Queue<byte[]>();
}
[TestMethod]
public void ShouldGetAndSetPropertiesOfBaseClient()
{
// Arrange
var connection = GetTcpConnection();
// Act
connection.ReadTimeout = TimeSpan.FromSeconds(123);
connection.WriteTimeout = TimeSpan.FromSeconds(456);
// Assert - part 1
Assert.AreEqual("TCP", connection.Name);
Assert.AreEqual(1, connection.ReadTimeout.TotalSeconds);
Assert.AreEqual(1, connection.WriteTimeout.TotalSeconds);
Assert.AreEqual(_hostname, connection.Hostname);
Assert.AreEqual(502, connection.Port);
// Assert - part 2
_tcpClientMock.VerifySet(c => c.ReceiveTimeout = 123000, Times.Once);
_tcpClientMock.VerifySet(c => c.SendTimeout = 456000, Times.Once);
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Once);
_tcpClientMock.VerifyGet(c => c.SendTimeout, Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[DataTestMethod]
[DataRow(null)]
[DataRow("")]
[DataRow(" ")]
[ExpectedException(typeof(ArgumentNullException))]
public void ShouldThrowArgumentNullExceptionForInvalidHostname(string hostname)
{
// Arrange
var connection = GetTcpConnection();
// Act
connection.Hostname = hostname;
// Assert - ArgumentNullException
}
[DataTestMethod]
[DataRow(0)]
[DataRow(65536)]
[ExpectedException(typeof(ArgumentOutOfRangeException))]
public void ShouldThrowArgumentOutOfRangeExceptionForInvalidPort(int port)
{
// Arrange
var connection = GetTcpConnection();
// Act
connection.Port = port;
// Assert - ArgumentOutOfRangeException
}
[TestMethod]
public void ShouldBeAbleToDisposeMultipleTimes()
{
// Arrange
var connection = GetConnection();
// Act
connection.Dispose();
connection.Dispose();
}
[TestMethod]
[ExpectedException(typeof(ObjectDisposedException))]
public async Task ShouldThrowDisposedExceptionOnInvokeAsync()
{
// Arrange
var connection = GetConnection();
connection.Dispose();
// Act
await connection.InvokeAsync(null, null);
// Assert - OjbectDisposedException
}
[DataTestMethod]
[DataRow(null)]
[DataRow(new byte[0])]
[ExpectedException(typeof(ArgumentNullException))]
public async Task ShouldThrowArgumentNullExceptionForMissingRequestOnInvokeAsync(byte[] request)
{
// Arrange
var connection = GetConnection();
// Act
await connection.InvokeAsync(request, null);
// Assert - ArgumentNullException
}
[TestMethod]
[ExpectedException(typeof(ArgumentNullException))]
public async Task ShouldThrowArgumentNullExceptionForMissingValidationOnInvokeAsync()
{
// Arrange
byte[] request = new byte[1];
var connection = GetConnection();
// Act
await connection.InvokeAsync(request, null);
// Assert - ArgumentNullException
}
[TestMethod]
public async Task ShouldInvokeAsync()
{
// Arrange
byte[] request = [1, 2, 3];
byte[] expectedResponse = [9, 8, 7];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
_networkResponseQueue.Enqueue(expectedResponse);
var connection = GetConnection();
// Act
var response = await connection.InvokeAsync(request, validation);
// Assert
Assert.IsNotNull(response);
CollectionAssert.AreEqual(expectedResponse, response.ToArray());
CollectionAssert.AreEqual(request, _networkRequestCallbacks.First());
_tcpClientMock.Verify(c => c.Connected, Times.Once);
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldConnectAndDisconnectOnInvokeAsync()
{
// Arrange
_alwaysConnected = false;
_connectedQueue.Enqueue(false);
_connectedQueue.Enqueue(true);
_connectedQueue.Enqueue(true);
byte[] request = [1, 2, 3];
byte[] expectedResponse = [9, 8, 7];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
_networkResponseQueue.Enqueue(expectedResponse);
var connection = GetConnection();
connection.IdleTimeout = TimeSpan.FromMilliseconds(200);
// Act
var response = await connection.InvokeAsync(request, validation);
await Task.Delay(500);
// Assert
Assert.IsNotNull(response);
CollectionAssert.AreEqual(expectedResponse, response.ToArray());
CollectionAssert.AreEqual(request, _networkRequestCallbacks.First());
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Once);
_tcpClientMock.Verify(c => c.Connected, Times.Exactly(3));
_tcpClientMock.Verify(c => c.Close(), Times.Exactly(2));
_tcpClientMock.Verify(c => c.ConnectAsync(It.IsAny<IPAddress>(), It.IsAny<int>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
[ExpectedException(typeof(EndOfStreamException))]
public async Task ShouldThrowEndOfStreamExceptionOnInvokeAsync()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
// Act
var response = await connection.InvokeAsync(request, validation);
// Assert - EndOfStreamException
}
[TestMethod]
[ExpectedException(typeof(ApplicationException))]
public async Task ShouldThrowApplicationExceptionWhenHostNotResolvableOnInvokeAsync()
{
// Arrange
_alwaysConnected = false;
_connectedQueue.Enqueue(false);
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
connection.GetType().GetField("_hostname", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(connection, "");
// Act
var response = await connection.InvokeAsync(request, validation);
// Assert - ApplicationException
}
[TestMethod]
public async Task ShouldSkipCloseOnTimeoutOnInvokeAsync()
{
// Arrange
_alwaysConnected = false;
_connectedQueue.Enqueue(false);
_connectedQueue.Enqueue(true);
_connectedQueue.Enqueue(false);
byte[] request = [1, 2, 3];
byte[] expectedResponse = [9, 8, 7];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
_networkResponseQueue.Enqueue(expectedResponse);
var connection = GetConnection();
connection.IdleTimeout = TimeSpan.FromMilliseconds(200);
// Act
var response = await connection.InvokeAsync(request, validation);
await Task.Delay(500);
// Assert
Assert.IsNotNull(response);
CollectionAssert.AreEqual(expectedResponse, response.ToArray());
CollectionAssert.AreEqual(request, _networkRequestCallbacks.First());
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Once);
_tcpClientMock.Verify(c => c.Connected, Times.Exactly(3));
_tcpClientMock.Verify(c => c.Close(), Times.Once);
_tcpClientMock.Verify(c => c.ConnectAsync(It.IsAny<IPAddress>(), It.IsAny<int>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldRetryToConnectOnInvokeAsync()
{
// Arrange
_alwaysConnected = false;
_connectedQueue.Enqueue(false);
_connectedQueue.Enqueue(false);
_connectedQueue.Enqueue(true);
byte[] request = [1, 2, 3];
byte[] expectedResponse = [9, 8, 7];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
_networkResponseQueue.Enqueue(expectedResponse);
var connection = GetConnection();
// Act
var response = await connection.InvokeAsync(request, validation);
// Assert
Assert.IsNotNull(response);
CollectionAssert.AreEqual(expectedResponse, response.ToArray());
CollectionAssert.AreEqual(request, _networkRequestCallbacks.First());
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Exactly(2));
_tcpClientMock.Verify(c => c.Connected, Times.Exactly(3));
_tcpClientMock.Verify(c => c.Close(), Times.Exactly(2));
_tcpClientMock.Verify(c => c.ConnectAsync(It.IsAny<IPAddress>(), It.IsAny<int>(), It.IsAny<CancellationToken>()), Times.Exactly(2));
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
[ExpectedException(typeof(TaskCanceledException))]
public async Task ShouldThrowTaskCancelledExceptionForDisposeOnInvokeAsync()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask(Task.Delay(100)));
// Act
var task = connection.InvokeAsync(request, validation);
connection.Dispose();
await task;
// Assert - TaskCancelledException
}
[TestMethod]
[ExpectedException(typeof(TaskCanceledException))]
public async Task ShouldThrowTaskCancelledExceptionForCancelOnInvokeAsync()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
using var cts = new CancellationTokenSource();
var connection = GetConnection();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask(Task.Delay(100)));
// Act
var task = connection.InvokeAsync(request, validation, cts.Token);
cts.Cancel();
await task;
// Assert - TaskCancelledException
}
[TestMethod]
public async Task ShouldRemoveRequestFromQueueOnInvokeAsync()
{
// Arrange
byte[] request = [1, 2, 3];
byte[] expectedResponse = [9, 8, 7];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
_networkResponseQueue.Enqueue(expectedResponse);
using var cts = new CancellationTokenSource();
var connection = GetConnection();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.Returns(new ValueTask(Task.Delay(100)));
// Act
var taskToComplete = connection.InvokeAsync(request, validation);
var taskToCancel = connection.InvokeAsync(request, validation, cts.Token);
cts.Cancel();
var response = await taskToComplete;
// Assert - Part 1
try
{
await taskToCancel;
Assert.Fail();
}
catch (TaskCanceledException)
{ /* expected exception */ }
// Assert - Part 2
Assert.AreEqual(1, _networkRequestCallbacks.Count);
CollectionAssert.AreEqual(request, _networkRequestCallbacks.First());
CollectionAssert.AreEqual(expectedResponse, response.ToArray());
_tcpClientMock.Verify(c => c.Connected, Times.Once);
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldRemoveRequestFromQueueOnDispose()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.Returns(new ValueTask(Task.Delay(100)));
// Act
var taskToCancel = connection.InvokeAsync(request, validation);
var taskToDequeue = connection.InvokeAsync(request, validation);
connection.Dispose();
// Assert
try
{
await taskToCancel;
Assert.Fail();
}
catch (TaskCanceledException)
{ /* expected exception */ }
try
{
await taskToDequeue;
Assert.Fail();
}
catch (ObjectDisposedException)
{ /* expected exception */ }
Assert.AreEqual(1, _networkRequestCallbacks.Count);
CollectionAssert.AreEqual(request, _networkRequestCallbacks.First());
_tcpClientMock.Verify(c => c.Connected, Times.Once);
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_tcpClientMock.Verify(c => c.Dispose(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
private IModbusConnection GetConnection()
=> GetTcpConnection();
private ModbusTcpConnection GetTcpConnection()
{
_networkStreamMock = new Mock<NetworkStreamWrapper>();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.Returns(ValueTask.CompletedTask);
_networkStreamMock
.Setup(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()))
.Returns<Memory<byte>, CancellationToken>((buffer, _) =>
{
if (_networkResponseQueue.TryDequeue(out byte[] bytes))
{
bytes.CopyTo(buffer);
return ValueTask.FromResult(bytes.Length);
}
return ValueTask.FromResult(0);
});
_tcpClientMock = new Mock<TcpClientWrapper>();
_tcpClientMock.Setup(c => c.Connected).Returns(() => _alwaysConnected || _connectedQueue.Dequeue());
_tcpClientMock.Setup(c => c.ReceiveTimeout).Returns(() => _clientReceiveTimeout);
_tcpClientMock.Setup(c => c.SendTimeout).Returns(() => _clientSendTimeout);
_tcpClientMock
.Setup(c => c.ConnectAsync(It.IsAny<IPAddress>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns(() => _clientConnectTask);
_tcpClientMock
.Setup(c => c.GetStream())
.Returns(() => _networkStreamMock.Object);
var connection = new ModbusTcpConnection
{
Hostname = _hostname,
Port = 502
};
// Replace real TCP client with mock
var clientField = connection.GetType().GetField("_client", BindingFlags.NonPublic | BindingFlags.Instance);
(clientField.GetValue(connection) as TcpClientWrapper)?.Dispose();
clientField.SetValue(connection, _tcpClientMock.Object);
return connection;
}
private void ClearInvocations()
{
_networkStreamMock.Invocations.Clear();
_tcpClientMock.Invocations.Clear();
}
}
}

View File

@@ -1,723 +0,0 @@
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using AMWD.Protocols.Modbus.Common.Contracts;
using AMWD.Protocols.Modbus.Tcp;
using AMWD.Protocols.Modbus.Tcp.Utils;
using Moq;
namespace AMWD.Protocols.Modbus.Tests.Tcp.Utils
{
[TestClass]
public class ModbusTcpConnectionTest
{
private string _hostname = "127.0.0.1";
private Mock<TcpClientWrapper> _tcpClientMock;
private Mock<NetworkStreamWrapper> _networkStreamMock;
private Mock<SocketWrapper> _socketMock;
private bool _clientIsAlwaysConnected;
private Queue<bool> _clientIsConnectedQueue;
private int _clientReceiveTimeout = 1000;
private int _clientSendTimeout = 1000;
private Task _clientConnectTask = Task.CompletedTask;
private List<byte[]> _networkRequestCallbacks;
private Queue<byte[]> _networkResponseQueue;
[TestInitialize]
public void Initialize()
{
_clientIsAlwaysConnected = true;
_clientIsConnectedQueue = new Queue<bool>();
_networkRequestCallbacks = [];
_networkResponseQueue = new Queue<byte[]>();
}
[TestMethod]
public void ShouldGetAndSetPropertiesOfBaseClient()
{
// Arrange
_clientIsAlwaysConnected = false;
_clientIsConnectedQueue.Enqueue(true);
var connection = GetTcpConnection();
connection.GetType().GetField("_isConnected", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(connection, true);
// Act
connection.ReadTimeout = TimeSpan.FromSeconds(123);
connection.WriteTimeout = TimeSpan.FromSeconds(456);
// Assert - part 1
Assert.AreEqual("TCP", connection.Name);
Assert.AreEqual(1, connection.ReadTimeout.TotalSeconds);
Assert.AreEqual(1, connection.WriteTimeout.TotalSeconds);
Assert.IsTrue(connection.IsConnected);
// Assert - part 2
_tcpClientMock.VerifySet(c => c.ReceiveTimeout = 123000, Times.Once);
_tcpClientMock.VerifySet(c => c.SendTimeout = 456000, Times.Once);
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Once);
_tcpClientMock.VerifyGet(c => c.SendTimeout, Times.Once);
_tcpClientMock.VerifyGet(c => c.Connected, Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[DataTestMethod]
[DataRow(null)]
[DataRow("")]
[DataRow(" ")]
[ExpectedException(typeof(ArgumentNullException))]
public void ShouldThrowArumentNullExceptionForInvalidHostname(string hostname)
{
// Arrange
var connection = GetTcpConnection();
// Act
connection.Hostname = hostname;
// Assert - ArgumentNullException
}
[DataTestMethod]
[DataRow(0)]
[DataRow(65536)]
[ExpectedException(typeof(ArgumentOutOfRangeException))]
public void ShouldThrowArumentOutOfRangeExceptionForInvalidPort(int port)
{
// Arrange
var connection = GetTcpConnection();
// Act
connection.Port = port;
// Assert - ArgumentOutOfRangeException
}
[TestMethod]
public async Task ShouldConnectAsync()
{
// Arrange
var connection = GetConnection();
// Act
await connection.ConnectAsync();
// Assert
Assert.IsTrue(connection.IsConnected);
_tcpClientMock.Verify(c => c.Close(), Times.Once);
_tcpClientMock.Verify(c => c.ConnectAsync(IPAddress.Loopback, 502, It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Once);
_tcpClientMock.VerifyGet(c => c.Connected, Times.Exactly(2));
_tcpClientMock.VerifyGet(c => c.Client, Times.Exactly(3));
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, false), Times.Once);
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, 0), Times.Once);
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, 0), Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldOnlyConnectAsyncOnce()
{
// Arrange
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
// Act
await connection.ConnectAsync();
// Assert
Assert.IsTrue(connection.IsConnected);
_tcpClientMock.VerifyGet(c => c.Connected, Times.Once);
_socketMock.VerifyNoOtherCalls();
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
[ExpectedException(typeof(ApplicationException))]
public async Task ShouldThrowApplicationExceptionHostnameNotResolvable()
{
// Arrange
var connection = GetConnection();
connection.GetType().GetField("_hostname", BindingFlags.NonPublic | BindingFlags.Instance).SetValue(connection, "");
// Act
await connection.ConnectAsync();
// Assert - ApplicationException
}
[TestMethod]
public async Task ShouldRetryConnectAsync()
{
// Arrange
_clientIsAlwaysConnected = false;
_clientIsConnectedQueue.Enqueue(false);
_clientIsConnectedQueue.Enqueue(true);
_clientIsConnectedQueue.Enqueue(true);
var connection = GetConnection();
// Act
await connection.ConnectAsync();
// Assert
Assert.IsTrue(connection.IsConnected);
_tcpClientMock.Verify(c => c.Close(), Times.Exactly(2));
_tcpClientMock.Verify(c => c.ConnectAsync(IPAddress.Loopback, 502, It.IsAny<CancellationToken>()), Times.Exactly(2));
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Exactly(2));
_tcpClientMock.VerifyGet(c => c.Connected, Times.Exactly(3));
_tcpClientMock.VerifyGet(c => c.Client, Times.Exactly(3));
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, false), Times.Once);
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, 0), Times.Once);
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, 0), Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
[ExpectedException(typeof(SocketException))]
public async Task ShouldThrowSocketExceptionOnConnectAsyncForNoReconnect()
{
// Arrange
_clientIsAlwaysConnected = false;
_clientIsConnectedQueue.Enqueue(false);
var connection = GetTcpConnection();
connection.ReconnectTimeout = TimeSpan.Zero;
// Act
await connection.ConnectAsync();
// Assert - SocketException
}
[TestMethod]
public async Task ShouldDisconnectAsync()
{
// Arrange
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
// Act
await connection.DisconnectAsync();
// Assert
Assert.IsFalse(connection.IsConnected);
_tcpClientMock.Verify(c => c.Close(), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldOnlyDisconnectAsyncOnce()
{
// Arrange
var connection = GetConnection();
await connection.ConnectAsync();
await connection.DisconnectAsync();
ClearInvocations();
// Act
await connection.DisconnectAsync();
// Assert
Assert.IsFalse(connection.IsConnected);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldCallDisconnectOnDispose()
{
// Arrange
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
// Act
connection.Dispose();
// Assert
_tcpClientMock.Verify(c => c.Close(), Times.Once);
_tcpClientMock.Verify(c => c.Dispose(), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public void ShouldAllowMultipleDispose()
{
// Arrange
var connection = GetConnection();
// Act
connection.Dispose();
connection.Dispose();
// Assert
_tcpClientMock.Verify(c => c.Close(), Times.Once);
_tcpClientMock.Verify(c => c.Dispose(), Times.Once);
_tcpClientMock.VerifyNoOtherCalls();
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
[ExpectedException(typeof(ApplicationException))]
public async Task ShouldThrowApplicationExceptionOnInvokeAsyncWhileNotConnected()
{
// Arrange
var connection = GetConnection();
// Act
await connection.InvokeAsync(null, null);
// Assert - ApplicationException
}
[DataTestMethod]
[DataRow(null)]
[DataRow(new byte[0])]
[ExpectedException(typeof(ArgumentNullException))]
public async Task ShouldThrowArgumentNullExceptionOnInvokeAsyncForRequest(byte[] request)
{
// Arrange
var connection = GetConnection();
await connection.ConnectAsync();
// Act
await connection.InvokeAsync(request, null);
// Assert - ArgumentNullException
}
[TestMethod]
[ExpectedException(typeof(ArgumentNullException))]
public async Task ShouldThrowArgumentNullExceptionOnInvokeAsyncForMissingValidation()
{
// Arrange
byte[] request = new byte[1];
var connection = GetConnection();
await connection.ConnectAsync();
// Act
await connection.InvokeAsync(request, null);
// Assert - ArgumentNullException
}
[TestMethod]
public async Task ShouldInvokeAsync()
{
// Arrange
_networkResponseQueue.Enqueue([9, 8, 7]);
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
// Act
var response = await connection.InvokeAsync(request, validation);
// Assert
Assert.AreEqual(1, _networkRequestCallbacks.Count);
CollectionAssert.AreEqual(new byte[] { 9, 8, 7 }, response.ToArray());
CollectionAssert.AreEqual(request, _networkRequestCallbacks[0]);
_tcpClientMock.Verify(c => c.Connected, Times.Once);
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
[ExpectedException(typeof(EndOfStreamException))]
public async Task ShouldThrowEndOfStreamOnInvokeAsync()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
// Act
_ = await connection.InvokeAsync(request, validation);
// Assert - EndOfStreamException
}
[TestMethod]
[ExpectedException(typeof(TaskCanceledException))]
public async Task ShouldCancelOnInvokeAsyncOnDisconnect()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask(Task.Delay(100)));
await connection.ConnectAsync();
ClearInvocations();
// Act
var task = connection.InvokeAsync(request, validation);
await connection.DisconnectAsync();
await task;
// Assert - TaskCanceledException
}
[TestMethod]
[ExpectedException(typeof(TaskCanceledException))]
public async Task ShouldCancelOnInvokeAsyncOnAbort()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var cts = new CancellationTokenSource();
var connection = GetConnection();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Returns(new ValueTask(Task.Delay(100)));
await connection.ConnectAsync();
ClearInvocations();
// Act
var task = connection.InvokeAsync(request, validation, cts.Token);
cts.Cancel();
await task;
// Assert - TaskCanceledException
}
[DataTestMethod]
[DataRow(typeof(IOException))]
[DataRow(typeof(SocketException))]
[DataRow(typeof(TimeoutException))]
[DataRow(typeof(InvalidOperationException))]
public async Task ShouldReconnectOnInvokeAsyncForExceptionType(Type exceptionType)
{
// Arrange
_networkResponseQueue.Enqueue([9, 8, 7]);
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.ThrowsAsync((Exception)Activator.CreateInstance(exceptionType));
// Act
try
{
await connection.InvokeAsync(request, validation);
}
catch (Exception ex)
{
// Assert - part 1
Assert.IsInstanceOfType(ex, exceptionType);
}
// Assert - part 2
Assert.AreEqual(1, _networkRequestCallbacks.Count);
CollectionAssert.AreEqual(request, _networkRequestCallbacks[0]);
_tcpClientMock.Verify(c => c.Close(), Times.Once);
_tcpClientMock.Verify(c => c.ConnectAsync(IPAddress.Loopback, 502, It.IsAny<CancellationToken>()), Times.Once);
_tcpClientMock.VerifyGet(c => c.ReceiveTimeout, Times.Once);
_tcpClientMock.VerifyGet(c => c.Connected, Times.Exactly(2));
_tcpClientMock.VerifyGet(c => c.Client, Times.Exactly(3));
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, false), Times.Once);
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, 0), Times.Once);
_socketMock.Verify(s => s.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, 0), Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldReturnWithUnknownExceptionOnInvokeAsync()
{
// Arrange
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.ThrowsAsync(new NotImplementedException());
// Act
try
{
await connection.InvokeAsync(request, validation);
}
catch (Exception ex)
{
// Assert - part 1
Assert.IsInstanceOfType(ex, typeof(NotImplementedException));
}
// Assert - part 2
Assert.AreEqual(1, _networkRequestCallbacks.Count);
CollectionAssert.AreEqual(request, _networkRequestCallbacks[0]);
_tcpClientMock.Verify(c => c.Connected, Times.Once);
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldRemoveRequestFromQueueOnInvokeAsync()
{
// Arrange
_networkResponseQueue.Enqueue([9, 8, 7]);
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.Returns(new ValueTask(Task.Delay(100)));
var cts = new CancellationTokenSource();
// Act
var taskToComplete = connection.InvokeAsync(request, validation);
var taskToCancel = connection.InvokeAsync(request, validation, cts.Token);
cts.Cancel();
var response = await taskToComplete;
// Assert
try
{
await taskToCancel;
Assert.Fail();
}
catch (TaskCanceledException)
{ /* expected exception */ }
Assert.AreEqual(1, _networkRequestCallbacks.Count);
CollectionAssert.AreEqual(new byte[] { 9, 8, 7 }, response.ToArray());
CollectionAssert.AreEqual(request, _networkRequestCallbacks[0]);
_tcpClientMock.Verify(c => c.Connected, Times.Exactly(2));
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
[TestMethod]
public async Task ShouldCancelQueuedRequestOnDisconnect()
{
// Arrange
_networkResponseQueue.Enqueue([9, 8, 7]);
byte[] request = [1, 2, 3];
var validation = new Func<IReadOnlyList<byte>, bool>(_ => true);
var connection = GetConnection();
await connection.ConnectAsync();
ClearInvocations();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.Returns(new ValueTask(Task.Delay(100)));
var cts = new CancellationTokenSource();
// Act
var taskToCancel = connection.InvokeAsync(request, validation);
var taskToDequeue = connection.InvokeAsync(request, validation);
await connection.DisconnectAsync();
// Assert
try
{
await taskToCancel;
Assert.Fail();
}
catch (TaskCanceledException ex)
{
/* expected exception */
Assert.AreNotEqual(CancellationToken.None, ex.CancellationToken);
}
try
{
await taskToDequeue;
Assert.Fail();
}
catch (TaskCanceledException ex)
{
/* expected exception */
Assert.AreEqual(CancellationToken.None, ex.CancellationToken);
}
Assert.AreEqual(1, _networkRequestCallbacks.Count);
CollectionAssert.AreEqual(request, _networkRequestCallbacks[0]);
_tcpClientMock.Verify(c => c.Connected, Times.Exactly(2));
_tcpClientMock.Verify(c => c.GetStream(), Times.Once);
_tcpClientMock.Verify(c => c.Close(), Times.Once);
_networkStreamMock.Verify(ns => ns.FlushAsync(It.IsAny<CancellationToken>()), Times.Once);
_networkStreamMock.Verify(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()), Times.Once);
_socketMock.VerifyNoOtherCalls();
_tcpClientMock.VerifyNoOtherCalls();
_networkStreamMock.VerifyNoOtherCalls();
}
private IModbusConnection GetConnection()
=> GetTcpConnection();
private ModbusTcpConnection GetTcpConnection()
{
_networkStreamMock = new Mock<NetworkStreamWrapper>();
_networkStreamMock
.Setup(ns => ns.WriteAsync(It.IsAny<ReadOnlyMemory<byte>>(), It.IsAny<CancellationToken>()))
.Callback<ReadOnlyMemory<byte>, CancellationToken>((req, _) => _networkRequestCallbacks.Add(req.ToArray()))
.Returns(ValueTask.CompletedTask);
_networkStreamMock
.Setup(ns => ns.ReadAsync(It.IsAny<Memory<byte>>(), It.IsAny<CancellationToken>()))
.Returns<Memory<byte>, CancellationToken>((buffer, _) =>
{
if (_networkResponseQueue.TryDequeue(out byte[] bytes))
{
bytes.CopyTo(buffer);
return ValueTask.FromResult(bytes.Length);
}
return ValueTask.FromResult(0);
});
_socketMock = new Mock<SocketWrapper>();
_tcpClientMock = new Mock<TcpClientWrapper>();
_tcpClientMock.Setup(c => c.Client).Returns(() => _socketMock.Object);
_tcpClientMock.Setup(c => c.Connected).Returns(() => _clientIsAlwaysConnected || _clientIsConnectedQueue.Dequeue());
_tcpClientMock.Setup(c => c.ReceiveTimeout).Returns(() => _clientReceiveTimeout);
_tcpClientMock.Setup(c => c.SendTimeout).Returns(() => _clientSendTimeout);
_tcpClientMock
.Setup(c => c.ConnectAsync(It.IsAny<IPAddress>(), It.IsAny<int>(), It.IsAny<CancellationToken>()))
.Returns(() => _clientConnectTask);
_tcpClientMock
.Setup(c => c.GetStream())
.Returns(() => _networkStreamMock.Object);
var connection = new ModbusTcpConnection
{
Hostname = _hostname,
Port = 502
};
// Replace real TCP client with mock
var clientField = connection.GetType().GetField("_client", BindingFlags.NonPublic | BindingFlags.Instance);
(clientField.GetValue(connection) as TcpClientWrapper)?.Dispose();
clientField.SetValue(connection, _tcpClientMock.Object);
return connection;
}
private void ClearInvocations()
{
_networkStreamMock.Invocations.Clear();
_socketMock.Invocations.Clear();
_tcpClientMock.Invocations.Clear();
}
}
}

View File

@@ -35,7 +35,7 @@ It uses a specific TCP connection implementation and plugs all things from the C
--- ---
Published under [MIT License] (see [**tl;dr**Legal]) Published under [MIT License] (see [**tl;dr**Legal])
[![built with Codeium](https://codeium.com/badges/main)](https://codeium.com) [![built with Codeium](https://codeium.com/badges/main)](https://codeium.com/profile/andreasmueller)