Add support for mask vectors of 8 and 16-bit element types.

There were a number of places throughout the system that assumed that the
execution mask would only have either 32-bit or 1-bit elements.  This
commit makes it possible to have a target with an 8- or 16-bit mask.
This commit is contained in:
Matt Pharr
2013-07-23 16:38:10 -07:00
parent 83e1630fbc
commit e7abf3f2ea
8 changed files with 284 additions and 133 deletions

View File

@@ -115,13 +115,25 @@ InitLLVMUtil(llvm::LLVMContext *ctx, Target& target) {
LLVMTypes::FloatPointerType = llvm::PointerType::get(LLVMTypes::FloatType, 0);
LLVMTypes::DoublePointerType = llvm::PointerType::get(LLVMTypes::DoubleType, 0);
if (target.getMaskBitCount() == 1)
switch (target.getMaskBitCount()) {
case 1:
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
llvm::VectorType::get(llvm::Type::getInt1Ty(*ctx), target.getVectorWidth());
else {
Assert(target.getMaskBitCount() == 32);
break;
case 8:
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
llvm::VectorType::get(llvm::Type::getInt8Ty(*ctx), target.getVectorWidth());
break;
case 16:
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
llvm::VectorType::get(llvm::Type::getInt16Ty(*ctx), target.getVectorWidth());
break;
case 32:
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
llvm::VectorType::get(llvm::Type::getInt32Ty(*ctx), target.getVectorWidth());
break;
default:
FATAL("Unhandled mask width for initializing MaskType");
}
LLVMTypes::Int1VectorType =
@@ -154,12 +166,26 @@ InitLLVMUtil(llvm::LLVMContext *ctx, Target& target) {
std::vector<llvm::Constant *> maskOnes;
llvm::Constant *onMask = NULL;
if (target.getMaskBitCount() == 1)
switch (target.getMaskBitCount()) {
case 1:
onMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 1,
false /*unsigned*/); // 0x1
else
break;
case 8:
onMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), -1,
true /*signed*/); // 0xff
break;
case 16:
onMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), -1,
true /*signed*/); // 0xffff
break;
case 32:
onMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), -1,
true /*signed*/); // 0xffffffff
break;
default:
FATAL("Unhandled mask width for onMask");
}
for (int i = 0; i < target.getVectorWidth(); ++i)
maskOnes.push_back(onMask);
@@ -167,13 +193,26 @@ InitLLVMUtil(llvm::LLVMContext *ctx, Target& target) {
std::vector<llvm::Constant *> maskZeros;
llvm::Constant *offMask = NULL;
if (target.getMaskBitCount() == 1)
switch (target.getMaskBitCount()) {
case 1:
offMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 0,
true /*signed*/);
else
break;
case 8:
offMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), 0,
true /*signed*/);
break;
case 16:
offMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), 0,
true /*signed*/);
break;
case 32:
offMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0,
true /*signed*/);
break;
default:
FATAL("Unhandled mask width for offMask");
}
for (int i = 0; i < target.getVectorWidth(); ++i)
maskZeros.push_back(offMask);
LLVMMaskAllOff = llvm::ConstantVector::get(maskZeros);
@@ -444,9 +483,14 @@ LLVMBoolVector(bool b) {
if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
v = llvm::ConstantInt::get(LLVMTypes::Int32Type, b ? 0xffffffff : 0,
false /*unsigned*/);
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
v = llvm::ConstantInt::get(LLVMTypes::Int16Type, b ? 0xffff : 0,
false /*unsigned*/);
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
v = llvm::ConstantInt::get(LLVMTypes::Int8Type, b ? 0xff : 0,
false /*unsigned*/);
else {
Assert(LLVMTypes::BoolVectorType->getElementType() ==
llvm::Type::getInt1Ty(*g->ctx));
Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
v = b ? LLVMTrue : LLVMFalse;
}
@@ -465,9 +509,14 @@ LLVMBoolVector(const bool *bvec) {
if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
v = llvm::ConstantInt::get(LLVMTypes::Int32Type, bvec[i] ? 0xffffffff : 0,
false /*unsigned*/);
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
v = llvm::ConstantInt::get(LLVMTypes::Int16Type, bvec[i] ? 0xffff : 0,
false /*unsigned*/);
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
v = llvm::ConstantInt::get(LLVMTypes::Int8Type, bvec[i] ? 0xff : 0,
false /*unsigned*/);
else {
Assert(LLVMTypes::BoolVectorType->getElementType() ==
llvm::Type::getInt1Ty(*g->ctx));
Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
v = bvec[i] ? LLVMTrue : LLVMFalse;
}