Skip to content

Runner

danling.runner.Runner

Bases: BaseRunner

Dynamic runner class that selects the appropriate platform based on configuration.

This runner dynamically modifies the __class__ attribute to adapt to the platform.

It’s safe (and recommended) to inherit from this class to extend the Runner.

Valid platform options are:

  • “auto” (default)
  • “torch”
  • “accelerate”
  • “deepspeed”

Examples:

Python Console Session
>>> config = Config({"platform": "accelerate"})
>>> runner = Runner(config)
See Also
Source code in danling/runner/runner.py
Python
class Runner(BaseRunner):
    r"""
    Dynamic runner class that selects the appropriate platform based on configuration.

    This runner dynamically modifies the `__class__` attribute to adapt to the platform.

    It's safe (and recommended) to inherit from this class to extend the Runner.

    Valid platform options are:

    - "auto" (default)
    - "torch"
    - "accelerate"
    - "deepspeed"

    Examples:
        >>> config = Config({"platform": "accelerate"})
        >>> runner = Runner(config)

    See Also:
        - [`BaseRunner`][danling.runner.BaseRunner]: Base class for all runners.
        - [`TorchRunner`][danling.runner.TorchRunner]: PyTorch runner.
        - [`AccelerateRunner`][danling.runner.AccelerateRunner]: PyTorch runner with Accelerate.
        - [`DeepSpeedRunner`][danling.runner.DeepSpeedRunner]: PyTorch runner with DeepSpeed.
    """

    def __init__(self, config: Config) -> None:
        platform = config.get("platform", "auto").lower()

        if platform == "auto":
            platform = "deepspeed" if ds.is_successful() else "torch"

        if platform == "accelerate":
            ac.check()
            self.__class__ = type("AccelerateRunner", (self.__class__, AccelerateRunner), {})
        elif platform == "deepspeed":
            ds.check()
            self.__class__ = type("DeepSpeedRunner", (self.__class__, DeepSpeedRunner), {})
        elif platform == "torch":
            self.__class__ = type("TorchRunner", (self.__class__, TorchRunner), {})
        else:
            raise ValueError(f"Unknown platform: {platform}. Valid options are: torch, accelerate, deepspeed")

        super().__init__(config)