• 教程 >
  • PyTorch 菜谱 >
  • 使用 Flask 进行部署
快捷键

使用 Flask 进行部署

创建时间:2025 年 4 月 1 日 | 最后更新时间:2025 年 4 月 1 日 | 最后验证:未验证

在本菜谱中,您将学习:

  • 如何将您的训练好的 PyTorch 模型封装在 Flask 容器中,并通过 Web API 进行暴露

  • 如何将传入的 Web 请求翻译成 PyTorch 张量以供您的模型使用

  • 如何打包您的模型输出以生成 HTTP 响应

需求

您需要一个安装了以下包(及其依赖项)的 Python 3 环境:

  • PyTorch 1.5

  • TorchVision 0.6.0

  • Flask 1.1

可选,要获取一些支持文件,您需要 git。

PyTorch 和 TorchVision 的安装说明可在 pytorch.org 找到。Flask 的安装说明可在 Flask 网站上找到。

什么是 Flask?

Flask 是一个用 Python 编写的轻量级 Web 服务器。它为您提供了一个方便的方式,可以快速设置一个用于从您的训练好的 PyTorch 模型进行预测的 Web API,无论是直接使用,还是作为更大系统中的 Web 服务。

设置和支持文件

我们将创建一个 Web 服务,该服务接收图像,并将它们映射到 ImageNet 数据集的 1000 个类别之一。为此,您需要一个用于测试的图像文件。可选地,您还可以获取一个文件,该文件将模型输出的类别索引映射到可读的类别名称。

选项 1:快速获取两个文件 ¶

您可以通过检出 TorchServe 仓库并将文件复制到您的工作文件夹来快速获取这两个支持文件。(注意:本教程不依赖于 TorchServe - 这只是获取文件的一个快捷方式。)从您的 shell 提示符执行以下命令:

git clone https://github.com/pytorch/serve
cp serve/examples/image_classifier/kitten.jpg .
cp serve/examples/image_classifier/index_to_name.json .

您已经得到了它们!

选项 2:自带图片 ¶

在下面的 Flask 服务中, index_to_name.json 文件是可选的。您可以使用自己的图像测试您的服务 - 只需确保它是一个 3 色 JPEG 即可。

构建您的 Flask 服务 ¶

本食谱末尾展示了 Flask 服务的完整 Python 脚本;您可以将其复制粘贴到自己的 app.py 文件中。以下我们将逐一查看各个部分,以使它们的功能更加清晰。

导入部分

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request

按顺序:

  • 我们将使用来自 torchvision.models 的预训练 DenseNet 模型

  • torchvision.transforms 包含用于操作您的图像数据的工具

  • Pillow( PIL )是我们最初用于加载图像文件的工具

  • 当然我们需要从 flask 获取类

预处理 ¶

def transform_image(infile):
    input_transforms = [transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)
    timg = my_transforms(image)
    timg.unsqueeze_(0)
    return timg

网络请求给我们提供了一个图像文件,但我们的模型期望一个形状为(N, 3, 224, 224)的 PyTorch 张量,其中 N 是输入批次的物品数量。(我们只有一个批次大小。)我们首先做的是组合一组 TorchVision 转换,用于调整图像大小和裁剪,将其转换为张量,然后对张量中的值进行归一化。(有关此归一化的更多信息,请参阅 torchvision.models_ 的文档。)

之后,我们打开文件并应用转换。转换返回一个形状为(3, 224, 224)的张量 - 一个 224x224 图像的 3 个颜色通道。因为我们需要将这张单独的图像变成一个批次,所以我们使用 unsqueeze_(0) 调用在原地修改张量,添加一个新的第一个维度。张量包含相同的数据,但现在形状为(1, 3, 224, 224)。

通常情况下,即使您不处理图像数据,也需要将 HTTP 请求的输入转换为 PyTorch 可以处理的张量。

推理 ¶

def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)
    _, y_hat = outputs.max(1)
    prediction = y_hat.item()
    return prediction

推理本身是最简单的部分:当我们把输入张量传递给模型时,我们得到一个张量,其中的值代表了模型估计该图像属于特定类别的可能性。 max() 调用找到最大可能性值的类别,并返回该值以及 ImageNet 类别索引。最后,我们使用 item() 调用从包含它的张量中提取该类别索引,并返回它。

后处理 ¶

def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name

render_prediction() 方法将预测的类别索引映射到可读的类别标签。通常,在从模型获得预测后,会进行后处理,以便将预测准备好供人类消费或供其他软件使用。

运行完整的 Flask 应用 ¶

将以下内容粘贴到名为 app.py 的文件中:

import io
import json
import os

import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
model = models.densenet121(pretrained=True)               # Trained on 1000 classes from ImageNet
model.eval()                                              # Turns off autograd



img_class_map = None
mapping_file_path = 'index_to_name.json'                  # Human-readable names for Imagenet classes
if os.path.isfile(mapping_file_path):
    with open (mapping_file_path) as f:
        img_class_map = json.load(f)



# Transform input into the form our model expects
def transform_image(infile):
    input_transforms = [transforms.Resize(255),           # We use multiple TorchVision transforms to ready the image
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],       # Standard normalization for ImageNet model input
            [0.229, 0.224, 0.225])]
    my_transforms = transforms.Compose(input_transforms)
    image = Image.open(infile)                            # Open the image file
    timg = my_transforms(image)                           # Transform PIL image to appropriately-shaped PyTorch tensor
    timg.unsqueeze_(0)                                    # PyTorch models expect batched input; create a batch of 1
    return timg


# Get a prediction
def get_prediction(input_tensor):
    outputs = model.forward(input_tensor)                 # Get likelihoods for all ImageNet classes
    _, y_hat = outputs.max(1)                             # Extract the most likely class
    prediction = y_hat.item()                             # Extract the int value from the PyTorch tensor
    return prediction

# Make the prediction human-readable
def render_prediction(prediction_idx):
    stridx = str(prediction_idx)
    class_name = 'Unknown'
    if img_class_map is not None:
        if stridx in img_class_map is not None:
            class_name = img_class_map[stridx][1]

    return prediction_idx, class_name


@app.route('/', methods=['GET'])
def root():
    return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'})


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        if file is not None:
            input_tensor = transform_image(file)
            prediction_idx = get_prediction(input_tensor)
            class_id, class_name = render_prediction(prediction_idx)
            return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

要从您的 shell 提示符启动服务器,请输入以下命令:

FLASK_APP=app.py flask run

默认情况下,您的 Flask 服务器正在监听 5000 端口。一旦服务器启动,请打开另一个终端窗口,测试您的新推理服务器:

curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "file=@kitten.jpg"

如果一切设置正确,你应该收到以下类似的响应:

{"class_id":285,"class_name":"Egyptian_cat"}

重要资源 ¶

  • 请访问 pytorch.org 获取安装说明,以及更多文档和教程

  • Flask 网站有一个快速入门指南,其中详细介绍了设置简单的 Flask 服务


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源