Compare commits

..

7 Commits

Author SHA1 Message Date
578abef488
ci: restore publishing Docker images
Some checks failed
CI / tests (push) Failing after 1s
CI / tests (pull_request) Failing after 2s
Publish Docker image / publish (pull_request) Failing after 1h12m17s
2025-01-21 22:32:12 +01:00
31dd1c41e3
chore: update project metadata 2025-01-21 00:21:51 +01:00
582a11ac40
feat: add new crud macro for easier entity manipulation in DB 2025-01-21 00:21:51 +01:00
c783411159
feat: OAuth implementation with Discord
This commit separates the core features of géjdr from the backend as
these will also be used by the bot in the future.

This commit also updates the dependencies of the project. It also
removes the dependency lettre as well as the mailpit docker service
for developers as it appears clearer this project won’t send emails
anytime soon.

The publication of a docker image is also postponed until later.
2025-01-21 00:21:51 +01:00
ae10711e41
chore: update rust toolchain 2024-11-23 09:39:52 +01:00
ff90b1959f
feat: authentication through Discord OAuth2 2024-08-10 12:13:05 +02:00
1125bc4a38
Initial commit 2024-08-10 12:12:58 +02:00
54 changed files with 5995 additions and 253 deletions

View File

@ -1,15 +1,14 @@
;;; Directory Local Variables -*- no-byte-compile: t -*-
;;; For more information see (info "(emacs) Directory Variables")
((sql-mode
.
((eval . (progn
(setq-local lsp-sqls-connections
`(((driver . "postgresql")
(dataSourceName .
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
(getenv "DB_HOST")
(getenv "DB_PORT")
(getenv "DB_USER")
(getenv "DB_PASSWORD")
(getenv "DB_NAME")))))))))))
((rustic-mode . ((fill-column . 80)))
(sql-mode . ((eval . (progn
(setq-local lsp-sqls-connections
`(((driver . "postgresql")
(dataSourceName \,
(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
(getenv "DB_HOST")
(getenv "DB_PORT")
(getenv "DB_USER")
(getenv "DB_PASSWORD")
(getenv "DB_NAME")))))))))))

View File

@ -45,8 +45,6 @@ jobs:
run: nix develop --command -- just lint
- name: Audit
run: nix develop --command -- just audit
- name: Minimum supported Rust version check
run: nix develop --command -- just msrv
- name: Tests
run: nix develop --command -- just test
- name: Coverage
@ -62,7 +60,7 @@ jobs:
hide_complexity: false
indicators: true
output: both
thresholds: '60 80'
thresholds: '40 80'
- name: Add Coverage PR Comment
uses: mshick/add-pr-comment@v2
if: gitea.event_name == 'pull_request'

View File

@ -9,6 +9,7 @@ on:
pull_request:
branches:
- 'main'
- 'develop'
jobs:
publish:
@ -20,14 +21,40 @@ jobs:
with:
username: ${{ secrets.DOCKER_REGISTRY_USERNAME }}
password: ${{ secrets.DOCKER_REGISTRY_PASSWORD }}
- uses: cachix/install-nix-action@v27
registry: ${{ vars.REGISTRY }}
- uses: cachix/install-nix-action@v30
with:
nix_path: nixpkgs=channel:nixos-unstable
extra_nix_config: |
sandbox = true
- name: Build Docker image
run: nix develop --command -- just docker-build
- name: Load Docker image
run: docker load < result
run: export HOME=/tmp/nix-home && nix build .#dockerBackend
- name: Load Docker Image
run: |
docker load < ./gejdr-backend/result
- name: Docker Metadata action
uses: docker/metadata-action@v5.5.1
uses: docker/metadata-action@v5.6.1
id: meta
with:
image: tal-backend:latest
image: gejdr-backend:latest
tags:
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=sha
labels: |
org.opencontainers.image.title=Backend GéJDR
org.opencontainers.image.description=Backend for GéJDR
org.opencontainers.image.vendor=Lucien Cartier-Tilet <lucien@phundrak.com>
- name: Retag and publish Docker image for backend
env:
TAGS: ${{ steps.meta.outputs.tags }}
run: |
for tag in ${{ steps.meta.outputs.tags }}; do
newtag=${{ vars.REGISTRY }}/$tag
echo $newtag
# docker tag gejdr-backend:latest $newtag
# docker push $newtag
done

2
.gitignore vendored
View File

@ -3,3 +3,5 @@
.env
/result
/coverage/
/gejdr-backend/result
/gejdr-bot/result

View File

@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "DELETE FROM users WHERE id = $1",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Text"
]
},
"nullable": []
},
"hash": "50293c2e54af11d4c2a553e29b671cef087a159c6ee7182d8ca929ecb748f3b7"
}

View File

@ -0,0 +1,64 @@
{
"db_name": "PostgreSQL",
"query": "UPDATE users SET username = $1, email = $2, avatar = $3, name = $4, created_at = $5, last_updated = $6 WHERE id = $7 RETURNING *",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Timestamptz",
"Timestamptz",
"Text"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "752d93e987bb126c321656bb3977ea3ae61ec20d641e6c9adc72d580b2fcc538"
}

View File

@ -0,0 +1,58 @@
{
"db_name": "PostgreSQL",
"query": "SELECT * FROM users WHERE id = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "843923b9a0257cf80f1dff554e7dc8fdfc05f489328e8376513124dfb42996e3"
}

View File

@ -0,0 +1,64 @@
{
"db_name": "PostgreSQL",
"query": "INSERT INTO users (id, username, email, avatar, name, created_at, last_updated) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING *",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Varchar"
},
{
"ordinal": 1,
"name": "username",
"type_info": "Varchar"
},
{
"ordinal": 2,
"name": "email",
"type_info": "Varchar"
},
{
"ordinal": 3,
"name": "avatar",
"type_info": "Varchar"
},
{
"ordinal": 4,
"name": "name",
"type_info": "Varchar"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "last_updated",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Varchar",
"Timestamptz",
"Timestamptz"
]
},
"nullable": [
false,
false,
true,
true,
true,
false,
false
]
},
"hash": "fe91cc30858aaf2f0d328a64d7da3a5dee255f85c130c1d6d7ee3e41b647bbf5"
}

View File

@ -2,4 +2,6 @@
out = ["Xml"]
target-dir = "coverage"
output-dir = "coverage"
fail-under = 60
fail-under = 40
exclude-files = ["target/*"]
run-types = ["AllTargets"]

View File

@ -3,4 +3,6 @@ out = ["Html", "Lcov"]
skip-clean = true
target-dir = "coverage"
output-dir = "coverage"
fail-under = 60
fail-under = 40
exclude-files = ["target/*"]
run-types = ["AllTargets"]

