%%
%Trigram Experiment. For 'Toward a Characterization of Loss Functions for
%Distribution Learning.'

%load word frequency list
fileID = fopen('unigram_freq.csv','r');
U = textscan(fileID, '%d %s %f','delimiter', ',', 'EmptyValue', -Inf);

%first column contains word frequency rank. Second contains count of that
%word per 1 billion words
A = [cell2mat(U(1)),cell2mat(U(3))]; 
A = double(A);
%the list of words themselves. row i corresponds to row i in a.
words = char(U{1,2});


%%
%Count the trigram and unigram freqs using full dataset
%Trigram characters are 26 letters of alphabet, numbers, some punctuation, and the space character
%which starts and ends words. Capitals are converted to lowercase and
%accents removed in the word lists
alpha = ['abcdefghijklmnopqrstuvwxyz1234567890.+-#?/]','''',' '];

%read in ~10000 German and 10000 French words
fileID = fopen('german.txt','r');
U = textscan(fileID, '%s','delimiter', ' ', 'EmptyValue', -Inf);
germanWords = char(U{1,1});

fileID = fopen('french.txt','r');
U = textscan(fileID, '%s %s','delimiter', '\t', 'EmptyValue', -Inf);
frenchWords = char(U{1,2});
frenchWords = [frenchWords,repmat([' '], size(frenchWords,1),11)];

%deduplicate words that already exist in the English word list.
poisonWords = [germanWords; frenchWords];
poisonWords = unique(poisonWords,'rows','stable');
words = [words; poisonWords];
words = unique(words,'rows','stable');

%end frequency matrix  to add in each German/French work with fixed
%frequency 7000.
A = [A; repmat([0,7000],size(words,1)-size(A,1),1)];

%pad each word with spaces for use in trigram computation
words = [words, repmat([' '], size(words,1),2)];

%trigram counts
tri = zeros(length(alpha),length(alpha),length(alpha));
%run through each word
for j=1:size(words,1)
    %keep track of current trigram
    cur = '   ';
    %run through each letter
    for k = 1:size(words,2)
        %update trigram by shifting and adding in current position
        %convert all characters to lowercase
        cur = [cur(2:3),lower(words(j,k))];
        %update trigram count using word frequency
        tri(find(alpha == cur(1)),find(alpha == cur(2)),find(alpha == cur(3))) = tri(find(alpha == cur(1)),find(alpha == cur(2)),find(alpha == cur(3))) + A(j,2);     
    end
end
%manually remove overcounting of '   '.
tri(end,end,end) = 0;

%%
%Count the trigram and unigram freqs using modified frequencies biased to
%the head of the distribution
triTop = zeros(length(alpha),length(alpha),length(alpha));
ATop = A; 
%bias towards higher frequency words.
ATop(:,2) = A(:,2).^1.4;

%run through each word
for j=1:size(words,1)
    %keep track of current trigram
    cur = '   ';
    %run through each letter
    for k = 1:size(words,2)
        %update trigram by shifting and adding in current position
        cur = [cur(2:3),lower(words(j,k))];
        triTop(find(alpha == cur(1)),find(alpha == cur(2)),find(alpha == cur(3))) = triTop(find(alpha == cur(1)),find(alpha == cur(2)),find(alpha == cur(3))) + ATop(j,2);     
    end
end

%triTop = triTop + .000000001*(tri~=0);

%manually remove overcounting of '   ';
triTop(end,end,end) = 0;

%%
%Generate words from different trigram models
display('Random Samples from Full Model:')
for t = 1:20
    cur = '  ';
    for i= 1:30
        p = tri(find(alpha == cur(i)),find(alpha == cur(i+1)),:);
        nextLetter = randsample(length(alpha),1,true,p/sum(p));
        cur = [cur, alpha(nextLetter)];
        if(nextLetter == length(alpha)) %have a space so end
            break;
        end
    end
    display([cur,'']);
end

display('Random Samples from Top Words Model:')
for t = 1:20
    cur = '  ';
    for i= 1:30
        p = triTop(find(alpha == cur(i)),find(alpha == cur(i+1)),:);
        nextLetter = randsample(length(alpha),1,true,p/sum(p));
        cur = [cur, alpha(nextLetter)];
        if(nextLetter == length(alpha)) %have a space so end
            break;
        end
    end
    display([cur,'']);
end


%%
%Calculate Losses
%word probability distribution:
p = double(A(:,2))/sum(A(:,2));
%compute trigram probability distribution
q = ones(size(p));
qTop = ones(size(p))

display('Random Samples from Target:')
samps = randsample(length(p),20,true,p);
for j=1:20
    display([words(samps(j),:),'']);
end

for j = 1:size(words,1)
    cur = '   ';
    for k = 1:size(words,2)
        cur = [cur(2:3),lower(words(j,k))];
        if(cur == '   ')
            break;
        end
        marginal = tri(find(alpha == cur(1)),find(alpha == cur(2)),:);
        marginal = marginal/sum(marginal);
        q(j) = q(j)* marginal(find(alpha == cur(3)));
        
        marginalTop = triTop(find(alpha == cur(1)),find(alpha == cur(2)),:); 
        marginalTop = marginalTop/sum(marginalTop);
        qTop(j) = qTop(j)* marginalTop(find(alpha == cur(3)));   
    end
end

%plot log probabilities for top 1000 words
figure
plot(log(1./p(1:1000)))
hold on
plot(log(1./q(1:1000)))
hold on 
plot(log(1./qTop(1:1000)))

%compute and display log loss, sqrt log loss, and log log loss for target
%and two candidates
display(sprintf('Log Loss p: %f',sum(p.*log(1./p))));
display(sprintf('Log Loss q: %f',sum(p.*log(1./q))));
display(sprintf('Log Loss qTop: %f',sum(p.*log(1./qTop))));

display(sprintf('Sqrt Log Loss p: %f',sum(p.*(log(1./p).^.5))));
display(sprintf('Sqrt Log Loss q: %f',sum(p.*(log(1./q).^.5))));
display(sprintf('Sqrt Log Loss qTop: %f',sum(p.*(log(1./qTop).^.5))));

display(sprintf('Log Log Loss p: %f',sum(p.*log(log(1./p)))));
display(sprintf('Log Log Loss q: %f',sum(p.*log(log(1./q)))));
display(sprintf('Log Log Loss qTop: %f',sum(p.*log(log(1./qTop)))));

%compute cumulative density 
pCum = zeros(size(p));
qCum = zeros(size(q));
qTopCum = zeros(size(qTop));

for j=1:size(q,1)
    pCum(j) = sum(p(1:j));
    qCum(j) = sum(q(1:j));
    qTopCum(j) = sum(qTop(1:j));
end

%display cumulative density for top 1000 words. 
figure
plot(pCum(1:1000),'linewidth',2);
hold on
plot(qCum(1:1000),'linewidth',2);
hold on
plot(qTopCum(1:1000),'linewidth',2);

%%
%How many distinct common word samples are generated

samps = randsample(length(p),10000,true,p);
%filter out duplicates
samps = unique(samps);
%filter to look at actual English words
display('Unique Words Generated By Target:')
sum(samps < 36664)

%sampling from q
wordsq = repmat([' '],10000,size(words,2)+2);
wordsqTop = repmat([' '],10000,size(words,2)+2);

for t = 1:10000
    cur = '  ';
    for i= 1:29
        p = tri(find(alpha == cur(i)),find(alpha == cur(i+1)),:);
        nextLetter = randsample(length(alpha),1,true,p/sum(p));
        cur = [cur, alpha(nextLetter)];
        if(nextLetter == length(alpha)) %have a space so end
            break;
        end
    end
    wordsq(t,1:length(cur)) = cur;
end

%sampling from qTop
for t = 1:10000
    cur = '  ';
    for i= 1:29
        p = triTop(find(alpha == cur(i)),find(alpha == cur(i+1)),:);
        nextLetter = randsample(length(alpha),1,true,p/sum(p));
        cur = [cur, alpha(nextLetter)];
        if(nextLetter == length(alpha)) %have a space so end
            break;
        end
    end
    wordsqTop(t,1:length(cur)) = cur;
end

%remove leading 2 spaces
wordsq = wordsq(:,3:end);
wordsqTop = wordsqTop(:,3:end);

%see if the generated words are actually in the English word list
idq = zeros(10000,1);
idqTop = zeros(10000,1);
for t = 1:10000
    t
    if(sum(ismember(words,wordsq(t,:),'rows')) > 0)
        idq(t) = find(ismember(words,wordsq(t,:),'rows'));
    end
    if(sum(ismember(words,wordsqTop(t,:),'rows')) > 0)
        idqTop(t) = find(ismember(words,wordsqTop(t,:),'rows'));
    end
end

%restrict to the English, not French/German words
idqF = idq .* (idq < 36664);
idqTopF = idq .* (idqTop < 36664);
%unique English words generated.
display('Unique Words Generated By q:')
sum(unique(idqF) > 0)
display('Unique Words Generated By qTop:')
sum(unique(idqTopF) > 0)

