# Copyright 2018 Xanadu Quantum Technologies Inc.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from pennylane.utils import _flatten, unflatten

Base class for other gradient-descent-based optimizers.

A step of the gradient descent optimizer computes the new values via the rule

.. math::

x^{(t+1)} = x^{(t)} - \eta \nabla f(x^{(t)}).

where :math:\eta is a user-defined hyperparameter corresponding to step size.

Args:
stepsize (float): the user-defined hyperparameter :math:\eta
"""
def __init__(self, stepsize=0.01):
self._stepsize = stepsize

[docs]    def update_stepsize(self, stepsize):
r"""Update the initialized stepsize value :math:\eta.

This allows for techniques such as learning rate scheduling.

Args:
stepsize (float): the user-defined hyperparameter :math:\eta
"""
self._stepsize = stepsize

[docs]    def step(self, objective_fn, x, grad_fn=None):
"""Update x with one step of the optimizer.

Args:
objective_fn (function): the objective function for optimization
x (array): NumPy array containing the current values of the variables to be updated
objective function with respect to the variables x.
If None, the gradient function is computed automatically.

Returns:
array: the new variable values :math:x^{(t+1)}
"""

return x_out

[docs]    @staticmethod
r"""Compute gradient of the objective_fn at the point x.

Args:
objective_fn (function): the objective function for optimization
x (array): NumPy array containing the current values of the variables to be updated
objective function with respect to the variables x.
If None, the gradient function is computed automatically.

Returns:
array: NumPy array containing the gradient :math:\nabla f(x^{(t)})
"""
else:
return g

r"""Update the variables x to take a single optimization step. Flattens and unflattens
the inputs to maintain nested iterables as the parameters of the optimization.

Args:
function at point :math:x^{(t)}: :math:\nabla f(x^{(t)})
x (array): the current value of the variables :math:x^{(t)}

Returns:
array: the new values :math:x^{(t+1)}
"""

x_flat = _flatten(x)