SDBMTest.cpp
15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SDBM/SDBM.h"
#include "mlir/Dialect/SDBM/SDBMDialect.h"
#include "mlir/Dialect/SDBM/SDBMExpr.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/MLIRContext.h"
#include "gtest/gtest.h"
#include "llvm/ADT/DenseSet.h"
using namespace mlir;
static MLIRContext *ctx() {
static thread_local MLIRContext context(false);
context.getOrLoadDialect<SDBMDialect>();
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
d = ctx()->getOrLoadDialect<SDBMDialect>();
}
return d;
}
static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
static SDBMExpr symb(unsigned pos) {
return SDBMSymbolExpr::get(dialect(), pos);
}
namespace {
using namespace mlir::ops_assertions;
TEST(SDBMOperators, Add) {
auto expr = dim(0) + 42;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getLHS(), dim(0));
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
}
TEST(SDBMOperators, AddFolding) {
auto constant = SDBMConstantExpr::get(dialect(), 2) + 42;
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 44);
auto expr = (dim(0) + 10) + 32;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1));
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffExpr);
EXPECT_EQ(diffExpr.getLHS(), dim(0));
EXPECT_EQ(diffExpr.getRHS(), dim(1));
auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
EXPECT_EQ(inverted, expr);
// Check that opposite values cancel each other, and that we elide the zero
// constant.
expr = dim(0) + 42;
auto onlyDim = expr - 42;
EXPECT_EQ(onlyDim, dim(0));
// Check that we can sink a constant under a negation.
expr = -(dim(0) + 2);
auto negatedSum = (expr + 10).dyn_cast<SDBMNegExpr>();
ASSERT_TRUE(negatedSum);
auto sum = negatedSum.getVar().dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sum);
EXPECT_EQ(sum.getRHS().getValue(), -8);
// Sum with zero is the same as the original expression.
EXPECT_EQ(dim(0) + 0, dim(0));
// Sum of opposite differences is zero.
auto diffOfDiffs =
((dim(0) - dim(1)) + (dim(1) - dim(0))).dyn_cast<SDBMConstantExpr>();
EXPECT_EQ(diffOfDiffs.getValue(), 0);
}
TEST(SDBMOperators, AddNegativeTerms) {
const int64_t A = 7;
const int64_t B = -5;
auto x = SDBMDimExpr::get(dialect(), 0);
auto y = SDBMDimExpr::get(dialect(), 1);
// Check the simplification patterns in addition where one of the variables is
// cancelled out and the result remains an SDBM.
EXPECT_EQ(-(x + A) + ((x + B) - y), -(y + (A - B)));
EXPECT_EQ((x + A) + ((y + B) - x), (y + B) + A);
EXPECT_EQ(((x + A) - y) + (-(x + B)), -(y + (B - A)));
EXPECT_EQ(((x + A) - y) + (y + B), (x + A) + B);
}
TEST(SDBMOperators, Diff) {
auto expr = dim(0) - dim(1);
auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffExpr);
EXPECT_EQ(diffExpr.getLHS(), dim(0));
EXPECT_EQ(diffExpr.getRHS(), dim(1));
}
TEST(SDBMOperators, DiffFolding) {
auto constant = SDBMConstantExpr::get(dialect(), 10) - 3;
auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 7);
auto expr = dim(0) - 3;
auto sumExpr = expr.dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(sumExpr);
EXPECT_EQ(sumExpr.getRHS().getValue(), -3);
auto zero = dim(0) - dim(0);
constantExpr = zero.dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(constantExpr);
EXPECT_EQ(constantExpr.getValue(), 0);
// Check that the constant terms in difference-of-sums are folded.
// (d0 - 3) - (d1 - 5) = (d0 + 2) - d1
auto diffOfSums = ((dim(0) - 3) - (dim(1) - 5)).dyn_cast<SDBMDiffExpr>();
ASSERT_TRUE(diffOfSums);
auto lhs = diffOfSums.getLHS().dyn_cast<SDBMSumExpr>();
ASSERT_TRUE(lhs);
EXPECT_EQ(lhs.getLHS(), dim(0));
EXPECT_EQ(lhs.getRHS().getValue(), 2);
EXPECT_EQ(diffOfSums.getRHS(), dim(1));
// Check that identical dimensions with opposite signs cancel each other.
auto cstOnly = ((dim(0) + 42) - dim(0)).dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(cstOnly);
EXPECT_EQ(cstOnly.getValue(), 42);
// Check that identical terms in sum of diffs cancel out.
auto dimOnly = (-dim(0) + (dim(0) - dim(1)));
EXPECT_EQ(dimOnly, -dim(1));
dimOnly = (dim(0) - dim(1)) + (-dim(0));
EXPECT_EQ(dimOnly, -dim(1));
dimOnly = (dim(0) - dim(1)) + dim(1);
EXPECT_EQ(dimOnly, dim(0));
dimOnly = dim(0) + (dim(1) - dim(0));
EXPECT_EQ(dimOnly, dim(1));
// Top-level zero constant is fine.
cstOnly = (-symb(1) + symb(1)).dyn_cast<SDBMConstantExpr>();
ASSERT_TRUE(cstOnly);
EXPECT_EQ(cstOnly.getValue(), 0);
}
TEST(SDBMOperators, Negate) {
auto sum = dim(0) + 3;
auto negated = (-sum).dyn_cast<SDBMNegExpr>();
ASSERT_TRUE(negated);
EXPECT_EQ(negated.getVar(), sum);
}
TEST(SDBMOperators, Stripe) {
auto expr = stripe(dim(0), 3);
auto stripeExpr = expr.dyn_cast<SDBMStripeExpr>();
ASSERT_TRUE(stripeExpr);
EXPECT_EQ(stripeExpr.getLHS(), dim(0));
EXPECT_EQ(stripeExpr.getStripeFactor().getValue(), 3);
}
TEST(SDBM, RoundTripEqs) {
// Build an SDBM defined by
//
// d0 = s0 # 3 # 5
// s0 # 3 # 5 - d1 + 42 = 0
//
// and perform a double round-trip between the "list of equalities" and SDBM
// representation. After the first round-trip, the equalities may be
// different due to simplification or equivalent substitutions (e.g., the
// second equality may become d0 - d1 + 42 = 0). However, there should not
// be any further simplification after the second round-trip,
// Build the SDBM from a pair of equalities and extract back the lists of
// inequalities and equalities. Check that all equalities are properly
// detected and none of them decayed into inequalities.
auto s = stripe(stripe(symb(0), 3), 5);
auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42});
SmallVector<SDBMExpr, 4> eqs, ineqs;
sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
ASSERT_TRUE(ineqs.empty());
// Do the second round-trip.
auto sdbm2 = SDBM::get(llvm::None, eqs);
SmallVector<SDBMExpr, 4> eqs2, ineqs2;
sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2);
ASSERT_EQ(eqs.size(), eqs2.size());
// Check that the sets of equalities are equal, their order is not relevant.
llvm::DenseSet<SDBMExpr> eqSet, eq2Set;
eqSet.insert(eqs.begin(), eqs.end());
eq2Set.insert(eqs2.begin(), eqs2.end());
EXPECT_EQ(eqSet, eq2Set);
}
TEST(SDBMExpr, Constant) {
// We can create constants and query them.
auto expr = SDBMConstantExpr::get(dialect(), 42);
EXPECT_EQ(expr.getValue(), 42);
// Two separately created constants with identical values are trivially equal.
auto expr2 = SDBMConstantExpr::get(dialect(), 42);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
}
TEST(SDBMExpr, Dim) {
// We can create dimension expressions and query them.
auto expr = SDBMDimExpr::get(dialect(), 0);
EXPECT_EQ(expr.getPosition(), 0u);
// Two separately created dimensions with the same position are trivially
// equal.
auto expr2 = SDBMDimExpr::get(dialect(), 0);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMDimExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
auto symbol = SDBMSymbolExpr::get(dialect(), 0);
EXPECT_NE(expr, symbol);
EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
}
TEST(SDBMExpr, Symbol) {
// We can create symbol expressions and query them.
auto expr = SDBMSymbolExpr::get(dialect(), 0);
EXPECT_EQ(expr.getPosition(), 0u);
// Two separately created symbols with the same position are trivially equal.
auto expr2 = SDBMSymbolExpr::get(dialect(), 0);
EXPECT_EQ(expr, expr2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
EXPECT_TRUE(generic.isa<SDBMInputExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
// Dimensions are not Symbols.
auto symbol = SDBMDimExpr::get(dialect(), 0);
EXPECT_NE(expr, symbol);
EXPECT_FALSE(expr.isa<SDBMDimExpr>());
}
TEST(SDBMExpr, Stripe) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto cst0 = SDBMConstantExpr::get(dialect(), 0);
auto var = SDBMSymbolExpr::get(dialect(), 0);
// We can create stripe expressions and query them.
auto expr = SDBMStripeExpr::get(var, cst2);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getStripeFactor(), cst2);
// Two separately created stripe expressions with the same LHS and RHS are
// trivially equal.
auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2);
EXPECT_EQ(expr, expr2);
// Stripes can be nested.
SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4));
// Non-positive stripe factors are not allowed.
EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
// Stripes can have sums on the LHS.
SDBMStripeExpr::get(SDBMSumExpr::get(var, cst2), cst2);
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
EXPECT_TRUE(generic.isa<SDBMTermExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Neg) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create negation expressions and query them.
auto expr = SDBMNegExpr::get(var);
EXPECT_EQ(expr.getVar(), var);
auto expr2 = SDBMNegExpr::get(stripe);
EXPECT_EQ(expr2.getVar(), stripe);
// Neg expressions are trivially comparable.
EXPECT_EQ(expr, SDBMNegExpr::get(var));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMNegExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Sum) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create sum expressions and query them.
auto expr = SDBMSumExpr::get(var, cst2);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getRHS(), cst2);
auto expr2 = SDBMSumExpr::get(stripe, cst2);
EXPECT_EQ(expr2.getLHS(), stripe);
EXPECT_EQ(expr2.getRHS(), cst2);
// Sum expressions are trivially comparable.
EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMSumExpr>());
EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, Diff) {
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
// We can create sum expressions and query them.
auto expr = SDBMDiffExpr::get(var, stripe);
EXPECT_EQ(expr.getLHS(), var);
EXPECT_EQ(expr.getRHS(), stripe);
auto expr2 = SDBMDiffExpr::get(stripe, var);
EXPECT_EQ(expr2.getLHS(), stripe);
EXPECT_EQ(expr2.getRHS(), var);
// Sum expressions are trivially comparable.
EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
// Hierarchy is okay.
auto generic = static_cast<SDBMExpr>(expr);
EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
TEST(SDBMExpr, AffineRoundTrip) {
// Build an expression (s0 - s0 # 2)
auto cst2 = SDBMConstantExpr::get(dialect(), 2);
auto var = SDBMSymbolExpr::get(dialect(), 0);
auto stripe = SDBMStripeExpr::get(var, cst2);
auto expr = SDBMDiffExpr::get(var, stripe);
// Check that it can be converted to AffineExpr and back, i.e. stripe
// detection works correctly.
Optional<SDBMExpr> roundtripped =
SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(expr));
// Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
// detection supports nested expressions.
auto cst5 = SDBMConstantExpr::get(dialect(), 5);
auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
// Check that ((s0 + 2) # 5) can be round-tripped through AffineExpr, i.e.
// stripe detection supports sum expressions.
auto inner = SDBMSumExpr::get(var, cst2);
auto stripeSum = SDBMStripeExpr::get(inner, cst5);
roundtripped = SDBMExpr::tryConvertAffineExpr(stripeSum.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(stripeSum));
// Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
// deeper expression tree.
auto sum = SDBMSumExpr::get(outerStripe, cst2);
auto diff = SDBMDiffExpr::get(sum, stripe);
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
// Check a nested stripe-sum combination.
auto cst7 = SDBMConstantExpr::get(dialect(), 7);
auto nestedStripe =
SDBMStripeExpr::get(SDBMSumExpr::get(stripeSum, cst2), cst7);
diff = SDBMDiffExpr::get(nestedStripe, stripe);
roundtripped = SDBMExpr::tryConvertAffineExpr(diff.getAsAffineExpr());
ASSERT_TRUE(roundtripped.hasValue());
EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(diff));
}
TEST(SDBMExpr, MatchStripeMulPattern) {
// Make sure conversion from AffineExpr recognizes multiplicative stripe
// pattern (x floordiv B) * B == x # B.
auto cst = getAffineConstantExpr(42, ctx());
auto dim = getAffineDimExpr(0, ctx());
auto floor = dim.floorDiv(cst);
auto mul = cst * floor;
Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
ASSERT_TRUE(converted.hasValue());
EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
}
TEST(SDBMExpr, NonSDBM) {
auto d0 = getAffineDimExpr(0, ctx());
auto d1 = getAffineDimExpr(1, ctx());
auto sum = d0 + d1;
auto c2 = getAffineConstantExpr(2, ctx());
auto prod = d0 * c2;
auto ceildiv = d1.ceilDiv(c2);
// The following are not valid SDBM expressions:
// - a sum of two variables
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue());
// - a variable with coefficient other than 1 or -1
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue());
// - a ceildiv expression
EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue());
}
} // end namespace