分類モデルまたは回帰モデルをトレーニングする

このページでは、Google Cloud コンソールまたは Vertex AI API のどちらかを使用して、表形式のデータセットから分類モデルまたは回帰モデルをトレーニングする方法について説明します。

準備

モデルをトレーニングするには、次の作業を完了しておく必要があります。

モデルのトレーニング

Google Cloud コンソール

  1. Google Cloud コンソールの [Vertex AI] セクションで、[データセット] ページに移動します。

    [データセット] ページに移動

  2. モデルのトレーニングに使用するデータセットの名前をクリックして、詳細ページを開きます。

  3. 目的のデータ型でアノテーション セットが使用されている場合は、モデルで使用するアノテーション セットを選択します。

  4. [新しいモデルのトレーニング] をクリックします。

  5. [その他] を選択します。

  6. [新しいモデルをトレーニング] ページで、次の手順を行います。

    1. モデルのトレーニング方法を選択します。

      • AutoML は幅広いユースケースに適しています。

      [続行] をクリックします。

    2. 新しいモデルの表示名を入力します。

    3. ターゲット列を選択します。

      ターゲット列の値が、モデルによって予測されます。

      ターゲット列の要件をご覧ください。

    4. 省略可: テスト データセットを BigQuery にエクスポートするには、[Export test dataset to BigQuery] のチェックボックスをオンにして、テーブルの名前を指定します。

    5. 省略可: データをトレーニング セット、テストセット、検証セットに分割する方法を選択するには、[詳細オプション] を開きます。次のデータ分割オプションから選択できます。

      • ランダム(デフォルト): Vertex AI は各データセットに関連付けられた行をランダムに選択します。デフォルトでは、Vertex AI はデータ行の 80% をトレーニング セット、10% を検証セット、10% をテストセットに選択します。
      • 手動: Vertex AI は、データ分割列の値に基づいて各データセットのデータ行を選択します。データ分割列の名前を指定します。
      • 時系列: Vertex AI は、時間列のタイムスタンプに基づいてデータを分割します。時間列の名前を指定します。

      データ分割をご覧ください。

    6. [続行] をクリックします。

    7. 省略可: [統計情報を生成] をクリックします。統計情報を生成すると、[変換] プルダウン メニューにデータが入力されます。

    8. [トレーニング オプション] ページで、列のリストを確認して、モデルのトレーニングに使用すべきではない列を除外します。

    9. 含まれている特徴に対して選択された変換と、無効なデータが許可されているかどうかを確認して、必要な更新を行います。

      詳細については、変換無効なデータをご覧ください。

    10. 重み列の指定や最適化目標のデフォルトからの変更を行うには、[詳細オプション] を開いて、選択をします。

      詳しくは、重み列および最適化の目標をご覧ください。

    11. [続行] をクリックします。

    12. [コンピューティングと料金] ウィンドウで、次のように構成します。

      モデルのトレーニングの最大時間数を入力します。

      この設定により、トレーニング費用に上限を設定できます。新しいモデルの作成には他のオペレーションも必要なため、実際の経過時間がこの値より長くなることがあります。

      推奨されるトレーニング時間は、トレーニング データのサイズに応じて変わります。 次の表に、行数別の推奨トレーニング時間を示します。列の数が多い場合もトレーニング時間が長くなります。

      推奨トレーニング時間
      100,000 未満 1~3 時間
      100,000~1,000,000 1~6 時間
      1,000,000~10,000,000 1~12 時間
      10,000,000 以上 3~24 時間
      トレーニングの料金については、料金ページをご覧ください。

    13. [トレーニングを開始] をクリックします。

      データのサイズ、複雑さ、トレーニング予算(指定された場合)に応じて、モデルのトレーニングに何時間もかかることがあります。このタブを閉じて、後で戻ることもできます。モデルのトレーニングが完了すると、メールが送られてきます。

API

表形式データ型の目標を選択します。

分類

お使いの言語または環境に応じて、以下のタブを選択してください。

REST

trainingPipelines.create コマンドを使用してモデルをトレーニングします。

モデルをトレーニングします。

リクエストのデータを使用する前に、次のように置き換えます。

  • LOCATION: 使用するリージョン。
  • PROJECT: 実際のプロジェクト ID
  • TRAININGPIPELINE_DISPLAY_NAME: このオペレーション用に作成されたトレーニング パイプラインの表示名。
  • TARGET_COLUMN: このモデルに予測させる列(値)。
  • WEIGHT_COLUMN:(省略可)重み列。詳細
  • TRAINING_BUDGET: モデルをトレーニングするミリノード時間単位の最大時間(1,000 ミリノード時間は 1 ノード時間に等しい)。
  • OPTIMIZATION_OBJECTIVE: 予測タイプに対してデフォルトの最適化目標を指定しない場合にのみ必須となります。詳細
  • TRANSFORMATION_TYPE: モデルのトレーニングに使用される列ごとに変換タイプが指定されます。詳細
  • COLUMN_NAME: 列の名前および指定された変換タイプ。モデルのトレーニングに使用するすべての列を指定する必要があります。
  • MODEL_DISPLAY_NAME: 新しくトレーニングされたモデルの表示名。
  • DATASET_ID: トレーニング データセットの ID。
  • Split オブジェクトを指定すれば、データ分割を制御できます。データ分割の制御については、REST を使用したデータ分割の制御をご覧ください。
  • PROJECT_NUMBER: プロジェクトに自動生成されたプロジェクト番号

