added basic optimization pass that promotes uniform into varying variables (not array) for nvptx target

This commit is contained in:
evghenii
2014-01-10 06:32:57 +01:00
parent b6b8855728
commit 9053eed4b4
5 changed files with 83 additions and 9 deletions

View File

@@ -630,7 +630,9 @@ lSetInternalFunctions(llvm::Module *module) {
"__nctaid_y", "__nctaid_y",
"__nctaid_z", "__nctaid_z",
"__warpsize", "__warpsize",
"__cvt_loc2gen" "__cvt_loc2gen",
"__cvt_loc2gen_var",
"__cvt_const2gen"
}; };
int count = sizeof(names) / sizeof(names[0]); int count = sizeof(names) / sizeof(names[0]);

View File

@@ -69,6 +69,11 @@ define i64* @__cvt_loc2gen(i64 addrspace(3)*) nounwind readnone alwaysinline
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p3i64(i64 addrspace(3)* %0) %ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p3i64(i64 addrspace(3)* %0)
ret i64* %ptr ret i64* %ptr
} }
define i64* @__cvt_loc2gen_var(i64 addrspace(3)*) nounwind readnone alwaysinline
{
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p3i64(i64 addrspace(3)* %0)
ret i64* %ptr
}
define i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline define i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline
{ {
%ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p4i64(i64 addrspace(4)* %0) %ptr = tail call i64* @llvm.nvvm.ptr.shared.to.gen.p0i64.p4i64(i64 addrspace(4)* %0)

View File

@@ -4544,5 +4544,6 @@ declare i32 @__nctaid_y() nounwind readnone alwaysinline
declare i32 @__nctaid_z() nounwind readnone alwaysinline declare i32 @__nctaid_z() nounwind readnone alwaysinline
declare i64* @__cvt_loc2gen(i64 addrspace(3)*) nounwind readnone alwaysinline declare i64* @__cvt_loc2gen(i64 addrspace(3)*) nounwind readnone alwaysinline
declare i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline declare i64* @__cvt_const2gen(i64 addrspace(4)*) nounwind readnone alwaysinline
declare i64* @__cvt_loc2gen_var(i64 addrspace(3)*) nounwind readnone alwaysinline
') ')

64
opt.cpp
View File

@@ -128,6 +128,7 @@ static llvm::Pass *CreateDebugPass(char * output);
static llvm::Pass *CreateReplaceStdlibShiftPass(); static llvm::Pass *CreateReplaceStdlibShiftPass();
static llvm::Pass *CreateFixBooleanSelectPass(); static llvm::Pass *CreateFixBooleanSelectPass();
static llvm::Pass *CreatePromoteLocalToPrivatePass();
#define DEBUG_START_PASS(NAME) \ #define DEBUG_START_PASS(NAME) \
if (g->debugPrint && \ if (g->debugPrint && \
@@ -574,6 +575,9 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(llvm::createReassociatePass()); optPM.add(llvm::createReassociatePass());
optPM.add(llvm::createIPConstantPropagationPass()); optPM.add(llvm::createIPConstantPropagationPass());
optPM.add(CreateReplaceStdlibShiftPass(),229); optPM.add(CreateReplaceStdlibShiftPass(),229);
if (g->target->getISA() == Target::NVPTX)
optPM.add(CreatePromoteLocalToPrivatePass());
#if 1
optPM.add(llvm::createDeadArgEliminationPass(),230); optPM.add(llvm::createDeadArgEliminationPass(),230);
optPM.add(llvm::createInstructionCombiningPass()); optPM.add(llvm::createInstructionCombiningPass());
optPM.add(llvm::createCFGSimplificationPass()); optPM.add(llvm::createCFGSimplificationPass());
@@ -685,6 +689,7 @@ Optimize(llvm::Module *module, int optLevel) {
// Should be the last // Should be the last
optPM.add(CreateFixBooleanSelectPass(), 400); optPM.add(CreateFixBooleanSelectPass(), 400);
#endif
} }
#endif #endif
@@ -5264,4 +5269,63 @@ CreateFixBooleanSelectPass() {
return new FixBooleanSelectPass(); return new FixBooleanSelectPass();
} }
///////////////////////////////////////////////////////////////////////////////
// Detect addrspace(3)
///////////////////////////////////////////////////////////////////////////////
class PromoteLocalToPrivatePass: public llvm::BasicBlockPass
{
public:
static char ID; // Pass identification, replacement for typeid
PromoteLocalToPrivatePass() : BasicBlockPass(ID) {}
bool runOnBasicBlock(llvm::BasicBlock &BB);
};
char PromoteLocalToPrivatePass::ID = 0;
bool
PromoteLocalToPrivatePass::runOnBasicBlock(llvm::BasicBlock &BB)
{
std::vector<llvm::AllocaInst*> Allocas;
bool modifiedAny = false;
llvm::Function *cvtFunc = m->module->getFunction("__cvt_loc2gen_var");
// Find allocas that are safe to promote, by looking at all instructions in
// the entry node
for (llvm::BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
{
llvm::Instruction *inst = &*I;
if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst))
{
llvm::Function *func = ci->getCalledFunction();
if (cvtFunc && (cvtFunc == func))
{
#if 0
fprintf(stderr , "--found cvt-- name= %s \n",
I->getName().str().c_str());
#endif
llvm::AllocaInst *alloca = new llvm::AllocaInst(LLVMTypes::Int64Type, "opt_loc2var", ci);
assert(alloca != NULL);
#if 0
const int align = 8; // g->target->getNativeVectorAlignment();
alloca->setAlignment(align);
#endif
ci->replaceAllUsesWith(alloca);
modifiedAny = true;
}
}
}
return modifiedAny;
}
static llvm::Pass *
CreatePromoteLocalToPrivatePass() {
return new PromoteLocalToPrivatePass();
}

