Neural Networks

Neural networks being a very powerful class of models, especially in cases where the learning of representations from low-level information (such as pixels, audio samples or text) is key, sensAI provides many useful abstractions for dealing with this class of models, facilitating data handling, learning and evaluation.

sensAI mainly provides abstractions for PyTorch, but there is also rudimentary support for TensorFlow.

[1]:
%load_ext autoreload
%autoreload 2
[2]:
import sys; sys.path.extend(["../src", ".."])
import sensai
import pandas as pd
import numpy as np
from typing import *
import config
import warnings
import functools

cfg = config.get_config()
warnings.filterwarnings("ignore")
sensai.util.logging.configure()

Image Classification

As an example use case, let us solve the classification problem of classifying digits in pixel images from the MNIST dataset. Images are greyscale (no colour information) and 28x28 pixels in size.

[3]:
mnist_df = pd.read_csv(cfg.datafile_path("mnist_train.csv.zip"))

The data frame contains one column for every pixel, each pixel being represented by an 8-bit integer (0 to 255).

[4]:
mnist_df.head(5)
[4]:
label 1x1 1x2 1x3 1x4 1x5 1x6 1x7 1x8 1x9 ... 28x19 28x20 28x21 28x22 28x23 28x24 28x25 28x26 28x27 28x28
0 5 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 4 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
3 1 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
4 9 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0

5 rows × 785 columns

Let’s create the I/O data for our experiments.

[5]:
mnistIoData = sensai.InputOutputData.from_data_frame(mnist_df, "label")

Now that we have the image data separated from the labels, let’s write a function to restore the 2D image arrays and take a look at some of the images.

[6]:
import matplotlib.pyplot as plt

def reshape_2d_image(series):
    return series.values.reshape(28, 28)

fig, axs = plt.subplots(nrows=1, ncols=5, figsize=(10, 5))
for i in range(5):
    axs[i].imshow(reshape_2d_image(mnistIoData.inputs.iloc[i]), cmap="binary")
_images/neural_networks_10_0.png

Applying Predefined Models

We create an evaluator in order to test the performance of our models, randomly splitting the data.

[7]:
evaluator_params = sensai.evaluation.ClassificationEvaluatorParams(fractional_split_test_fraction=0.2)
eval_util = sensai.evaluation.ClassificationModelEvaluation(mnistIoData, evaluator_params=evaluator_params)

One pre-defined model we could try is a simple multi-layer perceptron. A PyTorch-based implementation is provided via class MultiLayerPerceptronVectorClassificationModel. This implementation supports CUDA-accelerated computations (on Nvidia GPUs), yet we shall stick to CPU-based computation (cuda=False) in this tutorial.

[8]:
import sensai.torch

nn_optimiser_params = sensai.torch.NNOptimiserParams(early_stopping_epochs=2, batch_size=54)
torch_mlp_model = sensai.torch.models.MultiLayerPerceptronVectorClassificationModel(hidden_dims=(50, 20),
        cuda=False, normalisation_mode=sensai.torch.NormalisationMode.MAX_ALL,
        nn_optimiser_params=nn_optimiser_params, p_dropout=0.0) \
    .with_name("MLP")

Neural networks work best on normalised inputs, so we have opted to apply basic normalisation by specifying a normalisation mode which will transforms inputs by dividing by the maximum value found across all columns in the training data. For more elaborate normalisation options, we could have used a data frame transformer (DFT), particularly DFTNormalisation or DFTSkLearnTransformer.

sensAI’s default neural network training algorithm is based on early stopping, which involves checking, in regular intervals, the performance of the model on a validation set (which is split from the training set) and ultimately selecting the model that performed best on the validation set. You have full control over the loss evaluation method used to select the best model (by passing a respective NNLossEvaluator instance to NNOptimiserParams) as well as the method that is used to split the training set into the actual training set and the validation set (by adding a DataFrameSplitter to the model or using a custom TorchDataSetProvider).

Given the vectorised nature of our MNIST dataset, we can apply any type of model which can accept the numeric inputs. Let’s compare the neural network we defined above against another pre-defined model, which is based on a scikit-learn implementation and uses decision trees rather than neural networks.

[9]:
random_forest_model = sensai.sklearn.classification.SkLearnRandomForestVectorClassificationModel(
        min_samples_leaf=1,
        n_estimators=10) \
    .with_name("RandomForest")

Let’s compare the two models using our evaluation utility.

[10]:
eval_util.compare_models([random_forest_model, torch_mlp_model])
INFO  2024-04-30 08:29:54,805 sensai.evaluation.eval_util:compare_models:393 - Evaluating model 1/2 named 'RandomForest' ...
DEBUG 2024-04-30 08:29:55,019 sensai.evaluation.evaluator:__init__:182 - <sensai.data.DataSplitterFractional object at 0x7fe388bb8a90> created split with 48000 (80.00%) and 12000 (20.00%) training and test data points respectively
INFO  2024-04-30 08:29:55,020 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating SkLearnRandomForestVectorClassificationModel[featureGenerator=None, fitArgs={}, useBalancedClassWeights=False, useLabelEncoding=False, name=RandomForest, modelConstructor=RandomForestClassifier(random_state=42, min_samples_leaf=1, n_estimators=10)] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe388bb8890>
INFO  2024-04-30 08:29:55,021 sensai.vector_model:fit:371 - Fitting SkLearnRandomForestVectorClassificationModel instance
DEBUG 2024-04-30 08:29:55,045 sensai.vector_model:fit:394 - Fitting with outputs[1]=['label'], inputs[784]=[1x1/int64, 1x2/int64, 1x3/int64, 1x4/int64, 1x5/int64, 1x6/int64, 1x7/int64, 1x8/int64, 1x9/int64, 1x10/int64, 1x11/int64, 1x12/int64, 1x13/int64, 1x14/int64, 1x15/int64, 1x16/int64, 1x17/int64, 1x18/int64, 1x19/int64, 1x20/int64, 1x21/int64, 1x22/int64, 1x23/int64, 1x24/int64, 1x25/int64, 1x26/int64, 1x27/int64, 1x28/int64, 2x1/int64, 2x2/int64, 2x3/int64, 2x4/int64, 2x5/int64, 2x6/int64, 2x7/int64, 2x8/int64, 2x9/int64, 2x10/int64, 2x11/int64, 2x12/int64, 2x13/int64, 2x14/int64, 2x15/int64, 2x16/int64, 2x17/int64, 2x18/int64, 2x19/int64, 2x20/int64, 2x21/int64, 2x22/int64, 2x23/int64, 2x24/int64, 2x25/int64, 2x26/int64, 2x27/int64, 2x28/int64, 3x1/int64, 3x2/int64, 3x3/int64, 3x4/int64, 3x5/int64, 3x6/int64, 3x7/int64, 3x8/int64, 3x9/int64, 3x10/int64, 3x11/int64, 3x12/int64, 3x13/int64, 3x14/int64, 3x15/int64, 3x16/int64, 3x17/int64, 3x18/int64, 3x19/int64, 3x20/int64, 3x21/int64, 3x22/int64, 3x23/int64, 3x24/int64, 3x25/int64, 3x26/int64, 3x27/int64, 3x28/int64, 4x1/int64, 4x2/int64, 4x3/int64, 4x4/int64, 4x5/int64, 4x6/int64, 4x7/int64, 4x8/int64, 4x9/int64, 4x10/int64, 4x11/int64, 4x12/int64, 4x13/int64, 4x14/int64, 4x15/int64, 4x16/int64, 4x17/int64, 4x18/int64, 4x19/int64, 4x20/int64, 4x21/int64, 4x22/int64, 4x23/int64, 4x24/int64, 4x25/int64, 4x26/int64, 4x27/int64, 4x28/int64, 5x1/int64, 5x2/int64, 5x3/int64, 5x4/int64, 5x5/int64, 5x6/int64, 5x7/int64, 5x8/int64, 5x9/int64, 5x10/int64, 5x11/int64, 5x12/int64, 5x13/int64, 5x14/int64, 5x15/int64, 5x16/int64, 5x17/int64, 5x18/int64, 5x19/int64, 5x20/int64, 5x21/int64, 5x22/int64, 5x23/int64, 5x24/int64, 5x25/int64, 5x26/int64, 5x27/int64, 5x28/int64, 6x1/int64, 6x2/int64, 6x3/int64, 6x4/int64, 6x5/int64, 6x6/int64, 6x7/int64, 6x8/int64, 6x9/int64, 6x10/int64, 6x11/int64, 6x12/int64, 6x13/int64, 6x14/int64, 6x15/int64, 6x16/int64, 6x17/int64, 6x18/int64, 6x19/int64, 6x20/int64, 6x21/int64, 6x22/int64, 6x23/int64, 6x24/int64, 6x25/int64, 6x26/int64, 6x27/int64, 6x28/int64, 7x1/int64, 7x2/int64, 7x3/int64, 7x4/int64, 7x5/int64, 7x6/int64, 7x7/int64, 7x8/int64, 7x9/int64, 7x10/int64, 7x11/int64, 7x12/int64, 7x13/int64, 7x14/int64, 7x15/int64, 7x16/int64, 7x17/int64, 7x18/int64, 7x19/int64, 7x20/int64, 7x21/int64, 7x22/int64, 7x23/int64, 7x24/int64, 7x25/int64, 7x26/int64, 7x27/int64, 7x28/int64, 8x1/int64, 8x2/int64, 8x3/int64, 8x4/int64, 8x5/int64, 8x6/int64, 8x7/int64, 8x8/int64, 8x9/int64, 8x10/int64, 8x11/int64, 8x12/int64, 8x13/int64, 8x14/int64, 8x15/int64, 8x16/int64, 8x17/int64, 8x18/int64, 8x19/int64, 8x20/int64, 8x21/int64, 8x22/int64, 8x23/int64, 8x24/int64, 8x25/int64, 8x26/int64, 8x27/int64, 8x28/int64, 9x1/int64, 9x2/int64, 9x3/int64, 9x4/int64, 9x5/int64, 9x6/int64, 9x7/int64, 9x8/int64, 9x9/int64, 9x10/int64, 9x11/int64, 9x12/int64, 9x13/int64, 9x14/int64, 9x15/int64, 9x16/int64, 9x17/int64, 9x18/int64, 9x19/int64, 9x20/int64, 9x21/int64, 9x22/int64, 9x23/int64, 9x24/int64, 9x25/int64, 9x26/int64, 9x27/int64, 9x28/int64, 10x1/int64, 10x2/int64, 10x3/int64, 10x4/int64, 10x5/int64, 10x6/int64, 10x7/int64, 10x8/int64, 10x9/int64, 10x10/int64, 10x11/int64, 10x12/int64, 10x13/int64, 10x14/int64, 10x15/int64, 10x16/int64, 10x17/int64, 10x18/int64, 10x19/int64, 10x20/int64, 10x21/int64, 10x22/int64, 10x23/int64, 10x24/int64, 10x25/int64, 10x26/int64, 10x27/int64, 10x28/int64, 11x1/int64, 11x2/int64, 11x3/int64, 11x4/int64, 11x5/int64, 11x6/int64, 11x7/int64, 11x8/int64, 11x9/int64, 11x10/int64, 11x11/int64, 11x12/int64, 11x13/int64, 11x14/int64, 11x15/int64, 11x16/int64, 11x17/int64, 11x18/int64, 11x19/int64, 11x20/int64, 11x21/int64, 11x22/int64, 11x23/int64, 11x24/int64, 11x25/int64, 11x26/int64, 11x27/int64, 11x28/int64, 12x1/int64, 12x2/int64, 12x3/int64, 12x4/int64, 12x5/int64, 12x6/int64, 12x7/int64, 12x8/int64, 12x9/int64, 12x10/int64, 12x11/int64, 12x12/int64, 12x13/int64, 12x14/int64, 12x15/int64, 12x16/int64, 12x17/int64, 12x18/int64, 12x19/int64, 12x20/int64, 12x21/int64, 12x22/int64, 12x23/int64, 12x24/int64, 12x25/int64, 12x26/int64, 12x27/int64, 12x28/int64, 13x1/int64, 13x2/int64, 13x3/int64, 13x4/int64, 13x5/int64, 13x6/int64, 13x7/int64, 13x8/int64, 13x9/int64, 13x10/int64, 13x11/int64, 13x12/int64, 13x13/int64, 13x14/int64, 13x15/int64, 13x16/int64, 13x17/int64, 13x18/int64, 13x19/int64, 13x20/int64, 13x21/int64, 13x22/int64, 13x23/int64, 13x24/int64, 13x25/int64, 13x26/int64, 13x27/int64, 13x28/int64, 14x1/int64, 14x2/int64, 14x3/int64, 14x4/int64, 14x5/int64, 14x6/int64, 14x7/int64, 14x8/int64, 14x9/int64, 14x10/int64, 14x11/int64, 14x12/int64, 14x13/int64, 14x14/int64, 14x15/int64, 14x16/int64, 14x17/int64, 14x18/int64, 14x19/int64, 14x20/int64, 14x21/int64, 14x22/int64, 14x23/int64, 14x24/int64, 14x25/int64, 14x26/int64, 14x27/int64, 14x28/int64, 15x1/int64, 15x2/int64, 15x3/int64, 15x4/int64, 15x5/int64, 15x6/int64, 15x7/int64, 15x8/int64, 15x9/int64, 15x10/int64, 15x11/int64, 15x12/int64, 15x13/int64, 15x14/int64, 15x15/int64, 15x16/int64, 15x17/int64, 15x18/int64, 15x19/int64, 15x20/int64, 15x21/int64, 15x22/int64, 15x23/int64, 15x24/int64, 15x25/int64, 15x26/int64, 15x27/int64, 15x28/int64, 16x1/int64, 16x2/int64, 16x3/int64, 16x4/int64, 16x5/int64, 16x6/int64, 16x7/int64, 16x8/int64, 16x9/int64, 16x10/int64, 16x11/int64, 16x12/int64, 16x13/int64, 16x14/int64, 16x15/int64, 16x16/int64, 16x17/int64, 16x18/int64, 16x19/int64, 16x20/int64, 16x21/int64, 16x22/int64, 16x23/int64, 16x24/int64, 16x25/int64, 16x26/int64, 16x27/int64, 16x28/int64, 17x1/int64, 17x2/int64, 17x3/int64, 17x4/int64, 17x5/int64, 17x6/int64, 17x7/int64, 17x8/int64, 17x9/int64, 17x10/int64, 17x11/int64, 17x12/int64, 17x13/int64, 17x14/int64, 17x15/int64, 17x16/int64, 17x17/int64, 17x18/int64, 17x19/int64, 17x20/int64, 17x21/int64, 17x22/int64, 17x23/int64, 17x24/int64, 17x25/int64, 17x26/int64, 17x27/int64, 17x28/int64, 18x1/int64, 18x2/int64, 18x3/int64, 18x4/int64, 18x5/int64, 18x6/int64, 18x7/int64, 18x8/int64, 18x9/int64, 18x10/int64, 18x11/int64, 18x12/int64, 18x13/int64, 18x14/int64, 18x15/int64, 18x16/int64, 18x17/int64, 18x18/int64, 18x19/int64, 18x20/int64, 18x21/int64, 18x22/int64, 18x23/int64, 18x24/int64, 18x25/int64, 18x26/int64, 18x27/int64, 18x28/int64, 19x1/int64, 19x2/int64, 19x3/int64, 19x4/int64, 19x5/int64, 19x6/int64, 19x7/int64, 19x8/int64, 19x9/int64, 19x10/int64, 19x11/int64, 19x12/int64, 19x13/int64, 19x14/int64, 19x15/int64, 19x16/int64, 19x17/int64, 19x18/int64, 19x19/int64, 19x20/int64, 19x21/int64, 19x22/int64, 19x23/int64, 19x24/int64, 19x25/int64, 19x26/int64, 19x27/int64, 19x28/int64, 20x1/int64, 20x2/int64, 20x3/int64, 20x4/int64, 20x5/int64, 20x6/int64, 20x7/int64, 20x8/int64, 20x9/int64, 20x10/int64, 20x11/int64, 20x12/int64, 20x13/int64, 20x14/int64, 20x15/int64, 20x16/int64, 20x17/int64, 20x18/int64, 20x19/int64, 20x20/int64, 20x21/int64, 20x22/int64, 20x23/int64, 20x24/int64, 20x25/int64, 20x26/int64, 20x27/int64, 20x28/int64, 21x1/int64, 21x2/int64, 21x3/int64, 21x4/int64, 21x5/int64, 21x6/int64, 21x7/int64, 21x8/int64, 21x9/int64, 21x10/int64, 21x11/int64, 21x12/int64, 21x13/int64, 21x14/int64, 21x15/int64, 21x16/int64, 21x17/int64, 21x18/int64, 21x19/int64, 21x20/int64, 21x21/int64, 21x22/int64, 21x23/int64, 21x24/int64, 21x25/int64, 21x26/int64, 21x27/int64, 21x28/int64, 22x1/int64, 22x2/int64, 22x3/int64, 22x4/int64, 22x5/int64, 22x6/int64, 22x7/int64, 22x8/int64, 22x9/int64, 22x10/int64, 22x11/int64, 22x12/int64, 22x13/int64, 22x14/int64, 22x15/int64, 22x16/int64, 22x17/int64, 22x18/int64, 22x19/int64, 22x20/int64, 22x21/int64, 22x22/int64, 22x23/int64, 22x24/int64, 22x25/int64, 22x26/int64, 22x27/int64, 22x28/int64, 23x1/int64, 23x2/int64, 23x3/int64, 23x4/int64, 23x5/int64, 23x6/int64, 23x7/int64, 23x8/int64, 23x9/int64, 23x10/int64, 23x11/int64, 23x12/int64, 23x13/int64, 23x14/int64, 23x15/int64, 23x16/int64, 23x17/int64, 23x18/int64, 23x19/int64, 23x20/int64, 23x21/int64, 23x22/int64, 23x23/int64, 23x24/int64, 23x25/int64, 23x26/int64, 23x27/int64, 23x28/int64, 24x1/int64, 24x2/int64, 24x3/int64, 24x4/int64, 24x5/int64, 24x6/int64, 24x7/int64, 24x8/int64, 24x9/int64, 24x10/int64, 24x11/int64, 24x12/int64, 24x13/int64, 24x14/int64, 24x15/int64, 24x16/int64, 24x17/int64, 24x18/int64, 24x19/int64, 24x20/int64, 24x21/int64, 24x22/int64, 24x23/int64, 24x24/int64, 24x25/int64, 24x26/int64, 24x27/int64, 24x28/int64, 25x1/int64, 25x2/int64, 25x3/int64, 25x4/int64, 25x5/int64, 25x6/int64, 25x7/int64, 25x8/int64, 25x9/int64, 25x10/int64, 25x11/int64, 25x12/int64, 25x13/int64, 25x14/int64, 25x15/int64, 25x16/int64, 25x17/int64, 25x18/int64, 25x19/int64, 25x20/int64, 25x21/int64, 25x22/int64, 25x23/int64, 25x24/int64, 25x25/int64, 25x26/int64, 25x27/int64, 25x28/int64, 26x1/int64, 26x2/int64, 26x3/int64, 26x4/int64, 26x5/int64, 26x6/int64, 26x7/int64, 26x8/int64, 26x9/int64, 26x10/int64, 26x11/int64, 26x12/int64, 26x13/int64, 26x14/int64, 26x15/int64, 26x16/int64, 26x17/int64, 26x18/int64, 26x19/int64, 26x20/int64, 26x21/int64, 26x22/int64, 26x23/int64, 26x24/int64, 26x25/int64, 26x26/int64, 26x27/int64, 26x28/int64, 27x1/int64, 27x2/int64, 27x3/int64, 27x4/int64, 27x5/int64, 27x6/int64, 27x7/int64, 27x8/int64, 27x9/int64, 27x10/int64, 27x11/int64, 27x12/int64, 27x13/int64, 27x14/int64, 27x15/int64, 27x16/int64, 27x17/int64, 27x18/int64, 27x19/int64, 27x20/int64, 27x21/int64, 27x22/int64, 27x23/int64, 27x24/int64, 27x25/int64, 27x26/int64, 27x27/int64, 27x28/int64, 28x1/int64, 28x2/int64, 28x3/int64, 28x4/int64, 28x5/int64, 28x6/int64, 28x7/int64, 28x8/int64, 28x9/int64, 28x10/int64, 28x11/int64, 28x12/int64, 28x13/int64, 28x14/int64, 28x15/int64, 28x16/int64, 28x17/int64, 28x18/int64, 28x19/int64, 28x20/int64, 28x21/int64, 28x22/int64, 28x23/int64, 28x24/int64, 28x25/int64, 28x26/int64, 28x27/int64, 28x28/int64]; N=48000 data points
INFO  2024-04-30 08:29:55,046 sensai.sklearn.sklearn_base:_fit_classifier:281 - Fitting sklearn classifier of type RandomForestClassifier
INFO  2024-04-30 08:29:57,807 sensai.vector_model:fit:400 - Fitting completed in 2.79 seconds: SkLearnRandomForestVectorClassificationModel[featureGenerator=None, fitArgs={}, useBalancedClassWeights=False, useLabelEncoding=False, name=RandomForest, model=RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None, criterion='gini', max_depth=None, max_features='auto', max_leaf_nodes=None, max_samples=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None, oob_score=False, random_state=42, verbose=0, warm_start=False)]
INFO  2024-04-30 08:29:57,884 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.9466666666666667, balancedAccuracy=0.945916926388699, N=12000]
INFO  2024-04-30 08:29:58,214 sensai.evaluation.eval_util:compare_models:393 - Evaluating model 2/2 named 'MLP' ...
INFO  2024-04-30 08:29:58,215 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating MultiLayerPerceptronVectorClassificationModel[hidden_dims=(50, 20), hid_activation_function=<built-in method sigmoid of type object at 0x7fe329788880>, output_activation_function=ActivationFunction.LOG_SOFTMAX, input_dim=None, cuda=False, p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=Method[_create_torch_model], normalisationMode=NormalisationMode.MAX_ALL, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=54, optimiser_lr=0.001, shrinkage_clip=10.0, optimiser=adam, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=2, shuffle=True], model=None, inputTensoriser=None, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=MLP] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe388bb8890>
INFO  2024-04-30 08:29:58,216 sensai.vector_model:fit:371 - Fitting MultiLayerPerceptronVectorClassificationModel instance
DEBUG 2024-04-30 08:29:58,224 sensai.vector_model:fit:394 - Fitting with outputs[1]=['label'], inputs[784]=[1x1/int64, 1x2/int64, 1x3/int64, 1x4/int64, 1x5/int64, 1x6/int64, 1x7/int64, 1x8/int64, 1x9/int64, 1x10/int64, 1x11/int64, 1x12/int64, 1x13/int64, 1x14/int64, 1x15/int64, 1x16/int64, 1x17/int64, 1x18/int64, 1x19/int64, 1x20/int64, 1x21/int64, 1x22/int64, 1x23/int64, 1x24/int64, 1x25/int64, 1x26/int64, 1x27/int64, 1x28/int64, 2x1/int64, 2x2/int64, 2x3/int64, 2x4/int64, 2x5/int64, 2x6/int64, 2x7/int64, 2x8/int64, 2x9/int64, 2x10/int64, 2x11/int64, 2x12/int64, 2x13/int64, 2x14/int64, 2x15/int64, 2x16/int64, 2x17/int64, 2x18/int64, 2x19/int64, 2x20/int64, 2x21/int64, 2x22/int64, 2x23/int64, 2x24/int64, 2x25/int64, 2x26/int64, 2x27/int64, 2x28/int64, 3x1/int64, 3x2/int64, 3x3/int64, 3x4/int64, 3x5/int64, 3x6/int64, 3x7/int64, 3x8/int64, 3x9/int64, 3x10/int64, 3x11/int64, 3x12/int64, 3x13/int64, 3x14/int64, 3x15/int64, 3x16/int64, 3x17/int64, 3x18/int64, 3x19/int64, 3x20/int64, 3x21/int64, 3x22/int64, 3x23/int64, 3x24/int64, 3x25/int64, 3x26/int64, 3x27/int64, 3x28/int64, 4x1/int64, 4x2/int64, 4x3/int64, 4x4/int64, 4x5/int64, 4x6/int64, 4x7/int64, 4x8/int64, 4x9/int64, 4x10/int64, 4x11/int64, 4x12/int64, 4x13/int64, 4x14/int64, 4x15/int64, 4x16/int64, 4x17/int64, 4x18/int64, 4x19/int64, 4x20/int64, 4x21/int64, 4x22/int64, 4x23/int64, 4x24/int64, 4x25/int64, 4x26/int64, 4x27/int64, 4x28/int64, 5x1/int64, 5x2/int64, 5x3/int64, 5x4/int64, 5x5/int64, 5x6/int64, 5x7/int64, 5x8/int64, 5x9/int64, 5x10/int64, 5x11/int64, 5x12/int64, 5x13/int64, 5x14/int64, 5x15/int64, 5x16/int64, 5x17/int64, 5x18/int64, 5x19/int64, 5x20/int64, 5x21/int64, 5x22/int64, 5x23/int64, 5x24/int64, 5x25/int64, 5x26/int64, 5x27/int64, 5x28/int64, 6x1/int64, 6x2/int64, 6x3/int64, 6x4/int64, 6x5/int64, 6x6/int64, 6x7/int64, 6x8/int64, 6x9/int64, 6x10/int64, 6x11/int64, 6x12/int64, 6x13/int64, 6x14/int64, 6x15/int64, 6x16/int64, 6x17/int64, 6x18/int64, 6x19/int64, 6x20/int64, 6x21/int64, 6x22/int64, 6x23/int64, 6x24/int64, 6x25/int64, 6x26/int64, 6x27/int64, 6x28/int64, 7x1/int64, 7x2/int64, 7x3/int64, 7x4/int64, 7x5/int64, 7x6/int64, 7x7/int64, 7x8/int64, 7x9/int64, 7x10/int64, 7x11/int64, 7x12/int64, 7x13/int64, 7x14/int64, 7x15/int64, 7x16/int64, 7x17/int64, 7x18/int64, 7x19/int64, 7x20/int64, 7x21/int64, 7x22/int64, 7x23/int64, 7x24/int64, 7x25/int64, 7x26/int64, 7x27/int64, 7x28/int64, 8x1/int64, 8x2/int64, 8x3/int64, 8x4/int64, 8x5/int64, 8x6/int64, 8x7/int64, 8x8/int64, 8x9/int64, 8x10/int64, 8x11/int64, 8x12/int64, 8x13/int64, 8x14/int64, 8x15/int64, 8x16/int64, 8x17/int64, 8x18/int64, 8x19/int64, 8x20/int64, 8x21/int64, 8x22/int64, 8x23/int64, 8x24/int64, 8x25/int64, 8x26/int64, 8x27/int64, 8x28/int64, 9x1/int64, 9x2/int64, 9x3/int64, 9x4/int64, 9x5/int64, 9x6/int64, 9x7/int64, 9x8/int64, 9x9/int64, 9x10/int64, 9x11/int64, 9x12/int64, 9x13/int64, 9x14/int64, 9x15/int64, 9x16/int64, 9x17/int64, 9x18/int64, 9x19/int64, 9x20/int64, 9x21/int64, 9x22/int64, 9x23/int64, 9x24/int64, 9x25/int64, 9x26/int64, 9x27/int64, 9x28/int64, 10x1/int64, 10x2/int64, 10x3/int64, 10x4/int64, 10x5/int64, 10x6/int64, 10x7/int64, 10x8/int64, 10x9/int64, 10x10/int64, 10x11/int64, 10x12/int64, 10x13/int64, 10x14/int64, 10x15/int64, 10x16/int64, 10x17/int64, 10x18/int64, 10x19/int64, 10x20/int64, 10x21/int64, 10x22/int64, 10x23/int64, 10x24/int64, 10x25/int64, 10x26/int64, 10x27/int64, 10x28/int64, 11x1/int64, 11x2/int64, 11x3/int64, 11x4/int64, 11x5/int64, 11x6/int64, 11x7/int64, 11x8/int64, 11x9/int64, 11x10/int64, 11x11/int64, 11x12/int64, 11x13/int64, 11x14/int64, 11x15/int64, 11x16/int64, 11x17/int64, 11x18/int64, 11x19/int64, 11x20/int64, 11x21/int64, 11x22/int64, 11x23/int64, 11x24/int64, 11x25/int64, 11x26/int64, 11x27/int64, 11x28/int64, 12x1/int64, 12x2/int64, 12x3/int64, 12x4/int64, 12x5/int64, 12x6/int64, 12x7/int64, 12x8/int64, 12x9/int64, 12x10/int64, 12x11/int64, 12x12/int64, 12x13/int64, 12x14/int64, 12x15/int64, 12x16/int64, 12x17/int64, 12x18/int64, 12x19/int64, 12x20/int64, 12x21/int64, 12x22/int64, 12x23/int64, 12x24/int64, 12x25/int64, 12x26/int64, 12x27/int64, 12x28/int64, 13x1/int64, 13x2/int64, 13x3/int64, 13x4/int64, 13x5/int64, 13x6/int64, 13x7/int64, 13x8/int64, 13x9/int64, 13x10/int64, 13x11/int64, 13x12/int64, 13x13/int64, 13x14/int64, 13x15/int64, 13x16/int64, 13x17/int64, 13x18/int64, 13x19/int64, 13x20/int64, 13x21/int64, 13x22/int64, 13x23/int64, 13x24/int64, 13x25/int64, 13x26/int64, 13x27/int64, 13x28/int64, 14x1/int64, 14x2/int64, 14x3/int64, 14x4/int64, 14x5/int64, 14x6/int64, 14x7/int64, 14x8/int64, 14x9/int64, 14x10/int64, 14x11/int64, 14x12/int64, 14x13/int64, 14x14/int64, 14x15/int64, 14x16/int64, 14x17/int64, 14x18/int64, 14x19/int64, 14x20/int64, 14x21/int64, 14x22/int64, 14x23/int64, 14x24/int64, 14x25/int64, 14x26/int64, 14x27/int64, 14x28/int64, 15x1/int64, 15x2/int64, 15x3/int64, 15x4/int64, 15x5/int64, 15x6/int64, 15x7/int64, 15x8/int64, 15x9/int64, 15x10/int64, 15x11/int64, 15x12/int64, 15x13/int64, 15x14/int64, 15x15/int64, 15x16/int64, 15x17/int64, 15x18/int64, 15x19/int64, 15x20/int64, 15x21/int64, 15x22/int64, 15x23/int64, 15x24/int64, 15x25/int64, 15x26/int64, 15x27/int64, 15x28/int64, 16x1/int64, 16x2/int64, 16x3/int64, 16x4/int64, 16x5/int64, 16x6/int64, 16x7/int64, 16x8/int64, 16x9/int64, 16x10/int64, 16x11/int64, 16x12/int64, 16x13/int64, 16x14/int64, 16x15/int64, 16x16/int64, 16x17/int64, 16x18/int64, 16x19/int64, 16x20/int64, 16x21/int64, 16x22/int64, 16x23/int64, 16x24/int64, 16x25/int64, 16x26/int64, 16x27/int64, 16x28/int64, 17x1/int64, 17x2/int64, 17x3/int64, 17x4/int64, 17x5/int64, 17x6/int64, 17x7/int64, 17x8/int64, 17x9/int64, 17x10/int64, 17x11/int64, 17x12/int64, 17x13/int64, 17x14/int64, 17x15/int64, 17x16/int64, 17x17/int64, 17x18/int64, 17x19/int64, 17x20/int64, 17x21/int64, 17x22/int64, 17x23/int64, 17x24/int64, 17x25/int64, 17x26/int64, 17x27/int64, 17x28/int64, 18x1/int64, 18x2/int64, 18x3/int64, 18x4/int64, 18x5/int64, 18x6/int64, 18x7/int64, 18x8/int64, 18x9/int64, 18x10/int64, 18x11/int64, 18x12/int64, 18x13/int64, 18x14/int64, 18x15/int64, 18x16/int64, 18x17/int64, 18x18/int64, 18x19/int64, 18x20/int64, 18x21/int64, 18x22/int64, 18x23/int64, 18x24/int64, 18x25/int64, 18x26/int64, 18x27/int64, 18x28/int64, 19x1/int64, 19x2/int64, 19x3/int64, 19x4/int64, 19x5/int64, 19x6/int64, 19x7/int64, 19x8/int64, 19x9/int64, 19x10/int64, 19x11/int64, 19x12/int64, 19x13/int64, 19x14/int64, 19x15/int64, 19x16/int64, 19x17/int64, 19x18/int64, 19x19/int64, 19x20/int64, 19x21/int64, 19x22/int64, 19x23/int64, 19x24/int64, 19x25/int64, 19x26/int64, 19x27/int64, 19x28/int64, 20x1/int64, 20x2/int64, 20x3/int64, 20x4/int64, 20x5/int64, 20x6/int64, 20x7/int64, 20x8/int64, 20x9/int64, 20x10/int64, 20x11/int64, 20x12/int64, 20x13/int64, 20x14/int64, 20x15/int64, 20x16/int64, 20x17/int64, 20x18/int64, 20x19/int64, 20x20/int64, 20x21/int64, 20x22/int64, 20x23/int64, 20x24/int64, 20x25/int64, 20x26/int64, 20x27/int64, 20x28/int64, 21x1/int64, 21x2/int64, 21x3/int64, 21x4/int64, 21x5/int64, 21x6/int64, 21x7/int64, 21x8/int64, 21x9/int64, 21x10/int64, 21x11/int64, 21x12/int64, 21x13/int64, 21x14/int64, 21x15/int64, 21x16/int64, 21x17/int64, 21x18/int64, 21x19/int64, 21x20/int64, 21x21/int64, 21x22/int64, 21x23/int64, 21x24/int64, 21x25/int64, 21x26/int64, 21x27/int64, 21x28/int64, 22x1/int64, 22x2/int64, 22x3/int64, 22x4/int64, 22x5/int64, 22x6/int64, 22x7/int64, 22x8/int64, 22x9/int64, 22x10/int64, 22x11/int64, 22x12/int64, 22x13/int64, 22x14/int64, 22x15/int64, 22x16/int64, 22x17/int64, 22x18/int64, 22x19/int64, 22x20/int64, 22x21/int64, 22x22/int64, 22x23/int64, 22x24/int64, 22x25/int64, 22x26/int64, 22x27/int64, 22x28/int64, 23x1/int64, 23x2/int64, 23x3/int64, 23x4/int64, 23x5/int64, 23x6/int64, 23x7/int64, 23x8/int64, 23x9/int64, 23x10/int64, 23x11/int64, 23x12/int64, 23x13/int64, 23x14/int64, 23x15/int64, 23x16/int64, 23x17/int64, 23x18/int64, 23x19/int64, 23x20/int64, 23x21/int64, 23x22/int64, 23x23/int64, 23x24/int64, 23x25/int64, 23x26/int64, 23x27/int64, 23x28/int64, 24x1/int64, 24x2/int64, 24x3/int64, 24x4/int64, 24x5/int64, 24x6/int64, 24x7/int64, 24x8/int64, 24x9/int64, 24x10/int64, 24x11/int64, 24x12/int64, 24x13/int64, 24x14/int64, 24x15/int64, 24x16/int64, 24x17/int64, 24x18/int64, 24x19/int64, 24x20/int64, 24x21/int64, 24x22/int64, 24x23/int64, 24x24/int64, 24x25/int64, 24x26/int64, 24x27/int64, 24x28/int64, 25x1/int64, 25x2/int64, 25x3/int64, 25x4/int64, 25x5/int64, 25x6/int64, 25x7/int64, 25x8/int64, 25x9/int64, 25x10/int64, 25x11/int64, 25x12/int64, 25x13/int64, 25x14/int64, 25x15/int64, 25x16/int64, 25x17/int64, 25x18/int64, 25x19/int64, 25x20/int64, 25x21/int64, 25x22/int64, 25x23/int64, 25x24/int64, 25x25/int64, 25x26/int64, 25x27/int64, 25x28/int64, 26x1/int64, 26x2/int64, 26x3/int64, 26x4/int64, 26x5/int64, 26x6/int64, 26x7/int64, 26x8/int64, 26x9/int64, 26x10/int64, 26x11/int64, 26x12/int64, 26x13/int64, 26x14/int64, 26x15/int64, 26x16/int64, 26x17/int64, 26x18/int64, 26x19/int64, 26x20/int64, 26x21/int64, 26x22/int64, 26x23/int64, 26x24/int64, 26x25/int64, 26x26/int64, 26x27/int64, 26x28/int64, 27x1/int64, 27x2/int64, 27x3/int64, 27x4/int64, 27x5/int64, 27x6/int64, 27x7/int64, 27x8/int64, 27x9/int64, 27x10/int64, 27x11/int64, 27x12/int64, 27x13/int64, 27x14/int64, 27x15/int64, 27x16/int64, 27x17/int64, 27x18/int64, 27x19/int64, 27x20/int64, 27x21/int64, 27x22/int64, 27x23/int64, 27x24/int64, 27x25/int64, 27x26/int64, 27x27/int64, 27x28/int64, 28x1/int64, 28x2/int64, 28x3/int64, 28x4/int64, 28x5/int64, 28x6/int64, 28x7/int64, 28x8/int64, 28x9/int64, 28x10/int64, 28x11/int64, 28x12/int64, 28x13/int64, 28x14/int64, 28x15/int64, 28x16/int64, 28x17/int64, 28x18/int64, 28x19/int64, 28x20/int64, 28x21/int64, 28x22/int64, 28x23/int64, 28x24/int64, 28x25/int64, 28x26/int64, 28x27/int64, 28x28/int64]; N=48000 data points
INFO  2024-04-30 08:29:58,390 sensai.torch.torch_opt.NNOptimiser:fit:682 - Preparing parameter learning of MultiLayerPerceptronTorchModel[cuda=False, inputDim=784, outputDim=10, hidActivationFunction=<built-in method sigmoid of type object at 0x7fe329788880>, outputActivationFunction=functools.partial(<function log_softmax at 0x7fe3886bb8c0>, dim=1), hiddenDims=(50, 20), pDropout=0.0, overrideInputDim=None, bestEpoch=None, totalEpochs=None] via NNOptimiser[params=NNOptimiserParams[epochs=1000, batch_size=54, optimiser_lr=0.001, shrinkage_clip=10.0, optimiser=adam, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=2, shuffle=True]] with cuda=False
INFO  2024-04-30 08:29:58,393 sensai.torch.torch_opt.NNOptimiser:fit:716 - Obtaining input/output training instances
INFO  2024-04-30 08:29:58,566 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Data set 1/1: #train=36000, #validation=12000
INFO  2024-04-30 08:29:58,567 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Number of validation sets: 1
INFO  2024-04-30 08:29:58,569 sensai.torch.torch_opt.NNOptimiser:fit:746 - Learning parameters of MultiLayerPerceptronTorchModel[cuda=False, inputDim=784, outputDim=10, hidActivationFunction=<built-in method sigmoid of type object at 0x7fe329788880>, outputActivationFunction=functools.partial(<function log_softmax at 0x7fe3886bb8c0>, dim=1), hiddenDims=(50, 20), pDropout=0.0, overrideInputDim=None, bestEpoch=None, totalEpochs=None]
INFO  2024-04-30 08:29:58,569 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Number of parameters: 40480
INFO  2024-04-30 08:29:58,570 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Starting training process via NNOptimiser[params=NNOptimiserParams[epochs=1000, batch_size=54, optimiser_lr=0.001, shrinkage_clip=10.0, optimiser=adam, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=2, shuffle=True]]
INFO  2024-04-30 08:29:58,579 sensai.torch.torch_opt.NNOptimiser:fit:764 - Begin training with cuda=False
INFO  2024-04-30 08:29:58,579 sensai.torch.torch_opt.NNOptimiser:fit:765 - Press Ctrl+C to end training early
INFO  2024-04-30 08:29:59,727 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   1/1000 completed in  1.15s | train loss 1.4059 | validation NLL 0.7300 | best NLL 0.730044 from this epoch
INFO  2024-04-30 08:30:00,883 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   2/1000 completed in  1.15s | train loss 0.5021 | validation NLL 0.3806 | best NLL 0.380582 from this epoch
INFO  2024-04-30 08:30:02,052 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   3/1000 completed in  1.16s | train loss 0.3138 | validation NLL 0.2847 | best NLL 0.284687 from this epoch
INFO  2024-04-30 08:30:03,201 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   4/1000 completed in  1.14s | train loss 0.2409 | validation NLL 0.2393 | best NLL 0.239259 from this epoch
INFO  2024-04-30 08:30:04,360 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   5/1000 completed in  1.15s | train loss 0.1981 | validation NLL 0.2081 | best NLL 0.208118 from this epoch
INFO  2024-04-30 08:30:05,491 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   6/1000 completed in  1.12s | train loss 0.1676 | validation NLL 0.1872 | best NLL 0.187202 from this epoch
INFO  2024-04-30 08:30:06,621 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   7/1000 completed in  1.12s | train loss 0.1450 | validation NLL 0.1697 | best NLL 0.169709 from this epoch
INFO  2024-04-30 08:30:07,805 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   8/1000 completed in  1.18s | train loss 0.1274 | validation NLL 0.1593 | best NLL 0.159322 from this epoch
INFO  2024-04-30 08:30:08,986 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   9/1000 completed in  1.17s | train loss 0.1131 | validation NLL 0.1509 | best NLL 0.150902 from this epoch
INFO  2024-04-30 08:30:10,153 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  10/1000 completed in  1.16s | train loss 0.1013 | validation NLL 0.1440 | best NLL 0.144011 from this epoch
INFO  2024-04-30 08:30:11,369 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  11/1000 completed in  1.21s | train loss 0.0906 | validation NLL 0.1381 | best NLL 0.138148 from this epoch
INFO  2024-04-30 08:30:12,517 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  12/1000 completed in  1.14s | train loss 0.0816 | validation NLL 0.1334 | best NLL 0.133373 from this epoch
INFO  2024-04-30 08:30:13,710 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  13/1000 completed in  1.19s | train loss 0.0739 | validation NLL 0.1315 | best NLL 0.131459 from this epoch
INFO  2024-04-30 08:30:14,900 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  14/1000 completed in  1.18s | train loss 0.0667 | validation NLL 0.1281 | best NLL 0.128110 from this epoch
INFO  2024-04-30 08:30:16,051 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  15/1000 completed in  1.14s | train loss 0.0610 | validation NLL 0.1271 | best NLL 0.127124 from this epoch
INFO  2024-04-30 08:30:17,189 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  16/1000 completed in  1.13s | train loss 0.0553 | validation NLL 0.1222 | best NLL 0.122218 from this epoch
INFO  2024-04-30 08:30:18,260 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  17/1000 completed in  1.06s | train loss 0.0504 | validation NLL 0.1258 | best NLL 0.122218 from epoch 16
INFO  2024-04-30 08:30:19,414 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch  18/1000 completed in  1.15s | train loss 0.0461 | validation NLL 0.1233 | best NLL 0.122218 from epoch 16
INFO  2024-04-30 08:30:19,415 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Stopping early: 2 epochs without validation metric improvement
INFO  2024-04-30 08:30:19,416 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Training complete
INFO  2024-04-30 08:30:19,416 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Best model is from epoch 16 with NLL 0.12221813999116421 on validation set
INFO  2024-04-30 08:30:19,426 sensai.vector_model:fit:400 - Fitting completed in 21.21 seconds: MultiLayerPerceptronVectorClassificationModel[hidden_dims=(50, 20), hid_activation_function=<built-in method sigmoid of type object at 0x7fe329788880>, output_activation_function=ActivationFunction.LOG_SOFTMAX, input_dim=None, cuda=False, p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=Method[_create_torch_model], normalisationMode=NormalisationMode.MAX_ALL, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=54, optimiser_lr=0.001, shrinkage_clip=10.0, optimiser=adam, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=2, shuffle=True], model=MultiLayerPerceptronTorchModel[cuda=False, inputDim=784, outputDim=10, hidActivationFunction=<built-in method sigmoid of type object at 0x7fe329788880>, outputActivationFunction=functools.partial(<function log_softmax at 0x7fe3886bb8c0>, dim=1), hiddenDims=(50, 20), pDropout=0.0, overrideInputDim=None, bestEpoch=16, totalEpochs=18], inputTensoriser=None, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=MLP]
DEBUG 2024-04-30 08:30:19,427 sensai.torch.torch_data:__init__:546 - Applying <sensai.torch.torch_data.TensoriserDataFrameFloatValuesMatrix object at 0x7fe3b02eb8d0> to data frame of length 12000 ...
INFO  2024-04-30 08:30:19,555 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.96225, balancedAccuracy=0.9618968610010363, N=12000]
INFO  2024-04-30 08:30:19,876 sensai.evaluation.eval_util:compare_models:462 - Model comparison results:
              accuracy  balancedAccuracy
model_name
RandomForest  0.946667          0.945917
MLP           0.962250          0.961897
[10]:
<sensai.evaluation.eval_util.ModelComparisonData at 0x7fe385f6ead0>
_images/neural_networks_19_2.png
_images/neural_networks_19_3.png
_images/neural_networks_19_4.png
_images/neural_networks_19_5.png

Both models perform reasonably well.

Creating a Custom CNN Model

Given that this is an image recognition problem, it can be sensible to apply convolutional neural networks (CNNs), which can analyse patches of the image in order to generate more high-level features from them. Specifically, we shall apply a neural network model which uses multiple convolutions, a max-pooling layer and a multi-layer perceptron at the end in order to produce the classification.

For classification and regression, sensAI provides the fundamental classes TorchVectorClassificationModel and TorchVectorRegressionModel respectively. Ultimately, these classes will wrap an instance of torch.nn.Module, the base class for neural networks in PyTorch.

Wrapping a Custom torch.nn.Module Instance

If we already had an implementation of a torch.nn.Module, it can be straightforwardly adapted to become a sensAI VectorModel.

Let’s say we had the following implementation of a torch module, which performs the steps described above.

[11]:
import torch

class MnistCnnModule(torch.nn.Module):
    def __init__(self, image_dim: int, output_dim: int, num_conv: int, kernel_size: int, pooling_kernel_size: int,
            mlp_hidden_dims: Sequence[int], output_activation_fn: sensai.torch.ActivationFunction, p_dropout=0.0):
        super().__init__()
        k = kernel_size
        p = pooling_kernel_size
        self.cnn = torch.nn.Conv2d(1, num_conv, (k, k))
        self.pool = torch.nn.MaxPool2d((p, p))
        self.dropout = torch.nn.Dropout(p=p_dropout)
        reduced_dim = (image_dim - k + 1) / p
        if int(reduced_dim) != reduced_dim:
            raise ValueError(f"Pooling kernel size {p} is not a divisor of post-convolution dimension {image_dim - k + 1}")
        self.mlp = sensai.torch.models.MultiLayerPerceptron(num_conv * int(reduced_dim) ** 2, output_dim, mlp_hidden_dims,
            output_activation_fn=output_activation_fn.get_torch_function(),
            hid_activation_fn=sensai.torch.ActivationFunction.RELU.get_torch_function(),
            p_dropout=p_dropout)

    def forward(self, x):
        x = self.cnn(x.unsqueeze(1))
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = self.dropout(x)
        return self.mlp(x)

Since this module requires 2D images as input, we will need a component that transforms the vector input that is given in our data frame into a tensor that will serve as input to the module. In sensAI, the abstraction for this purpose is a sensai.torch.Tensoriser. A Tensoriser can, in principle, perform arbitrary computations in order to produce, from a data frame with N rows, one or more tensors of length N (first dimension equal to N) that will ultimately be fed to the neural network.

Luckily, for the case at hand, we already have the function reshape_2d_image from above to assist in the implementation of the tensoriser.

[12]:
class ImageReshapingInputTensoriser(sensai.torch.RuleBasedTensoriser):
    def _tensorise(self, df: pd.DataFrame) -> Union[torch.Tensor, List[torch.Tensor]]:
        images = [reshape_2d_image(row) for _, row in df.iterrows()]
        return torch.tensor(np.stack(images)).float() / 255

In this case, we derived the class from RuleBasedTensoriser rather than Tensoriser, because our tensoriser does not require fitting. We additionally took care of the normalisation.

Now we have all we need to create a sensAI TorchVectorClassificationModel that will work on the input/output data we loaded earlier.

[13]:
cnn_module = MnistCnnModule(28, 10, 32, 5, 2, (200, 20), sensai.torch.ActivationFunction.LOG_SOFTMAX)
nn_optimiser_params = sensai.torch.NNOptimiserParams(
    optimiser=sensai.torch.Optimiser.ADAMW,
    optimiser_lr=0.01,
    batch_size=1024,
    early_stopping_epochs=3)
cnn_model_from_module = sensai.torch.TorchVectorClassificationModel.from_module(
        cnn_module, sensai.torch.ClassificationOutputMode.LOG_PROBABILITIES,
        cuda=False, nn_optimiser_params=nn_optimiser_params) \
    .with_input_tensoriser(ImageReshapingInputTensoriser()) \
    .with_name("CNN")

We have now fully defined all the necessary parameters, including parameters controlling the training of the model.

We are now ready to evaluate the model.

[14]:
eval_util.perform_simple_evaluation(cnn_model_from_module);
DEBUG 2024-04-30 08:30:21,859 sensai.evaluation.evaluator:__init__:182 - <sensai.data.DataSplitterFractional object at 0x7fe388bb8a90> created split with 48000 (80.00%) and 12000 (20.00%) training and test data points respectively
INFO  2024-04-30 08:30:21,860 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating TorchVectorClassificationModel[featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=<sensai.torch.torch_base.TorchModelFactoryFromModule object at 0x7fe385b7ab90>, normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=None, inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385b6a150>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe385b7a850>
INFO  2024-04-30 08:30:21,861 sensai.vector_model:fit:371 - Fitting TorchVectorClassificationModel instance
DEBUG 2024-04-30 08:30:21,884 sensai.vector_model:fit:394 - Fitting with outputs[1]=['label'], inputs[784]=[1x1/int64, 1x2/int64, 1x3/int64, 1x4/int64, 1x5/int64, 1x6/int64, 1x7/int64, 1x8/int64, 1x9/int64, 1x10/int64, 1x11/int64, 1x12/int64, 1x13/int64, 1x14/int64, 1x15/int64, 1x16/int64, 1x17/int64, 1x18/int64, 1x19/int64, 1x20/int64, 1x21/int64, 1x22/int64, 1x23/int64, 1x24/int64, 1x25/int64, 1x26/int64, 1x27/int64, 1x28/int64, 2x1/int64, 2x2/int64, 2x3/int64, 2x4/int64, 2x5/int64, 2x6/int64, 2x7/int64, 2x8/int64, 2x9/int64, 2x10/int64, 2x11/int64, 2x12/int64, 2x13/int64, 2x14/int64, 2x15/int64, 2x16/int64, 2x17/int64, 2x18/int64, 2x19/int64, 2x20/int64, 2x21/int64, 2x22/int64, 2x23/int64, 2x24/int64, 2x25/int64, 2x26/int64, 2x27/int64, 2x28/int64, 3x1/int64, 3x2/int64, 3x3/int64, 3x4/int64, 3x5/int64, 3x6/int64, 3x7/int64, 3x8/int64, 3x9/int64, 3x10/int64, 3x11/int64, 3x12/int64, 3x13/int64, 3x14/int64, 3x15/int64, 3x16/int64, 3x17/int64, 3x18/int64, 3x19/int64, 3x20/int64, 3x21/int64, 3x22/int64, 3x23/int64, 3x24/int64, 3x25/int64, 3x26/int64, 3x27/int64, 3x28/int64, 4x1/int64, 4x2/int64, 4x3/int64, 4x4/int64, 4x5/int64, 4x6/int64, 4x7/int64, 4x8/int64, 4x9/int64, 4x10/int64, 4x11/int64, 4x12/int64, 4x13/int64, 4x14/int64, 4x15/int64, 4x16/int64, 4x17/int64, 4x18/int64, 4x19/int64, 4x20/int64, 4x21/int64, 4x22/int64, 4x23/int64, 4x24/int64, 4x25/int64, 4x26/int64, 4x27/int64, 4x28/int64, 5x1/int64, 5x2/int64, 5x3/int64, 5x4/int64, 5x5/int64, 5x6/int64, 5x7/int64, 5x8/int64, 5x9/int64, 5x10/int64, 5x11/int64, 5x12/int64, 5x13/int64, 5x14/int64, 5x15/int64, 5x16/int64, 5x17/int64, 5x18/int64, 5x19/int64, 5x20/int64, 5x21/int64, 5x22/int64, 5x23/int64, 5x24/int64, 5x25/int64, 5x26/int64, 5x27/int64, 5x28/int64, 6x1/int64, 6x2/int64, 6x3/int64, 6x4/int64, 6x5/int64, 6x6/int64, 6x7/int64, 6x8/int64, 6x9/int64, 6x10/int64, 6x11/int64, 6x12/int64, 6x13/int64, 6x14/int64, 6x15/int64, 6x16/int64, 6x17/int64, 6x18/int64, 6x19/int64, 6x20/int64, 6x21/int64, 6x22/int64, 6x23/int64, 6x24/int64, 6x25/int64, 6x26/int64, 6x27/int64, 6x28/int64, 7x1/int64, 7x2/int64, 7x3/int64, 7x4/int64, 7x5/int64, 7x6/int64, 7x7/int64, 7x8/int64, 7x9/int64, 7x10/int64, 7x11/int64, 7x12/int64, 7x13/int64, 7x14/int64, 7x15/int64, 7x16/int64, 7x17/int64, 7x18/int64, 7x19/int64, 7x20/int64, 7x21/int64, 7x22/int64, 7x23/int64, 7x24/int64, 7x25/int64, 7x26/int64, 7x27/int64, 7x28/int64, 8x1/int64, 8x2/int64, 8x3/int64, 8x4/int64, 8x5/int64, 8x6/int64, 8x7/int64, 8x8/int64, 8x9/int64, 8x10/int64, 8x11/int64, 8x12/int64, 8x13/int64, 8x14/int64, 8x15/int64, 8x16/int64, 8x17/int64, 8x18/int64, 8x19/int64, 8x20/int64, 8x21/int64, 8x22/int64, 8x23/int64, 8x24/int64, 8x25/int64, 8x26/int64, 8x27/int64, 8x28/int64, 9x1/int64, 9x2/int64, 9x3/int64, 9x4/int64, 9x5/int64, 9x6/int64, 9x7/int64, 9x8/int64, 9x9/int64, 9x10/int64, 9x11/int64, 9x12/int64, 9x13/int64, 9x14/int64, 9x15/int64, 9x16/int64, 9x17/int64, 9x18/int64, 9x19/int64, 9x20/int64, 9x21/int64, 9x22/int64, 9x23/int64, 9x24/int64, 9x25/int64, 9x26/int64, 9x27/int64, 9x28/int64, 10x1/int64, 10x2/int64, 10x3/int64, 10x4/int64, 10x5/int64, 10x6/int64, 10x7/int64, 10x8/int64, 10x9/int64, 10x10/int64, 10x11/int64, 10x12/int64, 10x13/int64, 10x14/int64, 10x15/int64, 10x16/int64, 10x17/int64, 10x18/int64, 10x19/int64, 10x20/int64, 10x21/int64, 10x22/int64, 10x23/int64, 10x24/int64, 10x25/int64, 10x26/int64, 10x27/int64, 10x28/int64, 11x1/int64, 11x2/int64, 11x3/int64, 11x4/int64, 11x5/int64, 11x6/int64, 11x7/int64, 11x8/int64, 11x9/int64, 11x10/int64, 11x11/int64, 11x12/int64, 11x13/int64, 11x14/int64, 11x15/int64, 11x16/int64, 11x17/int64, 11x18/int64, 11x19/int64, 11x20/int64, 11x21/int64, 11x22/int64, 11x23/int64, 11x24/int64, 11x25/int64, 11x26/int64, 11x27/int64, 11x28/int64, 12x1/int64, 12x2/int64, 12x3/int64, 12x4/int64, 12x5/int64, 12x6/int64, 12x7/int64, 12x8/int64, 12x9/int64, 12x10/int64, 12x11/int64, 12x12/int64, 12x13/int64, 12x14/int64, 12x15/int64, 12x16/int64, 12x17/int64, 12x18/int64, 12x19/int64, 12x20/int64, 12x21/int64, 12x22/int64, 12x23/int64, 12x24/int64, 12x25/int64, 12x26/int64, 12x27/int64, 12x28/int64, 13x1/int64, 13x2/int64, 13x3/int64, 13x4/int64, 13x5/int64, 13x6/int64, 13x7/int64, 13x8/int64, 13x9/int64, 13x10/int64, 13x11/int64, 13x12/int64, 13x13/int64, 13x14/int64, 13x15/int64, 13x16/int64, 13x17/int64, 13x18/int64, 13x19/int64, 13x20/int64, 13x21/int64, 13x22/int64, 13x23/int64, 13x24/int64, 13x25/int64, 13x26/int64, 13x27/int64, 13x28/int64, 14x1/int64, 14x2/int64, 14x3/int64, 14x4/int64, 14x5/int64, 14x6/int64, 14x7/int64, 14x8/int64, 14x9/int64, 14x10/int64, 14x11/int64, 14x12/int64, 14x13/int64, 14x14/int64, 14x15/int64, 14x16/int64, 14x17/int64, 14x18/int64, 14x19/int64, 14x20/int64, 14x21/int64, 14x22/int64, 14x23/int64, 14x24/int64, 14x25/int64, 14x26/int64, 14x27/int64, 14x28/int64, 15x1/int64, 15x2/int64, 15x3/int64, 15x4/int64, 15x5/int64, 15x6/int64, 15x7/int64, 15x8/int64, 15x9/int64, 15x10/int64, 15x11/int64, 15x12/int64, 15x13/int64, 15x14/int64, 15x15/int64, 15x16/int64, 15x17/int64, 15x18/int64, 15x19/int64, 15x20/int64, 15x21/int64, 15x22/int64, 15x23/int64, 15x24/int64, 15x25/int64, 15x26/int64, 15x27/int64, 15x28/int64, 16x1/int64, 16x2/int64, 16x3/int64, 16x4/int64, 16x5/int64, 16x6/int64, 16x7/int64, 16x8/int64, 16x9/int64, 16x10/int64, 16x11/int64, 16x12/int64, 16x13/int64, 16x14/int64, 16x15/int64, 16x16/int64, 16x17/int64, 16x18/int64, 16x19/int64, 16x20/int64, 16x21/int64, 16x22/int64, 16x23/int64, 16x24/int64, 16x25/int64, 16x26/int64, 16x27/int64, 16x28/int64, 17x1/int64, 17x2/int64, 17x3/int64, 17x4/int64, 17x5/int64, 17x6/int64, 17x7/int64, 17x8/int64, 17x9/int64, 17x10/int64, 17x11/int64, 17x12/int64, 17x13/int64, 17x14/int64, 17x15/int64, 17x16/int64, 17x17/int64, 17x18/int64, 17x19/int64, 17x20/int64, 17x21/int64, 17x22/int64, 17x23/int64, 17x24/int64, 17x25/int64, 17x26/int64, 17x27/int64, 17x28/int64, 18x1/int64, 18x2/int64, 18x3/int64, 18x4/int64, 18x5/int64, 18x6/int64, 18x7/int64, 18x8/int64, 18x9/int64, 18x10/int64, 18x11/int64, 18x12/int64, 18x13/int64, 18x14/int64, 18x15/int64, 18x16/int64, 18x17/int64, 18x18/int64, 18x19/int64, 18x20/int64, 18x21/int64, 18x22/int64, 18x23/int64, 18x24/int64, 18x25/int64, 18x26/int64, 18x27/int64, 18x28/int64, 19x1/int64, 19x2/int64, 19x3/int64, 19x4/int64, 19x5/int64, 19x6/int64, 19x7/int64, 19x8/int64, 19x9/int64, 19x10/int64, 19x11/int64, 19x12/int64, 19x13/int64, 19x14/int64, 19x15/int64, 19x16/int64, 19x17/int64, 19x18/int64, 19x19/int64, 19x20/int64, 19x21/int64, 19x22/int64, 19x23/int64, 19x24/int64, 19x25/int64, 19x26/int64, 19x27/int64, 19x28/int64, 20x1/int64, 20x2/int64, 20x3/int64, 20x4/int64, 20x5/int64, 20x6/int64, 20x7/int64, 20x8/int64, 20x9/int64, 20x10/int64, 20x11/int64, 20x12/int64, 20x13/int64, 20x14/int64, 20x15/int64, 20x16/int64, 20x17/int64, 20x18/int64, 20x19/int64, 20x20/int64, 20x21/int64, 20x22/int64, 20x23/int64, 20x24/int64, 20x25/int64, 20x26/int64, 20x27/int64, 20x28/int64, 21x1/int64, 21x2/int64, 21x3/int64, 21x4/int64, 21x5/int64, 21x6/int64, 21x7/int64, 21x8/int64, 21x9/int64, 21x10/int64, 21x11/int64, 21x12/int64, 21x13/int64, 21x14/int64, 21x15/int64, 21x16/int64, 21x17/int64, 21x18/int64, 21x19/int64, 21x20/int64, 21x21/int64, 21x22/int64, 21x23/int64, 21x24/int64, 21x25/int64, 21x26/int64, 21x27/int64, 21x28/int64, 22x1/int64, 22x2/int64, 22x3/int64, 22x4/int64, 22x5/int64, 22x6/int64, 22x7/int64, 22x8/int64, 22x9/int64, 22x10/int64, 22x11/int64, 22x12/int64, 22x13/int64, 22x14/int64, 22x15/int64, 22x16/int64, 22x17/int64, 22x18/int64, 22x19/int64, 22x20/int64, 22x21/int64, 22x22/int64, 22x23/int64, 22x24/int64, 22x25/int64, 22x26/int64, 22x27/int64, 22x28/int64, 23x1/int64, 23x2/int64, 23x3/int64, 23x4/int64, 23x5/int64, 23x6/int64, 23x7/int64, 23x8/int64, 23x9/int64, 23x10/int64, 23x11/int64, 23x12/int64, 23x13/int64, 23x14/int64, 23x15/int64, 23x16/int64, 23x17/int64, 23x18/int64, 23x19/int64, 23x20/int64, 23x21/int64, 23x22/int64, 23x23/int64, 23x24/int64, 23x25/int64, 23x26/int64, 23x27/int64, 23x28/int64, 24x1/int64, 24x2/int64, 24x3/int64, 24x4/int64, 24x5/int64, 24x6/int64, 24x7/int64, 24x8/int64, 24x9/int64, 24x10/int64, 24x11/int64, 24x12/int64, 24x13/int64, 24x14/int64, 24x15/int64, 24x16/int64, 24x17/int64, 24x18/int64, 24x19/int64, 24x20/int64, 24x21/int64, 24x22/int64, 24x23/int64, 24x24/int64, 24x25/int64, 24x26/int64, 24x27/int64, 24x28/int64, 25x1/int64, 25x2/int64, 25x3/int64, 25x4/int64, 25x5/int64, 25x6/int64, 25x7/int64, 25x8/int64, 25x9/int64, 25x10/int64, 25x11/int64, 25x12/int64, 25x13/int64, 25x14/int64, 25x15/int64, 25x16/int64, 25x17/int64, 25x18/int64, 25x19/int64, 25x20/int64, 25x21/int64, 25x22/int64, 25x23/int64, 25x24/int64, 25x25/int64, 25x26/int64, 25x27/int64, 25x28/int64, 26x1/int64, 26x2/int64, 26x3/int64, 26x4/int64, 26x5/int64, 26x6/int64, 26x7/int64, 26x8/int64, 26x9/int64, 26x10/int64, 26x11/int64, 26x12/int64, 26x13/int64, 26x14/int64, 26x15/int64, 26x16/int64, 26x17/int64, 26x18/int64, 26x19/int64, 26x20/int64, 26x21/int64, 26x22/int64, 26x23/int64, 26x24/int64, 26x25/int64, 26x26/int64, 26x27/int64, 26x28/int64, 27x1/int64, 27x2/int64, 27x3/int64, 27x4/int64, 27x5/int64, 27x6/int64, 27x7/int64, 27x8/int64, 27x9/int64, 27x10/int64, 27x11/int64, 27x12/int64, 27x13/int64, 27x14/int64, 27x15/int64, 27x16/int64, 27x17/int64, 27x18/int64, 27x19/int64, 27x20/int64, 27x21/int64, 27x22/int64, 27x23/int64, 27x24/int64, 27x25/int64, 27x26/int64, 27x27/int64, 27x28/int64, 28x1/int64, 28x2/int64, 28x3/int64, 28x4/int64, 28x5/int64, 28x6/int64, 28x7/int64, 28x8/int64, 28x9/int64, 28x10/int64, 28x11/int64, 28x12/int64, 28x13/int64, 28x14/int64, 28x15/int64, 28x16/int64, 28x17/int64, 28x18/int64, 28x19/int64, 28x20/int64, 28x21/int64, 28x22/int64, 28x23/int64, 28x24/int64, 28x25/int64, 28x26/int64, 28x27/int64, 28x28/int64]; N=48000 data points
INFO  2024-04-30 08:30:21,885 sensai.torch.torch_base:_fit_classifier:780 - Fitting <__main__.ImageReshapingInputTensoriser object at 0x7fe385b6a150> ...
INFO  2024-04-30 08:30:21,944 sensai.torch.torch_opt.NNOptimiser:fit:682 - Preparing parameter learning of TorchModelFromModule[cuda=False, bestEpoch=None, totalEpochs=None] via NNOptimiser[params=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True]] with cuda=False
INFO  2024-04-30 08:30:21,945 sensai.torch.torch_opt.NNOptimiser:fit:716 - Obtaining input/output training instances
INFO  2024-04-30 08:30:24,396 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Data set 1/1: #train=36000, #validation=12000
INFO  2024-04-30 08:30:24,398 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Number of validation sets: 1
INFO  2024-04-30 08:30:24,399 sensai.torch.torch_opt.NNOptimiser:fit:746 - Learning parameters of TorchModelFromModule[cuda=False, bestEpoch=None, totalEpochs=None]
INFO  2024-04-30 08:30:24,399 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Number of parameters: 926862
INFO  2024-04-30 08:30:24,400 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Starting training process via NNOptimiser[params=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True]]
INFO  2024-04-30 08:30:24,411 sensai.torch.torch_opt.NNOptimiser:fit:764 - Begin training with cuda=False
INFO  2024-04-30 08:30:24,411 sensai.torch.torch_opt.NNOptimiser:fit:765 - Press Ctrl+C to end training early
INFO  2024-04-30 08:30:34,292 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   1/1000 completed in  9.88s | train loss 0.8318 | validation NLL 0.1962 | best NLL 0.196246 from this epoch
INFO  2024-04-30 08:30:44,162 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   2/1000 completed in  9.86s | train loss 0.1365 | validation NLL 0.1341 | best NLL 0.134100 from this epoch
INFO  2024-04-30 08:30:54,059 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   3/1000 completed in  9.89s | train loss 0.0724 | validation NLL 0.1154 | best NLL 0.115383 from this epoch
INFO  2024-04-30 08:31:03,915 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   4/1000 completed in  9.85s | train loss 0.0524 | validation NLL 0.1286 | best NLL 0.115383 from epoch 3
INFO  2024-04-30 08:31:13,755 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   5/1000 completed in  9.84s | train loss 0.0418 | validation NLL 0.1073 | best NLL 0.107293 from this epoch
INFO  2024-04-30 08:31:23,543 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   6/1000 completed in  9.78s | train loss 0.0345 | validation NLL 0.0818 | best NLL 0.081762 from this epoch
INFO  2024-04-30 08:31:33,418 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   7/1000 completed in  9.86s | train loss 0.0293 | validation NLL 0.0953 | best NLL 0.081762 from epoch 6
INFO  2024-04-30 08:31:43,175 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   8/1000 completed in  9.76s | train loss 0.0245 | validation NLL 0.1157 | best NLL 0.081762 from epoch 6
INFO  2024-04-30 08:31:52,911 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   9/1000 completed in  9.73s | train loss 0.0209 | validation NLL 0.0996 | best NLL 0.081762 from epoch 6
INFO  2024-04-30 08:31:52,911 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Stopping early: 3 epochs without validation metric improvement
INFO  2024-04-30 08:31:52,912 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Training complete
INFO  2024-04-30 08:31:52,912 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Best model is from epoch 6 with NLL 0.08176225980122884 on validation set
INFO  2024-04-30 08:31:52,923 sensai.vector_model:fit:400 - Fitting completed in 91.06 seconds: TorchVectorClassificationModel[featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=<sensai.torch.torch_base.TorchModelFactoryFromModule object at 0x7fe385b7ab90>, normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=TorchModelFromModule[cuda=False, bestEpoch=6, totalEpochs=9], inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385b6a150>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN]
DEBUG 2024-04-30 08:31:52,924 sensai.torch.torch_data:__init__:546 - Applying <__main__.ImageReshapingInputTensoriser object at 0x7fe385b6a150> to data frame of length 12000 ...
INFO  2024-04-30 08:31:54,696 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.9784166666666667, balancedAccuracy=0.9782716826784611, N=12000]
_images/neural_networks_28_1.png
_images/neural_networks_28_2.png

Creating an Input-/Output-Adaptive Custom Model

While the above approach allows us to straightforwardly encapsulate a torch.nn.Module, it really doesn’t follow sensAI’s principle of adapting model hyperparameters based on the inputs and outputs we receive during training - whenever possible. Notice that in the above example, we had to hard-code the image dimension (28) as well as the number of classes (10), even though these parameters could have been easily determined from the data. Especially in other domains where feature engineering is possible, we might want to experiment with different combinations of features, and therefore automatically adapting to inputs is key if we want to avoid editing the model hyperparameters time and time again; similarly, we might change the set of target labels in our classification problem and the model should simply adapt to a changed output dimension.

To design a model that can fully adapt to the inputs and outputs, we can simply subclass TorchVectorClassificationModel, where the late instantiation of the underlying model is catered for. Naturally, delayed construction of the underlying model necessitates the use of factories and thus results in some indirections.

If we had designed the above model to be within the sensAI VectorModel realm from the beginning, here’s what we might have written:

[15]:
import torch

class CnnModel(sensai.torch.TorchVectorClassificationModel):
    def __init__(self, cuda: bool, kernel_size: int, num_conv: int, pooling_kernel_size: int, mlp_hidden_dims: Sequence[int],
            nn_optimiser_params: sensai.torch.NNOptimiserParams, p_dropout=0.0):
        self.cuda = cuda
        self.output_activation_fn = sensai.torch.ActivationFunction.LOG_SOFTMAX
        self.kernel_size = kernel_size
        self.num_conv = num_conv
        self.pooling_kernel_size = pooling_kernel_size
        self.mlp_hidden_dims = mlp_hidden_dims
        self.p_dropout = p_dropout
        super().__init__(sensai.torch.ClassificationOutputMode.for_activation_fn(self.output_activation_fn),
            torch_model_factory=functools.partial(self.VectorTorchModel, self),
            nn_optimiser_params=nn_optimiser_params)

    class VectorTorchModel(sensai.torch.VectorTorchModel):
        def __init__(self, parent: "CnnModel"):
            super().__init__(parent.cuda)
            self._parent = parent

        def create_torch_module_for_dims(self, input_dim: int, output_dim: int) -> torch.nn.Module:
            return self.Module(int(np.sqrt(input_dim)), output_dim, self._parent)

        class Module(torch.nn.Module):
            def __init__(self, image_dim, output_dim, parent: "CnnModel"):
                super().__init__()
                k = parent.kernel_size
                p = parent.pooling_kernel_size
                self.cnn = torch.nn.Conv2d(1, parent.num_conv, (k, k))
                self.pool = torch.nn.MaxPool2d((p, p))
                self.dropout = torch.nn.Dropout(p=parent.p_dropout)
                reduced_dim = (image_dim - k + 1) / p
                if int(reduced_dim) != reduced_dim:
                    raise ValueError(f"Pooling kernel size {p} is not a divisor of post-convolution dimension {image_dim - k + 1}")
                self.mlp = sensai.torch.models.MultiLayerPerceptron(parent.num_conv * int(reduced_dim) ** 2, output_dim, parent.mlp_hidden_dims,
                    output_activation_fn=parent.output_activation_fn.get_torch_function(),
                    hid_activation_fn=sensai.torch.ActivationFunction.RELU.get_torch_function(),
                    p_dropout=parent.p_dropout)

            def forward(self, x):
                x = self.cnn(x.unsqueeze(1))
                x = self.pool(x)
                x = x.view(x.shape[0], -1)
                x = self.dropout(x)
                return self.mlp(x)

It is only insignificantly more code than in the previous implementation. The outer class, which provides the sensAI VectorModel features, serves mainly to hold the parameters, and the inner class inheriting from VectorTorchModel serves as a factory for the torch.nn.Module, providing us with the input and output dimensions (number of input columns and number of classes respectively) based on the data, thus enabling the model to adapt. If we had required even more adaptiveness, we could have learnt more about the data from within the fitting process of a custom input tensoriser (i.e. we could have added an inner Tensoriser class, which could have derived further hyperparameters from the data in its implementation of the fitting method.)

Let’s instantiate our model and evaluate it.

[16]:
cnn_model = CnnModel(cuda=False, kernel_size=5, num_conv=32, pooling_kernel_size=2, mlp_hidden_dims=(200,20),
        nn_optimiser_params=nn_optimiser_params) \
    .with_name("CNN'") \
    .with_input_tensoriser(ImageReshapingInputTensoriser())

