Source code for neuralop.data.datasets.navier_stokes

import logging
import os
from pathlib import Path
from typing import Union, List

from torch.utils.data import DataLoader

from .pt_dataset import PTDataset
from .web_utils import download_from_zenodo_record
from neuralop.utils import get_project_root

logger = logging.Logger(logging.root.level)




example_data_root = get_project_root() / "neuralop/datasets/data"
# load navier stokes pt for backwards compatibility
def load_navier_stokes_pt(n_train,
    n_tests,
    batch_size,
    test_batch_sizes,
    data_root = example_data_root,
    train_resolution=128,
    test_resolutions=[128],
    encode_input=False,
    encode_output=True,
    encoding="channel-wise",
    channel_dim=1,
    subsampling_rate=None,):

    dataset = NavierStokesDataset(root_dir = data_root,
                           n_train=n_train,
                           n_tests=n_tests,
                           batch_size=batch_size,
                           test_batch_sizes=test_batch_sizes,
                           train_resolution=train_resolution,
                           test_resolutions=test_resolutions,
                           encode_input=encode_input,
                           encode_output=encode_output,
                           encoding=encoding,
                           channel_dim=channel_dim,
                           subsampling_rate=subsampling_rate)
    
    # return dataloaders for backwards compat
    train_loader = DataLoader(dataset.train_db,
                              batch_size=batch_size,
                              num_workers=0,
                              pin_memory=True,
                              persistent_workers=False,)
    
    test_loaders = {}
    for res,test_bsize in zip(test_resolutions, test_batch_sizes):
        test_loaders[res] = DataLoader(dataset.test_dbs[res],
                                       batch_size=test_bsize,
                                       shuffle=False,
                                       num_workers=0,
                                       pin_memory=True,
                                       persistent_workers=False,)
    
    return train_loader, test_loaders, dataset.data_processor