Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the number of memory allocations in def-use #4904

Merged
merged 16 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Ensure AllDefinitions own StorageMap
Signed-off-by: Anton Korobeynikov <anton@korobeynikov.info>
  • Loading branch information
asl committed Sep 10, 2024
commit b90f7caea353a066b410ce0dfa00e396531a4566
29 changes: 14 additions & 15 deletions frontends/p4/def_use.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ void ComputeWriteSet::enterScope(const IR::ParameterList *parameters,

if (parameters != nullptr) {
for (auto p : parameters->parameters) {
StorageLocation *loc = allDefinitions->storageMap->getOrAdd(p);
StorageLocation *loc = allDefinitions->getOrAddStorage(p);
if (loc == nullptr) continue;
if (p->direction == IR::Direction::In || p->direction == IR::Direction::InOut ||
p->direction == IR::Direction::None)
Expand All @@ -408,7 +408,7 @@ void ComputeWriteSet::enterScope(const IR::ParameterList *parameters,
if (locals != nullptr) {
for (auto d : *locals) {
if (d->is<IR::Declaration_Variable>()) {
StorageLocation *loc = allDefinitions->storageMap->getOrAdd(d);
StorageLocation *loc = allDefinitions->getOrAddStorage(d);
if (loc != nullptr) {
defs->setDefinition(loc, uninit);
auto valid = loc->getValidBits();
Expand All @@ -432,14 +432,14 @@ void ComputeWriteSet::exitScope(const IR::ParameterList *parameters,
currentDefinitions = currentDefinitions->cloneDefinitions();
if (parameters != nullptr) {
for (auto p : parameters->parameters) {
StorageLocation *loc = allDefinitions->storageMap->getStorage(p);
StorageLocation *loc = allDefinitions->getStorage(p);
if (loc != nullptr) currentDefinitions->removeLocation(loc);
}
}
if (locals != nullptr) {
for (auto d : *locals) {
if (d->is<IR::Declaration_Variable>()) {
StorageLocation *loc = allDefinitions->storageMap->getStorage(d);
StorageLocation *loc = allDefinitions->getStorage(d);
if (loc != nullptr) currentDefinitions->removeLocation(loc);
}
}
Expand Down Expand Up @@ -549,8 +549,8 @@ bool ComputeWriteSet::preorder(const IR::PathExpression *expression) {
expressionWrites(expression, LocationSet::empty);
return false;
}
auto decl = storageMap->refMap->getDeclaration(expression->path, true);
auto storage = storageMap->getStorage(decl);
auto decl = refMap->getDeclaration(expression->path, true);
auto storage = allDefinitions->getStorage(decl);
const LocationSet *result;
if (storage != nullptr)
result = new LocationSet(storage);
Expand All @@ -566,15 +566,15 @@ bool ComputeWriteSet::preorder(const IR::Member *expression) {
expressionWrites(expression, LocationSet::empty);
return false;
}
auto type = storageMap->typeMap->getType(expression, true);
auto type = typeMap->getType(expression, true);
if (type->is<IR::Type_Method>()) return false;
if (TableApplySolver::isHit(expression, storageMap->refMap, storageMap->typeMap) ||
TableApplySolver::isMiss(expression, storageMap->refMap, storageMap->typeMap) ||
TableApplySolver::isActionRun(expression, storageMap->refMap, storageMap->typeMap))
if (TableApplySolver::isHit(expression, refMap, typeMap) ||
TableApplySolver::isMiss(expression, refMap, typeMap) ||
TableApplySolver::isActionRun(expression, refMap, typeMap))
return false;
auto storage = getWrites(expression->expr);

auto basetype = storageMap->typeMap->getType(expression->expr, true);
auto basetype = typeMap->getType(expression->expr, true);
if (basetype->is<IR::Type_Stack>()) {
if (expression->member.name == IR::Type_Stack::next ||
expression->member.name == IR::Type_Stack::last) {
Expand Down Expand Up @@ -675,7 +675,7 @@ bool ComputeWriteSet::preorder(const IR::MethodCallExpression *expression) {
// The method call may modify the object, which is part of the method
visit(expression->method);
lhs = save;
auto mi = MethodInstance::resolve(expression, storageMap->refMap, storageMap->typeMap);
auto mi = MethodInstance::resolve(expression, refMap, typeMap);
if (auto bim = mi->to<BuiltInMethod>()) {
const loc_t *methodLoc = getLoc(expression->method, getChildContext());
auto base = getWrites(bim->appliedTo, methodLoc);
Expand Down Expand Up @@ -811,7 +811,7 @@ bool ComputeWriteSet::preorder(const IR::P4Parser *parser) {
visitVirtualMethods(parser->parserLocals);

ParserCallGraph transitions("transitions");
ComputeParserCG pcg(storageMap->refMap, &transitions);
ComputeParserCG pcg(refMap, &transitions);
pcg.setCalledBy(this);

(void)parser->apply(pcg);
Expand Down Expand Up @@ -1013,8 +1013,7 @@ bool ComputeWriteSet::preorder(const IR::SwitchStatement *statement) {
visit(s->statement);
result = result->joinDefinitions(currentDefinitions);
}
auto table = TableApplySolver::isActionRun(statement->expression, storageMap->refMap,
storageMap->typeMap);
auto table = TableApplySolver::isActionRun(statement->expression, refMap, typeMap);
if (table) {
auto al = table->getActionList();
bool allCases = statement->cases.size() == al->size();
Expand Down
33 changes: 24 additions & 9 deletions frontends/p4/def_use.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,11 @@ class AllDefinitions : public IHasDbPrint {
/// P4Table, P4Function -- the definitions are BEFORE the
/// ProgramPoint.
hvec_map<ProgramPoint, Definitions *> atPoint;
StorageMap storageMap;

public:
StorageMap *storageMap;
AllDefinitions(ReferenceMap *refMap, TypeMap *typeMap)
: storageMap(new StorageMap(refMap, typeMap)) {}
AllDefinitions(ReferenceMap *refMap, TypeMap *typeMap) : storageMap(refMap, typeMap) {}

Definitions *getDefinitions(ProgramPoint point, bool emptyIfNotFound = false) {
auto it = atPoint.find(point);
if (it == atPoint.end()) {
Expand All @@ -511,6 +511,15 @@ class AllDefinitions : public IHasDbPrint {
}
atPoint[point] = defs;
}

StorageLocation *getStorage(const IR::IDeclaration *decl) const {
return storageMap.getStorage(decl);
}

StorageLocation *getOrAddStorage(const IR::IDeclaration *decl) {
return storageMap.getOrAdd(decl);
}

void dbprint(std::ostream &out) const override {
for (auto e : atPoint) out << e.first << " => " << e.second << Log::endl;
}
Expand All @@ -529,16 +538,19 @@ class AllDefinitions : public IHasDbPrint {

class ComputeWriteSet : public Inspector, public IHasDbPrint {
public:
explicit ComputeWriteSet(AllDefinitions *allDefinitions)
: allDefinitions(allDefinitions),
explicit ComputeWriteSet(AllDefinitions *allDefinitions, ReferenceMap *refMap, TypeMap *typeMap)
: refMap(refMap),
typeMap(typeMap),
allDefinitions(allDefinitions),
currentDefinitions(nullptr),
returnedDefinitions(nullptr),
exitDefinitions(new Definitions()),
storageMap(allDefinitions->storageMap),
lhs(false),
virtualMethod(false),
cached_locs(*new std::unordered_set<loc_t>) {
CHECK_NULL(allDefinitions);
CHECK_NULL(refMap);
CHECK_NULL(typeMap);
visitDagOnce = false;
}

Expand Down Expand Up @@ -588,14 +600,15 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
}

protected:
ReferenceMap *refMap;
TypeMap *typeMap;
AllDefinitions *allDefinitions; /// Result computed by this pass.
Definitions *currentDefinitions; /// Before statement currently processed.
Definitions *returnedDefinitions; /// Definitions after return statements.
Definitions *exitDefinitions; /// Definitions after exit statements.
Definitions *breakDefinitions = nullptr; /// Definitions at break statements.
Definitions *continueDefinitions = nullptr; /// Definitions at continue statements.
ProgramPoint callingContext;
const StorageMap *storageMap;
/// if true we are processing an expression on the lhs of an assignment
bool lhs;
/// For each program location the location set it writes
Expand All @@ -609,14 +622,16 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint {
/// Needed to visit some program fragments repeatedly.
ComputeWriteSet(const ComputeWriteSet *source, ProgramPoint context, Definitions *definitions,
std::unordered_set<loc_t> &cached_locs)
: allDefinitions(source->allDefinitions),
: refMap(source->refMap),
typeMap(source->typeMap),

allDefinitions(source->allDefinitions),
currentDefinitions(definitions),
returnedDefinitions(nullptr),
exitDefinitions(source->exitDefinitions),
breakDefinitions(source->breakDefinitions),
continueDefinitions(source->continueDefinitions),
callingContext(context),
storageMap(source->storageMap),
lhs(false),
virtualMethod(false),
cached_locs(cached_locs) {
Expand Down
46 changes: 23 additions & 23 deletions frontends/p4/simplifyDefUse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include "simplifyDefUse.h"

#include "absl/container/flat_hash_set.h"
#include "frontends/common/resolveReferences/referenceMap.h"
#include "frontends/p4/def_use.h"
#include "frontends/p4/methodInstance.h"
#include "frontends/p4/parserCallGraph.h"
Expand Down Expand Up @@ -135,7 +136,7 @@ DeclarationToExpression *DeclarationToExpression::instance = nullptr;
class HeaderDefinitions : public IHasDbPrint {
ReferenceMap *refMap;
TypeMap *typeMap;
StorageMap *storageMap;
AllDefinitions *definitions;

/// The current values of the header valid bits are stored here. If the value in the map is Yes,
/// then the header is currently valid. If the value in the map is No, then the header is
Expand All @@ -151,11 +152,11 @@ class HeaderDefinitions : public IHasDbPrint {
absl::flat_hash_set<const StorageLocation *, Util::Hash> notReport;

public:
HeaderDefinitions(ReferenceMap *refMap, TypeMap *typeMap, StorageMap *storageMap)
: refMap(refMap), typeMap(typeMap), storageMap(storageMap) {
HeaderDefinitions(ReferenceMap *refMap, TypeMap *typeMap, AllDefinitions *definitions)
: refMap(refMap), typeMap(typeMap), definitions(definitions) {
CHECK_NULL(refMap);
CHECK_NULL(typeMap);
CHECK_NULL(storageMap);
CHECK_NULL(definitions);
}

void dbprint(std::ostream &out) const {
Expand All @@ -171,7 +172,7 @@ class HeaderDefinitions : public IHasDbPrint {
LocationSet *result = new LocationSet;
if (auto expr = expression->to<IR::PathExpression>()) {
auto decl = refMap->getDeclaration(expr->path, true);
result->add(storageMap->getStorage(decl));
result->add(definitions->getStorage(decl));
} else if (auto expr = expression->to<IR::Member>()) {
auto base_storage = getStorageLocation(expr->expr);
for (auto bs : *base_storage) {
Expand Down Expand Up @@ -330,7 +331,7 @@ class HeaderDefinitions : public IHasDbPrint {
bool operator!=(const HeaderDefinitions &other) const { return !(*this == other); }

HeaderDefinitions *intersect(const HeaderDefinitions *other) const {
HeaderDefinitions *result = new HeaderDefinitions(refMap, typeMap, storageMap);
HeaderDefinitions *result = new HeaderDefinitions(refMap, typeMap, definitions);
for (const auto &def : defs) {
auto valid = ::P4::get(other->defs, def.first, TernaryBool::Maybe);
result->defs.emplace(def.first, valid == def.second ? valid : TernaryBool::Maybe);
Expand Down Expand Up @@ -405,8 +406,8 @@ class FindUninitialized : public Inspector {

FindUninitialized(FindUninitialized *parent, ProgramPoint context)
: context(context),
refMap(parent->definitions->storageMap->refMap),
typeMap(parent->definitions->storageMap->typeMap),
refMap(parent->refMap),
typeMap(parent->typeMap),
definitions(parent->definitions),
currentPoint(context),
hasUses(parent->hasUses),
Expand All @@ -416,13 +417,14 @@ class FindUninitialized : public Inspector {
}

public:
FindUninitialized(AllDefinitions *definitions, HasUses &hasUses)
: refMap(definitions->storageMap->refMap),
typeMap(definitions->storageMap->typeMap),
FindUninitialized(AllDefinitions *definitions, ReferenceMap *refMap, TypeMap *typeMap,
HasUses &hasUses)
: refMap(refMap),
typeMap(typeMap),
definitions(definitions),
currentPoint(),
hasUses(hasUses),
headerDefs(new HeaderDefinitions(refMap, typeMap, definitions->storageMap)) {
headerDefs(new HeaderDefinitions(refMap, typeMap, definitions)) {
CHECK_NULL(refMap);
CHECK_NULL(typeMap);
CHECK_NULL(definitions);
Expand Down Expand Up @@ -453,7 +455,7 @@ class FindUninitialized : public Inspector {
void initHeaderParams(const IR::ParameterList *parameters) {
if (!parameters) return;
for (auto p : parameters->parameters)
if (auto storage = definitions->storageMap->getStorage(p)) {
if (auto storage = definitions->getStorage(p)) {
headerDefs->setValueToStorage(storage, p->direction != IR::Direction::Out
? TernaryBool::Yes
: TernaryBool::No);
Expand All @@ -466,7 +468,7 @@ class FindUninitialized : public Inspector {
<< defs);
for (auto p : parameters->parameters) {
if (p->direction == IR::Direction::Out || p->direction == IR::Direction::InOut) {
auto storage = definitions->storageMap->getStorage(p);
auto storage = definitions->getStorage(p);
LOG3("Checking parameter: " << p);
if (storage == nullptr) continue;

Expand Down Expand Up @@ -614,8 +616,7 @@ class FindUninitialized : public Inspector {
reportInvalidHeaders = true;
for (auto state : parser->states) {
if (inputHeaderDefs.find(state) == inputHeaderDefs.end()) {
inputHeaderDefs.emplace(
state, new HeaderDefinitions(refMap, typeMap, definitions->storageMap));
inputHeaderDefs.emplace(state, new HeaderDefinitions(refMap, typeMap, definitions));
}
headerDefs = inputHeaderDefs[state];
visit(state);
Expand Down Expand Up @@ -662,7 +663,7 @@ class FindUninitialized : public Inspector {
// else let loc be the whole array
} else if (auto pe = parent->to<IR::PathExpression>()) {
auto decl = refMap->getDeclaration(pe->path, true);
auto storage = definitions->storageMap->getStorage(decl);
auto storage = definitions->getStorage(decl);
if (storage != nullptr)
loc = new LocationSet(storage);
else
Expand Down Expand Up @@ -1141,7 +1142,7 @@ class FindUninitialized : public Inspector {
LOG4("Declaration for path '" << expression->path << "' is " << Log::indent << Log::endl
<< decl << Log::unindent);

auto storage = definitions->storageMap->getStorage(decl);
auto storage = definitions->getStorage(decl);
const LocationSet *result;
if (storage != nullptr)
result = new LocationSet(storage);
Expand Down Expand Up @@ -1269,7 +1270,7 @@ class FindUninitialized : public Inspector {
if (auto actionCall = mi->to<ActionCall>()) {
if (auto param = actionCall->action->parameters->getParameter(p->name)) {
if (p->direction == IR::Direction::Out) {
headerDefs->setValueToStorage(definitions->storageMap->getStorage(param),
headerDefs->setValueToStorage(definitions->getStorage(param),
TernaryBool::No);
} else {
// we can treat the argument passing as an assignment
Expand All @@ -1287,8 +1288,7 @@ class FindUninitialized : public Inspector {
if (auto actionCall = mi->to<ActionCall>()) {
for (auto p : actionCall->action->parameters->parameters) {
if (p->direction == IR::Direction::None && !mi->substitution.contains(p)) {
headerDefs->setValueToStorage(definitions->storageMap->getStorage(p),
TernaryBool::Yes);
headerDefs->setValueToStorage(definitions->getStorage(p), TernaryBool::Yes);
}
}
}
Expand Down Expand Up @@ -1542,8 +1542,8 @@ class ProcessDefUse : public PassManager {

public:
ProcessDefUse(ReferenceMap *refMap, TypeMap *typeMap) : definitions(refMap, typeMap) {
passes.push_back(new ComputeWriteSet(&definitions));
passes.push_back(new FindUninitialized(&definitions, hasUses));
passes.push_back(new ComputeWriteSet(&definitions, refMap, typeMap));
passes.push_back(new FindUninitialized(&definitions, refMap, typeMap, hasUses));
passes.push_back(new RemoveUnused(hasUses, refMap, typeMap));
setName("ProcessDefUse");
}
Expand Down