DeserializationTest.cpp 9.36 KB
//===- DeserializationTest.cpp - SPIR-V Deserialization Tests -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The purpose of this file is to provide negative deserialization tests.
// For positive deserialization tests, please use serialization and
// deserialization for roundtripping.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVModule.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "gmock/gmock.h"

#include <memory>

using namespace mlir;

/// Load the SPIRV dialect.
static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration;

using ::testing::StrEq;

//===----------------------------------------------------------------------===//
// Test Fixture
//===----------------------------------------------------------------------===//

/// A deserialization test fixture providing minimal SPIR-V building and
/// diagnostic checking utilities.
class DeserializationTest : public ::testing::Test {
protected:
  DeserializationTest() : context(/*loadAllDialects=*/false) {
    context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
    // Register a diagnostic handler to capture the diagnostic so that we can
    // check it later.
    context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
      diagnostic.reset(new Diagnostic(std::move(diag)));
    });
  }

  /// Performs deserialization and returns the constructed spv.module op.
  spirv::OwningSPIRVModuleRef deserialize() {
    return spirv::deserialize(binary, &context);
  }

  /// Checks there is a diagnostic generated with the given `errorMessage`.
  void expectDiagnostic(StringRef errorMessage) {
    ASSERT_NE(nullptr, diagnostic.get());

    // TODO: check error location too.
    EXPECT_THAT(diagnostic->str(), StrEq(std::string(errorMessage)));
  }

  //===--------------------------------------------------------------------===//
  // SPIR-V builder methods
  //===--------------------------------------------------------------------===//

  /// Adds the SPIR-V module header to `binary`.
  void addHeader() {
    spirv::appendModuleHeader(binary, spirv::Version::V_1_0, /*idBound=*/0);
  }

  /// Adds the SPIR-V instruction into `binary`.
  void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
    uint32_t wordCount = 1 + operands.size();
    binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
    binary.append(operands.begin(), operands.end());
  }

  uint32_t addVoidType() {
    auto id = nextID++;
    addInstruction(spirv::Opcode::OpTypeVoid, {id});
    return id;
  }

  uint32_t addIntType(uint32_t bitwidth) {
    auto id = nextID++;
    addInstruction(spirv::Opcode::OpTypeInt, {id, bitwidth, /*signedness=*/1});
    return id;
  }

  uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
    auto id = nextID++;
    SmallVector<uint32_t, 2> words;
    words.push_back(id);
    words.append(memberTypes.begin(), memberTypes.end());
    addInstruction(spirv::Opcode::OpTypeStruct, words);
    return id;
  }

  uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
    auto id = nextID++;
    SmallVector<uint32_t, 4> operands;
    operands.push_back(id);
    operands.push_back(retType);
    operands.append(paramTypes.begin(), paramTypes.end());
    addInstruction(spirv::Opcode::OpTypeFunction, operands);
    return id;
  }

  uint32_t addFunction(uint32_t retType, uint32_t fnType) {
    auto id = nextID++;
    addInstruction(spirv::Opcode::OpFunction,
                   {retType, id,
                    static_cast<uint32_t>(spirv::FunctionControl::None),
                    fnType});
    return id;
  }

  void addFunctionEnd() { addInstruction(spirv::Opcode::OpFunctionEnd, {}); }

  void addReturn() { addInstruction(spirv::Opcode::OpReturn, {}); }

protected:
  SmallVector<uint32_t, 5> binary;
  uint32_t nextID = 1;
  MLIRContext context;
  std::unique_ptr<Diagnostic> diagnostic;
};

//===----------------------------------------------------------------------===//
// Basics
//===----------------------------------------------------------------------===//

TEST_F(DeserializationTest, EmptyModuleFailure) {
  ASSERT_FALSE(deserialize());
  expectDiagnostic("SPIR-V binary module must have a 5-word header");
}

TEST_F(DeserializationTest, WrongMagicNumberFailure) {
  addHeader();
  binary.front() = 0xdeadbeef; // Change to a wrong magic number
  ASSERT_FALSE(deserialize());
  expectDiagnostic("incorrect magic number");
}

TEST_F(DeserializationTest, OnlyHeaderSuccess) {
  addHeader();
  EXPECT_TRUE(deserialize());
}

TEST_F(DeserializationTest, ZeroWordCountFailure) {
  addHeader();
  binary.push_back(0); // OpNop with zero word count

  ASSERT_FALSE(deserialize());
  expectDiagnostic("word count cannot be zero");
}

TEST_F(DeserializationTest, InsufficientWordFailure) {
  addHeader();
  binary.push_back((2u << 16) |
                   static_cast<uint32_t>(spirv::Opcode::OpTypeVoid));
  // Missing word for type <id>.

  ASSERT_FALSE(deserialize());
  expectDiagnostic("insufficient words for the last instruction");
}

//===----------------------------------------------------------------------===//
// Types
//===----------------------------------------------------------------------===//

TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
  addHeader();
  addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32});

  ASSERT_FALSE(deserialize());
  expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
}

//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//

TEST_F(DeserializationTest, OpMemberNameSuccess) {
  addHeader();
  SmallVector<uint32_t, 5> typeDecl;
  std::swap(typeDecl, binary);

  auto int32Type = addIntType(32);
  auto structType = addStructType({int32Type, int32Type});
  std::swap(typeDecl, binary);

  SmallVector<uint32_t, 5> operands1 = {structType, 0};
  spirv::encodeStringLiteralInto(operands1, "i1");
  addInstruction(spirv::Opcode::OpMemberName, operands1);

  SmallVector<uint32_t, 5> operands2 = {structType, 1};
  spirv::encodeStringLiteralInto(operands2, "i2");
  addInstruction(spirv::Opcode::OpMemberName, operands2);

  binary.append(typeDecl.begin(), typeDecl.end());
  EXPECT_TRUE(deserialize());
}

TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
  addHeader();
  SmallVector<uint32_t, 5> typeDecl;
  std::swap(typeDecl, binary);

  auto int32Type = addIntType(32);
  auto int64Type = addIntType(64);
  auto structType = addStructType({int32Type, int64Type});
  std::swap(typeDecl, binary);

  SmallVector<uint32_t, 5> operands1 = {structType};
  addInstruction(spirv::Opcode::OpMemberName, operands1);

  binary.append(typeDecl.begin(), typeDecl.end());
  ASSERT_FALSE(deserialize());
  expectDiagnostic("OpMemberName must have at least 3 operands");
}

TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
  addHeader();
  SmallVector<uint32_t, 5> typeDecl;
  std::swap(typeDecl, binary);

  auto int32Type = addIntType(32);
  auto structType = addStructType({int32Type});
  std::swap(typeDecl, binary);

  SmallVector<uint32_t, 5> operands = {structType, 0};
  spirv::encodeStringLiteralInto(operands, "int32");
  operands.push_back(42);
  addInstruction(spirv::Opcode::OpMemberName, operands);

  binary.append(typeDecl.begin(), typeDecl.end());
  ASSERT_FALSE(deserialize());
  expectDiagnostic("unexpected trailing words in OpMemberName instruction");
}

//===----------------------------------------------------------------------===//
// Functions
//===----------------------------------------------------------------------===//

TEST_F(DeserializationTest, FunctionMissingEndFailure) {
  addHeader();
  auto voidType = addVoidType();
  auto fnType = addFunctionType(voidType, {});
  addFunction(voidType, fnType);
  // Missing OpFunctionEnd.

  ASSERT_FALSE(deserialize());
  expectDiagnostic("expected OpFunctionEnd instruction");
}

TEST_F(DeserializationTest, FunctionMissingParameterFailure) {
  addHeader();
  auto voidType = addVoidType();
  auto i32Type = addIntType(32);
  auto fnType = addFunctionType(voidType, {i32Type});
  addFunction(voidType, fnType);
  // Missing OpFunctionParameter.

  ASSERT_FALSE(deserialize());
  expectDiagnostic("expected OpFunctionParameter instruction");
}

TEST_F(DeserializationTest, FunctionMissingLabelForFirstBlockFailure) {
  addHeader();
  auto voidType = addVoidType();
  auto fnType = addFunctionType(voidType, {});
  addFunction(voidType, fnType);
  // Missing OpLabel.
  addReturn();
  addFunctionEnd();

  ASSERT_FALSE(deserialize());
  expectDiagnostic("a basic block must start with OpLabel");
}

TEST_F(DeserializationTest, FunctionMalformedLabelFailure) {
  addHeader();
  auto voidType = addVoidType();
  auto fnType = addFunctionType(voidType, {});
  addFunction(voidType, fnType);
  addInstruction(spirv::Opcode::OpLabel, {}); // Malformed OpLabel
  addReturn();
  addFunctionEnd();

  ASSERT_FALSE(deserialize());
  expectDiagnostic("OpLabel should only have result <id>");
}