org.knowceans.sandbox.gauss
Class GmmGibbsSampler

java.lang.Object
  extended by org.knowceans.sandbox.gauss.GmmGibbsSampler

public class GmmGibbsSampler
extends java.lang.Object

GmmGibbsSampler tests a simple Gibbs sampler that simulates the posterior of a Gaussian mixture model, i.e., the parameter distribution as a function of the evidence. Adapted from http://www.sph.umich.edu/csg/abecasis/class/815.23.pdf .

Sample each of the mixture parameters from conditional distribution: Dirichlet, Normal and Gamma distributions are typical.

Simple alternative is to sample the origin of each observation and assign observation to specific component (used here).

Author:
heinrich

Field Summary
private static int BURN_IN
          burn-in period
(package private)  double[] data
           
private static int ITERATIONS
          max iterations
private static double MIN_GROUP
          minimum members of a group
private static int THIN_INTERVAL
          sampling lag (?)
 
Constructor Summary
GmmGibbsSampler(double[] data)
          Initialise the Gibbs sampler with data.
 
Method Summary
(package private)  void addObservation(double x, int group, double[] counts, double[] sum, double[] sumsq)
          For updating mixture parameters
 void configure(int iterations, int burnIn, int minGroup, int thinInterval)
           
private  double dmix(double x, int k, double[] probs, double[] mean, double[] sigma)
          GMM likelihood
private  double dnorm(double x, double mu, double sigma)
          Normal likelihood
private  void gibbs(int k, double[] probs, double[] mean, double[] sigma)
          Main method: Select initial state ?
(package private)  void initialState(int k, int[] group, double[] counts, double[] sum, double[] sumsq)
          Initialisation: Must start with an assignment of observations to groupings ?
static void main(java.lang.String[] args)
          Driver with example data.
(package private)  void removeObservation(double x, int group, double[] counts, double[] sum, double[] sumsq)
          For updating mixture parameters
static double[] rmix(int n, double[] probs, double[] mean, double[] sigma)
          GMM sampling
(package private)  int sampleGroup(double x, int k, double[] probs, double[] mean, double[] sigma)
          Sampling a component, P(z_j=i | x_j, theta) = pi_i * f(x_j | theta_j) / sum_m pi_m f(x_j | theta_m) Calculate the probability that the observation originated from each component.
(package private)  void updateEstimates(int k, int n, double[] prob, double[] mean, double[] sigma, double[] counts, double[] sum, double[] sumsq)
          Update the parameters.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

data

double[] data

MIN_GROUP

private static double MIN_GROUP
minimum members of a group


THIN_INTERVAL

private static int THIN_INTERVAL
sampling lag (?)


BURN_IN

private static int BURN_IN
burn-in period


ITERATIONS

private static int ITERATIONS
max iterations

Constructor Detail

GmmGibbsSampler

public GmmGibbsSampler(double[] data)
Initialise the Gibbs sampler with data.

Parameters:
data -
Method Detail

updateEstimates

void updateEstimates(int k,
                     int n,
                     double[] prob,
                     double[] mean,
                     double[] sigma,
                     double[] counts,
                     double[] sum,
                     double[] sumsq)
Update the parameters. Before sampling a new origin for an observation, update mixture parameters given current assignments. Could be expensive!

Parameters:
k -
n -
prob -
mean -
sigma -
counts -
sum -
sumsq -

sampleGroup

int sampleGroup(double x,
                int k,
                double[] probs,
                double[] mean,
                double[] sigma)
Sampling a component, P(z_j=i | x_j, theta) = pi_i * f(x_j | theta_j) / sum_m pi_m f(x_j | theta_m)

Calculate the probability that the observation originated from each component. Use random number(s) to assign component membership.

Parameters:
x -
k -
probs -
mean -
sigma -
Returns:

dnorm

private double dnorm(double x,
                     double mu,
                     double sigma)
Normal likelihood

Parameters:
x -
mu -
sigma -
Returns:

dmix

private double dmix(double x,
                    int k,
                    double[] probs,
                    double[] mean,
                    double[] sigma)
GMM likelihood

Parameters:
x -
k -
probs -
mean -
sigma -
Returns:

removeObservation

void removeObservation(double x,
                       int group,
                       double[] counts,
                       double[] sum,
                       double[] sumsq)
For updating mixture parameters

Parameters:
x -
group -
counts -
sum -
sumsq -

addObservation

void addObservation(double x,
                    int group,
                    double[] counts,
                    double[] sum,
                    double[] sumsq)
For updating mixture parameters

Parameters:
x -
group -
counts -
sum -
sumsq -

initialState

void initialState(int k,
                  int[] group,
                  double[] counts,
                  double[] sum,
                  double[] sumsq)
Initialisation: Must start with an assignment of observations to groupings ? Many alternatives are possible, I chose to perform random assignments with equal probabilities

Parameters:
k -
group -
counts -
sum -
sumsq -

gibbs

private void gibbs(int k,
                   double[] probs,
                   double[] mean,
                   double[] sigma)
Main method: Select initial state ? Repeat a large number of times: 1. Select an element 2. Update conditional on other elements. If appropriate, output summary for each run.

Parameters:
k -
probs -
mean -
sigma -

rmix

public static double[] rmix(int n,
                            double[] probs,
                            double[] mean,
                            double[] sigma)
GMM sampling

Parameters:
probs - mixture responsibilities
mean - mean vector
sigma - stddev vector
Returns:

configure

public void configure(int iterations,
                      int burnIn,
                      int minGroup,
                      int thinInterval)

main

public static void main(java.lang.String[] args)
Driver with example data.

Parameters:
args -