TensorFlow:tf.Session函数

2018-01-22 10:37 更新

tf.Session 函数

Session 类

定义在:tensorflow/python/client/session.py.

请参阅指南:运行图>会话管理

用于运行TensorFlow操作的类.

一个Session对象封装了Operation执行对象的环境,并对Tensor对象进行计算.例如:

# Build a graph.
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b

# Launch the graph in a session.
sess = tf.Session()

# Evaluate the tensor `c`.
print(sess.run(c))

session可能拥有的资源,如:tf.Variable,tf.QueueBase和tf.ReaderBase.不再需要时释放这些资源是非常重要的.为此,请在session中调用tf.Session.close方法,或使用session作为上下文管理器.以下两个例子是等价的:

# Using the `close()` method.
sess = tf.Session()
sess.run(...)
sess.close()

# Using the context manager.
with tf.Session() as sess:
  sess.run(...)

ConfigProto协议缓存公开了用于session的各种配置选项.例如,要创建为设备放置使用软约束的session,并记录生成的放置决策,请按如下方式创建session:

# Launch the graph in a session that allows soft device placement and
# logs the placement decisions.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                        log_device_placement=True))

Session属性

  • graph
    本次session上发布的图表.
  • graph_def
    底层TensorFlow图形的可序列化版本.
    • 函数返回:
      • 包含底层TensorFlow图表中所有操作的节点的graph_pb2.GraphDef原型.
  • sess_str

Session 方法

__init__

__init__(
    target='',
    graph=None,
    config=None
)

创建一个新的TensorFlow session.

