Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Nov 19, 2022
1 parent f1411ff commit e4199ac
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 67 deletions.
17 changes: 8 additions & 9 deletions .github/workflows/unittest.yml → .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
name: Unittest
name: Build
on:
push:
branches:
- main
jobs:
cleaner:
runs-on: self-hosted
steps:
- name: Clean up previous runs
run: rm -rf "${{ github.workspace }}"
# cleaner:
# runs-on: self-hosted
# steps:
# - name: Clean up previous runs
# run: rm -rf "${{ github.workspace }}"

build-docker-image:
runs-on: self-hosted
Expand All @@ -26,9 +26,8 @@ jobs:
- name: Build Docker image
run: |
cd docker
DOCKER_BUILDKIT=1 docker build \
. \
docker/ \
--file Dockerfile.ci_gpu \
--tag ${{ steps.generate-tag.outputs.tag }}
Expand All @@ -37,4 +36,4 @@ jobs:
needs: build-docker-image
steps:
- name: Run tests
run: docker run ${{ needs.build-docker-image.outputs.tag }} bash tests/scripts/task_python_sparsetir_unittest.sh
run: docker run ${{ needs.build-docker-image.outputs.tag }} bash ../tests/scripts/task_python_sparsetir_unittest.sh
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ SparseTIR: Sparse Tensor Compiler for Deep Learning
[Documentation](https://sampl.cs.washington.edu/sparsetir/) |
[Paper](https://arxiv.org/abs/2207.04606)

[![Build Status](https://github.com/uwsampl/sparsetir/actions/workflows/build.yml/badge.svg)](https://github.com/uwsampl/sparsetir/actions/workflows/build.yml)
[![Documentation Status](https://github.com/uwsampl/sparsetir/actions/workflows/docs.yml/badge.svg)](https://github.com/uwsampl/sparsetir/actions/workflows/docs.yml)

SparseTIR is a tensor-level compiler for sparse/irregular operators in Deep Learning. The design goal of SparseTIR is to provide a general programming abstraction that can cover both sparse and irregular (e.g. Ragged Tensors) workloads in Deep Learning including Graph Neural Networks, Sparse Transformers, Sparse Convolutions, Network Pruning, etc. while generating high-performance code on heterogeneous hardware.

The key innovation of SparseTIR is *composability*:
Expand Down
25 changes: 1 addition & 24 deletions tests/python/sparsetir/test_horizontal_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def original(
C1: T.Buffer[(128,), "float32"],
C2: T.Buffer[(64,), "float32"],
) -> None:
T.func_attr({"horizontal_fuse": "sequential"})
T.func_attr({"horizontal_fuse": 1})
for i, j in T.grid(128, 128):
with T.block("first"):
vi, vj = T.axis.remap("SR", [i, j])
Expand All @@ -42,22 +42,6 @@ def original(
C2[vi] = T.float32(0)
C2[vi] = C2[vi] + B[vi, vj]

# from tvm.script import tir as T
@T.prim_func
def local_alloc(A: T.Buffer[(200,), "float32"], B: T.Buffer[(200,), "float32"]) -> None:
# var definition
blockIdx_x = T.env_thread("blockIdx.x")
# body
T.launch_thread(blockIdx_x, 200)
# if blockIdx_x < 100:
C_local = T.allocate([1], "float32", "local")
C_local[0] = T.float32(0)
A[blockIdx_x] = C_local[0]
# else:
C_local_1 = T.allocate([1], "float32", "local")
C_local_1[0] = T.float32(0)
B[blockIdx_x] = C_local_1[0]


def test_end_to_end():
sch = tvm.tir.Schedule(original)
Expand Down Expand Up @@ -101,12 +85,5 @@ def test_end_to_end():
# tvm.testing.assert_allclose(z2.numpy(), z2_golden, rtol=1e-5, atol=1e-5)


def test_local_alloc():
mod = tvm.IRModule.from_expr(local_alloc)
mod = tir.transform.StorageRewrite()(mod)
print(mod["main"].script())


if __name__ == "__main__":
test_end_to_end()
# test_local_alloc()
51 changes: 30 additions & 21 deletions tests/python/sparsetir/test_merging_binary_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,53 @@ def func(indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, nnz: T.int
def lowered(indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, nnz: T.int32) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True, "sparse_tir_level": 1})
I = T.dense_fixed(m, "int32")
J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32")
J_dense = T.dense_variable(I, (n, nnz), indptr, "int32")
I = T.dense_fixed(m, idtype="int32")
J = T.sparse_variable(I, (n, nnz), (indptr, indices), idtype="int32", sorted=True)
J_dense = T.dense_variable(I, (n, nnz), indptr, idtype="int32")
J_indptr = T.match_sparse_buffer(indptr, [I], dtype="int32", extra_storage=1)
J_indices = T.match_sparse_buffer(indices, [I, J_dense], dtype="int32")
# body
# with T.block("root")
A = T.alloc_sparse_buffer([I, J], dtype="float32", extra_storage=0)
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
mid_0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
for v_vj in T.serial(nnz):
with T.block("binary_search_0"):
ax1 = T.axis.spatial(nnz, v_vj)
mid_0 = T.alloc_sparse_buffer([I, J], dtype="int32", extra_storage=0)
T.assume_buffer_domain(J_indptr, [0, nnz])
T.assume_buffer_domain(J_indices, [0, n])
T.assume_buffer_domain(mid_0, [0, m])
for vj in T.serial(nnz):
with T.block("binary_search_block_0_0"):
vvi = T.axis.spatial(1, 0)
vvj = T.axis.spatial(nnz, vj)
T.reads(J_indptr[0 : m + 1])
T.writes(mid_0[0])
T.block_attr({"sparse": True})
T.writes(mid_0[vvi, vvj])
T.block_attr({"preprocess": True, "sparse": True})
low = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
high = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local")
low[0] = 0
high[0] = m + 1
mid_0[vvi, vvj] = low[0] + (high[0] - low[0]) // 2
while low[0] < high[0]:
mid_0[0] = low[0] + (high[0] - low[0]) // 2
if J_indptr[mid_0[0]] > ax1:
high[0] = mid_0[0]
if J_indptr[mid_0[vvi, vvj]] > vvj:
high[0] = mid_0[vvi, vvj]
else:
low[0] = mid_0[0] + 1
mid_0[0] = mid_0[0] - 1
low[0] = mid_0[vvi, vvj] + 1
mid_0[vvi, vvj] = low[0] + (high[0] - low[0]) // 2
mid_0[vvi, vvj] = mid_0[vvi, vvj] - 1
for vj in T.serial(nnz):
with T.block("test0"):
vi = T.axis.spatial(1, 0)
vj = T.axis.spatial(nnz, v_vj)
T.reads(mid_0[0], J_indices[vi, vj])
T.writes(A[vi, vj])
vvi = T.axis.spatial(1, 0)
vvj = T.axis.spatial(nnz, vj)
T.reads(mid_0[vvi, vvj], J_indices[vvi, vvj])
T.writes(A[vvi, vvj])
T.block_attr({"sparse": True})
A[vi, vj] = (mid_0[0] + J_indices[vi, vj]) * (mid_0[0] - J_indices[vi, vj])
A[vvi, vvj] = (mid_0[vvi, vvj] + J_indices[vvi, vvj]) * (
mid_0[vvi, vvj] - J_indices[vvi, vvj]
)


def test_merging_binary_search():
mod = tvm.IRModule.from_expr(func)
mod = lower_sparse_iter(mod)
print(mod["main"].script())
tvm.ir.assert_structural_equal(mod["main"], lowered)


Expand Down
15 changes: 10 additions & 5 deletions tests/python/sparsetir/test_tir_sparse_script_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
"bsrmm",
"ellmm",
"csr_element_wise",
# "bmm",
"bmm",
"sddmm",
"fused_sddmm",
"square_sum",
"square_sum_two_K",
"fused_reduction_4d_2d",
"fused_reduction_4d_3d",
"rgcn_forward",
"rgcn_homo_forward",
"rgcn_hetero_forward",
"sparse_softmax",
"csr2bsr",
]
Expand Down Expand Up @@ -122,9 +123,13 @@ def specialize_fused_reduction_4d_3d(f):
return f.specialize({N: 16, NNZ_J: 128, NNZ_K: 256, NNZ_L: 1024})


def specialize_rgcn_forward(f):
N, R, FEAT_SIZE, NNZ = f.params[-4:]
return f.specialize({N: 128, R: 16, FEAT_SIZE: 128, NNZ: 1024})
def specialize_rgcn_homo_forward(f):
M, N, R, FEAT_SIZE, NNZ = f.params[-5:]
return f.specialize({M: 128, N: 128, R: 16, FEAT_SIZE: 128, NNZ: 1024})

def specialize_rgcn_hetero_forward(f):
M, N, R, FEAT_SIZE, NNZ_I, NNZ_J = f.params[-6:]
return f.specialize({M: 128, N: 128, R: 16, FEAT_SIZE: 128, NNZ_I: 32, NNZ_J: 1024})


def specialize_sparse_softmax(f):
Expand Down
16 changes: 8 additions & 8 deletions tests/scripts/task_lint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ tests/lint/cpplint.sh
echo "clang-format check..."
tests/lint/git-clang-format.sh

echo "Rust check..."
tests/lint/rust_format.sh
# echo "Rust check..."
# tests/lint/rust_format.sh

echo "black check..."
tests/lint/git-black.sh
Expand All @@ -59,11 +59,11 @@ echo "Linting the Python code..."
tests/lint/pylint.sh
tests/lint/flake8.sh

echo "Linting the JNI code..."
tests/lint/jnilint.sh
# echo "Linting the JNI code..."
# tests/lint/jnilint.sh

echo "Checking C++ documentation..."
tests/lint/cppdocs.sh
# echo "Checking C++ documentation..."
# tests/lint/cppdocs.sh

echo "Type checking with MyPy ..."
tests/scripts/task_mypy.sh
# echo "Type checking with MyPy ..."
# tests/scripts/task_mypy.sh

0 comments on commit e4199ac

Please sign in to comment.