diff --git a/examples_ptx/mergeSort/mergeSort.ispc b/examples_ptx/mergeSort/mergeSort.ispc index 19633839..2d65c501 100644 --- a/examples_ptx/mergeSort/mergeSort.ispc +++ b/examples_ptx/mergeSort/mergeSort.ispc @@ -65,6 +65,50 @@ int binarySearchExclusive( return pos; } +static inline +int binarySearchInclusive1( + const int val, + int data, + const uniform int L, + uniform int stride) +{ + if (L == 0) + return 0; + + int pos = 0; + for (; stride > 0; stride >>= 1) + { + int newPos = min(pos + stride, L); + + if (shuffle(data,newPos - 1) <= val) + pos = newPos; + } + + return pos; +} + +static inline +int binarySearchExclusive1( + const int val, + int data, + const uniform int L, + uniform int stride) +{ + if (L == 0) + return 0; + + int pos = 0; + for (; stride > 0; stride >>= 1) + { + int newPos = min(pos + stride, L); + + if (shuffle(data,newPos - 1) < val) + pos = newPos; + } + + return pos; +} + //////////////////////////////////////////////////////////////////////////////// // Bottom-level merge sort (binary search-based) //////////////////////////////////////////////////////////////////////////////// @@ -296,6 +340,33 @@ void merge( } } +static inline +void merge( + uniform int dstKey[], + uniform int dstVal[], + int keyA, int valA, + int keyB, int valB, + uniform int lenA, + uniform int nPowTwoLenA, + uniform int lenB, + uniform int nPowTwoLenB) +{ + if (programIndex < lenA) + { + const int dstPosA = binarySearchExclusive1(keyA, keyB, lenB, nPowTwoLenB) + programIndex; + dstKey[dstPosA] = keyA; + dstVal[dstPosA] = valA; + } + + if (programIndex < lenB) + { + const int dstPosB = binarySearchInclusive1(keyB, keyA, lenA, nPowTwoLenA) + programIndex; + dstKey[dstPosB] = keyB; + dstVal[dstPosB] = valB; + } +} + + task void mergeElementaryIntervalsKernel( uniform int dstKey[], @@ -312,53 +383,45 @@ void mergeElementaryIntervalsKernel( const int uniform intervalI = taskIndex & ((2 * stride) / SAMPLE_STRIDE - 1); const int uniform segmentBase = (taskIndex - intervalI) * SAMPLE_STRIDE; - srcKey += segmentBase; - srcVal += segmentBase; - dstKey += segmentBase; - dstVal += segmentBase; //Set up threadblock-wide parameters - uniform int startSrcA, startSrcB, lenSrcA, lenSrcB, startDstA, startDstB; - { - uniform int segmentElementsA = stride; - uniform int segmentElementsB = min(stride, N - segmentBase - stride); - uniform int segmentSamplesA = getSampleCount(segmentElementsA); - uniform int segmentSamplesB = getSampleCount(segmentElementsB); - uniform int segmentSamples = segmentSamplesA + segmentSamplesB; + const uniform int segmentElementsA = stride; + const uniform int segmentElementsB = min(stride, N - segmentBase - stride); + const uniform int segmentSamplesA = getSampleCount(segmentElementsA); + const uniform int segmentSamplesB = getSampleCount(segmentElementsB); + const uniform int segmentSamples = segmentSamplesA + segmentSamplesB; - startSrcA = limitsA[taskIndex]; - startSrcB = limitsB[taskIndex]; - uniform int endSrcA = (intervalI + 1 < segmentSamples) ? limitsA[taskIndex + 1] : segmentElementsA; - uniform int endSrcB = (intervalI + 1 < segmentSamples) ? limitsB[taskIndex + 1] : segmentElementsB; - lenSrcA = endSrcA - startSrcA; - lenSrcB = endSrcB - startSrcB; - startDstA = startSrcA + startSrcB; - startDstB = startDstA + lenSrcA; - } + const uniform int startSrcA = limitsA[taskIndex]; + const uniform int startSrcB = limitsB[taskIndex]; + const uniform int endSrcA = (intervalI + 1 < segmentSamples) ? limitsA[taskIndex + 1] : segmentElementsA; + const uniform int endSrcB = (intervalI + 1 < segmentSamples) ? limitsB[taskIndex + 1] : segmentElementsB; + const uniform int lenSrcA = endSrcA - startSrcA; + const uniform int lenSrcB = endSrcB - startSrcB; + const uniform int startDstA = startSrcA + startSrcB; + const uniform int startDstB = startDstA + lenSrcA; //Load main input data + int keyA, valA, keyB, valB; if (programIndex < lenSrcA) { - s_key[programIndex + 0] = srcKey[0 + startSrcA + programIndex]; - s_val[programIndex + 0] = srcVal[0 + startSrcA + programIndex]; + keyA = srcKey[segmentBase + startSrcA + programIndex]; + valA = srcVal[segmentBase + startSrcA + programIndex]; } if (programIndex < lenSrcB) { - s_key[programIndex + SAMPLE_STRIDE] = srcKey[stride + startSrcB + programIndex]; - s_val[programIndex + SAMPLE_STRIDE] = srcVal[stride + startSrcB + programIndex]; + keyB = srcKey[segmentBase + stride + startSrcB + programIndex]; + valB = srcVal[segmentBase + stride + startSrcB + programIndex]; } //Merge data in shared memory merge( s_key, s_val, - s_key + 0, - s_val + 0, - s_key + SAMPLE_STRIDE, - s_val + SAMPLE_STRIDE, + keyA, valA, + keyB, valB, lenSrcA, SAMPLE_STRIDE, lenSrcB, SAMPLE_STRIDE ); @@ -367,16 +430,17 @@ void mergeElementaryIntervalsKernel( if (programIndex < lenSrcA) { - dstKey[startDstA + programIndex] = s_key[programIndex]; - dstVal[startDstA + programIndex] = s_val[programIndex]; + dstKey[segmentBase + startDstA + programIndex] = s_key[programIndex]; + dstVal[segmentBase + startDstA + programIndex] = s_val[programIndex]; } if (programIndex < lenSrcB) { - dstKey[startDstB + programIndex] = s_key[lenSrcA + programIndex]; - dstVal[startDstB + programIndex] = s_val[lenSrcA + programIndex]; + dstKey[segmentBase + startDstB + programIndex] = s_key[lenSrcA + programIndex]; + dstVal[segmentBase + startDstB + programIndex] = s_val[lenSrcA + programIndex]; } } + static inline void mergeElementaryIntervals( uniform int dstKey[], @@ -409,11 +473,13 @@ static uniform int * uniform ranksA; static uniform int * uniform ranksB; static uniform int * uniform limitsA; static uniform int * uniform limitsB; +static uniform int nTasks; static uniform int MAX_SAMPLE_COUNT = 0; export void openMergeSort() { + nTasks = num_cores()*4; MAX_SAMPLE_COUNT = 8*32 * 131072 / programCount; assert(memPool == NULL); const uniform int nalloc = MAX_SAMPLE_COUNT * 4;