Set up everything for running a job with 🤗 accelerate.
AccelerateRunner extends the Accelerator class to provide a more user-friendly
and consistent interface.
AccelerateRunner provides the most easy-to-use interface for distributed training, but it can be slow, and not
very flexible.
Read the documentation of Accelerator for more details.
Source code in danling/runner/accelerate_runner.py
classAccelerateRunner(TorchRunner,Accelerator):# pylint: disable=too-many-public-methodsr""" Set up everything for running a job with 🤗 [`accelerate`](https://huggingface.co/docs/accelerate/). `AccelerateRunner` extends the [`Accelerator`][accelerate.Accelerator] class to provide a more user-friendly and consistent interface. `AccelerateRunner` provides the most easy-to-use interface for distributed training, but it can be slow, and not very flexible. Read the documentation of [`Accelerator`][accelerate.Accelerator] for more details. """_accelerate:FlatDict|None=Nonedef__init__(self,config:Config)->None:ac.check()TorchRunner.__init__(self,config)Accelerator.__init__(self,**self.accelerate)ifself.distributed:object_list=[self.id,self.timestamp]dist.broadcast_object_list(object_list)self.id,self.timestamp=object_listdef__post_init__(self)->None:BaseRunner.__post_init__(self)self.project_configuration.set_directories(self.dir)ifself.datasets:self.build_dataloaders()ifself.config.get("log_interval")isNone:self.config.log_interval=(max(ceil(max(len(d)fordinself.dataloaders.values())/10),1)ifself.dataloaderselse1)self.model,self.criterion,self.optimizer,self.scheduler=self.prepare(self.model,self.criterion,self.optimizer,self.scheduler)deftrain_step(self,data)->torch.Tensor:withself.autocast(),self.accumulate():input=data["input"]ifisinstance(data,Mapping)elsedata[0]target=data["target"]ifisinstance(data,Mapping)elsedata[1]pred=self.model(**input)ifisinstance(input,Mapping)elseself.model(input)loss=self.criterion(pred,target)ifself.metricsisnotNone:self.metrics.update(pred.squeeze(-1),target)self.advance(loss)returnlossdefadvance(self,loss)->None:r""" Backward loss and step optimizer & scheduler. Args: loss: The loss tensor from which to backpropagate. """self.backward(loss)ifself.sync_gradients:ifself.config.get("max_grad_value")isnotNone:self.clip_grad_value_(self.model.parameters(),self.config["max_grad_value"])ifself.config.get("max_grad_norm")isnotNone:self.clip_grad_norm_(self.model.parameters(),self.config["max_grad_norm"])self.optimizer.step()ifself.schedulerisnotNone:self.scheduler.step()ifself.emaisnotNone:self.ema.update()self.optimizer.zero_grad()self.config.steps=self.stepdefunwrap(self,model:nn.Module)->nn.Module:returnself.unwrap_model(model)@propertydefaccelerate(self)->FlatDict:ifself._accelerateisNone:self._accelerate=self.get_accelerate_config(self.config)returnself._accelerate@accelerate.setterdefaccelerate(self,config:FlatDict)->None:self._accelerate=config@propertydefdeepspeed(self)->dict|None:ifself.state.deepspeed_pluginisnotNone:returnself.state.deepspeed_plugin.deepspeed_configreturnNone@contextmanagerdefaccumulate(self,*models:nn.Module):ifnotmodels:models=(self.model,)yieldAccelerator.accumulate(self,*models)@propertydefdevice(self)->torch.device:returnself.state.device@propertydefworld_size(self)->int:if"state"inself.__dict__:returnself.state.num_processesreturn1@propertydefrank(self)->int:if"state"inself.__dict__:returnself.state.process_indexreturn0@propertydeflocal_rank(self)->int:if"state"inself.__dict__:returnself.state.local_process_indexreturn0@cached_propertydefaccum_steps(self)->int:returnself.gradient_accumulation_stepsdefget_accelerate_config(self,config)->FlatDict:accelerate=FlatDict()if"accelerate"inconfig:accelerate.update(config.accelerate)if"precision"inconfig:accelerate.mixed_precision=config.precisionif"dynamo"inconfig:accelerate.dynamo_backend=config.dynamo.upper()if"accum_steps"inconfig:accelerate.gradient_accumulation_steps=config.accum_stepsif"kwargs_handlers"notinaccelerate:accelerate.kwargs_handlers=[]# Must NOT set project_dir here as timestamp is not synced yet# config.project_dir = self.dirifos.getenv("ACCELERATE_USE_DEEPSPEED","false").lower()=="true":deepspeed_config=config.get("deepspeed",os.getenv("ACCELERATE_DEEPSPEED_CONFIG_FILE"))accelerate.deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=self.get_deepspeed_config(deepspeed_config))returnacceleratedefbuild_dataloaders(self):datasets={k:dfork,dinself.datasets.items()ifknotinself.dataloaders}default_kwargs=self.config.setdefault("dataloader",NestedDict())dataloader_kwargs=NestedDict({k:default_kwargs.pop(k)forkinself.datasetsifkindefault_kwargs})fork,dindatasets.items():dataloader_kwargs.setdefault(k,NestedDict())dataloader_kwargs[k].merge(default_kwargs,overwrite=False)dataloader_kwargs[k].setdefault("shuffle",getattr(d,"train",True))dataloader_kwargs[k].setdefault("drop_last",notgetattr(d,"train",True))self.dataloaders[k]=utils.data.DataLoader(d,collate_fn=self.collate_fn,**dataloader_kwargs[k])default_kwargs.update(dataloader_kwargs)fork,dinself.dataloaders.items():self.dataloaders[k]=self.prepare(d)
defadvance(self,loss)->None:r""" Backward loss and step optimizer & scheduler. Args: loss: The loss tensor from which to backpropagate. """self.backward(loss)ifself.sync_gradients:ifself.config.get("max_grad_value")isnotNone:self.clip_grad_value_(self.model.parameters(),self.config["max_grad_value"])ifself.config.get("max_grad_norm")isnotNone:self.clip_grad_norm_(self.model.parameters(),self.config["max_grad_norm"])self.optimizer.step()ifself.schedulerisnotNone:self.scheduler.step()ifself.emaisnotNone:self.ema.update()self.optimizer.zero_grad()self.config.steps=self.step