feat: Discord Oauth2 #1

Open
phundrak wants to merge 5 commits from feature/authentication into develop
53 changed files with 2700 additions and 753 deletions
Showing only changes of commit a2ea5a157d - Show all commits

View File

@ -1,15 +1,14 @@
;;; Directory Local Variables -*- no-byte-compile: t -*- ;;; Directory Local Variables -*- no-byte-compile: t -*-
;;; For more information see (info "(emacs) Directory Variables") ;;; For more information see (info "(emacs) Directory Variables")
((sql-mode ((rustic-mode . ((fill-column . 80)))
. (sql-mode . ((eval . (progn
((eval . (progn (setq-local lsp-sqls-connections
(setq-local lsp-sqls-connections `(((driver . "postgresql")
`(((driver . "postgresql") (dataSourceName \,
(dataSourceName . (format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable" (getenv "DB_HOST")
(getenv "DB_HOST") (getenv "DB_PORT")
(getenv "DB_PORT") (getenv "DB_USER")
(getenv "DB_USER") (getenv "DB_PASSWORD")
(getenv "DB_PASSWORD") (getenv "DB_NAME")))))))))))
(getenv "DB_NAME")))))))))))

View File

@ -45,8 +45,6 @@ jobs:
run: nix develop --command -- just lint run: nix develop --command -- just lint
- name: Audit - name: Audit
run: nix develop --command -- just audit run: nix develop --command -- just audit
- name: Minimum supported Rust version check
run: nix develop --command -- just msrv
- name: Tests - name: Tests
run: nix develop --command -- just test run: nix develop --command -- just test
- name: Coverage - name: Coverage

View File

@ -9,6 +9,7 @@ on:
pull_request: pull_request:
branches: branches:
- 'main' - 'main'
- 'develop'
jobs: jobs:
publish: publish:
@ -20,14 +21,26 @@ jobs:
with: with:
username: ${{ secrets.DOCKER_REGISTRY_USERNAME }} username: ${{ secrets.DOCKER_REGISTRY_USERNAME }}
password: ${{ secrets.DOCKER_REGISTRY_PASSWORD }} password: ${{ secrets.DOCKER_REGISTRY_PASSWORD }}
registry: ${{ var.REGISTRY }}
- uses: cachix/install-nix-action@v27 - uses: cachix/install-nix-action@v27
with: with:
nix_path: nixpkgs=channel:nixos-unstable nix_path: nixpkgs=channel:nixos-unstable
- name: Build Docker image - name: Build Docker image
run: nix develop --command -- just docker-build run: nix develop --command -- just backend build-docker
- name: Load Docker image - name: Load Docker image
run: docker load < result run: docker load < result
- name: Docker Metadata action - name: Docker Metadata action
uses: docker/metadata-action@v5.5.1 uses: docker/metadata-action@v5.6.1
with: with:
image: tal-backend:latest image: gejdr-backend:latest
tags:
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=sha
labels: |
org.opencontainers.image.title=Backend GéJDR
org.opencontainers.image.description=Backend for GéJDR
org.opencontainers.image.vendor=Lucien Cartier-Tilet <lucien@phundrak.com>

2
.gitignore vendored
View File

@ -3,3 +3,5 @@
.env .env
/result /result
/coverage/ /coverage/
/gejdr-backend/result
/gejdr-bot/result

View File

@ -0,0 +1,64 @@
{
"db_name": "PostgreSQL",
"query": "\nINSERT INTO users (id, username, email, avatar, name, created_at, last_updated)\nVALUES ($1, $2, $3, $4, $5, $6, $7)\nRETURNING *\n",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Timestamptz",
"Timestamptz"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "24bbd63a324e36f4a0c559c44909621cc21b493b4c9ae4c14e30c99a6d8072bd"
}

View File

@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "DELETE FROM users WHERE id = $1",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text"
]
},
"nullable": []
},
"hash": "50293c2e54af11d4c2a553e29b671cef087a159c6ee7182d8ca929ecb748f3b7"
}

View File

@ -0,0 +1,58 @@
{
"db_name": "PostgreSQL",
"query": "SELECT * FROM users WHERE id = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "843923b9a0257cf80f1dff554e7dc8fdfc05f489328e8376513124dfb42996e3"
}

View File

@ -0,0 +1,63 @@
{
"db_name": "PostgreSQL",
"query": "\nUPDATE users\nSET username = $1, email = $2, avatar = $3, name = $4, last_updated = $5\nWHERE id = $6\nRETURNING *\n",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Timestamptz",
"Text"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "c36467a81bad236a0c1a8d3fc1b5f8efda9b9cf9eaab140c10246218f0240600"
}

View File

@ -3,3 +3,5 @@ out = ["Xml"]
target-dir = "coverage" target-dir = "coverage"
output-dir = "coverage" output-dir = "coverage"
fail-under = 60 fail-under = 60
exclude-files = ["target/*"]
run-types = ["AllTargets"]

View File

@ -4,3 +4,5 @@ skip-clean = true
target-dir = "coverage" target-dir = "coverage"
output-dir = "coverage" output-dir = "coverage"
fail-under = 60 fail-under = 60
exclude-files = ["target/*"]
run-types = ["AllTargets"]

1616
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,61 +1,8 @@
[package] [workspace]
name = "gege-jdr-backend"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["phundrak"]
rust-version = "1.78"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html members = [
"gejdr-core",
[lib] "gejdr-bot",
path = "src/lib.rs" "gejdr-backend",
[[bin]]
path = "src/main.rs"
name = "gege-jdr-backend"
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.0", features = ["yaml"] }
dotenvy = "0.15.7"
serde = "1.0.204"
serde_json = "1.0.120"
thiserror = "1.0.63"
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
[dependencies.lettre]
version = "0.11.7"
default-features = false
features = [
"builder",
"hostname",
"pool",
"rustls-tls",
"tokio1",
"tokio1-rustls-tls",
"smtp-transport"
] ]
resolver = "2"
[dependencies.poem]
version = "3.0.4"
default-features = false
features = [
"csrf",
"rustls",
"cookie",
"test"
]
[dependencies.poem-openapi]
version = "5.0.3"
features = ["chrono", "swagger-ui", "uuid"]
[dependencies.sqlx]
version = "0.8.0"
default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]

View File

@ -31,25 +31,6 @@ services:
depends_on: depends_on:
- db - db
# If you run GegeJdrBackend in production, DO NOT use mailpit.
# This tool is for testing only. Instead, you should use a real SMTP
# provider, such as Mailgun, Mailwhale, or Postal.
mailpit:
image: axllent/mailpit:latest
restart: unless-stopped
container_name: gege-jdr-backend-mailpit
ports:
- 127.0.0.1:8025:8025 # WebUI
- 127.0.0.1:1025:1025 # SMTP
volumes:
- gege_jdr_backend_mailpit:/data
environment:
MP_MAX_MESSAGES: 5000
MP_DATABASE: /data/mailpit.db
MP_SMTP_AUTH_ACCEPT_ANY: 1
MP_SMTP_AUTH_ALLOW_INSECURE: 1
volumes: volumes:
gege_jdr_backend_db_data: gege_jdr_backend_db_data:
gege_jdr_backend_pgadmin_data: gege_jdr_backend_pgadmin_data:
gege_jdr_backend_mailpit:

14
docker/mod.just Normal file
View File

@ -0,0 +1,14 @@
default: start
start:
docker compose -f compose.dev.yml up -d
stop:
docker compose -f compose.dev.yml down
logs:
docker compose -f compose.dev.yml logs -f
## Local Variables:
## mode: makefile
## End:

12
flake.lock generated
View File

@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1732014248, "lastModified": 1736344531,
"narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=", "narHash": "sha256-8YVQ9ZbSfuUk2bUf2KRj60NRraLPKPS0Q4QFTbc+c2c=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "23e89b7da85c3640bbc2173fe04f4bd114342367", "rev": "bffc22eb12172e6db3c5dde9e3e5628f8e3e7912",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2" "nixpkgs": "nixpkgs_2"
}, },
"locked": { "locked": {
"lastModified": 1732242723, "lastModified": 1736476219,
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=", "narHash": "sha256-+qyv3QqdZCdZ3cSO/cbpEY6tntyYjfe1bB12mdpNFaY=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a", "rev": "de30cc5963da22e9742bbbbb9a3344570ed237b9",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -18,30 +18,61 @@
rustc = rustVersion; rustc = rustVersion;
}; };
appName = "gege-jdr-backend"; appNameBackend = "gejdr-backend";
appNameBot = "gejdr-bot";
appRustBuild = rustPlatform.buildRustPackage { appRustBuildBackend = rustPlatform.buildRustPackage {
pname = appName; pname = appNameBackend;
version = "0.1.0"; version = "0.1.0";
src = ./.; src = ./.;
cargoLock.lockFile = ./Cargo.lock; cargoLock.lockFile = ./Cargo.lock;
buildPhase = ''
cd gejdr-backend
SQLX_OFFLINE="1" cargo build --release --bin gejdr-backend
'';
installPhase = ''
cargo install --path . --root "$out/bin/"
'';
};
appRustBuildBot = rustPlatform.buildRustPackage {
pname = appNameBot;
version = "0.1.0";
src = ./.;
cargoLock.lockFile = ./Cargo.lock;
buildPhase = ''
SQLX_OFFLINE="1" cargo build --release --bin gejdr-bot
'';
}; };
dockerImage = pkgs.dockerTools.buildLayeredImage { dockerImageBackend = pkgs.dockerTools.buildLayeredImage {
name = appName; name = appNameBackend;
tag = "latest";
config = { config = {
Entrypoint = [ "${appRustBuild}/bin/${appName}" ]; Entrypoint = [ "${appRustBuildBackend}/bin/${appNameBackend}" ];
Env = [ "SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt" ]; Env = [ "SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt" ];
Tag = "latest"; Tag = "latest";
}; };
contents = [ appRustBuild pkgs.cacert ]; contents = [ appRustBuildBackend pkgs.cacert ];
};
dockerImageBot = pkgs.dockerTools.buildLayeredImage {
name = appNameBot;
tag = "latest";
fromImageTag = "latest";
config = {
Entrypoint = [ "${appRustBuildBot}/bin/${appNameBot}" ];
Env = [ "SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt" ];
Tag = "latest";
};
contents = [ appRustBuildBot pkgs.cacert ];
}; };
in { in {
packages = { packages = {
rustPackage = appRustBuild; backend = appRustBuildBot;
docker = dockerImage; bot = appRustBuildBot;
dockerBackend = dockerImageBackend;
dockerBot = dockerImageBot;
}; };
defaultPackage = dockerImage; defaultPackage = dockerImageBackend;
devShell = with pkgs; mkShell { devShell = with pkgs; mkShell {
buildInputs = [ buildInputs = [
bacon bacon
@ -49,7 +80,6 @@
cargo-audit cargo-audit
cargo-auditable cargo-auditable
cargo-tarpaulin cargo-tarpaulin
cargo-msrv
just just
rust-analyzer rust-analyzer
(rustVersion.override { extensions = [ "rust-src" ]; }) (rustVersion.override { extensions = [ "rust-src" ]; })

View File

@ -0,0 +1,6 @@
[all]
out = ["Xml"]
target-dir = "coverage"
output-dir = "coverage"
fail-under = 60
exclude-files = ["target/*"]

View File

@ -0,0 +1,7 @@
[all]
out = ["Html", "Lcov"]
skip-clean = true
target-dir = "coverage"
output-dir = "coverage"
fail-under = 60
exclude-files = ["target/*"]

43
gejdr-backend/Cargo.toml Normal file
View File

@ -0,0 +1,43 @@
[package]
name = "gejdr-backend"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
path = "src/lib.rs"
[[bin]]
path = "src/main.rs"
name = "gejdr-backend"
[dependencies]
gejdr-core = { path = "../gejdr-core" }
chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.1", features = ["yaml"] }
dotenvy = "0.15.7"
oauth2 = "4.4.2"
quote = "1.0.37"
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
serde = "1.0.215"
serde_json = "1.0.133"
thiserror = "1.0.69"
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.poem]
version = "3.1.3"
default-features = false
features = ["csrf", "rustls", "cookie", "test", "session"]
[dependencies.poem-openapi]
version = "5.1.2"
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }

View File

@ -0,0 +1,18 @@
default: run
build $SQLX_OFFLINE="1":
pwd
cargo auditable build --bin gejdr-backend
build-release $SQLX_OFFLINE="1":
cargo auditable build --release --bin gejdr-backend
build-docker:
nix build .#dockerBackend
run:
cargo auditable run --bin gejdr-backend
## Local Variables:
## mode: makefile
## End:

View File

@ -3,5 +3,5 @@ debug: true
application: application:
protocol: http protocol: http
host: 127.0.0.1 host: localhost
base_url: http://127.0.0.1:3000 base_url: http://localhost:3000

View File

@ -0,0 +1,55 @@
use super::{ApiError, DiscordErrorResponse};
use gejdr_core::models::accounts::RemoteUser;
static DISCORD_URL: &str = "https://discord.com/api/v10/";
pub async fn get_user_profile(token: &str) -> Result<RemoteUser, ApiError> {
let client = reqwest::Client::new();
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {token}").parse().unwrap(),
);
let response = client
.get(format!("{DISCORD_URL}/users/@me"))
.headers(headers)
.send()
.await;
match response {
Ok(resp) => {
if resp.status().is_success() {
resp.json::<RemoteUser>()
.await
.map_err(std::convert::Into::into)
} else {
let error_response = resp.json::<DiscordErrorResponse>().await;
match error_response {
Ok(val) => Err(ApiError::Api(val)),
Err(e) => Err(ApiError::Reqwest(e)),
}
}
}
Err(e) => Err(ApiError::Reqwest(e)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn user_profile_invalid_token_results_in_401() {
let res = get_user_profile("invalid").await;
assert!(res.is_err());
let err = res.err().unwrap();
println!("Error: {err:?}");
let expected = DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
};
assert!(matches!(ApiError::Api(expected), _err));
}
// TODO: Find a way to mock calls to discord.com API with a
// successful reply
}

View File

@ -0,0 +1,41 @@
use reqwest::Error as ReqwestError;
use std::fmt::{self, Display};
use thiserror::Error;
pub mod discord;
#[derive(Debug, serde::Deserialize, PartialEq, Eq)]
pub struct DiscordErrorResponse {
pub message: String,
pub code: u16,
}
impl Display for DiscordErrorResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "DiscordErrorResponse: {} ({})", self.message, self.code)
}
}
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Reqwest error: {0}")]
Reqwest(#[from] ReqwestError),
#[error("API Error: {0}")]
Api(DiscordErrorResponse),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_discord_error_response() {
let error = DiscordErrorResponse {
message: "Message".into(),
code: 42,
};
let error_str = error.to_string();
let expected = "DiscordErrorResponse: Message (42)".to_string();
assert_eq!(expected, error_str);
}
}

View File

@ -0,0 +1,14 @@
use thiserror::Error;
#[allow(dead_code)]
#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
Sql(#[from] gejdr_core::sqlx::Error),
#[error("OAuth token error: {0}")]
TokenError(String),
#[error("Unauthorized")]
Unauthorized,
#[error("Attempted to get a value, none found")]
OptionError,
}

View File

@ -5,10 +5,14 @@
#![allow(clippy::unused_async)] #![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling #![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
pub mod route; use gejdr_core::sqlx;
pub mod settings;
pub mod startup; mod api_wrapper;
pub mod telemetry; mod errors;
mod oauth;
mod route;
mod settings;
mod startup;
type MaybeListener = Option<poem::listener::TcpListener<String>>; type MaybeListener = Option<poem::listener::TcpListener<String>>;
@ -16,8 +20,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
dotenvy::dotenv().ok(); dotenvy::dotenv().ok();
let settings = settings::Settings::new().expect("Failed to read settings"); let settings = settings::Settings::new().expect("Failed to read settings");
if !cfg!(test) { if !cfg!(test) {
let subscriber = telemetry::get_subscriber(settings.clone().debug); let subscriber = gejdr_core::telemetry::get_subscriber(settings.clone().debug);
telemetry::init_subscriber(subscriber); gejdr_core::telemetry::init_subscriber(subscriber);
} }
tracing::event!( tracing::event!(
target: "gege-jdr-backend", target: "gege-jdr-backend",
@ -29,8 +33,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
tracing::event!( tracing::event!(
target: "gege-jdr-backend", target: "gege-jdr-backend",
tracing::Level::INFO, tracing::Level::INFO,
"Listening on http://127.0.0.1:{}/", "Listening on {}",
application.port() application.settings.web_address()
); );
application application
} }

View File

@ -1,5 +1,5 @@
#[cfg(not(tarpaulin_include))] #[cfg(not(tarpaulin_include))]
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), std::io::Error> { async fn main() -> Result<(), std::io::Error> {
gege_jdr_backend::run(None).await gejdr_backend::run(None).await
} }

View File

@ -0,0 +1,62 @@
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
};
use reqwest::Url;
use crate::{errors::ApiError, settings::Settings};
use super::OauthProvider;
#[derive(Debug, Clone)]
pub struct DiscordOauthProvider {
client: BasicClient,
}
impl DiscordOauthProvider {
pub fn new(settings: &Settings) -> Self {
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
.expect("Invalid token endpoint URL");
let revocation_url =
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
.expect("Invalid revocation URL");
let client = BasicClient::new(
ClientId::new(settings.discord.client_id.clone()),
Some(ClientSecret::new(settings.discord.client_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
.set_revocation_uri(revocation_url);
Self { client }
}
}
impl OauthProvider for DiscordOauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
.set_pkce_challenge(pkce_challenge)
.url();
(auth_url, csrf_token, pkce_verifier)
}
async fn token(
&self,
code: String,
verifier: PkceCodeVerifier,
) -> Result<super::Token, ApiError> {
self.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(verifier)
.request_async(oauth2::reqwest::async_http_client)
.await
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
}
}

View File

@ -0,0 +1,17 @@
mod discord;
pub use discord::DiscordOauthProvider;
use oauth2::{
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
StandardTokenResponse,
};
use reqwest::Url;
use crate::errors::ApiError;
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
pub trait OauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
}

View File

@ -0,0 +1,220 @@
use gejdr_core::models::accounts::User;
use gejdr_core::sqlx::PgPool;
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
use poem::web::Data;
use poem::{session::Session, web::Form};
use poem_openapi::payload::{Json, PlainText};
use poem_openapi::{ApiResponse, Object, OpenApi};
use crate::oauth::{DiscordOauthProvider, OauthProvider};
use super::errors::ErrorResponse;
use super::ApiCategory;
type Token =
oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>;
pub struct AuthApi;
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
struct DiscordCallbackRequest {
code: String,
state: String,
}
impl DiscordCallbackRequest {
pub fn check_token(&self, token: &CsrfToken) -> Result<(), LoginStatusResponse> {
if *token.secret().to_string() == self.state {
Ok(())
} else {
Err(LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some(
"OAuth provider did not send a message that matches what was expected".into(),
),
})))
}
}
}
#[derive(ApiResponse)]
enum LoginStatusResponse {
#[oai(status = 201)]
LoggedIn(Json<UserInfo>),
#[oai(status = 201)]
LoggedOut(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 301)]
LoginRedirect(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 500)]
TokenError(Json<ErrorResponse>),
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
#[derive(Debug, Eq, PartialEq, serde::Serialize, Object)]
struct UserInfo {
id: String,
username: String,
display_name: Option<String>,
avatar: Option<String>,
}
impl From<User> for UserInfo {
fn from(value: User) -> Self {
Self {
id: value.id,
username: value.username,
display_name: value.name,
avatar: value.avatar,
}
}
}
#[derive(ApiResponse)]
enum UserInfoResponse {
#[oai(status = 201)]
UserInfo(Json<UserInfo>),
#[oai(status = 401)]
Unauthorized,
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
impl From<UserInfoResponse> for LoginStatusResponse {
fn from(value: UserInfoResponse) -> Self {
match value {
UserInfoResponse::UserInfo(json) => Self::LoggedIn(json),
UserInfoResponse::Unauthorized => unimplemented!(),
UserInfoResponse::DatabaseError(json) => Self::DatabaseError(json),
UserInfoResponse::DiscordError(json) => Self::DiscordError(json),
}
}
}
#[derive(ApiResponse)]
enum CsrfResponse {
#[oai(status = 201)]
Token(PlainText<String>),
}
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
impl AuthApi {
async fn fetch_remote_user(
pool: Data<&PgPool>,
token: Token,
) -> Result<UserInfoResponse, UserInfoResponse> {
crate::api_wrapper::discord::get_user_profile(token.access_token().secret())
.await
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Failed to communicate with Discord: {}",
e
);
UserInfoResponse::DiscordError(Json(e.into()))
})?
.refresh_in_database(&pool)
.await
.map(|user| UserInfoResponse::UserInfo(Json(user.into())))
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Database error: {}",
e
);
UserInfoResponse::DatabaseError(Json(e.into()))
})
}
#[oai(path = "/signin/discord", method = "get")]
async fn signin_discord(
&self,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
session.set("csrf", csrf_token);
session.set("pkce", pkce_verifier);
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Signin through Discord",
);
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
}
#[oai(path = "/callback/discord", method = "get")]
async fn callback_discord(
&self,
Form(auth_request): Form<DiscordCallbackRequest>,
oauth: Data<&DiscordOauthProvider>,
pool: Data<&PgPool>,
session: &Session,
) -> Result<LoginStatusResponse, LoginStatusResponse> {
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Discord callback",
);
let csrf_token = session.get::<CsrfToken>("csrf").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch csrf token from session".to_string(),
..Default::default()
}))
})?;
auth_request.check_token(&csrf_token)?;
let pkce_verifier = session.get::<PkceCodeVerifier>("pkce").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch pkce verifier from session".to_string(),
..Default::default()
}))
})?;
let token = oauth
.token(auth_request.code, pkce_verifier)
.await
.map_err(|e| LoginStatusResponse::TokenError(Json(e.into())))?;
session.set("token", token.clone());
Self::fetch_remote_user(pool, token)
.await
.map(std::convert::Into::into)
.map_err(std::convert::Into::into)
}
#[oai(path = "/csrf", method = "get")]
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
CsrfResponse::Token(PlainText(token.0.clone()))
}
#[oai(path = "/signout", method = "post")]
async fn signout(&self, session: &Session) -> LoginStatusResponse {
session.purge();
LoginStatusResponse::LoggedOut("/".to_string(), "no-cache".to_string())
}
#[oai(path = "/me", method = "get")]
async fn user_info(
&self,
session: &Session,
pool: Data<&PgPool>,
) -> Result<UserInfoResponse, UserInfoResponse> {
let token = session
.get::<Token>("token")
.ok_or(UserInfoResponse::Unauthorized)?;
Self::fetch_remote_user(pool, token).await
}
}

View File

