您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

使用来自Keras模型的张量流图进行预测

使用来自Keras模型的张量流图进行预测

@frankyjuang将我链接到这里

https://github.com/amir-abdi/keras_to_tensorflow

并将其与来自

https://github.com/metaflow-ai/blog/blob/master/tf- freeze/load.py

https://github.com/tensorflow/tensorflow/issues/675

我找到了既可以使用tf图进行预测又可以创建jacobian函数解决方案:

import tensorflow as tf
import numpy as np

# Create function to convert saved keras model to tensorflow graph
def convert_to_pb(weight_file,input_fld='',output_fld=''):

    import os
    import os.path as osp
    from tensorflow.python.framework import graph_util
    from tensorflow.python.framework import graph_io
    from keras.models import load_model
    from keras import backend as K


    # weight_file is a .h5 keras model file
    output_node_names_of_input_network = ["pred0"] 
    output_node_names_of_final_network = 'output_node'

    # change filename to a .pb tensorflow file
    output_graph_name = weight_file[:-2]+'pb'
    weight_file_path = osp.join(input_fld, weight_file)

    net_model = load_model(weight_file_path)

    num_output = len(output_node_names_of_input_network)
    pred = [None]*num_output
    pred_node_names = [None]*num_output

    for i in range(num_output):
        pred_node_names[i] = output_node_names_of_final_network+str(i)
        pred[i] = tf.identity(net_model.output[i], name=pred_node_names[i])

    sess = K.get_session()

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)
    graph_io.write_graph(constant_graph, output_fld, output_graph_name, as_text=False)
    print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

    return output_fld+output_graph_name

呼叫:

tf_model_path = convert_to_pb('model_file.h5','/model_dir/','/model_dir/')

创建函数以将tf模型加载为图形:

def load_graph(frozen_graph_filename):
    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we can use again a convenient built-in function to import a graph_def into the 
    # current default Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def, 
            input_map=None, 
            return_elements=None, 
            name="prefix", 
            op_dict=None, 
            producer_op_list=None
        )

    input_name = graph.get_operations()[0].name+':0'
    output_name = graph.get_operations()[-1].name+':0'

    return graph, input_name, output_name

创建一个函数以使用tf图进行模型预测

def predict(model_path, input_data):
    # load tf graph
    tf_model,tf_input,tf_output = load_graph(model_path)

    # Create tensors for model input and output
    x = tf_model.get_tensor_by_name(tf_input)
    y = tf_model.get_tensor_by_name(tf_output)

    # Number of model outputs
    num_outputs = y.shape.as_list()[0]
    predictions = np.zeros((input_data.shape[0],num_outputs))
    for i in range(input_data.shape[0]):        
        with tf.Session(graph=tf_model) as sess:
            y_out = sess.run(y, Feed_dict={x: input_data[i:i+1]})
            predictions[i] = y_out

    return predictions

作出预测:

tf_predictions = predict(tf_model_path,test_data)

雅可比函数

def compute_jacobian(model_path,input_data):

    tf_model,tf_input,tf_output = load_graph(model_path)

    x = tf_model.get_tensor_by_name(tf_input)
    y = tf_model.get_tensor_by_name(tf_output)
    y_list = tf.unstack(y)
    num_outputs = y.shape.as_list()[0]
    jacobian = np.zeros((num_outputs,input_data.shape[0],input_data.shape[1]))
    for i in range(input_data.shape[0]):
        with tf.Session(graph=tf_model) as sess:
            y_out = sess.run([tf.gradients(y_, x)[0] for y_ in y_list], Feed_dict={x: input_data[i:i+1]})
            jac_temp = np.asarray(y_out)
        jacobian[:,i:i+1,:]=jac_temp[:,:,:,0]
    return jacobian

计算雅可比矩阵:

jacobians = compute_jacobian(tf_model_path,test_data)
其他 2022/1/1 18:33:17 有555人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