Last active
May 5, 2024 19:52
-
-
Save sagoez/5b5a805050b694cc4092cd067c9c5048 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use crate::{ | |
domain::{Error, Unit}, | |
service::Cache as CacheConfig, | |
}; | |
use async_trait::async_trait; | |
use deadpool_redis::{ | |
redis::{ | |
cmd, from_redis_value, AsyncCommands, FromRedisValue, JsonAsyncCommands, RedisError, | |
RedisResult, RedisWrite, ToRedisArgs, | |
}, | |
Connection, Pool, PoolConfig, Runtime, | |
}; | |
use futures::Future; | |
use secrecy::ExposeSecret; | |
use serde::{Deserialize, Serialize}; | |
use serde_json::Value; | |
#[async_trait] | |
pub trait CacheExt { | |
async fn get_or_insert_with<F, Fut, T>( | |
&self, | |
key: &str, | |
f: F, | |
expire: Option<u64>, | |
) -> Result<T, Error> | |
where | |
T: for<'de> Deserialize<'de> + Serialize + Send + Clone, | |
F: FnOnce() -> Fut + Send, | |
Fut: Future<Output = Result<T, Error>> + Send; | |
async fn get<T>(&self, key: &str) -> Result<Option<T>, Error> | |
where | |
T: for<'de> Deserialize<'de>; | |
async fn set<T>(&self, key: &str, value: T, expire: Option<u64>) -> Result<Unit, Error> | |
where | |
T: Serialize + Send; | |
async fn remove(&self, key: &str) -> Result<Unit, Error>; | |
async fn clear(&self) -> Result<Unit, Error>; | |
} | |
#[derive(Debug, Clone)] | |
struct CacheEntry { | |
key: String, | |
value: Option<ValueWrapper>, | |
} | |
impl CacheEntry { | |
pub fn new(key: String, value: Value) -> Self { | |
Self { | |
key, | |
value: Some(ValueWrapper(value)), | |
} | |
} | |
fn get_as<T>(&self) -> Option<T> | |
where | |
T: for<'de> Deserialize<'de>, | |
{ | |
match &self.value { | |
Some(value) => serde_json::from_value(value.0.clone()).ok(), | |
None => None, | |
} | |
} | |
} | |
#[derive(Debug, Clone, Serialize, Deserialize)] | |
#[serde(transparent)] | |
pub struct ValueWrapper(Value); | |
impl ToRedisArgs for ValueWrapper { | |
fn write_redis_args<W>(&self, out: &mut W) | |
where | |
W: ?Sized + RedisWrite, | |
{ | |
let json_str = serde_json::to_string(&self).unwrap(); | |
out.write_arg(json_str.as_bytes()); | |
} | |
} | |
impl FromRedisValue for ValueWrapper { | |
fn from_redis_value(v: &deadpool_redis::redis::Value) -> RedisResult<Self> { | |
let json_str: Option<String> = from_redis_value(v)?; | |
match json_str { | |
Some(json_str) => { | |
let json: Value = serde_json::from_str(json_str.as_str()).map_err(|error| { | |
RedisError::from(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
error.to_string(), | |
)) | |
})?; | |
Ok(Self(json)) | |
} | |
None => Ok(Self(Value::Null)), | |
} | |
} | |
} | |
impl ToRedisArgs for CacheEntry { | |
fn write_redis_args<W>(&self, out: &mut W) | |
where | |
W: ?Sized + RedisWrite, | |
{ | |
let json = serde_json::json!( | |
{ | |
"key": self.key, | |
"value": self.value | |
} | |
); | |
let json_str = json.to_string(); | |
out.write_arg(json_str.as_bytes()); | |
} | |
} | |
impl FromRedisValue for CacheEntry { | |
fn from_redis_value(v: &deadpool_redis::redis::Value) -> RedisResult<Self> { | |
let json_str: Option<String> = from_redis_value(v)?; | |
match json_str { | |
Some(json_str) => { | |
let json: Value = serde_json::from_str(json_str.as_str()).map_err(|error| { | |
RedisError::from(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
error.to_string(), | |
)) | |
})?; | |
let key = json.get("key").ok_or(RedisError::from(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"key not found", | |
)))?; | |
let key = key.as_str().ok_or(RedisError::from(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"key not found", | |
)))?; | |
let key = key.to_string(); | |
let value = json | |
.get("value") | |
.ok_or(RedisError::from(std::io::Error::new( | |
std::io::ErrorKind::InvalidData, | |
"value not found", | |
)))?; | |
let value = value.clone(); | |
Ok(Self { | |
key, | |
value: Some(ValueWrapper(value)), | |
}) | |
} | |
None => Ok(Self { | |
key: Default::default(), | |
value: None, | |
}), | |
} | |
} | |
} | |
pub struct Cache { | |
pool: Pool, | |
} | |
impl Cache { | |
pub fn new(configuration: &CacheConfig) -> Result<Self, Error> { | |
let redis = deadpool_redis::Config { | |
url: Some(configuration.url().expose_secret().into()), | |
connection: None, | |
pool: Some(PoolConfig { | |
max_size: configuration.pool_size(), | |
timeouts: configuration.timeouts(), | |
..Default::default() | |
}), | |
}; | |
let pool = redis.create_pool(Some(Runtime::Tokio1)).map_err(|error| { | |
Error::cache_generic( | |
"There was an error with the configuration", | |
None, | |
None, | |
error.into(), | |
) | |
})?; | |
Ok(Self { pool }) | |
} | |
pub async fn connection(&self) -> Result<Connection, Error> { | |
let connection = self.pool.get().await.map_err(|error| { | |
Error::cache_generic( | |
"There was an error with the connection", | |
None, | |
None, | |
error.into(), | |
) | |
})?; | |
Ok(connection) | |
} | |
} | |
#[async_trait] | |
impl CacheExt for Cache { | |
#[tracing::instrument(name = "cache::get_or_insert_with", skip(self, key, f))] | |
async fn get_or_insert_with<F, Fut, T>( | |
&self, | |
key: &str, | |
f: F, | |
expire: Option<u64>, | |
) -> Result<T, Error> | |
where | |
T: for<'de> Deserialize<'de> + Serialize + Send + Clone, | |
F: FnOnce() -> Fut + Send, | |
Fut: Future<Output = Result<T, Error>> + Send, | |
{ | |
match self.get(key).await? { | |
Some(entry) => { | |
tracing::debug!("Cache hit for key: {}", key); | |
Ok(entry) | |
} | |
None => { | |
let value = f().await?; | |
self.set(key, value.clone(), expire).await?; | |
Ok(value) | |
} | |
} | |
} | |
#[tracing::instrument(name = "cache::get", skip(self, key))] | |
async fn get<T>(&self, key: &str) -> Result<Option<T>, Error> | |
where | |
T: for<'de> Deserialize<'de>, | |
{ | |
let entry: CacheEntry = self.connection().await?.get(key).await?; | |
match entry.value { | |
Some(_) => Ok(entry.get_as()), | |
None => Ok(None), | |
} | |
} | |
#[tracing::instrument(name = "cache::insert", skip(self, key, value, expire))] | |
async fn set<T>(&self, key: &str, value: T, expire: Option<u64>) -> Result<Unit, Error> | |
where | |
T: Serialize + Send, | |
{ | |
let entry = CacheEntry::new(key.to_string(), serde_json::to_value(value)?); | |
self.connection() | |
.await? | |
.set_ex::<_, CacheEntry, Option<String>>( | |
entry.key.clone(), | |
entry.clone(), | |
expire.unwrap_or(86400), | |
) | |
.await | |
.map(|_| ())?; | |
Ok(()) | |
} | |
#[tracing::instrument(name = "cache::remove", skip(self, key))] | |
async fn remove(&self, key: &str) -> Result<Unit, Error> { | |
self.connection() | |
.await? | |
.del::<_, i64>(key) | |
.await | |
.map(|_| ())?; | |
Ok(()) | |
} | |
#[tracing::instrument(name = "cache::clear", skip(self))] | |
async fn clear(&self) -> Result<Unit, Error> { | |
cmd("FLUSHALL") | |
.query_async(&mut *self.connection().await?) | |
.await | |
.map(|_: ()| ())?; | |
Ok(()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment