Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More small updates #205

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed cost function issue with minimization
  • Loading branch information
edyounis committed Nov 29, 2023
commit ba55047749d6cd18e59789de156be276e80cffda
55 changes: 51 additions & 4 deletions bqskit/ir/opt/instantiaters/minimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from bqskit.ir.opt.instantiater import Instantiater
from bqskit.ir.opt.minimizer import Minimizer
from bqskit.ir.opt.minimizers.ceres import CeresMinimizer
from bqskit.qis.state.state import StateVector
from bqskit.qis.state.system import StateSystem
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
from bqskit.ir.opt.multistartgens.random import RandomStartGenerator
from bqskit.qis.state.state import StateLike, StateVector

if TYPE_CHECKING:
from bqskit.ir.circuit import Circuit
from bqskit.qis.state.system import StateSystem
from bqskit.qis.state.system import StateSystemLike
from bqskit.qis.unitary.unitarymatrix import UnitaryLike
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix


class Minimization(Instantiater):
Expand Down Expand Up @@ -61,7 +64,6 @@ def instantiate(
) -> npt.NDArray[np.float64]:
"""Instantiate `circuit`, see Instantiater for more info."""
cost = self.cost_fn_gen.gen_cost(circuit, target)
# print(x0, circuit.num_params, circuit.gate_counts)
return self.minimizer.minimize(cost, x0)

@staticmethod
Expand All @@ -88,3 +90,48 @@ def get_violation_report(circuit: Circuit) -> str:
def get_method_name() -> str:
"""Return the name of this method."""
return 'minimization'

def multi_start_instantiate_inplace(
self,
circuit: Circuit,
target: UnitaryLike | StateLike | StateSystemLike,
num_starts: int,
) -> None:
"""
Instantiate `circuit` to best implement `target` with multiple starts.

See Instantiater for more info.
"""
target = self.check_target(target)
start_gen = RandomStartGenerator()
starts = start_gen.gen_starting_points(num_starts, circuit, target)
cost_fn = self.cost_fn_gen.gen_cost(circuit, target)
params_list = [self.instantiate(circuit, target, x0) for x0 in starts]
params = sorted(params_list, key=lambda x: cost_fn(x))[0]
circuit.set_params(params)

async def multi_start_instantiate_async(
self,
circuit: Circuit,
target: UnitaryLike | StateLike | StateSystemLike,
num_starts: int,
) -> Circuit:
"""
Instantiate `circuit` to best implement `target` with multiple starts.

See Instantiater for more info.
"""
from bqskit.runtime import get_runtime
target = self.check_target(target)
start_gen = RandomStartGenerator()
starts = start_gen.gen_starting_points(num_starts, circuit, target)
cost_fn = self.cost_fn_gen.gen_cost(circuit, target)
params_list = await get_runtime().map(
self.instantiate,
[circuit] * num_starts,
[target] * num_starts,
starts,
)
params = sorted(params_list, key=lambda x: cost_fn(x))[0]
circuit.set_params(params)
return circuit