org.knowceans.util
Class DirichletEstimation

java.lang.Object
  extended by org.knowceans.util.DirichletEstimation

public class DirichletEstimation
extends java.lang.Object

DirichletEstimation provides a number of methods to estimate parameters of a Dirichlet distribution and the Dirichlet-multinomial (Polya) distribution. Most of the algorithms described in Minka (2003) Estimating a Dirichlet distribution, but some home-grown extensions.

Author:
gregor

Constructor Summary
DirichletEstimation()
           
 
Method Summary
static int alphaFixedPoint(double[] suffstats, double[] alpha)
          fixpoint iteration on alpha.
static void alphaNewton(int N, double[] suffstats, double[] alpha)
           
static double[] estimateAlpha(double[][] pp)
          Estimator for the Dirichlet parameters.
static double[] estimateAlpha(int[][] nmk)
          Estimator for the Dirichlet parameters from counts.
static int estimateAlphaLoo(double[] alpha, int[][] nmk)
          Polya estimation using the fixed point iteration of the leave-one-out likelihood, after Minka 2003.
static double[] estimateAlphaMap(int[][] nmk, int[] nm, double[] alpha, double a, double b)
          fixpoint iteration on alpha using counts as input and estimating by Polya distribution directly.
static double estimateAlphaMap(int[][] nmk, int[] nm, double alpha, double a, double b)
          fixpoint iteration on alpha using counts as input and estimating by Polya distribution directly.
static double[][] estimateAlphaMapSub(int[][] nmk, int[] nm, int[] m2j, double[][] alphajk, double a, double b)
          estimate several alphas based on subsets of rows
static double estimateAlphaMapSub(int[][] nmk, int[] nm, int[] mrows, double alpha, double a, double b)
          fixpoint iteration on alpha using counts as input and estimating by Polya distribution directly.
static double estimateAlphaMomentMatch(int[][] nmk, int[] nm)
           
static double[] estimateMeanPrec(double[][] pp)
          estimate mean and precision of the observations separately.
static double[] estimateMeanPrec(int[][] nn)
          Estimate mean and precision of the observations separately from counts.
static double[] getAlpha(double[] meanPrec)
          Get the alpha vector out of a mean precision combined vector.
static double[] getMean(double[] meanPrec)
          Get the mean out of a mean precision combined vector.
static double getPrec(double[] meanPrec)
          Get the precision out of a mean precision combined vector
static double[] guessAlpha(double[][] pp, double[] pmean)
          Estimate the dirichlet parameters using the moments method
static double[] guessAlpha(int[][] nmk)
          guess alpha via Dirichlet parameter point estimate and Dirichlet moment matching.
static double[] guessAlphaDirect(int[][] nmk, int[] nm)
          guess alpha via "direct" moment matching on the Polya distribution (which is just Dirichlet moment matching in disguise).
static double[] guessMean(double[][] pp)
          Estimate the Dirichlet mean of the data along columns
static double guessPrecision(double[] pmean, double[] pmeansq)
          Estimate the Dirichlet precision using moment matching method.
static void main(java.lang.String[] args)
           
static void testDirichlet()
           
static void testPolya()
           
 
Methods inherited from class java.lang.Object
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

DirichletEstimation

public DirichletEstimation()
Method Detail

estimateAlpha

public static double[] estimateAlpha(double[][] pp)
Estimator for the Dirichlet parameters.

Parameters:
multinomial - parameters p
Returns:
ML estimate of the corresponding parameters alpha

estimateMeanPrec

public static double[] estimateMeanPrec(double[][] pp)
estimate mean and precision of the observations separately.

Parameters:
pp - input data with vectors in rows
Returns:
a vector of the Dirichlet mean in the elements 0..K-2 (the last element is the difference between the others and 1) and Dirichlet precision in the element K-1, where K is the dimensionality of the data, pp[0].length.

estimateAlpha

public static double[] estimateAlpha(int[][] nmk)
Estimator for the Dirichlet parameters from counts. This corresponds to the estimation of the Polya distribution but is done via Dirichlet parameter estimation.

Parameters:
nn - counts in each multinomial experiment
Returns:
ML estimate of the corresponding parameters alpha

estimateAlphaLoo

public static int estimateAlphaLoo(double[] alpha,
                                   int[][] nmk)
Polya estimation using the fixed point iteration of the leave-one-out likelihood, after Minka 2003. TODO: Gibbs sampler for Polya distribution.

Parameters:
alpha - [in/out] Dirichlet parameter with element for each k
nmk - count matrix with individual observations in rows ("documents") and categories in columns ("topics")
Returns:
number of iterations

