take(tensor, indices, axis=None)¶
Gather elements from a tensor.
take(indices, axis=3)is equivalent to
tensor[:, :, :, indices, ...]for frameworks that support NumPy-like fancy indexing.
This function is roughly equivalent to
tf.gather. In the case of a 1-dimensional set of indices, it is roughly equivalent to
torch.index_select, but deviates for multi-dimensional indices.
tensor (tensor_like) – input tensor
indices (Sequence[int]) – the indices of the values to extract
axis – The axis over which to select the values. If not provided, the tensor is flattened before value extraction.
>>> x = torch.tensor([[1, 2], [3, 4]]) >>> take(y, indices=[[0, 0], [1, 0]], axis=1) tensor([[[1, 1], [2, 1]],
- [[3, 3],