diff --git a/doc/api/next_api_changes/2018-11-14-AL-scatter.rst b/doc/api/next_api_changes/2018-11-14-AL-scatter.rst new file mode 100644 index 000000000000..97c6404dfc95 --- /dev/null +++ b/doc/api/next_api_changes/2018-11-14-AL-scatter.rst @@ -0,0 +1,14 @@ +PathCollections created with `~.Axes.scatter` now keep track of invalid points +`````````````````````````````````````````````````````````````````````````````` + +Previously, points with nonfinite (infinite or nan) coordinates would not be +included in the offsets (as returned by `PathCollection.get_offsets`) of a +`PathCollection` created by `~.Axes.scatter`, and points with nonfinite values +(as specified by the *c* kwarg) would not be included in the array (as returned +by `PathCollection.get_array`) + +Such points are now included, but masked out by returning a masked array. + +If the *plotnonfinite* kwarg to `~.Axes.scatter` is set, then points with +nonfinite values are plotted using the bad color of the `PathCollection`\ 's +colormap (as set by `Colormap.set_bad`). diff --git a/examples/units/basic_units.py b/examples/units/basic_units.py index 4f8b514e0de5..49eb823ffbe2 100644 --- a/examples/units/basic_units.py +++ b/examples/units/basic_units.py @@ -174,7 +174,10 @@ def get_compressed_copy(self, mask): def convert_to(self, unit): if unit == self.unit or not unit: return self - new_value = self.unit.convert_value_to(self.value, unit) + try: + new_value = self.unit.convert_value_to(self.value, unit) + except AttributeError: + new_value = self return TaggedValue(new_value, unit) def get_value(self): @@ -345,7 +348,20 @@ def convert(val, unit, axis): if units.ConversionInterface.is_numlike(val): return val if np.iterable(val): - return [thisval.convert_to(unit).get_value() for thisval in val] + if isinstance(val, np.ma.MaskedArray): + val = val.astype(float).filled(np.nan) + out = np.empty(len(val)) + for i, thisval in enumerate(val): + if np.ma.is_masked(thisval): + out[i] = np.nan + else: + try: + out[i] = thisval.convert_to(unit).get_value() + except AttributeError: + out[i] = thisval + return out + if np.ma.is_masked(val): + return np.nan else: return val.convert_to(unit).get_value() diff --git a/examples/units/units_scatter.py b/examples/units/units_scatter.py index 7850adef7e33..095065815f4a 100644 --- a/examples/units/units_scatter.py +++ b/examples/units/units_scatter.py @@ -27,9 +27,8 @@ ax2.scatter(xsecs, xsecs, yunits=hertz) ax2.axis([0, 10, 0, 1]) -ax3.scatter(xsecs, xsecs, yunits=hertz) -ax3.yaxis.set_units(minutes) -ax3.axis([0, 10, 0, 1]) +ax3.scatter(xsecs, xsecs, yunits=minutes) +ax3.axis([0, 10, 0, 0.2]) fig.tight_layout() plt.show() diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 8bb104487702..f7ac8a9d7d66 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -4180,7 +4180,7 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape, label_namer="y") def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, - verts=None, edgecolors=None, + verts=None, edgecolors=None, *, plotnonfinite=False, **kwargs): """ A scatter plot of *y* vs *x* with varying marker size and/or color. @@ -4257,6 +4257,10 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, For non-filled markers, the *edgecolors* kwarg is ignored and forced to 'face' internally. + plotnonfinite : boolean, optional, default: False + Set to plot points with nonfinite *c*, in conjunction with + `~matplotlib.colors.Colormap.set_bad`. + Returns ------- paths : `~matplotlib.collections.PathCollection` @@ -4310,11 +4314,14 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, c, edgecolors, kwargs, xshape, yshape, get_next_color_func=self._get_patches_for_fill.get_next_color) - # `delete_masked_points` only modifies arguments of the same length as - # `x`. - x, y, s, c, colors, edgecolors, linewidths =\ - cbook.delete_masked_points( - x, y, s, c, colors, edgecolors, linewidths) + if plotnonfinite and colors is None: + c = np.ma.masked_invalid(c) + x, y, s, edgecolors, linewidths = \ + cbook._combine_masks(x, y, s, edgecolors, linewidths) + else: + x, y, s, c, colors, edgecolors, linewidths = \ + cbook._combine_masks( + x, y, s, c, colors, edgecolors, linewidths) scales = s # Renamed for readability below. @@ -4340,7 +4347,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, edgecolors = 'face' linewidths = rcParams['lines.linewidth'] - offsets = np.column_stack([x, y]) + offsets = np.ma.column_stack([x, y]) collection = mcoll.PathCollection( (path,), scales, @@ -4358,7 +4365,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None, if norm is not None and not isinstance(norm, mcolors.Normalize): raise ValueError( "'norm' must be an instance of 'mcolors.Normalize'") - collection.set_array(np.asarray(c)) + collection.set_array(c) collection.set_cmap(cmap) collection.set_norm(norm) diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py index c02a3c455263..8f7e7398e4a2 100644 --- a/lib/matplotlib/cbook/__init__.py +++ b/lib/matplotlib/cbook/__init__.py @@ -1081,6 +1081,66 @@ def delete_masked_points(*args): return margs +def _combine_masks(*args): + """ + Find all masked and/or non-finite points in a set of arguments, + and return the arguments as masked arrays with a common mask. + + Arguments can be in any of 5 categories: + + 1) 1-D masked arrays + 2) 1-D ndarrays + 3) ndarrays with more than one dimension + 4) other non-string iterables + 5) anything else + + The first argument must be in one of the first four categories; + any argument with a length differing from that of the first + argument (and hence anything in category 5) then will be + passed through unchanged. + + Masks are obtained from all arguments of the correct length + in categories 1, 2, and 4; a point is bad if masked in a masked + array or if it is a nan or inf. No attempt is made to + extract a mask from categories 2 and 4 if :meth:`np.isfinite` + does not yield a Boolean array. Category 3 is included to + support RGB or RGBA ndarrays, which are assumed to have only + valid values and which are passed through unchanged. + + All input arguments that are not passed unchanged are returned + as masked arrays if any masked points are found, otherwise as + ndarrays. + + """ + if not len(args): + return () + if is_scalar_or_string(args[0]): + raise ValueError("First argument must be a sequence") + nrecs = len(args[0]) + margs = [] # Output args; some may be modified. + seqlist = [False] * len(args) # Flags: True if output will be masked. + masks = [] # List of masks. + for i, x in enumerate(args): + if is_scalar_or_string(x) or len(x) != nrecs: + margs.append(x) # Leave it unmodified. + else: + if isinstance(x, np.ma.MaskedArray) and x.ndim > 1: + raise ValueError("Masked arrays must be 1-D") + x = np.asanyarray(x) + if x.ndim == 1: + x = safe_masked_invalid(x) + seqlist[i] = True + if np.ma.is_masked(x): + masks.append(np.ma.getmaskarray(x)) + margs.append(x) # Possibly modified. + if len(masks): + mask = np.logical_or.reduce(masks) + for i, x in enumerate(margs): + if seqlist[i]: + margs[i] = np.ma.array(x, mask=mask) + return margs + + def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None, autorange=False): """ diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 92bbf6c7b112..73c43302cbce 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -2835,12 +2835,13 @@ def quiverkey(Q, X, Y, U, label, **kw): def scatter( x, y, s=None, c=None, marker=None, cmap=None, norm=None, vmin=None, vmax=None, alpha=None, linewidths=None, verts=None, - edgecolors=None, *, data=None, **kwargs): + edgecolors=None, *, plotnonfinite=False, data=None, **kwargs): __ret = gca().scatter( x, y, s=s, c=c, marker=marker, cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, alpha=alpha, linewidths=linewidths, - verts=verts, edgecolors=edgecolors, **({"data": data} if data - is not None else {}), **kwargs) + verts=verts, edgecolors=edgecolors, + plotnonfinite=plotnonfinite, **({"data": data} if data is not + None else {}), **kwargs) sci(__ret) return __ret diff --git a/lib/matplotlib/testing/decorators.py b/lib/matplotlib/testing/decorators.py index 989cc566bf5f..548a4b54909b 100644 --- a/lib/matplotlib/testing/decorators.py +++ b/lib/matplotlib/testing/decorators.py @@ -452,19 +452,37 @@ def decorator(func): _, result_dir = map(Path, _image_directories(func)) - @pytest.mark.parametrize("ext", extensions) - def wrapper(ext): - fig_test = plt.figure("test") - fig_ref = plt.figure("reference") - func(fig_test, fig_ref) - test_image_path = str( - result_dir / (func.__name__ + "." + ext)) - ref_image_path = str( - result_dir / (func.__name__ + "-expected." + ext)) - fig_test.savefig(test_image_path) - fig_ref.savefig(ref_image_path) - _raise_on_image_difference( - ref_image_path, test_image_path, tol=tol) + if len(inspect.signature(func).parameters) == 2: + # Free-standing function. + @pytest.mark.parametrize("ext", extensions) + def wrapper(ext): + fig_test = plt.figure("test") + fig_ref = plt.figure("reference") + func(fig_test, fig_ref) + test_image_path = str( + result_dir / (func.__name__ + "." + ext)) + ref_image_path = str( + result_dir / (func.__name__ + "-expected." + ext)) + fig_test.savefig(test_image_path) + fig_ref.savefig(ref_image_path) + _raise_on_image_difference( + ref_image_path, test_image_path, tol=tol) + + elif len(inspect.signature(func).parameters) == 3: + # Method. + @pytest.mark.parametrize("ext", extensions) + def wrapper(self, ext): + fig_test = plt.figure("test") + fig_ref = plt.figure("reference") + func(self, fig_test, fig_ref) + test_image_path = str( + result_dir / (func.__name__ + "." + ext)) + ref_image_path = str( + result_dir / (func.__name__ + "-expected." + ext)) + fig_test.savefig(test_image_path) + fig_ref.savefig(ref_image_path) + _raise_on_image_difference( + ref_image_path, test_image_path, tol=tol) return wrapper diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index f386e2f03881..ed7e0ac8168e 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -1749,6 +1749,34 @@ def test_scatter_color(self): with pytest.raises(ValueError): plt.scatter([1, 2, 3], [1, 2, 3], color=[1, 2, 3]) + @check_figures_equal(extensions=["png"]) + def test_scatter_invalid_color(self, fig_test, fig_ref): + ax = fig_test.subplots() + cmap = plt.get_cmap("viridis", 16) + cmap.set_bad("k", 1) + # Set a nonuniform size to prevent the last call to `scatter` (plotting + # the invalid points separately in fig_ref) from using the marker + # stamping fast path, which would result in slightly offset markers. + ax.scatter(range(4), range(4), + c=[1, np.nan, 2, np.nan], s=[1, 2, 3, 4], + cmap=cmap, plotnonfinite=True) + ax = fig_ref.subplots() + cmap = plt.get_cmap("viridis", 16) + ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap) + ax.scatter([1, 3], [1, 3], s=[2, 4], color="k") + + @check_figures_equal(extensions=["png"]) + def test_scatter_no_invalid_color(self, fig_test, fig_ref): + # With plotninfinite=False we plot only 2 points. + ax = fig_test.subplots() + cmap = plt.get_cmap("viridis", 16) + cmap.set_bad("k", 1) + ax.scatter(range(4), range(4), + c=[1, np.nan, 2, np.nan], s=[1, 2, 3, 4], + cmap=cmap, plotnonfinite=False) + ax = fig_ref.subplots() + ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap) + # Parameters for *test_scatter_c*. NB: assuming that the # scatter plot will have 4 elements. The tuple scheme is: # (*c* parameter case, exception regexp key or None if no exception) @@ -5743,21 +5771,6 @@ def test_color_length_mismatch(): ax.scatter(x, y, c=[c_rgb] * N) -def test_scatter_color_masking(): - x = np.array([1, 2, 3]) - y = np.array([1, np.nan, 3]) - colors = np.array(['k', 'w', 'k']) - linewidths = np.array([1, 2, 3]) - s = plt.scatter(x, y, color=colors, linewidths=linewidths) - - facecolors = s.get_facecolors() - linecolors = s.get_edgecolors() - linewidths = s.get_linewidths() - assert_array_equal(facecolors[1], np.array([0, 0, 0, 1])) - assert_array_equal(linecolors[1], np.array([0, 0, 0, 1])) - assert linewidths[1] == 3 - - def test_eventplot_legend(): plt.eventplot([1.0], label='Label') plt.legend() diff --git a/lib/matplotlib/tests/test_colorbar.py b/lib/matplotlib/tests/test_colorbar.py index 1c358c09ae78..f8d43f8e942f 100644 --- a/lib/matplotlib/tests/test_colorbar.py +++ b/lib/matplotlib/tests/test_colorbar.py @@ -197,9 +197,8 @@ def test_colorbar_single_scatter(): # the norm scaling within the colorbar must ensure a # finite range, otherwise a zero denominator will occur in _locate. plt.figure() - x = np.arange(4) - y = x.copy() - z = np.ma.masked_greater(np.arange(50, 54), 50) + x = y = [0] + z = [50] cmap = plt.get_cmap('jet', 16) cs = plt.scatter(x, y, z, c=z, cmap=cmap) plt.colorbar(cs)