画像オブジェクト検出モデルをトレーニングする

このページでは、Google Cloud コンソールまたは Vertex AI API を使用して、画像データセットから AutoML オブジェクト検出モデルをトレーニングする方法について説明します。

AutoML モデルをトレーニングする

Google Cloud コンソール

  1. Google Cloud コンソールの [Vertex AI] セクションで、[データセット] ページに移動します。

    [データセット] ページに移動

  2. モデルのトレーニングに使用するデータセットの名前をクリックして、詳細ページを開きます。

  3. [新しいモデルのトレーニング] をクリックします。

  4. トレーニング メソッドとして [AutoML] を選択します。

  5. [モデルを使用する場所の選択] セクションで、モデルのホストの場所を選択します。クラウドエッジ、または Vertex AI ビジョン

  6. [続行] をクリックします。

  7. モデルの名前を入力します。

  8. トレーニング データの分割方法を手動で設定する場合は、[ADVANCED OPTIONS] を開き、データ分割オプションを選択します(詳細はこちら)。

  9. [トレーニングを開始] をクリックします。

    データのサイズ、複雑さ、トレーニング予算(指定された場合)に応じて、モデルのトレーニングに何時間もかかることがあります。このタブを閉じて、後で戻ることもできます。モデルのトレーニングが完了すると、メールが送られてきます。

API

お使いの言語または環境に応じて、以下のタブを選択してください。

REST

リクエストのデータを使用する前に、次のように置き換えます。

  • LOCATION: データセットが配置され、モデルが作成されるリージョン。例: us-central1
  • PROJECT: 実際のプロジェクト ID
  • TRAININGPIPELINE_DISPLAYNAME: 必須。trainingPipeline の表示名。
  • DATASET_ID: トレーニングに使用するデータセットの ID 番号。
  • fractionSplit: 省略可。ML が使用する可能性のあるデータ分割オプションの 1 つ。fractionSplit の場合、値の合計 1 でなければなりません。例:
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME*: TrainingPipeline によってアップロード(作成)されたモデルの表示名。
  • MODEL_DESCRIPTION*: モデルの説明。
  • modelToUpload.labels*: モデルを整理するための任意の Key-Value ペアのセット。例:
    • "env": "prod"
    • "tier": "backend"
  • MODELTYPE: トレーニングする Cloud ホスト型モデルのタイプ。次のオプションがあります。
    • CLOUD_1 - Google Cloud 内で使用するために最適化されたモデルです。このモデルはエクスポートできません。上記の CLOUD_HIGH_ACCURACY_1 モデルや CLOUD_LOW_LATENCY_1 モデルと比較すると、予測品質が高く、レイテンシが低いことが期待されます。
    • CLOUD_HIGH_ACCURACY_1 - Google Cloud 内で使用するために最適化されたモデルです。このモデルはエクスポートできません。このモデルでは、レイテンシが高くなることが予想されますが、他のクラウドモデルよりも予測品質が高くなります。
    • CLOUD_LOW_LATENCY_1 - Google Cloud 内で使用するために最適化されたモデルです。このモデルはエクスポートできません。このモデルはでは、レイテンシが低くなることが予想されますが、他のクラウドモデルよりも予測品質が低くなる可能性があります。
    その他のモデルタイプのオプションは、リファレンス ドキュメントに記載されています。
  • NODE_HOUR_BUDGET: 実際のトレーニング料金はこの値以下になります。Cloud モデルでは、予算は 20,000~900,000 ミリノード時間(上限、下限を含む)の範囲に収める必要があります。デフォルト値は 216,000 で、9 ノードが使用された場合の 1 日の経過時間に相当します。
  • PROJECT_NUMBER: プロジェクトに自動生成されたプロジェクト番号

HTTP メソッドと URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

