From 48d45dcfb19d3897a15b5b36a2c43150ce9e0cef Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Jun 2021 11:40:06 +0100 Subject: [PATCH 1/6] fix incorrected frule fallback --- src/ruleset_loading.jl | 3 ++- test/ruleset_loading.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index 876b8b8..220df6c 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -58,7 +58,8 @@ function _rule_list end _rule_list(rule_kind) = (m for m in methods(rule_kind) if !_is_fallback(rule_kind, m)) "check if this is the fallback-frule/rrule that always returns `nothing`" -_is_fallback(rule_kind, m::Method) = m.sig === Tuple{typeof(rule_kind), Any, Vararg{Any}} +_is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Vararg{Any}} +_is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}} const LAST_REFRESH_RRULE = Ref(0) const LAST_REFRESH_FRULE = Ref(0) diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 4f3b02d..e7e7a71 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -73,6 +73,6 @@ @testset "_is_fallback" begin _is_fallback = ChainRulesOverloadGeneration._is_fallback @test _is_fallback(rrule, first(methods(rrule, (Nothing,)))) - @test _is_fallback(frule, first(methods(frule, (Nothing,)))) + @test _is_fallback(frule, first(methods(frule, (Tuple{}, Nothing,)))) end end From 5468d285a8b81f98c970fe550095959e5aefef94 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Jun 2021 11:40:50 +0100 Subject: [PATCH 2/6] drop ChainRulesCore 0.9 and remove deprecations --- Project.toml | 4 ++-- test/demos/forwarddiffzero.jl | 2 +- test/demos/reversediffzero.jl | 2 +- test/runtests.jl | 3 --- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index c12d422..8e10b5e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,12 @@ name = "ChainRulesOverloadGeneration" uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" -version = "0.1.1" +version = "0.1.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [compat] -ChainRulesCore = "0.9, 0.10" +ChainRulesCore = "0.10.4" julia = "1" [extras] diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl index 3283dd3..ae93459 100644 --- a/test/demos/forwarddiffzero.jl +++ b/test/demos/forwarddiffzero.jl @@ -44,7 +44,7 @@ function define_dual_overload(sig) # we use the function call overloading form as it lets us avoid namespacing issues # as we can directly interpolate the function type into to the AST. function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) - ȧrgs = (NO_FIELDS, partial.(dual_args)...) + ȧrgs = (NoTangent(), partial.(dual_args)...) args = (op, primal.(dual_args)...) y, ẏ = frule(ȧrgs, args...; kwargs...) return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl index a410ad5..578c6e0 100644 --- a/test/demos/reversediffzero.jl +++ b/test/demos/reversediffzero.jl @@ -116,7 +116,7 @@ end function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number) function times_pullback(ΔΩ) # we will use thunks here to show we handle them fine. - return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) + return (NoTangent(), @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) end return x * y, times_pullback end diff --git a/test/runtests.jl b/test/runtests.jl index 8ddd53f..8795d8f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,5 @@ using ChainRulesCore using ChainRulesOverloadGeneration -# resolve conflicts while this code exists in both. -const on_new_rule = ChainRulesOverloadGeneration.on_new_rule -const refresh_rules = ChainRulesOverloadGeneration.refresh_rules using Test From b3880663fcb8f881327f4b72e875e5ba7bbafe82 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Jun 2021 12:10:04 +0100 Subject: [PATCH 3/6] Filter out rules that require config --- src/ruleset_loading.jl | 13 ++++++++++--- test/ruleset_loading.jl | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index 220df6c..00e8f07 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -52,15 +52,22 @@ clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) _rule_list(frule | rrule) Returns a list of all the methods of the currently defined rules of the given kind. -Excluding the fallback rule that returns `nothing` for every input. +Excluding the fallback rule that returns `nothing` for every input; +and excluding rules that require a particular `RuleConfig`. """ -function _rule_list end -_rule_list(rule_kind) = (m for m in methods(rule_kind) if !_is_fallback(rule_kind, m)) +function _rule_list(rule_kind) + return Iterators.filter(methods(rule_kind)) do m + !_is_fallback(rule_kind, m) && !_requires_config(m) + end +end "check if this is the fallback-frule/rrule that always returns `nothing`" _is_fallback(::typeof(rrule), m::Method) = m.sig === Tuple{typeof(rrule),Any,Vararg{Any}} _is_fallback(::typeof(frule), m::Method) = m.sig === Tuple{typeof(frule),Any,Any,Vararg{Any}} +"check if this rule requires a particular configuation (`RuleConfig`)" +_requires_config(m::Method) = m.sig.parameters[2] <: RuleConfig + const LAST_REFRESH_RRULE = Ref(0) const LAST_REFRESH_FRULE = Ref(0) last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index e7e7a71..63dc0c9 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -75,4 +75,29 @@ @test _is_fallback(rrule, first(methods(rrule, (Nothing,)))) @test _is_fallback(frule, first(methods(frule, (Tuple{}, Nothing,)))) end + + @test "_rule_list" begin + _rule_list = ChainRulesOverloadGeneration._rule_list + @testset "should not have frules that need RuleConfig" begin + old_frule_list = collect(_rule_list(frule)) + function ChainRulesCore.frule( + ::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, dargs, sum, f, xs + ) + return 1.0, 1.0 # this will not be call so return doesn't matter + end + # New rule should not have appeared + @test collect(_rule_list(frule)) == old_frule_list + end + + @testset "should not have rrules that need RuleConfig" begin + old_rrule_list = collect(_rule_list(rrule)) + function ChainRulesCore.rrule( + ::RuleConfig{>:Union{HasForwardsMode,HasReverseMode}}, sum, f, xs + ) + return 1.0, x->(x,x,x) # this will not be call so return doesn't matter + end + # New rule should not have appeared + @test collect(_rule_list(rrule)) == old_rrule_list + end + end end From 2d9748687646f98685be1acec051178d8b6bf3bc Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Jun 2021 13:37:47 +0100 Subject: [PATCH 4/6] style --- src/ruleset_loading.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl index 00e8f07..d59c306 100644 --- a/src/ruleset_loading.jl +++ b/src/ruleset_loading.jl @@ -57,7 +57,7 @@ and excluding rules that require a particular `RuleConfig`. """ function _rule_list(rule_kind) return Iterators.filter(methods(rule_kind)) do m - !_is_fallback(rule_kind, m) && !_requires_config(m) + return !_is_fallback(rule_kind, m) && !_requires_config(m) end end From c55eeb5bbfdaba2a304ff7cc45c935eee8679d33 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Jun 2021 14:56:51 +0100 Subject: [PATCH 5/6] testset not test --- test/ruleset_loading.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl index 63dc0c9..662a4e8 100644 --- a/test/ruleset_loading.jl +++ b/test/ruleset_loading.jl @@ -76,7 +76,7 @@ @test _is_fallback(frule, first(methods(frule, (Tuple{}, Nothing,)))) end - @test "_rule_list" begin + @testset "_rule_list" begin _rule_list = ChainRulesOverloadGeneration._rule_list @testset "should not have frules that need RuleConfig" begin old_frule_list = collect(_rule_list(frule)) From 03b82336edb71d736debc3d4688098d6d4fecc20 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Jun 2021 16:04:25 +0100 Subject: [PATCH 6/6] update docs --- docs/Manifest.toml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 3b5c092..7eeb942 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -11,15 +11,15 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" [[ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "b391f22252b8754f4440de1f37ece49d8a7314bb" +git-tree-sha1 = "d659e42240c2162300b321f05173cab5cc40a5ba" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.44" +version = "0.10.4" [[ChainRulesOverloadGeneration]] deps = ["ChainRulesCore"] path = ".." uuid = "f51149dc-2911-5acf-81fc-2076a2a81d4f" -version = "0.1.0" +version = "0.1.2" [[Compat]] deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] @@ -40,10 +40,10 @@ deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[DocStringExtensions]] -deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" +deps = ["LibGit2"] +git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.4" +version = "0.8.5" [[DocThemeIndigo]] deps = ["Sass"]