diff --git a/src/auth/session.rs b/src/auth/session.rs index 57e1502..1792ded 100644 --- a/src/auth/session.rs +++ b/src/auth/session.rs @@ -67,11 +67,58 @@ impl Session { }) } - pub async fn find(pool: &Pool, id: Uuid) -> Result> { - Ok(sqlx::query_as("SELECT * FROM sessions WHERE id = $1") - .bind(id) - .fetch_optional(pool) - .await?) + pub async fn find(pool: &Pool, session_id: Uuid) -> Result> { + let record = sqlx::query!( + "SELECT + sessions.id AS session_id, + sessions.actor AS session_actor, + sessions.secret, + sessions.created_at AS session_created_at, + sessions.expires_at, + users.id AS user_id, + users.name, + users.email, + users.display_name, + users.bio, + users.roles, + users.created_at AS user_created_at, + users.modified_at, + users.deleted_at + FROM + sessions + JOIN + users ON sessions.actor = users.id + WHERE + sessions.id = $1", + session_id + ) + .fetch_optional(pool) + .await?; + + match record { + None => Ok(None), + Some(record) => Ok(Some(( + Self { + id: record.session_id, + actor: record.session_actor, + secret: record.secret, + created_at: record.session_created_at, + expires_at: record.expires_at, + }, + User { + id: record.user_id, + name: record.name, + email: record.email, + password: None, + display_name: record.display_name, + bio: record.bio, + roles: record.roles, + created_at: record.user_created_at, + modified_at: record.modified_at, + deleted_at: record.deleted_at, + }, + ))), + } } pub async fn refresh(self: Self, pool: &Pool, duration: Duration) -> Result { @@ -116,14 +163,16 @@ impl Session { } } +pub struct RequireSession(pub Session, pub User); + #[async_trait] -impl FromRequestParts> for Session { +impl FromRequestParts> for RequireSession { type Rejection = AppError<'static>; async fn from_request_parts( parts: &mut Parts, state: &Arc, - ) -> Result { + ) -> Result { if let Some(cookie) = parts.headers.get(COOKIE) { let cookie_str = cookie.to_str()?; let cookie = Cookie::parse(cookie_str)?; @@ -133,7 +182,7 @@ impl FromRequestParts> for Session { match Session::find(&state.database, session_id).await? { None => Err(INVALID_SESSION), - Some(session) => { + Some((session, user)) => { println!("{:?}<{:?}", session.expires_at, Utc::now().naive_utc()); if session.secret != session_secret { return Err(INVALID_SESSION); @@ -141,7 +190,7 @@ impl FromRequestParts> for Session { if session.expires_at < Utc::now().naive_utc() { return Err(INVALID_SESSION); } - Ok(session) + Ok(RequireSession(session, user)) } } } else { diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 9c94a31..9284fde 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -11,7 +11,11 @@ use serde_json::json; use std::sync::Arc; use crate::{ - auth::{hash::verify, session::Session, user::User}, + auth::{ + hash::verify, + session::{RequireSession, Session}, + user::User, + }, error::AppError, state::AppState, }; @@ -72,10 +76,6 @@ pub async fn login( Ok(response) } -pub async fn me( - State(state): State>, - session: Session, -) -> Result> { - let user = session.user(&state.database).await?; +pub async fn me(RequireSession(_, user): RequireSession) -> Result> { Ok(user.name) }