diff options
| author | rtkay123 <dev@kanjala.com> | 2026-04-02 14:27:45 +0200 |
|---|---|---|
| committer | rtkay123 <dev@kanjala.com> | 2026-04-02 14:27:45 +0200 |
| commit | 8bc645b006080b860e40c0ff55b485125dc6157d (patch) | |
| tree | 8344e498fa50dc237f0f1fdbde2e38305da2fdfe | |
| parent | daeb5311840680599a0ce6e49d181b9289010f68 (diff) | |
| download | warden-master.tar.bz2 warden-master.zip | |
| -rw-r--r-- | Cargo.lock | 3 | ||||
| -rw-r--r-- | lib/api-config/src/lib.rs | 2 | ||||
| -rw-r--r-- | lib/api-config/src/schema/implementation.rs | 4 | ||||
| -rw-r--r-- | lib/api-config/src/schema/list_schemas.rs | 2 | ||||
| -rw-r--r-- | lib/api-config/src/schema/mod.rs | 1 | ||||
| -rw-r--r-- | lib/warden-core/src/config/cli/database.rs | 192 | ||||
| -rw-r--r-- | lib/warden-core/src/config/cli/mod.rs | 10 | ||||
| -rw-r--r-- | lib/warden-core/src/config/mod.rs | 231 | ||||
| -rw-r--r-- | warden/Cargo.toml | 5 | ||||
| -rw-r--r-- | warden/src/main.rs | 4 | ||||
| -rw-r--r-- | warden/src/server/api/version.rs | 1 | ||||
| -rw-r--r-- | warden/src/server/mod.rs | 41 | ||||
| -rw-r--r-- | warden/src/server/routes/config/logs.rs | 59 | ||||
| -rw-r--r-- | warden/src/server/routes/config/schema/create.rs | 149 | ||||
| -rw-r--r-- | warden/src/server/routes/config/schema/delete.rs | 37 | ||||
| -rw-r--r-- | warden/src/server/routes/config/schema/mod.rs | 258 | ||||
| -rw-r--r-- | warden/src/server/routes/config/schema/read.rs | 102 | ||||
| -rw-r--r-- | warden/src/server/routes/config/schema/update.rs | 172 | ||||
| -rw-r--r-- | warden/src/server/routes/mod.rs | 36 | ||||
| -rw-r--r-- | warden/src/state/mod.rs | 6 |
20 files changed, 1171 insertions, 144 deletions
@@ -3073,9 +3073,11 @@ version = "0.1.0" dependencies = [ "anyhow", "api-config", + "async-trait", "axum", "base64", "clap", + "http-body-util", "jsonschema", "secrecy", "serde", @@ -3084,6 +3086,7 @@ dependencies = [ "time", "tokio", "toml", + "tower", "tower-http", "tracing", "tracing-appender", diff --git a/lib/api-config/src/lib.rs b/lib/api-config/src/lib.rs index c2c6ccb..cdd48d1 100644 --- a/lib/api-config/src/lib.rs +++ b/lib/api-config/src/lib.rs @@ -1,5 +1,5 @@ //! Configuration -#![warn(missing_docs, missing_debug_implementations)] +#![warn(missing_debug_implementations)] mod error; /// Schema configuration implementation pub mod schema; diff --git a/lib/api-config/src/schema/implementation.rs b/lib/api-config/src/schema/implementation.rs index ca0757f..ed43c31 100644 --- a/lib/api-config/src/schema/implementation.rs +++ b/lib/api-config/src/schema/implementation.rs @@ -2,9 +2,7 @@ use async_trait::async_trait; use tracing::debug; use warden_core::pagination::{Connection, PaginationArgs}; -use crate::schema::{ - self, SchemaDriver, SchemaService, TransactionSchema, pagination::DecodedSchemaPagination, -}; +use crate::schema::{self, SchemaDriver, SchemaService, TransactionSchema}; #[async_trait] impl SchemaDriver for SchemaService { diff --git a/lib/api-config/src/schema/list_schemas.rs b/lib/api-config/src/schema/list_schemas.rs index 0b539ed..5e8c0aa 100644 --- a/lib/api-config/src/schema/list_schemas.rs +++ b/lib/api-config/src/schema/list_schemas.rs @@ -145,6 +145,7 @@ mod tests { migrator = "crate::MIGRATOR", fixtures(path = "../../tests/fixtures", scripts("schema")) )] + #[ignore = "requires live db"] async fn test_forward_pagination(pool: PgPool) -> anyhow::Result<()> { let get_count = 2; @@ -186,6 +187,7 @@ mod tests { migrator = "crate::MIGRATOR", fixtures(path = "../../tests/fixtures", scripts("schema")) )] + #[ignore = "requires live db"] async fn test_backward_pagination(pool: PgPool) -> anyhow::Result<()> { let get_count = 2; diff --git a/lib/api-config/src/schema/mod.rs b/lib/api-config/src/schema/mod.rs index 54bc015..38b93a8 100644 --- a/lib/api-config/src/schema/mod.rs +++ b/lib/api-config/src/schema/mod.rs @@ -123,6 +123,7 @@ mod tests { migrator = "crate::MIGRATOR", fixtures(path = "../../tests/fixtures", scripts("schema")) )] + #[ignore = "requires live db"] async fn schema(pool: PgPool) -> anyhow::Result<()> { let driver = SchemaService::new(pool); diff --git a/lib/warden-core/src/config/cli/database.rs b/lib/warden-core/src/config/cli/database.rs index 70bf600..90032a2 100644 --- a/lib/warden-core/src/config/cli/database.rs +++ b/lib/warden-core/src/config/cli/database.rs @@ -22,7 +22,7 @@ pub struct Database { pub database_password: Option<String>, /// Database host - #[arg(long, env = "DB_HOST", default_value = "localhost")] + #[arg(long, env = "DB_HOST")] #[serde(rename = "host")] pub database_host: Option<String>, @@ -37,7 +37,7 @@ pub struct Database { pub database_name: Option<String>, /// Database pool size - #[arg(long, env = "DATABASE_POOL_SIZE", default_value = "10")] + #[arg(long, env = "DATABASE_POOL_SIZE")] #[serde(rename = "pool-size")] pub database_pool_size: Option<u32>, } @@ -58,56 +58,67 @@ impl Default for Database { impl Database { pub fn merge(cli: &Self, file: &Self) -> Result<Self, WardenError> { - let url = cli.database_url.clone().or(file.database_url.clone()); - let pool_size = cli .database_pool_size .or(file.database_pool_size) .unwrap_or(10); - let final_url = match url { - Some(u) => u, - None => { - let host = cli - .database_host - .clone() - .or(file.database_host.clone()) - .unwrap_or_else(|| "localhost".to_string()); - - let mut u = Url::parse(&format!("postgresql://{}", host))?; - - let user = cli - .database_username - .as_ref() - .or(file.database_username.as_ref()); - let pass = cli - .database_password - .as_ref() - .or(file.database_password.as_ref()); - let port = cli.database_port.or(file.database_port); - let name = cli.database_name.as_ref().or(file.database_name.as_ref()); - - if let Some(user) = user { - u.set_username(user).ok(); - } - if let Some(pass) = pass { - u.set_password(Some(pass)).ok(); - } - if let Some(port) = port { - u.set_port(Some(port)).ok(); - } - if let Some(name) = name { - u.set_path(name); - } - - u - } - }; + if let Some(url) = cli + .database_url + .clone() + .or_else(|| file.database_url.clone()) + { + return Ok(Self { + database_url: Some(url), + database_pool_size: Some(pool_size), + ..Default::default() + }); + } + + let host = cli + .database_host + .as_deref() + .or(file.database_host.as_deref()) + .unwrap_or("localhost"); + + let mut u = Url::parse(&format!("postgresql://{}", host))?; + + let user = cli + .database_username + .as_deref() + .or(file.database_username.as_deref()); + let pass = cli + .database_password + .as_deref() + .or(file.database_password.as_deref()); + let port = cli.database_port.or(file.database_port); + let name = cli + .database_name + .as_deref() + .or(file.database_name.as_deref()); + + if let Some(user) = user { + u.set_username(user).ok(); + } + if let Some(pass) = pass { + u.set_password(Some(pass)).ok(); + } + if let Some(port) = port { + u.set_port(Some(port)).ok(); + } + if let Some(name) = name { + u.set_path(name); + } Ok(Self { - database_url: Some(final_url), + database_url: Some(u), database_pool_size: Some(pool_size), - ..cli.clone() + // Carry over the other fields for record-keeping + database_host: Some(host.to_string()), + database_username: user.map(String::from), + database_password: pass.map(String::from), + database_port: port, + database_name: name.map(String::from), }) } @@ -134,3 +145,96 @@ impl Database { Ok(url) } } + +#[cfg(test)] +mod tests { + use super::*; + use url::Url; + + /// Helper to create a "naked" Database struct with all Nones + /// Useful for testing merge logic without Default values interfering + fn empty_db() -> Database { + Database { + database_url: None, + database_username: None, + database_password: None, + database_host: None, + database_port: None, + database_name: None, + database_pool_size: None, + } + } + + #[test] + fn test_get_url_from_components() { + let db = Database { + database_host: Some("127.0.0.1".to_string()), + database_username: Some("admin".to_string()), + database_password: Some("secret".to_string()), + database_port: Some(5432), + database_name: Some("testdb".to_string()), + ..empty_db() + }; + + let url = db.get_url().expect("Should parse URL"); + // Note: get_url uses "postgres://" scheme + assert_eq!( + url.as_str(), + "postgres://admin:secret@127.0.0.1:5432/testdb" + ); + } + + #[test] + fn test_merge_cli_overrides_file() { + let mut file_config = empty_db(); + file_config.database_host = Some("file-host".to_string()); + file_config.database_port = Some(1111); + + let mut cli_config = empty_db(); + cli_config.database_host = Some("cli-host".to_string()); + // database_port is None in CLI + + let merged = Database::merge(&cli_config, &file_config).expect("Merge failed"); + + let url = merged.database_url.unwrap(); + // CLI host should win + assert_eq!(url.host_str(), Some("cli-host")); + // File port should win because CLI was None + assert_eq!(url.port(), Some(1111)); + } + + #[test] + fn test_merge_url_override_wins_all() { + let mut file_config = empty_db(); + file_config.database_host = Some("local-host".to_string()); + + let mut cli_config = empty_db(); + let expected_url = "postgresql://remote-host:9999/prod"; + cli_config.database_url = Some(Url::parse(expected_url).unwrap()); + + let merged = Database::merge(&cli_config, &file_config).expect("Merge failed"); + + assert_eq!(merged.database_url.unwrap().as_str(), expected_url); + } + + #[test] + fn test_merge_pool_size_logic() { + let mut file_config = empty_db(); + file_config.database_pool_size = Some(50); + + let cli_config = empty_db(); // pool_size is None + + let merged = Database::merge(&cli_config, &file_config).expect("Merge failed"); + + // Should take file value if CLI is None + assert_eq!(merged.database_pool_size, Some(50)); + } + + #[test] + fn test_default_trait_implementation() { + let db = Database::default(); + assert_eq!(db.database_port, Some(5432)); + assert_eq!(db.database_username, Some("postgres".to_string())); + assert_eq!(db.database_host, Some("localhost".to_string())); + } +} diff --git a/lib/warden-core/src/config/cli/mod.rs b/lib/warden-core/src/config/cli/mod.rs index e0c5450..f4ffe05 100644 --- a/lib/warden-core/src/config/cli/mod.rs +++ b/lib/warden-core/src/config/cli/mod.rs @@ -57,6 +57,15 @@ pub struct Server { default_value = "5" )] pub timeout_secs: Option<u64>, + /// Pagination limit + #[arg( + long, + value_name = "PAGINATION_LIMIT", + env = "PAGINATION_LIMIT", + default_value = "50" + )] + #[arg(value_parser = clap::value_parser!(i64).range(1..))] + pub pagination_limit: Option<i64>, } impl Default for Server { @@ -70,6 +79,7 @@ impl Default for Server { )), log_dir: Some(std::env::temp_dir()), timeout_secs: Some(5), + pagination_limit: Some(50), } } } diff --git a/lib/warden-core/src/config/mod.rs b/lib/warden-core/src/config/mod.rs index 9d0c937..8be7205 100644 --- a/lib/warden-core/src/config/mod.rs +++ b/lib/warden-core/src/config/mod.rs @@ -13,6 +13,16 @@ use crate::WardenError; use crate::config::cli::CliEnvironment; use crate::config::cli::database::Database; +macro_rules! pick { + ($cli:expr, $file:expr, $name:expr, $missing:expr) => {{ + let val = $cli.clone().or($file.clone()); + if val.is_none() { + $missing.push($name); + } + val + }}; +} + #[derive(Deserialize, Default, Debug, ValueEnum, Clone, Copy)] #[serde(rename_all = "lowercase")] pub enum Environment { @@ -43,67 +53,65 @@ pub struct Server { pub log_level: EnvFilter, pub log_dir: PathBuf, pub timeout_secs: u64, + pub pagination_limit: i64, } impl Server { - fn merge(cli: &Cli, file: &Cli, missing: &mut Vec<&str>) -> Result<Self, WardenError> { - let port = cli.server.port.or(file.server.port); - - if port.is_none() { - missing.push("server.port"); - } - - let timeout = cli.server.timeout_secs.or(file.server.timeout_secs); - - if timeout.is_none() { - missing.push("server.timeout"); - } - - let log_dir = cli.server.log_dir.clone().or(file.server.log_dir.clone()); - - if log_dir.is_none() { - missing.push("server.log_dir"); - } - - let log_level = cli - .server - .log_level - .as_ref() - .or(file.server.log_level.as_ref()) - .map(ToOwned::to_owned); - - if log_level.is_none() { - missing.push("server.log_level"); - } - - let environment = cli.server.environment.or(file.server.environment); - - if environment.is_none() { - missing.push("server.environment"); - } + pub fn merge(cli: &Cli, file: &Cli, missing: &mut Vec<&str>) -> Result<Self, WardenError> { + let port = pick!(cli.server.port, file.server.port, "server.port", missing); + let timeout = pick!( + cli.server.timeout_secs, + file.server.timeout_secs, + "server.timeout", + missing + ); + let log_dir = pick!( + cli.server.log_dir.clone(), + file.server.log_dir.clone(), + "server.log_dir", + missing + ); + let env = pick!( + cli.server.environment, + file.server.environment, + "server.environment", + missing + ); + let raw_log_level = pick!( + cli.server.log_level.clone(), + file.server.log_level.clone(), + "server.log_level", + missing + ); + let pagination_limit = pick!( + cli.server.pagination_limit.clone(), + file.server.pagination_limit.clone(), + "server.pagination_limit", + missing + ); if !missing.is_empty() { - let err = missing + let err_msg = missing .iter() .map(|f| format!(" - {}", f)) .collect::<Vec<_>>() .join("\n"); - return Err(WardenError::Config(err)); + return Err(WardenError::Config(format!( + "Missing required fields:\n{}", + err_msg + ))); } - let log_level = - tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { - // axum logs rejections from built-in extractors with the `axum::rejection` - // target, at `TRACE` level. `axum::rejection=trace` enables showing those events - log_level.unwrap().into() - }); + let log_level = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| raw_log_level.unwrap().into()); Ok(Self { port: port.unwrap(), - environment: environment.unwrap().into(), + environment: env.unwrap().into(), log_dir: log_dir.unwrap(), timeout_secs: timeout.unwrap(), log_level, + pagination_limit: pagination_limit.unwrap(), }) } } @@ -118,3 +126,138 @@ impl Configuration { Ok(Self { server, database }) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_merge_config() { + let mut cli = Cli::default(); + cli.server.port = Some(8080); + let timeout = 30; + cli.server.timeout_secs = Some(timeout); + + let file = Cli { + server: cli::Server { + environment: Some(CliEnvironment::Dev), + log_level: Some("info".into()), + log_dir: Some(PathBuf::from("/tmp")), + timeout_secs: Some(timeout), + ..Default::default() + }, + ..Default::default() + }; + + let result = Configuration::merge(&cli, &file); + assert!( + result.is_ok(), + "Merge should succeed when all fields are covered" + ); + + let server = result.unwrap(); + assert_eq!(server.server.port, 8080); + } + + #[test] + fn test_merge_all_fields_present_success() { + let mut cli = Cli::default(); + cli.server.port = Some(8080); + let timeout = 30; + cli.server.timeout_secs = Some(timeout); + + let file = Cli { + server: cli::Server { + environment: Some(CliEnvironment::Dev), + log_level: Some("info".into()), + log_dir: Some(PathBuf::from("/tmp")), + timeout_secs: Some(timeout), + ..Default::default() + }, + ..Default::default() + }; + let mut missing = vec![]; + + let result = Server::merge(&cli, &file, &mut missing); + assert!( + result.is_ok(), + "Merge should succeed when all fields are covered" + ); + + let server = result.unwrap(); + assert_eq!(server.port, 8080); + assert_eq!(server.timeout_secs, timeout); + } + + #[test] + fn test_merge_error_accumulation() { + let cli = Cli { + server: cli::Server { + port: None, + environment: None, + log_level: None, + log_dir: None, + timeout_secs: None, + pagination_limit: None, + }, + ..Default::default() + }; + + let mut missing = vec![]; + + let result = Server::merge(&cli, &cli, &mut missing); + dbg!(&result); + + match result { + Err(WardenError::Config(msg)) => { + assert!(msg.contains("server.port")); + assert!(msg.contains("server.environment")); + } + _ => panic!("Expected a Config error with multiple missing fields"), + } + } + + #[test] + fn test_cli_priority_over_file() { + let mut cli = Cli::default(); + cli.server.port = Some(9999); + + let file = Cli { + server: cli::Server { + port: Some(1111), // This should be ignored + environment: Some(CliEnvironment::Prod), + log_level: Some("error".into()), + log_dir: Some(PathBuf::from("/var/log")), + timeout_secs: Some(60), + ..Default::default() + }, + ..Default::default() + }; + + let mut missing = vec![]; + + let server = Server::merge(&cli, &file, &mut missing).expect("Merge failed"); + assert_eq!(server.port, 9999, "CLI port must override File port"); + } + + #[test] + fn test_env_filter_from_raw_string() { + let log_level = "warn"; + let cli = Cli { + server: cli::Server { + port: Some(80), + environment: Some(CliEnvironment::Production), + log_level: Some(log_level.to_string()), + log_dir: Some(PathBuf::from(".")), + timeout_secs: Some(5), + ..Default::default() + }, + ..Default::default() + }; + let file = Cli::default(); + let mut missing = vec![]; + + let server = Server::merge(&cli, &file, &mut missing).unwrap(); + assert_eq!(&server.log_level.to_string(), log_level) + } +} diff --git a/warden/Cargo.toml b/warden/Cargo.toml index 47cd8b6..f9e85a2 100644 --- a/warden/Cargo.toml +++ b/warden/Cargo.toml @@ -43,6 +43,11 @@ features = ["json", "runtime-tokio-rustls", "time"] version = "1.50.0" features = ["macros", "rt", "rt-multi-thread"] +[dev-dependencies] +async-trait.workspace = true +http-body-util = "0.1.3" +tower = { version = "0.5.3", features = ["util"] } + [features] rapidoc = ["dep:utoipa-rapidoc", "utoipa-rapidoc/axum"] scalar = ["dep:utoipa-scalar", "utoipa-scalar/axum"] diff --git a/warden/src/main.rs b/warden/src/main.rs index 59068f5..d9fa004 100644 --- a/warden/src/main.rs +++ b/warden/src/main.rs @@ -40,8 +40,8 @@ async fn main() -> anyhow::Result<()> { let schema = SchemaService::new(state::database::connect(&config.database).await?); let schema = Arc::new(schema); - let state = state::AppState::new(log_handle, schema).await?; - let app = server::router(Arc::new(state), &config).await; + let state = state::AppState::new(log_handle, schema, config.server.pagination_limit).await?; + let app = server::router(Arc::new(state), config.server.timeout_secs).await; let addr = SocketAddr::from((Ipv6Addr::UNSPECIFIED, config.server.port)); info!(port = addr.port(), "starting server"); diff --git a/warden/src/server/api/version.rs b/warden/src/server/api/version.rs index f8d856a..1639cf3 100644 --- a/warden/src/server/api/version.rs +++ b/warden/src/server/api/version.rs @@ -11,6 +11,7 @@ use utoipa::{IntoParams, ToSchema}; #[derive(Deserialize, Debug, IntoParams)] #[serde(rename_all = "camelCase")] +#[allow(unused)] pub struct VersionPath { pub api_version: Version, } diff --git a/warden/src/server/mod.rs b/warden/src/server/mod.rs index 92ac12d..9a311b7 100644 --- a/warden/src/server/mod.rs +++ b/warden/src/server/mod.rs @@ -15,7 +15,6 @@ use utoipa::OpenApi; use utoipa_axum::router::OpenApiRouter; use crate::{ - config::Configuration, server::{ middleware::request_id::{REQUEST_ID_HEADER, middleware_request_id}, routes::{ @@ -32,7 +31,7 @@ pub mod error; pub mod middleware; pub mod routes; -pub async fn router(state: Arc<AppState>, config: &Configuration) -> Router<()> { +pub async fn router(state: Arc<AppState>, timeout_secs: u64) -> Router<()> { let mut doc = ApiDoc::openapi(); doc.merge(ConfigDoc::openapi()); @@ -92,7 +91,7 @@ pub async fn router(state: Arc<AppState>, config: &Configuration) -> Router<()> ) .layer(TimeoutLayer::with_status_code( StatusCode::REQUEST_TIMEOUT, - Duration::from_secs(config.server.timeout_secs), + Duration::from_secs(timeout_secs), )) .layer(PropagateRequestIdLayer::new(HeaderName::from_static( REQUEST_ID_HEADER, @@ -105,3 +104,39 @@ pub async fn router(state: Arc<AppState>, config: &Configuration) -> Router<()> .allow_methods(cors::Any), ) } + +#[cfg(test)] +mod tests { + use std::sync::{Arc, OnceLock}; + + use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, reload}; + + use crate::{ + logging::LogHandle, server::routes::config::schema::tests::MockSchemaDriver, + state::AppState, + }; + + static TEST_LOG_DATA: OnceLock<LogHandle> = OnceLock::new(); + + pub async fn test_app() -> axum::Router { + let log_handle = TEST_LOG_DATA + .get_or_init(|| { + let filter = EnvFilter::new("warn"); + let (layer, handle) = reload::Layer::new(filter); + + let subscriber = Registry::default().with(layer); + + let _ = tracing::subscriber::set_global_default(subscriber); + + handle + }) + .clone(); + let state = Arc::new(MockSchemaDriver::new()); + let state = AppState { + schema_service: state, + log_handle, + pagination_limit: 50, + }; + super::router(state.into(), 5).await + } +} diff --git a/warden/src/server/routes/config/logs.rs b/warden/src/server/routes/config/logs.rs index 0f2cbf8..e0e379e 100644 --- a/warden/src/server/routes/config/logs.rs +++ b/warden/src/server/routes/config/logs.rs @@ -50,9 +50,66 @@ pub async fn reload( if let Ok(value) = body.log_level.parse::<tracing_subscriber::EnvFilter>() { match state.log_handle.reload(value) { Ok(_) => StatusCode::OK, - Err(_e) => StatusCode::INTERNAL_SERVER_ERROR, + Err(e) => { + println!("{e:?}"); + StatusCode::INTERNAL_SERVER_ERROR + } } } else { StatusCode::BAD_REQUEST } } + +#[cfg(test)] +mod tests { + use axum::{ + Router, + body::Body, + http::{Request, StatusCode, header}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + async fn check( + app: Router, + method: &str, + body: String, + expected_result: StatusCode, + ) -> Result<()> { + let response = app + .oneshot( + Request::builder() + .method(method) + .header(header::CONTENT_TYPE, "application/json") + .uri("/api/logging") + .body(Body::from(body))?, + ) + .await?; + let actual_result = response.status(); + assert_eq!(expected_result, actual_result); + Ok(()) + } + + #[tokio::test] + async fn log_update() -> Result<()> { + let app = server::tests::test_app().await; + + let info = serde_json::json!({ + "logLevel": "info", + }); + + check( + app.clone(), + "GET", + info.to_string(), + StatusCode::METHOD_NOT_ALLOWED, + ) + .await?; + + check(app.clone(), "PATCH", info.to_string(), StatusCode::OK).await?; + Ok(()) + } +} diff --git a/warden/src/server/routes/config/schema/create.rs b/warden/src/server/routes/config/schema/create.rs index 767ff3f..e6de992 100644 --- a/warden/src/server/routes/config/schema/create.rs +++ b/warden/src/server/routes/config/schema/create.rs @@ -97,17 +97,144 @@ pub async fn create_schema( .schema_service .create_schema(&body.schema_type, &body.schema_version, &body.schema) .await - .map_err(|e| match e { - api_config::ConfigurationError::Database(ref error) => match error { - sqlx::Error::Database(db_err) if db_err.code() == Some("23505".into()) => { - AppError::new( - StatusCode::CONFLICT, - anyhow::anyhow!("Transaction schema already exists"), - ) - } - _ => e.into(), - }, - _ => e.into(), + .map_err(|e| { + if let api_config::ConfigurationError::Database(sqlx::Error::Database(ref error)) = e + && error.is_unique_violation() + { + AppError::new( + StatusCode::CONFLICT, + anyhow::anyhow!("Transaction schema already exists"), + ) + } else { + e.into() + } })?; Ok((StatusCode::CREATED, Json(result))) } + +#[cfg(test)] +mod tests { + use axum::{ + Router, + body::Body, + http::{Request, StatusCode, header}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + async fn check( + app: Router, + method: &str, + body: String, + expected_result: StatusCode, + ) -> Result<()> { + let response = app + .oneshot( + Request::builder() + .method(method) + .header(header::CONTENT_TYPE, "application/json") + .uri("/api/v0/config/schema") + .body(Body::from(body))?, + ) + .await?; + let actual_result = response.status(); + assert_eq!(expected_result, actual_result); + Ok(()) + } + + #[tokio::test] + async fn save_schema() -> Result<()> { + let app = server::tests::test_app().await; + + let schema = serde_json::json!({ + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "amount": { + "exclusiveMinimum": 0, + "type": "number" + }, + }, + "required": [ + "amount", + ], + "title": "FinancialTransaction", + "type": "object" + }, + "schemaType": "custom.schema", + "schemaVersion": "1.0.0" + }); + + check(app.clone(), "POST", schema.to_string(), StatusCode::CREATED).await?; + + //Already exists + check( + app.clone(), + "POST", + schema.to_string(), + StatusCode::CONFLICT, + ) + .await?; + + let schema = serde_json::json!({ + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "amount": { + "exclusiveMinimum": "0", + "type": "something" + }, + }, + "required": [ + "amount", + ], + "title": "FinancialTransaction", + "type": "object" + }, + "schemaType": "custom.schema", + "schemaVersion": "1.0.0" + }); + + // Bad schema + check( + app.clone(), + "POST", + schema.to_string(), + StatusCode::BAD_REQUEST, + ) + .await?; + + let schema = serde_json::json!({ + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "amount": { + "exclusiveMinimum": 0, + "type": "number" + }, + }, + "required": [ + "amount", + ], + "title": "FinancialTransaction", + "type": "object" + }, + "schemaType": "error", + "schemaVersion": "1.0.0" + }); + + // error type for tests + check( + app.clone(), + "POST", + schema.to_string(), + StatusCode::INTERNAL_SERVER_ERROR, + ) + .await?; + + Ok(()) + } +} diff --git a/warden/src/server/routes/config/schema/delete.rs b/warden/src/server/routes/config/schema/delete.rs index 55577ae..afd6fae 100644 --- a/warden/src/server/routes/config/schema/delete.rs +++ b/warden/src/server/routes/config/schema/delete.rs @@ -92,3 +92,40 @@ pub async fn delete_schema( .await?; Ok(StatusCode::NO_CONTENT) } + +#[cfg(test)] +mod tests { + use axum::{ + Router, + body::Body, + http::{Request, StatusCode}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + async fn check(app: Router, method: &str, expected_result: StatusCode) -> Result<()> { + let response = app + .oneshot( + Request::builder() + .method(method) + .uri("/api/v0/config/schema?schemaType=1&schemaVersion=2") + .body(Body::empty())?, + ) + .await?; + let actual_result = response.status(); + assert_eq!(expected_result, actual_result); + Ok(()) + } + + #[tokio::test] + async fn delete_schema() -> Result<()> { + let app = server::tests::test_app().await; + + check(app.clone(), "DELETE", StatusCode::NO_CONTENT).await?; + + Ok(()) + } +} diff --git a/warden/src/server/routes/config/schema/mod.rs b/warden/src/server/routes/config/schema/mod.rs index 17db5ce..901d116 100644 --- a/warden/src/server/routes/config/schema/mod.rs +++ b/warden/src/server/routes/config/schema/mod.rs @@ -25,3 +25,261 @@ pub fn router(store: Arc<AppState>) -> OpenApiRouter { .routes(utoipa_axum::routes!(update::update_schema)) .with_state(store) } + +#[cfg(test)] +pub mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use api_config::schema::TransactionSchema; + use api_config::{ConfigurationError, schema::SchemaDriver}; + use async_trait::async_trait; + use time::OffsetDateTime; + use tokio::sync::RwLock; + + use serde_json::Value; + use warden_core::pagination::{Connection, Edge, PageInfo, PaginationArgs}; + + #[derive(Default)] + pub struct MockSchemaDriver { + store: Arc<RwLock<HashMap<(String, String), TransactionSchema>>>, + } + + impl MockSchemaDriver { + pub fn new() -> Self { + Self::default() + } + } + + use sqlx::error::DatabaseError; + use std::fmt; + + #[derive(Debug)] + struct UniqueViolationError { + msg: String, + } + + impl fmt::Display for UniqueViolationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.msg) + } + } + + impl std::error::Error for UniqueViolationError {} + + impl DatabaseError for UniqueViolationError { + fn message(&self) -> &str { + &self.msg + } + + fn as_error(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { + unimplemented!() + } + + fn as_error_mut(&mut self) -> &mut (dyn std::error::Error + Send + Sync + 'static) { + unimplemented!() + } + + fn into_error(self: Box<Self>) -> Box<dyn std::error::Error + Send + Sync + 'static> { + unimplemented!() + } + + fn kind(&self) -> sqlx::error::ErrorKind { + sqlx::error::ErrorKind::UniqueViolation + } + } + + #[derive(Debug)] + struct OtherDbError { + msg: String, + } + + impl fmt::Display for OtherDbError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.msg) + } + } + + impl std::error::Error for OtherDbError {} + + impl DatabaseError for OtherDbError { + fn message(&self) -> &str { + &self.msg + } + + fn as_error(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { + unimplemented!() + } + + fn as_error_mut(&mut self) -> &mut (dyn std::error::Error + Send + Sync + 'static) { + unimplemented!() + } + + fn into_error(self: Box<Self>) -> Box<dyn std::error::Error + Send + Sync + 'static> { + unimplemented!() + } + + fn kind(&self) -> sqlx::error::ErrorKind { + sqlx::error::ErrorKind::Other + } + } + + #[async_trait] + impl SchemaDriver for MockSchemaDriver { + async fn create_schema( + &self, + kind: &str, + version: &str, + schema: &Value, + ) -> Result<TransactionSchema, ConfigurationError> { + let mut store = self.store.write().await; + let key = (kind.to_string(), version.to_string()); + + if store.contains_key(&key) { + let err = UniqueViolationError { + msg: "key".to_string(), + }; + dbg!(err.code()); + return Err(ConfigurationError::Database(sqlx::Error::Database( + Box::new(err), + ))); + }; + + if matches!(kind, "error") { + let err = OtherDbError { + msg: "key".to_string(), + }; + dbg!(err.code()); + return Err(ConfigurationError::Database(sqlx::Error::Database( + Box::new(err), + ))); + } + + let schema_obj = TransactionSchema { + id: 1, + schema_type: kind.to_string(), + schema_version: version.to_string(), + schema: schema.clone(), + created_at: OffsetDateTime::now_utc(), + updated_at: OffsetDateTime::now_utc(), + }; + + store.insert(key, schema_obj.clone()); + Ok(schema_obj) + } + + async fn delete_schema(&self, kind: &str, version: &str) -> Result<(), ConfigurationError> { + let mut store = self.store.write().await; + let key = (kind.to_string(), version.to_string()); + + store.remove(&key); + Ok(()) + } + + async fn get_schema( + &self, + kind: &str, + version: &str, + ) -> Result<Option<TransactionSchema>, ConfigurationError> { + let store = self.store.read().await; + let key = (kind.to_string(), version.to_string()); + + Ok(store.get(&key).cloned()) + } + + async fn update_schema( + &self, + kind: &str, + version: &str, + schema: &Value, + ) -> Result<Option<TransactionSchema>, ConfigurationError> { + let mut store = self.store.write().await; + let key = (kind.to_string(), version.to_string()); + + if let Some(existing) = store.get_mut(&key) { + existing.schema = schema.clone(); + return Ok(Some(existing.clone())); + } + + Ok(None) + } + + async fn list_schemas( + &self, + input: &PaginationArgs, + limit: i64, + ) -> Result<Connection<TransactionSchema>, ConfigurationError> { + let store = self.store.read().await; + + // 1. Collect + sort for stable pagination + let mut items: Vec<_> = store.values().cloned().collect(); + items.sort_by(|a, b| { + (a.schema_type.clone(), a.schema_version.clone()) + .cmp(&(b.schema_type.clone(), b.schema_version.clone())) + }); + + // 2. Convert to edges + let mut edges: Vec<Edge<TransactionSchema>> = items + .into_iter() + .map(|schema| Edge { + cursor: format!("{}:{}", schema.schema_type, schema.schema_version), + node: schema, + }) + .collect(); + + // 3. Apply cursors (after / before) + if let Some(after) = &input.after + && let Some(pos) = edges.iter().position(|e| &e.cursor == after) + { + edges = edges.into_iter().skip(pos + 1).collect(); + } + + if let Some(before) = &input.before + && let Some(pos) = edges.iter().position(|e| &e.cursor == before) + { + edges = edges.into_iter().take(pos).collect(); + } + + let total = edges.len(); + + // 4. Apply first / last + let mut sliced = edges; + + if let Some(first) = input.first { + let take = first.min(limit).max(0) as usize; + sliced = sliced.into_iter().take(take).collect(); + } else if let Some(last) = input.last { + let take = last.min(limit).max(0) as usize; + let len = sliced.len(); + sliced = sliced.into_iter().skip(len.saturating_sub(take)).collect(); + } else { + // default limit + sliced = sliced.into_iter().take(limit as usize).collect(); + } + + // 5. PageInfo + let start_cursor = sliced.first().map(|e| e.cursor.clone()); + let end_cursor = sliced.last().map(|e| e.cursor.clone()); + + let has_next_page = match input.first { + Some(first) => total > first as usize, + None => false, + }; + + let has_previous_page = match input.last { + Some(last) => total > last as usize, + None => false, + }; + + Ok(Connection { + edges: sliced, + page_info: PageInfo { + has_next_page, + has_previous_page, + start_cursor, + end_cursor, + }, + }) + } + } +} diff --git a/warden/src/server/routes/config/schema/read.rs b/warden/src/server/routes/config/schema/read.rs index 2d12935..33eaa30 100644 --- a/warden/src/server/routes/config/schema/read.rs +++ b/warden/src/server/routes/config/schema/read.rs @@ -90,19 +90,7 @@ pub async fn get_schema( let result = state .schema_service .get_schema(&schema_type, &schema_version) - .await - .map_err(|e| match e { - api_config::ConfigurationError::Database(ref error) => match error { - sqlx::Error::Database(db_err) if db_err.code() == Some("23505".into()) => { - AppError::new( - StatusCode::CONFLICT, - anyhow::anyhow!("Transaction schema already exists"), - ) - } - _ => e.into(), - }, - _ => e.into(), - })?; + .await?; if let Some(result) = result { Ok(Json(result)) } else { @@ -175,25 +163,79 @@ pub async fn get_schemas( params: Query<PaginationArgs>, ) -> Result<impl IntoResponse, AppError> { debug!("searching for schema"); - let limit = 10; // TODO: get from cache - let rows = state - .schema_service - .list_schemas(¶ms, limit) - .await - .map_err(|e| match e { - api_config::ConfigurationError::Database(ref error) => match error { - sqlx::Error::Database(db_err) if db_err.code() == Some("23505".into()) => { - AppError::new( - StatusCode::CONFLICT, - anyhow::anyhow!("Transaction schema already exists"), - ) - } - _ => e.into(), - }, - _ => e.into(), - })?; + let rows = state.schema_service.list_schemas(¶ms, state.pagination_limit).await?; Ok(Json(rows)) } + +#[cfg(test)] +mod tests { + use axum::{ + body::Body, + http::{Request, StatusCode, header}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + #[tokio::test] + async fn get_schema() -> Result<()> { + let app = server::tests::test_app().await; + + let info = serde_json::json!({ + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "amount": { + "exclusiveMinimum": 0, + "type": "number" + }, + }, + "required": [ + "amount", + ], + "title": "FinancialTransaction", + "type": "object" + }, + "schemaType": "1", + "schemaVersion": "1" + }); + + app.clone() + .oneshot( + Request::builder() + .method("POST") + .header(header::CONTENT_TYPE, "application/json") + .uri("/api/v0/config/schema") + .body(Body::from(info.to_string()))?, + ) + .await?; + + let response = app + .clone() + .oneshot( + Request::builder() + .uri("/api/v0/config/schema/1/1") + .body(Body::empty())?, + ) + .await?; + + assert_eq!(StatusCode::OK, response.status()); + + let response = app + .oneshot( + Request::builder() + .uri("/api/v0/config/schema/12/1") + .body(Body::empty())?, + ) + .await?; + + assert_eq!(StatusCode::NOT_FOUND, response.status()); + + Ok(()) + } +} diff --git a/warden/src/server/routes/config/schema/update.rs b/warden/src/server/routes/config/schema/update.rs index ff518a7..4eb8e76 100644 --- a/warden/src/server/routes/config/schema/update.rs +++ b/warden/src/server/routes/config/schema/update.rs @@ -1,12 +1,13 @@ use std::sync::Arc; -use api_config::schema::{CreateSchema, TransactionSchema}; +use api_config::schema::TransactionSchema; use axum::{ Json, debug_handler, extract::{Path, State}, http::{HeaderMap, StatusCode}, response::IntoResponse, }; +use tracing::trace; use crate::{ server::{ @@ -28,6 +29,24 @@ use crate::{ ("x-request-id" = Uuid, description = "Request identifier") ), body = TransactionSchema, + example = json!({ + "schemaType": "custom.schema", + "schemaVersion": "1.0.0", + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "FinancialTransaction", + "type": "object", + "required": ["transactionId"], + "properties": { + "transactionId": { + "type": "string", + "format": "uuid" + }, + } + }, + "createdAt": time::OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339).unwrap(), + "updatedAt": time::OffsetDateTime::now_utc().format(&time::format_description::well_known::Rfc3339).unwrap(), + }) ), ( status = 400, @@ -61,7 +80,22 @@ use crate::{ operation_id = "update_schema", // https://github.com/juhaku/utoipa/issues/1170 tag = SCHEMA, request_body( - content = CreateSchema + content = serde_json::Value, + description = "The schema to set", + example = json!({ + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "FinancialTransaction", + "type": "object", + "required": ["transactionId"], + "properties": { + "transactionId": { + "type": "string", + "format": "uuid" + }, + } + } + }) ), path = "/{apiVersion}/config/schema/{schemaType}/{schemaVersion}", params( @@ -82,10 +116,18 @@ use crate::{ #[debug_handler] pub async fn update_schema( State(state): State<Arc<AppState>>, - headers: HeaderMap, Path((version, schema_type, schema_version)): Path<(Version, String, String)>, + headers: HeaderMap, Json(body): Json<serde_json::Value>, ) -> Result<impl IntoResponse, AppError> { + trace!("checking schema validity"); + jsonschema::Validator::new(&body).map_err(|e| { + AppError::new( + StatusCode::BAD_REQUEST, + anyhow::anyhow!("Invalid schema: {e}"), + ) + })?; + // TODO: should also clear cached ones eventually let result = state .schema_service @@ -101,3 +143,127 @@ pub async fn update_schema( )) } } + +#[cfg(test)] +mod tests { + use axum::{ + Router, + body::Body, + http::{Request, StatusCode, header}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + async fn check( + app: Router, + method: &str, + body: String, + expected_result: StatusCode, + endpoint: &str, + ) -> Result<()> { + let response = app + .oneshot( + Request::builder() + .method(method) + .header(header::CONTENT_TYPE, "application/json") + .uri(endpoint) + .body(Body::from(body))?, + ) + .await?; + let actual_result = response.status(); + assert_eq!(expected_result, actual_result); + Ok(()) + } + + #[tokio::test] + async fn update_schema() -> Result<()> { + let app = server::tests::test_app().await; + let create_endpoint = "/api/v0/config/schema"; + + let schema = serde_json::json!({ + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "amount": { + "exclusiveMinimum": 0, + "type": "number" + }, + }, + "required": [ + "amount", + ], + "title": "FinancialTransaction", + "type": "object" + }, + "schemaType": "custom.schema", + "schemaVersion": "1.0.0" + }); + + check( + app.clone(), + "POST", + schema.to_string(), + StatusCode::CREATED, + create_endpoint, + ) + .await?; + + let schema = serde_json::json!({ + "schema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "currency": { + "type": "string" + }, + }, + "required": [ + "currency", + ], + "title": "FinancialTransaction", + "type": "object" + }, + "schemaType": "custom.schema", + "schemaVersion": "1.0.0" + }); + + // Successful update + check( + app.clone(), + "PATCH", + schema.to_string(), + StatusCode::OK, + "/api/v0/config/schema/custom.schema/1.0.0", + ) + .await?; + + let schema = serde_json::json!({ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "properties": { + "amount": { + "exclusiveMinimum": "0", + "type": "something" + }, + }, + "required": [ + "value", + ], + "title": "FinancialTransaction", + "type": "object" + }); + + // Bad schema + check( + app.clone(), + "PATCH", + schema.to_string(), + StatusCode::BAD_REQUEST, + "/api/v0/config/schema/custom.schema/1.0.0", + ) + .await?; + + Ok(()) + } +} diff --git a/warden/src/server/routes/mod.rs b/warden/src/server/routes/mod.rs index 5994987..09f0240 100644 --- a/warden/src/server/routes/mod.rs +++ b/warden/src/server/routes/mod.rs @@ -36,3 +36,39 @@ pub async fn health() -> impl IntoResponse { let ver = env!("CARGO_PKG_VERSION"); format!("{name} v{ver} is live") } + +#[cfg(test)] +mod tests { + use axum::{ + Router, + body::Body, + http::{Request, StatusCode}, + }; + + use anyhow::Result; + use tower::ServiceExt; + + use crate::server::{self}; + + async fn check(app: Router, method: &str, expected_result: StatusCode) -> Result<()> { + let response = app + .oneshot( + Request::builder() + .method(method) + .uri("/api/health") + .body(Body::empty())?, + ) + .await?; + let actual_result = response.status(); + assert_eq!(expected_result, actual_result); + Ok(()) + } + + #[tokio::test] + async fn health() -> Result<()> { + let app = server::tests::test_app().await; + check(app.clone(), "GET", StatusCode::OK).await?; + check(app.clone(), "HEAD", StatusCode::OK).await?; + Ok(()) + } +} diff --git a/warden/src/state/mod.rs b/warden/src/state/mod.rs index a6b36e1..960fe87 100644 --- a/warden/src/state/mod.rs +++ b/warden/src/state/mod.rs @@ -2,20 +2,21 @@ pub(crate) mod database; use std::sync::Arc; use api_config::schema::SchemaDriver; -use tracing_subscriber::EnvFilter; -pub type LogHandle = tracing_subscriber::reload::Handle<EnvFilter, tracing_subscriber::Registry>; +use crate::logging::LogHandle; #[derive(Clone)] pub struct AppState { pub log_handle: LogHandle, pub schema_service: Arc<dyn SchemaDriver>, + pub pagination_limit: i64, } impl AppState { pub async fn new( log_handle: LogHandle, schema_service: Arc<dyn SchemaDriver>, + pagination_limit: i64, ) -> anyhow::Result<Self> { // let database = database::connect(&config.database).await?; // trace!("running database migrations"); @@ -25,6 +26,7 @@ impl AppState { Ok(Self { log_handle, schema_service, + pagination_limit, }) } } |
