医学图像分类简介

Copyright © 2022 Institute for Quantum Computing, Baidu Inc. All Rights Reserved.

医学图像分类(Medical image classification)是计算机辅助诊断系统的关键技术。医学图像分类问题主要是如何从图像中提取特征并进行分类,从而识别和了解人体的哪些部位受到特定疾病的影响。在这里我们主要使用量子神经网络对公开数据集 MedMNIST 中的胸腔数据进行分类。其中数据形式如下图所示:

医学图像分类常用数据集 医学图像分类模型_人工智能

使用 QNNMIC 模型进行医学图像分类

QNNMIC 模型简介

QNNMIC 模型是一个可以用于医学图像分类的量子机器学习模型(Quantum Machine Learning,QML)。我们具体称其为一种量子神经网络 (Quantum Neural Network, QNN),它结合了参数化量子电路(Parameterized Quantum Circuit, PQC)和经典神经网络。对于医学图像数据,QNNMIC 可以达到 85% 以上的分类准确率。模型主要分为量子和经典两部分,结构图如下:

医学图像分类常用数据集 医学图像分类模型_数据_02

注:

  • 通常我们使用主成分分析将图片数据进行降维处理,使其更容易通过编码电路将经典数据编码为量子态。
  • 参数化电路的作用是特征提取,其电路参数可以在训练中调整。
  • 量子测量由一组测量算子表示,是将量子态转化为经典数据的过程,我们可以对得到的经典数据做进一步处理。

如何使用

使用模型进行预测

这里,我们已经给出了一个训练好的模型,可以直接用于医学图片的预测。只需要在 test.toml 这个配置文件中进行对应的配置,然后输入命令 python qnn_medical_image.py --config test.toml 即可使用训练好的医学图片分类模型对输入的图片进行测试。

在线演示

这里,我们给出一个在线演示的版本,可以在线进行测试。首先定义配置文件的内容对测试集中图片进行预测:

import toml

test_toml = """
# 模型的整体配置文件。
# 图片的文件路径。
image_path = 'pneumoniamnist'

# 训练集中的数据个数,默认值为 -1 即使用全部数据。
num_samples = 20

# 训练好的模型参数文件的文件路径。
model_path = 'qnnmic.pdparams'

# 量子电路所包含的量子比特的数量。
num_qubits = [8, 8]

# 每一层量子电路中的电路深度。
num_depths = [2, 2]

# 量子电路中可观测量的设置。
observables = [['Z0', 'Z1', 'Z2', 'Z3'], ['X0', 'X1', 'X2', 'X3']]
"""

config = toml.loads(test_toml)
# 首先,我们需要在 AI Studio 中安装量桨环境。

!mkdir /home/aistudio/external-libraries
%pip install paddle-quantum -t /home/aistudio/external-libraries
import sys
sys.path.append('/home/aistudio/external-libraries')
from paddle_quantum.qml.qnnmic import inference

prediction, prob, label = inference(**config)
print(f"图片的预测结果分别为 {str(prediction)[1:-1]}")
print(f"图片的实际标签分别为 {str(label)[1:-1]}")
图片的预测结果分别为 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0
图片的实际标签分别为 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0

其中标签 0 代表肺部异常,标签 1 代表正常。

test_toml 配置文件中:

  • model_path: 为训练好的模型,这里固定为 qnnmic.pdparams
  • num_qubitsnum_depthsobservables 三个参数应与训练好的模型 qnnmic.pdparams 相匹配。num_qubits = [8,8] 表示量子部分一共两层电路;每层电路为 8 的量子比特;num_depths = [2,2] 表示每层参数化电路深度为 2;observables 表示每层测量算子的具体形式。

对于数据集中的某张肺部异常的图片:

医学图像分类常用数据集 医学图像分类模型_数据_03

# 使用模型进行预测并得到对应概率值
msg = f'对于上述输入的图片,模型有 {prob[10][1]:3.2%} 的置信度检测出肺部异常。'
print(msg)
对于上述输入的图片,模型有 98.30% 的置信度检测出肺部异常。

常。

注意事项

我们通常考虑调整 num_qubitsnum_depthsobservables 三个主要超参数,对模型的影响较大。

引用信息

@article{medmnistv2,
    title={MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification},
    author={Yang, Jiancheng and Shi, Rui and Wei, Donglai and Liu, Zequan and Zhao, Lin and Ke, Bilian and Pfister, Hanspeter and Ni, Bingbing},
    journal={arXiv preprint arXiv:2110.14795},
    year={2021}
}

参考文献

[1] Yang, Jiancheng, et al. “Medmnist v2: A large-scale lightweight benchmark for 2d and 3d biomedical image classification.” arXiv preprint arXiv:2110.14795 (2021).