1
0
Files
common/UnitTests/AspNetCore/Extensions/HttpContextExtensionsTests.cs

431 lines
9.2 KiB
C#

using System;
using System.Collections.Generic;
using System.Net;
using Microsoft.AspNetCore.Antiforgery;
using Microsoft.AspNetCore.Http;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
namespace UnitTests.AspNetCore.Extensions
{
[TestClass]
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public class HttpContextExtensionsTests
{
private Mock<ISession> sessionMock;
private string tokenName;
private string tokenValue;
private Dictionary<string, string> requestHeaders;
private Dictionary<string, string> requestQueries;
private Dictionary<object, object> items;
private IPAddress remote;
[TestInitialize]
public void InitializeTests()
{
tokenName = null;
tokenValue = null;
requestHeaders = new Dictionary<string, string>();
requestQueries = new Dictionary<string, string>();
items = new Dictionary<object, object>();
remote = IPAddress.Loopback;
}
#region Antiforgery
[TestMethod]
public void ShouldReturnAntiforgery()
{
// arrange
tokenName = "af-token";
tokenValue = "security_first";
var context = GetContext();
// act
var result = context.GetAntiforgeryToken();
// assert
Assert.AreEqual(tokenName, result.Name);
Assert.AreEqual(tokenValue, result.Value);
}
[TestMethod]
public void ShouldReturnAntiforgeryNullService()
{
// arrange
tokenName = "af-token";
tokenValue = "security_first";
var context = GetContext(hasAntiforgery: false);
// act
var result = context.GetAntiforgeryToken();
// assert
Assert.AreEqual(null, result.Name);
Assert.AreEqual(null, result.Value);
}
[TestMethod]
public void ShouldReturnAntiforgeryNullToken()
{
// arrange
var context = GetContext();
// act
var result = context.GetAntiforgeryToken();
// assert
Assert.AreEqual(null, result.Name);
Assert.AreEqual(null, result.Value);
}
#endregion Antiforgery
#region RemoteAddres
[TestMethod]
public void ShouldReturnRemoteAddress()
{
// arrange
remote = IPAddress.Parse("1.2.3.4");
var context = GetContext();
// act
var result = context.GetRemoteIpAddress();
// assert
Assert.AreEqual(remote, result);
}
[TestMethod]
public void ShouldReturnDefaultHeader()
{
// arrange
remote = IPAddress.Parse("1.2.3.4");
var header = IPAddress.Parse("5.6.7.8");
requestHeaders.Add("X-Forwarded-For", header.ToString());
var context = GetContext();
// act
var result = context.GetRemoteIpAddress();
// assert
Assert.AreNotEqual(remote, result);
Assert.AreEqual(header, result);
}
[TestMethod]
public void ShouldReturnCustomHeader()
{
// arrange
remote = IPAddress.Parse("1.2.3.4");
string headerName = "FooBar";
var headerIp = IPAddress.Parse("5.6.7.8");
requestHeaders.Add(headerName, headerIp.ToString());
var context = GetContext();
// act
var result = context.GetRemoteIpAddress(headerName: headerName);
// assert
Assert.AreNotEqual(remote, result);
Assert.AreEqual(headerIp, result);
}
[TestMethod]
public void ShouldReturnAddressInvalidHeader()
{
// arrange
remote = IPAddress.Parse("1.2.3.4");
requestHeaders.Add("X-Forwarded-For", "1.2.3:4");
var context = GetContext();
// act
var result = context.GetRemoteIpAddress();
// assert
Assert.AreEqual(remote, result);
}
#endregion RemoteAddres
#region Local Request
[TestMethod]
public void ShouldReturnTrueOnLocal()
{
// arrange
remote = IPAddress.Loopback;
var context = GetContext();
// act
bool result = context.IsLocalRequest();
// assert
Assert.IsTrue(result);
}
[TestMethod]
public void ShouldReturnFalseOnRemote()
{
// arrange
remote = IPAddress.Parse("1.2.3.4");
var context = GetContext();
// act
bool result = context.IsLocalRequest();
// assert
Assert.IsFalse(result);
}
[TestMethod]
public void ShouldReturnTrueOnDefaultHeader()
{
// arrange
remote = IPAddress.Parse("1.2.3.4");
var headerIp = IPAddress.Loopback;
requestHeaders.Add("X-Forwarded-For", headerIp.ToString());
var context = GetContext();
// act
bool result = context.IsLocalRequest();
// assert
Assert.IsTrue(result);
}
[TestMethod]
public void ShouldReturnTrueOnCustomHeader()
{
// arrange
remote = IPAddress.Parse("1.2.3.4");
string headerName = "FooBar";
var headerIp = IPAddress.Loopback;
requestHeaders.Add(headerName, headerIp.ToString());
var context = GetContext();
// act
bool result = context.IsLocalRequest(headerName: headerName);
// assert
Assert.IsTrue(result);
}
[TestMethod]
public void ShouldReturnFalseOnDefaultHeader()
{
// arrange
var headerIp = IPAddress.Parse("1.2.3.4");
requestHeaders.Add("X-Forwarded-For", headerIp.ToString());
var context = GetContext();
// act
bool result = context.IsLocalRequest();
// assert
Assert.IsFalse(result);
}
[TestMethod]
public void ShouldReturnFalseOnCustomHeader()
{
// arrange
string headerName = "FooBar";
var headerIp = IPAddress.Parse("1.2.3.4");
requestHeaders.Add(headerName, headerIp.ToString());
var context = GetContext();
// act
bool result = context.IsLocalRequest(headerName: headerName);
// assert
Assert.IsFalse(result);
}
#endregion Local Request
#region ReturnUrl
[TestMethod]
public void ShouldReturnNull()
{
// arrange
var context = GetContext();
// act
string result = context.GetReturnUrl();
// assert
Assert.IsNull(result);
}
[TestMethod]
public void ShouldReturnOriginalRequest()
{
// arrange
string request = "abc";
string query = "def";
items.Add("OriginalRequest", request);
requestQueries.Add("ReturnUrl", query);
var context = GetContext();
// act
string result = context.GetReturnUrl();
// assert
Assert.AreEqual(request, result);
Assert.AreNotEqual(query, result);
}
[TestMethod]
public void ShouldReturnUrl()
{
// arrange
string query = "def";
requestQueries.Add("ReturnUrl", query);
var context = GetContext();
// act
string result = context.GetReturnUrl();
// assert
Assert.AreEqual(query, result);
}
#endregion ReturnUrl
#region Session
[TestMethod]
public void ShouldClearSession()
{
// arrange
var context = GetContext();
// act
context.ClearSession();
// assert
sessionMock.Verify(s => s.Clear(), Times.Once);
}
[TestMethod]
public void ShouldSkipWhenNoSession()
{
// arrange
var context = GetContext(hasSession: false);
// act
context.ClearSession();
// assert
sessionMock.Verify(s => s.Clear(), Times.Never);
}
#endregion Session
private HttpContext GetContext(bool hasAntiforgery = true, bool hasSession = true)
{
// Request
var requestHeaderMock = new Mock<IHeaderDictionary>();
foreach (var header in requestHeaders)
{
requestHeaderMock
.Setup(h => h.ContainsKey(header.Key))
.Returns(true);
requestHeaderMock
.Setup(h => h[header.Key])
.Returns(header.Value);
}
var requestQueryMock = new Mock<IQueryCollection>();
foreach (var query in requestQueries)
{
requestQueryMock
.Setup(h => h.ContainsKey(query.Key))
.Returns(true);
requestQueryMock
.Setup(h => h[query.Key])
.Returns(query.Value);
}
var requestMock = new Mock<HttpRequest>();
requestMock
.Setup(r => r.Headers)
.Returns(requestHeaderMock.Object);
requestMock
.Setup(r => r.Query)
.Returns(requestQueryMock.Object);
// Request Services
var requestServicesMock = new Mock<IServiceProvider>();
if (hasAntiforgery)
{
var antiforgeryMock = new Mock<IAntiforgery>();
antiforgeryMock
.Setup(af => af.GetAndStoreTokens(It.IsAny<HttpContext>()))
.Returns(string.IsNullOrWhiteSpace(tokenName) ? null : new AntiforgeryTokenSet(tokenValue, tokenValue, tokenName, tokenName));
requestServicesMock
.Setup(rs => rs.GetService(typeof(IAntiforgery)))
.Returns(antiforgeryMock.Object);
}
// Connection
var connectionInfoMock = new Mock<ConnectionInfo>();
connectionInfoMock
.Setup(ci => ci.LocalIpAddress)
.Returns(IPAddress.Loopback);
connectionInfoMock
.Setup(ci => ci.RemoteIpAddress)
.Returns(remote);
// Session
sessionMock = new Mock<ISession>();
var contextMock = new Mock<HttpContext>();
contextMock
.Setup(c => c.Request)
.Returns(requestMock.Object);
contextMock
.Setup(c => c.RequestServices)
.Returns(requestServicesMock.Object);
contextMock
.Setup(c => c.Connection)
.Returns(connectionInfoMock.Object);
contextMock
.Setup(c => c.Items)
.Returns(items);
if (hasSession)
{
contextMock
.Setup(c => c.Session)
.Returns(sessionMock.Object);
}
return contextMock.Object;
}
}
}