Note
Go to the end to download the full example code.
Using torchtnt to count FLOPS
In this example, we demonstrate how to use torchtnt to estimate the number of floating-point operations per second (FLOPS) required for a model’s forward and backward pass.
We will use the FLOP computation to compare the resources used by a base FNO.
from copy import deepcopy
import torch
from torchtnt.utils.flops import FlopTensorDispatchMode
from neuralop.models import FNO
device = 'cpu'
fno = FNO(n_modes=(64,64),
in_channels=1,
out_channels=1,
hidden_channels=64,
projection_channel_ratio=1)
batch_size = 4
model_input = torch.randn(batch_size, 1, 128, 128)
with FlopTensorDispatchMode(fno) as ftdm:
# count forward flops
res = fno(model_input).mean()
fno_forward_flops = deepcopy(ftdm.flop_counts)
ftdm.reset()
res.backward()
fno_backward_flops = deepcopy(ftdm.flop_counts)
This output is organized as a defaultdict object that counts the FLOPS used in each submodule.
print(fno_forward_flops)
defaultdict(<function FlopTensorDispatchMode.__init__.<locals>.<lambda> at 0x7fb4d55f4900>, {'': defaultdict(<class 'int'>, {'convolution.default': 2982150144, 'bmm.default': 138412032}), 'lifting': defaultdict(<class 'int'>, {'convolution.default': 562036736}), 'lifting.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 25165824}), 'lifting.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 536870912}), 'fno_blocks': defaultdict(<class 'int'>, {'convolution.default': 2147483648, 'bmm.default': 138412032}), 'fno_blocks.fno_skips.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.0.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.0': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.0.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.0.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.1.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.1': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.1.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.1.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.2.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.2': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.2.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.2.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.3.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.3': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.3.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.3.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'projection': defaultdict(<class 'int'>, {'convolution.default': 272629760}), 'projection.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'projection.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 4194304})})
To check the maximum FLOPS used during the forward pass, let’s create a recursive function to search the nested dict:
from collections import defaultdict
def get_max_flops(flop_count_dict, max_value = 0):
for _, value in flop_count_dict.items():
# if not nested, compare leaf value to max
if isinstance(value, int):
max_value = max(max_value, value)
# otherwise compute recursive max value below node
elif isinstance(value, defaultdict):
new_val = get_max_flops(value, max_value)
max_value = max(max_value, new_val)
return max_value
print(f"Max FLOPS required for FNO.forward: {get_max_flops(fno_forward_flops)}")
print(f"Max FLOPS required for FNO.backward: {get_max_flops(fno_backward_flops)}")
Max FLOPS required for FNO.forward: 2982150144
Max FLOPS required for FNO.backward: 5939134464
Total running time of the script: (0 minutes 3.260 seconds)