Entrenar un modelo de clasificación de imágenes

En esta página se explica cómo entrenar un modelo de clasificación de AutoML a partir de un conjunto de datos de imágenes mediante la Google Cloud consola o la API de Vertex AI.

Entrenar un modelo de AutoML

Google Cloud consola

  1. En la Google Cloud consola, en la sección Vertex AI, vaya a la página Conjuntos de datos.

    Ve a la página Conjuntos de datos.

  2. Haga clic en el nombre del conjunto de datos que quiera usar para entrenar su modelo y abra su página de detalles.

  3. Haz clic en Entrenar un modelo nuevo.

  4. En el método de entrenamiento, selecciona AutoML.

  5. Haz clic en Continuar.

  6. Introduzca un nombre para el modelo.

  7. Si quieres definir manualmente cómo se dividen los datos de entrenamiento, despliega Opciones avanzadas y selecciona una opción de división de datos. Más información

  8. Haz clic en Start Training (Iniciar entrenamiento).

    El entrenamiento del modelo puede llevar muchas horas, en función del tamaño y la complejidad de los datos, así como del presupuesto de entrenamiento, si has especificado uno. Puedes cerrar esta pestaña y volver a ella más adelante. Recibirás un correo cuando tu modelo haya terminado de entrenarse.

API

Seleccione la pestaña correspondiente a su objetivo:

Clasificación

Selecciona la pestaña correspondiente a tu idioma o entorno:

REST

Antes de usar los datos de la solicitud, haz las siguientes sustituciones:

  • LOCATION: región en la que se encuentra el conjunto de datos y se crea el modelo. Por ejemplo, us-central1.
  • PROJECT: tu ID de proyecto.
  • TRAININGPIPELINE_DISPLAYNAME: obligatorio. Nombre visible del recurso TrainingPipeline.
  • DATASET_ID: número de ID del conjunto de datos que se va a usar para el entrenamiento.
  • fractionSplit: opcional. Una de las varias opciones de uso de aprendizaje automático splitposibles para tus datos. En fractionSplit, los valores deben sumar 1. Por ejemplo:
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME*: nombre visible del modelo subido (creado) por TrainingPipeline.
  • MODEL_DESCRIPTION*: descripción del modelo.
  • modelToUpload.labels*: cualquier conjunto de pares clave-valor para organizar tus modelos. Por ejemplo:
    • "env": "prod"
    • "tier": "backend"
  • MODELTYPE: Tipo de modelo alojado en Cloud que se va a entrenar. Las opciones son:
    • CLOUD (predeterminado)
  • NODE_HOUR_BUDGET: El coste de formación real será igual o inferior a este valor. En el caso de los modelos de Cloud, el presupuesto debe ser de entre 8000 y 800.000 horas de nodo (ambos incluidos). El valor predeterminado es 192.000,que representa un día en tiempo real, suponiendo que se usen 8 nodos.
  • PROJECT_NUMBER: el número de proyecto generado automáticamente de tu proyecto

Método HTTP y URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Cuerpo JSON de la solicitud:

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "false",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

Para enviar tu solicitud, elige una de estas opciones:

curl

Guarda el cuerpo de la solicitud en un archivo llamado request.json y ejecuta el siguiente comando:

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

Guarda el cuerpo de la solicitud en un archivo llamado request.json y ejecuta el siguiente comando:

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

La respuesta contiene información sobre las especificaciones, así como el elemento TRAININGPIPELINE_ID.

Java

Antes de probar este ejemplo, sigue las Java instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Java de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

Antes de probar este ejemplo, sigue las Node.js instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Node.js de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Para saber cómo instalar o actualizar el SDK de Vertex AI para Python, consulta Instalar el SDK de Vertex AI para Python. Para obtener más información, consulta la documentación de referencia de la API Python.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Clasificación

Selecciona la pestaña correspondiente a tu idioma o entorno:

REST

