Module.cpp 3.58 KB
//===- Module.cpp - MLIR Module Operation ---------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Module.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Module Operation.
//===----------------------------------------------------------------------===//

void ModuleOp::build(OpBuilder &builder, OperationState &result,
                     Optional<StringRef> name) {
  ensureTerminator(*result.addRegion(), builder, result.location);
  if (name)
    result.attributes.push_back(builder.getNamedAttr(
        mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
}

/// Construct a module from the given context.
ModuleOp ModuleOp::create(Location loc, Optional<StringRef> name) {
  OperationState state(loc, "module");
  OpBuilder builder(loc->getContext());
  ModuleOp::build(builder, state, name);
  return cast<ModuleOp>(Operation::create(state));
}

ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
  // If the name is present, parse it.
  StringAttr nameAttr;
  (void)parser.parseOptionalSymbolName(
      nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);

  // If module attributes are present, parse them.
  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
    return failure();

  // Parse the module body.
  auto *body = result.addRegion();
  if (parser.parseRegion(*body, llvm::None, llvm::None))
    return failure();

  // Ensure that this module has a valid terminator.
  ensureTerminator(*body, parser.getBuilder(), result.location);
  return success();
}

void ModuleOp::print(OpAsmPrinter &p) {
  p << "module";

  if (Optional<StringRef> name = getName()) {
    p << ' ';
    p.printSymbolName(*name);
  }

  // Print the module attributes.
  p.printOptionalAttrDictWithKeyword(getAttrs(),
                                     {mlir::SymbolTable::getSymbolAttrName()});

  // Print the region.
  p.printRegion(getOperation()->getRegion(0), /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/false);
}

LogicalResult ModuleOp::verify() {
  auto &bodyRegion = getOperation()->getRegion(0);

  // The body must contain a single basic block.
  if (!llvm::hasSingleElement(bodyRegion))
    return emitOpError("expected body region to have a single block");

  // Check that none of the attributes are non-dialect attributes, except for
  // the symbol related attributes.
  for (auto attr : getOperation()->getMutableAttrDict().getAttrs()) {
    if (!attr.first.strref().contains('.') &&
        !llvm::is_contained(
            ArrayRef<StringRef>{mlir::SymbolTable::getSymbolAttrName(),
                                mlir::SymbolTable::getVisibilityAttrName()},
            attr.first.strref()))
      return emitOpError(
                 "can only contain dialect-specific attributes, found: '")
             << attr.first << "'";
  }

  return success();
}

/// Return body of this module.
Region &ModuleOp::getBodyRegion() { return getOperation()->getRegion(0); }
Block *ModuleOp::getBody() { return &getBodyRegion().front(); }

Optional<StringRef> ModuleOp::getName() {
  if (auto nameAttr =
          getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
    return nameAttr.getValue();
  return llvm::None;
}