import numpy as np
import tensorflow as tf
from string import punctuation
from collections import Counter

# 介绍预览该项目,并介绍该项目网络结构!

with open('../datas/sentiment/reviews.txt', 'r') as f:
    reviews = f.read()
with open('../datas/sentiment/labels.txt', 'r') as f:
    labels = f.read()
print(reviews[0])


# 数据预处理
# todo-1、移除所有标点符号(生成1个没有标点符号的列表,然后再组合成文本)
all_text = ''.join([c for c in reviews if c not in punctuation])

# todo 2、以'\n'为分隔符,拆分所有评论
reviews = all_text.split('\n')
all_text = ' '.join(reviews)
# 文本拆分为单独的单词列表
words = all_text.split()


# todo 1、创建数据字典:{单词:整数}。后面我们会对input向量填充0,编码的整数从1开始(不是0)
#      2、将所有文本转换成为整数,并存储到新的列表中:reviews_ints.

counts = Counter(words)
# 按计数进行排序
vocab = sorted(counts, key=counts.get, reverse=True)
# 生成字典:{单词:整数}
vocab_to_int = {word: ii for ii, word in enumerate(vocab, 1)}

# 将文本列表 转换为 整数列表
reviews_ints = []
for each in reviews:
    reviews_ints.append([vocab_to_int[word] for word in each.split()])

# todo-对labels进行编码: 将标签转换为数值:positive==1 和 negative ==0
labels = labels.split('\n')
labels = np.array([1 if each == 'positive' else 0 for each in labels])


# todo-有一个问题:
"""
有一条评论长度为0;且最长的评论长度为2514,过长了一点。所以将其截断成200的长度:
   1、评论长度小于200的,对其左边填充0, 
   2、对于大于200的,只截取其前200个单词。
"""
review_lens = Counter([len(x) for x in reviews_ints])
print("长度为0的评论数量: {}".format(review_lens[0]))
print("最大评论的长度为: {}".format(max(review_lens)))

# todo-从  reviews_ints列表中移除0长度的评论。
# 获得长度非0的 评论的索引号
non_zero_idx = [ii for ii, review in enumerate(reviews_ints) if len(review) != 0]
# 为了确保代码不出错,用in判断下
reviews_ints = [reviews_ints[ii] for ii in non_zero_idx]
labels = np.array([labels[ii] for ii in non_zero_idx])


# todo-练习
"""
需求:用 review_ints中的数据创建数组: features 。要求:每一行都是长度为200:如果评论小于200,那么对其左填充0。 
   举例:如果评论为 ['best', 'movie', 'ever'], 其整数形式为:[117, 18, 128],那么左填充0后,
       应该像这样:  [0, 0, 0, ..., 0, 117, 18, 128];评论大于200字的,只取其前200单词即可。
"""
seq_len = 200
# 生成一个25000*200的全0矩阵。
features = np.zeros((len(reviews_ints), seq_len), dtype=int)
# 将reviews_ints值逐行 赋值给features. 可以print出来检查一下。
for i, row in enumerate(reviews_ints):
    features[i, -len(row):] = np.array(row)[:seq_len]   # 注意这里的技巧。


# todo-练习 构建训练、验证、测试集。定义了一个拆分的分数 : split_frac(0。8--0.9) ,按该分数比率保留数据到训练数据集
split_frac = 0.8
split_idx = int(len(features)*0.8)
train_x, val_x = features[:split_idx], features[split_idx:]
train_y, val_y = labels[:split_idx], labels[split_idx:]

test_idx = int(len(val_x)*0.5)
val_x, test_x = val_x[:test_idx], val_x[test_idx:]
val_y, test_y = val_y[:test_idx], val_y[test_idx:]

print("\t\t\tFeature Shapes:")
print("Train set: \t\t{}".format(train_x.shape),
      "\nValidation set: \t{}".format(val_x.shape),
      "\nTest set: \t\t{}".format(test_x.shape))


# todo-开始创建模型图
lstm_size = 256
lstm_layers = 1
batch_size = 128
learning_rate = 0.001

n_words = len(vocab_to_int)

# todo-练习 创建占位符

# 创建图对象
graph = tf.Graph()
# 将节点添加到图中:
with graph.as_default():
    inputs_ = tf.placeholder(tf.int32, [None, None], name='inputs')
    labels_ = tf.placeholder(tf.int32, [None, None], name='labels')
    keep_prob = tf.placeholder(tf.float32, name='keep_prob')

# todo-练习 创建嵌入层
"""
因为原始单词总量有72000个,直接one-hot编码后输入网络太不效率了,所以我们通过word2vec方法训练一个嵌入权重矩阵。
通过  `tf.Variable`创建embedding查找矩阵 。通过查找嵌入矩阵获得嵌入向量,再传入LSTM cell。 [`tf.nn.embedding_lookup`]
该函数接收 embedding 矩阵,和 input tensor(评论的向量);返回:嵌入向量 。例如:嵌入层节点数量== 200 ,那么该函数将返回:
一个张量,size==  [batch_size, 200].
"""
# 嵌入向量大小embedding vectors(既嵌入层节点数量)
embed_size = 300
with graph.as_default():
    embedding = tf.Variable(tf.random_uniform((n_words, embed_size), -1, 1))
    embed = tf.nn.embedding_lookup(embedding, inputs_)


