From 142917a7fd730c501e84f9005c364cc53ea71255 Mon Sep 17 00:00:00 2001 From: Andreas Mueller Date: Tue, 15 Feb 2022 22:04:00 +0100 Subject: [PATCH] Fixing errors with in-memory context, adding integrity hash tag helper --- .../TagHelpers/IntegrityHashTagHelper.cs | 162 ++++++++++++++++++ .../Extensions/DatabaseFacadeExtensions.cs | 20 ++- 2 files changed, 173 insertions(+), 9 deletions(-) create mode 100644 AMWD.Common.AspNetCore/TagHelpers/IntegrityHashTagHelper.cs diff --git a/AMWD.Common.AspNetCore/TagHelpers/IntegrityHashTagHelper.cs b/AMWD.Common.AspNetCore/TagHelpers/IntegrityHashTagHelper.cs new file mode 100644 index 0000000..d4a3e83 --- /dev/null +++ b/AMWD.Common.AspNetCore/TagHelpers/IntegrityHashTagHelper.cs @@ -0,0 +1,162 @@ +using System; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Security.Cryptography; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Html; +using Microsoft.Extensions.Configuration; + +namespace Microsoft.AspNetCore.Razor.TagHelpers +{ + /// + /// A tag helper to dynamically create integrity checks for linked sources. + /// + [HtmlTargetElement("link")] + [HtmlTargetElement("script")] + public class IntegrityHashTagHelper : TagHelper + { + private readonly IWebHostEnvironment env; + private readonly string hostUrl; + + /// + /// Initializes a new instance of the class. + /// + /// The web host environment. + /// The application configuration. + public IntegrityHashTagHelper(IWebHostEnvironment env, IConfiguration configuration) + { + this.env = env; + hostUrl = configuration.GetValue("ASPNETCORE_APPL_URL", "http://localhost/"); + } + + /// + /// Gets or sets a value indicating whether the integrity should be calculated. + /// + [HtmlAttributeName("asp-integrity")] + public bool IsIntegrityEnabled { get; set; } + + /// + /// Gets or sets the hash strength to use. + /// + [HtmlAttributeName("asp-integrity-strength")] + public int IntegrityStrength { get; set; } + + /// + public override async Task ProcessAsync(TagHelperContext context, TagHelperOutput output) + { + if (context.AllAttributes.Where(a => a.Name.Equals("integrity", StringComparison.OrdinalIgnoreCase)).Any()) + return; + + if (!IsIntegrityEnabled) + return; + + string source = null; + switch (context.TagName.ToLower()) + { + case "link": + var rel = context.AllAttributes.Where(a => a.Name.ToLower() == "rel").FirstOrDefault(); + if (rel == null || rel.Value.ToString().ToLower() == "stylesheet") + { + var href = context.AllAttributes.Where(a => a.Name.ToLower() == "href").FirstOrDefault(); + source = href?.Value?.ToString().Trim(); + } + break; + case "script": + var src = context.AllAttributes.Where(a => a.Name.ToLower() == "src").FirstOrDefault(); + source = src?.Value?.ToString().Trim(); + break; + } + + // no source given, no hash to calculate. + if (string.IsNullOrWhiteSpace(source)) + return; + + byte[] fileBytes = null; + if (source.StartsWith("http") || source.StartsWith("//")) + { + if (source.StartsWith("//")) + source = $"http:{source}"; + + try + { + using var client = new HttpClient(); + + if (!string.IsNullOrWhiteSpace(hostUrl)) + client.DefaultRequestHeaders.Referrer = new Uri(hostUrl); + + var response = await client.GetAsync(source); + fileBytes = await response.Content.ReadAsByteArrayAsync(); + } + catch + { + return; + } + } + else + { + if (source.StartsWith("~")) + source = source[1..]; + + if (source.StartsWith("/")) + source = source[1..]; + + if (source.Contains("?")) + source = source[..source.IndexOf("?")]; + + try + { + string path = Path.Combine(env.WebRootPath, source); + fileBytes = await File.ReadAllBytesAsync(path); + } + catch + { + return; + } + } + + string type; + byte[] hashBytes = new byte[0]; + switch (IntegrityStrength) + { + case 512: + type = "sha512"; + using (var sha = SHA512.Create()) + { + hashBytes = sha.ComputeHash(fileBytes); + } + break; + case 384: + type = "sha384"; + using (var sha = SHA384.Create()) + { + hashBytes = sha.ComputeHash(fileBytes); + } + break; + default: // 256 + type = "sha256"; + using (var sha = SHA256.Create()) + { + hashBytes = sha.ComputeHash(fileBytes); + } + break; + } + string hash = Convert.ToBase64String(hashBytes); + + output.Attributes.RemoveAll("integrity"); + output.Attributes.RemoveAll("crossorigin"); + + output.Attributes.Add(new TagHelperAttribute("integrity", new HtmlString($"{type}-{hash}"))); + output.Attributes.Add(new TagHelperAttribute("crossorigin", new HtmlString("anonymous"))); + } + + /// + public override void Process(TagHelperContext context, TagHelperOutput output) + { + // ensure leaving context to prevent a deadlock. + var task = Task.Run(() => ProcessAsync(context, output)); + task.Wait(); + } + } +} diff --git a/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs b/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs index 7dd8dab..d4fecd5 100644 --- a/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs +++ b/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs @@ -31,6 +31,9 @@ namespace Microsoft.EntityFrameworkCore var options = new DatabaseMigrationOptions(); optionsAction?.Invoke(options); + if (database.GetProviderType() == DatabaseProvider.InMemory) + return true; + if (string.IsNullOrWhiteSpace(options.MigrationsTableName)) throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required."); @@ -41,6 +44,7 @@ namespace Microsoft.EntityFrameworkCore try { await connection.OpenAsync(cancellationToken); + if (!await connection.CreateMigrationsTable(options, cancellationToken)) return false; @@ -52,10 +56,14 @@ namespace Microsoft.EntityFrameworkCore } } - private static DatabaseProvider GetProviderType(this DbConnection connection) - { - string provider = connection.GetType().FullName; + private static DatabaseProvider GetProviderType(this DatabaseFacade database) + => GetProviderType(database.ProviderName); + private static DatabaseProvider GetProviderType(this DbConnection connection) + => GetProviderType(connection.GetType().FullName); + + private static DatabaseProvider GetProviderType(string provider) + { if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase)) return DatabaseProvider.MySQL; if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase)) @@ -74,9 +82,6 @@ namespace Microsoft.EntityFrameworkCore private static async Task CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken) { - if (connection.GetProviderType() == DatabaseProvider.InMemory) - return true; - try { using var command = connection.CreateCommand(); @@ -140,9 +145,6 @@ END;" private static async Task Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken) { - if (connection.GetProviderType() == DatabaseProvider.InMemory) - return true; - try { List availableMigrationFiles;