|
@@ -88,7 +88,7 @@ class ImageColorizationPipeline(Pipeline): |
|
|
img = input.convert('LA').convert('RGB') |
|
|
img = input.convert('LA').convert('RGB') |
|
|
elif isinstance(input, np.ndarray): |
|
|
elif isinstance(input, np.ndarray): |
|
|
if len(input.shape) == 2: |
|
|
if len(input.shape) == 2: |
|
|
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
|
|
|
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) |
|
|
img = input[:, :, ::-1] # in rgb order |
|
|
img = input[:, :, ::-1] # in rgb order |
|
|
img = PIL.Image.fromarray(img).convert('LA').convert('RGB') |
|
|
img = PIL.Image.fromarray(img).convert('LA').convert('RGB') |
|
|
else: |
|
|
else: |
|
|