加入星計劃,您可以享受以下權益:

  • 創(chuàng)作內(nèi)容快速變現(xiàn)
  • 行業(yè)影響力擴散
  • 作品版權保護
  • 300W+ 專業(yè)用戶
  • 1.5W+ 優(yōu)質(zhì)創(chuàng)作者
  • 5000+ 長期合作伙伴
立即加入
  • 正文
    • 01、常見的模型格式
    • 02、建立模型并進行轉(zhuǎn)換
    • 03、運行驗證環(huán)節(jié)
  • 相關推薦
申請入駐 產(chǎn)業(yè)圖譜

研發(fā)干貨丨OK1052-C開發(fā)板運行 Tensorflow_lite模型

2021/04/01
397
加入交流群
掃碼加入
獲取工程師必備禮包
參與熱點資訊討論

tensorFlowLite 是一款TensorFlow用于移動設備和嵌入式設備的輕量級解決方案。

我們知道TensorFlow可以在多個平臺上運行,從機架式服務器到小型IoT設備。但是隨著近年來機器學習模型的廣泛使用,出現(xiàn)了在移動和嵌入式設備上部署它們的需求,而TensorFlowLite 允許設備端的機器學習模型的低延遲推斷。

飛凌OK1052-C開發(fā)板以高性能處理器i.MXRT1052作為硬件軀體,再用tensorflow_lite武裝軟件大腦,也搭上了開往機器學習,邊緣計算等朝陽行業(yè)領域的快車。

▲OK1052-C開發(fā)板接口圖

 

OK1052-C的用戶資料SDK中,已經(jīng)有了關于tensorflow_lite使用的demo例程,但是這些例程使用的都是現(xiàn)成訓練之后的實驗模型,并不適用于我們實際的應用場景。我們在實際應用項目中,必然是需要使用適合本項目的模型,網(wǎng)上現(xiàn)成模型資源下載或者自己訓練模型,然后搭載自己的應用程序,完成我們的應用項目需求。

由于模型的訓練需要算力的支持,通常我們在計算機上進行模型的訓練,我們將訓練好的模型稱為預訓練模型。我們也可以在網(wǎng)上download預訓練模型,然后通過格式轉(zhuǎn)換轉(zhuǎn)換為tflite格式的模型,也就是開發(fā)板可執(zhí)行的模型格式;也可以自己搭建計算網(wǎng)絡,通過訓練之后,形成預訓練模型,再轉(zhuǎn)換為tflite格式,運行到開發(fā)板。

今天我們通過一個簡單的例子,來介紹怎么建立計算網(wǎng)絡和進行模型的轉(zhuǎn)換。

01、常見的模型格式

我們常見的訓練模型格式有.PB,cpkt,saveModel,H5等,若想將模型運用于tensor flowlite,需要轉(zhuǎn)換為tflite格式,一般的轉(zhuǎn)換過程是:

由上圖可知,如果要將Checkpoints(cpkt)格式轉(zhuǎn)換為tflite,需經(jīng)過freeze_graph.py工具將cpkt格式模型轉(zhuǎn)換為Frozen GraphDef(.pb)格式,然后再經(jīng)過TFLite Converter轉(zhuǎn)換工具轉(zhuǎn)換為tflite。

02、建立模型并進行轉(zhuǎn)換

現(xiàn)在通過例子介紹如何將模型轉(zhuǎn)換為tflite格式,首先我們先建立兩個網(wǎng)絡流圖,并分別生成.pb和cpkt格式的模型。

1)建立計算網(wǎng)絡,并保存為.PB格式模型

# coding=UTF-8

import tensorflow as tf

import shutil

import os.path

from tensorflow.python.framework import graph_util

output_graph = "easy_model/add_model.pb"

#下面的過程你可以替換成CNN、RNN等你想做的訓練過程,這里只是簡單的一個計算公式

input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")

W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")

B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")

_y = (input_holder * W1) + B1

# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否則返回false

predictions = tf.add(_y, 10,name="predictions") #做一個加法運算

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

print ("predictions :", sess.run(predictions, feed_dict={input_holder: [10.0]}))

graph_def = tf.get_default_graph().as_graph_def() #得到當前的圖的 GraphDef 部分,通過這個部分就可以完成重輸入層到輸出層的計算過程

output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變量值固定

sess,

graph_def,

["predictions"] #需要保存節(jié)點的名字

)

with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型

f.write(output_graph_def.SerializeToString()) # 序列化輸出

print("%d ops in the final graph." % len(output_graph_def.node))

print (predictions)

此計算網(wǎng)絡只是一個簡單的數(shù)學運算,不需要進行訓練,該運算公式為:

predictions = (input_holder * W1) + B1 + 10

