""" mppnp plotting tool v.0.0.0, 20JULY2025: NuGrid collaboration (Joshua Issa) multizone_plot.py provides tools to plot the results of mppnp calculations and is designed to work with multizone.py mppnp_reader objects. """ import re import numpy as np from matplotlib import pyplot as plt from matplotlib.lines import Line2D from matplotlib import colors as mcolors from matplotlib.ticker import MaxNLocator from collections import defaultdict from IPython.display import display import ipywidgets as widgets import matplotlib.patches as patches from nugridpy import utils #### helper functions def format_iso(x): ele, A = x.split('-') return rf'$^{{{A}}}\mathrm{{{ele}}}$' def format_bit(bit): if bit == 'H-1': return r'$p$' elif bit == 'He-4': return r'$\alpha$' elif bit == 'Neutron-1': return r'$n$' elif bit == 'OOOOO': return 'v' else: return format_iso(bit) import re def flip_reaction(reaction): return re.sub(r'\((\w+),(\w+)\)', r'(\2,\1)', reaction) def format_vv(x, row, flip): def checkboth(iso1, iso2): return (row.ISO1 == iso1 and row.ISO2 == iso2) or (row.ISO1 == iso2 and row.ISO2 == iso1) if checkboth("He-4", "C-12"): return r'$\rightarrow$' elif checkboth("He-3", "He-4"): return r'$\rightarrow$' elif checkboth("He-4", "B-8"): return r'$\rightarrow$' elif checkboth("He-4", "Be-8"): return r'$\rightarrow$' reactor = format_iso(row.ISO2) if flip else format_iso(row.ISO1) product = format_bit(row.INPUT2) if flip else format_bit(row.OUTPUT2) return "(" + reactor + "," + product + ")" def format_rtype(x, flip=False, info=None): if x.find("+,g") != -1: x = '(+)' if x.find("-,g") != -1: x = '(-)' if flip: x = flip_reaction(x) s = x.translate(str.maketrans({"a": r"$\alpha$", "g": r"$\gamma$", 'p': r'$p$', 'n': r'$n$', '+': r'$\beta^{+}$', '-': r'$\beta^{-}$'})) if s == '(v,v)': s = format_vv(x, info, flip) return s def plot_ticks(ax, ymin, ymax, xmin, xmax, yminor=np.nan, ymajor=np.nan, xminor=np.nan, xmajor=np.nan, noxlabel=True): '''Set the x and y ticks for your plot. Parameters ---------- ax: matplotlib axis Axis object to set limits and tick spacings. ymin: float The minimum y-axis limit. ymax: float The maximum y-axis limit. xmin: float The minimum x-axis limit. xmax: float The maximum x-axis limit. yminor: float The spacing for minor ticks on the y-axis. Default: NaN (no minor ticks). ymajor: float The spacing for major ticks on the y-axis. Default: NaN (no major ticks). xminor: float The spacing for minor ticks on the x-axis. Default: NaN (no minor ticks). xmajor: float The spacing for major ticks on the x-axis. Default: NaN (no major ticks). noxlabel: boolean If True, suppress x-axis tick labels. Default: True. ''' if not np.isnan(ymajor): ymajorticks = np.arange(np.floor(ymin), np.ceil(ymax) + ymajor, ymajor) if not np.isnan(yminor): yminorticks = np.setdiff1d(np.arange(np.floor(ymin), np.ceil(ymax) + yminor, yminor), ymajorticks) if not np.isnan(ymajor): ax.set_yticks(ymajorticks) if not np.isnan(yminor): ax.set_yticks(yminorticks, minor=True) if not np.isnan(xmajor): xmajorticks = np.arange(np.floor(xmin), np.ceil(xmax) + xmajor, xmajor) if not np.isnan(xminor): xminorticks = np.setdiff1d(np.arange(np.floor(xmin), np.ceil(xmax) + xminor, xminor), xmajorticks) if not np.isnan(xmajor): ax.set_xticks(xmajorticks) if not np.isnan(xminor): ax.set_xticks(xminorticks, minor=True) ax.yaxis.set_ticks_position('both') ax.yaxis.set_tick_params(which='major', direction='in', right=True, left=True) ax.yaxis.set_tick_params(which='minor', direction='in', right=True, left=True) return ax ### OP plot def overproduction_plot(ifig, objs, cycle, isotopes, OP_type=None, x_as_massnumber=False, lines=False, **kwargs): '''Plot overproduction factors for all isotopes. Parameters ---------- ifig: integer The figure number. objs: mppnp_reader objects or list of mppnp_reader objects The mppnp simulations to plot read in as mppnp_reader objects from multizone.py cycle: float The cycle to plot the OP for. isotopes: string or list The isotopes to plot with OP. OP_type: string Possible inputs: solar: OP = log10( X/Xsolar ) intial: OP = log10( X/Xinitial ) ingest: OP = log10( X/Xingested ) Default: None, plot as log10(X) x_as_massnumber: boolean If False, isotopes will be plotted sequentially. If True, isotopes will be plotted according to their mass number A. lines: boolean If False, isotopes will be plotted individually. If True, isotopes of the same element will be plotted with a connecting line. kwargs: labels: The legend information for mppnp_reader objects. figsize: The size of the figure (tuple). label_fontsize: The fontsize of the x-axis and y-axis labels (integer). legend_fontsize: The fontsize of the legend (integer). isotope_fontsize: The fontsize of the isotope labels (integer). ymin: The lower y-axis limit (integer). ymax: The upper y-axis limit (integer). colours: Colours for the markers and lines (list). markers: Marker shapes (list). edgecolours: Edgecolour of the markers (list). linestyles: Linestyles for the connecting lines (list). size: Marker size (integer). ''' if not isinstance(objs, list): objs = [objs] if not isinstance(isotopes, list): isotopes = [isotopes] labels = kwargs.get('labels', len(objs)*['']) figsize = kwargs.get('figsize', (20, 8)) label_fontsize = kwargs.get('label_fontsize', 35) legend_fontsize = kwargs.get('legend_fontsize', 25) isotope_fontsize = kwargs.get('isotope_fontsize', 23) ymin = kwargs.get('ymin', None) ymax = kwargs.get('ymax', None) colours = kwargs.get('colours', [utils.linestylecb(i)[2] for i in range(len(objs))]) markers = kwargs.get('markers', [utils.linestylecb(i)[1] for i in range(len(objs))]) edgecolours = kwargs.get('edgecolours', ['none' for i in range(len(objs))]) linestyles = kwargs.get('linestyles', [utils.linestylecb(i)[0] for i in range(len(objs))]) size = kwargs.get('size', 150) alpha = kwargs.get('alpha', 0.8) if (ymin is None and ymax is not None) or (ymin is not None and ymax is None): ymin, ymax = None, None print("Both ymin and ymax must be set. Ignoring input.") if not isinstance(labels, list): labels = [labels] plt.close(ifig); plt.figure(ifig, figsize=figsize) ymins, ymaxs = [], [] for j, obj in enumerate(objs): if OP_type is None: X_OP = 1 elif OP_type == "solar": X_OP = obj.get("solar", 11000, isotopes) elif OP_type == "initial": X_OP = obj.get("initial", 11000, isotopes) elif OP_type == "ingest": X_OP = obj.get("ingest", 11000, isotopes) Xi_surf = obj.get("surf", 11000, isotopes) OP = np.log10(Xi_surf / X_OP) Zs, Ns, As, eles = obj.get_ZNAele(isotopes) if x_as_massnumber: xpos = As else: xpos = np.arange(len(isotopes)) OPplot = OP.copy() # Plot with arrows if out of bounds if ymin is not None and ymax is not None: adjust = (ymax - ymin) / 8 adjustbit = adjust / 1 toolarge, toosmall = OP > ymax, OP < ymin OPplot[toolarge] = ymax - adjust OPplot[toosmall] = ymin + adjust for x, y, op, big, tiny in zip(xpos, OPplot, OP, toolarge, toosmall): if not big and not tiny: continue if tiny: arrow = '-|>' length = y - adjustbit /2 txt = y + adjustbit * 0.25 elif big: arrow = '<-' length = y y = y + adjustbit / 2 txt = y - adjustbit * 0.75 plt.annotate( '', xy=(x, length), xytext=(x, y), arrowprops=dict(arrowstyle=arrow, color=colours[j], lw=4, alpha=0.5), zorder=6, ) plt.text( x-0.1, txt, fr'${{{op:.1f}}}$' if op < 10 else fr'${{{round(op,0)}}}$', fontsize=isotope_fontsize-9, verticalalignment='center', horizontalalignment='center', zorder=6 ) if lines: for ele in sorted(set(eles)): idxs = np.array(eles) == ele if np.sum(idxs) == 1: plt.scatter(xpos[idxs], OPplot[idxs], color=colours[j], edgecolor=edgecolours[j], marker=markers[j], s=size, zorder=2, alpha=alpha) else: plt.plot(xpos[idxs], OPplot[idxs], color=colours[j], marker=markers[j], \ linestyle=linestyles[j], markeredgecolor=edgecolours[j], markersize=np.sqrt(size), zorder=2, alpha=alpha) else: plt.scatter(xpos, OPplot, color=colours[j], edgecolor=edgecolours[j], marker=markers[j], s=size, zorder=2, alpha=alpha) if ymin is None: this_ymin, this_ymax = obj.round_to_nearest(OPplot, 0.5) ymins.append(this_ymin) ymaxs.append(this_ymax) if ymin is None: ymin, ymax = np.min(ymins), np.max(ymaxs) xmin, xmax = obj.round_to_nearest(xpos, 5) if x_as_massnumber else obj.round_to_nearest(xpos, 1) ax = plt.gca() ax = plot_ticks(ax, ymin, ymax, xmin, xmax, 0.25, 0.5, 5, 10) ax.set_ylim(ymin, ymax) ax.set_xlim(xmin-1, xmax+1) adjust = ((ymax - ymin) / 8 ) * 0.4 if not lines and not x_as_massnumber: adjust = (ymax - ymin) / 30 for idx, (x, isotope) in enumerate(zip(xpos, isotopes), 1): iso = format_iso(isotope) if idx % 2 == 1: plt.text(x-0.1, ymax-adjust*1.5, iso, ha='center', fontsize=isotope_fontsize) plt.axvline(x, color='lightgrey', lw=0.5,zorder=1) else: plt.text(x-0.1, ymin+adjust*0.3, iso, ha='center', fontsize=isotope_fontsize) if not lines and x_as_massnumber: for isotope, x, y in zip(isotopes, xpos, OPplot): iso = format_iso(isotope) plt.text(x, y-adjust, iso, fontsize=isotope_fontsize, ha='center', va='center', zorder=3) if lines: for ele in sorted(set(eles)): idxs = np.array(eles) == ele plt.text((xpos[idxs])[0], (OPplot[idxs])[0]-adjust, ele, fontsize=isotope_fontsize, ha='center', va='center', zorder = 3) if OP_type is not None: ax.set_ylabel("OP", fontsize=label_fontsize) else: ax.set_ylabel(r"$\log_{10}(X)$", fontsize=label_fontsize) if x_as_massnumber: ax.set_xlabel(r"$A$", fontsize=label_fontsize) ax.tick_params(axis='y', labelsize=legend_fontsize) if not x_as_massnumber: ax.set_xticks([]) else: ax.tick_params(axis='x', labelsize=legend_fontsize) if np.all(np.array(labels) == ''): pass else: if lines: legend_elements = [ Line2D([], [], marker=markers[j], color=colours[j], markerfacecolor=colours[j], markersize=np.sqrt(size), \ markeredgecolor=edgecolours[j], linestyle=linestyles[j], label=labels[j], alpha=alpha) for j in range(len(objs))] else: legend_elements = [ Line2D([], [], marker=markers[j], color=colours[j], markerfacecolor=colours[j], markersize=np.sqrt(size), \ markeredgecolor=edgecolours[j], linestyle='None', label=labels[j], alpha=alpha) for j in range(len(objs))] if not x_as_massnumber: ax.legend(handles=legend_elements, ncol=len(labels), frameon=False, fontsize=legend_fontsize, loc='lower center', \ bbox_to_anchor=(0.5, -0.105), handletextpad=0.5, columnspacing=0.5) else: ax.legend(handles=legend_elements, ncol=len(labels), fontsize=legend_fontsize, handletextpad=0.5, columnspacing=0.5) if 0 < ymax and 0 > ymin: plt.axhline(0, color='grey', lw=3, zorder=1) plt.tight_layout() #### reaction rates as a function of mass def plot_reactions(ifig, obj, cycle, reactive_isotope, plot_isos=[], reaction_type="all", minFij=None, maxFij=None, minX=None, maxX=None, **kwargs): '''Plot the reactions and mass fractions of an isotope as a function of mass. Parameters ---------- ifig: integer The figure number. obj: mppnp_reader object The mppnp simulation to plot read in as a mppnp_reader object reactive_isotope: string The isotope to plot reactions for cycle: float The cycle to plot the reactions for. plot_isos: string or list of strings The isotopes to plot mass fractions for. reaction_type: string Possible inputs: production: only reactions that produce reactive_isotope destruction: only reactions that destroy reactive_isotope all: all reactions involving reactive_isotope Default: all minFij: float Minimum log10(f_ij) maxFij: float Maximum log10(f_ij) minX: float Minimum log10(X_i) maxX: float Maximum log10(X_i) kwargs: figsize: The size of the figure (tuple). label_fontsize: The fontsize of the x-axis and y-axis labels (integer). legend_fontsize: The fontsize of the legend (integer). colours: Colours for the markers and lines (list). markers: Marker shapes (list). edgecolours: Edgecolour of the markers (list). linestyles: Linestyles for the connecting lines (list). ''' if isinstance(plot_isos, str): plot_isos = [plot_isos] figsize = kwargs.get('figsize', (6, 4)) label_fontsize = kwargs.get('label_fontsize', 14) legend_fontsize = kwargs.get('legend_fontsize', 12) colours = kwargs.get('colours', None) markers = kwargs.get('markers', None) edgecolours = kwargs.get('edgecolours', None) linestyles = kwargs.get('linestyles', None) plt.close(ifig) fig, ax = plt.subplots(num=ifig, figsize=figsize) handles = [] labels = [] relevant_f_ij_m = obj.get_relevant_fluxes(reactive_isotope, cycle) mass = obj.get('out', 11000, 'mass') rtypes = ['(g,n)', '(n,g)', '(g,p)', '(p,g)', '(g,a)', '(a,g)', '(n,a)', '(a,n)', '(p,a)', '(a,p)','(p,n)','(n,p)', '(+,g)', '(-,g)'] rstyle = [utils.linestylecb(i,a=max(3,int(len(mass)/8)), b=5) for i in range(2*len(rtypes)+20)] vv = 0 counter = 0 if reaction_type != "destruction": for idx, row in relevant_f_ij_m.iterrows(): with np.errstate(invalid='ignore', divide='ignore'): log_fij = np.log10(row["f_ij"]) valid = ~np.isnan(log_fij) if minFij is not None: if np.all(log_fij[valid] < minFij): continue if row.TYPE != '(v,v)': q = rtypes.index(row.TYPE) else: q = 2*len(rtypes) + vv vv+=1 if colours is not None: colour = colours[counter] else: colour = rstyle[q][2] if markers is not None: marker = markers[counter] else: marker = rstyle[q][1] if linestyles is not None: linestyle = linestyles[counter] else: linestyle = rstyle[q][0] if edgecolours is not None: edgecolor = edgecolours[counter] else: edgecolor = 'none' label = format_iso(row.ISO1) + format_rtype(row.TYPE, info=row) + format_iso(row.ISO2) h, = ax.plot(mass, log_fij, label=label, linestyle=linestyle, color=colour, \ marker=marker, markerfacecolor=colour, markeredgecolor=edgecolor, markevery=rstyle[q][3]) handles.append(h) labels.append(label) counter += 1 if reaction_type != "production": for idx, row in relevant_f_ij_m.iterrows(): with np.errstate(invalid='ignore', divide='ignore'): log_fij = np.log10(-row["f_ij"]) valid = ~np.isnan(log_fij) if minFij is not None: if np.all(log_fij[valid] < minFij): continue if row.TYPE != '(v,v)': q = rtypes.index(row.TYPE) + len(rtypes) else: q = 2*len(rtypes) + vv vv+=1 if colours is not None: colour = colours[counter] else: colour = rstyle[q][2] if markers is not None: marker = markers[counter] else: marker = rstyle[q][1] if linestyles is not None: linestyle = linestyles[counter] else: linestyle = rstyle[q][0] if edgecolours is not None: edgecolor = edgecolours[counter] else: edgecolor = 'none' label = format_iso(row.ISO2) + format_rtype(row.TYPE, flip=True, info=row) + format_iso(row.ISO1) h, = ax.plot(mass, log_fij, label=label, linestyle=linestyle, color=colour, \ marker=marker, markerfacecolor=colour, markeredgecolor=edgecolor, markevery=rstyle[q][3]) handles.append(h) labels.append(label) counter += 1 ax.set_xlim(mass[0].value, mass[-1].value) if minFij is not None and maxFij is not None: ax.set_ylim(minFij, maxFij) ax.set_xlabel(r"Mass ($\mathrm{M}_\odot$)", fontsize=label_fontsize) ax.set_ylabel(r"$\log_{10}(f_{ij})$", fontsize=label_fontsize) ax.tick_params(labelsize=legend_fontsize) if plot_isos != []: ax2 = ax.twinx() Xs = np.log10(obj.get("out", cycle, plot_isos)) for i, iso in enumerate(plot_isos): cbstyle = utils.linestylecb(i, a=max(3,int(len(mass)/8)), b=5) label = r'$X$(' + format_iso(iso) + ')' if len(plot_isos) == 1: h2, = ax2.plot(mass, Xs, color='black', label=label, linestyle=cbstyle[0], marker=cbstyle[1], markevery=cbstyle[3]) else: h2, = ax2.plot(mass, Xs[:,i], color='black', label=label, linestyle=cbstyle[0], marker=cbstyle[1], markevery=cbstyle[3]) handles.append(h2) labels.append(label) if minX is not None and maxX is not None: ax2.set_ylim(minX, maxX) ax2.set_ylabel(r'$\log_{10}(X_i)$',fontsize=label_fontsize) ax2.tick_params(labelsize=legend_fontsize) ax2.legend(handles=handles, labels=labels, ncol=1, loc='lower right', fontsize=legend_fontsize) for artist in ax2.get_children(): if isinstance(artist, plt.Line2D): artist.set_rasterized(True) else: ax.legend(handles=handles, labels=labels, ncol=1, loc='lower right', fontsize=legend_fontsize) for artist in ax.get_children(): if isinstance(artist, plt.Line2D): artist.set_rasterized(True) plt.tight_layout() ### charts helper functions def find_related_isotopes(obj, central_isotope, delta_A, pairs, fluxes_at_mass_coord, ttt, minFij, maxFij, single=False): '''Finds isotopes related to your central_isotope ''' pair_dict = defaultdict(list) for pair, flux in zip(pairs, fluxes_at_mass_coord): pair_dict[pair[0]].append((pair, flux)) pair_dict[pair[1]].append((pair[::-1], -flux)) # Use sets for faster membership testing to_work = {central_isotope} been_done = set() input_iso, output_iso, selected_fluxes = [], [], [] central_A = obj.get_ZNAele(central_isotope)[2] while to_work: current_isotope = to_work.pop() if current_isotope in been_done: continue relevant_pairs = pair_dict[current_isotope] if not relevant_pairs: continue # Compute thresholds once per isotope if ttt == 'max': fluxes = [flux for _, flux in relevant_pairs] min_flux, max_flux = min(fluxes), max(fluxes) thresholds = (min_flux * 0.1, max_flux * 0.1) else: # ttt == 'range' thresholds = (minFij, maxFij) # Filter significant fluxes for pair, flux in relevant_pairs: if flux <= thresholds[0] or flux >= thresholds[1]: input_iso.append(current_isotope) output_iso.append(pair[1]) selected_fluxes.append(flux) if abs(obj.get_ZNAele(pair[1])[2] - central_A) <= delta_A and pair[1] not in been_done: to_work.add(pair[1]) been_done.add(current_isotope) if single: break return input_iso, output_iso, selected_fluxes def find_isotopes_in_range(obj, Nrange, Zrange, pairs, fluxes_at_mass_coord, minFij, maxFij): '''Finds isotopes in selected N,Z range ''' N1, N2 = Nrange Z1, Z2 = Zrange input_iso, output_iso = [], [] selected_fluxes = [] for Z in range(Z1,Z2+1): ele = obj.get_element(Z) for N in range(N1,N2+1): isotope = ele + '-' + str(N + Z) if obj.exists_in_sim(isotope): for pair, flux in zip(pairs, fluxes_at_mass_coord): if isotope in pair: logdflux = np.log10(np.abs(flux)) # Check if flux magnitude is within the specified range if logdflux >= minFij and logdflux <= maxFij: if pair[0] == isotope: input_iso.append(isotope) output_iso.append(pair[1]) else: input_iso.append(pair[0]) output_iso.append(isotope) selected_fluxes.append(flux) return input_iso, output_iso, selected_fluxes def get_color(value, minFij, maxFij, cmap=plt.cm.jet, get='color'): '''Sets the colour bar scale. ''' def set_scale(): norm = mcolors.Normalize(vmin=minFij, vmax=maxFij) return norm norm = set_scale() if get == 'color': normalized_values = norm(value) rgba_colors = cmap(normalized_values) if type(rgba_colors) != type(np.array): rgba_colors = np.array(rgba_colors) rgba_colors[value < minFij, 3] = 0 # turn off if under minFij return rgba_colors elif get == 'map': return norm, cmap def scale_size(logdflux, minFij, maxFij, min_size=1, max_size=50): '''Sets the size of the arrows. ''' # Normalize logdflux to range between 0 and 1 normalized_flux = (logdflux - minFij) / (maxFij - minFij) return normalized_flux * (max_size - min_size) + min_size ### chart plotting def plot_chart(ifig, obj, cycle, mass_idx, Nrange=(10,20), Zrange=(10,20), central_isotope='', delta_A=10, relevant=False, ttt='max', minFij=None, maxFij=None, minX=None, maxX=None, network=True, abundances=False, split=False, **kwargs): '''Plot a nuclear reaction network or abundance plot. Parameters ---------- ifig: integer The figure number. obj: mppnp_reader object The mppnp simulation to plot read in as a mppnp_reader object cycle: float The cycle to plot the chart for. mass_idx: integer Mass index to plot the chart for. Nrange: tuple Minimum and maximum N to look at (use instead of central_isotope and delta_A). Zrange: tuple Minimum and maximum Z to look at (use instead of central_isotope and delta_A). central_isotope: string Plot isotopes around central_isotope. Use this with delta_A. delta_A: integer Range of A around central_isotope to look for isotopes. relevant: boolean Whether to plot only fluxes that are related to central_isotope. ttt: string Honestly I forget exactly how this works. It's like.... a timesaver flag but like keep it as max. minFij: float Minimum log10(f_ij) maxFij: float Maximum log10(f_ij) minX: float Minimum log10(X_i) maxX: float Maximum log10(X_i) network: boolean Whether to plot the reaction network. abundances: boolean Whether to plot abundances. split: boolean Whether to separate network and abundance plots. kwargs: figsize: The size of the figure (tuple). label_fontsize: The fontsize of the x-axis and y-axis labels (integer). tick_fontsize: The fontsize of the x-axis and y-axis ticks (integer). isotope_fontsize: The fontsize of the isotopes (integer). max_arrow_size: The maximum size of the arrow Warning: This will plot the range of minFij and maxFij you asked for --- if something is greater than maxFij it will be ommitted! ''' figsize = kwargs.get('figsize', (12,8)) label_fontsize = kwargs.get('label_fontsize', 35) tick_fontsize = kwargs.get('tick_fontsize', 20) isotope_fontsize = kwargs.get('isotope_fontsize', 15) max_arrow_size = kwargs.get('max_arrow_size', 8) iso_massf = obj.get("out", 11000, "iso_massf") def get_abundance(isotope, idx): j = np.where(obj.all_isos == isotope)[0][0] return iso_massf[idx,j] def setup_figure(): plt.close(ifig) if split: fig, (ax1, ax2) = plt.subplots(1, 2, num=ifig, figsize=figsize) return fig, [ax1, ax2] else: fig, ax = plt.subplots(num=ifig, figsize=figsize) return fig, [ax] def get_isotope_data(): fluxes = obj.get_balanced_fluxes(cycle) fluxes_at_mass_coord = np.stack(fluxes['f_ij'].to_numpy())[:, mass_idx] pairs = list(zip(fluxes['ISO1'], fluxes['ISO2'])) if relevant: return find_related_isotopes(obj, central_isotope, delta_A, pairs, fluxes_at_mass_coord, ttt, minFij, maxFij) return find_isotopes_in_range(obj, Nrange, Zrange, pairs, fluxes_at_mass_coord, minFij, maxFij) def combine_fluxes(input_iso, output_iso, fluxes): combined_fluxes = {} for iso1, iso2, flux in zip(input_iso, output_iso, fluxes): if iso1 == iso2: continue pair_key = tuple(sorted([iso1, iso2])) flux_direction = 1 if (iso1, iso2) == pair_key else -1 combined_fluxes[pair_key] = combined_fluxes.get(pair_key, 0) + (flux * flux_direction) return combined_fluxes def plot_network_arrows(axes, combined_fluxes, minFij, maxFij): Zs, Ns = [], [] for (iso1, iso2), total_flux in combined_fluxes.items(): logdflux = np.log10(np.abs(total_flux)) if logdflux < minFij: continue Z, N, *_ = obj.get_ZNAele([iso1, iso2]) Z1, Z2 = Z N1, N2 = N Zs.extend([Z1, Z2]) Ns.extend([N1, N2]) if network: start, end = (N1, Z1), (N2, Z2) pos = [start, end][::-1 * int(np.sign(total_flux))] arrow_width = scale_size(logdflux, minFij, maxFij, min_size=0.1, max_size=max_arrow_size) arrow_color = get_color(logdflux, minFij, maxFij) axes[0].annotate('', xy=pos[0], xytext=pos[1], arrowprops=dict(facecolor=arrow_color, edgecolor=arrow_color, arrowstyle='->', linewidth=arrow_width, zorder=int(arrow_width)+5)) #arrowstyle=f'-|>,head_length={max(0.5,arrow_width/8)},head_width={max(0.5,arrow_width/8)}', return Zs, Ns def plot_isotope_boxes(axes, Zs, Ns, minX, maxX): minZ, maxZ = min(Zs), max(Zs) for Z in range(minZ, maxZ+1): flag = False highest_N = 0 for N in range(200): isotope = f"{obj.get_element(Z)}-{Z+N}" if not obj.exists_in_sim(isotope): if flag: break continue if not (min(Ns) <= N <= max(Ns) and min(Zs) <= Z <= max(Zs)): continue flag = True highest_N = N for ax_idx, ax in enumerate(axes): # Plot abundance colors only in second subplot when split should_color = abundances and (split and ax_idx == 1 or not split) plot_single_isotope_box(ax, N, Z, isotope, should_color, minX, maxX, obj.is_stable(isotope)) if flag: element_symbol = obj.get_element(Z) for ax in axes: ax.text(highest_N+1, Z, element_symbol, fontsize=isotope_fontsize, ha='center', va='center', color='black', fontweight='medium', zorder=101) def plot_single_isotope_box(ax, N, Z, isotope, should_color, minX, maxX, is_stable): if should_color: logxi = np.log10(get_abundance(isotope, mass_idx)) ab_color = get_color(logxi, minX, maxX, cmap=plt.cm.pink) facecolor = ab_color else: facecolor = 'none' ax.text(N, Z, str(N+Z), fontsize=isotope_fontsize, ha='center', va='center', color='black', zorder=101) rect_props = { 'xy': (N-0.5, Z-0.5), 'width': 1, 'height': 1, 'facecolor': facecolor, 'edgecolor': 'black', 'zorder': 2 if is_stable else 1 } if is_stable: rect_props['lw'] = 4 ax.add_patch(plt.Rectangle(**rect_props)) def setup_colorbars(fig, axes, minFij, maxFij, minX, maxX): if network: norm, cmap = get_color(-1, minFij, maxFij, get='map') sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = fig.colorbar(sm, ax=axes[0], ticks=np.linspace(minFij, maxFij, 6)) cbar.set_label(r'$\log_{10}(f_{ij})$', fontsize=label_fontsize) cbar.ax.tick_params(labelsize=tick_fontsize) if abundances: norm, cmap = get_color(-1, minX, maxX, get='map', cmap=plt.cm.pink) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) ax = axes[1] if split else axes[0] cbar = fig.colorbar(sm, ax=ax, ticks=np.linspace(minX, maxX, 6)) cbar.set_label(r'$\log_{10}(X_i)$', fontsize=label_fontsize) cbar.ax.tick_params(labelsize=tick_fontsize) def setup_axes(axes, Ns, Zs): for ax in axes: ax.set_xlabel('Neutron Number', fontsize=label_fontsize) ax.set_ylabel('Proton Number', fontsize=label_fontsize) ax.set_xlim(min(Ns)-1., max(Ns)+1.5) ax.set_ylim(min(Zs)-1., max(Zs)+1.) ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.tick_params(axis='both', labelsize=tick_fontsize) # Main execution fig, axes = setup_figure() if network: input_iso, output_iso, fluxes = get_isotope_data() # Set flux limits if maxFij is None: maxFij = int(np.max(np.log10(np.abs(fluxes)))) if minFij is None: minFij = maxFij - 10 # Set abundance limits if maxX is None: maxX = 0 if minX is None: minX = maxX - 10 # Process and plot data if network: combined_fluxes = combine_fluxes(input_iso, output_iso, fluxes) Zs, Ns = plot_network_arrows(axes, combined_fluxes, minFij, maxFij) if not network and central_isotope=='': Zs, Ns = Zrange, Nrange elif not network: Z,N,_,_= obj.get_ZNAele(central_isotope) Zs = [Z-delta_A, Z+delta_A] Ns = [N-delta_A, N+delta_A] plot_isotope_boxes(axes, Zs, Ns, minX, maxX) # Setup visualization setup_colorbars(fig, axes, minFij, maxFij, minX, maxX) setup_axes(axes, Ns, Zs) plt.tight_layout() def plot_fourway_live(ifig, obj, cycle): '''Plot four mini network charts + mass fraction as a function of mass in an interactive plot. Parameters ---------- ifig: integer The figure number. obj: mppnp_reader object The mppnp simulation to plot read in as a mppnp_reader object cycle: float The cycle to plot the fourway for. ''' plt.close(ifig) fig = plt.figure(num=ifig,figsize=(8,12)) gs = plt.GridSpec(4, 2, height_ratios=[1, 1, 0.1, 1], hspace=0.13, wspace=0.0) zone_color = ['green', 'peru', 'magenta','k'] def update_plot(central_isotope, z1,z2,z3,z4, minimum, maximum): fig.clear() fig.suptitle('Central Isotope: ' + central_isotope, fontsize=16, y=0.9) zone_indices = [z1,z2,z3,z4] axes = [] Ns, Zs = [], [] balanced_fluxes = obj.get_balanced_fluxes(cycle) for i in range(4): ax = fig.add_subplot(gs[i // 2, i % 2]) axes.append(ax) ax.text(0.05, 0.95, 'Mass Zone: ' + str(round(mass[zone_indices[i]], 3)), transform=ax.transAxes, fontsize=12, verticalalignment='top', color=zone_color[i]) pairs = list(zip(balanced_fluxes['ISO1'], balanced_fluxes['ISO2'])) fluxes_at_mass_coord = np.stack(balanced_fluxes['f_ij'].to_numpy())[:, zone_indices[i]] input_iso, output_iso, fluxes = find_related_isotopes(obj,central_isotope, 1, pairs, fluxes_at_mass_coord, 'range', minimum, maximum, single=True) for iso1, iso2, flux in zip(input_iso, output_iso, fluxes): if iso1 == iso2: continue logdflux = np.log10(np.abs(flux)) if logdflux < minimum: continue point_size = scale_size(logdflux, minimum, maximum) arrow_width = scale_size(logdflux, minimum, maximum, min_size=0.5, max_size=4) # Use smaller min and max for line width color = get_color(logdflux, minimum, maximum) Z, N, *_ = obj.get_ZNAele([iso1, iso2]) Z1, Z2 = Z N1, N2 = N Ns.append(N1) Ns.append(N2) Zs.append(Z1) Zs.append(Z2) start = (N1, Z1) end = (N2, Z2) pos = [start, end][::-1 * int(np.sign(flux))] ax.annotate('', xy=pos[0], xytext=pos[1], arrowprops=dict(facecolor=color, edgecolor=color, arrowstyle='->', linewidth=arrow_width, zorder=int(arrow_width)+5)) minZ, maxZ = min(Zs), max(Zs) for Z in range(minZ, maxZ+1): flag = False for N in range(200): isotope = obj.get_element(Z) + '-' + str(Z+N) if obj.exists_in_sim(isotope) and ((N <= np.max(Ns) and N >= np.min(Ns)) and (Z <= np.max(Zs) and Z >= np.min(Zs))): ax.text(N, Z, str(N+Z), fontsize=8, ha='center', va='center', color='black', zorder=101) if obj.is_stable(isotope): ax.add_patch(plt.Rectangle((N-0.5, Z-0.5), 1, 1, fill=False, edgecolor='black', lw=4, zorder=2)) else: ax.add_patch(plt.Rectangle((N-0.5, Z-0.5), 1, 1, fill=False, edgecolor='black', zorder=1)) flag = True highest_N = N else: if flag: break for Z in range(minZ, maxZ+1): element_symbol = obj.get_element(Z) for ax in axes: ax.text(highest_N+1, Z, f'{element_symbol}', fontsize=8, ha='center', va='center', color='black', fontweight='medium', zorder=101) if i == 0 or i == 2: ax.set_ylabel('Proton Number') if i == 2 or i == 3: ax.set_xlabel('Neutron Number') # Create a single shared colorbar for the 2x2 subplots norm, cmap = get_color(-1, minimum, maximum, get='map') sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = plt.colorbar(sm, ax=axes, orientation='vertical', fraction=0.02, pad=0.02) cbar.set_label(r'$\log_{10}(f_{ij})$') # Share x and y axes among the 2x2 subplots for jj, ax in enumerate(axes): ax.set_ylim(np.min(Zs)-1.5,np.max(Zs)+1.5) ax.set_xlim(np.min(Ns)-1.5,np.max(Ns)+1.5) #ax.label_outer() ax.xaxis.set_major_locator(MaxNLocator(integer=True)) ax.yaxis.set_major_locator(MaxNLocator(integer=True)) ax.xaxis.set_tick_params(labelbottom=True) if jj == 1 or jj == 3: ax.set_yticklabels([]) ax5 = fig.add_subplot(gs[3, :]) # Add the bottom plot on the fourth row of the GridSpec Xi = obj.get('out',cycle,central_isotope) ax5.semilogy(mass, Xi) for i, zone_idx in enumerate(zone_indices): ax5.axvline(mass[zone_idx], color=zone_color[i]) ax5.set_xlim(mass[0], mass[-1]) ax5.set_xlabel(r'Mass ($\mathrm{M}_\odot$)') ax5.set_ylabel(r'$X$' + '(' + format_iso(central_isotope) + ')') isotope_dropdown = widgets.Dropdown( options=obj.all_isos, value='Se-74', # default value description='Isotope:') mass = obj.get('out', cycle, 'mass').value ms1 = widgets.IntSlider(min=0, max=mass.size-1, step=1, value=int(mass.size/10), description='(0,0):') ms2 = widgets.IntSlider(min=0, max=mass.size-1, step=1, value=int(mass.size/5), description='(0,1):') ms3 = widgets.IntSlider(min=0, max=mass.size-1, step=1, value=int(mass.size/2), description='(1,0):') ms4 = widgets.IntSlider(min=0, max=mass.size-1, step=1, value=int(mass.size/1.5), description='(1,1):') fs1 = widgets.IntSlider(min=-99, max=100, step=1, value=-20, description=r'$\min f$') fs2 = widgets.IntSlider(min=-99, max=100, step=1, value=0, description=r'$\max f$') out = widgets.interactive_output(update_plot, {'central_isotope': isotope_dropdown, 'z1': ms1, 'z2': ms2, 'z3': ms3, 'z4': ms4, 'minimum': fs1, 'maximum':fs2}) ui = widgets.VBox([isotope_dropdown, ms1, ms2, ms3, ms4, fs1, fs2]) display(ui, out)