1use crate::neo_error::{Neo3Error, Neo3Result};
2use std::{
3 sync::{
4 atomic::{AtomicU32, AtomicU64, Ordering},
5 Arc,
6 },
7 time::{Duration, Instant},
8};
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum CircuitState {
14 Closed,
16 Open,
18 HalfOpen,
20}
21
22impl Default for CircuitState {
23 fn default() -> Self {
24 CircuitState::Closed
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct CircuitBreakerConfig {
31 pub failure_threshold: u32,
33 pub timeout: Duration,
35 pub success_threshold: u32,
37 pub failure_window: Duration,
39 pub half_open_max_requests: u32,
41}
42
43impl Default for CircuitBreakerConfig {
44 fn default() -> Self {
45 Self {
46 failure_threshold: 5,
47 timeout: Duration::from_secs(60),
48 success_threshold: 3,
49 failure_window: Duration::from_secs(60),
50 half_open_max_requests: 3,
51 }
52 }
53}
54
55#[derive(Debug, Default)]
57pub struct CircuitBreakerStats {
58 pub total_requests: u64,
59 pub successful_requests: u64,
60 pub failed_requests: u64,
61 pub rejected_requests: u64,
62 pub state_transitions: u64,
63 pub current_state: CircuitState,
64 pub last_failure_time: Option<Instant>,
65 pub last_success_time: Option<Instant>,
66}
67
68pub struct CircuitBreaker {
70 config: CircuitBreakerConfig,
71 state: Arc<RwLock<CircuitState>>,
72 failure_count: AtomicU32,
73 success_count: AtomicU32,
74 half_open_requests: AtomicU32,
75 last_failure_time: Arc<RwLock<Option<Instant>>>,
76 last_success_time: Arc<RwLock<Option<Instant>>>,
77 stats: Arc<RwLock<CircuitBreakerStats>>,
78}
79
80impl CircuitBreaker {
81 pub fn new(config: CircuitBreakerConfig) -> Self {
83 Self {
84 config,
85 state: Arc::new(RwLock::new(CircuitState::Closed)),
86 failure_count: AtomicU32::new(0),
87 success_count: AtomicU32::new(0),
88 half_open_requests: AtomicU32::new(0),
89 last_failure_time: Arc::new(RwLock::new(None)),
90 last_success_time: Arc::new(RwLock::new(None)),
91 stats: Arc::new(RwLock::new(CircuitBreakerStats::default())),
92 }
93 }
94
95 pub async fn call<F, T>(&self, operation: F) -> Neo3Result<T>
97 where
98 F: std::future::Future<Output = Neo3Result<T>>,
99 {
100 {
102 let mut stats = self.stats.write().await;
103 stats.total_requests += 1;
104 }
105
106 if !self.should_allow_request().await {
108 let mut stats = self.stats.write().await;
109 stats.rejected_requests += 1;
110 return Err(Neo3Error::Network(crate::neo_error::NetworkError::RateLimitExceeded));
111 }
112
113 match operation.await {
115 Ok(result) => {
116 self.on_success().await;
117 Ok(result)
118 },
119 Err(error) => {
120 self.on_failure().await;
121 Err(error)
122 },
123 }
124 }
125
126 async fn should_allow_request(&self) -> bool {
128 let state = self.state.read().await;
129 match *state {
130 CircuitState::Closed => true,
131 CircuitState::Open => {
132 if let Some(last_failure) = *self.last_failure_time.read().await {
134 if last_failure.elapsed() >= self.config.timeout {
135 drop(state);
136 self.transition_to_half_open().await;
137 true
138 } else {
139 false
140 }
141 } else {
142 false
143 }
144 },
145 CircuitState::HalfOpen => {
146 let current_requests = self.half_open_requests.load(Ordering::Relaxed);
148 current_requests < self.config.half_open_max_requests
149 },
150 }
151 }
152
153 async fn on_success(&self) {
155 let mut stats = self.stats.write().await;
156 stats.successful_requests += 1;
157 stats.last_success_time = Some(Instant::now());
158 drop(stats);
159
160 *self.last_success_time.write().await = Some(Instant::now());
161
162 let state = self.state.read().await;
163 match *state {
164 CircuitState::Closed => {
165 self.failure_count.store(0, Ordering::Relaxed);
167 },
168 CircuitState::HalfOpen => {
169 let success_count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
170 if success_count >= self.config.success_threshold {
171 drop(state);
172 self.transition_to_closed().await;
173 }
174 },
175 CircuitState::Open => {
176 drop(state);
178 self.transition_to_closed().await;
179 },
180 }
181 }
182
183 async fn on_failure(&self) {
185 let mut stats = self.stats.write().await;
186 stats.failed_requests += 1;
187 stats.last_failure_time = Some(Instant::now());
188 drop(stats);
189
190 *self.last_failure_time.write().await = Some(Instant::now());
191
192 let state = self.state.read().await;
193 match *state {
194 CircuitState::Closed => {
195 let failure_count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
196 if failure_count >= self.config.failure_threshold {
197 drop(state);
198 self.transition_to_open().await;
199 }
200 },
201 CircuitState::HalfOpen => {
202 drop(state);
204 self.transition_to_open().await;
205 },
206 CircuitState::Open => {
207 },
209 }
210 }
211
212 async fn transition_to_closed(&self) {
214 let mut state = self.state.write().await;
215 if *state != CircuitState::Closed {
216 *state = CircuitState::Closed;
217 self.failure_count.store(0, Ordering::Relaxed);
218 self.success_count.store(0, Ordering::Relaxed);
219 self.half_open_requests.store(0, Ordering::Relaxed);
220
221 let mut stats = self.stats.write().await;
222 stats.state_transitions += 1;
223 stats.current_state = CircuitState::Closed;
224 }
225 }
226
227 async fn transition_to_open(&self) {
229 let mut state = self.state.write().await;
230 if *state != CircuitState::Open {
231 *state = CircuitState::Open;
232 self.success_count.store(0, Ordering::Relaxed);
233 self.half_open_requests.store(0, Ordering::Relaxed);
234
235 let mut stats = self.stats.write().await;
236 stats.state_transitions += 1;
237 stats.current_state = CircuitState::Open;
238 }
239 }
240
241 async fn transition_to_half_open(&self) {
243 let mut state = self.state.write().await;
244 if *state != CircuitState::HalfOpen {
245 *state = CircuitState::HalfOpen;
246 self.success_count.store(0, Ordering::Relaxed);
247 self.half_open_requests.store(0, Ordering::Relaxed);
248
249 let mut stats = self.stats.write().await;
250 stats.state_transitions += 1;
251 stats.current_state = CircuitState::HalfOpen;
252 }
253 }
254
255 pub async fn get_state(&self) -> CircuitState {
257 let state = self.state.read().await;
258 state.clone()
259 }
260
261 pub async fn get_stats(&self) -> CircuitBreakerStats {
263 let stats = self.stats.read().await;
264 CircuitBreakerStats {
265 total_requests: stats.total_requests,
266 successful_requests: stats.successful_requests,
267 failed_requests: stats.failed_requests,
268 rejected_requests: stats.rejected_requests,
269 state_transitions: stats.state_transitions,
270 current_state: stats.current_state.clone(),
271 last_failure_time: stats.last_failure_time,
272 last_success_time: stats.last_success_time,
273 }
274 }
275
276 pub async fn reset(&self) {
278 self.transition_to_closed().await;
279 *self.last_failure_time.write().await = None;
280 *self.last_success_time.write().await = None;
281
282 let mut stats = self.stats.write().await;
283 *stats = CircuitBreakerStats::default();
284 }
285
286 pub async fn force_open(&self) {
288 self.transition_to_open().await;
289 }
290
291 pub async fn get_failure_rate(&self) -> f64 {
293 let stats = self.stats.read().await;
294 if stats.total_requests == 0 {
295 0.0
296 } else {
297 stats.failed_requests as f64 / stats.total_requests as f64
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use tokio::time::{sleep, Duration};
306
307 #[tokio::test]
308 async fn test_circuit_breaker_closed_state() {
309 let config = CircuitBreakerConfig { failure_threshold: 3, ..Default::default() };
310 let cb = CircuitBreaker::new(config);
311
312 for _ in 0..5 {
314 let result = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
315 assert!(result.is_ok());
316 }
317
318 assert_eq!(cb.get_state().await, CircuitState::Closed);
319 }
320
321 #[tokio::test]
322 async fn test_circuit_breaker_opens_on_failures() {
323 let config = CircuitBreakerConfig { failure_threshold: 3, ..Default::default() };
324 let cb = CircuitBreaker::new(config);
325
326 for _ in 0..3 {
328 let result = cb
329 .call(async {
330 Err::<(), Neo3Error>(Neo3Error::Network(
331 crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
332 ))
333 })
334 .await;
335 assert!(result.is_err());
336 }
337
338 assert_eq!(cb.get_state().await, CircuitState::Open);
339 }
340
341 #[tokio::test]
342 async fn test_circuit_breaker_half_open_transition() {
343 let config = CircuitBreakerConfig {
344 failure_threshold: 2,
345 timeout: Duration::from_millis(100),
346 ..Default::default()
347 };
348 let cb = CircuitBreaker::new(config);
349
350 for _ in 0..2 {
352 let _ = cb
353 .call(async {
354 Err::<(), Neo3Error>(Neo3Error::Network(
355 crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
356 ))
357 })
358 .await;
359 }
360 assert_eq!(cb.get_state().await, CircuitState::Open);
361
362 sleep(Duration::from_millis(150)).await;
364
365 let result = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
367 assert!(result.is_ok());
368 assert_eq!(cb.get_state().await, CircuitState::HalfOpen);
369 }
370
371 #[tokio::test]
372 async fn test_circuit_breaker_stats() {
373 let cb = CircuitBreaker::new(CircuitBreakerConfig::default());
374
375 let _ = cb.call(async { Ok::<(), Neo3Error>(()) }).await;
377 let _ = cb
378 .call(async {
379 Err::<(), Neo3Error>(Neo3Error::Network(
380 crate::neo_error::NetworkError::ConnectionFailed("test".to_string()),
381 ))
382 })
383 .await;
384
385 let stats = cb.get_stats().await;
386 assert_eq!(stats.total_requests, 2);
387 assert_eq!(stats.successful_requests, 1);
388 assert_eq!(stats.failed_requests, 1);
389 }
390}