replace_copyright.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. #!/usr/bin/env python3
  2. # Copyright (c) Microsoft Corporation.
  3. # SPDX-License-Identifier: Apache-2.0
  4. # DeepSpeed Team
  5. """
  6. USAGE:
  7. $ python3 script/replace_copyright.py --repo_dir ./
  8. """
  9. import os
  10. import argparse
  11. NEW_COPYRIGHT = ("Copyright (c) Microsoft Corporation.", "SPDX-License-Identifier: Apache-2.0", "", "DeepSpeed Team")
  12. PY_SL_COMMENT = "#"
  13. PY_ML_SINGLE = "'''"
  14. PY_ML_DOUBLE = '"""'
  15. PY_COMMENTS = (PY_SL_COMMENT, PY_ML_SINGLE, PY_ML_DOUBLE)
  16. C_SL_COMMENT = "//"
  17. C_ML_OPEN = "/*"
  18. C_ML_CLOSE = "*/"
  19. C_COMMENTS = (C_SL_COMMENT, C_ML_OPEN, C_ML_CLOSE)
  20. BASH_SL_COMMENT = "#"
  21. BASH_COMMENTS = (BASH_SL_COMMENT, )
  22. DELIM = "|/-\|/-\|BARRIER|/-\|/-\|" # noqa: W605
  23. def parser_args():
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--repo_dir", type=str, help="Repository directory")
  26. parser.add_argument("--python_style_ext",
  27. type=str,
  28. nargs="+",
  29. default=[".py"],
  30. help="File types to process with python-style comments")
  31. parser.add_argument("--bash_style_ext",
  32. type=str,
  33. nargs="+",
  34. default=[".sh"],
  35. help="File types to process with bash-style comments")
  36. parser.add_argument("--c_style_ext",
  37. type=str,
  38. nargs="+",
  39. default=[
  40. ".c",
  41. ".cpp",
  42. ".cu",
  43. ".h",
  44. ".hpp",
  45. ".cuh",
  46. ".cc",
  47. ".hip",
  48. ".tr",
  49. ],
  50. help="File types to process with C-style comments")
  51. args = parser.parse_args()
  52. return args
  53. # These get_header_* functions are ugly, but they work :)
  54. def get_header_py(fp):
  55. with open(fp, "r") as f:
  56. lines = iter(l for l in f.readlines())
  57. header = []
  58. rest = []
  59. in_multiline = False
  60. multiline_type = None
  61. while (l := next(lines, None)) is not None:
  62. l = l.strip()
  63. if l.startswith(PY_ML_SINGLE) or l.startswith(PY_ML_DOUBLE):
  64. # Detected multiline comment
  65. if in_multiline and multiline_type == l[:3]:
  66. # Ended a multiline comment
  67. in_multiline = False
  68. else:
  69. # Started a multiline comment
  70. in_multiline = True
  71. multiline_type = l[:3]
  72. if l.endswith(multiline_type) and len(l) >= 6:
  73. # Opened and closed multiline comment on single line
  74. in_multiline = False
  75. elif in_multiline and l.endswith(multiline_type):
  76. # Ended a multiline comment
  77. in_multiline = False
  78. elif not (in_multiline or l.startswith(PY_SL_COMMENT) or l == ""):
  79. # Not in a comment
  80. rest += [l + "\n"]
  81. break
  82. header.append(l)
  83. rest += list(lines)
  84. return header, rest
  85. def get_header_c(fp):
  86. with open(fp, "r") as f:
  87. lines = iter(l for l in f.readlines())
  88. header = []
  89. rest = []
  90. in_multiline = False
  91. while (l := next(lines, None)) is not None:
  92. l = l.strip()
  93. if l.startswith(C_ML_OPEN):
  94. # Detected multiline comment
  95. if not l.endswith(C_ML_CLOSE):
  96. # multiline comment not closed on same line
  97. in_multiline = True
  98. elif l.endswith(C_ML_CLOSE):
  99. # Ended a multline comment
  100. in_multiline = False
  101. elif not in_multiline or l.startswith(C_SL_COMMENT) or l.isspace():
  102. # Not in a comment
  103. rest += [l + "\n"]
  104. break
  105. header.append(l)
  106. rest += list(lines)
  107. return header, rest
  108. def get_header_bash(fp):
  109. with open(fp, "r") as f:
  110. lines = iter(l for l in f.readlines())
  111. header = []
  112. rest = []
  113. while (l := next(lines, None)) is not None:
  114. l = l.strip()
  115. if not l.startswith(BASH_SL_COMMENT) or l.isspace():
  116. # Not in a comment
  117. rest += [l + "\n"]
  118. break
  119. header.append(l)
  120. rest += list(lines)
  121. return header, rest
  122. def remove_comments(line, comment_strs):
  123. for cstr in comment_strs:
  124. line = line.replace(cstr, "")
  125. return line
  126. def format_multiline_comment(text, comment_type):
  127. if comment_type == PY_COMMENTS:
  128. text = f"\n{comment_type[2]}\n" + "\n".join(text) + f"{comment_type[2]}"
  129. if comment_type == C_COMMENTS:
  130. text = f"\n{comment_type[1]}\n" + "\n".join(text) + f"{comment_type[2]}"
  131. if comment_type == BASH_COMMENTS:
  132. text = "\n".join([f"{comment_type[0]}{l}" for l in text])
  133. return text
  134. def modify_file_header(fp, file_header, rest_of_file, preserve_text_store, comment_type):
  135. header_text = "\n".join(file_header)
  136. if not (header_text.strip() == "" or header_text in preserve_text_store):
  137. # Unique header, need to get user input
  138. print("\n", DELIM, "\n")
  139. for idx, line in enumerate(file_header):
  140. print(f"{idx}: {line}")
  141. print("\n", DELIM, "\n")
  142. print("\nIndicate the FIRST line of the Header to KEEP")
  143. print("(shebang #! lines will be automatically processed and should not be included).")
  144. keep_idx = input("Enter number (or leave blank if no lines should be preserved): ")
  145. preserve_text_store[header_text] = file_header[int(keep_idx):] if keep_idx != "" else ""
  146. # Identify any shebang lines in the file
  147. shebang = "\n".join([l for l in file_header if l.startswith("#!")])
  148. if shebang != "":
  149. shebang += "\n"
  150. # Get the text we should preserve in this file and process to remove comment characters
  151. text_to_preserve = preserve_text_store.get(header_text, [""])
  152. text_to_preserve = [remove_comments(l, comment_type) for l in text_to_preserve]
  153. # Format the text we want to keep into a new multiline comment
  154. if "".join(text_to_preserve) == "":
  155. text_to_preserve = ""
  156. else:
  157. text_to_preserve = format_multiline_comment(text_to_preserve, comment_type)
  158. # Generate the copyright text we will be adding
  159. copyright_text = "\n".join([f"{comment_type[0]} {l}" if l != "" else l for l in NEW_COPYRIGHT])
  160. # Assemble the new header
  161. new_header = shebang + copyright_text + text_to_preserve
  162. # Write out the new file
  163. new_file_contents = new_header + "\n" + "".join(rest_of_file)
  164. with open(fp, "w") as f:
  165. f.write(new_file_contents)
  166. return preserve_text_store # Return so we can reuse for future files
  167. def main(args):
  168. preserve_text_store = {} # Used to track header comments we should preserve
  169. for root, dirs, fnames in os.walk(args.repo_dir):
  170. # Walk across directory looking for all files with extensions we want to modify
  171. for ext in args.python_style_ext:
  172. fpaths = [os.path.join(root, fn) for fn in fnames if fn.endswith(ext)]
  173. for fp in fpaths:
  174. file_header, rest_of_file = get_header_py(fp)
  175. preserve_text_store = modify_file_header(fp, file_header, rest_of_file, preserve_text_store,
  176. PY_COMMENTS)
  177. for ext in args.c_style_ext:
  178. fpaths = [os.path.join(root, fn) for fn in fnames if fn.endswith(ext)]
  179. for fp in fpaths:
  180. file_header, rest_of_file = get_header_c(fp)
  181. preserve_text_store = modify_file_header(fp, file_header, rest_of_file, preserve_text_store,
  182. C_COMMENTS)
  183. for ext in args.bash_style_ext:
  184. fpaths = [os.path.join(root, fn) for fn in fnames if fn.endswith(ext)]
  185. for fp in fpaths:
  186. file_header, rest_of_file = get_header_bash(fp)
  187. preserve_text_store = modify_file_header(fp, file_header, rest_of_file, preserve_text_store,
  188. BASH_COMMENTS)
  189. if __name__ == "__main__":
  190. args = parser_args()
  191. main(args)