aboutsummaryrefslogtreecommitdiff
#include "wasm_compile.h"
#include "wasm.h"
#include "stack_machine_instruction.h"

struct target {
	struct instruction *instr;
	struct target *prev;
};

struct label {
	struct label *prev;
	struct target *target;
	struct resulttype *arity;
	uint32_t values_on_stack;
};

struct types {
	struct types *prev;
	char type; /* should be one of VALTYPE_* constants from wasm.h */
	int refs;
};

struct translation {
	FILE *handle;
	struct function *function;
	struct module *module;
	struct types *types_stack;
	struct label *labels;
};

struct end_markers {
	int count;
	const uint8_t *codes;
};

void free_targets(struct target *top)
{
	struct target *tmp;

	while (top) {
		tmp = top->prev;
		free(top);
		top = tmp;
	}
}

inline static void get_type(struct types *type)
{
	if (type)
		type->refs++;
}

static void put_type(struct types *type)
{
	struct types *tmp;

	while (type && !--type->refs) {
		tmp = type->prev;
		free(type);
		type = tmp;
	}
}

static uint32_t stack_size(struct types *types_stack)
{
	uint32_t count = 0;

	while (types_stack) {
		count++;
		types_stack = types_stack->prev;
	}

	return count;
}

static void free_types_stack(struct types *top)
{
	struct types *tmp;

	while (top) {
		tmp = top->prev;
		free(top);
		top = tmp;
	}
}

static int translate_expr(struct translation *data, struct resulttype *args,
			  struct resulttype *results,
			  const struct end_markers *end_markers,
			  char *marker_found, bool continuation_at_start);

static struct target *add_target(struct module *module);

static int parse_blocktype(FILE *handle, struct resulttype *args,
			   struct resulttype *results, char *storage,
			   struct module *module);

/* All functions, that go into one of function pointer arrays, start with _ */

/** DEFINE ARGUMENT TYPECHECK FUNCTIONS **/

static int _argcheck_empty(struct types **types_stack)
{
	return 0;
}

static int argcheck_generic_noremove(struct types *types_stack, char expected)
{
	char *name;

	name =
		expected == VALTYPE_F64 ? "f64" :
		expected == VALTYPE_F32 ? "f32" :
		expected == VALTYPE_I64 ? "i64" :
		"i32";

	if (!types_stack) {
		PRERR("Expected %s on stack, got nothing\n", name);
		return -1;
	}

	if (types_stack->type != VALTYPE_I32) {
		PRERR("Expected %s (0x%02hhx) on stack, got 0x%02hhx\n",
		      name, expected, types_stack->type);
		return -1;
	}

	return 0;
}

static int argcheck_generic(struct types **types_stack, char expected)
{
	struct types *top_type;

	if (argcheck_generic_noremove(*types_stack, expected))
		return -1;

	top_type = *types_stack;
	*types_stack = top_type->prev;

	if (*types_stack)
		get_type(*types_stack);

	put_type(top_type);

	return 0;
}

static int _argcheck_i32(struct types **types_stack)
{
	return argcheck_generic(types_stack, VALTYPE_I32);
}

static int _argcheck_i32_i32(struct types **types_stack)
{
	int i;

	for (i = 0; i < 2; i++) {
		if (argcheck_generic(types_stack, VALTYPE_I32))
			return -1;
	}

	return 0;
}

static int _argcheck_custom(struct types **types_stack)
{
	return 0; /* Translation function will handle argument checks */
}

/** DEFINE RESULT TYPECHECK FUNCTIONS **/

static int _rescheck_empty(struct types **types_stack)
{
	return 0;
}

static int rescheck_generic(struct types **types_stack, char returned)
{
	struct types *top;

	top = malloc(sizeof(struct types));

	if (!top) {
		PRERR(MSG_ALLOC_FAIL(sizeof(struct types)));
		return -1;
	}

	*top = (struct types) {.prev = *types_stack, .type = returned, .refs = 1};
	*types_stack = top;

	return 0;
}

static int _rescheck_i32(struct types **types_stack)
{
	return rescheck_generic(types_stack, VALTYPE_I32);
}

