Quantum Generative Adversarial Network

This demo constructs a Quantum Generative Adversarial Network (QGAN) (Lloyd and Weedbrook (2018), Dallaire-Demers and Killoran (2018)) using two subcircuits, a generator and a discriminator. The generator attempts to generate synthetic quantum data to match a pattern of “real” data, while the discriminator, tries to discern real data from fake data; see image below. The gradient of the discriminator’s output provides a training signal for the generator to improve its fake generated data.


../_images/qgan.png

Imports

# As usual, we import PennyLane, the PennyLane-provided version of NumPy,
# and an optimizer.

import pennylane as qml
from pennylane import numpy as np
from pennylane.optimize import GradientDescentOptimizer

We also declare a 3-qubit device.

dev = qml.device("default.qubit", wires=3)

Classical and quantum nodes

In classical GANs, the starting point is to draw samples either from some “real data” distribution, or from the generator, and feed them to the discriminator. In this QGAN example, we will use a quantum circuit to generate the real data.

For this simple example, our real data will be a qubit that has been rotated (from the starting state \(\left|0\right\rangle\)) to some arbitrary, but fixed, state.

def real(phi, theta, omega):
    qml.Rot(phi, theta, omega, wires=0)

For the generator and discriminator, we will choose the same basic circuit structure, but acting on different wires.

Both the real data circuit and the generator will output on wire 0, which will be connected as an input to the discriminator. Wire 1 is provided as a workspace for the generator, while the discriminator’s output will be on wire 2.

def generator(w):
    qml.RX(w[0], wires=0)
    qml.RX(w[1], wires=1)
    qml.RY(w[2], wires=0)
    qml.RY(w[3], wires=1)
    qml.RZ(w[4], wires=0)
    qml.RZ(w[5], wires=1)
    qml.CNOT(wires=[0, 1])
    qml.RX(w[6], wires=0)
    qml.RY(w[7], wires=0)
    qml.RZ(w[8], wires=0)


def discriminator(w):
    qml.RX(w[0], wires=0)
    qml.RX(w[1], wires=2)
    qml.RY(w[2], wires=0)
    qml.RY(w[3], wires=2)
    qml.RZ(w[4], wires=0)
    qml.RZ(w[5], wires=2)
    qml.CNOT(wires=[1, 2])
    qml.RX(w[6], wires=2)
    qml.RY(w[7], wires=2)
    qml.RZ(w[8], wires=2)

We create two QNodes. One where the real data source is wired up to the discriminator, and one where the generator is connected to the discriminator.

@qml.qnode(dev)
def real_disc_circuit(phi, theta, omega, disc_weights):
    real(phi, theta, omega)
    discriminator(disc_weights)
    return qml.expval(qml.PauliZ(2))


@qml.qnode(dev)
def gen_disc_circuit(gen_weights, disc_weights):
    generator(gen_weights)
    discriminator(disc_weights)
    return qml.expval(qml.PauliZ(2))

Cost

There are two ingredients to the cost here. The first is the probability that the discriminator correctly classifies real data as real. The second ingredient is the probability that the discriminator classifies fake data (i.e., a state prepared by the generator) as real.

The discriminator’s objective is to maximize the probability of correctly classifying real data, while minimizing the probability of mistakenly classifying fake data.

The generator’s objective is to maximize the probability that the discriminator accepts fake data as real.

def prob_real_true(disc_weights):
    true_disc_output = real_disc_circuit(phi, theta, omega, disc_weights)
    # convert to probability
    prob_real_true = (true_disc_output + 1) / 2
    return prob_real_true


def prob_fake_true(gen_weights, disc_weights):
    fake_disc_output = gen_disc_circuit(gen_weights, disc_weights)
    # convert to probability
    prob_fake_true = (fake_disc_output + 1) / 2
    return prob_fake_true  # generator wants to minimize this prob


def disc_cost(disc_weights):
    cost = prob_fake_true(gen_weights, disc_weights) - prob_real_true(disc_weights)
    return cost


def gen_cost(gen_weights):
    return -prob_fake_true(gen_weights, disc_weights)

Optimization

We initialize the fixed angles of the “real data” circuit, as well as the initial parameters for both generator and discriminator. These are chosen so that the generator initially prepares a state on wire 0 that is very close to the \(\left| 1 \right\rangle\) state.

phi = np.pi / 6
theta = np.pi / 2
omega = np.pi / 7
np.random.seed(0)
eps = 1e-2
gen_weights = np.array([np.pi] + [0] * 8) + np.random.normal(scale=eps, size=[9])
disc_weights = np.random.normal(size=[9])

We begin by creating the optimizer:

opt = GradientDescentOptimizer(0.1)

In the first stage of training, we optimize the discriminator while keeping the generator parameters fixed.

for it in range(50):
    disc_weights = opt.step(disc_cost, disc_weights)
    cost = disc_cost(disc_weights)
    if it % 5 == 0:
        print("Step {}: cost = {}".format(it + 1, cost))

Out:

Step 1: cost = -0.10942017805789106
Step 6: cost = -0.3899884226490309
Step 11: cost = -0.6660191175815626
Step 16: cost = -0.8550839212078469
Step 21: cost = -0.9454459581664483
Step 26: cost = -0.9805878247866402
Step 31: cost = -0.9931371328342746
Step 36: cost = -0.9974896764916585
Step 41: cost = -0.9989863506630712
Step 46: cost = -0.9995000463932012

At the discriminator’s optimum, the probability for the discriminator to correctly classify the real data should be close to one.

print(prob_real_true(disc_weights))

Out:

0.999897195184226

For comparison, we check how the discriminator classifies the generator’s (still unoptimized) fake data:

print(prob_fake_true(gen_weights, disc_weights))

Out:

0.00024278396180033024

In the adverserial game we have to now train the generator to better fool the discriminator (we can continue training the models in an alternating fashion until we reach the optimum point of the two-player adversarial game).

for it in range(200):
    gen_weights = opt.step(gen_cost, gen_weights)
    cost = -gen_cost(gen_weights)
    if it % 5 == 0:
        print("Step {}: cost = {}".format(it, cost))

Out:

Step 0: cost = 0.00026646913829941887
Step 5: cost = 0.0004266200858934477
Step 10: cost = 0.0006872486146980994
Step 15: cost = 0.0011111626380133632
Step 20: cost = 0.0018000510248330492
Step 25: cost = 0.0029179304125444006
Step 30: cost = 0.004727717539774023
Step 35: cost = 0.007646628881031792
Step 40: cost = 0.012325866735736213
Step 45: cost = 0.01975451893452712
Step 50: cost = 0.03136834673567185
Step 55: cost = 0.049097345993078134
Step 60: cost = 0.07520378135265482
Step 65: cost = 0.11169015288702167
Step 70: cost = 0.1591728633374102
Step 75: cost = 0.21566031343947312
Step 80: cost = 0.27637357210452673
Step 85: cost = 0.3354169186527467
Step 90: cost = 0.3883501266928635
Step 95: cost = 0.43371772120148644
Step 100: cost = 0.4728490188392828
Step 105: cost = 0.5087778323625848
Step 110: cost = 0.5451977336157682
Step 115: cost = 0.5856632916397914
Step 120: cost = 0.6327897835085243
Step 125: cost = 0.6872469221106658
Step 130: cost = 0.7468453348018477
Step 135: cost = 0.806641363758769
Step 140: cost = 0.8607353038575862
Step 145: cost = 0.904841399047886
Step 150: cost = 0.9376687441862369
Step 155: cost = 0.96041048602581
Step 160: cost = 0.9753711478478343
Step 165: cost = 0.9848746701785992
Step 170: cost = 0.9907765008479255
Step 175: cost = 0.9943898535953368
Step 180: cost = 0.9965827747752587
Step 185: cost = 0.9979065439048094
Step 190: cost = 0.998703060843953
Step 195: cost = 0.999181393560638

At the optimum of the generator, the probability for the discriminator to be fooled should be close to 1.

print(prob_fake_true(gen_weights, disc_weights))

Out:

0.9994220324420164

At the joint optimum the overall cost will be close to zero.

print(disc_cost(disc_weights))

Out:

-0.00047516274220960053

The generator has successfully learned how to simulate the real data enough to fool the discriminator.

Total running time of the script: ( 0 minutes 43.918 seconds)

Gallery generated by Sphinx-Gallery