1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "CheckTraceVisitor.h"
6 
7 #include <vector>
8 
9 #include "Config.h"
10 
11 using namespace clang;
12 
CheckTraceVisitor(CXXMethodDecl * trace,RecordInfo * info,RecordCache * cache)13 CheckTraceVisitor::CheckTraceVisitor(CXXMethodDecl* trace,
14                                      RecordInfo* info,
15                                      RecordCache* cache)
16     : trace_(trace),
17       info_(info),
18       cache_(cache),
19       delegates_to_traceimpl_(false) {
20 }
21 
delegates_to_traceimpl() const22 bool CheckTraceVisitor::delegates_to_traceimpl() const {
23   return delegates_to_traceimpl_;
24 }
25 
VisitMemberExpr(MemberExpr * member)26 bool CheckTraceVisitor::VisitMemberExpr(MemberExpr* member) {
27   // In weak callbacks, consider any occurrence as a correct usage.
28   // TODO: We really want to require that isAlive is checked on manually
29   // processed weak fields.
30   if (IsWeakCallback()) {
31     if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl()))
32       FoundField(field);
33   }
34   return true;
35 }
36 
VisitCallExpr(CallExpr * call)37 bool CheckTraceVisitor::VisitCallExpr(CallExpr* call) {
38   // In weak callbacks we don't check calls (see VisitMemberExpr).
39   if (IsWeakCallback())
40     return true;
41 
42   Expr* callee = call->getCallee();
43 
44   // Trace calls from a templated derived class result in a
45   // DependentScopeMemberExpr because the concrete trace call depends on the
46   // instantiation of any shared template parameters. In this case the call is
47   // "unresolved" and we resort to comparing the syntactic type names.
48   if (CXXDependentScopeMemberExpr* expr =
49       dyn_cast<CXXDependentScopeMemberExpr>(callee)) {
50     CheckCXXDependentScopeMemberExpr(call, expr);
51     return true;
52   }
53 
54   // A tracing call will have either a |visitor| or a |m_field| argument.
55   // A registerWeakMembers call will have a |this| argument.
56   if (call->getNumArgs() != 1)
57     return true;
58   Expr* arg = call->getArg(0);
59 
60   if (UnresolvedMemberExpr* expr = dyn_cast<UnresolvedMemberExpr>(callee)) {
61     // This could be a trace call of a base class, as explained in the
62     // comments of CheckTraceBaseCall().
63     if (CheckTraceBaseCall(call))
64       return true;
65 
66     if (expr->getMemberName().getAsString() == kRegisterWeakMembersName)
67       MarkAllWeakMembersTraced();
68 
69     QualType base = expr->getBaseType();
70     if (!base->isPointerType())
71       return true;
72     CXXRecordDecl* decl = base->getPointeeType()->getAsCXXRecordDecl();
73     if (decl)
74       CheckTraceFieldCall(expr->getMemberName().getAsString(), decl, arg);
75     if (Config::IsTraceImplName(expr->getMemberName().getAsString()))
76       delegates_to_traceimpl_ = true;
77     return true;
78   }
79 
80   if (CXXMemberCallExpr* expr = dyn_cast<CXXMemberCallExpr>(call)) {
81     if (CheckTraceFieldMemberCall(expr) || CheckRegisterWeakMembers(expr))
82       return true;
83 
84     if (Config::IsTraceImplName(expr->getMethodDecl()->getNameAsString())) {
85       delegates_to_traceimpl_ = true;
86       return true;
87     }
88   }
89 
90   CheckTraceBaseCall(call);
91   return true;
92 }
93 
IsTraceCallName(const std::string & name)94 bool CheckTraceVisitor::IsTraceCallName(const std::string& name) {
95   if (trace_->getName() == kTraceImplName)
96     return name == kTraceName;
97   if (trace_->getName() == kTraceAfterDispatchImplName)
98     return name == kTraceAfterDispatchName;
99   // Currently, a manually dispatched class cannot have mixin bases (having
100   // one would add a vtable which we explicitly check against). This means
101   // that we can only make calls to a trace method of the same name. Revisit
102   // this if our mixin/vtable assumption changes.
103   return name == trace_->getName();
104 }
105 
GetDependentTemplatedDecl(CXXDependentScopeMemberExpr * expr)106 CXXRecordDecl* CheckTraceVisitor::GetDependentTemplatedDecl(
107     CXXDependentScopeMemberExpr* expr) {
108   NestedNameSpecifier* qual = expr->getQualifier();
109   if (!qual)
110     return 0;
111 
112   const Type* type = qual->getAsType();
113   if (!type)
114     return 0;
115 
116   return RecordInfo::GetDependentTemplatedDecl(*type);
117 }
118 
119 namespace {
120 
121 class FindFieldVisitor : public RecursiveASTVisitor<FindFieldVisitor> {
122  public:
123   FindFieldVisitor();
124   MemberExpr* member() const;
125   FieldDecl* field() const;
126   bool TraverseMemberExpr(MemberExpr* member);
127 
128  private:
129   MemberExpr* member_;
130   FieldDecl* field_;
131 };
132 
FindFieldVisitor()133 FindFieldVisitor::FindFieldVisitor()
134     : member_(0),
135       field_(0) {
136 }
137 
member() const138 MemberExpr* FindFieldVisitor::member() const {
139   return member_;
140 }
141 
field() const142 FieldDecl* FindFieldVisitor::field() const {
143   return field_;
144 }
145 
TraverseMemberExpr(MemberExpr * member)146 bool FindFieldVisitor::TraverseMemberExpr(MemberExpr* member) {
147   if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl())) {
148     member_ = member;
149     field_ = field;
150     return false;
151   }
152   return true;
153 }
154 
155 }  // namespace
156 
CheckCXXDependentScopeMemberExpr(CallExpr * call,CXXDependentScopeMemberExpr * expr)157 void CheckTraceVisitor::CheckCXXDependentScopeMemberExpr(
158     CallExpr* call,
159     CXXDependentScopeMemberExpr* expr) {
160   std::string fn_name = expr->getMember().getAsString();
161 
162   // Check for VisitorDispatcher::trace(field) and
163   // VisitorDispatcher::registerWeakMembers.
164   if (!expr->isImplicitAccess()) {
165     if (DeclRefExpr* base_decl = dyn_cast<DeclRefExpr>(expr->getBase())) {
166       if (Config::IsVisitorDispatcherType(base_decl->getType())) {
167         if (call->getNumArgs() == 1 && fn_name == kTraceName) {
168           FindFieldVisitor finder;
169           finder.TraverseStmt(call->getArg(0));
170           if (finder.field())
171             FoundField(finder.field());
172 
173           return;
174         } else if (call->getNumArgs() == 1 &&
175                    fn_name == kRegisterWeakMembersName) {
176           MarkAllWeakMembersTraced();
177         }
178       }
179     }
180   }
181 
182   CXXRecordDecl* tmpl = GetDependentTemplatedDecl(expr);
183   if (!tmpl)
184     return;
185 
186   // Check for Super<T>::trace(visitor)
187   if (call->getNumArgs() == 1 && IsTraceCallName(fn_name)) {
188     RecordInfo::Bases::iterator it = info_->GetBases().begin();
189     for (; it != info_->GetBases().end(); ++it) {
190       if (it->first->getName() == tmpl->getName())
191         it->second.MarkTraced();
192     }
193   }
194 
195   // Check for TraceIfNeeded<T>::trace(visitor, &field)
196   if (call->getNumArgs() == 2 && fn_name == kTraceName &&
197       tmpl->getName() == kTraceIfNeededName) {
198     FindFieldVisitor finder;
199     finder.TraverseStmt(call->getArg(1));
200     if (finder.field())
201       FoundField(finder.field());
202   }
203 }
204 
CheckTraceBaseCall(CallExpr * call)205 bool CheckTraceVisitor::CheckTraceBaseCall(CallExpr* call) {
206   // Checks for "Base::trace(visitor)"-like calls.
207 
208   // Checking code for these two variables is shared among MemberExpr* case
209   // and UnresolvedMemberCase* case below.
210   //
211   // For example, if we've got "Base::trace(visitor)" as |call|,
212   // callee_record will be "Base", and func_name will be "trace".
213   CXXRecordDecl* callee_record = nullptr;
214   std::string func_name;
215 
216   if (MemberExpr* callee = dyn_cast<MemberExpr>(call->getCallee())) {
217     if (!callee->hasQualifier())
218       return false;
219 
220     FunctionDecl* trace_decl =
221         dyn_cast<FunctionDecl>(callee->getMemberDecl());
222     if (!trace_decl || !Config::IsTraceMethod(trace_decl))
223       return false;
224 
225     const Type* type = callee->getQualifier()->getAsType();
226     if (!type)
227       return false;
228 
229     callee_record = type->getAsCXXRecordDecl();
230     func_name = trace_decl->getName();
231   } else if (UnresolvedMemberExpr* callee =
232              dyn_cast<UnresolvedMemberExpr>(call->getCallee())) {
233     // Callee part may become unresolved if the type of the argument
234     // ("visitor") is a template parameter and the called function is
235     // overloaded (i.e. trace(Visitor*) and
236     // trace(InlinedGlobalMarkingVisitor)).
237     //
238     // Here, we try to find a function that looks like trace() from the
239     // candidate overloaded functions, and if we find one, we assume it is
240     // called here.
241 
242     CXXMethodDecl* trace_decl = nullptr;
243     for (NamedDecl* named_decl : callee->decls()) {
244       if (CXXMethodDecl* method_decl = dyn_cast<CXXMethodDecl>(named_decl)) {
245         if (Config::IsTraceMethod(method_decl)) {
246           trace_decl = method_decl;
247           break;
248         }
249       }
250     }
251     if (!trace_decl)
252       return false;
253 
254     // Check if the passed argument is named "visitor".
255     if (call->getNumArgs() != 1)
256       return false;
257     DeclRefExpr* arg = dyn_cast<DeclRefExpr>(call->getArg(0));
258     if (!arg || arg->getNameInfo().getAsString() != kVisitorVarName)
259       return false;
260 
261     callee_record = trace_decl->getParent();
262     func_name = callee->getMemberName().getAsString();
263   }
264 
265   if (!callee_record)
266     return false;
267 
268   if (!IsTraceCallName(func_name))
269     return false;
270 
271   for (auto& base : info_->GetBases()) {
272     // We want to deal with omitted trace() function in an intermediary
273     // class in the class hierarchy, e.g.:
274     //     class A : public GarbageCollected<A> { trace() { ... } };
275     //     class B : public A { /* No trace(); have nothing to trace. */ };
276     //     class C : public B { trace() { B::trace(visitor); } }
277     // where, B::trace() is actually A::trace(), and in some cases we get
278     // A as |callee_record| instead of B. We somehow need to mark B as
279     // traced if we find A::trace() call.
280     //
281     // To solve this, here we keep going up the class hierarchy as long as
282     // they are not required to have a trace method. The implementation is
283     // a simple DFS, where |base_records| represents the set of base classes
284     // we need to visit.
285 
286     std::vector<CXXRecordDecl*> base_records;
287     base_records.push_back(base.first);
288 
289     while (!base_records.empty()) {
290       CXXRecordDecl* base_record = base_records.back();
291       base_records.pop_back();
292 
293       if (base_record == callee_record) {
294         // If we find a matching trace method, pretend the user has written
295         // a correct trace() method of the base; in the example above, we
296         // find A::trace() here and mark B as correctly traced.
297         base.second.MarkTraced();
298         return true;
299       }
300 
301       if (RecordInfo* base_info = cache_->Lookup(base_record)) {
302         if (!base_info->RequiresTraceMethod()) {
303           // If this base class is not required to have a trace method, then
304           // the actual trace method may be defined in an ancestor.
305           for (auto& inner_base : base_info->GetBases())
306             base_records.push_back(inner_base.first);
307         }
308       }
309     }
310   }
311 
312   return false;
313 }
314 
CheckTraceFieldMemberCall(CXXMemberCallExpr * call)315 bool CheckTraceVisitor::CheckTraceFieldMemberCall(CXXMemberCallExpr* call) {
316   return CheckTraceFieldCall(call->getMethodDecl()->getNameAsString(),
317                              call->getRecordDecl(),
318                              call->getArg(0));
319 }
320 
CheckTraceFieldCall(const std::string & name,CXXRecordDecl * callee,Expr * arg)321 bool CheckTraceVisitor::CheckTraceFieldCall(
322     const std::string& name,
323     CXXRecordDecl* callee,
324     Expr* arg) {
325   if (name != kTraceName || !Config::IsVisitor(callee->getName()))
326     return false;
327 
328   FindFieldVisitor finder;
329   finder.TraverseStmt(arg);
330   if (finder.field())
331     FoundField(finder.field());
332 
333   return true;
334 }
335 
CheckRegisterWeakMembers(CXXMemberCallExpr * call)336 bool CheckTraceVisitor::CheckRegisterWeakMembers(CXXMemberCallExpr* call) {
337   CXXMethodDecl* fn = call->getMethodDecl();
338   if (fn->getName() != kRegisterWeakMembersName)
339     return false;
340 
341   if (fn->isTemplateInstantiation()) {
342     const TemplateArgumentList& args =
343         *fn->getTemplateSpecializationInfo()->TemplateArguments;
344     // The second template argument is the callback method.
345     if (args.size() > 1 &&
346         args[1].getKind() == TemplateArgument::Declaration) {
347       if (FunctionDecl* callback =
348           dyn_cast<FunctionDecl>(args[1].getAsDecl())) {
349         if (callback->hasBody()) {
350           CheckTraceVisitor nested_visitor(nullptr, info_, nullptr);
351           nested_visitor.TraverseStmt(callback->getBody());
352         }
353       }
354     }
355   }
356   return true;
357 }
358 
IsWeakCallback() const359 bool CheckTraceVisitor::IsWeakCallback() const {
360   return !trace_;
361 }
362 
MarkTraced(RecordInfo::Fields::iterator it)363 void CheckTraceVisitor::MarkTraced(RecordInfo::Fields::iterator it) {
364   // In a weak callback we can't mark strong fields as traced.
365   if (IsWeakCallback() && !it->second.edge()->IsWeakMember())
366     return;
367   it->second.MarkTraced();
368 }
369 
FoundField(FieldDecl * field)370 void CheckTraceVisitor::FoundField(FieldDecl* field) {
371   if (Config::IsTemplateInstantiation(info_->record())) {
372     // Pointer equality on fields does not work for template instantiations.
373     // The trace method refers to fields of the template definition which
374     // are different from the instantiated fields that need to be traced.
375     const std::string& name = field->getNameAsString();
376     for (RecordInfo::Fields::iterator it = info_->GetFields().begin();
377          it != info_->GetFields().end();
378          ++it) {
379       if (it->first->getNameAsString() == name) {
380         MarkTraced(it);
381         break;
382       }
383     }
384   } else {
385     RecordInfo::Fields::iterator it = info_->GetFields().find(field);
386     if (it != info_->GetFields().end())
387       MarkTraced(it);
388   }
389 }
390 
MarkAllWeakMembersTraced()391 void CheckTraceVisitor::MarkAllWeakMembersTraced() {
392   // If we find a call to registerWeakMembers which is unresolved we
393   // unsoundly consider all weak members as traced.
394   // TODO: Find out how to validate weak member tracing for unresolved call.
395   for (auto& field : info_->GetFields()) {
396     if (field.second.edge()->IsWeakMember())
397       field.second.MarkTraced();
398   }
399 }
400