You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils_test.cc 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "abstract/utils.h"
  17. #include "common/common_test.h"
  18. #include "pipeline/jit/static_analysis/static_analysis.h"
  19. namespace mindspore {
  20. namespace abstract {
  21. class TestUtils : public UT::Common {
  22. public:
  23. TestUtils() {}
  24. virtual void SetUp() {}
  25. virtual void TearDown() {}
  26. };
  27. TEST_F(TestUtils, test_join) {
  28. // AbstractScalar
  29. AbstractBasePtr abs_s1 = FromValue(static_cast<int64_t>(1), false);
  30. AbstractBasePtr abs_s2 = FromValue(static_cast<int64_t>(2), false);
  31. AbstractBasePtr abs_s_anything = FromValue(static_cast<int64_t>(2), true);
  32. AbstractBasePtr res_s1 = abs_s1->Join(abs_s2);
  33. ASSERT_EQ(*res_s1, *abs_s_anything);
  34. // AbstractTuple join;
  35. std::vector<int64_t> list1 = {1, 2, 3, 4, 5};
  36. std::vector<int64_t> list2 = {5, 4, 3, 2, 1};
  37. AbstractBasePtr abs_t1 = FromValue(list1, true);
  38. AbstractBasePtr abs_t2 = FromValue(list2, true);
  39. AbstractBasePtr res_t1 = abs_t1->Join(abs_t2);
  40. ASSERT_EQ(res_t1, abs_t1);
  41. abs_s1 = FromValue(static_cast<int64_t>(1), false);
  42. AbstractBasePtr t1 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
  43. AbstractBasePtr t2 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s1, abs_s_anything}));
  44. AbstractBasePtr t3 = std::make_shared<AbstractTuple>(AbstractBasePtrList({abs_s_anything, abs_s_anything}));
  45. res_t1 = t1->Join(t2);
  46. ASSERT_EQ(res_t1, t1);
  47. res_t1 = t1->Join(t3);
  48. ASSERT_EQ(*res_t1, *t3);
  49. res_t1 = t3->Join(t1);
  50. ASSERT_EQ(res_t1, t3);
  51. }
  52. } // namespace abstract
  53. } // namespace mindspore