from typing import Tuple, Union, List
import numpy.typing as npt
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cp2kdata.plots.colormaps
[docs]def get_mae(data: npt.NDArray[np.float64],
ref: npt.NDArray[np.float64]
) -> float:
return np.mean(np.abs(data - ref))
[docs]def get_rmse(data: npt.NDArray[np.float64],
ref: npt.NDArray[np.float64]
) -> float:
return np.sqrt(np.mean((data - ref)**2))
[docs]def get_dptest_data(e_file:str,
f_file:str
) -> Tuple[npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
npt.NDArray[np.float64],
int
]:
e_data, e_pred = np.loadtxt(e_file, unpack=True)
e_mean = np.mean(e_data)
e_data -= e_mean
e_pred -= e_mean
nframe = len(e_data)
fx_data, fy_data, fz_data, fx_pred, fy_pred, fz_pred = np.loadtxt(f_file, unpack=True)
fx_data = fx_data.reshape((nframe, -1))
fy_data = fy_data.reshape((nframe, -1))
fz_data = fz_data.reshape((nframe, -1))
fx_pred = fx_pred.reshape((nframe, -1))
fy_pred = fy_pred.reshape((nframe, -1))
fz_pred = fz_pred.reshape((nframe, -1))
natom = fx_data.shape[1]
return e_data, e_pred, fx_data, fy_data, fz_data, fx_pred, fy_pred, fz_pred, natom
[docs]def plot_error(ax,
data: npt.NDArray,
pred: npt.NDArray,
type: str='energy',
title: str='Energy',
return_err: bool=False
) -> Union[None, Tuple[float, float]]:
ax.scatter(data, pred, s=0.5)
xlim = ax.get_xlim()
ax.plot([xlim[0], xlim[1]], [xlim[0], xlim[1]], color='red')
rmse = get_rmse(data, pred)
rmse = rmse * 1000
mae = get_mae(data, pred)
mae = mae * 1000
ax.set_title(title)
if type == 'force':
ax.set_xlabel(r'Forces (eV/$\mathrm{\AA}$)')
ax.set_ylabel(r'Forces (eV/$\mathrm{\AA}$)')
unit=r'meV/$\mathrm{\AA}$'
elif type == 'energy':
ax.set_xlabel('Energies (eV/atom)')
ax.set_ylabel('Energies (eV/atom)')
unit='meV/atom'
ax.text(0.1, 0.8, f"RMSE: {rmse:.3f} {unit} ", transform=ax.transAxes)
ax.text(0.1, 0.9, f"MAE: {mae:.3f} {unit} ", transform=ax.transAxes)
if return_err:
return rmse, mae
[docs]def plot_dptest(e_file: str,
f_file: str,
save_name: str="dptest.png",
return_err: bool=False,
frc_comp: bool=False
):
plt.style.use("cp2kdata.matplotlibstyle.jcp")
if frc_comp:
row = 2
col = 2
fig = plt.figure(figsize=(3.37*col, 2.6*row), dpi=200, facecolor='white')
gs = fig.add_gridspec(row,col, hspace=0.5)
e_data, e_pred, fx_data, fy_data, fz_data, fx_pred, fy_pred, fz_pred, natom = \
get_dptest_data(e_file=e_file, f_file=f_file)
ax = fig.add_subplot(gs[0])
rmse_e, mae_e = plot_error(ax, e_data/natom, e_pred/natom, type='energy', title='Energy', return_err=True)
ax = fig.add_subplot(gs[1])
rmse_fx, mae_fx = plot_error(ax, fx_data, fx_pred, type='force', title=r'$\mathrm{F_{x}}$', return_err=True)
ax = fig.add_subplot(gs[2])
rmse_fy, mae_fy = plot_error(ax, fy_data, fy_pred, type='force', title=r'$\mathrm{F_{y}}$', return_err=True)
ax = fig.add_subplot(gs[3])
rmse_fz, mae_fz = plot_error(ax, fz_data, fz_pred, type='force', title=r'$\mathrm{F_{z}}$', return_err=True)
else:
row = 1
col = 2
fig = plt.figure(figsize=(3.37*col, 2.6*row), dpi=200, facecolor='white')
gs = fig.add_gridspec(row,col, hspace=0.5)
e_data, e_pred, fx_data, fy_data, fz_data, fx_pred, fy_pred, fz_pred, natom = \
get_dptest_data(e_file=e_file, f_file=f_file)
f_data = np.array([fx_data, fy_data, fz_data])
f_pred = np.array([fx_pred, fy_pred, fz_pred])
ax = fig.add_subplot(gs[0])
rmse_e, mae_e = plot_error(ax, e_data/natom, e_pred/natom, type='energy', title='Energy', return_err=True)
ax = fig.add_subplot(gs[1])
rmse_f, mae_f = plot_error(ax, f_data, f_pred, type='force', title=r'Force', return_err=True)
if save_name:
fig.savefig(save_name, dpi=300)
if return_err:
if frc_comp:
return fig, rmse_e, mae_e, rmse_fx, mae_fx, rmse_fy, mae_fy, rmse_fz, mae_fz
else:
return fig, rmse_e, mae_e, rmse_f, mae_f
else:
return fig
## model deviation
def _hist_and_stat_model_devi(model_devi_files: List[List[str]],
trust_lo: float,
trust_hi: float,
**kwargs
) -> Tuple[List[Tuple[npt.NDArray, npt.NDArray]],
List[List[float]]
]:
hist_model_devi = []
stat_model_devi = []
for _iter_files in model_devi_files:
_model_devi = None
for _model_devi_file in _iter_files:
if _model_devi is None:
_model_devi = np.loadtxt(_model_devi_file, usecols=4)
else:
_model_devi = np.concatenate((_model_devi, np.loadtxt(_model_devi_file, usecols=4)))
ntotal = len(_model_devi)
naccurate = np.sum((_model_devi < trust_lo))
ncandidate = np.sum((_model_devi > trust_lo) &(_model_devi < trust_hi))
npoor = np.sum((_model_devi > trust_hi))
ratio_accurate = naccurate / ntotal
ratio_candidate = ncandidate / ntotal
ratio_poor = npoor / ntotal
stat_model_devi.append([ratio_accurate, ratio_candidate, ratio_poor])
hist, bin_edges = np.histogram(_model_devi, **kwargs)
hist_model_devi.append([hist, bin_edges])
return hist_model_devi, stat_model_devi
[docs]def plot_model_devi(model_devi_files: List[List[str]],
trust_lo: float,
trust_hi: float,
save_name: str="model_devi.png",
**kwargs
):
hist_model_devi, stat_model_devi = _hist_and_stat_model_devi(model_devi_files, trust_lo, trust_hi, **kwargs)
# get the maximum value of histograms.
max_hist = 0
for _hist, _bin_edges in hist_model_devi:
if _hist.max() > max_hist:
max_hist = _hist.max()
plt.style.use('cp2kdata.matplotlibstyle.jcp')
niters = len(hist_model_devi)
cp2kdata_cb_lscmap = mpl.colormaps['cp2kdata_cb_lscmap']
plt.rcParams["axes.prop_cycle"] = plt.cycler("color", cp2kdata_cb_lscmap(np.linspace(0, 1, niters)))
row = niters
col = 1
fig = plt.figure(figsize=(3.37*col, 1.00*row), dpi=300, facecolor="white")
gs = fig.add_gridspec(row, col, hspace=0.0)
shifts = np.linspace(0, 10, niters)
for i_hist_model_devi, (_hist, _bin_edges) in enumerate(hist_model_devi):
ax = fig.add_subplot(gs[i_hist_model_devi])
ax.fill(_bin_edges[:-1], _hist, label=f"iter-{i_hist_model_devi:03d}", alpha=0.5, color=f"C{i_hist_model_devi}")
ax.axvline(trust_lo, color="black", linestyle="--", label="trust_lo")
ax.axvline(trust_hi, color="black", linestyle="--", label="trust_hi")
name = "Accurate:"
percent = stat_model_devi[i_hist_model_devi][0] * 100
percent = f"{percent:4.2f}"
percent = percent.zfill(5) # I must turn a float into a string to pad zeros.
ax.text(1.35, 0.3, f"{name:>10} {percent}%", transform=ax.transAxes, ha="right")
name = "Candidate:"
percent = stat_model_devi[i_hist_model_devi][1] * 100
percent = f"{percent:4.2f}"
percent = percent.zfill(5)
ax.text(1.35, 0.2, f"{name:>10} {percent}%", transform=ax.transAxes, ha="right")
name = "Poor:"
percent = stat_model_devi[i_hist_model_devi][2] * 100
percent = f"{percent:4.2f}"
percent = percent.zfill(5)
ax.text(1.35, 0.1, f"{name:>10} {percent}%", transform=ax.transAxes, ha="right")
ax.set_ylim(-0.05, max_hist + 0.05)
ax.legend(bbox_to_anchor=(1.37, 1.00))
# I prefer fig.supxlabel to ax.set_xlabel, but the y position of fig.supxlabel is scaled with the figure size.
ax.set_xlabel("Model deviation " + r"(eV/$\mathrm{\AA}$)")
fig.supylabel("Density")
if save_name:
fig.savefig(save_name, bbox_inches="tight")
return fig