java实现adaboost算法
adaboost算法的主要原理是训练若干个弱分类器,根据训练结果赋予它们不同的权值,最后再将这些弱分类器组合起来,形成一个强分类器,adaboost的基本原理在http://wenku.baidu.com/view/49478920aaea998fcc220e98.html###中已经有很详细的描述
这里使用上一篇博客中的感知器算法作为弱分类器,代码如下:
首先是adaboost算法的结果类
/** * * @author zhenhua.chen * @Description: adboost算法的结果类,包括弱分类器的集合和每个弱分类器的权重 * @date 2013-3-8 下午3:14:58 * */public class AdboostResult {private ArrayList<ArrayList<Double>> weakClassifierSet;private ArrayList<Double> classifierWeightSet;public ArrayList<ArrayList<Double>> getWeakClassifierSet() {return weakClassifierSet;}public void setWeakClassifierSet(ArrayList<ArrayList<Double>> weakClassifierSet) {this.weakClassifierSet = weakClassifierSet;}public ArrayList<Double> getClassifierWeightSet() {return classifierWeightSet;}public void setClassifierWeightSet(ArrayList<Double> classifierWeightSet) {this.classifierWeightSet = classifierWeightSet;}}
?adaboost算法:
/** * http://wenku.baidu.com/view/49478920aaea998fcc220e98.html * @author zhenhua.chen * @Description: TODO * @date 2013-3-8 下午3:09:36 * */public class AdaboostAlgorithm {private static final int T = 30; // 迭代次数PerceptronApproach pa = new PerceptronApproach(); // 弱分类器/** * * @Title: adaboostClassify * @Description: 通过训练集计算出组合分类器* @return AdboostResult* @throws */public AdboostResult adaboostClassify(ArrayList<ArrayList<Double>> dataSet) {AdboostResult res = new AdboostResult();int dataDimension;if(null != dataSet && dataSet.size() > 0) {dataDimension = dataSet.get(0).size();} else {return null;}// 为每条数据的权重赋初值ArrayList<Double> dataWeightSet = new ArrayList<Double>();for(int i = 0; i < dataSet.size(); i ++) {dataWeightSet.add(1.0 / (double)dataSet.size());}// 存储每个弱分类器的权重ArrayList<Double> classifierWeightSet = new ArrayList<Double>();// 存储每个弱分类器ArrayList<ArrayList<Double>> weakClassifierSet = new ArrayList<ArrayList<Double>>();for(int i = 0; i < T; i++) {// 计算弱分类器ArrayList<Double> sensorWeightVector = pa.getWeightVector(dataSet, dataWeightSet);weakClassifierSet.add(sensorWeightVector);// 计算弱分类器误差double error = 0; //分类数int rightClassifyNum = 0;ArrayList<Double> cllassifyResult = new ArrayList<Double>();for(int j = 0; j < dataSet.size(); j++) { double result = 0;for(int k = 0; k < dataDimension - 1; k++) {result += dataSet.get(j).get(k) * sensorWeightVector.get(k);}result += sensorWeightVector.get(dataDimension - 1);if(result < 0) { // 说明预测错误error += dataWeightSet.get(j);cllassifyResult.add(-1d);} else{ cllassifyResult.add(1d);rightClassifyNum++;}}System.out.println("总数:" + dataSet.size() + "正确预测数" + rightClassifyNum);if(dataSet.size() == rightClassifyNum) {classifierWeightSet.clear();weakClassifierSet.clear();classifierWeightSet.add(1.0);weakClassifierSet.add(sensorWeightVector);break;}// 更新数据集中每条数据的权重并归一化double dataWeightSum = 0;for(int j = 0; j < dataSet.size(); j++) {dataWeightSet.set(j, dataWeightSet.get(j) * Math.pow(Math.E, (-1) * 0.5 * Math.log((1 - error) / error) * cllassifyResult.get(j))); // 按照http://wenku.baidu.com/view/49478920aaea998fcc220e98.html,更新的权重少除一个常数dataWeightSum += dataWeightSet.get(j);}for(int j = 0; j < dataSet.size(); j++) {dataWeightSet.set(j, dataWeightSet.get(j) / dataWeightSum);}// 计算次弱分类器的权重double currentWeight = (0.5 * Math.log((1 - error) / error));classifierWeightSet.add(currentWeight);System.out.println("classifier weight: " + currentWeight);}res.setClassifierWeightSet(classifierWeightSet);res.setWeakClassifierSet(weakClassifierSet);return res;}/** * * @Title: computeResult * @Description: 计算输入数据的类别* @return double* @throws */public int computeResult(ArrayList<Double> data, AdboostResult classifier) {double result = 0;int dataSize = data.size();ArrayList<ArrayList<Double>> weakClassifierSet = classifier.getWeakClassifierSet();ArrayList<Double> classifierWeightSet = classifier.getClassifierWeightSet();for(int i = 0; i < weakClassifierSet.size(); i++) {for(int j = 0; j < dataSize; j++) {result += weakClassifierSet.get(i).get(j) * data.get(j) * classifierWeightSet.get(i);}result += weakClassifierSet.get(i).get(dataSize);}if(result > 0) {return 1;} else {return -1;}}
?测试类:
public static void main(String[] args) {/** * 测试数据,产生两类随机数据一类位于圆内,另一类位于包含小圆的大圆内,成环状 * 小圆半径为1,大圆半径为2,公共圆心位于(2, 2)内 */final int SMALL_CIRCLE_NUM = 24;final int RING_NUM = 34;ArrayList<ArrayList<Double>> dataSet = new ArrayList<ArrayList<Double>>();// 产生小圆数据for(int i = 0; i < SMALL_CIRCLE_NUM; i++) {double x = 1 + Math.random() * 2; // 1到3的随机数double y = 1 + Math.random() * 2; // 1到3的随机数if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内ArrayList<Double> smallCircle = new ArrayList<Double>();smallCircle.add(x);smallCircle.add(y);smallCircle.add(1d); // 列别1dataSet.add(smallCircle);}}// 产生外围环形数据for(int i = 0; i < RING_NUM; i++) {double x1 = Math.random() * 4;double y1 = Math.random() * 4;if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) { //说明位于环形区域内ArrayList<Double> ring = new ArrayList<Double>();ring.add(-x1);ring.add(-y1);ring.add(-1d); // 列别2dataSet.add(ring);}}AdaboostAlgorithm algo = new AdaboostAlgorithm();AdboostResult result = algo.adaboostClassify(dataSet);// 产生测试数据for(int i = 0; i < 10; i++) {ArrayList<Double> testData = new ArrayList<Double>();double x1 = Math.random() * 4;double y1 = Math.random() * 4;if((x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 4 < 0 && (x1 - 2) * (x1 - 2) + (y1 - 2) * (y1 - 2) - 1 > 0) {testData.add(x1);testData.add(y1);}//double x = 1 + Math.random() * 2; // 1到3的随机数//double y = 1 + Math.random() * 2; // 1到3的随机数//if((x - 2) * (x - 2) + (y - 2) * (y - 2) - 1 <= 0) { //说明位于圆内//testData.add(x);//testData.add(y);//}algo.computeResult(testData, result);System.out.println(algo.computeResult(testData, result));}}
?