Skip to content

Commit

Permalink
Merge pull request hpcaitech#409 from 1SAA/develop
Browse files Browse the repository at this point in the history
[hotfix] fixed error when no collective communication in CommProfiler
  • Loading branch information
FrankLeeeee authored Mar 14, 2022
2 parents 62b08ac + 907ac4a commit 32296cf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 7 additions & 4 deletions colossalai/utils/profiler/comm_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ def disable(self):
dist.reduce = torch_reduce

def to_tensorboard(self, writer):
writer.add_text(tag="Collective Communication", text_string=self.result_list("\n\n"))
writer.add_text(tag="Collective Communication", text_string=self.result_str("\n\n"))

def to_file(self, filename: Path):
with open(filename, "w") as f:
f.write(self.result_list())
f.write(self.result_str())

def show(self):
print(self.result_list())
print(self.result_str())

def result_list(self, sep: str = "\n"):
def result_str(self, sep: str = "\n"):
res = []

def append(s: str = None):
Expand All @@ -114,6 +114,9 @@ def append(s: str = None):
append("Warnning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.")

if self.total_cuda_time == 0:
return "No collective communication has been called yet!"

append("Collective communication profiling result:")
append("total cuda time: {}".format(_format_time(self.total_cuda_time)))
append("average bandwidth: {}".format(_format_bandwidth(self.total_comm_vol, self.total_cuda_time)))
Expand Down
8 changes: 4 additions & 4 deletions colossalai/utils/profiler/pcie_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,16 @@ def disable(self):
self.profiler = None

def to_tensorboard(self, writer):
writer.add_text(tag="Data Transmission", text_string=self.result_list("\n\n"))
writer.add_text(tag="Data Transmission", text_string=self.result_str("\n\n"))

def to_file(self, filename: Path):
with open(filename, "w") as f:
f.write(self.result_list())
f.write(self.result_str())

def show(self):
print(self.result_list())
print(self.result_str())

def result_list(self, sep: str = "\n"):
def result_str(self, sep: str = "\n"):
res = []

def append(s: str = None):
Expand Down

0 comments on commit 32296cf

Please sign in to comment.