diff --git a/Cargo.lock b/Cargo.lock index 3f0d669..2307a0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1121,7 +1121,7 @@ checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" [[package]] name = "p4bl0t" -version = "0.1.0" +version = "1.0.0" dependencies = [ "color-eyre", "dotenvy", diff --git a/Cargo.toml b/Cargo.toml index e5c8fa2..60f2fae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "p4bl0t" -version = "0.1.0" +version = "1.0.0" edition = "2021" authors = ["Lucien Cartier-Tilet "] license-file = "LICENSE.md" @@ -11,8 +11,6 @@ repository = "https://github.com/phundrak/p4bl0t" keywords = ["discord", "bot", "logging"] publish = false -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] color-eyre = "0.6.2" dotenvy = "0.15.7" diff --git a/migrations/20231122220824_guild_log_channels.sql b/migrations/20231122220824_guild_log_channels.sql index f5f1863..90085f8 100644 --- a/migrations/20231122220824_guild_log_channels.sql +++ b/migrations/20231122220824_guild_log_channels.sql @@ -1,6 +1,14 @@ -- Add migration script here + +-- Discord IDs are kept as INTEGERs and not unsigned INTEGERs despite +-- their Rust type being u64. In order to properly manage them, you'll +-- need to cast any u64 to i64 with `as i64` before writing them to +-- the database, and cast any i64 to u64 with `as u64` when reading +-- them from the database. This operation is noop in Rust and should +-- therefore not cost a single CPU cycle. CREATE TABLE IF NOT EXISTS guild_log_channels ( - guild_id TEXT PRIMARY KEY, - channel_id TEXT NOT NULL, + guild_id INTEGER NOT NULL, + channel_id INTEGER NOT NULL, UNIQUE(guild_id, channel_id) ); +CREATE INDEX IF NOT EXISTS guild_log_channels_guild_id ON guild_log_channels(guild_id); diff --git a/src/db.rs b/src/db.rs deleted file mode 100644 index 34eef83..0000000 --- a/src/db.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::env; - -use sqlx::SqlitePool; - -pub struct Database { - pool: SqlitePool, -} - -impl Database { - pub async fn new() -> color_eyre::Result { - Ok(Self { - pool: SqlitePool::connect(&env::var("DATABASE_URL")?).await?, - }) - } - - pub async fn get_logging_channel( - &self, - guild_id: u64, - ) -> color_eyre::Result> { - let guild_str = guild_id.to_string(); - let channels = sqlx::query!( - r#" -SELECT channel_id -FROM guild_log_channels -WHERE guild_id = ?1 - "#, - guild_str - ) - .fetch_all(&self.pool) - .await?; - Ok(channels - .iter() - .map(|id| id.channel_id.parse::().unwrap()) - .collect()) - } - - pub async fn set_logging_channel( - &self, - guild_id: u64, - channel_id: u64, - ) -> color_eyre::Result<()> { - let guild_str = guild_id.to_string(); - let channel_str = channel_id.to_string(); - let mut conn = self.pool.acquire().await?; - - sqlx::query!( - r#" -INSERT INTO guild_log_channels (guild_id, channel_id) -VALUES ( ?1, ?2 ) - "#, - guild_str, - channel_str - ) - .execute(&mut *conn) - .await? - .last_insert_rowid(); - Ok(()) - } -} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..fbd4a3f --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,98 @@ +#![allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)] + +use std::env; + +use poise::serenity_prelude::{ChannelId, GuildId}; +use sqlx::SqlitePool; +use tracing::error; + +pub type Result = ::std::result::Result; + +pub struct Database { + pool: SqlitePool, +} + +impl Database { + pub async fn new() -> Result { + Ok(Self { + pool: SqlitePool::connect( + &env::var("DATABASE_URL") + .expect("Missing enviroment variable DATABASE_URL"), + ) + .await?, + }) + } + + pub async fn get_logging_channels( + &self, + guild_id: GuildId, + ) -> Result> { + let guild_id = guild_id.0 as i64; + let channels = sqlx::query!( + r#" +SELECT channel_id +FROM guild_log_channels +WHERE guild_id = ?1 + "#, + guild_id + ) + .fetch_all(&self.pool) + .await + .map_err(|e| { + error!( + "Error getting logging channels for guild {guild_id}: {e:?}" + ); + e + })?; + Ok(channels.iter().map(|id| id.channel_id as u64).collect()) + } + + pub async fn set_logging_channel( + &self, + guild_id: GuildId, + channel_id: ChannelId, + ) -> Result<()> { + let guild_id = guild_id.0 as i64; + let channel_id = channel_id.0 as i64; + let mut conn = self.pool.acquire().await?; + + sqlx::query!( + r#" +INSERT INTO guild_log_channels (guild_id, channel_id) +VALUES ( ?1, ?2 ) + "#, + guild_id, + channel_id + ) + .execute(&mut *conn) + .await + .map_err(|e| { + error!("Error setting channel {channel_id} as logger for guild {guild_id}: {e:?}"); + e + }) + .map(|_| ()) + } + + pub async fn remove_logging_channel( + &self, + guild_id: GuildId, + channel_id: ChannelId, + ) -> Result<()> { + let guild_id = guild_id.0 as i64; + let channel_id = channel_id.0 as i64; + let mut conn = self.pool.acquire().await?; + sqlx::query!(r#" +DELETE FROM guild_log_channels +WHERE guild_id = ?1 AND channel_id = ?2 + "#, + guild_id, + channel_id) + .execute(&mut *conn) + .await + .map_err(|e| { + error!("Error removing channel {channel_id} as a logger for guild {guild_id}: {e:?}"); + e + }) + .map(|_| ()) + } +} diff --git a/src/discord/commands.rs b/src/discord/commands.rs index b0b24e2..d48ee88 100644 --- a/src/discord/commands.rs +++ b/src/discord/commands.rs @@ -1,17 +1,105 @@ -use super::{Context, Error}; +use super::{Context, Result}; use super::utils::serenity; +#[allow(clippy::unused_async)] +#[poise::command( + slash_command, + subcommands("add_channel", "list_channels", "remove_channel"), + required_permissions = "ADMINISTRATOR" +)] +pub async fn logging(_ctx: Context<'_>) -> Result { + Ok(()) +} + #[poise::command(slash_command)] -pub async fn add_logging_channel( +pub async fn add_channel( ctx: Context<'_>, - #[description = "Selected channel"] channel: Option, -) -> Result<(), Error> { - let response = match channel { - None => "No channel selected. Please select one.".to_owned(), - Some(chan) => { - let channel_id = chan.id(); - format!("Selected channel <#{channel_id}>") + #[description = "New logging channel"] channel: serenity::Channel, +) -> Result { + let channel_id = channel.id(); + let response = match ctx.guild_id() { + None => "Error: Could not determine the guild's ID".to_owned(), + Some(guild_id) => { + match ctx + .data() + .database + .set_logging_channel(guild_id, channel_id) + .await + { + Ok(()) => format!( + "Added channel <#{channel_id}> as a logging channel" + ), + Err(e) => { + if let Some(db_error) = e.as_database_error() { + if db_error.is_unique_violation() { + format!("Channel <#{channel_id}> is already a logging channel") + } else { + format!("Error: {e:?}") + } + } else { + format!( + "Something bad happened with the database: {e:?}" + ) + } + } + } + } + }; + ctx.say(response).await?; + Ok(()) +} + +#[poise::command(slash_command)] +pub async fn list_channels(ctx: Context<'_>) -> Result { + let response = match ctx.guild_id() { + None => "Error: Could not determine the guild's ID".to_owned(), + Some(guild_id) => { + match ctx.data().database.get_logging_channels(guild_id).await { + Err(e) => format!("Could not retrieve loggers: {e:?}"), + Ok(channels) => { + if channels.is_empty() { + "No channels registered as loggers".to_owned() + } else { + format!( + "Here are the channels currently set as loggers:\n{}", + channels + .iter() + .map(|channel| format!("- <#{channel}>")) + .collect::>() + .join("\n") + ) + } + } + } + } + }; + ctx.say(response).await?; + Ok(()) +} + +#[poise::command(slash_command)] +pub async fn remove_channel( + ctx: Context<'_>, + #[description = "Logger channel to remove"] channel: serenity::Channel, +) -> Result { + let channel_id = channel.id(); + let response = match ctx.guild_id() { + None => "Error: Could not determine the guild's ID".to_owned(), + Some(guild_id) => { + match ctx + .data() + .database + .remove_logging_channel(guild_id, channel_id) + .await + { + Ok(()) => { + format!("Removed channel <#{channel_id}> as a logger") + } + Err(e) => { + format!("Could not remove channel as a logger: {e:?}") + } + } } }; ctx.say(response).await?; diff --git a/src/discord/events.rs b/src/discord/events.rs index e69de29..f60ba9e 100644 --- a/src/discord/events.rs +++ b/src/discord/events.rs @@ -0,0 +1,72 @@ +use crate::db::Database; + +use super::{utils::BotData, Error, Result}; + +use poise::{serenity_prelude as serenity, Event}; +use tracing::{error, info}; + +async fn handle_everyone_mention( + ctx: &serenity::Context, + database: &Database, + message: &serenity::Message, +) -> Result { + use serenity::ChannelId; + if let Some(guild_id) = message.guild_id { + if message.mention_everyone { + let author = message.author.clone(); + let message_channel = message.channel_id; + let channels: Vec = database + .get_logging_channels(guild_id) + .await? + .iter() + .map(|channel_id| serenity::ChannelId(channel_id.to_owned())) + .collect(); + for channel in &channels { + channel + .send_message(&ctx, |m| { + m.embed(|e| { + e.title("Someone mentioned everyone!") + .field("Author", author.clone(), true) + .field( + "When", + message.timestamp.naive_local().to_string(), + true, + ) + .field( + "Channel", + format!("<#{message_channel}>"), + true, + ) + .field("Link", format!("https://discord.com/channels/{guild_id}/{}/{}", channel.0, message.id), false) + }) + }) + .await + .map_err(|e| { + error!("Failed to send message: {e:?}"); + e + })?; + } + } + } else { + error!("Could not determine guild id of message {message:?}"); + } + Ok(()) +} + +pub async fn event_handler( + ctx: &serenity::Context, + event: &Event<'_>, + _framework: poise::FrameworkContext<'_, BotData, Error>, + data: &BotData, +) -> Result { + match event { + Event::Ready { data_about_bot } => { + info!("Logged in as {}", data_about_bot.user.name); + } + Event::Message { new_message } => { + handle_everyone_mention(ctx, &data.database, new_message).await?; + } + _ => {} + } + Ok(()) +} diff --git a/src/discord/mod.rs b/src/discord/mod.rs index 1e17966..a436a40 100644 --- a/src/discord/mod.rs +++ b/src/discord/mod.rs @@ -5,14 +5,20 @@ pub mod utils; use poise::FrameworkBuilder; use utils::serenity; -use commands::add_logging_channel; +use commands::logging; use utils::{BotData, Context, Error}; -pub async fn make_bot() -> color_eyre::Result> -{ - let framework = poise::Framework::builder() +use self::events::event_handler; + +pub type Result = ::std::result::Result<(), Error>; + +pub fn make_bot() -> FrameworkBuilder { + poise::Framework::builder() .options(poise::FrameworkOptions { - commands: vec![add_logging_channel()], + commands: vec![logging()], + event_handler: |ctx, event, framework, data| { + Box::pin(event_handler(ctx, event, framework, data)) + }, ..Default::default() }) .token(std::env::var("DISCORD_TOKEN").expect("missing DISCORD_TOKEN")) @@ -26,6 +32,5 @@ pub async fn make_bot() -> color_eyre::Result> .await?; Ok(BotData::new().await?) }) - }); - Ok(framework) + }) } diff --git a/src/discord/utils.rs b/src/discord/utils.rs index e08a8d0..ee599da 100644 --- a/src/discord/utils.rs +++ b/src/discord/utils.rs @@ -2,7 +2,7 @@ use crate::db::Database; pub use poise::serenity_prelude as serenity; pub struct BotData { - database: Database, + pub database: Database, } impl BotData { diff --git a/src/main.rs b/src/main.rs index 200e01a..4d8c8eb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +#![warn(clippy::style, clippy::pedantic)] + mod utils; mod db; mod discord; @@ -10,7 +12,7 @@ async fn main() -> Result<(), Box> { color_eyre::install()?; utils::setup_logging(); - let bot = discord::make_bot().await?; + let bot = discord::make_bot(); bot.run().await?; Ok(()) diff --git a/src/utils.rs b/src/utils.rs index a610e4b..b6258e0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,7 +3,7 @@ use tracing_subscriber::FmtSubscriber; pub fn setup_logging() { let subscriber = FmtSubscriber::builder() - .with_max_level(Level::DEBUG) + .with_max_level(Level::INFO) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("Setting default subscriber failed");