文本目录
利用JavaCV+OpenCV的ANN_MLP神经网络训练识别MNIST手写数字
JavaCV是可以在java中使用OpenCV的一个库。OpenCV是一个跨平台的开源计算机视觉和机器学习软件库。白话就是一个处理图片和进行人工智能识别图片的一个软件库。
MNIST手写数字数据集
MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。
下载地址:http://yann.lecun.com/exdb/mnist/
做为许多神经网络学习的入门数据,一直没找到javaCV的相关例子。
用IDEA架设创建JAVACV的开发环境,请参考:在IDEA和Android Studio中用Gradle构建javacv开发环境
读取MNIST数据
MNIST数据有4个文件,分别为训练和测试的图片和标签。关于用java读取的方法可参考文章:
上面文章中介绍有MNIST文件的格式等信息,在这里就不再重复。
从MNIST生成训练图片Mat和标签Mat数据
JavaCV训练时的所有数据,都用Mat的形式提供。说白了就是一个float数组。注意神经网训练时,最好用float数据,MNIST数据集是一个byte数组,这里需要转换一下。
图片Mat的格式
x = new Mat(number, size, CvType.CV_32FC1);
number是样本数量,做为mat的行数
size是图片像素点数,即28*28。每个样本图片生成一个单行的数组放入Mat中
CV_32FC1是数据类型,为32位的float数据
完整的代码如下:
/** * 生成训练数据 * * @param fileName the file of 'train' or 'test' about image * @return one row show a `picture` */ public static Mat getTrainData(String fileName) { Mat x = null; try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000803".equals(bytesToHex(bytes))) { // 读取魔数 throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); // 读取样本总数 int number = Integer.parseInt(bytesToHex(bytes), 16); bin.read(bytes, 0, 4); // 读取每行所含像素点数 int xPixel = Integer.parseInt(bytesToHex(bytes), 16); bin.read(bytes, 0, 4); // 读取每列所含像素点数 int yPixel = Integer.parseInt(bytesToHex(bytes), 16); int l = xPixel*yPixel; x = new Mat(number, l, CvType.CV_32FC1); FloatIndexer indexer = x.createIndexer(); for (int i = 0; i < number; i++) { for(int j=0; j<l; j++){ indexer.put(i, j, bin.read()); } } } } catch (IOException e) { throw new RuntimeException(e); } return x; }
标签Mat数据格式
Mat x = new Mat(data.length, 10, CvType.CV_32FC1);
data.length是样本数量,做为mat的行数
10是每个标签的数据量,即为float[10]。每个标签成一个单行的数组放入Mat中
CV_32FC1是数据类型,为32位的float数据
完整的代码如下:
/** * 获取训练的标签 * 格式要求,每个标签为一个 float[10]数组,放在Mat的一行中 * @param fileName * @return */ public static Mat getTrainLabels(String fileName) { byte[] data = getLabels(fileName); Mat x = new Mat(data.length, 10, CvType.CV_32FC1); FloatIndexer indexer = x.createIndexer(); for(int i=0; i<data.length; i++){ byte b = (byte) data[i]; for(int j=0; j<10; j++){ if(j==b) indexer.put(i, j, 1); else indexer.put(i, j, 0); } } return x; } /** * 获取所有标签的数值 * * @param fileName the file of 'train' or 'test' about label * @return */ public static byte[] getLabels(String fileName) { byte[] y = null; try (BufferedInputStream bin = new BufferedInputStream(new FileInputStream(fileName))) { byte[] bytes = new byte[4]; bin.read(bytes, 0, 4); if (!"00000801".equals(bytesToHex(bytes))) { throw new RuntimeException("Please select the correct file!"); } else { bin.read(bytes, 0, 4); int number = Integer.parseInt(bytesToHex(bytes), 16); y = new byte[number]; byte c; for (int i = 0; i < number; i++) { c = (byte) bin.read(); y[i] = c; } } } catch (IOException e) { throw new RuntimeException(e); } return y; }
创建ANN_MLP神网络训练数据
创建了一个四层的神经网络,神经元个数分别为 { 28*28 , 512, 256, 10 } ,分别为:
输入层,对应着每个像素,所以是28*28
隐含层两个,神经元个数分别为 512 和 256
输出层,和训练的标签对应,神经元为10个,即数字 0123456789
具体代码如下:
/** * 训练数据 * @param xml 要保存的数据文件名 */ public static void train(String xml){ opencv_core.Mat trainData = MnistRead.getTrainData(TRAIN_IMAGES_FILE); opencv_core.Mat lables = MnistRead.getTrainLabels(TRAIN_LABELS_FILE); opencv_ml.ANN_MLP mlp= opencv_ml.ANN_MLP.create(); int image_cols = 28; //图片宽 int image_rows = 28; //图片高 int class_num = 10; //预测的结果,为 float[10] 数组 /* * 神经网络层 * */ int[] layer={ image_cols*image_rows , 512, 256, class_num}; opencv_core.Mat layerSizes=new opencv_core.Mat(1, layer.length, CV_32FC1); org.bytedeco.javacpp.indexer.FloatIndexer indexer = layerSizes.createIndexer(); for(int i=0;i<layer.length;i++){ indexer.put(i, layer[i]); } mlp.setLayerSizes(layerSizes); mlp.setActivationFunction(opencv_ml.ANN_MLP.SIGMOID_SYM); mlp.train(trainData, ROW_SAMPLE, lables); /* * 开始训练 * */ mlp.save(xml); mlp.clear(); System.out.println("训练结束"); }
正确率测试
数据格式和训练时一样,就不做解释了,代码如下:
/** * 使用测试数据,测试识别率 * @param xml 训练好的数据文件 */ public static void test(String xml){ opencv_ml.ANN_MLP ann = opencv_ml.ANN_MLP.load(xml); opencv_core.Mat predictData = MnistRead.getTrainData(TEST_IMAGES_FILE); byte[] predictLables = MnistRead.getLabels(TEST_LABELS_FILE); //正确计数 int rc = 0; for(int i=0; i<predictData.rows(); i++){ opencv_core.Mat sample = predictData.row(i); opencv_core.Mat predict = new opencv_core.Mat(); ann.predict(sample, predict, UPDATE_MODEL); if(predictLables[i] == getMaxIndex(predict)){ //预测正确 rc++; } } //计算正确率 double zql = rc*1.0/predictData.rows(); System.out.println("正确率:" + zql); }
源代码下载
本示例源代码已在GITEE上开源,大家可以免费下验证:
https://gitee.com/zizai/StudyJavaCV