using System; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.IO; using System.Linq; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.Extensions.Logging; namespace Microsoft.EntityFrameworkCore { /// /// Extensions for the . /// public static class DatabaseFacadeExtensions { /// /// Applies migration files to the database. /// /// The database connection. /// An action to set additional options. /// The cancellation token. /// true on success, otherwise false or an exception is thrown. [System.Diagnostics.CodeAnalysis.SuppressMessage("Usage", "CA2208")] public static async Task ApplyMigrationsAsync(this DatabaseFacade database, Action optionsAction, CancellationToken cancellationToken = default) { if (database == null) throw new ArgumentNullException(nameof(database)); if (database.GetProviderType() == DatabaseProvider.InMemory) return true; var options = new DatabaseMigrationOptions(); optionsAction?.Invoke(options); if (string.IsNullOrWhiteSpace(options.MigrationsTableName)) throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required."); if (string.IsNullOrWhiteSpace(options.Path)) throw new ArgumentNullException(nameof(options.Path), $"The property {nameof(options.Path)} of the {nameof(options)} parameter is required."); await database.WaitAvailableAsync(opts => { opts.WaitDelay = options.WaitDelay; opts.Logger = options.Logger; }, cancellationToken).ConfigureAwait(false); var connection = database.GetDbConnection(); try { await connection.OpenAsync(cancellationToken).ConfigureAwait(false); if (!await connection.CreateMigrationsTable(options, cancellationToken).ConfigureAwait(false)) return false; return await connection.Migrate(options, cancellationToken).ConfigureAwait(false); } finally { await connection.CloseAsync().ConfigureAwait(false); } } /// /// Waits until the database connection is available. /// /// The database connection. /// An action to set additional options. /// A cancellation token. /// An awaitable task to wait until the database is available. public static async Task WaitAvailableAsync(this DatabaseFacade database, Action optionsAction = null, CancellationToken cancellationToken = default) { if (database == null) throw new ArgumentNullException(nameof(database)); if (database.GetProviderType() == DatabaseProvider.InMemory) return; var options = new DatabaseMigrationOptions(); optionsAction?.Invoke(options); options.Logger?.LogInformation("Waiting for a database connection"); var connection = database.GetDbConnection(); while (!cancellationToken.IsCancellationRequested) { try { await connection.OpenAsync(cancellationToken).ConfigureAwait(false); options.Logger?.LogInformation("Database connection available"); return; } catch { // keep things quiet try { await Task.Delay(options.WaitDelay, cancellationToken).ConfigureAwait(false); } catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { return; } catch { // keep things quiet } } finally { await connection.CloseAsync().ConfigureAwait(false); } } } internal static DatabaseProvider GetProviderType(this DatabaseFacade database) => GetProviderType(database.ProviderName); private static DatabaseProvider GetProviderType(this DbConnection connection) => GetProviderType(connection.GetType().FullName); private static DatabaseProvider GetProviderType(string provider) { if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase)) return DatabaseProvider.MySQL; if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase)) return DatabaseProvider.Oracle; if (provider.Contains("npgsql", StringComparison.OrdinalIgnoreCase)) return DatabaseProvider.PostgreSQL; if (provider.Contains("sqlite", StringComparison.OrdinalIgnoreCase)) return DatabaseProvider.SQLite; if (provider.Contains("sqlclient", StringComparison.OrdinalIgnoreCase) || provider.Contains("sqlserver", StringComparison.OrdinalIgnoreCase)) return DatabaseProvider.SQLServer; if (provider.Contains("inmemory", StringComparison.OrdinalIgnoreCase)) return DatabaseProvider.InMemory; throw new DatabaseProviderException($"The database provider '{provider}' is unknown"); } private static async Task CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken) { try { using var command = connection.CreateCommand(); #pragma warning disable CS8509 // ignore missing cases command.CommandText = connection.GetProviderType() switch #pragma warning restore CS8509 // ignore missing cases { DatabaseProvider.MySQL => $@"CREATE TABLE IF NOT EXISTS `{options.MigrationsTableName}` ( `id` INT NOT NULL AUTO_INCREMENT, `schema_file` VARCHAR(250) NOT NULL, `installed_at` VARCHAR(16) NOT NULL, PRIMARY KEY (`id`) );", DatabaseProvider.Oracle => $@"DECLARE ncount NUMBER; BEGIN SELECT count(*) INTO ncount FROM dba_tables WHERE table_name = '{options.MigrationsTableName}'; IF (ncount <= 0) THEN EXECUTE IMMEDIATE 'CREATE TABLE ""{options.MigrationsTableName}"" ( ""id"" NUMBER GENERATED by default on null as IDENTITY, ""schema_file"" VARCHAR2(250) NOT NULL, ""installed_at"" VARCHAR2(16) NOT NULL, PRIMARY KEY (""id""), CONSTRAINT uq_schema_file UNIQUE (""schema_file"") )'; END IF; END;", DatabaseProvider.PostgreSQL => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" ( ""id"" SERIAL4 PRIMARY KEY, ""schema_file"" VARCHAR(250) NOT NULL, ""installed_at"" VARCHAR(16) NOT NULL, CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"") );", DatabaseProvider.SQLite => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" ( ""id"" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, ""schema_file"" TEXT(250) NOT NULL, ""installed_at"" TEXT(16) NOT NULL, CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"") );", DatabaseProvider.SQLServer => $@"IF NOT EXISTS (SELECT * FROM [sysobjects] WHERE [name] = '{options.MigrationsTableName}' AND [xtype] = 'U') BEGIN CREATE TABLE [{options.MigrationsTableName}] ( [id] int IDENTITY(1,1) NOT NULL PRIMARY KEY, [schema_file] varchar(250) NOT NULL, [installed_at] varchar(16) NOT NULL, CONSTRAINT uq_schema_file UNIQUE (schema_file) ) END;" }; await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); return true; } catch (Exception ex) { options.Logger?.LogCritical(ex, $"Creating migrations table '{options.MigrationsTableName}' failed: {ex.InnerException?.Message ?? ex.Message}"); return false; } } private static async Task Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken) { try { List availableMigrationFiles; if (options.SourceAssembly == null) { availableMigrationFiles = Directory.GetFiles(options.Path) .Where(f => f.ToLower().StartsWith(options.Path.ToLower())) .Where(f => f.ToLower().EndsWith(".sql")) .ToList(); } else { availableMigrationFiles = options.SourceAssembly .GetManifestResourceNames() .Where(f => f.ToLower().StartsWith(options.Path.ToLower())) .Where(f => f.ToLower().EndsWith(".sql")) .ToList(); } if (!availableMigrationFiles.Any()) return true; using var command = connection.CreateCommand(); var migratedFiles = new List(); command.CommandText = connection.GetProviderType() switch { DatabaseProvider.MySQL => $"SELECT `schema_file` FROM `{options.MigrationsTableName}`;", DatabaseProvider.SQLServer => $"SELECT [schema_file] FROM [{options.MigrationsTableName}];", _ => $@"SELECT ""schema_file"" FROM ""{options.MigrationsTableName}"";", }; using (var reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false)) { while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false)) migratedFiles.Add(reader.GetString(0)); } int pathLength = options.Path.Length + 1; foreach (string migrationFile in availableMigrationFiles) { // remove path including the separator string fileName = migrationFile.Replace(options.Path, "")[1..]; using var transaction = await connection.BeginTransactionAsync(cancellationToken).ConfigureAwait(false); try { // max length in the database: 250 chars string trimmedFileName = fileName; if (trimmedFileName.Length > 250) fileName = fileName[..250]; if (migratedFiles.Contains(trimmedFileName)) { options.Logger?.LogDebug($" Migrating file '{fileName}' done"); continue; } string sqlScript = null; if (options.SourceAssembly == null) { sqlScript = await File.ReadAllTextAsync(migrationFile, cancellationToken).ConfigureAwait(false); } else { using var stream = options.SourceAssembly.GetManifestResourceStream(migrationFile); using var sr = new StreamReader(stream); sqlScript = await sr.ReadToEndAsync().ConfigureAwait(false); } if (string.IsNullOrWhiteSpace(sqlScript)) continue; options.Logger?.LogDebug($" Migrating file '{fileName}' started"); command.Transaction = transaction; await command.ExecuteScript(sqlScript, cancellationToken).ConfigureAwait(false); command.CommandText = connection.GetProviderType() switch { DatabaseProvider.MySQL => $"INSERT INTO `{options.MigrationsTableName}` (`schema_file`, `installed_at`) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');", DatabaseProvider.SQLServer => $"INSERT INTO [{options.MigrationsTableName}] ([schema_file], [installed_at]) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');", _ => $@"INSERT INTO ""{options.MigrationsTableName}"" (""schema_file"", ""installed_at"") VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');", }; await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); await transaction.CommitAsync(cancellationToken).ConfigureAwait(false); command.Transaction = null; options.Logger?.LogDebug($" Migrating file '{fileName}' successful"); } catch (Exception ex) { await transaction.RollbackAsync(cancellationToken).ConfigureAwait(false); options.Logger?.LogError($"Migrating file '{fileName}' failed: {ex.InnerException?.Message ?? ex.Message}"); return false; } } return true; } catch (Exception ex) { options.Logger?.LogCritical(ex, $"Migrating the database failed ({ex.GetType().Name}): {ex.InnerException?.Message ?? ex.Message}"); return false; } } private static async Task ExecuteScript(this DbCommand command, string text, CancellationToken cancellationToken) { if (command.Connection.GetProviderType() == DatabaseProvider.Oracle) { int affectedRows = 0; // Split script by a single slash in a line string[] parts = Regex.Split(text, @"\r?\n[ \t]*/[ \t]*\r?\n"); foreach (string part in parts) { // Make writable copy string pt = part; // Remove the trailing semicolon from commands where they're not supported // (Oracle doesn't like semicolons. To keep the semicolon, it must be directly // preceeded by "end".) pt = Regex.Replace(pt.TrimEnd(), @"(?