aboutsummaryrefslogtreecommitdiff
path: root/gnu/packages/patches/python-pytorch-fix-codegen.patch
blob: cb246b25de75fc45fd7b25bdbca26547cad56b9b (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
This patch fixes some scripts for generating source files.  For
gen_jit_decompositions.py, gen_mobile_upgraders.py and
gen_jit_shape_functions.py, which depend on the compiled PyTorch library, the
option to generate "dummy" source files is added for the initial build, which
is later corrected.  codegen_external.py is patched to avoid duplicate
functions and add the static keyword as in the existing generated file.

diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh
index cc0263dbbf..ac34e84b82 100644
--- a/tools/gen_flatbuffers.sh
+++ b/tools/gen_flatbuffers.sh
@@ -1,13 +1,13 @@
 #!/bin/bash
 ROOT=$(pwd)
-FF_LOCATION="$ROOT/third_party/flatbuffers"
-cd "$FF_LOCATION" || exit
-mkdir build
-cd build || exit
-cmake ..
-cmake --build . --target flatc
-mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
-./flatc --cpp --gen-mutable --scoped-enums \
+#FF_LOCATION="$ROOT/third_party/flatbuffers"
+#cd "$FF_LOCATION" || exit
+#mkdir build
+#cd build || exit
+#cmake ..
+#cmake --build . --target flatc
+#mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
+flatc --cpp --gen-mutable --scoped-enums \
      -o "$ROOT/torch/csrc/jit/serialization" \
      -c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs"
 echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h"
diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py
index bc69b05162..0f8df81de3 100644
--- a/torch/csrc/jit/tensorexpr/codegen_external.py
+++ b/torch/csrc/jit/tensorexpr/codegen_external.py
@@ -20,9 +20,14 @@ def gen_external(native_functions_path, tags_path, external_path):
     native_functions = parse_native_yaml(native_functions_path, tags_path)
     func_decls = []
     func_registrations = []
-    for func in native_functions:
+    done_names = set()
+    for func in native_functions[0]:
         schema = func.func
         name = schema.name.name.base
+        if name in done_names:
+            continue
+        else:
+            done_names.add(name)
         args = schema.arguments
         # Only supports extern calls for functions with out variants
         if not schema.is_out_fn():
@@ -62,7 +67,7 @@ def gen_external(native_functions_path, tags_path, external_path):
 
         # print(tensor_decls, name, arg_names)
         func_decl = f"""\
-void nnc_aten_{name}(
+static void nnc_aten_{name}(
     int64_t bufs_num,
     void** buf_data,
     int64_t* buf_ranks,
diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py
index 7cfbb803f9..2e69bb1868 100644
--- a/torchgen/decompositions/gen_jit_decompositions.py
+++ b/torchgen/decompositions/gen_jit_decompositions.py
@@ -1,8 +1,12 @@
 #!/usr/bin/env python3
 import os
 from pathlib import Path
+import sys
 
-from torch.jit._decompositions import decomposition_table
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    from torch.jit._decompositions import decomposition_table
+else:
+    decomposition_table = {}
 
 # from torchgen.code_template import CodeTemplate
 
@@ -85,7 +89,7 @@ def write_decomposition_util_file(path: str) -> None:
 
 
 def main() -> None:
-    pytorch_dir = Path(__file__).resolve().parents[3]
+    pytorch_dir = Path(__file__).resolve().parents[2]
     upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
     write_decomposition_util_file(str(upgrader_path))
 
diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py
index dab1568580..55c58715fc 100644
--- a/torchgen/operator_versions/gen_mobile_upgraders.py
+++ b/torchgen/operator_versions/gen_mobile_upgraders.py
@@ -2,10 +2,12 @@
 import os
 from enum import Enum
 from pathlib import Path
+import sys
 from typing import Any, Dict, List
 
-import torch
-from torch.jit.generate_bytecode import generate_upgraders_bytecode
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    import torch
+    from torch.jit.generate_bytecode import generate_upgraders_bytecode
 
 from torchgen.code_template import CodeTemplate
 from torchgen.operator_versions.gen_mobile_upgraders_constant import (
@@ -262,7 +264,10 @@ def construct_register_size(register_size_from_yaml: int) -> str:
 def construct_version_maps(
     upgrader_bytecode_function_to_index_map: Dict[str, Any]
 ) -> str:
-    version_map = torch._C._get_operator_version_map()
+    if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+        version_map = torch._C._get_operator_version_map()
+    else:
+        version_map = {}
     sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0])  # type: ignore[no-any-return]
     sorted_version_map = dict(sorted_version_map_)
 
@@ -378,7 +383,10 @@ def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 
 def main() -> None:
-    upgrader_list = generate_upgraders_bytecode()
+    if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+        upgrader_list = generate_upgraders_bytecode()
+    else:
+        upgrader_list = []
     sorted_upgrader_list = sort_upgrader(upgrader_list)
     for up in sorted_upgrader_list:
         print("after sort upgrader : ", next(iter(up)))
diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py
index c6336a6951..34e394d818 100644
--- a/torchgen/shape_functions/gen_jit_shape_functions.py
+++ b/torchgen/shape_functions/gen_jit_shape_functions.py
@@ -18,16 +18,20 @@ you are in the root directory of the Pytorch git repo"""
 if not file_path.exists():
     raise Exception(err_msg)
 
-spec = importlib.util.spec_from_file_location(module_name, file_path)
-assert spec is not None
-module = importlib.util.module_from_spec(spec)
-sys.modules[module_name] = module
-assert spec.loader is not None
-assert module is not None
-spec.loader.exec_module(module)
-
-bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
-shape_compute_graph_mapping = module.shape_compute_graph_mapping
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    assert spec is not None
+    module = importlib.util.module_from_spec(spec)
+    sys.modules[module_name] = module
+    assert spec.loader is not None
+    assert module is not None
+    spec.loader.exec_module(module)
+
+    bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
+    shape_compute_graph_mapping = module.shape_compute_graph_mapping
+else:
+    bounded_compute_graph_mapping = {}
+    shape_compute_graph_mapping = {}
 
 
 SHAPE_HEADER = r"""