The newest Tensorflow api about seq2seq model has included scheduled sampling:
https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTraini
This might also help you. This is for the case where you want to do scheduled sampling at each decoding step separately.
import tensorflow as tf
import numpy as np
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import bernoulli
batch_size = 64
vocab_size = 50000
emb_dim = 128
output = tf.get_variable('output',
initializer=tf.constant(np.random.rand(batch_size,vocab_size)))
base_next_inputs = tf.get_variable('input',
initializer=tf.constant(np.random.rand(batch_size,emb_dim)))
embedding = tf.get_variable('embedding',
initializer=tf.constant(np.random.rand(vocab_size,emb_dim)))
select_sampler = bernoulli.Bernoulli(probs=0.99, dtype=tf.bool)
select_sample = select_sampler.sample(sample_shape=batch_size,
seed=123)
sample_id_sampler = categorical.Categorical(logits=output)
sample_ids = array_ops.where(
select_sample,
sample_id_sampler.sample(seed=123),
gen_array_ops.fill([batch_size], -1))
where_sampling = math_ops.cast(
array_ops.where(sample_ids > -1), tf.int32)
where_not_sampling = math_ops.cast(
array_ops.where(sample_ids <= -1), tf.int32)
sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
where_not_sampling)
sampled_next_inputs = tf.nn.embedding_lookup(embedding,
sample_ids_sampling)
base_shape = array_ops.shape(base_next_inputs)
result1 = array_ops.scatter_nd(indices=where_sampling,
updates=sampled_next_inputs, shape=base_shape)
result2 = array_ops.scatter_nd(indices=where_not_sampling,
updates=inputs_not_sampling, shape=base_shape)
result = result1 + result2
I used the tensorflow documentation code to make this example. https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/contrib/seq2seq/python/ops/helper.py