Tensorflow可以使用训练好的模型对新的数据进行测试,有两种方法:第一种方法是调用模型和训练在同一个py文件中,中情况比较简单;第二种是训练过程和调用模型过程分别在两个py文件中。本文将讲解第二种方法。

模型的保存
tensorflow提供可保存训练模型的接口,使用起来也不是很难,直接上代码讲解:
#网络结构
w1 = tf.Variable(tf.truncated_normal([in_units, h2_units], stddev=0.1))
b1 = tf.Variable(tf.zeros([h2_units]))
y = tf.nn.softmax(tf.matmul(w1, x) + b1)
tf.add_to_collection('network-output', y)
x = tf.placeholder(tf.float32, [None, in_units], name='x')
y_ = tf.placeholder(tf.float32, [None, 10], name='y_')
#损失函数与优化函数
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(rate).minimize(cross_entropy)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.save(sess,"save/model.ckpt")
train_step.run({x: train_x, y_: train_y})