好了,我们现在有了可以将图片转化为Estimator需要的输入数据的方法,以及一个自定义的Estimator的方法,我们可以开始我们的模型训练了。

模型训练


chart_classifier = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir="E:\\char_3_Conv")
chart_classifier.train(input_fn=lambda:train_input_fn(images,labels),steps=200000)

创建自定义Estimator的时候,我们指定了我们定义的模型方法,还给出了模型路径。Estimator会自动把定期把log信息和模型参数存在这个路径下。这点非常有用,如果你中断了模型的训练后,下次再运行的时候,模型回先加载之前训练的模型参数。然后接着训练。不会从头开始。

模型验证


eval_results=chart_classifier.evaluate(input_fn=lambda:eval_input_fn(images_eval,labels_eval))
print(eval_results)

模型验证使用了我们的测试数据。

我在训练时发现经过20万次迭代,在验证集上,模型精度可以达到0.88左右。当然,你可以通过增加网络深度来尝试获取更佳的精度。

模型预测

下边我们就要用我们的模型来预测一下了。我自己写了4个字:

然后我们用我们的模型来预测一下:


import pickle
f=open('E:/data/hwdb/char_dict','br')
dict = pickle.load(f)
images_pred,lables_pred = read_data(predictDataPath)
index_to_char = {value:key for key,value in dict.items()}
for pred in chart_classifier.predict(input_fn=lambda:eval_input_fn(images_pred,lables_pred)):
    print(index_to_char[pred["classes"]])

因为我们存在char_dict文件里是汉字对应index,我们需要将index转回汉字,所以我们生成了一个index_to_char方法。还有,对于Estimator的predict方法,其实是不需要labels,但是为了重用前边的代码,我们还是传入labels,实际上并没有作用。
最终,我们得到的结果如下:


INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_task_id': 0, '_model_dir': 'E:\\char_3_Conv', '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000000006BD808D0>, '_save_summary_steps': 100, '_num_ps_replicas': 0, '_log_step_count_steps': 100, '_keep_checkpoint_max': 5, '_evaluation_master': '', '_keep_checkpoint_every_n_hours': 10000, '_is_chief': True, '_tf_random_seed': None, '_service': None, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_task_type': 'worker', '_session_config': None, '_master': '', '_num_worker_replicas': 1, '_global_id_in_cluster': 0}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from E:\char_3_Conv\model.ckpt-45933
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.



可以看出,我们的模型得到了准确的答案。

TensorBoard

在模型训练过程中,我们可以打开tensorboard来看模型的训练进度,包括当前的loss值,以及模型验证的准确度,还有当前的训练速度。
我的tensorboard在这个目录下:C:\ProgramData\Anaconda3\envs\tensorflow\Scripts
运行时,我们需要指定我们模型的文件目录


C:\ProgramData\Anaconda3\envs\tensorflow\Scripts&gt;tensorboard --logdir=E:\char_3_
Conv

然后你可以在你的浏览器里输入http://localhost:6006 来打开tensorboard
其中你可以在Graph里看到整个Graph的结构:

还有,就是可以看到loss,accuracy,以及global_step/sec的图:

1 对 “用TensorFlow实现一个手写汉字识别之三:训练,验证模型,以及用模型预测”的想法;

发表评论

电子邮件地址不会被公开。 必填项已用*标注

%d 博主赞过: