1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "CheckTraceVisitor.h"
6 
7 #include <vector>
8 
9 #include "Config.h"
10 
11 using namespace clang;
12 
CheckTraceVisitor(CXXMethodDecl * trace,RecordInfo * info,RecordCache * cache)13 CheckTraceVisitor::CheckTraceVisitor(CXXMethodDecl* trace,
14                                      RecordInfo* info,
15                                      RecordCache* cache)
16     : trace_(trace), info_(info), cache_(cache) {}
17 
VisitMemberExpr(MemberExpr * member)18 bool CheckTraceVisitor::VisitMemberExpr(MemberExpr* member) {
19   // In weak callbacks, consider any occurrence as a correct usage.
20   // TODO: We really want to require that isAlive is checked on manually
21   // processed weak fields.
22   if (IsWeakCallback()) {
23     if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl()))
24       FoundField(field);
25   }
26   return true;
27 }
28 
VisitCallExpr(CallExpr * call)29 bool CheckTraceVisitor::VisitCallExpr(CallExpr* call) {
30   // In weak callbacks we don't check calls (see VisitMemberExpr).
31   if (IsWeakCallback())
32     return true;
33 
34   Expr* callee = call->getCallee();
35 
36   // Trace calls from a templated derived class result in a
37   // DependentScopeMemberExpr because the concrete trace call depends on the
38   // instantiation of any shared template parameters. In this case the call is
39   // "unresolved" and we resort to comparing the syntactic type names.
40   if (CXXDependentScopeMemberExpr* expr =
41       dyn_cast<CXXDependentScopeMemberExpr>(callee)) {
42     CheckCXXDependentScopeMemberExpr(call, expr);
43     return true;
44   }
45 
46   // A tracing call will have either a |visitor| or a |m_field| argument.
47   // A registerWeakMembers call will have a |this| argument.
48   if (call->getNumArgs() != 1)
49     return true;
50   Expr* arg = call->getArg(0);
51 
52   if (UnresolvedMemberExpr* expr = dyn_cast<UnresolvedMemberExpr>(callee)) {
53     // This could be a trace call of a base class, as explained in the
54     // comments of CheckTraceBaseCall().
55     if (CheckTraceBaseCall(call))
56       return true;
57 
58     if (expr->getMemberName().getAsString() == kRegisterWeakMembersName)
59       MarkAllWeakMembersTraced();
60 
61     QualType base = expr->getBaseType();
62     if (!base->isPointerType())
63       return true;
64     CXXRecordDecl* decl = base->getPointeeType()->getAsCXXRecordDecl();
65     if (decl)
66       CheckTraceFieldCall(expr->getMemberName().getAsString(), decl, arg);
67     return true;
68   }
69 
70   if (CXXMemberCallExpr* expr = dyn_cast<CXXMemberCallExpr>(call)) {
71     if (CheckTraceFieldMemberCall(expr) || CheckRegisterWeakMembers(expr))
72       return true;
73 
74   }
75 
76   CheckTraceBaseCall(call);
77   return true;
78 }
79 
IsTraceCallName(const std::string & name)80 bool CheckTraceVisitor::IsTraceCallName(const std::string& name) {
81   // Currently, a manually dispatched class cannot have mixin bases (having
82   // one would add a vtable which we explicitly check against). This means
83   // that we can only make calls to a trace method of the same name. Revisit
84   // this if our mixin/vtable assumption changes.
85   return name == trace_->getName();
86 }
87 
GetDependentTemplatedDecl(CXXDependentScopeMemberExpr * expr)88 CXXRecordDecl* CheckTraceVisitor::GetDependentTemplatedDecl(
89     CXXDependentScopeMemberExpr* expr) {
90   NestedNameSpecifier* qual = expr->getQualifier();
91   if (!qual)
92     return 0;
93 
94   const Type* type = qual->getAsType();
95   if (!type)
96     return 0;
97 
98   return RecordInfo::GetDependentTemplatedDecl(*type);
99 }
100 
101 namespace {
102 
103 class FindFieldVisitor : public RecursiveASTVisitor<FindFieldVisitor> {
104  public:
105   FindFieldVisitor();
106   MemberExpr* member() const;
107   FieldDecl* field() const;
108   bool TraverseMemberExpr(MemberExpr* member);
109 
110  private:
111   MemberExpr* member_;
112   FieldDecl* field_;
113 };
114 
FindFieldVisitor()115 FindFieldVisitor::FindFieldVisitor()
116     : member_(0),
117       field_(0) {
118 }
119 
member() const120 MemberExpr* FindFieldVisitor::member() const {
121   return member_;
122 }
123 
field() const124 FieldDecl* FindFieldVisitor::field() const {
125   return field_;
126 }
127 
TraverseMemberExpr(MemberExpr * member)128 bool FindFieldVisitor::TraverseMemberExpr(MemberExpr* member) {
129   if (FieldDecl* field = dyn_cast<FieldDecl>(member->getMemberDecl())) {
130     member_ = member;
131     field_ = field;
132     return false;
133   }
134   return true;
135 }
136 
137 }  // namespace
138 
CheckCXXDependentScopeMemberExpr(CallExpr * call,CXXDependentScopeMemberExpr * expr)139 void CheckTraceVisitor::CheckCXXDependentScopeMemberExpr(
140     CallExpr* call,
141     CXXDependentScopeMemberExpr* expr) {
142   std::string fn_name = expr->getMember().getAsString();
143 
144   // Check for VisitorDispatcher::trace(field) and
145   // VisitorDispatcher::registerWeakMembers.
146   if (!expr->isImplicitAccess()) {
147     if (DeclRefExpr* base_decl = dyn_cast<DeclRefExpr>(expr->getBase())) {
148       if (Config::IsVisitorDispatcherType(base_decl->getType())) {
149         if (call->getNumArgs() == 1 && fn_name == kTraceName) {
150           FindFieldVisitor finder;
151           finder.TraverseStmt(call->getArg(0));
152           if (finder.field())
153             FoundField(finder.field());
154 
155           return;
156         } else if (call->getNumArgs() == 1 &&
157                    fn_name == kRegisterWeakMembersName) {
158           MarkAllWeakMembersTraced();
159         }
160       }
161     }
162   }
163 
164   CXXRecordDecl* tmpl = GetDependentTemplatedDecl(expr);
165   if (!tmpl)
166     return;
167 
168   // Check for Super<T>::trace(visitor)
169   if (call->getNumArgs() == 1 && IsTraceCallName(fn_name)) {
170     RecordInfo::Bases::iterator it = info_->GetBases().begin();
171     for (; it != info_->GetBases().end(); ++it) {
172       if (it->first->getName() == tmpl->getName())
173         it->second.MarkTraced();
174     }
175   }
176 
177   // Check for TraceIfNeeded<T>::trace(visitor, &field)
178   if (call->getNumArgs() == 2 && fn_name == kTraceName &&
179       tmpl->getName() == kTraceIfNeededName) {
180     FindFieldVisitor finder;
181     finder.TraverseStmt(call->getArg(1));
182     if (finder.field())
183       FoundField(finder.field());
184   }
185 }
186 
CheckTraceBaseCall(CallExpr * call)187 bool CheckTraceVisitor::CheckTraceBaseCall(CallExpr* call) {
188   // Checks for "Base::trace(visitor)"-like calls.
189 
190   // Checking code for these two variables is shared among MemberExpr* case
191   // and UnresolvedMemberCase* case below.
192   //
193   // For example, if we've got "Base::trace(visitor)" as |call|,
194   // callee_record will be "Base", and func_name will be "trace".
195   CXXRecordDecl* callee_record = nullptr;
196   std::string func_name;
197 
198   if (MemberExpr* callee = dyn_cast<MemberExpr>(call->getCallee())) {
199     if (!callee->hasQualifier())
200       return false;
201 
202     FunctionDecl* trace_decl =
203         dyn_cast<FunctionDecl>(callee->getMemberDecl());
204     if (!trace_decl || !Config::IsTraceMethod(trace_decl))
205       return false;
206 
207     const Type* type = callee->getQualifier()->getAsType();
208     if (!type)
209       return false;
210 
211     callee_record = type->getAsCXXRecordDecl();
212     func_name = trace_decl->getName();
213   } else if (UnresolvedMemberExpr* callee =
214              dyn_cast<UnresolvedMemberExpr>(call->getCallee())) {
215     // Callee part may become unresolved if the type of the argument
216     // ("visitor") is a template parameter and the called function is
217     // overloaded.
218     //
219     // Here, we try to find a function that looks like trace() from the
220     // candidate overloaded functions, and if we find one, we assume it is
221     // called here.
222 
223     CXXMethodDecl* trace_decl = nullptr;
224     for (NamedDecl* named_decl : callee->decls()) {
225       if (CXXMethodDecl* method_decl = dyn_cast<CXXMethodDecl>(named_decl)) {
226         if (Config::IsTraceMethod(method_decl)) {
227           trace_decl = method_decl;
228           break;
229         }
230       }
231     }
232     if (!trace_decl)
233       return false;
234 
235     // Check if the passed argument is named "visitor".
236     if (call->getNumArgs() != 1)
237       return false;
238     DeclRefExpr* arg = dyn_cast<DeclRefExpr>(call->getArg(0));
239     if (!arg || arg->getNameInfo().getAsString() != kVisitorVarName)
240       return false;
241 
242     callee_record = trace_decl->getParent();
243     func_name = callee->getMemberName().getAsString();
244   }
245 
246   if (!callee_record)
247     return false;
248 
249   if (!IsTraceCallName(func_name))
250     return false;
251 
252   for (auto& base : info_->GetBases()) {
253     // We want to deal with omitted trace() function in an intermediary
254     // class in the class hierarchy, e.g.:
255     //     class A : public GarbageCollected<A> { trace() { ... } };
256     //     class B : public A { /* No trace(); have nothing to trace. */ };
257     //     class C : public B { trace() { B::trace(visitor); } }
258     // where, B::trace() is actually A::trace(), and in some cases we get
259     // A as |callee_record| instead of B. We somehow need to mark B as
260     // traced if we find A::trace() call.
261     //
262     // To solve this, here we keep going up the class hierarchy as long as
263     // they are not required to have a trace method. The implementation is
264     // a simple DFS, where |base_records| represents the set of base classes
265     // we need to visit.
266 
267     std::vector<CXXRecordDecl*> base_records;
268     base_records.push_back(base.first);
269 
270     while (!base_records.empty()) {
271       CXXRecordDecl* base_record = base_records.back();
272       base_records.pop_back();
273 
274       if (base_record == callee_record) {
275         // If we find a matching trace method, pretend the user has written
276         // a correct trace() method of the base; in the example above, we
277         // find A::trace() here and mark B as correctly traced.
278         base.second.MarkTraced();
279         return true;
280       }
281 
282       if (RecordInfo* base_info = cache_->Lookup(base_record)) {
283         if (!base_info->RequiresTraceMethod()) {
284           // If this base class is not required to have a trace method, then
285           // the actual trace method may be defined in an ancestor.
286           for (auto& inner_base : base_info->GetBases())
287             base_records.push_back(inner_base.first);
288         }
289       }
290     }
291   }
292 
293   return false;
294 }
295 
CheckTraceFieldMemberCall(CXXMemberCallExpr * call)296 bool CheckTraceVisitor::CheckTraceFieldMemberCall(CXXMemberCallExpr* call) {
297   return CheckTraceFieldCall(call->getMethodDecl()->getNameAsString(),
298                              call->getRecordDecl(),
299                              call->getArg(0));
300 }
301 
CheckTraceFieldCall(const std::string & name,CXXRecordDecl * callee,Expr * arg)302 bool CheckTraceVisitor::CheckTraceFieldCall(
303     const std::string& name,
304     CXXRecordDecl* callee,
305     Expr* arg) {
306   if (name != kTraceName || !Config::IsVisitor(callee->getName()))
307     return false;
308 
309   FindFieldVisitor finder;
310   finder.TraverseStmt(arg);
311   if (finder.field())
312     FoundField(finder.field());
313 
314   return true;
315 }
316 
CheckRegisterWeakMembers(CXXMemberCallExpr * call)317 bool CheckTraceVisitor::CheckRegisterWeakMembers(CXXMemberCallExpr* call) {
318   CXXMethodDecl* fn = call->getMethodDecl();
319   if (fn->getName() != kRegisterWeakMembersName)
320     return false;
321 
322   if (fn->isTemplateInstantiation()) {
323     const TemplateArgumentList& args =
324         *fn->getTemplateSpecializationInfo()->TemplateArguments;
325     // The second template argument is the callback method.
326     if (args.size() > 1 &&
327         args[1].getKind() == TemplateArgument::Declaration) {
328       if (FunctionDecl* callback =
329           dyn_cast<FunctionDecl>(args[1].getAsDecl())) {
330         if (callback->hasBody()) {
331           CheckTraceVisitor nested_visitor(nullptr, info_, nullptr);
332           nested_visitor.TraverseStmt(callback->getBody());
333         }
334       }
335       // TODO: mark all WeakMember<>s as traced even if
336       // the body isn't available?
337     }
338   }
339   return true;
340 }
341 
IsWeakCallback() const342 bool CheckTraceVisitor::IsWeakCallback() const {
343   return !trace_;
344 }
345 
MarkTraced(RecordInfo::Fields::iterator it)346 void CheckTraceVisitor::MarkTraced(RecordInfo::Fields::iterator it) {
347   // In a weak callback we can't mark strong fields as traced.
348   if (IsWeakCallback() && !it->second.edge()->IsWeakMember())
349     return;
350   it->second.MarkTraced();
351 }
352 
FoundField(FieldDecl * field)353 void CheckTraceVisitor::FoundField(FieldDecl* field) {
354   if (Config::IsTemplateInstantiation(info_->record())) {
355     // Pointer equality on fields does not work for template instantiations.
356     // The trace method refers to fields of the template definition which
357     // are different from the instantiated fields that need to be traced.
358     const std::string& name = field->getNameAsString();
359     for (RecordInfo::Fields::iterator it = info_->GetFields().begin();
360          it != info_->GetFields().end();
361          ++it) {
362       if (it->first->getNameAsString() == name) {
363         MarkTraced(it);
364         break;
365       }
366     }
367   } else {
368     RecordInfo::Fields::iterator it = info_->GetFields().find(field);
369     if (it != info_->GetFields().end())
370       MarkTraced(it);
371   }
372 }
373 
MarkAllWeakMembersTraced()374 void CheckTraceVisitor::MarkAllWeakMembersTraced() {
375   // If we find a call to registerWeakMembers which is unresolved we
376   // unsoundly consider all weak members as traced.
377   // TODO: Find out how to validate weak member tracing for unresolved call.
378   for (auto& field : info_->GetFields()) {
379     if (field.second.edge()->IsWeakMember())
380       field.second.MarkTraced();
381   }
382 }
383