avoid repeat session checking
This commit is contained in:
parent
87c9ea5aca
commit
1e10162fb3
4 changed files with 42 additions and 60 deletions
|
@ -2,17 +2,13 @@ use anyhow::Result;
|
||||||
use axum::{
|
use axum::{
|
||||||
async_trait,
|
async_trait,
|
||||||
extract::{FromRequestParts, State},
|
extract::{FromRequestParts, State},
|
||||||
http::{
|
http::{header::COOKIE, request::Parts, HeaderValue, Request, StatusCode},
|
||||||
header::{COOKIE, SET_COOKIE},
|
|
||||||
request::Parts,
|
|
||||||
HeaderValue, Request, StatusCode,
|
|
||||||
},
|
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::Response,
|
response::Response,
|
||||||
|
Extension,
|
||||||
};
|
};
|
||||||
use chrono::{Duration, Utc};
|
use chrono::{Duration, Utc};
|
||||||
use cookie::Cookie;
|
use cookie::Cookie;
|
||||||
use sqlx::{Pool, Postgres};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
@ -26,72 +22,58 @@ pub const INVALID_SESSION: AppError = AppError::ClientError {
|
||||||
message: "Please log-in and submit a valid session as a cookie",
|
message: "Please log-in and submit a valid session as a cookie",
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct RequireSession(pub Session, pub User);
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl FromRequestParts<Arc<AppState>> for RequireSession {
|
|
||||||
type Rejection = AppError<'static>;
|
|
||||||
|
|
||||||
async fn from_request_parts(
|
|
||||||
parts: &mut Parts,
|
|
||||||
state: &Arc<AppState>,
|
|
||||||
) -> Result<RequireSession, Self::Rejection> {
|
|
||||||
if let Some(cookie) = parts.headers.get(COOKIE) {
|
|
||||||
let (session_id, session_secret) = extract_session_token(cookie)?;
|
|
||||||
|
|
||||||
match Session::find(&state.database, session_id).await? {
|
|
||||||
None => Err(INVALID_SESSION),
|
|
||||||
Some((session, user)) => {
|
|
||||||
if session.secret != session_secret {
|
|
||||||
return Err(INVALID_SESSION);
|
|
||||||
}
|
|
||||||
if session.expires_at < Utc::now().naive_utc() {
|
|
||||||
return Err(INVALID_SESSION);
|
|
||||||
}
|
|
||||||
Ok(RequireSession(session, user))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return Err(INVALID_SESSION);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> {
|
fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> {
|
||||||
Ok(Session::parse_token(
|
Ok(Session::parse_token(
|
||||||
Cookie::parse(header.to_str()?)?.value(),
|
Cookie::parse(header.to_str()?)?.value(),
|
||||||
)?)
|
)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn find_and_refresh(
|
pub struct RequireUser(pub User);
|
||||||
pool: &Pool<Postgres>,
|
|
||||||
session_id: Uuid,
|
#[async_trait]
|
||||||
duration: Duration,
|
impl<S> FromRequestParts<S> for RequireUser
|
||||||
) -> Option<Session> {
|
where
|
||||||
if let Some(Some((session, _))) = Session::find(pool, session_id).await.ok() {
|
S: Send + Sync,
|
||||||
session.refresh(pool, duration).await.ok()
|
{
|
||||||
} else {
|
type Rejection = AppError<'static>;
|
||||||
None
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
match Extension::<User>::from_request_parts(parts, state).await {
|
||||||
|
Ok(Extension(user)) => Ok(RequireUser(user)),
|
||||||
|
_ => Err(INVALID_SESSION),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn refresh_sessions<B>(
|
pub async fn refresh_sessions<B>(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
req: Request<B>,
|
mut req: Request<B>,
|
||||||
next: Next<B>,
|
next: Next<B>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
if let Some((session_id, _)) = req
|
if let Some((session_id, session_secret)) = req
|
||||||
.headers()
|
.headers()
|
||||||
.get(COOKIE)
|
.get(COOKIE)
|
||||||
.and_then(|header| extract_session_token(header).ok())
|
.and_then(|header| extract_session_token(header).ok())
|
||||||
{
|
{
|
||||||
// in the future we might wanna change the session secret, if we do, do it here!
|
if let Some(Some((session, user))) = Session::find(&state.database, session_id).await.ok() {
|
||||||
find_and_refresh(
|
// session validity requirements: secret must match, session must not have been expired
|
||||||
&state.database,
|
if session.secret == session_secret && session.expires_at >= Utc::now().naive_utc() {
|
||||||
session_id,
|
// in the future we might wanna change the session secret, if we do, do it here!
|
||||||
Duration::seconds(state.config.session_duration),
|
if let Some((session, user)) = session
|
||||||
)
|
.refresh(
|
||||||
.await;
|
&state.database,
|
||||||
|
Duration::seconds(state.config.session_duration),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map(|s| (s, user))
|
||||||
|
.ok()
|
||||||
|
{
|
||||||
|
let extensions = req.extensions_mut();
|
||||||
|
extensions.insert(session);
|
||||||
|
extensions.insert(user);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
next.run(req).await
|
next.run(req).await
|
||||||
|
|
|
@ -16,7 +16,7 @@ pub const USER_NOT_FOUND: AppError = AppError::ClientError {
|
||||||
message: "The logged-in user was not found",
|
message: "The logged-in user was not found",
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, FromRow)]
|
#[derive(Deserialize, Serialize, Clone, FromRow)]
|
||||||
pub struct Session {
|
pub struct Session {
|
||||||
/// Role ID
|
/// Role ID
|
||||||
pub id: Uuid,
|
pub id: Uuid,
|
||||||
|
|
|
@ -6,7 +6,7 @@ use uuid::Uuid;
|
||||||
|
|
||||||
use super::hash::hash;
|
use super::hash::hash;
|
||||||
|
|
||||||
#[derive(Deserialize, Serialize, FromRow)]
|
#[derive(Deserialize, Serialize, Clone, FromRow)]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
/// User internal ID
|
/// User internal ID
|
||||||
pub id: Uuid,
|
pub id: Uuid,
|
||||||
|
|
|
@ -2,7 +2,7 @@ use axum::{
|
||||||
extract::State,
|
extract::State,
|
||||||
http::{header::SET_COOKIE, StatusCode},
|
http::{header::SET_COOKIE, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Extension, Json,
|
||||||
};
|
};
|
||||||
use chrono::Duration;
|
use chrono::Duration;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
@ -10,7 +10,7 @@ use serde_json::json;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
auth::{hash::verify, http::RequireSession, session::Session, user::User},
|
auth::{hash::verify, http::RequireUser, session::Session, user::User},
|
||||||
error::AppError,
|
error::AppError,
|
||||||
state::AppState,
|
state::AppState,
|
||||||
};
|
};
|
||||||
|
@ -68,6 +68,6 @@ pub async fn login(
|
||||||
Ok(response)
|
Ok(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn me(RequireSession(_, user): RequireSession) -> Result<String, AppError<'static>> {
|
pub async fn me(RequireUser(user): RequireUser) -> Result<String, AppError<'static>> {
|
||||||
Ok(user.name)
|
Ok(user.name)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue