CallGraph.cpp 7.71 KB
//===- CallGraph.cpp - CallGraph analysis for MLIR ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains interfaces and analyses for defining a nested callgraph.
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// CallGraphNode
//===----------------------------------------------------------------------===//

/// Returns true if this node refers to the indirect/external node.
bool CallGraphNode::isExternal() const { return !callableRegion; }

/// Return the callable region this node represents. This can only be called
/// on non-external nodes.
Region *CallGraphNode::getCallableRegion() const {
  assert(!isExternal() && "the external node has no callable region");
  return callableRegion;
}

/// Adds an reference edge to the given node. This is only valid on the
/// external node.
void CallGraphNode::addAbstractEdge(CallGraphNode *node) {
  assert(isExternal() && "abstract edges are only valid on external nodes");
  addEdge(node, Edge::Kind::Abstract);
}

/// Add an outgoing call edge from this node.
void CallGraphNode::addCallEdge(CallGraphNode *node) {
  addEdge(node, Edge::Kind::Call);
}

/// Adds a reference edge to the given child node.
void CallGraphNode::addChildEdge(CallGraphNode *child) {
  addEdge(child, Edge::Kind::Child);
}

/// Returns true if this node has any child edges.
bool CallGraphNode::hasChildren() const {
  return llvm::any_of(edges, [](const Edge &edge) { return edge.isChild(); });
}

/// Add an edge to 'node' with the given kind.
void CallGraphNode::addEdge(CallGraphNode *node, Edge::Kind kind) {
  edges.insert({node, kind});
}

//===----------------------------------------------------------------------===//
// CallGraph
//===----------------------------------------------------------------------===//

/// Recursively compute the callgraph edges for the given operation. Computed
/// edges are placed into the given callgraph object.
static void computeCallGraph(Operation *op, CallGraph &cg,
                             CallGraphNode *parentNode, bool resolveCalls) {
  if (CallOpInterface call = dyn_cast<CallOpInterface>(op)) {
    // If there is no parent node, we ignore this operation. Even if this
    // operation was a call, there would be no callgraph node to attribute it
    // to.
    if (resolveCalls && parentNode)
      parentNode->addCallEdge(cg.resolveCallable(call));
    return;
  }

  // Compute the callgraph nodes and edges for each of the nested operations.
  if (CallableOpInterface callable = dyn_cast<CallableOpInterface>(op)) {
    if (auto *callableRegion = callable.getCallableRegion())
      parentNode = cg.getOrAddNode(callableRegion, parentNode);
    else
      return;
  }

  for (Region &region : op->getRegions())
    for (Operation &nested : region.getOps())
      computeCallGraph(&nested, cg, parentNode, resolveCalls);
}

CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) {
  // Make two passes over the graph, one to compute the callables and one to
  // resolve the calls. We split these up as we may have nested callable objects
  // that need to be reserved before the calls.
  computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false);
  computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true);
}

/// Get or add a call graph node for the given region.
CallGraphNode *CallGraph::getOrAddNode(Region *region,
                                       CallGraphNode *parentNode) {
  assert(region && isa<CallableOpInterface>(region->getParentOp()) &&
         "expected parent operation to be callable");
  std::unique_ptr<CallGraphNode> &node = nodes[region];
  if (!node) {
    node.reset(new CallGraphNode(region));

    // Add this node to the given parent node if necessary.
    if (parentNode)
      parentNode->addChildEdge(node.get());
    else
      // Otherwise, connect all callable nodes to the external node, this allows
      // for conservatively including all callable nodes within the graph.
      // FIXME(riverriddle) This isn't correct, this is only necessary for
      // callable nodes that *could* be called from external sources. This
      // requires extending the interface for callables to check if they may be
      // referenced externally.
      externalNode.addAbstractEdge(node.get());
  }
  return node.get();
}

/// Lookup a call graph node for the given region, or nullptr if none is
/// registered.
CallGraphNode *CallGraph::lookupNode(Region *region) const {
  auto it = nodes.find(region);
  return it == nodes.end() ? nullptr : it->second.get();
}

/// Resolve the callable for given callee to a node in the callgraph, or the
/// external node if a valid node was not resolved.
CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const {
  Operation *callable = call.resolveCallable();
  if (auto callableOp = dyn_cast_or_null<CallableOpInterface>(callable))
    if (auto *node = lookupNode(callableOp.getCallableRegion()))
      return node;

  // If we don't have a valid direct region, this is an external call.
  return getExternalNode();
}

/// Erase the given node from the callgraph.
void CallGraph::eraseNode(CallGraphNode *node) {
  // Erase any children of this node first.
  if (node->hasChildren()) {
    for (const CallGraphNode::Edge &edge : llvm::make_early_inc_range(*node))
      if (edge.isChild())
        eraseNode(edge.getTarget());
  }
  // Erase any edges to this node from any other nodes.
  for (auto &it : nodes) {
    it.second->edges.remove_if([node](const CallGraphNode::Edge &edge) {
      return edge.getTarget() == node;
    });
  }
  nodes.erase(node->getCallableRegion());
}

//===----------------------------------------------------------------------===//
// Printing

/// Dump the graph in a human readable format.
void CallGraph::dump() const { print(llvm::errs()); }
void CallGraph::print(raw_ostream &os) const {
  os << "// ---- CallGraph ----\n";

  // Functor used to output the name for the given node.
  auto emitNodeName = [&](const CallGraphNode *node) {
    if (node->isExternal()) {
      os << "<External-Node>";
      return;
    }

    auto *callableRegion = node->getCallableRegion();
    auto *parentOp = callableRegion->getParentOp();
    os << "'" << callableRegion->getParentOp()->getName() << "' - Region #"
       << callableRegion->getRegionNumber();
    auto attrs = parentOp->getAttrDictionary();
    if (!attrs.empty())
      os << " : " << attrs;
  };

  for (auto &nodeIt : nodes) {
    const CallGraphNode *node = nodeIt.second.get();

    // Dump the header for this node.
    os << "// - Node : ";
    emitNodeName(node);
    os << "\n";

    // Emit each of the edges.
    for (auto &edge : *node) {
      os << "// -- ";
      if (edge.isCall())
        os << "Call";
      else if (edge.isChild())
        os << "Child";

      os << "-Edge : ";
      emitNodeName(edge.getTarget());
      os << "\n";
    }
    os << "//\n";
  }

  os << "// -- SCCs --\n";

  for (auto &scc : make_range(llvm::scc_begin(this), llvm::scc_end(this))) {
    os << "// - SCC : \n";
    for (auto &node : scc) {
      os << "// -- Node :";
      emitNodeName(node);
      os << "\n";
    }
    os << "\n";
  }

  os << "// -------------------\n";
}