diff --git a/src/auth/http.rs b/src/auth/http.rs index d0b683d..bf151c6 100644 --- a/src/auth/http.rs +++ b/src/auth/http.rs @@ -2,17 +2,13 @@ use anyhow::Result; use axum::{ async_trait, extract::{FromRequestParts, State}, - http::{ - header::{COOKIE, SET_COOKIE}, - request::Parts, - HeaderValue, Request, StatusCode, - }, + http::{header::COOKIE, request::Parts, HeaderValue, Request, StatusCode}, middleware::Next, response::Response, + Extension, }; use chrono::{Duration, Utc}; use cookie::Cookie; -use sqlx::{Pool, Postgres}; use std::sync::Arc; use uuid::Uuid; @@ -26,72 +22,58 @@ pub const INVALID_SESSION: AppError = AppError::ClientError { message: "Please log-in and submit a valid session as a cookie", }; -pub struct RequireSession(pub Session, pub User); - -#[async_trait] -impl FromRequestParts> for RequireSession { - type Rejection = AppError<'static>; - - async fn from_request_parts( - parts: &mut Parts, - state: &Arc, - ) -> Result { - if let Some(cookie) = parts.headers.get(COOKIE) { - let (session_id, session_secret) = extract_session_token(cookie)?; - - match Session::find(&state.database, session_id).await? { - None => Err(INVALID_SESSION), - Some((session, user)) => { - if session.secret != session_secret { - return Err(INVALID_SESSION); - } - if session.expires_at < Utc::now().naive_utc() { - return Err(INVALID_SESSION); - } - Ok(RequireSession(session, user)) - } - } - } else { - return Err(INVALID_SESSION); - } - } -} - fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> { Ok(Session::parse_token( Cookie::parse(header.to_str()?)?.value(), )?) } -async fn find_and_refresh( - pool: &Pool, - session_id: Uuid, - duration: Duration, -) -> Option { - if let Some(Some((session, _))) = Session::find(pool, session_id).await.ok() { - session.refresh(pool, duration).await.ok() - } else { - None +pub struct RequireUser(pub User); + +#[async_trait] +impl FromRequestParts for RequireUser +where + S: Send + Sync, +{ + type Rejection = AppError<'static>; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + match Extension::::from_request_parts(parts, state).await { + Ok(Extension(user)) => Ok(RequireUser(user)), + _ => Err(INVALID_SESSION), + } } } pub async fn refresh_sessions( State(state): State>, - req: Request, + mut req: Request, next: Next, ) -> Response { - if let Some((session_id, _)) = req + if let Some((session_id, session_secret)) = req .headers() .get(COOKIE) .and_then(|header| extract_session_token(header).ok()) { - // in the future we might wanna change the session secret, if we do, do it here! - find_and_refresh( - &state.database, - session_id, - Duration::seconds(state.config.session_duration), - ) - .await; + if let Some(Some((session, user))) = Session::find(&state.database, session_id).await.ok() { + // session validity requirements: secret must match, session must not have been expired + if session.secret == session_secret && session.expires_at >= Utc::now().naive_utc() { + // in the future we might wanna change the session secret, if we do, do it here! + if let Some((session, user)) = session + .refresh( + &state.database, + Duration::seconds(state.config.session_duration), + ) + .await + .map(|s| (s, user)) + .ok() + { + let extensions = req.extensions_mut(); + extensions.insert(session); + extensions.insert(user); + } + } + } } next.run(req).await diff --git a/src/auth/session.rs b/src/auth/session.rs index 801ce83..4f05e01 100644 --- a/src/auth/session.rs +++ b/src/auth/session.rs @@ -16,7 +16,7 @@ pub const USER_NOT_FOUND: AppError = AppError::ClientError { message: "The logged-in user was not found", }; -#[derive(Deserialize, Serialize, FromRow)] +#[derive(Deserialize, Serialize, Clone, FromRow)] pub struct Session { /// Role ID pub id: Uuid, diff --git a/src/auth/user.rs b/src/auth/user.rs index d3cddb3..ca6ded5 100644 --- a/src/auth/user.rs +++ b/src/auth/user.rs @@ -6,7 +6,7 @@ use uuid::Uuid; use super::hash::hash; -#[derive(Deserialize, Serialize, FromRow)] +#[derive(Deserialize, Serialize, Clone, FromRow)] pub struct User { /// User internal ID pub id: Uuid, diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 9667bc1..4dd006a 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -2,7 +2,7 @@ use axum::{ extract::State, http::{header::SET_COOKIE, StatusCode}, response::{IntoResponse, Response}, - Json, + Extension, Json, }; use chrono::Duration; use serde::Deserialize; @@ -10,7 +10,7 @@ use serde_json::json; use std::sync::Arc; use crate::{ - auth::{hash::verify, http::RequireSession, session::Session, user::User}, + auth::{hash::verify, http::RequireUser, session::Session, user::User}, error::AppError, state::AppState, }; @@ -68,6 +68,6 @@ pub async fn login( Ok(response) } -pub async fn me(RequireSession(_, user): RequireSession) -> Result> { +pub async fn me(RequireUser(user): RequireUser) -> Result> { Ok(user.name) }