added basic optimization pass that promotes uniform into varying variables (not array) for nvptx target
This commit is contained in:
@@ -630,7 +630,9 @@ lSetInternalFunctions(llvm::Module *module) {
|
|||||||
"__nctaid_y",
|
"__nctaid_y",
|
||||||
"__nctaid_z",
|
"__nctaid_z",
|
||||||
"__warpsize",
|
"__warpsize",
|
||||||
"__cvt_loc2gen"
|
"__cvt_loc2gen",
|
||||||
|
"__cvt_loc2gen_var",
|
||||||
|
"__cvt_const2gen"
|
||||||
};
|
};
|
||||||
|
|
||||||
int count = sizeof(names) / sizeof(names[0]);
|
int count = sizeof(names) / sizeof(names[0]);
|
||||||
|
|||||||
@@ -69,6 +69,11 @@ define i64* @__cvt_loc2gen(i64 addrspace(3)*) nounwind readnone alwaysinline
|
|||||||
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p3i64(i64 addrspace(3)* %0)
|
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p3i64(i64 addrspace(3)* %0)
|
||||||
ret i64* %ptr
|
ret i64* %ptr
|
||||||
}
|
}
|
||||||
|
define i64* @__cvt_loc2gen_var(i64 addrspace(3)*) nounwind readnone alwaysinline
|
||||||
|
{
|
||||||
|
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p3i64(i64 addrspace(3)* %0)
|
||||||
|
ret i64* %ptr
|
||||||
|
}
|
||||||
define i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline
|
define i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline
|
||||||
{
|
{
|
||||||
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p4i64(i64 addrspace(4)* %0)
|
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p4i64(i64 addrspace(4)* %0)
|
||||||
|
|||||||
@@ -4544,5 +4544,6 @@ declare i32 @__nctaid_y() nounwind readnone alwaysinline
|
|||||||
declare i32 @__nctaid_z() nounwind readnone alwaysinline
|
declare i32 @__nctaid_z() nounwind readnone alwaysinline
|
||||||
declare i64* @__cvt_loc2gen(i64 addrspace(3)*) nounwind readnone alwaysinline
|
declare i64* @__cvt_loc2gen(i64 addrspace(3)*) nounwind readnone alwaysinline
|
||||||
declare i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline
|
declare i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline
|
||||||
|
declare i64* @__cvt_loc2gen_var(i64 addrspace(3)*) nounwind readnone alwaysinline
|
||||||
')
|
')
|
||||||
|
|
||||||
|
|||||||
64
opt.cpp
64
opt.cpp
@@ -128,6 +128,7 @@ static llvm::Pass *CreateDebugPass(char * output);
|
|||||||
static llvm::Pass *CreateReplaceStdlibShiftPass();
|
static llvm::Pass *CreateReplaceStdlibShiftPass();
|
||||||
|
|
||||||
static llvm::Pass *CreateFixBooleanSelectPass();
|
static llvm::Pass *CreateFixBooleanSelectPass();
|
||||||
|
static llvm::Pass *CreatePromoteLocalToPrivatePass();
|
||||||
|
|
||||||
#define DEBUG_START_PASS(NAME) \
|
#define DEBUG_START_PASS(NAME) \
|
||||||
if (g->debugPrint && \
|
if (g->debugPrint && \
|
||||||
@@ -574,6 +575,9 @@ Optimize(llvm::Module *module, int optLevel) {
|
|||||||
optPM.add(llvm::createReassociatePass());
|
optPM.add(llvm::createReassociatePass());
|
||||||
optPM.add(llvm::createIPConstantPropagationPass());
|
optPM.add(llvm::createIPConstantPropagationPass());
|
||||||
optPM.add(CreateReplaceStdlibShiftPass(),229);
|
optPM.add(CreateReplaceStdlibShiftPass(),229);
|
||||||
|
if (g->target->getISA() == Target::NVPTX)
|
||||||
|
optPM.add(CreatePromoteLocalToPrivatePass());
|
||||||
|
#if 1
|
||||||
optPM.add(llvm::createDeadArgEliminationPass(),230);
|
optPM.add(llvm::createDeadArgEliminationPass(),230);
|
||||||
optPM.add(llvm::createInstructionCombiningPass());
|
optPM.add(llvm::createInstructionCombiningPass());
|
||||||
optPM.add(llvm::createCFGSimplificationPass());
|
optPM.add(llvm::createCFGSimplificationPass());
|
||||||
@@ -685,6 +689,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
|||||||
|
|
||||||
// Should be the last
|
// Should be the last
|
||||||
optPM.add(CreateFixBooleanSelectPass(), 400);
|
optPM.add(CreateFixBooleanSelectPass(), 400);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@@ -5264,4 +5269,63 @@ CreateFixBooleanSelectPass() {
|
|||||||
return new FixBooleanSelectPass();
|
return new FixBooleanSelectPass();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
// Detect addrspace(3)
|
||||||
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
class PromoteLocalToPrivatePass: public llvm::BasicBlockPass
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
static char ID; // Pass identification, replacement for typeid
|
||||||
|
PromoteLocalToPrivatePass() : BasicBlockPass(ID) {}
|
||||||
|
|
||||||
|
bool runOnBasicBlock(llvm::BasicBlock &BB);
|
||||||
|
};
|
||||||
|
|
||||||
|
char PromoteLocalToPrivatePass::ID = 0;
|
||||||
|
|
||||||
|
bool
|
||||||
|
PromoteLocalToPrivatePass::runOnBasicBlock(llvm::BasicBlock &BB)
|
||||||
|
{
|
||||||
|
std::vector<llvm::AllocaInst*> Allocas;
|
||||||
|
|
||||||
|
bool modifiedAny = false;
|
||||||
|
|
||||||
|
llvm::Function *cvtFunc = m->module->getFunction("__cvt_loc2gen_var");
|
||||||
|
|
||||||
|
// Find allocas that are safe to promote, by looking at all instructions in
|
||||||
|
// the entry node
|
||||||
|
for (llvm::BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
|
||||||
|
{
|
||||||
|
llvm::Instruction *inst = &*I;
|
||||||
|
if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst))
|
||||||
|
{
|
||||||
|
llvm::Function *func = ci->getCalledFunction();
|
||||||
|
if (cvtFunc && (cvtFunc == func))
|
||||||
|
{
|
||||||
|
#if 0
|
||||||
|
fprintf(stderr , "--found cvt-- name= %s \n",
|
||||||
|
I->getName().str().c_str());
|
||||||
|
#endif
|
||||||
|
llvm::AllocaInst *alloca = new llvm::AllocaInst(LLVMTypes::Int64Type, "opt_loc2var", ci);
|
||||||
|
assert(alloca != NULL);
|
||||||
|
#if 0
|
||||||
|
const int align = 8; // g->target->getNativeVectorAlignment();
|
||||||
|
alloca->setAlignment(align);
|
||||||
|
#endif
|
||||||
|
ci->replaceAllUsesWith(alloca);
|
||||||
|
modifiedAny = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modifiedAny;
|
||||||
|
}
|
||||||
|
|
||||||
|
static llvm::Pass *
|
||||||
|
CreatePromoteLocalToPrivatePass() {
|
||||||
|
return new PromoteLocalToPrivatePass();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
18
stmt.cpp
18
stmt.cpp
@@ -142,7 +142,7 @@ lHasUnsizedArrays(const Type *type) {
|
|||||||
return lHasUnsizedArrays(at->GetElementType());
|
return lHasUnsizedArrays(at->GetElementType());
|
||||||
}
|
}
|
||||||
|
|
||||||
static llvm::Value* lConvertToGenericPtr(FunctionEmitContext *ctx, llvm::Value *value, const SourcePos ¤tPos)
|
static llvm::Value* lConvertToGenericPtr(FunctionEmitContext *ctx, llvm::Value *value, const SourcePos ¤tPos, const bool variable = false)
|
||||||
{
|
{
|
||||||
if (!value->getType()->isPointerTy() || g->target->getISA() != Target::NVPTX)
|
if (!value->getType()->isPointerTy() || g->target->getISA() != Target::NVPTX)
|
||||||
return value;
|
return value;
|
||||||
@@ -159,10 +159,11 @@ static llvm::Value* lConvertToGenericPtr(FunctionEmitContext *ctx, llvm::Value *
|
|||||||
|
|
||||||
/* convert i64* addrspace(3) to i64* */
|
/* convert i64* addrspace(3) to i64* */
|
||||||
llvm::Function *__cvt2gen = m->module->getFunction(
|
llvm::Function *__cvt2gen = m->module->getFunction(
|
||||||
addressSpace == 3 ? "__cvt_loc2gen" : "__cvt_const2gen");
|
addressSpace == 3 ? (variable ? "__cvt_loc2gen_var" : "__cvt_loc2gen") : "__cvt_const2gen");
|
||||||
|
|
||||||
std::vector<llvm::Value *> __cvt2gen_args;
|
std::vector<llvm::Value *> __cvt2gen_args;
|
||||||
__cvt2gen_args.push_back(value);
|
__cvt2gen_args.push_back(value);
|
||||||
value = llvm::CallInst::Create(__cvt2gen, __cvt2gen_args, "gep2gen_cvt", ctx->GetCurrentBasicBlock());
|
value = llvm::CallInst::Create(__cvt2gen, __cvt2gen_args, variable ? "gep2gen_cvt_var" : "gep2gen_cvt", ctx->GetCurrentBasicBlock());
|
||||||
|
|
||||||
/* compute offset */
|
/* compute offset */
|
||||||
if (addressSpace == 3)
|
if (addressSpace == 3)
|
||||||
@@ -333,7 +334,7 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
|
|||||||
* constant uniform are automatically promoted to varying
|
* constant uniform are automatically promoted to varying
|
||||||
*/
|
*/
|
||||||
!sym->type->IsConstType() &&
|
!sym->type->IsConstType() &&
|
||||||
#if 1
|
#if 0
|
||||||
sym->type->IsArrayType() &&
|
sym->type->IsArrayType() &&
|
||||||
#endif
|
#endif
|
||||||
g->target->getISA() == Target::NVPTX)
|
g->target->getISA() == Target::NVPTX)
|
||||||
@@ -345,18 +346,19 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
|
|||||||
/* with __shared__ memory everything must be an array */
|
/* with __shared__ memory everything must be an array */
|
||||||
int nel = 4;
|
int nel = 4;
|
||||||
ArrayType *nat;
|
ArrayType *nat;
|
||||||
|
bool variable = true;
|
||||||
if (sym->type->IsArrayType())
|
if (sym->type->IsArrayType())
|
||||||
{
|
{
|
||||||
const ArrayType *at = CastType<ArrayType>(sym->type);
|
const ArrayType *at = CastType<ArrayType>(sym->type);
|
||||||
nel = at->GetElementCount();
|
|
||||||
/* we must scale # elements by 4, because a thread-block will run 4 warps
|
/* we must scale # elements by 4, because a thread-block will run 4 warps
|
||||||
* or 128 threads.
|
* or 128 threads.
|
||||||
* ***note-to-me***:please define these value (128threads/4warps)
|
* ***note-to-me***:please define these value (128threads/4warps)
|
||||||
* in nvptx-target definition
|
* in nvptx-target definition
|
||||||
* instead of compile-time constants
|
* instead of compile-time constants
|
||||||
*/
|
*/
|
||||||
nel *= 4;
|
nel *= at->GetElementCount();
|
||||||
nat = new ArrayType(at->GetElementType(), nel);
|
nat = new ArrayType(at->GetElementType(), nel);
|
||||||
|
variable = false;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
nat = new ArrayType(sym->type, nel);
|
nat = new ArrayType(sym->type, nel);
|
||||||
@@ -375,7 +377,7 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
|
|||||||
NULL,
|
NULL,
|
||||||
llvm::GlobalVariable::NotThreadLocal,
|
llvm::GlobalVariable::NotThreadLocal,
|
||||||
/*AddressSpace=*/3);
|
/*AddressSpace=*/3);
|
||||||
sym->storagePtr = lConvertToGenericPtr(ctx, sym->storagePtr, sym->pos);
|
sym->storagePtr = lConvertToGenericPtr(ctx, sym->storagePtr, sym->pos, variable);
|
||||||
llvm::PointerType *ptrTy = llvm::PointerType::get(sym->type->LLVMType(g->ctx),0);
|
llvm::PointerType *ptrTy = llvm::PointerType::get(sym->type->LLVMType(g->ctx),0);
|
||||||
sym->storagePtr = ctx->BitCastInst(sym->storagePtr, ptrTy, "uniform_decl");
|
sym->storagePtr = ctx->BitCastInst(sym->storagePtr, ptrTy, "uniform_decl");
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user