4125
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,61 +1,9 @@
[package]
name = "gege-jdr-backend"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["phundrak"]
rust-version = "1.78"
[workspace]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
path = "src/lib.rs"
[[bin]]
path = "src/main.rs"
name = "gege-jdr-backend"
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.0", features = ["yaml"] }
dotenvy = "0.15.7"
serde = "1.0.204"
serde_json = "1.0.120"
thiserror = "1.0.63"
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] }
[dependencies.lettre]
version = "0.11.7"
default-features = false
features = [
"builder",
"hostname",
"pool",
"rustls-tls",
"tokio1",
"tokio1-rustls-tls",
"smtp-transport"
members = [
"gejdr-core",
"gejdr-bot",
"gejdr-backend",
"gejdr-macros"
]
[dependencies.poem]
version = "3.0.4"
default-features = false
features = [
"csrf",
"rustls",
"cookie",
"test"
]
[dependencies.poem-openapi]
version = "5.0.3"
features = ["chrono", "swagger-ui", "uuid"]
[dependencies.sqlx]
version = "0.8.0"
default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
resolver = "2"

View File

@ -31,25 +31,6 @@ services:
depends_on:
- db
# If you run GegeJdrBackend in production, DO NOT use mailpit.
# This tool is for testing only. Instead, you should use a real SMTP
# provider, such as Mailgun, Mailwhale, or Postal.
mailpit:
image: axllent/mailpit:latest
restart: unless-stopped
container_name: gege-jdr-backend-mailpit
ports:
- 127.0.0.1:8025:8025 # WebUI
- 127.0.0.1:1025:1025 # SMTP
volumes:
- gege_jdr_backend_mailpit:/data
environment:
MP_MAX_MESSAGES: 5000
MP_DATABASE: /data/mailpit.db
MP_SMTP_AUTH_ACCEPT_ANY: 1
MP_SMTP_AUTH_ALLOW_INSECURE: 1
volumes:
gege_jdr_backend_db_data:
gege_jdr_backend_pgadmin_data:
gege_jdr_backend_mailpit:

14
docker/mod.just Normal file
View File

@ -0,0 +1,14 @@
default: start
start:
docker compose -f compose.dev.yml up -d
stop:
docker compose -f compose.dev.yml down
logs:
docker compose -f compose.dev.yml logs -f
## Local Variables:
## mode: makefile
## End:

24
flake.lock generated
View File

@ -5,11 +5,11 @@
"systems": "systems"
},
"locked": {
"lastModified": 1710146030,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
@ -20,11 +20,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1723175592,
"narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=",
"lastModified": 1736344531,
"narHash": "sha256-8YVQ9ZbSfuUk2bUf2KRj60NRraLPKPS0Q4QFTbc+c2c=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "5e0ca22929f3342b19569b21b2f3462f053e497b",
"rev": "bffc22eb12172e6db3c5dde9e3e5628f8e3e7912",
"type": "github"
},
"original": {
@ -36,11 +36,11 @@
},
"nixpkgs_2": {
"locked": {
"lastModified": 1718428119,
"narHash": "sha256-WdWDpNaq6u1IPtxtYHHWpl5BmabtpmLnMAx0RdJ/vo8=",
"lastModified": 1728538411,
"narHash": "sha256-f0SBJz1eZ2yOuKUr5CA9BHULGXVSn6miBuUWdTyhUhU=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "e6cea36f83499eb4e9cd184c8a8e823296b50ad5",
"rev": "b69de56fac8c2b6f8fd27f2eca01dcda8e0a4221",
"type": "github"
},
"original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2"
},
"locked": {
"lastModified": 1723256423,
"narHash": "sha256-9iDTrfVM+mbcad31a47oqW8t8tfSA4C/si6F8F2DO/w=",
"lastModified": 1736476219,
"narHash": "sha256-+qyv3QqdZCdZ3cSO/cbpEY6tntyYjfe1bB12mdpNFaY=",
"owner": "oxalica",
"repo": "rust-overlay",
"rev": "615cfd85b4d9c51811a8d875374268fab5bd4089",
"rev": "de30cc5963da22e9742bbbbb9a3344570ed237b9",
"type": "github"
},
"original": {

View File

@ -18,30 +18,69 @@
rustc = rustVersion;
};
appName = "gege-jdr-backend";
backendVersion = "0.1.0";
appNameBackend = "gejdr-backend";
appNameBot = "gejdr-bot";
appRustBuild = rustPlatform.buildRustPackage {
pname = appName;
appRustBuildBackend = rustPlatform.buildRustPackage {
pname = appNameBackend;
version = backendVersion;
src = ./.;
cargoLock.lockFile = ./Cargo.lock;
buildPhase = ''
cd gejdr-backend
SQLX_OFFLINE="1" cargo build --release --bin gejdr-backend
'';
installPhase = ''
cargo install --path . --root "$out/bin/"
'';
};
appRustBuildBot = rustPlatform.buildRustPackage {
pname = appNameBot;
version = "0.1.0";
src = ./.;
cargoLock.lockFile = ./Cargo.lock;
buildPhase = ''
SQLX_OFFLINE="1" cargo build --release --bin gejdr-bot
'';
};
dockerImage = pkgs.dockerTools.buildLayeredImage {
name = appName;
dockerImageBackend = pkgs.dockerTools.buildLayeredImage {
name = appNameBackend;
tag = "latest";
config = {
Entrypoint = [ "${appRustBuild}/bin/${appName}" ];
Entrypoint = [ "${appRustBuildBackend}/bin/${appNameBackend}" ];
Env = [ "SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt" ];
Tag = "latest";
};
contents = [ appRustBuild pkgs.cacert ];
contents = [ appRustBuildBackend pkgs.cacert ];
};
dockerImageBot = pkgs.dockerTools.buildLayeredImage {
name = appNameBot;
tag = "latest";
fromImageTag = "latest";
config = {
Entrypoint = [ "${appRustBuildBot}/bin/${appNameBot}" ];
Env = [ "SSL_CERT_FILE=${pkgs.cacert}/etc/ssl/certs/ca-bundle.crt" ];
Tag = "latest";
Label = {
"maintainer" = "Lucien Cartier-Tilet <lucien@phundrak.com>";
"version" = backendVersion;
"description" = "Backend for GéJDR, a Discord helper tool for the TTRPG L'Anneau Unique";
"source" = "https://labs.phundrak.com/phundrak/gejdr-rs";
"licenses" = "AGPL-3.0-or-later";
};
};
contents = [ appRustBuildBot pkgs.cacert ];
};
in {
packages = {
rustPackage = appRustBuild;
docker = dockerImage;
backend = appRustBuildBot;
bot = appRustBuildBot;
dockerBackend = dockerImageBackend;
dockerBot = dockerImageBot;
};
defaultPackage = dockerImage;
defaultPackage = dockerImageBackend;
devShell = with pkgs; mkShell {
buildInputs = [
bacon
@ -49,7 +88,6 @@
cargo-audit
cargo-auditable
cargo-tarpaulin
cargo-msrv
just
rust-analyzer
(rustVersion.override { extensions = [ "rust-src" ]; })

45
gejdr-backend/Cargo.toml Normal file
View File

@ -0,0 +1,45 @@
[package]
name = "gejdr-backend"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
license = "AGPL-3.0-or-later"
repository = "https://labs.phundrak.com/phundrak/gejdr-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
path = "src/lib.rs"
[[bin]]
path = "src/main.rs"
name = "gejdr-backend"
[dependencies]
gejdr-core = { path = "../gejdr-core" }
chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.1", features = ["yaml"] }
dotenvy = "0.15.7"
oauth2 = "4.4.2"
quote = "1.0.37"
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
serde = "1.0.215"
serde_json = "1.0.133"
thiserror = "1.0.69"
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.poem]
version = "3.1.3"
default-features = false
features = ["csrf", "rustls", "cookie", "test", "session"]
[dependencies.poem-openapi]
version = "5.1.2"
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }

View File

@ -0,0 +1,18 @@
default: run
build $SQLX_OFFLINE="1":
pwd
cargo auditable build --bin gejdr-backend
build-release $SQLX_OFFLINE="1":
cargo auditable build --release --bin gejdr-backend
build-docker:
nix build .#dockerBackend
run:
cargo auditable run --bin gejdr-backend
## Local Variables:
## mode: makefile
## End:

View File

@ -16,3 +16,7 @@ email:
user: user@gege-jdr-backend.example
from: GegeJdrBackend <noreply@gege-jdr-backend.example>
password: hunter2
discord:
client_id: changeme
client_secret: changeme

View File

@ -3,5 +3,5 @@ debug: true
application:
protocol: http
host: 127.0.0.1
base_url: http://127.0.0.1:3000
host: localhost
base_url: http://localhost:3000

View File

@ -0,0 +1,55 @@
use super::{ApiError, DiscordErrorResponse};
use gejdr_core::models::accounts::RemoteUser;
static DISCORD_URL: &str = "https://discord.com/api/v10/";
pub async fn get_user_profile(token: &str) -> Result<RemoteUser, ApiError> {
let client = reqwest::Client::new();
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {token}").parse().unwrap(),
);
let response = client
.get(format!("{DISCORD_URL}/users/@me"))
.headers(headers)
.send()
.await;
match response {
Ok(resp) => {
if resp.status().is_success() {
resp.json::<RemoteUser>()
.await
.map_err(std::convert::Into::into)
} else {
let error_response = resp.json::<DiscordErrorResponse>().await;
match error_response {
Ok(val) => Err(ApiError::Api(val)),
Err(e) => Err(ApiError::Reqwest(e)),
}
}
}
Err(e) => Err(ApiError::Reqwest(e)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn user_profile_invalid_token_results_in_401() {
let res = get_user_profile("invalid").await;
assert!(res.is_err());
let err = res.err().unwrap();
println!("Error: {err:?}");
let expected = DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
};
assert!(matches!(ApiError::Api(expected), _err));
}
// TODO: Find a way to mock calls to discord.com API with a
// successful reply
}

View File

@ -0,0 +1,41 @@
use reqwest::Error as ReqwestError;
use std::fmt::{self, Display};
use thiserror::Error;
pub mod discord;
#[derive(Debug, serde::Deserialize, PartialEq, Eq)]
pub struct DiscordErrorResponse {
pub message: String,
pub code: u16,
}
impl Display for DiscordErrorResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "DiscordErrorResponse: {} ({})", self.message, self.code)
}
}
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Reqwest error: {0}")]
Reqwest(#[from] ReqwestError),
#[error("API Error: {0}")]
Api(DiscordErrorResponse),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_discord_error_response() {
let error = DiscordErrorResponse {
message: "Message".into(),
code: 42,
};
let error_str = error.to_string();
let expected = "DiscordErrorResponse: Message (42)".to_string();
assert_eq!(expected, error_str);
}
}

View File

@ -0,0 +1,14 @@
use thiserror::Error;
#[allow(dead_code)]
#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
Sql(#[from] gejdr_core::sqlx::Error),
#[error("OAuth token error: {0}")]
TokenError(String),
#[error("Unauthorized")]
Unauthorized,
#[error("Attempted to get a value, none found")]
OptionError,
}

View File

@ -5,10 +5,14 @@
#![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
pub mod route;
pub mod settings;
pub mod startup;
pub mod telemetry;
use gejdr_core::sqlx;
mod api_wrapper;
mod errors;
mod oauth;
mod route;
mod settings;
mod startup;
type MaybeListener = Option<poem::listener::TcpListener<String>>;
@ -16,8 +20,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
dotenvy::dotenv().ok();
let settings = settings::Settings::new().expect("Failed to read settings");
if !cfg!(test) {
let subscriber = telemetry::get_subscriber(settings.clone().debug);
telemetry::init_subscriber(subscriber);
let subscriber = gejdr_core::telemetry::get_subscriber(settings.clone().debug);
gejdr_core::telemetry::init_subscriber(subscriber);
}
tracing::event!(
target: "gege-jdr-backend",
@ -29,8 +33,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
tracing::event!(
target: "gege-jdr-backend",
tracing::Level::INFO,
"Listening on http://127.0.0.1:{}/",
application.port()
"Listening on {}",
application.settings.web_address()
);
application
}

View File

@ -1,5 +1,5 @@
#[cfg(not(tarpaulin_include))]
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
gege_jdr_backend::run(None).await
gejdr_backend::run(None).await
}

View File

@ -0,0 +1,62 @@
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
};
use reqwest::Url;
use crate::{errors::ApiError, settings::Settings};
use super::OauthProvider;
#[derive(Debug, Clone)]
pub struct DiscordOauthProvider {
client: BasicClient,
}
impl DiscordOauthProvider {
pub fn new(settings: &Settings) -> Self {
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
.expect("Invalid token endpoint URL");
let revocation_url =
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
.expect("Invalid revocation URL");
let client = BasicClient::new(
ClientId::new(settings.discord.client_id.clone()),
Some(ClientSecret::new(settings.discord.client_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
.set_revocation_uri(revocation_url);
Self { client }
}
}
impl OauthProvider for DiscordOauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
.set_pkce_challenge(pkce_challenge)
.url();
(auth_url, csrf_token, pkce_verifier)
}
async fn token(
&self,
code: String,
verifier: PkceCodeVerifier,
) -> Result<super::Token, ApiError> {
self.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(verifier)
.request_async(oauth2::reqwest::async_http_client)
.await
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
}
}

View File

@ -0,0 +1,17 @@
mod discord;
pub use discord::DiscordOauthProvider;
use oauth2::{
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
StandardTokenResponse,
};
use reqwest::Url;
use crate::errors::ApiError;
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
pub trait OauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
}

View File

