Python接入指南

更新时间:

操作步骤

步骤一:安装依赖

pip install requests
pip install dataclass_wizard

步骤二:增加Client

增加 client.py,按需修改 package 名称

import time
import uuid
import hmac
import hashlib
import base64
import json
import io
import requests

class Client:
    def __init__(self, endpoint: str, app_key: str, app_secret: str):
        self.endpoint = endpoint
        self.app_key = app_key
        self.app_secret = app_secret

    def invoke(self, path: str, params: dict = None, method='POST', headers: dict = None, **kwargs):
        # url = f'https://{self.endpoint}{path}'
        # gen_headers = self._generate_header('POST', path, params, headers)
        # return requests.post(url, headers=gen_headers, json=params, **kwargs)

        url = f'https://{self.endpoint}{path}'
        if method == 'GET':
            gen_headers = self._generate_header('GET', path, params, headers)
            return requests.get(url, headers=gen_headers, **kwargs)
        else:
            gen_headers = self._generate_header('POST', path, params, headers)
            return requests.post(url, headers=gen_headers, json=params, **kwargs)

    def _generate_header(self, http_method: str, path: str, body: dict = None, hdrs: dict = None):
        """
        :param http_method:
        :param path:
        :param params:
        :param body:

        # https://help.aliyun.com/zh/api-gateway/traditional-api-gateway/use-digest-authentication-to-call-an-api?spm=a2c4g.11186623.0.0.52d126desp6m4B#topic-1867627
        """

        timestamp = time.time()
        date_str = time.strftime('%a, %d %b %Y %H:%M:%S GMT', time.gmtime(timestamp)).replace('GMT', 'GMT+00:00')
        timestamp_str = str(int(timestamp * 1000))
        uuid_str = str(uuid.uuid4())
        json_header = 'application/json; charset=utf-8'

        headers = {
            'date': date_str,
            'x-ca-key': self.app_key,
            'x-ca-timestamp': timestamp_str,
            'x-ca-nonce': uuid_str,
            'x-ca-signature-method': 'HmacSHA256',
            'x-ca-signature-headers': 'x-ca-timestamp,x-ca-key,x-ca-nonce,x-ca-signature-method',
            'Content-Type': json_header,
            'Accept': json_header
        }

        o = io.StringIO()
        o.write(http_method)
        o.write("\n")

        o.write(json_header)
        o.write("\n")

        if body:
            # perform md5 and base64
            h = hashlib.md5()
            h.update(json.dumps(body).encode('utf-8'))
            body_md5_str = base64.b64encode(h.digest()).strip().decode('utf-8')
            headers["content-md5"] = body_md5_str
            o.write(body_md5_str)
        o.write("\n")

        o.write(json_header)
        o.write("\n")

        o.write(date_str)
        o.write("\n")

        o.write("x-ca-key:")
        o.write(self.app_key)
        o.write("\n")

        o.write("x-ca-nonce:")
        o.write(uuid_str)
        o.write("\n")

        o.write("x-ca-signature-method:HmacSHA256")
        o.write("\n")

        o.write("x-ca-timestamp:")
        o.write(timestamp_str)
        o.write("\n")

        o.write(path)

        h = hmac.new(bytes(self.app_secret, 'utf-8'), bytes(o.getvalue(), 'utf-8'), hashlib.sha256)
        headers["x-ca-signature"] = base64.b64encode(h.digest()).decode('utf-8')

        if hdrs and len(hdrs) > 0:
            headers.update(hdrs)

        return headers

步骤三:增加Proto类(以ComfyUI生图服务举例)

增加 proto.py

# -*- coding: utf-8 -*-
from dataclasses import dataclass
from typing import Optional, List, Dict
from enum import Enum
from dataclass_wizard import JSONWizard, DumpMeta
from util import batch_download_images


class PredictResultStatusCode(Enum):
    TASK_INPROGRESS = "running"
    TASK_FAILED = "failed"
    TASK_QUEUE = "waiting"
    TASK_FINISH = "succeeded"

    def finished(self):
        return self in (PredictResultStatusCode.TASK_FAILED, PredictResultStatusCode.TASK_FINISH)

class JSONe(JSONWizard):
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        DumpMeta(key_transform='SNAKE').bind_to(cls)


@dataclass
class GatewayResponse(JSONe):
    status: Optional[int] = 0
    err_code: Optional[str] = ""
    err_message: Optional[str] = ""
    sub_err_code: Optional[str] = ""
    sub_err_message: Optional[str] = ""
    api_invoke_id: Optional[str] = ""


@dataclass
class ComfyRequest(JSONe):
    workflow_id: str
    version_id: Optional[str] = None
    inputs: Optional[Dict[str, any]] = None
    alias_id: Optional[str] = None



@dataclass
class ComfyResponseData(JSONe):
    task_id: str
    status: Optional[PredictResultStatusCode] = PredictResultStatusCode.TASK_INPROGRESS


@dataclass
class ComfyResponse(GatewayResponse):
    data: Optional[ComfyResponseData] = None



@dataclass
class PredictResult(JSONe):
    task_id: str
    images: Optional[List[str]] = None
    info: Optional[Dict[str, str]] = None
    parameters: Optional[Dict[str, str]] = None
    status: Optional[PredictResultStatusCode] = PredictResultStatusCode.TASK_INPROGRESS
    imgs_bytes: Optional[List[str]] = None
    result: Optional[Dict] = None

@dataclass
class PredictResultResponse(GatewayResponse):
    data: Optional[PredictResult] = None

    def download_images(self):
        if self.data.images is not None and len(self.data.images) > 0:
            self.data.imgs_bytes = batch_download_images(self.data.images)


@dataclass
class ProgressData(JSONe):
    task_id: str
    progress: float
    eta_relative: int
    message: Optional[str] = ""
    status: Optional[PredictResultStatusCode] = PredictResultStatusCode.TASK_INPROGRESS


@dataclass
class ProgressResponse(GatewayResponse):
    data: Optional[ProgressData] = None

步骤四:增加工具类(以ComfyUI生图服务举例)

增加 util.py

from multiprocessing.pool import ThreadPool
import logging
import requests
from dataclass_wizard.utils.string_conv import to_camel_case


logger = logging.getLogger(__name__)


def batch_download_images(image_links):
    def _download(image_link):
        attempts = 3
        while attempts > 0:
            try:
                response = requests.get(image_link, timeout=100)
                return response.content
            except Exception:
                logger.warning("Failed to download image, retrying...")
            attempts -= 1
        return None

    pool = ThreadPool()
    applied = []
    for img_url in image_links:
        applied.append(pool.apply_async(_download, (img_url, )))
    ret = [r.get() for r in applied]
    return [_ for _ in ret if _ is not None]


def convert_to_camel_case(data_dict):
    if isinstance(data_dict, dict):
        return {to_camel_case(key): convert_to_camel_case(value) for key, value in data_dict.items()}
    elif isinstance(data_dict, list):
        return [convert_to_camel_case(value) for value in data_dict]
    else:
        return data_dict

步骤五:填写调用AK/SK、调用路径、调用参数

from client import Client
from proto import ComfyRequest, ComfyResponse, PredictResultResponse, ProgressResponse
import time
import json

cli = Client(
    endpoint="openai.edu-aliyun.com",
    app_key="应用AK",
    app_secret="应用SK"
)


# 原始调用方法
def call(url, body, method='POST', headers=None):
    res = cli.invoke(url, body, method, headers)
    if res:
        data = res.json()
        return data


def comfy_prompt(prompt: ComfyRequest, custom_resource_config_id='default') -> ComfyResponse:
    print(prompt.to_dict())

    headers = {}
    if custom_resource_config_id:
        headers['X-SP-RESOURCE-CONFIG-ID'] = custom_resource_config_id

    r = ComfyResponse.from_dict(call("/scc/comfy_prompt", prompt.to_dict(), headers=headers))
    print(r.to_dict())
    if r.err_code:
        raise Exception(r.err_message)

    for _ in range(1200):
        params = {}
        params["taskId"] = r.data.task_id
        # 查询进度
        raw_res = call("/scc/comfy_get_progress", params, headers=None)
        print(raw_res)
        if raw_res:
            r = ProgressResponse.from_dict(raw_res)
            if r.status == 20:
                pretty_json_str = json.dumps(raw_res, indent=2, ensure_ascii=False)
                print(pretty_json_str)
                raise Exception("Failed to call , error: %s" % pretty_json_str)
            if r.data.status.finished():
                # 查询结果
                raw_res = call("/scc/comfy_get_result", {"taskId": r.data.task_id})
                return PredictResultResponse.from_dict(raw_res)
        time.sleep(1)
    raise Exception("1200s Timeout")


if __name__ == '__main__':
    begin = time.time()
    alias_id = "配置的工作流别名"
    workflow_id = "控制台获取的工作流id"
    # 定义的参数
    params = {
            "prompt": "A man is walking on the street."
    }
    result = comfy_prompt(ComfyRequest(alias_id=alias_id,
                                                workflow_id=workflow_id,
                                                inputs=params))
    print("生图结果:" + str(result))
    print("时间消耗: %.2fs" % (time.time() - begin))