use async_trait::async_trait; use std::result::Result as SResult; use std::sync::Arc; use tokio::sync::{Mutex, MutexGuard}; use tokio::time::error::Elapsed; use tokio::time::{Duration, timeout}; use tokio_modbus::client::Context; use crate::domain::relay::controller::{ControllerError, RelayController, Result}; use crate::domain::relay::types::{RelayId, RelayState}; use tokio_modbus::prelude::*; /// Modbus TCP relay controller for real hardware communication. /// /// This implementation communicates with physical Modbus relay hardware over TCP, /// supporting 8-channel relay control via the Modbus protocol. It provides thread-safe /// access using `Arc` and includes configurable timeout handling. pub struct ModbusRelayController { ctx: Arc>, timeout_duration: Duration, } impl std::fmt::Debug for ModbusRelayController { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ModbusRelayController") .field("timeout_duration", &self.timeout_duration) .field("ctx", &"") .finish() } } const ALL_ADDRS: tokio_modbus::Address = 0x0000; const FIRMWARE_ADDR: tokio_modbus::Address = 0x8000; impl ModbusRelayController { /// Creates a new Modbus relay controller connected to the specified device. /// /// Establishes a TCP connection to the Modbus device and configures timeout behavior. /// /// # Errors /// /// Returns `ControllerError::ConnectionError` if: /// - The host/port address is invalid /// - Connection to the Modbus device fails /// - The device is unreachable pub async fn new(host: &str, port: u16, slave_id: u8, timeout_secs: u64) -> Result { if slave_id != 1 { tracing::warn!("Device typically uses slave_id=1, got {slave_id}"); } let socket_addr = format!("{host}:{port}") .parse() .map_err(|e| ControllerError::ConnectionError(format!("Invalid address: {e}")))?; let ctx = tcp::connect_slave(socket_addr, Slave(slave_id)) .await .map_err(|e| ControllerError::ConnectionError(e.to_string()))?; Ok(Self { ctx: Arc::new(Mutex::new(ctx)), timeout_duration: Duration::from_secs(timeout_secs), }) } async fn context(&self) -> MutexGuard<'_, Context> { self.ctx.lock().await } fn handle_modbus_result( &self, result: SResult, tokio_modbus::Error>, Elapsed>, ) -> Result { result .map_err(|_| ControllerError::Timeout(self.timeout_duration.as_secs()))? .map_err(|e| ControllerError::ConnectionError(e.to_string()))? .map_err(|e| ControllerError::ModbusException(e.to_string())) } async fn read_coils_with_timeout(&self, addr: u16, count: u16) -> Result> { let result = timeout( self.timeout_duration, self.context().await.read_coils(addr, count), ) .await; self.handle_modbus_result(result) } async fn write_single_coil_with_timeout(&self, addr: u16, value: bool) -> Result<()> { let result = timeout( self.timeout_duration, self.context().await.write_single_coil(addr, value), ) .await; self.handle_modbus_result(result) } } #[async_trait] impl RelayController for ModbusRelayController { async fn read_relay_state(&self, id: RelayId) -> Result { let addr = id.to_modbus_address(); let coils = self.read_coils_with_timeout(addr, 1).await?; let state = RelayState::from( *coils .first() .ok_or_else(|| ControllerError::InvalidRelayId(id.as_u8()))?, ); tracing::debug!(target: "modbus", relay_id = id.as_u8(), ?state, "Read relay state"); Ok(state) } async fn write_relay_state(&self, id: RelayId, state: RelayState) -> Result<()> { let addr = id.to_modbus_address(); let value: bool = state.into(); self.write_single_coil_with_timeout(addr, value).await?; tracing::info!(target: "modbus", relay_id = id.as_u8(), ?state, "Wrote relay state"); Ok(()) } async fn read_all_states(&self) -> Result> { let coils = self.read_coils_with_timeout(ALL_ADDRS, 8).await?; let states: Vec = coils.into_iter().map(RelayState::from).collect(); tracing::debug!(target: "modbus", "Read all relay states"); Ok(states) } async fn write_all_states(&self, states: Vec) -> Result<()> { if states.len() != 8 { return Err(ControllerError::InvalidInput(format!( "Expected 8 relay states, got {}", states.len() ))); } let coils: Vec = states.iter().map(|&s| s.into()).collect(); let result = timeout( self.timeout_duration, self.context().await.write_multiple_coils(ALL_ADDRS, &coils), ) .await; self.handle_modbus_result(result)?; tracing::info!(target: "modbus", "Wrote all relay states"); Ok(()) } async fn check_connection(&self) -> Result<()> { // Try reading first coil as health check self.read_coils_with_timeout(ALL_ADDRS, 1).await?; Ok(()) } async fn get_firmware_version(&self) -> Result> { let result = timeout( self.timeout_duration, self.context() .await .read_holding_registers(FIRMWARE_ADDR, 1), ) .await; let result = self.handle_modbus_result(result)?; if let Some(&version_raw) = result.first() { let version = f32::from(version_raw) / 100.0; Ok(Some(format!("v{version:.2}"))) } else { Ok(None) } } } #[cfg(test)] #[path = "client_test.rs"] mod tests;