Utilities

Utilities

Zygote provides a set of helpful utilities. These are all "user-level" tools โ€“ in other words you could have written them easily yourself, but they live in Zygote for convenience.

Zygote.@showgrad โ€” Macro.
@showgrad(x) -> x

Much like @show, but shows the gradient about to accumulate to x. Useful for debugging gradients.

julia> gradient(2, 3) do a, b
         @showgrad(a)*b
       end
โˆ‚(a) = 3
(3, 2)

Note that the gradient depends on how the output of @showgrad is used, and is not the overall gradient of the variable a. For example:

julia> gradient(2) do a
     @showgrad(a)*a
   end
โˆ‚(a) = 2
(4,)

julia> gradient(2, 3) do a, b
         @showgrad(a) # not used, so no gradient
         a*b
       end
โˆ‚(a) = nothing
(3, 2)
source
Zygote.hook โ€” Function.
hook(xฬ„ -> ..., x) -> x

Gradient hooks. Allows you to apply an arbitrary function to the gradient for x.

julia> gradient(2, 3) do a, b
         hook(aฬ„ -> @show(aฬ„), a)*b
       end
ฤ = 3
(3, 2)

julia> gradient(2, 3) do a, b
         hook(-, a)*b
       end
(-3, 2)
source
Zygote.dropgrad โ€” Function.
dropgrad(x) -> x

Drop the gradient of x.

julia> gradient(2, 3) do a, b
     dropgrad(a)*b
   end
(nothing, 2)
source
Zygote.hessian โ€” Function.
hessian(f, x)

Construct the Hessian of f, where x is a real or real array and f(x) is a real.

julia> hessian(((a, b),) -> a*b, [2, 3])
2ร—2 Array{Int64,2}:
 0  1
 1  0
source
Zygote.Buffer โ€” Type.
Buffer(xs, ...)

Buffer is an array-like type which is mutable when taking gradients. You can construct a Buffer with the same syntax as similar (e.g. Buffer(xs, 5)) and then use normal indexing. Finally, use copy to get back a normal array.

For example:

julia> function vstack(xs)
           buf = Buffer(xs, length(xs), 5)
           for i = 1:5
             buf[:, i] = xs
           end
           return copy(buf)
         end
vstack (generic function with 1 method)

julia> vstack([1, 2, 3])
3ร—5 Array{Int64,2}:
 1  1  1  1  1
 2  2  2  2  2
 3  3  3  3  3

julia> gradient(x -> sum(vstack(x)), [1, 2, 3])
([5.0, 5.0, 5.0],)

Buffer is not an AbstractArray and can't be used for linear algebra operations like matrix multiplication. This prevents it from being captured by pullbacks.

copy is a semantic copy, but does not allocate memory. Instead the Buffer is made immutable after copying.

source
Zygote.forwarddiff โ€” Function.
forwarddiff(f, x) -> f(x)

Runs f(x) as usual, but instructs Zygote to differentiate f using forward mode, rather than the usual reverse mode.

Forward mode takes time linear in length(x) but only has constant memory overhead, and is very efficient for scalars, so in some cases this can be a useful optimisation.

julia> function pow(x, n)
         r = one(x)
         for i = 1:n
           r *= x
         end
         return r
       end
pow (generic function with 1 method)

julia> gradient(5) do x
         forwarddiff(x) do x
           pow(x, 2)
         end
       end
(10,)

Note that the function f will drop gradients for any closed-over values.

julia> gradient(2, 3) do a, b
         forwarddiff(a) do a
           a*b
         end
       end
(3, nothing)

This can be rewritten by explicitly passing through b, i.e.

gradient(2, 3) do a, b
  forwarddiff([a, b]) do (a, b)
    a*b
  end
end
source