generated from phundrak/rust-poem-openapi-template
	add get_id and relationships, untested
	
		
			
	
		
	
	
		
	
		
			All checks were successful
		
		
	
	
		
			
				
	
				CI / tests (push) Successful in 8m40s
				
			
		
		
	
	
				
					
				
			
		
			All checks were successful
		
		
	
	CI / tests (push) Successful in 8m40s
				
			This commit is contained in:
		
							parent
							
								
									642d7bae0d
								
							
						
					
					
						commit
						915bd8387e
					
				@ -70,4 +70,7 @@ pub trait Crud<Id> {
 | 
			
		||||
        pool: &sqlx::PgPool,
 | 
			
		||||
        id: &Id,
 | 
			
		||||
    ) -> impl std::future::Future<Output = sqlx::Result<u64>> + Send;
 | 
			
		||||
 | 
			
		||||
    /// Returns the identifier of the entity.
 | 
			
		||||
    fn get_id(&self) -> &Id;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,23 +1,215 @@
 | 
			
		||||
use quote::quote;
 | 
			
		||||
 | 
			
		||||
#[derive(deluxe::ExtractAttributes)]
 | 
			
		||||
#[deluxe(attributes(crud))]
 | 
			
		||||
pub struct CrudStructAttributes {
 | 
			
		||||
    pub table: String,
 | 
			
		||||
    #[deluxe(default = Vec::new())]
 | 
			
		||||
    pub one_to_many: Vec<O2MRelationship>,
 | 
			
		||||
    #[deluxe(default = Vec::new())]
 | 
			
		||||
    pub many_to_many: Vec<M2MRelationship>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(deluxe::ParseMetaItem)]
 | 
			
		||||
pub struct O2MRelationship {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub remote_id: String,
 | 
			
		||||
    pub table: String,
 | 
			
