Esta entrada es parte del curso de Deep learning con PyTorch.
El módulo torch.utils.data de PyTorh tiene clases muy útiles para la carga de datos necesaria en los procesos de entrenamiento y validación. En este post vamos a ver dos de las clases más importantes, torch.utils.data.Dataset para albergar los datos y torch.utils.data.DataLoader para cargar los datos.
Las clases Dataset permiten instanciar objetos con el conjunto de datos que se van a cargar. PyTorch permite crear dos tipos distintos de datasets:
- Map-style: Implementa los métodos getitem() and len() y representa un mapeo de claves/índices a valores del conjunto de datos. La clase Dataset sería un ejemplo y es el tipo de dataset que veremos.
- Iterable-style: Implementa el método iter() y representa un iterable sobre los datos. La clase IterableDataset es un ejemplo.
A continuación vamos a ver un ejemplo de definición de una clase que contiene un conjunto de datos. Esta clase hereda la clase Dataset y sobreescribe el método getitem() que permite obtener una muestra de los datos a partir de la clave/índice y el método len() que devuelve el tamaño de los datos. A modo de ejemplo creamos unos datos con una lista de 1000 valores como muestra y y otra lista de 1000 valores como etiquetas.
import torch from torch.utils.data import Dataset class NumbersDataset(Dataset): def __init__(self): self.samples = list(range(1, 1001)) self.labels = list(range(1000, 2001)) def __len__(self): return len(self.samples) def __getitem__(self, idx): return self.samples[idx], self.labels[idx]
Creamos una instancia de la clase que hemos definido. Vemos el tamaño de los datos y varias muestras de los datos.
dataset = NumbersDataset() print(len(dataset)) print(dataset[100]) print(dataset[122:125])
1000 (101, 1100) ([123, 124, 125], [1122, 1123, 1124])
La clase torch.utils.data.DataLoader es la clase principal para cargar los datos. A esta clase se le pasa como argumento un objeto Dataset (map-style o iterable style) y tiene varias opciones como:
- Definir el orden y la forma de cargar los datos.
- Batching automático: Se carga un batch de datos de manera automática o manual.
- Realizar la carga de datos en un proceso o en múltiples procesos.
- Ubicar los tensores en memoria page-locked para agilizar su transferencia a la GPU.
A continuación, creamos una instancia de la clase torch.utils.data.DataLoader a la que pasamos el objeto dataset que hemos creado. Definimos un tamaño de batch de 10 y shuffle=False para que no se cambie el orden de los datos en cada epoch (recorrido completo de los datos).
batch_size=10 train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False) for i, (numbers, labels) in enumerate(train_loader): if i<11: print('Batch number %d'%(i+1)) print(numbers, labels)
Batch number 1 tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) tensor([1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009]) Batch number 2 tensor([11, 12, 13, 14, 15, 16, 17, 18, 19, 20]) tensor([1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019]) Batch number 3 tensor([21, 22, 23, 24, 25, 26, 27, 28, 29, 30]) tensor([1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029]) Batch number 4 tensor([31, 32, 33, 34, 35, 36, 37, 38, 39, 40]) tensor([1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039]) Batch number 5 tensor([41, 42, 43, 44, 45, 46, 47, 48, 49, 50]) tensor([1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049]) Batch number 6 tensor([51, 52, 53, 54, 55, 56, 57, 58, 59, 60]) tensor([1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059]) Batch number 7 tensor([61, 62, 63, 64, 65, 66, 67, 68, 69, 70]) tensor([1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069]) Batch number 8 tensor([71, 72, 73, 74, 75, 76, 77, 78, 79, 80]) tensor([1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079]) Batch number 9 tensor([81, 82, 83, 84, 85, 86, 87, 88, 89, 90]) tensor([1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089]) Batch number 10 tensor([ 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]) tensor([1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099]) Batch number 11 tensor([101, 102, 103, 104, 105, 106, 107, 108, 109, 110]) tensor([1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108, 1109])
Ahora, creamos otra instancia pero con shuffle=True para que se cambie el orden de los datos en cada epoch.
train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) for i, (numbers, labels) in enumerate(train_loader): if i<11: print('Batch number %d'%(i+1)) print(numbers, labels)
Batch number 1 tensor([498, 561, 641, 592, 870, 242, 441, 161, 717, 683]) tensor([1497, 1560, 1640, 1591, 1869, 1241, 1440, 1160, 1716, 1682]) Batch number 2 tensor([782, 549, 768, 810, 990, 478, 89, 813, 394, 897]) tensor([1781, 1548, 1767, 1809, 1989, 1477, 1088, 1812, 1393, 1896]) Batch number 3 tensor([282, 411, 430, 635, 453, 197, 786, 814, 624, 228]) tensor([1281, 1410, 1429, 1634, 1452, 1196, 1785, 1813, 1623, 1227]) Batch number 4 tensor([370, 900, 238, 816, 894, 724, 620, 147, 359, 554]) tensor([1369, 1899, 1237, 1815, 1893, 1723, 1619, 1146, 1358, 1553]) Batch number 5 tensor([239, 193, 662, 819, 435, 157, 968, 126, 415, 722]) tensor([1238, 1192, 1661, 1818, 1434, 1156, 1967, 1125, 1414, 1721]) Batch number 6 tensor([342, 313, 924, 982, 636, 861, 90, 933, 178, 527]) tensor([1341, 1312, 1923, 1981, 1635, 1860, 1089, 1932, 1177, 1526]) Batch number 7 tensor([246, 278, 921, 347, 954, 696, 955, 268, 82, 277]) tensor([1245, 1277, 1920, 1346, 1953, 1695, 1954, 1267, 1081, 1276]) Batch number 8 tensor([766, 169, 841, 794, 908, 808, 949, 448, 583, 424]) tensor([1765, 1168, 1840, 1793, 1907, 1807, 1948, 1447, 1582, 1423]) Batch number 9 tensor([595, 522, 93, 416, 604, 677, 503, 373, 423, 365]) tensor([1594, 1521, 1092, 1415, 1603, 1676, 1502, 1372, 1422, 1364]) Batch number 10 tensor([553, 723, 490, 128, 101, 546, 661, 777, 966, 582]) tensor([1552, 1722, 1489, 1127, 1100, 1545, 1660, 1776, 1965, 1581]) Batch number 11 tensor([354, 781, 930, 88, 468, 648, 141, 361, 98, 632]) tensor([1353, 1780, 1929, 1087, 1467, 1647, 1140, 1360, 1097, 1631])
Deja una respuesta