Sequential Monte Carlo

import immutable/list
import immutable/option
import mutable/array
import unsafe/cont

In this case study we implement the Sequential Monte Carlo algorithm for doing probabilistic inference. The idea is to run multiple instances of some probabilistic process (so called particles) and occasionally resample from the collection of these instances, while they are still running, according to the weight they picked up so far.

We define the following SMC effect to model probabilistic processes and the Measure effect to deal with the results.

effect SMC {
  def resample(): Unit
  def uniform(): Double
  def score(d: Double): Unit
}

We can use the SMC effect to define some probabilistic programs.

def bernoulli(p: Double) = uniform() < p

def biasedGeometric(p: Double): Int / SMC = {
  resample();
  val x = bernoulli(p);
  if (x) {
    score(log(1.5));
    1 + biasedGeometric(p)
  } else { 1 }
}

Here bernoulli draws from a Bernoulli distribution (a biased coin flip) and biasedGeometric draws from a Geometric distribution with a bias towards smaller numbers.

A SMC Handler

A particle consists of its current weight (or “score”), the age (number of resampling generations it survived – not used at the moment), and the continuation that corresponds to the remainder of its computation.

record Particle(weight: Double, age: Int, cont: Cont[Unit, Unit])
record Measurement[R](weight: Double, data: R)

record Particles[R](moving: List[Particle], done: List[Measurement[R]])

Using the above data types, we can define our SMC handler as follows:

def smcHandler[R](numberOfParticles: Int) {
  // should maintain the number of particles.
  resample: Particles[R] => Particles[R]
} { p: () => R / SMC } = {
  var weight = 1.0;
  var particles: List[Particle] = Nil()
  var measurements: List[Measurement[R]] = Nil()
  var age = 0;

  def checkpoint(cont: Cont[Unit, Unit]) =
    particles = Cons(Particle(weight, age, cont), particles)

  def run(p: Particle): Unit = {
    weight = p.weight;
    age = p.age;
    p.cont.apply(())
  }

  def run() = {
    val Particles(ps, ms) = resample(Particles(particles, measurements));
    particles = Nil();
    measurements = ms;
    ps.foreach { p => p.run }
  }

  repeat(numberOfParticles) {
    weight = 1.0;
    try {
      val res = p();
      measurements = Cons(Measurement(weight, res), measurements)
    } with SMC {
      def resample() = checkpoint(cont { t => resume(t) })
      def uniform() = resume(random())
      def score(d) = { weight = weight * d; resume(()) }
    }
  }

  while (not(particles.isEmpty)) { run() }
  measurements
}

It runs numberOfParticles-many instances of the provided program p under a handler that collects the continuation whenever a particle encounters a call to resample. Once all particles either finished their computation or hit a resample, the handler passes the list of live particles to the argument function resample. This argument function then draws particles from the given list according to their weights to obtain a new list of particles. Thus, the new list will likely not contain particles with small weights while there will likely be multiple copies of particles with large weights.

Resampling

We now implement such a (naive) resampling function. It proceeds by first filling an array (100 times the particle count) with a number of copies of the particles relative to their weight. Then it picks new particles at random, resetting the weights in the new list.

def resampleUniform[R](particles: Particles[R]): Particles[R] = {
  val Particles(ps, ms) = particles;
  val total = ps.totalWeight + ms.totalWeight
  val numberOfParticles = ps.size + ms.size
  val targetSize = numberOfParticles * 100;

  var newParticles: List[Particle] = Nil();
  var newMeasurements: List[Measurement[R]] = Nil();

  // select new particles by drawing at random
  // this is a very naive implementation with O(numberOfParticles^2) worst case.
  def draw() = {
    val targetWeight = random() * total;
    var currentWeight = 0.0;
    var remainingPs = ps
    var remainingMs = ms
    while (currentWeight < targetWeight) {
      (remainingPs, remainingMs) match {
        case (Nil(), Nil()) => <> // ERROR should not happen
        case (Cons(p, rest), _) =>
          currentWeight = currentWeight + p.weight
          if (currentWeight >= targetWeight) {
            newParticles = Cons(Particle(1.0, p.age, p.cont), newParticles)
          } else { remainingPs = rest }
        case (Nil(), Cons(m, rest)) =>
          currentWeight = currentWeight + m.weight
          if (currentWeight >= targetWeight) {
            newMeasurements = Cons(Measurement(1.0, m.data), newMeasurements)
          } else { remainingMs = rest }
      }
    }
  }

  repeat(numberOfParticles) { draw() }

  Particles(newParticles, newMeasurements)
}

