1
0

Solution restructured to use multiple test projects

This commit is contained in:
2024-07-04 18:22:26 +02:00
parent 508379d704
commit df6763b99b
144 changed files with 387 additions and 1693 deletions

View File

@@ -0,0 +1,380 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.Extensions.Logging;
namespace Microsoft.EntityFrameworkCore
{
/// <summary>
/// Extensions for the <see cref="DatabaseFacade"/>.
/// </summary>
#if NET8_0_OR_GREATER
public static partial class DatabaseFacadeExtensions
#else
public static class DatabaseFacadeExtensions
#endif
{
/// <summary>
/// Applies migration files to the database.
/// </summary>
/// <param name="database">The database connection.</param>
/// <param name="optionsAction">An action to set additional options.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns><see langword="true"/> on success, otherwise false or an exception is thrown.</returns>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "CA2208")]
public static async Task<bool> ApplyMigrationsAsync(this DatabaseFacade database, Action<DatabaseMigrationOptions> optionsAction, CancellationToken cancellationToken = default)
{
if (database == null)
throw new ArgumentNullException(nameof(database));
if (database.GetProviderType() == DatabaseProvider.InMemory)
return true;
var options = new DatabaseMigrationOptions();
optionsAction?.Invoke(options);
if (string.IsNullOrWhiteSpace(options.MigrationsTableName))
throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required.");
if (string.IsNullOrWhiteSpace(options.Path))
throw new ArgumentNullException(nameof(options.Path), $"The property {nameof(options.Path)} of the {nameof(options)} parameter is required.");
await database.WaitAvailableAsync(opts =>
{
opts.WaitDelay = options.WaitDelay;
opts.Logger = options.Logger;
}, cancellationToken).ConfigureAwait(false);
var connection = database.GetDbConnection();
try
{
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
if (!await connection.CreateMigrationsTable(options, cancellationToken).ConfigureAwait(false))
return false;
return await connection.Migrate(options, cancellationToken).ConfigureAwait(false);
}
finally
{
await connection.CloseAsync().ConfigureAwait(false);
}
}
/// <summary>
/// Waits until the database connection is available.
/// </summary>
/// <param name="database">The database connection.</param>
/// <param name="optionsAction">An action to set additional options.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>An awaitable task to wait until the database is available.</returns>
public static async Task WaitAvailableAsync(this DatabaseFacade database, Action<DatabaseMigrationOptions> optionsAction = null, CancellationToken cancellationToken = default)
{
if (database == null)
throw new ArgumentNullException(nameof(database));
if (database.GetProviderType() == DatabaseProvider.InMemory)
return;
var options = new DatabaseMigrationOptions();
optionsAction?.Invoke(options);
options.Logger?.LogInformation("Waiting for a database connection");
var connection = database.GetDbConnection();
while (!cancellationToken.IsCancellationRequested)
{
try
{
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
options.Logger?.LogInformation("Database connection available");
return;
}
catch
{
// keep things quiet
try
{
await Task.Delay(options.WaitDelay, cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
{
return;
}
catch
{
// keep things quiet
}
}
finally
{
await connection.CloseAsync().ConfigureAwait(false);
}
}
}
internal static DatabaseProvider GetProviderType(this DatabaseFacade database)
=> GetProviderType(database.ProviderName);
private static DatabaseProvider GetProviderType(this DbConnection connection)
=> GetProviderType(connection.GetType().FullName);
private static DatabaseProvider GetProviderType(string provider)
{
if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.MySQL;
if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.Oracle;
if (provider.Contains("npgsql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.PostgreSQL;
if (provider.Contains("sqlite", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.SQLite;
if (provider.Contains("sqlclient", StringComparison.OrdinalIgnoreCase)
|| provider.Contains("sqlserver", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.SQLServer;
if (provider.Contains("inmemory", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.InMemory;
throw new DatabaseProviderException($"The database provider '{provider}' is unknown");
}
private static async Task<bool> CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{
try
{
using var command = connection.CreateCommand();
#pragma warning disable CS8509 // ignore missing cases
command.CommandText = connection.GetProviderType() switch
#pragma warning restore CS8509 // ignore missing cases
{
DatabaseProvider.MySQL => $@"CREATE TABLE IF NOT EXISTS `{options.MigrationsTableName}` (
`id` INT NOT NULL AUTO_INCREMENT,
`schema_file` VARCHAR(250) NOT NULL,
`installed_at` VARCHAR(16) NOT NULL,
PRIMARY KEY (`id`)
);",
DatabaseProvider.Oracle => $@"DECLARE ncount NUMBER;
BEGIN
SELECT count(*) INTO ncount FROM dba_tables WHERE table_name = '{options.MigrationsTableName}';
IF (ncount <= 0)
THEN
EXECUTE IMMEDIATE 'CREATE TABLE ""{options.MigrationsTableName}"" (
""id"" NUMBER GENERATED by default on null as IDENTITY,
""schema_file"" VARCHAR2(250) NOT NULL,
""installed_at"" VARCHAR2(16) NOT NULL,
PRIMARY KEY (""id""),
CONSTRAINT uq_schema_file UNIQUE (""schema_file"")
)';
END IF;
END;",
DatabaseProvider.PostgreSQL => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
""id"" SERIAL4 PRIMARY KEY,
""schema_file"" VARCHAR(250) NOT NULL,
""installed_at"" VARCHAR(16) NOT NULL,
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
);",
DatabaseProvider.SQLite => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
""id"" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
""schema_file"" TEXT(250) NOT NULL,
""installed_at"" TEXT(16) NOT NULL,
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
);",
DatabaseProvider.SQLServer => $@"IF NOT EXISTS (SELECT * FROM [sysobjects] WHERE [name] = '{options.MigrationsTableName}' AND [xtype] = 'U')
BEGIN
CREATE TABLE [{options.MigrationsTableName}] (
[id] int IDENTITY(1,1) NOT NULL PRIMARY KEY,
[schema_file] varchar(250) NOT NULL,
[installed_at] varchar(16) NOT NULL,
CONSTRAINT uq_schema_file UNIQUE (schema_file)
)
END;"
};
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
return true;
}
catch (Exception ex)
{
options.Logger?.LogCritical(ex, $"Creating migrations table '{options.MigrationsTableName}' failed: {ex.InnerException?.Message ?? ex.Message}");
return false;
}
}
private static async Task<bool> Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{
try
{
List<string> availableMigrationFiles;
if (options.SourceAssembly == null)
{
availableMigrationFiles = Directory.GetFiles(options.Path)
.Where(f => f.StartsWith(options.Path, StringComparison.OrdinalIgnoreCase))
.Where(f => f.EndsWith(".sql", StringComparison.OrdinalIgnoreCase))
.ToList();
}
else
{
availableMigrationFiles = options.SourceAssembly
.GetManifestResourceNames()
.Where(f => f.StartsWith(options.Path, StringComparison.OrdinalIgnoreCase))
.Where(f => f.EndsWith(".sql", StringComparison.OrdinalIgnoreCase))
.ToList();
}
if (availableMigrationFiles.Count == 0)
return true;
using var command = connection.CreateCommand();
var migratedFiles = new List<string>();
command.CommandText = connection.GetProviderType() switch
{
DatabaseProvider.MySQL => $"SELECT `schema_file` FROM `{options.MigrationsTableName}`;",
DatabaseProvider.SQLServer => $"SELECT [schema_file] FROM [{options.MigrationsTableName}];",
_ => $@"SELECT ""schema_file"" FROM ""{options.MigrationsTableName}"";",
};
using (var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false))
{
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
migratedFiles.Add(reader.GetString(0));
}
int pathLength = options.Path.Length + 1;
foreach (string migrationFile in availableMigrationFiles)
{
// remove path including the separator
string fileName = migrationFile.Replace(options.Path, "")[1..];
using var transaction = await connection.BeginTransactionAsync(cancellationToken).ConfigureAwait(false);
try
{
// max length in the database: 250 chars
string trimmedFileName = fileName;
if (trimmedFileName.Length > 250)
fileName = fileName[..250];
if (migratedFiles.Contains(trimmedFileName))
{
options.Logger?.LogDebug($" Migrating file '{fileName}' done");
continue;
}
string sqlScript = null;
if (options.SourceAssembly == null)
{
sqlScript = await File.ReadAllTextAsync(migrationFile, cancellationToken).ConfigureAwait(false);
}
else
{
using var stream = options.SourceAssembly.GetManifestResourceStream(migrationFile);
using var sr = new StreamReader(stream);
#if NET8_0_OR_GREATER
sqlScript = await sr.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
#else
sqlScript = await sr.ReadToEndAsync().ConfigureAwait(false);
#endif
}
if (string.IsNullOrWhiteSpace(sqlScript))
continue;
options.Logger?.LogDebug($" Migrating file '{fileName}' started");
command.Transaction = transaction;
await command.ExecuteScript(sqlScript, cancellationToken).ConfigureAwait(false);
command.CommandText = connection.GetProviderType() switch
{
DatabaseProvider.MySQL => $"INSERT INTO `{options.MigrationsTableName}` (`schema_file`, `installed_at`) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
DatabaseProvider.SQLServer => $"INSERT INTO [{options.MigrationsTableName}] ([schema_file], [installed_at]) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
_ => $@"INSERT INTO ""{options.MigrationsTableName}"" (""schema_file"", ""installed_at"") VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
};
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
await transaction.CommitAsync(cancellationToken).ConfigureAwait(false);
command.Transaction = null;
options.Logger?.LogDebug($" Migrating file '{fileName}' successful");
}
catch (Exception ex)
{
await transaction.RollbackAsync(cancellationToken).ConfigureAwait(false);
options.Logger?.LogError($"Migrating file '{fileName}' failed: {ex.InnerException?.Message ?? ex.Message}");
return false;
}
}
return true;
}
catch (Exception ex)
{
options.Logger?.LogCritical(ex, $"Migrating the database failed ({ex.GetType().Name}): {ex.InnerException?.Message ?? ex.Message}");
return false;
}
}
private static async Task<int> ExecuteScript(this DbCommand command, string text, CancellationToken cancellationToken)
{
if (command.Connection.GetProviderType() == DatabaseProvider.Oracle)
{
int affectedRows = 0;
// Split script by a single slash in a line
#if NET8_0_OR_GREATER
string[] parts = FindSingleSlashInLine().Split(text);
#else
string[] parts = Regex.Split(text, @"\r?\n[ \t]*/[ \t]*\r?\n");
#endif
foreach (string part in parts)
{
// Make writable copy
string pt = part;
// Remove the trailing semicolon from commands where they're not supported
// (Oracle doesn't like semicolons. To keep the semicolon, it must be directly
// preceeded by "end".)
#if NET8_0_OR_GREATER
pt = FindEndCommand().Replace(pt.TrimEnd(), "");
#else
pt = Regex.Replace(pt, @"(?<!end);$", "", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);
#endif
// Execute all non-empty parts as individual commands
if (!string.IsNullOrWhiteSpace(pt))
{
command.CommandText = pt;
affectedRows += await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}
return affectedRows;
}
else
{
command.CommandText = text;
return await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
}
internal enum DatabaseProvider
{
MySQL = 1,
Oracle = 2,
PostgreSQL = 3,
SQLite = 4,
SQLServer = 5,
InMemory = 6,
}
#if NET8_0_OR_GREATER
[GeneratedRegex(@"\r?\n[ \t]*/[ \t]*\r?\n")]
private static partial Regex FindSingleSlashInLine();
[GeneratedRegex(@"(?<!end);$", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant)]
private static partial Regex FindEndCommand();
#endif
}
}

View File

@@ -0,0 +1,91 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Storage;
namespace Microsoft.EntityFrameworkCore
{
/// <summary>
/// Extensions for the <see cref="DbContext"/>.
/// </summary>
public static class DbContextExensions
{
/// <summary>
/// Starts a new transaction.
/// </summary>
/// <remarks>
/// See <see href="https://aka.ms/efcore-docs-transactions">Transactions in EF Core</see> for more information.
/// </remarks>
/// <param name="dbContext">The current <see cref="DbContext"/>.</param>
/// <returns>
/// A <see cref="IDbContextTransaction" /> that represents the started transaction.
/// </returns>
public static IDbContextTransaction BeginTransaction(this DbContext dbContext)
{
if (dbContext.Database.GetProviderType() == DatabaseFacadeExtensions.DatabaseProvider.InMemory)
return new DbContextTransactionStub();
return dbContext.Database.BeginTransaction();
}
/// <summary>
/// Asynchronously starts a new transaction.
/// </summary>
/// <remarks>
/// <para>
/// Entity Framework Core does not support multiple parallel operations being run on the same DbContext instance. This
/// includes both parallel execution of async queries and any explicit concurrent use from multiple threads.
/// Therefore, always await async calls immediately, or use separate DbContext instances for operations that execute
/// in parallel. See <see href="https://aka.ms/efcore-docs-threading">Avoiding DbContext threading issues</see>
/// for more information.
/// </para>
/// <para>
/// See <see href="https://aka.ms/efcore-docs-transactions">Transactions in EF Core</see> for more information.
/// </para>
/// </remarks>
/// <param name="dbContext">The current <see cref="DbContext"/>.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
/// <returns>
/// A task that represents the asynchronous transaction initialization. The task result contains a <see cref="IDbContextTransaction" /> that represents the started transaction.
/// </returns>
/// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
public static Task<IDbContextTransaction> BeginTransactionAsync(this DbContext dbContext, CancellationToken cancellationToken)
{
if (dbContext.Database.GetProviderType() == DatabaseFacadeExtensions.DatabaseProvider.InMemory)
return Task.FromResult<IDbContextTransaction>(new DbContextTransactionStub());
return dbContext.Database.BeginTransactionAsync(cancellationToken);
}
/// <inheritdoc cref="IDbContextTransaction" />
private class DbContextTransactionStub : IDbContextTransaction
{
/// <inheritdoc />
public Guid TransactionId { get; private set; } = Guid.NewGuid();
/// <inheritdoc />
public void Commit()
{ }
/// <inheritdoc />
public Task CommitAsync(CancellationToken cancellationToken = default)
=> Task.CompletedTask;
/// <inheritdoc />
public void Dispose()
{ }
/// <inheritdoc />
public ValueTask DisposeAsync()
=> new(Task.CompletedTask);
/// <inheritdoc />
public void Rollback()
{ }
/// <inheritdoc />
public Task RollbackAsync(CancellationToken cancellationToken = default)
=> Task.CompletedTask;
}
}
}

View File

@@ -0,0 +1,248 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using Microsoft.Extensions.Configuration;
namespace Microsoft.EntityFrameworkCore
{
/// <summary>
/// Extends the <see cref="DbContextOptionsBuilder"/> to use a configurable database provider.
/// </summary>
public static class DbContextOptionsBuilderExtensions
{
/// <summary>
/// Adds the supported database provider to the context.
/// </summary>
/// <remarks>
/// The configuration provided requires the following entries:
/// <list type="bullet">
/// <item><strong>Provider</strong>: MySQL | Oracle | PostgreSQL | SQLite | SQLServer</item>
/// <item><strong>Host</strong>: hostname or IP address</item>
/// <item><strong>Port</strong>: port number</item>
/// <item><strong>Name</strong>: database name</item>
/// <item><strong>Schema</strong>: schema or search path (e.g. PostgreSQL: public)</item>
/// <item><strong>Username</strong>: username credential on the database</item>
/// <item><strong>Password</strong>: password credential on the database</item>
/// <item><strong>File</strong>: file name / path (for SQLite)</item>
/// </list>
/// </remarks>
/// <param name="optionsBuilder">The options builder.</param>
/// <param name="configuration">The application configuration section for the database.</param>
/// <param name="optionsAction">An optional action to set additional options.</param>
/// <returns>The <see cref="DbContextOptionsBuilder"/> with applied settings.</returns>
public static DbContextOptionsBuilder UseDatabaseProvider(this DbContextOptionsBuilder optionsBuilder, IConfiguration configuration, Action<DatabaseProviderOptions> optionsAction = null)
{
#if NET8_0_OR_GREATER
ArgumentNullException.ThrowIfNull(optionsBuilder);
ArgumentNullException.ThrowIfNull(configuration);
#else
if (optionsBuilder == null)
throw new ArgumentNullException(nameof(optionsBuilder));
if (configuration == null)
throw new ArgumentNullException(nameof(configuration));
#endif
var options = new DatabaseProviderOptions();
optionsAction?.Invoke(options);
string connectionString = GetConnectionString(configuration, options);
string provider = configuration.GetValue<string>("provider")?.ToLower();
var builderType = GetBuilderType(configuration);
var extensionType = GetExtensionType(configuration);
var actionType = typeof(Action<>).MakeGenericType(builderType);
object serverVersion = null;
MethodInfo methodInfo;
switch (provider)
{
case "mysql":
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
if (methodInfo == null)
methodInfo = extensionType.GetMethod("UseMySQL", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
if (methodInfo == null) // Pomelo MySQL v5
{
var serverVersionType = Type.GetType("Microsoft.EntityFrameworkCore.ServerVersion, Pomelo.EntityFrameworkCore.MySql");
var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", new Type[] { typeof(string) });
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType });
serverVersion = autoDetectMethodInfo.Invoke(null, new object[] { connectionString });
}
break;
case "oracle":
methodInfo = extensionType.GetMethod("UseOracle", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "postgres":
case "postgresql":
methodInfo = extensionType.GetMethod("UseNpgsql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "sqlite":
methodInfo = extensionType.GetMethod("UseSqlite", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "sqlserver":
case "mssql":
methodInfo = extensionType.GetMethod("UseSqlServer", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "memory":
case "inmemory":
methodInfo = extensionType.GetMethod("UseInMemoryDatabase", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
}
if (serverVersion == null)
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, null });
}
else
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, serverVersion, null });
}
return optionsBuilder;
}
private static Type GetBuilderType(IConfiguration configuration)
{
string provider = configuration.GetValue<string>("provider")?.ToLower();
Type builderType;
switch (provider)
{
case "mysql":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.MySqlDbContextOptionsBuilder, Pomelo.EntityFrameworkCore.MySql");
if (builderType == null)
builderType = Type.GetType("MySql.Data.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.Data.EntityFrameworkCore");
if (builderType == null) // as MySql.Data.EntityFrameworkCore is marked as deprecated on NuGet
builderType = Type.GetType("MySql.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.EntityFrameworkCore");
break;
case "oracle":
builderType = Type.GetType("Oracle.EntityFrameworkCore.Infrastructure.OracleDbContextOptionsBuilder, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
builderType = Type.GetType("Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.NpgsqlDbContextOptionsBuilder, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqliteDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqlServerDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.InMemoryDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return builderType;
}
private static Type GetExtensionType(IConfiguration configuration)
{
string provider = configuration.GetValue<string>("provider")?.ToLower();
Type extensionType;
switch (provider)
{
case "mysql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySqlDbContextOptionsBuilderExtensions, Pomelo.EntityFrameworkCore.MySql");
if (extensionType == null)
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.Data.EntityFrameworkCore");
if (extensionType == null)
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.EntityFrameworkCore");
break;
case "oracle":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.OracleDbContextOptionsBuilderExtensions, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.NpgsqlDbContextOptionsBuilderExtensions, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqliteDbContextOptionsBuilderExtensions, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqlServerDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.InMemoryDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return extensionType;
}
private static string GetConnectionString(IConfiguration configuration, DatabaseProviderOptions options)
{
var cs = new List<string>();
string provider = configuration.GetValue<string>("provider")?.ToLower();
switch (provider)
{
case "mysql":
cs.Add($"Server={configuration.GetValue<string>("Host")}");
cs.Add($"Port={configuration.GetValue("Port", 3306)}");
cs.Add($"Database={configuration.GetValue<string>("Name")}");
cs.Add($"Uid={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "oracle":
cs.Add($"Data Source=(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={configuration.GetValue<string>("Host")})(PORT={configuration.GetValue("Port", 1521)}))(CONNECT_DATA=(SERVICE_NAME={configuration.GetValue<string>("Name")})))");
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "postgres":
case "postgresql":
cs.Add($"Server={configuration.GetValue<string>("Host")}");
cs.Add($"Port={configuration.GetValue("Port", 5432)}");
cs.Add($"Database={configuration.GetValue<string>("Name")}");
cs.Add($"Search Path={configuration.GetValue("Schema", "public")}");
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Timeout=15");
break;
case "sqlite":
string path = configuration.GetValue<string>("File");
if (!Path.IsPathRooted(path))
{
if (string.IsNullOrWhiteSpace(options.AbsoluteBasePath))
options.AbsoluteBasePath = AppContext.BaseDirectory;
path = Path.Combine(options.AbsoluteBasePath, path);
}
cs.Add($"Data Source={path}");
cs.Add("Foreign Keys=True");
break;
case "sqlserver":
case "mssql":
cs.Add($"Server={configuration.GetValue<string>("Host")},{configuration.GetValue("Port", 1433)}");
cs.Add($"Database={configuration.GetValue<string>("Name")}");
if (!string.IsNullOrWhiteSpace(configuration.GetValue<string>("Username")))
{
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add("Integrated Security=False");
}
else
{
cs.Add("Integrated Security=True");
}
cs.Add("Connect Timeout=15");
break;
case "memory":
case "inmemory":
cs.Add(configuration.GetValue("Name", provider));
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
}
return string.Join(";", cs);
}
}
}

View File

@@ -0,0 +1,140 @@
using System.ComponentModel.DataAnnotations.Schema;
using System.Reflection;
using System.Text;
using AMWD.Common.EntityFrameworkCore.Attributes;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata;
namespace AMWD.Common.EntityFrameworkCore.Extensions
{
/// <summary>
/// Extensions for the <see cref="ModelBuilder"/> of entity framework core.
/// </summary>
public static class ModelBuilderExtensions
{
/// <summary>
/// Applies indices and unique constraints to the properties.
/// </summary>
/// <param name="builder">The database model builder.</param>
/// <returns>A reference to this instance after the operation has completed.</returns>
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0019", Justification = "No pattern comparison in this case due to readability.")]
public static ModelBuilder ApplyIndexAttributes(this ModelBuilder builder)
{
foreach (var entityType in builder.Model.GetEntityTypes())
{
foreach (var property in entityType.GetProperties())
{
var indexAttribute = entityType.ClrType
.GetProperty(property.Name)
?.GetCustomAttribute(typeof(DatabaseIndexAttribute), false) as DatabaseIndexAttribute;
if (indexAttribute != null)
{
var index = entityType.AddIndex(property);
index.IsUnique = indexAttribute.IsUnique;
if (!string.IsNullOrWhiteSpace(indexAttribute.Name))
index.SetDatabaseName(indexAttribute.Name.Trim());
}
}
}
return builder;
}
/// <summary>
/// Converts all table and column names to snake_case_names.
/// </summary>
/// <param name="builder">The database model builder.</param>
/// <returns>A reference to this instance after the operation has completed.</returns>
public static ModelBuilder ApplySnakeCase(this ModelBuilder builder)
{
foreach (var entityType in builder.Model.GetEntityTypes())
{
// skip conversion when table name is explicitly set
if ((entityType.ClrType.GetCustomAttribute(typeof(TableAttribute), false) as TableAttribute) == null)
entityType.SetTableName(ConvertToSnakeCase(entityType.GetTableName()));
var identifier = StoreObjectIdentifier.Table(entityType.GetTableName(), entityType.GetSchema());
foreach (var property in entityType.GetProperties())
{
// skip conversion when column name is explicitly set
if ((entityType.ClrType.GetProperty(property.Name)?.GetCustomAttribute(typeof(ColumnAttribute), false) as ColumnAttribute) == null)
property.SetColumnName(ConvertToSnakeCase(property.GetColumnName(identifier)));
}
}
return builder;
}
/// <summary>
/// Converts a string to its snake_case equivalent.
/// </summary>
/// <remarks>
/// Code borrowed from Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.
/// See https://github.com/npgsql/npgsql/blob/f2b2c98f45df6d2a78eec00ae867f18944d717ca/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs#L76-L136.
/// </remarks>
/// <param name="value">The value to convert.</param>
private static string ConvertToSnakeCase(string value)
{
var sb = new StringBuilder();
var state = SnakeCaseState.Start;
for (int i = 0; i < value.Length; i++)
{
if (value[i] == ' ')
{
if (state != SnakeCaseState.Start)
state = SnakeCaseState.NewWord;
}
else if (char.IsUpper(value[i]))
{
switch (state)
{
case SnakeCaseState.Upper:
bool hasNext = i + 1 < value.Length;
if (i > 0 && hasNext)
{
char nextChar = value[i + 1];
if (!char.IsUpper(nextChar) && nextChar != '_')
{
sb.Append('_');
}
}
break;
case SnakeCaseState.Lower:
case SnakeCaseState.NewWord:
sb.Append('_');
break;
}
sb.Append(char.ToLowerInvariant(value[i]));
state = SnakeCaseState.Upper;
}
else if (value[i] == '_')
{
sb.Append('_');
state = SnakeCaseState.Start;
}
else
{
if (state == SnakeCaseState.NewWord)
sb.Append('_');
sb.Append(value[i]);
state = SnakeCaseState.Lower;
}
}
return sb.ToString();
}
private enum SnakeCaseState
{
Start,
Lower,
Upper,
NewWord
}
}
}

View File

@@ -0,0 +1,52 @@
using System;
using AMWD.Common.EntityFrameworkCore.Converters;
using Microsoft.EntityFrameworkCore;
namespace AMWD.Common.EntityFrameworkCore.Extensions
{
/// <summary>
/// Extensions for the <see cref="ModelConfigurationBuilder"/> of entity framework core.
/// </summary>
public static class ModelConfigurationBuilderExtensions
{
/// <summary>
/// Adds converters for the <see cref="DateOnly"/> datatype introduced with .NET 6.0.
/// </summary>
/// <remarks>
/// As of 2022-06-04 only required for Microsoft SQL server on .NET 6.0.
/// </remarks>
/// <param name="builder">The <see cref="ModelConfigurationBuilder"/> instance.</param>
/// <returns>The <see cref="ModelConfigurationBuilder"/> instance after applying the converters.</returns>
public static ModelConfigurationBuilder AddDateOnlyConverters(this ModelConfigurationBuilder builder)
{
builder.Properties<DateOnly>()
.HaveConversion<DateOnlyConverter>()
.HaveColumnType("date");
builder.Properties<DateOnly?>()
.HaveConversion<NullableDateOnlyConverter>()
.HaveColumnType("date");
return builder;
}
/// <summary>
/// Adds converters for the <see cref="TimeOnly"/> datatype introduced with .NET 6.0.
/// </summary>
/// <remarks>
/// As of 2022-06-04 only required for Microsoft SQL server on .NET 6.0.
/// </remarks>
/// <param name="builder">The <see cref="ModelConfigurationBuilder"/> instance.</param>
/// <returns>The <see cref="ModelConfigurationBuilder"/> instance after applying the converters.</returns>
public static ModelConfigurationBuilder AddTimeOnlyConverters(this ModelConfigurationBuilder builder)
{
builder.Properties<TimeOnly>()
.HaveConversion<TimeOnlyConverter>()
.HaveColumnType("time");
builder.Properties<TimeOnly?>()
.HaveConversion<NullableTimeOnlyConverter>()
.HaveColumnType("time");
return builder;
}
}
}