由 Logan Kilpatrick - 高级技术倡导者,Harshith Padigela - 机器学习工程师,Syed Ashar Javed - 机器学习技术主管,Robert Egger - 生物医学数据科学家

PathAI 是领先的 AI 驱动病理技术工具和服务提供商(病理学是研究疾病的一门学科)。我们的平台旨在通过利用现代机器学习方法(如图像分割、图神经网络和多个实例学习)来显著提高复杂疾病的诊断准确性和治疗效果的测量。

传统手动病理学容易受到主观性和观察者变异性的影响,这可能会对诊断和药物开发试验产生负面影响。在我们深入探讨如何使用 PyTorch 改进我们的诊断工作流程之前,让我们首先概述没有机器学习的传统模拟病理学工作流程。

传统生物制药是如何运作的

生物制药公司通过多种途径来发现新的治疗药物或诊断方法。其中一条途径高度依赖于病理切片的分析来回答各种问题:特定的细胞通讯通路是如何工作的?特定的疾病状态能否与特定蛋白质的存在或缺失联系起来?为什么在临床试验中,某种药物对一些患者有效而对另一些患者无效?患者预后与新型生物标志物之间是否存在关联?

为了回答这些问题,生物制药公司依赖经验丰富的病理学家来分析切片并帮助他们评估可能的问题。

想象一下,要由一位专家认证的病理学家才能做出准确的解读和诊断。在一项研究中,将同一活检结果分发给 36 位不同的病理学家,结果有 18 种不同的诊断,严重程度从无需治疗到需要积极治疗不等。病理学家在处理边缘案例时也经常向同事征求反馈。鉴于问题的复杂性,即使在专家培训和协作的情况下,病理学家仍然可能难以做出正确的诊断。这种潜在的差异可能是药物能否获得批准与临床试验失败之间的区别。

PathAI 如何利用机器学习推动药物研发

PathAI 开发了机器学习模型,这些模型为药物研发、临床试验的推动以及诊断提供洞察。为此,PathAI 利用 PyTorch 进行切片级推理,采用包括图神经网络(GNN)以及多实例学习等多种方法。在此背景下,“切片”指的是全尺寸扫描的玻片图像,这些玻片是夹有薄层组织切片的玻璃片,经过染色以显示各种细胞形态。PyTorch 使我们使用这些不同方法论的团队能够共享一个足够鲁棒、适用于所有所需条件的通用框架。PyTorch 的高级、命令式和 Pythonic 语法使我们能够快速原型化模型,并在获得所需结果后将其扩展到规模。

在千兆图像上进行多实例学习

病理学中应用机器学习的独特挑战之一是图像的巨大尺寸。这些数字切片的分辨率通常可以达到 100,000 x 100,000 像素或更高,大小为几吉字节。在 GPU 内存中加载完整图像并应用传统的计算机视觉算法几乎是不可能的任务。此外,对完整切片图像(100k x 100k)进行标注也需要相当的时间和资源,尤其是在标注者需要是领域专家(认证病理学家)的情况下。我们通常构建模型来预测患者切片上的图像级标签,如癌症的存在,整个图像覆盖的像素只有几千个。癌变区域有时只是整个切片的一小部分,这使得机器学习问题类似于在干草堆里找针。另一方面,像预测某些组织学生物标志物这样的问题需要从整个切片中汇总信息,由于图像的大小,这同样很困难。所有这些因素在将机器学习技术应用于病理学问题时都增加了显著的算法、计算和后勤复杂性。

将图像分解成更小的块,学习块表示,然后将这些表示进行池化以预测图像级标签是解决该问题的一种方法,如图下所示。实现这一目标的一种流行方法是称为多实例学习(MIL)。每个块被视为一个“实例”,一组块构成一个“包”。将单个块表示池化在一起以预测最终的包级标签。从算法上讲,包中的单个块实例不需要标签,因此允许我们以弱监督的方式学习包级标签。它们还使用排列不变池化函数,这使得预测与块顺序无关,并允许高效地聚合信息。通常使用基于注意力的池化函数,这不仅允许高效聚合,还为包中的每个块提供注意力值。这些值表示对应块在预测中的重要性,可以可视化以更好地理解模型预测。 这个可解释性元素对于推动这些模型在现实世界中的应用至关重要,我们使用如加法 MIL 模型等变体来实现这种空间可解释性。从计算角度来看,MIL 模型绕过了将神经网络应用于大图像尺寸的问题,因为补丁表示与图像大小无关。

在 PathAI,我们使用基于深度网络的定制 MIL 模型来预测图像级标签。此过程的概述如下:

  1. 使用不同的采样方法从幻灯片中选择补丁。
  2. 根据随机采样或启发式规则构建补丁包。
  3. 基于预训练模型或大规模表示学习模型,为每个实例生成补丁表示。
  4. 应用排列不变池化函数以获取最终的幻灯片级别得分。

现在我们已经了解了 PyTorch 中 MIL 的一些高级细节,让我们看看一些代码,看看如何使用 PyTorch 从构思到生产代码的简单实现。我们首先定义一个采样器、转换和我们的 MIL 数据集:

# Create a bag sampler which randomly samples patches from a slide
bag_sampler = RandomBagSampler(bag_size=12)

# Setup the transformations
crop_transform = FlipRotateCenterCrop(use_flips=True)

# Create the dataset which loads patches for each bag
train_dataset = MILDataset(
  bag_sampler=bag_sampler,
  samples_loader=sample_loader,
  transform=crop_transform,
)

