You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_file.py 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import tempfile
  4. import unittest
  5. from requests import HTTPError
  6. from modelscope.fileio.file import File, HTTPStorage, LocalStorage
  7. class FileTest(unittest.TestCase):
  8. def test_local_storage(self):
  9. storage = LocalStorage()
  10. temp_name = tempfile.gettempdir() + '/' + next(
  11. tempfile._get_candidate_names())
  12. binary_content = b'12345'
  13. storage.write(binary_content, temp_name)
  14. self.assertEqual(binary_content, storage.read(temp_name))
  15. content = '12345'
  16. storage.write_text(content, temp_name)
  17. self.assertEqual(content, storage.read_text(temp_name))
  18. os.remove(temp_name)
  19. def test_http_storage(self):
  20. storage = HTTPStorage()
  21. url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/texts/data.txt'
  22. content = 'this is test data'
  23. self.assertEqual(content.encode('utf8'), storage.read(url))
  24. self.assertEqual(content, storage.read_text(url))
  25. with storage.as_local_path(url) as local_file:
  26. with open(local_file, 'r') as infile:
  27. self.assertEqual(content, infile.read())
  28. with self.assertRaises(NotImplementedError):
  29. storage.write('dfad', url)
  30. with self.assertRaises(HTTPError):
  31. storage.read(url + 'df')
  32. def test_file(self):
  33. url = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/texts/data.txt'
  34. content = 'this is test data'
  35. self.assertEqual(content.encode('utf8'), File.read(url))
  36. with File.as_local_path(url) as local_file:
  37. with open(local_file, 'r') as infile:
  38. self.assertEqual(content, infile.read())
  39. with self.assertRaises(NotImplementedError):
  40. File.write('dfad', url)
  41. with self.assertRaises(HTTPError):
  42. File.read(url + 'df')
  43. temp_name = tempfile.gettempdir() + '/' + next(
  44. tempfile._get_candidate_names())
  45. binary_content = b'12345'
  46. File.write(binary_content, temp_name)
  47. self.assertEqual(binary_content, File.read(temp_name))
  48. os.remove(temp_name)
  49. if __name__ == '__main__':
  50. unittest.main()