@ -0,0 +1,204 @@
use poem_openapi::Object;
use reqwest::Error as ReqwestError;
use crate::api_wrapper::ApiError as ApiWrapperError;
use crate::errors::ApiError;
#[derive(Debug, serde::Serialize, Default, Object, PartialEq, Eq)]
pub struct ErrorResponse {
pub code: u16,
pub message: String,
pub details: Option<String>,
}
impl From<ApiError> for ErrorResponse {
fn from(value: ApiError) -> Self {
match value {
ApiError::Sql(e) => Self {
code: 500,
message: "SQL error".into(),
details: Some(e.to_string()),
},
ApiError::TokenError(e) => Self {
code: 500,
message: "OAuth token error".into(),
details: Some(e),
},
ApiError::Unauthorized => Self {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
},
ApiError::OptionError => Self {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
},
}
}
}
impl From<ReqwestError> for ErrorResponse {
fn from(value: ReqwestError) -> Self {
Self {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some(value.status().map_or_else(
|| "Communication failed before we could hear back from Discord".into(),
|status| format!("Discord sent back the error code {status}"),
)),
}
}
}
impl From<ApiWrapperError> for ErrorResponse {
fn from(source: ApiWrapperError) -> Self {
match source {
ApiWrapperError::Reqwest(e) => e.into(),
ApiWrapperError::Api(e) => Self {
code: if e.message.as_str().starts_with("401") {
401
} else {
e.code
},
message: e.message,
details: None,
},
}
}
}
impl From<gejdr_core::sqlx::Error> for ErrorResponse {
fn from(_value: gejdr_core::sqlx::Error) -> Self {
Self {
code: 500,
message: "Internal database error".into(),
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api_wrapper::{ApiError as ApiWrapperError, DiscordErrorResponse};
#[test]
fn conversion_from_sql_api_error_works() {
let sql_error = ApiError::Sql(gejdr_core::sqlx::Error::ColumnNotFound(
"COLUMN_NAME".to_string(),
));
let final_error = ErrorResponse::from(sql_error);
let expected_error = ErrorResponse {
code: 500,
message: "SQL error".into(),
details: Some("no column found for name: COLUMN_NAME".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_token_error_works() {
let initial_error = ApiError::TokenError("TOKEN ERROR".into());
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some("TOKEN ERROR".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_unauthorized_works() {
let initial_error = ApiError::Unauthorized;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_option_error_works() {
let initial_error = ApiError::OptionError;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[tokio::test]
async fn conversion_from_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.err().unwrap().into();
assert_eq!(expected, actual);
}
#[tokio::test]
async fn conversion_from_apiwrappererror_with_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let err = ApiWrapperError::Reqwest(err.err().unwrap());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_401_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
});
let expected = ErrorResponse {
code: 401,
message: "401: Unauthorized".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_generic_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "Something else".into(),
});
let expected = ErrorResponse {
code: 0,
message: "Something else".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_database_error() {
let err = gejdr_core::sqlx::Error::PoolClosed;
let expected = ErrorResponse {
code: 500,
message: "Internal database error".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
}

View File

@ -10,7 +10,7 @@ enum HealthResponse {
pub struct HealthApi; pub struct HealthApi;
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")] #[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
impl HealthApi { impl HealthApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn health_check(&self) -> HealthResponse { async fn health_check(&self) -> HealthResponse {
@ -23,7 +23,7 @@ impl HealthApi {
async fn health_check_works() { async fn health_check_works() {
let app = crate::get_test_app(None).await; let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app); let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/health-check").send().await; let resp = cli.get("/v1/api/health-check").send().await;
resp.assert_status_is_ok(); resp.assert_status_is_ok();
resp.assert_text("").await; resp.assert_text("").await;
} }

View File

@ -6,13 +6,19 @@ pub use health::HealthApi;
mod version; mod version;
pub use version::VersionApi; pub use version::VersionApi;
mod errors;
mod auth;
pub use auth::AuthApi;
#[derive(Tags)] #[derive(Tags)]
enum ApiCategory { enum ApiCategory {
Auth,
Health, Health,
Version, Version,
} }
pub(crate) struct Api; pub struct Api;
#[OpenApi] #[OpenApi]
impl Api {} impl Api {}

View File

@ -25,7 +25,7 @@ enum VersionResponse {
pub struct VersionApi; pub struct VersionApi;
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")] #[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
impl VersionApi { impl VersionApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> { async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
@ -38,7 +38,7 @@ impl VersionApi {
async fn version_works() { async fn version_works() {
let app = crate::get_test_app(None).await; let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app); let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/version").send().await; let resp = cli.get("/v1/api/version").send().await;
resp.assert_status_is_ok(); resp.assert_status_is_ok();
let json = resp.json().await; let json = resp.json().await;
let json_value = json.value(); let json_value = json.value();

View File

@ -1,4 +1,4 @@
use sqlx::ConnectOptions; use gejdr_core::database::Database;
#[derive(Debug, serde::Deserialize, Clone, Default)] #[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Settings { pub struct Settings {
@ -13,15 +13,7 @@ pub struct Settings {
impl Settings { impl Settings {
#[must_use] #[must_use]
pub fn web_address(&self) -> String { pub fn web_address(&self) -> String {
if self.debug { self.application.base_url.clone()
format!(
"{}:{}",
self.application.base_url.clone(),
self.application.port
)
} else {
self.application.base_url.clone()
}
} }
/// Multipurpose function that helps detect the current /// Multipurpose function that helps detect the current
@ -66,7 +58,7 @@ impl Settings {
)) ))
.add_source( .add_source(
config::Environment::with_prefix("APP") config::Environment::with_prefix("APP")
.prefix_separator("_") .prefix_separator("__")
.separator("__"), .separator("__"),
) )
.build()?; .build()?;
@ -84,35 +76,6 @@ pub struct ApplicationSettings {
pub protocol: String, pub protocol: String,
} }
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Database {
pub host: String,
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
pub require_ssl: bool,
}
impl Database {
#[must_use]
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
let ssl_mode = if self.require_ssl {
sqlx::postgres::PgSslMode::Require
} else {
sqlx::postgres::PgSslMode::Prefer
};
sqlx::postgres::PgConnectOptions::new()
.host(&self.host)
.username(&self.user)
.password(&self.password)
.port(self.port)
.ssl_mode(ssl_mode)
.database(&self.name)
.log_statements(tracing::log::LevelFilter::Trace)
}
}
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub enum Environment { pub enum Environment {
Development, Development,
@ -167,8 +130,8 @@ pub struct EmailSettings {
#[derive(Debug, serde::Deserialize, Clone, Default)] #[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Discord { pub struct Discord {
client_id: String, pub client_id: String,
client_secret: String, pub client_secret: String,
} }
#[cfg(test)] #[cfg(test)]
@ -235,7 +198,7 @@ mod tests {
#[test] #[test]
fn web_address_works() { fn web_address_works() {
let mut settings = Settings { let settings = Settings {
debug: false, debug: false,
application: ApplicationSettings { application: ApplicationSettings {
base_url: "127.0.0.1".to_string(), base_url: "127.0.0.1".to_string(),
@ -244,10 +207,7 @@ mod tests {
}, },
..Default::default() ..Default::default()
}; };
let expected_no_debug = "127.0.0.1".to_string(); let expected = "127.0.0.1".to_string();
let expected_debug = "127.0.0.1:3000".to_string(); assert_eq!(expected, settings.web_address());
assert_eq!(expected_no_debug, settings.web_address());
settings.debug = true;
assert_eq!(expected_debug, settings.web_address());
} }
} }

View File

@ -1,35 +1,36 @@
use poem::middleware::Cors; use gejdr_core::sqlx;
use poem::middleware::{AddDataEndpoint, CorsEndpoint}; use poem::middleware::{AddDataEndpoint, CorsEndpoint};
use poem::middleware::{CookieJarManagerEndpoint, Cors};
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
use poem::{EndpointExt, Route}; use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
use crate::oauth::DiscordOauthProvider;
use crate::route::AuthApi;
use crate::{ use crate::{
route::{Api, HealthApi, VersionApi}, route::{Api, HealthApi, VersionApi},
settings::Settings, settings::Settings,
}; };
#[must_use]
pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgres::PgPool {
tracing::event!(
target: "startup",
tracing::Level::INFO,
"connecting to database with configuration {:?}",
settings.clone()
);
sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(settings.get_connect_options())
}
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>; type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>; pub type App = AddDataEndpoint<
AddDataEndpoint<
AddDataEndpoint<
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
DiscordOauthProvider,
>,
sqlx::Pool<sqlx::Postgres>,
>,
Settings,
>;
pub struct Application { pub struct Application {
server: Server, server: Server,
app: poem::Route, app: poem::Route,
port: u16, port: u16,
database: sqlx::postgres::PgPool, database: sqlx::postgres::PgPool,
settings: Settings, pub settings: Settings,
} }
pub struct RunnableApplication { pub struct RunnableApplication {
@ -61,6 +62,8 @@ impl From<Application> for RunnableApplication {
let app = val let app = val
.app .app
.with(Cors::new()) .with(Cors::new())
.with(CookieSession::new(CookieConfig::default().secure(true)))
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
.data(val.database) .data(val.database)
.data(val.settings); .data(val.settings);
let server = val.server; let server = val.server;
@ -74,16 +77,16 @@ impl Application {
test_pool: Option<sqlx::postgres::PgPool>, test_pool: Option<sqlx::postgres::PgPool>,
) -> sqlx::postgres::PgPool { ) -> sqlx::postgres::PgPool {
let database_pool = let database_pool =
test_pool.map_or_else(|| get_connection_pool(&settings.database), |pool| pool); test_pool.map_or_else(|| settings.database.get_connection_pool(), |pool| pool);
if !cfg!(test) { if !cfg!(test) {
migrate_database(&database_pool).await; gejdr_core::database::Database::migrate(&database_pool).await;
} }
database_pool database_pool
} }
fn setup_app(settings: &Settings) -> poem::Route { fn setup_app(settings: &Settings) -> poem::Route {
let api_service = OpenApiService::new( let api_service = OpenApiService::new(
(Api, HealthApi, VersionApi), (Api, AuthApi, HealthApi, VersionApi),
settings.application.clone().name, settings.application.clone().name,
settings.application.clone().version, settings.application.clone().version,
); );
@ -104,6 +107,7 @@ impl Application {
}); });
poem::Server::new(tcp_listener) poem::Server::new(tcp_listener)
} }
pub async fn build( pub async fn build(
settings: Settings, settings: Settings,
test_pool: Option<sqlx::postgres::PgPool>, test_pool: Option<sqlx::postgres::PgPool>,
@ -133,10 +137,3 @@ impl Application {
self.port self.port
} }
} }
async fn migrate_database(pool: &sqlx::postgres::PgPool) {
sqlx::migrate!()
.run(pool)
.await
.expect("Failed to migrate the database");
}

1
gejdr-bot/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

6
gejdr-bot/Cargo.toml Normal file
View File

@ -0,0 +1,6 @@
[package]
name = "gejdr-bot"
version = "0.1.0"
edition = "2021"
[dependencies]

3
gejdr-bot/src/main.rs Normal file
View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

1
gejdr-core/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

16
gejdr-core/Cargo.toml Normal file
View File

@ -0,0 +1,16 @@
[package]
name = "gejdr-core"
version = "0.1.0"
edition = "2021"
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
serde = "1.0.215"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.sqlx]
version = "0.8.2"
default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]

View File

@ -0,0 +1,3 @@
-- Add down migration script here
DROP TABLE IF EXISTS public.users;
DROP EXTENSION IF EXISTS "uuid-ossp";

View File

@ -0,0 +1,15 @@
-- Add up migration script here
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users
(
id character varying(255) NOT NULL,
username character varying(255) NOT NULL,
email character varying(255),
avatar character varying(511),
name character varying(255),
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
CONSTRAINT users_email_unique UNIQUE (email)
);

View File

@ -0,0 +1,50 @@
use sqlx::ConnectOptions;
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Database {
pub host: String,
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
pub require_ssl: bool,
}
impl Database {
#[must_use]
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
let ssl_mode = if self.require_ssl {
sqlx::postgres::PgSslMode::Require
} else {
sqlx::postgres::PgSslMode::Prefer
};
sqlx::postgres::PgConnectOptions::new()
.host(&self.host)
.username(&self.user)
.password(&self.password)
.port(self.port)
.ssl_mode(ssl_mode)
.database(&self.name)
.log_statements(tracing::log::LevelFilter::Trace)
}
#[must_use]
pub fn get_connection_pool(&self) -> sqlx::postgres::PgPool {
tracing::event!(
target: "startup",
tracing::Level::INFO,
"connecting to database with configuration {:?}",
self.clone()
);
sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(self.get_connect_options())
}
pub async fn migrate(pool: &sqlx::PgPool) {
sqlx::migrate!()
.run(pool)
.await
.expect("Failed to migrate the database");
}
}

4
gejdr-core/src/lib.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod database;
pub mod models;
pub mod telemetry;
pub use sqlx;

View File

@ -0,0 +1,396 @@
use sqlx::PgPool;
type Timestampz = chrono::DateTime<chrono::Utc>;
#[derive(serde::Deserialize, PartialEq, Eq, Debug, Clone, Default)]
pub struct RemoteUser {
id: String,
username: String,
global_name: Option<String>,
email: Option<String>,
avatar: Option<String>,
}
impl RemoteUser {
/// Refresh in database the row related to the remote user. Maybe
/// create a row for this user if needed.
pub async fn refresh_in_database(self, pool: &PgPool) -> Result<User, sqlx::Error> {
match User::find(pool, &self.id).await? {
Some(local_user) => local_user.update_from_remote(self).update(pool).await,
None => User::from(self).save(pool).await,
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq, Default, Clone)]
pub struct User {
pub id: String,
pub username: String,
pub email: Option<String>,
pub avatar: Option<String>,
pub name: Option<String>,
pub created_at: Timestampz,
pub last_updated: Timestampz,
}
impl From<RemoteUser> for User {
fn from(value: RemoteUser) -> Self {
Self {
id: value.id,
username: value.username,
email: value.email,
avatar: value.avatar,
name: value.global_name,
created_at: chrono::offset::Utc::now(),
last_updated: chrono::offset::Utc::now(),
}
}
}
impl PartialEq<RemoteUser> for User {
#[allow(clippy::suspicious_operation_groupings)]
fn eq(&self, other: &RemoteUser) -> bool {
self.id == other.id
&& self.username == other.username
&& self.email == other.email
&& self.avatar == other.avatar
&& self.name == other.global_name
}
}
impl PartialEq<User> for RemoteUser {
fn eq(&self, other: &User) -> bool {
other == self
}
}
impl User {
pub fn update_from_remote(self, from: RemoteUser) -> Self {
if self == from {
self
} else {
Self {
username: from.username,
email: from.email,
avatar: from.avatar,
name: from.global_name,
last_updated: chrono::offset::Utc::now(),
..self
}
}
}
pub async fn find(pool: &PgPool, id: &String) -> Result<Option<Self>, sqlx::Error> {
sqlx::query_as!(Self, r#"SELECT * FROM users WHERE id = $1"#, id)
.fetch_optional(pool)
.await
}
pub async fn save(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
sqlx::query_as!(
Self,
r#"
INSERT INTO users (id, username, email, avatar, name, created_at, last_updated)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING *
"#,
self.id,
self.username,
self.email,
self.avatar,
self.name,
self.created_at,
self.last_updated
)
.fetch_one(pool)
.await
}
pub async fn update(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
sqlx::query_as!(
Self,
r#"
UPDATE users
SET username = $1, email = $2, avatar = $3, name = $4, last_updated = $5
WHERE id = $6
RETURNING *
"#,
self.username,
self.email,
self.avatar,
self.name,
self.last_updated,
self.id
)
.fetch_one(pool)
.await
}
pub async fn save_or_update(&self, pool: &PgPool) -> Result<Self, sqlx::Error> {
if Self::find(pool, &self.id).await?.is_some() {
self.update(pool).await
} else {
self.save(pool).await
}
}
pub async fn delete(pool: &PgPool, id: &String) -> Result<u64, sqlx::Error> {
let rows_affected = sqlx::query!("DELETE FROM users WHERE id = $1", id)
.execute(pool)
.await?
.rows_affected();
Ok(rows_affected)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert_remote_user_to_local_user() {
let remote = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: None,
};
let local: User = remote.into();
let expected = User {
id: "user-id".into(),
username: "username".into(),
email: Some("user@example.com".into()),
avatar: None,
name: None,
created_at: local.created_at,
last_updated: local.last_updated,
};
assert_eq!(expected, local);
}
#[test]
fn can_compare_remote_and_local_user() {
let remote_same = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: None,
};
let remote_different = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: Some("some-hash".into()),
};
let local = User {
id: "user-id".into(),
username: "username".into(),
email: Some("user@example.com".into()),
avatar: None,
name: None,
created_at: chrono::offset::Utc::now(),
last_updated: chrono::offset::Utc::now(),
};
assert_eq!(remote_same, local);
assert_ne!(remote_different, local);
}
#[sqlx::test]
async fn add_new_remote_users_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let remote1 = RemoteUser {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let remote2 = RemoteUser {
id: "id2".into(),
username: "user2".into(),
..Default::default()
};
remote1.refresh_in_database(&pool).await?;
remote2.refresh_in_database(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn update_local_users_in_db_from_remote(pool: sqlx::PgPool) -> sqlx::Result<()> {
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
let remote1 = RemoteUser {
id: "id1".into(),
username: "user1-new".into(),
..Default::default()
};
let remote2 = RemoteUser {
id: "id2".into(),
username: "user2-new".into(),
..Default::default()
};
remote1.refresh_in_database(&pool).await?;
remote2.refresh_in_database(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
users
.iter()
.for_each(|user| assert!(user.last_updated > user.created_at));
Ok(())
}
#[test]
fn update_local_user_from_identical_remote_shouldnt_change_local() {
let remote = RemoteUser {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let local = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let new_local = local.clone().update_from_remote(remote);
assert_eq!(local, new_local);
}
#[test]
fn update_local_user_from_different_remote() {
let remote = RemoteUser {
id: "id1".into(),
username: "user2".into(),
..Default::default()
};
let local = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let new_local = local.clone().update_from_remote(remote.clone());
assert_ne!(remote, local);
assert_eq!(remote, new_local);
}
#[sqlx::test]
async fn save_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let user = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
user.save(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, users.len());
assert_eq!(Some(user), users.first().cloned());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn update_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert!(db_user.name.is_none());
let user = User {
id: "id1".into(),
username: "user1".into(),
name: Some("Cool Name".into()),
..Default::default()
};
user.update(&pool).await?;
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert!(db_user.name.is_some());
assert_eq!(Some("Cool Name".to_string()), db_user.name);
Ok(())
}
#[sqlx::test]
async fn save_or_update_saves_if_no_exist(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(0, rows.len());
let user = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
user.save_or_update(&pool).await?;
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, rows.len());
let db_user = rows.first();
assert_eq!(Some(user), db_user.cloned());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn save_or_update_updates_if_exists(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let user = User {
id: "id1".into(),
username: "user1".into(),
name: Some("Cool Nam".into()),
..Default::default()
};
user.save_or_update(&pool).await?;
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert_eq!(user.name, db_user.name);
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn delete_removes_account_from_db(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let id = "id1".to_string();
let deletions = User::delete(&pool, &id).await?;
assert_eq!(1, deletions);
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, rows.len());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn delete_with_wrong_id_shouldnt_delete_anything(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let id = "invalid".to_string();
let deletions = User::delete(&pool, &id).await?;
assert_eq!(0, deletions);
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
Ok(())
}
}

View File

@ -0,0 +1,2 @@
INSERT INTO users (id, username) VALUES ('id1', 'user1');
INSERT INTO users (id, username) VALUES ('id2', 'user2');

View File

@ -0,0 +1 @@
pub mod accounts;

View File

@ -5,7 +5,7 @@ pub fn get_subscriber(debug: bool) -> impl tracing::Subscriber + Send + Sync {
let env_filter = if debug { "debug" } else { "info" }.to_string(); let env_filter = if debug { "debug" } else { "info" }.to_string();
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter)); .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter));
let stdout_log = tracing_subscriber::fmt::layer().pretty(); let stdout_log = tracing_subscriber::fmt::layer().pretty().with_test_writer();
let subscriber = tracing_subscriber::Registry::default() let subscriber = tracing_subscriber::Registry::default()
.with(env_filter) .with(env_filter)
.with(stdout_log); .with(stdout_log);

View File

@ -1,10 +1,7 @@
default: run mod backend 'gejdr-backend/backend.just'
mod docker
prepare: default: lint
cargo sqlx prepare
migrate:
sqlx migrate run
format: format:
cargo fmt --all cargo fmt --all
@ -12,38 +9,30 @@ format:
format-check: format-check:
cargo fmt --check --all cargo fmt --check --all
build: migrate:
cargo auditable build sqlx migrate run --source gejdr-core/migrations
build-release: build $SQLX_OFFLINE="1":
cargo auditable build --release cargo auditable build --bin gejdr-backend
cargo auditable build --bin gejdr-bot
run: docker-start build-release $SQLX_OFFLINE="1":
cargo auditable run cargo auditable build --release --bin gejdr-backend
cargo auditable build --release --bin gejdr-bot
run-no-docker:
cargo auditable run
lint: lint:
cargo clippy --all-targets cargo clippy --all-targets
msrv:
cargo msrv verify
release-build:
cargo auditable build --release
release-run:
cargo auditable run --release
audit: build audit: build
cargo audit bin target/debug/gege-jdr-backend cargo audit bin target/debug/gejdr-backend
cargo audit bin target/debug/gejdr-bot
audit-release: build-release audit-release: build-release
cargo audit bin target/release/gege-jdr-backend cargo audit bin target/release/gejdr-backend
cargo audit bin target/release/gejdr-bot
test: test:
cargo test cargo test --all-targets --all
coverage: coverage:
mkdir -p coverage mkdir -p coverage
@ -53,19 +42,10 @@ coverage-ci:
mkdir -p coverage mkdir -p coverage
cargo tarpaulin --config .tarpaulin.ci.toml cargo tarpaulin --config .tarpaulin.ci.toml
check-all: format-check lint msrv coverage audit check-all: format-check lint coverage audit
docker-build: docker-backend $SQLX_OFFLINE="1":
nix build .#docker nix build .#dockerBackend
docker-start:
docker compose -f docker/compose.dev.yml up -d
docker-stop:
docker compose -f docker/compose.dev.yml down
docker-logs:
docker compose -f docker/compose.dev.yml logs -f
## Local Variables: ## Local Variables:
## mode: makefile ## mode: makefile

View File

@ -1,5 +0,0 @@
-- Add down migration script here
ALTER TABLE IF EXISTS public.sessions DROP CONSTRAINT IF EXISTS sessions_user_id_users_fk;
DROP TABLE IF EXISTS public.sessions;
DROP TABLE IF EXISTS public.users;
DROP EXTENSION IF EXISTS "uuid-ossp";

View File

@ -1,29 +0,0 @@
-- Add up migration script here
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
email character varying(255) NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
CONSTRAINT users_email_unique UNIQUE (email)
);
CREATE TABLE IF NOT EXISTS public.sessions
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
user_id uuid NOT NULL,
session_id character varying NOT NULL,
expires_at timestamp with time zone NOT NULL,
PRIMARY KEY (id),
CONSTRAINT sessions_user_id_unique UNIQUE (user_id)
);
ALTER TABLE IF EXISTS public.sessions
ADD CONSTRAINT sessions_user_id_users_fk FOREIGN KEY (user_id)
REFERENCES public.users (id) MATCH SIMPLE
ON UPDATE CASCADE
ON DELETE CASCADE
NOT VALID;