1#  Copyright 2016 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS-IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15def generate_files(injection_graph, generate_runtime_bench_code):
16    file_content_by_name = dict()
17
18    for node_id in injection_graph.nodes_iter():
19        deps = injection_graph.successors(node_id)
20        file_content_by_name['component%s.h' % node_id] = _generate_component_header(node_id, deps)
21        file_content_by_name['component%s.cpp' % node_id] = _generate_component_source(node_id, deps)
22
23    [toplevel_node] = [node_id
24                       for node_id in injection_graph.nodes_iter()
25                       if not injection_graph.predecessors(node_id)]
26    file_content_by_name['main.cpp'] = _generate_main(injection_graph, toplevel_node, generate_runtime_bench_code)
27
28    return file_content_by_name
29
30def _generate_component_header(component_index, deps):
31    fields = ''.join(['std::shared_ptr<Interface%s> x%s;\n' % (dep, dep)
32                      for dep in deps])
33    component_deps = ''.join([', std::shared_ptr<Interface%s>' % dep for dep in deps])
34
35    include_directives = ''.join(['#include "component%s.h"\n' % index for index in deps])
36
37    template = """
38#ifndef COMPONENT{component_index}_H
39#define COMPONENT{component_index}_H
40
41#include <boost/di.hpp>
42#include <boost/di/extension/scopes/scoped_scope.hpp>
43#include <memory>
44
45// Example include that the code might use
46#include <vector>
47
48namespace di = boost::di;
49
50{include_directives}
51
52struct Interface{component_index} {{
53    virtual ~Interface{component_index}() = default;
54}};
55
56struct X{component_index} : public Interface{component_index} {{
57    {fields}
58
59    BOOST_DI_INJECT(X{component_index}{component_deps});
60
61    virtual ~X{component_index}() = default;
62}};
63
64auto x{component_index}Component = [] {{
65    return di::make_injector(di::bind<Interface{component_index}>().to<X{component_index}>().in(di::extension::scoped));
66}};
67
68#endif // COMPONENT{component_index}_H
69"""
70    return template.format(**locals())
71
72def _generate_component_source(component_index, deps):
73    param_initializers = ', '.join('x%s(x%s)' % (dep, dep)
74                                   for dep in deps)
75    if param_initializers:
76        param_initializers = ': ' + param_initializers
77    component_deps = ', '.join('std::shared_ptr<Interface%s> x%s' % (dep, dep)
78                               for dep in deps)
79
80    template = """
81#include "component{component_index}.h"
82
83X{component_index}::X{component_index}({component_deps})
84    {param_initializers} {{
85}}
86"""
87    return template.format(**locals())
88
89def _generate_main(injection_graph, toplevel_component, generate_runtime_bench_code):
90    include_directives = ''.join('#include "component%s.h"\n' % index
91                                 for index in injection_graph.nodes_iter())
92
93    injector_params = ', '.join('x%sComponent()' % index
94                                for index in injection_graph.nodes_iter())
95
96    if generate_runtime_bench_code:
97        template = """
98{include_directives}
99
100#include "component{toplevel_component}.h"
101#include <ctime>
102#include <iostream>
103#include <cstdlib>
104#include <iomanip>
105#include <chrono>
106
107using namespace std;
108
109void f() {{
110  auto injector = di::make_injector({injector_params});
111  injector.create<std::shared_ptr<Interface{toplevel_component}>>();
112}}
113
114int main(int argc, char* argv[]) {{
115  if (argc != 2) {{
116    std::cout << "Need to specify num_loops as argument." << std::endl;
117    exit(1);
118  }}
119  size_t num_loops = std::atoi(argv[1]);
120  double perRequestTime = 0;
121  std::chrono::high_resolution_clock::time_point start_time = std::chrono::high_resolution_clock::now();
122  for (size_t i = 0; i < num_loops; i++) {{
123    f();
124  }}
125  perRequestTime += std::chrono::duration_cast<std::chrono::duration<double>>(std::chrono::high_resolution_clock::now() - start_time).count();
126  std::cout << std::fixed;
127  std::cout << std::setprecision(15);
128  std::cout << "Total for setup            = " << 0 << std::endl;
129  std::cout << "Full injection time        = " << perRequestTime / num_loops << std::endl;
130  std::cout << "Total per request          = " << perRequestTime / num_loops << std::endl;
131  return 0;
132}}
133"""
134    else:
135        template = """
136{include_directives}
137
138#include "component{toplevel_component}.h"
139
140#include <iostream>
141
142int main() {{
143  auto injector = di::make_injector({injector_params});
144  injector.create<std::shared_ptr<Interface{toplevel_component}>>();
145  std::cout << "Hello, world" << std::endl;
146  return 0;
147}}
148"""
149    return template.format(**locals())
150