用tf.Session.run去运行opertions

tf.Session.run方法是tensorflow里去执行一个opertion或者对tensor求值的主要方式。你可以把一个或者多个opertaion或者tensor传递给session.run去执行。TensorFlow会执行这些operation和所有这个operation依赖的计算去得到结果。

session.run需要你指定一个fetch列表。一个fetch可以是tensor,operation,类tensor对象(variable)。这个fetch列表决定了session.run要返回的值。这些fetch决定了graph里哪些subgraph里需要被执行来得到结果。这个subgraph是fetch里指定的operation和这个operation依赖的operations。下边的例子演示了这一点:


x = tf.contant([[37.0,-23.0],[1.0,4.0]])
w = tf.Variable(tf.random_uniform([2,2]))
y = tf.matmul(x,w)
output = tf.nn.softmax(y)
init_op = w.initializer

with tf.Session() as sess:
    # 在w上运行initializer
    sess.run(init_op)
    # 对output进行求值,会返回一个ndarray作为结果
    print(sess.run(output))
    # 对y和output进行求值,注意y只会被求一次,它作为y的求值结果也会作为tf.nn.softmax的输入。返回同样都是ndarray
    y_val,output_val = sess.run([y,output])

tf.Session.run会接受一个可选的feeds dictionary。主要是用来完成从一个tensor(主要是placeholder)到值(主要是python的标量,list,ndarray)的映射。在执行时用值来替换tensor。比如:


# 定义一个placehoder,它期望的值是一个浮点型型有3个元素的向量。
x = tf.placeholder(tf.float32,shape=[3])
# 定义一个依赖这个placeholder的操作
y = tf.square(x)
with tf.Session() as sess:
    print(sess.run(y,{x:[1.0,2.0,3.0]})) #输出为[1.0,4.0,9.0]
    print(sess.run(y,{x:[0.0,0.0,5.0]})) #输出为[0.0,0.0,25.0]
        # 产生异常: `tf.errors.InvalidArgumentError`因为没有指定placeholder。
        sess.run(y)
        # 产生异常: `ValueError`, 因为shape不对。
        sess.run(y, {x: 37.0})

tf.Session.run也接受可选的options参数来指定参数。还有一个可选的run_metatdata 参数可以让你收集执行的参数。比如你可以用这些参数一起来记录执行的信息。


y = tf.matmul([[37.0,-23.0],[1.0,4.0]],tf.random_uniform([2,2]))
with tf.Session() as sess:
    # 给sess.run定义options
    options = tf.RunOptions()
    options.output_partition_graphs = True
    options.trace_level = tf.RunOptions.FULL_TRACE
   
    # 给返回的metadata定义一个容器
    metaData = tf.RunMetadata()
    sess.run(y,options = options,run_metadata=metadata)
   
    # 打印在每个硬件上执行的subgraph
    print(metadata.partition_graphs)
   
    # 打印每一个opration的执行时间
    print(metadata.step_stats)

可视化你的图


TensorFlow提供了可以让你理解你的graph的工具。graph visualizer是TensorBoard的一个模块。它可以在浏览器里可视化的输出你在代码里定义的graph。最简单的方式去创建一个可视化的图是创建一个tf.summary.FileWriter,并给他传一个tf.Graph


# 创建你的图
x = tf.constant([[37.0, -23.0], [1.0, 4.0]])
w = tf.Variable(tf.random_uniform([2, 2]))
y = tf.matmul(x, w)
# ...
loss = ...
train_op = tf.train.AdagradOptimizer(0.01).minimize(loss)

with tf.Session() as sess:
  # `sess.graph` 是 `tf.Session`里对graph的引用对象.
  writer = tf.summary.FileWriter("/tmp/log/...", sess.graph)

  # 执行你的计算...
  for i in range(1000):
    sess.run(train_op)
    # ...

  writer.close()

注意:如果你用的是一个tf.estimator.Estimator对象。那么graph和任意的summaries都会被自动到放到你在创建estimator时指定的model_dir里。

你可以再TensorBoard里打开这个log。导航到Graph tab,可以看到高层次的你的图的结构。注意一个标准的tensorflow的graph –特别是用自动计算梯度的训练模型图– 一次看所有node太多了。这个graph的可视化工具用name scope把相关的操作封装到“super” 节点里。你可以点击橙色的“+”按钮去展开一个super节点去看里边的子图

多图编程


注意:当训练一个模型时,一种通用的做法是用一个graph去训练你的模型,用给另一个图去引用你的模型去验证你的模型的性能。很多情况下,验证的graph和训练的graph构造也不一样。比如droupout和batch normalization在不同的图里操作是不一样的。更进一步,默认情况下像tf.train.Server这样的工具类使用tf.Variable的名字(variable的名字是根据底层的tf.Operation命名的。)去识别一个保存的checkpoint的variable。你用不同的进程去执行不同的graph没有问题,你也可以在一个进程里执行多个图。这里我们要说的是第二种情况。

就像上边说的,TensorFlow提供了一个default graph,它会被隐式的传给在相同context下的所有API function。对于大多数程序,一个graph就够用了。但是Tensorflow也提供了操作default graph的方法。它可以帮你实现一些高级的应用。比如:
– 一个tf.Graph给所有tf.Operation定义了namespace。在单个graph里每个operation必须有一个唯一的名字。如果指定的名字已存在,tensorflow通过添加_1,_2..来让它唯一。通过创建多个graph,你可以更多的控制每个operation的名字。
– default的graph存储了关于每个Operation和Tensor的所有被设置的信息。如果你的程序创建了大量不连接的子图,最好用不同的tf.Graph去构建每个子图。这样不相关的状态可以被垃圾回收。

你可以创建一个不同的Graph作为default graph。可以使用tf.Graph.as_default context manager:


g_1 = tf.Graph()
with g_1.as_default():
    # 这里创建的operation会被加入到g_1
    c = tf.constant("node in g_1")
    # 这里创建的session 会返回g_1的operations
    sess_1 = tf.Session()
g_2 = tf.Graph()
with g_2.as_default():
    # 这里创建的operation会被加入到g_2
    d = tf.constant("Node in g_2")
# 创建一个session的时候你可以传入一个图。
# sess_2 会运行g_2里的操作
sess_2 = tf.Session(graph=g_2)

assert c.graph is g_1
assert sess_1.graph is g_1
assert d.graph is g_2
assert sess_2.graph is g_2

可以通过tf.get_default_graph来获取当前的default graph。这个方法返回了一个tf.Graph的对象。


# 打印在default graph里的所有操作
g = tf.get_default_graph()
print(g.get_operations())

发表评论

电子邮件地址不会被公开。 必填项已用*标注

%d 博主赞过: