12 Commits

19 changed files with 657 additions and 143 deletions

1
.gitignore vendored
View File

@@ -13,6 +13,7 @@ tests*/*cpp
tests*/*run
tests*/*.o
tests_ispcpp/*.h
tests_ispcpp/*.out
tests_ispcpp/*pre*
logs/
notify_log.log

33
ast.cpp
View File

@@ -42,8 +42,11 @@
#include "func.h"
#include "stmt.h"
#include "sym.h"
#include "type.h"
#include "util.h"
#include <map>
///////////////////////////////////////////////////////////////////////////
// ASTNode
@@ -55,14 +58,18 @@ ASTNode::~ASTNode() {
// AST
void
AST::AddFunction(Symbol *sym, Stmt *code) {
AST::AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable) {
if (sym == NULL)
return;
Function *f = new Function(sym, code);
if (f->IsPolyFunction()) {
FATAL("This is a good start, but implement me!");
std::vector<Function *> *expanded = f->ExpandPolyArguments(symbolTable);
for (size_t i=0; i<expanded->size(); i++) {
functions.push_back((*expanded)[i]);
}
delete expanded;
} else {
functions.push_back(f);
}
@@ -515,3 +522,25 @@ SafeToRunWithMaskAllOff(ASTNode *root) {
WalkAST(root, lCheckAllOffSafety, NULL, &safe);
return safe;
}
struct PolyData {
const PolyType *polyType;
const Type *replacement;
};
static ASTNode *
lTranslatePolyNode(ASTNode *node, void *d) {
struct PolyData *data = (struct PolyData*)d;
return node->ReplacePolyType(data->polyType, data->replacement);
}
ASTNode *
TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) {
struct PolyData data;
data.polyType = polyType;
data.replacement = replacement;
return WalkAST(root, NULL, lTranslatePolyNode, &data);
}

7
ast.h
View File

@@ -39,6 +39,7 @@
#define ISPC_AST_H 1
#include "ispc.h"
#include "type.h"
#include <vector>
/** @brief Abstract base class for nodes in the abstract syntax tree (AST).
@@ -67,6 +68,8 @@ public:
pointer in place of the original ASTNode *. */
virtual ASTNode *TypeCheck() = 0;
virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0;
/** Estimate the execution cost of the node (not including the cost of
the children. The value returned should be based on the COST_*
enumerant values defined in ispc.h. */
@@ -145,7 +148,7 @@ class AST {
public:
/** Add the AST for a function described by the given declaration
information and source code. */
void AddFunction(Symbol *sym, Stmt *code);
void AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable=NULL);
/** Generate LLVM IR for all of the functions into the current
module. */
@@ -204,6 +207,8 @@ extern Stmt *TypeCheck(Stmt *);
the given root. */
extern int EstimateCost(ASTNode *root);
extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement);
/** Returns true if it would be safe to run the given code with an "all
off" mask. */
extern bool SafeToRunWithMaskAllOff(ASTNode *root);

View File

@@ -112,6 +112,11 @@ Expr::GetBaseSymbol() const {
return NULL;
}
Expr *
Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this;
}
#if 0
/** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost
@@ -4669,6 +4674,23 @@ IndexExpr::TypeCheck() {
return this;
}
Expr *
IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (index == NULL || baseExpr == NULL)
return NULL;
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
lvalueType = new PointerType(to, lvalueType->GetVariability(),
lvalueType->IsConstType());
}
return this;
}
int
IndexExpr::EstimateCost() const {
@@ -5311,6 +5333,23 @@ MemberExpr::Optimize() {
return expr ? this : NULL;
}
Expr *
MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL)
return NULL;
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
lvalueType = PolyType::ReplaceType(lvalueType, lvalueType);
}
return this;
}
int
MemberExpr::EstimateCost() const {
@@ -7113,6 +7152,9 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const {
else {
const AtomicType *toAtomic = CastType<AtomicType>(toType);
// typechecking should ensure this is the case
if (!toAtomic) {
fprintf(stderr, "I want %s to be atomic\n", toType->GetString().c_str());
}
AssertPos(pos, toAtomic != NULL);
return lTypeConvAtomic(ctx, exprVal, toAtomic, fromAtomic, pos);
@@ -7342,6 +7384,18 @@ TypeCastExpr::Optimize() {
return this;
}
Expr *
TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (type == NULL)
return NULL;
if (Type::EqualForReplacement(type->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
return this;
}
int
TypeCastExpr::EstimateCost() const {
@@ -8012,6 +8066,18 @@ SymbolExpr::Optimize() {
return this;
}
Expr *
SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!symbol)
return NULL;
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
symbol->type = PolyType::ReplaceType(symbol->type, to);
}
return this;
}
int
SymbolExpr::EstimateCost() const {
@@ -8810,6 +8876,18 @@ NewExpr::Optimize() {
return this;
}
Expr *
NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!allocType)
return this;
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
allocType = PolyType::ReplaceType(allocType, to);
}
return this;
}
void
NewExpr::Print() const {

9
expr.h
View File

@@ -96,6 +96,10 @@ public:
encountered, NULL should be returned. */
virtual Expr *TypeCheck() = 0;
/** This method replaces a polymorphic type with a specific atomic type */
Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement);
/** Prints the expression to standard output (used for debugging). */
virtual void Print() const = 0;
};
@@ -324,6 +328,7 @@ public:
Expr *Optimize();
Expr *TypeCheck();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
Expr *baseExpr, *index;
@@ -357,6 +362,7 @@ public:
void Print() const;
Expr *Optimize();
Expr *TypeCheck();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
virtual int getElementNumber() const = 0;
@@ -522,6 +528,7 @@ public:
void Print() const;
Expr *TypeCheck();
Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
Symbol *GetBaseSymbol() const;
llvm::Constant *GetConstant(const Type *type) const;
@@ -681,6 +688,7 @@ public:
Symbol *GetBaseSymbol() const;
Expr *TypeCheck();
Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const;
int EstimateCost() const;
@@ -809,6 +817,7 @@ public:
const Type *GetType() const;
Expr *TypeCheck();
Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const;
int EstimateCost() const;

View File

@@ -45,6 +45,7 @@
#include "sym.h"
#include "util.h"
#include <stdio.h>
#include <set>
#if ISPC_LLVM_VERSION == ISPC_LLVM_3_2 // 3.2
#ifdef ISPC_NVPTX_ENABLED
@@ -638,3 +639,46 @@ Function::IsPolyFunction() const {
return false;
}
std::vector<Function *> *
Function::ExpandPolyArguments(SymbolTable *symbolTable) const {
Assert(symbolTable != NULL);
std::vector<Function *> *expanded = new std::vector<Function *>();
std::vector<Symbol *> versions = symbolTable->LookupPolyFunction(sym->name.c_str());
const FunctionType *func = CastType<FunctionType>(sym->type);
printf("%s before replacing anything:\n", sym->name.c_str());
code->Print(0);
for (size_t i=0; i<versions.size(); i++) {
const FunctionType *ft = CastType<FunctionType>(versions[i]->type);
Stmt *ncode = code;
for (int j=0; j<ft->GetNumParameters(); j++) {
if (func->GetParameterType(j)->IsPolymorphicType()) {
const PolyType *from = CastType<PolyType>(
func->GetParameterType(j)->GetBaseType());
ncode = (Stmt*)TranslatePoly(ncode, from,
ft->GetParameterType(j)->GetBaseType());
printf("%s after replacing %s with %s:\n\n",
sym->name.c_str(), from->GetString().c_str(),
ft->GetParameterType(j)->GetBaseType()->GetString().c_str());
ncode->Print(0);
printf("------------------------------------------\n\n");
}
}
Symbol *s = symbolTable->LookupFunction(versions[i]->name.c_str(), ft);
expanded->push_back(new Function(s, ncode));
}
return expanded;
}

3
func.h
View File

@@ -39,6 +39,7 @@
#define ISPC_FUNC_H 1
#include "ispc.h"
#include "sym.h"
#include <vector>
class Function {
@@ -54,6 +55,8 @@ public:
/** Checks if the function has polymorphic parameters */
const bool IsPolyFunction() const;
std::vector<Function *> *ExpandPolyArguments(SymbolTable *symbolTable) const;
private:
void emitCode(FunctionEmitContext *ctx, llvm::Function *function,
SourcePos firstStmtPos);

View File

@@ -1009,6 +1009,91 @@ Module::AddFunctionDeclaration(const std::string &name,
}
}
/* Handle Polymorphic functions
* a function
* int foo(number n, floating, f)
* will produce versions such as
* int foo(int n, float f)
*
* these functions will be overloaded if they are not exported, or mangled
* if exported */
std::set<const Type *, bool(*)(const Type*, const Type*)> toExpand(&PolyType::Less);
std::vector<const FunctionType *> expanded;
expanded.push_back(functionType);
for (int i=0; i<functionType->GetNumParameters(); i++) {
const Type *param = functionType->GetParameterType(i);
if (param->IsPolymorphicType() &&
!toExpand.count(param->GetBaseType())) {
toExpand.insert(param->GetBaseType());
}
}
std::vector<const FunctionType *> nextExpanded;
std::set<const Type*>::iterator iter;
for (iter = toExpand.begin(); iter != toExpand.end(); iter++) {
for (size_t j=0; j<expanded.size(); j++) {
const FunctionType *eft = expanded[j];
const PolyType *pt=CastType<PolyType>(*iter);
std::vector<AtomicType *>::iterator te;
for (te = pt->ExpandBegin(); te != pt->ExpandEnd(); te++) {
llvm::SmallVector<const Type *, 8> nargs;
llvm::SmallVector<std::string, 8> nargsn;
llvm::SmallVector<Expr *, 8> nargsd;
llvm::SmallVector<SourcePos, 8> nargsp;
for (size_t k=0; k<eft->GetNumParameters(); k++) {
if (Type::Equal(eft->GetParameterType(k)->GetBaseType(),
pt)) {
const Type *r;
r = PolyType::ReplaceType(eft->GetParameterType(k),*te);
nargs.push_back(r);
} else {
nargs.push_back(eft->GetParameterType(k));
}
nargsn.push_back(eft->GetParameterName(k));
nargsd.push_back(eft->GetParameterDefault(k));
nargsp.push_back(eft->GetParameterSourcePos(k));
}
nextExpanded.push_back(new FunctionType(eft->GetReturnType(),
nargs,
nargsn,
nargsd,
nargsp,
eft->isTask,
eft->isExported,
eft->isExternC,
eft->isUnmasked));
}
}
expanded.swap(nextExpanded);
nextExpanded.clear();
}
if (expanded.size() > 1) {
for (size_t i=0; i<expanded.size(); i++) {
std::string nname = name;
if (functionType->isExported || functionType->isExternC) {
for (int j=0; j<expanded[i]->GetNumParameters(); j++) {
nname += "_";
nname += expanded[i]->GetParameterType(j)->Mangle();
}
}
symbolTable->MapPolyFunction(name, nname, expanded[i]);
AddFunctionDeclaration(nname, expanded[i], storageClass,
isInline, pos);
}
return;
}
// Get the LLVM FunctionType
bool disableMask = (storageClass == SC_EXTERN_C);
llvm::FunctionType *llvmFunctionType =
@@ -1177,14 +1262,7 @@ Module::AddFunctionDefinition(const std::string &name, const FunctionType *type,
sym->pos = code->pos;
// FIXME: because we encode the parameter names in the function type,
// we need to override the function type here in case the function had
// earlier been declared with anonymous parameter names but is now
// defined with actual names. This is yet another reason we shouldn't
// include the names in FunctionType...
sym->type = type;
ast->AddFunction(sym, code);
ast->AddFunction(sym, code, symbolTable);
}
@@ -1899,6 +1977,27 @@ lPrintFunctionDeclarations(FILE *file, const std::vector<Symbol *> &funcs,
// fprintf(file, "#ifdef __cplusplus\n} /* end extern C */\n#endif // __cplusplus\n");
}
static void
lPrintPolyFunctionWrappers(FILE *file, const std::vector<std::string> &funcs) {
fprintf(file, "#if defined(__cplusplus)\n");
for (size_t i=0; i<funcs.size(); i++) {
std::vector<Symbol *> poly = m->symbolTable->LookupPolyFunction(funcs[i].c_str());
for (size_t j=0; j<poly.size(); j++) {
const FunctionType *ftype = CastType<FunctionType>(poly[j]->type);
Assert(ftype);
std::string decl = ftype->GetCDeclaration(funcs[i]);
fprintf(file, " %s {\n", decl.c_str());
std::string call = ftype->GetCCall(poly[j]->name);
fprintf(file, " return %s;\n }\n", call.c_str());
}
}
fprintf(file, "#endif // __cplusplus\n");
}
@@ -2275,8 +2374,10 @@ Module::writeHeader(const char *fn) {
// Collect single linear arrays of the exported and extern "C"
// functions
std::vector<Symbol *> exportedFuncs, externCFuncs;
std::vector<std::string> polyFuncs;
m->symbolTable->GetMatchingFunctions(lIsExported, &exportedFuncs);
m->symbolTable->GetMatchingFunctions(lIsExternC, &externCFuncs);
m->symbolTable->GetPolyFunctions(&polyFuncs);
// Get all of the struct, vector, and enumerant types used as function
// parameters. These vectors may have repeats.
@@ -2313,6 +2414,16 @@ Module::writeHeader(const char *fn) {
fprintf(f, "///////////////////////////////////////////////////////////////////////////\n");
lPrintFunctionDeclarations(f, exportedFuncs);
}
// emit wrappers for polymorphic functions
if (polyFuncs.size() > 0) {
fprintf(f, "\n");
fprintf(f, "///////////////////////////////////////////////////////////////////////////\n");
fprintf(f, "// Polymorphic function wrappers\n");
fprintf(f, "///////////////////////////////////////////////////////////////////////////\n");
lPrintPolyFunctionWrappers(f, polyFuncs);
}
#if 0
if (externCFuncs.size() > 0) {
fprintf(f, "\n");

View File

@@ -77,6 +77,11 @@ Stmt::Optimize() {
return this;
}
Stmt *
Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this;
}
///////////////////////////////////////////////////////////////////////////
// ExprStmt
@@ -494,6 +499,18 @@ DeclStmt::TypeCheck() {
return encounteredError ? NULL : this;
}
Stmt *
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
for (size_t i = 0; i < vars.size(); i++) {
Symbol *s = vars[i].sym;
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
s->type = PolyType::ReplaceType(s->type, to);
}
}
return this;
}
void
DeclStmt::Print(int indent) const {
@@ -2174,6 +2191,21 @@ ForeachStmt::TypeCheck() {
return anyErrors ? NULL : this;
}
Stmt *
ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) {
if (!stmts)
return NULL;
for (size_t i=0; i<dimVariables.size(); i++) {
const Type *t = dimVariables[i]->type;
if (Type::EqualForReplacement(t->GetBaseType(), from)) {
t = PolyType::ReplaceType(t, to);
}
}
return this;
}
int
ForeachStmt::EstimateCost() const {

3
stmt.h
View File

@@ -70,6 +70,7 @@ public:
// Stmts don't have anything to do here.
virtual Stmt *Optimize();
virtual Stmt *TypeCheck() = 0;
Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement);
};
@@ -117,6 +118,7 @@ public:
Stmt *Optimize();
Stmt *TypeCheck();
Stmt *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
std::vector<VariableDeclaration> vars;
@@ -281,6 +283,7 @@ public:
void Print(int indent) const;
Stmt *TypeCheck();
Stmt *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const;
std::vector<Symbol *> dimVariables;

27
sym.cpp
View File

@@ -157,6 +157,14 @@ SymbolTable::AddFunction(Symbol *symbol) {
return true;
}
void
SymbolTable::MapPolyFunction(std::string name, std::string polyname,
const FunctionType *type) {
std::vector<Symbol *> &polyExpansions = polyFunctions[name];
SourcePos p;
polyExpansions.push_back(new Symbol(polyname, p, type, SC_NONE));
}
bool
SymbolTable::LookupFunction(const char *name, std::vector<Symbol *> *matches) {
@@ -184,9 +192,28 @@ SymbolTable::LookupFunction(const char *name, const FunctionType *type) {
return funcs[j];
}
}
// Try looking for a polymorphic function
if (polyFunctions[name].size() > 0) {
std::string n = name;
return new Symbol(name, polyFunctions[name][0]->pos, type);
}
return NULL;
}
std::vector<Symbol *>&
SymbolTable::LookupPolyFunction(const char *name) {
return polyFunctions[name];
}
void
SymbolTable::GetPolyFunctions(std::vector<std::string> *funcs) {
FunctionMapType::iterator it = polyFunctions.begin();
for (; it != polyFunctions.end(); it++) {
funcs->push_back(it->first);
}
}
bool
SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) {

15
sym.h
View File

@@ -108,6 +108,7 @@ public:
};
/** @brief Symbol table that holds all known symbols during parsing and compilation.
A single instance of a SymbolTable is stored in the Module class
@@ -159,6 +160,14 @@ public:
already present in the symbol table. */
bool AddFunction(Symbol *symbol);
/** Adds the given function to the list of polymorphic definitions for the
given name
@param name The name of the original function
@param type The expanded FunctionType */
void MapPolyFunction(std::string name, std::string polyname,
const FunctionType *type);
/** Looks for the function or functions with the given name in the
symbol name. If a function has been overloaded and multiple
definitions are present for a given function name, all of them will
@@ -174,6 +183,10 @@ public:
@return pointer to matching Symbol; NULL if none is found. */
Symbol *LookupFunction(const char *name, const FunctionType *type);
std::vector<Symbol *>& LookupPolyFunction(const char *name);
void GetPolyFunctions(std::vector<std::string> *funcs);
/** Returns all of the functions in the symbol table that match the given
predicate.
@@ -276,6 +289,8 @@ private:
typedef std::map<std::string, std::vector<Symbol *> > FunctionMapType;
FunctionMapType functions;
FunctionMapType polyFunctions;
/** Type definitions can't currently be scoped.
*/
typedef std::map<std::string, const Type *> TypeMapType;

