Using k-NN Classification with Apache® Ignite™

In the previous article in this Machine Learning series, we looked at Linear Regression with Apache® Ignite™. Now let’s take the opportunity to try another Machine Learning algorithm. This time we’ll look at k-Nearest Neighbor (k-NN) Classification. This algorithm is useful for determining class membership, where we classify an object based upon the most common class amongst its k nearest neighbors.

A dataset that is a good candidate for k-NN Classification is the Iris flower dataset. We can easily find this through the UCI web site.

The Iris flower dataset consists of 150 samples, with 50 each from 3 different species of Iris flowers (Iris Setosa, Iris Versicolour and Iris Virginica). The following four features are available for each sample:

  1. Sepal length (cm)
  2. Sepal width (cm)
  3. Petal length (cm)
  4. Petal width (cm)

We’ll create a model that can distinguish between the different species using these four features.

First, we need to take the raw data and split it into training data (60%) and test data (40%). We’ll use Scikit-learn again to perform this task for us and we can modify the code that we used in the previous article, as follows:


from sklearn import datasets
import pandas as pd

# Load Iris dataset.
iris_dataset = datasets.load_iris()
x = iris_dataset.data
y = iris_dataset.target

# Split it into train and test subsets.
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.4, random_state=23)

# Save train set.
train_ds = pd.DataFrame(x_train, columns=iris_dataset.feature_names)
train_ds["TARGET"] = y_train
train_ds.to_csv("iris-train.csv", index=False, header=None)

# Save test set.
test_ds = pd.DataFrame(x_test, columns=iris_dataset.feature_names)
test_ds["TARGET"] = y_test
test_ds.to_csv("iris-test.csv", index=False, header=None)

With our training and test data ready, we can start coding the application. You can download the code from GitHub if you would like to follow along. Our algorithm is therefore:

  1. Read the training data and test data
  2. Store the training data and test data in Ignite
  3. Use the training data to fit the k-NN Classification model
  4. Apply the model to the test data
  5. Determine the accuracy of the model

Read the training data and test data

We have two CSV files with 5 columns, as follows:

  1. Sepal length
  2. Sepal width
  3. Petal length
  4. Petal width
  5. Flower class (0 = Iris Setosa, 1 = Iris Versicolour, 2 = Iris Virginica)

We can use the following code to read-in values from the CSV files:


private static void loadData(String fileName, IgniteCache<Integer, IrisObservation> cache)
        throws FileNotFoundException {

   Scanner scanner = new Scanner(new File(fileName));

   int cnt = 0;
   while (scanner.hasNextLine()) {
      String row = scanner.nextLine();
      String[] cells = row.split(",");
      double[] features = new double[cells.length - 1];

      for (int i = 0; i < cells.length - 1; i++)
         features[i] = Double.valueOf(cells[i]);
      double flowerClass = Double.valueOf(cells[cells.length - 1]);

      cache.put(cnt++, new IrisObservation(features, flowerClass));
   }
}

The code reads the data line-by-line and splits fields on a line by the CSV field separator. Each field value is then converted to double format and then the data are stored in Ignite.

Store the training data and test data in Ignite

The previous code stores data values in Ignite. To use this code, we need to create the Ignite storage first, as follows:


IgniteCache<Integer, IrisObservation> trainData = getCache(ignite, "IRIS_TRAIN");

IgniteCache<Integer, IrisObservation> testData = getCache(ignite, "IRIS_TEST");

loadData("src/main/resources/iris-train.csv", trainData);

loadData("src/main/resources/iris-test.csv", testData);

The code for getCache() implemented as follows:


private static IgniteCache<Integer, IrisObservation> getCache(Ignite ignite, String cacheName) {

   CacheConfiguration<Integer, IrisObservation> cacheConfiguration = new CacheConfiguration<>();
   cacheConfiguration.setName(cacheName);
   cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));

   IgniteCache<Integer, IrisObservation> cache = ignite.createCache(cacheConfiguration);

   return cache;
}

Use the training data to fit the k-NN Classification model

Now that our data are stored, we can create the trainer, as follows:


KNNClassificationTrainer trainer = new KNNClassificationTrainer();

and fit a classification model to the training data, as follows:


KNNClassificationModel mdl = trainer.fit(
        ignite,
        trainData,
        (k, v) -> v.getFeatures(),     // Feature extractor.
        (k, v) -> v.getFlowerClass())  // Label extractor.
        .withK(3)
        .withDistanceMeasure(new EuclideanDistance())
        .withStrategy(KNNStrategy.WEIGHTED);

Ignite stores data in a Key-Value (K-V) format, so the above code uses the value part. The target value is the Flower class and the features are in the other columns. We set the value of k to 3 to represent the 3 species. For distance measure we have several options, such as Euclidean, Hamming or Manhattan and we’ll use Euclidean in this case. Finally, we can specify whether to use a SIMPLE or WEIGHTED k-NN algorithm and we’ll use WEIGHTED in this case.

Apply the model to the test data

We are now ready to check the test data against the trained classification model. The following code will do this for us:


int amountOfErrors = 0;
int totalAmount = 0;

try (QueryCursor<Cache.Entry<Integer, IrisObservation>> cursor = testData.query(new ScanQuery<>())) {
   for (Cache.Entry<Integer, IrisObservation> testEntry : cursor) {
      IrisObservation observation = testEntry.getValue();

      double groundTruth = observation.getFlowerClass();
      double prediction = mdl.apply(new DenseLocalOnHeapVector(observation.getFeatures()));

      totalAmount++;
      if (groundTruth != prediction)
         amountOfErrors++;

      System.out.printf(">>> | %.0f\t\t\t | %.0f\t\t\t|\n", prediction, groundTruth);
   }

   System.out.println(">>> -----------------------------");

   System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
   System.out.printf("\n>>> Accuracy %.2f\n", (1 - amountOfErrors / (double) totalAmount));
}

Determine the accuracy of the model

We are now ready to check the accuracy of our model by comparing how the model classifies the different species against the actual species values using our test data.

Running the code gives us the following summary:


>>> Absolute amount of errors 2

>>> Accuracy 0.97

Ignite was, therefore, able to correctly classify 97% of the test data into the 3 different species.

Summary

Apache Ignite provides a library of Machine Learning algorithms. Through a k-NN Classification example, we have seen the ease with which we can create a model, test the model and determine the accuracy of the model.

In the next part of this Apache Ignite Machine Learning series, we’ll look at another Machine Learning algorithm.