1
0

Added ViewContext extensions

This commit is contained in:
2025-12-12 08:54:06 +01:00
parent 14b38135a0
commit 6f92908a81
7 changed files with 116 additions and 77 deletions

View File

@@ -0,0 +1,60 @@
using System;
namespace Microsoft.AspNetCore.Mvc.Rendering
{
/// <summary>
/// Provides extension methods for the <see cref="ViewContext"/>.
/// </summary>
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public static class ViewContextExtensions
{
/// <summary>
/// Determines whether the current view context matches the specified controller, action, and area names.
/// </summary>
/// <remarks>
/// This method is commonly used in web applications to determine whether a navigation element should be marked as active based on the current route data.
/// </remarks>
/// <param name="viewContext">The <see cref="ViewContext"/> containing routing information for the current request.</param>
/// <param name="controller">The name of the controller to compare against the current route.</param>
/// <param name="action">
/// The name of the action to compare against the current route.
/// If <see langword="null"/>, the action is not considered in the comparison.
/// </param>
/// <param name="area">
/// The name of the area to compare against the current route.
/// If <see langword="null"/>, the area is not considered in the comparison.
/// </param>
public static bool IsActive(this ViewContext viewContext, string controller, string action = null, string area = null)
{
string currentController = viewContext.RouteData.Values["Controller"]?.ToString() ?? "";
string currentAction = viewContext.RouteData.Values["Action"]?.ToString() ?? "";
string currentArea = viewContext.RouteData.Values["Area"]?.ToString() ?? "";
if (!string.IsNullOrWhiteSpace(area) && !string.Equals(currentArea, area, StringComparison.OrdinalIgnoreCase))
return false;
if (!string.Equals(currentController, controller, StringComparison.OrdinalIgnoreCase))
return false;
if (!string.IsNullOrWhiteSpace(action) && !string.Equals(currentAction, action, StringComparison.OrdinalIgnoreCase))
return false;
return true;
}
/// <summary>
/// Determines whether the specified area is the active area in the current view context.
/// </summary>
/// <remarks>
/// This method is typically used in web applications to highlight navigation elements or perform logic based on the active area.
/// </remarks>
/// <param name="viewContext">The <see cref="ViewContext"/> containing routing information for the current request.</param>
/// <param name="area">The name of the area to check for activity.</param>
public static bool IsAreaActive(this ViewContext viewContext, string area)
{
string currentArea = viewContext.RouteData.Values["Area"]?.ToString();
return string.Equals(currentArea, area, StringComparison.OrdinalIgnoreCase);
}
}
}

View File

