From c548c0ce5b2e7ca2784257966ebdd386e1f31218 Mon Sep 17 00:00:00 2001 From: Wojciech Kosior Date: Tue, 22 Sep 2020 16:00:53 +0200 Subject: perform type checking of translated instructions --- tools/parse_module.c | 172 +++++++++++++------------ tools/translate.c | 318 ++++++++++++++++++++++++++++++++++++++++++----- tools/translate_xmacro.h | 31 ++--- tools/wasm_compile.h | 10 +- 4 files changed, 403 insertions(+), 128 deletions(-) diff --git a/tools/parse_module.c b/tools/parse_module.c index 155fc9c..75329f8 100644 --- a/tools/parse_module.c +++ b/tools/parse_module.c @@ -68,6 +68,19 @@ int leb_u32(FILE *handle, uint32_t *result) return 0; } +void deinitialize_resulttype(struct resulttype *type) +{ + type->count = 0; + free(type->types); + type->types = NULL; +} + +void deinitialize_functype(struct functype *type) +{ + deinitialize_resulttype(&type->args); + deinitialize_resulttype(&type->results); +} + void free_module(struct module *module) { size_t i; @@ -76,7 +89,7 @@ void free_module(struct module *module) return; for (i = 0; i < module->functypes_count; i++) - free(module->functypes[i].arguments); + deinitialize_functype(module->functypes + i); free(module->functypes); @@ -113,37 +126,30 @@ static inline int safe_mul(size_t *factor1, uint32_t factor2) return 0; } -int parse_type_section(FILE *handle, struct module *module) +int parse_resulttype(FILE *handle, struct resulttype *type) { - uint32_t types_count; int readval; - size_t malloc_size; - struct functype *types = NULL; - uint32_t types_parsed = 0; - uint32_t args_count; - char *args; uint32_t i; - if (leb_u32(handle, &types_count)) { + if (leb_u32(handle, &type->count)) { PRERR(MSG_BAD_NUM); goto fail; } - malloc_size = sizeof(struct functype); + if (type->count) { + type->types = malloc(type->count); - if (safe_mul(&malloc_size, types_count)) { - PRERR(MSG_BAD_SIZE); - goto fail; + if (!type->types) { + PRERR(MSG_ALLOC_FAIL(type->count)); + goto fail; + } + } else { + type->types = NULL; } - types = malloc(malloc_size); - - if (!types) { - PRERR(MSG_ALLOC_FAIL(malloc_size)); - goto fail; - } + i = type->count; - while (types_parsed < types_count) { + while (i--) { readval = fgetc(handle); if (readval == EOF) { @@ -151,77 +157,78 @@ int parse_type_section(FILE *handle, struct module *module) goto fail; } - if (readval != 0x60) { - PRERR(MSG_BAD("functype starting byte (0x60)", - readval)); + if (!is_valid_valtype(readval)) { + PRERR(MSG_BAD("value type encoding", readval)); goto fail; } - if (leb_u32(handle, &args_count)) { - PRERR(MSG_BAD_NUM); - goto fail; - } + type->types[i] = readval; + } - if (args_count) { - args = malloc(args_count); + return 0; - if (!args) { - PRERR(MSG_ALLOC_FAIL(args_count)); - goto fail; - } - } else { - args = NULL; - } +fail: + deinitialize_resulttype(type); - types[types_parsed].arguments_count = args_count; - types[types_parsed].arguments = args; - /* Increment here, so that jump to fail: frees the args */ - types_parsed++; + return -1; +} - for (i = 0; i < args_count; i++) { - readval = fgetc(handle); +int parse_functype(FILE *handle, struct functype *type) +{ + int readval; - if (readval == EOF) { - PRERR(MSG_EOF); - goto fail; - } + readval = fgetc(handle); - if (!is_valid_valtype(readval)) { - PRERR(MSG_BAD("value type encoding", readval)); - goto fail; - } + if (readval == EOF) { + PRERR(MSG_EOF); + return -1; + } - args[i] = readval; - } + if (readval != 0x60) { + PRERR(MSG_BAD("functype starting byte (0x60)", readval)); + return -1; + } - readval = fgetc(handle); + if (parse_resulttype(handle, &type->args) || + parse_resulttype(handle, &type->results)) { + deinitialize_functype(type); + return -1; + } - if (readval == EOF) { - PRERR(MSG_EOF); - goto fail; - } + return 0; +} - if (readval == 0x00) { - types[types_parsed - 1].result = 0; - } else if (readval == 0x01) { +int parse_type_section(FILE *handle, struct module *module) +{ + uint32_t types_count; + size_t malloc_size; + struct functype *types = NULL; + uint32_t types_parsed = 0; - readval = fgetc(handle); + if (leb_u32(handle, &types_count)) { + PRERR(MSG_BAD_NUM); + goto fail; + } - if (readval == EOF) { - PRERR(MSG_EOF); - goto fail; - } + malloc_size = sizeof(struct functype); - if (!is_valid_valtype(readval)) { - PRERR(MSG_BAD("value type encoding", readval)); - goto fail; - } + if (safe_mul(&malloc_size, types_count)) { + PRERR(MSG_BAD_SIZE); + goto fail; + } - types[types_parsed - 1].result = readval; - } else { - PRERR(MSG_BAD("return values count", readval)); + types = malloc(malloc_size); + + if (!types) { + PRERR(MSG_ALLOC_FAIL(malloc_size)); + goto fail; + } + + while (types_parsed < types_count) { + if (parse_functype(handle, types + types_parsed)) goto fail; - } + + types_parsed++; } module->functypes_count = types_count; @@ -233,10 +240,8 @@ fail: PRERR("Couldn't parse function types section\n"); if (types) { - while (types_parsed) { - free(types[types_parsed - 1].arguments); - types_parsed--; - } + while (types_parsed--) + deinitialize_functype(types + types_parsed); free(types); } @@ -282,7 +287,12 @@ int parse_function_section(FILE *handle, struct module *module) goto fail; } - funcs[i].type = module->functypes + i; + if (module->functypes[type_idx].results.count > 1) { + PRERR("Function returning more than one value\n"); + goto fail; + } + + funcs[i].type = module->functypes + type_idx; funcs[i].translated_body = NULL; } @@ -444,7 +454,6 @@ static int parse_function_code(FILE *handle, struct function *function, uint32_t locals_blocks; uint32_t locals_count = 0; char *locals = NULL, *tmp; - char *body = NULL; uint32_t i; uint32_t locals_in_block; @@ -459,7 +468,8 @@ static int parse_function_code(FILE *handle, struct function *function, goto fail; } - if (locals_count + (uint64_t) locals_in_block > UINT32_MAX) { + if (locals_count + (uint64_t) locals_in_block + + (uint64_t) function->type->args.count > UINT32_MAX) { PRERR("Too many locals\n"); goto fail; } @@ -505,7 +515,6 @@ static int parse_function_code(FILE *handle, struct function *function, fail: free(locals); - free(body); return -1; } @@ -565,6 +574,7 @@ fail: free(module->functions[functions_parsed - 1].locals); free_expr(module->functions[functions_parsed - 1] .translated_body); + module->functions[functions_parsed - 1].translated_body = NULL; functions_parsed--; } diff --git a/tools/translate.c b/tools/translate.c index 97d48f7..b3d2afd 100644 --- a/tools/translate.c +++ b/tools/translate.c @@ -2,20 +2,42 @@ #include "wasm.h" #include "stack_machine_instruction.h" +struct types { + struct types *prev; + char type; /* should be one of VALTYPE_* constants from wasm.h */ +}; + struct translation { FILE *handle; struct function *function; struct module *module; + struct types *types_stack; }; -/* All functions, that go into function pointer array, start with _ */ +void free_types_stack(struct types *top) +{ + struct types *tmp; + + while (top) { + tmp = top->prev; + free(top); + top = tmp; + } +} + +/* All functions, that go into one of function pointer arrays, start with _ */ + +/** DEFINE TRANSLATION FUNCTIONS **/ /* Translate complex - those routines have to be defined manually */ -#define TC(wasm_opcode, name) +#define TC(wasm_opcode, name, argtypes, restype) + +static int typecheck_call(struct translation *data, struct function *callee); static int _translate_call(struct translation *data) { uint32_t funcidx; + struct function *func; struct instruction **target; if (leb_u32(data->handle, &funcidx)) { @@ -28,15 +50,22 @@ static int _translate_call(struct translation *data) return -1; } - target = &data->module->functions[funcidx].translated_body; + func = data->module->functions + funcidx; + + if (typecheck_call(data, func)) + return -1; + + target = &func->translated_body; return i_call(ptr(target), &data->function->translated_body); } +static int typecheck_local_get(struct translation *data, uint32_t localidx); + static int _translate_local_get(struct translation *data) { uint32_t localidx; - uint32_t args_count = data->function->type->arguments_count; + uint32_t args_count = data->function->type->args.count; uint32_t locals_count = data->function->locals_count; uint32_t all_locals_count = args_count + locals_count; uint64_t offset_on_frame; @@ -52,14 +81,19 @@ static int _translate_local_get(struct translation *data) return -1; } + if (typecheck_local_get(data, localidx)) + return -1; + offset_on_frame = all_locals_count - localidx + 1; if (localidx >= args_count) offset_on_frame -= 1; - return - i_load (im(STACK_FRAME_BACKUP_ADDR), expr) || - i_load_p(im(4 * offset_on_frame), expr); + if (i_load (im(STACK_FRAME_BACKUP_ADDR), expr) || + i_load_p(im(4 * offset_on_frame), expr)) + return -1; + + return 0; } static int _translate_const(struct translation *data) @@ -75,7 +109,7 @@ static int _translate_const(struct translation *data) } /* Translate Simple */ -#define TS(wasm_opcode, sm_instr) \ +#define TS(wasm_opcode, sm_instr, argtypes, restype) \ static int _translate_##sm_instr(struct translation *data) \ { \ return i_##sm_instr(&data->function->translated_body); \ @@ -99,7 +133,7 @@ static int translate_load_store(struct translation *data, return instr_routine(im(offset), &data->function->translated_body); } -#define TLS(wasm_opcode, sm_instr) \ +#define TLS(wasm_opcode, sm_instr, argtypes, restype) \ static int _translate_##sm_instr(struct translation *data) \ { \ return translate_load_store(data, i_##sm_instr); \ @@ -112,14 +146,19 @@ static int translate_load_store(struct translation *data, #undef TLS #undef TC +/** DEFINE TRANSLATION FUNCTIONS POINTER ARRAY **/ + /* Translate complex */ -#define TC(wasm_opcode, name) [wasm_opcode] = _translate_##name, +#define TC(wasm_opcode, name, argtypes, restype) \ + [wasm_opcode] = _translate_##name, /* Translate Simple */ -#define TS(wasm_opcode, sm_instr) TC(wasm_opcode, sm_instr) +#define TS(wasm_opcode, sm_instr, argtypes, restype) \ + TC(wasm_opcode, sm_instr, dummy, dummy) /* Translate load/store */ -#define TLS(wasm_opcode, sm_instr) TC(wasm_opcode, sm_instr) +#define TLS(wasm_opcode, sm_instr, argtypes, restype) \ + TC(wasm_opcode, sm_instr, dummy, dummy) /* The actual array of function pointers is defined here */ static int (*translation_routines[256]) (struct translation *) = { @@ -130,42 +169,262 @@ static int (*translation_routines[256]) (struct translation *) = { #undef TLS #undef TC -static int translate_expr(struct translation *data, struct functype *exprtype) +/** DEFINE ARGUMENT TYPECHECK FUNCTIONS **/ + +static int _argcheck_empty(struct types **types_stack) +{ + return 0; +} + +static int argcheck_generic_noremove(struct types *types_stack, char expected) +{ + char *name; + + name = + expected == VALTYPE_F64 ? "f64" : + expected == VALTYPE_F32 ? "f32" : + expected == VALTYPE_I64 ? "i64" : + "i32"; + + if (!types_stack) { + PRERR("Expected %s on stack, got nothing\n", name); + return -1; + } + + if (types_stack->type != VALTYPE_I32) { + PRERR("Expected %s (0x%02hhx) on stack, got 0x%02hhx\n", + name, expected, types_stack->type); + return -1; + } + + return 0; +} + +static int argcheck_generic(struct types **types_stack, char expected) +{ + struct types *top_type; + + if (argcheck_generic_noremove(*types_stack, expected)) + return -1; + + top_type = *types_stack; + *types_stack = top_type->prev; + free(top_type); + + return 0; +} + +static int _argcheck_i32(struct types **types_stack) +{ + return argcheck_generic(types_stack, VALTYPE_I32); +} + +static int _argcheck_i32_i32(struct types **types_stack) +{ + int i; + + for (i = 0; i < 2; i++) { + if (argcheck_generic(types_stack, VALTYPE_I32)) + return -1; + } + + return 0; +} + +static int _argcheck_custom(struct types **types_stack) +{ + return 0; /* Translation function will handle argument checks */ +} + +/** DEFINE RESULT TYPECHECK FUNCTIONS **/ + +static int _rescheck_empty(struct types **types_stack) +{ + return 0; +} + +static int rescheck_generic(struct types **types_stack, char returned) +{ + struct types *top; + + top = malloc(sizeof(struct types)); + + if (!top) { + PRERR(MSG_ALLOC_FAIL(sizeof(struct types))); + return -1; + } + + *top = (struct types) {.prev = *types_stack, .type = returned}; + *types_stack = top; + + return 0; +} + +static int _rescheck_i32(struct types **types_stack) +{ + return rescheck_generic(types_stack, VALTYPE_I32); +} + +static int _rescheck_custom(struct types **types_stack) +{ + return 0; +} + +/** DEFINE TYPECHECK FUNCTION POINTER ARRAY **/ + +/* Translate complex */ +#define TC(wasm_opcode, name, argtypes, restype) \ + [wasm_opcode] = {.argcheck = _argcheck_##argtypes, \ + .rescheck = _rescheck_##restype}, + +/* Translate Simple */ +#define TS(wasm_opcode, sm_instr, argtypes, restype) \ + TC(wasm_opcode, dummy, argtypes, restype) + +/* Translate load/store */ +#define TLS(wasm_opcode, sm_instr, argtypes, restype) \ + TC(wasm_opcode, dummy, argtypes, restype) + +struct typecheck { + int (*argcheck) (struct types **), (*rescheck) (struct types **); +}; + +/*static*/ struct typecheck typecheck_routines[256] = { +#include "translate_xmacro.h" +}; + +#undef TS +#undef TLS +#undef TC + +/** DEFINE CUSTOM TYPECHECK FUNCTIONS **/ + +static int typecheck_call(struct translation *data, struct function *callee) +{ + uint32_t i; + + i = callee->type->args.count; + + while (i--) { + if (argcheck_generic(&data->types_stack, + callee->type->args.types[i])) + return -1; + } + + for (i = 0; i < callee->type->results.count; i++) { + if (rescheck_generic(&data->types_stack, + callee->type->results.types[i])) + return -1; + } + + return 0; +} + +static int typecheck_local_get(struct translation *data, uint32_t localidx) +{ + uint32_t args_count = data->function->type->args.count; + char type = localidx < args_count ? + data->function->type->args.types[localidx] : + data->function->locals[localidx - args_count]; + + if (rescheck_generic(&data->types_stack, type)) + return -1; + + return 0; +} + +/** REST OF THE CODE **/ + +static int translate_instr(struct translation *data, uint8_t wasm_opcode) { + struct typecheck *tc_routines; + + if (!translation_routines[wasm_opcode]) { + PRERR("Unknown Wasm opcode: 0x%02x\n", wasm_opcode); + return -1; + } + + tc_routines = typecheck_routines + wasm_opcode; + + if (tc_routines->argcheck(&data->types_stack) || + tc_routines->rescheck(&data->types_stack) || + translation_routines[wasm_opcode](data)) + return -1; + + return 0; +} + +static int translate_expr(struct translation *data, struct resulttype *args, + struct resulttype *results) +{ + struct types **tmp, *types_stack_rest; + uint32_t i; int wasm_opcode; + tmp = &data->types_stack; + i = args ? args->count : 0; + + while (i--) { + if (argcheck_generic_noremove(*tmp, args->types[i])) + return -1; + + tmp = &(*tmp)->prev; + } + + types_stack_rest = *tmp; + *tmp = NULL; + while (1) { wasm_opcode = fgetc(data->handle); if (wasm_opcode == EOF) { PRERR(MSG_EOF); - return -1; + goto fail; } - if (wasm_opcode == WASM_END) { - return 0; - } + if (wasm_opcode == WASM_END) + break; - if (!translation_routines[wasm_opcode]) { - PRERR("Unknown Wasm opcode: 0x%02x\n", wasm_opcode); - return -1; - } + if (translate_instr(data, wasm_opcode)) + goto fail; + } - if (translation_routines[wasm_opcode](data)) - return -1; + tmp = &data->types_stack; + i = results ? results->count : 0; + + while (i--) { + if (argcheck_generic_noremove(*tmp, results->types[i])) + goto fail; + + tmp = &(*tmp)->prev; } + + if (*tmp) { + PRERR("Expression produces too many result values\n"); + goto fail; + } + + *tmp = types_stack_rest; + return 0; + +fail: + free_types_stack(data->types_stack); + data->types_stack = types_stack_rest; + + return -1; } int translate(FILE *handle, struct function *function, struct module *module) { struct instruction **expr = &function->translated_body; - uint32_t args_count = function->type->arguments_count; + uint32_t args_count = function->type->args.count; uint32_t locals_count = function->locals_count; uint32_t all_locals_count = args_count + locals_count; size_t i; struct translation data = {.handle = handle, .function = function, - .module = module}; + .module = module, + .types_stack = NULL}; if (locals_count + (uint64_t) args_count > STACK_TOP_ADDR * 4) { PRERR("Too many locals in a function\n"); @@ -186,11 +445,11 @@ int translate(FILE *handle, struct function *function, struct module *module) goto fail; /* actual function body */ - if (translate_expr(&data, function->type)) + if (translate_expr(&data, NULL, &function->type->results)) goto fail; /* function epilogue */ - if (function->type->result) { + if (function->type->results.count) { if (i_load (im(STACK_FRAME_BACKUP_ADDR), expr) || i_swap ( expr) || i_store_p(im(4 * (2 + all_locals_count - 1)), expr) || @@ -215,7 +474,10 @@ int translate(FILE *handle, struct function *function, struct module *module) } - i = locals_count + args_count + 2 + (function->type->result ? 0 : 1); + i = locals_count + args_count + 2; + + if (!function->type->results.count) + i++; while (i--) { if (i_drop(expr)) diff --git a/tools/translate_xmacro.h b/tools/translate_xmacro.h index 2fc4558..5e5bd79 100644 --- a/tools/translate_xmacro.h +++ b/tools/translate_xmacro.h @@ -1,19 +1,20 @@ /* X-macro-like definition of translation routines for each webasm opcode */ -TS(WASM_I32_ADD, add) -TS(WASM_I32_SUB, sub) -TS(WASM_I32_DIV_U, div) -TS(WASM_I32_MUL, mul) +/* wasm_opcode_______**translation_routine**argument_types**result_type*/ +TS (WASM_I32_ADD, add, i32_i32, i32) +TS (WASM_I32_SUB, sub, i32_i32, i32) +TS (WASM_I32_DIV_U, div, i32_i32, i32) +TS (WASM_I32_MUL, mul, i32_i32, i32) -TLS(WASM_I32_LOAD, load_p) -TLS(WASM_I32_LOAD8_S, loadbsx_p) -TLS(WASM_I32_LOAD8_U, loadbzx_p) -TLS(WASM_I32_LOAD16_S, loadwsx_p) -TLS(WASM_I32_LOAD16_U, loadwzx_p) -TLS(WASM_I32_STORE, store_p) -TLS(WASM_I32_STORE8, storeb_p) -TLS(WASM_I32_STORE16, storew_p) +TLS(WASM_I32_LOAD, load_p, i32, i32) +TLS(WASM_I32_LOAD8_S, loadbsx_p, i32, i32) +TLS(WASM_I32_LOAD8_U, loadbzx_p, i32, i32) +TLS(WASM_I32_LOAD16_S, loadwsx_p, i32, i32) +TLS(WASM_I32_LOAD16_U, loadwzx_p, i32, i32) +TLS(WASM_I32_STORE, store_p, i32_i32, empty) +TLS(WASM_I32_STORE8, storeb_p, i32_i32, empty) +TLS(WASM_I32_STORE16, storew_p, i32_i32, empty) -TC(WASM_CALL, call) -TC(WASM_LOCAL_GET, local_get) -TC(WASM_I32_CONST, const) +TC (WASM_CALL, call, custom, custom) +TC (WASM_LOCAL_GET, local_get, empty, custom) +TC (WASM_I32_CONST, const, empty, i32) diff --git a/tools/wasm_compile.h b/tools/wasm_compile.h index 3412b2d..e223ec6 100644 --- a/tools/wasm_compile.h +++ b/tools/wasm_compile.h @@ -24,11 +24,13 @@ #define PRERR(...) fprintf(stderr, __VA_ARGS__) -struct functype { - uint32_t arguments_count; - char *arguments; +struct resulttype { + uint32_t count; + char *types; +}; - char result; +struct functype { + struct resulttype args, results; }; struct function { -- cgit v1.2.3