import json
import time
import uuid
from tqdm import tqdm
from json import JSONDecodeError
from flask import Flask, jsonify, request
import logging

from call_tokenize import get_tokenize_and_query

app = Flask(__name__)

kv_dict = {}


def arpa_read(files):
    with open(files, mode='r', encoding='UTF-8') as f1:

        arpa_list = f1.readlines()

        for line in tqdm(arpa_list):
            try:
                splitList = line.strip("\n").split("\t")
                if len(splitList) == 2:
                    key = splitList[1]
                    v1 = float(splitList[0])
                    v2 = 0
                elif len(splitList) == 3:
                    key = splitList[1]
                    v1 = float(splitList[0])
                    v2 = float(splitList[2])
                else:
                    continue

                kv_dict[key] = (v1, v2)
            except Exception:
                continue

        return kv_dict


#arpa_read("D:\\Users\\72152411\\Documents\\vchat\\ChatFiles\\trainfile.lm")  # 调用函数arpa文件转换成字典
arpa_read("./1_9_arpa")  # 调用函数arpa文件转换成字典


def _score(sentence):

    def calculate_sentence_start(word_0, word_1):
        key = word_0 + " " + word_1
        if key in kv_dict:
            return kv_dict[key][0]
        else:
            return kv_dict[word_1][0] + kv_dict[word_0][1]

    def score_bigram_prob(w1,w2):
        s2 = 0
        if w1 == "<s>":
            calculate_sentence_start(w1,w2)
        else:
            key = w1 + " " + w2
            if key in kv_dict:
                s2 += kv_dict[key][0]
                return s2
            else:
                s2 += kv_dict[w2][0] + kv_dict[w1][1]
                return s2

    def score_trigram_prob(trigram_list):
        first, second, third = trigram_list

        tri_key = " ".join(trigram_list)
        if tri_key in kv_dict:
            return kv_dict[tri_key][0]
        # 需要回退
        else:
            bi_key = second + " " + third
            bi_bow_key = first + " " + second
            # 回退到bigram
            if bi_key in kv_dict:
                bi_prob = kv_dict[bi_key][0]
                bi_bow = 0
                # 后面的二元有, 前面上文的backoff可以查到
                if bi_bow_key in kv_dict:
                   bi_bow = kv_dict[bi_bow_key][1]
                bi_prob += bi_bow
                return bi_prob
            # 回退到unigram
            else:
                if third not in kv_dict:
                    raise ValueError
                bi_bow = 0
                uni_bow = 0
                # 前面上文的backoff可以查到
                if bi_bow_key in kv_dict:
                    bi_bow = kv_dict[bi_bow_key][1]
                if second in kv_dict:
                    uni_bow = kv_dict[second][1]
                uni_prob = kv_dict[third][0] + uni_bow + bi_bow
                return uni_prob

    try:
        sentence = ("<s> " + sentence).strip()
        wordArr = sentence.strip().split(" ")  # 对输入语料进行切分

        wordArrK = tuple(wordArr)  # 转换成元组  因为字典的k不能是list  将切分好的语料和字典的k匹配

        total_score = 0

        if len(wordArrK) == 1:  # <s>
            return kv_dict.get(wordArrK[0])[0]
        elif len(wordArrK) == 2:  # <s> 今天
            return calculate_sentence_start(wordArrK[0], wordArrK[1])
        else:
            total_score += calculate_sentence_start(wordArrK[0], wordArrK[1])
            for i in range(0, len(wordArrK) - 2):
                total_score += score_trigram_prob(wordArrK[i:i + 3])
            return total_score

    except Exception as e:
        app.logger.exception(e)
        return -1


@app.route('/ngram_score', methods=['POST'])
def score():
    try:

        parameters = request.form
        print(parameters)
        sentence = parameters['sentence']
        print(sentence)
        score_result = _score(sentence)

        response = {"score": score_result}

        response = jsonify(response)
    except Exception:
        response = 'Internal error', 500

    return response


@app.route('/input_score', methods=['POST'])
def input_score():
    try:

        parameters = request.form
        user_input = parameters['user_input']
        keyboard_type = parameters['keyboard_type']
        
        tokenize_result = get_tokenize_and_query(user_input, keyboard_type)
        
        result = []
        for words, score in tokenize_result:
            sent = " ".join(words)
            score_result = _score(sent)
            result.append((sent, score_result, score))

        response = result

        response = jsonify(response)
    except Exception as e:
        app.logger.exception(e)
        response = 'Internal error', 500

    return response


if __name__ == '__main__':
    app.run(host='0.0.0.0', port=9494)