import os
import glob
import SimpleITK as sitk
import numpy as np
import vtk
from vtkmodules.util import numpy_support
import cv2

# 设置路径
slice_image_folder = R'path_to_imgs'
mask_image_folder = R'path_to_masks'

# 获取所有图像文件路径
slice_image_files = sorted(glob.glob(os.path.join(slice_image_folder, '*.jpg')))
mask_image_files = sorted(glob.glob(os.path.join(mask_image_folder, '*.png')))


# 读取图像和掩码,并将其堆叠为3D数组
def read_images(slice_files, mask_files):
    slices = [cv2.imread(file, cv2.IMREAD_GRAYSCALE) for file in slice_files]
    masks = [cv2.imread(file, cv2.IMREAD_GRAYSCALE) for file in mask_files]
    if len(slices) != len(masks):
        raise ValueError("Number of slice images and mask images must be the same.")
    slices_array = np.stack(slices, axis=2)

    pre_mask = []
    for m in masks:
        m[m == 50] = 0
        _, binary_mask = cv2.threshold(m, 1, 255, cv2.THRESH_BINARY)
        pre_mask.append(binary_mask)
    masks_array = np.stack(pre_mask, axis=2)
    return slices_array, masks_array


# 将numpy数组转换为SimpleITK图像
def numpy_to_sitk(image_array):
    sitk_image = sitk.GetImageFromArray(image_array)
    return sitk_image


# 3D可视化
def plot_3d(image_array, mask_array):
    try:
        # 创建VTK图像数据对象
        vtk_image = vtk.vtkImageData()
        vtk_image.SetDimensions(image_array.shape)
        vtk_image.SetSpacing((1.0, 1.0, 1.0))

        # 将numpy数组转换为VTK数组
        flat_image_array = image_array.ravel(order='F')
        vtk_image_array = numpy_support.numpy_to_vtk(num_array=flat_image_array, deep=True, array_type=vtk.VTK_FLOAT)
        vtk_image.GetPointData().SetScalars(vtk_image_array)

        # 将掩码应用于VTK图像
        flat_mask_array = mask_array.ravel(order='F')
        vtk_mask_array = numpy_support.numpy_to_vtk(num_array=flat_mask_array, deep=True, array_type=vtk.VTK_FLOAT)
        vtk_mask_image = vtk.vtkImageData()
        vtk_mask_image.SetDimensions(mask_array.shape)
        vtk_mask_image.SetSpacing((1.0, 1.0, 1.0))
        vtk_mask_image.GetPointData().SetScalars(vtk_mask_array)

        # 创建颜色映射器
        colorFunc = vtk.vtkColorTransferFunction()
        colorFunc.AddRGBPoint(0, 0.0, 0.0, 0.0)
        colorFunc.AddRGBPoint(255, 1.0, 1.0, 1.0)

        opacityFunc = vtk.vtkPiecewiseFunction()
        opacityFunc.AddPoint(0, 0.0)
        opacityFunc.AddPoint(255, 1.0)

        # 创建渲染器
        renderer = vtk.vtkRenderer()

        # 图像映射
        volumeMapper = vtk.vtkSmartVolumeMapper()
        volumeMapper.SetInputData(vtk_image)

        volumeProperty = vtk.vtkVolumeProperty()
        volumeProperty.SetColor(colorFunc)
        volumeProperty.SetScalarOpacity(opacityFunc)
        volumeProperty.ShadeOn()
        volumeProperty.SetInterpolationTypeToLinear()

        volume = vtk.vtkVolume()
        volume.SetMapper(volumeMapper)
        volume.SetProperty(volumeProperty)

        renderer.AddVolume(volume)

        # 掩码映射
        maskMapper = vtk.vtkSmartVolumeMapper()
        maskMapper.SetInputData(vtk_mask_image)

        maskOpacityFunc = vtk.vtkPiecewiseFunction()
        maskOpacityFunc.AddPoint(0, 0.0)
        maskOpacityFunc.AddPoint(1, 1.0)

        maskColorFunc = vtk.vtkColorTransferFunction()
        maskColorFunc.AddRGBPoint(0, 0.0, 0.0, 0.0)
        maskColorFunc.AddRGBPoint(1, 1.0, 0.0, 0.0)

        maskProperty = vtk.vtkVolumeProperty()
        maskProperty.SetColor(maskColorFunc)
        maskProperty.SetScalarOpacity(maskOpacityFunc)
        maskProperty.ShadeOn()
        maskProperty.SetInterpolationTypeToLinear()

        maskVolume = vtk.vtkVolume()
        maskVolume.SetMapper(maskMapper)
        maskVolume.SetProperty(maskProperty)

        renderer.AddVolume(maskVolume)

        # 创建渲染窗口
        renderWindow = vtk.vtkRenderWindow()
        renderWindow.AddRenderer(renderer)

        # 创建交互式渲染窗口
        renderWindowInteractor = vtk.vtkRenderWindowInteractor()
        renderWindowInteractor.SetRenderWindow(renderWindow)

        # 开始渲染
        renderWindow.Render()
        renderWindowInteractor.Start()
    except Exception as e:
        print(f"An error occurred: {e}")


# 主函数
def main():
    try:
        slices_array, masks_array = read_images(slice_image_files, mask_image_files)
        plot_3d(slices_array, masks_array)
    except Exception as e:
        print(f"An error occurred in main function: {e}")


if __name__ == '__main__':
    main()