Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Building and Learning (Pseudo)Metrics -- Fourier Case

A metric is a notion of distance generalized to arbitrary topological spaces. It has four key properties:

  1. Nonnegativity: d(x,y)0d(x, y) \geq 0

  2. Uniqueness: if d(x,y)=0d(x,y) = 0, then x=yx=y

  3. Symmetry: d(x,y)=d(y,x)d(x,y) = d(y,x)

  4. Triangle Identity: d(x,z)d(x,y)+d(y,z)d(x, z) \leq d(x, y) + d(y, z)

A pseudometric lacks the uniqueness property: It has xyx \neq y with d(x,y)=0d(x,y) = 0.

For neural localization, we need to compare two spaces: locations and sensory percepts. In general, the representation spaces for these two spaces don’t naturally work with the Euclidean metric, especially for locations as described by grid cell codes. So we’ll first define an explicit metric for the Fourier space of grid cells in this notebook, and then consider how to learn a metric that is useful for comparing sensory percepts that could also be applied to location space in another.

We need this metrics to build spatial memories for non-Euclidean spaces to support TEM.

Fourier Location Encoding

Grid cells in dd dimensions (usually d=2d=2) fire with a rate described by

rj(x)=(i=1d+1Aijcos(xxkix+ϕij))+r_j(x) = \left(\sum_{i=1}^{d+1} A_{ij} \cos\left(\vphantom{x^x}k_{i}\cdot x + \phi_{ij}\right)\right)_+

when the organism is at position xx, where AijA_{ij} is a fixed amplitude and kik_{i} is a wave vector, one of d+1d + 1 unit vectors whose convex hull forms the simplex. In 2-D, these vectors start at the origin and are radially separated by 120120^\circ. In 3-D, four wavevectors are needed, pointing the the vertices of a triangular pyramid.

We can model these cells via an assemblage of column vectors having the form

ηij(x)=(cos(αjkix+ϕj)sin(αjkix+ϕj))\eta_{ij}(x) = \left(\begin{array}{c} \cos (\alpha_j k_i\cdot x + \phi_j) \\ \sin (\alpha_j k_i\cdot x + \phi_j) \end{array}\right)

where the {ϕj}\{\phi_j\} are chosen to cover the interval [0,2π)[0, 2\pi). Sufficiently wide choices of ϕj\phi_j will cover the full circle.

Grid cell responses are periodic as the animal moves around its environment, and thanks to the identity

cosacosb+sinasinb=cos(ab)=cos(ba)\cos a \cos b + \sin a \sin b = \cos (a - b) = \cos (b - a)

we can arrive a relationship between displacement in physical space and rotation in Fourier space. To do this, we form a matrix KK whose rows range over kik_i, so that KK is a (d+1)×d(d+1) \times d matrix. The rows of KK span Rd\mathbb{R}^d, so the d×dd \times d matrix KTKK^TK is invertible, and the d×(d+1)d \times (d+1) pseudoinverse of this matrix is K=(KTK)1KTK^\dagger = (K^TK)^{-1}K^T has KK=IK^\dagger K = I. Then we have

(xxxηij(x),ηij(x+Δx))=cos(αjkiΔx)whencenjZd+1,αjKΔx=2πnj+arccos(xxxxHj(x)THj(x+Δx))\left(\vphantom{x^{x^x}}\eta_{ij}(x), \,\eta_{ij}(x + \Delta x)\right) \,\,=\,\, \cos(\alpha_j k_i\cdot \Delta x) \quad\quad\text{whence}\quad\quad \forall n_j \in \mathbb{Z}^{d+1},\quad \alpha_j K \Delta x = 2\pi n_j + \mathrm{arccos} \left(\vphantom{x^{x^{x^x}}}\Eta_{j}(x)^T\,\Eta_{j}(x + \Delta x)\right)

where Hj(x)\Eta_j(x) is the (d+1)×2(d+1) \times 2 matrix whose columns are ηij(x)\eta_{ij}(x). Solving for Δx\Delta x, we find

jnjZd+1s.t.Δx=2πKnjαj+KαjΔθjgivenΔθj=arccos(xxxxHj(x)THj(x+Δx))\forall j\,\, \exists n_j \in \mathbb{Z}^{d+1}\quad\text{s.t.}\quad\quad \Delta x = \frac{2\pi K^\dagger\,n_j}{\alpha_j} + \frac{K^\dagger}{\alpha_j} \Delta \theta_{j}\quad\quad\text{given}\quad\quad\Delta\theta_{j} = \mathrm{arccos} \left(\vphantom{x^{x^{x^x}}}\Eta_{j}(x)^T\,\Eta_{j}(x + \Delta x)\right)

With sufficient and carefully chosen (αj,ϕj)(\alpha_j, \phi_j), these modular constraints can be solved for a wide range of Δx\Delta x, and Δx\Delta x can be identified exactly on this range.

So the firing rates of an assemblage of grid cells with different phases ϕj\phi_j can be stably and reliably inverted to retrieve the physical location xx.

From One Location to the Next

We need a function to update locations based on actions, t=h(t1,at1)\ell_t = h(\ell_{t-1}, a_{t-1}).

We will assume our actions ata_t has an impact on location as translation in physical space. If we assume Fourier location codes (like ηij\eta_{ij} above), then a displacement Δx\Delta x becomes a block-wise rotation of each ηij\eta_{ij}. Recall that a rotation by angle θ\theta is

R(θ)=[cosθsinθsinθcosθ],R(\theta) = \left[\begin{array}{cc} \cos \theta & -\sin \theta \\ \sin \theta & \cos \theta \end{array}\right],

so notating ηij=(xy)\eta_{ij} = \left(\begin{array}{c}x\\y\end{array}\right), we have

R(θij)ηij=(xcosθijysinθijxsinθij+ycosθij).R(\theta_{ij})\eta_{ij} = \left(\begin{array}{c} x\cos\theta_{ij} - y\sin\theta_{ij} \\ x\sin\theta_{ij} + y\cos\theta_{ij} \end{array}\right).

Beyond this rotational update, the Fourier coding can remain implicit in our location-learning scheme. Thus we do not need to know αj\alpha_j or ϕij\phi_{ij}, but instead we allow θij\theta_{ij} to differ for each ijij. Then we can learn a function gϕg_\phi with parameters ϕ\phi and apply it as a block-diagonal matrix multiply as follows:

θij=gϕ(at1)andt=[R(θ00)00R(θIJ)]t1\theta_{ij} = g_\phi(a_{t-1}) \quad\quad\text{and}\quad\quad \ell_t = \left[\begin{array}{ccc} R(\theta_{00}) & \ldots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \ldots & R(\theta_{IJ}) \end{array}\right] \ell_{t-1}

This treats t\ell_t as a sequence of pairs of components ηij\eta_{ij} under some enumeration of ijij. The following code implements this scheme more efficiently, without assembling the full block-diagonal matrix. Each application extends the input sequence by one.

Making Location Codes Physical

A location code represents a physical point. A system of location codes has to represent different points consistently. That restricts the codes that can simultaneously be valid.

When it comes to changes in location, simply updating the location codes by rotation does not guarantee a unique solution for the displacement Δx\Delta x; different ηij\eta_{ij} can yield different displacements. Basically, these location codes have extra flexibility that might lead to a suboptimal spatial representation. The same position in real space could have many different codes, which will prevent us from being able to generate an accurate map of local space, or to use the map to go to areas of interest or to judge what objects are close to each other.

Note that in our geometric decoder, we have calculated Δθij=gϕ(at)\Delta\theta_{ij} = g_\phi(a_t). In general, our location codes should adhere to

Δx=2πKnjαj+1αjKΔθjfor some njZd+1,αjR\Delta x = \frac{2\pi K^\dagger\,n_j}{\alpha_j} + \frac{1}{\alpha_j} K^\dagger\,\,\Delta \theta_{j} \quad\quad\text{for some }n_j \in \mathbb{Z}^{d+1}, \,\,\alpha_j \in \mathbb{R}

based on the formulae above. Our problem is that different jj will give us different values of Δx\Delta x; we can use Δxj\Delta x_j to represent the value given for the matrix Hj\Eta_{j}, which has size (d+1)×2(d+1) \times 2.

Hence we have an estimator for a random variable ΔX=1JjΔxj\Delta X = \frac{1}{J}\sum_j \Delta x_j, and we can minimize the variance of this estimator

Lconsistency=Var2[ΔX]=1J1jΔxjΔX2\mathcal{L}_{\text{consistency}} \,=\, \mathrm{Var}^2\left[\Delta X\right]\,\,=\,\, \frac{1}{J-1}\sum_j \left\|\Delta x_j - \Delta X\right\|^2

but first we must solve for Δxj\Delta x_j. To do this, we first canonicalize Δθj[π,π)\Delta\theta_j \in [-\pi, \pi) and compute Δxj=1αjKΔθj\overline{\Delta x}_j = \frac{1}{\alpha_j} K^\dagger\,\,\Delta \theta_{j} and then note that

ΔxjΔxj+ΛjforΛj=2παjKZd+1={2πKnjαjnjZd+1}.\Delta x_j \in \overline{\Delta x}_j + \Lambda_j \quad\quad\text{for}\quad\quad \Lambda_j \,\,=\,\, \frac{2\pi}{\alpha_j} K^\dagger\,\,\mathbb{Z}^{d+1} \,\,=\,\, \left\{\left.\frac{2\pi K^\dagger\,n_j}{\alpha_j}\,\right\vert\,n_j \in \mathbb{Z}^{d+1}\right\}.

Our problem is that we have too many variables (d+1)(d+1) for too few constraints (dd). We need to eliminate a constraint, and because we chose KK to be the vertices of the d+1d+1 simplex, the structure of our problem lets us do that. We can choose a “basis” BjB_j that generates Λj\Lambda_j as follows. Let

U=[e1ed+1,eded+1]Z(d+1)×dU = \left[e_1 - e_{d+1}, \ldots e_d - e_{d+1}\right] \in \mathbb{Z}^{(d+1)\times d}

