Skip to content

Commit 82c6eee

Browse files
authored
[MLIR] Add a second map for registered OperationName in MLIRContext (NFC) (#87170)
This speeds up registered op creation by 10-11% by allowing lookup by TypeID instead of StringRef. This can break your build/tests at runtime with an error that you're creating an unregistered operation that you have registered. If so you are likely using a class inheriting from the "real" operation. See for example in this patch the case of: class ConstantIndexOp : public arith::ConstantOp { If one is using `builder.create<ConstantIndexOp>()` they actually create an `arith.constant` operation, but the builder will fetch the TypeID for the `ConstantIndexOp` class which does not correspond to any registered operation. To fix it the `ConstantIndexOp` class got this addition: static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
1 parent 75f7d53 commit 82c6eee

File tree

6 files changed

+36
-15
lines changed

6 files changed

+36
-15
lines changed

mlir/include/mlir/Dialect/Arith/IR/Arith.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ namespace arith {
5353
class ConstantIntOp : public arith::ConstantOp {
5454
public:
5555
using arith::ConstantOp::ConstantOp;
56+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
5657

5758
/// Build a constant int op that produces an integer of the specified width.
5859
static void build(OpBuilder &builder, OperationState &result, int64_t value,
@@ -74,6 +75,7 @@ class ConstantIntOp : public arith::ConstantOp {
7475
class ConstantFloatOp : public arith::ConstantOp {
7576
public:
7677
using arith::ConstantOp::ConstantOp;
78+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
7779

7880
/// Build a constant float op that produces a float of the specified type.
7981
static void build(OpBuilder &builder, OperationState &result,
@@ -90,7 +92,7 @@ class ConstantFloatOp : public arith::ConstantOp {
9092
class ConstantIndexOp : public arith::ConstantOp {
9193
public:
9294
using arith::ConstantOp::ConstantOp;
93-
95+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
9496
/// Build a constant int op that produces an index.
9597
static void build(OpBuilder &builder, OperationState &result, int64_t value);
9698

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -252,21 +252,21 @@ class TransformDialectExtension
252252

253253
template <typename OpTy>
254254
void TransformDialect::addOperationIfNotRegistered() {
255-
StringRef name = OpTy::getOperationName();
256255
std::optional<RegisteredOperationName> opName =
257-
RegisteredOperationName::lookup(name, getContext());
256+
RegisteredOperationName::lookup(TypeID::get<OpTy>(), getContext());
258257
if (!opName) {
259258
addOperations<OpTy>();
260259
#ifndef NDEBUG
260+
StringRef name = OpTy::getOperationName();
261261
detail::checkImplementsTransformOpInterface(name, getContext());
262262
#endif // NDEBUG
263263
return;
264264
}
265265

266-
if (opName->getTypeID() == TypeID::get<OpTy>())
266+
if (LLVM_LIKELY(opName->getTypeID() == TypeID::get<OpTy>()))
267267
return;
268268

269-
reportDuplicateOpRegistration(name);
269+
reportDuplicateOpRegistration(OpTy::getOperationName());
270270
}
271271

272272
template <typename Type>

mlir/include/mlir/IR/Builders.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ class OpBuilder : public Builder {
490490
template <typename OpT>
491491
RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
492492
std::optional<RegisteredOperationName> opName =
493-
RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
493+
RegisteredOperationName::lookup(TypeID::get<OpT>(), ctx);
494494
if (LLVM_UNLIKELY(!opName)) {
495495
llvm::report_fatal_error(
496496
"Building op `" + OpT::getOperationName() +

mlir/include/mlir/IR/OpDefinition.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -1729,8 +1729,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
17291729
template <typename... Models>
17301730
static void attachInterface(MLIRContext &context) {
17311731
std::optional<RegisteredOperationName> info =
1732-
RegisteredOperationName::lookup(ConcreteType::getOperationName(),
1733-
&context);
1732+
RegisteredOperationName::lookup(TypeID::get<ConcreteType>(), &context);
17341733
if (!info)
17351734
llvm::report_fatal_error(
17361735
"Attempting to attach an interface to an unregistered operation " +

mlir/include/mlir/IR/OperationSupport.h

+5
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,11 @@ class RegisteredOperationName : public OperationName {
676676
static std::optional<RegisteredOperationName> lookup(StringRef name,
677677
MLIRContext *ctx);
678678

679+
/// Lookup the registered operation information for the given operation.
680+
/// Returns std::nullopt if the operation isn't registered.
681+
static std::optional<RegisteredOperationName> lookup(TypeID typeID,
682+
MLIRContext *ctx);
683+
679684
/// Register a new operation in a Dialect object.
680685
/// This constructor is used by Dialect objects when they register the list
681686
/// of operations they contain.

mlir/lib/IR/MLIRContext.cpp

+22-7
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ class MLIRContextImpl {
183183
llvm::StringMap<std::unique_ptr<OperationName::Impl>> operations;
184184

185185
/// A vector of operation info specifically for registered operations.
186-
llvm::StringMap<RegisteredOperationName> registeredOperations;
186+
llvm::DenseMap<TypeID, RegisteredOperationName> registeredOperations;
187+
llvm::StringMap<RegisteredOperationName> registeredOperationsByName;
187188

188189
/// This is a sorted container of registered operations for a deterministic
189190
/// and efficient `getRegisteredOperations` implementation.
@@ -780,8 +781,8 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
780781
// Check the registered info map first. In the overwhelmingly common case,
781782
// the entry will be in here and it also removes the need to acquire any
782783
// locks.
783-
auto registeredIt = ctxImpl.registeredOperations.find(name);
784-
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperations.end())) {
784+
auto registeredIt = ctxImpl.registeredOperationsByName.find(name);
785+
if (LLVM_LIKELY(registeredIt != ctxImpl.registeredOperationsByName.end())) {
785786
impl = registeredIt->second.impl;
786787
return;
787788
}
@@ -909,10 +910,19 @@ OperationName::UnregisteredOpModel::hashProperties(OpaqueProperties prop) {
909910
//===----------------------------------------------------------------------===//
910911

911912
std::optional<RegisteredOperationName>
912-
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
913+
RegisteredOperationName::lookup(TypeID typeID, MLIRContext *ctx) {
913914
auto &impl = ctx->getImpl();
914-
auto it = impl.registeredOperations.find(name);
915+
auto it = impl.registeredOperations.find(typeID);
915916
if (it != impl.registeredOperations.end())
917+
return it->second;
918+
return std::nullopt;
919+
}
920+
921+
std::optional<RegisteredOperationName>
922+
RegisteredOperationName::lookup(StringRef name, MLIRContext *ctx) {
923+
auto &impl = ctx->getImpl();
924+
auto it = impl.registeredOperationsByName.find(name);
925+
if (it != impl.registeredOperationsByName.end())
916926
return it->getValue();
917927
return std::nullopt;
918928
}
@@ -945,11 +955,16 @@ void RegisteredOperationName::insert(
945955

946956
// Update the registered info for this operation.
947957
auto emplaced = ctxImpl.registeredOperations.try_emplace(
948-
name, RegisteredOperationName(impl));
958+
impl->getTypeID(), RegisteredOperationName(impl));
949959
assert(emplaced.second && "operation name registration must be successful");
960+
auto emplacedByName = ctxImpl.registeredOperationsByName.try_emplace(
961+
name, RegisteredOperationName(impl));
962+
(void)emplacedByName;
963+
assert(emplacedByName.second &&
964+
"operation name registration must be successful");
950965

951966
// Add emplaced operation name to the sorted operations container.
952-
RegisteredOperationName &value = emplaced.first->getValue();
967+
RegisteredOperationName &value = emplaced.first->second;
953968
ctxImpl.sortedRegisteredOperations.insert(
954969
llvm::upper_bound(ctxImpl.sortedRegisteredOperations, value,
955970
[](auto &lhs, auto &rhs) {

0 commit comments

Comments
 (0)