diff --git a/migrations/20231122220824_guild_log_channels.sql b/migrations/20231122220824_guild_log_channels.sql index f5f1863..192e410 100644 --- a/migrations/20231122220824_guild_log_channels.sql +++ b/migrations/20231122220824_guild_log_channels.sql @@ -1,6 +1,13 @@ -- Add migration script here + +-- Discord IDs are kept as INTEGERs and not unsigned INTEGERs despite +-- their Rust type being u64. In order to properly manage them, you'll +-- need to cast any u64 to i64 with `as i64` before writing them to +-- the database, and cast any i64 to u64 with `as u64` when reading +-- them from the database. This operation is noop in Rust and should +-- therefore not cost a single CPU cycle. CREATE TABLE IF NOT EXISTS guild_log_channels ( - guild_id TEXT PRIMARY KEY, - channel_id TEXT NOT NULL, + guild_id INTEGER PRIMARY KEY, + channel_id INTEGER NOT NULL, UNIQUE(guild_id, channel_id) ); diff --git a/src/db.rs b/src/db.rs deleted file mode 100644 index 34eef83..0000000 --- a/src/db.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::env; - -use sqlx::SqlitePool; - -pub struct Database { - pool: SqlitePool, -} - -impl Database { - pub async fn new() -> color_eyre::Result { - Ok(Self { - pool: SqlitePool::connect(&env::var("DATABASE_URL")?).await?, - }) - } - - pub async fn get_logging_channel( - &self, - guild_id: u64, - ) -> color_eyre::Result> { - let guild_str = guild_id.to_string(); - let channels = sqlx::query!( - r#" -SELECT channel_id -FROM guild_log_channels -WHERE guild_id = ?1 - "#, - guild_str - ) - .fetch_all(&self.pool) - .await?; - Ok(channels - .iter() - .map(|id| id.channel_id.parse::().unwrap()) - .collect()) - } - - pub async fn set_logging_channel( - &self, - guild_id: u64, - channel_id: u64, - ) -> color_eyre::Result<()> { - let guild_str = guild_id.to_string(); - let channel_str = channel_id.to_string(); - let mut conn = self.pool.acquire().await?; - - sqlx::query!( - r#" -INSERT INTO guild_log_channels (guild_id, channel_id) -VALUES ( ?1, ?2 ) - "#, - guild_str, - channel_str - ) - .execute(&mut *conn) - .await? - .last_insert_rowid(); - Ok(()) - } -} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..3720f66 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,62 @@ +use std::env; + +use poise::serenity_prelude::{ChannelId, GuildId}; +use sqlx::SqlitePool; + +pub type Result = ::std::result::Result; + +pub struct Database { + pool: SqlitePool, +} + +impl Database { + pub async fn new() -> Result { + Ok(Self { + pool: SqlitePool::connect( + &env::var("DATABASE_URL") + .expect("Missing enviroment variable DATABASE_URL"), + ) + .await?, + }) + } + + pub async fn get_logging_channel( + &self, + guild_id: u64, + ) -> Result> { + let guild_id = guild_id as i64; + let channels = sqlx::query!( + r#" +SELECT channel_id +FROM guild_log_channels +WHERE guild_id = ?1 + "#, + guild_id + ) + .fetch_all(&self.pool) + .await?; + Ok(channels.iter().map(|id| id.channel_id as u64).collect()) + } + + pub async fn set_logging_channel( + &self, + guild_id: GuildId, + channel_id: ChannelId, + ) -> Result<()> { + let guild_id = guild_id.0 as i64; + let channel_id = channel_id.0 as i64; + let mut conn = self.pool.acquire().await?; + + sqlx::query!( + r#" +INSERT INTO guild_log_channels (guild_id, channel_id) +VALUES ( ?1, ?2 ) + "#, + guild_id, + channel_id + ) + .execute(&mut *conn) + .await + .map(|_| ()) + } +} diff --git a/src/discord/commands.rs b/src/discord/commands.rs index b0b24e2..93ee94f 100644 --- a/src/discord/commands.rs +++ b/src/discord/commands.rs @@ -2,16 +2,47 @@ use super::{Context, Error}; use super::utils::serenity; -#[poise::command(slash_command)] -pub async fn add_logging_channel( +#[poise::command( + slash_command, + subcommands("add_channel"), + required_permissions = "ADMINISTRATOR" +)] +pub async fn logging(_ctx: Context<'_>, _arg: String) -> Result<(), Error> { + Ok(()) +} + +#[poise::command(slash_command, aliases("add-channel"))] +pub async fn add_channel( ctx: Context<'_>, - #[description = "Selected channel"] channel: Option, + #[description = "New logging channel"] channel: serenity::Channel, ) -> Result<(), Error> { - let response = match channel { - None => "No channel selected. Please select one.".to_owned(), - Some(chan) => { - let channel_id = chan.id(); - format!("Selected channel <#{channel_id}>") + let channel_id = channel.id(); + let response = match ctx.guild_id() { + None => "Error: Could not determine the guild's ID.".to_owned(), + Some(guild_id) => { + match ctx + .data() + .database + .set_logging_channel(guild_id, channel_id) + .await + { + Ok(_) => format!( + "Added channel <#{channel_id}> as a logging channel" + ), + Err(e) => { + if let Some(db_error) = e.as_database_error() { + if db_error.is_unique_violation() { + format!("Channel <#{channel_id}> is already a logging channel") + } else { + format!("Error: {:?}", e) + } + } else { + format!( + "Something bad happened with the database: {e:?}" + ) + } + } + } } }; ctx.say(response).await?; diff --git a/src/discord/mod.rs b/src/discord/mod.rs index 1e17966..2818095 100644 --- a/src/discord/mod.rs +++ b/src/discord/mod.rs @@ -5,14 +5,14 @@ pub mod utils; use poise::FrameworkBuilder; use utils::serenity; -use commands::add_logging_channel; +use commands::logging; use utils::{BotData, Context, Error}; pub async fn make_bot() -> color_eyre::Result> { let framework = poise::Framework::builder() .options(poise::FrameworkOptions { - commands: vec![add_logging_channel()], + commands: vec![logging()], ..Default::default() }) .token(std::env::var("DISCORD_TOKEN").expect("missing DISCORD_TOKEN")) diff --git a/src/discord/utils.rs b/src/discord/utils.rs index e08a8d0..ee599da 100644 --- a/src/discord/utils.rs +++ b/src/discord/utils.rs @@ -2,7 +2,7 @@ use crate::db::Database; pub use poise::serenity_prelude as serenity; pub struct BotData { - database: Database, + pub database: Database, } impl BotData { diff --git a/src/utils.rs b/src/utils.rs index a610e4b..b6258e0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -3,7 +3,7 @@ use tracing_subscriber::FmtSubscriber; pub fn setup_logging() { let subscriber = FmtSubscriber::builder() - .with_max_level(Level::DEBUG) + .with_max_level(Level::INFO) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("Setting default subscriber failed");