Tensorflow计算网络参数量和计算量FLOPs
Author: 杭州电子科技大学-自动化学院-智能系统和机器人研究中心-Jolen Xie
先引入头文件
import tensorflow as tf
1.计算参数量
def count_param(): # 计算网络参数量
total_parameters = 0
for v in tf.trainable_variables():
shape = v.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
print('网络总参数量:', total_parameters)
2. 计算FLOPS
def count_flops(graph):
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
print('FLOPs: {}'.format(flops.total_float_ops))
3.一起算
def stats_graph(graph):
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
一起算的使用方法:
...
graph =tf.get_default_graph()
stats_graph(graph)
...