备注
点击此处下载完整示例代码
基于 Ax 的多目标神经架构搜索(NAS)
创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日
作者:David Eriksson、Max Balandat 以及 Meta 的适应性实验团队
在本教程中,我们展示了如何使用 Ax 在流行的 MNIST 数据集上运行多目标神经架构搜索(NAS)来对简单的神经网络模型进行搜索。虽然底层方法通常用于更复杂的模型和更大的数据集,但我们选择了一个可以在 20 分钟内轻松从头到尾在笔记本电脑上运行的教程。
在许多网络附加存储(NAS)应用中,多个目标之间存在自然权衡。例如,当在设备上部署模型时,我们可能希望最大化模型性能(例如,准确性),同时最小化竞争指标,如功耗、推理延迟或模型大小,以满足部署限制。通常,我们可以通过接受模型性能略微降低来显著减少计算需求或预测延迟。探索此类权衡的原理性方法是实现可扩展和可持续人工智能的关键推动力,并在 Meta 有许多成功的应用 - 例如,请参阅我们关于自然语言理解模型的案例研究。
在本例中,我们将调整两个隐藏层的宽度、学习率、dropout 概率、批量大小和训练轮数。目标是权衡性能(验证集上的准确性)和模型大小(模型参数数量)。
本教程使用了以下 PyTorch 库:
PyTorch Lightning(指定模型和训练循环)
TorchX(用于远程/异步运行训练作业)
BoTorch(为 Ax 的算法提供动力的贝叶斯优化库)
定义 TorchX 应用程序 §
我们的目的是优化在 mnist_train_nas.py 中定义的 PyTorch Lightning 训练作业。为此,我们编写一个辅助函数,该函数接受训练作业的架构和超参数的值,并创建具有适当设置的 TorchX AppDef。
from pathlib import Path
import torchx
from torchx import specs
from torchx.components import utils
def trainer(
log_path: str,
hidden_size_1: int,
hidden_size_2: int,
learning_rate: float,
epochs: int,
dropout: float,
batch_size: int,
trial_idx: int = -1,
) -> specs.AppDef:
# define the log path so we can pass it to the TorchX ``AppDef``
if trial_idx >= 0:
log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()
return utils.python(
# command line arguments to the training script
"--log_path",
log_path,
"--hidden_size_1",
str(hidden_size_1),
"--hidden_size_2",
str(hidden_size_2),
"--learning_rate",
str(learning_rate),
"--epochs",
str(epochs),
"--dropout",
str(dropout),
"--batch_size",
str(batch_size),
# other config options
name="trainer",
script="mnist_train_nas.py",
image=torchx.version.TORCHX_IMAGE,
)
设置运行器
Ax 的运行器抽象允许编写对各种后端的接口。Ax 已经自带了 TorchX 的运行器,所以我们只需要进行配置。在本教程中,我们将以完全异步的方式在本地运行作业。
为了在集群上启动它们,您可以指定不同的 TorchX 调度器并相应地调整配置。例如,如果您有一个 Kubernetes 集群,只需将调度器从 local_cwd
更改为 kubernetes
)。
import tempfile
from ax.runners.torchx import TorchXRunner
# Make a temporary dir to log our results into
log_dir = tempfile.mkdtemp()
ax_runner = TorchXRunner(
tracker_base="/tmp/",
component=trainer,
# NOTE: To launch this job on a cluster instead of locally you can
# specify a different scheduler and adjust arguments appropriately.
scheduler="local_cwd",
component_const_params={"log_path": log_dir},
cfg={},
)
设置 SearchSpace
¶
首先,我们定义我们的搜索空间。Ax 支持整数和浮点类型的范围参数以及具有非数值类型(如字符串)的选择参数。我们将以范围参数调整隐藏层大小、学习率、dropout 和训练轮数,并将批处理大小作为有序选择参数进行调整,以确保其为 2 的幂。
from ax.core import (
ChoiceParameter,
ParameterType,
RangeParameter,
SearchSpace,
)
parameters = [
# NOTE: In a real-world setting, hidden_size_1 and hidden_size_2
# should probably be powers of 2, but in our simple example this
# would mean that ``num_params`` can't take on that many values, which
# in turn makes the Pareto frontier look pretty weird.
RangeParameter(
name="hidden_size_1",
lower=16,
upper=128,
parameter_type=ParameterType.INT,
log_scale=True,
),
RangeParameter(
name="hidden_size_2",
lower=16,
upper=128,
parameter_type=ParameterType.INT,
log_scale=True,
),
RangeParameter(
name="learning_rate",
lower=1e-4,
upper=1e-2,
parameter_type=ParameterType.FLOAT,
log_scale=True,
),
RangeParameter(
name="epochs",
lower=1,
upper=4,
parameter_type=ParameterType.INT,
),
RangeParameter(
name="dropout",
lower=0.0,
upper=0.5,
parameter_type=ParameterType.FLOAT,
),
ChoiceParameter( # NOTE: ``ChoiceParameters`` don't require log-scale
name="batch_size",
values=[32, 64, 128, 256],
parameter_type=ParameterType.INT,
is_ordered=True,
sort_values=True,
),
]
search_space = SearchSpace(
parameters=parameters,
# NOTE: In practice, it may make sense to add a constraint
# hidden_size_2 <= hidden_size_1
parameter_constraints=[],
)
设置指标
Ax 有一个指标的概念,它定义了结果属性以及如何获取这些结果。这允许例如编码从某些分布式执行后端获取数据的方式以及在进行处理之前将其作为输入传递给 Ax。
在本教程中,我们将使用多目标优化,目标是最大化验证准确率并最小化模型参数数量。后者代表模型延迟的简单代理,对于小型机器学习模型来说,准确估计其延迟是困难的(在实际应用中,我们会在设备上运行模型时进行延迟基准测试)。
在我们的示例中,TorchX 将以完全异步的方式在本地运行训练任务,并将结果写入基于试验索引的 log_dir
(参见上面的 trainer()
函数)。我们将定义一个了解该日志目录的度量类。通过继承 TensorboardCurveMetric,我们可以免费获得读取和解析 TensorBoard 日志的逻辑。
from ax.metrics.tensorboard import TensorboardMetric
from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer
class MyTensorboardMetric(TensorboardMetric):
# NOTE: We need to tell the new TensorBoard metric how to get the id /
# file handle for the TensorBoard logs from a trial. In this case
# our convention is to just save a separate file per trial in
# the prespecified log dir.
def _get_event_multiplexer_for_trial(self, trial):
mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)
mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)
mul.Reload()
return mul
# This indicates whether the metric is queryable while the trial is
# still running. We don't use this in the current tutorial, but Ax
# utilizes this to implement trial-level early-stopping functionality.
@classmethod
def is_available_while_running(cls):
return False
现在我们可以实例化准确性和模型参数数量的度量。curve_name 是 TensorBoard 日志中度量的名称,而 name 是 Ax 内部使用的度量名称。我们还指定 lower_is_better 以指示两个度量的有利方向。
val_acc = MyTensorboardMetric(
name="val_acc",
tag="val_acc",
lower_is_better=False,
)
model_num_params = MyTensorboardMetric(
name="num_params",
tag="num_params",
lower_is_better=True,
)
设置 OptimizationConfig
¶
通过 OptimizationConfig 来告诉 Ax 应该优化什么。这里我们使用 MultiObjectiveOptimizationConfig
,因为我们将会执行多目标优化。
此外,Ax 支持通过指定目标阈值来对不同的指标施加约束,这些阈值将限制我们想要探索的输出空间中感兴趣的区域。例如,我们将验证准确率限制在至少 0.94(94%),并将模型参数数量限制在最多 80,000 个。
from ax.core import MultiObjective, Objective, ObjectiveThreshold
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
opt_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(
objectives=[
Objective(metric=val_acc, minimize=False),
Objective(metric=model_num_params, minimize=True),
],
),
objective_thresholds=[
ObjectiveThreshold(metric=val_acc, bound=0.94, relative=False),
ObjectiveThreshold(metric=model_num_params, bound=80_000, relative=False),
],
)
创建 Ax 实验
在 Ax 中,实验对象是存储有关问题设置所有信息的对象。
from ax.core import Experiment
experiment = Experiment(
name="torchx_mnist",
search_space=search_space,
optimization_config=opt_config,
runner=ax_runner,
)
选择生成策略
生成策略是对我们希望如何执行优化的抽象表示。虽然这可以自定义(如果您想这样做,请参阅本教程),但在大多数情况下,Ax 可以根据搜索空间、优化配置以及我们想要运行的试验总数自动确定一个合适的策略。
通常,Ax 会选择在开始基于模型的贝叶斯优化策略之前评估一定数量的随机配置。
total_trials = 48 # total evaluation budget
from ax.modelbridge.dispatch_utils import choose_generation_strategy
gs = choose_generation_strategy(
search_space=experiment.search_space,
optimization_config=experiment.optimization_config,
num_trials=total_trials,
)
配置调度器
Scheduler
作为优化的循环控制。它与后端通信以启动试验、检查其状态和检索结果。在本教程的情况下,它只是读取和解析本地保存的日志。在远程执行环境中,它会调用 API。以下来自 Ax 调度器教程的插图总结了调度器如何与用于运行试验评估的外部系统交互:

