Source code for neuralop.models.base_model
import inspect
import torch
import warnings
from pathlib import Path
# Author: Jean Kossaifi
class BaseModel(torch.nn.Module):
"""Based class for all Models
This class has two main functionalities:
* It monitors the creation of subclass, that are automatically registered
for users to use by name using the library's config system
* When a new instance of this class is created, the init call is intercepted
so we can store the parameters used to create the instance.
This makes it possible to save trained models along with their init parameters,
and therefore load saved modes easily.
Notes
-----
Model can be versioned using the _version class attribute.
This can be used for sanity check when loading models from checkpoints to verify the
model hasn't been updated since.
"""
_models = dict()
_version = '0.1.0'
def __init_subclass__(cls, name=None, **kwargs):
"""When a subclass is created, register it in _models
We look for an existing name attribute.
If not give, then we use the class' name.
"""
super().__init_subclass__(**kwargs)
if name is not None:
BaseModel._models[name.lower()] = cls
cls._name = name
else:
# warnings.warn(f'Creating a subclass of BaseModel {cls.__name__} with no name, initializing with {cls.__name__}.')
BaseModel._models[cls.__name__.lower()] = cls
cls._name = cls.__name__
def __new__(cls, *args, **kwargs):
"""Verify arguments and save init kwargs for loading/saving
We inspect the class' signature and check for unused parameters, or
parameters not passed.
We store all the args and kwargs given so we can duplicate the instance transparently.
"""
sig = inspect.signature(cls)
model_name = cls.__name__
verbose = kwargs.get('verbose', False)
# Verify that given parameters are actually arguments of the model
for key in kwargs:
if key not in sig.parameters:
if verbose:
print(f"Given argument key={key} "
f"that is not in {model_name}'s signature.")
# Check for model arguments not specified in the configuration
for key, value in sig.parameters.items():
if (value.default is not inspect._empty) and (key not in kwargs):
if verbose:
print(
f"Keyword argument {key} not specified for model {model_name}, "
f"using default={value.default}."
)
kwargs[key] = value.default
if hasattr(cls, '_version'):
kwargs['_version'] = cls._version
kwargs['args'] = args
kwargs['_name'] = cls._name
instance = super().__new__(cls)
instance._init_kwargs = kwargs
return instance
def save_checkpoint(self, save_folder, save_name):
"""Saves the model state and init param in the given folder under the given name
"""
save_folder = Path(save_folder)
if not save_folder.exists():
save_folder.mkdir(parents=True)
state_dict_filepath = save_folder.joinpath(f'{save_name}_state_dict.pt').as_posix()
torch.save(self.state_dict(), state_dict_filepath)
metadata_filepath = save_folder.joinpath(f'{save_name}_metadata.pkl').as_posix()
# Objects (e.g. GeLU) are not serializable by json - find a better solution in the future
torch.save(self._init_kwargs, metadata_filepath)
# with open(metadata_filepath, 'w') as f:
# json.dump(self._init_kwargs, f)
def load_checkpoint(self, save_folder, save_name, map_location=None):
save_folder = Path(save_folder)
state_dict_filepath = save_folder.joinpath(f'{save_name}_state_dict.pt').as_posix()
self.load_state_dict(torch.load(state_dict_filepath, map_location=map_location))
@classmethod
def from_checkpoint(cls, save_folder, save_name, map_location=None):
save_folder = Path(save_folder)
metadata_filepath = save_folder.joinpath(f'{save_name}_metadata.pkl').as_posix()
init_kwargs = torch.load(metadata_filepath)
# with open(metadata_filepath, 'r') as f:
# init_kwargs = json.load(f)
version = init_kwargs.pop('_version')
if hasattr(cls, '_version') and version != cls._version:
print(version)
warnings.warn(f'Checkpoing saved for version {version} of model {cls._name} but current code is version {cls._version}')
if 'args' in init_kwargs:
init_args = init_kwargs.pop('args')
else:
init_args = []
instance = cls(*init_args, **init_kwargs)
instance.load_checkpoint(save_folder, save_name, map_location=map_location)
return instance
[docs]
def available_models():
"""List the available neural operators"""
return list(BaseModel._models.keys())
[docs]
def get_model(config):
"""Returns an instantiated model for the given config
* Reads the model to be used from config['arch']
* Adjusts config["arch"]["data_channels"] accordingly if multi-grid patching is used
Also prints warnings for safety, in case::
* some given arguments aren't actually used by the model
* some keyword arguments of the models aren't provided by the config
Parameters
----------
config : Bunch or dict-like
configuration, must have
arch = config['arch'] (string)
and the corresponding config[arch] (a subdict with the kwargs of the model)
Returns
-------
model : nn.Module
the instanciated module
"""
arch = config["arch"].lower()
config_arch = config.get(arch)
# Set the number of input channels depending on channels in data + mg patching
data_channels = config_arch.pop("data_channels")
try:
patching_levels = config["patching"]["levels"]
except KeyError:
patching_levels = 0
if patching_levels:
data_channels *= patching_levels + 1
config_arch["in_channels"] = data_channels
# Dispatch model creation
try:
return BaseModel._models[arch](**config_arch)
except KeyError:
raise ValueError(f"Got config.arch={arch}, expected one of {available_models()}.")