229 lines
6.9 KiB
C#
229 lines
6.9 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Net;
|
|
using System.Security.Claims;
|
|
using System.Text;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using AMWD.Common.AspNetCore.Security.BasicAuthentication;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.Extensions.Primitives;
|
|
using Microsoft.Net.Http.Headers;
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting;
|
|
using Moq;
|
|
|
|
namespace UnitTests.AspNetCore.Security.BasicAuthentication
|
|
{
|
|
[TestClass]
|
|
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
|
|
public class BasicAuthenticationMiddlewareTests
|
|
{
|
|
private Dictionary<string, string> _requestHeaders;
|
|
|
|
private Dictionary<string, string> _responseHeadersCallback;
|
|
private int _responseStatusCodeCallback;
|
|
|
|
private string _validatorRealm;
|
|
private ClaimsPrincipal _validatorResponse;
|
|
private List<(string username, string password, IPAddress ipAddr)> _validatorCallback;
|
|
|
|
[TestInitialize]
|
|
public void InitializeTests()
|
|
{
|
|
_requestHeaders = [];
|
|
|
|
_responseHeadersCallback = [];
|
|
_responseStatusCodeCallback = 0;
|
|
|
|
_validatorRealm = null;
|
|
_validatorResponse = null;
|
|
_validatorCallback = [];
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldAllowAccess()
|
|
{
|
|
// arrange
|
|
string username = "user";
|
|
string password = "pass:word";
|
|
|
|
_requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))}");
|
|
_validatorResponse = new ClaimsPrincipal();
|
|
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext();
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(0, _responseStatusCodeCallback); // not triggered
|
|
Assert.AreEqual(0, _responseHeadersCallback.Count);
|
|
Assert.AreEqual(1, _validatorCallback.Count);
|
|
|
|
Assert.AreEqual(username, _validatorCallback.First().username);
|
|
Assert.AreEqual(password, _validatorCallback.First().password);
|
|
Assert.AreEqual(IPAddress.Loopback, _validatorCallback.First().ipAddr);
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldDenyMissingHeader()
|
|
{
|
|
// arrange
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext();
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(401, _responseStatusCodeCallback);
|
|
|
|
Assert.AreEqual(0, _validatorCallback.Count);
|
|
|
|
Assert.AreEqual(1, _responseHeadersCallback.Count);
|
|
Assert.AreEqual("WWW-Authenticate", _responseHeadersCallback.Keys.First());
|
|
Assert.AreEqual("Basic", _responseHeadersCallback.Values.First());
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldDenyNoResult()
|
|
{
|
|
// arrange
|
|
string username = "user";
|
|
string password = "pw";
|
|
|
|
_validatorRealm = "TEST";
|
|
var remote = IPAddress.Parse("1.2.3.4");
|
|
|
|
_requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}:{password}"))}");
|
|
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext(remote);
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(401, _responseStatusCodeCallback);
|
|
|
|
Assert.AreEqual(1, _responseHeadersCallback.Count);
|
|
Assert.AreEqual("WWW-Authenticate", _responseHeadersCallback.Keys.First());
|
|
Assert.AreEqual($"Basic realm=\"{_validatorRealm}\"", _responseHeadersCallback.Values.First());
|
|
|
|
Assert.AreEqual(1, _validatorCallback.Count);
|
|
Assert.AreEqual(username, _validatorCallback.First().username);
|
|
Assert.AreEqual(password, _validatorCallback.First().password);
|
|
Assert.AreEqual(remote, _validatorCallback.First().ipAddr);
|
|
}
|
|
|
|
[TestMethod]
|
|
public async Task ShouldBreakOnException()
|
|
{
|
|
// arrange
|
|
string username = "user";
|
|
|
|
_requestHeaders.Add("Authorization", $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes($"{username}"))}");
|
|
|
|
var middleware = GetMiddleware();
|
|
var context = GetContext();
|
|
|
|
// act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// assert
|
|
Assert.AreEqual(500, _responseStatusCodeCallback);
|
|
}
|
|
|
|
private BasicAuthenticationMiddleware GetMiddleware()
|
|
{
|
|
var nextMock = new Mock<RequestDelegate>();
|
|
var validatorMock = new Mock<IBasicAuthenticationValidator>();
|
|
validatorMock
|
|
.Setup(v => v.Realm)
|
|
.Returns(_validatorRealm);
|
|
validatorMock
|
|
.Setup(v => v.ValidateAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<IPAddress>(), It.IsAny<CancellationToken>()))
|
|
.Callback<string, string, IPAddress, CancellationToken>((username, password, ipAddress, _) => _validatorCallback.Add((username, password, ipAddress)))
|
|
.ReturnsAsync(_validatorResponse);
|
|
|
|
return new BasicAuthenticationMiddleware(nextMock.Object, validatorMock.Object);
|
|
}
|
|
|
|
private HttpContext GetContext(IPAddress remote = null)
|
|
{
|
|
// Request
|
|
var requestHeaderMock = new Mock<IHeaderDictionary>();
|
|
foreach (var header in _requestHeaders)
|
|
{
|
|
var strVal = new StringValues(header.Value);
|
|
requestHeaderMock
|
|
.Setup(h => h.ContainsKey(header.Key))
|
|
.Returns(true);
|
|
requestHeaderMock
|
|
.Setup(h => h[header.Key])
|
|
.Returns(strVal);
|
|
requestHeaderMock
|
|
.Setup(h => h.TryGetValue(header.Key, out strVal))
|
|
.Returns(true);
|
|
}
|
|
|
|
var requestMock = new Mock<HttpRequest>();
|
|
requestMock
|
|
.Setup(r => r.Headers)
|
|
.Returns(requestHeaderMock.Object);
|
|
|
|
// Response
|
|
var responseHeaderMock = new Mock<IHeaderDictionary>();
|
|
responseHeaderMock
|
|
.SetupSet(h => h[It.IsAny<string>()] = It.IsAny<StringValues>())
|
|
.Callback<string, StringValues>((key, value) => _responseHeadersCallback[key] = value);
|
|
#pragma warning disable CS0618
|
|
responseHeaderMock
|
|
.SetupSet(h => h.WWWAuthenticate)
|
|
.Callback((value) => _responseHeadersCallback[HeaderNames.WWWAuthenticate] = value);
|
|
#pragma warning restore CS0618
|
|
|
|
var responseMock = new Mock<HttpResponse>();
|
|
responseMock
|
|
.Setup(r => r.Headers)
|
|
.Returns(responseHeaderMock.Object);
|
|
responseMock
|
|
.SetupSet(r => r.StatusCode = It.IsAny<int>())
|
|
.Callback<int>((code) => _responseStatusCodeCallback = code);
|
|
|
|
// Connection
|
|
var connectionInfoMock = new Mock<ConnectionInfo>();
|
|
connectionInfoMock
|
|
.Setup(ci => ci.LocalIpAddress)
|
|
.Returns(IPAddress.Loopback);
|
|
connectionInfoMock
|
|
.Setup(ci => ci.RemoteIpAddress)
|
|
.Returns(remote ?? IPAddress.Loopback);
|
|
|
|
// Request Services
|
|
var requestServicesMock = new Mock<IServiceProvider>();
|
|
|
|
var contextMock = new Mock<HttpContext>();
|
|
contextMock
|
|
.Setup(c => c.Request)
|
|
.Returns(requestMock.Object);
|
|
contextMock
|
|
.Setup(c => c.Response)
|
|
.Returns(responseMock.Object);
|
|
contextMock
|
|
.Setup(c => c.Connection)
|
|
.Returns(connectionInfoMock.Object);
|
|
contextMock
|
|
.Setup(c => c.RequestServices)
|
|
.Returns(requestServicesMock.Object);
|
|
contextMock
|
|
.Setup(c => c.RequestAborted)
|
|
.Returns(CancellationToken.None);
|
|
|
|
return contextMock.Object;
|
|
}
|
|
}
|
|
}
|