1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
| def evaluate(input_sentence): attention_matrix = np.zeros((max_length_output, max_length_input)) input_sentence = preprocess_sentence(input_sentence)
inputs = [input_tokenizer.word_index[token] for token in input_sentence.split(' ')] inputs = keras.preprocessing.sequence.pad_sequences([inputs], maxlen = max_length_input, padding= 'post') inputs = tf.convert_to_tensor(inputs)
results = ''
encoding_hidden = tf.zeros((1, units))
encoding_outputs, encoding_hidden = encoder(inputs, encoding_hidden) decoding_hidden = encoding_hidden
decoding_input = tf.expand_dims([out_tokenizer.word_index['<start>']], 0) for t in range(max_length_output): predictions, decoding_hidden, attention_weights = decoder(decoding_input, decoding_hidden, encoding_outputs) attention_weights = tf.reshape(attention_weights, (-1,)) attention_matrix[t] = attention_weights.numpy()
predicted_id = tf.argmax(predictions[0]).numpy()
results += out_tokenizer.index_word[predicted_id] + ' '
if out_tokenizer.index_word[predicted_id] == '<end>': return results, input_sentence, attention_matrix
decoding_input = tf.expand_dims([predicted_id], 0) return results, input_sentence, attention_matrix
def plot_attention(attention_matrix, input_sentence, predicted_sentence): fig = plt.figure(figsize=(10,10)) ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention_matrix, cmap='viridis')
font_dict = {'fontsize': 14}
ax.set_xticklabels([''] + input_sentence, fontdict = font_dict, rotation = 90) ax.sey_yticklables([''] + predicted_sentence, fontdict = font_dict,) plt.show()
def translate(input_sentence): results, input_sentence, attention_matrix = evaluate(input_sentence)
print("Input: %s" % (input_sentence)) print("Predicted translation: %s" % (results))
attention_matrix = attention_matrix[:len(results.split(' ')), :len(input_sentence.split(' '))] plot_attention(attention_matrix, input_sentence.split(' '), results.split(' '))
|