compiles
This commit is contained in:
@@ -92,7 +92,7 @@ int binarySearchExclusive(
|
|||||||
// Bottom-level merge sort (binary search-based)
|
// Bottom-level merge sort (binary search-based)
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
task
|
task
|
||||||
void mergeSortSharedKernel(
|
void mergeSortGangKernel(
|
||||||
uniform int dstKey[],
|
uniform int dstKey[],
|
||||||
uniform int dstVal[],
|
uniform int dstVal[],
|
||||||
uniform int srcKey[],
|
uniform int srcKey[],
|
||||||
@@ -132,6 +132,18 @@ void mergeSortSharedKernel(
|
|||||||
dstVal[base + programIndex + programCount] = s_val[programIndex + programCount];
|
dstVal[base + programIndex + programCount] = s_val[programIndex + programCount];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static inline
|
||||||
|
void mergeSortGang(
|
||||||
|
uniform int dstKey[],
|
||||||
|
uniform int dstVal[],
|
||||||
|
uniform int srcKey[],
|
||||||
|
uniform int srcVal[],
|
||||||
|
uniform int batchSize)
|
||||||
|
{
|
||||||
|
launch [batchSize] mergeSortGangKernel(dstKey, dstVal, srcKey, srcVal);
|
||||||
|
sync;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// Merge step 1: generate sample ranks
|
// Merge step 1: generate sample ranks
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
@@ -140,42 +152,58 @@ void generateSampleRanksKernel(
|
|||||||
uniform int in_ranksA[],
|
uniform int in_ranksA[],
|
||||||
uniform int in_ranksB[],
|
uniform int in_ranksB[],
|
||||||
uniform int in_srcKey[],
|
uniform int in_srcKey[],
|
||||||
const uniform int stride,
|
uniform int stride,
|
||||||
const uniform int N,
|
uniform int N,
|
||||||
const int totalProgramCount)
|
uniform int totalProgramCount)
|
||||||
{
|
{
|
||||||
const int pos = taskIndex * programCount + programIndex;
|
const int pos = taskIndex * programCount + programIndex;
|
||||||
|
|
||||||
if (pos >= totalProgramCount)
|
if (pos >= totalProgramCount)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
const int i = pos & ((stride / SAMPLE_STRIDE) - 1);
|
||||||
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
const int segmentBase = (pos - i) * (2 * SAMPLE_STRIDE);
|
||||||
|
|
||||||
int * srcKey = in_srcKey + segmentBase;
|
int * srcKey = in_srcKey + segmentBase;
|
||||||
int * ranksA = in_ranksA + segmentBase / SAMPLE_STRIDE;
|
int * ranksA = in_ranksA + segmentBase / SAMPLE_STRIDE;
|
||||||
int * ranksB = in_ranksB + segmentBase / SAMPLE_STRIDE;
|
int * ranksB = in_ranksB + segmentBase / SAMPLE_STRIDE;
|
||||||
|
|
||||||
const int segmentElementsA = stride;
|
const int segmentElementsA = stride;
|
||||||
const int segmentElementsB = min(stride, N - segmentBase - stride);
|
const int segmentElementsB = min(stride, N - segmentBase - stride);
|
||||||
const int segmentSamplesA = getSampleCount(segmentElementsA);
|
const int segmentSamplesA = getSampleCount(segmentElementsA);
|
||||||
const int segmentSamplesB = getSampleCount(segmentElementsB);
|
const int segmentSamplesB = getSampleCount(segmentElementsB);
|
||||||
|
|
||||||
if (i < segmentSamplesA)
|
if (i < segmentSamplesA)
|
||||||
{
|
{
|
||||||
ranksA[i] = i * SAMPLE_STRIDE;
|
ranksA[i] = i * SAMPLE_STRIDE;
|
||||||
ranksB[i] = binarySearchExclusive(
|
ranksB[i] = binarySearchExclusive(
|
||||||
srcKey[i * SAMPLE_STRIDE], srcKey + stride,
|
srcKey[i * SAMPLE_STRIDE], srcKey + stride,
|
||||||
segmentElementsB, nextPowerOfTwo(segmentElementsB));
|
segmentElementsB, nextPowerOfTwo(segmentElementsB));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (i < segmentSamplesB)
|
if (i < segmentSamplesB)
|
||||||
{
|
{
|
||||||
ranksB[(stride / SAMPLE_STRIDE) + i] = i * SAMPLE_STRIDE;
|
ranksB[(stride / SAMPLE_STRIDE) + i] = i * SAMPLE_STRIDE;
|
||||||
ranksA[(stride / SAMPLE_STRIDE) + i] = binarySearchInclusive(
|
ranksA[(stride / SAMPLE_STRIDE) + i] = binarySearchInclusive(
|
||||||
srcKey[stride + i * SAMPLE_STRIDE], srcKey + 0,
|
srcKey[stride + i * SAMPLE_STRIDE], srcKey + 0,
|
||||||
segmentElementsA, nextPowerOfTwo(segmentElementsA));
|
segmentElementsA, nextPowerOfTwo(segmentElementsA));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline
|
||||||
|
void generateSampleRanks(
|
||||||
|
uniform int ranksA[],
|
||||||
|
uniform int ranksB[],
|
||||||
|
uniform int srcKey[],
|
||||||
|
uniform int stride,
|
||||||
|
uniform int N)
|
||||||
|
{
|
||||||
|
uniform int lastSegmentElements = N % (2 * stride);
|
||||||
|
uniform int threadCount = (lastSegmentElements > stride) ? (N + 2 * stride - lastSegmentElements) / (2 * SAMPLE_STRIDE) : (N - lastSegmentElements) / (2 * SAMPLE_STRIDE);
|
||||||
|
uniform int nTasks = (threadCount + programCount - 1) / programCount;
|
||||||
|
|
||||||
|
launch [nTasks] generateSampleRanksKernel(ranksA, ranksB, srcKey, stride, N, threadCount);
|
||||||
|
sync;
|
||||||
}
|
}
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
// Merge step 2: generate sample ranks and indices
|
// Merge step 2: generate sample ranks and indices
|
||||||
@@ -215,6 +243,35 @@ void mergeRanksAndIndicesKernel(
|
|||||||
limits[dstPos] = ranks[segmentSamplesA + i];
|
limits[dstPos] = ranks[segmentSamplesA + i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
static inline
|
||||||
|
void mergeRanksAndIndices(
|
||||||
|
uniform int limitsA[],
|
||||||
|
uniform int limitsB[],
|
||||||
|
uniform int ranksA[],
|
||||||
|
uniform int ranksB[],
|
||||||
|
uniform int stride,
|
||||||
|
uniform int N)
|
||||||
|
{
|
||||||
|
const uniform int lastSegmentElements = N % (2 * stride);
|
||||||
|
const uniform int threadCount = (lastSegmentElements > stride) ? (N + 2 * stride - lastSegmentElements) / (2 * SAMPLE_STRIDE) : (N - lastSegmentElements) / (2 * SAMPLE_STRIDE);
|
||||||
|
const uniform int nTasks = (threadCount + programCount -1 ) / programCount;
|
||||||
|
|
||||||
|
launch [nTasks] mergeRanksAndIndicesKernel(
|
||||||
|
limitsA,
|
||||||
|
ranksA,
|
||||||
|
stride,
|
||||||
|
N,
|
||||||
|
threadCount);
|
||||||
|
sync;
|
||||||
|
|
||||||
|
launch [nTasks] mergeRanksAndIndicesKernel(
|
||||||
|
limitsB,
|
||||||
|
ranksB,
|
||||||
|
stride,
|
||||||
|
N,
|
||||||
|
threadCount);
|
||||||
|
sync;
|
||||||
|
}
|
||||||
|
|
||||||
static inline
|
static inline
|
||||||
void merge(
|
void merge(
|
||||||
@@ -340,3 +397,131 @@ void mergeElementaryIntervalsKernel(
|
|||||||
dstVal[startDstB + programIndex] = s_val[lenSrcA + programIndex];
|
dstVal[startDstB + programIndex] = s_val[lenSrcA + programIndex];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
static inline
|
||||||
|
void mergeElementaryIntervals(
|
||||||
|
uniform int dstKey[],
|
||||||
|
uniform int dstVal[],
|
||||||
|
uniform int srcKey[],
|
||||||
|
uniform int srcVal[],
|
||||||
|
uniform int limitsA[],
|
||||||
|
uniform int limitsB[],
|
||||||
|
uniform int stride,
|
||||||
|
uniform int N)
|
||||||
|
{
|
||||||
|
const uniform int lastSegmentElements = N % (2 * stride);
|
||||||
|
const uniform int mergePairs = (lastSegmentElements > stride) ? getSampleCount(N) : (N - lastSegmentElements) / SAMPLE_STRIDE;
|
||||||
|
|
||||||
|
|
||||||
|
launch [mergePairs] mergeElementaryIntervalsKernel(
|
||||||
|
dstKey,
|
||||||
|
dstVal,
|
||||||
|
srcKey,
|
||||||
|
srcVal,
|
||||||
|
limitsA,
|
||||||
|
limitsB,
|
||||||
|
stride,
|
||||||
|
N);
|
||||||
|
sync;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uniform int * uniform memPool = NULL;
|
||||||
|
static uniform int * uniform ranksA;
|
||||||
|
static uniform int * uniform ranksB;
|
||||||
|
static uniform int * uniform limitsA;
|
||||||
|
static uniform int * uniform limitsB;
|
||||||
|
static uniform int MAX_SAMPLE_COUNT = 0;
|
||||||
|
|
||||||
|
export
|
||||||
|
void openMergeSort()
|
||||||
|
{
|
||||||
|
MAX_SAMPLE_COUNT = 32 * 131072 / programCount;
|
||||||
|
assert(memPool == NULL);
|
||||||
|
const uniform int nalloc = MAX_SAMPLE_COUNT * 4;
|
||||||
|
memPool = uniform new uniform int[nalloc];
|
||||||
|
ranksA = memPool;
|
||||||
|
ranksB = ranksA + MAX_SAMPLE_COUNT;
|
||||||
|
limitsA = ranksB + MAX_SAMPLE_COUNT;
|
||||||
|
limitsB = limitsA + MAX_SAMPLE_COUNT;
|
||||||
|
}
|
||||||
|
|
||||||
|
export
|
||||||
|
void closeMergeSort()
|
||||||
|
{
|
||||||
|
assert(memPool != NULL);
|
||||||
|
delete memPool;
|
||||||
|
memPool = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
export
|
||||||
|
void copyKernel(uniform int dst[], uniform int src[], uniform int size)
|
||||||
|
{
|
||||||
|
foreach (i = 0 ... size)
|
||||||
|
dst[i] = src[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
export
|
||||||
|
void mergeSort(
|
||||||
|
uniform int dstKey[],
|
||||||
|
uniform int dstVal[],
|
||||||
|
uniform int bufKey[],
|
||||||
|
uniform int bufVal[],
|
||||||
|
uniform int srcKey[],
|
||||||
|
uniform int srcVal[],
|
||||||
|
uniform int N)
|
||||||
|
{
|
||||||
|
uniform int stageCount = 0;
|
||||||
|
for (uniform int stride = 2*programCount; stride < N; stride <<= 1, stageCount++);
|
||||||
|
|
||||||
|
uniform int * uniform iKey, * uniform oKey;
|
||||||
|
uniform int * uniform iVal, * uniform oVal;
|
||||||
|
|
||||||
|
if (stageCount & 1)
|
||||||
|
{
|
||||||
|
iKey = bufKey;
|
||||||
|
iVal = bufVal;
|
||||||
|
oKey = dstKey;
|
||||||
|
oVal = dstVal;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
iKey = dstKey;
|
||||||
|
iVal = dstVal;
|
||||||
|
oKey = bufKey;
|
||||||
|
oVal = bufVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(N <= SAMPLE_STRIDE * MAX_SAMPLE_COUNT);
|
||||||
|
assert(N % (programCount*2) == 0);
|
||||||
|
mergeSortGang(iKey, iVal, srcKey, srcVal, N/(2*programCount));
|
||||||
|
|
||||||
|
for (uniform int stride = 2*programCount; stride < N; stride <<= 1)
|
||||||
|
{
|
||||||
|
uniform int lastSegmentElements = N % (2 * stride);
|
||||||
|
|
||||||
|
//Find sample ranks and prepare for limiters merge
|
||||||
|
generateSampleRanks(ranksA, ranksB, iKey, stride, N);
|
||||||
|
|
||||||
|
//Merge ranks and indices
|
||||||
|
mergeRanksAndIndices(limitsA, limitsB, ranksA, ranksB, stride, N);
|
||||||
|
|
||||||
|
//Merge elementary intervals
|
||||||
|
mergeElementaryIntervals(oKey, oVal, iKey, iVal, limitsA, limitsB, stride, N);
|
||||||
|
|
||||||
|
if (lastSegmentElements <= stride)
|
||||||
|
{
|
||||||
|
assert(0);
|
||||||
|
//Last merge segment consists of a single array which just needs to be passed through
|
||||||
|
copyKernel(oKey + (N - lastSegmentElements), iKey + (N - lastSegmentElements), lastSegmentElements);
|
||||||
|
copyKernel(oVal + (N - lastSegmentElements), iVal + (N - lastSegmentElements), lastSegmentElements);
|
||||||
|
}
|
||||||
|
|
||||||
|
uniform int * uniform tmpKey = iKey;
|
||||||
|
iKey = oKey;
|
||||||
|
oKey = tmpKey;
|
||||||
|
|
||||||
|
uniform int * uniform tmpVal = iVal;
|
||||||
|
iVal = oVal;
|
||||||
|
oVal = tmpVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user