-
-
Save devsnek/e4d381c7f2a1f0b0d392a132e9835ce7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs | |
index c6dc778..905aa8d 100644 | |
--- a/src/client/legacy/client.rs | |
+++ b/src/client/legacy/client.rs | |
@@ -8,6 +8,7 @@ use std::error::Error as StdError; | |
use std::fmt; | |
use std::future::Future; | |
use std::pin::Pin; | |
+use std::sync::Arc; | |
use std::task::{self, Poll}; | |
use std::time::Duration; | |
@@ -35,7 +36,7 @@ type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>; | |
/// `Client` is cheap to clone and cloning is the recommended way to share a `Client`. The | |
/// underlying connection pool will be reused. | |
#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))] | |
-pub struct Client<C, B> { | |
+pub struct Client<C, B, PK: pool::Key = DefaultPoolKey> { | |
config: Config, | |
connector: C, | |
exec: Exec, | |
@@ -43,7 +44,8 @@ pub struct Client<C, B> { | |
h1_builder: hyper::client::conn::http1::Builder, | |
#[cfg(feature = "http2")] | |
h2_builder: hyper::client::conn::http2::Builder<Exec>, | |
- pool: pool::Pool<PoolClient<B>, PoolKey>, | |
+ pool_key: Arc<dyn Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static>, | |
+ pool: pool::Pool<PoolClient<B>, PK>, | |
} | |
#[derive(Clone, Copy, Debug)] | |
@@ -90,7 +92,19 @@ macro_rules! e { | |
} | |
// We might change this... :shrug: | |
-type PoolKey = (http::uri::Scheme, http::uri::Authority); | |
+#[derive(Debug, Clone, Hash, PartialEq, Eq)] | |
+pub struct DefaultPoolKey(http::uri::Scheme, http::uri::Authority); | |
+impl pool::Key for DefaultPoolKey { | |
+ type Connect = Uri; | |
+ fn connect(self) -> Uri { | |
+ http::uri::Builder::new() | |
+ .scheme(self.0) | |
+ .authority(self.1) | |
+ .path_and_query("/") | |
+ .build() | |
+ .expect("domain is valid Uri") | |
+ } | |
+} | |
enum TrySendError<B> { | |
Retryable { | |
@@ -143,12 +157,13 @@ impl Client<(), ()> { | |
} | |
} | |
-impl<C, B> Client<C, B> | |
+impl<C, B, PK> Client<C, B, PK> | |
where | |
- C: Connect + Clone + Send + Sync + 'static, | |
+ C: Connect<PK> + Clone + Send + Sync + 'static, | |
B: Body + Send + 'static + Unpin, | |
B::Data: Send, | |
B::Error: Into<Box<dyn StdError + Send + Sync>>, | |
+ PK: pool::Key, | |
{ | |
/// Send a `GET` request to the supplied `Uri`. | |
/// | |
@@ -214,27 +229,15 @@ where | |
/// # } | |
/// # fn main() {} | |
/// ``` | |
- pub fn request(&self, mut req: Request<B>) -> ResponseFuture { | |
- let is_http_connect = req.method() == Method::CONNECT; | |
- match req.version() { | |
- Version::HTTP_11 => (), | |
- Version::HTTP_10 => { | |
- if is_http_connect { | |
- warn!("CONNECT is not allowed for HTTP/1.0"); | |
- return ResponseFuture::new(future::err(e!(UserUnsupportedRequestMethod))); | |
- } | |
- } | |
- Version::HTTP_2 => (), | |
- // completely unsupported HTTP version (like HTTP/0.9)! | |
- other => return ResponseFuture::error_version(other), | |
- }; | |
- | |
- let pool_key = match extract_domain(req.uri_mut(), is_http_connect) { | |
+ pub fn request(&self, req: Request<B>) -> ResponseFuture { | |
+ let (mut parts, body) = req.into_parts(); | |
+ let pool_key = match (self.pool_key)(&mut parts) { | |
Ok(s) => s, | |
Err(err) => { | |
return ResponseFuture::new(future::err(err)); | |
} | |
}; | |
+ let req = Request::from_parts(parts, body); | |
ResponseFuture::new(self.clone().send_request(req, pool_key)) | |
} | |
@@ -242,12 +245,13 @@ where | |
async fn send_request( | |
self, | |
mut req: Request<B>, | |
- pool_key: PoolKey, | |
+ pool_key: PK, | |
) -> Result<Response<hyper::body::Incoming>, Error> { | |
let uri = req.uri().clone(); | |
loop { | |
- req = match self.try_send_request(req, pool_key.clone()).await { | |
+ let pk: PK = pool_key.clone(); | |
+ req = match self.try_send_request(req, pk).await { | |
Ok(resp) => return Ok(resp), | |
Err(TrySendError::Nope(err)) => return Err(err), | |
Err(TrySendError::Retryable { | |
@@ -275,7 +279,7 @@ where | |
async fn try_send_request( | |
&self, | |
mut req: Request<B>, | |
- pool_key: PoolKey, | |
+ pool_key: PK, | |
) -> Result<Response<hyper::body::Incoming>, TrySendError<B>> { | |
let mut pooled = self | |
.connection_for(pool_key) | |
@@ -366,12 +370,10 @@ where | |
Ok(res) | |
} | |
- async fn connection_for( | |
- &self, | |
- pool_key: PoolKey, | |
- ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, Error> { | |
+ async fn connection_for(&self, pool_key: PK) -> Result<pool::Pooled<PoolClient<B>, PK>, Error> { | |
loop { | |
- match self.one_connection_for(pool_key.clone()).await { | |
+ let pk: PK = pool_key.clone(); | |
+ match self.one_connection_for(pk).await { | |
Ok(pooled) => return Ok(pooled), | |
Err(ClientConnectError::Normal(err)) => return Err(err), | |
Err(ClientConnectError::CheckoutIsClosed(reason)) => { | |
@@ -391,8 +393,8 @@ where | |
async fn one_connection_for( | |
&self, | |
- pool_key: PoolKey, | |
- ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, ClientConnectError> { | |
+ pool_key: PK, | |
+ ) -> Result<pool::Pooled<PoolClient<B>, PK>, ClientConnectError> { | |
// Return a single connection if pooling is not enabled | |
if !self.pool.is_enabled() { | |
return self | |
@@ -484,9 +486,8 @@ where | |
#[cfg(any(feature = "http1", feature = "http2"))] | |
fn connect_to( | |
&self, | |
- pool_key: PoolKey, | |
- ) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PoolKey>, Error>> + Send + Unpin | |
- { | |
+ pool_key: PK, | |
+ ) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PK>, Error>> + Send + Unpin { | |
let executor = self.exec.clone(); | |
let pool = self.pool.clone(); | |
#[cfg(feature = "http1")] | |
@@ -496,7 +497,6 @@ where | |
let ver = self.config.ver; | |
let is_ver_h2 = ver == Ver::Http2; | |
let connector = self.connector.clone(); | |
- let dst = domain_as_uri(pool_key.clone()); | |
hyper_lazy(move || { | |
// Try to take a "connecting lock". | |
// | |
@@ -514,7 +514,7 @@ where | |
}; | |
Either::Left( | |
connector | |
- .connect(super::connect::sealed::Internal, dst) | |
+ .connect(super::connect::sealed::Internal, pool_key.connect()) | |
.map_err(|src| e!(Connect, src)) | |
.and_then(move |io| { | |
let connected = io.connected(); | |
@@ -669,7 +669,7 @@ where | |
impl<C, B> tower_service::Service<Request<B>> for Client<C, B> | |
where | |
- C: Connect + Clone + Send + Sync + 'static, | |
+ C: Connect<DefaultPoolKey> + Clone + Send + Sync + 'static, | |
B: Body + Send + 'static + Unpin, | |
B::Data: Send, | |
B::Error: Into<Box<dyn StdError + Send + Sync>>, | |
@@ -689,7 +689,7 @@ where | |
impl<C, B> tower_service::Service<Request<B>> for &'_ Client<C, B> | |
where | |
- C: Connect + Clone + Send + Sync + 'static, | |
+ C: Connect<DefaultPoolKey> + Clone + Send + Sync + 'static, | |
B: Body + Send + 'static + Unpin, | |
B::Data: Send, | |
B::Error: Into<Box<dyn StdError + Send + Sync>>, | |
@@ -707,8 +707,8 @@ where | |
} | |
} | |
-impl<C: Clone, B> Clone for Client<C, B> { | |
- fn clone(&self) -> Client<C, B> { | |
+impl<C: Clone, B, PK: pool::Key> Clone for Client<C, B, PK> { | |
+ fn clone(&self) -> Client<C, B, PK> { | |
Client { | |
config: self.config, | |
exec: self.exec.clone(), | |
@@ -717,6 +717,7 @@ impl<C: Clone, B> Clone for Client<C, B> { | |
#[cfg(feature = "http2")] | |
h2_builder: self.h2_builder.clone(), | |
connector: self.connector.clone(), | |
+ pool_key: self.pool_key.clone(), | |
pool: self.pool.clone(), | |
} | |
} | |
@@ -739,11 +740,6 @@ impl ResponseFuture { | |
inner: SyncWrapper::new(Box::pin(value)), | |
} | |
} | |
- | |
- fn error_version(ver: Version) -> Self { | |
- warn!("Request has unsupported version \"{:?}\"", ver); | |
- ResponseFuture::new(Box::pin(future::err(e!(UserUnsupportedVersion)))) | |
- } | |
} | |
impl fmt::Debug for ResponseFuture { | |
@@ -937,10 +933,31 @@ fn authority_form(uri: &mut Uri) { | |
}; | |
} | |
-fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error> { | |
+fn default_pool_key(req: &mut http::request::Parts) -> Result<DefaultPoolKey, Error> { | |
+ let is_http_connect = req.method == Method::CONNECT; | |
+ match req.version { | |
+ Version::HTTP_11 => (), | |
+ Version::HTTP_10 => { | |
+ if is_http_connect { | |
+ warn!("CONNECT is not allowed for HTTP/1.0"); | |
+ return Err(e!(UserUnsupportedRequestMethod)); | |
+ } | |
+ } | |
+ Version::HTTP_2 => (), | |
+ // completely unsupported HTTP version (like HTTP/0.9)! | |
+ other => { | |
+ warn!("Request has unsupported version \"{:?}\"", other); | |
+ return Err(e!(UserUnsupportedVersion)); | |
+ } | |
+ }; | |
+ | |
+ extract_domain(&mut req.uri, is_http_connect) | |
+} | |
+ | |
+fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<DefaultPoolKey, Error> { | |
let uri_clone = uri.clone(); | |
match (uri_clone.scheme(), uri_clone.authority()) { | |
- (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())), | |
+ (Some(scheme), Some(auth)) => Ok(DefaultPoolKey(scheme.clone(), auth.clone())), | |
(None, Some(auth)) if is_http_connect => { | |
let scheme = match auth.port_u16() { | |
Some(443) => { | |
@@ -952,7 +969,7 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error | |
Scheme::HTTP | |
} | |
}; | |
- Ok((scheme, auth.clone())) | |
+ Ok(DefaultPoolKey(scheme, auth.clone())) | |
} | |
_ => { | |
debug!("Client requires absolute-form URIs, received: {:?}", uri); | |
@@ -961,15 +978,6 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error | |
} | |
} | |
-fn domain_as_uri((scheme, auth): PoolKey) -> Uri { | |
- http::uri::Builder::new() | |
- .scheme(scheme) | |
- .authority(auth) | |
- .path_and_query("/") | |
- .build() | |
- .expect("domain is valid Uri") | |
-} | |
- | |
fn set_scheme(uri: &mut Uri, scheme: Scheme) { | |
debug_assert!( | |
uri.scheme().is_none(), | |
@@ -1589,11 +1597,27 @@ impl Builder { | |
} | |
/// Combine the configuration of this builder with a connector to create a `Client`. | |
- pub fn build<C, B>(&self, connector: C) -> Client<C, B> | |
+ pub fn build<'a, C, B>(&'a self, connector: C) -> Client<C, B, DefaultPoolKey> | |
+ where | |
+ C: Connect<DefaultPoolKey> + Clone, | |
+ B: Body + Send, | |
+ B::Data: Send, | |
+ { | |
+ self.build_with_pool_key::<C, B, DefaultPoolKey>(connector, default_pool_key) | |
+ } | |
+ | |
+ /// Combine the configuration of this builder with a connector to create a `Client`, with a custom pooling key. | |
+ /// A function to extract the pool key from the request is required. | |
+ pub fn build_with_pool_key<C, B, PK>( | |
+ &self, | |
+ connector: C, | |
+ pool_key: impl Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static, | |
+ ) -> Client<C, B, PK> | |
where | |
- C: Connect + Clone, | |
+ C: Connect<PK> + Clone, | |
B: Body + Send, | |
B::Data: Send, | |
+ PK: pool::Key, | |
{ | |
let exec = self.exec.clone(); | |
let timer = self.pool_timer.clone(); | |
@@ -1605,7 +1629,8 @@ impl Builder { | |
#[cfg(feature = "http2")] | |
h2_builder: self.h2_builder.clone(), | |
connector, | |
- pool: pool::Pool::new(self.pool_config, exec, timer), | |
+ pool_key: Arc::new(pool_key), | |
+ pool: pool::Pool::<_, PK>::new(self.pool_config, exec, timer), | |
} | |
} | |
} | |
diff --git a/src/client/legacy/connect/http.rs b/src/client/legacy/connect/http.rs | |
index f19a78e..64592f2 100644 | |
--- a/src/client/legacy/connect/http.rs | |
+++ b/src/client/legacy/connect/http.rs | |
@@ -1020,6 +1020,7 @@ mod tests { | |
use ::http::Uri; | |
+ use crate::client::legacy::client::DefaultPoolKey; | |
use crate::client::legacy::connect::http::TcpKeepaliveConfig; | |
use super::super::sealed::{Connect, ConnectSvc}; | |
@@ -1030,7 +1031,7 @@ mod tests { | |
async fn connect<C>( | |
connector: C, | |
dst: Uri, | |
- ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> | |
+ ) -> Result<<C::_Svc as ConnectSvc<DefaultPoolKey>>::Connection, <C::_Svc as ConnectSvc>::Error> | |
where | |
C: Connect, | |
{ | |
diff --git a/src/client/legacy/connect/mod.rs b/src/client/legacy/connect/mod.rs | |
index 90a9767..32ba693 100644 | |
--- a/src/client/legacy/connect/mod.rs | |
+++ b/src/client/legacy/connect/mod.rs | |
@@ -304,10 +304,11 @@ pub(super) mod sealed { | |
use std::error::Error as StdError; | |
use std::future::Future; | |
- use ::http::Uri; | |
use hyper::rt::{Read, Write}; | |
use super::Connection; | |
+ use crate::client::legacy::client::DefaultPoolKey; | |
+ use crate::client::legacy::pool; | |
/// Connect to a destination, returning an IO transport. | |
/// | |
@@ -321,61 +322,72 @@ pub(super) mod sealed { | |
/// implement this trait, but `tower::Service<Uri>` instead. | |
// The `Sized` bound is to prevent creating `dyn Connect`, since they cannot | |
// fit the `Connect` bounds because of the blanket impl for `Service`. | |
- pub trait Connect: Sealed + Sized { | |
+ pub trait Connect<PK: pool::Key = DefaultPoolKey>: Sealed<PK> + Sized { | |
#[doc(hidden)] | |
- type _Svc: ConnectSvc; | |
+ type _Svc: ConnectSvc<PK>; | |
#[doc(hidden)] | |
- fn connect(self, internal_only: Internal, dst: Uri) -> <Self::_Svc as ConnectSvc>::Future; | |
+ fn connect( | |
+ self, | |
+ internal_only: Internal, | |
+ dst: PK::Connect, | |
+ ) -> <Self::_Svc as ConnectSvc<PK>>::Future; | |
} | |
- pub trait ConnectSvc { | |
+ pub trait ConnectSvc<PK: pool::Key = DefaultPoolKey> { | |
type Connection: Read + Write + Connection + Unpin + Send + 'static; | |
type Error: Into<Box<dyn StdError + Send + Sync>>; | |
type Future: Future<Output = Result<Self::Connection, Self::Error>> + Unpin + Send + 'static; | |
- fn connect(self, internal_only: Internal, dst: Uri) -> Self::Future; | |
+ fn connect(self, internal_only: Internal, dst: PK::Connect) -> Self::Future; | |
} | |
- impl<S, T> Connect for S | |
+ impl<S, T, PK> Connect<PK> for S | |
where | |
- S: tower_service::Service<Uri, Response = T> + Send + 'static, | |
+ PK: pool::Key, | |
+ S: tower_service::Service<PK::Connect, Response = T> + Send + 'static, | |
S::Error: Into<Box<dyn StdError + Send + Sync>>, | |
S::Future: Unpin + Send, | |
T: Read + Write + Connection + Unpin + Send + 'static, | |
{ | |
type _Svc = S; | |
- fn connect(self, _: Internal, dst: Uri) -> crate::service::Oneshot<S, Uri> { | |
+ fn connect(self, _: Internal, dst: PK::Connect) -> crate::service::Oneshot<S, PK::Connect> { | |
crate::service::Oneshot::new(self, dst) | |
} | |
} | |
- impl<S, T> ConnectSvc for S | |
+ impl<S, T, PK> ConnectSvc<PK> for S | |
where | |
- S: tower_service::Service<Uri, Response = T> + Send + 'static, | |
+ PK: pool::Key, | |
+ S: tower_service::Service<PK::Connect, Response = T> + Send + 'static, | |
S::Error: Into<Box<dyn StdError + Send + Sync>>, | |
S::Future: Unpin + Send, | |
T: Read + Write + Connection + Unpin + Send + 'static, | |
{ | |
type Connection = T; | |
type Error = S::Error; | |
- type Future = crate::service::Oneshot<S, Uri>; | |
+ type Future = crate::service::Oneshot<S, PK::Connect>; | |
- fn connect(self, _: Internal, dst: Uri) -> Self::Future { | |
+ fn connect(self, _: Internal, dst: PK::Connect) -> Self::Future { | |
crate::service::Oneshot::new(self, dst) | |
} | |
} | |
- impl<S, T> Sealed for S | |
+ impl<S, T, PK> Sealed<PK> for S | |
where | |
- S: tower_service::Service<Uri, Response = T> + Send, | |
+ PK: pool::Key, | |
+ S: tower_service::Service<PK::Connect, Response = T> + Send, | |
S::Error: Into<Box<dyn StdError + Send + Sync>>, | |
S::Future: Unpin + Send, | |
T: Read + Write + Connection + Unpin + Send + 'static, | |
{ | |
} | |
- pub trait Sealed {} | |
+ pub trait Sealed<PK> | |
+ where | |
+ PK: pool::Key, | |
+ { | |
+ } | |
#[allow(missing_debug_implementations)] | |
pub struct Internal; | |
} | |
diff --git a/src/client/legacy/pool.rs b/src/client/legacy/pool.rs | |
index 727f54b..8ffc356 100644 | |
--- a/src/client/legacy/pool.rs | |
+++ b/src/client/legacy/pool.rs | |
@@ -43,9 +43,10 @@ pub trait Poolable: Unpin + Send + Sized + 'static { | |
fn can_share(&self) -> bool; | |
} | |
-pub trait Key: Eq + Hash + Clone + Debug + Unpin + Send + 'static {} | |
- | |
-impl<T> Key for T where T: Eq + Hash + Clone + Debug + Unpin + Send + 'static {} | |
+pub trait Key: Eq + Hash + Clone + Debug + Unpin + Send + 'static { | |
+ type Connect: Send + 'static; | |
+ fn connect(self) -> Self::Connect; | |
+} | |
/// A marker to identify what version a pooled connection is. | |
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] | |
@@ -845,8 +846,17 @@ mod tests { | |
#[derive(Clone, Debug, PartialEq, Eq, Hash)] | |
struct KeyImpl(http::uri::Scheme, http::uri::Authority); | |
- | |
- type KeyTuple = (http::uri::Scheme, http::uri::Authority); | |
+ impl super::Key for KeyImpl { | |
+ type Connect = http::uri::Uri; | |
+ fn connect(self) -> http::uri::Uri { | |
+ http::uri::Builder::new() | |
+ .scheme(self.0) | |
+ .authority(self.1) | |
+ .path_and_query("/") | |
+ .build() | |
+ .expect("domain is valid Uri") | |
+ } | |
+ } | |
/// Test unique reservations. | |
#[derive(Debug, PartialEq, Eq)] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment