generated from phundrak/rust-poem-openapi-template
Compare commits
4 Commits
develop
...
feature/au
Author | SHA1 | Date | |
---|---|---|---|
9489f78224 | |||
ae10711e41 | |||
ff90b1959f | |||
1125bc4a38 |
@ -1,13 +1,12 @@
|
||||
;;; Directory Local Variables -*- no-byte-compile: t -*-
|
||||
;;; For more information see (info "(emacs) Directory Variables")
|
||||
|
||||
((sql-mode
|
||||
.
|
||||
((eval . (progn
|
||||
((rustic-mode . ((fill-column . 80)))
|
||||
(sql-mode . ((eval . (progn
|
||||
(setq-local lsp-sqls-connections
|
||||
`(((driver . "postgresql")
|
||||
(dataSourceName .
|
||||
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
||||
(dataSourceName \,
|
||||
(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
||||
(getenv "DB_HOST")
|
||||
(getenv "DB_PORT")
|
||||
(getenv "DB_USER")
|
||||
|
4021
Cargo.lock
generated
Normal file
4021
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
47
Cargo.toml
47
Cargo.toml
@ -17,45 +17,32 @@ name = "gege-jdr-backend"
|
||||
|
||||
[dependencies]
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
config = { version = "0.14.0", features = ["yaml"] }
|
||||
config = { version = "0.14.1", features = ["yaml"] }
|
||||
dotenvy = "0.15.7"
|
||||
serde = "1.0.204"
|
||||
serde_json = "1.0.120"
|
||||
thiserror = "1.0.63"
|
||||
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
|
||||
oauth2 = "4.4.2"
|
||||
quote = "1.0.37"
|
||||
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
|
||||
serde = "1.0.215"
|
||||
serde_json = "1.0.133"
|
||||
thiserror = "1.0.69"
|
||||
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
|
||||
tracing = "0.1.40"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
||||
|
||||
[dependencies.lettre]
|
||||
version = "0.11.7"
|
||||
default-features = false
|
||||
features = [
|
||||
"builder",
|
||||
"hostname",
|
||||
"pool",
|
||||
"rustls-tls",
|
||||
"tokio1",
|
||||
"tokio1-rustls-tls",
|
||||
"smtp-transport"
|
||||
]
|
||||
|
||||
uuid = { version = "1.11.0", features = ["v4", "serde"] }
|
||||
|
||||
[dependencies.poem]
|
||||
version = "3.0.4"
|
||||
version = "3.1.3"
|
||||
default-features = false
|
||||
features = [
|
||||
"csrf",
|
||||
"rustls",
|
||||
"cookie",
|
||||
"test"
|
||||
]
|
||||
features = ["csrf", "rustls", "cookie", "test", "session"]
|
||||
|
||||
[dependencies.poem-openapi]
|
||||
version = "5.0.3"
|
||||
features = ["chrono", "swagger-ui", "uuid"]
|
||||
version = "5.1.2"
|
||||
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
|
||||
|
||||
[dependencies.sqlx]
|
||||
version = "0.8.0"
|
||||
version = "0.8.2"
|
||||
default-features = false
|
||||
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
|
||||
|
||||
[lints.rust]
|
||||
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }
|
||||
|
@ -31,25 +31,6 @@ services:
|
||||
depends_on:
|
||||
- db
|
||||
|
||||
# If you run GegeJdrBackend in production, DO NOT use mailpit.
|
||||
# This tool is for testing only. Instead, you should use a real SMTP
|
||||
# provider, such as Mailgun, Mailwhale, or Postal.
|
||||
mailpit:
|
||||
image: axllent/mailpit:latest
|
||||
restart: unless-stopped
|
||||
container_name: gege-jdr-backend-mailpit
|
||||
ports:
|
||||
- 127.0.0.1:8025:8025 # WebUI
|
||||
- 127.0.0.1:1025:1025 # SMTP
|
||||
volumes:
|
||||
- gege_jdr_backend_mailpit:/data
|
||||
environment:
|
||||
MP_MAX_MESSAGES: 5000
|
||||
MP_DATABASE: /data/mailpit.db
|
||||
MP_SMTP_AUTH_ACCEPT_ANY: 1
|
||||
MP_SMTP_AUTH_ALLOW_INSECURE: 1
|
||||
|
||||
volumes:
|
||||
gege_jdr_backend_db_data:
|
||||
gege_jdr_backend_pgadmin_data:
|
||||
gege_jdr_backend_mailpit:
|
||||
|
24
flake.lock
24
flake.lock
@ -5,11 +5,11 @@
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1710146030,
|
||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -20,11 +20,11 @@
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1723175592,
|
||||
"narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=",
|
||||
"lastModified": 1732014248,
|
||||
"narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "5e0ca22929f3342b19569b21b2f3462f053e497b",
|
||||
"rev": "23e89b7da85c3640bbc2173fe04f4bd114342367",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -36,11 +36,11 @@
|
||||
},
|
||||
"nixpkgs_2": {
|
||||
"locked": {
|
||||
"lastModified": 1718428119,
|
||||
"narHash": "sha256-WdWDpNaq6u1IPtxtYHHWpl5BmabtpmLnMAx0RdJ/vo8=",
|
||||
"lastModified": 1728538411,
|
||||
"narHash": "sha256-f0SBJz1eZ2yOuKUr5CA9BHULGXVSn6miBuUWdTyhUhU=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "e6cea36f83499eb4e9cd184c8a8e823296b50ad5",
|
||||
"rev": "b69de56fac8c2b6f8fd27f2eca01dcda8e0a4221",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@ -62,11 +62,11 @@
|
||||
"nixpkgs": "nixpkgs_2"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1723256423,
|
||||
"narHash": "sha256-9iDTrfVM+mbcad31a47oqW8t8tfSA4C/si6F8F2DO/w=",
|
||||
"lastModified": 1732242723,
|
||||
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=",
|
||||
"owner": "oxalica",
|
||||
"repo": "rust-overlay",
|
||||
"rev": "615cfd85b4d9c51811a8d875374268fab5bd4089",
|
||||
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
5
migrations/20240809173617_users.down.sql
Normal file
5
migrations/20240809173617_users.down.sql
Normal file
@ -0,0 +1,5 @@
|
||||
-- Add down migration script here
|
||||
ALTER TABLE IF EXISTS public.sessions DROP CONSTRAINT IF EXISTS sessions_user_id_users_fk;
|
||||
DROP TABLE IF EXISTS public.sessions;
|
||||
DROP TABLE IF EXISTS public.users;
|
||||
DROP EXTENSION IF EXISTS "uuid-ossp";
|
29
migrations/20240809173617_users.up.sql
Normal file
29
migrations/20240809173617_users.up.sql
Normal file
@ -0,0 +1,29 @@
|
||||
-- Add up migration script here
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||
|
||||
CREATE TABLE IF NOT EXISTS public.users
|
||||
(
|
||||
id uuid NOT NULL DEFAULT uuid_generate_v4(),
|
||||
email character varying(255) NOT NULL,
|
||||
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT users_email_unique UNIQUE (email)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS public.sessions
|
||||
(
|
||||
id uuid NOT NULL DEFAULT uuid_generate_v4(),
|
||||
user_id uuid NOT NULL,
|
||||
session_id character varying NOT NULL,
|
||||
expires_at timestamp with time zone NOT NULL,
|
||||
PRIMARY KEY (id),
|
||||
CONSTRAINT sessions_user_id_unique UNIQUE (user_id)
|
||||
);
|
||||
|
||||
ALTER TABLE IF EXISTS public.sessions
|
||||
ADD CONSTRAINT sessions_user_id_users_fk FOREIGN KEY (user_id)
|
||||
REFERENCES public.users (id) MATCH SIMPLE
|
||||
ON UPDATE CASCADE
|
||||
ON DELETE CASCADE
|
||||
NOT VALID;
|
@ -1,4 +1,4 @@
|
||||
[toolchain]
|
||||
channel = "1.78.0"
|
||||
channel = "1.81.0"
|
||||
components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ]
|
||||
profile = "default"
|
||||
|
@ -16,3 +16,7 @@ email:
|
||||
user: user@gege-jdr-backend.example
|
||||
from: GegeJdrBackend <noreply@gege-jdr-backend.example>
|
||||
password: hunter2
|
||||
|
||||
discord:
|
||||
client_id: changeme
|
||||
client_secret: changeme
|
||||
|
@ -3,5 +3,5 @@ debug: true
|
||||
|
||||
application:
|
||||
protocol: http
|
||||
host: 127.0.0.1
|
||||
base_url: http://127.0.0.1:3000
|
||||
host: localhost
|
||||
base_url: http://localhost:3000
|
||||
|
19
src/errors.rs
Normal file
19
src/errors.rs
Normal file
@ -0,0 +1,19 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
#[error("SQL error: {0}")]
|
||||
Sql(#[from] sqlx::Error),
|
||||
#[error("HTTP request error: {0}")]
|
||||
Request(#[from] reqwest::Error),
|
||||
#[error("OAuth token error: {0}")]
|
||||
TokenError(String),
|
||||
#[error("Unauthorized")]
|
||||
Unauthorized,
|
||||
#[error("Attempted to get a value, none found")]
|
||||
OptionError,
|
||||
#[error("Failed to parse a number as an integer")]
|
||||
ParseIntError(#[from] std::num::TryFromIntError),
|
||||
#[error("Encountered an error trying to convert an infaillible value")]
|
||||
FromRequestPartsError(#[from] std::convert::Infallible),
|
||||
}
|
15
src/lib.rs
15
src/lib.rs
@ -5,10 +5,13 @@
|
||||
#![allow(clippy::unused_async)]
|
||||
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
|
||||
|
||||
pub mod route;
|
||||
pub mod settings;
|
||||
pub mod startup;
|
||||
pub mod telemetry;
|
||||
mod errors;
|
||||
mod models;
|
||||
mod oauth;
|
||||
mod route;
|
||||
mod settings;
|
||||
mod startup;
|
||||
mod telemetry;
|
||||
|
||||
type MaybeListener = Option<poem::listener::TcpListener<String>>;
|
||||
|
||||
@ -29,8 +32,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
|
||||
tracing::event!(
|
||||
target: "gege-jdr-backend",
|
||||
tracing::Level::INFO,
|
||||
"Listening on http://127.0.0.1:{}/",
|
||||
application.port()
|
||||
"Listening on {}",
|
||||
application.settings.web_address()
|
||||
);
|
||||
application
|
||||
}
|
||||
|
18
src/models/accounts.rs
Normal file
18
src/models/accounts.rs
Normal file
@ -0,0 +1,18 @@
|
||||
type Timestampz = chrono::DateTime<chrono::Utc>;
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
|
||||
pub struct User {
|
||||
pub id: uuid::Uuid,
|
||||
pub email: String,
|
||||
pub created_at: Timestampz,
|
||||
pub last_updated: Timestampz,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
|
||||
pub struct Session {
|
||||
pub id: uuid::Uuid,
|
||||
pub user_id: uuid::Uuid,
|
||||
#[allow(clippy::struct_field_names)]
|
||||
pub session_id: String,
|
||||
pub expires_at: Timestampz,
|
||||
}
|
1
src/models/mod.rs
Normal file
1
src/models/mod.rs
Normal file
@ -0,0 +1 @@
|
||||
pub mod accounts;
|
62
src/oauth/discord.rs
Normal file
62
src/oauth/discord.rs
Normal file
@ -0,0 +1,62 @@
|
||||
use oauth2::{
|
||||
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
|
||||
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
|
||||
};
|
||||
use reqwest::Url;
|
||||
|
||||
use crate::{errors::ApiError, settings::Settings};
|
||||
|
||||
use super::OauthProvider;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DiscordOauthProvider {
|
||||
client: BasicClient,
|
||||
}
|
||||
|
||||
impl DiscordOauthProvider {
|
||||
pub fn new(settings: &Settings) -> Self {
|
||||
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
|
||||
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
|
||||
.expect("Invalid authorization endpoint URL");
|
||||
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
|
||||
.expect("Invalid token endpoint URL");
|
||||
let revocation_url =
|
||||
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
|
||||
.expect("Invalid revocation URL");
|
||||
let client = BasicClient::new(
|
||||
ClientId::new(settings.discord.client_id.clone()),
|
||||
Some(ClientSecret::new(settings.discord.client_secret.clone())),
|
||||
auth_url,
|
||||
Some(token_url),
|
||||
)
|
||||
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
|
||||
.set_revocation_uri(revocation_url);
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
impl OauthProvider for DiscordOauthProvider {
|
||||
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
|
||||
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||
let (auth_url, csrf_token) = self
|
||||
.client
|
||||
.authorize_url(CsrfToken::new_random)
|
||||
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
|
||||
.set_pkce_challenge(pkce_challenge)
|
||||
.url();
|
||||
(auth_url, csrf_token, pkce_verifier)
|
||||
}
|
||||
|
||||
async fn token(
|
||||
&self,
|
||||
code: String,
|
||||
verifier: PkceCodeVerifier,
|
||||
) -> Result<super::Token, ApiError> {
|
||||
self.client
|
||||
.exchange_code(AuthorizationCode::new(code))
|
||||
.set_pkce_verifier(verifier)
|
||||
.request_async(oauth2::reqwest::async_http_client)
|
||||
.await
|
||||
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
|
||||
}
|
||||
}
|
17
src/oauth/mod.rs
Normal file
17
src/oauth/mod.rs
Normal file
@ -0,0 +1,17 @@
|
||||
mod discord;
|
||||
pub use discord::DiscordOauthProvider;
|
||||
|
||||
use oauth2::{
|
||||
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
|
||||
StandardTokenResponse,
|
||||
};
|
||||
use reqwest::Url;
|
||||
|
||||
use crate::errors::ApiError;
|
||||
|
||||
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
|
||||
|
||||
pub trait OauthProvider {
|
||||
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
|
||||
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
|
||||
}
|
125
src/route/auth.rs
Normal file
125
src/route/auth.rs
Normal file
@ -0,0 +1,125 @@
|
||||
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
|
||||
use poem::web::Data;
|
||||
use poem::{session::Session, web::Form};
|
||||
use poem_openapi::payload::{Json, PlainText};
|
||||
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||
|
||||
use crate::oauth::{DiscordOauthProvider, OauthProvider};
|
||||
|
||||
use super::errors::ErrorResponse;
|
||||
use super::ApiCategory;
|
||||
|
||||
pub struct AuthApi;
|
||||
|
||||
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
|
||||
struct DiscordCallbackRequest {
|
||||
code: String,
|
||||
state: String,
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum LoginStatusResponse {
|
||||
#[oai(status = 201)]
|
||||
LoggedIn,
|
||||
#[oai(status = 201)]
|
||||
LoggedOut,
|
||||
#[oai(status = 301)]
|
||||
LoginRedirect(
|
||||
#[oai(header = "Location")] String,
|
||||
#[oai(header = "Cache-Control")] String,
|
||||
),
|
||||
#[oai(status = 500)]
|
||||
TokenError(Json<ErrorResponse>),
|
||||
}
|
||||
|
||||
#[derive(ApiResponse)]
|
||||
enum CsrfResponse {
|
||||
#[oai(status = 201)]
|
||||
Token(PlainText<String>),
|
||||
}
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
|
||||
impl AuthApi {
|
||||
// TODO: implement the following endpoints:
|
||||
// - /signin
|
||||
// - /signout
|
||||
// - /session
|
||||
// - /providers
|
||||
// See https://next-auth.js.org/getting-started/rest-api
|
||||
|
||||
#[oai(path = "/signin/discord", method = "get")]
|
||||
async fn signin_discord(
|
||||
&self,
|
||||
oauth: Data<&DiscordOauthProvider>,
|
||||
session: &Session,
|
||||
) -> LoginStatusResponse {
|
||||
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
|
||||
session.set("csrf", csrf_token);
|
||||
session.set("pkce", pkce_verifier);
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::INFO,
|
||||
"Redirect URL: {}",
|
||||
auth_url
|
||||
);
|
||||
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
|
||||
}
|
||||
|
||||
#[oai(path = "/callback/discord", method = "get")]
|
||||
async fn callback_discord(
|
||||
&self,
|
||||
Form(auth_request): Form<DiscordCallbackRequest>,
|
||||
oauth: Data<&DiscordOauthProvider>,
|
||||
session: &Session,
|
||||
) -> LoginStatusResponse {
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::DEBUG,
|
||||
"Discord replied with: {:?}",
|
||||
auth_request
|
||||
);
|
||||
let csrf_token = session
|
||||
.get::<CsrfToken>("csrf")
|
||||
.expect("Failed to retrieve Csrf token from session");
|
||||
if *csrf_token.secret().to_string() != auth_request.state {
|
||||
return LoginStatusResponse::TokenError(
|
||||
ErrorResponse {
|
||||
code: 500,
|
||||
message: "OAuth token error".into(),
|
||||
details: Some(
|
||||
"OAuth provider did not send a message that matches what was expected"
|
||||
.into(),
|
||||
),
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
let pkce_verifier = session
|
||||
.get::<PkceCodeVerifier>("pkce")
|
||||
.expect("Failed to retrieve pkce verifier from session");
|
||||
let token = oauth.token(auth_request.code, pkce_verifier).await;
|
||||
if let Err(e) = token {
|
||||
return LoginStatusResponse::TokenError(Json(e.into()));
|
||||
}
|
||||
let token = token.unwrap();
|
||||
tracing::event!(
|
||||
target: "auth-discord",
|
||||
tracing::Level::DEBUG,
|
||||
"Token: {:?}",
|
||||
token
|
||||
);
|
||||
session.set("token", token);
|
||||
LoginStatusResponse::LoggedIn
|
||||
}
|
||||
|
||||
#[oai(path = "/csrf", method = "get")]
|
||||
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
|
||||
CsrfResponse::Token(PlainText(token.0.clone()))
|
||||
}
|
||||
|
||||
#[oai(path = "/signout", method = "post")]
|
||||
async fn signout(&self, session: &Session) -> LoginStatusResponse {
|
||||
session.remove("token");
|
||||
LoginStatusResponse::LoggedOut
|
||||
}
|
||||
}
|
88
src/route/errors.rs
Normal file
88
src/route/errors.rs
Normal file
@ -0,0 +1,88 @@
|
||||
use poem::error::ResponseError;
|
||||
use poem_openapi::Object;
|
||||
|
||||
use crate::errors::ApiError;
|
||||
|
||||
#[derive(Debug, serde::Serialize, Default, Object)]
|
||||
pub struct ErrorResponse {
|
||||
pub code: u16,
|
||||
pub message: String,
|
||||
pub details: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ApiError> for ErrorResponse {
|
||||
fn from(value: ApiError) -> Self {
|
||||
match value {
|
||||
ApiError::Sql(e) => Self {
|
||||
code: 500,
|
||||
message: "SQL error".into(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
ApiError::Request(e) => Self {
|
||||
code: 500,
|
||||
message: "HTTP request error".into(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
ApiError::TokenError(e) => Self {
|
||||
code: 500,
|
||||
message: "OAuth token error".into(),
|
||||
details: Some(e),
|
||||
},
|
||||
ApiError::Unauthorized => Self {
|
||||
code: 401,
|
||||
message: "Unauthorized!".into(),
|
||||
..Default::default()
|
||||
},
|
||||
ApiError::OptionError => Self {
|
||||
code: 500,
|
||||
message: "Attempted to get a value, but none found".into(),
|
||||
..Default::default()
|
||||
},
|
||||
ApiError::ParseIntError(e) => Self {
|
||||
code: 500,
|
||||
message: "Failed to parse a number as an integer".into(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
ApiError::FromRequestPartsError(e) => Self {
|
||||
code: 500,
|
||||
message: "Encountered an error trying to convert an infaillible value".to_string(),
|
||||
details: Some(e.to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponseError for ApiError {
|
||||
fn status(&self) -> reqwest::StatusCode {
|
||||
match self {
|
||||
Self::FromRequestPartsError(_)
|
||||
| Self::ParseIntError(_)
|
||||
| Self::OptionError
|
||||
| Self::Sql(_)
|
||||
| Self::Request(_)
|
||||
| Self::TokenError(_) => reqwest::StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Self::Unauthorized => reqwest::StatusCode::UNAUTHORIZED,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_response(&self) -> poem::Response
|
||||
where
|
||||
Self: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
match self {
|
||||
Self::Sql(_) => todo!(),
|
||||
Self::Request(_) => todo!(),
|
||||
Self::TokenError(_) => todo!(),
|
||||
Self::Unauthorized => todo!(),
|
||||
Self::OptionError => todo!(),
|
||||
Self::ParseIntError(_) => todo!(),
|
||||
Self::FromRequestPartsError(_) => todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ErrorResponse> for poem_openapi::payload::Json<ErrorResponse> {
|
||||
fn from(value: ErrorResponse) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
@ -10,7 +10,7 @@ enum HealthResponse {
|
||||
|
||||
pub struct HealthApi;
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")]
|
||||
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
|
||||
impl HealthApi {
|
||||
#[oai(path = "/", method = "get")]
|
||||
async fn health_check(&self) -> HealthResponse {
|
||||
|
@ -6,13 +6,19 @@ pub use health::HealthApi;
|
||||
mod version;
|
||||
pub use version::VersionApi;
|
||||
|
||||
mod errors;
|
||||
|
||||
mod auth;
|
||||
pub use auth::AuthApi;
|
||||
|
||||
#[derive(Tags)]
|
||||
enum ApiCategory {
|
||||
Auth,
|
||||
Health,
|
||||
Version,
|
||||
}
|
||||
|
||||
pub(crate) struct Api;
|
||||
pub struct Api;
|
||||
|
||||
#[OpenApi]
|
||||
impl Api {}
|
||||
|
@ -25,7 +25,7 @@ enum VersionResponse {
|
||||
|
||||
pub struct VersionApi;
|
||||
|
||||
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")]
|
||||
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
|
||||
impl VersionApi {
|
||||
#[oai(path = "/", method = "get")]
|
||||
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
|
||||
|
@ -4,6 +4,7 @@ use sqlx::ConnectOptions;
|
||||
pub struct Settings {
|
||||
pub application: ApplicationSettings,
|
||||
pub database: Database,
|
||||
pub discord: Discord,
|
||||
pub debug: bool,
|
||||
pub email: EmailSettings,
|
||||
pub frontend_url: String,
|
||||
@ -12,16 +13,8 @@ pub struct Settings {
|
||||
impl Settings {
|
||||
#[must_use]
|
||||
pub fn web_address(&self) -> String {
|
||||
if self.debug {
|
||||
format!(
|
||||
"{}:{}",
|
||||
self.application.base_url.clone(),
|
||||
self.application.port
|
||||
)
|
||||
} else {
|
||||
self.application.base_url.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Multipurpose function that helps detect the current
|
||||
/// environment the application is running in using the
|
||||
@ -65,7 +58,7 @@ impl Settings {
|
||||
))
|
||||
.add_source(
|
||||
config::Environment::with_prefix("APP")
|
||||
.prefix_separator("_")
|
||||
.prefix_separator("__")
|
||||
.separator("__"),
|
||||
)
|
||||
.build()?;
|
||||
@ -164,6 +157,12 @@ pub struct EmailSettings {
|
||||
pub from: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||
pub struct Discord {
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
@ -1,8 +1,11 @@
|
||||
use poem::middleware::Cors;
|
||||
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
|
||||
use poem::middleware::{CookieJarManagerEndpoint, Cors};
|
||||
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
|
||||
use poem::{EndpointExt, Route};
|
||||
use poem_openapi::OpenApiService;
|
||||
|
||||
use crate::oauth::DiscordOauthProvider;
|
||||
use crate::route::AuthApi;
|
||||
use crate::{
|
||||
route::{Api, HealthApi, VersionApi},
|
||||
settings::Settings,
|
||||
@ -22,14 +25,23 @@ pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgr
|
||||
}
|
||||
|
||||
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
|
||||
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>;
|
||||
pub type App = AddDataEndpoint<
|
||||
AddDataEndpoint<
|
||||
AddDataEndpoint<
|
||||
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
|
||||
DiscordOauthProvider,
|
||||
>,
|
||||
sqlx::Pool<sqlx::Postgres>,
|
||||
>,
|
||||
Settings,
|
||||
>;
|
||||
|
||||
pub struct Application {
|
||||
server: Server,
|
||||
app: poem::Route,
|
||||
port: u16,
|
||||
database: sqlx::postgres::PgPool,
|
||||
settings: Settings,
|
||||
pub settings: Settings,
|
||||
}
|
||||
|
||||
pub struct RunnableApplication {
|
||||
@ -61,6 +73,8 @@ impl From<Application> for RunnableApplication {
|
||||
let app = val
|
||||
.app
|
||||
.with(Cors::new())
|
||||
.with(CookieSession::new(CookieConfig::default().secure(true)))
|
||||
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
|
||||
.data(val.database)
|
||||
.data(val.settings);
|
||||
let server = val.server;
|
||||
@ -83,7 +97,7 @@ impl Application {
|
||||
|
||||
fn setup_app(settings: &Settings) -> poem::Route {
|
||||
let api_service = OpenApiService::new(
|
||||
(Api, HealthApi, VersionApi),
|
||||
(Api, AuthApi, HealthApi, VersionApi),
|
||||
settings.application.clone().name,
|
||||
settings.application.clone().version,
|
||||
);
|
||||
|
Loading…
Reference in New Issue
Block a user