Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArithmeticGate implementation #4702

Merged
merged 25 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
shor
  • Loading branch information
daxfohl committed Dec 23, 2021
commit 57fb38fc287eb32789ebd3c431def0808f022867
28 changes: 14 additions & 14 deletions examples/examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,45 +137,45 @@ def test_example_noisy_simulation():
def test_example_shor_modular_exp_register_size():
with pytest.raises(ValueError):
_ = examples.shor.ModularExp(
target=cirq.LineQubit.range(2), exponent=cirq.LineQubit.range(2, 5), base=4, modulus=5
target=[2, 2], exponent=[2, 2, 2], base=4, modulus=5
)


def test_example_shor_modular_exp_register_type():
operation = examples.shor.ModularExp(
target=cirq.LineQubit.range(3), exponent=cirq.LineQubit.range(3, 5), base=4, modulus=5
target=[2, 2, 2], exponent=[2, 2], base=4, modulus=5
)
with pytest.raises(ValueError):
_ = operation.with_registers(cirq.LineQubit.range(3))
_ = operation.with_registers([2, 2, 2])
with pytest.raises(ValueError):
_ = operation.with_registers(1, cirq.LineQubit.range(3, 6), 4, 5)
_ = operation.with_registers(1, [2, 2, 2], 4, 5)
with pytest.raises(ValueError):
_ = operation.with_registers(
cirq.LineQubit.range(3), cirq.LineQubit.range(3, 6), cirq.LineQubit.range(6, 9), 5
[2, 2, 2], [2, 2, 2], [2, 2, 2], 5
)
with pytest.raises(ValueError):
_ = operation.with_registers(
cirq.LineQubit.range(3), cirq.LineQubit.range(3, 6), 4, cirq.LineQubit.range(6, 9)
[2, 2, 2], [2, 2, 2], 4, [2, 2, 2]
)


def test_example_shor_modular_exp_registers():
target = cirq.LineQubit.range(3)
exponent = cirq.LineQubit.range(3, 5)
target = [2, 2, 2]
exponent = [2, 2]
operation = examples.shor.ModularExp(target, exponent, 4, 5)
assert operation.registers() == (target, exponent, 4, 5)

new_target = cirq.LineQubit.range(5, 8)
new_exponent = cirq.LineQubit.range(8, 12)
new_target = [2, 2, 2]
new_exponent = [2, 2, 2, 2]
new_operation = operation.with_registers(new_target, new_exponent, 6, 7)
assert new_operation.registers() == (new_target, new_exponent, 6, 7)


def test_example_shor_modular_exp_diagram():
target = cirq.LineQubit.range(3)
exponent = cirq.LineQubit.range(3, 5)
target = [2, 2, 2]
exponent = [2, 2]
operation = examples.shor.ModularExp(target, exponent, 4, 5)
circuit = cirq.Circuit(operation)
circuit = cirq.Circuit(operation.on(*cirq.LineQubit.range(5)))
daxfohl marked this conversation as resolved.
Show resolved Hide resolved
cirq.testing.assert_has_diagram(
circuit,
"""
Expand All @@ -192,7 +192,7 @@ def test_example_shor_modular_exp_diagram():
)

operation = operation.with_registers(target, 2, 4, 5)
circuit = cirq.Circuit(operation)
circuit = cirq.Circuit(operation.on(*cirq.LineQubit.range(3)))
cirq.testing.assert_has_diagram(
circuit,
"""
Expand Down
48 changes: 20 additions & 28 deletions examples/shor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def naive_order_finder(x: int, n: int) -> Optional[int]:
return r


class ModularExp(cirq.ArithmeticOperation):
class ModularExp(cirq.ArithmeticGate):
"""Quantum modular exponentiation.

This class represents the unitary which multiplies base raised to exponent
Expand Down Expand Up @@ -130,8 +130,8 @@ class ModularExp(cirq.ArithmeticOperation):

def __init__(
self,
target: Sequence[cirq.Qid],
exponent: Union[int, Sequence[cirq.Qid]],
target: Sequence[int],
exponent: Union[int, Sequence[int]],
base: int,
modulus: int,
) -> None:
Expand All @@ -144,12 +144,12 @@ def __init__(
self.base = base
self.modulus = modulus

def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]:
def registers(self) -> Sequence[Union[int, Sequence[int]]]:
return self.target, self.exponent, self.base, self.modulus

def with_registers(
self,
*new_registers: Union[int, Sequence['cirq.Qid']],
*new_registers: Union[int, Sequence[int]],
) -> 'ModularExp':
if len(new_registers) != 4:
raise ValueError(
Expand Down Expand Up @@ -177,22 +177,12 @@ def _circuit_diagram_info_(
args: cirq.CircuitDiagramInfoArgs,
) -> cirq.CircuitDiagramInfo:
assert args.known_qubits is not None
wire_symbols: List[str] = []
t, e = 0, 0
for qubit in args.known_qubits:
if qubit in self.target:
if t == 0:
if isinstance(self.exponent, Sequence):
e_str = 'e'
else:
e_str = str(self.exponent)
wire_symbols.append(f'ModularExp(t*{self.base}**{e_str} % {self.modulus})')
else:
wire_symbols.append('t' + str(t))
t += 1
if isinstance(self.exponent, Sequence) and qubit in self.exponent:
wire_symbols.append('e' + str(e))
e += 1
wire_symbols = [f't{i}' for i in range(len(self.target))]
e_str = str(self.exponent)
if isinstance(self.exponent, Sequence):
e_str = 'e'
wire_symbols += [f'e{i}' for i in range(len(self.exponent))]
wire_symbols[0] = f'ModularExp(t*{self.base}**{e_str} % {self.modulus})'
return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))


Expand Down Expand Up @@ -225,14 +215,16 @@ def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:
Quantum circuit for finding the order of x modulo n
"""
L = n.bit_length()
target = cirq.LineQubit.range(L)
exponent = cirq.LineQubit.range(L, 3 * L + 3)
target = (2,) * L
target_q = cirq.LineQubit.range(L)
exponent = (2,) * (2 * L + 3)
exponent_q = cirq.LineQubit.range(L, 3 * L + 3)
return cirq.Circuit(
cirq.X(target[L - 1]),
cirq.H.on_each(*exponent),
ModularExp(target, exponent, x, n),
cirq.qft(*exponent, inverse=True),
cirq.measure(*exponent, key='exponent'),
cirq.X(target_q[L - 1]),
cirq.H.on_each(*exponent_q),
ModularExp(target, exponent, x, n).on(*(target_q + exponent_q)),
cirq.qft(*exponent_q, inverse=True),
cirq.measure(*exponent_q, key='exponent'),
)


Expand Down