diff --git a/src/auth/hash.rs b/src/auth/hash.rs index f8b49cb..8cc171c 100644 --- a/src/auth/hash.rs +++ b/src/auth/hash.rs @@ -21,3 +21,42 @@ pub fn verify(plaintext: &str, hash: &str) -> Result { .verify_password(plaintext.as_bytes(), &parsed_hash) .is_ok()) } + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + #[test] + fn test_hash_and_verify() { + // Create random password + let random = super::random(); + + // Hash password + let hash = super::hash(&random).unwrap(); + + // Verify should be true + assert!(super::verify(&random, &hash).unwrap()); + } + + #[test] + fn test_random_requirements() { + // Test that a large number of random strings are unique + const NUM_STRINGS: usize = 10000; + + let mut strings = HashSet::new(); + + for _ in 0..NUM_STRINGS { + let random_string = super::random(); + + // Strings should also be long enough + assert!(random_string.len() >= 20); + + assert!( + !strings.contains(&random_string), + "Duplicate string found: {}", + random_string + ); + strings.insert(random_string); + } + } +} diff --git a/src/auth/session.rs b/src/auth/session.rs index 7b5c471..fd92347 100644 --- a/src/auth/session.rs +++ b/src/auth/session.rs @@ -1,14 +1,15 @@ -use anyhow::{anyhow, Result}; +use anyhow::anyhow; +use async_trait::async_trait; use axum::http::StatusCode; -use chrono::{Duration, NaiveDateTime, Utc}; +use chrono::{Duration, NaiveDateTime}; use cookie::Cookie; use serde::{Deserialize, Serialize}; -use sqlx::{FromRow, Pool, Postgres}; +use sqlx::FromRow; use uuid::Uuid; -use crate::http::error::ApiError; +use crate::{content::Error, http::error::ApiError}; -use super::{hash::random, user::User}; +use super::user::User; pub const USER_NOT_FOUND: ApiError<'static> = ApiError::Client { status: StatusCode::UNAUTHORIZED, @@ -32,126 +33,39 @@ pub struct Session { pub expires_at: NaiveDateTime, } +#[async_trait] +pub trait SessionRepository { + async fn create_session(&self, user_id: Uuid, duration: Duration) -> Result; + + async fn find_session(&self, session_id: Uuid) -> Result, Error>; + + async fn refresh_session(&self, session: Session, duration: Duration) + -> Result; + + async fn destroy_session(&self, session_id: Uuid) -> Result<(), Error>; + + async fn prune_dead_sessions(&self) -> Result; +} + impl Session { - pub async fn create(pool: &Pool, user_id: Uuid, duration: Duration) -> Result { - let now = Utc::now().naive_utc(); - let expires = now + duration; - let secret = random(); - let result = sqlx::query!( - "INSERT INTO sessions (id, actor, secret, created_at, expires_at) VALUES ($1, $2, $3, $4, $5) RETURNING id", - Uuid::now_v7(), - user_id, - secret, - now, - expires - ) - .fetch_one(pool) - .await?; - Ok(Self { - id: result.id, - actor: user_id, - secret, - created_at: now, - expires_at: now + duration, - }) - } - - pub async fn find(pool: &Pool, session_id: Uuid) -> Result> { - let record = sqlx::query!( - "SELECT - sessions.id AS session_id, - sessions.actor AS session_actor, - sessions.secret, - sessions.created_at AS session_created_at, - sessions.expires_at, - users.id AS user_id, - users.name, - users.email, - users.display_name, - users.bio, - users.roles, - users.created_at AS user_created_at, - users.modified_at, - users.deleted_at - FROM - sessions - JOIN - users ON sessions.actor = users.id - WHERE - sessions.id = $1", - session_id - ) - .fetch_optional(pool) - .await?; - - Ok(record.map(|record| { - ( - Self { - id: record.session_id, - actor: record.session_actor, - secret: record.secret, - created_at: record.session_created_at, - expires_at: record.expires_at, - }, - User { - id: record.user_id, - name: record.name, - email: record.email, - password: None, - display_name: record.display_name, - bio: record.bio, - roles: record.roles, - created_at: record.user_created_at, - modified_at: record.modified_at, - deleted_at: record.deleted_at, - }, - ) - })) - } - - pub async fn refresh(self, pool: &Pool, duration: Duration) -> Result { - let expires_at = (Utc::now() + duration).naive_utc(); - - sqlx::query!( - "UPDATE sessions SET expires_at = $1 WHERE id = $2 RETURNING id", - expires_at, - self.id - ) - .fetch_one(pool) - .await?; - - Ok(Session { expires_at, ..self }) - } - pub fn token(&self) -> String { format!("{}:{}", self.id.as_u128(), self.secret) } - pub fn parse_token(token: &str) -> Result<(Uuid, String)> { + pub fn parse_token(token: &str) -> Result<(Uuid, String), Error> { let (uuid_str, token_str) = token .split_once(':') - .ok_or_else(|| anyhow!("malformed token"))?; + .ok_or_else(|| Error::Internal(anyhow!("malformed token")))?; Ok(( - Uuid::from_u128(uuid_str.parse::()?), + Uuid::from_u128( + uuid_str + .parse::() + .map_err(|e| Error::Internal(e.into()))?, + ), token_str.to_string(), )) } - pub async fn destroy(&self, pool: &Pool) -> Result<()> { - sqlx::query!("DELETE FROM sessions WHERE id = $1", self.id) - .execute(pool) - .await?; - Ok(()) - } - - pub async fn prune_dead(pool: &Pool) -> Result { - let now = Utc::now().naive_utc(); - let result = sqlx::query!("DELETE FROM sessions WHERE expires_at < $1", now) - .execute(pool) - .await?; - Ok(result.rows_affected()) - } - pub fn cookie(&self, domain: &str, secure: bool) -> String { Cookie::build("session", self.token()) .domain(domain) diff --git a/src/auth/user.rs b/src/auth/user.rs index ad4f8b9..a5276ca 100644 --- a/src/auth/user.rs +++ b/src/auth/user.rs @@ -1,10 +1,10 @@ -use anyhow::Result; +use async_trait::async_trait; use chrono::NaiveDateTime; use serde::{Deserialize, Serialize}; -use sqlx::{FromRow, Pool, Postgres}; +use sqlx::FromRow; use uuid::Uuid; -use super::hash::hash; +use crate::content::Error; #[derive(Deserialize, Serialize, Clone, FromRow)] pub struct User { @@ -52,38 +52,18 @@ impl Default for User { } } -impl User { - pub async fn create( - pool: &Pool, +#[async_trait] +pub trait UserRepository { + async fn create_user( + &self, username: &str, password: &str, roles: &Vec, - ) -> Result { - let result = sqlx::query!( - r#"INSERT INTO users ( id, name, password, roles ) - VALUES ( $1,$2,$3,$4 ) RETURNING id, created_at"#, - Uuid::now_v7(), - username, - hash(&password)?, - roles, - ) - .fetch_one(pool) - .await?; - Ok(Self { - id: result.id, - name: username.to_owned(), - roles: roles.to_owned(), - created_at: result.created_at, - ..Default::default() - }) - } + ) -> Result; - pub async fn find(pool: &Pool, name: &str) -> Result> { - Ok(sqlx::query_as("SELECT * FROM users WHERE name = $1") - .bind(name) - .fetch_optional(pool) - .await?) - } + async fn find_user(&self, name: &str) -> Result, Error>; + + async fn has_no_users(&self) -> Result; } #[derive(Deserialize, Serialize, FromRow)] diff --git a/src/content/mod.rs b/src/content/mod.rs index 8b86f0a..1c94774 100644 --- a/src/content/mod.rs +++ b/src/content/mod.rs @@ -15,4 +15,7 @@ pub enum Error { #[error("Database error: {0}")] QueryFailed(#[from] sqlx::Error), + + #[error("Internal error: {0}")] + Internal(anyhow::Error), } diff --git a/src/database/mod.rs b/src/database/mod.rs index 15e2415..03311f4 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -5,7 +5,9 @@ use crate::state::AppState; pub mod collection; pub mod post; +pub mod session; pub mod site; +pub mod user; pub struct Database { pool: PgPool, diff --git a/src/database/session.rs b/src/database/session.rs new file mode 100644 index 0000000..291ce69 --- /dev/null +++ b/src/database/session.rs @@ -0,0 +1,129 @@ +use async_trait::async_trait; +use chrono::{Duration, Utc}; +use uuid::Uuid; + +use crate::{ + auth::{ + hash::random, + session::{Session, SessionRepository}, + user::User, + }, + content::Error, +}; + +use super::Database; + +#[async_trait] +impl SessionRepository for Database { + async fn create_session(&self, user_id: Uuid, duration: Duration) -> Result { + let now = Utc::now().naive_utc(); + let expires = now + duration; + let secret = random(); + let result = sqlx::query!( + "INSERT INTO sessions (id, actor, secret, created_at, expires_at) VALUES ($1, $2, $3, $4, $5) RETURNING id", + Uuid::now_v7(), + user_id, + secret, + now, + expires + ) + .fetch_one(&self.pool) + .await?; + Ok(Session { + id: result.id, + actor: user_id, + secret, + created_at: now, + expires_at: now + duration, + }) + } + + async fn find_session(&self, session_id: Uuid) -> Result, Error> { + let record = sqlx::query!( + "SELECT + sessions.id AS session_id, + sessions.actor AS session_actor, + sessions.secret, + sessions.created_at AS session_created_at, + sessions.expires_at, + users.id AS user_id, + users.name, + users.email, + users.display_name, + users.bio, + users.roles, + users.created_at AS user_created_at, + users.modified_at, + users.deleted_at + FROM + sessions + JOIN + users ON sessions.actor = users.id + WHERE + sessions.id = $1", + session_id + ) + .fetch_optional(&self.pool) + .await?; + + Ok(record.map(|record| { + ( + Session { + id: record.session_id, + actor: record.session_actor, + secret: record.secret, + created_at: record.session_created_at, + expires_at: record.expires_at, + }, + User { + id: record.user_id, + name: record.name, + email: record.email, + password: None, + display_name: record.display_name, + bio: record.bio, + roles: record.roles, + created_at: record.user_created_at, + modified_at: record.modified_at, + deleted_at: record.deleted_at, + }, + ) + })) + } + + async fn refresh_session( + &self, + session: Session, + duration: Duration, + ) -> Result { + let expires_at = (Utc::now() + duration).naive_utc(); + + sqlx::query!( + "UPDATE sessions SET expires_at = $1 WHERE id = $2 RETURNING id", + expires_at, + session.id + ) + .fetch_one(&self.pool) + .await?; + + Ok(Session { + expires_at, + ..session + }) + } + + async fn destroy_session(&self, session_id: Uuid) -> Result<(), Error> { + sqlx::query!("DELETE FROM sessions WHERE id = $1", session_id) + .execute(&self.pool) + .await?; + Ok(()) + } + + async fn prune_dead_sessions(&self) -> Result { + let now = Utc::now().naive_utc(); + let result = sqlx::query!("DELETE FROM sessions WHERE expires_at < $1", now) + .execute(&self.pool) + .await?; + Ok(result.rows_affected()) + } +} diff --git a/src/database/user.rs b/src/database/user.rs new file mode 100644 index 0000000..0c78a65 --- /dev/null +++ b/src/database/user.rs @@ -0,0 +1,57 @@ +use async_trait::async_trait; +use uuid::Uuid; + +use crate::{ + auth::hash::hash, + auth::user::{User, UserRepository}, + content::Error, +}; + +use super::Database; + +#[async_trait] +impl UserRepository for Database { + async fn create_user( + &self, + username: &str, + password: &str, + roles: &Vec, + ) -> Result { + let result = sqlx::query!( + r#"INSERT INTO users ( id, name, password, roles ) + VALUES ( $1,$2,$3,$4 ) RETURNING id, created_at"#, + Uuid::now_v7(), + username, + hash(&password).map_err(|e| Error::Internal(e))?, + roles, + ) + .fetch_one(&self.pool) + .await?; + Ok(User { + id: result.id, + name: username.to_owned(), + roles: roles.to_owned(), + created_at: result.created_at, + ..Default::default() + }) + } + + async fn find_user(&self, name: &str) -> Result, Error> { + Ok(sqlx::query_as("SELECT * FROM users WHERE name = $1") + .bind(name) + .fetch_optional(&self.pool) + .await?) + } + + async fn has_no_users(&self) -> Result { + // Check if the user table is completely empty + let empty = sqlx::query!( + "SELECT CASE WHEN EXISTS(SELECT 1 FROM users) THEN false ELSE true END AS empty;" + ) + .map(|row| row.empty.unwrap_or(true)) + .fetch_one(&self.pool) + .await?; + + Ok(empty) + } +} diff --git a/src/http/error.rs b/src/http/error.rs index 7a71874..d6e4314 100644 --- a/src/http/error.rs +++ b/src/http/error.rs @@ -87,6 +87,7 @@ impl From for ApiError<'_> { content::Error::IdentifierNotAvailable => ERR_NOT_AVAILABLE, content::Error::AccessDenied => ERR_UNAUTHORIZED, content::Error::QueryFailed(err) => err.into(), + content::Error::Internal(err) => err.into(), } } } diff --git a/src/http/session.rs b/src/http/session.rs index e66a8de..c02ca41 100644 --- a/src/http/session.rs +++ b/src/http/session.rs @@ -5,7 +5,7 @@ use axum::{ http::{ header::{COOKIE, SET_COOKIE}, request::Parts, - HeaderValue, Request, StatusCode, + Request, StatusCode, }, middleware::Next, response::Response, @@ -14,10 +14,13 @@ use axum::{ use chrono::{Duration, Utc}; use cookie::Cookie; use std::sync::Arc; -use uuid::Uuid; use crate::{ - auth::{session::Session, user::User}, + auth::{ + session::{Session, SessionRepository}, + user::User, + }, + database::Database, http::error::ApiError, state::AppState, }; @@ -28,10 +31,6 @@ pub const INVALID_SESSION: ApiError = ApiError::Client { message: "Please log-in and submit a valid session as a cookie", }; -fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> { - Session::parse_token(Cookie::parse(header.to_str()?)?.value()) -} - pub struct RequireUser(pub User); #[async_trait] @@ -91,17 +90,16 @@ pub async fn refresh_sessions( if let Some((session_id, session_secret)) = req .headers() .get(COOKIE) - .and_then(|header| extract_session_token(header).ok()) + .and_then(|header| Cookie::parse(header.to_str().unwrap_or_default()).ok()) + .and_then(|cookie| Session::parse_token(cookie.value()).ok()) { - if let Ok(Some((session, user))) = Session::find(&state.database, session_id).await { + let database = Database::from(&state); + if let Ok(Some((session, user))) = database.find_session(session_id).await { // session validity requirements: secret must match, session must not have been expired if session.secret == session_secret && session.expires_at >= Utc::now().naive_utc() { // in the future we might wanna change the session secret, if we do, do it here! - if let Ok((session, user)) = session - .refresh( - &state.database, - Duration::seconds(state.config.session_duration), - ) + if let Ok((session, user)) = database + .refresh_session(session, Duration::seconds(state.config.session_duration)) .await .map(|s| (s, user)) { diff --git a/src/routes/admin.rs b/src/routes/admin.rs index 2260f5e..057a0f3 100644 --- a/src/routes/admin.rs +++ b/src/routes/admin.rs @@ -6,24 +6,15 @@ use axum::{extract::State, Router}; use serde_json::json; use std::sync::Arc; +use crate::auth::user::UserRepository; +use crate::database::Database; use crate::{ - auth::{hash::random, user::User}, - builtins::ROLE_SUPERADMIN, - http::error::ApiError, - state::AppState, + auth::hash::random, builtins::ROLE_SUPERADMIN, http::error::ApiError, state::AppState, }; -async fn bootstrap(State(state): State>) -> impl IntoResponse { +async fn bootstrap(repository: Repo) -> impl IntoResponse { // Only allow this request if the user table is completely empty! - let empty = sqlx::query!( - "SELECT CASE WHEN EXISTS(SELECT 1 FROM users) THEN false ELSE true END AS empty;" - ) - .map(|row| row.empty.unwrap_or(true)) - .fetch_one(&state.database) - .await - .map_err(anyhow::Error::from)?; - - if !empty { + if !repository.has_no_users().await? { return Err(ApiError::Client { status: StatusCode::BAD_REQUEST, code: "already-setup", @@ -34,18 +25,13 @@ async fn bootstrap(State(state): State>) -> impl IntoResponse { let username = "admin"; let password = random(); - User::create( - &state.database, - username, - &password, - &[ROLE_SUPERADMIN].to_vec(), - ) - .await - .map_err(ApiError::from)?; + repository + .create_user(username, &password, &[ROLE_SUPERADMIN].to_vec()) + .await?; Ok(Json(json!({"username": username, "password": password}))) } pub fn router() -> Router> { - Router::new().route("/bootstrap", post(bootstrap)) + Router::new().route("/bootstrap", post(bootstrap::)) } diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 0df61a4..8e2f096 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -11,7 +11,12 @@ use serde_json::json; use std::sync::Arc; use crate::{ - auth::{hash::verify, session::Session, user::User}, + auth::{ + hash::verify, + session::{Session, SessionRepository}, + user::UserRepository, + }, + database::Database, http::{ error::ApiError, json::JsonBody, @@ -26,11 +31,13 @@ struct LoginRequest { pub password: String, } -async fn login( +async fn login( + repository: Repo, State(state): State>, JsonBody(payload): JsonBody, ) -> impl IntoResponse { - let user = User::find(&state.database, payload.username.as_str()) + let user = repository + .find_user(payload.username.as_str()) .await .map_err(ApiError::from)?; @@ -49,13 +56,9 @@ async fn login( return Err(invalid()); } - let session = Session::create( - &state.database, - user.id, - Duration::seconds(state.config.session_duration), - ) - .await - .map_err(ApiError::from)?; + let session = repository + .create_session(user.id, Duration::seconds(state.config.session_duration)) + .await?; let token = session.token(); let mut response: Response = @@ -76,11 +79,12 @@ async fn me(RequireUser(user): RequireUser) -> Result> Ok(user.name) } -async fn logout( +async fn logout( + repository: Repo, State(state): State>, RequireSession(session): RequireSession, ) -> Result> { - session.destroy(&state.database).await?; + repository.destroy_session(session.id).await?; let mut response: Response = Json(json!({ "ok": true })).into_response(); response.headers_mut().insert( @@ -94,7 +98,7 @@ async fn logout( pub fn router() -> Router> { Router::new() - .route("/login", post(login)) - .route("/logout", post(logout)) + .route("/login", post(login::)) + .route("/logout", post(logout::)) .route("/me", get(me)) }