using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using Microsoft.Extensions.Configuration;
namespace Microsoft.EntityFrameworkCore
{
///
/// Extends the to use a configurable database provider.
///
public static class DbContextOptionsBuilderExtensions
{
///
/// Adds the supported database provider to the context.
///
///
/// The configuration provided requires the following entries:
///
/// - Provider: MySQL | Oracle | PostgreSQL | SQLite | SQLServer
/// - Host: hostname or IP address
/// - Port: port number
/// - Name: database name
/// - Schema: schema or search path (e.g. PostgreSQL: public)
/// - Username: username credential on the database
/// - Password: password credential on the database
/// - File: file name / path (for SQLite)
///
///
/// The options builder.
/// The application configuration section for the database.
/// An optional action to set additional options.
/// The with applied settings.
public static DbContextOptionsBuilder UseDatabaseProvider(this DbContextOptionsBuilder optionsBuilder, IConfiguration configuration, Action optionsAction = null)
{
if (optionsBuilder == null)
throw new ArgumentNullException(nameof(optionsBuilder));
if (configuration == null)
throw new ArgumentNullException(nameof(configuration));
var options = new DatabaseProviderOptions();
optionsAction?.Invoke(options);
string connectionString = GetConnectionString(configuration, options);
string provider = configuration.GetValue("provider")?.ToLower();
var builderType = GetBuilderType(configuration);
var extensionType = GetExtensionType(configuration);
var actionType = typeof(Action<>).MakeGenericType(builderType);
object serverVersion = null;
MethodInfo methodInfo;
switch (provider)
{
case "mysql":
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
if (methodInfo == null)
methodInfo = extensionType.GetMethod("UseMySQL", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
if (methodInfo == null) // Pomelo MySQL v5
{
var serverVersionType = Type.GetType("Microsoft.EntityFrameworkCore.ServerVersion, Pomelo.EntityFrameworkCore.MySql");
var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", new Type[] { typeof(string) });
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType });
serverVersion = autoDetectMethodInfo.Invoke(null, new object[] { connectionString });
}
break;
case "oracle":
methodInfo = extensionType.GetMethod("UseOracle", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "postgres":
case "postgresql":
methodInfo = extensionType.GetMethod("UseNpgsql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "sqlite":
methodInfo = extensionType.GetMethod("UseSqlite", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "sqlserver":
case "mssql":
methodInfo = extensionType.GetMethod("UseSqlServer", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "memory":
case "inmemory":
methodInfo = extensionType.GetMethod("UseInMemoryDatabase", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
}
if (serverVersion == null)
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, null });
}
else
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, serverVersion, null });
}
return optionsBuilder;
}
private static Type GetBuilderType(IConfiguration configuration)
{
string provider = configuration.GetValue("provider")?.ToLower();
Type builderType;
switch (provider)
{
case "mysql":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.MySqlDbContextOptionsBuilder, Pomelo.EntityFrameworkCore.MySql");
if (builderType == null)
builderType = Type.GetType("MySql.Data.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.Data.EntityFrameworkCore");
if (builderType == null) // as MySql.Data.EntityFrameworkCore is marked as deprecated on NuGet
builderType = Type.GetType("MySql.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.EntityFrameworkCore");
break;
case "oracle":
builderType = Type.GetType("Oracle.EntityFrameworkCore.Infrastructure.OracleDbContextOptionsBuilder, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
builderType = Type.GetType("Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.NpgsqlDbContextOptionsBuilder, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqliteDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqlServerDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.InMemoryDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return builderType;
}
private static Type GetExtensionType(IConfiguration configuration)
{
string provider = configuration.GetValue("provider")?.ToLower();
Type extensionType;
switch (provider)
{
case "mysql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySqlDbContextOptionsBuilderExtensions, Pomelo.EntityFrameworkCore.MySql");
if (extensionType == null)
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.Data.EntityFrameworkCore");
if (extensionType == null)
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.EntityFrameworkCore");
break;
case "oracle":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.OracleDbContextOptionsBuilderExtensions, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.NpgsqlDbContextOptionsBuilderExtensions, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqliteDbContextOptionsBuilderExtensions, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqlServerDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.InMemoryDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return extensionType;
}
private static string GetConnectionString(IConfiguration configuration, DatabaseProviderOptions options)
{
var cs = new List();
string provider = configuration.GetValue("provider")?.ToLower();
switch (provider)
{
case "mysql":
cs.Add($"Server={configuration.GetValue("Host")}");
cs.Add($"Port={configuration.GetValue("Port", 3306)}");
cs.Add($"Database={configuration.GetValue("Name")}");
cs.Add($"Uid={configuration.GetValue("Username")}");
cs.Add($"Password={configuration.GetValue("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "oracle":
cs.Add($"Data Source=(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={configuration.GetValue("Host")})(PORT={configuration.GetValue("Port", 1521)}))(CONNECT_DATA=(SERVICE_NAME={configuration.GetValue("Name")})))");
cs.Add($"User Id={configuration.GetValue("Username")}");
cs.Add($"Password={configuration.GetValue("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "postgres":
case "postgresql":
cs.Add($"Server={configuration.GetValue("Host")}");
cs.Add($"Port={configuration.GetValue("Port", 5432)}");
cs.Add($"Database={configuration.GetValue("Name")}");
cs.Add($"Search Path={configuration.GetValue("Schema", "public")}");
cs.Add($"User Id={configuration.GetValue("Username")}");
cs.Add($"Password={configuration.GetValue("Password")}");
cs.Add($"Timeout=15");
break;
case "sqlite":
string path = configuration.GetValue("File");
if (!Path.IsPathRooted(path))
{
if (string.IsNullOrWhiteSpace(options.AbsoluteBasePath))
options.AbsoluteBasePath = AppContext.BaseDirectory;
path = Path.Combine(options.AbsoluteBasePath, path);
}
cs.Add($"Data Source={path}");
cs.Add("Foreign Keys=True");
break;
case "sqlserver":
case "mssql":
cs.Add($"Server={configuration.GetValue("Host")},{configuration.GetValue("Port", 1433)}");
cs.Add($"Database={configuration.GetValue("Name")}");
if (!string.IsNullOrWhiteSpace(configuration.GetValue("Username")))
{
cs.Add($"User Id={configuration.GetValue("Username")}");
cs.Add($"Password={configuration.GetValue("Password")}");
cs.Add("Integrated Security=False");
}
else
{
cs.Add("Integrated Security=True");
}
cs.Add("Connect Timeout=15");
break;
case "memory":
case "inmemory":
cs.Add(configuration.GetValue("Name", provider));
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
}
return string.Join(";", cs);
}
}
}