リクエストの本文(JSON):

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml",
  "trainingTaskInputs": {
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

リクエストを送信するには、次のいずれかのオプションを選択します。

curl

リクエスト本文を request.json という名前のファイルに保存して、次のコマンドを実行します。

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

リクエスト本文を request.json という名前のファイルに保存して、次のコマンドを実行します。

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

レスポンスには、仕様と TRAININGPIPELINE_ID に関する情報が含まれています。

Java

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Java の設定手順を完了してください。詳細については、Vertex AI Java API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageObjectDetectionInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageObjectDetectionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageObjectDetectionSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageObjectDetectionSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_object_detection_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageObjectDetectionInputs autoMlImageObjectDetectionInputs =
          AutoMlImageObjectDetectionInputs.newBuilder()
              .setModelType(ModelType.CLOUD_HIGH_ACCURACY_1)
              .setBudgetMilliNodeHours(20000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageObjectDetectionInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Object Detection Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Node.js の設定手順を完了してください。詳細については、Vertex AI Node.js API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';

const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageObjectDetectionInputs.ModelType;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageObjectDetection() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const trainingTaskInputsObj =
    new definition.AutoMlImageObjectDetectionInputs({
      disableEarlyStopping: false,
      modelType: ModelType.CLOUD_1,
      budgetMilliNodeHours: 20000,
    });

  const trainingTaskInputs = trainingTaskInputsObj.toValue();
  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId: datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline image object detection response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineImageObjectDetection();

Vertex AI SDK for Python

Vertex AI SDK for Python のインストールまたは更新の方法については、Vertex AI SDK for Python をインストールするをご覧ください。 詳細については、Vertex AI SDK for Python API のリファレンス ドキュメントをご覧ください。

from google.cloud import aiplatform
from google.cloud.aiplatform.gapic.schema import trainingjob


def create_training_pipeline_image_object_detection_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    model_display_name: str,
    location: str = "us-central1",
    api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
    # The AI Platform services require regional API endpoints.
    client_options = {"api_endpoint": api_endpoint}
    # Initialize client that will be used to create and send requests.
    # This client only needs to be created once, and can be reused for multiple requests.
    client = aiplatform.gapic.PipelineServiceClient(client_options=client_options)
    training_task_inputs = trainingjob.definition.AutoMlImageObjectDetectionInputs(
        model_type="CLOUD_HIGH_ACCURACY_1",
        budget_milli_node_hours=20000,
        disable_early_stopping=False,
    ).to_value()

    training_pipeline = {
        "display_name": display_name,
        "training_task_definition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_object_detection_1.0.0.yaml",
        "training_task_inputs": training_task_inputs,
        "input_data_config": {"dataset_id": dataset_id},
        "model_to_upload": {"display_name": model_display_name},
    }
    parent = f"projects/{project}/locations/{location}"
    response = client.create_training_pipeline(
        parent=parent, training_pipeline=training_pipeline
    )
    print("response:", response)

REST を使用してデータ分割を制御する

トレーニング セット、検証セット、テストセットの間でトレーニング データをどのように分割するかを制御できます。Vertex AI API を使用する場合は、Split オブジェクトを使用してデータの分割を決定します。トレーニング データの分割に使用できるオブジェクト タイプの一つとして、Split オブジェクトを InputConfig オブジェクトに含めることができます。選択できるメソッドは 1 つのみです。

  • FractionSplit:
    • TRAINING_FRACTION: トレーニング セットに使用されるトレーニング データの割合。
    • VALIDATION_FRACTION: 検証セットに使用されるトレーニング データの割合。動画データに対しては使用されません。
    • TEST_FRACTION: テストセットに使用されるトレーニング データの割合。

    いずれか一つでも指定する場合は、すべてを指定する必要があります。割合の合計が 1.0 になるようにしてください。割合のデフォルト値は、データ型によって異なります。詳細については、こちらをご覧ください。

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
    },
    
  • FilterSplit:
    • TRAINING_FILTER: このフィルタに一致するデータ項目がトレーニング セットに使用されます。
    • VALIDATION_FILTER: このフィルタに一致するデータ項目が検証セットに使用されます。動画データの場合は "-" にする必要があります。
    • TEST_FILTER: このフィルタに一致するデータ項目がテストセットに使用されます。

    これらのフィルタは、ml_use ラベル、またはデータに適用された任意のラベルとともに使用できます。ml-use ラベルその他のラベルを使用してデータをフィルタリングする方法をご確認ください。

    次の例は、検証セットを含む、ml_use ラベルを持つ filterSplit オブジェクトの使用方法を示しています。

    "filterSplit": {
    "trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
    "validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
    "testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
    }