该代码支持多图像裁剪

先将tif格式的图片转为png

再对多个png图片进行批量裁剪

批量裁剪:

import os

# import gdal_makeData
import numpy as np
from osgeo import gdal
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,40).__str__()
import cv2 as cv
from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

'''
single_cutImg函数的imread可以将输入的png格式的图片转为8位彩色后,再进行裁剪

'''
def read_gdal(path):
    """
        读取一个tiff图像
    :param path: 要读取的图像路径(包括后缀名)
    :type path: string
    :return im_data: 返回图像矩阵(h, w, c)
    :rtype im_data: numpy
    :return im_proj: 返回投影信息
    :rtype im_proj: ?
    :return im_geotrans: 返回坐标信息
    :rtype im_geotrans: ?
    """
    image = gdal.Open(path)  # 打开该图像
    if image is None:
        print(path + "文件无法打开")
        return
    img_w = image.RasterXSize  # 栅格矩阵的列数
    img_h = image.RasterYSize  # 栅格矩阵的行数
    im_bands = image.RasterCount  # 波段数
    im_proj = image.GetProjection()  # 获取投影信息
    im_geotrans = image.GetGeoTransform()  # 仿射矩阵
    im_data = image.ReadAsArray(0, 0, img_w, img_h)

    # 二值图一般是二维,需要添加一个维度
    if len(im_data.shape) == 2:
        im_data = im_data[np.newaxis, :, :]

    im_data = im_data.transpose((1, 2, 0))
    print("栅格矩阵的列数: ", img_w)
    print("栅格矩阵的行数: ", img_h)
    print("栅格矩阵的波段数: ", im_bands)
    print("栅格矩阵的投影信息: ", im_proj)
    print("栅格矩阵的仿射矩阵信息: ", im_geotrans)
    print("ia_data形状:", im_data.shape)
    return im_data, im_proj, im_geotrans


def write_gdal(im_data, path, im_proj=None, im_geotrans=None):
    """
        重新写一个tiff图像
    :param im_data: 图像矩阵(h, w, c)
    :type im_data: numpy
    :param im_proj: 要设置的投影信息(默认None)
    :type im_proj: ?
    :param im_geotrans: 要设置的坐标信息(默认None)
    :type im_geotrans: ?
    :param path: 生成的图像路径(包括后缀名)
    :type path: string
    :return: None
    :rtype: None
    """
    im_data = im_data.transpose((2, 0, 1))

    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    elif 'float32' in im_data.dtype.name:
        datatype = gdal.GDT_Float32
    else:
        datatype = gdal.GDT_Float64
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape

    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if dataset is not None:
        if im_geotrans is None or im_proj is None:
            pass
        else:
            dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
            dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


def single_set_proj_trans(ori_path, target_path):
    """
        为 target_path 影像设置 ori_path 影像的投影、坐标信息
    :param ori_path: 获取 ori_path 影像路径
    :type ori_path: string
    :param target_path: 获取 target_path 影像路径
    :type target_path: string
    :return: None
    :rtype: None
    """
    # 原图像导入
    _, im_proj, im_geotrans = read_gdal(ori_path)
    # 目标二值图导入
    im_data, _, _ = read_gdal(target_path)
    # 地理信息写入
    write_gdal(im_data, target_path, im_proj, im_geotrans)


def tif_bands_math(tif_data):
    # 读取image.tif 4bands   转换为uint16  3bands
    h, w, bands = tif_data.shape
    print("获取函数传递的tif数据为:", tif_data.shape)

    band1 = tif_data[:, :, 0]
    band2 = tif_data[:, :, 1]
    band3 = tif_data[:, :, 2]
    # band4 = tif_data[:, :, 3]
    # band5 = tif_data[:, :, 4]
    # band6 = tif_data[:, :, 5]
    # band7 = tif_data[:, :, 6]
    # band8 = tif_data[:, :, 7]

    new_tif_data = np.stack((band1, band2, band3))
    print("new_tif_data.shape::", new_tif_data.shape)
    new_tif_data = new_tif_data.transpose(1, 2, 0)
    print(new_tif_data.shape)
    new_tif_data = new_tif_data / (np.max(new_tif_data)) * 255
    new_tif_data = new_tif_data.astype(np.uint8)

    return new_tif_data


