1. 前言
本博文介绍的脚本,能够较为方便在指定区域批量地将遥感影像裁剪成固定大小的切片。
2. 样本准备
影像以及对应的点矢量
3. 基于gdal的裁剪代码
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from osgeo import ogr
import os, sys
import numpy as np
import cv2
import numpy
import gdal
import time
import glob
from osgeo import osr
def del_file(path):
for i in os.listdir(path):
path_file = os.path.join(path, i)
if os.path.isfile(path_file):
os.remove(path_file)
else:
del_file(path_file)
def sampleClip(shp, tif, outputdir, sampletype, size, fieldName='cls', n=None):
time1 = time.clock()
# if not os.path.exists(outputdir):
# os.mkdir(outputdir)
# else:
# del_file(outputdir)
gdal.AllRegister()
lc = gdal.Open(tif)
im_width = lc.RasterXSize
im_height = lc.RasterYSize
im_geotrans = lc.GetGeoTransform()
bandscount = lc.RasterCount
im_proj = lc.GetProjection()
print(im_width, im_height)
gdal.AllRegister()
gdal.SetConfigOption("gdal_FILENAME_IS_UTF8", "YES")
driver = ogr.GetDriverByName('ESRI Shapefile')
dsshp = driver.Open(shp, 0)
if dsshp is None:
print('Could not open ' + 'sites.shp')
sys.exit(1)
layer = dsshp.GetLayer()
xValues = []
yValues = []
m = layer.GetFeatureCount()
feature = layer.GetNextFeature()
print("tif_bands:{0},samples_nums:{1},sample_type:{2},sample_size:{3}*{3}".format(bandscount, m, sampletype,
int(size)))
if n is not None:
pass
else:
n = 1
while feature:
if n < 10:
dirname = "0000000" + str(n)
elif n >= 10 and n < 100:
dirname = "000000" + str(n)
elif n >= 100 and n > 1000:
dirname = "00000" + str(n)
else:
dirname = "0000" + str(n)
# print dirname
dirpath = os.path.join(outputdir, dirname + "_V1")
if not os.path.exists(dirpath):
os.mkdir(dirpath)
tifname = dirname + ".tif"
if "poly" in sampletype or "POLY" in sampletype:
shpname = dirname + "_V1_POLY.shp"
if "line" in sampletype or "LINE" in sampletype:
shpname = dirname + "_V1_LINE.shp"
geometry = feature.GetGeometryRef()
x = geometry.GetX()
y = geometry.GetY()
print(x, y)
print(im_geotrans)
xValues.append(x)
yValues.append(y)
newform = []
newform = list(im_geotrans)
# print newform
newform[0] = x - im_geotrans[1] * int(size) / 2.0
newform[3] = y - im_geotrans[5] * int(size) / 2.0
print(newform[0], newform[3])
newformtuple = tuple(newform)
x1 = x - int(size) / 2 * im_geotrans[1]
y1 = y - int(size) / 2 * im_geotrans[5]
x2 = x + int(size) / 2 * im_geotrans[1]
y2 = y - int(size) / 2 * im_geotrans[5]
x3 = x - int(size) / 2 * im_geotrans[1]
y3 = y + int(size) / 2 * im_geotrans[5]
x4 = x + int(size) / 2 * im_geotrans[1]
y4 = y + int(size) / 2 * im_geotrans[5]
Xpix = (x1 - im_geotrans[0]) / im_geotrans[1]
# Xpix=(newform[0]-im_geotrans[0])
Ypix = (newform[3] - im_geotrans[3]) / im_geotrans[5]
# Ypix=abs(newform[3]-im_geotrans[3])
print("#################")
print(Xpix, Ypix)
# **************create tif**********************
# print"start creating {0}".format(tifname)
pBuf = None
pBuf = lc.ReadAsArray(int(Xpix), int(Ypix), int(size), int(size))
# print pBuf.dtype.name
driver = gdal.GetDriverByName("GTiff")
create_option = []
if 'int8' in pBuf.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in pBuf.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
outtif = os.path.join(dirpath, tifname)
ds = driver.Create(outtif, int(size), int(size), int(bandscount), datatype, options=create_option)
if ds == None:
print("2222")
ds.SetProjection(im_proj)
ds.SetGeoTransform(newformtuple)
ds.FlushCache()
if bandscount > 1:
for i in range(int(bandscount)):
outBand = ds.GetRasterBand(i + 1)
outBand.WriteArray(pBuf[i])
else:
outBand = ds.GetRasterBand(1)
outBand.WriteArray(pBuf)
ds.FlushCache()
# print "creating {0} successfully".format(tifname)
# **************create shp**********************
# print"start creating shps"
gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "NO")
gdal.SetConfigOption("SHAPE_ENCODING", "")
strVectorFile = os.path.join(dirpath, shpname)
ogr.RegisterAll()
driver = ogr.GetDriverByName('ESRI Shapefile')
ds = driver.Open(shp)
layer0 = ds.GetLayerByIndex(0)
prosrs = layer0.GetSpatialRef()
# geosrs = osr.SpatialReference()
oDriver = ogr.GetDriverByName("ESRI Shapefile")
if oDriver == None:
print("1")
return
oDS = oDriver.CreateDataSource(strVectorFile)
if oDS == None:
print("2")
return
papszLCO = []
if "line" in sampletype or "LINE" in sampletype:
oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbLineString, papszLCO)
if "poly" in sampletype or "POLY" in sampletype:
oLayer = oDS.CreateLayer("TestPolygon", prosrs, ogr.wkbPolygon, papszLCO)
if oLayer == None:
print("3")
return
oFieldName = ogr.FieldDefn(fieldName, ogr.OFTString)
oFieldName.SetWidth(50)
oLayer.CreateField(oFieldName, 1)
oDefn = oLayer.GetLayerDefn()
oFeatureRectangle = ogr.Feature(oDefn)
geomRectangle = ogr.CreateGeometryFromWkt(
"POLYGON (({0} {1},{2} {3},{4} {5},{6} {7},{0} {1}))".format(x1, y1, x2, y2, x4, y4, x3, y3))
oFeatureRectangle.SetGeometry(geomRectangle)
oLayer.CreateFeature(oFeatureRectangle)
print("{0} ok".format(dirname))
n = n + 1
feature = layer.GetNextFeature()
time2 = time.clock()
print('Process Running time: %s min' % ((time2 - time1) / 60))
return n
def mkdir(path):
if not os.path.exists(path):
os.mkdir(path)
if __name__ == "__main__":
tifList = ['road_use_re.tif'] # 影像列表
outputdir = 'data2' # 输出路径
mkdir(outputdir)
sampletype = "line" # 样本类型(线line或者面poly)
size = 640 # 样本大小
n = 1 # 开始序号
fieldName = 'cls' # 字段名
for tif in tifList:
subRoot = os.path.split(tif)[0]
shp = 'train.shp'
assert os.path.exists(shp), 'check you shp file'
n = sampleClip(shp, tif, outputdir, sampletype, size, fieldName, n)
4. 效果预览
4. 封装好的exe
只支持windows
链接:https://pan.baidu.com/s/1R_7K4ObHRCvmbxsXEZoJiw
提取码:ec1m