1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operators for concise TensorFlow network models.
16
17This module is used as an environment for evaluating expressions
18in the "specs" DSL.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25from tensorflow.contrib.layers.python.layers import layers
26from tensorflow.contrib.specs.python import specs_lib
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import logging_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn
31from tensorflow.python.ops import nn_ops
32from tensorflow.python.ops import variable_scope
33
34# The following assignments don't appear to follow Google naming
35# conventions, but that's because these are functions defined by
36# higher-order function application, not "constants" and because they
37# are the commands of the DSL.
38# pylint: disable=invalid-name
39
40
41class Idx(specs_lib.Composable):
42  """Implements the identity function in network specifications."""
43
44  def funcall(self, x):
45    return x
46
47
48class Conc(specs_lib.Composable):
49  """Implements tensor concatenation in network specifications."""
50
51  def __init__(self, dim, *args):
52    """Concatenates tensors along the given dimension.
53
54    Args:
55        dim: dimension along which concatenation takes place
56        *args: argument tensor functions to be concatenated
57    """
58    self.dim = dim
59    self.funs = args
60
61  def funcall(self, x):
62    outputs = [f.funcall(x) for f in self.funs]
63    return array_ops.concat(outputs, self.dim)
64
65
66External = specs_lib.External
67Import = specs_lib.Import
68Fun = specs_lib.Function
69debug = specs_lib.debug
70Print = Fun(logging_ops.Print)
71Id = Fun(array_ops.identity)
72
73# TODO(tmb) add Assert
74
75# Two letter names for the most common layers.
76
77# 2D Convolutional layers with nonlinearities (s/t/r/m/l)
78# TODO(tmb) add Cbs, Fbs etc. for batch norms
79
80Cx = Fun(layers.conv2d)
81Cs = Fun(layers.conv2d, activation_fn=math_ops.sigmoid)
82Ct = Fun(layers.conv2d, activation_fn=math_ops.tanh)
83Cr = Fun(layers.conv2d, activation_fn=nn_ops.relu)
84Cm = Fun(layers.conv2d, activation_fn=nn_ops.softmax)
85Cl = Fun(layers.conv2d, activation_fn=None)
86
87# Fully connected slim with nonlinearities (s/t/r/m/l)
88
89Fx = Fun(layers.fully_connected)
90Fs = Fun(layers.fully_connected, activation_fn=math_ops.sigmoid)
91Ft = Fun(layers.fully_connected, activation_fn=math_ops.tanh)
92Fr = Fun(layers.fully_connected, activation_fn=nn_ops.relu)
93Fm = Fun(layers.fully_connected, activation_fn=nn_ops.softmax)
94Fl = Fun(layers.fully_connected, activation_fn=None)
95
96# Pooling
97
98Mp = Fun(layers.max_pool2d)
99Ap = Fun(layers.avg_pool2d)
100
101# Batch manipulations
102
103Do = Fun(layers.dropout)
104Bn = Fun(layers.batch_norm)
105Lrn = Fun(nn.local_response_normalization)
106Unit = Fun(layers.unit_norm)
107
108# Shape changes
109
110Flat = Fun(layers.flatten)
111Reshape = Fun(array_ops.reshape)
112Transpose = Fun(array_ops.transpose)
113Squeeze = Fun(array_ops.squeeze)
114Expand = Fun(array_ops.expand_dims)
115
116# Nonlinearities (rarely needed on their own)
117
118Relu = Fun(nn_ops.relu)
119Sig = Fun(math_ops.sigmoid)
120Tanh = Fun(math_ops.tanh)
121Smax = Fun(nn_ops.softmax)
122
123
124def Dws(n):
125  """Depth-wise convolution + sigmoid (used after LSTM)."""
126  return Cs(n, [1, 1])
127
128
129def Dwm(n):
130  """Depth-wise convolution + softmax (used after LSTM)."""
131  return Cm(n, [1, 1])
132
133# Sharing of Variables
134
135
136def Var(name, *args, **kw):
137  """Implements an operator that generates a variable.
138
139  This function is still experimental. Use it only
140  for generating a single variable instance for
141  each name.
142
143  Args:
144      name: Name of the variable.
145      *args: Other arguments to get_variable.
146      **kw: Other keywords for get_variable.
147
148  Returns:
149      A specs object for generating a variable.
150  """
151
152  def var(_):
153    return variable_scope.get_variable(name, *args, **kw)
154
155  return specs_lib.Callable(var)
156
157
158class Shared(specs_lib.Composable):
159  """Wraps a scope with variable reuse around the subnetwork.
160
161  This function is still experimental.
162
163  Attributes:
164      f: The shared subnetwork.
165      name: A name for the shared scope.
166      used: A flag indicating whether the scope has already been used.
167  """
168
169  shared_number = 1
170
171  def __init__(self, subnet, name=None, scope=None):
172    """Create the Shared operator.
173
174    Use this as:
175
176        f = Shared(Cr(100, 3))
177        g = f | f | f
178
179    Ordinarily, you do not need to provide either a name or a scope.
180    Providing a name is useful if you want a well-defined namespace
181    for the variables (e.g., for saving a subnet).
182
183    Args:
184        subnet: Definition of the shared network.
185        name: Optional name for the shared context.
186        scope: Optional shared scope (must be a Scope, not a string).
187
188    Raises:
189        ValueError: Scope is not of type tf.Scope, name is not
190        of type string, or both scope and name are given together.
191    """
192    if scope is not None and not isinstance(scope,
193                                            variable_scope.VariableScope):
194      raise ValueError("scope must be None or a VariableScope")
195    if name is not None and not isinstance(scope, str):
196      raise ValueError("name must be None or a string")
197    if scope is not None and name is not None:
198      raise ValueError("cannot provide both a name and a scope")
199    if name is None:
200      name = "Shared_%d" % Shared.shared_number
201      Shared.shared_number += 1
202    self.subnet = subnet
203    self.name = name
204    self.scope = scope
205
206  def funcall(self, x):
207    """Apply the shared operator to an input.
208
209    This wraps a variable scope around the creation of the subnet.
210
211    Args:
212        x: The input argument on which the subnet is invoked.
213
214    Returns:
215        The output tensor from invoking the subnet constructor.
216    """
217    if self.scope is None:
218      with variable_scope.variable_scope(self.name, values=[x]) as scope:
219        self.scope = scope
220        return self.subnet.funcall(x)
221    else:
222      with variable_scope.variable_scope(self.scope, values=[x], reuse=True):
223        return self.subnet.funcall(x)
224