Skip to content

Commit

Permalink
added kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
paulruvolo committed Apr 26, 2015
1 parent a7fed6f commit f3cfb3b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 32 deletions.
34 changes: 4 additions & 30 deletions ELLA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"collapsed": false,
"input": [
"%matplotlib inline\n",
"from sklearn.cross_validation import train_test_split\n",
"\n",
"def multi_task_train_test_split(Xs,Ys,train_size=0.5):\n",
" Xs_train = []\n",
Expand Down Expand Up @@ -66,32 +67,14 @@
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"Average explained variance score 0.998366292378\n",
"Average classification accuracy"
]
},
{
"output_type": "stream",
"stream": "stdout",
"text": [
" 0.905\n"
]
}
],
"prompt_number": 9
"outputs": []
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from scipy.io import loadmat\n",
"from sklearn.metrics import roc_auc_score\n",
"from sklearn.cross_validation import train_test_split\n",
"\n",
"data = loadmat('landminedata.mat')\n",
"\n",
Expand All @@ -107,7 +90,7 @@
"\n",
"Xs_lm_train, Xs_lm_test, Ys_lm_train, Ys_lm_test = multi_task_train_test_split(Xs_lm,Ys_lm,train_size=0.5) \n",
"\n",
"model = ELLA(d,k,LogisticRegression,mu=1,lam=10**-5)\n",
"model = ELLA(d,k,LogisticRegression,{'C':10**-1},mu=1,lam=10**-5)\n",
"for t in range(T):\n",
" model.fit(Xs_lm_train[t], Ys_lm_train[t], t)\n",
"print \"Average AUC:\", np.mean([roc_auc_score(Ys_lm_test[t],\n",
Expand All @@ -116,16 +99,7 @@
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"Average AUC: 0.810932123838\n"
]
}
],
"prompt_number": 12
"outputs": []
},
{
"cell_type": "code",
Expand Down
5 changes: 3 additions & 2 deletions ELLA.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class ELLA(object):
""" The ELLA model """
def __init__(self,d,k,base_learner,mu=1,lam=1):
def __init__(self,d,k,base_learner,base_learner_kwargs={},mu=1,lam=1):
""" Initializes a new model for the given base_learner.
d: the number of parameters for the base learner
k: the number of latent model components
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(self,d,k,base_learner,mu=1,lam=1):
raise Exception("Unsupported Base Learner")

self.base_learner = base_learner
self.base_learner_kwargs = base_learner_kwargs

def fit(self,X,y,task_id):
""" Fit the model to a new batch of training data. The task_id must
Expand All @@ -50,7 +51,7 @@ def fit(self,X,y,task_id):
task_id: the id of the task
"""
self.T += 1
single_task_model = self.base_learner(fit_intercept=False).fit(X,y)
single_task_model = self.base_learner(fit_intercept=False,**self.base_learner_kwargs).fit(X,y)
D_t = self.get_hessian(single_task_model, X, y)
D_t_sqrt = sqrtm(D_t)
theta_t = single_task_model.coef_
Expand Down

0 comments on commit f3cfb3b

Please sign in to comment.