Skip to content

Multi-node training expects num_nodes and devices, but that may be variable in a slurm cluster. #13804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
TheShadow29 opened this issue Jul 22, 2022 · 1 comment
Assignees
Labels
environment: slurm priority: 2 Low priority task strategy: ddp DistributedDataParallel
Milestone

Comments

@TheShadow29
Copy link

TheShadow29 commented Jul 22, 2022

🐛 Bug

Currently, Trainer requires num_nodes and devices, but this may be different across nodes. For instance, slurm may provide 1 node with 6 gpus, and 2 other nodes with 1 gpu each, for a total of 8 nodes. Right now, it gives the following error:

..../python3.9/site-packages/pytorch_lightning/strategies/ddp.py", line 118, in root_device
    return self.parallel_devices[self.local_rank]
IndexError: list index out of range
srun: error: <node-name>: tasks 6-7: Exited with exit code 1

To Reproduce

Note: SL_NUM_NODES being set externally

# dummy_run.py
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
import socket
import datetime
from pytorch_lightning.utilities.rank_zero import _get_rank

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def main():

    hostname = socket.gethostname()
    ngpus = torch.cuda.device_count()

    num_nodes = int(os.environ.get("SL_NUM_NODES", 1))
    wsize = int(os.environ.get("SLURM_NTASKS", ngpus * num_nodes))
    grank = _get_rank()
    print(
        f"Hostname={hostname}",
        f"nGPUs in host={torch.cuda.device_count()}",
        f"Start Time={datetime.datetime.now()}",
        f"num_nodes={num_nodes}",
        f"nCPUs = {torch.multiprocessing.cpu_count()}",
        f"wSize={wsize}",
        f"grank={grank}",
    )

    train_data = DataLoader(RandomDataset(32, 6400000), batch_size=2)

    model = BoringModel()

    trainer = Trainer(
        devices=ngpus,
        num_nodes=num_nodes,
        accelerator="gpu",
        strategy="ddp",
        limit_train_batches=100000,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        log_every_n_steps=1
    )
    trainer.fit(model, train_dataloaders=train_data)

    print("DONE")


if __name__ == "__main__":
    main()

And here is the slurm script (need to add , ,

#!/bin/bash
#SBATCH --job-name=check_dummy_run
#SBATCH --nodes=3
#SBATCH --time=1:00:00
#SBATCH --partition=<partition-name>
#SBATCH --gpus=8
#SBATCH --ntasks=8
#SBATCH --cpus-per-task=10

source ~/miniconda/etc/profile.d/conda.sh
conda activate <env-name>

export SL_NUM_NODES=3
export PYTHONPATH=$(pwd)

srun python dummy_run.py

Expected behavior

Ideally, the world size should be provided by cluster environment, and the trainer should create subprocesses only based on number of gpus available in current node.

Environment

  • PyTorch Lightning Version: 1.6.4
  • PyTorch Version: 1.10.1
  • Python version: 3.9.12
  • OS: Linux
  • CUDA/cuDNN version: 11.3
  • GPU models and configuration: 8x A100
  • How you installed PyTorch: conda

cc @awaelchli @tchaton @rohitgr7 @justusschock @kaushikb11 @akihironitta

@TheShadow29 TheShadow29 added the needs triage Waiting to be triaged by maintainers label Jul 22, 2022
@carmocca carmocca added bug Something isn't working environment: slurm strategy: ddp DistributedDataParallel and removed needs triage Waiting to be triaged by maintainers labels Aug 5, 2022
@carmocca carmocca added the priority: 2 Low priority task label Aug 5, 2022
@carmocca carmocca added this to the pl:1.7.x milestone Aug 5, 2022
@awaelchli
Copy link
Contributor

awaelchli commented Aug 7, 2022

Hi

Yes, this is a known limitation currently. While it was a true limitation in the past, today it is somewhat artificial.
I opened a proposal #14078 which should pave the way to remove this limitation eventually.

After #14078, you would simply set devices="auto" or devices=-1 and then the actual number of devices can be different per node.

I'm removing the bug label because this can't really be delivered as a bug fix, and depends on the decision in #14078.

@awaelchli awaelchli removed the bug Something isn't working label Aug 7, 2022
@awaelchli awaelchli modified the milestones: pl:1.7.x, pl:1.8 Aug 7, 2022
@carmocca carmocca modified the milestones: v1.8, future Oct 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
environment: slurm priority: 2 Low priority task strategy: ddp DistributedDataParallel
Projects
None yet
Development

No branches or pull requests

3 participants