feat: OAuth implementation with Discord
All checks were successful
CI / tests (push) Successful in 10m39s
CI / tests (pull_request) Successful in 11m17s

This commit separates the core features of GéJDR from the backend as
these will also be used by the bot in the future.

This commit also updates the dependencies of the project. It also
removes the dependency lettre as well as the mailpit docker service
for developers as it appears clearer this project won’t send emails
anytime soon.

The publication of a docker image is also postponed until later.
This commit is contained in:
2024-08-10 11:06:18 +02:00
parent 2013d04cf7
commit aac70e4131
49 changed files with 2699 additions and 720 deletions

43
gejdr-backend/Cargo.toml Normal file
View File

@@ -0,0 +1,43 @@
[package]
name = "gejdr-backend"
version = "0.1.0"
edition = "2021"
publish = false
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
path = "src/lib.rs"
[[bin]]
path = "src/main.rs"
name = "gejdr-backend"
[dependencies]
gejdr-core = { path = "../gejdr-core" }
chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.1", features = ["yaml"] }
dotenvy = "0.15.7"
oauth2 = "4.4.2"
quote = "1.0.37"
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
serde = "1.0.215"
serde_json = "1.0.133"
thiserror = "1.0.69"
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.poem]
version = "3.1.3"
default-features = false
features = ["csrf", "rustls", "cookie", "test", "session"]
[dependencies.poem-openapi]
version = "5.1.2"
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }

View File

@@ -0,0 +1,18 @@
default: run
build $SQLX_OFFLINE="1":
pwd
cargo auditable build --bin gejdr-backend
build-release $SQLX_OFFLINE="1":
cargo auditable build --release --bin gejdr-backend
build-docker:
nix build .#dockerBackend
run:
cargo auditable run --bin gejdr-backend
## Local Variables:
## mode: makefile
## End:

View File

@@ -0,0 +1,22 @@
application:
port: 3000
name: "GegeJdrBackend"
version: 0.1.0
database:
host: localhost
port: 5432
name: gege-jdr-backend
user: dev
password: password
require_ssl: false
email:
host: smtp.gege-jdr-backend.example
user: user@gege-jdr-backend.example
from: GegeJdrBackend <noreply@gege-jdr-backend.example>
password: hunter2
discord:
client_id: changeme
client_secret: changeme

View File

@@ -0,0 +1,7 @@
frontend_url: http://localhost:5173
debug: true
application:
protocol: http
host: localhost
base_url: http://localhost:3000

View File

@@ -0,0 +1,7 @@
debug: false
frontend_url: ""
application:
protocol: https
host: 0.0.0.0
base_url: ""

View File

@@ -0,0 +1,55 @@
use super::{ApiError, DiscordErrorResponse};
use gejdr_core::models::accounts::RemoteUser;
static DISCORD_URL: &str = "https://discord.com/api/v10/";
pub async fn get_user_profile(token: &str) -> Result<RemoteUser, ApiError> {
let client = reqwest::Client::new();
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {token}").parse().unwrap(),
);
let response = client
.get(format!("{DISCORD_URL}/users/@me"))
.headers(headers)
.send()
.await;
match response {
Ok(resp) => {
if resp.status().is_success() {
resp.json::<RemoteUser>()
.await
.map_err(std::convert::Into::into)
} else {
let error_response = resp.json::<DiscordErrorResponse>().await;
match error_response {
Ok(val) => Err(ApiError::Api(val)),
Err(e) => Err(ApiError::Reqwest(e)),
}
}
}
Err(e) => Err(ApiError::Reqwest(e)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn user_profile_invalid_token_results_in_401() {
let res = get_user_profile("invalid").await;
assert!(res.is_err());
let err = res.err().unwrap();
println!("Error: {err:?}");
let expected = DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
};
assert!(matches!(ApiError::Api(expected), _err));
}
// TODO: Find a way to mock calls to discord.com API with a
// successful reply
}

View File

@@ -0,0 +1,41 @@
use reqwest::Error as ReqwestError;
use std::fmt::{self, Display};
use thiserror::Error;
pub mod discord;
#[derive(Debug, serde::Deserialize, PartialEq, Eq)]
pub struct DiscordErrorResponse {
pub message: String,
pub code: u16,
}
impl Display for DiscordErrorResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "DiscordErrorResponse: {} ({})", self.message, self.code)
}
}
#[derive(Debug, Error)]
pub enum ApiError {
#[error("Reqwest error: {0}")]
Reqwest(#[from] ReqwestError),
#[error("API Error: {0}")]
Api(DiscordErrorResponse),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_discord_error_response() {
let error = DiscordErrorResponse {
message: "Message".into(),
code: 42,
};
let error_str = error.to_string();
let expected = "DiscordErrorResponse: Message (42)".to_string();
assert_eq!(expected, error_str);
}
}

View File

@@ -0,0 +1,14 @@
use thiserror::Error;
#[allow(dead_code)]
#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
Sql(#[from] gejdr_core::sqlx::Error),
#[error("OAuth token error: {0}")]
TokenError(String),
#[error("Unauthorized")]
Unauthorized,
#[error("Attempted to get a value, none found")]
OptionError,
}

67
gejdr-backend/src/lib.rs Normal file
View File

@@ -0,0 +1,67 @@
#![deny(clippy::all)]
#![deny(clippy::pedantic)]
#![deny(clippy::nursery)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
use gejdr_core::sqlx;
mod api_wrapper;
mod errors;
mod oauth;
mod route;
mod settings;
mod startup;
type MaybeListener = Option<poem::listener::TcpListener<String>>;
async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> startup::Application {
dotenvy::dotenv().ok();
let settings = settings::Settings::new().expect("Failed to read settings");
if !cfg!(test) {
let subscriber = gejdr_core::telemetry::get_subscriber(settings.clone().debug);
gejdr_core::telemetry::init_subscriber(subscriber);
}
tracing::event!(
target: "gege-jdr-backend",
tracing::Level::DEBUG,
"Using these settings: {:?}",
settings.clone()
);
let application = startup::Application::build(settings.clone(), test_db, listener).await;
tracing::event!(
target: "gege-jdr-backend",
tracing::Level::INFO,
"Listening on {}",
application.settings.web_address()
);
application
}
/// # Errors
///
/// May return an error if the server encounters an error it cannot
/// recover from.
#[cfg(not(tarpaulin_include))]
pub async fn run(listener: MaybeListener) -> Result<(), std::io::Error> {
let application = prepare(listener, None).await;
application.make_app().run().await
}
#[cfg(test)]
async fn make_random_tcp_listener() -> poem::listener::TcpListener<String> {
let tcp_listener =
std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind a random TCP listener");
let port = tcp_listener.local_addr().unwrap().port();
poem::listener::TcpListener::bind(format!("127.0.0.1:{port}"))
}
#[cfg(test)]
async fn get_test_app(test_db: Option<sqlx::PgPool>) -> startup::App {
let tcp_listener = crate::make_random_tcp_listener().await;
crate::prepare(Some(tcp_listener), test_db)
.await
.make_app()
.into()
}

View File

@@ -0,0 +1,5 @@
#[cfg(not(tarpaulin_include))]
#[tokio::main]
async fn main() -> Result<(), std::io::Error> {
gejdr_backend::run(None).await
}

View File

@@ -0,0 +1,62 @@
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
};
use reqwest::Url;
use crate::{errors::ApiError, settings::Settings};
use super::OauthProvider;
#[derive(Debug, Clone)]
pub struct DiscordOauthProvider {
client: BasicClient,
}
impl DiscordOauthProvider {
pub fn new(settings: &Settings) -> Self {
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
.expect("Invalid token endpoint URL");
let revocation_url =
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
.expect("Invalid revocation URL");
let client = BasicClient::new(
ClientId::new(settings.discord.client_id.clone()),
Some(ClientSecret::new(settings.discord.client_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
.set_revocation_uri(revocation_url);
Self { client }
}
}
impl OauthProvider for DiscordOauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
.set_pkce_challenge(pkce_challenge)
.url();
(auth_url, csrf_token, pkce_verifier)
}
async fn token(
&self,
code: String,
verifier: PkceCodeVerifier,
) -> Result<super::Token, ApiError> {
self.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(verifier)
.request_async(oauth2::reqwest::async_http_client)
.await
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
}
}

View File

@@ -0,0 +1,17 @@
mod discord;
pub use discord::DiscordOauthProvider;
use oauth2::{
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
StandardTokenResponse,
};
use reqwest::Url;
use crate::errors::ApiError;
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
pub trait OauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
}

View File

@@ -0,0 +1,220 @@
use gejdr_core::models::accounts::User;
use gejdr_core::sqlx::PgPool;
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
use poem::web::Data;
use poem::{session::Session, web::Form};
use poem_openapi::payload::{Json, PlainText};
use poem_openapi::{ApiResponse, Object, OpenApi};
use crate::oauth::{DiscordOauthProvider, OauthProvider};
use super::errors::ErrorResponse;
use super::ApiCategory;
type Token =
oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>;
pub struct AuthApi;
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
struct DiscordCallbackRequest {
code: String,
state: String,
}
impl DiscordCallbackRequest {
pub fn check_token(&self, token: &CsrfToken) -> Result<(), LoginStatusResponse> {
if *token.secret().to_string() == self.state {
Ok(())
} else {
Err(LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some(
"OAuth provider did not send a message that matches what was expected".into(),
),
})))
}
}
}
#[derive(ApiResponse)]
enum LoginStatusResponse {
#[oai(status = 201)]
LoggedIn(Json<UserInfo>),
#[oai(status = 201)]
LoggedOut(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 301)]
LoginRedirect(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 500)]
TokenError(Json<ErrorResponse>),
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
#[derive(Debug, Eq, PartialEq, serde::Serialize, Object)]
struct UserInfo {
id: String,
username: String,
display_name: Option<String>,
avatar: Option<String>,
}
impl From<User> for UserInfo {
fn from(value: User) -> Self {
Self {
id: value.id,
username: value.username,
display_name: value.name,
avatar: value.avatar,
}
}
}
#[derive(ApiResponse)]
enum UserInfoResponse {
#[oai(status = 201)]
UserInfo(Json<UserInfo>),
#[oai(status = 401)]
Unauthorized,
#[oai(status = 500)]
DatabaseError(Json<ErrorResponse>),
#[oai(status = 503)]
DiscordError(Json<ErrorResponse>),
}
impl From<UserInfoResponse> for LoginStatusResponse {
fn from(value: UserInfoResponse) -> Self {
match value {
UserInfoResponse::UserInfo(json) => Self::LoggedIn(json),
UserInfoResponse::Unauthorized => unimplemented!(),
UserInfoResponse::DatabaseError(json) => Self::DatabaseError(json),
UserInfoResponse::DiscordError(json) => Self::DiscordError(json),
}
}
}
#[derive(ApiResponse)]
enum CsrfResponse {
#[oai(status = 201)]
Token(PlainText<String>),
}
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
impl AuthApi {
async fn fetch_remote_user(
pool: Data<&PgPool>,
token: Token,
) -> Result<UserInfoResponse, UserInfoResponse> {
crate::api_wrapper::discord::get_user_profile(token.access_token().secret())
.await
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Failed to communicate with Discord: {}",
e
);
UserInfoResponse::DiscordError(Json(e.into()))
})?
.refresh_in_database(&pool)
.await
.map(|user| UserInfoResponse::UserInfo(Json(user.into())))
.map_err(|e| {
tracing::event!(
target: "auth-discord",
tracing::Level::ERROR,
"Database error: {}",
e
);
UserInfoResponse::DatabaseError(Json(e.into()))
})
}
#[oai(path = "/signin/discord", method = "get")]
async fn signin_discord(
&self,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
session.set("csrf", csrf_token);
session.set("pkce", pkce_verifier);
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Signin through Discord",
);
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
}
#[oai(path = "/callback/discord", method = "get")]
async fn callback_discord(
&self,
Form(auth_request): Form<DiscordCallbackRequest>,
oauth: Data<&DiscordOauthProvider>,
pool: Data<&PgPool>,
session: &Session,
) -> Result<LoginStatusResponse, LoginStatusResponse> {
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Discord callback",
);
let csrf_token = session.get::<CsrfToken>("csrf").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch csrf token from session".to_string(),
..Default::default()
}))
})?;
auth_request.check_token(&csrf_token)?;
let pkce_verifier = session.get::<PkceCodeVerifier>("pkce").ok_or_else(|| {
LoginStatusResponse::TokenError(Json(ErrorResponse {
code: 500,
message: "Cannot fetch pkce verifier from session".to_string(),
..Default::default()
}))
})?;
let token = oauth
.token(auth_request.code, pkce_verifier)
.await
.map_err(|e| LoginStatusResponse::TokenError(Json(e.into())))?;
session.set("token", token.clone());
Self::fetch_remote_user(pool, token)
.await
.map(std::convert::Into::into)
.map_err(std::convert::Into::into)
}
#[oai(path = "/csrf", method = "get")]
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
CsrfResponse::Token(PlainText(token.0.clone()))
}
#[oai(path = "/signout", method = "post")]
async fn signout(&self, session: &Session) -> LoginStatusResponse {
session.purge();
LoginStatusResponse::LoggedOut("/".to_string(), "no-cache".to_string())
}
#[oai(path = "/me", method = "get")]
async fn user_info(
&self,
session: &Session,
pool: Data<&PgPool>,
) -> Result<UserInfoResponse, UserInfoResponse> {
let token = session
.get::<Token>("token")
.ok_or(UserInfoResponse::Unauthorized)?;
Self::fetch_remote_user(pool, token).await
}
}

