Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Conv fusion optimizations in optimizeForIdeep #9255

Closed
wants to merge 4 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use CAFFE_ENFORCE, instead of assert
Signed-off-by: Gu, Jinghui <jinghui.gu@intel.com>
  • Loading branch information
gujinghui committed Jul 15, 2018
commit aedefb846360ef5234ec8bc895b061f884bc1afe
14 changes: 7 additions & 7 deletions caffe2/opt/optimize_ideep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ USE_IDEEP_DEF_ALIASES();

Blob *getBlob(repr::NNGraph::NodeRef node, caffe2::Workspace *ws) {
auto tensor = repr::nn::get<repr::Tensor>(node);
assert(ws->HasBlob(tensor->getName()) && "Blob not in workspace");
CAFFE_ENFORCE(ws->HasBlob(tensor->getName()), "Blob not in workspace");
return ws->GetBlob(tensor->getName());
}

template <class T> T *getTensor(Blob *blob) {
assert(blob && "Blob is invalid");
CAFFE_ENFORCE(blob, "Blob is invalid");
if (blob && blob->template IsType<T>()) {
return blob->template GetMutable<T>();
}
Expand Down Expand Up @@ -77,19 +77,19 @@ void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
}

if (op->type() == "ConvFusion") {
assert(fusion_type == 1 && "Invalid nest fusion");
CAFFE_ENFORCE(fusion_type == 1, "Invalid nest fusion");
for (auto &arg : *op->mutable_arg()) {
if (arg.name() == "fusion_type") {
// Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU
assert(arg.i() == 2 && "Invalid nest fusion");
CAFFE_ENFORCE(arg.i() == 2, "Invalid nest fusion");
arg.set_i(3);
return;
}
}
return;
}

assert(fusion_type < 3 && "Invalid fusion type");
CAFFE_ENFORCE(fusion_type < 3, "Invalid fusion type");
op->set_type("ConvFusion");
auto *arg = op->add_arg();
arg->set_name("fusion_type");
Expand Down Expand Up @@ -155,7 +155,7 @@ bool fuseConvBNHelperForIdeep(repr::NNModule *nn, caffe2::Workspace *ws) {
} \
itensor name##Tensor({name->get_dims(), name->get_data_type()}); \
name##Tensor.reorder_from(*name); \
assert(name##Tensor.is_public_format() && #name " not with public format"); \
CAFFE_ENFORCE(name##Tensor.is_public_format(), #name " not with public format"); \
auto *name##Data = static_cast<float *>(name##Tensor.get_data_handle());

EXPOSE_TENSOR_DATA(filter, 1, convInputs);
Expand Down Expand Up @@ -270,7 +270,7 @@ void fuseConvSumForIdeep(repr::NNModule *nn, caffe2::Workspace *ws) {
auto convOutput = repr::nn::getOutputs(convNode).front();
repr::NNGraph::NodeRef sumInputX =
(sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]);
assert(sumInputX != nullptr && "Invalid sum inputs");
CAFFE_ENFORCE(sumInputX != nullptr, "Invalid sum inputs");

auto preNode = repr::nn::getProducer(sumInputX);
if (preNode == nullptr || !repr::nn::is<repr::NeuralNetOperator>(preNode)) {
Expand Down