using System; using System.Collections.Generic; using System.Net; using Microsoft.AspNetCore.Antiforgery; using Microsoft.AspNetCore.Http; using Microsoft.VisualStudio.TestTools.UnitTesting; using Moq; namespace UnitTests.AspNetCore.Extensions { [TestClass] [System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage] public class HttpContextExtensionsTests { private Mock sessionMock; private string tokenName; private string tokenValue; private Dictionary requestHeaders; private Dictionary requestQueries; private Dictionary items; private IPAddress remote; [TestInitialize] public void InitializeTests() { tokenName = null; tokenValue = null; requestHeaders = new Dictionary(); requestQueries = new Dictionary(); items = new Dictionary(); remote = IPAddress.Loopback; } #region Antiforgery [TestMethod] public void ShouldReturnAntiforgery() { // arrange tokenName = "af-token"; tokenValue = "security_first"; var context = GetContext(); // act var result = context.GetAntiforgeryToken(); // assert Assert.AreEqual(tokenName, result.Name); Assert.AreEqual(tokenValue, result.Value); } [TestMethod] public void ShouldReturnAntiforgeryNullService() { // arrange tokenName = "af-token"; tokenValue = "security_first"; var context = GetContext(hasAntiforgery: false); // act var result = context.GetAntiforgeryToken(); // assert Assert.AreEqual(null, result.Name); Assert.AreEqual(null, result.Value); } [TestMethod] public void ShouldReturnAntiforgeryNullToken() { // arrange var context = GetContext(); // act var result = context.GetAntiforgeryToken(); // assert Assert.AreEqual(null, result.Name); Assert.AreEqual(null, result.Value); } #endregion Antiforgery #region RemoteAddres [TestMethod] public void ShouldReturnRemoteAddress() { // arrange remote = IPAddress.Parse("1.2.3.4"); var context = GetContext(); // act var result = context.GetRemoteIpAddress(); // assert Assert.AreEqual(remote, result); } [TestMethod] public void ShouldReturnDefaultHeader() { // arrange remote = IPAddress.Parse("1.2.3.4"); var header = IPAddress.Parse("5.6.7.8"); requestHeaders.Add("X-Forwarded-For", header.ToString()); var context = GetContext(); // act var result = context.GetRemoteIpAddress(); // assert Assert.AreNotEqual(remote, result); Assert.AreEqual(header, result); } [TestMethod] public void ShouldReturnCustomHeader() { // arrange remote = IPAddress.Parse("1.2.3.4"); string headerName = "FooBar"; var headerIp = IPAddress.Parse("5.6.7.8"); requestHeaders.Add(headerName, headerIp.ToString()); var context = GetContext(); // act var result = context.GetRemoteIpAddress(headerName: headerName); // assert Assert.AreNotEqual(remote, result); Assert.AreEqual(headerIp, result); } [TestMethod] public void ShouldReturnAddressInvalidHeader() { // arrange remote = IPAddress.Parse("1.2.3.4"); requestHeaders.Add("X-Forwarded-For", "1.2.3:4"); var context = GetContext(); // act var result = context.GetRemoteIpAddress(); // assert Assert.AreEqual(remote, result); } #endregion RemoteAddres #region Local Request [TestMethod] public void ShouldReturnTrueOnLocal() { // arrange remote = IPAddress.Loopback; var context = GetContext(); // act bool result = context.IsLocalRequest(); // assert Assert.IsTrue(result); } [TestMethod] public void ShouldReturnFalseOnRemote() { // arrange remote = IPAddress.Parse("1.2.3.4"); var context = GetContext(); // act bool result = context.IsLocalRequest(); // assert Assert.IsFalse(result); } [TestMethod] public void ShouldReturnTrueOnDefaultHeader() { // arrange remote = IPAddress.Parse("1.2.3.4"); var headerIp = IPAddress.Loopback; requestHeaders.Add("X-Forwarded-For", headerIp.ToString()); var context = GetContext(); // act bool result = context.IsLocalRequest(); // assert Assert.IsTrue(result); } [TestMethod] public void ShouldReturnTrueOnCustomHeader() { // arrange remote = IPAddress.Parse("1.2.3.4"); string headerName = "FooBar"; var headerIp = IPAddress.Loopback; requestHeaders.Add(headerName, headerIp.ToString()); var context = GetContext(); // act bool result = context.IsLocalRequest(headerName: headerName); // assert Assert.IsTrue(result); } [TestMethod] public void ShouldReturnFalseOnDefaultHeader() { // arrange var headerIp = IPAddress.Parse("1.2.3.4"); requestHeaders.Add("X-Forwarded-For", headerIp.ToString()); var context = GetContext(); // act bool result = context.IsLocalRequest(); // assert Assert.IsFalse(result); } [TestMethod] public void ShouldReturnFalseOnCustomHeader() { // arrange string headerName = "FooBar"; var headerIp = IPAddress.Parse("1.2.3.4"); requestHeaders.Add(headerName, headerIp.ToString()); var context = GetContext(); // act bool result = context.IsLocalRequest(headerName: headerName); // assert Assert.IsFalse(result); } #endregion Local Request #region ReturnUrl [TestMethod] public void ShouldReturnNull() { // arrange var context = GetContext(); // act string result = context.GetReturnUrl(); // assert Assert.IsNull(result); } [TestMethod] public void ShouldReturnOriginalRequest() { // arrange string request = "abc"; string query = "def"; items.Add("OriginalRequest", request); requestQueries.Add("ReturnUrl", query); var context = GetContext(); // act string result = context.GetReturnUrl(); // assert Assert.AreEqual(request, result); Assert.AreNotEqual(query, result); } [TestMethod] public void ShouldReturnUrl() { // arrange string query = "def"; requestQueries.Add("ReturnUrl", query); var context = GetContext(); // act string result = context.GetReturnUrl(); // assert Assert.AreEqual(query, result); } #endregion ReturnUrl #region Session [TestMethod] public void ShouldClearSession() { // arrange var context = GetContext(); // act context.ClearSession(); // assert sessionMock.Verify(s => s.Clear(), Times.Once); } [TestMethod] public void ShouldSkipWhenNoSession() { // arrange var context = GetContext(hasSession: false); // act context.ClearSession(); // assert sessionMock.Verify(s => s.Clear(), Times.Never); } #endregion Session private HttpContext GetContext(bool hasAntiforgery = true, bool hasSession = true) { // Request var requestHeaderMock = new Mock(); foreach (var header in requestHeaders) { requestHeaderMock .Setup(h => h.ContainsKey(header.Key)) .Returns(true); requestHeaderMock .Setup(h => h[header.Key]) .Returns(header.Value); } var requestQueryMock = new Mock(); foreach (var query in requestQueries) { requestQueryMock .Setup(h => h.ContainsKey(query.Key)) .Returns(true); requestQueryMock .Setup(h => h[query.Key]) .Returns(query.Value); } var requestMock = new Mock(); requestMock .Setup(r => r.Headers) .Returns(requestHeaderMock.Object); requestMock .Setup(r => r.Query) .Returns(requestQueryMock.Object); // Request Services var requestServicesMock = new Mock(); if (hasAntiforgery) { var antiforgeryMock = new Mock(); antiforgeryMock .Setup(af => af.GetAndStoreTokens(It.IsAny())) .Returns(string.IsNullOrWhiteSpace(tokenName) ? null : new AntiforgeryTokenSet(tokenValue, tokenValue, tokenName, tokenName)); requestServicesMock .Setup(rs => rs.GetService(typeof(IAntiforgery))) .Returns(antiforgeryMock.Object); } // Connection var connectionInfoMock = new Mock(); connectionInfoMock .Setup(ci => ci.LocalIpAddress) .Returns(IPAddress.Loopback); connectionInfoMock .Setup(ci => ci.RemoteIpAddress) .Returns(remote); // Session sessionMock = new Mock(); var contextMock = new Mock(); contextMock .Setup(c => c.Request) .Returns(requestMock.Object); contextMock .Setup(c => c.RequestServices) .Returns(requestServicesMock.Object); contextMock .Setup(c => c.Connection) .Returns(connectionInfoMock.Object); contextMock .Setup(c => c.Items) .Returns(items); if (hasSession) { contextMock .Setup(c => c.Session) .Returns(sessionMock.Object); } return contextMock.Object; } } }