function [ConditionalMeans, ConditionalDistributions] = conditionalmeansanddistributions(dataset,supportOfOutcomeY,pi01)

%This function calculates analytical bounds on ATE that are derived in the
%paper.
%INPUT:  dataset - dataset with three columns (outcome,treatment,instrument)
%        supportOfOutcomeY - row vector
%        pi01 - share of Defiers
%OUTPUT: conditionalMeans - [y11max,y10max,y01max,y00max,y11min,y10min,y01min,y00min]
%        conditionalDistributions - conditional probability distributions
%        that yield these means
%NESTED FUNCTIONS: are defined at the end of this m-file.

y = dataset(:,1);   %Outcome.
t = dataset(:,2);   %Treatment.
z = dataset(:,3);   %Instrument

%Ptz = Probability of Treatment=t given Instrument=z
P11=sum(z==1&t==1)/sum(z==1);
P10=sum(z==0&t==1)/sum(z==0);
P01=1-P11;
P00=1-P10;

y11=sort(y(z==1&t==1));
y01=sort(y(z==0&t==1));
y10=sort(y(z==1&t==0));
y00=sort(y(z==0&t==0));

%x11=unique(y11);
%x01=unique(y01);
%x10=unique(y10);
%x00=unique(y00);
x11 = supportOfOutcomeY';
x01 = supportOfOutcomeY';
x10 = supportOfOutcomeY';
x00 = supportOfOutcomeY';

n01=length(x01);
n11=length(x11);
n00=length(x00);
n10=length(x10);

f11=x11;
for i=1:n11
    if sum(y11==x11(i))>0
        f11(i)=mean(y11==x11(i));
    else
        f11(i) = 0;
    end
end

f01=x01;
for i=1:n01
    %f01(i)=mean(y01==x01(i));
    if sum(y01==x01(i))>0
        f01(i)=mean(y01==x01(i));
    else
        f01(i) = 0;
    end
end

f10=x10;
for i=1:n10
    %f10(i)=mean(y10==x10(i));
    if sum(y10==x10(i))>0
        f10(i)=mean(y10==x10(i));
    else
        f10(i) = 0;
    end
end

f00=x00;
for i=1:n00
    %f00(i)=mean(y00==x00(i));
    if sum(y00==x00(i))>0
        f00(i)=mean(y00==x00(i));
    else
        f00(i) = 0;
    end
end
        
n01=length(y01);
n11=length(y11);
n00=length(y00);
n10=length(y10);


q11=((P10-pi01)/P11);
q11=q11*(q11<=1 & q11>0)+(q11>1);
q01=(P10-pi01)/P10;
q01=q01*(q01<=1& q01>0)+(q01>1);
q10=(P01-pi01)/P01;
q10=q10*(q10<=1& q10>0)+(q10>1);
q00=(P01-pi01)/P00;
q00=q00*(q00<=1& q00>0)+(q00>1);

%We use left trimmed means. Function lefttrimmedmean is defined at the
%end of this m-file.
y11max=sum(x11.*(lefttrimmedmean(f11,x11,q11)));
y01max=sum(x01.*(lefttrimmedmean(f01,x01,q01)));
y10max=sum(x10.*(lefttrimmedmean(f10,x10,q10)));
y00max=sum(x00.*(lefttrimmedmean(f00,x00,q00)));
vec11max = (lefttrimmedmean(f11,x11,q11));
vec01max = (lefttrimmedmean(f01,x01,q01));
vec10max = (lefttrimmedmean(f10,x10,q10));
vec00max = (lefttrimmedmean(f00,x00,q00));

%We use right trimmed means. Function righttrimmedmean is defined at the
%end of this m-file.
y11min=sum(x11.*(righttrimmedmean(f11,x11,q11)));
y01min=sum(x01.*(righttrimmedmean(f01,x01,q01)));
y10min=sum(x10.*(righttrimmedmean(f10,x10,q10)));
y00min=sum(x00.*(righttrimmedmean(f00,x00,q00)));
vec11min = (righttrimmedmean(f11,x11,q11));
vec01min = (righttrimmedmean(f01,x01,q01));
vec10min = (righttrimmedmean(f10,x10,q10));
vec00min = (righttrimmedmean(f00,x00,q00));


ConditionalMeans.y11max = y11max;
ConditionalMeans.y10max = y10max;
ConditionalMeans.y01max = y01max;
ConditionalMeans.y00max = y00max;
ConditionalMeans.y11min = y11min;
ConditionalMeans.y10min = y10min;
ConditionalMeans.y01min = y01min;
ConditionalMeans.y00min = y00min;

ConditionalDistributions.vec11max = vec11max;
ConditionalDistributions.vec10max = vec10max;
ConditionalDistributions.vec01max = vec01max;
ConditionalDistributions.vec00max = vec00max;
ConditionalDistributions.vec11min = vec11min;
ConditionalDistributions.vec10min = vec10min;
ConditionalDistributions.vec01min = vec01min;
ConditionalDistributions.vec00min = vec00min;

%_____________________________________________________________________
%Nested functions
function [fr]=righttrimmedmean(ff,yy,qq) %right trimmed mean
    cf=incusum(ff);
    iq=cf>=(1-qq);
    if qq<10e-6
        fr=zeros(size(ff));
        fr(1)=1;
    else
        qly=max(yy(iq));
        fr=ff.*(yy<qly)+(min(cf(iq))-(1-qq))*(yy==qly);
        fr=fr/qq;
    end
end

function [fl]=lefttrimmedmean(ff,yy,qq) %left trimmed mean
    cf=cumsum(ff);
    iq=cf>=(1-qq);
    if qq<10e-6
        fl=zeros(size(ff));
        fl(end)=1;
    else
        qly=min(yy(iq));
        fl=ff.*(yy>qly)+(min(cf(iq))-(1-qq))*(yy==qly);
        fl=fl/qq;  
    end
end

function ic=incusum(x)              %reversed cumulative sum
    nn=max(size(x));
    ic=zeros(size(x));
    for ii=0:nn-1
        ic(nn-ii)=sum(x(nn-ii:nn));
    end
end


end %end of the main function