# Automatic Differentiation

In this case study we reimplement automatic differentiation as presented in

Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator Fei Wang et al. ICFP 2019

Instead of using the control operators `shift`

and `reset`

, we are using effects and
handlers.

```
module examples/casestudies/ad
import immutable/list
import mutable/heap
```

## Representing Differentiable Expressions

Like Wang et al. we start by defining our own representation of differentiable
expressions. We represent the language of differentiable expressions by an
effect `AD`

, which operates on `Num`

. The effect `AD`

can be thought of as an
embedded DSL.

```
record Num(value: Double, d: Ref[Double])
effect AD {
def num(x: Double): Num
def add(x: Num, y: Num): Num
def mul(x: Num, y: Num): Num
}
```

If Effekt would support polymorphic effects (or abstract type members on
effects) we could abstract over the domain `Num`

, which is hard-coded above to
contain a double value and a mutable reference.
We make use of operator overloading by defining `infixAdd`

and `infixMul`

, whch are
special-cased in the Effekt compiler.

```
def infixAdd(x: Num, y: Num) = add(x, y)
def infixMul(x: Num, y: Num) = mul(x, y)
```

We can use the number DSL to express an example program.

```
// d = 3 + 3x^2
def prog(x: Num): Num / AD =
(num(3.0) * x) + (x * x * x)
```

This program uses effect operations for multiplication, addition, and for embedding constant literals.

## Forwards Propagation

To compute the derivative of an expression like `prog`

, we start with
forwards propagation. A differentiable function has type `Num => Num / AD`

and computing the derivative is expressed as a handler.

For forwards propagation, we do not use the fact that the derivative stored
in `Num`

is a mutable box. This is only necessary for backwards propagation.

```
def forwards(in: Double) { prog: Num => Num / AD }: Double =
try { prog(Num(in, fresh(1.0))).d.get } with AD {
def num(v) = resume(Num(v, fresh(0.0)))
def add(x, y) = resume(Num(
x.value + y.value,
fresh(x.d.get + y.d.get)))
def mul(x, y) = resume(Num(
x.value * y.value,
fresh(x.d.get * y.value + y.d.get * x.value)))
}
```

Except for the wrapping and unwrapping of the references, the definition
of `add`

and `mul`

are exactly the ones we would expect from a text book.

## Backwards Propagation

We can use the same differentiable expression and compute its derivative
by using *backwards propagation*. Since we modeled the DSL as an effect,
we automatically get access to the continuation in the implementation of
`add`

and `mul`

. We thus do not have to use `shift`

and `reset`

.
Otherwise the implementation closely follows the one by Wang et al.

```
def backwards(in: Double) { prog: Num => Num / AD }: Double = {
// the representation of our input
val input = Num(in, fresh(0.0))
// a helper function to update the derivative of a given number by adding v
def push(n: Num)(v: Double): Unit = n.d.put(n.d.get + v)
try { prog(input).push(1.0) } with AD {
def num(v) = resume(Num(v, fresh(0.0)))
def add(x, y) = {
val z = Num(x.value + y.value, fresh(0.0))
resume(z)
x.push(z.d.get);
y.push(z.d.get)
}
def mul(x, y) = {
val z = Num(x.value * y.value, fresh(0.0))
resume(z)
x.push(y.value * z.d.get);
y.push(x.value * z.d.get)
}
}
// the derivative of `prog` at `in` is stored in the mutable reference
input.d.get
}
```

## Example Usages

We can use forwards and backwards propagation to compute derivatives of a few examples.

```
def main() = {
println(forwards(2.0) { x => prog(x) })
println(backwards(2.0) { x => prog(x) })
println(forwards(3.0) { x => prog(x) })
println(backwards(3.0) { x => prog(x) })
println(forwards(0.0) { x => prog(x) })
println(backwards(0.0) { x => prog(x) })
// we have the same pertubation confusion as in Lantern
val result = forwards(1.0) { x =>
val shouldBeOne = forwards(1.0) { y => x + y }
val z = num(shouldBeOne)
x * z
}
println(result)
val result2 = backwards(1.0) { x =>
val shouldBeOne = backwards(1.0) { y => x + y }
val z = num(shouldBeOne)
x * z
}
println(result2)
// this is proposed by Wang et al. as a solution to pertubation confusion
val result3 = backwards(1.0) { x =>
val shouldBeOne = forwards(1.0) { y => x + y }
val z = num(shouldBeOne)
x * z
}
println(result3)
}
```