Automatic Differentiation
In this case study we reimplement automatic differentiation as presented in
Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator Fei Wang et al. ICFP 2019
Instead of using the control operators shift
and reset
, we are using effects and
handlers.
module examples/casestudies/ad
import immutable/list
import mutable/heap
Representing Differentiable Expressions
Like Wang et al. we start by defining our own representation of differentiable
expressions. We represent the language of differentiable expressions by an
effect AD
, which operates on Num
. The effect AD
can be thought of as an
embedded DSL.
record Num(value: Double, d: Ref[Double])
effect AD {
def num(x: Double): Num
def add(x: Num, y: Num): Num
def mul(x: Num, y: Num): Num
}
If Effekt would support polymorphic effects (or abstract type members on
effects) we could abstract over the domain Num
, which is hard-coded above to
contain a double value and a mutable reference.
We make use of operator overloading by defining infixAdd
and infixMul
, whch are
special-cased in the Effekt compiler.
def infixAdd(x: Num, y: Num) = add(x, y)
def infixMul(x: Num, y: Num) = mul(x, y)
We can use the number DSL to express an example program.
// d = 3 + 3x^2
def prog(x: Num): Num / AD =
(num(3.0) * x) + (x * x * x)
This program uses effect operations for multiplication, addition, and for embedding constant literals.
Forwards Propagation
To compute the derivative of an expression like prog
, we start with
forwards propagation. A differentiable function has type Num => Num / AD
and computing the derivative is expressed as a handler.
For forwards propagation, we do not use the fact that the derivative stored
in Num
is a mutable box. This is only necessary for backwards propagation.
def forwards(in: Double) { prog: Num => Num / AD }: Double =
try { prog(Num(in, fresh(1.0))).d.get } with AD {
def num(v) = resume(Num(v, fresh(0.0)))
def add(x, y) = resume(Num(
x.value + y.value,
fresh(x.d.get + y.d.get)))
def mul(x, y) = resume(Num(
x.value * y.value,
fresh(x.d.get * y.value + y.d.get * x.value)))
}
Except for the wrapping and unwrapping of the references, the definition
of add
and mul
are exactly the ones we would expect from a text book.
Backwards Propagation
We can use the same differentiable expression and compute its derivative
by using backwards propagation. Since we modeled the DSL as an effect,
we automatically get access to the continuation in the implementation of
add
and mul
. We thus do not have to use shift
and reset
.
Otherwise the implementation closely follows the one by Wang et al.
def backwards(in: Double) { prog: Num => Num / AD }: Double = {
// the representation of our input
val input = Num(in, fresh(0.0))
// a helper function to update the derivative of a given number by adding v
def push(n: Num)(v: Double): Unit = n.d.put(n.d.get + v)
try { prog(input).push(1.0) } with AD {
def num(v) = resume(Num(v, fresh(0.0)))
def add(x, y) = {
val z = Num(x.value + y.value, fresh(0.0))
resume(z)
x.push(z.d.get);
y.push(z.d.get)
}
def mul(x, y) = {
val z = Num(x.value * y.value, fresh(0.0))
resume(z)
x.push(y.value * z.d.get);
y.push(x.value * z.d.get)
}
}
// the derivative of `prog` at `in` is stored in the mutable reference
input.d.get
}
Example Usages
We can use forwards and backwards propagation to compute derivatives of a few examples.
def main() = {
println(forwards(2.0) { x => prog(x) })
println(backwards(2.0) { x => prog(x) })
println(forwards(3.0) { x => prog(x) })
println(backwards(3.0) { x => prog(x) })
println(forwards(0.0) { x => prog(x) })
println(backwards(0.0) { x => prog(x) })
// we have the same pertubation confusion as in Lantern
val result = forwards(1.0) { x =>
val shouldBeOne = forwards(1.0) { y => x + y }
val z = num(shouldBeOne)
x * z
}
println(result)
val result2 = backwards(1.0) { x =>
val shouldBeOne = backwards(1.0) { y => x + y }
val z = num(shouldBeOne)
x * z
}
println(result2)
// this is proposed by Wang et al. as a solution to pertubation confusion
val result3 = backwards(1.0) { x =>
val shouldBeOne = forwards(1.0) { y => x + y }
val z = num(shouldBeOne)
x * z
}
println(result3)
}