aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorrtkay123 <dev@kanjala.com>2026-02-03 13:45:46 +0200
committerrtkay123 <dev@kanjala.com>2026-02-03 13:45:46 +0200
commiteb2e86997d47249aa31b703598de13ab2eb96caa (patch)
tree9a591adee7d027b305d07a04987b5559b99f4d37 /src
parent0ea3cb1d4743b922fbc6e07037096e75caffba8f (diff)
downloadsellershut-eb2e86997d47249aa31b703598de13ab2eb96caa.tar.bz2
sellershut-eb2e86997d47249aa31b703598de13ab2eb96caa.zip
feat: add cacheHEADmaster
Diffstat (limited to 'src')
-rw-r--r--src/config/cache.rs57
-rw-r--r--src/config/cli/cache/mod.rs46
-rw-r--r--src/config/cli/mod.rs (renamed from src/config/cli.rs)50
-rw-r--r--src/config/cli/oauth/mod.rs36
-rw-r--r--src/config/mod.rs35
-rw-r--r--src/main.rs4
-rw-r--r--src/server/driver/mod.rs16
-rw-r--r--src/server/routes/auth/mod.rs17
-rw-r--r--src/server/state/cache/cluster.rs58
-rw-r--r--src/server/state/cache/mod.rs240
-rw-r--r--src/server/state/cache/sentinel.rs56
-rw-r--r--src/server/state/mod.rs1
12 files changed, 560 insertions, 56 deletions
diff --git a/src/config/cache.rs b/src/config/cache.rs
new file mode 100644
index 0000000..96f3a9b
--- /dev/null
+++ b/src/config/cache.rs
@@ -0,0 +1,57 @@
+use serde::Deserialize;
+use url::Url;
+
+#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "lowercase")]
+pub struct CacheConfig {
+ #[serde(rename = "dsn")]
+ pub redis_dsn: Url,
+ #[serde(default)]
+ pub pooled: bool,
+ #[serde(rename = "type")]
+ pub kind: RedisVariant,
+ #[serde(default = "default_max_conns")]
+ #[serde(rename = "max-connections")]
+ pub max_connections: u16,
+}
+
+#[derive(Debug, Deserialize, Clone, Default)]
+#[serde(rename_all = "kebab-case")]
+pub enum RedisVariant {
+ Clustered,
+ #[default]
+ NonClustered,
+ Sentinel(SentinelConfig),
+}
+
+#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
+pub struct SentinelConfig {
+ #[serde(rename = "sentinel_service_name")]
+ pub service_name: String,
+ #[serde(default)]
+ pub redis_tls_mode_secure: bool,
+ pub redis_db: Option<i64>,
+ pub redis_username: Option<String>,
+ pub redis_password: Option<String>,
+ #[serde(default)]
+ pub redis_use_resp3: bool,
+}
+
+fn default_max_conns() -> u16 {
+ 100
+}
+
+fn default_cache() -> Url {
+ Url::parse("redis://localhost:6379").expect("valid default DATABASE url")
+}
+
+impl Default for CacheConfig {
+ fn default() -> Self {
+ Self {
+ redis_dsn: default_cache(),
+ pooled: Default::default(),
+ kind: Default::default(),
+ max_connections: default_max_conns(),
+ }
+ }
+}
diff --git a/src/config/cli/cache/mod.rs b/src/config/cli/cache/mod.rs
new file mode 100644
index 0000000..04b36bc
--- /dev/null
+++ b/src/config/cli/cache/mod.rs
@@ -0,0 +1,46 @@
+use clap::{Parser, ValueEnum};
+use serde::Deserialize;
+use url::Url;
+
+#[derive(Debug, Clone, Parser, Deserialize, Default)]
+pub struct Cache {
+ /// Cache connection string
+ #[arg(long, env = "CACHE_URL", default_value = "redis://localhost:6379")]
+ pub cache_url: Option<Url>,
+ #[arg(long, env = "CACHE_POOL_ENABLED", default_value = "true")]
+ pub cache_pooled: Option<bool>,
+ #[serde(rename = "type")]
+ #[arg(long, env = "CACHE_TYPE", default_value = "non-clustered")]
+ pub cache_type: Option<RedisVariant>,
+ #[serde(default = "default_max_conns")]
+ #[serde(rename = "max-connections")]
+ #[arg(long, env = "CACHE_MAX_CONNECTIONS", default_value = "100")]
+ pub cache_max_conn: Option<u16>,
+ #[command(flatten)]
+ pub sentinel_config: SentinelConfig,
+}
+
+#[derive(Debug, Deserialize, Clone, ValueEnum)]
+#[serde(rename_all = "kebab-case")]
+pub enum RedisVariant {
+ Clustered,
+ NonClustered,
+ Sentinel,
+}
+
+fn default_max_conns() -> Option<u16> {
+ Some(100)
+}
+
+#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Parser, Default)]
+pub struct SentinelConfig {
+ #[serde(rename = "sentinel_service_name")]
+ #[arg(long, env = "CACHE_SENTINEL_NAME", default_value = "true")]
+ pub service_name: Option<String>,
+ #[serde(default)]
+ #[arg(long, env = "CACHE_TLS_MODE_SECURE")]
+ pub cache_tls_mode_secure: bool,
+ #[serde(default)]
+ #[arg(long, env = "CACHE_USE_RESP3")]
+ pub cache_use_resp3: bool,
+}
diff --git a/src/config/cli.rs b/src/config/cli/mod.rs
index be1b913..81eb2fe 100644
--- a/src/config/cli.rs
+++ b/src/config/cli/mod.rs
@@ -1,9 +1,11 @@
+pub mod cache;
+
+#[cfg(feature = "oauth")]
+pub mod oauth;
+
use std::path::PathBuf;
use clap::Parser;
-#[cfg(feature = "oauth-discord")]
-use secrecy::SecretString;
-use serde::Deserialize;
use url::Url;
use crate::config::{logging::LogLevel, port::port_in_range};
@@ -31,14 +33,17 @@ pub struct Cli {
#[arg(short, long, env = "TIMEOUT_SECONDS", default_value = "10")]
pub timeout_duration: Option<u64>,
- /// Users database connection string
+ /// Database connection string
#[arg(
long,
- env = "USERS_DATABASE_URL",
+ env = "DATABASE_URL",
default_value = "postgres://postgres:password@localhost:5432/sellershut"
)]
pub db: Option<Url>,
+ #[command(flatten)]
+ pub cache: Option<cache::Cache>,
+
/// Server's system name
#[arg(short, long, default_value = "sellershut", env = "SYSTEM_NAME")]
pub system_name: Option<String>,
@@ -50,37 +55,7 @@ pub struct Cli {
/// Oauth optionas
#[command(flatten)]
#[cfg(feature = "oauth")]
- pub oauth: OAuth,
-}
-
-#[derive(Debug, Clone, Parser, Deserialize)]
-pub struct OAuth {
- #[cfg(feature = "oauth-discord")]
- #[command(flatten)]
- discord: DiscordOauth,
- #[arg(long, env = "OAUTH_REDIRECT_URL")]
- oauth_redirect_url: Option<Url>,
-}
-
-#[cfg(feature = "oauth-discord")]
-#[derive(Debug, Clone, Parser, Deserialize, Default)]
-pub struct DiscordOauth {
- #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")]
- discord_client_id: Option<String>,
- #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")]
- discord_client_secret: Option<SecretString>,
- #[arg(
- long,
- env = "OAUTH_DISCORD_TOKEN_URL",
- default_value = "https://discord.com/api/oauth2/token"
- )]
- discord_token_url: Option<Url>,
- #[arg(
- long,
- env = "OAUTH_DISCORD_AUTH_URL",
- default_value = "https://discord.com/api/oauth2/authorize?response_type=code"
- )]
- discord_auth_url: Option<Url>,
+ pub oauth: oauth::OAuth,
}
#[cfg(test)]
@@ -95,7 +70,8 @@ impl Default for Cli {
domain: Default::default(),
system_name: Default::default(),
environment: Default::default(),
- oauth: None,
+ oauth: Default::default(),
+ cache: Default::default(),
db: url,
}
}
diff --git a/src/config/cli/oauth/mod.rs b/src/config/cli/oauth/mod.rs
new file mode 100644
index 0000000..4bf1c34
--- /dev/null
+++ b/src/config/cli/oauth/mod.rs
@@ -0,0 +1,36 @@
+use clap::Parser;
+#[cfg(feature = "oauth-discord")]
+use secrecy::SecretString;
+use serde::Deserialize;
+#[cfg(feature = "oauth")]
+use url::Url;
+
+#[derive(Debug, Clone, Parser, Deserialize, Default)]
+pub struct OAuth {
+ #[cfg(feature = "oauth-discord")]
+ #[command(flatten)]
+ discord: DiscordOauth,
+ #[arg(long, env = "OAUTH_REDIRECT_URL")]
+ oauth_redirect_url: Option<Url>,
+}
+
+#[cfg(feature = "oauth-discord")]
+#[derive(Debug, Clone, Parser, Deserialize, Default)]
+pub struct DiscordOauth {
+ #[arg(long, env = "OAUTH_DISCORD_CLIENT_ID")]
+ discord_client_id: Option<String>,
+ #[arg(long, env = "OAUTH_DISCORD_CLIENT_SECRET")]
+ discord_client_secret: Option<SecretString>,
+ #[arg(
+ long,
+ env = "OAUTH_DISCORD_TOKEN_URL",
+ default_value = "https://discord.com/api/oauth2/token"
+ )]
+ discord_token_url: Option<Url>,
+ #[arg(
+ long,
+ env = "OAUTH_DISCORD_AUTH_URL",
+ default_value = "https://discord.com/api/oauth2/authorize?response_type=code"
+ )]
+ discord_auth_url: Option<Url>,
+}
diff --git a/src/config/mod.rs b/src/config/mod.rs
index 7495b22..e64ae5c 100644
--- a/src/config/mod.rs
+++ b/src/config/mod.rs
@@ -1,13 +1,15 @@
+pub mod cache;
mod cli;
mod logging;
mod port;
pub use cli::Cli;
-#[cfg(feature = "oauth")]
-use secrecy::SecretString;
use serde::Deserialize;
use url::Url;
-use crate::config::logging::LogLevel;
+use crate::config::{
+ cache::{CacheConfig, RedisVariant},
+ logging::LogLevel,
+};
#[derive(Default, Deserialize, Debug, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
@@ -23,6 +25,8 @@ pub struct Config {
#[serde(default)]
pub database: DatabaseOptions,
#[serde(default)]
+ pub cache: CacheConfig,
+ #[serde(default)]
pub server: Api,
#[serde(default)]
#[cfg(feature = "oauth")]
@@ -65,7 +69,7 @@ pub struct OAuth {
#[serde(rename_all = "kebab-case")]
pub struct DiscordOauth {
pub client_id: String,
- pub client_secret: SecretString,
+ pub client_secret: secrecy::SecretString,
#[serde(default = "discord_token_url")]
pub token_url: Url,
#[serde(default = "discord_auth_url")]
@@ -94,7 +98,7 @@ impl Default for OAuth {
#[cfg(feature = "oauth-discord")]
discord: DiscordOauth {
client_id: String::default(),
- client_secret: SecretString::default(),
+ client_secret: secrecy::SecretString::default(),
token_url: discord_token_url(),
auth_url: discord_auth_url(),
},
@@ -175,6 +179,7 @@ impl Config {
pub fn merge_with_cli(&mut self, cli: &Cli) {
let server = &mut self.server;
let dsn = &mut self.database;
+ let cache = &mut self.cache;
if let Some(port) = cli.port {
server.port = port;
@@ -195,6 +200,26 @@ impl Config {
if let Some(db_url) = &cli.db {
dsn.url = db_url.clone();
}
+
+ if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_url.clone()) {
+ cache.redis_dsn = c;
+ }
+
+ if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_pooled) {
+ cache.pooled = c;
+ }
+
+ if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_max_conn) {
+ cache.max_connections = c;
+ }
+
+ if let Some(c) = cli.cache.as_ref().and_then(|v| v.cache_type.clone()) {
+ cache.kind = match c {
+ cli::cache::RedisVariant::Clustered => RedisVariant::Clustered,
+ cli::cache::RedisVariant::NonClustered => RedisVariant::NonClustered,
+ cli::cache::RedisVariant::Sentinel => cache.kind.clone(),
+ };
+ }
}
}
diff --git a/src/main.rs b/src/main.rs
index 2018956..971c08c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,4 +1,4 @@
-mod config;
+pub mod config;
mod logging;
mod server;
@@ -29,7 +29,7 @@ async fn main() -> anyhow::Result<()> {
initialise_logging(&config);
- let driver = Services::new(&config.database).await?;
+ let driver = Services::new(&config.database, &config.cache).await?;
let state = AppState::new(&config, driver).await?;
let router = server::router(&config, state).await?;
diff --git a/src/server/driver/mod.rs b/src/server/driver/mod.rs
index 68bd18c..2eaf7dc 100644
--- a/src/server/driver/mod.rs
+++ b/src/server/driver/mod.rs
@@ -5,26 +5,28 @@ pub mod auth;
use async_session::Session;
use async_trait::async_trait;
#[cfg(feature = "oauth")]
-use axum::{
- http::HeaderMap,
- response::{IntoResponse, Redirect},
-};
+use axum::{http::HeaderMap, response::Redirect};
#[cfg(feature = "oauth")]
use oauth2::CsrfToken;
use sqlx::PgPool;
-use crate::{config::DatabaseOptions, server::state::database};
+use crate::{
+ config::{DatabaseOptions, cache::CacheConfig},
+ server::state::{cache::RedisManager, database},
+};
#[derive(Debug, Clone)]
pub struct Services {
database: PgPool,
+ cache: RedisManager,
}
impl Services {
- pub async fn new(database: &DatabaseOptions) -> anyhow::Result<Self> {
+ pub async fn new(database: &DatabaseOptions, cache: &CacheConfig) -> anyhow::Result<Self> {
let database = database::connect(database).await?;
+ let cache = RedisManager::new(cache).await?;
- Ok(Self { database })
+ Ok(Self { database, cache })
}
}
diff --git a/src/server/routes/auth/mod.rs b/src/server/routes/auth/mod.rs
index 485983a..7d7ecf3 100644
--- a/src/server/routes/auth/mod.rs
+++ b/src/server/routes/auth/mod.rs
@@ -55,11 +55,14 @@ pub async fn auth(
Query(params): Query<Params>,
data: Data<AppState>,
) -> Result<impl IntoResponse, AppError> {
- match params.provider {
- #[cfg(feature = "oauth-discord")]
+ #[cfg(feature = "oauth-discord")]
+ return match params.provider {
OauthProvider::Discord => discord::discord_auth(data),
}
- .await
+ .await;
+
+ #[cfg(not(feature = "oauth-discord"))]
+ Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
}
#[utoipa::path(
@@ -79,9 +82,13 @@ pub async fn authorised(
Query(params): Query<Params>,
data: Data<AppState>,
) -> Result<impl IntoResponse, AppError> {
- match params.provider {
+ #[cfg(feature = "oauth-discord")]
+ return match params.provider {
#[cfg(feature = "oauth-discord")]
OauthProvider::Discord => discord::discord_auth(data),
}
- .await
+ .await;
+
+ #[cfg(not(feature = "oauth-discord"))]
+ Ok(axum::http::StatusCode::INTERNAL_SERVER_ERROR)
}
diff --git a/src/server/state/cache/cluster.rs b/src/server/state/cache/cluster.rs
new file mode 100644
index 0000000..ea71954
--- /dev/null
+++ b/src/server/state/cache/cluster.rs
@@ -0,0 +1,58 @@
+use bb8_redis::bb8;
+use redis::{
+ ErrorKind, FromRedisValue, IntoConnectionInfo, RedisError,
+ cluster::{ClusterClient, ClusterClientBuilder},
+ cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo},
+};
+
+/// ConnectionManager that implements `bb8::ManageConnection` and supports
+/// asynchronous clustered connections via `redis_cluster_async::Connection`
+#[derive(Clone)]
+pub struct RedisClusterConnectionManager {
+ client: ClusterClient,
+}
+
+impl RedisClusterConnectionManager {
+ pub fn new<T: IntoConnectionInfo>(
+ info: T,
+ ) -> Result<RedisClusterConnectionManager, RedisError> {
+ Ok(RedisClusterConnectionManager {
+ client: ClusterClientBuilder::new(vec![info]).retries(0).build()?,
+ })
+ }
+}
+
+impl bb8::ManageConnection for RedisClusterConnectionManager {
+ type Connection = redis::cluster_async::ClusterConnection;
+ type Error = RedisError;
+
+ async fn connect(&self) -> Result<Self::Connection, Self::Error> {
+ self.client.get_async_connection().await
+ }
+
+ async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
+ let cmd = redis::cmd("PING");
+ let pong = conn
+ .route_command(
+ cmd,
+ RoutingInfo::MultiNode((
+ MultipleNodeRoutingInfo::AllMasters,
+ Some(ResponsePolicy::OneSucceeded),
+ )),
+ )
+ .await
+ .and_then(|v| Ok(String::from_redis_value(v)?))?;
+ match pong.as_str() {
+ "PONG" => Ok(()),
+ _ => Err((
+ ErrorKind::Server(redis::ServerErrorKind::ResponseError),
+ "ping request",
+ )
+ .into()),
+ }
+ }
+
+ fn has_broken(&self, _: &mut Self::Connection) -> bool {
+ false
+ }
+}
diff --git a/src/server/state/cache/mod.rs b/src/server/state/cache/mod.rs
new file mode 100644
index 0000000..09af5f7
--- /dev/null
+++ b/src/server/state/cache/mod.rs
@@ -0,0 +1,240 @@
+mod cluster;
+mod sentinel;
+
+use anyhow::Result;
+use redis::{
+ AsyncConnectionConfig, ProtocolVersion, RedisConnectionInfo, RedisError, TlsMode,
+ aio::ConnectionManagerConfig, sentinel::SentinelNodeConnectionInfo,
+};
+use std::{fmt::Debug, sync::Arc};
+
+use bb8_redis::{
+ RedisConnectionManager,
+ bb8::{self, Pool, RunError},
+};
+use tokio::sync::Mutex;
+
+use crate::{
+ config::cache::{CacheConfig, RedisVariant},
+ server::state::cache::{
+ cluster::RedisClusterConnectionManager, sentinel::RedisSentinelConnectionManager,
+ },
+};
+
+const REDIS_CONN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
+
+#[derive(Clone)]
+pub enum RedisManager {
+ Clustered(Pool<RedisClusterConnectionManager>),
+ NonClustered(Pool<RedisConnectionManager>),
+ Sentinel(Pool<RedisSentinelConnectionManager>),
+ ClusteredUnpooled(redis::cluster_async::ClusterConnection),
+ NonClusteredUnpooled(redis::aio::ConnectionManager),
+ SentinelUnpooled(Arc<Mutex<redis::sentinel::SentinelClient>>),
+}
+
+impl Debug for RedisManager {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::Clustered(arg0) => f.debug_tuple("Clustered").field(arg0).finish(),
+ Self::NonClustered(arg0) => f.debug_tuple("NonClustered").field(arg0).finish(),
+ Self::Sentinel(arg0) => f.debug_tuple("Sentinel").field(arg0).finish(),
+ Self::ClusteredUnpooled(_arg0) => f.debug_tuple("ClusteredUnpooled").finish(),
+ Self::NonClusteredUnpooled(arg0) => {
+ f.debug_tuple("NonClusteredUnpooled").field(arg0).finish()
+ }
+ Self::SentinelUnpooled(_arg0) => f.debug_tuple("SentinelUnpooled").finish(),
+ }
+ }
+}
+
+pub enum RedisConnection<'a> {
+ Clustered(bb8::PooledConnection<'a, RedisClusterConnectionManager>),
+ NonClustered(bb8::PooledConnection<'a, RedisConnectionManager>),
+ SentinelPooled(bb8::PooledConnection<'a, RedisSentinelConnectionManager>),
+ ClusteredUnpooled(redis::cluster_async::ClusterConnection),
+ NonClusteredUnpooled(redis::aio::ConnectionManager),
+ SentinelUnpooled(redis::aio::MultiplexedConnection),
+}
+
+impl redis::aio::ConnectionLike for RedisConnection<'_> {
+ fn req_packed_command<'a>(
+ &'a mut self,
+ cmd: &'a redis::Cmd,
+ ) -> redis::RedisFuture<'a, redis::Value> {
+ match self {
+ RedisConnection::Clustered(conn) => conn.req_packed_command(cmd),
+ RedisConnection::NonClustered(conn) => conn.req_packed_command(cmd),
+ RedisConnection::ClusteredUnpooled(conn) => conn.req_packed_command(cmd),
+ RedisConnection::NonClusteredUnpooled(conn) => conn.req_packed_command(cmd),
+ RedisConnection::SentinelPooled(conn) => conn.req_packed_command(cmd),
+ RedisConnection::SentinelUnpooled(conn) => conn.req_packed_command(cmd),
+ }
+ }
+
+ fn req_packed_commands<'a>(
+ &'a mut self,
+ cmd: &'a redis::Pipeline,
+ offset: usize,
+ count: usize,
+ ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
+ match self {
+ RedisConnection::Clustered(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::NonClustered(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::ClusteredUnpooled(conn) => {
+ conn.req_packed_commands(cmd, offset, count)
+ }
+ RedisConnection::NonClusteredUnpooled(conn) => {
+ conn.req_packed_commands(cmd, offset, count)
+ }
+ RedisConnection::SentinelPooled(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::SentinelUnpooled(conn) => conn.req_packed_commands(cmd, offset, count),
+ }
+ }
+
+ fn get_db(&self) -> i64 {
+ match self {
+ RedisConnection::Clustered(conn) => conn.get_db(),
+ RedisConnection::NonClustered(conn) => conn.get_db(),
+ RedisConnection::ClusteredUnpooled(conn) => conn.get_db(),
+ RedisConnection::NonClusteredUnpooled(conn) => conn.get_db(),
+ RedisConnection::SentinelPooled(conn) => conn.get_db(),
+ RedisConnection::SentinelUnpooled(conn) => conn.get_db(),
+ }
+ }
+}
+
+impl RedisManager {
+ pub async fn new(config: &CacheConfig) -> Result<Self> {
+ if config.pooled {
+ Self::new_pooled(
+ config.redis_dsn.as_ref(),
+ &config.kind,
+ config.max_connections,
+ )
+ .await
+ } else {
+ Self::new_unpooled(config.redis_dsn.as_ref(), &config.kind).await
+ }
+ }
+ async fn new_pooled(dsn: &str, variant: &RedisVariant, max_conns: u16) -> Result<Self> {
+ match variant {
+ RedisVariant::Clustered => {
+ let mgr = RedisClusterConnectionManager::new(dsn)?;
+ let pool = bb8::Pool::builder()
+ .max_size(max_conns.into())
+ .build(mgr)
+ .await?;
+ Ok(RedisManager::Clustered(pool))
+ }
+ RedisVariant::NonClustered => {
+ let mgr = RedisConnectionManager::new(dsn)?;
+ let pool = bb8::Pool::builder()
+ .max_size(max_conns.into())
+ .build(mgr)
+ .await?;
+ Ok(RedisManager::NonClustered(pool))
+ }
+ RedisVariant::Sentinel(cfg) => {
+ let mgr = RedisSentinelConnectionManager::new(
+ vec![dsn],
+ cfg.service_name.clone(),
+ Some(create_config(cfg)),
+ )?;
+ let pool = bb8::Pool::builder()
+ .max_size(max_conns.into())
+ .build(mgr)
+ .await?;
+ Ok(RedisManager::Sentinel(pool))
+ }
+ }
+ }
+
+ async fn new_unpooled(dsn: &str, variant: &RedisVariant) -> Result<Self> {
+ match variant {
+ RedisVariant::Clustered => {
+ let cli = redis::cluster::ClusterClient::builder(vec![dsn])
+ .retries(1)
+ .connection_timeout(REDIS_CONN_TIMEOUT)
+ .build()?;
+ let con = cli.get_async_connection().await?;
+ Ok(RedisManager::ClusteredUnpooled(con))
+ }
+ RedisVariant::NonClustered => {
+ let cli = redis::Client::open(dsn)?;
+ let con = redis::aio::ConnectionManager::new_with_config(
+ cli,
+ ConnectionManagerConfig::new()
+ .set_number_of_retries(1)
+ .set_connection_timeout(Some(REDIS_CONN_TIMEOUT)),
+ )
+ .await?;
+ Ok(RedisManager::NonClusteredUnpooled(con))
+ }
+ RedisVariant::Sentinel(cfg) => {
+ let cli = redis::sentinel::SentinelClient::build(
+ vec![dsn],
+ cfg.service_name.clone(),
+ Some(create_config(cfg)),
+ redis::sentinel::SentinelServerType::Master,
+ )?;
+
+ Ok(RedisManager::SentinelUnpooled(Arc::new(Mutex::new(cli))))
+ }
+ }
+ }
+
+ pub async fn get(&self) -> Result<RedisConnection<'_>, RunError<RedisError>> {
+ match self {
+ Self::Clustered(pool) => Ok(RedisConnection::Clustered(pool.get().await?)),
+ Self::NonClustered(pool) => Ok(RedisConnection::NonClustered(pool.get().await?)),
+ Self::Sentinel(pool) => Ok(RedisConnection::SentinelPooled(pool.get().await?)),
+ Self::ClusteredUnpooled(conn) => Ok(RedisConnection::ClusteredUnpooled(conn.clone())),
+ Self::NonClusteredUnpooled(conn) => {
+ Ok(RedisConnection::NonClusteredUnpooled(conn.clone()))
+ }
+ Self::SentinelUnpooled(conn) => {
+ let mut conn = conn.lock().await;
+ let con = conn
+ .get_async_connection_with_config(
+ &AsyncConnectionConfig::new()
+ .set_response_timeout(Some(REDIS_CONN_TIMEOUT)),
+ )
+ .await?;
+ Ok(RedisConnection::SentinelUnpooled(con))
+ }
+ }
+ }
+}
+
+fn create_config(cfg: &crate::config::cache::SentinelConfig) -> SentinelNodeConnectionInfo {
+ let tls_mode = cfg.redis_tls_mode_secure.then_some(TlsMode::Secure);
+ let protocol = if cfg.redis_use_resp3 {
+ ProtocolVersion::RESP3
+ } else {
+ ProtocolVersion::default()
+ };
+ let info = RedisConnectionInfo::default();
+ let info = if let Some(pass) = &cfg.redis_password {
+ info.set_password(pass.clone())
+ } else {
+ info
+ };
+
+ let info = if let Some(user) = &cfg.redis_username {
+ info.set_username(user.clone())
+ } else {
+ info
+ }
+ .set_protocol(protocol.clone())
+ .set_db(cfg.redis_db.unwrap_or(0));
+
+ let sent_info = SentinelNodeConnectionInfo::default();
+
+ if let Some(tls) = tls_mode {
+ sent_info.set_tls_mode(tls)
+ } else {
+ sent_info
+ }
+ .set_redis_connection_info(info)
+}
diff --git a/src/server/state/cache/sentinel.rs b/src/server/state/cache/sentinel.rs
new file mode 100644
index 0000000..8dcf394
--- /dev/null
+++ b/src/server/state/cache/sentinel.rs
@@ -0,0 +1,56 @@
+use bb8_redis::bb8;
+use redis::{
+ ErrorKind, IntoConnectionInfo, RedisError,
+ sentinel::{SentinelClient, SentinelNodeConnectionInfo, SentinelServerType},
+};
+use tokio::sync::Mutex;
+
+struct LockedSentinelClient(pub(crate) Mutex<SentinelClient>);
+
+/// ConnectionManager that implements `bb8::ManageConnection` and supports
+/// asynchronous Sentinel connections via `redis::sentinel::SentinelClient`
+pub struct RedisSentinelConnectionManager {
+ client: LockedSentinelClient,
+}
+
+impl RedisSentinelConnectionManager {
+ pub fn new<T: IntoConnectionInfo>(
+ info: Vec<T>,
+ service_name: String,
+ node_connection_info: Option<SentinelNodeConnectionInfo>,
+ ) -> Result<RedisSentinelConnectionManager, RedisError> {
+ Ok(RedisSentinelConnectionManager {
+ client: LockedSentinelClient(Mutex::new(SentinelClient::build(
+ info,
+ service_name,
+ node_connection_info,
+ SentinelServerType::Master,
+ )?)),
+ })
+ }
+}
+
+impl bb8::ManageConnection for RedisSentinelConnectionManager {
+ type Connection = redis::aio::MultiplexedConnection;
+ type Error = RedisError;
+
+ async fn connect(&self) -> Result<Self::Connection, Self::Error> {
+ self.client.0.lock().await.get_async_connection().await
+ }
+
+ async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
+ let pong: String = redis::cmd("PING").query_async(conn).await?;
+ match pong.as_str() {
+ "PONG" => Ok(()),
+ _ => Err((
+ ErrorKind::Server(redis::ServerErrorKind::ResponseError),
+ "ping request",
+ )
+ .into()),
+ }
+ }
+
+ fn has_broken(&self, _: &mut Self::Connection) -> bool {
+ false
+ }
+}
diff --git a/src/server/state/mod.rs b/src/server/state/mod.rs
index c86052d..f256949 100644
--- a/src/server/state/mod.rs
+++ b/src/server/state/mod.rs
@@ -1,3 +1,4 @@
+pub mod cache;
pub mod database;
pub mod federation;