Using torchtnt to count FLOPS

In this example, we demonstrate how to use torchtnt to estimate the number of floating-point operations per second (FLOPS) required for a model’s forward and backward pass.

We will use the FLOP computation to compare the resources used by a base FNO.

from copy import deepcopy
import torch
from torchtnt.utils.flops import FlopTensorDispatchMode

from neuralop.models import FNO

device = 'cpu'

fno = FNO(n_modes=(64,64),
          in_channels=1,
          out_channels=1,
          hidden_channels=64,
          projection_channel_ratio=1)

batch_size = 4
model_input = torch.randn(batch_size, 1, 128, 128)


with FlopTensorDispatchMode(fno) as ftdm:
    # count forward flops
    res = fno(model_input).mean()
    fno_forward_flops = deepcopy(ftdm.flop_counts)

    ftdm.reset()
    res.backward()
    fno_backward_flops = deepcopy(ftdm.flop_counts)

This output is organized as a defaultdict object that counts the FLOPS used in each submodule.

print(fno_forward_flops)
defaultdict(<function FlopTensorDispatchMode.__init__.<locals>.<lambda> at 0x7f56127fe310>, {'': defaultdict(<class 'int'>, {'convolution.default': 2982150144, 'bmm.default': 138412032}), 'lifting': defaultdict(<class 'int'>, {'convolution.default': 562036736}), 'lifting.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 25165824}), 'lifting.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 536870912}), 'fno_blocks': defaultdict(<class 'int'>, {'convolution.default': 2147483648, 'bmm.default': 138412032}), 'fno_blocks.fno_skips.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.0.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.0': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.0.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.0.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.1.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.1': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.1.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.1.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.2.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.2': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.2.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.2.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.3.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.3': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.3.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.3.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'projection': defaultdict(<class 'int'>, {'convolution.default': 272629760}), 'projection.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'projection.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 4194304})})

To check the maximum FLOPS used during the forward pass, let’s create a recursive function to search the nested dict:

from collections import defaultdict
def get_max_flops(flop_count_dict, max_value = 0):
    for _, value in flop_count_dict.items():
        # if not nested, compare leaf value to max
        if isinstance(value, int):
            max_value = max(max_value, value)

        # otherwise compute recursive max value below node
        elif isinstance(value, defaultdict):
            new_val = get_max_flops(value, max_value)
            max_value = max(max_value, new_val)
    return max_value

print(f"Max FLOPS required for FNO.forward: {get_max_flops(fno_forward_flops)}")
print(f"Max FLOPS required for FNO.backward: {get_max_flops(fno_backward_flops)}")
Max FLOPS required for FNO.forward: 2982150144
Max FLOPS required for FNO.backward: 5939134464

Total running time of the script: (0 minutes 3.747 seconds)

Gallery generated by Sphinx-Gallery