Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds DLPack support #57110

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Review changes and test
  • Loading branch information
Emilio Castillo committed Sep 8, 2021
commit aa107ef2b7a9afb24c3938f1c1c6dfb3056cfaa7
44 changes: 42 additions & 2 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7095,8 +7095,8 @@ def compare_strides(s1, s2, div):
_test_helper(x, op, unary=True)

@skipMeta
@dtypes(*get_all_dtypes())
def test_dlpack_conversion(self, device, dtype):
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_capsule_conversion(self, device, dtype):
# DLpack does not explicitly support bool
# It does it through uint8 type
if dtype is torch.bool:
Expand All @@ -7105,6 +7105,46 @@ def test_dlpack_conversion(self, device, dtype):
z = from_dlpack(to_dlpack(x))
emcastillo marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(z, x)

@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_protocol_conversion(self, device, dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a test that compares the DLPack semantics with our .numpy() and from_numpy() semantics. If/when NumPy implements the protocol we could really validate that the behavior is the same

Copy link
Collaborator Author

@emcastillo emcastillo Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand what should be compared here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, I think I understood it,
A test that checks that tensors from dlpack can't be resized, or tensors with gradients can't be exported in the same sense that numpy does, is this correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes: tensors with gradients or the conjugate bit set, I think, can't be exported.

For importing we should verify that the underlying memory is shared by writing to it on both CPU and CUDA.

NumPy arrays can also be non-writable, which we check for on import. I'm not sure what (if any) special properties DLPack capsules have that PyTorch can't emulate.

# DLpack does not explicitly support bool
# It does it through uint8 type
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add (xref dmlc/dlpack#75) - adding kDLBool is being discussed there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first 6 lines of each test are the same, maybe worth leaving the comment only in the first test, and shortening the rest to:

if dtype is torch.bool or 'xla' in device:
    return

Copy link
Collaborator Author

@emcastillo emcastillo Jun 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much better! thanks :)

if dtype is torch.bool:
return
x = make_tensor((5,), device, dtype, low=-9, high=9)
z = from_dlpack(x)
self.assertEqual(z, x)

@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_conversion_with_streams(self, device, dtype):
# DLpack does not explicitly support bool
# It does it through uint8 type
if dtype is torch.bool:
return
# Create a stream where the tensor will reside
if device == 'cuda':
x = make_tensor((5,), device, dtype, low=-9, high=9)
stream = torch.cuda.Stream()
Copy link
Collaborator

@mruberry mruberry Jul 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need a synchronize call here to ensure x has been populated on a CUDA device before it's accessed on a different stream?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we want to test how the dlpack interface does this synchronization, so we avoid doing it here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment to clarify this

with torch.cuda.stream(stream):
assert stream.query()
z = from_dlpack(to_dlpack(x))
assert not stream.query()
assert stream.query()
self.assertEqual(z, x)

@skipMeta
@dtypes(*torch.testing.get_all_dtypes())
def test_dlpack_tensor_invalid_stream(self, device, dtype):
# DLpack does not explicitly support bool
# It does it through uint8 type
if dtype is torch.bool:
return
with self.assertRaises(TypeError):
x = make_tensor((5,), device, dtype, low=-9, high=9)
x.__dlpack__(stream=object())

@onlyCUDA
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
def test_pin_memory_from_constructor(self, device):
Expand Down
38 changes: 23 additions & 15 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
import enum
import functools
from numbers import Number
from typing import Any, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -1063,18 +1064,19 @@ def __dlpack__(self, stream=None):

Args:
stream (integer or None): A Python integer representing a pointer
to a stream. `stream` is provided by the consumer to the producer
to instruct the producer to ensure that operations can safely be
performed on the array. The pointer must be a positive integer or
to a stream (CUDA or ROCm). `stream` is provided by the consumer
to the producer to instruct the producer to ensure that operations
can safely be performed on the array.
The pointer must be a positive integer or
-1 . If stream is -1 , the value may be used by the consumer to
signal "producer must not perform any synchronization. Optional.
"""
if stream is not None and type(stream) is not int:
# currently in pytorch is not possible to create a stream
# from a given pointer
# Stream pointers in CUDA/ROCm are uniquely numbered and can
# be retrieved from their integer value.
raise TypeError('stream must be ``int`` or ``none``')
rgommers marked this conversation as resolved.
Show resolved Hide resolved
elif stream is not None and stream != -1:
mruberry marked this conversation as resolved.
Show resolved Hide resolved
if self.device.type in ('cuda', 'rocm'):
if self.device.type == 'cuda':
stream = torch.cuda.streams.ExternalStream(stream)
# Only synchronize on different streams
if stream != torch.cuda.current_stream:
Expand All @@ -1083,16 +1085,22 @@ def __dlpack__(self, stream=None):
torch.cuda.current_stream().wait_event(event)
return torch.utils.dlpack.to_dlpack(self)

def __dlpack_device__(self) -> Tuple[int, int]:
# TODO(ecastill)
# Add support for the following devices
# CPU = 1 CPU_PINNED = 3 OPENCL = 4 VULKAN = 7
# METAL = 8 VPI = 9
dlpack_ids = {'cpu': 1, 'cuda': 2, 'rocm': 10}
def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
class DLPackIds(enum.IntEnum):
cpu = 1
cpu_pinned = 3
cuda = 2
opencl = 4
vulkan = 7
rocm = 10

idx = self.device.index if self.device.index is not None else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still has TODO's. I think it would be nice if this returned Tuple[enum.IntEnum, int] as in the spec: https://data-apis.org/array-api/latest/API_specification/array_object.html#dlpack-device-self

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit out-of-scope but if we were to support these other devices, how would the stream support work?
Should it be ignored in environments where a stream does not make any sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what @rgommers meant is to change the return type of this function:

def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:

This is a bit out-of-scope but if we were to support these other devices, how would the stream support work?
Should it be ignored in environments where a stream does not make any sense?

For __dlpack_device__ whether a device has the concept of stream/queue doesn't matter. For __dlpack__ stream can be Any:
https://data-apis.org/array-api/latest/API_specification/array_object.html#dlpack-self-stream-none

# TODO(ecastill) detect HIP or CUDA
# in torch rocm device is cuda too
return (dlpack_ids[self.device.type], idx)
device_type = self.device.type
if device_type == 'cuda' and torch.version.hip is not None:
device_type = 'rocm'
elif device_type == 'cpu' and self.is_pinned():
device_type = 'cpu_pinned'
emcastillo marked this conversation as resolved.
Show resolved Hide resolved
return (DLPackIds[device_type], idx)

__module__ = 'torch'

Expand Down
19 changes: 11 additions & 8 deletions torch/utils/dlpack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import torch

from torch._C import _from_dlpack
Expand All @@ -15,17 +17,18 @@
""")


def from_dlpack(ext_tensor) -> torch.Tensor:
def from_dlpack(ext_tensor: Any) -> torch.Tensor:
"""from_dlpack(ext_tensor) -> Tensor

Decodes a DLPack to a tensor.

Args:
ext_tensor: a PyCapsule object with the dltensor
Convers a tensor from a external library into a ``torch.Tensor``
Copy link
Collaborator

@mruberry mruberry Jul 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Converts a DLPack capsule or object implementing the DLPack protocol into a tensor."

The phrase "DLPack capsule" should be a link to the DLPack documentation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment still needs to be resolved

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The phrase "DLPack capsule" should be a link to the DLPack documentation

There's no such thing yet. DLPack itself only has comments in dlpack.h. __dlpack__ is part of the array API standard, see here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with adding the link to the API standard, would it be possible to do it in another PR?
Thanks!

by means of the ``__dlpack__`` protocol.

The tensor will share the memory with the object represented
mruberry marked this conversation as resolved.
Show resolved Hide resolved
in the dlpack.
Note that each dlpack can only be consumed once.

In order to keep backward compatibility, this function also admits
to pass a dlpack capsule object.
Note that each dlpack capsule can only be consumed once.

Args:
emcastillo marked this conversation as resolved.
Show resolved Hide resolved
ext_tensor (object with __dlpack__ attribute or dlpack capsule):
Expand All @@ -45,5 +48,5 @@ def from_dlpack(ext_tensor) -> torch.Tensor:
dlpack = ext_tensor.__dlpack__()
else:
# Old versions just call the converter
dlpack = tensor
_from_dlpack(dlpack)
dlpack = ext_tensor
return _from_dlpack(dlpack)