备注
点击此处下载完整示例代码
TorchRec 入门教程
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 10 月 2 日
TorchRec 是一个针对构建可扩展和高效的推荐系统而优化的 PyTorch 库。本教程将指导您完成安装过程,介绍嵌入的概念,并强调其在推荐系统中的重要性。它提供了使用 PyTorch 和 TorchRec 实现嵌入的实际演示,重点关注通过分布式训练和高级优化处理大型嵌入表。
嵌入的基本原理及其在推荐系统中的作用
如何在 PyTorch 环境中设置 TorchRec 以管理和实现嵌入
探索在多个 GPU 上分布大型嵌入表的先进技术
PyTorch v2.5 或更高版本,CUDA 11.8 或更高版本
Python 3.9 或更高版本
安装依赖项
在 Google Colab 或其他环境中运行本教程之前,请安装以下依赖项:
!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121
备注
如果你在 Google Colab 中运行此代码,请确保切换到 GPU 运行时类型。更多信息请参阅启用 CUDA
嵌入式
在构建推荐系统时,分类特征通常具有巨大的基数,如帖子、用户、广告等。
为了表示这些实体和建模这些关系,使用嵌入。在机器学习中,嵌入是高维空间中实数向量,用于表示复杂数据中的意义,如单词、图像或用户。
在推荐系统中的嵌入
现在你可能会想,这些嵌入是如何最初生成的呢?嗯,嵌入被表示为嵌入表中的单独行,也称为嵌入权重。这是因为嵌入或嵌入表权重就像模型的所有其他权重一样,通过梯度下降进行训练!
嵌入表只是一个用于存储嵌入的大矩阵,具有两个维度(B,N),其中:
B 是表中存储的嵌入数量
N 是每个嵌入的维度数(N 维嵌入)。
嵌入表的输入代表嵌入查找,用于检索特定索引或行的嵌入。在推荐系统中,如许多大型系统中使用的系统,唯一 ID 不仅用于特定用户,还用于帖子、广告等实体,作为查找嵌入表的索引!
嵌入在 RecSys 中通过以下过程进行训练:
输入/查找索引作为唯一 ID 输入到模型中,以防止 ID 大于行数时出现的问题,ID 会被散列到嵌入表的总大小。
嵌入随后被检索并汇总,例如取嵌入的总和或平均值。这是必需的,因为每个示例的嵌入数量可能不同,而模型期望一致的形状。
嵌入与模型的其余部分结合使用以生成预测,例如广告的点击通过率(CTR)。
使用预测和示例的标签计算损失,并通过梯度下降和反向传播更新模型的所有权重,包括与示例关联的嵌入权重。
这些嵌入对于表示分类特征至关重要,例如用户、帖子、广告,以捕捉关系并做出良好的推荐。深度学习推荐模型(DLRM)论文更多地讨论了在 RecSys 中使用嵌入表的技术细节。
本教程介绍了嵌入的概念,展示了 TorchRec 特定的模块和数据类型,并描述了 TorchRec 如何进行分布式训练。
import torch
PyTorch 中的嵌入 ¶
在 PyTorch 中,我们有以下类型的嵌入:
torch.nn.Embedding
:一个嵌入表,前向传递返回嵌入本身。嵌入表嵌入,其中前向传递返回嵌入,然后进行池化,例如求和或平均值,也称为池化嵌入。
在本节中,我们将简要介绍通过传递索引到表中执行嵌入查找。
num_embeddings, embedding_dim = 10, 4
# Initialize our embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights:", weights)
# Pass in pre-generated weights just for example, typically weights are randomly initialized
embedding_collection = torch.nn.Embedding(
num_embeddings, embedding_dim, _weight=weights
)
embedding_bag_collection = torch.nn.EmbeddingBag(
num_embeddings, embedding_dim, _weight=weights
)
# Print out the tables, we should see the same weights as above
print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)
# Lookup rows (ids for embedding ids) from the embedding tables
# 2D tensor with shape (batch_size, ids for each batch)
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)
# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)
# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)
# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))
恭喜!现在您已经基本了解了如何使用嵌入表——现代推荐系统的基础之一!这些表表示实体及其关系。例如,给定用户与他们喜欢的页面和帖子之间的关系。
TorchRec 功能概述 ¶
在上面的部分,我们已经学习了如何使用嵌入表,这是现代推荐系统的基础之一!这些表代表实体和关系,例如用户、页面、帖子等。鉴于这些实体总是不断增加,通常会应用哈希函数以确保 ID 在某个嵌入表的范围内。然而,为了表示大量实体并减少哈希冲突,这些表可能会变得非常大(例如,想想广告的数量)。事实上,这些表可能会变得如此之大,以至于即使有 80G 的内存,也无法在 1 个 GPU 上容纳。
为了训练具有巨大嵌入表的模型,需要将这些表分片到多个 GPU 上,这随后引入了并行化和优化方面的一系列新问题和新机遇。幸运的是,我们有了 TorchRec 库,它已经遇到了、整合并解决了许多这些问题。TorchRec 是一个提供大规模分布式嵌入原语的库。
接下来,我们将探讨 TorchRec 库的主要功能。我们将从 torch.nn.Embedding
开始,并将其扩展到自定义 TorchRec 模块,探讨具有生成嵌入分片计划的分布式训练环境,查看 TorchRec 的固有优化,并将模型扩展到 C++推理准备就绪。以下是本节内容的简要概述:
TorchRec 模块和数据类型
分布式训练、分片和优化
推理
让我们从导入 TorchRec 开始:
import torchrec
本节介绍了 TorchRec 模块和数据类型,包括 EmbeddingCollection
和 EmbeddingBagCollection
, JaggedTensor
, KeyedJaggedTensor
, KeyedTensor
等实体。
从 EmbeddingBag
到 EmbeddingBagCollection
¶
我们已经探讨了 torch.nn.Embedding
和 torch.nn.EmbeddingBag
。TorchRec 通过创建包含多个嵌入表的集合来扩展这些模块,换句话说,就是可以拥有多个嵌入表的模块,通过 EmbeddingCollection
和 EmbeddingBagCollection
,我们将使用 EmbeddingBagCollection
来表示一组嵌入包。
在下面的示例代码中,我们创建了一个包含两个嵌入包的 EmbeddingBagCollection
(EBC),一个代表产品,一个代表用户。每个表, product_table
和 user_table
,都由一个 64 维嵌入表示,大小为 4096。
ebc = torchrec.EmbeddingBagCollection(
device="cpu",
tables=[
torchrec.EmbeddingBagConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
pooling=torchrec.PoolingType.SUM,
),
torchrec.EmbeddingBagConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
pooling=torchrec.PoolingType.SUM,
)
]
)
print(ebc.embedding_bags)
让我们来检查 EmbeddingBagCollection
的前向方法以及模块的输入和输出:
import inspect
# Let's look at the ``EmbeddingBagCollection`` forward method
# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?
print(inspect.getsource(ebc.forward))
TorchRec 输入/输出数据类型 ¶
TorchRec 为其模块的输入和输出定义了不同的数据类型: JaggedTensor
、 KeyedJaggedTensor
和 KeyedTensor
。现在你可能想知道,为什么要创建新的数据类型来表示稀疏特征?为了回答这个问题,我们必须理解稀疏特征在代码中的表示方式。
稀疏特征也被称为 id_list_feature
和 id_score_list_feature
,它们是作为索引嵌入表以检索该 ID 嵌入的 ID。为了给出一个非常简单的例子,想象一个单一的稀疏特征是用户互动过的广告。输入本身将是一组用户互动过的广告 ID,检索到的嵌入将是这些广告的语义表示。在代码中表示这些特征的难点在于,每个输入示例中 ID 的数量是可变的。有一天用户可能只互动了一个广告,而第二天他们可能互动了三个。
下面展示了一个简单的表示,其中我们有一个 lengths
张量表示一个批次中示例的索引数量,以及一个 values
张量包含索引本身。
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
接下来,让我们看看偏移量以及每个批次中包含的内容。
# Lengths can be converted to offsets for easy indexing of values
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print(
"Second Batch: ",
id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],
)
from torchrec import JaggedTensor
# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!
jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
# Automatically compute offsets from lengths
print("Offsets: ", jt.offsets())
# Convert to list of values
print("List of Values: ", jt.to_dense())
# ``__str__`` representation
print(jt)
from torchrec import KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])
)
user_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))
# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})
# Look at our feature keys for the ``KeyedJaggedTensor``
print("Keys: ", kjt.keys())
# Look at the overall lengths for the ``KeyedJaggedTensor``
print("Lengths: ", kjt.lengths())
# Look at all values for ``KeyedJaggedTensor``
print("Values: ", kjt.values())
# Can convert ``KeyedJaggedTensor`` to dictionary representation
print("to_dict: ", kjt.to_dict())
# ``KeyedJaggedTensor`` string representation
print(kjt)
# Q2: What are the offsets for the ``KeyedJaggedTensor``?
# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before
result = ebc(kjt)
result
# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results
print(result.keys())
# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined
# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables "product" and "user" of dimension 64 each
# meaning embeddings for both features are of size 64. 64 + 64 = 128
print(result.values().shape)
# Nice to_dict method to determine the embeddings that belong to each feature
result_dict = result.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
恭喜!你现在已经理解了 TorchRec 模块和数据类型。给自己鼓掌吧,你已经走得很远了。接下来,我们将学习分布式训练和分片。
分布式训练和分片
现在我们已经掌握了 TorchRec 模块和数据类型,是时候将其提升到下一个层次了。
记住,TorchRec 的主要目的是提供分布式嵌入的原语。到目前为止,我们只在一个设备上处理嵌入表。鉴于嵌入表相对较小,这一直是可能的,但在生产环境中通常并非如此。嵌入表通常变得非常大,一个表无法适应单个 GPU,这就产生了对多个设备和分布式环境的需求。
在本节中,我们将探讨如何设置分布式环境,实际生产训练是如何进行的,以及如何使用 TorchRec 来分片嵌入表。
本节也将仅使用 1 个 GPU,尽管它将以分布式的方式处理。这仅是训练的限制,因为训练有一个与 GPU 相对应的过程。推理不会遇到这个要求。
以下示例代码中,我们设置了我们的 PyTorch 分布式环境。
警告
如果你在 Google Colab 中运行此代码,你只能调用此单元格一次,再次调用将导致错误,因为你只能初始化一次进程组。
import os
import torch.distributed as dist
# Set up environment variables for distributed training
# RANK is which GPU we are on, default 0
os.environ["RANK"] = "0"
# How many devices in our "world", colab notebook can only handle 1 process
os.environ["WORLD_SIZE"] = "1"
# Localhost as we are training locally
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"
# nccl backend is for GPUs, gloo is for CPUs
dist.init_process_group(backend="gloo")
print(f"Distributed environment initialized: {dist}")
分布式嵌入
我们已经与主要的 TorchRec 模块进行了合作: EmbeddingBagCollection
。我们已经研究了它的工作原理以及数据在 TorchRec 中的表示方式。然而,我们尚未探索 TorchRec 的主要部分之一,即分布式嵌入。
目前,GPU 是 ML 工作负载最受欢迎的选择,因为它们能够执行比 CPU 大得多的浮点运算次数/s(FLOPs)。然而,GPU 的限制是快速内存(HBM,类似于 CPU 的 RAM)稀缺,通常为数十 GB。
RecSys 模型可以包含超出 1 个 GPU 内存限制的嵌入表,因此需要将嵌入表分布到多个 GPU 上,这被称为模型并行。另一方面,数据并行是将整个模型复制到每个 GPU 上,每个 GPU 接收不同的数据批次进行训练,在反向传播过程中同步梯度。
需要较少计算但更多内存的部分(嵌入)使用模型并行进行分布,而需要更多计算和较少内存的部分(密集层、MLP 等)使用数据并行进行分布。
分片
为了分发嵌入表,我们将嵌入表分割成多个部分,并将这些部分放置在不同的设备上,这被称为“分片”。
分片嵌入表的方法有很多。最常见的方法有:
表分片:整个表被放置在一个设备上
列分片:嵌入表的列被分片
行内:嵌入表的行被分片
分片模块
虽然这一切看起来处理和实现起来似乎很多,但你很幸运。TorchRec 提供了所有原语以实现易于分布式训练和推理!实际上,TorchRec 模块有两个相应的类,用于在分布式环境中处理任何 TorchRec 模块:
模块分片器:此类公开了一个
shard
API,用于处理分片 TorchRec 模块,生成分片模块。* 对于EmbeddingBagCollection
,分片器是 EmbeddingBagCollectionSharder分片模块:此类是 TorchRec 模块的分片变体。它具有与常规 TorchRec 模块相同的输入/输出,但优化程度更高,且适用于分布式环境。* 对于
EmbeddingBagCollection
,分片变体是 ShardedEmbeddingBagCollection
每个 TorchRec 模块都有一个非分片和分片变体。
非分片版本旨在进行原型设计和实验。
分片版本旨在在分布式环境中用于分布式训练和推理。
火炬 Rec 模块的碎片化版本,例如 EmbeddingBagCollection
,将处理模型并行所需的所有内容,例如在 GPU 之间进行通信,将嵌入分布到正确的 GPU 上。
我们 EmbeddingBagCollection
模块的复习
ebc
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv
# Corresponding sharder for ``EmbeddingBagCollection`` module
sharder = EmbeddingBagCollectionSharder()
# ``ProcessGroup`` from torch.distributed initialized 2 cells above
pg = dist.GroupMember.WORLD
assert pg is not None, "Process group is not initialized"
print(f"Process Group: {pg}")
规划器
在我们展示碎片化如何工作之前,我们必须了解规划器,它帮助我们确定最佳的碎片化配置。
给定多个嵌入表和多个排名,存在许多可能的分片配置。例如,给定 2 个嵌入表和 2 个 GPU,你可以:
每个 GPU 放置 1 个表
将两个表都放置在一个 GPU 上,另一个 GPU 上不放置任何表
将某些行和列放置在每个 GPU 上
在所有这些可能性中,我们通常希望有一个针对性能最优的分区配置。
正是在这里,规划器发挥了作用。规划器能够根据嵌入表的数量和 GPU 的数量,确定最优配置。实际上,手动进行这项工作非常困难,工程师需要考虑众多因素以确保最优的分区计划。幸运的是,当使用规划器时,TorchRec 提供了一个自动规划器。
TorchRec 规划器:
评估硬件的内存限制
基于内存读取作为嵌入查找的估算
考虑数据特定因素
考虑其他硬件特定因素,如带宽以生成最佳分片计划
为了考虑所有这些变量,TorchRec 规划器可以接受各种数量的嵌入表、约束、硬件信息和拓扑数据,以帮助生成模型的最佳分片计划,该计划通常在各个堆栈中提供。
想了解更多关于分片的信息,请参阅我们的分片教程。
# In our case, 1 GPU and compute on CUDA device
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size=1,
compute_device="cuda",
)
)
# Run planner to get plan for sharding
plan = planner.collective_plan(ebc, [sharder], pg)
print(f"Sharding Plan generated: {plan}")
规划结果 ¶
如上图所示,在运行规划器时,会有相当多的输出。我们可以看到许多统计数据,以及我们的表最终放置的位置。
运行规划器的结果是静态计划,可以重复用于分片!这使得生产模型中的分片可以保持静态,而不是每次都确定一个新的分片计划。下面,我们使用分片计划最终生成我们的 ShardedEmbeddingBagCollection
。
# The static plan that was generated
plan
env = ShardingEnv.from_process_group(pg)
# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))
print(f"Sharded EBC Module: {sharded_ebc}")
使用 LazyAwaitable
进行 GPU 训练 ¶
记住 TorchRec 是一个高度优化的分布式嵌入库。TorchRec 引入的一个概念,用于在 GPU 上实现更高的性能,称为 LazyAwaitable。您将看到 LazyAwaitable
作为各种分片 TorchRec 模块的输出。 LazyAwaitable
类型所做的只是尽可能延迟计算某些结果,它通过表现得像一个异步类型来实现。
from typing import List
from torchrec.distributed.types import LazyAwaitable
# Demonstrate a ``LazyAwaitable`` type:
class ExampleAwaitable(LazyAwaitable[torch.Tensor]):
def __init__(self, size: List[int]) -> None:
super().__init__()
self._size = size
def _wait_impl(self) -> torch.Tensor:
return torch.ones(self._size)
awaitable = ExampleAwaitable([3, 2])
awaitable.wait()
kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?
print(output)
kt = output.wait()
# Now we have our ``KeyedTensor`` after calling ``.wait()``
# If you are confused as to why we have a ``KeyedTensor ``output,
# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module
print(type(kt))
print(kt.keys())
print(kt.values().shape)
# Same output format as unsharded ``EmbeddingBagCollection``
result_dict = kt.to_dict()
for key, embedding in result_dict.items():
print(key, embedding.shape)
分片 TorchRec 模块的解剖结构 ¶
现在我们已经成功地对一个 EmbeddingBagCollection
进行了分片,因为我们生成了一个分片计划!分片模块具有来自 TorchRec 的常用 API,这些 API 抽象化了多个 GPU 之间的分布式通信/计算。事实上,这些 API 在训练和推理方面都进行了高度优化。以下是 TorchRec 提供的三个用于分布式训练/推理的常用 API:
input_dist
:处理将输入从 GPU 分发到 GPU。使用 FBGEMM TBE 进行优化的批量查找实际嵌入。
处理将输出从 GPU 分发到 GPU 的过程。
输入和输出的分发是通过 NCCL 集体操作完成的,即所有 GPU 相互发送和接收数据。TorchRec 与 PyTorch 分布式集体操作接口集成,为最终用户提供干净的抽象,消除了对底层细节的担忧。
反向传播执行所有这些集体操作,但顺序相反,用于梯度分发。 input_dist
、 lookup
和 output_dist
都依赖于分片方案。由于我们按表的方式进行分片,这些 API 是由 TwPooledEmbeddingSharding 构建的模块。
sharded_ebc
# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists
# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists
优化嵌入查找 ¶
在执行嵌入表集合的查找时,一个简单的解决方案是遍历所有 nn.EmbeddingBags
并逐个表进行查找。这正是标准未分片的 EmbeddingBagCollection
所做的事情。然而,虽然这个解决方案很简单,但它非常慢。
FBGEMM 是一个提供非常优化的 GPU 操作(也称为内核)的库。其中之一是一个名为表批处理嵌入(TBE)的内核,提供了两大优化:
表批处理,允许您通过一个内核调用查找多个嵌入。
优化融合,允许模块在给定规范 PyTorch 优化器和参数的情况下更新自身。
ShardedEmbeddingBagCollection
使用 FBGEMM TBE 作为查找,而不是传统的 nn.EmbeddingBags
以优化嵌入查找。
sharded_ebc._lookups
DistributedModelParallel
¶
我们现在已经探索了单个 EmbeddingBagCollection
的分片!我们能够将 EmbeddingBagCollectionSharder
和未分片的 EmbeddingBagCollection
结合起来生成一个 ShardedEmbeddingBagCollection
模块。这种工作流程是可行的,但在实现模型并行时,通常使用 DistributedModelParallel (DMP) 作为标准接口。当用 DMP 包装你的模型(在我们的例子中为 ebc
)时,以下情况将会发生:
决定如何分片模型。DMP 将收集可用的分片器并制定一个最优分片嵌入表(例如,
EmbeddingBagCollection
)的计划。实际上要分片模型。这包括在适当的设备(们)上为每个嵌入表分配内存。
DMP 接受我们刚刚实验过的所有内容,比如静态分片计划、分片器列表等。然而,它还有一些很好的默认设置,可以无缝分片 TorchRec 模型。在这个玩具示例中,因为我们有两个嵌入表和一个 GPU,TorchRec 会将它们都放置在单个 GPU 上。
ebc
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
out = model(kjt)
out.wait()
model
分片最佳实践 ¶
目前,我们的配置仅在 1 个 GPU(或 rank)上分片,这是微不足道的:只需将所有表放置在 1 个 GPU 的内存中即可。然而,在实际的生产用例中,嵌入表通常在数百个 GPU 上分片,使用不同的分片方法,如按表分片、按行分片和按列分片。在防止内存不足问题的同时,保持内存和计算平衡以实现最佳性能,确定适当的分片配置至关重要。
添加优化器 ¶
请记住,TorchRec 模块针对大规模分布式训练进行了超优化。其中一项重要的优化是关于优化器。
TorchRec 模块提供了一个无缝的 API,用于融合反向传播和优化步骤,从而在性能上提供了显著的优化,并减少了内存使用,同时还可以为不同的模型参数分配不同的优化器,提高了粒度。
优化器类 ¶
TorchRec 使用 CombinedOptimizer
,其中包含一个 KeyedOptimizers
集合。一个 CombinedOptimizer
有效地使得处理模型中各种子组的多个优化器变得简单。一个 KeyedOptimizer
扩展了 torch.optim.Optimizer
,并通过参数字典初始化,暴露了参数。模型中的每个 TBE
模块都将拥有自己的 KeyedOptimizer
,这些模块组合成一个 CombinedOptimizer
。
TorchRec 中的融合优化器 ¶
使用 DistributedModelParallel
,优化器被融合,这意味着优化器更新是在反向传播中完成的。这是 TorchRec 和 FBGEMM 中的优化,其中优化器嵌入梯度不会被实体化并直接应用于参数。这带来了显著的内存节省,因为嵌入梯度通常与参数大小相当。
然而,你可以选择使优化器 dense
不应用此优化,这样你可以检查嵌入梯度或按需对其应用计算。在这种情况下,密集优化器将是你的标准 PyTorch 模型训练循环中的优化器。
通过 DistributedModelParallel
创建优化器后,您仍然需要管理与 TorchRec 嵌入模块不相关的其他参数的优化器。要找到其他参数,请使用 in_backward_optimizer_filter(model.named_parameters())
。将这些参数应用优化器,就像应用正常的 Torch 优化器一样,并将这些与 model.fused_optimizer
合并为一个 CombinedOptimizer
,您可以在训练循环中使用它来 zero_grad
和 step
。
添加优化器到 EmbeddingBagCollection
¶
我们将以两种方式完成此操作,这两种方式等效,但根据您的偏好提供选择:
通过 sharder 中的
fused_params
传递优化器 kwargs。通过
apply_optimizer_in_backward
,将优化器参数转换为fused_params
,以传递给TBE
在EmbeddingBagCollection
或EmbeddingCollection
中。
# Option 1: Passing optimizer kwargs through fused parameters
from torchrec.optim.optimizers import in_backward_optimizer_filter
from fbgemm_gpu.split_embedding_configs import EmbOptimType
# We initialize the sharder with
fused_params = {
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
"learning_rate": 0.02,
"eps": 0.002,
}
# Initialize sharder with ``fused_params``
sharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
# We'll use same plan and unsharded EBC as before but this time with our new sharder
sharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[""], env, torch.device("cuda"))
# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.
# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied
print(f"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")
print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")
from torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward
import copy
# Option 2: Applying optimizer through apply_optimizer_in_backward
# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it
# We can achieve the same result as we did in the previous
ebc_apply_opt = copy.deepcopy(ebc)
optimizer_kwargs = {"lr": 0.5}
for name, param in ebc_apply_opt.named_parameters():
print(f"{name=}")
apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)
sharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[""], env, torch.device("cuda"))
# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted
print(sharded_ebc_apply_opt.fused_optimizer)
print(type(sharded_ebc_apply_opt.fused_optimizer))
# We can also check through the filter other parameters that aren't associated with the "fused" optimizer(s)
# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC
# there are no other parameters that aren't associated with TorchRec
print("Non Fused Model Parameters:")
print(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())
# Here we do a dummy backwards call and see that parameter updates for fused
# optimizers happen as a result of the backward pass
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
print(f"First Iteration Loss: {loss}")
loss.backward()
ebc_output = sharded_ebc_fused_params(kjt).wait().values()
loss = torch.sum(torch.ones_like(ebc_output) - ebc_output)
# We don't call an optimizer.step(), so for the loss to have changed here,
# that means that the gradients were somehow updated, which is what the
# fused optimizer automatically handles for us
print(f"Second Iteration Loss: {loss}")
推理 ¶
现在我们能够训练分布式嵌入,那么我们如何将训练好的模型优化以用于推理呢?推理通常对模型的性能和大小非常敏感。仅在 Python 环境中运行训练好的模型效率极低。推理环境和训练环境之间有两个关键区别:
量化:推理模型通常进行量化,模型参数在预测中降低精度以实现更低的延迟和更小的模型大小。例如,将训练模型中的 FP32(4 字节)转换为每个嵌入权重为 INT8(1 字节)。鉴于嵌入表的大规模,这也是必要的,因为我们希望尽可能少地使用设备进行推理以最小化延迟。
C++环境:推理延迟非常重要,因此为了确保足够的性能,模型通常在 C++环境中运行,包括在没有 Python 运行时的情况,例如在设备上。
TorchRec 提供将 TorchRec 模型转换为推理就绪状态的接口:
提供量化模型的 API,自动引入 FBGEMM TBE 优化
对分布式推理进行分片嵌入
将模型编译为 TorchScript(与 C++ 兼容)
在本节中,我们将详细介绍以下整个工作流程:
模型量化
分片量化模型
将分片量化模型编译成 TorchScript
ebc
class InferenceModule(torch.nn.Module):
def __init__(self, ebc: torchrec.EmbeddingBagCollection):
super().__init__()
self.ebc_ = ebc
def forward(self, kjt: KeyedJaggedTensor):
return self.ebc_(kjt)
module = InferenceModule(ebc)
for name, param in module.named_parameters():
# Here, the parameters should still be FP32, as we are using a standard EBC
# FP32 is default, regularly used for training
print(name, param.shape, param.dtype)
量化
如上图所示,正常的 EBC 包含 FP32 精度的嵌入表权重(每个权重 32 位)。在这里,我们将使用 TorchRec 推理库将模型的嵌入权重量化为 INT8
from torch import quantization as quant
from torchrec.modules.embedding_configs import QuantConfig
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
)
quant_dtype = torch.int8
qconfig = QuantConfig(
# dtype of the result of the embedding lookup, post activation
# torch.float generally for compatibility with rest of the model
# as rest of the model here usually isn't quantized
activation=quant.PlaceholderObserver.with_args(dtype=torch.float),
# quantized type for embedding weights, aka parameters to actually quantize
weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),
)
qconfig_spec = {
# Map of module type to qconfig
torchrec.EmbeddingBagCollection: qconfig,
}
mapping = {
# Map of module type to quantized module type
torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,
}
module = InferenceModule(ebc)
# Quantize the module
qebc = quant.quantize_dynamic(
module,
qconfig_spec=qconfig_spec,
mapping=mapping,
inplace=False,
)
print(f"Quantized EBC: {qebc}")
kjt = kjt.to("cpu")
qebc(kjt)
# Once quantized, goes from parameters -> buffers, as no longer trainable
for name, buffer in qebc.named_buffers():
# The shapes of the tables should be the same but the dtype should be int8 now
# post quantization
print(name, buffer.shape, buffer.dtype)
分片
在这里,我们执行 TorchRec 量化模型的分片操作。这是为了确保我们通过 FBGEMM TBE 使用性能模块。在这里,我们使用一个设备以保持与训练的一致性(1 个 TBE)
from torchrec import distributed as trec_dist
from torchrec.distributed.shard import _shard_modules
sharded_qebc = _shard_modules(
module=qebc,
device=torch.device("cpu"),
env=trec_dist.ShardingEnv.from_local(
1,
0,
),
)
print(f"Sharded Quantized EBC: {sharded_qebc}")
sharded_qebc(kjt)
编译
现在我们有了优化的 TorchRec 推理模型。下一步是确保这个模型可以在 C++中加载,因为目前它只能在 Python 运行时中运行。
Meta 推荐的编译方法有两个方面:torch.fx 跟踪(生成模型的中间表示)和将结果转换为 TorchScript,其中 TorchScript 与 C++兼容。
from torchrec.fx import Tracer
tracer = Tracer(leaf_modules=["IntNBitTableBatchedEmbeddingBagsCodegen"])
graph = tracer.trace(sharded_qebc)
gm = torch.fx.GraphModule(sharded_qebc, graph)
print("Graph Module Created!")
print(gm.code)
scripted_gm = torch.jit.script(gm)
print("Scripted Graph Module Created!")
print(scripted_gm.code)
结论 ¶
在本教程中,您已经从训练分布式 RecSys 模型一直做到使其推理就绪。TorchRec 仓库提供了一个完整的示例,展示了如何将 TorchRec TorchScript 模型加载到 C++中进行推理。
如需更多信息,请参阅我们的 dlrm 示例,该示例包括使用在《深度学习推荐模型:个性化与推荐系统》中描述的方法在 Criteo 1TB 数据集上进行的多节点训练。
脚本总运行时间:(0 分钟 0.000 秒)