Created
April 16, 2025 15:23
-
-
Save bigsnarfdude/f89033f08a784422e9504bad43260aff to your computer and use it in GitHub Desktop.
image_segmentation.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2024 Big Vision Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Evaluator for segmentation.""" | |
import functools | |
import big_vision.evaluators.common as c | |
import big_vision.pp.tokenizer | |
import big_vision.utils as u | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import PIL.Image | |
from tensorflow.io import gfile | |
# Temporary global flag to facilitate backwards compatability. Will be removed | |
# by the end of year 2023. | |
API = 'jit' | |
def _inrange(a, min_value, max_value): | |
return (np.clip(a, min_value, max_value) == a).all() | |
def _area(y1, x1, y2, x2): | |
return max(x2 - x1, 0.0) * max(y2 - y1, 0.0) | |
class Evaluator: | |
"""Evaluator for instance segmentation.""" | |
def __init__(self, predict_fn, tokenizer, | |
model='oi', det_ious=(0.5, 0.75), | |
*, devices, **kw): | |
self.get_data_iter, self.steps = c.eval_input_pipeline( | |
keep_on_cpu={'prefix', 'suffix', 'objects/mask', 'objects/bbox'}, | |
devices=devices, **kw) | |
self.tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) | |
self.decode = functools.partial( | |
predict_fn, devices=devices, eos_token=self.tok.eos_token) | |
tok = big_vision.pp.tokenizer.get_tokenizer(tokenizer) | |
self.loc0 = np.array(tok.to_int('<loc0000>')) | |
self.seg0 = np.array(tok.to_int('<seg000>')) | |
# Verify tokenizer has `tokensets=("loc", "seg")` | |
assert self.loc0.shape == (1,), self.loc0 | |
assert self.seg0.shape == (1,), self.seg0 | |
self.reconstruct_masks = get_reconstruct_masks(model) | |
self.det_ious = det_ious | |
def run(self, train_state): | |
"""Does one evaluation run, yields metrics.""" | |
ious = [] # NOTE: no point to split in s/m/l: all objects are L (>96px²) | |
det_by_iou = {iou: [] for iou in self.det_ious} | |
invalid = total = 0 | |
for _, batch in zip(range(self.steps), self.get_data_iter()): | |
decoded = self.decode(train_state, batch) | |
not_padding = u.get_local_slice_from_fsarray(batch['_mask']) | |
decoded = u.get_local_slice_from_fsarray(decoded)[not_padding] | |
# Note, gt masks are in full original image resolution. | |
gt_masks = [gt[:, :, 0] > 0 for gt in batch['objects/mask'][not_padding]] | |
gt_bbs = [gt for gt in batch['objects/bbox'][not_padding]] | |
valid = [] | |
tokens = np.zeros([decoded.shape[0], 4 + 16], np.int32) | |
for i, dec in enumerate(decoded): | |
# TODO: b/andstein - do we need to optimize this loop? | |
t = np.r_[dec[:4] - self.loc0, dec[4:4 + 16] - self.seg0] # Ignore rest | |
if ( | |
len(t) == 4 + 16 # Full prediction | |
and _inrange(t[:4], 0, 1023) # Valid box tokens | |
and _inrange(t[4:], 0, 127) # Valid seg tokens | |
and t[2] > t[0] and t[3] > t[1] # Valid box | |
): | |
valid.append(True) | |
tokens[i] = t | |
else: | |
valid.append(False) | |
tocpu = lambda x: jax.device_put(x, jax.local_devices(backend='cpu')[0]) | |
seg_indices = np.array(tokens[:, 4:]) | |
mask64 = jax.device_get(self.reconstruct_masks(tocpu(seg_indices))) | |
mask64 = mask64[..., 0] | |
bbox = tokens[:, :4] / 1023 # Back to [0.0 ... 1.0] | |
for v, m64, gtm, bb, gtbb in zip(valid, mask64, gt_masks, bbox, gt_bbs): | |
# TODO: b/andstein - do we need to optimize this loop? | |
total += 1 | |
h, w = gtm.shape # gt is full/original image resolution mask. | |
# First, compute detection iou, in [0.0 ... 1.0] coordinate space. | |
y1, x1, y2, x2 = bb | |
gty1, gtx1, gty2, gtx2 = gtbb | |
ibb = max(y1, gty1), max(x1, gtx1), min(y2, gty2), min(x2, gtx2) | |
box_iou = _area(*ibb) / (_area(*bb) + _area(*gtbb) - _area(*ibb)) | |
for iou_thresh in det_by_iou: | |
det_by_iou[iou_thresh].append(iou_thresh <= box_iou) | |
# Next, we convert to pixel coordinates and compute mask iou. | |
gt_area = gtm.sum() | |
y1, x1, y2, x2 = map(int, (y1 * h, x1 * w, y2 * h, x2 * w)) | |
# Avoid compute-intensive mask stuff for invalid preds: | |
if not v or x2 <= x1 or y2 <= y1: # Can still happen after int(). | |
iou = 0.0 | |
invalid += 1 | |
else: | |
mi = np.asarray( | |
PIL.Image.fromarray(m64).resize( # pytype: disable=wrong-arg-types # pillow-102-upgrade | |
[x2 - x1, y2 - y1], resample=PIL.Image.BILINEAR # pytype: disable=module-attr | |
) | |
) # Predicted mask in box-sized image. | |
mi = mi > 0.0 # Mask decoder output in [-1.0 ... 1.0] | |
iarea = (gtm[y1:y2, x1:x2] & mi).sum() # Intersection pixels. | |
iou = iarea / (gt_area + mi.sum() - iarea) | |
ious.append(iou) | |
# Done going over all batches, now collect results from all processes. | |
sum_ious, num_ious, sum_dets, num_dets, num_invalid, num = c.process_sum([ | |
sum(ious), len(ious), | |
{k: sum(v) for k, v in det_by_iou.items()}, | |
{k: len(v) for k, v in det_by_iou.items()}, | |
invalid, total | |
]) | |
yield 'miou', sum_ious / num_ious | |
for k in sum_dets: | |
yield f'boxacc/{k}', sum_dets[k] / num_dets[k] | |
yield 'invalid', num_invalid | |
yield 'total', num | |
_KNOWN_MODELS = { | |
# Trained on open images. | |
'oi': 'gs://big_vision/paligemma/vae-oid.npz', | |
} | |
def _get_params(checkpoint): | |
"""Converts PyTorch checkpoint to Flax params.""" | |
def transp(kernel): | |
return np.transpose(kernel, (2, 3, 1, 0)) | |
def conv(name): | |
return { | |
'bias': checkpoint[name + '.bias'], | |
'kernel': transp(checkpoint[name + '.weight']), | |
} | |
def resblock(name): | |
return { | |
'Conv_0': conv(name + '.0'), | |
'Conv_1': conv(name + '.2'), | |
'Conv_2': conv(name + '.4'), | |
} | |
return { | |
'_embeddings': checkpoint['_vq_vae._embedding'], | |
'Conv_0': conv('decoder.0'), | |
'ResBlock_0': resblock('decoder.2.net'), | |
'ResBlock_1': resblock('decoder.3.net'), | |
'ConvTranspose_0': conv('decoder.4'), | |
'ConvTranspose_1': conv('decoder.6'), | |
'ConvTranspose_2': conv('decoder.8'), | |
'ConvTranspose_3': conv('decoder.10'), | |
'Conv_1': conv('decoder.12'), | |
} | |
def _quantized_values_from_codebook_indices(codebook_indices, embeddings): | |
batch_size, num_tokens = codebook_indices.shape | |
assert num_tokens == 16, codebook_indices.shape | |
unused_num_embeddings, embedding_dim = embeddings.shape | |
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0) | |
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim)) | |
return encodings | |
class ResBlock(nn.Module): | |
features: int | |
@nn.compact | |
def __call__(self, x): | |
original_x = x | |
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
x = nn.relu(x) | |
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x) | |
x = nn.relu(x) | |
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x) | |
return x + original_x | |
class Decoder(nn.Module): | |
"""Upscales quantized vectors to mask.""" | |
@nn.compact | |
def __call__(self, x): | |
num_res_blocks = 2 | |
dim = 128 | |
num_upsample_layers = 4 | |
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x) | |
x = nn.relu(x) | |
for _ in range(num_res_blocks): | |
x = ResBlock(features=dim)(x) | |
for _ in range(num_upsample_layers): | |
x = nn.ConvTranspose( | |
features=dim, | |
kernel_size=(4, 4), | |
strides=(2, 2), | |
padding=2, | |
transpose_kernel=True, | |
)(x) | |
x = nn.relu(x) | |
dim //= 2 | |
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x) | |
return x | |
@functools.cache | |
def get_reconstruct_masks(model): | |
"""Reconstructs masks from codebook indices. | |
Based on code from https://arxiv.org/abs/2301.02229 | |
Verified in | |
https://colab.research.google.com/drive/1AOr0cokOpM6-N9Z5HmxoeGxGj6jS37Vl | |
Args: | |
model: Model to use for conversion. | |
Returns: | |
A function that expects indices shaped `[B, 16]` of dtype int32, each | |
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized | |
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1]. | |
""" | |
def reconstruct_masks(codebook_indices): | |
quantized = _quantized_values_from_codebook_indices( | |
codebook_indices, params['_embeddings'] | |
) | |
return Decoder().apply({'params': params}, quantized) | |
with gfile.GFile(_KNOWN_MODELS.get(model, model), 'rb') as f: | |
params = _get_params(dict(np.load(f))) | |
return jax.jit(reconstruct_masks, backend='cpu') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment