Predicting Mortality of ICU Patients: The PhysioNet/Computing in Cardiology Challenge 2012 1.0.0

File: <base>/sources/alistairewj_at_gmail.com/entry6/NicForest_train.m (10,443 bytes)
function [ f_out  ] = NicForest_train(xtrain, ytrain, opt)
%   NICFOREST_TRAIN     Trains an ensemble.
%
%   [ forests , ytrain_pred ] = NicForest_train(xtrain,ytrain,NTrees,width,opt)
%       Trains a forest on data in xtrain associated with targets in ytrain 
%       using parameters specified by opt.

%	$LastChangedBy: alistair $
%	$LastChangedDate: 2012-05-30 12:21:30 +0100 (Wed, 30 May 2012) $
%	$Revision: 21 $
%	Originally written on MACI64 by Louis Mayaud, 25-April-2012 14:05:26
%	Contact: alistairewj@gmail.com

%% Initialise parameters
[N,P]=size(xtrain);
if nargin<3
    error('NicForest_train requires 3 inputs.');
end

%=== error check options
% opt = errorChk(opt);
width = opt.Width;
sigma = opt.Sigma;

if strcmp(opt.Family,'binomial')
    likelihoodFcn = @ll;
    %=== set default intercept+sigma
    intercept = log(1/(-1+1/mean(ytrain)));
    sigma = 0;
    % width^2*Ntrees/2 ~ var(logit(Pi)) where Pi is the pred from reasonable model.
    if isempty(width)
        width=0.2;
    end
else
    %=== normalize training set to smooth out non-linearities
    f_out.ytrain = ytrain; % save values
    f_out.ynum = numel(ytrain)+1;
    
    ytrain = tiedrankrelative(ytrain);
    f_out.yrk = ytrain; % save ranks
    ytrain = ytrain / f_out.ynum;
    
    ytrain = norminv(ytrain,0,1); 
    
    likelihoodFcn = @normll;
    intercept = mean(ytrain);
    if isempty(sigma)
        sigma = sqrt(var(ytrain));
    end
end
    
%=== Extract all variables from opt
Ntrees = opt.Trees;
MCCite=opt.Iterations; % number of Iterations for each MCMC repetition
MCCsave=opt.Save;  % Save forest every MCCsave
MCCres=opt.Resets;  % Number of reset  during MCCres iterations
MCCnbr = MCCite*MCCres;
NtreesUpdate = opt.UpdatedTrees;

%=== Hierarchichal model parameters
if isempty(opt.Group)
    %=== no specified hierarchy/groupings
    group = ones(N,1);
else
    group=opt.Group;
    groups = unique(group); % get the group numbers
    if numel(groups)==1 % only one group, do not weight anything
        group = ones(N,1);
    else
        group_num = hist(group,groups); % get the number of obs in each group
        
        for k=1:numel(groups)
            % group = (1/number_observations_in_group)
            group(group==groups(k)) = 1/group_num(k);
        end
        
        %=== Add in group power if using
        if ~isempty(opt.GroupPower)
            group = group.^(opt.GroupPower);
        end
    end
end

e1_exact = zeros(MCCsave,1);
e1_average = zeros(floor(MCCite/MCCsave),MCCres);
figure(1); clf; 
%=== Rank and replace training data
f_out.xtrain=xtrain;
xtrain = tiedrankrelative(xtrain);
NaNNbre = sum( ~isnan(xtrain) , 1 ) + 1; %+1 is to symetrize the ratio number of observation/NaNNbre

%% Normalize data
% squish data by assuming it is drawn from a normal distribution
%=== Scale ranks between 0->1
xtrain_normalized = bsxfun(@rdivide,xtrain,NaNNbre);

%=== Squish into normal distribution!
xtrain_normalized=norminv(xtrain_normalized,0,1);

%=== Initialize forest
forests=zeros( floor(MCCnbr/MCCsave*4/5), 11, Ntrees );
priors=zeros( floor(MCCnbr/MCCsave*4/5), 2 );
% goes, number of forest
% 11 paramers in each tree
% trees+1 is number of trees in the forest + 1 is intercept
% (b0 in regression)

%% Monte carlo
count=0;
for r=1:MCCres
    %=========================%
    %=== RESET MONTE CARLO ===%
    %=========================%
    
    %=== Initialize forest
    [forest adds sums] = initialize_forest( Ntrees , P , ytrain );
    BurnInFlag = false;
    
    %=== Initialize priors + add to adds/sums
    prior = zeros(2,MCCite+1); % 2xMCCite - [intercept; sigma]
    prior(1,1) = intercept; % initial sigma
    prior(2,1) = sigma; % initial sigma
    
    %=== for efficiency, generate all the random variables used to update
    %priors before the MCMC loop
    prior(2,2:end) = normrnd(0,width,1,size(prior,2)-1);
    prior(2,2:end) = gamrnd(1000000,1/1000000,1,size(prior,2)-1);
    % prior(2,2:end) = 0.02*rand(1,size(prior,2)-1) - 0.01; % uniform distribution
    
    adds(:,Ntrees+1)=prior(1,1);
    sums=sums+prior(1,1);
    
    %=== Keep track of the variables used in vector
    varUsed = zeros(3,P);
    for k=1:3
        varUsed(k,:) = hist(forest(k,:),1:P); % count number of variables used in forest
    end
    
    %=========================%
    %=== BEGIN MONTE CARLO ===%
    %=========================%
    for j=1:MCCite % for each MCMC iteration
        % Create the next forest candidate
        cForest = forest;
        prior(1,j+1)=prior(1,j)+prior(1,j+1); % randomize intercept
        prior(2,j+1)=abs(prior(2,j)*prior(2,j+1)); % randomize sigma
        
        
        v=zeros(N,NtreesUpdate);
        tempsums=sums+prior(1,j+1)-prior(1,j); % current tree intercept contribution
        UpdatedTreesNbr= randsample(Ntrees,NtreesUpdate,0); % randomly selected 2 trees
        
        % each iteration considers 2 random trees to update
        for i=1:NtreesUpdate
            % Update tree in the forest
            cForest(:,UpdatedTreesNbr(i)) = update_tree( forest(:,UpdatedTreesNbr(i)) , .5 , width , P, varUsed) ;
            
            % Apply tree
            v(:,i) = apply_tree_quick( cForest(:,UpdatedTreesNbr(i)), NaNNbre, xtrain, xtrain, xtrain_normalized, xtrain_normalized);
%             v(:,i) = apply_tree( cForest(:,UpdatedTreesNbr(i)) , xtrain , NaNNbre, xtrain_normalized );
            
            % temp score matrix where to update only the two changed trees
            % where adds is he old contribution (overwriten 2 trees)
            tempsums=tempsums+v(:,i)-adds(:,UpdatedTreesNbr(i));
        end
        
        % Compute log likelyhood for current and previous forest
        e1=feval(likelihoodFcn,sums,ytrain,prior(2,j),group); % log likelihood previous forest
        e2=feval(likelihoodFcn,tempsums,ytrain,prior(2,j+1),group); % log likelihood new forest
        
        e1_exact(mod(j-1,MCCsave)+1) = e1;
        
        % Metropolis accepting step
        if(rand(1)<exp(e1-e2))
            varRemove = forest(1:3,UpdatedTreesNbr);
            varAdd = cForest(1:3,UpdatedTreesNbr);
            for i=1:length(UpdatedTreesNbr)
                % accept the new forests
                forest(:,UpdatedTreesNbr(i)) = cForest(:,UpdatedTreesNbr(i));
                adds(:,UpdatedTreesNbr(i))=v(:,i); %update adds
                
                % update the variable count tracking
                idxUpdate = sub2ind([3,P],[1;2;3],varRemove(:,i));
                varUsed(idxUpdate) = varUsed(idxUpdate) - 1;
                idxUpdate = sub2ind([3,P],[1;2;3],varAdd(:,i));
                varUsed(idxUpdate) = varUsed(idxUpdate) + 1;
            end
            
            sums=tempsums; %update sums
        else
            %=== rejected, reset prior to previous value
            prior(:,j+1) = prior(:,j);
        end
        
        % save forests every MCCsave
        if(mod(j,MCCsave)==0)
            
            e1_average( floor(j/MCCsave), r ) = mean(e1_exact);
            set(0, 'CurrentFigure', 1); plot(e1_average);
            drawnow;
            if ~BurnInFlag && j>= (MCCite/5) % >= 20% into sampling
                % if MCCsave = 2000, MCCres = 40, then never in burn-in period
                BurnInFlag=true; % no longer in burn-in period.
            end
            if BurnInFlag
                count=count+1;
                forests(count,:,:)=forest;
                priors(count,:) = prior(:,j+1);
            end
            fprintf('%2.0f%% complete. MCMC %g: e1=%3.3f\n',(j+(r-1)*MCCite)/MCCnbr*100, r, e1_average( floor(j/MCCsave), r ));
        end
    end
end

if count ~= (floor(MCCnbr/MCCsave*4/5)+MCCres)
    warning('Missing forests here');
end
    
%=== Add intercept prior to end of forest matrix
forests = cat(3,forests,repmat(priors(:,1),1,11));
    

%=== Remove extra forests which occur when parameters Ntrees + MCMC
%iterations don't match up
if count < size(forests,1)
    forests = forests(1:count,:,:);
end
f_out.forests = forests;
f_out.priors = priors;
f_out.e1 = e1_average;
f_out.TrainNaN = NaNNbre;
f_out.Family = opt.Family;
end

function [p] = ll(pred,tar,sigma,group) % log likelihood
pred = invlogit(pred);
p = sum((-tar.*log(pred)-(1-tar).*log(1-pred)).*group);

end

function [p] = normll(pred,tar,sigma,group) % log likelihood
p = log(sigma)*sum(group) + 0.5*sum(((tar-pred).^2).*group)/(sigma^2);
end

function [p] = invlogit(p) % inverse logit
p = 1./(1+exp(-p));
end

function [forest adds sums] = initialize_forest( Ntrees , Nvar , ytrain )
forest=zeros(11,Ntrees); %initialize  matrix of trees (current forest)
N=length(ytrain);
for i=1:3 
    forest(i,1:Ntrees) = randsample(1:Nvar,Ntrees,1);
end
forest(7:9,:)=0.5;

    adds=zeros(N,Ntrees+1);
    
    sums=zeros(N,1);
    
end

function cT = update_tree(cT,p,width,varNum,varUsed)
% tree is a vector has the following parameters
% - 1-3: Variables indices for first three nodes
% - 4-5: Threshods for nodes 1 and 2
% - 6: slope
% - 7-9 : missing value param for nodes 1 to 3
% - 10 : architechture type (1,2,3,4), i.e. location of the final node in the tree
% - 11: intercept
% ramdomly create a entirely new tree or a similar to speed up
p2 = 0.5; % Probability of propagating a variable towards bottom of the tree
leakage = 0.1; % Probability of ignoring weights for variable selection
varTemp = sum(varUsed(1:2,:),1);
if(rand(1)<p)
    %=== Randomize variables to split on
    for k=1:2 % For the split variables, weight them using both splits
        if rand(1) < leakage
            cT(k)=randsample(varNum,1,1); % sample ignoring weights
        else
            cT(k)=randsample(varNum,1,1,varTemp); % resample variables using weights
        end
    end
    
    %=== Randomize variable to regress on
    % For the regression variable, which is weighted independetly
    if rand(1) < leakage
        cT(3)=randsample(varNum,1,1); % sample ignoring weights
    else
        cT(3)=randsample(varNum,1,1,varUsed(3,:)); % resample variables using weights
    end
    
%     if(rand(1)<p2); cT(2)=cT(1); end % encourage 2nd split to be first split
%     if(rand(1)<p2); cT(3)=cT(2); end % encourage 3rd split to be 2nd split
    
    cT([4:5 7:9])=rand(5,1);
    cT([6 11])=randn(2,1)*width;
    cT(10)=randsample(1:4,1);
else
    % tree similar to previous
    cT([4:5 7:9])=mod(cT([4:5 7:9])+(rand(5,1)-0.5)/2,1);
    cT([6 11])=1/2*cT([6 11])+randn(2,1)*sqrt(1-(1/2)^2)*width;
end
end