在 DarkNet 模型中编译 YOLO-V2 和 YOLO-V3
备注
单击 此处 下载完整的示例代码
作者:Siju Samuel
本文介绍如何用 TVM 部署 DarkNet 模型。所有必需的模型和库都可通过脚本从 Internet 下载。此脚本运行带有边界框的 YOLO-V2 和 YOLO-V3 模型。DarkNet 解析依赖 CFFI 和 CV2 库,因此执行脚本前要安装这两个库。
pip install cffi opencv-python
# numpy 和 matplotlib
import numpy as np
import matplotlib.pyplot as plt
import sys
# tvm 和 relay
import tvm
from tvm import te
from tvm import relay
from ctypes import *
from tvm.contrib.download import download_testdata
from tvm.relay.testing.darknet import __darknetffi__
import tvm.relay.testing.yolo_detection
import tvm.relay.testing.darknet
选择模型
模型有:‘yolov2’、‘yolov3’ 或 ‘yolov3-tiny’
# 模型名称
MODEL_NAME = "yolov3"
下载所需文件
第一次编译的话需要下载 cfg 和 weights 文件。
CFG_NAME = MODEL_NAME + ".cfg"
WEIGHTS_NAME = MODEL_NAME + ".weights"
REPO_URL = "https://github.com/dmlc/web-data/blob/main/darknet/"
CFG_URL = REPO_URL + "cfg/" + CFG_NAME + "?raw=true"
WEIGHTS_URL = "https://pjreddie.com/media/files/" + WEIGHTS_NAME
cfg_path = download_testdata(CFG_URL, CFG_NAME, module="darknet")
weights_path = download_testdata(WEIGHTS_URL, WEIGHTS_NAME, module="darknet")
# 下载并加载 DarkNet 库
if sys.platform in ["linux", "linux2"]:
DARKNET_LIB = "libdarknet2.0.so"
DARKNET_URL = REPO_URL + "lib/" + DARKNET_LIB + "?raw=true"
elif sys.platform == "darwin":
DARKNET_LIB = "libdarknet_mac2.0.so"
DARKNET_URL = REPO_URL + "lib_osx/" + DARKNET_LIB + "?raw=true"
else:
err = "Darknet lib is not supported on {} platform".format(sys.platform)
raise NotImplementedError(err)
lib_path = download_testdata(DARKNET_URL, DARKNET_LIB, module="darknet")
DARKNET_LIB = __darknetffi__.dlopen(lib_path)
net = DARKNET_LIB.load_network(cfg_path.encode("utf-8"), weights_path.encode("utf-8"), 0)
dtype = "float32"
batch_size = 1
data = np.empty([batch_size, net.c, net.h, net.w], dtype)
shape_dict = {"data": data.shape}
print("Converting darknet to relay functions...")
mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape)
输出结果:
Converting darknet to relay functions...
将计算图导入到 Relay 中
编译模型:
target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
data = np.empty([batch_size, net.c, net.h, net.w], dtype)
shape = {"data": data.shape}
print("Compiling the model...")
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
[neth, netw] = shape["data"][2:] # 当前图像 shape 是 608x608
输出结果:
Compiling the model...
/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
"target_host parameter is going to be deprecated. "
加载测试图像
test_image = "dog.jpg"
print("Loading the test image...")
img_url = REPO_URL + "data/" + test_image + "?raw=true"
img_path = download_testdata(img_url, test_image, "data")
data = tvm.relay.testing.darknet.load_image(img_path, netw, neth)
输出结果:
Loading the test image...