1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Unit tests for quantizing a Tensorflow graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.framework.python.ops import variables 22from tensorflow.contrib.layers.python.layers import layers 23from tensorflow.contrib.quantize.python import quantize 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import init_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn 31from tensorflow.python.ops import nn_ops 32from tensorflow.python.ops import partitioned_variables 33from tensorflow.python.ops import variable_scope 34from tensorflow.python.platform import googletest 35 36conv2d = layers.conv2d 37separable_conv2d = layers.separable_conv2d 38 39 40class QuantizeTest(test_util.TensorFlowTestCase): 41 42 def _RunTestOverParameters(self, test_fn): 43 params = [True, False] 44 for is_training in params: 45 test_fn(is_training) 46 47 def testInsertQuantOpFailsWhenOpsNotConnected(self): 48 pass 49 50 def _TestInsertQuantOpFailsWhenOpsNotConnected(self, is_training): 51 graph = ops.Graph() 52 with graph.as_default(): 53 batch_size, height, width, depth = 5, 128, 128, 3 54 inputs = array_ops.zeros((batch_size, height, width, depth)) 55 conv = conv2d(inputs, 32, [5, 5], stride=2, padding='SAME', 56 weights_initializer=self._WeightInit(0.09), 57 activation_fn=None, scope='test') 58 relu = nn_ops.relu6(inputs) 59 60 # Inserting a quantization op between two unconnected ops should fail with 61 # ValueError. 62 with self.assertRaises(ValueError) as err: 63 quantize._InsertQuantOp('test', is_training, conv.op, [relu.op], 64 'FailingQuantOp') 65 self.assertEqual( 66 str(err.exception), 'Some inputs not quantized for ops: [Relu6]') 67 68 def testInsertQuantOpForAddAfterConv2d(self): 69 self._RunTestOverParameters(self._TestInsertQuantOpForAddAfterConv2d) 70 71 def _TestInsertQuantOpForAddAfterConv2d(self, is_training): 72 graph = ops.Graph() 73 with graph.as_default(): 74 batch_size, height, width, depth = 5, 128, 128, 3 75 input1 = array_ops.zeros((batch_size, height, width, depth)) 76 input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) 77 conv = conv2d(input1, 32, [5, 5], stride=2, padding='SAME', 78 weights_initializer=self._WeightInit(0.09), 79 activation_fn=None, scope='test/test') 80 node = math_ops.add(conv, input2, name='test/add') 81 node = nn_ops.relu6(node, name='test/relu6') 82 update_barrier = control_flow_ops.no_op(name='update_barrier') 83 with ops.control_dependencies([update_barrier]): 84 array_ops.identity(node, name='control_dependency') 85 86 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 87 88 quantization_node_name = 'FakeQuantWithMinMaxVars' 89 conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + 90 quantization_node_name) 91 self.assertEqual(conv_quant.type, quantization_node_name) 92 93 # Scan through all FakeQuant operations, ensuring that the activation 94 # isn't in the consumers of the operation. Since activations are folded 95 # the preceding operation during inference, the FakeQuant operation after 96 # the activation is all that is needed. 97 for op in graph.get_operations(): 98 if op.type == quantization_node_name: 99 quant_op = graph.get_operation_by_name(op.name) 100 consumers = [] 101 for output in quant_op.outputs: 102 consumers.extend(output.consumers()) 103 104 self.assertNotIn('test/relu6', [c.name for c in consumers]) 105 106 def testInsertQuantOpForAddAfterSeparableConv2d(self): 107 self._RunTestOverParameters( 108 self._TestInsertQuantOpForAddAfterSeparableConv2d) 109 110 def _TestInsertQuantOpForAddAfterSeparableConv2d(self, is_training): 111 graph = ops.Graph() 112 with graph.as_default(): 113 batch_size, height, width, depth = 5, 128, 128, 3 114 input1 = array_ops.zeros((batch_size, height, width, depth)) 115 input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth)) 116 conv = separable_conv2d(input1, None, [5, 5], stride=2, 117 depth_multiplier=1.0, padding='SAME', 118 weights_initializer=self._WeightInit(0.09), 119 activation_fn=None, scope='test/test') 120 node = math_ops.add(conv, input2, name='test/add') 121 node = nn_ops.relu6(node, name='test/relu6') 122 update_barrier = control_flow_ops.no_op(name='update_barrier') 123 with ops.control_dependencies([update_barrier]): 124 array_ops.identity(node, name='control_dependency') 125 126 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 127 # Check if output of bias add is quantized 128 quantization_node_name = 'FakeQuantWithMinMaxVars' 129 conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + 130 quantization_node_name) 131 self.assertEqual(conv_quant.type, quantization_node_name) 132 133 for op in graph.get_operations(): 134 if op.type == quantization_node_name: 135 quant_op = graph.get_operation_by_name(op.name) 136 # Scan through all FakeQuant operations, ensuring that the activation 137 # identity op isn't in the consumers of the operation. 138 consumers = [] 139 for output in quant_op.outputs: 140 consumers.extend(output.consumers()) 141 142 self.assertNotIn('test/relu6', [c.name for c in consumers]) 143 144 def testInsertQuantOpInSeparableConv2d(self): 145 self._RunTestOverParameters(self._TestInsertQuantOpInSeparableConv2d) 146 147 def _TestInsertQuantOpInSeparableConv2d(self, is_training): 148 graph = ops.Graph() 149 with graph.as_default(): 150 batch_size, height, width, depth = 5, 128, 128, 3 151 input1 = array_ops.zeros((batch_size, height, width, depth)) 152 input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth)) 153 conv = separable_conv2d( 154 input1, 155 3, [5, 5], 156 stride=2, 157 depth_multiplier=1.0, 158 padding='SAME', 159 weights_initializer=self._WeightInit(0.09), 160 activation_fn=None, 161 scope='test/test') 162 node = math_ops.add(conv, input2, name='test/add') 163 node = nn_ops.relu6(node, name='test/relu6') 164 update_barrier = control_flow_ops.no_op(name='update_barrier') 165 with ops.control_dependencies([update_barrier]): 166 array_ops.identity(node, name='control_dependency') 167 168 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 169 # Check if output of bias add is quantized 170 quantization_node_name = 'FakeQuantWithMinMaxVars' 171 conv_quant = graph.get_operation_by_name('test/test/conv_quant/' + 172 quantization_node_name) 173 self.assertEqual(conv_quant.type, quantization_node_name) 174 175 # Check if weights for both convs inside seperable conv are quantized 176 pointwise_weight_quant = graph.get_operation_by_name( 177 'test/test/weights_quant/' + quantization_node_name) 178 self.assertEqual(pointwise_weight_quant.type, quantization_node_name) 179 depthwise_weight_quant = graph.get_operation_by_name( 180 'test/test/separable_conv2d/weights_quant/' + quantization_node_name) 181 self.assertEqual(depthwise_weight_quant.type, quantization_node_name) 182 183 # Check if activations after first depthwise conv are quantized. 184 depthwise_act_quant = graph.get_operation_by_name( 185 'test/test/separable_conv2d/act_quant/' + quantization_node_name) 186 self.assertEqual(depthwise_act_quant.type, quantization_node_name) 187 188 for op in graph.get_operations(): 189 if op.type == quantization_node_name: 190 quant_op = graph.get_operation_by_name(op.name) 191 # Scan through all FakeQuant operations, ensuring that the activation 192 # identity op isn't in the consumers of the operation. 193 consumers = [] 194 for output in quant_op.outputs: 195 consumers.extend(output.consumers()) 196 197 self.assertNotIn('test/relu6', [c.name for c in consumers]) 198 199 def testLayerActivationQuantized(self): 200 self._RunTestOverParameters(self._TestLayerActivationQuantized) 201 202 def _TestLayerActivationQuantized(self, is_training): 203 graph = ops.Graph() 204 with graph.as_default(): 205 batch_size, height, width, depth = 5, 128, 128, 3 206 input1 = array_ops.zeros((batch_size, height, width, depth)) 207 _ = conv2d( 208 input1, 209 32, [5, 5], 210 stride=2, 211 padding='SAME', 212 weights_initializer=self._WeightInit(0.09), 213 activation_fn=nn_ops.relu6, 214 biases_initializer=None, 215 scope='test') 216 # Ensure that both weights and output of activations are quantized 217 # when we have a conv->relu6 with no bias add 218 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 219 activation_op = graph.get_operation_by_name('test/Relu6') 220 conv_op = graph.get_operation_by_name('test/Conv2D') 221 self.assertTrue('test/weights_quant/FakeQuantWithMinMaxVars:0' in 222 [tensor_in.name for tensor_in in conv_op.inputs]) 223 self.assertTrue('FakeQuantWithMinMaxVars' in 224 [op.type for op in activation_op.outputs[0].consumers()]) 225 226 def testFinalLayerQuantized(self): 227 self._RunTestOverParameters(self._TestFinalLayerQuantized) 228 229 def _TestFinalLayerQuantized(self, is_training): 230 graph = ops.Graph() 231 with graph.as_default(): 232 batch_size, height, width, depth = 5, 128, 128, 3 233 input1 = array_ops.zeros((batch_size, height, width, depth)) 234 _ = conv2d( 235 input1, 236 32, [5, 5], 237 stride=2, 238 padding='SAME', 239 weights_initializer=self._WeightInit(0.09), 240 activation_fn=None, 241 scope='test') 242 # Ensure that the a FakeQuant operation is in the outputs of the BiasAdd. 243 bias_add_op = graph.get_operation_by_name('test/BiasAdd') 244 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 245 self.assertTrue('FakeQuantWithMinMaxVars' in 246 [op.type for op in bias_add_op.outputs[0].consumers()]) 247 248 def testPostActivationBypassQuantized(self): 249 self._RunTestOverParameters(self._TestPostActivationBypassQuantized) 250 251 def _TestPostActivationBypassQuantized(self, is_training): 252 graph = ops.Graph() 253 with graph.as_default(): 254 batch_size, height, width, depth = 5, 128, 128, 3 255 input1 = array_ops.zeros((batch_size, height, width, depth)) 256 input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) 257 conv = conv2d( 258 input1, 259 32, [5, 5], 260 stride=2, 261 padding='SAME', 262 weights_initializer=self._WeightInit(0.09), 263 activation_fn=nn_ops.relu6, 264 scope='test/test') 265 bypass_tensor = math_ops.add(conv, input2, name='test/add') 266 # The output of the post_activation bypass will be another layer. 267 _ = conv2d( 268 bypass_tensor, 269 32, [5, 5], 270 stride=2, 271 padding='SAME', 272 weights_initializer=self._WeightInit(0.09), 273 activation_fn=nn_ops.relu6, 274 scope='test/unused') 275 276 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 277 278 # Ensure that the bypass node is preceded by and followed by a 279 # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an 280 # activation. 281 self.assertTrue('FakeQuantWithMinMaxVars' in 282 [c.type for c in bypass_tensor.consumers()]) 283 self.assertTrue('FakeQuantWithMinMaxVars' in 284 [i.op.type for i in bypass_tensor.op.inputs]) 285 286 def testOverlappingPostActivationBypassQuantized(self): 287 self._RunTestOverParameters( 288 self._TestOverlappingPostActivationBypassQuantized) 289 290 def _TestOverlappingPostActivationBypassQuantized(self, is_training): 291 graph = ops.Graph() 292 with graph.as_default(): 293 batch_size, height, width, depth = 5, 128, 128, 3 294 conv_input = array_ops.zeros((batch_size, height, width, depth)) 295 conv1 = conv2d( 296 conv_input, 297 32, [5, 5], 298 stride=2, 299 padding='SAME', 300 weights_initializer=self._WeightInit(0.09), 301 activation_fn=nn_ops.relu6, 302 scope='test/test1') 303 304 # The bypass of this conv is the post activation bypass of the previous 305 # conv. 306 conv2 = conv2d( 307 conv_input, 308 32, [5, 5], 309 stride=2, 310 padding='SAME', 311 weights_initializer=self._WeightInit(0.09), 312 activation_fn=None, 313 scope='test/test2') 314 315 bypass_tensor = math_ops.add(conv1, conv2, name='test/add') 316 _ = nn_ops.relu6(bypass_tensor, name='test/output') 317 318 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 319 320 # Ensure that the bypass node is preceded by a FakeQuantWithMinMaxVar 321 # operation, and NOT followed by one. 322 self.assertTrue('FakeQuantWithMinMaxVars' not in 323 [c.type for c in bypass_tensor.consumers()]) 324 self.assertTrue('FakeQuantWithMinMaxVars' in 325 [i.op.type for i in bypass_tensor.op.inputs]) 326 327 # Ensure that all the convs and activations are quantized. 328 op_names = [op.name for op in graph.get_operations()] 329 self.assertTrue( 330 'test/test1/weights_quant/FakeQuantWithMinMaxVars' in op_names) 331 self.assertTrue( 332 'test/test2/weights_quant/FakeQuantWithMinMaxVars' in op_names) 333 self.assertTrue( 334 'test/test1/act_quant/FakeQuantWithMinMaxVars' in op_names) 335 self.assertTrue('test/act_quant/FakeQuantWithMinMaxVars' in op_names) 336 self.assertEqual( 337 'Relu6', 338 graph.get_operation_by_name( 339 'test/test1/act_quant/FakeQuantWithMinMaxVars').inputs[0].op.type) 340 self.assertEqual( 341 'Relu6', 342 graph.get_operation_by_name( 343 'test/act_quant/FakeQuantWithMinMaxVars').inputs[0].op.type) 344 345 def testWithNameScope(self): 346 self._RunTestOverParameters(self._TestWithNameScope) 347 348 def _TestWithNameScope(self, is_training): 349 graph = ops.Graph() 350 with graph.as_default(): 351 with graph.name_scope('name_scope'): 352 batch_size, height, width, depth = 5, 128, 128, 3 353 input1 = array_ops.zeros((batch_size, height, width, depth)) 354 _ = conv2d( 355 input1, 356 32, [5, 5], 357 stride=2, 358 padding='SAME', 359 weights_initializer=self._WeightInit(0.09), 360 activation_fn=None, 361 scope='test') 362 363 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 364 365 for op in graph.get_operations(): 366 self.assertTrue(not op.name.startswith('name_scope/name_scope/'), 367 'Broken op: %s' % op.name) 368 369 def testWithNullNameScope(self): 370 self._RunTestOverParameters(self._TestWithNullNameScope) 371 372 def _TestWithNullNameScope(self, is_training): 373 graph = ops.Graph() 374 with graph.as_default(): 375 with graph.name_scope(None): 376 batch_size, height, width, depth = 5, 128, 128, 32 377 input1 = array_ops.zeros((batch_size, height, width, depth)) 378 _ = conv2d( 379 input1, 380 32, [5, 5], 381 padding='SAME', 382 weights_initializer=self._WeightInit(0.09), 383 activation_fn=None, 384 scope='test') 385 386 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 387 # Passes if Quantize() does not crash. 388 389 def testWithNonMatchingNameScope(self): 390 self._RunTestOverParameters(self._testWithNonMatchingNameScope) 391 392 def _testWithNonMatchingNameScope(self, is_training): 393 graph = ops.Graph() 394 with graph.as_default(): 395 with graph.name_scope('name_scope'): 396 batch_size, height, width, depth = 5, 128, 128, 3 397 input1 = array_ops.zeros((batch_size, height, width, depth)) 398 _ = conv2d( 399 input1, 400 32, [5, 5], 401 stride=2, 402 padding='SAME', 403 weights_initializer=self._WeightInit(0.09), 404 activation_fn=None, 405 scope='test') 406 407 op_names_before_quantize = set([op.name for op in graph.get_operations()]) 408 quantize.Quantize( 409 graph, is_training, weight_bits=8, activation_bits=8, 410 scope='NonExisting/') 411 op_names_after_quantize = set([op.name for op in graph.get_operations()]) 412 413 # No ops should be inserted or removed. 414 self.assertEqual(op_names_before_quantize, op_names_after_quantize) 415 416 def testSinglePartitionedVariable(self): 417 self._RunTestOverParameters(self._testSinglePartitionedVariable) 418 419 def _testSinglePartitionedVariable(self, is_training): 420 # When weights are partitioned into a single partition, the weights variable 421 # is followed by a identity -> identity (An additional identity node). 422 partitioner = partitioned_variables.fixed_size_partitioner(1) 423 graph = ops.Graph() 424 with graph.as_default(): 425 with variable_scope.variable_scope('part', partitioner=partitioner): 426 batch_size, height, width, depth = 5, 128, 128, 3 427 input1 = array_ops.zeros((batch_size, height, width, depth)) 428 input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) 429 conv = conv2d( 430 input1, 431 32, [5, 5], 432 stride=2, 433 padding='SAME', 434 weights_initializer=self._WeightInit(0.09), 435 activation_fn=None, 436 scope='test/test') 437 node = math_ops.add(conv, input2, name='test/add') 438 node = nn_ops.relu6(node, name='test/relu6') 439 440 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 441 # Check that the weight's quant node was added. 442 op_names = [op.name for op in graph.get_operations()] 443 self.assertTrue( 444 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) 445 446 def testMultiplePartitionedVariables(self): 447 self._RunTestOverParameters(self._testMultiplePartitionedVariables) 448 449 def _testMultiplePartitionedVariables(self, is_training): 450 # When weights are partitioned into multiple partitions the weights variable 451 # is followed by a identity -> concat -> identity to group the partitions. 452 partitioner = partitioned_variables.fixed_size_partitioner(2) 453 graph = ops.Graph() 454 with graph.as_default(): 455 with variable_scope.variable_scope('part', partitioner=partitioner): 456 batch_size, height, width, depth = 5, 128, 128, 3 457 input1 = array_ops.zeros((batch_size, height, width, depth)) 458 input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) 459 conv = conv2d( 460 input1, 461 32, [5, 5], 462 stride=2, 463 padding='SAME', 464 weights_initializer=self._WeightInit(0.09), 465 activation_fn=None, 466 scope='test/test') 467 node = math_ops.add(conv, input2, name='test/add') 468 node = nn_ops.relu6(node, name='test/relu6') 469 470 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 471 # Check that the weight's quant node was added. 472 op_names = [op.name for op in graph.get_operations()] 473 self.assertTrue( 474 'part/test/test/weights_quant/FakeQuantWithMinMaxVars' in op_names) 475 476 def testSkipReshapeQuantization(self): 477 self._RunTestOverParameters(self._TestSkipReshapeQuantization) 478 479 def _TestSkipReshapeQuantization(self, is_training): 480 graph = ops.Graph() 481 with graph.as_default(): 482 batch_size, height, width, depth = 5, 128, 128, 3 483 input1 = array_ops.zeros((batch_size, height, width, depth)) 484 conv = conv2d( 485 input1, 486 32, [5, 5], 487 stride=2, 488 padding='SAME', 489 weights_initializer=self._WeightInit(0.09), 490 activation_fn=nn_ops.relu6, 491 scope='test/test') 492 493 reshape = array_ops.reshape( 494 conv, (int(10), int(height / 2), int(width / 2), int(16))) 495 496 # Insert a fake quant node after the reshape. We will check that one isn't 497 # insert before. 498 array_ops.fake_quant_with_min_max_vars(reshape, -1, 1) 499 500 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 501 502 # Ensure that there isn't a FakeQuant added before the reshape. 503 self.assertFalse( 504 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) 505 506 graph = ops.Graph() 507 with graph.as_default(): 508 batch_size, height, width, depth = 5, 128, 128, 3 509 input1 = array_ops.zeros((batch_size, height, width, depth)) 510 conv = conv2d( 511 input1, 512 32, [5, 5], 513 stride=2, 514 padding='SAME', 515 weights_initializer=self._WeightInit(0.09), 516 activation_fn=nn_ops.relu6, 517 scope='test/test') 518 519 reshape = array_ops.reshape( 520 conv, (int(10), int(height / 2), int(width / 2), int(16))) 521 522 # If no fake quant is added after the reshape, a FakeQuant should be added 523 # before the reshape. 524 quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) 525 526 # Ensure that there isn't a FakeQuant added before the reshape. 527 self.assertTrue( 528 'FakeQuantWithMinMaxVars' in [i.op.type for i in reshape.op.inputs]) 529 530 def testSeparableConvWithResourceVar(self): 531 graph = ops.Graph() 532 with graph.as_default(): 533 with variable_scope.variable_scope('', use_resource=True): 534 batch_size, height, width, depth = 5, 128, 128, 3 535 input1 = array_ops.zeros((batch_size, height, width, depth)) 536 kernel_size, depth_multiplier = 3, 1 537 depthwise_shape = [kernel_size, kernel_size, depth, depth_multiplier] 538 depthwise_weights = variables.model_variable( 539 'depthwise_weights', shape=depthwise_shape) 540 strides = [1, 1, 1, 1] 541 with variable_scope.variable_scope('depthwise_conv_1'): 542 conv1 = nn.depthwise_conv2d( 543 input1, depthwise_weights, strides, padding='SAME') 544 with variable_scope.variable_scope('depthwise_conv_2'): 545 conv2 = nn.depthwise_conv2d( 546 conv1, depthwise_weights, strides, padding='SAME') 547 math_ops.add(conv2, input1, name='add') 548 549 quantize.Quantize(graph, True) 550 551 # Test that the weights and activations of all convs have been quantized. 552 quant_node_name = 'FakeQuantWithMinMaxVars' 553 weights_quant = graph.get_operation_by_name( 554 'depthwise_conv_1/weights_quant/' + quant_node_name) 555 self.assertEqual(weights_quant.type, quant_node_name) 556 act_quant = graph.get_operation_by_name('depthwise_conv_1/act_quant/' + 557 quant_node_name) 558 self.assertEqual(act_quant.type, quant_node_name) 559 560 weights_quant = graph.get_operation_by_name( 561 'depthwise_conv_2/weights_quant/' + quant_node_name) 562 self.assertEqual(weights_quant.type, quant_node_name) 563 act_quant = graph.get_operation_by_name('depthwise_conv_2/act_quant/' + 564 quant_node_name) 565 self.assertEqual(act_quant.type, quant_node_name) 566 567 def _WeightInit(self, stddev): 568 """Returns truncated normal variable initializer. 569 570 Function is defined purely to shorten the name so that it stops wrapping. 571 572 Args: 573 stddev: Standard deviation of normal variable. 574 575 Returns: 576 An initialized that initializes with a truncated normal variable. 577 """ 578 return init_ops.truncated_normal_initializer(stddev=stddev) 579 580if __name__ == '__main__': 581 googletest.main() 582