Skip to content

Commit

Permalink
Make PauliMeasurementGate respect sign of the pauli observable. (qu…
Browse files Browse the repository at this point in the history
…antumlib#4836)

Fixes quantumlib#4814 

Note that this is a breaking change because:
- Serialization of the `PauliMeasurementGate` is now different -- the serialized observable is `DensePauliString` instead of a tuple of Pauli's.
-  A DensePauliString with coefficient != +1/-1 will now raise a `ValueError` whereas earlier the coefficient was simply ignored.
  • Loading branch information
tanujkhattar authored and rht committed May 1, 2023
1 parent 6a63eb8 commit 16bf9d2
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 56 deletions.
53 changes: 38 additions & 15 deletions cirq-core/cirq/ops/pauli_measurement_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union
from typing import Any, Dict, FrozenSet, Iterable, Tuple, Sequence, TYPE_CHECKING, Union, cast

from cirq import protocols, value
from cirq.ops import (
raw_types,
measurement_gate,
op_tree,
dense_pauli_string,
dense_pauli_string as dps,
pauli_gates,
pauli_string_phasor,
)
Expand All @@ -38,25 +38,36 @@ class PauliMeasurementGate(raw_types.Gate):

def __init__(
self,
observable: Iterable['cirq.Pauli'],
observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']],
key: Union[str, 'cirq.MeasurementKey'] = '',
) -> None:
"""Inits PauliMeasurementGate.
Args:
observable: Pauli observable to measure. Any `Iterable[cirq.Pauli]`
is a valid Pauli observable, including `cirq.DensePauliString`
instances, which do not contain any identity gates.
is a valid Pauli observable (with a +1 coefficient by default).
If you wish to measure pauli observables with coefficient -1,
then pass a `cirq.DensePauliString` as observable.
key: The string key of the measurement.
Raises:
ValueError: If the observable is empty.
"""
if not observable:
raise ValueError(f'Pauli observable {observable} is empty.')
if not all(isinstance(p, pauli_gates.Pauli) for p in observable):
if not all(
isinstance(p, pauli_gates.Pauli) for p in cast(Iterable['cirq.Gate'], observable)
):
raise ValueError(f'Pauli observable {observable} must be Iterable[`cirq.Pauli`].')
self._observable = tuple(observable)
coefficient = (
observable.coefficient if isinstance(observable, dps.BaseDensePauliString) else 1
)
if coefficient not in [+1, -1]:
raise ValueError(
f'`cirq.DensePauliString` observable {observable} must have coefficient +1/-1.'
)

self._observable = dps.DensePauliString(observable, coefficient=coefficient)
self.key = key # type: ignore

@property
Expand Down Expand Up @@ -94,9 +105,15 @@ def _with_rescoped_keys_(
def _with_measurement_key_mapping_(self, key_map: Dict[str, str]) -> 'PauliMeasurementGate':
return self.with_key(protocols.with_measurement_key_mapping(self.mkey, key_map))

def with_observable(self, observable: Iterable['cirq.Pauli']) -> 'PauliMeasurementGate':
def with_observable(
self, observable: Union['cirq.BaseDensePauliString', Iterable['cirq.Pauli']]
) -> 'PauliMeasurementGate':
"""Creates a pauli measurement gate with the new observable and same key."""
if tuple(observable) == self._observable:
if (
observable
if isinstance(observable, dps.BaseDensePauliString)
else dps.DensePauliString(observable)
) == self._observable:
return self
return PauliMeasurementGate(observable, key=self.key)

Expand All @@ -111,24 +128,30 @@ def _measurement_key_obj_(self) -> 'cirq.MeasurementKey':

def observable(self) -> 'cirq.DensePauliString':
"""Pauli observable which should be measured by the gate."""
return dense_pauli_string.DensePauliString(self._observable)
return self._observable

def _decompose_(
self, qubits: Tuple['cirq.Qid', ...]
) -> 'protocols.decompose_protocol.DecomposeResult':
any_qubit = qubits[0]
to_z_ops = op_tree.freeze_op_tree(self.observable().on(*qubits).to_z_basis_ops())
to_z_ops = op_tree.freeze_op_tree(self._observable.on(*qubits).to_z_basis_ops())
xor_decomp = tuple(pauli_string_phasor.xor_nonlocal_decompose(qubits, any_qubit))
yield to_z_ops
yield xor_decomp
yield measurement_gate.MeasurementGate(1, self.mkey).on(any_qubit)
yield measurement_gate.MeasurementGate(
1, self.mkey, invert_mask=(self._observable.coefficient != 1,)
).on(any_qubit)
yield protocols.inverse(xor_decomp)
yield protocols.inverse(to_z_ops)

def _circuit_diagram_info_(
self, args: 'cirq.CircuitDiagramInfoArgs'
) -> 'cirq.CircuitDiagramInfo':
symbols = [f'M({g})' for g in self._observable]
coefficient = '' if self._observable.coefficient == 1 else '-'
symbols = [
f'M({"" if i else coefficient}{self._observable[i]})'
for i in range(len(self._observable))
]

# Mention the measurement key.
label_map = args.label_map or {}
Expand All @@ -141,14 +164,14 @@ def _circuit_diagram_info_(
return protocols.CircuitDiagramInfo(tuple(symbols))

def _op_repr_(self, qubits: Sequence['cirq.Qid']) -> str:
args = [repr(self.observable().on(*qubits))]
args = [repr(self._observable.on(*qubits))]
if self.key != _default_measurement_key(qubits):
args.append(f'key={self.mkey!r}')
arg_list = ', '.join(args)
return f'cirq.measure_single_paulistring({arg_list})'

def __repr__(self) -> str:
return f'cirq.PauliMeasurementGate(' f'{self._observable!r}, ' f'{self.mkey!r})'
return f'cirq.PauliMeasurementGate({self._observable!r}, {self.mkey!r})'

def _value_equality_values_(self) -> Any:
return self.key, self._observable
Expand Down
22 changes: 21 additions & 1 deletion cirq-core/cirq/ops/pauli_measurement_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_init(observable, key):
assert g.num_qubits() == len(observable)
assert g.key == 'a'
assert g.mkey == cirq.MeasurementKey('a')
assert g._observable == tuple(observable)
assert g._observable == cirq.DensePauliString(observable)
assert cirq.qid_shape(g) == (2,) * len(observable)


Expand Down Expand Up @@ -162,6 +162,9 @@ def test_bad_observable_raises():
with pytest.raises(ValueError, match=r'Pauli observable .* must be Iterable\[`cirq.Pauli`\]'):
_ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZI'))

with pytest.raises(ValueError, match=r'must have coefficient \+1/-1.'):
_ = cirq.PauliMeasurementGate(cirq.DensePauliString('XYZ', coefficient=1j))


def test_with_observable():
o1 = [cirq.Z, cirq.Y, cirq.X]
Expand All @@ -170,3 +173,20 @@ def test_with_observable():
g2 = cirq.PauliMeasurementGate(o2, key='a')
assert g1.with_observable(o2) == g2
assert g1.with_observable(o1) is g1


@pytest.mark.parametrize(
'rot, obs, out',
[
(cirq.I, cirq.DensePauliString("Z", coefficient=+1), 0),
(cirq.I, cirq.DensePauliString("Z", coefficient=-1), 1),
(cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=+1), 0),
(cirq.Y ** 0.5, cirq.DensePauliString("X", coefficient=-1), 1),
(cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=+1), 0),
(cirq.X ** -0.5, cirq.DensePauliString("Y", coefficient=-1), 1),
],
)
def test_pauli_measurement_gate_samples(rot, obs, out):
q = cirq.NamedQubit("q")
c = cirq.Circuit(rot(q), cirq.PauliMeasurementGate(obs, key='out').on(q))
assert cirq.Simulator().sample(c)['out'][0] == out
89 changes: 50 additions & 39 deletions cirq-core/cirq/protocols/json_test_data/PauliMeasurementGate.json
Original file line number Diff line number Diff line change
@@ -1,42 +1,53 @@
[{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
[
{
"cirq_type": "PauliMeasurementGate",
"observable": {
"cirq_type": "DensePauliString",
"pauli_mask": [
1,
2,
3
],
"coefficient": {
"cirq_type": "complex",
"real": 1.0,
"imag": 0.0
}
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
"key": "key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": {
"cirq_type": "DensePauliString",
"pauli_mask": [
1,
2,
3
],
"coefficient": {
"cirq_type": "complex",
"real": 1.0,
"imag": 0.0
}
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
"key": "p:q:key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": {
"cirq_type": "DensePauliString",
"pauli_mask": [
1,
2,
3
],
"coefficient": {
"cirq_type": "complex",
"real": -1.0,
"imag": 0.0
}
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "p:q:key"
}]
"key": "key"
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "key"
},
{
"cirq_type": "PauliMeasurementGate",
"observable": [
{
"cirq_type": "_PauliX",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliY",
"exponent": 1.0,
"global_shift": 0.0
},
{
"cirq_type": "_PauliZ",
"exponent": 1.0,
"global_shift": 0.0
}
],
"key": "p:q:key"
}]
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')),
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')),
cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ"), cirq.MeasurementKey(path=('p', 'q'), name='key')),
cirq.PauliMeasurementGate(cirq.DensePauliString("XYZ", coefficient=-1), cirq.MeasurementKey(name='key')),
]

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(name='key')),
cirq.PauliMeasurementGate((cirq.X, cirq.Y, cirq.Z), cirq.MeasurementKey(path=('p', 'q'), name='key')),
]

0 comments on commit 16bf9d2

Please sign in to comment.