extern crate juniper; use crate::database::{DBLog, DBMessage}; use chrono::prelude::*; use juniper::http::GraphQLRequest; use juniper::Value::Null; use juniper::{FieldError, FieldResult}; use std::collections::HashSet; use std::convert::TryInto; use std::sync::Arc; use actix_cors::Cors; use actix_files as fs; use actix_web::{http, middleware, web, App, Error, HttpRequest, HttpResponse, HttpServer}; #[derive(Debug, juniper::GraphQLObject)] #[graphql(description = "Paginated list of messages")] struct MessageList { #[graphql(description = "List of messages")] messages: Vec, #[graphql(description = "Next message, if any (when using pagination)")] next: Option, } #[derive(Debug, Clone, juniper::GraphQLObject)] #[graphql(description = "A single message in a Slack workspace")] struct Message { #[graphql(description = "Message timestamp")] time: DateTime, #[graphql(description = "Message content")] content: String, #[graphql(description = "Slack username, if applicable")] username: String, #[graphql(description = "Slack real name, if applicable")] user_realname: String, #[graphql( description = "Channel/Private chat name. Channels are prefixed with #, Private chats with @" )] channel_name: String, #[graphql(description = "Unique message ID (hopefully)")] message_id: juniper::ID, } #[derive(Debug, juniper::GraphQLObject)] #[graphql(description = "A slack workspace info")] struct Workspace { #[graphql(description = "Workspace name / ID")] name: String, #[graphql(description = "URL to workspace icon")] icon: String, #[graphql(description = "List of channels and private chats")] channels: Vec, } #[derive(Debug, juniper::GraphQLObject)] #[graphql(description = "A slack channel or private chat")] struct Channel { #[graphql(description = "Channel/Chat name")] name: String, #[graphql(description = "True if a private chat (or group chat), False if channel")] is_private: bool, } struct WorkspaceData { name: String, icon: String, messages: Vec, } #[derive(Debug, juniper::GraphQLInputObject)] struct Pagination { #[graphql(description = "Skip messages before this one")] after: Option, #[graphql(description = "Show at most the first X messages")] first: Option, } #[derive(Debug, juniper::GraphQLInputObject)] struct MessageFilter { #[graphql(description = "Only show messages from this channel/chat")] channel: Option, } #[derive(juniper::GraphQLEnum)] enum SortOrder { #[graphql(description = "Sort from oldest")] DateAsc, #[graphql(description = "Sort from newest")] DateDesc, } struct Context { databases: Vec, } impl juniper::Context for Context {} /// Get message id for slack message fn message_id(msg: &DBMessage) -> juniper::ID { juniper::ID::new(format!( "{}@{}/{}", msg.username, msg.channel_name, msg.time.timestamp() )) } /// Convert from DB struct to GQL fn from_db(log: DBLog) -> WorkspaceData { WorkspaceData { name: log.name, icon: log.icon, messages: log .messages .iter() .map(|m| Message { message_id: message_id(&m), time: m.time, content: m.content.clone(), username: m.username.clone(), user_realname: m.user_realname.clone(), channel_name: m.channel_name.clone(), }) .collect(), } } struct Query; #[juniper::object( Context = Context, )] impl Query { fn apiVersion() -> &str { "1.0" } fn channels(context: &Context, workspace: String) -> FieldResult> { let dbs = context .databases .iter() .filter(|db| db.name == workspace) .take(1) .next(); match dbs { None => Err(FieldError::new("workspace not found", Null)), Some(db) => { let mut channels = HashSet::new(); for msg in &db.messages { channels.insert(msg.channel_name.clone()); } Ok(channels .iter() .map(|name| Channel { name: name.clone(), is_private: !name.starts_with('#'), }) .collect()) } } } fn workspace(context: &Context) -> FieldResult> { let mut results = vec![]; for ws in context.databases.as_slice() { let mut channels = HashSet::new(); for msg in &ws.messages { channels.insert(msg.channel_name.clone()); } results.push(Workspace { name: ws.name.clone(), icon: ws.icon.clone(), channels: channels .iter() .map(|name| Channel { name: name.clone(), is_private: !name.starts_with('#'), }) .collect(), }) } Ok(results) } fn messages( context: &Context, workspace: String, filter: Option, order: Option, pagination: Option, ) -> FieldResult { let dbs = context .databases .iter() .filter(|db| db.name == workspace) .take(1) .next(); match dbs { None => Err(FieldError::new("workspace not found", Null)), Some(db) => { let mut messages = db.messages.clone(); // Apply filters if filter.is_some() { let filters = filter.unwrap(); if filters.channel.is_some() { let channel = filters.channel.unwrap(); messages = messages .iter() .filter(|x| x.channel_name == channel) .cloned() .collect(); } } // Apply order match order.unwrap_or(SortOrder::DateAsc) { SortOrder::DateAsc => messages.sort_by(|a, b| a.time.cmp(&b.time)), SortOrder::DateDesc => messages.sort_by(|a, b| b.time.cmp(&a.time)), } // Apply pagination let (messages, next) = match pagination { None => (messages, None), Some(pdata) => { // Apply after, if specified let skipped = match pdata.after { None => messages, Some(after) => messages .iter() .skip_while(|m| m.message_id != after) .cloned() .collect(), }; // Apply limit, if specified let limit: usize = pdata.first.unwrap_or(1000).try_into().unwrap_or(0); if limit >= skipped.len() { (skipped, None) } else { ( skipped.iter().take(limit).cloned().collect(), Some(skipped.get(limit).unwrap().message_id.clone()), ) } } }; Ok(MessageList { messages, next }) } } } } struct Mutation; #[juniper::object( Context = Context, )] impl Mutation {} type Schema = juniper::RootNode<'static, Query, Mutation>; async fn graphql( st: web::Data>, data: web::Json, ) -> Result { let user = web::block(move || { let res = data.execute(&st.schema, &st.context); Ok::<_, serde_json::error::Error>(serde_json::to_string(&res)?) }) .await?; Ok(HttpResponse::Ok() .content_type("application/json") .body(user)) } struct GQLData { schema: Schema, context: Context, } pub async fn server(bind: &str, static_dir: String, databases: Vec) -> std::io::Result<()> { // Create Juniper schema let schema = Schema::new(Query, Mutation); let context = Context { databases: databases.into_iter().map(from_db).collect(), }; let data = std::sync::Arc::new(GQLData { schema, context }); // Start http server HttpServer::new(move || { App::new() .data(data.clone()) .wrap(middleware::Logger::default()) .wrap( Cors::new() .allowed_methods(vec!["GET", "POST", "OPTIONS"]) .allowed_headers(vec![ http::header::AUTHORIZATION, http::header::ACCEPT, http::header::CONTENT_TYPE, ]) .max_age(3600) .finish(), ) .service(web::resource("/graphql").route(web::post().to(graphql))) .service(fs::Files::new("/", &static_dir)) .service(web::resource("/").route(web::get().to(|req: HttpRequest| { println!("{:?}", req); HttpResponse::Found() .header(http::header::LOCATION, "/index.html") .finish() }))) }) .bind(bind)? .run() .await }