使用廣度和深度學習訓練模型

本頁面說明如何使用廣度和深度學習適用的 Tabular Workflow,從表格資料集訓練分類或迴歸模型。

廣度和深度學習適用的 Tabular Workflow 有兩個版本:

  • HyperparameterTuningJob 會找出最適合用於模型訓練的超參數值組合。
  • CustomJob 可讓您指定用於模型訓練的超參數值。如果您確切知道需要哪些超參數值,可以指定這些值,不必搜尋,節省訓練資源。

如要瞭解這個工作流程使用的服務帳戶,請參閱表格工作流程的服務帳戶

Workflow API

這個工作流程會使用下列 API:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

使用 HyperparameterTuningJob 訓練模型

下列程式碼範例示範如何執行 HyperparameterTuningJob 管道:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

您可以在 pipeline_job.run() 中使用選用的 service_account 參數,將 Vertex AI Pipelines 服務帳戶設為您選擇的帳戶。

下列函式會定義管道和參數值。 訓練資料可以是 Cloud Storage 中的 CSV 檔案,也可以是 BigQuery 中的資料表。

template_path, parameter_values =  automl_tabular_utils.get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters(...)

以下是一部分的 get_wide_and_deep_hyperparameter_tuning_job_pipeline_and_parameters 參數:

參數名稱 類型 定義
data_source_csv_filenames 字串 儲存在 Cloud Storage 中的 CSV 檔案 URI。
data_source_bigquery_table_path 字串 BigQuery 資料表的 URI。
dataflow_service_account 字串 (選用) 用於執行 Dataflow 工作的自訂服務帳戶。您可以將 Dataflow 工作設為使用私有 IP 位址和特定虛擬私有雲子網路。這個參數會覆寫預設的 Dataflow 工作站服務帳戶。
study_spec_parameters_override List[Dict[String, Any]] (選用) 覆寫超參數調整。這個參數可以空白,也可以包含一或多個可能的超參數。如果未設定超參數值,Vertex AI 會使用超參數的預設調整範圍。

如要使用 study_spec_parameters_override 參數設定超參數,可以使用 Vertex AI 的輔助函式 get_wide_and_deep_study_spec_parameters_override。這個函式會傳回超參數和範圍的清單。

以下範例說明如何使用 get_wide_and_deep_study_spec_parameters_override 函式:

study_spec_parameters_override = automl_tabular_utils.get_wide_and_deep_study_spec_parameters_override()

使用 CustomJob 訓練模型

下列程式碼範例說明如何執行 CustomJob 管道:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

您可以在 pipeline_job.run() 中使用選用的 service_account 參數,將 Vertex AI Pipelines 服務帳戶設為您選擇的帳戶。

下列函式會定義管道和參數值。 訓練資料可以是 Cloud Storage 中的 CSV 檔案,也可以是 BigQuery 中的資料表。

template_path, parameter_values = automl_tabular_utils.get_wide_and_deep_trainer_pipeline_and_parameters(...)

以下是一部分的 get_wide_and_deep_trainer_pipeline_and_parameters 參數:

參數名稱 類型 定義
data_source_csv_filenames 字串 儲存在 Cloud Storage 中的 CSV 檔案 URI。
data_source_bigquery_table_path 字串 BigQuery 資料表的 URI。
dataflow_service_account 字串 (選用) 用於執行 Dataflow 工作的自訂服務帳戶。您可以將 Dataflow 工作設為使用私有 IP 位址和特定虛擬私有雲子網路。這個參數會覆寫預設的 Dataflow 工作站服務帳戶。

後續步驟

準備好使用分類或迴歸模型進行推論時,有兩種做法: