aboutsummaryrefslogtreecommitdiffstats
path: root/src/server
diff options
context:
space:
mode:
Diffstat (limited to 'src/server')
-rw-r--r--src/server/driver/mod.rs16
-rw-r--r--src/server/routes/auth/mod.rs17
-rw-r--r--src/server/state/cache/cluster.rs58
-rw-r--r--src/server/state/cache/mod.rs240
-rw-r--r--src/server/state/cache/sentinel.rs56
-rw-r--r--src/server/state/mod.rs1
6 files changed, 376 insertions, 12 deletions
diff --git a/src/server/driver/mod.rs b/src/server/driver/mod.rs
index 68bd18c..2eaf7dc 100644
--- a/src/server/driver/mod.rs
+++ b/src/server/driver/mod.rs
@@ -5,26 +5,28 @@ pub mod auth;
use async_session::Session;
use async_trait::async_trait;
#[cfg(feature = "oauth")]
-use axum::{
- http::HeaderMap,
- response::{IntoResponse, Redirect},
-};
+use axum::{http::HeaderMap, response::Redirect};
#[cfg(feature = "oauth")]
use oauth2::CsrfToken;
use sqlx::PgPool;
-use crate::{config::DatabaseOptions, server::state::database};
+use crate::{
+ config::{DatabaseOptions, cache::CacheConfig},
+ server::state::{cache::RedisManager, database},
+};
#[derive(Debug, Clone)]
pub struct Services {
database: PgPool,
+ cache: RedisManager,
}
impl Services {
- pub async fn new(database: &DatabaseOptions) -> anyhow::Result<Self> {
+ pub async fn new(database: &DatabaseOptions, cache: &CacheConfig) -> anyhow::Result<Self> {
let database = database::connect(database).await?;
+ let cache = RedisManager::new(cache).await?;
- Ok(Self { database })
+ Ok(Self { database, cache })
}
}
diff --git a/src/server/routes/auth/mod.rs b/src/server/routes/auth/mod.rs
index 485983a..7d7ecf3 100644
--- a/src/server/routes/auth/mod.rs
+++ b/src/server/routes/auth/mod.rs
@@ -55,11 +55,14 @@ pub async fn auth(
Query(params): Query<Params>,
data: Data<AppState>,
) -> Result<impl IntoResponse, AppError> {
- match params.provider {
- #[cfg(feature = "oauth-discord")]
+ #[cfg(feature = "oauth-discord")]
+ return match params.provider {
OauthProvider::Discord => discord::discord_auth(data),
}
- .await
+ .await;
+
+ #[cfg(not(feature = "oauth-discord"))]
+ Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
}
#[utoipa::path(
@@ -79,9 +82,13 @@ pub async fn authorised(
Query(params): Query<Params>,
data: Data<AppState>,
) -> Result<impl IntoResponse, AppError> {
- match params.provider {
+ #[cfg(feature = "oauth-discord")]
+ return match params.provider {
#[cfg(feature = "oauth-discord")]
OauthProvider::Discord => discord::discord_auth(data),
}
- .await
+ .await;
+
+ #[cfg(not(feature = "oauth-discord"))]
+ Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
}
diff --git a/src/server/state/cache/cluster.rs b/src/server/state/cache/cluster.rs
new file mode 100644
index 0000000..ea71954
--- /dev/null
+++ b/src/server/state/cache/cluster.rs
@@ -0,0 +1,58 @@
+use bb8_redis::bb8;
+use redis::{
+ ErrorKind, FromRedisValue, IntoConnectionInfo, RedisError,
+ cluster::{ClusterClient, ClusterClientBuilder},
+ cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo},
+};
+
+/// ConnectionManager that implements `bb8::ManageConnection` and supports
+/// asynchronous clustered connections via `redis_cluster_async::Connection`
+#[derive(Clone)]
+pub struct RedisClusterConnectionManager {
+ client: ClusterClient,
+}
+
+impl RedisClusterConnectionManager {
+ pub fn new<T: IntoConnectionInfo>(
+ info: T,
+ ) -> Result<RedisClusterConnectionManager, RedisError> {
+ Ok(RedisClusterConnectionManager {
+ client: ClusterClientBuilder::new(vec![info]).retries(0).build()?,
+ })
+ }
+}
+
+impl bb8::ManageConnection for RedisClusterConnectionManager {
+ type Connection = redis::cluster_async::ClusterConnection;
+ type Error = RedisError;
+
+ async fn connect(&self) -> Result<Self::Connection, Self::Error> {
+ self.client.get_async_connection().await
+ }
+
+ async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
+ let cmd = redis::cmd("PING");
+ let pong = conn
+ .route_command(
+ cmd,
+ RoutingInfo::MultiNode((
+ MultipleNodeRoutingInfo::AllMasters,
+ Some(ResponsePolicy::OneSucceeded),
+ )),
+ )
+ .await
+ .and_then(|v| Ok(String::from_redis_value(v)?))?;
+ match pong.as_str() {
+ "PONG" => Ok(()),
+ _ => Err((
+ ErrorKind::Server(redis::ServerErrorKind::ResponseError),
+ "ping request",
+ )
+ .into()),
+ }
+ }
+
+ fn has_broken(&self, _: &mut Self::Connection) -> bool {
+ false
+ }
+}
diff --git a/src/server/state/cache/mod.rs b/src/server/state/cache/mod.rs
new file mode 100644
index 0000000..09af5f7
--- /dev/null
+++ b/src/server/state/cache/mod.rs
@@ -0,0 +1,240 @@
+mod cluster;
+mod sentinel;
+
+use anyhow::Result;
+use redis::{
+ AsyncConnectionConfig, ProtocolVersion, RedisConnectionInfo, RedisError, TlsMode,
+ aio::ConnectionManagerConfig, sentinel::SentinelNodeConnectionInfo,
+};
+use std::{fmt::Debug, sync::Arc};
+
+use bb8_redis::{
+ RedisConnectionManager,
+ bb8::{self, Pool, RunError},
+};
+use tokio::sync::Mutex;
+
+use crate::{
+ config::cache::{CacheConfig, RedisVariant},
+ server::state::cache::{
+ cluster::RedisClusterConnectionManager, sentinel::RedisSentinelConnectionManager,
+ },
+};
+
+const REDIS_CONN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
+
+#[derive(Clone)]
+pub enum RedisManager {
+ Clustered(Pool<RedisClusterConnectionManager>),
+ NonClustered(Pool<RedisConnectionManager>),
+ Sentinel(Pool<RedisSentinelConnectionManager>),
+ ClusteredUnpooled(redis::cluster_async::ClusterConnection),
+ NonClusteredUnpooled(redis::aio::ConnectionManager),
+ SentinelUnpooled(Arc<Mutex<redis::sentinel::SentinelClient>>),
+}
+
+impl Debug for RedisManager {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::Clustered(arg0) => f.debug_tuple("Clustered").field(arg0).finish(),
+ Self::NonClustered(arg0) => f.debug_tuple("NonClustered").field(arg0).finish(),
+ Self::Sentinel(arg0) => f.debug_tuple("Sentinel").field(arg0).finish(),
+ Self::ClusteredUnpooled(_arg0) => f.debug_tuple("ClusteredUnpooled").finish(),
+ Self::NonClusteredUnpooled(arg0) => {
+ f.debug_tuple("NonClusteredUnpooled").field(arg0).finish()
+ }
+ Self::SentinelUnpooled(_arg0) => f.debug_tuple("SentinelUnpooled").finish(),
+ }
+ }
+}
+
+pub enum RedisConnection<'a> {
+ Clustered(bb8::PooledConnection<'a, RedisClusterConnectionManager>),
+ NonClustered(bb8::PooledConnection<'a, RedisConnectionManager>),
+ SentinelPooled(bb8::PooledConnection<'a, RedisSentinelConnectionManager>),
+ ClusteredUnpooled(redis::cluster_async::ClusterConnection),
+ NonClusteredUnpooled(redis::aio::ConnectionManager),
+ SentinelUnpooled(redis::aio::MultiplexedConnection),
+}
+
+impl redis::aio::ConnectionLike for RedisConnection<'_> {
+ fn req_packed_command<'a>(
+ &'a mut self,
+ cmd: &'a redis::Cmd,
+ ) -> redis::RedisFuture<'a, redis::Value> {
+ match self {
+ RedisConnection::Clustered(conn) => conn.req_packed_command(cmd),
+ RedisConnection::NonClustered(conn) => conn.req_packed_command(cmd),
+ RedisConnection::ClusteredUnpooled(conn) => conn.req_packed_command(cmd),
+ RedisConnection::NonClusteredUnpooled(conn) => conn.req_packed_command(cmd),
+ RedisConnection::SentinelPooled(conn) => conn.req_packed_command(cmd),
+ RedisConnection::SentinelUnpooled(conn) => conn.req_packed_command(cmd),
+ }
+ }
+
+ fn req_packed_commands<'a>(
+ &'a mut self,
+ cmd: &'a redis::Pipeline,
+ offset: usize,
+ count: usize,
+ ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
+ match self {
+ RedisConnection::Clustered(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::NonClustered(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::ClusteredUnpooled(conn) => {
+ conn.req_packed_commands(cmd, offset, count)
+ }
+ RedisConnection::NonClusteredUnpooled(conn) => {
+ conn.req_packed_commands(cmd, offset, count)
+ }
+ RedisConnection::SentinelPooled(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::SentinelUnpooled(conn) => conn.req_packed_commands(cmd, offset, count),
+ }
+ }
+
+ fn get_db(&self) -> i64 {
+ match self {
+ RedisConnection::Clustered(conn) => conn.get_db(),
+ RedisConnection::NonClustered(conn) => conn.get_db(),
+ RedisConnection::ClusteredUnpooled(conn) => conn.get_db(),
+ RedisConnection::NonClusteredUnpooled(conn) => conn.get_db(),
+ RedisConnection::SentinelPooled(conn) => conn.get_db(),
+ RedisConnection::SentinelUnpooled(conn) => conn.get_db(),
+ }
+ }
+}
+
+impl RedisManager {
+ pub async fn new(config: &CacheConfig) -> Result<Self> {
+ if config.pooled {
+ Self::new_pooled(
+ config.redis_dsn.as_ref(),
+ &config.kind,
+ config.max_connections,
+ )
+ .await
+ } else {
+ Self::new_unpooled(config.redis_dsn.as_ref(), &config.kind).await
+ }
+ }
+ async fn new_pooled(dsn: &str, variant: &RedisVariant, max_conns: u16) -> Result<Self> {
+ match variant {
+ RedisVariant::Clustered => {
+ let mgr = RedisClusterConnectionManager::new(dsn)?;
+ let pool = bb8::Pool::builder()
+ .max_size(max_conns.into())
+ .build(mgr)
+ .await?;
+ Ok(RedisManager::Clustered(pool))
+ }
+ RedisVariant::NonClustered => {
+ let mgr = RedisConnectionManager::new(dsn)?;
+ let pool = bb8::Pool::builder()
+ .max_size(max_conns.into())
+ .build(mgr)
+ .await?;
+ Ok(RedisManager::NonClustered(pool))
+ }
+ RedisVariant::Sentinel(cfg) => {
+ let mgr = RedisSentinelConnectionManager::new(
+ vec![dsn],
+ cfg.service_name.clone(),
+ Some(create_config(cfg)),
+ )?;
+ let pool = bb8::Pool::builder()
+ .max_size(max_conns.into())
+ .build(mgr)
+ .await?;
+ Ok(RedisManager::Sentinel(pool))
+ }
+ }
+ }
+
+ async fn new_unpooled(dsn: &str, variant: &RedisVariant) -> Result<Self> {
+ match variant {
+ RedisVariant::Clustered => {
+ let cli = redis::cluster::ClusterClient::builder(vec![dsn])
+ .retries(1)
+ .connection_timeout(REDIS_CONN_TIMEOUT)
+ .build()?;
+ let con = cli.get_async_connection().await?;
+ Ok(RedisManager::ClusteredUnpooled(con))
+ }
+ RedisVariant::NonClustered => {
+ let cli = redis::Client::open(dsn)?;
+ let con = redis::aio::ConnectionManager::new_with_config(
+ cli,
+ ConnectionManagerConfig::new()
+ .set_number_of_retries(1)
+ .set_connection_timeout(Some(REDIS_CONN_TIMEOUT)),
+ )
+ .await?;
+ Ok(RedisManager::NonClusteredUnpooled(con))
+ }
+ RedisVariant::Sentinel(cfg) => {
+ let cli = redis::sentinel::SentinelClient::build(
+ vec![dsn],
+ cfg.service_name.clone(),
+ Some(create_config(cfg)),
+ redis::sentinel::SentinelServerType::Master,
+ )?;
+
+ Ok(RedisManager::SentinelUnpooled(Arc::new(Mutex::new(cli))))
+ }
+ }
+ }
+
+ pub async fn get(&self) -> Result<RedisConnection<'_>, RunError<RedisError>> {
+ match self {
+ Self::Clustered(pool) => Ok(RedisConnection::Clustered(pool.get().await?)),
+ Self::NonClustered(pool) => Ok(RedisConnection::NonClustered(pool.get().await?)),
+ Self::Sentinel(pool) => Ok(RedisConnection::SentinelPooled(pool.get().await?)),
+ Self::ClusteredUnpooled(conn) => Ok(RedisConnection::ClusteredUnpooled(conn.clone())),
+ Self::NonClusteredUnpooled(conn) => {
+ Ok(RedisConnection::NonClusteredUnpooled(conn.clone()))
+ }
+ Self::SentinelUnpooled(conn) => {
+ let mut conn = conn.lock().await;
+ let con = conn
+ .get_async_connection_with_config(
+ &AsyncConnectionConfig::new()
+ .set_response_timeout(Some(REDIS_CONN_TIMEOUT)),
+ )
+ .await?;
+ Ok(RedisConnection::SentinelUnpooled(con))
+ }
+ }
+ }
+}
+
+fn create_config(cfg: &crate::config::cache::SentinelConfig) -> SentinelNodeConnectionInfo {
+ let tls_mode = cfg.redis_tls_mode_secure.then_some(TlsMode::Secure);
+ let protocol = if cfg.redis_use_resp3 {
+ ProtocolVersion::RESP3
+ } else {
+ ProtocolVersion::default()
+ };
+ let info = RedisConnectionInfo::default();
+ let info = if let Some(pass) = &cfg.redis_password {
+ info.set_password(pass.clone())
+ } else {
+ info
+ };
+
+ let info = if let Some(user) = &cfg.redis_username {
+ info.set_username(user.clone())
+ } else {
+ info
+ }
+ .set_protocol(protocol.clone())
+ .set_db(cfg.redis_db.unwrap_or(0));
+
+ let sent_info = SentinelNodeConnectionInfo::default();
+
+ if let Some(tls) = tls_mode {
+ sent_info.set_tls_mode(tls)
+ } else {
+ sent_info
+ }
+ .set_redis_connection_info(info)
+}
diff --git a/src/server/state/cache/sentinel.rs b/src/server/state/cache/sentinel.rs
new file mode 100644
index 0000000..8dcf394
--- /dev/null
+++ b/src/server/state/cache/sentinel.rs
@@ -0,0 +1,56 @@
+use bb8_redis::bb8;
+use redis::{
+ ErrorKind, IntoConnectionInfo, RedisError,
+ sentinel::{SentinelClient, SentinelNodeConnectionInfo, SentinelServerType},
+};
+use tokio::sync::Mutex;
+
+struct LockedSentinelClient(pub(crate) Mutex<SentinelClient>);
+
+/// ConnectionManager that implements `bb8::ManageConnection` and supports
+/// asynchronous Sentinel connections via `redis::sentinel::SentinelClient`
+pub struct RedisSentinelConnectionManager {
+ client: LockedSentinelClient,
+}
+
+impl RedisSentinelConnectionManager {
+ pub fn new<T: IntoConnectionInfo>(
+ info: Vec<T>,
+ service_name: String,
+ node_connection_info: Option<SentinelNodeConnectionInfo>,
+ ) -> Result<RedisSentinelConnectionManager, RedisError> {
+ Ok(RedisSentinelConnectionManager {
+ client: LockedSentinelClient(Mutex::new(SentinelClient::build(
+ info,
+ service_name,
+ node_connection_info,
+ SentinelServerType::Master,
+ )?)),
+ })
+ }
+}
+
+impl bb8::ManageConnection for RedisSentinelConnectionManager {
+ type Connection = redis::aio::MultiplexedConnection;
+ type Error = RedisError;
+
+ async fn connect(&self) -> Result<Self::Connection, Self::Error> {
+ self.client.0.lock().await.get_async_connection().await
+ }
+
+ async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
+ let pong: String = redis::cmd("PING").query_async(conn).await?;
+ match pong.as_str() {
+ "PONG" => Ok(()),
+ _ => Err((
+ ErrorKind::Server(redis::ServerErrorKind::ResponseError),
+ "ping request",
+ )
+ .into()),
+ }
+ }
+
+ fn has_broken(&self, _: &mut Self::Connection) -> bool {
+ false
+ }
+}
diff --git a/src/server/state/mod.rs b/src/server/state/mod.rs
index c86052d..f256949 100644
--- a/src/server/state/mod.rs
+++ b/src/server/state/mod.rs
@@ -1,3 +1,4 @@
+pub mod cache;
pub mod database;
pub mod federation;