Skip to content

Commit

Permalink
feat: Output components with combined value
Browse files Browse the repository at this point in the history
Co-authored-by: Paulo Valente <16843419+polvalente@users.noreply.github.com>
  • Loading branch information
georgeguimaraes and polvalente committed Aug 29, 2024
1 parent 40407f6 commit a8e04b2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
10 changes: 8 additions & 2 deletions lib/soothsayer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ defmodule Soothsayer do
end

def predict(%Model{} = model, %Series{} = x) do
%{combined: combined} = predict_components(model, x)
combined
end

def predict_components(%Model{} = model, %Series{} = x) do
processed_x =
Preprocessor.prepare_data(DataFrame.new(%{"ds" => x}), nil, "ds", model.config.seasonality)

Expand All @@ -57,9 +62,10 @@ defmodule Soothsayer do
x_normalized = normalize_with_params(x_input, model.config.normalization.x)

predictions = Model.predict(model, x_normalized)
denormalized_predictions = denormalize(predictions, model.config.normalization.y)

Nx.to_flat_list(denormalized_predictions)
Map.new(predictions, fn {key, node} ->
{key, denormalize(node, model.config.normalization.y)}
end)
end

defp get_seasonality_input(data, seasonality, config) do
Expand Down
11 changes: 9 additions & 2 deletions lib/soothsayer/model.ex
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ defmodule Soothsayer.Model do
Axon.constant(0)
end

Axon.add([trend, yearly_seasonality, weekly_seasonality])
combined = Axon.add([trend, yearly_seasonality, weekly_seasonality])

Axon.container(%{
combined: combined,
trend: trend,
yearly_seasonality: yearly_seasonality,
weekly_seasonality: weekly_seasonality
})
end

def fit(model, x, y, epochs) do
Expand All @@ -44,7 +51,7 @@ defmodule Soothsayer.Model do
trained_params =
model.network
|> Axon.Loop.trainer(
:mean_squared_error,
&Axon.Losses.huber(&1, &2.combined, reduction: :mean),
Polaris.Optimizers.adam(learning_rate: model.config.learning_rate)
)
|> Axon.Loop.run(Stream.repeatedly(fn -> {x, y} end), initial_params,
Expand Down
6 changes: 3 additions & 3 deletions test/soothsayer_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ defmodule SoothsayerTest do
future_end = Date.add(future_start, 29)
future_dates = Date.range(future_start, future_end)
x_test = Series.from_list(Enum.to_list(future_dates))
predictions = Soothsayer.predict(fitted_model, x_test)
predictions = Soothsayer.predict(fitted_model, x_test) |> Nx.to_flat_list()

# Check if predictions follow the trend (with tolerance for noise)
Enum.zip(predictions, future_dates)
Expand Down Expand Up @@ -89,7 +89,7 @@ defmodule SoothsayerTest do
future_end = Date.add(future_start, 29)
future_dates = Date.range(future_start, future_end)
x_test = Series.from_list(Enum.to_list(future_dates))
predictions = Soothsayer.predict(fitted_model, x_test)
predictions = Soothsayer.predict(fitted_model, x_test) |> Nx.to_flat_list()

# Check if predictions follow the seasonality pattern (with tolerance for noise)
Enum.zip(predictions, future_dates)
Expand Down Expand Up @@ -140,7 +140,7 @@ defmodule SoothsayerTest do
future_end = Date.add(future_start, 29)
future_dates = Date.range(future_start, future_end)
x_test = Series.from_list(Enum.to_list(future_dates))
predictions = Soothsayer.predict(fitted_model, x_test)
predictions = Soothsayer.predict(fitted_model, x_test) |> Nx.to_flat_list()

# Check if predictions follow the trend and seasonality (with tolerance for noise)
Enum.zip(predictions, future_dates)
Expand Down

0 comments on commit a8e04b2

Please sign in to comment.