Skip to content

Commit 9b63bdd

Browse files
authored
[mlir] Improve mlir-query tool by implementing getBackwardSlice and getForwardSlice matchers (#115670)
Improve mlir-query tool by implementing `getBackwardSlice` and `getForwardSlice` matchers. As an addition `SetQuery` also needed to be added to enable custom configuration for each query. e.g: `inclusive`, `omitUsesFromAbove`, `omitBlockArguments`. Note: backwardSlice and forwardSlice algoritms are the same as the ones in `mlir/lib/Analysis/SliceAnalysis.cpp` Example of current matcher. The query was made to the file: `mlir/test/mlir-query/complex-test.mlir` ```mlir ./mlir-query /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir -c "match getDefinitions(hasOpName(\"arith.add f\"),2)" Match #1: /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:5:8: %0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) { ^ /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:7:10: note: "root" binds here %2 = arith.addf %in, %in : f32 ^ Match #2: /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:10:16: %collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32> ^ /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:13:11: %c2 = arith.constant 2 : index ^ /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:14:18: %extracted = tensor.extract %collapsed[%c2] : tensor<25xf32> ^ /home/dbudii/personal/llvm-project/mlir/test/mlir-query/complex-test.mlir:15:10: note: "root" binds here %2 = arith.addf %extracted, %extracted : f32 ^ 2 matches. ```
1 parent e01bdc1 commit 9b63bdd

File tree

14 files changed

+493
-57
lines changed

14 files changed

+493
-57
lines changed

mlir/include/mlir/Query/Matcher/Marshallers.h

+30
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,36 @@ struct ArgTypeTraits<llvm::StringRef> {
5050
}
5151
};
5252

53+
template <>
54+
struct ArgTypeTraits<int64_t> {
55+
static bool hasCorrectType(const VariantValue &value) {
56+
return value.isSigned();
57+
}
58+
59+
static unsigned get(const VariantValue &value) { return value.getSigned(); }
60+
61+
static ArgKind getKind() { return ArgKind::Signed; }
62+
63+
static std::optional<std::string> getBestGuess(const VariantValue &) {
64+
return std::nullopt;
65+
}
66+
};
67+
68+
template <>
69+
struct ArgTypeTraits<bool> {
70+
static bool hasCorrectType(const VariantValue &value) {
71+
return value.isBoolean();
72+
}
73+
74+
static unsigned get(const VariantValue &value) { return value.getBoolean(); }
75+
76+
static ArgKind getKind() { return ArgKind::Boolean; }
77+
78+
static std::optional<std::string> getBestGuess(const VariantValue &) {
79+
return std::nullopt;
80+
}
81+
};
82+
5383
template <>
5484
struct ArgTypeTraits<DynMatcher> {
5585

mlir/include/mlir/Query/Matcher/MatchFinder.h

+33-15
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,51 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains the MatchFinder class, which is used to find operations
10-
// that match a given matcher.
10+
// that match a given matcher and print them.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

1414
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1515
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1616

1717
#include "MatchersInternal.h"
18+
#include "mlir/Query/Query.h"
19+
#include "mlir/Query/QuerySession.h"
20+
#include "llvm/ADT/SetVector.h"
1821

1922
namespace mlir::query::matcher {
2023

21-
// MatchFinder is used to find all operations that match a given matcher.
24+
/// A class that provides utilities to find operations in the IR.
2225
class MatchFinder {
26+
2327
public:
24-
// Returns all operations that match the given matcher.
25-
static std::vector<Operation *> getMatches(Operation *root,
26-
DynMatcher matcher) {
27-
std::vector<Operation *> matches;
28-
29-
// Simple match finding with walk.
30-
root->walk([&](Operation *subOp) {
31-
if (matcher.match(subOp))
32-
matches.push_back(subOp);
33-
});
34-
35-
return matches;
36-
}
28+
/// A subclass which preserves the matching information. Each instance
29+
/// contains the `rootOp` along with the matching environment.
30+
struct MatchResult {
31+
MatchResult() = default;
32+
MatchResult(Operation *rootOp, std::vector<Operation *> matchedOps);
33+
34+
Operation *rootOp = nullptr;
35+
/// Contains the matching environment.
36+
std::vector<Operation *> matchedOps;
37+
};
38+
39+
/// Traverses the IR and returns a vector of `MatchResult` for each match of
40+
/// the `matcher`.
41+
std::vector<MatchResult> collectMatches(Operation *root,
42+
DynMatcher matcher) const;
43+
44+
/// Prints the matched operation.
45+
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op) const;
46+
47+
/// Labels the matched operation with the given binding (e.g., `"root"`) and
48+
/// prints it.
49+
void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op,
50+
const std::string &binding) const;
51+
52+
/// Flattens a vector of `MatchResult` into a vector of operations.
53+
std::vector<Operation *>
54+
flattenMatchedOps(std::vector<MatchResult> &matches) const;
3755
};
3856

