19 Commits

Author SHA1 Message Date
5e6f06cf59 Fixed issue with aliasing local variables
ISPC++ now produces valid code, or an appropriate error message, for all
of my test cases.
2017-05-11 15:42:11 -04:00
bfe723e1b7 Actually copy the AST.
Type replacement works except for function parameters.
2017-05-11 03:09:38 -04:00
f65b3e6300 [WIP] Remove cases for ForeachStmt and SymbolExpr 2017-05-11 01:19:50 -04:00
2e28640860 Merge branch 'master' into copy_ast 2017-05-10 23:13:40 -04:00
d020107d91 Typechecking fixes, moved some printing behind debug flag 2017-05-10 23:12:48 -04:00
ab29965d75 force add cpp file for test 2017-05-10 14:25:39 -04:00
64e1e2b008 Generate overloaded function definitions 2017-05-10 14:21:09 -04:00
6a91c5d5ac Attempt to replicate AST when expanding polytypes 2017-05-10 11:11:39 -04:00
192b99f21d Translates polymorphic function to a single instance 2017-05-09 23:41:36 -04:00
871af918ad Remove trailing whitespace 2017-05-09 23:01:40 -04:00
7bb1741b9a [WIP] implement ReplacePolyType for stmts 2017-05-09 15:30:39 -04:00
aeb4c0b6f9 [WIP] replace polymorphic types from expressions 2017-05-09 01:46:36 -04:00
9c0f9be022 remove trailing whitespace 2017-05-09 01:46:33 -04:00
a5306eddc1 Merge branch 'codegen' of github.com:aarongut/ispc into codegen 2017-05-08 17:45:28 -04:00
0f17514eb0 remove trailing whitespace 2017-05-08 17:45:17 -04:00
8a1aeed55c remove trailing whitespace 2017-05-08 17:40:15 -04:00
05c9f63527 Remove trailing whitespace 2017-05-08 15:30:06 -04:00
c86c5097d7 remove trailing whitespace 2017-05-07 15:08:47 -04:00
46ed9bdb3c [WIP] Plumbing to expand polymorphic functions 2017-05-04 21:26:43 -04:00
23 changed files with 1100 additions and 265 deletions

1
.gitignore vendored
View File

@@ -13,6 +13,7 @@ tests*/*cpp
tests*/*run
tests*/*.o
tests_ispcpp/*.h
tests_ispcpp/*.out
tests_ispcpp/*pre*
logs/
notify_log.log

153
ast.cpp
View File

@@ -42,8 +42,11 @@
#include "func.h"
#include "stmt.h"
#include "sym.h"
#include "type.h"
#include "util.h"
#include <map>
///////////////////////////////////////////////////////////////////////////
// ASTNode
@@ -55,14 +58,18 @@ ASTNode::~ASTNode() {
// AST
void
AST::AddFunction(Symbol *sym, Stmt *code) {
AST::AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable) {
if (sym == NULL)
return;
Function *f = new Function(sym, code);
if (f->IsPolyFunction()) {
FATAL("This is a good start, but implement me!");
std::vector<Function *> *expanded = f->ExpandPolyArguments(symbolTable);
for (size_t i=0; i<expanded->size(); i++) {
functions.push_back((*expanded)[i]);
}
delete expanded;
} else {
functions.push_back(f);
}
@@ -79,7 +86,7 @@ AST::GenerateIR() {
ASTNode *
WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
void *data) {
void *data, ASTPostCallBackFunc preUpdate) {
if (node == NULL)
return node;
@@ -90,6 +97,10 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
return node;
}
if (preUpdate != NULL) {
node = preUpdate(node, data);
}
////////////////////////////////////////////////////////////////////////////
// Handle Statements
if (llvm::dyn_cast<Stmt>(node) != NULL) {
@@ -113,54 +124,54 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
UnmaskedStmt *ums;
if ((es = llvm::dyn_cast<ExprStmt>(node)) != NULL)
es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data);
es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data, preUpdate);
else if ((ds = llvm::dyn_cast<DeclStmt>(node)) != NULL) {
for (unsigned int i = 0; i < ds->vars.size(); ++i)
ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc,
postFunc, data);
postFunc, data, preUpdate);
}
else if ((is = llvm::dyn_cast<IfStmt>(node)) != NULL) {
is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data);
is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data, preUpdate);
is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc,
postFunc, data);
postFunc, data, preUpdate);
is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc,
postFunc, data);
postFunc, data, preUpdate);
}
else if ((dos = llvm::dyn_cast<DoStmt>(node)) != NULL) {
dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc,
postFunc, data);
postFunc, data, preUpdate);
dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc,
postFunc, data);
postFunc, data, preUpdate);
}
else if ((fs = llvm::dyn_cast<ForStmt>(node)) != NULL) {
fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data);
fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data);
fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data);
fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data);
fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data, preUpdate);
fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data, preUpdate);
fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data, preUpdate);
fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data, preUpdate);
}
else if ((fes = llvm::dyn_cast<ForeachStmt>(node)) != NULL) {
for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc,
postFunc, data);
postFunc, data, preUpdate);
for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc,
postFunc, data);
fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data);
postFunc, data, preUpdate);
fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data, preUpdate);
}
else if ((fas = llvm::dyn_cast<ForeachActiveStmt>(node)) != NULL) {
fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data);
fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data, preUpdate);
}
else if ((fus = llvm::dyn_cast<ForeachUniqueStmt>(node)) != NULL) {
fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data);
fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data);
fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data, preUpdate);
fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data, preUpdate);
}
else if ((cs = llvm::dyn_cast<CaseStmt>(node)) != NULL)
cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data);
cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data, preUpdate);
else if ((defs = llvm::dyn_cast<DefaultStmt>(node)) != NULL)
defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data);
defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data, preUpdate);
else if ((ss = llvm::dyn_cast<SwitchStmt>(node)) != NULL) {
ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data);
ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data);
ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data, preUpdate);
ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data, preUpdate);
}
else if (llvm::dyn_cast<BreakStmt>(node) != NULL ||
llvm::dyn_cast<ContinueStmt>(node) != NULL ||
@@ -168,22 +179,22 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
// nothing
}
else if ((ls = llvm::dyn_cast<LabeledStmt>(node)) != NULL)
ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data);
ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data, preUpdate);
else if ((rs = llvm::dyn_cast<ReturnStmt>(node)) != NULL)
rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data);
rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data, preUpdate);
else if ((sl = llvm::dyn_cast<StmtList>(node)) != NULL) {
std::vector<Stmt *> &sls = sl->stmts;
for (unsigned int i = 0; i < sls.size(); ++i)
sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data);
sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data, preUpdate);
}
else if ((ps = llvm::dyn_cast<PrintStmt>(node)) != NULL)
ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data);
ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data, preUpdate);
else if ((as = llvm::dyn_cast<AssertStmt>(node)) != NULL)
as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data);
as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data, preUpdate);
else if ((dels = llvm::dyn_cast<DeleteStmt>(node)) != NULL)
dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data);
dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data, preUpdate);
else if ((ums = llvm::dyn_cast<UnmaskedStmt>(node)) != NULL)
ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data);
ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data, preUpdate);
else
FATAL("Unhandled statement type in WalkAST()");
}
@@ -208,57 +219,57 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
NewExpr *newe;
if ((ue = llvm::dyn_cast<UnaryExpr>(node)) != NULL)
ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data);
ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data, preUpdate);
else if ((be = llvm::dyn_cast<BinaryExpr>(node)) != NULL) {
be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data);
be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data);
be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data, preUpdate);
be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data, preUpdate);
}
else if ((ae = llvm::dyn_cast<AssignExpr>(node)) != NULL) {
ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data);
ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data);
ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data, preUpdate);
ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data, preUpdate);
}
else if ((se = llvm::dyn_cast<SelectExpr>(node)) != NULL) {
se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data);
se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data);
se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data);
se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data, preUpdate);
se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data, preUpdate);
se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data, preUpdate);
}
else if ((el = llvm::dyn_cast<ExprList>(node)) != NULL) {
for (unsigned int i = 0; i < el->exprs.size(); ++i)
el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc,
postFunc, data);
postFunc, data, preUpdate);
}
else if ((fce = llvm::dyn_cast<FunctionCallExpr>(node)) != NULL) {
fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data);
fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data);
fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data, preUpdate);
fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data, preUpdate);
for (int k = 0; k < 3; k++)
fce->launchCountExpr[0] = (Expr *)WalkAST(fce->launchCountExpr[0], preFunc,
postFunc, data);
postFunc, data, preUpdate);
}
else if ((ie = llvm::dyn_cast<IndexExpr>(node)) != NULL) {
ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data);
ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data);
ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data, preUpdate);
ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data, preUpdate);
}
else if ((me = llvm::dyn_cast<MemberExpr>(node)) != NULL)
me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data);
me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data, preUpdate);
else if ((tce = llvm::dyn_cast<TypeCastExpr>(node)) != NULL)
tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data);
tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data, preUpdate);
else if ((re = llvm::dyn_cast<ReferenceExpr>(node)) != NULL)
re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data);
re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data, preUpdate);
else if ((ptrderef = llvm::dyn_cast<PtrDerefExpr>(node)) != NULL)
ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc,
data);
data, preUpdate);
else if ((refderef = llvm::dyn_cast<RefDerefExpr>(node)) != NULL)
refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc,
data);
data, preUpdate);
else if ((soe = llvm::dyn_cast<SizeOfExpr>(node)) != NULL)
soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data);
soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data, preUpdate);
else if ((aoe = llvm::dyn_cast<AddressOfExpr>(node)) != NULL)
aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data);
aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data, preUpdate);
else if ((newe = llvm::dyn_cast<NewExpr>(node)) != NULL) {
newe->countExpr = (Expr *)WalkAST(newe->countExpr, preFunc,
postFunc, data);
postFunc, data, preUpdate);
newe->initExpr = (Expr *)WalkAST(newe->initExpr, preFunc,
postFunc, data);
postFunc, data, preUpdate);
}
else if (llvm::dyn_cast<SymbolExpr>(node) != NULL ||
llvm::dyn_cast<ConstExpr>(node) != NULL ||
@@ -492,7 +503,7 @@ lCheckAllOffSafety(ASTNode *node, void *data) {
}
/*
Don't allow turning if/else to straight-line-code if we
Don't allow turning if/else to straight-line-code if we
assign to a uniform.
*/
AssignExpr *ae;
@@ -515,3 +526,35 @@ SafeToRunWithMaskAllOff(ASTNode *root) {
WalkAST(root, lCheckAllOffSafety, NULL, &safe);
return safe;
}
struct PolyData {
const PolyType *polyType;
const Type *replacement;
};
static ASTNode *
lTranslatePolyNode(ASTNode *node, void *d) {
struct PolyData *data = (struct PolyData*)d;
return node->ReplacePolyType(data->polyType, data->replacement);
}
static ASTNode *
lCopyNode(ASTNode *node, void *) {
return node->Copy();
}
ASTNode *
TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) {
struct PolyData data;
data.polyType = polyType;
data.replacement = replacement;
return WalkAST(root, NULL, lTranslatePolyNode, &data, lCopyNode);
}
ASTNode *
CopyAST(ASTNode *root) {
return WalkAST(root, NULL, NULL, NULL, lCopyNode);
}

