aboutsummaryrefslogtreecommitdiff
// TODO: count read bytes ourselves instead of relying on ftell()
#include "wasm_compile.h"
#include "wasm.h"

int leb_32(FILE *handle, uint32_t *result, bool with_sign)
{
	int i, j;
	int encoded[5];
	int64_t decoded;

	for (i = 0; i < 5; i++) {
		encoded[i] = fgetc(handle);

		if (encoded[i] == EOF) {
			PRERR(MSG_EOF);
			return -1;
		}

		if (encoded[i] < 128)
			break;

		if (i == 4) {
			PRERR(MSG_BAD_NUM_ENC);
			return -1;
		}
	}

	if (with_sign && encoded[i] & (1 << 6))
		decoded = -1;
	else
		decoded = 0;

	for (j = i; j >= 0; j--)
		decoded = (decoded << 7) | (encoded[j] & 0x7F);

	if ((with_sign && decoded > INT32_MAX) ||
	    (with_sign && decoded < INT32_MIN) ||
	    (!with_sign && decoded > UINT32_MAX)) {
		PRERR(MSG_BAD_NUM_ENC);
		return -1;
	}

	*result = decoded;

	return 0;
}

void deinitialize_resulttype(struct resulttype *type)
{
	type->count = 0;
	free(type->types);
	type->types = NULL;
}

void deinitialize_functype(struct functype *type)
{
	deinitialize_resulttype(&type->args);
	deinitialize_resulttype(&type->results);
}

void free_module(struct module *module)
{
	size_t i;

	if (!module)
		return;

	for (i = 0; i < module->functypes_count; i++)
		deinitialize_functype(module->functypes + i);

	free(module->functypes);

	for (i = 0; i < module->functions_count; i++) {
		free(module->functions[i].locals);
		free_expr(module->functions[i].translated_body);
	}

	free(module->functions);

	for (i = 0; i < module->exports_count; i++)
		free(module->exports[i].name);

	free(module->exports);

	free_targets(module->targets);

	free_expr(module->startup);

	free(module);
}

/* Guard against overflows on 32-bit systems */
static inline int safe_mul(size_t *factor1, uint32_t factor2)
{
	uint64_t product;

	product = *factor1;
	product *= factor2;

	if (product > SIZE_MAX) {
		PRERR(MSG_SIZE_OVERFLOW);
		return -1;
	}

	*factor1 = product;

	return 0;
}

