# Build a simple classifier for the Enron email (sub)dataset

# Imports first
import os, datetime, string, re, random

import MySQLdb

from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from nltk.tokenize import WordPunctTokenizer
from nltk.collocations import BigramCollocationFinder
from nltk.metrics import BigramAssocMeasures
from nltk.classify import NaiveBayesClassifier
from nltk.classify.util import accuracy

# Runs standard exclusions in the crawling database.  To be run before doing a full crawl
dbsettings=dict(host = "localhost", user = "yourname", passwd = "yourpassword", db = "kiwipycon")



def rank_a_subject(resultfield,myparams):
	"""
	Takes the probability values for a given model and for a set of messages (training or scoring)
	and assigns a decile rank.  Output can be used to look at lift curves.
	"""
	
	conn = MySQLdb.connect (**dbsettings)
	cursor = conn.cursor(MySQLdb.cursors.DictCursor)
	
	flag=1 if myparams['mode'].lower()=='train' else 0
	query="select id, %s from basedata_training where TrainingSet=%s order by %s " % (resultfield, flag, resultfield)
	
	cursor.execute(query)
	rank_set = cursor.fetchall()
	
	upper_limit=len(rank_set)+1

	for idx, record in  enumerate(rank_set):
		rank=int((float(idx)/upper_limit*100)/10)
		query="update basedata_training set %s='%s' where id=%s" % (resultfield+'_RANK', rank, record['id'])

		cursor.execute(query)
		conn.commit()
	
	# Finish up with the database
	cursor.close()
	conn.close()
	
	return


def extract_features(record,myparams,stemmer,stopset,tokenizer):
	'''
	Extract features to use in our classifier. 
	
	Features include those defined already in the database, and those determined here
	using NLTK tools (stemmers, bigrammers) on the enron email body and subject texts.
	'''

	features={}
	
	#------
	# Pre-extracted features first
	if myparams['std_features']:
		features['from_internal'] = True if record['from_internal'] else False
		
		if record['to_howmany']<=1:
			features['sent_to_howmany']='Single'
		elif record['to_howmany']<=5:
			features['sent_to_howmany']='A few'
		else:
			features['sent_to_howmany']='Many'


		if record['to_howmany'] > 0:
			if record['to_howmany_internal']>0 and (float(record['to_howmany'])/float(record['to_howmany_internal'])>0.7):
				features['to_who']='Internals'
			elif record['to_howmany_internal']>0 and (float(record['to_howmany'])/float(record['to_howmany_internal'])<0.3):
				features['to_who']='Externals'
			else:
				features['to_who']='Mixed Group'
		else:
			features['to_who']='No one'

			
		if record['cc_howmany']==0:
			features['ccd_to_howmany']='No one'
		elif record['cc_howmany']<=8:
			features['ccd_to_howmany']='A few'
		else:
			features['ccd_to_howmany']='Many'

		if record['cc_howmany'] > 0:
			if record['cc_howmany_internal']>0 and (float(record['cc_howmany'])/float(record['cc_howmany_internal'])>0.7):
				features['ccd_to_who']='Internals'
			elif record['cc_howmany_internal']>0 and (float(record['cc_howmany'])/float(record['cc_howmany_internal'])<0.3):
				features['ccd_to_who']='Externals'
			else:
				features['ccd_to_who']='Mixed Group'
		else:
			features['ccd_to_who']='No one'

		features['many_digits_in_text'] = True if record['perc_words_are_digits']>7 else False
		features['many_capitals_in_text'] = True if record['perc_words_are_caps']>10 else False
	
		if record['num_words_in_body']<=20:
			features['message_length']='Very Short'
		elif record['num_words_in_body']<=80:
			features['message_length']='Short'
		elif record['num_words_in_body']<=300:
			features['message_length']='Medium'
		else:
			features['message_length']='Long'


	#------
	# Now words from message body text
	if myparams['word_features'] or myparams['bigrams']:
		#stemmer = PorterStemmer()
		#stopset = set(stopwords.words('english_alternate'))
		#tokenizer = WordPunctTokenizer()
		
		text=record['msg_subject']+" "+record['msg_body']
		tokens = tokenizer.tokenize(text)

		if myparams['bigrams']:
			bigram_finder = BigramCollocationFinder.from_words(tokens)
			bigrams = bigram_finder.nbest(BigramAssocMeasures.chi_sq, 500)
			for bigram_tuple in bigrams:
				x = "%s %s" % bigram_tuple
				tokens.append(x)
	
		words =  [stemmer.stem(x.lower()) for x in tokens if x not in stopset and len(x) > 1]
		
		for word in words:
			features[word]=True
	
	return features
	
	
