generated from phundrak/rust-poem-openapi-template
Compare commits
4 Commits
develop
...
feature/au
Author | SHA1 | Date | |
---|---|---|---|
9489f78224 | |||
ae10711e41 | |||
ff90b1959f | |||
1125bc4a38 |
@ -1,13 +1,12 @@
|
|||||||
;;; Directory Local Variables -*- no-byte-compile: t -*-
|
;;; Directory Local Variables -*- no-byte-compile: t -*-
|
||||||
;;; For more information see (info "(emacs) Directory Variables")
|
;;; For more information see (info "(emacs) Directory Variables")
|
||||||
|
|
||||||
((sql-mode
|
((rustic-mode . ((fill-column . 80)))
|
||||||
.
|
(sql-mode . ((eval . (progn
|
||||||
((eval . (progn
|
|
||||||
(setq-local lsp-sqls-connections
|
(setq-local lsp-sqls-connections
|
||||||
`(((driver . "postgresql")
|
`(((driver . "postgresql")
|
||||||
(dataSourceName .
|
(dataSourceName \,
|
||||||
,(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
(format "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable"
|
||||||
(getenv "DB_HOST")
|
(getenv "DB_HOST")
|
||||||
(getenv "DB_PORT")
|
(getenv "DB_PORT")
|
||||||
(getenv "DB_USER")
|
(getenv "DB_USER")
|
||||||
|
4021
Cargo.lock
generated
Normal file
4021
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
47
Cargo.toml
47
Cargo.toml
@ -17,45 +17,32 @@ name = "gege-jdr-backend"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
chrono = { version = "0.4.38", features = ["serde"] }
|
chrono = { version = "0.4.38", features = ["serde"] }
|
||||||
config = { version = "0.14.0", features = ["yaml"] }
|
config = { version = "0.14.1", features = ["yaml"] }
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
serde = "1.0.204"
|
oauth2 = "4.4.2"
|
||||||
serde_json = "1.0.120"
|
quote = "1.0.37"
|
||||||
thiserror = "1.0.63"
|
reqwest = { version = "0.12.9", default-features = false, features = ["charset", "h2", "http2", "rustls-tls", "json"] }
|
||||||
tokio = { version = "1.39.2", features = ["macros", "rt-multi-thread"] }
|
serde = "1.0.215"
|
||||||
|
serde_json = "1.0.133"
|
||||||
|
thiserror = "1.0.69"
|
||||||
|
tokio = { version = "1.41.1", features = ["macros", "rt-multi-thread"] }
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
tracing-subscriber = { version = "0.3.18", features = ["fmt", "std", "env-filter", "registry", "json", "tracing-log"] }
|
||||||
uuid = { version = "1.10.0", features = ["v4", "serde"] }
|
uuid = { version = "1.11.0", features = ["v4", "serde"] }
|
||||||
|
|
||||||
[dependencies.lettre]
|
|
||||||
version = "0.11.7"
|
|
||||||
default-features = false
|
|
||||||
features = [
|
|
||||||
"builder",
|
|
||||||
"hostname",
|
|
||||||
"pool",
|
|
||||||
"rustls-tls",
|
|
||||||
"tokio1",
|
|
||||||
"tokio1-rustls-tls",
|
|
||||||
"smtp-transport"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
[dependencies.poem]
|
[dependencies.poem]
|
||||||
version = "3.0.4"
|
version = "3.1.3"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = [
|
features = ["csrf", "rustls", "cookie", "test", "session"]
|
||||||
"csrf",
|
|
||||||
"rustls",
|
|
||||||
"cookie",
|
|
||||||
"test"
|
|
||||||
]
|
|
||||||
|
|
||||||
[dependencies.poem-openapi]
|
[dependencies.poem-openapi]
|
||||||
version = "5.0.3"
|
version = "5.1.2"
|
||||||
features = ["chrono", "swagger-ui", "uuid"]
|
features = ["chrono", "swagger-ui", "redoc", "rapidoc", "uuid"]
|
||||||
|
|
||||||
[dependencies.sqlx]
|
[dependencies.sqlx]
|
||||||
version = "0.8.0"
|
version = "0.8.2"
|
||||||
default-features = false
|
default-features = false
|
||||||
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
|
features = ["postgres", "uuid", "chrono", "migrate", "runtime-tokio", "macros"]
|
||||||
|
|
||||||
|
[lints.rust]
|
||||||
|
unexpected_cfgs = { level = "allow", check-cfg = ['cfg(tarpaulin_include)'] }
|
||||||
|
@ -31,25 +31,6 @@ services:
|
|||||||
depends_on:
|
depends_on:
|
||||||
- db
|
- db
|
||||||
|
|
||||||
# If you run GegeJdrBackend in production, DO NOT use mailpit.
|
|
||||||
# This tool is for testing only. Instead, you should use a real SMTP
|
|
||||||
# provider, such as Mailgun, Mailwhale, or Postal.
|
|
||||||
mailpit:
|
|
||||||
image: axllent/mailpit:latest
|
|
||||||
restart: unless-stopped
|
|
||||||
container_name: gege-jdr-backend-mailpit
|
|
||||||
ports:
|
|
||||||
- 127.0.0.1:8025:8025 # WebUI
|
|
||||||
- 127.0.0.1:1025:1025 # SMTP
|
|
||||||
volumes:
|
|
||||||
- gege_jdr_backend_mailpit:/data
|
|
||||||
environment:
|
|
||||||
MP_MAX_MESSAGES: 5000
|
|
||||||
MP_DATABASE: /data/mailpit.db
|
|
||||||
MP_SMTP_AUTH_ACCEPT_ANY: 1
|
|
||||||
MP_SMTP_AUTH_ALLOW_INSECURE: 1
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
gege_jdr_backend_db_data:
|
gege_jdr_backend_db_data:
|
||||||
gege_jdr_backend_pgadmin_data:
|
gege_jdr_backend_pgadmin_data:
|
||||||
gege_jdr_backend_mailpit:
|
|
||||||
|
24
flake.lock
24
flake.lock
@ -5,11 +5,11 @@
|
|||||||
"systems": "systems"
|
"systems": "systems"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1710146030,
|
"lastModified": 1731533236,
|
||||||
"narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=",
|
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||||
"owner": "numtide",
|
"owner": "numtide",
|
||||||
"repo": "flake-utils",
|
"repo": "flake-utils",
|
||||||
"rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a",
|
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -20,11 +20,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs": {
|
"nixpkgs": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723175592,
|
"lastModified": 1732014248,
|
||||||
"narHash": "sha256-M0xJ3FbDUc4fRZ84dPGx5VvgFsOzds77KiBMW/mMTnI=",
|
"narHash": "sha256-y/MEyuJ5oBWrWAic/14LaIr/u5E0wRVzyYsouYY3W6w=",
|
||||||
"owner": "nixos",
|
"owner": "nixos",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "5e0ca22929f3342b19569b21b2f3462f053e497b",
|
"rev": "23e89b7da85c3640bbc2173fe04f4bd114342367",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -36,11 +36,11 @@
|
|||||||
},
|
},
|
||||||
"nixpkgs_2": {
|
"nixpkgs_2": {
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1718428119,
|
"lastModified": 1728538411,
|
||||||
"narHash": "sha256-WdWDpNaq6u1IPtxtYHHWpl5BmabtpmLnMAx0RdJ/vo8=",
|
"narHash": "sha256-f0SBJz1eZ2yOuKUr5CA9BHULGXVSn6miBuUWdTyhUhU=",
|
||||||
"owner": "NixOS",
|
"owner": "NixOS",
|
||||||
"repo": "nixpkgs",
|
"repo": "nixpkgs",
|
||||||
"rev": "e6cea36f83499eb4e9cd184c8a8e823296b50ad5",
|
"rev": "b69de56fac8c2b6f8fd27f2eca01dcda8e0a4221",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
@ -62,11 +62,11 @@
|
|||||||
"nixpkgs": "nixpkgs_2"
|
"nixpkgs": "nixpkgs_2"
|
||||||
},
|
},
|
||||||
"locked": {
|
"locked": {
|
||||||
"lastModified": 1723256423,
|
"lastModified": 1732242723,
|
||||||
"narHash": "sha256-9iDTrfVM+mbcad31a47oqW8t8tfSA4C/si6F8F2DO/w=",
|
"narHash": "sha256-NWI8csIK0ujFlFuEXKnoc+7hWoCiEtINK9r48LUUMeU=",
|
||||||
"owner": "oxalica",
|
"owner": "oxalica",
|
||||||
"repo": "rust-overlay",
|
"repo": "rust-overlay",
|
||||||
"rev": "615cfd85b4d9c51811a8d875374268fab5bd4089",
|
"rev": "a229311fcb45b88a95fdfa5cecd8349c809a272a",
|
||||||
"type": "github"
|
"type": "github"
|
||||||
},
|
},
|
||||||
"original": {
|
"original": {
|
||||||
|
5
migrations/20240809173617_users.down.sql
Normal file
5
migrations/20240809173617_users.down.sql
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
-- Add down migration script here
|
||||||
|
ALTER TABLE IF EXISTS public.sessions DROP CONSTRAINT IF EXISTS sessions_user_id_users_fk;
|
||||||
|
DROP TABLE IF EXISTS public.sessions;
|
||||||
|
DROP TABLE IF EXISTS public.users;
|
||||||
|
DROP EXTENSION IF EXISTS "uuid-ossp";
|
29
migrations/20240809173617_users.up.sql
Normal file
29
migrations/20240809173617_users.up.sql
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
-- Add up migration script here
|
||||||
|
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS public.users
|
||||||
|
(
|
||||||
|
id uuid NOT NULL DEFAULT uuid_generate_v4(),
|
||||||
|
email character varying(255) NOT NULL,
|
||||||
|
created_at timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_updated timestamp with time zone NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
CONSTRAINT users_email_unique UNIQUE (email)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS public.sessions
|
||||||
|
(
|
||||||
|
id uuid NOT NULL DEFAULT uuid_generate_v4(),
|
||||||
|
user_id uuid NOT NULL,
|
||||||
|
session_id character varying NOT NULL,
|
||||||
|
expires_at timestamp with time zone NOT NULL,
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
CONSTRAINT sessions_user_id_unique UNIQUE (user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE IF EXISTS public.sessions
|
||||||
|
ADD CONSTRAINT sessions_user_id_users_fk FOREIGN KEY (user_id)
|
||||||
|
REFERENCES public.users (id) MATCH SIMPLE
|
||||||
|
ON UPDATE CASCADE
|
||||||
|
ON DELETE CASCADE
|
||||||
|
NOT VALID;
|
@ -1,4 +1,4 @@
|
|||||||
[toolchain]
|
[toolchain]
|
||||||
channel = "1.78.0"
|
channel = "1.81.0"
|
||||||
components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ]
|
components = [ "rustfmt", "rust-src", "clippy", "rust-analyzer" ]
|
||||||
profile = "default"
|
profile = "default"
|
||||||
|
@ -16,3 +16,7 @@ email:
|
|||||||
user: user@gege-jdr-backend.example
|
user: user@gege-jdr-backend.example
|
||||||
from: GegeJdrBackend <noreply@gege-jdr-backend.example>
|
from: GegeJdrBackend <noreply@gege-jdr-backend.example>
|
||||||
password: hunter2
|
password: hunter2
|
||||||
|
|
||||||
|
discord:
|
||||||
|
client_id: changeme
|
||||||
|
client_secret: changeme
|
||||||
|
@ -3,5 +3,5 @@ debug: true
|
|||||||
|
|
||||||
application:
|
application:
|
||||||
protocol: http
|
protocol: http
|
||||||
host: 127.0.0.1
|
host: localhost
|
||||||
base_url: http://127.0.0.1:3000
|
base_url: http://localhost:3000
|
||||||
|
19
src/errors.rs
Normal file
19
src/errors.rs
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum ApiError {
|
||||||
|
#[error("SQL error: {0}")]
|
||||||
|
Sql(#[from] sqlx::Error),
|
||||||
|
#[error("HTTP request error: {0}")]
|
||||||
|
Request(#[from] reqwest::Error),
|
||||||
|
#[error("OAuth token error: {0}")]
|
||||||
|
TokenError(String),
|
||||||
|
#[error("Unauthorized")]
|
||||||
|
Unauthorized,
|
||||||
|
#[error("Attempted to get a value, none found")]
|
||||||
|
OptionError,
|
||||||
|
#[error("Failed to parse a number as an integer")]
|
||||||
|
ParseIntError(#[from] std::num::TryFromIntError),
|
||||||
|
#[error("Encountered an error trying to convert an infaillible value")]
|
||||||
|
FromRequestPartsError(#[from] std::convert::Infallible),
|
||||||
|
}
|
15
src/lib.rs
15
src/lib.rs
@ -5,10 +5,13 @@
|
|||||||
#![allow(clippy::unused_async)]
|
#![allow(clippy::unused_async)]
|
||||||
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
|
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
|
||||||
|
|
||||||
pub mod route;
|
mod errors;
|
||||||
pub mod settings;
|
mod models;
|
||||||
pub mod startup;
|
mod oauth;
|
||||||
pub mod telemetry;
|
mod route;
|
||||||
|
mod settings;
|
||||||
|
mod startup;
|
||||||
|
mod telemetry;
|
||||||
|
|
||||||
type MaybeListener = Option<poem::listener::TcpListener<String>>;
|
type MaybeListener = Option<poem::listener::TcpListener<String>>;
|
||||||
|
|
||||||
@ -29,8 +32,8 @@ async fn prepare(listener: MaybeListener, test_db: Option<sqlx::PgPool>) -> star
|
|||||||
tracing::event!(
|
tracing::event!(
|
||||||
target: "gege-jdr-backend",
|
target: "gege-jdr-backend",
|
||||||
tracing::Level::INFO,
|
tracing::Level::INFO,
|
||||||
"Listening on http://127.0.0.1:{}/",
|
"Listening on {}",
|
||||||
application.port()
|
application.settings.web_address()
|
||||||
);
|
);
|
||||||
application
|
application
|
||||||
}
|
}
|
||||||
|
18
src/models/accounts.rs
Normal file
18
src/models/accounts.rs
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
type Timestampz = chrono::DateTime<chrono::Utc>;
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
|
||||||
|
pub struct User {
|
||||||
|
pub id: uuid::Uuid,
|
||||||
|
pub email: String,
|
||||||
|
pub created_at: Timestampz,
|
||||||
|
pub last_updated: Timestampz,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq, Eq)]
|
||||||
|
pub struct Session {
|
||||||
|
pub id: uuid::Uuid,
|
||||||
|
pub user_id: uuid::Uuid,
|
||||||
|
#[allow(clippy::struct_field_names)]
|
||||||
|
pub session_id: String,
|
||||||
|
pub expires_at: Timestampz,
|
||||||
|
}
|
1
src/models/mod.rs
Normal file
1
src/models/mod.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub mod accounts;
|
62
src/oauth/discord.rs
Normal file
62
src/oauth/discord.rs
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
use oauth2::{
|
||||||
|
basic::BasicClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken,
|
||||||
|
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RevocationUrl, Scope, TokenUrl,
|
||||||
|
};
|
||||||
|
use reqwest::Url;
|
||||||
|
|
||||||
|
use crate::{errors::ApiError, settings::Settings};
|
||||||
|
|
||||||
|
use super::OauthProvider;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DiscordOauthProvider {
|
||||||
|
client: BasicClient,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DiscordOauthProvider {
|
||||||
|
pub fn new(settings: &Settings) -> Self {
|
||||||
|
let redirect_url = format!("{}/v1/api/auth/callback/discord", settings.web_address());
|
||||||
|
let auth_url = AuthUrl::new("https://discord.com/oauth2/authorize".to_string())
|
||||||
|
.expect("Invalid authorization endpoint URL");
|
||||||
|
let token_url = TokenUrl::new("https://discord.com/api/oauth2/token".to_string())
|
||||||
|
.expect("Invalid token endpoint URL");
|
||||||
|
let revocation_url =
|
||||||
|
RevocationUrl::new("https://discord.com/api/oauth2/token/revoke".to_string())
|
||||||
|
.expect("Invalid revocation URL");
|
||||||
|
let client = BasicClient::new(
|
||||||
|
ClientId::new(settings.discord.client_id.clone()),
|
||||||
|
Some(ClientSecret::new(settings.discord.client_secret.clone())),
|
||||||
|
auth_url,
|
||||||
|
Some(token_url),
|
||||||
|
)
|
||||||
|
.set_redirect_uri(RedirectUrl::new(redirect_url).expect("Invalid redirect URL"))
|
||||||
|
.set_revocation_uri(revocation_url);
|
||||||
|
Self { client }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OauthProvider for DiscordOauthProvider {
|
||||||
|
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier) {
|
||||||
|
let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
|
||||||
|
let (auth_url, csrf_token) = self
|
||||||
|
.client
|
||||||
|
.authorize_url(CsrfToken::new_random)
|
||||||
|
.add_scopes(["identify", "openid", "email"].map(|v| Scope::new(v.to_string())))
|
||||||
|
.set_pkce_challenge(pkce_challenge)
|
||||||
|
.url();
|
||||||
|
(auth_url, csrf_token, pkce_verifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn token(
|
||||||
|
&self,
|
||||||
|
code: String,
|
||||||
|
verifier: PkceCodeVerifier,
|
||||||
|
) -> Result<super::Token, ApiError> {
|
||||||
|
self.client
|
||||||
|
.exchange_code(AuthorizationCode::new(code))
|
||||||
|
.set_pkce_verifier(verifier)
|
||||||
|
.request_async(oauth2::reqwest::async_http_client)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError::TokenError(format!("{e:?}")))
|
||||||
|
}
|
||||||
|
}
|
17
src/oauth/mod.rs
Normal file
17
src/oauth/mod.rs
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
mod discord;
|
||||||
|
pub use discord::DiscordOauthProvider;
|
||||||
|
|
||||||
|
use oauth2::{
|
||||||
|
basic::BasicTokenType, CsrfToken, EmptyExtraTokenFields, PkceCodeVerifier,
|
||||||
|
StandardTokenResponse,
|
||||||
|
};
|
||||||
|
use reqwest::Url;
|
||||||
|
|
||||||
|
use crate::errors::ApiError;
|
||||||
|
|
||||||
|
pub type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
|
||||||
|
|
||||||
|
pub trait OauthProvider {
|
||||||
|
fn auth_and_csrf(&self) -> (Url, CsrfToken, PkceCodeVerifier);
|
||||||
|
async fn token(&self, code: String, verifier: PkceCodeVerifier) -> Result<Token, ApiError>;
|
||||||
|
}
|
125
src/route/auth.rs
Normal file
125
src/route/auth.rs
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
use oauth2::{CsrfToken, PkceCodeVerifier, TokenResponse};
|
||||||
|
use poem::web::Data;
|
||||||
|
use poem::{session::Session, web::Form};
|
||||||
|
use poem_openapi::payload::{Json, PlainText};
|
||||||
|
use poem_openapi::{ApiResponse, Object, OpenApi};
|
||||||
|
|
||||||
|
use crate::oauth::{DiscordOauthProvider, OauthProvider};
|
||||||
|
|
||||||
|
use super::errors::ErrorResponse;
|
||||||
|
use super::ApiCategory;
|
||||||
|
|
||||||
|
pub struct AuthApi;
|
||||||
|
|
||||||
|
#[derive(Debug, Object, Clone, Eq, PartialEq, serde::Deserialize)]
|
||||||
|
struct DiscordCallbackRequest {
|
||||||
|
code: String,
|
||||||
|
state: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(ApiResponse)]
|
||||||
|
enum LoginStatusResponse {
|
||||||
|
#[oai(status = 201)]
|
||||||
|
LoggedIn,
|
||||||
|
#[oai(status = 201)]
|
||||||
|
LoggedOut,
|
||||||
|
#[oai(status = 301)]
|
||||||
|
LoginRedirect(
|
||||||
|
#[oai(header = "Location")] String,
|
||||||
|
#[oai(header = "Cache-Control")] String,
|
||||||
|
),
|
||||||
|
#[oai(status = 500)]
|
||||||
|
TokenError(Json<ErrorResponse>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(ApiResponse)]
|
||||||
|
enum CsrfResponse {
|
||||||
|
#[oai(status = 201)]
|
||||||
|
Token(PlainText<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[OpenApi(prefix_path = "/v1/api/auth", tag = "ApiCategory::Auth")]
|
||||||
|
impl AuthApi {
|
||||||
|
// TODO: implement the following endpoints:
|
||||||
|
// - /signin
|
||||||
|
// - /signout
|
||||||
|
// - /session
|
||||||
|
// - /providers
|
||||||
|
// See https://next-auth.js.org/getting-started/rest-api
|
||||||
|
|
||||||
|
#[oai(path = "/signin/discord", method = "get")]
|
||||||
|
async fn signin_discord(
|
||||||
|
&self,
|
||||||
|
oauth: Data<&DiscordOauthProvider>,
|
||||||
|
session: &Session,
|
||||||
|
) -> LoginStatusResponse {
|
||||||
|
let (auth_url, csrf_token, pkce_verifier) = oauth.0.auth_and_csrf();
|
||||||
|
session.set("csrf", csrf_token);
|
||||||
|
session.set("pkce", pkce_verifier);
|
||||||
|
tracing::event!(
|
||||||
|
target: "auth-discord",
|
||||||
|
tracing::Level::INFO,
|
||||||
|
"Redirect URL: {}",
|
||||||
|
auth_url
|
||||||
|
);
|
||||||
|
LoginStatusResponse::LoginRedirect(auth_url.to_string(), "no-cache".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/callback/discord", method = "get")]
|
||||||
|
async fn callback_discord(
|
||||||
|
&self,
|
||||||
|
Form(auth_request): Form<DiscordCallbackRequest>,
|
||||||
|
oauth: Data<&DiscordOauthProvider>,
|
||||||
|
session: &Session,
|
||||||
|
) -> LoginStatusResponse {
|
||||||
|
tracing::event!(
|
||||||
|
target: "auth-discord",
|
||||||
|
tracing::Level::DEBUG,
|
||||||
|
"Discord replied with: {:?}",
|
||||||
|
auth_request
|
||||||
|
);
|
||||||
|
let csrf_token = session
|
||||||
|
.get::<CsrfToken>("csrf")
|
||||||
|
.expect("Failed to retrieve Csrf token from session");
|
||||||
|
if *csrf_token.secret().to_string() != auth_request.state {
|
||||||
|
return LoginStatusResponse::TokenError(
|
||||||
|
ErrorResponse {
|
||||||
|
code: 500,
|
||||||
|
message: "OAuth token error".into(),
|
||||||
|
details: Some(
|
||||||
|
"OAuth provider did not send a message that matches what was expected"
|
||||||
|
.into(),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let pkce_verifier = session
|
||||||
|
.get::<PkceCodeVerifier>("pkce")
|
||||||
|
.expect("Failed to retrieve pkce verifier from session");
|
||||||
|
let token = oauth.token(auth_request.code, pkce_verifier).await;
|
||||||
|
if let Err(e) = token {
|
||||||
|
return LoginStatusResponse::TokenError(Json(e.into()));
|
||||||
|
}
|
||||||
|
let token = token.unwrap();
|
||||||
|
tracing::event!(
|
||||||
|
target: "auth-discord",
|
||||||
|
tracing::Level::DEBUG,
|
||||||
|
"Token: {:?}",
|
||||||
|
token
|
||||||
|
);
|
||||||
|
session.set("token", token);
|
||||||
|
LoginStatusResponse::LoggedIn
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/csrf", method = "get")]
|
||||||
|
async fn csrf(&self, token: &poem::web::CsrfToken) -> CsrfResponse {
|
||||||
|
CsrfResponse::Token(PlainText(token.0.clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[oai(path = "/signout", method = "post")]
|
||||||
|
async fn signout(&self, session: &Session) -> LoginStatusResponse {
|
||||||
|
session.remove("token");
|
||||||
|
LoginStatusResponse::LoggedOut
|
||||||
|
}
|
||||||
|
}
|
88
src/route/errors.rs
Normal file
88
src/route/errors.rs
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
use poem::error::ResponseError;
|
||||||
|
use poem_openapi::Object;
|
||||||
|
|
||||||
|
use crate::errors::ApiError;
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Serialize, Default, Object)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
pub code: u16,
|
||||||
|
pub message: String,
|
||||||
|
pub details: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ApiError> for ErrorResponse {
|
||||||
|
fn from(value: ApiError) -> Self {
|
||||||
|
match value {
|
||||||
|
ApiError::Sql(e) => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "SQL error".into(),
|
||||||
|
details: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
ApiError::Request(e) => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "HTTP request error".into(),
|
||||||
|
details: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
ApiError::TokenError(e) => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "OAuth token error".into(),
|
||||||
|
details: Some(e),
|
||||||
|
},
|
||||||
|
ApiError::Unauthorized => Self {
|
||||||
|
code: 401,
|
||||||
|
message: "Unauthorized!".into(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
ApiError::OptionError => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "Attempted to get a value, but none found".into(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
ApiError::ParseIntError(e) => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "Failed to parse a number as an integer".into(),
|
||||||
|
details: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
ApiError::FromRequestPartsError(e) => Self {
|
||||||
|
code: 500,
|
||||||
|
message: "Encountered an error trying to convert an infaillible value".to_string(),
|
||||||
|
details: Some(e.to_string()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResponseError for ApiError {
|
||||||
|
fn status(&self) -> reqwest::StatusCode {
|
||||||
|
match self {
|
||||||
|
Self::FromRequestPartsError(_)
|
||||||
|
| Self::ParseIntError(_)
|
||||||
|
| Self::OptionError
|
||||||
|
| Self::Sql(_)
|
||||||
|
| Self::Request(_)
|
||||||
|
| Self::TokenError(_) => reqwest::StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Self::Unauthorized => reqwest::StatusCode::UNAUTHORIZED,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_response(&self) -> poem::Response
|
||||||
|
where
|
||||||
|
Self: std::error::Error + Send + Sync + 'static,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Self::Sql(_) => todo!(),
|
||||||
|
Self::Request(_) => todo!(),
|
||||||
|
Self::TokenError(_) => todo!(),
|
||||||
|
Self::Unauthorized => todo!(),
|
||||||
|
Self::OptionError => todo!(),
|
||||||
|
Self::ParseIntError(_) => todo!(),
|
||||||
|
Self::FromRequestPartsError(_) => todo!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ErrorResponse> for poem_openapi::payload::Json<ErrorResponse> {
|
||||||
|
fn from(value: ErrorResponse) -> Self {
|
||||||
|
Self(value)
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@ enum HealthResponse {
|
|||||||
|
|
||||||
pub struct HealthApi;
|
pub struct HealthApi;
|
||||||
|
|
||||||
#[OpenApi(prefix_path = "/v1/health-check", tag = "ApiCategory::Health")]
|
#[OpenApi(prefix_path = "/v1/api/health-check", tag = "ApiCategory::Health")]
|
||||||
impl HealthApi {
|
impl HealthApi {
|
||||||
#[oai(path = "/", method = "get")]
|
#[oai(path = "/", method = "get")]
|
||||||
async fn health_check(&self) -> HealthResponse {
|
async fn health_check(&self) -> HealthResponse {
|
||||||
|
@ -6,13 +6,19 @@ pub use health::HealthApi;
|
|||||||
mod version;
|
mod version;
|
||||||
pub use version::VersionApi;
|
pub use version::VersionApi;
|
||||||
|
|
||||||
|
mod errors;
|
||||||
|
|
||||||
|
mod auth;
|
||||||
|
pub use auth::AuthApi;
|
||||||
|
|
||||||
#[derive(Tags)]
|
#[derive(Tags)]
|
||||||
enum ApiCategory {
|
enum ApiCategory {
|
||||||
|
Auth,
|
||||||
Health,
|
Health,
|
||||||
Version,
|
Version,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) struct Api;
|
pub struct Api;
|
||||||
|
|
||||||
#[OpenApi]
|
#[OpenApi]
|
||||||
impl Api {}
|
impl Api {}
|
||||||
|
@ -25,7 +25,7 @@ enum VersionResponse {
|
|||||||
|
|
||||||
pub struct VersionApi;
|
pub struct VersionApi;
|
||||||
|
|
||||||
#[OpenApi(prefix_path = "/v1/version", tag = "ApiCategory::Version")]
|
#[OpenApi(prefix_path = "/v1/api/version", tag = "ApiCategory::Version")]
|
||||||
impl VersionApi {
|
impl VersionApi {
|
||||||
#[oai(path = "/", method = "get")]
|
#[oai(path = "/", method = "get")]
|
||||||
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
|
async fn version(&self, settings: poem::web::Data<&Settings>) -> Result<VersionResponse> {
|
||||||
|
@ -4,6 +4,7 @@ use sqlx::ConnectOptions;
|
|||||||
pub struct Settings {
|
pub struct Settings {
|
||||||
pub application: ApplicationSettings,
|
pub application: ApplicationSettings,
|
||||||
pub database: Database,
|
pub database: Database,
|
||||||
|
pub discord: Discord,
|
||||||
pub debug: bool,
|
pub debug: bool,
|
||||||
pub email: EmailSettings,
|
pub email: EmailSettings,
|
||||||
pub frontend_url: String,
|
pub frontend_url: String,
|
||||||
@ -12,16 +13,8 @@ pub struct Settings {
|
|||||||
impl Settings {
|
impl Settings {
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn web_address(&self) -> String {
|
pub fn web_address(&self) -> String {
|
||||||
if self.debug {
|
|
||||||
format!(
|
|
||||||
"{}:{}",
|
|
||||||
self.application.base_url.clone(),
|
|
||||||
self.application.port
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
self.application.base_url.clone()
|
self.application.base_url.clone()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// Multipurpose function that helps detect the current
|
/// Multipurpose function that helps detect the current
|
||||||
/// environment the application is running in using the
|
/// environment the application is running in using the
|
||||||
@ -65,7 +58,7 @@ impl Settings {
|
|||||||
))
|
))
|
||||||
.add_source(
|
.add_source(
|
||||||
config::Environment::with_prefix("APP")
|
config::Environment::with_prefix("APP")
|
||||||
.prefix_separator("_")
|
.prefix_separator("__")
|
||||||
.separator("__"),
|
.separator("__"),
|
||||||
)
|
)
|
||||||
.build()?;
|
.build()?;
|
||||||
@ -164,6 +157,12 @@ pub struct EmailSettings {
|
|||||||
pub from: String,
|
pub from: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize, Clone, Default)]
|
||||||
|
pub struct Discord {
|
||||||
|
pub client_id: String,
|
||||||
|
pub client_secret: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
use poem::middleware::Cors;
|
|
||||||
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
|
use poem::middleware::{AddDataEndpoint, CorsEndpoint};
|
||||||
|
use poem::middleware::{CookieJarManagerEndpoint, Cors};
|
||||||
|
use poem::session::{CookieConfig, CookieSession, CookieSessionEndpoint};
|
||||||
use poem::{EndpointExt, Route};
|
use poem::{EndpointExt, Route};
|
||||||
use poem_openapi::OpenApiService;
|
use poem_openapi::OpenApiService;
|
||||||
|
|
||||||
|
use crate::oauth::DiscordOauthProvider;
|
||||||
|
use crate::route::AuthApi;
|
||||||
use crate::{
|
use crate::{
|
||||||
route::{Api, HealthApi, VersionApi},
|
route::{Api, HealthApi, VersionApi},
|
||||||
settings::Settings,
|
settings::Settings,
|
||||||
@ -22,14 +25,23 @@ pub fn get_connection_pool(settings: &crate::settings::Database) -> sqlx::postgr
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
|
type Server = poem::Server<poem::listener::TcpListener<String>, std::convert::Infallible>;
|
||||||
pub type App = AddDataEndpoint<AddDataEndpoint<CorsEndpoint<Route>, sqlx::PgPool>, Settings>;
|
pub type App = AddDataEndpoint<
|
||||||
|
AddDataEndpoint<
|
||||||
|
AddDataEndpoint<
|
||||||
|
CookieJarManagerEndpoint<CookieSessionEndpoint<CorsEndpoint<Route>>>,
|
||||||
|
DiscordOauthProvider,
|
||||||
|
>,
|
||||||
|
sqlx::Pool<sqlx::Postgres>,
|
||||||
|
>,
|
||||||
|
Settings,
|
||||||
|
>;
|
||||||
|
|
||||||
pub struct Application {
|
pub struct Application {
|
||||||
server: Server,
|
server: Server,
|
||||||
app: poem::Route,
|
app: poem::Route,
|
||||||
port: u16,
|
port: u16,
|
||||||
database: sqlx::postgres::PgPool,
|
database: sqlx::postgres::PgPool,
|
||||||
settings: Settings,
|
pub settings: Settings,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RunnableApplication {
|
pub struct RunnableApplication {
|
||||||
@ -61,6 +73,8 @@ impl From<Application> for RunnableApplication {
|
|||||||
let app = val
|
let app = val
|
||||||
.app
|
.app
|
||||||
.with(Cors::new())
|
.with(Cors::new())
|
||||||
|
.with(CookieSession::new(CookieConfig::default().secure(true)))
|
||||||
|
.data(crate::oauth::DiscordOauthProvider::new(&val.settings))
|
||||||
.data(val.database)
|
.data(val.database)
|
||||||
.data(val.settings);
|
.data(val.settings);
|
||||||
let server = val.server;
|
let server = val.server;
|
||||||
@ -83,7 +97,7 @@ impl Application {
|
|||||||
|
|
||||||
fn setup_app(settings: &Settings) -> poem::Route {
|
fn setup_app(settings: &Settings) -> poem::Route {
|
||||||
let api_service = OpenApiService::new(
|
let api_service = OpenApiService::new(
|
||||||
(Api, HealthApi, VersionApi),
|
(Api, AuthApi, HealthApi, VersionApi),
|
||||||
settings.application.clone().name,
|
settings.application.clone().name,
|
||||||
settings.application.clone().version,
|
settings.application.clone().version,
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user