单标签写入读取
#coding='utf-8'
import lmdb
import caffe
from matplotlib import pyplot as plt
import numpy as np
def write_lmdb(filename,X,y):
N = len(y)
map_size = X.nbytes * 10
env = lmdb.open(filename,map_size=map_size)
with env.begin(write = True) as txn:
for i in range(N):
datum = caffe.io.array_to_datum(X[i,:,:,:])
datum.label = int(y[i])
txn.put('{:0>10d}'.format(i).encode('ascii'),datum.SerializeToString())
def read_lmdb(filename):
env = lmdb.open(filename, readonly=True)
with env.begin(write=False) as txn:
cursor = txn.cursor()
datum = caffe.proto.caffe_pb2.Datum()
i=0
for key,value in cursor:
i=i+1
datum.ParseFromString(value)
x = caffe.io.datum_to_array(datum)
y = datum.label
return x,y
def main():
N = 1000
x1 = np.random.randint(1,10,(N,3,32,32))
y1 = np.zeros(N,dtype=np.int64)
x2 = np.random.randint(1,10,(N,3,32,32)) + 10
y2 = np.ones(N,dtype=np.int64)
x3 = np.random.randint(1,10,(N,3,32,32)) + 20
y3 = np.ones(N,dtype=np.int64)*2
x4 = np.random.randint(1,10,(N,3,32,32)) + 30
y4 = np.ones(N,dtype=np.int64)*3
X = np.vstack((x1,x2,x3,x4))
y = np.hstack((y1,y2,y3,y4))
idx = np.arange(len(y))
np.random.shuffle(idx)
TRAIN_NUM = int(4*len(y)/5)
write_lmdb("hbk_lmdb_train",X[idx[0:TRAIN_NUM],:,:,:],y[idx[0:TRAIN_NUM]])
write_lmdb("hbk_lmdb_test",X[idx[0:TRAIN_NUM],:,:,:],y[idx[TRAIN_NUM:]])
X1, y1 = read_lmdb("hbk_lmdb_train")
print (X1.shape, y1)
print (np.mean(X))
main()
多标签写入
import numpy as np
import lmdb
import caffe
def write_lmdb_data(filename, X):
"""
filename: lmdb data dir
x: data
y: label
"""
N = X.shape[0]
map_size = X.nbytes * 10
env = lmdb.open(filename, map_size=map_size)
with env.begin(write=True) as txn:
for i in range(N):
datum = caffe.io.array_to_datum(X[i,:,:,:])
txn.put('{:0>10d}'.format(i).encode('ascii'), datum.SerializeToString())
if __name__ == '__main__':
N = 1000
X1 = np.random.randint(1, 10, (N, 3, 32, 32))
# 0,0,0,0,....
y1 = np.zeros((N,10,1,1), dtype=np.int64)
X2 = np.random.randint(1, 10, (N, 3, 32, 32))+10
# 0,1,0,1,0,....
y2 = np.zeros((N,10,1, 1), dtype=np.int64)
y2[:,1,:,:] = 1; y2[:,3,:, :] = 1
X3 = np.random.randint(1, 10, (N, 3, 32, 32))+20
# 1,0,1,0,0,....
y3 = np.zeros((N,10,1, 1), dtype=np.int64)
y3[:,0,:, :] = 1; y3[:,2,:, :] = 1
X4 = np.random.randint(1, 10, (N, 3, 32, 32))+30
# 1,1,1,1,....
y4 = np.ones((N,10,1,1), dtype=np.int64)
X = np.vstack((X1, X2, X3, X4))
y = np.vstack((y1, y2, y3, y4))
idx = np.arange(len(y))
np.random.shuffle(idx)
TRAIN_NUM = int(4*len(y)/5)
write_lmdb_data("lmdb_train_data", X[idx[0:TRAIN_NUM], :, :, :])
write_lmdb_data("lmdb_train_label", y[idx[0:TRAIN_NUM], :])
write_lmdb_data("lmdb_test_data", X[idx[TRAIN_NUM:], :, :, :])
write_lmdb_data("lmdb_test_label", y[idx[TRAIN_NUM:], :])
print (np.mean(X))
说明,读取数据报错,尚未解决