Skip to content

Commit

Permalink
Merge pull request Pyomo#2576 from emma58/fix-transformed-gdp-pickle
Browse files Browse the repository at this point in the history
Pickle transformed GDP models
  • Loading branch information
emma58 authored Nov 2, 2022
2 parents f681556 + da08461 commit f6606f9
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 87 deletions.
6 changes: 5 additions & 1 deletion pyomo/gdp/disjunct.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,14 @@ def process(arg):


class _DisjunctData(_BlockData):
__autoslot_mappers__ = {'_transformation_block': AutoSlots.weakref_mapper}

_Block_reserved_words = set()

@property
def transformation_block(self):
return self._transformation_block
return None if self._transformation_block is None else \
self._transformation_block()

def __init__(self, component):
_BlockData.__init__(self, component)
Expand Down Expand Up @@ -427,6 +429,7 @@ def active(self):

class _DisjunctionData(ActiveComponentData):
__slots__ = ('disjuncts', 'xor', '_algebraic_constraint')
__autoslot_mappers__ = {'_algebraic_constraint': AutoSlots.weakref_mapper}
_NoArgument = (0,)

@property
Expand Down Expand Up @@ -514,6 +517,7 @@ def set_value(self, expr):
@ModelComponentFactory.register("Disjunction expressions.")
class Disjunction(ActiveIndexedComponent):
_ComponentDataClass = _DisjunctionData
__autoslot_mappers__ = {'_algebraic_constraint': AutoSlots.weakref_mapper}

def __new__(cls, *args, **kwds):
if cls != Disjunction:
Expand Down
9 changes: 5 additions & 4 deletions pyomo/gdp/plugins/bigm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pyomo.core.base import Transformation, TransformationFactory, Reference
import pyomo.core.expr.current as EXPR
from pyomo.gdp import Disjunct, Disjunction, GDP_Error
from pyomo.gdp.transformed_disjunct import _TransformedDisjunct
from pyomo.gdp.util import (
is_child_of, get_src_disjunction, get_src_constraint, get_gdp_tree,
get_transformed_constraints, _get_constraint_transBlock, get_src_disjunct,
Expand Down Expand Up @@ -304,7 +305,7 @@ def _add_transformation_block(self, to_block):
'_pyomo_gdp_bigm_reformulation')
self._transformation_blocks[to_block] = transBlock = Block()
to_block.add_component(transBlockName, transBlock)
transBlock.relaxedDisjuncts = Block(NonNegativeIntegers)
transBlock.relaxedDisjuncts = _TransformedDisjunct(NonNegativeIntegers)
transBlock.lbub = Set(initialize=['lb', 'ub'])

return transBlock
Expand Down Expand Up @@ -394,7 +395,7 @@ def _transform_disjunct(self, obj, bigM, root_disjunct):
relaxationBlock.bigm_src = {}
relaxationBlock.localVarReferences = Block()
obj._transformation_block = weakref_ref(relaxationBlock)
relaxationBlock._srcDisjunct = weakref_ref(obj)
relaxationBlock._src_disjunct = weakref_ref(obj)

# This is crazy, but if the disjunction has been previously
# relaxed, the disjunct *could* be deactivated. This is a big
Expand Down Expand Up @@ -877,9 +878,9 @@ def get_all_M_values_by_constraint(self, model):
Disjunct,
active=None,
descend_into=(Block, Disjunct)):
transBlock = disj.transformation_block
# First check if it was transformed at all.
if disj.transformation_block is not None:
transBlock = disj.transformation_block()
if transBlock is not None:
# If it was transformed with BigM, we get the M values.
if hasattr(transBlock, 'bigm_src'):
for cons in transBlock.bigm_src:
Expand Down
4 changes: 2 additions & 2 deletions pyomo/gdp/plugins/cuttingplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,8 @@ def _get_disaggregated_vars(self, hull):
descend_into=(Disjunct,
Block)):
for disjunct in disjunction.disjuncts:
if disjunct.transformation_block is not None:
transBlock = disjunct.transformation_block()
transBlock = disjunct.transformation_block
if transBlock is not None:
for v in transBlock.disaggregatedVars.\
component_data_objects(Var):
disaggregatedVars.add(v)
Expand Down
5 changes: 3 additions & 2 deletions pyomo/gdp/plugins/hull.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pyomo.core.base.boolean_var import (
_DeprecatedImplicitAssociatedBinaryVariable)
from pyomo.gdp import Disjunct, Disjunction, GDP_Error
from pyomo.gdp.transformed_disjunct import _TransformedDisjunct
from pyomo.gdp.util import (
clone_without_expression_components, is_child_of, get_src_disjunction,
get_src_constraint, get_transformed_constraints,
Expand Down Expand Up @@ -309,7 +310,7 @@ def _add_transformation_block(self, instance):
transBlock = Block()
instance.add_component(transBlockName, transBlock)
self._transformation_blocks[instance] = transBlock
transBlock.relaxedDisjuncts = Block(NonNegativeIntegers)
transBlock.relaxedDisjuncts = _TransformedDisjunct(NonNegativeIntegers)
transBlock.lbub = Set(initialize = ['lb','ub','eq'])
# Map between disaggregated variables and their
# originals
Expand Down Expand Up @@ -599,7 +600,7 @@ def _transform_disjunct(self, obj, transBlock, varSet, localVars,

# add mappings to source disjunct (so we'll know we've relaxed)
obj._transformation_block = weakref_ref(relaxationBlock)
relaxationBlock._srcDisjunct = weakref_ref(obj)
relaxationBlock._src_disjunct = weakref_ref(obj)

# add the disaggregated variables and their bigm constraints
# to the relaxationBlock
Expand Down
89 changes: 63 additions & 26 deletions pyomo/gdp/tests/common_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

import pickle
from pyomo.common.dependencies import dill

from pyomo.environ import (
TransformationFactory, ConcreteModel, Constraint, Var, Objective,
Expand Down Expand Up @@ -109,7 +111,7 @@ def checkb0TargetsTransformed(self, m, transformation):
(1,1),
]
for i, j in pairs:
self.assertIs(m.b[0].disjunct[i].transformation_block(),
self.assertIs(m.b[0].disjunct[i].transformation_block,
disjBlock[j])
self.assertIs(trans.get_src_disjunct(disjBlock[j]),
m.b[0].disjunct[i])
Expand All @@ -131,7 +133,7 @@ def check_user_deactivated_disjuncts(self, transformation,
rBlock = m.component("_pyomo_gdp_%s_reformulation" % transformation)
disjBlock = rBlock.relaxedDisjuncts
self.assertEqual(len(disjBlock), 1)
self.assertIs(disjBlock[0], m.d[1].transformation_block())
self.assertIs(disjBlock[0], m.d[1].transformation_block)
self.assertIs(transform.get_src_disjunct(disjBlock[0]), m.d[1])

def check_improperly_deactivated_disjuncts(self, transformation, **kwargs):
Expand Down Expand Up @@ -486,7 +488,7 @@ def check_disjunct_mapping(self, transformation):
# the disjuncts will always be transformed in the same order,
# and d[0] goes first, so we can check in a loop.
for i in [0,1]:
self.assertIs(disjBlock[i]._srcDisjunct(), m.d[i])
self.assertIs(disjBlock[i]._src_disjunct(), m.d[i])
self.assertIs(trans.get_src_disjunct(disjBlock[i]), m.d[i])

# targets
Expand Down Expand Up @@ -535,7 +537,7 @@ def check_only_targets_get_transformed(self, transformation):
(1, 1)
]
for i, j in pairs:
self.assertIs(disjBlock[i], m.disjunct1[j].transformation_block())
self.assertIs(disjBlock[i], m.disjunct1[j].transformation_block)
self.assertIs(trans.get_src_disjunct(disjBlock[i]), m.disjunct1[j])

self.assertIsNone(m.disjunct2[0].transformation_block)
Expand Down Expand Up @@ -602,16 +604,16 @@ def check_indexedDisj_only_targets_transformed(self, transformation):
relaxedDisjuncts
self.assertEqual(len(disjBlock), 4)
self.assertIsInstance(
m.disjunct1[1,0].transformation_block().component("disjunct1[1,0].c"),
m.disjunct1[1,0].transformation_block.component("disjunct1[1,0].c"),
Constraint)
self.assertIsInstance(
m.disjunct1[1,1].transformation_block().component("disjunct1[1,1].c"),
m.disjunct1[1,1].transformation_block.component("disjunct1[1,1].c"),
Constraint)
self.assertIsInstance(
m.disjunct1[2,0].transformation_block().component("disjunct1[2,0].c"),
m.disjunct1[2,0].transformation_block.component("disjunct1[2,0].c"),
Constraint)
self.assertIsInstance(
m.disjunct1[2,1].transformation_block().component("disjunct1[2,1].c"),
m.disjunct1[2,1].transformation_block.component("disjunct1[2,1].c"),
Constraint)

# This relies on the disjunctions being transformed in the same order
Expand All @@ -635,7 +637,7 @@ def check_indexedDisj_only_targets_transformed(self, transformation):

for i, j in pairs:
self.assertIs(trans.get_src_disjunct(disjBlock[j]), m.disjunct1[i])
self.assertIs(disjBlock[j], m.disjunct1[i].transformation_block())
self.assertIs(disjBlock[j], m.disjunct1[i].transformation_block)

def check_warn_for_untransformed(self, transformation, **kwargs):
# Check that we complain if we find an untransformed Disjunct inside of
Expand Down Expand Up @@ -723,7 +725,7 @@ def check_disjData_only_targets_transformed(self, transformation):
((2,1), 1),
]
for i, j in pairs:
self.assertIs(m.disjunct1[i].transformation_block(), disjBlock[j])
self.assertIs(m.disjunct1[i].transformation_block, disjBlock[j])
self.assertIs(trans.get_src_disjunct(disjBlock[j]), m.disjunct1[i])

def check_indexedBlock_targets_inactive(self, transformation, **kwargs):
Expand Down Expand Up @@ -796,7 +798,7 @@ def check_indexedBlock_only_targets_transformed(self, transformation):
disjBlock = disjBlock1
if blocknum == 1:
disjBlock = disjBlock2
self.assertIs(original[i].transformation_block(), disjBlock[j])
self.assertIs(original[i].transformation_block, disjBlock[j])
self.assertIs(trans.get_src_disjunct(disjBlock[j]), original[i])

def check_blockData_targets_inactive(self, transformation, **kwargs):
Expand Down Expand Up @@ -1154,7 +1156,7 @@ def check_block_only_targets_transformed(self, transformation):
(1,1),
]
for i, j in pairs:
self.assertIs(m.b.disjunct[i].transformation_block(), disjBlock[j])
self.assertIs(m.b.disjunct[i].transformation_block, disjBlock[j])
self.assertIs(trans.get_src_disjunct(disjBlock[j]), m.b.disjunct[i])

# common error messages
Expand Down Expand Up @@ -1481,7 +1483,7 @@ def check_disjunct_only_targets_transformed(self, transformation):
transform.get_src_disjunct(disjBlock[j]))
self.assertIs(disjBlock[j],
m.simpledisjunct.component(
'innerdisjunct%d'%i).transformation_block())
'innerdisjunct%d'%i).transformation_block)

