1use std::fmt;
4
5use base64::{engine::general_purpose, Engine};
6use jsonwebtoken::{encode, errors::Error, get_current_timestamp, Algorithm, EncodingKey, Header};
7use primitive_types::U256;
8use serde::{
9 de::{self, MapAccess, Unexpected, Visitor},
10 Deserialize, Serialize,
11};
12use serde_json::{value::RawValue, Value};
13use thiserror::Error;
14
15use neo3::prelude::Bytes;
16
17#[derive(Deserialize, Debug, Clone, Error, PartialEq)]
19pub struct JsonRpcError {
20 pub code: i64,
22 pub message: String,
24 pub data: Option<Value>,
26}
27
28fn spelunk_revert(value: &Value) -> Option<Bytes> {
33 match value {
34 Value::String(s) => Some(s.as_bytes().to_vec()),
35 Value::Object(o) => o.values().flat_map(spelunk_revert).next(),
36 _ => None,
37 }
38}
39
40impl JsonRpcError {
41 pub fn is_revert(&self) -> bool {
46 self.message.contains("revert")
48 }
49
50 pub fn as_revert_data(&self) -> Option<Bytes> {
62 self.is_revert()
63 .then(|| self.data.as_ref().and_then(spelunk_revert).unwrap_or_default())
64 }
65
66 }
71
72impl fmt::Display for JsonRpcError {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "(code: {}, message: {}, data: {:?})", self.code, self.message, self.data)
75 }
76}
77
78fn is_zst<T>(_t: &T) -> bool {
79 std::mem::size_of::<T>() == 0
80}
81
82#[derive(Serialize, Deserialize, Debug)]
83pub struct Request<'a, T> {
85 id: u64,
86 jsonrpc: &'a str,
87 method: &'a str,
88 #[serde(skip_serializing_if = "is_zst")]
89 params: T,
90}
91
92impl<'a, T> Request<'a, T> {
93 pub fn new(id: u64, method: &'a str, params: T) -> Self {
95 Self { id, jsonrpc: "2.0", method, params }
96 }
97}
98
99#[derive(Debug)]
101pub enum Response<'a> {
102 Success { id: u64, result: &'a RawValue },
103 Error { id: u64, error: JsonRpcError },
104 Notification { method: &'a str, params: Params<'a> },
105}
106
107#[derive(Deserialize, Debug)]
108pub struct Params<'a> {
109 pub subscription: U256,
110 #[serde(borrow)]
111 pub result: &'a RawValue,
112}
113
114impl<'de: 'a, 'a> Deserialize<'de> for Response<'a> {
117 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
118 where
119 D: serde::Deserializer<'de>,
120 {
121 struct ResponseVisitor<'a>(&'a ());
122 impl<'de: 'a, 'a> Visitor<'de> for ResponseVisitor<'a> {
123 type Value = Response<'a>;
124
125 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
126 formatter.write_str("a valid jsonrpc 2.0 response object")
127 }
128
129 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
130 where
131 A: MapAccess<'de>,
132 {
133 let mut jsonrpc = false;
134
135 let mut id = None;
137 let mut result = None;
139 let mut error = None;
141 let mut method = None;
143 let mut params = None;
144
145 while let Some(key) = map.next_key()? {
146 match key {
147 "jsonrpc" => {
148 if jsonrpc {
149 return Err(de::Error::duplicate_field("jsonrpc"));
150 }
151
152 let value = map.next_value()?;
153 if value != "2.0" {
154 return Err(de::Error::invalid_value(
155 Unexpected::Str(value),
156 &"2.0",
157 ));
158 }
159
160 jsonrpc = true;
161 },
162 "id" => {
163 if id.is_some() {
164 return Err(de::Error::duplicate_field("id"));
165 }
166
167 let value: u64 = map.next_value()?;
168 id = Some(value);
169 },
170 "result" => {
171 if result.is_some() {
172 return Err(de::Error::duplicate_field("result"));
173 }
174
175 let value: &RawValue = map.next_value()?;
176 result = Some(value);
177 },
178 "error" => {
179 if error.is_some() {
180 return Err(de::Error::duplicate_field("error"));
181 }
182
183 let value: JsonRpcError = map.next_value()?;
184 error = Some(value);
185 },
186 "method" => {
187 if method.is_some() {
188 return Err(de::Error::duplicate_field("method"));
189 }
190
191 let value: &str = map.next_value()?;
192 method = Some(value);
193 },
194 "params" => {
195 if params.is_some() {
196 return Err(de::Error::duplicate_field("params"));
197 }
198
199 let value: Params = map.next_value()?;
200 params = Some(value);
201 },
202 key => {
203 return Err(de::Error::unknown_field(
204 key,
205 &["id", "jsonrpc", "result", "error", "params", "method"],
206 ))
207 },
208 }
209 }
210
211 if !jsonrpc {
213 return Err(de::Error::missing_field("jsonrpc"));
214 }
215
216 match (id, result, error, method, params) {
217 (Some(id), Some(result), None, None, None) => {
218 Ok(Response::Success { id, result })
219 },
220 (Some(id), None, Some(error), None, None) => Ok(Response::Error { id, error }),
221 (None, None, None, Some(method), Some(params)) => {
222 Ok(Response::Notification { method, params })
223 },
224 _ => Err(de::Error::custom(
225 "response must be either a success/error or notification object",
226 )),
227 }
228 }
229 }
230
231 deserializer.deserialize_map(ResponseVisitor(&()))
232 }
233}
234
235#[derive(Clone, Debug)]
239pub enum Authorization {
240 Basic(String),
242 Bearer(String),
244 Raw(String),
246}
247
248impl Authorization {
249 pub fn basic(username: impl AsRef<str>, password: impl AsRef<str>) -> Self {
251 let username = username.as_ref();
252 let password = password.as_ref();
253 let auth_secret = general_purpose::STANDARD.encode(format!("{username}:{password}"));
254 Self::Basic(auth_secret)
255 }
256
257 pub fn bearer(token: impl Into<String>) -> Self {
259 Self::Bearer(token.into())
260 }
261
262 pub fn raw(token: impl Into<String>) -> Self {
264 Self::Raw(token.into())
265 }
266}
267
268impl fmt::Display for Authorization {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 match self {
271 Authorization::Basic(auth_secret) => write!(f, "Basic {auth_secret}"),
272 Authorization::Bearer(token) => write!(f, "Bearer {token}"),
273 Authorization::Raw(s) => write!(f, "{s}"),
274 }
275 }
276}
277
278const DEFAULT_ALGORITHM: Algorithm = Algorithm::HS256;
280
281pub const JWT_SECRET_LENGTH: usize = 32;
283
284pub struct JwtKey([u8; JWT_SECRET_LENGTH]);
286
287impl JwtKey {
288 pub fn from_slice(key: &[u8]) -> Result<Self, String> {
290 if key.len() != JWT_SECRET_LENGTH {
291 return Err(format!(
292 "Invalid key length. Expected {} got {}",
293 JWT_SECRET_LENGTH,
294 key.len()
295 ));
296 }
297 let mut res = [0; JWT_SECRET_LENGTH];
298 res.copy_from_slice(key);
299 Ok(Self(res))
300 }
301
302 pub fn from_hex(hex: &str) -> Result<Self, String> {
304 let bytes = hex::decode(hex).map_err(|e| format!("Invalid hex: {}", e))?;
305 Self::from_slice(&bytes)
306 }
307
308 pub fn as_bytes(&self) -> &[u8; JWT_SECRET_LENGTH] {
310 &self.0
311 }
312
313 pub fn into_bytes(self) -> [u8; JWT_SECRET_LENGTH] {
315 self.0
316 }
317}
318
319pub struct JwtAuth {
321 key: EncodingKey,
322 id: Option<String>,
323 clv: Option<String>,
324}
325
326impl JwtAuth {
327 pub fn new(secret: JwtKey, id: Option<String>, clv: Option<String>) -> Self {
329 Self { key: EncodingKey::from_secret(secret.as_bytes()), id, clv }
330 }
331
332 pub fn generate_token(&self) -> Result<String, Error> {
334 let claims = self.generate_claims_at_timestamp();
335 self.generate_token_with_claims(&claims)
336 }
337
338 fn generate_token_with_claims(&self, claims: &Claims) -> Result<String, Error> {
340 let header = Header::new(DEFAULT_ALGORITHM);
341 encode(&header, claims, &self.key)
342 }
343
344 fn generate_claims_at_timestamp(&self) -> Claims {
346 Claims { iat: get_current_timestamp(), id: self.id.clone(), clv: self.clv.clone() }
347 }
348
349 pub fn validate_token(
351 token: &str,
352 secret: &JwtKey,
353 ) -> Result<jsonwebtoken::TokenData<Claims>, Error> {
354 let mut validation = jsonwebtoken::Validation::new(DEFAULT_ALGORITHM);
355 validation.validate_exp = false;
356 validation.required_spec_claims.remove("exp");
357
358 jsonwebtoken::decode::<Claims>(
359 token,
360 &jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()),
361 &validation,
362 )
363 .map_err(Into::into)
364 }
365}
366
367#[derive(Debug, Serialize, Deserialize, PartialEq)]
369pub struct Claims {
370 iat: u64,
372 id: Option<String>,
374 clv: Option<String>,
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn deser_response() {
384 let _ =
385 serde_json::from_str::<Response<'_>>(r#"{"jsonrpc":"2.0","result":19}"#).unwrap_err();
386 let _ = serde_json::from_str::<Response<'_>>(r#"{"jsonrpc":"3.0","result":19,"id":1}"#)
387 .unwrap_err();
388
389 let response: Response<'_> =
390 serde_json::from_str(r#"{"jsonrpc":"2.0","result":19,"id":1}"#).unwrap();
391
392 match response {
393 Response::Success { id, result } => {
394 assert_eq!(id, 1);
395 let result: u64 = serde_json::from_str(result.get()).unwrap();
396 assert_eq!(result, 19);
397 },
398 _ => {
399 assert!(false, "Expected Success response but got: {:?}", response);
400 },
401 }
402
403 let response: Response<'_> = serde_json::from_str(
404 r#"{"jsonrpc":"2.0","error":{"code":-32000,"message":"error occurred"},"id":2}"#,
405 )
406 .unwrap();
407
408 match response {
409 Response::Error { id, error } => {
410 assert_eq!(id, 2);
411 assert_eq!(error.code, -32000);
412 assert_eq!(error.message, "error occurred");
413 assert!(error.data.is_none());
414 },
415 _ => {
416 assert!(false, "Expected Error response but got: {:?}", response);
417 },
418 }
419
420 let response: Response<'_> =
421 serde_json::from_str(r#"{"jsonrpc":"2.0","result":"0xfa","id":0}"#).unwrap();
422
423 match response {
424 Response::Success { id, result } => {
425 assert_eq!(id, 0);
426 let result: String = serde_json::from_str(result.get()).unwrap();
427 assert_eq!(i64::from_str_radix(result.trim_start_matches("0x"), 16).unwrap(), 250);
428 },
429 _ => {
430 assert!(false, "Expected Success response but got: {:?}", response);
431 },
432 }
433 }
434
435 #[test]
436 fn ser_request() {
437 let request: Request<()> = Request::new(0, "neo_chainId", ());
438 assert_eq!(
439 &serde_json::to_string(&request).unwrap(),
440 r#"{"id":0,"jsonrpc":"2.0","method":"neo_chainId"}"#
441 );
442
443 let request: Request<()> = Request::new(300, "method_name", ());
444 assert_eq!(
445 &serde_json::to_string(&request).unwrap(),
446 r#"{"id":300,"jsonrpc":"2.0","method":"method_name"}"#
447 );
448
449 let request: Request<u32> = Request::new(300, "method_name", 1);
450 assert_eq!(
451 &serde_json::to_string(&request).unwrap(),
452 r#"{"id":300,"jsonrpc":"2.0","method":"method_name","params":1}"#
453 );
454 }
455
456 #[test]
457 fn test_roundtrip() {
458 let jwt_secret = [42; 32];
459 let auth = JwtAuth::new(
460 JwtKey::from_slice(&jwt_secret).unwrap(),
461 Some("42".into()),
462 Some("Lighthouse".into()),
463 );
464 let claims = auth.generate_claims_at_timestamp();
465 let token = auth.generate_token_with_claims(&claims).unwrap();
466
467 assert_eq!(
468 JwtAuth::validate_token(&token, &JwtKey::from_slice(&jwt_secret).unwrap())
469 .unwrap()
470 .claims,
471 claims
472 );
473 }
474}