Skip to content

Commit

Permalink
updated cuda-to-pytorch.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dusty-nv committed May 2, 2023
1 parent b29cf49 commit 5b2e2be
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions python/examples/cuda-to-pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
print("failed to import torch - if you wish to test PyTorch interoperability, please install it")
sys.exit(0)

from jetson_utils import cudaAllocMapped
from jetson_utils import cudaImage


# parse the command line
Expand All @@ -44,19 +44,18 @@
print(args)

# allocate cuda memory
cuda_img = cudaAllocMapped(width=args.width, height=args.height, format=args.format)
cuda_img = cudaImage(width=args.width, height=args.height, format=args.format)

print(cuda_img)
print(cuda_img.__cuda_array_interface__)

# map to torch tensor using numba __cuda_array_interface__
# map to torch tensor using __cuda_array_interface__
tensor = torch.as_tensor(cuda_img, device='cuda')

print("\nPyTorch tensor:\n")
print(type(tensor))
print(hex(tensor.data_ptr()))
print(tensor.dtype)
print(tensor.shape)
print(f" -- ptr: {hex(tensor.data_ptr())}")
print(f" -- type: {tensor.dtype}")
print(f" -- shape: {tensor.shape}\n")
print(tensor)

# modify PyTorch tensor
Expand Down

0 comments on commit 5b2e2be

Please sign in to comment.