riplog-view/backend/graphql.rs

268 lines
7.9 KiB
Rust

extern crate juniper;
use crate::database::{DBLog, DBMessage};
use chrono::prelude::*;
use juniper::Value::Null;
use juniper::{FieldError, FieldResult};
use std::collections::HashSet;
use std::convert::TryInto;
use warp::Filter;
#[derive(Debug, juniper::GraphQLObject)]
#[graphql(description = "Paginated list of messages")]
struct MessageList {
#[graphql(description = "List of messages")]
messages: Vec<Message>,
#[graphql(description = "Next message, if any (when using pagination)")]
next: Option<juniper::ID>,
}
#[derive(Debug, Clone, juniper::GraphQLObject)]
#[graphql(description = "A single message in a Slack workspace")]
struct Message {
#[graphql(description = "Message timestamp")]
time: DateTime<Utc>,
#[graphql(description = "Message content")]
content: String,
#[graphql(description = "Slack username, if applicable")]
username: String,
#[graphql(description = "Slack real name, if applicable")]
user_realname: String,
#[graphql(
description = "Channel/Private chat name. Channels are prefixed with #, Private chats with @"
)]
channel_name: String,
#[graphql(description = "Unique message ID (hopefully)")]
message_id: juniper::ID,
}
#[derive(Debug, juniper::GraphQLObject)]
#[graphql(description = "A slack workspace info")]
struct Workspace {
#[graphql(description = "Workspace name / ID")]
name: String,
#[graphql(description = "URL to workspace icon")]
icon: String,
}
#[derive(Debug, juniper::GraphQLObject)]
#[graphql(description = "A slack channel or private chat")]
struct Channel {
#[graphql(description = "Channel/Chat name")]
name: String,
#[graphql(description = "True if a private chat (or group chat), False if channel")]
is_private: bool,
}
struct WorkspaceData {
name: String,
icon: String,
messages: Vec<Message>,
}
#[derive(Debug, juniper::GraphQLInputObject)]
struct Pagination {
#[graphql(description = "Skip messages before this one")]
after: Option<juniper::ID>,
#[graphql(description = "Show at most the first X messages")]
first: Option<i32>,
}
#[derive(Debug, juniper::GraphQLInputObject)]
struct MessageFilter {
#[graphql(description = "Only show messages from this channel/chat")]
channel: Option<String>,
}
#[derive(juniper::GraphQLEnum)]
enum SortOrder {
#[graphql(description = "Sort from oldest")]
DateAsc,
#[graphql(description = "Sort from newest")]
DateDesc,
}
struct Context {
databases: Vec<WorkspaceData>,
}
impl juniper::Context for Context {}
/// Get message id for slack message
fn message_id(msg: &DBMessage) -> juniper::ID {
juniper::ID::new(format!("{}/{}", msg.channel_name, msg.time.timestamp()))
}
/// Convert from DB struct to GQL
fn from_db(log: DBLog) -> WorkspaceData {
WorkspaceData {
name: log.name,
icon: log.icon,
messages: log
.messages
.iter()
.map(|m| Message {
message_id: message_id(&m),
time: m.time,
content: m.content.clone(),
username: m.username.clone(),
user_realname: m.user_realname.clone(),
channel_name: m.channel_name.clone(),
})
.collect(),
}
}
struct Query;
#[juniper::object(
Context = Context,
)]
impl Query {
fn apiVersion() -> &str {
"1.0"
}
fn workspace(context: &Context) -> FieldResult<Vec<Workspace>> {
let mut results = vec![];
for ws in context.databases.as_slice() {
results.push(Workspace {
name: ws.name.clone(),
icon: ws.icon.clone(),
})
}
Ok(results)
}
fn channels(context: &Context, workspace: String) -> FieldResult<Vec<Channel>> {
let dbs = context
.databases
.iter()
.filter(|db| db.name == workspace)
.take(1)
.next();
match dbs {
None => Err(FieldError::new("workspace not found", Null)),
Some(db) => {
let mut channels = HashSet::new();
for msg in &db.messages {
channels.insert(msg.channel_name.clone());
}
Ok(channels
.iter()
.map(|name| Channel {
name: name.clone(),
is_private: !name.starts_with("#"),
})
.collect())
}
}
}
fn messages(
context: &Context,
workspace: String,
filter: Option<MessageFilter>,
order: Option<SortOrder>,
pagination: Option<Pagination>,
) -> FieldResult<MessageList> {
let dbs = context
.databases
.iter()
.filter(|db| db.name == workspace)
.take(1)
.next();
match dbs {
None => Err(FieldError::new("workspace not found", Null)),
Some(db) => {
let mut messages = db.messages.clone();
// Apply filters
if filter.is_some() {
let filters = filter.unwrap();
if filters.channel.is_some() {
let channel = filters.channel.unwrap();
messages = messages
.iter()
.filter(|x| x.channel_name == channel)
.cloned()
.collect();
}
}
// Apply order
match order.unwrap_or(SortOrder::DateAsc) {
SortOrder::DateAsc => messages.sort_by(|a, b| a.time.cmp(&b.time)),
SortOrder::DateDesc => messages.sort_by(|a, b| b.time.cmp(&a.time)),
}
// Apply pagination
let (messages, next) = match pagination {
None => (messages, None),
Some(pdata) => {
// Apply after, if specified
let skipped = match pdata.after {
None => messages,
Some(after) => messages
.iter()
.skip_while(|m| m.message_id != after)
.cloned()
.collect(),
};
// Apply limit, if specified
let limit: usize = pdata.first.unwrap_or(1000).try_into().unwrap_or(0);
if limit >= skipped.len() {
(skipped, None)
} else {
(
skipped.iter().take(limit).cloned().collect(),
Some(skipped.get(limit).unwrap().message_id.clone()),
)
}
}
};
Ok(MessageList { messages, next })
}
}
}
}
struct Mutation;
#[juniper::object(
Context = Context,
)]
impl Mutation {}
type Schema = juniper::RootNode<'static, Query, Mutation>;
pub fn server(bind: &str, port: u16, databases: Vec<DBLog>) {
let schema = Schema::new(Query, Mutation);
let state = warp::any().map(move || Context {
databases: databases.clone().into_iter().map(from_db).collect(),
});
let graphql_filter = juniper_warp::make_graphql_filter(schema, state.boxed());
println!("Starting server at {}:{}\n\nEndpoints:\n graphql: http://{}:{}/graphql\n graphiql: http://{}:{}/graphiql", bind, port, bind, port, bind, port);
warp::serve(
warp::get2()
.and(warp::path("graphiql"))
.and(juniper_warp::graphiql_filter("/graphql"))
.or(warp::path("graphql").and(graphql_filter)),
)
.run(std::net::SocketAddr::new(bind.parse().unwrap(), port));
}