加入星计划,您可以享受以下权益:

  • 创作内容快速变现
  • 行业影响力扩散
  • 作品版权保护
  • 300W+ 专业用户
  • 1.5W+ 优质创作者
  • 5000+ 长期合作伙伴
立即加入
  • 正文
    • 01、常见的模型格式
    • 02、建立模型并进行转换
    • 03、运行验证环节
  • 推荐器件
  • 相关推荐
  • 电子产业图谱
申请入驻 产业图谱

研发干货丨OK1052-C开发板运行 Tensorflow_lite模型

2021/04/01
713
阅读需 21 分钟
加入交流群
扫码加入
获取工程师必备礼包
参与热点资讯讨论

tensorFlowLite 是一款TensorFlow用于移动设备和嵌入式设备的轻量级解决方案。

我们知道TensorFlow可以在多个平台上运行,从机架式服务器到小型IoT设备。但是随着近年来机器学习模型的广泛使用,出现了在移动和嵌入式设备上部署它们的需求,而TensorFlowLite 允许设备端的机器学习模型的低延迟推断。

飞凌OK1052-C开发板以高性能处理器i.MXRT1052作为硬件躯体,再用tensorflow_lite武装软件大脑,也搭上了开往机器学习,边缘计算等朝阳行业领域的快车。

▲OK1052-C开发板接口图

 

OK1052-C的用户资料SDK中,已经有了关于tensorflow_lite使用的demo例程,但是这些例程使用的都是现成训练之后的实验模型,并不适用于我们实际的应用场景。我们在实际应用项目中,必然是需要使用适合本项目的模型,网上现成模型资源下载或者自己训练模型,然后搭载自己的应用程序,完成我们的应用项目需求。

由于模型的训练需要算力的支持,通常我们在计算机上进行模型的训练,我们将训练好的模型称为预训练模型。我们也可以在网上download预训练模型,然后通过格式转换转换为tflite格式的模型,也就是开发板可执行的模型格式;也可以自己搭建计算网络,通过训练之后,形成预训练模型,再转换为tflite格式,运行到开发板。

今天我们通过一个简单的例子,来介绍怎么建立计算网络和进行模型的转换。

01、常见的模型格式

我们常见的训练模型格式有.PB,cpkt,saveModel,H5等,若想将模型运用于tensor flowlite,需要转换为tflite格式,一般的转换过程是:

由上图可知,如果要将Checkpoints(cpkt)格式转换为tflite,需经过freeze_graph.py工具将cpkt格式模型转换为Frozen GraphDef(.pb)格式,然后再经过TFLite Converter转换工具转换为tflite。

02、建立模型并进行转换

现在通过例子介绍如何将模型转换为tflite格式,首先我们先建立两个网络流图,并分别生成.pb和cpkt格式的模型。

1)建立计算网络,并保存为.PB格式模型

# coding=UTF-8

import tensorflow as tf

import shutil

import os.path

from tensorflow.python.framework import graph_util

output_graph = "easy_model/add_model.pb"

#下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式

input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")

W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")

B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")

_y = (input_holder * W1) + B1

# predictions = tf.greater(_y, 50, name="predictions") #比50大返回true,否则返回false

predictions = tf.add(_y, 10,name="predictions") #做一个加法运算

init = tf.global_variables_initializer()

with tf.Session() as sess:

sess.run(init)

print ("predictions :", sess.run(predictions, feed_dict={input_holder: [10.0]}))

graph_def = tf.get_default_graph().as_graph_def() #得到当前的图的 GraphDef 部分,通过这个部分就可以完成重输入层到输出层的计算过程

output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定

sess,

graph_def,

["predictions"] #需要保存节点的名字

)

with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型

f.write(output_graph_def.SerializeToString()) # 序列化输出

print("%d ops in the final graph." % len(output_graph_def.node))

print (predictions)

此计算网络只是一个简单的数学运算,不需要进行训练,该运算公式为:

predictions = (input_holder * W1) + B1 + 10

其中input_holder是网络的输入节点,predictions是网络的输出节点,W1,B1是两个变量,分别被赋值为5.0,,1.0

程序运行之后,会在easy_model/文件夹下生成add_model.pb模型文件。我们通过Netron软件(Netron是一个很方便的软件)来看一下add_model.pb模型文件中存储的网络图及参数信息:

 

2)建立网络图,并保存为.ckpt格式模型:

# coding=UTF-8 支持中文编码格式

import tensorflow as tf

import shutil

import os.path

MODEL_DIR = "easy_model"

MODEL_NAME = "model.ckpt"

if not tf.gfile.Exists(MODEL_DIR):

tf.gfile.MakeDirs(MODEL_DIR)

下面的过程你可以替换成CNN、RNN等你想做的训练过程,这里只是简单的一个计算公式

input_holder = tf.placeholder(tf.float32, shape=[1], name="input_holder")

W1 = tf.Variable(tf.constant(5.0, shape=[1]), name="W1")

B1 = tf.Variable(tf.constant(1.0, shape=[1]), name="B1")

_y = (input_holder * W1) + B1

#predictions = tf.greater(_y, 50, name="predictions")

predictions = tf.add(_y, 10,name="predictions")

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:

sess.run(init)

print ("predictions : ", sess.run(predictions, feed_dict={input_holder: [10.0]}))

saver.save(sess, os.path.join(MODEL_DIR, MODEL_NAME))

print("%d ops in the final graph." % len(tf.get_default_graph().as_graph_def().node))

for op in tf.get_default_graph().get_operations():

print (op.name, op.values())

 

此网络图同上一节一样是一个简单的数学运算。

不同的是,此程序最后保存为ckpt格式的模型文件:

checkpoint :记录目录下所有模型文件列表

ckpt.data :保存模型中每个变量的取值

ckpt.meta :保存整个计算网络图的结构

同样可通过Netron软件查看ckpt文件中存储的网络图结构:

 

3)将生成的cpkt格式模型文件转换为.pb文件:

import tensorflow as tf

from tensorflow.python.framework import graph_util

#from create_tf_record import *

resize_height = 100 # 指定图片高度

resize_width = 100 # 指定图片宽度

def freeze_graph(input_checkpoint, output_graph):

'''

:param input_checkpoint:

:param output_graph: PB 模型保存路径

:return:

'''

# 检查目录下ckpt文件状态是否可用

# checkpoint = tf.train.get_checkpoint_state(model_folder)

# 得ckpt文件路径

# input_checkpoint = checkpoint.model_checkpoint_path

# 指定输出的节点名称,该节点名称必须是元模型中存在的节点

output_node_names = "predictions"

saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

graph = tf.get_default_graph() # 获得默认的图

input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图

with tf.Session() as sess:

saver.restore(sess, input_checkpoint) # 恢复图并得到数据

# 模型持久化,将变量值固定

output_graph_def = graph_util.convert_variables_to_constants(

sess=sess,

# 等于:sess.graph_def

input_graph_def=input_graph_def,

# 如果有多个输出节点,以逗号隔开

output_node_names=output_node_names.split(","))

# 保存模型

with tf.gfile.GFile(output_graph, "wb") as f:

f.write(output_graph_def.SerializeToString()) # 序列化输出

# 得到当前图有几个操作节点

print("%d ops in the final graph." % len(output_graph_def.node))

input_checkpoint='easy_model/model.ckpt'

# 输出pb模型的路径

out_pb_path="easy_model/frozen_model.pb"

freeze_graph(input_checkpoint,out_pb_path)

这里需要输入输出节点名称output_node_names= "predictions",

程序运行之后在easy_model文件夹下生成了frozen_model.pb文件,我们通过Netron软件查看一下frozen_model.pb网络图,可以发现跟第一节中直接生成的.pb模型文件一样的:

cpkt文件格式将模型保存为4个文件,pb文件格式为一个。ckpt模型持久化方式将图结构与权重参数分开保存,多了模型更多的细节,适合模型训练阶段;而pb持久化方式完成了从输入到输出的前向传播,完成了端到端的形式,更适合离线使用。

 

4)最后,我们将.pb文件转换为.tflite:

我们运行此段代码:

function showSnackbar() {

var $snackbar = $('#snackbar');

$snackbar.addClass('show');

setTimeout(() => {

$snackbar.removeClass('show');

}, 3000);

}

注意这里需要写上输入输出节点名称,这个在构建网络模型时,已经定义。

运行之后,会报错:

module 'tensorflow' has no attribute 'lite'

意思是目前的tensorflow版本不支持lite类,没关系,我们重新安装1.14版本的tensorflow即可成功运行。最后生成tflite格式的模型文件:easy_model.lite。

03、运行验证环节

将转换后模型运行到OK1052-C开发板进行验证

首先需要将将easy_model.lite转换为二进制数组的.h文件easy_frozen_0.h,转换完成之后,将其放入OK1052-C用户资料的,SDK的middlewareeiqtensorflow-liteexampleslabel_image目录下,

