Compare commits

..

No commits in common. "feature/authentication" and "develop" have entirely different histories.

23 changed files with 99 additions and 4477 deletions

View File

@ -1,14 +1,15 @@
;;; Directory Local Variables -*- no-byte-compile: t -*- ;;; Directory Local Variables -*- no-byte-compile: t -*-
;;; For more information see (info "(emacs) Directory Variables") ;;; For more information see (info "(emacs) Directory Variables")
((rustic-mode . ((fill-column . 80))) ((sql-mode
(sql-mode . ((eval . (progn .
(setq-local lsp-sqls-connections ((eval . (progn
`(((driver . "postgresql") (setq-local lsp-sqls-connections
(dataSourceName \, `(((driver . "postgresql")
(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable" (dataSourceName .
(getenv "DB_HOST") ,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
(getenv "DB_PORT") (getenv "DB_HOST")
(getenv "DB_USER") (getenv "DB_PORT")
(getenv "DB_PASSWORD") (getenv "DB_USER")
(getenv "DB_NAME"))))))))))) (getenv "DB_PASSWORD")
(getenv "DB_NAME")))))))))))

4021
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -17,32 +17,45 @@ name = "gege-jdr-backend"
[dependencies] [dependencies]
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }
config = { version = "0.14.1", features = ["yaml"] } config = { version = "0.14.0", features = ["yaml"] }
dotenvy = "0.15.7" dotenvy = "0.15.7"
oauth2 = "4.4.2" serde = "1.0.204"
quote = "1.0.37" serde_json = "1.0.120"
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] } thiserror = "1.0.63"
serde = "1.0.215" tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
serde_json = "1.0.133"
thiserror = "1.0.69"
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] } tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
uuid = { version = "1.11.0", features = ["v4", "serde"] } uuid = { version = "1.10.0", features = ["v4", "serde"] }
[dependencies.lettre]
version = "0.11.7"
default-features = false
features = [
"builder",
"hostname",
"pool",
"rustls-tls",
"tokio1",
"tokio1-rustls-tls",
"smtp-transport"
]
[dependencies.poem] [dependencies.poem]
version = "3.1.3" version = "3.0.4"
default-features = false default-features = false
features = ["csrf", "rustls", "cookie", "test", "session"] features = [
"csrf",
"rustls",
"cookie",
"test"
]
[dependencies.poem-openapi] [dependencies.poem-openapi]
version = "5.1.2" version = "5.0.3"
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"] features = ["chrono", "swagger-ui", "uuid"]
[dependencies.sqlx] [dependencies.sqlx]
version = "0.8.2" version = "0.8.0"
default-features = false default-features = false
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"] features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
[lints.rust]
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }

View File

@ -31,6 +31,25 @@ services:
depends_on: depends_on:
- db - db
# If you run GegeJdrBackend in production, DO NOT use mailpit.
# This tool is for testing only. Instead, you should use a real SMTP
# provider, such as Mailgun, Mailwhale, or Postal.
mailpit:
image: axllent/mailpit:latest
restart: unless-stopped
container_name: gege-jdr-backend-mailpit
ports:
- 127.0.0.1:8025:8025 # WebUI
- 127.0.0.1:1025:1025 # SMTP
volumes:
- gege_jdr_backend_mailpit:/data
environment:
MP_MAX_MESSAGES: 5000
MP_DATABASE: /data/mailpit.db
MP_SMTP_AUTH_ACCEPT_ANY: 1
MP_SMTP_AUTH_ALLOW_INSECURE: 1
volumes: volumes:
gege_jdr_backend_db_data: gege_jdr_backend_db_data:
gege_jdr_backend_pgadmin_data: gege_jdr_backend_pgadmin_data:
gege_jdr_backend_mailpit:

View File

