目录
- Tensor概述
- Tensor创建
- 常用属性和方法
- 修改形状
- 索引操作
- 广播机制
- 逐元素操作
- 归并操作
- 比较操作
- 矩阵操作
Tensor概述
Pytorch的数据结构用Tensor表示,它与Numpy相似,二者可以共享内存,且之间的转换非常方便和高效。不过它们也有不同之处,最大的区别就是Numpy会把ndarray放在CPU中进行加速运算,而由Torch产生的Tensor会放在GPU中进行加速运算。
从接口可分为两类:
- torch.function()
- tensor.function()
从修改方式分成两类:
- 不修改自身数据,返回新数据:x.add(y)
- 修改自身数据:x.add_(y)
Tensor创建
import torch
import numpy as np
#创建tensor
#1、直接传入数据
a_1 = torch.tensor([1,2,3])
a_2 = torch.tensor(np.array([1,2,3]))
a_3 = torch.tensor((1,2,3))
# 指定形状
b_1 = torch.Tensor(4,2)
#内建函数创建
temp = torch.Tensor(3,3)
c_1 = torch.eye(2,2) #单位矩阵
c_2 = torch.zeros(2,2) #零矩阵
c_3 = torch.ones(2,2) #全1矩阵
c_4 = torch.zeros_like(temp) #形状和temp相同的零矩阵
c_5 = torch.ones_like(temp) #形状和temp相同的全1矩阵
c_6 = torch.linspace(0,10,3) #等差矩阵,分成3块
c_7 = torch.logspace(0,10,3) #等差矩阵,分成3块,但终点值为1^end
c_8 = torch.rand(2,2) #0~1分布
c_9 = torch.randn(2,2) #正太分布
c_10 = torch.arange(0,10,3) #等差矩阵,设定步长
c_11 = torch.from_numpy(np.array([[1,2,3],[4,5,6]])) #将ndarray转
为tensor
c_11.numpy() #tensor转numpy数组
"""
torch.Tensor和torch.tensor的区别:
1、Tensor使用全局默认的dtype,即float。而tensor会根据数据自动推断类型
2、tensor(1)表示值为1的tensor,Tensor(1)表示大小为1的tensor,其初始值随机初始化
"""
常用属性和方法
#tensor对象常用属性、方法
#1、shape size()
c_1.shape == c_1.size()
#2、type() dtype
c_1.type()
c_1.dtype
#3、numel() 元素的个数
c_1.grad #grad用于存储累加的梯度值
c_1.grad_fn # grad_fn 确定该节点是否是通过运算得到的非叶子节点,其值表示梯度函数对象
c_1.is_leaf # is_leaf 确定这个节点是否是叶子节点,返回bool类型
c_1.require_grad # require_grad确定是否要对这个节点求导,返回bool类型
修改形状
c_1.numel()
#4、resize() 修改形状 返回一个新的tensor
c_1.resize(1,4)
#5、reshape 修改形状 返回一个新的tensor
c_2.reshape(1,4)
#6、view 修改形状 返回一个新的tensor 参数-1表示展平数据
c_2.view(1,4)
#7、unsqueeze(0) 在0位置增加一个维度
c_1.unsqueeze(0)
#8、squeeze(0) 在0位置压缩一个维度
c_1.unsqueeze(0).squeeze(0)
#9、item() 将单元素的tensor转为python标量
torch.tensor(1).item()
索引操作
#1、获取某一行数据
c_1[0,:]
#2、获取某一列数据
c_1[:,1]
#3、index_select(dim,index)在指定维度上选择行或列 ,index是tensor类型参数
c = torch.Tensor(2,2,2)
c.index_select(1,torch.tensor(0))
#4、nonzero() 获取非零元素的下标
c_1.nonzero()
#5、masked_select(mask) 获取满足条件的所有值
mask = c_1 == 1
c_1.masked_select(mask)
#6、gather(dim,index) 在指定维度选择数据,index类型为long、大小和输出形状一样
c = torch.randn(2,3)
c.gather(1,torch.LongTensor([[0,1,2],[1,1,1]])) #按行
广播机制
x = torch.tensor([1,2,3]) #shape [3]
y = torch.tensor([0,0,1,1]).reshape(4,1) #shape[4,1]
x+y #shape [4,3]
#上下等效
x = x.unsqueeze(0) #shape [1,3]
x = x.expand(4,3)
y = y.expand(4,3)
x+y
逐元素操作
大部分数学运算都属于逐元素操作,其输入与输出的形状相同。
#1、加减乘除
x = torch.tensor([1.0,2.0,3.0])
y = torch.tensor([1.0,1.0,1.0])
x.add(y) #加
x.add(y.neg()) #减
x.mul(y) #乘
x.div(y) #除
# 2、绝对值、求整
x.abs()
x.ceil() #向上取整
x.floor() #向下取整
#3、指数、对数、幂、开根号、取符号
x.exp() #以e为底的指数
x.log()
x.pow(2)
x.sqrt()
x.sign() #正的为1,负的为-1
#4、混合运算
torch.addcdiv(y,2,x,y) #(x/y)*2 + y
torch.addcmul(y,2,x,y) #(x*y)*2 + y
#5、将张量元素限制在指定范围 clamp(min,max)
x.clamp(1.0,2.0)
#6、激活函数
x.sigmoid()
x.tanh()
归并操作
#归并操作顾名思义,就是对输入进行归并或合计等操作,这类操作的输入输出形状一般并不相同,而且往往是输入大于输出形状。
#一个参数是dim,另一个参数是keepdim,说明输出结果中是否保留维度1,缺省情况是False
x = torch.tensor([[1.0,2.0,3.0],[4,5,6]])
#1、指定维度进行累加
x.sum(dim=1) #0按列,1按行
#2、求均值、方差、标准值、中位数、众数
x.mean(dim=0)
x.var(dim=0)
x.std(dim=0)
x.median(dim=0)
x.norm(dim=0)
#3、指定维度进行累乘
x.prod(dim=0)
#4、求范数
x.norm(p=2)
x.norm(dim=0,p=2)
#5、不改变输出大小,累积求和、求积
x.cumsum(dim=0)
x.cumprod(dim=0)
比较操作
#比较操作一般是进行逐元素比较,有些是按指定方向比较。
x = torch.tensor([[ 1., 2., 3.],
[ 4., 10., 18.]])
#1、取最大值、最小值 ,若指定axis,则额外返回下标
x.min(0) # (tensor([ 1., 2., 3.]), tensor([ 0, 0, 0]))
x.max() # tensor(18.)
#2、在某维度上取最高的几个值及其对应索引
x.topk(2,1) # (tensor([[ 3., 2.],[ 18., 10.]]), tensor([[ 2, 1],[ 2, 1]]))
#3、判断是否相等(前提是类型相同)
x.equal(torch.tensor([1.,2,3])) # False
x.eq(1.) #支持广播机制 tensor([[ 1, 0, 0],[ 0, 0, 0]], dtype=torch.uint8) 返回大小和input相同,值为0或1,代表对应元素是否相同
#4、比较大小 支持广播机制
x.ge(1.) #大于等于
x.le(1.)
x.gt(1.) #大于
x.lt(1.)
矩阵操作
x = torch.randn(3,1)
y = torch.randn(3)
#1、转置
x.t()
#2、内积(点积)操作(1D)
y.dot(torch.randn(3))
#3、 矩阵乘法mm(2D) bmm(3D)
x.mm(torch.randn(1,3)) #shape(3,3)
#4、矩阵和向量乘法
x.resize(1,3).mv(y)
#5、SVD分解
torch.svd(torch.tensor([[ 1., 2., 3.],
[ 4., 10., 18.]]))
"""
1、Torch的dot与Numpy的dot有点不同,Torch中的dot是对两个为1D张量进行点积运算,Numpy中的dot无此限制。
2、mm是对2D的矩阵进行点积,bmm对含batch的3D进行点积运算。
3、转置运算会导致存储空间不连续,需要调用contiguous方法转为连续。 contiguous一般与transpose,permute,view搭配使用
"""