MLOPS/SERVING
trition client
개발허재
2022. 6. 8. 23:38
api server 에서 gpu server로 http 통신을 하기위한 client 파일이다.
import sys
import cv2
import numpy as np
from io import BytesIO
from PIL import Image,ImageOps
import tritonclient.grpc as grpcclient
import tritonclient.grpc.model_config_pb2 as model_config
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype
from tritonclient.utils import InferenceServerException
from django.conf import settings
FLAGS = None
def parse_model_grpc(model_metadata, model_config):
"""
Check the configuration of a model to make sure it meets the
requirements for an image classification network (as expected by
this client)
"""
if len(model_metadata.inputs) != 1:
raise Exception("expecting 1 input, got {}".format(
len(model_metadata.inputs)))
if len(model_config.input) != 1:
raise Exception(
"expecting 1 input in model configuration, got {}".format(
len(model_config.input)))
input_metadata = model_metadata.inputs[0]
output_metadata = model_metadata.outputs
return (input_metadata.name, output_metadata, model_config.max_batch_size)
def parse_model_http(model_metadata, model_config):
"""
Check the configuration of a model to make sure it meets the
requirements for an image classification network (as expected by
this client)
"""
if len(model_metadata['inputs']) != 1:
raise Exception("expecting 1 input, got {}".format(
len(model_metadata['inputs'])))
if len(model_config['input']) != 1:
raise Exception(
"expecting 1 input in model configuration, got {}".format(
len(model_config['input'])))
input_metadata = model_metadata['inputs'][0]
output_metadata = model_metadata['outputs']
return (input_metadata['name'], output_metadata,
model_config['max_batch_size'])
def preprocess(image_file):
image = Image.open(BytesIO(image_file.read()))
image = ImageOps.exif_transpose(image)
image = np.asarray(image)
if image.shape[0] == 2: image = image[0]
if len(image.shape) == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: image = image[:, :, :3]
return image
def postprocess(results, output_names, filenames, batch_size):
"""
Post-process results to show classifications.
"""
result_list = []
for output_name in output_names:
result = results.as_numpy(output_name)
if result.dtype.type == np.object_:
result_list = [x.decode('utf-8') for x in result]
else:
result_list = result
return result_list
def request_api(verbose, model_name, protocol, image_file):
url = settings.TRITON_API_URL
protocol = protocol.lower()
try:
if protocol == "grpc":
# Create gRPC client for communicating with the server
triton_client = grpcclient.InferenceServerClient(
url=url, verbose=verbose)
else:
# Create HTTP client for communicating with the server
triton_client = httpclient.InferenceServerClient(
url=url, verbose=verbose)
except Exception as e:
print("client creation failed: " + str(e))
sys.exit(1)
try:
model_metadata = triton_client.get_model_metadata(model_name=model_name)
except InferenceServerException as e:
print("failed to retrieve the metadata: " + str(e))
sys.exit(1)
try:
model_config = triton_client.get_model_config(model_name=model_name)
except InferenceServerException as e:
print("failed to retrieve the config: " + str(e))
sys.exit(1)
if protocol.lower() == "grpc":
input_name, output_metadata, batch_size = parse_model_grpc(
model_metadata, model_config.config)
else:
input_name, output_metadata, batch_size = parse_model_http(
model_metadata, model_config)
batch_size = 1
# Preprocess the images into input data according to model
# requirements
image_data = preprocess(image_file)
image_data = np.expand_dims(image_data, axis=0)
# Send requests of batch_size images.
input_filenames = [image_file]
# Set the input data
inputs = []
if protocol.lower() == "grpc":
inputs.append(
grpcclient.InferInput(input_name, image_data.shape,
"UINT8"))
inputs[0].set_data_from_numpy(image_data)
else:
inputs.append(
httpclient.InferInput(input_name, image_data.shape,
"UINT8"))
inputs[0].set_data_from_numpy(image_data)
output_names = [
output.name if protocol.lower() == "grpc" else output['name']
for output in output_metadata
]
outputs = []
for output_name in output_names:
if protocol.lower() == "grpc":
outputs.append(
grpcclient.InferRequestedOutput(output_name))
else:
outputs.append(
httpclient.InferRequestedOutput(output_name,
binary_data=True))
# Send request
result = triton_client.infer(model_name, inputs, outputs=outputs)
result = postprocess(result, output_names, input_filenames, batch_size)
# result = result.as_numpy(output_names[0])
print("PASS")
return result
preprocess 부분은 간단하게 이미지 텐서를 3차원으로 변환해주는 부분이며
나머지 부분은 triton client 프로세스를 따른다.