int parse_resulttype(FILE *handle, struct resulttype *type)
{
	int readval;
	uint32_t i;

	if (leb_u32(handle, &type->count)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	if (type->count) {
		type->types = malloc(type->count);

		if (!type->types) {
			PRERR(MSG_ALLOC_FAIL(type->count));
			goto fail;
		}
	} else {
		type->types = NULL;
	}

	i = type->count;

	while (i--) {
		readval = fgetc(handle);

		if (readval == EOF) {
			PRERR(MSG_EOF);
			goto fail;
		}

		if (!is_valid_valtype(readval)) {
			PRERR(MSG_BAD("value type encoding", readval));
			goto fail;
		}

		type->types[i] = readval;
	}

	return 0;

fail:
	deinitialize_resulttype(type);

	return -1;
}

int parse_functype(FILE *handle, struct functype *type)
{
	int readval;

	readval = fgetc(handle);

	if (readval == EOF) {
		PRERR(MSG_EOF);
		return -1;
	}

	if (readval != 0x60) {
		PRERR(MSG_BAD("functype starting byte (0x60)", readval));
		return -1;
	}

	if (parse_resulttype(handle, &type->args) ||
	    parse_resulttype(handle, &type->results)) {
		deinitialize_functype(type);
		return -1;
	}

	return 0;
}

int parse_type_section(FILE *handle, struct module *module)
{
	uint32_t types_count;
	size_t malloc_size;
	struct functype *types = NULL;
	uint32_t types_parsed = 0;

	if (leb_u32(handle, &types_count)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	malloc_size = sizeof(struct functype);

	if (safe_mul(&malloc_size, types_count)) {
		PRERR(MSG_BAD_SIZE);
		goto fail;
	}

	types = malloc(malloc_size);

	if (!types) {
		PRERR(MSG_ALLOC_FAIL(malloc_size));
		goto fail;
	}

	while (types_parsed < types_count) {
		if (parse_functype(handle, types + types_parsed))
			goto fail;

		types_parsed++;
	}

	module->functypes_count = types_count;
	module->functypes = types;

	return 0;

fail:
	PRERR("Couldn't parse function types section\n");

	if (types) {
		while (types_parsed--)
			deinitialize_functype(types + types_parsed);

		free(types);
	}

	return -1;
}

int parse_function_section(FILE *handle, struct module *module)
{
	uint32_t funcs_count;
	size_t malloc_size;
	struct function *funcs = NULL;
	uint32_t i;
	uint32_t type_idx;

	if (leb_u32(handle, &funcs_count)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	malloc_size = sizeof(struct function);

	if (safe_mul(&malloc_size, funcs_count)) {
		PRERR(MSG_BAD_SIZE);
		goto fail;
	}

	funcs = malloc(malloc_size);

	if (!funcs) {
		PRERR(MSG_ALLOC_FAIL(malloc_size));
		goto fail;
	}

	for (i = 0; i < funcs_count; i++) {
		if (leb_u32(handle, &type_idx))  {
			PRERR(MSG_BAD_NUM);
			goto fail;
		}

		if (type_idx >= module->functypes_count) {
			PRERR("Nonexistent function type index used");
			goto fail;
		}

		if (module->functypes[type_idx].results.count > 1) {
			PRERR("Function returning more than one value\n");
			goto fail;
		}

		funcs[i].type = module->functypes + type_idx;
		funcs[i].translated_body = NULL;
	}

	module->functions_count = funcs_count;
	module->functions = funcs;

	return 0;

fail:
	PRERR("Couldn't parse functions section");

	free(funcs);

	return -1;
}

static int parse_memory_section(FILE *handle, struct module *module)
{
	// TODO: move limits parsing to separate function?
	uint32_t memories_count;
	int limits_type;

	if (leb_u32(handle, &memories_count)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	if (memories_count > 1) {
		PRERR("More than one Wasm memory\n");
		goto fail;
	}

	limits_type = fgetc(handle);

	if (limits_type == EOF) {
		PRERR(MSG_EOF);
		return -1;
	}

	if (leb_u32(handle, &module->mem_min)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	if (limits_type == 0x00) {
		module->memory_type = MEM_MIN;
	} else if (limits_type == 0x01) {
		if (leb_u32(handle, &module->mem_max)) {
			PRERR(MSG_BAD_NUM);
			goto fail;
		}

		module->memory_type = MEM_MIN_MAX;
	} else {
		PRERR(MSG_BAD("limit type", limits_type));
		goto fail;
	}

	return 0;

fail:
	module->mem_min = 0;

	return -1;
}

static int parse_export_section(FILE *handle, struct module *module)
{
	int readval;
	uint32_t exports_count;
	size_t malloc_size;
	struct export *exports = NULL;
	uint32_t exports_parsed = 0;
	uint32_t name_len;
	char *name;

	if (leb_u32(handle, &exports_count)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	malloc_size = sizeof(struct export);

	if (safe_mul(&malloc_size, exports_count)) {
		PRERR(MSG_BAD_SIZE);
		goto fail;
	}

	exports = malloc(malloc_size);

	if (!exports) {
		PRERR(MSG_ALLOC_FAIL(malloc_size));
		goto fail;
	}

	while (exports_parsed < exports_count) {
		if (leb_u32(handle, &name_len)) {
			PRERR(MSG_BAD_NUM);
			goto fail;
		}

		name = malloc(name_len + 1);

		if (!name) {
			PRERR(MSG_ALLOC_FAIL(name_len + 1));
			goto fail;
		}

		exports[exports_parsed].name = name;
		/* Increment here, so that jump to fail: frees the name */
		exports_parsed++;

		if (fread(name, name_len, 1, handle) != 1) {
			PRERR(MSG_EOF);
			goto fail;
		}

		name[name_len] = '\0';

		readval = fgetc(handle);

		if (!is_valid_exportdesc(readval)) {
			PRERR(MSG_BAD("exportdesc", readval));
			goto fail;
		}

		exports[exports_parsed - 1].desc = readval;

		if (leb_u32(handle, &exports[exports_parsed - 1].idx)) {
			PRERR(MSG_BAD_NUM);
			goto fail;
		}
	}

	module->exports_count = exports_count;
	module->exports = exports;

	return 0;

fail:
	PRERR("Couldn't parse exports section\n");

	if (exports) {
		while (exports_parsed) {
			free(exports[exports_parsed - 1].name);
			exports_parsed--;
		}

		free(exports);
	}

	return -1;
}

static int parse_function_code(FILE *handle, struct function *function,
			       struct module *module)
{
	int readval;
	uint32_t locals_blocks;
	uint32_t locals_count = 0;
	char *locals = NULL, *tmp;
	uint32_t i;
	uint32_t locals_in_block;

	if (leb_u32(handle, &locals_blocks)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	for (i = 0; i < locals_blocks; i++) {
		if (leb_u32(handle, &locals_in_block)) {
			PRERR(MSG_BAD_NUM);
			goto fail;
		}

		if (locals_count + (uint64_t) locals_in_block +
		    (uint64_t) function->type->args.count > UINT32_MAX) {
			PRERR("Too many locals\n");
			goto fail;
		}

		locals_count += locals_in_block;

		if (locals_in_block) {
			tmp = realloc(locals, locals_count);

			if (!tmp) {
				PRERR(MSG_ALLOC_FAIL(locals_count));
				goto fail;
			}

			locals = tmp;
		}

		readval = fgetc(handle);

		if (readval == EOF) {
			PRERR(MSG_EOF);
			goto fail;
		}

		if (!is_valid_valtype(readval)) {
			PRERR(MSG_BAD("value type encoding", readval));
			goto fail;
		}

		while (locals_in_block)
			locals[locals_count - locals_in_block--] = readval;
	}

	function->translated_body = NULL;

	function->locals_count = locals_count;
	function->locals = locals;

	if (translate(handle, function, module))
		goto fail;

	return 0;

fail:
	free(locals);

	return -1;
}

int parse_code_section(FILE *handle, struct module *module)
{
	uint32_t functions_count;
	uint32_t functions_parsed = 0;
	uint32_t function_size;
	long function_start, function_end;

	if (leb_u32(handle, &functions_count)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	if (functions_count != module->functions_count) {
		PRERR("Number of function bodies doesn't match number of functions\n");
		goto fail;
	}

	while (functions_parsed < functions_count) {
		if (leb_u32(handle, &function_size)) {
			PRERR(MSG_BAD_NUM);
			goto fail;
		}

		function_start = ftell(handle);

		if (parse_function_code(handle,
					module->functions + functions_parsed,
					module)) {
			PRERR("Couldn't parse code of function %lu\n",
			      (unsigned long) functions_parsed);
			goto fail;
		}

		function_end = ftell(handle);

		if (function_end - function_size != function_start) {
			PRERR("Function %lu started at offset %ld and should end at %ld, but ended at %ld\n",
			      (unsigned long) functions_parsed,
			      function_start, function_end,
			      (long) (function_start + function_size));
			goto fail;
		}

		functions_parsed++;
	}

	return 0;

fail:
	PRERR("Couldn't parse code section\n");

	while (functions_parsed) {
		free(module->functions[functions_parsed - 1].locals);
		free_expr(module->functions[functions_parsed - 1]
			  .translated_body);
		module->functions[functions_parsed - 1].translated_body = NULL;

		functions_parsed--;
	}

	return -1;
}

static const char magic[] = {0x00, 0x61, 0x73, 0x6D};
static const char version[] = {0x01, 0x00, 0x00, 0x00};

struct module *parse_module(FILE *handle)
{
	char initial[8];
	struct module *module = NULL;
	int section_id;
	char highest_section_id = 0;
	uint32_t section_size;
	long section_start, section_end;
	int (*section_parser) (FILE*, struct module*);

	if (fread(initial, 8, 1, handle) != 1) {
		PRERR(MSG_EOF);
		goto fail;
	}

	/* check magic number */
	if (memcmp(initial, magic, 4)) {
		PRERR("Bad magic number\n");
		goto fail;
	}

	/* check version */
	if (memcmp(initial + 4, version, 4)) {
		PRERR("Unsupported Wasm version: 0x%02hhx 0x%02hhx 0x%02hhx 0x%02hhx\n",
		      initial[4], initial[5], initial[6], initial[7]);
		goto fail;
	}

	module = calloc(1, sizeof(struct module));

	if (!module) {
		PRERR(MSG_ALLOC_FAIL(sizeof(struct module)));
		goto fail;
	}

	while (1) {
		section_id = fgetc(handle);

		if (section_id == EOF)
			break;

		if (leb_u32(handle, &section_size)) {
			PRERR(MSG_BAD_NUM);
			goto fail;
		}

		section_start = ftell(handle);

		if (section_id == SECTION_CUSTOM)
			continue;

		/* Sections are only allowed to appear in order */
		if (section_id <= highest_section_id) {
			PRERR("Sections out of order\n");
			goto fail;
		}

		highest_section_id = section_id;

		if (section_id == SECTION_TYPE) {
			section_parser = parse_type_section;
		} else if (section_id == SECTION_FUNCTION) {
			section_parser = parse_function_section;
		} else if (section_id == SECTION_MEMORY) {
			section_parser = parse_memory_section;
		} else if (section_id == SECTION_EXPORT) {
			section_parser = parse_export_section;
		} else if (section_id == SECTION_CODE) {
			section_parser = parse_code_section;
		} else {
			PRERR("Unknown section id: %d\n", section_id);
			goto fail;
		}

		if (section_parser(handle, module))
			goto fail;

		section_end = ftell(handle);

		if (section_end - section_size != section_start) {
			PRERR("Section %d started at offset %ld and should end at %ld, but ended at %ld\n",
			      section_id, section_start, section_end,
			      (long) (section_start + section_size));
			goto fail;
		}
	}

	return module;

fail:
	PRERR("Parsing failed\n");

	free_module(module);

	return NULL;
}