Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run product state functions inplace to avoid copies where possible #6396

Merged
merged 4 commits into from
Feb 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions cirq-core/cirq/sim/simulation_product_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,19 @@ def split_untangled_states(self) -> bool:
return self._split_untangled_states

def create_merged_state(self) -> TSimulationState:
merged_state = self.sim_states[None]
if not self.split_untangled_states:
return self.sim_states[None]
final_args = self.sim_states[None]
for args in set([self.sim_states[k] for k in self.sim_states.keys() if k is not None]):
final_args = final_args.kronecker_product(args)
return final_args.transpose_to_qubit_order(self.qubits)
return merged_state
extra_states = set([self.sim_states[k] for k in self.sim_states.keys() if k is not None])
if not extra_states:
return merged_state

# This comes from a member variable so we need to copy it if we're going to modify inplace
# before returning. We're not running a step currently, so no need to copy buffers.
merged_state = merged_state.copy(deep_copy_buffers=False)
for state in extra_states:
merged_state.kronecker_product(state, inplace=True)
return merged_state.transpose_to_qubit_order(self.qubits, inplace=True)

def _act_on_fallback_(
self, action: Any, qubits: Sequence['cirq.Qid'], allow_decompose: bool = True
Expand Down Expand Up @@ -106,7 +113,7 @@ def _act_on_fallback_(
if op_args_opt is None:
op_args_opt = self.sim_states[q]
elif q not in op_args_opt.qubits:
op_args_opt = op_args_opt.kronecker_product(self.sim_states[q])
op_args_opt.kronecker_product(self.sim_states[q], inplace=True)
op_args = op_args_opt or self.sim_states[None]

# (Backfill the args map with the new value)
Expand All @@ -123,7 +130,7 @@ def _act_on_fallback_(
):
for q in qubits:
if op_args.allows_factoring and len(op_args.qubits) > 1:
q_args, op_args = op_args.factor((q,), validate=False)
q_args, _ = op_args.factor((q,), validate=False, inplace=True)
self._sim_states[q] = q_args

# (Backfill the args map with the new value)
Expand Down
Loading