打开SDK中boardsevkbimxrt1050eiq_examplestensorflow_lite_label_image下工程,在文件label_iamge.cpp中做如下修改:


 

#include "board.h"

#include "pin_mux.h"

#include "clock_config.h"

#include "fsl_debug_console.h"

#include

#include

#include

#include "timer.h"

#include "tensorflow/lite/kernels/register.h"

#include "tensorflow/lite/model.h"

#include "tensorflow/lite/optional_debug_tools.h"

#include "tensorflow/lite/string_util.h"

//#include "Sine_mode.h"

//#include "add_model.h"

#include "easy_frozen_0.h"

int inference_count = 0;

// This is a small number so that it's easy to read the logs

const int kInferencesPerCycle = 30;

const float kXrange = 2.f * 3.14159265359f;

#define LOG(x) std::cout

void RunInference()

{

std::unique_ptr model;

std::unique_ptr interpreter;

model = tflite::FlatBufferModel::BuildFromBuffer(mobilenet_model, mobilenet_model_len);

if (!model) {

LOG(FATAL) << "Failed to load modelrn";

exit(-1);

}

model->error_reporter();

tflite::ops::builtin::BuiltinOpResolver resolver;

tflite::InterpreterBuilder(*model, resolver)(&interpreter);

if (!interpreter) {

LOG(FATAL) << "Failed to construct interpreterrn";

exit(-1);

}

float input = interpreter->inputs()[0];

if (interpreter->AllocateTensors() != kTfLiteOk) {

LOG(FATAL) << "Failed to allocate tensors!rn";

}

while(true)

{

// Calculate an x value to feed into the model. We compare the current

// inference_count to the number of inferences per cycle to determine

// our position within the range of possible x values the model was

// trained on, and use this to calculate a value.

float position = static_cast(inference_count) /

static_cast(kInferencesPerCycle);

float x_val = position * kXrange;

float* input_tensor_data = interpreter->typed_tensor(input);

*input_tensor_data = x_val;

// Delay_time(1000);

// Run inference, and report any error

TfLiteStatus invoke_status = interpreter->Invoke();

if (invoke_status != kTfLiteOk)

{

LOG(FATAL) << "Failed to invoke tflite!rn";

return;

}

// Read the predicted y value from the model's output tensor

float* y_val = interpreter->typed_output_tensor(0);

PRINTF("rn x_value: %f, y_value: %f rn", x_val, y_val[0]);

//PRINTF("rn x_value: %d, y_value: %d rn", (int)x_val, (int)y_val[0]);

// Increment the inference_counter, and reset it if we have reached

// the total number per cycle

inference_count += 1;

if (inference_count >= kInferencesPerCycle) inference_count = 0;

}

}

/*

* @brief Application entry point.

*/

int main(void)

{

/* Init board hardware */

BOARD_ConfigMPU();

BOARD_InitPins();

BOARD_InitDEBUG_UARTPins();

BOARD_BootClockRUN();

BOARD_InitDebugConsole();

NVIC_SetPriorityGrouping(3);

InitTimer();

std::cout << "The hello_world demo of TensorFlow Lite modelrn";

RunInference();

std::flush(std::cout);

for (;;) {}

}

此工程运行之后打印信息如下:

可以看到,输入的值x_value通过模型计算之后得到y_value

使用的就是计算网络中的公式:

predictions = (input_holder * W1) + B1 + 10

由此可知,我们模型转换成功。

推荐器件

更多器件
器件型号 数量 器件厂商 器件描述 数据手册 ECAD模型 风险等级 参考价格 更多信息
NX3225SA-40.000MHZ-STD-CSR-1 1 Nihon Dempa Kogyo Co Ltd Parallel - Fundamental Quartz Crystal, 40MHz Nom,
暂无数据 查看
NC7WZ07P6X 1 Fairchild Semiconductor Corporation Buffer, LVC/LCX/Z Series, 2-Func, 1-Input, CMOS, PDSO6, 1.25 MM, EIAJ, SC-88, SC-70, 6 PIN
$0.28 查看
AFBR-720XPDZ 1 Foxconn Transceiver, 840nm Min, 860nm Max, 10000Mbps(Tx), 10000Mbps(Rx), LC Connector, Board/panel Mount, ROHS COMPLIANT PACKAGE-30
$164.34 查看

相关推荐

电子产业图谱

秉承专业态度,专注智能设备核心平台研发与制造,以技术研发创新为主导,以客户实用化,产品化为目标,把握嵌入式行业的前沿发展需求,利用核心技术为客户提供稳定、可靠、功能优异的高品质产品。合作联系:17713286011