快捷键

嵌入 ¶

class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)[source][source]

一个简单的查找表,存储固定字典和大小的嵌入。

此模块通常用于存储词嵌入并使用索引检索它们。模块的输入是一个索引列表,输出是对应的词嵌入。

参数:
  • num_embeddings (int) – 嵌入字典的大小

  • embedding_dim (int) – 每个嵌入向量的尺寸

  • padding_idx (int, optional) – 如果指定, padding_idx 的条目不会对梯度产生影响;因此, padding_idx 的嵌入向量在训练期间不会更新,即它保持为固定的“填充”。对于新构建的嵌入, padding_idx 的嵌入向量将默认为全零,但可以更新为其他值用作填充向量。

  • max_norm (float, optional) – 如果给定,则将范数大于 max_norm 的每个嵌入向量重新归一化,使其范数为 max_norm

  • norm_type (float, optional) – 计算选项 max_norm 的 p-norm 的 p。默认为 2

  • scale_grad_by_freq (bool, 可选) – 如果提供,则将梯度按 mini-batch 中单词的频率的倒数进行缩放。默认为 False

  • sparse (bool, 可选) – 如果为 True ,则相对于 weight 矩阵的梯度将是一个稀疏张量。有关稀疏梯度的更多详细信息,请参阅注释。

变量:

weight (Tensor) – 模块的 learnable 权重,形状为 (num_embeddings, embedding_dim),从 N(0,1)\mathcal{N}(0, 1) 初始化。

形状:
  • 输入: ()(*) ,任意形状的 IntTensor 或 LongTensor,包含要提取的索引

  • 输出: (,H)(*, H) ,其中 * 是输入形状, H=embedding_dimH=\text{embedding\_dim}

注意

请注意,只有有限数量的优化器支持稀疏梯度:目前是 optim.SGD (CUDA 和 CPU)、 optim.SparseAdam (CUDA 和 CPU)以及 optim.Adagrad (CPU)

注意

max_norm 不是 None 时, Embedding 的前向方法将就地修改 weight 张量。由于梯度计算所需的张量不能就地修改,在调用 Embedding 的前向方法之前对 Embedding.weight 执行可微分操作需要克隆 Embedding.weight ,当 max_norm 不是 None 时。例如:

n, d, m = 3, 5, 7
embedding = nn.Embedding(n, d, max_norm=1.0)
W = torch.randn((m, d), requires_grad=True)
idx = torch.tensor([1, 2])
a = embedding.weight.clone() @ W.t()  # weight must be cloned for this to be differentiable
b = embedding(idx) @ W.t()  # modifies weight in-place
out = (a.unsqueeze(0) + b.unsqueeze(1))
loss = out.sigmoid().prod()
loss.backward()

示例:

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])


>>> # example with padding_idx
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0, 2, 0, 5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1535, -2.0309,  0.9315],
         [ 0.0000,  0.0000,  0.0000],
         [-0.1655,  0.9897,  0.0635]]])

>>> # example of changing `pad` vector
>>> padding_idx = 0
>>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
>>> embedding.weight
Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [-0.7895, -0.7089, -0.0364],
        [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
>>> with torch.no_grad():
...     embedding.weight[padding_idx] = torch.ones(3)
>>> embedding.weight
Parameter containing:
tensor([[ 1.0000,  1.0000,  1.0000],
        [-0.7895, -0.7089, -0.0364],
        [ 0.6778,  0.5803,  0.2678]], requires_grad=True)
classmethod from_pretrained(embeddings, freeze=True, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)[source][source]

从给定的二维 FloatTensor 创建嵌入实例。

参数:
  • 嵌入(Tensor)- 包含嵌入权重的 FloatTensor。第一维作为 num_embeddings 传递给嵌入,第二维作为 embedding_dim 传递。

  • 冻结(bool,可选)- 如果为 True ,则张量在训练过程中不会被更新。相当于 embedding.weight.requires_grad = False 。默认值: True

  • padding_idx(int,可选)- 如果指定,则 padding_idx 处的条目不会对梯度产生影响;因此,在训练过程中, padding_idx 处的嵌入向量不会被更新,即它保持为固定的“填充”。

  • max_norm (浮点数,可选) – 请参阅模块初始化文档。

  • norm_type (浮点数,可选) – 请参阅模块初始化文档。默认值 2

  • scale_grad_by_freq (布尔值,可选) – 请参阅模块初始化文档。默认值 False

  • sparse (布尔值,可选) – 请参阅模块初始化文档。

示例:

>>> # FloatTensor containing pretrained weights
>>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> # Get embeddings for index 1
>>> input = torch.LongTensor([1])
>>> embedding(input)
tensor([[ 4.0000,  5.1000,  6.3000]])

© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源