Note
Click here to download the full example code
Pre-conditioner Groups.
In this tutorial, we show how to customize SIRFShampoo to treat multiple parameters
with a single pre-conditioner, i.e. how to form a pre-conditioner group.
First, we will illustrate the default behaviour, then talk about other built-in options.
Last, we will show how to define custom rules.
Some use cases where this is useful are:
- Combining the weight matrix and bias vector of an
nn.Linearlayer into one matrix by appending the bias as last column - Combining the
d-dimensional weight and bias vectors of a normalization layer into adx2matrix - Combining multiple (say
L) weights of shaped_out x d_ininto a 3d tensor of shapeL x d_out x d_in(think LLMs) - Reshaping a parameter tensor with a large dimension into a tensor with higher rank but smaller dimensions per axis (think embedding layers or large last linear layers)
- Reshaping a parameter tensor with a large dimension into a vector and treating that with a light-weight pre-conditioner (e.g. diagonal, think embedding layers)
First, the imports.
from collections import OrderedDict
from typing import List
from pytest import raises
from torch import Size, Tensor, cat, cuda, device, manual_seed, rand, randint
from torch.nn import Embedding, Linear, Module, MSELoss, Parameter, ReLU, Sequential
from sirfshampoo import SIRFShampoo
from sirfshampoo.combiner import (
FlattenEmbedding,
LinearWeightBias,
PerParameter,
PreconditionerGroup,
)
manual_seed(0) # make deterministic
DEV = device("cuda" if cuda.is_available() else "cpu")
Setup
We will not train neural networks in this tutorial, but only look at SIRFShampoo's
configuration after setting up the optimizer. We will use a simple MLP with an
embedding layer (for demonstration purposes), and two fully-connected layers activated
by ReLU:
model = Sequential(
OrderedDict(
{
"embedding": Embedding(64, 128),
"linear1": Linear(128, 32, bias=False),
"relu1": ReLU(),
"linear2": Linear(32, 4),
}
)
).to(DEV)
Default Behaviour (One Pre-conditioner per Parameter)
By default, SIRFShampoo will treat each parameter with a separate pre-conditioner,
and the number of Kronecker factors is determined by the parameter's axes.
We can observe this while setting up the optimizer by turning on verbose_init=True:
Out:
Parameter groups:
Group 0
- Parameter names: ['embedding.weight']
- Pre-conditioner: ['64x64 (DenseMatrix)', '128x128 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0db82e0>}
Group 1
- Parameter names: ['linear1.weight']
- Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0db8340>}
Group 2
- Parameter names: ['linear2.weight']
- Pre-conditioner: ['4x4 (DenseMatrix)', '32x32 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0dbd6a0>}
Group 3
- Parameter names: ['linear2.bias']
- Pre-conditioner: ['4x4 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0dbd580>}
We can read off that each parameter forms its own group that is handled with an
independent pre-conditioner. For instance, 'linear2.bias' has its own 4 x 4
pre-conditioner.
Built-in: Combining Weights and Biases of a Linear Layer
If you look carefully in the example above, you can also see under the
'Other' entry of each group that there is a key 'combine_params'. The
associated value is an instance of PerParameter,
which is one defined rule to assign parameters to pre-conditioners (in this
case, one pre-conditioner per parameter).
We also provide a rule for treating weights and biases of an nn.Linear
layer jointly. This is done by appending the bias as additional column to the
weight matrix. To use this rule, we need to create an instance of
LinearWeightBias
and pass it to the optimizer's combine_params argument:
optimizer = SIRFShampoo(
model,
combine_params=(LinearWeightBias(), PerParameter()),
verbose_init=True,
)
Out:
Parameter groups:
Group 0
- Parameter names: ['linear2.weight', 'linear2.bias']
- Pre-conditioner: ['4x4 (DenseMatrix)', '33x33 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.LinearWeightBias object at 0x7fb572256a30>}
Group 1
- Parameter names: ['embedding.weight']
- Pre-conditioner: ['64x64 (DenseMatrix)', '128x128 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb572256a90>}
Group 2
- Parameter names: ['linear1.weight']
- Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb572256a00>}
Note that now there are only three parameter groups, and one of them contains the
weight and bias of a linear layer (also, the pre-conditioner dimensions are slightly
different). The third group is simply the weight of the second layer which does not
have a bias term. You can also see under 'Other' that the second group uses a
LinearWeightBias instance under 'combine_params', while the last group uses
a PerParameter instance.
We had to pass the tuple combine_params=(LinearWeightBias(), PerParameter()) to
the optimizer, which will iterate over the supplied rules and identify
pre-conditioner groups, prioritizing the rules that were supplied first.
Had we only passed combine_params=(LinearWeightBias(),), then the optimizer would
have crashed because it would not have been able to assign the second layer's weight
to a pre-conditioner:
with raises(ValueError):
optimizer = SIRFShampoo(
model,
# no fall-back option to `PerParameter` leads to crash because the net has
# a linear layer without bias
combine_params=(LinearWeightBias(),),
)
Built-in: Flattening Embedding Weights
Another option that we found useful in practise is to treat the (2d) weight matrix of
an embedding layer as a (1d) vector, and pre-condition it with a diagonal matrix.
We can use the built-in FlattenEmbedding rule
to achieve this. Also, to use a diagonal pre-conditioner for this group, we need to
specify two parameter groups:
embedding_params = [
p
for m in model.modules()
if isinstance(m, Embedding)
for p in m.parameters()
if p.requires_grad
]
other_params = [
p
for m in model.modules()
if isinstance(m, Linear)
for p in m.parameters()
if p.requires_grad
]
param_groups = [
# Flatten embedding layer weights, pre-condition with diagonal matrix
{
"params": embedding_params,
"combine_params": (FlattenEmbedding(),),
"structures": "diagonal",
},
# Handle all other parameters with the default rule
{"params": other_params},
]
optimizer = SIRFShampoo(model, params=param_groups, verbose_init=True)
Out:
Parameter groups:
Group 0
- Parameter names: ['embedding.weight']
- Pre-conditioner: ['8192x8192 (DiagonalMatrix)']
- Other: {'combine_params': <sirfshampoo.combiner.FlattenEmbedding object at 0x7fb4a09dbb80>, 'structures': ('diagonal',), 'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'preconditioner_dtypes': (torch.float32,)}
Group 1
- Parameter names: ['linear1.weight']
- Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a09dbbe0>}
Group 2
- Parameter names: ['linear2.weight']
- Pre-conditioner: ['4x4 (DenseMatrix)', '32x32 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a09dbfa0>}
Group 3
- Parameter names: ['linear2.bias']
- Pre-conditioner: ['4x4 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a09dbf40>}
If you look a the first parameter group, you can see that the pre-conditioner is now one factor of higher dimension and diagonal structure.
Writing Custom Pre-conditioner Groups
So far, we discussed the default option to treat each parameter with its own
pre-conditioner via PerParameter,
and to combine weight and bias of a linear layer with LinearWeightBias.
Here, we discuss how to implement a custom rule to group parameters together.
sirfshampoo offers a PreconditionerGroup
interface which can be implemented to create new rules.
Let's implement our custom (albeit a little artificial) rule, which will treat:
- each weight parameter independently
- all biases jointly with one pre-conditioner by stacking them together
Let's call this rule SeparateWeightsJointBiases. Here is its implementation:
class SeparateWeightsJointBiases(PreconditionerGroup):
"""Pre-conditioner group to treat weights independently and biases jointly."""
def identify(self, model: Module) -> List[List[Parameter]]:
"""Detect parameters that should be treated jointly.
Args:
model: The neural network.
Returns:
A list of lists. Each sub-list contains the parameters that are treated
jointly, i.e. either a single weight or all biases.
"""
independent = []
biases = []
for name, param in model.named_parameters():
if "weight" in name:
independent.append([param])
elif "bias" in name:
biases.append(param)
independent.append(biases)
return independent
def group(self, tensors: List[Tensor]) -> Tensor:
"""Combine tensors that are pre-conditioned together into one tensor.
Args:
tensors: List of tensors to combine. The list either has a single entry
that is a weight-shaped tensor, or multiple entries that are bias-
shaped tensors.
Returns:
The combined tensor.
"""
# does nothing if `tensors` has one entry, otherwise
# concatenates bias-shaped entries
combined = cat(tensors)
# NOTE It is good practise to remove axes of size 1 from the combined tensor,
# because this will otherwise create 1x1 pre-conditioners which are unnecessary.
combined = combined.squeeze()
# However, the combined tensor must have at least one axis
combined = combined.unsqueeze(0) if combined.ndim == 0 else combined
return combined
def ungroup(
self, grouped_tensor: Tensor, tensor_shapes: List[Size]
) -> List[Tensor]:
"""Split the combined tensor into the original components.
This is the inverse operation of `group`.
Args:
grouped_tensor: Combined tensor.
tensor_shapes: Shapes of the tensors to split into.
Returns:
List of tensors that have the specified shapes.
"""
if len(tensor_shapes) == 1: # weight case or just one overall bias
return [grouped_tensor.reshape(tensor_shapes[0])]
# bias case
bias_sizes = [s.numel() for s in tensor_shapes]
tensors = grouped_tensor.split(bias_sizes)
return [t.reshape(s) for t, s in zip(tensors, tensor_shapes)]
Let's create an optimizer with this custom rule (note that specifying the per-parameter rule as fallback would not be necessary here because our custom rule matches all parameters that are trained; but doing so is in general good practise).
optimizer = SIRFShampoo(
model,
combine_params=(SeparateWeightsJointBiases(), PerParameter()),
verbose_init=True,
)
Out:
Parameter groups:
Group 0
- Parameter names: ['embedding.weight']
- Pre-conditioner: ['64x64 (DenseMatrix)', '128x128 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a0986100>}
Group 1
- Parameter names: ['linear1.weight']
- Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a09861f0>}
Group 2
- Parameter names: ['linear2.weight']
- Pre-conditioner: ['4x4 (DenseMatrix)', '32x32 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a0986430>}
Group 3
- Parameter names: ['linear2.bias']
- Pre-conditioner: ['4x4 (DenseMatrix)']
- Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a0986640>}
As expected, the optimizer has five groups. Three contain one weight matrix each.
The last group contains all bias parameters. We can also see a difference in the
'combine_params values in the 'Other' section.
To make sure everything works, let's train on synthetic data for a couple of steps.
# synthetic data
BATCH_SIZE = 32
X, y = randint(0, 32, (BATCH_SIZE,), device=DEV), rand(BATCH_SIZE, 4, device=DEV)
loss_func = MSELoss().to(DEV)
STEPS = 200
PRINT_LOSS_EVERY = 25 # logging interval
initial_loss = loss_func(model(X), y).item()
for step in range(STEPS):
optimizer.zero_grad() # clear gradients from previous iterations
# regular forward-backward pass
loss = loss_func(model(X), y)
loss.backward()
if step % PRINT_LOSS_EVERY == 0:
print(f"Step: {step}, Loss: {loss.item():.3f}")
optimizer.step() # update neural network parameters
# make sure the loss decreased
final_loss = loss_func(model(X), y).item()
assert final_loss < initial_loss
Out:
Step: 0, Loss: 0.359
Step: 25, Loss: 0.225
Step: 50, Loss: 0.128
Step: 75, Loss: 0.081
Step: 100, Loss: 0.058
Step: 125, Loss: 0.045
Step: 150, Loss: 0.038
Step: 175, Loss: 0.033
Conclusion
Congratulations! You now know how to jointly pre-condition multiple parameters
using sirfshampoo's built-in rules, and how to write your custom rules via
the PreconditionerGroup interface.
To learn more about the PreconditionerGroup interface, check out its
documentation.
Total running time of the script: ( 0 minutes 1.151 seconds)
Download Python source code: example_03_preconditioner_groups.py
Download Jupyter notebook: example_03_preconditioner_groups.ipynb