22
ast.h
View File

@@ -39,6 +39,7 @@
#define ISPC_AST_H 1
#include "ispc.h"
#include "type.h"
#include <vector>
/** @brief Abstract base class for nodes in the abstract syntax tree (AST).
@@ -67,6 +68,10 @@ public:
pointer in place of the original ASTNode *. */
virtual ASTNode *TypeCheck() = 0;
virtual ASTNode *Copy() = 0;
virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0;
/** Estimate the execution cost of the node (not including the cost of
the children. The value returned should be based on the COST_*
enumerant values defined in ispc.h. */
@@ -75,8 +80,8 @@ public:
/** All AST nodes must track the file position where they are
defined. */
SourcePos pos;
/** An enumeration for keeping track of the concrete subclass of Value
/** An enumeration for keeping track of the concrete subclass of Value
that is actually instantiated.*/
enum ASTNodeTy {
/* For classes inherited from Expr */
@@ -127,9 +132,9 @@ public:
SwitchStmtID,
UnmaskedStmtID
};
/** Return an ID for the concrete type of this object. This is used to
implement the classof checks. This should not be used for any
implement the classof checks. This should not be used for any
other purpose, as the values may change as ISPC evolves */
unsigned getValueID() const {
return SubclassID;
@@ -145,7 +150,7 @@ class AST {
public:
/** Add the AST for a function described by the given declaration
information and source code. */
void AddFunction(Symbol *sym, Stmt *code);
void AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable=NULL);
/** Generate LLVM IR for all of the functions into the current
module. */
@@ -174,7 +179,8 @@ typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data);
doing so, calls postFunc, at the node. The return value from the
postFunc call is ignored. */
extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc,
ASTPostCallBackFunc postFunc, void *data);
ASTPostCallBackFunc postFunc, void *data,
ASTPostCallBackFunc preUpdate = NULL);
/** Perform simple optimizations on the AST or portion thereof passed to
this function, returning the resulting AST. */
@@ -204,6 +210,10 @@ extern Stmt *TypeCheck(Stmt *);
the given root. */
extern int EstimateCost(ASTNode *root);
extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement);
extern ASTNode * CopyAST(ASTNode *root);
/** Returns true if it would be safe to run the given code with an "all
off" mask. */
extern bool SafeToRunWithMaskAllOff(ASTNode *root);

35
ctx.cpp
View File

@@ -631,7 +631,7 @@ FunctionEmitContext::EndIf() {
breakLanes, "|break_lanes");
}
llvm::Value *notBreakOrContinue =
llvm::Value *notBreakOrContinue =
BinaryOperator(llvm::Instruction::Xor,
bcLanes, LLVMMaskAllOn,
"!(break|continue)_lanes");
@@ -942,7 +942,7 @@ FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) {
finishedLanes = BinaryOperator(llvm::Instruction::Or, finishedLanes,
continued, "returned|breaked|continued");
}
finishedLanes = BinaryOperator(llvm::Instruction::And,
finishedLanes, GetFunctionMask(),
"finished&func");
@@ -1446,7 +1446,7 @@ FunctionEmitContext::None(llvm::Value *mask) {
llvm::Value *
FunctionEmitContext::LaneMask(llvm::Value *v) {
#ifdef ISPC_NVPTX_ENABLED
/* this makes mandelbrot example slower with "nvptx" target.
/* this makes mandelbrot example slower with "nvptx" target.
* Needs further investigation. */
const char *__movmsk = g->target->getISA() == Target::NVPTX ? "__movmsk_ptx" : "__movmsk";
#else
@@ -1494,7 +1494,7 @@ FunctionEmitContext::Insert(llvm::Value *vector, llvm::Value *lane, llvm::Value
std::string funcName = "__insert";
assert(lAppendInsertExtractName(vector, funcName));
assert(lane->getType() == LLVMTypes::Int32Type);
llvm::Function *func = m->module->getFunction(funcName.c_str());
assert(func != NULL);
std::vector<llvm::Value *> args;
@@ -1511,7 +1511,7 @@ FunctionEmitContext::Extract(llvm::Value *vector, llvm::Value *lane)
std::string funcName = "__extract";
assert(lAppendInsertExtractName(vector, funcName));
assert(lane->getType() == LLVMTypes::Int32Type);
llvm::Function *func = m->module->getFunction(funcName.c_str());
assert(func != NULL);
std::vector<llvm::Value *> args;
@@ -1927,6 +1927,11 @@ FunctionEmitContext::BinaryOperator(llvm::Instruction::BinaryOps inst,
return NULL;
}
if (v0->getType() != v1->getType()) {
v0->dump();
printf("\n\n");
v1->dump();
}
AssertPos(currentPos, v0->getType() == v1->getType());
llvm::Type *type = v0->getType();
int arraySize = lArrayVectorWidth(type);
@@ -2823,7 +2828,7 @@ FunctionEmitContext::loadUniformFromSOA(llvm::Value *ptr, llvm::Value *mask,
llvm::Value *
FunctionEmitContext::LoadInst(llvm::Value *ptr, llvm::Value *mask,
const Type *ptrRefType, const char *name,
const Type *ptrRefType, const char *name,
bool one_elem) {
if (ptr == NULL) {
AssertPos(currentPos, m->errorCount > 0);
@@ -3285,8 +3290,8 @@ FunctionEmitContext::scatter(llvm::Value *value, llvm::Value *ptr,
const PointerType *pt = CastType<PointerType>(valueType);
// And everything should be a pointer or atomic (or enum) from here on out...
AssertPos(currentPos,
pt != NULL
AssertPos(currentPos,
pt != NULL
|| CastType<AtomicType>(valueType) != NULL
|| CastType<EnumType>(valueType) != NULL);
@@ -3887,7 +3892,7 @@ FunctionEmitContext::LaunchInst(llvm::Value *callee,
llvm::Function *F = llvm::dyn_cast<llvm::Function>(callee);
const unsigned int nArgs = F->arg_size();
llvm::Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
for (; I != E; ++I)
for (; I != E; ++I)
argTypes.push_back(I->getType());
llvm::Type *st = llvm::StructType::get(*g->ctx, argTypes);
llvm::StructType *argStructType = static_cast<llvm::StructType *>(st);
@@ -3908,24 +3913,24 @@ FunctionEmitContext::LaunchInst(llvm::Value *callee,
llvm::BasicBlock* if_true = CreateBasicBlock("if_true");
llvm::BasicBlock* if_false = CreateBasicBlock("if_false");
/* check if the pointer returned by ISPCAlloc is not NULL
/* check if the pointer returned by ISPCAlloc is not NULL
* --------------
* this is a workaround for not checking the value of programIndex
* this is a workaround for not checking the value of programIndex
* because ISPCAlloc will return NULL pointer for all programIndex > 0
* of course, if ISPAlloc fails to get parameter buffer, the pointer for programIndex = 0
* will also be NULL
* This check must be added, and also rewrite the code to make it less opaque
* This check must be added, and also rewrite the code to make it less opaque
*/
llvm::Value* cmp1 = CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE, voidi64, LLVMInt64(0), "cmp1");
BranchInst(if_true, if_false, cmp1);
/**********************/
bblock = if_true;
bblock = if_true;
// label_if_then block:
llvm::Type *pt = llvm::PointerType::getUnqual(st);
llvm::Value *argmem = BitCastInst(voidmem, pt);
for (unsigned int i = 0; i < argVals.size(); ++i)
for (unsigned int i = 0; i < argVals.size(); ++i)
{
llvm::Value *ptr = AddElementOffset(argmem, i, NULL, "funarg");
// don't need to do masked store here, I think
@@ -4027,7 +4032,7 @@ FunctionEmitContext::LaunchInst(llvm::Value *callee,
void
FunctionEmitContext::SyncInst() {
#ifdef ISPC_NVPTX_ENABLED
#ifdef ISPC_NVPTX_ENABLED
if (g->target->getISA() == Target::NVPTX)
{
llvm::Value *launchGroupHandle = LoadInst(launchGroupHandlePtr);

10
ctx.h
View File

@@ -195,10 +195,10 @@ public:
'continue' statement when going through the loop body in the
previous iteration. */
void RestoreContinuedLanes();
/** This method is called by code emitting IR for a loop. It clears
/** This method is called by code emitting IR for a loop. It clears
any lanes that contained a break since the mask has been updated to take
them into account. This is necessary as all the bail out checks for
them into account. This is necessary as all the bail out checks for
breaks are meant to only deal with lanes breaking on the current iteration.
*/
void ClearBreakLanes();
@@ -312,7 +312,7 @@ public:
llvm::Value* Insert(llvm::Value *vector, llvm::Value *lane, llvm::Value *scalar);
/** Issues a call to __extract_int8/int16/int32/int64/float/double */
llvm::Value* Extract(llvm::Value *vector, llvm::Value *lane);
#endif
#endif
/** Given a string, create an anonymous global variable to hold its
value and return the pointer to the string. */
@@ -481,7 +481,7 @@ public:
pointer values given by the lvalue. If the lvalue is not varying,
then both the mask pointer and the type pointer may be NULL. */
llvm::Value *LoadInst(llvm::Value *ptr, llvm::Value *mask,
const Type *ptrType, const char *name = NULL,
const Type *ptrType, const char *name = NULL,
bool one_elem = false);
llvm::Value *LoadInst(llvm::Value *ptr, const char *name = NULL);

244
expr.cpp
View File

@@ -112,6 +112,85 @@ Expr::GetBaseSymbol() const {
return NULL;
}
Expr *
Expr::Copy() {
Expr *copy;
switch (getValueID()) {
case AddressOfExprID:
copy = (Expr*)new AddressOfExpr(*(AddressOfExpr*)this);
break;
case AssignExprID:
copy = (Expr*)new AssignExpr(*(AssignExpr*)this);
break;
case BinaryExprID:
copy = (Expr*)new BinaryExpr(*(BinaryExpr*)this);
break;
case ConstExprID:
copy = (Expr*)new ConstExpr(*(ConstExpr*)this);
break;
case PtrDerefExprID:
copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this);
break;
case RefDerefExprID:
copy = (Expr*)new RefDerefExpr(*(RefDerefExpr*)this);
break;
case ExprListID:
copy = (Expr*)new ExprList(*(ExprList*)this);
break;
case FunctionCallExprID:
copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this);
break;
case FunctionSymbolExprID:
copy = (Expr*)new FunctionSymbolExpr(*(FunctionSymbolExpr*)this);
break;
case IndexExprID:
copy = (Expr*)new IndexExpr(*(IndexExpr*)this);
break;
case StructMemberExprID:
copy = (Expr*)new StructMemberExpr(*(StructMemberExpr*)this);
break;
case VectorMemberExprID:
copy = (Expr*)new VectorMemberExpr(*(VectorMemberExpr*)this);
break;
case NewExprID:
copy = (Expr*)new NewExpr(*(NewExpr*)this);
break;
case NullPointerExprID:
copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this);
break;
case ReferenceExprID:
copy = (Expr*)new ReferenceExpr(*(ReferenceExpr*)this);
break;
case SelectExprID:
copy = (Expr*)new SelectExpr(*(SelectExpr*)this);
break;
case SizeOfExprID:
copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this);
break;
case SymbolExprID:
copy = (Expr*)new SymbolExpr(*(SymbolExpr*)this);
break;
case SyncExprID:
copy = (Expr*)new SyncExpr(*(SyncExpr*)this);
break;
case TypeCastExprID:
copy = (Expr*)new TypeCastExpr(*(TypeCastExpr*)this);
break;
case UnaryExprID:
copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this);
break;
default:
FATAL("Unmatched case in Expr::Copy");
copy = this; // just to silence the compiler
}
return copy;
}
Expr *
Expr::ReplacePolyType(const PolyType *, const Type *) {
return this;
}
#if 0
/** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost
@@ -556,6 +635,7 @@ lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr,
"\"%s\" for %s", fromType->GetString().c_str(),
toPolyType->GetString().c_str(), errorMsgBase);
}
return false;
}
}
@@ -3199,7 +3279,7 @@ static llvm::Value *
lEmitVaryingSelect(FunctionEmitContext *ctx, llvm::Value *test,
llvm::Value *expr1, llvm::Value *expr2,
const Type *type) {
llvm::Value *resultPtr = ctx->AllocaInst(expr1->getType(), "selectexpr_tmp");
// Don't need to worry about masking here
ctx->StoreInst(expr2, resultPtr);
@@ -3699,7 +3779,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const {
ctx->SetDebugPos(pos);
if (ft->isTask) {
AssertPos(pos, launchCountExpr[0] != NULL);
llvm::Value *launchCount[3] =
llvm::Value *launchCount[3] =
{ launchCountExpr[0]->GetValue(ctx),
launchCountExpr[1]->GetValue(ctx),
launchCountExpr[2]->GetValue(ctx) };
@@ -3768,7 +3848,7 @@ FunctionCallExpr::GetType() const {
const Type *
FunctionCallExpr::GetLValueType() const {
const FunctionType *ftype = lGetFunctionType(func);
if (ftype && (ftype->GetReturnType()->IsPointerType()
if (ftype && (ftype->GetReturnType()->IsPointerType()
|| ftype->GetReturnType()->IsReferenceType())) {
return ftype->GetReturnType();
}
@@ -4309,7 +4389,7 @@ IndexExpr::GetValue(FunctionEmitContext *ctx) const {
}
else {
Symbol *baseSym = GetBaseSymbol();
if (llvm::dyn_cast<FunctionCallExpr>(baseExpr) == NULL &&
if (llvm::dyn_cast<FunctionCallExpr>(baseExpr) == NULL &&
llvm::dyn_cast<BinaryExpr>(baseExpr) == NULL) {
// Don't check if we're doing a function call or pointer arith
AssertPos(pos, baseSym != NULL);
@@ -4669,6 +4749,23 @@ IndexExpr::TypeCheck() {
return this;
}
Expr *
IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (index == NULL || baseExpr == NULL)
return NULL;
if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
if (Type::EqualForReplacement(GetLValueType()->GetBaseType(), from)) {
lvalueType = new PointerType(to, lvalueType->GetVariability(),
lvalueType->IsConstType());
}
return this;
}
int
IndexExpr::EstimateCost() const {
@@ -4734,27 +4831,6 @@ lIdentifierToVectorElement(char id) {
//////////////////////////////////////////////////
// StructMemberExpr
class StructMemberExpr : public MemberExpr
{
public:
StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(StructMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == StructMemberExprID;
}
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const StructType *getStructType() const;
};
StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue)
: MemberExpr(e, id, p, idpos, derefLValue, StructMemberExprID) {
@@ -4906,31 +4982,6 @@ StructMemberExpr::getStructType() const {
//////////////////////////////////////////////////
// VectorMemberExpr
class VectorMemberExpr : public MemberExpr
{
public:
VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(VectorMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == VectorMemberExprID;
}
llvm::Value *GetValue(FunctionEmitContext* ctx) const;
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const VectorType *exprVectorType;
const VectorType *memberType;
};
VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue)
: MemberExpr(e, id, p, idpos, derefLValue, VectorMemberExprID) {
@@ -5163,7 +5214,7 @@ MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos,
}
if (CastType<StructType>(exprType) != NULL) {
const StructType *st = CastType<StructType>(exprType);
if (st->IsDefined()) {
if (st->IsDefined()) {
return new StructMemberExpr(e, id, p, idpos, derefLValue);
}
else {
@@ -5311,6 +5362,19 @@ MemberExpr::Optimize() {
return expr ? this : NULL;
}
Expr *
MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL)
return NULL;
if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
return this;
}
int
MemberExpr::EstimateCost() const {
@@ -7109,10 +7173,16 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const {
return NULL;
return ctx->IntToPtrInst(exprVal, llvmToType, "int_to_ptr");
}
else {
} else if (CastType<PolyType>(toType)) {
Error(pos, "Unexpected polymorphic type cast to \"%s\"",
toType->GetString().c_str());
return NULL;
} else {
const AtomicType *toAtomic = CastType<AtomicType>(toType);
// typechecking should ensure this is the case
if (!toAtomic) {
fprintf(stderr, "I want %s to be atomic\n", toType->GetString().c_str());
}
AssertPos(pos, toAtomic != NULL);
return lTypeConvAtomic(ctx, exprVal, toAtomic, fromAtomic, pos);
@@ -7219,7 +7289,7 @@ TypeCastExpr::TypeCheck() {
// Issues #721
return this;
}
const AtomicType *fromAtomic = CastType<AtomicType>(fromType);
const AtomicType *toAtomic = CastType<AtomicType>(toType);
const EnumType *fromEnum = CastType<EnumType>(fromType);
@@ -7342,6 +7412,18 @@ TypeCastExpr::Optimize() {
return this;
}
Expr *
TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (type == NULL)
return NULL;
if (Type::EqualForReplacement(type->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
return this;
}
int
TypeCastExpr::EstimateCost() const {
@@ -8012,6 +8094,24 @@ SymbolExpr::Optimize() {
return this;
}
Expr *
SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!symbol)
return NULL;
Symbol *tmp = m->symbolTable->LookupVariable(symbol->name.c_str());
if (tmp) {
tmp->parentFunction = symbol->parentFunction;
symbol = tmp;
}
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
symbol->type = PolyType::ReplaceType(symbol->type, to);
}
return this;
}
int
SymbolExpr::EstimateCost() const {
@@ -8081,6 +8181,14 @@ FunctionSymbolExpr::Optimize() {
return this;
}
Expr *
FunctionSymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
// force re-evaluation of overloaded type
this->triedToResolve = false;
return this;
}
int
FunctionSymbolExpr::EstimateCost() const {
@@ -8327,6 +8435,16 @@ FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype,
cost[i] += 8 * costScale;
continue;
}
if (callTypeNC->IsPolymorphicType()) {
const PolyType *callTypeP =
CastType<PolyType>(callTypeNC->GetBaseType());
if (callTypeP->CanBeType(fargTypeNC->GetBaseType()) &&
callTypeNC->IsArrayType() == fargTypeNC->IsArrayType() &&
callTypeNC->IsPointerType() == fargTypeNC->IsPointerType()){
cost[i] += 8 * costScale;
continue;
}
}
if (fargType->IsVaryingType() && callType->IsUniformType()) {
// Here we deal with brodcasting uniform to varying.
// callType - varying and fargType - uniform is forbidden.
@@ -8453,6 +8571,12 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
return true;
}
else if (matches.size() > 1) {
for (size_t i=0; i<argTypes.size(); i++) {
if (argTypes[i]->IsPolymorphicType()) {
matchingFunc = matches[0];
return true;
}
}
// Multiple matches: ambiguous
std::string candidateMessage =
lGetOverloadCandidateMessage(matches, argTypes, argCouldBeNULL);
@@ -8810,6 +8934,18 @@ NewExpr::Optimize() {
return this;
}
Expr *
NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!allocType)
return this;
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
allocType = PolyType::ReplaceType(allocType, to);
}
return this;
}
void
NewExpr::Print() const {

96
expr.h
View File

@@ -96,6 +96,11 @@ public:
encountered, NULL should be returned. */
virtual Expr *TypeCheck() = 0;
Expr *Copy();
/** This method replaces a polymorphic type with a specific atomic type */
Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement);
/** Prints the expression to standard output (used for debugging). */
virtual void Print() const = 0;
};
@@ -162,12 +167,12 @@ public:
};
BinaryExpr(Op o, Expr *a, Expr *b, SourcePos p);
static inline bool classof(BinaryExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == BinaryExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
const Type *GetLValueType() const;
@@ -205,7 +210,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == AssignExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
void Print() const;
@@ -231,7 +236,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == SelectExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
void Print() const;
@@ -258,7 +263,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == ExprListID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
void Print() const;
@@ -276,14 +281,14 @@ public:
class FunctionCallExpr : public Expr {
public:
FunctionCallExpr(Expr *func, ExprList *args, SourcePos p,
bool isLaunch = false,
bool isLaunch = false,
Expr *launchCountExpr[3] = NULL);
static inline bool classof(FunctionCallExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == FunctionCallExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
@@ -314,7 +319,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == IndexExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
@@ -324,6 +329,7 @@ public:
Expr *Optimize();
Expr *TypeCheck();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
Expr *baseExpr, *index;
@@ -349,7 +355,7 @@ public:
return ((N->getValueID() == StructMemberExprID) ||
(N->getValueID() == VectorMemberExprID));
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
@@ -357,6 +363,7 @@ public:
void Print() const;
Expr *Optimize();
Expr *TypeCheck();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
virtual int getElementNumber() const = 0;
@@ -379,6 +386,51 @@ protected:
mutable const Type *type, *lvalueType;
};
class StructMemberExpr : public MemberExpr
{
public:
StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(StructMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == StructMemberExprID;
}
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const StructType *getStructType() const;
};
class VectorMemberExpr : public MemberExpr
{
public:
VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(VectorMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == VectorMemberExprID;
}
llvm::Value *GetValue(FunctionEmitContext* ctx) const;
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const VectorType *exprVectorType;
const VectorType *memberType;
};
/** @brief Expression representing a compile-time constant value.
@@ -452,7 +504,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == ConstExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
void Print() const;
@@ -514,7 +566,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == TypeCastExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
@@ -522,6 +574,7 @@ public:
void Print() const;
Expr *TypeCheck();
Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
Symbol *GetBaseSymbol() const;
llvm::Constant *GetConstant(const Type *type) const;
@@ -541,7 +594,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == ReferenceExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
const Type *GetLValueType() const;
@@ -567,7 +620,7 @@ public:
(N->getValueID() == PtrDerefExprID) ||
(N->getValueID() == RefDerefExprID));
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
const Type *GetLValueType() const;
@@ -588,7 +641,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == PtrDerefExprID;
}
const Type *GetType() const;
void Print() const;
Expr *TypeCheck();
@@ -606,7 +659,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == RefDerefExprID;
}
const Type *GetType() const;
void Print() const;
Expr *TypeCheck();
@@ -649,7 +702,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == SizeOfExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
void Print() const;
@@ -673,7 +726,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == SymbolExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
llvm::Value *GetLValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
@@ -681,6 +734,7 @@ public:
Symbol *GetBaseSymbol() const;
Expr *TypeCheck();
Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const;
int EstimateCost() const;
@@ -701,12 +755,13 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == FunctionSymbolExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
Symbol *GetBaseSymbol() const;
Expr *TypeCheck();
Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const;
int EstimateCost() const;
llvm::Constant *GetConstant(const Type *type) const;
@@ -762,7 +817,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == SyncExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
Expr *TypeCheck();
@@ -781,7 +836,7 @@ public:
static inline bool classof(ASTNode const* N) {
return N->getValueID() == NullPointerExprID;
}
llvm::Value *GetValue(FunctionEmitContext *ctx) const;
const Type *GetType() const;
Expr *TypeCheck();
@@ -809,6 +864,7 @@ public:
const Type *GetType() const;
Expr *TypeCheck();
Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const;
int EstimateCost() const;

105
func.cpp
View File

@@ -45,6 +45,7 @@
#include "sym.h"
#include "util.h"
#include <stdio.h>
#include <set>
#if ISPC_LLVM_VERSION == ISPC_LLVM_3_2 // 3.2
#ifdef ISPC_NVPTX_ENABLED
@@ -89,7 +90,7 @@
#endif
#include <llvm/Support/ToolOutputFile.h>
Function::Function(Symbol *s, Stmt *c) {
Function::Function(Symbol *s, Stmt *c, bool typecheck) {
sym = s;
code = c;
@@ -97,13 +98,15 @@ Function::Function(Symbol *s, Stmt *c) {
Assert(maskSymbol != NULL);
if (code != NULL) {
code = TypeCheck(code);
if (typecheck) {
code = TypeCheck(code);
if (code != NULL && g->debugPrint) {
printf("After typechecking function \"%s\":\n",
sym->name.c_str());
code->Print(0);
printf("---------------------\n");
if (code != NULL && g->debugPrint) {
printf("After typechecking function \"%s\":\n",
sym->name.c_str());
code->Print(0);
printf("---------------------\n");
}
}
if (code != NULL) {
@@ -134,13 +137,14 @@ Function::Function(Symbol *s, Stmt *c) {
args.push_back(sym);
const Type *t = type->GetParameterType(i);
if (sym != NULL && CastType<ReferenceType>(t) == NULL)
sym->parentFunction = this;
}
if (type->isTask
#ifdef ISPC_NVPTX_ENABLED
&& (g->target->getISA() != Target::NVPTX)
&& (g->target->getISA() != Target::NVPTX)
#endif
){
threadIndexSym = m->symbolTable->LookupVariable("threadIndex");
@@ -260,8 +264,8 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function,
Assert(type != NULL);
if (type->isTask == true
#ifdef ISPC_NVPTX_ENABLED
&& (g->target->getISA() != Target::NVPTX)
#endif
&& (g->target->getISA() != Target::NVPTX)
#endif
){
// For tasks, there should always be three parameters: the
// pointer to the structure that holds all of the arguments, the
@@ -322,14 +326,14 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function,
taskCountSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount");
ctx->StoreInst(taskCount, taskCountSym->storagePtr);
taskIndexSym0->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex0");
ctx->StoreInst(taskIndex0, taskIndexSym0->storagePtr);
taskIndexSym1->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex1");
ctx->StoreInst(taskIndex1, taskIndexSym1->storagePtr);
taskIndexSym2->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex2");
ctx->StoreInst(taskIndex2, taskIndexSym2->storagePtr);
taskCountSym0->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount0");
ctx->StoreInst(taskCount0, taskCountSym0->storagePtr);
taskCountSym1->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount1");
@@ -570,7 +574,7 @@ Function::GenerateIR() {
av.push_back(function);
av.push_back(llvm::MDString::get(*g->ctx, "kernel"));
av.push_back(llvm::ConstantInt::get(llvm::IntegerType::get(*g->ctx,32), 1));
annotations->addOperand(llvm::MDNode::get(*g->ctx, av));
annotations->addOperand(llvm::MDNode::get(*g->ctx, av));
#endif
}
#endif /* ISPC_NVPTX_ENABLED */
@@ -611,7 +615,7 @@ Function::GenerateIR() {
av.push_back(llvm::ValueAsMetadata::get(appFunction));
av.push_back(llvm::MDString::get(*g->ctx, "kernel"));
av.push_back(llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(llvm::IntegerType::get(*g->ctx,32), 1)));
annotations->addOperand(llvm::MDNode::get(*g->ctx, llvm::ArrayRef<llvm::Metadata*>(av)));
annotations->addOperand(llvm::MDNode::get(*g->ctx, llvm::ArrayRef<llvm::Metadata*>(av)));
#else
llvm::SmallVector<llvm::Value*, 3> av;
av.push_back(appFunction);
@@ -638,3 +642,76 @@ Function::IsPolyFunction() const {
return false;
}
std::vector<Function *> *
Function::ExpandPolyArguments(SymbolTable *symbolTable) const {
Assert(symbolTable != NULL);
std::vector<Function *> *expanded = new std::vector<Function *>();
std::vector<Symbol *> versions = symbolTable->LookupPolyFunction(sym->name.c_str());
const FunctionType *func = CastType<FunctionType>(sym->type);
for (size_t i=0; i<versions.size(); i++) {
if (g->debugPrint) {
printf("%s before replacing anything:\n", sym->name.c_str());
code->Print(0);
}
const FunctionType *ft = CastType<FunctionType>(versions[i]->type);
symbolTable->PushScope();
Symbol *s = symbolTable->LookupFunction(versions[i]->name.c_str(), ft);
Stmt *ncode = (Stmt*)CopyAST(code);
Function *f = new Function(s, ncode, false);
for (size_t j=0; j<args.size(); j++) {
f->args[j] = new Symbol(*args[j]);
symbolTable->AddVariable(f->args[j], false);
}
for (int j=0; j<ft->GetNumParameters(); j++) {
if (func->GetParameterType(j)->IsPolymorphicType()) {
const PolyType *from = CastType<PolyType>(
func->GetParameterType(j)->GetBaseType());
f->code = (Stmt*)TranslatePoly(f->code, from,
ft->GetParameterType(j)->GetBaseType());
if (g->debugPrint) {
printf("%s after replacing %s with %s:\n\n",
sym->name.c_str(), from->GetString().c_str(),
ft->GetParameterType(j)->GetBaseType()->GetString().c_str());
f->code->Print(0);
printf("------------------------------------------\n\n");
}
}
}
// we didn't typecheck before, now we can
f->code = TypeCheck(f->code);
f->code = Optimize(f->code);
if (g->debugPrint) {
printf("After optimizing expanded function \"%s\":\n",
f->sym->name.c_str());
f->code->Print(0);
printf("---------------------\n");
}
symbolTable->PopScope();
expanded->push_back(f);
}
return expanded;
}

5
func.h
View File

@@ -39,11 +39,12 @@
#define ISPC_FUNC_H 1
#include "ispc.h"
#include "sym.h"
#include <vector>
class Function {
public:
Function(Symbol *sym, Stmt *code);
Function(Symbol *sym, Stmt *code, bool typecheck=true);
const Type *GetReturnType() const;
const FunctionType *GetType() const;
@@ -54,6 +55,8 @@ public:
/** Checks if the function has polymorphic parameters */
const bool IsPolyFunction() const;
std::vector<Function *> *ExpandPolyArguments(SymbolTable *symbolTable) const;
private:
void emitCode(FunctionEmitContext *ctx, llvm::Function *function,
SourcePos firstStmtPos);

View File

@@ -471,7 +471,7 @@ Target::Target(const char *arch, const char *cpu, const char *isa, bool pic, boo
m_is32Bit(true),
m_cpu(""),
m_attributes(""),
#if ISPC_LLVM_VERSION >= ISPC_LLVM_3_3
#if ISPC_LLVM_VERSION >= ISPC_LLVM_3_3
m_tf_attributes(NULL),
#endif
m_nativeVectorWidth(-1),
@@ -733,7 +733,7 @@ Target::Target(const char *arch, const char *cpu, const char *isa, bool pic, boo
else if (!strcasecmp(isa, "generic-16") ||
!strcasecmp(isa, "generic-x16") ||
// We treat *-generic-16 as generic-16, but with special name mangling
strstr(isa, "-generic-16") ||
strstr(isa, "-generic-16") ||
strstr(isa, "-generic-x16")) {
this->m_isa = Target::GENERIC;
if (strstr(isa, "-generic-16") ||

View File

@@ -349,7 +349,7 @@ lStripUnusedDebugInfo(llvm::Module *module) {
// And now we can go and stuff it into the unit with some
// confidence...
llvm::MDNode *replNode = llvm::MDNode::get(module->getContext(),
llvm::MDNode *replNode = llvm::MDNode::get(module->getContext(),
llvm::ArrayRef<llvm::Metadata *>(usedSubprograms));
cu.replaceSubprograms(llvm::DIArray(replNode));
#else // LLVM 3.7+
@@ -589,7 +589,7 @@ Module::AddGlobalVariable(const std::string &name, const Type *type, Expr *initE
}
#ifdef ISPC_NVPTX_ENABLED
if (g->target->getISA() == Target::NVPTX &&
if (g->target->getISA() == Target::NVPTX &&
#if 0
!type->IsConstType() &&
#endif
@@ -609,7 +609,7 @@ Module::AddGlobalVariable(const std::string &name, const Type *type, Expr *initE
* or 128 threads.
* ***note-to-me***:please define these value (128threads/4warps)
* in nvptx-target definition
* instead of compile-time constants
* instead of compile-time constants
*/
nel *= at->GetElementCount();
assert (!type->IsSOAType());
@@ -830,7 +830,7 @@ lRecursiveCheckValidParamType(const Type *t, bool vectorOk) {
if (pt != NULL) {
// Only allow exported uniform pointers
// Uniform pointers to varying data, however, are ok.
if (pt->IsVaryingType())
if (pt->IsVaryingType())
return false;
else
return lRecursiveCheckValidParamType(pt->GetBaseType(), true);
@@ -838,7 +838,7 @@ lRecursiveCheckValidParamType(const Type *t, bool vectorOk) {
if (t->IsVaryingType() && !vectorOk)
return false;
else
else
return true;
}
@@ -871,7 +871,7 @@ lCheckExportedParameterTypes(const Type *type, const std::string &name,
static void
lCheckTaskParameterTypes(const Type *type, const std::string &name,
SourcePos pos) {
if (g->target->getISA() != Target::NVPTX)
if (g->target->getISA() != Target::NVPTX)
return;
if (lRecursiveCheckValidParamType(type, false) == false) {
if (CastType<VectorType>(type))
@@ -1009,6 +1009,102 @@ Module::AddFunctionDeclaration(const std::string &name,
}
}
/* Handle Polymorphic functions
* a function
* int foo(number n, floating, f)
* will produce versions such as
* int foo(int n, float f)
*
* these functions will be overloaded if they are not exported, or mangled
* if exported */
std::set<const Type *, bool(*)(const Type*, const Type*)> toExpand(&PolyType::Less);
std::vector<const FunctionType *> expanded;
expanded.push_back(functionType);
for (int i=0; i<functionType->GetNumParameters(); i++) {
const Type *param = functionType->GetParameterType(i);
if (param->IsPolymorphicType() &&
!toExpand.count(param->GetBaseType())) {
toExpand.insert(param->GetBaseType());
}
}
std::vector<const FunctionType *> nextExpanded;
for (auto iter = toExpand.begin(); iter != toExpand.end(); iter++) {
for (size_t j=0; j<expanded.size(); j++) {
const FunctionType *eft = expanded[j];
const PolyType *pt=CastType<PolyType>(*iter);
std::vector<AtomicType *>::iterator te;
for (te = pt->ExpandBegin(); te != pt->ExpandEnd(); te++) {
llvm::SmallVector<const Type *, 8> nargs;
llvm::SmallVector<std::string, 8> nargsn;
llvm::SmallVector<Expr *, 8> nargsd;
llvm::SmallVector<SourcePos, 8> nargsp;
for (size_t k=0; k<eft->GetNumParameters(); k++) {
if (Type::Equal(eft->GetParameterType(k)->GetBaseType(),
pt)) {
const Type *r;
r = PolyType::ReplaceType(eft->GetParameterType(k),*te);
nargs.push_back(r);
} else {
nargs.push_back(eft->GetParameterType(k));
}
nargsn.push_back(eft->GetParameterName(k));
nargsd.push_back(eft->GetParameterDefault(k));
nargsp.push_back(eft->GetParameterSourcePos(k));
}
const Type *ret = eft->GetReturnType();
if (Type::EqualForReplacement(ret, pt)) {
printf("Replaced return type %s\n",
ret->GetString().c_str());
ret = PolyType::ReplaceType(ret, *te);
}
nextExpanded.push_back(new FunctionType(ret,
nargs,
nargsn,
nargsd,
nargsp,
eft->isTask,
eft->isExported,
eft->isExternC,
eft->isUnmasked));
}
}
expanded.swap(nextExpanded);
nextExpanded.clear();
}
if (expanded.size() > 1) {
for (size_t i=0; i<expanded.size(); i++) {
if (expanded[i]->GetReturnType()->IsPolymorphicType()) {
Error(pos, "Unexpected polymorphic return type \"%s\"",
expanded[i]->GetReturnType()->GetString().c_str());
return;
}
std::string nname = name;
if (functionType->isExported || functionType->isExternC) {
for (int j=0; j<expanded[i]->GetNumParameters(); j++) {
nname += "_";
nname += expanded[i]->GetParameterType(j)->Mangle();
}
}
symbolTable->MapPolyFunction(name, nname, expanded[i]);
AddFunctionDeclaration(nname, expanded[i], storageClass,
isInline, pos);
}
return;
}
// Get the LLVM FunctionType
bool disableMask = (storageClass == SC_EXTERN_C);
llvm::FunctionType *llvmFunctionType =
@@ -1026,7 +1122,7 @@ Module::AddFunctionDeclaration(const std::string &name,
functionName += functionType->Mangle();
// If we treat generic as smth, we should have appropriate mangling
if (g->mangleFunctionsWithTarget) {
if (g->target->getISA() == Target::GENERIC &&
if (g->target->getISA() == Target::GENERIC &&
!g->target->getTreatGenericAsSmth().empty())
functionName += g->target->getTreatGenericAsSmth();
else
@@ -1177,14 +1273,7 @@ Module::AddFunctionDefinition(const std::string &name, const FunctionType *type,
sym->pos = code->pos;
// FIXME: because we encode the parameter names in the function type,
// we need to override the function type here in case the function had
// earlier been declared with anonymous parameter names but is now
// defined with actual names. This is yet another reason we shouldn't
// include the names in FunctionType...
sym->type = type;
ast->AddFunction(sym, code);
ast->AddFunction(sym, code, symbolTable);
}
@@ -1326,7 +1415,7 @@ Module::writeOutput(OutputType outputType, const char *outFileName,
#ifdef ISPC_NVPTX_ENABLED
typedef std::vector<std::string> vecString_t;
static vecString_t
static vecString_t
lSplitString(const std::string &s)
{
std::stringstream ss(s);
@@ -1335,7 +1424,7 @@ lSplitString(const std::string &s)
return vecString_t(begin,end);
}
static void
static void
lFixAttributes(const vecString_t &src, vecString_t &dst)
{
dst.clear();
@@ -1434,7 +1523,7 @@ Module::writeBitcode(llvm::Module *module, const char *outFileName) {
#ifdef ISPC_NVPTX_ENABLED
if (g->target->getISA() == Target::NVPTX)
{
/* when using "nvptx" target, emit patched/hacked assembly
/* when using "nvptx" target, emit patched/hacked assembly
* NVPTX only accepts 3.2-style LLVM assembly, where attributes
* must be inlined, rather then referenced by #attribute_d
* As soon as NVVM support 3.3,3.4 style assembly this fix won't be needed
@@ -1506,7 +1595,7 @@ Module::writeObjectFileOrAssembly(llvm::TargetMachine *targetMachine,
#if ISPC_LLVM_VERSION <= ISPC_LLVM_3_5
std::string error;
#else // LLVM 3.6+
#else // LLVM 3.6+
std::error_code error;
#endif
@@ -1518,7 +1607,7 @@ Module::writeObjectFileOrAssembly(llvm::TargetMachine *targetMachine,
#if ISPC_LLVM_VERSION <= ISPC_LLVM_3_5
if (error.size()) {
#else // LLVM 3.6+
#else // LLVM 3.6+
if (error) {
#endif
@@ -1603,7 +1692,7 @@ static void
lEmitStructDecl(const StructType *st, std::vector<const StructType *> *emittedStructs,
FILE *file, bool emitUnifs=true) {
// if we're emitting this for a generic dispatch header file and it's
// if we're emitting this for a generic dispatch header file and it's
// struct that only contains uniforms, don't bother if we're emitting uniforms
if (!emitUnifs && !lContainsPtrToVarying(st)) {
return;
@@ -1626,7 +1715,7 @@ lEmitStructDecl(const StructType *st, std::vector<const StructType *> *emittedSt
// And now it's safe to declare this one
emittedStructs->push_back(st);
fprintf(file, "#ifndef __ISPC_STRUCT_%s__\n",st->GetCStructName().c_str());
fprintf(file, "#define __ISPC_STRUCT_%s__\n",st->GetCStructName().c_str());
@@ -1848,7 +1937,7 @@ lGetExportedTypes(const Type *type,
lGetExportedTypes(ftype->GetParameterType(j), exportedStructTypes,
exportedEnumTypes, exportedVectorTypes);
}
else
else
Assert(CastType<AtomicType>(type) != NULL);
}
@@ -1899,6 +1988,27 @@ lPrintFunctionDeclarations(FILE *file, const std::vector<Symbol *> &funcs,
// fprintf(file, "#ifdef __cplusplus\n} /* end extern C */\n#endif // __cplusplus\n");
}
static void
lPrintPolyFunctionWrappers(FILE *file, const std::vector<std::string> &funcs) {
fprintf(file, "#if defined(__cplusplus)\n");
for (size_t i=0; i<funcs.size(); i++) {
std::vector<Symbol *> poly = m->symbolTable->LookupPolyFunction(funcs[i].c_str());
for (size_t j=0; j<poly.size(); j++) {
const FunctionType *ftype = CastType<FunctionType>(poly[j]->type);
Assert(ftype);
std::string decl = ftype->GetCDeclaration(funcs[i]);
fprintf(file, " %s {\n", decl.c_str());
std::string call = ftype->GetCCall(poly[j]->name);
fprintf(file, " return %s;\n }\n", call.c_str());
}
}
fprintf(file, "#endif // __cplusplus\n");
}
@@ -2275,8 +2385,10 @@ Module::writeHeader(const char *fn) {
// Collect single linear arrays of the exported and extern "C"
// functions
std::vector<Symbol *> exportedFuncs, externCFuncs;
std::vector<std::string> polyFuncs;
m->symbolTable->GetMatchingFunctions(lIsExported, &exportedFuncs);
m->symbolTable->GetMatchingFunctions(lIsExternC, &externCFuncs);
m->symbolTable->GetPolyFunctions(&polyFuncs);
// Get all of the struct, vector, and enumerant types used as function
// parameters. These vectors may have repeats.
@@ -2313,6 +2425,16 @@ Module::writeHeader(const char *fn) {
fprintf(f, "///////////////////////////////////////////////////////////////////////////\n");
lPrintFunctionDeclarations(f, exportedFuncs);
}
// emit wrappers for polymorphic functions
if (polyFuncs.size() > 0) {
fprintf(f, "\n");
fprintf(f, "///////////////////////////////////////////////////////////////////////////\n");
fprintf(f, "// Polymorphic function wrappers\n");
fprintf(f, "///////////////////////////////////////////////////////////////////////////\n");
lPrintPolyFunctionWrappers(f, polyFuncs);
}
#if 0
if (externCFuncs.size() > 0) {
fprintf(f, "\n");
@@ -2349,7 +2471,7 @@ struct DispatchHeaderInfo {
bool
Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
FILE *f = DHI->file;
if (DHI->EmitFrontMatter) {
fprintf(f, "//\n// %s\n// (Header automatically generated by the ispc compiler.)\n", DHI->fn);
fprintf(f, "// DO NOT EDIT THIS FILE.\n//\n\n");
@@ -2392,10 +2514,10 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
std::vector<Symbol *> exportedFuncs, externCFuncs;
m->symbolTable->GetMatchingFunctions(lIsExported, &exportedFuncs);
m->symbolTable->GetMatchingFunctions(lIsExternC, &externCFuncs);
int programCount = g->target->getVectorWidth();
if ((DHI->Emit4 && (programCount == 4)) ||
if ((DHI->Emit4 && (programCount == 4)) ||
(DHI->Emit8 && (programCount == 8)) ||
(DHI->Emit16 && (programCount == 16))) {
// Get all of the struct, vector, and enumerant types used as function
@@ -2407,7 +2529,7 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
&exportedEnumTypes, &exportedVectorTypes);
lGetExportedParamTypes(externCFuncs, &exportedStructTypes,
&exportedEnumTypes, &exportedVectorTypes);
// Go through the explicitly exported types
for (int i = 0; i < (int)exportedTypes.size(); ++i) {
if (const StructType *st = CastType<StructType>(exportedTypes[i].first))
@@ -2420,19 +2542,19 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
FATAL("Unexpected type in export list");
}
// And print them
if (DHI->EmitUnifs) {
lEmitVectorTypedefs(exportedVectorTypes, f);
lEmitEnumDecls(exportedEnumTypes, f);
}
lEmitStructDecls(exportedStructTypes, f, DHI->EmitUnifs);
// Update flags
DHI->EmitUnifs = false;
if (programCount == 4) {
DHI->Emit4 = false;
}
}
else if (programCount == 8) {
DHI->Emit8 = false;
}
@@ -2457,12 +2579,12 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
// end namespace
fprintf(f, "\n");
fprintf(f, "\n#ifdef __cplusplus\n} /* namespace */\n#endif // __cplusplus\n");
// end guard
fprintf(f, "\n#endif // %s\n", guard.c_str());
DHI->EmitBackMatter = false;
}
return true;
}
@@ -2477,17 +2599,17 @@ Module::execPreprocessor(const char *infilename, llvm::raw_string_ostream *ostre
clang::DiagnosticOptions *diagOptions = new clang::DiagnosticOptions();
clang::TextDiagnosticPrinter *diagPrinter =
new clang::TextDiagnosticPrinter(stderrRaw, diagOptions);
llvm::IntrusiveRefCntPtr<clang::DiagnosticIDs> diagIDs(new clang::DiagnosticIDs);
clang::DiagnosticsEngine *diagEngine =
new clang::DiagnosticsEngine(diagIDs, diagOptions, diagPrinter);
inst.setDiagnostics(diagEngine);
#if ISPC_LLVM_VERSION <= ISPC_LLVM_3_4 // 3.2, 3.3, 3.4
clang::TargetOptions &options = inst.getTargetOpts();
#else // LLVM 3.5+
const std::shared_ptr< clang::TargetOptions > &options =
const std::shared_ptr< clang::TargetOptions > &options =
std::make_shared< clang::TargetOptions >(inst.getTargetOpts());
#endif
@@ -2654,7 +2776,7 @@ lGetTargetFileName(const char *outFileName, const char *isaString, bool forceCXX
strcpy(targetOutFileName, outFileName);
strcat(targetOutFileName, "_");
strcat(targetOutFileName, isaString);
// Append ".cpp" suffix to the original file if it is *-generic target
if (forceCXX)
strcat(targetOutFileName, ".cpp");
@@ -2760,11 +2882,11 @@ lGetVaryingDispatchType(FunctionTargetVariants &funcs) {
}
}
}
// We should've found at least one variant here
// or else something fishy is going on.
Assert(resultFuncTy);
return resultFuncTy;
}
@@ -2847,7 +2969,7 @@ lCreateDispatchFunction(llvm::Module *module, llvm::Function *setISAFunc,
// dispatchNum is needed to separate generic from *-generic target
int dispatchNum = i;
if ((Target::ISA)(i == Target::GENERIC) &&
if ((Target::ISA)(i == Target::GENERIC) &&
!g->target->getTreatGenericAsSmth().empty()) {
if (g->target->getTreatGenericAsSmth() == "knl_generic")
dispatchNum = Target::KNL_AVX512;
@@ -2879,7 +3001,7 @@ lCreateDispatchFunction(llvm::Module *module, llvm::Function *setISAFunc,
args.push_back(&*argIter);
}
else {
llvm::CastInst *argCast =
llvm::CastInst *argCast =
llvm::CastInst::CreatePointerCast(&*argIter, targsIter->getType(),
"dpatch_arg_bitcast", callBBlock);
args.push_back(argCast);
@@ -3053,7 +3175,7 @@ lExtractOrCheckGlobals(llvm::Module *msrc, llvm::Module *mdst, bool check) {
}
#ifdef ISPC_NVPTX_ENABLED
static std::string lCBEMangle(const std::string &S)
static std::string lCBEMangle(const std::string &S)
{
std::string Result;
@@ -3102,7 +3224,7 @@ Module::CompileAndOutput(const char *srcFile,
if (m->CompileFile() == 0) {
#ifdef ISPC_NVPTX_ENABLED
/* NVPTX:
* for PTX target replace '.' with '_' in all global variables
* for PTX target replace '.' with '_' in all global variables
* a PTX identifier name must match [a-zA-Z$_][a-zA-Z$_0-9]*
*/
if (g->target->getISA() == Target::NVPTX)
@@ -3135,7 +3257,7 @@ Module::CompileAndOutput(const char *srcFile,
}
}
else if (outputType == Asm || outputType == Object) {
if (target != NULL &&
if (target != NULL &&
(strncmp(target, "generic-", 8) == 0 || strstr(target, "-generic-") != NULL)) {
Error(SourcePos(), "When using a \"generic-*\" compilation target, "
"%s output can not be used.",
@@ -3212,7 +3334,7 @@ Module::CompileAndOutput(const char *srcFile,
std::map<std::string, FunctionTargetVariants> exportedFunctions;
int errorCount = 0;
// Handle creating a "generic" header file for multiple targets
// that use exported varyings
DispatchHeaderInfo DHI;
@@ -3234,7 +3356,7 @@ Module::CompileAndOutput(const char *srcFile,
}
// Variable is needed later for approptiate dispatch function.
// It indicates if we have *-generic target.
// It indicates if we have *-generic target.
std::string treatGenericAsSmth = "";
for (unsigned int i = 0; i < targets.size(); ++i) {
@@ -3272,9 +3394,9 @@ Module::CompileAndOutput(const char *srcFile,
if (outFileName != NULL) {
std::string targetOutFileName;
// We always generate cpp file for *-generic target during multitarget compilation
if (g->target->getISA() == Target::GENERIC &&
if (g->target->getISA() == Target::GENERIC &&
!g->target->getTreatGenericAsSmth().empty()) {
targetOutFileName = lGetTargetFileName(outFileName,
targetOutFileName = lGetTargetFileName(outFileName,
g->target->getTreatGenericAsSmth().c_str(), true);
if (!m->writeOutput(CXX, targetOutFileName.c_str(), includeFileName))
return 1;
@@ -3299,14 +3421,14 @@ Module::CompileAndOutput(const char *srcFile,
// only print backmatter on the last target.
DHI.EmitBackMatter = true;
}
const char *isaName;
if (g->target->getISA() == Target::GENERIC &&
!g->target->getTreatGenericAsSmth().empty())
isaName = g->target->getTreatGenericAsSmth().c_str();
else
else
isaName = g->target->GetISAString();
std::string targetHeaderFileName =
std::string targetHeaderFileName =
lGetTargetFileName(headerFileName, isaName, false);
// write out a header w/o target name for the first target only
if (!m->writeOutput(Module::Header, headerFileName, "", &DHI)) {

148
stmt.cpp
View File

@@ -35,6 +35,7 @@
@brief File with definitions classes related to statements in the language
*/
#include "ast.h"
#include "stmt.h"
#include "ctx.h"
#include "util.h"
@@ -77,6 +78,85 @@ Stmt::Optimize() {
return this;
}
Stmt *
Stmt::Copy() {
Stmt *copy;
switch (getValueID()) {
case AssertStmtID:
copy = (Stmt*)new AssertStmt(*(AssertStmt*)this);
break;
case BreakStmtID:
copy = (Stmt*)new BreakStmt(*(BreakStmt*)this);
break;
case CaseStmtID:
copy = (Stmt*)new CaseStmt(*(CaseStmt*)this);
break;
case ContinueStmtID:
copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this);
break;
case DeclStmtID:
copy = (Stmt*)new DeclStmt(*(DeclStmt*)this);
break;
case DefaultStmtID:
copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this);
break;
case DeleteStmtID:
copy = (Stmt*)new DeleteStmt(*(DeleteStmt*)this);
break;
case DoStmtID:
copy = (Stmt*)new DoStmt(*(DoStmt*)this);
break;
case ExprStmtID:
copy = (Stmt*)new ExprStmt(*(ExprStmt*)this);
break;
case ForeachActiveStmtID:
copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this);
break;
case ForeachStmtID:
copy = (Stmt*)new ForeachStmt(*(ForeachStmt*)this);
break;
case ForeachUniqueStmtID:
copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this);
break;
case ForStmtID:
copy = (Stmt*)new ForStmt(*(ForStmt*)this);
break;
case GotoStmtID:
copy = (Stmt*)new GotoStmt(*(GotoStmt*)this);
break;
case IfStmtID:
copy = (Stmt*)new IfStmt(*(IfStmt*)this);
break;
case LabeledStmtID:
copy = (Stmt*)new LabeledStmt(*(LabeledStmt*)this);
break;
case PrintStmtID:
copy = (Stmt*)new PrintStmt(*(PrintStmt*)this);
break;
case ReturnStmtID:
copy = (Stmt*)new ReturnStmt(*(ReturnStmt*)this);
break;
case StmtListID:
copy = (Stmt*)new StmtList(*(StmtList*)this);
break;
case SwitchStmtID:
copy = (Stmt*)new SwitchStmt(*(SwitchStmt*)this);
break;
case UnmaskedStmtID:
copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this);
break;
default:
FATAL("Unmatched case in Stmt::Copy");
copy = this; // just to silence the compiler
}
return copy;
}
Stmt *
Stmt::ReplacePolyType(const PolyType *, const Type *) {
return this;
}
///////////////////////////////////////////////////////////////////////////
// ExprStmt
@@ -145,11 +225,11 @@ lHasUnsizedArrays(const Type *type) {
#ifdef ISPC_NVPTX_ENABLED
static llvm::Value* lConvertToGenericPtr(FunctionEmitContext *ctx, llvm::Value *value, const SourcePos &currentPos, const bool variable = false)
{
if (!value->getType()->isPointerTy() || g->target->getISA() != Target::NVPTX)
if (!value->getType()->isPointerTy() || g->target->getISA() != Target::NVPTX)
return value;
llvm::PointerType *pt = llvm::dyn_cast<llvm::PointerType>(value->getType());
const int addressSpace = pt->getAddressSpace();
if (addressSpace != 3 && addressSpace != 4)
if (addressSpace != 3 && addressSpace != 4)
return value;
llvm::Type *elTy = pt->getElementType();
@@ -271,17 +351,17 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
#ifdef ISPC_NVPTX_ENABLED
if (g->target->getISA() == Target::NVPTX && !sym->type->IsConstType())
{
Error(sym->pos,
Error(sym->pos,
"Non-constant static variable ""\"%s\" is not supported with ""\"nvptx\" target.",
sym->name.c_str());
return;
}
if (g->target->getISA() == Target::NVPTX && sym->type->IsVaryingType())
PerformanceWarning(sym->pos,
PerformanceWarning(sym->pos,
"\"const static varying\" variable ""\"%s\" is stored in __global address space with ""\"nvptx\" target.",
sym->name.c_str());
if (g->target->getISA() == Target::NVPTX && sym->type->IsUniformType())
PerformanceWarning(sym->pos,
PerformanceWarning(sym->pos,
"\"const static uniform\" variable ""\"%s\" is stored in __constant address space with ""\"nvptx\" target.",
sym->name.c_str());
#endif /* ISPC_NVPTX_ENABLED */
@@ -346,11 +426,11 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
#ifdef ISPC_NVPTX_ENABLED
else if ((sym->type->IsUniformType() || sym->type->IsSOAType()) &&
/* NVPTX:
* only non-constant uniform data types are stored in shared memory
* constant uniform are automatically promoted to varying
* only non-constant uniform data types are stored in shared memory
* constant uniform are automatically promoted to varying
*/
!sym->type->IsConstType() &&
#if 1
#if 1
sym->type->IsArrayType() &&
#endif
g->target->getISA() == Target::NVPTX)
@@ -370,7 +450,7 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
* or 128 threads.
* ***note-to-me***:please define these value (128threads/4warps)
* in nvptx-target definition
* instead of compile-time constants
* instead of compile-time constants
*/
nel *= at->GetElementCount();
if (sym->type->IsSOAType())
@@ -387,9 +467,9 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
sym->storagePtr =
new llvm::GlobalVariable(*m->module, llvmTypeUn,
sym->type->IsConstType(),
llvm::GlobalValue::InternalLinkage,
llvm::GlobalValue::InternalLinkage,
cinit,
llvm::Twine("local_") +
llvm::Twine("local_") +
llvm::Twine(sym->pos.first_line) +
llvm::Twine("_") + sym->name.c_str(),
NULL,
@@ -479,7 +559,8 @@ DeclStmt::TypeCheck() {
// an int as the constValue later...
const Type *type = vars[i].sym->type;
if (CastType<AtomicType>(type) != NULL ||
CastType<EnumType>(type) != NULL) {
CastType<EnumType>(type) != NULL ||
CastType<PolyType>(type) != NULL) {
// If it's an expr list with an atomic type, we'll later issue
// an error. Need to leave vars[i].init as is in that case so
// it is in fact caught later, though.
@@ -494,6 +575,24 @@ DeclStmt::TypeCheck() {
return encounteredError ? NULL : this;
}
Stmt *
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
for (size_t i = 0; i < vars.size(); i++) {
vars[i].sym = new Symbol(*vars[i].sym);
m->symbolTable->AddVariable(vars[i].sym, false);
Symbol *s = vars[i].sym;
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
s->type = PolyType::ReplaceType(s->type, to);
// this typecast *should* be valid after typechecking
vars[i].init = TypeConvertExpr(vars[i].init, s->type,
"initializer");
}
}
return this;
}
void
DeclStmt::Print(int indent) const {
@@ -590,7 +689,7 @@ IfStmt::EmitCode(FunctionEmitContext *ctx) const {
#if 0
if (!isUniform && g->target->getISA() == Target::NVPTX)
{
/* With "nvptx" target, SIMT hardware takes care of non-uniform
/* With "nvptx" target, SIMT hardware takes care of non-uniform
* control flow. We trick ISPC to generate uniform control flow.
*/
testValue = ctx->ExtractInst(testValue, 0);
@@ -1470,6 +1569,17 @@ ForeachStmt::ForeachStmt(const std::vector<Symbol *> &lvs,
stmts(s) {
}
/*
ForeachStmt::ForeachStmt(ForeachStmt *base)
: Stmt(base->pos, ForeachStmtID) {
dimVariables = base->dimVariables;
startExprs = base->startExprs;
endExprs = base->endExprs;
isTiled = base->isTiled;
stmts = base->stmts;
}
*/
/* Given a uniform counter value in the memory location pointed to by
uniformCounterPtr, compute the corresponding set of varying counter
@@ -1495,9 +1605,9 @@ lUpdateVaryingCounter(int dim, int nDims, FunctionEmitContext *ctx,
// (0,1,2,3,0,1,2,3), and for the outer dimension we want
// (0,0,0,0,1,1,1,1).
int32_t delta[ISPC_MAX_NVEC];
const int vecWidth = 32;
const int vecWidth = 32;
std::vector<llvm::Constant*> constDeltaList;
for (int i = 0; i < vecWidth; ++i)
for (int i = 0; i < vecWidth; ++i)
{
int d = i;
// First, account for the effect of any dimensions at deeper
@@ -1694,7 +1804,7 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
std::vector<int> span(nDims, 0);
#ifdef ISPC_NVPTX_ENABLED
const int vectorWidth =
const int vectorWidth =
g->target->getISA() == Target::NVPTX ? 32 : g->target->getVectorWidth();
lGetSpans(nDims-1, nDims, vectorWidth, isTiled, &span[0]);
#else /* ISPC_NVPTX_ENABLED */
@@ -1713,8 +1823,10 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
// Start and end value for this loop dimension
llvm::Value *sv = startExprs[i]->GetValue(ctx);
llvm::Value *ev = endExprs[i]->GetValue(ctx);
if (sv == NULL || ev == NULL)
if (sv == NULL || ev == NULL) {
fprintf(stderr, "ev is NULL again :(\n");
return;
}
startVals.push_back(sv);
endVals.push_back(ev);
@@ -3338,7 +3450,7 @@ lProcessPrintArg(Expr *expr, FunctionEmitContext *ctx, std::string &argTypes) {
}
else {
if (Type::Equal(baseType, AtomicType::UniformBool)) {
// Blast bools to ints, but do it here to preserve encoding for
// Blast bools to ints, but do it here to preserve encoding for
// printing 'true' or 'false'
expr = new TypeCastExpr(type->IsUniformType() ? AtomicType::UniformInt32 :
AtomicType::VaryingInt32,

3
stmt.h
View File

@@ -70,6 +70,8 @@ public:
// Stmts don't have anything to do here.
virtual Stmt *Optimize();
virtual Stmt *TypeCheck() = 0;
Stmt *Copy();
Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement);
};
@@ -117,6 +119,7 @@ public:
Stmt *Optimize();
Stmt *TypeCheck();
Stmt *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
std::vector<VariableDeclaration> vars;

44
sym.cpp
View File

@@ -95,14 +95,14 @@ SymbolTable::PopScope() {
bool
SymbolTable::AddVariable(Symbol *symbol) {
SymbolTable::AddVariable(Symbol *symbol, bool issueScopeWarning) {
Assert(symbol != NULL);
// Check to see if a symbol of the same name has already been declared.
for (int i = (int)variables.size() - 1; i >= 0; --i) {
SymbolMapType &sm = *(variables[i]);
if (sm.find(symbol->name) != sm.end()) {
if (i == (int)variables.size()-1) {
if (i == (int)variables.size()-1 && issueScopeWarning) {
// If a symbol of the same name was declared in the
// same scope, it's an error.
Error(symbol->pos, "Ignoring redeclaration of symbol \"%s\".",
@@ -112,9 +112,11 @@ SymbolTable::AddVariable(Symbol *symbol) {
else {
// Otherwise it's just shadowing something else, which
// is legal but dangerous..
Warning(symbol->pos,
"Symbol \"%s\" shadows symbol declared in outer scope.",
symbol->name.c_str());
if (issueScopeWarning) {
Warning(symbol->pos,
"Symbol \"%s\" shadows symbol declared in outer scope.",
symbol->name.c_str());
}
(*variables.back())[symbol->name] = symbol;
return true;
}
@@ -147,7 +149,7 @@ bool
SymbolTable::AddFunction(Symbol *symbol) {
const FunctionType *ft = CastType<FunctionType>(symbol->type);
Assert(ft != NULL);
if (LookupFunction(symbol->name.c_str(), ft) != NULL)
if (LookupFunction(symbol->name.c_str(), ft, true) != NULL)
// A function of the same name and type has already been added to
// the symbol table
return false;
@@ -157,6 +159,14 @@ SymbolTable::AddFunction(Symbol *symbol) {
return true;
}
void
SymbolTable::MapPolyFunction(std::string name, std::string polyname,
const FunctionType *type) {
std::vector<Symbol *> &polyExpansions = polyFunctions[name];
SourcePos p;
polyExpansions.push_back(new Symbol(polyname, p, type, SC_NONE));
}
bool
SymbolTable::LookupFunction(const char *name, std::vector<Symbol *> *matches) {
@@ -175,7 +185,8 @@ SymbolTable::LookupFunction(const char *name, std::vector<Symbol *> *matches) {
Symbol *
SymbolTable::LookupFunction(const char *name, const FunctionType *type) {
SymbolTable::LookupFunction(const char *name, const FunctionType *type,
bool ignorePoly) {
FunctionMapType::iterator iter = functions.find(name);
if (iter != functions.end()) {
std::vector<Symbol *> funcs = iter->second;
@@ -184,9 +195,28 @@ SymbolTable::LookupFunction(const char *name, const FunctionType *type) {
return funcs[j];
}
}
// Try looking for a polymorphic function
if (!ignorePoly && polyFunctions[name].size() > 0) {
std::string n = name;
return new Symbol(name, polyFunctions[name][0]->pos, type);
}
return NULL;
}
std::vector<Symbol *>&
SymbolTable::LookupPolyFunction(const char *name) {
return polyFunctions[name];
}
void
SymbolTable::GetPolyFunctions(std::vector<std::string> *funcs) {
FunctionMapType::iterator it = polyFunctions.begin();
for (; it != polyFunctions.end(); it++) {
funcs->push_back(it->first);
}
}
bool
SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) {

22
sym.h
View File

@@ -108,6 +108,7 @@ public:
};
/** @brief Symbol table that holds all known symbols during parsing and compilation.
A single instance of a SymbolTable is stored in the Module class
@@ -140,7 +141,7 @@ public:
with a symbol defined at the same scope. (Symbols may shaodow
symbols in outer scopes; a warning is issued in this case, but this
method still returns true.) */
bool AddVariable(Symbol *symbol);
bool AddVariable(Symbol *symbol, bool issueScopeWarning=true);
/** Looks for a variable with the given name in the symbol table. This
method searches outward from the innermost scope to the outermost,
@@ -159,6 +160,14 @@ public:
already present in the symbol table. */
bool AddFunction(Symbol *symbol);
/** Adds the given function to the list of polymorphic definitions for the
given name
@param name The name of the original function
@param type The expanded FunctionType */
void MapPolyFunction(std::string name, std::string polyname,
const FunctionType *type);
/** Looks for the function or functions with the given name in the
symbol name. If a function has been overloaded and multiple
definitions are present for a given function name, all of them will
@@ -172,7 +181,12 @@ public:
in the symbol table.
@return pointer to matching Symbol; NULL if none is found. */
Symbol *LookupFunction(const char *name, const FunctionType *type);
Symbol *LookupFunction(const char *name, const FunctionType *type,
bool ignorePoly = false);
std::vector<Symbol *>& LookupPolyFunction(const char *name);
void GetPolyFunctions(std::vector<std::string> *funcs);
/** Returns all of the functions in the symbol table that match the given
predicate.
@@ -219,7 +233,7 @@ public:
@return Pointer to the Type, if found; otherwise NULL is returned.
*/
const Type *LookupType(const char *name) const;
/** Look for a type given a pointer.
@return True if found, False otherwise.
@@ -276,6 +290,8 @@ private:
typedef std::map<std::string, std::vector<Symbol *> > FunctionMapType;
FunctionMapType functions;
FunctionMapType polyFunctions;
/** Type definitions can't currently be scoped.
*/
typedef std::map<std::string, const Type *> TypeMapType;

13
tests_ispcpp/Makefile Normal file
View File

@@ -0,0 +1,13 @@
CXX=g++
CXXFLAGS=-std=c++11 -O2
ISPC=../ispc
ISPCFLAGS=--target=sse4-x2 -O2 --arch=x86-64
%.out : %.cpp %.o
$(CXX) $(CXXFLAGS) -o $@ $^
$ : $.o
%.o : %.ispc
$(ISPC) $(ISPCFLAGS) -h $*.h -o $*.o $<

View File

@@ -1,7 +1,7 @@
#include <stdlib.h>
#include <stdio.h>
#include "hello.ispc.h"
#include "hello.h"
int main() {
float A[100];

20
tests_ispcpp/simple.cpp Normal file
View File

@@ -0,0 +1,20 @@
#include <stdlib.h>
#include <stdio.h>
#include "simple.h"
int main() {
double A[256];
for (int i=0; i<256; i++) {
A[i] = i / 11.;
}
ispc::foo(256, (double*)&A);
for (int i=0; i<256; i++) {
printf("%f\n", A[i]);
}
return 0;
}

View File

@@ -1,4 +1,4 @@
export void foo(uniform int N, floating$1 X[])
export void foo(uniform int N, uniform floating$1 X[])
{
foreach (i = 0 ... N) {
X[i] = X[i] + 1.0;

27
tests_ispcpp/varying.cpp Normal file
View File

@@ -0,0 +1,27 @@
#include <stdlib.h>
#include <stdio.h>
#include "varying.h"
int main() {
float A[256];
double B[256];
double outA[256];
double outB[256];
for (int i=0; i<256; i++) {
A[i] = 1. / (i+1);
B[i] = 1. / (i+1);
}
ispc::square(256, (float*)&A, (double*)&outA);
ispc::square(256, (double*)&B, (double*)&outB);
for (int i=0; i<256; i++) {
printf("float: %.16f\tdouble: %.16f\n", outA[i], outB[i]);
}
return 0;
}

14
tests_ispcpp/varying.ispc Normal file
View File

@@ -0,0 +1,14 @@
floating foo(const uniform int a, floating b) {
floating out = b;
for (int i = 1; i<a; i++) {
out *= b;
}
return out;
}
export void square(uniform int N, uniform floating b[], uniform double out[]) {
foreach (i = 0 ... N) {
out[i] = foo(2, b[i]);
}
}

158
type.cpp
View File

@@ -249,6 +249,15 @@ Type::IsVoidType() const {
bool
Type::IsPolymorphicType() const {
const FunctionType *ft = CastType<FunctionType>(this);
if (ft) {
for (int i=0; i<ft->GetNumParameters(); i++) {
if (ft->GetParameterType(i)->IsPolymorphicType())
return true;
}
return false;
}
return (CastType<PolyType>(GetBaseType()) != NULL);
}
@@ -694,16 +703,70 @@ const PolyType *PolyType::UniformNumber =
const PolyType *PolyType::VaryingNumber =
new PolyType(PolyType::TYPE_NUMBER, Variability::Varying, false);
const Type *
PolyType::ReplaceType(const Type *from, const Type *to) {
const Type *t = to;
if (from->IsPointerType()) {
t = new PointerType(to,
from->GetVariability(),
from->IsConstType());
} else if (from->IsArrayType()) {
t = new ArrayType(to,
CastType<ArrayType>(from)->GetElementCount());
} else if (from->IsReferenceType()) {
t = new ReferenceType(to);
}
if (from->IsVaryingType())
t = t->GetAsVaryingType();
if (g->debugPrint) {
fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n",
from->GetString().c_str(),
t->GetString().c_str());
}
return t;
}
bool
PolyType::Less(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
if (!pa || !pb) {
char buf[1024];
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
"\"%s\" and \"%s\"\n",
a->GetString().c_str(), b->GetString().c_str());
FATAL(buf);
}
if (pa->restriction < pb->restriction)
return true;
if (pa->restriction > pb->restriction)
return false;
if (pa->GetQuant() < pb->GetQuant())
return true;
return false;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic, int q)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(q) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
@@ -814,6 +877,39 @@ PolyType::GetAsUniformType() const {
return asUniformType;
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandBegin() const {
if (expandedTypes)
return expandedTypes->begin();
expandedTypes = new std::vector<AtomicType *>();
if (restriction == TYPE_INTEGER || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT64, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT64, variability, isConst));
}
if (restriction == TYPE_FLOATING || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_FLOAT, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_DOUBLE, variability, isConst));
}
return expandedTypes->begin();
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandEnd() const {
Assert(expandedTypes != NULL);
return expandedTypes->end();
}
const PolyType *
PolyType::GetAsUnboundVariabilityType() const {
@@ -887,7 +983,7 @@ PolyType::GetString() const {
case TYPE_NUMBER: ret += "number"; break;
default: FATAL("Logic error in PolyType::GetString()");
}
if (quant >= 0) {
ret += "$";
ret += std::to_string(quant);
@@ -1584,9 +1680,9 @@ PointerType::GetCDeclaration(const std::string &name) const {
}
std::string ret = baseType->GetCDeclaration("");
bool baseIsBasicVarying = (IsBasicType(baseType)) && (baseType->IsVaryingType());
if (baseIsBasicVarying) ret += std::string("(");
ret += std::string(" *");
if (isConst) ret += " const";
@@ -2428,7 +2524,7 @@ StructType::StructType(const std::string &n, const llvm::SmallVector<const Type
}
}
const std::string
const std::string
StructType::GetCStructName() const {
// only return mangled name for varying structs for backwards
// compatibility...
@@ -3488,7 +3584,7 @@ FunctionType::GetCDeclaration(const std::string &fname) const {
CastType<ArrayType>(pt->GetBaseType()) != NULL) {
type = new ArrayType(pt->GetBaseType(), 0);
}
if (paramNames[i] != "")
ret += type->GetCDeclaration(paramNames[i]);
else
@@ -3500,6 +3596,34 @@ FunctionType::GetCDeclaration(const std::string &fname) const {
return ret;
}
std::string
FunctionType::GetCCall(const std::string &fname) const {
std::string ret;
ret += fname;
ret += "(";
for (unsigned int i = 0; i < paramTypes.size(); ++i) {
const Type *type = paramTypes[i];
// Convert pointers to arrays to unsized arrays, which are more clear
// to print out for multidimensional arrays (i.e. "float foo[][4] "
// versus "float (foo *)[4]").
const PointerType *pt = CastType<PointerType>(type);
if (pt != NULL &&
CastType<ArrayType>(pt->GetBaseType()) != NULL) {
type = new ArrayType(pt->GetBaseType(), 0);
}
if (paramNames[i] != "")
ret += paramNames[i];
else
FATAL("Exporting a polymorphic function with incomplete arguments");
if (i != paramTypes.size() - 1)
ret += ", ";
}
ret += ")";
return ret;
}
std::string
FunctionType::GetCDeclarationForDispatch(const std::string &fname) const {
@@ -3519,11 +3643,11 @@ FunctionType::GetCDeclarationForDispatch(const std::string &fname) const {
CastType<ArrayType>(pt->GetBaseType()) != NULL) {
type = new ArrayType(pt->GetBaseType(), 0);
}
// Change pointers to varying thingies to void *
if (pt != NULL && pt->GetBaseType()->IsVaryingType()) {
PointerType *t = PointerType::Void;
if (paramNames[i] != "")
ret += t->GetCDeclaration(paramNames[i]);
else
@@ -3655,10 +3779,10 @@ FunctionType::LLVMFunctionType(llvm::LLVMContext *ctx, bool removeMask) const {
llvmArgTypes.push_back(LLVMTypes::MaskType);
std::vector<llvm::Type *> callTypes;
if (isTask
if (isTask
#ifdef ISPC_NVPTX_ENABLED
&& (g->target->getISA() != Target::NVPTX)
#endif
#endif
){
// Tasks take three arguments: a pointer to a struct that holds the
// actual task arguments, the thread index, and the total number of
@@ -3956,7 +4080,8 @@ bool
Type::IsBasicType(const Type *type) {
return (CastType<AtomicType>(type) != NULL ||
CastType<EnumType>(type) != NULL ||
CastType<PointerType>(type) != NULL);
CastType<PointerType>(type) != NULL ||
CastType<PolyType>(type) != NULL);
}
@@ -4080,3 +4205,16 @@ bool
Type::EqualIgnoringConst(const Type *a, const Type *b) {
return lCheckTypeEquality(a, b, true);
}
bool
Type::EqualForReplacement(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a);
const PolyType *pb = CastType<PolyType>(b);
if (!pa || !pb)
return false;
return pa->restriction == pb->restriction &&
pa->GetQuant() == pb->GetQuant();
}

13
type.h
View File

@@ -244,6 +244,8 @@ public:
the same (ignoring const-ness of the type), false otherwise. */
static bool EqualIgnoringConst(const Type *a, const Type *b);
static bool EqualForReplacement(const Type *a, const Type *b);
/** Given two types, returns the least general Type that is more general
than both of them. (i.e. that can represent their values without
any loss of data.) If there is no such Type, return NULL.
@@ -360,10 +362,10 @@ public:
static const AtomicType *UniformDouble, *VaryingDouble;
static const AtomicType *Void;
AtomicType(BasicType basicType, Variability v, bool isConst);
private:
const Variability variability;
const bool isConst;
AtomicType(BasicType basicType, Variability v, bool isConst);
mutable const AtomicType *asOtherConstType, *asUniformType, *asVaryingType;
};
@@ -413,6 +415,10 @@ public:
const PolyRestriction restriction;
static const Type * ReplaceType(const Type *from, const Type *to);
static bool Less(const Type *a, const Type *b);
static const PolyType *UniformInteger, *VaryingInteger;
static const PolyType *UniformFloating, *VaryingFloating;
@@ -420,7 +426,8 @@ public:
// Returns the list of AtomicTypes that are valid instantiations of the
// polymorphic type
const std::vector<AtomicType *> GetEnumeratedTypes() const;
const std::vector<AtomicType *>::iterator ExpandBegin() const;
const std::vector<AtomicType *>::iterator ExpandEnd() const;
private:
const Variability variability;
@@ -430,6 +437,7 @@ private:
PolyType(PolyRestriction type, Variability v, bool isConst, int quant);
mutable const PolyType *asOtherConstType, *asUniformType, *asVaryingType;
mutable std::vector<AtomicType *> *expandedTypes;
};
@@ -980,6 +988,7 @@ public:
std::string GetString() const;
std::string Mangle() const;
std::string GetCDeclaration(const std::string &fname) const;
std::string GetCCall(const std::string &fname) const;
std::string GetCDeclarationForDispatch(const std::string &fname) const;
llvm::Type *LLVMType(llvm::LLVMContext *ctx) const;