generated from phundrak/rust-poem-openapi-template
feat: OAuth implementation with Discord
This commit separates the core features of GéJDR from the backend as these will also be used by the bot in the future. This commit also updates the dependencies of the project. It also removes the dependency lettre as well as the mailpit docker service for developers as it appears clearer this project won’t send emails anytime soon. The publication of a docker image is also postponed until later.
This commit is contained in:
43
gejdr-backend/Cargo.toml
Normal file
43
gejdr-backend/Cargo.toml
Normal file
@@ -0,0 +1,43 @@
|
||||
[package]
|
||||
name = "gejdr-backend"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
authors = ["Lucien Cartier-Tilet <lucien@phundrak.com>"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/main.rs"
|
||||
name = "gejdr-backend"
|
||||
|
||||
[dependencies]
|
||||
gejdr-core = { path = "../gejdr-core" }
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
config = { version = "0.14.1", features = ["yaml"] }
|
||||
dotenvy = "0.15.7"
|
||||
oauth2 = "4.4.2"
|
||||
quote = "1.0.37"
|
||||
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
|
||||
serde = "1.0.215"
|
||||
serde_json = "1.0.133"
|
||||
thiserror = "1.0.69"
|
||||
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
||||
uuid = { version = "1.11.0", features = ["v4", "serde"] }
|
||||
|
||||
[dependencies.poem]
|
||||
version = "3.1.3"
|
||||
default-features = false
|
||||
features = ["csrf", "rustls", "cookie", "test", "session"]
|
||||
|
||||
[dependencies.poem-openapi]
|
||||
version = "5.1.2"
|
||||
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }
|
||||
18
gejdr-backend/backend.just
Normal file
18
gejdr-backend/backend.just
Normal file
@@ -0,0 +1,18 @@
|
||||
default: run
|
||||
|
||||
build $SQLX_OFFLINE="1":
|
||||
pwd
|
||||
cargo auditable build --bin gejdr-backend
|
||||
|
||||
build-release $SQLX_OFFLINE="1":
|
||||
cargo auditable build --release --bin gejdr-backend
|
||||
|
||||
build-docker:
|
||||
nix build .#dockerBackend
|
||||
|
||||
run:
|
||||
cargo auditable run --bin gejdr-backend
|
||||
|
||||
## Local Variables:
|
||||
## mode: makefile
|
||||
## End:
|
||||
22
gejdr-backend/settings/base.yaml
Normal file
22
gejdr-backend/settings/base.yaml
Normal file
@@ -0,0 +1,22 @@
|
||||
application:
|
||||
port: 3000
|
||||
name: "GegeJdrBackend"
|
||||
version: 0.1.0
|
||||
|
||||
database:
|
||||
host: localhost
|
||||
port: 5432
|
||||
name: gege-jdr-backend
|
||||
user: dev
|
||||
password: password
|
||||
require_ssl: false
|
||||
|
||||
email:
|
||||
host: smtp.gege-jdr-backend.example
|
||||
user: user@gege-jdr-backend.example
|
||||
from: GegeJdrBackend <noreply@gege-jdr-backend.example>
|
||||
password: hunter2
|
||||
|
||||
discord:
|
||||
client_id: changeme
|
||||
client_secret: changeme
|
||||
7
gejdr-backend/settings/development.yaml
Normal file
7
gejdr-backend/settings/development.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
frontend_url: http://localhost:5173
|
||||
debug: true
|
||||
|
||||
application:
|
||||
protocol: http
|
||||
host: localhost
|
||||
base_url: http://localhost:3000
|
||||
7
gejdr-backend/settings/production.yaml
Normal file
7
gejdr-backend/settings/production.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
debug: false
|
||||
frontend_url: ""
|
||||
|
||||
application:
|
||||
protocol: https
|
||||
host: 0.0.0.0
|
||||
base_url: ""
|
||||
55
gejdr-backend/src/api_wrapper/discord.rs
Normal file
55
gejdr-backend/src/api_wrapper/discord.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use super::{ApiError, DiscordErrorResponse};
|
||||
use gejdr_core::models::accounts::RemoteUser;
|
||||
|
||||
static DISCORD_URL: &str = "https://discord.com/api/v10/";
|
||||
|
||||
pub async fn get_user_profile(token: &str) -> Result<RemoteUser, ApiError> {
|
||||
let client = reqwest::Client::new();
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
format!("Bearer {token}").parse().unwrap(),
|
||||
);
|
||||
let response = client
|
||||
.get(format!("{DISCORD_URL}/users/@me"))
|
||||
.headers(headers)
|
||||
.send()
|
||||
.await;
|
||||
match response {
|
||||
Ok(resp) => {
|
||||
if resp.status().is_success() {
|
||||
resp.json::<RemoteUser>()
|
||||
.await
|
||||
.map_err(std::convert::Into::into)
|
||||
} else {
|
||||
let error_response = resp.json::<DiscordErrorResponse>().await;
|
||||
match error_response {
|
||||
Ok(val) => Err(ApiError::Api(val)),
|
||||
Err(e) => Err(ApiError::Reqwest(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => Err(ApiError::Reqwest(e)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn user_profile_invalid_token_results_in_401() {
|
||||
let res = get_user_profile("invalid").await;
|
||||
assert!(res.is_err());
|
||||
let err = res.err().unwrap();
|
||||
println!("Error: {err:?}");
|
||||
let expected = DiscordErrorResponse {
|
||||
code: 0,
|
||||
message: "401: Unauthorized".into(),
|
||||
};
|
||||
assert!(matches!(ApiError::Api(expected), _err));
|
||||
}
|
||||
|
||||
// TODO: Find a way to mock calls to discord.com API with a
|
||||
// successful reply
|
||||
}
|
||||
41
gejdr-backend/src/api_wrapper/mod.rs
Normal file
41
gejdr-backend/src/api_wrapper/mod.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use reqwest::Error as ReqwestError;
|
||||
use std::fmt::{self, Display};
|
||||
use thiserror::Error;
|
||||
|
||||
pub mod discord;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, PartialEq, Eq)]
|
||||
pub struct DiscordErrorResponse {
|
||||
pub message: String,
|
||||
pub code: u16,
|
||||
}
|
||||
|
||||
impl Display for DiscordErrorResponse {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "DiscordErrorResponse: {} ({})", self.message, self.code)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
#[error("Reqwest error: {0}")]
|
||||
Reqwest(#[from] ReqwestError),
|
||||
#[error("API Error: {0}")]
|
||||
Api(DiscordErrorResponse),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn display_discord_error_response() {
|
||||
let error = DiscordErrorResponse {
|
||||
message: "Message".into(),
|
||||
code: 42,
|
||||
};
|
||||
let error_str = error.to_string();
|
||||
let expected = "DiscordErrorResponse: Message (42)".to_string();
|
||||
assert_eq!(expected, error_str);
|
||||
}
|
||||
}
|
||||
14
gejdr-backend/src/errors.rs
Normal file
14
gejdr-backend/src/errors.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
#[error("SQL error: {0}")]
|
||||
Sql(#[from] gejdr_core::sqlx::Error),
|
||||
#[error("OAuth token error: {0}")]
|
||||
TokenError(String),
|
||||
#[error("Unauthorized")]
|
||||
Unauthorized,
|
||||
#[error("Attempted to get a value, none found")]
|
||||
OptionError,
|
||||
}
|
||||
67
gejdr-backend/src/lib.rs
Normal file
67
gejdr-backend/src/lib.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
#![deny(clippy::all)]
|
||||
#![deny(clippy::pedantic)]
|
||||
#![deny(clippy::nursery)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
#![allow(clippy::unused_async)]
|
||||
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
|
||||
|
||||
use gejdr_core::sqlx;
|
||||
|
||||
mod api_wrapper;
|
||||
mod errors;
|
||||
mod oauth;
|
||||
mod route;
|
||||
mod settings;
|
||||
mod startup;
|
||||
|
||||
type MaybeListener = Option<poem::listener::TcpListener<String>>;
|
||||
|
||||
async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> startup::Application {
|
||||
dotenvy::dotenv().ok();
|
||||
let settings = settings::Settings::new().expect("Failed to read settings");
|
||||
if !cfg!(test) {
|
||||
let subscriber = gejdr_core::telemetry::get_subscriber(settings.clone().debug);
|
||||
gejdr_core::telemetry::init_subscriber(subscriber);
|
||||
}
|
||||
tracing::event!(
|
||||
target: "gege-jdr-backend",
|
||||
tracing::Level::DEBUG,
|
||||
"Using these settings: {:?}",
|
||||
settings.clone()
|
||||
);
|
||||
let application = startup::Application::build(settings.clone(), test_db, listener).await;
|
||||
tracing::event!(
|
||||
target: "gege-jdr-backend",
|
||||
tracing::Level::INFO,
|
||||
"Listening on {}",
|
||||
application.settings.web_address()
|
||||
);
|
||||
application
|
||||
}
|
||||
|
||||
/// # Errors
|
||||
///
|
||||
/// May return an error if the server encounters an error it cannot
|
||||
/// recover from.
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
pub async fn run(listener: MaybeListener) -> Result<(), std::io::Error> {
|
||||
let application = prepare(listener, None).await;
|
||||
application.make_app().run().await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn make_random_tcp_listener() -> poem::listener::TcpListener<String> {
|
||||
let tcp_listener =
|
||||
std::net::TcpListener::bind("127.0.0.1:0").expect("Failed to bind a random TCP listener");
|
||||
let port = tcp_listener.local_addr().unwrap().port();
|
||||
poem::listener::TcpListener::bind(format!("127.0.0.1:{port}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn get_test_app(test_db: Option<sqlx::PgPool>) -> startup::App {
|
||||
let tcp_listener = crate::make_random_tcp_listener().await;
|
||||
crate::prepare(Some(tcp_listener), test_db)
|
||||
.await
|
||||
.make_app()
|
||||
.into()
|
||||
}
|
||||
5
gejdr-backend/src/main.rs
Normal file
5
gejdr-backend/src/main.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
#[cfg(not(tarpaulin_include))]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), std::io::Error> {
|
||||
gejdr_backend::run(None).await
|
||||
}
|
||||
62
gejdr-backend/src/oauth/discord.rs
Normal file
62
gejdr-backend/src/oauth/discord.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
|
||||
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
|
||||
};
|
||||
use reqwest::Url;
|
||||
|
||||
use crate::{errors::ApiError, settings::Settings};
|
||||
|
||||
use super::OauthProvider;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DiscordOauthProvider {
|
||||
client: BasicClient,
|
||||
}
|
||||
|
||||
impl DiscordOauthProvider {
|
||||
pub fn new(settings: &Settings) -> Self {
|
||||
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
|
||||
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
|
||||
.expect("Invalid authorization endpoint URL");
|
||||
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
|
||||
.expect("Invalid token endpoint URL");
|
||||
let revocation_url =
|
||||
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
|
||||
.expect("Invalid revocation URL");
|
||||
let client = BasicClient::new(
|
||||
ClientId::new(settings.discord.client_id.clone()),
|
||||
Some(ClientSecret::new(settings.discord.client_secret.clone())),
|
||||
auth_url,
|
||||
Some(token_url),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
|
||||
.set_revocation_uri(revocation_url);
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
impl OauthProvider for DiscordOauthProvider {
|
||||
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let (auth_url, csrf_token) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.url();
|
||||
(auth_url, csrf_token, pkce_verifier)
|
||||
}
|
||||
|
||||
async fn token(
|
||||
&self,
|
||||
code: String,
|
||||
verifier: PkceCodeVerifier,
|
||||
) -> Result<super::Token, ApiError> {
|
||||
self.client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(verifier)
|
||||
.request_async(oauth2::reqwest::async_http_client)
|
||||
.await
|
||||
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
|
||||
}
|
||||
}
|
||||
17
gejdr-backend/src/oauth/mod.rs
Normal file
17
gejdr-backend/src/oauth/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
mod discord;
|
||||
pub use discord::DiscordOauthProvider;
|
||||
|
||||
use oauth2::{
|
||||
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
|
||||
StandardTokenResponse,
|
||||
};
|
||||
use reqwest::Url;
|
||||
|
||||
use crate::errors::ApiError;
|
||||
|
||||
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
|
||||
|
||||
pub trait OauthProvider {
|
||||
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
|
||||
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
|
||||
}
|
||||
220
gejdr-backend/src/route/auth.rs
Normal file
220
gejdr-backend/src/route/auth.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
use gejdr_core::models::accounts::User;
|
||||
use gejdr_core::sqlx::PgPool;
|
||||
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
|
||||
use poem::web::Data;
|
||||
use poem::{session::Session, web::Form};
|
||||
use poem_openapi::payload::{Json, PlainText};
|
||||
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||
|
||||
use crate::oauth::{DiscordOauthProvider, OauthProvider};
|
||||
|
||||
use super::errors::ErrorResponse;
|
||||
use super::ApiCategory;
|
||||
|
||||
type Token =
|
||||
oauth2::StandardTokenResponse<oauth2::EmptyExtraTokenFields, oauth2::basic::BasicTokenType>;
|
||||
|
||||
pub struct AuthApi;
|
||||
|
||||
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
|
||||
struct DiscordCallbackRequest {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
impl DiscordCallbackRequest {
|
||||
pub fn check_token(&self, token: &CsrfToken) -> Result<(), LoginStatusResponse> {
|
||||
if *token.secret().to_string() == self.state {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(LoginStatusResponse::TokenError(Json(ErrorResponse {
|
||||
code: 500,
|
||||
message: "OAuth token error".into(),
|
||||
details: Some(
|
||||
"OAuth provider did not send a message that matches what was expected".into(),
|
||||
),
|
||||
})))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum LoginStatusResponse {
|
||||
#[oai(status = 201)]
|
||||
LoggedIn(Json<UserInfo>),
|
||||
#[oai(status = 201)]
|
||||
LoggedOut(
|
||||
#[oai(header = "Location")] String,
|
||||
#[oai(header = "Cache-Control")] String,
|
||||
),
|
||||
#[oai(status = 301)]
|
||||
LoginRedirect(
|
||||
#[oai(header = "Location")] String,
|
||||
#[oai(header = "Cache-Control")] String,
|
||||
),
|
||||
#[oai(status = 500)]
|
||||
TokenError(Json<ErrorResponse>),
|
||||
#[oai(status = 500)]
|
||||
DatabaseError(Json<ErrorResponse>),
|
||||
#[oai(status = 503)]
|
||||
DiscordError(Json<ErrorResponse>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Eq, PartialEq, serde::Serialize, Object)]
|
||||
struct UserInfo {
|
||||
id: String,
|
||||
username: String,
|
||||
display_name: Option<String>,
|
||||
avatar: Option<String>,
|
||||
}
|
||||
|
||||
impl From<User> for UserInfo {
|
||||
fn from(value: User) -> Self {
|
||||
Self {
|
||||
id: value.id,
|
||||
username: value.username,
|
||||
display_name: value.name,
|
||||
avatar: value.avatar,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum UserInfoResponse {
|
||||
#[oai(status = 201)]
|
||||
UserInfo(Json<UserInfo>),
|
||||
#[oai(status = 401)]
|
||||
Unauthorized,
|
||||
#[oai(status = 500)]
|
||||
DatabaseError(Json<ErrorResponse>),
|
||||
#[oai(status = 503)]
|
||||
DiscordError(Json<ErrorResponse>),
|
||||
}
|
||||
|
||||
impl From<UserInfoResponse> for LoginStatusResponse {
|
||||
fn from(value: UserInfoResponse) -> Self {
|
||||
match value {
|
||||
UserInfoResponse::UserInfo(json) => Self::LoggedIn(json),
|
||||
UserInfoResponse::Unauthorized => unimplemented!(),
|
||||
UserInfoResponse::DatabaseError(json) => Self::DatabaseError(json),
|
||||
UserInfoResponse::DiscordError(json) => Self::DiscordError(json),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum CsrfResponse {
|
||||
#[oai(status = 201)]
|
||||
Token(PlainText<String>),
|
||||
}
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
|
||||
impl AuthApi {
|
||||
async fn fetch_remote_user(
|
||||
pool: Data<&PgPool>,
|
||||
token: Token,
|
||||
) -> Result<UserInfoResponse, UserInfoResponse> {
|
||||
crate::api_wrapper::discord::get_user_profile(token.access_token().secret())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::ERROR,
|
||||
"Failed to communicate with Discord: {}",
|
||||
e
|
||||
);
|
||||
UserInfoResponse::DiscordError(Json(e.into()))
|
||||
})?
|
||||
.refresh_in_database(&pool)
|
||||
.await
|
||||
.map(|user| UserInfoResponse::UserInfo(Json(user.into())))
|
||||
.map_err(|e| {
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::ERROR,
|
||||
"Database error: {}",
|
||||
e
|
||||
);
|
||||
UserInfoResponse::DatabaseError(Json(e.into()))
|
||||
})
|
||||
}
|
||||
|
||||
#[oai(path = "/signin/discord", method = "get")]
|
||||
async fn signin_discord(
|
||||
&self,
|
||||
oauth: Data<&DiscordOauthProvider>,
|
||||
session: &Session,
|
||||
) -> LoginStatusResponse {
|
||||
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
|
||||
session.set("csrf", csrf_token);
|
||||
session.set("pkce", pkce_verifier);
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::INFO,
|
||||
"Signin through Discord",
|
||||
);
|
||||
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
|
||||
}
|
||||
|
||||
#[oai(path = "/callback/discord", method = "get")]
|
||||
async fn callback_discord(
|
||||
&self,
|
||||
Form(auth_request): Form<DiscordCallbackRequest>,
|
||||
oauth: Data<&DiscordOauthProvider>,
|
||||
pool: Data<&PgPool>,
|
||||
session: &Session,
|
||||
) -> Result<LoginStatusResponse, LoginStatusResponse> {
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::INFO,
|
||||
"Discord callback",
|
||||
);
|
||||
let csrf_token = session.get::<CsrfToken>("csrf").ok_or_else(|| {
|
||||
LoginStatusResponse::TokenError(Json(ErrorResponse {
|
||||
code: 500,
|
||||
message: "Cannot fetch csrf token from session".to_string(),
|
||||
..Default::default()
|
||||
}))
|
||||
})?;
|
||||
auth_request.check_token(&csrf_token)?;
|
||||
let pkce_verifier = session.get::<PkceCodeVerifier>("pkce").ok_or_else(|| {
|
||||
LoginStatusResponse::TokenError(Json(ErrorResponse {
|
||||
code: 500,
|
||||
message: "Cannot fetch pkce verifier from session".to_string(),
|
||||
..Default::default()
|
||||
}))
|
||||
})?;
|
||||
let token = oauth
|
||||
.token(auth_request.code, pkce_verifier)
|
||||
.await
|
||||
.map_err(|e| LoginStatusResponse::TokenError(Json(e.into())))?;
|
||||
session.set("token", token.clone());
|
||||
Self::fetch_remote_user(pool, token)
|
||||
.await
|
||||
.map(std::convert::Into::into)
|
||||
.map_err(std::convert::Into::into)
|
||||
}
|
||||
|
||||
#[oai(path = "/csrf", method = "get")]
|
||||
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
|
||||
CsrfResponse::Token(PlainText(token.0.clone()))
|
||||
}
|
||||
|
||||
#[oai(path = "/signout", method = "post")]
|
||||
async fn signout(&self, session: &Session) -> LoginStatusResponse {
|
||||
session.purge();
|
||||
LoginStatusResponse::LoggedOut("/".to_string(), "no-cache".to_string())
|
||||
}
|
||||
|
||||
#[oai(path = "/me", method = "get")]
|
||||
async fn user_info(
|
||||
&self,
|
||||
session: &Session,
|
||||
pool: Data<&PgPool>,
|
||||
) -> Result<UserInfoResponse, UserInfoResponse> {
|
||||
let token = session
|
||||
.get::<Token>("token")
|
||||
.ok_or(UserInfoResponse::Unauthorized)?;
|
||||
Self::fetch_remote_user(pool, token).await
|
||||
}
|
||||
}
|
||||
204
gejdr-backend/src/route/errors.rs
Normal file
204
gejdr-backend/src/route/errors.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use poem_openapi::Object;
|
||||
use reqwest::Error as ReqwestError;
|
||||
|
||||
use crate::api_wrapper::ApiError as ApiWrapperError;
|
||||
use crate::errors::ApiError;
|
||||
|
||||
#[derive(Debug, serde::Serialize, Default, Object, PartialEq, Eq)]
|
||||
pub struct ErrorResponse {
|
||||
pub code: u16,
|
||||
pub message: String,
|
||||
pub details: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ApiError> for ErrorResponse {
|
||||
fn from(value: ApiError) -> Self {
|
||||
match value {
|
||||
ApiError::Sql(e) => Self {
|
||||
code: 500,
|
||||
message: "SQL error".into(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
ApiError::TokenError(e) => Self {
|
||||
code: 500,
|
||||
message: "OAuth token error".into(),
|
||||
details: Some(e),
|
||||
},
|
||||
ApiError::Unauthorized => Self {
|
||||
code: 401,
|
||||
message: "Unauthorized!".into(),
|
||||
..Default::default()
|
||||
},
|
||||
ApiError::OptionError => Self {
|
||||
code: 500,
|
||||
message: "Attempted to get a value, but none found".into(),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ReqwestError> for ErrorResponse {
|
||||
fn from(value: ReqwestError) -> Self {
|
||||
Self {
|
||||
code: 503,
|
||||
message: "Failed to communicate with Discord".into(),
|
||||
details: Some(value.status().map_or_else(
|
||||
|| "Communication failed before we could hear back from Discord".into(),
|
||||
|status| format!("Discord sent back the error code {status}"),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ApiWrapperError> for ErrorResponse {
|
||||
fn from(source: ApiWrapperError) -> Self {
|
||||
match source {
|
||||
ApiWrapperError::Reqwest(e) => e.into(),
|
||||
ApiWrapperError::Api(e) => Self {
|
||||
code: if e.message.as_str().starts_with("401") {
|
||||
401
|
||||
} else {
|
||||
e.code
|
||||
},
|
||||
message: e.message,
|
||||
details: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<gejdr_core::sqlx::Error> for ErrorResponse {
|
||||
fn from(_value: gejdr_core::sqlx::Error) -> Self {
|
||||
Self {
|
||||
code: 500,
|
||||
message: "Internal database error".into(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::api_wrapper::{ApiError as ApiWrapperError, DiscordErrorResponse};
|
||||
|
||||
#[test]
|
||||
fn conversion_from_sql_api_error_works() {
|
||||
let sql_error = ApiError::Sql(gejdr_core::sqlx::Error::ColumnNotFound(
|
||||
"COLUMN_NAME".to_string(),
|
||||
));
|
||||
let final_error = ErrorResponse::from(sql_error);
|
||||
let expected_error = ErrorResponse {
|
||||
code: 500,
|
||||
message: "SQL error".into(),
|
||||
details: Some("no column found for name: COLUMN_NAME".into()),
|
||||
};
|
||||
assert_eq!(expected_error, final_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversion_from_token_error_works() {
|
||||
let initial_error = ApiError::TokenError("TOKEN ERROR".into());
|
||||
let final_error: ErrorResponse = initial_error.into();
|
||||
let expected_error = ErrorResponse {
|
||||
code: 500,
|
||||
message: "OAuth token error".into(),
|
||||
details: Some("TOKEN ERROR".into()),
|
||||
};
|
||||
assert_eq!(expected_error, final_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversion_from_unauthorized_works() {
|
||||
let initial_error = ApiError::Unauthorized;
|
||||
let final_error: ErrorResponse = initial_error.into();
|
||||
let expected_error = ErrorResponse {
|
||||
code: 401,
|
||||
message: "Unauthorized!".into(),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(expected_error, final_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversion_from_option_error_works() {
|
||||
let initial_error = ApiError::OptionError;
|
||||
let final_error: ErrorResponse = initial_error.into();
|
||||
let expected_error = ErrorResponse {
|
||||
code: 500,
|
||||
message: "Attempted to get a value, but none found".into(),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(expected_error, final_error);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn conversion_from_reqwest_error() {
|
||||
let err = reqwest::get("https://example.example/401").await;
|
||||
assert!(err.is_err());
|
||||
let expected = ErrorResponse {
|
||||
code: 503,
|
||||
message: "Failed to communicate with Discord".into(),
|
||||
details: Some("Communication failed before we could hear back from Discord".into()),
|
||||
};
|
||||
let actual: ErrorResponse = err.err().unwrap().into();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn conversion_from_apiwrappererror_with_reqwest_error() {
|
||||
let err = reqwest::get("https://example.example/401").await;
|
||||
assert!(err.is_err());
|
||||
let err = ApiWrapperError::Reqwest(err.err().unwrap());
|
||||
let expected = ErrorResponse {
|
||||
code: 503,
|
||||
message: "Failed to communicate with Discord".into(),
|
||||
details: Some("Communication failed before we could hear back from Discord".into()),
|
||||
};
|
||||
let actual: ErrorResponse = err.into();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversion_from_apiwrappererror_with_401_discord_error() {
|
||||
let err = ApiWrapperError::Api(DiscordErrorResponse {
|
||||
code: 0,
|
||||
message: "401: Unauthorized".into(),
|
||||
});
|
||||
let expected = ErrorResponse {
|
||||
code: 401,
|
||||
message: "401: Unauthorized".into(),
|
||||
..Default::default()
|
||||
};
|
||||
let actual: ErrorResponse = err.into();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversion_from_apiwrappererror_with_generic_discord_error() {
|
||||
let err = ApiWrapperError::Api(DiscordErrorResponse {
|
||||
code: 0,
|
||||
message: "Something else".into(),
|
||||
});
|
||||
let expected = ErrorResponse {
|
||||
code: 0,
|
||||
message: "Something else".into(),
|
||||
..Default::default()
|
||||
};
|
||||
let actual: ErrorResponse = err.into();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversion_from_database_error() {
|
||||
let err = gejdr_core::sqlx::Error::PoolClosed;
|
||||
let expected = ErrorResponse {
|
||||
code: 500,
|
||||
message: "Internal database error".into(),
|
||||
..Default::default()
|
||||
};
|
||||
let actual: ErrorResponse = err.into();
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
}
|
||||
29
gejdr-backend/src/route/health.rs
Normal file
29
gejdr-backend/src/route/health.rs
Normal file
@@ -0,0 +1,29 @@
|
||||
use poem_openapi::{ApiResponse, OpenApi};
|
||||
|
||||
use super::ApiCategory;
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum HealthResponse {
|
||||
#[oai(status = 200)]
|
||||
Ok,
|
||||
}
|
||||
|
||||
pub struct HealthApi;
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
|
||||
impl HealthApi {
|
||||
#[oai(path = "/", method = "get")]
|
||||
async fn health_check(&self) -> HealthResponse {
|
||||
tracing::event!(target: "gege-jdr-backend", tracing::Level::DEBUG, "Accessing health-check endpoint.");
|
||||
HealthResponse::Ok
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_check_works() {
|
||||
let app = crate::get_test_app(None).await;
|
||||
let cli = poem::test::TestClient::new(app);
|
||||
let resp = cli.get("/v1/api/health-check").send().await;
|
||||
resp.assert_status_is_ok();
|
||||
resp.assert_text("").await;
|
||||
}
|
||||
24
gejdr-backend/src/route/mod.rs
Normal file
24
gejdr-backend/src/route/mod.rs
Normal file
@@ -0,0 +1,24 @@
|
||||
use poem_openapi::{OpenApi, Tags};
|
||||
|
||||
mod health;
|
||||
pub use health::HealthApi;
|
||||
|
||||
mod version;
|
||||
pub use version::VersionApi;
|
||||
|
||||
mod errors;
|
||||
|
||||
mod auth;
|
||||
pub use auth::AuthApi;
|
||||
|
||||
#[derive(Tags)]
|
||||
enum ApiCategory {
|
||||
Auth,
|
||||
Health,
|
||||
Version,
|
||||
}
|
||||
|
||||
pub struct Api;
|
||||
|
||||
#[OpenApi]
|
||||
impl Api {}
|
||||
46
gejdr-backend/src/route/version.rs
Normal file
46
gejdr-backend/src/route/version.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use poem::Result;
|
||||
use poem_openapi::{payload::Json, ApiResponse, Object, OpenApi};
|
||||
|
||||
use crate::settings::Settings;
|
||||
|
||||
use super::ApiCategory;
|
||||
|
||||
#[derive(Object, Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
struct Meta {
|
||||
version: String,
|
||||
}
|
||||
|
||||
impl From<poem::web::Data<&Settings>> for Meta {
|
||||
fn from(value: poem::web::Data<&Settings>) -> Self {
|
||||
let version = value.application.version.clone();
|
||||
Self { version }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum VersionResponse {
|
||||
#[oai(status = 200)]
|
||||
Version(Json<Meta>),
|
||||
}
|
||||
|
||||
pub struct VersionApi;
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
|
||||
impl VersionApi {
|
||||
#[oai(path = "/", method = "get")]
|
||||
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
|
||||
tracing::event!(target: "gege-jdr-backend", tracing::Level::DEBUG, "Accessing version endpoint.");
|
||||
Ok(VersionResponse::Version(Json(settings.into())))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn version_works() {
|
||||
let app = crate::get_test_app(None).await;
|
||||
let cli = poem::test::TestClient::new(app);
|
||||
let resp = cli.get("/v1/api/version").send().await;
|
||||
resp.assert_status_is_ok();
|
||||
let json = resp.json().await;
|
||||
let json_value = json.value();
|
||||
json_value.object().get("version").assert_not_null();
|
||||
}
|
||||
213
gejdr-backend/src/settings.rs
Normal file
213
gejdr-backend/src/settings.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use gejdr_core::database::Database;
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||
pub struct Settings {
|
||||
pub application: ApplicationSettings,
|
||||
pub database: Database,
|
||||
pub discord: Discord,
|
||||
pub debug: bool,
|
||||
pub email: EmailSettings,
|
||||
pub frontend_url: String,
|
||||
}
|
||||
|
||||
impl Settings {
|
||||
#[must_use]
|
||||
pub fn web_address(&self) -> String {
|
||||
self.application.base_url.clone()
|
||||
}
|
||||
|
||||
/// Multipurpose function that helps detect the current
|
||||
/// environment the application is running in using the
|
||||
/// `APP_ENVIRONMENT` environment variable.
|
||||
///
|
||||
/// ```text
|
||||
/// APP_ENVIRONMENT = development | dev | production | prod
|
||||
/// ```
|
||||
///
|
||||
/// After detection, it loads the appropriate `.yaml` file. It
|
||||
/// then loads the environment variables that overrides whatever
|
||||
/// is set in the `.yaml` files. For this to work, the environment
|
||||
/// variable MUST be in uppercase and start with `APP`, a `_`
|
||||
/// separator, then the category of settings, followed by a `__`
|
||||
/// separator, and finally the variable itself. For instance,
|
||||
/// `APP__APPLICATION_PORT=3001` for `port` to be set as `3001`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Function may return an error if it fails to parse its config
|
||||
/// files.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the program fails to detect the directory it is
|
||||
/// running in. Can also panic if it fails to parse the
|
||||
/// environment variable `APP_ENVIRONMENT` and it fails to fall
|
||||
/// back to its default value.
|
||||
pub fn new() -> Result<Self, config::ConfigError> {
|
||||
let base_path = std::env::current_dir().expect("Failed to determine the current directory");
|
||||
let settings_directory = base_path.join("settings");
|
||||
let environment: Environment = std::env::var("APP_ENVIRONMENT")
|
||||
.unwrap_or_else(|_| "development".into())
|
||||
.try_into()
|
||||
.expect("Failed to parse APP_ENVIRONMENT");
|
||||
let environment_filename = format!("{environment}.yaml");
|
||||
let settings = config::Config::builder()
|
||||
.add_source(config::File::from(settings_directory.join("base.yaml")))
|
||||
.add_source(config::File::from(
|
||||
settings_directory.join(environment_filename),
|
||||
))
|
||||
.add_source(
|
||||
config::Environment::with_prefix("APP")
|
||||
.prefix_separator("__")
|
||||
.separator("__"),
|
||||
)
|
||||
.build()?;
|
||||
settings.try_deserialize::<Self>()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||
pub struct ApplicationSettings {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub port: u16,
|
||||
pub host: String,
|
||||
pub base_url: String,
|
||||
pub protocol: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub enum Environment {
|
||||
Development,
|
||||
Production,
|
||||
}
|
||||
|
||||
impl Default for Environment {
|
||||
fn default() -> Self {
|
||||
Self::Development
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Environment {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let self_str = match self {
|
||||
Self::Development => "development",
|
||||
Self::Production => "production",
|
||||
};
|
||||
write!(f, "{self_str}")
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Environment {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: String) -> Result<Self, Self::Error> {
|
||||
match value.to_lowercase().as_str() {
|
||||
"development" | "dev" => Ok(Self::Development),
|
||||
"production" | "prod" => Ok(Self::Production),
|
||||
other => Err(format!(
|
||||
"{other} is not a supported environment. Use either `development` or `production`"
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&str> for Environment {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self, Self::Error> {
|
||||
Self::try_from(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Clone, Debug, Default)]
|
||||
pub struct EmailSettings {
|
||||
pub host: String,
|
||||
pub user: String,
|
||||
pub password: String,
|
||||
pub from: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||
pub struct Discord {
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_environment_works() {
|
||||
let default_environment = Environment::default();
|
||||
assert_eq!(Environment::Development, default_environment);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display_environment_works() {
|
||||
let expected_prod = "production".to_string();
|
||||
let expected_dev = "development".to_string();
|
||||
let prod = Environment::Production.to_string();
|
||||
let dev = Environment::Development.to_string();
|
||||
assert_eq!(expected_prod, prod);
|
||||
assert_eq!(expected_dev, dev);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_from_works() {
|
||||
[
|
||||
"DEVELOPMENT",
|
||||
"DEVEloPmENT",
|
||||
"Development",
|
||||
"DEV",
|
||||
"Dev",
|
||||
"dev",
|
||||
]
|
||||
.iter()
|
||||
.map(|v| (*v).to_string())
|
||||
.for_each(|v| {
|
||||
let environment = Environment::try_from(v);
|
||||
assert!(environment.is_ok());
|
||||
assert_eq!(Environment::Development, environment.unwrap());
|
||||
});
|
||||
[
|
||||
"PRODUCTION",
|
||||
"Production",
|
||||
"PRODuction",
|
||||
"production",
|
||||
"PROD",
|
||||
"Prod",
|
||||
"prod",
|
||||
]
|
||||
.iter()
|
||||
.map(|v| (*v).to_string())
|
||||
.for_each(|v| {
|
||||
let environment = Environment::try_from(v);
|
||||
assert!(environment.is_ok());
|
||||
assert_eq!(Environment::Production, environment.unwrap());
|
||||
});
|
||||
let environment = Environment::try_from("invalid");
|
||||
assert!(environment.is_err());
|
||||
assert_eq!(
|
||||
"invalid is not a supported environment. Use either `development` or `production`"
|
||||
.to_string(),
|
||||
environment.err().unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn web_address_works() {
|
||||
let settings = Settings {
|
||||
debug: false,
|
||||
application: ApplicationSettings {
|
||||
base_url: "127.0.0.1".to_string(),
|
||||
port: 3000,
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let expected = "127.0.0.1".to_string();
|
||||
assert_eq!(expected, settings.web_address());
|
||||
}
|
||||
}
|
||||
139
gejdr-backend/src/startup.rs
Normal file
139
gejdr-backend/src/startup.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use gejdr_core::sqlx;
|
||||
|
||||
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
|
||||
use poem::middleware::{CookieJarManagerEndpoint, Cors};
|
||||
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
|
||||
use poem::{EndpointExt, Route};
|
||||
use poem_openapi::OpenApiService;
|
||||
|
||||
use crate::oauth::DiscordOauthProvider;
|
||||
use crate::route::AuthApi;
|
||||
use crate::{
|
||||
route::{Api, HealthApi, VersionApi},
|
||||
settings::Settings,
|
||||
};
|
||||
|
||||
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
|
||||
pub type App = AddDataEndpoint<
|
||||
AddDataEndpoint<
|
||||
AddDataEndpoint<
|
||||
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
|
||||
DiscordOauthProvider,
|
||||
>,
|
||||
sqlx::Pool<sqlx::Postgres>,
|
||||
>,
|
||||
Settings,
|
||||
>;
|
||||
|
||||
pub struct Application {
|
||||
server: Server,
|
||||
app: poem::Route,
|
||||
port: u16,
|
||||
database: sqlx::postgres::PgPool,
|
||||
pub settings: Settings,
|
||||
}
|
||||
|
||||
pub struct RunnableApplication {
|
||||
server: Server,
|
||||
app: App,
|
||||
}
|
||||
|
||||
impl RunnableApplication {
|
||||
/// Runs the application until it decides to stop by itself.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the server encounters an internal error it cannot recover
|
||||
/// from, it will forward it to this function which will forward
|
||||
/// it to its caller.
|
||||
pub async fn run(self) -> Result<(), std::io::Error> {
|
||||
self.server.run(self.app).await
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RunnableApplication> for App {
|
||||
fn from(value: RunnableApplication) -> Self {
|
||||
value.app
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Application> for RunnableApplication {
|
||||
fn from(val: Application) -> Self {
|
||||
let app = val
|
||||
.app
|
||||
.with(Cors::new())
|
||||
.with(CookieSession::new(CookieConfig::default().secure(true)))
|
||||
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
|
||||
.data(val.database)
|
||||
.data(val.settings);
|
||||
let server = val.server;
|
||||
Self { server, app }
|
||||
}
|
||||
}
|
||||
|
||||
impl Application {
|
||||
async fn setup_db(
|
||||
settings: &Settings,
|
||||
test_pool: Option<sqlx::postgres::PgPool>,
|
||||
) -> sqlx::postgres::PgPool {
|
||||
let database_pool =
|
||||
test_pool.map_or_else(|| settings.database.get_connection_pool(), |pool| pool);
|
||||
if !cfg!(test) {
|
||||
gejdr_core::database::Database::migrate(&database_pool).await;
|
||||
}
|
||||
database_pool
|
||||
}
|
||||
|
||||
fn setup_app(settings: &Settings) -> poem::Route {
|
||||
let api_service = OpenApiService::new(
|
||||
(Api, AuthApi, HealthApi, VersionApi),
|
||||
settings.application.clone().name,
|
||||
settings.application.clone().version,
|
||||
);
|
||||
let ui = api_service.swagger_ui();
|
||||
poem::Route::new().nest("/", api_service).nest("/docs", ui)
|
||||
}
|
||||
|
||||
fn setup_server(
|
||||
settings: &Settings,
|
||||
tcp_listener: Option<poem::listener::TcpListener<String>>,
|
||||
) -> Server {
|
||||
let tcp_listener = tcp_listener.unwrap_or_else(|| {
|
||||
let address = format!(
|
||||
"{}:{}",
|
||||
settings.application.host, settings.application.port
|
||||
);
|
||||
poem::listener::TcpListener::bind(address)
|
||||
});
|
||||
poem::Server::new(tcp_listener)
|
||||
}
|
||||
|
||||
pub async fn build(
|
||||
settings: Settings,
|
||||
test_pool: Option<sqlx::postgres::PgPool>,
|
||||
tcp_listener: Option<poem::listener::TcpListener<String>>,
|
||||
) -> Self {
|
||||
let database_pool = Self::setup_db(&settings, test_pool).await;
|
||||
let port = settings.application.port;
|
||||
let app = Self::setup_app(&settings);
|
||||
let server = Self::setup_server(&settings, tcp_listener);
|
||||
Self {
|
||||
server,
|
||||
app,
|
||||
port,
|
||||
database: database_pool,
|
||||
settings,
|
||||
}
|
||||
}
|
||||
|
||||
/// Make the app runnable.
|
||||
#[must_use]
|
||||
pub fn make_app(self) -> RunnableApplication {
|
||||
self.into()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub const fn port(&self) -> u16 {
|
||||
self.port
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user