Last active
September 14, 2024 00:01
-
-
Save JasonGross/8a70de4f5ae464fb9679c8acf1041963 to your computer and use it in GitHub Desktop.
Some notation for functional loops in Coq
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
*.aux | |
*.glob |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Require Import List. | |
Definition fold_cps [R A B] (init : A) l f k : R := | |
let x := @fold_left A B f l init in k x. | |
Compute fold_cps 0 (seq 0 10) Nat.add id. | |
Require Ltac2.Control. | |
Require Import Ltac2.Ltac2. | |
Require Import Ltac2.Notations. | |
Require Import Ltac2.Printf. | |
Ltac2 replace_var_with_rel (template_pair : constr) (rels : constr list) := | |
let pair_c := '@pair in | |
let rec go (c : constr) (rels : constr list) : constr * constr list := | |
match Constr.Unsafe.kind c with | |
| Constr.Unsafe.Var v | |
=> match rels with | |
| rel :: rels | |
=> (rel, rels) | |
| [] | |
=> Control.zero (Invalid_argument (Some (fprintf "Not enough rels to replace all vars in %t, missing rel at %I" template_pair v))) | |
end | |
| Constr.Unsafe.App f args => | |
if Constr.equal f pair_c then | |
let tA := Array.get args 0 in | |
let tB := Array.get args 1 in | |
let (a, rels) := go (Array.get args 2) rels in | |
let (b, rels) := go (Array.get args 3) rels in | |
let args := Array.of_list [tA; tB; a; b] in | |
(Constr.Unsafe.make (Constr.Unsafe.App f args), rels) | |
else | |
Control.zero (Invalid_argument (Some (fprintf "Expected a pair, got %t" c))) | |
| _ | |
=> Control.zero (Invalid_argument (Some (fprintf "Expected a var or pair, got %t" c))) | |
end in | |
let (result, rels) := go template_pair rels in | |
if Int.equal (List.length rels) 0 then | |
result | |
else | |
Control.zero (Invalid_argument (Some (fprintf "Too many rels to replace all vars in %t, extra rels: %i" template_pair (List.length rels)))). | |
Definition continue {T} {x : T} := x. | |
Ltac2 head (c : constr) := | |
match Constr.Unsafe.kind c with | |
| Constr.Unsafe.App f _ => f | |
| _ => c | |
end. | |
Definition unify_then A (x y : A) (pf : x = y) := y. | |
Ltac2 rec process_let_block_helper (template_pair : constr) (binders_and_indices : (ident * constr) list) (let_block : constr) := | |
(*printf ">> %t" let_block;*) | |
(* Walks a block of [let ... in ...], looking for assignments to names in binders_and_indices; at the bottom, we return a pair shaped like template_pair which replaces each rel in order of the most recently assigned version of that name *) | |
let continue_c := '@continue in | |
let fold_cps_c := '@fold_cps in | |
let assert_is_continue (c : constr) := | |
if Constr.equal (head c) continue_c then | |
() | |
else | |
Control.zero (Invalid_argument (Some (fprintf "The final value underneath lets should be headed by %t, not %t (head of %t)" continue_c (head c) c))) in | |
(*printf "match >> %t" let_block;*) | |
match Constr.Unsafe.kind let_block with | |
| Constr.Unsafe.LetIn b value let_block | |
=> (* we are going under a binder, so lift indices by 1 (* TODO: check that 1 1 is correct *) *) | |
let binders_and_indices := List.map (fun (name, c) => (name, Constr.Unsafe.liftn 1 1 c)) binders_and_indices in | |
let binders_and_indices | |
:= match Constr.Binder.name b with | |
| Some new_name | |
=> (* if the name matches anything in binders_and_indices, then we replace that rel with 0, which gets incremented below to 1 (rels are 1-indexed IIRC) *) | |
List.map (fun (name, i) => if Ident.equal name new_name | |
then (name, Constr.Unsafe.make (Constr.Unsafe.Rel 1)) else (name, i)) | |
binders_and_indices | |
| None => | |
(* this let value is not named, so it's not one of the ones we care about, so we just increment all de Bruijn indices by one below *) | |
binders_and_indices | |
end in | |
let new_let_block := process_let_block_helper template_pair binders_and_indices let_block in | |
Constr.Unsafe.make (Constr.Unsafe.LetIn b value new_let_block) | |
| Constr.Unsafe.Case c x iv y bl (* (case, (constr * Binder.relevance), case_invert, constr, constr array) *) | |
=> let new_bl := Array.map (fun b => process_let_block_helper template_pair binders_and_indices b) bl in | |
Constr.Unsafe.make (Constr.Unsafe.Case c x iv y new_bl) | |
| Constr.Unsafe.App nesting_construct args | |
=> if Constr.equal nesting_construct fold_cps_c then | |
let cont_idx := Int.sub (Array.length args) 1 in | |
let cont := Array.get args cont_idx in | |
let cont := match Constr.Unsafe.kind cont with | |
| Constr.Unsafe.Lambda b body | |
=> let binders_and_indices := List.map (fun (name, c) => (name, Constr.Unsafe.liftn 1 1 c)) binders_and_indices in | |
let body := process_let_block_helper template_pair binders_and_indices body in | |
(* TODO: implement automatic shadowing of outer loop variables with inner loop variables *) | |
Constr.Unsafe.make (Constr.Unsafe.Lambda b body) | |
| _ => Control.zero (Invalid_argument (Some (fprintf "%t continuation should be a lambda not %t" fold_cps_c let_block))) | |
end in | |
Array.set args cont_idx cont; | |
Constr.Unsafe.make (Constr.Unsafe.App nesting_construct args) | |
else | |
((*printf "AAA"; *)assert_is_continue let_block; | |
let rels := List.map (fun (_, c) => c) binders_and_indices in | |
let result := replace_var_with_rel template_pair rels in | |
(*printf "result: %t" result;*) | |
let final_continue_arg := Array.get args (Int.sub (Array.length args) 1) in | |
(if Constr.is_evar final_continue_arg then () else Control.zero (Invalid_argument (Some (fprintf "final argument to %t (in %t) should be evar, not %t" continue_c let_block final_continue_arg)))); | |
let continue_ty := Array.get args 0 in | |
let eq_refl_c := Constr.Unsafe.make (Constr.Unsafe.App '(@eq_refl) (Array.of_list [continue_ty; final_continue_arg])) in | |
Constr.Unsafe.make (Constr.Unsafe.App '(@unify_then) (Array.of_list [continue_ty; final_continue_arg; result; eq_refl_c])) | |
) | |
| _ => | |
Control.zero (Invalid_argument (Some (fprintf "The final value underneath lets should be %t or %t applied to arguments, not %t" continue_c fold_cps_c let_block))) | |
end. | |
Ltac2 rec pair_to_list (pair : constr) := | |
let pair_c := '@pair in | |
match Constr.Unsafe.kind pair with | |
| Constr.Unsafe.App f args => | |
if Constr.equal f '@pair then | |
if Int.equal (Array.length args) 4 then | |
let a := Array.get args 2 in | |
let b := Array.get args 3 in | |
List.append (pair_to_list a) (pair_to_list b) | |
else | |
Control.zero (Invalid_argument (Some (fprintf "Expected a %t with 4 arguments, got %t of length %i" pair_c pair (Array.length args)))) | |
else | |
Control.zero (Invalid_argument (Some (fprintf "Expected a pair, got %t" pair))) | |
| Constr.Unsafe.Var name => [(name, pair)] | |
| _ => Control.zero (Invalid_argument (Some (fprintf "Expected a pair or var, got %t" pair))) | |
end. | |
Ltac2 process_let_block_tac (init_indices : constr) (let_block : constr) := | |
let result := process_let_block_helper init_indices (pair_to_list init_indices) let_block in | |
match Constr.Unsafe.check result with | |
| Val result => eval cbv beta delta [unify_then continue] in $result | |
| Err exn => Control.zero exn | |
end. | |
Ltac2 pretype_open_constr (c : preterm) := | |
Constr.Pretype.pretype Constr.Pretype.Flags.open_constr_flags_no_tc Constr.Pretype.expected_without_type_constraint c. | |
Notation process_let_block init_indices pair_fun let_block := (ltac2:( | |
let result := process_let_block_tac (pretype_open_constr init_indices) (pretype_open_constr let_block) in (*printf "%t" result ;*) | |
Control.refine (fun () => result))) (only parsing). | |
Notation "'loop' ' init_binder := init_value 'for' loop_var 'in' fold_list 'do' let_block 'in' cont" := | |
(fold_cps init_value fold_list (fun 'init_binder loop_var => process_let_block init_binder (fun 'init_binder => init_binder) let_block) (fun 'init_binder => cont)) | |
(init_binder pattern, at level 200, let_block at level 200, cont at level 200, only parsing). | |
Goal False. | |
epose (loop 'sum := 0 | |
for i in seq 0 10 do | |
let sum := sum + i in | |
continue | |
in sum). | |
(* | |
Anomaly | |
"File "pretyping/glob_ops.ml", line 30, characters 11-17: Assertion failed." | |
Please report at http://coq.inria.fr/bugs/. | |
*) | |
epose (loop '(sum, t) := (0, 0) | |
for i in seq 0 10 do | |
let sum := sum + i + t in | |
let t := if Nat.odd sum then 7 else 42 in | |
continue | |
in (sum, t)). | |
pose (fun A => | |
loop 's := 0 | |
for i in seq 0 10 do ( | |
loop 'si := 0 | |
for j in seq 0 10 do ( | |
let si := si + nth_default 0 (nth_default nil A i) j in | |
continue ) | |
in | |
let s := s + si in | |
continue ) | |
in s). | |
Require Import Coq.Arith.Arith. | |
Compute loop 'sum := 0 | |
for i in seq 0 20 do | |
if (i <? 10)%nat then let sum := sum + i in continue | |
else if Nat.odd i then let sum := sum-1 in continue else let sum := sum+1 in continue | |
in sum. | |
Eval cbv iota in (fun x x => x) 0. | |
Eval cbv beta in (fun x x => x) 0. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment