Note
Click here to download the full example code
Basic usage.
This example demonstrates the simplest usage of SIRFShampoo. The algorithm works
pretty much like any other
torch.optim.Optimizer;
but there are some additional aspects that are good to know.
First, the imports.
from torch import cuda, device, manual_seed
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from sirfshampoo import SIRFShampoo
manual_seed(0) # make deterministic
MAX_STEPS = 200 # quit training after this many steps (or one epoch)
DEV = device("cuda" if cuda.is_available() else "cpu")
Problem Setup
We will train a simple neural network on MNIST using cross-entropy loss:
BATCH_SIZE = 32
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
model = Sequential(
Conv2d(1, 3, kernel_size=5, stride=2),
ReLU(),
Flatten(),
Linear(432, 50),
ReLU(),
Linear(50, 10),
).to(DEV)
loss_func = CrossEntropyLoss().to(DEV)
Out:
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0.00/9.91M [00:00<?, ?B/s]
53%|#####2 | 5.24M/9.91M [00:00<00:00, 52.4MB/s]
100%|##########| 9.91M/9.91M [00:00<00:00, 74.7MB/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0.00/28.9k [00:00<?, ?B/s]
100%|##########| 28.9k/28.9k [00:00<00:00, 1.57MB/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0.00/1.65M [00:00<?, ?B/s]
100%|##########| 1.65M/1.65M [00:00<00:00, 28.0MB/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0.00/4.54k [00:00<?, ?B/s]
100%|##########| 4.54k/4.54k [00:00<00:00, 15.8MB/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Optimizer Setup
One difference to many built-in PyTorch optimizers is that SIRFShampoo requires
access to the neural network (a torch.nn.Module). Let's also turn on verbose
initialization to get some insights into the pre-conditioners:
Out:
Parameter groups:
Group 0
- Parameter names: ['0.weight']
- Pre-conditioner: ['3x3 (DenseMatrix)', '5x5 (DenseMatrix)', '5x5 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fdaca0>}
Group 1
- Parameter names: ['0.bias']
- Pre-conditioner: ['3x3 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fdac70>}
Group 2
- Parameter names: ['3.weight']
- Pre-conditioner: ['50x50 (DenseMatrix)', '432x432 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbda60>}
Group 3
- Parameter names: ['3.bias']
- Pre-conditioner: ['50x50 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbd3a0>}
Group 4
- Parameter names: ['5.weight']
- Pre-conditioner: ['10x10 (DenseMatrix)', '50x50 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbd400>}
Group 5
- Parameter names: ['5.bias']
- Pre-conditioner: ['10x10 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbd340>}
This is because SIRFShampoo installs a hook onto the neural network to detect the
batch size.
Of course, you can also tweak its other arguments, such as learning rates and momenta. See here for a complete overview.
Training
When it comes to training, SIRFShampoo can be used in exactly the same way as
other optimizers (see
here
for an introduction). Let's train for a couple of steps and print the loss.
PRINT_LOSS_EVERY = 25 # logging interval
for step, (inputs, target) in enumerate(train_loader):
optimizer.zero_grad() # clear gradients from previous iterations
# regular forward-backward pass
loss = loss_func(model(inputs.to(DEV)), target.to(DEV))
loss.backward()
if step % PRINT_LOSS_EVERY == 0:
print(f"Step {step}, Loss {loss.item():.3f}")
optimizer.step() # update neural network parameters
if step >= MAX_STEPS: # don't train a full epoch to keep the example light-weight
break
Out:
Step 0, Loss 2.317
Step 25, Loss 2.322
Step 50, Loss 2.311
Step 75, Loss 2.297
Step 100, Loss 2.266
Step 125, Loss 2.163
Step 150, Loss 1.978
Step 175, Loss 0.866
Step 200, Loss 1.076
Conclusion
You now know the most basic way to train a neural network with SIRFShampoo. From
here, you might be interested in
-
checking out the more advanced examples
-
taking a closer look at
SIRFShampoos hyper-parameters.
Total running time of the script: ( 0 minutes 9.419 seconds)
Download Python source code: example_01_basic.py