专注Java领域技术
我们一直在努力

图像集存储成MNIST数据集格式实现

原文始发于:图像集存储成MNIST数据集格式实现

有时会用到将一组图像存放成MNIST中那样的数据格式,以便于用于网络的训练和测试,如MNSIT中的测试集标签t10k-labels.idx1-ubyte和测试集图像t10k-images.idx3-ubyte,各包含了10000个样本,这里以此两个测试集为例详细说明下实现过程:

http://yann.lecun.com/exdb/mnist/  中对MNIST的数据存放格式进行了介绍,存储的数据都以大多数非英特尔处理器使用的MSB优先(高端)格式存储,英特尔处理器和其他低端机器的用户必须翻转标头的字节(All the integers in the files are stored in the MSB first(high endian) format used by most non-Intel processors. Users of Intel processors and other low-endian machines must flip the bytes of the header.)。

t10k-labels.idx1-ubyte(训练集标签train-labels.idx1-ubyte与此存放格式完全相同):第1至第4个字节存放magic number(MSB first);第5至第8个字节存放标签数即10000;从第9个字节开始,每个字节存放一个标签值(label value),标签值的范围为0到9。

此处的magic number(MSB first)是一个四个字节的整数,是一个IDX文件格式;第1,第2个字节总是0;第3个字节值表示数据的类型,如0x08表示unsigned byte;0x09表示signed byte;0x0B表示short(2 bytes);0x0C表示int(4 bytes);0x0D表示float(4 bytes);0x0E表示double(8 bytes);因为t10k-labels.idx1-ubyte中标签值范围为0到9,因此这里第3字节值为0x08;第4个字节表示向量/矩阵的维数,1表示向量,2表示矩阵等;这里的标签为一维向量,因此第4字节为0x01。t10k-labels.idx1-ubyte中的前8个字节是两个magic number。

打开t10k-labels.idx1-ubyte二进制文件,前8个字节数据是:00 00 08 01 00 00 27 10,这里需要注意的是,magic number是一个四字节int,在读或写时每次性读取4个字节,高字节在后,低字节在前,与存储时顺序不同,高字节在前,低字节在后,因此在读或写magic number时,需要做个转换,即高字节变低字节,低字节变高字节,实现见ReverseInt函数。

t10k-images.idx3-ubyte(训练集图像train-images.idx3-ubyte与此存放格式完全相同):第1至第4个字节存放magic number(MSB first);第5至第8个字节存放图像数即10000;第9至第12个字节存放每个图像的行数即高,这里为28;第13至第16个字节存放每个图像的列数即宽,这里为28;从第17个字节开始,每个字节存放一个像素值,像素值的范围为0到255,0表示背景,255表示前景,像素按行排列;每28*28个字节大小存放一幅图像数据。

此处的magic number(MSB first)是一个四个字节的整数,是一个IDX文件格式;第1,第2个字节总是0;第3个字节值表示数据的类型,如0x08表示unsigned byte;0x09表示signed byte;0x0B表示short(2 bytes);0x0C表示int(4 bytes);0x0D表示float(4 bytes);0x0E表示double(8 bytes);因为t10k-images.idx3-ubyte中图像像素值范围为0到255,因此这里第3字节值为0x08;第4个字节表示向量/矩阵的维数,1表示向量,2表示矩阵等;这里的图像可看做三维即channels*height*width,因此第4字节为0x03。t10k-images.idx3-ubyte中的前16个字节是四个magic number。打开0x03.t10k-images.idx3-ubyte二进制文件,前16个字节数据是:00 00 08 03 00 00 27 10 00 00 00 1c 00 00 00 1c。

测试代码如下:

#include "funset.hpp" #include <iostream> #include <fstream> #include <vector> #include <memory> #include <opencv2/opencv.hpp>  // MNIST / namespace { int ReverseInt(int i) { 	unsigned char ch1, ch2, ch3, ch4; 	ch1 = i & 255; 	ch2 = (i >> 8) & 255; 	ch3 = (i >> 16) & 255; 	ch4 = (i >> 24) & 255; 	return((int)ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4; }  void read_Mnist(std::string filename, std::vector<cv::Mat> &vec) { 	std::ifstream file(filename, std::ios::binary); 	if (file.is_open()) { 		int magic_number = 0; 		int number_of_images = 0; 		int n_rows = 0; 		int n_cols = 0; 		file.read((char*)&magic_number, sizeof(magic_number)); 		magic_number = ReverseInt(magic_number); 		file.read((char*)&number_of_images, sizeof(number_of_images)); 		number_of_images = ReverseInt(number_of_images); 		file.read((char*)&n_rows, sizeof(n_rows)); 		n_rows = ReverseInt(n_rows); 		file.read((char*)&n_cols, sizeof(n_cols)); 		n_cols = ReverseInt(n_cols);  		for (int i = 0; i < number_of_images; ++i) { 			cv::Mat tp = cv::Mat::zeros(n_rows, n_cols, CV_8UC1); 			for (int r = 0; r < n_rows; ++r) { 				for (int c = 0; c < n_cols; ++c) { 					unsigned char temp = 0; 					file.read((char*)&temp, sizeof(temp)); 					tp.at<uchar>(r, c) = (int)temp; 				} 			} 			vec.push_back(tp); 		}  		file.close(); 	} }  void read_Mnist_Label(std::string filename, std::vector<int> &vec) { 	std::ifstream file(filename, std::ios::binary); 	if (file.is_open()) { 		int magic_number = 0; 		int number_of_images = 0; 		int n_rows = 0; 		int n_cols = 0; 		file.read((char*)&magic_number, sizeof(magic_number)); 		magic_number = ReverseInt(magic_number); 		file.read((char*)&number_of_images, sizeof(number_of_images)); 		number_of_images = ReverseInt(number_of_images);  		for (int i = 0; i < number_of_images; ++i) { 			unsigned char temp = 0; 			file.read((char*)&temp, sizeof(temp)); 			vec[i] = (int)temp; 		}  		file.close(); 	} }  std::string GetImageName(int number, int arr[]) { 	std::string str1, str2;  	for (int i = 0; i < 10; i++) { 		if (number == i) { 			arr[i]++; 			str1 = std::to_string(arr[i]);  			if (arr[i] < 10) { 				str1 = "0000" + str1; 			} else if (arr[i] < 100) { 				str1 = "000" + str1; 			} else if (arr[i] < 1000) { 				str1 = "00" + str1; 			} else if (arr[i] < 10000) { 				str1 = "0" + str1; 			}  			break; 		} 	}  	str2 = std::to_string(number) + "_" + str1;  	return str2; }  int write_images_to_file(const std::string& file_name, const std::vector<cv::Mat>& image_data, 	int magic_number, int image_number, int image_rows, int image_cols) { 	if (image_number > image_data.size()) { 		fprintf(stderr, "Error: image_number > image_data.size():  			image_number: %d, image_data.size: %d", image_number, image_data.size()); 		return -1; 	}  	std::ofstream file(file_name, std::ios::binary); 	if (!file.is_open()) { 		fprintf(stderr, "Error: open file fail: %sn", file_name.c_str()); 		return -1; 	}  	int tmp = ReverseInt(magic_number); 	file.write((char*)&tmp, sizeof(int)); 	tmp = ReverseInt(image_number); 	file.write((char*)&tmp, sizeof(int)); 	tmp = ReverseInt(image_rows); 	file.write((char*)&tmp, sizeof(int)); 	tmp = ReverseInt(image_cols); 	file.write((char*)&tmp, sizeof(int));  	int size = image_rows * image_cols; 	for (int i = 0; i < image_number; ++i) { 		file.write((char*)image_data[i].data, sizeof(unsigned char) * size); 	}  	file.close(); 	return 0; }  int write_labels_to_file(const std::string& file_name, const std::vector<int>& label_data, 	int magic_number, int label_number) { 	if (label_number > label_data.size()) { 		fprintf(stderr, "Error: label_number > label_data.size():  			label_number: %d, label_data.size: %d", label_number, label_data.size()); 		return -1; 	}  	std::ofstream file(file_name, std::ios::binary); 	if (!file.is_open()) { 		fprintf(stderr, "Error: open file fail: %sn", file_name.c_str()); 		return -1; 	}  	int tmp = ReverseInt(magic_number); 	file.write((char*)&tmp, sizeof(int)); 	tmp = ReverseInt(label_number); 	file.write((char*)&tmp, sizeof(int));  	std::unique_ptr<unsigned char[]> labels(new unsigned char[label_number]); 	for (int i = 0; i < label_number; ++i) { 		labels[i] = static_cast<unsigned char>(label_data[i]); 	} 	file.write((char*)labels.get(), sizeof(unsigned char) * label_number);  	file.close(); 	return 0; } } // namespace //mnist  int ImageToMNIST() { 	// read images #ifdef _MSC_VER 	std::string filename_test_images = "E:/GitCode/NN_Test/data/database/MNIST/t10k-images.idx3-ubyte"; #else 	std::string filename_test_images = "data/database/MNIST/t10k-images.idx3-ubyte"; #endif 	const int number_of_test_images = 10000; 	std::vector<cv::Mat> vec_test_images;  	read_Mnist(filename_test_images, vec_test_images); 	if (vec_test_images.size() != number_of_test_images) { 		fprintf(stderr, "Error: fail to parse t10k-images.idx3-ubyte file: %dn", vec_test_images.size()); 		return -1; 	}  	// read labels #ifdef _MSC_VER 	std::string filename_test_labels = "E:/GitCode/NN_Test/data/database/MNIST/t10k-labels.idx1-ubyte"; #else 	std::string filename_test_labels = "data/database/MNIST/t10k-labels.idx1-ubyte"; #endif 	std::vector<int> vec_test_labels(number_of_test_images);  	read_Mnist_Label(filename_test_labels, vec_test_labels);  	// write images 	const int image_magic_number = 2051; // 0x00000803 	const int image_number = 10000; 	const int image_rows = 28; 	const int image_cols = 28; #ifdef _MSC_VER 	const std::string images_save_file_name = "E:/GitCode/NN_Test/data/new_t10k-images.idx3-ubyte"; #else 	const std::string images_save_file_name = "data/new_t10k-images.idx3-ubyte"; #endif  	if (write_images_to_file(images_save_file_name, vec_test_images, image_magic_number, 		image_number, image_rows, image_cols) != 0) { 		fprintf(stderr, "Error: write images to file failn"); 		return -1; 	}  	// write labels 	const int label_magic_number = 2049; // 0x00000801 	const int label_number = 10000; #ifdef _MSC_VER 	const std::string labels_save_file_name = "E:/GitCode/NN_Test/data/new_t10k-labels.idx1-ubyte"; #else 	const std::string labels_save_file_name = "data/new_t10k-labels.idx1-ubyte"; #endif  	if (write_labels_to_file(labels_save_file_name, vec_test_labels, label_magic_number, label_number) != 0) { 		fprintf(stderr, "Error: write labels to file failn"); 		return -1; 	}  	return 0; } 

新生成的两个数据文件为new_t10k-labels.idx1-ubyte和new_t10k-images.idx3-ubyte,通过md5可知,新生成的文件与原始文件完全相同,结果如下:

图像集存储成MNIST数据集格式实现

GitHub: https://github.com/fengbingchun/NN_Test 

赞(0) 打赏
未经允许不得转载:Java小咖秀 » 图像集存储成MNIST数据集格式实现
免责声明

抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

专注Java技术 100年

联系我们联系我们

你默默的关注就是最好的打赏~

支付宝扫一扫打赏

微信扫一扫打赏