package org.knuchel.playground; import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; // K-nearest neighbor /* 0.0 < sepalLength < 7.9 0.0 < sepalWidth < 4.4 0.0 < petalLength < 6.9 0.0 < petalWidth < 2.5 */ public class KNNJava { public static void main(String[] args) { List dataset = loadDataset("data/iris.data.csv"); Integer k = 5; Double sepalLength = 5.7d, sepalWidth = 2.6d, petalLength = 3.5d, petalWidth = 1d; String predictedSpecie = predictSpecie(dataset, k, sepalLength, sepalWidth, petalLength, petalWidth); System.out.println("Iris with [sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth=" + petalWidth + "] should be a " + predictedSpecie); // try different values of k and see how prediction errors change // evaluate(dataset); } public static String predictSpecie(List dataset, Integer k, Double sepalLength, Double sepalWidth, Double petalLength, Double petalWidth) { // calculate distance for each sample in dataset Iris unknownIris = new Iris(sepalLength, sepalWidth, petalLength, petalWidth, null); List scores = new ArrayList(); for (Iris iris : dataset) { scores.add(new Score(unknownIris.distance(iris), iris.specie)); } Collections.sort(scores, Score.COMPARATOR); // count occurences for K nearest neighbor Map occurenceCount = new HashMap(); for (Integer i = 0; i < scores.size(); i++) { String specie = scores.get(i).specie; if (occurenceCount.containsKey(specie)) { occurenceCount.put(specie, occurenceCount.get(specie) + 1); } else { occurenceCount.put(specie, 1); } if (i >= k - 1) { break; } } // find the most frequent occurence String mostFrequentSpecie = null; Integer nbOccurence = 0; for (Entry entry : occurenceCount.entrySet()) { if (nbOccurence < entry.getValue()) { nbOccurence = entry.getValue(); mostFrequentSpecie = entry.getKey(); } } return mostFrequentSpecie; } public static List loadDataset(String csvFile) { List dataset = new ArrayList(); BufferedReader br = null; String line = ""; String cvsSplitBy = ","; try { br = new BufferedReader(new FileReader(csvFile)); while ((line = br.readLine()) != null) { if (line.length() > 0) { String[] cell = line.split(cvsSplitBy); dataset.add(new Iris(Double.parseDouble(cell[0]), Double.parseDouble(cell[1]), Double.parseDouble(cell[2]), Double.parseDouble(cell[3]), cell[4])); } } } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } finally { if (br != null) { try { br.close(); } catch (IOException e) { e.printStackTrace(); } } } return dataset; } static class Score { public static final Comparator COMPARATOR = new Comparator() { @Override public int compare(Score o1, Score o2) { return o1.score.compareTo(o2.score); } }; public Double score; public String specie; public Score(Double score, String specie) { this.score = score; this.specie = specie; } } static class Iris { public Double sepalLength; public Double sepalWidth; public Double petalLength; public Double petalWidth; public String specie; public Iris(Double sepalLength, Double sepalWidth, Double petalLength, Double petalWidth, String specie) { this.sepalLength = sepalLength; this.sepalWidth = sepalWidth; this.petalLength = petalLength; this.petalWidth = petalWidth; this.specie = specie; } public Double distance(Iris that) { return Math.sqrt(Math.pow(sepalLength - that.sepalLength, 2) + Math.pow(sepalWidth - that.sepalWidth, 2) + Math.pow(petalLength - that.petalLength, 2) + Math.pow(petalWidth - that.petalWidth, 2)); } @Override public String toString() { return "Iris [specie=" + specie + ", sepalLength=" + sepalLength + ", sepalWidth=" + sepalWidth + ", petalLength=" + petalLength + ", petalWidth=" + petalWidth + "]"; } } public static void evaluate(List dataset) { // split dataset in 2 parts : one part to learn, the other part to test List versicolor = new ArrayList(), virginica = new ArrayList(), setosa = new ArrayList(); for (Iris iris : dataset) { if (iris.specie.equals("Iris-versicolor")) versicolor.add(iris); else if (iris.specie.equals("Iris-virginica")) virginica.add(iris); else if (iris.specie.equals("Iris-setosa")) setosa.add(iris); } Collections.shuffle(versicolor); Collections.shuffle(virginica); Collections.shuffle(setosa); List learningData = new ArrayList(), testData = new ArrayList(); for (Integer i = 0; i < versicolor.size(); i++) { if (i < versicolor.size() / 2) learningData.add(versicolor.get(i)); else testData.add(versicolor.get(i)); } for (Integer i = 0; i < virginica.size(); i++) { if (i < virginica.size() / 2) learningData.add(virginica.get(i)); else testData.add(virginica.get(i)); } for (Integer i = 0; i < setosa.size(); i++) { if (i < setosa.size() / 2) learningData.add(setosa.get(i)); else testData.add(setosa.get(i)); } // for each value of k, count the number of errors for (Integer k = 1; k <= 20; k++) { Integer cpt = 0; for (Iris iris : testData) { if (!predictSpecie(learningData, k, iris.sepalLength, iris.sepalWidth, iris.petalLength, iris.petalWidth).equals(iris.specie)) { cpt++; } } System.out.println(cpt + " errors on " + testData.size() + " tests with k=" + k); } } }