Autograph

Table of Contents

1. Autograph

1.1. 先编译成静态图再执行

import tensorflow as tf
import numpy as np 

@tf.function(autograph=True)
def myadd(a,b):
    for i in tf.range(3):
        tf.print(i)
        print("tracing")
    return a + b
print("=== first")
print(myadd(tf.constant("hello"), tf.constant("world")))
print("=== second")
print(myadd(tf.constant("hello"), tf.constant("world")))

第一次执行时需要先编译成 static graph, 由于 tf.range 会被编译为 while tensor (见下文), 所以 "tracing" 只输出了一次.

第二次执行时不需要编译, 直接执行之前生成的 static graph, 所以不会再输出 "tracing", 因为 python 的 print 并不会转换为 static graph 中的某个节点, 但 tf.print 会以节点的形式存在于 graph 中, 所以每次都会打印出 0,1,2

= first tracing 0 1 2 tf.Tensor(b'helloworld', shape=(), dtype=string) = second 0 1 2 tf.Tensor(b'helloworld', shape=(), dtype=string)

1.2. 如何编译

编译的过程是: 把函数做为普通的 python 函数执行一次, 执行过程中碰到 tf.xxx 则生成对应的 tensor 嵌入到 static graph 中

1.2.1. tf.range 对应 while tensor

import tensorflow as tf

writer = tf.summary.create_file_writer("/tmp/tf/")
tf.summary.trace_on(graph=True)

@tf.function(autograph=True)
def myadd(a,b):
    for i in tf.range(3):
        tf.print(i)
        print("tracing")
    return a + b

myadd(tf.constant("hello"), tf.constant("world"))

with writer.as_default():
    tf.summary.trace_export(name = "autograph",step=0)

autograph 把 myadd 中的的 tf.xxx 转换为相应的 node, tf.range 对应 graph 中的 "while_loop" tensor

tracing 0 1 2

autograph.png

1.2.2. python range

import tensorflow as tf

writer = tf.summary.create_file_writer("/tmp/tf/")
tf.summary.trace_on(graph=True)

@tf.function(autograph=True)
def myadd(a,b):
    if True:
        for i in range(3):
            tf.print(i)
            print("tracing")
    return a + b

myadd(tf.constant("hello"), tf.constant("world"))

with writer.as_default():
    tf.summary.trace_export(name = "autograph",step=0)


这里用 range 代替了 tf.range, 导致 graph 中生成了 3 个 PrintV2 tensor, 而不是一个 while_loop tensor

tracing tracing tracing 0 1 2

autograph2.png

1.3. 限制

  1. python 代码 (而非 tf api) 只在编译时才会起作用, 导致一些意料之外的结果, 所以应尽可能使用 tf api.
  2. 若参数包含非 tensor 类型, 则每次都调用都会导致重新编译, 因为参数可能导致不同的控制流, 而这些控制流的变化无法嵌入到 graph 中
data = []

@tf.function(autograph=True)
def myadd(a,b, x):
    data.append(x)
    return a + b

myadd(tf.constant("hello"), tf.constant("world"),1)
print(data)
myadd(tf.constant("hello"), tf.constant("world"),2)
print(data)

[1] [1, 2]

Author: [email protected]
Date: 2020-07-23 Thu 00:00
Last updated: 2022-01-03 Mon 10:52

知识共享许可协议