fix for non-pow2 number of elements
This commit is contained in:
@@ -2,7 +2,7 @@ PROG=mergeSort
|
|||||||
ISPC_SRC=mergeSort.ispc
|
ISPC_SRC=mergeSort.ispc
|
||||||
#CU_SRC=mergeSort.cu
|
#CU_SRC=mergeSort.cu
|
||||||
CXX_SRC=mergeSort.cpp mergeSort.cpp
|
CXX_SRC=mergeSort.cpp mergeSort.cpp
|
||||||
PTXCC_REGMAX=32
|
PTXCC_REGMAX=64
|
||||||
#PTXCC_FLAGS= -Xptxas=-O3
|
#PTXCC_FLAGS= -Xptxas=-O3
|
||||||
|
|
||||||
# LLVM_GPU=1
|
# LLVM_GPU=1
|
||||||
|
|||||||
@@ -225,6 +225,8 @@ void generateSampleRanksKernel(
|
|||||||
uniform int totalProgramCount)
|
uniform int totalProgramCount)
|
||||||
{
|
{
|
||||||
const int pos = taskIndex * programCount + programIndex;
|
const int pos = taskIndex * programCount + programIndex;
|
||||||
|
cif (pos >= totalProgramCount)
|
||||||
|
return;
|
||||||
|
|
||||||
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
||||||
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
||||||
@@ -285,6 +287,8 @@ void mergeRanksAndIndicesKernel(
|
|||||||
uniform int totalProgramCount)
|
uniform int totalProgramCount)
|
||||||
{
|
{
|
||||||
int pos = taskIndex * programCount + programIndex;
|
int pos = taskIndex * programCount + programIndex;
|
||||||
|
cif (pos >= totalProgramCount)
|
||||||
|
return;
|
||||||
|
|
||||||
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
||||||
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
||||||
@@ -477,8 +481,101 @@ void mergeElementaryIntervalsKernel(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
task
|
||||||
|
void mergeElementaryIntervalsKernel(
|
||||||
|
uniform int mergePairs,
|
||||||
|
uniform Key_t dstKey[],
|
||||||
|
uniform Val_t dstVal[],
|
||||||
|
uniform Key_t srcKey[],
|
||||||
|
uniform Val_t srcVal[],
|
||||||
|
uniform int limitsA[],
|
||||||
|
uniform int limitsB[],
|
||||||
|
uniform int stride,
|
||||||
|
uniform int N)
|
||||||
|
{
|
||||||
|
const uniform int blockIdx = taskIndex;
|
||||||
|
const uniform int blockDim = (mergePairs + taskCount - 1)/taskCount;
|
||||||
|
const uniform int blockBeg = blockIdx * blockDim;
|
||||||
|
const uniform int blockEnd = min(blockBeg + blockDim, mergePairs);
|
||||||
|
|
||||||
|
for (uniform int block = blockBeg; block < blockEnd; block++)
|
||||||
|
{
|
||||||
|
const int uniform intervalI = block & ((2 * stride) / SAMPLE_STRIDE - 1);
|
||||||
|
const int uniform segmentBase = (block - intervalI) * SAMPLE_STRIDE;
|
||||||
|
|
||||||
|
//Set up threadblock-wide parameters
|
||||||
|
|
||||||
|
const uniform int segmentElementsA = stride;
|
||||||
|
const uniform int segmentElementsB = min(stride, N - segmentBase - stride);
|
||||||
|
const uniform int segmentSamplesA = getSampleCount(segmentElementsA);
|
||||||
|
const uniform int segmentSamplesB = getSampleCount(segmentElementsB);
|
||||||
|
const uniform int segmentSamples = segmentSamplesA + segmentSamplesB;
|
||||||
|
|
||||||
|
const uniform int startSrcA = limitsA[block];
|
||||||
|
const uniform int startSrcB = limitsB[block];
|
||||||
|
const uniform int endSrcA = (intervalI + 1 < segmentSamples) ? limitsA[block + 1] : segmentElementsA;
|
||||||
|
const uniform int endSrcB = (intervalI + 1 < segmentSamples) ? limitsB[block + 1] : segmentElementsB;
|
||||||
|
const uniform int lenSrcA = endSrcA - startSrcA;
|
||||||
|
const uniform int lenSrcB = endSrcB - startSrcB;
|
||||||
|
const uniform int startDstA = startSrcA + startSrcB;
|
||||||
|
const uniform int startDstB = startDstA + lenSrcA;
|
||||||
|
|
||||||
|
//Load main input data
|
||||||
|
|
||||||
|
Key_t keyA, keyB;
|
||||||
|
Val_t valA, valB;
|
||||||
|
if (programIndex < lenSrcA)
|
||||||
|
{
|
||||||
|
keyA = srcKey[segmentBase + startSrcA + programIndex];
|
||||||
|
valA = srcVal[segmentBase + startSrcA + programIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (programIndex < lenSrcB)
|
||||||
|
{
|
||||||
|
keyB = srcKey[segmentBase + stride + startSrcB + programIndex];
|
||||||
|
valB = srcVal[segmentBase + stride + startSrcB + programIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute destination addresses for merge data
|
||||||
|
int dstPosA, dstPosB;
|
||||||
|
if (programIndex < lenSrcA)
|
||||||
|
dstPosA = binarySearchExclusive1(keyA, keyB, lenSrcB, SAMPLE_STRIDE) + programIndex;
|
||||||
|
if (programIndex < lenSrcB)
|
||||||
|
dstPosB = binarySearchInclusive1(keyB, keyA, lenSrcA, SAMPLE_STRIDE) + programIndex;
|
||||||
|
|
||||||
|
|
||||||
|
int dstA = -1, dstB = -1;
|
||||||
|
if (programIndex < lenSrcA && dstPosA < lenSrcA)
|
||||||
|
dstA = segmentBase + startDstA + dstPosA;
|
||||||
|
if (programIndex < lenSrcB && dstPosB < lenSrcA)
|
||||||
|
dstB = segmentBase + startDstA + dstPosB;
|
||||||
|
|
||||||
|
dstPosA -= lenSrcA;
|
||||||
|
dstPosB -= lenSrcA;
|
||||||
|
if (programIndex < lenSrcA && dstPosA < lenSrcB)
|
||||||
|
dstA = segmentBase + startDstB + dstPosA;
|
||||||
|
if (programIndex < lenSrcB && dstPosB < lenSrcB)
|
||||||
|
dstB = segmentBase + startDstB + dstPosB;
|
||||||
|
|
||||||
|
// store merge data
|
||||||
|
if (dstA >= 0)
|
||||||
|
{
|
||||||
|
dstKey[dstA] = keyA;
|
||||||
|
dstVal[dstA] = valA;
|
||||||
|
}
|
||||||
|
if (dstB >= 0)
|
||||||
|
{
|
||||||
|
dstKey[dstB] = keyB;
|
||||||
|
dstVal[dstB] = valB;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static inline
|
static inline
|
||||||
void mergeElementaryIntervals(
|
void mergeElementaryIntervals(
|
||||||
|
uniform int nTasks,
|
||||||
uniform Key_t dstKey[],
|
uniform Key_t dstKey[],
|
||||||
uniform Val_t dstVal[],
|
uniform Val_t dstVal[],
|
||||||
uniform Key_t srcKey[],
|
uniform Key_t srcKey[],
|
||||||
@@ -492,6 +589,7 @@ void mergeElementaryIntervals(
|
|||||||
const uniform int mergePairs = (lastSegmentElements > stride) ? getSampleCount(N) : (N - lastSegmentElements) / SAMPLE_STRIDE;
|
const uniform int mergePairs = (lastSegmentElements > stride) ? getSampleCount(N) : (N - lastSegmentElements) / SAMPLE_STRIDE;
|
||||||
|
|
||||||
|
|
||||||
|
#if 0
|
||||||
launch [mergePairs] mergeElementaryIntervalsKernel(
|
launch [mergePairs] mergeElementaryIntervalsKernel(
|
||||||
dstKey,
|
dstKey,
|
||||||
dstVal,
|
dstVal,
|
||||||
@@ -501,6 +599,18 @@ void mergeElementaryIntervals(
|
|||||||
limitsB,
|
limitsB,
|
||||||
stride,
|
stride,
|
||||||
N);
|
N);
|
||||||
|
#else
|
||||||
|
launch [nTasks] mergeElementaryIntervalsKernel(
|
||||||
|
mergePairs,
|
||||||
|
dstKey,
|
||||||
|
dstVal,
|
||||||
|
srcKey,
|
||||||
|
srcVal,
|
||||||
|
limitsA,
|
||||||
|
limitsB,
|
||||||
|
stride,
|
||||||
|
N);
|
||||||
|
#endif
|
||||||
sync;
|
sync;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -516,6 +626,9 @@ export
|
|||||||
void openMergeSort()
|
void openMergeSort()
|
||||||
{
|
{
|
||||||
nTasks = num_cores()*4;
|
nTasks = num_cores()*4;
|
||||||
|
#ifdef __NVPTX__
|
||||||
|
nTasks = num_cores()*13;
|
||||||
|
#endif
|
||||||
MAX_SAMPLE_COUNT = 8*32 * 131072 / programCount;
|
MAX_SAMPLE_COUNT = 8*32 * 131072 / programCount;
|
||||||
assert(memPool == NULL);
|
assert(memPool == NULL);
|
||||||
const uniform int nalloc = MAX_SAMPLE_COUNT * 4;
|
const uniform int nalloc = MAX_SAMPLE_COUNT * 4;
|
||||||
@@ -582,7 +695,7 @@ void mergeSort(
|
|||||||
mergeRanksAndIndices(limitsA, limitsB, ranksA, ranksB, stride, N);
|
mergeRanksAndIndices(limitsA, limitsB, ranksA, ranksB, stride, N);
|
||||||
|
|
||||||
//Merge elementary intervals
|
//Merge elementary intervals
|
||||||
mergeElementaryIntervals(oKey, oVal, iKey, iVal, limitsA, limitsB, stride, N);
|
mergeElementaryIntervals(nTasks, oKey, oVal, iKey, iVal, limitsA, limitsB, stride, N);
|
||||||
|
|
||||||
if (lastSegmentElements <= stride)
|
if (lastSegmentElements <= stride)
|
||||||
foreach (i = 0 ... lastSegmentElements)
|
foreach (i = 0 ... lastSegmentElements)
|
||||||
|
|||||||
Reference in New Issue
Block a user