TensorRT Plugin
Table of Contents
1. TensorRT Plugin
https://github.com/sunwayforever/hello_tensorrt
通过 tensorrt caffe plugin 实现自定义算子 (以 convolution 为例)
1.1. patch tensorrt
需要修改 tensorrt 的 caffeParser.cpp, 让它把 convolution 算子交给 plugin 去处理
const IBlobNameToTensor* CaffeParser::parse(INetworkDefinition& network,... ... else if (layerMsg.type() == "Convolution") { pluginName = "CONVOLUTION"; f = parseConvolutionParam(layerMsg, weights, *mBlobNameToTensor); } ...
其中 parseConvolutionParam 需要负责把 convolution 的参数 (kernel, stride, padding) 和 weights 封装成 std::vector<nvinfer1::PluginField>, 交给 plugin 的构造函数
std::vector<nvinfer1::PluginField> CaffeParser::parseConvolutionParam( const trtcaffe::LayerParameter& msg, CaffeWeightFactory& weightFactory, BlobNameToTensor& tensors) { std::vector<nvinfer1::PluginField> f; const trtcaffe::ConvolutionParameter& p = msg.convolution_param(); int* num_output = allocMemory<int>(); *num_output = p.num_output(); f.emplace_back("num_output", num_output, PluginFieldType::kINT32, 1); int* kernel_h = allocMemory<int>(); int* kernel_w = allocMemory<int>(); *kernel_h = p.has_kernel_h() ? p.kernel_h() : p.kernel_size(0); *kernel_w = p.has_kernel_w() ? p.kernel_w() : p.kernel_size(0); f.emplace_back("kernel_h", kernel_h, PluginFieldType::kINT32, 1); f.emplace_back("kernel_w", kernel_w, PluginFieldType::kINT32, 1); // ... Weights kernelWeights = weightFactory(msg.name(), WeightType::kGENERIC); Weights* kernel = allocMemory<Weights>(); memcpy(kernel, &kernelWeights, sizeof(kernelWeights)); f.emplace_back("kernel_weights", kernel, PluginFieldType::kUNKNOWN, 1); // ... return f; }
1.2. plugin creator
class ConvolutionPluginCreator : public IPluginCreator { public: const char* getPluginName() const noexcept override { return "CONVOLUTION"; } const char* getPluginVersion() const noexcept override { return "1"; } // ... IPluginV2* createPlugin( const char* name, const PluginFieldCollection* fc) noexcept override { auto* plugin = new ConvolutionPlugin(*fc); mFieldCollection = *fc; mPluginName = name; return plugin; } IPluginV2* deserializePlugin( const char* name, const void* serialData, size_t serialLength) noexcept override { auto* plugin = new ConvolutionPlugin(serialData, serialLength); mPluginName = name; return plugin; } // ... }; REGISTER_TENSORRT_PLUGIN(ConvolutionPluginCreator);
1.3. plugin
class ConvolutionPlugin : public IPluginV2IOExt { public: ConvolutionPlugin(const PluginFieldCollection fc) { for (int i = 0; i < fc.nbFields; i++) { auto field = fc.fields[i]; if (std::string(field.name) == "num_output") { this->mOutputChannel = *((int*)field.data); } if (std::string(field.name) == "kernel_weights") { this->mKernelWeights = *(Weights*)field.data; } if (std::string(field.name) == "bias_weights") { this->mBiasWeights = *(Weights*)field.data; } // ... } } ConvolutionPlugin(const void* data, size_t length) { mInputChannel = ((int*)data)[0]; mOutputChannel = ((int*)data)[1]; mH = ((int*)data)[2]; mW = ((int*)data)[3]; // ... memcpy(kernel, ((int*)data) + 15, kc * 4); memcpy(bias, ((int*)data) + 15 + kc, bc * 4); mKernelWeights = Weights{ .type = DataType::kFLOAT, .values = kernel, .count = kc, }; mBiasWeights = Weights{ .type = DataType::kFLOAT, .values = bias, .count = bc, }; } public: int getNbOutputs() const noexcept override { return 1; } Dims getOutputDimensions( int index, const Dims* inputs, int nbInputDims) noexcept override { int channel = inputs->d[0]; int h = inputs->d[1]; int w = inputs->d[2]; Dims3 outputDims; outputDims.nbDims = 3; outputDims.d[0] = mOutputChannel; // NOTE: `floor` for convolution outputDims.d[1] = floor(h + 2 * mPadH - mKernelH, mStrideH) + 1; outputDims.d[2] = floor(w + 2 * mPadW - mKernelW, mStrideW) + 1; return outputDims; } int enqueue( int batchSize, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override { if (mType == (int)DataType::kFLOAT) { float* dst = reinterpret_cast<float*>(outputs[0]); const float* src = reinterpret_cast<const float*>(inputs[0]); Convolution( dst, src, mInputChannel, mOutputChannel, mH, mW, mKernelH, mKernelW, mStrideH, mStrideW, mPadH, mPadW, (float*)mKernelWeights.values, (float*)mBiasWeights.values, stream); } else { int8_t* dst = reinterpret_cast<int8_t*>(outputs[0]); const int8_t* src = reinterpret_cast<const int8_t*>(inputs[0]); ConvolutionInt8( dst, src, mInputScale, mOutputScale, mInputChannel, mOutputChannel, mH, mW, mKernelH, mKernelW, mStrideH, mStrideW, mPadH, mPadW, (float*)mKernelWeights.values, (float*)mBiasWeights.values, stream); } return 0; } size_t getSerializationSize() const noexcept override { return (12 + 3 + mKernelWeights.count + mBiasWeights.count) * 4; } void serialize(void* buffer) const noexcept override { ((int*)buffer)[0] = mInputChannel; ((int*)buffer)[1] = mOutputChannel; ((int*)buffer)[2] = mH; ((int*)buffer)[3] = mW; // ... memcpy( ((int*)buffer) + 15, mKernelWeights.values, mKernelWeights.count * 4); memcpy( ((int*)buffer) + 15 + mKernelWeights.count, mBiasWeights.values, mBiasWeights.count * 4); } void configurePlugin( const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) noexcept override { mType = (int)in[0].type; mInputScale = in[0].scale; mOutputScale = out[0].scale; auto dims = in[0].dims; mInputChannel = dims.d[0]; mH = dims.d[1]; mW = dims.d[2]; } bool supportsFormatCombination( int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const noexcept override { return inOut[pos].format == TensorFormat::kLINEAR && (inOut[pos].type == DataType::kFLOAT || inOut[pos].type == DataType::kINT8) && inOut[pos].type == inOut[0].type; } DataType getOutputDataType( int index, const DataType* inputTypes, int nbInputs) const noexcept override { (void)index; return inputTypes[0]; } const char* getPluginType() const noexcept override { return "CONVOLUTION"; } const char* getPluginVersion() const noexcept override { return "1"; } // ... private: int mOutputChannel; int mInputChannel; int mH; int mW; Weights mKernelWeights; Weights mBiasWeights; int mKernelH; int mKernelW; // ... };
plugin 中需要定义的主要函数及其被调用的顺序::
ConvolutionPlugin(const PluginFieldCollection fc)
通过 tensorrt 传过来的参数来初始化 plugin
getOutputDimensions
根据 input dim 计算 output dim.
getOutputDimensions 在 supportsFormatCombination 之前被调用, 且没有任何 format 相关的参数, 所以这个函数只需要以 nchw 格式计算出 output dim, 而不是涉及 memory layout. 后续 tensorrt 会根据 supportsFormatCombination 决定 tensor 的 memory layout
supportsFormatCombination
plugin 需要通过这个函数告诉 tensorrt 它支持的 input/ouput 的 format (即 memory layout, 例如 N[C/32]HW[32]) 和 data type (fp32, fp16, int32, int8, …)
这个函数是 plugin 中非常重要的函数, 因为 tensorrt 支持不同的 format 和 data type, 若 supportsFormatCombination 没有指明, 有可能会导致 tenssort 传递给 plugin 不支持的 format 和 data type.
例如, 当实现 softmax plugin 时, 虽然 softmax 算子本身并不需要关注 input/output format, 但还是需要通过 supportsFormatCombination 要求 input/output 有相同的 format, 以免自己做 format 的转换.
configurePlugin
通过这个函数可以知道 tenssort 根据 supportsFormatCombination 和 getOutputDimensions 确定的算子最终的 input/output 的 format, data type 和 ouptut dim
serialize
把前面获得的所有信息通过 serialize 转换为 tensorrt 自己的格式, 以便后续不需要再执行前面的步骤
ConvolutionPlugin(const void* data, size_t length)
通过 serialize data 重新构造 plugin
enqueue
执行最终的 infer