ast_nodes_test.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright 2022 DeepMind Technologies Limited
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for ast_nodes.py."""
  16. from absl.testing import absltest
  17. from . import ast_nodes
  18. class AstNodesTest(absltest.TestCase):
  19. def test_value_type(self):
  20. value_type = ast_nodes.ValueType('int')
  21. self.assertEqual(str(value_type), 'int')
  22. self.assertEqual(value_type.decl('var'), 'int var')
  23. const_value_type = ast_nodes.ValueType('double', is_const=True)
  24. self.assertEqual(str(const_value_type), 'const double')
  25. self.assertEqual(const_value_type.decl('var2'), 'const double var2')
  26. def test_pointer_type(self):
  27. pointer_type = ast_nodes.PointerType(ast_nodes.ValueType('int'))
  28. self.assertEqual(str(pointer_type), 'int *')
  29. self.assertEqual(pointer_type.decl('var'), 'int * var')
  30. const_pointer_type = ast_nodes.PointerType(
  31. ast_nodes.ValueType('double'), is_const=True)
  32. self.assertEqual(str(const_pointer_type), 'double * const')
  33. self.assertEqual(const_pointer_type.decl('var2'), 'double * const var2')
  34. pointer_to_const_type = ast_nodes.PointerType(
  35. ast_nodes.ValueType('float', is_const=True))
  36. self.assertEqual(str(pointer_to_const_type), 'const float *')
  37. self.assertEqual(pointer_to_const_type.decl('var3'), 'const float * var3')
  38. restrict_volatile_pointer_to_const_type = ast_nodes.PointerType(
  39. ast_nodes.ValueType('char', is_const=True),
  40. is_volatile=True, is_restrict=True)
  41. self.assertEqual(str(restrict_volatile_pointer_to_const_type),
  42. 'const char * volatile restrict')
  43. self.assertEqual(
  44. restrict_volatile_pointer_to_const_type.decl('var4'),
  45. 'const char * volatile restrict var4')
  46. pointer_to_array_type = ast_nodes.PointerType(
  47. ast_nodes.ArrayType(ast_nodes.ValueType('long'), (3,)))
  48. self.assertEqual(str(pointer_to_array_type), 'long (*)[3]')
  49. self.assertEqual(pointer_to_array_type.decl('var5'), 'long (* var5)[3]')
  50. const_pointer_to_array_type = ast_nodes.PointerType(
  51. ast_nodes.ArrayType(ast_nodes.ValueType('unsigned int'), (4,)),
  52. is_const=True)
  53. self.assertEqual(
  54. str(const_pointer_to_array_type), 'unsigned int (* const)[4]')
  55. self.assertEqual(
  56. const_pointer_to_array_type.decl('var6'),
  57. 'unsigned int (* const var6)[4]')
  58. def test_array_type(self):
  59. array_type = ast_nodes.ArrayType(ast_nodes.ValueType('int'), (4,))
  60. self.assertEqual(str(array_type), 'int [4]')
  61. self.assertEqual(array_type.decl('var'), 'int var[4]')
  62. array_2d_type = ast_nodes.ArrayType(
  63. ast_nodes.ValueType('double', is_const=True), (2, 3))
  64. self.assertEqual(str(array_2d_type), 'const double [2][3]')
  65. self.assertEqual(array_2d_type.decl('var2'), 'const double var2[2][3]')
  66. array_to_pointer_type = ast_nodes.ArrayType(
  67. ast_nodes.PointerType(ast_nodes.ValueType('char', is_const=True)), (5,))
  68. self.assertEqual(str(array_to_pointer_type), 'const char * [5]')
  69. self.assertEqual(array_to_pointer_type.decl('var3'), 'const char * var3[5]')
  70. array_to_const_pointer_type = ast_nodes.ArrayType(
  71. ast_nodes.PointerType(ast_nodes.ValueType('float'), is_const=True),
  72. (7,))
  73. self.assertEqual(str(array_to_const_pointer_type), 'float * const [7]')
  74. self.assertEqual(
  75. array_to_const_pointer_type.decl('var4'), 'float * const var4[7]')
  76. def test_complex_type(self):
  77. complex_type = ast_nodes.ArrayType(
  78. extents=[9],
  79. inner_type=ast_nodes.PointerType(
  80. ast_nodes.PointerType(
  81. is_const=True,
  82. inner_type=ast_nodes.ArrayType(
  83. extents=[7],
  84. inner_type=ast_nodes.PointerType(
  85. is_const=True,
  86. inner_type=ast_nodes.PointerType(
  87. ast_nodes.ArrayType(
  88. extents=(3, 4),
  89. inner_type=ast_nodes.ValueType(
  90. 'unsigned int', is_const=True)
  91. )
  92. )
  93. )
  94. )
  95. )
  96. )
  97. )
  98. self.assertEqual(str(complex_type),
  99. 'const unsigned int (* * const (* const * [9])[7])[3][4]')
  100. self.assertEqual(
  101. complex_type.decl('var'),
  102. 'const unsigned int (* * const (* const * var[9])[7])[3][4]')
  103. def test_struct_decl(self):
  104. struct = ast_nodes.StructDecl(
  105. name='mystruct',
  106. declname='struct mystruct_',
  107. fields=[
  108. ast_nodes.StructFieldDecl(
  109. name='foo',
  110. type=ast_nodes.ValueType('int'),
  111. doc='',
  112. )
  113. ],
  114. )
  115. self.assertEqual(struct.decl('var'), 'mystruct var')
  116. def test_anonymous_struct_decl(self):
  117. struct = ast_nodes.AnonymousStructDecl(
  118. fields=[
  119. ast_nodes.StructFieldDecl(
  120. name='foo',
  121. type=ast_nodes.ValueType('int'),
  122. doc='',
  123. ),
  124. ast_nodes.StructFieldDecl(
  125. name='bar',
  126. type=ast_nodes.ArrayType(
  127. inner_type=ast_nodes.ValueType('float'), extents=(3,)
  128. ),
  129. doc='',
  130. ),
  131. ],
  132. )
  133. self.assertEqual(str(struct), 'struct {int foo; float bar[3];}')
  134. self.assertEqual(struct.decl('var'), 'struct {int foo; float bar[3];} var')
  135. self.assertEqual(struct.fields[0].decltype, 'int')
  136. self.assertEqual(struct.fields[1].decltype, 'float [3]')
  137. def test_anonymous_union_decl(self):
  138. union = ast_nodes.AnonymousUnionDecl(
  139. fields=[
  140. ast_nodes.StructFieldDecl(
  141. name='foo',
  142. type=ast_nodes.ValueType('int'),
  143. doc='',
  144. ),
  145. ast_nodes.StructFieldDecl(
  146. name='bar',
  147. type=ast_nodes.ArrayType(
  148. inner_type=ast_nodes.ValueType('float'), extents=(3,)
  149. ),
  150. doc='',
  151. ),
  152. ],
  153. )
  154. self.assertEqual(str(union), 'union {int foo; float bar[3];}')
  155. self.assertEqual(union.decl('var'), 'union {int foo; float bar[3];} var')
  156. if __name__ == '__main__':
  157. absltest.main()