Scheduler
需要使用 Experiment
和 GenerationStrategy
。可以通过 SchedulerOptions
传入一组选项。在这里,我们配置了总的评估次数以及 max_pending_trials
,即应同时运行的最多试验次数。在我们的本地设置中,这是作为独立进程运行的训练作业数量,而在远程执行设置中,这将是你想要并行使用的机器数量。
from ax.service.scheduler import Scheduler, SchedulerOptions
scheduler = Scheduler(
experiment=experiment,
generation_strategy=gs,
options=SchedulerOptions(
total_trials=total_trials, max_pending_trials=4
),
)
运行优化过程
一切配置完成后,我们可以让 Ax 以全自动方式运行优化。调度器将定期检查日志以获取所有当前运行试验的状态,如果试验完成,调度器将更新实验的状态并获取贝叶斯优化算法所需的观察结果。
scheduler.run_all_trials()
评估结果
我们现在可以使用 Ax 提供的辅助函数和可视化工具来检查优化的结果。
首先,我们生成一个包含实验结果摘要的数据框。数据框中的每一行对应一个试验(即运行过的训练作业),包含有关试验状态、评估的参数配置和观察到的指标值的信息。这为验证优化提供了一个简单的方法。
from ax.service.utils.report_utils import exp_to_df
df = exp_to_df(experiment)
df.head(10)
我们还可以可视化验证准确性和模型参数数量之间的权衡的帕累托前沿。
提示
Ax 使用 Plotly 生成交互式图表,允许您进行缩放、裁剪或悬停以查看图表组件的详细信息。您可以试一试,如果您想了解更多,可以查看可视化教程)。
下面的图中展示了最终的优化结果,其中颜色对应于每个试验的迭代次数。我们看到我们的方法能够成功探索权衡,并找到了具有高验证准确率的大模型以及具有相对较低验证准确率的小模型。
from ax.service.utils.report_utils import _pareto_frontier_scatter_2d_plotly
_pareto_frontier_scatter_2d_plotly(experiment)
为了更好地理解我们的代理模型对黑盒目标函数的学习情况,我们可以查看留一法交叉验证的结果。由于我们的模型是高斯过程,它们不仅提供点预测,还提供这些预测的不确定性估计。一个好的模型意味着预测的均值(图中的点)接近 45 度线,并且置信区间以预期的频率覆盖 45 度线(在这里我们使用 95%置信区间,因此我们预计它们会在 95%的时间内包含真实观测值)。
如下所示,模型大小( num_params
)指标比验证准确率( val_acc
)指标更容易建模。
from ax.modelbridge.cross_validation import compute_diagnostics, cross_validate
from ax.plot.diagnostic import interact_cross_validation_plotly
from ax.utils.notebook.plotting import init_notebook_plotting, render
cv = cross_validate(model=gs.model) # The surrogate model is stored on the ``GenerationStrategy``
compute_diagnostics(cv)
interact_cross_validation_plotly(cv)
我们还可以绘制轮廓图,以更好地理解不同目标如何依赖于两个输入参数。在下面的图中,我们展示了模型预测的验证准确率作为两个隐藏大小的函数。随着隐藏大小的增加,验证准确率明显提高。
from ax.plot.contour import interact_contour_plotly
interact_contour_plotly(model=gs.model, metric_name="val_acc")
类似地,我们在下面的图中展示了模型参数数量作为隐藏大小的函数,并观察到它也随着隐藏大小的增加而增加(对 hidden_size_1
的依赖性要大得多)。
interact_contour_plotly(model=gs.model, metric_name="num_params")
致谢
我们感谢 TorchX 团队(特别是 Kiuk Chung 和 Tristan Rice)在将 TorchX 与 Ax 集成方面的帮助。
脚本总运行时间:(0 分钟 0.000 秒)