@ -5,11 +5,11 @@
"systems": "systems" "systems": "systems"
}, },
"locked": { "locked": {
"lastModified": 1731533236, "lastModified": 1710146030,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
"owner": "numtide", "owner": "numtide",
"repo": "flake-utils", "repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -20,11 +20,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1732014248, "lastModified": 1723175592,
"narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=", "narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=",
"owner": "nixos", "owner": "nixos",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "23e89b7da85c3640bbc2173fe04f4bd114342367", "rev": "5e0ca22929f3342b19569b21b2f3462f053e497b",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -36,11 +36,11 @@
}, },
"nixpkgs_2": { "nixpkgs_2": {
"locked": { "locked": {
"lastModified": 1728538411, "lastModified": 1718428119,
"narHash": "sha256-f0SBJz1eZ2yOuKUr5CA9BHULGXVSn6miBuUWdTyhUhU=", "narHash": "sha256-WdWDpNaq6u1IPtxtYHHWpl5BmabtpmLnMAx0RdJ/vo8=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "b69de56fac8c2b6f8fd27f2eca01dcda8e0a4221", "rev": "e6cea36f83499eb4e9cd184c8a8e823296b50ad5",
"type": "github" "type": "github"
}, },
"original": { "original": {
@ -62,11 +62,11 @@
"nixpkgs": "nixpkgs_2" "nixpkgs": "nixpkgs_2"
}, },
"locked": { "locked": {
"lastModified": 1732242723, "lastModified": 1723256423,
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=", "narHash": "sha256-9iDTrfVM+mbcad31a47oqW8t8tfSA4C/si6F8F2DO/w=",
"owner": "oxalica", "owner": "oxalica",
"repo": "rust-overlay", "repo": "rust-overlay",
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a", "rev": "615cfd85b4d9c51811a8d875374268fab5bd4089",
"type": "github" "type": "github"
}, },
"original": { "original": {

View File

@ -1,5 +0,0 @@
-- Add down migration script here
ALTER TABLE IF EXISTS public.sessions DROP CONSTRAINT IF EXISTS sessions_user_id_users_fk;
DROP TABLE IF EXISTS public.sessions;
DROP TABLE IF EXISTS public.users;
DROP EXTENSION IF EXISTS "uuid-ossp";

View File

@ -1,29 +0,0 @@
-- Add up migration script here
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE TABLE IF NOT EXISTS public.users
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
email character varying(255) NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (id),
CONSTRAINT users_email_unique UNIQUE (email)
);
CREATE TABLE IF NOT EXISTS public.sessions
(
id uuid NOT NULL DEFAULT uuid_generate_v4(),
user_id uuid NOT NULL,
session_id character varying NOT NULL,
expires_at timestamp with time zone NOT NULL,
PRIMARY KEY (id),
CONSTRAINT sessions_user_id_unique UNIQUE (user_id)
);
ALTER TABLE IF EXISTS public.sessions
ADD CONSTRAINT sessions_user_id_users_fk FOREIGN KEY (user_id)
REFERENCES public.users (id) MATCH SIMPLE
ON UPDATE CASCADE
ON DELETE CASCADE
NOT VALID;

View File

@ -1,4 +1,4 @@
[toolchain] [toolchain]
channel = "1.81.0" channel = "1.78.0"
components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ] components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ]
profile = "default" profile = "default"

View File

@ -16,7 +16,3 @@ email:
user: user@gege-jdr-backend.example user: user@gege-jdr-backend.example
from: GegeJdrBackend <noreply@gege-jdr-backend.example> from: GegeJdrBackend <noreply@gege-jdr-backend.example>
password: hunter2 password: hunter2
discord:
client_id: changeme
client_secret: changeme

View File

@ -3,5 +3,5 @@ debug: true
application: application:
protocol: http protocol: http
host: localhost host: 127.0.0.1
base_url: http://localhost:3000 base_url: http://127.0.0.1:3000

View File

@ -1,19 +0,0 @@
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ApiError {
#[error("SQL error: {0}")]
Sql(#[from] sqlx::Error),
#[error("HTTP request error: {0}")]
Request(#[from] reqwest::Error),
#[error("OAuth token error: {0}")]
TokenError(String),
#[error("Unauthorized")]
Unauthorized,
#[error("Attempted to get a value, none found")]
OptionError,
#[error("Failed to parse a number as an integer")]
ParseIntError(#[from] std::num::TryFromIntError),
#[error("Encountered an error trying to convert an infaillible value")]
FromRequestPartsError(#[from] std::convert::Infallible),
}

View File

@ -5,13 +5,10 @@
#![allow(clippy::unused_async)] #![allow(clippy::unused_async)]
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling #![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
mod errors; pub mod route;
mod models; pub mod settings;
mod oauth; pub mod startup;
mod route; pub mod telemetry;
mod settings;
mod startup;
mod telemetry;
type MaybeListener = Option<poem::listener::TcpListener<String>>; type MaybeListener = Option<poem::listener::TcpListener<String>>;
@ -32,8 +29,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
tracing::event!( tracing::event!(
target: "gege-jdr-backend", target: "gege-jdr-backend",
tracing::Level::INFO, tracing::Level::INFO,
"Listening on {}", "Listening on http://127.0.0.1:{}/",
application.settings.web_address() application.port()
); );
application application
} }

View File

@ -1,18 +0,0 @@
type Timestampz = chrono::DateTime<chrono::Utc>;
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
pub struct User {
pub id: uuid::Uuid,
pub email: String,
pub created_at: Timestampz,
pub last_updated: Timestampz,
}
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
pub struct Session {
pub id: uuid::Uuid,
pub user_id: uuid::Uuid,
#[allow(clippy::struct_field_names)]
pub session_id: String,
pub expires_at: Timestampz,
}

View File

@ -1 +0,0 @@
pub mod accounts;

View File

@ -1,62 +0,0 @@
use oauth2::{
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
};
use reqwest::Url;
use crate::{errors::ApiError, settings::Settings};
use super::OauthProvider;
#[derive(Debug, Clone)]
pub struct DiscordOauthProvider {
client: BasicClient,
}
impl DiscordOauthProvider {
pub fn new(settings: &Settings) -> Self {
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
.expect("Invalid authorization endpoint URL");
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
.expect("Invalid token endpoint URL");
let revocation_url =
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
.expect("Invalid revocation URL");
let client = BasicClient::new(
ClientId::new(settings.discord.client_id.clone()),
Some(ClientSecret::new(settings.discord.client_secret.clone())),
auth_url,
Some(token_url),
)
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
.set_revocation_uri(revocation_url);
Self { client }
}
}
impl OauthProvider for DiscordOauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
let (auth_url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
.set_pkce_challenge(pkce_challenge)
.url();
(auth_url, csrf_token, pkce_verifier)
}
async fn token(
&self,
code: String,
verifier: PkceCodeVerifier,
) -> Result<super::Token, ApiError> {
self.client
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(verifier)
.request_async(oauth2::reqwest::async_http_client)
.await
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
}
}

View File

@ -1,17 +0,0 @@
mod discord;
pub use discord::DiscordOauthProvider;
use oauth2::{
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
StandardTokenResponse,
};
use reqwest::Url;
use crate::errors::ApiError;
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
pub trait OauthProvider {
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
}

View File

@ -1,125 +0,0 @@
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
use poem::web::Data;
use poem::{session::Session, web::Form};
use poem_openapi::payload::{Json, PlainText};
use poem_openapi::{ApiResponse, Object, OpenApi};
use crate::oauth::{DiscordOauthProvider, OauthProvider};
use super::errors::ErrorResponse;
use super::ApiCategory;
pub struct AuthApi;
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
struct DiscordCallbackRequest {
code: String,
state: String,
}
#[derive(ApiResponse)]
enum LoginStatusResponse {
#[oai(status = 201)]
LoggedIn,
#[oai(status = 201)]
LoggedOut,
#[oai(status = 301)]
LoginRedirect(
#[oai(header = "Location")] String,
#[oai(header = "Cache-Control")] String,
),
#[oai(status = 500)]
TokenError(Json<ErrorResponse>),
}
#[derive(ApiResponse)]
enum CsrfResponse {
#[oai(status = 201)]
Token(PlainText<String>),
}
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
impl AuthApi {
// TODO: implement the following endpoints:
// - /signin
// - /signout
// - /session
// - /providers
// See https://next-auth.js.org/getting-started/rest-api
#[oai(path = "/signin/discord", method = "get")]
async fn signin_discord(
&self,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
session.set("csrf", csrf_token);
session.set("pkce", pkce_verifier);
tracing::event!(
target: "auth-discord",
tracing::Level::INFO,
"Redirect URL: {}",
auth_url
);
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
}
#[oai(path = "/callback/discord", method = "get")]
async fn callback_discord(
&self,
Form(auth_request): Form<DiscordCallbackRequest>,
oauth: Data<&DiscordOauthProvider>,
session: &Session,
) -> LoginStatusResponse {
tracing::event!(
target: "auth-discord",
tracing::Level::DEBUG,
"Discord replied with: {:?}",
auth_request
);
let csrf_token = session
.get::<CsrfToken>("csrf")
.expect("Failed to retrieve Csrf token from session");
if *csrf_token.secret().to_string() != auth_request.state {
return LoginStatusResponse::TokenError(
ErrorResponse {
code: 500,
message: "OAuth token error".into(),
details: Some(
"OAuth provider did not send a message that matches what was expected"
.into(),
),
}
.into(),
);
}
let pkce_verifier = session
.get::<PkceCodeVerifier>("pkce")
.expect("Failed to retrieve pkce verifier from session");
let token = oauth.token(auth_request.code, pkce_verifier).await;
if let Err(e) = token {
return LoginStatusResponse::TokenError(Json(e.into()));
}
let token = token.unwrap();
tracing::event!(
target: "auth-discord",
tracing::Level::DEBUG,
"Token: {:?}",
token
);
session.set("token", token);
LoginStatusResponse::LoggedIn
}
#[oai(path = "/csrf", method = "get")]
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
CsrfResponse::Token(PlainText(token.0.clone()))
}
#[oai(path = "/signout", method = "post")]
async fn signout(&self, session: &Session) -> LoginStatusResponse {
session.remove("token");
LoginStatusResponse::LoggedOut
}
}

View File

@ -1,88 +0,0 @@
use poem::error::ResponseError;
use poem_openapi::Object;
use crate::errors::ApiError;
#[derive(Debug, serde::Serialize, Default, Object)]
pub struct ErrorResponse {
pub code: u16,
pub message: String,
pub details: Option<String>,
}
impl From<ApiError> for ErrorResponse {
fn from(value: ApiError) -> Self {
match value {
ApiError::Sql(e) => Self {
code: 500,
message: "SQL error".into(),
details: Some(e.to_string()),
},
ApiError::Request(e) => Self {
code: 500,
message: "HTTP request error".into(),
details: Some(e.to_string()),
},
ApiError::TokenError(e) => Self {
code: 500,
message: "OAuth token error".into(),
details: Some(e),
},
ApiError::Unauthorized => Self {
code: 401,
message: "Unauthorized!".into(),
..Default::default()
},
ApiError::OptionError => Self {
code: 500,
message: "Attempted to get a value, but none found".into(),
..Default::default()
},
ApiError::ParseIntError(e) => Self {
code: 500,
message: "Failed to parse a number as an integer".into(),
details: Some(e.to_string()),
},
ApiError::FromRequestPartsError(e) => Self {
code: 500,
message: "Encountered an error trying to convert an infaillible value".to_string(),
details: Some(e.to_string()),
},
}
}
}
impl ResponseError for ApiError {
fn status(&self) -> reqwest::StatusCode {
match self {
Self::FromRequestPartsError(_)
| Self::ParseIntError(_)
| Self::OptionError
| Self::Sql(_)
| Self::Request(_)
| Self::TokenError(_) => reqwest::StatusCode::INTERNAL_SERVER_ERROR,
Self::Unauthorized => reqwest::StatusCode::UNAUTHORIZED,
}
}
fn as_response(&self) -> poem::Response
where
Self: std::error::Error + Send + Sync + 'static,
{
match self {
Self::Sql(_) => todo!(),
Self::Request(_) => todo!(),
Self::TokenError(_) => todo!(),
Self::Unauthorized => todo!(),
Self::OptionError => todo!(),
Self::ParseIntError(_) => todo!(),
Self::FromRequestPartsError(_) => todo!(),
}
}
}
impl From<ErrorResponse> for poem_openapi::payload::Json<ErrorResponse> {
fn from(value: ErrorResponse) -> Self {
Self(value)
}
}

View File

@ -10,7 +10,7 @@ enum HealthResponse {
pub struct HealthApi; pub struct HealthApi;
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")] #[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")]
impl HealthApi { impl HealthApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn health_check(&self) -> HealthResponse { async fn health_check(&self) -> HealthResponse {

View File

@ -6,19 +6,13 @@ pub use health::HealthApi;
mod version; mod version;
pub use version::VersionApi; pub use version::VersionApi;
mod errors;
mod auth;
pub use auth::AuthApi;
#[derive(Tags)] #[derive(Tags)]
enum ApiCategory { enum ApiCategory {
Auth,
Health, Health,
Version, Version,
} }
pub struct Api; pub(crate) struct Api;
#[OpenApi] #[OpenApi]
impl Api {} impl Api {}

View File

@ -25,7 +25,7 @@ enum VersionResponse {
pub struct VersionApi; pub struct VersionApi;
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")] #[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")]
impl VersionApi { impl VersionApi {
#[oai(path = "/", method = "get")] #[oai(path = "/", method = "get")]
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> { async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {

View File

@ -4,7 +4,6 @@ use sqlx::ConnectOptions;
pub struct Settings { pub struct Settings {
pub application: ApplicationSettings, pub application: ApplicationSettings,
pub database: Database, pub database: Database,
pub discord: Discord,
pub debug: bool, pub debug: bool,
pub email: EmailSettings, pub email: EmailSettings,
pub frontend_url: String, pub frontend_url: String,
@ -13,7 +12,15 @@ pub struct Settings {
impl Settings { impl Settings {
#[must_use] #[must_use]
pub fn web_address(&self) -> String { pub fn web_address(&self) -> String {
self.application.base_url.clone() if self.debug {
format!(
"{}:{}",
self.application.base_url.clone(),
self.application.port
)
} else {
self.application.base_url.clone()
}
} }
/// Multipurpose function that helps detect the current /// Multipurpose function that helps detect the current
@ -58,7 +65,7 @@ impl Settings {
)) ))
.add_source( .add_source(
config::Environment::with_prefix("APP") config::Environment::with_prefix("APP")
.prefix_separator("__") .prefix_separator("_")
.separator("__"), .separator("__"),
) )
.build()?; .build()?;
@ -157,12 +164,6 @@ pub struct EmailSettings {
pub from: String, pub from: String,
} }
#[derive(Debug, serde::Deserialize, Clone, Default)]
pub struct Discord {
pub client_id: String,
pub client_secret: String,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;

View File

@ -1,11 +1,8 @@
use poem::middleware::Cors;
use poem::middleware::{AddDataEndpoint, CorsEndpoint}; use poem::middleware::{AddDataEndpoint, CorsEndpoint};
use poem::middleware::{CookieJarManagerEndpoint, Cors};
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
use poem::{EndpointExt, Route}; use poem::{EndpointExt, Route};
use poem_openapi::OpenApiService; use poem_openapi::OpenApiService;
use crate::oauth::DiscordOauthProvider;
use crate::route::AuthApi;
use crate::{ use crate::{
route::{Api, HealthApi, VersionApi}, route::{Api, HealthApi, VersionApi},
settings::Settings, settings::Settings,
@ -25,23 +22,14 @@ pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgr
} }
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>; type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
pub type App = AddDataEndpoint< pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>;
AddDataEndpoint<
AddDataEndpoint<
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
DiscordOauthProvider,
>,
sqlx::Pool<sqlx::Postgres>,
>,
Settings,
>;
pub struct Application { pub struct Application {
server: Server, server: Server,
app: poem::Route, app: poem::Route,
port: u16, port: u16,
database: sqlx::postgres::PgPool, database: sqlx::postgres::PgPool,
pub settings: Settings, settings: Settings,
} }
pub struct RunnableApplication { pub struct RunnableApplication {
@ -73,8 +61,6 @@ impl From<Application> for RunnableApplication {
let app = val let app = val
.app .app
.with(Cors::new()) .with(Cors::new())
.with(CookieSession::new(CookieConfig::default().secure(true)))
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
.data(val.database) .data(val.database)
.data(val.settings); .data(val.settings);
let server = val.server; let server = val.server;
@ -97,7 +83,7 @@ impl Application {
fn setup_app(settings: &Settings) -> poem::Route { fn setup_app(settings: &Settings) -> poem::Route {
let api_service = OpenApiService::new( let api_service = OpenApiService::new(
(Api, AuthApi, HealthApi, VersionApi), (Api, HealthApi, VersionApi),
settings.application.clone().name, settings.application.clone().name,
settings.application.clone().version, settings.application.clone().version,
); );