Understand textsum and seq2seq with Attention from TensorFlow Code

This article will peel textsum algorithm and seq2seq with attention mechanism based on the project on TensorFlow/models into very little detail. (tensorflow/models/research/textsum)And here is the code of my personal realization of seq2seq with Attention based on the official code, and it turns out to converge quite well.

Transfer Text into .bin Type

This part is written in textsum_data_convert.py. The inputs are text file, encoded in utf-8. It has two functions.

  1. It outputs .bin file which transform the original unstructured text file into structured .bin file.
  2. It outputs vocab file which counts frequncy of certain word in text files and stores the word as well as its frequency.
    As for the first function, we have:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def _convert_files_to_binary(input_filenames, output_filename):
with open(output_filename, 'wb') as writer:
for filename in input_filenames:
with open(filename, 'r') as f:
document = f.read()
document_parts = document.split('\n', 1)
assert len(document_parts) == 2
title = '<d><p><s>' + document_parts[0] + '</s></p></d>'
# encode the title into the form of ('UTF-8'), otherwise in tf_example.features.feature[].bytes_list.value.extend()
# will report an error.
title = title.encode('utf8')
# body = document_parts[1].decode('utf8').replace('\n', ' ').replace('\t', ' ')
# AttributeError: 'str' object has no attribute 'decode' -> by Murphy 02.Jan.18
try:
body = document_parts[1].decode('utf8').replace('\n', ' ').replace('\t', ' ')
except:
body = document_parts[1].replace('\n', ' ').replace('\t', ' ')
sentences = sent_tokenize(body)
body = '<d><p>' + ' '.join(['<s>' + sentence + '</s>' for sentence in sentences]) + '</p></d>'
body = body.encode('utf8')
tf_example = example_pb2.Example()
tf_example.features.feature['article'].bytes_list.value.extend([body])
tf_example.features.feature['abstract'].bytes_list.value.extend([title])
tf_example_str = tf_example.SerializeToString()
str_len = len(tf_example_str)
writer.write(struct.pack('q', str_len))
writer.write(struct.pack('%ds' % str_len, tf_example_str))

It processes all the text files under the path “./data/cnn/stories”. These text files are consist of two parts, “title” and “body” (also known as “abstract” and “article”), which are seperated by the first “\n”. For title and body, both of them are decorated with prefix <\d><\p><\s> and postfix </\s></\p></\d>. Moreover, body part is seperated into sentences by the fucntion sent_tokenize from nltk package. Every sentence in body is decorated with prefix <\s> and postfix </\s>. The “/n”s in body are replaced by “/t”s. Then, the processed “title” and “body” are writen into text file through the format of example_pb2.Example(), which is a format imported from tensorflow.core.example. It can simply be preceived as a storing format without special meaning. If you like, you could write your own format to replace example_pb2.Example().

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def _text_to_vocabulary(input_directories, vocabulary_filename, max_words=200000):
filenames = _get_filenames(input_directories)
counter = collections.Counter()
for filename in filenames:
with open(filename, 'r') as f:
document = f.read()
words = document.split()
counter.update(words)
with open(vocabulary_filename, 'w') as writer:
for word, count in counter.most_common(max_words - 2):
writer.write(word + ' ' + str(count) + '\n')
writer.write('<s> 0\n')
writer.write('</s> 0\n')
writer.write('<UNK> 0\n')
writer.write('<PAD> 0\n')

collections.Counter() works as a collector (or you can see as a dictionary) in the format of {“word1”: frequency1 , “word2”: frequency2 ,…}. By executing “.update” operation and pass a list of words into the Counter, you could let this class automatically update the word-frequency dictionary. At the end, it adds some symbols with special meaning to the Counter. (<\s> </\s> denote the start and the end of a sentence; is used to padding the blank to make all of the sentences have same length)

batcher_reader: management of input data

