Tensor的索引与变形

先行变量:a = torch.Tensor([[0,1],[2,3]])

  1. 根据下标进行索引(类似数组索引)
    例:a[0], a[0,1]
  2. 设置条件对tensor内的元素进行判断,符合条件的置True,否则置False
    例:b = a > 1
  3. 选择符合条件的元素并返回
    例:a[a > 1]
  4. 选择非0元素的坐标,并返回
    torch.nonzero()
  5. 满足condition的位置输出x,否则输出y
    torch.where(condition, torch.full_like(input, x), y)
  6. 限制Tensor元素在[x, y]范围内,小于x的元素被设置成x,大于y的元素被设置成y,其余的不变
    input.clamp(x, y)
  7. view()、resize()和reshape()函数可以在不改变Tensor数据的前提下任意改变Tensor的形状,必须保证调整前后的元素总数相同,并且调整前后共享内存,三者的作用基本相同。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch
    a = torch.Tensor([[0,1],[2,3]])
    print(a.view(1,4))
    print(a.resize(2,2))
    print(a.reshape(4,1))

    '''
    result:

    tensor([[0., 1., 2., 3.]])
    tensor([[0., 1.],
    [2., 3.]])
    tensor([[0.],
    [1.],
    [2.],
    [3.]])
    '''
  8. transpose()函数可以将指定的两个维度的元素进行转置,而permute()函数则可以按照给定的维度进行维度变换。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import torch
    a = torch.Tensor([[0,1],[2,3]])
    print(a)
    print(a.transpose(0,1))
    print(a.permute(1,0))
    '''
    result:

    tensor([[0., 1.],
    [2., 3.]])
    tensor([[0., 2.],
    [1., 3.]])
    tensor([[0., 2.],
    [1., 3.]])
    '''
  9. 使用squeeze()与unsqueeze()函数,前者用于去除size为1的维度,而后者则是将指定的维度的size变为1。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    import torch
    a = torch.Tensor([[0,1],[2,3]])
    print(a)
    print(a.unsqueeze(2))
    print(a.unsqueeze(2).squeeze(2))
    '''
    result:

    tensor([[0., 1.],
    [2., 3.]])
    tensor([[[0.],
    [1.]],

    [[2.],
    [3.]]])
    tensor([[0., 1.],
    [2., 3.]])
    '''
  10. expand()函数将size为1的维度复制扩展为指定大小,也可以使用expand_as()函数指定为示例Tensor的维度。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch
    a = torch.Tensor([[0],[3]])
    print(a)
    print(a.expand(2,4))

    '''
    result:

    tensor([[0.],
    [3.]])
    tensor([[0., 0., 0., 0.],
    [3., 3., 3., 3.]])
    '''

Tensor的排序与取极值

  1. 函数sort(),选择沿着指定维度进行排序,返回排序后的Tensor及对应的索引位置。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch
    a = torch.randn(3,3)
    print(a)
    # 按照第0维进行按列排序,True代表降序,False代表升序
    print(a.sort(0,True))

    '''
    result:

    tensor([[ 0.6448, 0.5900, 2.2122],
    [ 0.1974, 2.0291, -0.1883],
    [-0.6540, -1.2901, 0.8186]])
    torch.return_types.sort(
    values=tensor([[ 0.6448, 2.0291, 2.2122],
    [ 0.1974, 0.5900, 0.8186],
    [-0.6540, -1.2901, -0.1883]]),
    indices=tensor([[0, 1, 0],
    [1, 0, 2],
    [2, 2, 1]]))
  2. max()与min()函数则是沿着指定维度选择最大与最小元素,返回该元素及对应的索引位置。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    import torch
    a = torch.randn(3,3)
    print(a)
    # 选出每一列的最大值
    print(a.max(0))
    '''
    result:

    tensor([[ 0.8609, 2.1089, 0.5477],
    [-0.7530, 0.6850, -0.9778],
    [-2.0674, 0.6929, 1.4075]])
    torch.return_types.max(
    values=tensor([0.8609, 2.1089, 1.4075]),
    indices=tensor([0, 0, 2]))
    '''

    Tensor的自动广播机制

    不同形状的Tensor进行计算时,可自动扩展到较大的相同形状,再进行计算。广播机制的前提是任一个Tensor至少有一个维度,且从尾部遍历Tensor维度时,两者维度必须相等,其中一个要么是1要么不存在。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    import torch
    a = torch.ones(3,1,2)
    b = torch.ones(3,1)
    print(a)
    print(b)
    # 从尾部遍历维度,1对应3,2对应1,3对应不存在,因此满足广播条件,最后求和后的维度为[3,3,2]
    print(a+b)
    '''
    result:

    tensor([[[1., 1.]],

    [[1., 1.]],

    [[1., 1.]]])
    tensor([[1.],
    [1.],
    [1.]])
    tensor([[[2., 2.],
    [2., 2.],
    [2., 2.]],

    [[2., 2.],
    [2., 2.],
    [2., 2.]],

    [[2., 2.],
    [2., 2.],
    [2., 2.]]])
    '''

    Tensor的内存共享和转换

  3. 原地操作符
    PyTorch对于一些操作通过加后缀“”实现了原地操作,如add()和resize_()等,这种操作只要被执行,本身的Tensor则会被改变。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch
    a = torch.ones(2,2)
    print(a)
    a.add_(a)
    print(a)
    '''
    result:

    tensor([[1., 1.],
    [1., 1.]])
    tensor([[2., 2.],
    [2., 2.]])
    '''
  4. Tensor与NumPy转换
    Tensor与NumPy可以高效地进行转换,并且转换前后的变量共享内存。在进行PyTorch不支持的操作时,甚至可以曲线救国,将Tensor转换为NumPy类型,操作后再转为Tensor。
    例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch
    a = torch.ones(2,2)
    # 转numpy
    b = a.numpy()
    # 转tensor
    c = torch.from_numpy(b)
    # 转list
    d = a.tolist()
    print(a)
    print(b)
    print(c)
    print(d)
    '''
    result:

    tensor([[1., 1.],
    [1., 1.]])

    [[1. 1.]
    [1. 1.]]

    tensor([[1., 1.],
    [1., 1.]])

    [[1.0, 1.0], [1.0, 1.0]]
    '''

Reference

深度学习之PyTorch物体检测实战 - 董洪义