import os, sys, re, multiprocess as mp
from typing import Callable
from contextlib import AbstractContextManager, nullcontext
import torch
__all__ = ['import_star', 'global_imports', 'ranch', 'TorchDDPCtx', 'in_torchddp']
# - globals() doesn't necessarily return the '__main__' global scope when inside package function,
# thus we use sys.modules['__main__'].__dict__.
def import_star(modules:[str], ns:dict=None):
"""Import ``*`` from a list of module, into namespace ns (default to '__main__')
Args:
modules: list of modules or packages
ns: destination namespace, optional. If not provided, will default to '__main__'
"""
global_imports([f"from {m} import *" for m in modules], ns)
def global_imports(imports:[str], ns:dict=None):
"""
Parse and execute multiple import statements, and import into target namespace 'ns'
Args:
imports: list of import statements, as in Python code. Supported formats include:
* import x, y, z as z_alias
* from A import x
* from A import z as z_alias
* from A import x, y, z as z_alias
Not supported: 'from A import (a, b)'
ns: target namespace to import into. Default to '__main__'
"""
if ns is None:
import sys
ns = sys.modules['__main__'].__dict__
pat = re.compile(r'^\s*?(?:from\s+?(\S+?)\s+?)?import\s+?(.+)$')
pat_as = re.compile(r'^\s*?(\S+?)(?:\s*?as\s+?(\S+?))?\s*$')
for parsed in filter(lambda p:p,[pat.match(i) for i in imports]):
(from_, imp_) = parsed.groups()
imps = imp_.split(',')
# Parse "from X import ..."
from_mod = __import__(from_, fromlist=['']) if from_ else None
for name in imps: # each comma-separated item in import a, b, x as y
(x, y) = pat_as.match(name).groups()
if y is None: y=x
if x == '*': # Handle starred import: 'from X import *'
assert from_, SyntaxError(f"From what <module> are you trying to 'import *': {parsed.string}")
importables = getattr(from_mod, "__all__", [n for n in dir(from_mod) if not n.startswith('_')])
for o in importables: ns[o] = getattr(from_mod, o)
else: # x is either a name in 1 module, OR a module itself
ns[y] = getattr(from_mod, x) if from_ else __import__(x, fromlist=[''])
def _contextualize(i:int, nprocs:int, fn:Callable, cm:AbstractContextManager, l=None, env:dict={}, imports=""):
"Return a function that will setup os.environ and execute a target function within a context manager."
if l: assert i < len(l), ValueError("Invalid index {i}, exceeds size of the result list: {len(l)}")
def _cfn(*args, **kwargs):
import os
os.environ.update({"LOCAL_RANK":str(i), "LOCAL_WORLD_SIZE":str(nprocs)})
try:
import sys
from mpify import global_imports
# import env into '__main__', which can be in a subprocess here.
g = sys.modules['__main__'].__dict__
global_imports(imports.split('\n'), g)
g.update(env)
with cm or nullcontext(): r = fn(*args, **kwargs)
if l: l[i] = r
return r
finally: map(lambda k: os.environ.pop(k, None), ("LOCAL_RANK", "LOCAL_WORLD_SIZE"))
return _cfn
[docs]def ranch(nprocs:int, fn:Callable, *args, caller_rank:int=0, gather:bool=True, ctx:AbstractContextManager=None, need:str="", imports="", **kwargs):
""" Execute `fn(\*args, \*\*kwargs)` distributedly in `nprocs` processes. User can
serialize over objects and functions, spell out import statements, manage execution
context, gather results, and the parent process can participate as one of the workers.
If `caller_rank` is `0 <= caller_rank < nprocs`, only `nprocs - 1` processes will be forked, and the caller process will be a worker to run its share of `fn(..)`.
If `caller_rank` is ``None``, `nprocs` processes will be forked.
Inside each worker process, its relative rank among all workers is set up in `os.environ['LOCAL_RANK']`, and the total
number of workers is set up in `os.environ['LOCAL_WORLD_SIZE']`, both as strings.
Then import statements in `imports`, followed by any objects/functions in `need`, are brought
into the python global namespace.
Then, context manager `ctx` is applied around the call `fn(\*args, \*\*kwargs)`.
Return value of each worker can be gathered in a list (indexed by the process's rank)
and returned to the caller of `ranch()`.
Args:
nprocs: Number of processes to fork. Visible as a string in `os.environ['LOCAL_WORLD_SIZE']`
in all worker processes.
fn: Function to execute on the worker pool
\*args: Positional arguments by values to `fn(\*args....)`
\*\*kwargs: Named parameters to `fn(x=..., y=....)`
caller_rank: Rank of the parent process. ``0 <= caller_rank < nprocs`` to join, ``None`` to opt out. Default to ``0``.
In distributed data parallel, 0 means the leading process.
gather: if ``True``, `ranch` will return a list of return values from each worker, indexed by their ranks.
If ``False``, and if 'caller_rank' is not None (meaning parent process is a worker),
`ranch()` will return whatever the parent process' `fn(...)` returns.
ctx: User defined context manager to be used in a 'with'-clause around the 'fn(...)' call in worker processes.
Subclassed from AbstractContextManager, ctx needs to define '__enter__()' and '__exit__()' methods.
need: Space-separated names of objects/functions to be serialized over to the subprocesses.
imports: A multiline string of `import` statements to execute in the subprocesses
before `fn()` execution. Supported formats:
* `import x, y, z as zoo`
* `from A import x`
* `from A import z as zoo`
* `from A import x, y, z as zoo`
* Not supported: `from A import (x, y)`
Returns:
``None``, or list of results from worker processes, indexed by their `LOCAL_RANK`: ``[res_0, res_1, .... res_{nprocs-1}]``
"""
assert nprocs > 0, ValueError("nprocs: # of processes to launch must be > 0")
children_ranks = list(range(nprocs))
if caller_rank is not None:
assert 0 <= caller_rank < nprocs, ValueError(f"Invalid caller_rank {caller_rank}, must satisfy 0 <= caller_rank < {nprocs}")
children_ranks.pop(caller_rank)
multiproc_ctx, procs = mp.get_context("spawn"), []
result_list = multiproc_ctx.Manager().list([None] * nprocs) if gather else None
try:
# pass globals in this process to subprocess via fn's wrapper, 'target_fn'
env = {k : sys.modules['__main__'].__dict__[k] for k in need.split()}
for rank in children_ranks:
target_fn = _contextualize(rank, nprocs, fn, cm=ctx, l=result_list, env=env, imports=imports)
p = multiproc_ctx.Process(target=target_fn, args=args, kwargs=kwargs)
procs.append(p)
p.start()
p_res = (_contextualize(caller_rank, nprocs, fn, cm=ctx, l=result_list, env=env, imports=imports))(*args, **kwargs) if caller_rank is not None else None
for p in procs: p.join()
return result_list if gather else p_res
finally:
for p in procs: p.terminate(), p.join()
[docs]class TorchDDPCtx(AbstractContextManager):
"""
A context manager to set up and tear down a PyTorch distributed data parallel process group.
`os.environ['LOCAL_RANK']` must be defined prior to `__enter__()`.
Args:
world_size: total number of members in the DDP group
base_rank: the starting, lowest rank value of among the forked local processes
use_gpu: if True, will set the default CUDA device base on `os.environ['LOCAL_RANK']`
addr, port, num_threads: see PyTorch distributed data parallel documentation.
"""
def __init__(self, *args, world_size:int=None, base_rank:int=0, use_gpu:bool=True,
addr:str="127.0.0.1", port:int=29500, num_threads:int=1, **kwargs):
assert world_size and (base_rank >= 0 and world_size > base_rank), ValueError(f"Invalid world_size {world_size} or base_rank {base_rank}. Need to be: world_size > base_rank >=0 ")
self._ws, self._base_rank = world_size, base_rank
self._a, self._p, self._nt = addr, str(port), str(num_threads)
self._use_gpu = use_gpu and torch.cuda.is_available()
self._myddp, self._backend = False, 'gloo' # default to CPU backend
def __enter__(self):
import os
try: local_rank, local_ws = int(os.environ['LOCAL_RANK']), int(os.environ['LOCAL_WORLD_SIZE'])
except KeyError: raise KeyError(f"'LOCAL_RANK' or 'LOCAL_RANK' not found in os.environ")
assert 0 < local_ws <= self._ws, ValueError(f"Invalid 'LOCAL_WORLD_SIZE': {local_ws}, should be 0 < ws <= {self._ws}!")
rank = local_rank + self._base_rank
assert rank < self._ws, ValueError(f"local_rank {local_rank} + base_rank {self._base_rank}, should be < ({self._ws})")
if self._use_gpu:
assert local_rank<torch.cuda.device_count(), ValueError(f"LOCAL_RANK {local_rank} > available CUDA devices")
try:
torch.cuda.set_device(local_rank)
self._backend = 'nccl'
print(f"Rank [{rank}] using CUDA GPU {local_rank}", flush=True)
except RuntimeError as e:
self._use_gpu = False;
print(f"Unable to set cuda device {local_rank}, using CPU. {e}", flush=True)
os.environ.update({"WORLD_SIZE":str(self._ws), "RANK":str(rank),
"MASTER_ADDR":self._a, "MASTER_PORT":self._p, "OMP_NUM_THREADS":self._nt})
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=self._backend, init_method='env://')
self._myddp = torch.distributed.is_initialized()
return self
def __exit__(self, exc_type, exc_value, traceback):
if self._myddp: torch.distributed.destroy_process_group()
if self._use_gpu: torch.cuda.empty_cache()
for k in ["WORLD_SIZE", "RANK", "MASTER_ADDR", "MASTER_PORT", "OMP_NUM_THREADS"]: os.environ.pop(k, None)
return exc_type is None
[docs]def in_torchddp(nprocs:int, fn:Callable, *args, world_size:int=None, base_rank:int=0,
ctx:TorchDDPCtx=None, use_gpu:bool=True, need:str="", imports:str="", **kwargs):
"""A convenience routine to prepare a context manager for PyTorch Distributed Data Parallel group setup/teardown,
then calls `ranch()` to fork and execute `fn(*args, **kwargs)`
Args:
nprocs: Number of local processes to fork
fn, \*args, \*\*kwargs: the functions and its arguments
world_size: total number of members in the entire PyTorch DDP group
base_rank: the lowest, starting rank of in the local processes
ctx: by default will use `mpify.TorchDDPCtx` to set up torch distributed group,
but user can override it with their own if necessary.
use_gpu: a hint to suggest using GPU if available.
need: names of local objects to serialize over, comma-separated
imports: multi-line import statements, to apply in each forked process.
Returns:
The result of `fn(*args, **kwargs)` in the rank `base_rank` execution.
"""
if world_size is None: world_size = nprocs
assert base_rank + nprocs <= world_size, ValueError(f"nprocs({nprocs}) + base_rank({base_rank}) must be < world_size({world_size})")
if ctx is None: ctx = TorchDDPCtx(world_size=world_size, base_rank=base_rank, use_gpu=use_gpu)
return ranch(nprocs, fn, *args, caller_rank=0, gather=False, ctx=ctx, need=need, imports=imports, **kwargs)