diff options
| -rw-r--r-- | Cargo.lock | 82 | ||||
| -rw-r--r-- | Cargo.toml | 1 | ||||
| -rw-r--r-- | crates/api-auth/Cargo.toml | 1 | ||||
| -rw-r--r-- | crates/api-auth/src/client.rs | 58 | ||||
| -rw-r--r-- | crates/api-auth/src/discord/mod.rs | 77 | ||||
| -rw-r--r-- | crates/api-auth/src/error.rs | 7 | ||||
| -rw-r--r-- | crates/api-auth/src/lib.rs | 10 | ||||
| -rw-r--r-- | crates/api-core/src/models/user.rs | 3 | ||||
| -rw-r--r-- | crates/sellershut/Cargo.toml | 4 | ||||
| -rw-r--r-- | crates/sellershut/src/config/auth/discord.rs | 15 | ||||
| -rw-r--r-- | crates/sellershut/src/main.rs | 1 | ||||
| -rw-r--r-- | crates/sellershut/src/server/api/routes/auth/authorised.rs | 64 | ||||
| -rw-r--r-- | crates/sellershut/src/server/api/routes/auth/mod.rs | 6 | ||||
| -rw-r--r-- | crates/sellershut/src/state/mod.rs | 3 |
14 files changed, 312 insertions, 20 deletions
@@ -97,6 +97,7 @@ dependencies = [ "async-trait", "oauth2", "redis", + "reqwest 0.13.2", "secrecy", "serde", "sh-util", @@ -280,6 +281,28 @@ dependencies = [ ] [[package]] +name = "axum-extra" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fef252edff26ddba56bbcdf2ee3307b8129acb86f5749b68990c168a6fcc9c76" +dependencies = [ + "axum", + "axum-core", + "bytes", + "futures-core", + "futures-util", + "headers", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] name = "axum-macros" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -996,6 +1019,30 @@ dependencies = [ ] [[package]] +name = "headers" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" +dependencies = [ + "base64 0.22.1", + "bytes", + "headers-core", + "http", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http", +] + +[[package]] name = "heck" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1545,7 +1592,7 @@ dependencies = [ "getrandom 0.2.17", "http", "rand 0.8.5", - "reqwest", + "reqwest 0.12.28", "serde", "serde_json", "serde_path_to_error", @@ -1955,6 +2002,37 @@ dependencies = [ ] [[package]] +name = "reqwest" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" +dependencies = [ + "base64 0.22.1", + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "sync_wrapper", + "tokio", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + +[[package]] name = "ring" version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -2108,8 +2186,10 @@ dependencies = [ "api-auth", "api-core", "axum", + "axum-extra", "bon", "clap", + "reqwest 0.13.2", "secrecy", "serde", "serde_json", @@ -14,6 +14,7 @@ async-trait = "0.1.89" axum = "0.8.8" futures-util = "0.3.32" redis = { version = "1.1.0", default-features = false } +reqwest = { version = "0.13.2", default-features = false } secrecy = "0.10.3" serde = "1.0.228" serde_json = "1.0.149" diff --git a/crates/api-auth/Cargo.toml b/crates/api-auth/Cargo.toml index 5ce0647..518762b 100644 --- a/crates/api-auth/Cargo.toml +++ b/crates/api-auth/Cargo.toml @@ -12,6 +12,7 @@ api-core = { workspace = true, features = ["auth", "users"] } async-trait.workspace = true oauth2 = "5.0.0" redis.workspace = true +reqwest = { workspace = true, features = ["json"] } secrecy.workspace = true serde.workspace = true sh-util = { workspace = true, optional = true } diff --git a/crates/api-auth/src/client.rs b/crates/api-auth/src/client.rs new file mode 100644 index 0000000..d696162 --- /dev/null +++ b/crates/api-auth/src/client.rs @@ -0,0 +1,58 @@ +use std::pin::Pin; +use std::{future::Future, ops::Deref}; + +#[cfg(not(target_arch = "wasm32"))] +use oauth2::HttpResponse; +use oauth2::{AsyncHttpClient, HttpClientError, HttpRequest, http}; + +#[derive(Clone)] +pub struct AuthHttpClient(reqwest::Client); + +impl Deref for AuthHttpClient { + type Target = reqwest::Client; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From<reqwest::Client> for AuthHttpClient { + fn from(value: reqwest::Client) -> Self { + Self(value) + } +} + +impl<'c> AsyncHttpClient<'c> for AuthHttpClient { + type Error = HttpClientError<reqwest::Error>; + + #[cfg(target_arch = "wasm32")] + type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + 'c>>; + #[cfg(not(target_arch = "wasm32"))] + type Future = + Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>; + + fn call(&'c self, request: HttpRequest) -> Self::Future { + Box::pin(async move { + let response = self + .0 + .execute(request.try_into().map_err(Box::new)?) + .await + .map_err(Box::new)?; + + let mut builder = http::Response::builder().status(response.status()); + + #[cfg(not(target_arch = "wasm32"))] + { + builder = builder.version(response.version()); + } + + for (name, value) in response.headers().iter() { + builder = builder.header(name, value); + } + + builder + .body(response.bytes().await.map_err(Box::new)?.to_vec()) + .map_err(HttpClientError::Http) + }) + } +} diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs index 1a7d47d..0844f58 100644 --- a/crates/api-auth/src/discord/mod.rs +++ b/crates/api-auth/src/discord/mod.rs @@ -1,12 +1,31 @@ use api_core::models::user::User; use async_session::{Session, serde_json}; use async_trait::async_trait; -use oauth2::{CsrfToken, Scope}; +use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; use redis::AsyncCommands; +use serde::{Deserialize, Serialize}; use sh_util::cache::{CacheKey, RedisManager}; use sqlx::PgPool; -use crate::{BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, error::AuthError}; +use crate::{ + BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, client::AuthHttpClient, error::AuthError, +}; + +// The user data we'll get back from Discord. +// https://discord.com/developers/docs/resources/user#user-object-user-structure +#[derive(Debug, Serialize, Deserialize)] +struct DiscordUser { + id: String, + avatar: Option<String>, + username: String, + discriminator: String, +} + +impl From<DiscordUser> for User { + fn from(value: DiscordUser) -> Self { + todo!() + } +} #[derive(Clone)] pub struct AuthServiceDiscord { @@ -27,11 +46,57 @@ impl AuthServiceDiscord { #[async_trait] impl OauthDriver for AuthServiceDiscord { - async fn get_auth_token(&self) -> Result<String, AuthError> { - todo!() + async fn get_user(&self, client: &AuthHttpClient, code: &str) -> Result<User, AuthError> { + // Get an auth token + let token = self + .client + .exchange_code(AuthorizationCode::new(code.to_owned())) + .request_async(client) + .await + .unwrap(); + // Fetch user data from discord + let user_data: DiscordUser = client + // https://discord.com/developers/docs/resources/user#get-current-user + .get("https://discordapp.com/api/users/@me") + .bearer_auth(token.access_token().secret()) + .send() + .await + .unwrap() + .json::<DiscordUser>() + .await + .unwrap(); + + Ok(user_data.into()) } - async fn get_user(&self) -> Result<User, AuthError> { - todo!() + async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError> { + let id = Session::id_from_cookie_value(cookie)?; + let cache_key = CacheKey::Session(&id); + let mut cache = self.cache.get().await.unwrap(); + let session = cache.get::<_, String>(&cache_key).await?; + let session: Session = + serde_json::from_str(&session).map_err(|_e| AuthError::InvalidSession)?; + + match session.validate() { + Some(session) => { + // Extract the CSRF token from the session + let stored_csrf_token = session.get::<CsrfToken>(CSRF_TOKEN); + + if let Some(stored) = stored_csrf_token { + // Cleanup the CSRF token session + cache.del::<_, ()>(cache_key).await?; + + // Validate CSRF token is the same as the one in the auth request + if *stored.secret() != state { + return Err(AuthError::TokenMismatch); + } else { + return Ok(()); + } + } else { + return Err(AuthError::NoCSRFToken); + } + } + None => return Err(AuthError::MissingSession), + } } async fn create_oauth_session(&self) -> Result<SessionResponse, AuthError> { let (auth_url, csrf_token) = self diff --git a/crates/api-auth/src/error.rs b/crates/api-auth/src/error.rs index 72a7fba..2db3281 100644 --- a/crates/api-auth/src/error.rs +++ b/crates/api-auth/src/error.rs @@ -1,3 +1,4 @@ +use async_session::base64; use thiserror::Error; #[derive(Debug, Error)] @@ -28,4 +29,10 @@ pub enum AuthError { MissingSession, #[error("invalid session")] InvalidSession, + #[error("invalid session")] + CorruptedCookie(#[from] base64::DecodeError), + #[error("CSRF token mismatch")] + TokenMismatch, + #[error("CSRF token missing")] + NoCSRFToken, } diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs index 85fdb01..815b170 100644 --- a/crates/api-auth/src/lib.rs +++ b/crates/api-auth/src/lib.rs @@ -1,6 +1,8 @@ #[cfg(feature = "discord")] pub mod discord; +pub mod client; + mod error; use api_core::auth::AuthClientConfig; use api_core::models::user::User; @@ -21,8 +23,12 @@ pub struct BasicClient(C); #[async_trait::async_trait] pub trait OauthDriver: Send + Sync { - async fn get_auth_token(&self) -> Result<String, AuthError>; - async fn get_user(&self) -> Result<User, AuthError>; + async fn get_user( + &self, + client: &client::AuthHttpClient, + code: &str, + ) -> Result<User, AuthError>; + async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError>; async fn create_oauth_session(&self) -> Result<SessionResponse, AuthError>; async fn save_session(&self, user: &User) -> Result<(), AuthError>; } diff --git a/crates/api-core/src/models/user.rs b/crates/api-core/src/models/user.rs index e6ad9f0..7b70234 100644 --- a/crates/api-core/src/models/user.rs +++ b/crates/api-core/src/models/user.rs @@ -1 +1,4 @@ +use serde::Deserialize; + +#[derive(Deserialize)] pub struct User {} diff --git a/crates/sellershut/Cargo.toml b/crates/sellershut/Cargo.toml index df619f8..1151373 100644 --- a/crates/sellershut/Cargo.toml +++ b/crates/sellershut/Cargo.toml @@ -13,8 +13,10 @@ anyhow = "1.0.102" api-auth = { path = "../api-auth", features = ["discord", "utoipa"] } api-core = { workspace = true, features = ["auth-discord", "utoipa"] } axum = { version = "0.8.8", features = ["macros"] } +axum-extra = { version = "0.12.5", features = ["typed-header"] } bon = "3.9.1" clap = { version = "4.6.0", features = ["derive", "env"] } +reqwest.workspace = true secrecy = { workspace = true, features = ["serde"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true @@ -37,7 +39,7 @@ utoipa-swagger-ui = { version = "9.0.2", features = ["axum"], optional = true } tower = { workspace = true, features = ["util"] } [features] -default = ["auth-discord"] +default = ["auth-discord", "scalar"] auth-discord = [] swagger = ["dep:utoipa-swagger-ui"] redoc = ["dep:utoipa-redoc"] diff --git a/crates/sellershut/src/config/auth/discord.rs b/crates/sellershut/src/config/auth/discord.rs index 24ad711..cfbca91 100644 --- a/crates/sellershut/src/config/auth/discord.rs +++ b/crates/sellershut/src/config/auth/discord.rs @@ -15,11 +15,11 @@ pub struct DiscordClientConfig { pub discord_client_secret: Option<String>, /// Redirect URI registered with Discord OAuth. - #[arg(long, env = "HUT_DISCORD_REDIRECT_URI")] + #[arg(long, env = "HUT_DISCORD_REDIRECT_URL")] pub discord_redirect_uri: Option<String>, /// Discord token endpoint URI. - #[arg(long, env = "HUT_DISCORD_TOKEN_URI")] + #[arg(long, env = "HUT_DISCORD_TOKEN_URL")] pub discord_token_uri: Option<String>, /// Discord authorization URL. @@ -42,10 +42,9 @@ impl DiscordClientConfig { Self { discord_client_id: self.discord_client_id, discord_client_secret: self.discord_client_secret, - discord_redirect_uri: Some( - self.discord_redirect_uri - .unwrap_or_else(|| "http://localhost:2210/auth/discord/callback".to_string()), - ), + discord_redirect_uri: Some(self.discord_redirect_uri.unwrap_or_else(|| { + "http://localhost:2210/api/auth/discord/authorised".to_string() + })), discord_token_uri: Some( self.discord_token_uri .unwrap_or_else(|| "https://discord.com/api/oauth2/token".to_string()), @@ -61,7 +60,9 @@ impl DiscordClientConfig { Self { discord_client_id: None, discord_client_secret: None, - discord_redirect_uri: Some("http://localhost:2210/auth/discord/callback".to_string()), + discord_redirect_uri: Some( + "http://localhost:2210/api/auth/discord/authorised".to_string(), + ), discord_token_uri: Some("https://discord.com/api/oauth2/token".to_string()), discord_auth_url: Some("https://discord.com/api/oauth2/authorize".to_string()), } diff --git a/crates/sellershut/src/main.rs b/crates/sellershut/src/main.rs index a46cf3e..25ee315 100644 --- a/crates/sellershut/src/main.rs +++ b/crates/sellershut/src/main.rs @@ -50,6 +50,7 @@ async fn main() -> Result<()> { .log_handle(log_handle) .base_service(Arc::new(BaseService)) .auth_clients(auth_clients) + .http_client(reqwest::Client::new().into()) .build(); let addr = SocketAddr::from(( diff --git a/crates/sellershut/src/server/api/routes/auth/authorised.rs b/crates/sellershut/src/server/api/routes/auth/authorised.rs index 8b13789..94eaeca 100644 --- a/crates/sellershut/src/server/api/routes/auth/authorised.rs +++ b/crates/sellershut/src/server/api/routes/auth/authorised.rs @@ -1 +1,65 @@ +use anyhow::Context; +use api_core::auth::provider::OauthProvider; +use axum::{ + extract::{Path, Query, State}, + response::IntoResponse, +}; +use axum_extra::{TypedHeader, headers}; +use serde::Deserialize; +use utoipa::{IntoParams, ToSchema}; +use crate::{server::api::error::AppError, state::AppState}; + +#[derive(Debug, Deserialize, IntoParams, ToSchema)] +#[allow(dead_code)] +pub struct AuthRequest { + pub code: String, + pub state: String, +} + +#[derive(Debug, Deserialize, IntoParams, ToSchema)] +#[allow(dead_code)] +pub struct Params { + provider: OauthProvider, +} + +/// Authorised callback +#[utoipa::path( + get, + responses( + ( + status = 200, + description = "Application authorised", + headers( + ("x-request-id", description = "Request identifier"), + ("set-cookie", description = "Oauth session cookie") + ) + ), + ), + operation_id = "authorised", // https://github.com/juhaku/utoipa/issues/1170 + path = "/auth/{provider}/authorised", + tag = super::AUTH, + params(AuthRequest, Params) +)] +pub async fn authorised( + Query(params): Query<AuthRequest>, + State(state): State<AppState>, + Path(provider): Path<OauthProvider>, + TypedHeader(cookies): TypedHeader<headers::Cookie>, +) -> Result<impl IntoResponse, AppError> { + let cookie = cookies + .get(super::COOKIE_NAME) + .context("unexpected error getting cookie name")? + .to_string(); + + let client = state + .auth_clients + .get(&provider) + .context("missing oauth driver")?; + + client.validate_session(&cookie, ¶ms.state).await?; + + let user = client.get_user(&state.http_client, ¶ms.code).await?; + + Ok(String::default()) +} diff --git a/crates/sellershut/src/server/api/routes/auth/mod.rs b/crates/sellershut/src/server/api/routes/auth/mod.rs index 7414b77..340525e 100644 --- a/crates/sellershut/src/server/api/routes/auth/mod.rs +++ b/crates/sellershut/src/server/api/routes/auth/mod.rs @@ -17,7 +17,7 @@ const AUTH: &str = "Authentication"; static COOKIE_NAME: &str = "SESSION"; #[derive(OpenApi)] -#[openapi(tags((name = AUTH, description = "Transaction monitoring endpoints")),components(schemas(OauthProvider))) ] +#[openapi(tags((name = AUTH, description = "Auth endpoints")),components(schemas(OauthProvider))) ] pub struct AuthDoc; /// Oauth provider @@ -30,7 +30,9 @@ pub fn router(store: AppState) -> OpenApiRouter<AppState> { let router = OpenApiRouter::new(); #[cfg(feature = "auth-discord")] - let router = router.routes(utoipa_axum::routes!(auth)); + let router = router + .routes(utoipa_axum::routes!(auth)) + .routes(utoipa_axum::routes!(authorised::authorised)); router.with_state(store) } diff --git a/crates/sellershut/src/state/mod.rs b/crates/sellershut/src/state/mod.rs index 821d4eb..1febc6a 100644 --- a/crates/sellershut/src/state/mod.rs +++ b/crates/sellershut/src/state/mod.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, sync::Arc}; -use api_auth::OauthDriver; +use api_auth::{OauthDriver, client::AuthHttpClient}; use api_core::{auth::provider::OauthProvider, health::HealthDriver}; use bon::Builder; use sqlx::PgPool; @@ -12,6 +12,7 @@ pub struct AppState { pub base_service: Arc<dyn HealthDriver>, pub log_handle: LogHandle, pub auth_clients: HashMap<OauthProvider, Arc<dyn OauthDriver>>, + pub http_client: AuthHttpClient, } pub async fn postgres(config: &str, pool_size: u32) -> anyhow::Result<PgPool> { |
