partiall workin cuda version

This commit is contained in:
Evghenii
2014-01-29 13:20:41 +01:00
parent 2f5eb9f6d3
commit 97253354ac
2 changed files with 31 additions and 40 deletions

View File

@@ -6,7 +6,7 @@
typedef long long Key; typedef long long Key;
__device__ int atomic_add_global(int* ptr, int value) __forceinline__ __device__ int atomic_add_global(int* ptr, int value)
{ {
return atomicAdd(ptr, value); return atomicAdd(ptr, value);
} }
@@ -24,13 +24,13 @@ static __device__ __forceinline__ int shfl_scan_add_step(int partial, int up_off
return result; return result;
} }
__device__ int exclusive_scan_add(int value) __forceinline__ __device__ int exclusive_scan_add(int value)
{ {
int mysum = value; int mysum = value;
#pragma unroll #pragma unroll
for(int i = 0; i < 5; ++i) for(int i = 0; i < 5; ++i)
mysum = shfl_scan_add_step(mysum, 1 << i); mysum = shfl_scan_add_step(mysum, 1 << i);
return mysum; return mysum - value;
} }
__global__ __global__
@@ -53,6 +53,7 @@ void countPass(
int * counts = countsAll + blkIdx*NUMDIGITS; int * counts = countsAll + blkIdx*NUMDIGITS;
const int nloc = min(numElements - blkIdx*blkDim, blkDim); const int nloc = min(numElements - blkIdx*blkDim, blkDim);
#pragma unroll 8
for (int digit = programIndex; digit < NUMDIGITS; digit += programCount) for (int digit = programIndex; digit < NUMDIGITS; digit += programCount)
counts[digit] = 0; counts[digit] = 0;
@@ -64,6 +65,7 @@ void countPass(
atomic_add_global(&counts[key], 1); atomic_add_global(&counts[key], 1);
} }
#pragma unroll 8
for (int digit = programIndex; digit < NUMDIGITS; digit += programCount) for (int digit = programIndex; digit < NUMDIGITS; digit += programCount)
atomic_add_global(&countsGlobal[digit], counts[digit]); atomic_add_global(&countsGlobal[digit], counts[digit]);
} }
@@ -90,12 +92,15 @@ void sortPass(
const int mask = (1 << NUMBITS) - 1; const int mask = (1 << NUMBITS) - 1;
const int unitScan = exclusive_scan_add(1);
/* copy digit offset from Gmem to Lmem */ /* copy digit offset from Gmem to Lmem */
__shared__ int digitOffsets[NUMDIGITS]; #if 1
__shared__ int digitOffsets_sh[NUMDIGITS*4];
int *digitOffsets = digitOffsets_sh + warpIdx*NUMDIGITS;
for (int digit = programIndex; digit < NUMDIGITS; digit += programCount) for (int digit = programIndex; digit < NUMDIGITS; digit += programCount)
digitOffsets[digit] = digitOffsetsAll[blkIdx*NUMDIGITS + digit]; digitOffsets[digit] = digitOffsetsAll[blkIdx*NUMDIGITS + digit];
#else
int *digitOffsets = &digitOffsetsAll[blkIdx*NUMDIGITS];
#endif
for (int i = programIndex; i < nloc; i += programCount) for (int i = programIndex; i < nloc; i += programCount)
@@ -104,8 +109,8 @@ void sortPass(
const int key = mask & ((unsigned int)keys[i] >> bit); const int key = mask & ((unsigned int)keys[i] >> bit);
int scatter; int scatter;
/* not a vector friendly loop */ /* not a vector friendly loop */
for (int lane = 0; lane < programCount; lane++) for (int iv = 0; iv < programCount; iv++)
if (programIndex == lane) if (programIndex == iv)
scatter = digitOffsets[key]++; scatter = digitOffsets[key]++;
sorted [scatter] = keys[i]; sorted [scatter] = keys[i];
} }
@@ -128,6 +133,7 @@ void partialScanLocal(
int (* excScanBlock)[NUMDIGITS] = ( int (*)[NUMDIGITS])excScanAll; int (* excScanBlock)[NUMDIGITS] = ( int (*)[NUMDIGITS])excScanAll;
int (* partialSum)[NUMDIGITS] = ( int (*)[NUMDIGITS])partialSumAll; int (* partialSum)[NUMDIGITS] = ( int (*)[NUMDIGITS])partialSumAll;
#pragma unroll 8
for (int digit = programIndex; digit < NUMDIGITS; digit += programCount) for (int digit = programIndex; digit < NUMDIGITS; digit += programCount)
{ {
int prev = bbeg == 0 ? excScanBlock[0][digit] : 0; int prev = bbeg == 0 ? excScanBlock[0][digit] : 0;
@@ -152,10 +158,10 @@ void partialScanGlobal(
const int digit = taskIndex; const int digit = taskIndex;
int carry = 0; int carry = 0;
for (int block = programIndex; block < numBlocks; block += programCount) for (int block = programIndex; block < numBlocks; block += programCount)
if (block < numBlocks)
{ {
const int value = partialSum[block][digit]; const int value = partialSum[block][digit];
const int scan = exclusive_scan_add(value); const int scan = exclusive_scan_add(value);
if (block < numBlocks)
prefixSum[block][digit] = scan + carry; prefixSum[block][digit] = scan + carry;
carry += __shfl(scan+value, programCount-1); carry += __shfl(scan+value, programCount-1);
} }
@@ -175,6 +181,7 @@ void completeScanGlobal(
int (* excScanBlock)[NUMDIGITS] = ( int (*)[NUMDIGITS])excScanAll; int (* excScanBlock)[NUMDIGITS] = ( int (*)[NUMDIGITS])excScanAll;
int (* carryValue)[NUMDIGITS] = ( int (*)[NUMDIGITS])carryValueAll; int (* carryValue)[NUMDIGITS] = ( int (*)[NUMDIGITS])carryValueAll;
#pragma unroll 8
for (int digit = programIndex; digit < NUMDIGITS; digit += programCount) for (int digit = programIndex; digit < NUMDIGITS; digit += programCount)
{ {
const int carry = carryValue[blkIdx][digit]; const int carry = carryValue[blkIdx][digit];
@@ -245,15 +252,6 @@ void radixSort_alloc___export(const int n)
if (programIndex == 0) if (programIndex == 0)
memoryPool = new int[nalloc]; memoryPool = new int[nalloc];
union {int* ptr; int val[2];} t;
t.ptr = memoryPool;
t.val[0] = __shfl(t.val[0], 0);
t.val[1] = __shfl(t.val[1], 0);
memoryPool = t.ptr;
sharedCounts = memoryPool; sharedCounts = memoryPool;
countsGlobal = sharedCounts + nSharedCounts; countsGlobal = sharedCounts + nSharedCounts;
excScan = countsGlobal + nCountsGlobal; excScan = countsGlobal + nCountsGlobal;
@@ -262,7 +260,7 @@ void radixSort_alloc___export(const int n)
prefixSum = partialSum + nPartialSum; prefixSum = partialSum + nPartialSum;
} }
extern "C" __global__ extern "C"
void radixSort_alloc(const int n) void radixSort_alloc(const int n)
{ {
radixSort_alloc___export<<<1,32>>>(n); radixSort_alloc___export<<<1,32>>>(n);
@@ -275,6 +273,7 @@ void radixSort_freeBufKeys()
{ {
if (numElementsBuf > 0) if (numElementsBuf > 0)
{ {
if (programIndex == 0)
delete bufKeys; delete bufKeys;
numElementsBuf = 0; numElementsBuf = 0;
} }
@@ -289,7 +288,7 @@ __global__ void radixSort_free___export()
radixSort_freeBufKeys(); radixSort_freeBufKeys();
} }
extern "C" __global__ extern "C"
void radixSort_free() void radixSort_free()
{ {
radixSort_free___export<<<1,32>>>(); radixSort_free___export<<<1,32>>>();
@@ -312,13 +311,6 @@ __global__ void radixSort___export(
numElementsBuf = numElements; numElementsBuf = numElements;
if (programIndex == 0) if (programIndex == 0)
bufKeys = new Key[numElementsBuf]; bufKeys = new Key[numElementsBuf];
union {Key* ptr; int val[2];} t;
t.ptr = bufKeys;
t.val[0] = __shfl(t.val[0], 0);
t.val[1] = __shfl(t.val[1], 0);
bufKeys = t.ptr;
} }
const int blkDim = (numElements + numBlocks - 1) / numBlocks; const int blkDim = (numElements + numBlocks - 1) / numBlocks;
@@ -336,6 +328,7 @@ __global__ void radixSort___export(
/* exclusive scan on global histogram */ /* exclusive scan on global histogram */
int carry = 0; int carry = 0;
excScan[0] = 0; excScan[0] = 0;
#pragma unroll 8
for (int digit = programIndex; digit < NUMDIGITS; digit += programCount) for (int digit = programIndex; digit < NUMDIGITS; digit += programCount)
{ {
const int value = countsGlobal[digit]; const int value = countsGlobal[digit];
@@ -357,14 +350,15 @@ __global__ void radixSort___export(
excScan); excScan);
sync; sync;
} }
} }
extern "C" __global__
extern "C"
void radixSort( void radixSort(
const int numElements, const int numElements,
Key keys[], Key keys[],
const int nBits) const int nBits)
{ {
cudaDeviceSetCacheConfig ( cudaFuncCachePreferEqual );
radixSort___export<<<1,32>>>(numElements, keys, nBits); radixSort___export<<<1,32>>>(numElements, keys, nBits);
sync; sync;
} }

View File

@@ -63,14 +63,11 @@ void sortPass(
const uniform int mask = (1 << NUMBITS) - 1; const uniform int mask = (1 << NUMBITS) - 1;
const int unitScan = exclusive_scan_add(1);
/* copy digit offset from Gmem to Lmem */ /* copy digit offset from Gmem to Lmem */
uniform int digitOffsets[NUMDIGITS]; uniform int digitOffsets[NUMDIGITS];
foreach (digit = 0 ... NUMDIGITS) foreach (digit = 0 ... NUMDIGITS)
digitOffsets[digit] = digitOffsetsAll[blockIdx*NUMDIGITS + digit]; digitOffsets[digit] = digitOffsetsAll[blockIdx*NUMDIGITS + digit];
foreach (i = 0 ... nloc) foreach (i = 0 ... nloc)
{ {
const int key = mask & ((unsigned int)keys[i] >> bit); const int key = mask & ((unsigned int)keys[i] >> bit);
@@ -237,7 +234,7 @@ export void radixSort_free()
delete memoryPool; delete memoryPool;
memoryPool = NULL; memoryPool = NULL;
radixSort_freeBufKeys; radixSort_freeBufKeys();
} }
export void radixSort( export void radixSort(