Predicting Mortality of ICU Patients: The PhysioNet/Computing in Cardiology Challenge 2012 1.0.0
(10,675 bytes)
function [ f_out ] = NicForest_train(xtrain, ytrain, opt)
% NICFOREST_TRAIN Trains an ensemble.
%
% [ forests , ytrain_pred ] = NicForest_train(xtrain,ytrain,NTrees,width,opt)
% Trains a forest on data in xtrain associated with targets in ytrain
% using parameters specified by opt.
% $LastChangedBy: alistair $
% $LastChangedDate: 2012-05-30 12:21:30 +0100 (Wed, 30 May 2012) $
% $Revision: 21 $
% Originally written on MACI64 by Louis Mayaud, 25-April-2012 14:05:26
% Contact: alistairewj@gmail.com
%% Initialise parameters
[N,P]=size(xtrain);
if nargin<3
error('NicForest_train requires 3 inputs.');
end
%=== error check options
% opt = errorChk(opt);
width = opt.Width;
sigma = opt.Sigma;
if strcmp(opt.Family,'binomial')
likelihoodFcn = @ll;
%=== set default intercept+sigma
intercept = log(1/(-1+1/mean(ytrain)));
sigma = 0;
% width^2*Ntrees/2 ~ var(logit(Pi)) where Pi is the pred from reasonable model.
if isempty(width)
width=0.2;
end
else
%=== normalize training set to smooth out non-linearities
f_out.ytrain = ytrain; % save values
f_out.ynum = numel(ytrain)+1;
ytrain = tiedrankrelative(ytrain);
f_out.yrk = ytrain; % save ranks
ytrain = ytrain / f_out.ynum;
ytrain = norminv(ytrain,0,1);
likelihoodFcn = @normll;
intercept = mean(ytrain);
if isempty(sigma)
sigma = sqrt(var(ytrain));
end
end
%=== Extract all variables from opt
Ntrees = opt.Trees;
MCCite=opt.Iterations; % number of Iterations for each MCMC repetition
MCCsave=opt.Save; % Save forest every MCCsave
MCCres=opt.Resets; % Number of reset during MCCres iterations
MCCnbr = MCCite*MCCres;
NtreesUpdate = opt.UpdatedTrees;
%=== Hierarchichal model parameters
if isempty(opt.Group)
%=== no specified hierarchy/groupings
group = ones(N,1);
else
group=opt.Group;
groups = unique(group); % get the group numbers
if numel(groups)==1 % only one group, do not weight anything
group = ones(N,1);
else
group_num = hist(group,groups); % get the number of obs in each group
for k=1:numel(groups)
% group = (1/number_observations_in_group)
group(group==groups(k)) = 1/group_num(k);
end
%=== Add in group power if using
if ~isempty(opt.GroupPower)
group = group.^(opt.GroupPower);
end
end
end
e1_exact = zeros(MCCsave,1);
e1_average = zeros(floor(MCCite/MCCsave),MCCres);
figure(1); clf;
%=== Rank and replace training data
f_out.xtrain=xtrain;
xtrain = tiedrankrelative(xtrain);
NaNNbre = sum( ~isnan(xtrain) , 1 ) + 1; %+1 is to symetrize the ratio number of observation/NaNNbre
%% Normalize data
% squish data by assuming it is drawn from a normal distribution
%=== Scale ranks between 0->1
xtrain_normalized = bsxfun(@rdivide,xtrain,NaNNbre);
%=== Squish into normal distribution!
xtrain_normalized=norminv(xtrain_normalized,0,1);
xtrain_normalized_sort = sort(xtrain_normalized,1,'ascend');
%=== Initialize forest
forests=zeros( floor(MCCnbr/MCCsave*4/5), 11, Ntrees );
priors=zeros( floor(MCCnbr/MCCsave*4/5), 2 );
% goes, number of forest
% 11 paramers in each tree
% trees+1 is number of trees in the forest + 1 is intercept
% (b0 in regression)
%% Monte carlo
count=0;
for r=1:MCCres
%=========================%
%=== RESET MONTE CARLO ===%
%=========================%
%=== Initialize forest
[forest adds sums] = initialize_forest( Ntrees , P , ytrain );
BurnInFlag = false;
%=== Initialize priors + add to adds/sums
prior = zeros(2,MCCite+1); % 2xMCCite - [intercept; sigma]
prior(1,1) = intercept; % initial sigma
prior(2,1) = sigma; % initial sigma
%=== for efficiency, generate all the random variables used to update
%priors before the MCMC loop
prior(2,2:end) = normrnd(0,width,1,size(prior,2)-1);
prior(2,2:end) = gamrnd(1000000,1/1000000,1,size(prior,2)-1);
% prior(2,2:end) = 0.02*rand(1,size(prior,2)-1) - 0.01; % uniform distribution
adds(:,Ntrees+1)=prior(1,1);
sums=sums+prior(1,1);
%=== Keep track of the variables used in vector
varUsed = zeros(3,P);
for k=1:3
varUsed(k,:) = hist(forest(k,:),1:P); % count number of variables used in forest
end
%=========================%
%=== BEGIN MONTE CARLO ===%
%=========================%
for j=1:MCCite % for each MCMC iteration
% Create the next forest candidate
cForest = forest;
prior(1,j+1)=prior(1,j)+prior(1,j+1); % randomize intercept
prior(2,j+1)=abs(prior(2,j)*prior(2,j+1)); % randomize sigma
v=zeros(N,NtreesUpdate);
tempsums=sums+prior(1,j+1)-prior(1,j); % current tree intercept contribution
UpdatedTreesNbr= randsample(Ntrees,NtreesUpdate,0); % randomly selected 2 trees
% each iteration considers 2 random trees to update
for i=1:NtreesUpdate
% Update tree in the forest
cForest(:,UpdatedTreesNbr(i)) = update_tree( forest(:,UpdatedTreesNbr(i)) , .5 , width , P, varUsed) ;
% Apply tree
v(:,i) = apply_tree_quick( cForest(:,UpdatedTreesNbr(i)), NaNNbre, xtrain, xtrain, xtrain_normalized, xtrain_normalized_sort);
% v(:,i) = apply_tree( cForest(:,UpdatedTreesNbr(i)) , xtrain , NaNNbre, xtrain_normalized );
% temp score matrix where to update only the two changed trees
% where adds is he old contribution (overwriten 2 trees)
tempsums=tempsums+v(:,i)-adds(:,UpdatedTreesNbr(i));
end
% Compute log likelyhood for current and previous forest
e1=feval(likelihoodFcn,sums,ytrain,prior(2,j),group); % log likelihood previous forest
e2=feval(likelihoodFcn,tempsums,ytrain,prior(2,j+1),group); % log likelihood new forest
e1_exact(mod(j-1,MCCsave)+1) = e1;
% Metropolis accepting step
if(rand(1)<exp(e1-e2))
varRemove = forest(1:3,UpdatedTreesNbr);
varAdd = cForest(1:3,UpdatedTreesNbr);
for i=1:length(UpdatedTreesNbr)
% accept the new forests
forest(:,UpdatedTreesNbr(i)) = cForest(:,UpdatedTreesNbr(i));
adds(:,UpdatedTreesNbr(i))=v(:,i); %update adds
% update the variable count tracking
idxUpdate = sub2ind([3,P],[1;2;3],varRemove(:,i));
varUsed(idxUpdate) = varUsed(idxUpdate) - 1;
idxUpdate = sub2ind([3,P],[1;2;3],varAdd(:,i));
varUsed(idxUpdate) = varUsed(idxUpdate) + 1;
end
sums=tempsums; %update sums
else
%=== rejected, reset prior to previous value
prior(:,j+1) = prior(:,j);
end
% save forests every MCCsave
if(mod(j,MCCsave)==0)
e1_average( floor(j/MCCsave), r ) = mean(e1_exact);
set(0, 'CurrentFigure', 1); plot(e1_average);
drawnow;
if ~BurnInFlag && j>= (MCCite/5) % >= 20% into sampling
% if MCCsave = 2000, MCCres = 40, then never in burn-in period
BurnInFlag=true; % no longer in burn-in period.
end
if BurnInFlag
count=count+1;
forests(count,:,:)=forest;
priors(count,:) = prior(:,j+1);
end
fprintf('%2.0f%% complete. MCMC %g: e1=%3.3f\n',(j+(r-1)*MCCite)/MCCnbr*100, r, e1_average( floor(j/MCCsave), r ));
end
end
end
if count ~= (floor(MCCnbr/MCCsave*4/5)+MCCres)
warning('Missing forests here');
end
%=== Add intercept prior to end of forest matrix
forests = cat(3,forests,repmat(priors(:,1),1,11));
%=== Remove extra forests which occur when parameters Ntrees + MCMC
%iterations don't match up
if count < size(forests,1)
forests = forests(1:count,:,:);
end
f_out.forests = forests;
f_out.priors = priors;
f_out.e1 = e1_average;
f_out.TrainNaN = NaNNbre;
f_out.Family = opt.Family;
%=== Compact forest
[ f_out ] = BRF_CompactForest(f_out);
f_out.xtrain_rk_normalized = xtrain_normalized;
f_out.xtrain_rk_normalized_sort = xtrain_normalized_sort;
end
function [p] = ll(pred,tar,sigma,group) % log likelihood
pred = invlogit(pred);
p = sum((-tar.*log(pred)-(1-tar).*log(1-pred)).*group);
end
function [p] = normll(pred,tar,sigma,group) % log likelihood
p = log(sigma)*sum(group) + 0.5*sum(((tar-pred).^2).*group)/(sigma^2);
end
function [p] = invlogit(p) % inverse logit
p = 1./(1+exp(-p));
end
function [forest adds sums] = initialize_forest( Ntrees , Nvar , ytrain )
forest=zeros(11,Ntrees); %initialize matrix of trees (current forest)
N=length(ytrain);
for i=1:3
forest(i,1:Ntrees) = randsample(1:Nvar,Ntrees,1);
end
forest(7:9,:)=0.5;
adds=zeros(N,Ntrees+1);
sums=zeros(N,1);
end
function cT = update_tree(cT,p,width,varNum,varUsed)
% tree is a vector has the following parameters
% - 1-3: Variables indices for first three nodes
% - 4-5: Threshods for nodes 1 and 2
% - 6: slope
% - 7-9 : missing value param for nodes 1 to 3
% - 10 : architechture type (1,2,3,4), i.e. location of the final node in the tree
% - 11: intercept
% ramdomly create a entirely new tree or a similar to speed up
p2 = 0.5; % Probability of propagating a variable towards bottom of the tree
leakage = 0.1; % Probability of ignoring weights for variable selection
varTemp = sum(varUsed(1:2,:),1);
if(rand(1)<p)
%=== Randomize variables to split on
for k=1:2 % For the split variables, weight them using both splits
if rand(1) < leakage
cT(k)=randsample(varNum,1,1); % sample ignoring weights
else
cT(k)=randsample(varNum,1,1,varTemp); % resample variables using weights
end
end
%=== Randomize variable to regress on
% For the regression variable, which is weighted independetly
if rand(1) < leakage
cT(3)=randsample(varNum,1,1); % sample ignoring weights
else
cT(3)=randsample(varNum,1,1,varUsed(3,:)); % resample variables using weights
end
% if(rand(1)<p2); cT(2)=cT(1); end % encourage 2nd split to be first split
% if(rand(1)<p2); cT(3)=cT(2); end % encourage 3rd split to be 2nd split
cT([4:5 7:9])=rand(5,1);
cT([6 11])=randn(2,1)*width;
cT(10)=randsample(1:4,1);
else
% tree similar to previous
cT([4:5 7:9])=mod(cT([4:5 7:9])+(rand(5,1)-0.5)/2,1);
cT([6 11])=1/2*cT([6 11])+randn(2,1)*sqrt(1-(1/2)^2)*width;
end
end