### MatPlotLib Graph Wrapper #### Written by Cal.W 2020, originally for MECH2700 but continually #### expanded upon. #### 2023 - Added UQ Colors #### 2023 - Added pltKeyClose function #### 2023 - Added UQ Default Colours to MatPlotLib __author__ = "Cal Wing" __version__ = "0.1.10" from collections.abc import Iterator import numpy as np import matplotlib import matplotlib.pyplot as plt import matplotlib.colors as colors from mpl_toolkits.axes_grid1 import make_axes_locatable from cycler import cycler # Define the UQ Colours UQ_COLOURS_DICT = { # Primary "purple": "#51247A", "white" : "#FFFFFF", "black" : "#000000", # Secondary "light_purple": "#962A8B", "red" : "#E62645", "green" : "#2EA836", "gold" : "#BB9D65", "neutral" : "#D7D1CC", "orange" : "#EB602B", "yellow" : "#FBB800", "blue" : "#4085C6", "aqua" : "#00A2C7", "dark_grey" : "#999490" } # Define a colour object that can do neat conversions & things, by default stores as hex value class ColourValue(str): def __new__(self, name, value): self.name = name self.value = colors.to_hex(value, True) return super().__new__(self, self.value) def __str__(self) -> str: return self.value def __repr__(self) -> str: return self.name + " " + self.value + " " + str(self.rgba()) def rgba(self) -> tuple[float, float, float, float]: return colors.to_rgba(self.value) def rgb(self) -> tuple[float, float, float]: return colors.to_rgb() def hex(self) -> str: return self.value def hsv(self) -> np.ndarray: return colors.rgb_to_hsv(self.rgb()) # Define the UQ Colours in a nicer object class ColourList(object): def __init__(self, colours: dict) -> None: self.colours = {} for key, value in colours.items(): self.colours[key] = ColourValue(key, value) setattr(self, key, self.colours[key]) def __getitem__(self, key: str) -> str: if key.replace(" ", "_") in self.colours.keys(): key = key.replace(" ", "_") return self.colours[key] def items(self): return self.colours.items() def __repr__(self) -> str: return f"Colour List of {len(self.colours)} colour{'s' if len(self.colours) > 0 else ''}: " + str(list(self.colours.keys())) UQ_COLOURS = ColourList(UQ_COLOURS_DICT) # Load UQ Colours into MatPlotLib # UQ colours are prefaced with 'uq:', so UQ red is 'uq:red' # Note: Any names That have a _ also have a version with spaces so both "uq:light_purple" and "uq:light purple" work uq_colour_mapping = {'uq:' + name: value for name, value in list(UQ_COLOURS.items()) + [(x[0].replace("_", " "), x[1]) for x in UQ_COLOURS.items() if "_" in x[0]]} colors.get_named_colors_mapping().update( uq_colour_mapping ) ## UQ Colour Cycler # +-----------------------------+-----------------------------+ # | Default (Tab) | UQ | # +-----------------------------+-----------------------------+ # | C00 | #1f77b4 -> tab:blue | #51247A -> uq:purple | # | C01 | #ff7f0e -> tab:orange | #4085C6 -> uq:blue | # | C02 | #2ca02c -> tab:green | #2EA836 -> uq:green | # | C03 | #d62728 -> tab:red | #E62645 -> uq:red | # | C04 | #9467bd -> tab:purple | #962A8B -> uq:light_purple | # | C05 | #8c564b -> tab:brown | #999490 -> uq:dark_grey | # | C06 | #e377c2 -> tab:pink | #EB602B -> uq:orange | # | C07 | #7f7f7f -> tab:grey | #FBB800 -> uq:yellow | # | C08 | #bcbd22 -> tab:olive | #00A2C7 -> uq:aqua | # | C09 | #17becf -> tab:cyan | #BB9D65 -> uq:gold | # | C10 | | #D7D1CC -> uq:neutral | # +-----------------------------+-----------------------------+ # Build a colour cycler uq_colour_cycler = cycler(color=[ UQ_COLOURS["purple"], #51247A -> C00 -> uq:purple UQ_COLOURS["blue"], #4085C6 -> C01 -> uq:blue UQ_COLOURS["green"], #2EA836 -> C02 -> uq:green UQ_COLOURS["red"], #E62645 -> C03 -> uq:red UQ_COLOURS["light_purple"], #962A8B -> C04 -> uq:light_purple UQ_COLOURS["dark_grey"], #999490 -> C05 -> uq:dark_grey UQ_COLOURS["orange"], #EB602B -> C06 -> uq:orange UQ_COLOURS["yellow"], #FBB800 -> C07 -> uq:yellow UQ_COLOURS["aqua"], #00A2C7 -> C08 -> uq:aqua UQ_COLOURS["gold"], #BB9D65 -> C09 -> uq:gold UQ_COLOURS["neutral"] #D7D1CC -> C10 -> uq:neutral ]) # Tell MatPlotLib to use said cycler plt.rc('axes', prop_cycle=uq_colour_cycler) ## UQ Colour Gradient (Not very good :( ) uq_colour_map_grad = [ UQ_COLOURS["purple"], UQ_COLOURS["light_purple"], UQ_COLOURS["light_purple"], UQ_COLOURS["blue"], UQ_COLOURS["blue"], UQ_COLOURS["aqua"], UQ_COLOURS["green"], UQ_COLOURS["green"], UQ_COLOURS["green"], UQ_COLOURS["yellow"], UQ_COLOURS["yellow"] ] #Convert to RGB values uq_colour_map_grad = [colors.to_rgb(c) for c in uq_colour_map_grad] #Populate the working dict uq_colour_dict = { "red": [], "green": [], "blue": [], } for i, c in enumerate(uq_colour_map_grad): offset = i / (len(uq_colour_map_grad) - 1) uq_colour_dict["red"].append( (offset, c[0], c[0]) ) uq_colour_dict["green"].append( (offset, c[1], c[1]) ) uq_colour_dict["blue"].append( (offset, c[2], c[2]) ) #Define & register the colour map itself uq_cmap = colors.LinearSegmentedColormap('uq',segmentdata=uq_colour_dict) matplotlib.colormaps.register(uq_cmap) # Set the colour map - Not a very good default so not doing that #plt.set_cmap("uq") ## Colorbar Function by Joseph Long & Mike Lampton # Retrieved from https://joseph-long.com/writing/colorbars/ on 31/10/2021 # Minor Modifications made by Cal.W 2021 def colorbar(mappable, size="5%", pad=0.05, lsize=None, lpad=None, lax=True, **kwargs): last_axes = plt.gca() ax = mappable.axes fig = ax.figure divider = make_axes_locatable(ax) if lax: lsize = lsize if lsize is not None else size lpad = lpad if lpad is not None else pad dax = divider.append_axes("left", size=lsize, pad=lpad) dax.set_frame_on(False) dax.grid(False) dax.set_yticks([]) dax.set_xticks([]) cax = divider.append_axes("right", size=size, pad=pad) cbar = fig.colorbar(mappable, cax=cax, **kwargs) plt.sca(last_axes) return cbar ## Make Graph Function def makeGraph(graphData, showPlot=True, doProgramBlock=True, figSavePath=None, hideEmptyAxis=False) -> tuple[matplotlib.figure.Figure, tuple[matplotlib.axes.Axes, ...]]: """ Generate a matplotlib graph based on a simple dictionary object Input: dict(graphData): The dictionary containing all the graph data - see example for more info bool(showPlot[True]): Should the function display the plot bool(doProgramBlock[True]): Should the function block the main python thread str(figSavePath[None]): The path to save a copy of the figure, calls fig.savefig if not None Returns: The the figure and axes from matplotlib.pyplot.subplots() From 'matplotlib.pyplot.subplots(): fig : `matplotlib.figure.Figure` ax : `matplotlib.axes.Axes` or array of Axes *ax* can be either a single `~matplotlib.axes.Axes` object or an array of Axes objects if more than one subplot was created. Example: makeGraph({ "title": "Simple Plot", "xLabel": "x label", "yLabel": "y label", "plots": [ {"x":[0,1,2,3,4], "y":[0,1,2,3,4], "label":"Linear"}, {"x":[0,1,2,3,4], "y":[5,5,5,5,5]}, {"x":[4,3,2,1,0], "y":[4,3,2,1,0], "label":"Linear2"}, {"x":0, "type":"axvLine", "label":"Red Vertical Line", "color":"red"}, {"y":6, "type":"axhLine", "label":"Dashed Horizontal Line", "args":{"linestyle":"--"}}, {"type":"scatter", "x":4, "y":4, "label":"A Random Point", "colour":"purple", "args":{"zorder":2}} ] }) """ doKeyCopy = True plotDim = (1,) if "subPlots" in graphData: if "plotDim" in graphData: plotDim = graphData["plotDim"] else: plotDim = (1,len(graphData["subPlots"])) else: graphData["subPlots"] = [graphData] doKeyCopy = False figSize = graphData["figSize"] if "figSize" in graphData else None fig, axes = plt.subplots(*plotDim, figsize=figSize) # Create a figure and an axes. #if len(graphData["subPlots"]) <= 1: # axes = [axes] #Makes everything nice and linear # IE ((1,2), (3,4)) = (1,2,3,4) flatAxes = np.array(axes).flatten().tolist() loopKeys = [ "xLabel", "yLabel", "title", "axis", "grid", "xPos", "yPos", "xLabelPos", "yLabelPos", "xTickPos", "yTickPos", "xScale", "yScale", "xTickMap", "yTickMap", "plots", "xLim", "yLim", "ledgLoc", "y2Label", "ticklabel" ] #Feel like this could be optimized if doKeyCopy: for key in loopKeys: if key not in graphData: continue if key in graphData: for axGraphData in graphData["subPlots"]: if key not in axGraphData: axGraphData[key] = graphData[key] for i, axGraphData in enumerate(graphData["subPlots"]): ax1 = flatAxes[i] if bool(sum([("y2" in pData) for pData in axGraphData["plots"]])): ax2 = ax1.twinx() else: ax2 = None # Duct Tape ax = ax1 #Draw many plots as needed # Also provide functions for drawing other types of lines if "plots" in axGraphData: for pData in axGraphData["plots"]: getSafeValue = lambda key, result=None: pData[key] if key in pData else result #Only return the key-value if present in pData getSafeValue2 = lambda key, key2, result=None: pData[key][key2] if key in pData and key2 in pData[key] else result getSafeColour = getSafeValue("colour") or getSafeValue("color") #Frigen American Spelling optArgs = getSafeValue("args", {}) #Allow for other args to be passed in if "x" in pData: xData = pData["x"] if "y" in pData: yData = pData["y"] elif "y2" in pData: yData = pData["y2"] ax = ax2 if "type" not in pData or pData["type"] == "plot": ax.plot(xData, yData, label=getSafeValue("label"), color=getSafeColour, **optArgs) elif pData["type"] == "point": ax.scatter(xData, yData, marker=getSafeValue("marker"), label=getSafeValue("label"), color=getSafeColour, zorder=getSafeValue("zorder", 2), **optArgs ) elif pData["type"] == "hLine": ax.hlines(yData, *xData, label=getSafeValue("label"), color=getSafeColour, **optArgs) elif pData["type"] == "vLine": ax.vlines(xData, *yData, label=getSafeValue("label"), color=getSafeColour, **optArgs) elif pData["type"] == "axvLine": if "y" not in pData: yData = (0, 1) #Span the whole graph ax.axvline(xData, *yData, label=getSafeValue("label"), color=getSafeColour, **optArgs) elif pData["type"] == "axhLine": if "x" not in pData: xData = (0, 1) #Span the whole graph ax.axhline(yData, *xData, label=getSafeValue("label"), color=getSafeColour, **optArgs) elif pData["type"] == "scatter": ax.scatter(xData, yData, marker=getSafeValue("marker"), label=getSafeValue("label"), color=getSafeColour, **optArgs) elif pData["type"] == "contour": cs = ax.contour(getSafeValue("x"), getSafeValue("y"), pData["z"], levels=getSafeValue("levels"), colors=getSafeColour, **optArgs) if "label" in pData: cs.collections[0].set_label(getSafeValue("label")) elif pData["type"] == "matshow": ms = ax.matshow(pData["matrix"], origin=getSafeValue("origin"), label=getSafeValue("label"), **optArgs) if "colourBar" in pData: colorbar(ms, extend=getSafeValue2("colourBar", "extend")) elif pData["type"] == "pColourMesh": mesh = [] if "X" in pData or "Y" in pData: mesh = [xData, yData, pData["Z"]] if "x" in pData or "y" in pData: x = xData; y = yData if type(x) in [int, float]: x = (0, x, None) if type(y) in [int, float]: y = (0, x, None) x = tuple(x); y = tuple(y) if len(x) < 3: x = (x[0], x[1], None) if len(y) < 3: y = (y[0], y[1], None) x = np.arange(x[0], x[1], x[2]) y = np.arange(y[0], y[1], y[2]) X, Y = np.meshgrid(x, y) mesh = [X, Y, pData["Z"]] else: mesh = [pData["Z"]] cNorm = None if "norm" in pData: cNorm = colors.LogNorm(vmin=pData["norm"][0], vmax=pData["norm"][1]) pcMesh = ax.pcolormesh(*mesh, norm=cNorm, shading=getSafeValue("shading"), label=getSafeValue("label"), **optArgs) #pcMesh = ax.imshow(pData["Z"], norm=cNorm, origin="lower") if "colourBar" in pData: cBarOptArgs = pData["colourBar"]["optArgs"] if "optArgs" in pData["colourBar"] else {} fig.colorbar(pcMesh, ax=ax, extend=getSafeValue2("colourBar", "extend"), **cBarOptArgs) elif pData["type"] == "imshow": cNorm = None if "norm" in pData: cNorm = colors.LogNorm(vmin=pData["norm"][0], vmax=pData["norm"][1]) ims = ax.imshow(pData["data"], norm=cNorm, origin=getSafeValue("origin"), label=getSafeValue("label"), **optArgs) if "colourBar" in pData: cBarOptArgs = pData["colourBar"]["optArgs"] if "optArgs" in pData["colourBar"] else {} colorbar(ims, extend=getSafeValue2("colourBar", "extend"), **cBarOptArgs) elif pData["type"] == "text": if not "props" in pData: props = { "boxstyle" : getSafeValue("boxstyle", "round"), "facecolor": getSafeValue("facecolor", getSafeValue("facecolour", "wheat")), "alpha" : getSafeValue("alpha", 0.5) } align = ( getSafeValue("valign", None), getSafeValue("halign", None), ) align = getSafeValue("align", align) ax.text(getSafeValue("x", 0.05), getSafeValue("y", 0.95), pData["text"], transform=ax.transAxes, fontsize=getSafeValue("fontsize", None), va=align[0], ha=align[1], bbox=props) #Set extra options as needed ax = ax1 if "xLabel" in axGraphData: ax.set_xlabel(axGraphData["xLabel"]) # Add an x-label to the axes. if "yLabel" in axGraphData: ax.set_ylabel(axGraphData["yLabel"]) # Add an y-label to the axes. if "y2Label" in axGraphData: ax2.set_ylabel(axGraphData["y2Label"]) # Add a y2-label to the axes. if "title" in axGraphData: ax.set_title(axGraphData["title"]) # Add an title to the axes. if "axis" in axGraphData: ax.axis(axGraphData["axis"]) # Set the axis type if "grid" in axGraphData: ax.grid(axGraphData["grid"]) # Add grids to the graph if "xPos" in axGraphData: # Add the abilty to move the x axis label and ticks ax.xaxis.set_label_position(axGraphData["xPos"]) ax.xaxis.set_ticks_position(axGraphData["xPos"]) if "yPos" in axGraphData: # Add the abilty to move the y axis label and ticks ax.yaxis.set_label_position(axGraphData["yPos"]) ax.yaxis.set_ticks_position(axGraphData["yPos"]) if "xLabelPos" in axGraphData: ax.xaxis.set_label_position(axGraphData["xLabelPos"]) # Add the ability to move the x axis label if "yLabelPos" in axGraphData: ax.yaxis.set_label_position(axGraphData["yLabelPos"]) # Add the ability to move the y axis label if "xTickPos" in axGraphData: ax.xaxis.set_ticks_position(axGraphData["xTickPos"]) # Add the ability to move the x axis ticks if "yTickPos" in axGraphData: ax.yaxis.set_ticks_position(axGraphData["yTickPos"]) # Add the ability to move the y axis ticks if "xScale" in axGraphData: ax.set_xscale(axGraphData["xScale"]) #Add x axis scaling if needed if "yScale" in axGraphData: ax.set_yscale(axGraphData["yScale"]) #Add y axis scaling if needed if "xLim" in axGraphData: xLimit = () if type(axGraphData["xLim"]) in [int, float]: xLimit = (0, axGraphData["xLim"]) else: xLimit = axGraphData["xLim"] ax.set_xlim(xLimit) if "yLim" in axGraphData: yLimit = () if type(axGraphData["yLim"]) in [int, float]: yLimit = (0, axGraphData["yLim"]) else: yLimit = axGraphData["yLim"] ax.set_ylim(yLimit) if "xTickMap" in axGraphData: #Allow for the mapping / transformation of the xAxis Ticks xTicks = matplotlib.ticker.FuncFormatter(lambda x, pos: '{0:g}'.format(axGraphData["xTickMap"](x))) ax.xaxis.set_major_formatter(xTicks) if "yTickMap" in axGraphData: #Allow for the mapping / transformation of the yAxis Ticks yTicks = matplotlib.ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(axGraphData["yTickMap"](y))) ax.yaxis.set_major_formatter(yTicks) if "plots" in axGraphData and bool(sum([("label" in pData) for pData in axGraphData["plots"]])): locPoint = axGraphData["ledgLoc"] if "ledgLoc" in axGraphData else None lines1, labels1 = ax1.get_legend_handles_labels() if ax2: lines2, labels2 = ax2.get_legend_handles_labels() ax2.legend(lines1 + lines2, labels1 + labels2, loc=locPoint) else: ax1.legend(lines1, labels1, loc=locPoint) if "ticklabel" in axGraphData: style = axGraphData["ticklabel"]["style"] if "style" in axGraphData["ticklabel"] else "" axis = axGraphData["ticklabel"]["axis"] if "axis" in axGraphData["ticklabel"] else "both" limits = axGraphData["ticklabel"]["limits"] if "limits" in axGraphData["ticklabel"] else None optArgs = axGraphData["ticklabel"]["optArgs"] if "optArgs" in axGraphData["ticklabel"] else {} ax.ticklabel_format(axis=axis, style=style, scilimits=limits, **optArgs) #Should work? if hideEmptyAxis: if not ax.collections and not ax.lines: ax.set_axis_off() if "title" in graphData and not "figTitle" in graphData: fig.canvas.manager.set_window_title(graphData["title"].replace("\n", " ")) #Set the figure title correctly if "figTitle" in graphData: getSafeValue = lambda key: graphData[key] if key in graphData else None #Only return the key-value if present in graphData fig.suptitle(graphData["figTitle"], fontsize=getSafeValue("figTitleFontSize")) fig.canvas.manager.set_window_title(graphData["figTitle"].replace("\n", " ")) fig.tight_layout() #Fix labels being cut off sometimes #Very big hack if hideEmptyAxis: flatAxes[-1].set_axis_off() if figSavePath: fig.savefig(figSavePath) if showPlot: plt.show(block=doProgramBlock) #Show the plot and also block the program - doing things OO style allow for more flexible programs return fig, axes # [TODO] Make this Async so the closure of all graphs exits def pltKeyClose(): '''Show all plots and wait for user input to close them all.''' plt.show(block=False) input('Press any key to close all graphs...') plt.close() if __name__ == '__main__': #This is an example of drawing 4 plots by generating them graphData = { "figTitle": "Simple Plot", "figTitleFontSize": 16, "figSize": (8,8), #Yay America, this is in inches :/ # Note: cm = 1/2.54 "xLabel": "x label", "yLabel": "y label", "plotDim": (2,2), "subPlots":[] } #Create 4 identical plots with different names for i in range(4): newPlot = { "title": f"Graph {i+1}", "plots": [ {"x":[0,1,2,3,4], "y":[0,1,2,3,4], "label":"Linear"}, {"x":[0,1,2,3,4], "y":[5,5,5,5,5]}, {"x":[4,3,2,1,0], "y":[4,3,2,1,0], "label":"Linear2"}, {"x":0, "type":"axvLine", "label":"Red Vertical Line", "color":"uq:red"}, {"y":6, "type":"axhLine", "label":"Dashed Horizontal Line", "args":{"linestyle":"--"}}, {"type":"point", "x":4, "y":4, "label":"A Random Point", "colour":"uq:purple"} ] } graphData["subPlots"].append(newPlot) makeGraph(graphData)