1. 简介

TensorSpace 是一套用于构建神经网络3D可视化应用的框架。开发者可以使用类 Keras 风格的 TensorSpace API, 轻松创建可视化网络、加载神经网络模型并在浏览器中基于已加载的模型进行3D可交互呈现。通过使用 TensorSpace,可以更直观的观察并理解基于 TensorFlow、Keras 或者 TensorFlow.js 等开发的神经网络模型。

使用cnn进行水果图片识别_2d

图1 ensorSpace 使用开发流程

2. 项目环境要求

TensorSpace官网给出的适配TensorSpace.js 浏览器有:

浏览器

版本号

Chrome

64+

Firefox

58+

Safari

12+

TensorSpace.js 需要以下依赖库:

名称

版本号

TensorFlow.js

1.0.0+

Three.js

r101+

Tween.js

17.2.0+

3. 安装

  1. 第一步:下载依赖库

依赖库文件

备注

TensorFlow.js

tf.min.js

Three.js

three.min.js

Tween.js

tween.min.js

TrackballControls

TrackballControls.js

  1. 第二步:下载 TensorSpace
    下载链接: Github.
<!-- 将”VERSION”替换成需要的版本 -->
<script src="tensorspace.min.js"></script>
  1. 第三步:在页面中引入库文件
<script src="tf.min.js"></script>
<script src="three.min.js"></script>
<script src="Tween.min.js"></script>
<script src="TrackballControls.js"></script>
<script src="tensorspace.min.js"></script>

4. 使用

  1. 第一步:模型预处理
  • 在我之前写的博客基于Fruits-360数据集构建CNN进行水果识别实验中,我们已经完成训练并得到了的.h5文件,接下来我们要做的就是将神经网络模型通过一系列过程转换至 TensorSpace 可以使用的格式,而这一过程就被称为模型预处理

图2 模型的预处理

  • 这里我们使用到的是一个名为TensorSpace-Converter的模型转换工具,它可以帮助我们快速完成 TensorSpace 预处理过程。由于我之前的实验使用的是Keras,并用它训练得到了一个Keras模型,并且其模型结构和权重保存在一个HDF5文件里,所以我们编写一个bash脚本来进行模型转化。
// An highlighted block
#!/usr/bin/env bash
tensorspacejs_converter \
    --input_model_from="keras" \
    --input_model_format="topology_weights_combined" \
    --output_layer_names="conv2d_1,max_pooling2d_1,conv2d_2,max_pooling2d_2,conv2d_3,max_pooling2d_3,conv2d_4,max_pooling2d_4,flatten_1,dense_1,dense_2" \
    ./model/model_demo.h5 \
    ./convertedModel\
  • 以上 TensorSpace-Converter预处理脚本将会在 convertedModel 文件夹中生成经过预处理的模型:
    (1)一份 model.json 文件:包含所得到的模型结构信息(包括中间层输出)。
    (2)一些权重文件:包含模型训练所得到的权重信息。权重文件的数量取决于模型的结构。

图3 模型转后的结果

使用cnn进行水果图片识别_可视化_02

图4 将模型 Layer 名取出并设置 output_layer_names

  • 我的实验中构建的神经网络模型结构如下所示:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_1 (Conv2D)            (None, 64, 64, 16)        448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 32, 32, 16)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 32)        4640      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 16, 16, 32)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 32)        9248      
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 8, 8, 32)          0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 8, 8, 64)          18496     
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 4, 4, 64)          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 1024)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               262400    
_________________________________________________________________
dense_2 (Dense)              (None, 131)               33667     
=================================================================
Total params: 328,899
Trainable params: 328,899
Non-trainable params: 0
_________________________________________________________________
  1. 第二步:使用 TensorSpace 可视化模型
  • 载入并可视化
    通过 TensorSpace API 构建 TensorSpace 可视化模型。
let modelContainer = document.getElementById("container");
            let model = new TSP.models.Sequential( modelContainer );

    		model.add( new TSP.layers.RGBInput({ shape: [64, 64, 3] }) );
    		model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 16, strides: 1 }));
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
    		model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 32, strides: 1 }) );
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
            model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 32, strides: 1 }) );
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
    		model.add( new TSP.layers.Conv2d({ kernelSize: 3, filters: 64, strides: 1 }) );
    		model.add( new TSP.layers.Pooling2d({ poolSize: [2, 2], strides: [2, 2] }) );
    		model.add( new TSP.layers.Dense({ units: 256 }) );
    		model.add( new TSP.layers.Dense({ units: 131 }) );
    		model.add( new TSP.layers.Output1d({
                units :  131 ,
    			outputs: ['Apple Braeburn', 'Apple Crimson Snow', 'Apple Golden 1', 'Apple Golden 2', 'Apple Golden 3',
                    'Apple Granny Smith', 'Apple Pink Lady', 'Apple Red 1', 'Apple Red 2', 'Apple Red 3',
                    'Apple Red Delicious', 'Apple Red Yellow 1', 'Apple Red Yellow 2', 'Apricot', 'Avocado',
                    'Avocado ripe', 'Banana', 'Banana Lady Finger', 'Banana Red', 'Beetroot',
                    'Blueberry', 'Cactus fruit', 'Cantaloupe 1', 'Cantaloupe 2', 'Carambula',
                    'Cauliflower', 'Cherry 1', 'Cherry 2', 'Cherry Rainier', 'Cherry Wax Black',
                    'Cherry Wax Red', 'Cherry Wax Yellow', 'Chestnut', 'Clementine', 'Cocos',
                    'Corn', 'Corn Husk', 'Cucumber Ripe', 'Cucumber Ripe 2', 'Dates',
                    'Eggplant', 'Fig', 'Ginger Root', 'Granadilla', 'Grape Blue',
                    'Grape Pink', 'Grape White', 'Grape White 2', 'Grape White 3', 'Grape White 4',
                    'Grapefruit Pink', 'Grapefruit White', 'Guava', 'Hazelnut', 'Huckleberry',
                    'Kaki', 'Kiwi', 'Kohlrabi', 'Kumquats', 'Lemon',
                    'Lemon Meyer', 'Limes', 'Lychee', 'Mandarine',
                    'Mango', 'Mango Red', 'Mangostan', 'Maracuja', 'Melon Piel de Sapo',
                    'Mulberry', 'Nectarine', 'Nectarine Flat', 'Nut Forest', 'Nut Pecan',
                    'Onion Red', 'Onion Red Peeled', 'Onion White', 'Orange', 'Papaya',
                    'Passion Fruit', 'Peach', 'Peach 2', 'Peach Flat', 'Pear',
                    'Pear 2', 'Pear Abate', 'Pear Forelle', 'Pear Kaiser', 'Pear Monster',
                    'Pear Red', 'Pear Stone', 'Pear Williams', 'Pepino', 'Pepper Green',
                    'Pepper Orange', 'Pepper Red', 'Pepper Yellow', 'Physalis', 'Physalis with Husk',
                    'Pineapple', 'Pineapple Mini', 'Pitahaya Red', 'Plum', 'Plum 2',
                    'Plum 3', 'Pomegranate', 'Pomelo Sweetie', 'Potato Red', 'Potato Red Washed',
                    'Potato Sweet', 'Potato White', 'Quince', 'Rambutan', 'Raspberry',
                    'Redcurrant', 'Salak', 'Strawberry', 'Strawberry Wedge', 'Tamarillo',
                    'Tangelo', 'Tomato 1', 'Tomato 2', 'Tomato 3', 'Tomato 4',
                    'Tomato Cherry Red', 'Tomato Heart', 'Tomato Maroon', 'Tomato not Ripened', 'Tomato Yellow',
                    'Walnut', 'Watermelon']

    		}) );
  • 载入经过 TensorSpace-Converter 预处理的模型,然后将模型进行初始化:
model.load({
    			type: "tfjs",
    			url: "convertedModel/model.json",
    			onComplete: function() {
    				console.log( "\"Hello World!\" from TensorSpace Loader." );
    			}
    		});

    		model.init( function() {
                $.ajax({
    				url: "json/banana_107_100.json",
    				type: 'GET',
    				async: true,
    				dataType: 'json',
    				success: function (d) {
    					model.predict( d);
    					console.log( d);
    				}
    			});
            } );
  1. 可视化结果展示
    展示结果(以banana图像为例)可进行拖拽、展开、旋转,可以详细地了解到神经网络的每一层。

图5 可视化展示1

使用cnn进行水果图片识别_使用cnn进行水果图片识别_03

图6 可视化展示2

使用cnn进行水果图片识别_神经网络_04

图7 可视化展示3

使用cnn进行水果图片识别_深度学习_05

图8 可视化展示4

使用cnn进行水果图片识别_使用cnn进行水果图片识别_06

图9 可视化展示5

最后一张图(图9)展示的即为我们的TensorSpace可视化最终的预测结果,我们的输入为banana的一张图片(已转化为banana_107_100.json),可以看到最终的Output层将其输出为banana这一分类。

5. 结语

通过基于 TensorSpace 所开发的3D可视化神经网络模型实例,我们可以体验不同的可交互模型,包括但不限于:物体分类、物体探测、图片生成等。通过展示这些模型实例,我们能更好、更直观地体现 TensorSpace 的应用场景、操作方法以及展示效果。