OpReducer.cpp 1.72 KB
//===- OpReducer.cpp - Operation Reducer ------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the OpReducer class. It defines a variant generator method
// with the purpose of producing different variants by eliminating a
// parametarizable type of operations from the  parent module.
//
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/Passes/OpReducer.h"

using namespace mlir;

OpReducerImpl::OpReducerImpl(
    llvm::function_ref<std::vector<Operation *>(ModuleOp)> getSpecificOps)
    : getSpecificOps(getSpecificOps) {}

/// Return the name of this reducer class.
StringRef OpReducerImpl::getName() {
  return StringRef("High Level Operation Reduction");
}

/// Return the initial transformSpace containing the transformable indices.
std::vector<bool> OpReducerImpl::initTransformSpace(ModuleOp module) {
  auto ops = getSpecificOps(module);
  int numOps = std::distance(ops.begin(), ops.end());
  return ReductionTreeUtils::createTransformSpace(module, numOps);
}

/// Generate variants by removing opType operations from the module in the
/// parent and link the variants as childs in the Reduction Tree Pass.
void OpReducerImpl::generateVariants(
    ReductionNode *parent, const Tester &test, int numVariants,
    llvm::function_ref<void(ModuleOp, int, int)> transform) {
  ReductionTreeUtils::createVariants(parent, test, numVariants, transform,
                                     true);
}