Add support for mask vectors of 8 and 16-bit element types.
There were a number of places throughout the system that assumed that the execution mask would only have either 32-bit or 1-bit elements. This commit makes it possible to have a target with an 8- or 16-bit mask.
This commit is contained in:
73
llvmutil.cpp
73
llvmutil.cpp
@@ -115,13 +115,25 @@ InitLLVMUtil(llvm::LLVMContext *ctx, Target& target) {
|
||||
LLVMTypes::FloatPointerType = llvm::PointerType::get(LLVMTypes::FloatType, 0);
|
||||
LLVMTypes::DoublePointerType = llvm::PointerType::get(LLVMTypes::DoubleType, 0);
|
||||
|
||||
if (target.getMaskBitCount() == 1)
|
||||
switch (target.getMaskBitCount()) {
|
||||
case 1:
|
||||
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
|
||||
llvm::VectorType::get(llvm::Type::getInt1Ty(*ctx), target.getVectorWidth());
|
||||
else {
|
||||
Assert(target.getMaskBitCount() == 32);
|
||||
break;
|
||||
case 8:
|
||||
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
|
||||
llvm::VectorType::get(llvm::Type::getInt8Ty(*ctx), target.getVectorWidth());
|
||||
break;
|
||||
case 16:
|
||||
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
|
||||
llvm::VectorType::get(llvm::Type::getInt16Ty(*ctx), target.getVectorWidth());
|
||||
break;
|
||||
case 32:
|
||||
LLVMTypes::MaskType = LLVMTypes::BoolVectorType =
|
||||
llvm::VectorType::get(llvm::Type::getInt32Ty(*ctx), target.getVectorWidth());
|
||||
break;
|
||||
default:
|
||||
FATAL("Unhandled mask width for initializing MaskType");
|
||||
}
|
||||
|
||||
LLVMTypes::Int1VectorType =
|
||||
@@ -154,12 +166,26 @@ InitLLVMUtil(llvm::LLVMContext *ctx, Target& target) {
|
||||
|
||||
std::vector<llvm::Constant *> maskOnes;
|
||||
llvm::Constant *onMask = NULL;
|
||||
if (target.getMaskBitCount() == 1)
|
||||
switch (target.getMaskBitCount()) {
|
||||
case 1:
|
||||
onMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 1,
|
||||
false /*unsigned*/); // 0x1
|
||||
else
|
||||
break;
|
||||
case 8:
|
||||
onMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), -1,
|
||||
true /*signed*/); // 0xff
|
||||
break;
|
||||
case 16:
|
||||
onMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), -1,
|
||||
true /*signed*/); // 0xffff
|
||||
break;
|
||||
case 32:
|
||||
onMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), -1,
|
||||
true /*signed*/); // 0xffffffff
|
||||
break;
|
||||
default:
|
||||
FATAL("Unhandled mask width for onMask");
|
||||
}
|
||||
|
||||
for (int i = 0; i < target.getVectorWidth(); ++i)
|
||||
maskOnes.push_back(onMask);
|
||||
@@ -167,13 +193,26 @@ InitLLVMUtil(llvm::LLVMContext *ctx, Target& target) {
|
||||
|
||||
std::vector<llvm::Constant *> maskZeros;
|
||||
llvm::Constant *offMask = NULL;
|
||||
if (target.getMaskBitCount() == 1)
|
||||
switch (target.getMaskBitCount()) {
|
||||
case 1:
|
||||
offMask = llvm::ConstantInt::get(llvm::Type::getInt1Ty(*ctx), 0,
|
||||
true /*signed*/);
|
||||
else
|
||||
break;
|
||||
case 8:
|
||||
offMask = llvm::ConstantInt::get(llvm::Type::getInt8Ty(*ctx), 0,
|
||||
true /*signed*/);
|
||||
break;
|
||||
case 16:
|
||||
offMask = llvm::ConstantInt::get(llvm::Type::getInt16Ty(*ctx), 0,
|
||||
true /*signed*/);
|
||||
break;
|
||||
case 32:
|
||||
offMask = llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0,
|
||||
true /*signed*/);
|
||||
|
||||
break;
|
||||
default:
|
||||
FATAL("Unhandled mask width for offMask");
|
||||
}
|
||||
for (int i = 0; i < target.getVectorWidth(); ++i)
|
||||
maskZeros.push_back(offMask);
|
||||
LLVMMaskAllOff = llvm::ConstantVector::get(maskZeros);
|
||||
@@ -444,9 +483,14 @@ LLVMBoolVector(bool b) {
|
||||
if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
|
||||
v = llvm::ConstantInt::get(LLVMTypes::Int32Type, b ? 0xffffffff : 0,
|
||||
false /*unsigned*/);
|
||||
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
|
||||
v = llvm::ConstantInt::get(LLVMTypes::Int16Type, b ? 0xffff : 0,
|
||||
false /*unsigned*/);
|
||||
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
|
||||
v = llvm::ConstantInt::get(LLVMTypes::Int8Type, b ? 0xff : 0,
|
||||
false /*unsigned*/);
|
||||
else {
|
||||
Assert(LLVMTypes::BoolVectorType->getElementType() ==
|
||||
llvm::Type::getInt1Ty(*g->ctx));
|
||||
Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
|
||||
v = b ? LLVMTrue : LLVMFalse;
|
||||
}
|
||||
|
||||
@@ -465,9 +509,14 @@ LLVMBoolVector(const bool *bvec) {
|
||||
if (LLVMTypes::BoolVectorType == LLVMTypes::Int32VectorType)
|
||||
v = llvm::ConstantInt::get(LLVMTypes::Int32Type, bvec[i] ? 0xffffffff : 0,
|
||||
false /*unsigned*/);
|
||||
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int16VectorType)
|
||||
v = llvm::ConstantInt::get(LLVMTypes::Int16Type, bvec[i] ? 0xffff : 0,
|
||||
false /*unsigned*/);
|
||||
else if (LLVMTypes::BoolVectorType == LLVMTypes::Int8VectorType)
|
||||
v = llvm::ConstantInt::get(LLVMTypes::Int8Type, bvec[i] ? 0xff : 0,
|
||||
false /*unsigned*/);
|
||||
else {
|
||||
Assert(LLVMTypes::BoolVectorType->getElementType() ==
|
||||
llvm::Type::getInt1Ty(*g->ctx));
|
||||
Assert(LLVMTypes::BoolVectorType == LLVMTypes::Int1VectorType);
|
||||
v = bvec[i] ? LLVMTrue : LLVMFalse;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user