Utilities
Zygote provides a set of helpful utilities. These are all "user-level" tools โ in other words you could have written them easily yourself, but they live in Zygote for convenience.
Zygote.@showgrad โ Macro.@showgrad(x) -> xMuch like @show, but shows the gradient about to accumulate to x. Useful for debugging gradients.
julia> gradient(2, 3) do a, b
@showgrad(a)*b
end
โ(a) = 3
(3, 2)Note that the gradient depends on how the output of @showgrad is used, and is not the overall gradient of the variable a. For example:
julia> gradient(2) do a
@showgrad(a)*a
end
โ(a) = 2
(4,)
julia> gradient(2, 3) do a, b
@showgrad(a) # not used, so no gradient
a*b
end
โ(a) = nothing
(3, 2)Zygote.hook โ Function.hook(xฬ -> ..., x) -> xGradient hooks. Allows you to apply an arbitrary function to the gradient for x.
julia> gradient(2, 3) do a, b
hook(aฬ -> @show(aฬ), a)*b
end
ฤ = 3
(3, 2)
julia> gradient(2, 3) do a, b
hook(-, a)*b
end
(-3, 2)Zygote.dropgrad โ Function.dropgrad(x) -> xDrop the gradient of x.
julia> gradient(2, 3) do a, b
dropgrad(a)*b
end
(nothing, 2)Zygote.hessian โ Function.hessian(f, x)Construct the Hessian of f, where x is a real or real array and f(x) is a real.
julia> hessian(((a, b),) -> a*b, [2, 3])
2ร2 Array{Int64,2}:
0 1
1 0Zygote.Buffer โ Type.Buffer(xs, ...)Buffer is an array-like type which is mutable when taking gradients. You can construct a Buffer with the same syntax as similar (e.g. Buffer(xs, 5)) and then use normal indexing. Finally, use copy to get back a normal array.
For example:
julia> function vstack(xs)
buf = Buffer(xs, length(xs), 5)
for i = 1:5
buf[:, i] = xs
end
return copy(buf)
end
vstack (generic function with 1 method)
julia> vstack([1, 2, 3])
3ร5 Array{Int64,2}:
1 1 1 1 1
2 2 2 2 2
3 3 3 3 3
julia> gradient(x -> sum(vstack(x)), [1, 2, 3])
([5.0, 5.0, 5.0],)Buffer is not an AbstractArray and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.
copy is a semantic copy, but does not allocate memory. Instead the Buffer is made immutable after copying.
Zygote.forwarddiff โ Function.forwarddiff(f, x) -> f(x)Runs f(x) as usual, but instructs Zygote to differentiate f using forward mode, rather than the usual reverse mode.
Forward mode takes time linear in length(x) but only has constant memory overhead, and is very efficient for scalars, so in some cases this can be a useful optimisation.
julia> function pow(x, n)
r = one(x)
for i = 1:n
r *= x
end
return r
end
pow (generic function with 1 method)
julia> gradient(5) do x
forwarddiff(x) do x
pow(x, 2)
end
end
(10,)Note that the function f will drop gradients for any closed-over values.
julia> gradient(2, 3) do a, b
forwarddiff(a) do a
a*b
end
end
(3, nothing)This can be rewritten by explicitly passing through b, i.e.
gradient(2, 3) do a, b
forwarddiff([a, b]) do (a, b)
a*b
end
end