Skip to content

Note

Click here to download the full example code

Pre-conditioner Groups.

In this tutorial, we show how to customize SIRFShampoo to treat multiple parameters with a single pre-conditioner, i.e. how to form a pre-conditioner group. First, we will illustrate the default behaviour, then talk about other built-in options. Last, we will show how to define custom rules.

Some use cases where this is useful are:

  • Combining the weight matrix and bias vector of an nn.Linear layer into one matrix by appending the bias as last column
  • Combining the d-dimensional weight and bias vectors of a normalization layer into a dx2 matrix
  • Combining multiple (say L) weights of shape d_out x d_in into a 3d tensor of shape L x d_out x d_in (think LLMs)
  • Reshaping a parameter tensor with a large dimension into a tensor with higher rank but smaller dimensions per axis (think embedding layers or large last linear layers)
  • Reshaping a parameter tensor with a large dimension into a vector and treating that with a light-weight pre-conditioner (e.g. diagonal, think embedding layers)

First, the imports.

from collections import OrderedDict
from typing import List

from pytest import raises
from torch import Size, Tensor, cat, cuda, device, manual_seed, rand, randint
from torch.nn import Embedding, Linear, Module, MSELoss, Parameter, ReLU, Sequential

from sirfshampoo import SIRFShampoo
from sirfshampoo.combiner import (
    FlattenEmbedding,
    LinearWeightBias,
    PerParameter,
    PreconditionerGroup,
)

manual_seed(0)  # make deterministic
DEV = device("cuda" if cuda.is_available() else "cpu")

Setup

We will not train neural networks in this tutorial, but only look at SIRFShampoo's configuration after setting up the optimizer. We will use a simple MLP with an embedding layer (for demonstration purposes), and two fully-connected layers activated by ReLU:

model = Sequential(
    OrderedDict(
        {
            "embedding": Embedding(64, 128),
            "linear1": Linear(128, 32, bias=False),
            "relu1": ReLU(),
            "linear2": Linear(32, 4),
        }
    )
).to(DEV)

Default Behaviour (One Pre-conditioner per Parameter)

By default, SIRFShampoo will treat each parameter with a separate pre-conditioner, and the number of Kronecker factors is determined by the parameter's axes. We can observe this while setting up the optimizer by turning on verbose_init=True:

optimizer = SIRFShampoo(model, verbose_init=True)

Out:

Parameter groups:
Group 0
        - Parameter names: ['embedding.weight']
        - Pre-conditioner: ['64x64 (DenseMatrix)', '128x128 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0db82e0>}
Group 1
        - Parameter names: ['linear1.weight']
        - Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0db8340>}
