Skip to content

Commit

Permalink
Refine code by clang-format
Browse files Browse the repository at this point in the history
Signed-off-by: Gu, Jinghui <jinghui.gu@intel.com>
  • Loading branch information
gujinghui committed Jul 12, 2018
1 parent 1d801ed commit 90213c0
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 39 deletions.
73 changes: 37 additions & 36 deletions caffe2/opt/optimize_ideep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,51 +12,51 @@ namespace opt {
using namespace nom;

#ifndef CAFFE2_USE_IDEEP
void OptimizeForIdeep(repr::NNModule* nn, caffe2::Workspace* ws, bool training_mode) {
void OptimizeForIdeep(repr::NNModule *nn, caffe2::Workspace *ws,
bool training_mode) {
LOG(WARNING) << "Only support optimizations for IDEEP";
}

#else
USE_IDEEP_DEF_ALIASES();

Blob* getBlob(repr::NNGraph::NodeRef node, caffe2::Workspace* ws) {
Blob *getBlob(repr::NNGraph::NodeRef node, caffe2::Workspace *ws) {
auto tensor = repr::nn::get<repr::Tensor>(node);
assert(ws->HasBlob(tensor->getName()) && "Blob not in workspace");
return ws->GetBlob(tensor->getName());
}

template<class T>
T* getTensor(Blob* blob) {
template <class T> T *getTensor(Blob *blob) {
assert(blob && "Blob is invalid");
if (blob && blob->template IsType<T>()) {
return blob->template GetMutable<T>();
}
return nullptr;
}

const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) {
const caffe2::OperatorDef &getOpDef(const repr::NeuralNetOperator &nnOp) {
auto annotation = nnOp.getAnnotation();
if (annotation == nullptr) {
CAFFE_THROW("Cannot get Operator annotation");
}
return dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
}

caffe2::OperatorDef* getMutableOpDef(repr::NeuralNetOperator& nnOp) {
caffe2::OperatorDef *getMutableOpDef(repr::NeuralNetOperator &nnOp) {
auto annotation = nnOp.getMutableAnnotation();
if (annotation == nullptr) {
CAFFE_THROW("Cannot get Operator annotation");
}
return dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
}

bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) {
bool isOnIdeepDevice(const repr::NeuralNetOperator &nnOp) {
// We only want to fuse for IDEEP convs
const auto& op = getOpDef(nnOp);
const auto &op = getOpDef(nnOp);
return op.device_option().device_type() == DeviceType::IDEEP;
}

bool shouldFuseConv(const repr::Conv& conv) {
bool shouldFuseConv(const repr::Conv &conv) {
return isOnIdeepDevice(conv) ? (conv.getGroup() <= 1) : false;
}

Expand All @@ -71,14 +71,14 @@ void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
return;
}

auto* op = getMutableOpDef(*conv);
auto *op = getMutableOpDef(*conv);
if (op == nullptr) {
return;
}

if (op->type() == "ConvFusion") {
assert(fusion_type == 1 && "Invalid nest fusion");
for (auto& arg : *op->mutable_arg()) {
for (auto &arg : *op->mutable_arg()) {
if (arg.name() == "fusion_type") {
// Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU
assert(arg.i() == 2 && "Invalid nest fusion");
Expand All @@ -91,24 +91,24 @@ void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {

assert(fusion_type < 3 && "Invalid fusion type");
op->set_type("ConvFusion");
auto* arg = op->add_arg();
auto *arg = op->add_arg();
arg->set_name("fusion_type");
arg->set_i(fusion_type);
}

bool fuseConvBNHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
bool fuseConvBNHelperForIdeep(repr::NNModule *nn, caffe2::Workspace *ws) {
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
bool no_bias = false;
repr::NNGraph::NodeRef convNode;
repr::Conv* conv;
repr::Conv *conv;
std::tie(conv, convNode) = node_pair;

if (!isOnIdeepDevice(*conv)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}

const auto& op = getOpDef(*conv);
const auto &op = getOpDef(*conv);
if (op.type() == "ConvFusion") {
continue;
}
Expand Down Expand Up @@ -148,15 +148,15 @@ bool fuseConvBNHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
}

#define EXPOSE_TENSOR_DATA(name, index, nodes) \
auto* name = getTensor<itensor>(getBlob(nodes[index], ws)); \
if (name == nullptr) { \
LOG(WARNING) << #name " not a IDEEP tensor"; \
continue; \
} \
itensor name##Tensor({name->get_dims(), name->get_data_type()}); \
name##Tensor.reorder_from(*name); \
assert(name##Tensor.is_public_format() && #name " not with public format");\
auto* name##Data = static_cast<float*>(name##Tensor.get_data_handle());
auto *name = getTensor<itensor>(getBlob(nodes[index], ws)); \
if (name == nullptr) { \
LOG(WARNING) << #name " not a IDEEP tensor"; \
continue; \
} \
itensor name##Tensor({name->get_dims(), name->get_data_type()}); \
name##Tensor.reorder_from(*name); \
assert(name##Tensor.is_public_format() && #name " not with public format"); \
auto *name##Data = static_cast<float *>(name##Tensor.get_data_handle());

EXPOSE_TENSOR_DATA(filter, 1, convInputs);
EXPOSE_TENSOR_DATA(biasConv, 2, convInputs);
Expand All @@ -170,7 +170,7 @@ bool fuseConvBNHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {

// Assume M{CHW,HWC}
auto chwDim = filterTensor.get_dim(1) * filterTensor.get_dim(2) *
filterTensor.get_dim(3);
filterTensor.get_dim(3);
for (auto c = 0; c < filterTensor.get_dim(0); ++c) {
float coeff =
scaleData[c] / std::sqrt(varianceData[c] + bn->getEpsilon());
Expand All @@ -180,7 +180,8 @@ bool fuseConvBNHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
if (no_bias) {
biasConvData[c] = biasBNData[c] - meanData[c] * coeff;
} else {
biasConvData[c] = biasBNData[c] + (biasConvData[c] - meanData[c]) * coeff;
biasConvData[c] =
biasBNData[c] + (biasConvData[c] - meanData[c]) * coeff;
}
}

Expand All @@ -197,12 +198,12 @@ bool fuseConvBNHelperForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
return false;
}

void fuseConvBNForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
void fuseConvBNForIdeep(repr::NNModule *nn, caffe2::Workspace *ws) {
while (fuseConvBNHelperForIdeep(nn, ws)) {
}
}

void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
void fuseConvSumForIdeep(repr::NNModule *nn, caffe2::Workspace *ws) {
// Assume the order of nodes from getMutableNodes conforms to
// the original topo order of operators
auto allNodes = nn->dataFlow.getMutableNodes();
Expand Down Expand Up @@ -268,7 +269,7 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {

auto convOutput = repr::nn::getOutputs(convNode).front();
repr::NNGraph::NodeRef sumInputX =
(sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]);
(sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]);
assert(sumInputX != nullptr && "Invalid sum inputs");

auto preNode = repr::nn::getProducer(sumInputX);
Expand Down Expand Up @@ -296,34 +297,34 @@ void fuseConvSumForIdeep(repr::NNModule* nn, caffe2::Workspace* ws) {
}
}

void fuseActivationForIdeep(repr::NNModule* nn) {
void fuseActivationForIdeep(repr::NNModule *nn) {
// Conv+Relu fusion
auto should_fuse = shouldFuseConv;
auto postprocess = std::bind(resetConvForFusion, std::placeholders::_1, 1);
fuseActivation<repr::Conv, repr::Relu>(nn, should_fuse, postprocess);
}

void enforceFusionInplaceForIdeep(repr::NNModule* nn) {
void enforceFusionInplaceForIdeep(repr::NNModule *nn) {
// For fusions of Conv+Sum or Conv+Sum+ReLU, the last input and output must
// be inplaced. To enforce inplace, here to re-check whole graph and correct
// the ConvFusion Ops.
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
repr::NNGraph::NodeRef convNode;
repr::Conv* conv;
repr::Conv *conv;
std::tie(conv, convNode) = node_pair;

if (!isOnIdeepDevice(*conv)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}

const auto& op = getOpDef(*conv);
const auto &op = getOpDef(*conv);
if (op.type() != "ConvFusion") {
continue;
}

bool enforce_inplace = false;
for (const auto& arg : op.arg()) {
for (const auto &arg : op.arg()) {
// Only check FUSION_SUM & FUSION_SUM_RELU
if (arg.name() == "fusion_type" && (arg.i() == 2 || arg.i() == 3)) {
enforce_inplace = true;
Expand Down Expand Up @@ -358,7 +359,8 @@ void enforceFusionInplaceForIdeep(repr::NNModule* nn) {
}
}

void OptimizeForIdeep(repr::NNModule* nn, caffe2::Workspace* ws, bool training_mode) {
void OptimizeForIdeep(repr::NNModule *nn, caffe2::Workspace *ws,
bool training_mode) {
if (training_mode) {
// Only support inference so far
return;
Expand All @@ -371,7 +373,6 @@ void OptimizeForIdeep(repr::NNModule* nn, caffe2::Workspace* ws, bool training_m
fuseActivationForIdeep(nn);

enforceFusionInplaceForIdeep(nn);

}

#endif // CAFFE2_USE_IDEEP
Expand Down
6 changes: 3 additions & 3 deletions caffe2/opt/optimize_ideep.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#pragma once

#include "caffe2/core/common.h"
#include "caffe2/proto/caffe2.pb.h"
#include "caffe2/core/workspace.h"
#include "caffe2/proto/caffe2.pb.h"
#include "nomnigraph/Representations/NeuralNet.h"

namespace caffe2 {
namespace opt {

void OptimizeForIdeep(nom::repr::NNModule* nn, caffe2::Workspace* ws, bool training_mode = false);

void OptimizeForIdeep(nom::repr::NNModule *nn, caffe2::Workspace *ws,
bool training_mode = false);
}
} // namespace caffe2

0 comments on commit 90213c0

Please sign in to comment.