文章目录

  • 介绍
  • 制作数据集
  • 将信号保存成STFT图片
  • 划分数据集
  • 训练模型
  • 下载模板
  • 自定义数据加载器
  • 修改配置文件
  • 训练模型
  • 测试模型
  • 开启tensorboard监控


介绍

什么是时域信号?什么是频域信号?那么什么又是时频图?有什么区别,优缺点又如何

时域信号:

对于振动信号来说,时域信号是随着时间变化的振动幅度或者加速度,对于振动系统的变化有着最直接的表现

时域信号图片

时频图 java 时频图含义_经验分享

频域信号 :

由于振动系统发出的信号是经过很多振动源叠加的信号,那么时域信号就会难以分析。
离散傅里叶变换(DFT),是将离散的时域信号转换为频域信号,其采样率和采样点数决定了频域的分辨率,分辨率也就是说能够分辨信号频率的精度最高是多少,或者也可以用能识别的频率最小间隔来描述。

相应的频域信号图片,大小通常表示振幅

时频图 java 时频图含义_经验分享_02

时频域图像:

经过离散傅里叶变换,我们可以很好的观察到信号频率分布,但是频率信号丢失了时间信息,这是它最大的缺点。
短时傅里叶变换(STFT)就可以将时域信号转换为时频域信号,时频域信号能够反应信号等频率随着时间的变化。通常这也可以看系统是否稳定。

短时傅里叶变换:

1.分祯
将信号进行分祯,通常还会重叠一部分分割的地方。
2.加窗
当分祯取的信号不在一个周期时,会发生频谱泄露,出现不存在的频率,那么需要加窗平滑信号首位部分,让首尾是连续的,或者变化不那么大,常用的有汉明窗。
3.DFT计算
对每一个片段信号做DFT计算得到频谱信息。

短时傅里叶变换图像(时频图像)

时频图 java 时频图含义_时频图 java_03

从这副图可以看出,时域信号是较为稳定的。

制作数据集

将信号保存成STFT图片

我的数据集是由txt构成的,由于数据比较少,采用的是随机重复采样,一次截取20000个点,采样率为20000。数据集的标签体现在路径上。STFT的计算可以采scripy.signal的stft函数

def loadSingleFileWithReapetSampling(SampleNum:int,path:str,SampleLen=20000,SampleLate = 20000):

    signal,all_lenght,m =readfile(path)
    for i in range(SampleNum):
        random_start= np.random.randint(low=0,high=all_lenght - SampleLen)
        sliceSignal = signal[random_start:random_start+SampleLen]
        data_stft = STFT.GetFrequencyFeature4(sliceSignal,SampleLate) 
        newPath = path.split("/")
        newPath[-1] = "index-{}.jpg".format(i) #重构文件名
        newPath =  os.path.join(*newPath) #将列表展开成参数
        plt.imsave(newPath,data_stft)
loadSingleFileWithReapetSampling(5000, "data2/0/0hp-70%rpm/channel-0/xxx.txt")

划分数据集

这里采用把数据分别放到train目录和test目录下,其他文件结构不变。

import os
import glob
import numpy as np
from shutil import *
from torch.utils.data import random_split, Subset
for root,cDir,filename in os.walk("data2\\"):
    # print(root)
    if len(root.split("\\"))==4:
        paths = glob.glob(root+"\\*.jpg") #按匹配符加载文件路径,巨好用
        dataSize = len(paths)
        trainSize = int(0.8*dataSize)
        train_set = Subset(paths, range(trainSize)) #也可以用切片
        test_set = Subset(paths, range(trainSize, dataSize))
        for index in test_set:
            newPath = "data3\\test\\"+os.path.join(*index.split("\\")[1:-1])
            if not os.path.exists(newPath):
                os.makedirs(newPath)
            print("newPath:",newPath)
            print("oldPath:",index)
            copyfile(index,newPath+"\\"+os.path.basename(index))
        for index in train_set:
            newPath = "data3\\train\\"+os.path.join(*index.split("\\")[1:-1])
            if not os.path.exists(newPath):
                os.makedirs(newPath)

训练模型

下载模板

这里采用了模板框架github地址:github.com

直接使用git clone https://github.com/victoresque/pytorch-template拷贝到本地

得到结构如下,然后data3为数据集

时频图 java 时频图含义_机器学习_04

自定义数据加载器

因为数据集是自己的所以要自定义加载器(data_loader文件夹下):

class JJDataLoader(BaseDataLoader):
    #参数在config.json中配置
    def __init__(self,data_dir, batch_size, shuffle=True, validation_split=0.0, num_workers=1, training=True):
        trans = transforms.Compose([
            transforms.ToTensor()
        ])
        self.data_dir = data_dir
        self.dataset = JJDataSet(data_dir,trans)
        super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers)
class JJDataSet(Dataset):
    def __init__(self,dir:str,transforms):
        self.img_path = glob.glob(dir+"/*/*/*/*.jpg")
        print("len:",len(self.img_path))
        self.img_label =  np.array([index.split("/")[-4] for index in self.img_path])
        if transforms is not None:
            self.tramsforms= transforms
        else:
            self.tramsforms = None
    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('L')
        if self.tramsforms is not None:
            img = self.tramsforms(img)
        lbl = int(self.img_label[index])
        # print("lbl:",lbl)
        # print("path:",self.img_path[index])
        return img,torch.tensor(lbl)
    def __len__(self):
        return len(self.img_path)

修改配置文件

主要是修改加载器名称和训练集路径

时频图 java 时频图含义_机器学习_05

训练模型

注意训练的gpu数量和测试的数量一定要一致

python train.py -c config.json --resume saved/models/xxxx/model_best.pth

测试模型

这里主要修改的是配置文件下的数据集目录

时频图 java 时频图含义_经验分享_06

开启tensorboard监控

可以很好的看到训练精确度的变化和损失值的变化

tensorboard --logdir saved/log/

时频图 java 时频图含义_pytorch_07