在我们定义了采样器和数据集之后,我们需要定义我们将使用该数据集训练的模型。PyTorch 熟悉的模型定义语法使得这变得简单易行,同时允许我们创建定制模型。

classifier = DefaultPooledClassifier(hidden_dims=[256, 256], input_dims=1024, output_dims=1)

pooling = DefaultAttentionModule(
  input_dims=1024,
  hidden_dims=[256, 256],
  output_activation=StableSoftmax()
)

# Define the model which is a composition of the featurizer, pooling module and a classifier
model = DefaultMILGraph(featurizer=ShuffleNetV2(), classifier=classifier, pooling = pooling)

由于这些模型是端到端训练的,它们提供了一种直接从数吉像素的全景切片图像到单个标签的强大方式。由于它们广泛适用于不同的生物学问题,它们的实现和部署的两个方面很重要:

  1. 对管道的每个部分的可配置控制,包括数据加载器、模型的模块化部分以及它们之间的交互。
  2. 能够快速迭代通过构思-实施-实验-产品化的循环。

PyTorch 在 MIL 建模方面具有各种优势。它提供了一种直观的方式来创建动态计算图,具有灵活的控制流,非常适合快速研究实验。地图样式的数据集、可配置的采样器和批采样器使我们能够自定义构建补丁包的方式,从而实现更快的实验。由于 MIL 模型 I/O 密集,数据并行性和 Python 数据加载器使得任务非常高效且用户友好。最后,PyTorch 的面向对象特性使得构建可重用模块成为可能,这些模块有助于快速实验、可维护的实现以及构建管道组合组件的便捷性。

在 PyTorch 中使用 GNN 探索空间组织结构

在健康组织和病变组织中,细胞的空間排列和结构往往与细胞本身一样重要。例如,在评估肺癌时,病理学家试图观察肿瘤细胞的整体分组和结构(它们是否形成实心层?或者是否以较小的、局部的簇出现?)以确定癌症是否属于具有截然不同预后的特定亚型。细胞与其他组织结构之间的这种空间关系可以通过图来建模,以同时捕捉组织拓扑和细胞组成。图神经网络(GNN)允许在学习这些图中的空间模式时,关联其他临床变量,例如某些癌症中基因的过表达。

在 2020 年底,当 PathAI 开始在组织样本上使用 GNN 时,PyTorch 通过 PyG 包提供了最佳且最成熟的 GNN 功能支持。鉴于我们知道 GNN 模型将是我们要探索的重要机器学习概念,因此 PyTorch 自然成为我们团队的首选。

在组织样本的背景下,GNN 的主要增值之一是图本身可以揭示仅通过视觉检查难以发现的时空关系。在我们的最近 AACR 出版物中,我们展示了通过使用 GNN,我们可以更好地理解免疫细胞聚集(特别是三级淋巴结构,或 TLS)在肿瘤微环境中的存在如何影响患者预后。在这种情况下,GNN 方法被用来预测与 TLS 存在相关的基因表达,并识别 TLS 区域本身之外与 TLS 相关的组织学特征。这些关于基因表达的见解在没有机器学习模型辅助的情况下,从组织样本图像中很难识别。

我们在 GNN 变体中取得成功的一个最有希望的例子是自注意力图池化。让我们看看我们如何使用 PyTorch 和 PyG 定义我们的自注意力图池化(SAGPool)模型:

class SAGPool(torch.nn.Module):
  def __init__(self, ...):
    super().__init__()
    self.conv1 = GraphConv(in_features, hidden_features, aggr='mean')
    self.convs = torch.nn.ModuleList()
    self.pools = torch.nn.ModuleList()
    self.convs.extend([GraphConv(hidden_features, hidden_features, aggr='mean') for i in range(num_layers - 1)])
    self.pools.extend([SAGPooling(hidden_features, ratio, GNN=GraphConv, min_score=min_score) for i in range((num_layers) // 2)])
    self.jump = JumpingKnowledge(mode='cat')
    self.lin1 = Linear(num_layers * hidden_features, hidden_features)
    self.lin2 = Linear(hidden_features, out_features)
    self.out_activation = out_activation
    self.dropout = dropout

在上面的代码中,我们首先定义了一个单个卷积图层,然后添加了两个模块列表层,这使得我们可以传递可变数量的层。然后,我们取我们的空模块列表,并追加可变数量的 GraphConv 层,然后追加可变数量的 SAGPooling 层。我们通过添加跳跃知识层、两个线性层、我们的激活函数和我们的 dropout 值来完成我们的 SAGPool 定义。PyTorch 直观的语法允许我们抽象出使用 SAG Poolings 等最先进方法的工作复杂性,同时保持我们熟悉的常见模型开发方法。

如上所述的我们的 SAG Pool 模型只是 GNNs 与 PyTorch 一起使我们能够探索新思想和创新的一个例子。我们最近还探索了多模态 CNN-GNN 混合模型,其准确率比传统的病理学家共识评分高出 20%。这些创新以及传统 CNN 和 GNN 之间的相互作用是由从研究到生产模型开发的短循环所实现的。

提高患者预后

为了实现我们利用人工智能病理学改善患者预后的使命,PathAI 需要依赖一个机器学习开发框架,该框架(1)在开发初期和探索阶段促进快速迭代和易于扩展(即模型配置为代码)(2)将模型训练和推理扩展到海量图像(3)轻松且稳健地为我们的产品生产使用(包括临床试验等)提供服务。正如我们所展示的,PyTorch 为我们提供了所有这些能力以及更多。我们对 PyTorch 的未来感到无比兴奋,迫不及待地想看看我们还能用这个框架解决哪些有影响力的挑战。