Skip to content

Commit

Permalink
Deprecate fetch_mldata (#11466)
Browse files Browse the repository at this point in the history
* API Deprecate fetch_mldata and update examples
* Use pytest's filterwarnings
* Rm unused import
* Remove broken doctest
* Refer user to openml URL
* DOC whatsnew tweak
jnothman authored and rth committed Aug 18, 2018
1 parent 4752ea7 commit fc56da5
Showing 11 changed files with 100 additions and 144 deletions.
83 changes: 0 additions & 83 deletions doc/datasets/index.rst
Original file line number Diff line number Diff line change
@@ -351,89 +351,6 @@ features::

_`Faster API-compatible implementation`: https://github.com/mblondel/svmlight-loader

..
For doctests:
>>> import numpy as np
>>> import os
>>> import tempfile
>>> # Create a temporary folder for the data fetcher
>>> custom_data_home = tempfile.mkdtemp()
>>> os.makedirs(os.path.join(custom_data_home, 'mldata'))


.. _mldata:

Downloading datasets from the mldata.org repository
---------------------------------------------------

`mldata.org <http://mldata.org>`_ is a public repository for machine learning
data, supported by the `PASCAL network <http://www.pascal-network.org>`_ .

The ``sklearn.datasets`` package is able to directly download data
sets from the repository using the function
:func:`sklearn.datasets.fetch_mldata`.

For example, to download the MNIST digit recognition database::

>>> from sklearn.datasets import fetch_mldata
>>> mnist = fetch_mldata('MNIST original', data_home=custom_data_home)

The MNIST database contains a total of 70000 examples of handwritten digits
of size 28x28 pixels, labeled from 0 to 9::

>>> mnist.data.shape
(70000, 784)
>>> mnist.target.shape
(70000,)
>>> np.unique(mnist.target)
array([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])

After the first download, the dataset is cached locally in the path
specified by the ``data_home`` keyword argument, which defaults to
``~/scikit_learn_data/``::

>>> os.listdir(os.path.join(custom_data_home, 'mldata'))
['mnist-original.mat']

Data sets in `mldata.org <http://mldata.org>`_ do not adhere to a strict
naming or formatting convention. :func:`sklearn.datasets.fetch_mldata` is
able to make sense of the most common cases, but allows to tailor the
defaults to individual datasets:

* The data arrays in `mldata.org <http://mldata.org>`_ are most often
shaped as ``(n_features, n_samples)``. This is the opposite of the
``scikit-learn`` convention, so :func:`sklearn.datasets.fetch_mldata`
transposes the matrix by default. The ``transpose_data`` keyword controls
this behavior::

>>> iris = fetch_mldata('iris', data_home=custom_data_home)
>>> iris.data.shape
(150, 4)
>>> iris = fetch_mldata('iris', transpose_data=False,
... data_home=custom_data_home)
>>> iris.data.shape
(4, 150)

* For datasets with multiple columns, :func:`sklearn.datasets.fetch_mldata`
tries to identify the target and data columns and rename them to ``target``
and ``data``. This is done by looking for arrays named ``label`` and
``data`` in the dataset, and failing that by choosing the first array to be
``target`` and the second to be ``data``. This behavior can be changed with
the ``target_name`` and ``data_name`` keywords, setting them to a specific
name or index number (the name and order of the columns in the datasets
can be found at its `mldata.org <http://mldata.org>`_ under the tab "Data"::

>>> iris2 = fetch_mldata('datasets-UCI iris', target_name=1, data_name=0,
... data_home=custom_data_home)
>>> iris3 = fetch_mldata('datasets-UCI iris', target_name='class',
... data_name='double0', data_home=custom_data_home)


..
>>> import shutil
>>> shutil.rmtree(custom_data_home)
.. _external_datasets:

Loading from external datasets
2 changes: 1 addition & 1 deletion doc/modules/classes.rst
Original file line number Diff line number Diff line change
@@ -257,7 +257,6 @@ Loaders
datasets.fetch_kddcup99
datasets.fetch_lfw_pairs
datasets.fetch_lfw_people
datasets.fetch_mldata
datasets.fetch_olivetti_faces
datasets.fetch_openml
datasets.fetch_rcv1
@@ -1513,6 +1512,7 @@ To be removed in 0.22
:template: deprecated_function.rst

covariance.graph_lasso
datasets.fetch_mldata


To be removed in 0.21
4 changes: 4 additions & 0 deletions doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
@@ -209,6 +209,10 @@ Support for Python 3.3 has been officially dropped.
data points could be generated. :issue:`10045` by :user:`Christian Braune
<christianbraune79>`.

- |API| Deprecated :func:`sklearn.datasets.fetch_mldata` to be removed in
version 0.22. mldata.org is no longer operational. Until removal it will
remain possible to load cached datasets. :issue:`11466` by `Joel Nothman`_.

:mod:`sklearn.decomposition`
............................

49 changes: 42 additions & 7 deletions examples/gaussian_process/plot_gpr_co2.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
hyperparameter optimization using gradient ascent on the
log-marginal-likelihood. The data consists of the monthly average atmospheric
CO2 concentrations (in parts per million by volume (ppmv)) collected at the
Mauna Loa Observatory in Hawaii, between 1958 and 1997. The objective is to
Mauna Loa Observatory in Hawaii, between 1958 and 2001. The objective is to
model the CO2 concentration as a function of the time t.
The kernel is composed of several terms that are responsible for explaining
@@ -57,24 +57,59 @@
explained by the model. The figure shows also that the model makes very
confident predictions until around 2015.
"""
print(__doc__)

# Authors: Jan Hendrik Metzen <jhm@informatik.uni-bremen.de>
#
# License: BSD 3 clause

from __future__ import division, print_function

import numpy as np

from matplotlib import pyplot as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels \
import RBF, WhiteKernel, RationalQuadratic, ExpSineSquared
from sklearn.datasets import fetch_mldata
try:
from urllib.request import urlopen
except ImportError:
# Python 2
from urllib2 import urlopen

print(__doc__)


data = fetch_mldata('mauna-loa-atmospheric-co2').data
X = data[:, [1]]
y = data[:, 0]
def load_mauna_loa_atmospheric_c02():
url = ('http://cdiac.ess-dive.lbl.gov/'
'ftp/trends/co2/sio-keel-flask/maunaloa_c.dat')
months = []
ppmv_sums = []
counts = []
for line in urlopen(url):
line = line.decode('utf8')
if not line.startswith('MLO'):
# ignore headers
continue
station, date, weight, flag, ppmv = line.split()
y = date[:2]
m = date[2:4]
month_float = (int(('20' if y < '20' else '19') + y) +
(int(m) - 1) / 12)
if not months or month_float != months[-1]:
months.append(month_float)
ppmv_sums.append(float(ppmv))
counts.append(1)
else:
# aggregate monthly sum to produce average
ppmv_sums[-1] += float(ppmv)
counts[-1] += 1

months = np.asarray(months).reshape(-1, 1)
avg_ppmvs = np.asarray(ppmv_sums) / counts
return months, avg_ppmvs


X, y = load_mauna_loa_atmospheric_c02()

# Kernel with parameters given in GPML book
k1 = 66.0**2 * RBF(length_scale=67.0) # long term smooth rising trend
7 changes: 4 additions & 3 deletions examples/linear_model/plot_sgd_early_stopping.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@
import matplotlib.pyplot as plt

from sklearn import linear_model
from sklearn.datasets import fetch_mldata
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.utils.testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
@@ -56,9 +56,10 @@
print(__doc__)


def load_mnist(n_samples=None, class_0=0, class_1=8):
def load_mnist(n_samples=None, class_0='0', class_1='8'):
"""Load MNIST, select two classes, shuffle and return only n_samples."""
mnist = fetch_mldata('MNIST original')
# Load data from http://openml.org/d/554
mnist = fetch_openml('mnist_784', version=1)

# take only two classes for binary classification
mask = np.logical_or(mnist.target == class_0, mnist.target == class_1)
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import fetch_mldata
from sklearn.datasets import fetch_openml
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
@@ -35,9 +35,11 @@
t0 = time.time()
train_samples = 5000

mnist = fetch_mldata('MNIST original')
X = mnist.data.astype('float64')
# Load data from https://www.openml.org/d/554
mnist = fetch_openml('mnist_784', version=1)
X = mnist.data
y = mnist.target

random_state = check_random_state(0)
permutation = random_state.permutation(X.shape[0])
X = X[permutation]
14 changes: 7 additions & 7 deletions examples/multioutput/plot_classifier_chain_yeast.py
Original file line number Diff line number Diff line change
@@ -32,24 +32,24 @@
with randomly ordered chains).
"""

print(__doc__)

# Author: Adam Kleczewski
# License: BSD 3 clause

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.multioutput import ClassifierChain
from sklearn.model_selection import train_test_split
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import jaccard_similarity_score
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import fetch_mldata

# Load a multi-label dataset
yeast = fetch_mldata('yeast')
X = yeast['data']
Y = yeast['target'].transpose().toarray()
print(__doc__)

# Load a multi-label dataset from https://www.openml.org/d/40597
yeast = fetch_openml('yeast', version=4)
X = yeast.data
Y = yeast.target == 'TRUE'
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2,
random_state=0)

13 changes: 8 additions & 5 deletions examples/neural_networks/plot_mnist_filters.py
Original file line number Diff line number Diff line change
@@ -20,15 +20,18 @@
for a very short time. Training longer would result in weights with a much
smoother spatial appearance.
"""
print(__doc__)

import matplotlib.pyplot as plt
from sklearn.datasets import fetch_mldata
from sklearn.datasets import fetch_openml
from sklearn.neural_network import MLPClassifier

mnist = fetch_mldata("MNIST original")
print(__doc__)

# Load data from https://www.openml.org/d/554
mnist = fetch_openml('mnist_784', version=1)
X = mnist.data
y = mnist.target

# rescale the data, use the traditional train/test split
X, y = mnist.data / 255., mnist.target
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

47 changes: 13 additions & 34 deletions sklearn/datasets/mldata.py
Original file line number Diff line number Diff line change
@@ -25,13 +25,19 @@

from .base import get_data_home
from ..utils import Bunch
from ..utils import deprecated

MLDATA_BASE_URL = "http://mldata.org/repository/data/download/matlab/%s"


@deprecated('mldata_filename was deprecated in version 0.20 and will be '
'removed in version 0.22')
def mldata_filename(dataname):
"""Convert a raw name for a data set in a mldata.org filename.
.. deprecated:: 0.20
Will be removed in version 0.22
Parameters
----------
dataname : str
@@ -46,10 +52,14 @@ def mldata_filename(dataname):
return re.sub(r'[().]', '', dataname)


@deprecated('fetch_mldata was deprecated in version 0.20 and will be removed '
'in version 0.22')
def fetch_mldata(dataname, target_name='label', data_name='data',
transpose_data=True, data_home=None):
"""Fetch an mldata.org data set
mldata.org is no longer operational.
If the file does not exist yet, it is downloaded from mldata.org .
mldata.org does not have an enforced convention for storing data or
@@ -70,6 +80,9 @@ def fetch_mldata(dataname, target_name='label', data_name='data',
mldata.org data sets may have multiple columns, which are stored in the
Bunch object with their original name.
.. deprecated:: 0.20
Will be removed in version 0.22
Parameters
----------
@@ -99,40 +112,6 @@ def fetch_mldata(dataname, target_name='label', data_name='data',
'data', the data to learn, 'target', the classification labels,
'DESCR', the full description of the dataset, and
'COL_NAMES', the original names of the dataset columns.
Examples
--------
Load the 'iris' dataset from mldata.org:
>>> from sklearn.datasets.mldata import fetch_mldata
>>> import tempfile
>>> test_data_home = tempfile.mkdtemp()
>>> iris = fetch_mldata('iris', data_home=test_data_home)
>>> iris.target.shape
(150,)
>>> iris.data.shape
(150, 4)
Load the 'leukemia' dataset from mldata.org, which needs to be transposed
to respects the scikit-learn axes convention:
>>> leuk = fetch_mldata('leukemia', transpose_data=True,
... data_home=test_data_home)
>>> leuk.data.shape
(72, 7129)
Load an alternative 'iris' dataset, which has different names for the
columns:
>>> iris2 = fetch_mldata('datasets-UCI iris', target_name=1,
... data_name=0, data_home=test_data_home)
>>> iris3 = fetch_mldata('datasets-UCI iris',
... target_name='class', data_name='double0',
... data_home=test_data_home)
>>> import shutil
>>> shutil.rmtree(test_data_home)
"""

# normalize dataset name
Loading
Oops, something went wrong.

0 comments on commit fc56da5

Please sign in to comment.