using System; using System.Net.Http.Headers; using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.Filters; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace AMWD.Common.AspNetCore.BasicAuthentication { /// /// A basic authentication as attribute to use for specific actions. /// public class BasicAuthenticationAttribute : ActionFilterAttribute { private readonly ILogger logger; private readonly IServiceScopeFactory serviceScopeFactory; /// /// Initializes a new instance of the class. /// /// A logger. /// A service scope factory. public BasicAuthenticationAttribute(ILogger logger, IServiceScopeFactory serviceScopeFactory) { this.logger = logger; this.serviceScopeFactory = serviceScopeFactory; } /// /// Gets or sets a username to validate. /// public string Username { get; set; } /// /// Gets or sets a password to validate. /// public string Password { get; set; } /// /// Gets or sets a realm used on authentication header. /// public string Realm { get; set; } /// public override async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next) { await DoValidation(context); await base.OnActionExecutionAsync(context, next); } private async Task DoValidation(ActionExecutingContext context) { if (context.Result != null) return; if (!context.HttpContext.Request.Headers.ContainsKey("Authorization")) { SetAuthenticateRequest(context); return; } try { var authHeader = AuthenticationHeaderValue.Parse(context.HttpContext.Request.Headers["Authorization"]); byte[] decoded = Convert.FromBase64String(authHeader.Parameter); string plain = Encoding.UTF8.GetString(decoded); string[] credentials = plain.Split(':', 2); if (!string.IsNullOrWhiteSpace(Username) && !string.IsNullOrWhiteSpace(Password)) { if (Username == credentials[0] && Password == credentials[1]) return; } using var scope = serviceScopeFactory.CreateScope(); var validator = scope.ServiceProvider.GetService(); var principal = await validator?.ValidateAsync(credentials[0], credentials[1], context.HttpContext.GetRemoteIpAddress()); if (principal == null) SetAuthenticateRequest(context); } catch (Exception ex) { logger.LogError(ex, $"Failed to execute the basic authentication attribute: {ex.Message}"); context.Result = new StatusCodeResult(StatusCodes.Status500InternalServerError); } } private void SetAuthenticateRequest(ActionExecutingContext context) { context.HttpContext.Response.Headers["WWW-Authenticate"] = "Basic"; if (!string.IsNullOrWhiteSpace(Realm)) context.HttpContext.Response.Headers["WWW-Authenticate"] += $" realm=\"{Realm.Replace("\"", "")}\""; context.HttpContext.Response.StatusCode = 401; context.Result = new UnauthorizedResult(); } } }