Skip to content

Instantly share code, notes, and snippets.

@collinvandyck
Created January 6, 2025 20:06
Show Gist options
  • Save collinvandyck/001c837d3a9f0a585b851ef246bfe350 to your computer and use it in GitHub Desktop.
Save collinvandyck/001c837d3a9f0a585b851ef246bfe350 to your computer and use it in GitHub Desktop.
#![allow(unused)]
use anyhow::{Result, bail};
use core::panic;
use itertools::Itertools;
fn main() {
}
#[derive(Debug)]
enum Tok {
Num(f64),
Op(char),
}
struct Stack(Vec<Vec<Tok>>);
impl Stack {
fn new() -> Self {
Self(vec![vec![]])
}
fn push_grp(&mut self) {
self.0.push(vec![]);
}
fn pop_grp(&mut self) -> Vec<Tok> {
self.0.pop().unwrap()
}
fn push(&mut self, tok: Tok) {
self.0.last_mut().unwrap().push(tok);
}
fn len(&self) -> usize {
self.0.last().unwrap().len()
}
fn drain(&mut self, v: usize) -> Vec<Tok> {
self.0
.last_mut()
.unwrap()
.drain(0..v)
.collect_vec()
}
}
fn eval(s: &str) -> Result<f64> {
fn eval(chars: &[char]) -> Result<f64> {
let mut pos = 0;
let mut stack = Stack::new();
while pos < chars.len() {
let ch = chars[pos];
match ch {
'(' => {
stack.push_grp();
pos += 1;
}
')' => {
let other = stack.pop_grp();
if other.len() != 1 {
bail!("bad group expr: {other:?}");
}
let tok = other.into_iter().next().unwrap();
stack.push(tok);
pos += 1;
}
'0'..'9' => {
let cur = pos;
while pos < chars.len() && chars[pos].is_numeric() {
pos += 1;
}
let num: f64 = chars[cur..pos]
.iter()
.collect::<String>()
.parse()?;
stack.push(Tok::Num(num));
}
'+' | '-' | '*' | '/' => {
stack.push(Tok::Op(ch));
pos += 1;
}
' ' | '\n' => pos += 1,
_ => bail!("bad ch: {ch:?}"),
}
let len = stack.len();
if len >= 3 {
match stack.drain(3).as_slice() {
[Tok::Num(v1), Tok::Op(op), Tok::Num(v2)] => {
let num = match op {
'+' => *v1 + *v2,
'-' => *v1 - *v2,
'*' => *v1 * *v2,
'/' => *v1 / *v2,
_ => bail!("invalid op: {op:?}"),
};
stack.push(Tok::Num(num));
}
rest => bail!("invalid stack state: {rest:?}"),
}
}
}
match stack.pop_grp().as_slice() {
[Tok::Num(v)] => Ok(*v),
xs => bail!("invalid state: {xs:?}"),
}
}
let mut chars: Vec<char> = s.chars().collect_vec();
eval(&chars)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eval() {
assert_eq!(eval("3").unwrap(), 3.0);
assert_eq!(eval("3 + 4").unwrap(), 7.0);
assert_eq!(eval("3 + 4 + 3").unwrap(), 10.0);
assert_eq!(eval("3 * (4 + 3)").unwrap(), 21.0);
assert_eq!(eval("3 * ((4/2) + 3)").unwrap(), 15.0);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment