diff --git a/Project.toml b/Project.toml index a93c816..6b4bdba 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DiffRules" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.0.2" +version = "1.1.0" [deps] NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" diff --git a/src/rules.jl b/src/rules.jl index e31f5d4..9c0fb90 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -60,10 +60,15 @@ @define_diffrule Base.deg2rad(x) = :( π / 180 ) @define_diffrule Base.mod2pi(x) = :( isinteger($x / 2pi) ? NaN : 1 ) @define_diffrule Base.rad2deg(x) = :( 180 / π ) + @define_diffrule SpecialFunctions.gamma(x) = :( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) ) @define_diffrule SpecialFunctions.loggamma(x) = :( SpecialFunctions.digamma($x) ) + +@define_diffrule Base.identity(x) = :( 1 ) +@define_diffrule Base.conj(x) = :( 1 ) +@define_diffrule Base.adjoint(x) = :( 1 ) @define_diffrule Base.transpose(x) = :( 1 ) @define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) ) @@ -87,12 +92,22 @@ else @define_diffrule Base.atan(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) ) end @define_diffrule Base.hypot(x, y) = :( $x / hypot($x, $y) ), :( $y / hypot($x, $y) ) +@define_diffrule Base.log(b, x) = :( log($x) * inv(-log($b)^2 * $b) ), :( inv($x) / log($b) ) + @define_diffrule Base.mod(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -floor(z)), NaN)) ) @define_diffrule Base.rem(x, y) = :( first(promote(ifelse(isinteger($x / $y), NaN, 1), NaN)) ), :( z = $x / $y; first(promote(ifelse(isinteger(z), NaN, -trunc(z)), NaN)) ) @define_diffrule Base.rem2pi(x, r) = :( 1 ), :NaN @define_diffrule Base.max(x, y) = :( $x > $y ? one($x) : zero($x) ), :( $x > $y ? zero($y) : one($y) ) @define_diffrule Base.min(x, y) = :( $x > $y ? zero($x) : one($x) ), :( $x > $y ? one($y) : zero($y) ) +# trinary # +#---------# + +@define_diffrule Base.muladd(x, y, z) = :($y), :($x), :(one($z)) +@define_diffrule Base.fma(x, y, z) = :($y), :($x), :(one($z)) + +@define_diffrule Base.ifelse(p, x, y) = false, :($p), :(!$p) + #################### # SpecialFunctions # #################### diff --git a/test/runtests.jl b/test/runtests.jl index ede8617..c41c29e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,7 +16,7 @@ function finitediff(f, x) end -non_numeric_arg_functions = [(:Base, :rem2pi, 2)] +non_numeric_arg_functions = [(:Base, :rem2pi, 2), (:Base, :ifelse, 3)] for (M, f, arity) in DiffRules.diffrules() (M, f, arity) ∈ non_numeric_arg_functions && continue @@ -46,6 +46,22 @@ for (M, f, arity) in DiffRules.diffrules() @test isapprox(dy, finitediff(z -> $M.$f(foo, z), bar), rtol=0.05) end end + elseif arity == 3 + @test DiffRules.hasdiffrule(M, f, 3) + derivs = DiffRules.diffrule(M, f, :foo, :bar, :goo) + @eval begin + foo, bar, goo = randn(3) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + if !(isnan(dx)) + @test isapprox(dx, finitediff(x -> $M.$f(x, bar, goo), foo), rtol=0.05) + end + if !(isnan(dy)) + @test isapprox(dy, finitediff(y -> $M.$f(foo, y, goo), bar), rtol=0.05) + end + if !(isnan(dz)) + @test isapprox(dz, finitediff(z -> $M.$f(foo, bar, z), goo), rtol=0.05) + end + end end end @@ -62,3 +78,17 @@ for xtype in [:Float64, :BigFloat, :Int64] end end end + +# Test ifelse separately as first argument is boolean +@test DiffRules.hasdiffrule(:Base, :ifelse, 3) +derivs = DiffRules.diffrule(:Base, :ifelse, :foo, :bar, :goo) +for cond in [true, false] + @eval begin + foo = $cond + bar, gee = randn(2) + dx, dy, dz = $(derivs[1]), $(derivs[2]), $(derivs[3]) + @test isapprox(dy, finitediff(y -> ifelse(foo, y, goo), bar), rtol=0.05) + @test isapprox(dz, finitediff(z -> ifelse(foo, bar, z), goo), rtol=0.05) + end +end +