import unittest from rlbench.tasks import (FS10_V1, FS25_V1, FS50_V1, FS95_V1, MT15_V1, MT30_V1, MT55_V1, MT100_V1) FS_V1 = [ (FS10_V1, 10, 5), (FS25_V1, 25, 5), (FS50_V1, 50, 5), (FS95_V1, 95, 5)] MT_V1 = [ (MT15_V1, 15), (MT30_V1, 30), (MT55_V1, 55), (MT100_V1, 100)] class TestTaskSet(unittest.TestCase): def test_fs_v1(self): for ts, train, test in FS_V1: with self.subTest(task_set='FS%d_V1' % train): self.assertEqual(len(ts['train']), train) self.assertEqual(len(ts['test']), test) # Test no duplicates self.assertEqual(len(ts['train'] + ts['test']), len(set(ts['train'] + ts['test']))) self.assertFalse(any(i in ts['test'] for i in ts['train'])) def test_mt_v1(self): for ts, train in MT_V1: with self.subTest(task_set='MT%d_V1' % train): self.assertEqual(len(ts['train']), train) # Test no duplicates self.assertEqual(len(ts['train']), len(set(ts['train'])))