Wednesday, February 28, 2007

Wrapper for svmtrain.exe

% model = libsvmtrain(training_set, options)
%
% options:
% -s svm_type : set type of SVM (default 0)
% 0 -- C-SVC
% 1 -- nu-SVC
% 2 -- one-class SVM
% 3 -- epsilon-SVR
% 4 -- nu-SVR
% -t kernel_type : set type of kernel function (default 2)
% 0 -- linear: u'*v
% 1 -- polynomial: (gamma*u'*v + coef0)^degree
% 2 -- radial basis function: exp(-gamma*|u-v|^2)
% 3 -- sigmoid: tanh(gamma*u'*v + coef0)
% 4 -- precomputed kernel (kernel values in training_set_file)
% -d degree : set degree in kernel function (default 3)
% -g gamma : set gamma in kernel function (default 1/k)
% -r coef0 : set coef0 in kernel function (default 0)
% -c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
% -n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
% -p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
% -m cachesize : set cache memory size in MB (default 100)
% -e epsilon : set tolerance of termination criterion (default 0.001)
% -h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
% -b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
% -wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)
% -v n: n-fold cross validation mode
%
% NOTE1: This function actually executes LibSVM's svmtrain tool. For more
% info check out: http://www.csie.ntu.edu.tw/~cjlin/libsvm/
% NOTE2: Although LibSVM supports more types of labels, this wrapper is
% limited to integer labels only.
% NOTE3: This function assumes that svmtrain.exe is located in a directory
% that lies in your system's PATH.
function model = libsvmtrain(training_set, labels, varargin)
if size(training_set,1) ~= size(labels,1)
error('The training_set and labels should have the same number of rows.');
end;
if isempty(varargin)
options = struct([]);
else
options = varargin{1};
end;

% Dump the training set to a temporary file
datafile = tempname;
fid = fopen(datafile, 'wt');
dumpData = zeros(size(training_set,1), 2*size(training_set,2) + 1);
dumpData(:, 1) = labels(:, 1);
format = '%d';
for i=1:size(training_set, 2)
dumpData(:, 2*i) = repmat(i, size(training_set,1), 1);
dumpData(:, 2*i + 1) = training_set(:, i);
format = sprintf('%s %%d:%%f', format);
end;
format = sprintf('%s\n', format);
fprintf(fid, format, dumpData');
fclose(fid);

% Run svmtrain with all given options and write the model to a temporary
% file
modelfile = tempname;
command = 'svmtrain';

if isfield(options, 's')
command = sprintf('%s -s %d', options.s);
end;

if isfield(options, 't')
command = sprintf('%s -t %d', options.t);
end;

if isfield(options, 'd')
command = sprintf('%s -d %f', options.d);
end;

if isfield(options, 'g')
command = sprintf('%s -g %f', options.g);
end;

if isfield(options, 'r')
command = sprintf('%s -r %f', options.r);
end;

if isfield(options, 'c')
command = sprintf('%s -c %f', options.c);
end;

if isfield(options, 'n')
command = sprintf('%s -n %f', options.s);
end;

if isfield(options, 'p')
command = sprintf('%s -p %f', options.p);
end;

if isfield(options, 'm')
command = sprintf('%s -m %f', options.m);
end;

if isfield(options, 'e')
command = sprintf('%s -e %f', options.e);
end;

if isfield(options, 'h')
command = sprintf('%s -h %d', options.h);
end;

if isfield(options, 'b')
command = sprintf('%s -b %d', options.b);
end;

if isfield(options, 'wi')
command = sprintf('%s -wi %f', options.wi);
end;

if isfield(options, 'v')
command = sprintf('%s -v %d', options.v);
end;

dos(sprintf('%s %s %s', command, datafile, modelfile), '-echo');

% Load the model temporary file and return its content
fid = fopen(modelfile, 'r');
model = fread(fid, '*char')';
fclose(fid);

% For some reason, the data read has superfluous newlines (it seems to have
% doubled the newline characters or something like that).
model = regexprep(model, '\r', '');

return;

No comments: