diff --git a/gejdr-core/src/models/mod.rs b/gejdr-core/src/models/mod.rs index eabca60..cd80f50 100644 --- a/gejdr-core/src/models/mod.rs +++ b/gejdr-core/src/models/mod.rs @@ -70,4 +70,7 @@ pub trait Crud { pool: &sqlx::PgPool, id: &Id, ) -> impl std::future::Future> + Send; + + /// Returns the identifier of the entity. + fn get_id(&self) -> &Id; } diff --git a/gejdr-macros/src/crud/ir.rs b/gejdr-macros/src/crud/ir.rs index 950434d..a5d6e29 100644 --- a/gejdr-macros/src/crud/ir.rs +++ b/gejdr-macros/src/crud/ir.rs @@ -1,23 +1,215 @@ +use quote::quote; + #[derive(deluxe::ExtractAttributes)] #[deluxe(attributes(crud))] pub struct CrudStructAttributes { pub table: String, + #[deluxe(default = Vec::new())] + pub one_to_many: Vec, + #[deluxe(default = Vec::new())] + pub many_to_many: Vec, +} + +#[derive(deluxe::ParseMetaItem)] +pub struct O2MRelationship { + pub name: String, + pub remote_id: String, + pub table: String, + pub entity: syn::Type, +} + +impl From<&O2MRelationship> for proc_macro2::TokenStream { + fn from(value: &O2MRelationship) -> Self { + let query = format!( + "SELECT * FROM {} WHERE {} = $1", + value.table, value.remote_id + ); + let entity = &value.entity; + let function = syn::Ident::new( + &format!("get_{}", value.name), + proc_macro2::Span::call_site(), + ); + quote! { + pub async fn #function(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result> { + query_as!(#entity, #query, self.get_id()).fetch_all(pool).await + } + } + } +} + +#[derive(deluxe::ParseMetaItem, Clone)] +pub struct M2MLink { + pub table: String, + pub from: String, + pub to: String, +} + +//#[crud( +// table = "users", +// many_to_many = [ +// { +// name = friends, +// entity: User, +// link = { table = "user_friendships", from: "user1", to "user2" } +// } +// ] +//)] +#[derive(deluxe::ParseMetaItem)] +pub struct M2MRelationship { + pub name: String, + pub entity: syn::Type, + pub table: String, + #[deluxe(default = String::from("id"))] + pub remote_id: String, + pub link: M2MLink, +} + +pub struct Identifier { + pub table: String, + pub id: String, +} + +pub struct M2MRelationshipComplete { + pub name: String, + pub entity: syn::Type, + pub local: Identifier, + pub remote: Identifier, + pub link: M2MLink, +} + +impl M2MRelationshipComplete { + pub fn new(other: &M2MRelationship, local_table: &String, local_id: &String) -> Self { + Self { + name: other.name.clone(), + entity: other.entity.clone(), + link: other.link.clone(), + local: Identifier { + table: local_table.to_string(), + id: local_id.to_string(), + }, + remote: Identifier { + table: other.table.clone(), + id: other.remote_id.clone(), + }, + } + } +} + +impl From<&M2MRelationshipComplete> for proc_macro2::TokenStream { + fn from(value: &M2MRelationshipComplete) -> Self { + let function = syn::Ident::new( + &format!("get_{}", value.name), + proc_macro2::Span::call_site(), + ); + let entity = &value.entity; + let query = format!( + " +SELECT remote.* +FROM {} local +JOIN {} link ON link.{} = local.{} +JOIN {} remote ON link.{} = remote.{} +WHERE local.{} = $1 +", + value.local.table, + value.link.table, + value.link.from, + value.local.id, + value.remote.table, + value.link.to, + value.remote.id, + value.local.id + ); + quote! { + pub async fn #function(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result> { + query_as!(#entity, #query, self.get_id()).fetch_all(pool).await + } + } + } } #[derive(deluxe::ExtractAttributes, Clone)] #[deluxe(attributes(crud))] -pub struct CrudFieldAttributes { +struct CrudFieldAttributes { #[deluxe(default = false)] pub id: bool, #[deluxe(default = None)] pub column: Option, + #[deluxe(default = None)] + pub relation: Option, +} + +// #[crud( +// table = "profileId", +// one_to_one = { name = profile, id = "id", entity = Profile, nullable } +// )] +#[derive(deluxe::ParseMetaItem, Clone)] +pub struct O2ORelationship { + pub entity: syn::Type, + pub table: String, + #[deluxe(default = String::from("id"))] + pub remote_id: String, + #[deluxe(default = false)] + pub nullable: bool, + pub name: String, } #[derive(Clone)] pub struct CrudField { pub ident: syn::Ident, pub field: syn::Field, + pub ty: syn::Type, pub column: String, pub id: bool, - pub ty: syn::Type, + pub relation: Option, +} + +impl CrudField { + pub fn new(field: &mut syn::Field) -> Self { + let ident = field.clone().ident.unwrap(); + let ty = field.clone().ty; + let attrs: CrudFieldAttributes = + deluxe::extract_attributes(field).expect("Could not extract attributes from field"); + Self { + ident: ident.clone(), + field: field.to_owned(), + column: attrs.column.unwrap_or_else(|| ident.to_string()), + id: attrs.id, + ty, + relation: attrs.relation, + } + } +} + +impl From<&CrudField> for proc_macro2::TokenStream { + fn from(value: &CrudField) -> Self { + let Some(relation) = value.relation.clone() else { + return quote! {}; + }; + + let function = syn::Ident::new( + &format!("get_{}", relation.name), + proc_macro2::Span::call_site(), + ); + let entity = &relation.entity; + let return_type = if relation.nullable { + quote! { Option<#entity> } + } else { + quote! { #entity } + }; + let query = format!( + "SELECT * FROM {} WHERE {} = $1", + relation.table, relation.remote_id + ); + let local_ident = &value.field.ident; + let fetch = if relation.nullable { + quote! { fetch_optional } + } else { + quote! { fetch_one } + }; + quote! { + pub async fn #function(&value, pool: &::sqlx::PgPool) -> ::sqlx::Result<#return_type> { + query_as!(#entity, #query, value.#local_ident).#fetch(pool).await + } + } + } } diff --git a/gejdr-macros/src/crud/mod.rs b/gejdr-macros/src/crud/mod.rs index fa3e835..e9fb51d 100644 --- a/gejdr-macros/src/crud/mod.rs +++ b/gejdr-macros/src/crud/mod.rs @@ -1,142 +1,43 @@ -use ir::{CrudField, CrudFieldAttributes, CrudStructAttributes}; +use ir::CrudField; use quote::quote; -use syn::DeriveInput; mod ir; +mod relationships; +mod trait_implementation; -fn extract_crud_field_attrs(ast: &mut DeriveInput) -> deluxe::Result<(Vec, CrudField)> { - let mut field_attrs: Vec = Vec::new(); - // let mut identifier: Option = None; - let mut identifier: Option = None; - let mut identifier_counter = 0; - if let syn::Data::Struct(s) = &mut ast.data { - for field in &mut s.fields { - let ident = field.clone().ident.unwrap(); - let ty = field.clone().ty; - let attrs: CrudFieldAttributes = - deluxe::extract_attributes(field).expect("Could not extract attributes from field"); - let field = CrudField { - ident: ident.clone(), - field: field.to_owned(), - column: attrs.column.unwrap_or_else(|| ident.to_string()), - id: attrs.id, - ty, - }; - if attrs.id { - identifier_counter += 1; - identifier = Some(field.clone()); - } - if identifier_counter > 1 { - return Err(syn::Error::new_spanned( - field.field, - "Struct {name} can only have one identifier", - )); - } - field_attrs.push(field); - } - } - if identifier_counter < 1 { - Err(syn::Error::new_spanned( +fn extract_crud_field_attrs( + ast: &mut syn::DeriveInput, +) -> deluxe::Result<(Vec, CrudField)> { + let syn::Data::Struct(s) = &mut ast.data else { + return Err(syn::Error::new_spanned( + ast, + "Cannot apply to something other than a struct", + )); + }; + let fields = s + .fields + .clone() + .into_iter() + .map(|mut field| CrudField::new(&mut field)) + .collect::>(); + let identifiers: Vec = fields + .clone() + .into_iter() + .filter(|field| field.id) + .collect(); + match identifiers.len() { + 0 => Err(syn::Error::new_spanned( ast, "Struct {name} must have one identifier", - )) - } else { - Ok((field_attrs, identifier.unwrap())) - } -} - -fn generate_find_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream { - let find_string = format!("SELECT * FROM {} WHERE {} = $1", table, id.column); - let ty = &id.ty; - quote! { - async fn find(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result> { - ::sqlx::query_as!(Self, #find_string, id) - .fetch_optional(pool) - .await - } - } -} - -fn generate_create_query(table: &str, fields: &[CrudField]) -> proc_macro2::TokenStream { - let inputs: Vec = (1..=fields.len()).map(|num| format!("${num}")).collect(); - let create_string = format!( - "INSERT INTO {} ({}) VALUES ({}) RETURNING *", - table, - fields - .iter() - .map(|v| v.column.clone()) - .collect::>() - .join(", "), - inputs.join(", ") - ); - let field_idents: Vec = fields.iter().map(|f| f.ident.clone()).collect(); - quote! { - async fn create(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { - ::sqlx::query_as!( - Self, - #create_string, - #(self.#field_idents),* - ) - .fetch_one(pool) - .await - } - } -} - -fn generate_update_query( - table: &str, - fields: &[CrudField], - id: &CrudField, -) -> proc_macro2::TokenStream { - let mut fields: Vec<&CrudField> = fields.iter().filter(|f| !f.id).collect(); - let update_columns = fields - .iter() - .enumerate() - .map(|(i, &field)| format!("{} = ${}", field.column, i + 1)) - .collect::>() - .join(", "); - let update_string = format!( - "UPDATE {} SET {} WHERE {} = ${} RETURNING *", - table, - update_columns, - id.column, - fields.len() + 1 - ); - fields.push(id); - let field_idents: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect(); - quote! { - async fn update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { - ::sqlx::query_as!( - Self, - #update_string, - #(self.#field_idents),* - ) - .fetch_one(pool) - .await - } - } -} - -fn generate_delete_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream { - let delete_string = format!("DELETE FROM {} WHERE {} = $1", table, id.column); - let ty = &id.ty; - let ident = &id.ident; - - quote! { - async fn delete_by_id(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result { - let rows_affected = ::sqlx::query!(#delete_string, id) - .execute(pool) - .await? - .rows_affected(); - Ok(rows_affected) - } - - async fn delete(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { - let rows_affected = ::sqlx::query!(#delete_string, self.#ident) - .execute(pool) - .await? - .rows_affected(); - Ok(rows_affected) + )), + 1 => Ok((fields, identifiers.first().unwrap().clone())), + _ => { + let id1 = identifiers.first().unwrap(); + let id2 = identifiers.get(1).unwrap(); + Err(syn::Error::new_spanned(id2.field.clone(), format!( + "Field {} cannot be an identifier, {} already is one.\nOnly one identifier is supported.", + id1.ident, id2.ident + ))) } } } @@ -144,45 +45,15 @@ fn generate_delete_query(table: &str, id: &CrudField) -> proc_macro2::TokenStrea pub fn crud_derive_macro2( item: proc_macro2::TokenStream, ) -> deluxe::Result { - // parse - let mut ast: DeriveInput = syn::parse2(item).expect("Failed to parse input"); - - // extract struct attributes - let CrudStructAttributes { table } = + let mut ast: syn::DeriveInput = syn::parse2(item).expect("Failed to parse input"); + let struct_attrs: ir::CrudStructAttributes = deluxe::extract_attributes(&mut ast).expect("Could not extract attributes from struct"); - - // extract field attributes let (fields, id) = extract_crud_field_attrs(&mut ast)?; - let ty = &id.ty; - let id_ident = &id.ident; - - // define impl variables - let ident = &ast.ident; - let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl(); - - // generate - let find_query = generate_find_query(&table, &id); - let create_query = generate_create_query(&table, &fields); - let update_query = generate_update_query(&table, &fields, &id); - let delete_query = generate_delete_query(&table, &id); + let trait_impl = trait_implementation::derive_trait(&ast, &struct_attrs.table, &fields, &id); + let relationships = relationships::derive_relationships(&ast, &struct_attrs, &fields, &id); let code = quote! { - impl #impl_generics Crud<#ty> for #ident #type_generics #where_clause { - #find_query - - #create_query - - #update_query - - async fn create_or_update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { - if Self::find(pool, &self.#id_ident).await?.is_some() { - self.update(pool).await - } else { - self.create(pool).await - } - } - - #delete_query - } + #trait_impl + #relationships }; Ok(code) } diff --git a/gejdr-macros/src/crud/relationships.rs b/gejdr-macros/src/crud/relationships.rs new file mode 100644 index 0000000..a50b9cb --- /dev/null +++ b/gejdr-macros/src/crud/relationships.rs @@ -0,0 +1,54 @@ +use std::str::FromStr; + +use crate::crud::ir::M2MRelationshipComplete; + +use super::ir::CrudField; +use proc_macro2::TokenStream; +use quote::quote; + +fn join_token_streams(token_streams: &[TokenStream]) -> TokenStream { + let newline = TokenStream::from_str("\n").unwrap(); + token_streams + .iter() + .cloned() + .flat_map(|ts| std::iter::once(ts).chain(std::iter::once(newline.clone()))) + .collect() +} + +fn derive(relationships: &[T], condition: P) -> TokenStream +where + for<'a> &'a T: Into, + P: FnMut(&&T) -> bool, +{ + let implementations: Vec = relationships + .iter() + .filter(condition) + .map(std::convert::Into::into) + .collect(); + join_token_streams(&implementations) +} + +pub fn derive_relationships( + ast: &syn::DeriveInput, + struct_attrs: &super::ir::CrudStructAttributes, + fields: &[CrudField], + id: &CrudField, +) -> TokenStream { + let struct_name = &ast.ident; + let one_to_one = derive(fields, |field| field.relation.is_none()); + let one_to_many = derive(&struct_attrs.one_to_many, |_| true); + let many_to_many: Vec = struct_attrs + .many_to_many + .iter() + .map(|v| M2MRelationshipComplete::new(v, &struct_attrs.table, &id.column)) + .collect(); + let many_to_many = derive(&many_to_many, |_| true); + + quote! { + impl #struct_name { + #one_to_one + #one_to_many + #many_to_many + } + } +} diff --git a/gejdr-macros/src/crud/trait_implementation.rs b/gejdr-macros/src/crud/trait_implementation.rs new file mode 100644 index 0000000..3b6d71b --- /dev/null +++ b/gejdr-macros/src/crud/trait_implementation.rs @@ -0,0 +1,149 @@ +use super::ir::CrudField; +use quote::quote; + +fn generate_find_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream { + let find_string = format!("SELECT * FROM {} WHERE {} = $1", table, id.column); + let ty = &id.ty; + quote! { + async fn find(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result> { + ::sqlx::query_as!(Self, #find_string, id) + .fetch_optional(pool) + .await + } + } +} + +fn generate_create_query(table: &str, fields: &[CrudField]) -> proc_macro2::TokenStream { + let inputs: Vec = (1..=fields.len()).map(|num| format!("${num}")).collect(); + let create_string = format!( + "INSERT INTO {} ({}) VALUES ({}) RETURNING *", + table, + fields + .iter() + .map(|v| v.column.clone()) + .collect::>() + .join(", "), + inputs.join(", ") + ); + let field_idents: Vec = fields.iter().map(|f| f.ident.clone()).collect(); + quote! { + async fn create(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { + ::sqlx::query_as!( + Self, + #create_string, + #(self.#field_idents),* + ) + .fetch_one(pool) + .await + } + } +} + +fn generate_update_query( + table: &str, + fields: &[CrudField], + id: &CrudField, +) -> proc_macro2::TokenStream { + let mut fields: Vec<&CrudField> = fields.iter().filter(|f| !f.id).collect(); + let update_columns = fields + .iter() + .enumerate() + .map(|(i, &field)| format!("{} = ${}", field.column, i + 1)) + .collect::>() + .join(", "); + let update_string = format!( + "UPDATE {} SET {} WHERE {} = ${} RETURNING *", + table, + update_columns, + id.column, + fields.len() + 1 + ); + fields.push(id); + let field_idents: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect(); + quote! { + async fn update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { + ::sqlx::query_as!( + Self, + #update_string, + #(self.#field_idents),* + ) + .fetch_one(pool) + .await + } + } +} + +fn generate_delete_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream { + let delete_string = format!("DELETE FROM {} WHERE {} = $1", table, id.column); + let ty = &id.ty; + + quote! { + async fn delete_by_id(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result { + let rows_affected = ::sqlx::query!(#delete_string, id) + .execute(pool) + .await? + .rows_affected(); + Ok(rows_affected) + } + + async fn delete(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { + let rows_affected = ::sqlx::query!(#delete_string, self.get_id()) + .execute(pool) + .await? + .rows_affected(); + Ok(rows_affected) + } + } +} + +fn generate_get_id(id: &CrudField) -> proc_macro2::TokenStream { + let ident = &id.ident; + let ty = &id.ty; + quote! { + fn get_id(&self) -> &#ty { + &self.#ident + } + } +} + +pub fn derive_trait( + ast: &syn::DeriveInput, + table: &str, + fields: &[CrudField], + id: &CrudField, +) -> proc_macro2::TokenStream { + let ty = &id.ty; + let id_ident = &id.ident; + + // define impl variables + let ident = &ast.ident; + let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl(); + + // generate + let get_id = generate_get_id(id); + let find_query = generate_find_query(table, id); + let create_query = generate_create_query(table, fields); + let update_query = generate_update_query(table, fields, id); + let delete_query = generate_delete_query(table, id); + quote! { + impl #impl_generics Crud<#ty> for #ident #type_generics #where_clause { + #get_id + + #find_query + + #create_query + + #update_query + + async fn create_or_update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { + if Self::find(pool, &self.#id_ident).await?.is_some() { + self.update(pool).await + } else { + self.create(pool).await + } + } + + #delete_query + } + } +} diff --git a/gejdr-macros/src/lib.rs b/gejdr-macros/src/lib.rs index 4889dc7..a7d2436 100644 --- a/gejdr-macros/src/lib.rs +++ b/gejdr-macros/src/lib.rs @@ -3,7 +3,6 @@ #![deny(clippy::nursery)] #![allow(clippy::module_name_repetitions)] #![allow(clippy::unused_async)] -#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling #![forbid(unsafe_code)] //! Create ``SQLx`` CRUD code for a struct in Postgres. @@ -19,11 +18,12 @@ //! - update an entity in the database //! //! SQL query: `UPDATE ... SET ... WHERE = ... RETURNING *` -//! - delete an entity from the database using its id +//! - delete an entity from the database using its id or an id +//! provided by the interface’s user //! //! SQL query: `DELETE FROM ... WHERE = ...` //! - update an entity or create it if it does not already exist in -//! - the database +//! the database //! //! This macro relies on the trait `Crud` found in the `gejdr-core` //! crate. @@ -34,9 +34,16 @@ //! # Usage //! //! Add `#[crud(table = "my_table_name")]` atop of the structure, -//! after the `Crud` derive. You will also need to add `#[crud(id)]` -//! atop of the field of your struct that will be used as the -//! identifier of your entity. +//! after the `Crud` derive. +//! +//! ## Entity Identifier +//! You will also need to add `#[crud(id)]` atop of the field of your +//! struct that will be used as the identifier of your entity. +//! +//! ## Column Name +//! If the name of a field does not match the name of its related +//! column, you can use `#[crud(column = "...")]` to specify the +//! correct value. //! //! ```ignore //! #[derive(Crud)] @@ -44,6 +51,7 @@ //! pub struct User { //! #[crud(id)] //! id: String, +//! #[crud(column = "name")] //! username: String, //! created_at: Timestampz, //! last_updated: Timestampz,