generated from phundrak/rust-poem-openapi-template
This commit separates the core features of géjdr from the backend as these will also be used by the bot in the future. This commit also updates the dependencies of the project. It also removes the dependency lettre as well as the mailpit docker service for developers as it appears clearer this project won’t send emails anytime soon. The publication of a docker image is also postponed until later.
This commit is contained in:
parent
ae10711e41
commit
c5688ecb86
@ -1,13 +1,12 @@
|
|||||||
;;; Directory Local Variables -*- no-byte-compile: t -*-
|
;;; Directory Local Variables -*- no-byte-compile: t -*-
|
||||||
;;; For more information see (info "(emacs) Directory Variables")
|
;;; For more information see (info "(emacs) Directory Variables")
|
||||||
|
|
||||||
((sql-mode
|
((rustic-mode . ((fill-column . 80)))
|
||||||
.
|
(sql-mode . ((eval . (progn
|
||||||
((eval . (progn
|
|
||||||
(setq-local lsp-sqls-connections
|
(setq-local lsp-sqls-connections
|
||||||
`(((driver . "postgresql")
|
`(((driver . "postgresql")
|
||||||
(dataSourceName .
|
(dataSourceName \,
|
||||||
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
||||||
(getenv "DB_HOST")
|
(getenv "DB_HOST")
|
||||||
(getenv "DB_PORT")
|
(getenv "DB_PORT")
|
||||||
(getenv "DB_USER")
|
(getenv "DB_USER")
|
||||||
|
@ -3,3 +3,5 @@ out = ["Xml"]
|
|||||||
target-dir = "coverage"
|
target-dir = "coverage"
|
||||||
output-dir = "coverage"
|
output-dir = "coverage"
|
||||||
fail-under = 60
|
fail-under = 60
|
||||||
|
exclude-files = ["target/*"]
|
||||||
|
run-types = ["AllTargets"]
|
||||||
|
@ -4,3 +4,5 @@ skip-clean = true
|
|||||||
target-dir = "coverage"
|
target-dir = "coverage"
|
||||||
output-dir = "coverage"
|
output-dir = "coverage"
|
||||||
fail-under = 60
|
fail-under = 60
|
||||||
|
exclude-files = ["target/*"]
|
||||||
|
run-types = ["AllTargets"]
|
1616
Cargo.lock
generated
1616
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
65
Cargo.toml
65
Cargo.toml
@ -1,61 +1,8 @@
|
|||||||
[package]
|
[workspace]
|
||||||
name = "gege-jdr-backend"
|
|
||||||
version = "0.1.0"
|
|
||||||
edition = "2021"
|
|
||||||
publish = false
|
|
||||||
authors = ["phundrak"]
|
|
||||||
rust-version = "1.78"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
members = [
|
||||||
|
"gejdr-core",
|
||||||
[lib]
|
"gejdr-bot",
|
||||||
path = "src/lib.rs"
|
"gejdr-backend",
|
||||||
|
|
||||||
[[bin]]
|
|
||||||
path = "src/main.rs"
|
|
||||||
name = "gege-jdr-backend"
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
chrono = { version = "0.4.38", features = ["serde"] }
|
|
||||||
config = { version = "0.14.0", features = ["yaml"] }
|
|
||||||
dotenvy = "0.15.7"
|
|
||||||
serde = "1.0.204"
|
|
||||||
serde_json = "1.0.120"
|
|
||||||
thiserror = "1.0.63"
|
|
||||||
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
|
|
||||||
tracing = "0.1.40"
|
|
||||||
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
|
||||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
|
||||||
|
|
||||||
[dependencies.lettre]
|
|
||||||
version = "0.11.7"
|
|
||||||
default-features = false
|
|
||||||
features = [
|
|
||||||
"builder",
|
|
||||||
"hostname",
|
|
||||||
"pool",
|
|
||||||
"rustls-tls",
|
|
||||||
"tokio1",
|
|
||||||
"tokio1-rustls-tls",
|
|
||||||
"smtp-transport"
|
|
||||||
]
|
]
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
[dependencies.poem]
|
|
||||||
version = "3.0.4"
|
|
||||||
default-features = false
|
|
||||||
features = [
|
|
||||||
"csrf",
|
|
||||||
"rustls",
|
|
||||||
"cookie",
|
|
||||||
"test"
|
|
||||||
]
|
|
||||||
|
|
||||||
[dependencies.poem-openapi]
|
|
||||||
version = "5.0.3"
|
|
||||||
features = ["chrono", "swagger-ui", "uuid"]
|
|
||||||
|
|
||||||
[dependencies.sqlx]
|
|
||||||
version = "0.8.0"
|
|
||||||
default-features = false
|
|
||||||
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
|
|
19
backend.just
Normal file
19
backend.just
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
default: run
|
||||||
|
|
||||||
|
prepare:
|
||||||
|
pushd gejdr-backend
|
||||||
|
cargo sqlx prepare
|
||||||
|
popd
|
||||||
|
|
||||||
|
build:
|
||||||
|
cargo auditable build
|
||||||
|
|
||||||
|
build-release:
|
||||||
|
cargo auditable build --release
|
||||||
|
|
||||||
|
run:
|
||||||
|
cargo auditable run --bin gejdr-backend
|
||||||
|
|
||||||
|
## Local Variables:
|
||||||
|
## mode: makefile
|
||||||
|
## End:
|
@ -31,25 +31,6 @@ services:
|
|||||||
depends_on:
|
depends_on:
|
||||||
- db
|
- db
|
||||||
|
|
||||||
# If you run GegeJdrBackend in production, DO NOT use mailpit.
|
|
||||||
# This tool is for testing only. Instead, you should use a real SMTP
|
|
||||||
# provider, such as Mailgun, Mailwhale, or Postal.
|
|
||||||
mailpit:
|
|
||||||
image: axllent/mailpit:latest
|
|
||||||
restart: unless-stopped
|
|
||||||
container_name: gege-jdr-backend-mailpit
|
|
||||||
ports:
|
|
||||||
- 127.0.0.1:8025:8025 # WebUI
|
|
||||||
- 127.0.0.1:1025:1025 # SMTP
|
|
||||||
volumes:
|
|
||||||
- gege_jdr_backend_mailpit:/data
|
|
||||||
environment:
|
|
||||||
MP_MAX_MESSAGES: 5000
|
|
||||||
MP_DATABASE: /data/mailpit.db
|
|
||||||
MP_SMTP_AUTH_ACCEPT_ANY: 1
|
|
||||||
MP_SMTP_AUTH_ALLOW_INSECURE: 1
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
gege_jdr_backend_db_data:
|
gege_jdr_backend_db_data:
|
||||||
gege_jdr_backend_pgadmin_data:
|
gege_jdr_backend_pgadmin_data:
|
||||||
gege_jdr_backend_mailpit:
|
|
||||||
|
14
docker/mod.just
Normal file
14
docker/mod.just
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
default: start
|
||||||
|
|
||||||
|
start:
|
||||||
|
docker compose -f compose.dev.yml up -d
|
||||||
|
|
||||||
|
stop:
|
||||||
|
docker compose -f compose.dev.yml down
|
||||||
|
|
||||||
|
logs:
|
||||||
|
docker compose -f compose.dev.yml logs -f
|
||||||
|
|
||||||
|
## Local Variables:
|
||||||
|
## mode: makefile
|
||||||
|
## End:
|
12
flake.lock
generated
12
flake.lock
generated
@ -20,11 +20,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1732014248,
|
"lastModified": 1735291276,
|
||||||
"narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=",
|
"narHash": "sha256-NYVcA06+blsLG6wpAbSPTCyLvxD/92Hy4vlY9WxFI1M=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "23e89b7da85c3640bbc2173fe04f4bd114342367",
|
"rev": "634fd46801442d760e09493a794c4f15db2d0cbb",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -62,11 +62,11 @@
|
|||||||
"nixpkgs": "nixpkgs_2"
|
"nixpkgs": "nixpkgs_2"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1732242723,
|
"lastModified": 1735439489,
|
||||||
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=",
|
"narHash": "sha256-IysonaW/cItfmMuvg43flOqMgS4N0C6yKJobCa09XOQ=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a",
|
"rev": "915d7c42a706f9191696d1b779cf1ea1769d34a8",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
@ -18,7 +18,7 @@
|
|||||||
rustc = rustVersion;
|
rustc = rustVersion;
|
||||||
};
|
};
|
||||||
|
|
||||||
appName = "gege-jdr-backend";
|
appName = "gejdr-backend";
|
||||||
|
|
||||||
appRustBuild = rustPlatform.buildRustPackage {
|
appRustBuild = rustPlatform.buildRustPackage {
|
||||||
pname = appName;
|
pname = appName;
|
||||||
|
6
gejdr-backend/.tarpaulin.ci.toml
Normal file
6
gejdr-backend/.tarpaulin.ci.toml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
[all]
|
||||||
|
out = ["Xml"]
|
||||||
|
target-dir = "coverage"
|
||||||
|
output-dir = "coverage"
|
||||||
|
fail-under = 60
|
||||||
|
exclude-files = ["target/*"]
|
7
gejdr-backend/.tarpaulin.local.toml
Normal file
7
gejdr-backend/.tarpaulin.local.toml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
[all]
|
||||||
|
out = ["Html", "Lcov"]
|
||||||
|
skip-clean = true
|
||||||
|
target-dir = "coverage"
|
||||||
|
output-dir = "coverage"
|
||||||
|
fail-under = 60
|
||||||
|
exclude-files = ["target/*"]
|
43
gejdr-backend/Cargo.toml
Normal file
43
gejdr-backend/Cargo.toml
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
[package]
|
||||||
|
name = "gejdr-backend"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
publish = false
|
||||||
|
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
path = "src/main.rs"
|
||||||
|
name = "gejdr-backend"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
gejdr-core = { path = "../gejdr-core" }
|
||||||
|
chrono = { version = "0.4.38", features = ["serde"] }
|
||||||
|
config = { version = "0.14.1", features = ["yaml"] }
|
||||||
|
dotenvy = "0.15.7"
|
||||||
|
oauth2 = "4.4.2"
|
||||||
|
quote = "1.0.37"
|
||||||
|
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
|
||||||
|
serde = "1.0.215"
|
||||||
|
serde_json = "1.0.133"
|
||||||
|
thiserror = "1.0.69"
|
||||||
|
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
|
||||||
|
tracing = "0.1.40"
|
||||||
|
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
||||||
|
uuid = { version = "1.11.0", features = ["v4", "serde"] }
|
||||||
|
|
||||||
|
[dependencies.poem]
|
||||||
|
version = "3.1.3"
|
||||||
|
default-features = false
|
||||||
|
features = ["csrf", "rustls", "cookie", "test", "session"]
|
||||||
|
|
||||||
|
[dependencies.poem-openapi]
|
||||||
|
version = "5.1.2"
|
||||||
|
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
|
||||||
|
|
||||||
|
[lints.rust]
|
||||||
|
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }
|
@ -3,5 +3,5 @@ debug: true
|
|||||||
|
|
||||||
application:
|
application:
|
||||||
protocol: http
|
protocol: http
|
||||||
host: 127.0.0.1
|
host: localhost
|
||||||
base_url: http://127.0.0.1:3000
|
base_url: http://localhost:3000
|
55
gejdr-backend/src/api_wrapper/discord.rs
Normal file
55
gejdr-backend/src/api_wrapper/discord.rs
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
use super::{ApiError, DiscordErrorResponse};
|
||||||
|
use gejdr_core::models::accounts::RemoteUser;
|
||||||
|
|
||||||
|
static DISCORD_URL: &str = "https://discord.com/api/v10/";
|
||||||
|
|
||||||
|
pub async fn get_user_profile(token: &str) -> Result<RemoteUser, ApiError> {
|
||||||
|
let client = reqwest::Client::new();
|
||||||
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
|
headers.insert(
|
||||||
|
reqwest::header::AUTHORIZATION,
|
||||||
|
format!("Bearer {token}").parse().unwrap(),
|
||||||
|
);
|
||||||
|
let response = client
|
||||||
|
.get(format!("{DISCORD_URL}/users/@me"))
|
||||||
|
.headers(headers)
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
match response {
|
||||||
|
Ok(resp) => {
|
||||||
|
if resp.status().is_success() {
|
||||||
|
resp.json::<RemoteUser>()
|
||||||
|
.await
|
||||||
|
.map_err(std::convert::Into::into)
|
||||||
|
} else {
|
||||||
|
let error_response = resp.json::<DiscordErrorResponse>().await;
|
||||||
|
match error_response {
|
||||||
|
Ok(val) => Err(ApiError::Api(val)),
|
||||||
|
Err(e) => Err(ApiError::Reqwest(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Err(ApiError::Reqwest(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn user_profile_invalid_token_results_in_401() {
|
||||||
|
let res = get_user_profile("invalid").await;
|
||||||
|
assert!(res.is_err());
|
||||||
|
let err = res.err().unwrap();
|
||||||
|
println!("Error: {err:?}");
|
||||||
|
let expected = DiscordErrorResponse {
|
||||||
|
code: 0,
|
||||||
|
message: "401: Unauthorized".into(),
|
||||||
|
};
|
||||||
|
assert!(matches!(ApiError::Api(expected), _err));
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Find a way to mock calls to discord.com API with a
|
||||||
|
// successful reply
|
||||||
|
}
|
41
gejdr-backend/src/api_wrapper/mod.rs
Normal file
41
gejdr-backend/src/api_wrapper/mod.rs
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
use reqwest::Error as ReqwestError;
|
||||||
|
use std::fmt::{self, Display};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
pub mod discord;
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize, PartialEq, Eq)]
|
||||||
|
pub struct DiscordErrorResponse {
|
||||||
|
pub message: String,
|
||||||
|
pub code: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for DiscordErrorResponse {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
write!(f, "DiscordErrorResponse: {} ({})", self.message, self.code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ApiError {
|
||||||
|
#[error("Reqwest error: {0}")]
|
||||||
|
Reqwest(#[from] ReqwestError),
|
||||||
|
#[error("API Error: {0}")]
|
||||||
|
Api(DiscordErrorResponse),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn display_discord_error_response() {
|
||||||
|
let error = DiscordErrorResponse {
|
||||||
|
message: "Message".into(),
|
||||||
|
code: 42,
|
||||||
|
};
|
||||||
|
let error_str = error.to_string();
|
||||||
|
let expected = "DiscordErrorResponse: Message (42)".to_string();
|
||||||
|
assert_eq!(expected, error_str);
|
||||||
|
}
|
||||||
|
}
|
14
gejdr-backend/src/errors.rs
Normal file
14
gejdr-backend/src/errors.rs
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ApiError {
|
||||||
|
#[error("SQL error: {0}")]
|
||||||
|
Sql(#[from] gejdr_core::sqlx::Error),
|
||||||
|
#[error("OAuth token error: {0}")]
|
||||||
|
TokenError(String),
|
||||||
|
#[error("Unauthorized")]
|
||||||
|
Unauthorized,
|
||||||
|
#[error("Attempted to get a value, none found")]
|
||||||
|
OptionError,
|
||||||
|
}
|
@ -5,10 +5,14 @@
|
|||||||
#![allow(clippy::unused_async)]
|
#![allow(clippy::unused_async)]
|
||||||
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
|
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
|
||||||
|
|
||||||
pub mod route;
|
use gejdr_core::sqlx;
|
||||||
pub mod settings;
|
|
||||||
pub mod startup;
|
mod api_wrapper;
|
||||||
pub mod telemetry;
|
mod errors;
|
||||||
|
mod oauth;
|
||||||
|
mod route;
|
||||||
|
mod settings;
|
||||||
|
mod startup;
|
||||||
|
|
||||||
type MaybeListener = Option<poem::listener::TcpListener<String>>;
|
type MaybeListener = Option<poem::listener::TcpListener<String>>;
|
||||||
|
|
||||||
@ -16,8 +20,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
|
|||||||
dotenvy::dotenv().ok();
|
dotenvy::dotenv().ok();
|
||||||
let settings = settings::Settings::new().expect("Failed to read settings");
|
let settings = settings::Settings::new().expect("Failed to read settings");
|
||||||
if !cfg!(test) {
|
if !cfg!(test) {
|
||||||
let subscriber = telemetry::get_subscriber(settings.clone().debug);
|
let subscriber = gejdr_core::telemetry::get_subscriber(settings.clone().debug);
|
||||||
telemetry::init_subscriber(subscriber);
|
gejdr_core::telemetry::init_subscriber(subscriber);
|
||||||
}
|
}
|
||||||
tracing::event!(
|
tracing::event!(
|
||||||
target: "gege-jdr-backend",
|
target: "gege-jdr-backend",
|
||||||
@ -29,8 +33,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
|
|||||||
tracing::event!(
|
tracing::event!(
|
||||||
target: "gege-jdr-backend",
|
target: "gege-jdr-backend",
|
||||||
tracing::Level::INFO,
|
tracing::Level::INFO,
|
||||||
"Listening on http://127.0.0.1:{}/",
|
"Listening on {}",
|
||||||
application.port()
|
application.settings.web_address()
|
||||||
);
|
);
|
||||||
application
|
application
|
||||||
}
|
}
|
@ -1,5 +1,5 @@
|
|||||||
#[cfg(not(tarpaulin_include))]
|
#[cfg(not(tarpaulin_include))]
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), std::io::Error> {
|
async fn main() -> Result<(), std::io::Error> {
|
||||||
gege_jdr_backend::run(None).await
|
gejdr_backend::run(None).await
|
||||||
}
|
}
|
62
gejdr-backend/src/oauth/discord.rs
Normal file
62
gejdr-backend/src/oauth/discord.rs
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
use oauth2::{
|
||||||
|
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
|
||||||
|
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
|
||||||
|
};
|
||||||
|
use reqwest::Url;
|
||||||
|
|
||||||
|
use crate::{errors::ApiError, settings::Settings};
|
||||||
|
|
||||||
|
use super::OauthProvider;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DiscordOauthProvider {
|
||||||
|
client: BasicClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiscordOauthProvider {
|
||||||
|
pub fn new(settings: &Settings) -> Self {
|
||||||
|
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
|
||||||
|
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
|
||||||
|
.expect("Invalid authorization endpoint URL");
|
||||||
|
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
|
||||||
|
.expect("Invalid token endpoint URL");
|
||||||
|
let revocation_url =
|
||||||
|
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
|
||||||
|
.expect("Invalid revocation URL");
|
||||||
|
let client = BasicClient::new(
|
||||||
|
ClientId::new(settings.discord.client_id.clone()),
|
||||||
|
Some(ClientSecret::new(settings.discord.client_secret.clone())),
|
||||||
|
auth_url,
|
||||||
|
Some(token_url),
|
||||||
|
)
|
||||||
|
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
|
||||||
|
.set_revocation_uri(revocation_url);
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OauthProvider for DiscordOauthProvider {
|
||||||
|
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
|
||||||
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
let (auth_url, csrf_token) = self
|
||||||
|
.client
|
||||||
|
.authorize_url(CsrfToken::new_random)
|
||||||
|
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
|
||||||
|
.set_pkce_challenge(pkce_challenge)
|
||||||
|
.url();
|
||||||
|
(auth_url, csrf_token, pkce_verifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn token(
|
||||||
|
&self,
|
||||||
|
code: String,
|
||||||
|
verifier: PkceCodeVerifier,
|
||||||
|
) -> Result<super::Token, ApiError> {
|
||||||
|
self.client
|
||||||
|
.exchange_code(AuthorizationCode::new(code))
|
||||||
|
.set_pkce_verifier(verifier)
|
||||||
|
.request_async(oauth2::reqwest::async_http_client)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
|
||||||
|
}
|
||||||
|
}
|
17
gejdr-backend/src/oauth/mod.rs
Normal file
17
gejdr-backend/src/oauth/mod.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
mod discord;
|
||||||
|
pub use discord::DiscordOauthProvider;
|
||||||
|
|
||||||
|
use oauth2::{
|
||||||
|
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
|
||||||
|
StandardTokenResponse,
|
||||||
|
};
|
||||||
|
use reqwest::Url;
|
||||||
|
|
||||||
|
use crate::errors::ApiError;
|
||||||
|
|
||||||
|
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
|
||||||
|
|
||||||
|
pub trait OauthProvider {
|
||||||
|
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
|
||||||
|
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
|
||||||
|
}
|
220
gejdr-backend/src/route/auth.rs
Normal file
220
gejdr-backend/src/route/auth.rs
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
use gejdr_core::models::accounts::User;
|
||||||
|
use gejdr_core::sqlx::PgPool;
|
||||||
|
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
|
||||||
|
use poem::web::Data;
|
||||||
|
use poem::{session::Session, web::Form};
|
||||||
|
use poem_openapi::payload::{Json, PlainText};
|
||||||
|
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||||
|
|
||||||
|
use crate::oauth::{DiscordOauthProvider, OauthProvider};
|
||||||
|
|
||||||
|
use super::errors::ErrorResponse;
|
||||||
|
use super::ApiCategory;
|
||||||
|
|
||||||
|
type Token =
|
||||||
|
oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>;
|
||||||
|
|
||||||
|
pub struct AuthApi;
|
||||||
|
|
||||||
|
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
|
||||||
|
struct DiscordCallbackRequest {
|
||||||
|
code: String,
|
||||||
|
state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiscordCallbackRequest {
|
||||||
|
pub fn check_token(&self, token: &CsrfToken) -> Result<(), LoginStatusResponse> {
|
||||||
|
if *token.secret().to_string() == self.state {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(LoginStatusResponse::TokenError(Json(ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "OAuth token error".into(),
|
||||||
|
details: Some(
|
||||||
|
"OAuth provider did not send a message that matches what was expected".into(),
|
||||||
|
),
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(ApiResponse)]
|
||||||
|
enum LoginStatusResponse {
|
||||||
|
#[oai(status = 201)]
|
||||||
|
LoggedIn(Json<UserInfo>),
|
||||||
|
#[oai(status = 201)]
|
||||||
|
LoggedOut(
|
||||||
|
#[oai(header = "Location")] String,
|
||||||
|
#[oai(header = "Cache-Control")] String,
|
||||||
|
),
|
||||||
|
#[oai(status = 301)]
|
||||||
|
LoginRedirect(
|
||||||
|
#[oai(header = "Location")] String,
|
||||||
|
#[oai(header = "Cache-Control")] String,
|
||||||
|
),
|
||||||
|
#[oai(status = 500)]
|
||||||
|
TokenError(Json<ErrorResponse>),
|
||||||
|
#[oai(status = 500)]
|
||||||
|
DatabaseError(Json<ErrorResponse>),
|
||||||
|
#[oai(status = 503)]
|
||||||
|
DiscordError(Json<ErrorResponse>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Eq, PartialEq, serde::Serialize, Object)]
|
||||||
|
struct UserInfo {
|
||||||
|
id: String,
|
||||||
|
username: String,
|
||||||
|
display_name: Option<String>,
|
||||||
|
avatar: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<User> for UserInfo {
|
||||||
|
fn from(value: User) -> Self {
|
||||||
|
Self {
|
||||||
|
id: value.id,
|
||||||
|
username: value.username,
|
||||||
|
display_name: value.name,
|
||||||
|
avatar: value.avatar,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(ApiResponse)]
|
||||||
|
enum UserInfoResponse {
|
||||||
|
#[oai(status = 201)]
|
||||||
|
UserInfo(Json<UserInfo>),
|
||||||
|
#[oai(status = 401)]
|
||||||
|
Unauthorized,
|
||||||
|
#[oai(status = 500)]
|
||||||
|
DatabaseError(Json<ErrorResponse>),
|
||||||
|
#[oai(status = 503)]
|
||||||
|
DiscordError(Json<ErrorResponse>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<UserInfoResponse> for LoginStatusResponse {
|
||||||
|
fn from(value: UserInfoResponse) -> Self {
|
||||||
|
match value {
|
||||||
|
UserInfoResponse::UserInfo(json) => Self::LoggedIn(json),
|
||||||
|
UserInfoResponse::Unauthorized => unimplemented!(),
|
||||||
|
UserInfoResponse::DatabaseError(json) => Self::DatabaseError(json),
|
||||||
|
UserInfoResponse::DiscordError(json) => Self::DiscordError(json),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(ApiResponse)]
|
||||||
|
enum CsrfResponse {
|
||||||
|
#[oai(status = 201)]
|
||||||
|
Token(PlainText<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
|
||||||
|
impl AuthApi {
|
||||||
|
async fn fetch_remote_user(
|
||||||
|
pool: Data<&PgPool>,
|
||||||
|
token: Token,
|
||||||
|
) -> Result<UserInfoResponse, UserInfoResponse> {
|
||||||
|
crate::api_wrapper::discord::get_user_profile(token.access_token().secret())
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::event!(
|
||||||
|
target: "auth-discord",
|
||||||
|
tracing::Level::ERROR,
|
||||||
|
"Failed to communicate with Discord: {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
UserInfoResponse::DiscordError(Json(e.into()))
|
||||||
|
})?
|
||||||
|
.refresh_in_database(&pool)
|
||||||
|
.await
|
||||||
|
.map(|user| UserInfoResponse::UserInfo(Json(user.into())))
|
||||||
|
.map_err(|e| {
|
||||||
|
tracing::event!(
|
||||||
|
target: "auth-discord",
|
||||||
|
tracing::Level::ERROR,
|
||||||
|
"Database error: {}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
UserInfoResponse::DatabaseError(Json(e.into()))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/signin/discord", method = "get")]
|
||||||
|
async fn signin_discord(
|
||||||
|
&self,
|
||||||
|
oauth: Data<&DiscordOauthProvider>,
|
||||||
|
session: &Session,
|
||||||
|
) -> LoginStatusResponse {
|
||||||
|
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
|
||||||
|
session.set("csrf", csrf_token);
|
||||||
|
session.set("pkce", pkce_verifier);
|
||||||
|
tracing::event!(
|
||||||
|
target: "auth-discord",
|
||||||
|
tracing::Level::INFO,
|
||||||
|
"Signin through Discord",
|
||||||
|
);
|
||||||
|
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/callback/discord", method = "get")]
|
||||||
|
async fn callback_discord(
|
||||||
|
&self,
|
||||||
|
Form(auth_request): Form<DiscordCallbackRequest>,
|
||||||
|
oauth: Data<&DiscordOauthProvider>,
|
||||||
|
pool: Data<&PgPool>,
|
||||||
|
session: &Session,
|
||||||
|
) -> Result<LoginStatusResponse, LoginStatusResponse> {
|
||||||
|
tracing::event!(
|
||||||
|
target: "auth-discord",
|
||||||
|
tracing::Level::INFO,
|
||||||
|
"Discord callback",
|
||||||
|
);
|
||||||
|
let csrf_token = session.get::<CsrfToken>("csrf").ok_or_else(|| {
|
||||||
|
LoginStatusResponse::TokenError(Json(ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "Cannot fetch csrf token from session".to_string(),
|
||||||
|
..Default::default()
|
||||||
|
}))
|
||||||
|
})?;
|
||||||
|
auth_request.check_token(&csrf_token)?;
|
||||||
|
let pkce_verifier = session.get::<PkceCodeVerifier>("pkce").ok_or_else(|| {
|
||||||
|
LoginStatusResponse::TokenError(Json(ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "Cannot fetch pkce verifier from session".to_string(),
|
||||||
|
..Default::default()
|
||||||
|
}))
|
||||||
|
})?;
|
||||||
|
let token = oauth
|
||||||
|
.token(auth_request.code, pkce_verifier)
|
||||||
|
.await
|
||||||
|
.map_err(|e| LoginStatusResponse::TokenError(Json(e.into())))?;
|
||||||
|
session.set("token", token.clone());
|
||||||
|
Self::fetch_remote_user(pool, token)
|
||||||
|
.await
|
||||||
|
.map(std::convert::Into::into)
|
||||||
|
.map_err(std::convert::Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/csrf", method = "get")]
|
||||||
|
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
|
||||||
|
CsrfResponse::Token(PlainText(token.0.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/signout", method = "post")]
|
||||||
|
async fn signout(&self, session: &Session) -> LoginStatusResponse {
|
||||||
|
session.purge();
|
||||||
|
LoginStatusResponse::LoggedOut("/".to_string(), "no-cache".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/me", method = "get")]
|
||||||
|
async fn user_info(
|
||||||
|
&self,
|
||||||
|
session: &Session,
|
||||||
|
pool: Data<&PgPool>,
|
||||||
|
) -> Result<UserInfoResponse, UserInfoResponse> {
|
||||||
|
let token = session
|
||||||
|
.get::<Token>("token")
|
||||||
|
.ok_or(UserInfoResponse::Unauthorized)?;
|
||||||
|
Self::fetch_remote_user(pool, token).await
|
||||||
|
}
|
||||||
|
}
|
204
gejdr-backend/src/route/errors.rs
Normal file
204
gejdr-backend/src/route/errors.rs
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
use poem_openapi::Object;
|
||||||
|
use reqwest::Error as ReqwestError;
|
||||||
|
|
||||||
|
use crate::api_wrapper::ApiError as ApiWrapperError;
|
||||||
|
use crate::errors::ApiError;
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Serialize, Default, Object, PartialEq, Eq)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
pub code: u16,
|
||||||
|
pub message: String,
|
||||||
|
pub details: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ApiError> for ErrorResponse {
|
||||||
|
fn from(value: ApiError) -> Self {
|
||||||
|
match value {
|
||||||
|
ApiError::Sql(e) => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "SQL error".into(),
|
||||||
|
details: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
ApiError::TokenError(e) => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "OAuth token error".into(),
|
||||||
|
details: Some(e),
|
||||||
|
},
|
||||||
|
ApiError::Unauthorized => Self {
|
||||||
|
code: 401,
|
||||||
|
message: "Unauthorized!".into(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
ApiError::OptionError => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "Attempted to get a value, but none found".into(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ReqwestError> for ErrorResponse {
|
||||||
|
fn from(value: ReqwestError) -> Self {
|
||||||
|
Self {
|
||||||
|
code: 503,
|
||||||
|
message: "Failed to communicate with Discord".into(),
|
||||||
|
details: Some(value.status().map_or_else(
|
||||||
|
|| "Communication failed before we could hear back from Discord".into(),
|
||||||
|
|status| format!("Discord sent back the error code {status}"),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ApiWrapperError> for ErrorResponse {
|
||||||
|
fn from(source: ApiWrapperError) -> Self {
|
||||||
|
match source {
|
||||||
|
ApiWrapperError::Reqwest(e) => e.into(),
|
||||||
|
ApiWrapperError::Api(e) => Self {
|
||||||
|
code: if e.message.as_str().starts_with("401") {
|
||||||
|
401
|
||||||
|
} else {
|
||||||
|
e.code
|
||||||
|
},
|
||||||
|
message: e.message,
|
||||||
|
details: None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<gejdr_core::sqlx::Error> for ErrorResponse {
|
||||||
|
fn from(_value: gejdr_core::sqlx::Error) -> Self {
|
||||||
|
Self {
|
||||||
|
code: 500,
|
||||||
|
message: "Internal database error".into(),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::api_wrapper::{ApiError as ApiWrapperError, DiscordErrorResponse};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversion_from_sql_api_error_works() {
|
||||||
|
let sql_error = ApiError::Sql(gejdr_core::sqlx::Error::ColumnNotFound(
|
||||||
|
"COLUMN_NAME".to_string(),
|
||||||
|
));
|
||||||
|
let final_error = ErrorResponse::from(sql_error);
|
||||||
|
let expected_error = ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "SQL error".into(),
|
||||||
|
details: Some("no column found for name: COLUMN_NAME".into()),
|
||||||
|
};
|
||||||
|
assert_eq!(expected_error, final_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversion_from_token_error_works() {
|
||||||
|
let initial_error = ApiError::TokenError("TOKEN ERROR".into());
|
||||||
|
let final_error: ErrorResponse = initial_error.into();
|
||||||
|
let expected_error = ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "OAuth token error".into(),
|
||||||
|
details: Some("TOKEN ERROR".into()),
|
||||||
|
};
|
||||||
|
assert_eq!(expected_error, final_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversion_from_unauthorized_works() {
|
||||||
|
let initial_error = ApiError::Unauthorized;
|
||||||
|
let final_error: ErrorResponse = initial_error.into();
|
||||||
|
let expected_error = ErrorResponse {
|
||||||
|
code: 401,
|
||||||
|
message: "Unauthorized!".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert_eq!(expected_error, final_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversion_from_option_error_works() {
|
||||||
|
let initial_error = ApiError::OptionError;
|
||||||
|
let final_error: ErrorResponse = initial_error.into();
|
||||||
|
let expected_error = ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "Attempted to get a value, but none found".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
assert_eq!(expected_error, final_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn conversion_from_reqwest_error() {
|
||||||
|
let err = reqwest::get("https://example.example/401").await;
|
||||||
|
assert!(err.is_err());
|
||||||
|
let expected = ErrorResponse {
|
||||||
|
code: 503,
|
||||||
|
message: "Failed to communicate with Discord".into(),
|
||||||
|
details: Some("Communication failed before we could hear back from Discord".into()),
|
||||||
|
};
|
||||||
|
let actual: ErrorResponse = err.err().unwrap().into();
|
||||||
|
assert_eq!(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn conversion_from_apiwrappererror_with_reqwest_error() {
|
||||||
|
let err = reqwest::get("https://example.example/401").await;
|
||||||
|
assert!(err.is_err());
|
||||||
|
let err = ApiWrapperError::Reqwest(err.err().unwrap());
|
||||||
|
let expected = ErrorResponse {
|
||||||
|
code: 503,
|
||||||
|
message: "Failed to communicate with Discord".into(),
|
||||||
|
details: Some("Communication failed before we could hear back from Discord".into()),
|
||||||
|
};
|
||||||
|
let actual: ErrorResponse = err.into();
|
||||||
|
assert_eq!(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversion_from_apiwrappererror_with_401_discord_error() {
|
||||||
|
let err = ApiWrapperError::Api(DiscordErrorResponse {
|
||||||
|
code: 0,
|
||||||
|
message: "401: Unauthorized".into(),
|
||||||
|
});
|
||||||
|
let expected = ErrorResponse {
|
||||||
|
code: 401,
|
||||||
|
message: "401: Unauthorized".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let actual: ErrorResponse = err.into();
|
||||||
|
assert_eq!(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversion_from_apiwrappererror_with_generic_discord_error() {
|
||||||
|
let err = ApiWrapperError::Api(DiscordErrorResponse {
|
||||||
|
code: 0,
|
||||||
|
message: "Something else".into(),
|
||||||
|
});
|
||||||
|
let expected = ErrorResponse {
|
||||||
|
code: 0,
|
||||||
|
message: "Something else".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let actual: ErrorResponse = err.into();
|
||||||
|
assert_eq!(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversion_from_database_error() {
|
||||||
|
let err = gejdr_core::sqlx::Error::PoolClosed;
|
||||||
|
let expected = ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "Internal database error".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let actual: ErrorResponse = err.into();
|
||||||
|
assert_eq!(expected, actual);
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ enum HealthResponse {
|
|||||||
|
|
||||||
pub struct HealthApi;
|
pub struct HealthApi;
|
||||||
|
|
||||||
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")]
|
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
|
||||||
impl HealthApi {
|
impl HealthApi {
|
||||||
#[oai(path = "/", method = "get")]
|
#[oai(path = "/", method = "get")]
|
||||||
async fn health_check(&self) -> HealthResponse {
|
async fn health_check(&self) -> HealthResponse {
|
||||||
@ -23,7 +23,7 @@ impl HealthApi {
|
|||||||
async fn health_check_works() {
|
async fn health_check_works() {
|
||||||
let app = crate::get_test_app(None).await;
|
let app = crate::get_test_app(None).await;
|
||||||
let cli = poem::test::TestClient::new(app);
|
let cli = poem::test::TestClient::new(app);
|
||||||
let resp = cli.get("/v1/health-check").send().await;
|
let resp = cli.get("/v1/api/health-check").send().await;
|
||||||
resp.assert_status_is_ok();
|
resp.assert_status_is_ok();
|
||||||
resp.assert_text("").await;
|
resp.assert_text("").await;
|
||||||
}
|
}
|
@ -6,13 +6,19 @@ pub use health::HealthApi;
|
|||||||
mod version;
|
mod version;
|
||||||
pub use version::VersionApi;
|
pub use version::VersionApi;
|
||||||
|
|
||||||
|
mod errors;
|
||||||
|
|
||||||
|
mod auth;
|
||||||
|
pub use auth::AuthApi;
|
||||||
|
|
||||||
#[derive(Tags)]
|
#[derive(Tags)]
|
||||||
enum ApiCategory {
|
enum ApiCategory {
|
||||||
|
Auth,
|
||||||
Health,
|
Health,
|
||||||
Version,
|
Version,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Api;
|
pub struct Api;
|
||||||
|
|
||||||
#[OpenApi]
|
#[OpenApi]
|
||||||
impl Api {}
|
impl Api {}
|
@ -25,7 +25,7 @@ enum VersionResponse {
|
|||||||
|
|
||||||
pub struct VersionApi;
|
pub struct VersionApi;
|
||||||
|
|
||||||
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")]
|
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
|
||||||
impl VersionApi {
|
impl VersionApi {
|
||||||
#[oai(path = "/", method = "get")]
|
#[oai(path = "/", method = "get")]
|
||||||
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
|
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
|
||||||
@ -38,7 +38,7 @@ impl VersionApi {
|
|||||||
async fn version_works() {
|
async fn version_works() {
|
||||||
let app = crate::get_test_app(None).await;
|
let app = crate::get_test_app(None).await;
|
||||||
let cli = poem::test::TestClient::new(app);
|
let cli = poem::test::TestClient::new(app);
|
||||||
let resp = cli.get("/v1/version").send().await;
|
let resp = cli.get("/v1/api/version").send().await;
|
||||||
resp.assert_status_is_ok();
|
resp.assert_status_is_ok();
|
||||||
let json = resp.json().await;
|
let json = resp.json().await;
|
||||||
let json_value = json.value();
|
let json_value = json.value();
|
@ -1,4 +1,4 @@
|
|||||||
use sqlx::ConnectOptions;
|
use gejdr_core::database::Database;
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||||
pub struct Settings {
|
pub struct Settings {
|
||||||
@ -13,16 +13,8 @@ pub struct Settings {
|
|||||||
impl Settings {
|
impl Settings {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn web_address(&self) -> String {
|
pub fn web_address(&self) -> String {
|
||||||
if self.debug {
|
|
||||||
format!(
|
|
||||||
"{}:{}",
|
|
||||||
self.application.base_url.clone(),
|
|
||||||
self.application.port
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
self.application.base_url.clone()
|
self.application.base_url.clone()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Multipurpose function that helps detect the current
|
/// Multipurpose function that helps detect the current
|
||||||
/// environment the application is running in using the
|
/// environment the application is running in using the
|
||||||
@ -66,7 +58,7 @@ impl Settings {
|
|||||||
))
|
))
|
||||||
.add_source(
|
.add_source(
|
||||||
config::Environment::with_prefix("APP")
|
config::Environment::with_prefix("APP")
|
||||||
.prefix_separator("_")
|
.prefix_separator("__")
|
||||||
.separator("__"),
|
.separator("__"),
|
||||||
)
|
)
|
||||||
.build()?;
|
.build()?;
|
||||||
@ -84,35 +76,6 @@ pub struct ApplicationSettings {
|
|||||||
pub protocol: String,
|
pub protocol: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
|
||||||
pub struct Database {
|
|
||||||
pub host: String,
|
|
||||||
pub port: u16,
|
|
||||||
pub name: String,
|
|
||||||
pub user: String,
|
|
||||||
pub password: String,
|
|
||||||
pub require_ssl: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Database {
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
|
|
||||||
let ssl_mode = if self.require_ssl {
|
|
||||||
sqlx::postgres::PgSslMode::Require
|
|
||||||
} else {
|
|
||||||
sqlx::postgres::PgSslMode::Prefer
|
|
||||||
};
|
|
||||||
sqlx::postgres::PgConnectOptions::new()
|
|
||||||
.host(&self.host)
|
|
||||||
.username(&self.user)
|
|
||||||
.password(&self.password)
|
|
||||||
.port(self.port)
|
|
||||||
.ssl_mode(ssl_mode)
|
|
||||||
.database(&self.name)
|
|
||||||
.log_statements(tracing::log::LevelFilter::Trace)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq)]
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
pub enum Environment {
|
pub enum Environment {
|
||||||
Development,
|
Development,
|
||||||
@ -167,8 +130,8 @@ pub struct EmailSettings {
|
|||||||
|
|
||||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||||
pub struct Discord {
|
pub struct Discord {
|
||||||
client_id: String,
|
pub client_id: String,
|
||||||
client_secret: String,
|
pub client_secret: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@ -235,7 +198,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn web_address_works() {
|
fn web_address_works() {
|
||||||
let mut settings = Settings {
|
let settings = Settings {
|
||||||
debug: false,
|
debug: false,
|
||||||
application: ApplicationSettings {
|
application: ApplicationSettings {
|
||||||
base_url: "127.0.0.1".to_string(),
|
base_url: "127.0.0.1".to_string(),
|
||||||
@ -244,10 +207,7 @@ mod tests {
|
|||||||
},
|
},
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let expected_no_debug = "127.0.0.1".to_string();
|
let expected = "127.0.0.1".to_string();
|
||||||
let expected_debug = "127.0.0.1:3000".to_string();
|
assert_eq!(expected, settings.web_address());
|
||||||
assert_eq!(expected_no_debug, settings.web_address());
|
|
||||||
settings.debug = true;
|
|
||||||
assert_eq!(expected_debug, settings.web_address());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,35 +1,36 @@
|
|||||||
use poem::middleware::Cors;
|
use gejdr_core::sqlx;
|
||||||
|
|
||||||
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
|
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
|
||||||
|
use poem::middleware::{CookieJarManagerEndpoint, Cors};
|
||||||
|
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
|
||||||
use poem::{EndpointExt, Route};
|
use poem::{EndpointExt, Route};
|
||||||
use poem_openapi::OpenApiService;
|
use poem_openapi::OpenApiService;
|
||||||
|
|
||||||
|
use crate::oauth::DiscordOauthProvider;
|
||||||
|
use crate::route::AuthApi;
|
||||||
use crate::{
|
use crate::{
|
||||||
route::{Api, HealthApi, VersionApi},
|
route::{Api, HealthApi, VersionApi},
|
||||||
settings::Settings,
|
settings::Settings,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[must_use]
|
|
||||||
pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgres::PgPool {
|
|
||||||
tracing::event!(
|
|
||||||
target: "startup",
|
|
||||||
tracing::Level::INFO,
|
|
||||||
"connecting to database with configuration {:?}",
|
|
||||||
settings.clone()
|
|
||||||
);
|
|
||||||
sqlx::postgres::PgPoolOptions::new()
|
|
||||||
.acquire_timeout(std::time::Duration::from_secs(2))
|
|
||||||
.connect_lazy_with(settings.get_connect_options())
|
|
||||||
}
|
|
||||||
|
|
||||||
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
|
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
|
||||||
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>;
|
pub type App = AddDataEndpoint<
|
||||||
|
AddDataEndpoint<
|
||||||
|
AddDataEndpoint<
|
||||||
|
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
|
||||||
|
DiscordOauthProvider,
|
||||||
|
>,
|
||||||
|
sqlx::Pool<sqlx::Postgres>,
|
||||||
|
>,
|
||||||
|
Settings,
|
||||||
|
>;
|
||||||
|
|
||||||
pub struct Application {
|
pub struct Application {
|
||||||
server: Server,
|
server: Server,
|
||||||
app: poem::Route,
|
app: poem::Route,
|
||||||
port: u16,
|
port: u16,
|
||||||
database: sqlx::postgres::PgPool,
|
database: sqlx::postgres::PgPool,
|
||||||
settings: Settings,
|
pub settings: Settings,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RunnableApplication {
|
pub struct RunnableApplication {
|
||||||
@ -61,6 +62,8 @@ impl From<Application> for RunnableApplication {
|
|||||||
let app = val
|
let app = val
|
||||||
.app
|
.app
|
||||||
.with(Cors::new())
|
.with(Cors::new())
|
||||||
|
.with(CookieSession::new(CookieConfig::default().secure(true)))
|
||||||
|
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
|
||||||
.data(val.database)
|
.data(val.database)
|
||||||
.data(val.settings);
|
.data(val.settings);
|
||||||
let server = val.server;
|
let server = val.server;
|
||||||
@ -74,16 +77,16 @@ impl Application {
|
|||||||
test_pool: Option<sqlx::postgres::PgPool>,
|
test_pool: Option<sqlx::postgres::PgPool>,
|
||||||
) -> sqlx::postgres::PgPool {
|
) -> sqlx::postgres::PgPool {
|
||||||
let database_pool =
|
let database_pool =
|
||||||
test_pool.map_or_else(|| get_connection_pool(&settings.database), |pool| pool);
|
test_pool.map_or_else(|| settings.database.get_connection_pool(), |pool| pool);
|
||||||
if !cfg!(test) {
|
if !cfg!(test) {
|
||||||
migrate_database(&database_pool).await;
|
gejdr_core::database::Database::migrate(&database_pool).await;
|
||||||
}
|
}
|
||||||
database_pool
|
database_pool
|
||||||
}
|
}
|
||||||
|
|
||||||
fn setup_app(settings: &Settings) -> poem::Route {
|
fn setup_app(settings: &Settings) -> poem::Route {
|
||||||
let api_service = OpenApiService::new(
|
let api_service = OpenApiService::new(
|
||||||
(Api, HealthApi, VersionApi),
|
(Api, AuthApi, HealthApi, VersionApi),
|
||||||
settings.application.clone().name,
|
settings.application.clone().name,
|
||||||
settings.application.clone().version,
|
settings.application.clone().version,
|
||||||
);
|
);
|
||||||
@ -104,6 +107,7 @@ impl Application {
|
|||||||
});
|
});
|
||||||
poem::Server::new(tcp_listener)
|
poem::Server::new(tcp_listener)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn build(
|
pub async fn build(
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
test_pool: Option<sqlx::postgres::PgPool>,
|
test_pool: Option<sqlx::postgres::PgPool>,
|
||||||
@ -133,10 +137,3 @@ impl Application {
|
|||||||
self.port
|
self.port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn migrate_database(pool: &sqlx::postgres::PgPool) {
|
|
||||||
sqlx::migrate!()
|
|
||||||
.run(pool)
|
|
||||||
.await
|
|
||||||
.expect("Failed to migrate the database");
|
|
||||||
}
|
|
1
gejdr-bot/.gitignore
vendored
Normal file
1
gejdr-bot/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
6
gejdr-bot/Cargo.toml
Normal file
6
gejdr-bot/Cargo.toml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
[package]
|
||||||
|
name = "gejdr-bot"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
3
gejdr-bot/src/main.rs
Normal file
3
gejdr-bot/src/main.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
println!("Hello, world!");
|
||||||
|
}
|
1
gejdr-core/.gitignore
vendored
Normal file
1
gejdr-core/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
16
gejdr-core/Cargo.toml
Normal file
16
gejdr-core/Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
[package]
|
||||||
|
name = "gejdr-core"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
chrono = { version = "0.4.38", features = ["serde"] }
|
||||||
|
serde = "1.0.215"
|
||||||
|
tracing = "0.1.40"
|
||||||
|
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
||||||
|
uuid = { version = "1.11.0", features = ["v4", "serde"] }
|
||||||
|
|
||||||
|
[dependencies.sqlx]
|
||||||
|
version = "0.8.2"
|
||||||
|
default-features = false
|
||||||
|
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
|
3
gejdr-core/migrations/20240809173617_users.down.sql
Normal file
3
gejdr-core/migrations/20240809173617_users.down.sql
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
-- Add down migration script here
|
||||||
|
DROP TABLE IF EXISTS public.users;
|
||||||
|
DROP EXTENSION IF EXISTS "uuid-ossp";
|
15
gejdr-core/migrations/20240809173617_users.up.sql
Normal file
15
gejdr-core/migrations/20240809173617_users.up.sql
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
-- Add up migration script here
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS public.users
|
||||||
|
(
|
||||||
|
id character varying(255) NOT NULL,
|
||||||
|
username character varying(255) NOT NULL,
|
||||||
|
email character varying(255),
|
||||||
|
avatar character varying(511),
|
||||||
|
name character varying(255),
|
||||||
|
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
CONSTRAINT users_email_unique UNIQUE (email)
|
||||||
|
);
|
50
gejdr-core/src/database.rs
Normal file
50
gejdr-core/src/database.rs
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
use sqlx::ConnectOptions;
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||||
|
pub struct Database {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub name: String,
|
||||||
|
pub user: String,
|
||||||
|
pub password: String,
|
||||||
|
pub require_ssl: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
|
||||||
|
let ssl_mode = if self.require_ssl {
|
||||||
|
sqlx::postgres::PgSslMode::Require
|
||||||
|
} else {
|
||||||
|
sqlx::postgres::PgSslMode::Prefer
|
||||||
|
};
|
||||||
|
sqlx::postgres::PgConnectOptions::new()
|
||||||
|
.host(&self.host)
|
||||||
|
.username(&self.user)
|
||||||
|
.password(&self.password)
|
||||||
|
.port(self.port)
|
||||||
|
.ssl_mode(ssl_mode)
|
||||||
|
.database(&self.name)
|
||||||
|
.log_statements(tracing::log::LevelFilter::Trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[must_use]
|
||||||
|
pub fn get_connection_pool(&self) -> sqlx::postgres::PgPool {
|
||||||
|
tracing::event!(
|
||||||
|
target: "startup",
|
||||||
|
tracing::Level::INFO,
|
||||||
|
"connecting to database with configuration {:?}",
|
||||||
|
self.clone()
|
||||||
|
);
|
||||||
|
sqlx::postgres::PgPoolOptions::new()
|
||||||
|
.acquire_timeout(std::time::Duration::from_secs(2))
|
||||||
|
.connect_lazy_with(self.get_connect_options())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn migrate(pool: &sqlx::PgPool) {
|
||||||
|
sqlx::migrate!()
|
||||||
|
.run(pool)
|
||||||
|
.await
|
||||||
|
.expect("Failed to migrate the database");
|
||||||
|
}
|
||||||
|
}
|
4
gejdr-core/src/lib.rs
Normal file
4
gejdr-core/src/lib.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
pub mod database;
|
||||||
|
pub mod models;
|
||||||
|
pub mod telemetry;
|
||||||
|
pub use sqlx;
|
396
gejdr-core/src/models/accounts.rs
Normal file
396
gejdr-core/src/models/accounts.rs
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
use sqlx::PgPool;
|
||||||
|
|
||||||
|
type Timestampz = chrono::DateTime<chrono::Utc>;
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, PartialEq, Eq, Debug, Clone, Default)]
|
||||||
|
pub struct RemoteUser {
|
||||||
|
id: String,
|
||||||
|
username: String,
|
||||||
|
global_name: Option<String>,
|
||||||
|
email: Option<String>,
|
||||||
|
avatar: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteUser {
|
||||||
|
/// Refresh in database the row related to the remote user. Maybe
|
||||||
|
/// create a row for this user if needed.
|
||||||
|
pub async fn refresh_in_database(self, pool: &PgPool) -> Result<User, sqlx::Error> {
|
||||||
|
match User::find(pool, &self.id).await? {
|
||||||
|
Some(local_user) => local_user.update_from_remote(self).update(pool).await,
|
||||||
|
None => User::from(self).save(pool).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq, Default, Clone)]
|
||||||
|
pub struct User {
|
||||||
|
pub id: String,
|
||||||
|
pub username: String,
|
||||||
|
pub email: Option<String>,
|
||||||
|
pub avatar: Option<String>,
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub created_at: Timestampz,
|
||||||
|
pub last_updated: Timestampz,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RemoteUser> for User {
|
||||||
|
fn from(value: RemoteUser) -> Self {
|
||||||
|
Self {
|
||||||
|
id: value.id,
|
||||||
|
username: value.username,
|
||||||
|
email: value.email,
|
||||||
|
avatar: value.avatar,
|
||||||
|
name: value.global_name,
|
||||||
|
created_at: chrono::offset::Utc::now(),
|
||||||
|
last_updated: chrono::offset::Utc::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<RemoteUser> for User {
|
||||||
|
#[allow(clippy::suspicious_operation_groupings)]
|
||||||
|
fn eq(&self, other: &RemoteUser) -> bool {
|
||||||
|
self.id == other.id
|
||||||
|
&& self.username == other.username
|
||||||
|
&& self.email == other.email
|
||||||
|
&& self.avatar == other.avatar
|
||||||
|
&& self.name == other.global_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq<User> for RemoteUser {
|
||||||
|
fn eq(&self, other: &User) -> bool {
|
||||||
|
other == self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl User {
|
||||||
|
pub fn update_from_remote(self, from: RemoteUser) -> Self {
|
||||||
|
if self == from {
|
||||||
|
self
|
||||||
|
} else {
|
||||||
|
Self {
|
||||||
|
username: from.username,
|
||||||
|
email: from.email,
|
||||||
|
avatar: from.avatar,
|
||||||
|
name: from.global_name,
|
||||||
|
last_updated: chrono::offset::Utc::now(),
|
||||||
|
..self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn find(pool: &PgPool, id: &String) -> Result<Option<Self>, sqlx::Error> {
|
||||||
|
sqlx::query_as!(Self, r#"SELECT * FROM users WHERE id = $1"#, id)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn save(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
|
||||||
|
sqlx::query_as!(
|
||||||
|
Self,
|
||||||
|
r#"
|
||||||
|
INSERT INTO users (id, username, email, avatar, name, created_at, last_updated)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||||
|
RETURNING *
|
||||||
|
"#,
|
||||||
|
self.id,
|
||||||
|
self.username,
|
||||||
|
self.email,
|
||||||
|
self.avatar,
|
||||||
|
self.name,
|
||||||
|
self.created_at,
|
||||||
|
self.last_updated
|
||||||
|
)
|
||||||
|
.fetch_one(pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
|
||||||
|
sqlx::query_as!(
|
||||||
|
Self,
|
||||||
|
r#"
|
||||||
|
UPDATE users
|
||||||
|
SET username = $1, email = $2, avatar = $3, name = $4, last_updated = $5
|
||||||
|
WHERE id = $6
|
||||||
|
RETURNING *
|
||||||
|
"#,
|
||||||
|
self.username,
|
||||||
|
self.email,
|
||||||
|
self.avatar,
|
||||||
|
self.name,
|
||||||
|
self.last_updated,
|
||||||
|
self.id
|
||||||
|
)
|
||||||
|
.fetch_one(pool)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn save_or_update(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
|
||||||
|
if Self::find(pool, &self.id).await?.is_some() {
|
||||||
|
self.update(pool).await
|
||||||
|
} else {
|
||||||
|
self.save(pool).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete(pool: &PgPool, id: &String) -> Result<u64, sqlx::Error> {
|
||||||
|
let rows_affected = sqlx::query!("DELETE FROM users WHERE id = $1", id)
|
||||||
|
.execute(pool)
|
||||||
|
.await?
|
||||||
|
.rows_affected();
|
||||||
|
Ok(rows_affected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_remote_user_to_local_user() {
|
||||||
|
let remote = RemoteUser {
|
||||||
|
id: "user-id".into(),
|
||||||
|
username: "username".into(),
|
||||||
|
global_name: None,
|
||||||
|
email: Some("user@example.com".into()),
|
||||||
|
avatar: None,
|
||||||
|
};
|
||||||
|
let local: User = remote.into();
|
||||||
|
let expected = User {
|
||||||
|
id: "user-id".into(),
|
||||||
|
username: "username".into(),
|
||||||
|
email: Some("user@example.com".into()),
|
||||||
|
avatar: None,
|
||||||
|
name: None,
|
||||||
|
created_at: local.created_at,
|
||||||
|
last_updated: local.last_updated,
|
||||||
|
};
|
||||||
|
assert_eq!(expected, local);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn can_compare_remote_and_local_user() {
|
||||||
|
let remote_same = RemoteUser {
|
||||||
|
id: "user-id".into(),
|
||||||
|
username: "username".into(),
|
||||||
|
global_name: None,
|
||||||
|
email: Some("user@example.com".into()),
|
||||||
|
avatar: None,
|
||||||
|
};
|
||||||
|
let remote_different = RemoteUser {
|
||||||
|
id: "user-id".into(),
|
||||||
|
username: "username".into(),
|
||||||
|
global_name: None,
|
||||||
|
email: Some("user@example.com".into()),
|
||||||
|
avatar: Some("some-hash".into()),
|
||||||
|
};
|
||||||
|
let local = User {
|
||||||
|
id: "user-id".into(),
|
||||||
|
username: "username".into(),
|
||||||
|
email: Some("user@example.com".into()),
|
||||||
|
avatar: None,
|
||||||
|
name: None,
|
||||||
|
created_at: chrono::offset::Utc::now(),
|
||||||
|
last_updated: chrono::offset::Utc::now(),
|
||||||
|
};
|
||||||
|
assert_eq!(remote_same, local);
|
||||||
|
assert_ne!(remote_different, local);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test]
|
||||||
|
async fn add_new_remote_users_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let remote1 = RemoteUser {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let remote2 = RemoteUser {
|
||||||
|
id: "id2".into(),
|
||||||
|
username: "user2".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
remote1.refresh_in_database(&pool).await?;
|
||||||
|
remote2.refresh_in_database(&pool).await?;
|
||||||
|
let users = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, users.len());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test(fixtures("accounts"))]
|
||||||
|
async fn update_local_users_in_db_from_remote(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let users = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, users.len());
|
||||||
|
let remote1 = RemoteUser {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1-new".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let remote2 = RemoteUser {
|
||||||
|
id: "id2".into(),
|
||||||
|
username: "user2-new".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
remote1.refresh_in_database(&pool).await?;
|
||||||
|
remote2.refresh_in_database(&pool).await?;
|
||||||
|
let users = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, users.len());
|
||||||
|
users
|
||||||
|
.iter()
|
||||||
|
.for_each(|user| assert!(user.last_updated > user.created_at));
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn update_local_user_from_identical_remote_shouldnt_change_local() {
|
||||||
|
let remote = RemoteUser {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let local = User {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let new_local = local.clone().update_from_remote(remote);
|
||||||
|
assert_eq!(local, new_local);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn update_local_user_from_different_remote() {
|
||||||
|
let remote = RemoteUser {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user2".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let local = User {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let new_local = local.clone().update_from_remote(remote.clone());
|
||||||
|
assert_ne!(remote, local);
|
||||||
|
assert_eq!(remote, new_local);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test]
|
||||||
|
async fn save_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let user = User {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
user.save(&pool).await?;
|
||||||
|
let users = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(1, users.len());
|
||||||
|
assert_eq!(Some(user), users.first().cloned());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test(fixtures("accounts"))]
|
||||||
|
async fn update_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
|
||||||
|
.fetch_one(&pool)
|
||||||
|
.await?;
|
||||||
|
assert!(db_user.name.is_none());
|
||||||
|
let user = User {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
name: Some("Cool Name".into()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
user.update(&pool).await?;
|
||||||
|
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
|
||||||
|
.fetch_one(&pool)
|
||||||
|
.await?;
|
||||||
|
assert!(db_user.name.is_some());
|
||||||
|
assert_eq!(Some("Cool Name".to_string()), db_user.name);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test]
|
||||||
|
async fn save_or_update_saves_if_no_exist(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(0, rows.len());
|
||||||
|
let user = User {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
user.save_or_update(&pool).await?;
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(1, rows.len());
|
||||||
|
let db_user = rows.first();
|
||||||
|
assert_eq!(Some(user), db_user.cloned());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test(fixtures("accounts"))]
|
||||||
|
async fn save_or_update_updates_if_exists(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, rows.len());
|
||||||
|
let user = User {
|
||||||
|
id: "id1".into(),
|
||||||
|
username: "user1".into(),
|
||||||
|
name: Some("Cool Nam".into()),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
user.save_or_update(&pool).await?;
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, rows.len());
|
||||||
|
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
|
||||||
|
.fetch_one(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(user.name, db_user.name);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test(fixtures("accounts"))]
|
||||||
|
async fn delete_removes_account_from_db(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, rows.len());
|
||||||
|
let id = "id1".to_string();
|
||||||
|
let deletions = User::delete(&pool, &id).await?;
|
||||||
|
assert_eq!(1, deletions);
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(1, rows.len());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[sqlx::test(fixtures("accounts"))]
|
||||||
|
async fn delete_with_wrong_id_shouldnt_delete_anything(pool: sqlx::PgPool) -> sqlx::Result<()> {
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, rows.len());
|
||||||
|
let id = "invalid".to_string();
|
||||||
|
let deletions = User::delete(&pool, &id).await?;
|
||||||
|
assert_eq!(0, deletions);
|
||||||
|
let rows = sqlx::query_as!(User, "SELECT * FROM users")
|
||||||
|
.fetch_all(&pool)
|
||||||
|
.await?;
|
||||||
|
assert_eq!(2, rows.len());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
2
gejdr-core/src/models/fixtures/accounts.sql
Normal file
2
gejdr-core/src/models/fixtures/accounts.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
INSERT INTO users (id, username) VALUES ('id1', 'user1');
|
||||||
|
INSERT INTO users (id, username) VALUES ('id2', 'user2');
|
1
gejdr-core/src/models/mod.rs
Normal file
1
gejdr-core/src/models/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod accounts;
|
@ -5,7 +5,7 @@ pub fn get_subscriber(debug: bool) -> impl tracing::Subscriber + Send + Sync {
|
|||||||
let env_filter = if debug { "debug" } else { "info" }.to_string();
|
let env_filter = if debug { "debug" } else { "info" }.to_string();
|
||||||
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
|
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter));
|
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter));
|
||||||
let stdout_log = tracing_subscriber::fmt::layer().pretty();
|
let stdout_log = tracing_subscriber::fmt::layer().pretty().with_test_writer();
|
||||||
let subscriber = tracing_subscriber::Registry::default()
|
let subscriber = tracing_subscriber::Registry::default()
|
||||||
.with(env_filter)
|
.with(env_filter)
|
||||||
.with(stdout_log);
|
.with(stdout_log);
|
50
justfile
50
justfile
@ -1,10 +1,7 @@
|
|||||||
default: run
|
mod backend
|
||||||
|
mod docker
|
||||||
|
|
||||||
prepare:
|
default: lint
|
||||||
cargo sqlx prepare
|
|
||||||
|
|
||||||
migrate:
|
|
||||||
sqlx migrate run
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
cargo fmt --all
|
cargo fmt --all
|
||||||
@ -12,17 +9,12 @@ format:
|
|||||||
format-check:
|
format-check:
|
||||||
cargo fmt --check --all
|
cargo fmt --check --all
|
||||||
|
|
||||||
|
migrate:
|
||||||
|
sqlx migrate run --source gejdr-core/migrations
|
||||||
|
|
||||||
build:
|
build:
|
||||||
cargo auditable build
|
cargo auditable build --bin gejdr-backend
|
||||||
|
cargo auditable build --bin gejdr-bot
|
||||||
build-release:
|
|
||||||
cargo auditable build --release
|
|
||||||
|
|
||||||
run: docker-start
|
|
||||||
cargo auditable run
|
|
||||||
|
|
||||||
run-no-docker:
|
|
||||||
cargo auditable run
|
|
||||||
|
|
||||||
lint:
|
lint:
|
||||||
cargo clippy --all-targets
|
cargo clippy --all-targets
|
||||||
@ -31,19 +23,22 @@ msrv:
|
|||||||
cargo msrv verify
|
cargo msrv verify
|
||||||
|
|
||||||
release-build:
|
release-build:
|
||||||
cargo auditable build --release
|
cargo auditable build --release --bin gejdr-backend
|
||||||
|
cargo auditable build --release --bin gejdr-bot
|
||||||
|
|
||||||
release-run:
|
release-run:
|
||||||
cargo auditable run --release
|
cargo auditable run --release
|
||||||
|
|
||||||
audit: build
|
audit: build
|
||||||
cargo audit bin target/debug/gege-jdr-backend
|
cargo audit bin target/debug/gejdr-backend
|
||||||
|
cargo audit bin target/debug/gejdr-bot
|
||||||
|
|
||||||
audit-release: build-release
|
audit-release:
|
||||||
cargo audit bin target/release/gege-jdr-backend
|
cargo audit bin target/release/gejdr-backend
|
||||||
|
cargo audit bin target/release/gejdr-bot
|
||||||
|
|
||||||
test:
|
test:
|
||||||
cargo test
|
cargo test --all-targets --all
|
||||||
|
|
||||||
coverage:
|
coverage:
|
||||||
mkdir -p coverage
|
mkdir -p coverage
|
||||||
@ -55,17 +50,8 @@ coverage-ci:
|
|||||||
|
|
||||||
check-all: format-check lint msrv coverage audit
|
check-all: format-check lint msrv coverage audit
|
||||||
|
|
||||||
docker-build:
|
# docker-build:
|
||||||
nix build .#docker
|
# nix build .#docker
|
||||||
|
|
||||||
docker-start:
|
|
||||||
docker compose -f docker/compose.dev.yml up -d
|
|
||||||
|
|
||||||
docker-stop:
|
|
||||||
docker compose -f docker/compose.dev.yml down
|
|
||||||
|
|
||||||
docker-logs:
|
|
||||||
docker compose -f docker/compose.dev.yml logs -f
|
|
||||||
|
|
||||||
## Local Variables:
|
## Local Variables:
|
||||||
## mode: makefile
|
## mode: makefile
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
-- Add down migration script here
|
|
||||||
ALTER TABLE IF EXISTS public.sessions DROP CONSTRAINT IF EXISTS sessions_user_id_users_fk;
|
|
||||||
DROP TABLE IF EXISTS public.sessions;
|
|
||||||
DROP TABLE IF EXISTS public.users;
|
|
||||||
DROP EXTENSION IF EXISTS "uuid-ossp";
|
|
@ -1,29 +0,0 @@
|
|||||||
-- Add up migration script here
|
|
||||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS public.users
|
|
||||||
(
|
|
||||||
id uuid NOT NULL DEFAULT uuid_generate_v4(),
|
|
||||||
email character varying(255) NOT NULL,
|
|
||||||
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
||||||
PRIMARY KEY (id),
|
|
||||||
CONSTRAINT users_email_unique UNIQUE (email)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS public.sessions
|
|
||||||
(
|
|
||||||
id uuid NOT NULL DEFAULT uuid_generate_v4(),
|
|
||||||
user_id uuid NOT NULL,
|
|
||||||
session_id character varying NOT NULL,
|
|
||||||
expires_at timestamp with time zone NOT NULL,
|
|
||||||
PRIMARY KEY (id),
|
|
||||||
CONSTRAINT sessions_user_id_unique UNIQUE (user_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
ALTER TABLE IF EXISTS public.sessions
|
|
||||||
ADD CONSTRAINT sessions_user_id_users_fk FOREIGN KEY (user_id)
|
|
||||||
REFERENCES public.users (id) MATCH SIMPLE
|
|
||||||
ON UPDATE CASCADE
|
|
||||||
ON DELETE CASCADE
|
|
||||||
NOT VALID;
|
|
Loading…
Reference in New Issue
Block a user