from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
from pyspark.sql import Row

import re
import numpy as np
from time import time
from sklearn.datasets import fetch_20newsgroups

from pyspark.ml.feature import CountVectorizer, HashingTF, IDF
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.ml.clustering import LDA

np.random.seed(0)

if __name__ == "__main__":

sc = SparkContext('local', 'lda')
sqlContext = SQLContext(sc)

spark = SparkSession\
.builder\
.appName("LDA")\
.getOrCreate()


num_features = 8000 #vocabulary size
num_topics = 20 #fixed for LDA

#print "loading 20 newsgroups dataset..."
tic = time()
dataset = fetch_20newsgroups(shuffle=True, random_state=0, remove=('headers','footers','quotes'))
train_corpus = dataset.data # a list of 11314 documents / entries
toc = time()
print ("elapsed time: %.4f sec" %(toc - tic) )

#distribute data
corpus_rdd = sc.parallelize(train_corpus)
corpus_rdd = corpus_rdd.map(lambda doc: re.sub(r"[^A-Za-z]", " ", doc))
corpus_rdd = corpus_rdd.map(lambda doc: u"".join(doc).encode('utf-8').strip())

rdd_row = corpus_rdd.map(lambda doc: Row(raw_corpus=str(doc)))
newsgroups = spark.createDataFrame(rdd_row)

tokenizer = Tokenizer(inputCol="raw_corpus", outputCol="tokens")
newsgroups = tokenizer.transform(newsgroups)
newsgroups = newsgroups.drop('raw_corpus')

stopwords = StopWordsRemover(inputCol="tokens", outputCol="tokens_filtered")
newsgroups = stopwords.transform(newsgroups)
newsgroups = newsgroups.drop('tokens')

count_vec = CountVectorizer(inputCol="tokens_filtered", outputCol="tf_features", vocabSize=num_features, minDF=2.0)
count_vec_model = count_vec.fit(newsgroups)
vocab = count_vec_model.vocabulary
newsgroups = count_vec_model.transform(newsgroups)
newsgroups = newsgroups.drop('tokens_filtered')

#hashingTF = HashingTF(inputCol="tokens_filtered", outputCol="tf_features", numFeatures=num_features)
#newsgroups = hashingTF.transform(newsgroups)
#newsgroups = newsgroups.drop('tokens_filtered')

idf = IDF(inputCol="tf_features", outputCol="features")
newsgroups = idf.fit(newsgroups).transform(newsgroups)
newsgroups = newsgroups.drop('tf_features')

lda = LDA(k=num_topics, featuresCol="features", seed=0)
model = lda.fit(newsgroups)

topics = model.describeTopics()
topics.show()

model.topicsMatrix()

topics_rdd = topics.rdd

topics_words = topics_rdd\
.map(lambda row: row['termIndices'])\
.map(lambda idx_list: [vocab[idx] for idx in idx_list])\
.collect()

for idx, topic in enumerate(topics_words):
print ("topic: ", idx)
print ("----------")
for word in topic:
print( word)
print( "----------")
elapsed time: 1.0284 sec
+-----+--------------------+--------------------+
|topic| termIndices| termWeights|
+-----+--------------------+--------------------+
| 0|[0, 552, 967, 108...|[0.01258332472159...|
| 1|[0, 1004, 76, 40,...|[0.08220619222238...|
| 2|[3, 0, 373, 18, 2...|[0.11591833022404...|
| 3|[3541, 1057, 2060...|[0.02274190214796...|
| 4|[0, 87, 364, 3645...|[0.01822486526972...|
| 5|[104, 0, 527, 188...|[0.01069089006155...|
| 6|[1, 4, 7, 16, 8, ...|[0.40037927605170...|
| 7|[0, 1079, 24, 325...|[0.01809503390053...|
| 8|[0, 50, 148, 19, ...|[0.01119972590376...|
| 9|[0, 261, 356, 340...|[0.00977327728361...|
| 10|[0, 182, 35, 743,...|[0.01214197116201...|
| 11|[0, 308, 706, 561...|[0.01776294690247...|
| 12|[0, 69, 38, 9, 35...|[0.01212285192631...|
| 13|[179, 1219, 0, 11...|[0.01490648021084...|
| 14|[0, 93, 133, 83, ...|[0.01778007036882...|
| 15|[569, 949, 0, 124...|[0.02115331163136...|
| 16|[755, 56, 303, 20...|[0.02666471302923...|
| 17|[171, 0, 299, 895...|[0.01153933315409...|
| 18|[0, 208, 574, 116...|[0.01336781044283...|
| 19|[831, 0, 2149, 87...|[0.00780105775382...|
+-----+--------------------+--------------------+

topic: 0
----------

monitor
printer
vga
pin
print
cable
apple
please
video
----------
topic: 1
----------

pts
com
edu
la
pt
vs
period
pp
w
----------
topic: 2
----------
x

entry
n
output
oname
c
entries
eof
file
----------
topic: 3
----------
den
dod
tank
b'if

accelerators
ctrl
radius
td
rc