Klassifizierungs- oder Regressionsmodell trainieren

Auf dieser Seite erfahren Sie, wie Sie mithilfe der Google Cloud Console oder der Vertex AI API eine Klassifizierung oder Regression aus einem tabellarischen Dataset trainieren.

Vorbereitung

Bevor Sie ein Modell trainieren können, müssen Sie die folgenden Schritte ausführen:

Modell trainieren

Google Cloud Console

  1. Rufen Sie in der Google Cloud Console im Abschnitt "Vertex AI" die Seite Datasets auf.

    Zur Seite „Datasets“

  2. Klicken Sie auf den Namen des Datasets, das Sie zum Trainieren Ihres Modells verwenden möchten, um dessen Detailseite zu öffnen.

  3. Wenn Ihr Datentyp Annotationssätze verwendet, wählen Sie den Annotationssatz aus, den Sie für dieses Modell verwenden möchten.

  4. Klicken Sie auf Neues Modell trainieren.

  5. Wählen Sie Sonstiges aus.

  6. Führen Sie im Fenster Neues Modell trainieren die folgenden Schritte aus:

    1. Wählen Sie die Modelltrainingsmethode aus.

      • AutoML ist eine gute Wahl für eine Vielzahl von Anwendungsfällen.

      Klicken Sie auf Weiter.

    2. Geben Sie den Anzeigenamen für das neue Modell ein.

    3. Wählen Sie die Zielspalte aus.

      Die Zielspalte ist der Wert, den das Modell vorhersagt.

      Weitere Informationen zu den Anforderungen an Zielspalten.

    4. Optional: Markieren Sie Test-Dataset nach BigQuery exportieren und geben Sie den Namen der Tabelle an, um Ihr Test-Dataset nach BigQuery zu exportieren.

    5. Optional: Wenn Sie festlegen möchten, wie die Daten auf Trainings-, Test- und Validierungs-Datasets aufgeteilt werden, öffnen Sie die Erweiterten Optionen. Sie können zwischen den folgenden Datenaufteilungsoptionen wählen:

      • Zufällig (Standard): Vertex AI wählt nach dem Zufallsprinzip die Zeilen aus, die mit den einzelnen Datasets verknüpft sind. Standardmäßig wählt Vertex AI 80 % Ihrer Datenzeilen für das Trainings-Dataset, 10 % für das Validierungs-Dataset und 10 % für das Test-Dataset aus.
      • Manuell: Vertex AI wählt Datenzeilen für jedes Dataset anhand der Werte in einer Spalte für die Datenaufteilung aus. Geben Sie den Namen der Spalte für die Datenaufteilung an.
      • Chronologisch: Vertex AI teilt Daten basierend auf dem Zeitstempel in einer Zeitspalte auf. Geben Sie den Namen der Zeitspalte an.

      Weitere Informationen zur Datenaufteilung.

    6. Klicken Sie auf Weiter.

    7. Optional: Klicken Sie auf Statistiken generieren. Durch das Generieren von Statistiken werden die Drop-down-Menüs Transformation ausgefüllt.

    8. Prüfen Sie auf der Seite „Trainingsoptionen“ Ihre Spaltenliste und schließen Sie alle Spalten aus dem Training aus, die nicht zum Trainieren des Modells verwendet werden sollen.

    9. Prüfen Sie die für Ihre eingeschlossenen Features ausgewählten Transformationen und ob ungültige Daten zulässig sind, und nehmen Sie ggf. Aktualisierungen vor.

      Weitere Informationen zu Transformationen und ungültigen Daten.

    10. Wenn Sie eine Gewichtungsspalte angeben oder das Standardoptimierungsziel ändern möchten, öffnen Sie Erweiterte Optionen und treffen Sie Ihre Auswahl.

      Weitere Informationen finden Sie unter Gewichtungsspalten und Optimierungsziele.

    11. Klicken Sie auf Weiter.

    12. Konfigurieren Sie auf der Seite Computing und Preise Folgendes:

      Geben Sie an, wie viele Stunden das Modell maximal trainiert werden soll.

      Mit dieser Einstellung können Sie die Trainingskosten begrenzen. Die tatsächlich benötigte Zeit kann aber länger sein als dieser Wert, da auch noch andere Vorgänge am Erstellen eines neuen Modells beteiligt sind.

      Die empfohlene Trainingszeit hängt von der Größe der Trainingsdaten ab. Die folgende Tabelle zeigt die empfohlene Trainingszeit nach Zeilenanzahl. Eine große Anzahl von Spalten erhöht auch die Trainingszeit.

      Zeilen Vorgeschlagene Trainingszeit
      Unter 100.000 1–3 Stunden
      100.000–1.000.000 1–6 Stunden
      1.000.000–10.000.000 1–12 Stunden
      Über 10.000.000 3–24 Stunden
      Informationen zu den Preisen für Trainings finden Sie auf der Seite „Preise“.

    13. Klicken Sie auf START TRAINING (Training starten).

      Das Modelltraining kann viele Stunden dauern, je nach Größe und Komplexität Ihrer Daten und Ihres Trainingsbudgets, sofern Sie eines angegeben haben. Sie können diesen Tab schließen und später zu ihm zurückkehren. Wenn das Training für Ihr Modell abgeschlossen ist, erhalten eine E-Mail.

API

Wählen Sie ein tabellarisches Datentypziel aus.

Klassifizierung

Wählen Sie einen Tab für Ihre Sprache oder Ihre Umgebung aus:

REST

Sie verwenden den Befehl trainingPipelines.create, um ein Modell zu trainieren.

Modell trainieren

Ersetzen Sie diese Werte in den folgenden Anfragedaten:

  • LOCATION: Ihre Region.
  • PROJECT: Ihre Projekt-ID.
  • TRAININGPIPELINE_DISPLAY_NAME: Anzeigename für die Trainingspipeline, die für diesen Vorgang erstellt wurde.
  • TARGET_COLUMN: Die Spalte (Wert), für die das Modell Vorhersagen treffen soll.
  • WEIGHT_COLUMN: (Optional) Die Gewichtungsspalte. Weitere Informationen
  • TRAINING_BUDGET: Die maximale Zeit, die das Modell trainiert werden soll, in Milli-Knotenstunden (1.000 Milli-Knotenstunden entsprechen einer Knotenstunde).
  • OPTIMIZATION_OBJECTIVE: Nur erforderlich, wenn Sie das standardmäßige Optimierungsziel für Ihren Vorhersagetyp nicht verwenden möchten. Weitere Informationen
  • TRANSFORMATION_TYPE: Der Transformationstyp wird für jede Spalte bereitgestellt, die zum Trainieren des Modells verwendet wird. Weitere Informationen
  • COLUMN_NAME: Der Name der Spalte mit dem angegebenen Transformationstyp. Jede Spalte, die zum Trainieren des Modells verwendet wird, muss angegeben werden.
  • MODEL_DISPLAY_NAME: Anzeigename für das neu trainierte Modell.
  • DATASET_ID: ID für das Trainings-Dataset.
  • Sie können ein Split-Objekt zur Steuerung Ihrer Datenaufteilung bereitstellen. Weitere Informationen zur Datenaufteilung finden Sie unter Datenaufteilung mit REST steuern.
  • PROJECT_NUMBER: Die automatisch generierte Projektnummer Ihres Projekts.

HTTP-Methode und URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON-Text anfordern:

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "classification",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

Wenn Sie die Anfrage senden möchten, maximieren Sie eine der folgenden Optionen:

Sie sollten eine JSON-Antwort ähnlich wie diese erhalten:

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

