Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Created April 16, 2025 15:23
Show Gist options
  • Save bigsnarfdude/f89033f08a784422e9504bad43260aff to your computer and use it in GitHub Desktop.
Save bigsnarfdude/f89033f08a784422e9504bad43260aff to your computer and use it in GitHub Desktop.
image_segmentation.py
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluator for segmentation."""
import functools
import big_vision.evaluators.common as c
import big_vision.pp.tokenizer
import big_vision.utils as u
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import PIL.Image
from tensorflow.io import gfile
# Temporary global flag to facilitate backwards compatability. Will be removed
# by the end of year 2023.
API = 'jit'
def _inrange(a, min_value, max_value):
return (np.clip(a, min_value, max_value) == a).all()
def _area(y1, x1, y2, x2):
return max(x2 - x1, 0.0) * max(y2 - y1, 0.0)
class Evaluator:
"""Evaluator for instance segmentation."""
def __init__(self, predict_fn, tokenizer,
model='oi', det_ious=(0.5, 0.75),
*, devices, **kw):
self.get_data_iter, self.steps = c.eval_input_pipeline(
keep_on_cpu={'prefix', 'suffix', 'objects/mask', 'objects/bbox'},
devices=devices, **kw)
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
self.decode = functools.partial(
predict_fn, devices=devices, eos_token=self.tok.eos_token)
tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer)
self.loc0 = np.array(tok.to_int('<loc0000>'))
self.seg0 = np.array(tok.to_int('<seg000>'))
# Verify tokenizer has `tokensets=("loc", "seg")`
assert self.loc0.shape == (1,), self.loc0
assert self.seg0.shape == (1,), self.seg0
self.reconstruct_masks = get_reconstruct_masks(model)
self.det_ious = det_ious
def run(self, train_state):
"""Does one evaluation run, yields metrics."""
ious = [] # NOTE: no point to split in s/m/l: all objects are L (>96px²)
det_by_iou = {iou: [] for iou in self.det_ious}
invalid = total = 0
for _, batch in zip(range(self.steps), self.get_data_iter()):
decoded = self.decode(train_state, batch)
not_padding = u.get_local_slice_from_fsarray(batch['_mask'])
decoded = u.get_local_slice_from_fsarray(decoded)[not_padding]
# Note, gt masks are in full original image resolution.
gt_masks = [gt[:, :, 0] > 0 for gt in batch['objects/mask'][not_padding]]
gt_bbs = [gt for gt in batch['objects/bbox'][not_padding]]
valid = []
tokens = np.zeros([decoded.shape[0], 4 + 16], np.int32)
for i, dec in enumerate(decoded):
# TODO: b/andstein - do we need to optimize this loop?
t = np.r_[dec[:4] - self.loc0, dec[4:4 + 16] - self.seg0] # Ignore rest
if (
len(t) == 4 + 16 # Full prediction
and _inrange(t[:4], 0, 1023) # Valid box tokens
and _inrange(t[4:], 0, 127) # Valid seg tokens
and t[2] > t[0] and t[3] > t[1] # Valid box
):
valid.append(True)
tokens[i] = t
else:
valid.append(False)
tocpu = lambda x: jax.device_put(x, jax.local_devices(backend='cpu')[0])
seg_indices = np.array(tokens[:, 4:])
mask64 = jax.device_get(self.reconstruct_masks(tocpu(seg_indices)))
mask64 = mask64[..., 0]
bbox = tokens[:, :4] / 1023 # Back to [0.0 ... 1.0]
for v, m64, gtm, bb, gtbb in zip(valid, mask64, gt_masks, bbox, gt_bbs):
# TODO: b/andstein - do we need to optimize this loop?
total += 1
h, w = gtm.shape # gt is full/original image resolution mask.
# First, compute detection iou, in [0.0 ... 1.0] coordinate space.
y1, x1, y2, x2 = bb
gty1, gtx1, gty2, gtx2 = gtbb
ibb = max(y1, gty1), max(x1, gtx1), min(y2, gty2), min(x2, gtx2)
box_iou = _area(*ibb) / (_area(*bb) + _area(*gtbb) - _area(*ibb))
for iou_thresh in det_by_iou:
det_by_iou[iou_thresh].append(iou_thresh <= box_iou)
# Next, we convert to pixel coordinates and compute mask iou.
gt_area = gtm.sum()
y1, x1, y2, x2 = map(int, (y1 * h, x1 * w, y2 * h, x2 * w))
# Avoid compute-intensive mask stuff for invalid preds:
if not v or x2 <= x1 or y2 <= y1: # Can still happen after int().
iou = 0.0
invalid += 1
else:
mi = np.asarray(
PIL.Image.fromarray(m64).resize( # pytype: disable=wrong-arg-types # pillow-102-upgrade
[x2 - x1, y2 - y1], resample=PIL.Image.BILINEAR # pytype: disable=module-attr
)
) # Predicted mask in box-sized image.
mi = mi > 0.0 # Mask decoder output in [-1.0 ... 1.0]
iarea = (gtm[y1:y2, x1:x2] & mi).sum() # Intersection pixels.
iou = iarea / (gt_area + mi.sum() - iarea)
ious.append(iou)
# Done going over all batches, now collect results from all processes.
sum_ious, num_ious, sum_dets, num_dets, num_invalid, num = c.process_sum([
sum(ious), len(ious),
{k: sum(v) for k, v in det_by_iou.items()},
{k: len(v) for k, v in det_by_iou.items()},
invalid, total
])
yield 'miou', sum_ious / num_ious
for k in sum_dets:
yield f'boxacc/{k}', sum_dets[k] / num_dets[k]
yield 'invalid', num_invalid
yield 'total', num
_KNOWN_MODELS = {
# Trained on open images.
'oi': 'gs://big_vision/paligemma/vae-oid.npz',
}
def _get_params(checkpoint):
"""Converts PyTorch checkpoint to Flax params."""
def transp(kernel):
return np.transpose(kernel, (2, 3, 1, 0))
def conv(name):
return {
'bias': checkpoint[name + '.bias'],
'kernel': transp(checkpoint[name + '.weight']),
}
def resblock(name):
return {
'Conv_0': conv(name + '.0'),
'Conv_1': conv(name + '.2'),
'Conv_2': conv(name + '.4'),
}
return {
'_embeddings': checkpoint['_vq_vae._embedding'],
'Conv_0': conv('decoder.0'),
'ResBlock_0': resblock('decoder.2.net'),
'ResBlock_1': resblock('decoder.3.net'),
'ConvTranspose_0': conv('decoder.4'),
'ConvTranspose_1': conv('decoder.6'),
'ConvTranspose_2': conv('decoder.8'),
'ConvTranspose_3': conv('decoder.10'),
'Conv_1': conv('decoder.12'),
}
def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
batch_size, num_tokens = codebook_indices.shape
assert num_tokens == 16, codebook_indices.shape
unused_num_embeddings, embedding_dim = embeddings.shape
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
return encodings
class ResBlock(nn.Module):
features: int
@nn.compact
def __call__(self, x):
original_x = x
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
return x + original_x
class Decoder(nn.Module):
"""Upscales quantized vectors to mask."""
@nn.compact
def __call__(self, x):
num_res_blocks = 2
dim = 128
num_upsample_layers = 4
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
x = nn.relu(x)
for _ in range(num_res_blocks):
x = ResBlock(features=dim)(x)
for _ in range(num_upsample_layers):
x = nn.ConvTranspose(
features=dim,
kernel_size=(4, 4),
strides=(2, 2),
padding=2,
transpose_kernel=True,
)(x)
x = nn.relu(x)
dim //= 2
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
return x
@functools.cache
def get_reconstruct_masks(model):
"""Reconstructs masks from codebook indices.
Based on code from https://arxiv.org/abs/2301.02229
Verified in
https://colab.research.google.com/drive/1AOr0cokOpM6-N9Z5HmxoeGxGj6jS37Vl
Args:
model: Model to use for conversion.
Returns:
A function that expects indices shaped `[B, 16]` of dtype int32, each
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
"""
def reconstruct_masks(codebook_indices):
quantized = _quantized_values_from_codebook_indices(
codebook_indices, params['_embeddings']
)
return Decoder().apply({'params': params}, quantized)
with gfile.GFile(_KNOWN_MODELS.get(model, model), 'rb') as f:
params = _get_params(dict(np.load(f)))
return jax.jit(reconstruct_masks, backend='cpu')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment