diff --git a/README.md b/README.md index 61d63df..0501255 100644 --- a/README.md +++ b/README.md @@ -506,7 +506,6 @@ Georm is designed for zero runtime overhead: ### High Priority - **Transaction Support**: Comprehensive transaction handling with atomic operations -- **Race Condition Fix**: Database-native UPSERT operations to replace current `create_or_update` ### Medium Priority - **Multi-Database Support**: MySQL and SQLite support with feature flags diff --git a/georm-macros/src/georm/trait_implementation.rs b/georm-macros/src/georm/trait_implementation.rs index 6da75c7..4bc3647 100644 --- a/georm-macros/src/georm/trait_implementation.rs +++ b/georm-macros/src/georm/trait_implementation.rs @@ -97,6 +97,47 @@ fn generate_delete_query(table: &str, id: &GeormField) -> proc_macro2::TokenStre } } +fn generate_upsert_query( + table: &str, + fields: &[GeormField], + id: &GeormField, +) -> proc_macro2::TokenStream { + let inputs: Vec = (1..=fields.len()).map(|num| format!("${num}")).collect(); + let columns = fields + .iter() + .map(|f| f.ident.to_string()) + .collect::>() + .join(", "); + + // For ON CONFLICT DO UPDATE, exclude the ID field from updates + let update_assignments = fields + .iter() + .filter(|f| !f.id) + .map(|f| format!("{} = EXCLUDED.{}", f.ident, f.ident)) + .collect::>() + .join(", "); + + let upsert_string = format!( + "INSERT INTO {table} ({columns}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {update_assignments} RETURNING *", + inputs.join(", "), + id.ident + ); + + let field_idents: Vec = fields.iter().map(|f| f.ident.clone()).collect(); + + quote! { + async fn create_or_update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result { + ::sqlx::query_as!( + Self, + #upsert_string, + #(self.#field_idents),* + ) + .fetch_one(pool) + .await + } + } +} + fn generate_get_id(id: &GeormField) -> proc_macro2::TokenStream { let ident = &id.ident; let ty = &id.ty; @@ -125,6 +166,7 @@ pub fn derive_trait( let find_query = generate_find_query(table, id); let create_query = generate_create_query(table, fields); let update_query = generate_update_query(table, fields, id); + let upsert_query = generate_upsert_query(table, fields, id); let delete_query = generate_delete_query(table, id); quote! { impl #impl_generics Georm<#ty> for #ident #type_generics #where_clause { @@ -133,6 +175,7 @@ pub fn derive_trait( #find_query #create_query #update_query + #upsert_query #delete_query } } diff --git a/src/entity.rs b/src/entity.rs index 3f7759d..4b17100 100644 --- a/src/entity.rs +++ b/src/entity.rs @@ -50,18 +50,9 @@ pub trait Georm { fn create_or_update( &self, pool: &sqlx::PgPool, - ) -> impl ::std::future::Future> + ) -> impl std::future::Future> + Send where - Self: Sized, - { - async { - if Self::find(pool, self.get_id()).await?.is_some() { - self.update(pool).await - } else { - self.create(pool).await - } - } - } + Self: Sized; /// Delete the entity from the database if it exists. ///