with graph.as_default():
    # 创建基础的LSTM cell
    lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size)

    # 对cell添加dropout
    drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)

    # 堆栈多个LSTM layers
    cell = tf.nn.rnn_cell.MultiRNNCell([drop] * lstm_layers)

    # 将所有cell初始化为0状态。
    initial_state = cell.zero_state(batch_size, tf.float32)
    print('initial state的shape是:{}'.format(initial_state))

# todo-RNN正向传播
"""
真正的运行 RNN 节点。需要使用函数 [`tf.nn.dynamic_rnn`]。需要传入2个参数:多层LSTM单元(multiple layered LSTM `cell`),以及输入(inputs)。
outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
同时我们将上面定义的 `initial_state`传给了 RNN网络。这是在隐藏层之间传递的单元状态。  `tf.nn.dynamic_rnn` 函数帮我们完成了绝大多数工作。
并返回每一步的输出和隐藏层最终状态。

> **练习:** 使用 `tf.nn.dynamic_rnn`向RNN网络添加正向传播。注意:这里我们传入的inputs,实际是嵌入层(embedding layer)的输出: `embed`。
"""
with graph.as_default():
    outputs, final_state = tf.nn.dynamic_rnn(cell, embed,
                                             initial_state=initial_state)
    print('final_state的shape 是:{}'.format(final_state))
    print('隐藏层的output的shape 是:{}'.format(outputs))  # shape=(batch_size=128, ?, lstm_size=256)

# RNN输出
"""
我们需要抓取最后一个输出,通过:`outputs[:, -1, :]`, 在计算它与`labels_`的损失。
"""

loss_method = 'MSE'
with graph.as_default():
    if loss_method == 'MSE':
        # 方法1、用最小均方差来做。
        predictions = tf.contrib.layers.fully_connected(outputs[:, -1, :], 1, activation_fn=tf.sigmoid)
        print('outputs[:, -1]的shape 是:{}'.format(outputs[:, -1, :]))  # shape=(batch=128, lstm_size=256)
        cost = tf.losses.mean_squared_error(labels_, predictions)
    else:
        # 方法2:用sigmoid交叉熵来做。
        logits = tf.contrib.layers.fully_connected(outputs[:, -1, :], 1, activation_fn=None)
        predictions = tf.nn.sigmoid(logits)
        cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(labels_, tf.float32), logits=logits))
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)

# Validation accuracy
with graph.as_default():
    correct_pred = tf.equal(tf.cast(tf.round(predictions), tf.int32), labels_)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))


# todo 定义batch函数。1、我们移除了最后一个batch,以便我们的batches是齐整的。
#     2、迭代 `x` 和 `y` 数组,以 `[batch_size]`为单位,返回上述数组的切片。
def get_batches(x, y, batch_size=128):
    n_batches = len(x) // batch_size
    x, y = x[:n_batches * batch_size], y[:n_batches * batch_size]
    for ii in range(0, len(x), batch_size):
        yield x[ii:ii + batch_size], y[ii:ii + batch_size]


# 训练
def training():
    epochs = 20

    with graph.as_default():
        saver = tf.train.Saver()

    with tf.Session(graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        iteration = 1
        for e in range(epochs):
            # 要先跑 state
            state = sess.run(initial_state)

            for ii, (x, y) in enumerate(get_batches(train_x, train_y, batch_size), 1):
                feed = {inputs_: x,
                        labels_: y[:, None],
                        keep_prob: 0.5,
                        initial_state: state}
                # todo - 跑出来的state 下一个batch又喂给了模型。
                loss, state, _ = sess.run([cost, final_state, optimizer], feed_dict=feed)

                if iteration % 5 == 0:
                    print("Epoch: {}/{}".format(e, epochs),
                          "Iteration: {}".format(iteration),
                          "Train loss: {:.3f}".format(loss))

                if iteration % 25 == 0:
                    # 跑验证数据
                    val_acc = []
                    val_state = sess.run(cell.zero_state(batch_size, tf.float32))
                    for x, y in get_batches(val_x, val_y, batch_size):
                        feed = {inputs_: x,
                                labels_: y[:, None],
                                keep_prob: 1,
                                initial_state: val_state}
                        batch_acc, val_state = sess.run([accuracy, final_state], feed_dict=feed)
                        val_acc.append(batch_acc)
                    print("Val acc: {:.5f}".format(np.mean(val_acc)))
                iteration += 1
        saver.save(sess, "checkpoints/sentiment.ckpt")


# 测试:
def testing():
    test_acc = []
    with graph.as_default():
        saver = tf.train.Saver()
    with tf.Session(graph=graph) as sess:
        saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
        test_state = sess.run(cell.zero_state(batch_size, tf.float32))
        for ii, (x, y) in enumerate(get_batches(test_x, test_y, batch_size), 1):
            feed = {inputs_: x,
                    labels_: y[:, None],
                    keep_prob: 1,
                    initial_state: test_state}
            batch_acc, test_state = sess.run([accuracy, final_state], feed_dict=feed)
            test_acc.append(batch_acc)
        print("Test accuracy: {:.5f}".format(np.mean(test_acc)))


if __name__ == '__main__':
    training()