Created
June 10, 2025 18:26
-
-
Save alisterburt/75efcbfff6434e14e80ab8f09365fda4 to your computer and use it in GitHub Desktop.
tomogram reconstruction
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import combinations | |
import einops | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch_affine_utils import homogenise_coordinates | |
from torch_affine_utils.transforms_3d import Rx, Ry, Rz, T | |
from torch_grid_utils import dft_center | |
from torch_subpixel_crop import subpixel_crop_2d | |
from torch_fourier_slice import backproject_2d_to_3d | |
class Tomogram: | |
def __init__( | |
self, | |
tilt_angles: torch.Tensor, | |
tilt_axis_angle: torch.Tensor, | |
sample_translations: torch.Tensor, | |
images: torch.Tensor | None = None, # (b, h, w) | |
): | |
self.images = torch.tensor(images).float() | |
self.tilt_angles = torch.tensor(tilt_angles).float() | |
self.tilt_axis_angle = torch.tensor(tilt_axis_angle).float() | |
self.sample_translations = torch.tensor(sample_translations).float() | |
@property | |
def projection_matrices(self) -> torch.Tensor: | |
"""Matrices that project points from 3D -> 2D""" | |
shifts_3d = F.pad(self.sample_translations, (0, 1), value=0) | |
r0 = Ry(self.tilt_angles, zyx=True) | |
r1 = Rz(self.tilt_axis_angle, zyx=True) | |
t2 = T(shifts_3d) | |
return t2 @ r1 @ r0 | |
def project_points(self, points_zyx: torch.Tensor) -> torch.Tensor: | |
"""Project points from 3D to 2D. | |
- points are 3D zyx coordinates | |
- points are positions relative to center of tomogram | |
- projected 2D points are relative to center of 2D image | |
""" | |
points_zyx = torch.tensor(points_zyx).float() | |
M_yx = self.projection_matrices[..., [1, 2], :] # (ntilts, 2, 4) | |
points_zyxw = homogenise_coordinates(points_zyx) | |
projected_yx = M_yx @ einops.rearrange(points_zyxw, "nparticles zyxw -> nparticles 1 zyxw 1") | |
projected_yx = einops.rearrange(projected_yx, "nparticles ntilts yx 1 -> nparticles ntilts yx") | |
return projected_yx # (points, tilts, yx) | |
def extract_particle_tilt_series( | |
self, points_zyx: torch.Tensor, sidelength: int | |
) -> torch.Tensor: | |
projected_yx = self.project_points(points_zyx) | |
projected_yx += dft_center(self.images.shape[-2:], rfft=False, fftshift=True) | |
images = subpixel_crop_2d( | |
image=self.images, | |
positions=projected_yx, | |
sidelength=sidelength, | |
) | |
return images | |
def reconstruct_subvolume(self, point_zyx: torch.Tensor, sidelength: int) -> torch.Tensor: | |
point_zyx = torch.tensor(point_zyx).float() | |
point_zyx = point_zyx.reshape((-1, 3)) | |
rotation_matrices = self.projection_matrices[:, :3, :3] | |
rotation_matrices = torch.linalg.pinv(rotation_matrices) | |
particle_tilt_series = self.extract_particle_tilt_series(point_zyx, sidelength=sidelength) | |
volume = backproject_2d_to_3d( | |
images=particle_tilt_series[0], | |
rotation_matrices=rotation_matrices, | |
zyx_matrices=True | |
) | |
return volume | |
def reconstruct_tomogram( | |
self, | |
volume_shape: tuple[int, int, int], | |
sidelength: int | |
) -> torch.Tensor: | |
d, h, w = volume_shape | |
r = sidelength // 2 | |
# setup grid points | |
z = torch.arange(start=r, end=d + r, step=sidelength) - d // 2 | |
y = torch.arange(start=r, end=h + r, step=sidelength) - h // 2 | |
x = torch.arange(start=r, end=w + r, step=sidelength) - w // 2 | |
# allocate whole volume | |
tomogram = torch.zeros(size=volume_shape, dtype=torch.float32) | |
for _z in z: | |
for _y in y: | |
for _x in x: | |
zyx = torch.tensor([_z, _y, _x]).float() | |
subvolume = self.reconstruct_subvolume(zyx, sidelength=sidelength) | |
_d, _h, _w = zyx + torch.tensor(volume_shape) // 2 | |
print(_d, _h, _w) | |
_d, _h, _w = int(_d), int(_h), int(_w) | |
d_min, d_max = _d - r, _d + r | |
h_min, h_max = _h - r, _h + r | |
w_min, w_max = _w - r, _w + r | |
tomogram[d_min:d_max, h_min:h_max, w_min:w_max] = subvolume | |
return tomogram | |
if __name__ == "__main__": | |
from fast_histogram import histogramdd | |
tilt_series = np.load("simulated_ts/tilt_series.npy") | |
tilt_angles = np.load("simulated_ts/tilt_angles.npy") | |
tilt_axis_angle = np.load("simulated_ts/tilt_axis_angle.npy") | |
sample_translations = np.load("simulated_ts/sample_translations.npy")[:, :-1] | |
particle_positions = np.load("simulated_ts/particle_positions.npy") | |
atoms_3d = np.load("simulated_ts/atoms_3d.npy") | |
tomogram = Tomogram( | |
images=tilt_series, | |
tilt_angles=tilt_angles, | |
tilt_axis_angle=tilt_axis_angle, | |
sample_translations=sample_translations | |
) | |
particle_tilt_series = tomogram.extract_particle_tilt_series( | |
points_zyx=particle_positions, | |
sidelength=32, | |
) | |
ribo = tomogram.reconstruct_subvolume(point_zyx=particle_positions[0], sidelength=32) | |
# tomo = tomogram.reconstruct_tomogram(volume_shape=(512, 512, 512), sidelength=32) | |
atoms_3d = np.load("simulated_ts/atoms_3d.npy") | |
# simulate ribo subvolume for comparison | |
d, h, w = 32, 32, 32 | |
d_min, d_max = (-0.5, d - 1 + 0.5) | |
h_min, h_max = (-0.5, h - 1 + 0.5) | |
w_min, w_max = (-0.5, w - 1 + 0.5) | |
ribo_coords = atoms_3d - particle_positions[0] + np.array([16, 16, 16]) | |
subvolume = histogramdd( | |
sample=ribo_coords, | |
bins=(d, h, w), | |
range=[[d_min, d_max], [h_min, h_max], [w_min, w_max]] | |
) | |
import napari | |
viewer = napari.Viewer() | |
viewer.add_image(tilt_series) | |
viewer.add_image(particle_tilt_series[0].numpy()) | |
viewer.add_image(ribo.numpy()) | |
viewer.add_image(subvolume) | |
# viewer.add_image(tomo.numpy()) | |
napari.run() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment