1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Builder for TensorFlow models specified using specs_ops. 16 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from six import exec_ 23from tensorflow.contrib.specs.python import params_ops 24from tensorflow.contrib.specs.python import specs_lib 25from tensorflow.contrib.specs.python import specs_ops 26from tensorflow.python.util import tf_inspect 27 28 29def eval_params(params, environment=None): 30 """Evaluates a parameter specification and returns the environment. 31 32 Args: 33 params: parameter assignments as a string 34 environment: a dictionary of input bindings 35 36 Returns: 37 Environment with additional bindings created by 38 executing `params` 39 40 Raises: 41 Exception: other exceptions raised during execution of `params` 42 """ 43 specs_lib.check_keywords(params) 44 bindings = {} 45 if environment: 46 bindings.update(environment) 47 exec_(params, vars(params_ops), bindings) # pylint: disable=exec-used 48 return bindings 49 50 51def eval_spec(spec, environment=None): 52 """Evaluates a spec and returns the environment. 53 54 This function allows you to use a spec to obtain multiple bindings 55 in an environment. That is useful if you use the spec language to 56 specify multiple components of a larger network, for example: "left 57 = Cr(64, [5,5]); right = Fc(64)" Usually, you will want to use 58 `create_net` or `create_net_fun` below. 59 60 Args: 61 spec: specification as a string 62 environment: a dictionary of input bindings 63 64 Returns: 65 Environment with additional bindings created by spec. 66 67 Raises: 68 Exception: other exceptions raised during execution of `spec` 69 70 """ 71 specs_lib.check_keywords(spec) 72 bindings = {} 73 if environment: 74 bindings.update(environment) 75 exec_(spec, vars(specs_ops), bindings) # pylint: disable=exec-used 76 return bindings 77 78 79def create_net_fun(spec, environment=None): 80 """Evaluates a spec and returns the binding of `net`. 81 82 Specs are written in a DSL based on function composition. A spec 83 like `net = Cr(64, [3, 3])` assigns an object that represents a 84 single argument function capable of creating a network to 85 the variable `net`. 86 87 Args: 88 spec: specification as a string, ending with a `net = ...` statement 89 environment: a dictionary of input bindings 90 91 Returns: 92 A callable that instantiates the `net` binding. 93 94 Raises: 95 ValueError: spec failed to create a `net` 96 Exception: other exceptions raised during execution of `spec` 97 98 """ 99 bindings = eval_spec(spec, environment) 100 net = bindings.get("net", None) 101 if net is None: 102 raise ValueError("spec failed to create 'net': %s" % (spec,)) 103 return net.funcall 104 105 106def create_net(spec, inputs, environment=None): 107 """Evaluates a spec and creates a network instance given the inputs. 108 109 Args: 110 spec: specification as a string, ending with a `net = ...` statement 111 inputs: input that `net` is applied to 112 environment: a dictionary of input bindings 113 114 Returns: 115 A callable that instantiates the `net` binding. 116 117 Raises: 118 ValueError: spec failed to create a `net` 119 Exception: other exceptions raised during execution of `spec` 120 """ 121 return create_net_fun(spec, environment)(inputs) 122 123 124class LocalImport(object): 125 """A class that allows us to temporarily import something. 126 127 Attributes: 128 frame: the frame in which the context manager was invocked 129 names: a dictionary containing the new bindings 130 old: variable bindings that have been shadowed by the import 131 """ 132 133 def __init__(self, names): 134 """Create a context manager that binds the names in values. 135 136 Args: 137 names: A dictionary or module containing the bindings. 138 """ 139 if not isinstance(names, dict): 140 names = vars(names) 141 self.names = names 142 143 def __enter__(self): 144 self.frame = tf_inspect.currentframe() 145 bindings = self.frame.f_back.f_globals 146 self.old = {k: bindings.get(k, None) for k in self.names.keys()} 147 bindings.update(self.names) 148 149 def __exit__(self, some_type, value, traceback): 150 del some_type, value, traceback 151 bindings = self.frame.f_back.f_globals 152 bindings.update(self.old) 153 for k, v in self.old.items(): 154 if v is None: 155 del bindings[k] 156 del self.frame 157 158 159ops = LocalImport(specs_ops) 160