		||||
    pub entity: syn::Type,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<&O2MRelationship> for proc_macro2::TokenStream {
 | 
			
		||||
    fn from(value: &O2MRelationship) -> Self {
 | 
			
		||||
        let query = format!(
 | 
			
		||||
            "SELECT * FROM {} WHERE {} = $1",
 | 
			
		||||
            value.table, value.remote_id
 | 
			
		||||
        );
 | 
			
		||||
        let entity = &value.entity;
 | 
			
		||||
        let function = syn::Ident::new(
 | 
			
		||||
            &format!("get_{}", value.name),
 | 
			
		||||
            proc_macro2::Span::call_site(),
 | 
			
		||||
        );
 | 
			
		||||
        quote! {
 | 
			
		||||
            pub async fn #function(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Vec<#entity>> {
 | 
			
		||||
                query_as!(#entity, #query, self.get_id()).fetch_all(pool).await
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(deluxe::ParseMetaItem, Clone)]
 | 
			
		||||
pub struct M2MLink {
 | 
			
		||||
    pub table: String,
 | 
			
		||||
    pub from: String,
 | 
			
		||||
    pub to: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//#[crud(
 | 
			
		||||
//    table = "users",
 | 
			
		||||
//    many_to_many = [
 | 
			
		||||
//        {
 | 
			
		||||
//            name = friends,
 | 
			
		||||
//            entity: User,
 | 
			
		||||
//            link = { table = "user_friendships", from: "user1", to "user2" }
 | 
			
		||||
//        }
 | 
			
		||||
//    ]
 | 
			
		||||
//)]
 | 
			
		||||
#[derive(deluxe::ParseMetaItem)]
 | 
			
		||||
pub struct M2MRelationship {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub entity: syn::Type,
 | 
			
		||||
    pub table: String,
 | 
			
		||||
    #[deluxe(default = String::from("id"))]
 | 
			
		||||
    pub remote_id: String,
 | 
			
		||||
    pub link: M2MLink,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct Identifier {
 | 
			
		||||
    pub table: String,
 | 
			
		||||
    pub id: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct M2MRelationshipComplete {
 | 
			
		||||
    pub name: String,
 | 
			
		||||
    pub entity: syn::Type,
 | 
			
		||||
    pub local: Identifier,
 | 
			
		||||
    pub remote: Identifier,
 | 
			
		||||
    pub link: M2MLink,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl M2MRelationshipComplete {
 | 
			
		||||
    pub fn new(other: &M2MRelationship, local_table: &String, local_id: &String) -> Self {
 | 
			
		||||
        Self {
 | 
			
		||||
            name: other.name.clone(),
 | 
			
		||||
            entity: other.entity.clone(),
 | 
			
		||||
            link: other.link.clone(),
 | 
			
		||||
            local: Identifier {
 | 
			
		||||
                table: local_table.to_string(),
 | 
			
		||||
                id: local_id.to_string(),
 | 
			
		||||
            },
 | 
			
		||||
            remote: Identifier {
 | 
			
		||||
                table: other.table.clone(),
 | 
			
		||||
                id: other.remote_id.clone(),
 | 
			
		||||
            },
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<&M2MRelationshipComplete> for proc_macro2::TokenStream {
 | 
			
		||||
    fn from(value: &M2MRelationshipComplete) -> Self {
 | 
			
		||||
        let function = syn::Ident::new(
 | 
			
		||||
            &format!("get_{}", value.name),
 | 
			
		||||
            proc_macro2::Span::call_site(),
 | 
			
		||||
        );
 | 
			
		||||
        let entity = &value.entity;
 | 
			
		||||
        let query = format!(
 | 
			
		||||
            "
 | 
			
		||||
SELECT remote.*
 | 
			
		||||
FROM {} local
 | 
			
		||||
JOIN {} link ON link.{} = local.{}
 | 
			
		||||
JOIN {} remote ON link.{} = remote.{}
 | 
			
		||||
WHERE local.{} = $1
 | 
			
		||||
",
 | 
			
		||||
            value.local.table,
 | 
			
		||||
            value.link.table,
 | 
			
		||||
            value.link.from,
 | 
			
		||||
            value.local.id,
 | 
			
		||||
            value.remote.table,
 | 
			
		||||
            value.link.to,
 | 
			
		||||
            value.remote.id,
 | 
			
		||||
            value.local.id
 | 
			
		||||
        );
 | 
			
		||||
        quote! {
 | 
			
		||||
            pub async fn #function(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Vec<#entity>> {
 | 
			
		||||
                query_as!(#entity, #query, self.get_id()).fetch_all(pool).await
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(deluxe::ExtractAttributes, Clone)]
 | 
			
		||||
#[deluxe(attributes(crud))]
 | 
			
		||||
pub struct CrudFieldAttributes {
 | 
			
		||||
struct CrudFieldAttributes {
 | 
			
		||||
    #[deluxe(default = false)]
 | 
			
		||||
    pub id: bool,
 | 
			
		||||
    #[deluxe(default = None)]
 | 
			
		||||
    pub column: Option<String>,
 | 
			
		||||
    #[deluxe(default = None)]
 | 
			
		||||
    pub relation: Option<O2ORelationship>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// #[crud(
 | 
			
		||||
//     table = "profileId",
 | 
			
		||||
//     one_to_one = { name = profile, id = "id", entity = Profile, nullable }
 | 
			
		||||
//  )]
 | 
			
		||||
#[derive(deluxe::ParseMetaItem, Clone)]
 | 
			
		||||
pub struct O2ORelationship {
 | 
			
		||||
    pub entity: syn::Type,
 | 
			
		||||
    pub table: String,
 | 
			
		||||
    #[deluxe(default = String::from("id"))]
 | 
			
		||||
    pub remote_id: String,
 | 
			
		||||
    #[deluxe(default = false)]
 | 
			
		||||
    pub nullable: bool,
 | 
			
		||||
    pub name: String,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Clone)]
 | 
			
		||||
pub struct CrudField {
 | 
			
		||||
    pub ident: syn::Ident,
 | 
			
		||||
    pub field: syn::Field,
 | 
			
		||||
    pub ty: syn::Type,
 | 
			
		||||
    pub column: String,
 | 
			
		||||
    pub id: bool,
 | 
			
		||||
    pub ty: syn::Type,
 | 
			
		||||
    pub relation: Option<O2ORelationship>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl CrudField {
 | 
			
		||||
    pub fn new(field: &mut syn::Field) -> Self {
 | 
			
		||||
        let ident = field.clone().ident.unwrap();
 | 
			
		||||
        let ty = field.clone().ty;
 | 
			
		||||
        let attrs: CrudFieldAttributes =
 | 
			
		||||
            deluxe::extract_attributes(field).expect("Could not extract attributes from field");
 | 
			
		||||
        Self {
 | 
			
		||||
            ident: ident.clone(),
 | 
			
		||||
            field: field.to_owned(),
 | 
			
		||||
            column: attrs.column.unwrap_or_else(|| ident.to_string()),
 | 
			
		||||
            id: attrs.id,
 | 
			
		||||
            ty,
 | 
			
		||||
            relation: attrs.relation,
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl From<&CrudField> for proc_macro2::TokenStream {
 | 
			
		||||
    fn from(value: &CrudField) -> Self {
 | 
			
		||||
        let Some(relation) = value.relation.clone() else {
 | 
			
		||||
            return quote! {};
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        let function = syn::Ident::new(
 | 
			
		||||
            &format!("get_{}", relation.name),
 | 
			
		||||
            proc_macro2::Span::call_site(),
 | 
			
		||||
        );
 | 
			
		||||
        let entity = &relation.entity;
 | 
			
		||||
        let return_type = if relation.nullable {
 | 
			
		||||
            quote! { Option<#entity> }
 | 
			
		||||
        } else {
 | 
			
		||||
            quote! { #entity }
 | 
			
		||||
        };
 | 
			
		||||
        let query = format!(
 | 
			
		||||
            "SELECT * FROM {} WHERE {} = $1",
 | 
			
		||||
            relation.table, relation.remote_id
 | 
			
		||||
        );
 | 
			
		||||
        let local_ident = &value.field.ident;
 | 
			
		||||
        let fetch = if relation.nullable {
 | 
			
		||||
            quote! { fetch_optional }
 | 
			
		||||
        } else {
 | 
			
		||||
            quote! { fetch_one }
 | 
			
		||||
        };
 | 
			
		||||
        quote! {
 | 
			
		||||
            pub async fn #function(&value, pool: &::sqlx::PgPool) -> ::sqlx::Result<#return_type> {
 | 
			
		||||
                query_as!(#entity, #query, value.#local_ident).#fetch(pool).await
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1,142 +1,43 @@
 | 
			
		||||
use ir::{CrudField, CrudFieldAttributes, CrudStructAttributes};
 | 
			
		||||
use ir::CrudField;
 | 
			
		||||
use quote::quote;
 | 
			
		||||
use syn::DeriveInput;
 | 
			
		||||
 | 
			
		||||
mod ir;
 | 
			
		||||
mod relationships;
 | 
			
		||||
mod trait_implementation;
 | 
			
		||||
 | 
			
		||||
fn extract_crud_field_attrs(ast: &mut DeriveInput) -> deluxe::Result<(Vec<CrudField>, CrudField)> {
 | 
			
		||||
    let mut field_attrs: Vec<CrudField> = Vec::new();
 | 
			
		||||
    // let mut identifier: Option<CrudIdentifier> = None;
 | 
			
		||||
    let mut identifier: Option<CrudField> = None;
 | 
			
		||||
    let mut identifier_counter = 0;
 | 
			
		||||
    if let syn::Data::Struct(s) = &mut ast.data {
 | 
			
		||||
        for field in &mut s.fields {
 | 
			
		||||
            let ident = field.clone().ident.unwrap();
 | 
			
		||||
            let ty = field.clone().ty;
 | 
			
		||||
            let attrs: CrudFieldAttributes =
 | 
			
		||||
                deluxe::extract_attributes(field).expect("Could not extract attributes from field");
 | 
			
		||||
            let field = CrudField {
 | 
			
		||||
                ident: ident.clone(),
 | 
			
		||||
                field: field.to_owned(),
 | 
			
		||||
                column: attrs.column.unwrap_or_else(|| ident.to_string()),
 | 
			
		||||
                id: attrs.id,
 | 
			
		||||
                ty,
 | 
			
		||||
            };
 | 
			
		||||
            if attrs.id {
 | 
			
		||||
                identifier_counter += 1;
 | 
			
		||||
                identifier = Some(field.clone());
 | 
			
		||||
            }
 | 
			
		||||
            if identifier_counter > 1 {
 | 
			
		||||
                return Err(syn::Error::new_spanned(
 | 
			
		||||
                    field.field,
 | 
			
		||||
                    "Struct {name} can only have one identifier",
 | 
			
		||||
                ));
 | 
			
		||||
            }
 | 
			
		||||
            field_attrs.push(field);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    if identifier_counter < 1 {
 | 
			
		||||
        Err(syn::Error::new_spanned(
 | 
			
		||||
fn extract_crud_field_attrs(
 | 
			
		||||
    ast: &mut syn::DeriveInput,
 | 
			
		||||
) -> deluxe::Result<(Vec<CrudField>, CrudField)> {
 | 
			
		||||
    let syn::Data::Struct(s) = &mut ast.data else {
 | 
			
		||||
        return Err(syn::Error::new_spanned(
 | 
			
		||||
            ast,
 | 
			
		||||
            "Cannot apply to something other than a struct",
 | 
			
		||||
        ));
 | 
			
		||||
    };
 | 
			
		||||
    let fields = s
 | 
			
		||||
        .fields
 | 
			
		||||
        .clone()
 | 
			
		||||
        .into_iter()
 | 
			
		||||
        .map(|mut field| CrudField::new(&mut field))
 | 
			
		||||
        .collect::<Vec<CrudField>>();
 | 
			
		||||
    let identifiers: Vec<CrudField> = fields
 | 
			
		||||
        .clone()
 | 
			
		||||
        .into_iter()
 | 
			
		||||
        .filter(|field| field.id)
 | 
			
		||||
        .collect();
 | 
			
		||||
    match identifiers.len() {
 | 
			
		||||
        0 => Err(syn::Error::new_spanned(
 | 
			
		||||
            ast,
 | 
			
		||||
            "Struct {name} must have one identifier",
 | 
			
		||||
        ))
 | 
			
		||||
    } else {
 | 
			
		||||
        Ok((field_attrs, identifier.unwrap()))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_find_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream {
 | 
			
		||||
    let find_string = format!("SELECT * FROM {} WHERE {} = $1", table, id.column);
 | 
			
		||||
    let ty = &id.ty;
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn find(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<Option<Self>> {
 | 
			
		||||
            ::sqlx::query_as!(Self, #find_string, id)
 | 
			
		||||
                .fetch_optional(pool)
 | 
			
		||||
                .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_create_query(table: &str, fields: &[CrudField]) -> proc_macro2::TokenStream {
 | 
			
		||||
    let inputs: Vec<String> = (1..=fields.len()).map(|num| format!("${num}")).collect();
 | 
			
		||||
    let create_string = format!(
 | 
			
		||||
        "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
 | 
			
		||||
        table,
 | 
			
		||||
        fields
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|v| v.column.clone())
 | 
			
		||||
            .collect::<Vec<String>>()
 | 
			
		||||
            .join(", "),
 | 
			
		||||
        inputs.join(", ")
 | 
			
		||||
    );
 | 
			
		||||
    let field_idents: Vec<syn::Ident> = fields.iter().map(|f| f.ident.clone()).collect();
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn create(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
 | 
			
		||||
            ::sqlx::query_as!(
 | 
			
		||||
                Self,
 | 
			
		||||
                #create_string,
 | 
			
		||||
                #(self.#field_idents),*
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_one(pool)
 | 
			
		||||
            .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_update_query(
 | 
			
		||||
    table: &str,
 | 
			
		||||
    fields: &[CrudField],
 | 
			
		||||
    id: &CrudField,
 | 
			
		||||
) -> proc_macro2::TokenStream {
 | 
			
		||||
    let mut fields: Vec<&CrudField> = fields.iter().filter(|f| !f.id).collect();
 | 
			
		||||
    let update_columns = fields
 | 
			
		||||
        .iter()
 | 
			
		||||
        .enumerate()
 | 
			
		||||
        .map(|(i, &field)| format!("{} = ${}", field.column, i + 1))
 | 
			
		||||
        .collect::<Vec<String>>()
 | 
			
		||||
        .join(", ");
 | 
			
		||||
    let update_string = format!(
 | 
			
		||||
        "UPDATE {} SET {} WHERE {} = ${} RETURNING *",
 | 
			
		||||
        table,
 | 
			
		||||
        update_columns,
 | 
			
		||||
        id.column,
 | 
			
		||||
        fields.len() + 1
 | 
			
		||||
    );
 | 
			
		||||
    fields.push(id);
 | 
			
		||||
    let field_idents: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
 | 
			
		||||
            ::sqlx::query_as!(
 | 
			
		||||
                Self,
 | 
			
		||||
                #update_string,
 | 
			
		||||
                #(self.#field_idents),*
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_one(pool)
 | 
			
		||||
            .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_delete_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream {
 | 
			
		||||
    let delete_string = format!("DELETE FROM {} WHERE {} = $1", table, id.column);
 | 
			
		||||
    let ty = &id.ty;
 | 
			
		||||
    let ident = &id.ident;
 | 
			
		||||
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn delete_by_id(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<u64> {
 | 
			
		||||
            let rows_affected = ::sqlx::query!(#delete_string, id)
 | 
			
		||||
                .execute(pool)
 | 
			
		||||
                .await?
 | 
			
		||||
                .rows_affected();
 | 
			
		||||
            Ok(rows_affected)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        async fn delete(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<u64> {
 | 
			
		||||
            let rows_affected = ::sqlx::query!(#delete_string, self.#ident)
 | 
			
		||||
                .execute(pool)
 | 
			
		||||
                .await?
 | 
			
		||||
                .rows_affected();
 | 
			
		||||
            Ok(rows_affected)
 | 
			
		||||
        )),
 | 
			
		||||
        1 => Ok((fields, identifiers.first().unwrap().clone())),
 | 
			
		||||
        _ => {
 | 
			
		||||
            let id1 = identifiers.first().unwrap();
 | 
			
		||||
            let id2 = identifiers.get(1).unwrap();
 | 
			
		||||
            Err(syn::Error::new_spanned(id2.field.clone(), format!(
 | 
			
		||||
                "Field {} cannot be an identifier, {} already is one.\nOnly one identifier is supported.",
 | 
			
		||||
                id1.ident, id2.ident
 | 
			
		||||
            )))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -144,45 +45,15 @@ fn generate_delete_query(table: &str, id: &CrudField) -> proc_macro2::TokenStrea
 | 
			
		||||
pub fn crud_derive_macro2(
 | 
			
		||||
    item: proc_macro2::TokenStream,
 | 
			
		||||
) -> deluxe::Result<proc_macro2::TokenStream> {
 | 
			
		||||
    // parse
 | 
			
		||||
    let mut ast: DeriveInput = syn::parse2(item).expect("Failed to parse input");
 | 
			
		||||
 | 
			
		||||
    // extract struct attributes
 | 
			
		||||
    let CrudStructAttributes { table } =
 | 
			
		||||
    let mut ast: syn::DeriveInput = syn::parse2(item).expect("Failed to parse input");
 | 
			
		||||
    let struct_attrs: ir::CrudStructAttributes =
 | 
			
		||||
        deluxe::extract_attributes(&mut ast).expect("Could not extract attributes from struct");
 | 
			
		||||
 | 
			
		||||
    // extract field attributes
 | 
			
		||||
    let (fields, id) = extract_crud_field_attrs(&mut ast)?;
 | 
			
		||||
    let ty = &id.ty;
 | 
			
		||||
    let id_ident = &id.ident;
 | 
			
		||||
 | 
			
		||||
    // define impl variables
 | 
			
		||||
    let ident = &ast.ident;
 | 
			
		||||
    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
 | 
			
		||||
 | 
			
		||||
    // generate
 | 
			
		||||
    let find_query = generate_find_query(&table, &id);
 | 
			
		||||
    let create_query = generate_create_query(&table, &fields);
 | 
			
		||||
    let update_query = generate_update_query(&table, &fields, &id);
 | 
			
		||||
    let delete_query = generate_delete_query(&table, &id);
 | 
			
		||||
    let trait_impl = trait_implementation::derive_trait(&ast, &struct_attrs.table, &fields, &id);
 | 
			
		||||
    let relationships = relationships::derive_relationships(&ast, &struct_attrs, &fields, &id);
 | 
			
		||||
    let code = quote! {
 | 
			
		||||
        impl #impl_generics Crud<#ty> for #ident #type_generics #where_clause {
 | 
			
		||||
            #find_query
 | 
			
		||||
 | 
			
		||||
            #create_query
 | 
			
		||||
 | 
			
		||||
            #update_query
 | 
			
		||||
 | 
			
		||||
            async fn create_or_update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
 | 
			
		||||
                if Self::find(pool, &self.#id_ident).await?.is_some() {
 | 
			
		||||
                    self.update(pool).await
 | 
			
		||||
                } else {
 | 
			
		||||
                    self.create(pool).await
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            #delete_query
 | 
			
		||||
        }
 | 
			
		||||
        #trait_impl
 | 
			
		||||
        #relationships
 | 
			
		||||
    };
 | 
			
		||||
    Ok(code)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										54
									
								
								gejdr-macros/src/crud/relationships.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								gejdr-macros/src/crud/relationships.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
			
		||||
use std::str::FromStr;
 | 
			
		||||
 | 
			
		||||
use crate::crud::ir::M2MRelationshipComplete;
 | 
			
		||||
 | 
			
		||||
use super::ir::CrudField;
 | 
			
		||||
use proc_macro2::TokenStream;
 | 
			
		||||
use quote::quote;
 | 
			
		||||
 | 
			
		||||
fn join_token_streams(token_streams: &[TokenStream]) -> TokenStream {
 | 
			
		||||
    let newline = TokenStream::from_str("\n").unwrap();
 | 
			
		||||
    token_streams
 | 
			
		||||
        .iter()
 | 
			
		||||
        .cloned()
 | 
			
		||||
        .flat_map(|ts| std::iter::once(ts).chain(std::iter::once(newline.clone())))
 | 
			
		||||
        .collect()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn derive<T, P>(relationships: &[T], condition: P) -> TokenStream
 | 
			
		||||
where
 | 
			
		||||
    for<'a> &'a T: Into<TokenStream>,
 | 
			
		||||
    P: FnMut(&&T) -> bool,
 | 
			
		||||
{
 | 
			
		||||
    let implementations: Vec<TokenStream> = relationships
 | 
			
		||||
        .iter()
 | 
			
		||||
        .filter(condition)
 | 
			
		||||
        .map(std::convert::Into::into)
 | 
			
		||||
        .collect();
 | 
			
		||||
    join_token_streams(&implementations)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn derive_relationships(
 | 
			
		||||
    ast: &syn::DeriveInput,
 | 
			
		||||
    struct_attrs: &super::ir::CrudStructAttributes,
 | 
			
		||||
    fields: &[CrudField],
 | 
			
		||||
    id: &CrudField,
 | 
			
		||||
) -> TokenStream {
 | 
			
		||||
    let struct_name = &ast.ident;
 | 
			
		||||
    let one_to_one = derive(fields, |field| field.relation.is_none());
 | 
			
		||||
    let one_to_many = derive(&struct_attrs.one_to_many, |_| true);
 | 
			
		||||
    let many_to_many: Vec<M2MRelationshipComplete> = struct_attrs
 | 
			
		||||
        .many_to_many
 | 
			
		||||
        .iter()
 | 
			
		||||
        .map(|v| M2MRelationshipComplete::new(v, &struct_attrs.table, &id.column))
 | 
			
		||||
        .collect();
 | 
			
		||||
    let many_to_many = derive(&many_to_many, |_| true);
 | 
			
		||||
 | 
			
		||||
    quote! {
 | 
			
		||||
        impl #struct_name {
 | 
			
		||||
            #one_to_one
 | 
			
		||||
            #one_to_many
 | 
			
		||||
            #many_to_many
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										149
									
								
								gejdr-macros/src/crud/trait_implementation.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								gejdr-macros/src/crud/trait_implementation.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,149 @@
 | 
			
		||||
use super::ir::CrudField;
 | 
			
		||||
use quote::quote;
 | 
			
		||||
 | 
			
		||||
fn generate_find_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream {
 | 
			
		||||
    let find_string = format!("SELECT * FROM {} WHERE {} = $1", table, id.column);
 | 
			
		||||
    let ty = &id.ty;
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn find(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<Option<Self>> {
 | 
			
		||||
            ::sqlx::query_as!(Self, #find_string, id)
 | 
			
		||||
                .fetch_optional(pool)
 | 
			
		||||
                .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_create_query(table: &str, fields: &[CrudField]) -> proc_macro2::TokenStream {
 | 
			
		||||
    let inputs: Vec<String> = (1..=fields.len()).map(|num| format!("${num}")).collect();
 | 
			
		||||
    let create_string = format!(
 | 
			
		||||
        "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
 | 
			
		||||
        table,
 | 
			
		||||
        fields
 | 
			
		||||
            .iter()
 | 
			
		||||
            .map(|v| v.column.clone())
 | 
			
		||||
            .collect::<Vec<String>>()
 | 
			
		||||
            .join(", "),
 | 
			
		||||
        inputs.join(", ")
 | 
			
		||||
    );
 | 
			
		||||
    let field_idents: Vec<syn::Ident> = fields.iter().map(|f| f.ident.clone()).collect();
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn create(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
 | 
			
		||||
            ::sqlx::query_as!(
 | 
			
		||||
                Self,
 | 
			
		||||
                #create_string,
 | 
			
		||||
                #(self.#field_idents),*
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_one(pool)
 | 
			
		||||
            .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_update_query(
 | 
			
		||||
    table: &str,
 | 
			
		||||
    fields: &[CrudField],
 | 
			
		||||
    id: &CrudField,
 | 
			
		||||
) -> proc_macro2::TokenStream {
 | 
			
		||||
    let mut fields: Vec<&CrudField> = fields.iter().filter(|f| !f.id).collect();
 | 
			
		||||
    let update_columns = fields
 | 
			
		||||
        .iter()
 | 
			
		||||
        .enumerate()
 | 
			
		||||
        .map(|(i, &field)| format!("{} = ${}", field.column, i + 1))
 | 
			
		||||
        .collect::<Vec<String>>()
 | 
			
		||||
        .join(", ");
 | 
			
		||||
    let update_string = format!(
 | 
			
		||||
        "UPDATE {} SET {} WHERE {} = ${} RETURNING *",
 | 
			
		||||
        table,
 | 
			
		||||
        update_columns,
 | 
			
		||||
        id.column,
 | 
			
		||||
        fields.len() + 1
 | 
			
		||||
    );
 | 
			
		||||
    fields.push(id);
 | 
			
		||||
    let field_idents: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
 | 
			
		||||
            ::sqlx::query_as!(
 | 
			
		||||
                Self,
 | 
			
		||||
                #update_string,
 | 
			
		||||
                #(self.#field_idents),*
 | 
			
		||||
            )
 | 
			
		||||
            .fetch_one(pool)
 | 
			
		||||
            .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_delete_query(table: &str, id: &CrudField) -> proc_macro2::TokenStream {
 | 
			
		||||
    let delete_string = format!("DELETE FROM {} WHERE {} = $1", table, id.column);
 | 
			
		||||
    let ty = &id.ty;
 | 
			
		||||
 | 
			
		||||
    quote! {
 | 
			
		||||
        async fn delete_by_id(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<u64> {
 | 
			
		||||
            let rows_affected = ::sqlx::query!(#delete_string, id)
 | 
			
		||||
                .execute(pool)
 | 
			
		||||
                .await?
 | 
			
		||||
                .rows_affected();
 | 
			
		||||
            Ok(rows_affected)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        async fn delete(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<u64> {
 | 
			
		||||
            let rows_affected = ::sqlx::query!(#delete_string, self.get_id())
 | 
			
		||||
                .execute(pool)
 | 
			
		||||
                .await?
 | 
			
		||||
                .rows_affected();
 | 
			
		||||
            Ok(rows_affected)
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn generate_get_id(id: &CrudField) -> proc_macro2::TokenStream {
 | 
			
		||||
    let ident = &id.ident;
 | 
			
		||||
    let ty = &id.ty;
 | 
			
		||||
    quote! {
 | 
			
		||||
        fn get_id(&self) -> &#ty {
 | 
			
		||||
            &self.#ident
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub fn derive_trait(
 | 
			
		||||
    ast: &syn::DeriveInput,
 | 
			
		||||
    table: &str,
 | 
			
		||||
    fields: &[CrudField],
 | 
			
		||||
    id: &CrudField,
 | 
			
		||||
) -> proc_macro2::TokenStream {
 | 
			
		||||
    let ty = &id.ty;
 | 
			
		||||
    let id_ident = &id.ident;
 | 
			
		||||
 | 
			
		||||
    // define impl variables
 | 
			
		||||
    let ident = &ast.ident;
 | 
			
		||||
    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
 | 
			
		||||
 | 
			
		||||
    // generate
 | 
			
		||||
    let get_id = generate_get_id(id);
 | 
			
		||||
    let find_query = generate_find_query(table, id);
 | 
			
		||||
    let create_query = generate_create_query(table, fields);
 | 
			
		||||
    let update_query = generate_update_query(table, fields, id);
 | 
			
		||||
    let delete_query = generate_delete_query(table, id);
 | 
			
		||||
    quote! {
 | 
			
		||||
        impl #impl_generics Crud<#ty> for #ident #type_generics #where_clause {
 | 
			
		||||
            #get_id
 | 
			
		||||
 | 
			
		||||
            #find_query
 | 
			
		||||
 | 
			
		||||
            #create_query
 | 
			
		||||
 | 
			
		||||
            #update_query
 | 
			
		||||
 | 
			
		||||
            async fn create_or_update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
 | 
			
		||||
                if Self::find(pool, &self.#id_ident).await?.is_some() {
 | 
			
		||||
                    self.update(pool).await
 | 
			
		||||
                } else {
 | 
			
		||||
                    self.create(pool).await
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            #delete_query
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@ -3,7 +3,6 @@
 | 
			
		||||
#![deny(clippy::nursery)]
 | 
			
		||||
#![allow(clippy::module_name_repetitions)]
 | 
			
		||||
#![allow(clippy::unused_async)]
 | 
			
		||||
#![allow(clippy::useless_let_if_seq)] // Reason: prevents some OpenApi structs from compiling
 | 
			
		||||
#![forbid(unsafe_code)]
 | 
			
		||||
 | 
			
		||||
//! Create ``SQLx`` CRUD code for a struct in Postgres.
 | 
			
		||||
@ -19,11 +18,12 @@
 | 
			
		||||
//! - update an entity in the database
 | 
			
		||||
//!
 | 
			
		||||
//!   SQL query: `UPDATE ... SET ... WHERE <id> = ... RETURNING *`
 | 
			
		||||
//! - delete an entity from the database using its id
 | 
			
		||||
//! - delete an entity from the database using its id or an id
 | 
			
		||||
//!   provided by the interface’s user
 | 
			
		||||
//!
 | 
			
		||||
//!   SQL query: `DELETE FROM ... WHERE <id> = ...`
 | 
			
		||||
//! - update an entity or create it if it does not already exist in
 | 
			
		||||
//! - the database
 | 
			
		||||
//!   the database
 | 
			
		||||
//!
 | 
			
		||||
//! This macro relies on the trait `Crud` found in the `gejdr-core`
 | 
			
		||||
//! crate.
 | 
			
		||||
@ -34,9 +34,16 @@
 | 
			
		||||
//! # Usage
 | 
			
		||||
//!
 | 
			
		||||
//! Add `#[crud(table = "my_table_name")]` atop of the structure,
 | 
			
		||||
//! after the `Crud` derive. You will also need to add `#[crud(id)]`
 | 
			
		||||
//! atop of the field of your struct that will be used as the
 | 
			
		||||
//! identifier of your entity.
 | 
			
		||||
//! after the `Crud` derive.
 | 
			
		||||
//!
 | 
			
		||||
//! ## Entity Identifier
 | 
			
		||||
//! You will also need to add `#[crud(id)]` atop of the field of your
 | 
			
		||||
//! struct that will be used as the identifier of your entity.
 | 
			
		||||
//!
 | 
			
		||||
//! ## Column Name
 | 
			
		||||
//! If the name of a field does not match the name of its related
 | 
			
		||||
//! column, you can use `#[crud(column = "...")]` to specify the
 | 
			
		||||
//! correct value.
 | 
			
		||||
//!
 | 
			
		||||
//! ```ignore
 | 
			
		||||
//! #[derive(Crud)]
 | 
			
		||||
@ -44,6 +51,7 @@
 | 
			
		||||
//! pub struct User {
 | 
			
		||||
//!     #[crud(id)]
 | 
			
		||||
//!     id: String,
 | 
			
		||||
//!     #[crud(column = "name")]
 | 
			
		||||
//!     username: String,
 | 
			
		||||
//!     created_at: Timestampz,
 | 
			
		||||
//!     last_updated: Timestampz,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user