diff --git a/spatialmath/base/animate.py b/spatialmath/base/animate.py index 3876a2ea..7654a5a0 100755 --- a/spatialmath/base/animate.py +++ b/spatialmath/base/animate.py @@ -104,6 +104,15 @@ def __init__( # ax.set_zlim(dims[4:6]) # # ax.set_aspect('equal') ax = smb.plotvol3(ax=ax, dim=dim) + if dim is not None: + dim = list(np.ndarray.flatten(np.array(dim))) + if len(dim) == 2: + dim = dim * 3 + elif len(dim) != 6: + raise ValueError(f"dim must have 2 or 6 elements, got {dim}. See docstring for details.") + ax.set_xlim(dim[0:2]) + ax.set_ylim(dim[2:4]) + ax.set_zlim(dim[4:]) self.ax = ax @@ -208,10 +217,12 @@ def update(frame, animation): if isinstance(frame, float): # passed a single transform, interpolate it T = smb.trinterp(start=self.start, end=self.end, s=frame) - else: - # assume it is an SO(3) or SE(3) + elif isinstance(frame, NDArray): + # type is SO3Array or SE3Array when Animate.trajectory is not None T = frame - # ensure result is SE(3) + else: + # [unlikely] other types are converted to np array + T = np.array(frame) if T.shape == (3, 3): T = smb.r2t(T) @@ -309,7 +320,7 @@ def __init__(self, anim: Animate, h, xs, ys, zs): self.anim = anim def draw(self, T): - p = T.A @ self.p + p = T @ self.p self.h.set_data(p[0, :], p[1, :]) self.h.set_3d_properties(p[2, :]) @@ -367,7 +378,7 @@ def __init__(self, anim, h): def draw(self, T): # import ipdb; ipdb.set_trace() - p = T.A @ self.p + p = T @ self.p # reshape it p = p[0:3, :].T.reshape(3, 2, 3) @@ -421,7 +432,7 @@ def __init__(self, anim, h, x, y, z): self.anim = anim def draw(self, T): - p = T.A @ self.p + p = T @ self.p # x2, y2, _ = proj3d.proj_transform( # p[0], p[1], p[2], self.anim.ax.get_proj()) # self.h.set_position((x2, y2)) @@ -546,8 +557,6 @@ def __init__( axes.set_xlim(dims[0:2]) axes.set_ylim(dims[2:4]) # ax.set_aspect('equal') - else: - axes.autoscale(enable=True, axis="both") self.ax = axes diff --git a/spatialmath/base/transforms2d.py b/spatialmath/base/transforms2d.py index c64ed036..682ea0ca 100644 --- a/spatialmath/base/transforms2d.py +++ b/spatialmath/base/transforms2d.py @@ -1510,12 +1510,9 @@ def tranimate2(T: Union[SO2Array, SE2Array], **kwargs): tranimate2(transl(1,2)@trot2(1), frame='A', arrow=False, dims=[0, 5]) tranimate2(transl(1,2)@trot2(1), frame='A', arrow=False, dims=[0, 5], movie='spin.mp4') """ - anim = smb.animate.Animate2(**kwargs) - try: - del kwargs["dims"] - except KeyError: - pass - + dims = kwargs.pop("dims", None) + ax = kwargs.pop("ax", None) + anim = smb.animate.Animate2(dims=dims, axes=ax, **kwargs) anim.trplot2(T, **kwargs) return anim.run(**kwargs) diff --git a/spatialmath/base/transforms3d.py b/spatialmath/base/transforms3d.py index ceff8732..3617f965 100644 --- a/spatialmath/base/transforms3d.py +++ b/spatialmath/base/transforms3d.py @@ -3409,12 +3409,9 @@ def tranimate(T: Union[SO3Array, SE3Array], **kwargs) -> str: :seealso: `trplot`, `plotvol3` """ - anim = Animate(**kwargs) - try: - del kwargs["dims"] - except KeyError: - pass - + dim = kwargs.pop("dims", None) + ax = kwargs.pop("ax", None) + anim = Animate(dim=dim, ax=ax, **kwargs) anim.trplot(T, **kwargs) return anim.run(**kwargs)