estimateMeanPrec

public static double[] estimateMeanPrec(int[][] nn)
Estimate mean and precision of the observations separately from counts. This corresponds to estimation of the Polya distribution.

Parameters:
nn - input data with vectors in rows
Returns:
a vector of the Dirichlet mean in the elements 0..K-2 (the last element is the difference between the others and 1) and Dirichlet precision in the element K-1, where K is the dimensionality of the data, pp[0].length.

getPrec

public static double getPrec(double[] meanPrec)
Get the precision out of a mean precision combined vector

Parameters:
meanPrec -
Returns:

getMean

public static double[] getMean(double[] meanPrec)
Get the mean out of a mean precision combined vector. The mean vector is copied.

Parameters:
meanPrec -
Returns:

getAlpha

public static double[] getAlpha(double[] meanPrec)
Get the alpha vector out of a mean precision combined vector. The vector is copied.

Parameters:
meanPrec -
Returns:

guessAlpha

public static double[] guessAlpha(double[][] pp,
                                  double[] pmean)
Estimate the dirichlet parameters using the moments method

Parameters:
pp - data with items in rows and dimensions in cols
pmean - first moment of pp
Returns:

guessMean

public static double[] guessMean(double[][] pp)
Estimate the Dirichlet mean of the data along columns

Parameters:
pp -
Returns:

guessAlpha

public static double[] guessAlpha(int[][] nmk)
guess alpha via Dirichlet parameter point estimate and Dirichlet moment matching.

Parameters:
nmk -
Returns:

guessAlphaDirect

public static double[] guessAlphaDirect(int[][] nmk,
                                        int[] nm)
guess alpha via "direct" moment matching on the Polya distribution (which is just Dirichlet moment matching in disguise). After Minka's (2003) Equation (19ff).

Parameters:
nmk -
nm - sums of observations for all categories (eg document lengths, ndsum)

guessPrecision

public static double guessPrecision(double[] pmean,
                                    double[] pmeansq)
Estimate the Dirichlet precision using moment matching method.

Parameters:
pmean -
pmeansq -
Returns:

alphaNewton

public static void alphaNewton(int N,
                               double[] suffstats,
                               double[] alpha)

alphaFixedPoint

public static int alphaFixedPoint(double[] suffstats,
                                  double[] alpha)
fixpoint iteration on alpha.

Parameters:
suffstats -
alpha - [in/out]

estimateAlphaMomentMatch

public static double estimateAlphaMomentMatch(int[][] nmk,
                                              int[] nm)

estimateAlphaMap

public static double estimateAlphaMap(int[][] nmk,
                                      int[] nm,
                                      double alpha,
                                      double a,
                                      double b)
fixpoint iteration on alpha using counts as input and estimating by Polya distribution directly. Eq. 55 in Minka (2003)

Parameters:
nmk - count data (documents in rows, topic associations in cols)
nm - total counts across rows
alpha -
alpha -

estimateAlphaMapSub

public static double estimateAlphaMapSub(int[][] nmk,
                                         int[] nm,
                                         int[] mrows,
                                         double alpha,
                                         double a,
                                         double b)
fixpoint iteration on alpha using counts as input and estimating by Polya distribution directly. Eq. 55 in Minka (2003). This version uses a subset of rows in nmk, indexed by mrows.

Parameters:
nmk - count data (documents in rows, topic associations in cols)
nm - total counts across rows
mrows - set of rows to be used for estimation
alpha -
alpha -

estimateAlphaMapSub

public static double[][] estimateAlphaMapSub(int[][] nmk,
                                             int[] nm,
                                             int[] m2j,
                                             double[][] alphajk,
                                             double a,
                                             double b)
estimate several alphas based on subsets of rows

Parameters:
nmk - full array
nm - full sums array
m2j - row-wise association with set j (max j - 1 = alphajk.length)
alphajk - set-wise hyperparameter
a -
b -
Returns:

estimateAlphaMap

public static double[] estimateAlphaMap(int[][] nmk,
                                        int[] nm,
                                        double[] alpha,
                                        double a,
                                        double b)
fixpoint iteration on alpha using counts as input and estimating by Polya distribution directly. Eq. 55 in Minka (2003)

Parameters:
nmk - count data (documents in rows, topic associations in cols)
nm - total counts across rows
alpha - [in/out]

main

public static void main(java.lang.String[] args)
                 throws java.lang.Exception
Throws:
java.lang.Exception

testPolya

public static void testPolya()

testDirichlet

public static void testDirichlet()