# Source code for pennylane.optimize.momentum

# Copyright 2018 Xanadu Quantum Technologies Inc.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""Momentum optimizer"""
from pennylane.utils import _flatten, unflatten

.. math:: x^{(t+1)} = x^{(t)} - a^{(t+1)}.

The accumulator term :math:a is updated as follows:

.. math:: a^{(t+1)} = m a^{(t)} + \eta \nabla f(x^{(t)}),

with user defined parameters:

* :math:\eta: the step size
* :math:m: the momentum

Args:
stepsize (float): user-defined hyperparameter :math:\eta
momentum (float): user-defined hyperparameter :math:m
"""
def __init__(self, stepsize=0.01, momentum=0.9):
super().__init__(stepsize)
self.momentum = momentum
self.accumulation = None

r"""Update the variables x to take a single optimization step. Flattens and unflattens
the inputs to maintain nested iterables as the parameters of the optimization.

Args:
function at point :math:x^{(t)}: :math:\nabla f(x^{(t)})
x (array): the current value of the variables :math:x^{(t)}

Returns:
array: the new values :math:x^{(t+1)}
"""

x_flat = _flatten(x)

if self.accumulation is None:
self.accumulation = [self._stepsize * g for g in grad_flat]
else:
self.accumulation = [self.momentum * a + self._stepsize * g for a, g in zip(self.accumulation, grad_flat)]

x_new_flat = [e-a for a, e in zip(self.accumulation, x_flat)]

return unflatten(x_new_flat, x)

[docs]    def reset(self):
"""Reset optimizer by erasing memory of past steps."""
self.accumulation = None