def tif_bands_math2(tif_data):
    # 读取image.tif 4bands   转换为uint16  3bands
    h, w, bands = tif_data.shape
    print("获取函数传递的tif数据为:", tif_data.shape)

    band1 = tif_data[:, :, 0]
    band2 = tif_data[:, :, 1]
    band3 = tif_data[:, :, 2]
    band4 = tif_data[:, :, 3]
    band5 = tif_data[:, :, 4]
    band6 = tif_data[:, :, 5]
    # band7 = tif_data[:, :, 6]
    # band8 = tif_data[:, :, 7]

    new_tif_data = np.stack((band1, band2, band3))
    print("new_tif_data.shape::", new_tif_data.shape)
    new_tif_data = new_tif_data.transpose(1, 2, 0)
    print(new_tif_data.shape)
    new_tif_data = new_tif_data / (np.max(new_tif_data)) * 255
    new_tif_data = new_tif_data.astype(np.uint8)

    return new_tif_data

    # a = (band4 - band3) * 1.0
    # print(type(a))
    # b = ((band4 + band3) + 0.00001)/ 1.0
    # print(type(b))
    # ndvi = a / b
    # print(type(ndvi))
    #
    # new_ndvi = ndvi / (np.max(ndvi)) * 65536
    # new_tif_data2 = np.dstack((new_ndvi, band5, band6))
    #
    # print("new_tif_data2.shape::", new_tif_data2.shape)
    # new_tif_data = np.dstack((band1, band2, band3, band4))
    # new_tif_data = np.dstack((band1, band2, band3))
    # new_tif_data = new_tif_data / (np.max(new_tif_data)) * 255
    # new_tif_data = new_tif_data.astype(np.uint8)
    # return new_tif_data


'''
ndvi
    h, w = band3.shape

    ndvi = np.zeros((h, w)).astype(np.float32)

    for i in range(h):
        for j in range(w):
            # if (band6[i][j] + band3[i][j]) == 0:
            if band6[i][j] == 0 and band3[i][j] == 0:
                # ndvi[i][j] = 0
                continue
            else:
                a = band6[i][j].astype(np.float32)
                b = band3[i][j].astype(np.float32)

                # print(ndvi[i][j], band6[i][j], band3[i][j])
                # ndvi[i][j] = (band6[i][j] - band3[i][j]).astype(np.float32) / (band6[i][j] + band3[i][j]).astype(np.float32)
                ndvi[i][j] = (a - b).astype(np.float32) / (a + b).astype(np.float32)
            # print(ndvi[i][j])

    print("ndvi最大值", np.max(ndvi))
    print("ndvi最小值", np.min(ndvi))

    # for i in range(h):
    #     for j in range(w):
    #         ndvi[i][j] = ndvi / np.max(ndvi)
    # ndvi = ((ndvi - np.min(ndvi)) / (np.max(ndvi) - np.min(ndvi))) * 255
    #
    # band5 = ((band5 - np.min(band5)) / (np.max(band5) - np.min(band5))) * 255
    # band4 = ((band4 - np.min(band4)) / (np.max(band4) - np.min(band4))) * 255
    ndvi = ((ndvi - np.min(ndvi)) / (np.max(ndvi) - np.min(ndvi)) * 255).astype(np.uint8)


    new_tif_data = np.stack((ndvi, ndvi, ndvi))
    # new_tif_data = np.stack((ndvi, band4, band5))
    print("new_tif_data.shape::", new_tif_data.shape)
    new_tif_data = new_tif_data.transpose(1, 2, 0)
    print(new_tif_data.shape)
'''


def single_tif2pngORjpg(tifPath, savePath):
    # driver = gdal.GetDriverByName('JPEG')
    driver = gdal.GetDriverByName('PNG')
    data0 = gdal.Open(tifPath)

    # data0.astype(np.byte)

    """ ERROR 6: PNG driver doesn't support data type Int16. 
    # Only eight bit (Byte) and sixteen bit (UInt16) bands supported."""

    # data1 = driver.CreateCopy(savePath, data0)
    driver.CreateCopy(savePath, data0)

    print("路径{}中单个tif数据格式转换完成。。。。。。".format(tifPath))


