获取文本嵌入

本页介绍了如何使用文本嵌入 API 创建文本嵌入。

Vertex AI 支持通过文本嵌入 API 在 Google Distributed Cloud (GDC) 气隙环境中进行文本嵌入。文本嵌入使用向量表示法。

文本嵌入可将以任何支持的语言编写的文本数据转换为数值向量。这些向量表示旨在捕获它们所表示字词的语义含义和上下文。文本嵌入模型可以为各种任务类型(例如文档检索、问答、分类和文本的事实验证)生成经过优化的嵌入。

如需详细了解文本嵌入所使用的关键概念,请参阅以下文档:

准备工作

在 GDC 项目中使用文本嵌入之前,请按以下步骤操作:

  1. 为 Vertex AI 设置项目
  2. 根据语言任务类型选择一个可用于文本嵌入的模型
  3. 启用 Text Embedding API 或 Text Embedding Multilingual API,具体取决于您要使用的模型。
  4. 向用户或服务账号授予对 Text Embedding 或 Text Embedding Multilingual 的相应访问权限。如需了解详情,请参阅以下文档:

  5. 安装 Vertex AI 客户端库

  6. 获取身份验证令牌

您必须为模型请求、服务账号和 IAM 角色绑定使用同一项目。

获取文本片段的文本嵌入

满足前提条件后,您可以使用 Text Embedding 或 Text Embedding Multilingual 模型,通过 API 或 Python 版 SDK 获取文本片段的文本嵌入。

以下示例使用 text-embedding-004 模型。

向文本嵌入 API 发出 REST 请求。否则,请通过 Python 脚本与模型互动,以获取文本嵌入。

REST

如需获取文本嵌入,请通过指定模型端点来发送 POST 请求。

如需提出要求,请按以下步骤操作:

  1. 将请求内容保存在名为 request.json 的 JSON 文件中。该文件必须类似于以下示例:

    {
      "instances": [
        {
          "content": "What is life?",
          "task_type": "",
          "title": ""
        }
      ]
    }
    
  2. 使用 curl 工具发出请求:

    curl -X POST \
    -H "Authorization: Bearer TOKEN"\
    -H "Content-Type: application/json; charset=utf-8" \
    -d @request.json \
    "https://ENDPOINT:443/v1/projects/PROJECT/locations/PROJECT/endpoints/MODEL:predict"
    

    替换以下内容:

    • TOKEN:您获得的身份验证令牌
    • ENDPOINT:贵组织使用的 Text Embedding 或 Text Embedding Multilingual 端点。如需了解详情,请查看服务状态和端点
    • PROJECT:您的项目名称。
    • MODEL:您要使用的模型。可用的值如下:

      • endpoint-text-embedding(适用于文本嵌入模型)。
      • endpoint-text-embedding-multilingual(适用于 Text Embedding Multilingual 模型)。

您必须获得类似以下内容的 JSON 响应:

{"predictions":[[-0.00668720435,3.20804138e-05,-0.0281705819,-0.00954890903,-0.0818724185,0.0150693133,-0.00677698106, …. ,0.0167487375,-0.0534791686,0.00208711182,0.032938987,-0.01491543]],"deployedModelId":"text-embedding","model":"models/text-embedding/1","modelDisplayName":"text-embedding","modelVersionId":"1"}

Python

请按照以下步骤从 Python 脚本获取文本嵌入:

  1. 安装 Vertex AI Platform 客户端库

  2. 将请求内容保存在名为 request.json 的 JSON 文件中。该文件必须类似于以下示例:

    {
      "instances": [
        {
          "content": "What is life?",
          "task_type": "",
          "title": ""
        }
      ]
    }
    
  3. 安装所需的 Python 库:

    pip install absl-py
    
  4. 创建一个名为 client.py 的 Python 文件。文件必须如下例所示:

    import json
    import os
    from typing import Sequence
    
    import grpc
    from absl import app
    from absl import flags
    
    from google.protobuf import json_format
    from google.protobuf.struct_pb2 import Value
    from google.cloud.aiplatform_v1.services import prediction_service
    
    _INPUT = flags.DEFINE_string("input", None, "input", required=True)
    _HOST = flags.DEFINE_string("host", None, "Prediction endpoint", required=True)
    _ENDPOINT_ID = flags.DEFINE_string("endpoint_id", None, "endpoint id", required=True)
    _TOKEN = flags.DEFINE_string("token", None, "STS token", required=True)
    
    # ENDPOINT_RESOURCE_NAME is a placeholder value that doesn't affect prediction behavior.
    ENDPOINT_RESOURCE_NAME="projects/PROJECT/locations/PROJECT/endpoints/MODEL"
    
    os.environ["GRPC_DEFAULT_SSL_ROOTS_FILE_PATH"] = CERT_NAME
    
    # predict_client_secure builds a client that requires TLS
    def predict_client_secure(host):
      with open(os.environ["GRPC_DEFAULT_SSL_ROOTS_FILE_PATH"], 'rb') as f:
          creds = grpc.ssl_channel_credentials(f.read())
    
      channel_opts = ()
      channel_opts += (('grpc.ssl_target_name_override', host),)
      client = prediction_service.PredictionServiceClient(
          transport=prediction_service.transports.grpc.PredictionServiceGrpcTransport(
              channel=grpc.secure_channel(target=host+":443", credentials=creds, options=channel_opts)))
      return client
    
    def predict_func(client, instances, token):
      resp = client.predict(
          endpoint=ENDPOINT_RESOURCE_NAME,
          instances=instances,
          metadata=[ ("x-vertex-ai-endpoint-id", _ENDPOINT_ID.value), ("authorization", "Bearer " + token),])
      print(resp)
    
    def main(argv: Sequence[str]):
      del argv  # Unused.
      with open(_INPUT.value) as json_file:
          data = json.load(json_file)
          instances = [json_format.ParseDict(s, Value()) for s in data["instances"]]
    
      client = predict_client_secure(_HOST.value,)
    
      predict_func(client=client, instances=instances, token=_TOKEN.value)
    
    if __name__=="__main__":
      app.run(main)
    

    替换以下内容:

    • PROJECT:您的项目名称。
    • MODEL:您要使用的模型。以下是可用的值:
      • endpoint-text-embedding(适用于文本嵌入模型)。
      • endpoint-text-embedding-multilingual(适用于 Text Embedding Multilingual 模型)。
    • CERT_NAME:证书授权机构 (CA) 证书文件的名称,例如 org-1-trust-bundle-ca.cert。只有在开发环境中,您才需要此值。否则,请省略此字段。
  5. 发送请求:

    python client.py --token=TOKEN --host=ENDPOINT --input=request.json --endpoint_id=MODEL
    

    替换以下内容:

    • TOKEN:您获得的身份验证令牌
    • ENDPOINT:贵组织使用的 Text Embedding 或 Text Embedding Multilingual 端点。如需了解详情,请查看服务状态和端点
    • MODEL:您要使用的模型。可用的值如下:

      • endpoint-text-embedding(适用于文本嵌入模型)。
      • endpoint-text-embedding-multilingual(适用于 Text Embedding Multilingual 模型)。

您必须获得类似以下内容的 JSON 响应:

{"predictions":[[-0.00668720435,3.20804138e-05,-0.0281705819,-0.00954890903,-0.0818724185,0.0150693133,-0.00677698106, …. ,0.0167487375,-0.0534791686,0.00208711182,0.032938987,-0.01491543]],"deployedModelId":"text-embedding","model":"models/text-embedding/1","modelDisplayName":"text-embedding","modelVersionId":"1"}