Skip to content

Latest commit

 

History

History
 
 

burn-tch

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

Burn Torch Backend

Burn Torch backend

Current Crates.io Version license

This crate provides a Torch backend for Burn utilizing the tch-rs crate, which offers a Rust interface to the PyTorch C++ API.

The backend supports CPU (multithreaded), CUDA (multiple GPUs), and MPS devices (MacOS).

Installation

tch-rs requires the C++ PyTorch library (LibTorch) to be available on your system.

By default, the CPU distribution is installed for LibTorch v2.2.0 as required by tch-rs.

CUDA

To install the latest compatible CUDA distribution, set the TORCH_CUDA_VERSION environment variable before the tch-rs dependency is retrieved with cargo.

export TORCH_CUDA_VERSION=cu121

On Windows:

$Env:TORCH_CUDA_VERSION = "cu121"

For example, running the validation sample for the first time could be done with the following commands:

export TORCH_CUDA_VERSION=cu121
cargo run --bin cuda --release

Important: make sure your driver version is compatible with the selected CUDA version. A CUDA Toolkit installation is not required since LibTorch ships with the appropriate CUDA runtimes. Having the latest driver version is recommended, but you can always take a look at the toolkit driver version table or minimum required driver version (limited feature-set, might not work with all operations).


Once your installation is complete, you should be able to build/run your project. You can also validate your installation by running the appropriate cpu, cuda or mps sample as below.

cargo run --bin cpu --release
cargo run --bin cuda --release
cargo run --bin mps --release

Note: no MPS distribution is available for automatic download at this time, please check out the manual instructions.

Manual Download

To install tch-rs with a different LibTorch distribution, you will have to manually download the desired LibTorch distribution. The instructions are detailed in the sections below for each platform.

Compute Platform CPU GPU Linux MacOS Windows Android iOS WASM
CPU Yes No Yes Yes Yes Yes Yes No
CUDA Yes [1] Yes Yes No Yes No No No
Metal (MPS) No Yes No Yes No No No No
Vulkan Yes Yes Yes Yes Yes Yes No No

[1] The LibTorch CUDA distribution also comes with CPU support.

CPU

🐧 Linux

First, download the LibTorch CPU distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcpu.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and LD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

🍎 Mac

First, download the LibTorch CPU distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-2.2.0.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and DYLD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export DYLD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$DYLD_LIBRARY_PATH

🪟 Windows

First, download the LibTorch CPU distribution.

wget https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.2.0%2Bcpu.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

Then, set the LIBTORCH environment variable and append the library to your path as with the PowerShell commands below before building burn-tch or a crate which depends on it.

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

CUDA

LibTorch 2.2.0 currently includes binary distributions with CUDA 11.8 or 12.1 runtimes. The manual installation instructions are detailed below.

CUDA 11.8

🐧 Linux

First, download the LibTorch CUDA 11.8 distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu118.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and LD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

Note: make sure your CUDA installation is in your PATH and LD_LIBRARY_PATH.


🪟 Windows

First, download the LibTorch CUDA 11.8 distribution.

wget https://download.pytorch.org/libtorch/cu118/libtorch-win-shared-with-deps-2.2.0%2Bcu118.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

Then, set the LIBTORCH environment variable and append the library to your path as with the PowerShell commands below before building burn-tch or a crate which depends on it.

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

CUDA 12.1

🐧 Linux

First, download the LibTorch CUDA 12.1 distribution.

wget -O libtorch.zip https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip
unzip libtorch.zip

Then, point to that installation using the LIBTORCH and LD_LIBRARY_PATH environment variables before building burn-tch or a crate which depends on it.

export LIBTORCH=/absolute/path/to/libtorch/
export LD_LIBRARY_PATH=/absolute/path/to/libtorch/lib:$LD_LIBRARY_PATH

Note: make sure your CUDA installation is in your PATH and LD_LIBRARY_PATH.


🪟 Windows

First, download the LibTorch CUDA 12.1 distribution.

wget https://download.pytorch.org/libtorch/cu121/libtorch-win-shared-with-deps-2.2.0%2Bcu121.zip -OutFile libtorch.zip
Expand-Archive libtorch.zip

Then, set the LIBTORCH environment variable and append the library to your path as with the PowerShell commands below before building burn-tch or a crate which depends on it.

$Env:LIBTORCH = "/absolute/path/to/libtorch/"
$Env:Path += ";/absolute/path/to/libtorch/"

Metal (MPS)

There is no official LibTorch distribution with MPS support at this time, so the easiest alternative is to use a PyTorch installation. This requires a Python installation.

Note: MPS acceleration is available on MacOS 12.3+.

pip install torch==2.2.0
export LIBTORCH_USE_PYTORCH=1
export DYLD_LIBRARY_PATH=/path/to/pytorch/lib:$DYLD_LIBRARY_PATH

Example Usage

For a simple example, check out any of the test programs in src/bin/. Each program sets the device to use and and performs a simple element-wise addition.

For a more complete example using the tch backend, take a loot at the Burn mnist example.