From ddb9b2fc473e3579da61bbad90e00326a59fd3c0 Mon Sep 17 00:00:00 2001 From: Evghenii Date: Fri, 24 Jan 2014 13:44:38 +0100 Subject: [PATCH] added basic printing from ptx --- builtins.cpp | 1 + builtins/builtins.c | 71 ++++++++++++++++++++++++++++++++++++++++ builtins/target-nvptx.ll | 10 +++--- stmt.cpp | 3 +- 4 files changed, 79 insertions(+), 6 deletions(-) diff --git a/builtins.cpp b/builtins.cpp index a72d20f4..57ed3808 100644 --- a/builtins.cpp +++ b/builtins.cpp @@ -415,6 +415,7 @@ lSetInternalFunctions(llvm::Module *module) { "__do_assert_uniform", "__do_assert_varying", "__do_print", + "__do_print_nvptx", "__doublebits_uniform_int64", "__doublebits_varying_int64", "__exclusive_scan_add_double", diff --git a/builtins/builtins.c b/builtins/builtins.c index f6c385fb..ee34ff54 100644 --- a/builtins/builtins.c +++ b/builtins/builtins.c @@ -185,6 +185,77 @@ void __do_print(const char *format, const char *types, int width, uint64_t mask, fflush(stdout); } +/* this is print for PTX target only */ +int __puts_nvptx(const char *); +void __do_print_nvptx(const char *format, const char *types, int width, uint64_t mask, + void **args) { + char printString[PRINT_BUF_SIZE+1]; // +1 for trailing NUL + char *bufp = &printString[0]; + char tmpBuf[256]; + + int argCount = 0; + while (*format && bufp < &printString[PRINT_BUF_SIZE]) { + // Format strings are just single percent signs. + if (*format != '%') { + *bufp++ = *format; + } + else { + if (*types) { + void *ptr = args[argCount++]; + // Based on the encoding in the types string, cast the + // value appropriately and print it with a reasonable + // printf() formatting string. + switch (*types) { + case 'b': { + sprintf(tmpBuf, "%s", *((Bool *)ptr) ? "true" : "false"); + APPEND(tmpBuf); + break; + } + case 'B': { + *bufp++ = '['; + if (bufp == &printString[PRINT_BUF_SIZE]) + break; + for (int i = 0; i < width; ++i) { + if (mask & (1ull << i)) { + sprintf(tmpBuf, "%s", ((Bool *)ptr)[i] ? "true" : "false"); + APPEND(tmpBuf); + } + else + APPEND("_________"); + *bufp++ = (i != width-1) ? ',' : ']'; + } + break; + } + case 'i': PRINT_SCALAR("%d", int); + case 'I': PRINT_VECTOR("%d", int); + case 'u': PRINT_SCALAR("%u", unsigned int); + case 'U': PRINT_VECTOR("%u", unsigned int); + case 'f': PRINT_SCALAR("%f", float); + case 'F': PRINT_VECTOR("%f", float); + case 'l': PRINT_SCALAR("%lld", long long); + case 'L': PRINT_VECTOR("%lld", long long); + case 'v': PRINT_SCALAR("%llu", unsigned long long); + case 'V': PRINT_VECTOR("%llu", unsigned long long); + case 'd': PRINT_SCALAR("%f", double); + case 'D': PRINT_VECTOR("%f", double); + case 'p': PRINT_SCALAR("%p", void *); + case 'P': PRINT_VECTOR("%p", void *); + default: + APPEND("UNKNOWN TYPE "); + *bufp++ = *types; + } + ++types; + } + } + ++format; + } + + done: + *bufp = '\n'; bufp++; + *bufp = '\0'; + __puts_nvptx(printString); +} + int __num_cores() { #if defined(_MSC_VER) || defined(__MINGW32__) diff --git a/builtins/target-nvptx.ll b/builtins/target-nvptx.ll index 4f523e54..abe2bbe7 100644 --- a/builtins/target-nvptx.ll +++ b/builtins/target-nvptx.ll @@ -1396,11 +1396,11 @@ define i32 @__puts_nvptx(i8*) alwaysinline %str = ptrtoint i8* %0 to i64 %parm = or i64 0, 0 %call = call i32 @vprintf(i64 %str, i64 %parm) - %cr = alloca <2 x i8> - store <2 x i8> , <2 x i8>* %cr - %cr1 = ptrtoint <2 x i8>* %cr to i64 - %call1 = call i32 @vprintf(i64 %cr1, i64 %parm) - ret i32 %call1; +;; %cr = alloca <3 x i8> +;; store <3 x i8> , <3 x i8>* %cr +;; %cr1 = ptrtoint <3 x i8>* %cr to i64 +;; %call1 = call i32 @vprintf(i64 %cr1, i64 %parm) + ret i32 %call; } define void @__abort_nvptx(i8* %str) noreturn { diff --git a/stmt.cpp b/stmt.cpp index 4c9a8f05..327bb4fc 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -3352,7 +3352,8 @@ PrintStmt::EmitCode(FunctionEmitContext *ctx) const { } // Now we can emit code to call __do_print() - llvm::Function *printFunc = m->module->getFunction("__do_print"); + llvm::Function *printFunc = g->target->getISA() != Target::NVPTX ? + m->module->getFunction("__do_print") : m->module->getFunction("__do_print_nvptx"); AssertPos(pos, printFunc); llvm::Value *mask = ctx->GetFullMask();