Skip to content

Commit

Permalink
Override gate.controlled() for GlobalPhaseGate to return a ZPowGate (q…
Browse files Browse the repository at this point in the history
…uantumlib#6073)

* Override gate.controlled() for GlobalPhaseGate to return a ZPowGate

* Test unitary equivalence

* Override controlled only if gate is not parameterized

* Fix typo

* Fix type check

* another attempt at fixing types

* Add a comment and additional tests
  • Loading branch information
tanujkhattar authored May 10, 2023
1 parent 0346162 commit 1e3d8ce
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
30 changes: 28 additions & 2 deletions cirq/ops/global_phase_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
# limitations under the License.
"""A no-qubit global phase operation."""

from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union
from typing import AbstractSet, Any, cast, Dict, Sequence, Tuple, Union, Optional, Collection

import numpy as np
import sympy

import cirq
from cirq import value, protocols
from cirq.ops import raw_types
from cirq.ops import raw_types, controlled_gate, control_values as cv
from cirq.type_workarounds import NotImplementedType


Expand Down Expand Up @@ -91,6 +91,32 @@ def _resolve_parameters_(
coefficient = protocols.resolve_parameters(self.coefficient, resolver, recursive)
return GlobalPhaseGate(coefficient=coefficient)

def controlled(
self,
num_controls: Optional[int] = None,
control_values: Optional[
Union[cv.AbstractControlValues, Sequence[Union[int, Collection[int]]]]
] = None,
control_qid_shape: Optional[Tuple[int, ...]] = None,
) -> raw_types.Gate:
result = super().controlled(num_controls, control_values, control_qid_shape)
if (
not self._is_parameterized_()
and isinstance(result, controlled_gate.ControlledGate)
and isinstance(result.control_values, cv.ProductOfSums)
and result.control_values[-1] == (1,)
and result.control_qid_shape[-1] == 2
):
# A `GlobalPhaseGate` controlled on a qubit in state `|1>` is equivalent
# to applying a `ZPowGate`. This override ensures that `global_phase_gate.controlled()`
# returns a `ZPowGate` instead of a `ControlledGate(sub_gate=global_phase_gate)`.
coefficient = complex(self.coefficient)
exponent = float(np.angle(coefficient) / np.pi)
return cirq.ZPowGate(exponent=exponent).controlled(
result.num_controls() - 1, result.control_values[:-1], result.control_qid_shape[:-1]
)
return result


def global_phase_operation(
coefficient: 'cirq.TParamValComplex', atol: float = 1e-8
Expand Down
21 changes: 21 additions & 0 deletions cirq/ops/global_phase_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,24 @@ def test_resolve_error(resolve_fn):
gpt = cirq.GlobalPhaseGate(coefficient=t)
with pytest.raises(ValueError, match='Coefficient is not unitary'):
resolve_fn(gpt, {'t': -2})


@pytest.mark.parametrize(
'coeff, exp', [(-1, 1), (1j, 0.5), (-1j, -0.5), (1 / np.sqrt(2) * (1 + 1j), 0.25)]
)
def test_global_phase_gate_controlled(coeff, exp):
g = cirq.GlobalPhaseGate(coeff)
op = cirq.global_phase_operation(coeff)
q = cirq.LineQubit.range(3)
for num_controls, target_gate in zip(range(1, 4), [cirq.Z, cirq.CZ, cirq.CCZ]):
assert g.controlled(num_controls) == target_gate**exp
np.testing.assert_allclose(
cirq.unitary(cirq.ControlledGate(g, num_controls)),
cirq.unitary(g.controlled(num_controls)),
)
assert op.controlled_by(*q[:num_controls]) == target_gate(*q[:num_controls]) ** exp
assert g.controlled(control_values=[0]) == cirq.ControlledGate(g, control_values=[0])
xor_control_values = cirq.SumOfProducts(((0, 0), (1, 1)))
assert g.controlled(control_values=xor_control_values) == cirq.ControlledGate(
g, control_values=xor_control_values
)

0 comments on commit 1e3d8ce

Please sign in to comment.