Internals

Internals

What Zygote Does

These notebooks and the Zygote paper provide useful background on Zygote's transform; this page is particularly focused on implementation details.

Given a function like

function foo(x)
  a = bar(x)
  b = baz(a)
  return b
end

how do we differentiate it? The key is that we can differentiate foo if we can differentiate bar and baz. If we assume we can get pullbacks for those functions, the pullback for foo looks as follows.

function J(::typeof(foo), x)
  a, da = J(bar, x)
  b, db = J(baz, a)
  return b, function(bฬ„)
    aฬ„ = db(bฬ„)
    xฬ„ = da(aฬ„)
    return xฬ„
  end
end

Thus, where the forward pass calculates x -> a -> b, the backwards takes bฬ„ -> aฬ„ -> xฬ„ via the pullbacks. The AD transform is recursive; we'll differentiate bar and baz in the same way, until we hit a case where gradient is explicitly defined.

Here's a working example that illustrates the concepts.

J(::typeof(sin), x) = sin(x), yฬ„ -> yฬ„*cos(x)
J(::typeof(cos), x) = cos(x), yฬ„ -> -yฬ„*sin(x)

foo(x) = sin(cos(x))

function J(::typeof(foo), x)
  a, da = J(sin, x)
  b, db = J(cos, a)
  return b, function(bฬ„)
    aฬ„ = db(bฬ„)
    xฬ„ = da(aฬ„)
    return xฬ„
  end
end

gradient(f, x) = J(f, x)[2](1)

gradient(foo, 1)

Now, clearly this is a mechanical transformation, so the only remaining thing is to automate it โ€“ a small matter of programming.

Closures

The J function here corresponds to forward in Zygote. However, forward actually a wrapper around the lower level _forward function.

julia> y, back = Zygote._forward(sin, 0.5);

julia> back(1)
(nothing, 0.8775825618903728)

Why the extra nothing here? This actually represents the gradient of the function sin. This is often nothing, but when we have closures the function contains data we need gradients for.

julia> f = let a = 3; x -> x*a; end
#19 (generic function with 1 method)

julia> y, back = Zygote._forward(f, 2);

julia> back(1)
((a = 2,), 3)

This is a minor point for the most part, but _forward will come up in future examples.

Entry Points

We could do this transform with a macro, but don't want to require that all differentiable code is annotated. Instead a generated function gets us much of the power of a macro without this annotation, because we can use it to get lowered code for a function. We can then modify the code as we please and return it to implement J(foo, x).

julia> foo(x) = baz(bar(x))
foo (generic function with 1 method)

julia> @code_lowered foo(1)
CodeInfo(
1 โ”€ %1 = (Main.bar)(x)
โ”‚   %2 = (Main.baz)(%1)
โ””โ”€โ”€      return %2

We convert the code to SSA form using Julia's built-in IR data structure, after which it looks like this.

julia> Zygote.@code_ir foo(1)
1 1 โ”€ %1 = (Main.bar)(_2)::Any
  โ”‚   %2 = (Main.baz)(%1)::Any
  โ””โ”€โ”€      return %2    

(There isn't much difference unless there's some control flow.)

The code is then differentiated by the code in compiler/reverse.jl. You can see the output with @code_adjoint.

julia> Zygote.@code_adjoint foo(1)
1 1 โ”€ %1  = (Zygote._forward)(_2, Zygote.unwrap, Main.bar)::Any
  โ”‚   %2  = (Base.getindex)(%1, 1)::Any
  โ”‚         (Base.getindex)(%1, 2)::Any
  โ”‚   %4  = (Zygote._forward)(_2, %2, _4)::Any
  โ”‚   %5  = (Base.getindex)(%4, 1)::Any
  โ”‚         (Base.getindex)(%4, 2)::Any
  โ”‚   %7  = (Zygote._forward)(_2, Zygote.unwrap, Main.baz)::Any
  โ”‚   %8  = (Base.getindex)(%7, 1)::Any
  โ”‚         (Base.getindex)(%7, 2)::Any
  โ”‚   %10 = (Zygote._forward)(_2, %8, %5)::Any
  โ”‚   %11 = (Base.getindex)(%10, 1)::Any
  โ”‚         (Base.getindex)(%10, 2)::Any
  โ””โ”€โ”€       return %11
  1 โ”€ %1  = ฮ”()::Any
1 โ”‚   %2  = (@12)(%1)::Any
  โ”‚   %3  = (Zygote.gradindex)(%2, 1)::Any
  โ”‚   %4  = (Zygote.gradindex)(%2, 2)::Any
  โ”‚         (@9)(%3)::Any
  โ”‚   %6  = (@6)(%4)::Any
  โ”‚   %7  = (Zygote.gradindex)(%6, 1)::Any
  โ”‚   %8  = (Zygote.gradindex)(%6, 2)::Any
  โ”‚         (@3)(%7)::Any
  โ”‚   %10 = (Zygote.tuple)(nothing, %8)::Any
  โ””โ”€โ”€       return %10
, [1])

This code is quite verbose, mainly due to all the tuple unpacking (gradindex is just like getindex, but handles nothing gracefully). The are two pieces of IR here, one for the modified forward pass and one for the pullback closure. The @ nodes allow the closure to refer to values from the forward pass, and the ฮ”() represents the incoming gradient yฬ„. In essence, this is just what we wrote above by hand for J(::typeof(foo), x).

compiler/emit.jl lowers this code into runnable IR (e.g. by turning @ references into getfields and stacks), and it's then turned back into lowered code for Julia to run.

Closure Conversion

There are no closures in lowered Julia code, so we can't actually emit one directly in lowered code. To work around this we have a trick: we have a generic struct like

struct Pullback{F}
  data
end

We can put whatever we want in data, and the F will be the signature for the original call, like Tuple{typeof(foo),Int}. When the pullback gets called it hits another generated function which emits the pullback code.

In hand written code this would look like:

struct Pullback{F}
  data
end

function J(::typeof(foo), x)
  a, da = J(sin, x)
  b, db = J(cos, a)
  return b, Pullback{typeof(foo)}((da, db))
end

function(p::Pullback{typeof(foo)})(bฬ„)
  da, db = p.data[1], p.data[2]
  aฬ„ = db(bฬ„)
  xฬ„ = da(aฬ„)
  return xฬ„
end

Debugging

Say some of our code is throwing an error.

bad(x) = x

Zygote.@adjoint bad(x) = x, _ -> error("bad")

foo(x) = bad(sin(x))

gradient(foo, 1) # error!

Zygote can usually give a stacktrace pointing right to the issue here, but in some cases there are compiler crashes that make this harder. In these cases it's best to (a) use _forward and (b) take advantage of Zygote's recursion to narrow down the problem function.

julia> y, back = Zygote._forward(foo, 1);

julia> back(1) # just make up a value here, it just needs to look similar to `y`
ERROR: bad

# Ok, so we try functions that foo calls

julia> y, back = Zygote._forward(sin, 1);

julia> back(1)
(nothing, 0.5403023058681398)

# Looks like that's fine

julia> y, back = Zygote._forward(bad, 1);

julia> back(1) # ok, here's our issue. Lather, rinse, repeat.
ERROR: bad

Of course, our goal is that you never have to do this, but until Zygote is more mature it can be a useful way to narrow down test cases.