Traktor/myenv/Lib/site-packages/sympy/physics/quantum/circuitplot.py
2024-05-26 05:12:46 +02:00

371 lines
12 KiB
Python

"""Matplotlib based plotting of quantum circuits.
Todo:
* Optimize printing of large circuits.
* Get this to work with single gates.
* Do a better job checking the form of circuits to make sure it is a Mul of
Gates.
* Get multi-target gates plotting.
* Get initial and final states to plot.
* Get measurements to plot. Might need to rethink measurement as a gate
issue.
* Get scale and figsize to be handled in a better way.
* Write some tests/examples!
"""
from __future__ import annotations
from sympy.core.mul import Mul
from sympy.external import import_module
from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS
__all__ = [
'CircuitPlot',
'circuit_plot',
'labeller',
'Mz',
'Mx',
'CreateOneQubitGate',
'CreateCGate',
]
np = import_module('numpy')
matplotlib = import_module(
'matplotlib', import_kwargs={'fromlist': ['pyplot']},
catch=(RuntimeError,)) # This is raised in environments that have no display.
if np and matplotlib:
pyplot = matplotlib.pyplot
Line2D = matplotlib.lines.Line2D
Circle = matplotlib.patches.Circle
#from matplotlib import rc
#rc('text',usetex=True)
class CircuitPlot:
"""A class for managing a circuit plot."""
scale = 1.0
fontsize = 20.0
linewidth = 1.0
control_radius = 0.05
not_radius = 0.15
swap_delta = 0.05
labels: list[str] = []
inits: dict[str, str] = {}
label_buffer = 0.5
def __init__(self, c, nqubits, **kwargs):
if not np or not matplotlib:
raise ImportError('numpy or matplotlib not available.')
self.circuit = c
self.ngates = len(self.circuit.args)
self.nqubits = nqubits
self.update(kwargs)
self._create_grid()
self._create_figure()
self._plot_wires()
self._plot_gates()
self._finish()
def update(self, kwargs):
"""Load the kwargs into the instance dict."""
self.__dict__.update(kwargs)
def _create_grid(self):
"""Create the grid of wires."""
scale = self.scale
wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float)
gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float)
self._wire_grid = wire_grid
self._gate_grid = gate_grid
def _create_figure(self):
"""Create the main matplotlib figure."""
self._figure = pyplot.figure(
figsize=(self.ngates*self.scale, self.nqubits*self.scale),
facecolor='w',
edgecolor='w'
)
ax = self._figure.add_subplot(
1, 1, 1,
frameon=True
)
ax.set_axis_off()
offset = 0.5*self.scale
ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset)
ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset)
ax.set_aspect('equal')
self._axes = ax
def _plot_wires(self):
"""Plot the wires of the circuit diagram."""
xstart = self._gate_grid[0]
xstop = self._gate_grid[-1]
xdata = (xstart - self.scale, xstop + self.scale)
for i in range(self.nqubits):
ydata = (self._wire_grid[i], self._wire_grid[i])
line = Line2D(
xdata, ydata,
color='k',
lw=self.linewidth
)
self._axes.add_line(line)
if self.labels:
init_label_buffer = 0
if self.inits.get(self.labels[i]): init_label_buffer = 0.25
self._axes.text(
xdata[0]-self.label_buffer-init_label_buffer,ydata[0],
render_label(self.labels[i],self.inits),
size=self.fontsize,
color='k',ha='center',va='center')
self._plot_measured_wires()
def _plot_measured_wires(self):
ismeasured = self._measurements()
xstop = self._gate_grid[-1]
dy = 0.04 # amount to shift wires when doubled
# Plot doubled wires after they are measured
for im in ismeasured:
xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale)
ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy)
line = Line2D(
xdata, ydata,
color='k',
lw=self.linewidth
)
self._axes.add_line(line)
# Also double any controlled lines off these wires
for i,g in enumerate(self._gates()):
if isinstance(g, (CGate, CGateS)):
wires = g.controls + g.targets
for wire in wires:
if wire in ismeasured and \
self._gate_grid[i] > self._gate_grid[ismeasured[wire]]:
ydata = min(wires), max(wires)
xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy
line = Line2D(
xdata, ydata,
color='k',
lw=self.linewidth
)
self._axes.add_line(line)
def _gates(self):
"""Create a list of all gates in the circuit plot."""
gates = []
if isinstance(self.circuit, Mul):
for g in reversed(self.circuit.args):
if isinstance(g, Gate):
gates.append(g)
elif isinstance(self.circuit, Gate):
gates.append(self.circuit)
return gates
def _plot_gates(self):
"""Iterate through the gates and plot each of them."""
for i, gate in enumerate(self._gates()):
gate.plot_gate(self, i)
def _measurements(self):
"""Return a dict ``{i:j}`` where i is the index of the wire that has
been measured, and j is the gate where the wire is measured.
"""
ismeasured = {}
for i,g in enumerate(self._gates()):
if getattr(g,'measurement',False):
for target in g.targets:
if target in ismeasured:
if ismeasured[target] > i:
ismeasured[target] = i
else:
ismeasured[target] = i
return ismeasured
def _finish(self):
# Disable clipping to make panning work well for large circuits.
for o in self._figure.findobj():
o.set_clip_on(False)
def one_qubit_box(self, t, gate_idx, wire_idx):
"""Draw a box for a single qubit gate."""
x = self._gate_grid[gate_idx]
y = self._wire_grid[wire_idx]
self._axes.text(
x, y, t,
color='k',
ha='center',
va='center',
bbox={"ec": 'k', "fc": 'w', "fill": True, "lw": self.linewidth},
size=self.fontsize
)
def two_qubit_box(self, t, gate_idx, wire_idx):
"""Draw a box for a two qubit gate. Does not work yet.
"""
# x = self._gate_grid[gate_idx]
# y = self._wire_grid[wire_idx]+0.5
print(self._gate_grid)
print(self._wire_grid)
# unused:
# obj = self._axes.text(
# x, y, t,
# color='k',
# ha='center',
# va='center',
# bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth),
# size=self.fontsize
# )
def control_line(self, gate_idx, min_wire, max_wire):
"""Draw a vertical control line."""
xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx])
ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire])
line = Line2D(
xdata, ydata,
color='k',
lw=self.linewidth
)
self._axes.add_line(line)
def control_point(self, gate_idx, wire_idx):
"""Draw a control point."""
x = self._gate_grid[gate_idx]
y = self._wire_grid[wire_idx]
radius = self.control_radius
c = Circle(
(x, y),
radius*self.scale,
ec='k',
fc='k',
fill=True,
lw=self.linewidth
)
self._axes.add_patch(c)
def not_point(self, gate_idx, wire_idx):
"""Draw a NOT gates as the circle with plus in the middle."""
x = self._gate_grid[gate_idx]
y = self._wire_grid[wire_idx]
radius = self.not_radius
c = Circle(
(x, y),
radius,
ec='k',
fc='w',
fill=False,
lw=self.linewidth
)
self._axes.add_patch(c)
l = Line2D(
(x, x), (y - radius, y + radius),
color='k',
lw=self.linewidth
)
self._axes.add_line(l)
def swap_point(self, gate_idx, wire_idx):
"""Draw a swap point as a cross."""
x = self._gate_grid[gate_idx]
y = self._wire_grid[wire_idx]
d = self.swap_delta
l1 = Line2D(
(x - d, x + d),
(y - d, y + d),
color='k',
lw=self.linewidth
)
l2 = Line2D(
(x - d, x + d),
(y + d, y - d),
color='k',
lw=self.linewidth
)
self._axes.add_line(l1)
self._axes.add_line(l2)
def circuit_plot(c, nqubits, **kwargs):
"""Draw the circuit diagram for the circuit with nqubits.
Parameters
==========
c : circuit
The circuit to plot. Should be a product of Gate instances.
nqubits : int
The number of qubits to include in the circuit. Must be at least
as big as the largest ``min_qubits`` of the gates.
"""
return CircuitPlot(c, nqubits, **kwargs)
def render_label(label, inits={}):
"""Slightly more flexible way to render labels.
>>> from sympy.physics.quantum.circuitplot import render_label
>>> render_label('q0')
'$\\\\left|q0\\\\right\\\\rangle$'
>>> render_label('q0', {'q0':'0'})
'$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$'
"""
init = inits.get(label)
if init:
return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init)
return r'$\left|%s\right\rangle$' % label
def labeller(n, symbol='q'):
"""Autogenerate labels for wires of quantum circuits.
Parameters
==========
n : int
number of qubits in the circuit.
symbol : string
A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc.
>>> from sympy.physics.quantum.circuitplot import labeller
>>> labeller(2)
['q_1', 'q_0']
>>> labeller(3,'j')
['j_2', 'j_1', 'j_0']
"""
return ['%s_%d' % (symbol,n-i-1) for i in range(n)]
class Mz(OneQubitGate):
"""Mock-up of a z measurement gate.
This is in circuitplot rather than gate.py because it's not a real
gate, it just draws one.
"""
measurement = True
gate_name='Mz'
gate_name_latex='M_z'
class Mx(OneQubitGate):
"""Mock-up of an x measurement gate.
This is in circuitplot rather than gate.py because it's not a real
gate, it just draws one.
"""
measurement = True
gate_name='Mx'
gate_name_latex='M_x'
class CreateOneQubitGate(type):
def __new__(mcl, name, latexname=None):
if not latexname:
latexname = name
return type(name + "Gate", (OneQubitGate,),
{'gate_name': name, 'gate_name_latex': latexname})
def CreateCGate(name, latexname=None):
"""Use a lexical closure to make a controlled gate.
"""
if not latexname:
latexname = name
onequbitgate = CreateOneQubitGate(name, latexname)
def ControlledGate(ctrls,target):
return CGate(tuple(ctrls),onequbitgate(target))
return ControlledGate