pytorch中的torch.nonzero(),就是返回张量中元素不为0的元素的索引。
举例子如下:
import torch
x = torch.tensor([4,0,1,2,1,2,3])
result = 1==x
print(result)
print(result.nonzero()) #输出了不为0值的索引
print(result.nonzero().view(-1))#将结果转为一维的张量
pytorch中的torch.nonzero(),就是返回张量中元素不为0的元素的索引。
举例子如下:
import torch
x = torch.tensor([4,0,1,2,1,2,3])
result = 1==x
print(result)
print(result.nonzero()) #输出了不为0值的索引
print(result.nonzero().view(-1))#将结果转为一维的张量