MLOPS/SERVING

trition client

개발허재 2022. 6. 8. 23:38

api server 에서 gpu server로 http 통신을 하기위한 client 파일이다.

import sys
import cv2
import numpy as np
from io import BytesIO
from PIL import Image,ImageOps

import tritonclient.grpc as grpcclient
import tritonclient.grpc.model_config_pb2 as model_config
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype
from tritonclient.utils import InferenceServerException

from django.conf import settings

FLAGS = None


def parse_model_grpc(model_metadata, model_config):
    """
    Check the configuration of a model to make sure it meets the
    requirements for an image classification network (as expected by
    this client)
    """
    if len(model_metadata.inputs) != 1:
        raise Exception("expecting 1 input, got {}".format(
            len(model_metadata.inputs)))

    if len(model_config.input) != 1:
        raise Exception(
            "expecting 1 input in model configuration, got {}".format(
                len(model_config.input)))

    input_metadata = model_metadata.inputs[0]
    output_metadata = model_metadata.outputs

    return (input_metadata.name, output_metadata, model_config.max_batch_size)


def parse_model_http(model_metadata, model_config):
    """
    Check the configuration of a model to make sure it meets the
    requirements for an image classification network (as expected by
    this client)
    """
    if len(model_metadata['inputs']) != 1:
        raise Exception("expecting 1 input, got {}".format(
            len(model_metadata['inputs'])))

    if len(model_config['input']) != 1:
        raise Exception(
            "expecting 1 input in model configuration, got {}".format(
                len(model_config['input'])))

    input_metadata = model_metadata['inputs'][0]
    output_metadata = model_metadata['outputs']

    return (input_metadata['name'], output_metadata,
            model_config['max_batch_size'])

def preprocess(image_file):

    image = Image.open(BytesIO(image_file.read()))
    image = ImageOps.exif_transpose(image)

    image = np.asarray(image)

    if image.shape[0] == 2: image = image[0]
    if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    if image.shape[2] == 4:   image = image[:, :, :3]

    return image

def postprocess(results, output_names, filenames, batch_size):
    """
    Post-process results to show classifications.
    """
    result_list = []
    for output_name in output_names:

        result = results.as_numpy(output_name)

        if result.dtype.type == np.object_:
            result_list = [x.decode('utf-8') for x in result]
        else:
            result_list = result

    return result_list


def request_api(verbose, model_name, protocol, image_file):

    url = settings.TRITON_API_URL

    protocol = protocol.lower()

    try:
        if protocol == "grpc":
            # Create gRPC client for communicating with the server
            triton_client = grpcclient.InferenceServerClient(
                url=url, verbose=verbose)
        else:
            # Create HTTP client for communicating with the server
            triton_client = httpclient.InferenceServerClient(
                url=url, verbose=verbose)
    except Exception as e:
        print("client creation failed: " + str(e))
        sys.exit(1)

    try:
        model_metadata = triton_client.get_model_metadata(model_name=model_name)
    except InferenceServerException as e:
        print("failed to retrieve the metadata: " + str(e))
        sys.exit(1)

    try:
        model_config = triton_client.get_model_config(model_name=model_name)
    except InferenceServerException as e:
        print("failed to retrieve the config: " + str(e))
        sys.exit(1)

    if protocol.lower() == "grpc":
        input_name, output_metadata, batch_size = parse_model_grpc(
            model_metadata, model_config.config)
    else:
        input_name, output_metadata, batch_size = parse_model_http(
            model_metadata, model_config)

    batch_size = 1

    # Preprocess the images into input data according to model
    # requirements

    image_data = preprocess(image_file)
    image_data = np.expand_dims(image_data, axis=0)

    # Send requests of batch_size images.
    input_filenames = [image_file]

    # Set the input data
    inputs = []
    if protocol.lower() == "grpc":
        inputs.append(
            grpcclient.InferInput(input_name, image_data.shape,
                                  "UINT8"))
        inputs[0].set_data_from_numpy(image_data)
    else:
        inputs.append(
            httpclient.InferInput(input_name, image_data.shape,
                                  "UINT8"))
        inputs[0].set_data_from_numpy(image_data)

    output_names = [
        output.name if protocol.lower() == "grpc" else output['name']
        for output in output_metadata
    ]

    outputs = []
    for output_name in output_names:
        if protocol.lower() == "grpc":
            outputs.append(
                grpcclient.InferRequestedOutput(output_name))
        else:
            outputs.append(
                httpclient.InferRequestedOutput(output_name,
                                                binary_data=True))

    # Send request
    result = triton_client.infer(model_name, inputs, outputs=outputs)

    result = postprocess(result, output_names, input_filenames, batch_size)

    # result = result.as_numpy(output_names[0])
    print("PASS")
    return result

 

preprocess 부분은 간단하게 이미지 텐서를 3차원으로 변환해주는 부분이며

 

나머지 부분은 triton client 프로세스를 따른다.