Skip to content

Commit

Permalink
Add plot 3d potential field
Browse files Browse the repository at this point in the history
Signed-off-by: An Thai Le <an.thai.le97@gmail.com>
  • Loading branch information
anindex committed Aug 23, 2021
1 parent 2ed96cb commit 097f293
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
8 changes: 4 additions & 4 deletions scripts/test_tprmp_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
# parameters
oversteps = 50
dt = 0.01
NUM_COMP = 20
NUM_COMP = 10
alpha, beta = 0., 0.
stiff_scale = 1.
tau = 1.
potential_method = 'quadratic'
optimize_method = 'flow'
optimize_method = 'riemannian'
d_min = 0.
d_scale = 1.
energy = 0.
Expand All @@ -52,14 +52,14 @@
demo.add_frame_from_pose(traj[:, 0], 'start')
demo.add_frame_from_pose(traj[:, -1], 'end')
demos.append(demo)
plot_demo(demos, only_global=False, plot_quat=False, new_fig=True, new_ax=True, three_d=False, show=True)
# plot_demo(demos, only_global=False, plot_quat=False, new_fig=True, new_ax=True, three_d=False, show=True)
# train tprmp
sample = demos[0]
frames = sample.get_task_parameters()
model = TPRMP(num_comp=NUM_COMP, name=args.task, sigma=sigma, stiff_scale=stiff_scale, tau=tau, potential_method=potential_method, d_scale=d_scale)
model.train(demos, optimize_method=optimize_method, alpha=alpha, beta=beta, d_min=d_min, energy=energy, var_scale=var_scale, verbose=verbose)
model.model.plot_model(demos, tagging=False, three_d=False)
plot_potential_field(model, frames, only_global=False, margin=0.5, res=res, new_fig=True, show=True)
plot_potential_field(model, frames, only_global=False, margin=0.5, three_d=True, res=res, new_fig=True, show=True)
# execution
x0, dx0 = np.array([0.5, 2.5]), np.zeros(2)
visualize_rmp(model, frames, x0, dx0, sample.traj.shape[1] + oversteps, dt, sample=sample, x_limits=[0., 4.], vel_limits=[-10., 10.])
Expand Down
43 changes: 30 additions & 13 deletions tprmp/visualization/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def visualize_rmp(tprmp, frames, x0, dx0, T, dt, sample=None, x_limits=[0., 5.],

def plot_potential_field(tprmp, frames, **kwargs):
only_global = kwargs.get('only_global', True)
plot_gaussian = kwargs.get('plot_gaussian', 0.1)
plot_gaussian = kwargs.get('plot_gaussian', True)
three_d = kwargs.get('three_d', False)
margin = kwargs.get('margin', 0.5)
res = kwargs.get('res', 0.1)
new_fig = kwargs.get('new_fig', False)
Expand All @@ -60,37 +61,50 @@ def plot_potential_field(tprmp, frames, **kwargs):
mid_y = (frame_origins[:, 1].max() + frame_origins[:, 1].min()) * 0.5
if new_fig:
plt.figure()
_plot_potential_field_global(tprmp, frames, [mid_x, mid_y], ranges, plot_gaussian=plot_gaussian, res=res)
_plot_potential_field_global(tprmp, frames, [mid_x, mid_y], ranges, plot_gaussian=plot_gaussian, three_d=three_d, res=res)
if not only_global:
_plot_potential_field_frames(tprmp, frames, ranges, res=res)
if three_d:
plt.figure()
_plot_potential_field_frames(tprmp, frames, ranges, plot_gaussian=plot_gaussian, three_d=three_d, res=res)
if show:
plt.show()


def _plot_potential_field_global(tprmp, frames, mid, ranges, plot_gaussian=True, res=0.1, alpha=0.5):
ax = plt.subplot(111)
def _plot_potential_field_global(tprmp, frames, mid, ranges, plot_gaussian=True, three_d=False, res=0.1, alpha=0.7):
if three_d:
ax = plt.subplot(111, projection='3d')
else:
ax = plt.subplot(111)
x = np.arange(mid[0] - ranges, mid[0] + ranges, res)
y = np.arange(mid[1] - ranges, mid[1] + ranges, res)
X, Y = np.meshgrid(x, y)
Z = np.zeros_like(X)
for i in range(X.shape[0]):
for j in range(X.shape[1]):
Z[i, j] = tprmp.compute_potential_field(np.array([X[i, j], Y[i, j]]), frames)
c = ax.pcolormesh(X, Y, Z, cmap='RdBu', shading='auto', vmin=0., vmax=Z.max(), alpha=alpha)
if three_d:
c = ax.plot_surface(X, Y, Z, cmap='RdBu', vmin=0., vmax=Z.max(), alpha=alpha)
else:
c = ax.pcolormesh(X, Y, Z, cmap='RdBu', shading='auto', vmin=0., vmax=Z.max(), alpha=alpha)
ax.axes.set_aspect('equal')
ax.set_title('Global potential Phi')
plt.gcf().colorbar(c, ax=ax)
plot_frame_2d(frames.values())
ax.axes.set_aspect('equal')
if plot_gaussian:
_plot_gmm_global(tprmp.model, frames, three_d=False, new_ax=False)


def _plot_potential_field_frames(tprmp, frames, ranges, axs=None, plot_gaussian=True, res=0.1, alpha=0.5):
def _plot_potential_field_frames(tprmp, frames, ranges, axs=None, plot_gaussian=True, three_d=False, res=0.1, alpha=0.7):
if axs is None:
axs = {}
_, axes = plt.subplots(1, len(frames), figsize=(14, 6))
for i, frame in enumerate(frames):
axs[frame] = axes[i]
if three_d:
plt.clf()
for i, frame in enumerate(frames):
axs[frame] = plt.subplot(1, len(frames), i + 1, projection="3d")
else:
_, axes = plt.subplots(1, len(frames), figsize=(14, 6))
for i, frame in enumerate(frames):
axs[frame] = axes[i]
x = y = np.arange(-ranges * 2, ranges * 2, res)
X, Y = np.meshgrid(x, y)
Z = {}
Expand All @@ -102,11 +116,14 @@ def _plot_potential_field_frames(tprmp, frames, ranges, axs=None, plot_gaussian=
Z[f_key][i, j] = tprmp.compute_potential_field_frame(np.array([X[i, j], Y[i, j]]), f_key)
z_max = max(z_max, Z[f_key].max())
for f_key in frames:
c = axs[f_key].pcolormesh(X, Y, Z[f_key], cmap='RdBu', shading='auto', vmin=0., vmax=z_max, alpha=alpha)
if three_d:
c = axs[f_key].plot_surface(X, Y, Z[f_key], cmap='RdBu', vmin=0., vmax=z_max, alpha=alpha)
else:
c = axs[f_key].pcolormesh(X, Y, Z[f_key], cmap='RdBu', shading='auto', vmin=0., vmax=z_max, alpha=alpha)
axs[f_key].axes.set_aspect('equal')
axs[f_key].set_title(f'Frame {f_key}')
plt.gcf().colorbar(c, ax=axs[f_key])
axs[f_key].plot([0, 1 / 2], [0, 0], 'r')
axs[f_key].plot([0, 0], [0, 1 / 2], 'b')
axs[f_key].axes.set_aspect('equal')
if plot_gaussian:
_plot_gmm_frames(tprmp.model, frames, axs=axs, three_d=False)
8 changes: 6 additions & 2 deletions tprmp/visualization/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import mpl_toolkits.mplot3d.art3d as art3d
import numpy as np
import logging

Expand Down Expand Up @@ -88,8 +89,11 @@ def _plot_gaussian(mvn, color='b', three_d=True):
theta = float(np.degrees(np.arctan2(y, x)))
width, height = 2 * np.sqrt(w)
ellipse = Ellipse(mu[0:2], width, height, angle=theta, color=color, fill=True)
plt.gca().add_patch(ellipse)
plt.scatter([mu[0]], [mu[1]], marker='o', color=color, alpha=0.5)
ax = plt.gca()
ax.add_patch(ellipse)
if ax.name == '3d':
art3d.pathpatch_2d_to_3d(ellipse, z=0, zdir='z')
plt.plot([mu[0]], [mu[1]], marker='o', color=color)


def plot_hsmm(model, end_states=True, legend=True, duration=True, new_fig=False, show=False): # TODO: check plotting locations
Expand Down

0 comments on commit 097f293

Please sign in to comment.