Predicting Mortality of ICU Patients: The PhysioNet/Computing in Cardiology Challenge 2012 1.0.0

File: <base>/sources/alistairewj_at_gmail.com/entry8/NicForest_FeatureImportance.m (1,904 bytes)
function [ FI ] = NicForest_FeatureImportance(f_out,xtrain,xtest)
%NICFOREST_FEATUREIMPORTANCE	Calculate feature importance for a given
%ensemble model trained on xtrain. xtest must be an unused held-out data
%set.
%
%	[ FI ] = NicForest_FeatureImportance(f_out,xtrain,xtest) calculates
%	the feature importance as determined by the utility of its
%	contributions in the ensemble of forests.
%
%	Inputs:
%		f_out       - Bayesian ensemble output of NicForest_train
%       xtrain      - Training data used for forests
%       xtest       - Data not used in the training of forests
%		
%
%	Outputs:
%		FI          - 1xD feature importance values
%		
%
%	Example
%       load PhysionetDataSetA
% 
%       xtrain=data(1:1500,:);
%       ytrain=outcome(1:1500);
%       xtest=data(1501:end,:);
%       ytest=outcome(1501:end);
%       
%       load forest.mat
%		[ fi ] = NicForest_FeatureImportance(forests,xtrain,xtest);
%	
%	See also NICFOREST_TRAIN APPLY_TREE TIEDRANKRELATIVE

%	$LastChangedBy: alistair $
%	$LastChangedDate: 2012-05-30 12:21:30 +0100 (Wed, 30 May 2012) $
%	$Revision: 21 $
%	Originally written on GLNXA64 by Alistair Johnson, 09-May-2012 14:58:24
%	Contact: alistairewj@gmail.com

forests = f_out.forests;

%=== Rank test set according to train set
xtest_rk = tiedrankrelative([xtest;xtrain],...
    [false(size(xtest,1),1);true(size(xtrain,1),1)]);
xtest_rk = xtest_rk(1:size(xtest,1),:);

%=== Calculate number of non-missing values in each feature and symmetrize
NaNNbr = sum( ~isnan(xtrain) , 1) + 1;

%=== Determine number of trees
NTrees = size(forests,3)-1;

% feature importance
FI = zeros(size(xtest));
for i=1:size(forests,1)
    for j=1:NTrees
        FI(:,forests(i,3,j)) = FI(:,forests(i,3,j)) + apply_tree( forests(i,:,j) , xtest_rk , NaNNbr );
    end
    if mod(i,10)==0
        fprintf('%2.2f%% complete.\n',i/size(forests,1)*100);
    end
end

FI = std(FI,[],1);

end