where the eie_i are the standard basis vectors of Rd+1\mathbb{R}^{d+1}, that is, eie_i is the ithi^{th} row or column of the identity matrix Id+1I_{d+1}. For any nZd+1n \in \mathbb{Z}^{d+1}, we then have that

mjZdbjZs.t.nj=Umj+bj1whereforeKnj=KUm+bjK1=KUmj\exists m_j \in \mathbb{Z}^d\quad\exists b_j \in \mathbb{Z} \quad \text{s.t.}\quad n_j = Um_j + b_j\mathbf{1} \quad\quad\text{wherefore}\quad\quad K^\dagger\, n_j \,\,=\,\, K^\dagger\,Um + b_j\,K^\dagger\,\mathbf{1}\,\,=\,\, K^\dagger\,Um_j

because K1=0K^\dagger\,\mathbf{1} = 0, which is a consequence of choosing KK to be the vertices of the simplex. Now, we define

Bj=2παjKURd×dwhich yieldsΔxjΔxj+BjZdB_j = \frac{2\pi}{\alpha_j} K^\dagger\,U \quad\in\,\mathbb{R}^{d\times d} \quad\quad\text{which yields}\quad\quad \Delta x_j \in \overline{\Delta x}_j + B_j \mathbb{Z}^d

which reduces the number of constraints by one, allowing us to find solutions.

Let’s compute KK and BjB_j for all jj.

import math
import torch


def make_lattice_basis(alphas, dim: int = 2):
    """
    Construct lattice basis matrices B_j for each grid module j in d dimensions.

    Args:
        alphas: 1D tensor of shape (J,) with spatial frequencies α_j.
        dim: spatial dimension d (>=1).

    Returns:
        B: tensor of shape (J, d, d), where B[j] is the lattice basis B_j.
        K: tensor of shape (d+1, d) with simplex directions as rows.
        K_dagger: tensor of shape (d, d+1) with pseudoinverse of K, returned for convenience.
    """
    if dim < 1:
        raise ValueError("dimension must be >= 1")

    # Ensure alphas is a tensor of shape (J,)
    dtype = alphas.dtype
    device = alphas.device
    J = alphas.shape[0]

    # 1. Build a basis U for the null subspace of the simplex K
    U = torch.cat([torch.eye(dim, dtype=dtype, device=device), -torch.ones(1, dim, dtype=dtype, device=device)], dim=0)

    # 2. Build a regular simplex in R^(d+1) and normalize
    Q, _ = torch.linalg.qr(U, mode="reduced")  # Q: (d+1, d)
    V = torch.eye(dim + 1, dtype=dtype, device=device) - (1.0 / (dim + 1))
    K = V @ Q  # (d+1, d)
    K = K / K.norm(dim=1, keepdim=True)  # each row is now unit length

    # Pseudoinverse of K: K^† ∈ R^{d×(d+1)}
    K_dagger = torch.linalg.pinv(K)

    # Base (unscaled) lattice generator: shape (d, d)
    base = K_dagger @ U

    # 3. Scale by 2π / α_j for each module
    scale = (2.0 * math.pi) / alphas.view(J, 1, 1)  # (J, 1, 1)
    B = scale * base[None, ...]                  # (J, d, d)

    return B, K, K_dagger

To solve for Δj\Delta_j, we will rely on the αj\alpha_j being sorted from coarsest (αj\alpha_j small = long wavelength) to finest. We will assume this as a matter of initialization for αj\alpha_j, which will be fixed and not learned. Let’s look at our initialization of αj\alpha_j for the decoding process.

The choice of αj\alpha_j will determine the range of Δx\Delta x that can be recovered. So we will use a scale parameter (ss). This parameter is, in effect, the maximum norm of displacement caused by any action. Hence actions determine the scale of the space.

Rather than generating the αj\alpha_j directly, we will instead focus on the wavelength λj=2π/αj\lambda_j = {2\pi}/{\alpha_j}. Since our αj\alpha_j will be sorted, we want λ0=2s\lambda_0 = 2s. This will allow us to recover actions with norms in the range [s,s)[-s, s). From there, each successive jj will use a smaller and smaller wavelength in a geometric pattern. The ratio (ρ\rho) of reduction is a parameter to our generator, and we will have λj=2sρj\lambda_j = 2s\rho^{-j}, whence αj=πρj/s\alpha_j = \pi\rho^j / s.

import math
import torch


def make_alphas(location_dim: int, dim: int = 2, scale: float = 10.0,
                ratio: float = math.sqrt(2.0), dtype=torch.get_default_dtype(), device=torch.device("cpu")) -> torch.Tensor:
    """
    Choose module spatial frequencies alpha_j given a location code size and
    a target "safe" displacement scale.

    Args:
        location_dim: int, total number of phase channels = J * (d+1).
        dim: spatial dimension d.
        scale: radius such that for ||Δx|| < scale, the coarsest module
               is unambiguous (λ_0/2 ≈ scale).
        ratio: geometric ratio between successive periods (default √2).

    Returns:
        alphas: (J,) tensor of spatial frequencies α_j, with j=0 coarsest.
    """
    num_dirs = dim + 1
    assert location_dim % (2 * num_dirs) == 0,  f"location_dim={location_dim} must be divisible by 2*(dimension+1)={2*num_dirs}"
    J = location_dim // num_dirs // 2  # number of modules
    j_idx = torch.arange(J, dtype=dtype, device=device)
    return (math.pi / scale) * ratio ** j_idx

Finally, to solve for Δj\Delta_j, we iterate over jj, choosing n0=0n_0 = 0 as our initial estimate on the assumption that Δx<s\|\Delta x\| < s for our scale parameter ss. That is, Δx0=Δx0\Delta x_0 = \overline{\Delta x}_0. Then, we choose the value for njn_j that minimizes the error Δx0Δxj\|\Delta x_0 - \Delta x_j\| by rounding. That is, from Δxj=Δxj+Bjnj\Delta x_j = \overline{\Delta x}_j + B_jn_j we equate Δxj=Δx0\Delta x_j = \Delta x_0, which leads to

Bjnj=Δx0ΔxjδjB_j n_j = \Delta x_0 - \overline{\Delta x}_j \,\,\equiv\,\, \delta_j

Now the matrix formula Bjz=δjB_j z = \delta_j can be solved for zRdz \in \mathbb{R}^d, and we can estimate nj=round(z)n_j = \mathrm{round}(z), allowing us to compute Δxj\Delta x_j.

def solve_for_deltas(delta_thetas: torch.Tensor, K_dagger: torch.Tensor, lattice_basis: torch.Tensor, alphas: torch.Tensor):
    d, dplus = K_dagger.shape
    J, _, _ = lattice_basis.shape
    shape = delta_thetas.shape[:-1] + (J, dplus)
    delta_thetas = delta_thetas.view(shape)[..., None]
    while K_dagger.ndim < delta_thetas.ndim:
        K_dagger = K_dagger[None, ...]
    displacement_base = (K_dagger @ delta_thetas).view(shape[:-1] + (d,)) / alphas[None, :, None]
    reference_displacement = displacement_base[..., 0, :]
    errors = reference_displacement[..., None, :] - displacement_base
    lattice_basis = lattice_basis.view(J, d, d)
    while lattice_basis.ndim < errors.ndim + 1:
        lattice_basis = lattice_basis[None, ...]
    offsets = torch.linalg.solve(lattice_basis.float(), errors[..., None].float()).round().to(delta_thetas.dtype)
    deltas = displacement_base + (lattice_basis @ offsets).squeeze(-1)
    return deltas.view(shape[:-1] + (d,))

Now we can check these functions for accuracy:

J = 20
d = 2

location_dim = J * (d+1) * 2
alphas = make_alphas(location_dim, d)
print(f"alphas: {alphas.detach().cpu().numpy().tolist()}")

lattice_basis, K, K_dagger = make_lattice_basis(alphas, d)
print(f"K: {K.detach().cpu().numpy().tolist()}")
print(f"K_dagger: {K_dagger.detach().cpu().numpy().tolist()}")
assert torch.allclose(K.sum(dim=0), torch.zeros(d))

phis = torch.linspace(0, 2*math.pi, J)

x0 = torch.randn(2, 5, d)
xf = torch.randn(2, 5, d)
deltas = xf - x0

theta0 = alphas[None, None, :, None, None] * (K[None, None, ...] @ x0.view(2, 5, d, 1))[:, :, None, :, :] + phis[None, None, :, None, None]
thetaf = alphas[None, None, :, None, None] * (K[None, None, ...] @ xf.view(2, 5, d, 1))[:, :, None, :, :] + phis[None, None, :, None, None]

print(theta0.shape)

delta_thetas = (thetaf - theta0).view(2, 5, -1)

delta_thetas = delta_thetas.view(2, 5, -1)
deltas_estimate = solve_for_deltas(delta_thetas, K_dagger, lattice_basis, alphas)

print("AVERAGE ERROR OF ESTIMATED DELTA FROM TRUE DELTA: ", torch.norm(deltas[:, :, None, :] - deltas_estimate, dim=-1).mean().item())
alphas: [0.3141592741012573, 0.44428831338882446, 0.6283184885978699, 0.8885766267776489, 1.2566369771957397, 1.7771530151367188, 2.5132739543914795, 3.5543060302734375, 5.026547908782959, 7.108611583709717, 10.053094863891602, 14.217223167419434, 20.106189727783203, 28.434444427490234, 40.21237564086914, 56.86888885498047, 80.42475128173828, 113.73777770996094, 160.84950256347656, 227.47552490234375]
K: [[-0.866025447845459, 0.4999999701976776], [1.5730305946703993e-08, -1.0], [0.866025447845459, 0.5]]
K_dagger: [[-0.5773501992225647, 1.0486870039017049e-08, 0.5773502588272095], [0.3333333432674408, -0.666666567325592, 0.3333333134651184]]
torch.Size([2, 5, 20, 3, 1])
AVERAGE ERROR OF ESTIMATED DELTA FROM TRUE DELTA:  1.7676826757906383e-07

Finally, we have a loss function to regularize over Δθj\Delta \theta_j!

def loss_for_deltas(delta_thetas: torch.Tensor, K_dagger: torch.Tensor, lattice_basis: torch.Tensor, alphas: torch.Tensor):
    deltas = solve_for_deltas(delta_thetas, K_dagger, lattice_basis, alphas)

    # deltas has shape (batch_size, time_steps, J, d)
    return deltas.var(dim=-2).mean()

delta_thetas_noise = delta_thetas + torch.randn_like(delta_thetas)

print("LOSS OF NOISY DELTA_THETAS: ", loss_for_deltas(delta_thetas_noise, K_dagger, lattice_basis, alphas).item())
LOSS OF NOISY DELTA_THETAS:  1.8048560619354248

Creating and Verifying Location Codes

To work with location codes on a physical space of dimension dd, we need to be able to create valid locations. In general, our location code is a tensor of shape (J,d+1,2)(J, d+1, 2) where the first dimension is the number of different modules, the second indexes over the vertices of the simplex, and the third contains the cos\cos and sin\sin values. To be valid, these last two must satisfy cos2θ+sin2θ=1\cos^2 \theta + \sin^2\theta = 1, that is, the last dimension must lie on the the unit sphere. But there is a deeper notion of validity as well: the code must have a unique underlying physical interpretation.

To begin, we will need to sample a random location code. One way to do this is to sample a set of phase angles and generate the code from there. Most of the ingredients for this transformation are given above; we’ve computed the simplex vertices KK and the parameters αj=πρj/s\alpha_j = \pi\rho^j / s for ρ=2\rho = \sqrt{2}. The final piece is the phase ϕi,j\phi_{i,j}, which we will define as combination of two factors:

ϕi,j=ϕj~+ξi\phi_{i,j} = \tilde{\phi_j} + \xi_i

where ϕj~\tilde{\phi_j} is uniform on [π,pi)[-\pi, pi) and ξi=2πid+1\xi_i = \frac{2\pi i}{d+1}. This makes the phases different among all components i,ji,j of the code. In particular, it spaces out the wavevectors kik_i evenly over the circle. To say that ϕi,j\phi_{i,j} are the phases assumes the sample represents the encoding of the origin x=0x = 0, but this is acceptable because the location αKx\alpha Kx for general x0x \neq 0 would merely represent a simultaneous shift across all components, which can be absorbed into the uniform sample ϕ~i\tilde{\phi}_i. So for any set of nn samples taken by this method, we can pick one to be the origin and use it to extract a consistent set of phases ϕi,j\phi_{i,j} and consistent physical locations for the other n1n-1 points.

from typing import Tuple