3957
} // namespace mlir::query::matcher

mlir/include/mlir/Query/Matcher/MatchersInternal.h

+49-10
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
//
99
// Implements the base layer of the matcher framework.
1010
//
11-
// Matchers are methods that return a Matcher which provides a method
12-
// match(Operation *op)
11+
// Matchers are methods that return a Matcher which provides a method one of the
12+
// following methods: match(Operation *op), match(Operation *op,
13+
// SetVector<Operation *> &matchedOps)
1314
//
1415
// The matcher functions are defined in include/mlir/IR/Matchers.h.
1516
// This file contains the wrapper classes needed to construct matchers for
@@ -25,13 +26,39 @@
2526

2627
namespace mlir::query::matcher {
2728

29+
// Defaults to false if T has no match() method with the signature:
30+
// match(Operation* op).
31+
template <typename T, typename = void>
32+
struct has_simple_match : std::false_type {};
33+
34+
// Specialized type trait that evaluates to true if T has a match() method
35+
// with the signature: match(Operation* op).
36+
template <typename T>
37+
struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
38+
std::declval<Operation *>()))>>
39+
: std::true_type {};
40+
41+
// Defaults to false if T has no match() method with the signature:
42+
// match(Operation* op, SetVector<Operation*>&).
43+
template <typename T, typename = void>
44+
struct has_bound_match : std::false_type {};
45+
46+
// Specialized type trait that evaluates to true if T has a match() method
47+
// with the signature: match(Operation* op, SetVector<Operation*>&).
48+
template <typename T>
49+
struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
50+
std::declval<Operation *>(),
51+
std::declval<SetVector<Operation *> &>()))>>
52+
: std::true_type {};
53+
2854
// Generic interface for matchers on an MLIR operation.
2955
class MatcherInterface
3056
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
3157
public:
3258
virtual ~MatcherInterface() = default;
3359

3460
virtual bool match(Operation *op) = 0;
61+
virtual bool match(Operation *op, SetVector<Operation *> &matchedOps) = 0;
3562
};
3663

3764
// MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +67,25 @@ template <typename MatcherFn>
4067
class MatcherFnImpl : public MatcherInterface {
4168
public:
4269
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
43-
bool match(Operation *op) override { return matcherFn.match(op); }
70+
71+
bool match(Operation *op) override {
72+
if constexpr (has_simple_match<MatcherFn>::value)
73+
return matcherFn.match(op);
74+
return false;
75+
}
76+
77+
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
78+
if constexpr (has_bound_match<MatcherFn>::value)
79+
return matcherFn.match(op, matchedOps);
80+
return false;
81+
}
4482

4583
private:
4684
MatcherFn matcherFn;
4785
};
4886

