diff --git a/builtins-c.c b/builtins-c.c index 30cbe106..f1cb35dd 100644 --- a/builtins-c.c +++ b/builtins-c.c @@ -133,6 +133,8 @@ void __do_print(const char *format, const char *types, int width, int mask, case 'V': PRINT_VECTOR("%llu", unsigned long long); case 'd': PRINT_SCALAR("%f", double); case 'D': PRINT_VECTOR("%f", double); + case 'p': PRINT_SCALAR("%p", void *); + case 'P': PRINT_VECTOR("%p", void *); default: printf("UNKNOWN TYPE "); putchar(*types); diff --git a/builtins.m4 b/builtins.m4 index 957af297..141f6ebd 100644 --- a/builtins.m4 +++ b/builtins.m4 @@ -1080,6 +1080,14 @@ define internal <$1 x i32> @__sext_varying_bool(<$1 x i32>) nounwind readnone al ret <$1 x i32> %0 } +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; count trailing zeros + +define internal i32 @__count_trailing_zeros(i32) nounwind readnone alwaysinline { + %c = call i32 @llvm.cttz.i32(i32 %0) + ret i32 %c +} + ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; AOS/SOA conversion primitives @@ -2457,7 +2465,7 @@ done: ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; reduce_equal -; count leading zeros +; count trailing zeros declare i32 @llvm.cttz.i32(i32) define(`reduce_equal_aux', ` diff --git a/ctx.cpp b/ctx.cpp index d7219ee2..4a5bc561 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -735,7 +735,7 @@ FunctionEmitContext::LaneMask(llvm::Value *v) { // There should be one with signed int signature, one unsigned int. assert(mm && mm->size() == 2); llvm::Function *fmm = (*mm)[0]->function; - return CallInst(fmm, v, "val_movmsk"); + return CallInst(fmm, AtomicType::UniformInt32, v, "val_movmsk"); } @@ -858,7 +858,7 @@ FunctionEmitContext::AddInstrumentationPoint(const char *note) { args.push_back(LaneMask(GetFullMask())); llvm::Function *finst = m->module->getFunction("ISPCInstrument"); - CallInst(finst, args, ""); + CallInst(finst, AtomicType::Void, args, ""); } @@ -976,12 +976,18 @@ lArrayVectorWidth(LLVM_TYPE_CONST llvm::Type *t) { if (arrayType == NULL) return 0; - // We shouldn't be seeing arrays of anything but vectors being passed - // to things like FunctionEmitContext::BinaryOperator() as operands + // We shouldn't be seeing arrays of anything but vectors or pointers + // (for == and !=) being passed to things like + // FunctionEmitContext::BinaryOperator() as operands LLVM_TYPE_CONST llvm::VectorType *vectorElementType = llvm::dyn_cast(arrayType->getElementType()); - assert(vectorElementType != NULL && - (int)vectorElementType->getNumElements() == g->target.vectorWidth); + LLVM_TYPE_CONST llvm::PointerType *pointerElementType = + llvm::dyn_cast(arrayType->getElementType()); + assert((vectorElementType != NULL && + (int)vectorElementType->getNumElements() == g->target.vectorWidth) || + (pointerElementType != NULL && + (int)arrayType->getNumElements() == g->target.vectorWidth)); + return (int)arrayType->getNumElements(); } @@ -1052,24 +1058,30 @@ FunctionEmitContext::NotOperator(llvm::Value *v, const char *name) { } -// Given the llvm Type that represents an ispc VectorType, return an -// equally-shaped type with boolean elements. (This is the type that will -// be returned from CmpInst with ispc VectorTypes). +// Given the llvm Type that represents an ispc VectorType (or array of +// pointers), return an equally-shaped type with boolean elements. (This +// is the type that will be returned from CmpInst with ispc VectorTypes). static LLVM_TYPE_CONST llvm::Type * lGetMatchingBoolVectorType(LLVM_TYPE_CONST llvm::Type *type) { LLVM_TYPE_CONST llvm::ArrayType *arrayType = llvm::dyn_cast(type); - // should only be called for vector typed stuff... assert(arrayType != NULL); - LLVM_TYPE_CONST llvm::VectorType *vectorElementType = + LLVM_TYPE_CONST llvm::VectorType *vectorElementType = llvm::dyn_cast(arrayType->getElementType()); - assert(vectorElementType != NULL && - (int)vectorElementType->getNumElements() == g->target.vectorWidth); - - LLVM_TYPE_CONST llvm::Type *base = - llvm::VectorType::get(LLVMTypes::BoolType, g->target.vectorWidth); - return llvm::ArrayType::get(base, arrayType->getNumElements()); + if (vectorElementType != NULL) { + assert((int)vectorElementType->getNumElements() == g->target.vectorWidth); + LLVM_TYPE_CONST llvm::Type *base = + llvm::VectorType::get(LLVMTypes::BoolType, g->target.vectorWidth); + return llvm::ArrayType::get(base, arrayType->getNumElements()); + } + else { + LLVM_TYPE_CONST llvm::PointerType *pointerElementType = + llvm::dyn_cast(arrayType->getElementType()); + assert(pointerElementType != NULL); + assert((int)arrayType->getNumElements() == g->target.vectorWidth); + return llvm::VectorType::get(LLVMTypes::BoolType, g->target.vectorWidth); + } } @@ -1213,7 +1225,7 @@ FunctionEmitContext::IntToPtrInst(llvm::Value *value, LLVM_TYPE_CONST llvm::Type LLVM_TYPE_CONST llvm::Type *valType = value->getType(); LLVM_TYPE_CONST llvm::ArrayType *at = llvm::dyn_cast(valType); - if (at && llvm::isa(at->getElementType())) { + if (at != NULL) { // varying lvalue -> apply int to ptr to the individual pointers assert((int)at->getNumElements() == g->target.vectorWidth); @@ -1252,6 +1264,48 @@ FunctionEmitContext::TruncInst(llvm::Value *value, LLVM_TYPE_CONST llvm::Type *t } +llvm::Value * +FunctionEmitContext::ArrayToVectorInst(llvm::Value *array) { + if (array == NULL) { + assert(m->errorCount > 0); + return NULL; + } + + LLVM_TYPE_CONST llvm::ArrayType *at = + llvm::dyn_cast(array->getType()); + assert(at != NULL); + + uint64_t count = at->getNumElements(); + LLVM_TYPE_CONST llvm::VectorType *vt = + llvm::VectorType::get(at->getElementType(), count); + llvm::Value *vec = llvm::UndefValue::get(vt); + for (uint64_t i = 0; i < count; ++i) + vec = InsertInst(vec, ExtractInst(array, i), i); + return vec; +} + + +llvm::Value * +FunctionEmitContext::VectorToArrayInst(llvm::Value *vector) { + if (vector == NULL) { + assert(m->errorCount > 0); + return NULL; + } + + LLVM_TYPE_CONST llvm::VectorType *vt = + llvm::dyn_cast(vector->getType()); + assert(vt != NULL); + + uint64_t count = vt->getNumElements(); + LLVM_TYPE_CONST llvm::ArrayType *at = + llvm::ArrayType::get(vt->getElementType(), count); + llvm::Value *array = llvm::UndefValue::get(at); + for (uint64_t i = 0; i < count; ++i) + array = InsertInst(array, ExtractInst(vector, i), i); + return array; +} + + llvm::Instruction * FunctionEmitContext::CastInst(llvm::Instruction::CastOps op, llvm::Value *value, LLVM_TYPE_CONST llvm::Type *type, const char *name) { @@ -1504,16 +1558,23 @@ FunctionEmitContext::gather(llvm::Value *lvalue, llvm::Value *mask, } return retValue; } - - // Otherwise we should just have a basic scalar type and we can go and - // do the actual gather + + // Otherwise we should just have a basic scalar or pointer type and we + // can go and do the actual gather AddInstrumentationPoint("gather"); llvm::Function *gather = NULL; // Figure out which gather function to call based on the size of // the elements. - if (retType == LLVMTypes::DoubleVectorType || - retType == LLVMTypes::Int64VectorType) + const PointerType *pt = dynamic_cast(type); + if (pt != NULL) { + if (g->target.is32bit) + gather = m->module->getFunction("__pseudo_gather_32"); + else + gather = m->module->getFunction("__pseudo_gather_64"); + } + else if (retType == LLVMTypes::DoubleVectorType || + retType == LLVMTypes::Int64VectorType) gather = m->module->getFunction("__pseudo_gather_64"); else if (retType == LLVMTypes::FloatVectorType || retType == LLVMTypes::Int32VectorType) @@ -1529,15 +1590,21 @@ FunctionEmitContext::gather(llvm::Value *lvalue, llvm::Value *mask, lvalue = addVaryingOffsetsIfNeeded(lvalue, type); llvm::Value *voidlvalue = BitCastInst(lvalue, LLVMTypes::VoidPointerType); - llvm::Instruction *call = CallInst(gather, voidlvalue, mask, name); + llvm::Value *call = CallInst(gather, type, voidlvalue, mask, name); + // Add metadata about the source file location so that the // optimization passes can print useful performance warnings if we // can't optimize out this gather addGSMetadata(call, currentPos); - llvm::Value *val = BitCastInst(call, retType, "gather_bitcast"); - - return val; + if (pt != NULL) { + LLVM_TYPE_CONST llvm::Type *ptrType = + pt->GetAsUniformType()->LLVMType(g->ctx); + return IntToPtrInst(VectorToArrayInst(call), ptrType, + "gather_bitcast"); + } + else + return BitCastInst(call, retType, "gather_bitcast"); } @@ -1546,7 +1613,11 @@ FunctionEmitContext::gather(llvm::Value *lvalue, llvm::Value *mask, function in opt.cpp. */ void -FunctionEmitContext::addGSMetadata(llvm::Instruction *inst, SourcePos pos) { +FunctionEmitContext::addGSMetadata(llvm::Value *v, SourcePos pos) { + llvm::Instruction *inst = llvm::dyn_cast(v); + if (inst == NULL) + return; + llvm::Value *str = llvm::MDString::get(*g->ctx, pos.name); llvm::MDNode *md = llvm::MDNode::get(*g->ctx, str); inst->setMetadata("filename", md); @@ -1628,6 +1699,19 @@ FunctionEmitContext::maskedStore(llvm::Value *rvalue, llvm::Value *lvalue, return; } + const PointerType *pt = dynamic_cast(rvalueType); + if (pt != NULL) { + if (g->target.is32bit) { + rvalue = PtrToIntInst(rvalue, LLVMTypes::Int32Type, "ptr2int"); + rvalueType = AtomicType::VaryingInt32; + } + else { + rvalue = PtrToIntInst(rvalue, LLVMTypes::Int64Type, "ptr2int"); + rvalueType = AtomicType::VaryingInt64; + } + rvalue = ArrayToVectorInst(rvalue); + } + // We must have a regular atomic or enumerator type at this point assert(dynamic_cast(rvalueType) != NULL || dynamic_cast(rvalueType) != NULL); @@ -1673,7 +1757,7 @@ FunctionEmitContext::maskedStore(llvm::Value *rvalue, llvm::Value *lvalue, args.push_back(lvalue); args.push_back(rvalue); args.push_back(storeMask); - CallInst(maskedStoreFunc, args); + CallInst(maskedStoreFunc, AtomicType::Void, args); } @@ -1722,12 +1806,27 @@ FunctionEmitContext::scatter(llvm::Value *rvalue, llvm::Value *lvalue, // I think this should be impossible assert(dynamic_cast(rvalueType) == NULL); - // And everything should be atomic from here on out... - assert(dynamic_cast(rvalueType) != NULL); + const PointerType *pt = dynamic_cast(rvalueType); + + // And everything should be a pointer or atomic from here on out... + assert(pt != NULL || + dynamic_cast(rvalueType) != NULL); llvm::Function *func = NULL; LLVM_TYPE_CONST llvm::Type *type = rvalue->getType(); - if (type == LLVMTypes::DoubleVectorType || + if (pt != NULL) { + if (g->target.is32bit) { + rvalue = PtrToIntInst(rvalue, LLVMTypes::Int32Type); + rvalue = ArrayToVectorInst(rvalue); + func = m->module->getFunction("__pseudo_scatter_32"); + } + else { + rvalue = PtrToIntInst(rvalue, LLVMTypes::Int64Type); + rvalue = ArrayToVectorInst(rvalue); + func = m->module->getFunction("__pseudo_scatter_64"); + } + } + else if (type == LLVMTypes::DoubleVectorType || type == LLVMTypes::Int64VectorType) { func = m->module->getFunction("__pseudo_scatter_64"); rvalue = BitCastInst(rvalue, LLVMTypes::Int64VectorType, "rvalue2int"); @@ -1752,7 +1851,7 @@ FunctionEmitContext::scatter(llvm::Value *rvalue, llvm::Value *lvalue, args.push_back(voidlvalue); args.push_back(rvalue); args.push_back(storeMask); - llvm::Instruction *inst = CallInst(func, args); + llvm::Value *inst = CallInst(func, AtomicType::Void, args); addGSMetadata(inst, currentPos); } @@ -1900,8 +1999,33 @@ FunctionEmitContext::SelectInst(llvm::Value *test, llvm::Value *val0, } -llvm::Instruction * -FunctionEmitContext::CallInst(llvm::Function *func, +/* Given a value representing a function to be called or possibly-varying + pointer to a function to be called, figure out how many arguments the + function has. */ +static unsigned int +lCalleeArgCount(llvm::Value *callee) { + LLVM_TYPE_CONST llvm::FunctionType *ft = + llvm::dyn_cast(callee->getType()); + if (ft == NULL) { + LLVM_TYPE_CONST llvm::PointerType *pt = + llvm::dyn_cast(callee->getType()); + if (pt == NULL) { + // varying... + LLVM_TYPE_CONST llvm::ArrayType *at = + llvm::dyn_cast(callee->getType()); + assert(at != NULL); + pt = llvm::dyn_cast(at->getElementType()); + assert(pt != NULL); + } + ft = llvm::dyn_cast(pt->getElementType()); + } + assert(ft != NULL); + return ft->getNumParams(); +} + + +llvm::Value * +FunctionEmitContext::CallInst(llvm::Value *func, const Type *returnType, const std::vector &args, const char *name) { if (func == NULL) { @@ -1909,62 +2033,186 @@ FunctionEmitContext::CallInst(llvm::Function *func, return NULL; } + std::vector argVals = args; + // Most of the time, the mask is passed as the last argument. this + // isn't the case for things like intrinsics, builtins, and extern "C" + // functions from the application. Add the mask if it's needed. + unsigned int calleeArgCount = lCalleeArgCount(func); + assert(argVals.size() + 1 == calleeArgCount || + argVals.size() == calleeArgCount); + if (argVals.size() + 1 == calleeArgCount) + argVals.push_back(GetFullMask()); + + LLVM_TYPE_CONST llvm::Type *funcType = func->getType(); + LLVM_TYPE_CONST llvm::ArrayType *funcArrayType = + llvm::dyn_cast(funcType); + + if (funcArrayType == NULL) { + // Regular 'uniform' function call--just one function or function + // pointer, so just emit the IR directly. #if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn) - llvm::Instruction *ci = - llvm::CallInst::Create(func, args, name ? name : "", bblock); + llvm::Instruction *ci = + llvm::CallInst::Create(func, argVals, name ? name : "", bblock); #else - llvm::Instruction *ci = - llvm::CallInst::Create(func, args.begin(), args.end(), - name ? name : "", bblock); + llvm::Instruction *ci = + llvm::CallInst::Create(func, argVals.begin(), argVals.end(), + name ? name : "", bblock); #endif - AddDebugPos(ci); - return ci; + AddDebugPos(ci); + return ci; + } + else { + // Emit the code for a varying function call, where we have an + // array of function pointers, one for each program instance. The + // basic strategy is that we go through the function pointers, and + // for the executing program instances, for each unique function + // pointer that's in the array, call that function with a mask + // equal to the set of active program instances that also have that + // function pointer. When all unique function pointers have been + // called, we're done. + + llvm::BasicBlock *bbTest = CreateBasicBlock("varying_funcall_test"); + llvm::BasicBlock *bbCall = CreateBasicBlock("varying_funcall_call"); + llvm::BasicBlock *bbDone = CreateBasicBlock("varying_funcall_done"); + + llvm::Value *origMask = GetInternalMask(); + + // First allocate memory to accumulate the various lanes' return + // values... + LLVM_TYPE_CONST llvm::Type *llvmReturnType = returnType->LLVMType(g->ctx); + llvm::Value *resultPtr = NULL; + if (llvmReturnType->isVoidTy() == false) + resultPtr = AllocaInst(llvmReturnType); + + // Store the function pointers into an array so that we can index + // into them.. + llvm::Value *funcPtrArray = AllocaInst(funcType); + StoreInst(func, funcPtrArray); + + // The memory pointed to by maskPointer tracks the set of program + // instances for which we still need to call the function they are + // pointing to. It starts out initialized with the mask of + // currently running program instances. + llvm::Value *maskPtr = AllocaInst(LLVMTypes::MaskType); + StoreInst(GetFullMask(), maskPtr); + + // And now we branch to the test to see if there's more work to be + // done. + BranchInst(bbTest); + + // bbTest: are any lanes of the mask still on? If so, jump to + // bbCall + SetCurrentBasicBlock(bbTest); { + llvm::Value *maskLoad = LoadInst(maskPtr, NULL, NULL); + llvm::Value *any = Any(maskLoad); + BranchInst(bbCall, bbDone, any); + } + + // bbCall: this is the body of the loop that calls out to one of + // the active function pointer values. + SetCurrentBasicBlock(bbCall); { + // Figure out the first lane that still needs its function + // pointer to be called. + llvm::Value *currentMask = LoadInst(maskPtr, NULL, NULL); + llvm::Function *cttz = m->module->getFunction("__count_trailing_zeros"); + assert(cttz != NULL); + llvm::Value *firstLane = CallInst(cttz, AtomicType::UniformInt32, + LaneMask(currentMask), "first_lane"); + + // Get the pointer to the function we're going to call this time through: + // ftpr = funcPtrArray[firstLane] + llvm::Value *fpOffset = + GetElementPtrInst(funcPtrArray, LLVMInt32(0), firstLane, + "func_offset_ptr"); + llvm::Value *fptr = LoadInst(fpOffset, NULL, NULL); + + // Smear it out into an array of function pointers + llvm::Value *fptrSmear = SmearScalar(fptr, "func_ptr"); + + // Now convert the smeared array of function pointers and the + // given array of function pointers to vectors of int32s or + // int64s, where the pointer has been cast to an int of the + // appropraite size for the compilation target. + LLVM_TYPE_CONST llvm::Type *ptrIntType = g->target.is32bit ? + LLVMTypes::Int32Type : LLVMTypes::Int64Type; + llvm::Value *fpSmearAsVec = + ArrayToVectorInst(PtrToIntInst(fptrSmear, ptrIntType)); + llvm::Value *fpOrigAsVec = + ArrayToVectorInst(PtrToIntInst(func, ptrIntType)); + + // fpOverlap = (fpSmearAsVec == fpOrigAsVec). This gives us a + // mask for the set of program instances that have the same + // value for their function pointer. + llvm::Value *fpOverlap = + CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, + fpSmearAsVec, fpOrigAsVec); + fpOverlap = I1VecToBoolVec(fpOverlap); + + // Figure out the mask to use when calling the function + // pointer: we need to AND the current execution mask to handle + // the case of any non-running program instances that happen to + // have this function pointer value. + // callMask = (currentMask & fpOverlap) + llvm::Value *callMask = + BinaryOperator(llvm::Instruction::And, currentMask, fpOverlap, + "call_mask"); + + // Set the mask + SetInternalMask(callMask); + + // Call the function: callResult = call ftpr(args, args, call mask) + llvm::Value *callResult = CallInst(fptr, returnType, args, name); + + // Now, do a masked store into the memory allocated to + // accumulate the result using the call mask. + if (callResult != NULL) { + assert(resultPtr != NULL); + StoreInst(callResult, resultPtr, callMask, returnType); + } + else + assert(resultPtr == NULL); + + // Update the mask to turn off the program instances for which + // we just called the function. + // currentMask = currentMask & ~callmask + llvm::Value *notCallMask = + BinaryOperator(llvm::Instruction::Xor, callMask, LLVMMaskAllOn, + "~callMask"); + currentMask = BinaryOperator(llvm::Instruction::And, currentMask, + notCallMask, "currentMask&~callMask"); + StoreInst(currentMask, maskPtr); + + // And go back to the test to see if we need to do another + // call. + BranchInst(bbTest); + } + + // bbDone: We're all done; clean up and return the result we've + // accumulated in the result memory. + SetCurrentBasicBlock(bbDone); + SetInternalMask(origMask); + return LoadInst(resultPtr, NULL, NULL); + } } -llvm::Instruction * -FunctionEmitContext::CallInst(llvm::Function *func, llvm::Value *arg, +llvm::Value * +FunctionEmitContext::CallInst(llvm::Value *func, const Type *returnType, + llvm::Value *arg, const char *name) { + std::vector args; + args.push_back(arg); + return CallInst(func, returnType, args, name); +} + + +llvm::Value * +FunctionEmitContext::CallInst(llvm::Value *func, const Type *returnType, + llvm::Value *arg0, llvm::Value *arg1, const char *name) { - if (func == NULL || arg == NULL) { - assert(m->errorCount > 0); - return NULL; - } - -#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn) - llvm::Instruction *ci = - llvm::CallInst::Create(func, arg, name ? name : "", bblock); -#else - llvm::Value *args[] = { arg }; - llvm::Instruction *ci = - llvm::CallInst::Create(func, &args[0], &args[1], name ? name : "", - bblock); -#endif - AddDebugPos(ci); - return ci; -} - - -llvm::Instruction * -FunctionEmitContext::CallInst(llvm::Function *func, llvm::Value *arg0, - llvm::Value *arg1, const char *name) { - if (func == NULL || arg0 == NULL || arg1 == NULL) { - assert(m->errorCount > 0); - return NULL; - } - - llvm::Value *args[] = { arg0, arg1 }; -#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn) - llvm::ArrayRef argArrayRef(&args[0], &args[2]); - llvm::Instruction *ci = - llvm::CallInst::Create(func, argArrayRef, name ? name : "", - bblock); -#else - llvm::Instruction *ci = - llvm::CallInst::Create(func, &args[0], &args[2], name ? name : "", - bblock); -#endif - AddDebugPos(ci); - return ci; + std::vector args; + args.push_back(arg0); + args.push_back(arg1); + return CallInst(func, returnType, args, name); } @@ -1993,8 +2241,8 @@ FunctionEmitContext::ReturnInst() { } -llvm::Instruction * -FunctionEmitContext::LaunchInst(llvm::Function *callee, +llvm::Value * +FunctionEmitContext::LaunchInst(llvm::Value *callee, std::vector &argVals, llvm::Value *launchCount) { if (callee == NULL) { @@ -2004,7 +2252,9 @@ FunctionEmitContext::LaunchInst(llvm::Function *callee, launchedTasks = true; - LLVM_TYPE_CONST llvm::Type *argType = callee->arg_begin()->getType(); + assert(llvm::isa(callee)); + LLVM_TYPE_CONST llvm::Type *argType = + (llvm::dyn_cast(callee))->arg_begin()->getType(); assert(llvm::PointerType::classof(argType)); LLVM_TYPE_CONST llvm::PointerType *pt = llvm::dyn_cast(argType); @@ -2020,7 +2270,7 @@ FunctionEmitContext::LaunchInst(llvm::Function *callee, allocArgs.push_back(launchGroupHandlePtr); allocArgs.push_back(SizeOf(argStructType)); allocArgs.push_back(LLVMInt32(align)); - llvm::Value *voidmem = CallInst(falloc, allocArgs, "args_ptr"); + llvm::Value *voidmem = CallInst(falloc, NULL, allocArgs, "args_ptr"); llvm::Value *argmem = BitCastInst(voidmem, pt); // Copy the values of the parameters into the appropriate place in @@ -2048,7 +2298,7 @@ FunctionEmitContext::LaunchInst(llvm::Function *callee, args.push_back(fptr); args.push_back(voidmem); args.push_back(launchCount); - return CallInst(flaunch, args, ""); + return CallInst(flaunch, AtomicType::Void, args, ""); } @@ -2067,7 +2317,7 @@ FunctionEmitContext::SyncInst() { llvm::Function *fsync = m->module->getFunction("ISPCSync"); if (fsync == NULL) FATAL("Couldn't find ISPCSync declaration?!"); - CallInst(fsync, launchGroupHandle, ""); + CallInst(fsync, AtomicType::Void, launchGroupHandle, ""); BranchInst(bPostSync); SetCurrentBasicBlock(bPostSync); @@ -2095,7 +2345,9 @@ FunctionEmitContext::addVaryingOffsetsIfNeeded(llvm::Value *ptr, const Type *typ // the data we're gathering from/scattering to is varying in memory. // If we have pointers to scalar types, e.g. [8 x float *], then the // data is uniform in memory and doesn't need any additional offsets. - if (llvm::isa(pt->getElementType()) == false) + if (pt->getElementType()->isIntegerTy() || + pt->getElementType()->isFloatingPointTy() || + pt->getElementType()->isPointerTy()) return ptr; llvm::Value *varyingOffsets = llvm::UndefValue::get(LLVMTypes::Int32VectorType); diff --git a/ctx.h b/ctx.h index e41bad90..27d030ee 100644 --- a/ctx.h +++ b/ctx.h @@ -367,8 +367,9 @@ public: instruction is added at the start of the function in the entry basic block; if it should be added to the current basic block, then the atEntryBlock parameter should be false. */ - llvm::Value *AllocaInst(LLVM_TYPE_CONST llvm::Type *llvmType, const char *name = NULL, - int align = 0, bool atEntryBlock = true); + llvm::Value *AllocaInst(LLVM_TYPE_CONST llvm::Type *llvmType, + const char *name = NULL, int align = 0, + bool atEntryBlock = true); /** Standard store instruction; for this variant, the lvalue must be a single pointer, not a varying lvalue. */ @@ -403,24 +404,28 @@ public: llvm::Instruction *SelectInst(llvm::Value *test, llvm::Value *val0, llvm::Value *val1, const char *name = NULL); - llvm::Instruction *CallInst(llvm::Function *func, - const std::vector &args, - const char *name = NULL); + /** Emits IR to do a function call with the given arguments. The + function return type must be provided in returnType. */ + llvm::Value *CallInst(llvm::Value *func, const Type *returnType, + const std::vector &args, + const char *name = NULL); + /** This is a convenience method that issues a call instruction to a function that takes just a single argument. */ - llvm::Instruction *CallInst(llvm::Function *func, llvm::Value *arg, - const char *name = NULL); + llvm::Value *CallInst(llvm::Value *func, const Type *returnType, + llvm::Value *arg, const char *name = NULL); /** This is a convenience method that issues a call instruction to a function that takes two arguments. */ - llvm::Instruction *CallInst(llvm::Function *func, llvm::Value *arg0, - llvm::Value *arg1, const char *name = NULL); + llvm::Value *CallInst(llvm::Value *func, const Type *returnType, + llvm::Value *arg0, llvm::Value *arg1, + const char *name = NULL); /** Launch an asynchronous task to run the given function, passing it he given argument values. */ - llvm::Instruction *LaunchInst(llvm::Function *callee, - std::vector &argVals, - llvm::Value *launchCount); + llvm::Value *LaunchInst(llvm::Value *callee, + std::vector &argVals, + llvm::Value *launchCount); void SyncInst(); @@ -523,7 +528,7 @@ private: llvm::Value *launchGroupHandlePtr; llvm::Value *pointerVectorToVoidPointers(llvm::Value *value); - static void addGSMetadata(llvm::Instruction *inst, SourcePos pos); + static void addGSMetadata(llvm::Value *inst, SourcePos pos); bool ifsInLoopAllUniform() const; void jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target); llvm::Value *emitGatherCallback(llvm::Value *lvalue, llvm::Value *retPtr); diff --git a/decl.cpp b/decl.cpp index cfecf289..f6ee1e59 100644 --- a/decl.cpp +++ b/decl.cpp @@ -91,6 +91,7 @@ Declarator::Declarator(Symbol *s, SourcePos p) functionArgs = NULL; isFunction = false; initExpr = NULL; + pointerCount = 0; } @@ -104,6 +105,14 @@ Declarator::AddArrayDimension(int size) { void Declarator::InitFromDeclSpecs(DeclSpecs *ds) { sym->type = GetType(ds); + for (int i = 0; i < pointerCount; ++i) { + // Only function pointers for now... + if (dynamic_cast(sym->type) == NULL) + Error(pos, "Only pointers to functions are currently allowed, " + "not pointers to \"%s\".", sym->type->GetString().c_str()); + else + sym->type = new PointerType(sym->type, true, false); + } sym->storageClass = ds->storageClass; } diff --git a/decl.h b/decl.h index a4d869ee..5ee05a21 100644 --- a/decl.h +++ b/decl.h @@ -156,6 +156,7 @@ public: /** Initialization expression for the variable. May be NULL. */ Expr *initExpr; bool isFunction; + int pointerCount; std::vector *functionArgs; }; diff --git a/docs/ispc.txt b/docs/ispc.txt index af34cbe2..2eac37af 100644 --- a/docs/ispc.txt +++ b/docs/ispc.txt @@ -63,6 +63,7 @@ Contents: + `Lexical Structure`_ + `Basic Types and Type Qualifiers`_ + + `Function Pointer Types`_ + `Enumeration Types`_ + `Short Vector Types`_ + `Struct and Array Types`_ @@ -82,6 +83,7 @@ Contents: + `Program Instance Convergence`_ + `Data Races`_ + `Uniform Variables and Varying Control Flow`_ + + `Function Pointers`_ + `Task Parallelism: Language Syntax`_ + `Task Parallelism: Runtime Requirements`_ @@ -620,7 +622,34 @@ results or modify existing variables. ++f; } -``ispc`` doesn't currently support pointer types. +``ispc`` doesn't currently support pointer types, except for functions, as +described below. + +Function Pointer Types +---------------------- + +``ispc`` does allow function pointers to be taken and used as in C and +C++. The syntax for declaring function pointer types is the same as in +those languages; it's generally easiest to use a ``typedef`` to help: + +:: + + int inc(int v) { return v+1; } + int dec(int v) { return v-1; } + + typedef int (*FPType)(int); + FPType fptr = inc; + +Given a function pointer, the function it points to can be called: + +:: + + int x = fptr(1); + +Note that ``ispc`` doesn't currently support the "address-of" operator +``&`` or the "derefernce" operator ``*``, so it's not necessary to take the +address of a function to assign it to a function pointer or to dereference +it to call the function. Enumeration Types @@ -1439,6 +1468,26 @@ be modified in the above code even if *none* of the program instances evaluated a true value for the test, given the ``ispc`` execution model. +Function Pointers +----------------- + +As with other variables, a function pointer in ``ispc`` may be of +``uniform`` or ``varying`` type. If a function pointer is ``uniform``, it +has the same value for all of the executing program instances, and thus all +active program instances will call the same function if the function +pointer is used. + +If a function pointer is ``varying``, then it has a possibly-different +value for all running program instances. Given a call to a varying +function pointer, ``ispc`` maintains as much execution convergence as +possible; the code executed finds the set of unique function pointers over +the currently running program instances and calls each one just once, such +that the executing program instances when it is called are the set of +active program instances that had that function pointer value. The order +in which the various function pointers are called in this case is +indefined. + + Task Parallelism: Language Syntax --------------------------------- diff --git a/expr.cpp b/expr.cpp index 613563f9..df8b262a 100644 --- a/expr.cpp +++ b/expr.cpp @@ -174,6 +174,43 @@ lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr, const AtomicType *toAtomicType = dynamic_cast(toType); const AtomicType *fromAtomicType = dynamic_cast(fromType); + const PointerType *fromPointerType = dynamic_cast(fromType); + const PointerType *toPointerType = dynamic_cast(toType); + if (fromPointerType != NULL) { + if (dynamic_cast(toType) != NULL && + toType->IsBoolType()) + // Allow implicit conversion of pointers to bools + goto typecast_ok; + + if (toPointerType == NULL) { + if (!failureOk) + Error(pos, "Can't convert between from pointer type " + "\"%s\" to non-pointer type \"%s\".", + fromType->GetString().c_str(), + toType->GetString().c_str()); + return false; + } + else if (Type::Equal(fromPointerType->GetAsUniformType()->GetAsConstType(), + PointerType::Void)) { + // void *s can be converted to any other pointer type + goto typecast_ok; + } + else if (!Type::Equal(fromPointerType->GetBaseType(), + toPointerType->GetBaseType())) { + if (!failureOk) + Error(pos, "Can't convert between incompatible pointer types " + "\"%s\" and \"%s\".", fromPointerType->GetString().c_str(), + toPointerType->GetString().c_str()); + return false; + } + + if (toType->IsVaryingType() && fromType->IsUniformType()) + goto typecast_ok; + + // Otherwise there's nothing to do + return true; + } + // Convert from type T -> const T; just return a TypeCast expr, which // can handle this if (Type::Equal(toType, fromType->GetAsConstType())) @@ -380,7 +417,8 @@ lMatchingBoolType(const Type *type) { if (vt != NULL) return new VectorType(boolBase, vt->GetElementCount()); else { - assert(dynamic_cast(type) != NULL); + assert(dynamic_cast(type) != NULL || + dynamic_cast(type) != NULL); return boolBase; } } @@ -1068,6 +1106,10 @@ BinaryExpr::GetType() const { if (type0 == NULL || type1 == NULL) return NULL; +#if 0 + // FIXME: I think these are redundant given the checks in + // BinaryExpr::TypeCheck(). They should either be removed or updated + // to handle the cases where pointer == and != tests are ok. if (!type0->IsBoolType() && !type0->IsNumericType()) { Error(arg0->pos, "First operand to binary operator \"%s\" is of invalid " "type \"%s\".", lOpString(op), type0->GetString().c_str()); @@ -1079,6 +1121,7 @@ BinaryExpr::GetType() const { "type \"%s\".", lOpString(op), type1->GetString().c_str()); return NULL; } +#endif const Type *promotedType = Type::MoreGeneralType(type0, type1, pos, lOpString(op)); @@ -1167,8 +1210,8 @@ lConstFoldBinLogicalOp(BinaryExpr::Op op, const T *v0, const T *v1, ConstExpr *c return NULL; } - const Type *rType = carg0->GetType()->IsUniformType() ? AtomicType::UniformBool : - AtomicType::VaryingBool; + const Type *rType = carg0->GetType()->IsUniformType() ? + AtomicType::UniformBool : AtomicType::VaryingBool; return new ConstExpr(rType, result, carg0->pos); } @@ -1478,19 +1521,23 @@ BinaryExpr::TypeCheck() { } case Equal: case NotEqual: { - if (!type0->IsBoolType() && !type0->IsNumericType()) { - Error(arg0->pos, - "First operand to equality operator \"%s\" is of " - "non-comparable type \"%s\".", lOpString(op), - type0->GetString().c_str()); - return NULL; - } - if (!type1->IsBoolType() && !type1->IsNumericType()) { - Error(arg1->pos, - "Second operand to equality operator \"%s\" is of " - "non-comparable type \"%s\".", lOpString(op), - type1->GetString().c_str()); - return NULL; + const PointerType *pt0 = dynamic_cast(type0); + const PointerType *pt1 = dynamic_cast(type1); + if (pt0 == NULL && pt1 == NULL) { + if (!type0->IsBoolType() && !type0->IsNumericType()) { + Error(arg0->pos, + "First operand to equality operator \"%s\" is of " + "non-comparable type \"%s\".", lOpString(op), + type0->GetString().c_str()); + return NULL; + } + if (!type1->IsBoolType() && !type1->IsNumericType()) { + Error(arg1->pos, + "Second operand to equality operator \"%s\" is of " + "non-comparable type \"%s\".", lOpString(op), + type1->GetString().c_str()); + return NULL; + } } const Type *promotedType = @@ -1750,9 +1797,16 @@ AssignExpr::GetType() const { Expr * AssignExpr::TypeCheck() { - bool lvalueIsReference = lvalue && + if (lvalue != NULL) + lvalue = lvalue->TypeCheck(); + if (rvalue != NULL) + rvalue = rvalue->TypeCheck(); + if (lvalue == NULL || rvalue == NULL) + return NULL; + + bool lvalueIsReference = dynamic_cast(lvalue->GetType()) != NULL; - bool rvalueIsReference = rvalue && + bool rvalueIsReference = dynamic_cast(rvalue->GetType()) != NULL; // hack to allow asigning array references e.g. in a struct... @@ -1761,12 +1815,25 @@ AssignExpr::TypeCheck() { dynamic_cast(rvalue->GetType()->GetReferenceTarget()))) lvalue = new DereferenceExpr(lvalue, lvalue->pos); - if (lvalue != NULL) - lvalue = lvalue->TypeCheck(); - if (rvalue != NULL) - rvalue = rvalue->TypeCheck(); - if (lvalue == NULL || rvalue == NULL) - return NULL; + FunctionSymbolExpr *fse; + if ((fse = dynamic_cast(rvalue)) != NULL) { + // Special case to use the type of the LHS to resolve function + // overloads when we're assigning a function pointer where the + // function is overloaded. + const Type *lvalueType = lvalue->GetType(); + const FunctionType *ftype; + if (dynamic_cast(lvalueType) == NULL || + (ftype = dynamic_cast(lvalueType->GetBaseType())) == NULL) { + Error(pos, "Can't assign function pointer to type \"%s\".", + lvalue->GetType()->GetString().c_str()); + return NULL; + } + if (!fse->ResolveOverloads(ftype->GetArgumentTypes())) { + Error(pos, "Unable to find overloaded function for function " + "pointer assignment."); + return NULL; + } + } rvalue = TypeConvertExpr(rvalue, lvalue->GetType(), "assignment"); if (rvalue == NULL) @@ -1862,9 +1929,6 @@ SelectExpr::GetValue(FunctionEmitContext *ctx) const { testType->GetBaseType() == AtomicType::VaryingBool); const Type *type = expr1->GetType(); - // Type checking should also make sure this is the case - assert(Type::Equal(type->GetAsNonConstType(), - expr2->GetType()->GetAsNonConstType())); if (testType == AtomicType::UniformBool) { // Simple case of a single uniform bool test expression; we just @@ -2075,29 +2139,40 @@ FunctionCallExpr::FunctionCallExpr(Expr *f, ExprList *a, SourcePos p, } +static const FunctionType * +lGetFunctionType(Expr *func) { + if (func == NULL) + return NULL; + + const Type *type = func->GetType(); + if (type == NULL) + return NULL; + + const FunctionType *ftype = dynamic_cast(type); + if (ftype == NULL) { + // Not a regular function symbol--is it a function pointer? + if (dynamic_cast(type) != NULL) + ftype = dynamic_cast(type->GetBaseType()); + } + return ftype; +} + + llvm::Value * FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { - if (!func || !args) + if (func == NULL || args == NULL) return NULL; ctx->SetDebugPos(pos); - FunctionSymbolExpr *fse = dynamic_cast(func); - assert(fse != NULL); // should be caught during typechecking + llvm::Value *callee = func->GetValue(ctx); - Symbol *funSym = fse->GetMatchingFunction(); - if (funSym == NULL) - // No match was found; an error should have been issued earlier, so - // just return. - return NULL; - - llvm::Function *callee = funSym->function; if (callee == NULL) { - Error(pos, "Symbol \"%s\" is not a function.", funSym->name.c_str()); + assert(m->errorCount > 0); return NULL; } - const FunctionType *ft = dynamic_cast(funSym->type); + const FunctionType *ft = lGetFunctionType(func); assert(ft != NULL); bool isVoidFunc = (ft->GetReturnType() == AtomicType::Void); @@ -2127,8 +2202,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { // store the expr's value to alloca'ed memory and then // pass a reference to that... Error(pos, "Can't pass non-lvalue as \"reference\" parameter \"%s\" " - "to function \"%s\".", ft->GetArgumentName(i).c_str(), - funSym->name.c_str()); + "to function.", ft->GetArgumentName(i).c_str()); err = true; } else @@ -2215,18 +2289,9 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { if (launchCount != NULL) ctx->LaunchInst(callee, argVals, launchCount); } - else { - // Most of the time, the mask is passed as the last argument. this - // isn't the case for things like intrinsics, builtins, and extern - // "C" functions from the application. - assert(callargs.size() + 1 == callee->arg_size() || - callargs.size() == callee->arg_size()); - - if (callargs.size() + 1 == callee->arg_size()) - argVals.push_back(ctx->GetFullMask()); - - retVal = ctx->CallInst(callee, argVals, isVoidFunc ? "" : "calltmp"); - } + else + retVal = ctx->CallInst(callee, ft->GetReturnType(), argVals, + isVoidFunc ? "" : "calltmp"); // For anything we had to do as pass by value/result, copy the // corresponding reference values back out @@ -2253,17 +2318,8 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { const Type * FunctionCallExpr::GetType() const { - FunctionSymbolExpr *fse = dynamic_cast(func); - if (fse == NULL) - return NULL; - - Symbol *sym = fse->GetMatchingFunction(); - if (sym == NULL) - return NULL; - - const FunctionType *ft = dynamic_cast(sym->type); - assert(ft != NULL); - return ft->GetReturnType(); + const FunctionType *ftype = lGetFunctionType(func); + return ftype ? ftype->GetReturnType() : NULL; } @@ -2282,56 +2338,78 @@ FunctionCallExpr::Optimize() { Expr * FunctionCallExpr::TypeCheck() { + if (func != NULL) + func = func->TypeCheck(); if (args != NULL) args = args->TypeCheck(); if (args != NULL && func != NULL) { FunctionSymbolExpr *fse = dynamic_cast(func); + if (fse != NULL) { + std::vector argTypes; + for (unsigned int i = 0; i < args->exprs.size(); ++i) { + if (args->exprs[i] == NULL) + return NULL; + const Type *t = args->exprs[i]->GetType(); + if (t == NULL) + return NULL; + argTypes.push_back(t); + } - if (fse == NULL) { - Error(pos, "No valid function available for function call."); - return NULL; - } + if (fse->ResolveOverloads(argTypes) == true) { + func = fse->TypeCheck(); - std::vector argTypes; - for (unsigned int i = 0; i < args->exprs.size(); ++i) { - if (args->exprs[i] == NULL) - return NULL; - const Type *t = args->exprs[i]->GetType(); - if (t == NULL) - return NULL; - argTypes.push_back(t); - } + if (func != NULL) { + const PointerType *pt = + dynamic_cast(func->GetType()); + const FunctionType *ft = (pt == NULL) ? NULL : + dynamic_cast(pt->GetBaseType()); + if (ft != NULL) { + if (ft->isTask) { + if (!isLaunch) + Error(pos, "\"launch\" expression needed to call function " + "with \"task\" qualifier."); + if (!launchCountExpr) + return NULL; - if (fse->ResolveOverloads(argTypes) == true) { - func = fse->TypeCheck(); - - if (func != NULL) { - const FunctionType *ft = - dynamic_cast(func->GetType()); - if (ft != NULL) { - if (ft->isTask) { - if (!isLaunch) - Error(pos, "\"launch\" expression needed to call function " - "with \"task\" qualifier."); - if (!launchCountExpr) - return NULL; - - launchCountExpr = - TypeConvertExpr(launchCountExpr, AtomicType::UniformInt32, - "task launch count"); - if (launchCountExpr == NULL) - return NULL; - } - else { - if (isLaunch) - Error(pos, "\"launch\" expression illegal with non-\"task\"-" - "qualified function."); - assert(launchCountExpr == NULL); + launchCountExpr = + TypeConvertExpr(launchCountExpr, AtomicType::UniformInt32, + "task launch count"); + if (launchCountExpr == NULL) + return NULL; + } + else { + if (isLaunch) + Error(pos, "\"launch\" expression illegal with non-\"task\"-" + "qualified function."); + assert(launchCountExpr == NULL); + } } + else + Error(pos, "Valid function name must be used for function call."); + } + } + } + else { + const Type *funcType = func->GetType(); + if (funcType == NULL) + return NULL; + + if (dynamic_cast(funcType) == NULL || + dynamic_cast(funcType->GetBaseType()) == NULL) { + Error(pos, "Must provide function name or function pointer for " + "function call expression."); + return NULL; + } + + if (funcType->IsVaryingType()) { + const FunctionType *ft = + dynamic_cast(funcType->GetBaseType()); + if (ft->GetReturnType()->IsUniformType()) { + Error(pos, "Illegal to call a varying function pointer that " + "points to a function with a uniform return type."); + return NULL; } - else - Error(pos, "Valid function name must be used for function call."); } } } @@ -2344,8 +2422,25 @@ FunctionCallExpr::TypeCheck() { int FunctionCallExpr::EstimateCost() const { - return ((args ? args->EstimateCost() : 0) + - (isLaunch ? COST_TASK_LAUNCH : COST_FUNCALL)); + int callCost = 0; + if (isLaunch) + callCost = COST_TASK_LAUNCH; + else if (dynamic_cast(func) == NULL) { + // it's going through a function pointer + const Type *fpType = func->GetType(); + if (fpType != NULL) { + assert(dynamic_cast(fpType) != NULL); + if (fpType->IsUniformType()) + callCost = COST_FUNPTR_UNIFORM; + else + callCost = COST_FUNPTR_VARYING; + } + } + else + // regular function call + callCost = COST_FUNCALL; + + return (args ? args->EstimateCost() : 0) + callCost; } @@ -2544,7 +2639,8 @@ lAddVaryingOffsetsIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr, // If the result of the indexing isn't a varying atomic type, then // nothing to do here. if (returnType->IsVaryingType() == false || - dynamic_cast(returnType) == NULL) + (dynamic_cast(returnType) == NULL && + dynamic_cast(returnType) == NULL)) return ptr; // We should now have an array of pointer values, represing in a @@ -2562,7 +2658,9 @@ lAddVaryingOffsetsIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr, // above, and no additional offset is needed. Otherwise we have // pointers to varying atomic types--e.g. ptr->getType() == // [8 x <8 x float> *] - if (llvm::isa(pt->getElementType()) == false) + if (pt->getElementType()->isIntegerTy() || + pt->getElementType()->isFloatingPointTy() || + pt->getElementType()->isPointerTy()) return ptr; // But not so fast: if the reason we have a vector of pointers is that @@ -2785,7 +2883,7 @@ IndexExpr::Print() const { /** Map one character ids to vector element numbers. Allow a few different conventions--xyzw, rgba, uv. - */ +*/ static int lIdentifierToVectorElement(char id) { switch (id) { @@ -4643,6 +4741,22 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const { // an error should have been issued elsewhere in this case return NULL; + const PointerType *fromPointerType = dynamic_cast(fromType); + const PointerType *toPointerType = dynamic_cast(toType); + if (fromPointerType != NULL && toPointerType != NULL) { + llvm::Value *value = expr->GetValue(ctx); + if (value == NULL) + return NULL; + + // bitcast from NULL to actual pointer type... + value = ctx->BitCastInst(value, toType->GetAsUniformType()->LLVMType(g->ctx)); + + if (fromType->IsUniformType() && toType->IsVaryingType()) + return ctx->SmearScalar(value); + else + return value; + } + if (Type::Equal(toType->GetAsConstType(), fromType->GetAsConstType())) // There's nothing to do, just return the value. (LLVM's type // system doesn't worry about constiness.) @@ -4736,6 +4850,32 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const { toType = toEnum->IsUniformType() ? AtomicType::UniformUInt32 : AtomicType::VaryingUInt32; + if (fromPointerType != NULL) { + // convert pointer to bool + assert(dynamic_cast(toType) && + toType->IsBoolType()); + LLVM_TYPE_CONST llvm::Type *lfu = + fromType->GetAsUniformType()->LLVMType(g->ctx); + LLVM_TYPE_CONST llvm::PointerType *llvmFromUnifType = + llvm::dyn_cast(lfu); + + llvm::Value *nullPtrValue = llvm::ConstantPointerNull::get(llvmFromUnifType); + if (fromType->IsVaryingType()) + nullPtrValue = ctx->SmearScalar(nullPtrValue); + + llvm::Value *cmp = ctx->CmpInst(llvm::Instruction::ICmp, + llvm::CmpInst::ICMP_NE, + exprVal, nullPtrValue, "ptr_ne_NULL"); + + if (toType->IsVaryingType()) { + if (fromType->IsUniformType()) + cmp = ctx->SmearScalar(cmp); + cmp = ctx->I1VecToBoolVec(cmp); + } + + return cmp; + } + const AtomicType *fromAtomic = dynamic_cast(fromType); // at this point, coming from an atomic type is all that's left... assert(fromAtomic != NULL); @@ -4954,6 +5094,35 @@ TypeCastExpr::Print() const { } +llvm::Constant * +TypeCastExpr::GetConstant(const Type *constType) const { + // We don't need to worry about most the basic cases where the type + // cast can resolve to a constant here, since the + // TypeCastExpr::Optimize() method ends up doing the type conversion + // and returning a ConstExpr, which in turn will have its GetConstant() + // method called. Thus, the only case we do need to worry about here + // is converting a uniform function pointer to a varying function + // pointer of the same type. + assert(Type::Equal(constType, type)); + const FunctionType *ft = NULL; + if (dynamic_cast(type) == NULL || + (ft = dynamic_cast(type->GetBaseType())) == NULL) + return NULL; + + llvm::Constant *ec = expr->GetConstant(expr->GetType()); + if (ec == NULL) + return NULL; + + std::vector smear; + for (int i = 0; i < g->target.vectorWidth; ++i) + smear.push_back(ec); + LLVM_TYPE_CONST llvm::ArrayType *llvmVaryingType = + llvm::dyn_cast(type->LLVMType(g->ctx)); + assert(llvmVaryingType != NULL); + return llvm::ConstantArray::get(llvmVaryingType, smear); +} + + /////////////////////////////////////////////////////////////////////////// // ReferenceExpr @@ -5198,21 +5367,28 @@ FunctionSymbolExpr::FunctionSymbolExpr(const char *n, SourcePos p) : Expr(p) { name = n; - matchingFunc = NULL; candidateFunctions = candidates; + matchingFunc = (candidates && candidates->size() == 1) ? + (*candidates)[0] : NULL; + triedToResolve = false; } const Type * FunctionSymbolExpr::GetType() const { - return matchingFunc ? matchingFunc->type : NULL; + if (triedToResolve == false && matchingFunc == NULL) { + Error(pos, "Ambiguous use of overloaded function \"%s\".", + name.c_str()); + return NULL; + } + + return matchingFunc ? new PointerType(matchingFunc->type, true, true) : NULL; } llvm::Value * FunctionSymbolExpr::GetValue(FunctionEmitContext *ctx) const { - assert("!should not call FunctionSymbolExpr::GetValue()"); - return NULL; + return matchingFunc ? matchingFunc->function : NULL; } @@ -5251,6 +5427,18 @@ FunctionSymbolExpr::Print() const { } +llvm::Constant * +FunctionSymbolExpr::GetConstant(const Type *type) const { + assert(type->IsUniformType()); + assert(GetType()->IsUniformType()); + + if (Type::Equal(type->GetAsConstType(), + GetType()->GetAsConstType()) == false) + return NULL; + + return matchingFunc ? matchingFunc->function : NULL; +} + static std::string lGetFunctionDeclaration(const std::string &name, const FunctionType *type) { @@ -5617,6 +5805,8 @@ FunctionSymbolExpr::tryResolve(int (*matchFunc)(const Type *, const Type *), bool FunctionSymbolExpr::ResolveOverloads(const std::vector &argTypes) { + triedToResolve = true; + // Functions with names that start with "__" should only be various // builtins. For those, we'll demand an exact match, since we'll // expect whichever function in stdlib.ispc is calling out to one of @@ -5712,3 +5902,44 @@ Expr * SyncExpr::Optimize() { return this; } + + +/////////////////////////////////////////////////////////////////////////// +// NullPointerExpr + +llvm::Value * +NullPointerExpr::GetValue(FunctionEmitContext *ctx) const { + return llvm::ConstantPointerNull::get(LLVMTypes::VoidPointerType); +} + + +const Type * +NullPointerExpr::GetType() const { + return PointerType::Void; +} + + +Expr * +NullPointerExpr::TypeCheck() { + return this; +} + + +Expr * +NullPointerExpr::Optimize() { + return this; +} + + +void +NullPointerExpr::Print() const { + printf("NULL"); + pos.Print(); +} + + +int +NullPointerExpr::EstimateCost() const { + return 0; +} + diff --git a/expr.h b/expr.h index 54f2f7f3..17ec622b 100644 --- a/expr.h +++ b/expr.h @@ -490,6 +490,7 @@ public: Expr *TypeCheck(); Expr *Optimize(); int EstimateCost() const; + llvm::Constant *GetConstant(const Type *type) const; const Type *type; Expr *expr; @@ -568,6 +569,7 @@ public: Expr *Optimize(); void Print() const; int EstimateCost() const; + llvm::Constant *GetConstant(const Type *type) const; bool ResolveOverloads(const std::vector &argTypes); Symbol *GetMatchingFunction(); @@ -586,6 +588,8 @@ private: /** The actual matching function found after overload resolution. */ Symbol *matchingFunc; + + bool triedToResolve; }; @@ -604,6 +608,20 @@ public: }; +/** @brief An expression that represents a NULL pointer. */ +class NullPointerExpr : public Expr { +public: + NullPointerExpr(SourcePos p) : Expr(p) { } + + llvm::Value *GetValue(FunctionEmitContext *ctx) const; + const Type *GetType() const; + Expr *TypeCheck(); + Expr *Optimize(); + void Print() const; + int EstimateCost() const; +}; + + /** This function indicates whether it's legal to convert from fromType to toType. */ diff --git a/ispc.h b/ispc.h index 54d81932..e31a5656 100644 --- a/ispc.h +++ b/ispc.h @@ -365,6 +365,8 @@ enum { COST_COMPLEX_ARITH_OP = 4, COST_DEREF = 4, COST_FUNCALL = 4, + COST_FUNPTR_UNIFORM = 12, + COST_FUNPTR_VARYING = 24, COST_GATHER = 8, COST_LOAD = 2, COST_REGULAR_BREAK_CONTINUE = 2, @@ -372,7 +374,7 @@ enum { COST_SELECT = 4, COST_SIMPLE_ARITH_LOGIC_OP = 1, COST_SYNC = 32, - COST_TASK_LAUNCH = 16, + COST_TASK_LAUNCH = 32, COST_TYPECAST_COMPLEX = 4, COST_TYPECAST_SIMPLE = 1, COST_UNIFORM_IF = 2, diff --git a/lex.ll b/lex.ll index f2487a34..b31315c4 100644 --- a/lex.ll +++ b/lex.ll @@ -110,6 +110,7 @@ int16 { return TOKEN_INT16; } int32 { return TOKEN_INT; } int64 { return TOKEN_INT64; } launch { return TOKEN_LAUNCH; } +NULL { return TOKEN_NULL; } print { return TOKEN_PRINT; } reference { return TOKEN_REFERENCE; } return { return TOKEN_RETURN; } diff --git a/parse.yy b/parse.yy index a944784d..d3363518 100644 --- a/parse.yy +++ b/parse.yy @@ -103,8 +103,8 @@ static const char *lBuiltinTokens[] = { "bool", "break", "case", "cbreak", "ccontinue", "cdo", "cfor", "cif", "cwhile", "const", "continue", "creturn", "default", "do", "double", "else", "enum", "export", "extern", "false", "float", "for", "goto", "if", - "inline", "int", "int8", "int16", "int32", "int64", "launch", "print", - "reference", "return", + "inline", "int", "int8", "int16", "int32", "int64", "launch", "NULL", + "print", "reference", "return", "static", "struct", "switch", "sync", "task", "true", "typedef", "uniform", "unsigned", "varying", "void", "while", NULL }; @@ -146,7 +146,7 @@ static const char *lParamListTokens[] = { %token TOKEN_INT32_CONSTANT TOKEN_UINT32_CONSTANT TOKEN_INT64_CONSTANT %token TOKEN_UINT64_CONSTANT TOKEN_FLOAT_CONSTANT TOKEN_STRING_C_LITERAL -%token TOKEN_IDENTIFIER TOKEN_STRING_LITERAL TOKEN_TYPE_NAME +%token TOKEN_IDENTIFIER TOKEN_STRING_LITERAL TOKEN_TYPE_NAME TOKEN_NULL %token TOKEN_PTR_OP TOKEN_INC_OP TOKEN_DEC_OP TOKEN_LEFT_OP TOKEN_RIGHT_OP %token TOKEN_LE_OP TOKEN_GE_OP TOKEN_EQ_OP TOKEN_NE_OP %token TOKEN_AND_OP TOKEN_OR_OP TOKEN_MUL_ASSIGN TOKEN_DIV_ASSIGN TOKEN_MOD_ASSIGN @@ -253,6 +253,9 @@ primary_expression | TOKEN_FALSE { $$ = new ConstExpr(AtomicType::UniformConstBool, false, @1); } + | TOKEN_NULL { + $$ = new NullPointerExpr(@1); + } /* | TOKEN_STRING_LITERAL { UNIMPLEMENTED }*/ | '(' expression ')' { $$ = $2; } @@ -458,8 +461,15 @@ constant_expression declaration_statement : declaration { - std::vector vars = $1->GetVariableDeclarations(); - $$ = new DeclStmt(vars, @1); + if ($1->declSpecs->storageClass == SC_TYPEDEF) { + for (unsigned int i = 0; i < $1->declarators.size(); ++i) { + m->AddTypeDef($1->declarators[i]->sym); + } + } + else { + std::vector vars = $1->GetVariableDeclarations(); + $$ = new DeclStmt(vars, @1); + } } ; @@ -857,6 +867,11 @@ type_qualifier declarator : direct_declarator + | '*' direct_declarator + { + $2->pointerCount++; + $$ = $2; + } ; int_constant @@ -1265,7 +1280,9 @@ lAddDeclaration(DeclSpecs *ds, Declarator *decl) { // Error happened earlier during parsing return; - if (decl->isFunction) { + if (ds->storageClass == SC_TYPEDEF) + m->AddTypeDef(decl->sym); + else if (decl->isFunction) { // function declaration const Type *t = decl->GetType(ds); const FunctionType *ft = dynamic_cast(t); @@ -1293,8 +1310,6 @@ lAddDeclaration(DeclSpecs *ds, Declarator *decl) { bool isInline = (ds->typeQualifier & TYPEQUAL_INLINE); m->AddFunctionDeclaration(funSym, args, isInline); } - else if (ds->storageClass == SC_TYPEDEF) - m->AddTypeDef(decl->sym); else m->AddGlobalVariable(decl->sym, decl->initExpr, (ds->typeQualifier & TYPEQUAL_CONST) != 0); diff --git a/stmt.cpp b/stmt.cpp index cb547a55..c75fd45d 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -120,24 +120,42 @@ DeclStmt::DeclStmt(const std::vector &v, SourcePos p) } +static bool +lPossiblyResolveFunctionOverloads(Expr *expr, const Type *type) { + FunctionSymbolExpr *fse = NULL; + const FunctionType *funcType = NULL; + if (dynamic_cast(type) != NULL && + (funcType = dynamic_cast(type->GetBaseType())) && + (fse = dynamic_cast(expr)) != NULL) { + // We're initializing a function pointer with a function symbol, + // which in turn may represent an overloaded function. So we need + // to try to resolve the overload based on the type of the symbol + // we're initializing here. + if (fse->ResolveOverloads(funcType->GetArgumentTypes()) == false) + return false; + } + return true; +} + + /** Utility routine that emits code to initialize a symbol given an initializer expression. @param lvalue Memory location of storage for the symbol's data @param symName Name of symbol (used in error messages) - @param type Type of variable being initialized + @param symType Type of variable being initialized @param initExpr Expression for the initializer @param ctx FunctionEmitContext to use for generating instructions @param pos Source file position of the variable being initialized */ static void -lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type, +lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *symType, Expr *initExpr, FunctionEmitContext *ctx, SourcePos pos) { if (initExpr == NULL) { // Initialize things without initializers to the undefined value. // To auto-initialize everything to zero, replace 'UndefValue' with // 'NullValue' in the below - LLVM_TYPE_CONST llvm::Type *ltype = type->LLVMType(g->ctx); + LLVM_TYPE_CONST llvm::Type *ltype = symType->LLVMType(g->ctx); ctx->StoreInst(llvm::UndefValue::get(ltype), lvalue); return; } @@ -146,7 +164,10 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type, // ExprList, then we'll see if we can type convert it to the type of // the variable. if (dynamic_cast(initExpr) == NULL) { - initExpr = TypeConvertExpr(initExpr, type, "initializer"); + if (lPossiblyResolveFunctionOverloads(initExpr, symType) == false) + return; + initExpr = TypeConvertExpr(initExpr, symType, "initializer"); + if (initExpr != NULL) { llvm::Value *initializerValue = initExpr->GetValue(ctx); if (initializerValue != NULL) @@ -159,16 +180,17 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type, // Atomic types and enums can't be initialized with { ... } initializer // expressions, so print an error and return if that's what we've got // here.. - if (dynamic_cast(type) != NULL || - dynamic_cast(type) != NULL) { + if (dynamic_cast(symType) != NULL || + dynamic_cast(symType) != NULL || + dynamic_cast(symType) != NULL) { if (dynamic_cast(initExpr) != NULL) Error(initExpr->pos, "Expression list initializers can't be used for " "variable \"%s\' with type \"%s\".", symName, - type->GetString().c_str()); + symType->GetString().c_str()); return; } - const ReferenceType *rt = dynamic_cast(type); + const ReferenceType *rt = dynamic_cast(symType); if (rt) { if (!Type::Equal(initExpr->GetType(), rt)) { Error(initExpr->pos, "Initializer for reference type \"%s\" must have same " @@ -190,14 +212,14 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type, // in which case the elements are initialized with the corresponding // values. const CollectionType *collectionType = - dynamic_cast(type); + dynamic_cast(symType); if (collectionType != NULL) { std::string name; - if (dynamic_cast(type) != NULL) + if (dynamic_cast(symType) != NULL) name = "struct"; - else if (dynamic_cast(type) != NULL) + else if (dynamic_cast(symType) != NULL) name = "array"; - else if (dynamic_cast(type) != NULL) + else if (dynamic_cast(symType) != NULL) name = "vector"; else FATAL("Unexpected CollectionType in lInitSymbol()"); @@ -291,10 +313,21 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const { // zero value. llvm::Constant *cinit = NULL; if (initExpr != NULL) { + if (lPossiblyResolveFunctionOverloads(initExpr, type) == false) + continue; + // FIXME: we only need this for function pointers; it was + // already done for atomic types and enums in + // DeclStmt::TypeCheck()... + initExpr = TypeConvertExpr(initExpr, type, "initializer"); + // FIXME: and this is only needed to re-establish + // constant-ness so that GetConstant below works for + // constant artithmetic expressions... + initExpr = initExpr->Optimize(); + cinit = initExpr->GetConstant(type); if (cinit == NULL) - Error(sym->pos, "Initializer for static variable \"%s\" must be a constant.", - sym->name.c_str()); + Error(initExpr->pos, "Initializer for static variable " + "\"%s\" must be a constant.", sym->name.c_str()); } if (cinit == NULL) cinit = llvm::Constant::getNullValue(llvmType); @@ -370,8 +403,10 @@ DeclStmt::TypeCheck() { if (vars[i].init == NULL) continue; vars[i].init = vars[i].init->TypeCheck(); - if (vars[i].init == NULL) + if (vars[i].init == NULL) { + encounteredError = true; continue; + } // get the right type for stuff like const float foo = 2; so that // the int->float type conversion is in there and we don't return @@ -506,39 +541,35 @@ IfStmt::EmitCode(FunctionEmitContext *ctx) const { Stmt * IfStmt::Optimize() { - if (test) + if (test != NULL) test = test->Optimize(); - if (trueStmts) + if (trueStmts != NULL) trueStmts = trueStmts->Optimize(); - if (falseStmts) + if (falseStmts != NULL) falseStmts = falseStmts->Optimize(); return this; } Stmt *IfStmt::TypeCheck() { - if (test) { + if (test != NULL) { test = test->TypeCheck(); - if (test) { + if (test != NULL) { const Type *testType = test->GetType(); - if (testType) { + if (testType != NULL) { bool isUniform = (testType->IsUniformType() && !g->opt.disableUniformControlFlow); - if (!testType->IsNumericType() && !testType->IsBoolType()) { - Error(test->pos, "Type \"%s\" can't be converted to boolean " - "for \"if\" test.", testType->GetString().c_str()); + test = TypeConvertExpr(test, isUniform ? AtomicType::UniformBool : + AtomicType::VaryingBool, + "\"if\" statement test"); + if (test == NULL) return NULL; - } - test = new TypeCastExpr(isUniform ? AtomicType::UniformBool : - AtomicType::VaryingBool, - test, false, test->pos); - assert(test); } } } - if (trueStmts) + if (trueStmts != NULL) trueStmts = trueStmts->TypeCheck(); - if (falseStmts) + if (falseStmts != NULL) falseStmts = falseStmts->TypeCheck(); return this; @@ -698,7 +729,8 @@ lSafeToRunWithAllLanesOff(Expr *expr) { if (dynamic_cast(expr) != NULL || dynamic_cast(expr) != NULL || - dynamic_cast(expr) != NULL) + dynamic_cast(expr) != NULL || + dynamic_cast(expr) != NULL) return true; FATAL("Unknown Expr type in lSafeToRunWithAllLanesOff()"); @@ -1659,6 +1691,12 @@ lEncodeType(const Type *t) { if (t == AtomicType::VaryingUInt64) return 'V'; if (t == AtomicType::UniformDouble) return 'd'; if (t == AtomicType::VaryingDouble) return 'D'; + if (dynamic_cast(t) != NULL) { + if (t->IsUniformType()) + return 'p'; + else + return 'P'; + } else return '\0'; } @@ -1788,13 +1826,14 @@ PrintStmt::EmitCode(FunctionEmitContext *ctx) const { llvm::Function *printFunc = m->module->getFunction("__do_print"); assert(printFunc); + llvm::Value *mask = ctx->GetFullMask(); // Set up the rest of the parameters to it args[0] = ctx->GetStringPtr(format); args[1] = ctx->GetStringPtr(argTypes); args[2] = LLVMInt32(g->target.vectorWidth); - args[3] = ctx->LaneMask(ctx->GetFullMask()); + args[3] = ctx->LaneMask(mask); std::vector argVec(&args[0], &args[5]); - ctx->CallInst(printFunc, argVec, ""); + ctx->CallInst(printFunc, AtomicType::Void, argVec, ""); } @@ -1874,7 +1913,7 @@ AssertStmt::EmitCode(FunctionEmitContext *ctx) const { args.push_back(ctx->GetStringPtr(errorString)); args.push_back(expr->GetValue(ctx)); args.push_back(ctx->GetFullMask()); - ctx->CallInst(assertFunc, args, ""); + ctx->CallInst(assertFunc, AtomicType::Void, args, ""); #ifndef ISPC_IS_WINDOWS free(errorString); diff --git a/type.cpp b/type.cpp index 90c713c2..5454db98 100644 --- a/type.cpp +++ b/type.cpp @@ -698,6 +698,169 @@ EnumType::GetEnumerator(int i) const { } +/////////////////////////////////////////////////////////////////////////// +// PointerType + +PointerType *PointerType::Void = new PointerType(AtomicType::Void, true, true); + +PointerType::PointerType(const Type *t, bool iu, bool ic) + : isUniform(iu), isConst(ic) { + baseType = t; +} + + +bool +PointerType::IsUniformType() const { + return isUniform; +} + + +bool +PointerType::IsBoolType() const { + return false; +} + + +bool +PointerType::IsFloatType() const { + return false; +} + + +bool +PointerType::IsIntType() const { + return false; +} + + +bool +PointerType::IsUnsignedType() const { + return false; +} + + +bool +PointerType::IsConstType() const { + return isConst; +} + + +const Type * +PointerType::GetBaseType() const { + return baseType; +} + + +const PointerType * +PointerType::GetAsVaryingType() const { + if (isUniform == false) + return this; + else + return new PointerType(baseType, false, isConst); +} + + +const PointerType * +PointerType::GetAsUniformType() const { + if (isUniform == true) + return this; + else + return new PointerType(baseType, true, isConst); +} + + +const Type * +PointerType::GetSOAType(int width) const { + FATAL("Unimplemented."); + return NULL; +} + + +const PointerType * +PointerType::GetAsConstType() const { + if (isConst == true) + return this; + else + return new PointerType(baseType, isUniform, true); +} + + +const PointerType * +PointerType::GetAsNonConstType() const { + if (isConst == false) + return this; + else + return new PointerType(baseType, isUniform, false); +} + + +std::string +PointerType::GetString() const { + if (baseType == NULL) + return ""; + + std::string ret; + if (isConst) ret += "const "; + if (isUniform) ret += "uniform "; + return ret + std::string("*") + baseType->GetString(); +} + + +std::string +PointerType::Mangle() const { + if (baseType == NULL) + return ""; + + return std::string("ptr<") + baseType->Mangle() + std::string(">"); +} + + +std::string +PointerType::GetCDeclaration(const std::string &name) const { + if (baseType == NULL) + return ""; + + std::string ret; + if (isConst) ret += "const "; + return ret + std::string("*") + baseType->GetCDeclaration(name); +} + + +LLVM_TYPE_CONST llvm::Type * +PointerType::LLVMType(llvm::LLVMContext *ctx) const { + if (baseType == NULL) + return NULL; + + LLVM_TYPE_CONST llvm::Type *ptype = NULL; + const FunctionType *ftype = dynamic_cast(baseType); + if (ftype != NULL) + // Get the type of the function variant that takes the mask as the + // last parameter--i.e. we don't allow taking function pointers of + // exported functions. + ptype = llvm::PointerType::get(ftype->LLVMFunctionType(ctx, true), 0); + else + ptype = llvm::PointerType::get(baseType->LLVMType(ctx), 0); + + if (isUniform) + return ptype; + else + // Varying pointers are represented as arrays of pointers since + // LLVM doesn't allow vectors of pointers. + return llvm::ArrayType::get(ptype, g->target.vectorWidth); +} + + +llvm::DIType +PointerType::GetDIType(llvm::DIDescriptor scope) const { + if (baseType == NULL) + return llvm::DIType(); + + llvm::DIType diTargetType = baseType->GetDIType(scope); + int bitsSize = g->target.is32bit ? 32 : 64; + return m->diBuilder->createPointerType(diTargetType, bitsSize); +} + + /////////////////////////////////////////////////////////////////////////// // SequentialType @@ -1697,37 +1860,37 @@ FunctionType::FunctionType(const Type *r, const std::vector &a, bool FunctionType::IsUniformType() const { - return returnType->IsUniformType(); + return true; } bool FunctionType::IsFloatType() const { - return returnType->IsFloatType(); + return false; } bool FunctionType::IsIntType() const { - return returnType->IsIntType(); + return false; } bool FunctionType::IsBoolType() const { - return returnType->IsBoolType(); + return false; } bool FunctionType::IsUnsignedType() const { - return returnType->IsUnsignedType(); + return false; } bool FunctionType::IsConstType() const { - return returnType->IsConstType(); + return false; } @@ -1965,12 +2128,39 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char // Are they both the same type? If so, we're done, QED. if (Type::Equal(t0, t1)) return t0; + + // If they're function types, it's hopeless if they didn't match in the + // Type::Equal() call above. Fail here so that we don't get into + // trouble calling GetAsConstType()... + if (dynamic_cast(t0) || + dynamic_cast(t1)) { + Error(pos, "Incompatible function types \"%s\" and \"%s\" in %s.", + t0->GetString().c_str(), t1->GetString().c_str(), reason); + return NULL; + } // Not the same types, but only a const/non-const difference? Return // the non-const type as the more general one. if (Type::Equal(t0->GetAsConstType(), t1->GetAsConstType())) return t0->GetAsNonConstType(); + const PointerType *pt0 = dynamic_cast(t0); + const PointerType *pt1 = dynamic_cast(t1); + if (pt0 != NULL && pt1 != NULL) { + if (Type::Equal(pt0->GetAsUniformType()->GetAsConstType(), + PointerType::Void)) + return pt1; + else if (Type::Equal(pt1->GetAsUniformType()->GetAsConstType(), + PointerType::Void)) + return pt0; + else { + Error(pos, "Conversion between incompatible pointer types \"%s\" " + "and \"%s\" isn't possible.", t0->GetString().c_str(), + t1->GetString().c_str()); + return NULL; + } + } + const VectorType *vt0 = dynamic_cast(t0); const VectorType *vt1 = dynamic_cast(t1); if (vt0 && vt1) { @@ -2135,6 +2325,11 @@ Type::Equal(const Type *a, const Type *b) { if (!Equal(fta->GetReturnType(), ftb->GetReturnType())) return false; + if (fta->isTask != ftb->isTask || + fta->isExported != ftb->isExported || + fta->isExternC != ftb->isExternC) + return false; + const std::vector &aargs = fta->GetArgumentTypes(); const std::vector &bargs = ftb->GetArgumentTypes(); if (aargs.size() != bargs.size()) @@ -2145,5 +2340,12 @@ Type::Equal(const Type *a, const Type *b) { return true; } + const PointerType *pta = dynamic_cast(a); + const PointerType *ptb = dynamic_cast(b); + if (pta != NULL && ptb != NULL) + return (pta->IsConstType() == ptb->IsConstType() && + pta->IsUniformType() == ptb->IsUniformType() && + Type::Equal(pta->GetBaseType(), ptb->GetBaseType())); + return false; } diff --git a/type.h b/type.h index 56eacc95..2c6f6ba4 100644 --- a/type.h +++ b/type.h @@ -300,6 +300,40 @@ private: std::vector enumerators; }; +/** @brief Type implementation for pointers to other types + */ +class PointerType : public Type { +public: + PointerType(const Type *t, bool isUniform, bool isConst); + + bool IsUniformType() const; + bool IsBoolType() const; + bool IsFloatType() const; + bool IsIntType() const; + bool IsUnsignedType() const; + bool IsConstType() const; + + const Type *GetBaseType() const; + const PointerType *GetAsVaryingType() const; + const PointerType *GetAsUniformType() const; + const Type *GetSOAType(int width) const; + const PointerType *GetAsConstType() const; + const PointerType *GetAsNonConstType() const; + + std::string GetString() const; + std::string Mangle() const; + std::string GetCDeclaration(const std::string &name) const; + + LLVM_TYPE_CONST llvm::Type *LLVMType(llvm::LLVMContext *ctx) const; + llvm::DIType GetDIType(llvm::DIDescriptor scope) const; + + static PointerType *Void; + +private: + const bool isUniform, isConst; + const Type *baseType; +}; + /** @brief Abstract base class for types that represent collections of other types.