文章目录

  • 前言
  • 一、tensor打印配置
  • 二、numpy读取csv
  • 三、python内库读取csv
  • 四、numpy->tensor
  • 五、连续值 序数值 分类值
  • 六、tensor切分及类型转换
  • 七、独热编码
  • 八、规约(归一化)
  • 九、寻找阈值
  • 总结



前言

案例代码https://github.com/2012Netsky/pytorch_cnn/blob/main/3_tabular_wine.ipynb

一、tensor打印配置

#!/usr/bin/env python
# coding: utf-8
import numpy as np
import torch
torch.set_printoptions(edgeitems=2, threshold=50)
# precision是每一个元素的输出精度,默认是八位;
# threshold是输出时的阈值,当tensor中元素的个数大于该值时,进行缩略输出,默认时1000;
# edgeitems是输出的维度,默认是3;
# linewidth字面意思,每一行输出的长度;
# profile=None,修正默认设置(不太懂,感兴趣的可以试试)

pytorch 打印模型权重 pytorch如何打印tensor的值_pytorch 打印模型权重

二、numpy读取csv

# 处理csv文件的方法 python内库 numpy pandas(速度最快) 简单起见使用内置模块
# 使用它csv库读取csv文件特定内容
# 用于检查数据
import csv
wine_path = "../data/p1ch4/tabular-wine/winequality-white.csv"
# skiprows=1 跳过第一行
wineq_numpy = np.loadtxt(wine_path, dtype=np.float32, delimiter=";",
                         skiprows=1)
print(wineq_numpy.shape)
print(wineq_numpy)

pytorch 打印模型权重 pytorch如何打印tensor的值_python_02

三、python内库读取csv

# 使用python内库读取csv文件
# python open() 函数用于打开一个文件,创建一个 file 对象,相关的方法才可以调用它进行读写。
# csv.reader()返回一个reader对象,利用该对象遍历csv文件中的行。
# next() 返回迭代器的下一个项目。next() 函数要和生成迭代器的 iter() 函数一起使用。

# 检查是否读取了所有数据
col_list = next(csv.reader(open(wine_path), delimiter=';'))

print(open(wine_path))
print(csv.reader(open(wine_path), delimiter=';'))
print(col_list)
print(wineq_numpy.shape)

pytorch 打印模型权重 pytorch如何打印tensor的值_ai_03

四、numpy->tensor

# csv->numpy->tensor
# 前十一列为指标 最后一列表示质量评分的一列
wineq = torch.from_numpy(wineq_numpy)

wineq.shape, wineq.dtype

pytorch 打印模型权重 pytorch如何打印tensor的值_pytorch_04

五、连续值 序数值 分类值

# 连续值 序数值 分类值


# tensor切片
# 前11列
# 过滤掉最后一列   :-1   前面所有 后面第一列 索引从0开始
data = wineq[:, :-1] # <1>
data, data.shape

pytorch 打印模型权重 pytorch如何打印tensor的值_ai_05

六、tensor切分及类型转换

# tensor切片
# 第十二列
# 行所有 倒数最后一列
target = wineq[:, -1] # <2>
target, target.shape

# 打印最后一列
# 数据类型转换
target = wineq[:, -1].long()
target

pytorch 打印模型权重 pytorch如何打印tensor的值_python_06

七、独热编码

# 最后一列最为标签 需要编码
# 独热编码
target_onehot = torch.zeros(target.shape[0], 10)

print(target_onehot.shape)
#print(target_onehot)
print(target.unsqueeze(1).shape)
print((target_onehot.scatter_(1, target.unsqueeze(1), 1.0)).shape)
#print(target_onehot.scatter_(1, target.unsqueeze(1), 1.0))

pytorch 打印模型权重 pytorch如何打印tensor的值_ai_07

八、规约(归一化)

# 获取每列平均值 即每个指标平均值
data_mean = torch.mean(data, dim=0)
data_mean

# 获取每列标准差 即每个指标标准差
data_var = torch.var(data, dim=0)
data_var

# 规约(归一化)
data_normalized = (data - data_mean) / torch.sqrt(data_var)
data_normalized

pytorch 打印模型权重 pytorch如何打印tensor的值_ai_08

九、寻找阈值

# 寻找阈值
bad_indexes = target <= 3 # <1>
# 输出为布尔型
bad_indexes.shape, bad_indexes.dtype, bad_indexes.sum()


# tensor.sum()
# torch.sum()
# 按照指定维度求和
print(bad_indexes.shape)
print(bad_indexes)

bad_data = data[bad_indexes]
bad_data.shape

bad_data = data[target <= 3]
mid_data = data[(target > 3) & (target < 7)] # <1>
good_data = data[target >= 7]
print(data.shape)
print(bad_data.shape)
print(mid_data.shape)
print(good_data.shape)
print(good_data[0])

bad_mean = torch.mean(bad_data, dim=0)
mid_mean = torch.mean(mid_data, dim=0)
good_mean = torch.mean(good_data, dim=0)

print(bad_mean.shape)
print(mid_mean.shape)
print(good_mean.shape)
print(good_mean[0])

for i, args in enumerate(zip(col_list, bad_mean, mid_mean, good_mean)):
    print('{:2} {:20} {:6.2f} {:6.2f} {:6.2f}'.format(i, *args))


total_sulfur_threshold = 141.83
# 行所有 第六列
total_sulfur_data = data[:,6]
print(data.shape, total_sulfur_data.shape)

# 判断矩阵
predicted_indexes = torch.lt(total_sulfur_data, total_sulfur_threshold)

# ture索引之和
predicted_indexes.shape, predicted_indexes.dtype, predicted_indexes.sum()

actual_indexes = target > 5

actual_indexes.shape, actual_indexes.dtype, actual_indexes.sum()

n_matches = torch.sum(actual_indexes & predicted_indexes).item()
n_predicted = torch.sum(predicted_indexes).item()
n_actual = torch.sum(actual_indexes).item()

n_matches, n_matches / n_predicted, n_matches / n_actual

pytorch 打印模型权重 pytorch如何打印tensor的值_pytorch_09


pytorch 打印模型权重 pytorch如何打印tensor的值_pytorch_10


pytorch 打印模型权重 pytorch如何打印tensor的值_python_11

总结