static int _rescheck_custom(struct types **types_stack)
{
	return 0;
}

/** DEFINE TYPECHECK FUNCTION POINTER ARRAY **/

/* Translate complex */
#define TC(wasm_opcode, name, argtypes, restype)		\
	[wasm_opcode] = {.argcheck = _argcheck_##argtypes,	\
			 .rescheck = _rescheck_##restype},

/* Translate Simple */
#define TS(wasm_opcode, sm_instr, argtypes, restype)	\
	TC(wasm_opcode, dummy, argtypes, restype)

/* Translate load/store */
#define TLS(wasm_opcode, sm_instr, argtypes, restype)	\
	TC(wasm_opcode, dummy, argtypes, restype)

struct typecheck {
	int (*argcheck) (struct types **), (*rescheck) (struct types **);
};

static struct typecheck typecheck_routines[256] = {
#include "translate_xmacro.h"
};

#undef TS
#undef TLS
#undef TC

/** DEFINE CUSTOM TYPECHECK FUNCTIONS **/

/*
 * Each of these is called by its respective translate function. In some cases
 * it was not feasible to move some instruction's typecheck to a separate
 * function, so there are less functions here than instructions with declared
 * "custom" typechecking.
 */

static int typecheck_call(struct translation *data, struct function *callee)
{
	uint32_t i;

	i = callee->type->args.count;

	while (i--) {
		if (argcheck_generic(&data->types_stack,
				     callee->type->args.types[i]))
			return -1;
	}

	for (i = 0; i < callee->type->results.count; i++) {
		if (rescheck_generic(&data->types_stack,
				     callee->type->results.types[i]))
			return -1;
	}

	return 0;
}

static int typecheck_local_get_set(struct translation *data, uint32_t localidx,
				   bool set)
{
	uint32_t args_count = data->function->type->args.count;
	char type = localidx < args_count ?
		data->function->type->args.types[localidx] :
		data->function->locals[localidx - args_count];

	return (set ? argcheck_generic : rescheck_generic)
		(&data->types_stack, type);
}

/** DEFINE INSTRUCTION TRANSLATION FUNCTIONS **/

/* Translate complex - those routines have to be defined manually */
#define TC(wasm_opcode, name, argtypes, restype)

static int translate_block_loop(struct translation *data, bool loop)
{
	struct resulttype block_args, block_results;
	char type_storage;
	static const uint8_t loop_end_marker_code = WASM_END;
	static const struct end_markers loop_end_markers = {
		.count = 1,
		.codes = &loop_end_marker_code
	};

	if (parse_blocktype(data->handle, &block_args, &block_results,
			    &type_storage, data->module))
		goto fail;

	if (translate_expr(data, &block_args, &block_results,
			   &loop_end_markers, NULL, loop))
		goto fail;

	return 0;

fail:
	PRERR("Couldn't translate loop instruction\n");
	return -1;
}

static int _translate_block(struct translation *data)
{
	return translate_block_loop(data, false);
}

static int _translate_loop(struct translation *data)
{
	return translate_block_loop(data, true);
}

static int _translate_if(struct translation *data)
{
	struct types *backed_stack;
	struct resulttype block_args, block_results;
	char type_storage;
	struct target *if_end, *else_end;
	struct instruction **expr = &data->function->translated_body;
	static const uint8_t if_end_markers_codes[2] = {WASM_ELSE, WASM_END};
	static const struct end_markers if_end_markers = {
		.count = 2,
		.codes = if_end_markers_codes
	};
	static const struct end_markers else_end_markers = {
		.count = 1,
		.codes = if_end_markers_codes + 1
	};
	char marker_found;
	int retval;

	if (parse_blocktype(data->handle, &block_args, &block_results,
			    &type_storage, data->module))
		goto fail;

	if_end = add_target(data->module);
	else_end = add_target(data->module);

	if (!if_end || !else_end)
		goto fail;

	if (i_cond_jump_n(ptr_after(&if_end->instr), expr))
		goto fail;

	backed_stack = data->types_stack;
	get_type(backed_stack);

	retval = translate_expr(data, &block_args, &block_results,
				&if_end_markers, &marker_found, false);

	put_type(data->types_stack);
	data->types_stack = backed_stack;

	if (retval)
		goto fail;

	if (i_jump(ptr_after(&else_end->instr), expr))
		goto fail;

	if_end->instr = data->function->translated_body->prev;

	if (marker_found == WASM_END)
		ungetc(WASM_END, data->handle);

	if (translate_expr(data, &block_args, &block_results,
			   &else_end_markers, NULL, false))
		goto fail;

	else_end->instr = data->function->translated_body->prev;

	return 0;

fail:
	PRERR("Couldn't translate if-else instruction\n");

	return -1;
}

static int _translate_br(struct translation *data)
{
	uint32_t labelidx, i;
	struct label *label;
	uint32_t arity, values_on_stack;
	uint32_t shift, offset_src, offset_dst;
	struct instruction **expr = &data->function->translated_body;

	if (leb_u32(data->handle, &labelidx)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	label = data->labels;
	i = labelidx;

	while (i--) {
		label = label->prev;

		if (!label) {
			PRERR(MSG_BAD_IDX("labelidx"));
			goto fail;
		}
	}

	values_on_stack = stack_size(data->types_stack);
	arity = label->arity ? label->arity->count : 0;

	if (arity > values_on_stack) {
		PRERR("Need %lu values on stack to branch, only have %lu\n",
		      (unsigned long) arity, (unsigned long) values_on_stack);
		goto fail;
	}

	i = arity;

	while (i--) {
		if (argcheck_generic(&data->types_stack,
				     label->arity->types[i]))
			goto fail;
	}

	shift = data->labels->values_on_stack + values_on_stack -
		(label->values_on_stack + arity);
	offset_dst = label->values_on_stack + 2;
	offset_src = offset_dst + shift;

	if (!shift)
		goto values_moved;

	if (i_load  (im(STACK_FRAME_BACKUP_ADDR), expr) ||
	    i_tee   (                             expr) ||
	    i_load_p(im(-4 * (offset_dst - 1)),   expr) ||
	    i_swap  (                             expr) ||
	    i_load_p(im(-4 * offset_dst),         expr) ||
	    i_add_sp(im(4 * (shift + 2)),         expr))
		goto fail;

	for (i = 0; i < arity; i++) {
		if (i == 0 && i_load(im(STACK_FRAME_BACKUP_ADDR), expr))
			goto fail;

		if (i + 1 != arity && (i_tee(expr)))
			goto fail;

		if (i_load_p(im(-4 * (offset_src + i)), expr))
			goto fail;

		if (i + 1 != arity && (i_swap(expr)))
			goto fail;
	}

values_moved:
	if (i_jump(ptr_after(&label->target->instr), expr))
		goto fail;

	return 0;

fail:
	PRERR("Couldn't translate br instruction\n");
	return -1;
}

static int _translate_br_if(struct translation *data)
{
	struct target *br_end;
	struct instruction **expr = &data->function->translated_body;
	struct types *backed_stack = NULL;
	int retval = -1;

	br_end = add_target(data->module);

	if (!br_end)
		goto fail;

	if (i_cond_jump_n(ptr_after(&br_end->instr), expr))
		goto fail;

	backed_stack = data->types_stack;
	get_type(backed_stack);

	if (_translate_br(data))
		goto fail;

	br_end->instr = data->function->translated_body->prev;
	retval = 0;

fail:
	if (backed_stack) {
		put_type(data->types_stack);
		data->types_stack = backed_stack;
	}

	if (retval)
		PRERR("Couldn't translate br_if instruction\n");

	return retval;
}

static int _translate_call(struct translation *data)
{
	uint32_t funcidx;
	struct function *func;
	struct instruction **target;

	if (leb_u32(data->handle, &funcidx)) {
		PRERR(MSG_BAD_NUM);
		return -1;
	}

	if (funcidx >= data->module->functions_count) {
		PRERR(MSG_BAD_IDX("funcidx"));
		return -1;
	}

	func = data->module->functions + funcidx;

	if (typecheck_call(data, func))
		return -1;

	target = &func->translated_body;

	return i_call(ptr(target), &data->function->translated_body);
}

static int translate_local_get_set(struct translation *data, bool set)
{
	uint32_t localidx;
	uint32_t args_count = data->function->type->args.count;
	uint32_t locals_count = data->function->locals_count;
	uint32_t all_locals_count = args_count + locals_count;
	uint64_t offset_on_frame;
	struct instruction **expr = &data->function->translated_body;

	if (leb_u32(data->handle, &localidx)) {
		PRERR(MSG_BAD_NUM);
		return -1;
	}

	if (localidx >= all_locals_count) {
		PRERR(MSG_BAD_IDX("localidx"));
		return -1;
	}

	if (typecheck_local_get_set(data, localidx, set))
		return -1;

	offset_on_frame = all_locals_count - localidx + 1;

	if (localidx >= args_count)
		offset_on_frame -= 1;

	if (i_load(im(STACK_FRAME_BACKUP_ADDR), expr))
		return -1;

	if (set) {
		return
			i_swap   (                             expr) ||
			i_store_p(im(4 * offset_on_frame),     expr);
	} else {
		return
			i_load_p(im(4 * offset_on_frame), expr);
	}
}

static int _translate_local_get(struct translation *data)
{
	return translate_local_get_set(data, false);
}

static int _translate_local_set(struct translation *data)
{
	return translate_local_get_set(data, true);
}

static int _translate_const(struct translation *data)
{
	int32_t constant;

	if (leb_s32(data->handle, &constant)) {
		PRERR(MSG_BAD_NUM);
		return -1;
	}

	return i_const(im(constant), &data->function->translated_body);
}

/* Translate Simple */
#define TS(wasm_opcode, sm_instr, argtypes, restype)			\
	static int _translate_##sm_instr(struct translation *data)	\
	{								\
		return i_##sm_instr(&data->function->translated_body);	\
	}

/* Translate load/store */
static int translate_load_store(struct translation *data,
				int (*instr_routine) (struct instruction_data,
						      struct instruction **))
{
	uint32_t align, offset;

	if (leb_u32(data->handle, &align) ||
	    leb_u32(data->handle, &offset)) {
		PRERR(MSG_BAD_NUM);
		return -1;
	}

	offset += MEMORY_BOTTOM_ADDR;

	return instr_routine(im(offset), &data->function->translated_body);
}

#define TLS(wasm_opcode, sm_instr, argtypes, restype)			\
	static int _translate_##sm_instr(struct translation *data)	\
	{								\
		return translate_load_store(data, i_##sm_instr);	\
	}

/* This inclusion defines functions using macros above */
#include "translate_xmacro.h"

#undef TS
#undef TLS
#undef TC

/** DEFINE TRANSLATION FUNCTIONS POINTER ARRAY **/

/* Translate complex */
#define TC(wasm_opcode, name, argtypes, restype)	\
	[wasm_opcode] = _translate_##name,

/* Translate Simple */
#define TS(wasm_opcode, sm_instr, argtypes, restype)	\
	TC(wasm_opcode, sm_instr, dummy, dummy)

/* Translate load/store */
#define TLS(wasm_opcode, sm_instr, argtypes, restype)	\
	TC(wasm_opcode, sm_instr, dummy, dummy)

/* The actual array of function pointers is defined here */
static int (*translation_routines[256]) (struct translation *) = {
#include "translate_xmacro.h"
};

#undef TS
#undef TLS
#undef TC

/** REST OF THE CODE **/

static int translate_instr(struct translation *data, uint8_t wasm_opcode)
{
	struct typecheck *tc_routines;

	if (!translation_routines[wasm_opcode]) {
		PRERR("Unknown Wasm opcode: 0x%02x\n", wasm_opcode);
		return -1;
	}

	tc_routines = typecheck_routines + wasm_opcode;

	return
		tc_routines->argcheck(&data->types_stack) ||
		translation_routines[wasm_opcode](data) ||
		tc_routines->rescheck(&data->types_stack);
}

static int parse_blocktype(FILE *handle, struct resulttype *args,
			   struct resulttype *results, char *storage,
			   struct module *module)
{
	int readval;
	uint32_t typeidx;

	readval = fgetc(handle);

	if (readval == EOF) {
		PRERR(MSG_EOF);
		return -1;
	}

	if (readval == 0x40) {
		/* Blocktype is empty (no arguments, no result values) */
		*args = (struct resulttype) {.count = 0, .types = NULL};
		*results = *args;
		return 0;
	}

	/*
	 * A nonnegative array index encoded as signed number in LEB
	 * shall have 0 as the second (most significant) bit of the first byte.
	 * Otherwise, it can't be array index, but might be a simple value type.
	 */
	if (readval & (1 << 6)) {
		if (!is_valid_valtype(readval))
			goto fail;

		*args = (struct resulttype) {.count = 0, .types = NULL};
		*storage = readval;
		*results = (struct resulttype) {.count = 1, .types = storage};
		return 0;
	}

	/*
	 * We know for sure it's a nonnegative number, we can just use leb_u32
	 * decoding function (encoding as signed or unsigned is the same in this
	 * particular case).
	 */
	ungetc(readval, handle);

	if (leb_u32(handle, &typeidx)) {
		PRERR(MSG_BAD_NUM);
		goto fail;
	}

	if (typeidx >= module->functypes_count) {
		PRERR(MSG_BAD_IDX("type index"));
		goto fail;
	}

	*args = module->functypes[typeidx].args;
	*results = module->functypes[typeidx].results;
	return 0;

fail:
	PRERR("Couldn't parse blocktype\n");
	return -1;
}

static struct target *add_target(struct module *module)
{
	struct target *tgt;

	tgt = malloc(sizeof(struct target));

	if (!tgt) {
		PRERR(MSG_ALLOC_FAIL(sizeof(struct target)));
		return NULL;
	}

	tgt->instr = NULL;
	tgt->prev = module->targets;
	module->targets = tgt;
	return tgt;
}

static struct label *add_label(struct translation *data, struct target *target,
			       uint32_t values_popped,
			       struct resulttype *arity)
{
	struct label *lbl;

	lbl = malloc(sizeof(struct label));

	if (!lbl) {
		PRERR(MSG_ALLOC_FAIL(sizeof(struct label)));
		return NULL;
	}

	lbl->prev = data->labels;
	lbl->target = target;
	lbl->arity = arity;
	lbl->values_on_stack = data->labels ? data->labels->values_on_stack : 0;
	lbl->values_on_stack += stack_size(data->types_stack);
	lbl->values_on_stack -= values_popped;

	data->labels = lbl;

	return lbl;
}

static int translate_expr(struct translation *data, struct resulttype *args,
			  struct resulttype *results,
			  const struct end_markers *end_markers,
			  char *marker_found, bool continuation_at_start)
{
	struct target *continuation;
	struct label *label = NULL;
	struct types **tmp, *types_stack_rest = NULL;
	uint32_t i, rescount = results ? results->count : 0;
	int wasm_opcode;
	bool last_instruction_was_branch = false;
	int retval = -1;

	continuation = add_target(data->module);

	if (!continuation)
		goto fail;

	if (continuation_at_start)
		continuation->instr = data->function->translated_body->prev;

	label = add_label(data, continuation, args ? args->count : 0,
			  continuation_at_start ? args : results);

	if (!label)
		goto fail;

	tmp = &data->types_stack;
	i = args ? args->count : 0;

	while (i--) {
		if (argcheck_generic_noremove(*tmp, args->types[i]))
			goto fail;

		tmp = &(*tmp)->prev;
	}

	types_stack_rest = *tmp;
	*tmp = NULL;

	while (1) {
		wasm_opcode = fgetc(data->handle);

		if (wasm_opcode == EOF) {
			PRERR(MSG_EOF);
			goto fail;
		}

		i = end_markers->count;

		while (i--) {
			if (wasm_opcode == end_markers->codes[i])
				goto block_end;
		}

		if (translate_instr(data, wasm_opcode))
			goto fail;

		/* WASM_BR_TABLE will also appear here once implemented */
		last_instruction_was_branch =
			wasm_opcode == WASM_BR;
	}

block_end:
	if (marker_found)
		*marker_found = wasm_opcode;

	if (!continuation_at_start)
		continuation->instr = data->function->translated_body->prev;

	/*
	 * Types on stack don't seem to matter, if last instruction was an
	 * unconditional branch anyway. However, we need to make the types stack
	 * appear ok to our caller.
	 */
	if (last_instruction_was_branch) {
		put_type(data->types_stack);
		data->types_stack = NULL;

		for (i = 0; i < rescount; i++) {
			if (rescheck_generic(&data->types_stack,
					     results->types[i]))
				goto fail;
		}
	}

	tmp = &data->types_stack;
	i = results ? results->count : 0;

	while (i--) {
		if (argcheck_generic_noremove(*tmp, results->types[i]))
			goto fail;

		tmp = &(*tmp)->prev;
	}

	if (*tmp) {
		PRERR("Expression produces too many result values\n");
		goto fail;
	}

	*tmp = types_stack_rest;
	types_stack_rest = NULL;
	retval = 0;

fail:
	if (label) {
		data->labels = label->prev;
		free(label);
	}

	if (types_stack_rest) {
		put_type(data->types_stack);
		data->types_stack = types_stack_rest;
	}

	return retval;
}

int translate(FILE *handle, struct function *function, struct module *module)
{
	struct instruction **expr = &function->translated_body;
	uint32_t args_count = function->type->args.count;
	uint32_t locals_count = function->locals_count;
	uint32_t all_locals_count = args_count + locals_count;
	size_t i;
	static const uint8_t function_end_marker_code = WASM_END;
	static const struct end_markers function_end_markers = {
		.count = 1,
		.codes = &function_end_marker_code
	};
	struct translation data = {.handle = handle,
				   .function = function,
				   .module = module,
				   .types_stack = NULL,
				   .labels = NULL};
	int retval = -1;

	if (locals_count + (uint64_t) args_count > STACK_TOP_ADDR * 4) {
		PRERR("Too many locals in a function\n");
		goto fail;
	}

	for (i = locals_count + 3; i; i--) {
		if (i_const(im(0), expr))
			goto fail;
	}

	/* function prologue */
	if (i_get_frame(                             expr) ||
	    i_tee      (                             expr) ||
	    i_load     (im(STACK_FRAME_BACKUP_ADDR), expr) ||
	    i_store_p  (im(0x0),                     expr) ||
	    i_store    (im(STACK_FRAME_BACKUP_ADDR), expr))
		goto fail;

	/* actual function body */
	if (translate_expr(&data, NULL, &function->type->results,
			   &function_end_markers, NULL, false))
		goto fail;

        /* function epilogue */
	if (i_load   (im(STACK_FRAME_BACKUP_ADDR),        expr))
		goto fail;

	if (function->type->results.count) {
		if (i_swap   (                                    expr) ||
		    i_store_p(im(4 * (2 + all_locals_count - 1)), expr) ||
		    i_load   (im(STACK_FRAME_BACKUP_ADDR),        expr) ||
		    i_tee    (                                    expr) ||
		    i_tee    (                                    expr) ||
		    i_load_p (im(4 * (1 + locals_count)),         expr) ||
		    i_store_p(im(4 * (2 + all_locals_count - 2)), expr) ||
		    i_load_p (im(0),                              expr) ||
		    i_store  (im(STACK_FRAME_BACKUP_ADDR),        expr))
			goto fail;
	} else {
		/* It's a bit shorter if we don't return anything */
		if (i_tee    (                                    expr) ||
		    i_tee    (                                    expr) ||
		    i_load_p (im(4 * (1 + locals_count)),         expr) ||
		    i_store_p(im(4 * (2 + all_locals_count - 1)), expr) ||
		    i_load_p (im(0),                              expr) ||
		    i_store  (im(STACK_FRAME_BACKUP_ADDR),        expr))
			goto fail;
	}

	i = locals_count + args_count + 2;

	if (!function->type->results.count)
		i++;

	while (i--) {
		if (i_drop(expr))
			goto fail;
	}

	if (i_ret(expr))
		goto fail;

	retval = 0;

fail:
	free_types_stack(data.types_stack);

	if (!retval)
		return retval;

	PRERR("Couldn't translate function to stack machine\n");

	free_expr(*expr);
	function->translated_body = NULL;

	return retval;
}