Source code for idstools.view.common

import logging
import os
import sys

import matplotlib


def _is_jupyter() -> bool:
    """Return True if running inside a Jupyter notebook/lab/Colab kernel."""
    try:
        from IPython import get_ipython

        shell = get_ipython()
        if shell is None:
            return False
        shell_class = shell.__class__.__name__
        # ZMQInteractiveShell: Jupyter Notebook/Lab
        # Shell: Google Colab
        return shell_class in ("ZMQInteractiveShell", "Shell")
    except ImportError:
        return False


# Select the appropriate matplotlib backend
if _is_jupyter():
    if "matplotlib.pyplot" not in sys.modules:
        try:
            import ipympl  # noqa: F401 - imported to check availability

            matplotlib.use("widget")
        except ImportError:
            matplotlib.use("agg")
elif sys.platform.startswith("win") or "DISPLAY" in os.environ:

    try:
        import tkinter  # noqa: F401 - imported to check availability

        matplotlib.use("TkAgg")
    except (ImportError, ModuleNotFoundError):
        matplotlib.use("agg")
else:
    matplotlib.use("agg")

import matplotlib.pyplot as plt  # noqa: E402

logger = logging.getLogger("module")

current_directory = os.path.abspath(os.path.dirname(__file__))
# reach to `share` directory (sys.prefix won't work if using --prefix option)
share_directory = os.path.abspath(os.path.join(current_directory, "../../../../../"))
mplstyle_filepath = os.path.join(share_directory, r"share/styles/scientific.mplstyle")

if os.path.exists(mplstyle_filepath):
    plt.style.use(mplstyle_filepath)
else:
    plt.style.use(os.path.join(current_directory, r"styles/scientific.mplstyle"))


try:
    from rich.console import Console
    from rich.panel import Panel
    from rich.pretty import Pretty, pprint

    rich_available = True
except ImportError:
    rich_available = False