Bevor Sie dieses Beispiel anwenden, folgen Sie den Java-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Java API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableClassification(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableClassification(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      Transformation transformation1 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_width").build())
              .build();
      Transformation transformation2 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("sepal_length").build())
              .build();
      Transformation transformation3 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_length").build())
              .build();
      Transformation transformation4 =
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("petal_width").build())
              .build();

      ArrayList<Transformation> transformationArrayList = new ArrayList<>();
      transformationArrayList.add(transformation1);
      transformationArrayList.add(transformation2);
      transformationArrayList.add(transformation3);
      transformationArrayList.add(transformation4);

      AutoMlTablesInputs autoMlTablesInputs =
          AutoMlTablesInputs.newBuilder()
              .setTargetColumn(targetColumn)
              .setPredictionType("classification")
              .addAllTransformations(transformationArrayList)
              .setTrainBudgetMilliNodeHours(8000)
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlTablesInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Classification Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Bevor Sie dieses Beispiel anwenden, folgen Sie den Node.js-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Node.js API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'sepal_width'}},
    {auto: {column_name: 'sepal_length'}},
    {auto: {column_name: 'petal_length'}},
    {auto: {column_name: 'petal_width'}},
  ];
  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    targetColumn: targetColumn,
    predictionType: 'classification',
    transformations: transformations,
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-log-loss',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline tabular classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesClassification();

Python

Informationen zur Installation des Vertex AI SDK for Python finden Sie unter Vertex AI SDK for Python installieren. Weitere Informationen finden Sie in der Referenzdokumentation zur Python API.

def create_training_pipeline_tabular_classification_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = None,
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_classification_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="classification"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_classification_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Regression

Wählen Sie einen Tab für Ihre Sprache oder Ihre Umgebung aus:

REST

Sie verwenden den Befehl trainingPipelines.create, um ein Modell zu trainieren.

Modell trainieren

Ersetzen Sie diese Werte in den folgenden Anfragedaten:

  • LOCATION: Ihre Region.
  • PROJECT: Ihre Projekt-ID.
  • TRAININGPIPELINE_DISPLAY_NAME: Anzeigename für die Trainingspipeline, die für diesen Vorgang erstellt wurde.
  • TARGET_COLUMN: Die Spalte (Wert), für die das Modell Vorhersagen treffen soll.
  • WEIGHT_COLUMN: (Optional) Die Gewichtungsspalte. Weitere Informationen
  • TRAINING_BUDGET: Die maximale Zeit, die das Modell trainiert werden soll, in Milli-Knotenstunden (1.000 Milli-Knotenstunden entsprechen einer Knotenstunde).
  • OPTIMIZATION_OBJECTIVE: Nur erforderlich, wenn Sie das standardmäßige Optimierungsziel für Ihren Vorhersagetyp nicht verwenden möchten. Weitere Informationen
  • TRANSFORMATION_TYPE: Der Transformationstyp wird für jede Spalte bereitgestellt, die zum Trainieren des Modells verwendet wird. Weitere Informationen
  • COLUMN_NAME: Der Name der Spalte mit dem angegebenen Transformationstyp. Jede Spalte, die zum Trainieren des Modells verwendet wird, muss angegeben werden.
  • MODEL_DISPLAY_NAME: Anzeigename für das neu trainierte Modell.
  • DATASET_ID: ID für das Trainings-Dataset.
  • Sie können ein Split-Objekt zur Steuerung Ihrer Datenaufteilung bereitstellen. Weitere Informationen zur Datenaufteilung finden Sie unter Datenaufteilung mit REST steuern.
  • PROJECT_NUMBER: Die automatisch generierte Projektnummer Ihres Projekts.

HTTP-Methode und URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

JSON-Text anfordern:

{
    "displayName": "TRAININGPIPELINE_DISPLAY_NAME",
    "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
    "trainingTaskInputs": {
        "targetColumn": "TARGET_COLUMN",
        "weightColumn": "WEIGHT_COLUMN",
        "predictionType": "regression",
        "trainBudgetMilliNodeHours": TRAINING_BUDGET,
        "optimizationObjective": "OPTIMIZATION_OBJECTIVE",
        "transformations": [
            {"TRANSFORMATION_TYPE_1":  {"column_name" : "COLUMN_NAME_1"} },
            {"TRANSFORMATION_TYPE_2":  {"column_name" : "COLUMN_NAME_2"} },
            ...
    },
    "modelToUpload": {"displayName": "MODEL_DISPLAY_NAME"},
    "inputDataConfig": {
      "datasetId": "DATASET_ID",
    }
}

Wenn Sie die Anfrage senden möchten, maximieren Sie eine der folgenden Optionen:

Sie sollten eine JSON-Antwort ähnlich wie diese erhalten:

{
  "name": "projects/PROJECT_NUMBER/locations/us-central1/trainingPipelines/4567",
  "displayName": "myModelName",
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tabular_1.0.0.yaml",
  "modelToUpload": {
    "displayName": "myModelName"
  },
  "state": "PIPELINE_STATE_PENDING",
  "createTime": "2020-08-18T01:22:57.479336Z",
  "updateTime": "2020-08-18T01:22:57.479336Z"
}

Java

Bevor Sie dieses Beispiel anwenden, folgen Sie den Java-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Java API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.


import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.AutoTransformation;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlTablesInputs.Transformation.TimestampTransformation;
import com.google.rpc.Status;
import java.io.IOException;
import java.util.ArrayList;

public class CreateTrainingPipelineTabularRegressionSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String project = "YOUR_PROJECT_ID";
    String modelDisplayName = "YOUR_DATASET_DISPLAY_NAME";
    String datasetId = "YOUR_DATASET_ID";
    String targetColumn = "TARGET_COLUMN";
    createTrainingPipelineTableRegression(project, modelDisplayName, datasetId, targetColumn);
  }

  static void createTrainingPipelineTableRegression(
      String project, String modelDisplayName, String datasetId, String targetColumn)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      LocationName locationName = LocationName.of(project, location);
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml";

      // Set the columns used for training and their data types
      ArrayList<Transformation> tranformations = new ArrayList<>();
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("BOOLEAN_2unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("TIMESTAMP_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("DATE_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(AutoTransformation.newBuilder().setColumnName("TIME_1unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setTimestamp(
                  TimestampTransformation.newBuilder()
                      .setColumnName("DATETIME_1unique_NULLABLE")
                      .setInvalidValuesAllowed(true))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.STRING_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.FLOAT_5000unique_REPEATED"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE"))
              .build());
      tranformations.add(
          Transformation.newBuilder()
              .setAuto(
                  AutoTransformation.newBuilder()
                      .setColumnName("STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE"))
              .build());

      AutoMlTablesInputs trainingTaskInputs =
          AutoMlTablesInputs.newBuilder()
              .addAllTransformations(tranformations)
              .setTargetColumn(targetColumn)
              .setPredictionType("regression")
              .setTrainBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              // supported regression optimisation objectives: minimize-rmse,
              // minimize-mae, minimize-rmsle
              .setOptimizationObjective("minimize-rmse")
              .build();

      FractionSplit fractionSplit =
          FractionSplit.newBuilder()
              .setTrainingFraction(0.8)
              .setValidationFraction(0.1)
              .setTestFraction(0.1)
              .build();

      InputDataConfig inputDataConfig =
          InputDataConfig.newBuilder()
              .setDatasetId(datasetId)
              .setFractionSplit(fractionSplit)
              .build();
      Model modelToUpload = Model.newBuilder().setDisplayName(modelDisplayName).build();

      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(modelDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(trainingTaskInputs))
              .setInputDataConfig(inputDataConfig)
              .setModelToUpload(modelToUpload)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Tabular Regression Response");
      System.out.format("\tName: %s\n", trainingPipelineResponse.getName());
      System.out.format("\tDisplay Name: %s\n", trainingPipelineResponse.getDisplayName());
      System.out.format(
          "\tTraining Task Definition: %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "\tTraining Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "\tTraining Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());

      System.out.format("\tState: %s\n", trainingPipelineResponse.getState());
      System.out.format("\tCreate Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("\tStart Time: %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("\tEnd Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("\tUpdate Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("\tLabels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfigResponse = trainingPipelineResponse.getInputDataConfig();
      System.out.println("\tInput Data Config");
      System.out.format("\t\tDataset Id: %s\n", inputDataConfigResponse.getDatasetId());
      System.out.format(
          "\t\tAnnotations Filter: %s\n", inputDataConfigResponse.getAnnotationsFilter());

      FractionSplit fractionSplitResponse = inputDataConfigResponse.getFractionSplit();
      System.out.println("\t\tFraction Split");
      System.out.format(
          "\t\t\tTraining Fraction: %s\n", fractionSplitResponse.getTrainingFraction());
      System.out.format(
          "\t\t\tValidation Fraction: %s\n", fractionSplitResponse.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", fractionSplitResponse.getTestFraction());

      FilterSplit filterSplit = inputDataConfigResponse.getFilterSplit();
      System.out.println("\t\tFilter Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", filterSplit.getTrainingFilter());
      System.out.format("\t\t\tValidation Fraction: %s\n", filterSplit.getValidationFilter());
      System.out.format("\t\t\tTest Fraction: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfigResponse.getPredefinedSplit();
      System.out.println("\t\tPredefined Split");
      System.out.format("\t\t\tKey: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfigResponse.getTimestampSplit();
      System.out.println("\t\tTimestamp Split");
      System.out.format("\t\t\tTraining Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("\t\t\tValidation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("\t\t\tTest Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("\t\t\tKey: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("\tModel To Upload");
      System.out.format("\t\tName: %s\n", modelResponse.getName());
      System.out.format("\t\tDisplay Name: %s\n", modelResponse.getDisplayName());
      System.out.format("\t\tDescription: %s\n", modelResponse.getDescription());
      System.out.format("\t\tMetadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("\t\tMeta Data: %s\n", modelResponse.getMetadata());
      System.out.format("\t\tTraining Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("\t\tArtifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "\t\tSupported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList().toString());
      System.out.format(
          "\t\tSupported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList().toString());
      System.out.format(
          "\t\tSupported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList().toString());

      System.out.format("\t\tCreate Time: %s\n", modelResponse.getCreateTime());
      System.out.format("\t\tUpdate Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("\t\tLables: %s\n", modelResponse.getLabelsMap());
      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();

      System.out.println("\tPredict Schemata");
      System.out.format("\t\tInstance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format(
          "\t\tParameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format(
          "\t\tPrediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (Model.ExportFormat supportedExportFormat :
          modelResponse.getSupportedExportFormatsList()) {
        System.out.println("\tSupported Export Format");
        System.out.format("\t\tId: %s\n", supportedExportFormat.getId());
      }
      ModelContainerSpec containerSpec = modelResponse.getContainerSpec();

      System.out.println("\tContainer Spec");
      System.out.format("\t\tImage Uri: %s\n", containerSpec.getImageUri());
      System.out.format("\t\tCommand: %s\n", containerSpec.getCommandList());
      System.out.format("\t\tArgs: %s\n", containerSpec.getArgsList());
      System.out.format("\t\tPredict Route: %s\n", containerSpec.getPredictRoute());
      System.out.format("\t\tHealth Route: %s\n", containerSpec.getHealthRoute());

      for (EnvVar envVar : containerSpec.getEnvList()) {
        System.out.println("\t\tEnv");
        System.out.format("\t\t\tName: %s\n", envVar.getName());
        System.out.format("\t\t\tValue: %s\n", envVar.getValue());
      }

      for (Port port : containerSpec.getPortsList()) {
        System.out.println("\t\tPort");
        System.out.format("\t\t\tContainer Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("\tDeployed Model");
        System.out.format("\t\tEndpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("\t\tDeployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("\tError");
      System.out.format("\t\tCode: %s\n", status.getCode());
      System.out.format("\t\tMessage: %s\n", status.getMessage());
    }
  }
}

Node.js

Bevor Sie dieses Beispiel anwenden, folgen Sie den Node.js-Einrichtungsschritten in der Vertex AI-Kurzanleitung zur Verwendung von Clientbibliotheken. Weitere Informationen finden Sie in der Referenzdokumentation zur Vertex AI Node.js API.

Richten Sie zur Authentifizierung bei Vertex AI Standardanmeldedaten für Anwendungen ein. Weitere Informationen finden Sie unter Authentifizierung für eine lokale Entwicklungsumgebung einrichten.

/**
 * TODO(developer): Uncomment these variables before running the sample.\
 * (Not necessary if passing values as arguments)
 */

// const datasetId = 'YOUR_DATASET_ID';
// const modelDisplayName = 'YOUR_MODEL_DISPLAY_NAME';
// const trainingPipelineDisplayName = 'YOUR_TRAINING_PIPELINE_DISPLAY_NAME';
// const targetColumn = 'YOUR_TARGET_COLUMN';
// const project = 'YOUR_PROJECT_ID';
// const location = 'YOUR_PROJECT_LOCATION';
const aiplatform = require('@google-cloud/aiplatform');
const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;

// Imports the Google Cloud Pipeline Service Client library
const {PipelineServiceClient} = aiplatform.v1;
// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineTablesRegression() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  const transformations = [
    {auto: {column_name: 'STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'BOOLEAN_2unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'TIMESTAMP_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'DATE_1unique_NULLABLE'}},
    {auto: {column_name: 'TIME_1unique_NULLABLE'}},
    {
      timestamp: {
        column_name: 'DATETIME_1unique_NULLABLE',
        invalid_values_allowed: true,
      },
    },
    {auto: {column_name: 'STRUCT_NULLABLE.STRING_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.INTEGER_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REQUIRED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.FLOAT_5000unique_REPEATED'}},
    {auto: {column_name: 'STRUCT_NULLABLE.NUMERIC_5000unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.BOOLEAN_2unique_NULLABLE'}},
    {auto: {column_name: 'STRUCT_NULLABLE.TIMESTAMP_1unique_NULLABLE'}},
  ];

  const trainingTaskInputsObj = new definition.AutoMlTablesInputs({
    transformations,
    targetColumn,
    predictionType: 'regression',
    trainBudgetMilliNodeHours: 8000,
    disableEarlyStopping: false,
    optimizationObjective: 'minimize-rmse',
  });
  const trainingTaskInputs = trainingTaskInputsObj.toValue();

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {
    datasetId: datasetId,
    fractionSplit: {
      trainingFraction: 0.8,
      validationFraction: 0.1,
      testFraction: 0.1,
    },
  };
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition:
      'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_tables_1.0.0.yaml',
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {
    parent,
    trainingPipeline,
  };

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline tabular regression response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}
createTrainingPipelineTablesRegression();

Python

Informationen zur Installation des Vertex AI SDK for Python finden Sie unter Vertex AI SDK for Python installieren. Weitere Informationen finden Sie in der Referenzdokumentation zur Python API.

def create_training_pipeline_tabular_regression_sample(
    project: str,
    display_name: str,
    dataset_id: str,
    location: str = "us-central1",
    model_display_name: str = "my_model",
    target_column: str = "target_column",
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    tabular_regression_job = aiplatform.AutoMLTabularTrainingJob(
        display_name=display_name, optimization_prediction_type="regression"
    )

    my_tabular_dataset = aiplatform.TabularDataset(dataset_name=dataset_id)

    model = tabular_regression_job.run(
        dataset=my_tabular_dataset,
        target_column=target_column,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        model_display_name=model_display_name,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Datenaufteilung mit REST steuern

Sie können steuern, wie Ihre Trainingsdaten auf die Trainings-, Validierungs- und Test-Datasets aufgeteilt werden. Wenn Sie die Vertex AI API verwenden, nutzen Sie das Objekt Split, um Ihre Datenaufteilung anzugeben. Das Split-Objekt kann als einer von mehreren Objekttypen in das inputDataConfig-Objekt aufgenommen werden. Jedes Objekt bietet eine andere Möglichkeit zur Aufteilung der Trainingsdaten.

Welche Methoden zum Aufteilen Ihrer Daten verwendet werden können, hängt vom Datentyp ab:

  • FractionSplit:

    • TRAINING_FRACTION: Der Anteil der Trainingsdaten, die für das Trainings-Dataset verwendet werden sollen.
    • VALIDATION_FRACTION: Der Anteil der Trainingsdaten, die für das Validierungs-Dataset verwendet werden sollen.
    • TEST_FRACTION: Der Anteil der Trainingsdaten, die für das Trainings-Dataset verwendet werden sollen.

    Wenn Bruchzahlen angegeben werden, müssen alle angegeben werden. Die Bruchwerte müssen zusammengenommen 1,0 ergeben. Weitere Informationen.

    "fractionSplit": {
    "trainingFraction": TRAINING_FRACTION,
    "validationFraction": VALIDATION_FRACTION,
    "testFraction": TEST_FRACTION
    },
    

  • PredefinedSplit:

    • DATA_SPLIT_COLUMN: die Spalte mit den Datenaufteilungswerten (TRAIN, VALIDATION, TEST).

    Geben Sie die Datenaufteilung für jede Zeile manuell mithilfe einer geteilten Spalte an. Weitere Informationen.

    "predefinedSplit": {
      "key": DATA_SPLIT_COLUMN
    },
    
  • TimestampSplit:

    • TRAINING_FRACTION: Der Prozentsatz der Trainingsdaten, die für das Trainings-Dataset verwendet werden sollen. Der Standardfaktor ist 0,80.
    • VALIDATION_FRACTION: Der Prozentsatz der Trainingsdaten, die für das Validierungs-Dataset verwendet werden sollen. Der Standardwert ist 0,10.
    • TEST_FRACTION: Der Prozentsatz der Trainingsdaten, die für das Test-Dataset verwendet werden sollen. Der Standardwert ist 0,10.
    • TIME_COLUMN: Die Spalte mit den Zeitstempeln.

    Wenn Bruchzahlen angegeben werden, müssen alle angegeben werden. Die Bruchwerte müssen zusammengenommen 1,0 ergeben. Weitere Informationen.

    "timestampSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION,
      "key": TIME_COLUMN
    }
    

Optimierungsziele für Klassifizierungs- oder Regressionsmodelle

Wenn Sie ein Modell trainieren, wählt Vertex AI ein Standardoptimierungsziel basierend auf Ihrem Modelltyp und dem für die Zielspalte verwendeten Datentyp aus.

Klassifikationsmodelle eignen sich am besten für:
Optimierungsziel API-Wert Zweck
AUC ROC maximize-au-roc Den Bereich unter der Grenzwertoptimierungskurve (Receiver Operating Curve, ROC) maximieren. Sie können zwischen Klassen unterscheiden. Standardwert für die binäre Klassifizierung.
Logarithmischer Verlust minimize-log-loss Möglichst genaue Vorhersagewahrscheinlichkeiten erzielen. Nur unterstütztes Ziel für die Klassifizierung mehrerer Klassen.
AUC PR maximize-au-prc Den Bereich unter der Genauigkeits-/Trefferquotenkurve maximieren. Optimiert die Ergebnisse für Vorhersagen für die weniger gängige Klasse.
Präzision von Trefferquote maximize-precision-at-recall Optimieren Sie die Präzision bei einem bestimmten Trefferquotenwert.
Trefferquote von Präzision maximize-recall-at-precision Optimieren Sie die Trefferquote bei einem bestimmten Präzisionswert.
Regressionsmodelle eignen sich am besten für:
Optimierungsziel API-Wert Zweck
RMSE minimize-rmse Wurzel des mittleren quadratischen Fehlers (RMSE) minimieren. Mehr Extremwerte genau erfassen. Standardwert.
MAE minimize-mae Mittleren absoluten Fehler (MAE) minimieren. Extremwerte als Ausreißer mit geringerem Einfluss auf das Modell ansehen.
RMSLE minimize-rmsle Wurzel des mittleren quadratischen Logfehlers (RMSLE) minimieren. Abzüge für Fehler nach der relativen Größe statt des absoluten Werts vornehmen. Dies ist hilfreich, wenn sowohl die vorhergesagten als auch die tatsächlichen Werte sehr groß werden können.

Nächste Schritte