scheduled sampling in Tensorflow

后端 未结 3 1014
长情又很酷
长情又很酷 2020-12-31 19:13

The newest Tensorflow api about seq2seq model has included scheduled sampling:

https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTraini

3条回答
  •  耶瑟儿~
    2020-12-31 19:49

    This might also help you. This is for the case where you want to do scheduled sampling at each decoding step separately.

    import tensorflow as tf
    import numpy as np
    from tensorflow.python.ops import array_ops
    from tensorflow.python.ops import gen_array_ops
    from tensorflow.python.ops import math_ops
    from tensorflow.python.ops.distributions import categorical
    from tensorflow.python.ops.distributions import bernoulli
    batch_size = 64
    vocab_size = 50000
    emb_dim = 128
    output = tf.get_variable('output', 
    initializer=tf.constant(np.random.rand(batch_size,vocab_size)))
    base_next_inputs = tf.get_variable('input', 
    initializer=tf.constant(np.random.rand(batch_size,emb_dim)))
    embedding = tf.get_variable('embedding', 
    initializer=tf.constant(np.random.rand(vocab_size,emb_dim)))
    select_sampler = bernoulli.Bernoulli(probs=0.99, dtype=tf.bool)
    select_sample = select_sampler.sample(sample_shape=batch_size, 
    seed=123)
    sample_id_sampler = categorical.Categorical(logits=output)
    sample_ids = array_ops.where(
        select_sample,
        sample_id_sampler.sample(seed=123),
        gen_array_ops.fill([batch_size], -1))
    
    where_sampling = math_ops.cast(
       array_ops.where(sample_ids > -1), tf.int32)
    where_not_sampling = math_ops.cast(
       array_ops.where(sample_ids <= -1), tf.int32)
    sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
    inputs_not_sampling = array_ops.gather_nd(base_next_inputs, 
         where_not_sampling)
    sampled_next_inputs = tf.nn.embedding_lookup(embedding, 
        sample_ids_sampling)
    base_shape = array_ops.shape(base_next_inputs)
    result1 = array_ops.scatter_nd(indices=where_sampling, 
       updates=sampled_next_inputs, shape=base_shape)
    result2 = array_ops.scatter_nd(indices=where_not_sampling, 
       updates=inputs_not_sampling, shape=base_shape)
    result = result1 + result2
    

    I used the tensorflow documentation code to make this example. https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/contrib/seq2seq/python/ops/helper.py

提交回复
热议问题