diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/config/cli.rs | 23 | ||||
| -rw-r--r-- | src/config/mod.rs | 60 | ||||
| -rw-r--r-- | src/main.rs | 2 | ||||
| -rw-r--r-- | src/server/state/database.rs | 1 | ||||
| -rw-r--r-- | src/server/state/mod.rs | 27 |
5 files changed, 96 insertions, 17 deletions
diff --git a/src/config/cli.rs b/src/config/cli.rs index dab7216..5254135 100644 --- a/src/config/cli.rs +++ b/src/config/cli.rs @@ -1,6 +1,7 @@ use std::path::PathBuf; use clap::Parser; +use serde::Deserialize; use url::Url; use crate::config::{logging::LogLevel, port::port_in_range}; @@ -49,25 +50,33 @@ pub struct Cli { pub oauth: Option<OAuth>, } -#[derive(Debug, Clone, Parser)] +#[derive(Debug, Clone, Parser, Deserialize)] pub struct OAuth { #[cfg(feature = "oauth-discord")] #[command(flatten)] discord: DiscordOauth, - #[arg(long)] + #[arg(long, env = "OAUTH_REDIRECT_URL")] oauth_redirect_url: Option<Url>, } #[cfg(feature = "oauth-discord")] -#[derive(Debug, Clone, Parser)] +#[derive(Debug, Clone, Parser, Deserialize)] pub struct DiscordOauth { - #[arg(long)] + #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")] discord_client_id: Option<String>, - #[arg(long)] + #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")] discord_client_secret: Option<String>, - #[arg(long)] + #[arg( + long, + env = "OAUTH_DISCORD_TOKEN_URL", + default_value = "https://discord.com/api/oauth2/token" + )] discord_token_url: Option<Url>, - #[arg(long)] + #[arg( + long, + env = "OAUTH_DISCORD_AUTH_URL", + default_value = "https://discord.com/api/oauth2/authorize?response_type=code" + )] discord_auth_url: Option<Url>, } diff --git a/src/config/mod.rs b/src/config/mod.rs index 45e12c3..19ee241 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,10 +2,12 @@ mod cli; mod logging; mod port; pub use cli::Cli; +#[cfg(feature = "oauth-discord")] +use secrecy::SecretString; use serde::Deserialize; use url::Url; -use crate::{config::logging::LogLevel}; +use crate::config::logging::LogLevel; #[derive(Default, Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "kebab-case")] @@ -22,6 +24,8 @@ pub struct Config { pub database: DatabaseOptions, #[serde(default)] pub server: Api, + #[serde(default)] + pub oauth: OAuth, } #[derive(Debug, Deserialize)] @@ -46,6 +50,52 @@ pub struct Api { pub environment: Environment, } +#[derive(Debug, Clone, Deserialize)] +pub struct OAuth { + #[cfg(feature = "oauth-discord")] + pub discord: DiscordOauth, + #[serde(rename = "redirect-url")] + pub oauth_redirect_url: Url, +} + +#[cfg(feature = "oauth-discord")] +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub struct DiscordOauth { + pub client_id: String, + pub client_secret: SecretString, + #[serde(default = "discord_token_url")] + pub token_url: Url, + #[serde(default = "discord_auth_url")] + pub auth_url: Url, +} + +fn discord_token_url() -> Url { + Url::parse("https://discord.com/api/oauth2/authorize?response_type=code").expect("valid url") +} + +fn discord_auth_url() -> Url { + Url::parse("https://discord.com/api/oauth2/authorize?response_type=code").expect("valid url") +} + +fn redirect_url() -> Url { + Url::parse("http://127.0.0.1:2210/auth/authorised").expect("valid url") +} + +impl Default for OAuth { + fn default() -> Self { + Self { + discord: DiscordOauth { + client_id: String::default(), + client_secret: SecretString::default(), + token_url: discord_token_url(), + auth_url: discord_auth_url(), + }, + oauth_redirect_url: redirect_url(), + } +} +} + impl Default for Api { fn default() -> Self { Self { @@ -68,7 +118,7 @@ pub struct DatabaseOptions { } impl DatabaseOptions { - pub fn create(url: & Url, pool_size: Option<u32>) -> Self { + pub fn create(url: &Url, pool_size: Option<u32>) -> Self { Self { url: url.to_owned(), pool_size: pool_size.unwrap_or_else(|| { @@ -89,8 +139,7 @@ impl Default for DatabaseOptions { fn default() -> Self { Self { url: default_database(), - pool_size: 100 - + pool_size: 100, } } } @@ -115,7 +164,6 @@ fn default_log_level() -> LogLevel { LogLevel::Debug } - impl Config { pub fn merge_with_cli(&mut self, cli: &Cli) { let server = &mut self.server; @@ -149,7 +197,7 @@ mod tests { #[test] fn config_file() { - let s = include_str!("../../sellershut.toml"); + let s = include_str!("../../misc/sellershut.toml"); assert!(toml::from_str::<Config>(s).is_ok()) } } diff --git a/src/main.rs b/src/main.rs index cb8c2a9..8ee10a1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use clap::Parser; use tokio::net::TcpListener; use tracing::info; -use crate::{config::Config, logging::initialise_logging, server::state::{AppState }}; +use crate::{config::Config, logging::initialise_logging, server::state::AppState}; #[tokio::main] async fn main() -> anyhow::Result<()> { diff --git a/src/server/state/database.rs b/src/server/state/database.rs index 32d3f98..f8fd332 100644 --- a/src/server/state/database.rs +++ b/src/server/state/database.rs @@ -4,7 +4,6 @@ use tracing::{debug, trace}; use crate::config::DatabaseOptions; - pub(super) async fn connect(opts: &DatabaseOptions) -> Result<PgPool> { trace!(host = ?opts.url.host(), "connecting to database"); let pg = PgPoolOptions::new() diff --git a/src/server/state/mod.rs b/src/server/state/mod.rs index f4bf029..0726689 100644 --- a/src/server/state/mod.rs +++ b/src/server/state/mod.rs @@ -1,17 +1,40 @@ pub mod database; +use sellershut_auth::{ClientOptions, OauthClient}; use sqlx::PgPool; +#[cfg(feature = "oauth-discord")] +use url::Url; -use crate::{config::Config}; +use crate::config::Config; +#[cfg(feature = "oauth-discord")] +use crate::config::DiscordOauth; pub struct AppState { database: PgPool, + #[cfg(feature = "oauth-discord")] + oauth_discord: OauthClient, } impl AppState { pub async fn new(config: &Config) -> anyhow::Result<Self> { let database = database::connect(&config.database).await?; - Ok(Self{database}) + Ok(Self { + database, + oauth_discord: discord_client(&config.oauth.discord, &config.oauth.oauth_redirect_url)?, + }) } } + +#[cfg(feature = "oauth-discord")] +fn discord_client(disc: &DiscordOauth, redirect: &Url)->anyhow::Result<OauthClient> { + let discord_opts = ClientOptions::builder() + .client_id(disc.client_id.to_owned()) + .redirect_url(redirect.to_string()) + .auth_url(disc.auth_url.to_string()) + .client_secret(disc.client_secret.clone()) + .token_url(disc.token_url.to_string()) + .build(); + + Ok(sellershut_auth::oauth_client(&discord_opts)?) +} |