@ -0,0 +1,220 @@
use gejdr_core::models::accounts::User;
use gejdr_core::sqlx::PgPool;
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
use poem::web::Data;
use poem::{session::Session, web::Form};
use poem_openapi::payload::{Json, PlainText};
use poem_openapi::{ApiResponse, Object, OpenApi};
use crate::oauth::{DiscordOauthProvider, OauthProvider};
use super::errors::ErrorResponse;
use super::ApiCategory;
type Token =
oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>;
pub struct AuthApi;
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
struct DiscordCallbackRequest {
code: String,
state: String,
}
impl DiscordCallbackRequest {
pub fn check_token(&self, token: &CsrfToken) -> Result<(), LoginStatusResponse> {
if *token.secret().to_string() == self.state {
Ok(())
} else {
Err(LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some(
"OAuth provider did not send a message that matches what was expected".into(),
),
})))
}
}
}
#[derive(ApiResponse)]
enum LoginStatusResponse {
#[oai(status = 201)]
LoggedIn(Json<UserInfo>),
#[oai(status = 201)]
LoggedOut(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 301)]
LoginRedirect(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 500)]
TokenError(Json<ErrorResponse>),
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
#[derive(Debug, Eq, PartialEq, serde::Serialize, Object)]
struct UserInfo {
id: String,
username: String,
display_name: Option<String>,
avatar: Option<String>,
}
impl From<User> for UserInfo {
fn from(value: User) -> Self {
Self {
id: value.id,
username: value.username,
display_name: value.name,
avatar: value.avatar,
}
}
}
#[derive(ApiResponse)]
enum UserInfoResponse {
#[oai(status = 201)]
UserInfo(Json<UserInfo>),
#[oai(status = 401)]
Unauthorized,
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
impl From<UserInfoResponse> for LoginStatusResponse {
fn from(value: UserInfoResponse) -> Self {
match value {
UserInfoResponse::UserInfo(json) => Self::LoggedIn(json),
UserInfoResponse::Unauthorized => unimplemented!(),
UserInfoResponse::DatabaseError(json) => Self::DatabaseError(json),
UserInfoResponse::DiscordError(json) => Self::DiscordError(json),
}
}
}
#[derive(ApiResponse)]
enum CsrfResponse {
#[oai(status = 201)]
Token(PlainText<String>),
}
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
impl AuthApi {
async fn fetch_remote_user(
pool: Data<&PgPool>,
token: Token,
) -> Result<UserInfoResponse, UserInfoResponse> {
crate::api_wrapper::discord::get_user_profile(token.access_token().secret())
.await
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Failed to communicate with Discord: {}",
e
);
UserInfoResponse::DiscordError(Json(e.into()))
})?
.refresh_in_database(&pool)
.await
.map(|user| UserInfoResponse::UserInfo(Json(user.into())))
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Database error: {}",
e
);
UserInfoResponse::DatabaseError(Json(e.into()))
})
}
#[oai(path = "/signin/discord", method = "get")]
async fn signin_discord(
&self,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
session.set("csrf", csrf_token);
session.set("pkce", pkce_verifier);
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Signin through Discord",
);
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
}
#[oai(path = "/callback/discord", method = "get")]
async fn callback_discord(
&self,
Form(auth_request): Form<DiscordCallbackRequest>,
oauth: Data<&DiscordOauthProvider>,
pool: Data<&PgPool>,
session: &Session,
) -> Result<LoginStatusResponse, LoginStatusResponse> {
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Discord callback",
);
let csrf_token = session.get::<CsrfToken>("csrf").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch csrf token from session".to_string(),
..Default::default()
}))
})?;
auth_request.check_token(&csrf_token)?;
let pkce_verifier = session.get::<PkceCodeVerifier>("pkce").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch pkce verifier from session".to_string(),
..Default::default()
}))
})?;
let token = oauth
.token(auth_request.code, pkce_verifier)
.await
.map_err(|e| LoginStatusResponse::TokenError(Json(e.into())))?;
session.set("token", token.clone());
Self::fetch_remote_user(pool, token)
.await
.map(std::convert::Into::into)
.map_err(std::convert::Into::into)
}
#[oai(path = "/csrf", method = "get")]
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
CsrfResponse::Token(PlainText(token.0.clone()))
}
#[oai(path = "/signout", method = "post")]
async fn signout(&self, session: &Session) -> LoginStatusResponse {
session.purge();
LoginStatusResponse::LoggedOut("/".to_string(), "no-cache".to_string())
}
#[oai(path = "/me", method = "get")]
async fn user_info(
&self,
session: &Session,
pool: Data<&PgPool>,
) -> Result<UserInfoResponse, UserInfoResponse> {
let token = session
.get::<Token>("token")
.ok_or(UserInfoResponse::Unauthorized)?;
Self::fetch_remote_user(pool, token).await
}
}

View File

