使用tensorflow实现u-net的一个例子

[复制链接]
 楼主| keer_zu 发表于 2019-3-31 09:40 | 显示全部楼层 |阅读模式
basic-u-net-using-tensorflow


  1. import os
  2. import sys
  3. import numpy as np
  4. import tensorflow as tf
  5. import random
  6. import math
  7. import warnings
  8. import pandas as pd
  9. import cv2
  10. import matplotlib.pyplot as plt

  11. from tqdm import tqdm
  12. from itertools import chain
  13. from skimage.io import imread, imshow, imread_collection, concatenate_images
  14. from skimage.transform import resize
  15. from skimage.morphology import label

  16. warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
  17. seed = 42
  18. random.seed = seed
  19. np.random.seed = seed

  20. # Set some parameters
  21. IMG_WIDTH = 128
  22. IMG_HEIGHT = 128
  23. IMG_CHANNELS = 3
  24. TRAIN_PATH = '../input/stage1_train/'
  25. TEST_PATH = '../input/stage1_test/'

  26. warnings.filterwarnings('ignore', category=UserWarning, module='skimage')
  27. seed = 42
  28. random.seed = seed
  29. np.random.seed = seed

  30. # Get train and test IDs
  31. train_ids = next(os.walk(TRAIN_PATH))[1]
  32. test_ids = next(os.walk(TEST_PATH))[1]


  33. # Get and resize train images and masks
  34. images = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
  35. labels = np.zeros((len(train_ids), IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool)
  36. print('Getting and resizing train images and masks ... ')
  37. sys.stdout.flush()
  38. for n, id_ in tqdm(enumerate(train_ids), total=len(train_ids)):
  39.     path = TRAIN_PATH + id_
  40.     img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]
  41.     img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
  42.     images[n] = img
  43.     mask = np.zeros((IMG_HEIGHT, IMG_WIDTH, 1), dtype=np.bool)
  44.     for mask_file in next(os.walk(path + '/masks/'))[2]:
  45.         mask_ = imread(path + '/masks/' + mask_file)
  46.         mask_ = np.expand_dims(resize(mask_, (IMG_HEIGHT, IMG_WIDTH), mode='constant',
  47.                                       preserve_range=True), axis=-1)
  48.         mask = np.maximum(mask, mask_)
  49.     labels[n] = mask

  50. X_train = images
  51. Y_train = labels

  52. # Get and resize test images
  53. X_test = np.zeros((len(test_ids), IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), dtype=np.uint8)
  54. sizes_test = []
  55. print('Getting and resizing test images ... ')
  56. sys.stdout.flush()
  57. for n, id_ in tqdm(enumerate(test_ids), total=len(test_ids)):
  58.     path = TEST_PATH + id_
  59.     img = imread(path + '/images/' + id_ + '.png')[:,:,:IMG_CHANNELS]
  60.     sizes_test.append([img.shape[0], img.shape[1]])
  61.     img = resize(img, (IMG_HEIGHT, IMG_WIDTH), mode='constant', preserve_range=True)
  62.     X_test[n] = img

  63. print('Done!')





您需要登录后才可以回帖 登录 | 注册

本版积分规则

个人签名:qq群:49734243 Email:zukeqiang@gmail.com

1474

主题

12900

帖子

55

粉丝
快速回复 返回顶部 返回列表