49-
// Matcher wraps a MatcherInterface implementation and provides a match()
50-
// method that redirects calls to the underlying implementation.
87+
// Matcher wraps a MatcherInterface implementation and provides match()
88+
// methods that redirect calls to the underlying implementation.
5189
class DynMatcher {
5290
public:
5391
// Takes ownership of the provided implementation pointer.
@@ -62,12 +100,13 @@ class DynMatcher {
62100
}
63101

64102
bool match(Operation *op) const { return implementation->match(op); }
103+
bool match(Operation *op, SetVector<Operation *> &matchedOps) const {
104+
return implementation->match(op, matchedOps);
105+
}
65106

66-
void setFunctionName(StringRef name) { functionName = name.str(); };
67-
68-
bool hasFunctionName() const { return !functionName.empty(); };
69-
70-
StringRef getFunctionName() const { return functionName; };
107+
void setFunctionName(StringRef name) { functionName = name.str(); }
108+
bool hasFunctionName() const { return !functionName.empty(); }
109+
StringRef getFunctionName() const { return functionName; }
71110

72111
private:
73112
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
//===- SliceMatchers.h - Matchers for slicing analysis ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides matchers for MLIRQuery that peform slicing analysis
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
14+
#define MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
15+
16+
#include "mlir/Analysis/SliceAnalysis.h"
17+
18+
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
19+
/// Additionally, it limits the slice computation to a certain depth level using
20+
/// a custom filter.
21+
///
22+
/// Example: starting from node 9, assuming the matcher
23+
/// computes the slice for the first two depth levels:
24+
/// ============================
25+
/// 1 2 3 4
26+
/// |_______| |______|
27+
/// | | |
28+
/// | 5 6
29+
/// |___|_____________|
30+
/// | |
31+
/// 7 8
32+
/// |_______________|
33+
/// |
34+
/// 9
35+
///
36+
/// Assuming all local orders match the numbering order:
37+
/// {5, 7, 6, 8, 9}
38+
namespace mlir::query::matcher {
39+
40+
template <typename Matcher>
41+
class BackwardSliceMatcher {
42+
public:
43+
BackwardSliceMatcher(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
44+
bool omitBlockArguments, bool omitUsesFromAbove)
45+
: innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth),
46+
inclusive(inclusive), omitBlockArguments(omitBlockArguments),
47+
omitUsesFromAbove(omitUsesFromAbove) {}
48+
49+
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
50+
BackwardSliceOptions options;
51+
options.inclusive = inclusive;
52+
options.omitUsesFromAbove = omitUsesFromAbove;
53+
options.omitBlockArguments = omitBlockArguments;
54+
return (innerMatcher.match(rootOp) &&
55+
matches(rootOp, backwardSlice, options, maxDepth));
56+
}
57+
58+
private:
59+
bool matches(Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
60+
BackwardSliceOptions &options, int64_t maxDepth);
61+
62+
private:
63+
// The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
64+
// to determine whether we want to traverse the IR or not. For example, we
65+
// want to explore the IR only if the top-level operation name is
66+
// `"arith.addf"`.
67+
Matcher innerMatcher;
68+
// `maxDepth` specifies the maximum depth that the matcher can traverse the
69+
// IR. For example, if `maxDepth` is 2, the matcher will explore the defining
70+
// operations of the top-level op up to 2 levels.
71+
int64_t maxDepth;
72+
bool inclusive;
73+
bool omitBlockArguments;
74+
bool omitUsesFromAbove;
75+
};
76+
77+
template <typename Matcher>
78+
bool BackwardSliceMatcher<Matcher>::matches(
79+
Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
80+
BackwardSliceOptions &options, int64_t maxDepth) {
81+
backwardSlice.clear();
82+
llvm::DenseMap<Operation *, int64_t> opDepths;
83+
// Initializing the root op with a depth of 0
84+
opDepths[rootOp] = 0;
85+
options.filter = [&](Operation *subOp) {
86+
// If the subOp hasn't been recorded in opDepths, it is deeper than
87+
// maxDepth.
88+
if (!opDepths.contains(subOp))
89+
return false;
90+
// Examine subOp's operands to compute depths of their defining operations.
91+
for (auto operand : subOp->getOperands()) {
92+
int64_t newDepth = opDepths[subOp] + 1;
93+
// If the newDepth is greater than maxDepth, further computation can be
94+
// skipped.
95+
if (newDepth > maxDepth)
96+
continue;
97+
98+
if (auto definingOp = operand.getDefiningOp()) {
99+
// Registers the minimum depth
100+
if (!opDepths.contains(definingOp) || newDepth < opDepths[definingOp])
101+
opDepths[definingOp] = newDepth;
102+
} else {
103+
auto blockArgument = cast<BlockArgument>(operand);
104+
Operation *parentOp = blockArgument.getOwner()->getParentOp();
105+
if (!parentOp)
106+
continue;
107+
108+
if (!opDepths.contains(parentOp) || newDepth < opDepths[parentOp])
109+
opDepths[parentOp] = newDepth;
110+
}
111+
}
112+
return true;
113+
};
114+
getBackwardSlice(rootOp, &backwardSlice, options);
115+
return options.inclusive ? backwardSlice.size() > 1
116+
: backwardSlice.size() >= 1;
117+
}
118+
119+
/// Matches transitive defs of a top-level operation up to N levels.
120+
template <typename Matcher>
121+
inline BackwardSliceMatcher<Matcher>
122+
m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
123+
bool omitBlockArguments, bool omitUsesFromAbove) {
124+
assert(maxDepth >= 0 && "maxDepth must be non-negative");
125+
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth,
126+
inclusive, omitBlockArguments,
127+
omitUsesFromAbove);
128+
}
129+
130+
/// Matches all transitive defs of a top-level operation up to N levels
131+
template <typename Matcher>
132+
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
133+
int64_t maxDepth) {
134+
assert(maxDepth >= 0 && "maxDepth must be non-negative");
135+
return BackwardSliceMatcher<Matcher>(std::move(innerMatcher), maxDepth, true,
136+
false, false);
137+
}
138+
139+
} // namespace mlir::query::matcher
140+
141+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

0 commit comments

Comments
 (0)