// helper function to compute the total weight
def totalWeight(ps: List[Particle]): Double = {
  var totalWeight = 0.0
  ps.foreach { case Particle(w, _, _) =>
    totalWeight = totalWeight + w
  }
  totalWeight
}
def totalWeight[R](ps: List[Measurement[R]]): Double = {
  var totalWeight = 0.0
  ps.foreach { case Measurement(w, _) =>
    totalWeight = totalWeight + w
  }
  totalWeight
}

Now we have everything available to define smc as smcHandler using resampleUniform:

def smc[R](numberOfParticles: Int) { p: () => R / SMC } =
  smcHandler[R](numberOfParticles) { ps => resampleUniform(ps) } { p() }

Importance Sampling

Of course the above handler is not the only one. We can define an even simpler handler that performs importance sampling by sequentially running each particle to the end.

def importance[R](n: Int) { p : R / SMC } = {
  var measurements: List[Measurement[R]] = Nil()
  n.repeat {
    var currentWeight = 1.0;
    try {
      val result = p();
      measurements = Cons(Measurement(currentWeight, result), measurements)
    } with SMC {
      def resample() = resume(())
      def uniform() = resume(random())
      def score(d) = { currentWeight = currentWeight * d; resume(()) }
    }
  }
  measurements
}

Running the Examples

extern def sleep(n: Int): Unit =
  "$effekt.callcc(k => window.setTimeout(() => k(null), n))"

// here we set a time out to allow rerendering
extern def reportMeasurementJS[R](w: Double, d: R): Unit =
  "$effekt.callcc(k => { showPoint(w, d); window.setTimeout(() => k(null), 0)})"

extern def reportDiscreteMeasurementJS[R](w: Double, d: R): Unit =
  "$effekt.callcc(k => { showPoint(w, d, { discrete: true }); window.setTimeout(() => k(null), 0)})"


// here we set a time out to allow rerendering
extern pure def setupGraphJS(): Unit =
  "setup()"

To visualize the results, we define the following helper function report that handles Measure effects by adding the data points to a graph (below).

def report[R](interval: Int, ms: List[Measurement[R]]) = {
  setupGraphJS();
  ms.foreach { m =>
    reportMeasurementJS(m.weight, m.data);
    sleep(interval)
  }
}
def reportDiscrete[R](interval: Int, ms: List[Measurement[R]]) = {
  setupGraphJS();
  ms.foreach { m =>
    reportDiscreteMeasurementJS(m.weight, m.data);
    sleep(interval)
  }
}

Running SMC and importance sampling now is a matter of composing the handlers.

def runSMC(numberOfParticles: Int) =
  report(20, smc(numberOfParticles) { biasedGeometric(0.5) })
def runImportance(numberOfParticles: Int) =
  report(20, importance(numberOfParticles) { biasedGeometric(0.5) })

We have also prepared a handler called reportDiscrete to experiment with examples that have non-integer return types:

def runDiscrete(numberOfParticles: Int) =
  reportDiscrete(0, smc(numberOfParticles) {
    if (bernoulli(0.5)) { "hello" } else { "world" }
  })

In the below REPL you can try the examples. Click run and then try entering runSMC(100) (then click run again):


Particles: