Skip to content

Commit

Permalink
Fixed C{ROC,PRC}Evaluation classes
Browse files Browse the repository at this point in the history
  • Loading branch information
lisitsyn committed Jun 18, 2011
1 parent e8143c5 commit 80f768e
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 41 deletions.
16 changes: 8 additions & 8 deletions examples/undocumented/python_modular/graphical/prc.py
Expand Up @@ -30,23 +30,23 @@
# plot PRC for SVM
subplot(223)
PRC_evaluation=PRCEvaluation()
PRC_evaluation.evaluate(svm.classify(),labels)
PRC_evaluation.evaluate(svm.apply(),labels)
PRC = PRC_evaluation.get_PRC()
plot(PRC[:,0], PRC[:,1])
fill_between(PRC[:,0],PRC[:,1],0,alpha=0.1)
text(0.55,mean(PRC[:,1])/3,'auPRC = %.5f' % PRC_evaluation.get_auPRC())
plot(PRC[0], PRC[1])
fill_between(PRC[0],PRC[1],0,alpha=0.1)
text(0.55,mean(PRC[1])/3,'auPRC = %.5f' % PRC_evaluation.get_auPRC())
grid(True)
xlabel('Precision')
ylabel('Recall')
title('LibSVM (Gaussian kernel, C=%.3f) PRC curve' % svm.get_C1(),size=10)

# plot PRC for LDA
subplot(224)
PRC_evaluation.evaluate(lda.classify(),labels)
PRC_evaluation.evaluate(lda.apply(),labels)
PRC = PRC_evaluation.get_PRC()
plot(PRC[:,0], PRC[:,1])
fill_between(PRC[:,0],PRC[:,1],0,alpha=0.1)
text(0.55,mean(PRC[:,1])/3,'auPRC = %.5f' % PRC_evaluation.get_auPRC())
plot(PRC[0], PRC[1])
fill_between(PRC[0],PRC[1],0,alpha=0.1)
text(0.55,mean(PRC[1])/3,'auPRC = %.5f' % PRC_evaluation.get_auPRC())
grid(True)
xlabel('Precision')
ylabel('Recall')
Expand Down
17 changes: 9 additions & 8 deletions examples/undocumented/python_modular/graphical/roc.py
Expand Up @@ -30,23 +30,24 @@
# plot ROC for SVM
subplot(223)
ROC_evaluation=ROCEvaluation()
ROC_evaluation.evaluate(svm.classify(),labels)
ROC_evaluation.evaluate(svm.apply(),labels)
roc = ROC_evaluation.get_ROC()
plot(roc[:,0], roc[:,1])
fill_between(roc[:,0],roc[:,1],0,alpha=0.1)
text(mean(roc[:,0])/2,mean(roc[:,1])/2,'auROC = %.5f' % ROC_evaluation.get_auROC())
print roc
plot(roc[0], roc[1])
fill_between(roc[0],roc[1],0,alpha=0.1)
text(mean(roc[0])/2,mean(roc[1])/2,'auROC = %.5f' % ROC_evaluation.get_auROC())
grid(True)
xlabel('FPR')
ylabel('TPR')
title('LibSVM (Gaussian kernel, C=%.3f) ROC curve' % svm.get_C1(),size=10)

# plot ROC for LDA
subplot(224)
ROC_evaluation.evaluate(lda.classify(),labels)
ROC_evaluation.evaluate(lda.apply(),labels)
roc = ROC_evaluation.get_ROC()
plot(roc[:,0], roc[:,1])
fill_between(roc[:,0],roc[:,1],0,alpha=0.1)
text(mean(roc[:,0])/2,mean(roc[:,1])/2,'auROC = %.5f' % ROC_evaluation.get_auROC())
plot(roc[0], roc[1])
fill_between(roc[0],roc[1],0,alpha=0.1)
text(mean(roc[0])/2,mean(roc[1])/2,'auROC = %.5f' % ROC_evaluation.get_auROC())
grid(True)
xlabel('FPR')
ylabel('TPR')
Expand Down
11 changes: 4 additions & 7 deletions src/libshogun/evaluation/PRCEvaluation.cpp
Expand Up @@ -68,9 +68,9 @@ float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
tp += 1.0;

