Skip to content

Commit

Permalink
apllied interface changes in other cross-validation class
Browse files Browse the repository at this point in the history
  • Loading branch information
karlnapf committed Jul 15, 2011
1 parent b454101 commit ae9e30f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
Expand Up @@ -21,7 +21,7 @@
parameter_list = [[traindat,label_traindat]]

def modelselection_grid_search_simple(traindat=traindat, label_traindat=label_traindat):
from shogun.Evaluation import CrossValidation
from shogun.Evaluation import CrossValidation, CrossValidationResult
from shogun.Evaluation import ContingencyTableEvaluation, ACCURACY
from shogun.Evaluation import StratifiedCrossValidationSplitting
from shogun.Modelselection import GridSearchModelSelection
Expand Down Expand Up @@ -73,7 +73,8 @@ def modelselection_grid_search_simple(traindat=traindat, label_traindat=label_tr

# apply them and print result
best_parameters.apply_to_machine(classifier)
print "accuracy: " + repr(cross_validation.evaluate())
result=cross_validation.evaluate()
result.print_result()

if __name__=='__main__':
print 'GridSearchSimple'
Expand Down
12 changes: 6 additions & 6 deletions src/libshogun/modelselection/GridSearchModelSelection.cpp
Expand Up @@ -40,26 +40,26 @@ CParameterCombination* CGridSearchModelSelection::select_model()
CDynamicObjectArray<CParameterCombination>* combinations=
m_model_parameters->get_combinations();

float64_t best_result;
CrossValidationResult best_result;

CParameterCombination* best_combination=NULL;
if (m_cross_validation->get_evaluation_direction()==ED_MAXIMIZE)
best_result=CMath::ALMOST_NEG_INFTY;
best_result.value=CMath::ALMOST_NEG_INFTY;
else
best_result=CMath::ALMOST_INFTY;
best_result.value=CMath::ALMOST_INFTY;

/* apply all combinations and search for best one */
for (index_t i=0; i<combinations->get_num_elements(); ++i)
{
CParameterCombination* current_combination=combinations->get_element(i);
current_combination->apply_to_parameter(
m_cross_validation->get_machine_parameters());
float64_t result=m_cross_validation->evaluate();
CrossValidationResult result=m_cross_validation->evaluate();

/* check if current result is better, delete old combinations */
if (m_cross_validation->get_evaluation_direction()==ED_MAXIMIZE)
{
if (result>best_result)
if (result.value>best_result.value)
{
if (best_combination)
SG_UNREF(best_combination);
Expand All @@ -75,7 +75,7 @@ CParameterCombination* CGridSearchModelSelection::select_model()
}
else
{
if (result<best_result)
if (result.value<best_result.value)
{
if (best_combination)
SG_UNREF(best_combination);
Expand Down

0 comments on commit ae9e30f

Please sign in to comment.