Compare commits

..

4 Commits

Author SHA1 Message Date
9489f78224
feat: initial OAuth implementation with Discord
Some checks failed
CI / tests (push) Failing after 3m53s
For now, only a basic implementation of OAuth with Discord is
implemented. If the user calls the Discord signin endpoint, they get
redirected to Discord’s OAuth page. Once they accept, they get
redirected back to the backend’s callback API endpoint. The token the
user got from Discord is stored in the user’s session.

When the user wants to log out, the user’s session’s token is wiped.

This commit also updates the dependencies of the project. It also
removes the dependency lettre as well as the mailpit docker service
for developers as it appears clearer this project won’t send emails
anytime soon.
2024-11-23 10:09:03 +01:00
ae10711e41
chore: update rust toolchain 2024-11-23 09:39:52 +01:00
ff90b1959f
feat: authentication through Discord OAuth2 2024-08-10 12:13:05 +02:00
1125bc4a38
Initial commit 2024-08-10 12:12:58 +02:00
23 changed files with 4477 additions and 99 deletions

View File

@ -1,13 +1,12 @@
;;; Directory Local Variables -*- no-byte-compile: t -*- ;;; Directory Local Variables -*- no-byte-compile: t -*-
;;; For more information see (info "(emacs) Directory Variables") ;;; For more information see (info "(emacs) Directory Variables")
((sql-mode ((rustic-mode . ((fill-column . 80)))
. (sql-mode . ((eval . (progn
((eval . (progn
(setq-local lsp-sqls-connections (setq-local lsp-sqls-connections
`(((driver . "postgresql") `(((driver . "postgresql")
(dataSourceName . (dataSourceName \,
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable" (format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
(getenv "DB_HOST") (getenv "DB_HOST")
(getenv "DB_PORT") (getenv "DB_PORT")
(getenv "DB_USER") (getenv "DB_USER")

4021
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -17,45 +17,32 @@ name = "gege-jdr-backend"
[dependencies] [dependencies]
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.0", features = ["yaml"] } config = { version = "0.14.1", features = ["yaml"] }
dotenvy = "0.15.7" dotenvy = "0.15.7"
serde = "1.0.204" oauth2 = "4.4.2"
serde_json = "1.0.120" quote = "1.0.37"
thiserror = "1.0.63" reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] } serde = "1.0.215"
serde_json = "1.0.133"
thiserror = "1.0.69"
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] } tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] } uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.lettre]
version = "0.11.7"
default-features = false
features = [
"builder",
"hostname",
"pool",
"rustls-tls",
"tokio1",
"tokio1-rustls-tls",
"smtp-transport"
]
[dependencies.poem] [dependencies.poem]
version = "3.0.4" version = "3.1.3"
default-features = false default-features = false
features = [ features = ["csrf", "rustls", "cookie", "test", "session"]
"csrf",
"rustls",
"cookie",
"test"
]
[dependencies.poem-openapi] [dependencies.poem-openapi]
version = "5.0.3" version = "5.1.2"
features = ["chrono", "swagger-ui", "uuid"] features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
[dependencies.sqlx] [dependencies.sqlx]
version = "0.8.0" version = "0.8.2"
default-features = false default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"] features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }

View File

@ -31,25 +31,6 @@ services:
depends_on: depends_on:
- db - db
# If you run GegeJdrBackend in production, DO NOT use mailpit.
# This tool is for testing only. Instead, you should use a real SMTP
# provider, such as Mailgun, Mailwhale, or Postal.
mailpit:
image: axllent/mailpit:latest
restart: unless-stopped
container_name: gege-jdr-backend-mailpit
ports:
- 127.0.0.1:8025:8025 # WebUI
- 127.0.0.1:1025:1025 # SMTP
volumes:
- gege_jdr_backend_mailpit:/data
environment:
MP_MAX_MESSAGES: 5000
MP_DATABASE: /data/mailpit.db
MP_SMTP_AUTH_ACCEPT_ANY: 1
MP_SMTP_AUTH_ALLOW_INSECURE: 1
volumes: volumes:
gege_jdr_backend_db_data: gege_jdr_backend_db_data:
gege_jdr_backend_pgadmin_data: gege_jdr_backend_pgadmin_data:
gege_jdr_backend_mailpit:

View File

@ -5,11 +5,11 @@
"systems": "systems" "systems": "systems"
}, },
"locked": { "locked": {
"lastModified": 1710146030, "lastModified": 1731533236,
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide", "owner": "numtide",
"repo": "flake-utils", "repo": "flake-utils",
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1723175592, "lastModified": 1732014248,
"narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=", "narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "5e0ca22929f3342b19569b21b2f3462f053e497b", "rev": "23e89b7da85c3640bbc2173fe04f4bd114342367",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -36,11 +36,11 @@
}, },
"nixpkgs_2": { "nixpkgs_2": {
"locked": { "locked": {
"lastModified": 1718428119, "lastModified": 1728538411,
"narHash": "sha256-WdWDpNaq6u1IPtxtYHHWpl5BmabtpmLnMAx0RdJ/vo8=", "narHash": "sha256-f0SBJz1eZ2yOuKUr5CA9BHULGXVSn6miBuUWdTyhUhU=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "e6cea36f83499eb4e9cd184c8a8e823296b50ad5", "rev": "b69de56fac8c2b6f8fd27f2eca01dcda8e0a4221",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2" "nixpkgs": "nixpkgs_2"
}, },
"locked": { "locked": {
"lastModified": 1723256423, "lastModified": 1732242723,
"narHash": "sha256-9iDTrfVM+mbcad31a47oqW8t8tfSA4C/si6F8F2DO/w=", "narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "615cfd85b4d9c51811a8d875374268fab5bd4089", "rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -0,0 +1,5 @@
-- Add down migration script here
ALTER TABLE IF EXISTS public.sessions DROP CONSTRAINT IF EXISTS sessions_user_id_users_fk;
DROP TABLE IF EXISTS public.sessions;
DROP TABLE IF EXISTS public.users;
DROP EXTENSION IF EXISTS "uuid-ossp";

View File

@ -0,0 +1,29 @@
-- Add up migration script here
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
email character varying(255) NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
CONSTRAINT users_email_unique UNIQUE (email)
);
CREATE TABLE IF NOT EXISTS public.sessions
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
user_id uuid NOT NULL,
session_id character varying NOT NULL,
expires_at timestamp with time zone NOT NULL,
PRIMARY KEY (id),
CONSTRAINT sessions_user_id_unique UNIQUE (user_id)
);
ALTER TABLE IF EXISTS public.sessions
ADD CONSTRAINT sessions_user_id_users_fk FOREIGN KEY (user_id)
REFERENCES public.users (id) MATCH SIMPLE
ON UPDATE CASCADE
ON DELETE CASCADE
NOT VALID;

View File

@ -1,4 +1,4 @@
[toolchain] [toolchain]
channel = "1.78.0" channel = "1.81.0"
components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ] components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ]
profile = "default" profile = "default"

View File

@ -16,3 +16,7 @@ email:
user: user@gege-jdr-backend.example user: user@gege-jdr-backend.example
from: GegeJdrBackend <noreply@gege-jdr-backend.example> from: GegeJdrBackend <noreply@gege-jdr-backend.example>
password: hunter2 password: hunter2
discord:
client_id: changeme
client_secret: changeme

View File

@ -3,5 +3,5 @@ debug: true
application: application:
protocol: http protocol: http
host: 127.0.0.1 host: localhost
base_url: http://127.0.0.1:3000 base_url: http://localhost:3000

19
src/errors.rs Normal file
View File

@ -0,0 +1,19 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
Sql(#[from] sqlx::Error),
#[error("HTTP request error: {0}")]
Request(#[from] reqwest::Error),
#[error("OAuth token error: {0}")]
TokenError(String),
#[error("Unauthorized")]
Unauthorized,
#[error("Attempted to get a value, none found")]
OptionError,
#[error("Failed to parse a number as an integer")]
ParseIntError(#[from] std::num::TryFromIntError),
#[error("Encountered an error trying to convert an infaillible value")]
FromRequestPartsError(#[from] std::convert::Infallible),
}

View File

@ -5,10 +5,13 @@
#![allow(clippy::unused_async)] #![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling #![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
pub mod route; mod errors;
pub mod settings; mod models;
pub mod startup; mod oauth;
pub mod telemetry; mod route;
mod settings;
mod startup;
mod telemetry;
type MaybeListener = Option<poem::listener::TcpListener<String>>; type MaybeListener = Option<poem::listener::TcpListener<String>>;
@ -29,8 +32,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
tracing::event!( tracing::event!(
target: "gege-jdr-backend", target: "gege-jdr-backend",
tracing::Level::INFO, tracing::Level::INFO,
"Listening on http://127.0.0.1:{}/", "Listening on {}",
application.port() application.settings.web_address()
); );
application application
} }

18
src/models/accounts.rs Normal file
View File

@ -0,0 +1,18 @@
type Timestampz = chrono::DateTime<chrono::Utc>;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
pub struct User {
pub id: uuid::Uuid,
pub email: String,
pub created_at: Timestampz,
pub last_updated: Timestampz,
}
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
pub struct Session {
pub id: uuid::Uuid,
pub user_id: uuid::Uuid,
#[allow(clippy::struct_field_names)]
pub session_id: String,
pub expires_at: Timestampz,
}

1
src/models/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod accounts;

62
src/oauth/discord.rs Normal file
View File

@ -0,0 +1,62 @@
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
};
use reqwest::Url;
use crate::{errors::ApiError, settings::Settings};
use super::OauthProvider;
#[derive(Debug, Clone)]
pub struct DiscordOauthProvider {
client: BasicClient,
}
impl DiscordOauthProvider {
pub fn new(settings: &Settings) -> Self {
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
.expect("Invalid token endpoint URL");
let revocation_url =
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
.expect("Invalid revocation URL");
let client = BasicClient::new(
ClientId::new(settings.discord.client_id.clone()),
Some(ClientSecret::new(settings.discord.client_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
.set_revocation_uri(revocation_url);
Self { client }
}
}
impl OauthProvider for DiscordOauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
.set_pkce_challenge(pkce_challenge)
.url();
(auth_url, csrf_token, pkce_verifier)
}
async fn token(
&self,
code: String,
verifier: PkceCodeVerifier,
) -> Result<super::Token, ApiError> {
self.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(verifier)
.request_async(oauth2::reqwest::async_http_client)
.await
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
}
}

17
src/oauth/mod.rs Normal file
View File

@ -0,0 +1,17 @@
mod discord;
pub use discord::DiscordOauthProvider;
use oauth2::{
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
StandardTokenResponse,
};
use reqwest::Url;
use crate::errors::ApiError;
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
pub trait OauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
}

125
src/route/auth.rs Normal file
View File

@ -0,0 +1,125 @@
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
use poem::web::Data;
use poem::{session::Session, web::Form};
use poem_openapi::payload::{Json, PlainText};
use poem_openapi::{ApiResponse, Object, OpenApi};
use crate::oauth::{DiscordOauthProvider, OauthProvider};
use super::errors::ErrorResponse;
use super::ApiCategory;
pub struct AuthApi;
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
struct DiscordCallbackRequest {
code: String,
state: String,
}
#[derive(ApiResponse)]
enum LoginStatusResponse {
#[oai(status = 201)]
LoggedIn,
#[oai(status = 201)]
LoggedOut,
#[oai(status = 301)]
LoginRedirect(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 500)]
TokenError(Json<ErrorResponse>),
}
#[derive(ApiResponse)]
enum CsrfResponse {
#[oai(status = 201)]
Token(PlainText<String>),
}
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
impl AuthApi {
// TODO: implement the following endpoints:
// - /signin
// - /signout
// - /session
// - /providers
// See https://next-auth.js.org/getting-started/rest-api
#[oai(path = "/signin/discord", method = "get")]
async fn signin_discord(
&self,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
session.set("csrf", csrf_token);
session.set("pkce", pkce_verifier);
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Redirect URL: {}",
auth_url
);
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
}
#[oai(path = "/callback/discord", method = "get")]
async fn callback_discord(
&self,
Form(auth_request): Form<DiscordCallbackRequest>,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
tracing::event!(
target: "auth-discord",
tracing::Level::DEBUG,
"Discord replied with: {:?}",
auth_request
);
let csrf_token = session
.get::<CsrfToken>("csrf")
.expect("Failed to retrieve Csrf token from session");
if *csrf_token.secret().to_string() != auth_request.state {
return LoginStatusResponse::TokenError(
ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some(
"OAuth provider did not send a message that matches what was expected"
.into(),
),
}
.into(),
);
}
let pkce_verifier = session
.get::<PkceCodeVerifier>("pkce")
.expect("Failed to retrieve pkce verifier from session");
let token = oauth.token(auth_request.code, pkce_verifier).await;
if let Err(e) = token {
return LoginStatusResponse::TokenError(Json(e.into()));
}
let token = token.unwrap();
tracing::event!(
target: "auth-discord",
tracing::Level::DEBUG,
"Token: {:?}",
token
);
session.set("token", token);
LoginStatusResponse::LoggedIn
}
#[oai(path = "/csrf", method = "get")]
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
CsrfResponse::Token(PlainText(token.0.clone()))
}
#[oai(path = "/signout", method = "post")]
async fn signout(&self, session: &Session) -> LoginStatusResponse {
session.remove("token");
LoginStatusResponse::LoggedOut
}
}

88
src/route/errors.rs Normal file
View File

@ -0,0 +1,88 @@
use poem::error::ResponseError;
use poem_openapi::Object;
use crate::errors::ApiError;
#[derive(Debug, serde::Serialize, Default, Object)]
pub struct ErrorResponse {
pub code: u16,
pub message: String,
pub details: Option<String>,
}
impl From<ApiError> for ErrorResponse {
fn from(value: ApiError) -> Self {
match value {
ApiError::Sql(e) => Self {
code: 500,
message: "SQL error".into(),
details: Some(e.to_string()),
},
ApiError::Request(e) => Self {
code: 500,
message: "HTTP request error".into(),
details: Some(e.to_string()),
},
ApiError::TokenError(e) => Self {
code: 500,
message: "OAuth token error".into(),
details: Some(e),
},
ApiError::Unauthorized => Self {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
},
ApiError::OptionError => Self {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
},
ApiError::ParseIntError(e) => Self {
code: 500,
message: "Failed to parse a number as an integer".into(),
details: Some(e.to_string()),
},
ApiError::FromRequestPartsError(e) => Self {
code: 500,
message: "Encountered an error trying to convert an infaillible value".to_string(),
details: Some(e.to_string()),
},
}
}
}
impl ResponseError for ApiError {
fn status(&self) -> reqwest::StatusCode {
match self {
Self::FromRequestPartsError(_)
| Self::ParseIntError(_)
| Self::OptionError
| Self::Sql(_)
| Self::Request(_)
| Self::TokenError(_) => reqwest::StatusCode::INTERNAL_SERVER_ERROR,
Self::Unauthorized => reqwest::StatusCode::UNAUTHORIZED,
}
}
fn as_response(&self) -> poem::Response
where
Self: std::error::Error + Send + Sync + 'static,
{
match self {
Self::Sql(_) => todo!(),
Self::Request(_) => todo!(),
Self::TokenError(_) => todo!(),
Self::Unauthorized => todo!(),
Self::OptionError => todo!(),
Self::ParseIntError(_) => todo!(),
Self::FromRequestPartsError(_) => todo!(),
}
}
}
impl From<ErrorResponse> for poem_openapi::payload::Json<ErrorResponse> {
fn from(value: ErrorResponse) -> Self {
Self(value)
}
}

View File

@ -10,7 +10,7 @@ enum HealthResponse {
pub struct HealthApi; pub struct HealthApi;
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")] #[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
impl HealthApi { impl HealthApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn health_check(&self) -> HealthResponse { async fn health_check(&self) -> HealthResponse {

View File

@ -6,13 +6,19 @@ pub use health::HealthApi;
mod version; mod version;
pub use version::VersionApi; pub use version::VersionApi;
mod errors;
mod auth;
pub use auth::AuthApi;
#[derive(Tags)] #[derive(Tags)]
enum ApiCategory { enum ApiCategory {
Auth,
Health, Health,
Version, Version,
} }
pub(crate) struct Api; pub struct Api;
#[OpenApi] #[OpenApi]
impl Api {} impl Api {}

View File

@ -25,7 +25,7 @@ enum VersionResponse {
pub struct VersionApi; pub struct VersionApi;
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")] #[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
impl VersionApi { impl VersionApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> { async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {

View File

@ -4,6 +4,7 @@ use sqlx::ConnectOptions;
pub struct Settings { pub struct Settings {
pub application: ApplicationSettings, pub application: ApplicationSettings,
pub database: Database, pub database: Database,
pub discord: Discord,
pub debug: bool, pub debug: bool,
pub email: EmailSettings, pub email: EmailSettings,
pub frontend_url: String, pub frontend_url: String,
@ -12,16 +13,8 @@ pub struct Settings {
impl Settings { impl Settings {
#[must_use] #[must_use]
pub fn web_address(&self) -> String { pub fn web_address(&self) -> String {
if self.debug {
format!(
"{}:{}",
self.application.base_url.clone(),
self.application.port
)
} else {
self.application.base_url.clone() self.application.base_url.clone()
} }
}
/// Multipurpose function that helps detect the current /// Multipurpose function that helps detect the current
/// environment the application is running in using the /// environment the application is running in using the
@ -65,7 +58,7 @@ impl Settings {
)) ))
.add_source( .add_source(
config::Environment::with_prefix("APP") config::Environment::with_prefix("APP")
.prefix_separator("_") .prefix_separator("__")
.separator("__"), .separator("__"),
) )
.build()?; .build()?;
@ -164,6 +157,12 @@ pub struct EmailSettings {
pub from: String, pub from: String,
} }
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Discord {
pub client_id: String,
pub client_secret: String,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,8 +1,11 @@
use poem::middleware::Cors;
use poem::middleware::{AddDataEndpoint, CorsEndpoint}; use poem::middleware::{AddDataEndpoint, CorsEndpoint};
use poem::middleware::{CookieJarManagerEndpoint, Cors};
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
use poem::{EndpointExt, Route}; use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
use crate::oauth::DiscordOauthProvider;
use crate::route::AuthApi;
use crate::{ use crate::{
route::{Api, HealthApi, VersionApi}, route::{Api, HealthApi, VersionApi},
settings::Settings, settings::Settings,
@ -22,14 +25,23 @@ pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgr
} }
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>; type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>; pub type App = AddDataEndpoint<
AddDataEndpoint<
AddDataEndpoint<
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
DiscordOauthProvider,
>,
sqlx::Pool<sqlx::Postgres>,
>,
Settings,
>;
pub struct Application { pub struct Application {
server: Server, server: Server,
app: poem::Route, app: poem::Route,
port: u16, port: u16,
database: sqlx::postgres::PgPool, database: sqlx::postgres::PgPool,
settings: Settings, pub settings: Settings,
} }
pub struct RunnableApplication { pub struct RunnableApplication {
@ -61,6 +73,8 @@ impl From<Application> for RunnableApplication {
let app = val let app = val
.app .app
.with(Cors::new()) .with(Cors::new())
.with(CookieSession::new(CookieConfig::default().secure(true)))
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
.data(val.database) .data(val.database)
.data(val.settings); .data(val.settings);
let server = val.server; let server = val.server;
@ -83,7 +97,7 @@ impl Application {
fn setup_app(settings: &Settings) -> poem::Route { fn setup_app(settings: &Settings) -> poem::Route {
let api_service = OpenApiService::new( let api_service = OpenApiService::new(
(Api, HealthApi, VersionApi), (Api, AuthApi, HealthApi, VersionApi),
settings.application.clone().name, settings.application.clone().name,
settings.application.clone().version, settings.application.clone().version,
); );