diff --git a/lib/soothsayer.ex b/lib/soothsayer.ex index a87f5c9..7caf3dc 100644 --- a/lib/soothsayer.ex +++ b/lib/soothsayer.ex @@ -44,6 +44,11 @@ defmodule Soothsayer do end def predict(%Model{} = model, %Series{} = x) do + %{combined: combined} = predict_components(model, x) + combined + end + + def predict_components(%Model{} = model, %Series{} = x) do processed_x = Preprocessor.prepare_data(DataFrame.new(%{"ds" => x}), nil, "ds", model.config.seasonality) @@ -57,9 +62,10 @@ defmodule Soothsayer do x_normalized = normalize_with_params(x_input, model.config.normalization.x) predictions = Model.predict(model, x_normalized) - denormalized_predictions = denormalize(predictions, model.config.normalization.y) - Nx.to_flat_list(denormalized_predictions) + Map.new(predictions, fn {key, node} -> + {key, denormalize(node, model.config.normalization.y)} + end) end defp get_seasonality_input(data, seasonality, config) do diff --git a/lib/soothsayer/model.ex b/lib/soothsayer/model.ex index 16e31cd..1f2e986 100644 --- a/lib/soothsayer/model.ex +++ b/lib/soothsayer/model.ex @@ -34,7 +34,14 @@ defmodule Soothsayer.Model do Axon.constant(0) end - Axon.add([trend, yearly_seasonality, weekly_seasonality]) + combined = Axon.add([trend, yearly_seasonality, weekly_seasonality]) + + Axon.container(%{ + combined: combined, + trend: trend, + yearly_seasonality: yearly_seasonality, + weekly_seasonality: weekly_seasonality + }) end def fit(model, x, y, epochs) do @@ -44,7 +51,7 @@ defmodule Soothsayer.Model do trained_params = model.network |> Axon.Loop.trainer( - :mean_squared_error, + &Axon.Losses.huber(&1, &2.combined, reduction: :mean), Polaris.Optimizers.adam(learning_rate: model.config.learning_rate) ) |> Axon.Loop.run(Stream.repeatedly(fn -> {x, y} end), initial_params, diff --git a/test/soothsayer_test.exs b/test/soothsayer_test.exs index 08d9dd7..a7e3295 100644 --- a/test/soothsayer_test.exs +++ b/test/soothsayer_test.exs @@ -41,7 +41,7 @@ defmodule SoothsayerTest do future_end = Date.add(future_start, 29) future_dates = Date.range(future_start, future_end) x_test = Series.from_list(Enum.to_list(future_dates)) - predictions = Soothsayer.predict(fitted_model, x_test) + predictions = Soothsayer.predict(fitted_model, x_test) |> Nx.to_flat_list() # Check if predictions follow the trend (with tolerance for noise) Enum.zip(predictions, future_dates) @@ -89,7 +89,7 @@ defmodule SoothsayerTest do future_end = Date.add(future_start, 29) future_dates = Date.range(future_start, future_end) x_test = Series.from_list(Enum.to_list(future_dates)) - predictions = Soothsayer.predict(fitted_model, x_test) + predictions = Soothsayer.predict(fitted_model, x_test) |> Nx.to_flat_list() # Check if predictions follow the seasonality pattern (with tolerance for noise) Enum.zip(predictions, future_dates) @@ -140,7 +140,7 @@ defmodule SoothsayerTest do future_end = Date.add(future_start, 29) future_dates = Date.range(future_start, future_end) x_test = Series.from_list(Enum.to_list(future_dates)) - predictions = Soothsayer.predict(fitted_model, x_test) + predictions = Soothsayer.predict(fitted_model, x_test) |> Nx.to_flat_list() # Check if predictions follow the trend and seasonality (with tolerance for noise) Enum.zip(predictions, future_dates)