Training Neural Networks with JAX

JAX is a python library made to boost machine learning research using accelerators like TPUs/GPUs. Due to its speed and efficiency coupled with familiarity of Python and Numpy, it has been widely adopted by machine learning researchers. While training neural networks is even faster, another advantage of JAX is it saves memory cost and energy. In this tutorial, we’ll be using JAX to create a simple neural network which we’ll use to solve a regression task. If you are new to JAX this article here is a solid introduction. For our example, we’ll use a small dataset available from yellowbrick which is an open source, pure Python project that extends Scikit-Learn with visual analysis and diagnostic tools.

What is a neural network?

Simply put it is a mathematical function that maps a given input in conjunction with information from other nodes to develop an output. It is inspired and modeled on the human mind. In this tutorial, I won’t explain many of the basics of the neural network so if you’re new to neural networks I will refer to this article here.

Regression

Regression is a method of investigating the relationship between independent variables or features and a dependent variable or outcome. It’s used as a method for predictive modeling in machine learning, in which an algorithm is used to predict continuous outcomes. We will create a neural network with JAX to solve a regression task. We’ll use a dataset for Boston housing from scikit-learn. Below we will import JAX and some of its submodules stax and optimizers that we will use to train neural networks. We have also imported the jax.numpy module as we’ll require it to convert input data to JAX arrays and a few other calculations. Here is the link to our notebook.

Load dataset

We first load the concrete strength dataset available from Yellowbrick. Concrete is the most important material in civil engineering. The concrete compressive strength is the regression problem. The concrete compressive strength is a highly nonlinear function of age and ingredients. We have loaded data features in variable X and target values in variable Y. We split the dataset into the train (80%) and test (20%) sets according to the Pareto principle. The Pareto principle states that “for many events, roughly 80% of the effects come from 20% of the causes”. After dividing the dataset, we converted each NumPy array to a Jax array using jax.numpy.array() constructor. We also printed the shape of the train datasets and test datasets at the end.

Normalize data

To normalize data, we first calculated the mean and standard deviation of the training dataset for each feature of data. We then subtracted the mean from both training and testing sets. Finally, we divided subtracted values by standard deviation. The main reason to normalize data is to bring the values of each feature to almost the same scale. This helps the optimization algorithm gradient descent to converge faster. When values of different features are on a different scale and vary a lot then it can increase training time because the gradient descent algorithm will have a hard time converging.

Creating the Neural Network

The JAX module, stax provides various readily available layers that we can stack together to create a neural network. The process of creating a neural network using the stax module is almost the same as that of creating a neural network using Sequential() API of keras. The stax module provides a serial() method that accepts a list of layers and activation functions as input and creates a neural network. It applies the layers in the sequence in which they are given as input when performing forward pass-through data. Using the Dense() method we can create fully connected and dense layers. We can also provide a weight initialization and bias initialization function if we don’t want internal initialization performed by JAX after we create the layer using Dense().

Most stax module methods return 2 callable functions as output when executed:

  1. init_fun — This function takes seed for weight initialization of that layer/network and input shape for that layer/network as input. It then returns weights and biases. For a single layer, it returns just weights and biases as output and for the neural network, it’ll return a list of weights and biases.

  2. apply_fun — This function takes weights & biases of layer/network and data as input. It then executes the layer/network on input data using weights. It performs forward pass-through data for the network.

All activation functions are available as basic attributes of the stax module and we don’t need to call them with brackets. We can just give them as input to the serial() method after layers and they will be applied to the layer’s output.

Below is an example of a Dense() layer with 5 units to show the output returned by it. We can notice that it returns two callable functions which we described above.

We then created our neural network whose layer sizes are [5,10,15,1] the last layer is the output layer and all other layers are hidden layers. We have then created layers using the Dense() method followed by Relu (Rectified Linear Unit) activation function. We’ll provide all the layers and activation to the serial() method in sequence. The Relu function that we have used takes as input an array and returns a new array of the same size where all values less than 0 are replaced by 0.

By calling the init_fun() function we have initialized the weights of our neural network. We have given seed i.e (jax.random.PRNGKey(123)) and input data shape as input so it functions. The seed and shape information will be used to initialize the weights and biases of each layer of the neural network.

Below I have printed the weights and biases for each layer after initializing the weights.

We can perform a forward pass through our neural network. For this, we will take a few samples of our data and give them as input to the apply_fun() function along with weights. First, the weights are given followed by a small batch of data then apply_func() will perform one forward pass-through data using weights and return predictions.

Define loss function

In this part, we will calculate the gradient of the loss function with respect to weights and then update weights using gradients. We’ll use Mean squared error loss as our loss function. It simply subtracts predictions from actual values, squares subtracted values, and then the mean of them. Our loss function takes weights, data, and actual target values as input. It then performs a forward pass through the neural network using the apply_fun() function providing weights and data to it. The predictions made by the network are stored in a variable. We can then actually calculate MSE using actual target values and predictions.

Train Neural Network

We will create a function that we will call to train our neural network. The function takes data features, target values, number of epochs, and optimizer state as input. The Optimizer state is an object created by the optimizer that has our model’s weights.

Our function loops a number of epochs time, each time, it first calculates loss value and gradients using the value_and_grad() function. This function takes another function as input, the MSE loss function in our case. It then returns another callable which when called will return the actual value of the function as well as the gradient of the function with respect to the first parameter which is weights in our case. In this instance, we have given our loss function to the value_and_grad() function as input and then called the returned function by providing weights, data features, and target values. With these three as inputs of our loss function, the call will return MSE value and gradients for weights and biases of each layer of our neural network.

Then we will call an optimizer state update method that takes epoch number, gradients, and current optimizer state that has current weights as inputs. The method returns a new optimizer state which will have weights updated by subtracting gradients from it. We will print MSE at every 100 epochs to keep track of our training progress and finally, we return the last optimizer state (final updated weights).

Now that we have initialized an optimizer for our neural network we can go into what it is. The optimizer is an algorithm responsible for finding the minimum value of our loss function. The optimizers module available from the example_libraries module of jax provides us with a list of different optimizers. In our case, we’ll use the sgd() (gradient descent) optimizer. We initialized our optimizer by giving it a learning rate of (0.001).

The optimizer returns three callables necessary for maintaining and updating the weights of the neural network.

  1. init — This function takes weights of a neural network as input and returns the OptimizerState object which is a wrapper for holding and updating weights.

  2. update_fn — This function takes epoch number, gradients, and optimizer state as input. It then updates weights present in the optimizer state object by subtracting learning times gradients from it. It then returns a new OptimizerState object which has updated weights.

  3. params_fn — This function takes the OptimizerState object as input and returns the actual weights of the neural network.

Here we will train the neural network with the function we created in the previous cell. After initializing the optimizer with weights, we have called our training routine to actually perform training by providing data, target values, number of epochs, and optimizer state (weights). We are a training network for 2500 epochs.

Output:

As we can surmise from MSE getting printed every 100 epochs the model is getting better at the task.

Make Predictions

In this section, we have made predictions for both train and test datasets. We retrieved weights of the neural network using the params_fn optimizer function. We have then given weights and data features as input to the apply_fn method which will make predictions.

Evaluating the Model performance

Here we will evaluate how our model is actually performing. We are going to be calculating the R² score for both our train and test predictions. We are calculating the R² score using the r2_score() method of scikit-learn. The score generally returns the value in the range [0,1] where a value near 1 indicates a good model.

We can notice from the R² score that our model seems to be doing a good job.

Train the Model on Batches of Data

Some datasets are quite large and do not really fit into the main memory of the computer. In cases like this, we only bring a small batch of data into the main memory of the computer to train the model a batch at a time until the whole data is covered. The optimization algorithm used in this instance is referred to as stochastic gradient descent and it works on a small batch of data at a time.

The function we have below takes data features, target values, number of epochs, optimizer state (weights), and batch size (default 32) as input. We will perform training loops a number of epoch times while calculating the start and end indexes of our batch of data for each training loop. We will be performing a forward pass, calculating loss, and updating loss on a single batch of data at a time until the whole data is covered. When we are training data in batches we update weights for each batch of data until the whole data is covered for a number of epochs.

Now we have the function for training our neural network. We will initialize the weights of the neural network using init_fun by giving seed and input shape to it. Next, we initialized our optimizer by calling the sgd() function giving a learning rate (0.001). Then we created the first optimizer state with weights and then called our function from the previous cell to perform training in batches. We will be training the neural network for 500 epochs.

Making predictions in batches

Because all the data cannot fit into our main memory we will make predictions in batches. Below is a function that takes weights and data as input and then makes predictions on data in batches.

We will call the function above to make predictions on test and train datasets in batches. We will combine the prediction of the batches.

Evaluate model performance

We will calculate our R² score on train and test predictions to see how our network is performing.

Conclusion

You can attempt the example above yourself even with other tasks like classification which will use similar code. With modules like Stax and optimizers from JAX, you use less code which helps with efficiency. There are more libraries and modules from JAX to explore that may improve your machine learning research. As you can see with JAX you can vastly improve the speed of your machine learning research depending on your field of research.

References

https://coderzcolumn.com/tutorials/artificial-intelligence/create-neural-networks-using-high-level-jax-api

https://coderzcolumn.com/tutorials/artificial-intelligence/guide-to-create-simple-neural-networks-using-jax

https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

Previous
Previous

Introduction to Computer Vision: Image segmentation with Scikit-image

Next
Next

SVD Algorithm Tutorial in Python