#!/usr/bin/python2

import sys,os
from math import log
import openfst

chars = {}
counts = {}

if len(sys.argv)<2:
    sys.stderr.write("usage: %s outfile n files...\n"%sys.argv[0])
    sys.stderr.write("    computes n-graph models (per-character n-grams) with p(c) removed")
    sys.exit(0)

outfile = sys.argv[1]
assert not os.path.exists(outfile)
n = int(sys.argv[2])
assert n>1 and n<30

for file in sys.argv[3:]:
    sys.stderr.write("[%s]\n"%file)
    for line in open(file).readlines():
        line = "^"*(n-1) + line
        for i in range(n,len(line)):
            ngram = line[i-n:i]
            chars[ngram[-1]] = 1 + chars.get(ngram[-1],0)
            counts[ngram] = 1 + counts.get(ngram,0)

ngrams = counts.keys()
ngrams.sort()
total = sum(counts.values())
try: del chars["^"]
except: pass
ctotal = sum(chars.values())
assert ctotal>0
print "n",n
print "total",total,ctotal
print "distinct",len(ngrams)
print "coverage",len(ngrams)*1.0/len(chars.keys())**n

fst = openfst.StdVectorFst()

nstates = 0
states = {}

def getstate(s):
    result = states.get(s)
    if result is None:
        result = fst.AddState()
        states[s] = result
    return result

start = getstate("^"*(n-1))
fst.SetStart(start)

for ngram in ngrams:
    c = ngram[-1]
    frm = getstate(ngram[:-1])
    to = getstate(ngram[1:])
    if 0:
        cost = -log(counts[ngram]*1.0/total)
    else:
        arg = counts[ngram]/float(chars[c]) * (ctotal/float(total))
        cost = -log(arg)
    fst.AddArc(frm,ord(ngram[-1]),ord(ngram[-1]),cost,to)
    fst.SetFinal(to,0.0)
    # print frm,to,ngram[-1],cost

det = openfst.StdVectorFst()
openfst.Determinize(fst,det)
openfst.Minimize(det)
det.Write(outfile)
