using System; using System.Collections.Generic; using System.IO; using System.Reflection; using Microsoft.Extensions.Configuration; namespace Microsoft.EntityFrameworkCore { /// /// Extends the to use a configurable database provider. /// public static class DbContextOptionsBuilderExtensions { /// /// Adds the supported database provider to the context. /// /// /// The configuration provided requires the following entries: /// /// Provider: MySQL | Oracle | PostgreSQL | SQLite | SQLServer /// Host: hostname or IP address /// Port: port number /// Name: database name /// Schema: schema or search path (e.g. PostgreSQL: public) /// Username: username credential on the database /// Password: password credential on the database /// File: file name / path (for SQLite) /// /// /// The options builder. /// The application configuration section for the database. /// An optional action to set additional options. /// The with applied settings. public static DbContextOptionsBuilder UseDatabaseProvider(this DbContextOptionsBuilder optionsBuilder, IConfiguration configuration, Action optionsAction = null) { if (optionsBuilder == null) throw new ArgumentNullException(nameof(optionsBuilder)); if (configuration == null) throw new ArgumentNullException(nameof(configuration)); var options = new DatabaseProviderOptions(); optionsAction?.Invoke(options); string connectionString = GetConnectionString(configuration, options); string provider = configuration.GetValue("provider")?.ToLower(); var builderType = GetBuilderType(configuration); var extensionType = GetExtensionType(configuration); var actionType = typeof(Action<>).MakeGenericType(builderType); object serverVersion = null; MethodInfo methodInfo; switch (provider) { case "mysql": methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType }); if (methodInfo == null) methodInfo = extensionType.GetMethod("UseMySQL", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType }); if (methodInfo == null) // Pomelo MySQL v5 { var serverVersionType = Type.GetType("Microsoft.EntityFrameworkCore.ServerVersion, Pomelo.EntityFrameworkCore.MySql"); var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", new Type[] { typeof(string) }); methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType }); serverVersion = autoDetectMethodInfo.Invoke(null, new object[] { connectionString }); } break; case "oracle": methodInfo = extensionType.GetMethod("UseOracle", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType }); break; case "postgres": case "postgresql": methodInfo = extensionType.GetMethod("UseNpgsql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType }); break; case "sqlite": methodInfo = extensionType.GetMethod("UseSqlite", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType }); break; case "sqlserver": case "mssql": methodInfo = extensionType.GetMethod("UseSqlServer", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType }); break; case "memory": case "inmemory": methodInfo = extensionType.GetMethod("UseInMemoryDatabase", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType }); break; default: throw new DatabaseProviderException($"Unknown database provider: {provider}"); } if (serverVersion == null) { methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, null }); } else { methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, serverVersion, null }); } return optionsBuilder; } private static Type GetBuilderType(IConfiguration configuration) { string provider = configuration.GetValue("provider")?.ToLower(); Type builderType; switch (provider) { case "mysql": builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.MySqlDbContextOptionsBuilder, Pomelo.EntityFrameworkCore.MySql"); if (builderType == null) builderType = Type.GetType("MySql.Data.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.Data.EntityFrameworkCore"); if (builderType == null) // as MySql.Data.EntityFrameworkCore is marked as deprecated on NuGet builderType = Type.GetType("MySql.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.EntityFrameworkCore"); break; case "oracle": builderType = Type.GetType("Oracle.EntityFrameworkCore.Infrastructure.OracleDbContextOptionsBuilder, Oracle.EntityFrameworkCore"); break; case "postgres": case "postgresql": builderType = Type.GetType("Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.NpgsqlDbContextOptionsBuilder, Npgsql.EntityFrameworkCore.PostgreSQL"); break; case "sqlite": builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqliteDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.Sqlite"); break; case "sqlserver": case "mssql": builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqlServerDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.SqlServer"); break; case "memory": case "inmemory": builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.InMemoryDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.InMemory"); break; default: throw new ArgumentException($"Unknown database provider: {provider}"); } return builderType; } private static Type GetExtensionType(IConfiguration configuration) { string provider = configuration.GetValue("provider")?.ToLower(); Type extensionType; switch (provider) { case "mysql": extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySqlDbContextOptionsBuilderExtensions, Pomelo.EntityFrameworkCore.MySql"); if (extensionType == null) extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.Data.EntityFrameworkCore"); if (extensionType == null) extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.EntityFrameworkCore"); break; case "oracle": extensionType = Type.GetType("Microsoft.EntityFrameworkCore.OracleDbContextOptionsBuilderExtensions, Oracle.EntityFrameworkCore"); break; case "postgres": case "postgresql": extensionType = Type.GetType("Microsoft.EntityFrameworkCore.NpgsqlDbContextOptionsBuilderExtensions, Npgsql.EntityFrameworkCore.PostgreSQL"); break; case "sqlite": extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqliteDbContextOptionsBuilderExtensions, Microsoft.EntityFrameworkCore.Sqlite"); break; case "sqlserver": case "mssql": extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqlServerDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.SqlServer"); break; case "memory": case "inmemory": extensionType = Type.GetType("Microsoft.EntityFrameworkCore.InMemoryDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.InMemory"); break; default: throw new ArgumentException($"Unknown database provider: {provider}"); } return extensionType; } private static string GetConnectionString(IConfiguration configuration, DatabaseProviderOptions options) { var cs = new List(); string provider = configuration.GetValue("provider")?.ToLower(); switch (provider) { case "mysql": cs.Add($"Server={configuration.GetValue("Host")}"); cs.Add($"Port={configuration.GetValue("Port", 3306)}"); cs.Add($"Database={configuration.GetValue("Name")}"); cs.Add($"Uid={configuration.GetValue("Username")}"); cs.Add($"Password={configuration.GetValue("Password")}"); cs.Add($"Connection Timeout=15"); break; case "oracle": cs.Add($"Data Source=(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={configuration.GetValue("Host")})(PORT={configuration.GetValue("Port", 1521)}))(CONNECT_DATA=(SERVICE_NAME={configuration.GetValue("Name")})))"); cs.Add($"User Id={configuration.GetValue("Username")}"); cs.Add($"Password={configuration.GetValue("Password")}"); cs.Add($"Connection Timeout=15"); break; case "postgres": case "postgresql": cs.Add($"Server={configuration.GetValue("Host")}"); cs.Add($"Port={configuration.GetValue("Port", 5432)}"); cs.Add($"Database={configuration.GetValue("Name")}"); cs.Add($"Search Path={configuration.GetValue("Schema", "public")}"); cs.Add($"User Id={configuration.GetValue("Username")}"); cs.Add($"Password={configuration.GetValue("Password")}"); cs.Add($"Timeout=15"); break; case "sqlite": string path = configuration.GetValue("File"); if (!Path.IsPathRooted(path)) { if (string.IsNullOrWhiteSpace(options.AbsoluteBasePath)) options.AbsoluteBasePath = Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location); path = Path.Combine(options.AbsoluteBasePath, path); } cs.Add($"Data Source={path}"); cs.Add("Foreign Keys=True"); break; case "sqlserver": case "mssql": cs.Add($"Server={configuration.GetValue("Host")},{configuration.GetValue("Port", 1433)}"); cs.Add($"Database={configuration.GetValue("Name")}"); if (!string.IsNullOrWhiteSpace(configuration.GetValue("Username"))) { cs.Add($"User Id={configuration.GetValue("Username")}"); cs.Add($"Password={configuration.GetValue("Password")}"); cs.Add("Integrated Security=False"); } else { cs.Add("Integrated Security=True"); } cs.Add("Connect Timeout=15"); break; case "memory": case "inmemory": cs.Add(configuration.GetValue("Name", provider)); break; default: throw new DatabaseProviderException($"Unknown database provider: {provider}"); } return string.Join(";", cs); } } }