aboutsummaryrefslogtreecommitdiffstats
path: root/crates
diff options
context:
space:
mode:
authorrtkay123 <dev@kanjala.com>2026-04-05 15:17:55 +0200
committerrtkay123 <dev@kanjala.com>2026-04-05 15:17:55 +0200
commit3f708c5fffed105b27965f8e844a26de6bdf9662 (patch)
treefbed157ae7fc15a26a86fba5e0b8b9c5107ee07f /crates
parente86366c6d68b9d3d2af4ac4afb5cf7d5a8400dde (diff)
downloadsellershut-3f708c5fffed105b27965f8e844a26de6bdf9662.tar.bz2
sellershut-3f708c5fffed105b27965f8e844a26de6bdf9662.zip
feat(cli): cache
Diffstat (limited to 'crates')
-rw-r--r--crates/api-auth/Cargo.toml3
-rw-r--r--crates/api-auth/src/discord/mod.rs16
-rw-r--r--crates/api-auth/src/lib.rs6
-rw-r--r--crates/sellershut/Cargo.toml1
-rw-r--r--crates/sellershut/src/config/cache/mod.rs231
-rw-r--r--crates/sellershut/src/config/mod.rs7
-rw-r--r--crates/sellershut/src/main.rs8
-rw-r--r--crates/sellershut/src/server/api/routes/auth/discord.rs2
-rw-r--r--crates/sh-util/Cargo.toml26
-rw-r--r--crates/sh-util/src/cache/cluster.rs56
-rw-r--r--crates/sh-util/src/cache/mod.rs176
-rw-r--r--crates/sh-util/src/cache/sentinel.rs66
-rw-r--r--crates/sh-util/src/lib.rs2
13 files changed, 588 insertions, 12 deletions
diff --git a/crates/api-auth/Cargo.toml b/crates/api-auth/Cargo.toml
index 053bbb9..a0868a5 100644
--- a/crates/api-auth/Cargo.toml
+++ b/crates/api-auth/Cargo.toml
@@ -13,6 +13,7 @@ async-trait.workspace = true
oauth2 = "5.0.0"
secrecy.workspace = true
serde.workspace = true
+sh-util = { workspace = true, optional = true }
sqlx.workspace = true
thiserror.workspace = true
utoipa = { workspace = true, optional = true }
@@ -20,5 +21,5 @@ url.workspace = true
async-session = "3.0.0"
[features]
-discord = []
+discord = ["sh-util/cache"]
utoipa = ["dep:utoipa", "serde/derive"]
diff --git a/crates/api-auth/src/discord/mod.rs b/crates/api-auth/src/discord/mod.rs
index 29b9bc2..dbcb139 100644
--- a/crates/api-auth/src/discord/mod.rs
+++ b/crates/api-auth/src/discord/mod.rs
@@ -2,19 +2,25 @@ use api_core::models::user::User;
use async_session::Session;
use async_trait::async_trait;
use oauth2::{CsrfToken, Scope};
+use sh_util::cache::RedisManager;
use sqlx::PgPool;
use crate::{BasicClient, CSRF_TOKEN, OauthDriver, error::AuthError};
-#[derive(Clone, Debug)]
+#[derive(Clone)]
pub struct AuthServiceDiscord {
database: PgPool,
+ cache: RedisManager,
client: BasicClient,
}
impl AuthServiceDiscord {
- pub fn new(database: PgPool, client: BasicClient) -> Self {
- Self { database, client }
+ pub fn new(database: PgPool, client: BasicClient, cache: RedisManager) -> Self {
+ Self {
+ database,
+ client,
+ cache,
+ }
}
}
@@ -26,7 +32,7 @@ impl OauthDriver for AuthServiceDiscord {
async fn get_user(&self) -> Result<User, AuthError> {
todo!()
}
- async fn create_oauth_session(&self)->Result<String,AuthError> {
+ async fn create_oauth_session(&self) -> Result<String, AuthError> {
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
@@ -38,7 +44,7 @@ impl OauthDriver for AuthServiceDiscord {
Ok(String::default())
}
- async fn save_session(&self, user: &User)->Result<(), AuthError>{
+ async fn save_session(&self, user: &User) -> Result<(), AuthError> {
todo!()
}
}
diff --git a/crates/api-auth/src/lib.rs b/crates/api-auth/src/lib.rs
index 95a04c4..367d395 100644
--- a/crates/api-auth/src/lib.rs
+++ b/crates/api-auth/src/lib.rs
@@ -20,11 +20,11 @@ type C = oauth2::basic::BasicClient<
pub struct BasicClient(C);
#[async_trait::async_trait]
-pub trait OauthDriver: Send + Sync + std::fmt::Debug {
+pub trait OauthDriver: Send + Sync {
async fn get_auth_token(&self) -> Result<String, AuthError>;
async fn get_user(&self) -> Result<User, AuthError>;
- async fn create_oauth_session(&self)->Result<String, AuthError>;
- async fn save_session(&self, user: &User)->Result<(), AuthError>;
+ async fn create_oauth_session(&self) -> Result<String, AuthError>;
+ async fn save_session(&self, user: &User) -> Result<(), AuthError>;
}
use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
diff --git a/crates/sellershut/Cargo.toml b/crates/sellershut/Cargo.toml
index 14a686c..caf6fd0 100644
--- a/crates/sellershut/Cargo.toml
+++ b/crates/sellershut/Cargo.toml
@@ -18,6 +18,7 @@ clap = { version = "4.6.0", features = ["derive", "env"] }
secrecy = { workspace = true, features = ["serde"] }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
+sh-util = { workspace = true, features = ["cache"] }
sqlx = { workspace = true, features = ["migrate"] }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] }
toml = "1.1.2"
diff --git a/crates/sellershut/src/config/cache/mod.rs b/crates/sellershut/src/config/cache/mod.rs
new file mode 100644
index 0000000..136c3a4
--- /dev/null
+++ b/crates/sellershut/src/config/cache/mod.rs
@@ -0,0 +1,231 @@
+use anyhow::Context;
+use clap::{Args, ValueEnum};
+use serde::{Deserialize, Serialize};
+use sh_util::cache::{RedisVariant, SentinelConfig};
+
+#[derive(Debug, Clone, Copy, ValueEnum, Serialize, Deserialize, PartialEq, Eq)]
+#[serde(rename_all = "snake_case")]
+pub enum CacheMode {
+ Standalone,
+ Clustered,
+ Sentinel,
+}
+
+#[derive(Debug, Clone, Args, Serialize, Deserialize, Default, PartialEq, Eq)]
+#[serde(default, rename_all = "kebab-case")]
+pub struct CacheConfig {
+ /// Cache mode: standalone, clustered, or sentinel.
+ #[arg(long, env = "HUT_CACHE_MODE", value_enum)]
+ #[serde(rename = "mode", skip_serializing_if = "Option::is_none")]
+ pub cache_mode: Option<CacheMode>,
+
+ /// Full Redis URL. Useful for standalone mode and can override host/port style inputs.
+ #[arg(long, env = "HUT_CACHE_URL")]
+ #[serde(rename = "url", skip_serializing_if = "Option::is_none")]
+ pub cache_url: Option<String>,
+
+ /// Redis host for standalone mode.
+ #[arg(long, env = "HUT_CACHE_HOST")]
+ #[serde(rename = "host", skip_serializing_if = "Option::is_none")]
+ pub cache_host: Option<String>,
+
+ /// Redis port for standalone mode.
+ #[arg(long, env = "HUT_CACHE_PORT")]
+ #[serde(rename = "port", skip_serializing_if = "Option::is_none")]
+ pub cache_port: Option<u16>,
+
+ /// Comma-delimited node list for clustered or sentinel discovery, e.g. host1:6379,host2:6379.
+ #[arg(long, env = "HUT_CACHE_NODES", value_delimiter = ',')]
+ #[serde(rename = "nodes", skip_serializing_if = "Vec::is_empty")]
+ pub cache_nodes: Vec<String>,
+
+ /// Redis username.
+ #[arg(long, env = "HUT_CACHE_USERNAME")]
+ #[serde(rename = "username", skip_serializing_if = "Option::is_none")]
+ pub cache_username: Option<String>,
+
+ /// Redis password.
+ #[arg(long, env = "HUT_CACHE_PASSWORD")]
+ #[serde(rename = "password", skip_serializing_if = "Option::is_none")]
+ pub cache_password: Option<String>,
+
+ /// Redis logical database number.
+ #[arg(long, env = "HUT_CACHE_DB")]
+ #[serde(rename = "database", skip_serializing_if = "Option::is_none")]
+ pub cache_database: Option<u32>,
+
+ /// Sentinel service name. Required for sentinel mode.
+ #[arg(long, env = "HUT_CACHE_SERVICE_NAME")]
+ #[serde(rename = "service_name", skip_serializing_if = "Option::is_none")]
+ pub cache_service_name: Option<String>,
+
+ /// Whether Redis TLS should use secure mode.
+ #[arg(long, env = "HUT_CACHE_TLS_MODE_SECURE")]
+ #[serde(rename = "tls_mode_secure", skip_serializing_if = "Option::is_none")]
+ pub cache_tls_mode_secure: Option<bool>,
+
+ /// Whether the client should use RESP3.
+ #[arg(long, env = "HUT_CACHE_USE_RESP3")]
+ #[serde(rename = "use_resp3", skip_serializing_if = "Option::is_none")]
+ pub cache_use_resp3: Option<bool>,
+}
+
+impl CacheConfig {
+ pub fn merge(self, higher: Self) -> Self {
+ Self {
+ cache_mode: higher.cache_mode.or(self.cache_mode),
+ cache_url: higher.cache_url.or(self.cache_url),
+ cache_host: higher.cache_host.or(self.cache_host),
+ cache_port: higher.cache_port.or(self.cache_port),
+ cache_nodes: if higher.cache_nodes.is_empty() {
+ self.cache_nodes
+ } else {
+ higher.cache_nodes
+ },
+ cache_username: higher.cache_username.or(self.cache_username),
+ cache_password: higher.cache_password.or(self.cache_password),
+ cache_database: higher.cache_database.or(self.cache_database),
+ cache_service_name: higher.cache_service_name.or(self.cache_service_name),
+ cache_tls_mode_secure: higher
+ .cache_tls_mode_secure
+ .or(self.cache_tls_mode_secure),
+ cache_use_resp3: higher
+ .cache_use_resp3
+ .or(self.cache_use_resp3),
+ }
+ }
+
+ pub fn with_defaults(self) -> Self {
+ Self {
+ cache_mode: Some(self.cache_mode.unwrap_or(CacheMode::Standalone)),
+ cache_url: self.cache_url,
+ cache_host: Some(self.cache_host.unwrap_or_else(|| "127.0.0.1".to_string())),
+ cache_port: Some(self.cache_port.unwrap_or(6379)),
+ cache_nodes: self.cache_nodes,
+ cache_username: self.cache_username,
+ cache_password: self.cache_password,
+ cache_database: Some(self.cache_database.unwrap_or(0)),
+ cache_service_name: self.cache_service_name,
+ cache_tls_mode_secure: Some(self.cache_tls_mode_secure.unwrap_or(false)),
+ cache_use_resp3: Some(self.cache_use_resp3.unwrap_or(false)),
+ }
+ }
+
+ pub fn defaults() -> Self {
+ Self::default().with_defaults()
+ }
+
+ pub fn mode(&self) -> CacheMode {
+ self.cache_mode.unwrap_or(CacheMode::Standalone)
+ }
+
+ pub fn url(&self) -> anyhow::Result<String> {
+ if let Some(url) = &self.cache_url {
+ return Ok(url.clone());
+ }
+
+ match self.mode() {
+ CacheMode::Standalone => {
+ let host = self
+ .cache_host
+ .as_deref()
+ .context("cache.host")?;
+ let port = self
+ .cache_port
+ .context("cache.port")?;
+ let db = self.cache_database.unwrap_or(0);
+
+ let auth = match (&self.cache_username, &self.cache_password) {
+ (Some(username), Some(password)) => format!("{username}:{password}@"),
+ (None, Some(password)) => format!(":{password}@"),
+ (Some(username), None) => format!("{username}@"),
+ (None, None) => String::new(),
+ };
+
+ Ok(format!("redis://{}{}:{}/{}", auth, host, port, db))
+ }
+ CacheMode::Clustered | CacheMode::Sentinel => {
+ self.cache_nodes
+ .first()
+ .cloned()
+ .context("cache.nodes[0]")
+ }
+ }
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum CacheConfigConversionError {
+ WrongMode(CacheMode),
+ MissingField(&'static str),
+}
+
+impl std::fmt::Display for CacheConfigConversionError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Self::WrongMode(mode) => write!(f, "cache mode must be sentinel, got {mode:?}"),
+ Self::MissingField(field) => write!(f, "missing required cache field: {field}"),
+ }
+ }
+}
+
+impl std::error::Error for CacheConfigConversionError {}
+
+impl TryFrom<&CacheConfig> for SentinelConfig {
+ type Error = CacheConfigConversionError;
+
+ fn try_from(value: &CacheConfig) -> Result<Self, Self::Error> {
+ if value.mode() != CacheMode::Sentinel {
+ return Err(CacheConfigConversionError::WrongMode(value.mode()));
+ }
+
+ Ok(SentinelConfig {
+ service_name: value
+ .cache_service_name
+ .clone()
+ .ok_or(CacheConfigConversionError::MissingField("cache.service_name"))?,
+ redis_tls_mode_secure: value.cache_tls_mode_secure.unwrap_or(false),
+ redis_db: value.cache_database.map(i64::from),
+ redis_username: value
+ .cache_username
+ .clone()
+ .ok_or(CacheConfigConversionError::MissingField("cache.username"))?,
+ redis_password: value
+ .cache_password
+ .clone()
+ .ok_or(CacheConfigConversionError::MissingField("cache.password"))?,
+ redis_use_resp3: value.cache_use_resp3.unwrap_or(false),
+ })
+ }
+}
+
+impl TryFrom<CacheConfig> for SentinelConfig {
+ type Error = CacheConfigConversionError;
+
+ fn try_from(value: CacheConfig) -> Result<Self, Self::Error> {
+ SentinelConfig::try_from(&value)
+ }
+}
+
+impl TryFrom<&CacheConfig> for RedisVariant {
+ type Error = anyhow::Error;
+
+ fn try_from(value: &CacheConfig) -> Result<Self, Self::Error> {
+ let s = SentinelConfig::try_from(value)?;
+
+ match value.mode() {
+ CacheMode::Standalone => Ok(RedisVariant::NonClustered),
+ CacheMode::Clustered => Ok(RedisVariant::Clustered),
+ CacheMode::Sentinel => Ok(RedisVariant::Sentinel(s)),
+ }
+ }
+
+}
+
+impl TryFrom<CacheConfig> for sh_util::cache::RedisVariant {
+ type Error = anyhow::Error;
+
+ fn try_from(value: CacheConfig) -> Result<Self, Self::Error> {
+ RedisVariant::try_from(&value)
+ }
+}
diff --git a/crates/sellershut/src/config/mod.rs b/crates/sellershut/src/config/mod.rs
index 389b4bc..156ad0f 100644
--- a/crates/sellershut/src/config/mod.rs
+++ b/crates/sellershut/src/config/mod.rs
@@ -1,4 +1,5 @@
pub mod auth;
+pub mod cache;
pub mod cli;
pub mod database;
mod server;
@@ -25,6 +26,9 @@ pub struct Config {
/// Database configuration.
#[command(flatten)]
pub database: database::DatabaseConfig,
+ /// Cache configuration.
+ #[command(flatten)]
+ pub cache: cache::CacheConfig,
}
impl Config {
pub fn load(cli: Self) -> Result<Self> {
@@ -43,6 +47,7 @@ impl Config {
server: self.server.merge(higher.server),
auth: self.auth.merge(higher.auth),
database: self.database.merge(higher.database),
+ cache: self.cache.merge(higher.cache),
}
}
@@ -52,6 +57,7 @@ impl Config {
server: self.server.with_defaults(),
auth: self.auth.with_defaults(),
database: self.database.with_defaults(),
+ cache: self.cache.with_defaults(),
}
}
@@ -61,6 +67,7 @@ impl Config {
server: server::ServerConfig::defaults(),
auth: auth::OauthConfig::defaults(),
database: database::DatabaseConfig::defaults(),
+ cache: cache::CacheConfig::defaults(),
}
}
}
diff --git a/crates/sellershut/src/main.rs b/crates/sellershut/src/main.rs
index ebae4ed..a46cf3e 100644
--- a/crates/sellershut/src/main.rs
+++ b/crates/sellershut/src/main.rs
@@ -15,6 +15,7 @@ use api_core::{
health::BaseService,
};
use clap::Parser;
+use sh_util::cache::{RedisManager, RedisVariant};
use sqlx::PgPool;
use tokio::net::TcpListener;
use tracing::info;
@@ -40,8 +41,10 @@ async fn main() -> Result<()> {
)?;
let database = state::postgres(&cfg.database.connection_url(), 100).await?;
+ let variant = RedisVariant::try_from(cfg.cache.clone())?;
+ let cache = RedisManager::new(&cfg.cache.url()?, variant).await;
- let auth_clients = build_oauth_client(&cfg.auth, database)?;
+ let auth_clients = build_oauth_client(&cfg.auth, database, cache)?;
let state = AppState::builder()
.log_handle(log_handle)
@@ -67,6 +70,7 @@ async fn main() -> Result<()> {
fn build_oauth_client(
config: &OauthConfig,
database: PgPool,
+ cache: RedisManager,
) -> Result<HashMap<OauthProvider, Arc<dyn OauthDriver>>> {
let auth = config.to_owned();
let mut collection: HashMap<OauthProvider, Arc<dyn OauthDriver>> = HashMap::new();
@@ -77,7 +81,7 @@ fn build_oauth_client(
let c = AuthClientConfig::try_from(auth.discord.context("missing discord config")?)?;
let client = BasicClient::try_from(c)?;
- let auth_service = Arc::new(AuthServiceDiscord::new(database, client));
+ let auth_service = Arc::new(AuthServiceDiscord::new(database, client, cache));
collection.insert(OauthProvider::Discord, auth_service);
}
diff --git a/crates/sellershut/src/server/api/routes/auth/discord.rs b/crates/sellershut/src/server/api/routes/auth/discord.rs
index 163619b..0296e48 100644
--- a/crates/sellershut/src/server/api/routes/auth/discord.rs
+++ b/crates/sellershut/src/server/api/routes/auth/discord.rs
@@ -32,7 +32,7 @@ pub async fn discord_auth(State(state): State<AppState>) -> Result<impl IntoResp
.context("missing discord driver")?;
let headers = HeaderMap::new();
- Ok((headers, Redirect::to(redirect_url)))
+ Ok((headers, Redirect::to("/")))
// let (auth_url, csrf_token) = client
// .authorize_url(CsrfToken::new_random)
diff --git a/crates/sh-util/Cargo.toml b/crates/sh-util/Cargo.toml
new file mode 100644
index 0000000..12bf7a4
--- /dev/null
+++ b/crates/sh-util/Cargo.toml
@@ -0,0 +1,26 @@
+[package]
+name = "sh-util"
+version = "0.0.0"
+edition = "2024"
+license.workspace = true
+readme.workspace = true
+documentation.workspace = true
+homepage.workspace = true
+
+[dependencies]
+bb8 = { version = "0.9.1", optional = true }
+futures-util = { workspace = true, optional = true }
+redis = { workspace = true, optional = true }
+serde = { workspace = true, features = ["derive"] }
+
+[features]
+cache = [
+ "dep:redis",
+ "redis/cluster-async",
+ "redis/connection-manager",
+ "redis/tokio-comp",
+ "redis/sentinel",
+ "redis/bb8",
+ "dep:bb8",
+ "dep:futures-util",
+]
diff --git a/crates/sh-util/src/cache/cluster.rs b/crates/sh-util/src/cache/cluster.rs
new file mode 100644
index 0000000..de13629
--- /dev/null
+++ b/crates/sh-util/src/cache/cluster.rs
@@ -0,0 +1,56 @@
+use redis::{
+ ErrorKind, FromRedisValue, IntoConnectionInfo, RedisError,
+ cluster::{ClusterClient, ClusterClientBuilder},
+ cluster_routing::{MultipleNodeRoutingInfo, ResponsePolicy, RoutingInfo},
+};
+
+/// ConnectionManager that implements `bb8::ManageConnection` and supports
+/// asynchronous clustered connections via `redis_cluster_async::Connection`
+#[derive(Clone)]
+pub struct RedisClusterConnectionManager {
+ client: ClusterClient,
+}
+
+impl RedisClusterConnectionManager {
+ pub fn new<T: IntoConnectionInfo>(
+ info: T,
+ ) -> Result<RedisClusterConnectionManager, RedisError> {
+ Ok(RedisClusterConnectionManager {
+ client: ClusterClientBuilder::new(vec![info]).retries(0).build()?,
+ })
+ }
+}
+
+impl bb8::ManageConnection for RedisClusterConnectionManager {
+ type Connection = redis::cluster_async::ClusterConnection;
+ type Error = RedisError;
+
+ async fn connect(&self) -> Result<Self::Connection, Self::Error> {
+ self.client.get_async_connection().await
+ }
+
+ async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
+ let pong = conn
+ .route_command(
+ redis::cmd("PING"),
+ RoutingInfo::MultiNode((
+ MultipleNodeRoutingInfo::AllMasters,
+ Some(ResponsePolicy::OneSucceeded),
+ )),
+ )
+ .await
+ .and_then(|v| Ok(String::from_redis_value(v)?))?;
+ match pong.as_str() {
+ "PONG" => Ok(()),
+ _ => Err((
+ ErrorKind::Server(redis::ServerErrorKind::ResponseError),
+ "ping request",
+ )
+ .into()),
+ }
+ }
+
+ fn has_broken(&self, _: &mut Self::Connection) -> bool {
+ false
+ }
+}
diff --git a/crates/sh-util/src/cache/mod.rs b/crates/sh-util/src/cache/mod.rs
new file mode 100644
index 0000000..67a5121
--- /dev/null
+++ b/crates/sh-util/src/cache/mod.rs
@@ -0,0 +1,176 @@
+mod cluster;
+mod sentinel;
+pub use sentinel::SentinelConfig;
+
+use std::{sync::Arc, time::Duration};
+
+use bb8::RunError;
+// use bb8_redis::RedisConnectionManager;
+use futures_util::lock::Mutex;
+use redis::{
+ AsyncConnectionConfig, ProtocolVersion, RedisConnectionInfo, RedisError, TlsMode,
+ aio::ConnectionManagerConfig, sentinel::SentinelNodeConnectionInfo,
+};
+
+pub use self::cluster::RedisClusterConnectionManager;
+
+pub const REDIS_CONN_TIMEOUT: Duration = Duration::from_secs(2);
+
+pub enum RedisVariant {
+ Clustered,
+ NonClustered,
+ Sentinel(sentinel::SentinelConfig),
+}
+
+#[derive(Clone)]
+pub enum RedisManager {
+ Clustered(redis::cluster_async::ClusterConnection),
+ NonClustered(redis::aio::ConnectionManager),
+ Sentinel(Arc<Mutex<redis::sentinel::SentinelClient>>),
+}
+
+impl RedisManager {
+ pub async fn new(dsn: &str, variant: RedisVariant) -> Self {
+ match variant {
+ RedisVariant::Clustered => {
+ let cli = redis::cluster::ClusterClient::builder(vec![dsn])
+ .retries(1)
+ .connection_timeout(REDIS_CONN_TIMEOUT)
+ .build()
+ .expect("Error initializing redis-unpooled cluster client");
+ let con = cli
+ .get_async_connection()
+ .await
+ .expect("Failed to get redis-cluster-unpooled connection");
+ RedisManager::Clustered(con)
+ }
+ RedisVariant::NonClustered => {
+ let cli =
+ redis::Client::open(dsn).expect("Error initializing redis unpooled client");
+ let con = redis::aio::ConnectionManager::new_with_config(
+ cli,
+ ConnectionManagerConfig::new()
+ .set_number_of_retries(1)
+ .set_connection_timeout(Some(REDIS_CONN_TIMEOUT)),
+ )
+ .await
+ .expect("Failed to get redis-unpooled connection manager");
+ RedisManager::NonClustered(con)
+ }
+ RedisVariant::Sentinel(cfg) => {
+ let tls_mode = if cfg.redis_tls_mode_secure {
+ TlsMode::Secure
+ } else {
+ TlsMode::Insecure
+ };
+ let protocol = if cfg.redis_use_resp3 {
+ ProtocolVersion::RESP3
+ } else {
+ ProtocolVersion::default()
+ };
+
+ let redis_connection_info = RedisConnectionInfo::default()
+ .set_db(cfg.redis_db.unwrap_or(0))
+ .set_protocol(protocol)
+ .set_username(cfg.redis_username.clone())
+ .set_password(cfg.redis_password.clone());
+ let sentinel = SentinelNodeConnectionInfo::default()
+ .set_redis_connection_info(redis_connection_info)
+ .set_tls_mode(tls_mode);
+
+ let cli = redis::sentinel::SentinelClient::build(
+ vec![dsn],
+ cfg.service_name.clone(),
+ Some(sentinel),
+ redis::sentinel::SentinelServerType::Master,
+ )
+ .expect("Failed to build sentinel client");
+
+ RedisManager::Sentinel(Arc::new(Mutex::new(cli)))
+ }
+ }
+ }
+
+ pub async fn get(&self) -> Result<RedisConnection, RunError<RedisError>> {
+ match self {
+ Self::Clustered(conn) => Ok(RedisConnection::Clustered(conn.clone())),
+ Self::NonClustered(conn) => Ok(RedisConnection::NonClustere(conn.clone())),
+ Self::Sentinel(conn) => {
+ let mut conn = conn.lock().await;
+ let con = conn
+ .get_async_connection_with_config(
+ &AsyncConnectionConfig::new()
+ .set_response_timeout(Some(REDIS_CONN_TIMEOUT)),
+ )
+ .await?;
+ Ok(RedisConnection::Sentinel(con))
+ }
+ }
+ }
+}
+
+pub enum RedisConnection {
+ Clustered(redis::cluster_async::ClusterConnection),
+ NonClustere(redis::aio::ConnectionManager),
+ Sentinel(redis::aio::MultiplexedConnection),
+}
+
+impl redis::aio::ConnectionLike for RedisConnection {
+ fn req_packed_command<'a>(
+ &'a mut self,
+ cmd: &'a redis::Cmd,
+ ) -> redis::RedisFuture<'a, redis::Value> {
+ match self {
+ RedisConnection::Clustered(conn) => conn.req_packed_command(cmd),
+ RedisConnection::NonClustere(conn) => conn.req_packed_command(cmd),
+ RedisConnection::Sentinel(conn) => conn.req_packed_command(cmd),
+ }
+ }
+
+ fn req_packed_commands<'a>(
+ &'a mut self,
+ cmd: &'a redis::Pipeline,
+ offset: usize,
+ count: usize,
+ ) -> redis::RedisFuture<'a, Vec<redis::Value>> {
+ match self {
+ RedisConnection::Clustered(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::NonClustere(conn) => conn.req_packed_commands(cmd, offset, count),
+ RedisConnection::Sentinel(conn) => conn.req_packed_commands(cmd, offset, count),
+ }
+ }
+
+ fn get_db(&self) -> i64 {
+ match self {
+ RedisConnection::Clustered(conn) => conn.get_db(),
+ RedisConnection::NonClustere(conn) => conn.get_db(),
+ RedisConnection::Sentinel(conn) => conn.get_db(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use redis::AsyncCommands;
+
+ use super::RedisManager;
+
+ // Ensure basic set/get works -- should test sharding as well:
+ #[tokio::test]
+ // run with `cargo test -- --ignored redis` only when redis is up and configured
+ #[ignore]
+ async fn test_set_read_random_keys() {
+ let mgr = RedisManager::new(
+ "redis://127.0.0.1:6379/0",
+ super::RedisVariant::NonClustered,
+ )
+ .await;
+ let mut conn = mgr.get().await.unwrap();
+
+ for (val, key) in "abcdefghijklmnopqrstuvwxyz".chars().enumerate() {
+ let key = key.to_string();
+ let _: () = conn.set(key.clone(), val).await.unwrap();
+ assert_eq!(conn.get::<_, usize>(&key).await.unwrap(), val);
+ }
+ }
+}
diff --git a/crates/sh-util/src/cache/sentinel.rs b/crates/sh-util/src/cache/sentinel.rs
new file mode 100644
index 0000000..e52b043
--- /dev/null
+++ b/crates/sh-util/src/cache/sentinel.rs
@@ -0,0 +1,66 @@
+use futures_util::lock::Mutex;
+use redis::{
+ ErrorKind, IntoConnectionInfo, RedisError,
+ sentinel::{SentinelClient, SentinelNodeConnectionInfo, SentinelServerType},
+};
+use serde::Deserialize;
+
+struct LockedSentinelClient(pub(crate) Mutex<SentinelClient>);
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct SentinelConfig {
+ pub service_name: String,
+ pub redis_tls_mode_secure: bool,
+ pub redis_db: Option<i64>,
+ pub redis_username: String,
+ pub redis_password: String,
+ pub redis_use_resp3: bool,
+}
+
+/// ConnectionManager that implements `bb8::ManageConnection` and supports
+/// asynchronous Sentinel connections via `redis::sentinel::SentinelClient`
+pub struct RedisSentinelConnectionManager {
+ client: LockedSentinelClient,
+}
+
+impl RedisSentinelConnectionManager {
+ pub fn new<T: IntoConnectionInfo>(
+ info: Vec<T>,
+ service_name: String,
+ node_connection_info: Option<SentinelNodeConnectionInfo>,
+ ) -> Result<RedisSentinelConnectionManager, RedisError> {
+ Ok(RedisSentinelConnectionManager {
+ client: LockedSentinelClient(Mutex::new(SentinelClient::build(
+ info,
+ service_name,
+ node_connection_info,
+ SentinelServerType::Master,
+ )?)),
+ })
+ }
+}
+
+impl bb8::ManageConnection for RedisSentinelConnectionManager {
+ type Connection = redis::aio::MultiplexedConnection;
+ type Error = RedisError;
+
+ async fn connect(&self) -> Result<Self::Connection, Self::Error> {
+ self.client.0.lock().await.get_async_connection().await
+ }
+
+ async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
+ let pong: String = redis::cmd("PING").query_async(conn).await?;
+ match pong.as_str() {
+ "PONG" => Ok(()),
+ _ => Err((
+ ErrorKind::Server(redis::ServerErrorKind::ResponseError),
+ "ping request",
+ )
+ .into()),
+ }
+ }
+
+ fn has_broken(&self, _: &mut Self::Connection) -> bool {
+ false
+ }
+}
diff --git a/crates/sh-util/src/lib.rs b/crates/sh-util/src/lib.rs
new file mode 100644
index 0000000..5501a81
--- /dev/null
+++ b/crates/sh-util/src/lib.rs
@@ -0,0 +1,2 @@
+#[cfg(feature = "cache")]
+pub mod cache;