using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using AMWD.Protocols.Modbus.Common.Contracts;
using AMWD.Protocols.Modbus.Tcp.Utils;
namespace AMWD.Protocols.Modbus.Tcp
{
///
/// The default Modbus TCP connection.
///
public class ModbusTcpConnection : IModbusConnection
{
#region Fields
private string _hostname;
private int _port;
private bool _isDisposed;
private bool _isConnected;
private readonly TcpClientWrapper _client = new();
private CancellationTokenSource _disconnectCts;
private Task _reconnectTask = Task.CompletedTask;
private readonly SemaphoreSlim _reconnectLock = new(1, 1);
private CancellationTokenSource _processingCts;
private Task _processingTask = Task.CompletedTask;
private readonly AsyncQueue _requestQueue = new();
#endregion Fields
#region Properties
///
public string Name => "TCP";
///
public bool IsConnected => _isConnected && _client.Connected;
///
/// The DNS name of the remote host to which the connection is intended to.
///
public virtual string Hostname
{
get => _hostname;
set
{
if (string.IsNullOrWhiteSpace(value))
throw new ArgumentNullException(nameof(value));
_hostname = value;
}
}
///
/// The port number of the remote host to which the connection is intended to.
///
public virtual int Port
{
get => _port;
set
{
if (value < 1 || ushort.MaxValue < value)
throw new ArgumentOutOfRangeException(nameof(value));
_port = value;
}
}
///
/// Gets or sets the receive time out value of the connection.
///
public virtual TimeSpan ReadTimeout
{
get => TimeSpan.FromMilliseconds(_client.ReceiveTimeout);
set => _client.ReceiveTimeout = (int)value.TotalMilliseconds;
}
///
/// Gets or sets the send time out value of the connection.
///
public virtual TimeSpan WriteTimeout
{
get => TimeSpan.FromMilliseconds(_client.SendTimeout);
set => _client.SendTimeout = (int)value.TotalMilliseconds;
}
///
/// Gets or sets the maximum time until the reconnect is given up.
///
public virtual TimeSpan ReconnectTimeout { get; set; } = TimeSpan.MaxValue;
///
/// Gets or sets the interval in which a keep alive package should be sent.
///
public virtual TimeSpan KeepAliveInterval { get; set; } = TimeSpan.Zero;
#endregion Properties
///
public async Task ConnectAsync(CancellationToken cancellationToken = default)
{
#if NET8_0_OR_GREATER
ObjectDisposedException.ThrowIf(_isDisposed, this);
#else
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
#endif
if (_disconnectCts != null)
{
await _reconnectTask;
return;
}
_disconnectCts = new CancellationTokenSource();
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_disconnectCts.Token, cancellationToken);
_reconnectTask = ReconnectInternalAsync(linkedCts.Token);
await _reconnectTask.ConfigureAwait(false);
}
///
public Task DisconnectAsync(CancellationToken cancellationToken = default)
{
#if NET8_0_OR_GREATER
ObjectDisposedException.ThrowIf(_isDisposed, this);
#else
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
#endif
if (_disconnectCts == null)
return Task.CompletedTask;
return DisconnectInternalAsync(cancellationToken);
}
///
public void Dispose()
{
if (_isDisposed)
return;
_isDisposed = true;
DisconnectInternalAsync(CancellationToken.None).Wait();
_client.Dispose();
GC.SuppressFinalize(this);
}
///
public Task> InvokeAsync(IReadOnlyList request, Func, bool> validateResponseComplete, CancellationToken cancellationToken = default)
{
#if NET8_0_OR_GREATER
ObjectDisposedException.ThrowIf(_isDisposed, this);
#else
if (_isDisposed)
throw new ObjectDisposedException(GetType().FullName);
#endif
if (!IsConnected)
throw new ApplicationException($"Connection is not open");
if (request?.Count < 1)
throw new ArgumentNullException(nameof(request));
#if NET8_0_OR_GREATER
ArgumentNullException.ThrowIfNull(validateResponseComplete);
#else
if (validateResponseComplete == null)
throw new ArgumentNullException(nameof(validateResponseComplete));
#endif
var item = new RequestQueueItem
{
Request = [.. request],
ValidateResponseComplete = validateResponseComplete,
TaskCompletionSource = new(),
CancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken),
};
item.CancellationTokenRegistration = item.CancellationTokenSource.Token.Register(() =>
{
_requestQueue.Remove(item);
item.CancellationTokenSource.Dispose();
item.TaskCompletionSource.TrySetCanceled(cancellationToken);
item.CancellationTokenRegistration.Dispose();
});
_requestQueue.Enqueue(item);
return item.TaskCompletionSource.Task;
}
private async Task ReconnectInternalAsync(CancellationToken cancellationToken)
{
if (!_reconnectLock.Wait(0, cancellationToken))
return;
try
{
_isConnected = false;
_processingCts?.Cancel();
await _processingTask.ConfigureAwait(false);
int delay = 1;
int maxDelay = 60;
var ipAddresses = Resolve(Hostname);
if (ipAddresses.Count == 0)
throw new ApplicationException($"Could not resolve hostname '{Hostname}'");
var startTime = DateTime.UtcNow;
while (!cancellationToken.IsCancellationRequested)
{
try
{
foreach (var ipAddress in ipAddresses)
{
_client.Close();
#if NET6_0_OR_GREATER
using var connectTask = _client.ConnectAsync(ipAddress, Port, cancellationToken);
#else
using var connectTask = _client.ConnectAsync(ipAddress, Port);
#endif
if (await Task.WhenAny(connectTask, Task.Delay(ReadTimeout, cancellationToken)) == connectTask)
{
await connectTask;
if (_client.Connected)
{
_isConnected = true;
_processingCts?.Dispose();
_processingCts = new();
_processingTask = ProcessAsync(_processingCts.Token);
SetKeepAlive();
return;
}
}
}
throw new SocketException((int)SocketError.TimedOut);
}
catch (SocketException) when (ReconnectTimeout == TimeSpan.MaxValue || DateTime.UtcNow.Subtract(startTime) < ReconnectTimeout)
{
delay *= 2;
if (delay > maxDelay)
delay = maxDelay;
try
{
await Task.Delay(TimeSpan.FromSeconds(delay), cancellationToken).ConfigureAwait(false);
}
catch
{ /* keep it quiet */ }
}
}
}
finally
{
_reconnectLock.Release();
}
}
private async Task DisconnectInternalAsync(CancellationToken cancellationToken)
{
_disconnectCts?.Cancel();
_processingCts?.Cancel();
try
{
await _reconnectTask.ConfigureAwait(false);
await _processingTask.ConfigureAwait(false);
}
catch
{ /* keep it quiet */ }
// Ensure that the client is closed
await _reconnectLock.WaitAsync(cancellationToken).ConfigureAwait(false);
try
{
_isConnected = false;
_client.Close();
}
finally
{
_reconnectLock.Release();
}
_disconnectCts?.Dispose();
_disconnectCts = null;
_processingCts?.Dispose();
_processingCts = null;
while (_requestQueue.TryDequeue(out var item))
{
item.CancellationTokenRegistration.Dispose();
item.CancellationTokenSource.Dispose();
item.TaskCompletionSource.TrySetCanceled(CancellationToken.None);
}
}
#region Processing
private async Task ProcessAsync(CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
var item = await _requestQueue.DequeueAsync(cancellationToken).ConfigureAwait(false);
item.CancellationTokenRegistration.Dispose();
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, item.CancellationTokenSource.Token);
try
{
var stream = _client.GetStream();
await stream.FlushAsync(linkedCts.Token).ConfigureAwait(false);
#if NET6_0_OR_GREATER
await stream.WriteAsync(item.Request, linkedCts.Token).ConfigureAwait(false);
#else
await stream.WriteAsync(item.Request, 0, item.Request.Length, linkedCts.Token).ConfigureAwait(false);
#endif
linkedCts.Token.ThrowIfCancellationRequested();
var bytes = new List();
byte[] buffer = new byte[260];
do
{
#if NET6_0_OR_GREATER
int readCount = await stream.ReadAsync(buffer, linkedCts.Token).ConfigureAwait(false);
#else
int readCount = await stream.ReadAsync(buffer, 0, buffer.Length, linkedCts.Token).ConfigureAwait(false);
#endif
if (readCount < 1)
throw new EndOfStreamException();
bytes.AddRange(buffer.Take(readCount));
linkedCts.Token.ThrowIfCancellationRequested();
}
while (!item.ValidateResponseComplete(bytes));
item.TaskCompletionSource.TrySetResult(bytes);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
// DisconnectAsync() called
item.TaskCompletionSource.TrySetCanceled(cancellationToken);
return;
}
catch (OperationCanceledException) when (item.CancellationTokenSource.IsCancellationRequested)
{
item.TaskCompletionSource.TrySetCanceled(item.CancellationTokenSource.Token);
continue;
}
catch (IOException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (SocketException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (TimeoutException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (InvalidOperationException ex)
{
item.TaskCompletionSource.TrySetException(ex);
_reconnectTask = ReconnectInternalAsync(_disconnectCts.Token);
}
catch (Exception ex)
{
item.TaskCompletionSource.TrySetException(ex);
}
}
}
internal class RequestQueueItem
{
public byte[] Request { get; set; }
public Func, bool> ValidateResponseComplete { get; set; }
public TaskCompletionSource> TaskCompletionSource { get; set; }
public CancellationTokenSource CancellationTokenSource { get; set; }
public CancellationTokenRegistration CancellationTokenRegistration { get; set; }
}
#endregion Processing
#region Helpers
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
private static List Resolve(string hostname)
{
if (string.IsNullOrWhiteSpace(hostname))
return [];
if (IPAddress.TryParse(hostname, out var ipAddress))
return [ipAddress];
try
{
return Dns.GetHostAddresses(hostname)
.Where(a => a.AddressFamily == AddressFamily.InterNetwork || a.AddressFamily == AddressFamily.InterNetworkV6)
.OrderBy(a => a.AddressFamily) // Prefer IPv4
.ToList();
}
catch
{
return [];
}
}
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
private void SetKeepAlive()
{
#if NET6_0_OR_GREATER
_client.Client?.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, KeepAliveInterval.TotalMilliseconds > 0);
_client.Client?.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, (int)KeepAliveInterval.TotalSeconds);
_client.Client?.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, (int)KeepAliveInterval.TotalSeconds);
#else
// See: https://github.com/dotnet/runtime/issues/25555
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return;
bool isEnabled = KeepAliveInterval.TotalMilliseconds > 0;
uint interval = KeepAliveInterval.TotalMilliseconds > uint.MaxValue
? uint.MaxValue
: (uint)KeepAliveInterval.TotalMilliseconds;
int uIntSize = sizeof(uint);
byte[] config = new byte[uIntSize * 3];
Array.Copy(BitConverter.GetBytes(isEnabled ? 1U : 0U), 0, config, uIntSize * 0, uIntSize);
Array.Copy(BitConverter.GetBytes(interval), 0, config, uIntSize * 1, uIntSize);
Array.Copy(BitConverter.GetBytes(interval), 0, config, uIntSize * 2, uIntSize);
_client.Client?.IOControl(IOControlCode.KeepAliveValues, config, null);
#endif
}
#endregion Helpers
}
}