Skip to content

Commit 59873f1

Browse files
authored
compute mean of a set of rotations (#160)
1 parent 2306628 commit 59873f1

File tree

4 files changed

+223
-29
lines changed

4 files changed

+223
-29
lines changed

spatialmath/pose3d.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -843,22 +843,22 @@ def Exp(
843843

844844
def UnitQuaternion(self) -> UnitQuaternion:
845845
"""
846-
SO3 as a unit quaternion instance
846+
SO3 as a unit quaternion instance
847847
848-
:return: a unit quaternion representation
849-
:rtype: UnitQuaternion instance
848+
:return: a unit quaternion representation
849+
:rtype: UnitQuaternion instance
850850
851-
``R.UnitQuaternion()`` is an ``UnitQuaternion`` instance representing the same rotation
852-
as the SO3 rotation ``R``.
851+
``R.UnitQuaternion()`` is an ``UnitQuaternion`` instance representing the same rotation
852+
as the SO3 rotation ``R``.
853853
854-
Example:
854+
Example:
855855
856-
.. runblock:: pycon
856+
.. runblock:: pycon
857857
858-
>>> from spatialmath import SO3
859-
>>> SO3.Rz(0.3).UnitQuaternion()
858+
>>> from spatialmath import SO3
859+
>>> SO3.Rz(0.3).UnitQuaternion()
860860
861-
"""
861+
"""
862862
# Function level import to avoid circular dependencies
863863
from spatialmath import UnitQuaternion
864864

@@ -931,6 +931,29 @@ def angdist(self, other: SO3, metric: int = 6) -> Union[float, ndarray]:
931931
else:
932932
return ad
933933

934+
def mean(self, tol: float = 20) -> SO3:
935+
"""Mean of a set of rotations
936+
937+
:param tol: iteration tolerance in units of eps, defaults to 20
938+
:type tol: float, optional
939+
:return: the mean rotation
940+
:rtype: :class:`SO3` instance.
941+
942+
Computes the Karcher mean of the set of rotations within the SO(3) instance.
943+
944+
:references:
945+
- `**Hartley, Trumpf** - "Rotation Averaging" - IJCV 2011 <https://users.cecs.anu.edu.au/~hartley/Papers/PDF/Hartley-Trumpf:Rotation-averaging:IJCV.pdf>`_, Algorithm 1, page 15.
946+
- `Karcher mean <https://en.wikipedia.org/wiki/Karcher_mean>`_
947+
"""
948+
949+
eta = tol * np.finfo(float).eps
950+
R_mean = self[0] # initial guess
951+
while True:
952+
r = np.dstack((R_mean.inv() * self).log()).mean(axis=2)
953+
if np.linalg.norm(r) < eta:
954+
return R_mean
955+
R_mean = R_mean @ self.Exp(r) # update estimate and normalize
956+
934957

935958
# ============================== SE3 =====================================#
936959

spatialmath/quaternion.py

+64-14
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True):
4545
r"""
4646
Construct a new quaternion
4747
48-
:param s: scalar
49-
:type s: float
50-
:param v: vector
51-
:type v: 3-element array_like
48+
:param s: scalar part
49+
:type s: float or ndarray(N)
50+
:param v: vector part
51+
:type v: ndarray(3), ndarray(Nx3)
5252
5353
- ``Quaternion()`` constructs a zero quaternion
5454
- ``Quaternion(s, v)`` construct a new quaternion from the scalar ``s``
@@ -78,7 +78,7 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True):
7878
super().__init__()
7979

8080
if s is None and smb.isvector(v, 4):
81-
v,s = (s,v)
81+
v, s = (s, v)
8282

8383
if v is None:
8484
# single argument
@@ -92,6 +92,11 @@ def __init__(self, s: Any = None, v=None, check: Optional[bool] = True):
9292
# Quaternion(s, v)
9393
self.data = [np.r_[s, smb.getvector(v)]]
9494

95+
elif (
96+
smb.isvector(s) and smb.ismatrix(v, (None, 3)) and s.shape[0] == v.shape[0]
97+
):
98+
# Quaternion(s, v) where s and v are arrays
99+
self.data = [np.r_[_s, _v] for _s, _v in zip(s, v)]
95100
else:
96101
raise ValueError("bad argument to Quaternion constructor")
97102

@@ -395,9 +400,23 @@ def log(self) -> Quaternion:
395400
:seealso: :meth:`Quaternion.exp` :meth:`Quaternion.log` :meth:`UnitQuaternion.angvec`
396401
"""
397402
norm = self.norm()
398-
s = math.log(norm)
399-
v = math.acos(np.clip(self.s / norm, -1, 1)) * smb.unitvec(self.v)
400-
return Quaternion(s=s, v=v)
403+
s = np.log(norm)
404+
if len(self) == 1:
405+
if smb.iszerovec(self._A[1:4]):
406+
v = np.zeros((3,))
407+
else:
408+
v = math.acos(np.clip(self._A[0] / norm, -1, 1)) * smb.unitvec(
409+
self._A[1:4]
410+
)
411+
return Quaternion(s=s, v=v)
412+
else:
413+
v = [
414+
np.zeros((3,))
415+
if smb.iszerovec(A[1:4])
416+
else math.acos(np.clip(A[0] / n, -1, 1)) * smb.unitvec(A[1:4])
417+
for A, n in zip(self._A, norm)
418+
]
419+
return Quaternion(s=s, v=np.array(v))
401420

402421
def exp(self, tol: float = 20) -> Quaternion:
403422
r"""
@@ -437,7 +456,11 @@ def exp(self, tol: float = 20) -> Quaternion:
437456
exp_s = math.exp(self.s)
438457
norm_v = smb.norm(self.v)
439458
s = exp_s * math.cos(norm_v)
440-
v = exp_s * self.v / norm_v * math.sin(norm_v)
459+
if smb.iszerovec(self.v, tol * _eps):
460+
# result will be a unit quaternion
461+
v = np.zeros((3,))
462+
else:
463+
v = exp_s * self.v / norm_v * math.sin(norm_v)
441464
if abs(self.s) < tol * _eps:
442465
# result will be a unit quaternion
443466
return UnitQuaternion(s=s, v=v)
@@ -1260,7 +1283,7 @@ def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternio
12601283
Construct a new unit quaternion from Euler angles
12611284
12621285
:param 𝚪: 3-vector of Euler angles
1263-
:type 𝚪: array_like
1286+
:type 𝚪: 3 floats, array_like(3) or ndarray(N,3)
12641287
:param unit: angular units: 'rad' [default], or 'deg'
12651288
:type unit: str
12661289
:return: unit-quaternion
@@ -1286,20 +1309,23 @@ def Eul(cls, *angles: List[float], unit: Optional[str] = "rad") -> UnitQuaternio
12861309
if len(angles) == 1:
12871310
angles = angles[0]
12881311

1289-
return cls(smb.r2q(smb.eul2r(angles, unit=unit)), check=False)
1312+
if smb.isvector(angles, 3):
1313+
return cls(smb.r2q(smb.eul2r(angles, unit=unit)), check=False)
1314+
else:
1315+
return cls([smb.r2q(smb.eul2r(a, unit=unit)) for a in angles], check=False)
12901316

12911317
@classmethod
12921318
def RPY(
12931319
cls,
1294-
*angles: List[float],
1320+
*angles,
12951321
order: Optional[str] = "zyx",
12961322
unit: Optional[str] = "rad",
12971323
) -> UnitQuaternion:
12981324
r"""
12991325
Construct a new unit quaternion from roll-pitch-yaw angles
13001326
13011327
:param 𝚪: 3-vector of roll-pitch-yaw angles
1302-
:type 𝚪: array_like
1328+
:type 𝚪: 3 floats, array_like(3) or ndarray(N,3)
13031329
:param unit: angular units: 'rad' [default], or 'deg'
13041330
:type unit: str
13051331
:param unit: rotation order: 'zyx' [default], 'xyz', or 'yxz'
@@ -1341,7 +1367,13 @@ def RPY(
13411367
if len(angles) == 1:
13421368
angles = angles[0]
13431369

1344-
return cls(smb.r2q(smb.rpy2r(angles, unit=unit, order=order)), check=False)
1370+
if smb.isvector(angles, 3):
1371+
return cls(smb.r2q(smb.rpy2r(angles, unit=unit, order=order)), check=False)
1372+
else:
1373+
return cls(
1374+
[smb.r2q(smb.rpy2r(a, unit=unit, order=order)) for a in angles],
1375+
check=False,
1376+
)
13451377

13461378
@classmethod
13471379
def OA(cls, o: ArrayLike3, a: ArrayLike3) -> UnitQuaternion:
@@ -1569,6 +1601,24 @@ def dotb(self, omega: ArrayLike3) -> R4:
15691601
"""
15701602
return smb.qdotb(self._A, omega)
15711603

1604+
# def mean(self, tol: float = 20) -> SO3:
1605+
# """Mean of a set of rotations
1606+
1607+
# :param tol: iteration tolerance in units of eps, defaults to 20
1608+
# :type tol: float, optional
1609+
# :return: the mean rotation
1610+
# :rtype: :class:`UnitQuaternion` instance.
1611+
1612+
# Computes the Karcher mean of the set of rotations within the unit quaternion instance.
1613+
1614+
# :references:
1615+
# - `**Hartley, Trumpf** - "Rotation Averaging" - IJCV 2011 <https://users.cecs.anu.edu.au/~hartley/Papers/PDF/Hartley-Trumpf:Rotation-averaging:IJCV.pdf>`_
1616+
# - `Karcher mean <https://en.wikipedia.org/wiki/Karcher_mean`_
1617+
# """
1618+
1619+
# R_mean = self.SO3().mean(tol=tol)
1620+
# return R_mean.UnitQuaternion()
1621+
15721622
def __mul__(
15731623
left, right: UnitQuaternion
15741624
) -> UnitQuaternion: # lgtm[py/not-named-self] pylint: disable=no-self-argument

tests/test_pose3d.py

+31
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,37 @@ def test_functions_lie(self):
717717
nt.assert_equal(R, SO3.EulerVec(R.eulervec()))
718718
np.testing.assert_equal((R.inv() * R).eulervec(), np.zeros(3))
719719

720+
R = SO3() # identity matrix case
721+
722+
# Check log and exponential map
723+
nt.assert_equal(R, SO3.Exp(R.log()))
724+
np.testing.assert_equal((R.inv() * R).log(), np.zeros([3, 3]))
725+
726+
# Check euler vector map
727+
nt.assert_equal(R, SO3.EulerVec(R.eulervec()))
728+
np.testing.assert_equal((R.inv() * R).eulervec(), np.zeros(3))
729+
730+
def test_mean(self):
731+
rpy = np.ones((100, 1)) @ np.c_[0.1, 0.2, 0.3]
732+
R = SO3.RPY(rpy)
733+
self.assertEqual(len(R), 100)
734+
m = R.mean()
735+
self.assertIsInstance(m, SO3)
736+
array_compare(m, R[0])
737+
738+
# range of angles, mean should be the middle one, index=25
739+
R = SO3.Rz(np.linspace(start=0.3, stop=0.7, num=51))
740+
m = R.mean()
741+
self.assertIsInstance(m, SO3)
742+
array_compare(m, R[25])
743+
744+
# now add noise
745+
rng = np.random.default_rng(0) # reproducible random numbers
746+
rpy += rng.normal(scale=0.00001, size=(100, 3))
747+
R = SO3.RPY(rpy)
748+
m = R.mean()
749+
array_compare(m, SO3.RPY(0.1, 0.2, 0.3))
750+
720751

721752
# ============================== SE3 =====================================#
722753

tests/test_quaternion.py

+95-5
Original file line numberDiff line numberDiff line change
@@ -257,25 +257,79 @@ def test_staticconstructors(self):
257257
UnitQuaternion.Rz(theta, "deg").R, rotz(theta, "deg")
258258
)
259259

