K最近邻(KNN)算法原理和java实现
原理部分:
请参考:KNN演算法
?
?
代码实现:
?
KNN结点类,用来存储最近邻的k个元组相关的信息
/** * KNN结点类,用来存储最近邻的k个元组相关的信息 */public class KNNNode {private int index; // 元组标号private double distance; // 与测试元组的距离private String c; // 所属类别public KNNNode(int index, double distance, String c) {super();this.index = index;this.distance = distance;this.c = c;}public int getIndex() {return index;}public void setIndex(int index) {this.index = index;}public double getDistance() {return distance;}public void setDistance(double distance) {this.distance = distance;}public String getC() {return c;}public void setC(String c) {this.c = c;}}?
?
KNN算法主体类
/** * KNN算法主体类 */public class KNN {/** * 设置优先级队列的比较函数,距离越大,优先级越高 */private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {public int compare(KNNNode o1, KNNNode o2) {if (o1.getDistance() >= o2.getDistance()) {return 1;} else {return 0;}}};/** * 获取K个不同的随机数 * @param k 随机数的个数 * @param max 随机数最大的范围 * @return 生成的随机数数组 */public List<Integer> getRandKNum(int k, int max) {List<Integer> rand = new ArrayList<Integer>(k);for (int i = 0; i < k; i++) {int temp = (int) (Math.random() * max);if (!rand.contains(temp)) {rand.add(temp);} else {i--;}}return rand;}/** * 计算测试元组与训练元组之前的距离 * @param d1 测试元组 * @param d2 训练元组 * @return 距离值 */public double calDistance(List<Double> d1, List<Double> d2) {double distance = 0.00;for (int i = 0; i < d1.size(); i++) {distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));}return distance;}/** * 执行KNN算法,获取测试元组的类别 * @param datas 训练数据集 * @param testData 测试元组 * @param k 设定的K值 * @return 测试元组的类别 */public String knn(List<List<Double>> datas, List<Double> testData, int k) {PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);List<Integer> randNum = getRandKNum(k, datas.size());for (int i = 0; i < k; i++) {int index = randNum.get(i);List<Double> currData = datas.get(index);String c = currData.get(currData.size() - 1).toString();KNNNode node = new KNNNode(index, calDistance(testData, currData), c);pq.add(node);}for (int i = 0; i < datas.size(); i++) {List<Double> t = datas.get(i);double distance = calDistance(testData, t);KNNNode top = pq.peek();if (top.getDistance() > distance) {pq.remove();pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));}}return getMostClass(pq);}/** * 获取所得到的k个最近邻元组的多数类 * @param pq 存储k个最近近邻元组的优先级队列 * @return 多数类的名称 */private String getMostClass(PriorityQueue<KNNNode> pq) {Map<String, Integer> classCount = new HashMap<String, Integer>();for (int i = 0; i < pq.size(); i++) {KNNNode node = pq.remove();String c = node.getC();if (classCount.containsKey(c)) {classCount.put(c, classCount.get(c) + 1);} else {classCount.put(c, 1);}}int maxIndex = -1;int maxCount = 0;Object[] classes = classCount.keySet().toArray();for (int i = 0; i < classes.length; i++) {if (classCount.get(classes[i]) > maxCount) {maxIndex = i;maxCount = classCount.get(classes[i]);}}return classes[maxIndex].toString();}}?
KNN算法测试类
/** * KNN算法测试类 */public class TestKNN {/** * 从数据文件中读取数据 * @param datas 存储数据的集合对象 * @param path 数据文件的路径 */public void read(List<List<Double>> datas, String path){try {BufferedReader br = new BufferedReader(new FileReader(new File(path)));String data = br.readLine();List<Double> l = null;while (data != null) {String t[] = data.split(" ");l = new ArrayList<Double>();for (int i = 0; i < t.length; i++) {l.add(Double.parseDouble(t[i]));}datas.add(l);data = br.readLine();}} catch (Exception e) {e.printStackTrace();}}/** * 程序执行入口 * @param args */public static void main(String[] args) {TestKNN t = new TestKNN();String datafile = new File("").getAbsolutePath() + File.separator + "datafile";String testfile = new File("").getAbsolutePath() + File.separator + "testfile";try {List<List<Double>> datas = new ArrayList<List<Double>>();List<List<Double>> testDatas = new ArrayList<List<Double>>();t.read(datas, datafile);t.read(testDatas, testfile);KNN knn = new KNN();for (int i = 0; i < testDatas.size(); i++) {List<Double> test = testDatas.get(i);System.out.print("测试元组: ");for (int j = 0; j < test.size(); j++) {System.out.print(test.get(j) + " ");}System.out.print("类别为: ");System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));}} catch (Exception e) {e.printStackTrace();}}}?