Advanced Selection from Tensors in Pytorch
In some situations, you'll need to do some advanced indexing / selection with Pytorch, e.g. answer the question: "how can I select elements from Tensor A following the indices specified in Tensor B?"
In this post we'll present the three most common methods for such tasks, namely torch.index_select, torch.gather and torch.take. We'll explain all of them in detail and contrast them with one another.

Admittedly, one motivation for this post was me forgetting how and when to use which function, ending up googling, browsing Stack Overflow and the, in my opinion, relatively brief and not too helpful official documentation. Thus, as mentioned, we here do a deep dive into these functions: we motivate when to use which, give examples in 2- and 3D, and show the resulting selection graphically.
I hope this post will bring clarity about said functions and remove the need for further exploration – thanks for reading!
And now, without further ado, let's dive into the functions one by one. For all, we first start with a 2D example and visualize the resulting selection, and then move to somewhat more complex example in 3D. Further, we re-implement the executed operation in simple Python – s.t. you can look at pseudocode as another source of information what these functions do. In the end, we summarize the functions and their differences in a table.
torch.index_select
torch.index_select selects elements along one dimension, while keeping the other ones unchanged. That is: keep all elements from all other dimensions, but pick elements in the target dimensions following the index tensor. Let's demonstrate this with a 2D example, in which we select along dimension 1:
num_picks = 2
values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(num_picks,))
# [len_dim_0, num_picks]
picked = torch.index_select(values, 1, indices)
The resulting tensor has shape [len_dim_0, num_picks]
: for every element along dimension 0, we have picked the same element from dimension 1. Let's visualize this:

Now we move to three dimensions. For this, we arch closer to the world of Machine Learning / Data Science, and imagine a tensor of shape [batch_size, num_elements, num_features]
: we thus have num_elements
elements with num_feature
features, and everything is batched. Using torch.index_select
, we could pick the same element for every batch / feature combination:
import torch
batch_size = 16
num_elements = 64
num_features = 1024
num_picks = 2
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(num_picks,))
# [batch_size, num_picks, num_features]
picked = torch.index_select(values, 1, indices)
Some might like to understand what index_select
does in the form of code – thus, here's how one could re-implement this function using simple for loops:
Python">picked_manual = torch.zeros_like(picked)
for i in range(batch_size):
for j in range(num_picks):
for k in range(num_features):
picked_manual[i, j, k] = values[i, indices[j], k]
assert torch.all(torch.eq(picked, picked_manual))
torch.gather
Next, we move to torch.gather. gather
behaves similarly to index_select
, but now the element selection in the desired dimension is dependent on the other dimensions – i.e., re-using our ML example: for every batch index, and for every feature, we can pick a different element from the "element" dimension – we pick elements from one tensor following the indices of another tensor.
I came across this use case quite frequently when working on ML projects, one concrete example would be selecting nodes from a tree based on some condition, and each node is specified by some features: we then generate an index selection matrix putting the element to select in the batch dimension, and repeat these values along the feature dimension. I.e: per batch index, we can select different elements, based on some condition – and in our example this condition is only dependent on the batch index – although it could also depend on the feature index.
But first, let's start again with a 2D example:
num_picks = 2
values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_1, size=(len_dim_0, num_picks))
# [len_dim_0, num_picks]
picked = torch.gather(values, 1, indices)
When visualizing this, we observe, that the selection now is not characterized by straight lines anymore, but for each index along Dimension 0 a different element in Dimension 1 is picked:

With that, let's move to three dimensions, and also show Python code to re-implement this selection:
import torch
batch_size = 16
num_elements = 64
num_features = 1024
num_picks = 5
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, num_elements, size=(batch_size, num_picks, num_features))
picked = torch.gather(values, 1, indices)
picked_manual = torch.zeros_like(picked)
for i in range(batch_size):
for j in range(num_picks):
for k in range(num_features):
picked_manual[i, j, k] = values[i, indices[i, j, k], k]
assert torch.all(torch.eq(picked, picked_manual))
torch.take
torch.take might be the easiest of the three introduced functions to grasp: it essentially treats the input tensor as flattened, and then selects elements from this list. For example: when applying take to an input tensor of shape [4, 5], and selecting indices 6 and 19, we'll obtain the 6th and the 19th element of the flattened tensor – that is, element 2 from row 2, and the very last element.
The 2D example:
num_picks = 2
values = torch.rand((len_dim_0, len_dim_1))
indices = torch.randint(0, len_dim_0 * len_dim_1, size=(num_picks,))
# [num_picks]
picked = torch.take(values, indices)
As we can see, we now only get two elements out:

The 3D selection with subsequent re-implementation is given below. Note that the indices tensor now can have arbitrary shape, and the resulting selection is given in this shape:
import torch
batch_size = 16
num_elements = 64
num_features = 1024
num_picks = (2, 5, 3)
values = torch.rand((batch_size, num_elements, num_features))
indices = torch.randint(0, batch_size * num_elements * num_features, size=num_picks)
# [2, 5, 3]
picked = torch.take(values, indices)
picked_manual = torch.zeros(num_picks)
for i in range(num_picks[0]):
for j in range(num_picks[1]):
for k in range(num_picks[2]):
picked_manual[i, j, k] = values.flatten()[indices[i, j, k]]
assert torch.all(torch.eq(picked, picked_manual))
Conclusion
In this post we have seen three common selection methods in Pytorch: torch.index_select
, torch.gather
, and torch.take
. With all of these, one can select / index elements from a Tensor based on some condition. For all, we started with a simple 2D example, and also visualized the resulting selection graphically. Then, we moved to a somewhat more complex and realistic 3D scenario, in which one selects from a Tensor of shape [batch_size, num_elements, num_features]
– which could be a common use case in any ML project.
To conclude this post, I'd like to summarize the differences between these functions in a table – containing a short description and sample shapes. The sample shapes are tailored to the previously mentioned 3D ML example, and will list the necessary shape of the index tensor, as well as the resulting output shape:

Thanks for reading!