Intelegentny_Pszczelarz/.venv/Lib/site-packages/jax/_src/scipy/spatial/transform.py

389 lines
15 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import re
import typing
import scipy.spatial.transform
import jax
import jax.numpy as jnp
from jax._src.numpy.util import _wraps
@_wraps(scipy.spatial.transform.Rotation)
class Rotation(typing.NamedTuple):
"""Rotation in 3 dimensions."""
quat: jax.Array
@classmethod
def concatenate(cls, rotations: typing.Sequence):
"""Concatenate a sequence of `Rotation` objects."""
return cls(jnp.concatenate([rotation.quat for rotation in rotations]))
@classmethod
def from_euler(cls, seq: str, angles: jax.Array, degrees: bool = False):
"""Initialize from Euler angles."""
num_axes = len(seq)
if num_axes < 1 or num_axes > 3:
raise ValueError("Expected axis specification to be a non-empty "
"string of upto 3 characters, got {}".format(seq))
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
if not (intrinsic or extrinsic):
raise ValueError("Expected axes from `seq` to be from ['x', 'y', "
"'z'] or ['X', 'Y', 'Z'], got {}".format(seq))
if any(seq[i] == seq[i+1] for i in range(num_axes - 1)):
raise ValueError("Expected consecutive axes to be different, "
"got {}".format(seq))
angles = jnp.atleast_1d(angles)
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
return cls(_elementary_quat_compose(angles, axes, intrinsic, degrees))
@classmethod
def from_matrix(cls, matrix: jax.Array):
"""Initialize from rotation matrix."""
return cls(_from_matrix(matrix))
@classmethod
def from_mrp(cls, mrp: jax.Array):
"""Initialize from Modified Rodrigues Parameters (MRPs)."""
return cls(_from_mrp(mrp))
@classmethod
def from_quat(cls, quat: jax.Array):
"""Initialize from quaternions."""
return cls(_normalize_quaternion(quat))
@classmethod
def from_rotvec(cls, rotvec: jax.Array, degrees: bool = False):
"""Initialize from rotation vectors."""
return cls(_from_rotvec(rotvec, degrees))
@classmethod
def identity(cls, num: typing.Optional[int] = None, dtype=float):
"""Get identity rotation(s)."""
assert num is None
quat = jnp.array([0., 0., 0., 1.], dtype=dtype)
return cls(quat)
@classmethod
def random(cls, random_key: jax.Array, num: typing.Optional[int] = None):
"""Generate uniformly distributed rotations."""
# Need to implement scipy.stats.special_ortho_group for this to work...
raise NotImplementedError
def __getitem__(self, indexer):
"""Extract rotation(s) at given index(es) from object."""
if self.single:
raise TypeError("Single rotation is not subscriptable.")
return Rotation(self.quat[indexer])
def __len__(self):
"""Number of rotations contained in this object."""
if self.single:
raise TypeError('Single rotation has no len().')
else:
return self.quat.shape[0]
def __mul__(self, other):
"""Compose this rotation with the other."""
return Rotation.from_quat(_compose_quat(self.quat, other.quat))
def apply(self, vectors: jax.Array, inverse: bool = False) -> jax.Array:
"""Apply this rotation to one or more vectors."""
return _apply(self.as_matrix(), vectors, inverse)
def as_euler(self, seq: str, degrees: bool = False):
"""Represent as Euler angles."""
if len(seq) != 3:
raise ValueError("Expected 3 axes, got {}.".format(seq))
intrinsic = (re.match(r'^[XYZ]{1,3}$', seq) is not None)
extrinsic = (re.match(r'^[xyz]{1,3}$', seq) is not None)
if not (intrinsic or extrinsic):
raise ValueError("Expected axes from `seq` to be from "
"['x', 'y', 'z'] or ['X', 'Y', 'Z'], "
"got {}".format(seq))
if any(seq[i] == seq[i+1] for i in range(2)):
raise ValueError("Expected consecutive axes to be different, "
"got {}".format(seq))
axes = jnp.array([_elementary_basis_index(x) for x in seq.lower()])
return _compute_euler_from_quat(self.quat, axes, extrinsic, degrees)
def as_matrix(self) -> jax.Array:
"""Represent as rotation matrix."""
return _as_matrix(self.quat)
def as_mrp(self) -> jax.Array:
"""Represent as Modified Rodrigues Parameters (MRPs)."""
return _as_mrp(self.quat)
def as_rotvec(self, degrees: bool = False) -> jax.Array:
"""Represent as rotation vectors."""
return _as_rotvec(self.quat, degrees)
def as_quat(self) -> jax.Array:
"""Represent as quaternions."""
return self.quat
def inv(self):
"""Invert this rotation."""
return Rotation(_inv(self.quat))
def magnitude(self) -> jax.Array:
"""Get the magnitude(s) of the rotation(s)."""
return _magnitude(self.quat)
def mean(self, weights: typing.Optional[jax.Array] = None):
"""Get the mean of the rotations."""
weights = jnp.where(weights is None, jnp.ones(self.quat.shape[0], dtype=self.quat.dtype), jnp.asarray(weights, dtype=self.quat.dtype))
if weights.ndim != 1:
raise ValueError("Expected `weights` to be 1 dimensional, got "
"shape {}.".format(weights.shape))
if weights.shape[0] != len(self):
raise ValueError("Expected `weights` to have number of values "
"equal to number of rotations, got "
"{} values and {} rotations.".format(weights.shape[0], len(self)))
K = jnp.dot(weights[jnp.newaxis, :] * self.quat.T, self.quat)
_, v = jnp.linalg.eigh(K)
return Rotation(v[:, -1])
@property
def single(self) -> bool:
"""Whether this instance represents a single rotation."""
return self.quat.ndim == 1
@_wraps(scipy.spatial.transform.Slerp)
class Slerp(typing.NamedTuple):
"""Spherical Linear Interpolation of Rotations."""
times: jnp.ndarray
timedelta: jnp.ndarray
rotations: Rotation
rotvecs: jnp.ndarray
@classmethod
def init(cls, times: jax.Array, rotations: Rotation):
if not isinstance(rotations, Rotation):
raise TypeError("`rotations` must be a `Rotation` instance.")
if rotations.single or len(rotations) == 1:
raise ValueError("`rotations` must be a sequence of at least 2 rotations.")
times = jnp.asarray(times, dtype=rotations.quat.dtype)
if times.ndim != 1:
raise ValueError("Expected times to be specified in a 1 "
"dimensional array, got {} "
"dimensions.".format(times.ndim))
if times.shape[0] != len(rotations):
raise ValueError("Expected number of rotations to be equal to "
"number of timestamps given, got {} rotations "
"and {} timestamps.".format(len(rotations), times.shape[0]))
timedelta = jnp.diff(times)
# if jnp.any(timedelta <= 0): # this causes a concretization error...
# raise ValueError("Times must be in strictly increasing order.")
new_rotations = Rotation(rotations.as_quat()[:-1])
return cls(
times=times,
timedelta=timedelta,
rotations=new_rotations,
rotvecs=(new_rotations.inv() * Rotation(rotations.as_quat()[1:])).as_rotvec())
def __call__(self, times: jax.Array):
"""Interpolate rotations."""
compute_times = jnp.asarray(times, dtype=self.times.dtype)
if compute_times.ndim > 1:
raise ValueError("`times` must be at most 1-dimensional.")
single_time = compute_times.ndim == 0
compute_times = jnp.atleast_1d(compute_times)
ind = jnp.maximum(jnp.searchsorted(self.times, compute_times) - 1, 0)
alpha = (compute_times - self.times[ind]) / self.timedelta[ind]
result = (self.rotations[ind] * Rotation.from_rotvec(self.rotvecs[ind] * alpha[:, None]))
if single_time:
return result[0]
return result
@functools.partial(jnp.vectorize, signature='(m,m),(m),()->(m)')
def _apply(matrix: jax.Array, vector: jax.Array, inverse: bool) -> jax.Array:
return jnp.where(inverse, matrix.T, matrix) @ vector
@functools.partial(jnp.vectorize, signature='(m)->(n,n)')
def _as_matrix(quat: jax.Array) -> jax.Array:
x = quat[0]
y = quat[1]
z = quat[2]
w = quat[3]
x2 = x * x
y2 = y * y
z2 = z * z
w2 = w * w
xy = x * y
zw = z * w
xz = x * z
yw = y * w
yz = y * z
xw = x * w
return jnp.array([[+ x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw)],
[2 * (xy + zw), - x2 + y2 - z2 + w2, 2 * (yz - xw)],
[2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2]])
@functools.partial(jnp.vectorize, signature='(m)->(n)')
def _as_mrp(quat: jax.Array) -> jax.Array:
sign = jnp.where(quat[3] < 0, -1., 1.)
denominator = 1. + sign * quat[3]
return sign * quat[:3] / denominator
@functools.partial(jnp.vectorize, signature='(m),()->(n)')
def _as_rotvec(quat: jax.Array, degrees: bool) -> jax.Array:
quat = jnp.where(quat[3] < 0, -quat, quat) # w > 0 to ensure 0 <= angle <= pi
angle = 2. * jnp.arctan2(_vector_norm(quat[:3]), quat[3])
angle2 = angle * angle
small_scale = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
large_scale = angle / jnp.sin(angle / 2)
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
scale = jnp.where(degrees, jnp.rad2deg(scale), scale)
return scale * jnp.array(quat[:3])
@functools.partial(jnp.vectorize, signature='(n),(n)->(n)')
def _compose_quat(p: jax.Array, q: jax.Array) -> jax.Array:
cross = jnp.cross(p[:3], q[:3])
return jnp.array([p[3]*q[0] + q[3]*p[0] + cross[0],
p[3]*q[1] + q[3]*p[1] + cross[1],
p[3]*q[2] + q[3]*p[2] + cross[2],
p[3]*q[3] - p[0]*q[0] - p[1]*q[1] - p[2]*q[2]])
@functools.partial(jnp.vectorize, signature='(m),(l),(),()->(n)')
def _compute_euler_from_quat(quat: jax.Array, axes: jax.Array, extrinsic: bool, degrees: bool) -> jax.Array:
angle_first = jnp.where(extrinsic, 0, 2)
angle_third = jnp.where(extrinsic, 2, 0)
axes = jnp.where(extrinsic, axes, axes[::-1])
i = axes[0]
j = axes[1]
k = axes[2]
symmetric = i == k
k = jnp.where(symmetric, 3 - i - j, k)
sign = jnp.array((i - j) * (j - k) * (k - i) // 2, dtype=quat.dtype)
eps = 1e-7
a = jnp.where(symmetric, quat[3], quat[3] - quat[j])
b = jnp.where(symmetric, quat[i], quat[i] + quat[k] * sign)
c = jnp.where(symmetric, quat[j], quat[j] + quat[3])
d = jnp.where(symmetric, quat[k] * sign, quat[k] * sign - quat[i])
angles = jnp.empty(3, dtype=quat.dtype)
angles = angles.at[1].set(2 * jnp.arctan2(jnp.hypot(c, d), jnp.hypot(a, b)))
case = jnp.where(jnp.abs(angles[1] - jnp.pi) <= eps, 2, 0)
case = jnp.where(jnp.abs(angles[1]) <= eps, 1, case)
half_sum = jnp.arctan2(b, a)
half_diff = jnp.arctan2(d, c)
angles = angles.at[0].set(jnp.where(case == 1, 2 * half_sum, 2 * half_diff * jnp.where(extrinsic, -1, 1))) # any degenerate case
angles = angles.at[angle_first].set(jnp.where(case == 0, half_sum - half_diff, angles[angle_first]))
angles = angles.at[angle_third].set(jnp.where(case == 0, half_sum + half_diff, angles[angle_third]))
angles = angles.at[angle_third].set(jnp.where(symmetric, angles[angle_third], angles[angle_third] * sign))
angles = angles.at[1].set(jnp.where(symmetric, angles[1], angles[1] - jnp.pi / 2))
angles = (angles + jnp.pi) % (2 * jnp.pi) - jnp.pi
return jnp.where(degrees, jnp.rad2deg(angles), angles)
def _elementary_basis_index(axis: str) -> int:
if axis == 'x':
return 0
elif axis == 'y':
return 1
elif axis == 'z':
return 2
raise ValueError("Expected axis to be from ['x', 'y', 'z'], got {}".format(axis))
@functools.partial(jnp.vectorize, signature=('(m),(m),(),()->(n)'))
def _elementary_quat_compose(angles: jax.Array, axes: jax.Array, intrinsic: bool, degrees: bool) -> jax.Array:
angles = jnp.where(degrees, jnp.deg2rad(angles), angles)
result = _make_elementary_quat(axes[0], angles[0])
for idx in range(1, len(axes)):
quat = _make_elementary_quat(axes[idx], angles[idx])
result = jnp.where(intrinsic, _compose_quat(result, quat), _compose_quat(quat, result))
return result
@functools.partial(jnp.vectorize, signature=('(m),()->(n)'))
def _from_rotvec(rotvec: jax.Array, degrees: bool) -> jax.Array:
rotvec = jnp.where(degrees, jnp.deg2rad(rotvec), rotvec)
angle = _vector_norm(rotvec)
angle2 = angle * angle
small_scale = scale = 0.5 - angle2 / 48 + angle2 * angle2 / 3840
large_scale = jnp.sin(angle / 2) / angle
scale = jnp.where(angle <= 1e-3, small_scale, large_scale)
return jnp.hstack([scale * rotvec, jnp.cos(angle / 2)])
@functools.partial(jnp.vectorize, signature=('(m,m)->(n)'))
def _from_matrix(matrix: jax.Array) -> jax.Array:
matrix_trace = matrix[0, 0] + matrix[1, 1] + matrix[2, 2]
decision = jnp.array([matrix[0, 0], matrix[1, 1], matrix[2, 2], matrix_trace], dtype=matrix.dtype)
choice = jnp.argmax(decision)
i = choice
j = (i + 1) % 3
k = (j + 1) % 3
quat_012 = jnp.empty(4, dtype=matrix.dtype)
quat_012 = quat_012.at[i].set(1 - decision[3] + 2 * matrix[i, i])
quat_012 = quat_012.at[j].set(matrix[j, i] + matrix[i, j])
quat_012 = quat_012.at[k].set(matrix[k, i] + matrix[i, k])
quat_012 = quat_012.at[3].set(matrix[k, j] - matrix[j, k])
quat_3 = jnp.empty(4, dtype=matrix.dtype)
quat_3 = quat_3.at[0].set(matrix[2, 1] - matrix[1, 2])
quat_3 = quat_3.at[1].set(matrix[0, 2] - matrix[2, 0])
quat_3 = quat_3.at[2].set(matrix[1, 0] - matrix[0, 1])
quat_3 = quat_3.at[3].set(1 + decision[3])
quat = jnp.where(choice != 3, quat_012, quat_3)
return _normalize_quaternion(quat)
@functools.partial(jnp.vectorize, signature='(m)->(n)')
def _from_mrp(mrp: jax.Array) -> jax.Array:
mrp_squared_plus_1 = jnp.dot(mrp, mrp) + 1
return jnp.hstack([2 * mrp[:3], (2 - mrp_squared_plus_1)]) / mrp_squared_plus_1
@functools.partial(jnp.vectorize, signature='(n)->(n)')
def _inv(quat: jax.Array) -> jax.Array:
return quat.at[3].set(-quat[3])
@functools.partial(jnp.vectorize, signature='(n)->()')
def _magnitude(quat: jax.Array) -> jax.Array:
return 2. * jnp.arctan2(_vector_norm(quat[:3]), jnp.abs(quat[3]))
@functools.partial(jnp.vectorize, signature='(),()->(n)')
def _make_elementary_quat(axis: int, angle: jax.Array) -> jax.Array:
quat = jnp.zeros(4, dtype=angle.dtype)
quat = quat.at[3].set(jnp.cos(angle / 2.))
quat = quat.at[axis].set(jnp.sin(angle / 2.))
return quat
@functools.partial(jnp.vectorize, signature='(n)->(n)')
def _normalize_quaternion(quat: jax.Array) -> jax.Array:
return quat / _vector_norm(quat)
@functools.partial(jnp.vectorize, signature='(n)->()')
def _vector_norm(vector: jax.Array) -> jax.Array:
return jnp.sqrt(jnp.dot(vector, vector))