Created
August 16, 2023 06:55
-
-
Save Narsil/c15cdc4c9b8efd77017fcecc99bb20bf to your computer and use it in GitHub Desktop.
Proposal for `VarBuilder` more generic + Doable in user space.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct Tensor{ | |
id: usize | |
} | |
struct Linear{ | |
weight: Tensor, | |
bias: Option<Tensor>, | |
} | |
struct Mlp{ | |
in_proj: Linear, | |
out_proj: Linear | |
} | |
trait Loader{ | |
fn load_tensor(&mut self, prefix: &[&'static str]) -> Tensor; | |
} | |
trait Saver{ | |
fn save_tensor(&mut self, prefix: &[&'static str], tensor: &Tensor); | |
} | |
trait Module: Sized{ | |
fn load<L: Loader>(loader: &mut L) -> Self{ | |
let mut prefix = vec![]; | |
Self::load_nested(&mut prefix, loader) | |
} | |
fn load_nested<L: Loader>(prefix: &mut Vec<&'static str>, loader:&mut L) -> Self; | |
fn save<S: Saver>(&self, saver: &mut S){ | |
let mut prefix = vec![]; | |
self.save_nested(&mut prefix, saver); | |
} | |
fn save_nested<S: Saver>(&self, prefix: &mut Vec<&'static str>, saver:&mut S); | |
} | |
impl Module for Linear { | |
fn load_nested<L: Loader>(prefix: &mut Vec<&'static str>, loader: &mut L) -> Self{ | |
prefix.push("weight"); | |
let weight = loader.load_tensor(prefix); | |
prefix.pop(); | |
prefix.push("bias"); | |
let bias = Some(loader.load_tensor(prefix)); | |
prefix.pop(); | |
Self{ | |
weight, | |
bias | |
} | |
} | |
fn save_nested<S: Saver>(&self, prefix: &mut Vec<&'static str>, saver: &mut S){ | |
prefix.push("weight"); | |
saver.save_tensor(prefix, &self.weight); | |
prefix.pop(); | |
prefix.push("bias"); | |
saver.save_tensor(prefix, self.bias.as_ref().unwrap()); | |
prefix.pop(); | |
} | |
} | |
impl Module for Mlp { | |
fn load_nested<L: Loader>(prefix: &mut Vec<&'static str>, loader: &mut L) -> Self{ | |
prefix.push("in_proj"); | |
let in_proj = Linear::load_nested(prefix, loader); | |
prefix.pop(); | |
prefix.push("out_proj"); | |
let out_proj = Linear::load_nested(prefix, loader); | |
prefix.pop(); | |
Self{ | |
in_proj, | |
out_proj | |
} | |
} | |
fn save_nested<S: Saver>(&self, prefix: &mut Vec<&'static str>, saver: &mut S){ | |
prefix.push("weight"); | |
self.in_proj.save_nested(prefix, saver); | |
prefix.pop(); | |
prefix.push("bias"); | |
self.out_proj.save_nested(prefix, saver); | |
prefix.pop(); | |
} | |
} | |
struct Init{ | |
index: usize | |
} | |
impl Loader for Init{ | |
fn load_tensor(&mut self, _prefix: &[&'static str]) -> Tensor{ | |
let tensor = Tensor{id: self.index}; | |
self.index += 1; | |
tensor | |
} | |
} | |
fn main() { | |
let mut init = Init{index:0}; | |
let linear = Linear::load(&mut init); | |
assert_eq!(linear.weight.id, 0); | |
assert_eq!(linear.bias.unwrap().id, 1); | |
let mut init = Init{index:0}; | |
let mlp = Mlp::load(&mut init); | |
assert_eq!(mlp.in_proj.weight.id, 0); | |
assert_eq!(mlp.in_proj.bias.unwrap().id, 1); | |
assert_eq!(mlp.out_proj.weight.id, 2); | |
assert_eq!(mlp.out_proj.bias.unwrap().id, 3); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment