Skip to content

Commit

Permalink
add device to CUDAEvent (pytorch#9415)
Browse files Browse the repository at this point in the history
Summary:
This PR add a device_ member to CUDAEvent. This is necessary since if we create a cudaEvent on one device but destroy it from another, it also creates an additional context on that device. So this device information is needed to guard the cudaEventDestroy. (cc: ngimel is this expected behavior? I can provide a simple cu script to repro this).

c10d tests are probably not in CI yet, please let me know how the test are run and I could double check.

Thanks pietern apaszke for help debugging!
Pull Request resolved: pytorch#9415

Reviewed By: apaszke

Differential Revision: D8839688

Pulled By: ailzhang

fbshipit-source-id: b950ba37d57b9e3c5fe71726ec92f6a9601c4d0e
  • Loading branch information
Ailing Zhang authored and facebook-github-bot committed Jul 14, 2018
1 parent a4f6357 commit fda0340
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
9 changes: 8 additions & 1 deletion torch/lib/c10d/CUDAUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
namespace c10d {

CUDAEvent CUDAEvent::create(unsigned int flags) {
CUDAEvent event;
int current_device;
C10D_CUDA_CHECK(cudaGetDevice(&current_device));
CUDAEvent event(nullptr, current_device);

C10D_CUDA_CHECK(cudaEventCreateWithFlags(&event.event_, flags));
return event;
}

CUDAEvent::~CUDAEvent() {
if (event_ != nullptr) {
// cudaEventDestroy must run on the same device of the event,
// otherwise it creates a context on default device as well.
at::DeviceGuard guard(device_);

C10D_CUDA_CHECK(cudaEventDestroy(event_));
}
}
Expand Down
11 changes: 9 additions & 2 deletions torch/lib/c10d/CUDAUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace c10d {
// RAII wrapper for CUDA events.
class CUDAEvent {
public:
CUDAEvent(cudaEvent_t event) : event_(event) {}
CUDAEvent(cudaEvent_t event, int device) : device_(device), event_(event) {}

CUDAEvent() : CUDAEvent(nullptr) {}
CUDAEvent() : CUDAEvent(nullptr, 0) {}

~CUDAEvent();

Expand All @@ -27,19 +27,26 @@ class CUDAEvent {
// Must be move constructable.
CUDAEvent(CUDAEvent&& other) {
std::swap(event_, other.event_);
std::swap(device_, other.device_);
}

// Must be move assignable.
CUDAEvent& operator=(CUDAEvent&& other) {
std::swap(event_, other.event_);
std::swap(device_, other.device_);
return *this;
}

cudaEvent_t getEvent() const {
return event_;
}

int getDevice() const {
return device_;
}

protected:
int device_;
cudaEvent_t event_;
};

Expand Down

0 comments on commit fda0340

Please sign in to comment.