feat: implement preliminary composite primary key support

Add support for entities with composite primary keys using multiple
#[georm(id)] fields. Automatically generates {EntityName}Id structs for
type-safe composite key handling.

Features:
- Multi-field primary key detection and ID struct generation
- Full CRUD operations (find, create, update, delete, create_or_update)
- Proper SQL generation with AND clauses for composite keys
- Updated documNtation in README and lib.rs

Note: Relationships not yet supported for composite key entities
This commit is contained in:
2025-06-07 16:16:46 +02:00
parent 190c4d7b1d
commit 19284665e6
17 changed files with 712 additions and 75 deletions

View File

@@ -0,0 +1,87 @@
use super::ir::GeormField;
use quote::quote;
#[derive(Debug)]
pub enum IdType {
Simple {
field_name: syn::Ident,
field_type: syn::Type,
},
Composite {
fields: Vec<IdField>,
field_type: syn::Ident,
},
}
#[derive(Debug, Clone)]
pub struct IdField {
pub name: syn::Ident,
pub ty: syn::Type,
}
fn field_to_code(field: &GeormField) -> proc_macro2::TokenStream {
let ident = field.ident.clone();
let ty = field.ty.clone();
quote! {
pub #ident: #ty
}
}
fn generate_struct(
ast: &syn::DeriveInput,
fields: &[GeormField],
) -> (syn::Ident, proc_macro2::TokenStream) {
let struct_name = &ast.ident;
let id_struct_name = quote::format_ident!("{struct_name}Id");
let vis = &ast.vis;
let fields: Vec<proc_macro2::TokenStream> = fields
.iter()
.filter_map(|field| {
if field.id {
Some(field_to_code(field))
} else {
None
}
})
.collect();
let code = quote! {
#vis struct #id_struct_name {
#(#fields),*
}
};
(id_struct_name, code)
}
pub fn create_primary_key(
ast: &syn::DeriveInput,
fields: &[GeormField],
) -> (IdType, proc_macro2::TokenStream) {
let georm_id_fields: Vec<&GeormField> = fields.iter().filter(|field| field.id).collect();
let id_fields: Vec<IdField> = georm_id_fields
.iter()
.map(|field| IdField {
name: field.ident.clone(),
ty: field.ty.clone(),
})
.collect();
match id_fields.len() {
0 => panic!("No ID field found"),
1 => (
IdType::Simple {
field_name: id_fields[0].name.clone(),
field_type: id_fields[0].ty.clone(),
},
quote! {},
),
_ => {
let (struct_name, struct_code) = generate_struct(ast, fields);
(
IdType::Composite {
fields: id_fields.clone(),
field_type: struct_name,
},
struct_code,
)
}
}
}

View File

@@ -138,7 +138,6 @@ pub fn derive_defaultable_struct(
);
quote! {
#[derive(Debug, Clone)]
#vis struct #defaultable_struct_name {
#(#defaultable_fields),*
}

View File

@@ -31,14 +31,14 @@ pub struct M2MRelationshipComplete {
}
impl M2MRelationshipComplete {
pub fn new(other: &M2MRelationship, local_table: &String, local_id: String) -> Self {
pub fn new(other: &M2MRelationship, local_table: &String, local_id: &String) -> Self {
Self {
name: other.name.clone(),
entity: other.entity.clone(),
link: other.link.clone(),
local: Identifier {
table: local_table.to_string(),
id: local_id,
id: local_id.to_string(),
},
remote: Identifier {
table: other.table.clone(),

View File

@@ -1,14 +1,13 @@
use ir::GeormField;
use quote::quote;
mod composite_keys;
mod defaultable_struct;
mod ir;
mod relationships;
mod trait_implementation;
fn extract_georm_field_attrs(
ast: &mut syn::DeriveInput,
) -> deluxe::Result<(Vec<GeormField>, GeormField)> {
fn extract_georm_field_attrs(ast: &mut syn::DeriveInput) -> deluxe::Result<Vec<GeormField>> {
let syn::Data::Struct(s) = &mut ast.data else {
return Err(syn::Error::new_spanned(
ast,
@@ -26,23 +25,13 @@ fn extract_georm_field_attrs(
.into_iter()
.filter(|field| field.id)
.collect();
match identifiers.len() {
0 => Err(syn::Error::new_spanned(
if identifiers.is_empty() {
Err(syn::Error::new_spanned(
ast,
"Struct {name} must have one identifier",
)),
1 => Ok((fields, identifiers.first().unwrap().clone())),
_ => {
let id1 = identifiers.first().unwrap();
let id2 = identifiers.get(1).unwrap();
Err(syn::Error::new_spanned(
id2.field.clone(),
format!(
"Field {} cannot be an identifier, {} already is one.\nOnly one identifier is supported.",
id1.ident, id2.ident
),
))
}
))
} else {
Ok(fields)
}
}
@@ -52,16 +41,23 @@ pub fn georm_derive_macro2(
let mut ast: syn::DeriveInput = syn::parse2(item).expect("Failed to parse input");
let struct_attrs: ir::GeormStructAttributes =
deluxe::extract_attributes(&mut ast).expect("Could not extract attributes from struct");
let (fields, id) = extract_georm_field_attrs(&mut ast)?;
let relationships = relationships::derive_relationships(&ast, &struct_attrs, &fields, &id);
let trait_impl = trait_implementation::derive_trait(&ast, &struct_attrs.table, &fields, &id);
let fields = extract_georm_field_attrs(&mut ast)?;
let defaultable_struct =
defaultable_struct::derive_defaultable_struct(&ast, &struct_attrs, &fields);
let from_row_impl = generate_from_row_impl(&ast, &fields);
let (identifier, id_struct) = composite_keys::create_primary_key(&ast, &fields);
let relationships =
relationships::derive_relationships(&ast, &struct_attrs, &fields, &identifier);
let trait_impl =
trait_implementation::derive_trait(&ast, &struct_attrs.table, &fields, &identifier);
let code = quote! {
#id_struct
#defaultable_struct
#relationships
#trait_impl
#defaultable_struct
#from_row_impl
};
Ok(code)

View File

@@ -2,6 +2,7 @@ use std::str::FromStr;
use crate::georm::ir::m2m_relationship::M2MRelationshipComplete;
use super::composite_keys::IdType;
use super::ir::GeormField;
use proc_macro2::TokenStream;
use quote::quote;
@@ -28,8 +29,24 @@ pub fn derive_relationships(
ast: &syn::DeriveInput,
struct_attrs: &super::ir::GeormStructAttributes,
fields: &[GeormField],
id: &GeormField,
id: &IdType,
) -> TokenStream {
let id = match id {
IdType::Simple {
field_name,
field_type: _,
} => field_name.to_string(),
IdType::Composite {
fields: _,
field_type: _,
} => {
eprintln!(
"Warning: entity {}: Relationships are not supported for entities with composite primary keys yet",
ast.ident
);
return quote! {};
}
};
let struct_name = &ast.ident;
let one_to_one_local = derive(fields);
let one_to_one_remote = derive(&struct_attrs.one_to_one);
@@ -37,7 +54,7 @@ pub fn derive_relationships(
let many_to_many: Vec<M2MRelationshipComplete> = struct_attrs
.many_to_many
.iter()
.map(|v| M2MRelationshipComplete::new(v, &struct_attrs.table, id.ident.to_string()))
.map(|v| M2MRelationshipComplete::new(v, &struct_attrs.table, &id))
.collect();
let many_to_many = derive(&many_to_many);

View File

@@ -1,3 +1,4 @@
use super::composite_keys::IdType;
use super::ir::GeormField;
use quote::quote;
@@ -10,14 +11,38 @@ fn generate_find_all_query(table: &str) -> proc_macro2::TokenStream {
}
}
fn generate_find_query(table: &str, id: &GeormField) -> proc_macro2::TokenStream {
let find_string = format!("SELECT * FROM {table} WHERE {} = $1", id.ident);
let ty = &id.ty;
quote! {
async fn find(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<Option<Self>> {
::sqlx::query_as!(Self, #find_string, id)
.fetch_optional(pool)
.await
fn generate_find_query(table: &str, id: &IdType) -> proc_macro2::TokenStream {
match id {
IdType::Simple {
field_name,
field_type,
} => {
let find_string = format!("SELECT * FROM {table} WHERE {} = $1", field_name);
quote! {
async fn find(pool: &::sqlx::PgPool, id: &#field_type) -> ::sqlx::Result<Option<Self>> {
::sqlx::query_as!(Self, #find_string, id)
.fetch_optional(pool)
.await
}
}
}
IdType::Composite { fields, field_type } => {
let id_match_string = fields
.iter()
.enumerate()
.map(|(i, field)| format!("{} = ${}", field.name, i + 1))
.collect::<Vec<String>>()
.join(" AND ");
let id_members: Vec<syn::Ident> =
fields.iter().map(|field| field.name.clone()).collect();
let find_string = format!("SELECT * FROM {table} WHERE {id_match_string}");
quote! {
async fn find(pool: &::sqlx::PgPool, id: &#field_type) -> ::sqlx::Result<Option<Self>> {
::sqlx::query_as!(Self, #find_string, #(id.#id_members),*)
.fetch_optional(pool)
.await
}
}
}
}
}
@@ -50,28 +75,42 @@ fn generate_create_query(table: &str, fields: &[GeormField]) -> proc_macro2::Tok
fn generate_update_query(
table: &str,
fields: &[GeormField],
id: &GeormField,
id: &IdType,
) -> proc_macro2::TokenStream {
let mut fields: Vec<&GeormField> = fields.iter().filter(|f| !f.id).collect();
let update_columns = fields
let non_id_fields: Vec<syn::Ident> = fields
.iter()
.filter_map(|f| if f.id { None } else { Some(f.ident.clone()) })
.collect();
let update_columns = non_id_fields
.iter()
.enumerate()
.map(|(i, &field)| format!("{} = ${}", field.ident, i + 1))
.map(|(i, field)| format!("{} = ${}", field, i + 1))
.collect::<Vec<String>>()
.join(", ");
let update_string = format!(
"UPDATE {table} SET {update_columns} WHERE {} = ${} RETURNING *",
id.ident,
fields.len() + 1
);
fields.push(id);
let field_idents: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect();
let mut all_fields = non_id_fields.clone();
let where_clause = match id {
IdType::Simple { field_name, .. } => {
let where_clause = format!("{} = ${}", field_name, non_id_fields.len() + 1);
all_fields.push(field_name.clone());
where_clause
}
IdType::Composite { fields, .. } => fields
.iter()
.enumerate()
.map(|(i, field)| {
let where_clause = format!("{} = ${}", field.name, non_id_fields.len() + i + 1);
all_fields.push(field.name.clone());
where_clause
})
.collect::<Vec<String>>()
.join(" AND "),
};
let update_string =
format!("UPDATE {table} SET {update_columns} WHERE {where_clause} RETURNING *");
quote! {
async fn update(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<Self> {
::sqlx::query_as!(
Self,
#update_string,
#(self.#field_idents),*
Self, #update_string, #(self.#all_fields),*
)
.fetch_one(pool)
.await
@@ -79,12 +118,31 @@ fn generate_update_query(
}
}
fn generate_delete_query(table: &str, id: &GeormField) -> proc_macro2::TokenStream {
let delete_string = format!("DELETE FROM {table} WHERE {} = $1", id.ident);
let ty = &id.ty;
fn generate_delete_query(table: &str, id: &IdType) -> proc_macro2::TokenStream {
let where_clause = match id {
IdType::Simple { field_name, .. } => format!("{} = $1", field_name),
IdType::Composite { fields, .. } => fields
.iter()
.enumerate()
.map(|(i, field)| format!("{} = ${}", field.name, i + 1))
.collect::<Vec<String>>()
.join(" AND "),
};
let query_args = match id {
IdType::Simple { .. } => quote! { id },
IdType::Composite { fields, .. } => {
let fields: Vec<syn::Ident> = fields.iter().map(|f| f.name.clone()).collect();
quote! { #(id.#fields), * }
}
};
let id_type = match id {
IdType::Simple { field_type, .. } => quote! { #field_type },
IdType::Composite { field_type, .. } => quote! { #field_type },
};
let delete_string = format!("DELETE FROM {table} WHERE {where_clause}");
quote! {
async fn delete_by_id(pool: &::sqlx::PgPool, id: &#ty) -> ::sqlx::Result<u64> {
let rows_affected = ::sqlx::query!(#delete_string, id)
async fn delete_by_id(pool: &::sqlx::PgPool, id: &#id_type) -> ::sqlx::Result<u64> {
let rows_affected = ::sqlx::query!(#delete_string, #query_args)
.execute(pool)
.await?
.rows_affected();
@@ -92,7 +150,7 @@ fn generate_delete_query(table: &str, id: &GeormField) -> proc_macro2::TokenStre
}
async fn delete(&self, pool: &::sqlx::PgPool) -> ::sqlx::Result<u64> {
Self::delete_by_id(pool, self.get_id()).await
Self::delete_by_id(pool, &self.get_id()).await
}
}
}
@@ -100,7 +158,7 @@ fn generate_delete_query(table: &str, id: &GeormField) -> proc_macro2::TokenStre
fn generate_upsert_query(
table: &str,
fields: &[GeormField],
id: &GeormField,
id: &IdType,
) -> proc_macro2::TokenStream {
let inputs: Vec<String> = (1..=fields.len()).map(|num| format!("${num}")).collect();
let columns = fields
@@ -109,6 +167,16 @@ fn generate_upsert_query(
.collect::<Vec<String>>()
.join(", ");
let primary_key: proc_macro2::TokenStream = match id {
IdType::Simple { field_name, .. } => quote! {#field_name},
IdType::Composite { fields, .. } => {
let field_names: Vec<syn::Ident> = fields.iter().map(|f| f.name.clone()).collect();
quote! {
#(#field_names),*
}
}
};
// For ON CONFLICT DO UPDATE, exclude the ID field from updates
let update_assignments = fields
.iter()
@@ -120,7 +188,7 @@ fn generate_upsert_query(
let upsert_string = format!(
"INSERT INTO {table} ({columns}) VALUES ({}) ON CONFLICT ({}) DO UPDATE SET {update_assignments} RETURNING *",
inputs.join(", "),
id.ident
primary_key
);
let field_idents: Vec<syn::Ident> = fields.iter().map(|f| f.ident.clone()).collect();
@@ -138,12 +206,27 @@ fn generate_upsert_query(
}
}
fn generate_get_id(id: &GeormField) -> proc_macro2::TokenStream {
let ident = &id.ident;
let ty = &id.ty;
quote! {
fn get_id(&self) -> &#ty {
&self.#ident
fn generate_get_id(id: &IdType) -> proc_macro2::TokenStream {
match id {
IdType::Simple {
field_name,
field_type,
} => {
quote! {
fn get_id(&self) -> #field_type {
self.#field_name.clone()
}
}
}
IdType::Composite { fields, field_type } => {
let field_names: Vec<syn::Ident> = fields.iter().map(|f| f.name.clone()).collect();
quote! {
fn get_id(&self) -> #field_type {
#field_type {
#(#field_names: self.#field_names),*
}
}
}
}
}
}
@@ -152,9 +235,12 @@ pub fn derive_trait(
ast: &syn::DeriveInput,
table: &str,
fields: &[GeormField],
id: &GeormField,
id: &IdType,
) -> proc_macro2::TokenStream {
let ty = &id.ty;
let ty = match id {
IdType::Simple { field_type, .. } => quote! {#field_type},
IdType::Composite { field_type, .. } => quote! {#field_type},
};
// define impl variables
let ident = &ast.ident;