View File

@@ -142,7 +142,7 @@ lHasUnsizedArrays(const Type *type) {
return lHasUnsizedArrays(at->GetElementType()); return lHasUnsizedArrays(at->GetElementType());
} }
static llvm::Value* lConvertToGenericPtr(FunctionEmitContext *ctx, llvm::Value *value, const SourcePos &currentPos) static llvm::Value* lConvertToGenericPtr(FunctionEmitContext *ctx, llvm::Value *value, const SourcePos &currentPos, const bool variable = false)
{ {
if (!value->getType()->isPointerTy() || g->target->getISA() != Target::NVPTX) if (!value->getType()->isPointerTy() || g->target->getISA() != Target::NVPTX)
return value; return value;
@@ -159,10 +159,11 @@ static llvm::Value* lConvertToGenericPtr(FunctionEmitContext *ctx, llvm::Value *
/* convert i64* addrspace(3) to i64* */ /* convert i64* addrspace(3) to i64* */
llvm::Function *__cvt2gen = m->module->getFunction( llvm::Function *__cvt2gen = m->module->getFunction(
addressSpace == 3 ? "__cvt_loc2gen" : "__cvt_const2gen"); addressSpace == 3 ? (variable ? "__cvt_loc2gen_var" : "__cvt_loc2gen") : "__cvt_const2gen");
std::vector<llvm::Value *> __cvt2gen_args; std::vector<llvm::Value *> __cvt2gen_args;
__cvt2gen_args.push_back(value); __cvt2gen_args.push_back(value);
value = llvm::CallInst::Create(__cvt2gen, __cvt2gen_args, "gep2gen_cvt", ctx->GetCurrentBasicBlock()); value = llvm::CallInst::Create(__cvt2gen, __cvt2gen_args, variable ? "gep2gen_cvt_var" : "gep2gen_cvt", ctx->GetCurrentBasicBlock());
/* compute offset */ /* compute offset */
if (addressSpace == 3) if (addressSpace == 3)
@@ -333,7 +334,7 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
* constant uniform are automatically promoted to varying * constant uniform are automatically promoted to varying
*/ */
!sym->type->IsConstType() && !sym->type->IsConstType() &&
#if 1 #if 0
sym->type->IsArrayType() && sym->type->IsArrayType() &&
#endif #endif
g->target->getISA() == Target::NVPTX) g->target->getISA() == Target::NVPTX)
@@ -345,18 +346,19 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
/* with __shared__ memory everything must be an array */ /* with __shared__ memory everything must be an array */
int nel = 4; int nel = 4;
ArrayType *nat; ArrayType *nat;
bool variable = true;
if (sym->type->IsArrayType()) if (sym->type->IsArrayType())
{ {
const ArrayType *at = CastType<ArrayType>(sym->type); const ArrayType *at = CastType<ArrayType>(sym->type);
nel = at->GetElementCount();
/* we must scale # elements by 4, because a thread-block will run 4 warps /* we must scale # elements by 4, because a thread-block will run 4 warps
* or 128 threads. * or 128 threads.
* ***note-to-me***:please define these value (128threads/4warps) * ***note-to-me***:please define these value (128threads/4warps)
* in nvptx-target definition * in nvptx-target definition
* instead of compile-time constants * instead of compile-time constants
*/ */
nel *= 4; nel *= at->GetElementCount();
nat = new ArrayType(at->GetElementType(), nel); nat = new ArrayType(at->GetElementType(), nel);
variable = false;
} }
else else
nat = new ArrayType(sym->type, nel); nat = new ArrayType(sym->type, nel);
@@ -375,7 +377,7 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
NULL, NULL,
llvm::GlobalVariable::NotThreadLocal, llvm::GlobalVariable::NotThreadLocal,
/*AddressSpace=*/3); /*AddressSpace=*/3);
sym->storagePtr = lConvertToGenericPtr(ctx, sym->storagePtr, sym->pos); sym->storagePtr = lConvertToGenericPtr(ctx, sym->storagePtr, sym->pos, variable);
llvm::PointerType *ptrTy = llvm::PointerType::get(sym->type->LLVMType(g->ctx),0); llvm::PointerType *ptrTy = llvm::PointerType::get(sym->type->LLVMType(g->ctx),0);
sym->storagePtr = ctx->BitCastInst(sym->storagePtr, ptrTy, "uniform_decl"); sym->storagePtr = ctx->BitCastInst(sym->storagePtr, ptrTy, "uniform_decl");