@@ -13,7 +13,6 @@ using Microsoft.Extensions.Options;
namespace AMWD.Common.AspNetCore.Security.BasicAuthentication
{
#if NET8_0_OR_GREATER
/// <summary>
/// Implements the <see cref="AuthenticationHandler{TOptions}"/> for Basic Authentication.
/// </summary>
@@ -27,22 +26,6 @@ namespace AMWD.Common.AspNetCore.Security.BasicAuthentication
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public class BasicAuthenticationHandler(IOptionsMonitor<AuthenticationSchemeOptions> options, ILoggerFactory logger, UrlEncoder encoder, IBasicAuthenticationValidator validator)
: AuthenticationHandler<AuthenticationSchemeOptions>(options, logger, encoder)
#else
/// <summary>
/// Implements the <see cref="AuthenticationHandler{TOptions}"/> for Basic Authentication.
/// </summary>
/// <remarks>
/// Initializes a new instance of the <see cref="BasicAuthenticationHandler"/> class.
/// </remarks>
/// <param name="options" > The monitor for the options instance.</param>
/// <param name="logger">The <see cref="ILoggerFactory"/>.</param>
/// <param name="encoder">The <see cref="UrlEncoder"/>.</param>
/// <param name="clock">The <see cref="ISystemClock"/>.</param>
/// <param name="validator">An basic autentication validator implementation.</param>
[System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
public class BasicAuthenticationHandler(IOptionsMonitor<AuthenticationSchemeOptions> options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock, IBasicAuthenticationValidator validator)
: AuthenticationHandler<AuthenticationSchemeOptions>(options, logger, encoder, clock)
#endif
{
private readonly ILogger _logger = logger.CreateLogger<BasicAuthenticationHandler>();
private readonly IBasicAuthenticationValidator _validator = validator;

View File

@@ -29,27 +29,15 @@ namespace AMWD.Common.AspNetCore.Security.BasicAuthentication
/// <returns>An awaitable task.</returns>
public async Task InvokeAsync(HttpContext httpContext)
{
#if NET8_0_OR_GREATER
if (!httpContext.Request.Headers.TryGetValue("Authorization", out var authHeaderValue))
{
SetAuthenticateRequest(httpContext, _validator.Realm);
return;
}
#else
if (!httpContext.Request.Headers.ContainsKey("Authorization"))
{
SetAuthenticateRequest(httpContext, _validator.Realm);
return;
}
#endif
try
{
#if NET8_0_OR_GREATER
var authHeader = AuthenticationHeaderValue.Parse(authHeaderValue);
#else
var authHeader = AuthenticationHeaderValue.Parse(httpContext.Request.Headers["Authorization"]);
#endif
byte[] decoded = Convert.FromBase64String(authHeader.Parameter);
string plain = Encoding.UTF8.GetString(decoded);

View File

@@ -15,11 +15,7 @@ namespace Microsoft.EntityFrameworkCore
/// <summary>
/// Extensions for the <see cref="DatabaseFacade"/>.
/// </summary>
#if NET8_0_OR_GREATER
public static partial class DatabaseFacadeExtensions
#else
public static class DatabaseFacadeExtensions
#endif
{
/// <summary>
/// Applies migration files to the database.
@@ -31,8 +27,7 @@ namespace Microsoft.EntityFrameworkCore
[System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "CA2208")]
public static async Task<bool> ApplyMigrationsAsync(this DatabaseFacade database, Action<DatabaseMigrationOptions> optionsAction, CancellationToken cancellationToken = default)
{
if (database == null)
throw new ArgumentNullException(nameof(database));
ArgumentNullException.ThrowIfNull(database);
if (database.GetProviderType() == DatabaseProvider.InMemory)
return true;
@@ -77,8 +72,7 @@ namespace Microsoft.EntityFrameworkCore
/// <returns>An awaitable task to wait until the database is available.</returns>
public static async Task WaitAvailableAsync(this DatabaseFacade database, Action<DatabaseMigrationOptions> optionsAction = null, CancellationToken cancellationToken = default)
{
if (database == null)
throw new ArgumentNullException(nameof(database));
ArgumentNullException.ThrowIfNull(database);
if (database.GetProviderType() == DatabaseProvider.InMemory)
return;
@@ -129,15 +123,20 @@ namespace Microsoft.EntityFrameworkCore
{
if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.MySQL;
if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.Oracle;
if (provider.Contains("npgsql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.PostgreSQL;
if (provider.Contains("sqlite", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.SQLite;
if (provider.Contains("sqlclient", StringComparison.OrdinalIgnoreCase)
|| provider.Contains("sqlserver", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.SQLServer;
if (provider.Contains("inmemory", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.InMemory;
@@ -274,11 +273,7 @@ END;"
{
using var stream = options.SourceAssembly.GetManifestResourceStream(migrationFile);
using var sr = new StreamReader(stream);
#if NET8_0_OR_GREATER
sqlScript = await sr.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
#else
sqlScript = await sr.ReadToEndAsync().ConfigureAwait(false);
#endif
}
if (string.IsNullOrWhiteSpace(sqlScript))
@@ -324,11 +319,7 @@ END;"
{
int affectedRows = 0;
// Split script by a single slash in a line
#if NET8_0_OR_GREATER
string[] parts = FindSingleSlashInLine().Split(text);
#else
string[] parts = Regex.Split(text, @"\r?\n[ \t]*/[ \t]*\r?\n");
#endif
foreach (string part in parts)
{
// Make writable copy
@@ -337,11 +328,7 @@ END;"
// Remove the trailing semicolon from commands where they're not supported
// (Oracle doesn't like semicolons. To keep the semicolon, it must be directly
// preceeded by "end".)
#if NET8_0_OR_GREATER
pt = FindEndCommand().Replace(pt.TrimEnd(), "");
#else
pt = Regex.Replace(pt, @"(?<!end);$", "", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);
#endif
// Execute all non-empty parts as individual commands
if (!string.IsNullOrWhiteSpace(pt))
@@ -369,12 +356,10 @@ END;"
InMemory = 6,
}
#if NET8_0_OR_GREATER
[GeneratedRegex(@"\r?\n[ \t]*/[ \t]*\r?\n")]
private static partial Regex FindSingleSlashInLine();
[GeneratedRegex(@"(?<!end);$", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant)]
private static partial Regex FindEndCommand();
#endif
}
}

View File

@@ -33,15 +33,8 @@ namespace Microsoft.EntityFrameworkCore
/// <returns>The <see cref="DbContextOptionsBuilder"/> with applied settings.</returns>
public static DbContextOptionsBuilder UseDatabaseProvider(this DbContextOptionsBuilder optionsBuilder, IConfiguration configuration, Action<DatabaseProviderOptions> optionsAction = null)
{
#if NET8_0_OR_GREATER
ArgumentNullException.ThrowIfNull(optionsBuilder);
ArgumentNullException.ThrowIfNull(configuration);
#else
if (optionsBuilder == null)
throw new ArgumentNullException(nameof(optionsBuilder));
if (configuration == null)
throw new ArgumentNullException(nameof(configuration));
#endif
var options = new DatabaseProviderOptions();
optionsAction?.Invoke(options);
@@ -49,43 +42,47 @@ namespace Microsoft.EntityFrameworkCore
string connectionString = GetConnectionString(configuration, options);
string provider = configuration.GetValue<string>("provider")?.ToLower();
var builderType = GetBuilderType(configuration);
var extensionType = GetExtensionType(configuration);
var builderType = GetBuilderType(configuration)
?? throw new DatabaseProviderException($"Could not find the DbContextOptionsBuilder for provider: {provider}");
var extensionType = GetExtensionType(configuration)
?? throw new DatabaseProviderException($"Could not find the DbContextOptionsBuilder extensions for provider: {provider}");
var actionType = typeof(Action<>).MakeGenericType(builderType);
object serverVersion = null;
MethodInfo methodInfo;
object serverVersion = null;
switch (provider)
{
case "mysql":
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
methodInfo = extensionType.GetMethod("UseMySql", [typeof(DbContextOptionsBuilder), typeof(string), actionType]);
if (methodInfo == null)
methodInfo = extensionType.GetMethod("UseMySQL", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
methodInfo = extensionType.GetMethod("UseMySQL", [typeof(DbContextOptionsBuilder), typeof(string), actionType]);
if (methodInfo == null) // Pomelo MySQL v5
{
var serverVersionType = Type.GetType("Microsoft.EntityFrameworkCore.ServerVersion, Pomelo.EntityFrameworkCore.MySql");
var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", new Type[] { typeof(string) });
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType });
serverVersion = autoDetectMethodInfo.Invoke(null, new object[] { connectionString });
var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", [typeof(string)]);
methodInfo = extensionType.GetMethod("UseMySql", [typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType]);
serverVersion = autoDetectMethodInfo.Invoke(null, [connectionString]);
}
break;
case "oracle":
methodInfo = extensionType.GetMethod("UseOracle", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
methodInfo = extensionType.GetMethod("UseOracle", [typeof(DbContextOptionsBuilder), typeof(string), actionType]);
break;
case "postgres":
case "postgresql":
methodInfo = extensionType.GetMethod("UseNpgsql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
methodInfo = extensionType.GetMethod("UseNpgsql", [typeof(DbContextOptionsBuilder), typeof(string), actionType]);
break;
case "sqlite":
methodInfo = extensionType.GetMethod("UseSqlite", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
methodInfo = extensionType.GetMethod("UseSqlite", [typeof(DbContextOptionsBuilder), typeof(string), actionType]);
break;
case "sqlserver":
case "mssql":
methodInfo = extensionType.GetMethod("UseSqlServer", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
methodInfo = extensionType.GetMethod("UseSqlServer", [typeof(DbContextOptionsBuilder), typeof(string), actionType]);
break;
case "memory":
case "inmemory":
methodInfo = extensionType.GetMethod("UseInMemoryDatabase", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
methodInfo = extensionType.GetMethod("UseInMemoryDatabase", [typeof(DbContextOptionsBuilder), typeof(string), actionType]);
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
@@ -93,11 +90,11 @@ namespace Microsoft.EntityFrameworkCore
if (serverVersion == null)
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, null });
methodInfo?.Invoke(null, [optionsBuilder, connectionString, null]);
}
else
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, serverVersion, null });
methodInfo?.Invoke(null, [optionsBuilder, connectionString, serverVersion, null]);
}
return optionsBuilder;
@@ -105,8 +102,8 @@ namespace Microsoft.EntityFrameworkCore
private static Type GetBuilderType(IConfiguration configuration)
{
string provider = configuration.GetValue<string>("provider")?.ToLower();
Type builderType;
string provider = configuration.GetValue<string>("provider")?.ToLower();
switch (provider)
{
case "mysql":
@@ -116,34 +113,41 @@ namespace Microsoft.EntityFrameworkCore
if (builderType == null) // as MySql.Data.EntityFrameworkCore is marked as deprecated on NuGet
builderType = Type.GetType("MySql.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.EntityFrameworkCore");
break;
case "oracle":
builderType = Type.GetType("Oracle.EntityFrameworkCore.Infrastructure.OracleDbContextOptionsBuilder, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
builderType = Type.GetType("Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.NpgsqlDbContextOptionsBuilder, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqliteDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqlServerDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.InMemoryDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return builderType;
}
private static Type GetExtensionType(IConfiguration configuration)
{
string provider = configuration.GetValue<string>("provider")?.ToLower();
Type extensionType;
string provider = configuration.GetValue<string>("provider")?.ToLower();
switch (provider)
{
case "mysql":
@@ -153,27 +157,34 @@ namespace Microsoft.EntityFrameworkCore
if (extensionType == null)
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.EntityFrameworkCore");
break;
case "oracle":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.OracleDbContextOptionsBuilderExtensions, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.NpgsqlDbContextOptionsBuilderExtensions, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqliteDbContextOptionsBuilderExtensions, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqlServerDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.InMemoryDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return extensionType;
}
@@ -191,12 +202,14 @@ namespace Microsoft.EntityFrameworkCore
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "oracle":
cs.Add($"Data Source=(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={configuration.GetValue<string>("Host")})(PORT={configuration.GetValue("Port", 1521)}))(CONNECT_DATA=(SERVICE_NAME={configuration.GetValue<string>("Name")})))");
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "postgres":
case "postgresql":
cs.Add($"Server={configuration.GetValue<string>("Host")}");
@@ -207,6 +220,7 @@ namespace Microsoft.EntityFrameworkCore
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Timeout=15");
break;
case "sqlite":
string path = configuration.GetValue<string>("File");
if (!Path.IsPathRooted(path))
@@ -219,6 +233,7 @@ namespace Microsoft.EntityFrameworkCore
cs.Add($"Data Source={path}");
cs.Add("Foreign Keys=True");
break;
case "sqlserver":
case "mssql":
cs.Add($"Server={configuration.GetValue<string>("Host")},{configuration.GetValue("Port", 1433)}");
@@ -235,13 +250,16 @@ namespace Microsoft.EntityFrameworkCore
}
cs.Add("Connect Timeout=15");
break;
case "memory":
case "inmemory":
cs.Add(configuration.GetValue("Name", provider));
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
}
return string.Join(";", cs);
}
}

View File

@@ -12,7 +12,6 @@ namespace AMWD.Common.Comparer
{
private readonly Regex _versionRegex = VersionRegex();
#else
public class VersionStringComparer : IComparer<string>
{
private readonly Regex _versionRegex = new("([0-9.]+)", RegexOptions.Compiled);