// precision (x)
m_PRC_graph[i] = tp/(i+1);
m_PRC_graph[2*i] = tp/(i+1);
// recall (y)
m_PRC_graph[length+i] = tp/pos_count;
m_PRC_graph[2*i+1] = tp/pos_count;
}

// calc auRPC using area under curve
Expand All @@ -83,17 +83,14 @@ float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
return m_auPRC;
}

void CPRCEvaluation::get_PRC(float64_t** result, int32_t* num, int32_t* dim)
SGMatrix<float64_t> CPRCEvaluation::get_PRC()
{
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_PRC_graph);
*num = m_PRC_length;
*dim = 2;

*result = (float64_t*) SG_MALLOC(sizeof(float64_t)*m_PRC_length*2);
memcpy(*result, m_PRC_graph, m_PRC_length*2*sizeof(float64_t));
return SGMatrix<float64_t>(m_PRC_graph,2,m_PRC_length);
}

float64_t CPRCEvaluation::get_auPRC()
Expand Down
8 changes: 3 additions & 5 deletions src/libshogun/evaluation/PRCEvaluation.h
Expand Up @@ -50,12 +50,10 @@ class CPRCEvaluation: public CBinaryClassEvaluation
*/
float64_t get_auPRC();

/** get PRC (swig)
* @param result matrix of PRC graph
* @param num number of points in PRC graph
* @param dim dimensionality (always 2)
/** get PRC
* @return PRC graph matrix
*/
void get_PRC(float64_t** result, int32_t* num, int32_t* dim);
SGMatrix<float64_t> get_PRC();

protected:

Expand Down
13 changes: 5 additions & 8 deletions src/libshogun/evaluation/ROCEvaluation.cpp
Expand Up @@ -88,8 +88,8 @@ float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
if (label != threshold)
{
threshold = label;
m_ROC_graph[j] = fp/neg_count;
m_ROC_graph[j+diff_count+1] = tp/pos_count;
m_ROC_graph[2*j] = fp/neg_count;
m_ROC_graph[2*j+1] = tp/pos_count;
j++;
}

Expand All @@ -100,7 +100,7 @@ float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
}

// add (1,1) to ROC curve
m_ROC_graph[diff_count] = 1.0;
m_ROC_graph[2*diff_count] = 1.0;
m_ROC_graph[2*diff_count+1] = 1.0;

// set ROC length
Expand All @@ -114,17 +114,14 @@ float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
return m_auROC;
}

void CROCEvaluation::get_ROC(float64_t** result, int32_t* num, int32_t* dim)
SGMatrix<float64_t> CROCEvaluation::get_ROC()
{
if (!m_computed)
SG_ERROR("Uninitialized, please call evaluate first");

ASSERT(m_ROC_graph);
*num = m_ROC_length;
*dim = 2;

*result = (float64_t*) SG_MALLOC(sizeof(float64_t)*m_ROC_length*2);
memcpy(*result, m_ROC_graph, m_ROC_length*2*sizeof(float64_t));
return SGMatrix<float64_t>(m_ROC_graph,2,m_ROC_length);
}

float64_t CROCEvaluation::get_auROC()
Expand Down
8 changes: 3 additions & 5 deletions src/libshogun/evaluation/ROCEvaluation.h
Expand Up @@ -54,12 +54,10 @@ class CROCEvaluation: public CBinaryClassEvaluation
*/
float64_t get_auROC();

/** get ROC (swig)
* @param result matrix of ROC graph
* @param num number of points in ROC graph
* @param dim dimensionality (always 2)
/** get ROC
* @return ROC graph matrix
*/
void get_ROC(float64_t** result, int32_t* num, int32_t* dim);
SGMatrix<float64_t> get_ROC();

protected:

Expand Down

0 comments on commit 80f768e

Please sign in to comment.