From 7d37f7b63481073361e1fea9ad5ec93d21d8e860 Mon Sep 17 00:00:00 2001 From: evghenii Date: Tue, 7 Jan 2014 18:29:44 +0100 Subject: [PATCH] added separate function that deal with local pointers --- ctx.cpp | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/ctx.cpp b/ctx.cpp index 410deea0..ec0049c9 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -1832,6 +1832,34 @@ FunctionEmitContext::BitCastInst(llvm::Value *value, llvm::Type *type, return inst; } +static llvm::Value* lCorrectLocalPtr(FunctionEmitContext *ctx, llvm::Value* value) +{ + assert(value->getType()->isPointerTy()); + llvm::PointerType *pt = llvm::dyn_cast(value->getType()); + assert (pt->getAddressSpace() == 3); + + llvm::Function *func_tid_x = m->module->getFunction("__tid_x"); + llvm::Function *func_warpsz = m->module->getFunction("__warpsize"); + llvm::Value *__tid_x = ctx->CallInst(func_tid_x, NULL, std::vector(), "tidCorrectLocalPtr"); + llvm::Value *__warpsz = ctx->CallInst(func_warpsz, NULL, std::vector(), "warpSzCorrectLocaLPtr"); + llvm::Value *__warpid = ctx->BinaryOperator(llvm::Instruction::SDiv, __tid_x, __warpsz, "warpIdCorrectLocalPtr"); + return llvm::GetElementPtrInst::Create(value, __warpid, "__gepCorrectLocalPtr", ctx->GetCurrentBasicBlock()); +} + +static llvm::Value* lConvertLocalToGenericPtr(FunctionEmitContext *ctx, llvm::Value *value) +{ + if (!value->getType()->isPointerTy() || g->target->getISA() != Target::NVPTX) return value; + llvm::PointerType *pt = llvm::dyn_cast(value->getType()); + if (pt->getAddressSpace() != 3) return value; + + value = lCorrectLocalPtr(ctx, value); + llvm::PointerType *PointerTy = llvm::PointerType::get(LLVMTypes::Int64Type, 3); + llvm::Value *cast = ctx->BitCastInst(value, PointerTy, "__cvt_log2gen_i64ptr1_"); + llvm::Function *__cvt_loc2gen = m->module->getFunction("__cvt_loc2gen"); + std::vector __cvt_loc2gen_args; + __cvt_loc2gen_args.push_back(cast); + return ctx->CallInst(__cvt_loc2gen, NULL, __cvt_loc2gen_args, "__cvt_loc2gen1_"); +} llvm::Value * FunctionEmitContext::PtrToIntInst(llvm::Value *value, const char *name) { @@ -1847,19 +1875,6 @@ FunctionEmitContext::PtrToIntInst(llvm::Value *value, const char *name) { if (name == NULL) name = LLVMGetName(value, "_ptr2int"); - if (value->getType()->isPointerTy() && g->target->getISA() == Target::NVPTX) - { - llvm::PointerType *pt = llvm::dyn_cast(value->getType()); - if (pt->getAddressSpace() == 3) - { - llvm::PointerType *PointerTy3 = llvm::PointerType::get(LLVMTypes::Int64Type, 3); - llvm::Value *cast = BitCastInst(value, PointerTy3, "__cvt_log2gen_i64ptr1_"); - llvm::Function *__cvt_loc2gen = m->module->getFunction("__cvt_loc2gen"); - std::vector __cvt_loc2gen_args; - __cvt_loc2gen_args.push_back(cast); - value = CallInst(__cvt_loc2gen, NULL, __cvt_loc2gen_args, "__cvt_loc2gen1_"); - } - } llvm::Type *type = LLVMTypes::PointerIntType; llvm::Instruction *inst = new llvm::PtrToIntInst(value, type, name, bblock);