using System; using System.Collections.Generic; using System.Linq; using System.Net; using System.Security.Claims; using System.Text; using System.Threading; using System.Threading.Tasks; using AMWD.Common.AspNetCore.BasicAuthentication; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Primitives; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; namespace UnitTests.AspNetCore.BasicAuthentication { [TestClass] [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] public class BasicAuthenticationMiddlewareTests { private Dictionary requestHeaders; private Dictionary responseHeadersCallback; private int responseStatusCodeCallback; private string validatorRealm; private ClaimsPrincipal validatorResponse; private List<(string username, string password, IPAddress ipAddr)> validatorCallback; [TestInitialize] public void InitializeTests() { requestHeaders = new Dictionary(); responseHeadersCallback = new Dictionary(); responseStatusCodeCallback = 0; validatorRealm = null; validatorResponse = null; validatorCallback = new List<(string username, string password, IPAddress ipAddr)>(); } [TestMethod] public async Task ShouldAllowAccess() { // arrange string username = "user"; string password = "pass:word"; requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))}"); validatorResponse = new ClaimsPrincipal(); var middleware = GetMiddleware(); var context = GetContext(); // act await middleware.InvokeAsync(context); // assert Assert.AreEqual(0, responseStatusCodeCallback); // not triggered Assert.AreEqual(0, responseHeadersCallback.Count); Assert.AreEqual(1, validatorCallback.Count); Assert.AreEqual(username, validatorCallback.First().username); Assert.AreEqual(password, validatorCallback.First().password); Assert.AreEqual(IPAddress.Loopback, validatorCallback.First().ipAddr); } [TestMethod] public async Task ShouldDenyMissingHeader() { // arrange var middleware = GetMiddleware(); var context = GetContext(); // act await middleware.InvokeAsync(context); // assert Assert.AreEqual(401, responseStatusCodeCallback); Assert.AreEqual(0, validatorCallback.Count); Assert.AreEqual(1, responseHeadersCallback.Count); Assert.AreEqual("WWW-Authenticate", responseHeadersCallback.Keys.First()); Assert.AreEqual("Basic", responseHeadersCallback.Values.First()); } [TestMethod] public async Task ShouldDenyNoResult() { // arrange string username = "user"; string password = "pw"; validatorRealm = "TEST"; var remote = IPAddress.Parse("1.2.3.4"); requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))}"); var middleware = GetMiddleware(); var context = GetContext(remote); // act await middleware.InvokeAsync(context); // assert Assert.AreEqual(401, responseStatusCodeCallback); Assert.AreEqual(1, responseHeadersCallback.Count); Assert.AreEqual("WWW-Authenticate", responseHeadersCallback.Keys.First()); Assert.AreEqual($"Basic realm=\"{validatorRealm}\"", responseHeadersCallback.Values.First()); Assert.AreEqual(1, validatorCallback.Count); Assert.AreEqual(username, validatorCallback.First().username); Assert.AreEqual(password, validatorCallback.First().password); Assert.AreEqual(remote, validatorCallback.First().ipAddr); } [TestMethod] public async Task ShouldBreakOnException() { // arrange string username = "user"; requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}"))}"); var middleware = GetMiddleware(); var context = GetContext(); // act await middleware.InvokeAsync(context); // assert Assert.AreEqual(500, responseStatusCodeCallback); } private BasicAuthenticationMiddleware GetMiddleware() { var nextMock = new Mock(); var validatorMock = new Mock(); validatorMock .Setup(v => v.Realm) .Returns(validatorRealm); validatorMock .Setup(v => v.ValidateAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) .Callback((username, password, ipAddress, _) => validatorCallback.Add((username, password, ipAddress))) .ReturnsAsync(validatorResponse); return new BasicAuthenticationMiddleware(nextMock.Object, validatorMock.Object); } private HttpContext GetContext(IPAddress remote = null) { // Request var requestHeaderMock = new Mock(); foreach (var header in requestHeaders) { requestHeaderMock .Setup(h => h.ContainsKey(header.Key)) .Returns(true); requestHeaderMock .Setup(h => h[header.Key]) .Returns(header.Value); } var requestMock = new Mock(); requestMock .Setup(r => r.Headers) .Returns(requestHeaderMock.Object); // Response var responseHeaderMock = new Mock(); responseHeaderMock .SetupSet(h => h[It.IsAny()] = It.IsAny()) .Callback((key, value) => responseHeadersCallback[key] = value); var responseMock = new Mock(); responseMock .Setup(r => r.Headers) .Returns(responseHeaderMock.Object); responseMock .SetupSet(r => r.StatusCode = It.IsAny()) .Callback((code) => responseStatusCodeCallback = code); // Connection var connectionInfoMock = new Mock(); connectionInfoMock .Setup(ci => ci.LocalIpAddress) .Returns(IPAddress.Loopback); connectionInfoMock .Setup(ci => ci.RemoteIpAddress) .Returns(remote ?? IPAddress.Loopback); // Request Services var requestServicesMock = new Mock(); var contextMock = new Mock(); contextMock .Setup(c => c.Request) .Returns(requestMock.Object); contextMock .Setup(c => c.Response) .Returns(responseMock.Object); contextMock .Setup(c => c.Connection) .Returns(connectionInfoMock.Object); contextMock .Setup(c => c.RequestServices) .Returns(requestServicesMock.Object); contextMock .Setup(c => c.RequestAborted) .Returns(CancellationToken.None); return contextMock.Object; } } }