其中input_holder是網(wǎng)絡的輸入節(jié)點,predictions是網(wǎng)絡的輸出節(jié)點,W1,B1是兩個變量,分別被賦值為5.0,,1.0

程序運行之后,會在easy_model/文件夾下生成add_model.pb模型文件。我們通過Netron軟件(Netron是一個很方便的軟件)來看一下add_model.pb模型文件中存儲的網(wǎng)絡圖及參數(shù)信息:

 

2)建立網(wǎng)絡圖,并保存為.ckpt格式模型:

# coding=UTF-8 支持中文編碼格式

import tensorflow as tf

import shutil

import os.path

MODEL_DIR = "easy_model"

MODEL_NAME = "model.ckpt"

if not tf.gfile.Exists(MODEL_DIR):

tf.gfile.MakeDirs(MODEL_DIR)

下面的過程你可以替換成CNN、RNN等你想做的訓練過程,這里只是簡單的一個計算公式

input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")

W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")

B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")

_y = (input_holder * W1) + B1

#predictions = tf.greater(_y, 50, name="predictions")

predictions = tf.add(_y, 10,name="predictions")

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:

sess.run(init)

print ("predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}))

saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME))

print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node))

for op in tf.get_default_graph().get_operations():

print (op.name, op.values())

 

此網(wǎng)絡圖同上一節(jié)一樣是一個簡單的數(shù)學運算。

不同的是,此程序最后保存為ckpt格式的模型文件:

checkpoint :記錄目錄下所有模型文件列表

ckpt.data :保存模型中每個變量的取值

ckpt.meta :保存整個計算網(wǎng)絡圖的結構

同樣可通過Netron軟件查看ckpt文件中存儲的網(wǎng)絡圖結構:

 

3)將生成的cpkt格式模型文件轉(zhuǎn)換為.pb文件:

import tensorflow as tf

from tensorflow.python.framework import graph_util

#from create_tf_record import *

resize_height = 100 # 指定圖片高度

resize_width = 100 # 指定圖片寬度

def freeze_graph(input_checkpoint, output_graph):

'''

:param input_checkpoint:

:param output_graph: PB 模型保存路徑

:return:

'''

# 檢查目錄下ckpt文件狀態(tài)是否可用

# checkpoint = tf.train.get_checkpoint_state(model_folder)

# 得ckpt文件路徑

# input_checkpoint = checkpoint.model_checkpoint_path

# 指定輸出的節(jié)點名稱,該節(jié)點名稱必須是元模型中存在的節(jié)點

output_node_names = "predictions"

saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

graph = tf.get_default_graph() # 獲得默認的圖

input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當前的圖

with tf.Session() as sess:

saver.restore(sess, input_checkpoint) # 恢復圖并得到數(shù)據(jù)

# 模型持久化,將變量值固定

output_graph_def = graph_util.convert_variables_to_constants(

sess=sess,

# 等于:sess.graph_def

input_graph_def=input_graph_def,

# 如果有多個輸出節(jié)點,以逗號隔開

output_node_names=output_node_names.split(","))

# 保存模型

with tf.gfile.GFile(output_graph, "wb") as f:

f.write(output_graph_def.SerializeToString()) # 序列化輸出

# 得到當前圖有幾個操作節(jié)點

print("%d ops in the final graph." % len(output_graph_def.node))

input_checkpoint='easy_model/model.ckpt'

# 輸出pb模型的路徑

out_pb_path="easy_model/frozen_model.pb"

freeze_graph(input_checkpoint,out_pb_path)

這里需要輸入輸出節(jié)點名稱output_node_names= "predictions",

程序運行之后在easy_model文件夾下生成了frozen_model.pb文件,我們通過Netron軟件查看一下frozen_model.pb網(wǎng)絡圖,可以發(fā)現(xiàn)跟第一節(jié)中直接生成的.pb模型文件一樣的:

cpkt文件格式將模型保存為4個文件,pb文件格式為一個。ckpt模型持久化方式將圖結構與權重參數(shù)分開保存,多了模型更多的細節(jié),適合模型訓練階段;而pb持久化方式完成了從輸入到輸出的前向傳播,完成了端到端的形式,更適合離線使用。

 

4)最后,我們將.pb文件轉(zhuǎn)換為.tflite:

我們運行此段代碼:

function showSnackbar() {

var $snackbar = $('#snackbar');

$snackbar.addClass('show');

setTimeout(() => {

$snackbar.removeClass('show');

}, 3000);

}

注意這里需要寫上輸入輸出節(jié)點名稱,這個在構建網(wǎng)絡模型時,已經(jīng)定義。

運行之后,會報錯:

module 'tensorflow' has no attribute 'lite'

