Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
81 views
in Technique[技术] by (71.8m points)

python - PyTorch how to do gathers over multiple dimensions

I'm trying to find a way to do this without for loops.

Say I have a multi-dimensional tensor t0:

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

This has shape: torch.Size([4, 10, 16])

I have another tensor labels that is a batch of 5 random indices in the seq dimension:

labels = torch.randint(0, seq, size=[bs, sample])

So this has shape torch.Size([4, 5]). This is used to index the seq dimension of t0.

What I want to do is loop over the batch dimension doing gathers using labels tensor. My brute force solution is this:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]

Resulting in tensor t1 which has shape: torch.Size([4, 5, 16])

Is there a more idiomatic way of doing this in pytorch?

question from:https://stackoverflow.com/questions/65894166/pytorch-how-to-do-gathers-over-multiple-dimensions

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

You could do it like this:

t1 = t0[[[b] for b in range(bs)], labels]

or

t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...