1
0
Files
common/AMWD.Common.EntityFrameworkCore/Extensions/DbContextOptionsBuilderExtensions.cs
2024-01-14 13:10:33 +01:00

249 lines
11 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using Microsoft.Extensions.Configuration;
namespace Microsoft.EntityFrameworkCore
{
/// <summary>
/// Extends the <see cref="DbContextOptionsBuilder"/> to use a configurable database provider.
/// </summary>
public static class DbContextOptionsBuilderExtensions
{
/// <summary>
/// Adds the supported database provider to the context.
/// </summary>
/// <remarks>
/// The configuration provided requires the following entries:
/// <list type="bullet">
/// <item><strong>Provider</strong>: MySQL | Oracle | PostgreSQL | SQLite | SQLServer</item>
/// <item><strong>Host</strong>: hostname or IP address</item>
/// <item><strong>Port</strong>: port number</item>
/// <item><strong>Name</strong>: database name</item>
/// <item><strong>Schema</strong>: schema or search path (e.g. PostgreSQL: public)</item>
/// <item><strong>Username</strong>: username credential on the database</item>
/// <item><strong>Password</strong>: password credential on the database</item>
/// <item><strong>File</strong>: file name / path (for SQLite)</item>
/// </list>
/// </remarks>
/// <param name="optionsBuilder">The options builder.</param>
/// <param name="configuration">The application configuration section for the database.</param>
/// <param name="optionsAction">An optional action to set additional options.</param>
/// <returns>The <see cref="DbContextOptionsBuilder"/> with applied settings.</returns>
public static DbContextOptionsBuilder UseDatabaseProvider(this DbContextOptionsBuilder optionsBuilder, IConfiguration configuration, Action<DatabaseProviderOptions> optionsAction = null)
{
#if NET8_0_OR_GREATER
ArgumentNullException.ThrowIfNull(optionsBuilder);
ArgumentNullException.ThrowIfNull(configuration);
#else
if (optionsBuilder == null)
throw new ArgumentNullException(nameof(optionsBuilder));
if (configuration == null)
throw new ArgumentNullException(nameof(configuration));
#endif
var options = new DatabaseProviderOptions();
optionsAction?.Invoke(options);
string connectionString = GetConnectionString(configuration, options);
string provider = configuration.GetValue<string>("provider")?.ToLower();
var builderType = GetBuilderType(configuration);
var extensionType = GetExtensionType(configuration);
var actionType = typeof(Action<>).MakeGenericType(builderType);
object serverVersion = null;
MethodInfo methodInfo;
switch (provider)
{
case "mysql":
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
if (methodInfo == null)
methodInfo = extensionType.GetMethod("UseMySQL", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
if (methodInfo == null) // Pomelo MySQL v5
{
var serverVersionType = Type.GetType("Microsoft.EntityFrameworkCore.ServerVersion, Pomelo.EntityFrameworkCore.MySql");
var autoDetectMethodInfo = serverVersionType.GetMethod("AutoDetect", new Type[] { typeof(string) });
methodInfo = extensionType.GetMethod("UseMySql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), serverVersionType, actionType });
serverVersion = autoDetectMethodInfo.Invoke(null, new object[] { connectionString });
}
break;
case "oracle":
methodInfo = extensionType.GetMethod("UseOracle", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "postgres":
case "postgresql":
methodInfo = extensionType.GetMethod("UseNpgsql", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "sqlite":
methodInfo = extensionType.GetMethod("UseSqlite", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "sqlserver":
case "mssql":
methodInfo = extensionType.GetMethod("UseSqlServer", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
case "memory":
case "inmemory":
methodInfo = extensionType.GetMethod("UseInMemoryDatabase", new Type[] { typeof(DbContextOptionsBuilder), typeof(string), actionType });
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
}
if (serverVersion == null)
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, null });
}
else
{
methodInfo?.Invoke(null, new object[] { optionsBuilder, connectionString, serverVersion, null });
}
return optionsBuilder;
}
private static Type GetBuilderType(IConfiguration configuration)
{
string provider = configuration.GetValue<string>("provider")?.ToLower();
Type builderType;
switch (provider)
{
case "mysql":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.MySqlDbContextOptionsBuilder, Pomelo.EntityFrameworkCore.MySql");
if (builderType == null)
builderType = Type.GetType("MySql.Data.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.Data.EntityFrameworkCore");
if (builderType == null) // as MySql.Data.EntityFrameworkCore is marked as deprecated on NuGet
builderType = Type.GetType("MySql.EntityFrameworkCore.Infrastructure.MySQLDbContextOptionsBuilder, MySql.EntityFrameworkCore");
break;
case "oracle":
builderType = Type.GetType("Oracle.EntityFrameworkCore.Infrastructure.OracleDbContextOptionsBuilder, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
builderType = Type.GetType("Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure.NpgsqlDbContextOptionsBuilder, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqliteDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.SqlServerDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
builderType = Type.GetType("Microsoft.EntityFrameworkCore.Infrastructure.InMemoryDbContextOptionsBuilder, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return builderType;
}
private static Type GetExtensionType(IConfiguration configuration)
{
string provider = configuration.GetValue<string>("provider")?.ToLower();
Type extensionType;
switch (provider)
{
case "mysql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySqlDbContextOptionsBuilderExtensions, Pomelo.EntityFrameworkCore.MySql");
if (extensionType == null)
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.Data.EntityFrameworkCore");
if (extensionType == null)
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.MySQLDbContextOptionsExtensions, MySql.EntityFrameworkCore");
break;
case "oracle":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.OracleDbContextOptionsBuilderExtensions, Oracle.EntityFrameworkCore");
break;
case "postgres":
case "postgresql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.NpgsqlDbContextOptionsBuilderExtensions, Npgsql.EntityFrameworkCore.PostgreSQL");
break;
case "sqlite":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqliteDbContextOptionsBuilderExtensions, Microsoft.EntityFrameworkCore.Sqlite");
break;
case "sqlserver":
case "mssql":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.SqlServerDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.SqlServer");
break;
case "memory":
case "inmemory":
extensionType = Type.GetType("Microsoft.EntityFrameworkCore.InMemoryDbContextOptionsExtensions, Microsoft.EntityFrameworkCore.InMemory");
break;
default:
throw new ArgumentException($"Unknown database provider: {provider}");
}
return extensionType;
}
private static string GetConnectionString(IConfiguration configuration, DatabaseProviderOptions options)
{
var cs = new List<string>();
string provider = configuration.GetValue<string>("provider")?.ToLower();
switch (provider)
{
case "mysql":
cs.Add($"Server={configuration.GetValue<string>("Host")}");
cs.Add($"Port={configuration.GetValue("Port", 3306)}");
cs.Add($"Database={configuration.GetValue<string>("Name")}");
cs.Add($"Uid={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "oracle":
cs.Add($"Data Source=(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST={configuration.GetValue<string>("Host")})(PORT={configuration.GetValue("Port", 1521)}))(CONNECT_DATA=(SERVICE_NAME={configuration.GetValue<string>("Name")})))");
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Connection Timeout=15");
break;
case "postgres":
case "postgresql":
cs.Add($"Server={configuration.GetValue<string>("Host")}");
cs.Add($"Port={configuration.GetValue("Port", 5432)}");
cs.Add($"Database={configuration.GetValue<string>("Name")}");
cs.Add($"Search Path={configuration.GetValue("Schema", "public")}");
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add($"Timeout=15");
break;
case "sqlite":
string path = configuration.GetValue<string>("File");
if (!Path.IsPathRooted(path))
{
if (string.IsNullOrWhiteSpace(options.AbsoluteBasePath))
options.AbsoluteBasePath = AppContext.BaseDirectory;
path = Path.Combine(options.AbsoluteBasePath, path);
}
cs.Add($"Data Source={path}");
cs.Add("Foreign Keys=True");
break;
case "sqlserver":
case "mssql":
cs.Add($"Server={configuration.GetValue<string>("Host")},{configuration.GetValue("Port", 1433)}");
cs.Add($"Database={configuration.GetValue<string>("Name")}");
if (!string.IsNullOrWhiteSpace(configuration.GetValue<string>("Username")))
{
cs.Add($"User Id={configuration.GetValue<string>("Username")}");
cs.Add($"Password={configuration.GetValue<string>("Password")}");
cs.Add("Integrated Security=False");
}
else
{
cs.Add("Integrated Security=True");
}
cs.Add("Connect Timeout=15");
break;
case "memory":
case "inmemory":
cs.Add(configuration.GetValue("Name", provider));
break;
default:
throw new DatabaseProviderException($"Unknown database provider: {provider}");
}
return string.Join(";", cs);
}
}
}