260+
def test_constructor_RPY(self):
260261
# 3 angle
262+
q = UnitQuaternion.RPY([0.1, 0.2, 0.3])
263+
self.assertIsInstance(q, UnitQuaternion)
264+
self.assertEqual(len(q), 1)
265+
nt.assert_array_almost_equal(q.R, rpy2r(0.1, 0.2, 0.3))
266+
q = UnitQuaternion.RPY(0.1, 0.2, 0.3)
267+
self.assertIsInstance(q, UnitQuaternion)
268+
self.assertEqual(len(q), 1)
269+
nt.assert_array_almost_equal(q.R, rpy2r(0.1, 0.2, 0.3))
270+
q = UnitQuaternion.RPY(np.r_[0.1, 0.2, 0.3])
271+
self.assertIsInstance(q, UnitQuaternion)
272+
self.assertEqual(len(q), 1)
273+
nt.assert_array_almost_equal(q.R, rpy2r(0.1, 0.2, 0.3))
274+
261275
nt.assert_array_almost_equal(
262-
UnitQuaternion.RPY([0.1, 0.2, 0.3]).R, rpy2r(0.1, 0.2, 0.3)
276+
UnitQuaternion.RPY([10, 20, 30], unit="deg").R,
277+
rpy2r(10, 20, 30, unit="deg"),
263278
)
264-
265279
nt.assert_array_almost_equal(
266-
UnitQuaternion.Eul([0.1, 0.2, 0.3]).R, eul2r(0.1, 0.2, 0.3)
280+
UnitQuaternion.RPY([0.1, 0.2, 0.3], order="xyz").R,
281+
rpy2r(0.1, 0.2, 0.3, order="xyz"),
282+
)
283+
284+
angles = np.array(
285+
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]]
267286
)
287+
q = UnitQuaternion.RPY(angles)
288+
self.assertIsInstance(q, UnitQuaternion)
289+
self.assertEqual(len(q), 4)
290+
for i in range(4):
291+
nt.assert_array_almost_equal(q[i].R, rpy2r(angles[i, :]))
292+
293+
q = UnitQuaternion.RPY(angles, order="xyz")
294+
self.assertIsInstance(q, UnitQuaternion)
295+
self.assertEqual(len(q), 4)
296+
for i in range(4):
297+
nt.assert_array_almost_equal(q[i].R, rpy2r(angles[i, :], order="xyz"))
268298

299+
angles *= 10
300+
q = UnitQuaternion.RPY(angles, unit="deg")
301+
self.assertIsInstance(q, UnitQuaternion)
302+
self.assertEqual(len(q), 4)
303+
for i in range(4):
304+
nt.assert_array_almost_equal(q[i].R, rpy2r(angles[i, :], unit="deg"))
305+
306+
def test_constructor_Eul(self):
269307
nt.assert_array_almost_equal(
270-
UnitQuaternion.RPY([10, 20, 30], unit="deg").R,
271-
rpy2r(10, 20, 30, unit="deg"),
308+
UnitQuaternion.Eul([0.1, 0.2, 0.3]).R, eul2r(0.1, 0.2, 0.3)
272309
)
273310

274311
nt.assert_array_almost_equal(
275312
UnitQuaternion.Eul([10, 20, 30], unit="deg").R,
276313
eul2r(10, 20, 30, unit="deg"),
277314
)
278315

316+
angles = np.array(
317+
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]]
318+
)
319+
q = UnitQuaternion.Eul(angles)
320+
self.assertIsInstance(q, UnitQuaternion)
321+
self.assertEqual(len(q), 4)
322+
for i in range(4):
323+
nt.assert_array_almost_equal(q[i].R, eul2r(angles[i, :]))
324+
325+
angles *= 10
326+
q = UnitQuaternion.Eul(angles, unit="deg")
327+
self.assertIsInstance(q, UnitQuaternion)
328+
self.assertEqual(len(q), 4)
329+
for i in range(4):
330+
nt.assert_array_almost_equal(q[i].R, eul2r(angles[i, :], unit="deg"))
331+
332+
def test_constructor_AngVec(self):
279333
# (theta, v)
280334
th = 0.2
281335
v = unitvec([1, 2, 3])
@@ -286,6 +340,7 @@ def test_staticconstructors(self):
286340
)
287341
nt.assert_array_almost_equal(UnitQuaternion.AngVec(th, -v).R, angvec2r(th, -v))
288342

