# Source code for pennylane.optimize.nesterov_momentum

# Copyright 2018 Xanadu Quantum Technologies Inc.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""Nesterov momentum optimizer"""
from pennylane.utils import _flatten, unflatten
from .momentum import MomentumOptimizer

[docs]class NesterovMomentumOptimizer(MomentumOptimizer):

Nesterov Momentum works like the :class:Momentum optimizer <.pennylane.optimize.MomentumOptimizer>,
but shifts the current input by the momentum term when computing the gradient of the objective function:

.. math:: a^{(t+1)} = m a^{(t)} + \eta \nabla f(x^{(t)} - m a^{(t)}).

The user defined parameters are:

* :math:\eta: the step size
* :math:m: the momentum

Args:
stepsize (float): user-defined hyperparameter :math:\eta
momentum (float): user-defined hyperparameter :math:m
"""
r"""Compute gradient of the objective_fn at at
the shifted point :math:(x - m\times\text{accumulation}).

Args:
objective_fn (function): the objective function for optimization
x (array): NumPy array containing the current values of the variables to be updated
objective function with respect to the variables x.
If None, the gradient function is computed automatically.

Returns:
array: NumPy array containing the gradient :math:\nabla f(x^{(t)})
"""

x_flat = _flatten(x)

if self.accumulation is None:
shifted_x_flat = list(x_flat)
else:
shifted_x_flat = [e - self.momentum * a for a, e in zip(self.accumulation, x_flat)]

shifted_x = unflatten(shifted_x_flat, x)