Note
Go to the end to download the full example code.
Using torchtnt to count FLOPS
A demo using torchtnt
to estimate the number of floating-point
operations per second (FLOPS) required for a model’s forward and backward pass.
We will use the FLOP computation to compare the resources used by a base FNO.
from copy import deepcopy
import torch
from torchtnt.utils.flops import FlopTensorDispatchMode
from neuralop.models import FNO
device = 'cpu'
fno = FNO(n_modes=(64,64),
in_channels=1,
out_channels=1,
hidden_channels=64,
projection_channel_ratio=1)
batch_size = 4
model_input = torch.randn(batch_size, 1, 128, 128)
with FlopTensorDispatchMode(fno) as ftdm:
# count forward flops
res = fno(model_input).mean()
fno_forward_flops = deepcopy(ftdm.flop_counts)
ftdm.reset()
res.backward()
fno_backward_flops = deepcopy(ftdm.flop_counts)
This output is organized as a defaultdict object that counts the FLOPS used in each submodule.
print(fno_forward_flops)
defaultdict(<function FlopTensorDispatchMode.__init__.<locals>.<lambda> at 0x7fceca8bc5e0>, {'': defaultdict(<class 'int'>, {'convolution.default': 2982150144, 'bmm.default': 138412032}), 'lifting': defaultdict(<class 'int'>, {'convolution.default': 562036736}), 'lifting.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 25165824}), 'lifting.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 536870912}), 'fno_blocks': defaultdict(<class 'int'>, {'convolution.default': 2147483648, 'bmm.default': 138412032}), 'fno_blocks.fno_skips.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.0.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.0': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.0.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.0.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.1.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.1': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.1.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.1.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.2.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.2': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.2.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.2.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.3.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.3': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.3.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.3.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'projection': defaultdict(<class 'int'>, {'convolution.default': 272629760}), 'projection.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'projection.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 4194304})})
To check the maximum FLOPS used during the forward pass, let’s create a recursive function to search the nested dict:
from collections import defaultdict
def get_max_flops(flop_count_dict, max_value = 0):
for _, value in flop_count_dict.items():
# if not nested, compare leaf value to max
if isinstance(value, int):
max_value = max(max_value, value)
# otherwise compute recursive max value below node
elif isinstance(value, defaultdict):
new_val = get_max_flops(value, max_value)
max_value = max(max_value, new_val)
return max_value
print(f"Max FLOPS required for FNO.forward: {get_max_flops(fno_forward_flops)}")
print(f"Max FLOPS required for FNO.backward: {get_max_flops(fno_backward_flops)}")
Max FLOPS required for FNO.forward: 2982150144
Max FLOPS required for FNO.backward: 5939134464
Total running time of the script: (0 minutes 2.550 seconds)