Skip to content

Instantly share code, notes, and snippets.

@shartoo
Last active January 13, 2023 09:46
Show Gist options
  • Save shartoo/47e61dd99e8a4f2e9172ea4abb3942c6 to your computer and use it in GitHub Desktop.
Save shartoo/47e61dd99e8a4f2e9172ea4abb3942c6 to your computer and use it in GitHub Desktop.
import tensorflow as tf
#### -- ------------------------------------------- 多张图片写入一个 tfrecord, 整个数据集写入多个 tfrecord ------------------
# 其中 max_files 为每个shard 分片 写入文件数量, images 为全部图片列表
splits = (len(images)//max_files) + 1
for i in range(splits):
# filename 为当前shard保存的文件名,out_dir 全部tfrecord 保存目录
current_shard_name = "{}{}_{}{}.tfrecords".format(out_dir, i+1, splits, filename)
# 为 每个 分片 shard 创建一个写
writer = tf.io.TFRecordWriter(current_shard_name)
current_shard_count = 0
# 单独解析每个要写入当前 shard 的图片数据 Example
# max_files 为每个分片的图片数量
while current_shard_count < max_files:
# 获得 当前分片-当前图片 索引
index = i*max_files+current_shard_count
if index == len(images):
break
current_image = images[index]
current_label = labels[index]
#create the required Example representation
out = parse_single_image(image=current_image, label=current_label)
writer.write(out.SerializeToString())
current_shard_count+=1
file_count += 1
writer.close()
print(f"\nWrote {file_count} elements to TFRecord")
return file_count
### -------------------- 从存储了多个 tfrecord 的文件中恢复图片数据 ---------------
def get_dataset_large(tfr_dir:str="/content/", pattern:str="*large_images.tfrecords"):
files = glob.glob(tfr_dir+pattern, recursive=False)
#create the dataset
dataset = tf.data.TFRecordDataset(files)
#pass every single feature through our mapping function
dataset = dataset.map(
parse_tfr_element
)
return dataset
## 测试调用
dataset_large = get_dataset_large()
for sample in dataset_large.take(1):
print(sample[0].shape)
print(sample[1].shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment