医学图像分类简介
Copyright © 2022 Institute for Quantum Computing, Baidu Inc. All Rights Reserved.
医学图像分类(Medical image classification)是计算机辅助诊断系统的关键技术。医学图像分类问题主要是如何从图像中提取特征并进行分类,从而识别和了解人体的哪些部位受到特定疾病的影响。在这里我们主要使用量子神经网络对公开数据集 MedMNIST 中的胸腔数据进行分类。其中数据形式如下图所示:
使用 QNNMIC 模型进行医学图像分类
QNNMIC 模型简介
QNNMIC 模型是一个可以用于医学图像分类的量子机器学习模型(Quantum Machine Learning,QML)。我们具体称其为一种量子神经网络 (Quantum Neural Network, QNN),它结合了参数化量子电路(Parameterized Quantum Circuit, PQC)和经典神经网络。对于医学图像数据,QNNMIC 可以达到 85% 以上的分类准确率。模型主要分为量子和经典两部分,结构图如下:
注:
- 通常我们使用主成分分析将图片数据进行降维处理,使其更容易通过编码电路将经典数据编码为量子态。
- 参数化电路的作用是特征提取,其电路参数可以在训练中调整。
- 量子测量由一组测量算子表示,是将量子态转化为经典数据的过程,我们可以对得到的经典数据做进一步处理。
如何使用
使用模型进行预测
这里,我们已经给出了一个训练好的模型,可以直接用于医学图片的预测。只需要在 test.toml
这个配置文件中进行对应的配置,然后输入命令 python qnn_medical_image.py --config test.toml
即可使用训练好的医学图片分类模型对输入的图片进行测试。
在线演示
这里,我们给出一个在线演示的版本,可以在线进行测试。首先定义配置文件的内容对测试集中图片进行预测:
import toml
test_toml = """
# 模型的整体配置文件。
# 图片的文件路径。
image_path = 'pneumoniamnist'
# 训练集中的数据个数,默认值为 -1 即使用全部数据。
num_samples = 20
# 训练好的模型参数文件的文件路径。
model_path = 'qnnmic.pdparams'
# 量子电路所包含的量子比特的数量。
num_qubits = [8, 8]
# 每一层量子电路中的电路深度。
num_depths = [2, 2]
# 量子电路中可观测量的设置。
observables = [['Z0', 'Z1', 'Z2', 'Z3'], ['X0', 'X1', 'X2', 'X3']]
"""
config = toml.loads(test_toml)
# 首先,我们需要在 AI Studio 中安装量桨环境。
!mkdir /home/aistudio/external-libraries
%pip install paddle-quantum -t /home/aistudio/external-libraries
import sys
sys.path.append('/home/aistudio/external-libraries')
from paddle_quantum.qml.qnnmic import inference
prediction, prob, label = inference(**config)
print(f"图片的预测结果分别为 {str(prediction)[1:-1]}")
print(f"图片的实际标签分别为 {str(label)[1:-1]}")
图片的预测结果分别为 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0
图片的实际标签分别为 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0
其中标签 0 代表肺部异常,标签 1 代表正常。
在 test_toml
配置文件中:
-
model_path
: 为训练好的模型,这里固定为qnnmic.pdparams
; -
num_qubits
、num_depths
、observables
三个参数应与训练好的模型qnnmic.pdparams
相匹配。num_qubits = [8,8]
表示量子部分一共两层电路;每层电路为 8 的量子比特;num_depths = [2,2]
表示每层参数化电路深度为 2;observables
表示每层测量算子的具体形式。
对于数据集中的某张肺部异常的图片:
# 使用模型进行预测并得到对应概率值
msg = f'对于上述输入的图片,模型有 {prob[10][1]:3.2%} 的置信度检测出肺部异常。'
print(msg)
对于上述输入的图片,模型有 98.30% 的置信度检测出肺部异常。
常。
注意事项
我们通常考虑调整 num_qubits
,num_depths
,observables
三个主要超参数,对模型的影响较大。
引用信息
@article{medmnistv2,
title={MedMNIST v2: A Large-Scale Lightweight Benchmark for 2D and 3D Biomedical Image Classification},
author={Yang, Jiancheng and Shi, Rui and Wei, Donglai and Liu, Zequan and Zhao, Lin and Ke, Bilian and Pfister, Hanspeter and Ni, Bingbing},
journal={arXiv preprint arXiv:2110.14795},
year={2021}
}
参考文献
[1] Yang, Jiancheng, et al. “Medmnist v2: A large-scale lightweight benchmark for 2d and 3d biomedical image classification.” arXiv preprint arXiv:2110.14795 (2021).