torch.nn.functional.one_hot¶
- torch.nn.functional.one_hot(tensor, num_classes=- 1) LongTensor ¶
接受形状为
(*)
的 LongTensor,并返回形状为(*, num_classes)
的张量,其中除最后一个维度的索引与输入张量的对应值匹配的地方为 1,其他地方为 0。参见维基百科上的 One-hot
- 参数:
tensor(LongTensor)- 任意形状的类别值。
num_classes(int,可选)- 总类别数。如果设置为-1,类别数将推断为输入张量中最大类别值加一。默认:-1
- 返回值:
长整型张量,比输入张量多一个维度,在最后一个维度的索引处有 1,其他地方为 0。
示例
>>> F.one_hot(torch.arange(0, 5) % 3) tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0]]) >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5) tensor([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0], [0, 1, 0, 0, 0]]) >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) tensor([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [1, 0, 0]], [[0, 1, 0], [0, 0, 1]]])