Skip to content

Commit

Permalink
add argument and reformat of predict
Browse files Browse the repository at this point in the history
  • Loading branch information
zzsqwq committed May 16, 2021
1 parent 231d862 commit c53630c
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
# Created by Zs on 21-5-1
#

import argparse

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.datasets import load_boston
from sklearn.metrics import r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import r2_score
from torch import nn
import argparse


def performance_metric(y_true, y_predict):
""" Calculates and returns the performance score between
Expand Down Expand Up @@ -85,7 +87,7 @@ def draw_loss(self):
plt.title("The loss curve")
plt.xlabel("iteration step")
plt.ylabel("loss")
plt.savefig("Loss_curve.jpg",dpi=400)
plt.savefig("Loss_curve.jpg", dpi=400)
plt.show()

def predict(self, test_x):
Expand All @@ -102,48 +104,59 @@ def plot_tf(self, test_y, predict_y): # 绘制test和predict的图
plt.legend([line1, line2], ["y_predict", "y_groundtruth"])
plt.title("The curve of predict and groundtruth")
plt.ylabel("price")
plt.savefig('predict_groundtruth.png',dpi=400)
plt.savefig('predict_groundtruth.png', dpi=400)
plt.show()

def save_model(self,model_name='Boston.pt'):
torch.save(self.model,'Boston.pt')
def save_model(self, model_name='Boston.pt'):
torch.save(self.model, 'Boston.pt')

def load_model(self,weights_name='Boston.pt',learn_rate=0.1):
def load_model(self, weights_name='Boston.pt', learn_rate=0.1):
self.model = torch.load(weights_name)
self.criterion = nn.MSELoss()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learn_rate)


if __name__ == "__main__":

load_weights = False
load_cols = []
input_shape = 13
if len(load_cols)!=0:
input_shape=len(load_cols)
# load chosen cols of dataset,for example, choose RM PTRATIO LSTAT, load_cols = [5, 10, 12]
# load_cols = []
# input_shape = 13

parser = argparse.ArgumentParser()

parser.add_argument('--weights',type=str,default='Boston.pt',help='inital weights path')
parser.add_argument('--load_weights',action='store_true', help='load weights or not')
parser.add_argument('--weights', type=str, default='Boston.pt', help='inital weights path')
parser.add_argument('--load_weights', action='store_true', help='load weights or not')
parser.add_argument('--hidden_layer', type=int, default=1000, help="The dim of hidden_layer")
parser.add_argument('--learn_rate', type=float, default=0.01, help="The learning rate")
parser.add_argument("--input_shape", type=int, default=13,
help="The input_shape of networks,don't forget change load_cols")

parser.add_argument('--load_cols', nargs='+', type=int)
parser.add_argument('--epoch', type=int, default=10000, help="The epoch of train")

opt = parser.parse_args()

if len(opt.load_cols) != 0:
input_shape = len(opt.load_cols)
else:
input_shape = opt.input_shape

bos = boston()
x, y = bos.load_data(choose_col=load_cols)
x, y = bos.load_data(choose_col=opt.load_cols)
train_x, train_y, test_x, test_y = bos.split_data(x=x, y=y, split_size=0.2)

if not load_weights:
bos.init_model(input_layer=input_shape, hidden_layer=1000, learn_rate=0.01)
bos.train(train_x, train_y, epoch=10000)
bos.save_model(model_name='Boston1000.pt')
bos.init_model(input_layer=input_shape, hidden_layer=1, learn_rate=opt.learn_rate)
bos.train(train_x, train_y, epoch=opt.epoch)
bos.save_model(model_name='Boston.pt')
else:
bos.load_model(weights_name=opt.weights,learn_rate=0.01)
bos.load_model(weights_name=opt.weights, learn_rate=opt.learn_rate)

#predict_y = bos.predict(test_x)
#print(predict_y[:5].reshape(1, -1), '\n', test_y[:5].reshape(1, -1))
#print(bos.calc_loss(predict_y, test_y))
# predict_y = bos.predict(test_x)
# print(predict_y[:5].reshape(1, -1), '\n', test_y[:5].reshape(1, -1))
# print(bos.calc_loss(predict_y, test_y))
if not load_weights:
bos.draw_loss()
#bos.plot_tf(predict_y, test_y)
#print(performance_metric(predict_y,test_y))

# bos.plot_tf(predict_y, test_y)
# print(performance_metric(predict_y,test_y))

0 comments on commit c53630c

Please sign in to comment.