1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Signal reconstruction via overlapped addition of frames.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.util.tf_export import tf_export 27 28 29@tf_export("signal.overlap_and_add") 30def overlap_and_add(signal, frame_step, name=None): 31 """Reconstructs a signal from a framed representation. 32 33 Adds potentially overlapping frames of a signal with shape 34 `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 35 The resulting tensor has shape `[..., output_size]` where 36 37 output_size = (frames - 1) * frame_step + frame_length 38 39 Args: 40 signal: A [..., frames, frame_length] `Tensor`. All dimensions may be 41 unknown, and rank must be at least 2. 42 frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be 43 less than or equal to `frame_length`. 44 name: An optional name for the operation. 45 46 Returns: 47 A `Tensor` with shape `[..., output_size]` containing the overlap-added 48 frames of `signal`'s inner-most two dimensions. 49 50 Raises: 51 ValueError: If `signal`'s rank is less than 2, or `frame_step` is not a 52 scalar integer. 53 """ 54 with ops.name_scope(name, "overlap_and_add", [signal, frame_step]): 55 signal = ops.convert_to_tensor(signal, name="signal") 56 signal.shape.with_rank_at_least(2) 57 frame_step = ops.convert_to_tensor(frame_step, name="frame_step") 58 frame_step.shape.assert_has_rank(0) 59 if not frame_step.dtype.is_integer: 60 raise ValueError("frame_step must be an integer. Got %s" % 61 frame_step.dtype) 62 63 signal_shape = array_ops.shape(signal) 64 65 # All dimensions that are not part of the overlap-and-add. Can be empty for 66 # rank 2 inputs. 67 outer_dimensions = signal_shape[:-2] 68 outer_rank = array_ops.size(outer_dimensions) 69 70 def full_shape(inner_shape): 71 return array_ops.concat([outer_dimensions, inner_shape], 0) 72 73 frame_length = signal_shape[-1] 74 frames = signal_shape[-2] 75 76 # Compute output length. 77 output_length = frame_length + frame_step * (frames - 1) 78 79 # If frame_length is equal to frame_step, there's no overlap so just 80 # reshape the tensor. 81 frame_step_static = tensor_util.constant_value(frame_step) 82 if (frame_step_static is not None and signal.shape.dims is not None and 83 frame_step_static == signal.shape.dims[-1].value): 84 output_shape = full_shape([output_length]) 85 return array_ops.reshape(signal, output_shape, name="fast_path") 86 87 # The following code is documented using this example: 88 # 89 # frame_step = 2 90 # signal.shape = (3, 5) 91 # a b c d e 92 # f g h i j 93 # k l m n o 94 95 # Compute the number of segments, per frame. 96 segments = -(-frame_length // frame_step) # Divide and round up. 97 98 # Pad the frame_length dimension to a multiple of the frame step. 99 # Pad the frames dimension by `segments` so that signal.shape = (6, 6) 100 # a b c d e 0 101 # f g h i j 0 102 # k l m n o 0 103 # 0 0 0 0 0 0 104 # 0 0 0 0 0 0 105 # 0 0 0 0 0 0 106 paddings = [[0, segments], [0, segments * frame_step - frame_length]] 107 outer_paddings = array_ops.zeros([outer_rank, 2], dtypes.int32) 108 paddings = array_ops.concat([outer_paddings, paddings], 0) 109 signal = array_ops.pad(signal, paddings) 110 111 # Reshape so that signal.shape = (3, 6, 2) 112 # ab cd e0 113 # fg hi j0 114 # kl mn o0 115 # 00 00 00 116 # 00 00 00 117 # 00 00 00 118 shape = full_shape([frames + segments, segments, frame_step]) 119 signal = array_ops.reshape(signal, shape) 120 121 # Transpose dimensions so that signal.shape = (3, 6, 2) 122 # ab fg kl 00 00 00 123 # cd hi mn 00 00 00 124 # e0 j0 o0 00 00 00 125 perm = array_ops.concat( 126 [math_ops.range(outer_rank), outer_rank + [1, 0, 2]], 0) 127 signal = array_ops.transpose(signal, perm) 128 129 # Reshape so that signal.shape = (18, 2) 130 # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 00 00 00 131 shape = full_shape([(frames + segments) * segments, frame_step]) 132 signal = array_ops.reshape(signal, shape) 133 134 # Truncate so that signal.shape = (15, 2) 135 # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 136 signal = signal[..., :(frames + segments - 1) * segments, :] 137 138 # Reshape so that signal.shape = (3, 5, 2) 139 # ab fg kl 00 00 140 # 00 cd hi mn 00 141 # 00 00 e0 j0 o0 142 shape = full_shape([segments, (frames + segments - 1), frame_step]) 143 signal = array_ops.reshape(signal, shape) 144 145 # Now, reduce over the columns, to achieve the desired sum. 146 signal = math_ops.reduce_sum(signal, -3) 147 148 # Flatten the array. 149 shape = full_shape([(frames + segments - 1) * frame_step]) 150 signal = array_ops.reshape(signal, shape) 151 152 # Truncate to final length. 153 signal = signal[..., :output_length] 154 155 return signal 156