123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- # Copyright 2022 DeepMind Technologies Limited
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Tests for ast_nodes.py."""
- from absl.testing import absltest
- from . import ast_nodes
- class AstNodesTest(absltest.TestCase):
- def test_value_type(self):
- value_type = ast_nodes.ValueType('int')
- self.assertEqual(str(value_type), 'int')
- self.assertEqual(value_type.decl('var'), 'int var')
- const_value_type = ast_nodes.ValueType('double', is_const=True)
- self.assertEqual(str(const_value_type), 'const double')
- self.assertEqual(const_value_type.decl('var2'), 'const double var2')
- def test_pointer_type(self):
- pointer_type = ast_nodes.PointerType(ast_nodes.ValueType('int'))
- self.assertEqual(str(pointer_type), 'int *')
- self.assertEqual(pointer_type.decl('var'), 'int * var')
- const_pointer_type = ast_nodes.PointerType(
- ast_nodes.ValueType('double'), is_const=True)
- self.assertEqual(str(const_pointer_type), 'double * const')
- self.assertEqual(const_pointer_type.decl('var2'), 'double * const var2')
- pointer_to_const_type = ast_nodes.PointerType(
- ast_nodes.ValueType('float', is_const=True))
- self.assertEqual(str(pointer_to_const_type), 'const float *')
- self.assertEqual(pointer_to_const_type.decl('var3'), 'const float * var3')
- restrict_volatile_pointer_to_const_type = ast_nodes.PointerType(
- ast_nodes.ValueType('char', is_const=True),
- is_volatile=True, is_restrict=True)
- self.assertEqual(str(restrict_volatile_pointer_to_const_type),
- 'const char * volatile restrict')
- self.assertEqual(
- restrict_volatile_pointer_to_const_type.decl('var4'),
- 'const char * volatile restrict var4')
- pointer_to_array_type = ast_nodes.PointerType(
- ast_nodes.ArrayType(ast_nodes.ValueType('long'), (3,)))
- self.assertEqual(str(pointer_to_array_type), 'long (*)[3]')
- self.assertEqual(pointer_to_array_type.decl('var5'), 'long (* var5)[3]')
- const_pointer_to_array_type = ast_nodes.PointerType(
- ast_nodes.ArrayType(ast_nodes.ValueType('unsigned int'), (4,)),
- is_const=True)
- self.assertEqual(
- str(const_pointer_to_array_type), 'unsigned int (* const)[4]')
- self.assertEqual(
- const_pointer_to_array_type.decl('var6'),
- 'unsigned int (* const var6)[4]')
- def test_array_type(self):
- array_type = ast_nodes.ArrayType(ast_nodes.ValueType('int'), (4,))
- self.assertEqual(str(array_type), 'int [4]')
- self.assertEqual(array_type.decl('var'), 'int var[4]')
- array_2d_type = ast_nodes.ArrayType(
- ast_nodes.ValueType('double', is_const=True), (2, 3))
- self.assertEqual(str(array_2d_type), 'const double [2][3]')
- self.assertEqual(array_2d_type.decl('var2'), 'const double var2[2][3]')
- array_to_pointer_type = ast_nodes.ArrayType(
- ast_nodes.PointerType(ast_nodes.ValueType('char', is_const=True)), (5,))
- self.assertEqual(str(array_to_pointer_type), 'const char * [5]')
- self.assertEqual(array_to_pointer_type.decl('var3'), 'const char * var3[5]')
- array_to_const_pointer_type = ast_nodes.ArrayType(
- ast_nodes.PointerType(ast_nodes.ValueType('float'), is_const=True),
- (7,))
- self.assertEqual(str(array_to_const_pointer_type), 'float * const [7]')
- self.assertEqual(
- array_to_const_pointer_type.decl('var4'), 'float * const var4[7]')
- def test_complex_type(self):
- complex_type = ast_nodes.ArrayType(
- extents=[9],
- inner_type=ast_nodes.PointerType(
- ast_nodes.PointerType(
- is_const=True,
- inner_type=ast_nodes.ArrayType(
- extents=[7],
- inner_type=ast_nodes.PointerType(
- is_const=True,
- inner_type=ast_nodes.PointerType(
- ast_nodes.ArrayType(
- extents=(3, 4),
- inner_type=ast_nodes.ValueType(
- 'unsigned int', is_const=True)
- )
- )
- )
- )
- )
- )
- )
- self.assertEqual(str(complex_type),
- 'const unsigned int (* * const (* const * [9])[7])[3][4]')
- self.assertEqual(
- complex_type.decl('var'),
- 'const unsigned int (* * const (* const * var[9])[7])[3][4]')
- def test_struct_decl(self):
- struct = ast_nodes.StructDecl(
- name='mystruct',
- declname='struct mystruct_',
- fields=[
- ast_nodes.StructFieldDecl(
- name='foo',
- type=ast_nodes.ValueType('int'),
- doc='',
- )
- ],
- )
- self.assertEqual(struct.decl('var'), 'mystruct var')
- def test_anonymous_struct_decl(self):
- struct = ast_nodes.AnonymousStructDecl(
- fields=[
- ast_nodes.StructFieldDecl(
- name='foo',
- type=ast_nodes.ValueType('int'),
- doc='',
- ),
- ast_nodes.StructFieldDecl(
- name='bar',
- type=ast_nodes.ArrayType(
- inner_type=ast_nodes.ValueType('float'), extents=(3,)
- ),
- doc='',
- ),
- ],
- )
- self.assertEqual(str(struct), 'struct {int foo; float bar[3];}')
- self.assertEqual(struct.decl('var'), 'struct {int foo; float bar[3];} var')
- self.assertEqual(struct.fields[0].decltype, 'int')
- self.assertEqual(struct.fields[1].decltype, 'float [3]')
- def test_anonymous_union_decl(self):
- union = ast_nodes.AnonymousUnionDecl(
- fields=[
- ast_nodes.StructFieldDecl(
- name='foo',
- type=ast_nodes.ValueType('int'),
- doc='',
- ),
- ast_nodes.StructFieldDecl(
- name='bar',
- type=ast_nodes.ArrayType(
- inner_type=ast_nodes.ValueType('float'), extents=(3,)
- ),
- doc='',
- ),
- ],
- )
- self.assertEqual(str(union), 'union {int foo; float bar[3];}')
- self.assertEqual(union.decl('var'), 'union {int foo; float bar[3];} var')
- if __name__ == '__main__':
- absltest.main()
|