使用 TabNet 訓練模型

本頁面說明如何使用 TabNet 適用的 Tabular Workflow,從表格資料集訓練分類或迴歸模型。

TabNet 適用的 Tabular Workflow 有兩個版本:

  • HyperparameterTuningJob 會找出最適合用於模型訓練的超參數值組合。
  • CustomJob 可讓您指定用於模型訓練的超參數值。如果您確切知道需要哪些超參數值,可以指定這些值,不必搜尋,節省訓練資源。

如要瞭解這個工作流程使用的服務帳戶,請參閱表格工作流程的服務帳戶

Workflow API

這個工作流程會使用下列 API:

  • Vertex AI
  • Dataflow
  • Compute Engine
  • Cloud Storage

使用 HyperparameterTuningJob 訓練模型

下列程式碼範例示範如何執行 HyperparameterTuningJob 管道:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

您可以在 pipeline_job.run() 中使用選用的 service_account 參數,將 Vertex AI Pipelines 服務帳戶設為您選擇的帳戶。

下列函式會定義管道和參數值。 訓練資料可以是 Cloud Storage 中的 CSV 檔案,也可以是 BigQuery 中的資料表。

template_path, parameter_values =  automl_tabular_utils.get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters(...)

以下是一部分的 get_tabnet_hyperparameter_tuning_job_pipeline_and_parameters 參數:

參數名稱 類型 定義
data_source_csv_filenames 字串 儲存在 Cloud Storage 中的 CSV 檔案 URI。
data_source_bigquery_table_path 字串 BigQuery 資料表的 URI。
dataflow_service_account 字串 (選用) 用於執行 Dataflow 工作的自訂服務帳戶。您可以將 Dataflow 工作設為使用私有 IP 位址和特定虛擬私有雲子網路。這個參數會覆寫預設的 Dataflow 工作站服務帳戶。
study_spec_parameters_override List[Dict[String, Any]] (選用) 覆寫超參數調整。這個參數可以空白,也可以包含一或多個可能的超參數。如果未設定超參數值,Vertex AI 會使用超參數的預設調整範圍。

如要使用 study_spec_parameters_override 參數設定超參數,可以使用 Vertex AI 的輔助函式 get_tabnet_study_spec_parameters_override。這個函式具有下列輸入內容:

  • dataset_size_bucket:資料集大小的 bucket
    • 'small': < 100 萬列
    • 「medium」:100 萬至 1 億列
    • 「large」:> 1 億列
  • training_budget_bucket:訓練預算的儲存空間
    • 'small': < $600
    • 'medium': $600 - $2400
    • 'large': > $2400
  • prediction_type:所需的推論類型

get_tabnet_study_spec_parameters_override 函式會傳回超參數和範圍的清單。

以下範例說明如何使用 get_tabnet_study_spec_parameters_override 函式:

study_spec_parameters_override = automl_tabular_utils.get_tabnet_study_spec_parameters_override(
    dataset_size_bucket="small",
    prediction_type="classification",
    training_budget_bucket="small",
)

使用 CustomJob 訓練模型

下列程式碼範例說明如何執行 CustomJob 管道:

pipeline_job = aiplatform.PipelineJob(
    ...
    template_path=template_path,
    parameter_values=parameter_values,
    ...
)
pipeline_job.run(service_account=SERVICE_ACCOUNT)

您可以在 pipeline_job.run() 中使用選用的 service_account 參數,將 Vertex AI Pipelines 服務帳戶設為您選擇的帳戶。

下列函式會定義管道和參數值。 訓練資料可以是 Cloud Storage 中的 CSV 檔案,也可以是 BigQuery 中的資料表。

template_path, parameter_values = automl_tabular_utils.get_tabnet_trainer_pipeline_and_parameters(...)

以下是一部分的 get_tabnet_trainer_pipeline_and_parameters 參數:

參數名稱 類型 定義
data_source_csv_filenames 字串 儲存在 Cloud Storage 中的 CSV 檔案 URI。
data_source_bigquery_table_path 字串 BigQuery 資料表的 URI。
dataflow_service_account 字串 (選用) 用於執行 Dataflow 工作的自訂服務帳戶。您可以將 Dataflow 工作設為使用私有 IP 位址和特定虛擬私有雲子網路。這個參數會覆寫預設的 Dataflow 工作站服務帳戶。

後續步驟

準備好使用分類或迴歸模型進行推論時,有兩種做法: