qml.transforms.broadcast_expand¶
-
broadcast_expand
(tape)[source]¶ Expand a broadcasted tape into multiple tapes and a function that stacks and squeezes the results.
- Parameters
tape (QuantumTape) – Broadcasted tape to be expanded
- Returns
Returns a tuple containing a list of quantum tapes that produce one of the results of the broadcasted tape each, and a function that stacks and squeezes the tape execution results.
- Return type
tuple[list[QuantumTape], function]
This expansion function is used internally whenever a device does not support broadcasting.
Example
We may use
broadcast_expand
on aQNode
to separate it into multiple calculations. For this we will provideqml.RX
with thendim_params
attribute that allows the operation to detect broadcasting, and set up a simpleQNode
with a single operation and returned expectation value:>>> qml.RX.ndim_params = (0,) >>> dev = qml.device("default.qubit", wires=1) >>> @qml.qnode(dev) >>> def circuit(x): ... qml.RX(x, wires=0) ... return qml.expval(qml.PauliZ(0))
We can then call
broadcast_expand
on the QNode and store the expandedQNode
:>>> expanded_circuit = qml.transforms.broadcast_expand(circuit)
Let’s use the expanded QNode and draw it for broadcasted parameters with broadcasting axis of length
3
passed toqml.RX
:>>> x = pnp.array([0.2, 0.6, 1.0], requires_grad=True) >>> print(qml.draw(expanded_circuit)(x)) 0: ──RX(0.20)─┤ <Z> 0: ──RX(0.60)─┤ <Z> 0: ──RX(1.00)─┤ <Z>
Executing the expanded
QNode
results in three values, corresponding to the three parameters in the broadcasted inputx
:>>> expanded_circuit(x) tensor([0.98006658, 0.82533561, 0.54030231], requires_grad=True)
We also can call the transform manually on a tape:
>>> with qml.tape.QuantumTape() as tape: >>> qml.RX(pnp.array([0.2, 0.6, 1.0], requires_grad=True), wires=0) >>> qml.expval(qml.PauliZ(0)) >>> tapes, fn = qml.transforms.broadcast_expand(tape) >>> tapes [<QuantumTape: wires=[0], params=1>, <QuantumTape: wires=[0], params=1>, <QuantumTape: wires=[0], params=1>] >>> fn(qml.execute(tapes, qml.device("default.qubit", wires=1), None)) array([0.98006658, 0.82533561, 0.54030231])