1use std::borrow::Cow as StdCow;
28use std::fmt;
29use std::sync::Arc;
30use std::time::Duration;
31
32use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClient, HttpTransportClientBuilder};
33use crate::types::{NotificationSer, RequestSer, Response};
34use crate::{HttpRequest, HttpResponse};
35use async_trait::async_trait;
36use hyper::body::Bytes;
37use hyper::http::HeaderMap;
38use jsonrpsee_core::client::{
39 generate_batch_id_range, BatchResponse, ClientT, Error, IdKind, RequestIdManager, Subscription, SubscriptionClientT,
40};
41use jsonrpsee_core::params::BatchRequestBuilder;
42use jsonrpsee_core::traits::ToRpcParams;
43use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES};
44use jsonrpsee_types::{ErrorObject, InvalidRequestId, ResponseSuccess, TwoPointZero};
45use serde::de::DeserializeOwned;
46use tower::layer::util::Identity;
47use tower::{Layer, Service};
48use tracing::instrument;
49
50#[cfg(feature = "tls")]
51use crate::{CertificateStore, CustomCertStore};
52
53#[derive(Debug)]
77pub struct HttpClientBuilder<L = Identity> {
78 max_request_size: u32,
79 max_response_size: u32,
80 request_timeout: Duration,
81 max_concurrent_requests: usize,
82 #[cfg(feature = "tls")]
83 certificate_store: CertificateStore,
84 id_kind: IdKind,
85 max_log_length: u32,
86 headers: HeaderMap,
87 service_builder: tower::ServiceBuilder<L>,
88 tcp_no_delay: bool,
89}
90
91impl<L> HttpClientBuilder<L> {
92 pub fn max_request_size(mut self, size: u32) -> Self {
94 self.max_request_size = size;
95 self
96 }
97
98 pub fn max_response_size(mut self, size: u32) -> Self {
100 self.max_response_size = size;
101 self
102 }
103
104 pub fn request_timeout(mut self, timeout: Duration) -> Self {
106 self.request_timeout = timeout;
107 self
108 }
109
110 #[cfg(feature = "tls")]
176 pub fn with_custom_cert_store(mut self, cfg: CustomCertStore) -> Self {
177 self.certificate_store = CertificateStore::Custom(cfg);
178 self
179 }
180
181 pub fn id_format(mut self, id_kind: IdKind) -> Self {
183 self.id_kind = id_kind;
184 self
185 }
186
187 pub fn set_max_logging_length(mut self, max: u32) -> Self {
191 self.max_log_length = max;
192 self
193 }
194
195 pub fn set_headers(mut self, headers: HeaderMap) -> Self {
199 self.headers = headers;
200 self
201 }
202
203 pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
207 self.tcp_no_delay = no_delay;
208 self
209 }
210
211 pub fn set_http_middleware<T>(self, service_builder: tower::ServiceBuilder<T>) -> HttpClientBuilder<T> {
213 HttpClientBuilder {
214 #[cfg(feature = "tls")]
215 certificate_store: self.certificate_store,
216 id_kind: self.id_kind,
217 headers: self.headers,
218 max_log_length: self.max_log_length,
219 max_concurrent_requests: self.max_concurrent_requests,
220 max_request_size: self.max_request_size,
221 max_response_size: self.max_response_size,
222 service_builder,
223 request_timeout: self.request_timeout,
224 tcp_no_delay: self.tcp_no_delay,
225 }
226 }
227}
228
229impl<B, S, L> HttpClientBuilder<L>
230where
231 L: Layer<transport::HttpBackend, Service = S>,
232 S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Clone,
233 B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
234 B::Data: Send,
235 B::Error: Into<BoxError>,
236{
237 pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient<S>, Error> {
239 let Self {
240 max_request_size,
241 max_response_size,
242 request_timeout,
243 #[cfg(feature = "tls")]
244 certificate_store,
245 id_kind,
246 headers,
247 max_log_length,
248 service_builder,
249 tcp_no_delay,
250 ..
251 } = self;
252
253 let transport = HttpTransportClientBuilder {
254 max_request_size,
255 max_response_size,
256 headers,
257 max_log_length,
258 tcp_no_delay,
259 service_builder,
260 #[cfg(feature = "tls")]
261 certificate_store,
262 }
263 .build(target)
264 .map_err(|e| Error::Transport(e.into()))?;
265
266 Ok(HttpClient { transport, id_manager: Arc::new(RequestIdManager::new(id_kind)), request_timeout })
267 }
268}
269
270impl Default for HttpClientBuilder<Identity> {
271 fn default() -> Self {
272 Self {
273 max_request_size: TEN_MB_SIZE_BYTES,
274 max_response_size: TEN_MB_SIZE_BYTES,
275 request_timeout: Duration::from_secs(60),
276 max_concurrent_requests: 256,
277 #[cfg(feature = "tls")]
278 certificate_store: CertificateStore::Native,
279 id_kind: IdKind::Number,
280 max_log_length: 4096,
281 headers: HeaderMap::new(),
282 service_builder: tower::ServiceBuilder::new(),
283 tcp_no_delay: true,
284 }
285 }
286}
287
288impl HttpClientBuilder<Identity> {
289 pub fn new() -> HttpClientBuilder<Identity> {
291 HttpClientBuilder::default()
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct HttpClient<S = HttpBackend> {
298 transport: HttpTransportClient<S>,
300 request_timeout: Duration,
302 id_manager: Arc<RequestIdManager>,
304}
305
306impl HttpClient<HttpBackend> {
307 pub fn builder() -> HttpClientBuilder<Identity> {
309 HttpClientBuilder::new()
310 }
311}
312
313#[async_trait]
314impl<B, S> ClientT for HttpClient<S>
315where
316 S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Send + Sync + Clone,
317 <S as Service<HttpRequest>>::Future: Send,
318 B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
319 B::Error: Into<BoxError>,
320 B::Data: Send,
321{
322 #[instrument(name = "notification", skip(self, params), level = "trace")]
323 async fn notification<Params>(&self, method: &str, params: Params) -> Result<(), Error>
324 where
325 Params: ToRpcParams + Send,
326 {
327 let params = params.to_rpc_params()?;
328 let notif =
329 serde_json::to_string(&NotificationSer::borrowed(&method, params.as_deref())).map_err(Error::ParseError)?;
330
331 let fut = self.transport.send(notif);
332
333 match tokio::time::timeout(self.request_timeout, fut).await {
334 Ok(Ok(ok)) => Ok(ok),
335 Err(_) => Err(Error::RequestTimeout),
336 Ok(Err(e)) => Err(Error::Transport(e.into())),
337 }
338 }
339
340 #[instrument(name = "method_call", skip(self, params), level = "trace")]
341 async fn request<R, Params>(&self, method: &str, params: Params) -> Result<R, Error>
342 where
343 R: DeserializeOwned,
344 Params: ToRpcParams + Send,
345 {
346 let id = self.id_manager.next_request_id();
347 let params = params.to_rpc_params()?;
348
349 let request = RequestSer::borrowed(&id, &method, params.as_deref());
350 let raw = serde_json::to_string(&request).map_err(Error::ParseError)?;
351
352 let fut = self.transport.send_and_read_body(raw);
353 let body = match tokio::time::timeout(self.request_timeout, fut).await {
354 Ok(Ok(body)) => body,
355 Err(_e) => {
356 return Err(Error::RequestTimeout);
357 }
358 Ok(Err(e)) => {
359 return Err(Error::Transport(e.into()));
360 }
361 };
362
363 let response = ResponseSuccess::try_from(serde_json::from_slice::<Response<&JsonRawValue>>(&body)?)?;
366
367 let result = serde_json::from_str(response.result.get()).map_err(Error::ParseError)?;
368
369 if response.id == id {
370 Ok(result)
371 } else {
372 Err(InvalidRequestId::NotPendingRequest(response.id.to_string()).into())
373 }
374 }
375
376 #[instrument(name = "batch", skip(self, batch), level = "trace")]
377 async fn batch_request<'a, R>(&self, batch: BatchRequestBuilder<'a>) -> Result<BatchResponse<'a, R>, Error>
378 where
379 R: DeserializeOwned + fmt::Debug + 'a,
380 {
381 let batch = batch.build()?;
382 let id = self.id_manager.next_request_id();
383 let id_range = generate_batch_id_range(id, batch.len() as u64)?;
384
385 let mut batch_request = Vec::with_capacity(batch.len());
386 for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
387 let id = self.id_manager.as_id_kind().into_id(id);
388 batch_request.push(RequestSer {
389 jsonrpc: TwoPointZero,
390 id,
391 method: method.into(),
392 params: params.map(StdCow::Owned),
393 });
394 }
395
396 let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?);
397
398 let body = match tokio::time::timeout(self.request_timeout, fut).await {
399 Ok(Ok(body)) => body,
400 Err(_e) => return Err(Error::RequestTimeout),
401 Ok(Err(e)) => return Err(Error::Transport(e.into())),
402 };
403
404 let json_rps: Vec<Response<&JsonRawValue>> = serde_json::from_slice(&body).map_err(Error::ParseError)?;
405
406 let mut responses = Vec::with_capacity(json_rps.len());
407 let mut successful_calls = 0;
408 let mut failed_calls = 0;
409
410 for _ in 0..json_rps.len() {
411 responses.push(Err(ErrorObject::borrowed(0, "", None)));
412 }
413
414 for rp in json_rps {
415 let id = rp.id.try_parse_inner_as_number()?;
416
417 let res = match ResponseSuccess::try_from(rp) {
418 Ok(r) => {
419 let result = serde_json::from_str(r.result.get())?;
420 successful_calls += 1;
421 Ok(result)
422 }
423 Err(err) => {
424 failed_calls += 1;
425 Err(err)
426 }
427 };
428
429 let maybe_elem = id
430 .checked_sub(id_range.start)
431 .and_then(|p| p.try_into().ok())
432 .and_then(|p: usize| responses.get_mut(p));
433
434 if let Some(elem) = maybe_elem {
435 *elem = res;
436 } else {
437 return Err(InvalidRequestId::NotPendingRequest(id.to_string()).into());
438 }
439 }
440
441 Ok(BatchResponse::new(successful_calls, responses, failed_calls))
442 }
443}
444
445#[async_trait]
446impl<B, S> SubscriptionClientT for HttpClient<S>
447where
448 S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Send + Sync + Clone,
449 <S as Service<HttpRequest>>::Future: Send,
450 B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
451 B::Data: Send,
452 B::Error: Into<BoxError>,
453{
454 #[instrument(name = "subscription", fields(method = _subscribe_method), skip(self, _params, _subscribe_method, _unsubscribe_method), level = "trace")]
457 async fn subscribe<'a, N, Params>(
458 &self,
459 _subscribe_method: &'a str,
460 _params: Params,
461 _unsubscribe_method: &'a str,
462 ) -> Result<Subscription<N>, Error>
463 where
464 Params: ToRpcParams + Send,
465 N: DeserializeOwned,
466 {
467 Err(Error::HttpNotImplemented)
468 }
469
470 #[instrument(name = "subscribe_method", fields(method = _method), skip(self, _method), level = "trace")]
472 async fn subscribe_to_method<'a, N>(&self, _method: &'a str) -> Result<Subscription<N>, Error>
473 where
474 N: DeserializeOwned,
475 {
476 Err(Error::HttpNotImplemented)
477 }
478}