def train_model(subject,resultfield,myparams):
	# Bring the descriptions and other variables in from the database
	
	stemmer = PorterStemmer()
	stopset = set(stopwords.words('english_alternate'))
	tokenizer = WordPunctTokenizer()

	
	conn = MySQLdb.connect (**dbsettings)
	cursor = conn.cursor(MySQLdb.cursors.DictCursor)
	cursor.execute("select * from basedata_training where TrainingSet=1 order by randnum")
	training_set = cursor.fetchall()

	dev_set=[]
	for record in training_set:
		features = extract_features(record,myparams,stemmer,stopset,tokenizer)
		if bool(ord(record[subject]))==True:
			category_flag="IS"
		else:
			category_flag="IS NOT"
		dev_set = dev_set + [(features, category_flag)]

	if myparams['mode'].lower()=='train':
		random.shuffle(dev_set)
		cutoff = len(dev_set)*2/3
		train_set=dev_set[:cutoff]
		test_set=dev_set[cutoff:]
		
		print 'training for %s on %d instances, test on %d instances' % (subject,len(train_set), len(test_set))
		classifier = NaiveBayesClassifier.train(train_set)
	
		# See the most informative words the classifier will use, and its accuracy
		print 'accuracy for > ',subject,':', accuracy(classifier, test_set)
		classifier.show_most_informative_features(10)
	
		# Need a confusion matrix, decided to build up as data in MySQL - will interrogate from Access.
		for record in training_set:
			features = extract_features(record,myparams,stemmer,stopset,tokenizer)
			
			model_result=classifier.prob_classify(features)
			prob_is_class=model_result.prob('IS')
			#print 'RESULT: ',model_result.prob('IS'),record['msg_body']
			query="update basedata_training set %s=%s where id=%s" % (resultfield,prob_is_class,record['id'])
			cursor.execute(query)
			conn.commit()


	# Score the base, if asked for.
	if myparams['mode'].lower()=='score':
	
		print 'training for %s on all %d instances' % (subject,len(dev_set))
		classifier = NaiveBayesClassifier.train(dev_set)
		classifier.show_most_informative_features(10)
	
		cursor.execute("select * from basedata_training where TrainingSet=0")
		scoring_set = cursor.fetchall()
		for record in scoring_set:
			features = extract_features(record,myparams,stemmer,stopset,tokenizer)

			model_result=classifier.prob_classify(features)
			prob_is_class=model_result.prob('IS')
			query="update basedata_training set %s=%s where id=%s" % (resultfield,prob_is_class,record['id'])
			cursor.execute(query)
			conn.commit()
	
	# Finish up with the database for this part.
	cursor.close()
	conn.close()
	
	
	# Quick addon: Run 'rank a subject' %)
	rank_a_subject(resultfield,myparams)
	
	return

	
# Didn't use this in the end - needs to be modified as this approach to looking across models won't work.	
def pick_a_subject(myparams):
	
	conn = MySQLdb.connect (**dbsettings)
	cursor = conn.cursor(MySQLdb.cursors.DictCursor)
	
	probfields=['prob_social_pers', 'prob_admin_planning', 'prob_human_rec','prob_regulatory_accounting',
		'prob_external_relations','prob_deal_trading','prob_info_tech','prob_other_unsure']

	flag=1 if myparams['mode'].lower()=='train' else 0
	query="select id, %s from basedata_training where TrainingSet=%s" % (", ".join(probfields),flag)
	
	cursor.execute(query)
	pick_set = cursor.fetchall()
	for record in pick_set:
		proabilities=record.copy()
		del proabilities['id']
		pick = max(proabilities,key=proabilities.get)

		query="update basedata_training set likely_label_subject='%s' where id=%s" % (pick,record['id'])

		cursor.execute(query)
		conn.commit()
	
	# Finish up with the database
	cursor.close()
	conn.close()
	
	return



def main():
	myparams={'std_features':True,'word_features':True,'bigrams':True,'mode':'train'}
	
	#train_model('SUBJ_social_or_personal','prob_social_pers',myparams)
	#train_model('SUBJ_admin_or_planning','prob_admin_planning',myparams)
	#train_model('SUBJ_human_resources','prob_human_rec',myparams)
	#train_model('SUBJ_regulatory_or_accounting','prob_regulatory_accounting',myparams)
	#train_model('SUBJ_public_relations','prob_external_relations',myparams)
	#train_model('SUBJ_deal_or_trading','prob_deal_trading',myparams)
	train_model('SUBJ_info_technology','prob_info_tech',myparams)
	#train_model('SUBJ_other_unsure','prob_other_unsure',myparams)

	#pick_a_subject(myparams)
	
	
if __name__ == '__main__':
	main()


"""
select * from kiwipycon.basedata_training where TrainingSet=1 order by prob_deal_trading desc;

WORKING STUFF - IGNORE
show columns from `kiwipycon`.`basedata_training`;
SELECT column_name
FROM information_schema.columns
WHERE table_name = 'basedata_training'

id
person
folder
msg_subject
msg_body
from_internal
to_howmany
to_howmany_internal
to_howmany_external
cc_howmany
cc_howmany_internal
cc_howmany_external
perc_words_are_digits
perc_words_are_caps
num_words_in_body
is_forwarded
is_reply
randnum
message_type
SUBJ_social_or_personal
SUBJ_admin_or_planning
SUBJ_human_resources
SUBJ_regulatory_or_accounting
SUBJ_public_relations
SUBJ_deal_or_trading
SUBJ_info_technology
SUBJ_other_unsure
TrainingSet
prob_social_pers
prob_admin_planning
prob_human_rec
prob_regulatory_accounting
prob_external_relations
prob_deal_trading
prob_info_tech
prob_other_unsure
likely_label_subject
likely_label_type
prob_social_pers_RANK
prob_admin_planning_RANK
prob_human_rec_RANK
prob_regulatory_accounting_RANK
prob_external_relations_RANK
prob_deal_trading_RANK
prob_info_tech_RANK
prob_other_unsure_RANK

"""
