利用Python 脚本生成 .h5 文件
1 import os, json, argparse 2 from threading import Thread 3 from Queue import Queue 4 5 import numpy as np 6 from scipy.misc import imread, imresize 7 import h5py 8 9 """ 10 Create an HDF5 file of images for training a feedforward style transfer model. 11 """ 12 13 parser = argparse.ArgumentParser() 14 parser.add_argument('--train_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/train2014') 15 parser.add_argument('--val_dir', default='/media/wangxiao/WangXiao_Dataset/CoCo/val2014') 16 parser.add_argument('--output_file', default='/media/wangxiao/WangXiao_Dataset/CoCo/coco-256.h5') 17 parser.add_argument('--height', type=int, default=256) 18 parser.add_argument('--width', type=int, default=256) 19 parser.add_argument('--max_images', type=int, default=-1) 20 parser.add_argument('--num_workers', type=int, default=2) 21 parser.add_argument('--include_val', type=int, default=1) 22 parser.add_argument('--max_resize', default=16, type=int) 23 args = parser.parse_args() 24 25 26 def add_data(h5_file, image_dir, prefix, args): 27 # Make a list of all images in the source directory 28 image_list = [] 29 image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG', '.png', '.PNG'} 30 for filename in os.listdir(image_dir): 31 ext = os.path.splitext(filename)[1] 32 if ext in image_extensions: 33 image_list.append(os.path.join(image_dir, filename)) 34 num_images = len(image_list) 35 36 # Resize all images and copy them into the hdf5 file 37 # We'll bravely try multithreading 38 dset_name = os.path.join(prefix, 'images') 39 dset_size = (num_images, 3, args.height, args.width) 40 imgs_dset = h5_file.create_dataset(dset_name, dset_size, np.uint8) 41 42 # input_queue stores (idx, filename) tuples, 43 # output_queue stores (idx, resized_img) tuples 44 input_queue = Queue() 45 output_queue = Queue() 46 47 # Read workers pull images off disk and resize them 48 def read_worker(): 49 while True: 50 idx, filename = input_queue.get() 51 img = imread(filename) 52 try: 53 # First crop the image so its size is a multiple of max_resize 54 H, W = img.shape[0], img.shape[1] 55 H_crop = H - H % args.max_resize 56 W_crop = W - W % args.max_resize 57 img = img[:H_crop, :W_crop] 58 img = imresize(img, (args.height, args.width)) 59 except (ValueError, IndexError) as e: 60 print filename 61 print img.shape, img.dtype 62 print e 63 input_queue.task_done() 64 output_queue.put((idx, img)) 65 66 # Write workers write resized images to the hdf5 file 67 def write_worker(): 68 num_written = 0 69 while True: 70 idx, img = output_queue.get() 71 if img.ndim == 3: 72 # RGB image, transpose from H x W x C to C x H x W 73 imgs_dset[idx] = img.transpose(2, 0, 1) 74 elif img.ndim == 2: 75 # Grayscale image; it is H x W so broadcasting to C x H x W will just copy 76 # grayscale values into all channels. 77 imgs_dset[idx] = img 78 output_queue.task_done() 79 num_written = num_written + 1 80 if num_written % 100 == 0: 81 print 'Copied %d / %d images' % (num_written, num_images) 82 83 # Start the read workers. 84 for i in xrange(args.num_workers): 85 t = Thread(target=read_worker) 86 t.daemon = True 87 t.start() 88 89 # h5py locks internally, so we can only use a single write worker =( 90 t = Thread(target=write_worker) 91 t.daemon = True 92 t.start() 93 94 for idx, filename in enumerate(image_list): 95 if args.max_images > 0 and idx >= args.max_images: break 96 input_queue.put((idx, filename)) 97 98 input_queue.join() 99 output_queue.join() 100 101 102 103 if __name__ == '__main__': 104 105 with h5py.File(args.output_file, 'w') as f: 106 add_data(f, args.train_dir, 'train2014', args) 107 108 if args.include_val != 0: 109 add_data(f, args.val_dir, 'val2014', args)
时间: 2024-10-23 19:56:42