def check_disjunctData_targets_inactive(self, transformation, **kwargs):
m = models.makeNestedDisjunctions()
Expand Down Expand Up @@ -1528,33 +1530,33 @@ def check_disjunctData_only_targets_transformed(self, transformation):
for i, j in pairs:
self.assertIs(transform.get_src_disjunct(disjBlock[j]),
m.disjunct[1].innerdisjunct[i])
self.assertIs(m.disjunct[1].innerdisjunct[i].transformation_block(),
self.assertIs(m.disjunct[1].innerdisjunct[i].transformation_block,
disjBlock[j])

def check_all_components_transformed(self, m):
# checks that all the disjunctive components claim to be transformed in the
# makeNestedDisjunctions_NestedDisjuncts model.
self.assertIsInstance(m.disj.algebraic_constraint(), Constraint)
self.assertIsInstance(m.d1.disj2.algebraic_constraint(), Constraint)
self.assertIsInstance(m.d1.transformation_block(), _BlockData)
self.assertIsInstance(m.d2.transformation_block(), _BlockData)
self.assertIsInstance(m.d1.d3.transformation_block(), _BlockData)
self.assertIsInstance(m.d1.d4.transformation_block(), _BlockData)
self.assertIsInstance(m.d1.transformation_block, _BlockData)
self.assertIsInstance(m.d2.transformation_block, _BlockData)
self.assertIsInstance(m.d1.d3.transformation_block, _BlockData)
self.assertIsInstance(m.d1.d4.transformation_block, _BlockData)

def check_transformation_blocks_nestedDisjunctions(self, m, transformation):
disjunctionTransBlock = m.disj.algebraic_constraint().parent_block()
transBlocks = disjunctionTransBlock.relaxedDisjuncts
self.assertEqual(len(transBlocks), 4)
if transformation == 'bigm':
self.assertIs(transBlocks[0], m.d1.d3.transformation_block())
self.assertIs(transBlocks[1], m.d1.d4.transformation_block())
self.assertIs(transBlocks[2], m.d1.transformation_block())
self.assertIs(transBlocks[3], m.d2.transformation_block())
self.assertIs(transBlocks[0], m.d1.d3.transformation_block)
self.assertIs(transBlocks[1], m.d1.d4.transformation_block)
self.assertIs(transBlocks[2], m.d1.transformation_block)
self.assertIs(transBlocks[3], m.d2.transformation_block)
if transformation == 'hull':
self.assertIs(transBlocks[2], m.d1.d3.transformation_block())
self.assertIs(transBlocks[3], m.d1.d4.transformation_block())
self.assertIs(transBlocks[0], m.d1.transformation_block())
self.assertIs(transBlocks[1], m.d2.transformation_block())
self.assertIs(transBlocks[2], m.d1.d3.transformation_block)
self.assertIs(transBlocks[3], m.d1.d4.transformation_block)
self.assertIs(transBlocks[0], m.d1.transformation_block)
self.assertIs(transBlocks[1], m.d2.transformation_block)

def check_nested_disjunction_target(self, transformation):
m = models.makeNestedDisjunctions_NestedDisjuncts()
Expand Down Expand Up @@ -1666,3 +1668,38 @@ def check_solution_obeys_logical_constraints(self, transformation, m):
self.assertEqual(results.solver.termination_condition,
TerminationCondition.optimal)
self.assertAlmostEqual(value(m.x), 8)

# test pickling transformed models

def check_pprint_equal(self, m, unpickle):
# This is almost the same as in the diff_apply_to_and_create_using test but
# we don't have to transform in the middle or mess with seeds.
m_buf = StringIO()
m.pprint(ostream=m_buf)
m_output = m_buf.getvalue()

unpickle_buf = StringIO()
unpickle.pprint(ostream=unpickle_buf)
unpickle_output = unpickle_buf.getvalue()
self.assertMultiLineEqual(m_output, unpickle_output)

def check_transformed_model_pickles(self, transformation):
# Do a model where we'll have to call logical_to_linear too.
m = models.makeLogicalConstraintsOnDisjuncts_NonlinearConvex()
trans = TransformationFactory('gdp.%s' % transformation)
trans.apply_to(m)

# pickle and unpickle the transformed model
unpickle = pickle.loads(pickle.dumps(m))

check_pprint_equal(self, m, unpickle)

def check_transformed_model_pickles_with_dill(self, transformation):
m = models.makeLogicalConstraintsOnDisjuncts_NonlinearConvex()
trans = TransformationFactory('gdp.%s' % transformation)
trans.apply_to(m)

# pickle and unpickle the transformed model
unpickle = dill.loads(dill.dumps(m))

check_pprint_equal(self, m, unpickle)
Loading

0 comments on commit f6606f9

Please sign in to comment.