13
tests_ispcpp/Makefile Normal file
View File

@@ -0,0 +1,13 @@
CXX=g++
CXXFLAGS=-std=c++11
ISPC=../ispc
ISPCFLAGS=--target=sse4-x2 -O2 --arch=x86-64
%.out : %.cpp %.o
$(CXX) $(CXXFLAGS) -o $@ $^
$ : $.o
%.o : %.ispc
$(ISPC) $(ISPCFLAGS) -h $*.h -o $*.o $<

View File

@@ -1,4 +1,4 @@
export void foo(uniform int N, floating$1 X[])
export void foo(uniform int N, uniform floating$1 X[])
{
foreach (i = 0 ... N) {
X[i] = X[i] + 1.0;

135
type.cpp
View File

@@ -249,6 +249,15 @@ Type::IsVoidType() const {
bool
Type::IsPolymorphicType() const {
const FunctionType *ft = CastType<FunctionType>(this);
if (ft) {
for (int i=0; i<ft->GetNumParameters(); i++) {
if (ft->GetParameterType(i)->IsPolymorphicType())
return true;
}
return false;
}
return (CastType<PolyType>(GetBaseType()) != NULL);
}
@@ -694,16 +703,68 @@ const PolyType *PolyType::UniformNumber =
const PolyType *PolyType::VaryingNumber =
new PolyType(PolyType::TYPE_NUMBER, Variability::Varying, false);
const Type *
PolyType::ReplaceType(const Type *from, const Type *to) {
const Type *t = to;
if (from->IsPointerType()) {
t = new PointerType(to,
from->GetVariability(),
from->IsConstType());
} else if (from->IsArrayType()) {
t = new ArrayType(to,
CastType<ArrayType>(from)->GetElementCount());
} else if (from->IsReferenceType()) {
t = new ReferenceType(to);
}
if (from->IsVaryingType())
t = t->GetAsVaryingType();
fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n",
from->GetString().c_str(),
t->GetString().c_str());
return t;
}
bool
PolyType::Less(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
if (!pa || !pb) {
char buf[1024];
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
"\"%s\" and \"%s\"\n",
a->GetString().c_str(), b->GetString().c_str());
FATAL(buf);
}
if (pa->restriction < pb->restriction)
return true;
if (pa->restriction > pb->restriction)
return false;
if (pa->GetQuant() < pb->GetQuant())
return true;
return false;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic, int q)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(q) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
@@ -814,6 +875,39 @@ PolyType::GetAsUniformType() const {
return asUniformType;
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandBegin() const {
if (expandedTypes)
return expandedTypes->begin();
expandedTypes = new std::vector<AtomicType *>();
if (restriction == TYPE_INTEGER || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT64, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT64, variability, isConst));
}
if (restriction == TYPE_FLOATING || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_FLOAT, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_DOUBLE, variability, isConst));
}
return expandedTypes->begin();
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandEnd() const {
Assert(expandedTypes != NULL);
return expandedTypes->end();
}
const PolyType *
PolyType::GetAsUnboundVariabilityType() const {
@@ -3500,6 +3594,34 @@ FunctionType::GetCDeclaration(const std::string &fname) const {
return ret;
}
std::string
FunctionType::GetCCall(const std::string &fname) const {
std::string ret;
ret += fname;
ret += "(";
for (unsigned int i = 0; i < paramTypes.size(); ++i) {
const Type *type = paramTypes[i];
// Convert pointers to arrays to unsized arrays, which are more clear
// to print out for multidimensional arrays (i.e. "float foo[][4] "
// versus "float (foo *)[4]").
const PointerType *pt = CastType<PointerType>(type);
if (pt != NULL &&
CastType<ArrayType>(pt->GetBaseType()) != NULL) {
type = new ArrayType(pt->GetBaseType(), 0);
}
if (paramNames[i] != "")
ret += paramNames[i];
else
FATAL("Exporting a polymorphic function with incomplete arguments");
if (i != paramTypes.size() - 1)
ret += ", ";
}
ret += ")";
return ret;
}
std::string
FunctionType::GetCDeclarationForDispatch(const std::string &fname) const {
@@ -4080,3 +4202,16 @@ bool
Type::EqualIgnoringConst(const Type *a, const Type *b) {
return lCheckTypeEquality(a, b, true);
}
bool
Type::EqualForReplacement(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a);
const PolyType *pb = CastType<PolyType>(b);
if (!pa || !pb)
return false;
return pa->restriction == pb->restriction &&
pa->GetQuant() == pb->GetQuant();
}

13
type.h
View File

@@ -244,6 +244,8 @@ public:
the same (ignoring const-ness of the type), false otherwise. */
static bool EqualIgnoringConst(const Type *a, const Type *b);
static bool EqualForReplacement(const Type *a, const Type *b);
/** Given two types, returns the least general Type that is more general
than both of them. (i.e. that can represent their values without
any loss of data.) If there is no such Type, return NULL.
@@ -360,10 +362,10 @@ public:
static const AtomicType *UniformDouble, *VaryingDouble;
static const AtomicType *Void;
AtomicType(BasicType basicType, Variability v, bool isConst);
private:
const Variability variability;
const bool isConst;
AtomicType(BasicType basicType, Variability v, bool isConst);
mutable const AtomicType *asOtherConstType, *asUniformType, *asVaryingType;
};
@@ -413,6 +415,10 @@ public:
const PolyRestriction restriction;
static const Type * ReplaceType(const Type *from, const Type *to);
static bool Less(const Type *a, const Type *b);
static const PolyType *UniformInteger, *VaryingInteger;
static const PolyType *UniformFloating, *VaryingFloating;
@@ -420,7 +426,8 @@ public:
// Returns the list of AtomicTypes that are valid instantiations of the
// polymorphic type
const std::vector<AtomicType *> GetEnumeratedTypes() const;
const std::vector<AtomicType *>::iterator ExpandBegin() const;
const std::vector<AtomicType *>::iterator ExpandEnd() const;
private:
const Variability variability;
@@ -430,6 +437,7 @@ private:
PolyType(PolyRestriction type, Variability v, bool isConst, int quant);
mutable const PolyType *asOtherConstType, *asUniformType, *asVaryingType;
mutable std::vector<AtomicType *> *expandedTypes;
};
@@ -980,6 +988,7 @@ public:
std::string GetString() const;
std::string Mangle() const;
std::string GetCDeclaration(const std::string &fname) const;
std::string GetCCall(const std::string &fname) const;
std::string GetCDeclarationForDispatch(const std::string &fname) const;
llvm::Type *LLVMType(llvm::LLVMContext *ctx) const;