ops.mlir
3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s
// Verify the printed output can be parsed.
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// Verify the generic form can be parsed.
// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s
// CHECK-LABEL: shape_num_elements
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
%init = shape.const_size 0
%num_elements = shape.reduce(%shape, %init) -> !shape.size {
^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
%acc = shape.add %lci, %dim
shape.yield %acc : !shape.size
}
return %num_elements : !shape.size
}
func @test_shape_num_elements_unknown() {
%0 = "shape.unknown_shape"() : () -> !shape.shape
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
%2 = "shape.print"(%1) : (!shape.size) -> !shape.size
return
}
func @test_shape_num_elements_fixed() {
%0 = shape.const_shape [1, 57, 92]
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
return
}
func @test_broadcastable_fixed() {
%0 = shape.const_shape [10, 1, 57, 92]
%1 = shape.const_shape [4, 57, 92]
%2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed() {
%0 = shape.const_shape [4, 57, 92]
%1 = shape.const_shape [4, 57, 92]
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_unknown() {
%0 = shape.const_shape [4, -1, 92]
%1 = shape.const_shape [-1, 57, 92]
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed_mismatch() {
%0 = shape.const_shape [4, 57, 92]
%1 = shape.const_shape [2, 57, 92]
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_parse_const_shape() {
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
return
}
func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
%0 = shape.shape_of %arg0 : tensor<?xf32>
return %0 : !shape.shape
}
func @test_constraints() {
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
%w0 = shape.cstr_broadcastable %0, %1
%w1 = shape.cstr_eq %0, %1
%w2 = shape.const_witness true
%w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape {
%2 = shape.any %0, %1
shape.assuming_yield %2 : !shape.shape
}
return
}
func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size {
%product = shape.mul %lhs, %rhs
return %product: !shape.size
}
func @const_size() {
// CHECK: %c1 = shape.const_size 1
// CHECK: %c2 = shape.const_size 2
// CHECK: %c2_0 = shape.const_size 2
%0 = shape.const_size 1
%1 = shape.const_size 2
%2 = shape.const_size 2
return
}
func @test_to_extent_tensor(%arg: !shape.shape) -> tensor<3xindex> {
%0 = shape.to_extent_tensor %arg : tensor<3xindex>
return %0 : tensor<3xindex>
}
func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
%0 = shape.from_extent_tensor %arg : tensor<?xindex>
return %0 : !shape.shape
}
func @rank(%shape : !shape.shape) -> !shape.size {
%rank = shape.rank %shape
return %rank : !shape.size
}