Add support for function pointers.

Both uniform and varying function pointers are supported; when a function
is called through a varying function pointer, each unique function pointer
value across the running program instances is called once for the set of
active program instances that want to call it.
This commit is contained in:
Matt Pharr
2011-11-03 16:13:27 -07:00
parent f1d8ff96ce
commit afcd42028f
15 changed files with 1137 additions and 269 deletions

109
stmt.cpp
View File

@@ -120,24 +120,42 @@ DeclStmt::DeclStmt(const std::vector<VariableDeclaration> &v, SourcePos p)
}
static bool
lPossiblyResolveFunctionOverloads(Expr *expr, const Type *type) {
FunctionSymbolExpr *fse = NULL;
const FunctionType *funcType = NULL;
if (dynamic_cast<const PointerType *>(type) != NULL &&
(funcType = dynamic_cast<const FunctionType *>(type->GetBaseType())) &&
(fse = dynamic_cast<FunctionSymbolExpr *>(expr)) != NULL) {
// We're initializing a function pointer with a function symbol,
// which in turn may represent an overloaded function. So we need
// to try to resolve the overload based on the type of the symbol
// we're initializing here.
if (fse->ResolveOverloads(funcType->GetArgumentTypes()) == false)
return false;
}
return true;
}
/** Utility routine that emits code to initialize a symbol given an
initializer expression.
@param lvalue Memory location of storage for the symbol's data
@param symName Name of symbol (used in error messages)
@param type Type of variable being initialized
@param symType Type of variable being initialized
@param initExpr Expression for the initializer
@param ctx FunctionEmitContext to use for generating instructions
@param pos Source file position of the variable being initialized
*/
static void
lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type,
lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *symType,
Expr *initExpr, FunctionEmitContext *ctx, SourcePos pos) {
if (initExpr == NULL) {
// Initialize things without initializers to the undefined value.
// To auto-initialize everything to zero, replace 'UndefValue' with
// 'NullValue' in the below
LLVM_TYPE_CONST llvm::Type *ltype = type->LLVMType(g->ctx);
LLVM_TYPE_CONST llvm::Type *ltype = symType->LLVMType(g->ctx);
ctx->StoreInst(llvm::UndefValue::get(ltype), lvalue);
return;
}
@@ -146,7 +164,10 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type,
// ExprList, then we'll see if we can type convert it to the type of
// the variable.
if (dynamic_cast<ExprList *>(initExpr) == NULL) {
initExpr = TypeConvertExpr(initExpr, type, "initializer");
if (lPossiblyResolveFunctionOverloads(initExpr, symType) == false)
return;
initExpr = TypeConvertExpr(initExpr, symType, "initializer");
if (initExpr != NULL) {
llvm::Value *initializerValue = initExpr->GetValue(ctx);
if (initializerValue != NULL)
@@ -159,16 +180,17 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type,
// Atomic types and enums can't be initialized with { ... } initializer
// expressions, so print an error and return if that's what we've got
// here..
if (dynamic_cast<const AtomicType *>(type) != NULL ||
dynamic_cast<const EnumType *>(type) != NULL) {
if (dynamic_cast<const AtomicType *>(symType) != NULL ||
dynamic_cast<const EnumType *>(symType) != NULL ||
dynamic_cast<const PointerType *>(symType) != NULL) {
if (dynamic_cast<ExprList *>(initExpr) != NULL)
Error(initExpr->pos, "Expression list initializers can't be used for "
"variable \"%s\' with type \"%s\".", symName,
type->GetString().c_str());
symType->GetString().c_str());
return;
}
const ReferenceType *rt = dynamic_cast<const ReferenceType *>(type);
const ReferenceType *rt = dynamic_cast<const ReferenceType *>(symType);
if (rt) {
if (!Type::Equal(initExpr->GetType(), rt)) {
Error(initExpr->pos, "Initializer for reference type \"%s\" must have same "
@@ -190,14 +212,14 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type,
// in which case the elements are initialized with the corresponding
// values.
const CollectionType *collectionType =
dynamic_cast<const CollectionType *>(type);
dynamic_cast<const CollectionType *>(symType);
if (collectionType != NULL) {
std::string name;
if (dynamic_cast<const StructType *>(type) != NULL)
if (dynamic_cast<const StructType *>(symType) != NULL)
name = "struct";
else if (dynamic_cast<const ArrayType *>(type) != NULL)
else if (dynamic_cast<const ArrayType *>(symType) != NULL)
name = "array";
else if (dynamic_cast<const VectorType *>(type) != NULL)
else if (dynamic_cast<const VectorType *>(symType) != NULL)
name = "vector";
else
FATAL("Unexpected CollectionType in lInitSymbol()");
@@ -291,10 +313,21 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
// zero value.
llvm::Constant *cinit = NULL;
if (initExpr != NULL) {
if (lPossiblyResolveFunctionOverloads(initExpr, type) == false)
continue;
// FIXME: we only need this for function pointers; it was
// already done for atomic types and enums in
// DeclStmt::TypeCheck()...
initExpr = TypeConvertExpr(initExpr, type, "initializer");
// FIXME: and this is only needed to re-establish
// constant-ness so that GetConstant below works for
// constant artithmetic expressions...
initExpr = initExpr->Optimize();
cinit = initExpr->GetConstant(type);
if (cinit == NULL)
Error(sym->pos, "Initializer for static variable \"%s\" must be a constant.",
sym->name.c_str());
Error(initExpr->pos, "Initializer for static variable "
"\"%s\" must be a constant.", sym->name.c_str());
}
if (cinit == NULL)
cinit = llvm::Constant::getNullValue(llvmType);
@@ -370,8 +403,10 @@ DeclStmt::TypeCheck() {
if (vars[i].init == NULL)
continue;
vars[i].init = vars[i].init->TypeCheck();
if (vars[i].init == NULL)
if (vars[i].init == NULL) {
encounteredError = true;
continue;
}
// get the right type for stuff like const float foo = 2; so that
// the int->float type conversion is in there and we don't return
@@ -506,39 +541,35 @@ IfStmt::EmitCode(FunctionEmitContext *ctx) const {
Stmt *
IfStmt::Optimize() {
if (test)
if (test != NULL)
test = test->Optimize();
if (trueStmts)
if (trueStmts != NULL)
trueStmts = trueStmts->Optimize();
if (falseStmts)
if (falseStmts != NULL)
falseStmts = falseStmts->Optimize();
return this;
}
Stmt *IfStmt::TypeCheck() {
if (test) {
if (test != NULL) {
test = test->TypeCheck();
if (test) {
if (test != NULL) {
const Type *testType = test->GetType();
if (testType) {
if (testType != NULL) {
bool isUniform = (testType->IsUniformType() &&
!g->opt.disableUniformControlFlow);
if (!testType->IsNumericType() && !testType->IsBoolType()) {
Error(test->pos, "Type \"%s\" can't be converted to boolean "
"for \"if\" test.", testType->GetString().c_str());
test = TypeConvertExpr(test, isUniform ? AtomicType::UniformBool :
AtomicType::VaryingBool,
"\"if\" statement test");
if (test == NULL)
return NULL;
}
test = new TypeCastExpr(isUniform ? AtomicType::UniformBool :
AtomicType::VaryingBool,
test, false, test->pos);
assert(test);
}
}
}
if (trueStmts)
if (trueStmts != NULL)
trueStmts = trueStmts->TypeCheck();
if (falseStmts)
if (falseStmts != NULL)
falseStmts = falseStmts->TypeCheck();
return this;
@@ -698,7 +729,8 @@ lSafeToRunWithAllLanesOff(Expr *expr) {
if (dynamic_cast<SymbolExpr *>(expr) != NULL ||
dynamic_cast<FunctionSymbolExpr *>(expr) != NULL ||
dynamic_cast<SyncExpr *>(expr) != NULL)
dynamic_cast<SyncExpr *>(expr) != NULL ||
dynamic_cast<NullPointerExpr *>(expr) != NULL)
return true;
FATAL("Unknown Expr type in lSafeToRunWithAllLanesOff()");
@@ -1659,6 +1691,12 @@ lEncodeType(const Type *t) {
if (t == AtomicType::VaryingUInt64) return 'V';
if (t == AtomicType::UniformDouble) return 'd';
if (t == AtomicType::VaryingDouble) return 'D';
if (dynamic_cast<const PointerType *>(t) != NULL) {
if (t->IsUniformType())
return 'p';
else
return 'P';
}
else return '\0';
}
@@ -1788,13 +1826,14 @@ PrintStmt::EmitCode(FunctionEmitContext *ctx) const {
llvm::Function *printFunc = m->module->getFunction("__do_print");
assert(printFunc);
llvm::Value *mask = ctx->GetFullMask();
// Set up the rest of the parameters to it
args[0] = ctx->GetStringPtr(format);
args[1] = ctx->GetStringPtr(argTypes);
args[2] = LLVMInt32(g->target.vectorWidth);
args[3] = ctx->LaneMask(ctx->GetFullMask());
args[3] = ctx->LaneMask(mask);
std::vector<llvm::Value *> argVec(&args[0], &args[5]);
ctx->CallInst(printFunc, argVec, "");
ctx->CallInst(printFunc, AtomicType::Void, argVec, "");
}
@@ -1874,7 +1913,7 @@ AssertStmt::EmitCode(FunctionEmitContext *ctx) const {
args.push_back(ctx->GetStringPtr(errorString));
args.push_back(expr->GetValue(ctx));
args.push_back(ctx->GetFullMask());
ctx->CallInst(assertFunc, args, "");
ctx->CallInst(assertFunc, AtomicType::Void, args, "");
#ifndef ISPC_IS_WINDOWS
free(errorString);