diff --git a/src/auth.rs b/src/auth.rs index ca4deb2..b936291 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -23,7 +23,7 @@ pub struct User { pub email: Option, /// Hashed password - pub password: Option>, + pub password: Option, /// User's chosen displayed name pub display_name: Option, @@ -43,6 +43,55 @@ pub struct User { pub deleted_at: Option>, } +impl Default for User { + fn default() -> Self { + Self { + id: Uuid::nil(), + name: Default::default(), + email: Default::default(), + password: Default::default(), + display_name: Default::default(), + bio: Default::default(), + roles: Default::default(), + created_at: Default::default(), + modified_at: Default::default(), + deleted_at: Default::default(), + } + } +} + +impl User { + pub async fn create( + pool: &Pool, + username: &str, + password: &str, + roles: &Vec, + ) -> Result { + let result = sqlx::query!( + r#"INSERT INTO users ( name, password, roles ) VALUES ( $1,$2,$3 ) RETURNING id, created_at"#, + username, + hash(&password)?, + roles, + ) + .fetch_one(pool) + .await?; + Ok(Self { + id: result.id, + name: username.to_owned(), + roles: roles.to_owned(), + created_at: result.created_at.and_utc(), + ..Default::default() + }) + } + + pub async fn find(pool: &Pool, name: &str) -> Result> { + Ok(sqlx::query_as("SELECT * FROM users WHERE name = $1") + .bind(name) + .fetch_optional(pool) + .await?) + } +} + #[derive(Deserialize, Serialize, FromRow)] pub struct Role { /// Role ID @@ -137,7 +186,7 @@ pub fn random() -> String { SaltString::generate(&mut OsRng).to_string() } -pub fn hash(plaintext: &String) -> Result { +pub fn hash(plaintext: &str) -> Result { let salt = SaltString::generate(&mut OsRng); let hashed = Argon2::default() .hash_password(plaintext.as_bytes(), &salt) @@ -146,7 +195,7 @@ pub fn hash(plaintext: &String) -> Result { Ok(hashed) } -pub fn verify(plaintext: &String, hash: &String) -> Result { +pub fn verify(plaintext: &str, hash: &str) -> Result { let parsed_hash = PasswordHash::new(hash).map_err(|err| anyhow!(err))?; Ok(Argon2::default() .verify_password(plaintext.as_bytes(), &parsed_hash) diff --git a/src/routes/admin.rs b/src/routes/admin.rs index 73c1353..80e89ef 100644 --- a/src/routes/admin.rs +++ b/src/routes/admin.rs @@ -4,7 +4,7 @@ use axum::Json; use serde_json::{json, Value}; use std::sync::Arc; -use crate::auth::{hash, random}; +use crate::auth::{random, User}; use crate::error::AppError; use crate::roles::ROLE_SUPERADMIN; use crate::state::AppState; @@ -29,15 +29,13 @@ pub async fn bootstrap(State(state): State>) -> Result let username = "admin"; let password = random(); - sqlx::query!( - r#"INSERT INTO users ( name, display_name, password, roles ) VALUES ( $1, $2, $3, $4 ) RETURNING id"#, - username, - "Administrator", - hash(&password)?, - &[ROLE_SUPERADMIN], - ) - .fetch_one(&state.database) - .await?; + User::create( + &state.database, + &username, + &password, + &[ROLE_SUPERADMIN].to_vec(), + ) + .await?; Ok(Json(json!({"username": username, "password": password}))) } diff --git a/src/routes/auth.rs b/src/routes/auth.rs index 5cb5584..4e8171c 100644 --- a/src/routes/auth.rs +++ b/src/routes/auth.rs @@ -5,7 +5,7 @@ use serde_json::{json, Value}; use std::sync::Arc; use crate::{ - auth::{verify, Session}, + auth::{verify, Session, User}, error::AppError, state::AppState, }; @@ -20,9 +20,7 @@ pub async fn login( State(state): State>, Json(payload): Json, ) -> Result, AppError> { - let user = sqlx::query!("SELECT * FROM users WHERE name = $1", payload.username) - .fetch_optional(&state.database) - .await?; + let user = User::find(&state.database, payload.username.as_str()).await?; let invalid = || -> AppError { AppError::ClientError { diff --git a/src/routes/content.rs b/src/routes/content.rs index 909756a..5287735 100644 --- a/src/routes/content.rs +++ b/src/routes/content.rs @@ -1,7 +1,8 @@ -use crate::{content::Page, error::AppError, state::AppState}; use axum::extract::{Path, State}; use std::sync::Arc; +use crate::{content::Page, error::AppError, state::AppState}; + pub async fn page( State(state): State>, Path((site, slug)): Path<(String, String)>,