-
-
Save petvana/9e583dcc6471368322afec63c0fa81e3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module testmacro | |
using IntervalArithmetic | |
replace_interval(ex::Any, old::Symbol, new) = ex == old ? new : old | |
# Recursivelly replace `old`` symbol by `new` expression in `ex`. | |
function replace_interval(ex::Expr, old::Symbol, new) | |
for i in 1:length(ex.args) | |
if ex.args[i] == old | |
ex.args[i] = new | |
else | |
replace_interval(ex.args[i], old, new) | |
end | |
end | |
return ex | |
end | |
macro interval(code) | |
asserttext = """Unsupported usage of @interval. | |
Expected format is: x > 1 ? f(x) : g(x) | |
where x is Interval and 1 is number.""" | |
@assert code isa Expr asserttext | |
@assert code.head == :if asserttext | |
@assert length(code.args) == 3 asserttext | |
@assert code.args[1] isa Expr asserttext | |
@assert code.args[1].head == :call asserttext | |
@assert length(code.args[1].args) == 3 asserttext | |
comp = code.args[1].args[1] | |
variable = code.args[1].args[2] | |
@assert variable isa Symbol asserttext | |
value = code.args[1].args[3] | |
@assert comp ∈ [:<, :<=, :>, :>=] asserttext | |
neg_comp = Dict(:< => :>=, :<= => :>, :> => :<=, :>= => :<)[comp] | |
first = code.args[2] | |
second = code.args[3] | |
m1, m2 = comp ∈ [:>, :>=] ? (:max, :min) : (:min, :max) | |
ret1 = replace_interval(first, variable, Expr(:call, m1, variable, value)) | |
ret2 = replace_interval(second, variable, Expr(:call, m2, variable, value)) | |
esc(quote | |
if $(Expr(:call, comp, variable, value)) | |
$first | |
elseif $(Expr(:call, neg_comp, variable, value)) | |
$second | |
else | |
union($ret1, $ret2) | |
end | |
end) | |
end | |
function f(x) | |
y = x > 1 ? x-1 : x | |
return y^2 | |
end | |
@show f(0.9) | |
@show f(1.1) | |
@show f(0.9..1.1) | |
function g(x) | |
y = @interval x > 1 ? x-1 : x | |
return y^2 | |
end | |
@show g(0.9) | |
@show g(1.1) | |
@show g(0.9..1.1) | |
function h(x) | |
y = @interval if x > 1 | |
x-1 | |
else | |
x | |
end | |
return y^2 | |
end | |
@show h(0.9) | |
@show h(1.1) | |
@show h(0.9..1.1) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment