feat: OAuth implementation with Discord
All checks were successful
CI / tests (push) Successful in 17m15s
CI / tests (pull_request) Successful in 16m18s

This commit separates the core features of géjdr from the backend as
these will also be used by the bot in the future.

This commit also updates the dependencies of the project. It also
removes the dependency lettre as well as the mailpit docker service
for developers as it appears clearer this project won’t send emails
anytime soon.

The publication of a docker image is also postponed until later.
This commit is contained in:
Lucien Cartier-Tilet 2024-11-23 10:00:53 +01:00
parent ae10711e41
commit d9e29b62b7
Signed by: phundrak
GPG Key ID: 347803E8073EACE0
52 changed files with 2642 additions and 734 deletions

View File

@ -1,13 +1,12 @@
;;; Directory Local Variables -*- no-byte-compile: t -*- ;;; Directory Local Variables -*- no-byte-compile: t -*-
;;; For more information see (info "(emacs) Directory Variables") ;;; For more information see (info "(emacs) Directory Variables")
((sql-mode ((rustic-mode . ((fill-column . 80)))
. (sql-mode . ((eval . (progn
((eval . (progn
(setq-local lsp-sqls-connections (setq-local lsp-sqls-connections
`(((driver . "postgresql") `(((driver . "postgresql")
(dataSourceName . (dataSourceName \,
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable" (format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
(getenv "DB_HOST") (getenv "DB_HOST")
(getenv "DB_PORT") (getenv "DB_PORT")
(getenv "DB_USER") (getenv "DB_USER")

View File

@ -45,8 +45,6 @@ jobs:
run: nix develop --command -- just lint run: nix develop --command -- just lint
- name: Audit - name: Audit
run: nix develop --command -- just audit run: nix develop --command -- just audit
- name: Minimum supported Rust version check
run: nix develop --command -- just msrv
- name: Tests - name: Tests
run: nix develop --command -- just test run: nix develop --command -- just test
- name: Coverage - name: Coverage

View File

@ -0,0 +1,64 @@
{
"db_name": "PostgreSQL",
"query": "\nINSERT INTO users (id, username, email, avatar, name, created_at, last_updated)\nVALUES ($1, $2, $3, $4, $5, $6, $7)\nRETURNING *\n",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Timestamptz",
"Timestamptz"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "24bbd63a324e36f4a0c559c44909621cc21b493b4c9ae4c14e30c99a6d8072bd"
}

View File

@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "DELETE FROM users WHERE id = $1",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text"
]
},
"nullable": []
},
"hash": "50293c2e54af11d4c2a553e29b671cef087a159c6ee7182d8ca929ecb748f3b7"
}

View File

@ -0,0 +1,58 @@
{
"db_name": "PostgreSQL",
"query": "SELECT * FROM users WHERE id = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "843923b9a0257cf80f1dff554e7dc8fdfc05f489328e8376513124dfb42996e3"
}

View File

@ -0,0 +1,63 @@
{
"db_name": "PostgreSQL",
"query": "\nUPDATE users\nSET username = $1, email = $2, avatar = $3, name = $4, last_updated = $5\nWHERE id = $6\nRETURNING *\n",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Timestamptz",
"Text"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "c36467a81bad236a0c1a8d3fc1b5f8efda9b9cf9eaab140c10246218f0240600"
}

View File

@ -3,3 +3,5 @@ out = ["Xml"]
target-dir = "coverage" target-dir = "coverage"
output-dir = "coverage" output-dir = "coverage"
fail-under = 60 fail-under = 60
exclude-files = ["target/*"]
run-types = ["AllTargets"]

View File

@ -4,3 +4,5 @@ skip-clean = true
target-dir = "coverage" target-dir = "coverage"
output-dir = "coverage" output-dir = "coverage"
fail-under = 60 fail-under = 60
exclude-files = ["target/*"]
run-types = ["AllTargets"]

1616
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,61 +1,8 @@
[package] [workspace]
name = "gege-jdr-backend"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["phundrak"]
rust-version = "1.78"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html members = [
"gejdr-core",
[lib] "gejdr-bot",
path = "src/lib.rs" "gejdr-backend",
[[bin]]
path = "src/main.rs"
name = "gege-jdr-backend"
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.0", features = ["yaml"] }
dotenvy = "0.15.7"
serde = "1.0.204"
serde_json = "1.0.120"
thiserror = "1.0.63"
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
[dependencies.lettre]
version = "0.11.7"
default-features = false
features = [
"builder",
"hostname",
"pool",
"rustls-tls",
"tokio1",
"tokio1-rustls-tls",
"smtp-transport"
] ]
resolver = "2"
[dependencies.poem]
version = "3.0.4"
default-features = false
features = [
"csrf",
"rustls",
"cookie",
"test"
]
[dependencies.poem-openapi]
version = "5.0.3"
features = ["chrono", "swagger-ui", "uuid"]
[dependencies.sqlx]
version = "0.8.0"
default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]

19
backend.just Normal file
View File

@ -0,0 +1,19 @@
default: run
prepare:
pushd gejdr-backend
cargo sqlx prepare
popd
build:
cargo auditable build
build-release:
cargo auditable build --release
run:
cargo auditable run --bin gejdr-backend
## Local Variables:
## mode: makefile
## End:

View File

@ -31,25 +31,6 @@ services:
depends_on: depends_on:
- db - db
# If you run GegeJdrBackend in production, DO NOT use mailpit.
# This tool is for testing only. Instead, you should use a real SMTP
# provider, such as Mailgun, Mailwhale, or Postal.
mailpit:
image: axllent/mailpit:latest
restart: unless-stopped
container_name: gege-jdr-backend-mailpit
ports:
- 127.0.0.1:8025:8025 # WebUI
- 127.0.0.1:1025:1025 # SMTP
volumes:
- gege_jdr_backend_mailpit:/data
environment:
MP_MAX_MESSAGES: 5000
MP_DATABASE: /data/mailpit.db
MP_SMTP_AUTH_ACCEPT_ANY: 1
MP_SMTP_AUTH_ALLOW_INSECURE: 1
volumes: volumes:
gege_jdr_backend_db_data: gege_jdr_backend_db_data:
gege_jdr_backend_pgadmin_data: gege_jdr_backend_pgadmin_data:
gege_jdr_backend_mailpit:

14
docker/mod.just Normal file
View File

@ -0,0 +1,14 @@
default: start
start:
docker compose -f compose.dev.yml up -d
stop:
docker compose -f compose.dev.yml down
logs:
docker compose -f compose.dev.yml logs -f
## Local Variables:
## mode: makefile
## End:

12
flake.lock generated
View File

@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1732014248, "lastModified": 1736344531,
"narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=", "narHash": "sha256-8YVQ9ZbSfuUk2bUf2KRj60NRraLPKPS0Q4QFTbc+c2c=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "23e89b7da85c3640bbc2173fe04f4bd114342367", "rev": "bffc22eb12172e6db3c5dde9e3e5628f8e3e7912",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2" "nixpkgs": "nixpkgs_2"
}, },
"locked": { "locked": {
"lastModified": 1732242723, "lastModified": 1736476219,
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=", "narHash": "sha256-+qyv3QqdZCdZ3cSO/cbpEY6tntyYjfe1bB12mdpNFaY=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a", "rev": "de30cc5963da22e9742bbbbb9a3344570ed237b9",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -18,7 +18,7 @@
rustc = rustVersion; rustc = rustVersion;
}; };
appName = "gege-jdr-backend"; appName = "gejdr-backend";
appRustBuild = rustPlatform.buildRustPackage { appRustBuild = rustPlatform.buildRustPackage {
pname = appName; pname = appName;
@ -49,7 +49,6 @@
cargo-audit cargo-audit
cargo-auditable cargo-auditable
cargo-tarpaulin cargo-tarpaulin
cargo-msrv
just just
rust-analyzer rust-analyzer
(rustVersion.override { extensions = [ "rust-src" ]; }) (rustVersion.override { extensions = [ "rust-src" ]; })

View File

@ -0,0 +1,6 @@
[all]
out = ["Xml"]
target-dir = "coverage"
output-dir = "coverage"
fail-under = 60
exclude-files = ["target/*"]

View File

@ -0,0 +1,7 @@
[all]
out = ["Html", "Lcov"]
skip-clean = true
target-dir = "coverage"
output-dir = "coverage"
fail-under = 60
exclude-files = ["target/*"]

43
gejdr-backend/Cargo.toml Normal file
View File

@ -0,0 +1,43 @@
[package]
name = "gejdr-backend"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
path = "src/lib.rs"
[[bin]]
path = "src/main.rs"
name = "gejdr-backend"
[dependencies]
gejdr-core = { path = "../gejdr-core" }
chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.1", features = ["yaml"] }
dotenvy = "0.15.7"
oauth2 = "4.4.2"
quote = "1.0.37"
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
serde = "1.0.215"
serde_json = "1.0.133"
thiserror = "1.0.69"
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.poem]
version = "3.1.3"
default-features = false
features = ["csrf", "rustls", "cookie", "test", "session"]
[dependencies.poem-openapi]
version = "5.1.2"
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }

View File

@ -3,5 +3,5 @@ debug: true
application: application:
protocol: http protocol: http
host: 127.0.0.1 host: localhost
base_url: http://127.0.0.1:3000 base_url: http://localhost:3000

View File

@ -0,0 +1,55 @@
use super::{ApiError, DiscordErrorResponse};
use gejdr_core::models::accounts::RemoteUser;
static DISCORD_URL: &str = "https://discord.com/api/v10/";
pub async fn get_user_profile(token: &str) -> Result<RemoteUser, ApiError> {
let client = reqwest::Client::new();
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {token}").parse().unwrap(),
);
let response = client
.get(format!("{DISCORD_URL}/users/@me"))
.headers(headers)
.send()
.await;
match response {
Ok(resp) => {
if resp.status().is_success() {
resp.json::<RemoteUser>()
.await
.map_err(std::convert::Into::into)
} else {
let error_response = resp.json::<DiscordErrorResponse>().await;
match error_response {
Ok(val) => Err(ApiError::Api(val)),
Err(e) => Err(ApiError::Reqwest(e)),
}
}
}
Err(e) => Err(ApiError::Reqwest(e)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn user_profile_invalid_token_results_in_401() {
let res = get_user_profile("invalid").await;
assert!(res.is_err());
let err = res.err().unwrap();
println!("Error: {err:?}");
let expected = DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
};
assert!(matches!(ApiError::Api(expected), _err));
}
// TODO: Find a way to mock calls to discord.com API with a
// successful reply
}

View File

@ -0,0 +1,41 @@
use reqwest::Error as ReqwestError;
use std::fmt::{self, Display};
use thiserror::Error;
pub mod discord;
#[derive(Debug, serde::Deserialize, PartialEq, Eq)]
pub struct DiscordErrorResponse {
pub message: String,
pub code: u16,
}
impl Display for DiscordErrorResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "DiscordErrorResponse: {} ({})", self.message, self.code)
}
}
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Reqwest error: {0}")]
Reqwest(#[from] ReqwestError),
#[error("API Error: {0}")]
Api(DiscordErrorResponse),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_discord_error_response() {
let error = DiscordErrorResponse {
message: "Message".into(),
code: 42,
};
let error_str = error.to_string();
let expected = "DiscordErrorResponse: Message (42)".to_string();
assert_eq!(expected, error_str);
}
}

View File

@ -0,0 +1,14 @@
use thiserror::Error;
#[allow(dead_code)]
#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
Sql(#[from] gejdr_core::sqlx::Error),
#[error("OAuth token error: {0}")]
TokenError(String),
#[error("Unauthorized")]
Unauthorized,
#[error("Attempted to get a value, none found")]
OptionError,
}

View File

@ -5,10 +5,14 @@
#![allow(clippy::unused_async)] #![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling #![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
pub mod route; use gejdr_core::sqlx;
pub mod settings;
pub mod startup; mod api_wrapper;
pub mod telemetry; mod errors;
mod oauth;
mod route;
mod settings;
mod startup;
type MaybeListener = Option<poem::listener::TcpListener<String>>; type MaybeListener = Option<poem::listener::TcpListener<String>>;
@ -16,8 +20,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
dotenvy::dotenv().ok(); dotenvy::dotenv().ok();
let settings = settings::Settings::new().expect("Failed to read settings"); let settings = settings::Settings::new().expect("Failed to read settings");
if !cfg!(test) { if !cfg!(test) {
let subscriber = telemetry::get_subscriber(settings.clone().debug); let subscriber = gejdr_core::telemetry::get_subscriber(settings.clone().debug);
telemetry::init_subscriber(subscriber); gejdr_core::telemetry::init_subscriber(subscriber);
} }
tracing::event!( tracing::event!(
target: "gege-jdr-backend", target: "gege-jdr-backend",
@ -29,8 +33,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
tracing::event!( tracing::event!(
target: "gege-jdr-backend", target: "gege-jdr-backend",
tracing::Level::INFO, tracing::Level::INFO,
"Listening on http://127.0.0.1:{}/", "Listening on {}",
application.port() application.settings.web_address()
); );
application application
} }

View File

@ -1,5 +1,5 @@
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), std::io::Error> { async fn main() -> Result<(), std::io::Error> {
gege_jdr_backend::run(None).await gejdr_backend::run(None).await
} }

View File

@ -0,0 +1,62 @@
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
};
use reqwest::Url;
use crate::{errors::ApiError, settings::Settings};
use super::OauthProvider;
#[derive(Debug, Clone)]
pub struct DiscordOauthProvider {
client: BasicClient,
}
impl DiscordOauthProvider {
pub fn new(settings: &Settings) -> Self {
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
.expect("Invalid token endpoint URL");
let revocation_url =
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
.expect("Invalid revocation URL");
let client = BasicClient::new(
ClientId::new(settings.discord.client_id.clone()),
Some(ClientSecret::new(settings.discord.client_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
.set_revocation_uri(revocation_url);
Self { client }
}
}
impl OauthProvider for DiscordOauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
.set_pkce_challenge(pkce_challenge)
.url();
(auth_url, csrf_token, pkce_verifier)
}
async fn token(
&self,
code: String,
verifier: PkceCodeVerifier,
) -> Result<super::Token, ApiError> {
self.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(verifier)
.request_async(oauth2::reqwest::async_http_client)
.await
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
}
}

View File

@ -0,0 +1,17 @@
mod discord;
pub use discord::DiscordOauthProvider;
use oauth2::{
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
StandardTokenResponse,
};
use reqwest::Url;
use crate::errors::ApiError;
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
pub trait OauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
}

View File

@ -0,0 +1,220 @@
use gejdr_core::models::accounts::User;
use gejdr_core::sqlx::PgPool;
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
use poem::web::Data;
use poem::{session::Session, web::Form};
use poem_openapi::payload::{Json, PlainText};
use poem_openapi::{ApiResponse, Object, OpenApi};
use crate::oauth::{DiscordOauthProvider, OauthProvider};
use super::errors::ErrorResponse;
use super::ApiCategory;
type Token =
oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>;
pub struct AuthApi;
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
struct DiscordCallbackRequest {
code: String,
state: String,
}
impl DiscordCallbackRequest {
pub fn check_token(&self, token: &CsrfToken) -> Result<(), LoginStatusResponse> {
if *token.secret().to_string() == self.state {
Ok(())
} else {
Err(LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some(
"OAuth provider did not send a message that matches what was expected".into(),
),
})))
}
}
}
#[derive(ApiResponse)]
enum LoginStatusResponse {
#[oai(status = 201)]
LoggedIn(Json<UserInfo>),
#[oai(status = 201)]
LoggedOut(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 301)]
LoginRedirect(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 500)]
TokenError(Json<ErrorResponse>),
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
#[derive(Debug, Eq, PartialEq, serde::Serialize, Object)]
struct UserInfo {
id: String,
username: String,
display_name: Option<String>,
avatar: Option<String>,
}
impl From<User> for UserInfo {
fn from(value: User) -> Self {
Self {
id: value.id,
username: value.username,
display_name: value.name,
avatar: value.avatar,
}
}
}
#[derive(ApiResponse)]
enum UserInfoResponse {
#[oai(status = 201)]
UserInfo(Json<UserInfo>),
#[oai(status = 401)]
Unauthorized,
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
impl From<UserInfoResponse> for LoginStatusResponse {
fn from(value: UserInfoResponse) -> Self {
match value {
UserInfoResponse::UserInfo(json) => Self::LoggedIn(json),
UserInfoResponse::Unauthorized => unimplemented!(),
UserInfoResponse::DatabaseError(json) => Self::DatabaseError(json),
UserInfoResponse::DiscordError(json) => Self::DiscordError(json),
}
}
}
#[derive(ApiResponse)]
enum CsrfResponse {
#[oai(status = 201)]
Token(PlainText<String>),
}
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
impl AuthApi {
async fn fetch_remote_user(
pool: Data<&PgPool>,
token: Token,
) -> Result<UserInfoResponse, UserInfoResponse> {
crate::api_wrapper::discord::get_user_profile(token.access_token().secret())
.await
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Failed to communicate with Discord: {}",
e
);
UserInfoResponse::DiscordError(Json(e.into()))
})?
.refresh_in_database(&pool)
.await
.map(|user| UserInfoResponse::UserInfo(Json(user.into())))
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Database error: {}",
e
);
UserInfoResponse::DatabaseError(Json(e.into()))
})
}
#[oai(path = "/signin/discord", method = "get")]
async fn signin_discord(
&self,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
session.set("csrf", csrf_token);
session.set("pkce", pkce_verifier);
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Signin through Discord",
);
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
}
#[oai(path = "/callback/discord", method = "get")]
async fn callback_discord(
&self,
Form(auth_request): Form<DiscordCallbackRequest>,
oauth: Data<&DiscordOauthProvider>,
pool: Data<&PgPool>,
session: &Session,
) -> Result<LoginStatusResponse, LoginStatusResponse> {
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Discord callback",
);
let csrf_token = session.get::<CsrfToken>("csrf").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch csrf token from session".to_string(),
..Default::default()
}))
})?;
auth_request.check_token(&csrf_token)?;
let pkce_verifier = session.get::<PkceCodeVerifier>("pkce").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch pkce verifier from session".to_string(),
..Default::default()
}))
})?;
let token = oauth
.token(auth_request.code, pkce_verifier)
.await
.map_err(|e| LoginStatusResponse::TokenError(Json(e.into())))?;
session.set("token", token.clone());
Self::fetch_remote_user(pool, token)
.await
.map(std::convert::Into::into)
.map_err(std::convert::Into::into)
}
#[oai(path = "/csrf", method = "get")]
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
CsrfResponse::Token(PlainText(token.0.clone()))
}
#[oai(path = "/signout", method = "post")]
async fn signout(&self, session: &Session) -> LoginStatusResponse {
session.purge();
LoginStatusResponse::LoggedOut("/".to_string(), "no-cache".to_string())
}
#[oai(path = "/me", method = "get")]
async fn user_info(
&self,
session: &Session,
pool: Data<&PgPool>,
) -> Result<UserInfoResponse, UserInfoResponse> {
let token = session
.get::<Token>("token")
.ok_or(UserInfoResponse::Unauthorized)?;
Self::fetch_remote_user(pool, token).await
}
}

View File

@ -0,0 +1,204 @@
use poem_openapi::Object;
use reqwest::Error as ReqwestError;
use crate::api_wrapper::ApiError as ApiWrapperError;
use crate::errors::ApiError;
#[derive(Debug, serde::Serialize, Default, Object, PartialEq, Eq)]
pub struct ErrorResponse {
pub code: u16,
pub message: String,
pub details: Option<String>,
}
impl From<ApiError> for ErrorResponse {
fn from(value: ApiError) -> Self {
match value {
ApiError::Sql(e) => Self {
code: 500,
message: "SQL error".into(),
details: Some(e.to_string()),
},
ApiError::TokenError(e) => Self {
code: 500,
message: "OAuth token error".into(),
details: Some(e),
},
ApiError::Unauthorized => Self {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
},
ApiError::OptionError => Self {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
},
}
}
}
impl From<ReqwestError> for ErrorResponse {
fn from(value: ReqwestError) -> Self {
Self {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some(value.status().map_or_else(
|| "Communication failed before we could hear back from Discord".into(),
|status| format!("Discord sent back the error code {status}"),
)),
}
}
}
impl From<ApiWrapperError> for ErrorResponse {
fn from(source: ApiWrapperError) -> Self {
match source {
ApiWrapperError::Reqwest(e) => e.into(),
ApiWrapperError::Api(e) => Self {
code: if e.message.as_str().starts_with("401") {
401
} else {
e.code
},
message: e.message,
details: None,
},
}
}
}
impl From<gejdr_core::sqlx::Error> for ErrorResponse {
fn from(_value: gejdr_core::sqlx::Error) -> Self {
Self {
code: 500,
message: "Internal database error".into(),
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api_wrapper::{ApiError as ApiWrapperError, DiscordErrorResponse};
#[test]
fn conversion_from_sql_api_error_works() {
let sql_error = ApiError::Sql(gejdr_core::sqlx::Error::ColumnNotFound(
"COLUMN_NAME".to_string(),
));
let final_error = ErrorResponse::from(sql_error);
let expected_error = ErrorResponse {
code: 500,
message: "SQL error".into(),
details: Some("no column found for name: COLUMN_NAME".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_token_error_works() {
let initial_error = ApiError::TokenError("TOKEN ERROR".into());
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some("TOKEN ERROR".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_unauthorized_works() {
let initial_error = ApiError::Unauthorized;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_option_error_works() {
let initial_error = ApiError::OptionError;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[tokio::test]
async fn conversion_from_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.err().unwrap().into();
assert_eq!(expected, actual);
}
#[tokio::test]
async fn conversion_from_apiwrappererror_with_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let err = ApiWrapperError::Reqwest(err.err().unwrap());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_401_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
});
let expected = ErrorResponse {
code: 401,
message: "401: Unauthorized".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_generic_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "Something else".into(),
});
let expected = ErrorResponse {
code: 0,
message: "Something else".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_database_error() {
let err = gejdr_core::sqlx::Error::PoolClosed;
let expected = ErrorResponse {
code: 500,
message: "Internal database error".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
}

View File

@ -10,7 +10,7 @@ enum HealthResponse {
pub struct HealthApi; pub struct HealthApi;
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")] #[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
impl HealthApi { impl HealthApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn health_check(&self) -> HealthResponse { async fn health_check(&self) -> HealthResponse {
@ -23,7 +23,7 @@ impl HealthApi {
async fn health_check_works() { async fn health_check_works() {
let app = crate::get_test_app(None).await; let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app); let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/health-check").send().await; let resp = cli.get("/v1/api/health-check").send().await;
resp.assert_status_is_ok(); resp.assert_status_is_ok();
resp.assert_text("").await; resp.assert_text("").await;
} }

View File

@ -6,13 +6,19 @@ pub use health::HealthApi;
mod version; mod version;
pub use version::VersionApi; pub use version::VersionApi;
mod errors;
mod auth;
pub use auth::AuthApi;
#[derive(Tags)] #[derive(Tags)]
enum ApiCategory { enum ApiCategory {
Auth,
Health, Health,
Version, Version,
} }
pub(crate) struct Api; pub struct Api;
#[OpenApi] #[OpenApi]
impl Api {} impl Api {}

View File

@ -25,7 +25,7 @@ enum VersionResponse {
pub struct VersionApi; pub struct VersionApi;
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")] #[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
impl VersionApi { impl VersionApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> { async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
@ -38,7 +38,7 @@ impl VersionApi {
async fn version_works() { async fn version_works() {
let app = crate::get_test_app(None).await; let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app); let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/version").send().await; let resp = cli.get("/v1/api/version").send().await;
resp.assert_status_is_ok(); resp.assert_status_is_ok();
let json = resp.json().await; let json = resp.json().await;
let json_value = json.value(); let json_value = json.value();

View File

@ -1,4 +1,4 @@
use sqlx::ConnectOptions; use gejdr_core::database::Database;
#[derive(Debug, serde::Deserialize, Clone, Default)] #[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Settings { pub struct Settings {
@ -13,16 +13,8 @@ pub struct Settings {
impl Settings { impl Settings {
#[must_use] #[must_use]
pub fn web_address(&self) -> String { pub fn web_address(&self) -> String {
if self.debug {
format!(
"{}:{}",
self.application.base_url.clone(),
self.application.port
)
} else {
self.application.base_url.clone() self.application.base_url.clone()
} }
}
/// Multipurpose function that helps detect the current /// Multipurpose function that helps detect the current
/// environment the application is running in using the /// environment the application is running in using the
@ -66,7 +58,7 @@ impl Settings {
)) ))
.add_source( .add_source(
config::Environment::with_prefix("APP") config::Environment::with_prefix("APP")
.prefix_separator("_") .prefix_separator("__")
.separator("__"), .separator("__"),
) )
.build()?; .build()?;
@ -84,35 +76,6 @@ pub struct ApplicationSettings {
pub protocol: String, pub protocol: String,
} }
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Database {
pub host: String,
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
pub require_ssl: bool,
}
impl Database {
#[must_use]
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
let ssl_mode = if self.require_ssl {
sqlx::postgres::PgSslMode::Require
} else {
sqlx::postgres::PgSslMode::Prefer
};
sqlx::postgres::PgConnectOptions::new()
.host(&self.host)
.username(&self.user)
.password(&self.password)
.port(self.port)
.ssl_mode(ssl_mode)
.database(&self.name)
.log_statements(tracing::log::LevelFilter::Trace)
}
}
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub enum Environment { pub enum Environment {
Development, Development,
@ -167,8 +130,8 @@ pub struct EmailSettings {
#[derive(Debug, serde::Deserialize, Clone, Default)] #[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Discord { pub struct Discord {
client_id: String, pub client_id: String,
client_secret: String, pub client_secret: String,
} }
#[cfg(test)] #[cfg(test)]
@ -235,7 +198,7 @@ mod tests {
#[test] #[test]
fn web_address_works() { fn web_address_works() {
let mut settings = Settings { let settings = Settings {
debug: false, debug: false,
application: ApplicationSettings { application: ApplicationSettings {
base_url: "127.0.0.1".to_string(), base_url: "127.0.0.1".to_string(),
@ -244,10 +207,7 @@ mod tests {
}, },
..Default::default() ..Default::default()
}; };
let expected_no_debug = "127.0.0.1".to_string(); let expected = "127.0.0.1".to_string();
let expected_debug = "127.0.0.1:3000".to_string(); assert_eq!(expected, settings.web_address());
assert_eq!(expected_no_debug, settings.web_address());
settings.debug = true;
assert_eq!(expected_debug, settings.web_address());
} }
} }

View File

@ -1,35 +1,36 @@
use poem::middleware::Cors; use gejdr_core::sqlx;
use poem::middleware::{AddDataEndpoint, CorsEndpoint}; use poem::middleware::{AddDataEndpoint, CorsEndpoint};
use poem::middleware::{CookieJarManagerEndpoint, Cors};
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
use poem::{EndpointExt, Route}; use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
use crate::oauth::DiscordOauthProvider;
use crate::route::AuthApi;
use crate::{ use crate::{
route::{Api, HealthApi, VersionApi}, route::{Api, HealthApi, VersionApi},
settings::Settings, settings::Settings,
}; };
#[must_use]
pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgres::PgPool {
tracing::event!(
target: "startup",
tracing::Level::INFO,
"connecting to database with configuration {:?}",
settings.clone()
);
sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(settings.get_connect_options())
}
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>; type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>; pub type App = AddDataEndpoint<
AddDataEndpoint<
AddDataEndpoint<
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
DiscordOauthProvider,
>,
sqlx::Pool<sqlx::Postgres>,
>,
Settings,
>;
pub struct Application { pub struct Application {
server: Server, server: Server,
app: poem::Route, app: poem::Route,
port: u16, port: u16,
database: sqlx::postgres::PgPool, database: sqlx::postgres::PgPool,
settings: Settings, pub settings: Settings,
} }
pub struct RunnableApplication { pub struct RunnableApplication {
@ -61,6 +62,8 @@ impl From<Application> for RunnableApplication {
let app = val let app = val
.app .app
.with(Cors::new()) .with(Cors::new())
.with(CookieSession::new(CookieConfig::default().secure(true)))
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
.data(val.database) .data(val.database)
.data(val.settings); .data(val.settings);
let server = val.server; let server = val.server;
@ -74,16 +77,16 @@ impl Application {
test_pool: Option<sqlx::postgres::PgPool>, test_pool: Option<sqlx::postgres::PgPool>,
) -> sqlx::postgres::PgPool { ) -> sqlx::postgres::PgPool {
let database_pool = let database_pool =
test_pool.map_or_else(|| get_connection_pool(&settings.database), |pool| pool); test_pool.map_or_else(|| settings.database.get_connection_pool(), |pool| pool);
if !cfg!(test) { if !cfg!(test) {
migrate_database(&database_pool).await; gejdr_core::database::Database::migrate(&database_pool).await;
} }
database_pool database_pool
} }
fn setup_app(settings: &Settings) -> poem::Route { fn setup_app(settings: &Settings) -> poem::Route {
let api_service = OpenApiService::new( let api_service = OpenApiService::new(
(Api, HealthApi, VersionApi), (Api, AuthApi, HealthApi, VersionApi),
settings.application.clone().name, settings.application.clone().name,
settings.application.clone().version, settings.application.clone().version,
); );
@ -104,6 +107,7 @@ impl Application {
}); });
poem::Server::new(tcp_listener) poem::Server::new(tcp_listener)
} }
pub async fn build( pub async fn build(
settings: Settings, settings: Settings,
test_pool: Option<sqlx::postgres::PgPool>, test_pool: Option<sqlx::postgres::PgPool>,
@ -133,10 +137,3 @@ impl Application {
self.port self.port
} }
} }
async fn migrate_database(pool: &sqlx::postgres::PgPool) {
sqlx::migrate!()
.run(pool)
.await
.expect("Failed to migrate the database");
}

1
gejdr-bot/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

6
gejdr-bot/Cargo.toml Normal file
View File

@ -0,0 +1,6 @@
[package]
name = "gejdr-bot"
version = "0.1.0"
edition = "2021"
[dependencies]

3
gejdr-bot/src/main.rs Normal file
View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

1
gejdr-core/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

16
gejdr-core/Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "gejdr-core"
version = "0.1.0"
edition = "2021"
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
serde = "1.0.215"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.sqlx]
version = "0.8.2"
default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]

View File

@ -0,0 +1,3 @@
-- Add down migration script here
DROP TABLE IF EXISTS public.users;
DROP EXTENSION IF EXISTS "uuid-ossp";

View File

@ -0,0 +1,15 @@
-- Add up migration script here
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users
(
id character varying(255) NOT NULL,
username character varying(255) NOT NULL,
email character varying(255),
avatar character varying(511),
name character varying(255),
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
CONSTRAINT users_email_unique UNIQUE (email)
);

View File

@ -0,0 +1,50 @@
use sqlx::ConnectOptions;
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Database {
pub host: String,
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
pub require_ssl: bool,
}
impl Database {
#[must_use]
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
let ssl_mode = if self.require_ssl {
sqlx::postgres::PgSslMode::Require
} else {
sqlx::postgres::PgSslMode::Prefer
};
sqlx::postgres::PgConnectOptions::new()
.host(&self.host)
.username(&self.user)
.password(&self.password)
.port(self.port)
.ssl_mode(ssl_mode)
.database(&self.name)
.log_statements(tracing::log::LevelFilter::Trace)
}
#[must_use]
pub fn get_connection_pool(&self) -> sqlx::postgres::PgPool {
tracing::event!(
target: "startup",
tracing::Level::INFO,
"connecting to database with configuration {:?}",
self.clone()
);
sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(self.get_connect_options())
}
pub async fn migrate(pool: &sqlx::PgPool) {
sqlx::migrate!()
.run(pool)
.await
.expect("Failed to migrate the database");
}
}

4
gejdr-core/src/lib.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod database;
pub mod models;
pub mod telemetry;
pub use sqlx;

View File

@ -0,0 +1,396 @@
use sqlx::PgPool;
type Timestampz = chrono::DateTime<chrono::Utc>;
#[derive(serde::Deserialize, PartialEq, Eq, Debug, Clone, Default)]
pub struct RemoteUser {
id: String,
username: String,
global_name: Option<String>,
email: Option<String>,
avatar: Option<String>,
}
impl RemoteUser {
/// Refresh in database the row related to the remote user. Maybe
/// create a row for this user if needed.
pub async fn refresh_in_database(self, pool: &PgPool) -> Result<User, sqlx::Error> {
match User::find(pool, &self.id).await? {
Some(local_user) => local_user.update_from_remote(self).update(pool).await,
None => User::from(self).save(pool).await,
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq, Default, Clone)]
pub struct User {
pub id: String,
pub username: String,
pub email: Option<String>,
pub avatar: Option<String>,
pub name: Option<String>,
pub created_at: Timestampz,
pub last_updated: Timestampz,
}
impl From<RemoteUser> for User {
fn from(value: RemoteUser) -> Self {
Self {
id: value.id,
username: value.username,
email: value.email,
avatar: value.avatar,
name: value.global_name,
created_at: chrono::offset::Utc::now(),
last_updated: chrono::offset::Utc::now(),
}
}
}
impl PartialEq<RemoteUser> for User {
#[allow(clippy::suspicious_operation_groupings)]
fn eq(&self, other: &RemoteUser) -> bool {
self.id == other.id
&& self.username == other.username
&& self.email == other.email
&& self.avatar == other.avatar
&& self.name == other.global_name
}
}
impl PartialEq<User> for RemoteUser {
fn eq(&self, other: &User) -> bool {
other == self
}
}
impl User {
pub fn update_from_remote(self, from: RemoteUser) -> Self {
if self == from {
self
} else {
Self {
username: from.username,
email: from.email,
avatar: from.avatar,
name: from.global_name,
last_updated: chrono::offset::Utc::now(),
..self
}
}
}
pub async fn find(pool: &PgPool, id: &String) -> Result<Option<Self>, sqlx::Error> {
sqlx::query_as!(Self, r#"SELECT * FROM users WHERE id = $1"#, id)
.fetch_optional(pool)
.await
}
pub async fn save(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
sqlx::query_as!(
Self,
r#"
INSERT INTO users (id, username, email, avatar, name, created_at, last_updated)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING *
"#,
self.id,
self.username,
self.email,
self.avatar,
self.name,
self.created_at,
self.last_updated
)
.fetch_one(pool)
.await
}
pub async fn update(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
sqlx::query_as!(
Self,
r#"
UPDATE users
SET username = $1, email = $2, avatar = $3, name = $4, last_updated = $5
WHERE id = $6
RETURNING *
"#,
self.username,
self.email,
self.avatar,
self.name,
self.last_updated,
self.id
)
.fetch_one(pool)
.await
}
pub async fn save_or_update(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
if Self::find(pool, &self.id).await?.is_some() {
self.update(pool).await
} else {
self.save(pool).await
}
}
pub async fn delete(pool: &PgPool, id: &String) -> Result<u64, sqlx::Error> {
let rows_affected = sqlx::query!("DELETE FROM users WHERE id = $1", id)
.execute(pool)
.await?
.rows_affected();
Ok(rows_affected)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert_remote_user_to_local_user() {
let remote = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: None,
};
let local: User = remote.into();
let expected = User {
id: "user-id".into(),
username: "username".into(),
email: Some("user@example.com".into()),
avatar: None,
name: None,
created_at: local.created_at,
last_updated: local.last_updated,
};
assert_eq!(expected, local);
}
#[test]
fn can_compare_remote_and_local_user() {
let remote_same = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: None,
};
let remote_different = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: Some("some-hash".into()),
};
let local = User {
id: "user-id".into(),
username: "username".into(),
email: Some("user@example.com".into()),
avatar: None,
name: None,
created_at: chrono::offset::Utc::now(),
last_updated: chrono::offset::Utc::now(),
};
assert_eq!(remote_same, local);
assert_ne!(remote_different, local);
}
#[sqlx::test]
async fn add_new_remote_users_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let remote1 = RemoteUser {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let remote2 = RemoteUser {
id: "id2".into(),
username: "user2".into(),
..Default::default()
};
remote1.refresh_in_database(&pool).await?;
remote2.refresh_in_database(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn update_local_users_in_db_from_remote(pool: sqlx::PgPool) -> sqlx::Result<()> {
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
let remote1 = RemoteUser {
id: "id1".into(),
username: "user1-new".into(),
..Default::default()
};
let remote2 = RemoteUser {
id: "id2".into(),
username: "user2-new".into(),
..Default::default()
};
remote1.refresh_in_database(&pool).await?;
remote2.refresh_in_database(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
users
.iter()
.for_each(|user| assert!(user.last_updated > user.created_at));
Ok(())
}
#[test]
fn update_local_user_from_identical_remote_shouldnt_change_local() {
let remote = RemoteUser {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let local = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let new_local = local.clone().update_from_remote(remote);
assert_eq!(local, new_local);
}
#[test]
fn update_local_user_from_different_remote() {
let remote = RemoteUser {
id: "id1".into(),
username: "user2".into(),
..Default::default()
};
let local = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let new_local = local.clone().update_from_remote(remote.clone());
assert_ne!(remote, local);
assert_eq!(remote, new_local);
}
#[sqlx::test]
async fn save_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let user = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
user.save(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, users.len());
assert_eq!(Some(user), users.first().cloned());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn update_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert!(db_user.name.is_none());
let user = User {
id: "id1".into(),
username: "user1".into(),
name: Some("Cool Name".into()),
..Default::default()
};
user.update(&pool).await?;
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert!(db_user.name.is_some());
assert_eq!(Some("Cool Name".to_string()), db_user.name);
Ok(())
}
#[sqlx::test]
async fn save_or_update_saves_if_no_exist(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(0, rows.len());
let user = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
user.save_or_update(&pool).await?;
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, rows.len());
let db_user = rows.first();
assert_eq!(Some(user), db_user.cloned());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn save_or_update_updates_if_exists(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let user = User {
id: "id1".into(),
username: "user1".into(),
name: Some("Cool Nam".into()),
..Default::default()
};
user.save_or_update(&pool).await?;
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert_eq!(user.name, db_user.name);
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn delete_removes_account_from_db(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let id = "id1".to_string();
let deletions = User::delete(&pool, &id).await?;
assert_eq!(1, deletions);
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, rows.len());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn delete_with_wrong_id_shouldnt_delete_anything(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let id = "invalid".to_string();
let deletions = User::delete(&pool, &id).await?;
assert_eq!(0, deletions);
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
Ok(())
}
}

View File

@ -0,0 +1,2 @@
INSERT INTO users (id, username) VALUES ('id1', 'user1');
INSERT INTO users (id, username) VALUES ('id2', 'user2');

View File

@ -0,0 +1 @@
pub mod accounts;

View File

@ -5,7 +5,7 @@ pub fn get_subscriber(debug: bool) -> impl tracing::Subscriber + Send + Sync {
let env_filter = if debug { "debug" } else { "info" }.to_string(); let env_filter = if debug { "debug" } else { "info" }.to_string();
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter)); .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter));
let stdout_log = tracing_subscriber::fmt::layer().pretty(); let stdout_log = tracing_subscriber::fmt::layer().pretty().with_test_writer();
let subscriber = tracing_subscriber::Registry::default() let subscriber = tracing_subscriber::Registry::default()
.with(env_filter) .with(env_filter)
.with(stdout_log); .with(stdout_log);

View File

@ -1,10 +1,7 @@
default: run mod backend
mod docker
prepare: default: lint
cargo sqlx prepare
migrate:
sqlx migrate run
format: format:
cargo fmt --all cargo fmt --all
@ -12,17 +9,12 @@ format:
format-check: format-check:
cargo fmt --check --all cargo fmt --check --all
migrate:
sqlx migrate run --source gejdr-core/migrations
build: build:
cargo auditable build cargo auditable build --bin gejdr-backend
cargo auditable build --bin gejdr-bot
build-release:
cargo auditable build --release
run: docker-start
cargo auditable run
run-no-docker:
cargo auditable run
lint: lint:
cargo clippy --all-targets cargo clippy --all-targets
@ -31,19 +23,22 @@ msrv:
cargo msrv verify cargo msrv verify
release-build: release-build:
cargo auditable build --release cargo auditable build --release --bin gejdr-backend
cargo auditable build --release --bin gejdr-bot
release-run: release-run:
cargo auditable run --release cargo auditable run --release
audit: build audit: build
cargo audit bin target/debug/gege-jdr-backend cargo audit bin target/debug/gejdr-backend
cargo audit bin target/debug/gejdr-bot
audit-release: build-release audit-release:
cargo audit bin target/release/gege-jdr-backend cargo audit bin target/release/gejdr-backend
cargo audit bin target/release/gejdr-bot
test: test:
cargo test cargo test --all-targets --all
coverage: coverage:
mkdir -p coverage mkdir -p coverage
@ -55,17 +50,8 @@ coverage-ci:
check-all: format-check lint msrv coverage audit check-all: format-check lint msrv coverage audit
docker-build: # docker-build:
nix build .#docker # nix build .#docker
docker-start:
docker compose -f docker/compose.dev.yml up -d
docker-stop:
docker compose -f docker/compose.dev.yml down
docker-logs:
docker compose -f docker/compose.dev.yml logs -f
## Local Variables: ## Local Variables:
## mode: makefile ## mode: makefile

View File

@ -1,5 +0,0 @@
-- Add down migration script here
ALTER TABLE IF EXISTS public.sessions DROP CONSTRAINT IF EXISTS sessions_user_id_users_fk;
DROP TABLE IF EXISTS public.sessions;
DROP TABLE IF EXISTS public.users;
DROP EXTENSION IF EXISTS "uuid-ossp";

View File

@ -1,29 +0,0 @@
-- Add up migration script here
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
email character varying(255) NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
CONSTRAINT users_email_unique UNIQUE (email)
);
CREATE TABLE IF NOT EXISTS public.sessions
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
user_id uuid NOT NULL,
session_id character varying NOT NULL,
expires_at timestamp with time zone NOT NULL,
PRIMARY KEY (id),
CONSTRAINT sessions_user_id_unique UNIQUE (user_id)
);
ALTER TABLE IF EXISTS public.sessions
ADD CONSTRAINT sessions_user_id_users_fk FOREIGN KEY (user_id)
REFERENCES public.users (id) MATCH SIMPLE
ON UPDATE CASCADE
ON DELETE CASCADE
NOT VALID;