1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python module for MLIR functions exported by pybind11."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
22from tensorflow.python import pywrap_tensorflow
23from tensorflow.python.eager import context
24from tensorflow.python._pywrap_mlir import *
25
26
27def import_graphdef(graphdef, pass_pipeline, show_debug_info):
28  return ImportGraphDef(
29      str(graphdef).encode('utf-8'), pass_pipeline.encode('utf-8'),
30      show_debug_info)
31
32
33def import_function(concrete_function, pass_pipeline, show_debug_info):
34  ctxt = context.context()
35  ctxt.ensure_initialized()
36  return ImportFunction(ctxt._handle,
37                        str(concrete_function.function_def).encode('utf-8'),
38                        pass_pipeline.encode('utf-8'), show_debug_info)
39
40
41def experimental_convert_saved_model_to_mlir(saved_model_path, exported_names,
42                                             show_debug_info):
43  return ExperimentalConvertSavedModelToMlir(
44      str(saved_model_path).encode('utf-8'),
45      str(exported_names).encode('utf-8'), show_debug_info)
46
47
48def experimental_convert_saved_model_v1_to_mlir_lite(saved_model_path, tags,
49                                                     upgrade_legacy,
50                                                     show_debug_info):
51  return ExperimentalConvertSavedModelV1ToMlirLite(
52      str(saved_model_path).encode('utf-8'),
53      str(tags).encode('utf-8'), upgrade_legacy, show_debug_info)
54
55
56def experimental_convert_saved_model_v1_to_mlir(saved_model_path, tags,
57                                                lift_variables, upgrade_legacy,
58                                                show_debug_info):
59  return ExperimentalConvertSavedModelV1ToMlir(
60      str(saved_model_path).encode('utf-8'),
61      str(tags).encode('utf-8'), lift_variables, upgrade_legacy,
62      show_debug_info)
63
64
65def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
66  return ExperimentalRunPassPipeline(
67      mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info)
68