extern crate juniper; use crate::database::{DBLog, DBMessage}; use chrono::prelude::*; use juniper::Value::Null; use juniper::{FieldError, FieldResult}; use std::collections::HashSet; use std::convert::TryInto; use warp::Filter; #[derive(Debug, juniper::GraphQLObject)] #[graphql(description = "Paginated list of messages")] struct MessageList { #[graphql(description = "List of messages")] messages: Vec, #[graphql(description = "Next message, if any (when using pagination)")] next: Option, } #[derive(Debug, Clone, juniper::GraphQLObject)] #[graphql(description = "A single message in a Slack workspace")] struct Message { #[graphql(description = "Message timestamp")] time: DateTime, #[graphql(description = "Message content")] content: String, #[graphql(description = "Slack username, if applicable")] username: String, #[graphql(description = "Slack real name, if applicable")] user_realname: String, #[graphql( description = "Channel/Private chat name. Channels are prefixed with #, Private chats with @" )] channel_name: String, #[graphql(description = "Unique message ID (hopefully)")] message_id: juniper::ID, } #[derive(Debug, juniper::GraphQLObject)] #[graphql(description = "A slack workspace info")] struct Workspace { #[graphql(description = "Workspace name / ID")] name: String, #[graphql(description = "URL to workspace icon")] icon: String, } #[derive(Debug, juniper::GraphQLObject)] #[graphql(description = "A slack channel or private chat")] struct Channel { #[graphql(description = "Channel/Chat name")] name: String, #[graphql(description = "True if a private chat (or group chat), False if channel")] is_private: bool, } struct WorkspaceData { name: String, icon: String, messages: Vec, } #[derive(Debug, juniper::GraphQLInputObject)] struct Pagination { #[graphql(description = "Skip messages before this one")] after: Option, #[graphql(description = "Show at most the first X messages")] first: Option, } #[derive(Debug, juniper::GraphQLInputObject)] struct MessageFilter { #[graphql(description = "Only show messages from this channel/chat")] channel: Option, } #[derive(juniper::GraphQLEnum)] enum SortOrder { #[graphql(description = "Sort from oldest")] DateAsc, #[graphql(description = "Sort from newest")] DateDesc, } struct Context { databases: Vec, } impl juniper::Context for Context {} /// Get message id for slack message fn message_id(msg: &DBMessage) -> juniper::ID { juniper::ID::new(format!("{}/{}", msg.channel_name, msg.time.timestamp())) } /// Convert from DB struct to GQL fn from_db(log: DBLog) -> WorkspaceData { WorkspaceData { name: log.name, icon: log.icon, messages: log .messages .iter() .map(|m| Message { message_id: message_id(&m), time: m.time, content: m.content.clone(), username: m.username.clone(), user_realname: m.user_realname.clone(), channel_name: m.channel_name.clone(), }) .collect(), } } struct Query; #[juniper::object( Context = Context, )] impl Query { fn apiVersion() -> &str { "1.0" } fn workspace(context: &Context) -> FieldResult> { let mut results = vec![]; for ws in context.databases.as_slice() { results.push(Workspace { name: ws.name.clone(), icon: ws.icon.clone(), }) } Ok(results) } fn channels(context: &Context, workspace: String) -> FieldResult> { let dbs = context .databases .iter() .filter(|db| db.name == workspace) .take(1) .next(); match dbs { None => Err(FieldError::new("workspace not found", Null)), Some(db) => { let mut channels = HashSet::new(); for msg in &db.messages { channels.insert(msg.channel_name.clone()); } Ok(channels .iter() .map(|name| Channel { name: name.clone(), is_private: !name.starts_with("#"), }) .collect()) } } } fn messages( context: &Context, workspace: String, filter: Option, order: Option, pagination: Option, ) -> FieldResult { let dbs = context .databases .iter() .filter(|db| db.name == workspace) .take(1) .next(); match dbs { None => Err(FieldError::new("workspace not found", Null)), Some(db) => { let mut messages = db.messages.clone(); // Apply filters if filter.is_some() { let filters = filter.unwrap(); if filters.channel.is_some() { let channel = filters.channel.unwrap(); messages = messages .iter() .filter(|x| x.channel_name == channel) .cloned() .collect(); } } // Apply order match order.unwrap_or(SortOrder::DateAsc) { SortOrder::DateAsc => messages.sort_by(|a, b| a.time.cmp(&b.time)), SortOrder::DateDesc => messages.sort_by(|a, b| b.time.cmp(&a.time)), } // Apply pagination let (messages, next) = match pagination { None => (messages, None), Some(pdata) => { // Apply after, if specified let skipped = match pdata.after { None => messages, Some(after) => messages .iter() .skip_while(|m| m.message_id != after) .cloned() .collect(), }; // Apply limit, if specified let limit: usize = pdata.first.unwrap_or(1000).try_into().unwrap_or(0); if limit >= skipped.len() { (skipped, None) } else { ( skipped.iter().take(limit).cloned().collect(), Some(skipped.get(limit).unwrap().message_id.clone()), ) } } }; Ok(MessageList { messages, next }) } } } } struct Mutation; #[juniper::object( Context = Context, )] impl Mutation {} type Schema = juniper::RootNode<'static, Query, Mutation>; pub fn server(bind: &str, port: u16, databases: Vec) { let schema = Schema::new(Query, Mutation); let state = warp::any().map(move || Context { databases: databases.clone().into_iter().map(from_db).collect(), }); let graphql_filter = juniper_warp::make_graphql_filter(schema, state.boxed()); println!("Starting server at {}:{}\n\nEndpoints:\n graphql: http://{}:{}/graphql\n graphiql: http://{}:{}/graphiql", bind, port, bind, port, bind, port); warp::serve( warp::get2() .and(warp::path("graphiql")) .and(juniper_warp::graphiql_filter("/graphql")) .or(warp::path("graphql").and(graphql_filter)), ) .run(std::net::SocketAddr::new(bind.parse().unwrap(), port)); }