TensorRT Plugin

Table of Contents

1. TensorRT Plugin

https://github.com/sunwayforever/hello_tensorrt

通过 tensorrt caffe plugin 实现自定义算子 (以 convolution 为例)

1.1. patch tensorrt

需要修改 tensorrt 的 caffeParser.cpp, 让它把 convolution 算子交给 plugin 去处理

const IBlobNameToTensor* CaffeParser::parse(INetworkDefinition& network,...
  ...                                            
  else if (layerMsg.type() == "Convolution")
  {
      pluginName = "CONVOLUTION";
      f = parseConvolutionParam(layerMsg, weights, *mBlobNameToTensor);
  }                                            
  ...

其中 parseConvolutionParam 需要负责把 convolution 的参数 (kernel, stride, padding) 和 weights 封装成 std::vector<nvinfer1::PluginField>, 交给 plugin 的构造函数

std::vector<nvinfer1::PluginField> CaffeParser::parseConvolutionParam(
    const trtcaffe::LayerParameter& msg, CaffeWeightFactory& weightFactory,
    BlobNameToTensor& tensors) {
    std::vector<nvinfer1::PluginField> f;
    const trtcaffe::ConvolutionParameter& p = msg.convolution_param();

    int* num_output = allocMemory<int>();
    *num_output = p.num_output();
    f.emplace_back("num_output", num_output, PluginFieldType::kINT32, 1);

    int* kernel_h = allocMemory<int>();
    int* kernel_w = allocMemory<int>();
    *kernel_h = p.has_kernel_h() ? p.kernel_h() : p.kernel_size(0);
    *kernel_w = p.has_kernel_w() ? p.kernel_w() : p.kernel_size(0);
    f.emplace_back("kernel_h", kernel_h, PluginFieldType::kINT32, 1);
    f.emplace_back("kernel_w", kernel_w, PluginFieldType::kINT32, 1);

    // ...
    Weights kernelWeights = weightFactory(msg.name(), WeightType::kGENERIC);
    Weights* kernel = allocMemory<Weights>();
    memcpy(kernel, &kernelWeights, sizeof(kernelWeights));
    f.emplace_back("kernel_weights", kernel, PluginFieldType::kUNKNOWN, 1);

    // ...
    return f;
}

1.2. plugin creator

class ConvolutionPluginCreator : public IPluginCreator {
   public:
    const char* getPluginName() const noexcept override {
        return "CONVOLUTION";
    }
    const char* getPluginVersion() const noexcept override { return "1"; }
    // ...
    IPluginV2* createPlugin(
        const char* name, const PluginFieldCollection* fc) noexcept override {
        auto* plugin = new ConvolutionPlugin(*fc);
        mFieldCollection = *fc;
        mPluginName = name;
        return plugin;
    }
    IPluginV2* deserializePlugin(
        const char* name, const void* serialData,
        size_t serialLength) noexcept override {
        auto* plugin = new ConvolutionPlugin(serialData, serialLength);
        mPluginName = name;
        return plugin;
    }
    // ...

};

REGISTER_TENSORRT_PLUGIN(ConvolutionPluginCreator);

1.3. plugin

class ConvolutionPlugin : public IPluginV2IOExt {
   public:
    ConvolutionPlugin(const PluginFieldCollection fc) {
        for (int i = 0; i < fc.nbFields; i++) {
            auto field = fc.fields[i];
            if (std::string(field.name) == "num_output") {
                this->mOutputChannel = *((int*)field.data);
            }
            if (std::string(field.name) == "kernel_weights") {
                this->mKernelWeights = *(Weights*)field.data;
            }
            if (std::string(field.name) == "bias_weights") {
                this->mBiasWeights = *(Weights*)field.data;
            }
            // ...
        }
    }

    ConvolutionPlugin(const void* data, size_t length) {
        mInputChannel = ((int*)data)[0];
        mOutputChannel = ((int*)data)[1];
        mH = ((int*)data)[2];
        mW = ((int*)data)[3];
        // ...
        memcpy(kernel, ((int*)data) + 15, kc * 4);
        memcpy(bias, ((int*)data) + 15 + kc, bc * 4);
        mKernelWeights = Weights{
            .type = DataType::kFLOAT,
            .values = kernel,
            .count = kc,
        };
        mBiasWeights = Weights{
            .type = DataType::kFLOAT,
            .values = bias,
            .count = bc,
        };
    }

   public:
    int getNbOutputs() const noexcept override { return 1; }

    Dims getOutputDimensions(
        int index, const Dims* inputs, int nbInputDims) noexcept override {
        int channel = inputs->d[0];
        int h = inputs->d[1];
        int w = inputs->d[2];

        Dims3 outputDims;
        outputDims.nbDims = 3;
        outputDims.d[0] = mOutputChannel;
        // NOTE: `floor` for convolution
        outputDims.d[1] = floor(h + 2 * mPadH - mKernelH, mStrideH) + 1;
        outputDims.d[2] = floor(w + 2 * mPadW - mKernelW, mStrideW) + 1;

        return outputDims;
    }

    int enqueue(
        int batchSize, const void* const* inputs, void* const* outputs,
        void* workspace, cudaStream_t stream) noexcept override {
        if (mType == (int)DataType::kFLOAT) {
            float* dst = reinterpret_cast<float*>(outputs[0]);
            const float* src = reinterpret_cast<const float*>(inputs[0]);
            Convolution(
                dst, src, mInputChannel, mOutputChannel, mH, mW, mKernelH,
                mKernelW, mStrideH, mStrideW, mPadH, mPadW,
                (float*)mKernelWeights.values, (float*)mBiasWeights.values,
                stream);
        } else {
            int8_t* dst = reinterpret_cast<int8_t*>(outputs[0]);
            const int8_t* src = reinterpret_cast<const int8_t*>(inputs[0]);
            ConvolutionInt8(
                dst, src, mInputScale, mOutputScale, mInputChannel,
                mOutputChannel, mH, mW, mKernelH, mKernelW, mStrideH, mStrideW,
                mPadH, mPadW, (float*)mKernelWeights.values,
                (float*)mBiasWeights.values, stream);
        }

        return 0;
    }

    size_t getSerializationSize() const noexcept override {
        return (12 + 3 + mKernelWeights.count + mBiasWeights.count) * 4;
    }

    void serialize(void* buffer) const noexcept override {
        ((int*)buffer)[0] = mInputChannel;
        ((int*)buffer)[1] = mOutputChannel;
        ((int*)buffer)[2] = mH;
        ((int*)buffer)[3] = mW;
        // ...
        memcpy(
            ((int*)buffer) + 15, mKernelWeights.values,
            mKernelWeights.count * 4);
        memcpy(
            ((int*)buffer) + 15 + mKernelWeights.count, mBiasWeights.values,
            mBiasWeights.count * 4);
    }

    void configurePlugin(
        const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out,
        int nbOutput) noexcept override {
        mType = (int)in[0].type;
        mInputScale = in[0].scale;
        mOutputScale = out[0].scale;
        auto dims = in[0].dims;
        mInputChannel = dims.d[0];
        mH = dims.d[1];
        mW = dims.d[2];
    }

    bool supportsFormatCombination(
        int pos, const PluginTensorDesc* inOut, int nbInputs,
        int nbOutputs) const noexcept override {
        return inOut[pos].format == TensorFormat::kLINEAR &&
               (inOut[pos].type == DataType::kFLOAT ||
                inOut[pos].type == DataType::kINT8) &&
               inOut[pos].type == inOut[0].type;
    }
    DataType getOutputDataType(
        int index, const DataType* inputTypes,
        int nbInputs) const noexcept override {
        (void)index;
        return inputTypes[0];
    }

    const char* getPluginType() const noexcept override {
        return "CONVOLUTION";
    }
    const char* getPluginVersion() const noexcept override { return "1"; }

    // ...

   private:
    int mOutputChannel;
    int mInputChannel;
    int mH;
    int mW;
    Weights mKernelWeights;
    Weights mBiasWeights;
    int mKernelH;
    int mKernelW;
    // ...
};

plugin 中需要定义的主要函数及其被调用的顺序::

  1. ConvolutionPlugin(const PluginFieldCollection fc)

    通过 tensorrt 传过来的参数来初始化 plugin

  2. getOutputDimensions

    根据 input dim 计算 output dim.

    getOutputDimensions 在 supportsFormatCombination 之前被调用, 且没有任何 format 相关的参数, 所以这个函数只需要以 nchw 格式计算出 output dim, 而不是涉及 memory layout. 后续 tensorrt 会根据 supportsFormatCombination 决定 tensor 的 memory layout

  3. supportsFormatCombination

    plugin 需要通过这个函数告诉 tensorrt 它支持的 input/ouput 的 format (即 memory layout, 例如 N[C/32]HW[32]) 和 data type (fp32, fp16, int32, int8, …)

    这个函数是 plugin 中非常重要的函数, 因为 tensorrt 支持不同的 format 和 data type, 若 supportsFormatCombination 没有指明, 有可能会导致 tenssort 传递给 plugin 不支持的 format 和 data type.

    例如, 当实现 softmax plugin 时, 虽然 softmax 算子本身并不需要关注 input/output format, 但还是需要通过 supportsFormatCombination 要求 input/output 有相同的 format, 以免自己做 format 的转换.

  4. configurePlugin

    通过这个函数可以知道 tenssort 根据 supportsFormatCombination 和 getOutputDimensions 确定的算子最终的 input/output 的 format, data type 和 ouptut dim

  5. serialize

    把前面获得的所有信息通过 serialize 转换为 tensorrt 自己的格式, 以便后续不需要再执行前面的步骤

  6. ConvolutionPlugin(const void* data, size_t length)

    通过 serialize data 重新构造 plugin

  7. enqueue

    执行最终的 infer

Author: [email protected]
Date: 2022-06-26 Sun 18:13
Last updated: 2022-10-17 Mon 19:16

知识共享许可协议