UseTrailingReturnTypeCheck.cpp 18.7 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
//===--- UseTrailingReturnTypeCheck.cpp - clang-tidy-----------------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "UseTrailingReturnTypeCheck.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Lex/Preprocessor.h"
#include "clang/Tooling/FixIt.h"
#include "llvm/ADT/StringExtras.h"

#include <cctype>

using namespace clang::ast_matchers;

namespace clang {
namespace tidy {
namespace modernize {
namespace {
struct UnqualNameVisitor : public RecursiveASTVisitor<UnqualNameVisitor> {
public:
  UnqualNameVisitor(const FunctionDecl &F) : F(F) {}

  bool Collision = false;

  bool shouldWalkTypesOfTypeLocs() const { return false; }

  bool VisitUnqualName(StringRef UnqualName) {
    // Check for collisions with function arguments.
    for (ParmVarDecl *Param : F.parameters())
      if (const IdentifierInfo *Ident = Param->getIdentifier())
        if (Ident->getName() == UnqualName) {
          Collision = true;
          return true;
        }
    return false;
  }

  bool TraverseTypeLoc(TypeLoc TL, bool Elaborated = false) {
    if (TL.isNull())
      return true;

    if (!Elaborated) {
      switch (TL.getTypeLocClass()) {
      case TypeLoc::Record:
        if (VisitUnqualName(
                TL.getAs<RecordTypeLoc>().getTypePtr()->getDecl()->getName()))
          return false;
        break;
      case TypeLoc::Enum:
        if (VisitUnqualName(
                TL.getAs<EnumTypeLoc>().getTypePtr()->getDecl()->getName()))
          return false;
        break;
      case TypeLoc::TemplateSpecialization:
        if (VisitUnqualName(TL.getAs<TemplateSpecializationTypeLoc>()
                                .getTypePtr()
                                ->getTemplateName()
                                .getAsTemplateDecl()
                                ->getName()))
          return false;
        break;
      case TypeLoc::Typedef:
        if (VisitUnqualName(
                TL.getAs<TypedefTypeLoc>().getTypePtr()->getDecl()->getName()))
          return false;
        break;
      default:
        break;
      }
    }

    return RecursiveASTVisitor<UnqualNameVisitor>::TraverseTypeLoc(TL);
  }

  // Replace the base method in order to call ower own
  // TraverseTypeLoc().
  bool TraverseQualifiedTypeLoc(QualifiedTypeLoc TL) {
    return TraverseTypeLoc(TL.getUnqualifiedLoc());
  }

  // Replace the base version to inform TraverseTypeLoc that the type is
  // elaborated.
  bool TraverseElaboratedTypeLoc(ElaboratedTypeLoc TL) {
    if (TL.getQualifierLoc() &&
        !TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()))
      return false;
    return TraverseTypeLoc(TL.getNamedTypeLoc(), true);
  }

  bool VisitDeclRefExpr(DeclRefExpr *S) {
    DeclarationName Name = S->getNameInfo().getName();
    return S->getQualifierLoc() || !Name.isIdentifier() ||
           !VisitUnqualName(Name.getAsIdentifierInfo()->getName());
  }

private:
  const FunctionDecl &F;
};
} // namespace

constexpr llvm::StringLiteral Message =
    "use a trailing return type for this function";

static SourceLocation expandIfMacroId(SourceLocation Loc,
                                      const SourceManager &SM) {
  if (Loc.isMacroID())
    Loc = expandIfMacroId(SM.getImmediateExpansionRange(Loc).getBegin(), SM);
  assert(!Loc.isMacroID() &&
         "SourceLocation must not be a macro ID after recursive expansion");
  return Loc;
}

SourceLocation UseTrailingReturnTypeCheck::findTrailingReturnTypeSourceLocation(
    const FunctionDecl &F, const FunctionTypeLoc &FTL, const ASTContext &Ctx,
    const SourceManager &SM, const LangOptions &LangOpts) {
  // We start with the location of the closing parenthesis.
  SourceRange ExceptionSpecRange = F.getExceptionSpecSourceRange();
  if (ExceptionSpecRange.isValid())
    return Lexer::getLocForEndOfToken(ExceptionSpecRange.getEnd(), 0, SM,
                                      LangOpts);

  // If the function argument list ends inside of a macro, it is dangerous to
  // start lexing from here - bail out.
  SourceLocation ClosingParen = FTL.getRParenLoc();
  if (ClosingParen.isMacroID())
    return {};

  SourceLocation Result =
      Lexer::getLocForEndOfToken(ClosingParen, 0, SM, LangOpts);

  // Skip subsequent CV and ref qualifiers.
  std::pair<FileID, unsigned> Loc = SM.getDecomposedLoc(Result);
  StringRef File = SM.getBufferData(Loc.first);
  const char *TokenBegin = File.data() + Loc.second;
  Lexer Lexer(SM.getLocForStartOfFile(Loc.first), LangOpts, File.begin(),
              TokenBegin, File.end());
  Token T;
  while (!Lexer.LexFromRawLexer(T)) {
    if (T.is(tok::raw_identifier)) {
      IdentifierInfo &Info = Ctx.Idents.get(
          StringRef(SM.getCharacterData(T.getLocation()), T.getLength()));
      T.setIdentifierInfo(&Info);
      T.setKind(Info.getTokenID());
    }

    if (T.isOneOf(tok::amp, tok::ampamp, tok::kw_const, tok::kw_volatile,
                  tok::kw_restrict)) {
      Result = T.getEndLoc();
      continue;
    }
    break;
  }
  return Result;
}

static bool IsCVR(Token T) {
  return T.isOneOf(tok::kw_const, tok::kw_volatile, tok::kw_restrict);
}

static bool IsSpecifier(Token T) {
  return T.isOneOf(tok::kw_constexpr, tok::kw_inline, tok::kw_extern,
                   tok::kw_static, tok::kw_friend, tok::kw_virtual);
}

static llvm::Optional<ClassifiedToken>
classifyToken(const FunctionDecl &F, Preprocessor &PP, Token Tok) {
  ClassifiedToken CT;
  CT.T = Tok;
  CT.isQualifier = true;
  CT.isSpecifier = true;
  bool ContainsQualifiers = false;
  bool ContainsSpecifiers = false;
  bool ContainsSomethingElse = false;

  Token End;
  End.startToken();
  End.setKind(tok::eof);
  SmallVector<Token, 2> Stream{Tok, End};

  // FIXME: do not report these token to Preprocessor.TokenWatcher.
  PP.EnterTokenStream(Stream, false, /*IsReinject=*/false);
  while (true) {
    Token T;
    PP.Lex(T);
    if (T.is(tok::eof))
      break;

    bool Qual = IsCVR(T);
    bool Spec = IsSpecifier(T);
    CT.isQualifier &= Qual;
    CT.isSpecifier &= Spec;
    ContainsQualifiers |= Qual;
    ContainsSpecifiers |= Spec;
    ContainsSomethingElse |= !Qual && !Spec;
  }

  // If the Token/Macro contains more than one type of tokens, we would need
  // to split the macro in order to move parts to the trailing return type.
  if (ContainsQualifiers + ContainsSpecifiers + ContainsSomethingElse > 1)
    return llvm::None;

  return CT;
}

llvm::Optional<SmallVector<ClassifiedToken, 8>>
UseTrailingReturnTypeCheck::classifyTokensBeforeFunctionName(
    const FunctionDecl &F, const ASTContext &Ctx, const SourceManager &SM,
    const LangOptions &LangOpts) {
  SourceLocation BeginF = expandIfMacroId(F.getBeginLoc(), SM);
  SourceLocation BeginNameF = expandIfMacroId(F.getLocation(), SM);

  // Create tokens for everything before the name of the function.
  std::pair<FileID, unsigned> Loc = SM.getDecomposedLoc(BeginF);
  StringRef File = SM.getBufferData(Loc.first);
  const char *TokenBegin = File.data() + Loc.second;
  Lexer Lexer(SM.getLocForStartOfFile(Loc.first), LangOpts, File.begin(),
              TokenBegin, File.end());
  Token T;
  SmallVector<ClassifiedToken, 8> ClassifiedTokens;
  while (!Lexer.LexFromRawLexer(T) &&
         SM.isBeforeInTranslationUnit(T.getLocation(), BeginNameF)) {
    if (T.is(tok::raw_identifier)) {
      IdentifierInfo &Info = Ctx.Idents.get(
          StringRef(SM.getCharacterData(T.getLocation()), T.getLength()));

      if (Info.hasMacroDefinition()) {
        const MacroInfo *MI = PP->getMacroInfo(&Info);
        if (!MI || MI->isFunctionLike()) {
          // Cannot handle function style macros.
          diag(F.getLocation(), Message);
          return llvm::None;
        }
      }

      T.setIdentifierInfo(&Info);
      T.setKind(Info.getTokenID());
    }

    if (llvm::Optional<ClassifiedToken> CT = classifyToken(F, *PP, T))
      ClassifiedTokens.push_back(*CT);
    else {
      diag(F.getLocation(), Message);
      return llvm::None;
    }
  }

  return ClassifiedTokens;
}

static bool hasAnyNestedLocalQualifiers(QualType Type) {
  bool Result = Type.hasLocalQualifiers();
  if (Type->isPointerType())
    Result = Result || hasAnyNestedLocalQualifiers(
                           Type->castAs<PointerType>()->getPointeeType());
  if (Type->isReferenceType())
    Result = Result || hasAnyNestedLocalQualifiers(
                           Type->castAs<ReferenceType>()->getPointeeType());
  return Result;
}

SourceRange UseTrailingReturnTypeCheck::findReturnTypeAndCVSourceRange(
    const FunctionDecl &F, const TypeLoc &ReturnLoc, const ASTContext &Ctx,
    const SourceManager &SM, const LangOptions &LangOpts) {

  // We start with the range of the return type and expand to neighboring
  // qualifiers (const, volatile and restrict).
  SourceRange ReturnTypeRange = F.getReturnTypeSourceRange();
  if (ReturnTypeRange.isInvalid()) {
    // Happens if e.g. clang cannot resolve all includes and the return type is
    // unknown.
    diag(F.getLocation(), Message);
    return {};
  }

  // If the return type is a constrained 'auto' or 'decltype(auto)', we need to
  // include the tokens after the concept. Unfortunately, the source range of an
  // AutoTypeLoc, if it is constrained, does not include the 'auto' or
  // 'decltype(auto)'. If the return type is a plain 'decltype(...)', the
  // source range only contains the first 'decltype' token.
  auto ATL = ReturnLoc.getAs<AutoTypeLoc>();
  if ((ATL && (ATL.isConstrained() ||
               ATL.getAutoKeyword() == AutoTypeKeyword::DecltypeAuto)) ||
      ReturnLoc.getAs<DecltypeTypeLoc>()) {
    SourceLocation End =
        expandIfMacroId(ReturnLoc.getSourceRange().getEnd(), SM);
    SourceLocation BeginNameF = expandIfMacroId(F.getLocation(), SM);

    // Extend the ReturnTypeRange until the last token before the function
    // name.
    std::pair<FileID, unsigned> Loc = SM.getDecomposedLoc(End);
    StringRef File = SM.getBufferData(Loc.first);
    const char *TokenBegin = File.data() + Loc.second;
    Lexer Lexer(SM.getLocForStartOfFile(Loc.first), LangOpts, File.begin(),
                TokenBegin, File.end());
    Token T;
    SourceLocation LastTLoc = End;
    while (!Lexer.LexFromRawLexer(T) &&
           SM.isBeforeInTranslationUnit(T.getLocation(), BeginNameF)) {
      LastTLoc = T.getLocation();
    }
    ReturnTypeRange.setEnd(LastTLoc);
  }

  // If the return type has no local qualifiers, it's source range is accurate.
  if (!hasAnyNestedLocalQualifiers(F.getReturnType()))
    return ReturnTypeRange;

  // Include qualifiers to the left and right of the return type.
  llvm::Optional<SmallVector<ClassifiedToken, 8>> MaybeTokens =
      classifyTokensBeforeFunctionName(F, Ctx, SM, LangOpts);
  if (!MaybeTokens)
    return {};
  const SmallVector<ClassifiedToken, 8> &Tokens = *MaybeTokens;

  ReturnTypeRange.setBegin(expandIfMacroId(ReturnTypeRange.getBegin(), SM));
  ReturnTypeRange.setEnd(expandIfMacroId(ReturnTypeRange.getEnd(), SM));

  bool ExtendedLeft = false;
  for (size_t I = 0; I < Tokens.size(); I++) {
    // If we found the beginning of the return type, include left qualifiers.
    if (!SM.isBeforeInTranslationUnit(Tokens[I].T.getLocation(),
                                      ReturnTypeRange.getBegin()) &&
        !ExtendedLeft) {
      assert(I <= size_t(std::numeric_limits<int>::max()) &&
             "Integer overflow detected");
      for (int J = static_cast<int>(I) - 1; J >= 0 && Tokens[J].isQualifier;
           J--)
        ReturnTypeRange.setBegin(Tokens[J].T.getLocation());
      ExtendedLeft = true;
    }
    // If we found the end of the return type, include right qualifiers.
    if (SM.isBeforeInTranslationUnit(ReturnTypeRange.getEnd(),
                                     Tokens[I].T.getLocation())) {
      for (size_t J = I; J < Tokens.size() && Tokens[J].isQualifier; J++)
        ReturnTypeRange.setEnd(Tokens[J].T.getLocation());
      break;
    }
  }

  assert(!ReturnTypeRange.getBegin().isMacroID() &&
         "Return type source range begin must not be a macro");
  assert(!ReturnTypeRange.getEnd().isMacroID() &&
         "Return type source range end must not be a macro");
  return ReturnTypeRange;
}

void UseTrailingReturnTypeCheck::keepSpecifiers(
    std::string &ReturnType, std::string &Auto, SourceRange ReturnTypeCVRange,
    const FunctionDecl &F, const FriendDecl *Fr, const ASTContext &Ctx,
    const SourceManager &SM, const LangOptions &LangOpts) {
  // Check if there are specifiers inside the return type. E.g. unsigned
  // inline int.
  const auto *M = dyn_cast<CXXMethodDecl>(&F);
  if (!F.isConstexpr() && !F.isInlineSpecified() &&
      F.getStorageClass() != SC_Extern && F.getStorageClass() != SC_Static &&
      !Fr && !(M && M->isVirtualAsWritten()))
    return;

  // Tokenize return type. If it contains macros which contain a mix of
  // qualifiers, specifiers and types, give up.
  llvm::Optional<SmallVector<ClassifiedToken, 8>> MaybeTokens =
      classifyTokensBeforeFunctionName(F, Ctx, SM, LangOpts);
  if (!MaybeTokens)
    return;

  // Find specifiers, remove them from the return type, add them to 'auto'.
  unsigned int ReturnTypeBeginOffset =
      SM.getDecomposedLoc(ReturnTypeCVRange.getBegin()).second;
  size_t InitialAutoLength = Auto.size();
  unsigned int DeletedChars = 0;
  for (ClassifiedToken CT : *MaybeTokens) {
    if (SM.isBeforeInTranslationUnit(CT.T.getLocation(),
                                     ReturnTypeCVRange.getBegin()) ||
        SM.isBeforeInTranslationUnit(ReturnTypeCVRange.getEnd(),
                                     CT.T.getLocation()))
      continue;
    if (!CT.isSpecifier)
      continue;

    // Add the token to 'auto' and remove it from the return type, including
    // any whitespace following the token.
    unsigned int TOffset = SM.getDecomposedLoc(CT.T.getLocation()).second;
    assert(TOffset >= ReturnTypeBeginOffset &&
           "Token location must be after the beginning of the return type");
    unsigned int TOffsetInRT = TOffset - ReturnTypeBeginOffset - DeletedChars;
    unsigned int TLengthWithWS = CT.T.getLength();
    while (TOffsetInRT + TLengthWithWS < ReturnType.size() &&
           llvm::isSpace(ReturnType[TOffsetInRT + TLengthWithWS]))
      TLengthWithWS++;
    std::string Specifier = ReturnType.substr(TOffsetInRT, TLengthWithWS);
    if (!llvm::isSpace(Specifier.back()))
      Specifier.push_back(' ');
    Auto.insert(Auto.size() - InitialAutoLength, Specifier);
    ReturnType.erase(TOffsetInRT, TLengthWithWS);
    DeletedChars += TLengthWithWS;
  }
}

void UseTrailingReturnTypeCheck::registerMatchers(MatchFinder *Finder) {
  auto F = functionDecl(
               unless(anyOf(hasTrailingReturn(), returns(voidType()),
                            cxxConversionDecl(), cxxMethodDecl(isImplicit()))))
               .bind("Func");

  Finder->addMatcher(F, this);
  Finder->addMatcher(friendDecl(hasDescendant(F)).bind("Friend"), this);
}

void UseTrailingReturnTypeCheck::registerPPCallbacks(
    const SourceManager &SM, Preprocessor *PP, Preprocessor *ModuleExpanderPP) {
  this->PP = PP;
}

void UseTrailingReturnTypeCheck::check(const MatchFinder::MatchResult &Result) {
  assert(PP && "Expected registerPPCallbacks() to have been called before so "
               "preprocessor is available");

  const auto *F = Result.Nodes.getNodeAs<FunctionDecl>("Func");
  const auto *Fr = Result.Nodes.getNodeAs<FriendDecl>("Friend");
  assert(F && "Matcher is expected to find only FunctionDecls");

  if (F->getLocation().isInvalid())
    return;

  // Skip functions which return just 'auto'.
  const auto *AT = F->getDeclaredReturnType()->getAs<AutoType>();
  if (AT != nullptr && !AT->isConstrained() &&
      AT->getKeyword() == AutoTypeKeyword::Auto &&
      !hasAnyNestedLocalQualifiers(F->getDeclaredReturnType()))
    return;

  // TODO: implement those
  if (F->getDeclaredReturnType()->isFunctionPointerType() ||
      F->getDeclaredReturnType()->isMemberFunctionPointerType() ||
      F->getDeclaredReturnType()->isMemberPointerType()) {
    diag(F->getLocation(), Message);
    return;
  }

  const ASTContext &Ctx = *Result.Context;
  const SourceManager &SM = *Result.SourceManager;
  const LangOptions &LangOpts = getLangOpts();

  const TypeSourceInfo *TSI = F->getTypeSourceInfo();
  if (!TSI)
    return;

  FunctionTypeLoc FTL =
      TSI->getTypeLoc().IgnoreParens().getAs<FunctionTypeLoc>();
  if (!FTL) {
    // FIXME: This may happen if we have __attribute__((...)) on the function.
    // We abort for now. Remove this when the function type location gets
    // available in clang.
    diag(F->getLocation(), Message);
    return;
  }

  SourceLocation InsertionLoc =
      findTrailingReturnTypeSourceLocation(*F, FTL, Ctx, SM, LangOpts);
  if (InsertionLoc.isInvalid()) {
    diag(F->getLocation(), Message);
    return;
  }

  // Using the declared return type via F->getDeclaredReturnType().getAsString()
  // discards user formatting and order of const, volatile, type, whitespace,
  // space before & ... .
  SourceRange ReturnTypeCVRange =
      findReturnTypeAndCVSourceRange(*F, FTL.getReturnLoc(), Ctx, SM, LangOpts);
  if (ReturnTypeCVRange.isInvalid())
    return;

  // Check if unqualified names in the return type conflict with other entities
  // after the rewrite.
  // FIXME: this could be done better, by performing a lookup of all
  // unqualified names in the return type in the scope of the function. If the
  // lookup finds a different entity than the original entity identified by the
  // name, then we can either not perform a rewrite or explicitly qualify the
  // entity. Such entities could be function parameter names, (inherited) class
  // members, template parameters, etc.
  UnqualNameVisitor UNV{*F};
  UNV.TraverseTypeLoc(FTL.getReturnLoc());
  if (UNV.Collision) {
    diag(F->getLocation(), Message);
    return;
  }

  SourceLocation ReturnTypeEnd =
      Lexer::getLocForEndOfToken(ReturnTypeCVRange.getEnd(), 0, SM, LangOpts);
  StringRef CharAfterReturnType = Lexer::getSourceText(
      CharSourceRange::getCharRange(ReturnTypeEnd,
                                    ReturnTypeEnd.getLocWithOffset(1)),
      SM, LangOpts);
  bool NeedSpaceAfterAuto =
      CharAfterReturnType.empty() || !llvm::isSpace(CharAfterReturnType[0]);

  std::string Auto = NeedSpaceAfterAuto ? "auto " : "auto";
  std::string ReturnType =
      std::string(tooling::fixit::getText(ReturnTypeCVRange, Ctx));
  keepSpecifiers(ReturnType, Auto, ReturnTypeCVRange, *F, Fr, Ctx, SM,
                 LangOpts);

  diag(F->getLocation(), Message)
      << FixItHint::CreateReplacement(ReturnTypeCVRange, Auto)
      << FixItHint::CreateInsertion(InsertionLoc, " -> " + ReturnType);
}

} // namespace modernize
} // namespace tidy
} // namespace clang