where(condition, x=None, y=None)[source]

Returns elements chosen from x or y depending on a boolean tensor condition, or the indices of entries satisfying the condition.

The input tensors condition, x, and y must all be broadcastable to the same shape.

  • condition (tensor_like[bool]) – A boolean tensor. Where True , elements from x will be chosen, otherwise y. If x and y are None the indices where condition==True holds will be returned.

  • x (tensor_like) – values from which to choose if the condition evaluates to True

  • y (tensor_like) – values from which to choose if the condition evaluates to False


If x is None and y is None, a tensor or tuple of tensors with the indices where condition is True . Else, a tensor with elements from x where the condition is True, and y otherwise. In this case, the output tensor has the same shape as the input tensors.

Return type

tensor_like or tuple[tensor_like]

Example with three arguments

>>> a = torch.tensor([0.6, 0.23, 0.7, 1.5, 1.7], requires_grad=True)
>>> b = torch.tensor([-1., -2., -3., -4., -5.], requires_grad=True)
>>> math.where(a < 1, a, b)
tensor([ 0.6000,  0.2300,  0.7000, -4.0000, -5.0000], grad_fn=<SWhereBackward>)


The output format for x=None and y=None follows the respective interface and differs between TensorFlow and all other interfaces: For TensorFlow, the output is a tensor with shape (num_true, len(condition.shape)) where num_true is the number of entries in condition that are True . The entry at position (i, j) is the j th entry of the i th index. For all other interfaces, the output is a tuple of tensor-like objects, with the j th object indicating the j th entries of all indices. Also see the examples below.

Example with single argument

For Torch, Autograd, JAX and NumPy, the output formatting is as follows:

>>> a = [[0.6, 0.23, 1.7],[1.5, 0.7, -0.2]]
>>> math.where(torch.tensor(a) < 1)
(tensor([0, 0, 1, 1]), tensor([0, 1, 1, 2]))

This is not a single tensor-like object but corresponds to the shape (2, 4) . For TensorFlow, on the other hand:

>>> math.where(tf.constant(a) < 1)
[[0 0]
 [0 1]
 [1 1]
 [1 2]], shape=(4, 2), dtype=int64)

As we can see, the dimensions are swapped and the output is a single Tensor. Note that the number of dimensions of the output does not depend on the input shape, it is always two-dimensional.