class AccelerateRunner(TorchRunner, Accelerator): # pylint: disable=too-many-public-methods
r"""
Set up everything for running a job.
`AccelerateRunner` uses [`accelerate`][accelerate] as distributed backend to
provide seamless distributed training experience.
`AccelerateRunner` will automatically [`prepare`][accelerate.Accelerator.prepare] everything,
including `model`, `criterion`, `optimizer`, `scheduler`, and `dataloaders` for distribute training,
mixed precision, and deepspeed (optional).
In fact, you don't even need to create `dataloaders`, just define
`datasets` and `AccelerateRunner` will create `dataloaders` for you.
`AccelerateRunner` will inspect the `train` flag in corresponding dataset to
set `shuffle` and `drop_last` automatically.
"""
_accelerate: FlatDict | None = None
def __init__(self, config: Config) -> None:
ac.check()
TorchRunner.__init__(self, config)
Accelerator.__init__(self, **self.accelerate)
if self.distributed:
object_list = [self.id, self.timestamp]
dist.broadcast_object_list(object_list)
self.id, self.timestamp = object_list
def __post_init__(self) -> None:
BaseRunner.__post_init__(self)
self.project_configuration.set_directories(self.dir)
if self.datasets:
self.build_dataloaders()
if self.config.get("log_interval") is None:
self.config.log_interval = max(ceil(max(len(d) for d in self.dataloaders.values()) / 10), 1)
self.model, self.criterion, self.optimizer, self.scheduler = self.prepare(
self.model, self.criterion, self.optimizer, self.scheduler
)
def train_step(self, data) -> torch.Tensor:
with self.autocast(), self.accumulate():
input = data["input"] if isinstance(data, Mapping) else data[0]
target = data["target"] if isinstance(data, Mapping) else data[1]
pred = self.model(**input) if isinstance(input, Mapping) else self.model(input)
loss = self.criterion(pred, target)
if self.metrics is not None:
self.metrics.update(pred.squeeze(-1), target)
self.advance(loss)
return loss
def advance(self, loss) -> None:
r"""
Backward loss and step optimizer & scheduler.
Args:
zero_grad: Whether to zero the gradients.
"""
self.backward(loss)
if self.sync_gradients:
if self.config.get("max_grad_value") is not None:
self.clip_grad_value_(self.model.parameters(), self.config.get("max_grad_value"))
if self.config.get("max_grad_norm") is not None:
self.clip_grad_norm_(self.model.parameters(), self.config.get("max_grad_norm"))
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()
self.optimizer.zero_grad()
self.config.step = self.step
def unwrap(self, model: nn.Module) -> nn.Module:
return self.unwrap_model(model)
@property
def accelerate(self) -> FlatDict:
if self._accelerate is None:
self._accelerate = self.get_accelerate_config(self.config)
return self._accelerate
@accelerate.setter
def accelerate(self, config: FlatDict) -> None:
self._accelerate = config
@property
def deepspeed(self) -> dict | None:
if self.state.deepspeed_plugin is not None:
return self.state.deepspeed_plugin.deepspeed_config
return None
@contextmanager
def accumulate(self, *models: nn.Module):
if not models:
models = (self.model,)
yield Accelerator.accumulate(self, *models)
@property
def device(self) -> torch.device:
return self.state.device
@property
def world_size(self) -> int:
if "state" in self.__dict__:
return self.state.num_processes
return 1
@property
def rank(self) -> int:
if "state" in self.__dict__:
return self.state.process_index
return 0
@property
def local_rank(self) -> int:
if "state" in self.__dict__:
return self.state.local_process_index
return 0
@cached_property
def accum_steps(self) -> int:
return self.gradient_accumulation_steps
def get_accelerate_config(self, config) -> FlatDict:
accelerate = FlatDict()
if "accelerate" in config:
accelerate.update(config.accelerate)
if "precision" in config:
accelerate.mixed_precision = config.precision
if "dynamo" in config:
accelerate.dynamo_backend = config.dynamo.upper()
if "accum_steps" in config:
accelerate.gradient_accumulation_steps = config.accum_steps
if "kwargs_handlers" not in accelerate:
accelerate.kwargs_handlers = []
# Must NOT set project_dir here as timestamp is not synced yet
# config.project_dir = self.dir
if os.getenv("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true":
deepspeed_config = config.get("deepspeed", os.getenv("ACCELERATE_DEEPSPEED_CONFIG_FILE"))
accelerate.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.get_deepspeed_config(deepspeed_config))
return accelerate
def build_dataloaders(self):
datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
default_kwargs = self.config.setdefault("dataloader", NestedDict())
dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
for k, d in datasets.items():
dataloader_kwargs.setdefault(k, NestedDict())
dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
dataloader_kwargs[k].setdefault("shuffle", getattr(d, "train", True))
dataloader_kwargs[k].setdefault("drop_last", not getattr(d, "train", True))
self.dataloaders[k] = utils.data.DataLoader(d, collate_fn=self.collate_fn, **dataloader_kwargs[k])
default_kwargs.update(dataloader_kwargs)
for k, d in self.dataloaders.items():
self.dataloaders[k] = self.prepare(d)