import { Row, Col, Input, Button } from 'kenshin';
import JsonForm from '@/components/JsonForm';
import styles from './index.less';
import { getPayload, samplerList } from './_until';
import { useEffect, useRef } from 'react';
import { useReactive } from 'ahooks';
import { txt2img } from '../../../_serveice/openai/stablediffusion';

const defaultInfo = {
  width: 512,
  height: 512,
  sampler_index: 'Euler a',
  steps: 20,
  n_iter: 4,
};

const StableDiffusion = () => {
  const ref = useRef();
  const state = useReactive({
    prompt: '',
    imgUrl: ' ',
    loading: false,
    imgList: [],
  });

  useEffect(() => {
    const { form } = ref.current;
    form.setFieldsValue(defaultInfo);
  }, []);

  const edtaFormItem = [
    {
      title: '宽度',
      dataIndex: 'width',
      valueType: 'InputNumber',
      fieldProps: {
        placeholder: '请输入宽度',
        max: 1000,
        min: 0,
        stepType: 'inside',
      },
    },
    {
      title: '高度',
      dataIndex: 'height',
      valueType: 'InputNumber',
      fieldProps: {
        placeholder: '请输入高度',
        max: 1000,
        min: 0,
        stepType: 'inside',
      },
    },
    {
      title: '采样方法',
      dataIndex: 'sampler_index',
      valueType: 'Select',
      fieldProps: {
        options: samplerList.map((item) => ({ label: item, value: item })),
        placeholder: '请选择取样方法',
      },
    },
    {
      title: '迭代步数',
      dataIndex: 'steps',
      valueType: 'InputNumber',
      fieldProps: {
        placeholder: '请输入迭代步数',
        max: 1000,
        min: 0,
        stepType: 'inside',
      },
    },
    {
      title: '生成数量',
      dataIndex: 'n_iter',
      valueType: 'InputNumber',
      fieldProps: {
        placeholder: '请输入生成数量',
        max: 4,
        min: 1,
        stepType: 'inside',
      },
    },
  ];

  const handleCreateImg = () => {
    state.loading = true;
    state.imgUrl = '';
    const { form } = ref.current;
    form.validateFields().then((value) => {
      let payload = getPayload({
        ...value,
        prompt: state.prompt,
      });
      txt2img(payload).then((res) => {
        state.imgUrl = `data:image/png;base64,${res.images[0]}`;
        state.imgList =
          res.images?.map((item) => `data:image/png;base64,${item}`) || [];
        state.loading = false;
      });
    });
  };

  return (
    <Row style={{ position: 'relative', minHeight: '800px' }}>
      <Col flex="auto">
        <div className={styles.imgBox}>
          {state.imgUrl ? (
            <img src={state.imgUrl} style={{ height: '100%' }} />
          ) : (
            '正在生成图片。。。'
          )}
          <div className={styles.imgList}>
            {state.imgList?.map((item, index) => (
              <img
                src={item}
                key={index}
                style={{ width: 44, margin: '0 5px' }}
                onClick={() => (state.imgUrl = item)}
              />
            ))}
          </div>
        </div>
        <div className={styles.promptBox}>
          <Input.TextArea
            value={state.prompt}
            onChange={(e) => (state.prompt = e.target.value)}
          />
          <Button
            type="primary"
            disabled={state.loading}
            style={{ marginLeft: 20 }}
            onClick={handleCreateImg}
          >
            生成
          </Button>
        </div>
      </Col>
      <Col
        flex="350px"
        style={{ borderLeft: '1px solid #ccc', padding: '0 10px' }}
      >
        <JsonForm columns={edtaFormItem} ref={ref} />
      </Col>
    </Row>
  );
};

export default StableDiffusion;
