avoid repeat session checking

This commit is contained in:
Hamcha 2023-07-02 11:24:14 +02:00
parent 87c9ea5aca
commit 1e10162fb3
Signed by: hamcha
GPG Key ID: 1669C533B8CF6D89
4 changed files with 42 additions and 60 deletions

View File

@ -2,17 +2,13 @@ use anyhow::Result;
use axum::{
async_trait,
extract::{FromRequestParts, State},
http::{
header::{COOKIE, SET_COOKIE},
request::Parts,
HeaderValue, Request, StatusCode,
},
http::{header::COOKIE, request::Parts, HeaderValue, Request, StatusCode},
middleware::Next,
response::Response,
Extension,
};
use chrono::{Duration, Utc};
use cookie::Cookie;
use sqlx::{Pool, Postgres};
use std::sync::Arc;
use uuid::Uuid;
@ -26,72 +22,58 @@ pub const INVALID_SESSION: AppError = AppError::ClientError {
message: "Please log-in and submit a valid session as a cookie",
};
pub struct RequireSession(pub Session, pub User);
#[async_trait]
impl FromRequestParts<Arc<AppState>> for RequireSession {
type Rejection = AppError<'static>;
async fn from_request_parts(
parts: &mut Parts,
state: &Arc<AppState>,
) -> Result<RequireSession, Self::Rejection> {
if let Some(cookie) = parts.headers.get(COOKIE) {
let (session_id, session_secret) = extract_session_token(cookie)?;
match Session::find(&state.database, session_id).await? {
None => Err(INVALID_SESSION),
Some((session, user)) => {
if session.secret != session_secret {
return Err(INVALID_SESSION);
}
if session.expires_at < Utc::now().naive_utc() {
return Err(INVALID_SESSION);
}
Ok(RequireSession(session, user))
}
}
} else {
return Err(INVALID_SESSION);
}
}
}
fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> {
Ok(Session::parse_token(
Cookie::parse(header.to_str()?)?.value(),
)?)
}
async fn find_and_refresh(
pool: &Pool<Postgres>,
session_id: Uuid,
duration: Duration,
) -> Option<Session> {
if let Some(Some((session, _))) = Session::find(pool, session_id).await.ok() {
session.refresh(pool, duration).await.ok()
} else {
None
pub struct RequireUser(pub User);
#[async_trait]
impl<S> FromRequestParts<S> for RequireUser
where
S: Send + Sync,
{
type Rejection = AppError<'static>;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match Extension::<User>::from_request_parts(parts, state).await {
Ok(Extension(user)) => Ok(RequireUser(user)),
_ => Err(INVALID_SESSION),
}
}
}
pub async fn refresh_sessions<B>(
State(state): State<Arc<AppState>>,
req: Request<B>,
mut req: Request<B>,
next: Next<B>,
) -> Response {
if let Some((session_id, _)) = req
if let Some((session_id, session_secret)) = req
.headers()
.get(COOKIE)
.and_then(|header| extract_session_token(header).ok())
{
// in the future we might wanna change the session secret, if we do, do it here!
find_and_refresh(
&state.database,
session_id,
Duration::seconds(state.config.session_duration),
)
.await;
if let Some(Some((session, user))) = Session::find(&state.database, session_id).await.ok() {
// session validity requirements: secret must match, session must not have been expired
if session.secret == session_secret && session.expires_at >= Utc::now().naive_utc() {
// in the future we might wanna change the session secret, if we do, do it here!
if let Some((session, user)) = session
.refresh(
&state.database,
Duration::seconds(state.config.session_duration),
)
.await
.map(|s| (s, user))
.ok()
{
let extensions = req.extensions_mut();
extensions.insert(session);
extensions.insert(user);
}
}
}
}
next.run(req).await

View File

@ -16,7 +16,7 @@ pub const USER_NOT_FOUND: AppError = AppError::ClientError {
message: "The logged-in user was not found",
};
#[derive(Deserialize, Serialize, FromRow)]
#[derive(Deserialize, Serialize, Clone, FromRow)]
pub struct Session {
/// Role ID
pub id: Uuid,

View File

@ -6,7 +6,7 @@ use uuid::Uuid;
use super::hash::hash;
#[derive(Deserialize, Serialize, FromRow)]
#[derive(Deserialize, Serialize, Clone, FromRow)]
pub struct User {
/// User internal ID
pub id: Uuid,

View File

@ -2,7 +2,7 @@ use axum::{
extract::State,
http::{header::SET_COOKIE, StatusCode},
response::{IntoResponse, Response},
Json,
Extension, Json,
};
use chrono::Duration;
use serde::Deserialize;
@ -10,7 +10,7 @@ use serde_json::json;
use std::sync::Arc;
use crate::{
auth::{hash::verify, http::RequireSession, session::Session, user::User},
auth::{hash::verify, http::RequireUser, session::Session, user::User},
error::AppError,
state::AppState,
};
@ -68,6 +68,6 @@ pub async fn login(
Ok(response)
}
pub async fn me(RequireSession(_, user): RequireSession) -> Result<String, AppError<'static>> {
pub async fn me(RequireUser(user): RequireUser) -> Result<String, AppError<'static>> {
Ok(user.name)
}