mirror of
https://github.com/microsoft/mscclpp.git
synced 2026-05-13 01:36:10 +00:00
Merge mscclpp-lang to mscclpp project (#442)
First step to merge msccl-tools into mscclpp repo. In this step will move all msccl related code, pass the current tests and do some necessary refactor. Add `mscclpp.language` module Add `_InstructionOptimizer` and `DagOptimizer` class to optimize the dag Add `DagLower` to lower dag to intermediate representation Add documents for mscclpp.language Remove msccl related code
This commit is contained in:
33
python/mscclpp/language/rank.py
Normal file
33
python/mscclpp/language/rank.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BarrierInfo:
|
||||
def __init__(self, tb_list):
|
||||
self.tb_list = tb_list
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.tb_list == other.tb_list
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(self.tb_list))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Rank:
|
||||
rank_id: int
|
||||
current_max_barrier_id: int = 0
|
||||
current_barriers: Dict[BarrierInfo, int] = field(default_factory=dict)
|
||||
|
||||
def get_barrier_id(self, tb_list):
|
||||
barrier_info = BarrierInfo(tb_list)
|
||||
if barrier_info in self.current_barriers:
|
||||
return self.current_barriers[barrier_info]
|
||||
else:
|
||||
self.current_barriers[barrier_info] = self.current_max_barrier_id
|
||||
barrier_id = self.current_max_barrier_id
|
||||
self.current_max_barrier_id += 1
|
||||
return barrier_id
|
||||
Reference in New Issue
Block a user