First, let’s have a glance of the major part of function main:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
vocab = data.Vocab(FLAGS.vocab_path, 1000000)
# Check for presence of required special tokens.
assert vocab.CheckVocab(data.PAD_TOKEN) > 0
assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
assert vocab.CheckVocab(data.SENTENCE_START) > 0
assert vocab.CheckVocab(data.SENTENCE_END) > 0
batch_size = 4
if FLAGS.mode == 'decode':
batch_size = FLAGS.beam_size
hps = seq2seq_attention_model.HParams(
mode=FLAGS.mode, # train, eval, decode
min_lr=0.001, # min learning rate.
lr=0.015, # learning rate
batch_size=batch_size,
enc_layers=1,
enc_timesteps=800,
dec_timesteps=50,
min_input_len=2, # discard articles/summaries < than this
num_hidden=256, # for rnn cell
emb_dim=128, # If 0, don't use embedding
max_grad_norm=2,
num_softmax_samples=4096) # If 0, no sampled softmax.
batcher = batch_reader.Batcher(
FLAGS.data_path, vocab, hps, FLAGS.article_key,
FLAGS.abstract_key, FLAGS.max_article_sentences,
FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
truncate_input=FLAGS.truncate_input)
tf.set_random_seed(FLAGS.random_seed)
if hps.mode == 'train':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Train(model, batcher)
elif hps.mode == 'eval':
model = seq2seq_attention_model.Seq2SeqAttentionModel(
hps, vocab, num_gpus=FLAGS.num_gpus)
_Eval(model, batcher, vocab=vocab)
elif hps.mode == 'decode':
decode_mdl_hps = hps
# Only need to restore the 1st step and reuse it since
# we keep and feed in state for each step's output.
decode_mdl_hps = hps._replace(dec_timesteps=1)
model = seq2seq_attention_model.Seq2SeqAttentionModel(
decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
decoder.DecodeLoop()

We can see that main function do the following steps:

  1. read in hyperparameters;
  2. use batch_reader to manage the data;
  3. execute training/evaluating/decoding process.

For batch_reader, it only contains one class: Batcher. It has following methods:

1
2
3
4
5
6
7
8
9
10
11
12
13
class Batcher(object):
def NextBatch(self):
def _FillInputQueue(self):
def _FillBucketInputQueue(self):
def _WatchThreads(self):
def _TextGenerator(self, example_gen):
def _GetExFeatureText(self, ex, key):

We are gonna to dissect these code one by one:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def _FillInputQueue(self):
"""Fill input queue with ModelInput."""
start_id = self._vocab.WordToId(data.SENTENCE_START)
end_id = self._vocab.WordToId(data.SENTENCE_END)
pad_id = self._vocab.WordToId(data.PAD_TOKEN)
input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
while True:
(article, abstract) = six.next(input_gen)
article_sentences = [sent.strip() for sent in
data.ToSentences(article.decode('utf-8'), include_token=False)]
abstract_sentences = [sent.strip() for sent in
data.ToSentences(abstract.decode('utf-8'), include_token=False)]
enc_inputs = []
# Use the <s> as the <GO> symbol for decoder inputs.
dec_inputs = [start_id]
# Convert first N sentences to word IDs, stripping existing <s> and </s>.
for i in xrange(min(self._max_article_sentences,
len(article_sentences))):
enc_inputs += data.GetWordIds(article_sentences[i], self._vocab)
for i in xrange(min(self._max_abstract_sentences,
len(abstract_sentences))):
dec_inputs += data.GetWordIds(abstract_sentences[i], self._vocab)
# Filter out too-short input
if (len(enc_inputs) < self._hps.min_input_len or
len(dec_inputs) < self._hps.min_input_len):
tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d',
len(enc_inputs), len(dec_inputs))
continue
# If we're not truncating input, throw out too-long input
if not self._truncate_input:
if (len(enc_inputs) > self._hps.enc_timesteps or
len(dec_inputs) > self._hps.dec_timesteps):
tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d',
len(enc_inputs), len(dec_inputs))
continue
# If we are truncating input, do so if necessary
else:
if len(enc_inputs) > self._hps.enc_timesteps:
enc_inputs = enc_inputs[:self._hps.enc_timesteps]
if len(dec_inputs) > self._hps.dec_timesteps:
dec_inputs = dec_inputs[:self._hps.dec_timesteps]
# targets is dec_inputs without <s> at beginning, plus </s> at end
targets = dec_inputs[1:]
targets.append(end_id)
# Now len(enc_inputs) should be <= enc_timesteps, and
# len(targets) = len(dec_inputs) should be <= dec_timesteps
enc_input_len = len(enc_inputs)
dec_output_len = len(targets)
# Pad if necessary
while len(enc_inputs) < self._hps.enc_timesteps:
enc_inputs.append(pad_id)
while len(dec_inputs) < self._hps.dec_timesteps:
dec_inputs.append(end_id)
while len(targets) < self._hps.dec_timesteps:
targets.append(end_id)
element = ModelInput(enc_inputs, dec_inputs, targets, enc_input_len,
dec_output_len, ' '.join(article_sentences),
' '.join(abstract_sentences))
self._input_queue.put(element)

self._vocab is a class defined in data.py. To understand it briefly, you can view it as composed of two dictionary: _word_to_id and _id_to_word. In addition, this class provides methods related to “word to id” and “id to word” processes. As for self._TextGenerator method, it uses yield instead of return to deal with the huge memory usage. To master yield, you must understand that when you call the function, the code you have written in the function body does not run. The function only returns the generator object. To understand better, please refer to this article.

