Source code for ectoolkits.plots.band_align

import matplotlib.pyplot as plt


[docs]def vac_scale(she_scale): return -she_scale-4.44
[docs]def she_scale(vac_scale): return -vac_scale-4.44
[docs]def be_data_to_she(ba_dict): new_dict = {} for name, be_data in ba_dict.items(): new_dict[name] = { "vbm": she_scale(be_data["vbm"]), "cbm": she_scale(be_data["cbm"]), } return new_dict
[docs]def plot_one_ba_sc(ax, idx, vbm, cbm, width): ax.hlines(vbm, idx-width/2, idx+width/2, color='C0', linewidth=7) ax.hlines(cbm, idx-width/2, idx+width/2, color='C3', linewidth=7)
[docs]def plot_multiple_ba(ax, ba_dict): width = 0.3 for idx, (name, be_data) in enumerate(ba_dict.items()): plot_one_ba_sc( ax, idx, vbm=be_data["vbm"], cbm=be_data["cbm"], width=width)
[docs]def plot_one_ba_diff(ax, idx, vbm, cbm, ref_vbm, ref_cbm): ax.annotate('', xy=(idx, ref_vbm), xytext=(idx, vbm), arrowprops=dict(arrowstyle='<->', color='C0')) ax.annotate('', xy=(idx, ref_cbm), xytext=(idx, cbm), arrowprops=dict(arrowstyle='<->', color='C3')) middle_point = (ref_vbm + vbm)/2 diff = abs(ref_vbm - vbm) if abs(middle_point+0.1) < ref_vbm: ax.text(idx, middle_point+0.1, f"{diff:1.2f}", fontsize='medium') else: ax.text(idx, ref_vbm, f"{diff:1.2f}", fontsize='medium') middle_point = (ref_cbm + cbm)/2 diff = abs(ref_cbm - cbm) if abs(middle_point+0.1) < ref_cbm: ax.text(idx, middle_point+0.1, f"{diff:1.2f}", fontsize='medium') else: ax.text(idx, ref_cbm, f"{diff:1.2f}", fontsize='medium')
[docs]def plot_multiple_ba_diff(ax, ba_dict, key_list, ref_key): ref_vbm = ba_dict[ref_key]['vbm'] ref_cbm = ba_dict[ref_key]['cbm'] for idx, (name, be_data) in enumerate(ba_dict.items()): if name in key_list: plot_one_ba_diff( ax, idx, be_data["vbm"], be_data["cbm"], ref_vbm=ref_vbm, ref_cbm=ref_cbm)
[docs]def plot_band_alignment(ba_dict, show_diff=False, vac_value=False): plt.rc('font', size=18) # controls default text size plt.rc('axes', titlesize=23) # fontsize of the title plt.rc('axes', labelsize=20) # fontsize of the x and y labels plt.rc('xtick', labelsize=18) # fontsize of the x tick labels plt.rc('ytick', labelsize=18) # fontsize of the y tick labels plt.rc('legend', fontsize=16) # fontsize of the legend plt.rc('lines', linewidth=2, markersize=10) # controls default text size plt.rc('axes', linewidth=2) plt.rc('xtick.major', size=10, width=2) plt.rc('ytick.major', size=10, width=2) numb_sys = len(ba_dict) if vac_value: ba_dict = be_data_to_she(ba_dict) fig = plt.figure(figsize=(16, 9), dpi=200) ax = fig.add_subplot(111) plot_multiple_ba(ax, ba_dict=ba_dict) name_list = list(ba_dict) first_key = name_list[0] ax.axhline(ba_dict[first_key]["vbm"], linestyle="--", color='C0') ax.axhline(ba_dict[first_key]["cbm"], linestyle="--", color='C3') ax.set_ylabel("E(vs. Vac.) [eV]") ax.set_xticks(range(numb_sys)) ax.set_xticklabels(name_list) ax.tick_params(direction='in') ax.invert_yaxis() y2 = ax.secondary_yaxis('right', functions=(vac_scale, she_scale)) y2.set_ylabel("E(Vac.) [eV]") y2.tick_params(direction='in') ax.set_ylabel("U(SHE) [V]") if show_diff: plot_multiple_ba_diff( ax, ba_dict, key_list=name_list[1:], ref_key=first_key) return fig
[docs]def enumerate_subplots(fig): from matplotlib.offsetbox import AnchoredText import string for idx, ax in enumerate(fig.get_axes()): at = AnchoredText(f"({string.ascii_lowercase[idx]})", loc='lower left', prop=dict(size=25), frameon=False, bbox_to_anchor=(0., 1.), bbox_transform=ax.transAxes ) # at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2") ax.add_artist(at)