forked from hmmlearn/hmmlearn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
15a0634
commit 1a2da34
Showing
2 changed files
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
================================== | ||
Demonstration of sampling from HMM | ||
================================== | ||
This script shows how to sample points from a Hiden Markov Model (HMM): | ||
we use a 4-components with specified mean and covariance. | ||
The plot show the sequence of observations generated with the transitions | ||
between them. We can see that, as specified by our transition matrix, | ||
there are no transition between component 1 and 3. | ||
.. warning:: | ||
The HMM module and its functions will be removed in 0.17 | ||
as it no longer falls within the project's scope and API. | ||
""" | ||
print(__doc__) | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from hmmlearn import hmm | ||
|
||
############################################################## | ||
# Prepare parameters for a 3-components HMM | ||
# Initial population probability | ||
start_prob = np.array([0.6, 0.3, 0.1, 0.0]) | ||
# The transition matrix, note that there are no transitions possible | ||
# between component 1 and 4 | ||
trans_mat = np.array([[0.7, 0.2, 0.0, 0.1], | ||
[0.3, 0.5, 0.2, 0.0], | ||
[0.0, 0.3, 0.5, 0.2], | ||
[0.2, 0.0, 0.2, 0.6]]) | ||
# The means of each component | ||
means = np.array([[0.0, 0.0], | ||
[0.0, 11.0], | ||
[9.0, 10.0], | ||
[11.0, -1.0], | ||
]) | ||
# The covariance of each component | ||
covars = .5 * np.tile(np.identity(2), (4, 1, 1)) | ||
|
||
# Build an HMM instance and set parameters | ||
model = hmm.GaussianHMM(4, "full", start_prob, trans_mat, | ||
random_state=42) | ||
|
||
# Instead of fitting it from the data, we directly set the estimated | ||
# parameters, the means and covariance of the components | ||
model.means_ = means | ||
model.covars_ = covars | ||
############################################################### | ||
|
||
# Generate samples | ||
X, Z = model.sample(500) | ||
|
||
# Plot the sampled data | ||
plt.plot(X[:, 0], X[:, 1], "-o", label="observations", ms=6, | ||
mfc="orange", alpha=0.7) | ||
|
||
# Indicate the component numbers | ||
for i, m in enumerate(means): | ||
plt.text(m[0], m[1], 'Component %i' % (i + 1), | ||
size=17, horizontalalignment='center', | ||
bbox=dict(alpha=.7, facecolor='w')) | ||
plt.legend(loc='best') | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
""" | ||
========================== | ||
Gaussian HMM of stock data | ||
========================== | ||
This script shows how to use Gaussian HMM. | ||
It uses stock price data, which can be obtained from yahoo finance. | ||
For more information on how to get stock prices with matplotlib, please refer | ||
to date_demo1.py of matplotlib. | ||
.. warning:: | ||
The HMM module and its functions will be removed in 0.17 | ||
as it no longer falls within the project's scope and API. | ||
""" | ||
|
||
from __future__ import print_function | ||
|
||
import datetime | ||
import numpy as np | ||
import pylab as pl | ||
from matplotlib.finance import quotes_historical_yahoo | ||
from matplotlib.dates import YearLocator, MonthLocator, DateFormatter | ||
from hmmlearn.hmm import GaussianHMM | ||
|
||
|
||
print(__doc__) | ||
|
||
############################################################################### | ||
# Downloading the data | ||
date1 = datetime.date(1995, 1, 1) # start date | ||
date2 = datetime.date(2012, 1, 6) # end date | ||
# get quotes from yahoo finance | ||
quotes = quotes_historical_yahoo("INTC", date1, date2) | ||
if len(quotes) == 0: | ||
raise SystemExit | ||
|
||
# unpack quotes | ||
dates = np.array([q[0] for q in quotes], dtype=int) | ||
close_v = np.array([q[2] for q in quotes]) | ||
volume = np.array([q[5] for q in quotes])[1:] | ||
|
||
# take diff of close value | ||
# this makes len(diff) = len(close_t) - 1 | ||
# therefore, others quantity also need to be shifted | ||
diff = close_v[1:] - close_v[:-1] | ||
dates = dates[1:] | ||
close_v = close_v[1:] | ||
|
||
# pack diff and volume for training | ||
X = np.column_stack([diff, volume]) | ||
|
||
############################################################################### | ||
# Run Gaussian HMM | ||
print("fitting to HMM and decoding ...", end='') | ||
n_components = 5 | ||
|
||
# make an HMM instance and execute fit | ||
model = GaussianHMM(n_components, covariance_type="diag", n_iter=1000) | ||
|
||
model.fit([X]) | ||
|
||
# predict the optimal sequence of internal hidden state | ||
hidden_states = model.predict(X) | ||
|
||
print("done\n") | ||
|
||
############################################################################### | ||
# print trained parameters and plot | ||
print("Transition matrix") | ||
print(model.transmat_) | ||
print() | ||
|
||
print("means and vars of each hidden state") | ||
for i in range(n_components): | ||
print("%dth hidden state" % i) | ||
print("mean = ", model.means_[i]) | ||
print("var = ", np.diag(model.covars_[i])) | ||
print() | ||
|
||
years = YearLocator() # every year | ||
months = MonthLocator() # every month | ||
yearsFmt = DateFormatter('%Y') | ||
fig = pl.figure() | ||
ax = fig.add_subplot(111) | ||
|
||
for i in range(n_components): | ||
# use fancy indexing to plot data in each state | ||
idx = (hidden_states == i) | ||
ax.plot_date(dates[idx], close_v[idx], 'o', label="%dth hidden state" % i) | ||
ax.legend() | ||
|
||
# format the ticks | ||
ax.xaxis.set_major_locator(years) | ||
ax.xaxis.set_major_formatter(yearsFmt) | ||
ax.xaxis.set_minor_locator(months) | ||
ax.autoscale_view() | ||
|
||
# format the coords message box | ||
ax.fmt_xdata = DateFormatter('%Y-%m-%d') | ||
ax.fmt_ydata = lambda x: '$%1.2f' % x | ||
ax.grid(True) | ||
|
||
fig.autofmt_xdate() | ||
pl.show() |