Antes de usar los datos de la solicitud, haz las siguientes sustituciones:

  • LOCATION: región en la que se encuentra el conjunto de datos y se crea el modelo. Por ejemplo, us-central1.
  • PROJECT: .
  • TRAININGPIPELINE_DISPLAYNAME: obligatorio. Nombre visible del recurso TrainingPipeline.
  • DATASET_ID: número de ID del conjunto de datos que se va a usar para el entrenamiento.
  • fractionSplit: opcional. Una de las varias opciones de uso de aprendizaje automático splitposibles para tus datos. En fractionSplit, los valores deben sumar 1. Por ejemplo:
    • {"trainingFraction": "0.7","validationFraction": "0.15","testFraction": "0.15"}
  • MODEL_DISPLAYNAME*: nombre visible del modelo subido (creado) por TrainingPipeline.
  • MODEL_DESCRIPTION*: descripción del modelo.
  • modelToUpload.labels*: cualquier conjunto de pares clave-valor para organizar tus modelos. Por ejemplo:
    • "env": "prod"
    • "tier": "backend"
  • MODELTYPE: Tipo de modelo alojado en Cloud que se va a entrenar. Las opciones son:
    • CLOUD (predeterminado)
  • NODE_HOUR_BUDGET: El coste de formación real será igual o inferior a este valor. En el caso de los modelos de Cloud, el presupuesto debe ser de entre 8000 y 800.000 horas de nodo (ambos incluidos). El valor predeterminado es 192.000,que representa un día en tiempo real, suponiendo que se usen 8 nodos.
  • PROJECT_NUMBER: el número de proyecto generado automáticamente de tu proyecto

Método HTTP y URL:

POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines

Cuerpo JSON de la solicitud:

{
  "displayName": "TRAININGPIPELINE_DISPLAYNAME",
  "inputDataConfig": {
    "datasetId": "DATASET_ID",
    "fractionSplit": {
      "trainingFraction": "DECIMAL",
      "validationFraction": "DECIMAL",
      "testFraction": "DECIMAL"
    }
  },
  "modelToUpload": {
    "displayName": "MODEL_DISPLAYNAME",
    "description": "MODEL_DESCRIPTION",
    "labels": {
      "KEY": "VALUE"
    }
  },
  "trainingTaskDefinition": "gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml",
  "trainingTaskInputs": {
    "multiLabel": "true",
    "modelType": ["MODELTYPE"],
    "budgetMilliNodeHours": NODE_HOUR_BUDGET
  }
}

Para enviar tu solicitud, elige una de estas opciones:

curl

Guarda el cuerpo de la solicitud en un archivo llamado request.json y ejecuta el siguiente comando:

curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d @request.json \
"https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines"

PowerShell

Guarda el cuerpo de la solicitud en un archivo llamado request.json y ejecuta el siguiente comando:

$cred = gcloud auth print-access-token
$headers = @{ "Authorization" = "Bearer $cred" }

Invoke-WebRequest `
-Method POST `
-Headers $headers `
-ContentType: "application/json; charset=utf-8" `
-InFile request.json `
-Uri "https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT/locations/LOCATION/trainingPipelines" | Select-Object -Expand Content

La respuesta contiene información sobre las especificaciones, así como el elemento TRAININGPIPELINE_ID.

Java

Antes de probar este ejemplo, sigue las Java instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Java de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.DeployedModelRef;
import com.google.cloud.aiplatform.v1.EnvVar;
import com.google.cloud.aiplatform.v1.FilterSplit;
import com.google.cloud.aiplatform.v1.FractionSplit;
import com.google.cloud.aiplatform.v1.InputDataConfig;
import com.google.cloud.aiplatform.v1.LocationName;
import com.google.cloud.aiplatform.v1.Model;
import com.google.cloud.aiplatform.v1.Model.ExportFormat;
import com.google.cloud.aiplatform.v1.ModelContainerSpec;
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
import com.google.cloud.aiplatform.v1.Port;
import com.google.cloud.aiplatform.v1.PredefinedSplit;
import com.google.cloud.aiplatform.v1.PredictSchemata;
import com.google.cloud.aiplatform.v1.TimestampSplit;
import com.google.cloud.aiplatform.v1.TrainingPipeline;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
import com.google.cloud.aiplatform.v1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
import com.google.rpc.Status;
import java.io.IOException;

public class CreateTrainingPipelineImageClassificationSample {

  public static void main(String[] args) throws IOException {
    // TODO(developer): Replace these variables before running the sample.
    String trainingPipelineDisplayName = "YOUR_TRAINING_PIPELINE_DISPLAY_NAME";
    String project = "YOUR_PROJECT_ID";
    String datasetId = "YOUR_DATASET_ID";
    String modelDisplayName = "YOUR_MODEL_DISPLAY_NAME";
    createTrainingPipelineImageClassificationSample(
        project, trainingPipelineDisplayName, datasetId, modelDisplayName);
  }

  static void createTrainingPipelineImageClassificationSample(
      String project, String trainingPipelineDisplayName, String datasetId, String modelDisplayName)
      throws IOException {
    PipelineServiceSettings pipelineServiceSettings =
        PipelineServiceSettings.newBuilder()
            .setEndpoint("us-central1-aiplatform.googleapis.com:443")
            .build();

    // Initialize client that will be used to send requests. This client only needs to be created
    // once, and can be reused for multiple requests. After completing all of your requests, call
    // the "close" method on the client to safely clean up any remaining background resources.
    try (PipelineServiceClient pipelineServiceClient =
        PipelineServiceClient.create(pipelineServiceSettings)) {
      String location = "us-central1";
      String trainingTaskDefinition =
          "gs://google-cloud-aiplatform/schema/trainingjob/definition/"
              + "automl_image_classification_1.0.0.yaml";
      LocationName locationName = LocationName.of(project, location);

      AutoMlImageClassificationInputs autoMlImageClassificationInputs =
          AutoMlImageClassificationInputs.newBuilder()
              .setModelType(ModelType.CLOUD)
              .setMultiLabel(false)
              .setBudgetMilliNodeHours(8000)
              .setDisableEarlyStopping(false)
              .build();

      InputDataConfig trainingInputDataConfig =
          InputDataConfig.newBuilder().setDatasetId(datasetId).build();
      Model model = Model.newBuilder().setDisplayName(modelDisplayName).build();
      TrainingPipeline trainingPipeline =
          TrainingPipeline.newBuilder()
              .setDisplayName(trainingPipelineDisplayName)
              .setTrainingTaskDefinition(trainingTaskDefinition)
              .setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
              .setInputDataConfig(trainingInputDataConfig)
              .setModelToUpload(model)
              .build();

      TrainingPipeline trainingPipelineResponse =
          pipelineServiceClient.createTrainingPipeline(locationName, trainingPipeline);

      System.out.println("Create Training Pipeline Image Classification Response");
      System.out.format("Name: %s\n", trainingPipelineResponse.getName());
      System.out.format("Display Name: %s\n", trainingPipelineResponse.getDisplayName());

      System.out.format(
          "Training Task Definition %s\n", trainingPipelineResponse.getTrainingTaskDefinition());
      System.out.format(
          "Training Task Inputs: %s\n", trainingPipelineResponse.getTrainingTaskInputs());
      System.out.format(
          "Training Task Metadata: %s\n", trainingPipelineResponse.getTrainingTaskMetadata());
      System.out.format("State: %s\n", trainingPipelineResponse.getState());

      System.out.format("Create Time: %s\n", trainingPipelineResponse.getCreateTime());
      System.out.format("StartTime %s\n", trainingPipelineResponse.getStartTime());
      System.out.format("End Time: %s\n", trainingPipelineResponse.getEndTime());
      System.out.format("Update Time: %s\n", trainingPipelineResponse.getUpdateTime());
      System.out.format("Labels: %s\n", trainingPipelineResponse.getLabelsMap());

      InputDataConfig inputDataConfig = trainingPipelineResponse.getInputDataConfig();
      System.out.println("Input Data Config");
      System.out.format("Dataset Id: %s", inputDataConfig.getDatasetId());
      System.out.format("Annotations Filter: %s\n", inputDataConfig.getAnnotationsFilter());

      FractionSplit fractionSplit = inputDataConfig.getFractionSplit();
      System.out.println("Fraction Split");
      System.out.format("Training Fraction: %s\n", fractionSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", fractionSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", fractionSplit.getTestFraction());

      FilterSplit filterSplit = inputDataConfig.getFilterSplit();
      System.out.println("Filter Split");
      System.out.format("Training Filter: %s\n", filterSplit.getTrainingFilter());
      System.out.format("Validation Filter: %s\n", filterSplit.getValidationFilter());
      System.out.format("Test Filter: %s\n", filterSplit.getTestFilter());

      PredefinedSplit predefinedSplit = inputDataConfig.getPredefinedSplit();
      System.out.println("Predefined Split");
      System.out.format("Key: %s\n", predefinedSplit.getKey());

      TimestampSplit timestampSplit = inputDataConfig.getTimestampSplit();
      System.out.println("Timestamp Split");
      System.out.format("Training Fraction: %s\n", timestampSplit.getTrainingFraction());
      System.out.format("Validation Fraction: %s\n", timestampSplit.getValidationFraction());
      System.out.format("Test Fraction: %s\n", timestampSplit.getTestFraction());
      System.out.format("Key: %s\n", timestampSplit.getKey());

      Model modelResponse = trainingPipelineResponse.getModelToUpload();
      System.out.println("Model To Upload");
      System.out.format("Name: %s\n", modelResponse.getName());
      System.out.format("Display Name: %s\n", modelResponse.getDisplayName());
      System.out.format("Description: %s\n", modelResponse.getDescription());

      System.out.format("Metadata Schema Uri: %s\n", modelResponse.getMetadataSchemaUri());
      System.out.format("Metadata: %s\n", modelResponse.getMetadata());
      System.out.format("Training Pipeline: %s\n", modelResponse.getTrainingPipeline());
      System.out.format("Artifact Uri: %s\n", modelResponse.getArtifactUri());

      System.out.format(
          "Supported Deployment Resources Types: %s\n",
          modelResponse.getSupportedDeploymentResourcesTypesList());
      System.out.format(
          "Supported Input Storage Formats: %s\n",
          modelResponse.getSupportedInputStorageFormatsList());
      System.out.format(
          "Supported Output Storage Formats: %s\n",
          modelResponse.getSupportedOutputStorageFormatsList());

      System.out.format("Create Time: %s\n", modelResponse.getCreateTime());
      System.out.format("Update Time: %s\n", modelResponse.getUpdateTime());
      System.out.format("Labels: %sn\n", modelResponse.getLabelsMap());

      PredictSchemata predictSchemata = modelResponse.getPredictSchemata();
      System.out.println("Predict Schemata");
      System.out.format("Instance Schema Uri: %s\n", predictSchemata.getInstanceSchemaUri());
      System.out.format("Parameters Schema Uri: %s\n", predictSchemata.getParametersSchemaUri());
      System.out.format("Prediction Schema Uri: %s\n", predictSchemata.getPredictionSchemaUri());

      for (ExportFormat exportFormat : modelResponse.getSupportedExportFormatsList()) {
        System.out.println("Supported Export Format");
        System.out.format("Id: %s\n", exportFormat.getId());
      }

      ModelContainerSpec modelContainerSpec = modelResponse.getContainerSpec();
      System.out.println("Container Spec");
      System.out.format("Image Uri: %s\n", modelContainerSpec.getImageUri());
      System.out.format("Command: %s\n", modelContainerSpec.getCommandList());
      System.out.format("Args: %s\n", modelContainerSpec.getArgsList());
      System.out.format("Predict Route: %s\n", modelContainerSpec.getPredictRoute());
      System.out.format("Health Route: %s\n", modelContainerSpec.getHealthRoute());

      for (EnvVar envVar : modelContainerSpec.getEnvList()) {
        System.out.println("Env");
        System.out.format("Name: %s\n", envVar.getName());
        System.out.format("Value: %s\n", envVar.getValue());
      }

      for (Port port : modelContainerSpec.getPortsList()) {
        System.out.println("Port");
        System.out.format("Container Port: %s\n", port.getContainerPort());
      }

      for (DeployedModelRef deployedModelRef : modelResponse.getDeployedModelsList()) {
        System.out.println("Deployed Model");
        System.out.format("Endpoint: %s\n", deployedModelRef.getEndpoint());
        System.out.format("Deployed Model Id: %s\n", deployedModelRef.getDeployedModelId());
      }

      Status status = trainingPipelineResponse.getError();
      System.out.println("Error");
      System.out.format("Code: %s\n", status.getCode());
      System.out.format("Message: %s\n", status.getMessage());
    }
  }
}

Node.js

Antes de probar este ejemplo, sigue las Node.js instrucciones de configuración de la guía de inicio rápido de Vertex AI con bibliotecas de cliente. Para obtener más información, consulta la documentación de referencia de la API Node.js de Vertex AI.

Para autenticarte en Vertex AI, configura las credenciales predeterminadas de la aplicación. Para obtener más información, consulta el artículo Configurar la autenticación en un entorno de desarrollo local.

/**
 * TODO(developer): Uncomment these variables before running the sample.
 * (Not necessary if passing values as arguments)
 */
/*
const datasetId = 'YOUR DATASET';
const modelDisplayName = 'NEW MODEL NAME;
const trainingPipelineDisplayName = 'NAME FOR TRAINING PIPELINE';
const project = 'YOUR PROJECT ID';
const location = 'us-central1';
  */
// Imports the Google Cloud Pipeline Service Client library
const aiplatform = require('@google-cloud/aiplatform');

const {definition} =
  aiplatform.protos.google.cloud.aiplatform.v1.schema.trainingjob;
const ModelType = definition.AutoMlImageClassificationInputs.ModelType;

// Specifies the location of the api endpoint
const clientOptions = {
  apiEndpoint: 'us-central1-aiplatform.googleapis.com',
};

// Instantiates a client
const {PipelineServiceClient} = aiplatform.v1;
const pipelineServiceClient = new PipelineServiceClient(clientOptions);

async function createTrainingPipelineImageClassification() {
  // Configure the parent resource
  const parent = `projects/${project}/locations/${location}`;

  // Values should match the input expected by your model.
  const trainingTaskInputsMessage =
    new definition.AutoMlImageClassificationInputs({
      multiLabel: true,
      modelType: ModelType.CLOUD,
      budgetMilliNodeHours: 8000,
      disableEarlyStopping: false,
    });

  const trainingTaskInputs = trainingTaskInputsMessage.toValue();

  const trainingTaskDefinition =
    'gs://google-cloud-aiplatform/schema/trainingjob/definition/automl_image_classification_1.0.0.yaml';

  const modelToUpload = {displayName: modelDisplayName};
  const inputDataConfig = {datasetId};
  const trainingPipeline = {
    displayName: trainingPipelineDisplayName,
    trainingTaskDefinition,
    trainingTaskInputs,
    inputDataConfig,
    modelToUpload,
  };
  const request = {parent, trainingPipeline};

  // Create training pipeline request
  const [response] =
    await pipelineServiceClient.createTrainingPipeline(request);

  console.log('Create training pipeline image classification response');
  console.log(`Name : ${response.name}`);
  console.log('Raw response:');
  console.log(JSON.stringify(response, null, 2));
}

createTrainingPipelineImageClassification();

Python

Para saber cómo instalar o actualizar el SDK de Vertex AI para Python, consulta Instalar el SDK de Vertex AI para Python. Para obtener más información, consulta la documentación de referencia de la API Python.

def create_training_pipeline_image_classification_sample(
    project: str,
    location: str,
    display_name: str,
    dataset_id: str,
    model_display_name: Optional[str] = None,
    model_type: str = "CLOUD",
    multi_label: bool = False,
    training_fraction_split: float = 0.8,
    validation_fraction_split: float = 0.1,
    test_fraction_split: float = 0.1,
    budget_milli_node_hours: int = 8000,
    disable_early_stopping: bool = False,
    sync: bool = True,
):
    aiplatform.init(project=project, location=location)

    job = aiplatform.AutoMLImageTrainingJob(
        display_name=display_name,
        model_type=model_type,
        prediction_type="classification",
        multi_label=multi_label,
    )

    my_image_ds = aiplatform.ImageDataset(dataset_id)

    model = job.run(
        dataset=my_image_ds,
        model_display_name=model_display_name,
        training_fraction_split=training_fraction_split,
        validation_fraction_split=validation_fraction_split,
        test_fraction_split=test_fraction_split,
        budget_milli_node_hours=budget_milli_node_hours,
        disable_early_stopping=disable_early_stopping,
        sync=sync,
    )

    model.wait()

    print(model.display_name)
    print(model.resource_name)
    print(model.uri)
    return model

Controlar la división de datos mediante REST

Puedes controlar cómo se divide tu conjunto de datos de entrenamiento entre los conjuntos de entrenamiento, validación y prueba. Cuando uses la API de Vertex AI, usa el objeto Split para determinar la división de tus datos. El objeto Split se puede incluir en el objeto InputConfig como uno de los varios tipos de objetos, cada uno de los cuales proporciona una forma diferente de dividir los datos de entrenamiento. Solo puedes seleccionar un método.

  • FractionSplit:
    • TRAINING_FRACTION: la fracción de los datos de entrenamiento que se va a usar en el conjunto de entrenamiento.
    • VALIDATION_FRACTION: la fracción de los datos de entrenamiento que se va a usar en el conjunto de validación. No se usa para los datos de vídeo.
    • TEST_FRACTION: la fracción de los datos de entrenamiento que se va a usar en el conjunto de pruebas.

    Si se especifica alguna de las fracciones, deben especificarse todas. La suma de las fracciones debe ser 1,0. Los valores predeterminados de las fracciones varían en función del tipo de datos. Más información

    "fractionSplit": {
      "trainingFraction": TRAINING_FRACTION,
      "validationFraction": VALIDATION_FRACTION,
      "testFraction": TEST_FRACTION
    },
    
  • FilterSplit:
    • TRAINING_FILTER: los elementos de datos que coincidan con este filtro se usarán en el conjunto de entrenamiento.
    • VALIDATION_FILTER: los elementos de datos que coinciden con este filtro se usan en el conjunto de validación. Debe ser "-" para los datos de vídeo.
    • TEST_FILTER: los elementos de datos que coinciden con este filtro se usan en el conjunto de prueba.

    Estos filtros se pueden usar con la etiqueta ml_use o con cualquier etiqueta que aplique a sus datos. Consulte más información sobre cómo usar la etiqueta ml-use y otras etiquetas para filtrar sus datos.

    En el siguiente ejemplo se muestra cómo usar el objeto filterSplit con la etiqueta ml_use y el conjunto de validación incluido:

    "filterSplit": {
    "trainingFilter": "labels.aiplatform.googleapis.com/ml_use=training",
    "validationFilter": "labels.aiplatform.googleapis.com/ml_use=validation",
    "testFilter": "labels.aiplatform.googleapis.com/ml_use=test"
    }