diff --git a/Cargo.lock b/Cargo.lock index 2140954..2fff480 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -390,6 +390,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.28" @@ -406,6 +421,17 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +[[package]] +name = "futures-executor" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + [[package]] name = "futures-intrusive" version = "0.4.2" @@ -417,6 +443,23 @@ dependencies = [ "parking_lot 0.11.2", ] +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] +name = "futures-macro" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + [[package]] name = "futures-sink" version = "0.3.28" @@ -435,11 +478,16 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -708,6 +756,7 @@ dependencies = [ "chrono", "cookie", "figment", + "futures", "serde", "serde_json", "sqlx", @@ -1200,6 +1249,15 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.10.0" diff --git a/Cargo.toml b/Cargo.toml index d4c7053..26b0956 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ chrono = { version = "0.4", features = ["serde", "clock"] } anyhow = "1.0" argon2 = { version = "0.5", features = ["std", "alloc"] } url = "2.4" +futures = "0.3" [profile.dev.package.sqlx-macros] opt-level = 3 \ No newline at end of file diff --git a/src/auth/http.rs b/src/auth/http.rs new file mode 100644 index 0000000..55b33b7 --- /dev/null +++ b/src/auth/http.rs @@ -0,0 +1,97 @@ +use anyhow::Result; +use axum::{ + async_trait, + extract::{FromRequestParts, State}, + http::{ + header::{COOKIE, SET_COOKIE}, + request::Parts, + HeaderValue, Request, StatusCode, + }, + middleware::Next, + response::Response, +}; +use chrono::{Duration, Utc}; +use cookie::Cookie; +use sqlx::{Pool, Postgres}; +use std::sync::Arc; +use uuid::Uuid; + +use crate::{error::AppError, state::AppState}; + +use super::{session::Session, user::User}; + +pub const INVALID_SESSION: AppError = AppError::ClientError { + status: StatusCode::UNAUTHORIZED, + code: "authentication-required", + message: "Please log-in and submit a valid session as a cookie", +}; + +pub struct RequireSession(pub Session, pub User); + +#[async_trait] +impl FromRequestParts> for RequireSession { + type Rejection = AppError<'static>; + + async fn from_request_parts( + parts: &mut Parts, + state: &Arc, + ) -> Result { + if let Some(cookie) = parts.headers.get(COOKIE) { + let (session_id, session_secret) = extract_session_token(cookie)?; + + match Session::find(&state.database, session_id).await? { + None => Err(INVALID_SESSION), + Some((session, user)) => { + if session.secret != session_secret { + return Err(INVALID_SESSION); + } + if session.expires_at < Utc::now().naive_utc() { + return Err(INVALID_SESSION); + } + Ok(RequireSession(session, user)) + } + } + } else { + return Err(INVALID_SESSION); + } + } +} + +fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> { + Ok(Session::parse_token( + Cookie::parse(header.to_str()?)?.value(), + )?) +} + +async fn find_and_refresh( + pool: &Pool, + session_id: Uuid, + duration: Duration, +) -> Option { + if let Some(Some((session, _))) = Session::find(pool, session_id).await.ok() { + session.refresh(pool, duration).await.ok() + } else { + None + } +} + +pub async fn refresh_sessions( + State(state): State>, + req: Request, + next: Next, +) -> Response { + if let Some((session_id, _)) = req + .headers() + .get(COOKIE) + .and_then(|header| extract_session_token(header).ok()) + { + find_and_refresh( + &state.database, + session_id, + Duration::seconds(state.config.session_duration), + ) + .await; + } + + next.run(req).await +} diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 3a1b4ad..ae26496 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,3 +1,4 @@ pub mod hash; +pub mod http; pub mod session; pub mod user; diff --git a/src/auth/session.rs b/src/auth/session.rs index 13d6434..801ce83 100644 --- a/src/auth/session.rs +++ b/src/auth/session.rs @@ -1,28 +1,16 @@ -use std::sync::Arc; - use anyhow::{anyhow, Result}; -use axum::{ - async_trait, - extract::{FromRequestParts, State}, - http::{header::COOKIE, request::Parts, StatusCode}, -}; +use axum::http::StatusCode; use chrono::{Duration, NaiveDateTime, Utc}; use cookie::Cookie; use serde::{Deserialize, Serialize}; use sqlx::{FromRow, Pool, Postgres}; use uuid::Uuid; -use crate::{error::AppError, state::AppState}; +use crate::error::AppError; use super::{hash::random, user::User}; -const INVALID_SESSION: AppError = AppError::ClientError { - status: StatusCode::UNAUTHORIZED, - code: "authentication-required", - message: "Please log-in and submit a valid session as a cookie", -}; - -const USER_NOT_FOUND: AppError = AppError::ClientError { +pub const USER_NOT_FOUND: AppError = AppError::ClientError { status: StatusCode::UNAUTHORIZED, code: "user-not-found", message: "The logged-in user was not found", @@ -123,22 +111,16 @@ impl Session { pub async fn refresh(self: Self, pool: &Pool, duration: Duration) -> Result { let expires_at = (Utc::now() + duration).naive_utc(); - let secret = random(); sqlx::query!( - "UPDATE sessions SET secret = $1, expires_at = $2 WHERE id = $3 RETURNING id", - secret, + "UPDATE sessions SET expires_at = $1 WHERE id = $2 RETURNING id", expires_at, self.id ) .fetch_one(pool) .await?; - Ok(Session { - secret, - expires_at, - ..self - }) + Ok(Session { expires_at, ..self }) } pub fn token(self: &Self) -> String { @@ -161,39 +143,22 @@ impl Session { None => Err(USER_NOT_FOUND), } } -} -pub struct RequireSession(pub Session, pub User); + pub async fn prune_dead(pool: &Pool) -> Result { + let now = Utc::now().naive_utc(); + let result = sqlx::query!("DELETE FROM sessions WHERE expires_at < $1", now) + .execute(pool) + .await?; + Ok(result.rows_affected()) + } -#[async_trait] -impl FromRequestParts> for RequireSession { - type Rejection = AppError<'static>; - - async fn from_request_parts( - parts: &mut Parts, - state: &Arc, - ) -> Result { - if let Some(cookie) = parts.headers.get(COOKIE) { - let cookie_str = cookie.to_str()?; - let cookie = Cookie::parse(cookie_str)?; - let (session_id, session_secret) = Session::parse_token(cookie.value())?; - - let State(state) = State::>::from_request_parts(parts, state).await?; - - match Session::find(&state.database, session_id).await? { - None => Err(INVALID_SESSION), - Some((session, user)) => { - if session.secret != session_secret { - return Err(INVALID_SESSION); - } - if session.expires_at < Utc::now().naive_utc() { - return Err(INVALID_SESSION); - } - Ok(RequireSession(session, user)) - } - } - } else { - return Err(INVALID_SESSION); - } + pub fn cookie(self: &Self, domain: &str, secure: bool) -> String { + Cookie::build("session", self.token()) + .domain(domain) + .secure(secure) + .http_only(!secure) + .path("/") + .finish() + .to_string() } } diff --git a/src/main.rs b/src/main.rs index 2c243f2..fa0e93f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,8 +5,10 @@ mod roles; mod routes; mod state; +use crate::{auth::http::refresh_sessions, state::Config}; use anyhow::Result; use axum::{ + middleware, routing::{get, post}, Router, Server, }; @@ -18,8 +20,6 @@ use sqlx::postgres::PgPoolOptions; use state::AppState; use std::{net::SocketAddr, sync::Arc}; -use crate::state::Config; - #[tokio::main] async fn main() -> Result<()> { tracing_subscriber::fmt::init(); @@ -44,6 +44,10 @@ async fn main() -> Result<()> { .route("/me", get(routes::auth::me)) .route("/pages/:site/:slug", get(routes::content::page)) .route("/admin/bootstrap", post(routes::admin::bootstrap)) + .route_layer(middleware::from_fn_with_state( + shared_state.clone(), + refresh_sessions, + )) .with_state(shared_state); tracing::debug!("listening on {}", addr); diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 9284fde..9667bc1 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -5,17 +5,12 @@ use axum::{ Json, }; use chrono::Duration; -use cookie::Cookie; use serde::Deserialize; use serde_json::json; use std::sync::Arc; use crate::{ - auth::{ - hash::verify, - session::{RequireSession, Session}, - user::User, - }, + auth::{hash::verify, http::RequireSession, session::Session, user::User}, error::AppError, state::AppState, }; @@ -62,16 +57,13 @@ pub async fn login( Json(json!({ "session_token": token, "expires_at": session.expires_at })).into_response(); let secure = state.config.secure(); - let cookie = Cookie::build("session", token) - .domain(state.config.domain()) - .secure(secure) - .http_only(!secure) - .path("/") - .finish(); - response - .headers_mut() - .insert(SET_COOKIE, cookie.to_string().parse()?); + response.headers_mut().insert( + SET_COOKIE, + session + .cookie(state.config.domain().as_str(), secure) + .parse()?, + ); Ok(response) } diff --git a/src/state.rs b/src/state.rs index 5674a66..b1cae8a 100644 --- a/src/state.rs +++ b/src/state.rs @@ -6,7 +6,8 @@ use url::Url; pub struct Config { pub bind: String, pub database_url: String, - pub session_duration: i32, // in seconds + pub session_duration: i64, // in seconds + pub prune_interval: u64, // in seconds pub base_url: String, } @@ -31,6 +32,7 @@ impl Default for Config { bind: "127.0.0.1:3000".into(), database_url: "postgres://artificiale:changeme@localhost/artificiale".into(), session_duration: 3600, // 60min + prune_interval: 10, // 60min base_url: "http://localhost".into(), } }