您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

在TensorFlow中使用预训练的单词嵌入(word2vec或Glove)

在TensorFlow中使用预训练的单词嵌入(word2vec或Glove)

您可以通过多种方式在TensorFlow中使用预训练的嵌入。假设您将NemPy数组嵌入到embedding具有vocab_size行和embedding_dim列的NumPy数组中,并且想要创建一个W可用于调用的张量tf.nn.embedding_lookup()

W = tf.constant(embedding, name="W")

这是最简单的方法,但是由于a的值tf.constant()多次存储在内存中,因此内存使用效率不高。由于embedding可能很大,因此只应将这种方法用于玩具示例。

创建W为a,tf.Variable并通过NumPy数组对其进行初始化tf.placeholder()

W = tf.Variable(tf.constant(0.0, shape=[vocab_size, embedding_dim]),
            trainable=False, name="W")

embedding_placeholder = tf.placeholder(tf.float32, [vocab_size, embedding_dim]) embedding_init = W.assign(embedding_placeholder)

sess = tf.Session()

sess.run(embedding_init, Feed_dict={embedding_placeholder: embedding})

这样可以避免embedding在图表中存储的副本,但确实需要足够的内存才能一次在内存中保留矩阵的两个副本(一个用于NumPy数组,一个用于tf.Variable)。请注意,我假设您想在训练期间保持嵌入矩阵不变,因此W是使用创建的trainable=False

如果将嵌入训练为另一个TensorFlow模型的一部分,则可以使用tf.train.Saver从另一个模型的检查点文件中加载值。这意味着嵌入矩阵可以完全绕过Python。W按照选项2创建,然后执行以下操作:

W = tf.Variable(...)

embedding_saver = tf.train.Saver({“name_of_variable_in_other_model”: W})

sess = tf.Session() embedding_saver.restore(sess, “checkpoint_filename.ckpt”)

其他 2022/1/1 18:25:41 有456人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