refactor: simplify code, better organize it, and comment it

This commit is contained in:
Lucien Cartier-Tilet 2023-11-25 22:01:02 +01:00
parent 75cd5dd7cb
commit d6b208963d
Signed by: phundrak
GPG Key ID: BD7789E705CB8DCA
11 changed files with 303 additions and 130 deletions

87
Cargo.lock generated
View File

@ -96,6 +96,16 @@ dependencies = [
"num-traits",
]
[[package]]
name = "atomic-write-file"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c232177ba50b16fe7a4588495bd474a62a9e45a8e4ca6fd7d0b7ac29d164631e"
dependencies = [
"nix",
"rand",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@ -646,9 +656,9 @@ dependencies = [
[[package]]
name = "gimli"
version = "0.28.0"
version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
[[package]]
name = "h2"
@ -916,9 +926,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libsqlite3-sys"
version = "0.26.0"
version = "0.27.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "afc22eff61b133b115c6e8c74e818c628d6d5e7a502afea6f64dee076dd94326"
checksum = "cf4e226dcd58b4be396f7bd3c20da8fdee2911400705297ba7d2d7cc2c30f716"
dependencies = [
"cc",
"pkg-config",
@ -963,6 +973,15 @@ version = "2.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "mime"
version = "0.3.17"
@ -1005,6 +1024,19 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags 1.3.2",
"cfg-if",
"libc",
"memoffset",
"pin-utils",
]
[[package]]
name = "nom"
version = "7.1.3"
@ -1736,9 +1768,9 @@ dependencies = [
[[package]]
name = "sqlx"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e50c216e3624ec8e7ecd14c6a6a6370aad6ee5d8cfc3ab30b5162eeeef2ed33"
checksum = "dba03c279da73694ef99763320dea58b51095dfe87d001b1d4b5fe78ba8763cf"
dependencies = [
"sqlx-core",
"sqlx-macros",
@ -1749,9 +1781,9 @@ dependencies = [
[[package]]
name = "sqlx-core"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d6753e460c998bbd4cd8c6f0ed9a64346fcca0723d6e75e52fdc351c5d2169d"
checksum = "d84b0a3c3739e220d94b3239fd69fb1f74bc36e16643423bd99de3b43c21bfbd"
dependencies = [
"ahash",
"atoi",
@ -1787,14 +1819,14 @@ dependencies = [
"tokio-stream",
"tracing",
"url",
"webpki-roots 0.24.0",
"webpki-roots 0.25.3",
]
[[package]]
name = "sqlx-macros"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a793bb3ba331ec8359c1853bd39eed32cdd7baaf22c35ccf5c92a7e8d1189ec"
checksum = "89961c00dc4d7dffb7aee214964b065072bff69e36ddb9e2c107541f75e4f2a5"
dependencies = [
"proc-macro2",
"quote",
@ -1805,10 +1837,11 @@ dependencies = [
[[package]]
name = "sqlx-macros-core"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a4ee1e104e00dedb6aa5ffdd1343107b0a4702e862a84320ee7cc74782d96fc"
checksum = "d0bd4519486723648186a08785143599760f7cc81c52334a55d6a83ea1e20841"
dependencies = [
"atomic-write-file",
"dotenvy",
"either",
"heck",
@ -1830,9 +1863,9 @@ dependencies = [
[[package]]
name = "sqlx-mysql"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "864b869fdf56263f4c95c45483191ea0af340f9f3e3e7b4d57a61c7c87a970db"
checksum = "e37195395df71fd068f6e2082247891bc11e3289624bbc776a0cdfa1ca7f1ea4"
dependencies = [
"atoi",
"base64 0.21.5",
@ -1872,9 +1905,9 @@ dependencies = [
[[package]]
name = "sqlx-postgres"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb7ae0e6a97fb3ba33b23ac2671a5ce6e3cabe003f451abd5a56e7951d975624"
checksum = "d6ac0ac3b7ccd10cc96c7ab29791a7dd236bd94021f31eec7ba3d46a74aa1c24"
dependencies = [
"atoi",
"base64 0.21.5",
@ -1911,9 +1944,9 @@ dependencies = [
[[package]]
name = "sqlx-sqlite"
version = "0.7.2"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d59dc83cf45d89c555a577694534fcd1b55c545a816c816ce51f20bbe56a4f3f"
checksum = "210976b7d948c7ba9fced8ca835b11cbb2d677c59c79de41ac0d397e14547490"
dependencies = [
"atoi",
"flume",
@ -1929,6 +1962,7 @@ dependencies = [
"sqlx-core",
"tracing",
"url",
"urlencoding",
]
[[package]]
@ -2337,6 +2371,12 @@ dependencies = [
"serde",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf-8"
version = "0.7.6"
@ -2484,15 +2524,6 @@ dependencies = [
"webpki",
]
[[package]]
name = "webpki-roots"
version = "0.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b291546d5d9d1eab74f069c77749f2cb8504a12caa20f0f2de93ddbf6f411888"
dependencies = [
"rustls-webpki",
]
[[package]]
name = "webpki-roots"
version = "0.25.3"

View File

@ -8,45 +8,74 @@ use tracing::error;
pub type Result<T> = ::std::result::Result<T, sqlx::Error>;
pub struct Database {
pool: SqlitePool,
}
pub struct Database(SqlitePool);
impl Database {
/// Initialize Sqlite database.
///
/// The Sqlite database should already exist and have its
/// migrations already executed.
///
/// # Panics
///
/// Panics if the environment variable `DATABASE_URL` is not set.
///
/// # Errors
///
/// This function will return an error if the Sqlite pool fails to
/// create.
pub async fn new() -> Result<Self> {
Ok(Self {
pool: SqlitePool::connect(
Ok(Self(
SqlitePool::connect(
&env::var("DATABASE_URL")
.expect("Missing enviroment variable DATABASE_URL"),
)
.await?,
})
))
}
/// Return from database all channels registered as loggers for a
/// guild.
///
/// # Errors
///
/// This function will return an error if `sqlx` does so.
pub async fn get_logging_channels(
&self,
guild_id: GuildId,
) -> Result<Vec<u64>> {
) -> Result<Vec<ChannelId>> {
let guild_id = guild_id.0 as i64;
let channels = sqlx::query!(
sqlx::query!(
r#"
SELECT channel_id
FROM guild_log_channels
WHERE guild_id = ?1
"#,
WHERE guild_id = ?1"#,
guild_id
)
.fetch_all(&self.pool)
.fetch_all(&self.0)
.await
.map_err(|e| {
error!(
"Error getting logging channels for guild {guild_id}: {e:?}"
);
e
})?;
Ok(channels.iter().map(|id| id.channel_id as u64).collect())
})
.map(|channels| {
channels
.iter()
.map(|id| ChannelId(id.channel_id as u64))
.collect()
})
}
/// Adds a channel as a logger for a guild.
///
/// # Errors
///
/// This function will return an error if `sqlx` does so. This may
/// be either a database issue, or a channel is already registered
/// as a guild's logger, therefore violating the unicity
/// constraint for guild ID and channel ID pairs.
pub async fn set_logging_channel(
&self,
guild_id: GuildId,
@ -54,13 +83,11 @@ WHERE guild_id = ?1
) -> Result<()> {
let guild_id = guild_id.0 as i64;
let channel_id = channel_id.0 as i64;
let mut conn = self.pool.acquire().await?;
let mut conn = self.0.acquire().await?;
sqlx::query!(
r#"
sqlx::query!(r#"
INSERT INTO guild_log_channels (guild_id, channel_id)
VALUES ( ?1, ?2 )
"#,
VALUES ( ?1, ?2 )"#,
guild_id,
channel_id
)
@ -73,18 +100,25 @@ VALUES ( ?1, ?2 )
.map(|_| ())
}
/// Unregister a channel as a logger for a guild.
///
/// This function will return a success value even if `channel`
/// was not a logger of `guild` already.
///
/// # Errors
///
/// This function will return an error if `sqlx` does so.
pub async fn remove_logging_channel(
&self,
guild_id: GuildId,
channel_id: ChannelId,
guild: GuildId,
channel: ChannelId,
) -> Result<()> {
let guild_id = guild_id.0 as i64;
let channel_id = channel_id.0 as i64;
let mut conn = self.pool.acquire().await?;
let guild_id = guild.0 as i64;
let channel_id = channel.0 as i64;
let mut conn = self.0.acquire().await?;
sqlx::query!(r#"
DELETE FROM guild_log_channels
WHERE guild_id = ?1 AND channel_id = ?2
"#,
WHERE guild_id = ?1 AND channel_id = ?2"#,
guild_id,
channel_id)
.execute(&mut *conn)

View File

@ -1,7 +1,16 @@
use super::{Context, Result};
use super::super::{Context, Result};
use super::utils::serenity;
use poise::serenity_prelude as serenity;
/// Main command for logging subcommands.
///
/// This command cannot be called on its own and will do nothing by
/// itself.
///
/// # Errors
///
/// This command will never error out, even if its signature says it
/// can.
#[allow(clippy::unused_async)]
#[poise::command(
slash_command,
@ -12,8 +21,13 @@ pub async fn logging(_ctx: Context<'_>) -> Result {
Ok(())
}
/// Add a channel as a logger.
///
/// # Errors
///
/// This function will return an error if .
#[poise::command(slash_command)]
pub async fn add_channel(
async fn add_channel(
ctx: Context<'_>,
#[description = "New logging channel"] channel: serenity::Channel,
) -> Result {
@ -50,8 +64,16 @@ pub async fn add_channel(
Ok(())
}
/// List all channels registered as loggers for a guild.
///
/// This will list all channels that are logger channels in the server
/// from which the command was executed.
///
/// # Errors
///
/// This function will return an error if the database returns one.
#[poise::command(slash_command)]
pub async fn list_channels(ctx: Context<'_>) -> Result {
async fn list_channels(ctx: Context<'_>) -> Result {
let response = match ctx.guild_id() {
None => "Error: Could not determine the guild's ID".to_owned(),
Some(guild_id) => {
@ -78,8 +100,18 @@ pub async fn list_channels(ctx: Context<'_>) -> Result {
Ok(())
}
/// Remove a channel as a logger in a guild.
///
/// This will remove a channel from the list of logger channels in the
/// guild from which the command was executed. If the channel is not a
/// logger, the bot will still consider unsetting the channel as a
/// logger a success.
///
/// # Errors
///
/// This function will return an error if the database errors.
#[poise::command(slash_command)]
pub async fn remove_channel(
async fn remove_channel(
ctx: Context<'_>,
#[description = "Logger channel to remove"] channel: serenity::Channel,
) -> Result {

View File

@ -0,0 +1,3 @@
mod logging;
pub(crate) use logging::logging;

24
src/discord/error.rs Normal file
View File

@ -0,0 +1,24 @@
use std::error::Error as StdError;
use std::fmt::{self, Display};
#[derive(Debug, Clone, Copy)]
pub enum Error {
GuildIdNotFound,
}
impl Error {
pub fn boxed(self) -> Box<Self> {
Box::new(self)
}
}
impl Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// write!(f, "")
match self {
Self::GuildIdNotFound => write!(f, "Guild ID not found!"),
}
}
}
impl StdError for Error {}

View File

@ -1,72 +0,0 @@
use crate::db::Database;
use super::{utils::BotData, Error, Result};
use poise::{serenity_prelude as serenity, Event};
use tracing::{error, info};
async fn handle_everyone_mention(
ctx: &serenity::Context,
database: &Database,
message: &serenity::Message,
) -> Result {
use serenity::ChannelId;
if let Some(guild_id) = message.guild_id {
if message.mention_everyone {
let author = message.author.clone();
let message_channel = message.channel_id;
let channels: Vec<ChannelId> = database
.get_logging_channels(guild_id)
.await?
.iter()
.map(|channel_id| serenity::ChannelId(channel_id.to_owned()))
.collect();
for channel in &channels {
channel
.send_message(&ctx, |m| {
m.embed(|e| {
e.title("Someone mentioned everyone!")
.field("Author", author.clone(), true)
.field(
"When",
message.timestamp.naive_local().to_string(),
true,
)
.field(
"Channel",
format!("<#{message_channel}>"),
true,
)
.field("Link", format!("https://discord.com/channels/{guild_id}/{}/{}", channel.0, message.id), false)
})
})
.await
.map_err(|e| {
error!("Failed to send message: {e:?}");
e
})?;
}
}
} else {
error!("Could not determine guild id of message {message:?}");
}
Ok(())
}
pub async fn event_handler(
ctx: &serenity::Context,
event: &Event<'_>,
_framework: poise::FrameworkContext<'_, BotData, Error>,
data: &BotData,
) -> Result {
match event {
Event::Ready { data_about_bot } => {
info!("Logged in as {}", data_about_bot.user.name);
}
Event::Message { new_message } => {
handle_everyone_mention(ctx, &data.database, new_message).await?;
}
_ => {}
}
Ok(())
}

View File

@ -0,0 +1,68 @@
use crate::db::Database;
use super::super::Result;
use super::super::error::Error as DiscordError;
use poise::serenity_prelude::{self as serenity, CreateEmbed};
use tracing::{error, info};
fn message_for_everyone_mention(
embed: &mut CreateEmbed,
message: &serenity::Message,
guild_id: u64,
) {
let author = message.author.clone();
let message_channel = message.channel_id.0;
embed
.title("Someone mentioned everyone!")
.field("Author", author.clone(), true)
.field("When", message.timestamp.naive_local().to_string(), true)
.field("Channel", format!("<#{message_channel}>"), true)
.field(
"Link",
format!(
"https://discord.com/channels/{guild_id}/{message_channel}/{}",
message.id
),
false,
);
}
/// Handle messages mentioning everyone.
///
/// # Errors
///
/// This function will return an error if a message fails to be sent,
/// if retrieving the list of channels registered as loggers fails, or
/// if there is not guild ID that can be retrieved from the message.
pub async fn handle_everyone_mention(
ctx: &serenity::Context,
database: &Database,
message: &serenity::Message,
) -> Result {
info!("Message mentioning everyone: {message:?}");
if !message.mention_everyone {
return Ok(());
}
if message.guild_id.is_none() {
error!("Message without a guild_id! {message:?}");
return Err(DiscordError::GuildIdNotFound.boxed());
}
let guild_id = message.guild_id.unwrap();
let channels: Vec<serenity::ChannelId> =
database.get_logging_channels(guild_id).await?;
for channel in &channels {
// Ignore result, it'll be in the bot's logger
let _ = channel
.send_message(&ctx, |m| {
m.embed(|e| {
message_for_everyone_mention(e, message, guild_id.0);
e
})
})
.await
.map_err(|e| error!("Failed to send message: {e:?}"));
}
Ok(())
}

34
src/discord/events/mod.rs Normal file
View File

@ -0,0 +1,34 @@
use super::{utils::BotData, Error, Result};
use poise::{
serenity_prelude::{self as serenity},
Event,
};
use tracing::info;
mod everyone;
use everyone::handle_everyone_mention;
/// Function handling events the bot can see.
///
/// # Errors
///
/// This function will return an error if one of the functions error
/// themselves.
pub async fn event_handler(
ctx: &serenity::Context,
event: &Event<'_>,
_framework: poise::FrameworkContext<'_, BotData, Error>,
data: &BotData,
) -> Result {
match event {
Event::Ready { data_about_bot } => {
info!("Logged in as {}", data_about_bot.user.name);
}
Event::Message { new_message } => {
handle_everyone_mention(ctx, &data.database, new_message).await?;
}
_ => {}
}
Ok(())
}

View File

@ -1,6 +1,7 @@
mod commands;
mod events;
pub mod utils;
pub mod error;
use poise::FrameworkBuilder;
use utils::serenity;
@ -12,6 +13,11 @@ use self::events::event_handler;
pub type Result = ::std::result::Result<(), Error>;
/// Bootstraps the Discord bot.
///
/// # Panics
///
/// Panics if the environment `DISCORD_TOKEN` is unavailable.
pub fn make_bot() -> FrameworkBuilder<BotData, Error> {
poise::Framework::builder()
.options(poise::FrameworkOptions {

View File

@ -6,6 +6,14 @@ pub struct BotData {
}
impl BotData {
/// Initialize state data for bot.
///
/// For now, this only includes a connector to its database.
///
/// # Errors
///
/// This function will return an error if the database fails to
/// initialize.
pub async fn new() -> color_eyre::Result<Self> {
Ok(Self {
database: Database::new().await?,

View File

@ -1,6 +1,11 @@
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
/// Initialize logging for the project.
///
/// # Panics
///
/// Panics if the logger fails to initialize.
pub fn setup_logging() {
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)