fix for non-pow2 number of elements

This commit is contained in:
Evghenii
2014-01-30 09:05:27 +01:00
parent 4a17760a2d
commit d65e1b30ce
2 changed files with 115 additions and 2 deletions

View File

@@ -2,7 +2,7 @@ PROG=mergeSort
ISPC_SRC=mergeSort.ispc
#CU_SRC=mergeSort.cu
CXX_SRC=mergeSort.cpp mergeSort.cpp
PTXCC_REGMAX=32
PTXCC_REGMAX=64
#PTXCC_FLAGS= -Xptxas=-O3
# LLVM_GPU=1

View File

@@ -225,6 +225,8 @@ void generateSampleRanksKernel(
uniform int totalProgramCount)
{
const int pos = taskIndex * programCount + programIndex;
cif (pos >= totalProgramCount)
return;
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
@@ -285,6 +287,8 @@ void mergeRanksAndIndicesKernel(
uniform int totalProgramCount)
{
int pos = taskIndex * programCount + programIndex;
cif (pos >= totalProgramCount)
return;
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
@@ -477,8 +481,101 @@ void mergeElementaryIntervalsKernel(
}
task
void mergeElementaryIntervalsKernel(
uniform int mergePairs,
uniform Key_t dstKey[],
uniform Val_t dstVal[],
uniform Key_t srcKey[],
uniform Val_t srcVal[],
uniform int limitsA[],
uniform int limitsB[],
uniform int stride,
uniform int N)
{
const uniform int blockIdx = taskIndex;
const uniform int blockDim = (mergePairs + taskCount - 1)/taskCount;
const uniform int blockBeg = blockIdx * blockDim;
const uniform int blockEnd = min(blockBeg + blockDim, mergePairs);
for (uniform int block = blockBeg; block < blockEnd; block++)
{
const int uniform intervalI = block & ((2 * stride) / SAMPLE_STRIDE - 1);
const int uniform segmentBase = (block - intervalI) * SAMPLE_STRIDE;
//Set up threadblock-wide parameters
const uniform int segmentElementsA = stride;
const uniform int segmentElementsB = min(stride, N - segmentBase - stride);
const uniform int segmentSamplesA = getSampleCount(segmentElementsA);
const uniform int segmentSamplesB = getSampleCount(segmentElementsB);
const uniform int segmentSamples = segmentSamplesA + segmentSamplesB;
const uniform int startSrcA = limitsA[block];
const uniform int startSrcB = limitsB[block];
const uniform int endSrcA = (intervalI + 1 < segmentSamples) ? limitsA[block + 1] : segmentElementsA;
const uniform int endSrcB = (intervalI + 1 < segmentSamples) ? limitsB[block + 1] : segmentElementsB;
const uniform int lenSrcA = endSrcA - startSrcA;
const uniform int lenSrcB = endSrcB - startSrcB;
const uniform int startDstA = startSrcA + startSrcB;
const uniform int startDstB = startDstA + lenSrcA;
//Load main input data
Key_t keyA, keyB;
Val_t valA, valB;
if (programIndex < lenSrcA)
{
keyA = srcKey[segmentBase + startSrcA + programIndex];
valA = srcVal[segmentBase + startSrcA + programIndex];
}
if (programIndex < lenSrcB)
{
keyB = srcKey[segmentBase + stride + startSrcB + programIndex];
valB = srcVal[segmentBase + stride + startSrcB + programIndex];
}
// Compute destination addresses for merge data
int dstPosA, dstPosB;
if (programIndex < lenSrcA)
dstPosA = binarySearchExclusive1(keyA, keyB, lenSrcB, SAMPLE_STRIDE) + programIndex;
if (programIndex < lenSrcB)
dstPosB = binarySearchInclusive1(keyB, keyA, lenSrcA, SAMPLE_STRIDE) + programIndex;
int dstA = -1, dstB = -1;
if (programIndex < lenSrcA && dstPosA < lenSrcA)
dstA = segmentBase + startDstA + dstPosA;
if (programIndex < lenSrcB && dstPosB < lenSrcA)
dstB = segmentBase + startDstA + dstPosB;
dstPosA -= lenSrcA;
dstPosB -= lenSrcA;
if (programIndex < lenSrcA && dstPosA < lenSrcB)
dstA = segmentBase + startDstB + dstPosA;
if (programIndex < lenSrcB && dstPosB < lenSrcB)
dstB = segmentBase + startDstB + dstPosB;
// store merge data
if (dstA >= 0)
{
dstKey[dstA] = keyA;
dstVal[dstA] = valA;
}
if (dstB >= 0)
{
dstKey[dstB] = keyB;
dstVal[dstB] = valB;
}
}
}
static inline
void mergeElementaryIntervals(
uniform int nTasks,
uniform Key_t dstKey[],
uniform Val_t dstVal[],
uniform Key_t srcKey[],
@@ -492,6 +589,7 @@ void mergeElementaryIntervals(
const uniform int mergePairs = (lastSegmentElements > stride) ? getSampleCount(N) : (N - lastSegmentElements) / SAMPLE_STRIDE;
#if 0
launch [mergePairs] mergeElementaryIntervalsKernel(
dstKey,
dstVal,
@@ -501,6 +599,18 @@ void mergeElementaryIntervals(
limitsB,
stride,
N);
#else
launch [nTasks] mergeElementaryIntervalsKernel(
mergePairs,
dstKey,
dstVal,
srcKey,
srcVal,
limitsA,
limitsB,
stride,
N);
#endif
sync;
}
@@ -516,6 +626,9 @@ export
void openMergeSort()
{
nTasks = num_cores()*4;
#ifdef __NVPTX__
nTasks = num_cores()*13;
#endif
MAX_SAMPLE_COUNT = 8*32 * 131072 / programCount;
assert(memPool == NULL);
const uniform int nalloc = MAX_SAMPLE_COUNT * 4;
@@ -582,7 +695,7 @@ void mergeSort(
mergeRanksAndIndices(limitsA, limitsB, ranksA, ranksB, stride, N);
//Merge elementary intervals
mergeElementaryIntervals(oKey, oVal, iKey, iVal, limitsA, limitsB, stride, N);
mergeElementaryIntervals(nTasks, oKey, oVal, iKey, iVal, limitsA, limitsB, stride, N);
if (lastSegmentElements <= stride)
foreach (i = 0 ... lastSegmentElements)