Skip to content

Instantly share code, notes, and snippets.

@Narsil
Created August 16, 2023 06:55
Show Gist options
  • Save Narsil/c15cdc4c9b8efd77017fcecc99bb20bf to your computer and use it in GitHub Desktop.
Save Narsil/c15cdc4c9b8efd77017fcecc99bb20bf to your computer and use it in GitHub Desktop.
Proposal for `VarBuilder` more generic + Doable in user space.
struct Tensor{
id: usize
}
struct Linear{
weight: Tensor,
bias: Option<Tensor>,
}
struct Mlp{
in_proj: Linear,
out_proj: Linear
}
trait Loader{
fn load_tensor(&mut self, prefix: &[&'static str]) -> Tensor;
}
trait Saver{
fn save_tensor(&mut self, prefix: &[&'static str], tensor: &Tensor);
}
trait Module: Sized{
fn load<L: Loader>(loader: &mut L) -> Self{
let mut prefix = vec![];
Self::load_nested(&mut prefix, loader)
}
fn load_nested<L: Loader>(prefix: &mut Vec<&'static str>, loader:&mut L) -> Self;
fn save<S: Saver>(&self, saver: &mut S){
let mut prefix = vec![];
self.save_nested(&mut prefix, saver);
}
fn save_nested<S: Saver>(&self, prefix: &mut Vec<&'static str>, saver:&mut S);
}
impl Module for Linear {
fn load_nested<L: Loader>(prefix: &mut Vec<&'static str>, loader: &mut L) -> Self{
prefix.push("weight");
let weight = loader.load_tensor(prefix);
prefix.pop();
prefix.push("bias");
let bias = Some(loader.load_tensor(prefix));
prefix.pop();
Self{
weight,
bias
}
}
fn save_nested<S: Saver>(&self, prefix: &mut Vec<&'static str>, saver: &mut S){
prefix.push("weight");
saver.save_tensor(prefix, &self.weight);
prefix.pop();
prefix.push("bias");
saver.save_tensor(prefix, self.bias.as_ref().unwrap());
prefix.pop();
}
}
impl Module for Mlp {
fn load_nested<L: Loader>(prefix: &mut Vec<&'static str>, loader: &mut L) -> Self{
prefix.push("in_proj");
let in_proj = Linear::load_nested(prefix, loader);
prefix.pop();
prefix.push("out_proj");
let out_proj = Linear::load_nested(prefix, loader);
prefix.pop();
Self{
in_proj,
out_proj
}
}
fn save_nested<S: Saver>(&self, prefix: &mut Vec<&'static str>, saver: &mut S){
prefix.push("weight");
self.in_proj.save_nested(prefix, saver);
prefix.pop();
prefix.push("bias");
self.out_proj.save_nested(prefix, saver);
prefix.pop();
}
}
struct Init{
index: usize
}
impl Loader for Init{
fn load_tensor(&mut self, _prefix: &[&'static str]) -> Tensor{
let tensor = Tensor{id: self.index};
self.index += 1;
tensor
}
}
fn main() {
let mut init = Init{index:0};
let linear = Linear::load(&mut init);
assert_eq!(linear.weight.id, 0);
assert_eq!(linear.bias.unwrap().id, 1);
let mut init = Init{index:0};
let mlp = Mlp::load(&mut init);
assert_eq!(mlp.in_proj.weight.id, 0);
assert_eq!(mlp.in_proj.bias.unwrap().id, 1);
assert_eq!(mlp.out_proj.weight.id, 2);
assert_eq!(mlp.out_proj.bias.unwrap().id, 3);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment