Torch for R Now in the qeML Package

I’ve added a new function, qeNeuralTorch, to the qeML package, as an alternative to the package’s qeNeural. It is experimental as this point, but usable and I urge everyone to try it out. In this post, I will (a) state why I felt it desirable to add such a function, (b) show a couple of examples, (c) explain how the function works, thereby giving you an introduction to Torch, and finally (d) explain what at this point appears to be a fundamental problem with Torch.

If you need an introduction to neural networks, see the ML overview vignette in qeML.

Why an alternative to qeNeural?

The qeNeural function is a wrapper for regtools::krsFit, which is based on the R keras package. In turn, that calls functions in the tensorflow package, which calls the original Python version of the same name.

Though there has been much progress in interfacing R to Python, the interface is difficult to set up, especially in terms of the environment variables. That makes the “Torch for R” package, torch, highly attractive, as it contains no Python. Torch itself was developed (in Python) as a clearer alternative to Tensorflow, and is now probably more popular than the latter. So, it’s nice that a version for R has been produced, especially one that is Python-free.

Another advantage is easy access to GPUs, especially on Mac GPUs, which are generally problematic. There is a companion R package, luz, that can be used to speed up Torch computation.

Examples

Here we use the svcensus data from qeML.

lyrsReg <- list(
list('linear',0,100),
list('relu'),
list('linear',100,100),
list('relu'),
list('linear',100,1))

z <- qeNeuralTorch(svcensus,'wageinc',layers=lyrsReg,
learnRate=0.05)

The call format is standard qeML. We need to specify the network structure via its layers, hence an argument of that name. (The qeNeural function does this a little differently.) We see a linear, i.e. fully-connected layer of “0” inputs and 100 outputs, then a reLU activation function, then another hidden layer and activation, and finally another linear layer with 1 output, our predicted values. The “0” is filled in at runtime with the number of features.

The above is for regression problems; here is code for classification problems:

lyrsClass <- list( 
list('linear',0,100),
list('relu'),
list('linear',100,100),
list('relu'),
list('linear',100,1),
list('sigmoid'))

z <- qeNeuralTorch(svcensus,'gender',yesYVal='male',
layers=lyrsClass,learnRate=0.003)

That last entry in the layers formation squeezes the result to the interval (0,1), to make probabilities. Note that that list-of-lists code is just defining the network, not creating it.

How is Torch used in qeNeuralTorch?

One still must work with tensors, which are essentially multidimensional arrays. So there are lines in the function like

xT <- torch_tensor(x)

where the matrix of futures is converted to a tensor. The major work, though, is done in first setting up the network, and then running the data through it in an iterative manner. Note that one does not need to be aware of any of this to use qeNeuralTorch, but as usual, understanding the innards of a function sharpens one’s ability to use it well.

The network formation code within qeNeuralTorch works on the above list of lists to form nnSeqArgs. which again simply defines the network for Torch. Then the network is created:

model <- do.call(nn_sequential,nnSeqArgs)

A side note is that torch::nn_sequential has the formal argument ‘…’. That is no problem for the ordinary author of Torch code; if they have, say, 4 arguments in their particular application, he/she just states them as arguments in a call to nn_sequential. But qeNeuralTorch must allow for a variable number of network layers, hence the use of R’s do.call.

Here is the code that runs the network:

   for (i in 1:nEpochs) {
preds <- model(xT)
loss <- nnf_mse_loss(preds,yT,reduction = "sum")
optimizer$zero_grad()
loss$backward()
optimizer$step()
}

First, our feature data xT is run through the network, which has been initialized with random weights. That produces predictions, preds. The current L2 loss is then computed, then the gradient of the loss determined and the weights updated. The goes through the number of iterations specified by the user, nEpochs.

Our implementation is rather primitive; we use that same loss L2 function even for the classification case (actually this can be justified), and, for now, limited to the 2-class case).

Torch for R uses R6 class structure, rather different from the more common S3 and S4. An example above is the line

loss$backward()

Here loss is an object, mainly containing the current loss value but also containing a function backward. The latter is called on the former, as those who’ve used, e.g., Python or Java will recognize.

Again, you need not know this in order to use qeNeuralTorch.

Performance

The package seems to be very sensitive to the learning rate.

Also, it turns out, at least in my implementation, that Torch’s accuracy is generally much weaker than those of other qeML functions in regression cases, but similar in classification cases.

I surmised that this was due to Torch producing overly-large weights, and investigated by comparing the L2 norms of its prediction with those of other functions. Sure enough, Torch was producing larger predictions.

Torch has a parameter to control this via regularization, weighted_decay. However, this did not appear to help.

My current guess–maybe some who reads this will have suggestions–is that since ML applications tend more toward the classification case, the problem of large weights never really arose. Having predictions that are too extreme may not hurt in classification problems, as this simply brings them closer to 0 or 1 when run through a sigmoid function or similar. Since qeNeuralTorch rounds the sigmoid output to 0 or 1 to produce class predictions, it all may work out well in the end.

Note that this also means one should be cautious if one takes the unrounded output of the network to be the actual probabilities.

Leave a comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.