A metric is a notion of distance generalized to arbitrary topological spaces. It has four key properties:
Nonnegativity:
Uniqueness: if , then
Symmetry:
Triangle Identity:
A pseudometric lacks the uniqueness property: It has with .
For neural localization, we need to compare two spaces: locations and sensory percepts. In general, the representation spaces for these two spaces don’t naturally work with the Euclidean metric, especially for locations as described by grid cell codes. So we’ll first define an explicit metric for the Fourier space of grid cells in this notebook, and then consider how to learn a metric that is useful for comparing sensory percepts that could also be applied to location space in another.
We need this metrics to build spatial memories for non-Euclidean spaces to support TEM.
Fourier Location Encoding¶
Grid cells in dimensions (usually ) fire with a rate described by
when the organism is at position , where is a fixed amplitude and is a wave vector, one of unit vectors whose convex hull forms the simplex. In 2-D, these vectors start at the origin and are radially separated by . In 3-D, four wavevectors are needed, pointing the the vertices of a triangular pyramid.
We can model these cells via an assemblage of column vectors having the form
where the are chosen to cover the interval . Sufficiently wide choices of will cover the full circle.
Grid cell responses are periodic as the animal moves around its environment, and thanks to the identity
we can arrive a relationship between displacement in physical space and rotation in Fourier space. To do this, we form a matrix whose rows range over , so that is a matrix. The rows of span , so the matrix is invertible, and the pseudoinverse of this matrix is has . Then we have
where is the matrix whose columns are . Solving for , we find
With sufficient and carefully chosen , these modular constraints can be solved for a wide range of , and can be identified exactly on this range.
So the firing rates of an assemblage of grid cells with different phases can be stably and reliably inverted to retrieve the physical location .
From One Location to the Next¶
We need a function to update locations based on actions, .
We will assume our actions has an impact on location as translation in physical space. If we assume Fourier location codes (like above), then a displacement becomes a block-wise rotation of each . Recall that a rotation by angle is
so notating , we have
Beyond this rotational update, the Fourier coding can remain implicit in our location-learning scheme. Thus we do not need to know or , but instead we allow to differ for each . Then we can learn a function with parameters and apply it as a block-diagonal matrix multiply as follows:
This treats as a sequence of pairs of components under some enumeration of . The following code implements this scheme more efficiently, without assembling the full block-diagonal matrix. Each application extends the input sequence by one.
Making Location Codes Physical¶
A location code represents a physical point. A system of location codes has to represent different points consistently. That restricts the codes that can simultaneously be valid.
When it comes to changes in location, simply updating the location codes by rotation does not guarantee a unique solution for the displacement ; different can yield different displacements. Basically, these location codes have extra flexibility that might lead to a suboptimal spatial representation. The same position in real space could have many different codes, which will prevent us from being able to generate an accurate map of local space, or to use the map to go to areas of interest or to judge what objects are close to each other.
Note that in our geometric decoder, we have calculated . In general, our location codes should adhere to
based on the formulae above. Our problem is that different will give us different values of ; we can use to represent the value given for the matrix , which has size .
Hence we have an estimator for a random variable , and we can minimize the variance of this estimator
but first we must solve for . To do this, we first canonicalize and compute and then note that
Our problem is that we have too many variables for too few constraints (). We need to eliminate a constraint, and because we chose to be the vertices of the simplex, the structure of our problem lets us do that. We can choose a “basis” that generates as follows. Let
where the are the standard basis vectors of , that is, is the row or column of the identity matrix . For any , we then have that
because , which is a consequence of choosing to be the vertices of the simplex. Now, we define
which reduces the number of constraints by one, allowing us to find solutions.
Let’s compute and for all .
import math
import torch
def make_lattice_basis(alphas, dim: int = 2):
"""
Construct lattice basis matrices B_j for each grid module j in d dimensions.
Args:
alphas: 1D tensor of shape (J,) with spatial frequencies α_j.
dim: spatial dimension d (>=1).
Returns:
B: tensor of shape (J, d, d), where B[j] is the lattice basis B_j.
K: tensor of shape (d+1, d) with simplex directions as rows.
K_dagger: tensor of shape (d, d+1) with pseudoinverse of K, returned for convenience.
"""
if dim < 1:
raise ValueError("dimension must be >= 1")
# Ensure alphas is a tensor of shape (J,)
dtype = alphas.dtype
device = alphas.device
J = alphas.shape[0]
# 1. Build a basis U for the null subspace of the simplex K
U = torch.cat([torch.eye(dim, dtype=dtype, device=device), -torch.ones(1, dim, dtype=dtype, device=device)], dim=0)
# 2. Build a regular simplex in R^(d+1) and normalize
Q, _ = torch.linalg.qr(U, mode="reduced") # Q: (d+1, d)
V = torch.eye(dim + 1, dtype=dtype, device=device) - (1.0 / (dim + 1))
K = V @ Q # (d+1, d)
K = K / K.norm(dim=1, keepdim=True) # each row is now unit length
# Pseudoinverse of K: K^† ∈ R^{d×(d+1)}
K_dagger = torch.linalg.pinv(K)
# Base (unscaled) lattice generator: shape (d, d)
base = K_dagger @ U
# 3. Scale by 2π / α_j for each module
scale = (2.0 * math.pi) / alphas.view(J, 1, 1) # (J, 1, 1)
B = scale * base[None, ...] # (J, d, d)
return B, K, K_daggerTo solve for , we will rely on the being sorted from coarsest ( small = long wavelength) to finest. We will assume this as a matter of initialization for , which will be fixed and not learned. Let’s look at our initialization of for the decoding process.
The choice of will determine the range of that can be recovered. So we will use a scale parameter (). This parameter is, in effect, the maximum norm of displacement caused by any action. Hence actions determine the scale of the space.
Rather than generating the directly, we will instead focus on the wavelength . Since our will be sorted, we want . This will allow us to recover actions with norms in the range . From there, each successive will use a smaller and smaller wavelength in a geometric pattern. The ratio () of reduction is a parameter to our generator, and we will have , whence .
import math
import torch
def make_alphas(location_dim: int, dim: int = 2, scale: float = 10.0,
ratio: float = math.sqrt(2.0), dtype=torch.get_default_dtype(), device=torch.device("cpu")) -> torch.Tensor:
"""
Choose module spatial frequencies alpha_j given a location code size and
a target "safe" displacement scale.
Args:
location_dim: int, total number of phase channels = J * (d+1).
dim: spatial dimension d.
scale: radius such that for ||Δx|| < scale, the coarsest module
is unambiguous (λ_0/2 ≈ scale).
ratio: geometric ratio between successive periods (default √2).
Returns:
alphas: (J,) tensor of spatial frequencies α_j, with j=0 coarsest.
"""
num_dirs = dim + 1
assert location_dim % (2 * num_dirs) == 0, f"location_dim={location_dim} must be divisible by 2*(dimension+1)={2*num_dirs}"
J = location_dim // num_dirs // 2 # number of modules
j_idx = torch.arange(J, dtype=dtype, device=device)
return (math.pi / scale) * ratio ** j_idxFinally, to solve for , we iterate over , choosing as our initial estimate on the assumption that for our scale parameter . That is, . Then, we choose the value for that minimizes the error by rounding. That is, from we equate , which leads to
Now the matrix formula can be solved for , and we can estimate , allowing us to compute .
def solve_for_deltas(delta_thetas: torch.Tensor, K_dagger: torch.Tensor, lattice_basis: torch.Tensor, alphas: torch.Tensor):
d, dplus = K_dagger.shape
J, _, _ = lattice_basis.shape
shape = delta_thetas.shape[:-1] + (J, dplus)
delta_thetas = delta_thetas.view(shape)[..., None]
while K_dagger.ndim < delta_thetas.ndim:
K_dagger = K_dagger[None, ...]
displacement_base = (K_dagger @ delta_thetas).view(shape[:-1] + (d,)) / alphas[None, :, None]
reference_displacement = displacement_base[..., 0, :]
errors = reference_displacement[..., None, :] - displacement_base
lattice_basis = lattice_basis.view(J, d, d)
while lattice_basis.ndim < errors.ndim + 1:
lattice_basis = lattice_basis[None, ...]
offsets = torch.linalg.solve(lattice_basis.float(), errors[..., None].float()).round().to(delta_thetas.dtype)
deltas = displacement_base + (lattice_basis @ offsets).squeeze(-1)
return deltas.view(shape[:-1] + (d,))
Now we can check these functions for accuracy:
J = 20
d = 2
location_dim = J * (d+1) * 2
alphas = make_alphas(location_dim, d)
print(f"alphas: {alphas.detach().cpu().numpy().tolist()}")
lattice_basis, K, K_dagger = make_lattice_basis(alphas, d)
print(f"K: {K.detach().cpu().numpy().tolist()}")
print(f"K_dagger: {K_dagger.detach().cpu().numpy().tolist()}")
assert torch.allclose(K.sum(dim=0), torch.zeros(d))
phis = torch.linspace(0, 2*math.pi, J)
x0 = torch.randn(2, 5, d)
xf = torch.randn(2, 5, d)
deltas = xf - x0
theta0 = alphas[None, None, :, None, None] * (K[None, None, ...] @ x0.view(2, 5, d, 1))[:, :, None, :, :] + phis[None, None, :, None, None]
thetaf = alphas[None, None, :, None, None] * (K[None, None, ...] @ xf.view(2, 5, d, 1))[:, :, None, :, :] + phis[None, None, :, None, None]
print(theta0.shape)
delta_thetas = (thetaf - theta0).view(2, 5, -1)
delta_thetas = delta_thetas.view(2, 5, -1)
deltas_estimate = solve_for_deltas(delta_thetas, K_dagger, lattice_basis, alphas)
print("AVERAGE ERROR OF ESTIMATED DELTA FROM TRUE DELTA: ", torch.norm(deltas[:, :, None, :] - deltas_estimate, dim=-1).mean().item())
alphas: [0.3141592741012573, 0.44428831338882446, 0.6283184885978699, 0.8885766267776489, 1.2566369771957397, 1.7771530151367188, 2.5132739543914795, 3.5543060302734375, 5.026547908782959, 7.108611583709717, 10.053094863891602, 14.217223167419434, 20.106189727783203, 28.434444427490234, 40.21237564086914, 56.86888885498047, 80.42475128173828, 113.73777770996094, 160.84950256347656, 227.47552490234375]
K: [[-0.866025447845459, 0.4999999701976776], [1.5730305946703993e-08, -1.0], [0.866025447845459, 0.5]]
K_dagger: [[-0.5773501992225647, 1.0486870039017049e-08, 0.5773502588272095], [0.3333333432674408, -0.666666567325592, 0.3333333134651184]]
torch.Size([2, 5, 20, 3, 1])
AVERAGE ERROR OF ESTIMATED DELTA FROM TRUE DELTA: 1.7676826757906383e-07
Finally, we have a loss function to regularize over !
def loss_for_deltas(delta_thetas: torch.Tensor, K_dagger: torch.Tensor, lattice_basis: torch.Tensor, alphas: torch.Tensor):
deltas = solve_for_deltas(delta_thetas, K_dagger, lattice_basis, alphas)
# deltas has shape (batch_size, time_steps, J, d)
return deltas.var(dim=-2).mean()
delta_thetas_noise = delta_thetas + torch.randn_like(delta_thetas)
print("LOSS OF NOISY DELTA_THETAS: ", loss_for_deltas(delta_thetas_noise, K_dagger, lattice_basis, alphas).item())
LOSS OF NOISY DELTA_THETAS: 1.8048560619354248
Creating and Verifying Location Codes¶
To work with location codes on a physical space of dimension , we need to be able to create valid locations. In general, our location code is a tensor of shape where the first dimension is the number of different modules, the second indexes over the vertices of the simplex, and the third contains the and values. To be valid, these last two must satisfy , that is, the last dimension must lie on the the unit sphere. But there is a deeper notion of validity as well: the code must have a unique underlying physical interpretation.
To begin, we will need to sample a random location code. One way to do this is to sample a set of phase angles and generate the code from there. Most of the ingredients for this transformation are given above; we’ve computed the simplex vertices and the parameters for . The final piece is the phase , which we will define as combination of two factors:
where is uniform on and . This makes the phases different among all components of the code. In particular, it spaces out the wavevectors evenly over the circle. To say that are the phases assumes the sample represents the encoding of the origin , but this is acceptable because the location for general would merely represent a simultaneous shift across all components, which can be absorbed into the uniform sample . So for any set of samples taken by this method, we can pick one to be the origin and use it to extract a consistent set of phases and consistent physical locations for the other points.
from typing import Tuple
phis = torch.empty((location_dim // 2 // (2 + 1),)).uniform_(-math.pi, math.pi)
def sample(location_dim: int, K: torch.Tensor, alphas: torch.Tensor, phis: torch.Tensor, shape: Tuple[int, ...] = torch.Size()):
physical_dim = K.shape[1]
x = torch.randn(shape + (physical_dim,1))
print(f"sample picked x={x.detach().cpu().numpy().tolist()}")
while K.ndim < x.ndim:
K = K[None, ...]
Kx = (K @ x).squeeze(-1) # (..., d+1)
alphas = alphas[..., None] # (J, d+1)
while alphas.ndim < Kx.ndim:
alphas = alphas[None, ...]
aKx = alphas * Kx[..., None, :] # (..., J, d+1)
phis = phis[..., None]
while phis.ndim < aKx.ndim:
phis = phis[None, ...]
thetas = aKx + phis # + xis
return torch.stack([torch.cos(thetas), torch.sin(thetas)], dim=-1).view(shape + (location_dim,))
l1 = sample(120, K, alphas, phis)
l2 = sample(120, K, alphas, phis)
print(f"l1.shape: {l1.shape} code: {l1.detach().cpu().numpy().tolist()}")
print(f"l2.shape: {l2.shape} code: {l2.detach().cpu().numpy().tolist()}")
sample picked x=[[2.1085588932037354], [-1.3823367357254028]]
sample picked x=[[0.08212031424045563], [1.2365736961364746]]
l1.shape: torch.Size([120]) code: [-0.7163147330284119, -0.6977773308753967, 0.4137597382068634, -0.9103861451148987, 0.34181222319602966, -0.9397682547569275, 0.6467894911766052, -0.7626685500144958, 0.6485604643821716, 0.7611631155014038, 0.7281549572944641, 0.6854125261306763, 0.8047559261322021, 0.5936058163642883, -0.9984386563301086, 0.05585923790931702, -0.9777466654777527, 0.20978902280330658, -0.9895644783973694, 0.14409072697162628, 0.984043538570404, 0.17792800068855286, 0.99916011095047, -0.04097708687186241, -0.16963337361812592, -0.9855072498321533, -0.9998469352722168, -0.017496973276138306, -0.9572534561157227, 0.28925055265426636, 0.017376990988850594, 0.9998490214347839, -0.588797390460968, 0.8082806468009949, -0.18869440257549286, 0.9820358753204346, 0.8791793584823608, 0.47649094462394714, -0.6428601741790771, -0.765983521938324, -0.9687423706054688, -0.24806904792785645, -0.9874424338340759, 0.15797913074493408, -0.4219028055667877, -0.9066410660743713, -0.9674668312072754, -0.25299787521362305, 0.9121777415275574, -0.4097948670387268, 0.9461743235588074, 0.32365739345550537, 0.6104447245597839, -0.7920588850975037, 0.9812467694282532, -0.19275572896003723, -0.7332056164741516, 0.6800069808959961, 0.8051489591598511, 0.5930726528167725, -0.87924724817276, 0.4763656556606293, -0.5342934727668762, -0.8452990651130676, -0.09026530385017395, 0.9959177374839783, 0.859729528427124, 0.5107495784759521, 0.8410654067993164, -0.5409334301948547, -0.5834247469902039, 0.8121671676635742, -0.9495474100112915, 0.3136235773563385, 0.8989977240562439, -0.43795329332351685, 0.6563684940338135, 0.7544404864311218, -0.306147962808609, -0.9519839286804199, -0.5771567821502686, 0.8166333436965942, 0.13694500923156738, 0.9905786514282227, 0.4679951071739197, 0.8837310671806335, 0.6857665777206421, 0.7278215289115906, -0.9583046436309814, -0.28574851155281067, -0.7178304195404053, 0.696217954158783, -0.46898889541625977, -0.8832040429115295, -0.9120305180549622, 0.4101223349571228, 0.7137883305549622, 0.7003614902496338, 0.9716621041297913, 0.23637409508228302, 0.6879958510398865, -0.7257146239280701, -0.38272079825401306, 0.9238640666007996, 0.8172910809516907, -0.5762250423431396, -0.8856114745140076, 0.4644269049167633, 0.18837592005729675, 0.9820970296859741, 0.9533213376998901, 0.3019576072692871, -0.22200298309326172, -0.9750459790229797, -0.42945265769958496, -0.9030894041061401, 0.6317500472068787, -0.7751721143722534, 0.8111048340797424, -0.5849007964134216]
l2.shape: torch.Size([120]) code: [0.1634671688079834, -0.9865487813949585, -0.38589251041412354, -0.9225437641143799, 0.2073732614517212, -0.9782618880271912, 0.8804200291633606, 0.4741947054862976, 0.9557943940162659, -0.294035941362381, 0.8487162590026855, 0.5288484692573547, -0.8361117243766785, 0.5485591292381287, 0.13023075461387634, 0.9914836883544922, -0.8817344903945923, 0.4717460572719574, 0.8455346822738647, -0.5339204668998718, -0.545868992805481, -0.8378705382347107, 0.90609210729599, -0.4230804741382599, -0.5130985379219055, 0.858329713344574, 0.9913099408149719, -0.1315467804670334, -0.6575261950492859, 0.7534316778182983, 0.7543420791625977, 0.6564815640449524, -0.7726732492446899, -0.6348039507865906, 0.5661889910697937, 0.8242754340171814, -0.33759403228759766, 0.9412918090820312, -0.8398900032043457, -0.5427566766738892, -0.6456234455108643, 0.763655960559845, 0.2593931555747986, 0.9657717943191528, 0.3137822151184082, 0.9494949579238892, -0.2407698780298233, 0.9705822467803955, -0.7472859621047974, 0.6645026206970215, 0.9642717838287354, -0.26491492986679077, -0.9999295473098755, 0.011868349276483059, -0.920486330986023, 0.39077481627464294, -0.8702760934829712, 0.49256423115730286, -0.8198588490486145, -0.5725656747817993, -0.4484948515892029, 0.8937854170799255, -0.9823259711265564, 0.1871781349182129, -0.9479047656059265, -0.3185538947582245, 0.9926824569702148, 0.120754174888134, 0.9943798780441284, -0.10587082803249359, -0.5417043566703796, 0.8405690789222717, -0.03338322415947914, 0.9994426369667053, -0.9560878276824951, -0.2930803894996643, -0.24582013487815857, -0.9693154692649841, -0.9091147780418396, -0.41654571890830994, -0.9996976256370544, 0.024589691311120987, 0.23607660830020905, 0.9717344641685486, 0.21604879200458527, -0.9763825535774231, -0.6786207556724548, 0.734488844871521, -0.3388988971710205, -0.9408227801322937, 0.757413387298584, 0.6529356241226196, 0.9806544184684753, -0.1957469880580902, -0.8112629055976868, 0.5846815705299377, -0.5766915082931519, 0.8169620037078857, -0.9949922561645508, -0.0999518483877182, 0.49022790789604187, 0.8715943098068237, 0.21265260875225067, -0.9771278500556946, -0.9994928240776062, 0.0318438820540905, -0.631900429725647, 0.7750495672225952, -0.4898953139781952, -0.8717812895774841, 0.9995584487915039, 0.029714152216911316, -0.3668069541454315, 0.9302970767021179, -0.7180269956588745, -0.6960152983665466, 0.9617215394973755, 0.2740286588668823, 0.13802990317344666, -0.9904280304908752]
Now, if we identify as the origin , then we can describe as the block rotation required to get from to , and from there we can get to the displacement such that represents .
To do this, we use atan2 and the formula
which resolves to
after which we can call solve_for_deltas above.
Note that we will allow for the possibility of codes that have no physical interpretation, which will mean that our results will be of shape with different physical interpretations, one per module. We’ll also return the mean of these interpretations. For real physical locations, the variance should be small.
def reshape_to_components(location: torch.Tensor, physical_dim: int):
shape = tuple(list(location.shape[:-1]) + [-1, physical_dim + 1, 2])
return location.view(*shape)
def compute_angles(
location1: torch.Tensor,
location2: torch.Tensor,
physical_dim: int=2
):
location1 = reshape_to_components(location1, physical_dim) # Now has shape (..., J, d+1, 2)
location2 = reshape_to_components(location2, physical_dim) # Now has shape (..., J, d+1, 2)
s2c1 = location2[..., 1] * location1[..., 0].float()
c2s1 = location2[..., 0] * location1[..., 1].float()
c2c1 = location2[..., 0] * location1[..., 0].float()
s2s1 = location2[..., 1] * location1[..., 1].float()
delta_thetas = torch.atan2(s2c1 - c2s1, c2c1 + s2s1 + 1e-6) # something fishy here
delta_thetas = delta_thetas.view(delta_thetas.shape[:-2] + (-1,))
return delta_thetas
def compute_displacements(
location1: torch.Tensor,
location2: torch.Tensor,
K_dagger: torch.Tensor,
lattice_basis: torch.Tensor,
alphas: torch.Tensor,
):
physical_dim = K_dagger.shape[0]
delta_thetas = compute_angles(location1, location2)
# estimated displacements from thetas, made as small as possible solving across alphas -- but may not agree!
# deltas has shape (..., J, d)
deltas = solve_for_deltas(
delta_thetas,
K_dagger,
lattice_basis,
alphas
)
mean_deltas = deltas.mean(dim=-2)
return deltas.to(location1.dtype), mean_deltas.to(location1.dtype)
delta_thetas = compute_angles(l1, l2, 2)
print(f"delta_thetas.shape: {delta_thetas.shape}")
print(f"delta_thetas: {delta_thetas.detach().cpu().numpy().tolist()}")
deltas, mean_deltas = compute_displacements(l1, l2, K_dagger, lattice_basis, alphas)
print(f"deltas: {deltas.detach().cpu().numpy().tolist()}")
print(f"mean_deltas: {mean_deltas.detach().cpu().numpy().tolist()}")
delta_thetas.shape: torch.Size([60])
delta_thetas: [0.9627096652984619, -0.8227543830871582, -0.13995537161827087, 1.3614771366119385, -1.1635503768920898, -0.19792671501636505, 1.925419569015503, -1.6455087661743164, -0.27991050481796265, 2.7229557037353516, -2.327101945877075, -0.3958534300327301, -2.4323434829711914, 2.9921655654907227, -0.559821367263794, -0.8372727036476135, 1.6289799213409424, -0.7917068004608154, 1.4184958934783936, -0.2988537549972534, -1.1196424961090088, -1.674545407295227, -3.0252232551574707, -1.583414077758789, 2.836993455886841, -0.5977075695991516, -2.239285945892334, 2.934088945388794, 0.232738196849823, 3.1163547039031982, -0.6091985702514648, -1.1954123973846436, 1.8046104907989502, -0.41500645875930786, 0.46547645330429077, -0.050475697964429855, -1.218399167060852, -2.39082407951355, -2.6739628314971924, -0.8300091624259949, 0.9309605956077576, -0.10094953328371048, -2.436805486679077, 1.501538634300232, 0.9352618455886841, -1.6600226163864136, 1.8619219064712524, -0.20189906656742096, 1.4095646142959595, 3.0030789375305176, 1.8705166578292847, 2.963141918182373, -2.5593390464782715, -0.4038057327270508, 2.8191308975219727, -0.2770266532897949, -2.542149543762207, -0.3570764362812042, 1.164566159248352, -0.8075658679008484]
deltas: [[-2.0264368057250977, 2.618908166885376], [-2.0264365673065186, 2.618908166885376], [-2.0264370441436768, 2.618908643722534], [-2.026437520980835, 2.6189095973968506], [-2.0264387130737305, 2.618910551071167], [-2.0264384746551514, 2.618910789489746], [-2.026437759399414, 2.618910074234009], [-3.0470595359802246, 2.0296547412872314], [-2.0264382362365723, 2.618910074234009], [-1.5161278247833252, 2.9135377407073975], [-2.0264384746551514, 2.618910074234009], [-2.0264382362365723, 2.618910074234009], [-2.2068605422973633, 2.5147433280944824], [-2.0264387130737305, 2.618910312652588], [-2.0264387130737305, 2.618910551071167], [-2.0264384746551514, 2.618910312652588], [-1.9813331365585327, 2.6449520587921143], [-2.0264384746551514, 2.618910312652588], [-2.0264387130737305, 2.618910074234009], [-2.0264384746551514, 2.618910074234009]]
mean_deltas: [-2.0587196350097656, 2.6002724170684814]
A Metric for Fourier Codes¶
What we are really interested in is a way to compare location codes that respects the structure of the code. We could have used , which is defined on , but the structure of a fourier location code means that the different modules have different wavelengths , which means that different components should have different weights. Since we’ve sorted our wavelengths, we end up with the early dimensions mattering much more than the later dimensions. We can leverage the foregoing formulae to compose a pseudometric over location codes by extracting the block rotations that make two location align, extracting , computing and finally return as our norm. In other words, we return the physical displacement underlying the two codes. In terms of codes, I believe it is a full metric; but in terms of underlying physical space it is a pseudometric because finite-dimensional location codes cannot make distinctions that are too large or too small.
In addition to the distance, we can consider the case where a “location code” has varying physical interpretations across modules . In this case, we can compute a mean physical interpretation and add the variance to our distance as well.
def pseudo_distance(
location1: torch.Tensor,
location2: torch.Tensor,
K_dagger: torch.Tensor,
lattice_basis: torch.Tensor,
alphas: torch.Tensor,
squared: bool=False,
use_variance: bool=True
):
deltas, mean_deltas = compute_displacements(location1, location2, K_dagger, lattice_basis, alphas)
J = deltas.shape[-2]
assert J > 1, "J must be greater than 1"
squared_distances = mean_deltas.square().sum(dim=-1)
if use_variance:
dev_deltas = deltas - mean_deltas[..., None, :]
variances = dev_deltas.square().sum(dim=-1).mean(dim=-1) * (J / (J - 1)) # (...)
final_squared_distances = squared_distances + variances
else:
final_squared_distances = squared_distances
if squared:
return final_squared_distances
else:
return (final_squared_distances + 1e-12).sqrt()
dist = pseudo_distance(l1, l2, K_dagger, lattice_basis, alphas)
print(f"dist: {dist.detach().cpu().numpy().tolist()}")dist: 3.330477714538574
Now, when implementing TEM variants, what we are interested in is not just the distance but the cross-distance among many items. That is, given and we want to compute for all . We are interested in the case where has shape and has shape , and we will limit the flexibility for this case. This is easily done with shape tricks:
def cross_distance(location1: torch.Tensor, location2: torch.Tensor, K_dagger: torch.Tensor, lattice_basis: torch.Tensor, alphas: torch.Tensor, squared: bool=False):
return pseudo_distance(location1[..., :, None, :], location2[..., None, :, :], K_dagger, lattice_basis, alphas, squared=squared)
B = 2
T = 10
S = 8
li = sample(120, K, alphas, phis, (B, T))
lj = sample(120, K, alphas, phis, (B, S))
cross_dist = cross_distance(li, lj, K_dagger, lattice_basis, alphas)
print(f"cross_dist: {cross_dist}")
print(f"Min dist = {cross_dist.min().item()} Mean dist = {cross_dist.mean().item()} Max dist = {cross_dist.max().item()}")
sample picked x=[[[[-0.8145899772644043], [-0.569575309753418]], [[0.32764700055122375], [-0.9747219085693359]], [[2.815774917602539], [-2.0576868057250977]], [[0.35869482159614563], [0.6466221213340759]], [[-0.3521421551704407], [-0.3835611641407013]], [[-0.8626407384872437], [-1.2137911319732666]], [[-0.7666817307472229], [-0.2413574904203415]], [[1.1334725618362427], [2.322641372680664]], [[0.5391432642936707], [0.39215847849845886]], [[-0.7231886386871338], [-0.7787725329399109]]], [[[1.922032356262207], [0.7315117716789246]], [[1.4502531290054321], [0.6809671521186829]], [[-0.43476399779319763], [-2.8658132553100586]], [[-0.7493019700050354], [0.9299961924552917]], [[-0.2567051351070404], [0.39329344034194946]], [[0.5350938439369202], [-0.032766442745923996]], [[0.04726826772093773], [-1.804565668106079]], [[0.8472246527671814], [0.23619519174098969]], [[-0.1825113743543625], [-0.9091907739639282]], [[-0.4195076525211334], [-1.171326994895935]]]]
sample picked x=[[[[2.3358025550842285], [-1.3003051280975342]], [[-1.249247431755066], [0.008154939860105515]], [[-0.17388926446437836], [0.19512693583965302]], [[-2.7206344604492188], [-0.8759704828262329]], [[-1.68034827709198], [0.7819029688835144]], [[-0.7264818549156189], [0.43826475739479065]], [[-0.6606448292732239], [-0.14920498430728912]], [[-0.44584688544273376], [2.2403321266174316]]], [[[-0.6501056551933289], [-0.5011352896690369]], [[0.7847572565078735], [-0.533992350101471]], [[-0.3528536260128021], [-1.6510188579559326]], [[-0.6515844464302063], [0.34739628434181213]], [[0.25842607021331787], [-1.6758942604064941]], [[-0.08893660455942154], [-0.366155207157135]], [[0.049642354249954224], [2.1746392250061035]], [[1.3555898666381836], [-0.771336019039154]]]]
cross_dist: tensor([[[3.4689, 0.7575, 0.9823, 1.8817, 1.6062, 1.1158, 0.4538, 3.1328],
[2.2122, 2.0509, 1.2782, 3.1153, 2.9097, 1.7702, 1.4117, 3.4374],
[0.8993, 4.9564, 4.0990, 6.1888, 5.8806, 4.6774, 4.3384, 5.4092],
[2.9460, 1.7308, 0.6888, 3.4417, 2.2114, 1.0775, 1.2804, 1.7908],
[3.0486, 1.0823, 0.6142, 2.3939, 1.8334, 0.9066, 0.4071, 2.8971],
[3.2221, 1.2837, 1.6074, 2.0457, 2.2520, 1.7347, 1.1488, 3.6484],
[3.4882, 0.5877, 0.7055, 2.0300, 1.4394, 0.7544, 0.1382, 2.6188],
[4.0144, 3.2426, 2.4797, 4.8874, 3.1290, 2.6444, 2.9847, 1.5713],
[2.5590, 1.8164, 0.7261, 3.4920, 2.4291, 1.2892, 1.2877, 2.0946],
[3.3575, 0.9466, 1.1197, 1.9633, 1.8340, 1.3099, 0.7096, 3.3055]],
[[2.7375, 1.6826, 3.1825, 2.5894, 2.8849, 2.2471, 2.4678, 1.7314],
[2.3656, 1.3876, 2.8719, 2.0924, 2.6491, 1.8310, 2.1588, 1.5934],
[2.4782, 2.6401, 1.2733, 3.3916, 1.3760, 2.6298, 5.6739, 2.7023],
[1.5595, 2.2391, 2.8605, 0.6174, 2.9280, 1.5082, 1.4540, 2.9558],
[1.0579, 1.4847, 2.2328, 0.3931, 2.2495, 0.8114, 2.0267, 2.1672],
[1.2538, 0.5631, 1.8483, 1.3167, 1.7665, 0.7070, 2.3612, 1.1531],
[1.4785, 1.4660, 0.4537, 2.3618, 0.2485, 1.5734, 4.3647, 1.6213],
[1.6580, 0.8404, 2.2345, 1.6403, 2.1732, 1.0892, 2.1921, 1.1294],
[0.6448, 1.0114, 0.8087, 1.4078, 0.8845, 0.5831, 3.3470, 1.5272],
[0.7474, 1.3595, 0.5283, 1.6079, 0.8803, 0.8814, 3.5551, 1.8145]]])
Min dist = 0.13820236921310425 Mean dist = 2.040937900543213 Max dist = 6.188777446746826
Class Implementation and Relative Sampling¶
This code is implemented in tree_world.models.fourier_metric.FourierMetric with , , , , and as buffer variables. It is implemented in a torch.nn.Module but has does not generally have learnable parameters. This class can be used to sample and compare location codes.
from tree_world.models.fourier_metric import FourierMetric
metric = FourierMetric(location_dim=120, dim=2, scale=10.0, ratio=math.sqrt(2.0))
li = metric.sample((B, T))
lj = metric.sample((B, S))
dist = metric.cross_distance(li, lj)
print(f"dist: {dist}")
dist: tensor([[[0.9421, 1.5962, 2.3569, 2.3076, 1.9625, 1.9333, 1.5775, 1.8270],
[1.5659, 1.7411, 1.1836, 2.5785, 0.4220, 1.1630, 2.9069, 0.4242],
[2.2452, 2.4857, 1.7471, 3.4374, 0.5024, 1.8198, 3.4858, 1.0958],
[1.2867, 1.6016, 1.3601, 2.5070, 0.6681, 1.4069, 2.5962, 0.7059],
[3.0952, 2.6262, 1.7894, 2.9229, 3.1209, 1.9545, 4.0000, 2.4535],
[1.5533, 1.5758, 0.0778, 2.3225, 1.2331, 0.4601, 2.7617, 0.6553],
[0.3575, 0.2500, 1.4648, 1.1098, 2.0950, 0.9737, 1.3357, 1.5465],
[1.5390, 1.2148, 0.7372, 1.7099, 1.9579, 0.5735, 2.4550, 1.3080],
[1.5604, 1.7113, 0.8066, 2.5536, 0.5404, 0.8445, 2.9474, 0.0792],
[1.2319, 0.6270, 1.8447, 0.5421, 2.5495, 1.4288, 1.4021, 2.0933]],
[[2.3313, 1.7126, 0.9637, 1.1956, 1.2823, 1.0090, 1.0556, 1.2244],
[3.3409, 2.9554, 1.1170, 0.9126, 2.2302, 2.1943, 2.3602, 0.4456],
[1.9675, 1.8695, 0.5109, 0.5249, 0.7804, 1.2173, 2.4413, 1.1589],
[3.3661, 2.8070, 1.0090, 0.8525, 2.0592, 1.9736, 2.0501, 0.2020],
[4.1052, 3.4114, 2.2412, 2.3425, 2.9672, 2.6285, 1.2778, 1.8429],
[2.4114, 1.7896, 1.2168, 1.5364, 1.5727, 1.1399, 0.7792, 1.3450],
[3.1906, 3.2185, 1.4153, 1.1378, 2.0199, 2.5133, 3.2347, 1.4273],
[1.9130, 2.1871, 1.3014, 1.2335, 1.1101, 1.7137, 3.3191, 1.8391],
[3.3696, 2.8035, 0.9442, 0.7955, 2.0526, 1.9916, 2.1881, 0.2866],
[3.3669, 2.7935, 2.3590, 2.4586, 2.6393, 2.1803, 0.6271, 2.2216]]])
The module tree_world.models.fourier_metric also defines a relative sampler in support of our TEM VAE. To understand this problem, suppose you have a reference location , and you want to sample a neighboring location that preferentially lies in some neighborhood of scale . Again, we can revert to physical space for our operations. We can sample and then compute and the block rotation . Applying this rotation to yields a location that is -close to in terms of the pseudometric defined above. Here is the code:
import torch.distributions as D
from typing import Optional, Tuple
class FourierCodeDistribution(D.Distribution):
support = D.constraints.real
has_rsample = True
def __init__(self, metric: FourierMetric, reference_location: torch.Tensor, scale: torch.Tensor,
batch_lengths: Optional[torch.Tensor]=None, idx: Optional[torch.Tensor]=None, validate_args=None, tau: float=1e-6):
self.metric = metric
self.reference_location = reference_location
self.batch_lengths = batch_lengths
self.idx = idx
self.dtype = reference_location.dtype
self.device = reference_location.device
self.scale = scale
self.tau = tau
batch_shape = reference_location.shape[:-1]
assert scale.shape == () or scale.shape == batch_shape
if scale.shape == ():
self.scale = self.scale.expand(batch_shape)
super().__init__(batch_shape=batch_shape, event_shape=(self.metric.location_dim,), validate_args=validate_args)
def sample(self, sample_shape: Tuple[int, ...] = torch.Size()):
# sample Delta x ~ N(0, I_M)
u = torch.randn(sample_shape + self.batch_shape + (self.metric.dim,), device=self.device, dtype=self.dtype)
scale = self.scale[..., None]
while scale.ndim < u.ndim:
scale = scale[None, ...]
deltas = scale * u
locations = self.metric.apply_displacement(deltas, self.reference_location)
return locations, deltas
rsample = sample
def log_det_jacobian(
self, displacements: torch.Tensor
):
# note, this won't work for samples shapes other than ()!
output_shape = displacements.shape[:-1]
delta_thetas = self.metric.compute_angles_from_displacements(displacements)
dplus = self.metric.dim + 1
reference_location = self.reference_location.view(-1, self.metric.J, dplus, 2)
delta_thetas = delta_thetas.view(-1, self.metric.J, dplus)
xi = torch.atan2(reference_location[..., 1], reference_location[..., 0])
delta_xi = xi[..., None] - xi[..., None, :]
delta_delta_theta = delta_thetas[..., None] - delta_thetas[..., None, :]
# shape (flat_batch, J, dplus, dplus) (vs. K = (dplus, d))
W = (torch.cos(delta_xi - delta_delta_theta) * self.metric.alphas.square()[None, :, None, None]).sum(dim=1)
K = self.metric.K[None, None, ...]
KWK = ((K.transpose(-2, -1) @ W) @ K).float()
logdet = 0.5 * torch.logdet(KWK + self.tau * torch.eye(self.metric.dim, device=KWK.device, dtype=torch.float32))
return logdet.view(output_shape).to(self.dtype)
def log_prob(self, locations: torch.Tensor, displacements: Optional[torch.Tensor]=None):
if displacements is None:
_, displacements = self.metric.compute_displacements(self.reference_location, locations)
scale = self.scale[..., None]
while scale.ndim < displacements.ndim:
scale = scale[None, ...]
displacements = displacements / scale
gaussian_log_prob = (
- 0.5 * displacements.square().sum(dim=-1)
- self.metric.dim * (0.5 * math.log(2.0 * math.pi) + torch.log(self.scale))
)
return gaussian_log_prob - self.log_det_jacobian(displacements)And, a demonstration:
l0 = metric.sample((B, T))
epsilon = torch.empty(B, T).uniform_(0.01, 0.1)
sampler = FourierCodeDistribution(metric, l0, epsilon)
l, deltas = sampler.sample()
print(f"l: {l.shape}")
dist = metric.pseudo_distance(l, l0)
print(f"Dist: {dist}")
print(f"Min dist = {dist.min().item()} Mean dist = {dist.mean().item()} Max dist = {dist.max().item()}")
print(f"Epsilon-scaled dist: {dist / epsilon}")
print(f"Epsilon-scaled dist min = {(dist / epsilon).min().item()} mean = {(dist / epsilon).mean().item()} max = {(dist / epsilon).max().item()}")
log_prob = sampler.log_prob(l, deltas)
print(f"log_prob: {log_prob.shape} min = {log_prob.min().item()} mean = {log_prob.mean().item()} +/- {log_prob.std().item()} max = {log_prob.max().item()}")l: torch.Size([2, 10, 120])
Dist: tensor([[0.0567, 0.0542, 0.0332, 0.1041, 0.0506, 0.0575, 0.0133, 0.0610, 0.0355,
0.0346],
[0.0643, 0.0765, 0.0197, 0.2781, 0.1130, 0.0245, 0.0872, 0.0749, 0.0454,
0.1082]])
Min dist = 0.013298843055963516 Mean dist = 0.0696159079670906 Max dist = 0.2781006693840027
Epsilon-scaled dist: tensor([[0.8619, 1.3081, 1.6165, 2.5906, 1.3057, 0.6437, 1.2321, 1.9036, 0.5559,
0.3758],
[1.0868, 1.2073, 1.1949, 3.0120, 1.1895, 1.0035, 1.1124, 1.0657, 2.9058,
1.4086]])
Epsilon-scaled dist min = 0.37576478719711304 mean = 1.3790267705917358 max = 3.0119752883911133
log_prob: torch.Size([2, 10]) min = -12.616140365600586 mean = -8.568292617797852 +/- 1.6212801933288574 max = -5.113590717315674
/Users/alockett/dev/tree-world/venv/lib/python3.12/site-packages/torch/distributions/distribution.py:62: UserWarning: <class '__main__.FourierCodeDistribution'> does not define `arg_constraints`. Please set `arg_constraints = {}` or initialize the distribution with `validate_args=False` to turn off validation.
warnings.warn(
As we see, our referential sampler provides samples that are “close” to the reference point, and when we scale for epsilon, the mean epsilon-scaled distance is about 1, as we expect.
The one bit that hasn’t been explained is the log probability calculation, which is needed for the VAE loss on TEM. We’ll cover that now.
Computing the Log Probability¶
In general, when sampling a variable through a transform by inverting, we have the log probability
where is the Jacobian matrix of with respect to , assuming is not square. The derivation for this is given in Building Pseudometrics -- Embedding Case.
In this case, the inverse functions are composed of
This yields Jacobians with components
and so, putting these components together,
where the r.h.s. is interpreted as a tensor of shape indexed by , reshaped to a matrix where flattens the first three dimensions. That means that is a matrix, and the only components involved in generating these two -sized dimensions are the vectors above. Therefore, the form of where represents the interaction terms. is a matrix and is the “square” of the matrix obtained by pulling out of the equation above and conflating the dimensions. Squaring the terms and using to alias in ,
where
which multiplies to
The inner sum over allows us to use angle formulae:
yielding
Recall that as a location code, and for some . Therefore,
Therefore
Applying the identities one last time, we find that
leaving
In order to compute , we must compute and then , which is dependent on the input .
def calculate_jacobian_interaction(
reference_location: torch.Tensor, alphas: torch.Tensor, delta_thetas: torch.Tensor,
K: torch.Tensor, physical_dim: int, J: int, tau: float=1e-6
):
dplus = physical_dim + 1
reference_location = reference_location.view(-1, J, dplus, 2)
delta_thetas = delta_thetas.view(-1, J, dplus)
xi = torch.atan2(reference_location[..., 1], reference_location[..., 0])
delta_xi = xi[..., None] - xi[..., None, :]
delta_delta_theta = delta_thetas[..., None] - delta_thetas[..., None, :]
# shape (flat_batch, J, dplus, dplus) (vs. K = (dplus, d))
W = (torch.cos(delta_xi - delta_delta_theta) * alphas.square()[None, :, None, None]).sum(dim=1)
K = K[None, None, ...]
KWK = (K.transpose(-2, -1) @ W) @ K
return 0.5 * torch.logdet(KWK + tau * torch.eye(d, device=KWK.device, dtype=KWK.dtype))
Conclusion¶
We’ve built a metric for Fourier location codes that respects the underlying physicality. Together with the embedding pseudo-metric we’ll build next, we can use this metric to implement an advanced spatial memory for TEM-t .
Addendum: Test encoding¶
from importlib import reload
import tree_world.fourier
import tree_world.models.fourier_metric
reload(tree_world.fourier)
reload(tree_world.models.fourier_metric)
metric = tree_world.models.fourier_metric.FourierMetric(location_dim=120, dim=2, scale=1000.0, ratio=math.sqrt(2.0))
locations = torch.randn(10, 2) * 10
encoded = metric.encode(locations)
decoded_with_j, decoded = metric.interpret(encoded)
print(f"decoded_with_j: {decoded_with_j.shape}")
print(f"decoded: {decoded.shape}")
diff = torch.norm(locations - decoded, dim=-1)
print(f"locations: {locations.detach().cpu().numpy().tolist()}")
print(f"decoded: {decoded.detach().cpu().numpy().tolist()}")
print(f"diff: {diff.detach().cpu().numpy().tolist()}")
print(f"decoded_with_j[base]: {decoded_with_j[..., 0, :].detach().cpu().numpy().tolist()}")
thetas (-1): torch.Size([10, 3]), [1.9679633378982544, -13.361233711242676, 11.393270492553711]
thetas (0): torch.Size([10, 20, 3]), [[-2.225909948348999, -2.2740678787231445, -2.1962993144989014], [-2.2488722801208496, -2.2824606895446777, -2.1649441719055176], [-2.225297451019287, -2.2031636238098145, -2.2678160667419434], [-2.238586902618408, -2.2492616176605225, -2.2084288597106934], [-2.2219350337982178, -2.2988901138305664, -2.1754519939422607], [-2.1928365230560303, -2.206786870956421, -2.2966537475585938], [-2.2440671920776367, -2.224576473236084, -2.227633476257324], [-2.2316999435424805, -2.250920057296753, -2.2136573791503906], [-2.2012805938720703, -2.268103837966919, -2.2268927097320557], [-2.2615439891815186, -2.205914258956909, -2.228818893432617]]
locations (1): torch.Size([10, 20, 3, 2]), [[[-0.6092493534088135, -0.7929787039756775], [-0.3557904064655304, -0.9345657825469971], [-0.8461326360702515, -0.5329723954200745], [0.9849703907966614, 0.1727232038974762], [0.4764881432056427, -0.879180908203125], [0.9761363863945007, 0.21715831756591797], [-0.6183295249938965, -0.7859189510345459], [0.36158695816993713, 0.9323384165763855], [-0.7824238538742065, -0.6227462887763977], [-0.789602518081665, 0.6136186718940735], [0.5790926218032837, -0.8152617812156677], [-0.7545367479324341, -0.6562577486038208], [0.5645982027053833, 0.8253659009933472], [-0.028533324599266052, 0.9995928406715393], [0.4315015375614166, 0.9021121859550476], [0.999570369720459, -0.029309246689081192], [-0.36351820826530457, -0.9315871000289917], [-0.8710347414016724, -0.49122142791748047], [0.32901322841644287, 0.94432532787323], [-0.15151476860046387, -0.9884549975395203]], [[-0.6272957921028137, -0.7787811160087585], [-0.3859463930130005, -0.9225212335586548], [-0.8697085976600647, -0.49356555938720703], [0.9941037893295288, 0.10843255370855331], [0.3938405215740204, -0.9191787838935852], [0.996041476726532, 0.08888974040746689], [-0.7514882683753967, -0.6597464680671692], [0.5889506340026855, 0.8081690073013306], [-0.9538923501968384, -0.30014896392822266], [-0.38072511553764343, 0.9246882796287537], [-0.11691465228796005, -0.9931419491767883], [-0.9481881856918335, 0.3177092671394348], [0.8781832456588745, -0.47832438349723816], [0.8874616026878357, -0.46088168025016785], [-0.24134846031665802, -0.970438539981842], [-0.5024516582489014, 0.8646053075790405], [0.032755788415670395, -0.9994633793830872], [-0.054159343242645264, 0.9985322952270508], [-0.45675885677337646, 0.8895905613899231], [0.8775416016578674, 0.47950056195259094]], [[-0.6087635159492493, -0.7933517098426819], [-0.3549808859825134, -0.9348735809326172], [-0.8454792499542236, -0.5340083241462708], [0.9846697449684143, 0.1744290292263031], [0.47864025831222534, -0.8780111074447632], [0.9753782749176025, 0.22053857147693634], [-0.6144718527793884, -0.7889387011528015], [0.35511866211891174, 0.9348212480545044], [-0.7762845158576965, -0.6303827166557312], [-0.7980292439460754, 0.6026186943054199], [0.5949568152427673, -0.8037576675415039], [-0.7360618114471436, -0.6769143342971802], [0.531823992729187, 0.8468549251556396], [-0.08386636525392532, 0.9964770078659058], [0.3595351576805115, 0.9331315755844116], [0.9966772794723511, 0.08145192265510559], [-0.21360936760902405, -0.9769191741943359], [-0.7416945099830627, -0.6707378029823303], [0.021709362044930458, 0.9997643232345581], [0.28721901774406433, -0.9578649401664734]], [[-0.6192526817321777, -0.7851917743682861], [-0.3724871277809143, -0.9280373454093933], [-0.8593721389770508, -0.5113506317138672], [0.9905291199684143, 0.13730277121067047], [0.43131333589553833, -0.9022021889686584], [0.9891870021820068, 0.146659716963768], [-0.6947205066680908, -0.7192797660827637], [0.4911356568336487, 0.8710830807685852], [-0.8918324112892151, -0.4523659944534302], [-0.5837292075157166, 0.8119484186172485], [0.21036912500858307, -0.9776220321655273], [-0.9899130463600159, -0.14167632162570953], [0.9873029589653015, 0.15884873270988464], [0.8995989561080933, 0.4367170035839081], [0.8785343170166016, -0.47767919301986694], [-0.6840437650680542, -0.7294409275054932], [0.4580058753490448, 0.8889491558074951], [0.594248354434967, -0.8042815923690796], [0.5164297819137573, 0.8563295602798462], [-0.09344228357076645, 0.9956247210502625]], [[-0.6060925126075745, -0.7953941226005554], [-0.35053154826164246, -0.9365509152412415], [-0.8418691754341125, -0.5396816730499268], [0.982966423034668, 0.18378528952598572], [0.49040526151657104, -0.8714945316314697], [0.9710074067115784, 0.23904940485954285], [-0.5930308103561401, -0.8051797747612], [0.3193094730377197, 0.9476504921913147], [-0.7412649393081665, -0.6712125539779663], [-0.8415241241455078, 0.5402194857597351], [0.677829384803772, -0.7352192401885986], [-0.6249541640281677, -0.780661404132843], [0.33872780203819275, 0.9408844113349915], [-0.3786039352416992, 0.9255588054656982], [-0.06256643682718277, 0.9980407953262329], [0.7711268067359924, 0.6366815567016602], [0.601586639881134, -0.7988075613975525], [0.3724956214427948, -0.9280339479446411], [-0.9916909337043762, -0.12864302098751068], [0.40382176637649536, 0.914837658405304]], [[-0.582694411277771, -0.812691330909729], [-0.31170493364334106, -0.9501789808273315], [-0.8090534806251526, -0.5877350568771362], [0.9645299911499023, 0.2639734148979187], [0.5882955193519592, -0.808646023273468], [0.9187104105949402, 0.3949318826198578], [-0.39128577709198, -0.9202691912651062], [-0.004213245119899511, 0.9999911189079285], [-0.3610319197177887, -0.9325534701347351], [-0.9961541891098022, -0.08761728554964066], [0.9944750070571899, 0.10497363656759262], [0.5986244082450867, -0.8010298609733582], [-0.9985398650169373, 0.05401952192187309], [-0.1193096786737442, -0.9928570985794067], [0.6017131209373474, -0.7987123131752014], [0.9474664330482483, -0.3198551535606384], [0.9711562991142273, 0.2384437471628189], [-0.9968783259391785, 0.07895282655954361], [0.7772918343544006, -0.6291401982307434], [-0.9728303551673889, -0.23151901364326477]], [[-0.6235464215278625, -0.781786322593689], [-0.3796687424182892, -0.9251224994659424], [-0.8649253845214844, -0.5019004940986633], [0.9925383925437927, 0.12193258106708527], [0.41143330931663513, -0.9114398956298828], [0.9932577013969421, 0.11592713743448257], [-0.7255787253379822, -0.6881391406059265], [0.5441682934761047, 0.8389760851860046], [-0.928022027015686, -0.3725253939628601], [-0.4788154363632202, 0.8779155611991882], [0.03656914457678795, -0.9993311166763306], [-0.9944016933441162, 0.10566579550504684], [0.9817719459533691, -0.19006264209747314], [0.999028205871582, -0.044075630605220795], [0.36281266808509827, -0.9318621158599854], [-0.9848017692565918, 0.17368227243423462], [0.9529551863670349, -0.30311119556427], [-0.975242018699646, -0.22114025056362152], [-0.20560991764068604, -0.9786340594291687], [-0.6691880226135254, -0.7430931329727173]], [[-0.6138304471969604, -0.7894378900527954], [-0.36343085765838623, -0.9316211938858032], [-0.8522475957870483, -0.523138701915741], [0.9876668453216553, 0.15657033026218414], [0.45600035786628723, -0.8899796009063721], [0.9827241897583008, 0.185076043009758], [-0.6540570259094238, -0.7564452290534973], [0.42184171080589294, 0.9066694974899292], [-0.8366773724555969, -0.5476961135864258], [-0.7026738524436951, 0.7115121483802795], [0.4189927279949188, -0.9079896211624146], [-0.898777425289154, -0.43840518593788147], [0.8251713514328003, 0.5648825168609619], [0.47548264265060425, 0.8797250986099243], [0.9273526668548584, 0.37418848276138306], [0.4736085534095764, -0.880735456943512], [-0.9600874185562134, 0.2797001302242279], [0.01191070955246687, 0.9999290704727173], [-0.15748026967048645, -0.9875221252441406], [0.933087944984436, 0.35964834690093994]], [[-0.5895360112190247, -0.8077421188354492], [-0.32302916049957275, -0.9463890194892883], [-0.8188635110855103, -0.5739882588386536], [0.9705588817596436, 0.24086399376392365], [0.5606520771980286, -0.8280514478683472], [0.9365200400352478, 0.35061413049697876], [-0.4525127410888672, -0.8917579650878906], [0.09119383245706558, 0.995833158493042], [-0.4833519160747528, -0.8754261136054993], [-0.9946653842926025, 0.10315409302711487], [0.9864112138748169, -0.16429509222507477], [0.2567394971847534, -0.9664806723594666], [-0.8284469246864319, 0.5600675344467163], [-0.7731962203979492, -0.6341668367385864], [-0.4215902090072632, -0.9067864418029785], [-0.2795475125312805, -0.9601318836212158], [-0.34301578998565674, -0.9393296241760254], [0.9999850392341614, 0.0054725054651498795], [0.28731897473335266, 0.9578349590301514], [-0.9200234413146973, -0.39186328649520874]], [[-0.637113630771637, -0.7707698941230774], [-0.40241557359695435, -0.9154571294784546], [-0.8819366097450256, -0.47136804461479187], [0.9973508715629578, 0.07274086028337479], [0.3467644155025482, -0.9379522800445557], [0.9998499155044556, 0.017324326559901237], [-0.8143966197967529, -0.5803086757659912], [0.6983744502067566, 0.7157325744628906], [-0.9947921633720398, -0.10192417353391647], [-0.10366562008857727, 0.9946122169494629], [-0.49920180439949036, -0.8664857745170593], [-0.6241377592086792, 0.7813143134117126], [0.258105993270874, -0.9661166071891785], [-0.055076178163290024, -0.9984821677207947], [-0.9568203091621399, 0.29068002104759216], [0.9807372093200684, -0.1953318864107132], [0.06954678893089294, 0.9975786805152893], [-0.984035313129425, -0.17797325551509857], [-0.26636478304862976, 0.9638723134994507], [-0.7319835424423218, -0.6813223361968994]]]
thetas (1): torch.Size([10, 20, 3]), [[-2.225909948348999, -2.2740678787231445, -2.1962993144989014], [-2.2488722801208496, -2.2824606895446777, -2.1649441719055176], [-2.225297451019287, -2.2031636238098145, -2.2678160667419434], [-2.238586902618408, -2.2492616176605225, -2.2084288597106934], [-2.2219350337982178, -2.2988901138305664, -2.1754519939422607], [-2.1928365230560303, -2.206786870956421, -2.2966537475585938], [-2.2440671920776367, -2.224576473236084, -2.227633476257324], [-2.2316999435424805, -2.250920057296753, -2.2136573791503906], [-2.2012805938720703, -2.268103837966919, -2.2268927097320557], [-2.2615439891815186, -2.205914258956909, -2.228818893432617]]
thetas (2): torch.Size([10, 20, 3]), [[1.9679292440414429, -13.361215591430664, 11.39328670501709], [-5.341207981109619, -16.032730102539062, 21.373937606811523], [2.162893295288086, 9.208309173583984, -11.37120246887207], [-2.067270517349243, -5.465137958526611, 7.532332420349121], [3.2331838607788086, -21.262378692626953, 18.029193878173828], [12.495527267456055, 8.054993629455566, -20.550521850585938], [-3.8117008209228516, 2.392387628555298, 1.4193133115768433], [0.12491656839847565, -5.993035316467285, 5.868042945861816], [9.807696342468262, -11.46280288696289, 1.655106544494629], [-9.374737739562988, 8.332755088806152, 1.0419832468032837]]
xs (1): torch.Size([10, 20, 2]), [[5.441732406616211, 13.361213684082031], [15.423995971679688, 16.03272819519043], [-7.813913345336914, -9.208308219909668], [5.542333126068115, 5.46511173248291], [8.54248046875, 21.26237678527832], [-19.079145431518555, -8.05499267578125], [3.020127058029175, -2.3923873901367188], [3.315795421600342, 5.993009090423584], [-4.706899166107178, 11.462801933288574], [6.014095783233643, -8.332754135131836]]
decoded_with_j: torch.Size([10, 20, 2])
decoded: torch.Size([10, 2])
locations: [[5.441702842712402, 13.361233711242676], [15.423995018005371, 16.032732009887695], [-7.813913822174072, -9.208306312561035], [5.542320251464844, 5.465102672576904], [8.54250717163086, 21.26236915588379], [-19.079151153564453, -8.054960250854492], [3.0201385021209717, -2.3924241065979004], [3.315791130065918, 5.993025302886963], [-4.7069244384765625, 11.462799072265625], [6.0141096115112305, -8.332731246948242]]
decoded: [[5.441984176635742, 13.354339599609375], [15.423995018005371, 16.025999069213867], [-7.813673496246338, -9.214279174804688], [5.548147678375244, 5.46878719329834], [8.536678314208984, 21.265735626220703], [-19.047008514404297, -8.08303451538086], [3.014796018600464, -2.36527419090271], [3.3157920837402344, 5.987553596496582], [-4.678760051727295, 11.453174591064453], [6.013868808746338, -8.338705062866211]]
diff: [0.006899849511682987, 0.006732940673828125, 0.005977695342153311, 0.006894533988088369, 0.0067311739549040794, 0.04267685115337372, 0.027670564129948616, 0.005471706390380859, 0.02976345643401146, 0.005978667177259922]
decoded_with_j[base]: [[5.441732406616211, 13.361213684082031], [15.423995971679688, 16.03272819519043], [-7.813913345336914, -9.208308219909668], [5.542333126068115, 5.46511173248291], [8.54248046875, 21.26237678527832], [-19.079145431518555, -8.05499267578125], [3.020127058029175, -2.3923873901367188], [3.315795421600342, 5.993009090423584], [-4.706899166107178, 11.462801933288574], [6.014095783233643, -8.332754135131836]]