Refactoring
This commit is contained in:
@@ -0,0 +1,277 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Data.Common;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text.RegularExpressions;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.EntityFrameworkCore.Infrastructure;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
namespace Microsoft.EntityFrameworkCore
|
||||
{
|
||||
/// <summary>
|
||||
/// Extensions for the <see cref="DatabaseFacade"/>.
|
||||
/// </summary>
|
||||
public static class DatabaseFacadeExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Applies migration files to the database.
|
||||
/// </summary>
|
||||
/// <param name="database">The database connection.</param>
|
||||
/// <param name="optionsAction">An action to set additional options.</param>
|
||||
/// <param name="cancellationToken">The cancellation token.</param>
|
||||
/// <returns>true on success, otherwise false or an exception is thrown.</returns>
|
||||
public static async Task<bool> ApplyMigrationsAsync(this DatabaseFacade database, Action<DatabaseMigrationOptions> optionsAction, CancellationToken cancellationToken = default)
|
||||
{
|
||||
if (database == null)
|
||||
throw new ArgumentNullException(nameof(database));
|
||||
|
||||
var options = new DatabaseMigrationOptions();
|
||||
optionsAction?.Invoke(options);
|
||||
|
||||
if (string.IsNullOrWhiteSpace(options.MigrationsTableName))
|
||||
throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required.");
|
||||
|
||||
if (string.IsNullOrWhiteSpace(options.Path))
|
||||
throw new ArgumentNullException(nameof(options.Path), $"The property {nameof(options.Path)} of the {nameof(options)} parameter is required.");
|
||||
|
||||
var connection = database.GetDbConnection();
|
||||
try
|
||||
{
|
||||
await connection.OpenAsync(cancellationToken);
|
||||
if (!await connection.CreateMigrationsTable(options, cancellationToken))
|
||||
return false;
|
||||
|
||||
return await connection.Migrate(options, cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
connection.Close();
|
||||
}
|
||||
}
|
||||
|
||||
private static DatabaseProvider GetProviderType(this DbConnection connection)
|
||||
{
|
||||
string provider = connection.GetType().FullName;
|
||||
|
||||
if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.MySQL;
|
||||
if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.Oracle;
|
||||
if (provider.Contains("npgsql", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.PostgreSQL;
|
||||
if (provider.Contains("sqlite", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.SQLite;
|
||||
if (provider.Contains("sqlclient", StringComparison.OrdinalIgnoreCase))
|
||||
return DatabaseProvider.SQLServer;
|
||||
|
||||
throw new DatabaseProviderException($"The database provider '{provider}' is unknown");
|
||||
}
|
||||
|
||||
private static async Task<bool> CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
using var command = connection.CreateCommand();
|
||||
|
||||
#pragma warning disable CS8524 // missing default case
|
||||
command.CommandText = connection.GetProviderType() switch
|
||||
#pragma warning restore CS8524 // missing default case
|
||||
{
|
||||
DatabaseProvider.MySQL => $@"CREATE TABLE IF NOT EXISTS `{options.MigrationsTableName}` (
|
||||
`id` INT NOT NULL AUTO_INCREMENT,
|
||||
`schema_file` VARCHAR(250) NOT NULL,
|
||||
`installed_at` VARCHAR(16) NOT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
);",
|
||||
DatabaseProvider.Oracle => $@"DECLARE ncount NUMBER;
|
||||
BEGIN
|
||||
SELECT count(*) INTO ncount FROM dba_tables WHERE table_name = '{options.MigrationsTableName}';
|
||||
IF (ncount <= 0)
|
||||
THEN
|
||||
EXECUTE IMMEDIATE 'CREATE TABLE ""{options.MigrationsTableName}"" (
|
||||
""id"" NUMBER GENERATED by default on null as IDENTITY,
|
||||
""schema_file"" VARCHAR2(250) NOT NULL,
|
||||
""installed_at"" VARCHAR2(16) NOT NULL,
|
||||
PRIMARY KEY (""id""),
|
||||
CONSTRAINT uq_schema_file UNIQUE (""schema_file"")
|
||||
)';
|
||||
END IF;
|
||||
END;",
|
||||
DatabaseProvider.PostgreSQL => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
|
||||
""id"" SERIAL4 PRIMARY KEY,
|
||||
""schema_file"" VARCHAR(250) NOT NULL,
|
||||
""installed_at"" VARCHAR(16) NOT NULL,
|
||||
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
|
||||
);",
|
||||
DatabaseProvider.SQLite => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
|
||||
""id"" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
""schema_file"" TEXT(250) NOT NULL,
|
||||
""installed_at"" TEXT(16) NOT NULL,
|
||||
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
|
||||
);",
|
||||
DatabaseProvider.SQLServer => $@"IF NOT EXISTS (SELECT * FROM [sysobjects] WHERE [name] = '{options.MigrationsTableName}' AND [xtype] = 'U')
|
||||
BEGIN
|
||||
CREATE TABLE [{options.MigrationsTableName}] (
|
||||
[id] int IDENTITY(1,1) NOT NULL PRIMARY KEY,
|
||||
[schema_file] varchar(250) NOT NULL,
|
||||
[installed_at] varchar(16) NOT NULL,
|
||||
CONSTRAINT uq_schema_file UNIQUE (schema_file)
|
||||
)
|
||||
END;"
|
||||
};
|
||||
|
||||
await command.ExecuteNonQueryAsync(cancellationToken);
|
||||
return true;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
options.Logger?.LogCritical(ex, $"Creating migrations table '{options.MigrationsTableName}' failed: {ex.InnerException?.Message ?? ex.Message}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<bool> Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
|
||||
{
|
||||
try
|
||||
{
|
||||
List<string> availableMigrationFiles;
|
||||
if (options.SourceAssembly == null)
|
||||
{
|
||||
availableMigrationFiles = Directory.GetFiles(options.Path)
|
||||
.Where(f => f.ToLower().StartsWith(options.Path.ToLower()))
|
||||
.Where(f => f.ToLower().EndsWith(".sql"))
|
||||
.ToList();
|
||||
}
|
||||
else
|
||||
{
|
||||
availableMigrationFiles = options.SourceAssembly
|
||||
.GetManifestResourceNames()
|
||||
.Where(f => f.ToLower().StartsWith(options.Path.ToLower()))
|
||||
.Where(f => f.ToLower().EndsWith(".sql"))
|
||||
.ToList();
|
||||
}
|
||||
|
||||
if (!availableMigrationFiles.Any())
|
||||
return true;
|
||||
|
||||
using var command = connection.CreateCommand();
|
||||
|
||||
var migratedFiles = new List<string>();
|
||||
command.CommandText = connection.GetProviderType() switch
|
||||
{
|
||||
DatabaseProvider.MySQL => $"SELECT `schema_file` FROM `{options.MigrationsTableName}`;",
|
||||
DatabaseProvider.SQLServer => $"SELECT [schema_file] FROM [{options.MigrationsTableName}];",
|
||||
_ => $@"SELECT ""schema_file"" FROM ""{options.MigrationsTableName}"";",
|
||||
};
|
||||
using (var reader = await command.ExecuteReaderAsync(cancellationToken))
|
||||
{
|
||||
while (await reader.ReadAsync(cancellationToken))
|
||||
migratedFiles.Add(reader.GetString(0));
|
||||
}
|
||||
|
||||
int pathLength = options.Path.Length + 1;
|
||||
foreach (string migrationFile in availableMigrationFiles)
|
||||
{
|
||||
// remove path including the separator
|
||||
string fileName = migrationFile.Replace(options.Path, "")[1..];
|
||||
using var transaction = await connection.BeginTransactionAsync(cancellationToken);
|
||||
try
|
||||
{
|
||||
// max length in the database: 250 chars
|
||||
string trimmedFileName = fileName;
|
||||
if (trimmedFileName.Length > 250)
|
||||
fileName = fileName.Substring(0, 250);
|
||||
|
||||
if (migratedFiles.Contains(trimmedFileName))
|
||||
{
|
||||
options.Logger?.LogDebug($" Migrating file '{fileName}' done");
|
||||
continue;
|
||||
}
|
||||
|
||||
string sqlScript = null;
|
||||
if (options.SourceAssembly == null)
|
||||
{
|
||||
sqlScript = await File.ReadAllTextAsync(migrationFile, cancellationToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
using var stream = options.SourceAssembly.GetManifestResourceStream(migrationFile);
|
||||
using var sr = new StreamReader(stream);
|
||||
sqlScript = await sr.ReadToEndAsync();
|
||||
}
|
||||
|
||||
if (string.IsNullOrWhiteSpace(sqlScript))
|
||||
continue;
|
||||
|
||||
options.Logger?.LogDebug($" Migrating file '{fileName}' started");
|
||||
command.Transaction = transaction;
|
||||
|
||||
await command.ExecuteScript(sqlScript, cancellationToken);
|
||||
|
||||
await transaction.CommitAsync(cancellationToken);
|
||||
command.Transaction = null;
|
||||
options.Logger?.LogDebug($" Migrating file '{fileName}' successful");
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await transaction.RollbackAsync(cancellationToken);
|
||||
options.Logger?.LogError($"Migrating file '{fileName}' failed: {ex.InnerException?.Message ?? ex.Message}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
options.Logger?.LogCritical(ex, $"Migrating the database failed ({ex.GetType().Name}): {ex.InnerException?.Message ?? ex.Message}");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task<int> ExecuteScript(this DbCommand command, string text, CancellationToken cancellationToken)
|
||||
{
|
||||
if (command.Connection.GetProviderType() == DatabaseProvider.Oracle)
|
||||
{
|
||||
int affectedRows = 0;
|
||||
// Split script by a single slash in a line
|
||||
string[] parts = Regex.Split(text, @"\r?\n[ \t]*/[ \t]*\r?\n");
|
||||
foreach (string part in parts)
|
||||
{
|
||||
// Make writable copy
|
||||
string pt = part;
|
||||
|
||||
// Remove the trailing semicolon from commands where they're not supported
|
||||
// (Oracle doesn't like semicolons. To keep the semicolon, it must be directly
|
||||
// preceeded by "end".)
|
||||
pt = Regex.Replace(pt.TrimEnd(), @"(?<!end);$", "", RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);
|
||||
|
||||
// Execute all non-empty parts as individual commands
|
||||
if (!string.IsNullOrWhiteSpace(pt))
|
||||
{
|
||||
command.CommandText = pt;
|
||||
affectedRows += await command.ExecuteNonQueryAsync(cancellationToken);
|
||||
}
|
||||
}
|
||||
return affectedRows;
|
||||
}
|
||||
else
|
||||
{
|
||||
command.CommandText = text;
|
||||
return await command.ExecuteNonQueryAsync(cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
private enum DatabaseProvider
|
||||
{
|
||||
MySQL = 1,
|
||||
Oracle = 2,
|
||||
PostgreSQL = 3,
|
||||
SQLite = 4,
|
||||
SQLServer = 5
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Reflection;
|
||||
using Microsoft.Extensions.Configuration;
|
||||
|
||||
namespace Microsoft.EntityFrameworkCore
|
||||
{
|
||||
/// <summary>
|
||||
/// Extends the <see cref="DbContextOptionsBuilder"/> to use a configurable database provider.
|
||||
/// </summary>
|
||||
public static class DbContextOptionsBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Adds the supported database provider to the context.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The configuration provided requires the following entries:
|
||||
/// <list type="bullet">
|
||||
/// <item><strong>Provider</strong>: MySQL | Oracle | PostgreSQL | SQLite | SQLServer</item>
|
||||
/// <item><strong>Host</strong>: hostname or IP address</item>
|
||||
/// <item><strong>Port</strong>: port number</item>
|
||||
/// <item><strong>Name</strong>: database name</item>
|
||||
/// <item><strong>Schema</strong>: schema or search path (e.g. PostgreSQL: public)</item>
|
||||
/// <item><strong>Username</strong>: username credential on the database</item>
|
||||
/// <item><strong>Password</strong>: password credential on the database</item>
|
||||
/// <item><strong>File</strong>: file name / path (for SQLite)</item>
|
||||
/// </list>
|
||||
/// </remarks>
|
||||
/// <param name="optionsBuilder">The options builder.</param>
|
||||
/// <param name="configuration">The application configuration section for the database.</param>
|
||||
/// <param name="optionsAction">An optional action to set additional options.</param>
|
||||
/// <returns>The <see cref="DbContextOptionsBuilder"/> with applied settings.</returns>
|
||||
public static DbContextOptionsBuilder UseDatabaseProvider(this DbContextOptionsBuilder optionsBuilder, IConfiguration configuration, Action<DatabaseProviderOptions> optionsAction = null)
|
||||
{
|
||||
if (optionsBuilder == null)
|
||||
throw new ArgumentNullException(nameof(optionsBuilder));
|
||||
|
||||
if (configuration == null)
|
||||
throw new ArgumentNullException(nameof(configuration));
|
||||
|
||||
var options = new DatabaseProviderOptions();
|
||||
optionsAction?.Invoke(options);
|
||||
|
||||
string connectionString = GetConnectionString(configuration, options);
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
|
||||
var builderType = GetBuilderType(configuration);
|
||||
var extensionType = GetExtensionType(configuration);
|
||||
var actionType = typeof(Action<>).MakeGenericType(builderType);
|
||||
|
||||
object serverVersion = null;
|
||||
MethodInfo methodInfo;
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
if (methodInfo == null)
|
||||
methodInfo = extensionType.GetMethod("UseMySQL", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
if (methodInfo == null) // Pomelo MySQL v5
|
||||
{
|
||||
var serverVersionType = Type.GetType("Microsoft.EntityFrameworkCore.ServerVersion, Pomelo.EntityFrameworkCore.MySql");
|
||||
var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", new Type[] { typeof(string) });
|
||||
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType });
|
||||
serverVersion = autoDetectMethodInfo.Invoke(null, new object[] { connectionString });
|
||||
}
|
||||
break;
|
||||
case "oracle":
|
||||
methodInfo = extensionType.GetMethod("UseOracle", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
methodInfo = extensionType.GetMethod("UseNpgsql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
case "sqlite":
|
||||
methodInfo = extensionType.GetMethod("UseSqlite", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
methodInfo = extensionType.GetMethod("UseSqlServer", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
|
||||
break;
|
||||
default:
|
||||
throw new DatabaseProviderException($"Unknown database provider: {provider}");
|
||||
}
|
||||
|
||||
if (serverVersion == null)
|
||||
{
|
||||
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, null });
|
||||
}
|
||||
else
|
||||
{
|
||||
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, serverVersion, null });
|
||||
}
|
||||
|
||||
return optionsBuilder;
|
||||
}
|
||||
|
||||
private static Type GetBuilderType(IConfiguration configuration)
|
||||
{
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
Type builderType;
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.MySqlDbContextOptionsBuilder, Pomelo.EntityFrameworkCore.MySql");
|
||||
if (builderType == null)
|
||||
builderType = Type.GetType("MySql.Data.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.Data.EntityFrameworkCore");
|
||||
break;
|
||||
case "oracle":
|
||||
builderType = Type.GetType("Oracle.EntityFrameworkCore.Infrastructure.OracleDbContextOptionsBuilder, Oracle.EntityFrameworkCore");
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
builderType = Type.GetType("Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.NpgsqlDbContextOptionsBuilder, Npgsql.EntityFrameworkCore.PostgreSQL");
|
||||
break;
|
||||
case "sqlite":
|
||||
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqliteDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.Sqlite");
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqlServerDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.SqlServer");
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException($"Unknown database provider: {provider}");
|
||||
}
|
||||
return builderType;
|
||||
}
|
||||
|
||||
private static Type GetExtensionType(IConfiguration configuration)
|
||||
{
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
Type extensionType;
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySqlDbContextOptionsBuilderExtensions, Pomelo.EntityFrameworkCore.MySql");
|
||||
if (extensionType == null)
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsBuilderExtensions, MySql.Data.EntityFrameworkCore");
|
||||
break;
|
||||
case "oracle":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.OracleDbContextOptionsBuilderExtensions, Oracle.EntityFrameworkCore");
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.NpgsqlDbContextOptionsBuilderExtensions, Npgsql.EntityFrameworkCore.PostgreSQL");
|
||||
break;
|
||||
case "sqlite":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqliteDbContextOptionsBuilderExtensions, Microsoft.EntityFrameworkCore.Sqlite");
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqlServerDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.SqlServer");
|
||||
break;
|
||||
default:
|
||||
throw new ArgumentException($"Unknown database provider: {provider}");
|
||||
}
|
||||
return extensionType;
|
||||
}
|
||||
|
||||
private static string GetConnectionString(IConfiguration configuration, DatabaseProviderOptions options)
|
||||
{
|
||||
var cs = new List<string>();
|
||||
string provider = configuration.GetValue<string>("provider")?.ToLower();
|
||||
switch (provider)
|
||||
{
|
||||
case "mysql":
|
||||
cs.Add($"Server={configuration.GetValue<string>("Host")}");
|
||||
cs.Add($"Port={configuration.GetValue("Port", 3306)}");
|
||||
cs.Add($"Database={configuration.GetValue<string>("Name")}");
|
||||
cs.Add($"Uid={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add($"Connection Timeout=15");
|
||||
break;
|
||||
case "oracle":
|
||||
cs.Add($"Data Source=(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={configuration.GetValue<string>("Host")})(PORT={configuration.GetValue("Port", 1521)}))(CONNECT_DATA=(SERVICE_NAME={configuration.GetValue<string>("Name")})))");
|
||||
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add($"Connection Timeout=15");
|
||||
break;
|
||||
case "postgres":
|
||||
case "postgresql":
|
||||
cs.Add($"Server={configuration.GetValue<string>("Host")}");
|
||||
cs.Add($"Port={configuration.GetValue("Port", 5432)}");
|
||||
cs.Add($"Database={configuration.GetValue<string>("Name")}");
|
||||
cs.Add($"Search Path={configuration.GetValue("Schema", "public")}");
|
||||
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add($"Timeout=15");
|
||||
break;
|
||||
case "sqlite":
|
||||
string path = configuration.GetValue<string>("File");
|
||||
if (!Path.IsPathRooted(path))
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(options.AbsoluteBasePath))
|
||||
options.AbsoluteBasePath = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location);
|
||||
|
||||
path = Path.Combine(options.AbsoluteBasePath, path);
|
||||
}
|
||||
cs.Add($"Data Source={path}");
|
||||
cs.Add("Foreign Keys=True");
|
||||
break;
|
||||
case "sqlserver":
|
||||
case "mssql":
|
||||
cs.Add($"Server={configuration.GetValue<string>("Host")},{configuration.GetValue("Port", 1433)}");
|
||||
cs.Add($"Database={configuration.GetValue<string>("Name")}");
|
||||
if (!string.IsNullOrWhiteSpace(configuration.GetValue<string>("Username")))
|
||||
{
|
||||
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
|
||||
cs.Add($"Password={configuration.GetValue<string>("Password")}");
|
||||
cs.Add("Integrated Security=False");
|
||||
}
|
||||
else
|
||||
{
|
||||
cs.Add("Integrated Security=True");
|
||||
}
|
||||
cs.Add("Connect Timeout=15");
|
||||
break;
|
||||
default:
|
||||
throw new DatabaseProviderException($"Unknown database provider: {provider}");
|
||||
}
|
||||
return string.Join(";", cs);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
using System.ComponentModel.DataAnnotations.Schema;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using AMWD.Common.EntityFrameworkCore.Attributes;
|
||||
using Microsoft.EntityFrameworkCore;
|
||||
#if NET5_0_OR_GREATER
|
||||
using Microsoft.EntityFrameworkCore.Metadata;
|
||||
#endif
|
||||
|
||||
namespace AMWD.Common.EntityFrameworkCore.Extensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Extensions for the <see cref="ModelBuilder"/> of entity framework core.
|
||||
/// </summary>
|
||||
public static class ModelBuilderExtensions
|
||||
{
|
||||
/// <summary>
|
||||
/// Applies indices and unique constraints to the properties.
|
||||
/// </summary>
|
||||
/// <param name="builder">The database model builder.</param>
|
||||
/// <returns>A reference to this instance after the operation has completed.</returns>
|
||||
[System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0019", Justification = "No pattern comparison in this case due to readability.")]
|
||||
public static ModelBuilder ApplyIndexAttributes(this ModelBuilder builder)
|
||||
{
|
||||
foreach (var entityType in builder.Model.GetEntityTypes())
|
||||
{
|
||||
foreach (var property in entityType.GetProperties())
|
||||
{
|
||||
var indexAttribute = entityType.ClrType
|
||||
.GetProperty(property.Name)
|
||||
?.GetCustomAttribute(typeof(DatabaseIndexAttribute), false) as DatabaseIndexAttribute;
|
||||
if (indexAttribute != null)
|
||||
{
|
||||
var index = entityType.AddIndex(property);
|
||||
index.IsUnique = indexAttribute.IsUnique;
|
||||
|
||||
if (!string.IsNullOrWhiteSpace(indexAttribute.Name))
|
||||
{
|
||||
#if NET5_0_OR_GREATER
|
||||
index.SetDatabaseName(indexAttribute.Name.Trim());
|
||||
#else
|
||||
index.SetName(indexAttribute.Name.Trim());
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts all table and column names to snake_case_names.
|
||||
/// </summary>
|
||||
/// <param name="builder">The database model builder.</param>
|
||||
/// <returns>A reference to this instance after the operation has completed.</returns>
|
||||
public static ModelBuilder ApplySnakeCase(this ModelBuilder builder)
|
||||
{
|
||||
foreach (var entityType in builder.Model.GetEntityTypes())
|
||||
{
|
||||
// skip conversion when table name is explicitly set
|
||||
if ((entityType.ClrType.GetCustomAttribute(typeof(TableAttribute), false) as TableAttribute) == null)
|
||||
{
|
||||
#if NET5_0_OR_GREATER
|
||||
entityType.SetTableName(ConvertToSnakeCase(entityType.GetTableName()));
|
||||
#else
|
||||
entityType.SetTableName(ConvertToSnakeCase(entityType.GetTableName()));
|
||||
#endif
|
||||
}
|
||||
|
||||
#if NET5_0_OR_GREATER
|
||||
var identifier = StoreObjectIdentifier.Table(entityType.GetTableName(), entityType.GetSchema());
|
||||
#endif
|
||||
foreach (var property in entityType.GetProperties())
|
||||
{
|
||||
// skip conversion when column name is explicitly set
|
||||
if ((entityType.ClrType.GetProperty(property.Name)?.GetCustomAttribute(typeof(ColumnAttribute), false) as ColumnAttribute) == null)
|
||||
{
|
||||
#if NET5_0_OR_GREATER
|
||||
property.SetColumnName(ConvertToSnakeCase(property.GetColumnName(identifier)));
|
||||
#else
|
||||
property.SetColumnName(ConvertToSnakeCase(property.GetColumnName()));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Converts a string to its snake_case equivalent.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// Code borrowed from Npgsql.NameTranslation.NpgsqlSnakeCaseNameTranslator.
|
||||
/// See https://github.com/npgsql/npgsql/blob/f2b2c98f45df6d2a78eec00ae867f18944d717ca/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs#L76-L136.
|
||||
/// </remarks>
|
||||
/// <param name="value">The value to convert.</param>
|
||||
private static string ConvertToSnakeCase(string value)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
var state = SnakeCaseState.Start;
|
||||
|
||||
for (int i = 0; i < value.Length; i++)
|
||||
{
|
||||
if (value[i] == ' ')
|
||||
{
|
||||
if (state != SnakeCaseState.Start)
|
||||
state = SnakeCaseState.NewWord;
|
||||
}
|
||||
else if (char.IsUpper(value[i]))
|
||||
{
|
||||
switch (state)
|
||||
{
|
||||
case SnakeCaseState.Upper:
|
||||
bool hasNext = (i + 1 < value.Length);
|
||||
if (i > 0 && hasNext)
|
||||
{
|
||||
char nextChar = value[i + 1];
|
||||
if (!char.IsUpper(nextChar) && nextChar != '_')
|
||||
{
|
||||
sb.Append('_');
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
||||
case SnakeCaseState.Lower:
|
||||
case SnakeCaseState.NewWord:
|
||||
sb.Append('_');
|
||||
break;
|
||||
}
|
||||
|
||||
sb.Append(char.ToLowerInvariant(value[i]));
|
||||
state = SnakeCaseState.Upper;
|
||||
}
|
||||
else if (value[i] == '_')
|
||||
{
|
||||
sb.Append('_');
|
||||
state = SnakeCaseState.Start;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (state == SnakeCaseState.NewWord)
|
||||
sb.Append('_');
|
||||
|
||||
sb.Append(value[i]);
|
||||
state = SnakeCaseState.Lower;
|
||||
}
|
||||
}
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private enum SnakeCaseState
|
||||
{
|
||||
Start,
|
||||
Lower,
|
||||
Upper,
|
||||
NewWord
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user