116 lines
3.6 KiB
Rust
116 lines
3.6 KiB
Rust
use anyhow::Result;
|
|
use axum::{
|
|
async_trait,
|
|
extract::{FromRequestParts, State},
|
|
http::{
|
|
header::{COOKIE, SET_COOKIE},
|
|
request::Parts,
|
|
HeaderValue, Request, StatusCode,
|
|
},
|
|
middleware::Next,
|
|
response::Response,
|
|
Extension,
|
|
};
|
|
use chrono::{Duration, Utc};
|
|
use cookie::Cookie;
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
|
|
use crate::{http::error::ApiError, state::AppState};
|
|
|
|
use super::{session::Session, user::User};
|
|
|
|
pub const INVALID_SESSION: ApiError = ApiError::ClientError {
|
|
status: StatusCode::UNAUTHORIZED,
|
|
code: "authentication-required",
|
|
message: "Please log-in and submit a valid session as a cookie",
|
|
};
|
|
|
|
fn extract_session_token(header: &HeaderValue) -> Result<(Uuid, String)> {
|
|
Ok(Session::parse_token(
|
|
Cookie::parse(header.to_str()?)?.value(),
|
|
)?)
|
|
}
|
|
|
|
pub struct RequireUser(pub User);
|
|
|
|
#[async_trait]
|
|
impl<S> FromRequestParts<S> for RequireUser
|
|
where
|
|
S: Send + Sync,
|
|
{
|
|
type Rejection = ApiError<'static>;
|
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
|
match Extension::<User>::from_request_parts(parts, state).await {
|
|
Ok(Extension(user)) => Ok(RequireUser(user)),
|
|
_ => Err(INVALID_SESSION),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct RequireSession(pub Session);
|
|
|
|
#[async_trait]
|
|
impl<S> FromRequestParts<S> for RequireSession
|
|
where
|
|
S: Send + Sync,
|
|
{
|
|
type Rejection = ApiError<'static>;
|
|
|
|
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
|
match Extension::<Session>::from_request_parts(parts, state).await {
|
|
Ok(Extension(session)) => Ok(RequireSession(session)),
|
|
_ => Err(INVALID_SESSION),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn refresh_sessions<B>(
|
|
State(state): State<Arc<AppState>>,
|
|
mut req: Request<B>,
|
|
next: Next<B>,
|
|
) -> Response {
|
|
if let Some((session_id, session_secret)) = req
|
|
.headers()
|
|
.get(COOKIE)
|
|
.and_then(|header| extract_session_token(header).ok())
|
|
{
|
|
if let Some(Some((session, user))) = Session::find(&state.database, session_id).await.ok() {
|
|
// session validity requirements: secret must match, session must not have been expired
|
|
if session.secret == session_secret && session.expires_at >= Utc::now().naive_utc() {
|
|
// in the future we might wanna change the session secret, if we do, do it here!
|
|
if let Some((session, user)) = session
|
|
.refresh(
|
|
&state.database,
|
|
Duration::seconds(state.config.session_duration),
|
|
)
|
|
.await
|
|
.map(|s| (s, user))
|
|
.ok()
|
|
{
|
|
let extensions = req.extensions_mut();
|
|
extensions.insert(session.clone());
|
|
extensions.insert(user);
|
|
|
|
let mut response = next.run(req).await;
|
|
// Only set the session cookie if it hasn't been set yet (eg. logout)
|
|
let headers = response.headers_mut();
|
|
if !headers.contains_key(SET_COOKIE) {
|
|
headers.insert(
|
|
SET_COOKIE,
|
|
session
|
|
.cookie(state.config.domain().as_str(), state.config.secure())
|
|
.parse()
|
|
.unwrap(),
|
|
);
|
|
}
|
|
return response;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
next.run(req).await
|
|
}
|