Group 2
        - Parameter names: ['linear2.weight']
        - Pre-conditioner: ['4x4 (DenseMatrix)', '32x32 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0dbd6a0>}
Group 3
        - Parameter names: ['linear2.bias']
        - Pre-conditioner: ['4x4 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a0dbd580>}

We can read off that each parameter forms its own group that is handled with an independent pre-conditioner. For instance, 'linear2.bias' has its own 4 x 4 pre-conditioner.

Built-in: Combining Weights and Biases of a Linear Layer

If you look carefully in the example above, you can also see under the 'Other' entry of each group that there is a key 'combine_params'. The associated value is an instance of PerParameter, which is one defined rule to assign parameters to pre-conditioners (in this case, one pre-conditioner per parameter).

We also provide a rule for treating weights and biases of an nn.Linear layer jointly. This is done by appending the bias as additional column to the weight matrix. To use this rule, we need to create an instance of LinearWeightBias and pass it to the optimizer's combine_params argument:

optimizer = SIRFShampoo(
    model,
    combine_params=(LinearWeightBias(), PerParameter()),
    verbose_init=True,
)

Out:

Parameter groups:
Group 0
        - Parameter names: ['linear2.weight', 'linear2.bias']
        - Pre-conditioner: ['4x4 (DenseMatrix)', '33x33 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.LinearWeightBias object at 0x7fb572256a30>}
Group 1
        - Parameter names: ['embedding.weight']
        - Pre-conditioner: ['64x64 (DenseMatrix)', '128x128 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb572256a90>}
Group 2
        - Parameter names: ['linear1.weight']
        - Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb572256a00>}

Note that now there are only three parameter groups, and one of them contains the weight and bias of a linear layer (also, the pre-conditioner dimensions are slightly different). The third group is simply the weight of the second layer which does not have a bias term. You can also see under 'Other' that the second group uses a LinearWeightBias instance under 'combine_params', while the last group uses a PerParameter instance.

We had to pass the tuple combine_params=(LinearWeightBias(), PerParameter()) to the optimizer, which will iterate over the supplied rules and identify pre-conditioner groups, prioritizing the rules that were supplied first.

Had we only passed combine_params=(LinearWeightBias(),), then the optimizer would have crashed because it would not have been able to assign the second layer's weight to a pre-conditioner:

with raises(ValueError):
    optimizer = SIRFShampoo(
        model,
        # no fall-back option to `PerParameter` leads to crash because the net has
        # a linear layer without bias
        combine_params=(LinearWeightBias(),),
    )

Built-in: Flattening Embedding Weights

Another option that we found useful in practise is to treat the (2d) weight matrix of an embedding layer as a (1d) vector, and pre-condition it with a diagonal matrix. We can use the built-in FlattenEmbedding rule to achieve this. Also, to use a diagonal pre-conditioner for this group, we need to specify two parameter groups:

embedding_params = [
    p
    for m in model.modules()
    if isinstance(m, Embedding)
    for p in m.parameters()
    if p.requires_grad
]
other_params = [
    p
    for m in model.modules()
    if isinstance(m, Linear)
    for p in m.parameters()
    if p.requires_grad
]

param_groups = [
    # Flatten embedding layer weights, pre-condition with diagonal matrix
    {
        "params": embedding_params,
        "combine_params": (FlattenEmbedding(),),
        "structures": "diagonal",
    },
    # Handle all other parameters with the default rule
    {"params": other_params},
]

optimizer = SIRFShampoo(model, params=param_groups, verbose_init=True)

Out:

Parameter groups:
Group 0
        - Parameter names: ['embedding.weight']
        - Pre-conditioner: ['8192x8192 (DiagonalMatrix)']
        - Other: {'combine_params': <sirfshampoo.combiner.FlattenEmbedding object at 0x7fb4a09dbb80>, 'structures': ('diagonal',), 'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'preconditioner_dtypes': (torch.float32,)}
Group 1
        - Parameter names: ['linear1.weight']
        - Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a09dbbe0>}
Group 2
        - Parameter names: ['linear2.weight']
        - Pre-conditioner: ['4x4 (DenseMatrix)', '32x32 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a09dbfa0>}
Group 3
        - Parameter names: ['linear2.bias']
        - Pre-conditioner: ['4x4 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <sirfshampoo.combiner.PerParameter object at 0x7fb4a09dbf40>}

If you look a the first parameter group, you can see that the pre-conditioner is now one factor of higher dimension and diagonal structure.

Writing Custom Pre-conditioner Groups

So far, we discussed the default option to treat each parameter with its own pre-conditioner via PerParameter, and to combine weight and bias of a linear layer with LinearWeightBias.

Here, we discuss how to implement a custom rule to group parameters together. sirfshampoo offers a PreconditionerGroup interface which can be implemented to create new rules.

Let's implement our custom (albeit a little artificial) rule, which will treat:

  • each weight parameter independently
  • all biases jointly with one pre-conditioner by stacking them together

Let's call this rule SeparateWeightsJointBiases. Here is its implementation:

class SeparateWeightsJointBiases(PreconditionerGroup):
    """Pre-conditioner group to treat weights independently and biases jointly."""

    def identify(self, model: Module) -> List[List[Parameter]]:
        """Detect parameters that should be treated jointly.

        Args:
            model: The neural network.

        Returns:
            A list of lists. Each sub-list contains the parameters that are treated
            jointly, i.e. either a single weight or all biases.
        """
        independent = []

        biases = []
        for name, param in model.named_parameters():
            if "weight" in name:
                independent.append([param])
            elif "bias" in name:
                biases.append(param)
        independent.append(biases)

        return independent

    def group(self, tensors: List[Tensor]) -> Tensor:
        """Combine tensors that are pre-conditioned together into one tensor.

        Args:
            tensors: List of tensors to combine. The list either has a single entry
                that is a weight-shaped tensor, or multiple entries that are bias-
                shaped tensors.

        Returns:
            The combined tensor.
        """
        # does nothing if `tensors` has one entry, otherwise
        # concatenates bias-shaped entries
        combined = cat(tensors)
        # NOTE It is good practise to remove axes of size 1 from the combined tensor,
        # because this will otherwise create 1x1 pre-conditioners which are unnecessary.
        combined = combined.squeeze()
        # However, the combined tensor must have at least one axis
        combined = combined.unsqueeze(0) if combined.ndim == 0 else combined

        return combined

    def ungroup(
        self, grouped_tensor: Tensor, tensor_shapes: List[Size]
    ) -> List[Tensor]:
        """Split the combined tensor into the original components.

        This is the inverse operation of `group`.

        Args:
            grouped_tensor: Combined tensor.
            tensor_shapes: Shapes of the tensors to split into.

        Returns:
            List of tensors that have the specified shapes.
        """
        if len(tensor_shapes) == 1:  # weight case or just one overall bias
            return [grouped_tensor.reshape(tensor_shapes[0])]

        # bias case
        bias_sizes = [s.numel() for s in tensor_shapes]
        tensors = grouped_tensor.split(bias_sizes)
        return [t.reshape(s) for t, s in zip(tensors, tensor_shapes)]

Let's create an optimizer with this custom rule (note that specifying the per-parameter rule as fallback would not be necessary here because our custom rule matches all parameters that are trained; but doing so is in general good practise).

optimizer = SIRFShampoo(
    model,
    combine_params=(SeparateWeightsJointBiases(), PerParameter()),
    verbose_init=True,
)

Out:

Parameter groups:
Group 0
        - Parameter names: ['embedding.weight']
        - Pre-conditioner: ['64x64 (DenseMatrix)', '128x128 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a0986100>}
Group 1
        - Parameter names: ['linear1.weight']
        - Pre-conditioner: ['32x32 (DenseMatrix)', '128x128 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a09861f0>}
Group 2
        - Parameter names: ['linear2.weight']
        - Pre-conditioner: ['4x4 (DenseMatrix)', '32x32 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense', 'dense'), 'preconditioner_dtypes': (torch.float32, torch.float32), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a0986430>}
Group 3
        - Parameter names: ['linear2.bias']
        - Pre-conditioner: ['4x4 (DenseMatrix)']
        - Other: {'lr': 0.001, 'beta2': 0.01, 'alpha1': 0.9, 'alpha2': 0.5, 'lam': 0.001, 'kappa': 0.0, 'T': 1, 'structures': ('dense',), 'preconditioner_dtypes': (torch.float32,), 'combine_params': <__main__.SeparateWeightsJointBiases object at 0x7fb4a0986640>}

As expected, the optimizer has five groups. Three contain one weight matrix each. The last group contains all bias parameters. We can also see a difference in the 'combine_params values in the 'Other' section.

To make sure everything works, let's train on synthetic data for a couple of steps.

# synthetic data
BATCH_SIZE = 32
X, y = randint(0, 32, (BATCH_SIZE,), device=DEV), rand(BATCH_SIZE, 4, device=DEV)
loss_func = MSELoss().to(DEV)

STEPS = 200
PRINT_LOSS_EVERY = 25  # logging interval
initial_loss = loss_func(model(X), y).item()

for step in range(STEPS):
    optimizer.zero_grad()  # clear gradients from previous iterations

    # regular forward-backward pass
    loss = loss_func(model(X), y)
    loss.backward()
    if step % PRINT_LOSS_EVERY == 0:
        print(f"Step: {step}, Loss: {loss.item():.3f}")

    optimizer.step()  # update neural network parameters

# make sure the loss decreased
final_loss = loss_func(model(X), y).item()
assert final_loss < initial_loss

Out:

Step: 0, Loss: 0.359
Step: 25, Loss: 0.225
Step: 50, Loss: 0.128
Step: 75, Loss: 0.081
Step: 100, Loss: 0.058
Step: 125, Loss: 0.045
Step: 150, Loss: 0.038
Step: 175, Loss: 0.033

Conclusion

Congratulations! You now know how to jointly pre-condition multiple parameters using sirfshampoo's built-in rules, and how to write your custom rules via the PreconditionerGroup interface.

To learn more about the PreconditionerGroup interface, check out its documentation.

Total running time of the script: ( 0 minutes 1.151 seconds)

Download Python source code: example_03_preconditioner_groups.py

Download Jupyter notebook: example_03_preconditioner_groups.ipynb

Gallery generated by mkdocs-gallery