### Example: Vehicle Localization¶

Sample code implementing max-product and sum-product belief propagation

In [1]:
# setup model
import sys
import numpy as np
from numpy import array as arr

# general parameters
max_product = 0 # max-product vs. sum-product
num_iters = 10

# model parameters
unary = np.array([[0.7,0.1,0.2],[0.7,0.2,0.1],[0.2,0.1,0.7],[0.7,0.2,0.1],
[0.2,0.6,0.2],[0.1,0.8,0.1],[0.4,0.3,0.3],[0.1,0.8,0.1],
[0.1,0.1,0.8],[0.1,0.5,0.4]])
pairwise = arr([[0.8,0.2,0.0],[0.2,0.6,0.2],[0.0,0.2,0.8]])
[num,dim] = unary.shape
np.set_printoptions(precision=2)

# print unaries
print "=============== INPUT ===================="
print "unary potentials:"
print unary
print "pairwise potential:"
print pairwise

=============== INPUT ====================
unary potentials:
[[ 0.7  0.1  0.2]
[ 0.7  0.2  0.1]
[ 0.2  0.1  0.7]
[ 0.7  0.2  0.1]
[ 0.2  0.6  0.2]
[ 0.1  0.8  0.1]
[ 0.4  0.3  0.3]
[ 0.1  0.8  0.1]
[ 0.1  0.1  0.8]
[ 0.1  0.5  0.4]]
pairwise potential:
[[ 0.8  0.2  0. ]
[ 0.2  0.6  0.2]
[ 0.   0.2  0.8]]

In [2]:
# initialize variables and factors
variables = dim*np.ones(num,dtype=np.int64)
factors = []

# unary
for i in range(num):
factors.append({'vars':arr([i]), 'vals':arr(unary[i])})

# pairwise
for i in range(num-1):
factors.append({'vars':arr([i,i+1]), 'vals':pairwise})

# init all messages to zero
msg_fv = {}
msg_vf = {}
ne_var = [[] for i in range(num)]
for [f_idx,f] in enumerate(factors):
for v_idx in f['vars']:
msg_fv[(f_idx,v_idx)] = np.ones(variables[v_idx])
msg_vf[(v_idx,f_idx)] = np.ones(variables[v_idx])
ne_var[v_idx].append(f_idx)

In [3]:
# run inference
for it in range(num_iters):

# for all factor-to-variable messages do
for [key,msg] in msg_fv.items():
f_idx = key[0]
v_idx = key[1]
f_vars = factors[f_idx]['vars']
f_vals = factors[f_idx]['vals']

# unary factor
if np.size(f_vars)==1:
msg_fv[(f_idx,v_idx)] = f_vals

# pairwise factor
else:

# target variable = first variable of factor
if v_idx==f_vars[0]:
msg_vf_mat = np.tile(msg_vf[(f_vars[1],f_idx)],(variables[v_idx],1))
if max_product:
msg_fv[(f_idx,v_idx)] = np.multiply(msg_vf_mat.transpose(),f_vals.transpose()).max(0)
else:
msg_fv[(f_idx,v_idx)] = np.multiply(msg_vf_mat.transpose(),f_vals.transpose()).sum(0)

# target variable = second variable of factor
else:
msg_vf_mat = np.tile(msg_vf[(f_vars[0],f_idx)],(variables[v_idx],1))
if max_product:
msg_fv[(f_idx,v_idx)] = np.multiply(msg_vf_mat.transpose(),f_vals).max(0)
else:
msg_fv[(f_idx,v_idx)] = np.multiply(msg_vf_mat.transpose(),f_vals).sum(0)

# for all variable-to-factor messages do
for [key,msg] in msg_vf.items():
v_idx = key[0]
f_idx = key[1]
f_vars = factors[f_idx]['vars']
f_vals = factors[f_idx]['vals']

msg_vf[(v_idx,f_idx)] = np.ones((variables[v_idx]))
for f_idx2 in ne_var[v_idx]:
if f_idx2 != f_idx:
msg_vf[(v_idx,f_idx)] = np.multiply(msg_vf[(v_idx,f_idx)],msg_fv[(f_idx2,v_idx)])

# normalize marginals and compute map state
marginals = np.zeros([num,dim])
for v_idx in range(num):
marginals[v_idx] = np.ones((variables[v_idx]))
for f_idx in ne_var[v_idx]:
marginals[v_idx] = np.multiply(marginals[v_idx],msg_fv[(f_idx,v_idx)])
marginals[v_idx] = marginals[v_idx]/np.sum(marginals[v_idx])

# output marginals / map state
print "=============== OUTPUT ===================="
if max_product:
print "max marginals:"
print marginals
print "map estimate:"
print np.argmax(marginals,axis=1)
else:
print "marginals:"
print marginals

=============== OUTPUT ====================
marginals:
[[ 0.9   0.05  0.04]
[ 0.89  0.07  0.04]
[ 0.85  0.09  0.06]
[ 0.77  0.2   0.03]
[ 0.28  0.7   0.02]
[ 0.09  0.88  0.03]
[ 0.15  0.67  0.18]
[ 0.02  0.8   0.18]
[ 0.04  0.22  0.74]
[ 0.03  0.36  0.61]]

In [4]:
# plot observations and marginals graphically
import matplotlib.pyplot as plt
import time

plt.close()

# plot observations
f, axarr = plt.subplots(1,10,figsize=(10,2))
plt.suptitle('Observations', fontsize=16, fontweight='bold')
for i in range(num):
axarr[i].barh([0,1,2],np.ones([3,1]),color='white',edgecolor='black',linewidth=2)
axarr[i].barh([0,1,2],unary[i],color='red')
axarr[i].axis('off')

# plot marginals
f, axarr = plt.subplots(1,10,figsize=(10,2))
if max_product:
plt.suptitle('Max-Marginals', fontsize=16, fontweight='bold')
else:
plt.suptitle('Marginals', fontsize=16, fontweight='bold')
for i in range(num):
axarr[i].barh([0,1,2],np.ones([3,1]),color='white',edgecolor='black',linewidth=2)
axarr[i].barh([0,1,2],marginals[i],color='green')
axarr[i].axis('off')

plt.show()

In [ ]: