Skip to content

Commit

Permalink
Fix bug in TimeSeasonality when pop_state = False (#407)
Browse files Browse the repository at this point in the history
Rename `pop_state` to `remove_first_state` for clarity.

Add test for `remove_first_state = False`
  • Loading branch information
jessegrabowski authored Jan 2, 2025
1 parent dcc2bec commit be77d9d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 9 deletions.
27 changes: 21 additions & 6 deletions pymc_extras/statespace/models/structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,12 @@ class TimeSeasonality(Component):
If None, states will be numbered ``[State_0, ..., State_s]``
remove_first_state: bool, default True
If True, the first state will be removed from the model. This is done because there are only n-1 degrees of
freedom in the seasonal component, and one state is not identified. If False, the first state will be
included in the model, but it will not be identified -- you will need to handle this in the priors (e.g. with
ZeroSumNormal).
Notes
-----
A seasonal effect is any pattern that repeats every fixed interval. Although there are many possible ways to
Expand Down Expand Up @@ -1163,7 +1169,7 @@ def __init__(
innovations: bool = True,
name: str | None = None,
state_names: list | None = None,
pop_state: bool = True,
remove_first_state: bool = True,
):
if name is None:
name = f"Seasonal[s={season_length}]"
Expand All @@ -1176,14 +1182,15 @@ def __init__(
)
state_names = state_names.copy()
self.innovations = innovations
self.pop_state = pop_state
self.remove_first_state = remove_first_state

if self.pop_state:
if self.remove_first_state:
# In traditional models, the first state isn't identified, so we can help out the user by automatically
# discarding it.
# TODO: Can this be stashed and reconstructed automatically somehow?
state_names.pop(0)
k_states = season_length - 1

k_states = season_length - int(self.remove_first_state)

super().__init__(
name=name,
Expand Down Expand Up @@ -1218,8 +1225,16 @@ def populate_component_properties(self):
self.shock_names = [f"{self.name}"]

def make_symbolic_graph(self) -> None:
T = np.eye(self.k_states, k=-1)
T[0, :] = -1
if self.remove_first_state:
# In this case, parameters are normalized to sum to zero, so the current state is the negative sum of
# all previous states.
T = np.eye(self.k_states, k=-1)
T[0, :] = -1
else:
# In this case we assume the user to be responsible for ensuring the states sum to zero, so T is just a
# circulant matrix that cycles between the states.
T = np.eye(self.k_states, k=1)
T[-1, 0] = 1

self.ssm["transition", :, :] = T
self.ssm["design", 0, 0] = 1
Expand Down
13 changes: 10 additions & 3 deletions tests/statespace/test_structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

from collections import defaultdict
from copyreg import remove_extension
from typing import Optional

import numpy as np
Expand Down Expand Up @@ -592,13 +593,18 @@ def test_autoregressive_model(order, rng):

@pytest.mark.parametrize("s", [10, 25, 50])
@pytest.mark.parametrize("innovations", [True, False])
def test_time_seasonality(s, innovations, rng):
@pytest.mark.parametrize("remove_first_state", [True, False])
def test_time_seasonality(s, innovations, remove_first_state, rng):
def random_word(rng):
return "".join(rng.choice(list("abcdefghijklmnopqrstuvwxyz")) for _ in range(5))

state_names = [random_word(rng) for _ in range(s)]
mod = st.TimeSeasonality(
season_length=s, innovations=innovations, name="season", state_names=state_names
season_length=s,
innovations=innovations,
name="season",
state_names=state_names,
remove_first_state=remove_first_state,
)
x0 = np.zeros(mod.k_states, dtype=floatX)
x0[0] = 1
Expand All @@ -615,7 +621,8 @@ def random_word(rng):
# Check coords
mod.build(verbose=False)
_assert_basic_coords_correct(mod)
assert mod.coords["season_state"] == state_names[1:]
test_slice = slice(1, None) if remove_first_state else slice(None)
assert mod.coords["season_state"] == state_names[test_slice]


def get_shift_factor(s):
Expand Down

0 comments on commit be77d9d

Please sign in to comment.