意思是目前的tensorflow版本不支持lite類,沒關系,我們重新安裝1.14版本的tensorflow即可成功運行。最后生成tflite格式的模型文件:easy_model.lite。

03、運行驗證環(huán)節(jié)

將轉(zhuǎn)換后模型運行到OK1052-C開發(fā)板進行驗證

首先需要將將easy_model.lite轉(zhuǎn)換為二進制數(shù)組的.h文件easy_frozen_0.h,轉(zhuǎn)換完成之后,將其放入OK1052-C用戶資料的,SDK的middlewareeiqtensorflow-liteexampleslabel_image目錄下,

打開SDK中boardsevkbimxrt1050eiq_examplestensorflow_lite_label_image下工程,在文件label_iamge.cpp中做如下修改:


 

#include "board.h"

#include "pin_mux.h"

#include "clock_config.h"

#include "fsl_debug_console.h"

#include

#include

#include

#include "timer.h"

#include "tensorflow/lite/kernels/register.h"

#include "tensorflow/lite/model.h"

#include "tensorflow/lite/optional_debug_tools.h"

#include "tensorflow/lite/string_util.h"

//#include "Sine_mode.h"

//#include "add_model.h"

#include "easy_frozen_0.h"

int inference_count = 0;

// This is a small number so that it's easy to read the logs

const int kInferencesPerCycle = 30;

const float kXrange = 2.f * 3.14159265359f;

#define LOG(x) std::cout

void RunInference()

{

std::unique_ptr model;

std::unique_ptr interpreter;

model = tflite::FlatBufferModel::BuildFromBuffer(mobilenet_model, mobilenet_model_len);

if (!model) {

LOG(FATAL) << "Failed to load modelrn";

exit(-1);

}

model->error_reporter();

tflite::ops::builtin::BuiltinOpResolver resolver;

tflite::InterpreterBuilder(*model, resolver)(&interpreter);

if (!interpreter) {

LOG(FATAL) << "Failed to construct interpreterrn";

exit(-1);

}

float input = interpreter->inputs()[0];

if (interpreter->AllocateTensors() != kTfLiteOk) {

LOG(FATAL) << "Failed to allocate tensors!rn";

}

while(true)

{

// Calculate an x value to feed into the model. We compare the current

// inference_count to the number of inferences per cycle to determine

// our position within the range of possible x values the model was

// trained on, and use this to calculate a value.

float position = static_cast(inference_count) /

static_cast(kInferencesPerCycle);

float x_val = position * kXrange;

float* input_tensor_data = interpreter->typed_tensor(input);

*input_tensor_data = x_val;

// Delay_time(1000);

// Run inference, and report any error

TfLiteStatus invoke_status = interpreter->Invoke();

if (invoke_status != kTfLiteOk)

{

LOG(FATAL) << "Failed to invoke tflite!rn";

return;

}

// Read the predicted y value from the model's output tensor

float* y_val = interpreter->typed_output_tensor(0);

PRINTF("rn x_value: %f, y_value: %f rn", x_val, y_val[0]);

//PRINTF("rn x_value: %d, y_value: %d rn", (int)x_val, (int)y_val[0]);

// Increment the inference_counter, and reset it if we have reached

// the total number per cycle

inference_count += 1;

if (inference_count >= kInferencesPerCycle) inference_count = 0;

}

}

/*

* @brief Application entry point.

*/

int main(void)

{

/* Init board hardware */

BOARD_ConfigMPU();

BOARD_InitPins();

BOARD_InitDEBUG_UARTPins();

BOARD_BootClockRUN();

BOARD_InitDebugConsole();

NVIC_SetPriorityGrouping(3);

InitTimer();

std::cout << "The hello_world demo of TensorFlow Lite modelrn";

RunInference();

std::flush(std::cout);

for (;;) {}

}

此工程運行之后打印信息如下:

可以看到,輸入的值x_value通過模型計算之后得到y(tǒng)_value

使用的就是計算網(wǎng)絡中的公式:

predictions = (input_holder * W1) + B1 + 10

由此可知,我們模型轉(zhuǎn)換成功。

相關推薦

登錄即可解鎖
  • 海量技術文章
  • 設計資源下載
  • 產(chǎn)業(yè)鏈客戶資源
  • 寫文章/發(fā)需求
立即登錄

秉承專業(yè)態(tài)度,專注智能設備核心平臺研發(fā)與制造,以技術研發(fā)創(chuàng)新為主導,以客戶實用化,產(chǎn)品化為目標,把握嵌入式行業(yè)的前沿發(fā)展需求,利用核心技術為客戶提供穩(wěn)定、可靠、功能優(yōu)異的高品質(zhì)產(chǎn)品。合作聯(lián)系:17713286011