#!/usr/bin/python2 -B
# -*- coding: utf-8; mode: Python; indent-tabs-mode: t; tab-width: 4; python-indent: 4 -*-

# Copyright (C) 2012  Olga Yakovleva <yakovleva.o.v@gmail.com>

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os.path
import tempfile
import codecs
import re
import argparse
import subprocess
import collections
import regex

vowels=set(u"аеёиоуыэюя")

def find_first_group(parsed_regex,source_string):
	assert(isinstance(parsed_regex,regex.sequence_node))
	index=-1
	for i in xrange(len(parsed_regex.subexpressions)):
		location=parsed_regex.subexpressions[i].location
		if location>0 and source_string[location-1]=="(" and source_string[:location].count("(")==1:
			index=i
			break
	return index

def get_vowel_positions(word):
	return [i for i in xrange(len(word)) if word[i] in vowels]

def find_prev_vowel(word,i):
	for j in xrange(min(i,len(word)-1),-1,-1):
		if word[j] in vowels:
			return j
	return -1

def describe_transformation(word,pron):
	word_vowel_positions=get_vowel_positions(word)
	if not word_vowel_positions:
		return None
	pron_vowel_positions=get_vowel_positions(pron)
	if len(word_vowel_positions)!=len(pron_vowel_positions):
		return None
	pattern=list()
	for i,j in zip(word_vowel_positions,pron_vowel_positions):
		stressed=(j<(len(pron)-1)) and (pron[j+1]=="+")
		if word[i]==pron[j]:
			if stressed:
				pattern.append([i,True])
		elif word[i]==u"е":
			if pron[j]==u"ё":
				pattern.append([i,u"ё"])
			elif pron[j]==u"э":
				pattern.append([i,u"Э" if stressed else u"э"])
			else:
				return None
		else:
			return None
	return pattern

def apply_transformation(word,pattern):
	letters=list(word)
	for i,v in pattern:
		if word[i]!=u"ё":
			if v is True:
				letters[i]=word[i].upper()
			elif word[i]==u"е":
				letters[i]=v
			else:
				letters[i]=word[i].upper()
	return u"".join(letters)

def correct_transformation(word,pattern):
	last_vowel_index=find_prev_vowel(word,len(word))
	if last_vowel_index==-1:
		return None
	new_pattern=list()
	for i,v in pattern:
		j=find_prev_vowel(word,i)
		if j==i:
			if word[i]==u"е":
				new_pattern.append([i,v])
			else:
				new_pattern.append([i,True])
		elif i==0:
			return None
		elif re.match(u".*[адезилнорстуюя]ек$",word) and i==len(word):
			new_pattern.append([j,u"ё"])
		elif j==last_vowel_index:
			new_pattern.append([j,True])
		else:
			s=word[i-1:i+3]
			if re.match(u"[бвмпф]лю",s):
				new_pattern.append([i+1,True])
			elif re.match(u"ст[ия]",s):
				new_pattern.append([i+1,True])
			elif re.match(u".[йь]ц[аеоуы]",s):
				new_pattern.append([i+2,True])
			else:
				new_pattern.append([j,True])
	return new_pattern

def load_dict(inpath):
	result=collections.OrderedDict()
	with codecs.open(inpath,"r","koi8-r") as f_in:
		for line in f_in:
			tokens=line.lower().split()
			if len(tokens)==1:
				result[tokens[0]]=""
			elif len(tokens)==2:
				result[tokens[0]]=tokens[1]
	return result

def save_dict(lex,outpath,name):
	max_count=500000
	number=0
	f_out=None
	for word,pron in lex.iteritems():
		if f_out is None:
			number+=1
			count=0
			file_name="{}.{}.txt".format(outpath,number)
			f_out=codecs.open(file_name,"w","utf-8")
		f_out.write(u" ".join(word))
		f_out.write("\n")
		f_out.write(u" ".join(pron))
		f_out.write("\n")
		count+=1
		if count==max_count:
			f_out.close()
			f_out=None
	if f_out is not None:
		f_out.close()
		f_out=None
	foma=subprocess.Popen(["foma","-p"],stdin=subprocess.PIPE)
	foma.stdin.write("read spaced-text {}.1.txt\n".format(outpath))
	for i in xrange(2,number+1):
		foma.stdin.write("read spaced-text {}.{}.txt\n".format(outpath,i))
		foma.stdin.write("union\n")
	foma.stdin.write("define {}\n".format(name))
	foma.stdin.write("save defined {}.bin\n".format(outpath))
	foma.stdin.write("quit\n")
	foma.wait()

def import_explicit_dict(args):
	source=load_dict(os.path.join(args.datadir,"explicit.dict"))
	result=collections.OrderedDict()
	for word,pron in source.iteritems():
		pattern=describe_transformation(word,pron)
		if pattern is not None:
			new_pron=apply_transformation(word,pattern)
			if new_pron:
				result[word]=new_pron
	save_dict(result,"explicit_dict","ExplicitDict")

def import_lexicon_rules(args):
	lex=load_dict(os.path.join(args.datadir,"implicit.dict"))
	with codecs.open("base_forms.txt","w","utf-8") as f_out:
		for word in lex.iterkeys():
			f_out.write(word)
			f_out.write("\n")
	lex=load_dict(os.path.join(args.datadir,"lexicon.rules"))
	with codecs.open("lexicon_rules.foma","w","utf-8") as f_out:
		f_out.write("read text base_forms.txt\n")
		f_out.write("define BaseForms\n")
		count=0
		for pattern,base_form_ending in lex.iteritems():
			count+=1
			print("Parsing rule {}".format(count))
			parse_result=regex.parse(pattern)
			parsed_regex=parse_result.root
			index=find_first_group(parsed_regex,pattern)
			assert(index!=-1)
			suffix=u" ".join([expr.format_as_foma_regex() for expr in parsed_regex.subexpressions[:index+1]]) or "0"
			ending=u" ".join(expr.format_as_foma_regex() for expr in parsed_regex.subexpressions[index+1:]) or "0"
			f_out.write(u"define BaseFormEnding {} ; \n".format(u" ".join(base_form_ending) or "0"))
			f_out.write(u"define Suffix {} ; \n".format(suffix))
			f_out.write(u"define BaseForm BaseForms & [?* Suffix BaseFormEnding] ; \n")
			f_out.write(u"define Stem [BaseForm .o. [BaseFormEnding -> 0 || _ .#. ]].l ; \n")
			f_out.write(u"define Ending {} ; \n".format(ending))
			f_out.write("define Rule{index} Stem [Ending:BaseFormEnding] ; \n".format(index=count))
		f_out.write("regex \n")
		f_out.write(u" .P. \n".join(u"Rule{index}".format(index=i) for i in xrange(1,count+1)))
		f_out.write(" ; \n")
		f_out.write("save stack lexicon_rules.bin\n")
	subprocess.check_call(["foma","-f","lexicon_rules.foma"])

def import_general_rules(args):
	source=load_dict(os.path.join(args.datadir,"general.rules"))
	with codecs.open("general_rules.foma","w","utf-8") as f_out:
		for i,pattern in enumerate(source.iterkeys()):
			print("Parsing rule {}".format(i+1))
			parse_result=regex.parse(pattern)
			parsed_regex=parse_result.root
			index=find_first_group(parsed_regex,pattern)
			assert(index!=-1)
			before=u" ".join([expr.format_as_foma_regex() for expr in parsed_regex.subexpressions[:index+1]]) or "0"
			prefix="0" if "start_of_string" in parse_result else "?*"
			after=u" ".join(expr.format_as_foma_regex() for expr in parsed_regex.subexpressions[index+1:]) or "0"
			f_out.write(u"define Rule{index} {prefix} [{before}] 0:%+ [{after}] ; \n".format(index=i+1,prefix=prefix,before=before,after=after))
		f_out.write("regex [\n")
		f_out.write(u" .P. \n".join("Rule{}".format(i+1) for i in xrange(len(source))))
		f_out.write("] .o. \n")
		f_out.write(u'[[а:А|е:Е|ё:ё|и:И|о:О|у:У|ы:Ы|э:Э|ю:Ю|я:Я] "+":0 -> || _ ] ; \n')
		f_out.write("save stack general_rules.bin\n")
	subprocess.check_call(["foma","-f","general_rules.foma"])

def expand_implicit_dict(args):
	source=load_dict(os.path.join(args.datadir,"implicit.dict"))
	patterns=dict()
	for keyword,keyword_pron in source.iteritems():
		pattern=describe_transformation(keyword,keyword_pron)
		if pattern is not None:
			patterns[keyword]=pattern
	f_out=tempfile.TemporaryFile()
	flookup=subprocess.Popen(["flookup","lexicon_rules.bin"],stdin=subprocess.PIPE,stdout=f_out)
	flookup.communicate(u"\n".join(source.keys()).encode("utf-8"))
	f_out.seek(0)
	rules=collections.OrderedDict()
	for line in f_out:
		if line.isspace():
			continue
		if "?" in line:
			continue
		base_form,word=line.decode("utf-8").split()
		if word in rules:
			if len(base_form)>len(rules[word]):
				rules[word]=base_form
		else:
			rules[word]=base_form
	result=collections.OrderedDict()
	for word,base_form in rules.iteritems():
		if base_form not in patterns:
			continue
		pattern=correct_transformation(word,patterns[base_form])
		if pattern is not None:
			result[word]=apply_transformation(word,pattern)
	save_dict(result,"implicit_dict","ImplicitDict")

def merge_dicts(args):
	foma=subprocess.Popen(["foma","-p"],stdin=subprocess.PIPE)
	foma.stdin.write("load defined explicit_dict.bin\n")
	foma.stdin.write("load defined implicit_dict.bin\n")
	foma.stdin.write("define Dict ExplicitDict .P. ImplicitDict;\n")
	foma.stdin.write("regex _notid(Dict) .o. Dict ; \n")
	foma.stdin.write("save stack dict.bin\n")
	foma.stdin.write("quit\n")
	foma.wait()

def test(args):
	source=load_dict(os.path.join(args.datadir,"lexicon.test"))
	lex=collections.OrderedDict()
	for word,pron in source.iteritems():
		pattern=describe_transformation(word,pron)
		if pattern is not None:
			new_pron=apply_transformation(word,pattern)
			if new_pron!=word:
				lex[word]=new_pron
	print("Testing {} words".format(len(lex)))
	errors=list()
	for fst in ["dict.bin","general_rules.bin"]:
		flookup=subprocess.Popen(["flookup","-i",fst],stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
		output,error=flookup.communicate(u"\n".join(lex.iterkeys()).encode("utf-8"))
		for line in output.decode("utf-8").split("\n"):
			if not line:
				continue
			if "?" in line:
				continue
			word,pron=line.split()
			if pron!=lex[word]:
				errors.append(word)
			del lex[word]
	errors.extend(lex.iterkeys())
	for word in errors:
		print word

if __name__=="__main__":
	arg_parser=argparse.ArgumentParser(description="Import RuLex")
	arg_parser.add_argument("--datadir",required=True,help="Path to RuLex data directory")
	arg_subparsers=arg_parser.add_subparsers()
	import_lexicon_rules_parser=arg_subparsers.add_parser("import-lexicon-rules")
	import_lexicon_rules_parser.set_defaults(func=import_lexicon_rules)
	import_general_rules_parser=arg_subparsers.add_parser("import-general-rules")
	import_general_rules_parser.set_defaults(func=import_general_rules)
	import_explicit_dict_parser=arg_subparsers.add_parser("import-explicit-dict")
	import_explicit_dict_parser.set_defaults(func=import_explicit_dict)
	expand_implicit_dict_parser=arg_subparsers.add_parser("expand-implicit-dict")
	expand_implicit_dict_parser.set_defaults(func=expand_implicit_dict)
	merge_dicts_parser=arg_subparsers.add_parser("merge-dicts")
	merge_dicts_parser.set_defaults(func=merge_dicts)
	test_parser=arg_subparsers.add_parser("test")
	test_parser.set_defaults(func=test)
	args=arg_parser.parse_args()
	args.func(args)
