Skip to content

[MLIR][OpenMP] Make omp.wsloop into a loop wrapper (1/5) #89209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,9 @@ using TeamsClauseOps =
PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;

using WsloopClauseOps =
detail::Clauses<AllocateClauseOps, CollapseClauseOps, LinearClauseOps,
LoopRelatedOps, NowaitClauseOps, OrderClauseOps,
OrderedClauseOps, PrivateClauseOps, ReductionClauseOps,
ScheduleClauseOps>;
detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
ReductionClauseOps, ScheduleClauseOps>;

} // namespace omp
} // namespace mlir
Expand Down
52 changes: 21 additions & 31 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -600,29 +600,29 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
//===----------------------------------------------------------------------===//

def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
RecursiveMemoryEffects, ReductionClauseInterface]> {
RecursiveMemoryEffects, ReductionClauseInterface,
SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "worksharing-loop construct";
let description = [{
The worksharing-loop construct specifies that the iterations of the loop(s)
will be executed in parallel by threads in the current context. These
iterations are spread across threads that already exist in the enclosing
parallel region. The lower and upper bounds specify a half-open range: the
range includes the lower bound but does not include the upper bound. If the
`inclusive` attribute is specified then the upper bound is also included.
parallel region.

The body region can contain any number of blocks. The region is terminated
by "omp.yield" instruction without operands.
The body region can contain a single block which must contain a single
operation and a terminator. The operation must be another compatible loop
wrapper or an `omp.loop_nest`.

```
omp.wsloop <clauses>
for (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
%a = load %arrA[%i1, %i2] : memref<?x?xf32>
%b = load %arrB[%i1, %i2] : memref<?x?xf32>
%sum = arith.addf %a, %b : f32
store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
omp.yield
omp.wsloop <clauses> {
omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
%a = load %arrA[%i1, %i2] : memref<?x?xf32>
%b = load %arrB[%i1, %i2] : memref<?x?xf32>
%sum = arith.addf %a, %b : f32
store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
omp.yield
}
}
```

Expand Down Expand Up @@ -665,10 +665,7 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
passed by reference.
}];

let arguments = (ins Variadic<IntLikeType>:$lowerBound,
Variadic<IntLikeType>:$upperBound,
Variadic<IntLikeType>:$step,
Variadic<AnyType>:$linear_vars,
let arguments = (ins Variadic<AnyType>:$linear_vars,
Variadic<I32>:$linear_step_vars,
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
OptionalAttr<SymbolRefArrayAttr>:$reductions,
Expand All @@ -679,22 +676,16 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
UnitAttr:$nowait,
UnitAttr:$byref,
ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$ordered_val,
OptionalAttr<OrderKindAttr>:$order_val,
UnitAttr:$inclusive);
OptionalAttr<OrderKindAttr>:$order_val);

let builders = [
OpBuilder<(ins "ValueRange":$lowerBound, "ValueRange":$upperBound,
"ValueRange":$step,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
OpBuilder<(ins CArg<"const WsloopClauseOps &">:$clauses)>
];

let regions = (region AnyRegion:$region);

let extraClassDeclaration = [{
/// Returns the number of loops in the worksharing-loop nest.
unsigned getNumLoops() { return getLowerBound().size(); }

/// Returns the number of reduction variables.
unsigned getNumReductionVars() { return getReductionVars().size(); }
}];
Expand All @@ -711,9 +702,8 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
|`byref` $byref
|`ordered` `(` $ordered_val `)`
|`order` `(` custom<ClauseAttr>($order_val) `)`
) custom<Wsloop>($region, $lowerBound, $upperBound, $step, type($step),
$reduction_vars, type($reduction_vars), $reductions,
$inclusive) attr-dict
) custom<Wsloop>($region, $reduction_vars, type($reduction_vars),
$reductions) attr-dict
}];
let hasVerifier = 1;
}
Expand Down Expand Up @@ -805,8 +795,8 @@ def SimdOp : OpenMP_Op<"simd", [AttrSizedOperandSegments,

def YieldOp : OpenMP_Op<"yield",
[Pure, ReturnLike, Terminator,
ParentOneOf<["LoopNestOp", "WsloopOp", "DeclareReductionOp",
"AtomicUpdateOp", "PrivateClauseOp"]>]> {
ParentOneOf<["AtomicUpdateOp", "DeclareReductionOp", "LoopNestOp",
"PrivateClauseOp"]>]> {
let summary = "loop yield and termination operation";
let description = [{
"omp.yield" yields SSA values from the OpenMP dialect op region and
Expand Down
150 changes: 50 additions & 100 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,86 +1468,72 @@ LogicalResult SingleOp::verify() {
// WsloopOp
//===----------------------------------------------------------------------===//

/// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
/// steps := `step` `(`ssa-id-list`)`
ParseResult
parseWsloop(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
SmallVectorImpl<Type> &loopVarTypes,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionOperands,
SmallVectorImpl<Type> &reductionTypes, ArrayAttr &reductionSymbols,
UnitAttr &inclusive) {

SmallVectorImpl<Type> &reductionTypes,
ArrayAttr &reductionSymbols) {
// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
bool hasReduction = succeeded(parser.parseOptionalKeyword("reduction")) &&
succeeded(parseClauseWithRegionArgs(
parser, region, reductionOperands, reductionTypes,
reductionSymbols, privates));

if (parser.parseKeyword("for"))
return failure();

// Parse an opening `(` followed by induction variables followed by `)`
SmallVector<OpAsmParser::Argument> ivs;
Type loopVarType;
if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
parser.parseColonType(loopVarType) ||
// Parse loop bounds.
parser.parseEqual() ||
parser.parseOperandList(lowerBound, ivs.size(),
OpAsmParser::Delimiter::Paren) ||
parser.parseKeyword("to") ||
parser.parseOperandList(upperBound, ivs.size(),
OpAsmParser::Delimiter::Paren))
return failure();

if (succeeded(parser.parseOptionalKeyword("inclusive")))
inclusive = UnitAttr::get(parser.getBuilder().getContext());

// Parse step values.
if (parser.parseKeyword("step") ||
parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
return failure();

// Now parse the body.
loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
for (auto &iv : ivs)
iv.type = loopVarType;

SmallVector<OpAsmParser::Argument> regionArgs{ivs};
if (hasReduction)
llvm::copy(privates, std::back_inserter(regionArgs));

return parser.parseRegion(region, regionArgs);
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands,
reductionTypes, reductionSymbols,
privates)))
return failure();
}
return parser.parseRegion(region, privates);
}

void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange lowerBound, ValueRange upperBound, ValueRange steps,
TypeRange loopVarTypes, ValueRange reductionOperands,
TypeRange reductionTypes, ArrayAttr reductionSymbols,
UnitAttr inclusive) {
ValueRange reductionOperands, TypeRange reductionTypes,
ArrayAttr reductionSymbols) {
if (reductionSymbols) {
auto reductionArgs =
region.front().getArguments().drop_front(loopVarTypes.size());
auto reductionArgs = region.front().getArguments();
printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
reductionOperands, reductionTypes,
reductionSymbols);
}

p << " for ";
auto args = region.front().getArguments().drop_back(reductionOperands.size());
p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
<< ") to (" << upperBound << ") ";
if (inclusive)
p << "inclusive ";
p << "step (" << steps << ") ";
p.printRegion(region, /*printEntryBlockArgs=*/false);
}

void WsloopOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, /*linear_vars=*/ValueRange(),
/*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
/*reductions=*/nullptr, /*schedule_val=*/nullptr,
/*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
/*simd_modifier=*/false, /*nowait=*/false, /*byref=*/false,
/*ordered_val=*/nullptr, /*order_val=*/nullptr);
state.addAttributes(attributes);
}

void WsloopOp::build(OpBuilder &builder, OperationState &state,
const WsloopClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: allocateVars, allocatorVars, privateVars,
// privatizers.
WsloopOp::build(
builder, state, clauses.linearVars, clauses.linearStepVars,
clauses.reductionVars, makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.scheduleValAttr, clauses.scheduleChunkVar,
clauses.scheduleModAttr, clauses.scheduleSimdAttr, clauses.nowaitAttr,
clauses.reductionByRefAttr, clauses.orderedAttr, clauses.orderAttr);
}

LogicalResult WsloopOp::verify() {
if (!isWrapper())
return emitOpError() << "must be a loop wrapper";

if (LoopWrapperInterface nested = getNestedWrapper()) {
// Check for the allowed leaf constructs that may appear in a composite
// construct directly after DO/FOR.
if (!isa<SimdOp>(nested))
return emitError() << "only supported nested wrapper is 'omp.simd'";
}

return verifyReductionVarList(*this, getReductions(), getReductionVars());
}