@ -0,0 +1,204 @@
use poem_openapi::Object;
use reqwest::Error as ReqwestError;
use crate::api_wrapper::ApiError as ApiWrapperError;
use crate::errors::ApiError;
#[derive(Debug, serde::Serialize, Default, Object, PartialEq, Eq)]
pub struct ErrorResponse {
pub code: u16,
pub message: String,
pub details: Option<String>,
}
impl From<ApiError> for ErrorResponse {
fn from(value: ApiError) -> Self {
match value {
ApiError::Sql(e) => Self {
code: 500,
message: "SQL error".into(),
details: Some(e.to_string()),
},
ApiError::TokenError(e) => Self {
code: 500,
message: "OAuth token error".into(),
details: Some(e),
},
ApiError::Unauthorized => Self {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
},
ApiError::OptionError => Self {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
},
}
}
}
impl From<ReqwestError> for ErrorResponse {
fn from(value: ReqwestError) -> Self {
Self {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some(value.status().map_or_else(
|| "Communication failed before we could hear back from Discord".into(),
|status| format!("Discord sent back the error code {status}"),
)),
}
}
}
impl From<ApiWrapperError> for ErrorResponse {
fn from(source: ApiWrapperError) -> Self {
match source {
ApiWrapperError::Reqwest(e) => e.into(),
ApiWrapperError::Api(e) => Self {
code: if e.message.as_str().starts_with("401") {
401
} else {
e.code
},
message: e.message,
details: None,
},
}
}
}
impl From<gejdr_core::sqlx::Error> for ErrorResponse {
fn from(_value: gejdr_core::sqlx::Error) -> Self {
Self {
code: 500,
message: "Internal database error".into(),
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api_wrapper::{ApiError as ApiWrapperError, DiscordErrorResponse};
#[test]
fn conversion_from_sql_api_error_works() {
let sql_error = ApiError::Sql(gejdr_core::sqlx::Error::ColumnNotFound(
"COLUMN_NAME".to_string(),
));
let final_error = ErrorResponse::from(sql_error);
let expected_error = ErrorResponse {
code: 500,
message: "SQL error".into(),
details: Some("no column found for name: COLUMN_NAME".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_token_error_works() {
let initial_error = ApiError::TokenError("TOKEN ERROR".into());
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some("TOKEN ERROR".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_unauthorized_works() {
let initial_error = ApiError::Unauthorized;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_option_error_works() {
let initial_error = ApiError::OptionError;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[tokio::test]
async fn conversion_from_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.err().unwrap().into();
assert_eq!(expected, actual);
}
#[tokio::test]
async fn conversion_from_apiwrappererror_with_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let err = ApiWrapperError::Reqwest(err.err().unwrap());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_401_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
});
let expected = ErrorResponse {
code: 401,
message: "401: Unauthorized".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_generic_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "Something else".into(),
});
let expected = ErrorResponse {
code: 0,
message: "Something else".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_database_error() {
let err = gejdr_core::sqlx::Error::PoolClosed;
let expected = ErrorResponse {
code: 500,
message: "Internal database error".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
}

View File

@ -10,7 +10,7 @@ enum HealthResponse {
pub struct HealthApi;
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")]
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
impl HealthApi {
#[oai(path = "/", method = "get")]
async fn health_check(&self) -> HealthResponse {
@ -23,7 +23,7 @@ impl HealthApi {
async fn health_check_works() {
let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/health-check").send().await;
let resp = cli.get("/v1/api/health-check").send().await;
resp.assert_status_is_ok();
resp.assert_text("").await;
}

View File

@ -6,13 +6,19 @@ pub use health::HealthApi;
mod version;
pub use version::VersionApi;
mod errors;
mod auth;
pub use auth::AuthApi;
#[derive(Tags)]
enum ApiCategory {
Auth,
Health,
Version,
}
pub(crate) struct Api;
pub struct Api;
#[OpenApi]
impl Api {}

View File

@ -25,7 +25,7 @@ enum VersionResponse {
pub struct VersionApi;
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")]
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
impl VersionApi {
#[oai(path = "/", method = "get")]
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
@ -38,7 +38,7 @@ impl VersionApi {
async fn version_works() {
let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/version").send().await;
let resp = cli.get("/v1/api/version").send().await;
resp.assert_status_is_ok();
let json = resp.json().await;
let json_value = json.value();

View File

@ -1,9 +1,10 @@
use sqlx::ConnectOptions;
use gejdr_core::database::Database;
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Settings {
pub application: ApplicationSettings,
pub database: Database,
pub discord: Discord,
pub debug: bool,
pub email: EmailSettings,
pub frontend_url: String,
@ -12,15 +13,7 @@ pub struct Settings {
impl Settings {
#[must_use]
pub fn web_address(&self) -> String {
if self.debug {
format!(
"{}:{}",
self.application.base_url.clone(),
self.application.port
)
} else {
self.application.base_url.clone()
}
self.application.base_url.clone()
}
/// Multipurpose function that helps detect the current
@ -65,7 +58,7 @@ impl Settings {
))
.add_source(
config::Environment::with_prefix("APP")
.prefix_separator("_")
.prefix_separator("__")
.separator("__"),
)
.build()?;
@ -83,35 +76,6 @@ pub struct ApplicationSettings {
pub protocol: String,
}
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Database {
pub host: String,
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
pub require_ssl: bool,
}
impl Database {
#[must_use]
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
let ssl_mode = if self.require_ssl {
sqlx::postgres::PgSslMode::Require
} else {
sqlx::postgres::PgSslMode::Prefer
};
sqlx::postgres::PgConnectOptions::new()
.host(&self.host)
.username(&self.user)
.password(&self.password)
.port(self.port)
.ssl_mode(ssl_mode)
.database(&self.name)
.log_statements(tracing::log::LevelFilter::Trace)
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum Environment {
Development,
@ -164,6 +128,12 @@ pub struct EmailSettings {
pub from: String,
}
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Discord {
pub client_id: String,
pub client_secret: String,
}
#[cfg(test)]
mod tests {
use super::*;
@ -228,7 +198,7 @@ mod tests {
#[test]
fn web_address_works() {
let mut settings = Settings {
let settings = Settings {
debug: false,
application: ApplicationSettings {
base_url: "127.0.0.1".to_string(),
@ -237,10 +207,7 @@ mod tests {
},
..Default::default()
};
let expected_no_debug = "127.0.0.1".to_string();
let expected_debug = "127.0.0.1:3000".to_string();
assert_eq!(expected_no_debug, settings.web_address());
settings.debug = true;
assert_eq!(expected_debug, settings.web_address());
let expected = "127.0.0.1".to_string();
assert_eq!(expected, settings.web_address());
}
}

View File

@ -1,35 +1,36 @@
use poem::middleware::Cors;
use gejdr_core::sqlx;
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
use poem::middleware::{CookieJarManagerEndpoint, Cors};
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService;
use crate::oauth::DiscordOauthProvider;
use crate::route::AuthApi;
use crate::{
route::{Api, HealthApi, VersionApi},
settings::Settings,
};
#[must_use]
pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgres::PgPool {
tracing::event!(
target: "startup",
tracing::Level::INFO,
"connecting to database with configuration {:?}",
settings.clone()
);
sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(settings.get_connect_options())
}
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>;
pub type App = AddDataEndpoint<
AddDataEndpoint<
AddDataEndpoint<
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
DiscordOauthProvider,
>,
sqlx::Pool<sqlx::Postgres>,
>,
Settings,
>;
pub struct Application {
server: Server,
app: poem::Route,
port: u16,
database: sqlx::postgres::PgPool,
settings: Settings,
pub settings: Settings,
}
pub struct RunnableApplication {
@ -61,6 +62,8 @@ impl From<Application> for RunnableApplication {
let app = val
.app
.with(Cors::new())
.with(CookieSession::new(CookieConfig::default().secure(true)))
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
.data(val.database)
.data(val.settings);
let server = val.server;
@ -74,16 +77,16 @@ impl Application {
test_pool: Option<sqlx::postgres::PgPool>,
) -> sqlx::postgres::PgPool {
let database_pool =
test_pool.map_or_else(|| get_connection_pool(&settings.database), |pool| pool);
test_pool.map_or_else(|| settings.database.get_connection_pool(), |pool| pool);
if !cfg!(test) {
migrate_database(&database_pool).await;
gejdr_core::database::Database::migrate(&database_pool).await;
}
database_pool
}
fn setup_app(settings: &Settings) -> poem::Route {
let api_service = OpenApiService::new(
(Api, HealthApi, VersionApi),
(Api, AuthApi, HealthApi, VersionApi),
settings.application.clone().name,
settings.application.clone().version,
);
@ -104,6 +107,7 @@ impl Application {
});
poem::Server::new(tcp_listener)
}
pub async fn build(
settings: Settings,
test_pool: Option<sqlx::postgres::PgPool>,
@ -133,10 +137,3 @@ impl Application {
self.port
}
}
async fn migrate_database(pool: &sqlx::postgres::PgPool) {
sqlx::migrate!()
.run(pool)
.await
.expect("Failed to migrate the database");
}

1
gejdr-bot/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

10
gejdr-bot/Cargo.toml Normal file
View File

@ -0,0 +1,10 @@
[package]
name = "gejdr-bot"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
license = "AGPL-3.0-or-later"
repository = "https://labs.phundrak.com/phundrak/gejdr-rs"
[dependencies]

3
gejdr-bot/src/main.rs Normal file
View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

1
gejdr-core/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

21
gejdr-core/Cargo.toml Normal file
View File

@ -0,0 +1,21 @@
[package]
name = "gejdr-core"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
license = "AGPL-3.0-or-later"
repository = "https://labs.phundrak.com/phundrak/gejdr-rs"
[dependencies]
chrono = { version = "0.4.38", features = ["serde"] }
serde = "1.0.215"
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] }
gejdr-macros = { path = "../gejdr-macros" }
[dependencies.sqlx]
version = "0.8.3"
default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]

View File

@ -0,0 +1,3 @@
-- Add down migration script here
DROP TABLE IF EXISTS public.users;
DROP EXTENSION IF EXISTS "uuid-ossp";

View File

@ -0,0 +1,15 @@
-- Add up migration script here
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users
(
id character varying(255) NOT NULL,
username character varying(255) NOT NULL,
email character varying(255),
avatar character varying(511),
name character varying(255),
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
CONSTRAINT users_email_unique UNIQUE (email)
);

View File

@ -0,0 +1,50 @@
use sqlx::ConnectOptions;
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Database {
pub host: String,
pub port: u16,
pub name: String,
pub user: String,
pub password: String,
pub require_ssl: bool,
}
impl Database {
#[must_use]
pub fn get_connect_options(&self) -> sqlx::postgres::PgConnectOptions {
let ssl_mode = if self.require_ssl {
sqlx::postgres::PgSslMode::Require
} else {
sqlx::postgres::PgSslMode::Prefer
};
sqlx::postgres::PgConnectOptions::new()
.host(&self.host)
.username(&self.user)
.password(&self.password)
.port(self.port)
.ssl_mode(ssl_mode)
.database(&self.name)
.log_statements(tracing::log::LevelFilter::Trace)
}
#[must_use]
pub fn get_connection_pool(&self) -> sqlx::postgres::PgPool {
tracing::event!(
target: "startup",
tracing::Level::INFO,
"connecting to database with configuration {:?}",
self.clone()
);
sqlx::postgres::PgPoolOptions::new()
.acquire_timeout(std::time::Duration::from_secs(2))
.connect_lazy_with(self.get_connect_options())
}
pub async fn migrate(pool: &sqlx::PgPool) {
sqlx::migrate!()
.run(pool)
.await
.expect("Failed to migrate the database");
}
}

4
gejdr-core/src/lib.rs Normal file
View File

@ -0,0 +1,4 @@
pub mod database;
pub mod models;
pub mod telemetry;
pub use sqlx;

View File

@ -0,0 +1,337 @@
use super::Crud;
use sqlx::PgPool;
type Timestampz = chrono::DateTime<chrono::Utc>;
#[derive(serde::Deserialize, PartialEq, Eq, Debug, Clone, Default)]
pub struct RemoteUser {
id: String,
username: String,
global_name: Option<String>,
email: Option<String>,
avatar: Option<String>,
}
impl RemoteUser {
/// Refresh in database the row related to the remote user. Maybe
/// create a row for this user if needed.
pub async fn refresh_in_database(self, pool: &PgPool) -> Result<User, sqlx::Error> {
match User::find(pool, &self.id).await? {
Some(local_user) => local_user.update_from_remote(self).update(pool).await,
None => User::from(self).create(pool).await,
}
}
}
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq, Default, Clone, Crud)]
#[crud(table = "users")]
pub struct User {
#[crud(id = true)]
pub id: String,
pub username: String,
pub email: Option<String>,
pub avatar: Option<String>,
pub name: Option<String>,
pub created_at: Timestampz,
pub last_updated: Timestampz,
}
impl From<RemoteUser> for User {
fn from(value: RemoteUser) -> Self {
Self {
id: value.id,
username: value.username,
email: value.email,
avatar: value.avatar,
name: value.global_name,
created_at: chrono::offset::Utc::now(),
last_updated: chrono::offset::Utc::now(),
}
}
}
impl PartialEq<RemoteUser> for User {
#[allow(clippy::suspicious_operation_groupings)]
fn eq(&self, other: &RemoteUser) -> bool {
self.id == other.id
&& self.username == other.username
&& self.email == other.email
&& self.avatar == other.avatar
&& self.name == other.global_name
}
}
impl PartialEq<User> for RemoteUser {
fn eq(&self, other: &User) -> bool {
other == self
}
}
impl User {
pub fn update_from_remote(self, from: RemoteUser) -> Self {
if self == from {
self
} else {
Self {
username: from.username,
email: from.email,
avatar: from.avatar,
name: from.global_name,
last_updated: chrono::offset::Utc::now(),
..self
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert_remote_user_to_local_user() {
let remote = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: None,
};
let local: User = remote.into();
let expected = User {
id: "user-id".into(),
username: "username".into(),
email: Some("user@example.com".into()),
avatar: None,
name: None,
created_at: local.created_at,
last_updated: local.last_updated,
};
assert_eq!(expected, local);
}
#[test]
fn can_compare_remote_and_local_user() {
let remote_same = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: None,
};
let remote_different = RemoteUser {
id: "user-id".into(),
username: "username".into(),
global_name: None,
email: Some("user@example.com".into()),
avatar: Some("some-hash".into()),
};
let local = User {
id: "user-id".into(),
username: "username".into(),
email: Some("user@example.com".into()),
avatar: None,
name: None,
created_at: chrono::offset::Utc::now(),
last_updated: chrono::offset::Utc::now(),
};
assert_eq!(remote_same, local);
assert_ne!(remote_different, local);
}
#[sqlx::test]
async fn add_new_remote_users_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let remote1 = RemoteUser {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let remote2 = RemoteUser {
id: "id2".into(),
username: "user2".into(),
..Default::default()
};
remote1.refresh_in_database(&pool).await?;
remote2.refresh_in_database(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn update_local_users_in_db_from_remote(pool: sqlx::PgPool) -> sqlx::Result<()> {
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
let remote1 = RemoteUser {
id: "id1".into(),
username: "user1-new".into(),
..Default::default()
};
let remote2 = RemoteUser {
id: "id2".into(),
username: "user2-new".into(),
..Default::default()
};
remote1.refresh_in_database(&pool).await?;
remote2.refresh_in_database(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, users.len());
users
.iter()
.for_each(|user| assert!(user.last_updated > user.created_at));
Ok(())
}
#[test]
fn update_local_user_from_identical_remote_shouldnt_change_local() {
let remote = RemoteUser {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let local = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let new_local = local.clone().update_from_remote(remote);
assert_eq!(local, new_local);
}
#[test]
fn update_local_user_from_different_remote() {
let remote = RemoteUser {
id: "id1".into(),
username: "user2".into(),
..Default::default()
};
let local = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
let new_local = local.clone().update_from_remote(remote.clone());
assert_ne!(remote, local);
assert_eq!(remote, new_local);
}
#[sqlx::test]
async fn save_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let user = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
user.create(&pool).await?;
let users = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, users.len());
assert_eq!(Some(user), users.first().cloned());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn update_user_in_database(pool: sqlx::PgPool) -> sqlx::Result<()> {
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert!(db_user.name.is_none());
let user = User {
id: "id1".into(),
username: "user1".into(),
name: Some("Cool Name".into()),
..Default::default()
};
user.update(&pool).await?;
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert!(db_user.name.is_some());
assert_eq!(Some("Cool Name".to_string()), db_user.name);
Ok(())
}
#[sqlx::test]
async fn save_or_update_saves_if_no_exist(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(0, rows.len());
let user = User {
id: "id1".into(),
username: "user1".into(),
..Default::default()
};
user.create_or_update(&pool).await?;
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, rows.len());
let db_user = rows.first();
assert_eq!(Some(user), db_user.cloned());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn save_or_update_updates_if_exists(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let user = User {
id: "id1".into(),
username: "user1".into(),
name: Some("Cool Nam".into()),
..Default::default()
};
user.create_or_update(&pool).await?;
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let db_user = sqlx::query_as!(User, "SELECT * FROM users WHERE id = 'id1'")
.fetch_one(&pool)
.await?;
assert_eq!(user.name, db_user.name);
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn delete_removes_account_from_db(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let id = "id1".to_string();
let deletions = User::delete_by_id(&pool, &id).await?;
assert_eq!(1, deletions);
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(1, rows.len());
Ok(())
}
#[sqlx::test(fixtures("accounts"))]
async fn delete_with_wrong_id_shouldnt_delete_anything(pool: sqlx::PgPool) -> sqlx::Result<()> {
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
let id = "invalid".to_string();
let deletions = User::delete_by_id(&pool, &id).await?;
assert_eq!(0, deletions);
let rows = sqlx::query_as!(User, "SELECT * FROM users")
.fetch_all(&pool)
.await?;
assert_eq!(2, rows.len());
Ok(())
}
}

View File

@ -0,0 +1,2 @@
INSERT INTO users (id, username) VALUES ('id1', 'user1');
INSERT INTO users (id, username) VALUES ('id2', 'user2');

View File

@ -0,0 +1,73 @@
pub mod accounts;
pub use gejdr_macros::Crud;
pub trait Crud<Id> {
/// Find the entiy in the database based on its identifier.
///
/// # Errors
/// Returns any error Postgres may have encountered
fn find(
pool: &sqlx::PgPool,
id: &Id,
) -> impl std::future::Future<Output = sqlx::Result<Option<Self>>> + Send
where
Self: Sized;
/// Create the entity in the database.
///
/// # Errors
/// Returns any error Postgres may have encountered
fn create(
&self,
pool: &sqlx::PgPool,
) -> impl std::future::Future<Output = sqlx::Result<Self>> + Send
where
Self: Sized;
/// Update an entity with a matching identifier in the database.
///
/// # Errors
/// Returns any error Postgres may have encountered
fn update(
&self,
pool: &sqlx::PgPool,
) -> impl std::future::Future<Output = sqlx::Result<Self>> + Send
where
Self: Sized;
/// Update an entity with a matching identifier in the database if
/// it exists, create it otherwise.
///
/// # Errors
/// Returns any error Postgres may have encountered
fn create_or_update(
&self,
pool: &sqlx::PgPool,
) -> impl std::future::Future<Output = sqlx::Result<Self>> + Send
where
Self: Sized;
/// Delete the entity from the database if it exists.
///
/// # Returns
/// Returns the amount of rows affected by the deletion.
///
/// # Errors
/// Returns any error Postgres may have encountered
fn delete(
&self,
pool: &sqlx::PgPool,
) -> impl std::future::Future<Output = sqlx::Result<u64>> + Send;
/// Delete any entity with the identifier `id`.
///
/// # Returns
/// Returns the amount of rows affected by the deletion.
///
/// # Errors
/// Returns any error Postgres may have encountered
fn delete_by_id(
pool: &sqlx::PgPool,
id: &Id,
) -> impl std::future::Future<Output = sqlx::Result<u64>> + Send;
}

View File

@ -5,7 +5,7 @@ pub fn get_subscriber(debug: bool) -> impl tracing::Subscriber + Send + Sync {
let env_filter = if debug { "debug" } else { "info" }.to_string();
let env_filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(env_filter));
let stdout_log = tracing_subscriber::fmt::layer().pretty();
let stdout_log = tracing_subscriber::fmt::layer().pretty().with_test_writer();
let subscriber = tracing_subscriber::Registry::default()
.with(env_filter)
.with(stdout_log);

22
gejdr-macros/Cargo.toml Normal file
View File

@ -0,0 +1,22 @@
[package]
name = "gejdr-macros"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
license = "AGPL-3.0-or-later"
repository = "https://labs.phundrak.com/phundrak/gejdr-rs"
[lib]
proc-macro = true
[dependencies]
deluxe = "0.5.0"
proc-macro2 = "1.0.93"
quote = "1.0.38"
syn = "2.0.96"
[dependencies.sqlx]
version = "0.8.3"
default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]

View File

@ -0,0 +1,23 @@
#[derive(deluxe::ExtractAttributes)]
#[deluxe(attributes(crud))]
pub struct CrudStructAttributes {
pub table: String,
}
#[derive(deluxe::ExtractAttributes, Clone)]
#[deluxe(attributes(crud))]
pub struct CrudFieldAttributes {
#[deluxe(default = false)]
pub id: bool,
#[deluxe(default = None)]
pub column: Option<String>,
}
#[derive(Clone)]
pub struct CrudField {
pub ident: syn::Ident,
pub field: syn::Field,
pub column: String,
pub id: bool,
pub ty: syn::Type,
}

View File

@ -0,0 +1,188 @@
use ir::{CrudField, CrudFieldAttributes, CrudStructAttributes};
use quote::quote;
use syn::DeriveInput;
mod ir;
fn extract_crud_field_attrs(ast: &mut DeriveInput) -> deluxe::Result<(Vec<CrudField>, CrudField)> {
let mut field_attrs: Vec<CrudField> = Vec::new();
// let mut identifier: Option<CrudIdentifier> = None;
let mut identifier: Option<CrudField> = None;
let mut identifier_counter = 0;
if let syn::Data::Struct(s) = &mut ast.data {
for field in &mut s.fields {
let ident = field.clone().ident.unwrap();
let ty = field.clone().ty;
let attrs: CrudFieldAttributes =
deluxe::extract_attributes(field).expect("Could not extract attributes from field");
let field = CrudField {
ident: ident.clone(),
field: field.to_owned(),
column: attrs.column.unwrap_or_else(|| ident.to_string()),
id: attrs.id,
ty,
};
if attrs.id {
identifier_counter += 1;
identifier = Some(field.clone());
}
if identifier_counter > 1 {
return Err(syn::Error::new_spanned(
field.field,
"Struct {name} can only have one identifier",
));
}
field_attrs.push(field);
}
}
if identifier_counter < 1 {
Err(syn::Error::new_spanned(
ast,
"Struct {name} must have one identifier",
))
} else {
Ok((field_attrs, identifier.unwrap()))
}
}
fn generate_find_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream {
let find_string = format!("SELECT * FROM {} WHERE {} = $1", table, id.column);
let ty = &id.ty;
quote! {
async fn find(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<Option<Self>> {
::sqlx::query_as!(Self, #find_string, id)
.fetch_optional(pool)
.await
}
}
}
fn generate_create_query(table: &str, fields: &[CrudField]) -> proc_macro2::TokenStream {
let inputs: Vec<String> = (1..=fields.len()).map(|num| format!("${num}")).collect();
let create_string = format!(
"INSERT INTO {} ({}) VALUES ({}) RETURNING *",
table,
fields
.iter()
.map(|v| v.column.clone())
.collect::<Vec<String>>()
.join(", "),
inputs.join(", ")
);
let field_idents: Vec<syn::Ident> = fields.iter().map(|f| f.ident.clone()).collect();
quote! {
async fn create(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
::sqlx::query_as!(
Self,
#create_string,
#(self.#field_idents),*
)
.fetch_one(pool)
.await
}
}
}
fn generate_update_query(
table: &str,
fields: &[CrudField],
id: &CrudField,
) -> proc_macro2::TokenStream {
let mut fields: Vec<&CrudField> = fields.iter().filter(|f| !f.id).collect();
let update_columns = fields
.iter()
.enumerate()
.map(|(i, &field)| format!("{} = ${}", field.column, i + 1))
.collect::<Vec<String>>()
.join(", ");
let update_string = format!(
"UPDATE {} SET {} WHERE {} = ${} RETURNING *",
table,
update_columns,
id.column,
fields.len() + 1
);
fields.push(id);
let field_idents: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
quote! {
async fn update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
::sqlx::query_as!(
Self,
#update_string,
#(self.#field_idents),*
)
.fetch_one(pool)
.await
}
}
}
fn generate_delete_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream {
let delete_string = format!("DELETE FROM {} WHERE {} = $1", table, id.column);
let ty = &id.ty;
let ident = &id.ident;
quote! {
async fn delete_by_id(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<u64> {
let rows_affected = ::sqlx::query!(#delete_string, id)
.execute(pool)
.await?
.rows_affected();
Ok(rows_affected)
}
async fn delete(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<u64> {
let rows_affected = ::sqlx::query!(#delete_string, self.#ident)
.execute(pool)
.await?
.rows_affected();
Ok(rows_affected)
}
}
}
pub fn crud_derive_macro2(
item: proc_macro2::TokenStream,
) -> deluxe::Result<proc_macro2::TokenStream> {
// parse
let mut ast: DeriveInput = syn::parse2(item).expect("Failed to parse input");
// extract struct attributes
let CrudStructAttributes { table } =
deluxe::extract_attributes(&mut ast).expect("Could not extract attributes from struct");
// extract field attributes
let (fields, id) = extract_crud_field_attrs(&mut ast)?;
let ty = &id.ty;
let id_ident = &id.ident;
// define impl variables
let ident = &ast.ident;
let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
// generate
let find_query = generate_find_query(&table, &id);
let create_query = generate_create_query(&table, &fields);
let update_query = generate_update_query(&table, &fields, &id);
let delete_query = generate_delete_query(&table, &id);
let code = quote! {
impl #impl_generics Crud<#ty> for #ident #type_generics #where_clause {
#find_query
#create_query
#update_query
async fn create_or_update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
if Self::find(pool, &self.#id_ident).await?.is_some() {
self.update(pool).await
} else {
self.create(pool).await
}
}
#delete_query
}
};
Ok(code)
}

19
gejdr-macros/src/lib.rs Normal file
View File

@ -0,0 +1,19 @@
#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![deny(clippy::nursery)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
mod crud;
use crud::crud_derive_macro2;
/// Generates CRUD code for Sqlx for a struct.
///
/// # Panics
///
/// May panic if errors arise while parsing and generating code.
#[proc_macro_derive(Crud, attributes(crud))]
pub fn crud_derive_macro(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
crud_derive_macro2(item.into()).unwrap().into()
}

View File

@ -1,10 +1,7 @@
default: run
mod backend 'gejdr-backend/backend.just'
mod docker
prepare:
cargo sqlx prepare
migrate:
sqlx migrate run
default: lint
format:
cargo fmt --all
@ -12,38 +9,30 @@ format:
format-check:
cargo fmt --check --all
build:
cargo auditable build
migrate:
sqlx migrate run --source gejdr-core/migrations
build-release:
cargo auditable build --release
build $SQLX_OFFLINE="1":
cargo auditable build --bin gejdr-backend
cargo auditable build --bin gejdr-bot
run: docker-start
cargo auditable run
run-no-docker:
cargo auditable run
build-release $SQLX_OFFLINE="1":
cargo auditable build --release --bin gejdr-backend
cargo auditable build --release --bin gejdr-bot
lint:
cargo clippy --all-targets
msrv:
cargo msrv verify
release-build:
cargo auditable build --release
release-run:
cargo auditable run --release
audit: build
cargo audit bin target/debug/gege-jdr-backend
cargo audit bin target/debug/gejdr-backend
cargo audit bin target/debug/gejdr-bot
audit-release: build-release
cargo audit bin target/release/gege-jdr-backend
cargo audit bin target/release/gejdr-backend
cargo audit bin target/release/gejdr-bot
test:
cargo test
cargo test --all-targets --all
coverage:
mkdir -p coverage
@ -53,19 +42,10 @@ coverage-ci:
mkdir -p coverage
cargo tarpaulin --config .tarpaulin.ci.toml
check-all: format-check lint msrv coverage audit
check-all: format-check lint coverage audit
docker-build:
nix build .#docker
docker-start:
docker compose -f docker/compose.dev.yml up -d
docker-stop:
docker compose -f docker/compose.dev.yml down
docker-logs:
docker compose -f docker/compose.dev.yml logs -f
docker-backend $SQLX_OFFLINE="1":
nix build .#dockerBackend
## Local Variables:
## mode: makefile

View File

@ -1,4 +1,4 @@
[toolchain]
channel = "1.78.0"
channel = "1.81.0"
components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ]
profile = "default"