View File

@@ -0,0 +1,204 @@
use poem_openapi::Object;
use reqwest::Error as ReqwestError;
use crate::api_wrapper::ApiError as ApiWrapperError;
use crate::errors::ApiError;
#[derive(Debug, serde::Serialize, Default, Object, PartialEq, Eq)]
pub struct ErrorResponse {
pub code: u16,
pub message: String,
pub details: Option<String>,
}
impl From<ApiError> for ErrorResponse {
fn from(value: ApiError) -> Self {
match value {
ApiError::Sql(e) => Self {
code: 500,
message: "SQL error".into(),
details: Some(e.to_string()),
},
ApiError::TokenError(e) => Self {
code: 500,
message: "OAuth token error".into(),
details: Some(e),
},
ApiError::Unauthorized => Self {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
},
ApiError::OptionError => Self {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
},
}
}
}
impl From<ReqwestError> for ErrorResponse {
fn from(value: ReqwestError) -> Self {
Self {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some(value.status().map_or_else(
|| "Communication failed before we could hear back from Discord".into(),
|status| format!("Discord sent back the error code {status}"),
)),
}
}
}
impl From<ApiWrapperError> for ErrorResponse {
fn from(source: ApiWrapperError) -> Self {
match source {
ApiWrapperError::Reqwest(e) => e.into(),
ApiWrapperError::Api(e) => Self {
code: if e.message.as_str().starts_with("401") {
401
} else {
e.code
},
message: e.message,
details: None,
},
}
}
}
impl From<gejdr_core::sqlx::Error> for ErrorResponse {
fn from(_value: gejdr_core::sqlx::Error) -> Self {
Self {
code: 500,
message: "Internal database error".into(),
..Default::default()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api_wrapper::{ApiError as ApiWrapperError, DiscordErrorResponse};
#[test]
fn conversion_from_sql_api_error_works() {
let sql_error = ApiError::Sql(gejdr_core::sqlx::Error::ColumnNotFound(
"COLUMN_NAME".to_string(),
));
let final_error = ErrorResponse::from(sql_error);
let expected_error = ErrorResponse {
code: 500,
message: "SQL error".into(),
details: Some("no column found for name: COLUMN_NAME".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_token_error_works() {
let initial_error = ApiError::TokenError("TOKEN ERROR".into());
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some("TOKEN ERROR".into()),
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_unauthorized_works() {
let initial_error = ApiError::Unauthorized;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[test]
fn conversion_from_option_error_works() {
let initial_error = ApiError::OptionError;
let final_error: ErrorResponse = initial_error.into();
let expected_error = ErrorResponse {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
};
assert_eq!(expected_error, final_error);
}
#[tokio::test]
async fn conversion_from_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.err().unwrap().into();
assert_eq!(expected, actual);
}
#[tokio::test]
async fn conversion_from_apiwrappererror_with_reqwest_error() {
let err = reqwest::get("https://example.example/401").await;
assert!(err.is_err());
let err = ApiWrapperError::Reqwest(err.err().unwrap());
let expected = ErrorResponse {
code: 503,
message: "Failed to communicate with Discord".into(),
details: Some("Communication failed before we could hear back from Discord".into()),
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_401_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "401: Unauthorized".into(),
});
let expected = ErrorResponse {
code: 401,
message: "401: Unauthorized".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_apiwrappererror_with_generic_discord_error() {
let err = ApiWrapperError::Api(DiscordErrorResponse {
code: 0,
message: "Something else".into(),
});
let expected = ErrorResponse {
code: 0,
message: "Something else".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
#[test]
fn conversion_from_database_error() {
let err = gejdr_core::sqlx::Error::PoolClosed;
let expected = ErrorResponse {
code: 500,
message: "Internal database error".into(),
..Default::default()
};
let actual: ErrorResponse = err.into();
assert_eq!(expected, actual);
}
}

View File

@@ -0,0 +1,29 @@
use poem_openapi::{ApiResponse, OpenApi};
use super::ApiCategory;
#[derive(ApiResponse)]
enum HealthResponse {
#[oai(status = 200)]
Ok,
}
pub struct HealthApi;
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
impl HealthApi {
#[oai(path = "/", method = "get")]
async fn health_check(&self) -> HealthResponse {
tracing::event!(target: "gege-jdr-backend", tracing::Level::DEBUG, "Accessing health-check endpoint.");
HealthResponse::Ok
}
}
#[tokio::test]
async fn health_check_works() {
let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/api/health-check").send().await;
resp.assert_status_is_ok();
resp.assert_text("").await;
}

View File

@@ -0,0 +1,24 @@
use poem_openapi::{OpenApi, Tags};
mod health;
pub use health::HealthApi;
mod version;
pub use version::VersionApi;
mod errors;
mod auth;
pub use auth::AuthApi;
#[derive(Tags)]
enum ApiCategory {
Auth,
Health,
Version,
}
pub struct Api;
#[OpenApi]
impl Api {}

View File

@@ -0,0 +1,46 @@
use poem::Result;
use poem_openapi::{payload::Json, ApiResponse, Object, OpenApi};
use crate::settings::Settings;
use super::ApiCategory;
#[derive(Object, Debug, Clone, serde::Serialize, serde::Deserialize)]
struct Meta {
version: String,
}
impl From<poem::web::Data<&Settings>> for Meta {
fn from(value: poem::web::Data<&Settings>) -> Self {
let version = value.application.version.clone();
Self { version }
}
}
#[derive(ApiResponse)]
enum VersionResponse {
#[oai(status = 200)]
Version(Json<Meta>),
}
pub struct VersionApi;
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
impl VersionApi {
#[oai(path = "/", method = "get")]
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
tracing::event!(target: "gege-jdr-backend", tracing::Level::DEBUG, "Accessing version endpoint.");
Ok(VersionResponse::Version(Json(settings.into())))
}
}
#[tokio::test]
async fn version_works() {
let app = crate::get_test_app(None).await;
let cli = poem::test::TestClient::new(app);
let resp = cli.get("/v1/api/version").send().await;
resp.assert_status_is_ok();
let json = resp.json().await;
let json_value = json.value();
json_value.object().get("version").assert_not_null();
}

View File

@@ -0,0 +1,213 @@
use gejdr_core::database::Database;
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Settings {
pub application: ApplicationSettings,
pub database: Database,
pub discord: Discord,
pub debug: bool,
pub email: EmailSettings,
pub frontend_url: String,
}
impl Settings {
#[must_use]
pub fn web_address(&self) -> String {
self.application.base_url.clone()
}
/// Multipurpose function that helps detect the current
/// environment the application is running in using the
/// `APP_ENVIRONMENT` environment variable.
///
/// ```text
/// APP_ENVIRONMENT = development | dev | production | prod
/// ```
///
/// After detection, it loads the appropriate `.yaml` file. It
/// then loads the environment variables that overrides whatever
/// is set in the `.yaml` files. For this to work, the environment
/// variable MUST be in uppercase and start with `APP`, a `_`
/// separator, then the category of settings, followed by a `__`
/// separator, and finally the variable itself. For instance,
/// `APP__APPLICATION_PORT=3001` for `port` to be set as `3001`.
///
/// # Errors
///
/// Function may return an error if it fails to parse its config
/// files.
///
/// # Panics
///
/// Panics if the program fails to detect the directory it is
/// running in. Can also panic if it fails to parse the
/// environment variable `APP_ENVIRONMENT` and it fails to fall
/// back to its default value.
pub fn new() -> Result<Self, config::ConfigError> {
let base_path = std::env::current_dir().expect("Failed to determine the current directory");
let settings_directory = base_path.join("settings");
let environment: Environment = std::env::var("APP_ENVIRONMENT")
.unwrap_or_else(|_| "development".into())
.try_into()
.expect("Failed to parse APP_ENVIRONMENT");
let environment_filename = format!("{environment}.yaml");
let settings = config::Config::builder()
.add_source(config::File::from(settings_directory.join("base.yaml")))
.add_source(config::File::from(
settings_directory.join(environment_filename),
))
.add_source(
config::Environment::with_prefix("APP")
.prefix_separator("__")
.separator("__"),
)
.build()?;
settings.try_deserialize::<Self>()
}
}
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct ApplicationSettings {
pub name: String,
pub version: String,
pub port: u16,
pub host: String,
pub base_url: String,
pub protocol: String,
}
#[derive(Debug, PartialEq, Eq)]
pub enum Environment {
Development,
Production,
}
impl Default for Environment {
fn default() -> Self {
Self::Development
}
}
impl std::fmt::Display for Environment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let self_str = match self {
Self::Development => "development",
Self::Production => "production",
};
write!(f, "{self_str}")
}
}
impl TryFrom<String> for Environment {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
match value.to_lowercase().as_str() {
"development" | "dev" => Ok(Self::Development),
"production" | "prod" => Ok(Self::Production),
other => Err(format!(
"{other} is not a supported environment. Use either `development` or `production`"
)),
}
}
}
impl TryFrom<&str> for Environment {
type Error = String;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::try_from(value.to_string())
}
}
#[derive(serde::Deserialize, Clone, Debug, Default)]
pub struct EmailSettings {
pub host: String,
pub user: String,
pub password: String,
pub from: String,
}
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Discord {
pub client_id: String,
pub client_secret: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_environment_works() {
let default_environment = Environment::default();
assert_eq!(Environment::Development, default_environment);
}
#[test]
fn display_environment_works() {
let expected_prod = "production".to_string();
let expected_dev = "development".to_string();
let prod = Environment::Production.to_string();
let dev = Environment::Development.to_string();
assert_eq!(expected_prod, prod);
assert_eq!(expected_dev, dev);
}
#[test]
fn try_from_works() {
[
"DEVELOPMENT",
"DEVEloPmENT",
"Development",
"DEV",
"Dev",
"dev",
]
.iter()
.map(|v| (*v).to_string())
.for_each(|v| {
let environment = Environment::try_from(v);
assert!(environment.is_ok());
assert_eq!(Environment::Development, environment.unwrap());
});
[
"PRODUCTION",
"Production",
"PRODuction",
"production",
"PROD",
"Prod",
"prod",
]
.iter()
.map(|v| (*v).to_string())
.for_each(|v| {
let environment = Environment::try_from(v);
assert!(environment.is_ok());
assert_eq!(Environment::Production, environment.unwrap());
});
let environment = Environment::try_from("invalid");
assert!(environment.is_err());
assert_eq!(
"invalid is not a supported environment. Use either `development` or `production`"
.to_string(),
environment.err().unwrap()
);
}
#[test]
fn web_address_works() {
let settings = Settings {
debug: false,
application: ApplicationSettings {
base_url: "127.0.0.1".to_string(),
port: 3000,
..Default::default()
},
..Default::default()
};
let expected = "127.0.0.1".to_string();
assert_eq!(expected, settings.web_address());
}
}

View File

@@ -0,0 +1,139 @@
use gejdr_core::sqlx;
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
use poem::middleware::{CookieJarManagerEndpoint, Cors};
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService;
use crate::oauth::DiscordOauthProvider;
use crate::route::AuthApi;
use crate::{
route::{Api, HealthApi, VersionApi},
settings::Settings,
};
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
pub type App = AddDataEndpoint<
AddDataEndpoint<
AddDataEndpoint<
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
DiscordOauthProvider,
>,
sqlx::Pool<sqlx::Postgres>,
>,
Settings,
>;
pub struct Application {
server: Server,
app: poem::Route,
port: u16,
database: sqlx::postgres::PgPool,
pub settings: Settings,
}
pub struct RunnableApplication {
server: Server,
app: App,
}
impl RunnableApplication {
/// Runs the application until it decides to stop by itself.
///
/// # Errors
///
/// If the server encounters an internal error it cannot recover
/// from, it will forward it to this function which will forward
/// it to its caller.
pub async fn run(self) -> Result<(), std::io::Error> {
self.server.run(self.app).await
}
}
impl From<RunnableApplication> for App {
fn from(value: RunnableApplication) -> Self {
value.app
}
}
impl From<Application> for RunnableApplication {
fn from(val: Application) -> Self {
let app = val
.app
.with(Cors::new())
.with(CookieSession::new(CookieConfig::default().secure(true)))
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
.data(val.database)
.data(val.settings);
let server = val.server;
Self { server, app }
}
}
impl Application {
async fn setup_db(
settings: &Settings,
test_pool: Option<sqlx::postgres::PgPool>,
) -> sqlx::postgres::PgPool {
let database_pool =
test_pool.map_or_else(|| settings.database.get_connection_pool(), |pool| pool);
if !cfg!(test) {
gejdr_core::database::Database::migrate(&database_pool).await;
}
database_pool
}
fn setup_app(settings: &Settings) -> poem::Route {
let api_service = OpenApiService::new(
(Api, AuthApi, HealthApi, VersionApi),
settings.application.clone().name,
settings.application.clone().version,
);
let ui = api_service.swagger_ui();
poem::Route::new().nest("/", api_service).nest("/docs", ui)
}
fn setup_server(
settings: &Settings,
tcp_listener: Option<poem::listener::TcpListener<String>>,
) -> Server {
let tcp_listener = tcp_listener.unwrap_or_else(|| {
let address = format!(
"{}:{}",
settings.application.host, settings.application.port
);
poem::listener::TcpListener::bind(address)
});
poem::Server::new(tcp_listener)
}
pub async fn build(
settings: Settings,
test_pool: Option<sqlx::postgres::PgPool>,
tcp_listener: Option<poem::listener::TcpListener<String>>,
) -> Self {
let database_pool = Self::setup_db(&settings, test_pool).await;
let port = settings.application.port;
let app = Self::setup_app(&settings);
let server = Self::setup_server(&settings, tcp_listener);
Self {
server,
app,
port,
database: database_pool,
settings,
}
}
/// Make the app runnable.
#[must_use]
pub fn make_app(self) -> RunnableApplication {
self.into()
}
#[must_use]
pub const fn port(&self) -> u16 {
self.port
}
}