HTTP メソッドと URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

リクエストの本文(JSON):

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "classification",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

リクエストを送信するには、次のいずれかのオプションを展開します。

次のような JSON レスポンスが返されます。

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Java の設定手順を完了してください。詳細については、Vertex AI Java API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableClassification(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      Transformation transformation1 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build())
              .build();
      Transformation transformation2 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build())
              .build();
      Transformation transformation3 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build())
              .build();
      Transformation transformation4 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build())
              .build();

      ArrayList<Transformation> transformationArrayList = new ArrayList<>();
      transformationArrayList.add(transformation1);
      transformationArrayList.add(transformation2);
      transformationArrayList.add(transformation3);
      transformationArrayList.add(transformation4);

      AutoMlTablesInputs autoMlTablesInputs =
          AutoMlTablesInputs.newBuilder()
              .setTargetColumn(targetColumn)
              .setPredictionType("classification")
              .addAllTransformations(transformationArrayList)
              .setTrainBudgetMilliNodeHours(8000)
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Classification Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Node.js の設定手順を完了してください。詳細については、Vertex AI Node.js API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'sepal_width'}},
    {auto: {column_name: 'sepal_length'}},
    {auto: {column_name: 'petal_length'}},
    {auto: {column_name: 'petal_width'}},
  ];
  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    targetColumn: targetColumn,
    predictionType: 'classification',
    transformations: transformations,
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-log-loss',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline tabular classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesClassification();

Python

Vertex AI SDK for Python のインストールまたは更新の方法については、Vertex AI SDK for Python をインストールするをご覧ください。 詳細については、Python API リファレンス ドキュメントをご覧ください。

def create_training_pipeline_tabular_classification_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = None,
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_classification_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="classification"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_classification_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

回帰

お使いの言語または環境に応じて、以下のタブを選択してください。

REST

trainingPipelines.create コマンドを使用してモデルをトレーニングします。

モデルをトレーニングします。

リクエストのデータを使用する前に、次のように置き換えます。

  • LOCATION: 使用するリージョン。
  • PROJECT: 実際のプロジェクト ID
  • TRAININGPIPELINE_DISPLAY_NAME: このオペレーション用に作成されたトレーニング パイプラインの表示名。
  • TARGET_COLUMN: このモデルに予測させる列(値)。
  • WEIGHT_COLUMN:(省略可)重み列。詳細
  • TRAINING_BUDGET: モデルをトレーニングするミリノード時間単位の最大時間(1,000 ミリノード時間は 1 ノード時間に等しい)。
  • OPTIMIZATION_OBJECTIVE: 予測タイプに対してデフォルトの最適化目標を指定しない場合にのみ必須となります。詳細
  • TRANSFORMATION_TYPE: モデルのトレーニングに使用される列ごとに変換タイプが指定されます。詳細
  • COLUMN_NAME: 列の名前および指定された変換タイプ。モデルのトレーニングに使用するすべての列を指定する必要があります。
  • MODEL_DISPLAY_NAME: 新しくトレーニングされたモデルの表示名。
  • DATASET_ID: トレーニング データセットの ID。
  • Split オブジェクトを指定すれば、データ分割を制御できます。データ分割の制御については、REST を使用したデータ分割の制御をご覧ください。
  • PROJECT_NUMBER: プロジェクトに自動生成されたプロジェクト番号

HTTP メソッドと URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

リクエストの本文(JSON):

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "regression",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

リクエストを送信するには、次のいずれかのオプションを展開します。

次のような JSON レスポンスが返されます。

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Java の設定手順を完了してください。詳細については、Vertex AI Java API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularRegressionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableRegression(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      ArrayList<Transformation> tranformations = new ArrayList<>();
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("TIMESTAMP_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("DATETIME_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"))
              .build());

      AutoMlTablesInputs trainingTaskInputs =
          AutoMlTablesInputs.newBuilder()
              .addAllTransformations(tranformations)
              .setTargetColumn(targetColumn)
              .setPredictionType("regression")
              .setTrainBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              // supported regression optimisation objectives: minimize-rmse,
              // minimize-mae, minimize-rmsle
              .setOptimizationObjective("minimize-rmse")
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Regression Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

このサンプルを試す前に、Vertex AI クイックスタート: クライアント ライブラリの使用にある Node.js の設定手順を完了してください。詳細については、Vertex AI Node.js API のリファレンス ドキュメントをご覧ください。

Vertex AI に対する認証を行うには、アプリケーションのデフォルト認証情報を設定します。詳細については、ローカル開発環境の認証を設定するをご覧ください。

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesRegression() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'BOOLEAN_2unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'TIMESTAMP_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'DATE_1unique_NULLABLE'}},
    {auto: {column_name: 'TIME_1unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'DATETIME_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'STRUCT_NULLABLE.STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.BOOLEAN_2unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE'}},
  ];

  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    transformations,
    targetColumn,
    predictionType: 'regression',
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-rmse',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline tabular regression response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesRegression();

Python

Vertex AI SDK for Python のインストールまたは更新の方法については、Vertex AI SDK for Python をインストールするをご覧ください。 詳細については、Python API リファレンス ドキュメントをご覧ください。

def create_training_pipeline_tabular_regression_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_regression_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="regression"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_regression_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

REST を使用してデータ分割を制御する

トレーニング セット、検証セット、テストセットの間でトレーニング データをどのように分割するかを制御できます。Vertex AI API を使用する場合は、Split オブジェクトを使用してデータの分割を決定します。トレーニング データの分割に使用できるオブジェクト タイプの 1 つとして、Split オブジェクトを inputDataConfig オブジェクトに含めることができます。

データの分割に使用できるメソッドは、データ型によって異なります。

  • FractionSplit の最初の項目として追加します:

    • TRAINING_FRACTION: トレーニング セットに使用されるトレーニング データの割合。
    • VALIDATION_FRACTION: 検証セットに使用されるトレーニング データの割合。
    • TEST_FRACTION: テストセットに使用されるトレーニング データの割合。

    いずれか 1 つでも指定する場合は、すべてを指定する必要があります。割合の合計が 1.0 になるようにしてください。詳細

    "fractionSplit": {
    "trainingFraction": TRAINING_FRACTION,
    "validationFraction": VALIDATION_FRACTION,
    "testFraction": TEST_FRACTION
    },
    

  • PredefinedSplit:

    • DATA_SPLIT_COLUMN: データ分割値を格納する列(TRAINVALIDATIONTEST)。

    分割列を使用して、各行のデータ分割を手動で指定します。詳細

    "predefinedSplit": {
      "key": DATA_SPLIT_COLUMN
    },
    
  • TimestampSplit:

    • TRAINING_FRACTION: トレーニング セットに使用されるトレーニング データの割合。デフォルトは 0.80 です。
    • VALIDATION_FRACTION: 検証セットに使用されるトレーニング データの割合。デフォルトは 0.10 です。
    • TEST_FRACTION: テストセットに使用されるトレーニング データの割合。デフォルトは 0.10 です。
    • TIME_COLUMN: タイムスタンプを含む列。

    いずれか 1 つでも指定する場合は、すべてを指定する必要があります。割合の合計が 1.0 になるようにしてください。詳細

    "timestampSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION,
      "key": TIME_COLUMN
    }
    

分類モデルまたは回帰モデルの最適化目標

モデルをトレーニングするときに、Vertex AI はモデルタイプとターゲット列に使用されるデータタイプに基づいて、デフォルトの最適化目標を選択します。

分類モデルは、次の場合に最適です。
最適化の目標 API 値 この目標が適している問題
AUC ROC maximize-au-roc 受信者操作特性(ROC)曲線の下の面積を最大化する。クラスを区別する。バイナリ分類のデフォルト値。
ログ損失 minimize-log-loss 予測確率をできるだけ正確に維持する。マルチクラス分類でサポートされている目標のみ。
AUC PR maximize-au-prc PR(適合率 / 再現率)曲線の下の面積を最大化する。あまり一般的でないクラスの予測結果を最適化する。
再現率での適合率 maximize-precision-at-recall 特定の再現率の値で適合率を最適化する。
適合率での再現率 maximize-recall-at-precision 特定の適合率の値で再現率を最適化する。
回帰モデルは、次の場合に最適です。
最適化の目標 API 値 この目標が適している問題
RMSE minimize-rmse 二乗平均平方根誤差(RMSE)を最小化する。より極端な値を正確に取り込む。デフォルト値。
MAE minimize-mae 平均絶対誤差(MAE)を最小化する。モデルへの影響を抑えて、極端な値を外れ値として表示する。
RMSLE minimize-rmsle 二乗平均平方根対数誤差(RMSLE)を最小化する。絶対値ではなく、相対サイズに基づいてエラーにペナルティを適用する。予測値と実際の値の両方が非常に大きくなる可能性がある場合に有用。

次のステップ