LayoutUtils.cpp
6.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
//===-- LayoutUtils.cpp - Decorate composite type with layout information -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements Utilities used to get alignment and layout information
// for types in SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/LayoutUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
using namespace mlir;
spirv::StructType
VulkanLayoutUtils::decorateType(spirv::StructType structType) {
Size size = 0;
Size alignment = 1;
return decorateType(structType, size, alignment);
}
spirv::StructType
VulkanLayoutUtils::decorateType(spirv::StructType structType,
VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
if (structType.getNumElements() == 0) {
return structType;
}
SmallVector<Type, 4> memberTypes;
SmallVector<spirv::StructType::OffsetInfo, 4> offsetInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
Size structMemberOffset = 0;
Size maxMemberAlignment = 1;
for (uint32_t i = 0, e = structType.getNumElements(); i < e; ++i) {
Size memberSize = 0;
Size memberAlignment = 1;
auto memberType =
decorateType(structType.getElementType(i), memberSize, memberAlignment);
structMemberOffset = llvm::alignTo(structMemberOffset, memberAlignment);
memberTypes.push_back(memberType);
offsetInfo.push_back(
static_cast<spirv::StructType::OffsetInfo>(structMemberOffset));
// If the member's size is the max value, it must be the last member and it
// must be a runtime array.
assert(memberSize != std::numeric_limits<Size>().max() ||
(i + 1 == e &&
structType.getElementType(i).isa<spirv::RuntimeArrayType>()));
// According to the Vulkan spec:
// "A structure has a base alignment equal to the largest base alignment of
// any of its members."
structMemberOffset += memberSize;
maxMemberAlignment = std::max(maxMemberAlignment, memberAlignment);
}
// According to the Vulkan spec:
// "The Offset decoration of a member must not place it between the end of a
// structure or an array and the next multiple of the alignment of that
// structure or array."
size = llvm::alignTo(structMemberOffset, maxMemberAlignment);
alignment = maxMemberAlignment;
structType.getMemberDecorations(memberDecorations);
return spirv::StructType::get(memberTypes, offsetInfo, memberDecorations);
}
Type VulkanLayoutUtils::decorateType(Type type, VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
if (type.isa<spirv::ScalarType>()) {
alignment = getScalarTypeAlignment(type);
// Vulkan spec does not specify any padding for a scalar type.
size = alignment;
return type;
}
if (auto structType = type.dyn_cast<spirv::StructType>())
return decorateType(structType, size, alignment);
if (auto arrayType = type.dyn_cast<spirv::ArrayType>())
return decorateType(arrayType, size, alignment);
if (auto vectorType = type.dyn_cast<VectorType>())
return decorateType(vectorType, size, alignment);
if (auto arrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
size = std::numeric_limits<Size>().max();
return decorateType(arrayType, alignment);
}
llvm_unreachable("unhandled SPIR-V type");
}
Type VulkanLayoutUtils::decorateType(VectorType vectorType,
VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
const auto numElements = vectorType.getNumElements();
auto elementType = vectorType.getElementType();
Size elementSize = 0;
Size elementAlignment = 1;
auto memberType = decorateType(elementType, elementSize, elementAlignment);
// According to the Vulkan spec:
// 1. "A two-component vector has a base alignment equal to twice its scalar
// alignment."
// 2. "A three- or four-component vector has a base alignment equal to four
// times its scalar alignment."
size = elementSize * numElements;
alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
return VectorType::get(numElements, memberType);
}
Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
VulkanLayoutUtils::Size &size,
VulkanLayoutUtils::Size &alignment) {
const auto numElements = arrayType.getNumElements();
auto elementType = arrayType.getElementType();
Size elementSize = 0;
Size elementAlignment = 1;
auto memberType = decorateType(elementType, elementSize, elementAlignment);
// According to the Vulkan spec:
// "An array has a base alignment equal to the base alignment of its element
// type."
size = elementSize * numElements;
alignment = elementAlignment;
return spirv::ArrayType::get(memberType, numElements, elementSize);
}
Type VulkanLayoutUtils::decorateType(spirv::RuntimeArrayType arrayType,
VulkanLayoutUtils::Size &alignment) {
auto elementType = arrayType.getElementType();
Size elementSize = 0;
auto memberType = decorateType(elementType, elementSize, alignment);
return spirv::RuntimeArrayType::get(memberType, elementSize);
}
VulkanLayoutUtils::Size
VulkanLayoutUtils::getScalarTypeAlignment(Type scalarType) {
// According to the Vulkan spec:
// 1. "A scalar of size N has a scalar alignment of N."
// 2. "A scalar has a base alignment equal to its scalar alignment."
// 3. "A scalar, vector or matrix type has an extended alignment equal to its
// base alignment."
auto bitWidth = scalarType.getIntOrFloatBitWidth();
if (bitWidth == 1)
return 1;
return bitWidth / 8;
}
bool VulkanLayoutUtils::isLegalType(Type type) {
auto ptrType = type.dyn_cast<spirv::PointerType>();
if (!ptrType) {
return true;
}
auto storageClass = ptrType.getStorageClass();
auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
if (!structType) {
return true;
}
switch (storageClass) {
case spirv::StorageClass::Uniform:
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::PushConstant:
case spirv::StorageClass::PhysicalStorageBuffer:
return structType.hasOffset() || !structType.getNumElements();
default:
return true;
}
}