def tif2pngORjpg(tiffilePath, tifsavePath):
    # file_path = r"/home/dsj/dsj_Lab/lmy/Adata/zhuhaiGF/data/clipTIF/imageTIFF/"
    # save_path = r"/home/dsj/dsj_Lab/lmy/Adata/zhuhaiGF/data/JPG/"
    file_path = tiffilePath
    save_path = tifsavePath

    driver = gdal.GetDriverByName('PNG')

    files = [f for f in os.listdir(file_path) if f.endswith('.tif')]
    for each_file in files:
        file = file_path + '\\' + each_file  # 各tif数据路径全称
        fileName, fileLastName = os.path.splitext(each_file)  # 文件名, 后缀名
        oridata = gdal.Open(file)
        driver.CreateCopy(save_path + '\\' + fileName + '.png', oridata)
    print("路径{}中所有tif数据格式转换完成。。。。。。".format(file_path))


def single_cutImg(imgPath, savePath):
    #  拆分影像图的文件名称
    imgFirstName, imgLastName = os.path.splitext(imgPath)

    # img = Image.open(imgPath)
    # print(type(img))
    # print(img.size)   宽,高

    img = cv.imread(imgPath, -1)
    # print(img.dtype)
    # 第二个参数是通道数和位深的参数,
    '''
        # IMREAD_UNCHANGED = -1  # 不进行转化,比如保存为了16位的图片,读取出来仍然为16位。
        # IMREAD_GRAYSCALE = 0  # 进行转化为灰度图,比如保存为了16位的图片,读取出来为8位,类型为CV_8UC1。
        # IMREAD_COLOR = 1   # 进行转化为RGB三通道图像,图像深度转为8位
        # IMREAD_ANYDEPTH = 2  # 保持图像深度不变,进行转化为灰度图。
        # IMREAD_ANYCOLOR = 4  # 若图像通道数小于等于3,则保持原通道数不变;若通道数大于3则只取取前三个通道。图像深度转为8位
        # cv.imwrite(ResultPath1 + a + ".png", img)  # 保存为png格式
    '''

    # width, hight = img.size
    # print(img.shape)

    if len(img.shape) == 2:
        imgLabel = Image.open(imgPath)
        width, hight = imgLabel.size
    elif len(img.shape) == 3:
        hight, width, bands = img.shape


    # w = 512  # 宽度
    # h = 512  # 高度



    w = 512
    h = 512

    _id = 1  # 裁剪结果保存文件名:0 - N 升序方式
    y = 0
    while y + h <= hight:  # 控制高度,图像多余固定尺寸总和部分不要了
        x = 0
        while x + w <= width:  # 控制宽度,图像多余固定尺寸总和部分不要了
            # new_img = img.crop((x, y, x + w, y + h))
            # new_img.save(ResultPath + a + "_" + str(_id) + b)
            # new_img.save(savePath + '\\' + 'img' + "_" + str(_id) + imgLastName)

            if len(img.shape) == 2:
                new_img = imgLabel.crop((x, y, x + w, y + h))
                # new_img.save(ResultPath + a + "_" + str(_id) + b)
                # values, counts = np.unique(new_img, return_counts=True)
                new_img.save(savePath + '\\' + 'img' + "_" + str(_id) + imgLastName)
                # new_img = img[y: y+h, x: x+w]
            elif len(img.shape) == 3:
                new_img = img[y: y + h, x: x + w, :]
                # values, counts = np.unique(new_img, return_counts=True)
                cv.imwrite(savePath + '\\' + "img_" + str(_id) + '.png', new_img)

            _id += 1
            x += w
        y = y + h

def cutImg_Slide(imgPath, savePath, image_size, stride):
    #  拆分影像图的文件名称
    imgFirstName, imgLastName = os.path.splitext(imgPath)
    img = cv.imread(imgPath, -1)

    if len(img.shape) == 2:
        height, width = img.shape
    elif len(img.shape) == 3:
        height, width, bands = img.shape

    _id = 0  # 裁剪结果保存文件名:0 - N 升序方式
    h_num = int(height / stride)
    w_num = int(width / stride)
    print(h_num)
    print(w_num)

    for i in range(0, h_num):
        for j in range(0, w_num):
            if i * stride + image_size <= height and j * stride + image_size <= width:
                if len(img.shape) == 2:
                    new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size]
                    cv.imwrite(savePath + '\\' + "img_" + str(_id) + '.png', new_img)
                elif len(img.shape) == 3:
                    new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size, :]
                    cv.imwrite(savePath + '\\' + "img_" + str(_id) + '.png', new_img)
            _id += 1


def cutImg_Slide_2(file_path, save_path, image_size, stride):

    files = [f for f in os.listdir(file_path) if f.endswith('.png')]
    _num = 1
    for each_file in files:
        file = file_path + '\\' + each_file  # 各tif数据路径全称
        fileName, fileLastName = os.path.splitext(each_file)  # 文件名, 后缀名
        #oridata = gdal.Open(file)
        #driver.CreateCopy(save_path + '\\' + fileName + '.png', oridata)

        #  拆分影像图的文件名称
        imgFirstName, imgLastName = os.path.splitext(file)
        img = cv.imread(file, -1)

        if len(img.shape) == 2:
            height, width = img.shape
        elif len(img.shape) == 3:
            height, width, bands = img.shape

        _id = 0  # 裁剪结果保存文件名:0 - N 升序方式
        h_num = int(height / stride)
        w_num = int(width / stride)
        print(h_num)
        print(w_num)

        for i in range(0, h_num):
            for j in range(0, w_num):
                if i * stride + image_size <= height and j * stride + image_size <= width:
                    if len(img.shape) == 2:
                        new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size]
                        cv.imwrite(save_path + '\\' + "img_" +str(_num) +"_img_" + str(_id) + '.png', new_img)
                    elif len(img.shape) == 3:
                        new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size, :]
                        cv.imwrite(save_path + '\\' +"img_" +str(_num) +"_img_" + str(_id) +'.png', new_img)
                _id += 1
        _num+=1




    #
    # #  拆分影像图的文件名称
    # imgFirstName, imgLastName = os.path.splitext(imgPath)
    # img = cv.imread(imgPath, -1)
    #
    # if len(img.shape) == 2:
    #     height, width = img.shape
    # elif len(img.shape) == 3:
    #     height, width, bands = img.shape
    #
    # _id = 0  # 裁剪结果保存文件名:0 - N 升序方式
    # h_num = int(height / stride)
    # w_num = int(width / stride)
    # print(h_num)
    # print(w_num)
    #
    # for i in range(0, h_num):
    #     for j in range(0, w_num):
    #         if i * stride + image_size <= height and j * stride + image_size <= width:
    #             if len(img.shape) == 2:
    #                 new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size]
    #                 cv.imwrite(savePath + '\\' + "img_" + str(_id) + '.png', new_img)
    #             elif len(img.shape) == 3:
    #                 new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size, :]
    #                 cv.imwrite(savePath + '\\' + "img_" + str(_id) + '.png', new_img)
    #         _id += 1



def makeTIFF(tif_dir, tif_saveDir):
    tif_data, tif_proj, tif_geotrans = read_gdal(tif_dir)
    print("读取的tif数据为:", tif_data.shape)

    new_tif_data = tif_bands_math(tif_data)

    write_gdal(new_tif_data, tif_saveDir, tif_proj, tif_geotrans)


if __name__ == '__main__':

    # path = r'E:\New_gj\dataset\T0\data_pre\labels\label_1.tif'
    # img, _, _ = read_gdal(path)
    #
    # values, counts = np.unique(img, return_counts=True)
    # print(values)
    # print(counts)
    # exit()

    path = r'E:\New_gj\dataset\T_multi_22\data_pre'  #图像路径
    dataset_path = r'E:\New_gj\dataset\T_multi_22\cut'  #存储路径

    # img
    #读取tif bands   转换为3bands
    #tif_dir = r'E:\New_gj\dataset\T2\data_pre\images\image.tif'
    #tif_saveDir = path + r'\images\image.tif'   #三波段图片名
    #makeTIFF(tif_dir, tif_saveDir)  #转三波段

    tif_saveDir=path + r'\images'
    pngORjpg_path = path + r'\images\png'  # png,jpg存储路径
    # tif转格式  路径不要出现中文
    #pngORjpg_path = path + r'\images\image.png'  #png,jpg存储路径
    #single_tif2pngORjpg(tif_saveDir, pngORjpg_path)

    tif2pngORjpg(tif_saveDir, pngORjpg_path)



    # 裁剪
    image_size = 512#输出尺寸
    stride = 512
    pngORjpg_savePath = dataset_path + r'\images_pre'  #保存路径
    cutImg_Slide_2(pngORjpg_path, pngORjpg_savePath, image_size, stride)


    # -------------------------------------------------------------------------------
    # label
    tif_saveDir_label = path + r'\labels'
    pngORjpg_path_label = path + r'\labels\png'
    #single_tif2pngORjpg(tif_saveDir_label, pngORjpg_path_label)
    tif2pngORjpg(tif_saveDir_label, pngORjpg_path_label)

    pngORjpg_savePath_label = dataset_path + r'\labels_pre'
    cutImg_Slide_2(pngORjpg_path_label, pngORjpg_savePath_label, image_size, stride)

单张图片裁剪

import os

# import gdal_makeData
import numpy as np
from osgeo import gdal
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,40).__str__()
import cv2 as cv
from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

'''
single_cutImg函数的imread可以将输入的png格式的图片转为8位彩色后,再进行裁剪

'''
def read_gdal(path):
    """
        读取一个tiff图像
    :param path: 要读取的图像路径(包括后缀名)
    :type path: string
    :return im_data: 返回图像矩阵(h, w, c)
    :rtype im_data: numpy
    :return im_proj: 返回投影信息
    :rtype im_proj: ?
    :return im_geotrans: 返回坐标信息
    :rtype im_geotrans: ?
    """
    image = gdal.Open(path)  # 打开该图像
    if image is None:
        print(path + "文件无法打开")
        return
    img_w = image.RasterXSize  # 栅格矩阵的列数
    img_h = image.RasterYSize  # 栅格矩阵的行数
    im_bands = image.RasterCount  # 波段数
    im_proj = image.GetProjection()  # 获取投影信息
    im_geotrans = image.GetGeoTransform()  # 仿射矩阵
    im_data = image.ReadAsArray(0, 0, img_w, img_h)

    # 二值图一般是二维,需要添加一个维度
    if len(im_data.shape) == 2:
        im_data = im_data[np.newaxis, :, :]

    im_data = im_data.transpose((1, 2, 0))
    print("栅格矩阵的列数: ", img_w)
    print("栅格矩阵的行数: ", img_h)
    print("栅格矩阵的波段数: ", im_bands)
    print("栅格矩阵的投影信息: ", im_proj)
    print("栅格矩阵的仿射矩阵信息: ", im_geotrans)
    print("ia_data形状:", im_data.shape)
    return im_data, im_proj, im_geotrans


def write_gdal(im_data, path, im_proj=None, im_geotrans=None):
    """
        重新写一个tiff图像
    :param im_data: 图像矩阵(h, w, c)
    :type im_data: numpy
    :param im_proj: 要设置的投影信息(默认None)
    :type im_proj: ?
    :param im_geotrans: 要设置的坐标信息(默认None)
    :type im_geotrans: ?
    :param path: 生成的图像路径(包括后缀名)
    :type path: string
    :return: None
    :rtype: None
    """
    im_data = im_data.transpose((2, 0, 1))

    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    elif 'float32' in im_data.dtype.name:
        datatype = gdal.GDT_Float32
    else:
        datatype = gdal.GDT_Float64
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
    else:
        im_bands, (im_height, im_width) = 1, im_data.shape

    # 创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
    if dataset is not None:
        if im_geotrans is None or im_proj is None:
            pass
        else:
            dataset.SetGeoTransform(im_geotrans)  # 写入仿射变换参数
            dataset.SetProjection(im_proj)  # 写入投影
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


def single_set_proj_trans(ori_path, target_path):
    """
        为 target_path 影像设置 ori_path 影像的投影、坐标信息
    :param ori_path: 获取 ori_path 影像路径
    :type ori_path: string
    :param target_path: 获取 target_path 影像路径
    :type target_path: string
    :return: None
    :rtype: None
    """
    # 原图像导入
    _, im_proj, im_geotrans = read_gdal(ori_path)
    # 目标二值图导入
    im_data, _, _ = read_gdal(target_path)
    # 地理信息写入
    write_gdal(im_data, target_path, im_proj, im_geotrans)


def tif_bands_math(tif_data):
    # 读取image.tif 4bands   转换为uint16  3bands
    h, w, bands = tif_data.shape
    print("获取函数传递的tif数据为:", tif_data.shape)

    band1 = tif_data[:, :, 0]
    band2 = tif_data[:, :, 1]
    band3 = tif_data[:, :, 2]
    # band4 = tif_data[:, :, 3]
    # band5 = tif_data[:, :, 4]
    # band6 = tif_data[:, :, 5]
    # band7 = tif_data[:, :, 6]
    # band8 = tif_data[:, :, 7]

    new_tif_data = np.stack((band1, band2, band3))
    print("new_tif_data.shape::", new_tif_data.shape)
    new_tif_data = new_tif_data.transpose(1, 2, 0)
    print(new_tif_data.shape)
    new_tif_data = new_tif_data / (np.max(new_tif_data)) * 255
    new_tif_data = new_tif_data.astype(np.uint8)

    return new_tif_data


def tif_bands_math2(tif_data):
    # 读取image.tif 4bands   转换为uint16  3bands
    h, w, bands = tif_data.shape
    print("获取函数传递的tif数据为:", tif_data.shape)

    band1 = tif_data[:, :, 0]
    band2 = tif_data[:, :, 1]
    band3 = tif_data[:, :, 2]
    band4 = tif_data[:, :, 3]
    band5 = tif_data[:, :, 4]
    band6 = tif_data[:, :, 5]
    # band7 = tif_data[:, :, 6]
    # band8 = tif_data[:, :, 7]

    new_tif_data = np.stack((band1, band2, band3))
    print("new_tif_data.shape::", new_tif_data.shape)
    new_tif_data = new_tif_data.transpose(1, 2, 0)
    print(new_tif_data.shape)
    new_tif_data = new_tif_data / (np.max(new_tif_data)) * 255
    new_tif_data = new_tif_data.astype(np.uint8)

    return new_tif_data

    # a = (band4 - band3) * 1.0
    # print(type(a))
    # b = ((band4 + band3) + 0.00001)/ 1.0
    # print(type(b))
    # ndvi = a / b
    # print(type(ndvi))
    #
    # new_ndvi = ndvi / (np.max(ndvi)) * 65536
    # new_tif_data2 = np.dstack((new_ndvi, band5, band6))
    #
    # print("new_tif_data2.shape::", new_tif_data2.shape)
    # new_tif_data = np.dstack((band1, band2, band3, band4))
    # new_tif_data = np.dstack((band1, band2, band3))
    # new_tif_data = new_tif_data / (np.max(new_tif_data)) * 255
    # new_tif_data = new_tif_data.astype(np.uint8)
    # return new_tif_data


'''
ndvi
    h, w = band3.shape

    ndvi = np.zeros((h, w)).astype(np.float32)

    for i in range(h):
        for j in range(w):
            # if (band6[i][j] + band3[i][j]) == 0:
            if band6[i][j] == 0 and band3[i][j] == 0:
                # ndvi[i][j] = 0
                continue
            else:
                a = band6[i][j].astype(np.float32)
                b = band3[i][j].astype(np.float32)

                # print(ndvi[i][j], band6[i][j], band3[i][j])
                # ndvi[i][j] = (band6[i][j] - band3[i][j]).astype(np.float32) / (band6[i][j] + band3[i][j]).astype(np.float32)
                ndvi[i][j] = (a - b).astype(np.float32) / (a + b).astype(np.float32)
            # print(ndvi[i][j])

    print("ndvi最大值", np.max(ndvi))
    print("ndvi最小值", np.min(ndvi))

    # for i in range(h):
    #     for j in range(w):
    #         ndvi[i][j] = ndvi / np.max(ndvi)
    # ndvi = ((ndvi - np.min(ndvi)) / (np.max(ndvi) - np.min(ndvi))) * 255
    #
    # band5 = ((band5 - np.min(band5)) / (np.max(band5) - np.min(band5))) * 255
    # band4 = ((band4 - np.min(band4)) / (np.max(band4) - np.min(band4))) * 255
    ndvi = ((ndvi - np.min(ndvi)) / (np.max(ndvi) - np.min(ndvi)) * 255).astype(np.uint8)


    new_tif_data = np.stack((ndvi, ndvi, ndvi))
    # new_tif_data = np.stack((ndvi, band4, band5))
    print("new_tif_data.shape::", new_tif_data.shape)
    new_tif_data = new_tif_data.transpose(1, 2, 0)
    print(new_tif_data.shape)
'''


