generated from phundrak/rust-poem-openapi-template
feat: initial OAuth implementation with Discord
Some checks failed
CI / tests (push) Failing after 3m53s
Some checks failed
CI / tests (push) Failing after 3m53s
For now, only a basic implementation of OAuth with Discord is implemented. If the user calls the Discord signin endpoint, they get redirected to Discord’s OAuth page. Once they accept, they get redirected back to the backend’s callback API endpoint. The token the user got from Discord is stored in the user’s session. When the user wants to log out, the user’s session’s token is wiped. This commit also updates the dependencies of the project. It also removes the dependency lettre as well as the mailpit docker service for developers as it appears clearer this project won’t send emails anytime soon.
This commit is contained in:
parent
ae10711e41
commit
9489f78224
@ -1,13 +1,12 @@
|
||||
;;; Directory Local Variables -*- no-byte-compile: t -*-
|
||||
;;; For more information see (info "(emacs) Directory Variables")
|
||||
|
||||
((sql-mode
|
||||
.
|
||||
((eval . (progn
|
||||
((rustic-mode . ((fill-column . 80)))
|
||||
(sql-mode . ((eval . (progn
|
||||
(setq-local lsp-sqls-connections
|
||||
`(((driver . "postgresql")
|
||||
(dataSourceName .
|
||||
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
||||
(dataSourceName \,
|
||||
(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
||||
(getenv "DB_HOST")
|
||||
(getenv "DB_PORT")
|
||||
(getenv "DB_USER")
|
||||
|
1594
Cargo.lock
generated
1594
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
47
Cargo.toml
47
Cargo.toml
@ -17,45 +17,32 @@ name = "gege-jdr-backend"
|
||||
|
||||
[dependencies]
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
config = { version = "0.14.0", features = ["yaml"] }
|
||||
config = { version = "0.14.1", features = ["yaml"] }
|
||||
dotenvy = "0.15.7"
|
||||
serde = "1.0.204"
|
||||
serde_json = "1.0.120"
|
||||
thiserror = "1.0.63"
|
||||
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
|
||||
oauth2 = "4.4.2"
|
||||
quote = "1.0.37"
|
||||
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
|
||||
serde = "1.0.215"
|
||||
serde_json = "1.0.133"
|
||||
thiserror = "1.0.69"
|
||||
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
|
||||
[dependencies.lettre]
|
||||
version = "0.11.7"
|
||||
default-features = false
|
||||
features = [
|
||||
"builder",
|
||||
"hostname",
|
||||
"pool",
|
||||
"rustls-tls",
|
||||
"tokio1",
|
||||
"tokio1-rustls-tls",
|
||||
"smtp-transport"
|
||||
]
|
||||
|
||||
uuid = { version = "1.11.0", features = ["v4", "serde"] }
|
||||
|
||||
[dependencies.poem]
|
||||
version = "3.0.4"
|
||||
version = "3.1.3"
|
||||
default-features = false
|
||||
features = [
|
||||
"csrf",
|
||||
"rustls",
|
||||
"cookie",
|
||||
"test"
|
||||
]
|
||||
features = ["csrf", "rustls", "cookie", "test", "session"]
|
||||
|
||||
[dependencies.poem-openapi]
|
||||
version = "5.0.3"
|
||||
features = ["chrono", "swagger-ui", "uuid"]
|
||||
version = "5.1.2"
|
||||
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
|
||||
|
||||
[dependencies.sqlx]
|
||||
version = "0.8.0"
|
||||
version = "0.8.2"
|
||||
default-features = false
|
||||
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }
|
||||
|
@ -31,25 +31,6 @@ services:
|
||||
depends_on:
|
||||
- db
|
||||
|
||||
# If you run GegeJdrBackend in production, DO NOT use mailpit.
|
||||
# This tool is for testing only. Instead, you should use a real SMTP
|
||||
# provider, such as Mailgun, Mailwhale, or Postal.
|
||||
mailpit:
|
||||
image: axllent/mailpit:latest
|
||||
restart: unless-stopped
|
||||
container_name: gege-jdr-backend-mailpit
|
||||
ports:
|
||||
- 127.0.0.1:8025:8025 # WebUI
|
||||
- 127.0.0.1:1025:1025 # SMTP
|
||||
volumes:
|
||||
- gege_jdr_backend_mailpit:/data
|
||||
environment:
|
||||
MP_MAX_MESSAGES: 5000
|
||||
MP_DATABASE: /data/mailpit.db
|
||||
MP_SMTP_AUTH_ACCEPT_ANY: 1
|
||||
MP_SMTP_AUTH_ALLOW_INSECURE: 1
|
||||
|
||||
volumes:
|
||||
gege_jdr_backend_db_data:
|
||||
gege_jdr_backend_pgadmin_data:
|
||||
gege_jdr_backend_mailpit:
|
||||
|
@ -3,5 +3,5 @@ debug: true
|
||||
|
||||
application:
|
||||
protocol: http
|
||||
host: 127.0.0.1
|
||||
base_url: http://127.0.0.1:3000
|
||||
host: localhost
|
||||
base_url: http://localhost:3000
|
||||
|
19
src/errors.rs
Normal file
19
src/errors.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
#[error("SQL error: {0}")]
|
||||
Sql(#[from] sqlx::Error),
|
||||
#[error("HTTP request error: {0}")]
|
||||
Request(#[from] reqwest::Error),
|
||||
#[error("OAuth token error: {0}")]
|
||||
TokenError(String),
|
||||
#[error("Unauthorized")]
|
||||
Unauthorized,
|
||||
#[error("Attempted to get a value, none found")]
|
||||
OptionError,
|
||||
#[error("Failed to parse a number as an integer")]
|
||||
ParseIntError(#[from] std::num::TryFromIntError),
|
||||
#[error("Encountered an error trying to convert an infaillible value")]
|
||||
FromRequestPartsError(#[from] std::convert::Infallible),
|
||||
}
|
15
src/lib.rs
15
src/lib.rs
@ -5,10 +5,13 @@
|
||||
#![allow(clippy::unused_async)]
|
||||
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
|
||||
|
||||
pub mod route;
|
||||
pub mod settings;
|
||||
pub mod startup;
|
||||
pub mod telemetry;
|
||||
mod errors;
|
||||
mod models;
|
||||
mod oauth;
|
||||
mod route;
|
||||
mod settings;
|
||||
mod startup;
|
||||
mod telemetry;
|
||||
|
||||
type MaybeListener = Option<poem::listener::TcpListener<String>>;
|
||||
|
||||
@ -29,8 +32,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
|
||||
tracing::event!(
|
||||
target: "gege-jdr-backend",
|
||||
tracing::Level::INFO,
|
||||
"Listening on http://127.0.0.1:{}/",
|
||||
application.port()
|
||||
"Listening on {}",
|
||||
application.settings.web_address()
|
||||
);
|
||||
application
|
||||
}
|
||||
|
18
src/models/accounts.rs
Normal file
18
src/models/accounts.rs
Normal file
@ -0,0 +1,18 @@
|
||||
type Timestampz = chrono::DateTime<chrono::Utc>;
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
|
||||
pub struct User {
|
||||
pub id: uuid::Uuid,
|
||||
pub email: String,
|
||||
pub created_at: Timestampz,
|
||||
pub last_updated: Timestampz,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
|
||||
pub struct Session {
|
||||
pub id: uuid::Uuid,
|
||||
pub user_id: uuid::Uuid,
|
||||
#[allow(clippy::struct_field_names)]
|
||||
pub session_id: String,
|
||||
pub expires_at: Timestampz,
|
||||
}
|
1
src/models/mod.rs
Normal file
1
src/models/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod accounts;
|
62
src/oauth/discord.rs
Normal file
62
src/oauth/discord.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
|
||||
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
|
||||
};
|
||||
use reqwest::Url;
|
||||
|
||||
use crate::{errors::ApiError, settings::Settings};
|
||||
|
||||
use super::OauthProvider;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DiscordOauthProvider {
|
||||
client: BasicClient,
|
||||
}
|
||||
|
||||
impl DiscordOauthProvider {
|
||||
pub fn new(settings: &Settings) -> Self {
|
||||
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
|
||||
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
|
||||
.expect("Invalid authorization endpoint URL");
|
||||
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
|
||||
.expect("Invalid token endpoint URL");
|
||||
let revocation_url =
|
||||
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
|
||||
.expect("Invalid revocation URL");
|
||||
let client = BasicClient::new(
|
||||
ClientId::new(settings.discord.client_id.clone()),
|
||||
Some(ClientSecret::new(settings.discord.client_secret.clone())),
|
||||
auth_url,
|
||||
Some(token_url),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
|
||||
.set_revocation_uri(revocation_url);
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
impl OauthProvider for DiscordOauthProvider {
|
||||
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let (auth_url, csrf_token) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.url();
|
||||
(auth_url, csrf_token, pkce_verifier)
|
||||
}
|
||||
|
||||
async fn token(
|
||||
&self,
|
||||
code: String,
|
||||
verifier: PkceCodeVerifier,
|
||||
) -> Result<super::Token, ApiError> {
|
||||
self.client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(verifier)
|
||||
.request_async(oauth2::reqwest::async_http_client)
|
||||
.await
|
||||
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
|
||||
}
|
||||
}
|
17
src/oauth/mod.rs
Normal file
17
src/oauth/mod.rs
Normal file
@ -0,0 +1,17 @@
|
||||
mod discord;
|
||||
pub use discord::DiscordOauthProvider;
|
||||
|
||||
use oauth2::{
|
||||
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
|
||||
StandardTokenResponse,
|
||||
};
|
||||
use reqwest::Url;
|
||||
|
||||
use crate::errors::ApiError;
|
||||
|
||||
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
|
||||
|
||||
pub trait OauthProvider {
|
||||
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
|
||||
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
|
||||
}
|
125
src/route/auth.rs
Normal file
125
src/route/auth.rs
Normal file
@ -0,0 +1,125 @@
|
||||
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
|
||||
use poem::web::Data;
|
||||
use poem::{session::Session, web::Form};
|
||||
use poem_openapi::payload::{Json, PlainText};
|
||||
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||
|
||||
use crate::oauth::{DiscordOauthProvider, OauthProvider};
|
||||
|
||||
use super::errors::ErrorResponse;
|
||||
use super::ApiCategory;
|
||||
|
||||
pub struct AuthApi;
|
||||
|
||||
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
|
||||
struct DiscordCallbackRequest {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum LoginStatusResponse {
|
||||
#[oai(status = 201)]
|
||||
LoggedIn,
|
||||
#[oai(status = 201)]
|
||||
LoggedOut,
|
||||
#[oai(status = 301)]
|
||||
LoginRedirect(
|
||||
#[oai(header = "Location")] String,
|
||||
#[oai(header = "Cache-Control")] String,
|
||||
),
|
||||
#[oai(status = 500)]
|
||||
TokenError(Json<ErrorResponse>),
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum CsrfResponse {
|
||||
#[oai(status = 201)]
|
||||
Token(PlainText<String>),
|
||||
}
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
|
||||
impl AuthApi {
|
||||
// TODO: implement the following endpoints:
|
||||
// - /signin
|
||||
// - /signout
|
||||
// - /session
|
||||
// - /providers
|
||||
// See https://next-auth.js.org/getting-started/rest-api
|
||||
|
||||
#[oai(path = "/signin/discord", method = "get")]
|
||||
async fn signin_discord(
|
||||
&self,
|
||||
oauth: Data<&DiscordOauthProvider>,
|
||||
session: &Session,
|
||||
) -> LoginStatusResponse {
|
||||
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
|
||||
session.set("csrf", csrf_token);
|
||||
session.set("pkce", pkce_verifier);
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::INFO,
|
||||
"Redirect URL: {}",
|
||||
auth_url
|
||||
);
|
||||
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
|
||||
}
|
||||
|
||||
#[oai(path = "/callback/discord", method = "get")]
|
||||
async fn callback_discord(
|
||||
&self,
|
||||
Form(auth_request): Form<DiscordCallbackRequest>,
|
||||
oauth: Data<&DiscordOauthProvider>,
|
||||
session: &Session,
|
||||
) -> LoginStatusResponse {
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::DEBUG,
|
||||
"Discord replied with: {:?}",
|
||||
auth_request
|
||||
);
|
||||
let csrf_token = session
|
||||
.get::<CsrfToken>("csrf")
|
||||
.expect("Failed to retrieve Csrf token from session");
|
||||
if *csrf_token.secret().to_string() != auth_request.state {
|
||||
return LoginStatusResponse::TokenError(
|
||||
ErrorResponse {
|
||||
code: 500,
|
||||
message: "OAuth token error".into(),
|
||||
details: Some(
|
||||
"OAuth provider did not send a message that matches what was expected"
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
let pkce_verifier = session
|
||||
.get::<PkceCodeVerifier>("pkce")
|
||||
.expect("Failed to retrieve pkce verifier from session");
|
||||
let token = oauth.token(auth_request.code, pkce_verifier).await;
|
||||
if let Err(e) = token {
|
||||
return LoginStatusResponse::TokenError(Json(e.into()));
|
||||
}
|
||||
let token = token.unwrap();
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::DEBUG,
|
||||
"Token: {:?}",
|
||||
token
|
||||
);
|
||||
session.set("token", token);
|
||||
LoginStatusResponse::LoggedIn
|
||||
}
|
||||
|
||||
#[oai(path = "/csrf", method = "get")]
|
||||
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
|
||||
CsrfResponse::Token(PlainText(token.0.clone()))
|
||||
}
|
||||
|
||||
#[oai(path = "/signout", method = "post")]
|
||||
async fn signout(&self, session: &Session) -> LoginStatusResponse {
|
||||
session.remove("token");
|
||||
LoginStatusResponse::LoggedOut
|
||||
}
|
||||
}
|
88
src/route/errors.rs
Normal file
88
src/route/errors.rs
Normal file
@ -0,0 +1,88 @@
|
||||
use poem::error::ResponseError;
|
||||
use poem_openapi::Object;
|
||||
|
||||
use crate::errors::ApiError;
|
||||
|
||||
#[derive(Debug, serde::Serialize, Default, Object)]
|
||||
pub struct ErrorResponse {
|
||||
pub code: u16,
|
||||
pub message: String,
|
||||
pub details: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ApiError> for ErrorResponse {
|
||||
fn from(value: ApiError) -> Self {
|
||||
match value {
|
||||
ApiError::Sql(e) => Self {
|
||||
code: 500,
|
||||
message: "SQL error".into(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
ApiError::Request(e) => Self {
|
||||
code: 500,
|
||||
message: "HTTP request error".into(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
ApiError::TokenError(e) => Self {
|
||||
code: 500,
|
||||
message: "OAuth token error".into(),
|
||||
details: Some(e),
|
||||
},
|
||||
ApiError::Unauthorized => Self {
|
||||
code: 401,
|
||||
message: "Unauthorized!".into(),
|
||||
..Default::default()
|
||||
},
|
||||
ApiError::OptionError => Self {
|
||||
code: 500,
|
||||
message: "Attempted to get a value, but none found".into(),
|
||||
..Default::default()
|
||||
},
|
||||
ApiError::ParseIntError(e) => Self {
|
||||
code: 500,
|
||||
message: "Failed to parse a number as an integer".into(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
ApiError::FromRequestPartsError(e) => Self {
|
||||
code: 500,
|
||||
message: "Encountered an error trying to convert an infaillible value".to_string(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseError for ApiError {
|
||||
fn status(&self) -> reqwest::StatusCode {
|
||||
match self {
|
||||
Self::FromRequestPartsError(_)
|
||||
| Self::ParseIntError(_)
|
||||
| Self::OptionError
|
||||
| Self::Sql(_)
|
||||
| Self::Request(_)
|
||||
| Self::TokenError(_) => reqwest::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::Unauthorized => reqwest::StatusCode::UNAUTHORIZED,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_response(&self) -> poem::Response
|
||||
where
|
||||
Self: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
match self {
|
||||
Self::Sql(_) => todo!(),
|
||||
Self::Request(_) => todo!(),
|
||||
Self::TokenError(_) => todo!(),
|
||||
Self::Unauthorized => todo!(),
|
||||
Self::OptionError => todo!(),
|
||||
Self::ParseIntError(_) => todo!(),
|
||||
Self::FromRequestPartsError(_) => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ErrorResponse> for poem_openapi::payload::Json<ErrorResponse> {
|
||||
fn from(value: ErrorResponse) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
@ -10,7 +10,7 @@ enum HealthResponse {
|
||||
|
||||
pub struct HealthApi;
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")]
|
||||
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
|
||||
impl HealthApi {
|
||||
#[oai(path = "/", method = "get")]
|
||||
async fn health_check(&self) -> HealthResponse {
|
||||
|
@ -6,13 +6,19 @@ pub use health::HealthApi;
|
||||
mod version;
|
||||
pub use version::VersionApi;
|
||||
|
||||
mod errors;
|
||||
|
||||
mod auth;
|
||||
pub use auth::AuthApi;
|
||||
|
||||
#[derive(Tags)]
|
||||
enum ApiCategory {
|
||||
Auth,
|
||||
Health,
|
||||
Version,
|
||||
}
|
||||
|
||||
pub(crate) struct Api;
|
||||
pub struct Api;
|
||||
|
||||
#[OpenApi]
|
||||
impl Api {}
|
||||
|
@ -25,7 +25,7 @@ enum VersionResponse {
|
||||
|
||||
pub struct VersionApi;
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")]
|
||||
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
|
||||
impl VersionApi {
|
||||
#[oai(path = "/", method = "get")]
|
||||
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
|
||||
|
@ -13,16 +13,8 @@ pub struct Settings {
|
||||
impl Settings {
|
||||
#[must_use]
|
||||
pub fn web_address(&self) -> String {
|
||||
if self.debug {
|
||||
format!(
|
||||
"{}:{}",
|
||||
self.application.base_url.clone(),
|
||||
self.application.port
|
||||
)
|
||||
} else {
|
||||
self.application.base_url.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Multipurpose function that helps detect the current
|
||||
/// environment the application is running in using the
|
||||
@ -66,7 +58,7 @@ impl Settings {
|
||||
))
|
||||
.add_source(
|
||||
config::Environment::with_prefix("APP")
|
||||
.prefix_separator("_")
|
||||
.prefix_separator("__")
|
||||
.separator("__"),
|
||||
)
|
||||
.build()?;
|
||||
@ -167,8 +159,8 @@ pub struct EmailSettings {
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||
pub struct Discord {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1,8 +1,11 @@
|
||||
use poem::middleware::Cors;
|
||||
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
|
||||
use poem::middleware::{CookieJarManagerEndpoint, Cors};
|
||||
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
|
||||
use poem::{EndpointExt, Route};
|
||||
use poem_openapi::OpenApiService;
|
||||
|
||||
use crate::oauth::DiscordOauthProvider;
|
||||
use crate::route::AuthApi;
|
||||
use crate::{
|
||||
route::{Api, HealthApi, VersionApi},
|
||||
settings::Settings,
|
||||
@ -22,14 +25,23 @@ pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgr
|
||||
}
|
||||
|
||||
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
|
||||
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>;
|
||||
pub type App = AddDataEndpoint<
|
||||
AddDataEndpoint<
|
||||
AddDataEndpoint<
|
||||
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
|
||||
DiscordOauthProvider,
|
||||
>,
|
||||
sqlx::Pool<sqlx::Postgres>,
|
||||
>,
|
||||
Settings,
|
||||
>;
|
||||
|
||||
pub struct Application {
|
||||
server: Server,
|
||||
app: poem::Route,
|
||||
port: u16,
|
||||
database: sqlx::postgres::PgPool,
|
||||
settings: Settings,
|
||||
pub settings: Settings,
|
||||
}
|
||||
|
||||
pub struct RunnableApplication {
|
||||
@ -61,6 +73,8 @@ impl From<Application> for RunnableApplication {
|
||||
let app = val
|
||||
.app
|
||||
.with(Cors::new())
|
||||
.with(CookieSession::new(CookieConfig::default().secure(true)))
|
||||
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
|
||||
.data(val.database)
|
||||
.data(val.settings);
|
||||
let server = val.server;
|
||||
@ -83,7 +97,7 @@ impl Application {
|
||||
|
||||
fn setup_app(settings: &Settings) -> poem::Route {
|
||||
let api_service = OpenApiService::new(
|
||||
(Api, HealthApi, VersionApi),
|
||||
(Api, AuthApi, HealthApi, VersionApi),
|
||||
settings.application.clone().name,
|
||||
settings.application.clone().version,
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user