eval_data = eval_util.perform_simple_evaluation(cnn_model)
DEBUG 2024-04-30 08:31:56,187 sensai.evaluation.evaluator:__init__:182 - <sensai.data.DataSplitterFractional object at 0x7fe388bb8a90> created split with 48000 (80.00%) and 12000 (20.00%) training and test data points respectively
INFO  2024-04-30 08:31:56,188 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating CnnModel[cuda=False, output_activation_fn=ActivationFunction.LOG_SOFTMAX, kernel_size=5, num_conv=32, pooling_kernel_size=2, mlp_hidden_dims=(200, 20), p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=functools.partial(<class '__main__.CnnModel.VectorTorchModel'>, CnnModel[id=140615177024144, cuda=False, output_activation_fn=ActivationFunction.LOG_SOFTMAX, kernel_size=5, num_conv=32, pooling_kernel_size=2, mlp_hidden_dims=(200, 20), p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=..., normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=None, inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN']), normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=None, inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN'] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe385ad3710>
INFO  2024-04-30 08:31:56,188 sensai.vector_model:fit:371 - Fitting CnnModel instance
DEBUG 2024-04-30 08:31:56,211 sensai.vector_model:fit:394 - Fitting with outputs[1]=['label'], inputs[784]=[1x1/int64, 1x2/int64, 1x3/int64, 1x4/int64, 1x5/int64, 1x6/int64, 1x7/int64, 1x8/int64, 1x9/int64, 1x10/int64, 1x11/int64, 1x12/int64, 1x13/int64, 1x14/int64, 1x15/int64, 1x16/int64, 1x17/int64, 1x18/int64, 1x19/int64, 1x20/int64, 1x21/int64, 1x22/int64, 1x23/int64, 1x24/int64, 1x25/int64, 1x26/int64, 1x27/int64, 1x28/int64, 2x1/int64, 2x2/int64, 2x3/int64, 2x4/int64, 2x5/int64, 2x6/int64, 2x7/int64, 2x8/int64, 2x9/int64, 2x10/int64, 2x11/int64, 2x12/int64, 2x13/int64, 2x14/int64, 2x15/int64, 2x16/int64, 2x17/int64, 2x18/int64, 2x19/int64, 2x20/int64, 2x21/int64, 2x22/int64, 2x23/int64, 2x24/int64, 2x25/int64, 2x26/int64, 2x27/int64, 2x28/int64, 3x1/int64, 3x2/int64, 3x3/int64, 3x4/int64, 3x5/int64, 3x6/int64, 3x7/int64, 3x8/int64, 3x9/int64, 3x10/int64, 3x11/int64, 3x12/int64, 3x13/int64, 3x14/int64, 3x15/int64, 3x16/int64, 3x17/int64, 3x18/int64, 3x19/int64, 3x20/int64, 3x21/int64, 3x22/int64, 3x23/int64, 3x24/int64, 3x25/int64, 3x26/int64, 3x27/int64, 3x28/int64, 4x1/int64, 4x2/int64, 4x3/int64, 4x4/int64, 4x5/int64, 4x6/int64, 4x7/int64, 4x8/int64, 4x9/int64, 4x10/int64, 4x11/int64, 4x12/int64, 4x13/int64, 4x14/int64, 4x15/int64, 4x16/int64, 4x17/int64, 4x18/int64, 4x19/int64, 4x20/int64, 4x21/int64, 4x22/int64, 4x23/int64, 4x24/int64, 4x25/int64, 4x26/int64, 4x27/int64, 4x28/int64, 5x1/int64, 5x2/int64, 5x3/int64, 5x4/int64, 5x5/int64, 5x6/int64, 5x7/int64, 5x8/int64, 5x9/int64, 5x10/int64, 5x11/int64, 5x12/int64, 5x13/int64, 5x14/int64, 5x15/int64, 5x16/int64, 5x17/int64, 5x18/int64, 5x19/int64, 5x20/int64, 5x21/int64, 5x22/int64, 5x23/int64, 5x24/int64, 5x25/int64, 5x26/int64, 5x27/int64, 5x28/int64, 6x1/int64, 6x2/int64, 6x3/int64, 6x4/int64, 6x5/int64, 6x6/int64, 6x7/int64, 6x8/int64, 6x9/int64, 6x10/int64, 6x11/int64, 6x12/int64, 6x13/int64, 6x14/int64, 6x15/int64, 6x16/int64, 6x17/int64, 6x18/int64, 6x19/int64, 6x20/int64, 6x21/int64, 6x22/int64, 6x23/int64, 6x24/int64, 6x25/int64, 6x26/int64, 6x27/int64, 6x28/int64, 7x1/int64, 7x2/int64, 7x3/int64, 7x4/int64, 7x5/int64, 7x6/int64, 7x7/int64, 7x8/int64, 7x9/int64, 7x10/int64, 7x11/int64, 7x12/int64, 7x13/int64, 7x14/int64, 7x15/int64, 7x16/int64, 7x17/int64, 7x18/int64, 7x19/int64, 7x20/int64, 7x21/int64, 7x22/int64, 7x23/int64, 7x24/int64, 7x25/int64, 7x26/int64, 7x27/int64, 7x28/int64, 8x1/int64, 8x2/int64, 8x3/int64, 8x4/int64, 8x5/int64, 8x6/int64, 8x7/int64, 8x8/int64, 8x9/int64, 8x10/int64, 8x11/int64, 8x12/int64, 8x13/int64, 8x14/int64, 8x15/int64, 8x16/int64, 8x17/int64, 8x18/int64, 8x19/int64, 8x20/int64, 8x21/int64, 8x22/int64, 8x23/int64, 8x24/int64, 8x25/int64, 8x26/int64, 8x27/int64, 8x28/int64, 9x1/int64, 9x2/int64, 9x3/int64, 9x4/int64, 9x5/int64, 9x6/int64, 9x7/int64, 9x8/int64, 9x9/int64, 9x10/int64, 9x11/int64, 9x12/int64, 9x13/int64, 9x14/int64, 9x15/int64, 9x16/int64, 9x17/int64, 9x18/int64, 9x19/int64, 9x20/int64, 9x21/int64, 9x22/int64, 9x23/int64, 9x24/int64, 9x25/int64, 9x26/int64, 9x27/int64, 9x28/int64, 10x1/int64, 10x2/int64, 10x3/int64, 10x4/int64, 10x5/int64, 10x6/int64, 10x7/int64, 10x8/int64, 10x9/int64, 10x10/int64, 10x11/int64, 10x12/int64, 10x13/int64, 10x14/int64, 10x15/int64, 10x16/int64, 10x17/int64, 10x18/int64, 10x19/int64, 10x20/int64, 10x21/int64, 10x22/int64, 10x23/int64, 10x24/int64, 10x25/int64, 10x26/int64, 10x27/int64, 10x28/int64, 11x1/int64, 11x2/int64, 11x3/int64, 11x4/int64, 11x5/int64, 11x6/int64, 11x7/int64, 11x8/int64, 11x9/int64, 11x10/int64, 11x11/int64, 11x12/int64, 11x13/int64, 11x14/int64, 11x15/int64, 11x16/int64, 11x17/int64, 11x18/int64, 11x19/int64, 11x20/int64, 11x21/int64, 11x22/int64, 11x23/int64, 11x24/int64, 11x25/int64, 11x26/int64, 11x27/int64, 11x28/int64, 12x1/int64, 12x2/int64, 12x3/int64, 12x4/int64, 12x5/int64, 12x6/int64, 12x7/int64, 12x8/int64, 12x9/int64, 12x10/int64, 12x11/int64, 12x12/int64, 12x13/int64, 12x14/int64, 12x15/int64, 12x16/int64, 12x17/int64, 12x18/int64, 12x19/int64, 12x20/int64, 12x21/int64, 12x22/int64, 12x23/int64, 12x24/int64, 12x25/int64, 12x26/int64, 12x27/int64, 12x28/int64, 13x1/int64, 13x2/int64, 13x3/int64, 13x4/int64, 13x5/int64, 13x6/int64, 13x7/int64, 13x8/int64, 13x9/int64, 13x10/int64, 13x11/int64, 13x12/int64, 13x13/int64, 13x14/int64, 13x15/int64, 13x16/int64, 13x17/int64, 13x18/int64, 13x19/int64, 13x20/int64, 13x21/int64, 13x22/int64, 13x23/int64, 13x24/int64, 13x25/int64, 13x26/int64, 13x27/int64, 13x28/int64, 14x1/int64, 14x2/int64, 14x3/int64, 14x4/int64, 14x5/int64, 14x6/int64, 14x7/int64, 14x8/int64, 14x9/int64, 14x10/int64, 14x11/int64, 14x12/int64, 14x13/int64, 14x14/int64, 14x15/int64, 14x16/int64, 14x17/int64, 14x18/int64, 14x19/int64, 14x20/int64, 14x21/int64, 14x22/int64, 14x23/int64, 14x24/int64, 14x25/int64, 14x26/int64, 14x27/int64, 14x28/int64, 15x1/int64, 15x2/int64, 15x3/int64, 15x4/int64, 15x5/int64, 15x6/int64, 15x7/int64, 15x8/int64, 15x9/int64, 15x10/int64, 15x11/int64, 15x12/int64, 15x13/int64, 15x14/int64, 15x15/int64, 15x16/int64, 15x17/int64, 15x18/int64, 15x19/int64, 15x20/int64, 15x21/int64, 15x22/int64, 15x23/int64, 15x24/int64, 15x25/int64, 15x26/int64, 15x27/int64, 15x28/int64, 16x1/int64, 16x2/int64, 16x3/int64, 16x4/int64, 16x5/int64, 16x6/int64, 16x7/int64, 16x8/int64, 16x9/int64, 16x10/int64, 16x11/int64, 16x12/int64, 16x13/int64, 16x14/int64, 16x15/int64, 16x16/int64, 16x17/int64, 16x18/int64, 16x19/int64, 16x20/int64, 16x21/int64, 16x22/int64, 16x23/int64, 16x24/int64, 16x25/int64, 16x26/int64, 16x27/int64, 16x28/int64, 17x1/int64, 17x2/int64, 17x3/int64, 17x4/int64, 17x5/int64, 17x6/int64, 17x7/int64, 17x8/int64, 17x9/int64, 17x10/int64, 17x11/int64, 17x12/int64, 17x13/int64, 17x14/int64, 17x15/int64, 17x16/int64, 17x17/int64, 17x18/int64, 17x19/int64, 17x20/int64, 17x21/int64, 17x22/int64, 17x23/int64, 17x24/int64, 17x25/int64, 17x26/int64, 17x27/int64, 17x28/int64, 18x1/int64, 18x2/int64, 18x3/int64, 18x4/int64, 18x5/int64, 18x6/int64, 18x7/int64, 18x8/int64, 18x9/int64, 18x10/int64, 18x11/int64, 18x12/int64, 18x13/int64, 18x14/int64, 18x15/int64, 18x16/int64, 18x17/int64, 18x18/int64, 18x19/int64, 18x20/int64, 18x21/int64, 18x22/int64, 18x23/int64, 18x24/int64, 18x25/int64, 18x26/int64, 18x27/int64, 18x28/int64, 19x1/int64, 19x2/int64, 19x3/int64, 19x4/int64, 19x5/int64, 19x6/int64, 19x7/int64, 19x8/int64, 19x9/int64, 19x10/int64, 19x11/int64, 19x12/int64, 19x13/int64, 19x14/int64, 19x15/int64, 19x16/int64, 19x17/int64, 19x18/int64, 19x19/int64, 19x20/int64, 19x21/int64, 19x22/int64, 19x23/int64, 19x24/int64, 19x25/int64, 19x26/int64, 19x27/int64, 19x28/int64, 20x1/int64, 20x2/int64, 20x3/int64, 20x4/int64, 20x5/int64, 20x6/int64, 20x7/int64, 20x8/int64, 20x9/int64, 20x10/int64, 20x11/int64, 20x12/int64, 20x13/int64, 20x14/int64, 20x15/int64, 20x16/int64, 20x17/int64, 20x18/int64, 20x19/int64, 20x20/int64, 20x21/int64, 20x22/int64, 20x23/int64, 20x24/int64, 20x25/int64, 20x26/int64, 20x27/int64, 20x28/int64, 21x1/int64, 21x2/int64, 21x3/int64, 21x4/int64, 21x5/int64, 21x6/int64, 21x7/int64, 21x8/int64, 21x9/int64, 21x10/int64, 21x11/int64, 21x12/int64, 21x13/int64, 21x14/int64, 21x15/int64, 21x16/int64, 21x17/int64, 21x18/int64, 21x19/int64, 21x20/int64, 21x21/int64, 21x22/int64, 21x23/int64, 21x24/int64, 21x25/int64, 21x26/int64, 21x27/int64, 21x28/int64, 22x1/int64, 22x2/int64, 22x3/int64, 22x4/int64, 22x5/int64, 22x6/int64, 22x7/int64, 22x8/int64, 22x9/int64, 22x10/int64, 22x11/int64, 22x12/int64, 22x13/int64, 22x14/int64, 22x15/int64, 22x16/int64, 22x17/int64, 22x18/int64, 22x19/int64, 22x20/int64, 22x21/int64, 22x22/int64, 22x23/int64, 22x24/int64, 22x25/int64, 22x26/int64, 22x27/int64, 22x28/int64, 23x1/int64, 23x2/int64, 23x3/int64, 23x4/int64, 23x5/int64, 23x6/int64, 23x7/int64, 23x8/int64, 23x9/int64, 23x10/int64, 23x11/int64, 23x12/int64, 23x13/int64, 23x14/int64, 23x15/int64, 23x16/int64, 23x17/int64, 23x18/int64, 23x19/int64, 23x20/int64, 23x21/int64, 23x22/int64, 23x23/int64, 23x24/int64, 23x25/int64, 23x26/int64, 23x27/int64, 23x28/int64, 24x1/int64, 24x2/int64, 24x3/int64, 24x4/int64, 24x5/int64, 24x6/int64, 24x7/int64, 24x8/int64, 24x9/int64, 24x10/int64, 24x11/int64, 24x12/int64, 24x13/int64, 24x14/int64, 24x15/int64, 24x16/int64, 24x17/int64, 24x18/int64, 24x19/int64, 24x20/int64, 24x21/int64, 24x22/int64, 24x23/int64, 24x24/int64, 24x25/int64, 24x26/int64, 24x27/int64, 24x28/int64, 25x1/int64, 25x2/int64, 25x3/int64, 25x4/int64, 25x5/int64, 25x6/int64, 25x7/int64, 25x8/int64, 25x9/int64, 25x10/int64, 25x11/int64, 25x12/int64, 25x13/int64, 25x14/int64, 25x15/int64, 25x16/int64, 25x17/int64, 25x18/int64, 25x19/int64, 25x20/int64, 25x21/int64, 25x22/int64, 25x23/int64, 25x24/int64, 25x25/int64, 25x26/int64, 25x27/int64, 25x28/int64, 26x1/int64, 26x2/int64, 26x3/int64, 26x4/int64, 26x5/int64, 26x6/int64, 26x7/int64, 26x8/int64, 26x9/int64, 26x10/int64, 26x11/int64, 26x12/int64, 26x13/int64, 26x14/int64, 26x15/int64, 26x16/int64, 26x17/int64, 26x18/int64, 26x19/int64, 26x20/int64, 26x21/int64, 26x22/int64, 26x23/int64, 26x24/int64, 26x25/int64, 26x26/int64, 26x27/int64, 26x28/int64, 27x1/int64, 27x2/int64, 27x3/int64, 27x4/int64, 27x5/int64, 27x6/int64, 27x7/int64, 27x8/int64, 27x9/int64, 27x10/int64, 27x11/int64, 27x12/int64, 27x13/int64, 27x14/int64, 27x15/int64, 27x16/int64, 27x17/int64, 27x18/int64, 27x19/int64, 27x20/int64, 27x21/int64, 27x22/int64, 27x23/int64, 27x24/int64, 27x25/int64, 27x26/int64, 27x27/int64, 27x28/int64, 28x1/int64, 28x2/int64, 28x3/int64, 28x4/int64, 28x5/int64, 28x6/int64, 28x7/int64, 28x8/int64, 28x9/int64, 28x10/int64, 28x11/int64, 28x12/int64, 28x13/int64, 28x14/int64, 28x15/int64, 28x16/int64, 28x17/int64, 28x18/int64, 28x19/int64, 28x20/int64, 28x21/int64, 28x22/int64, 28x23/int64, 28x24/int64, 28x25/int64, 28x26/int64, 28x27/int64, 28x28/int64]; N=48000 data points
INFO  2024-04-30 08:31:56,212 sensai.torch.torch_base:_fit_classifier:780 - Fitting <__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090> ...
INFO  2024-04-30 08:31:56,271 sensai.torch.torch_opt.NNOptimiser:fit:682 - Preparing parameter learning of CnnModel.VectorTorchModel[cuda=False, inputDim=784, outputDim=10, bestEpoch=None, totalEpochs=None] via NNOptimiser[params=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True]] with cuda=False
INFO  2024-04-30 08:31:56,272 sensai.torch.torch_opt.NNOptimiser:fit:716 - Obtaining input/output training instances
INFO  2024-04-30 08:31:58,730 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Data set 1/1: #train=36000, #validation=12000
INFO  2024-04-30 08:31:58,731 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Number of validation sets: 1
INFO  2024-04-30 08:31:58,736 sensai.torch.torch_opt.NNOptimiser:fit:746 - Learning parameters of CnnModel.VectorTorchModel[cuda=False, inputDim=784, outputDim=10, bestEpoch=None, totalEpochs=None]
INFO  2024-04-30 08:31:58,737 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Number of parameters: 926862
INFO  2024-04-30 08:31:58,738 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Starting training process via NNOptimiser[params=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True]]
INFO  2024-04-30 08:31:58,748 sensai.torch.torch_opt.NNOptimiser:fit:764 - Begin training with cuda=False
INFO  2024-04-30 08:31:58,749 sensai.torch.torch_opt.NNOptimiser:fit:765 - Press Ctrl+C to end training early
INFO  2024-04-30 08:32:08,447 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   1/1000 completed in  9.70s | train loss 0.8259 | validation NLL 0.2525 | best NLL 0.252539 from this epoch
INFO  2024-04-30 08:32:18,180 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   2/1000 completed in  9.72s | train loss 0.1533 | validation NLL 0.1433 | best NLL 0.143283 from this epoch
INFO  2024-04-30 08:32:27,916 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   3/1000 completed in  9.73s | train loss 0.0881 | validation NLL 0.0999 | best NLL 0.099880 from this epoch
INFO  2024-04-30 08:32:37,697 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   4/1000 completed in  9.77s | train loss 0.0596 | validation NLL 0.1221 | best NLL 0.099880 from epoch 3
INFO  2024-04-30 08:32:47,434 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   5/1000 completed in  9.74s | train loss 0.0452 | validation NLL 0.0962 | best NLL 0.096233 from this epoch
INFO  2024-04-30 08:32:57,257 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   6/1000 completed in  9.81s | train loss 0.0380 | validation NLL 0.0974 | best NLL 0.096233 from epoch 5
INFO  2024-04-30 08:33:07,038 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   7/1000 completed in  9.78s | train loss 0.0294 | validation NLL 0.0986 | best NLL 0.096233 from epoch 5
INFO  2024-04-30 08:33:16,921 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Epoch   8/1000 completed in  9.88s | train loss 0.0223 | validation NLL 0.1068 | best NLL 0.096233 from epoch 5
INFO  2024-04-30 08:33:16,922 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Stopping early: 3 epochs without validation metric improvement
INFO  2024-04-30 08:33:16,923 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Training complete
INFO  2024-04-30 08:33:16,923 sensai.torch.torch_opt.NNOptimiser:training_log:697 - Best model is from epoch 5 with NLL 0.09623306624094645 on validation set
INFO  2024-04-30 08:33:16,935 sensai.vector_model:fit:400 - Fitting completed in 80.75 seconds: CnnModel[cuda=False, output_activation_fn=ActivationFunction.LOG_SOFTMAX, kernel_size=5, num_conv=32, pooling_kernel_size=2, mlp_hidden_dims=(200, 20), p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=functools.partial(<class '__main__.CnnModel.VectorTorchModel'>, CnnModel[id=140615177024144, cuda=False, output_activation_fn=ActivationFunction.LOG_SOFTMAX, kernel_size=5, num_conv=32, pooling_kernel_size=2, mlp_hidden_dims=(200, 20), p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=..., normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=CnnModel.VectorTorchModel[cuda=False, inputDim=784, outputDim=10, bestEpoch=5, totalEpochs=8], inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN']), normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=CnnModel.VectorTorchModel[cuda=False, inputDim=784, outputDim=10, bestEpoch=5, totalEpochs=8], inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN']
DEBUG 2024-04-30 08:33:16,936 sensai.torch.torch_data:__init__:546 - Applying <__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090> to data frame of length 12000 ...
INFO  2024-04-30 08:33:18,708 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.97775, balancedAccuracy=0.9776402722798222, N=12000]
_images/neural_networks_32_1.png
_images/neural_networks_32_2.png

Our CNN models do improve upon the MLP model we evaluated earlier. Let’s do a comparison of all the models we trained thus far:

[17]:
comparison_data = eval_util.compare_models([torch_mlp_model, cnn_model_from_module, cnn_model, random_forest_model], fit_models=False)
comparison_data.results_df
INFO  2024-04-30 08:33:19,874 sensai.evaluation.eval_util:compare_models:393 - Evaluating model 1/4 named 'MLP' ...
DEBUG 2024-04-30 08:33:20,084 sensai.evaluation.evaluator:__init__:182 - <sensai.data.DataSplitterFractional object at 0x7fe388bb8a90> created split with 48000 (80.00%) and 12000 (20.00%) training and test data points respectively
INFO  2024-04-30 08:33:20,085 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating MultiLayerPerceptronVectorClassificationModel[hidden_dims=(50, 20), hid_activation_function=<built-in method sigmoid of type object at 0x7fe329788880>, output_activation_function=ActivationFunction.LOG_SOFTMAX, input_dim=None, cuda=False, p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=Method[_create_torch_model], normalisationMode=NormalisationMode.MAX_ALL, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=54, optimiser_lr=0.001, shrinkage_clip=10.0, optimiser=adam, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=2, shuffle=True], model=MultiLayerPerceptronTorchModel[cuda=False, inputDim=784, outputDim=10, hidActivationFunction=<built-in method sigmoid of type object at 0x7fe329788880>, outputActivationFunction=functools.partial(<function log_softmax at 0x7fe3886bb8c0>, dim=1), hiddenDims=(50, 20), pDropout=0.0, overrideInputDim=None, bestEpoch=16, totalEpochs=18], inputTensoriser=None, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=MLP] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe386290090>
DEBUG 2024-04-30 08:33:20,085 sensai.torch.torch_data:__init__:546 - Applying <sensai.torch.torch_data.TensoriserDataFrameFloatValuesMatrix object at 0x7fe386fd2d90> to data frame of length 12000 ...
INFO  2024-04-30 08:33:20,218 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.96225, balancedAccuracy=0.9618968610010363, N=12000]
INFO  2024-04-30 08:33:20,629 sensai.evaluation.eval_util:compare_models:393 - Evaluating model 2/4 named 'CNN' ...
INFO  2024-04-30 08:33:20,630 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating TorchVectorClassificationModel[featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=<sensai.torch.torch_base.TorchModelFactoryFromModule object at 0x7fe385b7ab90>, normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=TorchModelFromModule[cuda=False, bestEpoch=6, totalEpochs=9], inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385b6a150>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe386290090>
DEBUG 2024-04-30 08:33:20,631 sensai.torch.torch_data:__init__:546 - Applying <__main__.ImageReshapingInputTensoriser object at 0x7fe385b6a150> to data frame of length 12000 ...
INFO  2024-04-30 08:33:22,446 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.9784166666666667, balancedAccuracy=0.9782716826784611, N=12000]
INFO  2024-04-30 08:33:22,763 sensai.evaluation.eval_util:compare_models:393 - Evaluating model 3/4 named 'CNN'' ...
INFO  2024-04-30 08:33:22,764 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating CnnModel[cuda=False, output_activation_fn=ActivationFunction.LOG_SOFTMAX, kernel_size=5, num_conv=32, pooling_kernel_size=2, mlp_hidden_dims=(200, 20), p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=functools.partial(<class '__main__.CnnModel.VectorTorchModel'>, CnnModel[id=140615177024144, cuda=False, output_activation_fn=ActivationFunction.LOG_SOFTMAX, kernel_size=5, num_conv=32, pooling_kernel_size=2, mlp_hidden_dims=(200, 20), p_dropout=0.0, featureGenerator=None, outputMode=ClassificationOutputMode.LOG_PROBABILITIES, torch_model_factory=..., normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=CnnModel.VectorTorchModel[cuda=False, inputDim=784, outputDim=10, bestEpoch=5, totalEpochs=8], inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN']), normalisationMode=NormalisationMode.NONE, nnOptimiserParams=NNOptimiserParams[epochs=1000, batch_size=1024, optimiser_lr=0.01, shrinkage_clip=10.0, optimiser=Optimiser.ADAMW, gpu=None, train_fraction=0.75, scaled_outputs=False, loss_evaluator=NNLossEvaluatorClassification[LossFunction.NLL], optimiser_args={}, use_shrinkage=True, early_stopping_epochs=3, shuffle=True], model=CnnModel.VectorTorchModel[cuda=False, inputDim=784, outputDim=10, bestEpoch=5, totalEpochs=8], inputTensoriser=<__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090>, outputTensoriser=None, torchDataSetProviderFactory=None, dataFrameSplitter=None, name=CNN'] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe386290090>
DEBUG 2024-04-30 08:33:22,765 sensai.torch.torch_data:__init__:546 - Applying <__main__.ImageReshapingInputTensoriser object at 0x7fe385ad3090> to data frame of length 12000 ...
INFO  2024-04-30 08:33:24,597 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.97775, balancedAccuracy=0.9776402722798222, N=12000]
INFO  2024-04-30 08:33:24,861 sensai.evaluation.eval_util:compare_models:393 - Evaluating model 4/4 named 'RandomForest' ...
INFO  2024-04-30 08:33:24,862 sensai.evaluation.eval_util:perform_simple_evaluation:281 - Evaluating SkLearnRandomForestVectorClassificationModel[featureGenerator=None, fitArgs={}, useBalancedClassWeights=False, useLabelEncoding=False, name=RandomForest, model=RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None, criterion='gini', max_depth=None, max_features='auto', max_leaf_nodes=None, max_samples=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None, oob_score=False, random_state=42, verbose=0, warm_start=False)] via <sensai.evaluation.evaluator.VectorClassificationModelEvaluator object at 0x7fe386290090>
INFO  2024-04-30 08:33:24,937 sensai.evaluation.eval_util:gather_results:289 - Evaluation results for label: ClassificationEvalStats[accuracy=0.9466666666666667, balancedAccuracy=0.945916926388699, N=12000]
INFO  2024-04-30 08:33:25,202 sensai.evaluation.eval_util:compare_models:462 - Model comparison results:
              accuracy  balancedAccuracy
model_name
MLP           0.962250          0.961897
CNN           0.978417          0.978272
CNN'          0.977750          0.977640
RandomForest  0.946667          0.945917
[17]:
accuracy balancedAccuracy
model_name
MLP 0.962250 0.961897
CNN 0.978417 0.978272
CNN' 0.977750 0.977640
RandomForest 0.946667 0.945917
_images/neural_networks_34_2.png
_images/neural_networks_34_3.png
_images/neural_networks_34_4.png
_images/neural_networks_34_5.png
_images/neural_networks_34_6.png
_images/neural_networks_34_7.png
_images/neural_networks_34_8.png
_images/neural_networks_34_9.png

Note that any differences between the two CNN models are due only to randomness in the parameter initialisation; they are functionally identical.

Could the CNN model have produced even better results? Let’s take a look at some examples where the CNN model went wrong by inspecting the evaluation data that was returned earlier.

[18]:
misclassified = eval_data.get_misclassified_triples_pred_true_input()
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(9,9))
for i, (predClass, trueClass, input) in enumerate(misclassified[:9]):
    axs[i//3][i%3].imshow(reshape_2d_image(input), cmap="binary")
    axs[i//3][i%3].set_title(f"{trueClass} misclassified as {predClass}")
plt.tight_layout()
_images/neural_networks_36_0.png

While some of these examples are indeed ambiguous, there still is room for improvement.