//===----------------------------------------------------------------------===//
// Simd construct [2.9.3.1]
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1928,42 +1914,6 @@ void LoopNestOp::gatherWrappers(
}
}

//===----------------------------------------------------------------------===//
// WsloopOp
//===----------------------------------------------------------------------===//

void WsloopOp::build(OpBuilder &builder, OperationState &state,
ValueRange lowerBound, ValueRange upperBound,
ValueRange step, ArrayRef<NamedAttribute> attributes) {
build(builder, state, lowerBound, upperBound, step,
/*linear_vars=*/ValueRange(),
/*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
/*reductions=*/nullptr, /*schedule_val=*/nullptr,
/*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
/*simd_modifier=*/false, /*nowait=*/false, /*byref=*/false,
/*ordered_val=*/nullptr,
/*order_val=*/nullptr, /*inclusive=*/false);
state.addAttributes(attributes);
}

void WsloopOp::build(OpBuilder &builder, OperationState &state,
const WsloopClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: allocateVars, allocatorVars, privateVars,
// privatizers.
WsloopOp::build(
builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
clauses.linearVars, clauses.linearStepVars, clauses.reductionVars,
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.scheduleValAttr,
clauses.scheduleChunkVar, clauses.scheduleModAttr,
clauses.scheduleSimdAttr, clauses.nowaitAttr, clauses.reductionByRefAttr,
clauses.orderedAttr, clauses.orderAttr, clauses.loopInclusiveAttr);
}

LogicalResult WsloopOp::verify() {
return verifyReductionVarList(*this, getReductions(), getReductionVars());
}

