Skip to content

Commit

Permalink
[OpenMP] Add initial support for omp [begin/end] assumes
Browse files Browse the repository at this point in the history
The `assumes` directive is an OpenMP 5.1 feature that allows the user to
provide assumptions to the optimizer. Assumptions can refer to
directives (`absent` and `contains` clauses), expressions (`holds`
clause), or generic properties (`no_openmp_routines`, `ext_ABCD`, ...).

The `assumes` spelling is used for assumptions in the global scope while
`assume` is used for executable contexts with an associated structured
block.

This patch only implements the global spellings. While clauses with
arguments are "accepted" by the parser, they will simply be ignored for
now. The implementation lowers the assumptions directly to the
`AssumptionAttr`.

Reviewed By: ABataev

Differential Revision: https://reviews.llvm.org/D91980
  • Loading branch information
jdoerfert committed Dec 17, 2020
1 parent f48dae3 commit 2e6e4e6
Show file tree
Hide file tree
Showing 15 changed files with 740 additions and 3 deletions.
10 changes: 10 additions & 0 deletions clang/include/clang/Basic/DiagnosticParseKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,16 @@ def err_expected_end_declare_target_or_variant : Error<
def err_expected_begin_declare_variant
: Error<"'#pragma omp end declare variant' with no matching '#pragma omp "
"begin declare variant'">;
def err_expected_begin_assumes
: Error<"'#pragma omp end assumes' with no matching '#pragma omp begin assumes'">;
def warn_omp_unknown_assumption_clause_missing_id
: Warning<"valid %0 clauses start with %1; %select{token|tokens}2 will be ignored">,
InGroup<OpenMPClauses>;
def warn_omp_unknown_assumption_clause_without_args
: Warning<"%0 clause should not be followed by arguments; tokens will be ignored">,
InGroup<OpenMPClauses>;
def note_omp_assumption_clause_continue_here
: Note<"the ignored tokens spans until here">;
def err_omp_declare_target_unexpected_clause: Error<
"unexpected '%0' clause, only %select{'to' or 'link'|'to', 'link' or 'device_type'}1 clauses expected">;
def err_omp_expected_clause: Error<
Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -3104,6 +3104,13 @@ class Parser : public CodeCompletionHandler {
void ParseOMPDeclareVariantClauses(DeclGroupPtrTy Ptr, CachedTokens &Toks,
SourceLocation Loc);

/// Parse 'omp [begin] assume[s]' directive.
void ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
SourceLocation Loc);

/// Parse 'omp end assumes' directive.
void ParseOpenMPEndAssumesDirective(SourceLocation Loc);

