Solution restructured to use multiple test projects
This commit is contained in:
@@ -0,0 +1,380 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Data;
|
||||
using System.Data.Common;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text.RegularExpressions;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.EntityFrameworkCore.Infrastructure;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.EntityFrameworkCore
|
||||
{
|
||||
/// <summary>
|
||||
/// Extensions for the <see cref="DatabaseFacade"/>.
|
||||
/// </summary>
|
||||
#if NET8_0_OR_GREATER
|
||||
public static partial class DatabaseFacadeExtensions
|
||||
#else
|
||||
public static class DatabaseFacadeExtensions
|
||||
#endif
|
||||
{
|
||||
/// <summary>
|
||||
/// Applies migration files to the database.
|
||||
/// </summary>
|
||||
/// <param name="database">The database connection.</param>
|
||||
/// <param name="optionsAction">An action to set additional options.</param>
|
||||
/// <param name="cancellationToken">The cancellation token.</param>
|
||||
/// <returns><see langword="true"/> on success, otherwise false or an exception is thrown.</returns>
|
||||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "CA2208")]
|
||||
public static async Task<bool> ApplyMigrationsAsync(this DatabaseFacade database, Action<DatabaseMigrationOptions> optionsAction, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (database == null)
|
||||
throw new ArgumentNullException(nameof(database));
|
||||
|
||||
if (database.GetProviderType() == DatabaseProvider.InMemory)
|
||||
return true;
|
||||
|
||||
var options = new DatabaseMigrationOptions();
|
||||
optionsAction?.Invoke(options);
|
||||
|
||||
if (string.IsNullOrWhiteSpace(options.MigrationsTableName))
|
||||
throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required.");
|
||||
|
||||
if (string.IsNullOrWhiteSpace(options.Path))
|
||||
throw new ArgumentNullException(nameof(options.Path), $"The property {nameof(options.Path)} of the {nameof(options)} parameter is required.");
|
||||
|
||||
await database.WaitAvailableAsync(opts =>
|
||||
{
|
||||
opts.WaitDelay = options.WaitDelay;
|
||||
opts.Logger = options.Logger;
|
||||
}, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
var connection = database.GetDbConnection();
|
||||
try
|
||||
{
|
||||
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
if (!await connection.CreateMigrationsTable(options, cancellationToken).ConfigureAwait(false))
|
||||
return false;
|
||||
|
||||
return await connection.Migrate(options, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
finally
|
||||
{
|
||||
await connection.CloseAsync().ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Waits until the database connection is available.
|
||||
/// </summary>
|
||||
/// <param name="database">The database connection.</param>
|
||||
/// <param name="optionsAction">An action to set additional options.</param>
|
||||
/// <param name="cancellationToken">A cancellation token.</param>
|
||||
/// <returns>An awaitable task to wait until the database is available.</returns>
|
||||
public static async Task WaitAvailableAsync(this DatabaseFacade database, Action<DatabaseMigrationOptions> optionsAction = null, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (database == null)
|
||||
throw new ArgumentNullException(nameof(database));
|
||||
|
||||
if (database.GetProviderType() == DatabaseProvider.InMemory)
|
||||
return;
|
||||
|
||||
var options = new DatabaseMigrationOptions();
|
||||
optionsAction?.Invoke(options);
|
||||
|
||||
options.Logger?.LogInformation("Waiting for a database connection");
|
||||
var connection = database.GetDbConnection();
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
try
|
||||
{
|
||||
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
|
||||
options.Logger?.LogInformation("Database connection available");
|
||||
return;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// keep things quiet
|
||||
try
|
||||
{
|
||||
await Task.Delay(options.WaitDelay, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested)
|
||||
{
|
||||
return;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// keep things quiet
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
await connection.CloseAsync().ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal static DatabaseProvider GetProviderType(this DatabaseFacade database)
|
||||
=> GetProviderType(database.ProviderName);
|
||||
|
||||
private static DatabaseProvider GetProviderType(this DbConnection connection)
|
||||
=> GetProviderType(connection.GetType().FullName);
|
||||
|
||||
private static DatabaseProvider GetProviderType(string provider)
|
||||
{
|
||||
if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.MySQL;
|
||||
if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.Oracle;
|
||||
if (provider.Contains("npgsql", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.PostgreSQL;
|
||||
if (provider.Contains("sqlite", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.SQLite;
|
||||
if (provider.Contains("sqlclient", StringComparison.OrdinalIgnoreCase)
|
||||
|| provider.Contains("sqlserver", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.SQLServer;
|
||||
if (provider.Contains("inmemory", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.InMemory;
|
||||
|
||||
throw new DatabaseProviderException($"The database provider '{provider}' is unknown");
|
||||
}
|
||||
|
||||
private static async Task<bool> CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
using var command = connection.CreateCommand();
|
||||
|
||||
#pragma warning disable CS8509 // ignore missing cases
|
||||
command.CommandText = connection.GetProviderType() switch
|
||||
#pragma warning restore CS8509 // ignore missing cases
|
||||
{
|
||||
DatabaseProvider.MySQL => $@"CREATE TABLE IF NOT EXISTS `{options.MigrationsTableName}` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT,
|
||||
`schema_file` VARCHAR(250) NOT NULL,
|
||||
`installed_at` VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
);",
|
||||
DatabaseProvider.Oracle => $@"DECLARE ncount NUMBER;
|
||||
BEGIN
|
||||
SELECT count(*) INTO ncount FROM dba_tables WHERE table_name = '{options.MigrationsTableName}';
|
||||
IF (ncount <= 0)
|
||||
THEN
|
||||
EXECUTE IMMEDIATE 'CREATE TABLE ""{options.MigrationsTableName}"" (
|
||||
""id"" NUMBER GENERATED by default on null as IDENTITY,
|
||||
""schema_file"" VARCHAR2(250) NOT NULL,
|
||||
""installed_at"" VARCHAR2(16) NOT NULL,
|
||||
PRIMARY KEY (""id""),
|
||||
CONSTRAINT uq_schema_file UNIQUE (""schema_file"")
|
||||
)';
|
||||
END IF;
|
||||
END;",
|
||||
DatabaseProvider.PostgreSQL => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
|
||||
""id"" SERIAL4 PRIMARY KEY,
|
||||
""schema_file"" VARCHAR(250) NOT NULL,
|
||||
""installed_at"" VARCHAR(16) NOT NULL,
|
||||
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
|
||||
);",
|
||||
DatabaseProvider.SQLite => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
|
||||
""id"" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
""schema_file"" TEXT(250) NOT NULL,
|
||||
""installed_at"" TEXT(16) NOT NULL,
|
||||
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
|
||||
);",
|
||||
DatabaseProvider.SQLServer => $@"IF NOT EXISTS (SELECT * FROM [sysobjects] WHERE [name] = '{options.MigrationsTableName}' AND [xtype] = 'U')
|
||||
BEGIN
|
||||
CREATE TABLE [{options.MigrationsTableName}] (
|
||||
[id] int IDENTITY(1,1) NOT NULL PRIMARY KEY,
|
||||
[schema_file] varchar(250) NOT NULL,
|
||||
[installed_at] varchar(16) NOT NULL,
|
||||
CONSTRAINT uq_schema_file UNIQUE (schema_file)
|
||||
)
|
||||
END;"
|
||||
};
|
||||
|
||||
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
|
||||
return true;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
options.Logger?.LogCritical(ex, $"Creating migrations table '{options.MigrationsTableName}' failed: {ex.InnerException?.Message ?? ex.Message}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<bool> Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
List<string> availableMigrationFiles;
|
||||
if (options.SourceAssembly == null)
|
||||
{
|
||||
availableMigrationFiles = Directory.GetFiles(options.Path)
|
||||
.Where(f => f.StartsWith(options.Path, StringComparison.OrdinalIgnoreCase))
|
||||
.Where(f => f.EndsWith(".sql", StringComparison.OrdinalIgnoreCase))
|
||||
.ToList();
|
||||
}
|
||||
else
|
||||
{
|
||||
availableMigrationFiles = options.SourceAssembly
|
||||
.GetManifestResourceNames()
|
||||
.Where(f => f.StartsWith(options.Path, StringComparison.OrdinalIgnoreCase))
|
||||
.Where(f => f.EndsWith(".sql", StringComparison.OrdinalIgnoreCase))
|
||||
.ToList();
|
||||
}
|
||||
|
||||
if (availableMigrationFiles.Count == 0)
|
||||
return true;
|
||||
|
||||
using var command = connection.CreateCommand();
|
||||
|
||||
var migratedFiles = new List<string>();
|
||||
command.CommandText = connection.GetProviderType() switch
|
||||
{
|
||||
DatabaseProvider.MySQL => $"SELECT `schema_file` FROM `{options.MigrationsTableName}`;",
|
||||
DatabaseProvider.SQLServer => $"SELECT [schema_file] FROM [{options.MigrationsTableName}];",
|
||||
_ => $@"SELECT ""schema_file"" FROM ""{options.MigrationsTableName}"";",
|
||||
};
|
||||
using (var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
|
||||
migratedFiles.Add(reader.GetString(0));
|
||||
}
|
||||
|
||||
int pathLength = options.Path.Length + 1;
|
||||
foreach (string migrationFile in availableMigrationFiles)
|
||||
{
|
||||
// remove path including the separator
|
||||
string fileName = migrationFile.Replace(options.Path, "")[1..];
|
||||
using var transaction = await connection.BeginTransactionAsync(cancellationToken).ConfigureAwait(false);
|
||||
try
|
||||
{
|
||||
// max length in the database: 250 chars
|
||||
string trimmedFileName = fileName;
|
||||
if (trimmedFileName.Length > 250)
|
||||
fileName = fileName[..250];
|
||||
|
||||
if (migratedFiles.Contains(trimmedFileName))
|
||||
{
|
||||
options.Logger?.LogDebug($" Migrating file '{fileName}' done");
|
||||
continue;
|
||||
}
|
||||
|
||||
string sqlScript = null;
|
||||
if (options.SourceAssembly == null)
|
||||
{
|
||||
sqlScript = await File.ReadAllTextAsync(migrationFile, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
else
|
||||
{
|
||||
using var stream = options.SourceAssembly.GetManifestResourceStream(migrationFile);
|
||||
using var sr = new StreamReader(stream);
|
||||
#if NET8_0_OR_GREATER
|
||||
sqlScript = await sr.ReadToEndAsync(cancellationToken).ConfigureAwait(false);
|
||||
#else
|
||||
sqlScript = await sr.ReadToEndAsync().ConfigureAwait(false);
|
||||
#endif
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(sqlScript))
|
||||
continue;
|
||||
|
||||
options.Logger?.LogDebug($" Migrating file '{fileName}' started");
|
||||
command.Transaction = transaction;
|
||||
|
||||
await command.ExecuteScript(sqlScript, cancellationToken).ConfigureAwait(false);
|
||||
|
||||
command.CommandText = connection.GetProviderType() switch
|
||||
{
|
||||
DatabaseProvider.MySQL => $"INSERT INTO `{options.MigrationsTableName}` (`schema_file`, `installed_at`) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
|
||||
DatabaseProvider.SQLServer => $"INSERT INTO [{options.MigrationsTableName}] ([schema_file], [installed_at]) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
|
||||
_ => $@"INSERT INTO ""{options.MigrationsTableName}"" (""schema_file"", ""installed_at"") VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
|
||||
};
|
||||
await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await transaction.CommitAsync(cancellationToken).ConfigureAwait(false);
|
||||
command.Transaction = null;
|
||||
options.Logger?.LogDebug($" Migrating file '{fileName}' successful");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await transaction.RollbackAsync(cancellationToken).ConfigureAwait(false);
|
||||
options.Logger?.LogError($"Migrating file '{fileName}' failed: {ex.InnerException?.Message ?? ex.Message}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
options.Logger?.LogCritical(ex, $"Migrating the database failed ({ex.GetType().Name}): {ex.InnerException?.Message ?? ex.Message}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<int> ExecuteScript(this DbCommand command, string text, CancellationToken cancellationToken)
|
||||
{
|
||||
if (command.Connection.GetProviderType() == DatabaseProvider.Oracle)
|
||||
{
|
||||
int affectedRows = 0;
|
||||
// Split script by a single slash in a line
|
||||
#if NET8_0_OR_GREATER
|
||||
string[] parts = FindSingleSlashInLine().Split(text);
|
||||
#else
|
||||
string[] parts = Regex.Split(text, @"\r?\n[ \t]*/[ \t]*\r?\n");
|
||||
#endif
|
||||
foreach (string part in parts)
|
||||
{
|
||||
// Make writable copy
|
||||
string pt = part;
|
||||
|
||||
// Remove the trailing semicolon from commands where they're not supported
|
||||
// (Oracle doesn't like semicolons. To keep the semicolon, it must be directly
|
||||
// preceeded by "end".)
|
||||
#if NET8_0_OR_GREATER
|
||||
pt = FindEndCommand().Replace(pt.TrimEnd(), "");
|
||||
#else
|
||||
pt = Regex.Replace(pt, @"(?<!end);$", "", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);
|
||||
#endif
|
||||
|
||||
// Execute all non-empty parts as individual commands
|
||||
if (!string.IsNullOrWhiteSpace(pt))
|
||||
{
|
||||
command.CommandText = pt;
|
||||
affectedRows += await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
return affectedRows;
|
||||
}
|
||||
else
|
||||
{
|
||||
command.CommandText = text;
|
||||
return await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
internal enum DatabaseProvider
|
||||
{
|
||||
MySQL = 1,
|
||||
Oracle = 2,
|
||||
PostgreSQL = 3,
|
||||
SQLite = 4,
|
||||
SQLServer = 5,
|
||||
InMemory = 6,
|
||||
}
|
||||
|
||||
#if NET8_0_OR_GREATER
|
||||
[GeneratedRegex(@"\r?\n[ \t]*/[ \t]*\r?\n")]
|
||||
private static partial Regex FindSingleSlashInLine();
|
||||
|
||||
[GeneratedRegex(@"(?<!end);$", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant)]
|
||||
private static partial Regex FindEndCommand();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.EntityFrameworkCore.Storage;
|
||||
|
||||
namespace Microsoft.EntityFrameworkCore
|
||||
{
|
||||
/// <summary>
|
||||
/// Extensions for the <see cref="DbContext"/>.
|
||||
/// </summary>
|
||||
public static class DbContextExensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Starts a new transaction.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// See <see href="https://aka.ms/efcore-docs-transactions">Transactions in EF Core</see> for more information.
|
||||
/// </remarks>
|
||||
/// <param name="dbContext">The current <see cref="DbContext"/>.</param>
|
||||
/// <returns>
|
||||
/// A <see cref="IDbContextTransaction" /> that represents the started transaction.
|
||||
/// </returns>
|
||||
public static IDbContextTransaction BeginTransaction(this DbContext dbContext)
|
||||
{
|
||||
if (dbContext.Database.GetProviderType() == DatabaseFacadeExtensions.DatabaseProvider.InMemory)
|
||||
return new DbContextTransactionStub();
|
||||
|
||||
return dbContext.Database.BeginTransaction();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Asynchronously starts a new transaction.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// <para>
|
||||
/// Entity Framework Core does not support multiple parallel operations being run on the same DbContext instance. This
|
||||
/// includes both parallel execution of async queries and any explicit concurrent use from multiple threads.
|
||||
/// Therefore, always await async calls immediately, or use separate DbContext instances for operations that execute
|
||||
/// in parallel. See <see href="https://aka.ms/efcore-docs-threading">Avoiding DbContext threading issues</see>
|
||||
/// for more information.
|
||||
/// </para>
|
||||
/// <para>
|
||||
/// See <see href="https://aka.ms/efcore-docs-transactions">Transactions in EF Core</see> for more information.
|
||||
/// </para>
|
||||
/// </remarks>
|
||||
/// <param name="dbContext">The current <see cref="DbContext"/>.</param>
|
||||
/// <param name="cancellationToken">A <see cref="CancellationToken" /> to observe while waiting for the task to complete.</param>
|
||||
/// <returns>
|
||||
/// A task that represents the asynchronous transaction initialization. The task result contains a <see cref="IDbContextTransaction" /> that represents the started transaction.
|
||||
/// </returns>
|
||||
/// <exception cref="OperationCanceledException">If the <see cref="CancellationToken" /> is canceled.</exception>
|
||||
public static Task<IDbContextTransaction> BeginTransactionAsync(this DbContext dbContext, CancellationToken cancellationToken)
|
||||
{
|
||||
if (dbContext.Database.GetProviderType() == DatabaseFacadeExtensions.DatabaseProvider.InMemory)
|
||||
return Task.FromResult<IDbContextTransaction>(new DbContextTransactionStub());
|
||||
|
||||
return dbContext.Database.BeginTransactionAsync(cancellationToken);
|
||||
}
|
||||
|
||||
/// <inheritdoc cref="IDbContextTransaction" />
|
||||
private class DbContextTransactionStub : IDbContextTransaction
|
||||
{
|
||||
/// <inheritdoc />
|
||||
public Guid TransactionId { get; private set; } = Guid.NewGuid();
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Commit()
|
||||
{ }
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task CommitAsync(CancellationToken cancellationToken = default)
|
||||
=> Task.CompletedTask;
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Dispose()
|
||||
{ }
|
||||
|
||||
/// <inheritdoc />
|
||||
public ValueTask DisposeAsync()
|
||||
=> new(Task.CompletedTask);
|
||||
|
||||
/// <inheritdoc />
|
||||
public void Rollback()
|
||||
{ }
|
||||
|
||||
/// <inheritdoc />
|
||||
public Task RollbackAsync(CancellationToken cancellationToken = default)
|
||||
=> Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Reflection;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
|
||||
namespace Microsoft.EntityFrameworkCore
|
||||
{
|
||||
/// <summary>
|
||||
/// Extends the <see cref="DbContextOptionsBuilder"/> to use a configurable database provider.
|
||||
/// </summary>
|
||||
public static class DbContextOptionsBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds the supported database provider to the context.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The configuration provided requires the following entries:
|
||||
/// <list type="bullet">
|
||||
/// <item><strong>Provider</strong>: MySQL | Oracle | PostgreSQL | SQLite | SQLServer</item>
|
||||
/// <item><strong>Host</strong>: hostname or IP address</item>
|
||||
/// <item><strong>Port</strong>: port number</item>
|
||||
/// <item><strong>Name</strong>: database name</item>
|
||||
/// <item><strong>Schema</strong>: schema or search path (e.g. PostgreSQL: public)</item>
|
||||
/// <item><strong>Username</strong>: username credential on the database</item>
|
||||
/// <item><strong>Password</strong>: password credential on the database</item>
|
||||
/// <item><strong>File</strong>: file name / path (for SQLite)</item>
|
||||
/// </list>
|
||||
/// </remarks>
|
||||
/// <param name="optionsBuilder">The options builder.</param>
|
||||
/// <param name="configuration">The application configuration section for the database.</param>
|
||||
/// <param name="optionsAction">An optional action to set additional options.</param>
|
||||
/// <returns>The <see cref="DbContextOptionsBuilder"/> with applied settings.</returns>
|
||||
public static DbContextOptionsBuilder UseDatabaseProvider(this DbContextOptionsBuilder optionsBuilder, IConfiguration configuration, Action<DatabaseProviderOptions> optionsAction = null)
|
||||
{
|
||||
#if NET8_0_OR_GREATER
|
||||
ArgumentNullException.ThrowIfNull(optionsBuilder);
|
||||
ArgumentNullException.ThrowIfNull(configuration);
|
||||
#else
|
||||
if (optionsBuilder == null)
|
||||
throw new ArgumentNullException(nameof(optionsBuilder));
|
||||
if (configuration == null)
|
||||
throw new ArgumentNullException(nameof(configuration));
|
||||
#endif
|
||||
|
||||
var options = new DatabaseProviderOptions();
|
||||
optionsAction?.Invoke(options);
|
||||
|
||||
string connectionString = GetConnectionString(configuration, options);
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
|
||||
var builderType = GetBuilderType(configuration);
|
||||
var extensionType = GetExtensionType(configuration);
|
||||
var actionType = typeof(Action<>).MakeGenericType(builderType);
|
||||
|
||||
object serverVersion = null;
|
||||
MethodInfo methodInfo;
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
if (methodInfo == null)
|
||||
methodInfo = extensionType.GetMethod("UseMySQL", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
if (methodInfo == null) // Pomelo MySQL v5
|
||||
{
|
||||
var serverVersionType = Type.GetType("Microsoft.EntityFrameworkCore.ServerVersion, Pomelo.EntityFrameworkCore.MySql");
|
||||
var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", new Type[] { typeof(string) });
|
||||
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType });
|
||||
serverVersion = autoDetectMethodInfo.Invoke(null, new object[] { connectionString });
|
||||
}
|
||||
break;
|
||||
case "oracle":
|
||||
methodInfo = extensionType.GetMethod("UseOracle", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
methodInfo = extensionType.GetMethod("UseNpgsql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
case "sqlite":
|
||||
methodInfo = extensionType.GetMethod("UseSqlite", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
methodInfo = extensionType.GetMethod("UseSqlServer", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
case "memory":
|
||||
case "inmemory":
|
||||
methodInfo = extensionType.GetMethod("UseInMemoryDatabase", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
default:
|
||||
throw new DatabaseProviderException($"Unknown database provider: {provider}");
|
||||
}
|
||||
|
||||
if (serverVersion == null)
|
||||
{
|
||||
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, null });
|
||||
}
|
||||
else
|
||||
{
|
||||
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, serverVersion, null });
|
||||
}
|
||||
|
||||
return optionsBuilder;
|
||||
}
|
||||
|
||||
private static Type GetBuilderType(IConfiguration configuration)
|
||||
{
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
Type builderType;
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.MySqlDbContextOptionsBuilder, Pomelo.EntityFrameworkCore.MySql");
|
||||
if (builderType == null)
|
||||
builderType = Type.GetType("MySql.Data.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.Data.EntityFrameworkCore");
|
||||
if (builderType == null) // as MySql.Data.EntityFrameworkCore is marked as deprecated on NuGet
|
||||
builderType = Type.GetType("MySql.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.EntityFrameworkCore");
|
||||
break;
|
||||
case "oracle":
|
||||
builderType = Type.GetType("Oracle.EntityFrameworkCore.Infrastructure.OracleDbContextOptionsBuilder, Oracle.EntityFrameworkCore");
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
builderType = Type.GetType("Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.NpgsqlDbContextOptionsBuilder, Npgsql.EntityFrameworkCore.PostgreSQL");
|
||||
break;
|
||||
case "sqlite":
|
||||
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqliteDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.Sqlite");
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqlServerDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.SqlServer");
|
||||
break;
|
||||
case "memory":
|
||||
case "inmemory":
|
||||
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.InMemoryDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.InMemory");
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException($"Unknown database provider: {provider}");
|
||||
}
|
||||
return builderType;
|
||||
}
|
||||
|
||||
private static Type GetExtensionType(IConfiguration configuration)
|
||||
{
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
Type extensionType;
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySqlDbContextOptionsBuilderExtensions, Pomelo.EntityFrameworkCore.MySql");
|
||||
if (extensionType == null)
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.Data.EntityFrameworkCore");
|
||||
if (extensionType == null)
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.EntityFrameworkCore");
|
||||
break;
|
||||
case "oracle":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.OracleDbContextOptionsBuilderExtensions, Oracle.EntityFrameworkCore");
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.NpgsqlDbContextOptionsBuilderExtensions, Npgsql.EntityFrameworkCore.PostgreSQL");
|
||||
break;
|
||||
case "sqlite":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqliteDbContextOptionsBuilderExtensions, Microsoft.EntityFrameworkCore.Sqlite");
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqlServerDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.SqlServer");
|
||||
break;
|
||||
case "memory":
|
||||
case "inmemory":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.InMemoryDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.InMemory");
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException($"Unknown database provider: {provider}");
|
||||
}
|
||||
return extensionType;
|
||||
}
|
||||
|
||||
private static string GetConnectionString(IConfiguration configuration, DatabaseProviderOptions options)
|
||||
{
|
||||
var cs = new List<string>();
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
cs.Add($"Server={configuration.GetValue<string>("Host")}");
|
||||
cs.Add($"Port={configuration.GetValue("Port", 3306)}");
|
||||
cs.Add($"Database={configuration.GetValue<string>("Name")}");
|
||||
cs.Add($"Uid={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add($"Connection Timeout=15");
|
||||
break;
|
||||
case "oracle":
|
||||
cs.Add($"Data Source=(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={configuration.GetValue<string>("Host")})(PORT={configuration.GetValue("Port", 1521)}))(CONNECT_DATA=(SERVICE_NAME={configuration.GetValue<string>("Name")})))");
|
||||
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add($"Connection Timeout=15");
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
cs.Add($"Server={configuration.GetValue<string>("Host")}");
|
||||
cs.Add($"Port={configuration.GetValue("Port", 5432)}");
|
||||
cs.Add($"Database={configuration.GetValue<string>("Name")}");
|
||||
cs.Add($"Search Path={configuration.GetValue("Schema", "public")}");
|
||||
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add($"Timeout=15");
|
||||
break;
|
||||
case "sqlite":
|
||||
string path = configuration.GetValue<string>("File");
|
||||
if (!Path.IsPathRooted(path))
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(options.AbsoluteBasePath))
|
||||
options.AbsoluteBasePath = AppContext.BaseDirectory;
|
||||
|
||||
path = Path.Combine(options.AbsoluteBasePath, path);
|
||||
}
|
||||
cs.Add($"Data Source={path}");
|
||||
cs.Add("Foreign Keys=True");
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
cs.Add($"Server={configuration.GetValue<string>("Host")},{configuration.GetValue("Port", 1433)}");
|
||||
cs.Add($"Database={configuration.GetValue<string>("Name")}");
|
||||
if (!string.IsNullOrWhiteSpace(configuration.GetValue<string>("Username")))
|
||||
{
|
||||
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add("Integrated Security=False");
|
||||
}
|
||||
else
|
||||
{
|
||||
cs.Add("Integrated Security=True");
|
||||
}
|
||||
cs.Add("Connect Timeout=15");
|
||||
break;
|
||||
case "memory":
|
||||
case "inmemory":
|
||||
cs.Add(configuration.GetValue("Name", provider));
|
||||
break;
|
||||
default:
|
||||
throw new DatabaseProviderException($"Unknown database provider: {provider}");
|
||||
}
|
||||
return string.Join(";", cs);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
using System.ComponentModel.DataAnnotations.Schema;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using AMWD.Common.EntityFrameworkCore.Attributes;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
using Microsoft.EntityFrameworkCore.Metadata;
|
||||
|
||||
namespace AMWD.Common.EntityFrameworkCore.Extensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Extensions for the <see cref="ModelBuilder"/> of entity framework core.
|
||||
/// </summary>
|
||||
public static class ModelBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Applies indices and unique constraints to the properties.
|
||||
/// </summary>
|
||||
/// <param name="builder">The database model builder.</param>
|
||||
/// <returns>A reference to this instance after the operation has completed.</returns>
|
||||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0019", Justification = "No pattern comparison in this case due to readability.")]
|
||||
public static ModelBuilder ApplyIndexAttributes(this ModelBuilder builder)
|
||||
{
|
||||
foreach (var entityType in builder.Model.GetEntityTypes())
|
||||
{
|
||||
foreach (var property in entityType.GetProperties())
|
||||
{
|
||||
var indexAttribute = entityType.ClrType
|
||||
.GetProperty(property.Name)
|
||||
?.GetCustomAttribute(typeof(DatabaseIndexAttribute), false) as DatabaseIndexAttribute;
|
||||
if (indexAttribute != null)
|
||||
{
|
||||
var index = entityType.AddIndex(property);
|
||||
index.IsUnique = indexAttribute.IsUnique;
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(indexAttribute.Name))
|
||||
index.SetDatabaseName(indexAttribute.Name.Trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts all table and column names to snake_case_names.
|
||||
/// </summary>
|
||||
/// <param name="builder">The database model builder.</param>
|
||||
/// <returns>A reference to this instance after the operation has completed.</returns>
|
||||
public static ModelBuilder ApplySnakeCase(this ModelBuilder builder)
|
||||
{
|
||||
foreach (var entityType in builder.Model.GetEntityTypes())
|
||||
{
|
||||
// skip conversion when table name is explicitly set
|
||||
if ((entityType.ClrType.GetCustomAttribute(typeof(TableAttribute), false) as TableAttribute) == null)
|
||||
entityType.SetTableName(ConvertToSnakeCase(entityType.GetTableName()));
|
||||
|
||||
var identifier = StoreObjectIdentifier.Table(entityType.GetTableName(), entityType.GetSchema());
|
||||
foreach (var property in entityType.GetProperties())
|
||||
{
|
||||
// skip conversion when column name is explicitly set
|
||||
if ((entityType.ClrType.GetProperty(property.Name)?.GetCustomAttribute(typeof(ColumnAttribute), false) as ColumnAttribute) == null)
|
||||
property.SetColumnName(ConvertToSnakeCase(property.GetColumnName(identifier)));
|
||||
}
|
||||
}
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts a string to its snake_case equivalent.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Code borrowed from Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.
|
||||
/// See https://github.com/npgsql/npgsql/blob/f2b2c98f45df6d2a78eec00ae867f18944d717ca/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs#L76-L136.
|
||||
/// </remarks>
|
||||
/// <param name="value">The value to convert.</param>
|
||||
private static string ConvertToSnakeCase(string value)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
var state = SnakeCaseState.Start;
|
||||
|
||||
for (int i = 0; i < value.Length; i++)
|
||||
{
|
||||
if (value[i] == ' ')
|
||||
{
|
||||
if (state != SnakeCaseState.Start)
|
||||
state = SnakeCaseState.NewWord;
|
||||
}
|
||||
else if (char.IsUpper(value[i]))
|
||||
{
|
||||
switch (state)
|
||||
{
|
||||
case SnakeCaseState.Upper:
|
||||
bool hasNext = i + 1 < value.Length;
|
||||
if (i > 0 && hasNext)
|
||||
{
|
||||
char nextChar = value[i + 1];
|
||||
if (!char.IsUpper(nextChar) && nextChar != '_')
|
||||
{
|
||||
sb.Append('_');
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case SnakeCaseState.Lower:
|
||||
case SnakeCaseState.NewWord:
|
||||
sb.Append('_');
|
||||
break;
|
||||
}
|
||||
|
||||
sb.Append(char.ToLowerInvariant(value[i]));
|
||||
state = SnakeCaseState.Upper;
|
||||
}
|
||||
else if (value[i] == '_')
|
||||
{
|
||||
sb.Append('_');
|
||||
state = SnakeCaseState.Start;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (state == SnakeCaseState.NewWord)
|
||||
sb.Append('_');
|
||||
|
||||
sb.Append(value[i]);
|
||||
state = SnakeCaseState.Lower;
|
||||
}
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private enum SnakeCaseState
|
||||
{
|
||||
Start,
|
||||
Lower,
|
||||
Upper,
|
||||
NewWord
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
using System;
|
||||
using AMWD.Common.EntityFrameworkCore.Converters;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
|
||||
namespace AMWD.Common.EntityFrameworkCore.Extensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Extensions for the <see cref="ModelConfigurationBuilder"/> of entity framework core.
|
||||
/// </summary>
|
||||
public static class ModelConfigurationBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds converters for the <see cref="DateOnly"/> datatype introduced with .NET 6.0.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// As of 2022-06-04 only required for Microsoft SQL server on .NET 6.0.
|
||||
/// </remarks>
|
||||
/// <param name="builder">The <see cref="ModelConfigurationBuilder"/> instance.</param>
|
||||
/// <returns>The <see cref="ModelConfigurationBuilder"/> instance after applying the converters.</returns>
|
||||
public static ModelConfigurationBuilder AddDateOnlyConverters(this ModelConfigurationBuilder builder)
|
||||
{
|
||||
builder.Properties<DateOnly>()
|
||||
.HaveConversion<DateOnlyConverter>()
|
||||
.HaveColumnType("date");
|
||||
builder.Properties<DateOnly?>()
|
||||
.HaveConversion<NullableDateOnlyConverter>()
|
||||
.HaveColumnType("date");
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds converters for the <see cref="TimeOnly"/> datatype introduced with .NET 6.0.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// As of 2022-06-04 only required for Microsoft SQL server on .NET 6.0.
|
||||
/// </remarks>
|
||||
/// <param name="builder">The <see cref="ModelConfigurationBuilder"/> instance.</param>
|
||||
/// <returns>The <see cref="ModelConfigurationBuilder"/> instance after applying the converters.</returns>
|
||||
public static ModelConfigurationBuilder AddTimeOnlyConverters(this ModelConfigurationBuilder builder)
|
||||
{
|
||||
builder.Properties<TimeOnly>()
|
||||
.HaveConversion<TimeOnlyConverter>()
|
||||
.HaveColumnType("time");
|
||||
builder.Properties<TimeOnly?>()
|
||||
.HaveConversion<NullableTimeOnlyConverter>()
|
||||
.HaveColumnType("time");
|
||||
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user