1"""Targets for generating TensorFlow Python API __init__.py files."""
2
3load("//tensorflow/python/tools/api/generator:api_init_files.bzl", "TENSORFLOW_API_INIT_FILES")
4
5def get_compat_files(
6        file_paths,
7        compat_api_version):
8    """Prepends compat/v<compat_api_version> to file_paths."""
9    return ["compat/v%d/%s" % (compat_api_version, f) for f in file_paths]
10
11def get_nested_compat_files(compat_api_versions):
12    """Return __init__.py file paths for files under nested compat modules.
13
14    A nested compat module contains two __init__.py files:
15      1. compat/vN/compat/vK/__init__.py
16      2. compat/vN/compat/vK/compat/__init__.py
17
18    Args:
19      compat_api_versions: list of compat versions.
20
21    Returns:
22      List of __init__.py file paths to include under nested compat modules.
23    """
24    files = []
25    for v in compat_api_versions:
26        files.extend([
27            "compat/v%d/compat/v%d/__init__.py" % (v, sv)
28            for sv in compat_api_versions
29        ])
30        files.extend([
31            "compat/v%d/compat/v%d/compat/__init__.py" % (v, sv)
32            for sv in compat_api_versions
33        ])
34    return files
35
36def gen_api_init_files(
37        name,
38        output_files = TENSORFLOW_API_INIT_FILES,
39        root_init_template = None,
40        srcs = [],
41        api_name = "tensorflow",
42        api_version = 2,
43        compat_api_versions = [],
44        compat_init_templates = [],
45        packages = [
46            "tensorflow.python",
47            "tensorflow.lite.python.lite",
48            "tensorflow.python.modules_with_exports",
49        ],
50        package_deps = [
51            "//tensorflow/python:no_contrib",
52            "//tensorflow/python:modules_with_exports",
53        ],
54        output_package = "tensorflow",
55        output_dir = "",
56        root_file_name = "__init__.py"):
57    """Creates API directory structure and __init__.py files.
58
59    Creates a genrule that generates a directory structure with __init__.py
60    files that import all exported modules (i.e. modules with tf_export
61    decorators).
62
63    Args:
64      name: name of genrule to create.
65      output_files: List of __init__.py files that should be generated.
66        This list should include file name for every module exported using
67        tf_export. For e.g. if an op is decorated with
68        @tf_export('module1.module2', 'module3'). Then, output_files should
69        include module1/module2/__init__.py and module3/__init__.py.
70      root_init_template: Python init file that should be used as template for
71        root __init__.py file. "# API IMPORTS PLACEHOLDER" comment inside this
72        template will be replaced with root imports collected by this genrule.
73      srcs: genrule sources. If passing root_init_template, the template file
74        must be included in sources.
75      api_name: Name of the project that you want to generate API files for
76        (e.g. "tensorflow" or "estimator").
77      api_version: TensorFlow API version to generate. Must be either 1 or 2.
78      compat_api_versions: Older TensorFlow API versions to generate under
79        compat/ directory.
80      compat_init_templates: Python init file that should be used as template
81        for top level __init__.py files under compat/vN directories.
82        "# API IMPORTS PLACEHOLDER" comment inside this
83        template will be replaced with root imports collected by this genrule.
84      packages: Python packages containing the @tf_export decorators you want to
85        process
86      package_deps: Python library target containing your packages.
87      output_package: Package where generated API will be added to.
88      output_dir: Subdirectory to output API to.
89        If non-empty, must end with '/'.
90      root_file_name: Name of the root file with all the root imports.
91    """
92    root_init_template_flag = ""
93    if root_init_template:
94        root_init_template_flag = "--root_init_template=" + root_init_template
95
96    primary_package = packages[0]
97    api_gen_binary_target = ("create_" + primary_package + "_api_%s") % name
98    native.py_binary(
99        name = api_gen_binary_target,
100        srcs = ["//tensorflow/python/tools/api/generator:create_python_api.py"],
101        main = "//tensorflow/python/tools/api/generator:create_python_api.py",
102        python_version = "PY3",
103        srcs_version = "PY3",
104        visibility = ["//visibility:public"],
105        deps = package_deps + [
106            "//tensorflow/python:util",
107            "//tensorflow/python/tools/api/generator:doc_srcs",
108        ],
109    )
110
111    # Replace name of root file with root_file_name.
112    output_files = [
113        root_file_name if f == "__init__.py" else f
114        for f in output_files
115    ]
116    all_output_files = ["%s%s" % (output_dir, f) for f in output_files]
117    compat_api_version_flags = ""
118    for compat_api_version in compat_api_versions:
119        compat_api_version_flags += " --compat_apiversion=%d" % compat_api_version
120
121    compat_init_template_flags = ""
122    for compat_init_template in compat_init_templates:
123        compat_init_template_flags += (
124            " --compat_init_template=$(location %s)" % compat_init_template
125        )
126
127    # copybara:uncomment_begin(configurable API loading)
128    # native.vardef("TF_API_INIT_LOADING", "default")
129    # loading_flag = " --loading=$(TF_API_INIT_LOADING)"
130    # copybara:uncomment_end_and_comment_begin
131    loading_flag = " --loading=default"
132    # copybara:comment_end
133
134    native.genrule(
135        name = name,
136        outs = all_output_files,
137        cmd = (
138            "$(location :" + api_gen_binary_target + ") " +
139            root_init_template_flag + " --apidir=$(@D)" + output_dir +
140            " --apiname=" + api_name + " --apiversion=" + str(api_version) +
141            compat_api_version_flags + " " + compat_init_template_flags +
142            loading_flag + " --packages=" + ",".join(packages) +
143            " --output_package=" + output_package +
144            " --use_relative_imports=True $(OUTS)"
145        ),
146        srcs = srcs,
147        tools = [":" + api_gen_binary_target],
148        visibility = [
149            "//tensorflow:__pkg__",
150            "//tensorflow/tools/api/tests:__pkg__",
151        ],
152    )
153