Bi-invariant Geodesic Regression¶
Many real-world phenomena are best described through continuous transformations, such as shape variations in medical imaging or movements in robotic systems. To effectively model the variability in such datasets, statistical analysis must be performed on Lie groups—mathematical structures that combine the properties of manifolds and groups. Since Lie groups inherently represent the symmetries present in the data, it is desirable to use statistical methods that adhere to these symmetries. (In technical terms: The result of the statical method should be equivariant or invariant under left and right translations of the data.) Doing so ensures that analyses are not influenced by arbitrary choices, such as that of a reference coordinate system.
Applying this idea to geodesic regression (see the tutorial on Riemannian regression), Morphomatics implements bi-invariant geodesic regression. For the definition, see Bi-invariant Geodesic Regression with Data from the Osteoarthritis Initiative. In this tutorial, we demonstrate the difference between bi-invariant and Riemannian geodesic regression when applied to data from the group of rigid-body motions. The product of the Frobenius and Euclidean inner products as Riemannian metric in the latter case.
We start by creating random samples along a Riemannian geodesic.
%env JAX_PLATFORM_NAME=cpu
import jax
import jax.numpy as jnp
import jax.random as rnd
import numpy as np
from morphomatics.manifold import SE3
from morphomatics.stats import RiemannianRegression
import pyvista as pv
G_riemannian = SE3(structure='CanonicalRiemannian')
key = rnd.key(3)
# random start and end points of geodesic
g0, g1 = jax.vmap(G_riemannian.rand)(rnd.split(key, 2))
# random samples along the geodesic
n_sample = 10
t = jnp.linspace(0, 1, n_sample)
gam = jax.vmap(G_riemannian.connec.geopoint, (None, None, 0))(g0, g1, t)
# noise vectors (represented in Lie algebra)
noise = jax.vmap(G_riemannian.group.coords_inv)(rnd.normal(key, (n_sample, G_riemannian.dim)))
# disturb data points by noise
Y = jax.vmap(G_riemannian.connec.exp)(gam, 0.1 * noise)
env: JAX_PLATFORM_NAME=cpu
Now, we fit a Riemannian geodesic to the data points. It will approximate the geodesic used to sample the data. This can be seen by computing the $\textnormal{R}^2$-statistic of the fitted geodesic. It is a number between 0 and 1 that measures how well the fitted geodesic approximates the original data (1 being a perfect fit).
Then, we right-translate the data 100 times by a random element and perform another regression.
regression = RiemannianRegression
reg = regression(G_riemannian, Y, t)
print(f'The R2-value of the geodesic that approximates the original data is {reg.R2statistic:.2f}.')
r2_values = []
for f in jax.vmap(G_riemannian.rand)(rnd.split(rnd.key(84), 100)):
# right-translate data by a random element f
Y_f = jax.vmap(G_riemannian.group.righttrans, (0, None))(Y, f)
# perform regression for translated data
reg_tilde = RiemannianRegression(G_riemannian, Y_f, t)
# compute R2-statistic
r2_values.append(reg_tilde.R2statistic)
r2_values = np.array(r2_values)
print(f'The minimal and average R2-values of the geodesic approximating the translated datasets are {r2_values.min():.2f} and {r2_values.mean():.2f}, respectively.')
The R2-value of the geodesic that approximates the original data is 0.93. The minimal and average R2-values of the geodesic approximating the translated datasets are 0.81 and 0.88, respectively.
Let's now see whether the regression estimator is equivariant under right translations. Therefore, we visually compare the translated original geodesic $\gamma_R$ with the geodesic $\widetilde{\gamma}_R$ that is the least-squares estimator for the translated data. To this end, we make use of the fact that a sequence of rigid motions can be visualized by applying all to a 3D object: They maniupulate its pose by translating and rotating it so that the object traces out a path. We use a (discretized) paper plane and plot its trajectory under $\gamma_R$ (orange) and $\widetilde{\gamma}_R$ (red). Additionally, we also plot the geodesic that connects the start and end point of $\gamma$ (yellow).
# load paper plane
glyph = pv.read('../data/paper-plane.obj').scale(.2).rotate_x(90).rotate_y(-45)
# plotting function 1
def frame(g: np.ndarray):
"""Add a frame to the plotter."""
return glyph.copy().transform(g, inplace=True)
# plotting function 2
def tube_mesh(pts: np.ndarray):
n = len(pts)-1
curve = pv.PolyData(pts)
curve.lines = np.c_[np.full(n,2), np.arange(n), np.arange(n)+1]
curve.tube(radius=0.01, inplace=True)
return curve
# plotting function 3
def trace_frame(g: np.ndarray):
return pv.MultiBlock([glyph.copy().transform(f, inplace=True) for f in g])
# translate the datapoints by a random element
f = G_riemannian.rand(rnd.key(1111111))
Y_f = jax.vmap(G_riemannian.group.righttrans, (0, None))(Y, f)
gamma_orig = regression(G_riemannian, Y, t).trend
gamma = lambda t: G_riemannian.group.righttrans(gamma_orig.eval(t), f)
gamma_tilde = RiemannianRegression(G_riemannian, Y_f, t).trend
def main_plot(G, gam, gam_tilde, translated_data):
plt = pv.Plotter()
for i, g in enumerate(translated_data):
plt.add_mesh(frame(np.asarray(g)), color='lime')
geo_args = {'smooth_shading': True, 'opacity': 1}
# opaque background
t_dense = jnp.linspace(0, 1, 200)
a_1 = jax.vmap(gam)(t_dense)
plt.add_mesh(trace_frame(np.asarray(a_1)), color='orange', opacity=.04)
# fewer solid samples
a_2 = jax.vmap(gam)(t)
plt.add_mesh(trace_frame(np.asarray(a_2)), color='orange', **geo_args)
b_1 = jax.vmap(gam_tilde.eval)(t_dense)
plt.add_mesh(trace_frame(np.asarray(b_1)), color='r', opacity=.04)
b_2 = jax.vmap(gamma_tilde.eval)(t)
plt.add_mesh(trace_frame(np.asarray(b_2)), color='r', **geo_args)
# show Riemannian geodesic between translated endpoints of regression one
plt.add_mesh(trace_frame(np.asarray(jax.vmap(G.connec.geopoint, (None, None, 0))(a_1[0], a_1[-1], t_dense))), color='yellow', opacity=.04)
plt.add_mesh(trace_frame(np.asarray(jax.vmap(G.connec.geopoint, (None, None, 0))(a_2[0], a_2[-1], t))), color='yellow', **geo_args)
# set optimal camera position
plt.camera_position = [
(5.0469224592322455, -1.4055172372092932, -2.072496421293708),
(-0.397411972284317, 0.5756635069847107, -1.4490582346916199),
(0.2040203291061177, 0.7562696348810438, -0.6216365052571513)
]
plt.show(jupyter_backend='static')
main_plot(G_riemannian, gamma, gamma_tilde, Y_f)
We can see that the results are clearly different. This is bad news. Lie group-valued data is often only well-defined up to translation (and inversion) symmetry, so the results of statistical procedures should not depend on them.
Bi-invariant statistics offers methods that do not suffer from this drawback. It endows Lie groups with an affine structure that yields translated one-parameter subgroups as geodesics. Let us initialize SE(3) with this structure and repeat the above experiment with the new tools.
G_affine = SE3(structure='AffineGroup')
from morphomatics.stats import BiinvariantRegression
gamma_orig = BiinvariantRegression(G_affine, Y, t).trend
gamma = lambda s: G_affine.group.righttrans(gamma_orig.eval(s), f)
gamma_tilde = BiinvariantRegression(G_affine, Y_f, t).trend
main_plot(G_affine, gamma, gamma_tilde, Y_f)
Apart from small differences that are due to numerical rounding, the results are identical. This is very good news. It shows that bi-invariant statistics are indeed equivariant under translations such that the latter do not influence the quality of the results.