fix for non-pow2 number of elements
This commit is contained in:
@@ -2,7 +2,7 @@ PROG=mergeSort
|
||||
ISPC_SRC=mergeSort.ispc
|
||||
#CU_SRC=mergeSort.cu
|
||||
CXX_SRC=mergeSort.cpp mergeSort.cpp
|
||||
PTXCC_REGMAX=32
|
||||
PTXCC_REGMAX=64
|
||||
#PTXCC_FLAGS= -Xptxas=-O3
|
||||
|
||||
# LLVM_GPU=1
|
||||
|
||||
@@ -225,6 +225,8 @@ void generateSampleRanksKernel(
|
||||
uniform int totalProgramCount)
|
||||
{
|
||||
const int pos = taskIndex * programCount + programIndex;
|
||||
cif (pos >= totalProgramCount)
|
||||
return;
|
||||
|
||||
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
||||
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
||||
@@ -285,6 +287,8 @@ void mergeRanksAndIndicesKernel(
|
||||
uniform int totalProgramCount)
|
||||
{
|
||||
int pos = taskIndex * programCount + programIndex;
|
||||
cif (pos >= totalProgramCount)
|
||||
return;
|
||||
|
||||
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
||||
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
||||
@@ -477,8 +481,101 @@ void mergeElementaryIntervalsKernel(
|
||||
|
||||
}
|
||||
|
||||
task
|
||||
void mergeElementaryIntervalsKernel(
|
||||
uniform int mergePairs,
|
||||
uniform Key_t dstKey[],
|
||||
uniform Val_t dstVal[],
|
||||
uniform Key_t srcKey[],
|
||||
uniform Val_t srcVal[],
|
||||
uniform int limitsA[],
|
||||
uniform int limitsB[],
|
||||
uniform int stride,
|
||||
uniform int N)
|
||||
{
|
||||
const uniform int blockIdx = taskIndex;
|
||||
const uniform int blockDim = (mergePairs + taskCount - 1)/taskCount;
|
||||
const uniform int blockBeg = blockIdx * blockDim;
|
||||
const uniform int blockEnd = min(blockBeg + blockDim, mergePairs);
|
||||
|
||||
for (uniform int block = blockBeg; block < blockEnd; block++)
|
||||
{
|
||||
const int uniform intervalI = block & ((2 * stride) / SAMPLE_STRIDE - 1);
|
||||
const int uniform segmentBase = (block - intervalI) * SAMPLE_STRIDE;
|
||||
|
||||
//Set up threadblock-wide parameters
|
||||
|
||||
const uniform int segmentElementsA = stride;
|
||||
const uniform int segmentElementsB = min(stride, N - segmentBase - stride);
|
||||
const uniform int segmentSamplesA = getSampleCount(segmentElementsA);
|
||||
const uniform int segmentSamplesB = getSampleCount(segmentElementsB);
|
||||
const uniform int segmentSamples = segmentSamplesA + segmentSamplesB;
|
||||
|
||||
const uniform int startSrcA = limitsA[block];
|
||||
const uniform int startSrcB = limitsB[block];
|
||||
const uniform int endSrcA = (intervalI + 1 < segmentSamples) ? limitsA[block + 1] : segmentElementsA;
|
||||
const uniform int endSrcB = (intervalI + 1 < segmentSamples) ? limitsB[block + 1] : segmentElementsB;
|
||||
const uniform int lenSrcA = endSrcA - startSrcA;
|
||||
const uniform int lenSrcB = endSrcB - startSrcB;
|
||||
const uniform int startDstA = startSrcA + startSrcB;
|
||||
const uniform int startDstB = startDstA + lenSrcA;
|
||||
|
||||
//Load main input data
|
||||
|
||||
Key_t keyA, keyB;
|
||||
Val_t valA, valB;
|
||||
if (programIndex < lenSrcA)
|
||||
{
|
||||
keyA = srcKey[segmentBase + startSrcA + programIndex];
|
||||
valA = srcVal[segmentBase + startSrcA + programIndex];
|
||||
}
|
||||
|
||||
if (programIndex < lenSrcB)
|
||||
{
|
||||
keyB = srcKey[segmentBase + stride + startSrcB + programIndex];
|
||||
valB = srcVal[segmentBase + stride + startSrcB + programIndex];
|
||||
}
|
||||
|
||||
// Compute destination addresses for merge data
|
||||
int dstPosA, dstPosB;
|
||||
if (programIndex < lenSrcA)
|
||||
dstPosA = binarySearchExclusive1(keyA, keyB, lenSrcB, SAMPLE_STRIDE) + programIndex;
|
||||
if (programIndex < lenSrcB)
|
||||
dstPosB = binarySearchInclusive1(keyB, keyA, lenSrcA, SAMPLE_STRIDE) + programIndex;
|
||||
|
||||
|
||||
int dstA = -1, dstB = -1;
|
||||
if (programIndex < lenSrcA && dstPosA < lenSrcA)
|
||||
dstA = segmentBase + startDstA + dstPosA;
|
||||
if (programIndex < lenSrcB && dstPosB < lenSrcA)
|
||||
dstB = segmentBase + startDstA + dstPosB;
|
||||
|
||||
dstPosA -= lenSrcA;
|
||||
dstPosB -= lenSrcA;
|
||||
if (programIndex < lenSrcA && dstPosA < lenSrcB)
|
||||
dstA = segmentBase + startDstB + dstPosA;
|
||||
if (programIndex < lenSrcB && dstPosB < lenSrcB)
|
||||
dstB = segmentBase + startDstB + dstPosB;
|
||||
|
||||
// store merge data
|
||||
if (dstA >= 0)
|
||||
{
|
||||
dstKey[dstA] = keyA;
|
||||
dstVal[dstA] = valA;
|
||||
}
|
||||
if (dstB >= 0)
|
||||
{
|
||||
dstKey[dstB] = keyB;
|
||||
dstVal[dstB] = valB;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
static inline
|
||||
void mergeElementaryIntervals(
|
||||
uniform int nTasks,
|
||||
uniform Key_t dstKey[],
|
||||
uniform Val_t dstVal[],
|
||||
uniform Key_t srcKey[],
|
||||
@@ -492,6 +589,7 @@ void mergeElementaryIntervals(
|
||||
const uniform int mergePairs = (lastSegmentElements > stride) ? getSampleCount(N) : (N - lastSegmentElements) / SAMPLE_STRIDE;
|
||||
|
||||
|
||||
#if 0
|
||||
launch [mergePairs] mergeElementaryIntervalsKernel(
|
||||
dstKey,
|
||||
dstVal,
|
||||
@@ -501,6 +599,18 @@ void mergeElementaryIntervals(
|
||||
limitsB,
|
||||
stride,
|
||||
N);
|
||||
#else
|
||||
launch [nTasks] mergeElementaryIntervalsKernel(
|
||||
mergePairs,
|
||||
dstKey,
|
||||
dstVal,
|
||||
srcKey,
|
||||
srcVal,
|
||||
limitsA,
|
||||
limitsB,
|
||||
stride,
|
||||
N);
|
||||
#endif
|
||||
sync;
|
||||
}
|
||||
|
||||
@@ -516,6 +626,9 @@ export
|
||||
void openMergeSort()
|
||||
{
|
||||
nTasks = num_cores()*4;
|
||||
#ifdef __NVPTX__
|
||||
nTasks = num_cores()*13;
|
||||
#endif
|
||||
MAX_SAMPLE_COUNT = 8*32 * 131072 / programCount;
|
||||
assert(memPool == NULL);
|
||||
const uniform int nalloc = MAX_SAMPLE_COUNT * 4;
|
||||
@@ -582,7 +695,7 @@ void mergeSort(
|
||||
mergeRanksAndIndices(limitsA, limitsB, ranksA, ranksB, stride, N);
|
||||
|
||||
//Merge elementary intervals
|
||||
mergeElementaryIntervals(oKey, oVal, iKey, iVal, limitsA, limitsB, stride, N);
|
||||
mergeElementaryIntervals(nTasks, oKey, oVal, iKey, iVal, limitsA, limitsB, stride, N);
|
||||
|
||||
if (lastSegmentElements <= stride)
|
||||
foreach (i = 0 ... lastSegmentElements)
|
||||
|
||||
Reference in New Issue
Block a user