如果在构建session时没有指定graph参数,则将在session中启动默认关系图.如果使用多个图(在同一个过程中使用tf.Graph()创建,则必须为每个图使用不同的sessio,但是每个图都可以用于多个sessio中,在这种情况下,将图形显式地传递给sessio构造函数通常更清晰.

方法参数

  • target:(可选)要连接到的执行引擎.默认使用进程内引擎.有关更多示例,请参阅“Distributed TensorFlow”.
  • graph:(可选)将被启动的Graph(如上所述).
  • config:(可选)具有session配置选项的ConfigProto协议缓冲区.

__enter__

__enter__()

__exit__

__exit__(
    exec_type,
    exec_value,
    exec_tb
)

as_default

as_default()

返回使该对象成为默认session的上下文管理器.

与with关键字一起使用来指定在此session中调用tf.Operation.run或tf.Tensor.eval应执行的操作.

c = tf.constant(..)
sess = tf.Session()

with sess.as_default():
  assert tf.get_default_session() is sess
  print(c.eval())

要获取当前的默认session,请使用tf.get_default_session.

注意:退出上下文时,as_default上下文管理器不会关闭session,并且必须显式关闭session.

c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
  print(c.eval())
# ...
with sess.as_default():
  print(c.eval())

sess.close()

或者,您可以使用tf.Session():创建会在退出上下文时自动关闭的session,包括未捕获的异常发生时.

注意,默认session是当前线程的一个属性.如果你创建一个新的线程,并希望在该线程中使用默认的session,则必须明确地添加一个sess.as_default():到该线程的函数.

注意,输入一个sess.as_default():块不会影响当前的默认图形.如果您正在使用多个图表,并且与其sess.graph值不同,则tf.get_default_graph必须明确地输入一个带有sess.graph.as_default():块来创建sess.graph默认图形.

as_default()方法返回:

使用此session作为默认session的上下文管理器.

close

close()

关闭这个session.

调用此方法可释放与session关联的所有资源.

可能引发的异常

  • tf.errors.OpError:如果在关闭TensorFlow session时发生错误,则会有一个子类.

list_devices

list_devices()

列出此session中的可用设备.

devices = sess.list_devices()
for d in devices:
  print(d.name)

列表中的每个元素都具有以下属性:

  • name:具有设备全名的字符串.例如:/job:worker/replica:0/task:3/device:CPU:0
  • device_type:设备的类型(例如CPU,GPU,TPU) 
  • memory_limit:存储设备上可用的最大内存量.注意:取决于设备,可用内存可能会大大减少.

可能引发的异常

  • tf.errors.OpError:如果遇到错误(例如session处于无效状态,或发生网络错误).

list_devices()方法返回:

list_devices()方法将返回session中的设备列表.

make_callable

make_callable(
    fetches,
    feed_list=None,
    accept_options=False
)

返回运行特定步骤的Python可调用对象.

返回的可调用将采取 len (feed_list) 参数,其类型必须是feed_list各自元素的兼容feed值.例如,如果feed_list的元素i是一个tf.Tensor,则返回的可调用的第 i 参数必须是一个 numpy 的 ndarray(或可转化成ndarray的东西)具有匹配元素类型和形状.请参阅tf.Session.run允许的Feed键和值类型的详细信息.

返回的可调用将具有与tf.Session.run(fetches, ...).例如,如果fetches是tf.Tensor ,则可调用将返回一个numpy的ndarray; 如果fetches是一个tf.Operation,它会返回None.

方法参数

  • fetches:要获取的值或值列表.请参阅tf.Session.run允许的获取类型的详细信息.
  • feed_list:(可选)一个feed_dict键列表.请参阅tf.Session.run允许的Feed键类型的详细信息.
  • accept_options:(可选)如果为True,则返回的Callable将是能够接受tf.RunOptions和tf.RunMetadata可选关键字参数options,并且run_metadata分别使用与tf.Session.run相同的语法和语义,这对于某些使用情况很有用(分析和调试),但会导致可测量放缓的Callable的表现.默认为False.

方法返回

一个函数调用将执行由feed_list定义的步骤时,并在此会话中读取的函数.

可能引发的异常

  • TypeError:如果fetches或feed_list不能被解释为tf.Session.run的参数.

partial_run

partial_run(
    handle,
    fetches,
    feed_dict=None
)

通过更多的feed和fetche继续执行.

这是实验性的,可能会有变化.

要使用部分执行,用户首先调用partial_run_setup(),然后是一个序列partial_run().partial_run_setup指定将在随后的partial_run调用中使用的提要和提取列表.

可选feed_dict参数允许调用者覆盖图中张量的值.请参阅run()以获取更多信息.

下面是一个简单的例子:

a = array_ops.placeholder(dtypes.float32, shape=[])
b = array_ops.placeholder(dtypes.float32, shape=[])
c = array_ops.placeholder(dtypes.float32, shape=[])
r1 = math_ops.add(a, b)
r2 = math_ops.multiply(r1, c)

h = sess.partial_run_setup([r1, r2], [a, b, c])
res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
res = sess.partial_run(h, r2, feed_dict={c: res})

方法参数

  • handle:部分运行序列的处理器.
  • fetches:单个图形元素,图形元素列表或其值为图形元素或图形元素列表的字典(请参阅“run文档”).
  • feed_dict:将图表元素映射到值的字典(如上所述).

方法返回:

可以是单个值,如果fetches是单个图元素,或者列表值,如果fetches是列表,或者是具有与字典相同的键fetches的字典(请参阅“run文档”).

方法可能引发的异常

  • tf.errors.OpError:其中一个子类出错.

partial_run_setup

partial_run_setup(
    fetches,
    feeds=None
)

为部分运行设置一个带有feed和fetche的图形.

这是实验性的,可能会有变化.

请注意,与运行相反,feeds只能指定图形元素.张量将由随后的partial_run调用提供.

方法参数

  • fetches:一个图形元素,或者一个图元素列表.
  • feeds:一个图形元素,或者一个图元素列表.

方法返回:

局部运行的处理器.

可能引发的异常

  • RuntimeError:如果这Session是无效状态(例如已经关闭).
  • TypeError:如果fetches或者feed_dict键的类型不合适.
  • tf.errors.OpError:如果发生TensorFlow错误,或者它的一个子类.

reset

@staticmethod
reset(
    target,
    containers=None,
    config=None
)

在target上重置资源容器,并关闭所有连接的会话.

资源容器分布在同一个群集target中的所有工作人员.target重置资源容器时,与该容器关联的资源将被清除.尤其是,容器中的所有变量都将变得不确定:它们将失去其值和形状.

注意:(i)reset()目前仅用于分布式会话.(ii)任何名为target的主的session将被关闭.

如果没有提供资源容器,则所有的容器都被重置.

方法参数

  • target:连接到的执行引擎.
  • containers:资源容器名称字符串的列表,如果所有容器都将被重置,则为None.
  • config:(可选)具有配置选项的协议缓冲区.

可能引发的异常

  • tf.errors.OpError:或者如果在重置容器时发生错误,它的一个子类.

run

run(
    fetches,
    feed_dict=None,
    options=None,
    run_metadata=None
)

在fetches中运行操作和计算张量.

此方法运行一个TensorFlow计算的一个“步骤”,通过运行所需的图形片段来执行每个Operation和计算fetches中的每个Tensor,用 feed_dict 中的值替换相应的输入值.

所述fetches参数可以是一个单一的图形元素,或任意嵌套列表、元组、namedtuple、字典、或含有它的叶子图表元素OrderedDict.图形元素可以是以下类型之一:

  • 一个tf.Operation.相应的取值将会是None.
  • 一个tf.Tensor.相应的取值将是一个包含该张量值的numpy ndarray.
  • 一个tf.SparseTensor.相应的取值将是一个tf.SparseTensorValue包含稀疏张量的值.
  • 一个get_tensor_handle操作.相应的取值将是包含该张量句柄的numpy ndarray.
  • A string是图中张量或操作的名称.

run()返回的值具有与fetches参数相同的形状,叶子由TensorFlow返回的相应值替换.

示例:

a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
# 'fetches' can be a singleton
v = session.run(a)
# v is the numpy array [10, 20]
# 'fetches' can be a list.
v = session.run([a, b])
# v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
# 1-D array [1.0, 2.0]
# 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
MyData = collections.namedtuple('MyData', ['a', 'b'])
v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
# v is a dict with
# v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
# 'b' (the numpy array [1.0, 2.0])
# v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
# [10, 20]

可选的feed_dict参数允许调用者在关系图中覆盖张量的值.feed_dict 中的每个键都可以是以下类型之一:

  • 如果键是a tf.Tensor,则值可以是可以转换为与dtype张量相同的Python标量,字符串,列表或numpy ndarray .此外,如果键是a tf.placeholder,则将检查值的形状是否与占位符兼容.
  • 如果键是a tf.SparseTensor,则值应该是a tf.SparseTensorValue.
  • 如果键是Tensors或SparseTensors 的嵌套元组,则该值应该是一个嵌套元组,其结构与映射到上面相应的值相同.

feed_dict 中的每个值必须可转换为相应键的 dtype 的 numpy 数组.

可选options参数需要一个[ RunOptions] 原型.这些选项允许控制此特定步骤的行为(例如,启用跟踪).

可选run_metadata参数需要一个[ RunMetadata] 原型.在适当的时候,这个步骤的非张量输出将被收集在那里.例如,当用户在options打开跟踪时,配置文件信息将被收集到该参数中并传回.

方法参数

  • fetches:单个图形元素,图形元素列表或其值为图元素或图元素列表(如上所述)的字典.
  • feed_dict:将图表元素映射到值的字典(如上所述).
  • options:一个[ RunOptions]协议缓冲区
  • run_metadata:一个[ RunMetadata]协议缓冲区

方法返回:

单个值如果fetches是单个图元素,或者值列表if fetches是列表,或者具有与fetches字典(如上所述)相同的关键字的字典.

可能发生的异常

  • RuntimeError:如果该Session是无效状态(例如已经关闭).
  • TypeError:如果fetches或者feed_dict键的类型不合适.
  • ValueError:如果fetches或者feed_dict键无效或者引用Tensor不存在的键.


以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy