-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,179 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
torch.set_default_dtype(torch.float) | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
class GPT(nn.Module): | ||
def __init__(self, layers, lmbda, eps, P, c_initial, xt_resid, IC_xt, | ||
BC_xt, IC_u, BC_u, f_hat, activation_resid, activation_IC, | ||
activation_BC, Pt_lPxx_eP_term): | ||
super().__init__() | ||
self.layers = layers | ||
|
||
self.lmbda = lmbda | ||
self.eps = eps | ||
|
||
self.loss_function = nn.MSELoss(reduction='mean') | ||
self.linears = nn.ModuleList([nn.Linear(layers[i], layers[i+1], bias=False) for i in range(len(layers)-1)]) | ||
self.activation = P | ||
|
||
self.activation_IC = activation_IC | ||
self.activation_BC = activation_BC | ||
self.activation_resid = activation_resid | ||
self.Pt_lPxx_eP_term = Pt_lPxx_eP_term | ||
|
||
self.IC_u = IC_u | ||
self.BC_u = BC_u | ||
self.f_hat = f_hat | ||
self.xt_resid = xt_resid | ||
self.IC_xt = IC_xt | ||
self.BC_xt = BC_xt | ||
|
||
self.linears[0].weight.data = torch.ones(self.layers[1], self.layers[0]) | ||
self.linears[1].weight.data = c_initial | ||
|
||
def forward(self, datatype=None, test_data=None): | ||
if test_data is not None: | ||
a = torch.Tensor().to(device) | ||
for i in range(0, self.layers[1]): | ||
a = torch.cat((a, self.activation[i](test_data)), 1) | ||
final_output = self.linears[-1](a) | ||
|
||
return final_output | ||
|
||
if datatype == 'residual': | ||
final_output = self.linears[-1](self.activation_resid).to(device) | ||
return final_output | ||
|
||
if datatype == 'initial': | ||
final_output = self.linears[-1](self.activation_IC).to(device) | ||
return final_output | ||
|
||
if datatype == 'boundary': | ||
final_output = self.linears[-1](self.activation_BC).to(device) | ||
return final_output | ||
|
||
def lossR(self): | ||
"""Residual loss function""" | ||
u = self.forward(datatype='residual') | ||
|
||
ut_luxx_eu = torch.matmul(self.Pt_lPxx_eP_term, self.linears[1].weight.data[0][:,None]) | ||
eu3 = torch.mul(self.eps, torch.pow(u,3)) | ||
f = torch.add(ut_luxx_eu, eu3) | ||
|
||
return self.loss_function(f, self.f_hat) | ||
|
||
def lossICBC(self, datatype): | ||
"""Initial and both boundary condition loss function""" | ||
if datatype=='initial': | ||
return self.loss_function(self.forward(datatype), self.IC_u) | ||
|
||
elif datatype=='boundary': | ||
return self.loss_function(self.forward(datatype), self.BC_u) | ||
|
||
def loss(self): | ||
"""Total Loss Function""" | ||
loss_R = self.lossR() | ||
loss_IC = self.lossICBC(datatype='initial') | ||
loss_BC = self.lossICBC(datatype='boundary') | ||
return loss_R + loss_IC + loss_BC |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
import torch.nn as nn | ||
torch.set_default_dtype(torch.float) | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
class P(nn.Module): | ||
"""GPT-PINN Activation Function""" | ||
def __init__(self, w1, w2, w3, w4, w5, b1, b2, b3, b4, b5): | ||
super().__init__() | ||
self.layers = [2, 128, 128, 128, 128, 1] | ||
self.linears = nn.ModuleList([nn.Linear(self.layers[i], self.layers[i+1]) for i in range(len(self.layers)-1)]) | ||
|
||
self.linears[0].weight.data = torch.Tensor(w1) | ||
self.linears[1].weight.data = torch.Tensor(w2) | ||
self.linears[2].weight.data = torch.Tensor(w3) | ||
self.linears[3].weight.data = torch.Tensor(w4) | ||
self.linears[4].weight.data = torch.Tensor(w5).view(1,self.layers[4]) | ||
|
||
self.linears[0].bias.data = torch.Tensor(b1) | ||
self.linears[1].bias.data = torch.Tensor(b2) | ||
self.linears[2].bias.data = torch.Tensor(b3) | ||
self.linears[3].bias.data = torch.Tensor(b4) | ||
self.linears[4].bias.data = torch.Tensor(b5).view(-1) | ||
|
||
self.activation = nn.Tanh() | ||
|
||
def forward(self, x): | ||
a = x | ||
for i in range(0, len(self.layers)-2): | ||
z = self.linears[i](a) | ||
a = self.activation(z) | ||
a = self.linears[-1](a) | ||
return a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
torch.set_default_dtype(torch.float) | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
class grad_descent(object): | ||
def __init__(self, lmbda, eps, | ||
xt_resid, IC_xt, BC_xt, IC_u, BC_u, | ||
P_resid_values, P_IC_values, P_BC_values, Pt_lPxx_eP_term, lr_gpt): | ||
# PDE parameters | ||
self.lmbda = lmbda | ||
self.eps = eps | ||
|
||
# Data sizes | ||
self.N_R = xt_resid.shape[0] | ||
self.N_IC = IC_xt.shape[0] | ||
self.N_BC = BC_xt.shape[0] | ||
|
||
# Pre-computed terms | ||
self.P_resid_values = P_resid_values | ||
self.P_BC_values = P_BC_values | ||
self.P_IC_values = P_IC_values | ||
self.Pt_lPxx_eP_term = Pt_lPxx_eP_term | ||
|
||
# Training data | ||
self.IC_u = IC_u | ||
self.BC_u = BC_u | ||
|
||
# Optimizer data/parameter | ||
self.lr_gpt = lr_gpt | ||
|
||
def grad_loss(self, c): | ||
c = c.to(device) | ||
####################################################################### | ||
####################################################################### | ||
######################### Residual Gradient ######################### | ||
|
||
ut_luxx_eu = torch.matmul(self.Pt_lPxx_eP_term, c[:,None]) | ||
u = torch.matmul(self.P_resid_values, c[:,None]) | ||
eu3 = torch.mul(self.eps, torch.pow(u,3)) | ||
first_product = torch.add(ut_luxx_eu, eu3) | ||
|
||
term1 = torch.mul(3*self.eps, torch.mul(torch.square(u), self.P_resid_values)) | ||
second_product1 = torch.add(self.Pt_lPxx_eP_term, term1) | ||
grad_list = torch.mul(2/self.N_R, torch.sum(torch.mul(first_product, second_product1), axis=0)) | ||
|
||
####################################################################### | ||
####################################################################### | ||
################### Boundary and Initial Gradient ################### | ||
|
||
BC_term = torch.matmul(self.P_BC_values, c[:,None]) | ||
BC_term = torch.sub(BC_term, self.BC_u) | ||
|
||
IC_term = torch.matmul(self.P_IC_values, c[:,None]) | ||
IC_term = torch.sub(IC_term, self.IC_u) | ||
|
||
grad_list[:c.shape[0]] += torch.mul(2/self.N_BC, torch.sum(torch.mul(BC_term, self.P_BC_values), axis=0)) | ||
grad_list[:c.shape[0]] += torch.mul(2/self.N_IC, torch.sum(torch.mul(IC_term, self.P_IC_values), axis=0)) | ||
|
||
return grad_list | ||
|
||
def update(self, c): | ||
c = torch.sub(c, torch.mul(self.lr_gpt, self.grad_loss(c))) | ||
return c.expand(1,c.shape[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch | ||
import torch.autograd as autograd | ||
|
||
torch.set_default_dtype(torch.float) | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
def autograd_calculations(xt_resid, P): | ||
"""Compute graidents w.r.t t and xx for the residual data""" | ||
xt_resid = xt_resid.to(device).requires_grad_() | ||
Pi = P(xt_resid).to(device) | ||
P_xt = autograd.grad(Pi, xt_resid, torch.ones(xt_resid.shape[0], 1).to(device), create_graph=True)[0] | ||
P_xx_tt = autograd.grad(P_xt, xt_resid, torch.ones(xt_resid.shape).to(device), create_graph=True)[0] | ||
|
||
P_t = P_xt[:,[1]] | ||
P_xx = P_xx_tt[:,[0]] | ||
|
||
return P_t, P_xx | ||
|
||
def Pt_lPxx_eP(P_t, P_xx, P, lmbda, eps): | ||
"""Pt - lambda*Pxx - epsilon*P""" | ||
eP = torch.mul(-eps, P) | ||
lPxx = torch.mul(-lmbda, P_xx) | ||
pt_lPxx = torch.add(P_t, lPxx) | ||
return torch.add(pt_lPxx, eP) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
from AC_GPT_optimizer import grad_descent | ||
torch.set_default_dtype(torch.float) | ||
|
||
def gpt_train(GPT_PINN, lmbda, eps, xt_resid, IC_xt, BC_xt, IC_u, BC_u, | ||
P_resid_values, P_IC_values, | ||
P_BC_values, Pt_lPxx_eP_term, | ||
lr_gpt, epochs_gpt, largest_loss=None, largest_case=None, | ||
testing=False): | ||
|
||
GD = grad_descent(lmbda, eps, xt_resid, IC_xt, BC_xt, IC_u, BC_u, | ||
P_resid_values, P_IC_values, P_BC_values, Pt_lPxx_eP_term, lr_gpt) | ||
|
||
if (testing == False): | ||
loss_values = GPT_PINN.loss() | ||
for i in range(1, epochs_gpt+1): | ||
if (loss_values < largest_loss): | ||
break | ||
|
||
else: | ||
c = GPT_PINN.linears[1].weight.data.view(-1) | ||
GPT_PINN.linears[1].weight.data = GD.update(c) | ||
|
||
if (i == epochs_gpt): | ||
largest_case = [lmbda, eps] | ||
largest_loss = GPT_PINN.loss() | ||
|
||
loss_values = GPT_PINN.loss() | ||
return largest_loss, largest_case | ||
|
||
elif (testing): | ||
for i in range(1, epochs_gpt+1): | ||
c = GPT_PINN.linears[1].weight.data.view(-1) | ||
GPT_PINN.linears[1].weight.data = GD.update(c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
plt.style.use(['science', 'notebook']) | ||
|
||
def AC_plot(t_test, x_test, u_test, title, cmap="rainbow", scale=150): | ||
"""Allen-Cahn Contour Plot""" | ||
|
||
shape = [int(np.sqrt(u_test.shape[0])), int(np.sqrt(u_test.shape[0]))] | ||
|
||
x = x_test.reshape(shape) | ||
t = t_test.reshape(shape) | ||
u = u_test.reshape(shape) | ||
|
||
fig, ax = plt.subplots(dpi=150, figsize=(10,8)) | ||
cp = ax.contourf(t, x, u, scale, cmap=cmap) | ||
cbar = fig.colorbar(cp) | ||
|
||
cbar.ax.tick_params(labelsize=18) | ||
ax.set_xlabel("$t$", fontsize=25) | ||
ax.set_ylabel("$x$", fontsize=25) | ||
ax.set_xticks(ticks=[0.0, 0.25, 0.5, 0.75, 1.0], labels=[0.0, 0.25, 0.5, 0.75, 1.0], fontsize=18) | ||
ax.set_yticks(ticks=[-1.0, -0.5, 0.0, 0.5, 1.0], labels=[-1.0, -0.5, 0.0, 0.5, 1.0], fontsize=18) | ||
ax.set_title(title, fontsize=20) | ||
|
||
plt.show() | ||
|
||
def loss_plot(epochs_adam_sa, epochs_lbfgs_sa, adam_loss, lbfgs_loss, title=None, dpi=150, figsize=(10,8)): | ||
"""Training losses""" | ||
|
||
x_adam = range(0,epochs_adam_sa+250,250) | ||
x_lbfgs = range(x_adam[-1]+5,epochs_adam_sa+epochs_lbfgs_sa+5,5) | ||
|
||
plt.figure(dpi=dpi, figsize=figsize) | ||
|
||
plt.vlines(x_adam[-1], lbfgs_loss[0], adam_loss[-1], linewidth=3, colors='r') | ||
|
||
plt.plot(x_adam, adam_loss, c="k", linewidth=3, label="ADAM") | ||
|
||
plt.plot(x_lbfgs, lbfgs_loss, linewidth=3, c='r', label="L-BFGS") | ||
|
||
plt.xlabel("Epoch", fontsize=22.5) | ||
plt.ylabel("SA-PINN Loss", fontsize=22.5) | ||
plt.grid(True) | ||
plt.xlim(0,epochs_adam_sa+epochs_lbfgs_sa) | ||
plt.yscale('log') | ||
|
||
if title is not None: | ||
plt.title(title) | ||
|
||
plt.show() |
Oops, something went wrong.