diff --git a/examples_ptx/radixSort/radixSort.cpp b/examples_ptx/radixSort/radixSort.cpp index af47b44d..f474ff2d 100644 --- a/examples_ptx/radixSort/radixSort.cpp +++ b/examples_ptx/radixSort/radixSort.cpp @@ -48,8 +48,8 @@ int main (int argc, char *argv[]) #pragma omp parallel for for (int i = 0; i < n; i++) { - keys[i].key = drand48() * (1<<30); -// keys[i].val = i; + keys[i].key = ((int)(drand48() * (1<<30))) & 0x00FFFFFF; + keys[i].val = i; } std::random_shuffle(keys, keys + n); @@ -70,7 +70,7 @@ int main (int argc, char *argv[]) { ispcMemcpy(keys, keys_orig, n*sizeof(Key)); reset_and_start_timer(); - ispc::radixSort(n, (int64_t*)keys); + ispc::radixSort(n, (int64_t*)keys, 32); tISPC2 = std::min(tISPC2, get_elapsed_msec()); if (argc != 3) progressbar (i, m); diff --git a/examples_ptx/radixSort/radixSort.ispc b/examples_ptx/radixSort/radixSort.ispc index 4e2a7452..6526a337 100644 --- a/examples_ptx/radixSort/radixSort.ispc +++ b/examples_ptx/radixSort/radixSort.ispc @@ -6,6 +6,7 @@ typedef int64 Key; task void countPass( const uniform Key keysAll[], + uniform Key sortedAll[], const uniform int bit, const uniform int numElements, uniform int countsAll[], @@ -17,7 +18,8 @@ void countPass( const uniform int mask = (1 << NUMBITS) - 1; - const uniform Key * uniform keys = keysAll + blockIdx*blockDim; + const uniform Key * uniform keys = keysAll + blockIdx*blockDim; + uniform Key * uniform sorted = sortedAll + blockIdx*blockDim; uniform int * uniform counts = countsAll + blockIdx*NUMDIGITS; const uniform int nloc = min(numElements - blockIdx*blockDim, blockDim); @@ -27,6 +29,7 @@ void countPass( #if 1 foreach (i = 0 ... nloc) { + sorted[i] = keys[i]; const int key = mask & ((unsigned int)keys[i] >> bit); uniform int skey; if (reduce_equal(key, &skey) == true) @@ -274,7 +277,8 @@ export void radixSort_free() export void radixSort( const uniform int numElements, - uniform Key keys[]) + uniform Key keys[], + const uniform int nBits) { #ifdef __NVPTX__ assert((numBlocks & 3) == 0); /* task granularity on Kepler is 4 */ @@ -290,14 +294,14 @@ export void radixSort( const uniform int blockDim = (numElements + numBlocks - 1) / numBlocks; - for (uniform int bit = 0; bit < 32; bit += NUMBITS) + for (uniform int bit = 0; bit < nBits; bit += NUMBITS) { /* initialize histogram for each digit */ foreach (digit = 0 ... NUMDIGITS) countsGlobal[digit] = 0; /* compute histogram for each digit */ - launch [numBlocks] countPass(keys, bit, numElements, counts, countsGlobal); + launch [numBlocks] countPass(keys, bufKeys, bit, numElements, counts, countsGlobal); sync; /* exclusive scan on global histogram */ @@ -317,17 +321,13 @@ export void radixSort( /* sorting */ launch [numBlocks] sortPass( - keys, bufKeys, + keys, bit, numElements, excScan, sharedCounts); sync; - - uniform Key * uniform tmp = keys; - keys = bufKeys; - bufKeys = tmp; } }