[docs]class PlotCanvas: # https://matplotlib.org/stable/tutorials/intermediate/arranging_axes.html def __init__(self, nrows=1, ncols=1, *args, **kwargs) -> None: # self.fig, self.axes_array = plt.subplots(nrows, ncols) self.nrows = nrows self.ncols = ncols self.fig = plt.figure(*args, **kwargs) self.fig.subplots_adjust(hspace=0.5, wspace=0.5) # Share axes # https://matplotlib.org/stable/gallery/subplots_axes_and_figures/shared_axis_demo.html#sphx-glr-gallery-subplots-axes-and-figures-shared-axis-demo-py # https://matplotlib.org/stable/gallery/subplots_axes_and_figures/share_axis_lims_views.html#sphx-glr-gallery-subplots-axes-and-figures-share-axis-lims-views-py
[docs] def add_axes( self, title=None, xlabel=None, ylabel=None, row=0, col=0, rowspan=1, colspan=1, **kwargs, ): """ Add a new subplot axes to the figure at a specified grid position. Creates a subplot at the given row/column location with optional spanning across multiple grid cells. Automatically sets title and axis labels if provided. Args: title (str, optional): Title for the axes. Defaults to None. xlabel (str, optional): Label for the x-axis. Defaults to None. ylabel (str, optional): Label for the y-axis. Defaults to None. row (int, optional): Row index in the grid (0-indexed). Defaults to 0. col (int, optional): Column index in the grid (0-indexed). Defaults to 0. rowspan (int, optional): Number of rows this axes spans. Defaults to 1. colspan (int, optional): Number of columns this axes spans. Defaults to 1. kwargs: Additional keyword arguments passed to plt.subplot2grid(). Returns: matplotlib.axes.Axes: The created axes object for plotting. Examples: >>> canvas = PlotCanvas(nrows=2, ncols=2) >>> ax = canvas.add_axes(title="Main Plot", xlabel="X", ylabel="Y", row=0, col=0) >>> ax.plot([1, 2, 3], [1, 4, 9]) """ ax = plt.subplot2grid( shape=(self.nrows, self.ncols), loc=(row, col), rowspan=rowspan, colspan=colspan, fig=self.fig, **kwargs, ) if title is not None: ax.set_title(title) if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) return ax
[docs] def save( self, fname, width=11.69, height=8.27, dpi="figure", ): """ Save the current matplotlib figure to a file. Saves the figure with specified dimensions and resolution. Supports all matplotlib-compatible file formats (PDF, PNG, SVG, etc.) based on file extension. Args: fname (str): Output filename (e.g., 'figure.pdf', 'plot.png'). The file extension determines the format. width (float, optional): Figure width in inches. Defaults to 11.69 (A4 width). height (float, optional): Figure height in inches. Defaults to 8.27 (A4 height). dpi (int or str, optional): Resolution in dots per inch. If 'figure', uses the default figure DPI. Defaults to 'figure'. Returns: None Examples: >>> canvas = PlotCanvas() >>> ax = canvas.add_axes() >>> ax.plot([1, 2, 3], [1, 2, 3]) >>> canvas.save('my_plot.pdf') # Save as PDF >>> canvas.save('my_plot.png', dpi=300) # Save as PNG with high resolution """ fig = plt.gcf() fig.set_size_inches(width, height) try: fig.savefig(fname, dpi=dpi) print(f"----> Figure saved to {fname}", file=sys.stderr) except Exception as e: logger.debug(f"{e}")
[docs] def set_text(self, x=0.001, y=0.985, text="", ha="left", fontsize=7): """ Add text annotation to the figure at a specific position. Places text at figure coordinates (independent of axes), useful for adding annotations, credits, or metadata to the entire figure. Args: x (float, optional): X-coordinate in figure coordinates (0=left, 1=right). Defaults to 0.001 (near left edge). y (float, optional): Y-coordinate in figure coordinates (0=bottom, 1=top). Defaults to 0.985 (near top). text (str, optional): Text string to display. Defaults to empty string. ha (str, optional): Horizontal alignment ('left', 'center', 'right'). Defaults to 'left'. fontsize (int, optional): Font size in points. Defaults to 7. Returns: None Examples: >>> canvas = PlotCanvas() >>> canvas.set_text(x=0.5, y=0.95, text="Main Title", ha="center", fontsize=14) """ plt.figtext( x, y, text, ha=ha, fontsize=fontsize, )
[docs] def set_sup_title(self, text="", *args, **kwargs): """ Set the super-title (title spanning all subplots) for the figure. Args: text (str, optional): Super-title text. Defaults to empty string. args: Positional arguments passed to matplotlib's suptitle(). kwargs: Keyword arguments (e.g., fontsize, color) passed to matplotlib's suptitle(). Returns: None Examples: >>> canvas = PlotCanvas(nrows=2, ncols=2) >>> canvas.set_sup_title("Main Figure Title", fontsize=16, fontweight='bold') """ plt.suptitle(text, *args, **kwargs)
[docs] def show(self, *args, **kwargs): """ Display the figure in a window and maximize it if possible. Attempts to maximize the figure window and show it. If window resizing is not supported by the current matplotlib backend, the figure is displayed normally. Args: args: Positional arguments passed to plt.show(). kwargs: Keyword arguments passed to plt.show(). Returns: None Notes: Uses the TkAgg backend for window resizing when available. Other backends (agg, Qt) may not support window maximization. Examples: >>> canvas = PlotCanvas() >>> ax = canvas.add_axes() >>> ax.plot([1, 2, 3], [1, 4, 9]) >>> canvas.show() """ backend = matplotlib.get_backend().lower() if _is_jupyter(): try: from IPython.display import display display(self.fig) if backend != "module://matplotlib_ipympl.backend_nbagg": plt.close("all") except ImportError: pass return wm = self.get_current_fig_manager() try: # Try to maximize the window (only works with TkAgg backend) window = wm.window screen_y = window.winfo_screenheight() screen_x = window.winfo_screenwidth() wm.resize(screen_x, screen_y) except AttributeError: # Backend doesn't support window resizing (e.g., 'agg', 'Qt', etc.) logger.debug("Window resizing not supported for current matplotlib backend") plt.show(*args, **kwargs)
[docs] def get_current_fig_manager(self): """ Get the current matplotlib figure manager. Returns the matplotlib figure manager for the active figure, which can be used to interact with the figure window and backend. Returns: matplotlib.backend_bases.FigureManagerBase: The current figure manager object. Examples: >>> canvas = PlotCanvas() >>> fig_mgr = canvas.get_current_fig_manager() >>> # Can access window properties, etc. """ return plt.get_current_fig_manager()
[docs] @staticmethod def is_axes_empty(ax): """ Check if a given Matplotlib Axes object is empty. An Axes object is considered empty if it has no data, no lines, no patches, and no texts. Parameters: ax (matplotlib.axes.Axes): The Axes object to check. Returns: bool: True if the Axes object is empty, False otherwise. """ return not ax.has_data() and len(ax.lines) == 0 and len(ax.patches) == 0 and len(ax.texts) == 0
[docs] def remove_empty_axes(self): """ Remove empty axes from the figure. This method iterates over all axes in the figure and removes those that are empty. An axis is considered empty if the `PlotCanvas.is_axes_empty` method returns True. Returns: None """ for ax in self.fig.axes[:]: # Iterate over a copy of the axes list if PlotCanvas.is_axes_empty(ax): self.fig.delaxes(ax)
[docs] def set_style(self, style="default"): """ The function `setStyle` in allows you to set different color schemes for plots using Matplotlib based on the specified style parameter. Available styles are vibrant, retro, muted, high-vis, contrast, bright Args: style: The `setStyle` function allows you to set different color schemes for your plots based on the `style` parameter you provide. Defaults to default """ if style == "default": # Standard SciencePlots color cycle # Set color cycle: blue, green, yellow, red, violet, gray matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( "color", ["0C5DA5", "00B945", "FF9500", "FF2C00", "845B97", "474747", "9e9e9e"], ) if style == "vibrant": # Vibrant color scheme # color-blind safe # from Paul Tot's website: https://personal.sron.nl/~pault/ # Set color cycle matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( "color", ["EE7733", "0077BB", "33BBEE", "EE3377", "CC3311", "009988", "BBBBBB"], ) if style == "retro": # Retro color style # Set color cycle matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( "color", ["4165c0", "e770a2", "5ac3be", "696969", "f79a1e", "ba7dcd"] ) if style == "muted": # Muted color scheme # color-blind safe # from Paul Tot's website: https://personal.sron.nl/~pault/ # Set color cycle matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( "color", [ "CC6677", "332288", "DDCC77", "117733", "88CCEE", "882255", "44AA99", "999933", "AA4499", "DDDDDD", ], ) if style == "light": # Light color scheme # color-blind safe # from Paul Tot's website: https://personal.sron.nl/~pault/ # Set color cycle matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( "color", [ "77AADD", "EE8866", "EEDD88", "FFAABB", "99DDFF", "44BB99", "BBCC33", "AAAA00", "DDDDDD", ], ) if style == "high-vis": # Matplotlib style for high visability plots (i.e., bright colors!!!) # Set color cycle matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( "color", ["0d49fb", "e6091c", "26eb47", "8936df", "fec32d", "25d7fd"] ) + matplotlib.cycler("ls", ["-", "--", "-.", ":", "-", "--"]) if style == "contrast": # High-contrast color scheme # color-blind safe # from Paul Tot's website: https://personal.sron.nl/~pault/ # Set color cycle matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler("color", ["004488", "DDAA33", "BB5566"]) if style == "bright": # Bright color scheme # color-blind safe # from Paul Tot's website: https://personal.sron.nl/~pault/ # Set color cycle matplotlib.rcParams["axes.prop_cycle"] = matplotlib.cycler( "color", ["4477AA", "EE6677", "228833", "CCBB44", "66CCEE", "AA3377", "BBBBBB"], )
[docs] def update_style(self, param_string=""): """ Updates matplotlib rcParams using a semicolon-separated string. Example input: "lines.linewidth=2;axes.titlesize=16" """ for item in param_string.split(";"): if not item.strip(): continue try: key, value = item.split("=", 1) key = key.strip() value = eval(value.strip(), {}, {}) # Convert to actual Python type matplotlib.rcParams[key] = value except Exception as e: print(f"Error applying rcParam '{item}': {e}")
[docs]class BasePlot:
[docs] def database_info(self, ax, title, hostdir, shot, run, t): """ Add database and shot information as text annotation to a plot axes. Displays metadata (host directory, shot number, run number, and time) as text on the right side of the plot axis. Useful for tracking the source and time point of plotted data. Args: ax (matplotlib.axes.Axes): The axes object to annotate. title (str): Title for the plot. Time information is appended to this. hostdir (str): Host directory or database name (e.g., 'mdsplus', 'localhost'). shot (int): Tokamak shot number. run (int): Run number within the shot. t (float): Time point in seconds. Returns: None Examples: >>> ax = plt.gca() >>> plot = BasePlot() >>> plot.database_info(ax, "Plasma Profile", "mdsplus", 134174, 1, 0.5) """ plottitle = title plottitle += " (t={:.3f})".format(t) ax.set_title(plottitle) xmin, xmax = ax.get_xlim() ymin, ymax = ax.get_ylim() ax.text( xmax + 0.01 * abs(xmax), ymin + 0.5 * abs(ymax - ymin), "{0}-Shot:{1},{2}".format(hostdir, shot, run), horizontalalignment="left", verticalalignment="center", rotation="vertical", fontsize=7, )
[docs]class Terminal: tabsize = 10 TAB = " " * 16 LINE = "-" * 8 def __init__(self) -> None: if rich_available: self.console = Console()
[docs] def print(self, text, style=None, panel=False, pretty=False): """ Print formatted text to the console with optional styling. Prints text with optional Rich library formatting and styling. If Rich is available, supports styled output, panels, and pretty-printing. Falls back to standard print if Rich is not installed. Args: text (str or dict): Text string to print, or dictionary to pretty-print. style (str, optional): Rich text styling (e.g., 'green', 'bold red'). Defaults to 'green' if Rich is available. panel (bool, optional): If True, text is displayed in a Rich panel box. Ignored if Rich is unavailable. Defaults to False. pretty (bool, optional): If True, uses Rich Pretty formatting for better display of complex objects. Ignored if Rich is unavailable. Defaults to False. Returns: None Notes: - Dictionaries are automatically pretty-printed regardless of other options - Requires 'rich' package for advanced formatting. Falls back to standard print. - Set style=None to disable coloring with Rich. Examples: >>> terminal = Terminal() >>> terminal.print("Hello World", style="green") >>> terminal.print({"data": [1, 2, 3]}) # Pretty prints dict >>> terminal.print("Warning!", style="bold red", panel=True) """ if type(text) is dict: pprint(text, expand_all=True) return if style is None: style = "green" if rich_available: if pretty: text = Pretty(text) if panel: text = Panel(text) self.console.print(text, style=style, highlight=False) return print(text)