neo3/neo_clients/
circuit_breaker.rs

1use crate::neo_error::{Neo3Error, Neo3Result};
2use std::{
3	sync::{
4		atomic::{AtomicU32, AtomicU64, Ordering},
5		Arc,
6	},
7	time::{Duration, Instant},
8};
9use tokio::sync::RwLock;
10
11/// Circuit breaker states
12#[derive(Debug, Clone, PartialEq)]
13pub enum CircuitState {
14	/// Circuit is closed, requests flow normally
15	Closed,
16	/// Circuit is open, requests are rejected immediately
17	Open,
18	/// Circuit is half-open, testing if service has recovered
19	HalfOpen,
20}
21
22impl Default for CircuitState {
23	fn default() -> Self {
24		CircuitState::Closed
25	}
26}
27
28/// Circuit breaker configuration
29#[derive(Debug, Clone)]
30pub struct CircuitBreakerConfig {
31	/// Number of failures before opening the circuit
32	pub failure_threshold: u32,
33	/// Time to wait before transitioning from Open to HalfOpen
34	pub timeout: Duration,
35	/// Number of successful requests needed to close the circuit from HalfOpen
36	pub success_threshold: u32,
37	/// Time window for counting failures
38	pub failure_window: Duration,
39	/// Maximum number of requests allowed in HalfOpen state
40	pub half_open_max_requests: u32,
41}
42
43impl Default for CircuitBreakerConfig {
44	fn default() -> Self {
45		Self {
46			failure_threshold: 5,
47			timeout: Duration::from_secs(60),
48			success_threshold: 3,
49			failure_window: Duration::from_secs(60),
50			half_open_max_requests: 3,
51		}
52	}
53}
54
55/// Circuit breaker statistics
56#[derive(Debug, Default)]
57pub struct CircuitBreakerStats {
58	pub total_requests: u64,
59	pub successful_requests: u64,
60	pub failed_requests: u64,
61	pub rejected_requests: u64,
62	pub state_transitions: u64,
63	pub current_state: CircuitState,
64	pub last_failure_time: Option<Instant>,
65	pub last_success_time: Option<Instant>,
66}
67
68/// Circuit breaker implementation for protecting against cascading failures
69pub struct CircuitBreaker {
70	config: CircuitBreakerConfig,
71	state: Arc<RwLock<CircuitState>>,
72	failure_count: AtomicU32,
73	success_count: AtomicU32,
74	half_open_requests: AtomicU32,
75	last_failure_time: Arc<RwLock<Option<Instant>>>,
76	last_success_time: Arc<RwLock<Option<Instant>>>,
77	stats: Arc<RwLock<CircuitBreakerStats>>,
78}
79
80impl CircuitBreaker {
81	/// Create a new circuit breaker with the given configuration
82	pub fn new(config: CircuitBreakerConfig) -> Self {
83		Self {
84			config,
85			state: Arc::new(RwLock::new(CircuitState::Closed)),
86			failure_count: AtomicU32::new(0),
87			success_count: AtomicU32::new(0),
88			half_open_requests: AtomicU32::new(0),
89			last_failure_time: Arc::new(RwLock::new(None)),
90			last_success_time: Arc::new(RwLock::new(None)),
91			stats: Arc::new(RwLock::new(CircuitBreakerStats::default())),
92		}
93	}
94
95	/// Execute a request through the circuit breaker
96	pub async fn call<F, T>(&self, operation: F) -> Neo3Result<T>
97	where
98		F: std::future::Future<Output = Neo3Result<T>>,
99	{
100		// Update total requests
101		{
102			let mut stats = self.stats.write().await;
103			stats.total_requests += 1;
104		}
105
106		// Check if we should allow the request
107		if !self.should_allow_request().await {
108			let mut stats = self.stats.write().await;
109			stats.rejected_requests += 1;
110			return Err(Neo3Error::Network(crate::neo_error::NetworkError::RateLimitExceeded));
111		}
112
113		// Execute the operation
114		match operation.await {
115			Ok(result) => {
116				self.on_success().await;
117				Ok(result)
118			},
119			Err(error) => {
120				self.on_failure().await;
121				Err(error)
122			},
123		}
124	}
125
126	/// Check if a request should be allowed based on current state
127	async fn should_allow_request(&self) -> bool {
128		let state = self.state.read().await;
129		match *state {
130			CircuitState::Closed => true,
131			CircuitState::Open => {
132				// Check if timeout has elapsed
133				if let Some(last_failure) = *self.last_failure_time.read().await {
134					if last_failure.elapsed() >= self.config.timeout {
135						drop(state);
136						self.transition_to_half_open().await;
137						true
138					} else {
139						false
140					}
141				} else {
142					false
143				}
144			},
145			CircuitState::HalfOpen => {
146				// Allow limited requests in half-open state
147				let current_requests = self.half_open_requests.load(Ordering::Relaxed);
148				current_requests < self.config.half_open_max_requests
149			},
150		}
151	}
152
153	/// Handle successful request
154	async fn on_success(&self) {
155		let mut stats = self.stats.write().await;
156		stats.successful_requests += 1;
157		stats.last_success_time = Some(Instant::now());
158		drop(stats);
159
160		*self.last_success_time.write().await = Some(Instant::now());
161
162		let state = self.state.read().await;
163		match *state {
164			CircuitState::Closed => {
165				// Reset failure count on success
166				self.failure_count.store(0, Ordering::Relaxed);
167			},
168			CircuitState::HalfOpen => {
169				let success_count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
170				if success_count >= self.config.success_threshold {
171					drop(state);
172					self.transition_to_closed().await;
173				}
174			},
175			CircuitState::Open => {
176				// This shouldn't happen, but reset if it does
177				drop(state);
178				self.transition_to_closed().await;
179			},
180		}
181	}
182
183	/// Handle failed request
184	async fn on_failure(&self) {
185		let mut stats = self.stats.write().await;
186		stats.failed_requests += 1;
187		stats.last_failure_time = Some(Instant::now());
188		drop(stats);
189
190		*self.last_failure_time.write().await = Some(Instant::now());
191
192		let state = self.state.read().await;
193		match *state {
194			CircuitState::Closed => {
195				let failure_count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
196				if failure_count >= self.config.failure_threshold {
197					drop(state);
198					self.transition_to_open().await;
199				}
200			},
201			CircuitState::HalfOpen => {
202				// Any failure in half-open state transitions back to open
203				drop(state);
204				self.transition_to_open().await;
205			},
206			CircuitState::Open => {
207				// Already open, nothing to do
208			},
209		}
210	}
211
212	/// Transition to closed state
213	async fn transition_to_closed(&self) {
214		let mut state = self.state.write().await;
215		if *state != CircuitState::Closed {
216			*state = CircuitState::Closed;
217			self.failure_count.store(0, Ordering::Relaxed);
218			self.success_count.store(0, Ordering::Relaxed);
219			self.half_open_requests.store(0, Ordering::Relaxed);
220
221			let mut stats = self.stats.write().await;
222			stats.state_transitions += 1;
223			stats.current_state = CircuitState::Closed;
224		}
225	}
226
227	/// Transition to open state
228	async fn transition_to_open(&self) {
229		let mut state = self.state.write().await;
230		if *state != CircuitState::Open {
231			*state = CircuitState::Open;
232			self.success_count.store(0, Ordering::Relaxed);
233			self.half_open_requests.store(0, Ordering::Relaxed);
234
235			let mut stats = self.stats.write().await;
236			stats.state_transitions += 1;
237			stats.current_state = CircuitState::Open;
238		}
239	}
240
241	/// Transition to half-open state
242	async fn transition_to_half_open(&self) {
243		let mut state = self.state.write().await;
244		if *state != CircuitState::HalfOpen {
245			*state = CircuitState::HalfOpen;
246			self.success_count.store(0, Ordering::Relaxed);
247			self.half_open_requests.store(0, Ordering::Relaxed);
248
249			let mut stats = self.stats.write().await;
250			stats.state_transitions += 1;
251			stats.current_state = CircuitState::HalfOpen;
252		}
253	}
254
255	/// Get current circuit breaker state
256	pub async fn get_state(&self) -> CircuitState {
257		let state = self.state.read().await;
258		state.clone()
259	}
260
261	/// Get circuit breaker statistics
262	pub async fn get_stats(&self) -> CircuitBreakerStats {
263		let stats = self.stats.read().await;
264		CircuitBreakerStats {
265			total_requests: stats.total_requests,
266			successful_requests: stats.successful_requests,
267			failed_requests: stats.failed_requests,
268			rejected_requests: stats.rejected_requests,
269			state_transitions: stats.state_transitions,
270			current_state: stats.current_state.clone(),
271			last_failure_time: stats.last_failure_time,
272			last_success_time: stats.last_success_time,
273		}
274	}
275
276	/// Reset the circuit breaker to closed state
277	pub async fn reset(&self) {
278		self.transition_to_closed().await;
279		*self.last_failure_time.write().await = None;
280		*self.last_success_time.write().await = None;
281
282		let mut stats = self.stats.write().await;
283		*stats = CircuitBreakerStats::default();
284	}
285
286	/// Force the circuit breaker to open state
287	pub async fn force_open(&self) {
288		self.transition_to_open().await;
289	}
290
291	/// Get failure rate (failures / total requests)
292	pub async fn get_failure_rate(&self) -> f64 {
293		let stats = self.stats.read().await;
294		if stats.total_requests == 0 {
295			0.0
296		} else {
297			stats.failed_requests as f64 / stats.total_requests as f64
298		}
299	}
300}
301
302#[cfg(test)]
303mod tests {
304	use super::*;
305	use tokio::time::{sleep, Duration};
306
307	#[tokio::test]
308	async fn test_circuit_breaker_closed_state() {
309		let config = CircuitBreakerConfig { failure_threshold: 3, ..Default::default() };
310		let cb = CircuitBreaker::new(config);
311
312		// Successful requests should keep circuit closed
313		for _ in 0..5 {
314			let result = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
315			assert!(result.is_ok());
316		}
317
318		assert_eq!(cb.get_state().await, CircuitState::Closed);
319	}
320
321	#[tokio::test]
322	async fn test_circuit_breaker_opens_on_failures() {
323		let config = CircuitBreakerConfig { failure_threshold: 3, ..Default::default() };
324		let cb = CircuitBreaker::new(config);
325
326		// Generate failures to open circuit
327		for _ in 0..3 {
328			let result = cb
329				.call(async {
330					Err::<(), Neo3Error>(Neo3Error::Network(
331						crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
332					))
333				})
334				.await;
335			assert!(result.is_err());
336		}
337
338		assert_eq!(cb.get_state().await, CircuitState::Open);
339	}
340
341	#[tokio::test]
342	async fn test_circuit_breaker_half_open_transition() {
343		let config = CircuitBreakerConfig {
344			failure_threshold: 2,
345			timeout: Duration::from_millis(100),
346			..Default::default()
347		};
348		let cb = CircuitBreaker::new(config);
349
350		// Open the circuit
351		for _ in 0..2 {
352			let _ = cb
353				.call(async {
354					Err::<(), Neo3Error>(Neo3Error::Network(
355						crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
356					))
357				})
358				.await;
359		}
360		assert_eq!(cb.get_state().await, CircuitState::Open);
361
362		// Wait for timeout
363		sleep(Duration::from_millis(150)).await;
364
365		// Next request should transition to half-open
366		let result = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
367		assert!(result.is_ok());
368		assert_eq!(cb.get_state().await, CircuitState::HalfOpen);
369	}
370
371	#[tokio::test]
372	async fn test_circuit_breaker_stats() {
373		let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
374
375		// Make some requests
376		let _ = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
377		let _ = cb
378			.call(async {
379				Err::<(), Neo3Error>(Neo3Error::Network(
380					crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
381				))
382			})
383			.await;
384
385		let stats = cb.get_stats().await;
386		assert_eq!(stats.total_requests, 2);
387		assert_eq!(stats.successful_requests, 1);
388		assert_eq!(stats.failed_requests, 1);
389	}
390}