diff --git a/Cargo.lock b/Cargo.lock index 05edb3b..fb7298b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,7 +102,7 @@ checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39" dependencies = [ "async-trait", "axum-core", - "bitflags", + "bitflags 1.3.2", "bytes", "futures-util", "http", @@ -179,6 +179,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" + [[package]] name = "blake2" version = "0.10.6" @@ -593,6 +599,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfe8eed0a9285ef776bb792479ea3834e8b94e13d615c2f66d03dd50a435a29" + [[package]] name = "httparse" version = "1.8.0" @@ -772,6 +784,7 @@ dependencies = [ "sqlx", "thiserror", "tokio", + "tower-http", "tracing", "tracing-subscriber", "url", @@ -1072,7 +1085,7 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -1081,7 +1094,7 @@ version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -1116,7 +1129,7 @@ version = "0.37.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8818fa822adcc98b18fedbb3632a6a33213c070556b5aa7c4c8cc21cff565c4c" dependencies = [ - "bitflags", + "bitflags 1.3.2", "errno", "io-lifetimes", "libc", @@ -1326,7 +1339,7 @@ dependencies = [ "ahash 0.7.6", "atoi", "base64 0.13.1", - "bitflags", + "bitflags 1.3.2", "byteorder", "bytes", "chrono", @@ -1648,6 +1661,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c" +dependencies = [ + "bitflags 2.3.3", + "bytes", + "futures-core", + "futures-util", + "http", + "http-body", + "http-range-header", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index 797560c..0a2f216 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ argon2 = { version = "0.5", features = ["std", "alloc"] } url = "2.4" thiserror = "1.0" async-trait = "0.1" +tower-http = { version = "0.4", features = ["cors"] } [profile.dev.package.sqlx-macros] opt-level = 3 \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index a4455d7..eda763a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,13 @@ mod state; use crate::{http::session::refresh_sessions, state::Config}; use anyhow::Result; -use axum::{middleware, Server}; +use axum::{ + http::{ + header::{ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_TYPE}, + Method, + }, + middleware, Server, +}; use figment::{ providers::{Env, Format, Serialized, Toml}, Figment, @@ -16,6 +22,7 @@ use figment::{ use sqlx::postgres::PgPoolOptions; use state::AppState; use std::{net::SocketAddr, sync::Arc}; +use tower_http::cors::CorsLayer; #[tokio::main] async fn main() -> Result<()> { @@ -27,6 +34,25 @@ async fn main() -> Result<()> { .extract()?; let addr: SocketAddr = config.bind.parse()?; + let origins = config + .cors_domains + .split(',') + .map(|x| x.to_string().parse().unwrap()) + .collect::>(); + println!("{:?}", origins); + let cors = CorsLayer::permissive() + .allow_methods([ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]) + .allow_origin(origins) + .allow_credentials(true) + .allow_headers([AUTHORIZATION, ACCEPT, CONTENT_TYPE]) + .expose_headers([CONTENT_ENCODING]); + let database = PgPoolOptions::new() .max_connections(5) .connect(config.database_url.as_str()) @@ -44,6 +70,7 @@ async fn main() -> Result<()> { shared_state.clone(), refresh_sessions, )) + .layer(cors) .with_state(shared_state); tracing::debug!("listening on {}", addr); diff --git a/src/state.rs b/src/state.rs index a1b690b..721c125 100644 --- a/src/state.rs +++ b/src/state.rs @@ -9,6 +9,7 @@ pub struct Config { pub session_duration: i64, // in seconds pub prune_interval: u64, // in seconds pub base_url: String, + pub cors_domains: String, // CORS-allowed domains, separated by comma } impl Config { @@ -46,6 +47,7 @@ impl Default for Config { session_duration: 3600, // 60min prune_interval: 3600, // 60min base_url: "http://localhost".into(), + cors_domains: "http://localhost:3000".into(), } } }