Note
Go to the end to download the full example code.
Using torchtnt to count FLOPS
A demo using torchtnt to estimate the number of floating-point
operations per second (FLOPS) required for a model’s forward and backward pass.
This tutorial demonstrates how to profile neural operator models to understand their computational requirements. FLOPS counting is crucial for: - Comparing different model architectures - Understanding computational bottlenecks - Optimizing model efficiency - Making informed decisions about model deployment
We will use the FLOP computation to analyze the computational resources used by a FNO model.
Import dependencies
We import the necessary modules for FLOPS counting and model creation
from copy import deepcopy
import torch
from torchtnt.utils.flops import FlopTensorDispatchMode
from neuralop.models import FNO
device = "cpu"
/opt/hostedtoolcache/Python/3.13.9/x64/lib/python3.13/site-packages/torchtnt/utils/version.py:12: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
import pkg_resources
Creating the FNO model for analysis
We create a moderately-sized FNO model to demonstrate FLOPS counting
fno = FNO(
n_modes=(64, 64),
in_channels=1,
out_channels=1,
hidden_channels=64,
projection_channel_ratio=1,
)
# Create a sample input tensor for FLOPS counting
batch_size = 4
model_input = torch.randn(batch_size, 1, 128, 128)
Counting FLOPS for forward and backward passes
We use the FlopTensorDispatchMode to count FLOPS during both forward and backward passes
with FlopTensorDispatchMode(fno) as ftdm:
# Count forward pass FLOPS
res = fno(model_input).mean()
fno_forward_flops = deepcopy(ftdm.flop_counts)
# Reset the counter and count backward pass FLOPS
ftdm.reset()
res.backward()
fno_backward_flops = deepcopy(ftdm.flop_counts)
Analyzing FLOPS breakdown
The output is organized as a defaultdict object that counts the FLOPS used in each submodule. This gives us detailed insight into which parts of the model are computationally expensive.
print("Forward pass FLOPS breakdown:")
print(fno_forward_flops)
Forward pass FLOPS breakdown:
defaultdict(<function FlopTensorDispatchMode.__init__.<locals>.<lambda> at 0x7f564ef81c60>, {'': defaultdict(<class 'int'>, {'convolution.default': 2982150144, 'bmm.default': 138412032}), 'lifting': defaultdict(<class 'int'>, {'convolution.default': 562036736}), 'lifting.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 25165824}), 'lifting.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 536870912}), 'fno_blocks': defaultdict(<class 'int'>, {'convolution.default': 2147483648, 'bmm.default': 138412032}), 'fno_blocks.fno_skips.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.0.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.0': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.0.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.0.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.1.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.1': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.1': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.1.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.1.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.2.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.2': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.2': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.2.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.2.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.fno_skips.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.fno_skips.3.conv': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.convs.3': defaultdict(<class 'int'>, {'bmm.default': 34603008}), 'fno_blocks.channel_mlp.3': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'fno_blocks.channel_mlp.3.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'fno_blocks.channel_mlp.3.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 134217728}), 'projection': defaultdict(<class 'int'>, {'convolution.default': 272629760}), 'projection.fcs.0': defaultdict(<class 'int'>, {'convolution.default': 268435456}), 'projection.fcs.1': defaultdict(<class 'int'>, {'convolution.default': 4194304})})
Finding maximum FLOPS usage
To check the maximum FLOPS used during the forward pass, let’s create a recursive function to search the nested dictionary structure:
from collections import defaultdict
def get_max_flops(flop_count_dict, max_value=0):
for _, value in flop_count_dict.items():
# If not nested, compare leaf value to max
if isinstance(value, int):
max_value = max(max_value, value)
# Otherwise compute recursive max value below node
elif isinstance(value, defaultdict):
new_val = get_max_flops(value, max_value)
max_value = max(max_value, new_val)
return max_value
print(f"Max FLOPS required for FNO.forward: {get_max_flops(fno_forward_flops)}")
print(f"Max FLOPS required for FNO.backward: {get_max_flops(fno_backward_flops)}")
Max FLOPS required for FNO.forward: 2982150144
Max FLOPS required for FNO.backward: 5939134464
Total running time of the script: (0 minutes 2.600 seconds)