vulkan-runtime-wrappers.cpp
6.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
//===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements C runtime wrappers around the VulkanRuntime.
//
//===----------------------------------------------------------------------===//
#include <iostream>
#include <mutex>
#include <numeric>
#include "VulkanRuntime.h"
namespace {
class VulkanRuntimeManager {
public:
VulkanRuntimeManager() = default;
VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
~VulkanRuntimeManager() = default;
void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
const VulkanHostMemoryBuffer &memBuffer) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
}
void setEntryPoint(const char *entryPoint) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setEntryPoint(entryPoint);
}
void setNumWorkGroups(NumWorkGroups numWorkGroups) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setNumWorkGroups(numWorkGroups);
}
void setShaderModule(uint8_t *shader, uint32_t size) {
std::lock_guard<std::mutex> lock(mutex);
vulkanRuntime.setShaderModule(shader, size);
}
void runOnVulkan() {
std::lock_guard<std::mutex> lock(mutex);
if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
failed(vulkanRuntime.updateHostMemoryBuffers()) ||
failed(vulkanRuntime.destroy())) {
std::cerr << "runOnVulkan failed";
}
}
private:
VulkanRuntime vulkanRuntime;
std::mutex mutex;
};
} // namespace
template <typename T, int N>
struct MemRefDescriptor {
T *allocated;
T *aligned;
int64_t offset;
int64_t sizes[N];
int64_t strides[N];
};
template <typename T, uint32_t S>
void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,
BindingIndex bindIndex, MemRefDescriptor<T, S> *ptr) {
uint32_t size = sizeof(T);
for (unsigned i = 0; i < S; i++)
size *= ptr->sizes[i];
VulkanHostMemoryBuffer memBuffer{ptr->allocated, size};
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setResourceData(setIndex, bindIndex, memBuffer);
}
extern "C" {
/// Initializes `VulkanRuntimeManager` and returns a pointer to it.
void *initVulkan() { return new VulkanRuntimeManager(); }
/// Deinitializes `VulkanRuntimeManager` by the given pointer.
void deinitVulkan(void *vkRuntimeManager) {
delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
}
void runOnVulkan(void *vkRuntimeManager) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
}
void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setEntryPoint(entryPoint);
}
void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y,
uint32_t z) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setNumWorkGroups({x, y, z});
}
void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setShaderModule(shader, size);
}
/// Binds the given memref to the given descriptor set and descriptor
/// index.
#define DECLARE_BIND_MEMREF(size, type, typeName) \
void bindMemRef##size##D##typeName( \
void *vkRuntimeManager, DescriptorSetIndex setIndex, \
BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \
bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \
}
DECLARE_BIND_MEMREF(1, float, Float)
DECLARE_BIND_MEMREF(2, float, Float)
DECLARE_BIND_MEMREF(3, float, Float)
DECLARE_BIND_MEMREF(1, int32_t, Int32)
DECLARE_BIND_MEMREF(2, int32_t, Int32)
DECLARE_BIND_MEMREF(3, int32_t, Int32)
DECLARE_BIND_MEMREF(1, int16_t, Int16)
DECLARE_BIND_MEMREF(2, int16_t, Int16)
DECLARE_BIND_MEMREF(3, int16_t, Int16)
DECLARE_BIND_MEMREF(1, int8_t, Int8)
DECLARE_BIND_MEMREF(2, int8_t, Int8)
DECLARE_BIND_MEMREF(3, int8_t, Int8)
DECLARE_BIND_MEMREF(1, int16_t, Half)
DECLARE_BIND_MEMREF(2, int16_t, Half)
DECLARE_BIND_MEMREF(3, int16_t, Half)
/// Fills the given 1D float memref with the given float value.
void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}
/// Fills the given 2D float memref with the given float value.
void _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}
/// Fills the given 3D float memref with the given float value.
void _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
/// Fills the given 1D int memref with the given int value.
void _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}
/// Fills the given 2D int memref with the given int value.
void _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}
/// Fills the given 3D int memref with the given int value.
void _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
/// Fills the given 1D int memref with the given int8 value.
void _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}
/// Fills the given 2D int memref with the given int8 value.
void _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}
/// Fills the given 3D int memref with the given int8 value.
void _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
}