phis = torch.empty((location_dim // 2 // (2 + 1),)).uniform_(-math.pi, math.pi)

def sample(location_dim: int, K: torch.Tensor, alphas: torch.Tensor, phis: torch.Tensor, shape: Tuple[int, ...] = torch.Size()):
    physical_dim = K.shape[1]
    x = torch.randn(shape + (physical_dim,1))
    print(f"sample picked x={x.detach().cpu().numpy().tolist()}")
    while K.ndim < x.ndim:
        K = K[None, ...]
    Kx = (K @ x).squeeze(-1)   # (..., d+1)
    alphas = alphas[..., None] # (J, d+1)
    while alphas.ndim < Kx.ndim:
        alphas = alphas[None, ...]
    aKx = alphas * Kx[..., None, :]  # (..., J, d+1)

    phis = phis[..., None]
    while phis.ndim < aKx.ndim:
        phis = phis[None, ...]

    thetas = aKx + phis # + xis
    return torch.stack([torch.cos(thetas), torch.sin(thetas)], dim=-1).view(shape + (location_dim,))

l1 = sample(120, K, alphas, phis)
l2 = sample(120, K, alphas, phis)

print(f"l1.shape: {l1.shape} code: {l1.detach().cpu().numpy().tolist()}")
print(f"l2.shape: {l2.shape} code: {l2.detach().cpu().numpy().tolist()}")
sample picked x=[[2.1085588932037354], [-1.3823367357254028]]
sample picked x=[[0.08212031424045563], [1.2365736961364746]]
l1.shape: torch.Size([120]) code: [-0.7163147330284119, -0.6977773308753967, 0.4137597382068634, -0.9103861451148987, 0.34181222319602966, -0.9397682547569275, 0.6467894911766052, -0.7626685500144958, 0.6485604643821716, 0.7611631155014038, 0.7281549572944641, 0.6854125261306763, 0.8047559261322021, 0.5936058163642883, -0.9984386563301086, 0.05585923790931702, -0.9777466654777527, 0.20978902280330658, -0.9895644783973694, 0.14409072697162628, 0.984043538570404, 0.17792800068855286, 0.99916011095047, -0.04097708687186241, -0.16963337361812592, -0.9855072498321533, -0.9998469352722168, -0.017496973276138306, -0.9572534561157227, 0.28925055265426636, 0.017376990988850594, 0.9998490214347839, -0.588797390460968, 0.8082806468009949, -0.18869440257549286, 0.9820358753204346, 0.8791793584823608, 0.47649094462394714, -0.6428601741790771, -0.765983521938324, -0.9687423706054688, -0.24806904792785645, -0.9874424338340759, 0.15797913074493408, -0.4219028055667877, -0.9066410660743713, -0.9674668312072754, -0.25299787521362305, 0.9121777415275574, -0.4097948670387268, 0.9461743235588074, 0.32365739345550537, 0.6104447245597839, -0.7920588850975037, 0.9812467694282532, -0.19275572896003723, -0.7332056164741516, 0.6800069808959961, 0.8051489591598511, 0.5930726528167725, -0.87924724817276, 0.4763656556606293, -0.5342934727668762, -0.8452990651130676, -0.09026530385017395, 0.9959177374839783, 0.859729528427124, 0.5107495784759521, 0.8410654067993164, -0.5409334301948547, -0.5834247469902039, 0.8121671676635742, -0.9495474100112915, 0.3136235773563385, 0.8989977240562439, -0.43795329332351685, 0.6563684940338135, 0.7544404864311218, -0.306147962808609, -0.9519839286804199, -0.5771567821502686, 0.8166333436965942, 0.13694500923156738, 0.9905786514282227, 0.4679951071739197, 0.8837310671806335, 0.6857665777206421, 0.7278215289115906, -0.9583046436309814, -0.28574851155281067, -0.7178304195404053, 0.696217954158783, -0.46898889541625977, -0.8832040429115295, -0.9120305180549622, 0.4101223349571228, 0.7137883305549622, 0.7003614902496338, 0.9716621041297913, 0.23637409508228302, 0.6879958510398865, -0.7257146239280701, -0.38272079825401306, 0.9238640666007996, 0.8172910809516907, -0.5762250423431396, -0.8856114745140076, 0.4644269049167633, 0.18837592005729675, 0.9820970296859741, 0.9533213376998901, 0.3019576072692871, -0.22200298309326172, -0.9750459790229797, -0.42945265769958496, -0.9030894041061401, 0.6317500472068787, -0.7751721143722534, 0.8111048340797424, -0.5849007964134216]
l2.shape: torch.Size([120]) code: [0.1634671688079834, -0.9865487813949585, -0.38589251041412354, -0.9225437641143799, 0.2073732614517212, -0.9782618880271912, 0.8804200291633606, 0.4741947054862976, 0.9557943940162659, -0.294035941362381, 0.8487162590026855, 0.5288484692573547, -0.8361117243766785, 0.5485591292381287, 0.13023075461387634, 0.9914836883544922, -0.8817344903945923, 0.4717460572719574, 0.8455346822738647, -0.5339204668998718, -0.545868992805481, -0.8378705382347107, 0.90609210729599, -0.4230804741382599, -0.5130985379219055, 0.858329713344574, 0.9913099408149719, -0.1315467804670334, -0.6575261950492859, 0.7534316778182983, 0.7543420791625977, 0.6564815640449524, -0.7726732492446899, -0.6348039507865906, 0.5661889910697937, 0.8242754340171814, -0.33759403228759766, 0.9412918090820312, -0.8398900032043457, -0.5427566766738892, -0.6456234455108643, 0.763655960559845, 0.2593931555747986, 0.9657717943191528, 0.3137822151184082, 0.9494949579238892, -0.2407698780298233, 0.9705822467803955, -0.7472859621047974, 0.6645026206970215, 0.9642717838287354, -0.26491492986679077, -0.9999295473098755, 0.011868349276483059, -0.920486330986023, 0.39077481627464294, -0.8702760934829712, 0.49256423115730286, -0.8198588490486145, -0.5725656747817993, -0.4484948515892029, 0.8937854170799255, -0.9823259711265564, 0.1871781349182129, -0.9479047656059265, -0.3185538947582245, 0.9926824569702148, 0.120754174888134, 0.9943798780441284, -0.10587082803249359, -0.5417043566703796, 0.8405690789222717, -0.03338322415947914, 0.9994426369667053, -0.9560878276824951, -0.2930803894996643, -0.24582013487815857, -0.9693154692649841, -0.9091147780418396, -0.41654571890830994, -0.9996976256370544, 0.024589691311120987, 0.23607660830020905, 0.9717344641685486, 0.21604879200458527, -0.9763825535774231, -0.6786207556724548, 0.734488844871521, -0.3388988971710205, -0.9408227801322937, 0.757413387298584, 0.6529356241226196, 0.9806544184684753, -0.1957469880580902, -0.8112629055976868, 0.5846815705299377, -0.5766915082931519, 0.8169620037078857, -0.9949922561645508, -0.0999518483877182, 0.49022790789604187, 0.8715943098068237, 0.21265260875225067, -0.9771278500556946, -0.9994928240776062, 0.0318438820540905, -0.631900429725647, 0.7750495672225952, -0.4898953139781952, -0.8717812895774841, 0.9995584487915039, 0.029714152216911316, -0.3668069541454315, 0.9302970767021179, -0.7180269956588745, -0.6960152983665466, 0.9617215394973755, 0.2740286588668823, 0.13802990317344666, -0.9904280304908752]

Now, if we identify 1\ell_1 as the origin x1=0x_1 = 0, then we can describe 2\ell_2 as the block rotation required to get from 1\ell_1 to 2\ell_2, and from there we can get to the displacement Δx\Delta x such that x2=x1+Δxx_2 = x_1 + \Delta x represents 2\ell_2.

To do this, we use atan2 and the formula

Δθij=atan2(ηij×ηij,ηijηij)\Delta\theta_{ij} = \mathrm{atan2}\left(\left|\eta_{ij} \times \eta_{ij}\right|, \eta_{ij}\cdot\eta_{ij}\right)

which resolves to

Δθij=atan2(cosθ1sinθ2sinθ1cosθ2,cosθ1cosθ2+sinθ1sinθ2)\Delta\theta_{ij} = \mathrm{atan2}\left(\cos\theta_1 \sin \theta_2 - \sin\theta_1\cos\theta_2,\,\, \cos\theta_1\cos \theta_2 + \sin\theta_1\sin\theta_2\right)

after which we can call solve_for_deltas above.

Note that we will allow for the possibility of codes that have no physical interpretation, which will mean that our results will be of shape (,J)(\ldots, J) with JJ different physical interpretations, one per module. We’ll also return the mean of these interpretations. For real physical locations, the variance should be small.

def reshape_to_components(location: torch.Tensor, physical_dim: int):
    shape = tuple(list(location.shape[:-1]) + [-1, physical_dim + 1, 2])

    return location.view(*shape)

def compute_angles(
    location1: torch.Tensor, 
    location2: torch.Tensor, 
    physical_dim: int=2
):
    location1 = reshape_to_components(location1, physical_dim)  # Now has shape (..., J, d+1, 2)
    location2 = reshape_to_components(location2, physical_dim)  # Now has shape (..., J, d+1, 2)

    s2c1 = location2[..., 1] * location1[..., 0].float()
    c2s1 = location2[..., 0] * location1[..., 1].float()
    c2c1 = location2[..., 0] * location1[..., 0].float()
    s2s1 = location2[..., 1] * location1[..., 1].float()
    delta_thetas = torch.atan2(s2c1 - c2s1, c2c1 + s2s1 + 1e-6) # something fishy here

    delta_thetas = delta_thetas.view(delta_thetas.shape[:-2] + (-1,))
    return delta_thetas

def compute_displacements(
    location1: torch.Tensor, 
    location2: torch.Tensor, 
    K_dagger: torch.Tensor, 
    lattice_basis: torch.Tensor, 
    alphas: torch.Tensor, 
):
    physical_dim = K_dagger.shape[0]
    delta_thetas = compute_angles(location1, location2)

    # estimated displacements from thetas, made as small as possible solving across alphas -- but may not agree!
    # deltas has shape (..., J, d)
    deltas = solve_for_deltas(
        delta_thetas, 
        K_dagger, 
        lattice_basis, 
        alphas
    )
    mean_deltas = deltas.mean(dim=-2)

    return deltas.to(location1.dtype), mean_deltas.to(location1.dtype)

delta_thetas = compute_angles(l1, l2, 2)

print(f"delta_thetas.shape: {delta_thetas.shape}")
print(f"delta_thetas: {delta_thetas.detach().cpu().numpy().tolist()}")

deltas, mean_deltas = compute_displacements(l1, l2, K_dagger, lattice_basis, alphas)

print(f"deltas: {deltas.detach().cpu().numpy().tolist()}")
print(f"mean_deltas: {mean_deltas.detach().cpu().numpy().tolist()}")
delta_thetas.shape: torch.Size([60])
delta_thetas: [0.9627096652984619, -0.8227543830871582, -0.13995537161827087, 1.3614771366119385, -1.1635503768920898, -0.19792671501636505, 1.925419569015503, -1.6455087661743164, -0.27991050481796265, 2.7229557037353516, -2.327101945877075, -0.3958534300327301, -2.4323434829711914, 2.9921655654907227, -0.559821367263794, -0.8372727036476135, 1.6289799213409424, -0.7917068004608154, 1.4184958934783936, -0.2988537549972534, -1.1196424961090088, -1.674545407295227, -3.0252232551574707, -1.583414077758789, 2.836993455886841, -0.5977075695991516, -2.239285945892334, 2.934088945388794, 0.232738196849823, 3.1163547039031982, -0.6091985702514648, -1.1954123973846436, 1.8046104907989502, -0.41500645875930786, 0.46547645330429077, -0.050475697964429855, -1.218399167060852, -2.39082407951355, -2.6739628314971924, -0.8300091624259949, 0.9309605956077576, -0.10094953328371048, -2.436805486679077, 1.501538634300232, 0.9352618455886841, -1.6600226163864136, 1.8619219064712524, -0.20189906656742096, 1.4095646142959595, 3.0030789375305176, 1.8705166578292847, 2.963141918182373, -2.5593390464782715, -0.4038057327270508, 2.8191308975219727, -0.2770266532897949, -2.542149543762207, -0.3570764362812042, 1.164566159248352, -0.8075658679008484]
deltas: [[-2.0264368057250977, 2.618908166885376], [-2.0264365673065186, 2.618908166885376], [-2.0264370441436768, 2.618908643722534], [-2.026437520980835, 2.6189095973968506], [-2.0264387130737305, 2.618910551071167], [-2.0264384746551514, 2.618910789489746], [-2.026437759399414, 2.618910074234009], [-3.0470595359802246, 2.0296547412872314], [-2.0264382362365723, 2.618910074234009], [-1.5161278247833252, 2.9135377407073975], [-2.0264384746551514, 2.618910074234009], [-2.0264382362365723, 2.618910074234009], [-2.2068605422973633, 2.5147433280944824], [-2.0264387130737305, 2.618910312652588], [-2.0264387130737305, 2.618910551071167], [-2.0264384746551514, 2.618910312652588], [-1.9813331365585327, 2.6449520587921143], [-2.0264384746551514, 2.618910312652588], [-2.0264387130737305, 2.618910074234009], [-2.0264384746551514, 2.618910074234009]]
mean_deltas: [-2.0587196350097656, 2.6002724170684814]

A Metric for Fourier Codes

What we are really interested in is a way to compare location codes that respects the structure of the code. We could have used \|\ell - \ell'\|, which is defined on [1,1]L[1, -1]^L, but the structure of a fourier location code means that the different modules jj have different wavelengths αj\alpha_j, which means that different components should have different weights. Since we’ve sorted our wavelengths, we end up with the early dimensions mattering much more than the later dimensions. We can leverage the foregoing formulae to compose a pseudometric over location codes by extracting the block rotations that make two location align, extracting Δθij\Delta \theta_{ij}, computing Δx\Delta x and finally return Δx\|\Delta x\| as our norm. In other words, we return the physical displacement underlying the two codes. In terms of codes, I believe it is a full metric; but in terms of underlying physical space it is a pseudometric because finite-dimensional location codes cannot make distinctions that are too large or too small.

In addition to the distance, we can consider the case where a “location code” has varying physical interpretations across modules jj. In this case, we can compute a mean physical interpretation and add the variance to our distance as well.

def pseudo_distance(
    location1: torch.Tensor, 
    location2: torch.Tensor, 
    K_dagger: torch.Tensor, 
    lattice_basis: torch.Tensor, 
    alphas: torch.Tensor, 
    squared: bool=False, 
    use_variance: bool=True
):
    deltas, mean_deltas = compute_displacements(location1, location2, K_dagger, lattice_basis, alphas)
    J = deltas.shape[-2]
    assert J > 1, "J must be greater than 1"

    squared_distances = mean_deltas.square().sum(dim=-1)

    if use_variance:
        dev_deltas = deltas - mean_deltas[..., None, :]
        variances = dev_deltas.square().sum(dim=-1).mean(dim=-1) * (J / (J - 1))  # (...)
        
        final_squared_distances = squared_distances + variances

    else:
        final_squared_distances = squared_distances

    if squared:
        return final_squared_distances
    else:
        return (final_squared_distances + 1e-12).sqrt()


dist = pseudo_distance(l1, l2, K_dagger, lattice_basis, alphas)
print(f"dist: {dist.detach().cpu().numpy().tolist()}")
dist: 3.330477714538574

Now, when implementing TEM variants, what we are interested in is not just the distance but the cross-distance among many items. That is, given {i}\{\ell_i\} and {j}\{\ell_j\} we want to compute {ij}\{\|\ell_i - \ell_j\|\} for all i,ji,j. We are interested in the case where i\ell_i has shape (B,T,L)(B, T, L) and j\ell_j has shape (B,S,L)(B, S, L), and we will limit the flexibility for this case. This is easily done with shape tricks:

def cross_distance(location1: torch.Tensor, location2: torch.Tensor, K_dagger: torch.Tensor, lattice_basis: torch.Tensor, alphas: torch.Tensor, squared: bool=False):
    return pseudo_distance(location1[..., :, None, :], location2[..., None, :, :], K_dagger, lattice_basis, alphas, squared=squared)

B = 2
T = 10
S = 8
li = sample(120, K, alphas, phis, (B, T))
lj = sample(120, K, alphas, phis, (B, S))

cross_dist = cross_distance(li, lj, K_dagger, lattice_basis, alphas)
print(f"cross_dist: {cross_dist}")
print(f"Min dist = {cross_dist.min().item()} Mean dist = {cross_dist.mean().item()} Max dist = {cross_dist.max().item()}")
sample picked x=[[[[-0.8145899772644043], [-0.569575309753418]], [[0.32764700055122375], [-0.9747219085693359]], [[2.815774917602539], [-2.0576868057250977]], [[0.35869482159614563], [0.6466221213340759]], [[-0.3521421551704407], [-0.3835611641407013]], [[-0.8626407384872437], [-1.2137911319732666]], [[-0.7666817307472229], [-0.2413574904203415]], [[1.1334725618362427], [2.322641372680664]], [[0.5391432642936707], [0.39215847849845886]], [[-0.7231886386871338], [-0.7787725329399109]]], [[[1.922032356262207], [0.7315117716789246]], [[1.4502531290054321], [0.6809671521186829]], [[-0.43476399779319763], [-2.8658132553100586]], [[-0.7493019700050354], [0.9299961924552917]], [[-0.2567051351070404], [0.39329344034194946]], [[0.5350938439369202], [-0.032766442745923996]], [[0.04726826772093773], [-1.804565668106079]], [[0.8472246527671814], [0.23619519174098969]], [[-0.1825113743543625], [-0.9091907739639282]], [[-0.4195076525211334], [-1.171326994895935]]]]
sample picked x=[[[[2.3358025550842285], [-1.3003051280975342]], [[-1.249247431755066], [0.008154939860105515]], [[-0.17388926446437836], [0.19512693583965302]], [[-2.7206344604492188], [-0.8759704828262329]], [[-1.68034827709198], [0.7819029688835144]], [[-0.7264818549156189], [0.43826475739479065]], [[-0.6606448292732239], [-0.14920498430728912]], [[-0.44584688544273376], [2.2403321266174316]]], [[[-0.6501056551933289], [-0.5011352896690369]], [[0.7847572565078735], [-0.533992350101471]], [[-0.3528536260128021], [-1.6510188579559326]], [[-0.6515844464302063], [0.34739628434181213]], [[0.25842607021331787], [-1.6758942604064941]], [[-0.08893660455942154], [-0.366155207157135]], [[0.049642354249954224], [2.1746392250061035]], [[1.3555898666381836], [-0.771336019039154]]]]
cross_dist: tensor([[[3.4689, 0.7575, 0.9823, 1.8817, 1.6062, 1.1158, 0.4538, 3.1328],
         [2.2122, 2.0509, 1.2782, 3.1153, 2.9097, 1.7702, 1.4117, 3.4374],
         [0.8993, 4.9564, 4.0990, 6.1888, 5.8806, 4.6774, 4.3384, 5.4092],
         [2.9460, 1.7308, 0.6888, 3.4417, 2.2114, 1.0775, 1.2804, 1.7908],
         [3.0486, 1.0823, 0.6142, 2.3939, 1.8334, 0.9066, 0.4071, 2.8971],
         [3.2221, 1.2837, 1.6074, 2.0457, 2.2520, 1.7347, 1.1488, 3.6484],
         [3.4882, 0.5877, 0.7055, 2.0300, 1.4394, 0.7544, 0.1382, 2.6188],
         [4.0144, 3.2426, 2.4797, 4.8874, 3.1290, 2.6444, 2.9847, 1.5713],
         [2.5590, 1.8164, 0.7261, 3.4920, 2.4291, 1.2892, 1.2877, 2.0946],
         [3.3575, 0.9466, 1.1197, 1.9633, 1.8340, 1.3099, 0.7096, 3.3055]],

        [[2.7375, 1.6826, 3.1825, 2.5894, 2.8849, 2.2471, 2.4678, 1.7314],
         [2.3656, 1.3876, 2.8719, 2.0924, 2.6491, 1.8310, 2.1588, 1.5934],
         [2.4782, 2.6401, 1.2733, 3.3916, 1.3760, 2.6298, 5.6739, 2.7023],
         [1.5595, 2.2391, 2.8605, 0.6174, 2.9280, 1.5082, 1.4540, 2.9558],
         [1.0579, 1.4847, 2.2328, 0.3931, 2.2495, 0.8114, 2.0267, 2.1672],
         [1.2538, 0.5631, 1.8483, 1.3167, 1.7665, 0.7070, 2.3612, 1.1531],
         [1.4785, 1.4660, 0.4537, 2.3618, 0.2485, 1.5734, 4.3647, 1.6213],
         [1.6580, 0.8404, 2.2345, 1.6403, 2.1732, 1.0892, 2.1921, 1.1294],
         [0.6448, 1.0114, 0.8087, 1.4078, 0.8845, 0.5831, 3.3470, 1.5272],
         [0.7474, 1.3595, 0.5283, 1.6079, 0.8803, 0.8814, 3.5551, 1.8145]]])
Min dist = 0.13820236921310425 Mean dist = 2.040937900543213 Max dist = 6.188777446746826

Class Implementation and Relative Sampling

This code is implemented in tree_world.models.fourier_metric.FourierMetric with KK, KK^\dagger, Λ\Lambda, α\alpha, and ϕ\phi as buffer variables. It is implemented in a torch.nn.Module but has does not generally have learnable parameters. This class can be used to sample and compare location codes.

from tree_world.models.fourier_metric import FourierMetric

metric = FourierMetric(location_dim=120, dim=2, scale=10.0, ratio=math.sqrt(2.0))

li = metric.sample((B, T))
lj = metric.sample((B, S))
dist = metric.cross_distance(li, lj)
print(f"dist: {dist}")
dist: tensor([[[0.9421, 1.5962, 2.3569, 2.3076, 1.9625, 1.9333, 1.5775, 1.8270],
         [1.5659, 1.7411, 1.1836, 2.5785, 0.4220, 1.1630, 2.9069, 0.4242],
         [2.2452, 2.4857, 1.7471, 3.4374, 0.5024, 1.8198, 3.4858, 1.0958],
         [1.2867, 1.6016, 1.3601, 2.5070, 0.6681, 1.4069, 2.5962, 0.7059],
         [3.0952, 2.6262, 1.7894, 2.9229, 3.1209, 1.9545, 4.0000, 2.4535],
         [1.5533, 1.5758, 0.0778, 2.3225, 1.2331, 0.4601, 2.7617, 0.6553],
         [0.3575, 0.2500, 1.4648, 1.1098, 2.0950, 0.9737, 1.3357, 1.5465],
         [1.5390, 1.2148, 0.7372, 1.7099, 1.9579, 0.5735, 2.4550, 1.3080],
         [1.5604, 1.7113, 0.8066, 2.5536, 0.5404, 0.8445, 2.9474, 0.0792],
         [1.2319, 0.6270, 1.8447, 0.5421, 2.5495, 1.4288, 1.4021, 2.0933]],

        [[2.3313, 1.7126, 0.9637, 1.1956, 1.2823, 1.0090, 1.0556, 1.2244],
         [3.3409, 2.9554, 1.1170, 0.9126, 2.2302, 2.1943, 2.3602, 0.4456],
         [1.9675, 1.8695, 0.5109, 0.5249, 0.7804, 1.2173, 2.4413, 1.1589],
         [3.3661, 2.8070, 1.0090, 0.8525, 2.0592, 1.9736, 2.0501, 0.2020],
         [4.1052, 3.4114, 2.2412, 2.3425, 2.9672, 2.6285, 1.2778, 1.8429],
         [2.4114, 1.7896, 1.2168, 1.5364, 1.5727, 1.1399, 0.7792, 1.3450],
         [3.1906, 3.2185, 1.4153, 1.1378, 2.0199, 2.5133, 3.2347, 1.4273],
         [1.9130, 2.1871, 1.3014, 1.2335, 1.1101, 1.7137, 3.3191, 1.8391],
         [3.3696, 2.8035, 0.9442, 0.7955, 2.0526, 1.9916, 2.1881, 0.2866],
         [3.3669, 2.7935, 2.3590, 2.4586, 2.6393, 2.1803, 0.6271, 2.2216]]])

The module tree_world.models.fourier_metric also defines a relative sampler in support of our TEM VAE. To understand this problem, suppose you have a reference location 0\ell_0, and you want to sample a neighboring location \ell that preferentially lies in some neighborhood of scale ϵ>0\epsilon > 0. Again, we can revert to physical space for our operations. We can sample ΔxN(0,ϵ2I)\Delta x \sim \mathcal{N}\left(0, \epsilon^2 I\right) and then compute Δθij\Delta\theta_{ij} and the block rotation R(Δθij)R(\Delta\theta_{ij}). Applying this rotation to 0\ell_0 yields a location that is ϵ\epsilon-close to 0\ell_0 in terms of the pseudometric defined above. Here is the code:

import torch.distributions as D
from typing import Optional, Tuple


class FourierCodeDistribution(D.Distribution):
    support = D.constraints.real
    has_rsample = True

    def __init__(self, metric: FourierMetric, reference_location: torch.Tensor, scale: torch.Tensor, 
                 batch_lengths: Optional[torch.Tensor]=None, idx: Optional[torch.Tensor]=None, validate_args=None, tau: float=1e-6):
        self.metric = metric
        self.reference_location = reference_location
        self.batch_lengths = batch_lengths
        self.idx = idx
        self.dtype = reference_location.dtype
        self.device = reference_location.device
        self.scale = scale
        self.tau = tau

        batch_shape = reference_location.shape[:-1]

        assert scale.shape == () or scale.shape == batch_shape
        if scale.shape == ():
            self.scale = self.scale.expand(batch_shape)

        super().__init__(batch_shape=batch_shape, event_shape=(self.metric.location_dim,), validate_args=validate_args)

    def sample(self, sample_shape: Tuple[int, ...] = torch.Size()):
        # sample Delta x ~ N(0, I_M)
        u = torch.randn(sample_shape + self.batch_shape + (self.metric.dim,), device=self.device, dtype=self.dtype)
        scale = self.scale[..., None]
        while scale.ndim < u.ndim:
            scale = scale[None, ...]
        deltas = scale * u
        locations = self.metric.apply_displacement(deltas, self.reference_location)
        return locations, deltas
    
    rsample = sample

    def log_det_jacobian(
        self, displacements: torch.Tensor
    ):
        # note, this won't work for samples shapes other than ()!
        output_shape = displacements.shape[:-1]

        delta_thetas = self.metric.compute_angles_from_displacements(displacements)

        dplus = self.metric.dim + 1
        reference_location = self.reference_location.view(-1, self.metric.J, dplus, 2)
        delta_thetas = delta_thetas.view(-1, self.metric.J, dplus)

        xi = torch.atan2(reference_location[..., 1], reference_location[..., 0])
        delta_xi = xi[..., None] - xi[..., None, :]
        
        delta_delta_theta = delta_thetas[..., None] - delta_thetas[..., None, :]

        # shape (flat_batch, J, dplus, dplus) (vs. K = (dplus, d))
        W = (torch.cos(delta_xi - delta_delta_theta) * self.metric.alphas.square()[None, :, None, None]).sum(dim=1)
        K = self.metric.K[None, None, ...]
        KWK = ((K.transpose(-2, -1) @ W) @ K).float()

        logdet = 0.5 * torch.logdet(KWK + self.tau * torch.eye(self.metric.dim, device=KWK.device, dtype=torch.float32))
        return logdet.view(output_shape).to(self.dtype)

    def log_prob(self, locations: torch.Tensor, displacements: Optional[torch.Tensor]=None):
        if displacements is None:
            _, displacements = self.metric.compute_displacements(self.reference_location, locations)

        scale = self.scale[..., None]
        while scale.ndim < displacements.ndim:
            scale = scale[None, ...]

        displacements = displacements / scale

        gaussian_log_prob = (
            - 0.5 * displacements.square().sum(dim=-1)
            - self.metric.dim * (0.5 * math.log(2.0 * math.pi) + torch.log(self.scale))
        )

        return gaussian_log_prob - self.log_det_jacobian(displacements)

And, a demonstration:

l0 = metric.sample((B, T))
epsilon = torch.empty(B, T).uniform_(0.01, 0.1)

sampler = FourierCodeDistribution(metric, l0, epsilon)
l, deltas = sampler.sample()
print(f"l: {l.shape}")
dist = metric.pseudo_distance(l, l0)
print(f"Dist: {dist}")
print(f"Min dist = {dist.min().item()} Mean dist = {dist.mean().item()} Max dist = {dist.max().item()}")
print(f"Epsilon-scaled dist: {dist / epsilon}")
print(f"Epsilon-scaled dist min = {(dist / epsilon).min().item()} mean = {(dist / epsilon).mean().item()} max = {(dist / epsilon).max().item()}")

log_prob = sampler.log_prob(l, deltas)
print(f"log_prob: {log_prob.shape} min = {log_prob.min().item()} mean = {log_prob.mean().item()} +/- {log_prob.std().item()} max = {log_prob.max().item()}")
l: torch.Size([2, 10, 120])
Dist: tensor([[0.0567, 0.0542, 0.0332, 0.1041, 0.0506, 0.0575, 0.0133, 0.0610, 0.0355,
         0.0346],
        [0.0643, 0.0765, 0.0197, 0.2781, 0.1130, 0.0245, 0.0872, 0.0749, 0.0454,
         0.1082]])
Min dist = 0.013298843055963516 Mean dist = 0.0696159079670906 Max dist = 0.2781006693840027
Epsilon-scaled dist: tensor([[0.8619, 1.3081, 1.6165, 2.5906, 1.3057, 0.6437, 1.2321, 1.9036, 0.5559,
         0.3758],
        [1.0868, 1.2073, 1.1949, 3.0120, 1.1895, 1.0035, 1.1124, 1.0657, 2.9058,
         1.4086]])
Epsilon-scaled dist min = 0.37576478719711304 mean = 1.3790267705917358 max = 3.0119752883911133
log_prob: torch.Size([2, 10]) min = -12.616140365600586 mean = -8.568292617797852 +/- 1.6212801933288574 max = -5.113590717315674
/Users/alockett/dev/tree-world/venv/lib/python3.12/site-packages/torch/distributions/distribution.py:62: UserWarning: <class '__main__.FourierCodeDistribution'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.
  warnings.warn(

As we see, our referential sampler provides samples that are “close” to the reference point, and when we scale for epsilon, the mean epsilon-scaled distance is about 1, as we expect.

The one bit that hasn’t been explained is the log probability calculation, which is needed for the VAE loss on TEM. We’ll cover that now.

Computing the Log Probability

In general, when sampling a variable \ell through a transform u()u(\ell) by inverting, we have the log probability

logp()=logp(u)12logJTJ\log p(\ell) = \log p(u) - \frac{1}{2} \log \left|J^T J\right|

where J=Ju()J = J_u(\ell) is the Jacobian matrix of uu with respect to \ell, assuming JJ is not square. The derivation for this is given in Building Pseudometrics -- Embedding Case.

In this case, the inverse functions are composed of

u1(Δθ)=R(Δθ)0=(ij00cosΔθijij10sinΔθijij00sinΔθij+ij10cosΔθij)andΔθij1(Δx)=αj(kiΔx)u^{-1}(\Delta \theta) = R(\Delta\theta) \ell_0 = \left( \begin{array}{c} \ell^0_{ij0} \cos\,\Delta\theta_{ij} - \ell^0_{ij1}\sin\,\Delta\theta_{ij} \\ \ell^0_{ij0} \sin\,\Delta\theta_{ij} + \ell^0_{ij1}\cos\,\Delta\theta_{ij} \end{array} \right) \quad\quad\text{and}\quad\quad \Delta\theta^{-1}_{ij}\left(\Delta x\right) = \alpha_j \left(k_i \cdot \Delta x \right)

This yields Jacobians with components

Ju(Δθ)ij=(ij00sinΔθijij10cosΔθijij00cosΔθijij10sinΔθij)andJΔθij(Δx)=αjkiJ_u(\Delta\theta)_{ij} = \left( \begin{array}{c} - \ell^0_{ij0} \sin\,\Delta\theta_{ij} - \ell^0_{ij1}\cos\,\Delta\theta_{ij} \\ \ell^0_{ij0} \cos\,\Delta\theta_{ij} - \ell^0_{ij1}\sin\,\Delta\theta_{ij} \end{array} \right) \quad\quad\text{and}\quad\quad J_{\Delta\theta_{ij}}(\Delta x) = \alpha_j k_i

and so, putting these components together,

Ju(Δx)=[αjki(ij00sinΔθijij10cosΔθijij00cosΔθijij10sinΔθij)]J_u(\Delta x) \,\,=\,\, \left[\alpha_j k_i \left( \begin{array}{c} - \ell^0_{ij0} \sin\,\Delta\theta_{ij} - \ell^0_{ij1}\cos\,\Delta\theta_{ij} \\ \ell^0_{ij0} \cos\,\Delta\theta_{ij} - \ell^0_{ij1}\sin\,\Delta\theta_{ij} \end{array} \right)\right]

where the r.h.s. is interpreted as a tensor of shape (J,d+1,2,d)(J, d+1, 2, d) indexed by (j,i,m,n)(j, i, m, n), reshaped to a matrix (L,d)(L, d) where L=2J(d+1)L = 2J(d+1) flattens the first three dimensions. That means that JTJJ^T J is a d×dd \times d matrix, and the only components involved in generating these two dd-sized dimensions are the kik_i vectors above. Therefore, the form of JTJ=KTWKJ^T J = K^T W K where WW represents the interaction terms. WW is a (d+1)×(d+1)(d+1)\times(d+1) matrix and is the “square” W=STSW = S^T S of the (2J,d+1)(2J, d+1) matrix SS obtained by pulling kik_i out of the equation above and conflating the jm{jm} dimensions. Squaring the jmjm terms and using aa to alias ii in STS^T,

Wa,i=jαj2 ⁣ ⁣m{0,1}yajmyijmW_{a, i} = \sum_{j} \alpha_j^2 \!\!\sum_{m\in \{0,1\}} y_{ajm} y_{ijm}

where

yij0=ij00sinΔθijij10cosΔθijy_{ij0} = - \ell^0_{ij0} \sin\,\Delta\theta_{ij} - \ell^0_{ij1}\cos\,\Delta\theta_{ij}
yij1=ij00cosΔθijij10sinΔθijy_{ij1} = \ell^0_{ij0} \cos\,\Delta\theta_{ij} - \ell^0_{ij1}\sin\,\Delta\theta_{ij}

which multiplies to

yaj0yij0=aj00ij00sinΔθajsinΔθij+aj00ij10sinΔθajcosΔθij+aj10ij00cosΔθajsinΔθij+aj10ij10cosΔθajcosΔθijy_{aj0} y_{ij0} = \ell^0_{aj0}\ell^0_{ij0}\sin\,\Delta\theta_{aj}\sin\,\Delta\theta_{ij} + \ell^0_{aj0}\ell^0_{ij1}\sin\,\Delta\theta_{aj}\cos\,\Delta\theta_{ij} + \ell^0_{aj1}\ell^0_{ij0}\cos\,\Delta\theta_{aj}\sin\,\Delta\theta_{ij} + \ell^0_{aj1}\ell^0_{ij1}\cos\,\Delta\theta_{aj}\cos\,\Delta\theta_{ij}
yaj1yij1=aj00ij00cosΔθajcosΔθijaj00ij10cosΔθajsinΔθijaj10ij00sinΔθajcosΔθij+aj10ij10sinΔθajsinΔθijy_{aj1} y_{ij1} = \ell^0_{aj0}\ell^0_{ij0}\cos\,\Delta\theta_{aj}\cos\,\Delta\theta_{ij} - \ell^0_{aj0}\ell^0_{ij1}\cos\,\Delta\theta_{aj}\sin\,\Delta\theta_{ij} - \ell^0_{aj1}\ell^0_{ij0}\sin\,\Delta\theta_{aj}\cos\,\Delta\theta_{ij} + \ell^0_{aj1}\ell^0_{ij1}\sin\,\Delta\theta_{aj}\sin\,\Delta\theta_{ij}

The inner sum over mm allows us to use angle formulae:

sin(α±β)=sinαcosβ±cosαsinβcos(α±β)=cosαcosβsinαsinβ\begin{array}{ccc} \sin\left(\alpha \pm \beta\right) &=& \sin \alpha \cos \beta \pm \cos\alpha \sin\beta \\ \cos\left(\alpha \pm \beta\right) &=& \cos \alpha \cos\beta \mp \sin\alpha\sin\beta \end{array}

yielding

yaj0yij0+yaj1yij1=(aj00ij00+aj10ij10)cos(ΔθajΔθij)+(aj00ij10aj10ij00)sin(ΔθajΔθij).y_{aj0}y_{ij0} + y_{aj1}y_{ij1} \quad=\quad \left(\ell^0_{aj0}\ell^0_{ij0} + \ell^0_{aj1}\ell^0_{ij1}\right)\cos\left(\Delta \theta_{aj} - \Delta\theta_{ij}\right) \quad+\quad \left(\ell^0_{aj0}\ell^0_{ij1} - \ell^0_{aj1}\ell^0_{ij0}\right)\sin\left(\Delta \theta_{aj} - \Delta\theta_{ij}\right).

Recall that as a location code, ij00=cosξij\ell^0_{ij0} = \cos \xi_{ij} and ij10=sinξij\ell^0_{ij1} = \sin \xi_{ij} for some ξij\xi_{ij}. Therefore,

aj00ij00+aj10ij10=cos(ξajξij)andaj00ij10aj10ij00=sin(ξajξij).\ell^0_{aj0}\ell^0_{ij0} + \ell^0_{aj1}\ell^0_{ij1} = \cos\left(\xi_{aj} - \xi_{ij}\right) \quad\quad\text{and}\quad\quad \ell^0_{aj0}\ell^0_{ij1} - \ell^0_{aj1}\ell^0_{ij0} = \sin\left(\xi_{aj} - \xi_{ij}\right).

Therefore

yaj0yij0+yaj1yij1=cos(ξajξij)cos(ΔθajΔθij)+sin(ξajξij)sin(ΔθajΔθij)y_{aj0}y_{ij0} + y_{aj1}y_{ij1} \quad=\quad \cos\left(\xi_{aj} - \xi_{ij}\right)\,\,\cos\left(\Delta \theta_{aj} - \Delta\theta_{ij}\right) \quad+\quad \sin\left(\xi_{aj} - \xi_{ij}\right)\,\,\sin\left(\Delta \theta_{aj} - \Delta\theta_{ij}\right)

Applying the identities one last time, we find that

yaj0yij0+yaj1yij1=cos(xxxx(ξajξij)(ΔθajΔθij)),y_{aj0}y_{ij0} + y_{aj1}y_{ij1} \quad=\quad \cos\left(\vphantom{x^{x^{x^x}}} \left(\xi_{aj} - \xi_{ij}\right) - \left(\Delta \theta_{aj} - \Delta\theta_{ij}\right)\right),

leaving

Wa,i=jαj2cos(xxxx(ξajξij)(ΔθajΔθij)).W_{a,i} = \sum_j \alpha_j^2 \cos\left(\vphantom{x^{x^{x^x}}} \left(\xi_{aj} - \xi_{ij}\right) - \left(\Delta \theta_{aj} - \Delta\theta_{ij}\right)\right).

In order to compute logdetJTJ\log \det |J^T J|, we must compute WW and then logdetKTWK\log\det\left|K^T W K\right|, which is dependent on the input Δθ\Delta\theta.

def calculate_jacobian_interaction(
    reference_location: torch.Tensor, alphas: torch.Tensor, delta_thetas: torch.Tensor,
    K: torch.Tensor, physical_dim: int, J: int, tau: float=1e-6
):
    dplus = physical_dim + 1
    reference_location = reference_location.view(-1, J, dplus, 2)
    delta_thetas = delta_thetas.view(-1, J, dplus)

    xi = torch.atan2(reference_location[..., 1], reference_location[..., 0])
    delta_xi = xi[..., None] - xi[..., None, :]
    
    delta_delta_theta = delta_thetas[..., None] - delta_thetas[..., None, :]

    # shape (flat_batch, J, dplus, dplus) (vs. K = (dplus, d))
    W = (torch.cos(delta_xi - delta_delta_theta) * alphas.square()[None, :, None, None]).sum(dim=1)
    K = K[None, None, ...]
    KWK = (K.transpose(-2, -1) @ W) @ K

    return 0.5 * torch.logdet(KWK + tau * torch.eye(d, device=KWK.device, dtype=KWK.dtype))

Conclusion

We’ve built a metric for Fourier location codes that respects the underlying physicality. Together with the embedding pseudo-metric we’ll build next, we can use this metric to implement an advanced spatial memory for TEM-t .

Addendum: Test encoding

from importlib import reload
import tree_world.fourier
import tree_world.models.fourier_metric
reload(tree_world.fourier)
reload(tree_world.models.fourier_metric)

metric = tree_world.models.fourier_metric.FourierMetric(location_dim=120, dim=2, scale=1000.0, ratio=math.sqrt(2.0))

locations = torch.randn(10, 2) * 10
encoded = metric.encode(locations)
decoded_with_j, decoded = metric.interpret(encoded)

print(f"decoded_with_j: {decoded_with_j.shape}")
print(f"decoded: {decoded.shape}")

diff = torch.norm(locations - decoded, dim=-1)

print(f"locations: {locations.detach().cpu().numpy().tolist()}")
print(f"decoded: {decoded.detach().cpu().numpy().tolist()}")
print(f"diff: {diff.detach().cpu().numpy().tolist()}")
print(f"decoded_with_j[base]: {decoded_with_j[..., 0, :].detach().cpu().numpy().tolist()}")
thetas (-1): torch.Size([10, 3]), [1.9679633378982544, -13.361233711242676, 11.393270492553711]
thetas (0): torch.Size([10, 20, 3]), [[-2.225909948348999, -2.2740678787231445, -2.1962993144989014], [-2.2488722801208496, -2.2824606895446777, -2.1649441719055176], [-2.225297451019287, -2.2031636238098145, -2.2678160667419434], [-2.238586902618408, -2.2492616176605225, -2.2084288597106934], [-2.2219350337982178, -2.2988901138305664, -2.1754519939422607], [-2.1928365230560303, -2.206786870956421, -2.2966537475585938], [-2.2440671920776367, -2.224576473236084, -2.227633476257324], [-2.2316999435424805, -2.250920057296753, -2.2136573791503906], [-2.2012805938720703, -2.268103837966919, -2.2268927097320557], [-2.2615439891815186, -2.205914258956909, -2.228818893432617]]
locations (1): torch.Size([10, 20, 3, 2]), [[[-0.6092493534088135, -0.7929787039756775], [-0.3557904064655304, -0.9345657825469971], [-0.8461326360702515, -0.5329723954200745], [0.9849703907966614, 0.1727232038974762], [0.4764881432056427, -0.879180908203125], [0.9761363863945007, 0.21715831756591797], [-0.6183295249938965, -0.7859189510345459], [0.36158695816993713, 0.9323384165763855], [-0.7824238538742065, -0.6227462887763977], [-0.789602518081665, 0.6136186718940735], [0.5790926218032837, -0.8152617812156677], [-0.7545367479324341, -0.6562577486038208], [0.5645982027053833, 0.8253659009933472], [-0.028533324599266052, 0.9995928406715393], [0.4315015375614166, 0.9021121859550476], [0.999570369720459, -0.029309246689081192], [-0.36351820826530457, -0.9315871000289917], [-0.8710347414016724, -0.49122142791748047], [0.32901322841644287, 0.94432532787323], [-0.15151476860046387, -0.9884549975395203]], [[-0.6272957921028137, -0.7787811160087585], [-0.3859463930130005, -0.9225212335586548], [-0.8697085976600647, -0.49356555938720703], [0.9941037893295288, 0.10843255370855331], [0.3938405215740204, -0.9191787838935852], [0.996041476726532, 0.08888974040746689], [-0.7514882683753967, -0.6597464680671692], [0.5889506340026855, 0.8081690073013306], [-0.9538923501968384, -0.30014896392822266], [-0.38072511553764343, 0.9246882796287537], [-0.11691465228796005, -0.9931419491767883], [-0.9481881856918335, 0.3177092671394348], [0.8781832456588745, -0.47832438349723816], [0.8874616026878357, -0.46088168025016785], [-0.24134846031665802, -0.970438539981842], [-0.5024516582489014, 0.8646053075790405], [0.032755788415670395, -0.9994633793830872], [-0.054159343242645264, 0.9985322952270508], [-0.45675885677337646, 0.8895905613899231], [0.8775416016578674, 0.47950056195259094]], [[-0.6087635159492493, -0.7933517098426819], [-0.3549808859825134, -0.9348735809326172], [-0.8454792499542236, -0.5340083241462708], [0.9846697449684143, 0.1744290292263031], [0.47864025831222534, -0.8780111074447632], [0.9753782749176025, 0.22053857147693634], [-0.6144718527793884, -0.7889387011528015], [0.35511866211891174, 0.9348212480545044], [-0.7762845158576965, -0.6303827166557312], [-0.7980292439460754, 0.6026186943054199], [0.5949568152427673, -0.8037576675415039], [-0.7360618114471436, -0.6769143342971802], [0.531823992729187, 0.8468549251556396], [-0.08386636525392532, 0.9964770078659058], [0.3595351576805115, 0.9331315755844116], [0.9966772794723511, 0.08145192265510559], [-0.21360936760902405, -0.9769191741943359], [-0.7416945099830627, -0.6707378029823303], [0.021709362044930458, 0.9997643232345581], [0.28721901774406433, -0.9578649401664734]], [[-0.6192526817321777, -0.7851917743682861], [-0.3724871277809143, -0.9280373454093933], [-0.8593721389770508, -0.5113506317138672], [0.9905291199684143, 0.13730277121067047], [0.43131333589553833, -0.9022021889686584], [0.9891870021820068, 0.146659716963768], [-0.6947205066680908, -0.7192797660827637], [0.4911356568336487, 0.8710830807685852], [-0.8918324112892151, -0.4523659944534302], [-0.5837292075157166, 0.8119484186172485], [0.21036912500858307, -0.9776220321655273], [-0.9899130463600159, -0.14167632162570953], [0.9873029589653015, 0.15884873270988464], [0.8995989561080933, 0.4367170035839081], [0.8785343170166016, -0.47767919301986694], [-0.6840437650680542, -0.7294409275054932], [0.4580058753490448, 0.8889491558074951], [0.594248354434967, -0.8042815923690796], [0.5164297819137573, 0.8563295602798462], [-0.09344228357076645, 0.9956247210502625]], [[-0.6060925126075745, -0.7953941226005554], [-0.35053154826164246, -0.9365509152412415], [-0.8418691754341125, -0.5396816730499268], [0.982966423034668, 0.18378528952598572], [0.49040526151657104, -0.8714945316314697], [0.9710074067115784, 0.23904940485954285], [-0.5930308103561401, -0.8051797747612], [0.3193094730377197, 0.9476504921913147], [-0.7412649393081665, -0.6712125539779663], [-0.8415241241455078, 0.5402194857597351], [0.677829384803772, -0.7352192401885986], [-0.6249541640281677, -0.780661404132843], [0.33872780203819275, 0.9408844113349915], [-0.3786039352416992, 0.9255588054656982], [-0.06256643682718277, 0.9980407953262329], [0.7711268067359924, 0.6366815567016602], [0.601586639881134, -0.7988075613975525], [0.3724956214427948, -0.9280339479446411], [-0.9916909337043762, -0.12864302098751068], [0.40382176637649536, 0.914837658405304]], [[-0.582694411277771, -0.812691330909729], [-0.31170493364334106, -0.9501789808273315], [-0.8090534806251526, -0.5877350568771362], [0.9645299911499023, 0.2639734148979187], [0.5882955193519592, -0.808646023273468], [0.9187104105949402, 0.3949318826198578], [-0.39128577709198, -0.9202691912651062], [-0.004213245119899511, 0.9999911189079285], [-0.3610319197177887, -0.9325534701347351], [-0.9961541891098022, -0.08761728554964066], [0.9944750070571899, 0.10497363656759262], [0.5986244082450867, -0.8010298609733582], [-0.9985398650169373, 0.05401952192187309], [-0.1193096786737442, -0.9928570985794067], [0.6017131209373474, -0.7987123131752014], [0.9474664330482483, -0.3198551535606384], [0.9711562991142273, 0.2384437471628189], [-0.9968783259391785, 0.07895282655954361], [0.7772918343544006, -0.6291401982307434], [-0.9728303551673889, -0.23151901364326477]], [[-0.6235464215278625, -0.781786322593689], [-0.3796687424182892, -0.9251224994659424], [-0.8649253845214844, -0.5019004940986633], [0.9925383925437927, 0.12193258106708527], [0.41143330931663513, -0.9114398956298828], [0.9932577013969421, 0.11592713743448257], [-0.7255787253379822, -0.6881391406059265], [0.5441682934761047, 0.8389760851860046], [-0.928022027015686, -0.3725253939628601], [-0.4788154363632202, 0.8779155611991882], [0.03656914457678795, -0.9993311166763306], [-0.9944016933441162, 0.10566579550504684], [0.9817719459533691, -0.19006264209747314], [0.999028205871582, -0.044075630605220795], [0.36281266808509827, -0.9318621158599854], [-0.9848017692565918, 0.17368227243423462], [0.9529551863670349, -0.30311119556427], [-0.975242018699646, -0.22114025056362152], [-0.20560991764068604, -0.9786340594291687], [-0.6691880226135254, -0.7430931329727173]], [[-0.6138304471969604, -0.7894378900527954], [-0.36343085765838623, -0.9316211938858032], [-0.8522475957870483, -0.523138701915741], [0.9876668453216553, 0.15657033026218414], [0.45600035786628723, -0.8899796009063721], [0.9827241897583008, 0.185076043009758], [-0.6540570259094238, -0.7564452290534973], [0.42184171080589294, 0.9066694974899292], [-0.8366773724555969, -0.5476961135864258], [-0.7026738524436951, 0.7115121483802795], [0.4189927279949188, -0.9079896211624146], [-0.898777425289154, -0.43840518593788147], [0.8251713514328003, 0.5648825168609619], [0.47548264265060425, 0.8797250986099243], [0.9273526668548584, 0.37418848276138306], [0.4736085534095764, -0.880735456943512], [-0.9600874185562134, 0.2797001302242279], [0.01191070955246687, 0.9999290704727173], [-0.15748026967048645, -0.9875221252441406], [0.933087944984436, 0.35964834690093994]], [[-0.5895360112190247, -0.8077421188354492], [-0.32302916049957275, -0.9463890194892883], [-0.8188635110855103, -0.5739882588386536], [0.9705588817596436, 0.24086399376392365], [0.5606520771980286, -0.8280514478683472], [0.9365200400352478, 0.35061413049697876], [-0.4525127410888672, -0.8917579650878906], [0.09119383245706558, 0.995833158493042], [-0.4833519160747528, -0.8754261136054993], [-0.9946653842926025, 0.10315409302711487], [0.9864112138748169, -0.16429509222507477], [0.2567394971847534, -0.9664806723594666], [-0.8284469246864319, 0.5600675344467163], [-0.7731962203979492, -0.6341668367385864], [-0.4215902090072632, -0.9067864418029785], [-0.2795475125312805, -0.9601318836212158], [-0.34301578998565674, -0.9393296241760254], [0.9999850392341614, 0.0054725054651498795], [0.28731897473335266, 0.9578349590301514], [-0.9200234413146973, -0.39186328649520874]], [[-0.637113630771637, -0.7707698941230774], [-0.40241557359695435, -0.9154571294784546], [-0.8819366097450256, -0.47136804461479187], [0.9973508715629578, 0.07274086028337479], [0.3467644155025482, -0.9379522800445557], [0.9998499155044556, 0.017324326559901237], [-0.8143966197967529, -0.5803086757659912], [0.6983744502067566, 0.7157325744628906], [-0.9947921633720398, -0.10192417353391647], [-0.10366562008857727, 0.9946122169494629], [-0.49920180439949036, -0.8664857745170593], [-0.6241377592086792, 0.7813143134117126], [0.258105993270874, -0.9661166071891785], [-0.055076178163290024, -0.9984821677207947], [-0.9568203091621399, 0.29068002104759216], [0.9807372093200684, -0.1953318864107132], [0.06954678893089294, 0.9975786805152893], [-0.984035313129425, -0.17797325551509857], [-0.26636478304862976, 0.9638723134994507], [-0.7319835424423218, -0.6813223361968994]]]
thetas (1): torch.Size([10, 20, 3]), [[-2.225909948348999, -2.2740678787231445, -2.1962993144989014], [-2.2488722801208496, -2.2824606895446777, -2.1649441719055176], [-2.225297451019287, -2.2031636238098145, -2.2678160667419434], [-2.238586902618408, -2.2492616176605225, -2.2084288597106934], [-2.2219350337982178, -2.2988901138305664, -2.1754519939422607], [-2.1928365230560303, -2.206786870956421, -2.2966537475585938], [-2.2440671920776367, -2.224576473236084, -2.227633476257324], [-2.2316999435424805, -2.250920057296753, -2.2136573791503906], [-2.2012805938720703, -2.268103837966919, -2.2268927097320557], [-2.2615439891815186, -2.205914258956909, -2.228818893432617]]
thetas (2): torch.Size([10, 20, 3]), [[1.9679292440414429, -13.361215591430664, 11.39328670501709], [-5.341207981109619, -16.032730102539062, 21.373937606811523], [2.162893295288086, 9.208309173583984, -11.37120246887207], [-2.067270517349243, -5.465137958526611, 7.532332420349121], [3.2331838607788086, -21.262378692626953, 18.029193878173828], [12.495527267456055, 8.054993629455566, -20.550521850585938], [-3.8117008209228516, 2.392387628555298, 1.4193133115768433], [0.12491656839847565, -5.993035316467285, 5.868042945861816], [9.807696342468262, -11.46280288696289, 1.655106544494629], [-9.374737739562988, 8.332755088806152, 1.0419832468032837]]
xs (1): torch.Size([10, 20, 2]), [[5.441732406616211, 13.361213684082031], [15.423995971679688, 16.03272819519043], [-7.813913345336914, -9.208308219909668], [5.542333126068115, 5.46511173248291], [8.54248046875, 21.26237678527832], [-19.079145431518555, -8.05499267578125], [3.020127058029175, -2.3923873901367188], [3.315795421600342, 5.993009090423584], [-4.706899166107178, 11.462801933288574], [6.014095783233643, -8.332754135131836]]
decoded_with_j: torch.Size([10, 20, 2])
decoded: torch.Size([10, 2])
locations: [[5.441702842712402, 13.361233711242676], [15.423995018005371, 16.032732009887695], [-7.813913822174072, -9.208306312561035], [5.542320251464844, 5.465102672576904], [8.54250717163086, 21.26236915588379], [-19.079151153564453, -8.054960250854492], [3.0201385021209717, -2.3924241065979004], [3.315791130065918, 5.993025302886963], [-4.7069244384765625, 11.462799072265625], [6.0141096115112305, -8.332731246948242]]
decoded: [[5.441984176635742, 13.354339599609375], [15.423995018005371, 16.025999069213867], [-7.813673496246338, -9.214279174804688], [5.548147678375244, 5.46878719329834], [8.536678314208984, 21.265735626220703], [-19.047008514404297, -8.08303451538086], [3.014796018600464, -2.36527419090271], [3.3157920837402344, 5.987553596496582], [-4.678760051727295, 11.453174591064453], [6.013868808746338, -8.338705062866211]]
diff: [0.006899849511682987, 0.006732940673828125, 0.005977695342153311, 0.006894533988088369, 0.0067311739549040794, 0.04267685115337372, 0.027670564129948616, 0.005471706390380859, 0.02976345643401146, 0.005978667177259922]
decoded_with_j[base]: [[5.441732406616211, 13.361213684082031], [15.423995971679688, 16.03272819519043], [-7.813913345336914, -9.208308219909668], [5.542333126068115, 5.46511173248291], [8.54248046875, 21.26237678527832], [-19.079145431518555, -8.05499267578125], [3.020127058029175, -2.3923873901367188], [3.315795421600342, 5.993009090423584], [-4.706899166107178, 11.462801933288574], [6.014095783233643, -8.332754135131836]]