pub mod cache; mod cli; mod logging; mod port; pub use cli::Cli; use serde::Deserialize; use url::Url; use crate::config::{ cache::{CacheConfig, RedisVariant}, logging::LogLevel, }; #[derive(Default, Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] pub enum Environment { #[default] Dev, Prod, } #[derive(Debug, Deserialize, Default)] #[serde(rename_all = "kebab-case")] pub struct Config { #[serde(default)] pub database: DatabaseOptions, #[serde(default)] pub cache: CacheConfig, #[serde(default)] pub server: Api, #[serde(default)] #[cfg(feature = "oauth")] pub oauth: OAuth, } #[derive(Debug, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct Api { #[serde(default = "default_domain")] pub domain: String, #[serde(default = "default_request_timeout")] pub request_timeout: u64, #[serde(default = "default_port")] pub port: u16, #[serde(default = "default_log_level")] pub log_level: LogLevel, #[serde(default = "default_sys_name")] pub system_name: String, #[serde(default)] pub environment: Environment, } #[derive(Debug, Clone, Deserialize)] #[cfg(feature = "oauth")] pub struct OAuth { #[cfg(feature = "oauth-discord")] pub discord: DiscordOauth, #[serde(rename = "redirect-url")] pub oauth_redirect_url: Url, } #[cfg(feature = "oauth-discord")] #[derive(Debug, Clone, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct DiscordOauth { pub client_id: String, pub client_secret: secrecy::SecretString, #[serde(default = "discord_token_url")] pub token_url: Url, #[serde(default = "discord_auth_url")] pub auth_url: Url, } #[cfg(feature = "oauth-discord")] fn discord_token_url() -> Url { Url::parse("https://discord.com/api/oauth2/token").expect("valid url") } #[cfg(feature = "oauth-discord")] fn discord_auth_url() -> Url { Url::parse("https://discord.com/api/oauth2/authorize?response_type=code").expect("valid url") } #[cfg(feature = "oauth")] fn redirect_url() -> Url { Url::parse("http://127.0.0.1:2210/auth/authorised").expect("valid url") } #[cfg(feature = "oauth")] impl Default for OAuth { fn default() -> Self { Self { #[cfg(feature = "oauth-discord")] discord: DiscordOauth { client_id: String::default(), client_secret: secrecy::SecretString::default(), token_url: discord_token_url(), auth_url: discord_auth_url(), }, oauth_redirect_url: redirect_url(), } } } impl Default for Api { fn default() -> Self { Self { domain: default_domain(), request_timeout: default_request_timeout(), port: default_port(), log_level: default_log_level(), system_name: default_sys_name(), environment: Environment::default(), } } } #[derive(Clone, Debug, Deserialize)] #[serde(rename_all = "kebab-case")] pub struct DatabaseOptions { #[serde(default = "default_database")] pub url: Url, pub pool_size: u32, } impl DatabaseOptions { pub fn create(url: &Url, pool_size: Option) -> Self { Self { url: url.to_owned(), pool_size: pool_size.unwrap_or_else(|| { let def = 100; tracing::debug!(size = def, "Setting default db pool size"); def }), } } } fn default_database() -> Url { Url::parse("postgres://postgres:password@localhost:5432/sellershut") .expect("valid default DATABASE url") } impl Default for DatabaseOptions { fn default() -> Self { Self { url: default_database(), pool_size: 100, } } } fn default_sys_name() -> String { "sellershut".to_string() } fn default_domain() -> String { "localhost".to_string() } fn default_request_timeout() -> u64 { 10 } fn default_port() -> u16 { 2210 } fn default_log_level() -> LogLevel { LogLevel::Debug } impl Config { pub fn merge_with_cli(&mut self, cli: &Cli) { let server = &mut self.server; let dsn = &mut self.database; let cache = &mut self.cache; if let Some(port) = cli.port { server.port = port; } if let Some(domain) = &cli.domain { server.domain = domain.to_string(); } if let Some(log_level) = &cli.log_level { server.log_level = *log_level; } if let Some(timeout) = cli.timeout_duration { server.request_timeout = timeout; } if let Some(db_url) = &cli.db { dsn.url = db_url.clone(); } if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_url.clone()) { cache.redis_dsn = c; } if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_pooled) { cache.pooled = c; } if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_max_conn) { cache.max_connections = c; } if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_type.clone()) { cache.kind = match c { cli::cache::RedisVariant::Clustered => RedisVariant::Clustered, cli::cache::RedisVariant::NonClustered => RedisVariant::NonClustered, cli::cache::RedisVariant::Sentinel => cache.kind.clone(), }; } } } #[cfg(test)] mod tests { use crate::config::Config; #[test] fn config_file() { let s = include_str!("../../misc/sellershut.toml"); assert!(toml::from_str::(s).is_ok()) } }