343+
def test_constructor_EulerVec(self):
289344
# (theta, v)
290345
th = 0.2
291346
v = unitvec([1, 2, 3])
@@ -830,6 +885,20 @@ def test_log(self):
830885
nt.assert_array_almost_equal(q1.log().exp(), q1)
831886
nt.assert_array_almost_equal(q2.log().exp(), q2)
832887

888+
q = Quaternion([q1, q2, q1, q2])
889+
qlog = q.log()
890+
nt.assert_array_almost_equal(qlog[0].exp(), q1)
891+
nt.assert_array_almost_equal(qlog[1].exp(), q2)
892+
nt.assert_array_almost_equal(qlog[2].exp(), q1)
893+
nt.assert_array_almost_equal(qlog[3].exp(), q2)
894+
895+
q = UnitQuaternion() # identity
896+
qlog = q.log()
897+
nt.assert_array_almost_equal(qlog.vec, np.zeros(4))
898+
qq = qlog.exp()
899+
self.assertIsInstance(qq, UnitQuaternion)
900+
nt.assert_array_almost_equal(qq.vec, np.r_[1, 0, 0, 0])
901+
833902
def test_concat(self):
834903
u = Quaternion()
835904
uu = Quaternion([u, u, u, u])
@@ -1018,6 +1087,27 @@ def test_miscellany(self):
10181087
nt.assert_equal(q.inner(q), q.norm() ** 2)
10191088
nt.assert_equal(q.inner(u), np.dot(q.vec, u.vec))
10201089

1090+
# def test_mean(self):
1091+
# rpy = np.ones((100, 1)) @ np.c_[0.1, 0.2, 0.3]
1092+
# q = UnitQuaternion.RPY(rpy)
1093+
# self.assertEqual(len(q), 100)
1094+
# m = q.mean()
1095+
# self.assertIsInstance(m, UnitQuaternion)
1096+
# nt.assert_array_almost_equal(m.vec, q[0].vec)
1097+
1098+
# # range of angles, mean should be the middle one, index=25
1099+
# q = UnitQuaternion.Rz(np.linspace(start=0.3, stop=0.7, num=51))
1100+
# m = q.mean()
1101+
# self.assertIsInstance(m, UnitQuaternion)
1102+
# nt.assert_array_almost_equal(m.vec, q[25].vec)
1103+
1104+
# # now add noise
1105+
# rng = np.random.default_rng(0) # reproducible random numbers
1106+
# rpy += rng.normal(scale=0.1, size=(100, 3))
1107+
# q = UnitQuaternion.RPY(rpy)
1108+
# m = q.mean()
1109+
# nt.assert_array_almost_equal(m.vec, q.RPY(0.1, 0.2, 0.3).vec)
1110+
10211111

10221112
# ---------------------------------------------------------------------------------------#
10231113
if __name__ == "__main__":

0 commit comments

Comments
 (0)