Created
June 22, 2021 13:54
-
-
Save ChrisRackauckas/1665257903b62462f6f970682636c5a8 to your computer and use it in GitHub Desktop.
hasbranching for automatically specializing ReverseDiff tape compilation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Cassette, DiffRules | |
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot | |
const printbranch = true | |
Cassette.@context HasBranchingCtx | |
function Cassette.overdub(ctx::HasBranchingCtx, f, args...) | |
if Cassette.canrecurse(ctx, f, args...) | |
return Cassette.recurse(ctx, f, args...) | |
else | |
return Cassette.fallback(ctx, f, args...) | |
end | |
end | |
for (mod, f, n) in DiffRules.diffrules() | |
isdefined(@__MODULE__, mod) || continue | |
@eval Cassette.overdub(::HasBranchingCtx, f::Core.Typeof($mod.$f), x::Vararg{Any, $n}) = f(x...) | |
end | |
function _pass(::Type{<:HasBranchingCtx}, reflection::Cassette.Reflection) | |
ir = reflection.code_info | |
if any(x -> isa(x, GotoIfNot), ir.code) | |
printbranch && ccall(:jl_safe_printf, Cvoid, (Cstring,), "GotoIfNot detected in $(reflection.method)\nir = $ir\n") | |
Cassette.insert_statements!( | |
ir.code, ir.codelocs, | |
(stmt, i) -> i == 1 ? 3 : nothing, | |
(stmt, i) -> Any[ | |
Expr(:call, Expr(:nooverdub, GlobalRef(Base, :getfield)), Expr(:contextslot), QuoteNode(:metadata)), | |
Expr(:call, Expr(:nooverdub, GlobalRef(Base, :setindex!)), SSAValue(1), true, QuoteNode(:has_branching)), | |
stmt, | |
], | |
) | |
Cassette.insert_statements!( | |
ir.code, ir.codelocs, | |
(stmt, i) -> i > 2 && isa(stmt, Expr) ? 1 : nothing, | |
(stmt, i) -> begin | |
callstmt = Meta.isexpr(stmt, :(=)) ? stmt.args[2] : stmt | |
Meta.isexpr(stmt, :call) || Meta.isexpr(stmt, :invoke) || return Any[stmt] | |
callstmt = Expr(callstmt.head, Expr(:nooverdub, callstmt.args[1]), callstmt.args[2:end]...) | |
return Any[ | |
Meta.isexpr(stmt, :(=)) ? Expr(:(=), stmt.args[1], callstmt) : callstmt, | |
] | |
end, | |
) | |
end | |
return ir | |
end | |
const pass = Cassette.@pass _pass | |
function hasbranching(f, x...) | |
metadata = Dict(:has_branching => false) | |
Cassette.overdub(Cassette.disablehooks(HasBranchingCtx(; pass, metadata)), f, x...) | |
return metadata[:has_branching] | |
end | |
Cassette.overdub(::HasBranchingCtx, ::typeof(+), x...) = +(x...) | |
Cassette.overdub(::HasBranchingCtx, ::typeof(*), x...) = *(x...) | |
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.materialize), x...) = Base.materialize(x...) | |
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.literal_pow), x...) = Base.literal_pow(x...) | |
Cassette.overdub(::HasBranchingCtx, ::typeof(Base.getindex), x...) = Base.getindex(x...) | |
Cassette.overdub(::HasBranchingCtx, ::typeof(Core.Typeof), x...) = Core.Typeof(x...) | |
hasbranching(1, 2) do x, y | |
(x < 0 ? -x : x) + exp(y) | |
end | |
hasbranching(1, 2) do x, y | |
ifelse(x < 0, -x, x) + exp(y) | |
end | |
using DiffEqFlux | |
# Override FastDense to exclude the branch from the check | |
function Cassette.overdub(ctx::HasBranchingCtx, f::FastDense, x, p) | |
y = reshape(p[1:(f.out*f.in)],f.out,f.in)*x | |
Cassette.@overdub ctx f.σ.(y) | |
end | |
u0 = Float32[2.0; 0.0] | |
dudt2 = FastChain((x, p) -> x.^3, | |
FastDense(2, 50, tanh), | |
FastDense(50, 2)) | |
p = initial_params(dudt2) | |
hasbranching(dudt2,u0,p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment