Stylegan2 model with flask API is generating weird results after first request

喜你入骨 提交于 2020-12-12 08:50:29

问题


So here's whats happening. I have been using the StyleGAN2 model for a while now and I decided to make a website that will allow the user to input the arguments for the model to generate the images. The model has been trained using tensorflow v1.15 and the code works perfectly fine and generates all the required outputs when I run the model directly on my machine through the command line. The problem arises when I am now using a flask API to do the same thing.

Here is all the code for generating the images that I am using. I have made little to no changes to the original run_generator.py file.

run_generator.py

import argparse
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import re
import sys
import pretrained_networks

#----------------------------------------------------------------------------
def generate_images(network_pkl, seeds, truncation_psi):
    print('Loading networks from "%s"...' % network_pkl)
    _G, _D, Gs = pretrained_networks.load_networks(network_pkl)
    noise_vars = [var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]

    Gs_kwargs = dnnlib.EasyDict()
    Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
    Gs_kwargs.randomize_noise = False
    if truncation_psi is not None:
        Gs_kwargs.truncation_psi = truncation_psi

    for seed_idx, seed in enumerate(seeds):
        print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))
        rnd = np.random.RandomState(seed)
        z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]
        tflib.set_vars({var: rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]
        images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]
        PIL.Image.fromarray(images[0], 'RGB').save(dnnlib.make_run_dir_path('./results/Image_Generates/seed%04d.png' % seed))

#----------------------------------------------------------------------------

def _parse_num_range(s):
    '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''

    range_re = re.compile(r'^(\d+)-(\d+)$')
    m = range_re.match(s)
    if m:
        return list(range(int(m.group(1)), int(m.group(2))+1))
    vals = s.split(',')
    return [int(x) for x in vals]

#----------------------------------------------------------------------------

def call_generate_images(seed,trunc):  
        networks = "./results/00000-pretrained/network-snapshot-10000.pkl"
        seeds = _parse_num_range(seed)
        generate_images(networks, seeds, trunc)

Here is the flask API that I using to get the HTML inputs from the webpage and use them to generate the images through the run_generator.py file above.

app.py

from flask import Flask, render_template, request
from flask_cors import CORS
import requests
import tensorflow
import my_generator
app = Flask(__name__)
graph = tensorflow.get_default_graph()
@app.route('/',methods=['GET'])
def Home():
    return render_template('index.html')

@app.route("/predict", methods=['POST'])
def predict():
    global graph
    s = tensorflow.Session(graph=graph)
    with graph.as_default():
        with s as sess:
            sess.run(tensorflow.global_variables_initializer())
            sess.run(tensorflow.local_variables_initializer())
            if request.method == 'POST':
                Seeds=request.form['Seeds']
                Truncation=float(request.form['Trunc'])
                my_generator.call_generate_images(Seeds,Truncation)
    return render_template('index.html',generation_text="Images have been generated.")
if __name__=="__main__":
    app.run()

And here is the HTML that I am using a very simple webpage.

Index.html

<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Document</title>
</head>

<body>

    <div style="color:blue">
        <form action="{{ url_for('predict')}}" method="post">
            <h2>Image Generator</h2>
            <h3>Seeds</h3><input id="second" name="Seeds" required="required">
            <h3>Truncation-psi</h3><input id="third" name="Trunc" required="required">
            <br><br><button id="sub" type="submit ">Generate images</button>
            <br>
        </form>
        <br><h3>{{ generation_text }}<h3>
    </div>
    <style>
        body {
            background-color: lightslategray;
            text-align: center;
            padding: 0px;
        }
        
        #research {
            font-size: 18px;
            width: 100px;
            height: 23px;
            top: 23px;
        }
        
        #box {
            border-radius: 60px;
            border-color: 45px;
            border-style: solid;
            font-family: cursive;
            text-align: center;
            background-color: rgb(168, 131, 61);
            font-size: medium;
            position: absolute;
            width: 700px;
            bottom: 9%;
            height: 850px;
            right: 30%;
            padding: 0px;
            margin: 0px;
            font-size: 14px;
        }
        
        #fuel {
            width: 83px;
            height: 43px;
            text-align: center;
            border-radius: 14px;
            font-size: 20px;
        }
        
        #fuel:hover {
            background-color: coral;
        }
        
        #research {
            width: 99px;
            height: 43px;
            text-align: center;
            border-radius: 14px;
            font-size: 18px;
        }
        
        #research:hover {
            background-color: coral;
        }
        
        #resea {
            width: 99px;
            height: 43px;
            text-align: center;
            border-radius: 14px;
            font-size: 18px;
        }
        
        #resea:hover {
            background-color: coral;
        }
        
        #sub {
            width: 120px;
            height: 43px;
            text-align: center;
            border-radius: 14px;
            font-size: 18px;
        }
        
        #sub:hover {
            background-color: darkcyan;
        }
        
        #first {
            border-radius: 14px;
            height: 25px;
            font-size: 20px;
            text-align: center;
        }
        
        #second {
            border-radius: 14px;
            height: 25px;
            font-size: 20px;
            text-align: center;
        }
        
        #third {
            border-radius: 14px;
            height: 25px;
            font-size: 20px;
            text-align: center;
        }
        
        #fourth {
            border-radius: 14px;
            height: 25px;
            font-size: 20px;
            text-align: center;
        }
    </style>
</body>
</html>

As it can be seen from all the code above all I am simply doing is:

  1. Asking the user to input the seed value and the truncation-psi for the model to generate the images.
  2. These values are taken up by the app.py flask API through a POST method and passed to the run_generator.py scripts call_generate_images method where the seed value is first converted to an iterable list, the pretrained model is loaded, and the truncation-psi was already converted to float in the app.py file.
  3. The run_generator.py file then generates all the images from the values inputted and stores them in the ./results/Image_Generates folder.

The problem that I am facing now is that whenever I make the first request to the API the model works perfectly and generates the images as required. For example: If I input the values for seed as 1 and truncation-psi as 0.5. The following image gets generated. Image generated for the first API request.

And the page is then reloaded with Images have been generated text at the bottom.

However, now when I enter any other value For example: The value for seed as 2 and truncation-psi as 0.5. The following image gets generated. Image generated on the subsequent API requests.

I am unable to understand what is the exact reason causing this issue. And I know that the code that I have shown may be very shabby almost spaghetti code and this may not be the best implementation for what I am trying to do, but this is just a raw implementation for a college project. Please help me to solve this error as I am not extremely experienced in tensorflow and flask in general.

P.S.: As you may have seen in the app.py file that I am getting the default graph from the tensorflow session. I have tried implementing without doing that and I have received the following error:

tensor("g_synthesis_1/noise0/new_value:0", shape=(1, 1, 4, 4), dtype=float32) must be from the same graph as tensor("g_synthesis_1/noise0:0", shape=(1, 1, 4, 4), dtype=float32_ref).

In order to solve this issue I implemented the app.py in the above manner.

Please, help me in resolving the issue and just a humble request please be a little easy on me :)

Any help is appreciated and feel free to ask me any additional information needed to help me resolve the issue.

来源:https://stackoverflow.com/questions/64926275/stylegan2-model-with-flask-api-is-generating-weird-results-after-first-request

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!