From 0744082393062d1d31c771d68cbaaecbb0d332e6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 9 Dec 2025 10:56:27 +0100 Subject: [PATCH] fix: function shadows for higher-order Enzyme --- .../forward_onearg.jl | 6 +-- .../forward_twoarg.jl | 2 +- .../reverse_onearg.jl | 4 +- .../reverse_twoarg.jl | 2 +- .../utils.jl | 43 +++++++++---------- 5 files changed, 28 insertions(+), 29 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 98ecc0db2..c44529caf 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -15,8 +15,8 @@ function DI.prepare_pushforward_nokwarg( contexts::Vararg{DI.Context, C} ) where {F, C, B} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - df = function_shadow(f, backend, Val(B)) mode = forward_withprimal(backend) + df = function_shadow(f, backend, mode, Val(B)) context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) return EnzymeOneArgPushforwardPrep(_sig, df, context_shadows) end @@ -146,8 +146,8 @@ function DI.prepare_gradient_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) valB = to_val(DI.pick_batchsize(backend, x)) - df = function_shadow(f, backend, valB) mode = forward_withprimal(backend) + df = function_shadow(f, backend, mode, valB) context_shadows = make_context_shadows(backend, mode, valB, contexts...) basis_shadows = create_shadows(valB, x) return EnzymeForwardGradientPrep(_sig, valB, df, context_shadows, basis_shadows) @@ -236,7 +236,7 @@ function DI.prepare_jacobian_nokwarg( y = f(x, map(DI.unwrap, contexts)...) valB = to_val(DI.pick_batchsize(backend, x)) mode = forward_withprimal(backend) - df = function_shadow(f, backend, valB) + df = function_shadow(f, backend, mode, valB) context_shadows = make_context_shadows(backend, mode, valB, contexts...) basis_shadows = create_shadows(valB, x) return EnzymeForwardOneArgJacobianPrep( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl index bd8bdc04b..629f046a0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl @@ -16,8 +16,8 @@ function DI.prepare_pushforward_nokwarg( contexts::Vararg{DI.Context, C} ) where {F, B, C} _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) - df! = function_shadow(f!, backend, Val(B)) mode = forward_noprimal(backend) + df! = function_shadow(f!, backend, mode, Val(B)) context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index c81c5cd44..750db69cd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -63,8 +63,8 @@ function DI.prepare_pullback_nokwarg( contexts::Vararg{DI.Context, C} ) where {F, B, C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) - df = function_shadow(f, backend, Val(B)) mode = reverse_split_withprimal(backend) + df = function_shadow(f, backend, mode, Val(B)) context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) y = f(x, map(DI.unwrap, contexts)...) return EnzymeReverseOneArgPullbackPrep(_sig, df, context_shadows, y) @@ -214,8 +214,8 @@ function DI.prepare_gradient_nokwarg( contexts::Vararg{DI.Context, C} ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) - df = function_shadow(f, backend, Val(1)) mode = reverse_withprimal(backend) + df = function_shadow(f, backend, mode, Val(1)) context_shadows = make_context_shadows(backend, mode, Val(1), contexts...) return EnzymeGradientPrep(_sig, df, context_shadows) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl index ddf5f774c..616155a1a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl @@ -17,8 +17,8 @@ function DI.prepare_pullback_nokwarg( contexts::Vararg{DI.Context, C} ) where {F, B, C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) - df! = function_shadow(f!, backend, Val(B)) mode = reverse_noprimal(backend) + df! = function_shadow(f!, backend, mode, Val(B)) context_shadows = make_context_shadows(backend, mode, Val(B), contexts...) ty_copy = map(copy, ty) return EnzymeReverseTwoArgPullbackPrep(_sig, df!, context_shadows, ty_copy) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 57c654cf5..f53218249 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -28,29 +28,33 @@ end function get_f_and_df_prepared!( df, f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B} ) where {F, M, B} - #= - It is not obvious why we don't need a `make_zero` here, in the case of mutable constant data in `f`. - - In forward mode, `df` is never incremented if `f` is not mutated, so it remains equal to its initial value of `0`. - - In reverse mode, `df` gets incremented but it does not influence the input cotangent `dx`. - =# - if B == 1 - return Duplicated(f, df) + if isnothing(df) + return Const(f) else - return BatchDuplicated(f, df) + if B == 1 + return Duplicated(f, df) + else + return BatchDuplicated(f, df) + end end end function function_shadow( - ::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Val{B} + ::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Mode, ::Val{B} ) where {M, B, F} return nothing end -function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B}) where {F, M, B} - if B == 1 - return make_zero(f) +function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, mode::Mode, ::Val{B}) where {F, M, B} + IA = guess_activity(F, mode) + return if IA <: Const + nothing else - return ntuple(_ -> make_zero(f), Val(B)) + if B == 1 + return make_zero(f) + else + return ntuple(_ -> make_zero(f), Val(B)) + end end end @@ -87,13 +91,13 @@ function _shadow( end function _shadow( - backend::AutoEnzyme{M, <:Union{Const, Nothing}}, - ::Mode, + backend::AutoEnzyme, + mode::Mode, ::Val{B}, c_wrapped::DI.FunctionContext, - ) where {M, B} + ) where {B} f = DI.unwrap(c_wrapped) - return function_shadow(f, backend, Val(B)) + return function_shadow(f, backend, mode, Val(B)) end function make_context_shadows( @@ -122,11 +126,6 @@ end function _translate_prepared!( dc, c_wrapped::Union{DI.ConstantOrCache, DI.FunctionContext}, ::Val{B} ) where {B} - #= - It is not obvious why we don't need a `make_zero` here, in the case of mutable constant contexts. - - In forward mode, `dc` is never incremented because `c` is not mutated, so it remains equal to its initial value of `0`. - - In reverse mode, `dc` gets incremented but it does not influence the input cotangent `dx`. - =# c = DI.unwrap(c_wrapped) if isnothing(dc) return Const(c)