1
2
3
4
5
6
7
8
9
10
11
12
def _TextGenerator(self, example_gen):
"""Generates article and abstract text from tf.Example."""
while True:
e = six.next(example_gen)
try:
article_text = self._GetExFeatureText(e, self._article_key) //return ex.features.feature[key].bytes_list.value[0]
abstract_text = self._GetExFeatureText(e, self._abstract_key)
except ValueError:
tf.logging.error('Failed to get article or abstract from example')
continue
yield (article_text, abstract_text)

six.next(example_gen) generates a tf.Example type if data of the following sample format. It’s long, but you only need to have a glanpse of its basic format.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
features {
feature {
key: "abstract"
value {
bytes_list {
value: "<d><p><s>(CNN) -- Republican presidential contender Michele Bachmann defended her position on gay rights, the HPV vaccine and the debt ceiling as she made her debut on \"The Tonight Show with Jay Leno.\"</s></p></d>"
}
}
}
feature {
key: "article"
value {
bytes_list {
value: "<d><p><s> Bachmann has more usually been the butt of Leno\'s jokes -- a point he made as he thanked her for being a good sport.</s> <s>\"We\'ve done a million jokes.</s> <s>Hopefully, you haven\'t been ... watching any of them,\" he said, as she joined him on set.</s> <s>But Leno largely skipped the jokes Friday as he quizzed the Minnesota congresswoman on her political positions.</s> <s>First up was the issue of the HPV vaccine, a subject on which Bachmann hit fellow Republican contender and Texas Gov.</s> <s>Rick Perry hard in this week\'s CNN/Tea Party Republican Debate.</s> <s>Perry signed an executive order in 2007 that required Texas schoolgirls to receive vaccinations against the sexually transmitted HPV, although it wasn\'t implemented.</s> <s>Bachmann told Leno that Perry\'s action had been \"an abuse of executive power\" and had sparked concern over \"crony capitalism,\" an apparent reference to the fact that a former Perry aide was a top lobbyist for Merck, the manufacturer for the HPV vaccine.</s> <s>Merck donated to Perry\'s campaign fund.</s> <s>She added: \"The concern is that there\'s, you know, potentially side effects that can come with something like that.</s> <s>But it gives a false sense of assurance to a young woman when she has that that if she\'s sexually active that she doesn\'t have to worry about sexually transmitted diseases.\"</s> <s>Leno responded: \"Well, I don\'t know if it gives assurance.</s> <s>It can prevent cervical cancer; correct?\"</s> <s>He then pressed Bachmann over comments she made earlier this week in which she said a woman had approached the congresswoman to say her daughter had suffered \"mental retardation\" as a result of receiving the vaccination.</s> <s>There had been no recorded cases of such side effects despite 30 million people receiving the jab, Leno pointed out.</s> <s>\"I wasn\'t speaking as a doctor.</s> <s>I wasn\'t speaking as a scientist.</s> <s>I was just relating what this woman said,\" Bachmann replied.</s> <s>The former tax attorney and mother of five, who won the Iowa straw poll last month but has seen her poll ratings slide since Perry entered the race, also defended two clinics she runs with her husband, offering what she said was a Christian counseling service.</s> <s>The clinics have come under fire over claims they use a controversial therapy that encourages gay and lesbian patients to change their sexual orientation.</s> <s>Asking Bachmann why gay people shouldn\'t have the right to be happily married, Leno said: \"That whole \'pray the gay away\' thing, What?</s> <s>I don\'t get that.\"</s> <s>Bachmann said the clinics did not discriminate, but repeated her position that marriage should be between a man and a woman.</s> <s>Quizzed on her opposition to raising the debt ceiling, Bachmann said she would have taken the same position whether it had been President Barack Obama or George W. Bush in power.</s> <s>On Afghanistan, Bachmann failed to answer whether she thought American forces should withdraw, but paid tribute to the \"unbelievable job\" done by U.S. service men and women there.</s> <s>Leno\'s final question was whom Bachmann would pick as a running mate if she wins the Republican nomination, suggesting she might want someone more moderate to balance her views.</s> <s>She joked: \"Well, you\'re taken.</s> <s>You don\'t want a cut in pay, so what can I say?\"</s> <s>Leno replied: \"Well, we\'d probably have an argument over that gay thing.\"</s> <s>@highlight Bachmann continues her criticism of Rick Perry over the HPV vaccine @highlight Leno presses Bachmann on gay marriage and the counseling clinics she runs @highlight The presidential hopeful says her opposition to raising the debt ceiling was not political @highlight Bachmann declines to say who her running mate would be</s></p></d>"
}
}
}
}