aboutsummaryrefslogtreecommitdiff
This patch fixes some scripts for generating source files.  For
gen_jit_decompositions.py, gen_mobile_upgraders.py and
gen_jit_shape_functions.py, which depend on the compiled PyTorch library, the
option to generate "dummy" source files is added for the initial build, which
is later corrected.  codegen_external.py is patched to avoid duplicate
functions and add the static keyword as in the existing generated file.

diff --git a/tools/gen_flatbuffers.sh b/tools/gen_flatbuffers.sh
index cc0263dbbf..ac34e84b82 100644
--- a/tools/gen_flatbuffers.sh
+++ b/tools/gen_flatbuffers.sh
@@ -1,13 +1,13 @@
 #!/bin/bash
 ROOT=$(pwd)
-FF_LOCATION="$ROOT/third_party/flatbuffers"
-cd "$FF_LOCATION" || exit
-mkdir build
-cd build || exit
-cmake ..
-cmake --build . --target flatc
-mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
-./flatc --cpp --gen-mutable --scoped-enums \
+#FF_LOCATION="$ROOT/third_party/flatbuffers"
+#cd "$FF_LOCATION" || exit
+#mkdir build
+#cd build || exit
+#cmake ..
+#cmake --build . --target flatc
+#mkdir -p "$ROOT/build/torch/csrc/jit/serialization"
+flatc --cpp --gen-mutable --scoped-enums \
      -o "$ROOT/torch/csrc/jit/serialization" \
      -c "$ROOT/torch/csrc/jit/serialization/mobile_bytecode.fbs"
 echo '// @generated' >> "$ROOT/torch/csrc/jit/serialization/mobile_bytecode_generated.h"
diff --git a/torch/csrc/jit/tensorexpr/codegen_external.py b/torch/csrc/jit/tensorexpr/codegen_external.py
index 120520b139..0c8587f02d 100644
--- a/torch/csrc/jit/tensorexpr/codegen_external.py
+++ b/torch/csrc/jit/tensorexpr/codegen_external.py
@@ -16,9 +16,14 @@ def gen_external(native_functions_path, tags_path, external_path):
     native_functions = parse_native_yaml(native_functions_path, tags_path)
     func_decls = []
     func_registrations = []
-    for func in native_functions:
+    done_names = set()
+    for func in native_functions[0]:
         schema = func.func
         name = schema.name.name.base
+        if name in done_names:
+            continue
+        else:
+            done_names.add(name)
         args = schema.arguments
         # Only supports extern calls for functions with out variants
         if not schema.is_out_fn():
@@ -48,7 +53,7 @@ def gen_external(native_functions_path, tags_path, external_path):
 
         # print(tensor_decls, name, arg_names)
         func_decl = f"""\
-void nnc_aten_{name}(
+static void nnc_aten_{name}(
     int64_t bufs_num,
     void** buf_data,
     int64_t* buf_ranks,
diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py
index 7cfbb803f9..2e69bb1868 100644
--- a/torchgen/decompositions/gen_jit_decompositions.py
+++ b/torchgen/decompositions/gen_jit_decompositions.py
@@ -1,8 +1,12 @@
 #!/usr/bin/env python3
 import os
 from pathlib import Path
+import sys
 
-from torch.jit._decompositions import decomposition_table
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    from torch.jit._decompositions import decomposition_table
+else:
+    decomposition_table = {}
 
 # from torchgen.code_template import CodeTemplate
 
@@ -85,7 +89,7 @@ def write_decomposition_util_file(path: str) -> None:
 
 
 def main() -> None:
-    pytorch_dir = Path(__file__).resolve().parents[3]
+    pytorch_dir = Path(__file__).resolve().parents[2]
     upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
     write_decomposition_util_file(str(upgrader_path))
 
diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py
index e5287cffc5..57f3c38096 100644
--- a/torchgen/operator_versions/gen_mobile_upgraders.py
+++ b/torchgen/operator_versions/gen_mobile_upgraders.py
@@ -2,10 +2,12 @@
 import os
 from enum import Enum
 from pathlib import Path
+import sys
 from typing import Any, Dict, List
 
-import torch
-from torch.jit.generate_bytecode import generate_upgraders_bytecode
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    import torch
+    from torch.jit.generate_bytecode import generate_upgraders_bytecode
 
 from torchgen.code_template import CodeTemplate
 from torchgen.operator_versions.gen_mobile_upgraders_constant import (
@@ -262,7 +264,10 @@ def construct_register_size(register_size_from_yaml: int) -> str:
 def construct_version_maps(
     upgrader_bytecode_function_to_index_map: Dict[str, Any]
 ) -> str:
-    version_map = torch._C._get_operator_version_map()
+    if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+        version_map = torch._C._get_operator_version_map()
+    else:
+        version_map = {}
     sorted_version_map_ = sorted(version_map.items(), key=lambda item: item[0])  # type: ignore[no-any-return]
     sorted_version_map = {name: lst for name, lst in sorted_version_map_}
 
@@ -379,7 +384,10 @@ def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
 
 def main() -> None:
 
-    upgrader_list = generate_upgraders_bytecode()
+    if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+        upgrader_list = generate_upgraders_bytecode()
+    else:
+        upgrader_list = []
     sorted_upgrader_list = sort_upgrader(upgrader_list)
     for up in sorted_upgrader_list:
         print("after sort upgrader : ", next(iter(up)))
diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py
index c6336a6951..34e394d818 100644
--- a/torchgen/shape_functions/gen_jit_shape_functions.py
+++ b/torchgen/shape_functions/gen_jit_shape_functions.py
@@ -18,16 +18,20 @@ you are in the root directory of the Pytorch git repo"""
 if not file_path.exists():
     raise Exception(err_msg)
 
-spec = importlib.util.spec_from_file_location(module_name, file_path)
-assert spec is not None
-module = importlib.util.module_from_spec(spec)
-sys.modules[module_name] = module
-assert spec.loader is not None
-assert module is not None
-spec.loader.exec_module(module)
-
-bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
-shape_compute_graph_mapping = module.shape_compute_graph_mapping
+if len(sys.argv) < 2 or sys.argv[1] != "dummy":
+    spec = importlib.util.spec_from_file_location(module_name, file_path)
+    assert spec is not None
+    module = importlib.util.module_from_spec(spec)
+    sys.modules[module_name] = module
+    assert spec.loader is not None
+    assert module is not None
+    spec.loader.exec_module(module)
+
+    bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
+    shape_compute_graph_mapping = module.shape_compute_graph_mapping
+else:
+    bounded_compute_graph_mapping = {}
+    shape_compute_graph_mapping = {}
 
 
 SHAPE_HEADER = r"""