diff --git a/AMWD.Common.AspNetCore/TagHelpers/IntegrityHashTagHelper.cs b/AMWD.Common.AspNetCore/TagHelpers/IntegrityHashTagHelper.cs
new file mode 100644
index 0000000..d4a3e83
--- /dev/null
+++ b/AMWD.Common.AspNetCore/TagHelpers/IntegrityHashTagHelper.cs
@@ -0,0 +1,162 @@
+using System;
+using System.IO;
+using System.Linq;
+using System.Net.Http;
+using System.Security.Cryptography;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.Hosting;
+using Microsoft.AspNetCore.Html;
+using Microsoft.Extensions.Configuration;
+
+namespace Microsoft.AspNetCore.Razor.TagHelpers
+{
+ ///
+ /// A tag helper to dynamically create integrity checks for linked sources.
+ ///
+ [HtmlTargetElement("link")]
+ [HtmlTargetElement("script")]
+ public class IntegrityHashTagHelper : TagHelper
+ {
+ private readonly IWebHostEnvironment env;
+ private readonly string hostUrl;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The web host environment.
+ /// The application configuration.
+ public IntegrityHashTagHelper(IWebHostEnvironment env, IConfiguration configuration)
+ {
+ this.env = env;
+ hostUrl = configuration.GetValue("ASPNETCORE_APPL_URL", "http://localhost/");
+ }
+
+ ///
+ /// Gets or sets a value indicating whether the integrity should be calculated.
+ ///
+ [HtmlAttributeName("asp-integrity")]
+ public bool IsIntegrityEnabled { get; set; }
+
+ ///
+ /// Gets or sets the hash strength to use.
+ ///
+ [HtmlAttributeName("asp-integrity-strength")]
+ public int IntegrityStrength { get; set; }
+
+ ///
+ public override async Task ProcessAsync(TagHelperContext context, TagHelperOutput output)
+ {
+ if (context.AllAttributes.Where(a => a.Name.Equals("integrity", StringComparison.OrdinalIgnoreCase)).Any())
+ return;
+
+ if (!IsIntegrityEnabled)
+ return;
+
+ string source = null;
+ switch (context.TagName.ToLower())
+ {
+ case "link":
+ var rel = context.AllAttributes.Where(a => a.Name.ToLower() == "rel").FirstOrDefault();
+ if (rel == null || rel.Value.ToString().ToLower() == "stylesheet")
+ {
+ var href = context.AllAttributes.Where(a => a.Name.ToLower() == "href").FirstOrDefault();
+ source = href?.Value?.ToString().Trim();
+ }
+ break;
+ case "script":
+ var src = context.AllAttributes.Where(a => a.Name.ToLower() == "src").FirstOrDefault();
+ source = src?.Value?.ToString().Trim();
+ break;
+ }
+
+ // no source given, no hash to calculate.
+ if (string.IsNullOrWhiteSpace(source))
+ return;
+
+ byte[] fileBytes = null;
+ if (source.StartsWith("http") || source.StartsWith("//"))
+ {
+ if (source.StartsWith("//"))
+ source = $"http:{source}";
+
+ try
+ {
+ using var client = new HttpClient();
+
+ if (!string.IsNullOrWhiteSpace(hostUrl))
+ client.DefaultRequestHeaders.Referrer = new Uri(hostUrl);
+
+ var response = await client.GetAsync(source);
+ fileBytes = await response.Content.ReadAsByteArrayAsync();
+ }
+ catch
+ {
+ return;
+ }
+ }
+ else
+ {
+ if (source.StartsWith("~"))
+ source = source[1..];
+
+ if (source.StartsWith("/"))
+ source = source[1..];
+
+ if (source.Contains("?"))
+ source = source[..source.IndexOf("?")];
+
+ try
+ {
+ string path = Path.Combine(env.WebRootPath, source);
+ fileBytes = await File.ReadAllBytesAsync(path);
+ }
+ catch
+ {
+ return;
+ }
+ }
+
+ string type;
+ byte[] hashBytes = new byte[0];
+ switch (IntegrityStrength)
+ {
+ case 512:
+ type = "sha512";
+ using (var sha = SHA512.Create())
+ {
+ hashBytes = sha.ComputeHash(fileBytes);
+ }
+ break;
+ case 384:
+ type = "sha384";
+ using (var sha = SHA384.Create())
+ {
+ hashBytes = sha.ComputeHash(fileBytes);
+ }
+ break;
+ default: // 256
+ type = "sha256";
+ using (var sha = SHA256.Create())
+ {
+ hashBytes = sha.ComputeHash(fileBytes);
+ }
+ break;
+ }
+ string hash = Convert.ToBase64String(hashBytes);
+
+ output.Attributes.RemoveAll("integrity");
+ output.Attributes.RemoveAll("crossorigin");
+
+ output.Attributes.Add(new TagHelperAttribute("integrity", new HtmlString($"{type}-{hash}")));
+ output.Attributes.Add(new TagHelperAttribute("crossorigin", new HtmlString("anonymous")));
+ }
+
+ ///
+ public override void Process(TagHelperContext context, TagHelperOutput output)
+ {
+ // ensure leaving context to prevent a deadlock.
+ var task = Task.Run(() => ProcessAsync(context, output));
+ task.Wait();
+ }
+ }
+}
diff --git a/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs b/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs
index 7dd8dab..d4fecd5 100644
--- a/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs
+++ b/AMWD.Common.EntityFrameworkCore/Extensions/DatabaseFacadeExtensions.cs
@@ -31,6 +31,9 @@ namespace Microsoft.EntityFrameworkCore
var options = new DatabaseMigrationOptions();
optionsAction?.Invoke(options);
+ if (database.GetProviderType() == DatabaseProvider.InMemory)
+ return true;
+
if (string.IsNullOrWhiteSpace(options.MigrationsTableName))
throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required.");
@@ -41,6 +44,7 @@ namespace Microsoft.EntityFrameworkCore
try
{
await connection.OpenAsync(cancellationToken);
+
if (!await connection.CreateMigrationsTable(options, cancellationToken))
return false;
@@ -52,10 +56,14 @@ namespace Microsoft.EntityFrameworkCore
}
}
- private static DatabaseProvider GetProviderType(this DbConnection connection)
- {
- string provider = connection.GetType().FullName;
+ private static DatabaseProvider GetProviderType(this DatabaseFacade database)
+ => GetProviderType(database.ProviderName);
+ private static DatabaseProvider GetProviderType(this DbConnection connection)
+ => GetProviderType(connection.GetType().FullName);
+
+ private static DatabaseProvider GetProviderType(string provider)
+ {
if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.MySQL;
if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase))
@@ -74,9 +82,6 @@ namespace Microsoft.EntityFrameworkCore
private static async Task CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{
- if (connection.GetProviderType() == DatabaseProvider.InMemory)
- return true;
-
try
{
using var command = connection.CreateCommand();
@@ -140,9 +145,6 @@ END;"
private static async Task Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{
- if (connection.GetProviderType() == DatabaseProvider.InMemory)
- return true;
-
try
{
List availableMigrationFiles;