{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "project_demo_en.ipynb",
"version": "0.3.2",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"metadata": {
"id": "wOO_TUhsCcbq",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Project work\n",
"\n",
"1. step:\n",
"Kaggle registration: https://www.kaggle.com\n",
"\n",
"2. step:\n",
"Join the competition: https://www.kaggle.com/t/b0fc1fc485b146a2887ab6ab8b71c2a8\n",
"\n",
"3. step:\n",
"Api key generation (my account-> create new API token), and upload it to the cloud computer"
]
},
{
"metadata": {
"id": "46qTSOr5CgxZ",
"colab_type": "code",
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=",
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": "OK"
}
},
"base_uri": "https://localhost:8080/",
"height": 101
},
"outputId": "32061b88-078c-40d4-9ce1-91ab5d3ba7df"
},
"cell_type": "code",
"source": [
"#upload kaggle.json\n",
"from google.colab import files\n",
"files.upload()"
],
"execution_count": 3,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" \n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving kaggle.json to kaggle.json\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'kaggle.json': b'{\"username\":\"grosythomas\",\"key\":\"cea5953741ce6db47782b6abab1d808b\"}'}"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"metadata": {
"id": "aRQt17mACcb2",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "3f0fac17-0d66-4cd6-ae49-569101109bdf"
},
"cell_type": "code",
"source": [
"!mkdir -p /root/.kaggle\n",
"!cp kaggle.json /root/.kaggle/kaggle.json\n",
"!chmod 600 /root/.kaggle/kaggle.json\n",
"!ls /root//.kaggle"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"kaggle.json\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "zuBDBRw5CccR",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"# Download data"
]
},
{
"metadata": {
"scrolled": true,
"id": "6Sab2JLkCccX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 178
},
"outputId": "6539048d-61a9-4837-cd34-6fec5d070f5c"
},
"cell_type": "code",
"source": [
"import kaggle\n",
"\n",
"kaggle.api.authenticate()\n",
"\n",
"!kaggle competitions download -c artificial-neural-networks-and-their-applications"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"Downloading extra_info.tar.gz to /content\n",
"\r 0% 0.00/608k [00:00, ?B/s]\n",
"100% 608k/608k [00:00<00:00, 39.9MB/s]\n",
"Downloading test_images.tar.gz to /content\n",
" 38% 5.00M/13.0M [00:00<00:00, 17.0MB/s]\n",
"100% 13.0M/13.0M [00:00<00:00, 37.4MB/s]\n",
"Downloading train.tar.gz to /content\n",
" 90% 119M/132M [00:01<00:00, 59.1MB/s]\n",
"100% 132M/132M [00:01<00:00, 74.9MB/s]\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "3sZpJo_LCccn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 71
},
"outputId": "dbb443e1-3075-42c3-afb7-0b0e7ec35e10"
},
"cell_type": "code",
"source": [
"!mkdir project\n",
"\n",
"!mv *.tar.gz project/\n",
"\n",
"%cd project\n",
"\n",
"!ls .\n",
"\n",
"!tar -xzf train.tar.gz\n",
"!tar -xzf test_images.tar.gz\n",
"\n",
"%cd ..\n",
"\n"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"/content/project\n",
"extra_info.tar.gz test_images.tar.gz train.tar.gz\n",
"/content\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "Dw8rLyzqCcc4",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Load the training data\n",
"The name of the class is the same as the name of the subdir"
]
},
{
"metadata": {
"id": "Meuhj_RbCcc9",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 3590
},
"outputId": "640f5b28-a2cf-4d2a-c35d-dfa93db388b7"
},
"cell_type": "code",
"source": [
"import os\n",
"import numpy as np\n",
"from matplotlib.image import imread\n",
"\n",
"train_x = np.zeros((100000,64*64*3))\n",
"train_y = np.array([])\n",
"class_id = 0\n",
"idx = 0\n",
"class_mapping = {}\n",
"for category in os.listdir(\"project/train\"):\n",
" print(category)\n",
" \n",
" for img in os.listdir(\"project/train/\"+category+\"/images\"):\n",
" image = imread(\"project/train/\"+category+\"/images/\"+img)\n",
" image_n = np.float32(image/255.0) # normalize the image\n",
" \n",
" #handle grayscale images\n",
" if len(image.shape) ==2:\n",
" #print(img, image.shape)\n",
" image_n = np.zeros((64,64,3))\n",
" for i in range(3):\n",
" image_n[:,:,i] = np.float32(image/255.0)\n",
" #flatten the image\n",
" train_x[idx,:] = np.reshape(image_n, [1,64*64*3])\n",
" idx += 1\n",
" train_y = np.append(train_y, class_id)\n",
" class_mapping[class_id] = category\n",
" class_id += 1"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"n03100240\n",
"n07734744\n",
"n04532106\n",
"n04070727\n",
"n01945685\n",
"n04417672\n",
"n03891332\n",
"n03706229\n",
"n03796401\n",
"n04118538\n",
"n04532670\n",
"n02226429\n",
"n01983481\n",
"n02231487\n",
"n03976657\n",
"n02094433\n",
"n02988304\n",
"n02281406\n",
"n04259630\n",
"n02132136\n",
"n01641577\n",
"n01910747\n",
"n02793495\n",
"n03201208\n",
"n07583066\n",
"n03355925\n",
"n02823428\n",
"n02125311\n",
"n03733131\n",
"n03902125\n",
"n09256479\n",
"n02841315\n",
"n01443537\n",
"n02917067\n",
"n07711569\n",
"n01644900\n",
"n03584254\n",
"n02236044\n",
"n04328186\n",
"n02279972\n",
"n02190166\n",
"n03980874\n",
"n04540053\n",
"n07875152\n",
"n03763968\n",
"n03930313\n",
"n03837869\n",
"n09246464\n",
"n03160309\n",
"n02843684\n",
"n03126707\n",
"n01984695\n",
"n03026506\n",
"n12267677\n",
"n02909870\n",
"n04149813\n",
"n09193705\n",
"n07920052\n",
"n04487081\n",
"n03983396\n",
"n04008634\n",
"n03617480\n",
"n07747607\n",
"n01774750\n",
"n02074367\n",
"n02814860\n",
"n03637318\n",
"n02437312\n",
"n02892201\n",
"n02403003\n",
"n02504458\n",
"n04251144\n",
"n02769748\n",
"n09332890\n",
"n02002724\n",
"n04501370\n",
"n04311004\n",
"n02410509\n",
"n07720875\n",
"n03770439\n",
"n03599486\n",
"n02099712\n",
"n01784675\n",
"n02669723\n",
"n03838899\n",
"n02791270\n",
"n07695742\n",
"n07579787\n",
"n02814533\n",
"n03814639\n",
"n02999410\n",
"n07614500\n",
"n01855672\n",
"n04133789\n",
"n04146614\n",
"n02123394\n",
"n02206856\n",
"n04465501\n",
"n02106662\n",
"n01917289\n",
"n02927161\n",
"n02666196\n",
"n02950826\n",
"n01944390\n",
"n02123045\n",
"n02730930\n",
"n04597913\n",
"n03662601\n",
"n03544143\n",
"n03042490\n",
"n02481823\n",
"n03670208\n",
"n07615774\n",
"n03970156\n",
"n01629819\n",
"n02364673\n",
"n02788148\n",
"n03447447\n",
"n04099969\n",
"n02480495\n",
"n01742172\n",
"n03854065\n",
"n03085013\n",
"n03649909\n",
"n02963159\n",
"n01770393\n",
"n04562935\n",
"n04074963\n",
"n01698640\n",
"n02977058\n",
"n03255030\n",
"n04254777\n",
"n04376876\n",
"n02948072\n",
"n04023962\n",
"n02268443\n",
"n03937543\n",
"n03424325\n",
"n04486054\n",
"n02233338\n",
"n07768694\n",
"n02099601\n",
"n04366367\n",
"n04560804\n",
"n07749582\n",
"n02085620\n",
"n04596742\n",
"n02837789\n",
"n02124075\n",
"n03014705\n",
"n03977966\n",
"n03250847\n",
"n02113799\n",
"n04507155\n",
"n02415577\n",
"n04356056\n",
"n02815834\n",
"n03179701\n",
"n04371430\n",
"n04399382\n",
"n02486410\n",
"n02165456\n",
"n04456115\n",
"n04179913\n",
"n09428293\n",
"n02795169\n",
"n01774384\n",
"n02699494\n",
"n03404251\n",
"n01882714\n",
"n02802426\n",
"n03992509\n",
"n03444034\n",
"n02056570\n",
"n01950731\n",
"n03804744\n",
"n04275548\n",
"n04285008\n",
"n03400231\n",
"n01768244\n",
"n02423022\n",
"n03089624\n",
"n07871810\n",
"n04067472\n",
"n07873807\n",
"n02509815\n",
"n03393912\n",
"n02808440\n",
"n02906734\n",
"n07753592\n",
"n03388043\n",
"n02321529\n",
"n02395406\n",
"n02883205\n",
"n07715103\n",
"n04398044\n",
"n02129165\n",
"n02058221\n",
"n04265275\n",
"n06596364\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "tnTcgieHCcdK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 53
},
"outputId": "16b8a757-8fe3-45fe-dba2-37a298146201"
},
"cell_type": "code",
"source": [
"print(train_x.shape, train_y.shape)\n",
"import tensorflow as tf\n",
"#test if GPU is visiable\n",
"tf.test.is_gpu_available()"
],
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": [
"(100000, 12288) (100000,)\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"metadata": {
"id": "xI-DyoHyCcdX",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Train a simple ANN"
]
},
{
"metadata": {
"id": "2BWxU9-zCcde",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 3503
},
"outputId": "e7f01663-cf6a-46a1-e4c9-44a68c314c4f"
},
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
"\n",
"\n",
"# Parameters\n",
"learning_rate = 0.01\n",
"batch_size = 128\n",
"num_steps = train_x.shape[0]/batch_size*10 #10 epoch\n",
"display_step = 100\n",
"\n",
"# Network Parameters\n",
"n_hidden_1 = 256 # 1st layer number of neurons\n",
"n_hidden_2 = 256 # 2nd layer number of neurons\n",
"num_input = 64*64*3 # MNIST data input (img shape: 28*28)\n",
"num_classes = 200 # MNIST total classes (0-9 digits)\n",
"\n",
"# Define the input function for training\n",
"input_fn = tf.estimator.inputs.numpy_input_fn(\n",
" x={'images': train_x}, y=train_y,\n",
" batch_size=batch_size, num_epochs=None, shuffle=True)\n",
"\n",
"\n",
"# Define the neural network\n",
"def neural_net(x_dict):\n",
" # TF Estimator input is a dict, in case of multiple inputs\n",
" x = x_dict['images']\n",
" # Hidden fully connected layer with 256 neurons\n",
" layer_1 = tf.layers.dense(x, n_hidden_1)\n",
" # Hidden fully connected layer with 256 neurons\n",
" layer_2 = tf.layers.dense(layer_1, n_hidden_2)\n",
" # Output fully connected layer with a neuron for each class\n",
" out_layer = tf.layers.dense(layer_2, num_classes)\n",
" return out_layer\n",
"\n",
"# Define the model function (following TF Estimator Template)\n",
"def model_fn(features, labels, mode):\n",
" \n",
" # Build the neural network\n",
" logits = neural_net(features)\n",
" \n",
" # Predictions\n",
" pred_classes = tf.argmax(logits, axis=1)\n",
" pred_probas = tf.nn.softmax(logits)\n",
" \n",
" # If prediction mode, early return\n",
" if mode == tf.estimator.ModeKeys.PREDICT:\n",
" return tf.estimator.EstimatorSpec(mode, predictions=pred_classes) \n",
" \n",
" # Define loss and optimizer\n",
" loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
" logits=logits, labels=tf.cast(labels, dtype=tf.int32)))\n",
" optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)\n",
" train_op = optimizer.minimize(loss_op, global_step=tf.train.get_global_step())\n",
" \n",
" # Evaluate the accuracy of the model\n",
" acc_op = tf.metrics.accuracy(labels=labels, predictions=pred_classes)\n",
" \n",
" # TF Estimators requires to return a EstimatorSpec, that specify\n",
" # the different ops for training, evaluating, ...\n",
" estim_specs = tf.estimator.EstimatorSpec(\n",
" mode=mode,\n",
" predictions=pred_classes,\n",
" loss=loss_op,\n",
" train_op=train_op,\n",
" eval_metric_ops={'accuracy': acc_op})\n",
"\n",
" return estim_specs\n",
"\n",
"# Build the Estimator\n",
"model = tf.estimator.Estimator(model_fn)\n",
"\n",
"# Train the Model\n",
"model.train(input_fn, steps=num_steps)"
],
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Using default config.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpt8cbvqzw\n",
"INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpt8cbvqzw', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': , '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Colocations handled automatically by placer.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:62: QueueRunner.__init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"To construct input pipelines, use the `tf.data` module.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:500: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"To construct input pipelines, use the `tf.data` module.\n",
"INFO:tensorflow:Calling model_fn.\n",
"WARNING:tensorflow:From :27: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use keras.layers.dense instead.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/metrics_impl.py:455: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use tf.cast instead.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Create CheckpointSaverHook.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/monitored_session.py:809: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"To construct input pipelines, use the `tf.data` module.\n",
"INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpt8cbvqzw/model.ckpt.\n",
"INFO:tensorflow:loss = 5.532808293561921, step = 0\n",
"INFO:tensorflow:global_step/sec: 51.7325\n",
"INFO:tensorflow:loss = 2.7909423101056046, step = 100 (1.937 sec)\n",
"INFO:tensorflow:global_step/sec: 53.4774\n",
"INFO:tensorflow:loss = 7.145926990041952, step = 200 (1.869 sec)\n",
"INFO:tensorflow:global_step/sec: 53.4078\n",
"INFO:tensorflow:loss = 4.81749445879914, step = 300 (1.871 sec)\n",
"INFO:tensorflow:global_step/sec: 53.9997\n",
"INFO:tensorflow:loss = 3.3063380197173475, step = 400 (1.852 sec)\n",
"INFO:tensorflow:global_step/sec: 53.2796\n",
"INFO:tensorflow:loss = 3.1607801329321514, step = 500 (1.876 sec)\n",
"INFO:tensorflow:global_step/sec: 55.324\n",
"INFO:tensorflow:loss = 2.809981796259568, step = 600 (1.810 sec)\n",
"INFO:tensorflow:global_step/sec: 53.1836\n",
"INFO:tensorflow:loss = 5.004986839579538, step = 700 (1.881 sec)\n",
"INFO:tensorflow:global_step/sec: 54.3265\n",
"INFO:tensorflow:loss = 2.2008611628128465, step = 800 (1.838 sec)\n",
"INFO:tensorflow:global_step/sec: 54.2645\n",
"INFO:tensorflow:loss = 2.5357822317504786, step = 900 (1.842 sec)\n",
"INFO:tensorflow:global_step/sec: 54.5605\n",
"INFO:tensorflow:loss = 2.47889681700948, step = 1000 (1.833 sec)\n",
"INFO:tensorflow:global_step/sec: 54.0196\n",
"INFO:tensorflow:loss = 1.911211729391921, step = 1100 (1.855 sec)\n",
"INFO:tensorflow:global_step/sec: 53.9353\n",
"INFO:tensorflow:loss = 2.842434887888386, step = 1200 (1.850 sec)\n",
"INFO:tensorflow:global_step/sec: 54.737\n",
"INFO:tensorflow:loss = 2.84526290277222, step = 1300 (1.828 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9719\n",
"INFO:tensorflow:loss = 1.8235673739875082, step = 1400 (1.822 sec)\n",
"INFO:tensorflow:global_step/sec: 53.8968\n",
"INFO:tensorflow:loss = 2.5243631505851707, step = 1500 (1.855 sec)\n",
"INFO:tensorflow:global_step/sec: 52.5537\n",
"INFO:tensorflow:loss = 1.6244772137655281, step = 1600 (1.899 sec)\n",
"INFO:tensorflow:global_step/sec: 54.7475\n",
"INFO:tensorflow:loss = 2.1790337530184245, step = 1700 (1.827 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9307\n",
"INFO:tensorflow:loss = 2.336739694712284, step = 1800 (1.820 sec)\n",
"INFO:tensorflow:global_step/sec: 53.9247\n",
"INFO:tensorflow:loss = 2.8000953754910425, step = 1900 (1.854 sec)\n",
"INFO:tensorflow:global_step/sec: 54.2789\n",
"INFO:tensorflow:loss = 1.5331674927531205, step = 2000 (1.847 sec)\n",
"INFO:tensorflow:global_step/sec: 53.5025\n",
"INFO:tensorflow:loss = 2.372903232156319, step = 2100 (1.868 sec)\n",
"INFO:tensorflow:global_step/sec: 54.3026\n",
"INFO:tensorflow:loss = 1.7382900286608827, step = 2200 (1.840 sec)\n",
"INFO:tensorflow:global_step/sec: 52.5714\n",
"INFO:tensorflow:loss = 1.719172717983127, step = 2300 (1.901 sec)\n",
"INFO:tensorflow:global_step/sec: 50.926\n",
"INFO:tensorflow:loss = 2.751872650256067, step = 2400 (1.964 sec)\n",
"INFO:tensorflow:global_step/sec: 50.84\n",
"INFO:tensorflow:loss = 1.50519375466099, step = 2500 (1.967 sec)\n",
"INFO:tensorflow:global_step/sec: 54.2989\n",
"INFO:tensorflow:loss = 2.117233111010944, step = 2600 (1.841 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9087\n",
"INFO:tensorflow:loss = 2.082425132961795, step = 2700 (1.825 sec)\n",
"INFO:tensorflow:global_step/sec: 55.1301\n",
"INFO:tensorflow:loss = 1.711033338559454, step = 2800 (1.815 sec)\n",
"INFO:tensorflow:global_step/sec: 55.2942\n",
"INFO:tensorflow:loss = 1.8292836903882497, step = 2900 (1.806 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6781\n",
"INFO:tensorflow:loss = 1.3164687860586282, step = 3000 (1.796 sec)\n",
"INFO:tensorflow:global_step/sec: 54.8217\n",
"INFO:tensorflow:loss = 2.406928803685318, step = 3100 (1.821 sec)\n",
"INFO:tensorflow:global_step/sec: 53.9252\n",
"INFO:tensorflow:loss = 1.8713280006020296, step = 3200 (1.855 sec)\n",
"INFO:tensorflow:global_step/sec: 54.0503\n",
"INFO:tensorflow:loss = 1.9042599514221032, step = 3300 (1.851 sec)\n",
"INFO:tensorflow:global_step/sec: 55.2128\n",
"INFO:tensorflow:loss = 2.3807734041868858, step = 3400 (1.814 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9296\n",
"INFO:tensorflow:loss = 1.2133821863938887, step = 3500 (1.816 sec)\n",
"INFO:tensorflow:global_step/sec: 55.185\n",
"INFO:tensorflow:loss = 1.9066535737583297, step = 3600 (1.813 sec)\n",
"INFO:tensorflow:global_step/sec: 55.4097\n",
"INFO:tensorflow:loss = 1.8476262616453625, step = 3700 (1.807 sec)\n",
"INFO:tensorflow:global_step/sec: 54.7528\n",
"INFO:tensorflow:loss = 1.9074748344355903, step = 3800 (1.825 sec)\n",
"INFO:tensorflow:global_step/sec: 54.8959\n",
"INFO:tensorflow:loss = 1.4288656690465646, step = 3900 (1.823 sec)\n",
"INFO:tensorflow:global_step/sec: 53.6729\n",
"INFO:tensorflow:loss = 1.5607270268388962, step = 4000 (1.860 sec)\n",
"INFO:tensorflow:global_step/sec: 52.9551\n",
"INFO:tensorflow:loss = 1.8385261425464055, step = 4100 (1.889 sec)\n",
"INFO:tensorflow:global_step/sec: 52.8463\n",
"INFO:tensorflow:loss = 1.5581479858880107, step = 4200 (1.890 sec)\n",
"INFO:tensorflow:global_step/sec: 51.6655\n",
"INFO:tensorflow:loss = 1.8432191406185507, step = 4300 (1.937 sec)\n",
"INFO:tensorflow:global_step/sec: 51.5113\n",
"INFO:tensorflow:loss = 1.8550525821871628, step = 4400 (1.941 sec)\n",
"INFO:tensorflow:global_step/sec: 53.0868\n",
"INFO:tensorflow:loss = 1.4330184675697115, step = 4500 (1.888 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6238\n",
"INFO:tensorflow:loss = 1.9019120793849802, step = 4600 (1.797 sec)\n",
"INFO:tensorflow:global_step/sec: 55.4309\n",
"INFO:tensorflow:loss = 1.4712609788353173, step = 4700 (1.800 sec)\n",
"INFO:tensorflow:global_step/sec: 54.8557\n",
"INFO:tensorflow:loss = 1.5890422839681764, step = 4800 (1.823 sec)\n",
"INFO:tensorflow:global_step/sec: 53.1096\n",
"INFO:tensorflow:loss = 2.1563534627404204, step = 4900 (1.884 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9163\n",
"INFO:tensorflow:loss = 1.6195583668334823, step = 5000 (1.822 sec)\n",
"INFO:tensorflow:global_step/sec: 55.2709\n",
"INFO:tensorflow:loss = 1.848527393655266, step = 5100 (1.810 sec)\n",
"INFO:tensorflow:global_step/sec: 54.8537\n",
"INFO:tensorflow:loss = 1.7388532690388423, step = 5200 (1.821 sec)\n",
"INFO:tensorflow:global_step/sec: 55.7159\n",
"INFO:tensorflow:loss = 1.127423055431641, step = 5300 (1.799 sec)\n",
"INFO:tensorflow:global_step/sec: 54.8101\n",
"INFO:tensorflow:loss = 2.42203717206824, step = 5400 (1.820 sec)\n",
"INFO:tensorflow:global_step/sec: 54.0933\n",
"INFO:tensorflow:loss = 1.4206105865575551, step = 5500 (1.852 sec)\n",
"INFO:tensorflow:global_step/sec: 54.2863\n",
"INFO:tensorflow:loss = 1.5992331088497465, step = 5600 (1.842 sec)\n",
"INFO:tensorflow:global_step/sec: 53.6805\n",
"INFO:tensorflow:loss = 1.667188298756222, step = 5700 (1.864 sec)\n",
"INFO:tensorflow:global_step/sec: 55.02\n",
"INFO:tensorflow:loss = 1.7874781322400692, step = 5800 (1.815 sec)\n",
"INFO:tensorflow:global_step/sec: 54.4075\n",
"INFO:tensorflow:loss = 2.192092512840261, step = 5900 (1.839 sec)\n",
"INFO:tensorflow:global_step/sec: 54.5884\n",
"INFO:tensorflow:loss = 1.4784066057909628, step = 6000 (1.829 sec)\n",
"INFO:tensorflow:global_step/sec: 55.3408\n",
"INFO:tensorflow:loss = 1.6134632220405465, step = 6100 (1.807 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6674\n",
"INFO:tensorflow:loss = 1.919851346337701, step = 6200 (1.803 sec)\n",
"INFO:tensorflow:global_step/sec: 55.6387\n",
"INFO:tensorflow:loss = 1.762901685102014, step = 6300 (1.793 sec)\n",
"INFO:tensorflow:global_step/sec: 55.2594\n",
"INFO:tensorflow:loss = 1.6896174563685626, step = 6400 (1.812 sec)\n",
"INFO:tensorflow:global_step/sec: 53.8222\n",
"INFO:tensorflow:loss = 1.7089215104830981, step = 6500 (1.858 sec)\n",
"INFO:tensorflow:global_step/sec: 53.3423\n",
"INFO:tensorflow:loss = 2.1050756223781537, step = 6600 (1.870 sec)\n",
"INFO:tensorflow:global_step/sec: 55.1512\n",
"INFO:tensorflow:loss = 1.4146305906460062, step = 6700 (1.814 sec)\n",
"INFO:tensorflow:global_step/sec: 55.4751\n",
"INFO:tensorflow:loss = 1.3412139258939646, step = 6800 (1.806 sec)\n",
"INFO:tensorflow:global_step/sec: 55.3993\n",
"INFO:tensorflow:loss = 1.6216958076168058, step = 6900 (1.806 sec)\n",
"INFO:tensorflow:global_step/sec: 54.9556\n",
"INFO:tensorflow:loss = 1.4521656481953862, step = 7000 (1.818 sec)\n",
"INFO:tensorflow:global_step/sec: 54.8374\n",
"INFO:tensorflow:loss = 1.9592250804251852, step = 7100 (1.821 sec)\n",
"INFO:tensorflow:global_step/sec: 55.217\n",
"INFO:tensorflow:loss = 1.8456127490710568, step = 7200 (1.815 sec)\n",
"INFO:tensorflow:global_step/sec: 54.0894\n",
"INFO:tensorflow:loss = 1.7207501260958113, step = 7300 (1.845 sec)\n",
"INFO:tensorflow:global_step/sec: 53.2812\n",
"INFO:tensorflow:loss = 2.30412418130913, step = 7400 (1.876 sec)\n",
"INFO:tensorflow:global_step/sec: 53.5978\n",
"INFO:tensorflow:loss = 1.4743325174316617, step = 7500 (1.868 sec)\n",
"INFO:tensorflow:global_step/sec: 55.1629\n",
"INFO:tensorflow:loss = 1.6677467106786026, step = 7600 (1.815 sec)\n",
"INFO:tensorflow:global_step/sec: 54.327\n",
"INFO:tensorflow:loss = 1.8417745451713907, step = 7700 (1.839 sec)\n",
"INFO:tensorflow:global_step/sec: 54.2786\n",
"INFO:tensorflow:loss = 1.235292680354832, step = 7800 (1.839 sec)\n",
"INFO:tensorflow:Saving checkpoints for 7813 into /tmp/tmpt8cbvqzw/model.ckpt.\n",
"INFO:tensorflow:Loss for final step: 1.3179605742790623.\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
""
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"metadata": {
"id": "Igrv7r3LCcd2",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Load the test data"
]
},
{
"metadata": {
"id": "5zYQM910Ccd-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "48ce63d7-6d33-4d93-961a-833ee4650677"
},
"cell_type": "code",
"source": [
"test_x = np.zeros((10000,64*64*3))\n",
"test_filenames = []\n",
"idx = 0\n",
"for img in os.listdir(\"project/test_images\"):\n",
" image = imread(\"project/test_images/\"+img)\n",
" image_n = np.float32(image/255.0) # normalize the image\n",
" test_filenames.append(img) \n",
" #handle grayscale images\n",
" if len(image.shape) ==2:\n",
" #print(img, image.shape)\n",
" image_n = np.zeros((64,64,3))\n",
" for i in range(3):\n",
" image_n[:,:,i] = np.float32(image/255.0)\n",
" #flatten the image\n",
" test_x[idx,:] = np.reshape(image_n, [1,64*64*3])\n",
" idx += 1\n",
"print(test_x.shape)"
],
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": [
"(10000, 12288)\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "ebxpqEKpCceP",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Generate submission"
]
},
{
"metadata": {
"id": "P6wYKxucCcea",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 198
},
"outputId": "0fc414e4-5ab1-46cc-9908-9bf4c492cd7b"
},
"cell_type": "code",
"source": [
"input_fn_test = tf.estimator.inputs.numpy_input_fn(\n",
" x={'images': test_x},\n",
" batch_size=batch_size, shuffle=False)\n",
"\n",
"preds = list(model.predict(input_fn_test))\n",
"\n",
"f = open('submission.csv', 'w')\n",
"f.write('Id,Category\\n') # write header\n",
"for idx, pred_label in enumerate(preds):\n",
" f.write('%s,%s\\n' % (test_filenames[idx], class_mapping[pred_label]))\n",
" \n",
"f.close()"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Graph was finalized.\n",
"WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use standard file APIs to check for files with this prefix.\n",
"INFO:tensorflow:Restoring parameters from /tmp/tmpt8cbvqzw/model.ckpt-7813\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "sK-UnZtoCcev",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"**Don't forget to upload the submission.csv!**"
]
}
]
}