//===----------------------------------------------------------------------===//
// Critical construct (2.17.1)
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 5 additions & 2 deletions mlir/test/CAPI/execution_engine.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ void testOmpCreation(void) {
" %1 = arith.constant 1 : i32 \n"
" %2 = arith.constant 2 : i32 \n"
" omp.parallel { \n"
" omp.wsloop for (%3) : i32 = (%0) to (%2) step (%1) { \n"
" omp.yield \n"
" omp.wsloop { \n"
" omp.loop_nest (%3) : i32 = (%0) to (%2) step (%1) { \n"
" omp.yield \n"
" } \n"
" omp.terminator \n"
" } \n"
" omp.terminator \n"
" } \n"
Expand Down
64 changes: 36 additions & 28 deletions mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,18 @@ func.func @branch_loop() {
func.func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) {
// CHECK: omp.parallel
omp.parallel {
// CHECK: omp.wsloop for (%[[ARG6:.*]], %[[ARG7:.*]]) : i64 = (%[[ARG0]], %[[ARG1]]) to (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[ARG5]]) {
"omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ({
^bb0(%arg6: index, %arg7: index):
// CHECK-DAG: %[[CAST_ARG6:.*]] = builtin.unrealized_conversion_cast %[[ARG6]] : i64 to index
// CHECK-DAG: %[[CAST_ARG7:.*]] = builtin.unrealized_conversion_cast %[[ARG7]] : i64 to index
// CHECK: "test.payload"(%[[CAST_ARG6]], %[[CAST_ARG7]]) : (index, index) -> ()
"test.payload"(%arg6, %arg7) : (index, index) -> ()
omp.yield
}) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0>} : (index, index, index, index, index, index) -> ()
// CHECK: omp.wsloop {
"omp.wsloop"() ({
// CHECK: omp.loop_nest (%[[ARG6:.*]], %[[ARG7:.*]]) : i64 = (%[[ARG0]], %[[ARG1]]) to (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[ARG5]]) {
omp.loop_nest (%arg6, %arg7) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) {
// CHECK-DAG: %[[CAST_ARG6:.*]] = builtin.unrealized_conversion_cast %[[ARG6]] : i64 to index
// CHECK-DAG: %[[CAST_ARG7:.*]] = builtin.unrealized_conversion_cast %[[ARG7]] : i64 to index
// CHECK: "test.payload"(%[[CAST_ARG6]], %[[CAST_ARG7]]) : (index, index) -> ()
"test.payload"(%arg6, %arg7) : (index, index) -> ()
omp.yield
}
omp.terminator
}) : () -> ()
omp.terminator
}
return
Expand Down Expand Up @@ -323,12 +326,14 @@ llvm.func @_QPsb() {
// CHECK-LABEL: @_QPsimple_reduction
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %{{.*}} x i32 {bindc_name = "x", uniq_name = "_QFsimple_reductionEx"} : (i64) -> !llvm.ptr
// CHECK: omp.parallel
// CHECK: omp.wsloop reduction(@eqv_reduction %{{.+}} -> %[[PRV:.+]] : !llvm.ptr) for
// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
// CHECK: %[[CMP:.+]] = llvm.icmp "eq" %{{.*}}, %[[LPRV]] : i32
// CHECK: %[[ZEXT:.+]] = llvm.zext %[[CMP]] : i1 to i32
// CHECK: llvm.store %[[ZEXT]], %[[PRV]] : i32, !llvm.ptr
// CHECK: omp.yield
// CHECK: omp.wsloop reduction(@eqv_reduction %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
// CHECK-NEXT: omp.loop_nest {{.*}}{
// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
// CHECK: %[[CMP:.+]] = llvm.icmp "eq" %{{.*}}, %[[LPRV]] : i32
// CHECK: %[[ZEXT:.+]] = llvm.zext %[[CMP]] : i1 to i32
// CHECK: llvm.store %[[ZEXT]], %[[PRV]] : i32, !llvm.ptr
// CHECK: omp.yield
// CHECK: omp.terminator
// CHECK: omp.terminator
// CHECK: llvm.return

Expand All @@ -354,20 +359,23 @@ llvm.func @_QPsimple_reduction(%arg0: !llvm.ptr {fir.bindc_name = "y"}) {
%4 = llvm.alloca %3 x i32 {bindc_name = "x", uniq_name = "_QFsimple_reductionEx"} : (i64) -> !llvm.ptr
%5 = llvm.zext %2 : i1 to i32
llvm.store %5, %4 : i32, !llvm.ptr
omp.parallel {
omp.parallel {
%6 = llvm.alloca %3 x i32 {adapt.valuebyref, in_type = i32, operandSegmentSizes = array<i32: 0, 0>, pinned} : (i64) -> !llvm.ptr
omp.wsloop reduction(@eqv_reduction %4 -> %prv : !llvm.ptr) for (%arg1) : i32 = (%1) to (%0) inclusive step (%1) {
llvm.store %arg1, %6 : i32, !llvm.ptr
%7 = llvm.load %6 : !llvm.ptr -> i32
%8 = llvm.sext %7 : i32 to i64
%9 = llvm.sub %8, %3 : i64
%10 = llvm.getelementptr %arg0[0, %9] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<100 x i32>
%11 = llvm.load %10 : !llvm.ptr -> i32
%12 = llvm.load %prv : !llvm.ptr -> i32
%13 = llvm.icmp "eq" %11, %12 : i32
%14 = llvm.zext %13 : i1 to i32
llvm.store %14, %prv : i32, !llvm.ptr
omp.yield
omp.wsloop reduction(@eqv_reduction %4 -> %prv : !llvm.ptr) {
omp.loop_nest (%arg1) : i32 = (%1) to (%0) inclusive step (%1) {
llvm.store %arg1, %6 : i32, !llvm.ptr
%7 = llvm.load %6 : !llvm.ptr -> i32
%8 = llvm.sext %7 : i32 to i64
%9 = llvm.sub %8, %3 : i64
%10 = llvm.getelementptr %arg0[0, %9] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<100 x i32>
%11 = llvm.load %10 : !llvm.ptr -> i32
%12 = llvm.load %prv : !llvm.ptr -> i32
%13 = llvm.icmp "eq" %11, %12 : i32
%14 = llvm.zext %13 : i1 to i32
llvm.store %14, %prv : i32, !llvm.ptr
omp.yield
}
omp.terminator
}
omp.terminator
}
Expand Down
Loading
Loading