def single_tif2pngORjpg(tifPath, savePath):
    # driver = gdal.GetDriverByName('JPEG')
    driver = gdal.GetDriverByName('PNG')
    data0 = gdal.Open(tifPath)

    # data0.astype(np.byte)

    """ ERROR 6: PNG driver doesn't support data type Int16. 
    # Only eight bit (Byte) and sixteen bit (UInt16) bands supported."""

    # data1 = driver.CreateCopy(savePath, data0)
    driver.CreateCopy(savePath, data0)

    print("路径{}中单个tif数据格式转换完成。。。。。。".format(tifPath))


def tif2pngORjpg(tiffilePath, tifsavePath):
    # file_path = r"/home/dsj/dsj_Lab/lmy/Adata/zhuhaiGF/data/clipTIF/imageTIFF/"
    # save_path = r"/home/dsj/dsj_Lab/lmy/Adata/zhuhaiGF/data/JPG/"
    file_path = tiffilePath
    save_path = tifsavePath

    driver = gdal.GetDriverByName('PNG')

    files = [f for f in os.listdir(file_path) if f.endswith('.tif')]
    for each_file in files:
        file = file_path + each_file  # 各tif数据路径全称
        fileName, fileLastName = os.path.splitext(each_file)  # 文件名, 后缀名
        oridata = gdal.Open(file)
    data = driver.CreateCopy(save_path + 'zhuhai_' + fileName + '.png', oridata)
    print("路径{}中所有tif数据格式转换完成。。。。。。".format(file_path))


def single_cutImg(imgPath, savePath):
    #  拆分影像图的文件名称
    imgFirstName, imgLastName = os.path.splitext(imgPath)

    # img = Image.open(imgPath)
    # print(type(img))
    # print(img.size)   宽,高

    img = cv.imread(imgPath, -1)
    # print(img.dtype)
    # 第二个参数是通道数和位深的参数,
    '''
        # IMREAD_UNCHANGED = -1  # 不进行转化,比如保存为了16位的图片,读取出来仍然为16位。
        # IMREAD_GRAYSCALE = 0  # 进行转化为灰度图,比如保存为了16位的图片,读取出来为8位,类型为CV_8UC1。
        # IMREAD_COLOR = 1   # 进行转化为RGB三通道图像,图像深度转为8位
        # IMREAD_ANYDEPTH = 2  # 保持图像深度不变,进行转化为灰度图。
        # IMREAD_ANYCOLOR = 4  # 若图像通道数小于等于3,则保持原通道数不变;若通道数大于3则只取取前三个通道。图像深度转为8位
        # cv.imwrite(ResultPath1 + a + ".png", img)  # 保存为png格式
    '''

    # width, hight = img.size
    # print(img.shape)

    if len(img.shape) == 2:
        imgLabel = Image.open(imgPath)
        width, hight = imgLabel.size
    elif len(img.shape) == 3:
        hight, width, bands = img.shape


    # w = 512  # 宽度
    # h = 512  # 高度



    w = 512
    h = 512

    _id = 1  # 裁剪结果保存文件名:0 - N 升序方式
    y = 0
    while y + h <= hight:  # 控制高度,图像多余固定尺寸总和部分不要了
        x = 0
        while x + w <= width:  # 控制宽度,图像多余固定尺寸总和部分不要了
            # new_img = img.crop((x, y, x + w, y + h))
            # new_img.save(ResultPath + a + "_" + str(_id) + b)
            # new_img.save(savePath + '\\' + 'img' + "_" + str(_id) + imgLastName)

            if len(img.shape) == 2:
                new_img = imgLabel.crop((x, y, x + w, y + h))
                # new_img.save(ResultPath + a + "_" + str(_id) + b)
                # values, counts = np.unique(new_img, return_counts=True)
                new_img.save(savePath + '\\' + 'img' + "_" + str(_id) + imgLastName)
                # new_img = img[y: y+h, x: x+w]
            elif len(img.shape) == 3:
                new_img = img[y: y + h, x: x + w, :]
                # values, counts = np.unique(new_img, return_counts=True)
                cv.imwrite(savePath + '\\' + "img_" + str(_id) + '.png', new_img)

            _id += 1
            x += w
        y = y + h

def cutImg_Slide(imgPath, savePath, image_size, stride):
    #  拆分影像图的文件名称
    imgFirstName, imgLastName = os.path.splitext(imgPath)
    img = cv.imread(imgPath, -1)

    if len(img.shape) == 2:
        height, width = img.shape
    elif len(img.shape) == 3:
        height, width, bands = img.shape

    _id = 0  # 裁剪结果保存文件名:0 - N 升序方式
    h_num = int(height / stride)
    w_num = int(width / stride)
    print(h_num)
    print(w_num)

    for i in range(0, h_num):
        for j in range(0, w_num):
            if i * stride + image_size <= height and j * stride + image_size <= width:
                if len(img.shape) == 2:
                    new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size]
                    cv.imwrite(savePath + '\\' + imgFirstName + "_img_" + str(_id) + '.png', new_img)
                elif len(img.shape) == 3:
                    new_img = img[i * stride: i * stride + image_size, j * stride: j * stride + image_size, :]
                    cv.imwrite(savePath + '\\' + imgFirstName +"_img_" + str(_id) + '.png', new_img)
            _id += 1


def makeTIFF(tif_dir, tif_saveDir):
    tif_data, tif_proj, tif_geotrans = read_gdal(tif_dir)
    print("读取的tif数据为:", tif_data.shape)

    new_tif_data = tif_bands_math(tif_data)

    write_gdal(new_tif_data, tif_saveDir, tif_proj, tif_geotrans)


if __name__ == '__main__':

    # path = r'E:\New_gj\dataset\T0\data_pre\labels\label_1.tif'
    # img, _, _ = read_gdal(path)
    #
    # values, counts = np.unique(img, return_counts=True)
    # print(values)
    # print(counts)
    # exit()

    path = r'E:\New_gj\dataset\T_multi_re_0.1\predict'  #图像路径
    dataset_path = r'E:\New_gj\dataset\T_multi_re_0.1\predict\cut'  #存储路径

    # img
    #读取tif bands   转换为3bands
    #tif_dir = r'E:\New_gj\dataset\T2\data_pre\images\image.tif'
    tif_saveDir = path + r'\images\image_2.tif'   #三波段图片名
    #makeTIFF(tif_dir, tif_saveDir)  #转三波段


    # tif转格式  路径不要出现中文
    pngORjpg_path = path + r'\images\png\image_2.png'  #png,jpg存储路径
    single_tif2pngORjpg(tif_saveDir, pngORjpg_path)

    # 裁剪
    image_size = 512#输出尺寸
    stride = 512
    pngORjpg_savePath = dataset_path + r'\images'  #保存路径
    cutImg_Slide(pngORjpg_path, pngORjpg_savePath, image_size, stride)


    # -------------------------------------------------------------------------------
    # label
    tif_saveDir_label = path + r'\labels\label_2.tif'
    pngORjpg_path_label = path + r'\labels\png\label_2.png'
    single_tif2pngORjpg(tif_saveDir_label, pngORjpg_path_label)


    pngORjpg_savePath_label = dataset_path + r'\labels'
    cutImg_Slide(pngORjpg_path_label, pngORjpg_savePath_label, image_size, stride)