Created
December 20, 2020 19:15
-
-
Save deepak-karkala/f8d150b1c01361b2dbefb859e0975ca7 to your computer and use it in GitHub Desktop.
Logging custom data to tensorboard
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### Tensorboard callbacks | |
# Log IoU at end of epoch | |
def log_epoch_metrics(epoch, logs): | |
# Log directory | |
logdir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | |
file_writer = tf.summary.create_file_writer(logdir + "/fit") | |
file_writer.set_as_default() | |
# Intersection Over Union metric | |
m = tf.keras.metrics.MeanIoU(num_classes=len(LABEL_NAMES)) | |
m.reset_states() | |
# Compute mean IoU for validation dataset | |
for image, mask in val_dataset: | |
input_image_model, expected_mask = image, mask | |
predicted_mask = model.predict(input_image_model) | |
predicted_mask = create_mask(predicted_mask) | |
m.update_state(predicted_mask, expected_mask) | |
iou = m.result().numpy() | |
tf.summary.scalar('IoU', data=iou, step=epoch) | |
# Define the per-epoch callback. | |
iou_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_epoch_metrics) | |
# Basic Tensorboard callback for loss and profiling | |
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") | |
# Profile from batches 10 to 15 | |
tb_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch='1, 5') | |
# Fit model | |
EPOCHS = 50 | |
VAL_SUBSPLITS = 5 | |
VALIDATION_STEPS = len(val_images)//BATCH_SIZE//VAL_SUBSPLITS | |
model_history = model.fit(train_dataset, epochs=EPOCHS, | |
steps_per_epoch=STEPS_PER_EPOCH, | |
validation_steps=VALIDATION_STEPS, | |
validation_data=val_dataset, | |
callbacks=[tb_callback, iou_callback]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment