From f3cfb3b6ed679b8b76f99f1db03a1c313dca3b31 Mon Sep 17 00:00:00 2001 From: paulruvolo Date: Sat, 25 Apr 2015 22:52:28 -0400 Subject: [PATCH] added kwargs --- ELLA.ipynb | 34 ++++------------------------------ ELLA.py | 5 +++-- 2 files changed, 7 insertions(+), 32 deletions(-) diff --git a/ELLA.ipynb b/ELLA.ipynb index de98abc..e7e59a3 100644 --- a/ELLA.ipynb +++ b/ELLA.ipynb @@ -12,6 +12,7 @@ "collapsed": false, "input": [ "%matplotlib inline\n", + "from sklearn.cross_validation import train_test_split\n", "\n", "def multi_task_train_test_split(Xs,Ys,train_size=0.5):\n", " Xs_train = []\n", @@ -66,24 +67,7 @@ ], "language": "python", "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "stream": "stdout", - "text": [ - "Average explained variance score 0.998366292378\n", - "Average classification accuracy" - ] - }, - { - "output_type": "stream", - "stream": "stdout", - "text": [ - " 0.905\n" - ] - } - ], - "prompt_number": 9 + "outputs": [] }, { "cell_type": "code", @@ -91,7 +75,6 @@ "input": [ "from scipy.io import loadmat\n", "from sklearn.metrics import roc_auc_score\n", - "from sklearn.cross_validation import train_test_split\n", "\n", "data = loadmat('landminedata.mat')\n", "\n", @@ -107,7 +90,7 @@ "\n", "Xs_lm_train, Xs_lm_test, Ys_lm_train, Ys_lm_test = multi_task_train_test_split(Xs_lm,Ys_lm,train_size=0.5) \n", "\n", - "model = ELLA(d,k,LogisticRegression,mu=1,lam=10**-5)\n", + "model = ELLA(d,k,LogisticRegression,{'C':10**-1},mu=1,lam=10**-5)\n", "for t in range(T):\n", " model.fit(Xs_lm_train[t], Ys_lm_train[t], t)\n", "print \"Average AUC:\", np.mean([roc_auc_score(Ys_lm_test[t],\n", @@ -116,16 +99,7 @@ ], "language": "python", "metadata": {}, - "outputs": [ - { - "output_type": "stream", - "stream": "stdout", - "text": [ - "Average AUC: 0.810932123838\n" - ] - } - ], - "prompt_number": 12 + "outputs": [] }, { "cell_type": "code", diff --git a/ELLA.py b/ELLA.py index 9b6f06c..93a8b91 100644 --- a/ELLA.py +++ b/ELLA.py @@ -11,7 +11,7 @@ class ELLA(object): """ The ELLA model """ - def __init__(self,d,k,base_learner,mu=1,lam=1): + def __init__(self,d,k,base_learner,base_learner_kwargs={},mu=1,lam=1): """ Initializes a new model for the given base_learner. d: the number of parameters for the base learner k: the number of latent model components @@ -39,6 +39,7 @@ def __init__(self,d,k,base_learner,mu=1,lam=1): raise Exception("Unsupported Base Learner") self.base_learner = base_learner + self.base_learner_kwargs = base_learner_kwargs def fit(self,X,y,task_id): """ Fit the model to a new batch of training data. The task_id must @@ -50,7 +51,7 @@ def fit(self,X,y,task_id): task_id: the id of the task """ self.T += 1 - single_task_model = self.base_learner(fit_intercept=False).fit(X,y) + single_task_model = self.base_learner(fit_intercept=False,**self.base_learner_kwargs).fit(X,y) D_t = self.get_hessian(single_task_model, X, y) D_t_sqrt = sqrtm(D_t) theta_t = single_task_model.coef_