feat: initial OAuth implementation with Discord
Some checks failed
CI / tests (push) Failing after 3m53s

For now, only a basic implementation of OAuth with Discord is
implemented. If the user calls the Discord signin endpoint, they get
redirected to Discord’s OAuth page. Once they accept, they get
redirected back to the backend’s callback API endpoint. The token the
user got from Discord is stored in the user’s session.

When the user wants to log out, the user’s session’s token is wiped.

This commit also updates the dependencies of the project. It also
removes the dependency lettre as well as the mailpit docker service
for developers as it appears clearer this project won’t send emails
anytime soon.
This commit is contained in:
Lucien Cartier-Tilet 2024-11-23 10:00:53 +01:00
parent ae10711e41
commit 9489f78224
Signed by: phundrak
GPG Key ID: 347803E8073EACE0
18 changed files with 1520 additions and 562 deletions

View File

@ -1,13 +1,12 @@
;;; Directory Local Variables -*- no-byte-compile: t -*- ;;; Directory Local Variables -*- no-byte-compile: t -*-
;;; For more information see (info "(emacs) Directory Variables") ;;; For more information see (info "(emacs) Directory Variables")
((sql-mode ((rustic-mode . ((fill-column . 80)))
. (sql-mode . ((eval . (progn
((eval . (progn
(setq-local lsp-sqls-connections (setq-local lsp-sqls-connections
`(((driver . "postgresql") `(((driver . "postgresql")
(dataSourceName . (dataSourceName \,
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable" (format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
(getenv "DB_HOST") (getenv "DB_HOST")
(getenv "DB_PORT") (getenv "DB_PORT")
(getenv "DB_USER") (getenv "DB_USER")

1594
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -17,45 +17,32 @@ name = "gege-jdr-backend"
[dependencies] [dependencies]
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.0", features = ["yaml"] } config = { version = "0.14.1", features = ["yaml"] }
dotenvy = "0.15.7" dotenvy = "0.15.7"
serde = "1.0.204" oauth2 = "4.4.2"
serde_json = "1.0.120" quote = "1.0.37"
thiserror = "1.0.63" reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] } serde = "1.0.215"
serde_json = "1.0.133"
thiserror = "1.0.69"
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] } tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.10.0", features = ["v4", "serde"] } uuid = { version = "1.11.0", features = ["v4", "serde"] }
[dependencies.lettre]
version = "0.11.7"
default-features = false
features = [
"builder",
"hostname",
"pool",
"rustls-tls",
"tokio1",
"tokio1-rustls-tls",
"smtp-transport"
]
[dependencies.poem] [dependencies.poem]
version = "3.0.4" version = "3.1.3"
default-features = false default-features = false
features = [ features = ["csrf", "rustls", "cookie", "test", "session"]
"csrf",
"rustls",
"cookie",
"test"
]
[dependencies.poem-openapi] [dependencies.poem-openapi]
version = "5.0.3" version = "5.1.2"
features = ["chrono", "swagger-ui", "uuid"] features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
[dependencies.sqlx] [dependencies.sqlx]
version = "0.8.0" version = "0.8.2"
default-features = false default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"] features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }

View File

@ -31,25 +31,6 @@ services:
depends_on: depends_on:
- db - db
# If you run GegeJdrBackend in production, DO NOT use mailpit.
# This tool is for testing only. Instead, you should use a real SMTP
# provider, such as Mailgun, Mailwhale, or Postal.
mailpit:
image: axllent/mailpit:latest
restart: unless-stopped
container_name: gege-jdr-backend-mailpit
ports:
- 127.0.0.1:8025:8025 # WebUI
- 127.0.0.1:1025:1025 # SMTP
volumes:
- gege_jdr_backend_mailpit:/data
environment:
MP_MAX_MESSAGES: 5000
MP_DATABASE: /data/mailpit.db
MP_SMTP_AUTH_ACCEPT_ANY: 1
MP_SMTP_AUTH_ALLOW_INSECURE: 1
volumes: volumes:
gege_jdr_backend_db_data: gege_jdr_backend_db_data:
gege_jdr_backend_pgadmin_data: gege_jdr_backend_pgadmin_data:
gege_jdr_backend_mailpit:

View File

@ -3,5 +3,5 @@ debug: true
application: application:
protocol: http protocol: http
host: 127.0.0.1 host: localhost
base_url: http://127.0.0.1:3000 base_url: http://localhost:3000

19
src/errors.rs Normal file
View File

@ -0,0 +1,19 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
Sql(#[from] sqlx::Error),
#[error("HTTP request error: {0}")]
Request(#[from] reqwest::Error),
#[error("OAuth token error: {0}")]
TokenError(String),
#[error("Unauthorized")]
Unauthorized,
#[error("Attempted to get a value, none found")]
OptionError,
#[error("Failed to parse a number as an integer")]
ParseIntError(#[from] std::num::TryFromIntError),
#[error("Encountered an error trying to convert an infaillible value")]
FromRequestPartsError(#[from] std::convert::Infallible),
}

View File

@ -5,10 +5,13 @@
#![allow(clippy::unused_async)] #![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling #![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
pub mod route; mod errors;
pub mod settings; mod models;
pub mod startup; mod oauth;
pub mod telemetry; mod route;
mod settings;
mod startup;
mod telemetry;
type MaybeListener = Option<poem::listener::TcpListener<String>>; type MaybeListener = Option<poem::listener::TcpListener<String>>;
@ -29,8 +32,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
tracing::event!( tracing::event!(
target: "gege-jdr-backend", target: "gege-jdr-backend",
tracing::Level::INFO, tracing::Level::INFO,
"Listening on http://127.0.0.1:{}/", "Listening on {}",
application.port() application.settings.web_address()
); );
application application
} }

18
src/models/accounts.rs Normal file
View File

@ -0,0 +1,18 @@
type Timestampz = chrono::DateTime<chrono::Utc>;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
pub struct User {
pub id: uuid::Uuid,
pub email: String,
pub created_at: Timestampz,
pub last_updated: Timestampz,
}
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
pub struct Session {
pub id: uuid::Uuid,
pub user_id: uuid::Uuid,
#[allow(clippy::struct_field_names)]
pub session_id: String,
pub expires_at: Timestampz,
}

1
src/models/mod.rs Normal file
View File

@ -0,0 +1 @@
pub mod accounts;

62
src/oauth/discord.rs Normal file
View File

@ -0,0 +1,62 @@
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
};
use reqwest::Url;
use crate::{errors::ApiError, settings::Settings};
use super::OauthProvider;
#[derive(Debug, Clone)]
pub struct DiscordOauthProvider {
client: BasicClient,
}
impl DiscordOauthProvider {
pub fn new(settings: &Settings) -> Self {
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
.expect("Invalid token endpoint URL");
let revocation_url =
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
.expect("Invalid revocation URL");
let client = BasicClient::new(
ClientId::new(settings.discord.client_id.clone()),
Some(ClientSecret::new(settings.discord.client_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
.set_revocation_uri(revocation_url);
Self { client }
}
}
impl OauthProvider for DiscordOauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
.set_pkce_challenge(pkce_challenge)
.url();
(auth_url, csrf_token, pkce_verifier)
}
async fn token(
&self,
code: String,
verifier: PkceCodeVerifier,
) -> Result<super::Token, ApiError> {
self.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(verifier)
.request_async(oauth2::reqwest::async_http_client)
.await
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
}
}

17
src/oauth/mod.rs Normal file
View File

@ -0,0 +1,17 @@
mod discord;
pub use discord::DiscordOauthProvider;
use oauth2::{
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
StandardTokenResponse,
};
use reqwest::Url;
use crate::errors::ApiError;
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
pub trait OauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
}

125
src/route/auth.rs Normal file
View File

@ -0,0 +1,125 @@
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
use poem::web::Data;
use poem::{session::Session, web::Form};
use poem_openapi::payload::{Json, PlainText};
use poem_openapi::{ApiResponse, Object, OpenApi};
use crate::oauth::{DiscordOauthProvider, OauthProvider};
use super::errors::ErrorResponse;
use super::ApiCategory;
pub struct AuthApi;
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
struct DiscordCallbackRequest {
code: String,
state: String,
}
#[derive(ApiResponse)]
enum LoginStatusResponse {
#[oai(status = 201)]
LoggedIn,
#[oai(status = 201)]
LoggedOut,
#[oai(status = 301)]
LoginRedirect(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 500)]
TokenError(Json<ErrorResponse>),
}
#[derive(ApiResponse)]
enum CsrfResponse {
#[oai(status = 201)]
Token(PlainText<String>),
}
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
impl AuthApi {
// TODO: implement the following endpoints:
// - /signin
// - /signout
// - /session
// - /providers
// See https://next-auth.js.org/getting-started/rest-api
#[oai(path = "/signin/discord", method = "get")]
async fn signin_discord(
&self,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
session.set("csrf", csrf_token);
session.set("pkce", pkce_verifier);
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Redirect URL: {}",
auth_url
);
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
}
#[oai(path = "/callback/discord", method = "get")]
async fn callback_discord(
&self,
Form(auth_request): Form<DiscordCallbackRequest>,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
tracing::event!(
target: "auth-discord",
tracing::Level::DEBUG,
"Discord replied with: {:?}",
auth_request
);
let csrf_token = session
.get::<CsrfToken>("csrf")
.expect("Failed to retrieve Csrf token from session");
if *csrf_token.secret().to_string() != auth_request.state {
return LoginStatusResponse::TokenError(
ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some(
"OAuth provider did not send a message that matches what was expected"
.into(),
),
}
.into(),
);
}
let pkce_verifier = session
.get::<PkceCodeVerifier>("pkce")
.expect("Failed to retrieve pkce verifier from session");
let token = oauth.token(auth_request.code, pkce_verifier).await;
if let Err(e) = token {
return LoginStatusResponse::TokenError(Json(e.into()));
}
let token = token.unwrap();
tracing::event!(
target: "auth-discord",
tracing::Level::DEBUG,
"Token: {:?}",
token
);
session.set("token", token);
LoginStatusResponse::LoggedIn
}
#[oai(path = "/csrf", method = "get")]
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
CsrfResponse::Token(PlainText(token.0.clone()))
}
#[oai(path = "/signout", method = "post")]
async fn signout(&self, session: &Session) -> LoginStatusResponse {
session.remove("token");
LoginStatusResponse::LoggedOut
}
}

88
src/route/errors.rs Normal file
View File

@ -0,0 +1,88 @@
use poem::error::ResponseError;
use poem_openapi::Object;
use crate::errors::ApiError;
#[derive(Debug, serde::Serialize, Default, Object)]
pub struct ErrorResponse {
pub code: u16,
pub message: String,
pub details: Option<String>,
}
impl From<ApiError> for ErrorResponse {
fn from(value: ApiError) -> Self {
match value {
ApiError::Sql(e) => Self {
code: 500,
message: "SQL error".into(),
details: Some(e.to_string()),
},
ApiError::Request(e) => Self {
code: 500,
message: "HTTP request error".into(),
details: Some(e.to_string()),
},
ApiError::TokenError(e) => Self {
code: 500,
message: "OAuth token error".into(),
details: Some(e),
},
ApiError::Unauthorized => Self {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
},
ApiError::OptionError => Self {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
},
ApiError::ParseIntError(e) => Self {
code: 500,
message: "Failed to parse a number as an integer".into(),
details: Some(e.to_string()),
},
ApiError::FromRequestPartsError(e) => Self {
code: 500,
message: "Encountered an error trying to convert an infaillible value".to_string(),
details: Some(e.to_string()),
},
}
}
}
impl ResponseError for ApiError {
fn status(&self) -> reqwest::StatusCode {
match self {
Self::FromRequestPartsError(_)
| Self::ParseIntError(_)
| Self::OptionError
| Self::Sql(_)
| Self::Request(_)
| Self::TokenError(_) => reqwest::StatusCode::INTERNAL_SERVER_ERROR,
Self::Unauthorized => reqwest::StatusCode::UNAUTHORIZED,
}
}
fn as_response(&self) -> poem::Response
where
Self: std::error::Error + Send + Sync + 'static,
{
match self {
Self::Sql(_) => todo!(),
Self::Request(_) => todo!(),
Self::TokenError(_) => todo!(),
Self::Unauthorized => todo!(),
Self::OptionError => todo!(),
Self::ParseIntError(_) => todo!(),
Self::FromRequestPartsError(_) => todo!(),
}
}
}
impl From<ErrorResponse> for poem_openapi::payload::Json<ErrorResponse> {
fn from(value: ErrorResponse) -> Self {
Self(value)
}
}

View File

@ -10,7 +10,7 @@ enum HealthResponse {
pub struct HealthApi; pub struct HealthApi;
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")] #[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
impl HealthApi { impl HealthApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn health_check(&self) -> HealthResponse { async fn health_check(&self) -> HealthResponse {

View File

@ -6,13 +6,19 @@ pub use health::HealthApi;
mod version; mod version;
pub use version::VersionApi; pub use version::VersionApi;
mod errors;
mod auth;
pub use auth::AuthApi;
#[derive(Tags)] #[derive(Tags)]
enum ApiCategory { enum ApiCategory {
Auth,
Health, Health,
Version, Version,
} }
pub(crate) struct Api; pub struct Api;
#[OpenApi] #[OpenApi]
impl Api {} impl Api {}

View File

@ -25,7 +25,7 @@ enum VersionResponse {
pub struct VersionApi; pub struct VersionApi;
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")] #[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
impl VersionApi { impl VersionApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> { async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {

View File

@ -13,16 +13,8 @@ pub struct Settings {
impl Settings { impl Settings {
#[must_use] #[must_use]
pub fn web_address(&self) -> String { pub fn web_address(&self) -> String {
if self.debug {
format!(
"{}:{}",
self.application.base_url.clone(),
self.application.port
)
} else {
self.application.base_url.clone() self.application.base_url.clone()
} }
}
/// Multipurpose function that helps detect the current /// Multipurpose function that helps detect the current
/// environment the application is running in using the /// environment the application is running in using the
@ -66,7 +58,7 @@ impl Settings {
)) ))
.add_source( .add_source(
config::Environment::with_prefix("APP") config::Environment::with_prefix("APP")
.prefix_separator("_") .prefix_separator("__")
.separator("__"), .separator("__"),
) )
.build()?; .build()?;
@ -167,8 +159,8 @@ pub struct EmailSettings {
#[derive(Debug, serde::Deserialize, Clone, Default)] #[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Discord { pub struct Discord {
client_id: String, pub client_id: String,
client_secret: String, pub client_secret: String,
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1,8 +1,11 @@
use poem::middleware::Cors;
use poem::middleware::{AddDataEndpoint, CorsEndpoint}; use poem::middleware::{AddDataEndpoint, CorsEndpoint};
use poem::middleware::{CookieJarManagerEndpoint, Cors};
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
use poem::{EndpointExt, Route}; use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
use crate::oauth::DiscordOauthProvider;
use crate::route::AuthApi;
use crate::{ use crate::{
route::{Api, HealthApi, VersionApi}, route::{Api, HealthApi, VersionApi},
settings::Settings, settings::Settings,
@ -22,14 +25,23 @@ pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgr
} }
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>; type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>; pub type App = AddDataEndpoint<
AddDataEndpoint<
AddDataEndpoint<
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
DiscordOauthProvider,
>,
sqlx::Pool<sqlx::Postgres>,
>,
Settings,
>;
pub struct Application { pub struct Application {
server: Server, server: Server,
app: poem::Route, app: poem::Route,
port: u16, port: u16,
database: sqlx::postgres::PgPool, database: sqlx::postgres::PgPool,
settings: Settings, pub settings: Settings,
} }
pub struct RunnableApplication { pub struct RunnableApplication {
@ -61,6 +73,8 @@ impl From<Application> for RunnableApplication {
let app = val let app = val
.app .app
.with(Cors::new()) .with(Cors::new())
.with(CookieSession::new(CookieConfig::default().secure(true)))
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
.data(val.database) .data(val.database)
.data(val.settings); .data(val.settings);
let server = val.server; let server = val.server;
@ -83,7 +97,7 @@ impl Application {
fn setup_app(settings: &Settings) -> poem::Route { fn setup_app(settings: &Settings) -> poem::Route {
let api_service = OpenApiService::new( let api_service = OpenApiService::new(
(Api, HealthApi, VersionApi), (Api, AuthApi, HealthApi, VersionApi),
settings.application.clone().name, settings.application.clone().name,
settings.application.clone().version, settings.application.clone().version,
); );