1
0

Fixing errors with in-memory context, adding integrity hash tag helper

This commit is contained in:
2022-02-15 22:04:00 +01:00
parent 1eb40237c5
commit 142917a7fd
2 changed files with 173 additions and 9 deletions

View File

@@ -0,0 +1,162 @@
using System;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Security.Cryptography;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Html;
using Microsoft.Extensions.Configuration;
namespace Microsoft.AspNetCore.Razor.TagHelpers
{
/// <summary>
/// A tag helper to dynamically create integrity checks for linked sources.
/// </summary>
[HtmlTargetElement("link")]
[HtmlTargetElement("script")]
public class IntegrityHashTagHelper : TagHelper
{
private readonly IWebHostEnvironment env;
private readonly string hostUrl;
/// <summary>
/// Initializes a new instance of the <see cref="IntegrityHashTagHelper"/> class.
/// </summary>
/// <param name="env">The web host environment.</param>
/// <param name="configuration">The application configuration.</param>
public IntegrityHashTagHelper(IWebHostEnvironment env, IConfiguration configuration)
{
this.env = env;
hostUrl = configuration.GetValue("ASPNETCORE_APPL_URL", "http://localhost/");
}
/// <summary>
/// Gets or sets a value indicating whether the integrity should be calculated.
/// </summary>
[HtmlAttributeName("asp-integrity")]
public bool IsIntegrityEnabled { get; set; }
/// <summary>
/// Gets or sets the hash strength to use.
/// </summary>
[HtmlAttributeName("asp-integrity-strength")]
public int IntegrityStrength { get; set; }
/// <inheritdoc/>
public override async Task ProcessAsync(TagHelperContext context, TagHelperOutput output)
{
if (context.AllAttributes.Where(a => a.Name.Equals("integrity", StringComparison.OrdinalIgnoreCase)).Any())
return;
if (!IsIntegrityEnabled)
return;
string source = null;
switch (context.TagName.ToLower())
{
case "link":
var rel = context.AllAttributes.Where(a => a.Name.ToLower() == "rel").FirstOrDefault();
if (rel == null || rel.Value.ToString().ToLower() == "stylesheet")
{
var href = context.AllAttributes.Where(a => a.Name.ToLower() == "href").FirstOrDefault();
source = href?.Value?.ToString().Trim();
}
break;
case "script":
var src = context.AllAttributes.Where(a => a.Name.ToLower() == "src").FirstOrDefault();
source = src?.Value?.ToString().Trim();
break;
}
// no source given, no hash to calculate.
if (string.IsNullOrWhiteSpace(source))
return;
byte[] fileBytes = null;
if (source.StartsWith("http") || source.StartsWith("//"))
{
if (source.StartsWith("//"))
source = $"http:{source}";
try
{
using var client = new HttpClient();
if (!string.IsNullOrWhiteSpace(hostUrl))
client.DefaultRequestHeaders.Referrer = new Uri(hostUrl);
var response = await client.GetAsync(source);
fileBytes = await response.Content.ReadAsByteArrayAsync();
}
catch
{
return;
}
}
else
{
if (source.StartsWith("~"))
source = source[1..];
if (source.StartsWith("/"))
source = source[1..];
if (source.Contains("?"))
source = source[..source.IndexOf("?")];
try
{
string path = Path.Combine(env.WebRootPath, source);
fileBytes = await File.ReadAllBytesAsync(path);
}
catch
{
return;
}
}
string type;
byte[] hashBytes = new byte[0];
switch (IntegrityStrength)
{
case 512:
type = "sha512";
using (var sha = SHA512.Create())
{
hashBytes = sha.ComputeHash(fileBytes);
}
break;
case 384:
type = "sha384";
using (var sha = SHA384.Create())
{
hashBytes = sha.ComputeHash(fileBytes);
}
break;
default: // 256
type = "sha256";
using (var sha = SHA256.Create())
{
hashBytes = sha.ComputeHash(fileBytes);
}
break;
}
string hash = Convert.ToBase64String(hashBytes);
output.Attributes.RemoveAll("integrity");
output.Attributes.RemoveAll("crossorigin");
output.Attributes.Add(new TagHelperAttribute("integrity", new HtmlString($"{type}-{hash}")));
output.Attributes.Add(new TagHelperAttribute("crossorigin", new HtmlString("anonymous")));
}
/// <inheritdoc/>
public override void Process(TagHelperContext context, TagHelperOutput output)
{
// ensure leaving context to prevent a deadlock.
var task = Task.Run(() => ProcessAsync(context, output));
task.Wait();
}
}
}

View File

@@ -31,6 +31,9 @@ namespace Microsoft.EntityFrameworkCore
var options = new DatabaseMigrationOptions(); var options = new DatabaseMigrationOptions();
optionsAction?.Invoke(options); optionsAction?.Invoke(options);
if (database.GetProviderType() == DatabaseProvider.InMemory)
return true;
if (string.IsNullOrWhiteSpace(options.MigrationsTableName)) if (string.IsNullOrWhiteSpace(options.MigrationsTableName))
throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required."); throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required.");
@@ -41,6 +44,7 @@ namespace Microsoft.EntityFrameworkCore
try try
{ {
await connection.OpenAsync(cancellationToken); await connection.OpenAsync(cancellationToken);
if (!await connection.CreateMigrationsTable(options, cancellationToken)) if (!await connection.CreateMigrationsTable(options, cancellationToken))
return false; return false;
@@ -52,10 +56,14 @@ namespace Microsoft.EntityFrameworkCore
} }
} }
private static DatabaseProvider GetProviderType(this DbConnection connection) private static DatabaseProvider GetProviderType(this DatabaseFacade database)
{ => GetProviderType(database.ProviderName);
string provider = connection.GetType().FullName;
private static DatabaseProvider GetProviderType(this DbConnection connection)
=> GetProviderType(connection.GetType().FullName);
private static DatabaseProvider GetProviderType(string provider)
{
if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase)) if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.MySQL; return DatabaseProvider.MySQL;
if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase)) if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase))
@@ -74,9 +82,6 @@ namespace Microsoft.EntityFrameworkCore
private static async Task<bool> CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken) private static async Task<bool> CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{ {
if (connection.GetProviderType() == DatabaseProvider.InMemory)
return true;
try try
{ {
using var command = connection.CreateCommand(); using var command = connection.CreateCommand();
@@ -140,9 +145,6 @@ END;"
private static async Task<bool> Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken) private static async Task<bool> Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{ {
if (connection.GetProviderType() == DatabaseProvider.InMemory)
return true;
try try
{ {
List<string> availableMigrationFiles; List<string> availableMigrationFiles;