Skip to content

Commit

Permalink
Validate features of input request (pytorch#482)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#482

Why is validation in `populateTrecRequestFromTensorMap` in `TrecGpuUtils.cpp`?
- because I didn't want to iterate through all the features twice. we already have to iterate through all the features once to convert the request to a torchrec request, to be efficient, the validation should happen during this conversion as well
- if `populateTrecRequestFromTensorMap` throws an exception, it is handled by `TrecGpuMon.cpp` already

Reviewed By: zyan0

Differential Revision: D37385890

fbshipit-source-id: 119fb13fbe8e37fb28c08c3e726ab914162b97be
  • Loading branch information
s4ayub authored and facebook-github-bot committed Jun 29, 2022
1 parent 5fc0ae4 commit f80c740
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
26 changes: 26 additions & 0 deletions torchrec/inference/include/torchrec/inference/Validation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include "torchrec/inference/Types.h"

namespace torchrec {

// Returns whether sparse features (KeyedJaggedTensor) are valid.
// Currently validates:
// 1. Whether sum(lengths) == size(values)
// 2. Whether there are negative values in lengths
bool validateSparseFeatures(at::Tensor& values, at::Tensor& lengths);

// Returns whether dense features are valid.
// Currently validates:
// 1. Whether the size of values is divisable by batch size (request level)
bool validateDenseFeatures(at::Tensor& values, size_t batchSize);

} // namespace torchrec
50 changes: 50 additions & 0 deletions torchrec/inference/src/Validation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "torchrec/inference/Validation.h"
#include "ATen/Functions.h"

namespace torchrec {

bool validateSparseFeatures(at::Tensor& values, at::Tensor& lengths) {
auto flatLengths = lengths.view(-1);

// validate sum of lengths equals number of values
auto lengthsTotal = at::sum(flatLengths).item<int>();
if (lengthsTotal != values.size(0)) {
return false;
}

// Validate no negative values in lengths.
// Use faster path if contiguous.
if (flatLengths.is_contiguous()) {
int* ptr = (int*)flatLengths.data_ptr();
for (int i = 0; i < flatLengths.numel(); ++i) {
if (*ptr < 0) {
return false;
}
ptr++;
}
} else {
// accessor does boundary check (slower)
auto acc = flatLengths.accessor<int, 1>();
for (int i = 0; i < acc.size(0); i++) {
if (acc[i] < 0) {
return false;
}
}
}

return true;
}

bool validateDenseFeatures(at::Tensor& values, size_t batchSize) {
return values.size(0) % batchSize == 0;
}

} // namespace torchrec
39 changes: 39 additions & 0 deletions torchrec/inference/tests/ValidationTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "torchrec/inference/Validation.h"

#include <ATen/ATen.h>
#include <gtest/gtest.h>

TEST(ValidationTest, validateSparseFeatures) {
auto values = at::tensor({1, 2, 3, 4});
auto lengths = at::tensor({1, 1, 1, 1});

// pass 1D
EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths));

// pass 2D
lengths.reshape({2, 2});
EXPECT_TRUE(torchrec::validateSparseFeatures(values, lengths));

// fail 1D
auto invalidLengths = at::tensor({1, 2, 1, 1});
EXPECT_FALSE(torchrec::validateSparseFeatures(values, invalidLengths));

// fail 2D
invalidLengths.reshape({2, 2});
EXPECT_FALSE(torchrec::validateSparseFeatures(values, invalidLengths));
}

TEST(ValidationTest, validateDenseFeatures) {
auto values = at::tensor({1, 2, 3, 4});
EXPECT_TRUE(torchrec::validateDenseFeatures(values, 1));
EXPECT_TRUE(torchrec::validateDenseFeatures(values, 4));
EXPECT_FALSE(torchrec::validateDenseFeatures(values, 3));
}

0 comments on commit f80c740

Please sign in to comment.