From 35a201cc8ef0c3f5b2df88d2e528aabee1048348 Mon Sep 17 00:00:00 2001 From: Wojtek Kosior Date: Fri, 30 Apr 2021 18:47:09 +0200 Subject: Initial/Final commit --- restore.c | 1705 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1705 insertions(+) create mode 100644 restore.c (limited to 'restore.c') diff --git a/restore.c b/restore.c new file mode 100644 index 0000000..2889655 --- /dev/null +++ b/restore.c @@ -0,0 +1,1705 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +#ifdef _WIN32 + +#include + +#else + +#include + +#endif + +#include "hashtable.h" +#include "string_buf.h" + +#define X(type, name) type, +#define Y(type, name) X(type, name) + +enum field_type { +#include "field_types.c" +}; + +const enum field_type field_type_ids[] = { +#include "field_types.c" +}; + +#undef X +#undef Y + +#define X(type, name) [type] = #type, +#define Y(type, name) X(type, name) + +static const char *const field_enum_names[] = { +#include "field_types.c" +}; + +#undef X +#undef Y + +#define X(type, name) [type] = name, +#define Y(type, name) X(type, name) + +static const char *const field_names[] = { +#include "field_types.c" +}; + +#undef X +#undef Y + +static inline bool is_rel_type(enum field_type type) +{ + return type == ONE_TO_ONE_REL || type == MANY_TO_ONE_REL || + type == MANY_TO_MANY_REL || type == BROKEN_TO_ONE_REL || + type == BROKEN_TO_MANY_REL; +} + +static inline bool is_text_type(enum field_type type) +{ + return type == CHAR_FIELD || type == FILE_FIELD || type == TEXT_FIELD; +} + +/* All keys and values are static, ht_destroy() is enough to free this. */ +static int make_field_names_ht(hashtable_t *ht) +{ + int ret; + int i; + + ret = ht_string_init(ht); + if (ret) + return ret; + + for (i = 0; i < sizeof(field_type_ids) / sizeof(*field_type_ids); i++) { + ret = ht_add(ht, field_names[i], &field_type_ids[i]); + if (ret) { + ht_destroy(ht); + break; + } + } + + return ret; +} + +inline static void *create_struct_with_string(size_t minimum_size, + size_t char_field_offset, + const char *string) +{ + size_t string_size = strlen(string) + 1; + size_t dynamic_size = char_field_offset + string_size; + char *mem; + + mem = malloc(dynamic_size > minimum_size ? dynamic_size : minimum_size); + if (mem) + memcpy(mem + char_field_offset, string, string_size); + + return mem; +} + +#define STRUCT_WITH_STRING_CREATOR(fun_name, type, char_field) \ + static type *fun_name(const char *string) \ + { \ + return (type*) create_struct_with_string \ + (sizeof(type), \ + (size_t) ((type*) 0)->char_field, \ + string); \ + } + +struct field { + struct field *next; + union field_data { + size_t max_length; + char *target_table; /* NULL means mtm table creation failure */ + bool array_has_data; + } spec_data; + uint8_t type; + char name[1]; +}; + +STRUCT_WITH_STRING_CREATOR(create_field, struct field, name) + +struct table { + struct field *fields; + size_t all_rows; + size_t inserted_rows; + size_t insertion_errors; + bool creation_error; + char name[1]; +}; + +STRUCT_WITH_STRING_CREATOR(create_table, struct table, name) + +struct restore { + size_t errors_count; + enum { + PASS_1 = 1, + PASS_2 = 2 + } pass; + int nesting; + bool fatal_error; + bool creating_table; + struct table *current_table; + struct field *current_field; + + char *characters_buf; + size_t characters_size; + size_t characters_count; /* in PASS_1 used without _buf and _size */ + + char *pk_buf; + size_t pk_size; + size_t pk_filled; + + char *stmt_buf; + size_t stmt_size; + size_t stmt_filled; + + hashtable_t field_names; + + hashtable_t tables; + + MYSQL *mysql; + + FILE *raport_stream; +}; + +/* + * I'm not sure how I should interpret the 'bigger' and 'smaller' when it comes + * to chars, that may or may not be signed. Anyways, no need to be coherent with + * the original version because in practice this function is only ever used + * to check equality. + */ +static int strcmp_replacing_dot(const char *s1, const char *s2) +{ + size_t i = 0; + int c1, c2; + + while (true) { + c1 = (int) (unsigned char) s1[i]; + c2 = (int) (unsigned char) s2[i]; + + if (c1 == '.') + c1 = '_'; + if (c2 == '.') + c2 = '_'; + + if (c1 != c2) + return c1 - c2; + + if (!c1) + return 0; + + i++; + } +} + +static size_t string_hash_replacing_dot(const void *key_) +{ + size_t i = 0, hash = (size_t) 0xa1bad2dead3beef4; + char c, shift; + const char *key = key_; + + do { + c = key[i++]; + + if (c == '.') + c = '_'; + + shift = ((unsigned char) c) % sizeof(size_t); + hash += ((hash >> shift) | + (hash << (sizeof(size_t) - shift))) ^ c; + } while (c); + + return hash; +} + +static void replace_dot(char *s) +{ + do { + if (*s == '.') + *s = '_'; + } while (*(s++)); +} + +static int init_restore_data(struct restore *data, const char *raport_path) +{ + data->errors_count = 0; + data->pass = PASS_1; + data->nesting = 0; + data->fatal_error = false; + data->creating_table = false; + data->current_table = NULL; + data->current_field = NULL; + + data->characters_buf = NULL; + data->characters_size = 0; + data->characters_count = 0; + + data->stmt_buf = NULL; + data->stmt_size = 0; + data->stmt_filled = 0; + + data->pk_buf = NULL; + data->pk_size = 0; + data->pk_filled = 0; + + if (make_field_names_ht(&data->field_names)) + return -1; + + if (ht_init(&data->tables, string_hash_replacing_dot, + (int (*)(const void*, const void*)) strcmp_replacing_dot)) { + ht_destroy(&data->field_names); + return -1; + } + + data->raport_stream = raport_path ? fopen(raport_path, "w+") : NULL; + + if (!data->raport_stream) { + if (raport_path) { + fprintf(stderr, + "Nie udalo sie utworzyc pliku raportu\n\n"); + } + data->raport_stream = stdout; + } + + data->mysql = NULL; + + return 0; +} + +static void free_field(struct field *field) +{ + if (!field) + return; + + if (is_rel_type(field->type)) + free(field->spec_data.target_table); + free(field); +} + +static void free_table(struct table *table) +{ + struct field *field, *tmp; + + if (!table) + return; + + field = table->fields; + while (field) { + tmp = field->next; + free_field(field); + field = tmp; + } + + free(table); +} + +static void free_table_entry(void *key, void *val, void *dummy_arg) +{ + free_table(val); +} + +static void destroy_restore_data(struct restore *data) +{ + ht_destroy(&data->field_names); + ht_map_destroy(&data->tables, NULL, free_table_entry); + free(data->characters_buf); + free(data->stmt_buf); + free(data->pk_buf); + if (data->raport_stream != stderr) + fclose(data->raport_stream); +} + +static void out_of_memory_error(struct restore *data) +{ + /* + * TODO: dodac informacje na jakim etapie dzialania stanelo + * i wypisać jakies informacje o dotychczas zrobionych rzeczach. + */ + data->fatal_error = true; + + fprintf(data->raport_stream, "Brak pamieci.\n"); +} + +static void semantic_error(struct restore *data, const char *msg, ...) +{ + va_list ap; + + data->fatal_error = true; + + fprintf(data->raport_stream, "Blad semantyczny xml: "); + va_start(ap, msg); + vfprintf(data->raport_stream, msg, ap); + va_end(ap); + fputc('\n', data->raport_stream); +} + +struct sb_mysql_arg { + MYSQL *mysql; + const char *string; +}; + +static int sb_mysql_escape(char **buf, size_t *buf_len, size_t *buf_filled, + void *arg_) +{ + struct sb_mysql_arg *arg = arg_; + size_t len = strlen(arg->string); + + if (extend_buf(buf, buf_len, buf_filled, len * 2)) + return -1; + + *buf_filled += mysql_real_escape_string(arg->mysql, *buf + *buf_filled, + arg->string, len); + + return 0; +} + +static int sb_name_escape(char **buf, size_t *buf_len, size_t *buf_filled, + void *name_) +{ + const unsigned char *name = name_; + size_t i; + + for (i = 0; name[i]; i++) { + if (name[i] == '`') { + if (sb_bytes(buf, buf_len, buf_filled, name, i + 1)) + return -1; + name += i; + i = 0; + } + } + + return sb_bytes(buf, buf_len, buf_filled, name, i); +} + +static bool is_field_compatible(struct restore *data, const char *name, + enum field_type type, const char *target_table) +{ + struct field *previous = data->current_field; + + return + !strcmp(previous->name, (const char*) name) && + previous->type == type && + (!is_rel_type(type) || + !strcmp_replacing_dot(target_table, + previous->spec_data.target_table)); +} + +static void new_field(struct restore *data, const char *name, + enum field_type type, const char *target_table) +{ + struct field *field = NULL, **dst; + size_t table_name_size; + + field = create_field(name); + if (!field) + goto fail; + + field->next = NULL; + field->type = type; + if (is_rel_type(type)) { + table_name_size = strlen(target_table) + 1; + field->spec_data.target_table = malloc(table_name_size); + if (!field->spec_data.target_table) + goto fail; + memcpy(field->spec_data.target_table, + target_table, table_name_size); + replace_dot(field->spec_data.target_table); + } else if (type == ARRAY_FIELD) { + field->spec_data.array_has_data = false; + } else { + field->spec_data.max_length = 0; + } + + dst = data->current_field ? &data->current_field->next : + &data->current_table->fields; + *dst = field; + + data->current_field = field; + return; + +fail: + free_field(field); + out_of_memory_error(data); +} + +static void new_table(struct restore *data, const char *table_name) +{ + struct table *table = NULL; + + table = create_table(table_name); + if (!table) + goto fail; + replace_dot(table->name); + + table->creation_error = false; + table->all_rows = 0; + table->inserted_rows = 0; + table->insertion_errors = 0; + table->fields = NULL; + + if (ht_add(&data->tables, table->name, table)) + goto fail; + + data->creating_table = true; + data->current_table = table; + data->current_field = NULL; + return; + +fail: + free_table(table); + out_of_memory_error(data); +} + +static int find_attrs(const char **attrs, int to_find, + const char *const *searched, const char **found) +{ + int i, j; /* `i' iterates through attrs[], `j' through searched */ + int found_count = 0; + + if (!attrs) + return to_find; + + for (j = 0; j < to_find; j++) + found[j] = NULL; + + for (i = 0; attrs[i]; i += 2) { + for (j = 0; j < to_find; j++) { + if (!found[j] && !strcmp(attrs[i], searched[j])) { + found_count++; + found[j] = attrs[i + 1]; + if (to_find == found_count) + return 0; + } + } + } + + return to_find - found_count; +} + +/* + * TODO: change handle_field(), handle_entry() and other functions below and + * above not to call exit(), but rather let control return to main(). + */ + +static void handle_field(struct restore *data, const char *name, + const char **attrs) +{ + const char *searched[4] = {"name", "type", "rel", "to"}; + const char *found[4]; + const char *type_name; + const enum field_type *type; + + data->characters_count = 0; + + if (data->pass == PASS_2) + return; + + if (strcmp(name, "field")) { + semantic_error(data, "tag '%s' zamiast 'field' na poziomie 2", + name); + return; + } + + find_attrs(attrs, 4, searched, found); + + if (!found[0]) { + semantic_error(data, "brak atrybutu 'name' w polu modelu `%s`", + data->current_table->name); + return; + } + + type_name = found[2] ? found[2] : found[1] ? found[1] : NULL; + if (!type_name) { + semantic_error(data, + "brak atrybutu 'type' w polu `%s` modelu `%s`", + found[0], data->current_table->name); + return; + } + + if (ht_get(&data->field_names, type_name, NULL, (void**) &type)) { + fprintf(data->raport_stream, + "nieznany typ '%s' kolumny `%s` tabeli `%s`.\n", + type_name, found[0], data->current_table->name); + type = &field_type_ids[UNKNOWN_FIELD]; + } + + if (is_rel_type(*type) && !found[3]) { + semantic_error(data, + "brak atrybutu 'to' w polu `%s` modelu `%s`", + found[0], data->current_table->name); + return; + } + + if (data->creating_table) { + new_field(data, found[0], *type, found[3]); + } else if (!data->current_field) { + semantic_error(data, + "zmienna liczba pol w rekordach modelu `%s`", + data->current_table->name); + return; + } else if (!is_field_compatible(data, found[0], *type, found[3])){ + semantic_error(data, "niekompatybilne rekordy modelu `%s`", + data->current_table->name); + return; + } + + return; +} + +static void handle_entry(struct restore *data, const char *name, + const char **attrs) +{ + const char *searched[2] = {"model", "pk"}; + const char *found[2]; + struct table *table; + struct sb_mysql_arg sb_arg; + int ret; + + if (strcmp(name, "object")) { + semantic_error(data, "tag '%s' zamiast 'object' na poziomie 1", + name); + return; + } + + find_attrs(attrs, 2, searched, found); + + if (!found[0]) { + semantic_error(data, "brak atrybutu 'model' w obiekcie"); + return; + } + if (!found[1]) { + semantic_error(data, "brak atrybutu 'pk' w obiekcie"); + return; + } + + ret = ht_get(&data->tables, found[0], NULL, (void**) &table); + + if (ret == HT_KEY_ABSENT) { + switch (data->pass) { + case PASS_1: + new_table(data, found[0]); + break; + case PASS_2: + goto fail; + } + } else { + data->creating_table = false; + if (data->pass == PASS_2 && data->current_table != table) + printf("Wypelnianie tabeli `%s`\n", table->name); + data->current_table = table; /* not really needed in PASS_1 */ + data->current_field = table->fields; + } + + if (data->pass != PASS_2) + return; + + data->pk_filled = 0; + if (sb_string(&data->pk_buf, &data->pk_size, + &data->pk_filled, found[1])) + goto fail; + + data->stmt_filled = 0; + sb_arg = (struct sb_mysql_arg) {data->mysql, found[1]}; + if (sb_sprintf(&data->stmt_buf, &data->stmt_size, + &data->stmt_filled, + "INSERT INTO `%_` VALUES(\n" + "'%_'", + sb_name_escape, table->name, sb_mysql_escape, &sb_arg)) + goto fail; + + return; + +fail: + data->fatal_error = true; +} + +static int insert_to_connector_table(struct restore *data, const char **attrs) +{ + int res = -1; + struct field *field = data->current_field; + const char *tab1 = data->current_table->name; + const char *tab2 = field->spec_data.target_table; + const char *pk_attr = "pk"; + const char *pk; + char *stmt_buf = NULL; + size_t stmt_size = 0, stmt_filled = 0; + struct sb_mysql_arg sb_arg1, sb_arg2; + + if (!tab2) /* Means error during creation of conenctor table */ + return 0; + + find_attrs(attrs, 1, &pk_attr, &pk); + if (!pk) { + semantic_error(data, + "Brak atrybutu 'pk' w tagu w relacji wiele-do-wielu.\n"); + goto fail; + } + + sb_arg1 = (struct sb_mysql_arg) {data->mysql, data->pk_buf}; + sb_arg2 = (struct sb_mysql_arg) {data->mysql, pk}; + if (sb_sprintf(&stmt_buf, &stmt_size, &stmt_filled, + "INSERT INTO `%__mtm_%_` VALUES(\n" + "'%_',\n" + "'%_');", + sb_name_escape, tab1, sb_name_escape, tab2, + sb_mysql_escape, &sb_arg1, + sb_mysql_escape, &sb_arg2)) + goto fail; + + if (mysql_real_query(data->mysql, stmt_buf, stmt_filled)) { + data->errors_count++; + if (!data->current_table->insertion_errors++) { + fprintf(data->raport_stream, + "Nie udalo sie wstawic elementu do tabeli lacznikowej:\n" + "Blad %u (%s): %s\n", + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql)); + } + } + + res = 0; + +fail: + free(stmt_buf); + return res; +} + +static void startElement_callback(void *user_data, const xmlChar *name_, + const xmlChar **attrs_) +{ + const char *name = (const char*) name_; + const char **attrs = (const char**) attrs_; + struct restore *data = user_data; + + if (data->fatal_error) + return; + + switch(data->nesting++) { + case 0: + /* We have no interest in the root element */ + return; + case 1: + handle_entry(data, name, attrs); + break; + case 2: + handle_field(data, name, attrs); + break; + case 3: + if (data->pass == PASS_2 && + (data->current_field->type == MANY_TO_MANY_REL || + data->current_field->type == BROKEN_TO_MANY_REL) && + insert_to_connector_table(data, attrs)) + goto fail; + default: + break; + } + + return; + +fail: + data->fatal_error = true; +} + +static int insert_CHAR_FIELD(struct restore *data) +{ + struct sb_mysql_arg sb_arg; + + if (!data->characters_count && + sb_string(&data->stmt_buf, &data->stmt_size, + &data->stmt_filled, ",\nNULL")) + return -1; + + sb_arg = (struct sb_mysql_arg) + {data->mysql, data->characters_buf}; + if (data->characters_count && + sb_sprintf(&data->stmt_buf, &data->stmt_size, + &data->stmt_filled, + ",\n'%_'", + sb_mysql_escape, &sb_arg)) + return -1; + + return 0; +} + +#define insert_FILE_FIELD insert_CHAR_FIELD +#define insert_TEXT_FIELD insert_CHAR_FIELD +#define insert_FILE_FIELD insert_CHAR_FIELD +#define insert_INT_FIELD insert_CHAR_FIELD +#define insert_SMALL_INT_FIELD insert_CHAR_FIELD +#define insert_POSITIVE_INT_FIELD insert_CHAR_FIELD +#define insert_BIG_INT_FIELD insert_CHAR_FIELD +#define insert_DECIMAL_FIELD insert_CHAR_FIELD +#define insert_TIME_FIELD insert_CHAR_FIELD +#define insert_DATETIME_FIELD insert_CHAR_FIELD +#define insert_DATE_FIELD insert_CHAR_FIELD +#define insert_JSON_FIELD insert_CHAR_FIELD +#define insert_ONE_TO_ONE_REL insert_CHAR_FIELD +#define insert_MANY_TO_ONE_REL insert_CHAR_FIELD +#define insert_BROKEN_TO_ONE_REL insert_CHAR_FIELD +#define insert_UNKNOWN_FIELD insert_CHAR_FIELD + +static int insert_BOOL_FIELD(struct restore *data) +{ + const char *value; + + if (data->characters_count && strstr(data->characters_buf, "False")) + value = ",\nFALSE"; + else if (data->characters_count && strstr(data->characters_buf, "True")) + value = ",\nTRUE"; + else + value = ",\nNULL"; + + if (sb_string(&data->stmt_buf, &data->stmt_size, + &data->stmt_filled, value)) + return -1; + + return 0; +} + +static int insert_array_element(struct restore *data, const char *element, + int *current_id) +{ + int res = -1; + char *stmt_buf = NULL; + size_t stmt_size = 0, stmt_filled = 0; + struct sb_mysql_arg sb_arg1, sb_arg2; + + sb_arg1 = (struct sb_mysql_arg) {data->mysql, data->pk_buf}; + sb_arg2 = (struct sb_mysql_arg) {data->mysql, element}; + if (sb_sprintf(&stmt_buf, &stmt_size, &stmt_filled, + "INSERT INTO `%__array_%_` VALUES(\n" + "%d,\n" + "'%_',\n" + "'%_');", + sb_name_escape, data->current_table->name, + sb_name_escape, data->current_field->name, + (*current_id)++, + sb_mysql_escape, &sb_arg1, + sb_mysql_escape, &sb_arg2)) + goto fail; + + if (mysql_real_query(data->mysql, stmt_buf, stmt_filled)) { + data->errors_count++; + if (!data->current_table->insertion_errors++) { + fprintf(data->raport_stream, + "Nie udalo sie wstawic elementu danej ARRAY:\n" + "Blad %u (%s): %s\n", + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql)); + } + } + + res = 0; + +fail: + free(stmt_buf); + return res; +} + +static int insert_ARRAY_FIELD(struct restore *data) +{ + char *pos, *string_start; + unsigned char c; + int current_id = 0; + enum { + NOT_STARTED, + BEFORE_FIRST_STRING, + BEFORE_STRING, + IN_STRING, + AFTER_STRING, + FINISHED + } state = NOT_STARTED; + + for (pos = data->characters_buf; *pos; pos++) { + c = *pos; + + switch (state) { + case NOT_STARTED: + if (isspace(c)) + break; + if (c == '[') { + state = BEFORE_FIRST_STRING; + break; + } + goto syntax_fail; + case BEFORE_FIRST_STRING: + case BEFORE_STRING: + if (isspace(c)) + break; + if (c == '"') { + state = IN_STRING; + string_start = pos + 1; + break; + } + if (c == ']' && state == BEFORE_FIRST_STRING) { + state = FINISHED; + break; + } + goto syntax_fail; + case IN_STRING: + if (c != '"' || pos[-1] == '\\') + break; + *pos = '\0'; + if (insert_array_element(data, string_start, + ¤t_id)) + goto fail; + state = AFTER_STRING; + break; + case AFTER_STRING: + if (isspace(c)) + break; + if (c == ',') { + state = BEFORE_STRING; + break; + } + if (c == ']') { + state = FINISHED; + break; + } + goto syntax_fail; + case FINISHED: + if (!isspace(c)) + goto syntax_fail; + } + } + + if (state != FINISHED) + goto fail; + + return 0; + +syntax_fail: + fprintf(data->raport_stream, + "Blad syntaktyczny w danej typu array:\n%s\n", + data->characters_buf); +fail: + return -1; +} + +static int no_insert_field(struct restore *data) +{ + return 0; +} + +/* Deliberately omitted */ +#define insert_EMPTY_ARRAY_FIELD no_insert_field + +/* Inserts to connector tables happen elsewhere */ +#define insert_MANY_TO_MANY_REL no_insert_field +#define insert_BROKEN_TO_MANY_REL no_insert_field + +#define X(type, name) [type] = insert_##type, +#define Y(type, name) X(type, name) + +static int (*field_inserters[])(struct restore *) = { +#include "field_types.c" +}; + +#undef X +#undef Y + +static int finalize_insert(struct restore *data) +{ + if (sb_string(&data->stmt_buf, &data->stmt_size, + &data->stmt_filled, ");")) + return -1; + + data->current_table->all_rows++; + if (mysql_real_query(data->mysql, data->stmt_buf, data->stmt_filled)) { + data->errors_count++; + if (!data->current_table->insertion_errors++) { + fprintf(data->raport_stream, + "Nie udalo sie wstawic rekordu do bazy:\n" + "Blad %u (%s): %s\n" + "Zapytanie:\n" + "%s\n", + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql), + data->stmt_buf); + } + } else { + data->current_table->inserted_rows++; + } + + return 0; +} + +static void endElement_callback(void *user_data, const xmlChar *name) +{ + struct restore *data = user_data; + struct field *field = data->current_field; + + if (data->fatal_error) + return; + + switch (data->nesting--) { + case 3: + break; + case 2: + if (data->pass == PASS_2 && + !data->current_table->creation_error && + finalize_insert(data)) + goto fail; + default: + return; + } + + switch (data->pass) { + case PASS_1: + if (!field) + goto out; + + if (is_text_type(field->type) && + field->spec_data.max_length < data->characters_count) + field->spec_data.max_length = data->characters_count; + + if (field->type == ARRAY_FIELD && + strchr(data->characters_buf, '"')) + field->spec_data.array_has_data = true; + + break; + case PASS_2: + if (!field) + goto fail; + + if (field_inserters[field->type](data)) + goto fail; + } + + if (!data->creating_table) + data->current_field = field->next; + +out: + return; + +fail: + data->fatal_error = true; +} + +static void characters_callback(void *user_data, const xmlChar *characters, + int len) +{ + struct restore *data = user_data; + + if (data->fatal_error) + return; + + /* Should not happen, but... */ + if (!data->current_field) { + if (data->pass == PASS_1) + fprintf(data->raport_stream, + "Znaki poza tagiem 'field': '%.*s'\n", + len, characters); + return; + } + + switch (data->pass) { + case PASS_1: + if (data->current_field->type != ARRAY_FIELD) { + data->characters_count += len; + break; + } + case PASS_2: + if (sb_bytes(&data->characters_buf, &data->characters_size, + &data->characters_count, characters, len)) + data->fatal_error = true; + } +} + +static void error_callback(void *user_data, const char *msg, ...) +{ + struct restore *data = user_data; + va_list ap; + int c; + + if (data->fatal_error) + return; + + fprintf(data->raport_stream, "Blad xml: "); + va_start(ap, msg); + vfprintf(data->raport_stream, msg, ap); + va_end(ap); + if (!data->errors_count++) { + fprintf(stderr, "Blad syntaktyczny xml, kontynuowac? y/N: "); + fflush(stderr); + c = getchar(); + if (c != 'y' && c != 'Y') + data->fatal_error = true; + } +} + +static int clear_database(struct restore *data, const char *dbname) +{ + int res = -1; + char *stmt_buf = NULL; + size_t buf_len = 0, buf_filled = 0; + struct sb_mysql_arg sb_arg; + MYSQL_RES *query_res = NULL; + size_t i; + my_ulonglong rows; + MYSQL_ROW *rows_array = NULL; + + sb_arg = (struct sb_mysql_arg) {data->mysql, dbname}; + if (sb_sprintf(&stmt_buf, &buf_len, &buf_filled, + "SELECT table_name " + "FROM information_schema.tables " + "WHERE table_schema = '%_';", + sb_mysql_escape, &sb_arg)) + goto fail; + + if (mysql_real_query(data->mysql, stmt_buf, buf_filled)) { + fprintf(data->raport_stream, + "Nie udalo sie pobrac nazw istniejacych tabel:\n" + "Blad %u (%s): %s\n" + "Zapytanie:\n" + "%s\n", + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql), stmt_buf); + goto fail; + } + + query_res = mysql_store_result(data->mysql); + if (!query_res) + goto fail; + + rows = mysql_num_rows(query_res); + + rows_array = malloc(sizeof(MYSQL_ROW) * rows); + if (!rows_array) + goto fail; + + for (i = 0; i < rows; i++) { + rows_array[i] = mysql_fetch_row(query_res); + if (!rows_array[i]) + goto fail; + } + + for (i = 0; i < rows; i++) { + printf("Usuwanie tabeli `%s`.\n", rows_array[i][0]); + + buf_filled = 0; + + /* Assume fetched table names are not malicious */ + if (sb_sprintf(&stmt_buf, &buf_len, &buf_filled, + "DROP TABLE IF EXISTS `%_`;", + sb_name_escape, rows_array[i][0])) + goto fail; + + if (mysql_real_query(data->mysql, stmt_buf, buf_filled)) { + fprintf(data->raport_stream, + "Nie udalo sie usunac instniejacej tabeli:\n" + "Blad %u (%s): %s\n", + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql)); + goto fail; + } + } + + putchar('\n'); + res = 0; + +fail: + free(rows_array); + if (query_res) + mysql_free_result(query_res); + free(stmt_buf); + return res; +} + +static int sb_field_name(char **stmt_buf, size_t *buf_len, + size_t *buf_filled, struct field *field) +{ + if (sb_sprintf(stmt_buf, buf_len, buf_filled, + ",\n" + "`%_` ", + sb_name_escape, field->name)) + return -1; + + return 0; +} + +#define SIMPLE_FIELD_DECLARATOR(type, sql_name) \ + static int sb_##type(char **stmt_buf, size_t *buf_len, \ + size_t *buf_filled, struct field *field) \ + { \ + if (sb_field_name(stmt_buf, buf_len, \ + buf_filled, field) || \ + sb_string(stmt_buf, buf_len, buf_filled, sql_name)) \ + return -1; \ + \ + return 0; \ + } + +SIMPLE_FIELD_DECLARATOR(TEXT_FIELD, "TEXT") +SIMPLE_FIELD_DECLARATOR(INT_FIELD, "INT") +SIMPLE_FIELD_DECLARATOR(SMALL_INT_FIELD, "SMALLINT") +SIMPLE_FIELD_DECLARATOR(POSITIVE_INT_FIELD, "INT UNSIGNED") +SIMPLE_FIELD_DECLARATOR(BIG_INT_FIELD, "BIGINT") +SIMPLE_FIELD_DECLARATOR(DECIMAL_FIELD, "DECIMAL") +SIMPLE_FIELD_DECLARATOR(BOOL_FIELD, "BOOL") +SIMPLE_FIELD_DECLARATOR(TIME_FIELD, "TIME") +SIMPLE_FIELD_DECLARATOR(DATETIME_FIELD, "DATETIME") +SIMPLE_FIELD_DECLARATOR(DATE_FIELD, "DATE") +SIMPLE_FIELD_DECLARATOR(JSON_FIELD, "TEXT") +/* We do foreign key constraints later, now we just create an int column */ +SIMPLE_FIELD_DECLARATOR(ONE_TO_ONE_REL, "INT") +SIMPLE_FIELD_DECLARATOR(MANY_TO_ONE_REL, "INT") +SIMPLE_FIELD_DECLARATOR(BROKEN_TO_ONE_REL, "INT") +SIMPLE_FIELD_DECLARATOR(UNKNOWN_FIELD, "TEXT") + +static int sb_CHAR_FIELD(char **stmt_buf, size_t *buf_len, + size_t *buf_filled, struct field *field) +{ + if (sb_field_name(stmt_buf, buf_len, buf_filled, field) || + sb_sprintf(stmt_buf, buf_len, buf_filled, "VARCHAR(%u)", + (unsigned) field->spec_data.max_length)) + return -1; + + return 0; +} + +#define sb_FILE_FIELD sb_CHAR_FIELD + +static int no_sb_field(char **stmt_buf, size_t *buf_len, + size_t *buf_filled, struct field *field) +{ + return 0; +} + +/* Deliberately omitted */ +#define sb_EMPTY_ARRAY_FIELD no_sb_field + +/* Realized in terms of separate tables */ +#define sb_ARRAY_FIELD no_sb_field +#define sb_MANY_TO_MANY_REL no_sb_field +#define sb_BROKEN_TO_MANY_REL no_sb_field + +#define X(type, name) [type] = sb_##type, +#define Y(type, name) X(type, name) + +static int (*field_declarators[])(char**, size_t*, size_t*, struct field*) = { +#include "field_types.c" +}; + +#undef X +#undef Y + +static void create_mysql_table(const void *key, void *val, void *data_) +{ + bool fatal_error = true; + struct table *table = val; + struct field *field; + struct restore *data = data_; + char *stmt_buf = NULL; + size_t buf_len = 0, buf_filled = 0; + + if (data->fatal_error) + return; + + printf("Tworzenie tabeli `%s`.\n", table->name); + + if (sb_sprintf(&stmt_buf, &buf_len, &buf_filled, + "CREATE TABLE `%_` (\n" + "ID INT PRIMARY KEY", + sb_name_escape, table->name)) + goto fail; + + for (field = table->fields; field; field = field->next) { + if (field_declarators[field->type](&stmt_buf, &buf_len, + &buf_filled, field)) + goto fail; + } + + if (sb_string(&stmt_buf, &buf_len, &buf_filled, ");")) + goto fail; + + if (mysql_real_query(data->mysql, stmt_buf, buf_filled)) { + data->errors_count++; + fprintf(data->raport_stream, + "Nie udalo sie utworzyc tabeli `%s`:\n" + "Blad %u (%s): %s\n" + "Zapytanie:\n" + "%s\n", + table->name, + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql), + stmt_buf); + table->creation_error = true; + } + + fatal_error = false; + +fail: + free(stmt_buf); + if (fatal_error) + data->fatal_error = true; +} + +static int create_mysql_tables(struct restore *data) +{ + ht_map(&data->tables, data, create_mysql_table); + if (data->fatal_error) + return -1; + putchar('\n'); + return 0; +} + +static void check_table_relations_correctness(const void *key, void *val, + void *data_) +{ + struct restore *data = data_; + struct table *table = val, *target_table; + struct field *field; + + for (field = table->fields; field; field = field->next) { + if (field->type == ARRAY_FIELD && + !field->spec_data.array_has_data) { + fprintf(data->raport_stream, + "Pole `%s` typu ARRAY tabeli `%s` nie zawiera danych, bedzie ono pominiete w odtworzonej bazie.\n", + field->name, table->name); + field->type = EMPTY_ARRAY_FIELD; + } + + if (!is_rel_type(field->type)) + continue; + + target_table = NULL; + + /* + * I wrote that hashtable. I can assume certain ht api functions + * are re-entrant... Can I? + */ + if (!ht_get(&data->tables, field->spec_data.target_table, + NULL, (void**) &target_table) && + !target_table->creation_error) + continue; + + fprintf(data->raport_stream, + "Relacja %s z tabeli `%s` (pole `%s`) do nie%s tabeli `%s`\n", + field_enum_names[field->type], + table->name, field->name, + target_table ? "dodanej" : "istniejacej", + field->spec_data.target_table); + + field->type = field->type == MANY_TO_MANY_REL ? + BROKEN_TO_MANY_REL : BROKEN_TO_ONE_REL; + } +} + +static void check_tables_relations(struct restore *data) +{ + /* It's a bit of a cheat - assuming I can call ht_get() while mapping */ + ht_map(&data->tables, data, check_table_relations_correctness); + fputc('\n', data->raport_stream); +} + +static void fprint_field(FILE *os, const struct field *field) +{ + fprintf(os, " `%s` %s", field->name, field_enum_names[field->type]); + + if (is_rel_type(field->type)) + fprintf(os, " to `%s`", field->spec_data.target_table); + + if (is_text_type(field->type)) { + fprintf(os, " max %lu", + (unsigned long) field->spec_data.max_length); + } + + fputc('\n', os); +} + +static void fprint_table(FILE *os, const struct table *table) +{ + struct field *field; + + fprintf(os, "Tabela `%s`\n", table->name); + for (field = table->fields; field; field = field->next) + fprint_field(os, field); + + putc('\n', os); +} + +static void fprint_table_entry(const void *key, void *val, void *os) +{ + fprint_table(os, val); +} + +static void fprint_all_tables(FILE *os, hashtable_t *tables) +{ + ht_map(tables, os, fprint_table_entry); +} + +static int add_fk_constraint(struct restore *data, struct table *table, + struct field *field) +{ + int res = -1; + char *stmt_buf = NULL; + size_t buf_len = 0, buf_filled = 0; + + printf("Nakladanie ograniczenia na klucz obcy `%s` w tabeli `%s`.\n", + field->name, table->name); + + if (sb_sprintf(&stmt_buf, &buf_len, &buf_filled, + "ALTER TABLE `%_`\n" + "ADD CONSTRAINT `%__%__FK` FOREIGN KEY (`%_`)\n" + "REFERENCES `%_`(id);", + sb_name_escape, table->name, sb_name_escape, table->name, + sb_name_escape, field->name, sb_name_escape, field->name, + sb_name_escape, field->spec_data.target_table)) + goto out; + + if (mysql_real_query(data->mysql, stmt_buf, buf_filled)) { + data->errors_count++; + fprintf(data->raport_stream, + "Nie udalo sie dodac ograniczenia klucza obcego na kolumne `%s` tabeli `%s`:\n" + "Blad %u (%s): %s\n", + field->name, table->name, + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql)); + } + + res = 0; + +out: + free(stmt_buf); + return res; +} + +static int sb_maybe_references(char **buf, size_t *buf_len, + size_t *buf_filled, void *table_name) +{ + if (!table_name) + return 0; + + return sb_sprintf(buf, buf_len, buf_filled, + " REFERENCES `%_`(id)", sb_name_escape, table_name); +} + +static int create_mysql_mtm_table(struct restore *data, struct table *table, + struct field *field) +{ + int res = -1; + char *stmt_buf = NULL; + size_t buf_len = 0, buf_filled = 0; + char *tab1 = table->name, *tab2 = field->spec_data.target_table; + + printf("Tworzenie tabeli lacznikowej `%s_mtm_%s`.\n", tab1, tab2); + + if (sb_sprintf(&stmt_buf, &buf_len, &buf_filled, + "CREATE TABLE `%__mtm_%_` (\n" + "`%__id` INT%_,\n" + "`%__%__id` INT%_,\n" + "PRIMARY KEY(`%__id`, `%__%__id`));", + sb_name_escape, tab1, sb_name_escape, tab2, + sb_name_escape, tab1, + sb_maybe_references, + table->creation_error ? NULL : tab1, + sb_name_escape, field->name, sb_name_escape, tab2, + sb_maybe_references, + field->type == BROKEN_TO_MANY_REL ? NULL : tab2, + sb_name_escape, tab1, + sb_name_escape, field->name, sb_name_escape, tab2)) + goto fail; + + if (mysql_real_query(data->mysql, stmt_buf, buf_filled)) { + data->errors_count++; + fprintf(data->raport_stream, + "Nie udalo sie utworzyc tabeli lacznikowej `%s_mtm_%s`:\n" + "Blad %u (%s): %s\n" + "Zapytanie:\n" + "%s\n", + tab1, tab2, + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql), + stmt_buf); + field->spec_data.target_table = NULL; + } + + res = 0; + +fail: + free(stmt_buf); + return res; +} + +static int create_mysql_array(struct restore *data, struct table *table, + struct field *field) +{ + int res = -1; + char *stmt_buf = NULL; + size_t buf_len = 0, buf_filled = 0; + + printf("Tworzenie tabeli `%s_array_%s` do zrealizowania pola ARRAY.\n", + table->name, field->name); + + if (sb_sprintf(&stmt_buf, &buf_len, &buf_filled, + "CREATE TABLE `%__array_%_` (\n" + "id_of_element INT,\n" + "`%__id` INT%_,\n" + "element VARCHAR(255),\n" + "PRIMARY KEY(`%__id`, id_of_element));", + sb_name_escape, table->name, sb_name_escape, field->name, + sb_name_escape, table->name, + sb_maybe_references, + table->creation_error ? NULL : table->name, + sb_name_escape, table->name)) + goto fail; + + if (mysql_real_query(data->mysql, stmt_buf, buf_filled)) { + data->errors_count++; + fprintf(data->raport_stream, + "Nie udalo sie utworzyc tabeli `%s_array%s` realizujacej pole typu ARRAY:\n" + "Blad %u (%s): %s\n" + "Zapytanie:\n" + "%s\n", + table->name, field->name, + mysql_errno(data->mysql), + mysql_sqlstate(data->mysql), + mysql_error(data->mysql), + stmt_buf); + field->type = EMPTY_ARRAY_FIELD; + } + + res = 0; + +fail: + free(stmt_buf); + return res; +} + +static void create_mysql_table_relations(const void *key, void *val, + void *data_) +{ + struct table *table = val; + struct field *field; + struct restore *data = data_; + + if (data->fatal_error) + return; + + for (field = table->fields; field; field = field->next) { + if ((field->type == MANY_TO_MANY_REL || + field->type == BROKEN_TO_MANY_REL) && + create_mysql_mtm_table(data, table, field)) + goto fail; + + if (field->type == ARRAY_FIELD && + create_mysql_array(data, table, field)) + goto fail; + + if (table->creation_error) + continue; + + if ((field->type == ONE_TO_ONE_REL || + field->type == MANY_TO_ONE_REL) && + add_fk_constraint(data, table, field)) + goto fail; + } + + return; + +fail: + data->fatal_error = true; +} + +static int complete_mysql_tables(struct restore *data) +{ + ht_map(&data->tables, data, create_mysql_table_relations); + if (data->fatal_error) + return -1; + putchar('\n'); + return 0; +} + +static void print_table_summary(const void *key, void *val, void *data_) +{ + struct table *table = val; + struct restore *data = data_; + + if (!table->all_rows) { + fprintf(data->raport_stream, + "Nie wstawiono zadnych rekordow do tabeli `table->name`.\n"); + return; + } + + fprintf(data->raport_stream, + "Rekordy wstawione do tabeli `%s`: %ld/%ld\n", + table->name, (long) table->inserted_rows, + (long) table->all_rows); + + if (!table->insertion_errors) + return; + + fprintf(data->raport_stream, + "Bledy przy wypelnianiu tabeli `%s`: %ld\n", + table->name, (long) table->insertion_errors); +} + +static void print_entire_summary(struct restore *data) +{ + ht_map(&data->tables, data, print_table_summary); + fputc('\n', data->raport_stream); +} + +static inline void error_exit(const char *msg, ...) +{ + va_list ap; + + va_start(ap, msg); + vfprintf(stderr, msg, ap); + va_end(ap); + fputc('\n', stderr); + + exit(-1); +} + +const char foreign_checks_off[] = "SET FOREIGN_KEY_CHECKS = 0;"; +const char foreign_checks_on[] = "SET FOREIGN_KEY_CHECKS = 1;"; +const char set_utf8[] = "SET NAMES 'utf8';"; + +int main(int argc, char **argv) +{ + const char *xml_path, *db_host, *db_user, *db_password, *db_name, + *raport_path = NULL; + int db_port; + xmlSAXHandler sax1, sax2; + struct restore data; + int i; + + if (argc < 7) { + error_exit("Zbyt malo argumentów. Oczekiwane:\n" + "restore SCIEZKA_PLIKU_XML ADRES_BAZY PORT UZYTKOWNIK HASLO NAZWA_BAZY [SCIEZKA_PLIKU_RAPORTU]"); + } + + xml_path = argv[1]; + db_host = argv[2]; + db_port = atoi(argv[3]); + db_user = argv[4]; + db_password = argv[5]; + db_name = argv[6]; + + if (argc >= 8) + raport_path = argv[7]; + + memset(&sax1, 0, sizeof(sax1)); + if (init_restore_data(&data, raport_path)) + error_exit("Blad inicjalizacji wewnetrznych struktur."); + + sax1.initialized = XML_SAX2_MAGIC; + sax1.startElement = startElement_callback; + sax1.endElement = endElement_callback; + sax1.characters = characters_callback; + + sax2 = sax1; + + sax1.error = error_callback; + + data.pass = PASS_1; + xmlSAXUserParseFile(&sax1, &data, xml_path); + if (data.fatal_error || !data.tables.entries) + error_exit("Nie udalo sie ustalic schematu bazy."); + if (data.errors_count) { + fprintf(data.raport_stream, + "Byly bledy podczas czytania pliku xml (%ld).\n", + (long) data.errors_count); + } + + putc('\n', data.raport_stream); + + /* initialize connection handler */ + data.mysql = mysql_init(NULL); + if (!data.mysql) + error_exit("Blad inicializacji libmysql (prawdopodobnie brak pamieci)."); + + /* connect to server */ + if (!mysql_real_connect(data.mysql, + db_host, db_user, db_password, db_name, db_port, + NULL, 0)) { + fprintf(data.raport_stream, + "Nie udalo sie polaczyc z baza.\n" + "Blad %u (%s): %s\n", + mysql_errno(data.mysql), + mysql_sqlstate(data.mysql), + mysql_error(data.mysql)); + mysql_close(data.mysql); + return -1; + } + + printf("Polaczono z baza\n\n"); + + /* Re-enabling foreign checks later is not needed. */ + if (mysql_real_query(data.mysql, foreign_checks_off, + sizeof(foreign_checks_off) - 1)) { + fprintf(data.raport_stream, + "Nie udalo sie wylaczyc sprawdzania spojnosci relacji.\n" + "Blad %u (%s): %s\n", + mysql_errno(data.mysql), + mysql_sqlstate(data.mysql), + mysql_error(data.mysql)); + error_exit("Blad podczas wylaczania sprawdzania spojnosci relacji."); + } + + /* Setting utf-8 for communication between database and the client. */ + if (mysql_real_query(data.mysql, set_utf8, sizeof(set_utf8) - 1)) { + data.errors_count++; + fprintf(data.raport_stream, + "Nie udalo sie ustawic kodowania utf8.\n" + "Blad %u (%s): %s\n", + mysql_errno(data.mysql), + mysql_sqlstate(data.mysql), + mysql_error(data.mysql)); + } + + if (clear_database(&data, "nowabaze")) { + error_exit("Blad podczas usuwania starej zawartosci bazy."); + } + + if (create_mysql_tables(&data)) + error_exit("Blad podczas tworzenia tabel."); + + check_tables_relations(&data); + fprint_all_tables(data.raport_stream, &data.tables); + + if (complete_mysql_tables(&data)) + error_exit("Blad podczas uzupelniania utworzonych tabel."); + + data.pass = 2; + xmlSAXUserParseFile(&sax2, &data, xml_path); + putchar('\n'); + fputc('\n', data.raport_stream); + print_entire_summary(&data); + if (data.fatal_error) + error_exit("Fatalny blad podczas umieszczania danych w bazie."); + +#define OS(n) ((n) == 0 ? data.raport_stream : stdout) + for (i = data.raport_stream == stdout ? 0 : 1; i >= 0; i--) { + fprintf(OS(i), "Operacja zakonczona"); + if (data.errors_count) { + fprintf(OS(i), "; bledow lacznie: %ld\n", + (long) data.errors_count); + } else { + fprintf(OS(i), " bez bledow.\n"); + } + } + + mysql_close(data.mysql); + + destroy_restore_data(&data); + + return 0; +} -- cgit v1.2.3