Java Code for Feature Selection
July 6th, 2012Feature selection is the technique of selecting a subset of relevant features for building robust learning models, as described in a Wikipedia article. We need to go through the process basically whenever we apply a machine learning model to a certain task such as classification or regression. There are various approaches to feature selection in general, but mutual information, chi-square, and information gain are common metrics to figure out how effectively an individual feature characterizes a particular class in natural language processing tasks such as text classification.
Just for future reference, I will put my Java code for calculating those scores below.
/**
* This class counts the number of feature occurrences given a particular class.
* Specifically, it assumes the following confusion matrix for feature selection.
*
* Feature occurrence
* yes no
* Gold standard class yes n11 n10
* no n01 n00
*
* The variable n is the sum of n11, n10, n01, and n00.
*
* @author Jun Araki
*/
public class FeatureOccurrenceCounter {
protected double n11;
protected double n10;
protected double n01;
protected double n00;
protected double n;
/**
* Constructor.
*/
public FeatureOccurrenceCounter() {
n11 = 0.0;
n10 = 0.0;
n01 = 0.0;
n00 = 0.0;
}
/**
* Constructor with respective counts.
*
* @param n11
* @param n10
* @param n01
* @param n00
*/
public FeatureOccurrenceCounter(double n11, double n10, double n01, double n00) {
this.n11 = n11;
this.n10 = n10;
this.n01 = n01;
this.n00 = n00;
}
public void calculateSum() {
n = n11 + n10 + n01 + n00;
}
public void incrementN11() {
n11++;
}
public void incrementN10() {
n10++;
}
public void incrementN01() {
n01++;
}
public void incrementN00() {
n00++;
}
public double getN11() {
return n11;
}
public void setN11(double n11) {
this.n11 = n11;
}
public double getN10() {
return n10;
}
public void setN10(double n10) {
this.n10 = n10;
}
public double getN01() {
return n01;
}
public void setN01(double n01) {
this.n01 = n01;
}
public double getN00() {
return n00;
}
public void setN00(double n00) {
this.n00 = n00;
}
public double getN() {
calculateSum();
return n;
}
}
/**
* This class gives the following popular metrics for feature selection.
* <ul>
* <li>Mutual Information (MI)
* <li>Chi-square
* <li>Information Gain (IG)
* </ul>
*
* In order to calculate the scores above, it needs to first count the number
* of feature occurrences. Specifically, it assumes the following confusion
* matrix for feature selection.
*
* Feature occurrence
* yes no
* Gold standard class yes n11 n10
* no n01 n00
*
* The variable n is the sum of n11, n10, n01, and n00. For more information
* on feature selection, see:
*
* Christopher D. Manning, Prabhakar Raghavan, and Hinrich Schtze. 2008.
* Introduction to Information Retrieval. Cambridge University Press.
*
* George Forman, Isabelle Guyon, and Andr Elisseeff. 2003. An Extensive
* Empirical Study of Feature Selection Metrics for Text Classification.
* Journal of Machine Learning Research, 3:12891305.
*
* @author Jun Araki
*/
public class FeatureSelectionMetrics extends FeatureOccurrenceCounter {
/** Mutual information score */
private Double mi;
/** Chi-square score */
private Double chiSquare;
/** Information gain score */
private Double ig;
/**
* Constructor.
*/
public FeatureSelectionMetrics() {
super();
}
/**
* Constructor taking respective counts.
*
* @param n11
* @param n10
* @param n01
* @param n00
*/
public FeatureSelectionMetrics(double n11, double n10, double n01, double n00) {
super(n11, n10, n01, n00);
}
/**
* Calculates and returns the mutual information score.
*
* @return the mutual information score
*/
public Double getMI() {
calculateMutualInformation();
return mi;
}
/**
* Calculates and returns the chi-square score.
*
* @return the chi-square score
*/
public Double getChiSquare() {
calculateChiSquare();
return chiSquare;
}
/**
* Calculates and returns the information gain score.
*
* @return the information gain score
*/
public Double getIG() {
calculateInformationGain();
return ig;
}
/**
* Calculates mutual information given the counts from n11 to n00. For more
* information, see (Manning et al., 2008).
*/
private void calculateMutualInformation() {
if (n11 == 0 || n10 == 0 || n01 == 0 || n00 == 0) {
// Boundary cases.
mi = null;
return;
}
calculateSum();
double n1x = n10 + n11;
double n0x = n00 + n01;
double nx1 = n01 + n11;
double nx0 = n00 + n10;
mi = (n11 / n) * log2((n * n11) / (n1x * nx1))
+ (n01 / n) * log2((n * n01) / (n0x * nx1))
+ (n10 / n) * log2((n * n10) / (n1x * nx0))
+ (n00 / n) * log2((n * n00) / (n0x * nx0));
}
/**
* Calculates the chi-square score given the counts from n11 to n00. In
* order to know statistical significance of the score, you can refer to
* the following relationship between the p value and chi-square score
* (Manning et al., 2008).
*
* p value chi-square
* 0.1 2.71
* 0.05 3.84
* 0.01 6.63
* 0.005 7.88
* 0.001 10.83
*/
private void calculateChiSquare() {
if (n11 + n01 == 0 || n11 + n10 == 0 || n10 + n00 == 0 || n01 + n00 == 0) {
// Boundary cases.
chiSquare = null;
return;
}
calculateSum();
// An arithmetically simpler way of computing chi-square.
chiSquare = ((n11 + n10 + n01 + n00) * (n11 * n00 - n10 * n01) * (n11 * n00 - n10 * n01))
/ ((n11 + n01) * (n11 + n10) * (n10 + n00) * (n01 + n00));
}
/**
* Calculates the information gain score given the counts from n11 to n00.
* For more information, see (Forman et al., 2003).
*/
private void calculateInformationGain() {
if (n11 == 0 || n10 == 0 || n01 == 0 || n00 == 0) {
// Boundary cases.
ig = null;
return;
}
calculateSum();
double n1x = n10 + n11;
double n0x = n00 + n01;
double nx1 = n01 + n11;
double nx0 = n00 + n10;
ig = (n11 / n) * Math.log((n11 / n) / ((n11 / nx1) * (n11 / n1x)))
+ (n10 / n) * Math.log((n10 / n) / ((n10 / nx0) * (n10 / n1x)))
+ (n01 / n) * Math.log((n01 / n) / ((n01 / nx1) * (n01 / n0x)))
+ (n00 / n) * Math.log((n00 / n) / ((n00 / nx0) * (n00 / n0x)));
}
private double log2(double value) {
return (Math.log(value) / Math.log(2));
}
/**
* A simple tester with a couple of examples.
*
* @param args
*/
public static void main(String[] args) {
FeatureSelectionMetrics fsm1 = new FeatureSelectionMetrics(49, 141, 27652, 774106);
Double mi1 = fsm1.getMI();
Double chiSquare1 = fsm1.getChiSquare();
Double ig1 = fsm1.getIG();
FeatureSelectionMetrics fsm2 = new FeatureSelectionMetrics(0, 4, 0, 164);
Double mi2 = fsm2.getMI();
Double chiSquare2 = fsm2.getChiSquare();
Double ig2 = fsm2.getIG();
System.out.println("mi1: " + mi1); // Should be approximately 0.0001105
System.out.println("chiSquare1: " + chiSquare1); // Should be approximately 284
System.out.println("ig1: " + ig1);
// The scores below should be undefined (null) due to boundary cases.
System.out.println("mi2: " + mi2);
System.out.println("chiSquare2: " + chiSquare2);
System.out.println("ig2: " + ig2);
}
}














