Skip to content

Instantly share code, notes, and snippets.

@devsnek
Created July 9, 2025 15:51
Show Gist options
  • Save devsnek/e4d381c7f2a1f0b0d392a132e9835ce7 to your computer and use it in GitHub Desktop.
Save devsnek/e4d381c7f2a1f0b0d392a132e9835ce7 to your computer and use it in GitHub Desktop.
diff --git a/src/client/legacy/client.rs b/src/client/legacy/client.rs
index c6dc778..905aa8d 100644
--- a/src/client/legacy/client.rs
+++ b/src/client/legacy/client.rs
@@ -8,6 +8,7 @@ use std::error::Error as StdError;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
+use std::sync::Arc;
use std::task::{self, Poll};
use std::time::Duration;
@@ -35,7 +36,7 @@ type BoxSendFuture = Pin<Box<dyn Future<Output = ()> + Send>>;
/// `Client` is cheap to clone and cloning is the recommended way to share a `Client`. The
/// underlying connection pool will be reused.
#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
-pub struct Client<C, B> {
+pub struct Client<C, B, PK: pool::Key = DefaultPoolKey> {
config: Config,
connector: C,
exec: Exec,
@@ -43,7 +44,8 @@ pub struct Client<C, B> {
h1_builder: hyper::client::conn::http1::Builder,
#[cfg(feature = "http2")]
h2_builder: hyper::client::conn::http2::Builder<Exec>,
- pool: pool::Pool<PoolClient<B>, PoolKey>,
+ pool_key: Arc<dyn Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static>,
+ pool: pool::Pool<PoolClient<B>, PK>,
}
#[derive(Clone, Copy, Debug)]
@@ -90,7 +92,19 @@ macro_rules! e {
}
// We might change this... :shrug:
-type PoolKey = (http::uri::Scheme, http::uri::Authority);
+#[derive(Debug, Clone, Hash, PartialEq, Eq)]
+pub struct DefaultPoolKey(http::uri::Scheme, http::uri::Authority);
+impl pool::Key for DefaultPoolKey {
+ type Connect = Uri;
+ fn connect(self) -> Uri {
+ http::uri::Builder::new()
+ .scheme(self.0)
+ .authority(self.1)
+ .path_and_query("/")
+ .build()
+ .expect("domain is valid Uri")
+ }
+}
enum TrySendError<B> {
Retryable {
@@ -143,12 +157,13 @@ impl Client<(), ()> {
}
}
-impl<C, B> Client<C, B>
+impl<C, B, PK> Client<C, B, PK>
where
- C: Connect + Clone + Send + Sync + 'static,
+ C: Connect<PK> + Clone + Send + Sync + 'static,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
+ PK: pool::Key,
{
/// Send a `GET` request to the supplied `Uri`.
///
@@ -214,27 +229,15 @@ where
/// # }
/// # fn main() {}
/// ```
- pub fn request(&self, mut req: Request<B>) -> ResponseFuture {
- let is_http_connect = req.method() == Method::CONNECT;
- match req.version() {
- Version::HTTP_11 => (),
- Version::HTTP_10 => {
- if is_http_connect {
- warn!("CONNECT is not allowed for HTTP/1.0");
- return ResponseFuture::new(future::err(e!(UserUnsupportedRequestMethod)));
- }
- }
- Version::HTTP_2 => (),
- // completely unsupported HTTP version (like HTTP/0.9)!
- other => return ResponseFuture::error_version(other),
- };
-
- let pool_key = match extract_domain(req.uri_mut(), is_http_connect) {
+ pub fn request(&self, req: Request<B>) -> ResponseFuture {
+ let (mut parts, body) = req.into_parts();
+ let pool_key = match (self.pool_key)(&mut parts) {
Ok(s) => s,
Err(err) => {
return ResponseFuture::new(future::err(err));
}
};
+ let req = Request::from_parts(parts, body);
ResponseFuture::new(self.clone().send_request(req, pool_key))
}
@@ -242,12 +245,13 @@ where
async fn send_request(
self,
mut req: Request<B>,
- pool_key: PoolKey,
+ pool_key: PK,
) -> Result<Response<hyper::body::Incoming>, Error> {
let uri = req.uri().clone();
loop {
- req = match self.try_send_request(req, pool_key.clone()).await {
+ let pk: PK = pool_key.clone();
+ req = match self.try_send_request(req, pk).await {
Ok(resp) => return Ok(resp),
Err(TrySendError::Nope(err)) => return Err(err),
Err(TrySendError::Retryable {
@@ -275,7 +279,7 @@ where
async fn try_send_request(
&self,
mut req: Request<B>,
- pool_key: PoolKey,
+ pool_key: PK,
) -> Result<Response<hyper::body::Incoming>, TrySendError<B>> {
let mut pooled = self
.connection_for(pool_key)
@@ -366,12 +370,10 @@ where
Ok(res)
}
- async fn connection_for(
- &self,
- pool_key: PoolKey,
- ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, Error> {
+ async fn connection_for(&self, pool_key: PK) -> Result<pool::Pooled<PoolClient<B>, PK>, Error> {
loop {
- match self.one_connection_for(pool_key.clone()).await {
+ let pk: PK = pool_key.clone();
+ match self.one_connection_for(pk).await {
Ok(pooled) => return Ok(pooled),
Err(ClientConnectError::Normal(err)) => return Err(err),
Err(ClientConnectError::CheckoutIsClosed(reason)) => {
@@ -391,8 +393,8 @@ where
async fn one_connection_for(
&self,
- pool_key: PoolKey,
- ) -> Result<pool::Pooled<PoolClient<B>, PoolKey>, ClientConnectError> {
+ pool_key: PK,
+ ) -> Result<pool::Pooled<PoolClient<B>, PK>, ClientConnectError> {
// Return a single connection if pooling is not enabled
if !self.pool.is_enabled() {
return self
@@ -484,9 +486,8 @@ where
#[cfg(any(feature = "http1", feature = "http2"))]
fn connect_to(
&self,
- pool_key: PoolKey,
- ) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PoolKey>, Error>> + Send + Unpin
- {
+ pool_key: PK,
+ ) -> impl Lazy<Output = Result<pool::Pooled<PoolClient<B>, PK>, Error>> + Send + Unpin {
let executor = self.exec.clone();
let pool = self.pool.clone();
#[cfg(feature = "http1")]
@@ -496,7 +497,6 @@ where
let ver = self.config.ver;
let is_ver_h2 = ver == Ver::Http2;
let connector = self.connector.clone();
- let dst = domain_as_uri(pool_key.clone());
hyper_lazy(move || {
// Try to take a "connecting lock".
//
@@ -514,7 +514,7 @@ where
};
Either::Left(
connector
- .connect(super::connect::sealed::Internal, dst)
+ .connect(super::connect::sealed::Internal, pool_key.connect())
.map_err(|src| e!(Connect, src))
.and_then(move |io| {
let connected = io.connected();
@@ -669,7 +669,7 @@ where
impl<C, B> tower_service::Service<Request<B>> for Client<C, B>
where
- C: Connect + Clone + Send + Sync + 'static,
+ C: Connect<DefaultPoolKey> + Clone + Send + Sync + 'static,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
@@ -689,7 +689,7 @@ where
impl<C, B> tower_service::Service<Request<B>> for &'_ Client<C, B>
where
- C: Connect + Clone + Send + Sync + 'static,
+ C: Connect<DefaultPoolKey> + Clone + Send + Sync + 'static,
B: Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
@@ -707,8 +707,8 @@ where
}
}
-impl<C: Clone, B> Clone for Client<C, B> {
- fn clone(&self) -> Client<C, B> {
+impl<C: Clone, B, PK: pool::Key> Clone for Client<C, B, PK> {
+ fn clone(&self) -> Client<C, B, PK> {
Client {
config: self.config,
exec: self.exec.clone(),
@@ -717,6 +717,7 @@ impl<C: Clone, B> Clone for Client<C, B> {
#[cfg(feature = "http2")]
h2_builder: self.h2_builder.clone(),
connector: self.connector.clone(),
+ pool_key: self.pool_key.clone(),
pool: self.pool.clone(),
}
}
@@ -739,11 +740,6 @@ impl ResponseFuture {
inner: SyncWrapper::new(Box::pin(value)),
}
}
-
- fn error_version(ver: Version) -> Self {
- warn!("Request has unsupported version \"{:?}\"", ver);
- ResponseFuture::new(Box::pin(future::err(e!(UserUnsupportedVersion))))
- }
}
impl fmt::Debug for ResponseFuture {
@@ -937,10 +933,31 @@ fn authority_form(uri: &mut Uri) {
};
}
-fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error> {
+fn default_pool_key(req: &mut http::request::Parts) -> Result<DefaultPoolKey, Error> {
+ let is_http_connect = req.method == Method::CONNECT;
+ match req.version {
+ Version::HTTP_11 => (),
+ Version::HTTP_10 => {
+ if is_http_connect {
+ warn!("CONNECT is not allowed for HTTP/1.0");
+ return Err(e!(UserUnsupportedRequestMethod));
+ }
+ }
+ Version::HTTP_2 => (),
+ // completely unsupported HTTP version (like HTTP/0.9)!
+ other => {
+ warn!("Request has unsupported version \"{:?}\"", other);
+ return Err(e!(UserUnsupportedVersion));
+ }
+ };
+
+ extract_domain(&mut req.uri, is_http_connect)
+}
+
+fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<DefaultPoolKey, Error> {
let uri_clone = uri.clone();
match (uri_clone.scheme(), uri_clone.authority()) {
- (Some(scheme), Some(auth)) => Ok((scheme.clone(), auth.clone())),
+ (Some(scheme), Some(auth)) => Ok(DefaultPoolKey(scheme.clone(), auth.clone())),
(None, Some(auth)) if is_http_connect => {
let scheme = match auth.port_u16() {
Some(443) => {
@@ -952,7 +969,7 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error
Scheme::HTTP
}
};
- Ok((scheme, auth.clone()))
+ Ok(DefaultPoolKey(scheme, auth.clone()))
}
_ => {
debug!("Client requires absolute-form URIs, received: {:?}", uri);
@@ -961,15 +978,6 @@ fn extract_domain(uri: &mut Uri, is_http_connect: bool) -> Result<PoolKey, Error
}
}
-fn domain_as_uri((scheme, auth): PoolKey) -> Uri {
- http::uri::Builder::new()
- .scheme(scheme)
- .authority(auth)
- .path_and_query("/")
- .build()
- .expect("domain is valid Uri")
-}
-
fn set_scheme(uri: &mut Uri, scheme: Scheme) {
debug_assert!(
uri.scheme().is_none(),
@@ -1589,11 +1597,27 @@ impl Builder {
}
/// Combine the configuration of this builder with a connector to create a `Client`.
- pub fn build<C, B>(&self, connector: C) -> Client<C, B>
+ pub fn build<'a, C, B>(&'a self, connector: C) -> Client<C, B, DefaultPoolKey>
+ where
+ C: Connect<DefaultPoolKey> + Clone,
+ B: Body + Send,
+ B::Data: Send,
+ {
+ self.build_with_pool_key::<C, B, DefaultPoolKey>(connector, default_pool_key)
+ }
+
+ /// Combine the configuration of this builder with a connector to create a `Client`, with a custom pooling key.
+ /// A function to extract the pool key from the request is required.
+ pub fn build_with_pool_key<C, B, PK>(
+ &self,
+ connector: C,
+ pool_key: impl Fn(&mut http::request::Parts) -> Result<PK, Error> + Send + Sync + 'static,
+ ) -> Client<C, B, PK>
where
- C: Connect + Clone,
+ C: Connect<PK> + Clone,
B: Body + Send,
B::Data: Send,
+ PK: pool::Key,
{
let exec = self.exec.clone();
let timer = self.pool_timer.clone();
@@ -1605,7 +1629,8 @@ impl Builder {
#[cfg(feature = "http2")]
h2_builder: self.h2_builder.clone(),
connector,
- pool: pool::Pool::new(self.pool_config, exec, timer),
+ pool_key: Arc::new(pool_key),
+ pool: pool::Pool::<_, PK>::new(self.pool_config, exec, timer),
}
}
}
diff --git a/src/client/legacy/connect/http.rs b/src/client/legacy/connect/http.rs
index f19a78e..64592f2 100644
--- a/src/client/legacy/connect/http.rs
+++ b/src/client/legacy/connect/http.rs
@@ -1020,6 +1020,7 @@ mod tests {
use ::http::Uri;
+ use crate::client::legacy::client::DefaultPoolKey;
use crate::client::legacy::connect::http::TcpKeepaliveConfig;
use super::super::sealed::{Connect, ConnectSvc};
@@ -1030,7 +1031,7 @@ mod tests {
async fn connect<C>(
connector: C,
dst: Uri,
- ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
+ ) -> Result<<C::_Svc as ConnectSvc<DefaultPoolKey>>::Connection, <C::_Svc as ConnectSvc>::Error>
where
C: Connect,
{
diff --git a/src/client/legacy/connect/mod.rs b/src/client/legacy/connect/mod.rs
index 90a9767..32ba693 100644
--- a/src/client/legacy/connect/mod.rs
+++ b/src/client/legacy/connect/mod.rs
@@ -304,10 +304,11 @@ pub(super) mod sealed {
use std::error::Error as StdError;
use std::future::Future;
- use ::http::Uri;
use hyper::rt::{Read, Write};
use super::Connection;
+ use crate::client::legacy::client::DefaultPoolKey;
+ use crate::client::legacy::pool;
/// Connect to a destination, returning an IO transport.
///
@@ -321,61 +322,72 @@ pub(super) mod sealed {
/// implement this trait, but `tower::Service<Uri>` instead.
// The `Sized` bound is to prevent creating `dyn Connect`, since they cannot
// fit the `Connect` bounds because of the blanket impl for `Service`.
- pub trait Connect: Sealed + Sized {
+ pub trait Connect<PK: pool::Key = DefaultPoolKey>: Sealed<PK> + Sized {
#[doc(hidden)]
- type _Svc: ConnectSvc;
+ type _Svc: ConnectSvc<PK>;
#[doc(hidden)]
- fn connect(self, internal_only: Internal, dst: Uri) -> <Self::_Svc as ConnectSvc>::Future;
+ fn connect(
+ self,
+ internal_only: Internal,
+ dst: PK::Connect,
+ ) -> <Self::_Svc as ConnectSvc<PK>>::Future;
}
- pub trait ConnectSvc {
+ pub trait ConnectSvc<PK: pool::Key = DefaultPoolKey> {
type Connection: Read + Write + Connection + Unpin + Send + 'static;
type Error: Into<Box<dyn StdError + Send + Sync>>;
type Future: Future<Output = Result<Self::Connection, Self::Error>> + Unpin + Send + 'static;
- fn connect(self, internal_only: Internal, dst: Uri) -> Self::Future;
+ fn connect(self, internal_only: Internal, dst: PK::Connect) -> Self::Future;
}
- impl<S, T> Connect for S
+ impl<S, T, PK> Connect<PK> for S
where
- S: tower_service::Service<Uri, Response = T> + Send + 'static,
+ PK: pool::Key,
+ S: tower_service::Service<PK::Connect, Response = T> + Send + 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
S::Future: Unpin + Send,
T: Read + Write + Connection + Unpin + Send + 'static,
{
type _Svc = S;
- fn connect(self, _: Internal, dst: Uri) -> crate::service::Oneshot<S, Uri> {
+ fn connect(self, _: Internal, dst: PK::Connect) -> crate::service::Oneshot<S, PK::Connect> {
crate::service::Oneshot::new(self, dst)
}
}
- impl<S, T> ConnectSvc for S
+ impl<S, T, PK> ConnectSvc<PK> for S
where
- S: tower_service::Service<Uri, Response = T> + Send + 'static,
+ PK: pool::Key,
+ S: tower_service::Service<PK::Connect, Response = T> + Send + 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
S::Future: Unpin + Send,
T: Read + Write + Connection + Unpin + Send + 'static,
{
type Connection = T;
type Error = S::Error;
- type Future = crate::service::Oneshot<S, Uri>;
+ type Future = crate::service::Oneshot<S, PK::Connect>;
- fn connect(self, _: Internal, dst: Uri) -> Self::Future {
+ fn connect(self, _: Internal, dst: PK::Connect) -> Self::Future {
crate::service::Oneshot::new(self, dst)
}
}
- impl<S, T> Sealed for S
+ impl<S, T, PK> Sealed<PK> for S
where
- S: tower_service::Service<Uri, Response = T> + Send,
+ PK: pool::Key,
+ S: tower_service::Service<PK::Connect, Response = T> + Send,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
S::Future: Unpin + Send,
T: Read + Write + Connection + Unpin + Send + 'static,
{
}
- pub trait Sealed {}
+ pub trait Sealed<PK>
+ where
+ PK: pool::Key,
+ {
+ }
#[allow(missing_debug_implementations)]
pub struct Internal;
}
diff --git a/src/client/legacy/pool.rs b/src/client/legacy/pool.rs
index 727f54b..8ffc356 100644
--- a/src/client/legacy/pool.rs
+++ b/src/client/legacy/pool.rs
@@ -43,9 +43,10 @@ pub trait Poolable: Unpin + Send + Sized + 'static {
fn can_share(&self) -> bool;
}
-pub trait Key: Eq + Hash + Clone + Debug + Unpin + Send + 'static {}
-
-impl<T> Key for T where T: Eq + Hash + Clone + Debug + Unpin + Send + 'static {}
+pub trait Key: Eq + Hash + Clone + Debug + Unpin + Send + 'static {
+ type Connect: Send + 'static;
+ fn connect(self) -> Self::Connect;
+}
/// A marker to identify what version a pooled connection is.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@@ -845,8 +846,17 @@ mod tests {
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct KeyImpl(http::uri::Scheme, http::uri::Authority);
-
- type KeyTuple = (http::uri::Scheme, http::uri::Authority);
+ impl super::Key for KeyImpl {
+ type Connect = http::uri::Uri;
+ fn connect(self) -> http::uri::Uri {
+ http::uri::Builder::new()
+ .scheme(self.0)
+ .authority(self.1)
+ .path_and_query("/")
+ .build()
+ .expect("domain is valid Uri")
+ }
+ }
/// Test unique reservations.
#[derive(Debug, PartialEq, Eq)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment