从自定义训练模型获取批量预测结果

本页面介绍了如何使用 Google Cloud 控制台或 Vertex AI API 从自定义训练模型获取批量预测结果。

如需发出批量预测请求,请指定输入源和 Vertex AI 存储批量预测结果的输出位置(Cloud StorageBigQuery)。

限制和要求

获取批量预测时,请考虑以下限制和要求:

  • 为了尽量减少处理时间,您的输入和输出位置必须位于相同的单区域或多区域。例如,如果输入位于 us-central1,则输出可以位于 us-central1US,但不能位于 europe-west4。如需了解详情,请参阅 Cloud Storage 位置BigQuery 位置
  • 您的输入和输出还必须与模型位于相同的单区域或多区域。
  • 不支持 Model Garden 模型。
  • BigQuery ML 模型不是自定义训练的模型。不过,在以下条件下,您可以使用本页面中的信息从 BigQuery ML 模型获取批量预测:
    • BigQuery ML 模型必须在 Vertex AI Model Registry 中注册。
    • 如需将 BigQuery 表用作输入,您必须使用 Vertex AI API 将 InstanceConfig.instanceType 设置为 "object"

输入数据要求

批量请求的输入指定要发送到模型进行预测的内容。我们支持以下输入格式:

JSON 行

使用 JSON 行文件指定用于进行预测的输入实例列表。将文件存储在 Cloud Storage 存储桶中。

示例 1

以下示例显示了一个 JSON 行文件,其中每行包含一个数组:

[1, 2, 3, 4]
[5, 6, 7, 8]

以下是 HTTP 请求正文中发送到预测容器的内容:

所有其他容器

{"instances": [ [1, 2, 3, 4], [5, 6, 7, 8] ]}

PyTorch 容器

{"instances": [
{ "data": [1, 2, 3, 4] },
{ "data": [5, 6, 7, 8] } ]}

示例 2

以下示例显示 JSON 行文件,其中每行包含一个对象。

{ "values": [1, 2, 3, 4], "key": 1 }
{ "values": [5, 6, 7, 8], "key": 2 }

以下是 HTTP 请求正文中发送到预测容器的内容。请注意,系统会将同一请求正文发送到所有容器。

{"instances": [
  { "values": [1, 2, 3, 4], "key": 1 },
  { "values": [5, 6, 7, 8], "key": 2 }
]}

示例 3

对于 PyTorch 预构建容器,请确保按照 TorchServe 默认处理程序的要求将每个实例封装在 data 字段中:Vertex AI 不会为您封装实例。例如:

{ "data": { "values": [1, 2, 3, 4], "key": 1 } }
{ "data": { "values": [5, 6, 7, 8], "key": 2 } }

以下是 HTTP 请求正文中发送到预测容器的内容:

{"instances": [
  { "data": { "values": [1, 2, 3, 4], "key": 1 } },
  { "data": { "values": [5, 6, 7, 8], "key": 2 } }
]}

TFRecord

TFRecord 格式保存输入实例。您可以选择使用 Gzip 压缩 TFRecord 文件。将 TFRecord 文件存储在 Cloud Storage 存储桶中。

Vertex AI 将 TFRecord 文件中的每个实例读取为二进制文件,然后使用名为 b64 的单个键对实例进行 base64 编码,使其成为 JSON 对象。

以下是 HTTP 请求正文中发送到预测容器的内容:

所有其他容器

{"instances": [
{ "b64": "b64EncodedASCIIString" },
{ "b64": "b64EncodedASCIIString" } ]}

PyTorch 容器

{"instances": [ { "data": {"b64": "b64EncodedASCIIString" } }, { "data": {"b64": "b64EncodedASCIIString" } }
]}

确保您的预测容器知道如何解码实例。

CSV

在 CSV 文件中每行指定一个输入实例。第一行必须是标题行。您必须将所有字符串用英文双引号 (") 括起来。Vertex AI 不接受包含换行符的单元值。不带英文引号的值会被读取为浮点数。

以下示例展示了包含两个输入实例的 CSV 文件:

"input1","input2","input3"
0.1,1.2,"cat1"
4.0,5.0,"cat2"

以下是 HTTP 请求正文中发送到预测容器的内容:

所有其他容器

{"instances": [ [0.1,1.2,"cat1"], [4.0,5.0,"cat2"] ]}

PyTorch 容器

{"instances": [
{ "data": [0.1,1.2,"cat1"] },
{ "data": [4.0,5.0,"cat2"] } ]}

文件列表

创建一个文本文件,其中每一行是一个文件的 Cloud Storage URI。Vertex AI 将每个文件的内容读取为二进制文件,然后使用名为 b64 的单个键对实例进行 base64 编码,使其成为 JSON 对象。

如果您计划使用 Google Cloud 控制台获取批量预测结果,请将文件列表直接粘贴到 Google Cloud 控制台中。否则,请将列表保存在 Cloud Storage 存储桶中。

以下示例展示了包含两个输入实例的文件列表:

gs://path/to/image/image1.jpg
gs://path/to/image/image2.jpg

以下是 HTTP 请求正文中发送到预测容器的内容:

所有其他容器

{ "instances": [
{ "b64": "b64EncodedASCIIString" },
{ "b64": "b64EncodedASCIIString" } ]}

PyTorch 容器

{ "instances": [ { "data": { "b64": "b64EncodedASCIIString" } }, { "data": { "b64": "b64EncodedASCIIString" } }
]}

确保您的预测容器知道如何解码实例。

BigQuery

将 BigQuery 表指定为 projectId.datasetId.tableId。 Vertex AI 将表中的每一行转换为 JSON 实例。

例如,如果您的表包含以下内容:

第 1 列 第 2 列 第 3 列
1.0 3.0 "Cat1"
2.0 4.0 "Cat2"

以下是 HTTP 请求正文中发送到预测容器的内容:

所有其他容器

{"instances": [ [1.0,3.0,"cat1"], [2.0,4.0,"cat2"] ]}

PyTorch 容器

{"instances": [
{ "data": [1.0,3.0,"cat1"] },
{ "data": [2.0,4.0,"cat2"] } ]}

以下是将 BigQuery 数据类型转换为 JSON 的方式:

BigQuery 类型 JSON 类型 示例值
字符串 字符串 "abc"
整数 整数 1
浮点数 浮点数 1.2
数字 浮点数 4925.000000000
布尔值 布尔值 true
TimeStamp 字符串 "2019-01-01 23:59:59.999999+00:00"
日期 字符串 "2018-12-31"
时间 字符串 "23:59:59.999999"
DateTime 字符串 "2019-01-01T00:00:00"
录制 对象 { "A": 1,"B": 2}
重复类型 数组类型 [1, 2]
嵌套记录 对象 {"A": {"a": 0}, "B": 1}

分区数据

批量预测使用 MapReduce 将输入分片到每个副本。为了使用 MapReduce 功能,输入应该是可分区的。

Vertex AI 会自动对 BigQuery文件列表JSON 行输入进行分区。

Vertex AI 不会自动对 CSV 文件进行分区,因为它们本身不适合分区。CSV 文件中的行不是自描述的,不是类型化的,而且可能包含新行。我们建议不要将 CSV 输入用于对吞吐量敏感的应用。

对于 TFRecord 输入,请务必将实例拆分为较小的文件,并使用通配符(例如 gs://my-bucket/*.tfrecord)将文件传递给作业,从而手动对数据进行分区。文件数量应至少为指定的副本数量。

过滤和转换输入数据

您可以通过在 BatchPredictionJob 请求中指定 instanceConfig 来过滤和转换批量输入。

借助过滤功能,您可以从预测请求中排除输入数据中的某些字段,也可以仅将输入数据中的部分字段添加到预测请求中,而无需在预测容器中执行任何自定义预处理或后处理。如果输入数据文件具有模型不需要的额外列(例如键或其他数据),则此功能非常有用。

借助转换功能,您可以以 JSON arrayobject 格式将实例发送到预测容器。如需了解详情,请参阅 instanceType

例如,如果您的输入表包含以下内容:

customerId col1 col2
1001 1 2
1002 5 6

并且您指定了以下 instanceConfig

{
  "name": "batchJob1",
  ...
  "instanceConfig": {
    "excludedFields":["customerId"]
    "instanceType":"object"
  }
}

然后,预测请求中的实例将作为 JSON 对象发送,并且不包含 customerId 列:

{"col1":1,"col2":2}
{"col1":5,"col2":6}

请注意,指定以下 instanceConfig 会产生相同的结果:

{
  "name": "batchJob1",
  ...
  "instanceConfig": {
    "includedFields": ["col1","col2"]
    "instanceType":"object"
  }
}

有关如何使用特征过滤条件的演示,请参阅使用特征过滤的自定义模型批量预测笔记本。

请求批量预测

对于批量预测请求,您可以使用 Google Cloud 控制台或 Vertex AI API。批量预测任务可能需要一些时间才能完成,具体取决于提交的输入数据项数量。

当您请求批量预测时,预测容器将以用户提供的自定义服务账号身份运行。读取/写入操作(例如从数据源读取预测实例或写入预测结果)均使用 Vertex AI 服务代理完成,默认情况下有权访问 BigQuery 和 Cloud Storage。

Google Cloud 控制台

使用 Google Cloud 控制台请求批量预测。

  1. 在 Google Cloud 控制台的 Vertex AI 部分中,前往批量预测页面。

前往“批量预测”页面

  1. 点击创建以打开新建批量预测窗口。

  2. 定义批量预测部分,完成以下步骤:

    1. 输入批量预测的名称。

    2. 对于模型名称,选择要用于此批量预测的模型的名称。

    3. 选择来源部分,选择适用于输入数据的来源:

      • 如果您已将输入设置为 JSON 行、CSV 或 TFRecord 格式,请选择 Cloud Storage 上的文件(JSON 行、CSV、TFRecord、TFRecord、Gzip)。然后在源路径字段中指定输入文件。
      • 如果您使用文件列表作为输入,请选择 Cloud Storage 上的文件(其他),然后将文件列表粘贴到下面的字段中。
      • 对于 BigQuery 输入,选择 BigQuery 路径。如果您选择 BigQuery 作为输入,则还必须选择 BigQuery 作为输出以及选择 Google 管理的加密密钥。BigQuery 不支持将客户管理的加密密钥 (CMEK) 用作输入/输出。
    4. 目标路径字段中,指定您希望 Vertex AI 存储批量预测输出的 Cloud Storage 目录。

    5. (可选)您可以选中为此模型启用特征归因,以便在批量预测响应中获取特征归因。然后点击修改配置说明设置。(如果您之前为模型配置了说明设置,则修改说明设置是可选的,否则需要这样做。)

    6. 为批量预测作业指定计算选项:计算节点数量机器类型,以及(可选)加速器类型加速器数量

  3. 可选:适用于批量预测的模型监控分析功能已推出预览版。请参阅将偏差检测配置添加到批量预测作业的前提条件

    1. 点击以为此批量预测启用模型监控

    2. 选择训练数据源。 输入您选择的训练数据源的数据路径或位置。

    3. 可选:在提醒阈值下,指定触发提醒的阈值。

    4. 通知电子邮件地址部分,输入一个或多个以逗号分隔的电子邮件地址,以便在模型超出提醒阈值时接收提醒。

    5. 可选:在通知渠道部分,添加 Cloud Monitoring 渠道,以便在模型超出提醒阈值时接收提醒。您可以选择现有的 Cloud Monitoring 渠道,也可以通过点击管理通知渠道来创建一个新的 Cloud Monitoring 渠道。Google Cloud 控制台支持 PagerDuty、Slack 和 Pub/Sub 通知渠道。

  4. 点击创建

API

使用 Vertex AI API 发送批量预测请求。根据您用于获取批量预测结果的工具选择相应的标签页。

REST

在使用任何请求数据之前,请先进行以下替换:

  • LOCATION_ID:存储模型和执行批量预测作业的区域。例如 us-central1

  • PROJECT_ID:您的项目 ID

  • BATCH_JOB_NAME:批量预测作业的显示名。

  • MODEL_ID:用于执行预测的模型的 ID。

  • INPUT_FORMAT输入数据格式jsonlcsvtf-recordtf-record-gzipfile-list

  • INPUT_URI:输入数据的 Cloud Storage URI。可能包含通配符。

  • OUTPUT_DIRECTORY:您希望 Vertex AI 用于保存输出的目录的 Cloud Storage URI。

  • MACHINE_TYPE:要用于此批量预测作业的机器资源

    您可以选择配置 machineSpec 字段使用加速器,但以下示例未展示这一设置。

  • BATCH_SIZE:每个预测请求中发送的实例数;默认值为 64。增加批次大小可以产生更高的吞吐量,但也可能会导致请求超时。

  • STARTING_REPLICA_COUNT:此批量预测作业的节点数。

HTTP 方法和网址:

POST https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs

请求 JSON 正文:

{
  "displayName": "BATCH_JOB_NAME",
  "model": "projects/PROJECT_ID/locations/LOCATION_ID/models/MODEL_ID",
  "inputConfig": {
    "instancesFormat": "INPUT_FORMAT",
    "gcsSource": {
      "uris": ["INPUT_URI"],
    },
  },
  "outputConfig": {
    "predictionsFormat": "jsonl",
    "gcsDestination": {
      "outputUriPrefix": "OUTPUT_DIRECTORY",
    },
  },
  "dedicatedResources" : {
    "machineSpec" : {
      "machineType": MACHINE_TYPE
    },
    "startingReplicaCount": STARTING_REPLICA_COUNT
  },
  "manualBatchTuningParameters": {
    "batch_size": BATCH_SIZE,
  }
}

如需发送请求,请选择以下方式之一:

curl

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs"

PowerShell

将请求正文保存在名为 request.json 的文件中,然后执行以下命令:

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION_ID-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION_ID/batchPredictionJobs" | Select-Object -Expand Content

您应该收到类似以下内容的 JSON 响应:

{
  "name": "projects/PROJECT_NUMBER/locations/LOCATION_ID/batchPredictionJobs/BATCH_JOB_ID",
  "displayName": "BATCH_JOB_NAME 202005291958",
  "model": "projects/PROJECT_ID/locations/LOCATION_ID/models/MODEL_ID",
  "inputConfig": {
    "instancesFormat": "jsonl",
    "gcsSource": {
      "uris": [
        "INPUT_URI"
      ]
    }
  },
  "outputConfig": {
    "predictionsFormat": "jsonl",
    "gcsDestination": {
      "outputUriPrefix": "OUTPUT_DIRECTORY"
    }
  },
  "state": "JOB_STATE_PENDING",
  "createTime": "2020-05-30T02:58:44.341643Z",
  "updateTime": "2020-05-30T02:58:44.341643Z",
}

Java

在尝试此示例之前,请按照《Vertex AI 快速入门:使用客户端库》中的 Java 设置说明执行操作。 如需了解详情,请参阅 Vertex AI Java API 参考文档

如需向 Vertex AI 进行身份验证,请设置应用默认凭据。 如需了解详情,请参阅为本地开发环境设置身份验证

在以下示例中,将 PREDICTIONS_FORMAT 替换为 jsonl。 如需了解如何替换其他占位符,请参阅本部分的 REST & CMD LINE 标签页。

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.AcceleratorType;
import com.google.cloud.aiplatform.v1.BatchDedicatedResources;
import com.google.cloud.aiplatform.v1.BatchPredictionJob;
import com.google.cloud.aiplatform.v1.GcsDestination;
import com.google.cloud.aiplatform.v1.GcsSource;
import com.google.cloud.aiplatform.v1.JobServiceClient;
import com.google.cloud.aiplatform.v1.JobServiceSettings;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.MachineSpec;
import com.google.cloud.aiplatform.v1.ModelName;
import com.google.protobuf.Value;
import java.io.IOException;

public class CreateBatchPredictionJobSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "PROJECT";
    String displayName = "DISPLAY_NAME";
    String modelName = "MODEL_NAME";
    String instancesFormat = "INSTANCES_FORMAT";
    String gcsSourceUri = "GCS_SOURCE_URI";
    String predictionsFormat = "PREDICTIONS_FORMAT";
    String gcsDestinationOutputUriPrefix = "GCS_DESTINATION_OUTPUT_URI_PREFIX";
    createBatchPredictionJobSample(
        project,
        displayName,
        modelName,
        instancesFormat,
        gcsSourceUri,
        predictionsFormat,
        gcsDestinationOutputUriPrefix);
  }

  static void createBatchPredictionJobSample(
      String project,
      String displayName,
      String model,
      String instancesFormat,
      String gcsSourceUri,
      String predictionsFormat,
      String gcsDestinationOutputUriPrefix)
      throws IOException {
    JobServiceSettings settings =
        JobServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();
    String location = "us-central1";

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (JobServiceClient client = JobServiceClient.create(settings)) {

      // Passing in an empty Value object for model parameters
      Value modelParameters = ValueConverter.EMPTY_VALUE;

      GcsSource gcsSource = GcsSource.newBuilder().addUris(gcsSourceUri).build();
      BatchPredictionJob.InputConfig inputConfig =
          BatchPredictionJob.InputConfig.newBuilder()
              .setInstancesFormat(instancesFormat)
              .setGcsSource(gcsSource)
              .build();
      GcsDestination gcsDestination =
          GcsDestination.newBuilder().setOutputUriPrefix(gcsDestinationOutputUriPrefix).build();
      BatchPredictionJob.OutputConfig outputConfig =
          BatchPredictionJob.OutputConfig.newBuilder()
              .setPredictionsFormat(predictionsFormat)
              .setGcsDestination(gcsDestination)
              .build();
      MachineSpec machineSpec =
          MachineSpec.newBuilder()
              .setMachineType("n1-standard-2")
              .setAcceleratorType(AcceleratorType.NVIDIA_TESLA_T4)
              .setAcceleratorCount(1)
              .build();
      BatchDedicatedResources dedicatedResources =
          BatchDedicatedResources.newBuilder()
              .setMachineSpec(machineSpec)
              .setStartingReplicaCount(1)
              .setMaxReplicaCount(1)
              .build();
      String modelName = ModelName.of(project, location, model).toString();
      BatchPredictionJob batchPredictionJob =
          BatchPredictionJob.newBuilder()
              .setDisplayName(displayName)
              .setModel(modelName)
              .setModelParameters(modelParameters)
              .setInputConfig(inputConfig)
              .setOutputConfig(outputConfig)
              .setDedicatedResources(dedicatedResources)
              .build();
      LocationName parent = LocationName.of(project, location);
      BatchPredictionJob response = client.createBatchPredictionJob(parent, batchPredictionJob);
      System.out.format("response: %s\n", response);
      System.out.format("\tName: %s\n", response.getName());
    }
  }
}

Python

如需了解如何安装或更新 Vertex AI SDK for Python,请参阅安装 Vertex AI SDK for Python。 如需了解详情,请参阅 Python API 参考文档

def create_batch_prediction_job_dedicated_resources_sample(
    project: str,
    location: str,
    model_resource_name: str,
    job_display_name: str,
    gcs_source: Union[str, Sequence[str]],
    gcs_destination: str,
    instances_format: str = "jsonl",
    machine_type: str = "n1-standard-2",
    accelerator_count: int = 1,
    accelerator_type: Union[str, aiplatform_v1.AcceleratorType] = "NVIDIA_TESLA_K80",
    starting_replica_count: int = 1,
    max_replica_count: int = 1,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    my_model = aiplatform.Model(model_resource_name)

    batch_prediction_job = my_model.batch_predict(
        job_display_name=job_display_name,
        gcs_source=gcs_source,
        gcs_destination_prefix=gcs_destination,
        instances_format=instances_format,
        machine_type=machine_type,
        accelerator_count=accelerator_count,
        accelerator_type=accelerator_type,
        starting_replica_count=starting_replica_count,
        max_replica_count=max_replica_count,
        sync=sync,
    )

    batch_prediction_job.wait()

    print(batch_prediction_job.display_name)
    print(batch_prediction_job.resource_name)
    print(batch_prediction_job.state)
    return batch_prediction_job

BigQuery

上一个 REST 示例使用 Cloud Storage 作为源和目标。如需改用 BigQuery,请进行以下更改:

  • inputConfig 字段更改为以下内容:

    "inputConfig": {
       "instancesFormat": "bigquery",
       "bigquerySource": {
          "inputUri": "bq://SOURCE_PROJECT_ID.SOURCE_DATASET_NAME.SOURCE_TABLE_NAME"
       }
    }
    
  • outputConfig 字段更改为以下内容:

    "outputConfig": {
       "predictionsFormat":"bigquery",
       "bigqueryDestination":{
          "outputUri": "bq://DESTINATION_PROJECT_ID.DESTINATION_DATASET_NAME.DESTINATION_TABLE_NAME"
       }
     }
    
  • 替换以下内容:

    • SOURCE_PROJECT_ID:源 Google Cloud 项目的 ID
    • SOURCE_DATASET_NAME:来源 BigQuery 数据集的名称
    • SOURCE_TABLE_NAME:BigQuery 源表的名称
    • DESTINATION_PROJECT_ID:目标 Google Cloud 项目的 ID
    • DESTINATION_DATASET_NAME:目标 BigQuery 数据集的名称
    • DESTINATION_TABLE_NAME:BigQuery 目标表的名称

特征重要性

如果您希望返回预测结果的特征重要性值,请将 generateExplanation 属性设置为 true。请注意,预测模型不支持特征重要性,因此您无法在批量预测请求中包含特征重要性。

特征重要性(有时称为特征归因)是 Vertex Explainable AI 的一部分

只有在为说明配置了 Model 或指定了 BatchPredictionJobexplanationSpec 字段时,才能将 generateExplanation 设置为 true

选择机器类型和副本数量

与使用较大的机器类型相比,通过增加副本数量进行横向扩缩能够以更线性且可预测的方式提高吞吐量。

通常,我们建议您尽可能为作业指定最小的机器类型并增加副本的数量。

为了获得成本效益,我们建议您选择副本数量,以使批量预测作业至少运行 10 分钟。这是因为您需要按每个副本节点时付费,其中包括每个副本启动所需的大约 5 分钟。仅处理几秒钟然后关停并不经济实惠。

一般来说,如果实例数为几千,我们建议将 starting_replica_count 设置为几十。如果实例数为几百万,我们建议将 starting_replica_count 设置为几百。您还可以使用以下公式来估算副本的数量:

N / (T * (60 / Tb))

其中:

  • N:作业中的批次数量。例如,100 万个实例 / 100 批次大小 = 10,000 个批次。
  • T:批量预测作业的预期时间。例如,10 分钟。
  • Tb:副本处理单个批次所需的时间(以秒为单位)。例如,在 2 核机器类型上,处理每个批次需要 1 秒。

在我们的示例中,10,000 个批次 /(10 分钟 * [60 / 1 秒])的结果向上取整得到 17 个副本。

与在线预测不同,批量预测作业不会自动扩缩。由于预先知道所有输入数据,因此系统会在作业开始时将数据划分到每个副本。系统会使用 starting_replica_count 参数。系统会忽略 max_replica_count 参数。

这些建议均为近似指导建议。它们不一定能为每个模型提供最佳吞吐量。它们不会提供对处理时间和费用的准确预测。并且它们不一定能针对每种场景实现最佳费用/吞吐量权衡。您可将建议用作合理的起点,并根据需要进行调整。如需衡量模型的吞吐量等特征,请运行找到理想的机器类型笔记本。

适用于 GPU 或 TPU 加速机器

遵循上述准则(也适用于 CPU 专用模型),并注意以下其他事项:

  • 您可能需要更多的 CPU 和 GPU(例如,用于数据预处理)。
  • GPU 机器类型的启动需要更多时间(10 分钟),因此您可能希望批量预测作业定位更长的时间(例如,至少 20 分钟,而不是 10 分钟),以便将合理比例的时间和费用花在生成预测结果上。

检索批量预测结果

批量预测任务完成后,预测的输出存储在您在请求中指定的 Cloud Storage 存储桶或 BigQuery 位置中。

批量预测结果示例

输出文件夹包含一组 JSON 行文件。

这些文件的名称为 {gcs_path}/prediction.results-{file_number}-of-{number_of_files_generated}。由于批量预测的分布式特性,文件数量不确定。

文件中的每一行对应输入中的一个实例,并具有以下键值对:

  • prediction:包含预测容器返回的值。
  • instance:对于文件列表,包含 Cloud Storage URI。对于所有其他输入格式,它包含 HTTP 请求正文中发送到预测容器的值。

示例 1

如果 HTTP 请求包含:

{
  "instances": [
    [1, 2, 3, 4],
    [5, 6, 7, 8]
]}

并且预测容器返回:

{
  "predictions": [
    [0.1,0.9],
    [0.7,0.3]
  ],
}

然后,JSON 行输出文件为:

{ "instance": [1, 2, 3, 4], "prediction": [0.1,0.9]}
{ "instance": [5, 6, 7, 8], "prediction": [0.7,0.3]}

示例 2

如果 HTTP 请求包含:

{
  "instances": [
    {"values": [1, 2, 3, 4], "key": 1},
    {"values": [5, 6, 7, 8], "key": 2}
]}

并且预测容器返回:

{
  "predictions": [
    {"result":1},
    {"result":0}
  ],
}

然后,JSON 行输出文件为:

{ "instance": {"values": [1, 2, 3, 4], "key": 1}, "prediction": {"result":1}}
{ "instance": {"values": [5, 6, 7, 8], "key": 2}, "prediction": {"result":0}}

使用 Explainable AI

我们不建议对大量数据运行基于特征的说明。这是因为每个输入都可能会根据一组可能的特征值扇出到数千个请求,这可能会导致处理时间和费用大幅增加。通常,通过小型数据集足以了解特征重要性。

批量预测不支持基于样本的说明

笔记本

后续步骤