1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Matrix functions contains iterative methods for M^p."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.ops import control_flow_ops
22from tensorflow.python.ops import linalg_ops
23from tensorflow.python.ops import math_ops
24
25
26def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4):
27  """Iterative method to get matrix square root.
28
29  Stable iterations for the matrix square root, Nicholas J. Higham
30
31  Page 231, Eq 2.6b
32  http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf
33
34  Args:
35    mat_a: the symmetric PSD matrix whose matrix square root be computed
36    mat_a_size: size of mat_a.
37    iter_count: Maximum number of iterations.
38    ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
39
40  Returns:
41    mat_a^0.5
42  """
43
44  def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z,
45                      unused_old_mat_z, err, old_err):
46    # This method require that we check for divergence every step.
47    return math_ops.logical_and(i < iter_count, err < old_err)
48
49  def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err,
50                 unused_old_err):
51    current_iterate = 0.5 * (3.0 * identity - math_ops.matmul(mat_z, mat_y))
52    current_mat_y = math_ops.matmul(mat_y, current_iterate)
53    current_mat_z = math_ops.matmul(current_iterate, mat_z)
54    # Compute the error in approximation.
55    mat_sqrt_a = current_mat_y * math_ops.sqrt(norm)
56    mat_a_approx = math_ops.matmul(mat_sqrt_a, mat_sqrt_a)
57    residual = mat_a - mat_a_approx
58    current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm
59    return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err
60
61  identity = linalg_ops.eye(math_ops.cast(mat_a_size, dtypes.int32))
62  mat_a = mat_a + ridge_epsilon * identity
63  norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a))
64  mat_init_y = mat_a / norm
65  mat_init_z = identity
66  init_err = norm
67
68  _, _, prev_mat_y, _, _, _, _ = control_flow_ops.while_loop(
69      _iter_condition, _iter_body, [
70          0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err,
71          init_err + 1.0
72      ])
73  return prev_mat_y * math_ops.sqrt(norm)
74
75
76def matrix_inverse_pth_root(mat_g,
77                            mat_g_size,
78                            alpha,
79                            iter_count=100,
80                            epsilon=1e-6,
81                            ridge_epsilon=1e-6):
82  """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
83
84  We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
85
86  A Schur-Newton Method for the Matrix p-th Root and its Inverse
87  by Chun-Hua Guo and Nicholas J. Higham
88  SIAM Journal on Matrix Analysis and Applications,
89  2006, Vol. 28, No. 3 : pp. 788-804
90  https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
91
92  Args:
93    mat_g: the symmetric PSD matrix whose power it to be computed
94    mat_g_size: size of mat_g.
95    alpha: exponent, must be -1/p for p a positive integer.
96    iter_count: Maximum number of iterations.
97    epsilon: accuracy indicator, useful for early termination.
98    ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
99
100  Returns:
101    mat_g^alpha
102  """
103
104  identity = linalg_ops.eye(math_ops.cast(mat_g_size, dtypes.int32))
105
106  def mat_power(mat_m, p):
107    """Computes mat_m^p, for p a positive integer.
108
109    Power p is known at graph compile time, so no need for loop and cond.
110    Args:
111      mat_m: a square matrix
112      p: a positive integer
113
114    Returns:
115      mat_m^p
116    """
117    assert p == int(p) and p > 0
118    power = None
119    while p > 0:
120      if p % 2 == 1:
121        power = math_ops.matmul(mat_m, power) if power is not None else mat_m
122      p //= 2
123      mat_m = math_ops.matmul(mat_m, mat_m)
124    return power
125
126  def _iter_condition(i, mat_m, _):
127    return math_ops.logical_and(
128        i < iter_count,
129        math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
130
131  def _iter_body(i, mat_m, mat_x):
132    mat_m_i = (1 - alpha) * identity + alpha * mat_m
133    return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha), mat_m),
134            math_ops.matmul(mat_x, mat_m_i))
135
136  if mat_g_size == 1:
137    mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha)
138  else:
139    damped_mat_g = mat_g + ridge_epsilon * identity
140    z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
141    # The best value for z is
142    # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
143    #                 (c_max^{1-alpha} - c_min^{1-alpha})
144    # where c_max and c_min are the largest and smallest singular values of
145    # damped_mat_g.
146    # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
147    # Can replace above line by the one below, but it is less accurate,
148    # hence needs more iterations to converge.
149    # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
150    # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
151    # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
152    # extra iterations.
153    _, _, mat_h = control_flow_ops.while_loop(
154        _iter_condition, _iter_body,
155        [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
156  return mat_h
157