diff options
| author | rtkay123 <dev@kanjala.com> | 2026-02-02 13:05:49 +0200 |
|---|---|---|
| committer | rtkay123 <dev@kanjala.com> | 2026-02-02 13:05:49 +0200 |
| commit | e06094f23ca861ea5ae4864d11fa8ce8b7d7aa2c (patch) | |
| tree | 27bbff5fd21711f99aaf579a76b1a0aca7869003 /src | |
| parent | 78f61ccdf66572d7432b5b627994038479103653 (diff) | |
| download | sellershut-e06094f23ca861ea5ae4864d11fa8ce8b7d7aa2c.tar.bz2 sellershut-e06094f23ca861ea5ae4864d11fa8ce8b7d7aa2c.zip | |
feat: oauth route
Diffstat (limited to 'src')
| -rw-r--r-- | src/config/cli.rs | 1 | ||||
| -rw-r--r-- | src/config/mod.rs | 23 | ||||
| -rw-r--r-- | src/main.rs | 10 | ||||
| -rw-r--r-- | src/server/driver/mod.rs | 29 | ||||
| -rw-r--r-- | src/server/error/mod.rs | 27 | ||||
| -rw-r--r-- | src/server/middleware/mod.rs | 1 | ||||
| -rw-r--r-- | src/server/middleware/request_id.rs | 20 | ||||
| -rw-r--r-- | src/server/mod.rs | 125 | ||||
| -rw-r--r-- | src/server/routes/auth/discord.rs | 11 | ||||
| -rw-r--r-- | src/server/routes/auth/mod.rs | 59 | ||||
| -rw-r--r-- | src/server/routes/mod.rs | 49 | ||||
| -rw-r--r-- | src/server/state/database.rs | 2 | ||||
| -rw-r--r-- | src/server/state/federation.rs | 65 | ||||
| -rw-r--r-- | src/server/state/mod.rs | 17 |
14 files changed, 412 insertions, 27 deletions
diff --git a/src/config/cli.rs b/src/config/cli.rs index 5254135..7bc6312 100644 --- a/src/config/cli.rs +++ b/src/config/cli.rs @@ -47,6 +47,7 @@ pub struct Cli { /// Oauth optionas #[command(flatten)] + #[cfg(feature = "oauth")] pub oauth: Option<OAuth>, } diff --git a/src/config/mod.rs b/src/config/mod.rs index 19ee241..01af6d8 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,7 +2,7 @@ mod cli; mod logging; mod port; pub use cli::Cli; -#[cfg(feature = "oauth-discord")] +#[cfg(feature = "oauth")] use secrecy::SecretString; use serde::Deserialize; use url::Url; @@ -25,6 +25,7 @@ pub struct Config { #[serde(default)] pub server: Api, #[serde(default)] + #[cfg(feature = "oauth")] pub oauth: OAuth, } @@ -84,16 +85,16 @@ fn redirect_url() -> Url { impl Default for OAuth { fn default() -> Self { - Self { - discord: DiscordOauth { - client_id: String::default(), - client_secret: SecretString::default(), - token_url: discord_token_url(), - auth_url: discord_auth_url(), - }, - oauth_redirect_url: redirect_url(), - } -} + Self { + discord: DiscordOauth { + client_id: String::default(), + client_secret: SecretString::default(), + token_url: discord_token_url(), + auth_url: discord_auth_url(), + }, + oauth_redirect_url: redirect_url(), + } + } } impl Default for Api { diff --git a/src/main.rs b/src/main.rs index 8ee10a1..c529379 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,11 @@ use clap::Parser; use tokio::net::TcpListener; use tracing::info; -use crate::{config::Config, logging::initialise_logging, server::state::AppState}; +use crate::{ + config::Config, + logging::initialise_logging, + server::{driver::Services, state::AppState}, +}; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -23,7 +27,9 @@ async fn main() -> anyhow::Result<()> { initialise_logging(&config); - let state = AppState::new(&config).await?; + let driver = Services::new(&config.database).await?; + let state = AppState::new(&config, driver).await?; + let router = server::router(&config, state).await?; let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.server.port)); diff --git a/src/server/driver/mod.rs b/src/server/driver/mod.rs new file mode 100644 index 0000000..4c540cb --- /dev/null +++ b/src/server/driver/mod.rs @@ -0,0 +1,29 @@ +use async_trait::async_trait; +use sqlx::PgPool; + +use crate::{config::DatabaseOptions, server::state::database}; + +pub struct Services { + database: PgPool, + // oauth: OauthClient, +} + +impl Services { + pub async fn new(database: &DatabaseOptions) -> anyhow::Result<Self> { + let database = database::connect(database).await?; + + Ok(Self { database }) + } +} + +#[async_trait] +pub trait SellershutDriver: Send + Sync + 'static { + async fn hello(&self); +} + +#[async_trait] +impl SellershutDriver for Services { + async fn hello(&self) { + todo!() + } +} diff --git a/src/server/error/mod.rs b/src/server/error/mod.rs new file mode 100644 index 0000000..6d07f9f --- /dev/null +++ b/src/server/error/mod.rs @@ -0,0 +1,27 @@ +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; + +#[derive(Debug)] +pub struct AppError(anyhow::Error); + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AppError { + fn into_response(self) -> Response { + tracing::error!("Application error: {:#}", self.0); + + (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response() + } +} + +// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into +// `Result<_, AppError>`. That way you don't need to do that manually. +impl<E> From<E> for AppError +where + E: Into<anyhow::Error>, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} diff --git a/src/server/middleware/mod.rs b/src/server/middleware/mod.rs new file mode 100644 index 0000000..f68f27a --- /dev/null +++ b/src/server/middleware/mod.rs @@ -0,0 +1 @@ +pub(super) mod request_id; diff --git a/src/server/middleware/request_id.rs b/src/server/middleware/request_id.rs new file mode 100644 index 0000000..7163c86 --- /dev/null +++ b/src/server/middleware/request_id.rs @@ -0,0 +1,20 @@ +use axum::{ + extract::Request, + http::{HeaderValue, StatusCode}, + middleware::Next, + response::Response, +}; +use uuid::Uuid; + +pub const REQUEST_ID_HEADER: &str = "x-request-id"; + +pub async fn add_request_id(mut request: Request, next: Next) -> Result<Response, StatusCode> { + let headers = request.headers_mut(); + let id = Uuid::now_v7().to_string(); + tracing::trace!(id = id, "attaching request id"); + let bytes = id.as_bytes(); + + headers.insert(REQUEST_ID_HEADER, HeaderValue::from_bytes(bytes).unwrap()); + + Ok(next.run(request).await) +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 803135f..3301035 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,10 +1,125 @@ -use axum::Router; - -use crate::{config::Config, server::state::AppState}; - +pub mod driver; +pub mod error; +mod middleware; +pub mod routes; pub mod shutdown; pub mod state; +use std::time::Duration; + +use activitypub_federation::config::{FederationConfig, FederationMiddleware}; +use axum::{ + Router, + http::{HeaderName, StatusCode}, +}; +use tower_http::{ + cors::{self, CorsLayer}, + request_id::PropagateRequestIdLayer, + timeout::TimeoutLayer, + trace::TraceLayer, +}; +use tracing::{error, info_span}; +use utoipa::OpenApi; +use utoipa_axum::router::OpenApiRouter; + +use crate::{ + config::Config, + server::{ + middleware::request_id::{REQUEST_ID_HEADER, add_request_id}, + routes::auth::OAuthDoc, + state::{AppState, federation}, + }, +}; + +#[derive(OpenApi)] +#[openapi( + tags( + (name = routes::HEALTH, description = "Check API health"), + ), +)] +pub struct ApiDoc; + pub async fn router(config: &Config, state: AppState) -> anyhow::Result<Router<()>> { - todo!() + let state = federation::add_federation(state, config).await?; + + let mut doc = ApiDoc::openapi(); + doc.merge(OAuthDoc::openapi()); + + let (router, _api) = OpenApiRouter::with_openapi(doc) + .routes(utoipa_axum::routes!(routes::health_check)) + .routes(utoipa_axum::routes!(routes::auth::auth)) + .split_for_parts(); + + #[cfg(feature = "swagger")] + let router = router.merge( + utoipa_swagger_ui::SwaggerUi::new("/swagger-ui") + .url("/api-docs/swaggerdoc.json", _api.clone()), + ); + + #[cfg(feature = "redoc")] + let router = { + use utoipa_redoc::Servable as _; + router.merge(utoipa_redoc::Redoc::with_url("/redoc", _api.clone())) + }; + + #[cfg(feature = "scalar")] + let router = { + use utoipa_scalar::Servable as _; + router.merge(utoipa_scalar::Scalar::with_url("/scalar", _api.clone())) + }; + + #[cfg(feature = "rapidoc")] + let router = router.merge( + utoipa_rapidoc::RapiDoc::with_openapi("/api-docs/rapidoc.json", _api).path("/rapidoc"), + ); + + let router = router + .layer( + TraceLayer::new_for_http().make_span_with(|request: &axum::http::Request<_>| { + if let Some(request_id) = request.headers().get(REQUEST_ID_HEADER) { + info_span!( + "http_request", + request_id = ?request_id, + ) + } else { + error!("could not extract request_id"); + info_span!("http_request") + } + }), + ) + .layer(TimeoutLayer::with_status_code( + StatusCode::REQUEST_TIMEOUT, + Duration::from_secs(config.server.request_timeout), + )) + .layer(FederationMiddleware::new(state)) + // send headers from request to response headers + .layer(PropagateRequestIdLayer::new(HeaderName::from_static( + REQUEST_ID_HEADER, + ))) + .layer(axum::middleware::from_fn(add_request_id)) + .layer( + CorsLayer::new() + .allow_origin(cors::Any) + .allow_headers(cors::Any) + .allow_methods(cors::Any), + ); + + Ok(router) +} + +#[cfg(test)] +pub mod bootstrap { + use async_trait::async_trait; + + use crate::server::driver::SellershutDriver; + + #[derive(Debug, Default)] + pub struct TestDriver {} + + #[async_trait] + impl SellershutDriver for TestDriver { + async fn hello(&self) { + todo!() + } + } } diff --git a/src/server/routes/auth/discord.rs b/src/server/routes/auth/discord.rs new file mode 100644 index 0000000..036a35a --- /dev/null +++ b/src/server/routes/auth/discord.rs @@ -0,0 +1,11 @@ +use std::sync::Arc; + +use axum::{extract::State, response::IntoResponse}; + +use crate::server::{driver::SellershutDriver, error::AppError}; + +async fn auth( + State(client): State<Arc<dyn SellershutDriver>>, +) -> Result<impl IntoResponse, AppError> { + Ok(()) +} diff --git a/src/server/routes/auth/mod.rs b/src/server/routes/auth/mod.rs new file mode 100644 index 0000000..b80c565 --- /dev/null +++ b/src/server/routes/auth/mod.rs @@ -0,0 +1,59 @@ +use activitypub_federation::config::Data; + +use serde::Deserialize; + +#[cfg(feature = "oauth-discord")] +pub mod discord; + +#[derive(Deserialize, Debug, Clone, Copy, ToSchema)] +#[serde(rename_all = "lowercase")] +pub enum OauthProvider { + /// Discord + #[cfg(feature = "oauth-discord")] + Discord, +} + +#[derive(Deserialize, Debug, Clone, Copy, IntoParams)] +#[into_params(parameter_in = Query)] +pub struct Params { + /// Set OAuth provider name + provider: OauthProvider, +} + +use axum::{extract::Query, response::IntoResponse}; +use utoipa::{IntoParams, OpenApi, ToSchema}; + +use crate::server::{error::AppError, state::AppState}; + +pub const AUTH: &str = "AUTH"; + +#[derive(OpenApi)] +#[openapi( + tags( + (name = AUTH, description = "OAuth integration") + ), + components( + schemas(OauthProvider) + ) +)] +pub struct OAuthDoc; + +#[utoipa::path( + method(get), + path = "/auth", + params( + Params + ), + tag = AUTH, + responses( + (status = OK, description = "Routes to oauth provider for login", body = str, content_type = "text/plain") + ) +)] +#[axum::debug_handler] +pub async fn auth( + Query(params): Query<Params>, + data: Data<AppState>, +) -> Result<impl IntoResponse, AppError> { + dbg!(¶ms); + Ok(String::default()) +} diff --git a/src/server/routes/mod.rs b/src/server/routes/mod.rs new file mode 100644 index 0000000..edd6fdf --- /dev/null +++ b/src/server/routes/mod.rs @@ -0,0 +1,49 @@ +#[cfg(feature = "oauth")] +pub mod auth; + +pub(super) const HEALTH: &str = "HEALTH"; + +#[utoipa::path( + method(get), + path = "/", + tag = HEALTH, + responses( + (status = OK, description = "Checks if the server is running", body = str, content_type = "text/plain") + ) +)] +pub async fn health_check() -> impl axum::response::IntoResponse { + let name = env!("CARGO_PKG_NAME"); + let version = env!("CARGO_PKG_VERSION"); + + format!("{name} v{version} is live") +} + +#[cfg(test)] +mod tests { + use crate::{ + config::Config, + server::{self, bootstrap::TestDriver, state::AppState}, + }; + + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use tower::ServiceExt; + + #[tokio::test] + async fn health_check() { + let config = Config::default(); + let driver = TestDriver::default(); + let state = AppState::new(&config, driver).await.unwrap(); + + let app = server::router(&config, state).await.unwrap(); + + let response = app + .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap()) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + } +} diff --git a/src/server/state/database.rs b/src/server/state/database.rs index f8fd332..156de0f 100644 --- a/src/server/state/database.rs +++ b/src/server/state/database.rs @@ -4,7 +4,7 @@ use tracing::{debug, trace}; use crate::config::DatabaseOptions; -pub(super) async fn connect(opts: &DatabaseOptions) -> Result<PgPool> { +pub async fn connect(opts: &DatabaseOptions) -> Result<PgPool> { trace!(host = ?opts.url.host(), "connecting to database"); let pg = PgPoolOptions::new() .max_connections(opts.pool_size) diff --git a/src/server/state/federation.rs b/src/server/state/federation.rs new file mode 100644 index 0000000..083741c --- /dev/null +++ b/src/server/state/federation.rs @@ -0,0 +1,65 @@ +use activitypub_federation::config::FederationConfig; +use url::Url; + +use crate::{ + config::{Config, Environment}, + server::state::AppState, +}; + +pub async fn add_federation( + state: AppState, + config: &Config, +) -> anyhow::Result<FederationConfig<AppState>> { + let url = match config.server.environment { + Environment::Dev => { + format!("http://{}", config.server.domain) + } + Environment::Prod => { + format!("https://{}", config.server.domain) + } + }; + let mut url = Url::parse(&url)?; + + if Environment::Dev == config.server.environment { + let _ = url.set_port(Some(config.server.port)); + } + + let mut ap_id = url.clone(); + + { + let mut ps = ap_id.path_segments_mut().expect("path segments in url"); + ps.push("users"); + ps.push(&config.server.system_name); + } + + // let user = if let Some(user) = state.users_service.get_by_ap_id(ap_id.as_str()).await? { + // user + // } else { + // let mut inbox = ap_id.clone(); + // { + // let mut ps = inbox.path_segments_mut().expect("path segments in url"); + // ps.push("inbox"); + // } + // state + // .users_service + // .create_user( + // &ap_id, + // &config.server.system_name, + // PersonType::Service, + // &inbox, + // true, + // ) + // .await? + // }; + // + // let user = User::from(user); + + let config = FederationConfig::builder() + .domain(url.domain().expect("system domain")) + //.signed_fetch_actor(&user) + .app_data(state) + .build() + .await?; + + Ok(config) +} diff --git a/src/server/state/mod.rs b/src/server/state/mod.rs index 0726689..f5f731e 100644 --- a/src/server/state/mod.rs +++ b/src/server/state/mod.rs @@ -1,33 +1,34 @@ pub mod database; +pub mod federation; + +use std::sync::Arc; use sellershut_auth::{ClientOptions, OauthClient}; -use sqlx::PgPool; #[cfg(feature = "oauth-discord")] use url::Url; -use crate::config::Config; #[cfg(feature = "oauth-discord")] use crate::config::DiscordOauth; +use crate::{config::Config, server::driver::SellershutDriver}; +#[derive(Clone)] pub struct AppState { - database: PgPool, + driver: Arc<dyn SellershutDriver>, #[cfg(feature = "oauth-discord")] oauth_discord: OauthClient, } impl AppState { - pub async fn new(config: &Config) -> anyhow::Result<Self> { - let database = database::connect(&config.database).await?; - + pub async fn new(config: &Config, driver: impl SellershutDriver) -> anyhow::Result<Self> { Ok(Self { - database, + driver: Arc::new(driver), oauth_discord: discord_client(&config.oauth.discord, &config.oauth.oauth_redirect_url)?, }) } } #[cfg(feature = "oauth-discord")] -fn discord_client(disc: &DiscordOauth, redirect: &Url)->anyhow::Result<OauthClient> { +fn discord_client(disc: &DiscordOauth, redirect: &Url) -> anyhow::Result<OauthClient> { let discord_opts = ClientOptions::builder() .client_id(disc.client_id.to_owned()) .redirect_url(redirect.to_string()) |
