Sequential Monte Carlo
import list
import option
import array
In this case study we implement the Sequential Monte Carlo algorithm for doing probabilistic inference. The idea is to run multiple instances of some probabilistic process (so called particles) and occasionally resample from the collection of these instances, while they are still running, according to the weight they picked up so far.
We define the following SMC
effect to model probabilistic processes and the
Measure
effect to deal with the results.
effect SMC {
def resample(): Unit
def uniform(): Double
def score(d: Double): Unit
}
We can use the SMC
effect to define some probabilistic programs.
def bernoulli(p: Double) = do uniform() < p
def biasedGeometric(p: Double): Int / SMC = {
do resample();
val x = bernoulli(p);
if (x) {
do score(log(1.5));
1 + biasedGeometric(p)
} else { 1 }
}
Here bernoulli
draws from a Bernoulli distribution (a biased coin flip) and
biasedGeometric
draws from a Geometric distribution with a bias towards
smaller numbers.
A SMC Handler
A particle consists of its current weight (or “score”), the age (number of resampling generations it survived – not used at the moment), and the continuation that corresponds to the remainder of its computation.
record Particle(weight: Double, age: Int, cont: Cont[Unit, Unit])
record Measurement[R](weight: Double, data: R)
record Particles[R](moving: List[Particle], done: List[Measurement[R]])
Using the above data types, we can define our SMC handler as follows:
def smcHandler[R](numberOfParticles: Int) {
// should maintain the number of particles.
resample: Particles[R] => Particles[R]
} { p: () => R / SMC } = {
var currentWeight = 1.0;
var particles: List[Particle] = Nil()
var measurements: List[Measurement[R]] = Nil()
var currentAge = 0;
def checkpoint(cont: Cont[Unit, Unit]) =
particles = Cons(Particle(currentWeight, currentAge, cont), particles)
def run(p: Particle): Unit = {
currentWeight = p.weight;
currentAge = p.age;
p.cont.apply(())
}
def run() = {
val Particles(ps, ms) = resample(Particles(particles, measurements));
particles = Nil();
measurements = ms;
ps.foreach { p => p.run }
}
repeat(numberOfParticles) {
currentWeight = 1.0;
try {
val res = p();
measurements = Cons(Measurement(currentWeight, res), measurements)
} with SMC {
def resample() = checkpoint(cont { t => resume(t) })
def uniform() = resume(random())
def score(d) = { currentWeight = currentWeight * d; resume(()) }
}
}
while (not(particles.isEmpty)) { run() }
measurements
}
It runs numberOfParticles
-many instances of the provided program p
under a handler that collects the
continuation whenever a particle encounters a call to resample
. Once all particles either finished their
computation or hit a resample
, the handler passes the list of live particles to the argument function resample
.
This argument function then draws particles from the given list according to their weights to obtain a new list
of particles. Thus, the new list will likely not contain particles with small weights while there will likely be
multiple copies of particles with large weights.
Resampling
We now implement such a (naive) resampling function. It proceeds by first filling an array (100 times the particle count) with a number of copies of the particles relative to their weight. Then it picks new particles at random, resetting the weights in the new list.
def resampleUniform[R](particles: Particles[R]): Particles[R] = {
val Particles(ps, ms) = particles;
val total = ps.totalWeight + ms.totalWeight
val numberOfParticles = ps.size + ms.size
val targetSize = numberOfParticles * 100;
var newParticles: List[Particle] = Nil();
var newMeasurements: List[Measurement[R]] = Nil();
// select new particles by drawing at random
// this is a very naive implementation with O(numberOfParticles^2) worst case.
def draw() = {
val targetWeight = random() * total;
var currentWeight = 0.0;
var remainingPs = ps
var remainingMs = ms
while (currentWeight < targetWeight) {
(remainingPs, remainingMs) match {
case (Nil(), Nil()) => <> // ERROR should not happen
case (Cons(p, rest), _) =>
currentWeight = currentWeight + p.weight
if (currentWeight >= targetWeight) {
newParticles = Cons(Particle(1.0, p.age, p.cont), newParticles)
} else { remainingPs = rest }
case (Nil(), Cons(m, rest)) =>
currentWeight = currentWeight + m.weight
if (currentWeight >= targetWeight) {
newMeasurements = Cons(Measurement(1.0, m.data), newMeasurements)
} else { remainingMs = rest }
}
}
}
repeat(numberOfParticles) { draw() }
Particles(newParticles, newMeasurements)
}
// helper function to compute the total weight
def totalWeight(ps: List[Particle]): Double = {
var totalWeight = 0.0
ps.foreach { case Particle(w, _, _) =>
totalWeight = totalWeight + w
}
totalWeight
}
def totalWeight[R](ps: List[Measurement[R]]): Double = {
var totalWeight = 0.0
ps.foreach { case Measurement(w, _) =>
totalWeight = totalWeight + w
}
totalWeight
}
Now we have everything available to define smc
as smcHandler
using resampleUniform
:
def smc[R](numberOfParticles: Int) { p: () => R / SMC } =
smcHandler[R](numberOfParticles) { ps => resampleUniform(ps) } { p() }
Importance Sampling
Of course the above handler is not the only one. We can define an even simpler handler that performs importance sampling by sequentially running each particle to the end.
def importance[R](n: Int) { p : => R / SMC } = {
var measurements: List[Measurement[R]] = Nil()
n.repeat {
var currentWeight = 1.0;
try {
val result = p();
measurements = Cons(Measurement(currentWeight, result), measurements)
} with SMC {
def resample() = resume(())
def uniform() = resume(random())
def score(d) = { currentWeight = currentWeight * d; resume(()) }
}
}
measurements
}
Running the Examples
extern control def sleep(n: Int): Unit =
"$effekt.callcc(k => window.setTimeout(() => k(null), n))"
// here we set a time out to allow rerendering
extern control def reportMeasurementJS[R](w: Double, d: R): Unit =
"$effekt.callcc(k => { showPoint(w, d); window.setTimeout(() => k(null), 0)})"
extern control def reportDiscreteMeasurementJS[R](w: Double, d: R): Unit =
"$effekt.callcc(k => { showPoint(w, d, { discrete: true }); window.setTimeout(() => k(null), 0)})"
// here we set a time out to allow rerendering
extern io def setupGraphJS(): Unit =
"setup()"
To visualize the results, we define the following helper function report
that
handles Measure
effects by adding the data points to a graph (below).
def report[R](interval: Int, ms: List[Measurement[R]]) = {
setupGraphJS();
ms.foreach { m =>
reportMeasurementJS(m.weight, m.data);
sleep(interval)
}
}
def reportDiscrete[R](interval: Int, ms: List[Measurement[R]]) = {
setupGraphJS();
ms.foreach { m =>
reportDiscreteMeasurementJS(m.weight, m.data);
sleep(interval)
}
}
Running SMC and importance sampling now is a matter of composing the handlers.
def runSMC(numberOfParticles: Int) =
report(20, smc(numberOfParticles) { biasedGeometric(0.5) })
def runImportance(numberOfParticles: Int) =
report(20, importance(numberOfParticles) { biasedGeometric(0.5) })
We have also prepared a handler called reportDiscrete
to experiment with examples that
have non-integer return types:
def runDiscrete(numberOfParticles: Int) =
reportDiscrete(0, smc(numberOfParticles) {
if (bernoulli(0.5)) { "hello" } else { "world" }
})
In the below REPL you can try the examples. Click run
and then try entering runSMC(100)
(then click run
again):
Particles: