This commit is contained in:
Evghenii
2014-01-30 10:26:58 +01:00
parent 2f44b81d4f
commit 4e26a1b700
2 changed files with 50 additions and 28 deletions

View File

@@ -2,7 +2,7 @@
EXAMPLE=mergeSort EXAMPLE=mergeSort
CPP_SRC=mergeSort.cpp CPP_SRC=mergeSort.cpp
ISPC_SRC=mergeSort.ispc ISPC_SRC=mergeSort.ispc
ISPC_IA_TARGETS=avx1-i32x16 ISPC_IA_TARGETS=avx1-i32x8
ISPC_ARM_TARGETS=neon ISPC_ARM_TARGETS=neon
#ISPC_FLAGS=-DDEBUG -g #ISPC_FLAGS=-DDEBUG -g
CXXFLAGS=-g CXXFLAGS=-g

View File

@@ -161,43 +161,54 @@ int binarySearchExclusive1(
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
task task
void mergeSortGangKernel( void mergeSortGangKernel(
uniform int batchSize,
uniform Key_t dstKey[], uniform Key_t dstKey[],
uniform Val_t dstVal[], uniform Val_t dstVal[],
uniform Key_t srcKey[], uniform Key_t srcKey[],
uniform Val_t srcVal[]) uniform Val_t srcVal[])
{ {
const uniform int blockIdx = taskIndex;
const uniform int blockDim = (batchSize + taskCount - 1)/taskCount;
const uniform int blockBeg = blockIdx * blockDim;
const uniform int blockEnd = min(blockBeg + blockDim, batchSize);
uniform Key_t s_key[2*programCount]; uniform Key_t s_key[2*programCount];
uniform Val_t s_val[2*programCount]; uniform Val_t s_val[2*programCount];
const uniform int base = taskIndex * (programCount*2); for (uniform int block = blockBeg; block < blockEnd; block++)
s_key[programIndex + 0] = srcKey[base + programIndex + 0];
s_val[programIndex + 0] = srcVal[base + programIndex + 0];
s_key[programIndex + programCount] = srcKey[base + programIndex + programCount];
s_val[programIndex + programCount] = srcVal[base + programIndex + programCount];
for (uniform int stride = 1; stride < 2*programCount; stride <<= 1)
{ {
const int lPos = programIndex & (stride - 1); const uniform int base = block * (programCount*2);
uniform Key_t *baseKey = s_key + 2 * (programIndex - lPos); s_key[programIndex + 0] = srcKey[base + programIndex + 0];
uniform Val_t *baseVal = s_val + 2 * (programIndex - lPos); s_val[programIndex + 0] = srcVal[base + programIndex + 0];
s_key[programIndex + programCount] = srcKey[base + programIndex + programCount];
s_val[programIndex + programCount] = srcVal[base + programIndex + programCount];
Key_t keyA = baseKey[lPos + 0]; #if 1
Val_t valA = baseVal[lPos + 0]; for (uniform int stride = 1; stride < 2*programCount; stride <<= 1)
Key_t keyB = baseKey[lPos + stride]; {
Val_t valB = baseVal[lPos + stride]; const int lPos = programIndex & (stride - 1);
int posA = binarySearchExclusive(keyA, baseKey + stride, stride, stride) + lPos; uniform Key_t *baseKey = s_key + 2 * (programIndex - lPos);
int posB = binarySearchInclusive(keyB, baseKey + 0, stride, stride) + lPos; uniform Val_t *baseVal = s_val + 2 * (programIndex - lPos);
baseKey[posA] = keyA; Key_t keyA = baseKey[lPos + 0];
baseVal[posA] = valA; Val_t valA = baseVal[lPos + 0];
baseKey[posB] = keyB; Key_t keyB = baseKey[lPos + stride];
baseVal[posB] = valB; Val_t valB = baseVal[lPos + stride];
int posA = binarySearchExclusive(keyA, baseKey + stride, stride, stride) + lPos;
int posB = binarySearchInclusive(keyB, baseKey + 0, stride, stride) + lPos;
baseKey[posA] = keyA;
baseVal[posA] = valA;
baseKey[posB] = keyB;
baseVal[posB] = valB;
}
#endif
dstKey[base + programIndex + 0] = s_key[programIndex + 0];
dstVal[base + programIndex + 0] = s_val[programIndex + 0];
dstKey[base + programIndex + programCount] = s_key[programIndex + programCount];
dstVal[base + programIndex + programCount] = s_val[programIndex + programCount];
} }
dstKey[base + programIndex + 0] = s_key[programIndex + 0];
dstVal[base + programIndex + 0] = s_val[programIndex + 0];
dstKey[base + programIndex + programCount] = s_key[programIndex + programCount];
dstVal[base + programIndex + programCount] = s_val[programIndex + programCount];
} }
static inline static inline
@@ -208,7 +219,11 @@ void mergeSortGang(
uniform Val_t srcVal[], uniform Val_t srcVal[],
uniform int batchSize) uniform int batchSize)
{ {
launch [batchSize] mergeSortGangKernel(dstKey, dstVal, srcKey, srcVal); uniform int nTasks = num_cores()*4;
#ifdef __NVPTX__
nTasks = batchSize/4;
#endif
launch [nTasks] mergeSortGangKernel(batchSize, dstKey, dstVal, srcKey, srcVal);
sync; sync;
} }
@@ -536,8 +551,8 @@ void mergeElementaryIntervalsKernel(
valB = srcVal[segmentBase + stride + startSrcB + programIndex]; valB = srcVal[segmentBase + stride + startSrcB + programIndex];
} }
// Compute destination addresses for merge data
int dstPosA, dstPosB; int dstPosA, dstPosB;
// Compute destination addresses for merge data
if (programIndex < lenSrcA) if (programIndex < lenSrcA)
dstPosA = binarySearchExclusive1(keyA, keyB, lenSrcB, SAMPLE_STRIDE) + programIndex; dstPosA = binarySearchExclusive1(keyA, keyB, lenSrcB, SAMPLE_STRIDE) + programIndex;
if (programIndex < lenSrcB) if (programIndex < lenSrcB)
@@ -560,11 +575,13 @@ void mergeElementaryIntervalsKernel(
// store merge data // store merge data
if (dstA >= 0) if (dstA >= 0)
{ {
// int dstA = segmentBase + startSrcA + programIndex;
dstKey[dstA] = keyA; dstKey[dstA] = keyA;
dstVal[dstA] = valA; dstVal[dstA] = valA;
} }
if (dstB >= 0) if (dstB >= 0)
{ {
// int dstB = segmentBase + stride + startSrcB + programIndex;
dstKey[dstB] = keyB; dstKey[dstB] = keyB;
dstVal[dstB] = valB; dstVal[dstB] = valB;
} }
@@ -600,6 +617,9 @@ void mergeElementaryIntervals(
stride, stride,
N); N);
#else #else
#ifdef __NVPTX__
nTasks = mergePairs/(4*programCount);
#endif
launch [nTasks] mergeElementaryIntervalsKernel( launch [nTasks] mergeElementaryIntervalsKernel(
mergePairs, mergePairs,
dstKey, dstKey,
@@ -684,6 +704,7 @@ void mergeSort(
assert(N % (programCount*2) == 0); assert(N % (programCount*2) == 0);
mergeSortGang(iKey, iVal, srcKey, srcVal, N/(2*programCount)); mergeSortGang(iKey, iVal, srcKey, srcVal, N/(2*programCount));
#if 1
for (uniform int stride = 2*programCount; stride < N; stride <<= 1) for (uniform int stride = 2*programCount; stride < N; stride <<= 1)
{ {
const uniform int lastSegmentElements = N % (2 * stride); const uniform int lastSegmentElements = N % (2 * stride);
@@ -717,4 +738,5 @@ void mergeSort(
oVal = tmpVal; oVal = tmpVal;
} }
} }
#endif
} }