219 lines
6.6 KiB
C#
219 lines
6.6 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Net;
|
|
using System.Security.Claims;
|
|
using System.Text;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using AMWD.Common.AspNetCore.Security.BasicAuthentication;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.Extensions.Primitives;
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting;
|
|
using Moq;
|
|
|
|
namespace UnitTests.AspNetCore.Security.BasicAuthentication
|
|
{
|
|
[TestClass]
|
|
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
|
|
public class BasicAuthenticationMiddlewareTests
|
|
{
|
|
private Dictionary<string, string> requestHeaders;
|
|
|
|
private Dictionary<string, string> responseHeadersCallback;
|
|
private int responseStatusCodeCallback;
|
|
|
|
private string validatorRealm;
|
|
private ClaimsPrincipal validatorResponse;
|
|
private List<(string username, string password, IPAddress ipAddr)> validatorCallback;
|
|
|
|
[TestInitialize]
|
|
public void InitializeTests()
|
|
{
|
|
requestHeaders = new Dictionary<string, string>();
|
|
|
|
responseHeadersCallback = new Dictionary<string, string>();
|
|
responseStatusCodeCallback = 0;
|
|
|
|
validatorRealm = null;
|
|
validatorResponse = null;
|
|
validatorCallback = new List<(string username, string password, IPAddress ipAddr)>();
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldAllowAccess()
|
|
{
|
|
// arrange
|
|
string username = "user";
|
|
string password = "pass:word";
|
|
|
|
requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))}");
|
|
validatorResponse = new ClaimsPrincipal();
|
|
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext();
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(0, responseStatusCodeCallback); // not triggered
|
|
Assert.AreEqual(0, responseHeadersCallback.Count);
|
|
Assert.AreEqual(1, validatorCallback.Count);
|
|
|
|
Assert.AreEqual(username, validatorCallback.First().username);
|
|
Assert.AreEqual(password, validatorCallback.First().password);
|
|
Assert.AreEqual(IPAddress.Loopback, validatorCallback.First().ipAddr);
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldDenyMissingHeader()
|
|
{
|
|
// arrange
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext();
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(401, responseStatusCodeCallback);
|
|
|
|
Assert.AreEqual(0, validatorCallback.Count);
|
|
|
|
Assert.AreEqual(1, responseHeadersCallback.Count);
|
|
Assert.AreEqual("WWW-Authenticate", responseHeadersCallback.Keys.First());
|
|
Assert.AreEqual("Basic", responseHeadersCallback.Values.First());
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldDenyNoResult()
|
|
{
|
|
// arrange
|
|
string username = "user";
|
|
string password = "pw";
|
|
|
|
validatorRealm = "TEST";
|
|
var remote = IPAddress.Parse("1.2.3.4");
|
|
|
|
requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))}");
|
|
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext(remote);
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(401, responseStatusCodeCallback);
|
|
|
|
Assert.AreEqual(1, responseHeadersCallback.Count);
|
|
Assert.AreEqual("WWW-Authenticate", responseHeadersCallback.Keys.First());
|
|
Assert.AreEqual($"Basic realm=\"{validatorRealm}\"", responseHeadersCallback.Values.First());
|
|
|
|
Assert.AreEqual(1, validatorCallback.Count);
|
|
Assert.AreEqual(username, validatorCallback.First().username);
|
|
Assert.AreEqual(password, validatorCallback.First().password);
|
|
Assert.AreEqual(remote, validatorCallback.First().ipAddr);
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldBreakOnException()
|
|
{
|
|
// arrange
|
|
string username = "user";
|
|
|
|
requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}"))}");
|
|
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext();
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(500, responseStatusCodeCallback);
|
|
}
|
|
|
|
private BasicAuthenticationMiddleware GetMiddleware()
|
|
{
|
|
var nextMock = new Mock<RequestDelegate>();
|
|
var validatorMock = new Mock<IBasicAuthenticationValidator>();
|
|
validatorMock
|
|
.Setup(v => v.Realm)
|
|
.Returns(validatorRealm);
|
|
validatorMock
|
|
.Setup(v => v.ValidateAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<IPAddress>(), It.IsAny<CancellationToken>()))
|
|
.Callback<string, string, IPAddress, CancellationToken>((username, password, ipAddress, _) => validatorCallback.Add((username, password, ipAddress)))
|
|
.ReturnsAsync(validatorResponse);
|
|
|
|
return new BasicAuthenticationMiddleware(nextMock.Object, validatorMock.Object);
|
|
}
|
|
|
|
private HttpContext GetContext(IPAddress remote = null)
|
|
{
|
|
// Request
|
|
var requestHeaderMock = new Mock<IHeaderDictionary>();
|
|
foreach (var header in requestHeaders)
|
|
{
|
|
requestHeaderMock
|
|
.Setup(h => h.ContainsKey(header.Key))
|
|
.Returns(true);
|
|
requestHeaderMock
|
|
.Setup(h => h[header.Key])
|
|
.Returns(header.Value);
|
|
}
|
|
|
|
var requestMock = new Mock<HttpRequest>();
|
|
requestMock
|
|
.Setup(r => r.Headers)
|
|
.Returns(requestHeaderMock.Object);
|
|
|
|
// Response
|
|
var responseHeaderMock = new Mock<IHeaderDictionary>();
|
|
responseHeaderMock
|
|
.SetupSet(h => h[It.IsAny<string>()] = It.IsAny<StringValues>())
|
|
.Callback<string, StringValues>((key, value) => responseHeadersCallback[key] = value);
|
|
|
|
var responseMock = new Mock<HttpResponse>();
|
|
responseMock
|
|
.Setup(r => r.Headers)
|
|
.Returns(responseHeaderMock.Object);
|
|
responseMock
|
|
.SetupSet(r => r.StatusCode = It.IsAny<int>())
|
|
.Callback<int>((code) => responseStatusCodeCallback = code);
|
|
|
|
// Connection
|
|
var connectionInfoMock = new Mock<ConnectionInfo>();
|
|
connectionInfoMock
|
|
.Setup(ci => ci.LocalIpAddress)
|
|
.Returns(IPAddress.Loopback);
|
|
connectionInfoMock
|
|
.Setup(ci => ci.RemoteIpAddress)
|
|
.Returns(remote ?? IPAddress.Loopback);
|
|
|
|
// Request Services
|
|
var requestServicesMock = new Mock<IServiceProvider>();
|
|
|
|
var contextMock = new Mock<HttpContext>();
|
|
contextMock
|
|
.Setup(c => c.Request)
|
|
.Returns(requestMock.Object);
|
|
contextMock
|
|
.Setup(c => c.Response)
|
|
.Returns(responseMock.Object);
|
|
contextMock
|
|
.Setup(c => c.Connection)
|
|
.Returns(connectionInfoMock.Object);
|
|
contextMock
|
|
.Setup(c => c.RequestServices)
|
|
.Returns(requestServicesMock.Object);
|
|
contextMock
|
|
.Setup(c => c.RequestAborted)
|
|
.Returns(CancellationToken.None);
|
|
|
|
return contextMock.Object;
|
|
}
|
|
}
|
|
}
|