/// Parse clauses for '#pragma omp declare target'.
DeclGroupPtrTy ParseOMPDeclareTargetClauses();
/// Parse '#pragma omp end declare target'.
Expand Down
30 changes: 27 additions & 3 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -10128,6 +10128,13 @@ class Sema final {
/// The current `omp begin/end declare variant` scopes.
SmallVector<OMPDeclareVariantScope, 4> OMPDeclareVariantScopes;

/// The current `omp begin/end assumes` scopes.
SmallVector<AssumptionAttr *, 4> OMPAssumeScoped;

/// All `omp assumes` we encountered so far.
SmallVector<AssumptionAttr *, 4> OMPAssumeGlobal;

public:
/// The declarator \p D defines a function in the scope \p S which is nested
/// in an `omp begin/end declare variant` scope. In this method we create a
/// declaration for \p D and rename \p D according to the OpenMP context
Expand All @@ -10141,10 +10148,11 @@ class Sema final {
void ActOnFinishedFunctionDefinitionInOpenMPDeclareVariantScope(
Decl *D, SmallVectorImpl<FunctionDecl *> &Bases);

public:
/// Act on \p D, a function definition inside of an `omp [begin/end] assumes`.
void ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(Decl *D);

/// Can we exit a scope at the moment.
bool isInOpenMPDeclareVariantScope() {
/// Can we exit an OpenMP declare variant scope at the moment.
bool isInOpenMPDeclareVariantScope() const {
return !OMPDeclareVariantScopes.empty();
}

Expand Down Expand Up @@ -10260,6 +10268,22 @@ class Sema final {
ArrayRef<Expr *> VarList,
ArrayRef<OMPClause *> Clauses,
DeclContext *Owner = nullptr);

/// Called on well-formed '#pragma omp [begin] assume[s]'.
void ActOnOpenMPAssumesDirective(SourceLocation Loc,
OpenMPDirectiveKind DKind,
ArrayRef<StringRef> Assumptions,
bool SkippedClauses);

/// Check if there is an active global `omp begin assumes` directive.
bool isInOpenMPAssumeScope() const { return !OMPAssumeScoped.empty(); }

/// Check if there is an active global `omp assumes` directive.
bool hasGlobalOpenMPAssumes() const { return !OMPAssumeGlobal.empty(); }

/// Called on well-formed '#pragma omp end assumes'.
void ActOnOpenMPEndAssumesDirective();

/// Called on well-formed '#pragma omp requires'.
DeclGroupPtrTy ActOnOpenMPRequiresDirective(SourceLocation Loc,
ArrayRef<OMPClause *> ClauseList);
Expand Down
115 changes: 115 additions & 0 deletions clang/lib/Parse/ParseOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "clang/Parse/RAIIObjectsForParser.h"
#include "clang/Sema/Scope.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/UniqueVector.h"
#include "llvm/Frontend/OpenMP/OMPContext.h"

Expand Down Expand Up @@ -115,7 +116,9 @@ static OpenMPDirectiveKindExWrapper parseOpenMPDirectiveKind(Parser &P) {
// TODO: add other combined directives in topological order.
static const OpenMPDirectiveKindExWrapper F[][3] = {
{OMPD_begin, OMPD_declare, OMPD_begin_declare},
{OMPD_begin, OMPD_assumes, OMPD_begin_assumes},
{OMPD_end, OMPD_declare, OMPD_end_declare},
{OMPD_end, OMPD_assumes, OMPD_end_assumes},
{OMPD_cancellation, OMPD_point, OMPD_cancellation_point},
{OMPD_declare, OMPD_reduction, OMPD_declare_reduction},
{OMPD_declare, OMPD_mapper, OMPD_declare_mapper},
Expand Down Expand Up @@ -1508,6 +1511,103 @@ bool Parser::parseOMPDeclareVariantMatchClause(SourceLocation Loc,
return false;
}

/// `omp assumes` or `omp begin/end assumes` <clause> [[,]<clause>]...
/// where
///
/// clause:
/// 'ext_IMPL_DEFINED'
/// 'absent' '(' directive-name [, directive-name]* ')'
/// 'contains' '(' directive-name [, directive-name]* ')'
/// 'holds' '(' scalar-expression ')'
/// 'no_openmp'
/// 'no_openmp_routines'
/// 'no_parallelism'
///
void Parser::ParseOpenMPAssumesDirective(OpenMPDirectiveKind DKind,
SourceLocation Loc) {
SmallVector<StringRef, 4> Assumptions;
bool SkippedClauses = false;

auto SkipBraces = [&](llvm::StringRef Spelling, bool IssueNote) {
BalancedDelimiterTracker T(*this, tok::l_paren,
tok::annot_pragma_openmp_end);
if (T.expectAndConsume(diag::err_expected_lparen_after, Spelling.data()))
return;
T.skipToEnd();
if (IssueNote && T.getCloseLocation().isValid())
Diag(T.getCloseLocation(),
diag::note_omp_assumption_clause_continue_here);
};

/// Helper to determine which AssumptionClauseMapping (ACM) in the
/// AssumptionClauseMappings table matches \p RawString. The return value is
/// the index of the matching ACM into the table or -1 if there was no match.
auto MatchACMClause = [&](StringRef RawString) {
llvm::StringSwitch<int> SS(RawString);
unsigned ACMIdx = 0;
for (const AssumptionClauseMappingInfo &ACMI : AssumptionClauseMappings) {
if (ACMI.StartsWith)
SS.StartsWith(ACMI.Identifier, ACMIdx++);
else
SS.Case(ACMI.Identifier, ACMIdx++);
}
return SS.Default(-1);
};

while (Tok.isNot(tok::annot_pragma_openmp_end)) {
IdentifierInfo *II = nullptr;
SourceLocation StartLoc = Tok.getLocation();
int Idx = -1;
if (Tok.isAnyIdentifier()) {
II = Tok.getIdentifierInfo();
Idx = MatchACMClause(II->getName());
}
ConsumeAnyToken();

bool NextIsLPar = Tok.is(tok::l_paren);
// Handle unknown clauses by skipping them.
if (Idx == -1) {
Diag(StartLoc, diag::warn_omp_unknown_assumption_clause_missing_id)
<< llvm::omp::getOpenMPDirectiveName(DKind)
<< llvm::omp::getAllAssumeClauseOptions() << NextIsLPar;
if (NextIsLPar)
SkipBraces(II ? II->getName() : "", /* IssueNote */ true);
SkippedClauses = true;
continue;
}
const AssumptionClauseMappingInfo &ACMI = AssumptionClauseMappings[Idx];
if (ACMI.HasDirectiveList || ACMI.HasExpression) {
// TODO: We ignore absent, contains, and holds assumptions for now. We
// also do not verify the content in the parenthesis at all.
SkippedClauses = true;
SkipBraces(II->getName(), /* IssueNote */ false);
continue;
}

if (NextIsLPar) {
Diag(Tok.getLocation(),
diag::warn_omp_unknown_assumption_clause_without_args)
<< II;
SkipBraces(II->getName(), /* IssueNote */ true);
}

assert(II && "Expected an identifier clause!");
StringRef Assumption = II->getName();
if (ACMI.StartsWith)
Assumption = Assumption.substr(ACMI.Identifier.size());
Assumptions.push_back(Assumption);
}

Actions.ActOnOpenMPAssumesDirective(Loc, DKind, Assumptions, SkippedClauses);
}

void Parser::ParseOpenMPEndAssumesDirective(SourceLocation Loc) {
if (Actions.isInOpenMPAssumeScope())
Actions.ActOnOpenMPEndAssumesDirective();
else
Diag(Loc, diag::err_expected_begin_assumes);
}

/// Parsing of simple OpenMP clauses like 'default' or 'proc_bind'.
///
/// default-clause:
Expand Down Expand Up @@ -1716,6 +1816,14 @@ void Parser::ParseOMPEndDeclareTargetDirective(OpenMPDirectiveKind DKind,
/// annot_pragma_openmp 'requires' <clause> [[[,] <clause>] ... ]
/// annot_pragma_openmp_end
///
/// assumes directive:
/// annot_pragma_openmp 'assumes' <clause> [[[,] <clause>] ... ]
/// annot_pragma_openmp_end
/// or
/// annot_pragma_openmp 'begin assumes' <clause> [[[,] <clause>] ... ]
/// annot_pragma_openmp 'end assumes'
/// annot_pragma_openmp_end
///
Parser::DeclGroupPtrTy Parser::ParseOpenMPDeclarativeDirectiveWithExtDecl(
AccessSpecifier &AS, ParsedAttributesWithRange &Attrs, bool Delayed,
DeclSpec::TST TagType, Decl *Tag) {
Expand Down Expand Up @@ -1853,6 +1961,13 @@ Parser::DeclGroupPtrTy Parser::ParseOpenMPDeclarativeDirectiveWithExtDecl(
ConsumeAnnotationToken();
return Actions.ActOnOpenMPRequiresDirective(StartLoc, Clauses);
}
case OMPD_assumes:
case OMPD_begin_assumes:
ParseOpenMPAssumesDirective(DKind, ConsumeToken());
break;
case OMPD_end_assumes:
ParseOpenMPEndAssumesDirective(ConsumeToken());
break;
case OMPD_declare_reduction:
ConsumeToken();
if (DeclGroupPtrTy Res = ParseOpenMPDeclareReductionDirective(AS)) {
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/Sema/SemaDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10843,6 +10843,9 @@ bool Sema::CheckFunctionDeclaration(Scope *S, FunctionDecl *NewFD,
}
}

if (LangOpts.OpenMP)
ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(NewFD);

// Semantic checking for this function declaration (in isolation).

if (getLangOpts().CPlusPlus) {
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/Sema/SemaLambda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,10 @@ void Sema::ActOnStartOfLambdaDefinition(LambdaIntroducer &Intro,
if (getLangOpts().CUDA)
CUDASetLambdaAttrs(Method);

// OpenMP lambdas might get assumumption attributes.
if (LangOpts.OpenMP)
ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(Method);

// Number the lambda for linkage purposes if necessary.
handleLambdaNumbering(Class, Method);

Expand Down
80 changes: 80 additions & 0 deletions clang/lib/Sema/SemaOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "llvm/ADT/IndexedMap.h"
#include "llvm/ADT/PointerEmbeddedInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include <set>

Expand Down Expand Up @@ -3195,6 +3196,64 @@ Sema::ActOnOpenMPRequiresDirective(SourceLocation Loc,
return DeclGroupPtrTy::make(DeclGroupRef(D));
}

void Sema::ActOnOpenMPAssumesDirective(SourceLocation Loc,
OpenMPDirectiveKind DKind,
ArrayRef<StringRef> Assumptions,
bool SkippedClauses) {
if (!SkippedClauses && Assumptions.empty())
Diag(Loc, diag::err_omp_no_clause_for_directive)
<< llvm::omp::getAllAssumeClauseOptions()
<< llvm::omp::getOpenMPDirectiveName(DKind);

auto *AA = AssumptionAttr::Create(Context, llvm::join(Assumptions, ","), Loc);
if (DKind == llvm::omp::Directive::OMPD_begin_assumes) {
OMPAssumeScoped.push_back(AA);
return;
}

// Global assumes without assumption clauses are ignored.
if (Assumptions.empty())
return;

assert(DKind == llvm::omp::Directive::OMPD_assumes &&
"Unexpected omp assumption directive!");
OMPAssumeGlobal.push_back(AA);

// The OMPAssumeGlobal scope above will take care of new declarations but
// we also want to apply the assumption to existing ones, e.g., to
// declarations in included headers. To this end, we traverse all existing
// declaration contexts and annotate function declarations here.
SmallVector<DeclContext *, 8> DeclContexts;
auto *Ctx = CurContext;
while (Ctx->getLexicalParent())
Ctx = Ctx->getLexicalParent();
DeclContexts.push_back(Ctx);
while (!DeclContexts.empty()) {
DeclContext *DC = DeclContexts.pop_back_val();
for (auto *SubDC : DC->decls()) {
if (SubDC->isInvalidDecl())
continue;
if (auto *CTD = dyn_cast<ClassTemplateDecl>(SubDC)) {
DeclContexts.push_back(CTD->getTemplatedDecl());
for (auto *S : CTD->specializations())
DeclContexts.push_back(S);
continue;
}
if (auto *DC = dyn_cast<DeclContext>(SubDC))
DeclContexts.push_back(DC);
if (auto *F = dyn_cast<FunctionDecl>(SubDC)) {
F->addAttr(AA);
continue;
}
}
}
}

void Sema::ActOnOpenMPEndAssumesDirective() {
assert(isInOpenMPAssumeScope() && "Not in OpenMP assumes scope!");
OMPAssumeScoped.pop_back();
}

OMPRequiresDecl *Sema::CheckOMPRequiresDecl(SourceLocation Loc,
ArrayRef<OMPClause *> ClauseList) {
/// For target specific clauses, the requires directive cannot be
Expand Down Expand Up @@ -5936,6 +5995,27 @@ static void setPrototype(Sema &S, FunctionDecl *FD, FunctionDecl *FDWithProto,
FD->setParams(Params);
}

void Sema::ActOnFinishedFunctionDefinitionInOpenMPAssumeScope(Decl *D) {
if (D->isInvalidDecl())
return;
FunctionDecl *FD = nullptr;
if (auto *UTemplDecl = dyn_cast<FunctionTemplateDecl>(D))
FD = UTemplDecl->getTemplatedDecl();
else
FD = cast<FunctionDecl>(D);
assert(FD && "Expected a function declaration!");

// If we are intantiating templates we do *not* apply scoped assumptions but
// only global ones. We apply scoped assumption to the template definition
// though.
if (!inTemplateInstantiation()) {
for (AssumptionAttr *AA : OMPAssumeScoped)
FD->addAttr(AA);
}
for (AssumptionAttr *AA : OMPAssumeGlobal)
FD->addAttr(AA);
}

Sema::OMPDeclareVariantScope::OMPDeclareVariantScope(OMPTraitInfo &TI)
: TI(&TI), NameSuffix(TI.getMangledName()) {}

Expand Down
Loading

0 comments on commit 2e6e4e6

Please sign in to comment.