using System;
using System.Collections.Generic;
using System.Data.Common;
using System.IO;
using System.Linq;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.Extensions.Logging;
namespace Microsoft.EntityFrameworkCore
{
///
/// Extensions for the .
///
public static class DatabaseFacadeExtensions
{
///
/// Applies migration files to the database.
///
/// The database connection.
/// An action to set additional options.
/// The cancellation token.
/// true on success, otherwise false or an exception is thrown.
public static async Task ApplyMigrationsAsync(this DatabaseFacade database, Action optionsAction, CancellationToken cancellationToken = default)
{
if (database == null)
throw new ArgumentNullException(nameof(database));
var options = new DatabaseMigrationOptions();
optionsAction?.Invoke(options);
if (string.IsNullOrWhiteSpace(options.MigrationsTableName))
throw new ArgumentNullException(nameof(options.MigrationsTableName), $"The property {nameof(options.MigrationsTableName)} of the {nameof(options)} parameter is required.");
if (string.IsNullOrWhiteSpace(options.Path))
throw new ArgumentNullException(nameof(options.Path), $"The property {nameof(options.Path)} of the {nameof(options)} parameter is required.");
var connection = database.GetDbConnection();
try
{
await connection.OpenAsync(cancellationToken);
if (!await connection.CreateMigrationsTable(options, cancellationToken))
return false;
return await connection.Migrate(options, cancellationToken);
}
finally
{
connection.Close();
}
}
private static DatabaseProvider GetProviderType(this DbConnection connection)
{
string provider = connection.GetType().FullName;
if (provider.Contains("mysql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.MySQL;
if (provider.Contains("oracle", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.Oracle;
if (provider.Contains("npgsql", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.PostgreSQL;
if (provider.Contains("sqlite", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.SQLite;
if (provider.Contains("sqlclient", StringComparison.OrdinalIgnoreCase))
return DatabaseProvider.SQLServer;
throw new DatabaseProviderException($"The database provider '{provider}' is unknown");
}
private static async Task CreateMigrationsTable(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{
try
{
using var command = connection.CreateCommand();
#pragma warning disable CS8524 // missing default case
command.CommandText = connection.GetProviderType() switch
#pragma warning restore CS8524 // missing default case
{
DatabaseProvider.MySQL => $@"CREATE TABLE IF NOT EXISTS `{options.MigrationsTableName}` (
`id` INT NOT NULL AUTO_INCREMENT,
`schema_file` VARCHAR(250) NOT NULL,
`installed_at` VARCHAR(16) NOT NULL,
PRIMARY KEY (`id`)
);",
DatabaseProvider.Oracle => $@"DECLARE ncount NUMBER;
BEGIN
SELECT count(*) INTO ncount FROM dba_tables WHERE table_name = '{options.MigrationsTableName}';
IF (ncount <= 0)
THEN
EXECUTE IMMEDIATE 'CREATE TABLE ""{options.MigrationsTableName}"" (
""id"" NUMBER GENERATED by default on null as IDENTITY,
""schema_file"" VARCHAR2(250) NOT NULL,
""installed_at"" VARCHAR2(16) NOT NULL,
PRIMARY KEY (""id""),
CONSTRAINT uq_schema_file UNIQUE (""schema_file"")
)';
END IF;
END;",
DatabaseProvider.PostgreSQL => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
""id"" SERIAL4 PRIMARY KEY,
""schema_file"" VARCHAR(250) NOT NULL,
""installed_at"" VARCHAR(16) NOT NULL,
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
);",
DatabaseProvider.SQLite => $@"CREATE TABLE IF NOT EXISTS ""{options.MigrationsTableName}"" (
""id"" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
""schema_file"" TEXT(250) NOT NULL,
""installed_at"" TEXT(16) NOT NULL,
CONSTRAINT ""uq_schema_file"" UNIQUE (""schema_file"")
);",
DatabaseProvider.SQLServer => $@"IF NOT EXISTS (SELECT * FROM [sysobjects] WHERE [name] = '{options.MigrationsTableName}' AND [xtype] = 'U')
BEGIN
CREATE TABLE [{options.MigrationsTableName}] (
[id] int IDENTITY(1,1) NOT NULL PRIMARY KEY,
[schema_file] varchar(250) NOT NULL,
[installed_at] varchar(16) NOT NULL,
CONSTRAINT uq_schema_file UNIQUE (schema_file)
)
END;"
};
await command.ExecuteNonQueryAsync(cancellationToken);
return true;
}
catch (Exception ex)
{
options.Logger?.LogCritical(ex, $"Creating migrations table '{options.MigrationsTableName}' failed: {ex.InnerException?.Message ?? ex.Message}");
return false;
}
}
private static async Task Migrate(this DbConnection connection, DatabaseMigrationOptions options, CancellationToken cancellationToken)
{
try
{
List availableMigrationFiles;
if (options.SourceAssembly == null)
{
availableMigrationFiles = Directory.GetFiles(options.Path)
.Where(f => f.ToLower().StartsWith(options.Path.ToLower()))
.Where(f => f.ToLower().EndsWith(".sql"))
.ToList();
}
else
{
availableMigrationFiles = options.SourceAssembly
.GetManifestResourceNames()
.Where(f => f.ToLower().StartsWith(options.Path.ToLower()))
.Where(f => f.ToLower().EndsWith(".sql"))
.ToList();
}
if (!availableMigrationFiles.Any())
return true;
using var command = connection.CreateCommand();
var migratedFiles = new List();
command.CommandText = connection.GetProviderType() switch
{
DatabaseProvider.MySQL => $"SELECT `schema_file` FROM `{options.MigrationsTableName}`;",
DatabaseProvider.SQLServer => $"SELECT [schema_file] FROM [{options.MigrationsTableName}];",
_ => $@"SELECT ""schema_file"" FROM ""{options.MigrationsTableName}"";",
};
using (var reader = await command.ExecuteReaderAsync(cancellationToken))
{
while (await reader.ReadAsync(cancellationToken))
migratedFiles.Add(reader.GetString(0));
}
int pathLength = options.Path.Length + 1;
foreach (string migrationFile in availableMigrationFiles)
{
// remove path including the separator
string fileName = migrationFile.Replace(options.Path, "")[1..];
using var transaction = await connection.BeginTransactionAsync(cancellationToken);
try
{
// max length in the database: 250 chars
string trimmedFileName = fileName;
if (trimmedFileName.Length > 250)
fileName = fileName[..250];
if (migratedFiles.Contains(trimmedFileName))
{
options.Logger?.LogDebug($" Migrating file '{fileName}' done");
continue;
}
string sqlScript = null;
if (options.SourceAssembly == null)
{
sqlScript = await File.ReadAllTextAsync(migrationFile, cancellationToken);
}
else
{
using var stream = options.SourceAssembly.GetManifestResourceStream(migrationFile);
using var sr = new StreamReader(stream);
sqlScript = await sr.ReadToEndAsync();
}
if (string.IsNullOrWhiteSpace(sqlScript))
continue;
options.Logger?.LogDebug($" Migrating file '{fileName}' started");
command.Transaction = transaction;
await command.ExecuteScript(sqlScript, cancellationToken);
command.CommandText = connection.GetProviderType() switch
{
DatabaseProvider.MySQL => $"INSERT INTO `{options.MigrationsTableName}` (`schema_file`, `installed_at`) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
DatabaseProvider.SQLServer => $"INSERT INTO [{options.MigrationsTableName}] ([schema_file], [installed_at]) VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
_ => $@"INSERT INTO ""{options.MigrationsTableName}"" (""schema_file"", ""installed_at"") VALUES ('{trimmedFileName.Replace("'", "\\'")}', '{DateTime.UtcNow:yyyy-MM-dd HH:mm}');",
};
await command.ExecuteNonQueryAsync(cancellationToken);
await transaction.CommitAsync(cancellationToken);
command.Transaction = null;
options.Logger?.LogDebug($" Migrating file '{fileName}' successful");
}
catch (Exception ex)
{
await transaction.RollbackAsync(cancellationToken);
options.Logger?.LogError($"Migrating file '{fileName}' failed: {ex.InnerException?.Message ?? ex.Message}");
return false;
}
}
return true;
}
catch (Exception ex)
{
options.Logger?.LogCritical(ex, $"Migrating the database failed ({ex.GetType().Name}): {ex.InnerException?.Message ?? ex.Message}");
return false;
}
}
private static async Task ExecuteScript(this DbCommand command, string text, CancellationToken cancellationToken)
{
if (command.Connection.GetProviderType() == DatabaseProvider.Oracle)
{
int affectedRows = 0;
// Split script by a single slash in a line
string[] parts = Regex.Split(text, @"\r?\n[ \t]*/[ \t]*\r?\n");
foreach (string part in parts)
{
// Make writable copy
string pt = part;
// Remove the trailing semicolon from commands where they're not supported
// (Oracle doesn't like semicolons. To keep the semicolon, it must be directly
// preceeded by "end".)
pt = Regex.Replace(pt.TrimEnd(), @"(?