org.knowceans.dirichlet.dmm
Class DmmGibbsSampler

java.lang.Object
  extended by org.knowceans.dirichlet.dmm.DmmGibbsSampler

public class DmmGibbsSampler
extends java.lang.Object

DmmGibbsSampler implements a simple Gibbs sampler that simulates the posterior of a Dirichlet mixture model. It uses a fixed point iteration to estimate the component parameters.

We sample the originating component / component of each observation and assign observation to the component, re-estimating the component statistics.

TODO: increase estimation accuracy, e.g., by using Newton-Raphson instead of fixed-point estimator for component precisions. TODO: fix for very small data sets.

TODO: fix NaN and Inf behaviour for small data sets.

Author:
gregor heinrich

Field Summary
private  int[] component
          Component assignment of each data item.
private  int[] counts
           
(package private)  double[][] data
           
private  boolean fullPrecEst
          use moment-based guess or full estimation (fixed point iteration).
private  int iterations
          max iterations
private  int K
          Dimensionality of the data
private  int L
          Number of components.
private  double[][] means
          Componentwise sum mean of all data samples. dim = components x data dimension.
private  double[][] meansq
          Componentwise meansq of the sqares of all data samples. dim = components x data dimension.
private  double minComponent
          minimum members of a component
private  int N
          Length of the data
private  double[] precs
          Component precisions
private  double[] probs
          Component probabilities
private  double[][] suffstats
          Componentwise sufficient statistics for alpha * N. dim = components x data dimension.
 
Constructor Summary
DmmGibbsSampler(double[][] data)
          Initialise the Gibbs sampler with data.
 
Method Summary
(package private)  void addObservation(double[] x, int comp)
          For updating mixture parameters
 void configure(int iterations, int minComponent, boolean fullPrecision)
          Configure the Gibbs sampler.
 int[] getComponent()
          Get the data-component associations.
 double[][] getMeans()
          Calculates the means from the component sums and returns.
 double[] getPrecs()
          Get the component precisions
 double[] getProbs()
          Get the component weights.
private  void gibbs(int L)
          Main method: Select initial state ?
(package private)  void initialState(int L)
          Initialisation: random assignments with equal probabilities.
static void main(java.lang.String[] args)
          Driver with example data.
(package private)  void removeObservation(double[] x, int comp)
          For updating mixture parameters
(package private)  int sampleComponent(double[] x)
          Sampling a component, P(z_j=i | x_j, theta) = pi_i * f(x_j | theta_j) / sum_m pi_m f(x_j | theta_m) Calculate the probability that the observation originated from each component.
(package private)  void updateParams()
          Update the parameters.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

data

double[][] data

counts

private int[] counts

means

private double[][] means
Componentwise sum mean of all data samples. dim = components x data dimension.


meansq

private double[][] meansq
Componentwise meansq of the sqares of all data samples. dim = components x data dimension.


suffstats

private double[][] suffstats
Componentwise sufficient statistics for alpha * N. dim = components x data dimension.


component

private int[] component
Component assignment of each data item.


probs

private double[] probs
Component probabilities


precs

private double[] precs
Component precisions


L

private int L
Number of components.


N

private int N
Length of the data


K

private int K
Dimensionality of the data


fullPrecEst

private boolean fullPrecEst
use moment-based guess or full estimation (fixed point iteration). Can be first set to false and to fine-adjust the results set to true.


minComponent

private double minComponent
minimum members of a component


iterations

private int iterations
max iterations

Constructor Detail

DmmGibbsSampler

public DmmGibbsSampler(double[][] data)
Initialise the Gibbs sampler with data.

Parameters:
data -
Method Detail

main

public static void main(java.lang.String[] args)
Driver with example data.

Parameters:
args -

updateParams

void updateParams()
Update the parameters. Before sampling a new origin for an observation, update mixture parameters given current assignments. Could be expensive!


sampleComponent

int sampleComponent(double[] x)
Sampling a component, P(z_j=i | x_j, theta) = pi_i * f(x_j | theta_j) / sum_m pi_m f(x_j | theta_m)

Calculate the probability that the observation originated from each component. Use random number(s) to assign component membership.

Parameters:
x - data item
L - number of components
probs -
means -
precs -
Returns:

addObservation

void addObservation(double[] x,
                    int comp)
For updating mixture parameters

Parameters:
x -
comp -

removeObservation

void removeObservation(double[] x,
                       int comp)
For updating mixture parameters

Parameters:
x -
comp -

initialState

void initialState(int L)
Initialisation: random assignments with equal probabilities.

Parameters:
L -

gibbs

private void gibbs(int L)
Main method: Select initial state ? Repeat a large number of times: 1. Select an element 2. Update conditional on other elements. If appropriate, output summary for each run.

Parameters:
L - number of components

configure

public void configure(int iterations,
                      int minComponent,
                      boolean fullPrecision)
Configure the Gibbs sampler.

Parameters:
iterations -
minComponent -
fullPrecision -

getComponent

public final int[] getComponent()
Get the data-component associations.

Returns:

getPrecs

public final double[] getPrecs()
Get the component precisions

Returns:

getProbs

public final double[] getProbs()
Get the component weights.

Returns:

getMeans

public final double[][] getMeans()
Calculates the means from the component sums and returns.

Returns: