aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock82
-rw-r--r--Cargo.toml1
-rw-r--r--crates/api-auth/Cargo.toml1
-rw-r--r--crates/api-auth/src/client.rs58
-rw-r--r--crates/api-auth/src/discord/mod.rs77
-rw-r--r--crates/api-auth/src/error.rs7
-rw-r--r--crates/api-auth/src/lib.rs10
-rw-r--r--crates/api-core/src/models/user.rs3
-rw-r--r--crates/sellershut/Cargo.toml4
-rw-r--r--crates/sellershut/src/config/auth/discord.rs15
-rw-r--r--crates/sellershut/src/main.rs1
-rw-r--r--crates/sellershut/src/server/api/routes/auth/authorised.rs64
-rw-r--r--crates/sellershut/src/server/api/routes/auth/mod.rs6
-rw-r--r--crates/sellershut/src/state/mod.rs3
14 files changed, 312 insertions, 20 deletions
diff --git a/Cargo.lock b/Cargo.lock
index b488efb..e2663f8 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -97,6 +97,7 @@ dependencies = [
"async-trait",
"oauth2",
"redis",
+ "reqwest 0.13.2",
"secrecy",
"serde",
"sh-util",
@@ -280,6 +281,28 @@ dependencies = [
]
[[package]]
+name = "axum-extra"
+version = "0.12.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fef252edff26ddba56bbcdf2ee3307b8129acb86f5749b68990c168a6fcc9c76"
+dependencies = [
+ "axum",
+ "axum-core",
+ "bytes",
+ "futures-core",
+ "futures-util",
+ "headers",
+ "http",
+ "http-body",
+ "http-body-util",
+ "mime",
+ "pin-project-lite",
+ "tower-layer",
+ "tower-service",
+ "tracing",
+]
+
+[[package]]
name = "axum-macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -996,6 +1019,30 @@ dependencies = [
]
[[package]]
+name = "headers"
+version = "0.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb"
+dependencies = [
+ "base64 0.22.1",
+ "bytes",
+ "headers-core",
+ "http",
+ "httpdate",
+ "mime",
+ "sha1",
+]
+
+[[package]]
+name = "headers-core"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
+dependencies = [
+ "http",
+]
+
+[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1545,7 +1592,7 @@ dependencies = [
"getrandom 0.2.17",
"http",
"rand 0.8.5",
- "reqwest",
+ "reqwest 0.12.28",
"serde",
"serde_json",
"serde_path_to_error",
@@ -1955,6 +2002,37 @@ dependencies = [
]
[[package]]
+name = "reqwest"
+version = "0.13.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801"
+dependencies = [
+ "base64 0.22.1",
+ "bytes",
+ "futures-core",
+ "http",
+ "http-body",
+ "http-body-util",
+ "hyper",
+ "hyper-util",
+ "js-sys",
+ "log",
+ "percent-encoding",
+ "pin-project-lite",
+ "serde",
+ "serde_json",
+ "sync_wrapper",
+ "tokio",
+ "tower",
+ "tower-http",
+ "tower-service",
+ "url",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "web-sys",
+]
+
+[[package]]
name = "ring"
version = "0.17.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2108,8 +2186,10 @@ dependencies = [
"api-auth",
"api-core",
"axum",
+ "axum-extra",
"bon",
"clap",
+ "reqwest 0.13.2",
"secrecy",
"serde",
"serde_json",
diff --git a/Cargo.toml b/Cargo.toml
index a93b10e..5f10aa4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -14,6 +14,7 @@ async-trait = "0.1.89"
axum = "0.8.8"
futures-util = "0.3.32"
redis = { version = "1.1.0", default-features = false }
+reqwest = { version = "0.13.2", default-features = false }
secrecy = "0.10.3"
serde = "1.0.228"
serde_json = "1.0.149"
diff --git a/crates/api-auth/Cargo.toml b/crates/api-auth/Cargo.toml
index 5ce0647..518762b 100644
--- a/crates/api-auth/Cargo.toml
+++ b/crates/api-auth/Cargo.toml
@@ -12,6 +12,7 @@ api-core = { workspace = true, features = ["auth", "users"] }
async-trait.workspace = true
oauth2 = "5.0.0"
redis.workspace = true
+reqwest = { workspace = true, features = ["json"] }
secrecy.workspace = true
serde.workspace = true
sh-util = { workspace = true, optional = true }
diff --git a/crates/api-auth/src/client.rs b/crates/api-auth/src/client.rs
new file mode 100644
index 0000000..d696162
--- /dev/null
+++ b/crates/api-auth/src/client.rs
@@ -0,0 +1,58 @@
+use std::pin::Pin;
+use std::{future::Future, ops::Deref};
+
+#[cfg(not(target_arch = "wasm32"))]
+use oauth2::HttpResponse;
+use oauth2::{AsyncHttpClient, HttpClientError, HttpRequest, http};
+
+#[derive(Clone)]
+pub struct AuthHttpClient(reqwest::Client);
+
+impl Deref for AuthHttpClient {
+ type Target = reqwest::Client;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<reqwest::Client> for AuthHttpClient {
+ fn from(value: reqwest::Client) -> Self {
+ Self(value)
+ }
+}
+
+impl<'c> AsyncHttpClient<'c> for AuthHttpClient {
+ type Error = HttpClientError<reqwest::Error>;
+
+ #[cfg(target_arch = "wasm32")]
+ type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + 'c>>;
+ #[cfg(not(target_arch = "wasm32"))]
+ type Future =
+ Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + Sync + 'c>>;
+
+ fn call(&'c self, request: HttpRequest) -> Self::Future {
+ Box::pin(async move {
+ let response = self
+ .0
+ .execute(request.try_into().map_err(Box::new)?)
+ .await
+ .map_err(Box::new)?;
+
+ let mut builder = http::Response::builder().status(response.status());
+
+ #[cfg(not(target_arch = "wasm32"))]
+ {
+ builder = builder.version(response.version());
+ }
+
+ for (name, value) in response.headers().iter() {
+ builder = builder.header(name, value);
+ }
+
+ builder
+ .body(response.bytes().await.map_err(Box::new)?.to_vec())
+ .map_err(HttpClientError::Http)
+ })
+ }
+}
diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs
index 1a7d47d..0844f58 100644
--- a/crates/api-auth/src/discord/mod.rs
+++ b/crates/api-auth/src/discord/mod.rs
@@ -1,12 +1,31 @@
use api_core::models::user::User;
use async_session::{Session, serde_json};
use async_trait::async_trait;
-use oauth2::{CsrfToken, Scope};
+use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse};
use redis::AsyncCommands;
+use serde::{Deserialize, Serialize};
use sh_util::cache::{CacheKey, RedisManager};
use sqlx::PgPool;
-use crate::{BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, error::AuthError};
+use crate::{
+ BasicClient, CSRF_TOKEN, OauthDriver, SessionResponse, client::AuthHttpClient, error::AuthError,
+};
+
+// The user data we'll get back from Discord.
+// https://discord.com/developers/docs/resources/user#user-object-user-structure
+#[derive(Debug, Serialize, Deserialize)]
+struct DiscordUser {
+ id: String,
+ avatar: Option<String>,
+ username: String,
+ discriminator: String,
+}
+
+impl From<DiscordUser> for User {
+ fn from(value: DiscordUser) -> Self {
+ todo!()
+ }
+}
#[derive(Clone)]
pub struct AuthServiceDiscord {
@@ -27,11 +46,57 @@ impl AuthServiceDiscord {
#[async_trait]
impl OauthDriver for AuthServiceDiscord {
- async fn get_auth_token(&self) -> Result<String, AuthError> {
- todo!()
+ async fn get_user(&self, client: &AuthHttpClient, code: &str) -> Result<User, AuthError> {
+ // Get an auth token
+ let token = self
+ .client
+ .exchange_code(AuthorizationCode::new(code.to_owned()))
+ .request_async(client)
+ .await
+ .unwrap();
+ // Fetch user data from discord
+ let user_data: DiscordUser = client
+ // https://discord.com/developers/docs/resources/user#get-current-user
+ .get("https://discordapp.com/api/users/@me")
+ .bearer_auth(token.access_token().secret())
+ .send()
+ .await
+ .unwrap()
+ .json::<DiscordUser>()
+ .await
+ .unwrap();
+
+ Ok(user_data.into())
}
- async fn get_user(&self) -> Result<User, AuthError> {
- todo!()
+ async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError> {
+ let id = Session::id_from_cookie_value(cookie)?;
+ let cache_key = CacheKey::Session(&id);
+ let mut cache = self.cache.get().await.unwrap();
+ let session = cache.get::<_, String>(&cache_key).await?;
+ let session: Session =
+ serde_json::from_str(&session).map_err(|_e| AuthError::InvalidSession)?;
+
+ match session.validate() {
+ Some(session) => {
+ // Extract the CSRF token from the session
+ let stored_csrf_token = session.get::<CsrfToken>(CSRF_TOKEN);
+
+ if let Some(stored) = stored_csrf_token {
+ // Cleanup the CSRF token session
+ cache.del::<_, ()>(cache_key).await?;
+
+ // Validate CSRF token is the same as the one in the auth request
+ if *stored.secret() != state {
+ return Err(AuthError::TokenMismatch);
+ } else {
+ return Ok(());
+ }
+ } else {
+ return Err(AuthError::NoCSRFToken);
+ }
+ }
+ None => return Err(AuthError::MissingSession),
+ }
}
async fn create_oauth_session(&self) -> Result<SessionResponse, AuthError> {
let (auth_url, csrf_token) = self
diff --git a/crates/api-auth/src/error.rs b/crates/api-auth/src/error.rs
index 72a7fba..2db3281 100644
--- a/crates/api-auth/src/error.rs
+++ b/crates/api-auth/src/error.rs
@@ -1,3 +1,4 @@
+use async_session::base64;
use thiserror::Error;
#[derive(Debug, Error)]
@@ -28,4 +29,10 @@ pub enum AuthError {
MissingSession,
#[error("invalid session")]
InvalidSession,
+ #[error("invalid session")]
+ CorruptedCookie(#[from] base64::DecodeError),
+ #[error("CSRF token mismatch")]
+ TokenMismatch,
+ #[error("CSRF token missing")]
+ NoCSRFToken,
}
diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs
index 85fdb01..815b170 100644
--- a/crates/api-auth/src/lib.rs
+++ b/crates/api-auth/src/lib.rs
@@ -1,6 +1,8 @@
#[cfg(feature = "discord")]
pub mod discord;
+pub mod client;
+
mod error;
use api_core::auth::AuthClientConfig;
use api_core::models::user::User;
@@ -21,8 +23,12 @@ pub struct BasicClient(C);
#[async_trait::async_trait]
pub trait OauthDriver: Send + Sync {
- async fn get_auth_token(&self) -> Result<String, AuthError>;
- async fn get_user(&self) -> Result<User, AuthError>;
+ async fn get_user(
+ &self,
+ client: &client::AuthHttpClient,
+ code: &str,
+ ) -> Result<User, AuthError>;
+ async fn validate_session(&self, cookie: &str, state: &str) -> Result<(), AuthError>;
async fn create_oauth_session(&self) -> Result<SessionResponse, AuthError>;
async fn save_session(&self, user: &User) -> Result<(), AuthError>;
}
diff --git a/crates/api-core/src/models/user.rs b/crates/api-core/src/models/user.rs
index e6ad9f0..7b70234 100644
--- a/crates/api-core/src/models/user.rs
+++ b/crates/api-core/src/models/user.rs
@@ -1 +1,4 @@
+use serde::Deserialize;
+
+#[derive(Deserialize)]
pub struct User {}
diff --git a/crates/sellershut/Cargo.toml b/crates/sellershut/Cargo.toml
index df619f8..1151373 100644
--- a/crates/sellershut/Cargo.toml
+++ b/crates/sellershut/Cargo.toml
@@ -13,8 +13,10 @@ anyhow = "1.0.102"
api-auth = { path = "../api-auth", features = ["discord", "utoipa"] }
api-core = { workspace = true, features = ["auth-discord", "utoipa"] }
axum = { version = "0.8.8", features = ["macros"] }
+axum-extra = { version = "0.12.5", features = ["typed-header"] }
bon = "3.9.1"
clap = { version = "4.6.0", features = ["derive", "env"] }
+reqwest.workspace = true
secrecy = { workspace = true, features = ["serde"] }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
@@ -37,7 +39,7 @@ utoipa-swagger-ui = { version = "9.0.2", features = ["axum"], optional = true }
tower = { workspace = true, features = ["util"] }
[features]
-default = ["auth-discord"]
+default = ["auth-discord", "scalar"]
auth-discord = []
swagger = ["dep:utoipa-swagger-ui"]
redoc = ["dep:utoipa-redoc"]
diff --git a/crates/sellershut/src/config/auth/discord.rs b/crates/sellershut/src/config/auth/discord.rs
index 24ad711..cfbca91 100644
--- a/crates/sellershut/src/config/auth/discord.rs
+++ b/crates/sellershut/src/config/auth/discord.rs
@@ -15,11 +15,11 @@ pub struct DiscordClientConfig {
pub discord_client_secret: Option<String>,
/// Redirect URI registered with Discord OAuth.
- #[arg(long, env = "HUT_DISCORD_REDIRECT_URI")]
+ #[arg(long, env = "HUT_DISCORD_REDIRECT_URL")]
pub discord_redirect_uri: Option<String>,
/// Discord token endpoint URI.
- #[arg(long, env = "HUT_DISCORD_TOKEN_URI")]
+ #[arg(long, env = "HUT_DISCORD_TOKEN_URL")]
pub discord_token_uri: Option<String>,
/// Discord authorization URL.
@@ -42,10 +42,9 @@ impl DiscordClientConfig {
Self {
discord_client_id: self.discord_client_id,
discord_client_secret: self.discord_client_secret,
- discord_redirect_uri: Some(
- self.discord_redirect_uri
- .unwrap_or_else(|| "http://localhost:2210/auth/discord/callback".to_string()),
- ),
+ discord_redirect_uri: Some(self.discord_redirect_uri.unwrap_or_else(|| {
+ "http://localhost:2210/api/auth/discord/authorised".to_string()
+ })),
discord_token_uri: Some(
self.discord_token_uri
.unwrap_or_else(|| "https://discord.com/api/oauth2/token".to_string()),
@@ -61,7 +60,9 @@ impl DiscordClientConfig {
Self {
discord_client_id: None,
discord_client_secret: None,
- discord_redirect_uri: Some("http://localhost:2210/auth/discord/callback".to_string()),
+ discord_redirect_uri: Some(
+ "http://localhost:2210/api/auth/discord/authorised".to_string(),
+ ),
discord_token_uri: Some("https://discord.com/api/oauth2/token".to_string()),
discord_auth_url: Some("https://discord.com/api/oauth2/authorize".to_string()),
}
diff --git a/crates/sellershut/src/main.rs b/crates/sellershut/src/main.rs
index a46cf3e..25ee315 100644
--- a/crates/sellershut/src/main.rs
+++ b/crates/sellershut/src/main.rs
@@ -50,6 +50,7 @@ async fn main() -> Result<()> {
.log_handle(log_handle)
.base_service(Arc::new(BaseService))
.auth_clients(auth_clients)
+ .http_client(reqwest::Client::new().into())
.build();
let addr = SocketAddr::from((
diff --git a/crates/sellershut/src/server/api/routes/auth/authorised.rs b/crates/sellershut/src/server/api/routes/auth/authorised.rs
index 8b13789..94eaeca 100644
--- a/crates/sellershut/src/server/api/routes/auth/authorised.rs
+++ b/crates/sellershut/src/server/api/routes/auth/authorised.rs
@@ -1 +1,65 @@
+use anyhow::Context;
+use api_core::auth::provider::OauthProvider;
+use axum::{
+ extract::{Path, Query, State},
+ response::IntoResponse,
+};
+use axum_extra::{TypedHeader, headers};
+use serde::Deserialize;
+use utoipa::{IntoParams, ToSchema};
+use crate::{server::api::error::AppError, state::AppState};
+
+#[derive(Debug, Deserialize, IntoParams, ToSchema)]
+#[allow(dead_code)]
+pub struct AuthRequest {
+ pub code: String,
+ pub state: String,
+}
+
+#[derive(Debug, Deserialize, IntoParams, ToSchema)]
+#[allow(dead_code)]
+pub struct Params {
+ provider: OauthProvider,
+}
+
+/// Authorised callback
+#[utoipa::path(
+ get,
+ responses(
+ (
+ status = 200,
+ description = "Application authorised",
+ headers(
+ ("x-request-id", description = "Request identifier"),
+ ("set-cookie", description = "Oauth session cookie")
+ )
+ ),
+ ),
+ operation_id = "authorised", // https://github.com/juhaku/utoipa/issues/1170
+ path = "/auth/{provider}/authorised",
+ tag = super::AUTH,
+ params(AuthRequest, Params)
+)]
+pub async fn authorised(
+ Query(params): Query<AuthRequest>,
+ State(state): State<AppState>,
+ Path(provider): Path<OauthProvider>,
+ TypedHeader(cookies): TypedHeader<headers::Cookie>,
+) -> Result<impl IntoResponse, AppError> {
+ let cookie = cookies
+ .get(super::COOKIE_NAME)
+ .context("unexpected error getting cookie name")?
+ .to_string();
+
+ let client = state
+ .auth_clients
+ .get(&provider)
+ .context("missing oauth driver")?;
+
+ client.validate_session(&cookie, &params.state).await?;
+
+ let user = client.get_user(&state.http_client, &params.code).await?;
+
+ Ok(String::default())
+}
diff --git a/crates/sellershut/src/server/api/routes/auth/mod.rs b/crates/sellershut/src/server/api/routes/auth/mod.rs
index 7414b77..340525e 100644
--- a/crates/sellershut/src/server/api/routes/auth/mod.rs
+++ b/crates/sellershut/src/server/api/routes/auth/mod.rs
@@ -17,7 +17,7 @@ const AUTH: &str = "Authentication";
static COOKIE_NAME: &str = "SESSION";
#[derive(OpenApi)]
-#[openapi(tags((name = AUTH, description = "Transaction monitoring endpoints")),components(schemas(OauthProvider))) ]
+#[openapi(tags((name = AUTH, description = "Auth endpoints")),components(schemas(OauthProvider))) ]
pub struct AuthDoc;
/// Oauth provider
@@ -30,7 +30,9 @@ pub fn router(store: AppState) -> OpenApiRouter<AppState> {
let router = OpenApiRouter::new();
#[cfg(feature = "auth-discord")]
- let router = router.routes(utoipa_axum::routes!(auth));
+ let router = router
+ .routes(utoipa_axum::routes!(auth))
+ .routes(utoipa_axum::routes!(authorised::authorised));
router.with_state(store)
}
diff --git a/crates/sellershut/src/state/mod.rs b/crates/sellershut/src/state/mod.rs
index 821d4eb..1febc6a 100644
--- a/crates/sellershut/src/state/mod.rs
+++ b/crates/sellershut/src/state/mod.rs
@@ -1,6 +1,6 @@
use std::{collections::HashMap, sync::Arc};
-use api_auth::OauthDriver;
+use api_auth::{OauthDriver, client::AuthHttpClient};
use api_core::{auth::provider::OauthProvider, health::HealthDriver};
use bon::Builder;
use sqlx::PgPool;
@@ -12,6 +12,7 @@ pub struct AppState {
pub base_service: Arc<dyn HealthDriver>,
pub log_handle: LogHandle,
pub auth_clients: HashMap<OauthProvider, Arc<dyn OauthDriver>>,
+ pub http_client: AuthHttpClient,
}
pub async fn postgres(config: &str, pool_size: u32) -> anyhow::Result<PgPool> {