本文共 2822 字,大约阅读时间需要 9 分钟。
data层用在训练或测试阶段,为模型提供数据接口,caffe可以接受的数据类型包括数据库类型(如LMDB、LevelDB)、hdf5、内存数据、图片数据等。
该类型数据必须指定数据库文件夹路径,该文件夹内包含一个data.mdb文件和一个lock.mdb文件;还需要指定batch_size.
可选参数包括:rand_skip: 在开始的时候,路过某个数据的输入。通常对异步的SGD很有用。
backend: 选择是采用LevelDB还是LMDB, 默认是LevelDB.
prototxt文件对应内容:layer { name: "mnist" type: "Data" top: "data" top: "label" include { phase: TRAIN } transform_param { scale: 0.003 mean_value: 52 } data_param { source: "examples/mnist/mnist_train_lmdb" batch_size: 16 backend: LMDB }}
用python API定义该层代码:
n.data, n.label = caffe.layers.Data(batch_size=16, source= "examples/mnist/mnist_train_lmdb", ntop=2, backend = P.Data.LMDB, include=dict(phase=caffe.TRAIN), transform_param=dict(scale=0.003, mean_value=52))
同样需要指定扩展名为h5的数据文件路径,此外,也可指定包含多个h5路径的文件;也需要指定batch_size.
prototxt文件对应内容:layer { name: "InputData" type: "HDF5Data" top: "data" top: "label" include { phase: TRAIN } hdf5_data_param { source: "./training_data_paths.txt" batch_size: 64 }}
python API代码:
net.data, net.label = caffe.layers.HDF5Data( name="InputData", source='./training_data_paths.txt', batch_size=64, include=dict(phase=caffe.TRAIN), ntop=2 )
这种类型经常用来做分类任务,图片数据文件每一行给出图片的路径及该图片对应的类别。
layer { name: "InputData" type: "ImageData" top: "data" top: "label" transform_param { mirror: true crop_size: 40 } image_data_param { source: "train.txt" batch_size: 32 shuffle: true new_height: 48 new_width: 48 is_color: true root_folder: "/" }}
python API代码:
net.data ,net.label = caffe.layers.ImageData( name="InputData" source="train.txt", batch_size=32, new_width=48, new_height=48, ntop=2, is_color=True, shuffle=True, root_folder='/', transform_param=dict(crop_size=40,mirror=True))
直接用内存中的数据训练模型,这类数据往往是ndarray型。
layer { name: "data" type: "ImageData" top: "data" top: "label" transform_param { scale: 0.00390625 mean_value: 20.0 } image_data_param { source: "img_list" } memory_data_param { batch_size: 16 channels: 1 height: 1 width: 230 }}
python API代码:
def conv_pool_net(): n = caffe.NetSpec() n.data, n.label = L.ImageData(source='img_list', memory_data_param=dict(batch_size=16, height=1, width=230, channels=1), ntop=2,transform_param=dict(scale=0.00390625, mean_value=20)) return n.to_proto()print(str(conv_pool_net()))
部署时,data层要做一下转换(图片来自网络):
最后来张网上的总结图片,总结的很好:参考:
转载地址:http://forti.baihongyu.com/