1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Signal reconstruction via overlapped addition of frames."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.util.tf_export import tf_export
27
28
29@tf_export("signal.overlap_and_add")
30def overlap_and_add(signal, frame_step, name=None):
31  """Reconstructs a signal from a framed representation.
32
33  Adds potentially overlapping frames of a signal with shape
34  `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
35  The resulting tensor has shape `[..., output_size]` where
36
37      output_size = (frames - 1) * frame_step + frame_length
38
39  Args:
40    signal: A [..., frames, frame_length] `Tensor`. All dimensions may be
41      unknown, and rank must be at least 2.
42    frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be
43      less than or equal to `frame_length`.
44    name: An optional name for the operation.
45
46  Returns:
47    A `Tensor` with shape `[..., output_size]` containing the overlap-added
48    frames of `signal`'s inner-most two dimensions.
49
50  Raises:
51    ValueError: If `signal`'s rank is less than 2, or `frame_step` is not a
52      scalar integer.
53  """
54  with ops.name_scope(name, "overlap_and_add", [signal, frame_step]):
55    signal = ops.convert_to_tensor(signal, name="signal")
56    signal.shape.with_rank_at_least(2)
57    frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
58    frame_step.shape.assert_has_rank(0)
59    if not frame_step.dtype.is_integer:
60      raise ValueError("frame_step must be an integer. Got %s" %
61                       frame_step.dtype)
62
63    signal_shape = array_ops.shape(signal)
64
65    # All dimensions that are not part of the overlap-and-add. Can be empty for
66    # rank 2 inputs.
67    outer_dimensions = signal_shape[:-2]
68    outer_rank = array_ops.size(outer_dimensions)
69
70    def full_shape(inner_shape):
71      return array_ops.concat([outer_dimensions, inner_shape], 0)
72
73    frame_length = signal_shape[-1]
74    frames = signal_shape[-2]
75
76    # Compute output length.
77    output_length = frame_length + frame_step * (frames - 1)
78
79    # If frame_length is equal to frame_step, there's no overlap so just
80    # reshape the tensor.
81    frame_step_static = tensor_util.constant_value(frame_step)
82    if (frame_step_static is not None and signal.shape.dims is not None and
83        frame_step_static == signal.shape.dims[-1].value):
84      output_shape = full_shape([output_length])
85      return array_ops.reshape(signal, output_shape, name="fast_path")
86
87    # The following code is documented using this example:
88    #
89    # frame_step = 2
90    # signal.shape = (3, 5)
91    # a b c d e
92    # f g h i j
93    # k l m n o
94
95    # Compute the number of segments, per frame.
96    segments = -(-frame_length // frame_step)  # Divide and round up.
97
98    # Pad the frame_length dimension to a multiple of the frame step.
99    # Pad the frames dimension by `segments` so that signal.shape = (6, 6)
100    # a b c d e 0
101    # f g h i j 0
102    # k l m n o 0
103    # 0 0 0 0 0 0
104    # 0 0 0 0 0 0
105    # 0 0 0 0 0 0
106    paddings = [[0, segments], [0, segments * frame_step - frame_length]]
107    outer_paddings = array_ops.zeros([outer_rank, 2], dtypes.int32)
108    paddings = array_ops.concat([outer_paddings, paddings], 0)
109    signal = array_ops.pad(signal, paddings)
110
111    # Reshape so that signal.shape = (3, 6, 2)
112    # ab cd e0
113    # fg hi j0
114    # kl mn o0
115    # 00 00 00
116    # 00 00 00
117    # 00 00 00
118    shape = full_shape([frames + segments, segments, frame_step])
119    signal = array_ops.reshape(signal, shape)
120
121    # Transpose dimensions so that signal.shape = (3, 6, 2)
122    # ab fg kl 00 00 00
123    # cd hi mn 00 00 00
124    # e0 j0 o0 00 00 00
125    perm = array_ops.concat(
126        [math_ops.range(outer_rank), outer_rank + [1, 0, 2]], 0)
127    signal = array_ops.transpose(signal, perm)
128
129    # Reshape so that signal.shape = (18, 2)
130    # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0 00 00 00
131    shape = full_shape([(frames + segments) * segments, frame_step])
132    signal = array_ops.reshape(signal, shape)
133
134    # Truncate so that signal.shape = (15, 2)
135    # ab fg kl 00 00 00 cd hi mn 00 00 00 e0 j0 o0
136    signal = signal[..., :(frames + segments - 1) * segments, :]
137
138    # Reshape so that signal.shape = (3, 5, 2)
139    # ab fg kl 00 00
140    # 00 cd hi mn 00
141    # 00 00 e0 j0 o0
142    shape = full_shape([segments, (frames + segments - 1), frame_step])
143    signal = array_ops.reshape(signal, shape)
144
145    # Now, reduce over the columns, to achieve the desired sum.
146    signal = math_ops.reduce_sum(signal, -3)
147
148    # Flatten the array.
149    shape = full_shape([(frames + segments - 1) * frame_step])
150    signal = array_ops.reshape(signal, shape)
151
152    # Truncate to final length.
153    signal = signal[..., :output_length]
154
155    return signal
156