avoid repeat session checking

This commit is contained in:
Hamcha 2023-07-02 11:24:14 +02:00
parent 87c9ea5aca
commit 1e10162fb3
Signed by: hamcha
GPG key ID: 1669C533B8CF6D89
4 changed files with 42 additions and 60 deletions

View file

@ -2,17 +2,13 @@ use anyhow::Result;
use axum::{ use axum::{
async_trait, async_trait,
extract::{FromRequestParts, State}, extract::{FromRequestParts, State},
http::{ http::{header::COOKIE, request::Parts, HeaderValue, Request, StatusCode},
header::{COOKIE, SET_COOKIE},
request::Parts,
HeaderValue, Request, StatusCode,
},
middleware::Next, middleware::Next,
response::Response, response::Response,
Extension,
}; };
use chrono::{Duration, Utc}; use chrono::{Duration, Utc};
use cookie::Cookie; use cookie::Cookie;
use sqlx::{Pool, Postgres};
use std::sync::Arc; use std::sync::Arc;
use uuid::Uuid; use uuid::Uuid;
@ -26,72 +22,58 @@ pub const INVALID_SESSION: AppError = AppError::ClientError {
message: "Please log-in and submit a valid session as a cookie", message: "Please log-in and submit a valid session as a cookie",
}; };
pub struct RequireSession(pub Session, pub User);
#[async_trait]
impl FromRequestParts<Arc<AppState>> for RequireSession {
type Rejection = AppError<'static>;
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<AppState>,
) -> Result<RequireSession, Self::Rejection> {
if let Some(cookie) = parts.headers.get(COOKIE) {
let (session_id, session_secret) = extract_session_token(cookie)?;
match Session::find(&state.database, session_id).await? {
None => Err(INVALID_SESSION),
Some((session, user)) => {
if session.secret != session_secret {
return Err(INVALID_SESSION);
}
if session.expires_at < Utc::now().naive_utc() {
return Err(INVALID_SESSION);
}
Ok(RequireSession(session, user))
}
}
} else {
return Err(INVALID_SESSION);
}
}
}
fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> { fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> {
Ok(Session::parse_token( Ok(Session::parse_token(
Cookie::parse(header.to_str()?)?.value(), Cookie::parse(header.to_str()?)?.value(),
)?) )?)
} }
async fn find_and_refresh( pub struct RequireUser(pub User);
pool: &Pool<Postgres>,
session_id: Uuid, #[async_trait]
duration: Duration, impl<S> FromRequestParts<S> for RequireUser
) -> Option<Session> { where
if let Some(Some((session, _))) = Session::find(pool, session_id).await.ok() { S: Send + Sync,
session.refresh(pool, duration).await.ok() {
} else { type Rejection = AppError<'static>;
None
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match Extension::<User>::from_request_parts(parts, state).await {
Ok(Extension(user)) => Ok(RequireUser(user)),
_ => Err(INVALID_SESSION),
}
} }
} }
pub async fn refresh_sessions<B>( pub async fn refresh_sessions<B>(
State(state): State<Arc<AppState>>, State(state): State<Arc<AppState>>,
req: Request<B>, mut req: Request<B>,
next: Next<B>, next: Next<B>,
) -> Response { ) -> Response {
if let Some((session_id, _)) = req if let Some((session_id, session_secret)) = req
.headers() .headers()
.get(COOKIE) .get(COOKIE)
.and_then(|header| extract_session_token(header).ok()) .and_then(|header| extract_session_token(header).ok())
{ {
// in the future we might wanna change the session secret, if we do, do it here! if let Some(Some((session, user))) = Session::find(&state.database, session_id).await.ok() {
find_and_refresh( // session validity requirements: secret must match, session must not have been expired
&state.database, if session.secret == session_secret && session.expires_at >= Utc::now().naive_utc() {
session_id, // in the future we might wanna change the session secret, if we do, do it here!
Duration::seconds(state.config.session_duration), if let Some((session, user)) = session
) .refresh(
.await; &state.database,
Duration::seconds(state.config.session_duration),
)
.await
.map(|s| (s, user))
.ok()
{
let extensions = req.extensions_mut();
extensions.insert(session);
extensions.insert(user);
}
}
}
} }
next.run(req).await next.run(req).await

View file

@ -16,7 +16,7 @@ pub const USER_NOT_FOUND: AppError = AppError::ClientError {
message: "The logged-in user was not found", message: "The logged-in user was not found",
}; };
#[derive(Deserialize, Serialize, FromRow)] #[derive(Deserialize, Serialize, Clone, FromRow)]
pub struct Session { pub struct Session {
/// Role ID /// Role ID
pub id: Uuid, pub id: Uuid,

View file

@ -6,7 +6,7 @@ use uuid::Uuid;
use super::hash::hash; use super::hash::hash;
#[derive(Deserialize, Serialize, FromRow)] #[derive(Deserialize, Serialize, Clone, FromRow)]
pub struct User { pub struct User {
/// User internal ID /// User internal ID
pub id: Uuid, pub id: Uuid,

View file

@ -2,7 +2,7 @@ use axum::{
extract::State, extract::State,
http::{header::SET_COOKIE, StatusCode}, http::{header::SET_COOKIE, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Extension, Json,
}; };
use chrono::Duration; use chrono::Duration;
use serde::Deserialize; use serde::Deserialize;
@ -10,7 +10,7 @@ use serde_json::json;
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
auth::{hash::verify, http::RequireSession, session::Session, user::User}, auth::{hash::verify, http::RequireUser, session::Session, user::User},
error::AppError, error::AppError,
state::AppState, state::AppState,
}; };
@ -68,6 +68,6 @@ pub async fn login(
Ok(response) Ok(response)
} }
pub async fn me(RequireSession(_, user): RequireSession) -> Result<String, AppError<'static>> { pub async fn me(RequireUser(user): RequireUser) -> Result<String, AppError<'static>> {
Ok(user.name) Ok(user.name)
} }