Skip to content

Note

Click here to download the full example code

Basic usage.

This example demonstrates the simplest usage of SIRFShampoo. The algorithm works pretty much like any other torch.optim.Optimizer; but there are some additional aspects that are good to know.

First, the imports.

from torch import cuda, device, manual_seed
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from sirfshampoo import SIRFShampoo

manual_seed(0)  # make deterministic
MAX_STEPS = 200  # quit training after this many steps (or one epoch)
DEV = device("cuda" if cuda.is_available() else "cpu")

Problem Setup

We will train a simple neural network on MNIST using cross-entropy loss:

BATCH_SIZE = 32
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = Sequential(
    Conv2d(1, 3, kernel_size=5, stride=2),
    ReLU(),
    Flatten(),
    Linear(432, 50),
    ReLU(),
    Linear(50, 10),
).to(DEV)
loss_func = CrossEntropyLoss().to(DEV)

Out:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0.00/9.91M [00:00<?, ?B/s]

 53%|#####2    | 5.24M/9.91M [00:00<00:00, 52.4MB/s]
100%|##########| 9.91M/9.91M [00:00<00:00, 74.7MB/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|##########| 28.9k/28.9k [00:00<00:00, 1.57MB/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0.00/1.65M [00:00<?, ?B/s]
100%|##########| 1.65M/1.65M [00:00<00:00, 28.0MB/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1133)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0.00/4.54k [00:00<?, ?B/s]
100%|##########| 4.54k/4.54k [00:00<00:00, 15.8MB/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Optimizer Setup

One difference to many built-in PyTorch optimizers is that SIRFShampoo requires access to the neural network (a torch.nn.Module). Let's also turn on verbose initialization to get some insights into the pre-conditioners:

optimizer = SIRFShampoo(model, verbose_init=True)

Out:

Parameter groups:
Group 0
        - Parameter names: ['0.weight']
        - Pre-conditioner: ['3x3 (DenseMatrix)', '5x5 (DenseMatrix)', '5x5 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fdaca0>}
Group 1
        - Parameter names: ['0.bias']
        - Pre-conditioner: ['3x3 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fdac70>}
Group 2
        - Parameter names: ['3.weight']
        - Pre-conditioner: ['50x50 (DenseMatrix)', '432x432 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbda60>}
Group 3
        - Parameter names: ['3.bias']
        - Pre-conditioner: ['50x50 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbd3a0>}
Group 4
        - Parameter names: ['5.weight']
        - Pre-conditioner: ['10x10 (DenseMatrix)', '50x50 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbd400>}
Group 5
        - Parameter names: ['5.bias']
        - Pre-conditioner: ['10x10 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0fbd340>}

This is because SIRFShampoo installs a hook onto the neural network to detect the batch size.

Of course, you can also tweak its other arguments, such as learning rates and momenta. See here for a complete overview.

Training

When it comes to training, SIRFShampoo can be used in exactly the same way as other optimizers (see here for an introduction). Let's train for a couple of steps and print the loss.

PRINT_LOSS_EVERY = 25  # logging interval

for step, (inputs, target) in enumerate(train_loader):
    optimizer.zero_grad()  # clear gradients from previous iterations

    # regular forward-backward pass
    loss = loss_func(model(inputs.to(DEV)), target.to(DEV))
    loss.backward()
    if step % PRINT_LOSS_EVERY == 0:
        print(f"Step {step}, Loss {loss.item():.3f}")

    optimizer.step()  # update neural network parameters

    if step >= MAX_STEPS:  # don't train a full epoch to keep the example light-weight
        break

Out:

Step 0, Loss 2.317
Step 25, Loss 2.322
Step 50, Loss 2.311
Step 75, Loss 2.297
Step 100, Loss 2.266
Step 125, Loss 2.163
Step 150, Loss 1.978
Step 175, Loss 0.866
Step 200, Loss 1.076

Conclusion

You now know the most basic way to train a neural network with SIRFShampoo. From here, you might be interested in

Total running time of the script: ( 0 minutes 9.419 seconds)

Download Python source code: example_01_basic.py

Download Jupyter notebook: example_01_basic.ipynb

Gallery generated by mkdocs-gallery