import torch
input_tensor = torch.tensor([1,2,3,4,5])
mask = input_tensor>3
print(mask)
indexes = mask.nonzero().squeeze()
print(indexes)
tensor([0, 0, 0, 1, 1], dtype=torch.uint8)
tensor([3, 4])