Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ function DI.prepare_pushforward_nokwarg(
contexts::Vararg{DI.Context, C}
) where {F, C, B}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
df = function_shadow(f, backend, Val(B))
mode = forward_withprimal(backend)
df = function_shadow(f, backend, mode, Val(B))
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
return EnzymeOneArgPushforwardPrep(_sig, df, context_shadows)
end
Expand Down Expand Up @@ -146,8 +146,8 @@ function DI.prepare_gradient_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, contexts...; strict)
valB = to_val(DI.pick_batchsize(backend, x))
df = function_shadow(f, backend, valB)
mode = forward_withprimal(backend)
df = function_shadow(f, backend, mode, valB)
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
basis_shadows = create_shadows(valB, x)
return EnzymeForwardGradientPrep(_sig, valB, df, context_shadows, basis_shadows)
Expand Down Expand Up @@ -236,7 +236,7 @@ function DI.prepare_jacobian_nokwarg(
y = f(x, map(DI.unwrap, contexts)...)
valB = to_val(DI.pick_batchsize(backend, x))
mode = forward_withprimal(backend)
df = function_shadow(f, backend, valB)
df = function_shadow(f, backend, mode, valB)
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
basis_shadows = create_shadows(valB, x)
return EnzymeForwardOneArgJacobianPrep(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ function DI.prepare_pushforward_nokwarg(
contexts::Vararg{DI.Context, C}
) where {F, B, C}
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
df! = function_shadow(f!, backend, Val(B))
mode = forward_noprimal(backend)
df! = function_shadow(f!, backend, mode, Val(B))
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ function DI.prepare_pullback_nokwarg(
contexts::Vararg{DI.Context, C}
) where {F, B, C}
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
df = function_shadow(f, backend, Val(B))
mode = reverse_split_withprimal(backend)
df = function_shadow(f, backend, mode, Val(B))
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
y = f(x, map(DI.unwrap, contexts)...)
return EnzymeReverseOneArgPullbackPrep(_sig, df, context_shadows, y)
Expand Down Expand Up @@ -214,8 +214,8 @@ function DI.prepare_gradient_nokwarg(
contexts::Vararg{DI.Context, C}
) where {F, C}
_sig = DI.signature(f, backend, x, contexts...; strict)
df = function_shadow(f, backend, Val(1))
mode = reverse_withprimal(backend)
df = function_shadow(f, backend, mode, Val(1))
context_shadows = make_context_shadows(backend, mode, Val(1), contexts...)
return EnzymeGradientPrep(_sig, df, context_shadows)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ function DI.prepare_pullback_nokwarg(
contexts::Vararg{DI.Context, C}
) where {F, B, C}
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
df! = function_shadow(f!, backend, Val(B))
mode = reverse_noprimal(backend)
df! = function_shadow(f!, backend, mode, Val(B))
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
ty_copy = map(copy, ty)
return EnzymeReverseTwoArgPullbackPrep(_sig, df!, context_shadows, ty_copy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,33 @@ end
function get_f_and_df_prepared!(
df, f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B}
) where {F, M, B}
#=
It is not obvious why we don't need a `make_zero` here, in the case of mutable constant data in `f`.
- In forward mode, `df` is never incremented if `f` is not mutated, so it remains equal to its initial value of `0`.
- In reverse mode, `df` gets incremented but it does not influence the input cotangent `dx`.
=#
if B == 1
return Duplicated(f, df)
if isnothing(df)
return Const(f)
else
return BatchDuplicated(f, df)
if B == 1
return Duplicated(f, df)
else
return BatchDuplicated(f, df)
end
end
end

function function_shadow(
::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Val{B}
::F, ::AutoEnzyme{M, <:Union{Const, Nothing}}, ::Mode, ::Val{B}
) where {M, B, F}
return nothing
end

function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, ::Val{B}) where {F, M, B}
if B == 1
return make_zero(f)
function function_shadow(f::F, ::AutoEnzyme{M, <:AnyDuplicated}, mode::Mode, ::Val{B}) where {F, M, B}
IA = guess_activity(F, mode)
return if IA <: Const
nothing
else
return ntuple(_ -> make_zero(f), Val(B))
if B == 1
return make_zero(f)
else
return ntuple(_ -> make_zero(f), Val(B))
end
end
end

Expand Down Expand Up @@ -87,13 +91,13 @@ function _shadow(
end

function _shadow(
backend::AutoEnzyme{M, <:Union{Const, Nothing}},
::Mode,
backend::AutoEnzyme,
mode::Mode,
::Val{B},
c_wrapped::DI.FunctionContext,
) where {M, B}
) where {B}
f = DI.unwrap(c_wrapped)
return function_shadow(f, backend, Val(B))
return function_shadow(f, backend, mode, Val(B))
end

function make_context_shadows(
Expand Down Expand Up @@ -122,11 +126,6 @@ end
function _translate_prepared!(
dc, c_wrapped::Union{DI.ConstantOrCache, DI.FunctionContext}, ::Val{B}
) where {B}
#=
It is not obvious why we don't need a `make_zero` here, in the case of mutable constant contexts.
- In forward mode, `dc` is never incremented because `c` is not mutated, so it remains equal to its initial value of `0`.
- In reverse mode, `dc` gets incremented but it does not influence the input cotangent `dx`.
=#
c = DI.unwrap(c_wrapped)
if isnothing(dc)
return Const(c)
Expand Down
Loading