Skip to content

Commit

Permalink
server: unify setting the variant
Browse files Browse the repository at this point in the history
Fix a mismatch between the variant of the server and the variant of the
SubiquityModel. POST requests to /meta/client_variant would modify the
variant for both the server and the SubiquityModel, while POST requests to
the source controller would update only the SubiquityModel.

Fix this by making all requests to change the variant go through a
single function in the server and let that be responsible for updating
itself and the SubiquityModel.

(Conceptually the Source controller should not be responsible for
updating the SubiquityModel.)
  • Loading branch information
Chris-Peterson444 committed Jan 21, 2025
1 parent b570dfb commit 0e5ef2f
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 4 deletions.
2 changes: 1 addition & 1 deletion subiquity/server/controllers/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_handler(
async def configured(self):
await super().configured()
self._configured = True
self.app.base_model.set_source_variant(self.model.current.variant)
self.app.set_source_variant(self.model.current.variant)

async def POST(self, source_id: str, search_drivers: bool = False) -> None:
# Marking the source model configured has an effect on many of the
Expand Down
19 changes: 19 additions & 0 deletions subiquity/server/controllers/tests/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,22 @@ def test_install_source_detection__autoinstall(self, catalog, ai_data, expected)
self._set_source_catalog(catalog)
self.controller.load_autoinstall_data(ai_data)
self.assertEqual(self.controller.model.current.variant, expected)

async def test_on_configure_update_variant(self):
"""Test update variant through server on configure.
Ensure the source controller doesn't update variant on the base
model directly.
"""
app = self.controller.app = unittest.mock.Mock()
model = self.controller.app.base_model = unittest.mock.Mock()

self.controller.model.current.variant = "mock-variant"

with unittest.mock.patch(
"subiquity.server.controller.SubiquityController.configured"
):
await self.controller.configured()

app.set_source_variant.assert_called_with("mock-variant")
model.set_source_variant.assert_not_called()
26 changes: 23 additions & 3 deletions subiquity/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ async def mark_configured_POST(self, endpoint_names: List[str]) -> None:
async def client_variant_POST(self, variant: str) -> None:
if variant not in self.app.supported_variants:
raise ValueError(f"unrecognized client variant {variant}")
self.app.base_model.set_source_variant(variant)
self.app.set_source_variant(variant)

async def client_variant_GET(self) -> str:
Expand Down Expand Up @@ -291,6 +290,10 @@ def make_model(self):
root = "/"
if self.opts.dry_run:
root = os.path.abspath(self.opts.output_base)
# TODO: Set the model source variant before returning it?
# This _will_ eventually get set by the source controller,
# but before then it's in a state that only requires the
# "default" models i.e., the base set all variants require.
return SubiquityModel(
root,
self.hub,
Expand All @@ -302,7 +305,7 @@ def make_model(self):
def __init__(self, opts, block_log_dir):
super().__init__(opts)
self.dr_cfg: Optional[DRConfig] = None
self.set_source_variant(self.supported_variants[0])
self._set_source_variant(self.supported_variants[0])
self.block_log_dir = block_log_dir
self.cloud_init_ok = None
self.state_event = asyncio.Event()
Expand Down Expand Up @@ -353,9 +356,26 @@ def __init__(self, opts, block_log_dir):

self.geoip = GeoIP(self, strategy=geoip_strategy)

def set_source_variant(self, variant):
def _set_source_variant(self, variant):
self.variant = variant

def set_source_variant(self, variant):
"""Set the source variant for the install.
This is the public interface for setting the variant for the install.
This ensures that both the server and the model's understanding of the
variant is updated in one place.
Any extra logic for updating the variant in the server should go into
the private method _set_source_variant. This is separated out because
the sever needs to seed the initial variant state during __init__
but the base_model isn't attached to the server object until the .Run()
method is called.
"""
self._set_source_variant(variant)

self.base_model.set_source_variant(variant)

def load_serialized_state(self):
for controller in self.controllers.instances:
controller.load_state()
Expand Down
15 changes: 15 additions & 0 deletions subiquity/server/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,3 +744,18 @@ def test_push_error_events(self, interactive):
(message,) = journal_send_mock.call_args.args
self.assertIn("message", message)
self.assertNotIn("description", message)


class TestVariantHandling(SubiTestCase):
async def asyncSetUp(self):
opts = Mock()
opts.dry_run = True
opts.output_base = self.tmp_dir()
opts.machine_config = NOPROBERARG
self.server = SubiquityServer(opts, None)

def test_set_source_variant(self):
self.server.base_model = Mock()
self.server.set_source_variant("mock-variant")
self.assertEqual(self.server.variant, "mock-variant")
self.server.base_model.set_source_variant.assert_called_with("mock-variant")
16 changes: 16 additions & 0 deletions subiquity/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2330,3 +2330,19 @@ async def test_supported_variants(self, variant, is_supported):
"unrecognized client variant foo-bar",
json.loads(cre.headers["x-error-msg"]),
)

async def test_post_source_update_server_variant(self):
"""Test POSTing to source will correctly update Server variant."""

extra_args = ["--source-catalog", "examples/sources/mixed.yaml"]
async with start_server(
"examples/machines/simple.json",
extra_args=extra_args,
) as inst:
resp = await inst.get("/meta/client_variant")
self.assertEqual(resp, "server")

await inst.post("/source", source_id="ubuntu-desktop")

resp = await inst.get("/meta/client_variant")
self.assertEqual(resp, "desktop")

0 comments on commit 0e5ef2f

Please sign in to comment.