Tensorflow: 如何在 Estimator API 中初始化嵌入层?

创建于 2018-01-12  ·  3评论  ·  资料来源: tensorflow/tensorflow

我正在尝试在 tensorflow 模型中使用现有的嵌入,嵌入的大小大于 2Gb,这使得我最初尝试这样做不成功:

embedding_var = tf.get_variable(
        "embeddings", 
        shape=GLOVE_MATRIX.shape, 
        initializer=tf.constant_initializer(np.array(GLOVE_MATRIX))
)

这给了我这个错误:

Cannot create a tensor proto whose content is larger than 2GB.

我正在使用基于 Estimator API 的 AWS SageMaker,并且会话中图形的实际运行发生在幕后,因此我不确定如何初始化一些占位符以进行嵌入。 如果有人能够分享如何在 EstimatorAPI 方面进行此类初始化的方式,将会很有帮助。


请前往 Stack Overflow 寻求帮助和支持:

https://stackoverflow.com/questions/tagged/tensorflow

如果您打开 GitHub 问题,这是我们的政策:

  1. 它必须是错误或功能请求。
  2. 必须填写以下表格。
  3. 这不应该是TensorBoard的问题。 那些去这里

这就是我们制定


系统信息

  • 我是否编写了自定义代码(而不是使用 TensorFlow 中提供的股票示例脚本)
  • 操作系统平台和发行版(例如,Linux Ubuntu 16.04)
  • 从(源代码或二进制文件)安装的 TensorFlow
  • TensorFlow 版本(使用下面的命令)
  • 蟒蛇版本
  • Bazel 版本(如果从源代码编译)
  • GCC /编译器版本(如果从源代码编译)
  • CUDA / cuDNN版本
  • GPU 型号和内存
  • 重现的确切命令

您可以使用我们的环境捕获脚本收集其中的一些信息:

https://github.com/tensorflow/tensorflow/tree/master/tools/tf_env_collect.sh

您可以使用以下命令获取 TensorFlow 版本

python -c“将tensorflow导入为tf;打印(tf.GIT_VERSION,tf.VERSION)”

描述问题

在这里清楚地描述问题。 请务必在此处说明为什么它是 TensorFlow 中的错误或功能请求。

源代码/日志

包括任何有助于诊断问题的日志或源代码。 如果包括回溯,请包括完整的回溯。 应附上大型日志和文件。 尝试提供一个可重现的测试用例,它是生成问题所需的最低限度。

awaiting response bug

最有用的评论

看起来用嵌入初始化变量的正确方法是使用tf.train.Scaffold 。 这是有关stackoverflow 的更多信息

所有3条评论

我认为这通常是“发送到 StackOverflow”(下面附有标准响应)类型的问题,但 2GB 的限制似乎在错误或功能请求的范围内。

@martinwicke @ispirmustafa 有什么建议吗?

这个问题最好在StackOverflow 上问,因为它不是错误或功能请求。 还有一个更大的社区在那里阅读问题。 谢谢!

我认为这与图形大小限制有关。 使用 constant_initializer 将 GLOVE_MATRIX 嵌入到图形中,从而增加了图形的大小。
你能尝试使用非常量初始化程序吗?

看起来用嵌入初始化变量的正确方法是使用tf.train.Scaffold 。 这是有关stackoverflow 的更多信息

此页面是否有帮助?
0 / 5 - 0 等级