Skip to content

Note

Click here to download the full example code

Per-parameter Options.

Here we demonstrate SIRFShampoo's more fine-grained configuration options.

We will use parameter groups which allow training parameters of a neural network differently and demonstrate this by taking a CNN and training the parameters in the linear layers differently than those of the convolutional layers.

First, the imports.

from torch import cuda, device, manual_seed
from torch.nn import Conv2d, CrossEntropyLoss, Flatten, Linear, ReLU, Sequential
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from sirfshampoo import SIRFShampoo

manual_seed(0)  # make deterministic
MAX_STEPS = 200  # quit training after this many steps (or one epoch)
DEV = device("cuda" if cuda.is_available() else "cpu")

Problem Setup

Next, we load the data set, define the neural network, and the loss function:

BATCH_SIZE = 32
train_dataset = MNIST("./data", train=True, download=True, transform=ToTensor())
train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)

model = Sequential(
    Conv2d(1, 3, kernel_size=5, stride=2),
    ReLU(),
    Flatten(),
    Linear(432, 50),
    ReLU(),
    Linear(50, 10),
).to(DEV)
loss_func = CrossEntropyLoss().to(DEV)

Optimizer Setup

As mentioned above, we will train parameters of convolutions different than those in linear layers. We will do so by specifying two groups, and passing them to the optimizer via param_groups.

Specifically, we will use a dense pre-conditioner for convolutions, and a diagonal pre-conditioner for linear layers. We will also update the pre-conditioners at different steps.

First, we identify each group's parameters:

conv_params = [
    p
    for m in model.modules()
    if isinstance(m, Conv2d)
    for p in m.parameters()
    if p.requires_grad
]
linear_params = [
    p
    for m in model.modules()
    if isinstance(m, Linear)
    for p in m.parameters()
    if p.requires_grad
]

Second, let's set up the schedules for updating the pre-conditioners, as well as their structures:

def T_conv(step: int) -> bool:
    """Pre-conditioner update schedule for parameters in convolutional layers.

    Args:
        step: Global step of the optimizer.

    Returns:
        Whether to update the pre-conditioner.
    """
    steps = [0, 1, 2, 4, 8, 16, 32, 64, 128]
    if step in steps:
        print(f"Updating pre-conditioner of a convolution parameter at step {step}.")
    return step in steps


T_linear = 5  # every 5 steps


structures_conv = "dense"
structures_linear = "diagonal"

We are now ready to set up the two groups:

conv_group = {
    "params": conv_params,
    "structures": structures_conv,
    "T": T_conv,
}
linear_group = {
    "params": linear_params,
    "structures": structures_linear,
    "T": T_linear,
}

The param_groups are just a list containing the groups. We can pass it to the optimizer's params argument. Let's turn on the verbose_init flag to inspect the pre-conditioner structures:

param_groups = [conv_group, linear_group]
optimizer = SIRFShampoo(
    model,
    params=param_groups,
    lr=0.01,  # shared across all groups
    verbose_init=True,
)

Out:

Parameter groups:
Group 0
        - Parameter names: ['0.weight']
        - Pre-conditioner: ['3x3 (DenseMatrix)', '5x5 (DenseMatrix)', '5x5 (DenseMatrix)']
        - Other: {'structures': ('dense', 'dense', 'dense'), 'T': <function T_conv at 0x7fb4a0f40b80>, 'lr': 0.01, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'preconditioner_dtypes': (torch.float32, torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0d6ef40>}
Group 1
        - Parameter names: ['0.bias']
        - Pre-conditioner: ['3x3 (DenseMatrix)']
        - Other: {'structures': ('dense',), 'T': <function T_conv at 0x7fb4a0f40b80>, 'lr': 0.01, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb572252700>}
Group 2
        - Parameter names: ['3.weight']
        - Pre-conditioner: ['50x50 (DiagonalMatrix)', '432x432 (DiagonalMatrix)']
        - Other: {'structures': ('diagonal', 'diagonal'), 'T': 5, 'lr': 0.01, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb572252040>}
Group 3
        - Parameter names: ['3.bias']
        - Pre-conditioner: ['50x50 (DiagonalMatrix)']
        - Other: {'structures': ('diagonal',), 'T': 5, 'lr': 0.01, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb572252070>}
Group 4
        - Parameter names: ['5.weight']
        - Pre-conditioner: ['10x10 (DiagonalMatrix)', '50x50 (DiagonalMatrix)']
        - Other: {'structures': ('diagonal', 'diagonal'), 'T': 5, 'lr': 0.01, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb5722522e0>}
Group 5
        - Parameter names: ['5.bias']
        - Pre-conditioner: ['10x10 (DiagonalMatrix)']
        - Other: {'structures': ('diagonal',), 'T': 5, 'lr': 0.01, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb5722523d0>}

That's everything. What follows is just a canonical training loop.

Training

Let's train for a couple of steps and print the loss. SIRFShampoo works like most other PyTorch optimizers:

PRINT_LOSS_EVERY = 25  # logging interval

for step, (inputs, target) in enumerate(train_loader):
    optimizer.zero_grad()  # clear gradients from previous iterations

    # regular forward-backward pass
    loss = loss_func(model(inputs.to(DEV)), target.to(DEV))
    loss.backward()
    if step % PRINT_LOSS_EVERY == 0:
        print(f"Step {step}, Loss {loss.item():.3f}")

    optimizer.step()  # update neural network parameters

    if step >= MAX_STEPS:  # don't train a full epoch to keep the example light-weight
        break

Out:

Step 0, Loss 2.317
Updating pre-conditioner of a convolution parameter at step 0.
Updating pre-conditioner of a convolution parameter at step 0.
Updating pre-conditioner of a convolution parameter at step 1.
Updating pre-conditioner of a convolution parameter at step 1.
Updating pre-conditioner of a convolution parameter at step 2.
Updating pre-conditioner of a convolution parameter at step 2.
Updating pre-conditioner of a convolution parameter at step 4.
Updating pre-conditioner of a convolution parameter at step 4.
Updating pre-conditioner of a convolution parameter at step 8.
Updating pre-conditioner of a convolution parameter at step 8.
Updating pre-conditioner of a convolution parameter at step 16.
Updating pre-conditioner of a convolution parameter at step 16.
Step 25, Loss 2.289
Updating pre-conditioner of a convolution parameter at step 32.
Updating pre-conditioner of a convolution parameter at step 32.
Step 50, Loss 2.007
Updating pre-conditioner of a convolution parameter at step 64.
Updating pre-conditioner of a convolution parameter at step 64.
Step 75, Loss 0.894
Step 100, Loss 0.682
Step 125, Loss 0.538
Updating pre-conditioner of a convolution parameter at step 128.
Updating pre-conditioner of a convolution parameter at step 128.
Step 150, Loss 0.578
Step 175, Loss 0.180
Step 200, Loss 0.949

Conclusion

Congratulations! You now know how to train each layer of a neural network differently with SIRFShampoo.

For example, this may be useful when the network has layers with large pre-conditioner dimensions. One way to reduce cost would be to use a more light-weight pre-conditioner type (e.g. 'diagonal') for such layers. But of course you can also use this to tweak learning rates, momenta, etc. per layer.

To find out more about SIRFShampoo's configuration options, check out the optimizer's docstring.

Total running time of the script: ( 0 minutes 1.490 seconds)

Download Python source code: example_02_param_groups.py

Download Jupyter notebook: example_02_param_groups.ipynb

Gallery generated by mkdocs-gallery