mlir_test_cblas_interface.cpp
4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
//===- mlir_test_cblas_interface.cpp - Simple Blas subset interface -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Simple Blas subset interface implementation.
//
//===----------------------------------------------------------------------===//
#include "include/mlir_test_cblas_interface.h"
#include "include/mlir_test_cblas.h"
#include <assert.h>
#include <iostream>
extern "C" void
_mlir_ciface_linalg_fill_viewf32_f32(StridedMemRefType<float, 0> *X, float f) {
X->data[X->offset] = f;
}
extern "C" void
_mlir_ciface_linalg_fill_viewsxf32_f32(StridedMemRefType<float, 1> *X,
float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i)
*(X->data + X->offset + i * X->strides[0]) = f;
}
extern "C" void
_mlir_ciface_linalg_fill_viewsxsxf32_f32(StridedMemRefType<float, 2> *X,
float f) {
for (unsigned i = 0; i < X->sizes[0]; ++i)
for (unsigned j = 0; j < X->sizes[1]; ++j)
*(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f;
}
extern "C" void
_mlir_ciface_linalg_copy_viewf32_viewf32(StridedMemRefType<float, 0> *I,
StridedMemRefType<float, 0> *O) {
O->data[O->offset] = I->data[I->offset];
}
extern "C" void
_mlir_ciface_linalg_copy_viewsxf32_viewsxf32(StridedMemRefType<float, 1> *I,
StridedMemRefType<float, 1> *O) {
if (I->sizes[0] != O->sizes[0]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *I);
printMemRefMetaData(std::cerr, *O);
return;
}
for (unsigned i = 0; i < I->sizes[0]; ++i)
O->data[O->offset + i * O->strides[0]] =
I->data[I->offset + i * I->strides[0]];
}
extern "C" void _mlir_ciface_linalg_copy_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *I, StridedMemRefType<float, 2> *O) {
if (I->sizes[0] != O->sizes[0] || I->sizes[1] != O->sizes[1]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *I);
printMemRefMetaData(std::cerr, *O);
return;
}
auto so0 = O->strides[0], so1 = O->strides[1];
auto si0 = I->strides[0], si1 = I->strides[1];
for (unsigned i = 0; i < I->sizes[0]; ++i)
for (unsigned j = 0; j < I->sizes[1]; ++j)
O->data[O->offset + i * so0 + j * so1] =
I->data[I->offset + i * si0 + j * si1];
}
extern "C" void _mlir_ciface_linalg_dot_viewsxf32_viewsxf32_viewf32(
StridedMemRefType<float, 1> *X, StridedMemRefType<float, 1> *Y,
StridedMemRefType<float, 0> *Z) {
if (X->strides[0] != 1 || Y->strides[0] != 1 || X->sizes[0] != Y->sizes[0]) {
std::cerr << "Incompatible strided memrefs\n";
printMemRefMetaData(std::cerr, *X);
printMemRefMetaData(std::cerr, *Y);
printMemRefMetaData(std::cerr, *Z);
return;
}
Z->data[Z->offset] +=
mlir_test_cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0],
Y->data + Y->offset, Y->strides[0]);
}
extern "C" void _mlir_ciface_linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32(
StridedMemRefType<float, 2> *A, StridedMemRefType<float, 2> *B,
StridedMemRefType<float, 2> *C) {
if (A->strides[1] != B->strides[1] || A->strides[1] != C->strides[1] ||
A->strides[1] != 1 || A->sizes[0] < A->strides[1] ||
B->sizes[0] < B->strides[1] || C->sizes[0] < C->strides[1] ||
C->sizes[0] != A->sizes[0] || C->sizes[1] != B->sizes[1] ||
A->sizes[1] != B->sizes[0]) {
printMemRefMetaData(std::cerr, *A);
printMemRefMetaData(std::cerr, *B);
printMemRefMetaData(std::cerr, *C);
return;
}
mlir_test_cblas_sgemm(
CBLAS_ORDER::CblasRowMajor, CBLAS_TRANSPOSE::CblasNoTrans,
CBLAS_TRANSPOSE::CblasNoTrans, C->sizes[0], C->sizes[1], A->sizes[1],
1.0f, A->data + A->offset, A->strides[0], B->data + B->offset,
B->strides[0], 1.0f, C->data + C->offset, C->strides[0]);
}