import numpy.typing as npt
import numpy as np
import matplotlib.pyplot as plt
[docs]def get_mae(data: npt.NDArray[np.float64],
ref: npt.NDArray[np.float64]
) -> float:
return np.mean(np.abs(data - ref))
[docs]def get_rmse(data: npt.NDArray[np.float64],
ref: npt.NDArray[np.float64]
) -> float:
return np.sqrt(np.mean((data - ref)**2))
[docs]def get_dptest_data(e_file, f_file):
e_data, e_pred = np.loadtxt(e_file, unpack=True)
e_mean = np.mean(e_data)
e_data -= e_mean
e_pred -= e_mean
nframe = len(e_data)
fx_data, fy_data, fz_data, fx_pred, fy_pred, fz_pred = np.loadtxt(f_file, unpack=True)
fx_data = fx_data.reshape((nframe, -1))
fy_data = fy_data.reshape((nframe, -1))
fz_data = fz_data.reshape((nframe, -1))
fx_pred = fx_pred.reshape((nframe, -1))
fy_pred = fy_pred.reshape((nframe, -1))
fz_pred = fz_pred.reshape((nframe, -1))
natom = fx_data.shape[1]
return e_data, e_pred, fx_data, fy_data, fz_data, fx_pred, fy_pred, fz_pred, natom
[docs]def plot_error(ax, data, pred, type='energy', title='Energy', return_err=False):
ax.scatter(data, pred, s=0.5)
xlim = ax.get_xlim()
ax.plot([xlim[0], xlim[1]], [xlim[0], xlim[1]], color='red')
rmse = get_rmse(data, pred)
rmse = rmse * 1000
mae = get_mae(data, pred)
mae = mae * 1000
ax.set_title(title)
if type == 'force':
ax.set_xlabel(r'Forces (eV/$\mathrm{\AA}$)')
ax.set_ylabel(r'Forces (eV/$\mathrm{\AA}$)')
unit=r'meV/$\mathrm{\AA}$'
elif type == 'energy':
ax.set_xlabel('Energies (eV/atom)')
ax.set_ylabel('Energies (eV/atom)')
unit='meV/atom'
ax.text(0.1, 0.8, f"RMSE: {rmse:.3f} {unit} ", transform=ax.transAxes)
ax.text(0.1, 0.9, f"MAE: {mae:.3f} {unit} ", transform=ax.transAxes)
if return_err:
return rmse, mae
[docs]def plot_dptest(e_file: str, f_file: str, save_name: str="dptest.png", return_err=False):
plt.style.use("cp2kdata.matplotlibstyle.jcp")
row = 2
col = 2
fig = plt.figure(figsize=(3.37*col, 2.6*row), dpi=200, facecolor='white')
gs = fig.add_gridspec(row,col, hspace=0.5)
e_data, e_pred, fx_data, fy_data, fz_data, fx_pred, fy_pred, fz_pred, natom = \
get_dptest_data(e_file=e_file, f_file=f_file)
ax = fig.add_subplot(gs[0])
rmse_e, mae_e = plot_error(ax, e_data/natom, e_pred/natom, type='energy', title='Energy', return_err=True)
ax = fig.add_subplot(gs[1])
rmse_fx, mae_fx = plot_error(ax, fx_data, fx_pred, type='force', title=r'$\mathrm{F_{x}}$', return_err=True)
ax = fig.add_subplot(gs[2])
rmse_fy, mae_fy = plot_error(ax, fy_data, fy_pred, type='force', title=r'$\mathrm{F_{y}}$', return_err=True)
ax = fig.add_subplot(gs[3])
rmse_fz, mae_fz = plot_error(ax, fz_data, fz_pred, type='force', title=r'$\mathrm{F_{z}}$', return_err=True)
fig.savefig(save_name, dpi=400)
if return_err:
return rmse_e, mae_e, rmse_fx, mae_fx, rmse_fy, mae_fy, rmse_fz, mae_fz