ast_nodes.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright 2022 DeepMind Technologies Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Classes that roughly correspond to Clang AST node types."""
  16. import collections
  17. import dataclasses
  18. import re
  19. from typing import Dict, Optional, Sequence, Tuple, Union
  20. # We are relying on Clang to do the actual source parsing and are only doing
  21. # a little bit of extra parsing of function parameter type declarations here.
  22. # These patterns are here for sanity checking rather than actual parsing.
  23. VALID_TYPE_NAME_PATTERN = re.compile('(struct )?[A-Za-z_][A-Za-z0-9_]*')
  24. C_INVALID_TYPE_NAMES = frozenset([
  25. 'auto', 'break', 'case', 'const', 'continue', 'default', 'do', 'else',
  26. 'enum', 'extern', 'for', 'goto', 'if', 'inline', 'register', 'restrict',
  27. 'return', 'sizeof', 'static', 'struct', 'switch', 'typedef', 'union',
  28. 'volatile', 'while', '_Alignas', '_Atomic', '_Generic', '_Imaginary',
  29. '_Noreturn', '_Static_assert', '_Thread_local', '__attribute__', '_Pragma'])
  30. def _is_valid_integral_type(type_str: str):
  31. """Checks if a string is a valid integral type."""
  32. parts = re.split(r'\s+', type_str)
  33. counter = collections.defaultdict(lambda: 0)
  34. wildcard_counter = 0
  35. for part in parts:
  36. if part in ('signed', 'unsigned', 'short', 'long', 'int', 'char'):
  37. counter[part] += 1
  38. elif VALID_TYPE_NAME_PATTERN.fullmatch(part):
  39. # a non-keyword can be a typedef for int
  40. wildcard_counter += 1
  41. else:
  42. return False
  43. if (counter['signed'] + counter['unsigned'] > 1 or
  44. counter['short'] > 1 or counter['long'] > 2 or
  45. (counter['short'] and counter['long']) or
  46. ((counter['short'] or counter['long']) and counter['char']) or
  47. counter['char'] + counter['int'] + wildcard_counter > 1):
  48. return False
  49. else:
  50. return True
  51. @dataclasses.dataclass
  52. class ValueType:
  53. """Represents a C type that is neither a pointer type nor an array type."""
  54. name: str
  55. is_const: bool = False
  56. is_volatile: bool = False
  57. def __init__(self, name: str, is_const: bool = False,
  58. is_volatile: bool = False):
  59. is_valid_type_name = (
  60. name == 'void *(*)(void *)' or
  61. VALID_TYPE_NAME_PATTERN.fullmatch(name) or
  62. _is_valid_integral_type(name)) and name not in C_INVALID_TYPE_NAMES
  63. if not is_valid_type_name:
  64. raise ValueError(f'{name!r} is not a valid value type name')
  65. self.name = name
  66. self.is_const = is_const
  67. self.is_volatile = is_volatile
  68. def decl(self, name_or_decl: Optional[str] = None) -> str:
  69. parts = []
  70. if self.is_const:
  71. parts.append('const')
  72. if self.is_volatile:
  73. parts.append('volatile')
  74. parts.append(self.name)
  75. if name_or_decl:
  76. parts.append(name_or_decl)
  77. return ' '.join(parts)
  78. def __str__(self):
  79. return self.decl()
  80. @dataclasses.dataclass
  81. class ArrayType:
  82. """Represents a C array type."""
  83. inner_type: Union[ValueType, 'PointerType']
  84. extents: Tuple[int, ...]
  85. def __init__(self, inner_type: Union[ValueType, 'PointerType'],
  86. extents: Sequence[int]):
  87. self.inner_type = inner_type
  88. self.extents = tuple(extents)
  89. @property
  90. def _extents_str(self) -> str:
  91. return ''.join(f'[{n}]' for n in self.extents)
  92. def decl(self, name_or_decl: Optional[str] = None) -> str:
  93. name_or_decl = name_or_decl or ''
  94. return self.inner_type.decl(f'{name_or_decl}{self._extents_str}')
  95. def __str__(self):
  96. return self.decl()
  97. @dataclasses.dataclass
  98. class PointerType:
  99. """Represents a C pointer type."""
  100. inner_type: Union[ValueType, ArrayType, 'PointerType']
  101. is_const: bool = False
  102. is_volatile: bool = False
  103. is_restrict: bool = False
  104. def decl(self, name_or_decl: Optional[str] = None) -> str:
  105. """Creates a string that declares an object of this type."""
  106. parts = ['*']
  107. if self.is_const:
  108. parts.append('const')
  109. if self.is_volatile:
  110. parts.append('volatile')
  111. if self.is_restrict:
  112. parts.append('restrict')
  113. if name_or_decl:
  114. parts.append(name_or_decl)
  115. ptr_decl = ' '.join(parts)
  116. if isinstance(self.inner_type, ArrayType):
  117. ptr_decl = f'({ptr_decl})'
  118. return self.inner_type.decl(ptr_decl)
  119. def __str__(self):
  120. return self.decl()
  121. @dataclasses.dataclass
  122. class FunctionParameterDecl:
  123. """Represents a parameter in a function declaration.
  124. Note that according to the C language rule, a function parameter of array
  125. type undergoes array-to-pointer decay, and therefore appears as a pointer
  126. parameter in an actual C AST. We retain the arrayness of a parameter here
  127. since the array's extents are informative.
  128. """
  129. name: str
  130. type: Union[ValueType, ArrayType, PointerType]
  131. def __str__(self):
  132. return self.type.decl(self.name)
  133. @property
  134. def decltype(self) -> str:
  135. return self.type.decl()
  136. @dataclasses.dataclass
  137. class FunctionDecl:
  138. """Represents a function declaration."""
  139. name: str
  140. return_type: Union[ValueType, ArrayType, PointerType]
  141. parameters: Tuple[FunctionParameterDecl, ...]
  142. doc: str
  143. def __init__(self, name: str,
  144. return_type: Union[ValueType, ArrayType, PointerType],
  145. parameters: Sequence[FunctionParameterDecl],
  146. doc: str):
  147. self.name = name
  148. self.return_type = return_type
  149. self.parameters = tuple(parameters)
  150. self.doc = doc
  151. def __str__(self):
  152. param_str = ', '.join(str(p) for p in self.parameters)
  153. return f'{self.return_type} {self.name}({param_str})'
  154. @property
  155. def decltype(self) -> str:
  156. param_str = ', '.join(str(p.decltype) for p in self.parameters)
  157. return f'{self.return_type} ({param_str})'
  158. class _EnumDeclValues(Dict[str, int]):
  159. """A dict with modified stringified representation.
  160. The __repr__ method of this class adds a trailing comma to the list of values.
  161. This is done as a hint for code formatters to place one item per line when
  162. the stringified OrderedDict is used in generated Python code.
  163. """
  164. def __repr__(self):
  165. out = super().__repr__()
  166. if self:
  167. out = re.sub(r'\(\[(.+)\]\)\Z', r'([\1,])', out)
  168. return re.sub(r'\A_EnumDeclValues', 'dict', out)
  169. @dataclasses.dataclass
  170. class EnumDecl:
  171. """Represents an enum declaration."""
  172. name: str
  173. declname: str
  174. values: Dict[str, int]
  175. def __init__(self, name: str, declname: str, values: Dict[str, int]):
  176. self.name = name
  177. self.declname = declname
  178. self.values = _EnumDeclValues(values)
  179. @dataclasses.dataclass
  180. class StructFieldDecl:
  181. """Represents a field in a struct or union declaration."""
  182. name: str
  183. type: Union[
  184. ValueType,
  185. ArrayType,
  186. PointerType,
  187. 'AnonymousStructDecl',
  188. 'AnonymousUnionDecl',
  189. ]
  190. doc: str
  191. array_extent: Optional[Tuple[Union[str, int], ...]] = None
  192. def __str__(self):
  193. return self.type.decl(self.name)
  194. @property
  195. def decltype(self) -> str:
  196. return self.type.decl()
  197. @dataclasses.dataclass
  198. class AnonymousStructDecl:
  199. """Represents an anonymous struct declaration."""
  200. fields: Tuple[Union[StructFieldDecl, 'AnonymousUnionDecl'], ...]
  201. def __init__(self, fields: Sequence[StructFieldDecl]):
  202. self.fields = tuple(fields)
  203. def __str__(self):
  204. return self.decl()
  205. def _inner_decl(self):
  206. return '; '.join(str(field) for field in self.fields) + ';'
  207. def decl(self, name_or_decl: Optional[str] = None):
  208. parts = ['struct', f'{{{self._inner_decl()}}}']
  209. if name_or_decl:
  210. parts.append(name_or_decl)
  211. return ' '.join(parts)
  212. class AnonymousUnionDecl(AnonymousStructDecl):
  213. """Represents an anonymous union declaration."""
  214. def decl(self, name_or_decl: Optional[str] = None):
  215. parts = ['union', f'{{{self._inner_decl()}}}']
  216. if name_or_decl:
  217. parts.append(name_or_decl)
  218. return ' '.join(parts)
  219. @dataclasses.dataclass
  220. class StructDecl:
  221. """Represents a struct declaration."""
  222. name: str
  223. declname: str
  224. fields: Tuple[Union[StructFieldDecl, AnonymousUnionDecl], ...]
  225. def __init__(self, name: str,
  226. declname: str,
  227. fields: Sequence[Union[StructFieldDecl, AnonymousUnionDecl]]):
  228. self.name = name
  229. self.declname = declname
  230. self.fields = tuple(fields)
  231. def decl(self, name_or_decl: Optional[str] = None) -> str:
  232. parts = [self.name]
  233. if name_or_decl:
  234. parts.append(name_or_decl)
  235. return ' '.join(parts)