Skip to content

Commit

Permalink
add fishers statistics test for 2x3 tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Soeren Sonnenburg committed Jul 7, 2011
1 parent 52f2f0a commit 5bd07dc
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 155 deletions.
345 changes: 194 additions & 151 deletions src/libshogun/lib/Mathematics.cpp
Expand Up @@ -77,88 +77,228 @@ CMath::~CMath()
#endif
}

#ifdef USE_LOGCACHE
int32_t CMath::determine_logrange()
namespace shogun
{
int32_t i;
float64_t acc=0;
for (i=0; i<50; i++)
{
acc=((float64_t)log(1+((float64_t)exp(-float64_t(i)))));
if (acc<=(float64_t)LOG_TABLE_PRECISION)
break;
}
template <>
void CMath::display_vector(const uint8_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%d%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}

SG_SINFO( "determined range for x in table log(1+exp(-x)) is:%d (error:%G)\n",i,acc);
return i;
template <>
void CMath::display_vector(const int32_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%d%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}

int32_t CMath::determine_logaccuracy(int32_t range)
template <>
void CMath::display_vector(const int64_t* vector, int32_t n, const char* name)
{
range=MAX_LOG_TABLE_SIZE/range/((int)sizeof(float64_t));
SG_SINFO( "determined accuracy for x in table log(1+exp(-x)) is:%d (error:%G)\n",range,1.0/(double) range);
return range;
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%lld%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}

SGVector<float64_t> CMath::fishers_exact_test_for_multiple_3x2_tables(SGMatrix<float64_t> tables, float64_t epsilon)
template <>
void CMath::display_vector(const uint64_t* vector, int32_t n, const char* name)
{
SGMatrix<float64_t> table(3,2);
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%llu%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}

template <>
void CMath::display_vector(const float32_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%10.10f%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}

template <>
void CMath::display_vector(const float64_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%10.10f%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}

template <>
void CMath::display_matrix(
const int32_t* matrix, int32_t rows, int32_t cols, const char* name)
{
ASSERT(rows>=0 && cols>=0);
SG_SPRINT("%s=[\n", name);
for (int32_t i=0; i<rows; i++)
{
SG_SPRINT("[");
for (int32_t j=0; j<cols; j++)
SG_SPRINT("\t%d%s", matrix[j*rows+i],
j==cols-1? "" : ",");
SG_SPRINT("]%s\n", i==rows-1? "" : ",");
}
SG_SPRINT("]\n");
}

float64_t CMath::fishers_exact_test_for_3x2_table(SGMatrix<float64_t> table, float64_t epsilon)
template <>
void CMath::display_matrix(
const float64_t* matrix, int32_t rows, int32_t cols, const char* name)
{
ASSERT(rows>=0 && cols>=0);
SG_SPRINT("%s=[\n", name);
for (int32_t i=0; i<rows; i++)
{
SG_SPRINT("[");
for (int32_t j=0; j<cols; j++)
SG_SPRINT("\t%lf%s", (double) matrix[j*rows+i],
j==cols-1? "" : ",");
SG_SPRINT("]%s\n", i==rows-1? "" : ",");
}
SG_SPRINT("]\n");
}
}

function [nonrand_p rand_p, prob_table] = fisher_p23_fast(obs, epsilon)
SGVector<float64_t> CMath::fishers_exact_test_for_multiple_2x3_tables(SGMatrix<float64_t> tables)
{
SGMatrix<float64_t> table(NULL,2,3);
int32_t len=tables.num_cols/3;

%build marginals = (n1., n2., n.1, n.2, n.3)
marginals = [sum(obs, 2)', sum(obs, 1)];
SGVector<float64_t> v(len);
for (int32_t i=0; i<len; i++)
{
table.matrix=&tables.matrix[2*3*i];
v.vector[i]=fishers_exact_test_for_2x3_table(table);
}
return v;
}

n = sum(marginals) / 2;
float64_t CMath::fishers_exact_test_for_2x3_table(SGMatrix<float64_t> table)
{
ASSERT(table.num_rows==2);
ASSERT(table.num_cols==3);

x = zeros(2, 3, max(marginals)*max(marginals));
int32_t m_len=3+2;
float64_t* m=new float64_t[3+2];
m[0]=table.matrix[0]+table.matrix[2]+table.matrix[4];
m[1]=table.matrix[1]+table.matrix[3]+table.matrix[5];
m[2]=table.matrix[0]+table.matrix[1];
m[3]=table.matrix[2]+table.matrix[3];
m[4]=table.matrix[4]+table.matrix[5];

%build log nominator statistic
log_nom = sum(sum(gammaln(marginals+1))) - gammaln(n+1);
float64_t n = CMath::sum(m, m_len) / 2.0;
int32_t x_len=2*3* CMath::sq(CMath::max(m, m_len));
float64_t* x = new float64_t[x_len];
CMath::fill_vector(x, x_len, 0.0);

%build log denominator statistic
log_denom = sum(sum(gammaln(obs+1)));
float64_t log_nom=-CMath::lgamma(n+1);
for (int32_t i=0; i<3+2; i++)
log_nom+=CMath::lgamma(m[i]+1);

float64_t log_denom=0;
for (int32_t i=0; i<3*2; i++)
log_denom+=CMath::lgamma(table.matrix[i]+1);

%compute probability of observed table
prob_table = exp(log_nom - log_denom);
float64_t prob_table=CMath::exp(log_nom - log_denom);

int32_t dim1 = CMath::min(m[0], m[2]);

nonrand_p = 0.0;
rand_count = 0;
//traverse all possible tables with given m
int32_t counter = 0;
for (int32_t k=0; k<=dim1; k++)
{
for (int32_t l=CMath::max(0.0,m[0]-m[4]-k); l<=CMath::min(m[0]-k, m[3]); l++)
{
x[0 + 0*2 + counter*2*3] = k;
x[0 + 1*2 + counter*2*3] = l;
x[0 + 2*2 + counter*2*3] = m[0] - x[0 + 0*2 + counter*2*3] - x[0 + 1*2 + counter*2*3];
x[1 + 0*2 + counter*2*3] = m[2] - x[0 + 0*2 + counter*2*3];
x[1 + 1*2 + counter*2*3] = m[3] - x[0 + 1*2 + counter*2*3];
x[1 + 2*2 + counter*2*3] = m[4] - x[0 + 2*2 + counter*2*3];

counter++;
}
}

#ifdef DEBUG_FISHER_TABLE
SG_SPRINT("log_denom=%g\n", log_denom);
SG_SPRINT("log_nom=%g\n", log_nom);
display_vector(m, m_len, "marginals");
display_vector(x, counter, "x");
#endif // DEBUG_FISHER_TABLE


float64_t* log_denom_vec=new float64_t[counter];
CMath::fill_vector(log_denom_vec, counter, 0.0);

for (int32_t i=0; i<2; i++)
{
for (int32_t j=0; j<3; j++)
{
for (int32_t k=0; k<counter; k++)
log_denom_vec[k]+=CMath::lgamma(x[i + j*2 + k*2*3]+1);
}
}

for (int32_t i=0; i<counter; i++)
log_denom_vec[i]=CMath::exp(log_nom-log_denom_vec[i]);

dim1 = min(marginals(1), marginals(3));
#ifdef DEBUG_FISHER_TABLE
display_vector(log_denom_vec, counter, "log_denom_vec");
#endif // DEBUG_FISHER_TABLE

%traverse all possible tables with given marginals
counter = 0;
for k=0:dim1
for l=max(0,marginals(1)-marginals(5)-k):min(marginals(1)-k, marginals(4))
counter = counter+1;
x(1, 1, counter) = k;
x(1, 2, counter) = l;
x(1, 3, counter) = marginals(1) - x(1, 1, counter) - x(1, 2, counter);
x(2, 1, counter) = marginals(3) - x(1, 1, counter);
x(2, 2, counter) = marginals(4) - x(1, 2, counter);
x(2, 3, counter) = marginals(5) - x(1, 3, counter);
end
end
log_denom = sum(sum(gammaln(x(:, :, 1:counter)+1)));
log_denom = reshape(log_denom, 1, counter);
prob_lauf = exp(log_nom - log_denom);
nonrand_p = sum(prob_lauf(find(prob_lauf <= prob_table)));
rand_count = length(find(abs(exp(prob_lauf) - exp(prob_table)) < epsilon));

u = rand(1);
rand_p = max(0.0, nonrand_p - u*rand_count*prob_table);
% End of function 'fisher_p23_fast'.
float64_t nonrand_p=0.0;

for (int32_t i=0; i<counter; i++)
{
if (log_denom_vec[i]<=prob_table)
nonrand_p += log_denom_vec[i];
}
delete[] log_denom_vec;
delete[] x;
delete[] m;

return nonrand_p;
}


#ifdef USE_LOGCACHE
int32_t CMath::determine_logrange()
{
int32_t i;
float64_t acc=0;
for (i=0; i<50; i++)
{
acc=((float64_t)log(1+((float64_t)exp(-float64_t(i)))));
if (acc<=(float64_t)LOG_TABLE_PRECISION)
break;
}

SG_SINFO( "determined range for x in table log(1+exp(-x)) is:%d (error:%G)\n",i,acc);
return i;
}

int32_t CMath::determine_logaccuracy(int32_t range)
{
range=MAX_LOG_TABLE_SIZE/range/((int)sizeof(float64_t));
SG_SINFO( "determined accuracy for x in table log(1+exp(-x)) is:%d (error:%G)\n",range,1.0/(double) range);
return range;
}

//init log table of form log(1+exp(x))
Expand Down Expand Up @@ -330,100 +470,3 @@ float64_t* CMath::pinv(
return target;
}
#endif
namespace shogun
{
template <>
void CMath::display_vector(const uint8_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%d%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}
template <>
void CMath::display_vector(const int32_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%d%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}
template <>
void CMath::display_vector(const int64_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%lld%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}
template <>
void CMath::display_vector(const uint64_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%llu%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}
template <>
void CMath::display_vector(const float32_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%10.10f%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}
template <>
void CMath::display_vector(const float64_t* vector, int32_t n, const char* name)
{
ASSERT(n>=0);
SG_SPRINT("%s=[", name);
for (int32_t i=0; i<n; i++)
SG_SPRINT("%10.10f%s", vector[i], i==n-1? "" : ",");
SG_SPRINT("]\n");
}
template <>
void CMath::display_matrix(
const int32_t* matrix, int32_t rows, int32_t cols, const char* name)
{
ASSERT(rows>=0 && cols>=0);
SG_SPRINT("%s=[\n", name);
for (int32_t i=0; i<rows; i++)
{
SG_SPRINT("[");
for (int32_t j=0; j<cols; j++)
SG_SPRINT("\t%d%s", matrix[j*rows+i],
j==cols-1? "" : ",");
SG_SPRINT("]%s\n", i==rows-1? "" : ",");
}
SG_SPRINT("]\n");
}
template <>
void CMath::display_matrix(
const float64_t* matrix, int32_t rows, int32_t cols, const char* name)
{
ASSERT(rows>=0 && cols>=0);
SG_SPRINT("%s=[\n", name);
for (int32_t i=0; i<rows; i++)
{
SG_SPRINT("[");
for (int32_t j=0; j<cols; j++)
SG_SPRINT("\t%lf%s", (double) matrix[j*rows+i],
j==cols-1? "" : ",");
SG_SPRINT("]%s\n", i==rows-1? "" : ",");
}
SG_SPRINT("]\n");
}
}

0 comments on commit 5bd07dc

Please sign in to comment.