diff --git a/doc/release/next_whats_new/3d_scales.rst b/doc/release/next_whats_new/3d_scales.rst new file mode 100644 index 000000000000..e92b8e0abb61 --- /dev/null +++ b/doc/release/next_whats_new/3d_scales.rst @@ -0,0 +1,35 @@ +Non-linear scales on 3D axes +---------------------------- + +Resolving a long-standing issue, 3D axes now support non-linear axis scales +such as 'log', 'symlog', 'logit', 'asinh', and custom 'function' scales, just +like 2D axes. Use `~.Axes3D.set_xscale`, `~.Axes3D.set_yscale`, and +`~.Axes3D.set_zscale` to set the scale for each axis independently. + +.. plot:: + :include-source: true + :alt: A 3D plot with a linear x-axis, logarithmic y-axis, and symlog z-axis. + + import matplotlib.pyplot as plt + import numpy as np + + # A sine chirp with increasing frequency and amplitude + x = np.linspace(0, 1, 400) # time + y = 10 ** (2 * x) # frequency, growing exponentially from 1 to 100 Hz + phase = 2 * np.pi * (10 ** (2 * x) - 1) / (2 * np.log(10)) + z = np.sin(phase) * x ** 2 * 10 # amplitude, growing quadratically + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.plot(x, y, z) + + ax.set_xlabel('Time (linear)') + ax.set_ylabel('Frequency, Hz (log)') + ax.set_zlabel('Amplitude (symlog)') + + ax.set_yscale('log') + ax.set_zscale('symlog') + + plt.show() + +See `matplotlib.scale` for details on all available scales and their parameters. diff --git a/galleries/examples/mplot3d/scales3d.py b/galleries/examples/mplot3d/scales3d.py new file mode 100644 index 000000000000..c0da42da85bd --- /dev/null +++ b/galleries/examples/mplot3d/scales3d.py @@ -0,0 +1,52 @@ +""" +================================ +Scales on 3D (Log, Symlog, etc.) +================================ + +Demonstrate how to use non-linear scales such as logarithmic scales on 3D axes. + +3D axes support the same axis scales as 2D plots: 'linear', 'log', 'symlog', +'logit', 'asinh', and custom 'function' scales. This example shows a mix of +scales: linear on X, log on Y, and symlog on Z. + +For a complete list of built-in scales, see `matplotlib.scale`. For an overview +of scale transformations, see :doc:`/gallery/scales/scales`. +""" + +import matplotlib.pyplot as plt +import numpy as np + +# A sine chirp with increasing frequency and amplitude +x = np.linspace(0, 1, 400) # time +y = 10 ** (2 * x) # frequency, growing exponentially from 1 to 100 Hz +phase = 2 * np.pi * (10 ** (2 * x) - 1) / (2 * np.log(10)) +z = np.sin(phase) * x **2 * 10 # amplitude, growing quadratically + +fig = plt.figure() +ax = fig.add_subplot(projection='3d') +ax.plot(x, y, z) + +ax.set_xlabel('Time (linear)') +ax.set_ylabel('Frequency, Hz (log)') +ax.set_zlabel('Amplitude (symlog)') + +ax.set_yscale('log') +ax.set_zscale('symlog') + +plt.show() + +# %% +# +# .. admonition:: References +# +# The use of the following functions, methods, classes and modules is shown +# in this example: +# +# - `mpl_toolkits.mplot3d.axes3d.Axes3D.set_xscale` +# - `mpl_toolkits.mplot3d.axes3d.Axes3D.set_yscale` +# - `mpl_toolkits.mplot3d.axes3d.Axes3D.set_zscale` +# - `matplotlib.scale` +# +# .. tags:: +# plot-type: 3D, +# level: beginner diff --git a/galleries/examples/scales/scales.py b/galleries/examples/scales/scales.py index 6c4556c9c1d3..071b564b23b1 100644 --- a/galleries/examples/scales/scales.py +++ b/galleries/examples/scales/scales.py @@ -5,8 +5,9 @@ Illustrate the scale transformations applied to axes, e.g. log, symlog, logit. -See `matplotlib.scale` for a full list of built-in scales, and -:doc:`/gallery/scales/custom_scale` for how to create your own scale. +See `matplotlib.scale` for a full list of built-in scales, +:doc:`/gallery/scales/custom_scale` for how to create your own scale, and +:doc:`/gallery/mplot3d/scales3d` for using scales on 3D axes. """ import matplotlib.pyplot as plt diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index d06d157db4ce..6898a8aaf4cf 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -79,7 +79,7 @@ def _viewlim_mask(xs, ys, zs, axes): Parameters ---------- xs, ys, zs : array-like - The points to mask. + The points to mask. These should be in data coordinates. axes : Axes3D The axes to use for the view limits. @@ -198,7 +198,10 @@ def draw(self, renderer): else: pos3d = np.array([self._x, self._y, self._z], dtype=float) - proj = proj3d._proj_trans_points([pos3d, pos3d + self._dir_vec], self.axes.M) + dir_end = pos3d + self._dir_vec + points = np.asarray([pos3d, dir_end]) + proj = proj3d._scale_proj_transform( + points[:, 0], points[:, 1], points[:, 2], self.axes) dx = proj[0][1] - proj[0][0] dy = proj[1][1] - proj[1][0] angle = math.degrees(math.atan2(dy, dx)) @@ -334,9 +337,7 @@ def draw(self, renderer): dtype=float, mask=mask).filled(np.nan) else: xs3d, ys3d, zs3d = self._verts3d - xs, ys, zs, tis = proj3d._proj_transform_clip(xs3d, ys3d, zs3d, - self.axes.M, - self.axes._focal_length) + xs, ys, zs, tis = proj3d._scale_proj_transform_clip(xs3d, ys3d, zs3d, self.axes) self.set_data(xs, ys) super().draw(renderer) self.stale = False @@ -427,7 +428,8 @@ def do_3d_projection(self): vs_list = [np.ma.array(vs, mask=np.broadcast_to( _viewlim_mask(*vs.T, self.axes), vs.shape)) for vs in vs_list] - xyzs_list = [proj3d.proj_transform(*vs.T, self.axes.M) for vs in vs_list] + xyzs_list = [proj3d._scale_proj_transform( + vs[:, 0], vs[:, 1], vs[:, 2], self.axes) for vs in vs_list] self._paths = [mpath.Path(np.ma.column_stack([xs, ys]), cs) for (xs, ys, _), (_, cs) in zip(xyzs_list, self._3dverts_codes)] zs = np.concatenate([zs for _, _, zs in xyzs_list]) @@ -497,6 +499,11 @@ def do_3d_projection(self): """ segments = np.asanyarray(self._segments3d) + # Handle empty segments + if segments.size == 0: + LineCollection.set_segments(self, []) + return np.nan + mask = False if np.ma.isMA(segments): mask = segments.mask @@ -511,8 +518,9 @@ def do_3d_projection(self): viewlim_mask = np.broadcast_to(viewlim_mask[..., np.newaxis], (*viewlim_mask.shape, 3)) mask = mask | viewlim_mask - xyzs = np.ma.array(proj3d._proj_transform_vectors(segments, self.axes.M), - mask=mask) + + xyzs = np.ma.array( + proj3d._scale_proj_transform_vectors(segments, self.axes), mask=mask) segments_2d = xyzs[..., 0:2] LineCollection.set_segments(self, segments_2d) @@ -595,9 +603,7 @@ def do_3d_projection(self): dtype=float, mask=mask).filled(np.nan) else: xs, ys, zs = zip(*s) - vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs, - self.axes.M, - self.axes._focal_length) + vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) self._path2d = mpath.Path(np.ma.column_stack([vxs, vys])) return min(vzs) @@ -657,9 +663,7 @@ def do_3d_projection(self): dtype=float, mask=mask).filled(np.nan) else: xs, ys, zs = zip(*s) - vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs, - self.axes.M, - self.axes._focal_length) + vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) self._path2d = mpath.Path(np.ma.column_stack([vxs, vys]), self._code3d) return min(vzs) @@ -802,9 +806,7 @@ def do_3d_projection(self): xs, ys, zs = np.ma.array(self._offsets3d, mask=mask) else: xs, ys, zs = self._offsets3d - vxs, vys, vzs, vis = proj3d._proj_transform_clip(xs, ys, zs, - self.axes.M, - self.axes._focal_length) + vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(xs, ys, zs, self.axes) self._vzs = vzs if np.ma.isMA(vxs): super().set_offsets(np.ma.column_stack([vxs, vys])) @@ -1020,9 +1022,7 @@ def do_3d_projection(self): xyzs = np.ma.array(self._offsets3d, mask=mask) else: xyzs = self._offsets3d - vxs, vys, vzs, vis = proj3d._proj_transform_clip(*xyzs, - self.axes.M, - self.axes._focal_length) + vxs, vys, vzs, vis = proj3d._scale_proj_transform_clip(*xyzs, self.axes) self._data_scale = _get_data_scale(vxs, vys, vzs) # Sort the points based on z coordinates # Performance optimization: Create a sorted index array and reorder @@ -1356,7 +1356,7 @@ def do_3d_projection(self): # Some faces might contain masked vertices, so we want to ignore any # errors that those might cause with np.errstate(invalid='ignore', divide='ignore'): - pfaces = proj3d._proj_transform_vectors(self._faces, self.axes.M) + pfaces = proj3d._scale_proj_transform_vectors(self._faces, self.axes) if self._axlim_clip: viewlim_mask = _viewlim_mask(self._faces[..., 0], self._faces[..., 1], diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index b9323897c4d3..c6e28ec71b78 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -8,6 +8,26 @@ Module containing Axes3D, an object which can plot 3D objects on a 2D matplotlib figure. + +Coordinate Systems +------------------ +3D plotting involves several coordinate transformations: + +1. **Data coordinates**: The user's raw x, y, z values. + +2. **Transformed coordinates**: Data coordinates after applying axis scale + transforms (log, symlog, etc.). For linear scales, these equal data + coordinates. Zoom/pan operations work in this space to ensure uniform + behavior with non-linear scales. + +3. **Normalized coordinates**: Transformed coordinates mapped to a [0, 1] + unit cube based on the current axis limits. + +4. **Projected coordinates**: 2D coordinates after applying the 3D to 2D + projection matrix, ready for display. + +Artists receive data in data coordinates, apply scale transforms internally +via ``do_3d_projection()``, then project to 2D for rendering. """ from collections import defaultdict @@ -233,8 +253,16 @@ def get_zaxis(self): get_zticklines = _axis_method_wrapper("zaxis", "get_ticklines") def _transformed_cube(self, vals): - """Return cube with limits from *vals* transformed by self.M.""" + """Return cube with limits from *vals* transformed by self.M. + + The vals are in data space and are first transformed through the + axis scale transforms before being projected. + """ minx, maxx, miny, maxy, minz, maxz = vals + # Transform from data space to transformed coordinates + minx, maxx = self.xaxis.get_transform().transform([minx, maxx]) + miny, maxy = self.yaxis.get_transform().transform([miny, maxy]) + minz, maxz = self.zaxis.get_transform().transform([minz, maxz]) xyzs = [(minx, miny, minz), (maxx, miny, minz), (maxx, maxy, minz), @@ -243,7 +271,22 @@ def _transformed_cube(self, vals): (maxx, miny, maxz), (maxx, maxy, maxz), (minx, maxy, maxz)] - return proj3d._proj_points(xyzs, self.M) + return np.column_stack(proj3d._proj_trans_points(xyzs, self.M)) + + def _update_transScale(self): + """ + Override transScale to always use identity transforms. + + In 2D axes, transScale applies scale transforms (log, symlog, etc.) to + convert data coordinates to display coordinates. In 3D axes, scale + transforms are applied to data coordinates before 3D projection via + each axis's transform. The projected 2D coordinates are already in + display space, so transScale must be identity to avoid double-scaling. + """ + self.transScale.set( + mtransforms.blended_transform_factory( + mtransforms.IdentityTransform(), + mtransforms.IdentityTransform())) @_api.delete_parameter("3.11", "share") @_api.delete_parameter("3.11", "anchor") @@ -603,6 +646,42 @@ def auto_scale_xyz(self, X, Y, Z=None, had_data=None): # Let autoscale_view figure out how to use this data. self.autoscale_view() + def _autoscale_axis(self, axis, v0, v1, minpos, margin, set_bound, _tight): + """ + Autoscale a single axis. + + Parameters + ---------- + axis : Axis + The axis to autoscale. + v0, v1 : float + Data interval limits. + minpos : float + Minimum positive value for log-scale handling. + margin : float + Margin to apply (e.g., self._xmargin). + set_bound : callable + Function to set the axis bound (e.g., self.set_xbound). + _tight : bool + Whether to use tight bounds. + """ + locator = axis.get_major_locator() + v0, v1 = locator.nonsingular(v0, v1) + # Validate limits for the scale (e.g., positive for log scale) + v0, v1 = axis._scale.limit_range_for_scale(v0, v1, minpos) + if margin > 0: + # Apply margin in transformed space to handle non-linear scales + transform = axis.get_transform() + inverse_trans = transform.inverted() + v0t, v1t = transform.transform([v0, v1]) + delta = (v1t - v0t) * margin + if not np.isfinite(delta): + delta = 0 + v0, v1 = inverse_trans.transform([v0t - delta, v1t + delta]) + if not _tight: + v0, v1 = locator.view_limits(v0, v1) + set_bound(v0, v1, self._view_margin) + def autoscale_view(self, tight=None, scalex=True, scaley=True, scalez=True): """ @@ -627,40 +706,22 @@ def autoscale_view(self, tight=None, _tight = self._tight = bool(tight) if scalex and self.get_autoscalex_on(): - x0, x1 = self.xy_dataLim.intervalx - xlocator = self.xaxis.get_major_locator() - x0, x1 = xlocator.nonsingular(x0, x1) - if self._xmargin > 0: - delta = (x1 - x0) * self._xmargin - x0 -= delta - x1 += delta - if not _tight: - x0, x1 = xlocator.view_limits(x0, x1) - self.set_xbound(x0, x1, self._view_margin) + self._autoscale_axis( + self.xaxis, *self.xy_dataLim.intervalx, + self.xy_dataLim.minposx, self._xmargin, + self.set_xbound, _tight) if scaley and self.get_autoscaley_on(): - y0, y1 = self.xy_dataLim.intervaly - ylocator = self.yaxis.get_major_locator() - y0, y1 = ylocator.nonsingular(y0, y1) - if self._ymargin > 0: - delta = (y1 - y0) * self._ymargin - y0 -= delta - y1 += delta - if not _tight: - y0, y1 = ylocator.view_limits(y0, y1) - self.set_ybound(y0, y1, self._view_margin) + self._autoscale_axis( + self.yaxis, *self.xy_dataLim.intervaly, + self.xy_dataLim.minposy, self._ymargin, + self.set_ybound, _tight) if scalez and self.get_autoscalez_on(): - z0, z1 = self.zz_dataLim.intervalx - zlocator = self.zaxis.get_major_locator() - z0, z1 = zlocator.nonsingular(z0, z1) - if self._zmargin > 0: - delta = (z1 - z0) * self._zmargin - z0 -= delta - z1 += delta - if not _tight: - z0, z1 = zlocator.view_limits(z0, z1) - self.set_zbound(z0, z1, self._view_margin) + self._autoscale_axis( + self.zaxis, *self.zz_dataLim.intervalx, + self.zz_dataLim.minposx, self._zmargin, + self.set_zbound, _tight) def get_w_lims(self): """Get 3D world limits.""" @@ -761,7 +822,8 @@ def set_zbound(self, lower=None, upper=None, view_margin=None): lower, upper, view_margin) def _set_lim3d(self, axis, lower=None, upper=None, *, emit=True, - auto=False, view_margin=None, axmin=None, axmax=None): + auto=False, view_margin=None, axmin=None, axmax=None, + minpos=np.inf): """ Set 3D axis limits. """ @@ -787,9 +849,20 @@ def _set_lim3d(self, axis, lower=None, upper=None, *, emit=True, view_margin = self._view_margin else: view_margin = 0 - delta = (upper - lower) * view_margin - lower -= delta - upper += delta + # Apply margin in transformed space to handle non-linear scales properly + if view_margin > 0 and hasattr(axis, '_scale') and axis._scale is not None: + transform = axis.get_transform() + inverse_trans = transform.inverted() + lower, upper = axis._scale.limit_range_for_scale(lower, upper, minpos) + lower_t, upper_t = transform.transform([lower, upper]) + delta = (upper_t - lower_t) * view_margin + if np.isfinite(delta): + new_range = [lower_t - delta, upper_t + delta] + lower, upper = inverse_trans.transform(new_range) + else: + delta = (upper - lower) * view_margin + lower -= delta + upper += delta return axis._set_lim(lower, upper, emit=emit, auto=auto) def set_xlim(self, left=None, right=None, *, emit=True, auto=False, @@ -862,7 +935,8 @@ def set_xlim(self, left=None, right=None, *, emit=True, auto=False, >>> set_xlim(5000, 0) """ return self._set_lim3d(self.xaxis, left, right, emit=emit, auto=auto, - view_margin=view_margin, axmin=xmin, axmax=xmax) + view_margin=view_margin, axmin=xmin, axmax=xmax, + minpos=self.xy_dataLim.minposx) def set_ylim(self, bottom=None, top=None, *, emit=True, auto=False, view_margin=None, ymin=None, ymax=None): @@ -934,7 +1008,8 @@ def set_ylim(self, bottom=None, top=None, *, emit=True, auto=False, >>> set_ylim(5000, 0) """ return self._set_lim3d(self.yaxis, bottom, top, emit=emit, auto=auto, - view_margin=view_margin, axmin=ymin, axmax=ymax) + view_margin=view_margin, axmin=ymin, axmax=ymax, + minpos=self.xy_dataLim.minposy) def set_zlim(self, bottom=None, top=None, *, emit=True, auto=False, view_margin=None, zmin=None, zmax=None): @@ -1006,7 +1081,8 @@ def set_zlim(self, bottom=None, top=None, *, emit=True, auto=False, >>> set_zlim(5000, 0) """ return self._set_lim3d(self.zaxis, bottom, top, emit=emit, auto=auto, - view_margin=view_margin, axmin=zmin, axmax=zmax) + view_margin=view_margin, axmin=zmin, axmax=zmax, + minpos=self.zz_dataLim.minposx) set_xlim3d = set_xlim set_ylim3d = set_ylim @@ -1044,25 +1120,81 @@ def get_zlim(self): get_zscale = _axis_method_wrapper("zaxis", "get_scale") - # Redefine all three methods to overwrite their docstrings. - set_xscale = _axis_method_wrapper("xaxis", "_set_axes_scale") - set_yscale = _axis_method_wrapper("yaxis", "_set_axes_scale") - set_zscale = _axis_method_wrapper("zaxis", "_set_axes_scale") - set_xscale.__doc__, set_yscale.__doc__, set_zscale.__doc__ = map( + # Custom scale setters that handle limit validation for non-linear scales + def _set_axis_scale(self, axis, value, **kwargs): + """ + Set scale for an axis and constrain limits to valid range. + + Parameters + ---------- + axis : Axis + The axis to set the scale on. + value : str + The scale name. + **kwargs + Forwarded to scale constructor. + """ + # For non-linear scales on the z-axis, switch from the [0, 1] + + # margin=0 representation to the same xymargin + margin=0.05 + # representation that x/y use. Both produce identical linear limits, + # but only the xymargin form has valid positive lower bounds for log + # etc. This must happen before _set_axes_scale because that triggers + # autoscale_view internally. + if (axis is self.zaxis and value != 'linear' + and np.array_equal(self.zz_dataLim.get_points(), [[0, 0], [1, 1]])): + xymargin = 0.05 * 10/11 + self.zz_dataLim = Bbox([[xymargin, xymargin], + [1 - xymargin, 1 - xymargin]]) + self._zmargin = self._xmargin + axis._set_axes_scale(value, **kwargs) + + def set_xscale(self, value, **kwargs): + """ + Set the x-axis scale. + + Parameters + ---------- + value : {"linear", "log", "symlog", "logit", ...} + The axis scale type to apply. See `~.scale.ScaleBase` for + the list of available scales. + + **kwargs + Keyword arguments are forwarded to the scale class. + For example, ``base=2`` can be passed when using a log scale. + """ + self._set_axis_scale(self.xaxis, value, **kwargs) + + def set_yscale(self, value, **kwargs): """ - Set the {}-axis scale. + Set the y-axis scale. Parameters ---------- - value : {{"linear"}} - The axis scale type to apply. 3D Axes currently only support - linear scales; other scales yield nonsensical results. + value : {"linear", "log", "symlog", "logit", ...} + The axis scale type to apply. See `~.scale.ScaleBase` for + the list of available scales. **kwargs - Keyword arguments are nominally forwarded to the scale class, but - none of them is applicable for linear scales. - """.format, - ["x", "y", "z"]) + Keyword arguments are forwarded to the scale class. + For example, ``base=2`` can be passed when using a log scale. + """ + self._set_axis_scale(self.yaxis, value, **kwargs) + + def set_zscale(self, value, **kwargs): + """ + Set the z-axis scale. + + Parameters + ---------- + value : {"linear", "log", "symlog", "logit", ...} + The axis scale type to apply. See `~.scale.ScaleBase` for + the list of available scales. + + **kwargs + Keyword arguments are forwarded to the scale class. + For example, ``base=2`` can be passed when using a log scale. + """ + self._set_axis_scale(self.zaxis, value, **kwargs) get_zticks = _axis_method_wrapper("zaxis", "get_ticklocs") set_zticks = _axis_method_wrapper("zaxis", "set_ticks") @@ -1210,17 +1342,81 @@ def _roll_to_vertical( else: return np.roll(arr, (self._vertical_axis - 2)) + def _get_scaled_limits(self): + """ + Get axis limits transformed through their respective scale transforms. + + Returns + ------- + tuple + (xmin_scaled, xmax_scaled, ymin_scaled, ymax_scaled, + zmin_scaled, zmax_scaled) + """ + xmin, xmax = self.xaxis.get_transform().transform(self.get_xlim3d()) + ymin, ymax = self.yaxis.get_transform().transform(self.get_ylim3d()) + zmin, zmax = self.zaxis.get_transform().transform(self.get_zlim3d()) + return xmin, xmax, ymin, ymax, zmin, zmax + + def _untransform_point(self, x, y, z): + """ + Convert a point from transformed coordinates to data coordinates. + + Parameters + ---------- + x, y, z : float + A single point in transformed coordinates. + + Returns + ------- + x_data, y_data, z_data : float + The point in data coordinates. + """ + x_data = self.xaxis.get_transform().inverted().transform([x])[0] + y_data = self.yaxis.get_transform().inverted().transform([y])[0] + z_data = self.zaxis.get_transform().inverted().transform([z])[0] + return x_data, y_data, z_data + + def _set_lims_from_transformed(self, xmin_t, xmax_t, ymin_t, ymax_t, + zmin_t, zmax_t): + """ + Set axis limits from transformed coordinates. + + Converts limits from transformed coordinates back to data coordinates, + applies limit_range_for_scale validation, and sets the axis limits. + + Parameters + ---------- + xmin_t, xmax_t, ymin_t, ymax_t, zmin_t, zmax_t : float + Axis limits in transformed coordinates. + """ + # Transform back to data space + xmin, xmax = self.xaxis.get_transform().inverted().transform([xmin_t, xmax_t]) + ymin, ymax = self.yaxis.get_transform().inverted().transform([ymin_t, ymax_t]) + zmin, zmax = self.zaxis.get_transform().inverted().transform([zmin_t, zmax_t]) + + # Validate limits for scale constraints (e.g., positive for log scale) + xmin, xmax = self.xaxis._scale.limit_range_for_scale( + xmin, xmax, self.xy_dataLim.minposx) + ymin, ymax = self.yaxis._scale.limit_range_for_scale( + ymin, ymax, self.xy_dataLim.minposy) + zmin, zmax = self.zaxis._scale.limit_range_for_scale( + zmin, zmax, self.zz_dataLim.minposx) + + # Set the new axis limits + self.set_xlim3d(xmin, xmax, auto=None) + self.set_ylim3d(ymin, ymax, auto=None) + self.set_zlim3d(zmin, zmax, auto=None) + def get_proj(self): """Create the projection matrix from the current viewing position.""" # Transform to uniform world coordinates 0-1, 0-1, 0-1 box_aspect = self._roll_to_vertical(self._box_aspect) - worldM = proj3d.world_transformation( - *self.get_xlim3d(), - *self.get_ylim3d(), - *self.get_zlim3d(), - pb_aspect=box_aspect, - ) + # For non-linear scales, we use the scaled limits so the world + # transformation maps transformed coordinates (not data coordinates) + # to the unit cube + scaled_limits = self._get_scaled_limits() + worldM = proj3d.world_transformation(*scaled_limits, pb_aspect=box_aspect) # Look into the middle of the world coordinates: R = 0.5 * box_aspect @@ -1453,7 +1649,7 @@ def _location_coords(self, xv, yv, renderer): def _get_camera_loc(self): """ - Returns the current camera location in data coordinates. + Returns the current camera location in transformed coordinates. """ cx, cy, cz, dx, dy, dz = self._get_w_centers_ranges() c = np.array([cx, cy, cz]) @@ -1477,17 +1673,13 @@ def _calc_coord(self, xv, yv, renderer=None): else: # perspective projection zv = -1 / self._focal_length - # Convert point on view plane to data coordinates p1 = np.array(proj3d.inv_transform(xv, yv, zv, self.invM)).ravel() # Get the vector from the camera to the point on the view plane vec = self._get_camera_loc() - p1 # Get the pane locations for each of the axes - pane_locs = [] - for axis in self._axis_map.values(): - xys, loc = axis.active_pane() - pane_locs.append(loc) + pane_locs_data = [axis.active_pane()[1] for axis in self._axis_map.values()] # Find the distance to the nearest pane by projecting the view vector scales = np.zeros(3) @@ -1495,12 +1687,15 @@ def _calc_coord(self, xv, yv, renderer=None): if vec[i] == 0: scales[i] = np.inf else: - scales[i] = (p1[i] - pane_locs[i]) / vec[i] + scales[i] = (pane_locs_data[i] - p1[i]) / vec[i] pane_idx = np.argmin(abs(scales)) scale = scales[pane_idx] # Calculate the point on the closest pane - p2 = p1 - scale*vec + p2 = p1 + scale * vec + + # Convert from transformed to data coordinates + p2 = np.array(self._untransform_point(p2[0], p2[1], p2[2])) return p2, pane_idx def _arcball(self, x: float, y: float) -> np.ndarray: @@ -1660,16 +1855,17 @@ def drag_pan(self, button, key, x, y): R = -R / self._box_aspect * self._dist duvw_projected = R.T @ np.array([du, dv, dw]) - # Calculate pan distance - minx, maxx, miny, maxy, minz, maxz = self.get_w_lims() + # Calculate pan distance in transformed coordinates for non-linear scales + minx, maxx, miny, maxy, minz, maxz = self._get_scaled_limits() dx = (maxx - minx) * duvw_projected[0] dy = (maxy - miny) * duvw_projected[1] dz = (maxz - minz) * duvw_projected[2] - # Set the new axis limits - self.set_xlim3d(minx + dx, maxx + dx, auto=None) - self.set_ylim3d(miny + dy, maxy + dy, auto=None) - self.set_zlim3d(minz + dz, maxz + dz, auto=None) + # Compute new limits in transformed coordinates + self._set_lims_from_transformed( + minx + dx, maxx + dx, + miny + dy, maxy + dy, + minz + dz, maxz + dz) def _calc_view_axes(self, eye): """ @@ -1785,6 +1981,9 @@ def _scale_axis_limits(self, scale_x, scale_y, scale_z): limits by scale factors. A scale factor > 1 zooms out and a scale factor < 1 zooms in. + For non-linear scales, the scaling happens in transformed coordinates to ensure + uniform zoom behavior. + Parameters ---------- scale_x : float @@ -1794,23 +1993,29 @@ def _scale_axis_limits(self, scale_x, scale_y, scale_z): scale_z : float Scale factor for the z data axis. """ - # Get the axis centers and ranges + # Get the axis centers and ranges in transformed coordinates cx, cy, cz, dx, dy, dz = self._get_w_centers_ranges() - # Set the scaled axis limits - self.set_xlim3d(cx - dx*scale_x/2, cx + dx*scale_x/2, auto=None) - self.set_ylim3d(cy - dy*scale_y/2, cy + dy*scale_y/2, auto=None) - self.set_zlim3d(cz - dz*scale_z/2, cz + dz*scale_z/2, auto=None) + # Compute new limits in transformed coordinates and set + self._set_lims_from_transformed( + cx - dx*scale_x/2, cx + dx*scale_x/2, + cy - dy*scale_y/2, cy + dy*scale_y/2, + cz - dz*scale_z/2, cz + dz*scale_z/2) def _get_w_centers_ranges(self): - """Get 3D world centers and axis ranges.""" - # Calculate center of axis limits - minx, maxx, miny, maxy, minz, maxz = self.get_w_lims() + """ + Get 3D world centers and axis ranges in transformed coordinates. + + For non-linear scales (log, symlog, etc.), centers and ranges are + computed in transformed coordinates to ensure uniform zoom/pan behavior. + """ + # Get limits in transformed coordinates for non-linear scale zoom/pan + minx, maxx, miny, maxy, minz, maxz = self._get_scaled_limits() cx = (maxx + minx)/2 cy = (maxy + miny)/2 cz = (maxz + minz)/2 - # Calculate range of axis limits + # Calculate range of axis limits in transformed coordinates dx = (maxx - minx) dy = (maxy - miny) dz = (maxz - minz) diff --git a/lib/mpl_toolkits/mplot3d/axis3d.py b/lib/mpl_toolkits/mplot3d/axis3d.py index fdd22b717f67..0ac2e50b1a1a 100644 --- a/lib/mpl_toolkits/mplot3d/axis3d.py +++ b/lib/mpl_toolkits/mplot3d/axis3d.py @@ -267,14 +267,15 @@ def get_rotate_label(self, text): return len(text) > 4 def _get_coord_info(self): - mins, maxs = np.array([ - self.axes.get_xbound(), - self.axes.get_ybound(), - self.axes.get_zbound(), - ]).T - - # Project the bounds along the current position of the cube: - bounds = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2] + # Get scaled limits directly from the axes helper + xmin, xmax, ymin, ymax, zmin, zmax = self.axes._get_scaled_limits() + mins = np.array([xmin, ymin, zmin]) + maxs = np.array([xmax, ymax, zmax]) + + # Get data-space bounds for _transformed_cube + bounds = (*self.axes.get_xbound(), + *self.axes.get_ybound(), + *self.axes.get_zbound()) bounds_proj = self.axes._transformed_cube(bounds) # Determine which one of the parallel planes are higher up: @@ -443,6 +444,10 @@ def _draw_ticks(self, renderer, edgep1, centers, deltas, highs, mins, maxs, tc, highs = self._get_coord_info() centers, deltas = self._calc_centers_deltas(maxs, mins) + # Get the scale transform for this axis to transform tick locations + axis = [self.axes.xaxis, self.axes.yaxis, self.axes.zaxis][index] + axis_trans = axis.get_transform() + # Draw ticks: tickdir = self._get_tickdir(pos) tickdelta = deltas[tickdir] if highs[tickdir] else -deltas[tickdir] @@ -457,10 +462,11 @@ def _draw_ticks(self, renderer, edgep1, centers, deltas, highs, default_label_offset = 8. # A rough estimate points = deltas_per_point * deltas + # All coordinates below are in transformed coordinates for proper projection for tick in ticks: # Get tick line positions pos = edgep1.copy() - pos[index] = tick.get_loc() + pos[index] = axis_trans.transform([tick.get_loc()])[0] pos[tickdir] = out_tickdir x1, y1, z1 = proj3d.proj_transform(*pos, self.axes.M) pos[tickdir] = in_tickdir @@ -468,7 +474,6 @@ def _draw_ticks(self, renderer, edgep1, centers, deltas, highs, # Get position of label labeldeltas = (tick.get_pad() + default_label_offset) * points - pos[tickdir] = edgep1_tickdir pos = _move_from_center(pos, centers, labeldeltas, self._axmask()) lx, ly, lz = proj3d.proj_transform(*pos, self.axes.M) @@ -642,10 +647,15 @@ def draw_grid(self, renderer): info = self._axinfo index = info["i"] + # Grid lines use data-space bounds (Line3DCollection applies transforms) mins, maxs, tc, highs = self._get_coord_info() - - minmax = np.where(highs, maxs, mins) - maxmin = np.where(~highs, maxs, mins) + xlim, ylim, zlim = (self.axes.get_xbound(), + self.axes.get_ybound(), + self.axes.get_zbound()) + data_mins = np.array([xlim[0], ylim[0], zlim[0]]) + data_maxs = np.array([xlim[1], ylim[1], zlim[1]]) + minmax = np.where(highs, data_maxs, data_mins) + maxmin = np.where(~highs, data_maxs, data_mins) # Grid points where the planes meet xyz0 = np.tile(minmax, (len(ticks), 1)) diff --git a/lib/mpl_toolkits/mplot3d/proj3d.py b/lib/mpl_toolkits/mplot3d/proj3d.py index 87c59ae05714..81a5aacbdded 100644 --- a/lib/mpl_toolkits/mplot3d/proj3d.py +++ b/lib/mpl_toolkits/mplot3d/proj3d.py @@ -131,6 +131,33 @@ def _ortho_transformation(zfront, zback): return proj_matrix +def _apply_scale_transforms(xs, ys, zs, axes): + """ + Apply axis scale transforms to 3D coordinates. + + Transforms data coordinates to transformed coordinates (applying log, + symlog, etc.) for 3D projection. Preserves masked arrays. + """ + def transform_coord(coord, axis): + coord = np.asanyarray(coord) + data = np.ma.getdata(coord).ravel() + return axis.get_transform().transform(data).reshape(coord.shape) + + xs_scaled = transform_coord(xs, axes.xaxis) + ys_scaled = transform_coord(ys, axes.yaxis) + zs_scaled = transform_coord(zs, axes.zaxis) + + # Preserve combined mask from any masked input + masks = [np.ma.getmask(a) for a in [xs, ys, zs]] + if any(m is not np.ma.nomask for m in masks): + combined = np.ma.mask_or(np.ma.mask_or(masks[0], masks[1]), masks[2]) + xs_scaled = np.ma.array(xs_scaled, mask=combined) + ys_scaled = np.ma.array(ys_scaled, mask=combined) + zs_scaled = np.ma.array(zs_scaled, mask=combined) + + return xs_scaled, ys_scaled, zs_scaled + + def _proj_transform_vec(vec, M): vecw = np.dot(M, vec.data) ts = vecw[0:3]/vecw[3] @@ -139,27 +166,24 @@ def _proj_transform_vec(vec, M): return ts[0], ts[1], ts[2] -def _proj_transform_vectors(vecs, M): +def _scale_proj_transform_vectors(vecs, axes): """ - Vectorized version of ``_proj_transform_vec``. + Apply scale transforms and project vectors. Parameters ---------- vecs : ... x 3 np.ndarray - Input vectors - M : 4 x 4 np.ndarray - Projection matrix + Input vectors. + axes : Axes3D + The 3D axes (used for scale transforms and projection matrix). """ - vecs_shape = vecs.shape - vecs = vecs.reshape(-1, 3).T - - vecs_pad = np.empty((vecs.shape[0] + 1,) + vecs.shape[1:]) - vecs_pad[:-1] = vecs - vecs_pad[-1] = 1 - product = np.dot(M, vecs_pad) + result_shape = vecs.shape + xs, ys, zs = _apply_scale_transforms( + vecs[..., 0], vecs[..., 1], vecs[..., 2], axes) + vec = _vec_pad_ones(xs.ravel(), ys.ravel(), zs.ravel()) + product = np.dot(axes.M, vec) tvecs = product[:3] / product[3] - - return tvecs.T.reshape(vecs_shape) + return tvecs.T.reshape(result_shape) def _proj_transform_vec_clip(vec, M, focal_length): @@ -213,24 +237,33 @@ def proj_transform(xs, ys, zs, M): @_api.deprecated("3.10") def proj_transform_clip(xs, ys, zs, M): - return _proj_transform_clip(xs, ys, zs, M, focal_length=np.inf) + vec = _vec_pad_ones(xs, ys, zs) + return _proj_transform_vec_clip(vec, M, focal_length=np.inf) -def _proj_transform_clip(xs, ys, zs, M, focal_length): +def _scale_proj_transform_clip(xs, ys, zs, axes): """ - Transform the points by the projection matrix - and return the clipping result - returns txs, tys, tzs, tis + Apply scale transforms, project, and return clipping result. + + Returns txs, tys, tzs, tis. """ + xs, ys, zs = _apply_scale_transforms(xs, ys, zs, axes) vec = _vec_pad_ones(xs, ys, zs) - return _proj_transform_vec_clip(vec, M, focal_length) - - -def _proj_points(points, M): - return np.column_stack(_proj_trans_points(points, M)) + return _proj_transform_vec_clip(vec, axes.M, axes._focal_length) def _proj_trans_points(points, M): points = np.asanyarray(points) xs, ys, zs = points[:, 0], points[:, 1], points[:, 2] return proj_transform(xs, ys, zs, M) + + +def _scale_proj_transform(xs, ys, zs, axes): + """ + Apply scale transforms and project. + + Combines `_apply_scale_transforms` and `proj_transform` into a single + call. Returns txs, tys, tzs. + """ + xs, ys, zs = _apply_scale_transforms(xs, ys, zs, axes) + return proj_transform(xs, ys, zs, axes.M) diff --git a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_all_scales.png b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_all_scales.png new file mode 100644 index 000000000000..af1411dbfc9c Binary files /dev/null and b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_all_scales.png differ diff --git a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_artists_log.png b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_artists_log.png new file mode 100644 index 000000000000..e5180b57fa9a Binary files /dev/null and b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_artists_log.png differ diff --git a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_log_bases.png b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_log_bases.png new file mode 100644 index 000000000000..875c91e07f67 Binary files /dev/null and b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_log_bases.png differ diff --git a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_symlog_params.png b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_symlog_params.png new file mode 100644 index 000000000000..73732dea1284 Binary files /dev/null and b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/scale3d_symlog_params.png differ diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index ec7295981969..b953fd415955 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -18,7 +18,7 @@ from matplotlib.patches import Circle, PathPatch from matplotlib.path import Path from matplotlib.text import Text -from matplotlib import _api +from matplotlib import _api import matplotlib.pyplot as plt import numpy as np @@ -2844,3 +2844,340 @@ def test_ctrl_rotation_snaps_to_5deg(): assert ax.roll == pytest.approx(expected_roll) plt.close(fig) + + +# ============================================================================= +# Tests for 3D scale transforms (log, symlog, logit, etc.) +# ============================================================================= + +def _make_log_data(): + """Data spanning 1 to ~1000 for log scale.""" + t = np.linspace(0, 2 * np.pi, 50) + x = 10 ** (t / 2) + y = 10 ** (1 + np.sin(t)) + z = 10 ** (2 * (1 + np.cos(t) / 2)) + return x, y, z + + +def _make_surface_log_data(): + """Grid data for surface with positive Z.""" + x = np.linspace(1, 10, 20) + y = np.linspace(1, 10, 20) + X, Y = np.meshgrid(x, y) + Z = X * Y + return X, Y, Z + + +def _make_triangulation_data(): + """Data for trisurf with positive values.""" + np.random.seed(42) + x = np.random.uniform(1, 100, 100) + y = np.random.uniform(1, 100, 100) + z = x * y / 10 + return x, y, z + + +@mpl3d_image_comparison(['scale3d_artists_log.png'], style='mpl20', + remove_text=False, tol=0.03) +def test_scale3d_artists_log(): + """Test all 3D artist types with log scale.""" + fig = plt.figure(figsize=(16, 12)) + log_kw = dict(xscale='log', yscale='log', zscale='log') + line_data = _make_log_data() + surf_X, surf_Y, surf_Z = _make_surface_log_data() + + # Row 1: plot, wireframe, scatter, bar3d + ax = fig.add_subplot(3, 4, 1, projection='3d') + ax.plot(*line_data) + ax.set(**log_kw, title='plot') + + ax = fig.add_subplot(3, 4, 2, projection='3d') + ax.plot_wireframe(surf_X, surf_Y, surf_Z, rstride=5, cstride=5) + ax.set(**log_kw, title='wireframe') + + ax = fig.add_subplot(3, 4, 3, projection='3d') + ax.scatter(*line_data, c=line_data[2], cmap='viridis') + ax.set(**log_kw, title='scatter') + + ax = fig.add_subplot(3, 4, 4, projection='3d') + bx, by = np.meshgrid([1, 10, 100], [1, 10, 100]) + bx, by = bx.flatten(), by.flatten() + ax.bar3d(bx, by, np.ones_like(bx, dtype=float), + bx * 0.3, by * 0.3, bx * by / 10, alpha=0.8) + ax.set(**log_kw, title='bar3d') + + # Row 2: surface, trisurf, contour, contourf + ax = fig.add_subplot(3, 4, 5, projection='3d') + ax.plot_surface(surf_X, surf_Y, surf_Z, cmap='viridis', alpha=0.8) + ax.set(**log_kw, title='surface') + + ax = fig.add_subplot(3, 4, 6, projection='3d') + tri_data = _make_triangulation_data() + ax.plot_trisurf(*tri_data, cmap='viridis', alpha=0.8) + ax.set(**log_kw, title='trisurf') + + ax = fig.add_subplot(3, 4, 7, projection='3d') + ax.contour(surf_X, surf_Y, surf_Z, levels=10) + ax.set(**log_kw, title='contour') + + ax = fig.add_subplot(3, 4, 8, projection='3d') + ax.contourf(surf_X, surf_Y, surf_Z, levels=10, alpha=0.8) + ax.set(**log_kw, title='contourf') + + # Row 3: stem, quiver, text + ax = fig.add_subplot(3, 4, 9, projection='3d') + ax.stem([1, 10, 100], [1, 10, 100], [10, 100, 1000], bottom=1) + ax.set(**log_kw, title='stem') + + ax = fig.add_subplot(3, 4, 10, projection='3d') + qxyz = np.array([1, 10, 100]) + ax.quiver(qxyz, qxyz, qxyz, qxyz * 0.5, qxyz * 0.5, qxyz * 0.5) + ax.set(**log_kw, title='quiver') + + ax = fig.add_subplot(3, 4, 11, projection='3d') + ax.text(1, 1, 1, "Point A") + ax.text(10, 10, 10, "Point B") + ax.text(100, 100, 100, "Point C") + ax.set(**log_kw, title='text', + xlim=(0.5, 200), ylim=(0.5, 200), zlim=(0.5, 200)) + + +@mpl3d_image_comparison(['scale3d_all_scales.png'], style='mpl20', remove_text=False) +def test_scale3d_all_scales(): + """Test all scale types with mixed scales on each axis.""" + fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'}, figsize=(10, 6)) + + # Data that works across all scale types + t = np.linspace(0.1, 0.9, 30) + # x: positive for log/asinh, y: spans neg/pos for symlog, z: (0,1) for logit + x = t * 100 # 10 to 90 + y = (t - 0.5) * 20 # -10 to 10 + z = t # 0.1 to 0.9 + + # Subplot 1: x=log, y=symlog, z=logit + axs[0].scatter(x, y, z) + axs[0].set(xscale='log', yscale='symlog', zscale='logit', + xlabel='log', ylabel='symlog', zlabel='logit') + + # Subplot 2: x=asinh, y=linear, z=function (square root) + axs[1].scatter(x, y, z) + axs[1].set_xscale('asinh') + axs[1].set_zscale('function', functions=(lambda v: v**0.5, lambda v: v**2)) + axs[1].set(xlabel='asinh', ylabel='linear', zlabel='function') + + +@pytest.mark.parametrize("scale, expected_lims", [ + ("linear", (-0.020833333333333332, 1.0208333333333333)), + ("log", (0.03640537388223389, 1.1918138759519783)), + ("symlog", (-0.020833333333333332, 1.0208333333333333)), + ("logit", (0.029640777806688817, 0.9703592221933112)), + ("asinh", (-0.020833333333333332, 1.0208333333333333)), +]) +@mpl.style.context("default") +def test_scale3d_default_limits(scale, expected_lims): + """Default axis limits on an empty plot should be correct for each scale.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set_xscale(scale) + ax.set_yscale(scale) + ax.set_zscale(scale) + fig.canvas.draw() + + for get_lim in (ax.get_xlim, ax.get_ylim, ax.get_zlim): + np.testing.assert_allclose(get_lim(), expected_lims) + + +@check_figures_equal() +@pytest.mark.filterwarnings("ignore:Data has no positive values") +def test_scale3d_all_clipped(fig_test, fig_ref): + """Fully clipped data (e.g. negative values on log) should look like an empty plot. + """ + lims = (0.1, 10) + for ax in [fig_test.add_subplot(projection='3d'), + fig_ref.add_subplot(projection='3d')]: + ax.set_xscale('log') + ax.set_yscale('log') + ax.set_zscale('log') + ax.set(xlim=lims, ylim=lims, zlim=lims) + + # All negative data — everything is invalid for log scale + fig_test.axes[0].plot([-1, -2, -3], [-4, -5, -6], [-7, -8, -9]) + + +@mpl3d_image_comparison(['scale3d_log_bases.png'], style='mpl20', remove_text=False) +def test_scale3d_log_bases(): + """Test log scale with different bases and subs.""" + fig, axs = plt.subplots(2, 2, subplot_kw={'projection': '3d'}, figsize=(10, 8)) + x, y, z = _make_log_data() + + for ax, base, title in [(axs[0, 0], 10, 'base=10'), + (axs[0, 1], 2, 'base=2'), + (axs[1, 0], np.e, 'base=e')]: + ax.scatter(x, y, z, s=10) + ax.set_xscale('log', base=base) + ax.set_yscale('log', base=base) + ax.set_zscale('log', base=base) + ax.set_title(title) + if base == np.e: + # Format tick labels as e^n instead of 2.718...^n + def fmt_e(x, pos=None): + if x <= 0: + return '' + exp = np.log(x) + if np.isclose(exp, round(exp)): + return r'$e^{%d}$' % round(exp) + return '' + ax.xaxis.set_major_formatter(fmt_e) + ax.yaxis.set_major_formatter(fmt_e) + ax.zaxis.set_major_formatter(fmt_e) + + # subs + axs[1, 1].scatter(x, y, z, s=10) + axs[1, 1].set_xscale('log', subs=[2, 5]) + axs[1, 1].set_yscale('log', subs=[2, 5]) + axs[1, 1].set_zscale('log', subs=[2, 5]) + axs[1, 1].set_title('subs=[2,5]') + + +@mpl3d_image_comparison(['scale3d_symlog_params.png'], style='mpl20', + remove_text=False) +def test_scale3d_symlog_params(): + """Test symlog scale with different linthresh values.""" + fig, axs = plt.subplots(1, 2, subplot_kw={'projection': '3d'}) + + # Data spanning negative, zero, and positive + t = np.linspace(-3, 3, 50) + x = np.sinh(t) * 10 + y = t ** 3 + z = np.sign(t) * np.abs(t) ** 2 + + for ax, linthresh in [(axs[0], 0.1), (axs[1], 10)]: + ax.scatter(x, y, z, c=np.abs(z), cmap='viridis', s=10) + ax.set_xscale('symlog', linthresh=linthresh) + ax.set_yscale('symlog', linthresh=linthresh) + ax.set_zscale('symlog', linthresh=linthresh) + ax.set_title(f'linthresh={linthresh}') + + +@pytest.mark.parametrize('scale_type,kwargs', [ + ('log', {'base': 10}), + ('log', {'base': 2}), + ('log', {'subs': [2, 5]}), + ('log', {'nonpositive': 'mask'}), + ('symlog', {'base': 2}), + ('symlog', {'linthresh': 1}), + ('symlog', {'linscale': 0.5}), + ('symlog', {'subs': [2, 5]}), + ('asinh', {'linear_width': 0.5}), + ('asinh', {'base': 2}), + ('logit', {'nonpositive': 'clip'}), +]) +def test_scale3d_keywords_accepted(scale_type, kwargs): + """Verify that scale keywords are accepted on all 3 axes.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + for setter in [ax.set_xscale, ax.set_yscale, ax.set_zscale]: + setter(scale_type, **kwargs) + assert (ax.get_xscale(), ax.get_yscale(), ax.get_zscale()) == (scale_type,) * 3 + + +@pytest.mark.parametrize('axis', ['x', 'y', 'z']) +def test_scale3d_limit_range_log(axis): + """Log scale should warn when setting non-positive limits.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + getattr(ax, f'set_{axis}scale')('log') + + # Setting non-positive limits should warn + with pytest.warns(UserWarning, match="non-positive"): + getattr(ax, f'set_{axis}lim')(-10, 100) + + +def test_scale3d_limit_range_logit(): + """Logit scale should constrain axis to (0, 1).""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set(xscale='logit', yscale='logit', zscale='logit', + xlim=(-0.5, 1.5), ylim=(-0.5, 1.5), zlim=(-0.5, 1.5)) + + # Limits should be constrained to (0, 1) + for name, lim in [('x', ax.get_xlim()), ('y', ax.get_ylim()), + ('z', ax.get_zlim())]: + assert lim[0] > 0, f"{name} lower limit should be > 0 for logit" + assert lim[1] < 1, f"{name} upper limit should be < 1 for logit" + + +@pytest.mark.parametrize('scale_type', ['log', 'symlog', 'logit', 'asinh']) +def test_scale3d_transform_roundtrip(scale_type): + """Forward/inverse transform should preserve values.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set(xscale=scale_type, yscale=scale_type, zscale=scale_type) + + # Use appropriate test values for each scale type + test_values = { + 'log': [1, 10, 100, 1000], + 'symlog': [-100, -1, 0, 1, 100], + 'asinh': [-100, -1, 0, 1, 100], + 'logit': [0.01, 0.1, 0.5, 0.9, 0.99], + }[scale_type] + test_values = np.array(test_values) + + # Test round-trip for each axis + for axis in [ax.xaxis, ax.yaxis, ax.zaxis]: + trans = axis.get_transform() + forward = trans.transform(test_values.reshape(-1, 1)) + inverse = trans.inverted().transform(forward) + np.testing.assert_allclose(inverse.flatten(), test_values, rtol=1e-10) + + +def test_scale3d_invalid_keywords_raise(): + """Invalid kwargs should raise TypeError.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + + with pytest.raises(TypeError): + ax.set_xscale('log', invalid_kwarg=True) + + with pytest.raises(TypeError): + ax.set_yscale('symlog', invalid_kwarg=True) + + with pytest.raises(TypeError): + ax.set_zscale('logit', invalid_kwarg=True) + + +def test_scale3d_persists_after_plot(): + """Scale should persist after adding plot data.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set(xscale='log', yscale='log', zscale='log') + ax.plot(*_make_log_data()) + assert (ax.get_xscale(), ax.get_yscale(), ax.get_zscale()) == ('log',) * 3 + + +def test_scale3d_autoscale_with_log(): + """Autoscale should work correctly with log scale.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.set(xscale='log', yscale='log', zscale='log') + ax.scatter([1, 10, 100], [1, 10, 100], [1, 10, 100]) + + # All limits should be positive + for name, lim in [('x', ax.get_xlim()), ('y', ax.get_ylim()), + ('z', ax.get_zlim())]: + assert lim[0] > 0, f"{name} lower limit should be positive" + assert lim[1] > 0, f"{name} upper limit should be positive" + + +def test_scale3d_calc_coord(): + """_calc_coord should return data coordinates with correct pane values.""" + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + ax.scatter([1, 10, 100], [1, 10, 100], [1, 10, 100]) + ax.set(xscale='log', yscale='log', zscale='log') + fig.canvas.draw() + + point, pane_idx = ax._calc_coord(0.5, 0.5) + # Pane coordinate should match axis limit (y-pane at max) + assert pane_idx == 1 + assert point[pane_idx] == pytest.approx(ax.get_ylim()[1])