Skip to content

Instantly share code, notes, and snippets.

@JasonGross
Last active September 14, 2024 00:01
Show Gist options
  • Save JasonGross/8a70de4f5ae464fb9679c8acf1041963 to your computer and use it in GitHub Desktop.
Save JasonGross/8a70de4f5ae464fb9679c8acf1041963 to your computer and use it in GitHub Desktop.
Some notation for functional loops in Coq
*.aux
*.glob
Require Import List.
Definition fold_cps [R A B] (init : A) l f k : R :=
let x := @fold_left A B f l init in k x.
Compute fold_cps 0 (seq 0 10) Nat.add id.
Require Ltac2.Control.
Require Import Ltac2.Ltac2.
Require Import Ltac2.Notations.
Require Import Ltac2.Printf.
Ltac2 replace_var_with_rel (template_pair : constr) (rels : constr list) :=
let pair_c := '@pair in
let rec go (c : constr) (rels : constr list) : constr * constr list :=
match Constr.Unsafe.kind c with
| Constr.Unsafe.Var v
=> match rels with
| rel :: rels
=> (rel, rels)
| []
=> Control.zero (Invalid_argument (Some (fprintf "Not enough rels to replace all vars in %t, missing rel at %I" template_pair v)))
end
| Constr.Unsafe.App f args =>
if Constr.equal f pair_c then
let tA := Array.get args 0 in
let tB := Array.get args 1 in
let (a, rels) := go (Array.get args 2) rels in
let (b, rels) := go (Array.get args 3) rels in
let args := Array.of_list [tA; tB; a; b] in
(Constr.Unsafe.make (Constr.Unsafe.App f args), rels)
else
Control.zero (Invalid_argument (Some (fprintf "Expected a pair, got %t" c)))
| _
=> Control.zero (Invalid_argument (Some (fprintf "Expected a var or pair, got %t" c)))
end in
let (result, rels) := go template_pair rels in
if Int.equal (List.length rels) 0 then
result
else
Control.zero (Invalid_argument (Some (fprintf "Too many rels to replace all vars in %t, extra rels: %i" template_pair (List.length rels)))).
Definition continue {T} {x : T} := x.
Ltac2 head (c : constr) :=
match Constr.Unsafe.kind c with
| Constr.Unsafe.App f _ => f
| _ => c
end.
Definition unify_then A (x y : A) (pf : x = y) := y.
Ltac2 rec process_let_block_helper (template_pair : constr) (binders_and_indices : (ident * constr) list) (let_block : constr) :=
(*printf ">> %t" let_block;*)
(* Walks a block of [let ... in ...], looking for assignments to names in binders_and_indices; at the bottom, we return a pair shaped like template_pair which replaces each rel in order of the most recently assigned version of that name *)
let continue_c := '@continue in
let fold_cps_c := '@fold_cps in
let assert_is_continue (c : constr) :=
if Constr.equal (head c) continue_c then
()
else
Control.zero (Invalid_argument (Some (fprintf "The final value underneath lets should be headed by %t, not %t (head of %t)" continue_c (head c) c))) in
(*printf "match >> %t" let_block;*)
match Constr.Unsafe.kind let_block with
| Constr.Unsafe.LetIn b value let_block
=> (* we are going under a binder, so lift indices by 1 (* TODO: check that 1 1 is correct *) *)
let binders_and_indices := List.map (fun (name, c) => (name, Constr.Unsafe.liftn 1 1 c)) binders_and_indices in
let binders_and_indices
:= match Constr.Binder.name b with
| Some new_name
=> (* if the name matches anything in binders_and_indices, then we replace that rel with 0, which gets incremented below to 1 (rels are 1-indexed IIRC) *)
List.map (fun (name, i) => if Ident.equal name new_name
then (name, Constr.Unsafe.make (Constr.Unsafe.Rel 1)) else (name, i))
binders_and_indices
| None =>
(* this let value is not named, so it's not one of the ones we care about, so we just increment all de Bruijn indices by one below *)
binders_and_indices
end in
let new_let_block := process_let_block_helper template_pair binders_and_indices let_block in
Constr.Unsafe.make (Constr.Unsafe.LetIn b value new_let_block)
| Constr.Unsafe.Case c x iv y bl (* (case, (constr * Binder.relevance), case_invert, constr, constr array) *)
=> let new_bl := Array.map (fun b => process_let_block_helper template_pair binders_and_indices b) bl in
Constr.Unsafe.make (Constr.Unsafe.Case c x iv y new_bl)
| Constr.Unsafe.App nesting_construct args
=> if Constr.equal nesting_construct fold_cps_c then
let cont_idx := Int.sub (Array.length args) 1 in
let cont := Array.get args cont_idx in
let cont := match Constr.Unsafe.kind cont with
| Constr.Unsafe.Lambda b body
=> let binders_and_indices := List.map (fun (name, c) => (name, Constr.Unsafe.liftn 1 1 c)) binders_and_indices in
let body := process_let_block_helper template_pair binders_and_indices body in
(* TODO: implement automatic shadowing of outer loop variables with inner loop variables *)
Constr.Unsafe.make (Constr.Unsafe.Lambda b body)
| _ => Control.zero (Invalid_argument (Some (fprintf "%t continuation should be a lambda not %t" fold_cps_c let_block)))
end in
Array.set args cont_idx cont;
Constr.Unsafe.make (Constr.Unsafe.App nesting_construct args)
else
((*printf "AAA"; *)assert_is_continue let_block;
let rels := List.map (fun (_, c) => c) binders_and_indices in
let result := replace_var_with_rel template_pair rels in
(*printf "result: %t" result;*)
let final_continue_arg := Array.get args (Int.sub (Array.length args) 1) in
(if Constr.is_evar final_continue_arg then () else Control.zero (Invalid_argument (Some (fprintf "final argument to %t (in %t) should be evar, not %t" continue_c let_block final_continue_arg))));
let continue_ty := Array.get args 0 in
let eq_refl_c := Constr.Unsafe.make (Constr.Unsafe.App '(@eq_refl) (Array.of_list [continue_ty; final_continue_arg])) in
Constr.Unsafe.make (Constr.Unsafe.App '(@unify_then) (Array.of_list [continue_ty; final_continue_arg; result; eq_refl_c]))
)
| _ =>
Control.zero (Invalid_argument (Some (fprintf "The final value underneath lets should be %t or %t applied to arguments, not %t" continue_c fold_cps_c let_block)))
end.
Ltac2 rec pair_to_list (pair : constr) :=
let pair_c := '@pair in
match Constr.Unsafe.kind pair with
| Constr.Unsafe.App f args =>
if Constr.equal f '@pair then
if Int.equal (Array.length args) 4 then
let a := Array.get args 2 in
let b := Array.get args 3 in
List.append (pair_to_list a) (pair_to_list b)
else
Control.zero (Invalid_argument (Some (fprintf "Expected a %t with 4 arguments, got %t of length %i" pair_c pair (Array.length args))))
else
Control.zero (Invalid_argument (Some (fprintf "Expected a pair, got %t" pair)))
| Constr.Unsafe.Var name => [(name, pair)]
| _ => Control.zero (Invalid_argument (Some (fprintf "Expected a pair or var, got %t" pair)))
end.
Ltac2 process_let_block_tac (init_indices : constr) (let_block : constr) :=
let result := process_let_block_helper init_indices (pair_to_list init_indices) let_block in
match Constr.Unsafe.check result with
| Val result => eval cbv beta delta [unify_then continue] in $result
| Err exn => Control.zero exn
end.
Ltac2 pretype_open_constr (c : preterm) :=
Constr.Pretype.pretype Constr.Pretype.Flags.open_constr_flags_no_tc Constr.Pretype.expected_without_type_constraint c.
Notation process_let_block init_indices pair_fun let_block := (ltac2:(
let result := process_let_block_tac (pretype_open_constr init_indices) (pretype_open_constr let_block) in (*printf "%t" result ;*)
Control.refine (fun () => result))) (only parsing).
Notation "'loop' ' init_binder := init_value 'for' loop_var 'in' fold_list 'do' let_block 'in' cont" :=
(fold_cps init_value fold_list (fun 'init_binder loop_var => process_let_block init_binder (fun 'init_binder => init_binder) let_block) (fun 'init_binder => cont))
(init_binder pattern, at level 200, let_block at level 200, cont at level 200, only parsing).
Goal False.
epose (loop 'sum := 0
for i in seq 0 10 do
let sum := sum + i in
continue
in sum).
(*
Anomaly
"File "pretyping/glob_ops.ml", line 30, characters 11-17: Assertion failed."
Please report at http://coq.inria.fr/bugs/.
*)
epose (loop '(sum, t) := (0, 0)
for i in seq 0 10 do
let sum := sum + i + t in
let t := if Nat.odd sum then 7 else 42 in
continue
in (sum, t)).
pose (fun A =>
loop 's := 0
for i in seq 0 10 do (
loop 'si := 0
for j in seq 0 10 do (
let si := si + nth_default 0 (nth_default nil A i) j in
continue )
in
let s := s + si in
continue )
in s).
Require Import Coq.Arith.Arith.
Compute loop 'sum := 0
for i in seq 0 20 do
if (i <? 10)%nat then let sum := sum + i in continue
else if Nat.odd i then let sum := sum-1 in continue else let sum := sum+1 in continue
in sum.
Eval cbv iota in (fun x x => x) 0.
Eval cbv beta in (fun x x => x) 0.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment